Merge commit from fork

* sec: do not expose traffic stats to all users, harden origin check in websocket endpoint

* add tests to validate new logic
This commit is contained in:
h44z
2026-06-05 20:13:18 +02:00
committed by GitHub
parent e3dc31a133
commit 316f389f11
3 changed files with 271 additions and 4 deletions

View File

@@ -135,7 +135,7 @@ func main() {
apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendPeers)
apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth, wireGuard)
apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth)
apiV0EndpointWebsocket := handlersV0.NewWebsocketEndpoint(cfg, apiV0Auth, eventBus)
apiV0EndpointWebsocket := handlersV0.NewWebsocketEndpoint(cfg, apiV0Auth, eventBus, apiV0BackendPeers)
apiFrontend := handlersV0.NewRestApi(apiV0Session,
apiV0EndpointAuth,

View File

@@ -3,6 +3,7 @@ package handlers
import (
"context"
"net/http"
"net/url"
"strings"
"sync"
@@ -19,23 +20,28 @@ type WebsocketEventBus interface {
Unsubscribe(topic string, fn any) error
}
type WebsocketPeerService interface {
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
}
type WebsocketEndpoint struct {
authenticator Authenticator
bus WebsocketEventBus
peerService WebsocketPeerService
upgrader websocket.Upgrader
}
func NewWebsocketEndpoint(cfg *config.Config, auth Authenticator, bus WebsocketEventBus) *WebsocketEndpoint {
func NewWebsocketEndpoint(cfg *config.Config, auth Authenticator, bus WebsocketEventBus, peerService WebsocketPeerService) *WebsocketEndpoint {
return &WebsocketEndpoint{
authenticator: auth,
bus: bus,
peerService: peerService,
upgrader: websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
origin := r.Header.Get("Origin")
return strings.HasPrefix(origin, cfg.Web.ExternalUrl)
return matchOrigin(cfg.Web.ExternalUrl, r.Header.Get("Origin"))
},
},
}
@@ -57,6 +63,8 @@ type wsMessage struct {
func (e WebsocketEndpoint) handleWebsocket() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userInfo := domain.GetUserInfo(r.Context())
conn, err := e.upgrader.Upgrade(w, r, nil)
if err != nil {
return
@@ -74,9 +82,29 @@ func (e WebsocketEndpoint) handleWebsocket() http.HandlerFunc {
}
peerStatsHandler := func(status domain.TrafficDelta) {
if !userInfo.IsAdmin {
// lookup peer user-info to validate ownership
peer, err := e.peerService.GetPeer(ctx, domain.PeerIdentifier(status.EntityId))
if err != nil {
return
}
if peer.UserIdentifier == "" {
return // if peer is not assigned to any user, dont send stats
}
if peer.UserIdentifier != userInfo.Id {
return // only expose stats for own peers
}
}
_ = writeJSON(wsMessage{Type: "peer_stats", Data: status})
}
interfaceStatsHandler := func(status domain.TrafficDelta) {
if !userInfo.IsAdmin {
return // interface stats will only be exposed to admins
}
_ = writeJSON(wsMessage{Type: "interface_stats", Data: status})
}
@@ -98,3 +126,18 @@ func (e WebsocketEndpoint) handleWebsocket() http.HandlerFunc {
<-ctx.Done()
}
}
func matchOrigin(externalBaseUrl, origin string) bool {
originURL, err := url.Parse(origin)
if err != nil {
return false
}
externalURL, err := url.Parse(externalBaseUrl)
if err != nil {
return false
}
return originURL.Scheme == externalURL.Scheme &&
strings.EqualFold(originURL.Host, externalURL.Host)
}

View File

@@ -0,0 +1,224 @@
package handlers
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gorilla/websocket"
evbus "github.com/vardius/message-bus"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
// region test-helper
type websocketTestPeerService struct {
peers map[domain.PeerIdentifier]*domain.Peer
}
func (s websocketTestPeerService) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) {
peer, ok := s.peers[id]
if !ok {
return nil, errors.New("peer not found")
}
return peer, nil
}
func newTestWebsocketConnection(
t *testing.T,
bus evbus.MessageBus,
userInfo *domain.ContextUserInfo,
peers map[domain.PeerIdentifier]*domain.Peer,
) (*websocket.Conn, func()) {
t.Helper()
cfg := &config.Config{}
endpoint := NewWebsocketEndpoint(cfg, nil, bus, websocketTestPeerService{peers: peers})
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(domain.SetUserInfo(r.Context(), userInfo))
endpoint.handleWebsocket()(w, r)
}))
cfg.Web.ExternalUrl = server.URL
wsURL := "ws" + server.URL[len("http"):]
conn, _, err := websocket.DefaultDialer.Dial(wsURL, http.Header{"Origin": []string{server.URL}})
if err != nil {
server.Close()
t.Fatalf("failed to dial websocket: %v", err)
}
cleanup := func() {
conn.Close()
server.Close()
}
return conn, cleanup
}
func assertWebsocketMessage(t *testing.T, conn *websocket.Conn, messageType string, entityId string) {
t.Helper()
if err := conn.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
var message wsMessage
if err := conn.ReadJSON(&message); err != nil {
t.Fatalf("failed to read websocket message: %v", err)
}
if message.Type != messageType {
t.Fatalf("unexpected message type: got %q, want %q", message.Type, messageType)
}
data, ok := message.Data.(map[string]any)
if !ok {
t.Fatalf("unexpected message data type: %T", message.Data)
}
if data["EntityId"] != entityId {
t.Fatalf("unexpected entity id: got %v, want %q", data["EntityId"], entityId)
}
}
func assertNoWebsocketMessage(t *testing.T, conn *websocket.Conn) {
t.Helper()
if err := conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
var message wsMessage
if err := conn.ReadJSON(&message); err == nil {
t.Fatalf("unexpected websocket message: %+v", message)
}
}
// endregion test-helper
func TestWebsocketEndpointAllowsOwnPeerStatsForNonAdmin(t *testing.T) {
bus := evbus.New(10)
conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "user-a"},
map[domain.PeerIdentifier]*domain.Peer{
"own-peer": {Identifier: "own-peer", UserIdentifier: "user-a"},
})
defer cleanup()
bus.Publish(app.TopicPeerStatsUpdated, domain.TrafficDelta{EntityId: "own-peer", BytesReceivedPerSecond: 1})
assertWebsocketMessage(t, conn, "peer_stats", "own-peer")
}
func TestWebsocketEndpointFiltersOtherPeerStatsForNonAdmin(t *testing.T) {
bus := evbus.New(10)
conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "user-a"},
map[domain.PeerIdentifier]*domain.Peer{
"other-peer": {Identifier: "other-peer", UserIdentifier: "user-b"},
})
defer cleanup()
bus.Publish(app.TopicPeerStatsUpdated, domain.TrafficDelta{EntityId: "other-peer", BytesReceivedPerSecond: 1})
assertNoWebsocketMessage(t, conn)
}
func TestWebsocketEndpointFiltersUnknownPeerStatsForNonAdmin(t *testing.T) {
bus := evbus.New(10)
conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "user-a"},
map[domain.PeerIdentifier]*domain.Peer{
"other-peer": {Identifier: "other-peer", UserIdentifier: ""},
})
defer cleanup()
bus.Publish(app.TopicPeerStatsUpdated, domain.TrafficDelta{EntityId: "other-peer", BytesReceivedPerSecond: 1})
assertNoWebsocketMessage(t, conn)
}
func TestWebsocketEndpointFiltersUnknownPeerStatsForNonAdmin2(t *testing.T) {
bus := evbus.New(10)
conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "user-a"}, nil)
defer cleanup()
bus.Publish(app.TopicPeerStatsUpdated, domain.TrafficDelta{EntityId: "unknown-peer", BytesReceivedPerSecond: 1})
assertNoWebsocketMessage(t, conn)
}
func TestWebsocketEndpointFiltersInterfaceStatsForNonAdmin(t *testing.T) {
bus := evbus.New(10)
conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "user-a"}, nil)
defer cleanup()
bus.Publish(app.TopicInterfaceStatsUpdated, domain.TrafficDelta{EntityId: "wg0", BytesReceivedPerSecond: 1})
assertNoWebsocketMessage(t, conn)
}
func TestWebsocketEndpointAllowsAllStatsForAdmin(t *testing.T) {
bus := evbus.New(10)
conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "admin", IsAdmin: true}, nil)
defer cleanup()
bus.Publish(app.TopicPeerStatsUpdated, domain.TrafficDelta{EntityId: "other-peer", BytesReceivedPerSecond: 1})
assertWebsocketMessage(t, conn, "peer_stats", "other-peer")
bus.Publish(app.TopicInterfaceStatsUpdated, domain.TrafficDelta{EntityId: "wg0", BytesReceivedPerSecond: 1})
assertWebsocketMessage(t, conn, "interface_stats", "wg0")
}
func Test_matchOrigin(t *testing.T) {
tests := []struct {
name string
externalBaseUrl string
origin string
want bool
}{
{
name: "matching origin",
externalBaseUrl: "https://example.com",
origin: "https://example.com",
want: true,
},
{
name: "matching origin with path",
externalBaseUrl: "https://example.com/app1",
origin: "https://example.com/app2",
want: true,
},
{
name: "non-matching origin with different host",
externalBaseUrl: "https://example.com",
origin: "https://example.com.malicious.com",
want: false,
},
{
name: "non-matching origin with different scheme",
externalBaseUrl: "https://example.com",
origin: "http://example.com",
want: false,
},
{
name: "invalid origin URL",
externalBaseUrl: "https://example.com",
origin: "://invalid-url",
want: false,
},
{
name: "invalid externalBaseUrl",
externalBaseUrl: "://invalid-url",
origin: "https://example.com",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := matchOrigin(tt.externalBaseUrl, tt.origin)
if got != tt.want {
t.Errorf("matchOrigin() = %v, want %v", got, tt.want)
}
})
}
}