Mikrotik integration (#467)
Some checks failed
Docker / Build and Push (push) Has been cancelled
github-pages / deploy (push) Has been cancelled
Docker / release (push) Has been cancelled

Allow MikroTik routes as WireGuard backends
This commit is contained in:
h44z
2025-08-10 14:42:02 +02:00
committed by GitHub
parent a86f83a219
commit 112f6bfb77
40 changed files with 3150 additions and 205 deletions

View File

@@ -21,17 +21,23 @@ import (
//go:embed frontend_config.js.gotpl
var frontendJs embed.FS
type ControllerManager interface {
GetControllerNames() []config.BackendBase
}
type ConfigEndpoint struct {
cfg *config.Config
authenticator Authenticator
controllerMgr ControllerManager
tpl *respond.TemplateRenderer
}
func NewConfigEndpoint(cfg *config.Config, authenticator Authenticator) ConfigEndpoint {
func NewConfigEndpoint(cfg *config.Config, authenticator Authenticator, ctrlMgr ControllerManager) ConfigEndpoint {
ep := ConfigEndpoint{
cfg: cfg,
authenticator: authenticator,
controllerMgr: ctrlMgr,
tpl: respond.NewTemplateRenderer(template.Must(template.ParseFS(frontendJs,
"frontend_config.js.gotpl"))),
}
@@ -96,13 +102,36 @@ func (e ConfigEndpoint) handleSettingsGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
sessionUser := domain.GetUserInfo(r.Context())
controllerFn := func() []model.SettingsBackendNames {
controllers := e.controllerMgr.GetControllerNames()
names := make([]model.SettingsBackendNames, 0, len(controllers))
for _, controller := range controllers {
displayName := controller.GetDisplayName()
if displayName == "" {
displayName = controller.Id // fallback to ID if no display name is set
}
if controller.Id == config.LocalBackendName {
displayName = "modals.interface-edit.backend.local" // use a localized string for the local backend
}
names = append(names, model.SettingsBackendNames{
Id: controller.Id,
Name: displayName,
})
}
return names
}
hasSocialLogin := len(e.cfg.Auth.OAuth) > 0 || len(e.cfg.Auth.OpenIDConnect) > 0 || e.cfg.Auth.WebAuthn.Enabled
// For anonymous users, we return the settings object with minimal information
if sessionUser.Id == domain.CtxUnknownUserId || sessionUser.Id == "" {
respond.JSON(w, http.StatusOK, model.Settings{
WebAuthnEnabled: e.cfg.Auth.WebAuthn.Enabled,
LoginFormVisible: !e.cfg.Auth.HideLoginForm || !hasSocialLogin,
WebAuthnEnabled: e.cfg.Auth.WebAuthn.Enabled,
AvailableBackends: []model.SettingsBackendNames{}, // return an empty list instead of null
LoginFormVisible: !e.cfg.Auth.HideLoginForm || !hasSocialLogin,
})
} else {
respond.JSON(w, http.StatusOK, model.Settings{
@@ -112,6 +141,7 @@ func (e ConfigEndpoint) handleSettingsGet() http.HandlerFunc {
ApiAdminOnly: e.cfg.Advanced.ApiAdminOnly,
WebAuthnEnabled: e.cfg.Auth.WebAuthn.Enabled,
MinPasswordLength: e.cfg.Auth.MinPasswordLength,
AvailableBackends: controllerFn(),
LoginFormVisible: !e.cfg.Auth.HideLoginForm || !hasSocialLogin,
})
}

View File

@@ -6,11 +6,17 @@ type Error struct {
}
type Settings struct {
MailLinkOnly bool `json:"MailLinkOnly"`
PersistentConfigSupported bool `json:"PersistentConfigSupported"`
SelfProvisioning bool `json:"SelfProvisioning"`
ApiAdminOnly bool `json:"ApiAdminOnly"`
WebAuthnEnabled bool `json:"WebAuthnEnabled"`
MinPasswordLength int `json:"MinPasswordLength"`
LoginFormVisible bool `json:"LoginFormVisible"`
MailLinkOnly bool `json:"MailLinkOnly"`
PersistentConfigSupported bool `json:"PersistentConfigSupported"`
SelfProvisioning bool `json:"SelfProvisioning"`
ApiAdminOnly bool `json:"ApiAdminOnly"`
WebAuthnEnabled bool `json:"WebAuthnEnabled"`
MinPasswordLength int `json:"MinPasswordLength"`
AvailableBackends []SettingsBackendNames `json:"AvailableBackends"`
LoginFormVisible bool `json:"LoginFormVisible"`
}
type SettingsBackendNames struct {
Id string `json:"Id"`
Name string `json:"Name"`
}

View File

@@ -4,6 +4,7 @@ import (
"time"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
@@ -11,6 +12,7 @@ type Interface struct {
Identifier string `json:"Identifier" example:"wg0"` // device name, for example: wg0
DisplayName string `json:"DisplayName"` // a nice display name/ description for the interface
Mode string `json:"Mode" example:"server"` // the interface type, either 'server', 'client' or 'any'
Backend string `json:"Backend" example:"local"` // the backend used for this interface e.g., local, mikrotik, ...
PrivateKey string `json:"PrivateKey" example:"abcdef=="` // private Key of the server interface
PublicKey string `json:"PublicKey" example:"abcdef=="` // public Key of the server interface
Disabled bool `json:"Disabled"` // flag that specifies if the interface is enabled (up) or not (down)
@@ -57,6 +59,7 @@ func NewInterface(src *domain.Interface, peers []domain.Peer) *Interface {
Identifier: string(src.Identifier),
DisplayName: src.DisplayName,
Mode: string(src.Type),
Backend: string(src.Backend),
PrivateKey: src.PrivateKey,
PublicKey: src.PublicKey,
Disabled: src.IsDisabled(),
@@ -92,6 +95,10 @@ func NewInterface(src *domain.Interface, peers []domain.Peer) *Interface {
Filename: src.GetConfigFileName(),
}
if iface.Backend == "" {
iface.Backend = config.LocalBackendName // default to local backend
}
if len(peers) > 0 {
iface.TotalPeers = len(peers)
@@ -146,6 +153,7 @@ func NewDomainInterface(src *Interface) *domain.Interface {
SaveConfig: src.SaveConfig,
DisplayName: src.DisplayName,
Type: domain.InterfaceType(src.Mode),
Backend: domain.InterfaceBackend(src.Backend),
DriverType: "", // currently unused
Disabled: nil, // set below
DisabledReason: src.DisabledReason,

View File

@@ -46,7 +46,7 @@ func Initialize(
users: users,
}
startupContext, cancel := context.WithTimeout(context.Background(), 30*time.Second)
startupContext, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
// Switch to admin user context

View File

@@ -0,0 +1,166 @@
package wireguard
import (
"context"
"fmt"
"log/slog"
"maps"
"slices"
"github.com/h44z/wg-portal/internal/adapters/wgcontroller"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
type InterfaceController interface {
GetId() domain.InterfaceBackend
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
SaveInterface(
_ context.Context,
id domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
SavePeer(
_ context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
PingAddresses(
ctx context.Context,
addr string,
) (*domain.PingerResult, error)
}
type backendInstance struct {
Config config.BackendBase // Config is the configuration for the backend instance.
Implementation InterfaceController
}
type ControllerManager struct {
cfg *config.Config
controllers map[domain.InterfaceBackend]backendInstance
}
func NewControllerManager(cfg *config.Config) (*ControllerManager, error) {
c := &ControllerManager{
cfg: cfg,
controllers: make(map[domain.InterfaceBackend]backendInstance),
}
err := c.init()
if err != nil {
return nil, err
}
return c, nil
}
func (c *ControllerManager) init() error {
if err := c.registerLocalController(); err != nil {
return err
}
if err := c.registerMikrotikControllers(); err != nil {
return err
}
c.logRegisteredControllers()
return nil
}
func (c *ControllerManager) registerLocalController() error {
localController, err := wgcontroller.NewLocalController(c.cfg)
if err != nil {
return fmt.Errorf("failed to create local WireGuard controller: %w", err)
}
c.controllers[config.LocalBackendName] = backendInstance{
Config: config.BackendBase{
Id: config.LocalBackendName,
DisplayName: "Local WireGuard Controller",
},
Implementation: localController,
}
return nil
}
func (c *ControllerManager) registerMikrotikControllers() error {
for _, backendConfig := range c.cfg.Backend.Mikrotik {
if backendConfig.Id == config.LocalBackendName {
slog.Warn("skipping registration of Mikrotik controller with reserved ID", "id", config.LocalBackendName)
continue
}
controller, err := wgcontroller.NewMikrotikController(c.cfg, &backendConfig)
if err != nil {
return fmt.Errorf("failed to create Mikrotik controller for backend %s: %w", backendConfig.Id, err)
}
c.controllers[domain.InterfaceBackend(backendConfig.Id)] = backendInstance{
Config: backendConfig.BackendBase,
Implementation: controller,
}
}
return nil
}
func (c *ControllerManager) logRegisteredControllers() {
for backend, controller := range c.controllers {
slog.Debug("backend controller registered",
"backend", backend, "type", fmt.Sprintf("%T", controller.Implementation))
}
}
func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) InterfaceController {
return c.getController(backend, "")
}
func (c *ControllerManager) GetController(iface domain.Interface) InterfaceController {
return c.getController(iface.Backend, iface.Identifier)
}
func (c *ControllerManager) getController(
backend domain.InterfaceBackend,
ifaceId domain.InterfaceIdentifier,
) InterfaceController {
if backend == "" {
// If no backend is specified, use the local controller.
// This might be the case for interfaces created in previous WireGuard Portal versions.
backend = config.LocalBackendName
}
controller, exists := c.controllers[backend]
if !exists {
controller, exists = c.controllers[config.LocalBackendName] // Fallback to local controller
if !exists {
// If the local controller is also not found, panic
panic(fmt.Sprintf("%s interface controller for backend %s not found", ifaceId, backend))
}
slog.Warn("controller for backend not found, using local controller",
"backend", backend, "interface", ifaceId)
}
return controller.Implementation
}
func (c *ControllerManager) GetAllControllers() []InterfaceController {
var backendInstances = make([]InterfaceController, 0, len(c.controllers))
for instance := range maps.Values(c.controllers) {
backendInstances = append(backendInstances, instance.Implementation)
}
return backendInstances
}
func (c *ControllerManager) GetControllerNames() []config.BackendBase {
var names []config.BackendBase
for _, id := range slices.Sorted(maps.Keys(c.controllers)) {
names = append(names, c.controllers[id].Config)
}
return names
}

View File

@@ -6,8 +6,6 @@ import (
"sync"
"time"
probing "github.com/prometheus-community/pro-bing"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
@@ -30,11 +28,6 @@ type StatisticsDatabaseRepo interface {
DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error
}
type StatisticsInterfaceController interface {
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
}
type StatisticsMetricsServer interface {
UpdateInterfaceMetrics(status domain.InterfaceStatus)
UpdatePeerMetrics(peer *domain.Peer, status domain.PeerStatus)
@@ -47,15 +40,20 @@ type StatisticsEventBus interface {
Publish(topic string, args ...any)
}
type pingJob struct {
Peer domain.Peer
Backend domain.InterfaceBackend
}
type StatisticsCollector struct {
cfg *config.Config
bus StatisticsEventBus
pingWaitGroup sync.WaitGroup
pingJobs chan domain.Peer
pingJobs chan pingJob
db StatisticsDatabaseRepo
wg StatisticsInterfaceController
wg *ControllerManager
ms StatisticsMetricsServer
peerChangeEvent chan domain.PeerIdentifier
@@ -66,7 +64,7 @@ func NewStatisticsCollector(
cfg *config.Config,
bus StatisticsEventBus,
db StatisticsDatabaseRepo,
wg StatisticsInterfaceController,
wg *ControllerManager,
ms StatisticsMetricsServer,
) (*StatisticsCollector, error) {
c := &StatisticsCollector{
@@ -117,7 +115,7 @@ func (c *StatisticsCollector) collectInterfaceData(ctx context.Context) {
}
for _, in := range interfaces {
physicalInterface, err := c.wg.GetInterface(ctx, in.Identifier)
physicalInterface, err := c.wg.GetController(in).GetInterface(ctx, in.Identifier)
if err != nil {
slog.Warn("failed to load physical interface for data collection", "interface", in.Identifier,
"error", err)
@@ -169,7 +167,7 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
}
for _, in := range interfaces {
peers, err := c.wg.GetPeers(ctx, in.Identifier)
peers, err := c.wg.GetController(in).GetPeers(ctx, in.Identifier)
if err != nil {
slog.Warn("failed to fetch peers for data collection", "interface", in.Identifier, "error", err)
continue
@@ -271,7 +269,7 @@ func (c *StatisticsCollector) startPingWorkers(ctx context.Context) {
c.pingWaitGroup = sync.WaitGroup{}
c.pingWaitGroup.Add(c.cfg.Statistics.PingCheckWorkers)
c.pingJobs = make(chan domain.Peer, c.cfg.Statistics.PingCheckWorkers)
c.pingJobs = make(chan pingJob, c.cfg.Statistics.PingCheckWorkers)
// start workers
for i := 0; i < c.cfg.Statistics.PingCheckWorkers; i++ {
@@ -314,7 +312,10 @@ func (c *StatisticsCollector) enqueuePingChecks(ctx context.Context) {
continue
}
for _, peer := range peers {
c.pingJobs <- peer
c.pingJobs <- pingJob{
Peer: peer,
Backend: in.Backend,
}
}
}
}
@@ -323,11 +324,14 @@ func (c *StatisticsCollector) enqueuePingChecks(ctx context.Context) {
func (c *StatisticsCollector) pingWorker(ctx context.Context) {
defer c.pingWaitGroup.Done()
for peer := range c.pingJobs {
for job := range c.pingJobs {
peer := job.Peer
backend := job.Backend
var connectionStateChanged bool
var newPeerStatus domain.PeerStatus
peerPingable := c.isPeerPingable(ctx, peer)
peerPingable := c.isPeerPingable(ctx, backend, peer)
slog.Debug("peer ping check completed", "peer", peer.Identifier, "pingable", peerPingable)
now := time.Now()
@@ -368,7 +372,11 @@ func (c *StatisticsCollector) pingWorker(ctx context.Context) {
}
}
func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Peer) bool {
func (c *StatisticsCollector) isPeerPingable(
ctx context.Context,
backend domain.InterfaceBackend,
peer domain.Peer,
) bool {
if !c.cfg.Statistics.UsePingChecks {
return false
}
@@ -378,23 +386,13 @@ func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Pe
return false
}
pinger, err := probing.NewPinger(checkAddr)
stats, err := c.wg.GetControllerByName(backend).PingAddresses(ctx, checkAddr)
if err != nil {
slog.Debug("failed to instantiate pinger", "peer", peer.Identifier, "address", checkAddr, "error", err)
slog.Debug("failed to ping peer", "peer", peer.Identifier, "error", err)
return false
}
checkCount := 1
pinger.SetPrivileged(!c.cfg.Statistics.PingUnprivileged)
pinger.Count = checkCount
pinger.Timeout = 2 * time.Second
err = pinger.RunWithContext(ctx) // Blocks until finished.
if err != nil {
slog.Debug("pinger for peer exited unexpectedly", "peer", peer.Identifier, "address", checkAddr, "error", err)
return false
}
stats := pinger.Statistics()
return stats.PacketsRecv == checkCount
return stats.IsPingable()
}
func (c *StatisticsCollector) updateInterfaceMetrics(status domain.InterfaceStatus) {

View File

@@ -37,25 +37,6 @@ type InterfaceAndPeerDatabaseRepo interface {
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
}
type InterfaceController interface {
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
SaveInterface(
_ context.Context,
id domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
SavePeer(
_ context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
}
type WgQuickController interface {
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
@@ -75,7 +56,7 @@ type Manager struct {
cfg *config.Config
bus EventBus
db InterfaceAndPeerDatabaseRepo
wg InterfaceController
wg *ControllerManager
quick WgQuickController
userLockMap *sync.Map
@@ -84,7 +65,7 @@ type Manager struct {
func NewWireGuardManager(
cfg *config.Config,
bus EventBus,
wg InterfaceController,
wg *ControllerManager,
quick WgQuickController,
db InterfaceAndPeerDatabaseRepo,
) (*Manager, error) {

View File

@@ -11,6 +11,7 @@ import (
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/app/audit"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
@@ -21,12 +22,17 @@ func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.Physical
return nil, err
}
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
if err != nil {
return nil, err
var allPhysicalInterfaces []domain.PhysicalInterface
for _, wgBackend := range m.wg.GetAllControllers() {
physicalInterfaces, err := wgBackend.GetInterfaces(ctx)
if err != nil {
return nil, err
}
allPhysicalInterfaces = append(allPhysicalInterfaces, physicalInterfaces...)
}
return physicalInterfaces, nil
return allPhysicalInterfaces, nil
}
// GetInterfaceAndPeers returns the interface and all peers for the given interface identifier.
@@ -109,47 +115,49 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter
return 0, err
}
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
if err != nil {
return 0, err
}
// if no filter is given, exclude already existing interfaces
var excludedInterfaces []domain.InterfaceIdentifier
if len(filter) == 0 {
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return 0, err
}
for _, existingInterface := range existingInterfaces {
excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier)
}
}
imported := 0
for _, physicalInterface := range physicalInterfaces {
if slices.Contains(excludedInterfaces, physicalInterface.Identifier) {
continue
}
if len(filter) != 0 && !slices.Contains(filter, physicalInterface.Identifier) {
continue
}
slog.Info("importing new interface", "interface", physicalInterface.Identifier)
physicalPeers, err := m.wg.GetPeers(ctx, physicalInterface.Identifier)
for _, wgBackend := range m.wg.GetAllControllers() {
physicalInterfaces, err := wgBackend.GetInterfaces(ctx)
if err != nil {
return 0, err
}
err = m.importInterface(ctx, &physicalInterface, physicalPeers)
if err != nil {
return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err)
// if no filter is given, exclude already existing interfaces
var excludedInterfaces []domain.InterfaceIdentifier
if len(filter) == 0 {
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return 0, err
}
for _, existingInterface := range existingInterfaces {
excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier)
}
}
slog.Info("imported new interface", "interface", physicalInterface.Identifier, "peers", len(physicalPeers))
imported++
for _, physicalInterface := range physicalInterfaces {
if slices.Contains(excludedInterfaces, physicalInterface.Identifier) {
continue
}
if len(filter) != 0 && !slices.Contains(filter, physicalInterface.Identifier) {
continue
}
slog.Info("importing new interface", "interface", physicalInterface.Identifier)
physicalPeers, err := wgBackend.GetPeers(ctx, physicalInterface.Identifier)
if err != nil {
return 0, err
}
err = m.importInterface(ctx, wgBackend, &physicalInterface, physicalPeers)
if err != nil {
return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err)
}
slog.Info("imported new interface", "interface", physicalInterface.Identifier, "peers", len(physicalPeers))
imported++
}
}
return imported, nil
@@ -213,7 +221,7 @@ func (m Manager) RestoreInterfaceState(
return fmt.Errorf("failed to load peers for %s: %w", iface.Identifier, err)
}
_, err = m.wg.GetInterface(ctx, iface.Identifier)
_, err = m.wg.GetController(iface).GetInterface(ctx, iface.Identifier)
if err != nil && !iface.IsDisabled() {
slog.Debug("creating missing interface", "interface", iface.Identifier)
@@ -260,18 +268,14 @@ func (m Manager) RestoreInterfaceState(
// restore peers
for _, peer := range peers {
switch {
case iface.IsDisabled(): // if interface is disabled, delete all peers
if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil {
case iface.IsDisabled() && iface.Backend == config.LocalBackendName: // if interface is disabled, delete all peers
if err := m.wg.GetController(iface).DeletePeer(ctx, iface.Identifier,
peer.Identifier); err != nil {
return fmt.Errorf("failed to remove peer %s for disabled interface %s: %w",
peer.Identifier, iface.Identifier, err)
}
case peer.IsDisabled(): // if peer is disabled, delete it
if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil {
return fmt.Errorf("failed to remove disbaled peer %s from interface %s: %w",
peer.Identifier, iface.Identifier, err)
}
default: // update peer
err := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier,
err := m.wg.GetController(iface).SavePeer(ctx, iface.Identifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, &peer)
return pp, nil
@@ -284,7 +288,7 @@ func (m Manager) RestoreInterfaceState(
}
// remove non-wgportal peers
physicalPeers, _ := m.wg.GetPeers(ctx, iface.Identifier)
physicalPeers, _ := m.wg.GetController(iface).GetPeers(ctx, iface.Identifier)
for _, physicalPeer := range physicalPeers {
isWgPortalPeer := false
for _, peer := range peers {
@@ -294,7 +298,8 @@ func (m Manager) RestoreInterfaceState(
}
}
if !isWgPortalPeer {
err := m.wg.DeletePeer(ctx, iface.Identifier, domain.PeerIdentifier(physicalPeer.PublicKey))
err := m.wg.GetController(iface).DeletePeer(ctx, iface.Identifier,
domain.PeerIdentifier(physicalPeer.PublicKey))
if err != nil {
return fmt.Errorf("failed to remove non-wgportal peer %s from interface %s: %w",
physicalPeer.PublicKey, iface.Identifier, err)
@@ -459,7 +464,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
existingInterface.Disabled = &now // simulate a disabled interface
existingInterface.DisabledReason = domain.DisabledReasonDeleted
physicalInterface, _ := m.wg.GetInterface(ctx, id)
physicalInterface, _ := m.wg.GetController(*existingInterface).GetInterface(ctx, id)
if err := m.handleInterfacePreSaveHooks(existingInterface, !existingInterface.IsDisabled(), false); err != nil {
return fmt.Errorf("pre-delete hooks failed: %w", err)
@@ -473,7 +478,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
return fmt.Errorf("peer deletion failure: %w", err)
}
if err := m.wg.DeleteInterface(ctx, id); err != nil {
if err := m.wg.GetController(*existingInterface).DeleteInterface(ctx, id); err != nil {
return fmt.Errorf("wireguard deletion failure: %w", err)
}
@@ -522,7 +527,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
iface.CopyCalculatedAttributes(i)
err := m.wg.SaveInterface(ctx, iface.Identifier,
err := m.wg.GetController(*iface).SaveInterface(ctx, iface.Identifier,
func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
domain.MergeToPhysicalInterface(pi, iface)
return pi, nil
@@ -538,7 +543,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
}
if iface.IsDisabled() {
physicalInterface, _ := m.wg.GetInterface(ctx, iface.Identifier)
physicalInterface, _ := m.wg.GetController(*iface).GetInterface(ctx, iface.Identifier)
fwMark := iface.FirewallMark
if physicalInterface != nil && fwMark == 0 {
fwMark = physicalInterface.FirewallMark
@@ -556,13 +561,13 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
}
// If the interface has just been enabled, restore its peers on the physical controller
if !oldEnabled && newEnabled {
if !oldEnabled && newEnabled && iface.Backend == config.LocalBackendName {
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil {
return nil, fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
}
for _, peer := range peers {
saveErr := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier,
saveErr := m.wg.GetController(*iface).SavePeer(ctx, iface.Identifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, &peer)
return pp, nil
@@ -766,7 +771,12 @@ func (m Manager) getFreshListenPort(ctx context.Context) (port int, err error) {
return
}
func (m Manager) importInterface(ctx context.Context, in *domain.PhysicalInterface, peers []domain.PhysicalPeer) error {
func (m Manager) importInterface(
ctx context.Context,
backend InterfaceController,
in *domain.PhysicalInterface,
peers []domain.PhysicalPeer,
) error {
now := time.Now()
iface := domain.ConvertPhysicalInterface(in)
iface.BaseModel = domain.BaseModel{
@@ -775,8 +785,20 @@ func (m Manager) importInterface(ctx context.Context, in *domain.PhysicalInterfa
CreatedAt: now,
UpdatedAt: now,
}
iface.Backend = backend.GetId()
iface.PeerDefAllowedIPsStr = iface.AddressStr()
// try to predict the interface type based on the number of peers
switch len(peers) {
case 0:
iface.Type = domain.InterfaceTypeAny // no peers means this is an unknown interface
case 1:
iface.Type = domain.InterfaceTypeClient // one peer means this is a client interface
default: // multiple peers means this is a server interface
iface.Type = domain.InterfaceTypeServer
}
existingInterface, err := m.db.GetInterface(ctx, iface.Identifier)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return err
@@ -827,16 +849,20 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain
peer.Interface.PreDown = domain.NewConfigOption(in.PeerDefPreDown, true)
peer.Interface.PostDown = domain.NewConfigOption(in.PeerDefPostDown, true)
var displayName string
switch in.Type {
case domain.InterfaceTypeAny:
peer.Interface.Type = domain.InterfaceTypeAny
peer.DisplayName = "Autodetected Peer (" + peer.Interface.PublicKey[0:8] + ")"
displayName = "Autodetected Peer (" + peer.Interface.PublicKey[0:8] + ")"
case domain.InterfaceTypeClient:
peer.Interface.Type = domain.InterfaceTypeServer
peer.DisplayName = "Autodetected Endpoint (" + peer.Interface.PublicKey[0:8] + ")"
displayName = "Autodetected Endpoint (" + peer.Interface.PublicKey[0:8] + ")"
case domain.InterfaceTypeServer:
peer.Interface.Type = domain.InterfaceTypeClient
peer.DisplayName = "Autodetected Client (" + peer.Interface.PublicKey[0:8] + ")"
displayName = "Autodetected Client (" + peer.Interface.PublicKey[0:8] + ")"
}
if peer.DisplayName == "" {
peer.DisplayName = displayName // use auto-generated display name if not set
}
err := m.db.SavePeer(ctx, peer.Identifier, func(_ *domain.Peer) (*domain.Peer, error) {
@@ -850,12 +876,12 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain
}
func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) error {
allPeers, err := m.db.GetInterfacePeers(ctx, id)
iface, allPeers, err := m.db.GetInterfaceAndPeers(ctx, id)
if err != nil {
return err
}
for _, peer := range allPeers {
err = m.wg.DeletePeer(ctx, id, peer.Identifier)
err = m.wg.GetController(*iface).DeletePeer(ctx, id, peer.Identifier)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err)
}

View File

@@ -371,7 +371,12 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
return fmt.Errorf("delete not allowed: %w", err)
}
err = m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, id)
iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
if err != nil {
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
}
err = m.wg.GetController(*iface).DeletePeer(ctx, peer.InterfaceIdentifier, id)
if err != nil {
return fmt.Errorf("wireguard failed to delete peer %s: %w", id, err)
}
@@ -433,35 +438,28 @@ func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier)
func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
interfaces := make(map[domain.InterfaceIdentifier]struct{})
for i := range peers {
peer := peers[i]
var err error
if peer.IsDisabled() || peer.IsExpired() {
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
peer.CopyCalculatedAttributes(p)
if err := m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, peer.Identifier); err != nil {
return nil, fmt.Errorf("failed to delete wireguard peer %s: %w", peer.Identifier, err)
}
return peer, nil
})
} else {
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
peer.CopyCalculatedAttributes(p)
err := m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, peer)
return pp, nil
})
if err != nil {
return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
}
return peer, nil
})
for _, peer := range peers {
iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
if err != nil {
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
}
// Always save the peer to the backend, regardless of disabled/expired state
// The backend will handle the disabled state appropriately
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
peer.CopyCalculatedAttributes(p)
err := m.wg.GetController(*iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, peer)
return pp, nil
})
if err != nil {
return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
}
return peer, nil
})
if err != nil {
return fmt.Errorf("save failure for peer %s: %w", peer.Identifier, err)
}