package wireguard import ( "context" "log/slog" "time" evbus "github.com/vardius/message-bus" "github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/domain" ) type Manager struct { cfg *config.Config bus evbus.MessageBus db InterfaceAndPeerDatabaseRepo wg InterfaceController quick WgQuickController } func NewWireGuardManager( cfg *config.Config, bus evbus.MessageBus, wg InterfaceController, quick WgQuickController, db InterfaceAndPeerDatabaseRepo, ) (*Manager, error) { m := &Manager{ cfg: cfg, bus: bus, wg: wg, db: db, quick: quick, } m.connectToMessageBus() return m, nil } func (m Manager) StartBackgroundJobs(ctx context.Context) { go m.runExpiredPeersCheck(ctx) } func (m Manager) connectToMessageBus() { _ = m.bus.Subscribe(app.TopicUserCreated, m.handleUserCreationEvent) _ = m.bus.Subscribe(app.TopicAuthLogin, m.handleUserLoginEvent) _ = m.bus.Subscribe(app.TopicUserDisabled, m.handleUserDisabledEvent) _ = m.bus.Subscribe(app.TopicUserEnabled, m.handleUserEnabledEvent) _ = m.bus.Subscribe(app.TopicUserDeleted, m.handleUserDeletedEvent) } func (m Manager) handleUserCreationEvent(user *domain.User) { if !m.cfg.Core.CreateDefaultPeerOnCreation { return } slog.Debug("handling new user event", "user", user.Identifier) ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo()) err := m.CreateDefaultPeer(ctx, user.Identifier) if err != nil { slog.Error("failed to create default peer", "user", user.Identifier, "error", err) return } } func (m Manager) handleUserLoginEvent(userId domain.UserIdentifier) { if !m.cfg.Core.CreateDefaultPeer { return } userPeers, err := m.db.GetUserPeers(context.Background(), userId) if err != nil { slog.Error("failed to retrieve existing peers prior to default peer creation", "user", userId, "error", err) return } if len(userPeers) > 0 { return // user already has peers, skip creation } slog.Debug("handling new user login", "user", userId) ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo()) err = m.CreateDefaultPeer(ctx, userId) if err != nil { slog.Error("failed to create default peer", "user", userId, "error", err) return } } func (m Manager) handleUserDisabledEvent(user domain.User) { ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo()) userPeers, err := m.db.GetUserPeers(ctx, user.Identifier) if err != nil { slog.Error("failed to retrieve peers for disabled user", "user", user.Identifier, "error", err) return } for _, peer := range userPeers { if peer.IsDisabled() { continue // peer is already disabled } slog.Debug("disabling peer due to user being disabled", "peer", peer.Identifier, "user", user.Identifier) peer.Disabled = user.Disabled // set to user disabled timestamp peer.DisabledReason = domain.DisabledReasonUserDisabled _, err := m.UpdatePeer(ctx, &peer) if err != nil { slog.Error("failed to disable peer for disabled user", "peer", peer.Identifier, "user", user.Identifier, "error", err) } } } func (m Manager) handleUserEnabledEvent(user domain.User) { if !m.cfg.Core.ReEnablePeerAfterUserEnable { return } ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo()) userPeers, err := m.db.GetUserPeers(ctx, user.Identifier) if err != nil { slog.Error("failed to retrieve peers for re-enabled user", "user", user.Identifier, "error", err) return } for _, peer := range userPeers { if !peer.IsDisabled() { continue // peer is already active } if peer.DisabledReason != domain.DisabledReasonUserDisabled { continue // peer was disabled for another reason } slog.Debug("enabling peer due to user being enabled", "peer", peer.Identifier, "user", user.Identifier) peer.Disabled = nil peer.DisabledReason = "" _, err := m.UpdatePeer(ctx, &peer) if err != nil { slog.Error("failed to enable peer for enabled user", "peer", peer.Identifier, "user", user.Identifier, "error", err) } } return } func (m Manager) handleUserDeletedEvent(user domain.User) { ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo()) userPeers, err := m.db.GetUserPeers(ctx, user.Identifier) if err != nil { slog.Error("failed to retrieve peers for deleted user", "user", user.Identifier, "error", err) return } deletionTime := time.Now() for _, peer := range userPeers { if peer.IsDisabled() { continue // peer is already disabled } if m.cfg.Core.DeletePeerAfterUserDeleted { slog.Debug("deleting peer due to user being deleted", "peer", peer.Identifier, "user", user.Identifier) if err := m.DeletePeer(ctx, peer.Identifier); err != nil { slog.Error("failed to delete peer for deleted user", "peer", peer.Identifier, "user", user.Identifier, "error", err) } } else { slog.Debug("disabling peer due to user being deleted", "peer", peer.Identifier, "user", user.Identifier) peer.UserIdentifier = "" // remove user reference peer.Disabled = &deletionTime peer.DisabledReason = domain.DisabledReasonUserDeleted _, err := m.UpdatePeer(ctx, &peer) if err != nil { slog.Error("failed to disable peer for deleted user", "peer", peer.Identifier, "user", user.Identifier, "error", err) } } } } func (m Manager) runExpiredPeersCheck(ctx context.Context) { ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) running := true for running { select { case <-ctx.Done(): running = false continue case <-time.After(m.cfg.Advanced.ExpiryCheckInterval): // select blocks until one of the cases evaluate to true } interfaces, err := m.db.GetAllInterfaces(ctx) if err != nil { slog.Error("failed to fetch all interfaces for expiry check", "error", err) continue } for _, iface := range interfaces { peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier) if err != nil { slog.Error("failed to fetch all peers from interface for expiry check", "interface", iface.Identifier, "error", err) continue } m.checkExpiredPeers(ctx, peers) } } } func (m Manager) checkExpiredPeers(ctx context.Context, peers []domain.Peer) { now := time.Now() for _, peer := range peers { if peer.IsExpired() && !peer.IsDisabled() { slog.Info("peer has expired, disabling", "peer", peer.Identifier) peer.Disabled = &now peer.DisabledReason = domain.DisabledReasonExpired _, err := m.UpdatePeer(ctx, &peer) if err != nil { slog.Error("failed to update expired peer", "peer", peer.Identifier, "error", err) } } } }