From 1fc7e352abae6a5413f15b79e2e46af94f979243 Mon Sep 17 00:00:00 2001 From: Christoph Haas Date: Mon, 6 Oct 2025 22:17:39 +0200 Subject: [PATCH] mikrotik: allow to set DNS, wip: handle routes in wg-controller --- cmd/wg-portal/main.go | 6 +- docs/documentation/configuration/examples.md | 3 + docs/documentation/configuration/overview.md | 6 + internal/adapters/wgcontroller/local.go | 228 +-------- internal/adapters/wgcontroller/mikrotik.go | 127 ++++- internal/adapters/wgquick.go | 113 ---- internal/app/route/routes.go | 481 +++--------------- internal/app/wireguard/controller_manager.go | 31 +- internal/app/wireguard/wireguard.go | 17 +- .../app/wireguard/wireguard_interfaces.go | 111 ++-- internal/app/wireguard/wireguard_peers.go | 43 +- internal/config/backend.go | 1 + internal/config/config.go | 3 + internal/domain/interface.go | 8 +- internal/domain/interface_controller.go | 27 + internal/domain/ip.go | 19 + internal/lowlevel/mikrotik.go | 1 + 17 files changed, 394 insertions(+), 831 deletions(-) delete mode 100644 internal/adapters/wgquick.go create mode 100644 internal/domain/interface_controller.go diff --git a/cmd/wg-portal/main.go b/cmd/wg-portal/main.go index 97f0b67..cd5fd7e 100644 --- a/cmd/wg-portal/main.go +++ b/cmd/wg-portal/main.go @@ -53,8 +53,6 @@ func main() { wireGuard, err := wireguard.NewControllerManager(cfg) internal.AssertNoError(err) - wgQuick := adapters.NewWgQuickRepo() - mailer := adapters.NewSmtpMailRepo(cfg.Mail) metricsServer := adapters.NewMetricsServer(cfg) @@ -93,7 +91,7 @@ func main() { webAuthn, err := auth.NewWebAuthnAuthenticator(cfg, eventBus, userManager) internal.AssertNoError(err) - wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, wgQuick, database) + wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, database) internal.AssertNoError(err) wireGuardManager.StartBackgroundJobs(ctx) @@ -107,7 +105,7 @@ func main() { mailManager, err := mail.NewMailManager(cfg, mailer, cfgFileManager, database, database) internal.AssertNoError(err) - routeManager, err := route.NewRouteManager(cfg, eventBus, database) + routeManager, err := route.NewRouteManager(cfg, eventBus, database, wireGuard) internal.AssertNoError(err) routeManager.StartBackgroundJobs(ctx) diff --git a/docs/documentation/configuration/examples.md b/docs/documentation/configuration/examples.md index 3909d48..4cf1e56 100644 --- a/docs/documentation/configuration/examples.md +++ b/docs/documentation/configuration/examples.md @@ -15,6 +15,9 @@ backend: # default backend decides where new interfaces are created default: mikrotik + # A prefix for resolvconf. Usually it is "tun.". If you are using systemd, the prefix should be empty. + local_resolvconf_prefix: "tun." + mikrotik: - id: mikrotik # unique id, not "local" display_name: RouterOS RB5009 # optional nice name diff --git a/docs/documentation/configuration/overview.md b/docs/documentation/configuration/overview.md index 0345338..f5d605e 100644 --- a/docs/documentation/configuration/overview.md +++ b/docs/documentation/configuration/overview.md @@ -28,6 +28,7 @@ core: backend: default: local + local_resolvconf_prefix: tun. advanced: log_level: info @@ -184,6 +185,11 @@ The current MikroTik backend is in **BETA** and may not support all features. - **Description:** The default backend to use for managing WireGuard interfaces. Valid options are: `local`, or other backend id's configured in the `mikrotik` section. +### `local_resolvconf_prefix` +- **Default:** `tun.` +- **Description:** Interface name prefix for WireGuard interfaces on the local system which is used to configure DNS servers with *resolvconf*. + It depends on the *resolvconf* implementation you are using, most use a prefix of `tun.`, but some have an empty prefix (e.g., systemd). + ### `ignored_local_interfaces` - **Default:** *(empty)* - **Description:** A list of interface names to exclude when enumerating local interfaces. diff --git a/internal/adapters/wgcontroller/local.go b/internal/adapters/wgcontroller/local.go index 7f2e7fa..c977a57 100644 --- a/internal/adapters/wgcontroller/local.go +++ b/internal/adapters/wgcontroller/local.go @@ -14,7 +14,6 @@ import ( probing "github.com/prometheus-community/pro-bing" "github.com/vishvananda/netlink" - "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -84,8 +83,8 @@ func NewLocalController(cfg *config.Config) (*LocalController, error) { wg: wg, nl: nl, - shellCmd: "bash", // we only support bash at the moment - resolvConfIfacePrefix: "tun.", // WireGuard interfaces have a tun. prefix in resolvconf + shellCmd: "bash", // we only support bash at the moment + resolvConfIfacePrefix: cfg.Backend.LocalResolvconfPrefix, // WireGuard interfaces have a tun. prefix in resolvconf } return repo, nil @@ -546,7 +545,11 @@ func (c LocalController) deletePeer(deviceId domain.InterfaceIdentifier, id doma // region wg-quick-related -func (c LocalController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error { +func (c LocalController) ExecuteInterfaceHook( + _ context.Context, + id domain.InterfaceIdentifier, + hookCmd string, +) error { if hookCmd == "" { return nil } @@ -560,7 +563,7 @@ func (c LocalController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hoo return nil } -func (c LocalController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error { +func (c LocalController) SetDNS(_ context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error { if dnsStr == "" && dnsSearchStr == "" { return nil } @@ -589,7 +592,7 @@ func (c LocalController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearch return nil } -func (c LocalController) UnsetDNS(id domain.InterfaceIdentifier) error { +func (c LocalController) UnsetDNS(_ context.Context, id domain.InterfaceIdentifier, _, _ string) error { dnsCommand := "resolvconf -d %resPref%i -f" err := c.exec(dnsCommand, id) @@ -611,7 +614,7 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti if len(stdin) > 0 { b := &bytes.Buffer{} for _, ln := range stdin { - if _, err := fmt.Fprint(b, ln); err != nil { + if _, err := fmt.Fprint(b, ln+"\n"); err != nil { return err } } @@ -619,6 +622,8 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti } out, err := cmd.CombinedOutput() // execute and wait for output if err != nil { + slog.Warn("failed to executed shell command", + "command", commandWithInterfaceName, "stdin", stdin, "output", string(out), "error", err) return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err) } slog.Debug("executed shell command", @@ -631,205 +636,28 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti // region routing-related -func (c LocalController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error { - // update fwmark rules - if err := c.setFwMarkRules(rules); err != nil { - return err - } - - // update main rule - if err := c.setMainRule(rules); err != nil { - return err - } - - // cleanup old main rules - if err := c.cleanupMainRule(rules); err != nil { - return err - } - +// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op. +func (c LocalController) SetRoutes( + ctx context.Context, + interfaceId domain.InterfaceIdentifier, + table int, + fwMark uint32, + cidrs []domain.Cidr, +) error { return nil } -func (c LocalController) setFwMarkRules(rules []domain.RouteRule) error { - for _, rule := range rules { - existingRules, err := c.nl.RuleList(int(rule.IpFamily)) - if err != nil { - return fmt.Errorf("failed to get existing rules for family %s: %w", rule.IpFamily, err) - } - - ruleExists := false - for _, existingRule := range existingRules { - if rule.FwMark == existingRule.Mark && rule.Table == existingRule.Table { - ruleExists = true - break - } - } - - if ruleExists { - continue // rule already exists, no need to recreate it - } - - // create a missing rule - if err := c.nl.RuleAdd(&netlink.Rule{ - Family: int(rule.IpFamily), - Table: rule.Table, - Mark: rule.FwMark, - Invert: true, - SuppressIfgroup: -1, - SuppressPrefixlen: -1, - Priority: c.getRulePriority(existingRules), - Mask: nil, - Goto: -1, - Flow: -1, - }); err != nil { - return fmt.Errorf("failed to setup %s rule for fwmark %d and table %d: %w", - rule.IpFamily, rule.FwMark, rule.Table, err) - } - } +// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op. +func (c LocalController) RemoveRoutes( + ctx context.Context, + interfaceId domain.InterfaceIdentifier, + table int, + fwMark uint32, + oldCidrs []domain.Cidr, +) error { return nil } -func (c LocalController) getRulePriority(existingRules []netlink.Rule) int { - prio := 32700 // linux main rule has a priority of 32766 - for { - isFresh := true - for _, existingRule := range existingRules { - if existingRule.Priority == prio { - isFresh = false - break - } - } - if isFresh { - break - } else { - prio-- - } - } - return prio -} - -func (c LocalController) setMainRule(rules []domain.RouteRule) error { - var family domain.IpFamily - shouldHaveMainRule := false - for _, rule := range rules { - family = rule.IpFamily - if rule.HasDefault == true { - shouldHaveMainRule = true - break - } - } - if !shouldHaveMainRule { - return nil - } - - existingRules, err := c.nl.RuleList(int(family)) - if err != nil { - return fmt.Errorf("failed to get existing rules for family %s: %w", family, err) - } - - ruleExists := false - for _, existingRule := range existingRules { - if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { - ruleExists = true - break - } - } - - if ruleExists { - return nil // rule already exists, skip re-creation - } - - if err := c.nl.RuleAdd(&netlink.Rule{ - Family: int(family), - Table: unix.RT_TABLE_MAIN, - SuppressIfgroup: -1, - SuppressPrefixlen: 0, - Priority: c.getMainRulePriority(existingRules), - Mark: 0, - Mask: nil, - Goto: -1, - Flow: -1, - }); err != nil { - return fmt.Errorf("failed to setup rule for main table: %w", err) - } - - return nil -} - -func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int { - priority := c.cfg.Advanced.RulePrioOffset - for { - isFresh := true - for _, existingRule := range existingRules { - if existingRule.Priority == priority { - isFresh = false - break - } - } - if isFresh { - break - } else { - priority++ - } - } - return priority -} - -func (c LocalController) cleanupMainRule(rules []domain.RouteRule) error { - var family domain.IpFamily - for _, rule := range rules { - family = rule.IpFamily - break - } - - existingRules, err := c.nl.RuleList(int(family)) - if err != nil { - return fmt.Errorf("failed to get existing rules for family %s: %w", family, err) - } - - shouldHaveMainRule := false - for _, rule := range rules { - if rule.HasDefault == true { - shouldHaveMainRule = true - break - } - } - - mainRules := 0 - for _, existingRule := range existingRules { - if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { - mainRules++ - } - } - - removalCount := 0 - if mainRules > 1 { - removalCount = mainRules - 1 // we only want one single rule - } - if !shouldHaveMainRule { - removalCount = mainRules - } - - for _, existingRule := range existingRules { - if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { - if removalCount > 0 { - existingRule.Family = int(family) // set family, somehow the RuleList method does not populate the family field - if err := c.nl.RuleDel(&existingRule); err != nil { - return fmt.Errorf("failed to delete main rule: %w", err) - } - removalCount-- - } - } - } - - return nil -} - -func (c LocalController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error { - // TODO implement me - panic("implement me") -} - // endregion routing-related // region statistics-related diff --git a/internal/adapters/wgcontroller/mikrotik.go b/internal/adapters/wgcontroller/mikrotik.go index ac98094..004f2f2 100644 --- a/internal/adapters/wgcontroller/mikrotik.go +++ b/internal/adapters/wgcontroller/mikrotik.go @@ -22,8 +22,9 @@ type MikrotikController struct { client *lowlevel.MikrotikApiClient // Add mutexes to prevent race conditions - interfaceMutexes sync.Map // map[domain.InterfaceIdentifier]*sync.Mutex - peerMutexes sync.Map // map[domain.PeerIdentifier]*sync.Mutex + interfaceMutexes sync.Map // map[domain.InterfaceIdentifier]*sync.Mutex + peerMutexes sync.Map // map[domain.PeerIdentifier]*sync.Mutex + coreMutex sync.Mutex // for updating the core configuration such as routing table or DNS settings } func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik) (*MikrotikController, error) { @@ -40,6 +41,7 @@ func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik) interfaceMutexes: sync.Map{}, peerMutexes: sync.Map{}, + coreMutex: sync.Mutex{}, }, nil } @@ -763,33 +765,126 @@ func (c *MikrotikController) DeletePeer( // region wg-quick-related -func (c *MikrotikController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error { +func (c *MikrotikController) ExecuteInterfaceHook( + _ context.Context, + _ domain.InterfaceIdentifier, + _ string, +) error { // TODO implement me - panic("implement me") + slog.Error("interface hooks are not yet supported for Mikrotik backends, please open an issue on GitHub") + return nil } -func (c *MikrotikController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error { - // TODO implement me - panic("implement me") +func (c *MikrotikController) SetDNS( + ctx context.Context, + _ domain.InterfaceIdentifier, + dnsStr, _ string, +) error { + // Lock the interface to prevent concurrent modifications + c.coreMutex.Lock() + defer c.coreMutex.Unlock() + + // check if the server is already configured + wgReply := c.client.Get(ctx, "/ip/dns", &lowlevel.MikrotikRequestOptions{ + PropList: []string{"servers"}, + }) + if wgReply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("unable to find WireGuard dns settings: %v", wgReply.Error) + } + + var existingServers []string + existingServers = append(existingServers, strings.Split(wgReply.Data.GetString("servers"), ",")...) + + newServers := strings.Split(dnsStr, ",") + + mergedServers := slices.Clone(existingServers) + for _, s := range newServers { + if s == "" { + continue + } + if !slices.Contains(mergedServers, s) { + mergedServers = append(mergedServers, s) + } + } + mergedServersStr := strings.Join(mergedServers, ",") + + reply := c.client.ExecList(ctx, "/ip/dns/set", lowlevel.GenericJsonObject{ + "servers": mergedServersStr, + }) + if reply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to set DNS servers: %s: %v", mergedServersStr, reply.Error) + } + + return nil } -func (c *MikrotikController) UnsetDNS(id domain.InterfaceIdentifier) error { - // TODO implement me - panic("implement me") +func (c *MikrotikController) UnsetDNS( + ctx context.Context, + _ domain.InterfaceIdentifier, + dnsStr, _ string, +) error { + // Lock the interface to prevent concurrent modifications + c.coreMutex.Lock() + defer c.coreMutex.Unlock() + + // retrieve current DNS settings + wgReply := c.client.Get(ctx, "/ip/dns", &lowlevel.MikrotikRequestOptions{ + PropList: []string{"servers"}, + }) + if wgReply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("unable to find WireGuard dns settings: %v", wgReply.Error) + } + + var existingServers []string + existingServers = append(existingServers, strings.Split(wgReply.Data.GetString("servers"), ",")...) + + oldServers := strings.Split(dnsStr, ",") + + mergedServers := make([]string, 0, len(existingServers)) + for _, s := range existingServers { + if s == "" { + continue + } + if !slices.Contains(oldServers, s) { + mergedServers = append(mergedServers, s) // only keep the servers that are not in the old list + } + } + mergedServersStr := strings.Join(mergedServers, ",") + + reply := c.client.ExecList(ctx, "/ip/dns/set", lowlevel.GenericJsonObject{ + "servers": mergedServersStr, + }) + if reply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to set DNS servers: %s: %v", mergedServersStr, reply.Error) + } + + return nil } // endregion wg-quick-related // region routing-related -func (c *MikrotikController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error { - // TODO implement me - panic("implement me") +// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op. +func (c *MikrotikController) SetRoutes( + ctx context.Context, + interfaceId domain.InterfaceIdentifier, + table int, + fwMark uint32, + cidrs []domain.Cidr, +) error { + return nil } -func (c *MikrotikController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error { - // TODO implement me - panic("implement me") +// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op. +func (c *MikrotikController) RemoveRoutes( + ctx context.Context, + interfaceId domain.InterfaceIdentifier, + table int, + fwMark uint32, + oldCidrs []domain.Cidr, +) error { + return nil } // endregion routing-related diff --git a/internal/adapters/wgquick.go b/internal/adapters/wgquick.go deleted file mode 100644 index 992e69a..0000000 --- a/internal/adapters/wgquick.go +++ /dev/null @@ -1,113 +0,0 @@ -package adapters - -import ( - "bytes" - "fmt" - "log/slog" - "os/exec" - "strings" - - "github.com/h44z/wg-portal/internal" - "github.com/h44z/wg-portal/internal/domain" -) - -// WgQuickRepo implements higher level wg-quick like interactions like setting DNS, routing tables or interface hooks. -type WgQuickRepo struct { - shellCmd string - resolvConfIfacePrefix string -} - -// NewWgQuickRepo creates a new WgQuickRepo instance. -func NewWgQuickRepo() *WgQuickRepo { - return &WgQuickRepo{ - shellCmd: "bash", - resolvConfIfacePrefix: "tun.", - } -} - -// ExecuteInterfaceHook executes the given hook command. -// The hook command can contain the following placeholders: -// -// %i: the interface identifier. -func (r *WgQuickRepo) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error { - if hookCmd == "" { - return nil - } - - slog.Debug("executing interface hook", "interface", id, "hook", hookCmd) - err := r.exec(hookCmd, id) - if err != nil { - return fmt.Errorf("failed to exec hook: %w", err) - } - - return nil -} - -// SetDNS sets the DNS settings for the given interface. It uses resolvconf to set the DNS settings. -func (r *WgQuickRepo) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error { - if dnsStr == "" && dnsSearchStr == "" { - return nil - } - - dnsServers := internal.SliceString(dnsStr) - dnsSearchDomains := internal.SliceString(dnsSearchStr) - - dnsCommand := "resolvconf -a %resPref%i -m 0 -x" - dnsCommandInput := make([]string, 0, len(dnsServers)+len(dnsSearchDomains)) - - for _, dnsServer := range dnsServers { - dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("nameserver %s", dnsServer)) - } - for _, searchDomain := range dnsSearchDomains { - dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("search %s", searchDomain)) - } - - err := r.exec(dnsCommand, id, dnsCommandInput...) - if err != nil { - return fmt.Errorf( - "failed to set dns settings (is resolvconf available?, for systemd create this symlink: ln -s /usr/bin/resolvectl /usr/local/bin/resolvconf): %w", - err, - ) - } - - return nil -} - -// UnsetDNS unsets the DNS settings for the given interface. It uses resolvconf to unset the DNS settings. -func (r *WgQuickRepo) UnsetDNS(id domain.InterfaceIdentifier) error { - dnsCommand := "resolvconf -d %resPref%i -f" - - err := r.exec(dnsCommand, id) - if err != nil { - return fmt.Errorf("failed to unset dns settings: %w", err) - } - - return nil -} - -func (r *WgQuickRepo) replaceCommandPlaceHolders(command string, interfaceId domain.InterfaceIdentifier) string { - command = strings.ReplaceAll(command, "%resPref", r.resolvConfIfacePrefix) - return strings.ReplaceAll(command, "%i", string(interfaceId)) -} - -func (r *WgQuickRepo) exec(command string, interfaceId domain.InterfaceIdentifier, stdin ...string) error { - commandWithInterfaceName := r.replaceCommandPlaceHolders(command, interfaceId) - cmd := exec.Command(r.shellCmd, "-ce", commandWithInterfaceName) - if len(stdin) > 0 { - b := &bytes.Buffer{} - for _, ln := range stdin { - if _, err := fmt.Fprint(b, ln); err != nil { - return err - } - } - cmd.Stdin = b - } - out, err := cmd.CombinedOutput() // execute and wait for output - if err != nil { - return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err) - } - slog.Debug("executed shell command", - "command", commandWithInterfaceName, - "output", string(out)) - return nil -} diff --git a/internal/app/route/routes.go b/internal/app/route/routes.go index c87bcaf..9e52e23 100644 --- a/internal/app/route/routes.go +++ b/internal/app/route/routes.go @@ -5,24 +5,21 @@ import ( "fmt" "log/slog" - "github.com/vishvananda/netlink" - "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/wgctrl" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/domain" - "github.com/h44z/wg-portal/internal/lowlevel" ) // region dependencies +type ControllerManager interface { + // GetController returns the controller for the given interface. + GetController(iface domain.Interface) domain.InterfaceController +} + type InterfaceAndPeerDatabaseRepo interface { - // GetAllInterfaces returns all interfaces - GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) - // GetInterfacePeers returns all peers for a given interface - GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error) + // GetInterface returns the interface with the given identifier. + GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) } type EventBus interface { @@ -30,6 +27,25 @@ type EventBus interface { Subscribe(topic string, fn interface{}) error } +type RoutesController interface { + // SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op. + SetRoutes( + ctx context.Context, + interfaceId domain.InterfaceIdentifier, + table int, + fwMark uint32, + cidrs []domain.Cidr, + ) error + // RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op. + RemoveRoutes( + ctx context.Context, + interfaceId domain.InterfaceIdentifier, + table int, + fwMark uint32, + oldCidrs []domain.Cidr, + ) error +} + // endregion dependencies type routeRuleInfo struct { @@ -45,28 +61,24 @@ type routeRuleInfo struct { type Manager struct { cfg *config.Config - bus EventBus - wg lowlevel.WireGuardClient - nl lowlevel.NetlinkClient - db InterfaceAndPeerDatabaseRepo + bus EventBus + db InterfaceAndPeerDatabaseRepo + wgController ControllerManager } // NewRouteManager creates a new route manager instance. -func NewRouteManager(cfg *config.Config, bus EventBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) { - wg, err := wgctrl.New() - if err != nil { - panic("failed to init wgctrl: " + err.Error()) - } - - nl := &lowlevel.NetlinkManager{} - +func NewRouteManager( + cfg *config.Config, + bus EventBus, + db InterfaceAndPeerDatabaseRepo, + wgController ControllerManager, +) (*Manager, error) { m := &Manager{ cfg: cfg, bus: bus, - db: db, - wg: wg, - nl: nl, + db: db, + wgController: wgController, } m.connectToMessageBus() @@ -85,17 +97,21 @@ func (m Manager) StartBackgroundJobs(_ context.Context) { // this is a no-op for now } -func (m Manager) handleRouteUpdateEvent(srcDescription string) { - slog.Debug("handling route update event", "source", srcDescription) +func (m Manager) handleRouteUpdateEvent(info domain.RoutingTableInfo) { + slog.Debug("handling route update event", "info", info.String()) - err := m.syncRoutes(context.Background()) - if err != nil { - slog.Error("failed to synchronize routes", - "source", srcDescription, - "error", err) + if !info.ManagementEnabled() { + return // route management disabled } - slog.Debug("routes synchronized", "source", srcDescription) + err := m.syncRoutes(context.Background(), info) + if err != nil { + slog.Error("failed to synchronize routes", + "info", info.String(), "error", err) + return + } + + slog.Debug("routes synchronized", "info", info.String()) } func (m Manager) handleRouteRemoveEvent(info domain.RoutingTableInfo) { @@ -105,399 +121,40 @@ func (m Manager) handleRouteRemoveEvent(info domain.RoutingTableInfo) { return // route management disabled } - if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V4); err != nil { - slog.Error("failed to remove v4 fwmark rules", "error", err) - } - if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V6); err != nil { - slog.Error("failed to remove v6 fwmark rules", "error", err) - } - - slog.Debug("routes removed", "table", info.String()) -} - -func (m Manager) syncRoutes(ctx context.Context) error { - interfaces, err := m.db.GetAllInterfaces(ctx) + err := m.removeRoutes(context.Background(), info) if err != nil { - return fmt.Errorf("failed to find all interfaces: %w", err) + slog.Error("failed to synchronize routes", + "info", info.String(), "error", err) + return } - rules := map[int][]routeRuleInfo{ - netlink.FAMILY_V4: nil, - netlink.FAMILY_V6: nil, - } - for _, iface := range interfaces { - if iface.IsDisabled() { - continue // disabled interface does not need route entries - } - if !iface.ManageRoutingTable() { - continue - } - - peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier) - if err != nil { - return fmt.Errorf("failed to find peers for %s: %w", iface.Identifier, err) - } - allowedIPs := iface.GetAllowedIPs(peers) - defRouteV4, defRouteV6 := m.containsDefaultRoute(allowedIPs) - - link, err := m.nl.LinkByName(string(iface.Identifier)) - if err != nil { - return fmt.Errorf("failed to find physical link for %s: %w", iface.Identifier, err) - } - - table, fwmark, err := m.getRoutingTableAndFwMark(&iface, link) - if err != nil { - return fmt.Errorf("failed to get table and fwmark for %s: %w", iface.Identifier, err) - } - - if err := m.setInterfaceRoutes(link, table, allowedIPs); err != nil { - return fmt.Errorf("failed to set routes for %s: %w", iface.Identifier, err) - } - - if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V4, allowedIPs); err != nil { - return fmt.Errorf("failed to remove deprecated v4 routes for %s: %w", iface.Identifier, err) - } - if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V6, allowedIPs); err != nil { - return fmt.Errorf("failed to remove deprecated v6 routes for %s: %w", iface.Identifier, err) - } - - if table != 0 { - rules[netlink.FAMILY_V4] = append(rules[netlink.FAMILY_V4], routeRuleInfo{ - ifaceId: iface.Identifier, - fwMark: fwmark, - table: table, - family: netlink.FAMILY_V4, - hasDefault: defRouteV4, - }) - } - if table != 0 { - rules[netlink.FAMILY_V6] = append(rules[netlink.FAMILY_V6], routeRuleInfo{ - ifaceId: iface.Identifier, - fwMark: fwmark, - table: table, - family: netlink.FAMILY_V6, - hasDefault: defRouteV6, - }) - } - } - - return m.syncRouteRules(rules) + slog.Debug("routes removed", "info", info.String()) } -func (m Manager) syncRouteRules(allRules map[int][]routeRuleInfo) error { - for family, rules := range allRules { - // update fwmark rules - if err := m.setFwMarkRules(rules, family); err != nil { - return err - } - - // update main rule - if err := m.setMainRule(rules, family); err != nil { - return err - } - - // cleanup old main rules - if err := m.cleanupMainRule(rules, family); err != nil { - return err - } - } - - return nil -} - -func (m Manager) setFwMarkRules(rules []routeRuleInfo, family int) error { - for _, rule := range rules { - existingRules, err := m.nl.RuleList(family) - if err != nil { - return fmt.Errorf("failed to get existing rules for family %d: %w", family, err) - } - - ruleExists := false - for _, existingRule := range existingRules { - if rule.fwMark == existingRule.Mark && rule.table == existingRule.Table { - ruleExists = true - break - } - } - - if ruleExists { - continue // rule already exists, no need to recreate it - } - - // create missing rule - if err := m.nl.RuleAdd(&netlink.Rule{ - Family: family, - Table: rule.table, - Mark: rule.fwMark, - Invert: true, - SuppressIfgroup: -1, - SuppressPrefixlen: -1, - Priority: m.getRulePriority(existingRules), - Mask: nil, - Goto: -1, - Flow: -1, - }); err != nil { - return fmt.Errorf("failed to setup rule for fwmark %d and table %d: %w", rule.fwMark, rule.table, err) - } - } - return nil -} - -func (m Manager) removeFwMarkRules(fwmark uint32, table int, family int) error { - existingRules, err := m.nl.RuleList(family) - if err != nil { - return fmt.Errorf("failed to get existing rules for family %d: %w", family, err) - } - - for _, existingRule := range existingRules { - if fwmark == existingRule.Mark && table == existingRule.Table { - existingRule.Family = family // set family, somehow the RuleList method does not populate the family field - if err := m.nl.RuleDel(&existingRule); err != nil { - return fmt.Errorf("failed to delete fwmark rule: %w", err) - } - } - } - return nil -} - -func (m Manager) setMainRule(rules []routeRuleInfo, family int) error { - shouldHaveMainRule := false - for _, rule := range rules { - if rule.hasDefault == true { - shouldHaveMainRule = true - break - } - } - if !shouldHaveMainRule { +func (m Manager) syncRoutes(ctx context.Context, info domain.RoutingTableInfo) error { + rc, ok := m.wgController.GetController(info.Interface).(RoutesController) + if !ok { + slog.Warn("no capable routes-controller found for interface", "interface", info.Interface.Identifier) return nil } - existingRules, err := m.nl.RuleList(family) + err := rc.SetRoutes(ctx, info.Interface.Identifier, info.Table, info.FwMark, info.AllowedIps) if err != nil { - return fmt.Errorf("failed to get existing rules for family %d: %w", family, err) + return fmt.Errorf("failed to set routes for interface %s: %w", info.Interface.Identifier, err) } - - ruleExists := false - for _, existingRule := range existingRules { - if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { - ruleExists = true - break - } - } - - if ruleExists { - return nil // rule already exists, skip re-creation - } - - if err := m.nl.RuleAdd(&netlink.Rule{ - Family: family, - Table: unix.RT_TABLE_MAIN, - SuppressIfgroup: -1, - SuppressPrefixlen: 0, - Priority: m.getMainRulePriority(existingRules), - Mark: 0, - Mask: nil, - Goto: -1, - Flow: -1, - }); err != nil { - return fmt.Errorf("failed to setup rule for main table: %w", err) - } - return nil } -func (m Manager) cleanupMainRule(rules []routeRuleInfo, family int) error { - existingRules, err := m.nl.RuleList(family) +func (m Manager) removeRoutes(ctx context.Context, info domain.RoutingTableInfo) error { + rc, ok := m.wgController.GetController(info.Interface).(RoutesController) + if !ok { + slog.Warn("no capable routes-controller found for interface", "interface", info.Interface.Identifier) + return nil + } + + err := rc.RemoveRoutes(ctx, info.Interface.Identifier, info.Table, info.FwMark, info.AllowedIps) if err != nil { - return fmt.Errorf("failed to get existing rules for family %d: %w", family, err) - } - - shouldHaveMainRule := false - for _, rule := range rules { - if rule.hasDefault == true { - shouldHaveMainRule = true - break - } - } - - mainRules := 0 - for _, existingRule := range existingRules { - if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { - mainRules++ - } - } - - removalCount := 0 - if mainRules > 1 { - removalCount = mainRules - 1 // we only want one single rule - } - if !shouldHaveMainRule { - removalCount = mainRules - } - - for _, existingRule := range existingRules { - if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { - if removalCount > 0 { - existingRule.Family = family // set family, somehow the RuleList method does not populate the family field - if err := m.nl.RuleDel(&existingRule); err != nil { - return fmt.Errorf("failed to delete main rule: %w", err) - } - removalCount-- - } - } - } - - return nil -} - -func (m Manager) getMainRulePriority(existingRules []netlink.Rule) int { - prio := m.cfg.Advanced.RulePrioOffset - for { - isFresh := true - for _, existingRule := range existingRules { - if existingRule.Priority == prio { - isFresh = false - break - } - } - if isFresh { - break - } else { - prio++ - } - } - return prio -} - -func (m Manager) getRulePriority(existingRules []netlink.Rule) int { - prio := 32700 // linux main rule has a prio of 32766 - for { - isFresh := true - for _, existingRule := range existingRules { - if existingRule.Priority == prio { - isFresh = false - break - } - } - if isFresh { - break - } else { - prio-- - } - } - return prio -} - -func (m Manager) setInterfaceRoutes(link netlink.Link, table int, allowedIPs []domain.Cidr) error { - for _, allowedIP := range allowedIPs { - err := m.nl.RouteReplace(&netlink.Route{ - LinkIndex: link.Attrs().Index, - Dst: allowedIP.IpNet(), - Table: table, - Scope: unix.RT_SCOPE_LINK, - Type: unix.RTN_UNICAST, - }) - if err != nil { - return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err) - } - } - - return nil -} - -func (m Manager) removeDeprecatedRoutes(link netlink.Link, family int, allowedIPs []domain.Cidr) error { - rawRoutes, err := m.nl.RouteListFiltered(family, &netlink.Route{ - LinkIndex: link.Attrs().Index, - Table: unix.RT_TABLE_UNSPEC, // all tables - Scope: unix.RT_SCOPE_LINK, - Type: unix.RTN_UNICAST, - }, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF) - if err != nil { - return fmt.Errorf("failed to fetch raw routes: %w", err) - } - for _, rawRoute := range rawRoutes { - if rawRoute.Dst == nil { // handle default route - var netlinkAddr domain.Cidr - if family == netlink.FAMILY_V4 { - netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0") - } else { - netlinkAddr, _ = domain.CidrFromString("::/0") - } - rawRoute.Dst = netlinkAddr.IpNet() - } - - netlinkAddr := domain.CidrFromIpNet(*rawRoute.Dst) - remove := true - for _, allowedIP := range allowedIPs { - if netlinkAddr == allowedIP { - remove = false - break - } - } - - if !remove { - continue - } - - err := m.nl.RouteDel(&rawRoute) - if err != nil { - return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err) - } + return fmt.Errorf("failed to remove routes for interface %s: %w", info.Interface.Identifier, err) } return nil } - -func (m Manager) getRoutingTableAndFwMark(iface *domain.Interface, link netlink.Link) ( - table int, - fwmark uint32, - err error, -) { - table = iface.GetRoutingTable() - fwmark = iface.FirewallMark - - if fwmark == 0 { - // generate a new (temporary) firewall mark based on the interface index - fwmark = uint32(m.cfg.Advanced.RouteTableOffset + link.Attrs().Index) - slog.Debug("using fwmark to handle routes", - "interface", iface.Identifier, - "fwmark", fwmark) - - // apply the temporary fwmark to the wireguard interface - err = m.setFwMark(iface.Identifier, int(fwmark)) - } - if table == 0 { - table = int(fwmark) // generate a new routing table base on interface index - slog.Debug("using routing table to handle default routes", - "interface", iface.Identifier, - "table", table) - } - return -} - -func (m Manager) setFwMark(id domain.InterfaceIdentifier, fwmark int) error { - err := m.wg.ConfigureDevice(string(id), wgtypes.Config{ - FirewallMark: &fwmark, - }) - if err != nil { - return fmt.Errorf("failed to update fwmark to: %d: %w", fwmark, err) - } - return nil -} - -func (m Manager) containsDefaultRoute(allowedIPs []domain.Cidr) (ipV4, ipV6 bool) { - for _, allowedIP := range allowedIPs { - if ipV4 && ipV6 { - break // speed up - } - - if allowedIP.Prefix().Bits() == 0 { - if allowedIP.IsV4() { - ipV4 = true - } else { - ipV6 = true - } - } - } - - return -} diff --git a/internal/app/wireguard/controller_manager.go b/internal/app/wireguard/controller_manager.go index 2eea6af..0f6bd23 100644 --- a/internal/app/wireguard/controller_manager.go +++ b/internal/app/wireguard/controller_manager.go @@ -1,7 +1,6 @@ package wireguard import ( - "context" "fmt" "log/slog" "maps" @@ -12,33 +11,9 @@ import ( "github.com/h44z/wg-portal/internal/domain" ) -type InterfaceController interface { - GetId() domain.InterfaceBackend - GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) - GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) - GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) - SaveInterface( - _ context.Context, - id domain.InterfaceIdentifier, - updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error), - ) error - DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error - SavePeer( - _ context.Context, - deviceId domain.InterfaceIdentifier, - id domain.PeerIdentifier, - updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error), - ) error - DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error - PingAddresses( - ctx context.Context, - addr string, - ) (*domain.PingerResult, error) -} - type backendInstance struct { Config config.BackendBase // Config is the configuration for the backend instance. - Implementation InterfaceController + Implementation domain.InterfaceController } type ControllerManager struct { @@ -118,11 +93,11 @@ func (c *ControllerManager) logRegisteredControllers() { } } -func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) InterfaceController { +func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) domain.InterfaceController { return c.getController(backend, "").Implementation } -func (c *ControllerManager) GetController(iface domain.Interface) InterfaceController { +func (c *ControllerManager) GetController(iface domain.Interface) domain.InterfaceController { return c.getController(iface.Backend, iface.Identifier).Implementation } diff --git a/internal/app/wireguard/wireguard.go b/internal/app/wireguard/wireguard.go index b28f70e..e1e9dfa 100644 --- a/internal/app/wireguard/wireguard.go +++ b/internal/app/wireguard/wireguard.go @@ -38,9 +38,9 @@ type InterfaceAndPeerDatabaseRepo interface { } type WgQuickController interface { - ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error - SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error - UnsetDNS(id domain.InterfaceIdentifier) error + ExecuteInterfaceHook(ctx context.Context, id domain.InterfaceIdentifier, hookCmd string) error + SetDNS(ctx context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error + UnsetDNS(ctx context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error } type EventBus interface { @@ -53,11 +53,10 @@ type EventBus interface { // endregion dependencies type Manager struct { - cfg *config.Config - bus EventBus - db InterfaceAndPeerDatabaseRepo - wg *ControllerManager - quick WgQuickController + cfg *config.Config + bus EventBus + db InterfaceAndPeerDatabaseRepo + wg *ControllerManager userLockMap *sync.Map } @@ -66,7 +65,6 @@ func NewWireGuardManager( cfg *config.Config, bus EventBus, wg *ControllerManager, - quick WgQuickController, db InterfaceAndPeerDatabaseRepo, ) (*Manager, error) { m := &Manager{ @@ -74,7 +72,6 @@ func NewWireGuardManager( bus: bus, wg: wg, db: db, - quick: quick, userLockMap: &sync.Map{}, } diff --git a/internal/app/wireguard/wireguard_interfaces.go b/internal/app/wireguard/wireguard_interfaces.go index 368d1eb..e40c164 100644 --- a/internal/app/wireguard/wireguard_interfaces.go +++ b/internal/app/wireguard/wireguard_interfaces.go @@ -453,7 +453,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif return err } - existingInterface, err := m.db.GetInterface(ctx, id) + existingInterface, existingPeers, err := m.db.GetInterfaceAndPeers(ctx, id) if err != nil { return fmt.Errorf("unable to find interface %s: %w", id, err) } @@ -468,15 +468,16 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif physicalInterface, _ := m.wg.GetController(*existingInterface).GetInterface(ctx, id) - if err := m.handleInterfacePreSaveHooks(existingInterface, !existingInterface.IsDisabled(), false); err != nil { + if err := m.handleInterfacePreSaveHooks(ctx, existingInterface, !existingInterface.IsDisabled(), + false); err != nil { return fmt.Errorf("pre-delete hooks failed: %w", err) } - if err := m.handleInterfacePreSaveActions(existingInterface); err != nil { + if err := m.handleInterfacePreSaveActions(ctx, existingInterface); err != nil { return fmt.Errorf("pre-delete actions failed: %w", err) } - if err := m.deleteInterfacePeers(ctx, id); err != nil { + if err := m.deleteInterfacePeers(ctx, existingInterface, existingPeers); err != nil { return fmt.Errorf("peer deletion failure: %w", err) } @@ -493,11 +494,18 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif fwMark = physicalInterface.FirewallMark } m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{ - FwMark: fwMark, - Table: existingInterface.GetRoutingTable(), + Interface: *existingInterface, + AllowedIps: existingInterface.GetAllowedIPs(existingPeers), + FwMark: fwMark, + Table: existingInterface.GetRoutingTable(), }) - if err := m.handleInterfacePostSaveHooks(existingInterface, !existingInterface.IsDisabled(), false); err != nil { + if err := m.handleInterfacePostSaveHooks( + ctx, + existingInterface, + !existingInterface.IsDisabled(), + false, + ); err != nil { return fmt.Errorf("post-delete hooks failed: %w", err) } @@ -518,11 +526,11 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) ( oldEnabled, newEnabled := m.getInterfaceStateHistory(ctx, iface) - if err := m.handleInterfacePreSaveHooks(iface, oldEnabled, newEnabled); err != nil { + if err := m.handleInterfacePreSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil { return nil, fmt.Errorf("pre-save hooks failed: %w", err) } - if err := m.handleInterfacePreSaveActions(iface); err != nil { + if err := m.handleInterfacePreSaveActions(ctx, iface); err != nil { return nil, fmt.Errorf("pre-save actions failed: %w", err) } @@ -575,14 +583,21 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) ( fwMark = physicalInterface.FirewallMark } m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{ - FwMark: fwMark, - Table: iface.GetRoutingTable(), + Interface: *iface, + AllowedIps: iface.GetAllowedIPs(peers), + FwMark: fwMark, + Table: iface.GetRoutingTable(), }) } else { - m.bus.Publish(app.TopicRouteUpdate, "interface updated: "+string(iface.Identifier)) + m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{ + Interface: *iface, + AllowedIps: iface.GetAllowedIPs(peers), + FwMark: iface.FirewallMark, + Table: iface.GetRoutingTable(), + }) } - if err := m.handleInterfacePostSaveHooks(iface, oldEnabled, newEnabled); err != nil { + if err := m.handleInterfacePostSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil { return nil, fmt.Errorf("post-save hooks failed: %w", err) } @@ -627,51 +642,83 @@ func (m Manager) getInterfaceStateHistory(ctx context.Context, iface *domain.Int return !oldInterface.IsDisabled(), !iface.IsDisabled() } -func (m Manager) handleInterfacePreSaveActions(iface *domain.Interface) error { - if !iface.IsDisabled() { - if err := m.quick.SetDNS(iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil { - return fmt.Errorf("failed to update dns settings: %w", err) - } - } else { - if err := m.quick.UnsetDNS(iface.Identifier); err != nil { - return fmt.Errorf("failed to clear dns settings: %w", err) +func (m Manager) handleInterfacePreSaveActions(ctx context.Context, iface *domain.Interface) error { + wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController) + if !ok { + slog.Warn("failed to perform pre-save actions", "interface", iface.Identifier, + "error", "no capable controller found") + return nil + } + + // update DNS settings only for client interfaces + if iface.Type == domain.InterfaceTypeClient || iface.Type == domain.InterfaceTypeAny { + if !iface.IsDisabled() { + if err := wgQuickController.SetDNS(ctx, iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil { + return fmt.Errorf("failed to update dns settings: %w", err) + } + } else { + if err := wgQuickController.UnsetDNS(ctx, iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil { + return fmt.Errorf("failed to clear dns settings: %w", err) + } } } return nil } -func (m Manager) handleInterfacePreSaveHooks(iface *domain.Interface, oldEnabled, newEnabled bool) error { +func (m Manager) handleInterfacePreSaveHooks( + ctx context.Context, + iface *domain.Interface, + oldEnabled, newEnabled bool, +) error { if oldEnabled == newEnabled { return nil // do nothing if state did not change } slog.Debug("executing pre-save hooks", "interface", iface.Identifier, "up", newEnabled) + wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController) + if !ok { + slog.Warn("failed to execute pre-save hooks", "interface", iface.Identifier, "up", newEnabled, + "error", "no capable controller found") + return nil + } + if newEnabled { - if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreUp); err != nil { + if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PreUp); err != nil { return fmt.Errorf("failed to execute pre-up hook: %w", err) } } else { - if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreDown); err != nil { + if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PreDown); err != nil { return fmt.Errorf("failed to execute pre-down hook: %w", err) } } return nil } -func (m Manager) handleInterfacePostSaveHooks(iface *domain.Interface, oldEnabled, newEnabled bool) error { +func (m Manager) handleInterfacePostSaveHooks( + ctx context.Context, + iface *domain.Interface, + oldEnabled, newEnabled bool, +) error { if oldEnabled == newEnabled { return nil // do nothing if state did not change } slog.Debug("executing post-save hooks", "interface", iface.Identifier, "up", newEnabled) + wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController) + if !ok { + slog.Warn("failed to execute post-save hooks", "interface", iface.Identifier, "up", newEnabled, + "error", "no capable controller found") + return nil + } + if newEnabled { - if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostUp); err != nil { + if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PostUp); err != nil { return fmt.Errorf("failed to execute post-up hook: %w", err) } } else { - if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostDown); err != nil { + if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PostDown); err != nil { return fmt.Errorf("failed to execute post-down hook: %w", err) } } @@ -799,7 +846,7 @@ func (m Manager) getFreshListenPort(ctx context.Context) (port int, err error) { func (m Manager) importInterface( ctx context.Context, - backend InterfaceController, + backend domain.InterfaceController, in *domain.PhysicalInterface, peers []domain.PhysicalPeer, ) error { @@ -901,13 +948,9 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain return nil } -func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) error { - iface, allPeers, err := m.db.GetInterfaceAndPeers(ctx, id) - if err != nil { - return err - } +func (m Manager) deleteInterfacePeers(ctx context.Context, iface *domain.Interface, allPeers []domain.Peer) error { for _, peer := range allPeers { - err = m.wg.GetController(*iface).DeletePeer(ctx, id, peer.Identifier) + err := m.wg.GetController(*iface).DeletePeer(ctx, iface.Identifier, peer.Identifier) if err != nil && !errors.Is(err, os.ErrNotExist) { return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err) } diff --git a/internal/app/wireguard/wireguard_peers.go b/internal/app/wireguard/wireguard_peers.go index f3c5364..bde31ce 100644 --- a/internal/app/wireguard/wireguard_peers.go +++ b/internal/app/wireguard/wireguard_peers.go @@ -388,9 +388,19 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error return fmt.Errorf("failed to delete peer %s: %w", id, err) } + peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier) + if err != nil { + return fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err) + } + m.bus.Publish(app.TopicPeerDeleted, *peer) // Update routes after peers have changed - m.bus.Publish(app.TopicRouteUpdate, "peers updated") + m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{ + Interface: *iface, + AllowedIps: iface.GetAllowedIPs(peers), + FwMark: iface.FirewallMark, + Table: iface.GetRoutingTable(), + }) // Update interface after peers have changed m.bus.Publish(app.TopicPeerInterfaceUpdated, peer.InterfaceIdentifier) @@ -438,20 +448,28 @@ func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) // region helper-functions func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error { - interfaces := make(map[domain.InterfaceIdentifier]struct{}) + interfaces := make(map[domain.InterfaceIdentifier]domain.Interface) + interfacePeers := make(map[domain.InterfaceIdentifier][]domain.Peer) for _, peer := range peers { - iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier) - if err != nil { - return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err) + // get interface from db if it is not yet in the map + if _, ok := interfaces[peer.InterfaceIdentifier]; !ok { + iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier) + if err != nil { + return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err) + } + interfaces[peer.InterfaceIdentifier] = *iface } + iface := interfaces[peer.InterfaceIdentifier] + interfacePeers[iface.Identifier] = append(interfacePeers[iface.Identifier], *peer) + // Always save the peer to the backend, regardless of disabled/expired state // The backend will handle the disabled state appropriately - err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) { + err := m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) { peer.CopyCalculatedAttributes(p) - err := m.wg.GetController(*iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier, + err := m.wg.GetController(iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier, func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) { domain.MergeToPhysicalPeer(pp, peer) return pp, nil @@ -475,13 +493,16 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error { Peer: *peer, }, }) - - interfaces[peer.InterfaceIdentifier] = struct{}{} } // Update routes after peers have changed - if len(interfaces) != 0 { - m.bus.Publish(app.TopicRouteUpdate, "peers updated") + for id, iface := range interfaces { + m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{ + Interface: iface, + AllowedIps: iface.GetAllowedIPs(interfacePeers[id]), + FwMark: iface.FirewallMark, + Table: iface.GetRoutingTable(), + }) } for iface := range interfaces { diff --git a/internal/config/backend.go b/internal/config/backend.go index fa8ff2e..cee7c8b 100644 --- a/internal/config/backend.go +++ b/internal/config/backend.go @@ -13,6 +13,7 @@ type Backend struct { // Local Backend-specific configuration IgnoredLocalInterfaces []string `yaml:"ignored_local_interfaces"` // A list of interface names that should be ignored by this backend (e.g., "wg0") + LocalResolvconfPrefix string `yaml:"local_resolvconf_prefix"` // The prefix to use for interface names when passing them to resolvconf. // External Backend-specific configuration diff --git a/internal/config/config.go b/internal/config/config.go index 4203099..338dbf6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -134,6 +134,9 @@ func defaultConfig() *Config { cfg.Backend = Backend{ Default: LocalBackendName, // local backend is the default (using wgcrtl) + // Most resolconf implementations use "tun." as a prefix for interface names. + // But systemd's implementation uses no prefix, for example. + LocalResolvconfPrefix: "tun.", } cfg.Web = WebConfig{ diff --git a/internal/domain/interface.go b/internal/domain/interface.go index 32fc1c0..1ea4305 100644 --- a/internal/domain/interface.go +++ b/internal/domain/interface.go @@ -308,12 +308,14 @@ func MergeToPhysicalInterface(pi *PhysicalInterface, i *Interface) { } type RoutingTableInfo struct { - FwMark uint32 - Table int + Interface Interface + AllowedIps []Cidr + FwMark uint32 + Table int } func (r RoutingTableInfo) String() string { - return fmt.Sprintf("%d -> %d", r.FwMark, r.Table) + return fmt.Sprintf("%s: %d -> %d", r.Interface.Identifier, r.FwMark, r.Table) } func (r RoutingTableInfo) ManagementEnabled() bool { diff --git a/internal/domain/interface_controller.go b/internal/domain/interface_controller.go new file mode 100644 index 0000000..efd3207 --- /dev/null +++ b/internal/domain/interface_controller.go @@ -0,0 +1,27 @@ +package domain + +import "context" + +type InterfaceController interface { + GetId() InterfaceBackend + GetInterfaces(_ context.Context) ([]PhysicalInterface, error) + GetInterface(_ context.Context, id InterfaceIdentifier) (*PhysicalInterface, error) + GetPeers(_ context.Context, deviceId InterfaceIdentifier) ([]PhysicalPeer, error) + SaveInterface( + _ context.Context, + id InterfaceIdentifier, + updateFunc func(pi *PhysicalInterface) (*PhysicalInterface, error), + ) error + DeleteInterface(_ context.Context, id InterfaceIdentifier) error + SavePeer( + _ context.Context, + deviceId InterfaceIdentifier, + id PeerIdentifier, + updateFunc func(pp *PhysicalPeer) (*PhysicalPeer, error), + ) error + DeletePeer(_ context.Context, deviceId InterfaceIdentifier, id PeerIdentifier) error + PingAddresses( + ctx context.Context, + addr string, + ) (*PingerResult, error) +} diff --git a/internal/domain/ip.go b/internal/domain/ip.go index ee67413..acdaf02 100644 --- a/internal/domain/ip.go +++ b/internal/domain/ip.go @@ -199,3 +199,22 @@ func (c Cidr) Contains(other Cidr) bool { return subnet.Contains(otherIP) } + +// ContainsDefaultRoute returns true if the given CIDRs contain a default route. +func ContainsDefaultRoute(cidrs []Cidr) (ipV4, ipV6 bool) { + for _, allowedIP := range cidrs { + if ipV4 && ipV6 { + break // speed up + } + + if allowedIP.Prefix().Bits() == 0 { + if allowedIP.IsV4() { + ipV4 = true + } else { + ipV6 = true + } + } + } + + return +} diff --git a/internal/lowlevel/mikrotik.go b/internal/lowlevel/mikrotik.go index 49ef1d7..0c86a13 100644 --- a/internal/lowlevel/mikrotik.go +++ b/internal/lowlevel/mikrotik.go @@ -267,6 +267,7 @@ func parseHttpResponse[T any](resp *http.Response, err error) MikrotikApiRespons } defer func(Body io.ReadCloser) { + _, _ = io.Copy(io.Discard, Body) // ensure to empty the body err := Body.Close() if err != nil { slog.Error("failed to close response body", "error", err)