diff --git a/cmd/wg-portal/main.go b/cmd/wg-portal/main.go index 83512ec..0592b82 100644 --- a/cmd/wg-portal/main.go +++ b/cmd/wg-portal/main.go @@ -135,6 +135,7 @@ func main() { apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendPeers) apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth, wireGuard) apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth) + apiV0EndpointWebsocket := handlersV0.NewWebsocketEndpoint(cfg, apiV0Auth, eventBus) apiFrontend := handlersV0.NewRestApi(apiV0Session, apiV0EndpointAuth, @@ -144,6 +145,7 @@ func main() { apiV0EndpointPeers, apiV0EndpointConfig, apiV0EndpointTest, + apiV0EndpointWebsocket, ) // endregion API v0 (SPA frontend) diff --git a/frontend/src/helpers/websocket-wrapper.js b/frontend/src/helpers/websocket-wrapper.js new file mode 100644 index 0000000..f5aa0c1 --- /dev/null +++ b/frontend/src/helpers/websocket-wrapper.js @@ -0,0 +1,86 @@ +import { peerStore } from '@/stores/peers'; +import { interfaceStore } from '@/stores/interfaces'; +import { authStore } from '@/stores/auth'; + +let socket = null; +let reconnectTimer = null; +let failureCount = 0; + +export const websocketWrapper = { + connect() { + if (socket) { + console.log('WebSocket already connected, re-using existing connection.'); + return; + } + + const protocol = WGPORTAL_BACKEND_BASE_URL.startsWith('https://') ? 'wss://' : 'ws://'; + const baseUrl = WGPORTAL_BACKEND_BASE_URL.replace(/^https?:\/\//, ''); + const url = `${protocol}${baseUrl}/ws`; + + socket = new WebSocket(url); + + socket.onopen = () => { + console.log('WebSocket connected'); + failureCount = 0; + if (reconnectTimer) { + clearInterval(reconnectTimer); + reconnectTimer = null; + } + }; + + socket.onclose = () => { + console.log('WebSocket disconnected'); + failureCount++; + socket = null; + this.scheduleReconnect(); + }; + + socket.onerror = (error) => { + console.error('WebSocket error:', error); + failureCount++; + socket.close(); + socket = null; + }; + + socket.onmessage = (event) => { + const message = JSON.parse(event.data); + switch (message.type) { + case 'peer_stats': + peerStore().updatePeerTrafficStats(message.data); + break; + case 'interface_stats': + interfaceStore().updateInterfaceTrafficStats(message.data); + break; + } + }; + }, + + disconnect() { + if (socket) { + socket.close(); + socket = null; + } + if (reconnectTimer) { + clearInterval(reconnectTimer); + reconnectTimer = null; + failureCount = 0; + } + }, + + scheduleReconnect() { + if (reconnectTimer) return; + if (!authStore().IsAuthenticated) return; // Don't reconnect if not logged in + + reconnectTimer = setInterval(() => { + if (failureCount > 2) { + console.log('WebSocket connection unavailable, giving up.'); + clearInterval(reconnectTimer); + reconnectTimer = null; + return; + } + + console.log('Attempting to reconnect WebSocket...'); + this.connect(); + }, 5000); + } +}; diff --git a/frontend/src/stores/auth.js b/frontend/src/stores/auth.js index 50c0d09..83d97d2 100644 --- a/frontend/src/stores/auth.js +++ b/frontend/src/stores/auth.js @@ -2,6 +2,7 @@ import { defineStore } from 'pinia' import { notify } from "@kyvg/vue3-notification"; import { apiWrapper } from '@/helpers/fetch-wrapper' +import { websocketWrapper } from '@/helpers/websocket-wrapper' import router from '../router' import { browserSupportsWebAuthn,startRegistration,startAuthentication } from '@simplewebauthn/browser'; import {base64_url_encode} from "@/helpers/encoding"; @@ -295,9 +296,11 @@ export const authStore = defineStore('auth',{ } } localStorage.setItem('user', JSON.stringify(this.user)) + websocketWrapper.connect() } else { this.user = null localStorage.removeItem('user') + websocketWrapper.disconnect() } }, setWebAuthnCredentials(credentials) { diff --git a/frontend/src/stores/interfaces.js b/frontend/src/stores/interfaces.js index 6f1f529..efe75c2 100644 --- a/frontend/src/stores/interfaces.js +++ b/frontend/src/stores/interfaces.js @@ -14,6 +14,7 @@ export const interfaceStore = defineStore('interfaces', { configuration: "", selected: "", fetching: false, + trafficStats: {}, }), getters: { Count: (state) => state.interfaces.length, @@ -24,6 +25,9 @@ export const interfaceStore = defineStore('interfaces', { }, GetSelected: (state) => state.interfaces.find((i) => i.Identifier === state.selected) || state.interfaces[0], isFetching: (state) => state.fetching, + TrafficStats: (state) => { + return (state.selected in state.trafficStats) ? state.trafficStats[state.selected] : { Received: 0, Transmitted: 0 } + }, }, actions: { setInterfaces(interfaces) { @@ -34,6 +38,14 @@ export const interfaceStore = defineStore('interfaces', { this.selected = "" } this.fetching = false + this.trafficStats = {} + }, + updateInterfaceTrafficStats(interfaceStats) { + const id = interfaceStats.EntityId; + this.trafficStats[id] = { + Received: interfaceStats.BytesReceived, + Transmitted: interfaceStats.BytesTransmitted, + }; }, async LoadInterfaces() { this.fetching = true diff --git a/frontend/src/stores/peers.js b/frontend/src/stores/peers.js index 2f71656..8e80b2c 100644 --- a/frontend/src/stores/peers.js +++ b/frontend/src/stores/peers.js @@ -23,6 +23,7 @@ export const peerStore = defineStore('peers', { fetching: false, sortKey: 'IsConnected', // Default sort key sortOrder: -1, // 1 for ascending, -1 for descending + trafficStats: {}, }), getters: { Find: (state) => { @@ -76,6 +77,9 @@ export const peerStore = defineStore('peers', { Statistics: (state) => { return (id) => state.statsEnabled && (id in state.stats) ? state.stats[id] : freshStats() }, + TrafficStats: (state) => { + return (id) => (id in state.trafficStats) ? state.trafficStats[id] : { Received: 0, Transmitted: 0 } + }, hasStatistics: (state) => state.statsEnabled, }, @@ -111,6 +115,7 @@ export const peerStore = defineStore('peers', { this.peers = peers this.calculatePages() this.fetching = false + this.trafficStats = {} }, setPeer(peer) { this.peer = peer @@ -126,11 +131,19 @@ export const peerStore = defineStore('peers', { if (!statsResponse) { this.stats = {} this.statsEnabled = false + this.trafficStats = {} } else { this.stats = statsResponse.Stats this.statsEnabled = statsResponse.Enabled } }, + updatePeerTrafficStats(peerStats) { + const id = peerStats.EntityId; + this.trafficStats[id] = { + Received: peerStats.BytesReceived, + Transmitted: peerStats.BytesTransmitted, + }; + }, async Reset() { this.setPeers([]) this.setStats(undefined) diff --git a/frontend/src/views/InterfaceView.vue b/frontend/src/views/InterfaceView.vue index ed9dc35..559a056 100644 --- a/frontend/src/views/InterfaceView.vue +++ b/frontend/src/views/InterfaceView.vue @@ -210,6 +210,12 @@ onMounted(async () => {
{{ $t('interfaces.interface.headline') }} {{interfaces.GetSelected.Identifier}} ({{ $t('modals.interface-edit.mode.' + interfaces.GetSelected.Mode )}} | {{ $t('interfaces.interface.backend') + ": " + calculateBackendName }}) +
+ + Traffic: {{ humanFileSize(interfaces.TrafficStats.Received) }}/s + {{ humanFileSize(interfaces.TrafficStats.Transmitted) }}/s + +
@@ -451,14 +457,19 @@ onMounted(async () => { {{peer.Endpoint.Value}}
- {{ $t('interfaces.peer-connected') }} +
- {{ humanFileSize(peers.Statistics(peer.Identifier).BytesReceived) }} / {{ humanFileSize(peers.Statistics(peer.Identifier).BytesTransmitted) }} +
+ + {{ humanFileSize(peers.TrafficStats(peer.Identifier).Received) }}/s + {{ humanFileSize(peers.TrafficStats(peer.Identifier).Transmitted) }}/s + +
diff --git a/go.mod b/go.mod index 41130f4..8a3ad3e 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/go-playground/validator/v10 v10.30.1 github.com/go-webauthn/webauthn v0.15.0 github.com/google/uuid v1.6.0 + github.com/gorilla/websocket v1.5.3 github.com/prometheus-community/pro-bing v0.7.0 github.com/prometheus/client_golang v1.23.2 github.com/stretchr/testify v1.11.1 diff --git a/go.sum b/go.sum index 93ad84f..c60bc20 100644 --- a/go.sum +++ b/go.sum @@ -130,6 +130,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= diff --git a/internal/app/api/core/middleware/logging/writer.go b/internal/app/api/core/middleware/logging/writer.go index 4e3c42f..d77a7f1 100644 --- a/internal/app/api/core/middleware/logging/writer.go +++ b/internal/app/api/core/middleware/logging/writer.go @@ -1,6 +1,8 @@ package logging import ( + "bufio" + "net" "net/http" ) @@ -38,6 +40,12 @@ func (w *writerWrapper) Write(data []byte) (int, error) { return n, err } +// Hijack wraps the Hijack method of the ResponseWriter and returns the hijacked connection. +// This is required for websockets to work. +func (w *writerWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return http.NewResponseController(w.ResponseWriter).Hijack() +} + // newWriterWrapper returns a new writerWrapper that wraps the given http.ResponseWriter. // It initializes the StatusCode to http.StatusOK. func newWriterWrapper(w http.ResponseWriter) *writerWrapper { diff --git a/internal/app/api/v0/handlers/endpoint_websocket.go b/internal/app/api/v0/handlers/endpoint_websocket.go new file mode 100644 index 0000000..5dcc35a --- /dev/null +++ b/internal/app/api/v0/handlers/endpoint_websocket.go @@ -0,0 +1,100 @@ +package handlers + +import ( + "context" + "net/http" + "strings" + "sync" + + "github.com/go-pkgz/routegroup" + "github.com/gorilla/websocket" + + "github.com/h44z/wg-portal/internal/app" + "github.com/h44z/wg-portal/internal/config" + "github.com/h44z/wg-portal/internal/domain" +) + +type WebsocketEventBus interface { + Subscribe(topic string, fn any) error + Unsubscribe(topic string, fn any) error +} + +type WebsocketEndpoint struct { + authenticator Authenticator + bus WebsocketEventBus + + upgrader websocket.Upgrader +} + +func NewWebsocketEndpoint(cfg *config.Config, auth Authenticator, bus WebsocketEventBus) *WebsocketEndpoint { + return &WebsocketEndpoint{ + authenticator: auth, + bus: bus, + upgrader: websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + origin := r.Header.Get("Origin") + return strings.HasPrefix(origin, cfg.Web.ExternalUrl) + }, + }, + } +} + +func (e WebsocketEndpoint) GetName() string { + return "WebsocketEndpoint" +} + +func (e WebsocketEndpoint) RegisterRoutes(g *routegroup.Bundle) { + g.With(e.authenticator.LoggedIn()).HandleFunc("GET /ws", e.handleWebsocket()) +} + +// wsMessage represents a message sent over websocket to the frontend +type wsMessage struct { + Type string `json:"type"` // either "peer_stats" or "interface_stats" + Data any `json:"data"` // domain.TrafficDelta +} + +func (e WebsocketEndpoint) handleWebsocket() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + conn, err := e.upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + + writeMutex := sync.Mutex{} + writeJSON := func(msg wsMessage) error { + writeMutex.Lock() + defer writeMutex.Unlock() + return conn.WriteJSON(msg) + } + + peerStatsHandler := func(status domain.TrafficDelta) { + _ = writeJSON(wsMessage{Type: "peer_stats", Data: status}) + } + interfaceStatsHandler := func(status domain.TrafficDelta) { + _ = writeJSON(wsMessage{Type: "interface_stats", Data: status}) + } + + _ = e.bus.Subscribe(app.TopicPeerStatsUpdated, peerStatsHandler) + defer e.bus.Unsubscribe(app.TopicPeerStatsUpdated, peerStatsHandler) + _ = e.bus.Subscribe(app.TopicInterfaceStatsUpdated, interfaceStatsHandler) + defer e.bus.Unsubscribe(app.TopicInterfaceStatsUpdated, interfaceStatsHandler) + + // Keep connection open until client disconnects or context is cancelled + go func() { + for { + if _, _, err := conn.ReadMessage(); err != nil { + cancel() + return + } + } + }() + + <-ctx.Done() + } +} diff --git a/internal/app/eventbus.go b/internal/app/eventbus.go index d411aa6..4e265fe 100644 --- a/internal/app/eventbus.go +++ b/internal/app/eventbus.go @@ -26,6 +26,7 @@ const TopicUserEnabled = "user:enabled" const TopicInterfaceCreated = "interface:created" const TopicInterfaceUpdated = "interface:updated" const TopicInterfaceDeleted = "interface:deleted" +const TopicInterfaceStatsUpdated = "interface:stats:updated" // endregion interface-events @@ -37,6 +38,7 @@ const TopicPeerUpdated = "peer:updated" const TopicPeerInterfaceUpdated = "peer:interface:updated" const TopicPeerIdentifierUpdated = "peer:identifier:updated" const TopicPeerStateChanged = "peer:state:changed" +const TopicPeerStatsUpdated = "peer:stats:updated" // endregion peer-events diff --git a/internal/app/wireguard/statistics.go b/internal/app/wireguard/statistics.go index 78ca6eb..bbc0764 100644 --- a/internal/app/wireguard/statistics.go +++ b/internal/app/wireguard/statistics.go @@ -121,15 +121,25 @@ func (c *StatisticsCollector) collectInterfaceData(ctx context.Context) { "error", err) continue } + now := time.Now() err = c.db.UpdateInterfaceStatus(ctx, in.Identifier, func(i *domain.InterfaceStatus) (*domain.InterfaceStatus, error) { - i.UpdatedAt = time.Now() + td := domain.CalculateTrafficDelta( + string(in.Identifier), + i.UpdatedAt, now, + i.BytesTransmitted, physicalInterface.BytesUpload, + i.BytesReceived, physicalInterface.BytesDownload, + ) + i.UpdatedAt = now i.BytesReceived = physicalInterface.BytesDownload i.BytesTransmitted = physicalInterface.BytesUpload // Update prometheus metrics go c.updateInterfaceMetrics(*i) + // Publish stats update event + c.bus.Publish(app.TopicInterfaceStatsUpdated, td) + return i, nil }) if err != nil { @@ -172,6 +182,7 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) { slog.Warn("failed to fetch peers for data collection", "interface", in.Identifier, "error", err) continue } + now := time.Now() for _, peer := range peers { var connectionStateChanged bool var newPeerStatus domain.PeerStatus @@ -184,8 +195,15 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) { lastHandshake = &peer.LastHandshake } + td := domain.CalculateTrafficDelta( + string(peer.Identifier), + p.UpdatedAt, now, + p.BytesTransmitted, peer.BytesDownload, + p.BytesReceived, peer.BytesUpload, + ) + // calculate if session was restarted - p.UpdatedAt = time.Now() + p.UpdatedAt = now p.LastSessionStart = getSessionStartTime(*p, peer.BytesUpload, peer.BytesDownload, lastHandshake) p.BytesReceived = peer.BytesUpload // store bytes that where uploaded from the peer and received by the server @@ -195,7 +213,8 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) { p.CalcConnected() if wasConnected != p.IsConnected { - slog.Debug("peer connection state changed", "peer", peer.Identifier, "connected", p.IsConnected) + slog.Debug("peer connection state changed", + "peer", peer.Identifier, "connected", p.IsConnected) connectionStateChanged = true newPeerStatus = *p // store new status for event publishing } @@ -203,6 +222,9 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) { // Update prometheus metrics go c.updatePeerMetrics(ctx, *p) + // Publish stats update event + c.bus.Publish(app.TopicPeerStatsUpdated, td) + return p, nil }) if err != nil { diff --git a/internal/domain/statistics.go b/internal/domain/statistics.go index cbc987d..aa205e8 100644 --- a/internal/domain/statistics.go +++ b/internal/domain/statistics.go @@ -61,3 +61,25 @@ func (r PingerResult) AverageRtt() time.Duration { } return total / time.Duration(len(r.Rtts)) } + +type TrafficDelta struct { + EntityId string `json:"EntityId"` // Either peerId or interfaceId + BytesReceivedPerSecond uint64 `json:"BytesReceived"` + BytesTransmittedPerSecond uint64 `json:"BytesTransmitted"` +} + +func CalculateTrafficDelta(id string, oldTime, newTime time.Time, oldTx, newTx, oldRx, newRx uint64) TrafficDelta { + timeDiff := uint64(newTime.Sub(oldTime).Seconds()) + if timeDiff == 0 { + return TrafficDelta{ + EntityId: id, + BytesReceivedPerSecond: 0, + BytesTransmittedPerSecond: 0, + } + } + return TrafficDelta{ + EntityId: id, + BytesReceivedPerSecond: (newRx - oldRx) / timeDiff, + BytesTransmittedPerSecond: (newTx - oldTx) / timeDiff, + } +}