mirror of
https://github.com/h44z/wg-portal.git
synced 2025-11-18 23:06:17 +00:00
Merge branch 'master' into stable
# Conflicts: # internal/domain/peer.go
This commit is contained in:
142
internal/app/wireguard/controller_manager.go
Normal file
142
internal/app/wireguard/controller_manager.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"slices"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/adapters/wgcontroller"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type backendInstance struct {
|
||||
Config config.BackendBase // Config is the configuration for the backend instance.
|
||||
Implementation domain.InterfaceController
|
||||
}
|
||||
|
||||
type ControllerManager struct {
|
||||
cfg *config.Config
|
||||
controllers map[domain.InterfaceBackend]backendInstance
|
||||
}
|
||||
|
||||
func NewControllerManager(cfg *config.Config) (*ControllerManager, error) {
|
||||
c := &ControllerManager{
|
||||
cfg: cfg,
|
||||
controllers: make(map[domain.InterfaceBackend]backendInstance),
|
||||
}
|
||||
|
||||
err := c.init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *ControllerManager) init() error {
|
||||
if err := c.registerLocalController(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.registerMikrotikControllers(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.logRegisteredControllers()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ControllerManager) registerLocalController() error {
|
||||
localController, err := wgcontroller.NewLocalController(c.cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create local WireGuard controller: %w", err)
|
||||
}
|
||||
|
||||
c.controllers[config.LocalBackendName] = backendInstance{
|
||||
Config: config.BackendBase{
|
||||
Id: config.LocalBackendName,
|
||||
DisplayName: "Local WireGuard Controller",
|
||||
IgnoredInterfaces: c.cfg.Backend.IgnoredLocalInterfaces,
|
||||
},
|
||||
Implementation: localController,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ControllerManager) registerMikrotikControllers() error {
|
||||
for _, backendConfig := range c.cfg.Backend.Mikrotik {
|
||||
if backendConfig.Id == config.LocalBackendName {
|
||||
slog.Warn("skipping registration of Mikrotik controller with reserved ID", "id", config.LocalBackendName)
|
||||
continue
|
||||
}
|
||||
|
||||
controller, err := wgcontroller.NewMikrotikController(c.cfg, &backendConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Mikrotik controller for backend %s: %w", backendConfig.Id, err)
|
||||
}
|
||||
|
||||
c.controllers[domain.InterfaceBackend(backendConfig.Id)] = backendInstance{
|
||||
Config: backendConfig.BackendBase,
|
||||
Implementation: controller,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ControllerManager) logRegisteredControllers() {
|
||||
for backend, controller := range c.controllers {
|
||||
slog.Debug("backend controller registered",
|
||||
"backend", backend, "type", fmt.Sprintf("%T", controller.Implementation))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) domain.InterfaceController {
|
||||
return c.getController(backend, "").Implementation
|
||||
}
|
||||
|
||||
func (c *ControllerManager) GetController(iface domain.Interface) domain.InterfaceController {
|
||||
return c.getController(iface.Backend, iface.Identifier).Implementation
|
||||
}
|
||||
|
||||
func (c *ControllerManager) getController(
|
||||
backend domain.InterfaceBackend,
|
||||
ifaceId domain.InterfaceIdentifier,
|
||||
) backendInstance {
|
||||
if backend == "" {
|
||||
// If no backend is specified, use the local controller.
|
||||
// This might be the case for interfaces created in previous WireGuard Portal versions.
|
||||
backend = config.LocalBackendName
|
||||
}
|
||||
|
||||
controller, exists := c.controllers[backend]
|
||||
if !exists {
|
||||
controller, exists = c.controllers[config.LocalBackendName] // Fallback to local controller
|
||||
if !exists {
|
||||
// If the local controller is also not found, panic
|
||||
panic(fmt.Sprintf("%s interface controller for backend %s not found", ifaceId, backend))
|
||||
}
|
||||
slog.Warn("controller for backend not found, using local controller",
|
||||
"backend", backend, "interface", ifaceId)
|
||||
}
|
||||
return controller
|
||||
}
|
||||
|
||||
func (c *ControllerManager) GetAllControllers() []backendInstance {
|
||||
var backendInstances = make([]backendInstance, 0, len(c.controllers))
|
||||
for instance := range maps.Values(c.controllers) {
|
||||
backendInstances = append(backendInstances, instance)
|
||||
}
|
||||
return backendInstances
|
||||
}
|
||||
|
||||
func (c *ControllerManager) GetControllerNames() []config.BackendBase {
|
||||
var names []config.BackendBase
|
||||
for _, id := range slices.Sorted(maps.Keys(c.controllers)) {
|
||||
names = append(names, c.controllers[id].Config)
|
||||
}
|
||||
|
||||
return names
|
||||
}
|
||||
@@ -6,8 +6,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
probing "github.com/prometheus-community/pro-bing"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
@@ -30,11 +28,6 @@ type StatisticsDatabaseRepo interface {
|
||||
DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error
|
||||
}
|
||||
|
||||
type StatisticsInterfaceController interface {
|
||||
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
||||
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
||||
}
|
||||
|
||||
type StatisticsMetricsServer interface {
|
||||
UpdateInterfaceMetrics(status domain.InterfaceStatus)
|
||||
UpdatePeerMetrics(peer *domain.Peer, status domain.PeerStatus)
|
||||
@@ -43,6 +36,13 @@ type StatisticsMetricsServer interface {
|
||||
type StatisticsEventBus interface {
|
||||
// Subscribe subscribes to a topic
|
||||
Subscribe(topic string, fn interface{}) error
|
||||
// Publish sends a message to the message bus.
|
||||
Publish(topic string, args ...any)
|
||||
}
|
||||
|
||||
type pingJob struct {
|
||||
Peer domain.Peer
|
||||
Backend domain.InterfaceBackend
|
||||
}
|
||||
|
||||
type StatisticsCollector struct {
|
||||
@@ -50,11 +50,13 @@ type StatisticsCollector struct {
|
||||
bus StatisticsEventBus
|
||||
|
||||
pingWaitGroup sync.WaitGroup
|
||||
pingJobs chan domain.Peer
|
||||
pingJobs chan pingJob
|
||||
|
||||
db StatisticsDatabaseRepo
|
||||
wg StatisticsInterfaceController
|
||||
wg *ControllerManager
|
||||
ms StatisticsMetricsServer
|
||||
|
||||
peerChangeEvent chan domain.PeerIdentifier
|
||||
}
|
||||
|
||||
// NewStatisticsCollector creates a new statistics collector.
|
||||
@@ -62,7 +64,7 @@ func NewStatisticsCollector(
|
||||
cfg *config.Config,
|
||||
bus StatisticsEventBus,
|
||||
db StatisticsDatabaseRepo,
|
||||
wg StatisticsInterfaceController,
|
||||
wg *ControllerManager,
|
||||
ms StatisticsMetricsServer,
|
||||
) (*StatisticsCollector, error) {
|
||||
c := &StatisticsCollector{
|
||||
@@ -113,7 +115,7 @@ func (c *StatisticsCollector) collectInterfaceData(ctx context.Context) {
|
||||
}
|
||||
|
||||
for _, in := range interfaces {
|
||||
physicalInterface, err := c.wg.GetInterface(ctx, in.Identifier)
|
||||
physicalInterface, err := c.wg.GetController(in).GetInterface(ctx, in.Identifier)
|
||||
if err != nil {
|
||||
slog.Warn("failed to load physical interface for data collection", "interface", in.Identifier,
|
||||
"error", err)
|
||||
@@ -165,14 +167,18 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
|
||||
}
|
||||
|
||||
for _, in := range interfaces {
|
||||
peers, err := c.wg.GetPeers(ctx, in.Identifier)
|
||||
peers, err := c.wg.GetController(in).GetPeers(ctx, in.Identifier)
|
||||
if err != nil {
|
||||
slog.Warn("failed to fetch peers for data collection", "interface", in.Identifier, "error", err)
|
||||
continue
|
||||
}
|
||||
for _, peer := range peers {
|
||||
var connectionStateChanged bool
|
||||
var newPeerStatus domain.PeerStatus
|
||||
err = c.db.UpdatePeerStatus(ctx, peer.Identifier,
|
||||
func(p *domain.PeerStatus) (*domain.PeerStatus, error) {
|
||||
wasConnected := p.IsConnected
|
||||
|
||||
var lastHandshake *time.Time
|
||||
if !peer.LastHandshake.IsZero() {
|
||||
lastHandshake = &peer.LastHandshake
|
||||
@@ -186,6 +192,13 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
|
||||
p.BytesTransmitted = peer.BytesDownload // store bytes that where received from the peer and sent by the server
|
||||
p.Endpoint = peer.Endpoint
|
||||
p.LastHandshake = lastHandshake
|
||||
p.CalcConnected()
|
||||
|
||||
if wasConnected != p.IsConnected {
|
||||
slog.Debug("peer connection state changed", "peer", peer.Identifier, "connected", p.IsConnected)
|
||||
connectionStateChanged = true
|
||||
newPeerStatus = *p // store new status for event publishing
|
||||
}
|
||||
|
||||
// Update prometheus metrics
|
||||
go c.updatePeerMetrics(ctx, *p)
|
||||
@@ -197,6 +210,17 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
|
||||
} else {
|
||||
slog.Debug("updated peer status", "peer", peer.Identifier)
|
||||
}
|
||||
|
||||
if connectionStateChanged {
|
||||
peerModel, err := c.db.GetPeer(ctx, peer.Identifier)
|
||||
if err != nil {
|
||||
slog.Error("failed to fetch peer for data collection", "peer", peer.Identifier, "error",
|
||||
err)
|
||||
continue
|
||||
}
|
||||
// publish event if connection state changed
|
||||
c.bus.Publish(app.TopicPeerStateChanged, newPeerStatus, *peerModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -245,7 +269,7 @@ func (c *StatisticsCollector) startPingWorkers(ctx context.Context) {
|
||||
|
||||
c.pingWaitGroup = sync.WaitGroup{}
|
||||
c.pingWaitGroup.Add(c.cfg.Statistics.PingCheckWorkers)
|
||||
c.pingJobs = make(chan domain.Peer, c.cfg.Statistics.PingCheckWorkers)
|
||||
c.pingJobs = make(chan pingJob, c.cfg.Statistics.PingCheckWorkers)
|
||||
|
||||
// start workers
|
||||
for i := 0; i < c.cfg.Statistics.PingCheckWorkers; i++ {
|
||||
@@ -288,7 +312,10 @@ func (c *StatisticsCollector) enqueuePingChecks(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
for _, peer := range peers {
|
||||
c.pingJobs <- peer
|
||||
c.pingJobs <- pingJob{
|
||||
Peer: peer,
|
||||
Backend: in.Backend,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -297,13 +324,21 @@ func (c *StatisticsCollector) enqueuePingChecks(ctx context.Context) {
|
||||
|
||||
func (c *StatisticsCollector) pingWorker(ctx context.Context) {
|
||||
defer c.pingWaitGroup.Done()
|
||||
for peer := range c.pingJobs {
|
||||
peerPingable := c.isPeerPingable(ctx, peer)
|
||||
for job := range c.pingJobs {
|
||||
peer := job.Peer
|
||||
backend := job.Backend
|
||||
|
||||
var connectionStateChanged bool
|
||||
var newPeerStatus domain.PeerStatus
|
||||
|
||||
peerPingable := c.isPeerPingable(ctx, backend, peer)
|
||||
slog.Debug("peer ping check completed", "peer", peer.Identifier, "pingable", peerPingable)
|
||||
|
||||
now := time.Now()
|
||||
err := c.db.UpdatePeerStatus(ctx, peer.Identifier,
|
||||
func(p *domain.PeerStatus) (*domain.PeerStatus, error) {
|
||||
wasConnected := p.IsConnected
|
||||
|
||||
if peerPingable {
|
||||
p.IsPingable = true
|
||||
p.LastPing = &now
|
||||
@@ -311,6 +346,13 @@ func (c *StatisticsCollector) pingWorker(ctx context.Context) {
|
||||
p.IsPingable = false
|
||||
p.LastPing = nil
|
||||
}
|
||||
p.UpdatedAt = time.Now()
|
||||
p.CalcConnected()
|
||||
|
||||
if wasConnected != p.IsConnected {
|
||||
connectionStateChanged = true
|
||||
newPeerStatus = *p // store new status for event publishing
|
||||
}
|
||||
|
||||
// Update prometheus metrics
|
||||
go c.updatePeerMetrics(ctx, *p)
|
||||
@@ -322,10 +364,19 @@ func (c *StatisticsCollector) pingWorker(ctx context.Context) {
|
||||
} else {
|
||||
slog.Debug("updated peer ping status", "peer", peer.Identifier)
|
||||
}
|
||||
|
||||
if connectionStateChanged {
|
||||
// publish event if connection state changed
|
||||
c.bus.Publish(app.TopicPeerStateChanged, newPeerStatus, peer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Peer) bool {
|
||||
func (c *StatisticsCollector) isPeerPingable(
|
||||
ctx context.Context,
|
||||
backend domain.InterfaceBackend,
|
||||
peer domain.Peer,
|
||||
) bool {
|
||||
if !c.cfg.Statistics.UsePingChecks {
|
||||
return false
|
||||
}
|
||||
@@ -335,23 +386,13 @@ func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Pe
|
||||
return false
|
||||
}
|
||||
|
||||
pinger, err := probing.NewPinger(checkAddr)
|
||||
stats, err := c.wg.GetControllerByName(backend).PingAddresses(ctx, checkAddr)
|
||||
if err != nil {
|
||||
slog.Debug("failed to instantiate pinger", "peer", peer.Identifier, "address", checkAddr, "error", err)
|
||||
slog.Debug("failed to ping peer", "peer", peer.Identifier, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
checkCount := 1
|
||||
pinger.SetPrivileged(!c.cfg.Statistics.PingUnprivileged)
|
||||
pinger.Count = checkCount
|
||||
pinger.Timeout = 2 * time.Second
|
||||
err = pinger.RunWithContext(ctx) // Blocks until finished.
|
||||
if err != nil {
|
||||
slog.Debug("pinger for peer exited unexpectedly", "peer", peer.Identifier, "address", checkAddr, "error", err)
|
||||
return false
|
||||
}
|
||||
stats := pinger.Statistics()
|
||||
return stats.PacketsRecv == checkCount
|
||||
return stats.IsPingable()
|
||||
}
|
||||
|
||||
func (c *StatisticsCollector) updateInterfaceMetrics(status domain.InterfaceStatus) {
|
||||
|
||||
@@ -37,29 +37,10 @@ type InterfaceAndPeerDatabaseRepo interface {
|
||||
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
|
||||
}
|
||||
|
||||
type InterfaceController interface {
|
||||
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
|
||||
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
||||
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
||||
SaveInterface(
|
||||
_ context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
||||
) error
|
||||
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
||||
SavePeer(
|
||||
_ context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
||||
) error
|
||||
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
||||
}
|
||||
|
||||
type WgQuickController interface {
|
||||
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
|
||||
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
|
||||
UnsetDNS(id domain.InterfaceIdentifier) error
|
||||
ExecuteInterfaceHook(ctx context.Context, id domain.InterfaceIdentifier, hookCmd string) error
|
||||
SetDNS(ctx context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
|
||||
UnsetDNS(ctx context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
|
||||
}
|
||||
|
||||
type EventBus interface {
|
||||
@@ -72,11 +53,10 @@ type EventBus interface {
|
||||
// endregion dependencies
|
||||
|
||||
type Manager struct {
|
||||
cfg *config.Config
|
||||
bus EventBus
|
||||
db InterfaceAndPeerDatabaseRepo
|
||||
wg InterfaceController
|
||||
quick WgQuickController
|
||||
cfg *config.Config
|
||||
bus EventBus
|
||||
db InterfaceAndPeerDatabaseRepo
|
||||
wg *ControllerManager
|
||||
|
||||
userLockMap *sync.Map
|
||||
}
|
||||
@@ -84,8 +64,7 @@ type Manager struct {
|
||||
func NewWireGuardManager(
|
||||
cfg *config.Config,
|
||||
bus EventBus,
|
||||
wg InterfaceController,
|
||||
quick WgQuickController,
|
||||
wg *ControllerManager,
|
||||
db InterfaceAndPeerDatabaseRepo,
|
||||
) (*Manager, error) {
|
||||
m := &Manager{
|
||||
@@ -93,7 +72,6 @@ func NewWireGuardManager(
|
||||
bus: bus,
|
||||
wg: wg,
|
||||
db: db,
|
||||
quick: quick,
|
||||
userLockMap: &sync.Map{},
|
||||
}
|
||||
|
||||
|
||||
@@ -11,24 +11,10 @@ import (
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/app/audit"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
// GetImportableInterfaces returns all physical interfaces that are available on the system.
|
||||
// This function also returns interfaces that are already available in the database.
|
||||
func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return physicalInterfaces, nil
|
||||
}
|
||||
|
||||
// GetInterfaceAndPeers returns the interface and all peers for the given interface identifier.
|
||||
func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
|
||||
*domain.Interface,
|
||||
@@ -104,52 +90,64 @@ func (m Manager) GetUserInterfaces(ctx context.Context, _ domain.UserIdentifier)
|
||||
}
|
||||
|
||||
// ImportNewInterfaces imports all new physical interfaces that are available on the system.
|
||||
// If a filter is set, only interfaces that match the filter will be imported.
|
||||
func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
|
||||
var existingInterfaceIds []domain.InterfaceIdentifier
|
||||
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// if no filter is given, exclude already existing interfaces
|
||||
var excludedInterfaces []domain.InterfaceIdentifier
|
||||
if len(filter) == 0 {
|
||||
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for _, existingInterface := range existingInterfaces {
|
||||
excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier)
|
||||
}
|
||||
for _, existingInterface := range existingInterfaces {
|
||||
existingInterfaceIds = append(existingInterfaceIds, existingInterface.Identifier)
|
||||
}
|
||||
|
||||
imported := 0
|
||||
for _, physicalInterface := range physicalInterfaces {
|
||||
if slices.Contains(excludedInterfaces, physicalInterface.Identifier) {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(filter) != 0 && !slices.Contains(filter, physicalInterface.Identifier) {
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Info("importing new interface", "interface", physicalInterface.Identifier)
|
||||
|
||||
physicalPeers, err := m.wg.GetPeers(ctx, physicalInterface.Identifier)
|
||||
for _, wgBackend := range m.wg.GetAllControllers() {
|
||||
physicalInterfaces, err := wgBackend.Implementation.GetInterfaces(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = m.importInterface(ctx, &physicalInterface, physicalPeers)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err)
|
||||
}
|
||||
for _, physicalInterface := range physicalInterfaces {
|
||||
if slices.Contains(wgBackend.Config.IgnoredInterfaces, string(physicalInterface.Identifier)) {
|
||||
slog.Info("ignoring interface due to backend filter restrictions",
|
||||
"interface", physicalInterface.Identifier, "filter", wgBackend.Config.IgnoredInterfaces,
|
||||
"backend", wgBackend.Config.Id)
|
||||
continue // skip ignored interfaces
|
||||
}
|
||||
|
||||
slog.Info("imported new interface", "interface", physicalInterface.Identifier, "peers", len(physicalPeers))
|
||||
imported++
|
||||
if slices.Contains(existingInterfaceIds, physicalInterface.Identifier) {
|
||||
continue // skip interfaces that already exist
|
||||
}
|
||||
|
||||
if len(filter) > 0 && !slices.Contains(filter, physicalInterface.Identifier) {
|
||||
slog.Info("ignoring interface due to filter restrictions",
|
||||
"interface", physicalInterface.Identifier, "filter", wgBackend.Config.IgnoredInterfaces,
|
||||
"backend", wgBackend.Config.Id)
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Info("importing new interface",
|
||||
"interface", physicalInterface.Identifier, "backend", wgBackend.Config.Id)
|
||||
|
||||
physicalPeers, err := wgBackend.Implementation.GetPeers(ctx, physicalInterface.Identifier)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = m.importInterface(ctx, wgBackend.Implementation, &physicalInterface, physicalPeers)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err)
|
||||
}
|
||||
|
||||
slog.Info("imported new interface",
|
||||
"interface", physicalInterface.Identifier, "peers", len(physicalPeers), "backend", wgBackend.Config.Id)
|
||||
imported++
|
||||
}
|
||||
}
|
||||
|
||||
return imported, nil
|
||||
@@ -213,9 +211,20 @@ func (m Manager) RestoreInterfaceState(
|
||||
return fmt.Errorf("failed to load peers for %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
_, err = m.wg.GetInterface(ctx, iface.Identifier)
|
||||
controller := m.wg.GetController(iface)
|
||||
|
||||
_, err = controller.GetInterface(ctx, iface.Identifier)
|
||||
if err != nil && !iface.IsDisabled() {
|
||||
slog.Debug("creating missing interface", "interface", iface.Identifier)
|
||||
slog.Debug("creating missing interface", "interface", iface.Identifier, "backend", controller.GetId())
|
||||
|
||||
// temporarily disable interface in database so that the current state is reflected correctly
|
||||
_ = m.db.SaveInterface(ctx, iface.Identifier,
|
||||
func(in *domain.Interface) (*domain.Interface, error) {
|
||||
now := time.Now()
|
||||
in.Disabled = &now // set
|
||||
in.DisabledReason = domain.DisabledReasonInterfaceMissing
|
||||
return in, nil
|
||||
})
|
||||
|
||||
// temporarily disable interface in database so that the current state is reflected correctly
|
||||
_ = m.db.SaveInterface(ctx, iface.Identifier,
|
||||
@@ -242,7 +251,8 @@ func (m Manager) RestoreInterfaceState(
|
||||
return fmt.Errorf("failed to create physical interface %s: %w", iface.Identifier, err)
|
||||
}
|
||||
} else {
|
||||
slog.Debug("restoring interface state", "interface", iface.Identifier, "disabled", iface.IsDisabled())
|
||||
slog.Debug("restoring interface state",
|
||||
"interface", iface.Identifier, "disabled", iface.IsDisabled(), "backend", controller.GetId())
|
||||
|
||||
// try to move interface to stored state
|
||||
_, err = m.saveInterface(ctx, &iface)
|
||||
@@ -269,18 +279,14 @@ func (m Manager) RestoreInterfaceState(
|
||||
// restore peers
|
||||
for _, peer := range peers {
|
||||
switch {
|
||||
case iface.IsDisabled(): // if interface is disabled, delete all peers
|
||||
if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil {
|
||||
case iface.IsDisabled() && iface.Backend == config.LocalBackendName: // if interface is disabled, delete all peers
|
||||
if err := controller.DeletePeer(ctx, iface.Identifier,
|
||||
peer.Identifier); err != nil {
|
||||
return fmt.Errorf("failed to remove peer %s for disabled interface %s: %w",
|
||||
peer.Identifier, iface.Identifier, err)
|
||||
}
|
||||
case peer.IsDisabled(): // if peer is disabled, delete it
|
||||
if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil {
|
||||
return fmt.Errorf("failed to remove disbaled peer %s from interface %s: %w",
|
||||
peer.Identifier, iface.Identifier, err)
|
||||
}
|
||||
default: // update peer
|
||||
err := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier,
|
||||
err := controller.SavePeer(ctx, iface.Identifier, peer.Identifier,
|
||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
||||
domain.MergeToPhysicalPeer(pp, &peer)
|
||||
return pp, nil
|
||||
@@ -293,7 +299,7 @@ func (m Manager) RestoreInterfaceState(
|
||||
}
|
||||
|
||||
// remove non-wgportal peers
|
||||
physicalPeers, _ := m.wg.GetPeers(ctx, iface.Identifier)
|
||||
physicalPeers, _ := controller.GetPeers(ctx, iface.Identifier)
|
||||
for _, physicalPeer := range physicalPeers {
|
||||
isWgPortalPeer := false
|
||||
for _, peer := range peers {
|
||||
@@ -303,7 +309,8 @@ func (m Manager) RestoreInterfaceState(
|
||||
}
|
||||
}
|
||||
if !isWgPortalPeer {
|
||||
err := m.wg.DeletePeer(ctx, iface.Identifier, domain.PeerIdentifier(physicalPeer.PublicKey))
|
||||
err := controller.DeletePeer(ctx, iface.Identifier,
|
||||
domain.PeerIdentifier(physicalPeer.PublicKey))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove non-wgportal peer %s from interface %s: %w",
|
||||
physicalPeer.PublicKey, iface.Identifier, err)
|
||||
@@ -455,7 +462,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
|
||||
return err
|
||||
}
|
||||
|
||||
existingInterface, err := m.db.GetInterface(ctx, id)
|
||||
existingInterface, existingPeers, err := m.db.GetInterfaceAndPeers(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to find interface %s: %w", id, err)
|
||||
}
|
||||
@@ -464,25 +471,33 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
|
||||
return fmt.Errorf("deletion not allowed: %w", err)
|
||||
}
|
||||
|
||||
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
|
||||
Interface: *existingInterface,
|
||||
AllowedIps: existingInterface.GetAllowedIPs(existingPeers),
|
||||
FwMark: existingInterface.FirewallMark,
|
||||
Table: existingInterface.GetRoutingTable(),
|
||||
TableStr: existingInterface.RoutingTable,
|
||||
IsDeleted: true,
|
||||
})
|
||||
|
||||
now := time.Now()
|
||||
existingInterface.Disabled = &now // simulate a disabled interface
|
||||
existingInterface.DisabledReason = domain.DisabledReasonDeleted
|
||||
|
||||
physicalInterface, _ := m.wg.GetInterface(ctx, id)
|
||||
|
||||
if err := m.handleInterfacePreSaveHooks(true, existingInterface); err != nil {
|
||||
if err := m.handleInterfacePreSaveHooks(ctx, existingInterface, !existingInterface.IsDisabled(),
|
||||
false); err != nil {
|
||||
return fmt.Errorf("pre-delete hooks failed: %w", err)
|
||||
}
|
||||
|
||||
if err := m.handleInterfacePreSaveActions(existingInterface); err != nil {
|
||||
if err := m.handleInterfacePreSaveActions(ctx, existingInterface); err != nil {
|
||||
return fmt.Errorf("pre-delete actions failed: %w", err)
|
||||
}
|
||||
|
||||
if err := m.deleteInterfacePeers(ctx, id); err != nil {
|
||||
if err := m.deleteInterfacePeers(ctx, existingInterface, existingPeers); err != nil {
|
||||
return fmt.Errorf("peer deletion failure: %w", err)
|
||||
}
|
||||
|
||||
if err := m.wg.DeleteInterface(ctx, id); err != nil {
|
||||
if err := m.wg.GetController(*existingInterface).DeleteInterface(ctx, id); err != nil {
|
||||
return fmt.Errorf("wireguard deletion failure: %w", err)
|
||||
}
|
||||
|
||||
@@ -490,16 +505,12 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
|
||||
return fmt.Errorf("deletion failure: %w", err)
|
||||
}
|
||||
|
||||
fwMark := existingInterface.FirewallMark
|
||||
if physicalInterface != nil && fwMark == 0 {
|
||||
fwMark = physicalInterface.FirewallMark
|
||||
}
|
||||
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
|
||||
FwMark: fwMark,
|
||||
Table: existingInterface.GetRoutingTable(),
|
||||
})
|
||||
|
||||
if err := m.handleInterfacePostSaveHooks(true, existingInterface); err != nil {
|
||||
if err := m.handleInterfacePostSaveHooks(
|
||||
ctx,
|
||||
existingInterface,
|
||||
!existingInterface.IsDisabled(),
|
||||
false,
|
||||
); err != nil {
|
||||
return fmt.Errorf("post-delete hooks failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -518,20 +529,24 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
|
||||
return nil, fmt.Errorf("interface validation failed: %w", err)
|
||||
}
|
||||
|
||||
stateChanged := m.hasInterfaceStateChanged(ctx, iface)
|
||||
oldEnabled, newEnabled, routeTableChanged := false, !iface.IsDisabled(), false // if the interface did not exist, we assume it was not enabled
|
||||
oldInterface, err := m.db.GetInterface(ctx, iface.Identifier)
|
||||
if err == nil {
|
||||
oldEnabled, newEnabled, routeTableChanged = m.getInterfaceStateHistory(oldInterface, iface)
|
||||
}
|
||||
|
||||
if err := m.handleInterfacePreSaveHooks(stateChanged, iface); err != nil {
|
||||
if err := m.handleInterfacePreSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil {
|
||||
return nil, fmt.Errorf("pre-save hooks failed: %w", err)
|
||||
}
|
||||
|
||||
if err := m.handleInterfacePreSaveActions(iface); err != nil {
|
||||
if err := m.handleInterfacePreSaveActions(ctx, iface); err != nil {
|
||||
return nil, fmt.Errorf("pre-save actions failed: %w", err)
|
||||
}
|
||||
|
||||
err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
|
||||
err = m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
|
||||
iface.CopyCalculatedAttributes(i)
|
||||
|
||||
err := m.wg.SaveInterface(ctx, iface.Identifier,
|
||||
err := m.wg.GetController(*iface).SaveInterface(ctx, iface.Identifier,
|
||||
func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||
domain.MergeToPhysicalInterface(pi, iface)
|
||||
return pi, nil
|
||||
@@ -546,24 +561,84 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
|
||||
return nil, fmt.Errorf("failed to save interface: %w", err)
|
||||
}
|
||||
|
||||
if iface.IsDisabled() {
|
||||
physicalInterface, _ := m.wg.GetInterface(ctx, iface.Identifier)
|
||||
fwMark := iface.FirewallMark
|
||||
if physicalInterface != nil && fwMark == 0 {
|
||||
fwMark = physicalInterface.FirewallMark
|
||||
}
|
||||
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
|
||||
FwMark: fwMark,
|
||||
Table: iface.GetRoutingTable(),
|
||||
// update the interface type of peers in db
|
||||
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
|
||||
}
|
||||
for _, peer := range peers {
|
||||
err := m.db.SavePeer(ctx, peer.Identifier, func(_ *domain.Peer) (*domain.Peer, error) {
|
||||
switch iface.Type {
|
||||
case domain.InterfaceTypeAny:
|
||||
peer.Interface.Type = domain.InterfaceTypeAny
|
||||
case domain.InterfaceTypeClient:
|
||||
peer.Interface.Type = domain.InterfaceTypeServer
|
||||
case domain.InterfaceTypeServer:
|
||||
peer.Interface.Type = domain.InterfaceTypeClient
|
||||
}
|
||||
|
||||
return &peer, nil
|
||||
})
|
||||
} else {
|
||||
m.bus.Publish(app.TopicRouteUpdate, "interface updated: "+string(iface.Identifier))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update peer %s for interface %s: %w", peer.Identifier,
|
||||
iface.Identifier, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.handleInterfacePostSaveHooks(stateChanged, iface); err != nil {
|
||||
if iface.IsDisabled() {
|
||||
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
|
||||
Interface: *iface,
|
||||
AllowedIps: iface.GetAllowedIPs(peers),
|
||||
FwMark: iface.FirewallMark,
|
||||
Table: iface.GetRoutingTable(),
|
||||
TableStr: iface.RoutingTable,
|
||||
})
|
||||
} else {
|
||||
m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
|
||||
Interface: *iface,
|
||||
AllowedIps: iface.GetAllowedIPs(peers),
|
||||
FwMark: iface.FirewallMark,
|
||||
Table: iface.GetRoutingTable(),
|
||||
TableStr: iface.RoutingTable,
|
||||
})
|
||||
// if the route table changed, ensure that the old entries are remove
|
||||
if routeTableChanged {
|
||||
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
|
||||
Interface: *oldInterface,
|
||||
AllowedIps: oldInterface.GetAllowedIPs(peers),
|
||||
FwMark: oldInterface.FirewallMark,
|
||||
Table: oldInterface.GetRoutingTable(),
|
||||
TableStr: oldInterface.RoutingTable,
|
||||
IsDeleted: true, // mark the old entries as deleted
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.handleInterfacePostSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil {
|
||||
return nil, fmt.Errorf("post-save hooks failed: %w", err)
|
||||
}
|
||||
|
||||
// If the interface has just been enabled, restore its peers on the physical controller
|
||||
if !oldEnabled && newEnabled && iface.Backend == config.LocalBackendName {
|
||||
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
|
||||
}
|
||||
for _, peer := range peers {
|
||||
saveErr := m.wg.GetController(*iface).SavePeer(ctx, iface.Identifier, peer.Identifier,
|
||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
||||
domain.MergeToPhysicalPeer(pp, &peer)
|
||||
return pp, nil
|
||||
})
|
||||
if saveErr != nil {
|
||||
return nil, fmt.Errorf("failed to restore peer %s for interface %s: %w", peer.Identifier,
|
||||
iface.Identifier, saveErr)
|
||||
}
|
||||
}
|
||||
// notify that peers for this interface have changed so config/routes can be updated
|
||||
m.bus.Publish(app.TopicPeerInterfaceUpdated, iface.Identifier)
|
||||
}
|
||||
|
||||
m.bus.Publish(app.TopicAuditInterfaceChanged, domain.AuditEventWrapper[audit.InterfaceEvent]{
|
||||
Ctx: ctx,
|
||||
Event: audit.InterfaceEvent{
|
||||
@@ -575,75 +650,90 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
|
||||
return iface, nil
|
||||
}
|
||||
|
||||
func (m Manager) hasInterfaceStateChanged(ctx context.Context, iface *domain.Interface) bool {
|
||||
oldInterface, err := m.db.GetInterface(ctx, iface.Identifier)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if oldInterface.IsDisabled() != iface.IsDisabled() {
|
||||
return true // interface in db has changed
|
||||
}
|
||||
|
||||
wgInterface, err := m.wg.GetInterface(ctx, iface.Identifier)
|
||||
if err != nil {
|
||||
return true // interface might not exist - so we assume that there must be a change
|
||||
}
|
||||
|
||||
// compare physical interface settings
|
||||
if len(wgInterface.Addresses) != len(iface.Addresses) ||
|
||||
wgInterface.Mtu != iface.Mtu ||
|
||||
wgInterface.FirewallMark != iface.FirewallMark ||
|
||||
wgInterface.ListenPort != iface.ListenPort ||
|
||||
wgInterface.PrivateKey != iface.PrivateKey ||
|
||||
wgInterface.PublicKey != iface.PublicKey {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
func (m Manager) getInterfaceStateHistory(
|
||||
oldInterface *domain.Interface,
|
||||
iface *domain.Interface,
|
||||
) (oldEnabled, newEnabled, routeTableChanged bool) {
|
||||
return !oldInterface.IsDisabled(), !iface.IsDisabled(), oldInterface.RoutingTable != iface.RoutingTable
|
||||
}
|
||||
|
||||
func (m Manager) handleInterfacePreSaveActions(iface *domain.Interface) error {
|
||||
if !iface.IsDisabled() {
|
||||
if err := m.quick.SetDNS(iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
|
||||
return fmt.Errorf("failed to update dns settings: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := m.quick.UnsetDNS(iface.Identifier); err != nil {
|
||||
return fmt.Errorf("failed to clear dns settings: %w", err)
|
||||
func (m Manager) handleInterfacePreSaveActions(ctx context.Context, iface *domain.Interface) error {
|
||||
wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
|
||||
if !ok {
|
||||
slog.Warn("failed to perform pre-save actions", "interface", iface.Identifier,
|
||||
"error", "no capable controller found")
|
||||
return nil
|
||||
}
|
||||
|
||||
// update DNS settings only for client interfaces
|
||||
if iface.Type == domain.InterfaceTypeClient || iface.Type == domain.InterfaceTypeAny {
|
||||
if !iface.IsDisabled() {
|
||||
if err := wgQuickController.SetDNS(ctx, iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
|
||||
return fmt.Errorf("failed to update dns settings: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := wgQuickController.UnsetDNS(ctx, iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
|
||||
return fmt.Errorf("failed to clear dns settings: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) handleInterfacePreSaveHooks(stateChanged bool, iface *domain.Interface) error {
|
||||
if !stateChanged {
|
||||
func (m Manager) handleInterfacePreSaveHooks(
|
||||
ctx context.Context,
|
||||
iface *domain.Interface,
|
||||
oldEnabled, newEnabled bool,
|
||||
) error {
|
||||
if oldEnabled == newEnabled {
|
||||
return nil // do nothing if state did not change
|
||||
}
|
||||
|
||||
if !iface.IsDisabled() {
|
||||
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreUp); err != nil {
|
||||
slog.Debug("executing pre-save hooks", "interface", iface.Identifier, "up", newEnabled)
|
||||
|
||||
wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
|
||||
if !ok {
|
||||
slog.Warn("failed to execute pre-save hooks", "interface", iface.Identifier, "up", newEnabled,
|
||||
"error", "no capable controller found")
|
||||
return nil
|
||||
}
|
||||
|
||||
if newEnabled {
|
||||
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PreUp); err != nil {
|
||||
return fmt.Errorf("failed to execute pre-up hook: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreDown); err != nil {
|
||||
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PreDown); err != nil {
|
||||
return fmt.Errorf("failed to execute pre-down hook: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) handleInterfacePostSaveHooks(stateChanged bool, iface *domain.Interface) error {
|
||||
if !stateChanged {
|
||||
func (m Manager) handleInterfacePostSaveHooks(
|
||||
ctx context.Context,
|
||||
iface *domain.Interface,
|
||||
oldEnabled, newEnabled bool,
|
||||
) error {
|
||||
if oldEnabled == newEnabled {
|
||||
return nil // do nothing if state did not change
|
||||
}
|
||||
|
||||
if !iface.IsDisabled() {
|
||||
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostUp); err != nil {
|
||||
slog.Debug("executing post-save hooks", "interface", iface.Identifier, "up", newEnabled)
|
||||
|
||||
wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
|
||||
if !ok {
|
||||
slog.Warn("failed to execute post-save hooks", "interface", iface.Identifier, "up", newEnabled,
|
||||
"error", "no capable controller found")
|
||||
return nil
|
||||
}
|
||||
|
||||
if newEnabled {
|
||||
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PostUp); err != nil {
|
||||
return fmt.Errorf("failed to execute post-up hook: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostDown); err != nil {
|
||||
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PostDown); err != nil {
|
||||
return fmt.Errorf("failed to execute post-down hook: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -769,7 +859,12 @@ func (m Manager) getFreshListenPort(ctx context.Context) (port int, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m Manager) importInterface(ctx context.Context, in *domain.PhysicalInterface, peers []domain.PhysicalPeer) error {
|
||||
func (m Manager) importInterface(
|
||||
ctx context.Context,
|
||||
backend domain.InterfaceController,
|
||||
in *domain.PhysicalInterface,
|
||||
peers []domain.PhysicalPeer,
|
||||
) error {
|
||||
now := time.Now()
|
||||
iface := domain.ConvertPhysicalInterface(in)
|
||||
iface.BaseModel = domain.BaseModel{
|
||||
@@ -778,8 +873,20 @@ func (m Manager) importInterface(ctx context.Context, in *domain.PhysicalInterfa
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
iface.Backend = backend.GetId()
|
||||
iface.PeerDefAllowedIPsStr = iface.AddressStr()
|
||||
|
||||
// try to predict the interface type based on the number of peers
|
||||
switch len(peers) {
|
||||
case 0:
|
||||
iface.Type = domain.InterfaceTypeAny // no peers means this is an unknown interface
|
||||
case 1:
|
||||
iface.Type = domain.InterfaceTypeClient // one peer means this is a client interface
|
||||
default: // multiple peers means this is a server interface
|
||||
|
||||
iface.Type = domain.InterfaceTypeServer
|
||||
}
|
||||
|
||||
existingInterface, err := m.db.GetInterface(ctx, iface.Identifier)
|
||||
if err != nil && !errors.Is(err, domain.ErrNotFound) {
|
||||
return err
|
||||
@@ -830,16 +937,20 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain
|
||||
peer.Interface.PreDown = domain.NewConfigOption(in.PeerDefPreDown, true)
|
||||
peer.Interface.PostDown = domain.NewConfigOption(in.PeerDefPostDown, true)
|
||||
|
||||
var displayName string
|
||||
switch in.Type {
|
||||
case domain.InterfaceTypeAny:
|
||||
peer.Interface.Type = domain.InterfaceTypeAny
|
||||
peer.DisplayName = "Autodetected Peer (" + peer.Interface.PublicKey[0:8] + ")"
|
||||
displayName = "Autodetected Peer (" + peer.Interface.PublicKey[0:8] + ")"
|
||||
case domain.InterfaceTypeClient:
|
||||
peer.Interface.Type = domain.InterfaceTypeServer
|
||||
peer.DisplayName = "Autodetected Endpoint (" + peer.Interface.PublicKey[0:8] + ")"
|
||||
displayName = "Autodetected Endpoint (" + peer.Interface.PublicKey[0:8] + ")"
|
||||
case domain.InterfaceTypeServer:
|
||||
peer.Interface.Type = domain.InterfaceTypeClient
|
||||
peer.DisplayName = "Autodetected Client (" + peer.Interface.PublicKey[0:8] + ")"
|
||||
displayName = "Autodetected Client (" + peer.Interface.PublicKey[0:8] + ")"
|
||||
}
|
||||
if peer.DisplayName == "" {
|
||||
peer.DisplayName = displayName // use auto-generated display name if not set
|
||||
}
|
||||
|
||||
err := m.db.SavePeer(ctx, peer.Identifier, func(_ *domain.Peer) (*domain.Peer, error) {
|
||||
@@ -852,13 +963,9 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) error {
|
||||
allPeers, err := m.db.GetInterfacePeers(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
func (m Manager) deleteInterfacePeers(ctx context.Context, iface *domain.Interface, allPeers []domain.Peer) error {
|
||||
for _, peer := range allPeers {
|
||||
err = m.wg.DeletePeer(ctx, id, peer.Identifier)
|
||||
err := m.wg.GetController(*iface).DeletePeer(ctx, iface.Identifier, peer.Identifier)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err)
|
||||
}
|
||||
|
||||
@@ -188,6 +188,32 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
|
||||
|
||||
sessionUser := domain.GetUserInfo(ctx)
|
||||
|
||||
peer.Identifier = domain.PeerIdentifier(peer.Interface.PublicKey) // ensure that identifier corresponds to the public key
|
||||
|
||||
// Enforce peer limit for non-admin users if LimitAdditionalUserPeers is set
|
||||
if m.cfg.Core.SelfProvisioningAllowed && !sessionUser.IsAdmin && m.cfg.Advanced.LimitAdditionalUserPeers > 0 {
|
||||
peers, err := m.db.GetUserPeers(ctx, peer.UserIdentifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch peers for user %s: %w", peer.UserIdentifier, err)
|
||||
}
|
||||
// Count enabled peers (disabled IS NULL)
|
||||
peerCount := 0
|
||||
for _, p := range peers {
|
||||
if !p.IsDisabled() {
|
||||
peerCount++
|
||||
}
|
||||
}
|
||||
totalAllowedPeers := 1 + m.cfg.Advanced.LimitAdditionalUserPeers // 1 default peer + x additional peers
|
||||
if peerCount >= totalAllowedPeers {
|
||||
slog.WarnContext(ctx, "peer creation blocked due to limit",
|
||||
"user", peer.UserIdentifier,
|
||||
"current_count", peerCount,
|
||||
"allowed_count", totalAllowedPeers)
|
||||
return nil, fmt.Errorf("peer limit reached (%d peers allowed): %w", totalAllowedPeers,
|
||||
domain.ErrNoPermission)
|
||||
}
|
||||
}
|
||||
|
||||
existingPeer, err := m.db.GetPeer(ctx, peer.Identifier)
|
||||
if err != nil && !errors.Is(err, domain.ErrNotFound) {
|
||||
return nil, fmt.Errorf("unable to load existing peer %s: %w", peer.Identifier, err)
|
||||
@@ -347,7 +373,12 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
|
||||
return fmt.Errorf("delete not allowed: %w", err)
|
||||
}
|
||||
|
||||
err = m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, id)
|
||||
iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
|
||||
}
|
||||
|
||||
err = m.wg.GetController(*iface).DeletePeer(ctx, peer.InterfaceIdentifier, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wireguard failed to delete peer %s: %w", id, err)
|
||||
}
|
||||
@@ -357,9 +388,20 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
|
||||
return fmt.Errorf("failed to delete peer %s: %w", id, err)
|
||||
}
|
||||
|
||||
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
m.bus.Publish(app.TopicPeerDeleted, *peer)
|
||||
// Update routes after peers have changed
|
||||
m.bus.Publish(app.TopicRouteUpdate, "peers updated")
|
||||
m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
|
||||
Interface: *iface,
|
||||
AllowedIps: iface.GetAllowedIPs(peers),
|
||||
FwMark: iface.FirewallMark,
|
||||
Table: iface.GetRoutingTable(),
|
||||
TableStr: iface.RoutingTable,
|
||||
})
|
||||
// Update interface after peers have changed
|
||||
m.bus.Publish(app.TopicPeerInterfaceUpdated, peer.InterfaceIdentifier)
|
||||
|
||||
@@ -407,37 +449,36 @@ func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier)
|
||||
// region helper-functions
|
||||
|
||||
func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
|
||||
interfaces := make(map[domain.InterfaceIdentifier]struct{})
|
||||
interfaces := make(map[domain.InterfaceIdentifier]domain.Interface)
|
||||
|
||||
for i := range peers {
|
||||
peer := peers[i]
|
||||
var err error
|
||||
if peer.IsDisabled() || peer.IsExpired() {
|
||||
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
||||
peer.CopyCalculatedAttributes(p)
|
||||
|
||||
if err := m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, peer.Identifier); err != nil {
|
||||
return nil, fmt.Errorf("failed to delete wireguard peer %s: %w", peer.Identifier, err)
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
})
|
||||
} else {
|
||||
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
||||
peer.CopyCalculatedAttributes(p)
|
||||
|
||||
err := m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
|
||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
||||
domain.MergeToPhysicalPeer(pp, peer)
|
||||
return pp, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
})
|
||||
for _, peer := range peers {
|
||||
// get interface from db if it is not yet in the map
|
||||
if _, ok := interfaces[peer.InterfaceIdentifier]; !ok {
|
||||
iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
|
||||
}
|
||||
interfaces[peer.InterfaceIdentifier] = *iface
|
||||
}
|
||||
|
||||
iface := interfaces[peer.InterfaceIdentifier]
|
||||
|
||||
// Always save the peer to the backend, regardless of disabled/expired state
|
||||
// The backend will handle the disabled state appropriately
|
||||
err := m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
||||
peer.CopyCalculatedAttributes(p)
|
||||
|
||||
err := m.wg.GetController(iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
|
||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
||||
domain.MergeToPhysicalPeer(pp, peer)
|
||||
return pp, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("save failure for peer %s: %w", peer.Identifier, err)
|
||||
}
|
||||
@@ -451,13 +492,22 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
|
||||
Peer: *peer,
|
||||
},
|
||||
})
|
||||
|
||||
interfaces[peer.InterfaceIdentifier] = struct{}{}
|
||||
}
|
||||
|
||||
// Update routes after peers have changed
|
||||
if len(interfaces) != 0 {
|
||||
m.bus.Publish(app.TopicRouteUpdate, "peers updated")
|
||||
for id, iface := range interfaces {
|
||||
interfacePeers, err := m.db.GetInterfacePeers(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to re-load peers for interface %s: %w", id, err)
|
||||
}
|
||||
|
||||
m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
|
||||
Interface: iface,
|
||||
AllowedIps: iface.GetAllowedIPs(interfacePeers),
|
||||
FwMark: iface.FirewallMark,
|
||||
Table: iface.GetRoutingTable(),
|
||||
TableStr: iface.RoutingTable,
|
||||
})
|
||||
}
|
||||
|
||||
for iface := range interfaces {
|
||||
|
||||
194
internal/app/wireguard/wireguard_peers_test.go
Normal file
194
internal/app/wireguard/wireguard_peers_test.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
// --- Test mocks ---
|
||||
|
||||
type mockBus struct{}
|
||||
|
||||
func (f *mockBus) Publish(topic string, args ...any) {}
|
||||
func (f *mockBus) Subscribe(topic string, fn interface{}) error { return nil }
|
||||
|
||||
type mockController struct{}
|
||||
|
||||
func (f *mockController) GetId() domain.InterfaceBackend { return "local" }
|
||||
func (f *mockController) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *mockController) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (
|
||||
*domain.PhysicalInterface,
|
||||
error,
|
||||
) {
|
||||
return &domain.PhysicalInterface{Identifier: id}, nil
|
||||
}
|
||||
func (f *mockController) GetPeers(_ context.Context, _ domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *mockController) SaveInterface(
|
||||
_ context.Context,
|
||||
_ domain.InterfaceIdentifier,
|
||||
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
||||
) error {
|
||||
_, _ = updateFunc(&domain.PhysicalInterface{})
|
||||
return nil
|
||||
}
|
||||
func (f *mockController) DeleteInterface(_ context.Context, _ domain.InterfaceIdentifier) error {
|
||||
return nil
|
||||
}
|
||||
func (f *mockController) SavePeer(
|
||||
_ context.Context,
|
||||
_ domain.InterfaceIdentifier,
|
||||
_ domain.PeerIdentifier,
|
||||
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
||||
) error {
|
||||
_, _ = updateFunc(&domain.PhysicalPeer{})
|
||||
return nil
|
||||
}
|
||||
func (f *mockController) DeletePeer(_ context.Context, _ domain.InterfaceIdentifier, _ domain.PeerIdentifier) error {
|
||||
return nil
|
||||
}
|
||||
func (f *mockController) PingAddresses(_ context.Context, _ string) (*domain.PingerResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type mockDB struct {
|
||||
savedPeers map[domain.PeerIdentifier]*domain.Peer
|
||||
iface *domain.Interface
|
||||
}
|
||||
|
||||
func (f *mockDB) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) {
|
||||
if f.iface != nil && f.iface.Identifier == id {
|
||||
return f.iface, nil
|
||||
}
|
||||
return &domain.Interface{Identifier: id}, nil
|
||||
}
|
||||
func (f *mockDB) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
|
||||
*domain.Interface,
|
||||
[]domain.Peer,
|
||||
error,
|
||||
) {
|
||||
return f.iface, nil, nil
|
||||
}
|
||||
func (f *mockDB) GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *mockDB) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) { return nil, nil }
|
||||
func (f *mockDB) GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *mockDB) SaveInterface(
|
||||
ctx context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(in *domain.Interface) (*domain.Interface, error),
|
||||
) error {
|
||||
if f.iface == nil {
|
||||
f.iface = &domain.Interface{Identifier: id}
|
||||
}
|
||||
var err error
|
||||
f.iface, err = updateFunc(f.iface)
|
||||
return err
|
||||
}
|
||||
func (f *mockDB) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
|
||||
return nil
|
||||
}
|
||||
func (f *mockDB) GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *mockDB) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *mockDB) SavePeer(
|
||||
ctx context.Context,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(in *domain.Peer) (*domain.Peer, error),
|
||||
) error {
|
||||
if f.savedPeers == nil {
|
||||
f.savedPeers = make(map[domain.PeerIdentifier]*domain.Peer)
|
||||
}
|
||||
existing := f.savedPeers[id]
|
||||
if existing == nil {
|
||||
existing = &domain.Peer{Identifier: id}
|
||||
}
|
||||
updated, err := updateFunc(existing)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.savedPeers[updated.Identifier] = updated
|
||||
return nil
|
||||
}
|
||||
func (f *mockDB) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error { return nil }
|
||||
func (f *mockDB) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) {
|
||||
return nil, domain.ErrNotFound
|
||||
}
|
||||
func (f *mockDB) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (
|
||||
map[domain.Cidr][]domain.Cidr,
|
||||
error,
|
||||
) {
|
||||
return map[domain.Cidr][]domain.Cidr{}, nil
|
||||
}
|
||||
|
||||
// --- Test ---
|
||||
|
||||
func TestCreatePeer_SetsIdentifier_FromPublicKey(t *testing.T) {
|
||||
// Arrange
|
||||
cfg := &config.Config{}
|
||||
cfg.Core.SelfProvisioningAllowed = true
|
||||
cfg.Core.EditableKeys = true
|
||||
cfg.Advanced.LimitAdditionalUserPeers = 0
|
||||
|
||||
bus := &mockBus{}
|
||||
|
||||
// Prepare a controller manager with our mock controller
|
||||
ctrlMgr := &ControllerManager{
|
||||
controllers: map[domain.InterfaceBackend]backendInstance{
|
||||
config.LocalBackendName: {Implementation: &mockController{}},
|
||||
},
|
||||
}
|
||||
|
||||
db := &mockDB{iface: &domain.Interface{Identifier: "wg0", Type: domain.InterfaceTypeServer}}
|
||||
|
||||
m := Manager{
|
||||
cfg: cfg,
|
||||
bus: bus,
|
||||
db: db,
|
||||
wg: ctrlMgr,
|
||||
}
|
||||
|
||||
userId := domain.UserIdentifier("user@example.com")
|
||||
ctx := domain.SetUserInfo(context.Background(), &domain.ContextUserInfo{Id: userId, IsAdmin: false})
|
||||
|
||||
pubKey := "TEST_PUBLIC_KEY_ABC123"
|
||||
|
||||
input := &domain.Peer{
|
||||
Identifier: "should_be_overwritten",
|
||||
UserIdentifier: userId,
|
||||
InterfaceIdentifier: domain.InterfaceIdentifier("wg0"),
|
||||
Interface: domain.PeerInterfaceConfig{
|
||||
KeyPair: domain.KeyPair{PublicKey: pubKey},
|
||||
},
|
||||
}
|
||||
|
||||
// Act
|
||||
out, err := m.CreatePeer(ctx, input)
|
||||
|
||||
// Assert
|
||||
if err != nil {
|
||||
t.Fatalf("CreatePeer returned error: %v", err)
|
||||
}
|
||||
|
||||
expectedId := domain.PeerIdentifier(pubKey)
|
||||
if out.Identifier != expectedId {
|
||||
t.Fatalf("expected Identifier to be set from public key %q, got %q", expectedId, out.Identifier)
|
||||
}
|
||||
|
||||
// Ensure the saved peer in DB also has the expected identifier
|
||||
if db.savedPeers[expectedId] == nil {
|
||||
t.Fatalf("expected peer with identifier %q to be saved in DB", expectedId)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user