mirror of
https://github.com/h44z/wg-portal.git
synced 2026-06-07 09:06:20 +00:00
feat: add short-lived cache for peer-ownership checks
This commit is contained in:
@@ -136,6 +136,7 @@ func main() {
|
|||||||
apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth, wireGuard)
|
apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth, wireGuard)
|
||||||
apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth)
|
apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth)
|
||||||
apiV0EndpointWebsocket := handlersV0.NewWebsocketEndpoint(cfg, apiV0Auth, eventBus, apiV0BackendPeers)
|
apiV0EndpointWebsocket := handlersV0.NewWebsocketEndpoint(cfg, apiV0Auth, eventBus, apiV0BackendPeers)
|
||||||
|
apiV0EndpointWebsocket.StartBackgroundJobs(ctx)
|
||||||
|
|
||||||
apiFrontend := handlersV0.NewRestApi(apiV0Session,
|
apiFrontend := handlersV0.NewRestApi(apiV0Session,
|
||||||
apiV0EndpointAuth,
|
apiV0EndpointAuth,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/go-pkgz/routegroup"
|
"github.com/go-pkgz/routegroup"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
@@ -15,6 +16,11 @@ import (
|
|||||||
"github.com/h44z/wg-portal/internal/domain"
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
websocketPeerUserIdentifierCacheTTL = 90 * time.Second
|
||||||
|
websocketPeerUserIdentifierCacheCleanupInterval = websocketPeerUserIdentifierCacheTTL * 2
|
||||||
|
)
|
||||||
|
|
||||||
type WebsocketEventBus interface {
|
type WebsocketEventBus interface {
|
||||||
Subscribe(topic string, fn any) error
|
Subscribe(topic string, fn any) error
|
||||||
Unsubscribe(topic string, fn any) error
|
Unsubscribe(topic string, fn any) error
|
||||||
@@ -30,9 +36,17 @@ type WebsocketEndpoint struct {
|
|||||||
peerService WebsocketPeerService
|
peerService WebsocketPeerService
|
||||||
|
|
||||||
upgrader websocket.Upgrader
|
upgrader websocket.Upgrader
|
||||||
|
|
||||||
|
ownershipCache map[domain.PeerIdentifier]peerUserIdentifierCacheEntry
|
||||||
|
ownershipCacheMux sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWebsocketEndpoint(cfg *config.Config, auth Authenticator, bus WebsocketEventBus, peerService WebsocketPeerService) *WebsocketEndpoint {
|
func NewWebsocketEndpoint(
|
||||||
|
cfg *config.Config,
|
||||||
|
auth Authenticator,
|
||||||
|
bus WebsocketEventBus,
|
||||||
|
peerService WebsocketPeerService,
|
||||||
|
) *WebsocketEndpoint {
|
||||||
return &WebsocketEndpoint{
|
return &WebsocketEndpoint{
|
||||||
authenticator: auth,
|
authenticator: auth,
|
||||||
bus: bus,
|
bus: bus,
|
||||||
@@ -44,24 +58,38 @@ func NewWebsocketEndpoint(cfg *config.Config, auth Authenticator, bus WebsocketE
|
|||||||
return matchOrigin(cfg.Web.ExternalUrl, r.Header.Get("Origin"))
|
return matchOrigin(cfg.Web.ExternalUrl, r.Header.Get("Origin"))
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
ownershipCache: make(map[domain.PeerIdentifier]peerUserIdentifierCacheEntry),
|
||||||
|
ownershipCacheMux: sync.Mutex{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e WebsocketEndpoint) GetName() string {
|
func (e *WebsocketEndpoint) GetName() string {
|
||||||
return "WebsocketEndpoint"
|
return "WebsocketEndpoint"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e WebsocketEndpoint) RegisterRoutes(g *routegroup.Bundle) {
|
func (e *WebsocketEndpoint) RegisterRoutes(g *routegroup.Bundle) {
|
||||||
g.With(e.authenticator.LoggedIn()).HandleFunc("GET /ws", e.handleWebsocket())
|
g.With(e.authenticator.LoggedIn()).HandleFunc("GET /ws", e.handleWebsocket())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StartBackgroundJobs starts background jobs like the expired peers check.
|
||||||
|
// This method is non-blocking.
|
||||||
|
func (e *WebsocketEndpoint) StartBackgroundJobs(ctx context.Context) {
|
||||||
|
go e.startOwnerCacheCleanup(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
// wsMessage represents a message sent over websocket to the frontend
|
// wsMessage represents a message sent over websocket to the frontend
|
||||||
type wsMessage struct {
|
type wsMessage struct {
|
||||||
Type string `json:"type"` // either "peer_stats" or "interface_stats"
|
Type string `json:"type"` // either "peer_stats" or "interface_stats"
|
||||||
Data any `json:"data"` // domain.TrafficDelta
|
Data any `json:"data"` // domain.TrafficDelta
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e WebsocketEndpoint) handleWebsocket() http.HandlerFunc {
|
// peerUserIdentifierCacheEntry is a cache entry object that reduces database load when checking peer ownership.
|
||||||
|
type peerUserIdentifierCacheEntry struct {
|
||||||
|
userIdentifier domain.UserIdentifier
|
||||||
|
expiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *WebsocketEndpoint) handleWebsocket() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
userInfo := domain.GetUserInfo(r.Context())
|
userInfo := domain.GetUserInfo(r.Context())
|
||||||
|
|
||||||
@@ -84,16 +112,16 @@ func (e WebsocketEndpoint) handleWebsocket() http.HandlerFunc {
|
|||||||
peerStatsHandler := func(status domain.TrafficDelta) {
|
peerStatsHandler := func(status domain.TrafficDelta) {
|
||||||
if !userInfo.IsAdmin {
|
if !userInfo.IsAdmin {
|
||||||
// lookup peer user-info to validate ownership
|
// lookup peer user-info to validate ownership
|
||||||
peer, err := e.peerService.GetPeer(ctx, domain.PeerIdentifier(status.EntityId))
|
peerUserIdentifier, err := e.getPeerUserIdentifier(ctx, domain.PeerIdentifier(status.EntityId))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.UserIdentifier == "" {
|
if peerUserIdentifier == "" {
|
||||||
return // if peer is not assigned to any user, dont send stats
|
return // if peer is not assigned to any user, dont send stats
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.UserIdentifier != userInfo.Id {
|
if peerUserIdentifier != userInfo.Id {
|
||||||
return // only expose stats for own peers
|
return // only expose stats for own peers
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -127,6 +155,60 @@ func (e WebsocketEndpoint) handleWebsocket() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *WebsocketEndpoint) getPeerUserIdentifier(
|
||||||
|
ctx context.Context,
|
||||||
|
peerIdentifier domain.PeerIdentifier,
|
||||||
|
) (domain.UserIdentifier, error) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
e.ownershipCacheMux.Lock()
|
||||||
|
entry, ok := e.ownershipCache[peerIdentifier]
|
||||||
|
if ok && now.Before(entry.expiresAt) {
|
||||||
|
e.ownershipCacheMux.Unlock()
|
||||||
|
return entry.userIdentifier, nil
|
||||||
|
}
|
||||||
|
e.ownershipCacheMux.Unlock()
|
||||||
|
|
||||||
|
peer, err := e.peerService.GetPeer(ctx, peerIdentifier)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
e.ownershipCacheMux.Lock()
|
||||||
|
defer e.ownershipCacheMux.Unlock()
|
||||||
|
e.ownershipCache[peerIdentifier] = peerUserIdentifierCacheEntry{
|
||||||
|
userIdentifier: peer.UserIdentifier,
|
||||||
|
expiresAt: now.Add(websocketPeerUserIdentifierCacheTTL),
|
||||||
|
}
|
||||||
|
|
||||||
|
return peer.UserIdentifier, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *WebsocketEndpoint) startOwnerCacheCleanup(ctx context.Context) {
|
||||||
|
ticker := time.NewTicker(websocketPeerUserIdentifierCacheCleanupInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case now := <-ticker.C:
|
||||||
|
e.cleanupOwnerCache(now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *WebsocketEndpoint) cleanupOwnerCache(now time.Time) {
|
||||||
|
e.ownershipCacheMux.Lock()
|
||||||
|
defer e.ownershipCacheMux.Unlock()
|
||||||
|
|
||||||
|
for peerIdentifier, entry := range e.ownershipCache {
|
||||||
|
if !now.Before(entry.expiresAt) {
|
||||||
|
delete(e.ownershipCache, peerIdentifier)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func matchOrigin(externalBaseUrl, origin string) bool {
|
func matchOrigin(externalBaseUrl, origin string) bool {
|
||||||
originURL, err := url.Parse(origin)
|
originURL, err := url.Parse(origin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -115,6 +115,31 @@ func TestWebsocketEndpointAllowsOwnPeerStatsForNonAdmin(t *testing.T) {
|
|||||||
assertWebsocketMessage(t, conn, "peer_stats", "own-peer")
|
assertWebsocketMessage(t, conn, "peer_stats", "own-peer")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWebsocketEndpointCleansExpiredPeerUserIdentifierCache(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
endpoint := &WebsocketEndpoint{
|
||||||
|
ownershipCache: map[domain.PeerIdentifier]peerUserIdentifierCacheEntry{
|
||||||
|
"expired-peer": {
|
||||||
|
userIdentifier: "user-a",
|
||||||
|
expiresAt: now.Add(-time.Second),
|
||||||
|
},
|
||||||
|
"active-peer": {
|
||||||
|
userIdentifier: "user-b",
|
||||||
|
expiresAt: now.Add(time.Second),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint.cleanupOwnerCache(now)
|
||||||
|
|
||||||
|
if _, ok := endpoint.ownershipCache["expired-peer"]; ok {
|
||||||
|
t.Fatal("expired peer cache entry was not removed")
|
||||||
|
}
|
||||||
|
if _, ok := endpoint.ownershipCache["active-peer"]; !ok {
|
||||||
|
t.Fatal("active peer cache entry was removed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWebsocketEndpointFiltersOtherPeerStatsForNonAdmin(t *testing.T) {
|
func TestWebsocketEndpointFiltersOtherPeerStatsForNonAdmin(t *testing.T) {
|
||||||
bus := evbus.New(10)
|
bus := evbus.New(10)
|
||||||
conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "user-a"},
|
conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "user-a"},
|
||||||
|
|||||||
Reference in New Issue
Block a user