diff --git a/cmd/wg-portal/main.go b/cmd/wg-portal/main.go index b1c6aa9..de4a49f 100644 --- a/cmd/wg-portal/main.go +++ b/cmd/wg-portal/main.go @@ -136,6 +136,7 @@ func main() { apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth, wireGuard) apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth) apiV0EndpointWebsocket := handlersV0.NewWebsocketEndpoint(cfg, apiV0Auth, eventBus, apiV0BackendPeers) + apiV0EndpointWebsocket.StartBackgroundJobs(ctx) apiFrontend := handlersV0.NewRestApi(apiV0Session, apiV0EndpointAuth, diff --git a/internal/app/api/v0/handlers/endpoint_websocket.go b/internal/app/api/v0/handlers/endpoint_websocket.go index c3afa28..b2653ad 100644 --- a/internal/app/api/v0/handlers/endpoint_websocket.go +++ b/internal/app/api/v0/handlers/endpoint_websocket.go @@ -6,6 +6,7 @@ import ( "net/url" "strings" "sync" + "time" "github.com/go-pkgz/routegroup" "github.com/gorilla/websocket" @@ -15,6 +16,11 @@ import ( "github.com/h44z/wg-portal/internal/domain" ) +const ( + websocketPeerUserIdentifierCacheTTL = 90 * time.Second + websocketPeerUserIdentifierCacheCleanupInterval = websocketPeerUserIdentifierCacheTTL * 2 +) + type WebsocketEventBus interface { Subscribe(topic string, fn any) error Unsubscribe(topic string, fn any) error @@ -30,9 +36,17 @@ type WebsocketEndpoint struct { peerService WebsocketPeerService upgrader websocket.Upgrader + + ownershipCache map[domain.PeerIdentifier]peerUserIdentifierCacheEntry + ownershipCacheMux sync.Mutex } -func NewWebsocketEndpoint(cfg *config.Config, auth Authenticator, bus WebsocketEventBus, peerService WebsocketPeerService) *WebsocketEndpoint { +func NewWebsocketEndpoint( + cfg *config.Config, + auth Authenticator, + bus WebsocketEventBus, + peerService WebsocketPeerService, +) *WebsocketEndpoint { return &WebsocketEndpoint{ authenticator: auth, bus: bus, @@ -44,24 +58,38 @@ func NewWebsocketEndpoint(cfg *config.Config, auth Authenticator, bus WebsocketE return matchOrigin(cfg.Web.ExternalUrl, r.Header.Get("Origin")) }, }, + ownershipCache: make(map[domain.PeerIdentifier]peerUserIdentifierCacheEntry), + ownershipCacheMux: sync.Mutex{}, } } -func (e WebsocketEndpoint) GetName() string { +func (e *WebsocketEndpoint) GetName() string { return "WebsocketEndpoint" } -func (e WebsocketEndpoint) RegisterRoutes(g *routegroup.Bundle) { +func (e *WebsocketEndpoint) RegisterRoutes(g *routegroup.Bundle) { g.With(e.authenticator.LoggedIn()).HandleFunc("GET /ws", e.handleWebsocket()) } +// StartBackgroundJobs starts background jobs like the expired peers check. +// This method is non-blocking. +func (e *WebsocketEndpoint) StartBackgroundJobs(ctx context.Context) { + go e.startOwnerCacheCleanup(ctx) +} + // wsMessage represents a message sent over websocket to the frontend type wsMessage struct { Type string `json:"type"` // either "peer_stats" or "interface_stats" Data any `json:"data"` // domain.TrafficDelta } -func (e WebsocketEndpoint) handleWebsocket() http.HandlerFunc { +// peerUserIdentifierCacheEntry is a cache entry object that reduces database load when checking peer ownership. +type peerUserIdentifierCacheEntry struct { + userIdentifier domain.UserIdentifier + expiresAt time.Time +} + +func (e *WebsocketEndpoint) handleWebsocket() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { userInfo := domain.GetUserInfo(r.Context()) @@ -84,16 +112,16 @@ func (e WebsocketEndpoint) handleWebsocket() http.HandlerFunc { peerStatsHandler := func(status domain.TrafficDelta) { if !userInfo.IsAdmin { // lookup peer user-info to validate ownership - peer, err := e.peerService.GetPeer(ctx, domain.PeerIdentifier(status.EntityId)) + peerUserIdentifier, err := e.getPeerUserIdentifier(ctx, domain.PeerIdentifier(status.EntityId)) if err != nil { return } - if peer.UserIdentifier == "" { + if peerUserIdentifier == "" { return // if peer is not assigned to any user, dont send stats } - if peer.UserIdentifier != userInfo.Id { + if peerUserIdentifier != userInfo.Id { return // only expose stats for own peers } } @@ -127,6 +155,60 @@ func (e WebsocketEndpoint) handleWebsocket() http.HandlerFunc { } } +func (e *WebsocketEndpoint) getPeerUserIdentifier( + ctx context.Context, + peerIdentifier domain.PeerIdentifier, +) (domain.UserIdentifier, error) { + now := time.Now() + + e.ownershipCacheMux.Lock() + entry, ok := e.ownershipCache[peerIdentifier] + if ok && now.Before(entry.expiresAt) { + e.ownershipCacheMux.Unlock() + return entry.userIdentifier, nil + } + e.ownershipCacheMux.Unlock() + + peer, err := e.peerService.GetPeer(ctx, peerIdentifier) + if err != nil { + return "", err + } + + e.ownershipCacheMux.Lock() + defer e.ownershipCacheMux.Unlock() + e.ownershipCache[peerIdentifier] = peerUserIdentifierCacheEntry{ + userIdentifier: peer.UserIdentifier, + expiresAt: now.Add(websocketPeerUserIdentifierCacheTTL), + } + + return peer.UserIdentifier, nil +} + +func (e *WebsocketEndpoint) startOwnerCacheCleanup(ctx context.Context) { + ticker := time.NewTicker(websocketPeerUserIdentifierCacheCleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case now := <-ticker.C: + e.cleanupOwnerCache(now) + } + } +} + +func (e *WebsocketEndpoint) cleanupOwnerCache(now time.Time) { + e.ownershipCacheMux.Lock() + defer e.ownershipCacheMux.Unlock() + + for peerIdentifier, entry := range e.ownershipCache { + if !now.Before(entry.expiresAt) { + delete(e.ownershipCache, peerIdentifier) + } + } +} + func matchOrigin(externalBaseUrl, origin string) bool { originURL, err := url.Parse(origin) if err != nil { diff --git a/internal/app/api/v0/handlers/endpoint_websocket_test.go b/internal/app/api/v0/handlers/endpoint_websocket_test.go index ad98689..ba81901 100644 --- a/internal/app/api/v0/handlers/endpoint_websocket_test.go +++ b/internal/app/api/v0/handlers/endpoint_websocket_test.go @@ -115,6 +115,31 @@ func TestWebsocketEndpointAllowsOwnPeerStatsForNonAdmin(t *testing.T) { assertWebsocketMessage(t, conn, "peer_stats", "own-peer") } +func TestWebsocketEndpointCleansExpiredPeerUserIdentifierCache(t *testing.T) { + now := time.Now() + endpoint := &WebsocketEndpoint{ + ownershipCache: map[domain.PeerIdentifier]peerUserIdentifierCacheEntry{ + "expired-peer": { + userIdentifier: "user-a", + expiresAt: now.Add(-time.Second), + }, + "active-peer": { + userIdentifier: "user-b", + expiresAt: now.Add(time.Second), + }, + }, + } + + endpoint.cleanupOwnerCache(now) + + if _, ok := endpoint.ownershipCache["expired-peer"]; ok { + t.Fatal("expired peer cache entry was not removed") + } + if _, ok := endpoint.ownershipCache["active-peer"]; !ok { + t.Fatal("active peer cache entry was removed") + } +} + func TestWebsocketEndpointFiltersOtherPeerStatsForNonAdmin(t *testing.T) { bus := evbus.New(10) conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "user-a"},