Merge branch 'master' into stable

# Conflicts:
#	internal/domain/peer.go
This commit is contained in:
Christoph Haas
2025-10-19 13:25:07 +02:00
137 changed files with 10275 additions and 1996 deletions

View File

@@ -0,0 +1,142 @@
package wireguard
import (
"fmt"
"log/slog"
"maps"
"slices"
"github.com/h44z/wg-portal/internal/adapters/wgcontroller"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
type backendInstance struct {
Config config.BackendBase // Config is the configuration for the backend instance.
Implementation domain.InterfaceController
}
type ControllerManager struct {
cfg *config.Config
controllers map[domain.InterfaceBackend]backendInstance
}
func NewControllerManager(cfg *config.Config) (*ControllerManager, error) {
c := &ControllerManager{
cfg: cfg,
controllers: make(map[domain.InterfaceBackend]backendInstance),
}
err := c.init()
if err != nil {
return nil, err
}
return c, nil
}
func (c *ControllerManager) init() error {
if err := c.registerLocalController(); err != nil {
return err
}
if err := c.registerMikrotikControllers(); err != nil {
return err
}
c.logRegisteredControllers()
return nil
}
func (c *ControllerManager) registerLocalController() error {
localController, err := wgcontroller.NewLocalController(c.cfg)
if err != nil {
return fmt.Errorf("failed to create local WireGuard controller: %w", err)
}
c.controllers[config.LocalBackendName] = backendInstance{
Config: config.BackendBase{
Id: config.LocalBackendName,
DisplayName: "Local WireGuard Controller",
IgnoredInterfaces: c.cfg.Backend.IgnoredLocalInterfaces,
},
Implementation: localController,
}
return nil
}
func (c *ControllerManager) registerMikrotikControllers() error {
for _, backendConfig := range c.cfg.Backend.Mikrotik {
if backendConfig.Id == config.LocalBackendName {
slog.Warn("skipping registration of Mikrotik controller with reserved ID", "id", config.LocalBackendName)
continue
}
controller, err := wgcontroller.NewMikrotikController(c.cfg, &backendConfig)
if err != nil {
return fmt.Errorf("failed to create Mikrotik controller for backend %s: %w", backendConfig.Id, err)
}
c.controllers[domain.InterfaceBackend(backendConfig.Id)] = backendInstance{
Config: backendConfig.BackendBase,
Implementation: controller,
}
}
return nil
}
func (c *ControllerManager) logRegisteredControllers() {
for backend, controller := range c.controllers {
slog.Debug("backend controller registered",
"backend", backend, "type", fmt.Sprintf("%T", controller.Implementation))
}
}
func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) domain.InterfaceController {
return c.getController(backend, "").Implementation
}
func (c *ControllerManager) GetController(iface domain.Interface) domain.InterfaceController {
return c.getController(iface.Backend, iface.Identifier).Implementation
}
func (c *ControllerManager) getController(
backend domain.InterfaceBackend,
ifaceId domain.InterfaceIdentifier,
) backendInstance {
if backend == "" {
// If no backend is specified, use the local controller.
// This might be the case for interfaces created in previous WireGuard Portal versions.
backend = config.LocalBackendName
}
controller, exists := c.controllers[backend]
if !exists {
controller, exists = c.controllers[config.LocalBackendName] // Fallback to local controller
if !exists {
// If the local controller is also not found, panic
panic(fmt.Sprintf("%s interface controller for backend %s not found", ifaceId, backend))
}
slog.Warn("controller for backend not found, using local controller",
"backend", backend, "interface", ifaceId)
}
return controller
}
func (c *ControllerManager) GetAllControllers() []backendInstance {
var backendInstances = make([]backendInstance, 0, len(c.controllers))
for instance := range maps.Values(c.controllers) {
backendInstances = append(backendInstances, instance)
}
return backendInstances
}
func (c *ControllerManager) GetControllerNames() []config.BackendBase {
var names []config.BackendBase
for _, id := range slices.Sorted(maps.Keys(c.controllers)) {
names = append(names, c.controllers[id].Config)
}
return names
}

View File

@@ -6,8 +6,6 @@ import (
"sync"
"time"
probing "github.com/prometheus-community/pro-bing"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
@@ -30,11 +28,6 @@ type StatisticsDatabaseRepo interface {
DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error
}
type StatisticsInterfaceController interface {
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
}
type StatisticsMetricsServer interface {
UpdateInterfaceMetrics(status domain.InterfaceStatus)
UpdatePeerMetrics(peer *domain.Peer, status domain.PeerStatus)
@@ -43,6 +36,13 @@ type StatisticsMetricsServer interface {
type StatisticsEventBus interface {
// Subscribe subscribes to a topic
Subscribe(topic string, fn interface{}) error
// Publish sends a message to the message bus.
Publish(topic string, args ...any)
}
type pingJob struct {
Peer domain.Peer
Backend domain.InterfaceBackend
}
type StatisticsCollector struct {
@@ -50,11 +50,13 @@ type StatisticsCollector struct {
bus StatisticsEventBus
pingWaitGroup sync.WaitGroup
pingJobs chan domain.Peer
pingJobs chan pingJob
db StatisticsDatabaseRepo
wg StatisticsInterfaceController
wg *ControllerManager
ms StatisticsMetricsServer
peerChangeEvent chan domain.PeerIdentifier
}
// NewStatisticsCollector creates a new statistics collector.
@@ -62,7 +64,7 @@ func NewStatisticsCollector(
cfg *config.Config,
bus StatisticsEventBus,
db StatisticsDatabaseRepo,
wg StatisticsInterfaceController,
wg *ControllerManager,
ms StatisticsMetricsServer,
) (*StatisticsCollector, error) {
c := &StatisticsCollector{
@@ -113,7 +115,7 @@ func (c *StatisticsCollector) collectInterfaceData(ctx context.Context) {
}
for _, in := range interfaces {
physicalInterface, err := c.wg.GetInterface(ctx, in.Identifier)
physicalInterface, err := c.wg.GetController(in).GetInterface(ctx, in.Identifier)
if err != nil {
slog.Warn("failed to load physical interface for data collection", "interface", in.Identifier,
"error", err)
@@ -165,14 +167,18 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
}
for _, in := range interfaces {
peers, err := c.wg.GetPeers(ctx, in.Identifier)
peers, err := c.wg.GetController(in).GetPeers(ctx, in.Identifier)
if err != nil {
slog.Warn("failed to fetch peers for data collection", "interface", in.Identifier, "error", err)
continue
}
for _, peer := range peers {
var connectionStateChanged bool
var newPeerStatus domain.PeerStatus
err = c.db.UpdatePeerStatus(ctx, peer.Identifier,
func(p *domain.PeerStatus) (*domain.PeerStatus, error) {
wasConnected := p.IsConnected
var lastHandshake *time.Time
if !peer.LastHandshake.IsZero() {
lastHandshake = &peer.LastHandshake
@@ -186,6 +192,13 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
p.BytesTransmitted = peer.BytesDownload // store bytes that where received from the peer and sent by the server
p.Endpoint = peer.Endpoint
p.LastHandshake = lastHandshake
p.CalcConnected()
if wasConnected != p.IsConnected {
slog.Debug("peer connection state changed", "peer", peer.Identifier, "connected", p.IsConnected)
connectionStateChanged = true
newPeerStatus = *p // store new status for event publishing
}
// Update prometheus metrics
go c.updatePeerMetrics(ctx, *p)
@@ -197,6 +210,17 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
} else {
slog.Debug("updated peer status", "peer", peer.Identifier)
}
if connectionStateChanged {
peerModel, err := c.db.GetPeer(ctx, peer.Identifier)
if err != nil {
slog.Error("failed to fetch peer for data collection", "peer", peer.Identifier, "error",
err)
continue
}
// publish event if connection state changed
c.bus.Publish(app.TopicPeerStateChanged, newPeerStatus, *peerModel)
}
}
}
}
@@ -245,7 +269,7 @@ func (c *StatisticsCollector) startPingWorkers(ctx context.Context) {
c.pingWaitGroup = sync.WaitGroup{}
c.pingWaitGroup.Add(c.cfg.Statistics.PingCheckWorkers)
c.pingJobs = make(chan domain.Peer, c.cfg.Statistics.PingCheckWorkers)
c.pingJobs = make(chan pingJob, c.cfg.Statistics.PingCheckWorkers)
// start workers
for i := 0; i < c.cfg.Statistics.PingCheckWorkers; i++ {
@@ -288,7 +312,10 @@ func (c *StatisticsCollector) enqueuePingChecks(ctx context.Context) {
continue
}
for _, peer := range peers {
c.pingJobs <- peer
c.pingJobs <- pingJob{
Peer: peer,
Backend: in.Backend,
}
}
}
}
@@ -297,13 +324,21 @@ func (c *StatisticsCollector) enqueuePingChecks(ctx context.Context) {
func (c *StatisticsCollector) pingWorker(ctx context.Context) {
defer c.pingWaitGroup.Done()
for peer := range c.pingJobs {
peerPingable := c.isPeerPingable(ctx, peer)
for job := range c.pingJobs {
peer := job.Peer
backend := job.Backend
var connectionStateChanged bool
var newPeerStatus domain.PeerStatus
peerPingable := c.isPeerPingable(ctx, backend, peer)
slog.Debug("peer ping check completed", "peer", peer.Identifier, "pingable", peerPingable)
now := time.Now()
err := c.db.UpdatePeerStatus(ctx, peer.Identifier,
func(p *domain.PeerStatus) (*domain.PeerStatus, error) {
wasConnected := p.IsConnected
if peerPingable {
p.IsPingable = true
p.LastPing = &now
@@ -311,6 +346,13 @@ func (c *StatisticsCollector) pingWorker(ctx context.Context) {
p.IsPingable = false
p.LastPing = nil
}
p.UpdatedAt = time.Now()
p.CalcConnected()
if wasConnected != p.IsConnected {
connectionStateChanged = true
newPeerStatus = *p // store new status for event publishing
}
// Update prometheus metrics
go c.updatePeerMetrics(ctx, *p)
@@ -322,10 +364,19 @@ func (c *StatisticsCollector) pingWorker(ctx context.Context) {
} else {
slog.Debug("updated peer ping status", "peer", peer.Identifier)
}
if connectionStateChanged {
// publish event if connection state changed
c.bus.Publish(app.TopicPeerStateChanged, newPeerStatus, peer)
}
}
}
func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Peer) bool {
func (c *StatisticsCollector) isPeerPingable(
ctx context.Context,
backend domain.InterfaceBackend,
peer domain.Peer,
) bool {
if !c.cfg.Statistics.UsePingChecks {
return false
}
@@ -335,23 +386,13 @@ func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Pe
return false
}
pinger, err := probing.NewPinger(checkAddr)
stats, err := c.wg.GetControllerByName(backend).PingAddresses(ctx, checkAddr)
if err != nil {
slog.Debug("failed to instantiate pinger", "peer", peer.Identifier, "address", checkAddr, "error", err)
slog.Debug("failed to ping peer", "peer", peer.Identifier, "error", err)
return false
}
checkCount := 1
pinger.SetPrivileged(!c.cfg.Statistics.PingUnprivileged)
pinger.Count = checkCount
pinger.Timeout = 2 * time.Second
err = pinger.RunWithContext(ctx) // Blocks until finished.
if err != nil {
slog.Debug("pinger for peer exited unexpectedly", "peer", peer.Identifier, "address", checkAddr, "error", err)
return false
}
stats := pinger.Statistics()
return stats.PacketsRecv == checkCount
return stats.IsPingable()
}
func (c *StatisticsCollector) updateInterfaceMetrics(status domain.InterfaceStatus) {

View File

@@ -37,29 +37,10 @@ type InterfaceAndPeerDatabaseRepo interface {
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
}
type InterfaceController interface {
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
SaveInterface(
_ context.Context,
id domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
SavePeer(
_ context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
}
type WgQuickController interface {
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
UnsetDNS(id domain.InterfaceIdentifier) error
ExecuteInterfaceHook(ctx context.Context, id domain.InterfaceIdentifier, hookCmd string) error
SetDNS(ctx context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
UnsetDNS(ctx context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
}
type EventBus interface {
@@ -72,11 +53,10 @@ type EventBus interface {
// endregion dependencies
type Manager struct {
cfg *config.Config
bus EventBus
db InterfaceAndPeerDatabaseRepo
wg InterfaceController
quick WgQuickController
cfg *config.Config
bus EventBus
db InterfaceAndPeerDatabaseRepo
wg *ControllerManager
userLockMap *sync.Map
}
@@ -84,8 +64,7 @@ type Manager struct {
func NewWireGuardManager(
cfg *config.Config,
bus EventBus,
wg InterfaceController,
quick WgQuickController,
wg *ControllerManager,
db InterfaceAndPeerDatabaseRepo,
) (*Manager, error) {
m := &Manager{
@@ -93,7 +72,6 @@ func NewWireGuardManager(
bus: bus,
wg: wg,
db: db,
quick: quick,
userLockMap: &sync.Map{},
}

View File

@@ -11,24 +11,10 @@ import (
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/app/audit"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
// GetImportableInterfaces returns all physical interfaces that are available on the system.
// This function also returns interfaces that are already available in the database.
func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err
}
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
if err != nil {
return nil, err
}
return physicalInterfaces, nil
}
// GetInterfaceAndPeers returns the interface and all peers for the given interface identifier.
func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
*domain.Interface,
@@ -104,52 +90,64 @@ func (m Manager) GetUserInterfaces(ctx context.Context, _ domain.UserIdentifier)
}
// ImportNewInterfaces imports all new physical interfaces that are available on the system.
// If a filter is set, only interfaces that match the filter will be imported.
func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return 0, err
}
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
var existingInterfaceIds []domain.InterfaceIdentifier
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return 0, err
}
// if no filter is given, exclude already existing interfaces
var excludedInterfaces []domain.InterfaceIdentifier
if len(filter) == 0 {
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return 0, err
}
for _, existingInterface := range existingInterfaces {
excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier)
}
for _, existingInterface := range existingInterfaces {
existingInterfaceIds = append(existingInterfaceIds, existingInterface.Identifier)
}
imported := 0
for _, physicalInterface := range physicalInterfaces {
if slices.Contains(excludedInterfaces, physicalInterface.Identifier) {
continue
}
if len(filter) != 0 && !slices.Contains(filter, physicalInterface.Identifier) {
continue
}
slog.Info("importing new interface", "interface", physicalInterface.Identifier)
physicalPeers, err := m.wg.GetPeers(ctx, physicalInterface.Identifier)
for _, wgBackend := range m.wg.GetAllControllers() {
physicalInterfaces, err := wgBackend.Implementation.GetInterfaces(ctx)
if err != nil {
return 0, err
}
err = m.importInterface(ctx, &physicalInterface, physicalPeers)
if err != nil {
return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err)
}
for _, physicalInterface := range physicalInterfaces {
if slices.Contains(wgBackend.Config.IgnoredInterfaces, string(physicalInterface.Identifier)) {
slog.Info("ignoring interface due to backend filter restrictions",
"interface", physicalInterface.Identifier, "filter", wgBackend.Config.IgnoredInterfaces,
"backend", wgBackend.Config.Id)
continue // skip ignored interfaces
}
slog.Info("imported new interface", "interface", physicalInterface.Identifier, "peers", len(physicalPeers))
imported++
if slices.Contains(existingInterfaceIds, physicalInterface.Identifier) {
continue // skip interfaces that already exist
}
if len(filter) > 0 && !slices.Contains(filter, physicalInterface.Identifier) {
slog.Info("ignoring interface due to filter restrictions",
"interface", physicalInterface.Identifier, "filter", wgBackend.Config.IgnoredInterfaces,
"backend", wgBackend.Config.Id)
continue
}
slog.Info("importing new interface",
"interface", physicalInterface.Identifier, "backend", wgBackend.Config.Id)
physicalPeers, err := wgBackend.Implementation.GetPeers(ctx, physicalInterface.Identifier)
if err != nil {
return 0, err
}
err = m.importInterface(ctx, wgBackend.Implementation, &physicalInterface, physicalPeers)
if err != nil {
return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err)
}
slog.Info("imported new interface",
"interface", physicalInterface.Identifier, "peers", len(physicalPeers), "backend", wgBackend.Config.Id)
imported++
}
}
return imported, nil
@@ -213,9 +211,20 @@ func (m Manager) RestoreInterfaceState(
return fmt.Errorf("failed to load peers for %s: %w", iface.Identifier, err)
}
_, err = m.wg.GetInterface(ctx, iface.Identifier)
controller := m.wg.GetController(iface)
_, err = controller.GetInterface(ctx, iface.Identifier)
if err != nil && !iface.IsDisabled() {
slog.Debug("creating missing interface", "interface", iface.Identifier)
slog.Debug("creating missing interface", "interface", iface.Identifier, "backend", controller.GetId())
// temporarily disable interface in database so that the current state is reflected correctly
_ = m.db.SaveInterface(ctx, iface.Identifier,
func(in *domain.Interface) (*domain.Interface, error) {
now := time.Now()
in.Disabled = &now // set
in.DisabledReason = domain.DisabledReasonInterfaceMissing
return in, nil
})
// temporarily disable interface in database so that the current state is reflected correctly
_ = m.db.SaveInterface(ctx, iface.Identifier,
@@ -242,7 +251,8 @@ func (m Manager) RestoreInterfaceState(
return fmt.Errorf("failed to create physical interface %s: %w", iface.Identifier, err)
}
} else {
slog.Debug("restoring interface state", "interface", iface.Identifier, "disabled", iface.IsDisabled())
slog.Debug("restoring interface state",
"interface", iface.Identifier, "disabled", iface.IsDisabled(), "backend", controller.GetId())
// try to move interface to stored state
_, err = m.saveInterface(ctx, &iface)
@@ -269,18 +279,14 @@ func (m Manager) RestoreInterfaceState(
// restore peers
for _, peer := range peers {
switch {
case iface.IsDisabled(): // if interface is disabled, delete all peers
if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil {
case iface.IsDisabled() && iface.Backend == config.LocalBackendName: // if interface is disabled, delete all peers
if err := controller.DeletePeer(ctx, iface.Identifier,
peer.Identifier); err != nil {
return fmt.Errorf("failed to remove peer %s for disabled interface %s: %w",
peer.Identifier, iface.Identifier, err)
}
case peer.IsDisabled(): // if peer is disabled, delete it
if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil {
return fmt.Errorf("failed to remove disbaled peer %s from interface %s: %w",
peer.Identifier, iface.Identifier, err)
}
default: // update peer
err := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier,
err := controller.SavePeer(ctx, iface.Identifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, &peer)
return pp, nil
@@ -293,7 +299,7 @@ func (m Manager) RestoreInterfaceState(
}
// remove non-wgportal peers
physicalPeers, _ := m.wg.GetPeers(ctx, iface.Identifier)
physicalPeers, _ := controller.GetPeers(ctx, iface.Identifier)
for _, physicalPeer := range physicalPeers {
isWgPortalPeer := false
for _, peer := range peers {
@@ -303,7 +309,8 @@ func (m Manager) RestoreInterfaceState(
}
}
if !isWgPortalPeer {
err := m.wg.DeletePeer(ctx, iface.Identifier, domain.PeerIdentifier(physicalPeer.PublicKey))
err := controller.DeletePeer(ctx, iface.Identifier,
domain.PeerIdentifier(physicalPeer.PublicKey))
if err != nil {
return fmt.Errorf("failed to remove non-wgportal peer %s from interface %s: %w",
physicalPeer.PublicKey, iface.Identifier, err)
@@ -455,7 +462,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
return err
}
existingInterface, err := m.db.GetInterface(ctx, id)
existingInterface, existingPeers, err := m.db.GetInterfaceAndPeers(ctx, id)
if err != nil {
return fmt.Errorf("unable to find interface %s: %w", id, err)
}
@@ -464,25 +471,33 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
return fmt.Errorf("deletion not allowed: %w", err)
}
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
Interface: *existingInterface,
AllowedIps: existingInterface.GetAllowedIPs(existingPeers),
FwMark: existingInterface.FirewallMark,
Table: existingInterface.GetRoutingTable(),
TableStr: existingInterface.RoutingTable,
IsDeleted: true,
})
now := time.Now()
existingInterface.Disabled = &now // simulate a disabled interface
existingInterface.DisabledReason = domain.DisabledReasonDeleted
physicalInterface, _ := m.wg.GetInterface(ctx, id)
if err := m.handleInterfacePreSaveHooks(true, existingInterface); err != nil {
if err := m.handleInterfacePreSaveHooks(ctx, existingInterface, !existingInterface.IsDisabled(),
false); err != nil {
return fmt.Errorf("pre-delete hooks failed: %w", err)
}
if err := m.handleInterfacePreSaveActions(existingInterface); err != nil {
if err := m.handleInterfacePreSaveActions(ctx, existingInterface); err != nil {
return fmt.Errorf("pre-delete actions failed: %w", err)
}
if err := m.deleteInterfacePeers(ctx, id); err != nil {
if err := m.deleteInterfacePeers(ctx, existingInterface, existingPeers); err != nil {
return fmt.Errorf("peer deletion failure: %w", err)
}
if err := m.wg.DeleteInterface(ctx, id); err != nil {
if err := m.wg.GetController(*existingInterface).DeleteInterface(ctx, id); err != nil {
return fmt.Errorf("wireguard deletion failure: %w", err)
}
@@ -490,16 +505,12 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
return fmt.Errorf("deletion failure: %w", err)
}
fwMark := existingInterface.FirewallMark
if physicalInterface != nil && fwMark == 0 {
fwMark = physicalInterface.FirewallMark
}
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
FwMark: fwMark,
Table: existingInterface.GetRoutingTable(),
})
if err := m.handleInterfacePostSaveHooks(true, existingInterface); err != nil {
if err := m.handleInterfacePostSaveHooks(
ctx,
existingInterface,
!existingInterface.IsDisabled(),
false,
); err != nil {
return fmt.Errorf("post-delete hooks failed: %w", err)
}
@@ -518,20 +529,24 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
return nil, fmt.Errorf("interface validation failed: %w", err)
}
stateChanged := m.hasInterfaceStateChanged(ctx, iface)
oldEnabled, newEnabled, routeTableChanged := false, !iface.IsDisabled(), false // if the interface did not exist, we assume it was not enabled
oldInterface, err := m.db.GetInterface(ctx, iface.Identifier)
if err == nil {
oldEnabled, newEnabled, routeTableChanged = m.getInterfaceStateHistory(oldInterface, iface)
}
if err := m.handleInterfacePreSaveHooks(stateChanged, iface); err != nil {
if err := m.handleInterfacePreSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil {
return nil, fmt.Errorf("pre-save hooks failed: %w", err)
}
if err := m.handleInterfacePreSaveActions(iface); err != nil {
if err := m.handleInterfacePreSaveActions(ctx, iface); err != nil {
return nil, fmt.Errorf("pre-save actions failed: %w", err)
}
err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
err = m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
iface.CopyCalculatedAttributes(i)
err := m.wg.SaveInterface(ctx, iface.Identifier,
err := m.wg.GetController(*iface).SaveInterface(ctx, iface.Identifier,
func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
domain.MergeToPhysicalInterface(pi, iface)
return pi, nil
@@ -546,24 +561,84 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
return nil, fmt.Errorf("failed to save interface: %w", err)
}
if iface.IsDisabled() {
physicalInterface, _ := m.wg.GetInterface(ctx, iface.Identifier)
fwMark := iface.FirewallMark
if physicalInterface != nil && fwMark == 0 {
fwMark = physicalInterface.FirewallMark
}
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
FwMark: fwMark,
Table: iface.GetRoutingTable(),
// update the interface type of peers in db
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil {
return nil, fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
}
for _, peer := range peers {
err := m.db.SavePeer(ctx, peer.Identifier, func(_ *domain.Peer) (*domain.Peer, error) {
switch iface.Type {
case domain.InterfaceTypeAny:
peer.Interface.Type = domain.InterfaceTypeAny
case domain.InterfaceTypeClient:
peer.Interface.Type = domain.InterfaceTypeServer
case domain.InterfaceTypeServer:
peer.Interface.Type = domain.InterfaceTypeClient
}
return &peer, nil
})
} else {
m.bus.Publish(app.TopicRouteUpdate, "interface updated: "+string(iface.Identifier))
if err != nil {
return nil, fmt.Errorf("failed to update peer %s for interface %s: %w", peer.Identifier,
iface.Identifier, err)
}
}
if err := m.handleInterfacePostSaveHooks(stateChanged, iface); err != nil {
if iface.IsDisabled() {
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
Interface: *iface,
AllowedIps: iface.GetAllowedIPs(peers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(),
TableStr: iface.RoutingTable,
})
} else {
m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
Interface: *iface,
AllowedIps: iface.GetAllowedIPs(peers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(),
TableStr: iface.RoutingTable,
})
// if the route table changed, ensure that the old entries are remove
if routeTableChanged {
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
Interface: *oldInterface,
AllowedIps: oldInterface.GetAllowedIPs(peers),
FwMark: oldInterface.FirewallMark,
Table: oldInterface.GetRoutingTable(),
TableStr: oldInterface.RoutingTable,
IsDeleted: true, // mark the old entries as deleted
})
}
}
if err := m.handleInterfacePostSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil {
return nil, fmt.Errorf("post-save hooks failed: %w", err)
}
// If the interface has just been enabled, restore its peers on the physical controller
if !oldEnabled && newEnabled && iface.Backend == config.LocalBackendName {
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil {
return nil, fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
}
for _, peer := range peers {
saveErr := m.wg.GetController(*iface).SavePeer(ctx, iface.Identifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, &peer)
return pp, nil
})
if saveErr != nil {
return nil, fmt.Errorf("failed to restore peer %s for interface %s: %w", peer.Identifier,
iface.Identifier, saveErr)
}
}
// notify that peers for this interface have changed so config/routes can be updated
m.bus.Publish(app.TopicPeerInterfaceUpdated, iface.Identifier)
}
m.bus.Publish(app.TopicAuditInterfaceChanged, domain.AuditEventWrapper[audit.InterfaceEvent]{
Ctx: ctx,
Event: audit.InterfaceEvent{
@@ -575,75 +650,90 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
return iface, nil
}
func (m Manager) hasInterfaceStateChanged(ctx context.Context, iface *domain.Interface) bool {
oldInterface, err := m.db.GetInterface(ctx, iface.Identifier)
if err != nil {
return false
}
if oldInterface.IsDisabled() != iface.IsDisabled() {
return true // interface in db has changed
}
wgInterface, err := m.wg.GetInterface(ctx, iface.Identifier)
if err != nil {
return true // interface might not exist - so we assume that there must be a change
}
// compare physical interface settings
if len(wgInterface.Addresses) != len(iface.Addresses) ||
wgInterface.Mtu != iface.Mtu ||
wgInterface.FirewallMark != iface.FirewallMark ||
wgInterface.ListenPort != iface.ListenPort ||
wgInterface.PrivateKey != iface.PrivateKey ||
wgInterface.PublicKey != iface.PublicKey {
return true
}
return false
func (m Manager) getInterfaceStateHistory(
oldInterface *domain.Interface,
iface *domain.Interface,
) (oldEnabled, newEnabled, routeTableChanged bool) {
return !oldInterface.IsDisabled(), !iface.IsDisabled(), oldInterface.RoutingTable != iface.RoutingTable
}
func (m Manager) handleInterfacePreSaveActions(iface *domain.Interface) error {
if !iface.IsDisabled() {
if err := m.quick.SetDNS(iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
return fmt.Errorf("failed to update dns settings: %w", err)
}
} else {
if err := m.quick.UnsetDNS(iface.Identifier); err != nil {
return fmt.Errorf("failed to clear dns settings: %w", err)
func (m Manager) handleInterfacePreSaveActions(ctx context.Context, iface *domain.Interface) error {
wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
if !ok {
slog.Warn("failed to perform pre-save actions", "interface", iface.Identifier,
"error", "no capable controller found")
return nil
}
// update DNS settings only for client interfaces
if iface.Type == domain.InterfaceTypeClient || iface.Type == domain.InterfaceTypeAny {
if !iface.IsDisabled() {
if err := wgQuickController.SetDNS(ctx, iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
return fmt.Errorf("failed to update dns settings: %w", err)
}
} else {
if err := wgQuickController.UnsetDNS(ctx, iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
return fmt.Errorf("failed to clear dns settings: %w", err)
}
}
}
return nil
}
func (m Manager) handleInterfacePreSaveHooks(stateChanged bool, iface *domain.Interface) error {
if !stateChanged {
func (m Manager) handleInterfacePreSaveHooks(
ctx context.Context,
iface *domain.Interface,
oldEnabled, newEnabled bool,
) error {
if oldEnabled == newEnabled {
return nil // do nothing if state did not change
}
if !iface.IsDisabled() {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreUp); err != nil {
slog.Debug("executing pre-save hooks", "interface", iface.Identifier, "up", newEnabled)
wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
if !ok {
slog.Warn("failed to execute pre-save hooks", "interface", iface.Identifier, "up", newEnabled,
"error", "no capable controller found")
return nil
}
if newEnabled {
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PreUp); err != nil {
return fmt.Errorf("failed to execute pre-up hook: %w", err)
}
} else {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreDown); err != nil {
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PreDown); err != nil {
return fmt.Errorf("failed to execute pre-down hook: %w", err)
}
}
return nil
}
func (m Manager) handleInterfacePostSaveHooks(stateChanged bool, iface *domain.Interface) error {
if !stateChanged {
func (m Manager) handleInterfacePostSaveHooks(
ctx context.Context,
iface *domain.Interface,
oldEnabled, newEnabled bool,
) error {
if oldEnabled == newEnabled {
return nil // do nothing if state did not change
}
if !iface.IsDisabled() {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostUp); err != nil {
slog.Debug("executing post-save hooks", "interface", iface.Identifier, "up", newEnabled)
wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
if !ok {
slog.Warn("failed to execute post-save hooks", "interface", iface.Identifier, "up", newEnabled,
"error", "no capable controller found")
return nil
}
if newEnabled {
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PostUp); err != nil {
return fmt.Errorf("failed to execute post-up hook: %w", err)
}
} else {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostDown); err != nil {
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PostDown); err != nil {
return fmt.Errorf("failed to execute post-down hook: %w", err)
}
}
@@ -769,7 +859,12 @@ func (m Manager) getFreshListenPort(ctx context.Context) (port int, err error) {
return
}
func (m Manager) importInterface(ctx context.Context, in *domain.PhysicalInterface, peers []domain.PhysicalPeer) error {
func (m Manager) importInterface(
ctx context.Context,
backend domain.InterfaceController,
in *domain.PhysicalInterface,
peers []domain.PhysicalPeer,
) error {
now := time.Now()
iface := domain.ConvertPhysicalInterface(in)
iface.BaseModel = domain.BaseModel{
@@ -778,8 +873,20 @@ func (m Manager) importInterface(ctx context.Context, in *domain.PhysicalInterfa
CreatedAt: now,
UpdatedAt: now,
}
iface.Backend = backend.GetId()
iface.PeerDefAllowedIPsStr = iface.AddressStr()
// try to predict the interface type based on the number of peers
switch len(peers) {
case 0:
iface.Type = domain.InterfaceTypeAny // no peers means this is an unknown interface
case 1:
iface.Type = domain.InterfaceTypeClient // one peer means this is a client interface
default: // multiple peers means this is a server interface
iface.Type = domain.InterfaceTypeServer
}
existingInterface, err := m.db.GetInterface(ctx, iface.Identifier)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return err
@@ -830,16 +937,20 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain
peer.Interface.PreDown = domain.NewConfigOption(in.PeerDefPreDown, true)
peer.Interface.PostDown = domain.NewConfigOption(in.PeerDefPostDown, true)
var displayName string
switch in.Type {
case domain.InterfaceTypeAny:
peer.Interface.Type = domain.InterfaceTypeAny
peer.DisplayName = "Autodetected Peer (" + peer.Interface.PublicKey[0:8] + ")"
displayName = "Autodetected Peer (" + peer.Interface.PublicKey[0:8] + ")"
case domain.InterfaceTypeClient:
peer.Interface.Type = domain.InterfaceTypeServer
peer.DisplayName = "Autodetected Endpoint (" + peer.Interface.PublicKey[0:8] + ")"
displayName = "Autodetected Endpoint (" + peer.Interface.PublicKey[0:8] + ")"
case domain.InterfaceTypeServer:
peer.Interface.Type = domain.InterfaceTypeClient
peer.DisplayName = "Autodetected Client (" + peer.Interface.PublicKey[0:8] + ")"
displayName = "Autodetected Client (" + peer.Interface.PublicKey[0:8] + ")"
}
if peer.DisplayName == "" {
peer.DisplayName = displayName // use auto-generated display name if not set
}
err := m.db.SavePeer(ctx, peer.Identifier, func(_ *domain.Peer) (*domain.Peer, error) {
@@ -852,13 +963,9 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain
return nil
}
func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) error {
allPeers, err := m.db.GetInterfacePeers(ctx, id)
if err != nil {
return err
}
func (m Manager) deleteInterfacePeers(ctx context.Context, iface *domain.Interface, allPeers []domain.Peer) error {
for _, peer := range allPeers {
err = m.wg.DeletePeer(ctx, id, peer.Identifier)
err := m.wg.GetController(*iface).DeletePeer(ctx, iface.Identifier, peer.Identifier)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err)
}

View File

@@ -188,6 +188,32 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
sessionUser := domain.GetUserInfo(ctx)
peer.Identifier = domain.PeerIdentifier(peer.Interface.PublicKey) // ensure that identifier corresponds to the public key
// Enforce peer limit for non-admin users if LimitAdditionalUserPeers is set
if m.cfg.Core.SelfProvisioningAllowed && !sessionUser.IsAdmin && m.cfg.Advanced.LimitAdditionalUserPeers > 0 {
peers, err := m.db.GetUserPeers(ctx, peer.UserIdentifier)
if err != nil {
return nil, fmt.Errorf("failed to fetch peers for user %s: %w", peer.UserIdentifier, err)
}
// Count enabled peers (disabled IS NULL)
peerCount := 0
for _, p := range peers {
if !p.IsDisabled() {
peerCount++
}
}
totalAllowedPeers := 1 + m.cfg.Advanced.LimitAdditionalUserPeers // 1 default peer + x additional peers
if peerCount >= totalAllowedPeers {
slog.WarnContext(ctx, "peer creation blocked due to limit",
"user", peer.UserIdentifier,
"current_count", peerCount,
"allowed_count", totalAllowedPeers)
return nil, fmt.Errorf("peer limit reached (%d peers allowed): %w", totalAllowedPeers,
domain.ErrNoPermission)
}
}
existingPeer, err := m.db.GetPeer(ctx, peer.Identifier)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return nil, fmt.Errorf("unable to load existing peer %s: %w", peer.Identifier, err)
@@ -347,7 +373,12 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
return fmt.Errorf("delete not allowed: %w", err)
}
err = m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, id)
iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
if err != nil {
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
}
err = m.wg.GetController(*iface).DeletePeer(ctx, peer.InterfaceIdentifier, id)
if err != nil {
return fmt.Errorf("wireguard failed to delete peer %s: %w", id, err)
}
@@ -357,9 +388,20 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
return fmt.Errorf("failed to delete peer %s: %w", id, err)
}
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil {
return fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
}
m.bus.Publish(app.TopicPeerDeleted, *peer)
// Update routes after peers have changed
m.bus.Publish(app.TopicRouteUpdate, "peers updated")
m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
Interface: *iface,
AllowedIps: iface.GetAllowedIPs(peers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(),
TableStr: iface.RoutingTable,
})
// Update interface after peers have changed
m.bus.Publish(app.TopicPeerInterfaceUpdated, peer.InterfaceIdentifier)
@@ -407,37 +449,36 @@ func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier)
// region helper-functions
func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
interfaces := make(map[domain.InterfaceIdentifier]struct{})
interfaces := make(map[domain.InterfaceIdentifier]domain.Interface)
for i := range peers {
peer := peers[i]
var err error
if peer.IsDisabled() || peer.IsExpired() {
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
peer.CopyCalculatedAttributes(p)
if err := m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, peer.Identifier); err != nil {
return nil, fmt.Errorf("failed to delete wireguard peer %s: %w", peer.Identifier, err)
}
return peer, nil
})
} else {
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
peer.CopyCalculatedAttributes(p)
err := m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, peer)
return pp, nil
})
if err != nil {
return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
}
return peer, nil
})
for _, peer := range peers {
// get interface from db if it is not yet in the map
if _, ok := interfaces[peer.InterfaceIdentifier]; !ok {
iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
if err != nil {
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
}
interfaces[peer.InterfaceIdentifier] = *iface
}
iface := interfaces[peer.InterfaceIdentifier]
// Always save the peer to the backend, regardless of disabled/expired state
// The backend will handle the disabled state appropriately
err := m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
peer.CopyCalculatedAttributes(p)
err := m.wg.GetController(iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, peer)
return pp, nil
})
if err != nil {
return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
}
return peer, nil
})
if err != nil {
return fmt.Errorf("save failure for peer %s: %w", peer.Identifier, err)
}
@@ -451,13 +492,22 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
Peer: *peer,
},
})
interfaces[peer.InterfaceIdentifier] = struct{}{}
}
// Update routes after peers have changed
if len(interfaces) != 0 {
m.bus.Publish(app.TopicRouteUpdate, "peers updated")
for id, iface := range interfaces {
interfacePeers, err := m.db.GetInterfacePeers(ctx, id)
if err != nil {
return fmt.Errorf("failed to re-load peers for interface %s: %w", id, err)
}
m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
Interface: iface,
AllowedIps: iface.GetAllowedIPs(interfacePeers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(),
TableStr: iface.RoutingTable,
})
}
for iface := range interfaces {

View File

@@ -0,0 +1,194 @@
package wireguard
import (
"context"
"testing"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
// --- Test mocks ---
type mockBus struct{}
func (f *mockBus) Publish(topic string, args ...any) {}
func (f *mockBus) Subscribe(topic string, fn interface{}) error { return nil }
type mockController struct{}
func (f *mockController) GetId() domain.InterfaceBackend { return "local" }
func (f *mockController) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) {
return nil, nil
}
func (f *mockController) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (
*domain.PhysicalInterface,
error,
) {
return &domain.PhysicalInterface{Identifier: id}, nil
}
func (f *mockController) GetPeers(_ context.Context, _ domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) {
return nil, nil
}
func (f *mockController) SaveInterface(
_ context.Context,
_ domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error {
_, _ = updateFunc(&domain.PhysicalInterface{})
return nil
}
func (f *mockController) DeleteInterface(_ context.Context, _ domain.InterfaceIdentifier) error {
return nil
}
func (f *mockController) SavePeer(
_ context.Context,
_ domain.InterfaceIdentifier,
_ domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error {
_, _ = updateFunc(&domain.PhysicalPeer{})
return nil
}
func (f *mockController) DeletePeer(_ context.Context, _ domain.InterfaceIdentifier, _ domain.PeerIdentifier) error {
return nil
}
func (f *mockController) PingAddresses(_ context.Context, _ string) (*domain.PingerResult, error) {
return nil, nil
}
type mockDB struct {
savedPeers map[domain.PeerIdentifier]*domain.Peer
iface *domain.Interface
}
func (f *mockDB) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) {
if f.iface != nil && f.iface.Identifier == id {
return f.iface, nil
}
return &domain.Interface{Identifier: id}, nil
}
func (f *mockDB) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
*domain.Interface,
[]domain.Peer,
error,
) {
return f.iface, nil, nil
}
func (f *mockDB) GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error) {
return nil, nil
}
func (f *mockDB) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) { return nil, nil }
func (f *mockDB) GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error) {
return nil, nil
}
func (f *mockDB) SaveInterface(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(in *domain.Interface) (*domain.Interface, error),
) error {
if f.iface == nil {
f.iface = &domain.Interface{Identifier: id}
}
var err error
f.iface, err = updateFunc(f.iface)
return err
}
func (f *mockDB) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
return nil
}
func (f *mockDB) GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error) {
return nil, nil
}
func (f *mockDB) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
return nil, nil
}
func (f *mockDB) SavePeer(
ctx context.Context,
id domain.PeerIdentifier,
updateFunc func(in *domain.Peer) (*domain.Peer, error),
) error {
if f.savedPeers == nil {
f.savedPeers = make(map[domain.PeerIdentifier]*domain.Peer)
}
existing := f.savedPeers[id]
if existing == nil {
existing = &domain.Peer{Identifier: id}
}
updated, err := updateFunc(existing)
if err != nil {
return err
}
f.savedPeers[updated.Identifier] = updated
return nil
}
func (f *mockDB) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error { return nil }
func (f *mockDB) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) {
return nil, domain.ErrNotFound
}
func (f *mockDB) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (
map[domain.Cidr][]domain.Cidr,
error,
) {
return map[domain.Cidr][]domain.Cidr{}, nil
}
// --- Test ---
func TestCreatePeer_SetsIdentifier_FromPublicKey(t *testing.T) {
// Arrange
cfg := &config.Config{}
cfg.Core.SelfProvisioningAllowed = true
cfg.Core.EditableKeys = true
cfg.Advanced.LimitAdditionalUserPeers = 0
bus := &mockBus{}
// Prepare a controller manager with our mock controller
ctrlMgr := &ControllerManager{
controllers: map[domain.InterfaceBackend]backendInstance{
config.LocalBackendName: {Implementation: &mockController{}},
},
}
db := &mockDB{iface: &domain.Interface{Identifier: "wg0", Type: domain.InterfaceTypeServer}}
m := Manager{
cfg: cfg,
bus: bus,
db: db,
wg: ctrlMgr,
}
userId := domain.UserIdentifier("user@example.com")
ctx := domain.SetUserInfo(context.Background(), &domain.ContextUserInfo{Id: userId, IsAdmin: false})
pubKey := "TEST_PUBLIC_KEY_ABC123"
input := &domain.Peer{
Identifier: "should_be_overwritten",
UserIdentifier: userId,
InterfaceIdentifier: domain.InterfaceIdentifier("wg0"),
Interface: domain.PeerInterfaceConfig{
KeyPair: domain.KeyPair{PublicKey: pubKey},
},
}
// Act
out, err := m.CreatePeer(ctx, input)
// Assert
if err != nil {
t.Fatalf("CreatePeer returned error: %v", err)
}
expectedId := domain.PeerIdentifier(pubKey)
if out.Identifier != expectedId {
t.Fatalf("expected Identifier to be set from public key %q, got %q", expectedId, out.Identifier)
}
// Ensure the saved peer in DB also has the expected identifier
if db.savedPeers[expectedId] == nil {
t.Fatalf("expected peer with identifier %q to be saved in DB", expectedId)
}
}