diff --git a/cmd/wg-portal/main.go b/cmd/wg-portal/main.go index 327e327..b1c6aa9 100644 --- a/cmd/wg-portal/main.go +++ b/cmd/wg-portal/main.go @@ -135,7 +135,7 @@ func main() { apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendPeers) apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth, wireGuard) apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth) - apiV0EndpointWebsocket := handlersV0.NewWebsocketEndpoint(cfg, apiV0Auth, eventBus) + apiV0EndpointWebsocket := handlersV0.NewWebsocketEndpoint(cfg, apiV0Auth, eventBus, apiV0BackendPeers) apiFrontend := handlersV0.NewRestApi(apiV0Session, apiV0EndpointAuth, diff --git a/internal/app/api/v0/handlers/endpoint_websocket.go b/internal/app/api/v0/handlers/endpoint_websocket.go index 5dcc35a..c3afa28 100644 --- a/internal/app/api/v0/handlers/endpoint_websocket.go +++ b/internal/app/api/v0/handlers/endpoint_websocket.go @@ -3,6 +3,7 @@ package handlers import ( "context" "net/http" + "net/url" "strings" "sync" @@ -19,23 +20,28 @@ type WebsocketEventBus interface { Unsubscribe(topic string, fn any) error } +type WebsocketPeerService interface { + GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) +} + type WebsocketEndpoint struct { authenticator Authenticator bus WebsocketEventBus + peerService WebsocketPeerService upgrader websocket.Upgrader } -func NewWebsocketEndpoint(cfg *config.Config, auth Authenticator, bus WebsocketEventBus) *WebsocketEndpoint { +func NewWebsocketEndpoint(cfg *config.Config, auth Authenticator, bus WebsocketEventBus, peerService WebsocketPeerService) *WebsocketEndpoint { return &WebsocketEndpoint{ authenticator: auth, bus: bus, + peerService: peerService, upgrader: websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { - origin := r.Header.Get("Origin") - return strings.HasPrefix(origin, cfg.Web.ExternalUrl) + return matchOrigin(cfg.Web.ExternalUrl, r.Header.Get("Origin")) }, }, } @@ -57,6 +63,8 @@ type wsMessage struct { func (e WebsocketEndpoint) handleWebsocket() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + userInfo := domain.GetUserInfo(r.Context()) + conn, err := e.upgrader.Upgrade(w, r, nil) if err != nil { return @@ -74,9 +82,29 @@ func (e WebsocketEndpoint) handleWebsocket() http.HandlerFunc { } peerStatsHandler := func(status domain.TrafficDelta) { + if !userInfo.IsAdmin { + // lookup peer user-info to validate ownership + peer, err := e.peerService.GetPeer(ctx, domain.PeerIdentifier(status.EntityId)) + if err != nil { + return + } + + if peer.UserIdentifier == "" { + return // if peer is not assigned to any user, dont send stats + } + + if peer.UserIdentifier != userInfo.Id { + return // only expose stats for own peers + } + } + _ = writeJSON(wsMessage{Type: "peer_stats", Data: status}) } interfaceStatsHandler := func(status domain.TrafficDelta) { + if !userInfo.IsAdmin { + return // interface stats will only be exposed to admins + } + _ = writeJSON(wsMessage{Type: "interface_stats", Data: status}) } @@ -98,3 +126,18 @@ func (e WebsocketEndpoint) handleWebsocket() http.HandlerFunc { <-ctx.Done() } } + +func matchOrigin(externalBaseUrl, origin string) bool { + originURL, err := url.Parse(origin) + if err != nil { + return false + } + + externalURL, err := url.Parse(externalBaseUrl) + if err != nil { + return false + } + + return originURL.Scheme == externalURL.Scheme && + strings.EqualFold(originURL.Host, externalURL.Host) +} diff --git a/internal/app/api/v0/handlers/endpoint_websocket_test.go b/internal/app/api/v0/handlers/endpoint_websocket_test.go new file mode 100644 index 0000000..ad98689 --- /dev/null +++ b/internal/app/api/v0/handlers/endpoint_websocket_test.go @@ -0,0 +1,224 @@ +package handlers + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gorilla/websocket" + evbus "github.com/vardius/message-bus" + + "github.com/h44z/wg-portal/internal/app" + "github.com/h44z/wg-portal/internal/config" + "github.com/h44z/wg-portal/internal/domain" +) + +// region test-helper + +type websocketTestPeerService struct { + peers map[domain.PeerIdentifier]*domain.Peer +} + +func (s websocketTestPeerService) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) { + peer, ok := s.peers[id] + if !ok { + return nil, errors.New("peer not found") + } + + return peer, nil +} + +func newTestWebsocketConnection( + t *testing.T, + bus evbus.MessageBus, + userInfo *domain.ContextUserInfo, + peers map[domain.PeerIdentifier]*domain.Peer, +) (*websocket.Conn, func()) { + t.Helper() + + cfg := &config.Config{} + endpoint := NewWebsocketEndpoint(cfg, nil, bus, websocketTestPeerService{peers: peers}) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r = r.WithContext(domain.SetUserInfo(r.Context(), userInfo)) + endpoint.handleWebsocket()(w, r) + })) + cfg.Web.ExternalUrl = server.URL + + wsURL := "ws" + server.URL[len("http"):] + conn, _, err := websocket.DefaultDialer.Dial(wsURL, http.Header{"Origin": []string{server.URL}}) + if err != nil { + server.Close() + t.Fatalf("failed to dial websocket: %v", err) + } + + cleanup := func() { + conn.Close() + server.Close() + } + + return conn, cleanup +} + +func assertWebsocketMessage(t *testing.T, conn *websocket.Conn, messageType string, entityId string) { + t.Helper() + + if err := conn.SetReadDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("failed to set read deadline: %v", err) + } + + var message wsMessage + if err := conn.ReadJSON(&message); err != nil { + t.Fatalf("failed to read websocket message: %v", err) + } + + if message.Type != messageType { + t.Fatalf("unexpected message type: got %q, want %q", message.Type, messageType) + } + + data, ok := message.Data.(map[string]any) + if !ok { + t.Fatalf("unexpected message data type: %T", message.Data) + } + if data["EntityId"] != entityId { + t.Fatalf("unexpected entity id: got %v, want %q", data["EntityId"], entityId) + } +} + +func assertNoWebsocketMessage(t *testing.T, conn *websocket.Conn) { + t.Helper() + + if err := conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil { + t.Fatalf("failed to set read deadline: %v", err) + } + + var message wsMessage + if err := conn.ReadJSON(&message); err == nil { + t.Fatalf("unexpected websocket message: %+v", message) + } +} + +// endregion test-helper + +func TestWebsocketEndpointAllowsOwnPeerStatsForNonAdmin(t *testing.T) { + bus := evbus.New(10) + conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "user-a"}, + map[domain.PeerIdentifier]*domain.Peer{ + "own-peer": {Identifier: "own-peer", UserIdentifier: "user-a"}, + }) + defer cleanup() + + bus.Publish(app.TopicPeerStatsUpdated, domain.TrafficDelta{EntityId: "own-peer", BytesReceivedPerSecond: 1}) + assertWebsocketMessage(t, conn, "peer_stats", "own-peer") +} + +func TestWebsocketEndpointFiltersOtherPeerStatsForNonAdmin(t *testing.T) { + bus := evbus.New(10) + conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "user-a"}, + map[domain.PeerIdentifier]*domain.Peer{ + "other-peer": {Identifier: "other-peer", UserIdentifier: "user-b"}, + }) + defer cleanup() + + bus.Publish(app.TopicPeerStatsUpdated, domain.TrafficDelta{EntityId: "other-peer", BytesReceivedPerSecond: 1}) + assertNoWebsocketMessage(t, conn) +} + +func TestWebsocketEndpointFiltersUnknownPeerStatsForNonAdmin(t *testing.T) { + bus := evbus.New(10) + conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "user-a"}, + map[domain.PeerIdentifier]*domain.Peer{ + "other-peer": {Identifier: "other-peer", UserIdentifier: ""}, + }) + defer cleanup() + + bus.Publish(app.TopicPeerStatsUpdated, domain.TrafficDelta{EntityId: "other-peer", BytesReceivedPerSecond: 1}) + assertNoWebsocketMessage(t, conn) +} + +func TestWebsocketEndpointFiltersUnknownPeerStatsForNonAdmin2(t *testing.T) { + bus := evbus.New(10) + conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "user-a"}, nil) + defer cleanup() + + bus.Publish(app.TopicPeerStatsUpdated, domain.TrafficDelta{EntityId: "unknown-peer", BytesReceivedPerSecond: 1}) + assertNoWebsocketMessage(t, conn) +} + +func TestWebsocketEndpointFiltersInterfaceStatsForNonAdmin(t *testing.T) { + bus := evbus.New(10) + conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "user-a"}, nil) + defer cleanup() + + bus.Publish(app.TopicInterfaceStatsUpdated, domain.TrafficDelta{EntityId: "wg0", BytesReceivedPerSecond: 1}) + assertNoWebsocketMessage(t, conn) +} + +func TestWebsocketEndpointAllowsAllStatsForAdmin(t *testing.T) { + bus := evbus.New(10) + conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "admin", IsAdmin: true}, nil) + defer cleanup() + + bus.Publish(app.TopicPeerStatsUpdated, domain.TrafficDelta{EntityId: "other-peer", BytesReceivedPerSecond: 1}) + assertWebsocketMessage(t, conn, "peer_stats", "other-peer") + + bus.Publish(app.TopicInterfaceStatsUpdated, domain.TrafficDelta{EntityId: "wg0", BytesReceivedPerSecond: 1}) + assertWebsocketMessage(t, conn, "interface_stats", "wg0") +} + +func Test_matchOrigin(t *testing.T) { + tests := []struct { + name string + externalBaseUrl string + origin string + want bool + }{ + { + name: "matching origin", + externalBaseUrl: "https://example.com", + origin: "https://example.com", + want: true, + }, + { + name: "matching origin with path", + externalBaseUrl: "https://example.com/app1", + origin: "https://example.com/app2", + want: true, + }, + { + name: "non-matching origin with different host", + externalBaseUrl: "https://example.com", + origin: "https://example.com.malicious.com", + want: false, + }, + { + name: "non-matching origin with different scheme", + externalBaseUrl: "https://example.com", + origin: "http://example.com", + want: false, + }, + { + name: "invalid origin URL", + externalBaseUrl: "https://example.com", + origin: "://invalid-url", + want: false, + }, + { + name: "invalid externalBaseUrl", + externalBaseUrl: "://invalid-url", + origin: "https://example.com", + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := matchOrigin(tt.externalBaseUrl, tt.origin) + if got != tt.want { + t.Errorf("matchOrigin() = %v, want %v", got, tt.want) + } + }) + } +}