mirror of
https://github.com/h44z/wg-portal.git
synced 2025-09-15 07:11:15 +00:00
Mikrotik integration (#467)
Allow MikroTik routes as WireGuard backends
This commit is contained in:
166
internal/app/wireguard/controller_manager.go
Normal file
166
internal/app/wireguard/controller_manager.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"slices"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/adapters/wgcontroller"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type InterfaceController interface {
|
||||
GetId() domain.InterfaceBackend
|
||||
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
|
||||
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
||||
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
||||
SaveInterface(
|
||||
_ context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
||||
) error
|
||||
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
||||
SavePeer(
|
||||
_ context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
||||
) error
|
||||
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
||||
PingAddresses(
|
||||
ctx context.Context,
|
||||
addr string,
|
||||
) (*domain.PingerResult, error)
|
||||
}
|
||||
|
||||
type backendInstance struct {
|
||||
Config config.BackendBase // Config is the configuration for the backend instance.
|
||||
Implementation InterfaceController
|
||||
}
|
||||
|
||||
type ControllerManager struct {
|
||||
cfg *config.Config
|
||||
controllers map[domain.InterfaceBackend]backendInstance
|
||||
}
|
||||
|
||||
func NewControllerManager(cfg *config.Config) (*ControllerManager, error) {
|
||||
c := &ControllerManager{
|
||||
cfg: cfg,
|
||||
controllers: make(map[domain.InterfaceBackend]backendInstance),
|
||||
}
|
||||
|
||||
err := c.init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *ControllerManager) init() error {
|
||||
if err := c.registerLocalController(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.registerMikrotikControllers(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.logRegisteredControllers()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ControllerManager) registerLocalController() error {
|
||||
localController, err := wgcontroller.NewLocalController(c.cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create local WireGuard controller: %w", err)
|
||||
}
|
||||
|
||||
c.controllers[config.LocalBackendName] = backendInstance{
|
||||
Config: config.BackendBase{
|
||||
Id: config.LocalBackendName,
|
||||
DisplayName: "Local WireGuard Controller",
|
||||
},
|
||||
Implementation: localController,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ControllerManager) registerMikrotikControllers() error {
|
||||
for _, backendConfig := range c.cfg.Backend.Mikrotik {
|
||||
if backendConfig.Id == config.LocalBackendName {
|
||||
slog.Warn("skipping registration of Mikrotik controller with reserved ID", "id", config.LocalBackendName)
|
||||
continue
|
||||
}
|
||||
|
||||
controller, err := wgcontroller.NewMikrotikController(c.cfg, &backendConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Mikrotik controller for backend %s: %w", backendConfig.Id, err)
|
||||
}
|
||||
|
||||
c.controllers[domain.InterfaceBackend(backendConfig.Id)] = backendInstance{
|
||||
Config: backendConfig.BackendBase,
|
||||
Implementation: controller,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ControllerManager) logRegisteredControllers() {
|
||||
for backend, controller := range c.controllers {
|
||||
slog.Debug("backend controller registered",
|
||||
"backend", backend, "type", fmt.Sprintf("%T", controller.Implementation))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) InterfaceController {
|
||||
return c.getController(backend, "")
|
||||
}
|
||||
|
||||
func (c *ControllerManager) GetController(iface domain.Interface) InterfaceController {
|
||||
return c.getController(iface.Backend, iface.Identifier)
|
||||
}
|
||||
|
||||
func (c *ControllerManager) getController(
|
||||
backend domain.InterfaceBackend,
|
||||
ifaceId domain.InterfaceIdentifier,
|
||||
) InterfaceController {
|
||||
if backend == "" {
|
||||
// If no backend is specified, use the local controller.
|
||||
// This might be the case for interfaces created in previous WireGuard Portal versions.
|
||||
backend = config.LocalBackendName
|
||||
}
|
||||
|
||||
controller, exists := c.controllers[backend]
|
||||
if !exists {
|
||||
controller, exists = c.controllers[config.LocalBackendName] // Fallback to local controller
|
||||
if !exists {
|
||||
// If the local controller is also not found, panic
|
||||
panic(fmt.Sprintf("%s interface controller for backend %s not found", ifaceId, backend))
|
||||
}
|
||||
slog.Warn("controller for backend not found, using local controller",
|
||||
"backend", backend, "interface", ifaceId)
|
||||
}
|
||||
return controller.Implementation
|
||||
}
|
||||
|
||||
func (c *ControllerManager) GetAllControllers() []InterfaceController {
|
||||
var backendInstances = make([]InterfaceController, 0, len(c.controllers))
|
||||
for instance := range maps.Values(c.controllers) {
|
||||
backendInstances = append(backendInstances, instance.Implementation)
|
||||
}
|
||||
return backendInstances
|
||||
}
|
||||
|
||||
func (c *ControllerManager) GetControllerNames() []config.BackendBase {
|
||||
var names []config.BackendBase
|
||||
for _, id := range slices.Sorted(maps.Keys(c.controllers)) {
|
||||
names = append(names, c.controllers[id].Config)
|
||||
}
|
||||
|
||||
return names
|
||||
}
|
@@ -6,8 +6,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
probing "github.com/prometheus-community/pro-bing"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
@@ -30,11 +28,6 @@ type StatisticsDatabaseRepo interface {
|
||||
DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error
|
||||
}
|
||||
|
||||
type StatisticsInterfaceController interface {
|
||||
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
||||
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
||||
}
|
||||
|
||||
type StatisticsMetricsServer interface {
|
||||
UpdateInterfaceMetrics(status domain.InterfaceStatus)
|
||||
UpdatePeerMetrics(peer *domain.Peer, status domain.PeerStatus)
|
||||
@@ -47,15 +40,20 @@ type StatisticsEventBus interface {
|
||||
Publish(topic string, args ...any)
|
||||
}
|
||||
|
||||
type pingJob struct {
|
||||
Peer domain.Peer
|
||||
Backend domain.InterfaceBackend
|
||||
}
|
||||
|
||||
type StatisticsCollector struct {
|
||||
cfg *config.Config
|
||||
bus StatisticsEventBus
|
||||
|
||||
pingWaitGroup sync.WaitGroup
|
||||
pingJobs chan domain.Peer
|
||||
pingJobs chan pingJob
|
||||
|
||||
db StatisticsDatabaseRepo
|
||||
wg StatisticsInterfaceController
|
||||
wg *ControllerManager
|
||||
ms StatisticsMetricsServer
|
||||
|
||||
peerChangeEvent chan domain.PeerIdentifier
|
||||
@@ -66,7 +64,7 @@ func NewStatisticsCollector(
|
||||
cfg *config.Config,
|
||||
bus StatisticsEventBus,
|
||||
db StatisticsDatabaseRepo,
|
||||
wg StatisticsInterfaceController,
|
||||
wg *ControllerManager,
|
||||
ms StatisticsMetricsServer,
|
||||
) (*StatisticsCollector, error) {
|
||||
c := &StatisticsCollector{
|
||||
@@ -117,7 +115,7 @@ func (c *StatisticsCollector) collectInterfaceData(ctx context.Context) {
|
||||
}
|
||||
|
||||
for _, in := range interfaces {
|
||||
physicalInterface, err := c.wg.GetInterface(ctx, in.Identifier)
|
||||
physicalInterface, err := c.wg.GetController(in).GetInterface(ctx, in.Identifier)
|
||||
if err != nil {
|
||||
slog.Warn("failed to load physical interface for data collection", "interface", in.Identifier,
|
||||
"error", err)
|
||||
@@ -169,7 +167,7 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
|
||||
}
|
||||
|
||||
for _, in := range interfaces {
|
||||
peers, err := c.wg.GetPeers(ctx, in.Identifier)
|
||||
peers, err := c.wg.GetController(in).GetPeers(ctx, in.Identifier)
|
||||
if err != nil {
|
||||
slog.Warn("failed to fetch peers for data collection", "interface", in.Identifier, "error", err)
|
||||
continue
|
||||
@@ -271,7 +269,7 @@ func (c *StatisticsCollector) startPingWorkers(ctx context.Context) {
|
||||
|
||||
c.pingWaitGroup = sync.WaitGroup{}
|
||||
c.pingWaitGroup.Add(c.cfg.Statistics.PingCheckWorkers)
|
||||
c.pingJobs = make(chan domain.Peer, c.cfg.Statistics.PingCheckWorkers)
|
||||
c.pingJobs = make(chan pingJob, c.cfg.Statistics.PingCheckWorkers)
|
||||
|
||||
// start workers
|
||||
for i := 0; i < c.cfg.Statistics.PingCheckWorkers; i++ {
|
||||
@@ -314,7 +312,10 @@ func (c *StatisticsCollector) enqueuePingChecks(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
for _, peer := range peers {
|
||||
c.pingJobs <- peer
|
||||
c.pingJobs <- pingJob{
|
||||
Peer: peer,
|
||||
Backend: in.Backend,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -323,11 +324,14 @@ func (c *StatisticsCollector) enqueuePingChecks(ctx context.Context) {
|
||||
|
||||
func (c *StatisticsCollector) pingWorker(ctx context.Context) {
|
||||
defer c.pingWaitGroup.Done()
|
||||
for peer := range c.pingJobs {
|
||||
for job := range c.pingJobs {
|
||||
peer := job.Peer
|
||||
backend := job.Backend
|
||||
|
||||
var connectionStateChanged bool
|
||||
var newPeerStatus domain.PeerStatus
|
||||
|
||||
peerPingable := c.isPeerPingable(ctx, peer)
|
||||
peerPingable := c.isPeerPingable(ctx, backend, peer)
|
||||
slog.Debug("peer ping check completed", "peer", peer.Identifier, "pingable", peerPingable)
|
||||
|
||||
now := time.Now()
|
||||
@@ -368,7 +372,11 @@ func (c *StatisticsCollector) pingWorker(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Peer) bool {
|
||||
func (c *StatisticsCollector) isPeerPingable(
|
||||
ctx context.Context,
|
||||
backend domain.InterfaceBackend,
|
||||
peer domain.Peer,
|
||||
) bool {
|
||||
if !c.cfg.Statistics.UsePingChecks {
|
||||
return false
|
||||
}
|
||||
@@ -378,23 +386,13 @@ func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Pe
|
||||
return false
|
||||
}
|
||||
|
||||
pinger, err := probing.NewPinger(checkAddr)
|
||||
stats, err := c.wg.GetControllerByName(backend).PingAddresses(ctx, checkAddr)
|
||||
if err != nil {
|
||||
slog.Debug("failed to instantiate pinger", "peer", peer.Identifier, "address", checkAddr, "error", err)
|
||||
slog.Debug("failed to ping peer", "peer", peer.Identifier, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
checkCount := 1
|
||||
pinger.SetPrivileged(!c.cfg.Statistics.PingUnprivileged)
|
||||
pinger.Count = checkCount
|
||||
pinger.Timeout = 2 * time.Second
|
||||
err = pinger.RunWithContext(ctx) // Blocks until finished.
|
||||
if err != nil {
|
||||
slog.Debug("pinger for peer exited unexpectedly", "peer", peer.Identifier, "address", checkAddr, "error", err)
|
||||
return false
|
||||
}
|
||||
stats := pinger.Statistics()
|
||||
return stats.PacketsRecv == checkCount
|
||||
return stats.IsPingable()
|
||||
}
|
||||
|
||||
func (c *StatisticsCollector) updateInterfaceMetrics(status domain.InterfaceStatus) {
|
||||
|
@@ -37,25 +37,6 @@ type InterfaceAndPeerDatabaseRepo interface {
|
||||
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
|
||||
}
|
||||
|
||||
type InterfaceController interface {
|
||||
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
|
||||
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
||||
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
||||
SaveInterface(
|
||||
_ context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
||||
) error
|
||||
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
||||
SavePeer(
|
||||
_ context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
||||
) error
|
||||
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
||||
}
|
||||
|
||||
type WgQuickController interface {
|
||||
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
|
||||
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
|
||||
@@ -75,7 +56,7 @@ type Manager struct {
|
||||
cfg *config.Config
|
||||
bus EventBus
|
||||
db InterfaceAndPeerDatabaseRepo
|
||||
wg InterfaceController
|
||||
wg *ControllerManager
|
||||
quick WgQuickController
|
||||
|
||||
userLockMap *sync.Map
|
||||
@@ -84,7 +65,7 @@ type Manager struct {
|
||||
func NewWireGuardManager(
|
||||
cfg *config.Config,
|
||||
bus EventBus,
|
||||
wg InterfaceController,
|
||||
wg *ControllerManager,
|
||||
quick WgQuickController,
|
||||
db InterfaceAndPeerDatabaseRepo,
|
||||
) (*Manager, error) {
|
||||
|
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/app/audit"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
@@ -21,12 +22,17 @@ func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.Physical
|
||||
return nil, err
|
||||
}
|
||||
|
||||
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var allPhysicalInterfaces []domain.PhysicalInterface
|
||||
for _, wgBackend := range m.wg.GetAllControllers() {
|
||||
physicalInterfaces, err := wgBackend.GetInterfaces(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allPhysicalInterfaces = append(allPhysicalInterfaces, physicalInterfaces...)
|
||||
}
|
||||
|
||||
return physicalInterfaces, nil
|
||||
return allPhysicalInterfaces, nil
|
||||
}
|
||||
|
||||
// GetInterfaceAndPeers returns the interface and all peers for the given interface identifier.
|
||||
@@ -109,47 +115,49 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter
|
||||
return 0, err
|
||||
}
|
||||
|
||||
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// if no filter is given, exclude already existing interfaces
|
||||
var excludedInterfaces []domain.InterfaceIdentifier
|
||||
if len(filter) == 0 {
|
||||
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for _, existingInterface := range existingInterfaces {
|
||||
excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier)
|
||||
}
|
||||
}
|
||||
|
||||
imported := 0
|
||||
for _, physicalInterface := range physicalInterfaces {
|
||||
if slices.Contains(excludedInterfaces, physicalInterface.Identifier) {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(filter) != 0 && !slices.Contains(filter, physicalInterface.Identifier) {
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Info("importing new interface", "interface", physicalInterface.Identifier)
|
||||
|
||||
physicalPeers, err := m.wg.GetPeers(ctx, physicalInterface.Identifier)
|
||||
for _, wgBackend := range m.wg.GetAllControllers() {
|
||||
physicalInterfaces, err := wgBackend.GetInterfaces(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = m.importInterface(ctx, &physicalInterface, physicalPeers)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err)
|
||||
// if no filter is given, exclude already existing interfaces
|
||||
var excludedInterfaces []domain.InterfaceIdentifier
|
||||
if len(filter) == 0 {
|
||||
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for _, existingInterface := range existingInterfaces {
|
||||
excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("imported new interface", "interface", physicalInterface.Identifier, "peers", len(physicalPeers))
|
||||
imported++
|
||||
for _, physicalInterface := range physicalInterfaces {
|
||||
if slices.Contains(excludedInterfaces, physicalInterface.Identifier) {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(filter) != 0 && !slices.Contains(filter, physicalInterface.Identifier) {
|
||||
continue
|
||||
}
|
||||
|
||||
slog.Info("importing new interface", "interface", physicalInterface.Identifier)
|
||||
|
||||
physicalPeers, err := wgBackend.GetPeers(ctx, physicalInterface.Identifier)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = m.importInterface(ctx, wgBackend, &physicalInterface, physicalPeers)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err)
|
||||
}
|
||||
|
||||
slog.Info("imported new interface", "interface", physicalInterface.Identifier, "peers", len(physicalPeers))
|
||||
imported++
|
||||
}
|
||||
}
|
||||
|
||||
return imported, nil
|
||||
@@ -213,7 +221,7 @@ func (m Manager) RestoreInterfaceState(
|
||||
return fmt.Errorf("failed to load peers for %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
_, err = m.wg.GetInterface(ctx, iface.Identifier)
|
||||
_, err = m.wg.GetController(iface).GetInterface(ctx, iface.Identifier)
|
||||
if err != nil && !iface.IsDisabled() {
|
||||
slog.Debug("creating missing interface", "interface", iface.Identifier)
|
||||
|
||||
@@ -260,18 +268,14 @@ func (m Manager) RestoreInterfaceState(
|
||||
// restore peers
|
||||
for _, peer := range peers {
|
||||
switch {
|
||||
case iface.IsDisabled(): // if interface is disabled, delete all peers
|
||||
if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil {
|
||||
case iface.IsDisabled() && iface.Backend == config.LocalBackendName: // if interface is disabled, delete all peers
|
||||
if err := m.wg.GetController(iface).DeletePeer(ctx, iface.Identifier,
|
||||
peer.Identifier); err != nil {
|
||||
return fmt.Errorf("failed to remove peer %s for disabled interface %s: %w",
|
||||
peer.Identifier, iface.Identifier, err)
|
||||
}
|
||||
case peer.IsDisabled(): // if peer is disabled, delete it
|
||||
if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil {
|
||||
return fmt.Errorf("failed to remove disbaled peer %s from interface %s: %w",
|
||||
peer.Identifier, iface.Identifier, err)
|
||||
}
|
||||
default: // update peer
|
||||
err := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier,
|
||||
err := m.wg.GetController(iface).SavePeer(ctx, iface.Identifier, peer.Identifier,
|
||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
||||
domain.MergeToPhysicalPeer(pp, &peer)
|
||||
return pp, nil
|
||||
@@ -284,7 +288,7 @@ func (m Manager) RestoreInterfaceState(
|
||||
}
|
||||
|
||||
// remove non-wgportal peers
|
||||
physicalPeers, _ := m.wg.GetPeers(ctx, iface.Identifier)
|
||||
physicalPeers, _ := m.wg.GetController(iface).GetPeers(ctx, iface.Identifier)
|
||||
for _, physicalPeer := range physicalPeers {
|
||||
isWgPortalPeer := false
|
||||
for _, peer := range peers {
|
||||
@@ -294,7 +298,8 @@ func (m Manager) RestoreInterfaceState(
|
||||
}
|
||||
}
|
||||
if !isWgPortalPeer {
|
||||
err := m.wg.DeletePeer(ctx, iface.Identifier, domain.PeerIdentifier(physicalPeer.PublicKey))
|
||||
err := m.wg.GetController(iface).DeletePeer(ctx, iface.Identifier,
|
||||
domain.PeerIdentifier(physicalPeer.PublicKey))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove non-wgportal peer %s from interface %s: %w",
|
||||
physicalPeer.PublicKey, iface.Identifier, err)
|
||||
@@ -459,7 +464,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
|
||||
existingInterface.Disabled = &now // simulate a disabled interface
|
||||
existingInterface.DisabledReason = domain.DisabledReasonDeleted
|
||||
|
||||
physicalInterface, _ := m.wg.GetInterface(ctx, id)
|
||||
physicalInterface, _ := m.wg.GetController(*existingInterface).GetInterface(ctx, id)
|
||||
|
||||
if err := m.handleInterfacePreSaveHooks(existingInterface, !existingInterface.IsDisabled(), false); err != nil {
|
||||
return fmt.Errorf("pre-delete hooks failed: %w", err)
|
||||
@@ -473,7 +478,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
|
||||
return fmt.Errorf("peer deletion failure: %w", err)
|
||||
}
|
||||
|
||||
if err := m.wg.DeleteInterface(ctx, id); err != nil {
|
||||
if err := m.wg.GetController(*existingInterface).DeleteInterface(ctx, id); err != nil {
|
||||
return fmt.Errorf("wireguard deletion failure: %w", err)
|
||||
}
|
||||
|
||||
@@ -522,7 +527,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
|
||||
err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
|
||||
iface.CopyCalculatedAttributes(i)
|
||||
|
||||
err := m.wg.SaveInterface(ctx, iface.Identifier,
|
||||
err := m.wg.GetController(*iface).SaveInterface(ctx, iface.Identifier,
|
||||
func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||
domain.MergeToPhysicalInterface(pi, iface)
|
||||
return pi, nil
|
||||
@@ -538,7 +543,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
|
||||
}
|
||||
|
||||
if iface.IsDisabled() {
|
||||
physicalInterface, _ := m.wg.GetInterface(ctx, iface.Identifier)
|
||||
physicalInterface, _ := m.wg.GetController(*iface).GetInterface(ctx, iface.Identifier)
|
||||
fwMark := iface.FirewallMark
|
||||
if physicalInterface != nil && fwMark == 0 {
|
||||
fwMark = physicalInterface.FirewallMark
|
||||
@@ -556,13 +561,13 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
|
||||
}
|
||||
|
||||
// If the interface has just been enabled, restore its peers on the physical controller
|
||||
if !oldEnabled && newEnabled {
|
||||
if !oldEnabled && newEnabled && iface.Backend == config.LocalBackendName {
|
||||
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
|
||||
}
|
||||
for _, peer := range peers {
|
||||
saveErr := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier,
|
||||
saveErr := m.wg.GetController(*iface).SavePeer(ctx, iface.Identifier, peer.Identifier,
|
||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
||||
domain.MergeToPhysicalPeer(pp, &peer)
|
||||
return pp, nil
|
||||
@@ -766,7 +771,12 @@ func (m Manager) getFreshListenPort(ctx context.Context) (port int, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m Manager) importInterface(ctx context.Context, in *domain.PhysicalInterface, peers []domain.PhysicalPeer) error {
|
||||
func (m Manager) importInterface(
|
||||
ctx context.Context,
|
||||
backend InterfaceController,
|
||||
in *domain.PhysicalInterface,
|
||||
peers []domain.PhysicalPeer,
|
||||
) error {
|
||||
now := time.Now()
|
||||
iface := domain.ConvertPhysicalInterface(in)
|
||||
iface.BaseModel = domain.BaseModel{
|
||||
@@ -775,8 +785,20 @@ func (m Manager) importInterface(ctx context.Context, in *domain.PhysicalInterfa
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
iface.Backend = backend.GetId()
|
||||
iface.PeerDefAllowedIPsStr = iface.AddressStr()
|
||||
|
||||
// try to predict the interface type based on the number of peers
|
||||
switch len(peers) {
|
||||
case 0:
|
||||
iface.Type = domain.InterfaceTypeAny // no peers means this is an unknown interface
|
||||
case 1:
|
||||
iface.Type = domain.InterfaceTypeClient // one peer means this is a client interface
|
||||
default: // multiple peers means this is a server interface
|
||||
|
||||
iface.Type = domain.InterfaceTypeServer
|
||||
}
|
||||
|
||||
existingInterface, err := m.db.GetInterface(ctx, iface.Identifier)
|
||||
if err != nil && !errors.Is(err, domain.ErrNotFound) {
|
||||
return err
|
||||
@@ -827,16 +849,20 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain
|
||||
peer.Interface.PreDown = domain.NewConfigOption(in.PeerDefPreDown, true)
|
||||
peer.Interface.PostDown = domain.NewConfigOption(in.PeerDefPostDown, true)
|
||||
|
||||
var displayName string
|
||||
switch in.Type {
|
||||
case domain.InterfaceTypeAny:
|
||||
peer.Interface.Type = domain.InterfaceTypeAny
|
||||
peer.DisplayName = "Autodetected Peer (" + peer.Interface.PublicKey[0:8] + ")"
|
||||
displayName = "Autodetected Peer (" + peer.Interface.PublicKey[0:8] + ")"
|
||||
case domain.InterfaceTypeClient:
|
||||
peer.Interface.Type = domain.InterfaceTypeServer
|
||||
peer.DisplayName = "Autodetected Endpoint (" + peer.Interface.PublicKey[0:8] + ")"
|
||||
displayName = "Autodetected Endpoint (" + peer.Interface.PublicKey[0:8] + ")"
|
||||
case domain.InterfaceTypeServer:
|
||||
peer.Interface.Type = domain.InterfaceTypeClient
|
||||
peer.DisplayName = "Autodetected Client (" + peer.Interface.PublicKey[0:8] + ")"
|
||||
displayName = "Autodetected Client (" + peer.Interface.PublicKey[0:8] + ")"
|
||||
}
|
||||
if peer.DisplayName == "" {
|
||||
peer.DisplayName = displayName // use auto-generated display name if not set
|
||||
}
|
||||
|
||||
err := m.db.SavePeer(ctx, peer.Identifier, func(_ *domain.Peer) (*domain.Peer, error) {
|
||||
@@ -850,12 +876,12 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain
|
||||
}
|
||||
|
||||
func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) error {
|
||||
allPeers, err := m.db.GetInterfacePeers(ctx, id)
|
||||
iface, allPeers, err := m.db.GetInterfaceAndPeers(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, peer := range allPeers {
|
||||
err = m.wg.DeletePeer(ctx, id, peer.Identifier)
|
||||
err = m.wg.GetController(*iface).DeletePeer(ctx, id, peer.Identifier)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err)
|
||||
}
|
||||
|
@@ -371,7 +371,12 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
|
||||
return fmt.Errorf("delete not allowed: %w", err)
|
||||
}
|
||||
|
||||
err = m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, id)
|
||||
iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
|
||||
}
|
||||
|
||||
err = m.wg.GetController(*iface).DeletePeer(ctx, peer.InterfaceIdentifier, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wireguard failed to delete peer %s: %w", id, err)
|
||||
}
|
||||
@@ -433,35 +438,28 @@ func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier)
|
||||
func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
|
||||
interfaces := make(map[domain.InterfaceIdentifier]struct{})
|
||||
|
||||
for i := range peers {
|
||||
peer := peers[i]
|
||||
var err error
|
||||
if peer.IsDisabled() || peer.IsExpired() {
|
||||
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
||||
peer.CopyCalculatedAttributes(p)
|
||||
|
||||
if err := m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, peer.Identifier); err != nil {
|
||||
return nil, fmt.Errorf("failed to delete wireguard peer %s: %w", peer.Identifier, err)
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
})
|
||||
} else {
|
||||
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
||||
peer.CopyCalculatedAttributes(p)
|
||||
|
||||
err := m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
|
||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
||||
domain.MergeToPhysicalPeer(pp, peer)
|
||||
return pp, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
})
|
||||
for _, peer := range peers {
|
||||
iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
|
||||
}
|
||||
|
||||
// Always save the peer to the backend, regardless of disabled/expired state
|
||||
// The backend will handle the disabled state appropriately
|
||||
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
||||
peer.CopyCalculatedAttributes(p)
|
||||
|
||||
err := m.wg.GetController(*iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
|
||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
||||
domain.MergeToPhysicalPeer(pp, peer)
|
||||
return pp, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("save failure for peer %s: %w", peer.Identifier, err)
|
||||
}
|
||||
|
Reference in New Issue
Block a user