mirror of
https://github.com/h44z/wg-portal.git
synced 2026-06-07 00:56:22 +00:00
Merge commit from fork
* sec: do not expose traffic stats to all users, harden origin check in websocket endpoint * add tests to validate new logic
This commit is contained in:
@@ -135,7 +135,7 @@ func main() {
|
|||||||
apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendPeers)
|
apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendPeers)
|
||||||
apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth, wireGuard)
|
apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth, wireGuard)
|
||||||
apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth)
|
apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth)
|
||||||
apiV0EndpointWebsocket := handlersV0.NewWebsocketEndpoint(cfg, apiV0Auth, eventBus)
|
apiV0EndpointWebsocket := handlersV0.NewWebsocketEndpoint(cfg, apiV0Auth, eventBus, apiV0BackendPeers)
|
||||||
|
|
||||||
apiFrontend := handlersV0.NewRestApi(apiV0Session,
|
apiFrontend := handlersV0.NewRestApi(apiV0Session,
|
||||||
apiV0EndpointAuth,
|
apiV0EndpointAuth,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -19,23 +20,28 @@ type WebsocketEventBus interface {
|
|||||||
Unsubscribe(topic string, fn any) error
|
Unsubscribe(topic string, fn any) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type WebsocketPeerService interface {
|
||||||
|
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
||||||
|
}
|
||||||
|
|
||||||
type WebsocketEndpoint struct {
|
type WebsocketEndpoint struct {
|
||||||
authenticator Authenticator
|
authenticator Authenticator
|
||||||
bus WebsocketEventBus
|
bus WebsocketEventBus
|
||||||
|
peerService WebsocketPeerService
|
||||||
|
|
||||||
upgrader websocket.Upgrader
|
upgrader websocket.Upgrader
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWebsocketEndpoint(cfg *config.Config, auth Authenticator, bus WebsocketEventBus) *WebsocketEndpoint {
|
func NewWebsocketEndpoint(cfg *config.Config, auth Authenticator, bus WebsocketEventBus, peerService WebsocketPeerService) *WebsocketEndpoint {
|
||||||
return &WebsocketEndpoint{
|
return &WebsocketEndpoint{
|
||||||
authenticator: auth,
|
authenticator: auth,
|
||||||
bus: bus,
|
bus: bus,
|
||||||
|
peerService: peerService,
|
||||||
upgrader: websocket.Upgrader{
|
upgrader: websocket.Upgrader{
|
||||||
ReadBufferSize: 1024,
|
ReadBufferSize: 1024,
|
||||||
WriteBufferSize: 1024,
|
WriteBufferSize: 1024,
|
||||||
CheckOrigin: func(r *http.Request) bool {
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
origin := r.Header.Get("Origin")
|
return matchOrigin(cfg.Web.ExternalUrl, r.Header.Get("Origin"))
|
||||||
return strings.HasPrefix(origin, cfg.Web.ExternalUrl)
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -57,6 +63,8 @@ type wsMessage struct {
|
|||||||
|
|
||||||
func (e WebsocketEndpoint) handleWebsocket() http.HandlerFunc {
|
func (e WebsocketEndpoint) handleWebsocket() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userInfo := domain.GetUserInfo(r.Context())
|
||||||
|
|
||||||
conn, err := e.upgrader.Upgrade(w, r, nil)
|
conn, err := e.upgrader.Upgrade(w, r, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@@ -74,9 +82,29 @@ func (e WebsocketEndpoint) handleWebsocket() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
peerStatsHandler := func(status domain.TrafficDelta) {
|
peerStatsHandler := func(status domain.TrafficDelta) {
|
||||||
|
if !userInfo.IsAdmin {
|
||||||
|
// lookup peer user-info to validate ownership
|
||||||
|
peer, err := e.peerService.GetPeer(ctx, domain.PeerIdentifier(status.EntityId))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.UserIdentifier == "" {
|
||||||
|
return // if peer is not assigned to any user, dont send stats
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.UserIdentifier != userInfo.Id {
|
||||||
|
return // only expose stats for own peers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
_ = writeJSON(wsMessage{Type: "peer_stats", Data: status})
|
_ = writeJSON(wsMessage{Type: "peer_stats", Data: status})
|
||||||
}
|
}
|
||||||
interfaceStatsHandler := func(status domain.TrafficDelta) {
|
interfaceStatsHandler := func(status domain.TrafficDelta) {
|
||||||
|
if !userInfo.IsAdmin {
|
||||||
|
return // interface stats will only be exposed to admins
|
||||||
|
}
|
||||||
|
|
||||||
_ = writeJSON(wsMessage{Type: "interface_stats", Data: status})
|
_ = writeJSON(wsMessage{Type: "interface_stats", Data: status})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,3 +126,18 @@ func (e WebsocketEndpoint) handleWebsocket() http.HandlerFunc {
|
|||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func matchOrigin(externalBaseUrl, origin string) bool {
|
||||||
|
originURL, err := url.Parse(origin)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
externalURL, err := url.Parse(externalBaseUrl)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return originURL.Scheme == externalURL.Scheme &&
|
||||||
|
strings.EqualFold(originURL.Host, externalURL.Host)
|
||||||
|
}
|
||||||
|
|||||||
224
internal/app/api/v0/handlers/endpoint_websocket_test.go
Normal file
224
internal/app/api/v0/handlers/endpoint_websocket_test.go
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
evbus "github.com/vardius/message-bus"
|
||||||
|
|
||||||
|
"github.com/h44z/wg-portal/internal/app"
|
||||||
|
"github.com/h44z/wg-portal/internal/config"
|
||||||
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
// region test-helper
|
||||||
|
|
||||||
|
type websocketTestPeerService struct {
|
||||||
|
peers map[domain.PeerIdentifier]*domain.Peer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s websocketTestPeerService) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) {
|
||||||
|
peer, ok := s.peers[id]
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("peer not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return peer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestWebsocketConnection(
|
||||||
|
t *testing.T,
|
||||||
|
bus evbus.MessageBus,
|
||||||
|
userInfo *domain.ContextUserInfo,
|
||||||
|
peers map[domain.PeerIdentifier]*domain.Peer,
|
||||||
|
) (*websocket.Conn, func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
cfg := &config.Config{}
|
||||||
|
endpoint := NewWebsocketEndpoint(cfg, nil, bus, websocketTestPeerService{peers: peers})
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
r = r.WithContext(domain.SetUserInfo(r.Context(), userInfo))
|
||||||
|
endpoint.handleWebsocket()(w, r)
|
||||||
|
}))
|
||||||
|
cfg.Web.ExternalUrl = server.URL
|
||||||
|
|
||||||
|
wsURL := "ws" + server.URL[len("http"):]
|
||||||
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, http.Header{"Origin": []string{server.URL}})
|
||||||
|
if err != nil {
|
||||||
|
server.Close()
|
||||||
|
t.Fatalf("failed to dial websocket: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup := func() {
|
||||||
|
conn.Close()
|
||||||
|
server.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertWebsocketMessage(t *testing.T, conn *websocket.Conn, messageType string, entityId string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if err := conn.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
|
||||||
|
t.Fatalf("failed to set read deadline: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var message wsMessage
|
||||||
|
if err := conn.ReadJSON(&message); err != nil {
|
||||||
|
t.Fatalf("failed to read websocket message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if message.Type != messageType {
|
||||||
|
t.Fatalf("unexpected message type: got %q, want %q", message.Type, messageType)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, ok := message.Data.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("unexpected message data type: %T", message.Data)
|
||||||
|
}
|
||||||
|
if data["EntityId"] != entityId {
|
||||||
|
t.Fatalf("unexpected entity id: got %v, want %q", data["EntityId"], entityId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertNoWebsocketMessage(t *testing.T, conn *websocket.Conn) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if err := conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)); err != nil {
|
||||||
|
t.Fatalf("failed to set read deadline: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var message wsMessage
|
||||||
|
if err := conn.ReadJSON(&message); err == nil {
|
||||||
|
t.Fatalf("unexpected websocket message: %+v", message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion test-helper
|
||||||
|
|
||||||
|
func TestWebsocketEndpointAllowsOwnPeerStatsForNonAdmin(t *testing.T) {
|
||||||
|
bus := evbus.New(10)
|
||||||
|
conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "user-a"},
|
||||||
|
map[domain.PeerIdentifier]*domain.Peer{
|
||||||
|
"own-peer": {Identifier: "own-peer", UserIdentifier: "user-a"},
|
||||||
|
})
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
bus.Publish(app.TopicPeerStatsUpdated, domain.TrafficDelta{EntityId: "own-peer", BytesReceivedPerSecond: 1})
|
||||||
|
assertWebsocketMessage(t, conn, "peer_stats", "own-peer")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebsocketEndpointFiltersOtherPeerStatsForNonAdmin(t *testing.T) {
|
||||||
|
bus := evbus.New(10)
|
||||||
|
conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "user-a"},
|
||||||
|
map[domain.PeerIdentifier]*domain.Peer{
|
||||||
|
"other-peer": {Identifier: "other-peer", UserIdentifier: "user-b"},
|
||||||
|
})
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
bus.Publish(app.TopicPeerStatsUpdated, domain.TrafficDelta{EntityId: "other-peer", BytesReceivedPerSecond: 1})
|
||||||
|
assertNoWebsocketMessage(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebsocketEndpointFiltersUnknownPeerStatsForNonAdmin(t *testing.T) {
|
||||||
|
bus := evbus.New(10)
|
||||||
|
conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "user-a"},
|
||||||
|
map[domain.PeerIdentifier]*domain.Peer{
|
||||||
|
"other-peer": {Identifier: "other-peer", UserIdentifier: ""},
|
||||||
|
})
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
bus.Publish(app.TopicPeerStatsUpdated, domain.TrafficDelta{EntityId: "other-peer", BytesReceivedPerSecond: 1})
|
||||||
|
assertNoWebsocketMessage(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebsocketEndpointFiltersUnknownPeerStatsForNonAdmin2(t *testing.T) {
|
||||||
|
bus := evbus.New(10)
|
||||||
|
conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "user-a"}, nil)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
bus.Publish(app.TopicPeerStatsUpdated, domain.TrafficDelta{EntityId: "unknown-peer", BytesReceivedPerSecond: 1})
|
||||||
|
assertNoWebsocketMessage(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebsocketEndpointFiltersInterfaceStatsForNonAdmin(t *testing.T) {
|
||||||
|
bus := evbus.New(10)
|
||||||
|
conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "user-a"}, nil)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
bus.Publish(app.TopicInterfaceStatsUpdated, domain.TrafficDelta{EntityId: "wg0", BytesReceivedPerSecond: 1})
|
||||||
|
assertNoWebsocketMessage(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebsocketEndpointAllowsAllStatsForAdmin(t *testing.T) {
|
||||||
|
bus := evbus.New(10)
|
||||||
|
conn, cleanup := newTestWebsocketConnection(t, bus, &domain.ContextUserInfo{Id: "admin", IsAdmin: true}, nil)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
bus.Publish(app.TopicPeerStatsUpdated, domain.TrafficDelta{EntityId: "other-peer", BytesReceivedPerSecond: 1})
|
||||||
|
assertWebsocketMessage(t, conn, "peer_stats", "other-peer")
|
||||||
|
|
||||||
|
bus.Publish(app.TopicInterfaceStatsUpdated, domain.TrafficDelta{EntityId: "wg0", BytesReceivedPerSecond: 1})
|
||||||
|
assertWebsocketMessage(t, conn, "interface_stats", "wg0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_matchOrigin(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
externalBaseUrl string
|
||||||
|
origin string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "matching origin",
|
||||||
|
externalBaseUrl: "https://example.com",
|
||||||
|
origin: "https://example.com",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "matching origin with path",
|
||||||
|
externalBaseUrl: "https://example.com/app1",
|
||||||
|
origin: "https://example.com/app2",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-matching origin with different host",
|
||||||
|
externalBaseUrl: "https://example.com",
|
||||||
|
origin: "https://example.com.malicious.com",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-matching origin with different scheme",
|
||||||
|
externalBaseUrl: "https://example.com",
|
||||||
|
origin: "http://example.com",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid origin URL",
|
||||||
|
externalBaseUrl: "https://example.com",
|
||||||
|
origin: "://invalid-url",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid externalBaseUrl",
|
||||||
|
externalBaseUrl: "://invalid-url",
|
||||||
|
origin: "https://example.com",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := matchOrigin(tt.externalBaseUrl, tt.origin)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("matchOrigin() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user