diff --git a/cmd/wg-portal/main.go b/cmd/wg-portal/main.go
index 83512ec..0592b82 100644
--- a/cmd/wg-portal/main.go
+++ b/cmd/wg-portal/main.go
@@ -135,6 +135,7 @@ func main() {
apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendPeers)
apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth, wireGuard)
apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth)
+ apiV0EndpointWebsocket := handlersV0.NewWebsocketEndpoint(cfg, apiV0Auth, eventBus)
apiFrontend := handlersV0.NewRestApi(apiV0Session,
apiV0EndpointAuth,
@@ -144,6 +145,7 @@ func main() {
apiV0EndpointPeers,
apiV0EndpointConfig,
apiV0EndpointTest,
+ apiV0EndpointWebsocket,
)
// endregion API v0 (SPA frontend)
diff --git a/frontend/src/helpers/websocket-wrapper.js b/frontend/src/helpers/websocket-wrapper.js
new file mode 100644
index 0000000..f5aa0c1
--- /dev/null
+++ b/frontend/src/helpers/websocket-wrapper.js
@@ -0,0 +1,86 @@
+import { peerStore } from '@/stores/peers';
+import { interfaceStore } from '@/stores/interfaces';
+import { authStore } from '@/stores/auth';
+
+let socket = null;
+let reconnectTimer = null;
+let failureCount = 0;
+
+export const websocketWrapper = {
+ connect() {
+ if (socket) {
+ console.log('WebSocket already connected, re-using existing connection.');
+ return;
+ }
+
+ const protocol = WGPORTAL_BACKEND_BASE_URL.startsWith('https://') ? 'wss://' : 'ws://';
+ const baseUrl = WGPORTAL_BACKEND_BASE_URL.replace(/^https?:\/\//, '');
+ const url = `${protocol}${baseUrl}/ws`;
+
+ socket = new WebSocket(url);
+
+ socket.onopen = () => {
+ console.log('WebSocket connected');
+ failureCount = 0;
+ if (reconnectTimer) {
+ clearInterval(reconnectTimer);
+ reconnectTimer = null;
+ }
+ };
+
+ socket.onclose = () => {
+ console.log('WebSocket disconnected');
+ failureCount++;
+ socket = null;
+ this.scheduleReconnect();
+ };
+
+ socket.onerror = (error) => {
+ console.error('WebSocket error:', error);
+ failureCount++;
+ socket.close();
+ socket = null;
+ };
+
+ socket.onmessage = (event) => {
+ const message = JSON.parse(event.data);
+ switch (message.type) {
+ case 'peer_stats':
+ peerStore().updatePeerTrafficStats(message.data);
+ break;
+ case 'interface_stats':
+ interfaceStore().updateInterfaceTrafficStats(message.data);
+ break;
+ }
+ };
+ },
+
+ disconnect() {
+ if (socket) {
+ socket.close();
+ socket = null;
+ }
+ if (reconnectTimer) {
+ clearInterval(reconnectTimer);
+ reconnectTimer = null;
+ failureCount = 0;
+ }
+ },
+
+ scheduleReconnect() {
+ if (reconnectTimer) return;
+ if (!authStore().IsAuthenticated) return; // Don't reconnect if not logged in
+
+ reconnectTimer = setInterval(() => {
+ if (failureCount > 2) {
+ console.log('WebSocket connection unavailable, giving up.');
+ clearInterval(reconnectTimer);
+ reconnectTimer = null;
+ return;
+ }
+
+ console.log('Attempting to reconnect WebSocket...');
+ this.connect();
+ }, 5000);
+ }
+};
diff --git a/frontend/src/stores/auth.js b/frontend/src/stores/auth.js
index 50c0d09..83d97d2 100644
--- a/frontend/src/stores/auth.js
+++ b/frontend/src/stores/auth.js
@@ -2,6 +2,7 @@ import { defineStore } from 'pinia'
import { notify } from "@kyvg/vue3-notification";
import { apiWrapper } from '@/helpers/fetch-wrapper'
+import { websocketWrapper } from '@/helpers/websocket-wrapper'
import router from '../router'
import { browserSupportsWebAuthn,startRegistration,startAuthentication } from '@simplewebauthn/browser';
import {base64_url_encode} from "@/helpers/encoding";
@@ -295,9 +296,11 @@ export const authStore = defineStore('auth',{
}
}
localStorage.setItem('user', JSON.stringify(this.user))
+ websocketWrapper.connect()
} else {
this.user = null
localStorage.removeItem('user')
+ websocketWrapper.disconnect()
}
},
setWebAuthnCredentials(credentials) {
diff --git a/frontend/src/stores/interfaces.js b/frontend/src/stores/interfaces.js
index 6f1f529..efe75c2 100644
--- a/frontend/src/stores/interfaces.js
+++ b/frontend/src/stores/interfaces.js
@@ -14,6 +14,7 @@ export const interfaceStore = defineStore('interfaces', {
configuration: "",
selected: "",
fetching: false,
+ trafficStats: {},
}),
getters: {
Count: (state) => state.interfaces.length,
@@ -24,6 +25,9 @@ export const interfaceStore = defineStore('interfaces', {
},
GetSelected: (state) => state.interfaces.find((i) => i.Identifier === state.selected) || state.interfaces[0],
isFetching: (state) => state.fetching,
+ TrafficStats: (state) => {
+ return (state.selected in state.trafficStats) ? state.trafficStats[state.selected] : { Received: 0, Transmitted: 0 }
+ },
},
actions: {
setInterfaces(interfaces) {
@@ -34,6 +38,14 @@ export const interfaceStore = defineStore('interfaces', {
this.selected = ""
}
this.fetching = false
+ this.trafficStats = {}
+ },
+ updateInterfaceTrafficStats(interfaceStats) {
+ const id = interfaceStats.EntityId;
+ this.trafficStats[id] = {
+ Received: interfaceStats.BytesReceived,
+ Transmitted: interfaceStats.BytesTransmitted,
+ };
},
async LoadInterfaces() {
this.fetching = true
diff --git a/frontend/src/stores/peers.js b/frontend/src/stores/peers.js
index 2f71656..8e80b2c 100644
--- a/frontend/src/stores/peers.js
+++ b/frontend/src/stores/peers.js
@@ -23,6 +23,7 @@ export const peerStore = defineStore('peers', {
fetching: false,
sortKey: 'IsConnected', // Default sort key
sortOrder: -1, // 1 for ascending, -1 for descending
+ trafficStats: {},
}),
getters: {
Find: (state) => {
@@ -76,6 +77,9 @@ export const peerStore = defineStore('peers', {
Statistics: (state) => {
return (id) => state.statsEnabled && (id in state.stats) ? state.stats[id] : freshStats()
},
+ TrafficStats: (state) => {
+ return (id) => (id in state.trafficStats) ? state.trafficStats[id] : { Received: 0, Transmitted: 0 }
+ },
hasStatistics: (state) => state.statsEnabled,
},
@@ -111,6 +115,7 @@ export const peerStore = defineStore('peers', {
this.peers = peers
this.calculatePages()
this.fetching = false
+ this.trafficStats = {}
},
setPeer(peer) {
this.peer = peer
@@ -126,11 +131,19 @@ export const peerStore = defineStore('peers', {
if (!statsResponse) {
this.stats = {}
this.statsEnabled = false
+ this.trafficStats = {}
} else {
this.stats = statsResponse.Stats
this.statsEnabled = statsResponse.Enabled
}
},
+ updatePeerTrafficStats(peerStats) {
+ const id = peerStats.EntityId;
+ this.trafficStats[id] = {
+ Received: peerStats.BytesReceived,
+ Transmitted: peerStats.BytesTransmitted,
+ };
+ },
async Reset() {
this.setPeers([])
this.setStats(undefined)
diff --git a/frontend/src/views/InterfaceView.vue b/frontend/src/views/InterfaceView.vue
index ed9dc35..559a056 100644
--- a/frontend/src/views/InterfaceView.vue
+++ b/frontend/src/views/InterfaceView.vue
@@ -210,6 +210,12 @@ onMounted(async () => {
{{ $t('interfaces.interface.headline') }}
{{interfaces.GetSelected.Identifier}} ({{ $t('modals.interface-edit.mode.' + interfaces.GetSelected.Mode )}} | {{ $t('interfaces.interface.backend') + ": " + calculateBackendName }}
)
+
+
+ Traffic: {{ humanFileSize(interfaces.TrafficStats.Received) }}/s
+ {{ humanFileSize(interfaces.TrafficStats.Transmitted) }}/s
+
+
@@ -451,14 +457,19 @@ onMounted(async () => {
{{peer.Endpoint.Value}} |
- {{ $t('interfaces.peer-connected') }}
+
|
- {{ humanFileSize(peers.Statistics(peer.Identifier).BytesReceived) }} / {{ humanFileSize(peers.Statistics(peer.Identifier).BytesTransmitted) }}
+
+
+ {{ humanFileSize(peers.TrafficStats(peer.Identifier).Received) }}/s
+ {{ humanFileSize(peers.TrafficStats(peer.Identifier).Transmitted) }}/s
+
+
|
diff --git a/go.mod b/go.mod
index 41130f4..8a3ad3e 100644
--- a/go.mod
+++ b/go.mod
@@ -12,6 +12,7 @@ require (
github.com/go-playground/validator/v10 v10.30.1
github.com/go-webauthn/webauthn v0.15.0
github.com/google/uuid v1.6.0
+ github.com/gorilla/websocket v1.5.3
github.com/prometheus-community/pro-bing v0.7.0
github.com/prometheus/client_golang v1.23.2
github.com/stretchr/testify v1.11.1
diff --git a/go.sum b/go.sum
index 93ad84f..c60bc20 100644
--- a/go.sum
+++ b/go.sum
@@ -130,6 +130,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
+github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
+github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
diff --git a/internal/app/api/core/middleware/logging/writer.go b/internal/app/api/core/middleware/logging/writer.go
index 4e3c42f..d77a7f1 100644
--- a/internal/app/api/core/middleware/logging/writer.go
+++ b/internal/app/api/core/middleware/logging/writer.go
@@ -1,6 +1,8 @@
package logging
import (
+ "bufio"
+ "net"
"net/http"
)
@@ -38,6 +40,12 @@ func (w *writerWrapper) Write(data []byte) (int, error) {
return n, err
}
+// Hijack wraps the Hijack method of the ResponseWriter and returns the hijacked connection.
+// This is required for websockets to work.
+func (w *writerWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ return http.NewResponseController(w.ResponseWriter).Hijack()
+}
+
// newWriterWrapper returns a new writerWrapper that wraps the given http.ResponseWriter.
// It initializes the StatusCode to http.StatusOK.
func newWriterWrapper(w http.ResponseWriter) *writerWrapper {
diff --git a/internal/app/api/v0/handlers/endpoint_websocket.go b/internal/app/api/v0/handlers/endpoint_websocket.go
new file mode 100644
index 0000000..5dcc35a
--- /dev/null
+++ b/internal/app/api/v0/handlers/endpoint_websocket.go
@@ -0,0 +1,100 @@
+package handlers
+
+import (
+ "context"
+ "net/http"
+ "strings"
+ "sync"
+
+ "github.com/go-pkgz/routegroup"
+ "github.com/gorilla/websocket"
+
+ "github.com/h44z/wg-portal/internal/app"
+ "github.com/h44z/wg-portal/internal/config"
+ "github.com/h44z/wg-portal/internal/domain"
+)
+
+type WebsocketEventBus interface {
+ Subscribe(topic string, fn any) error
+ Unsubscribe(topic string, fn any) error
+}
+
+type WebsocketEndpoint struct {
+ authenticator Authenticator
+ bus WebsocketEventBus
+
+ upgrader websocket.Upgrader
+}
+
+func NewWebsocketEndpoint(cfg *config.Config, auth Authenticator, bus WebsocketEventBus) *WebsocketEndpoint {
+ return &WebsocketEndpoint{
+ authenticator: auth,
+ bus: bus,
+ upgrader: websocket.Upgrader{
+ ReadBufferSize: 1024,
+ WriteBufferSize: 1024,
+ CheckOrigin: func(r *http.Request) bool {
+ origin := r.Header.Get("Origin")
+ return strings.HasPrefix(origin, cfg.Web.ExternalUrl)
+ },
+ },
+ }
+}
+
+func (e WebsocketEndpoint) GetName() string {
+ return "WebsocketEndpoint"
+}
+
+func (e WebsocketEndpoint) RegisterRoutes(g *routegroup.Bundle) {
+ g.With(e.authenticator.LoggedIn()).HandleFunc("GET /ws", e.handleWebsocket())
+}
+
+// wsMessage represents a message sent over websocket to the frontend
+type wsMessage struct {
+ Type string `json:"type"` // either "peer_stats" or "interface_stats"
+ Data any `json:"data"` // domain.TrafficDelta
+}
+
+func (e WebsocketEndpoint) handleWebsocket() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ conn, err := e.upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ return
+ }
+ defer conn.Close()
+
+ ctx, cancel := context.WithCancel(r.Context())
+ defer cancel()
+
+ writeMutex := sync.Mutex{}
+ writeJSON := func(msg wsMessage) error {
+ writeMutex.Lock()
+ defer writeMutex.Unlock()
+ return conn.WriteJSON(msg)
+ }
+
+ peerStatsHandler := func(status domain.TrafficDelta) {
+ _ = writeJSON(wsMessage{Type: "peer_stats", Data: status})
+ }
+ interfaceStatsHandler := func(status domain.TrafficDelta) {
+ _ = writeJSON(wsMessage{Type: "interface_stats", Data: status})
+ }
+
+ _ = e.bus.Subscribe(app.TopicPeerStatsUpdated, peerStatsHandler)
+ defer e.bus.Unsubscribe(app.TopicPeerStatsUpdated, peerStatsHandler)
+ _ = e.bus.Subscribe(app.TopicInterfaceStatsUpdated, interfaceStatsHandler)
+ defer e.bus.Unsubscribe(app.TopicInterfaceStatsUpdated, interfaceStatsHandler)
+
+ // Keep connection open until client disconnects or context is cancelled
+ go func() {
+ for {
+ if _, _, err := conn.ReadMessage(); err != nil {
+ cancel()
+ return
+ }
+ }
+ }()
+
+ <-ctx.Done()
+ }
+}
diff --git a/internal/app/eventbus.go b/internal/app/eventbus.go
index d411aa6..4e265fe 100644
--- a/internal/app/eventbus.go
+++ b/internal/app/eventbus.go
@@ -26,6 +26,7 @@ const TopicUserEnabled = "user:enabled"
const TopicInterfaceCreated = "interface:created"
const TopicInterfaceUpdated = "interface:updated"
const TopicInterfaceDeleted = "interface:deleted"
+const TopicInterfaceStatsUpdated = "interface:stats:updated"
// endregion interface-events
@@ -37,6 +38,7 @@ const TopicPeerUpdated = "peer:updated"
const TopicPeerInterfaceUpdated = "peer:interface:updated"
const TopicPeerIdentifierUpdated = "peer:identifier:updated"
const TopicPeerStateChanged = "peer:state:changed"
+const TopicPeerStatsUpdated = "peer:stats:updated"
// endregion peer-events
diff --git a/internal/app/wireguard/statistics.go b/internal/app/wireguard/statistics.go
index 78ca6eb..bbc0764 100644
--- a/internal/app/wireguard/statistics.go
+++ b/internal/app/wireguard/statistics.go
@@ -121,15 +121,25 @@ func (c *StatisticsCollector) collectInterfaceData(ctx context.Context) {
"error", err)
continue
}
+ now := time.Now()
err = c.db.UpdateInterfaceStatus(ctx, in.Identifier,
func(i *domain.InterfaceStatus) (*domain.InterfaceStatus, error) {
- i.UpdatedAt = time.Now()
+ td := domain.CalculateTrafficDelta(
+ string(in.Identifier),
+ i.UpdatedAt, now,
+ i.BytesTransmitted, physicalInterface.BytesUpload,
+ i.BytesReceived, physicalInterface.BytesDownload,
+ )
+ i.UpdatedAt = now
i.BytesReceived = physicalInterface.BytesDownload
i.BytesTransmitted = physicalInterface.BytesUpload
// Update prometheus metrics
go c.updateInterfaceMetrics(*i)
+ // Publish stats update event
+ c.bus.Publish(app.TopicInterfaceStatsUpdated, td)
+
return i, nil
})
if err != nil {
@@ -172,6 +182,7 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
slog.Warn("failed to fetch peers for data collection", "interface", in.Identifier, "error", err)
continue
}
+ now := time.Now()
for _, peer := range peers {
var connectionStateChanged bool
var newPeerStatus domain.PeerStatus
@@ -184,8 +195,15 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
lastHandshake = &peer.LastHandshake
}
+ td := domain.CalculateTrafficDelta(
+ string(peer.Identifier),
+ p.UpdatedAt, now,
+ p.BytesTransmitted, peer.BytesDownload,
+ p.BytesReceived, peer.BytesUpload,
+ )
+
// calculate if session was restarted
- p.UpdatedAt = time.Now()
+ p.UpdatedAt = now
p.LastSessionStart = getSessionStartTime(*p, peer.BytesUpload, peer.BytesDownload,
lastHandshake)
p.BytesReceived = peer.BytesUpload // store bytes that where uploaded from the peer and received by the server
@@ -195,7 +213,8 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
p.CalcConnected()
if wasConnected != p.IsConnected {
- slog.Debug("peer connection state changed", "peer", peer.Identifier, "connected", p.IsConnected)
+ slog.Debug("peer connection state changed",
+ "peer", peer.Identifier, "connected", p.IsConnected)
connectionStateChanged = true
newPeerStatus = *p // store new status for event publishing
}
@@ -203,6 +222,9 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
// Update prometheus metrics
go c.updatePeerMetrics(ctx, *p)
+ // Publish stats update event
+ c.bus.Publish(app.TopicPeerStatsUpdated, td)
+
return p, nil
})
if err != nil {
diff --git a/internal/domain/statistics.go b/internal/domain/statistics.go
index cbc987d..aa205e8 100644
--- a/internal/domain/statistics.go
+++ b/internal/domain/statistics.go
@@ -61,3 +61,25 @@ func (r PingerResult) AverageRtt() time.Duration {
}
return total / time.Duration(len(r.Rtts))
}
+
+type TrafficDelta struct {
+ EntityId string `json:"EntityId"` // Either peerId or interfaceId
+ BytesReceivedPerSecond uint64 `json:"BytesReceived"`
+ BytesTransmittedPerSecond uint64 `json:"BytesTransmitted"`
+}
+
+func CalculateTrafficDelta(id string, oldTime, newTime time.Time, oldTx, newTx, oldRx, newRx uint64) TrafficDelta {
+ timeDiff := uint64(newTime.Sub(oldTime).Seconds())
+ if timeDiff == 0 {
+ return TrafficDelta{
+ EntityId: id,
+ BytesReceivedPerSecond: 0,
+ BytesTransmittedPerSecond: 0,
+ }
+ }
+ return TrafficDelta{
+ EntityId: id,
+ BytesReceivedPerSecond: (newRx - oldRx) / timeDiff,
+ BytesTransmittedPerSecond: (newTx - oldTx) / timeDiff,
+ }
+}
|