package wireguard import ( "context" "log/slog" "time" "github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/domain" ) // region dependencies type InterfaceAndPeerDatabaseRepo interface { GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error) SaveInterface( ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.Interface) (*domain.Interface, error), ) error DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) SavePeer( ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.Peer) (*domain.Peer, error), ) error DeletePeer(ctx context.Context, id domain.PeerIdentifier) error GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error) } type InterfaceController interface { GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) SaveInterface( _ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error), ) error DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error SavePeer( _ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error), ) error DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error } type WgQuickController interface { ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error UnsetDNS(id domain.InterfaceIdentifier) error } type EventBus interface { // Publish sends a message to the message bus. Publish(topic string, args ...any) // Subscribe subscribes to a topic Subscribe(topic string, fn interface{}) error } // endregion dependencies type Manager struct { cfg *config.Config bus EventBus db InterfaceAndPeerDatabaseRepo wg InterfaceController quick WgQuickController } func NewWireGuardManager( cfg *config.Config, bus EventBus, wg InterfaceController, quick WgQuickController, db InterfaceAndPeerDatabaseRepo, ) (*Manager, error) { m := &Manager{ cfg: cfg, bus: bus, wg: wg, db: db, quick: quick, } m.connectToMessageBus() return m, nil } // StartBackgroundJobs starts background jobs like the expired peers check. // This method is non-blocking. func (m Manager) StartBackgroundJobs(ctx context.Context) { go m.runExpiredPeersCheck(ctx) } func (m Manager) connectToMessageBus() { _ = m.bus.Subscribe(app.TopicUserCreated, m.handleUserCreationEvent) _ = m.bus.Subscribe(app.TopicAuthLogin, m.handleUserLoginEvent) _ = m.bus.Subscribe(app.TopicUserDisabled, m.handleUserDisabledEvent) _ = m.bus.Subscribe(app.TopicUserEnabled, m.handleUserEnabledEvent) _ = m.bus.Subscribe(app.TopicUserDeleted, m.handleUserDeletedEvent) } func (m Manager) handleUserCreationEvent(user *domain.User) { if !m.cfg.Core.CreateDefaultPeerOnCreation { return } slog.Debug("handling new user event", "user", user.Identifier) ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo()) err := m.CreateDefaultPeer(ctx, user.Identifier) if err != nil { slog.Error("failed to create default peer", "user", user.Identifier, "error", err) return } } func (m Manager) handleUserLoginEvent(userId domain.UserIdentifier) { if !m.cfg.Core.CreateDefaultPeer { return } userPeers, err := m.db.GetUserPeers(context.Background(), userId) if err != nil { slog.Error("failed to retrieve existing peers prior to default peer creation", "user", userId, "error", err) return } if len(userPeers) > 0 { return // user already has peers, skip creation } slog.Debug("handling new user login", "user", userId) ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo()) err = m.CreateDefaultPeer(ctx, userId) if err != nil { slog.Error("failed to create default peer", "user", userId, "error", err) return } } func (m Manager) handleUserDisabledEvent(user domain.User) { ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo()) userPeers, err := m.db.GetUserPeers(ctx, user.Identifier) if err != nil { slog.Error("failed to retrieve peers for disabled user", "user", user.Identifier, "error", err) return } for _, peer := range userPeers { if peer.IsDisabled() { continue // peer is already disabled } slog.Debug("disabling peer due to user being disabled", "peer", peer.Identifier, "user", user.Identifier) peer.Disabled = user.Disabled // set to user disabled timestamp peer.DisabledReason = domain.DisabledReasonUserDisabled _, err := m.UpdatePeer(ctx, &peer) if err != nil { slog.Error("failed to disable peer for disabled user", "peer", peer.Identifier, "user", user.Identifier, "error", err) } } } func (m Manager) handleUserEnabledEvent(user domain.User) { if !m.cfg.Core.ReEnablePeerAfterUserEnable { return } ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo()) userPeers, err := m.db.GetUserPeers(ctx, user.Identifier) if err != nil { slog.Error("failed to retrieve peers for re-enabled user", "user", user.Identifier, "error", err) return } for _, peer := range userPeers { if !peer.IsDisabled() { continue // peer is already active } if peer.DisabledReason != domain.DisabledReasonUserDisabled { continue // peer was disabled for another reason } slog.Debug("enabling peer due to user being enabled", "peer", peer.Identifier, "user", user.Identifier) peer.Disabled = nil peer.DisabledReason = "" _, err := m.UpdatePeer(ctx, &peer) if err != nil { slog.Error("failed to enable peer for enabled user", "peer", peer.Identifier, "user", user.Identifier, "error", err) } } return } func (m Manager) handleUserDeletedEvent(user domain.User) { ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo()) userPeers, err := m.db.GetUserPeers(ctx, user.Identifier) if err != nil { slog.Error("failed to retrieve peers for deleted user", "user", user.Identifier, "error", err) return } deletionTime := time.Now() for _, peer := range userPeers { if peer.IsDisabled() { continue // peer is already disabled } if m.cfg.Core.DeletePeerAfterUserDeleted { slog.Debug("deleting peer due to user being deleted", "peer", peer.Identifier, "user", user.Identifier) if err := m.DeletePeer(ctx, peer.Identifier); err != nil { slog.Error("failed to delete peer for deleted user", "peer", peer.Identifier, "user", user.Identifier, "error", err) } } else { slog.Debug("disabling peer due to user being deleted", "peer", peer.Identifier, "user", user.Identifier) peer.UserIdentifier = "" // remove user reference peer.Disabled = &deletionTime peer.DisabledReason = domain.DisabledReasonUserDeleted _, err := m.UpdatePeer(ctx, &peer) if err != nil { slog.Error("failed to disable peer for deleted user", "peer", peer.Identifier, "user", user.Identifier, "error", err) } } } } func (m Manager) runExpiredPeersCheck(ctx context.Context) { ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) running := true for running { select { case <-ctx.Done(): running = false continue case <-time.After(m.cfg.Advanced.ExpiryCheckInterval): // select blocks until one of the cases evaluate to true } interfaces, err := m.db.GetAllInterfaces(ctx) if err != nil { slog.Error("failed to fetch all interfaces for expiry check", "error", err) continue } for _, iface := range interfaces { peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier) if err != nil { slog.Error("failed to fetch all peers from interface for expiry check", "interface", iface.Identifier, "error", err) continue } m.checkExpiredPeers(ctx, peers) } } } func (m Manager) checkExpiredPeers(ctx context.Context, peers []domain.Peer) { now := time.Now() for _, peer := range peers { if peer.IsExpired() && !peer.IsDisabled() { slog.Info("peer has expired, disabling", "peer", peer.Identifier) peer.Disabled = &now peer.DisabledReason = domain.DisabledReasonExpired _, err := m.UpdatePeer(ctx, &peer) if err != nil { slog.Error("failed to update expired peer", "peer", peer.Identifier, "error", err) } } } }