diff --git a/internal/app/wireguard/wireguard.go b/internal/app/wireguard/wireguard.go index 46ded30..eb76bc9 100644 --- a/internal/app/wireguard/wireguard.go +++ b/internal/app/wireguard/wireguard.go @@ -3,6 +3,7 @@ package wireguard import ( "context" "log/slog" + "sync" "time" "github.com/h44z/wg-portal/internal/app" @@ -76,6 +77,8 @@ type Manager struct { db InterfaceAndPeerDatabaseRepo wg InterfaceController quick WgQuickController + + userLockMap *sync.Map } func NewWireGuardManager( @@ -86,11 +89,12 @@ func NewWireGuardManager( db InterfaceAndPeerDatabaseRepo, ) (*Manager, error) { m := &Manager{ - cfg: cfg, - bus: bus, - wg: wg, - db: db, - quick: quick, + cfg: cfg, + bus: bus, + wg: wg, + db: db, + quick: quick, + userLockMap: &sync.Map{}, } m.connectToMessageBus() @@ -117,6 +121,12 @@ func (m Manager) handleUserCreationEvent(user domain.User) { return } + _, loaded := m.userLockMap.LoadOrStore(user.Identifier, "create") + if loaded { + return // another goroutine is already handling this user + } + defer m.userLockMap.Delete(user.Identifier) + slog.Debug("handling new user event", "user", user.Identifier) ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo()) @@ -132,6 +142,12 @@ func (m Manager) handleUserLoginEvent(userId domain.UserIdentifier) { return } + _, loaded := m.userLockMap.LoadOrStore(userId, "login") + if loaded { + return // another goroutine is already handling this user + } + defer m.userLockMap.Delete(userId) + userPeers, err := m.db.GetUserPeers(context.Background(), userId) if err != nil { slog.Error("failed to retrieve existing peers prior to default peer creation", diff --git a/internal/app/wireguard/wireguard_peers.go b/internal/app/wireguard/wireguard_peers.go index ade8fa4..2131323 100644 --- a/internal/app/wireguard/wireguard_peers.go +++ b/internal/app/wireguard/wireguard_peers.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log/slog" + "slices" "time" "github.com/h44z/wg-portal/internal/app" @@ -23,12 +24,24 @@ func (m Manager) CreateDefaultPeer(ctx context.Context, userId domain.UserIdenti return fmt.Errorf("failed to fetch all interfaces: %w", err) } + userPeers, err := m.db.GetUserPeers(context.Background(), userId) + if err != nil { + return fmt.Errorf("failed to retrieve existing peers prior to default peer creation: %w", err) + } + var newPeers []domain.Peer for _, iface := range existingInterfaces { if iface.Type != domain.InterfaceTypeServer { continue // only create default peers for server interfaces } + peerAlreadyCreated := slices.ContainsFunc(userPeers, func(peer domain.Peer) bool { + return peer.InterfaceIdentifier == iface.Identifier + }) + if peerAlreadyCreated { + continue // skip creation if a peer already exists for this interface + } + peer, err := m.PreparePeer(ctx, iface.Identifier) if err != nil { return fmt.Errorf("failed to create default peer for interface %s: %w", iface.Identifier, err) diff --git a/internal/config/config.go b/internal/config/config.go index dedebac..44ef8e4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -190,6 +190,8 @@ func GetConfig() (*Config, error) { return nil, fmt.Errorf("failed to load config from yaml: %w", err) } + cfg.Web.Sanitize() + return cfg, nil } diff --git a/internal/config/web.go b/internal/config/web.go index 1743305..e4d8dd3 100644 --- a/internal/config/web.go +++ b/internal/config/web.go @@ -1,5 +1,7 @@ package config +import "strings" + // WebConfig contains the configuration for the web server. type WebConfig struct { // RequestLogging enables logging of all HTTP requests. @@ -26,3 +28,7 @@ type WebConfig struct { // KeyFile is the path to the TLS certificate key file. KeyFile string `yaml:"key_file"` } + +func (c *WebConfig) Sanitize() { + c.ExternalUrl = strings.TrimRight(c.ExternalUrl, "/") +}