2025-03-23 23:09:47 +01:00

327 lines
9.3 KiB
Go

package wireguard
import (
"context"
"log/slog"
"time"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
// region dependencies
type InterfaceAndPeerDatabaseRepo interface {
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error)
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error)
SaveInterface(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(in *domain.Interface) (*domain.Interface, error),
) error
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
SavePeer(
ctx context.Context,
id domain.PeerIdentifier,
updateFunc func(in *domain.Peer) (*domain.Peer, error),
) error
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
}
type InterfaceController interface {
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
SaveInterface(
_ context.Context,
id domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
SavePeer(
_ context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
}
type WgQuickController interface {
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
UnsetDNS(id domain.InterfaceIdentifier) error
}
type EventBus interface {
// Publish sends a message to the message bus.
Publish(topic string, args ...any)
// Subscribe subscribes to a topic
Subscribe(topic string, fn interface{}) error
}
// endregion dependencies
type Manager struct {
cfg *config.Config
bus EventBus
db InterfaceAndPeerDatabaseRepo
wg InterfaceController
quick WgQuickController
}
func NewWireGuardManager(
cfg *config.Config,
bus EventBus,
wg InterfaceController,
quick WgQuickController,
db InterfaceAndPeerDatabaseRepo,
) (*Manager, error) {
m := &Manager{
cfg: cfg,
bus: bus,
wg: wg,
db: db,
quick: quick,
}
m.connectToMessageBus()
return m, nil
}
// StartBackgroundJobs starts background jobs like the expired peers check.
// This method is non-blocking.
func (m Manager) StartBackgroundJobs(ctx context.Context) {
go m.runExpiredPeersCheck(ctx)
}
func (m Manager) connectToMessageBus() {
_ = m.bus.Subscribe(app.TopicUserCreated, m.handleUserCreationEvent)
_ = m.bus.Subscribe(app.TopicAuthLogin, m.handleUserLoginEvent)
_ = m.bus.Subscribe(app.TopicUserDisabled, m.handleUserDisabledEvent)
_ = m.bus.Subscribe(app.TopicUserEnabled, m.handleUserEnabledEvent)
_ = m.bus.Subscribe(app.TopicUserDeleted, m.handleUserDeletedEvent)
}
func (m Manager) handleUserCreationEvent(user *domain.User) {
if !m.cfg.Core.CreateDefaultPeerOnCreation {
return
}
slog.Debug("handling new user event", "user", user.Identifier)
ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
err := m.CreateDefaultPeer(ctx, user.Identifier)
if err != nil {
slog.Error("failed to create default peer", "user", user.Identifier, "error", err)
return
}
}
func (m Manager) handleUserLoginEvent(userId domain.UserIdentifier) {
if !m.cfg.Core.CreateDefaultPeer {
return
}
userPeers, err := m.db.GetUserPeers(context.Background(), userId)
if err != nil {
slog.Error("failed to retrieve existing peers prior to default peer creation",
"user", userId,
"error", err)
return
}
if len(userPeers) > 0 {
return // user already has peers, skip creation
}
slog.Debug("handling new user login", "user", userId)
ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
err = m.CreateDefaultPeer(ctx, userId)
if err != nil {
slog.Error("failed to create default peer", "user", userId, "error", err)
return
}
}
func (m Manager) handleUserDisabledEvent(user domain.User) {
ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
userPeers, err := m.db.GetUserPeers(ctx, user.Identifier)
if err != nil {
slog.Error("failed to retrieve peers for disabled user",
"user", user.Identifier,
"error", err)
return
}
for _, peer := range userPeers {
if peer.IsDisabled() {
continue // peer is already disabled
}
slog.Debug("disabling peer due to user being disabled",
"peer", peer.Identifier,
"user", user.Identifier)
peer.Disabled = user.Disabled // set to user disabled timestamp
peer.DisabledReason = domain.DisabledReasonUserDisabled
_, err := m.UpdatePeer(ctx, &peer)
if err != nil {
slog.Error("failed to disable peer for disabled user",
"peer", peer.Identifier,
"user", user.Identifier,
"error", err)
}
}
}
func (m Manager) handleUserEnabledEvent(user domain.User) {
if !m.cfg.Core.ReEnablePeerAfterUserEnable {
return
}
ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
userPeers, err := m.db.GetUserPeers(ctx, user.Identifier)
if err != nil {
slog.Error("failed to retrieve peers for re-enabled user",
"user", user.Identifier,
"error", err)
return
}
for _, peer := range userPeers {
if !peer.IsDisabled() {
continue // peer is already active
}
if peer.DisabledReason != domain.DisabledReasonUserDisabled {
continue // peer was disabled for another reason
}
slog.Debug("enabling peer due to user being enabled",
"peer", peer.Identifier,
"user", user.Identifier)
peer.Disabled = nil
peer.DisabledReason = ""
_, err := m.UpdatePeer(ctx, &peer)
if err != nil {
slog.Error("failed to enable peer for enabled user",
"peer", peer.Identifier,
"user", user.Identifier,
"error", err)
}
}
return
}
func (m Manager) handleUserDeletedEvent(user domain.User) {
ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
userPeers, err := m.db.GetUserPeers(ctx, user.Identifier)
if err != nil {
slog.Error("failed to retrieve peers for deleted user",
"user", user.Identifier,
"error", err)
return
}
deletionTime := time.Now()
for _, peer := range userPeers {
if peer.IsDisabled() {
continue // peer is already disabled
}
if m.cfg.Core.DeletePeerAfterUserDeleted {
slog.Debug("deleting peer due to user being deleted",
"peer", peer.Identifier,
"user", user.Identifier)
if err := m.DeletePeer(ctx, peer.Identifier); err != nil {
slog.Error("failed to delete peer for deleted user",
"peer", peer.Identifier,
"user", user.Identifier,
"error", err)
}
} else {
slog.Debug("disabling peer due to user being deleted",
"peer", peer.Identifier,
"user", user.Identifier)
peer.UserIdentifier = "" // remove user reference
peer.Disabled = &deletionTime
peer.DisabledReason = domain.DisabledReasonUserDeleted
_, err := m.UpdatePeer(ctx, &peer)
if err != nil {
slog.Error("failed to disable peer for deleted user",
"peer", peer.Identifier,
"user", user.Identifier,
"error", err)
}
}
}
}
func (m Manager) runExpiredPeersCheck(ctx context.Context) {
ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo())
running := true
for running {
select {
case <-ctx.Done():
running = false
continue
case <-time.After(m.cfg.Advanced.ExpiryCheckInterval):
// select blocks until one of the cases evaluate to true
}
interfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
slog.Error("failed to fetch all interfaces for expiry check", "error", err)
continue
}
for _, iface := range interfaces {
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil {
slog.Error("failed to fetch all peers from interface for expiry check",
"interface", iface.Identifier,
"error", err)
continue
}
m.checkExpiredPeers(ctx, peers)
}
}
}
func (m Manager) checkExpiredPeers(ctx context.Context, peers []domain.Peer) {
now := time.Now()
for _, peer := range peers {
if peer.IsExpired() && !peer.IsDisabled() {
slog.Info("peer has expired, disabling", "peer", peer.Identifier)
peer.Disabled = &now
peer.DisabledReason = domain.DisabledReasonExpired
_, err := m.UpdatePeer(ctx, &peer)
if err != nil {
slog.Error("failed to update expired peer", "peer", peer.Identifier, "error", err)
}
}
}
}