From cdf3a49801d956e47c97972f490ec38d6848746b Mon Sep 17 00:00:00 2001 From: h44z Date: Sun, 12 Oct 2025 14:31:19 +0200 Subject: [PATCH] Cleanup route handling (#542) * mikrotik: allow to set DNS, wip: handle routes in wg-controller * replace old route handling for local controller * cleanup route handling for local backend * implement route handling for mikrotik controller --- cmd/wg-portal/main.go | 6 +- docs/documentation/configuration/examples.md | 3 + docs/documentation/configuration/overview.md | 6 + .../src/components/InterfaceEditModal.vue | 9 +- internal/adapters/wgcontroller/local.go | 456 +++++++++++----- internal/adapters/wgcontroller/mikrotik.go | 408 ++++++++++++++- internal/adapters/wgquick.go | 113 ---- internal/app/route/routes.go | 491 +++--------------- internal/app/wireguard/controller_manager.go | 31 +- internal/app/wireguard/wireguard.go | 17 +- .../app/wireguard/wireguard_interfaces.go | 163 ++++-- internal/app/wireguard/wireguard_peers.go | 48 +- internal/config/backend.go | 1 + internal/config/config.go | 3 + internal/domain/interface.go | 45 +- internal/domain/interface_controller.go | 27 + internal/domain/interface_test.go | 64 ++- internal/domain/ip.go | 27 + internal/lowlevel/mikrotik.go | 1 + 19 files changed, 1116 insertions(+), 803 deletions(-) delete mode 100644 internal/adapters/wgquick.go create mode 100644 internal/domain/interface_controller.go diff --git a/cmd/wg-portal/main.go b/cmd/wg-portal/main.go index 97f0b67..cd5fd7e 100644 --- a/cmd/wg-portal/main.go +++ b/cmd/wg-portal/main.go @@ -53,8 +53,6 @@ func main() { wireGuard, err := wireguard.NewControllerManager(cfg) internal.AssertNoError(err) - wgQuick := adapters.NewWgQuickRepo() - mailer := adapters.NewSmtpMailRepo(cfg.Mail) metricsServer := adapters.NewMetricsServer(cfg) @@ -93,7 +91,7 @@ func main() { webAuthn, err := auth.NewWebAuthnAuthenticator(cfg, eventBus, userManager) internal.AssertNoError(err) - wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, wgQuick, database) + wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, database) internal.AssertNoError(err) wireGuardManager.StartBackgroundJobs(ctx) @@ -107,7 +105,7 @@ func main() { mailManager, err := mail.NewMailManager(cfg, mailer, cfgFileManager, database, database) internal.AssertNoError(err) - routeManager, err := route.NewRouteManager(cfg, eventBus, database) + routeManager, err := route.NewRouteManager(cfg, eventBus, database, wireGuard) internal.AssertNoError(err) routeManager.StartBackgroundJobs(ctx) diff --git a/docs/documentation/configuration/examples.md b/docs/documentation/configuration/examples.md index 3909d48..4cf1e56 100644 --- a/docs/documentation/configuration/examples.md +++ b/docs/documentation/configuration/examples.md @@ -15,6 +15,9 @@ backend: # default backend decides where new interfaces are created default: mikrotik + # A prefix for resolvconf. Usually it is "tun.". If you are using systemd, the prefix should be empty. + local_resolvconf_prefix: "tun." + mikrotik: - id: mikrotik # unique id, not "local" display_name: RouterOS RB5009 # optional nice name diff --git a/docs/documentation/configuration/overview.md b/docs/documentation/configuration/overview.md index 44ce206..4f1f49a 100644 --- a/docs/documentation/configuration/overview.md +++ b/docs/documentation/configuration/overview.md @@ -28,6 +28,7 @@ core: backend: default: local + local_resolvconf_prefix: tun. advanced: log_level: info @@ -185,6 +186,11 @@ The current MikroTik backend is in **BETA** and may not support all features. - **Description:** The default backend to use for managing WireGuard interfaces. Valid options are: `local`, or other backend id's configured in the `mikrotik` section. +### `local_resolvconf_prefix` +- **Default:** `tun.` +- **Description:** Interface name prefix for WireGuard interfaces on the local system which is used to configure DNS servers with *resolvconf*. + It depends on the *resolvconf* implementation you are using, most use a prefix of `tun.`, but some have an empty prefix (e.g., systemd). + ### `ignored_local_interfaces` - **Default:** *(empty)* - **Description:** A list of interface names to exclude when enumerating local interfaces. diff --git a/frontend/src/components/InterfaceEditModal.vue b/frontend/src/components/InterfaceEditModal.vue index 6782b9c..2f94433 100644 --- a/frontend/src/components/InterfaceEditModal.vue +++ b/frontend/src/components/InterfaceEditModal.vue @@ -444,6 +444,11 @@ async function del() { +
+ + + {{ $t('modals.interface-edit.routing-table.description') }} +
@@ -457,7 +462,7 @@ async function del() { -
+
{{ $t('modals.interface-edit.header-hooks') }}
@@ -482,7 +487,7 @@ async function del() {
-
+
diff --git a/internal/adapters/wgcontroller/local.go b/internal/adapters/wgcontroller/local.go index 7f2e7fa..d91a5f6 100644 --- a/internal/adapters/wgcontroller/local.go +++ b/internal/adapters/wgcontroller/local.go @@ -9,6 +9,7 @@ import ( "log/slog" "os" "os/exec" + "slices" "strings" "time" @@ -84,8 +85,8 @@ func NewLocalController(cfg *config.Config) (*LocalController, error) { wg: wg, nl: nl, - shellCmd: "bash", // we only support bash at the moment - resolvConfIfacePrefix: "tun.", // WireGuard interfaces have a tun. prefix in resolvconf + shellCmd: "bash", // we only support bash at the moment + resolvConfIfacePrefix: cfg.Backend.LocalResolvconfPrefix, // WireGuard interfaces have a tun. prefix in resolvconf } return repo, nil @@ -546,7 +547,11 @@ func (c LocalController) deletePeer(deviceId domain.InterfaceIdentifier, id doma // region wg-quick-related -func (c LocalController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error { +func (c LocalController) ExecuteInterfaceHook( + _ context.Context, + id domain.InterfaceIdentifier, + hookCmd string, +) error { if hookCmd == "" { return nil } @@ -560,7 +565,7 @@ func (c LocalController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hoo return nil } -func (c LocalController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error { +func (c LocalController) SetDNS(_ context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error { if dnsStr == "" && dnsSearchStr == "" { return nil } @@ -589,7 +594,7 @@ func (c LocalController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearch return nil } -func (c LocalController) UnsetDNS(id domain.InterfaceIdentifier) error { +func (c LocalController) UnsetDNS(_ context.Context, id domain.InterfaceIdentifier, _, _ string) error { dnsCommand := "resolvconf -d %resPref%i -f" err := c.exec(dnsCommand, id) @@ -611,7 +616,7 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti if len(stdin) > 0 { b := &bytes.Buffer{} for _, ln := range stdin { - if _, err := fmt.Fprint(b, ln); err != nil { + if _, err := fmt.Fprint(b, ln+"\n"); err != nil { return err } } @@ -619,6 +624,8 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti } out, err := cmd.CombinedOutput() // execute and wait for output if err != nil { + slog.Warn("failed to executed shell command", + "command", commandWithInterfaceName, "stdin", stdin, "output", string(out), "error", err) return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err) } slog.Debug("executed shell command", @@ -631,49 +638,116 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti // region routing-related -func (c LocalController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error { - // update fwmark rules - if err := c.setFwMarkRules(rules); err != nil { - return err +// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op. +func (c LocalController) SetRoutes(_ context.Context, info domain.RoutingTableInfo) error { + interfaceId := info.Interface.Identifier + slog.Debug("setting linux routes", "interface", interfaceId, "table", info.Table, "fwMark", info.FwMark, + "cidrs", info.AllowedIps) + + link, err := c.nl.LinkByName(string(interfaceId)) + if err != nil { + return fmt.Errorf("failed to find physical link for %s: %w", interfaceId, err) } - // update main rule - if err := c.setMainRule(rules); err != nil { - return err + cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps) + realTable, realFwMark, err := c.getOrCreateRoutingTableAndFwMark(link, info.Table, info.FwMark) + if err != nil { + return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err) + } + wgDev, err := c.wg.Device(string(interfaceId)) + if err != nil { + return fmt.Errorf("failed to get wg device for %s: %w", interfaceId, err) + } + currentFwMark := wgDev.FirewallMark + if int(realFwMark) != currentFwMark { + slog.Debug("updating fwmark for interface", "interface", interfaceId, "oldFwMark", currentFwMark, + "newFwMark", realFwMark, "oldTable", info.Table, "newTable", realTable) + if err := c.updateFwMarkOnInterface(interfaceId, int(realFwMark)); err != nil { + return fmt.Errorf("failed to update fwmark for interface %s to %d: %w", interfaceId, realFwMark, err) + } } - // cleanup old main rules - if err := c.cleanupMainRule(rules); err != nil { - return err + if err := c.setRoutesForFamily(interfaceId, link, netlink.FAMILY_V4, realTable, realFwMark, cidrsV4); err != nil { + return fmt.Errorf("failed to set v4 routes: %w", err) + } + if err := c.setRoutesForFamily(interfaceId, link, netlink.FAMILY_V6, realTable, realFwMark, cidrsV6); err != nil { + return fmt.Errorf("failed to set v6 routes: %w", err) } return nil } -func (c LocalController) setFwMarkRules(rules []domain.RouteRule) error { - for _, rule := range rules { - existingRules, err := c.nl.RuleList(int(rule.IpFamily)) +func (c LocalController) setRoutesForFamily( + interfaceId domain.InterfaceIdentifier, + link netlink.Link, + family int, + table int, + fwMark uint32, + cidrs []domain.Cidr, +) error { + // first create or update the routes + for _, cidr := range cidrs { + err := c.nl.RouteReplace(&netlink.Route{ + LinkIndex: link.Attrs().Index, + Dst: cidr.IpNet(), + Table: table, + Scope: unix.RT_SCOPE_LINK, + Type: unix.RTN_UNICAST, + }) if err != nil { - return fmt.Errorf("failed to get existing rules for family %s: %w", rule.IpFamily, err) + return fmt.Errorf("failed to add/update route %s on table %d for interface %s: %w", + cidr.String(), table, interfaceId, err) } + } - ruleExists := false - for _, existingRule := range existingRules { - if rule.FwMark == existingRule.Mark && rule.Table == existingRule.Table { - ruleExists = true - break + // next remove old routes + rawRoutes, err := c.nl.RouteListFiltered(family, &netlink.Route{ + LinkIndex: link.Attrs().Index, + Table: unix.RT_TABLE_UNSPEC, // all tables + Scope: unix.RT_SCOPE_LINK, + Type: unix.RTN_UNICAST, + }, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF) + if err != nil { + return fmt.Errorf("failed to fetch raw routes for interface %s and family-id %d: %w", + interfaceId, family, err) + } + for _, rawRoute := range rawRoutes { + if rawRoute.Dst == nil { // handle default route + var netlinkAddr domain.Cidr + if family == netlink.FAMILY_V4 { + netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0") + } else { + netlinkAddr, _ = domain.CidrFromString("::/0") } + rawRoute.Dst = netlinkAddr.IpNet() } - if ruleExists { - continue // rule already exists, no need to recreate it + route := domain.CidrFromIpNet(*rawRoute.Dst) + if slices.Contains(cidrs, route) { + continue } - // create a missing rule + if err := c.nl.RouteDel(&rawRoute); err != nil { + return fmt.Errorf("failed to remove deprecated route %s from interface %s: %w", route, interfaceId, err) + } + } + + // next, update route rules for normal routes + if table == 0 { + return nil // no need to update route rules as we are using the default table + } + existingRules, err := c.nl.RuleList(family) + if err != nil { + return fmt.Errorf("failed to get existing rules for family-id %d: %w", family, err) + } + ruleExists := slices.ContainsFunc(existingRules, func(rule netlink.Rule) bool { + return rule.Mark == fwMark && rule.Table == table + }) + if !ruleExists { if err := c.nl.RuleAdd(&netlink.Rule{ - Family: int(rule.IpFamily), - Table: rule.Table, - Mark: rule.FwMark, + Family: family, + Table: table, + Mark: fwMark, Invert: true, SuppressIfgroup: -1, SuppressPrefixlen: -1, @@ -682,15 +756,102 @@ func (c LocalController) setFwMarkRules(rules []domain.RouteRule) error { Goto: -1, Flow: -1, }); err != nil { - return fmt.Errorf("failed to setup %s rule for fwmark %d and table %d: %w", - rule.IpFamily, rule.FwMark, rule.Table, err) + return fmt.Errorf("failed to setup rule for fwmark %d and table %d for family-id %d: %w", + fwMark, table, family, err) } } + mainRuleExists := slices.ContainsFunc(existingRules, func(rule netlink.Rule) bool { + return rule.SuppressPrefixlen == 0 && rule.Table == unix.RT_TABLE_MAIN + }) + if !mainRuleExists && domain.ContainsDefaultRoute(cidrs) { + err = c.nl.RuleAdd(&netlink.Rule{ + Family: family, + Table: unix.RT_TABLE_MAIN, + SuppressIfgroup: -1, + SuppressPrefixlen: 0, + Priority: c.getMainRulePriority(existingRules), + Mark: 0, + Mask: nil, + Goto: -1, + Flow: -1, + }) + } + + // finally, clean up extra main rules - only one rule is allowed + existingRules, err = c.nl.RuleList(family) + if err != nil { + return fmt.Errorf("failed to get existing main rules for family-id %d: %w", family, err) + } + mainRuleCount := 0 + for _, rule := range existingRules { + if rule.SuppressPrefixlen == 0 && rule.Table == unix.RT_TABLE_MAIN { + mainRuleCount++ + } + if mainRuleCount > 1 { + if err := c.nl.RuleDel(&rule); err != nil { + return fmt.Errorf("failed to remove extra main rule for family-id %d: %w", family, err) + } + } + } + return nil } +func (c LocalController) getOrCreateRoutingTableAndFwMark( + link netlink.Link, + tableIn int, + fwMarkIn uint32, +) ( + table int, + fwmark uint32, + err error, +) { + table = tableIn + fwmark = fwMarkIn + + if fwmark == 0 { + // generate a new (temporary) firewall mark based on the interface index + fwmark = uint32(c.cfg.Advanced.RouteTableOffset + link.Attrs().Index) + } + if table == 0 { + table = int(fwmark) // generate a new routing table base on interface index + } + return +} + +func (c LocalController) updateFwMarkOnInterface(interfaceId domain.InterfaceIdentifier, fwMark int) error { + // apply the new fwmark to the wireguard interface + err := c.wg.ConfigureDevice(string(interfaceId), wgtypes.Config{ + FirewallMark: &fwMark, + }) + if err != nil { + return fmt.Errorf("failed to update fwmark of interface %s to: %d: %w", interfaceId, fwMark, err) + } + + return nil +} + +func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int { + prio := c.cfg.Advanced.RulePrioOffset + for { + isFresh := true + for _, existingRule := range existingRules { + if existingRule.Priority == prio { + isFresh = false + break + } + } + if isFresh { + break + } else { + prio++ + } + } + return prio +} + func (c LocalController) getRulePriority(existingRules []netlink.Rule) int { - prio := 32700 // linux main rule has a priority of 32766 + prio := 32700 // linux main rule has a prio of 32766 for { isFresh := true for _, existingRule := range existingRules { @@ -708,126 +869,145 @@ func (c LocalController) getRulePriority(existingRules []netlink.Rule) int { return prio } -func (c LocalController) setMainRule(rules []domain.RouteRule) error { - var family domain.IpFamily - shouldHaveMainRule := false - for _, rule := range rules { - family = rule.IpFamily - if rule.HasDefault == true { - shouldHaveMainRule = true - break - } - } - if !shouldHaveMainRule { - return nil - } +// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op. +func (c LocalController) RemoveRoutes(_ context.Context, info domain.RoutingTableInfo) error { + interfaceId := info.Interface.Identifier + slog.Debug("removing linux routes", "interface", interfaceId, "table", info.Table, "fwMark", info.FwMark, + "cidrs", info.AllowedIps) - existingRules, err := c.nl.RuleList(int(family)) + wgDev, err := c.wg.Device(string(interfaceId)) if err != nil { - return fmt.Errorf("failed to get existing rules for family %s: %w", family, err) + slog.Debug("wg device already removed, route cleanup might be incomplete", "interface", interfaceId) + wgDev = nil } - - ruleExists := false - for _, existingRule := range existingRules { - if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { - ruleExists = true - break - } - } - - if ruleExists { - return nil // rule already exists, skip re-creation - } - - if err := c.nl.RuleAdd(&netlink.Rule{ - Family: int(family), - Table: unix.RT_TABLE_MAIN, - SuppressIfgroup: -1, - SuppressPrefixlen: 0, - Priority: c.getMainRulePriority(existingRules), - Mark: 0, - Mask: nil, - Goto: -1, - Flow: -1, - }); err != nil { - return fmt.Errorf("failed to setup rule for main table: %w", err) - } - - return nil -} - -func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int { - priority := c.cfg.Advanced.RulePrioOffset - for { - isFresh := true - for _, existingRule := range existingRules { - if existingRule.Priority == priority { - isFresh = false - break - } - } - if isFresh { - break - } else { - priority++ - } - } - return priority -} - -func (c LocalController) cleanupMainRule(rules []domain.RouteRule) error { - var family domain.IpFamily - for _, rule := range rules { - family = rule.IpFamily - break - } - - existingRules, err := c.nl.RuleList(int(family)) + link, err := c.nl.LinkByName(string(interfaceId)) if err != nil { - return fmt.Errorf("failed to get existing rules for family %s: %w", family, err) + slog.Debug("physical link already removed, route cleanup might be incomplete", "interface", interfaceId) + link = nil } - shouldHaveMainRule := false - for _, rule := range rules { - if rule.HasDefault == true { - shouldHaveMainRule = true - break + fwMark := info.FwMark + if wgDev != nil && info.FwMark == 0 { + fwMark = uint32(wgDev.FirewallMark) + } + table := info.Table + if wgDev != nil && info.Table == 0 { + table = wgDev.FirewallMark // use the fwMark as table, this is the default behavior + } + linkIndex := -1 + if link != nil { + linkIndex = link.Attrs().Index + } + + cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps) + realTable, realFwMark, err := c.getOrCreateRoutingTableAndFwMark(link, table, fwMark) + if err != nil { + return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err) + } + + if linkIndex > 0 { + err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V4, realTable, realFwMark, cidrsV4) + if err != nil { + return fmt.Errorf("failed to remove v4 routes: %w", err) + } + err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V6, realTable, realFwMark, cidrsV6) + if err != nil { + return fmt.Errorf("failed to remove v6 routes: %w", err) } } - mainRules := 0 - for _, existingRule := range existingRules { - if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { - mainRules++ + if table > 0 { + err = c.removeRouteRulesForTable(netlink.FAMILY_V4, realTable) + if err != nil { + return fmt.Errorf("failed to remove v4 route rules for %s: %w", interfaceId, err) } - } - - removalCount := 0 - if mainRules > 1 { - removalCount = mainRules - 1 // we only want one single rule - } - if !shouldHaveMainRule { - removalCount = mainRules - } - - for _, existingRule := range existingRules { - if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { - if removalCount > 0 { - existingRule.Family = int(family) // set family, somehow the RuleList method does not populate the family field - if err := c.nl.RuleDel(&existingRule); err != nil { - return fmt.Errorf("failed to delete main rule: %w", err) - } - removalCount-- - } + err = c.removeRouteRulesForTable(netlink.FAMILY_V6, realTable) + if err != nil { + return fmt.Errorf("failed to remove v6 route rules for %s: %w", interfaceId, err) } } return nil } -func (c LocalController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error { - // TODO implement me - panic("implement me") +func (c LocalController) removeRoutesForFamily( + interfaceId domain.InterfaceIdentifier, + link netlink.Link, + family int, + table int, + fwMark uint32, + cidrs []domain.Cidr, +) error { + // first remove all rules + existingRules, err := c.nl.RuleList(family) + if err != nil { + return fmt.Errorf("failed to get existing rules for family %d: %w", family, err) + } + for _, existingRule := range existingRules { + if fwMark == existingRule.Mark && table == existingRule.Table { + existingRule.Family = family // set family, somehow the RuleList method does not populate the family field + if err := c.nl.RuleDel(&existingRule); err != nil { + return fmt.Errorf("failed to delete old fwmark rule: %w", err) + } + } + } + + // next remove all routes + rawRoutes, err := c.nl.RouteListFiltered(family, &netlink.Route{ + LinkIndex: link.Attrs().Index, + Table: unix.RT_TABLE_UNSPEC, // all tables + Scope: unix.RT_SCOPE_LINK, + Type: unix.RTN_UNICAST, + }, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF) + if err != nil { + return fmt.Errorf("failed to fetch raw routes for interface %s and family-id %d: %w", + interfaceId, family, err) + } + for _, rawRoute := range rawRoutes { + if rawRoute.Dst == nil { // handle default route + var netlinkAddr domain.Cidr + if family == netlink.FAMILY_V4 { + netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0") + } else { + netlinkAddr, _ = domain.CidrFromString("::/0") + } + rawRoute.Dst = netlinkAddr.IpNet() + } + + if rawRoute.Table != table { + continue // ignore routes from other tables + } + + route := domain.CidrFromIpNet(*rawRoute.Dst) + if !slices.Contains(cidrs, route) { + continue // only remove routes that were previously added + } + + if err := c.nl.RouteDel(&rawRoute); err != nil { + return fmt.Errorf("failed to remove old route %s from interface %s: %w", route, interfaceId, err) + } + } + + return nil +} + +func (c LocalController) removeRouteRulesForTable( + family int, + table int, +) error { + existingRules, err := c.nl.RuleList(family) + if err != nil { + return fmt.Errorf("failed to get existing route rules for family-id %d: %w", family, err) + } + for _, existingRule := range existingRules { + if existingRule.Table == table { + err := c.nl.RuleDel(&existingRule) + if err != nil { + return fmt.Errorf("failed to delete old rule for table %d and family-id %d: %w", table, family, err) + } + } + } + return nil } // endregion routing-related diff --git a/internal/adapters/wgcontroller/mikrotik.go b/internal/adapters/wgcontroller/mikrotik.go index ac98094..e730bd7 100644 --- a/internal/adapters/wgcontroller/mikrotik.go +++ b/internal/adapters/wgcontroller/mikrotik.go @@ -15,6 +15,9 @@ import ( "github.com/h44z/wg-portal/internal/lowlevel" ) +const MikrotikRouteDistance = 5 +const MikrotikDefaultRoutingTable = "main" + type MikrotikController struct { coreCfg *config.Config cfg *config.BackendMikrotik @@ -22,8 +25,9 @@ type MikrotikController struct { client *lowlevel.MikrotikApiClient // Add mutexes to prevent race conditions - interfaceMutexes sync.Map // map[domain.InterfaceIdentifier]*sync.Mutex - peerMutexes sync.Map // map[domain.PeerIdentifier]*sync.Mutex + interfaceMutexes sync.Map // map[domain.InterfaceIdentifier]*sync.Mutex + peerMutexes sync.Map // map[domain.PeerIdentifier]*sync.Mutex + coreMutex sync.Mutex // for updating the core configuration such as routing table or DNS settings } func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik) (*MikrotikController, error) { @@ -40,6 +44,7 @@ func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik) interfaceMutexes: sync.Map{}, peerMutexes: sync.Map{}, + coreMutex: sync.Mutex{}, }, nil } @@ -763,33 +768,404 @@ func (c *MikrotikController) DeletePeer( // region wg-quick-related -func (c *MikrotikController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error { +func (c *MikrotikController) ExecuteInterfaceHook( + _ context.Context, + _ domain.InterfaceIdentifier, + _ string, +) error { // TODO implement me - panic("implement me") + slog.Error("interface hooks are not yet supported for Mikrotik backends, please open an issue on GitHub") + return nil } -func (c *MikrotikController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error { - // TODO implement me - panic("implement me") +func (c *MikrotikController) SetDNS( + ctx context.Context, + _ domain.InterfaceIdentifier, + dnsStr, _ string, +) error { + // Lock the interface to prevent concurrent modifications + c.coreMutex.Lock() + defer c.coreMutex.Unlock() + + // check if the server is already configured + wgReply := c.client.Get(ctx, "/ip/dns", &lowlevel.MikrotikRequestOptions{ + PropList: []string{"servers"}, + }) + if wgReply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("unable to find WireGuard dns settings: %v", wgReply.Error) + } + + var existingServers []string + existingServers = append(existingServers, strings.Split(wgReply.Data.GetString("servers"), ",")...) + + newServers := strings.Split(dnsStr, ",") + + mergedServers := slices.Clone(existingServers) + for _, s := range newServers { + if s == "" { + continue + } + if !slices.Contains(mergedServers, s) { + mergedServers = append(mergedServers, s) + } + } + mergedServersStr := strings.Join(mergedServers, ",") + + reply := c.client.ExecList(ctx, "/ip/dns/set", lowlevel.GenericJsonObject{ + "servers": mergedServersStr, + }) + if reply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to set DNS servers: %s: %v", mergedServersStr, reply.Error) + } + + return nil } -func (c *MikrotikController) UnsetDNS(id domain.InterfaceIdentifier) error { - // TODO implement me - panic("implement me") +func (c *MikrotikController) UnsetDNS( + ctx context.Context, + _ domain.InterfaceIdentifier, + dnsStr, _ string, +) error { + // Lock the interface to prevent concurrent modifications + c.coreMutex.Lock() + defer c.coreMutex.Unlock() + + // retrieve current DNS settings + wgReply := c.client.Get(ctx, "/ip/dns", &lowlevel.MikrotikRequestOptions{ + PropList: []string{"servers"}, + }) + if wgReply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("unable to find WireGuard dns settings: %v", wgReply.Error) + } + + var existingServers []string + existingServers = append(existingServers, strings.Split(wgReply.Data.GetString("servers"), ",")...) + + oldServers := strings.Split(dnsStr, ",") + + mergedServers := make([]string, 0, len(existingServers)) + for _, s := range existingServers { + if s == "" { + continue + } + if !slices.Contains(oldServers, s) { + mergedServers = append(mergedServers, s) // only keep the servers that are not in the old list + } + } + mergedServersStr := strings.Join(mergedServers, ",") + + reply := c.client.ExecList(ctx, "/ip/dns/set", lowlevel.GenericJsonObject{ + "servers": mergedServersStr, + }) + if reply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to set DNS servers: %s: %v", mergedServersStr, reply.Error) + } + + return nil } // endregion wg-quick-related // region routing-related -func (c *MikrotikController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error { - // TODO implement me - panic("implement me") +// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op. +func (c *MikrotikController) SetRoutes(ctx context.Context, info domain.RoutingTableInfo) error { + interfaceId := info.Interface.Identifier + slog.Debug("setting mikrotik routes", "interface", interfaceId, "table", info.TableStr, "cidrs", info.AllowedIps) + + // Mikrotik needs some time to apply the changes. + // If we don't wait, the routes might get created multiple times as the dynamic routes are not yet available. + time.Sleep(2 * time.Second) + + tableName, err := c.getOrCreateRoutingTables(ctx, info.Interface.Identifier, info.TableStr) + if err != nil { + return fmt.Errorf("failed to get or create routing table for %s: %v", interfaceId, err) + } + + cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps) + + err = c.setRoutesForFamily(ctx, interfaceId, false, tableName, cidrsV4) + if err != nil { + return fmt.Errorf("failed to set IPv4 routes for %s: %v", interfaceId, err) + } + + err = c.setRoutesForFamily(ctx, interfaceId, true, tableName, cidrsV6) + if err != nil { + return fmt.Errorf("failed to set IPv6 routes for %s: %v", interfaceId, err) + } + + return nil } -func (c *MikrotikController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error { - // TODO implement me - panic("implement me") +func (c *MikrotikController) resolveRouteTableName(name string) string { + name = strings.TrimSpace(name) + + var mikrotikTableName string + switch strings.ToLower(name) { + case "", "0": + mikrotikTableName = MikrotikDefaultRoutingTable + case MikrotikDefaultRoutingTable: + return fmt.Sprintf("wgportal-%s", + MikrotikDefaultRoutingTable) // if the Mikrotik Main table should be used, the table-name should be left empty or set to "0". + default: + mikrotikTableName = name + } + + return mikrotikTableName +} + +func (c *MikrotikController) getOrCreateRoutingTables( + ctx context.Context, + interfaceId domain.InterfaceIdentifier, + table string, +) (string, error) { + // retrieve current routing tables + wgReply := c.client.Query(ctx, "/routing/table", &lowlevel.MikrotikRequestOptions{ + PropList: []string{ + ".id", "dynamic", "fib", "name", + }, + }) + if wgReply.Status != lowlevel.MikrotikApiStatusOk { + return "", fmt.Errorf("unable to query routing tables: %v", wgReply.Error) + } + + wantedTableName := c.resolveRouteTableName(table) + + // check if the table already exists + for _, table := range wgReply.Data { + if table.GetString("name") == wantedTableName { + return wantedTableName, nil // already exists, nothing to do + } + } + + // create the table if it does not exist + createReply := c.client.Create(ctx, "/routing/table", lowlevel.GenericJsonObject{ + "name": wantedTableName, + "comment": fmt.Sprintf("Routing Table for %s", interfaceId), + "fib": strconv.FormatBool(true), + }) + if createReply.Status != lowlevel.MikrotikApiStatusOk { + return "", fmt.Errorf("failed to create routing table %s: %v", wantedTableName, createReply.Error) + } + + return wantedTableName, nil +} + +func (c *MikrotikController) setRoutesForFamily( + ctx context.Context, + interfaceId domain.InterfaceIdentifier, + ipV6 bool, + table string, + cidrs []domain.Cidr, +) error { + apiPath := "/ip/route" + if ipV6 { + apiPath = "/ipv6/route" + } + + // retrieve current routes + wgReply := c.client.Query(ctx, apiPath, &lowlevel.MikrotikRequestOptions{ + PropList: []string{ + ".id", "disabled", "inactive", "distance", "dst-address", "dynamic", "gateway", "immediate-gw", + "routing-table", "scope", "target-scope", "client-dns", "comment", "disabled", "responder", + }, + Filters: map[string]string{ + "gateway": string(interfaceId), + }, + }) + if wgReply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("unable to find WireGuard IP route settings (v6=%t): %v", ipV6, wgReply.Error) + } + + // first create or update the routes + for _, cidr := range cidrs { + // check if the route already exists + exists := false + for _, route := range wgReply.Data { + existingRoute, err := domain.CidrFromString(route.GetString("dst-address")) + if err != nil { + slog.Warn("failed to parse route destination address", + "cidr", route.GetString("dst-address"), "error", err) + continue + } + if existingRoute.EqualPrefix(cidr) && route.GetString("routing-table") == table { + exists = true + break + } + } + if exists { + continue // route already exists, nothing to do + } + + // create the route + reply := c.client.Create(ctx, apiPath, lowlevel.GenericJsonObject{ + "gateway": string(interfaceId), + "dst-address": cidr.String(), + "distance": strconv.Itoa(MikrotikRouteDistance), + "disabled": strconv.FormatBool(false), + "routing-table": table, + }) + if reply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to create new route %s via %s: %v", cidr.String(), interfaceId, reply.Error) + } + } + + // finally, remove the routes that are not in the new list + for _, route := range wgReply.Data { + if route.GetBool("dynamic") { + continue // dynamic routes are not managed by the controller, nothing to do + } + + existingRoute, err := domain.CidrFromString(route.GetString("dst-address")) + if err != nil { + slog.Warn("failed to parse route destination address", + "cidr", route.GetString("dst-address"), "error", err) + continue + } + + valid := false + for _, cidr := range cidrs { + if existingRoute.EqualPrefix(cidr) { + valid = true + break + } + } + if valid { + continue // route is still valid, nothing to do + } + + // remove the route + reply := c.client.Delete(ctx, apiPath+"/"+route.GetString(".id")) + if reply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to remove outdated route %s: %v", existingRoute.String(), reply.Error) + } + } + + return nil +} + +// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op. +func (c *MikrotikController) RemoveRoutes(ctx context.Context, info domain.RoutingTableInfo) error { + interfaceId := info.Interface.Identifier + slog.Debug("removing mikrotik routes", "interface", interfaceId, "table", info.TableStr, "cidrs", info.AllowedIps) + + tableName := c.resolveRouteTableName(info.TableStr) + + cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps) + + err := c.removeRoutesForFamily(ctx, interfaceId, false, tableName, cidrsV4) + if err != nil { + return fmt.Errorf("failed to remove IPv4 routes for %s: %v", interfaceId, err) + } + + err = c.removeRoutesForFamily(ctx, interfaceId, true, tableName, cidrsV6) + if err != nil { + return fmt.Errorf("failed to remove IPv6 routes for %s: %v", interfaceId, err) + } + + err = c.removeRoutingTable(ctx, tableName) + if err != nil { + return fmt.Errorf("failed to remove routing table for %s: %v", interfaceId, err) + } + + return nil +} + +func (c *MikrotikController) removeRoutesForFamily( + ctx context.Context, + interfaceId domain.InterfaceIdentifier, + ipV6 bool, + table string, + cidrs []domain.Cidr, +) error { + apiPath := "/ip/route" + if ipV6 { + apiPath = "/ipv6/route" + } + + // retrieve current routes + wgReply := c.client.Query(ctx, apiPath, &lowlevel.MikrotikRequestOptions{ + PropList: []string{ + ".id", "disabled", "inactive", "distance", "dst-address", "dynamic", "gateway", "immediate-gw", + "routing-table", "scope", "target-scope", "client-dns", "comment", "disabled", "responder", + }, + Filters: map[string]string{ + "gateway": string(interfaceId), + }, + }) + if wgReply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("unable to find WireGuard IP route settings (v6=%t): %v", ipV6, wgReply.Error) + } + + // remove the routes from the list + for _, route := range wgReply.Data { + if route.GetBool("dynamic") { + continue // dynamic routes are not managed by the controller, nothing to do + } + + existingRoute, err := domain.CidrFromString(route.GetString("dst-address")) + if err != nil { + slog.Warn("failed to parse route destination address", + "cidr", route.GetString("dst-address"), "error", err) + continue + } + + remove := false + for _, cidr := range cidrs { + if existingRoute.EqualPrefix(cidr) && route.GetString("routing-table") == table { + remove = true + break + } + } + if !remove { + continue // route is still valid, nothing to do + } + + // remove the route + reply := c.client.Delete(ctx, apiPath+"/"+route.GetString(".id")) + if reply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to remove old route %s: %v", existingRoute.String(), reply.Error) + } + } + + return nil +} + +func (c *MikrotikController) removeRoutingTable( + ctx context.Context, + table string, +) error { + if table == MikrotikDefaultRoutingTable { + return nil // we cannot remove the default table + } + + // retrieve current routing tables + wgReply := c.client.Query(ctx, "/routing/table", &lowlevel.MikrotikRequestOptions{ + PropList: []string{ + ".id", "dynamic", "fib", "name", + }, + }) + if wgReply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("unable to query routing tables: %v", wgReply.Error) + } + + for _, existingTable := range wgReply.Data { + if existingTable.GetBool("dynamic") { + continue // dynamic tables are not managed by the controller, nothing to do + } + if existingTable.GetString("name") != table { + continue // not the table we want to remove + } + + // remove the table + reply := c.client.Delete(ctx, "/routing/table/"+existingTable.GetString(".id")) + if reply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to remove routing table %s: %v", table, reply.Error) + } + return nil + } + + return nil } // endregion routing-related diff --git a/internal/adapters/wgquick.go b/internal/adapters/wgquick.go deleted file mode 100644 index 992e69a..0000000 --- a/internal/adapters/wgquick.go +++ /dev/null @@ -1,113 +0,0 @@ -package adapters - -import ( - "bytes" - "fmt" - "log/slog" - "os/exec" - "strings" - - "github.com/h44z/wg-portal/internal" - "github.com/h44z/wg-portal/internal/domain" -) - -// WgQuickRepo implements higher level wg-quick like interactions like setting DNS, routing tables or interface hooks. -type WgQuickRepo struct { - shellCmd string - resolvConfIfacePrefix string -} - -// NewWgQuickRepo creates a new WgQuickRepo instance. -func NewWgQuickRepo() *WgQuickRepo { - return &WgQuickRepo{ - shellCmd: "bash", - resolvConfIfacePrefix: "tun.", - } -} - -// ExecuteInterfaceHook executes the given hook command. -// The hook command can contain the following placeholders: -// -// %i: the interface identifier. -func (r *WgQuickRepo) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error { - if hookCmd == "" { - return nil - } - - slog.Debug("executing interface hook", "interface", id, "hook", hookCmd) - err := r.exec(hookCmd, id) - if err != nil { - return fmt.Errorf("failed to exec hook: %w", err) - } - - return nil -} - -// SetDNS sets the DNS settings for the given interface. It uses resolvconf to set the DNS settings. -func (r *WgQuickRepo) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error { - if dnsStr == "" && dnsSearchStr == "" { - return nil - } - - dnsServers := internal.SliceString(dnsStr) - dnsSearchDomains := internal.SliceString(dnsSearchStr) - - dnsCommand := "resolvconf -a %resPref%i -m 0 -x" - dnsCommandInput := make([]string, 0, len(dnsServers)+len(dnsSearchDomains)) - - for _, dnsServer := range dnsServers { - dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("nameserver %s", dnsServer)) - } - for _, searchDomain := range dnsSearchDomains { - dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("search %s", searchDomain)) - } - - err := r.exec(dnsCommand, id, dnsCommandInput...) - if err != nil { - return fmt.Errorf( - "failed to set dns settings (is resolvconf available?, for systemd create this symlink: ln -s /usr/bin/resolvectl /usr/local/bin/resolvconf): %w", - err, - ) - } - - return nil -} - -// UnsetDNS unsets the DNS settings for the given interface. It uses resolvconf to unset the DNS settings. -func (r *WgQuickRepo) UnsetDNS(id domain.InterfaceIdentifier) error { - dnsCommand := "resolvconf -d %resPref%i -f" - - err := r.exec(dnsCommand, id) - if err != nil { - return fmt.Errorf("failed to unset dns settings: %w", err) - } - - return nil -} - -func (r *WgQuickRepo) replaceCommandPlaceHolders(command string, interfaceId domain.InterfaceIdentifier) string { - command = strings.ReplaceAll(command, "%resPref", r.resolvConfIfacePrefix) - return strings.ReplaceAll(command, "%i", string(interfaceId)) -} - -func (r *WgQuickRepo) exec(command string, interfaceId domain.InterfaceIdentifier, stdin ...string) error { - commandWithInterfaceName := r.replaceCommandPlaceHolders(command, interfaceId) - cmd := exec.Command(r.shellCmd, "-ce", commandWithInterfaceName) - if len(stdin) > 0 { - b := &bytes.Buffer{} - for _, ln := range stdin { - if _, err := fmt.Fprint(b, ln); err != nil { - return err - } - } - cmd.Stdin = b - } - out, err := cmd.CombinedOutput() // execute and wait for output - if err != nil { - return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err) - } - slog.Debug("executed shell command", - "command", commandWithInterfaceName, - "output", string(out)) - return nil -} diff --git a/internal/app/route/routes.go b/internal/app/route/routes.go index c87bcaf..62cd67e 100644 --- a/internal/app/route/routes.go +++ b/internal/app/route/routes.go @@ -4,25 +4,23 @@ import ( "context" "fmt" "log/slog" - - "github.com/vishvananda/netlink" - "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/wgctrl" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "sync" "github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/domain" - "github.com/h44z/wg-portal/internal/lowlevel" ) // region dependencies +type ControllerManager interface { + // GetController returns the controller for the given interface. + GetController(iface domain.Interface) domain.InterfaceController +} + type InterfaceAndPeerDatabaseRepo interface { - // GetAllInterfaces returns all interfaces - GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) - // GetInterfacePeers returns all peers for a given interface - GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error) + // GetInterface returns the interface with the given identifier. + GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) } type EventBus interface { @@ -30,6 +28,13 @@ type EventBus interface { Subscribe(topic string, fn interface{}) error } +type RoutesController interface { + // SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op. + SetRoutes(ctx context.Context, info domain.RoutingTableInfo) error + // RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op. + RemoveRoutes(ctx context.Context, info domain.RoutingTableInfo) error +} + // endregion dependencies type routeRuleInfo struct { @@ -45,28 +50,27 @@ type routeRuleInfo struct { type Manager struct { cfg *config.Config - bus EventBus - wg lowlevel.WireGuardClient - nl lowlevel.NetlinkClient - db InterfaceAndPeerDatabaseRepo + bus EventBus + db InterfaceAndPeerDatabaseRepo + wgController ControllerManager + + mux *sync.Mutex } // NewRouteManager creates a new route manager instance. -func NewRouteManager(cfg *config.Config, bus EventBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) { - wg, err := wgctrl.New() - if err != nil { - panic("failed to init wgctrl: " + err.Error()) - } - - nl := &lowlevel.NetlinkManager{} - +func NewRouteManager( + cfg *config.Config, + bus EventBus, + db InterfaceAndPeerDatabaseRepo, + wgController ControllerManager, +) (*Manager, error) { m := &Manager{ cfg: cfg, bus: bus, - db: db, - wg: wg, - nl: nl, + db: db, + wgController: wgController, + mux: &sync.Mutex{}, } m.connectToMessageBus() @@ -85,419 +89,82 @@ func (m Manager) StartBackgroundJobs(_ context.Context) { // this is a no-op for now } -func (m Manager) handleRouteUpdateEvent(srcDescription string) { - slog.Debug("handling route update event", "source", srcDescription) +func (m Manager) handleRouteUpdateEvent(info domain.RoutingTableInfo) { + m.mux.Lock() // ensure that only one route update is processed at a time + defer m.mux.Unlock() - err := m.syncRoutes(context.Background()) - if err != nil { - slog.Error("failed to synchronize routes", - "source", srcDescription, - "error", err) + slog.Debug("handling route update event", "info", info.String()) + + if !info.ManagementEnabled() { + return // route management disabled } - slog.Debug("routes synchronized", "source", srcDescription) + err := m.syncRoutes(context.Background(), info) + if err != nil { + slog.Error("failed to synchronize routes", + "info", info.String(), "error", err) + return + } + + slog.Debug("routes synchronized", "info", info.String()) } func (m Manager) handleRouteRemoveEvent(info domain.RoutingTableInfo) { + m.mux.Lock() // ensure that only one route update is processed at a time + defer m.mux.Unlock() + slog.Debug("handling route remove event", "info", info.String()) if !info.ManagementEnabled() { return // route management disabled } - if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V4); err != nil { - slog.Error("failed to remove v4 fwmark rules", "error", err) - } - if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V6); err != nil { - slog.Error("failed to remove v6 fwmark rules", "error", err) - } - - slog.Debug("routes removed", "table", info.String()) -} - -func (m Manager) syncRoutes(ctx context.Context) error { - interfaces, err := m.db.GetAllInterfaces(ctx) + err := m.removeRoutes(context.Background(), info) if err != nil { - return fmt.Errorf("failed to find all interfaces: %w", err) + slog.Error("failed to synchronize routes", + "info", info.String(), "error", err) + return } - rules := map[int][]routeRuleInfo{ - netlink.FAMILY_V4: nil, - netlink.FAMILY_V6: nil, - } - for _, iface := range interfaces { - if iface.IsDisabled() { - continue // disabled interface does not need route entries - } - if !iface.ManageRoutingTable() { - continue - } - - peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier) - if err != nil { - return fmt.Errorf("failed to find peers for %s: %w", iface.Identifier, err) - } - allowedIPs := iface.GetAllowedIPs(peers) - defRouteV4, defRouteV6 := m.containsDefaultRoute(allowedIPs) - - link, err := m.nl.LinkByName(string(iface.Identifier)) - if err != nil { - return fmt.Errorf("failed to find physical link for %s: %w", iface.Identifier, err) - } - - table, fwmark, err := m.getRoutingTableAndFwMark(&iface, link) - if err != nil { - return fmt.Errorf("failed to get table and fwmark for %s: %w", iface.Identifier, err) - } - - if err := m.setInterfaceRoutes(link, table, allowedIPs); err != nil { - return fmt.Errorf("failed to set routes for %s: %w", iface.Identifier, err) - } - - if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V4, allowedIPs); err != nil { - return fmt.Errorf("failed to remove deprecated v4 routes for %s: %w", iface.Identifier, err) - } - if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V6, allowedIPs); err != nil { - return fmt.Errorf("failed to remove deprecated v6 routes for %s: %w", iface.Identifier, err) - } - - if table != 0 { - rules[netlink.FAMILY_V4] = append(rules[netlink.FAMILY_V4], routeRuleInfo{ - ifaceId: iface.Identifier, - fwMark: fwmark, - table: table, - family: netlink.FAMILY_V4, - hasDefault: defRouteV4, - }) - } - if table != 0 { - rules[netlink.FAMILY_V6] = append(rules[netlink.FAMILY_V6], routeRuleInfo{ - ifaceId: iface.Identifier, - fwMark: fwmark, - table: table, - family: netlink.FAMILY_V6, - hasDefault: defRouteV6, - }) - } - } - - return m.syncRouteRules(rules) + slog.Debug("routes removed", "info", info.String()) } -func (m Manager) syncRouteRules(allRules map[int][]routeRuleInfo) error { - for family, rules := range allRules { - // update fwmark rules - if err := m.setFwMarkRules(rules, family); err != nil { - return err - } - - // update main rule - if err := m.setMainRule(rules, family); err != nil { - return err - } - - // cleanup old main rules - if err := m.cleanupMainRule(rules, family); err != nil { - return err - } - } - - return nil -} - -func (m Manager) setFwMarkRules(rules []routeRuleInfo, family int) error { - for _, rule := range rules { - existingRules, err := m.nl.RuleList(family) - if err != nil { - return fmt.Errorf("failed to get existing rules for family %d: %w", family, err) - } - - ruleExists := false - for _, existingRule := range existingRules { - if rule.fwMark == existingRule.Mark && rule.table == existingRule.Table { - ruleExists = true - break - } - } - - if ruleExists { - continue // rule already exists, no need to recreate it - } - - // create missing rule - if err := m.nl.RuleAdd(&netlink.Rule{ - Family: family, - Table: rule.table, - Mark: rule.fwMark, - Invert: true, - SuppressIfgroup: -1, - SuppressPrefixlen: -1, - Priority: m.getRulePriority(existingRules), - Mask: nil, - Goto: -1, - Flow: -1, - }); err != nil { - return fmt.Errorf("failed to setup rule for fwmark %d and table %d: %w", rule.fwMark, rule.table, err) - } - } - return nil -} - -func (m Manager) removeFwMarkRules(fwmark uint32, table int, family int) error { - existingRules, err := m.nl.RuleList(family) - if err != nil { - return fmt.Errorf("failed to get existing rules for family %d: %w", family, err) - } - - for _, existingRule := range existingRules { - if fwmark == existingRule.Mark && table == existingRule.Table { - existingRule.Family = family // set family, somehow the RuleList method does not populate the family field - if err := m.nl.RuleDel(&existingRule); err != nil { - return fmt.Errorf("failed to delete fwmark rule: %w", err) - } - } - } - return nil -} - -func (m Manager) setMainRule(rules []routeRuleInfo, family int) error { - shouldHaveMainRule := false - for _, rule := range rules { - if rule.hasDefault == true { - shouldHaveMainRule = true - break - } - } - if !shouldHaveMainRule { +func (m Manager) syncRoutes(ctx context.Context, info domain.RoutingTableInfo) error { + rc, ok := m.wgController.GetController(info.Interface).(RoutesController) + if !ok { + slog.Warn("no capable routes-controller found for interface", "interface", info.Interface.Identifier) return nil } - existingRules, err := m.nl.RuleList(family) + if !info.Interface.ManageRoutingTable() { + slog.Debug("interface does not manage routing table, skipping route update", + "interface", info.Interface.Identifier) + return nil + } + + err := rc.SetRoutes(ctx, info) if err != nil { - return fmt.Errorf("failed to get existing rules for family %d: %w", family, err) + return fmt.Errorf("failed to set routes for interface %s: %w", info.Interface.Identifier, err) } - - ruleExists := false - for _, existingRule := range existingRules { - if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { - ruleExists = true - break - } - } - - if ruleExists { - return nil // rule already exists, skip re-creation - } - - if err := m.nl.RuleAdd(&netlink.Rule{ - Family: family, - Table: unix.RT_TABLE_MAIN, - SuppressIfgroup: -1, - SuppressPrefixlen: 0, - Priority: m.getMainRulePriority(existingRules), - Mark: 0, - Mask: nil, - Goto: -1, - Flow: -1, - }); err != nil { - return fmt.Errorf("failed to setup rule for main table: %w", err) - } - return nil } -func (m Manager) cleanupMainRule(rules []routeRuleInfo, family int) error { - existingRules, err := m.nl.RuleList(family) +func (m Manager) removeRoutes(ctx context.Context, info domain.RoutingTableInfo) error { + rc, ok := m.wgController.GetController(info.Interface).(RoutesController) + if !ok { + slog.Warn("no capable routes-controller found for interface", "interface", info.Interface.Identifier) + return nil + } + + if !info.Interface.ManageRoutingTable() { + slog.Debug("interface does not manage routing table, skipping route removal", + "interface", info.Interface.Identifier) + return nil + } + + err := rc.RemoveRoutes(ctx, info) if err != nil { - return fmt.Errorf("failed to get existing rules for family %d: %w", family, err) - } - - shouldHaveMainRule := false - for _, rule := range rules { - if rule.hasDefault == true { - shouldHaveMainRule = true - break - } - } - - mainRules := 0 - for _, existingRule := range existingRules { - if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { - mainRules++ - } - } - - removalCount := 0 - if mainRules > 1 { - removalCount = mainRules - 1 // we only want one single rule - } - if !shouldHaveMainRule { - removalCount = mainRules - } - - for _, existingRule := range existingRules { - if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { - if removalCount > 0 { - existingRule.Family = family // set family, somehow the RuleList method does not populate the family field - if err := m.nl.RuleDel(&existingRule); err != nil { - return fmt.Errorf("failed to delete main rule: %w", err) - } - removalCount-- - } - } - } - - return nil -} - -func (m Manager) getMainRulePriority(existingRules []netlink.Rule) int { - prio := m.cfg.Advanced.RulePrioOffset - for { - isFresh := true - for _, existingRule := range existingRules { - if existingRule.Priority == prio { - isFresh = false - break - } - } - if isFresh { - break - } else { - prio++ - } - } - return prio -} - -func (m Manager) getRulePriority(existingRules []netlink.Rule) int { - prio := 32700 // linux main rule has a prio of 32766 - for { - isFresh := true - for _, existingRule := range existingRules { - if existingRule.Priority == prio { - isFresh = false - break - } - } - if isFresh { - break - } else { - prio-- - } - } - return prio -} - -func (m Manager) setInterfaceRoutes(link netlink.Link, table int, allowedIPs []domain.Cidr) error { - for _, allowedIP := range allowedIPs { - err := m.nl.RouteReplace(&netlink.Route{ - LinkIndex: link.Attrs().Index, - Dst: allowedIP.IpNet(), - Table: table, - Scope: unix.RT_SCOPE_LINK, - Type: unix.RTN_UNICAST, - }) - if err != nil { - return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err) - } - } - - return nil -} - -func (m Manager) removeDeprecatedRoutes(link netlink.Link, family int, allowedIPs []domain.Cidr) error { - rawRoutes, err := m.nl.RouteListFiltered(family, &netlink.Route{ - LinkIndex: link.Attrs().Index, - Table: unix.RT_TABLE_UNSPEC, // all tables - Scope: unix.RT_SCOPE_LINK, - Type: unix.RTN_UNICAST, - }, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF) - if err != nil { - return fmt.Errorf("failed to fetch raw routes: %w", err) - } - for _, rawRoute := range rawRoutes { - if rawRoute.Dst == nil { // handle default route - var netlinkAddr domain.Cidr - if family == netlink.FAMILY_V4 { - netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0") - } else { - netlinkAddr, _ = domain.CidrFromString("::/0") - } - rawRoute.Dst = netlinkAddr.IpNet() - } - - netlinkAddr := domain.CidrFromIpNet(*rawRoute.Dst) - remove := true - for _, allowedIP := range allowedIPs { - if netlinkAddr == allowedIP { - remove = false - break - } - } - - if !remove { - continue - } - - err := m.nl.RouteDel(&rawRoute) - if err != nil { - return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err) - } + return fmt.Errorf("failed to remove routes for interface %s: %w", info.Interface.Identifier, err) } return nil } - -func (m Manager) getRoutingTableAndFwMark(iface *domain.Interface, link netlink.Link) ( - table int, - fwmark uint32, - err error, -) { - table = iface.GetRoutingTable() - fwmark = iface.FirewallMark - - if fwmark == 0 { - // generate a new (temporary) firewall mark based on the interface index - fwmark = uint32(m.cfg.Advanced.RouteTableOffset + link.Attrs().Index) - slog.Debug("using fwmark to handle routes", - "interface", iface.Identifier, - "fwmark", fwmark) - - // apply the temporary fwmark to the wireguard interface - err = m.setFwMark(iface.Identifier, int(fwmark)) - } - if table == 0 { - table = int(fwmark) // generate a new routing table base on interface index - slog.Debug("using routing table to handle default routes", - "interface", iface.Identifier, - "table", table) - } - return -} - -func (m Manager) setFwMark(id domain.InterfaceIdentifier, fwmark int) error { - err := m.wg.ConfigureDevice(string(id), wgtypes.Config{ - FirewallMark: &fwmark, - }) - if err != nil { - return fmt.Errorf("failed to update fwmark to: %d: %w", fwmark, err) - } - return nil -} - -func (m Manager) containsDefaultRoute(allowedIPs []domain.Cidr) (ipV4, ipV6 bool) { - for _, allowedIP := range allowedIPs { - if ipV4 && ipV6 { - break // speed up - } - - if allowedIP.Prefix().Bits() == 0 { - if allowedIP.IsV4() { - ipV4 = true - } else { - ipV6 = true - } - } - } - - return -} diff --git a/internal/app/wireguard/controller_manager.go b/internal/app/wireguard/controller_manager.go index 2eea6af..0f6bd23 100644 --- a/internal/app/wireguard/controller_manager.go +++ b/internal/app/wireguard/controller_manager.go @@ -1,7 +1,6 @@ package wireguard import ( - "context" "fmt" "log/slog" "maps" @@ -12,33 +11,9 @@ import ( "github.com/h44z/wg-portal/internal/domain" ) -type InterfaceController interface { - GetId() domain.InterfaceBackend - GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) - GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) - GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) - SaveInterface( - _ context.Context, - id domain.InterfaceIdentifier, - updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error), - ) error - DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error - SavePeer( - _ context.Context, - deviceId domain.InterfaceIdentifier, - id domain.PeerIdentifier, - updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error), - ) error - DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error - PingAddresses( - ctx context.Context, - addr string, - ) (*domain.PingerResult, error) -} - type backendInstance struct { Config config.BackendBase // Config is the configuration for the backend instance. - Implementation InterfaceController + Implementation domain.InterfaceController } type ControllerManager struct { @@ -118,11 +93,11 @@ func (c *ControllerManager) logRegisteredControllers() { } } -func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) InterfaceController { +func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) domain.InterfaceController { return c.getController(backend, "").Implementation } -func (c *ControllerManager) GetController(iface domain.Interface) InterfaceController { +func (c *ControllerManager) GetController(iface domain.Interface) domain.InterfaceController { return c.getController(iface.Backend, iface.Identifier).Implementation } diff --git a/internal/app/wireguard/wireguard.go b/internal/app/wireguard/wireguard.go index b28f70e..e1e9dfa 100644 --- a/internal/app/wireguard/wireguard.go +++ b/internal/app/wireguard/wireguard.go @@ -38,9 +38,9 @@ type InterfaceAndPeerDatabaseRepo interface { } type WgQuickController interface { - ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error - SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error - UnsetDNS(id domain.InterfaceIdentifier) error + ExecuteInterfaceHook(ctx context.Context, id domain.InterfaceIdentifier, hookCmd string) error + SetDNS(ctx context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error + UnsetDNS(ctx context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error } type EventBus interface { @@ -53,11 +53,10 @@ type EventBus interface { // endregion dependencies type Manager struct { - cfg *config.Config - bus EventBus - db InterfaceAndPeerDatabaseRepo - wg *ControllerManager - quick WgQuickController + cfg *config.Config + bus EventBus + db InterfaceAndPeerDatabaseRepo + wg *ControllerManager userLockMap *sync.Map } @@ -66,7 +65,6 @@ func NewWireGuardManager( cfg *config.Config, bus EventBus, wg *ControllerManager, - quick WgQuickController, db InterfaceAndPeerDatabaseRepo, ) (*Manager, error) { m := &Manager{ @@ -74,7 +72,6 @@ func NewWireGuardManager( bus: bus, wg: wg, db: db, - quick: quick, userLockMap: &sync.Map{}, } diff --git a/internal/app/wireguard/wireguard_interfaces.go b/internal/app/wireguard/wireguard_interfaces.go index 368d1eb..3dbe53f 100644 --- a/internal/app/wireguard/wireguard_interfaces.go +++ b/internal/app/wireguard/wireguard_interfaces.go @@ -453,7 +453,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif return err } - existingInterface, err := m.db.GetInterface(ctx, id) + existingInterface, existingPeers, err := m.db.GetInterfaceAndPeers(ctx, id) if err != nil { return fmt.Errorf("unable to find interface %s: %w", id, err) } @@ -462,21 +462,29 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif return fmt.Errorf("deletion not allowed: %w", err) } + m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{ + Interface: *existingInterface, + AllowedIps: existingInterface.GetAllowedIPs(existingPeers), + FwMark: existingInterface.FirewallMark, + Table: existingInterface.GetRoutingTable(), + TableStr: existingInterface.RoutingTable, + IsDeleted: true, + }) + now := time.Now() existingInterface.Disabled = &now // simulate a disabled interface existingInterface.DisabledReason = domain.DisabledReasonDeleted - physicalInterface, _ := m.wg.GetController(*existingInterface).GetInterface(ctx, id) - - if err := m.handleInterfacePreSaveHooks(existingInterface, !existingInterface.IsDisabled(), false); err != nil { + if err := m.handleInterfacePreSaveHooks(ctx, existingInterface, !existingInterface.IsDisabled(), + false); err != nil { return fmt.Errorf("pre-delete hooks failed: %w", err) } - if err := m.handleInterfacePreSaveActions(existingInterface); err != nil { + if err := m.handleInterfacePreSaveActions(ctx, existingInterface); err != nil { return fmt.Errorf("pre-delete actions failed: %w", err) } - if err := m.deleteInterfacePeers(ctx, id); err != nil { + if err := m.deleteInterfacePeers(ctx, existingInterface, existingPeers); err != nil { return fmt.Errorf("peer deletion failure: %w", err) } @@ -488,16 +496,12 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif return fmt.Errorf("deletion failure: %w", err) } - fwMark := existingInterface.FirewallMark - if physicalInterface != nil && fwMark == 0 { - fwMark = physicalInterface.FirewallMark - } - m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{ - FwMark: fwMark, - Table: existingInterface.GetRoutingTable(), - }) - - if err := m.handleInterfacePostSaveHooks(existingInterface, !existingInterface.IsDisabled(), false); err != nil { + if err := m.handleInterfacePostSaveHooks( + ctx, + existingInterface, + !existingInterface.IsDisabled(), + false, + ); err != nil { return fmt.Errorf("post-delete hooks failed: %w", err) } @@ -516,17 +520,21 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) ( return nil, fmt.Errorf("interface validation failed: %w", err) } - oldEnabled, newEnabled := m.getInterfaceStateHistory(ctx, iface) + oldEnabled, newEnabled, routeTableChanged := false, !iface.IsDisabled(), false // if the interface did not exist, we assume it was not enabled + oldInterface, err := m.db.GetInterface(ctx, iface.Identifier) + if err == nil { + oldEnabled, newEnabled, routeTableChanged = m.getInterfaceStateHistory(oldInterface, iface) + } - if err := m.handleInterfacePreSaveHooks(iface, oldEnabled, newEnabled); err != nil { + if err := m.handleInterfacePreSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil { return nil, fmt.Errorf("pre-save hooks failed: %w", err) } - if err := m.handleInterfacePreSaveActions(iface); err != nil { + if err := m.handleInterfacePreSaveActions(ctx, iface); err != nil { return nil, fmt.Errorf("pre-save actions failed: %w", err) } - err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) { + err = m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) { iface.CopyCalculatedAttributes(i) err := m.wg.GetController(*iface).SaveInterface(ctx, iface.Identifier, @@ -569,20 +577,35 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) ( } if iface.IsDisabled() { - physicalInterface, _ := m.wg.GetController(*iface).GetInterface(ctx, iface.Identifier) - fwMark := iface.FirewallMark - if physicalInterface != nil && fwMark == 0 { - fwMark = physicalInterface.FirewallMark - } m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{ - FwMark: fwMark, - Table: iface.GetRoutingTable(), + Interface: *iface, + AllowedIps: iface.GetAllowedIPs(peers), + FwMark: iface.FirewallMark, + Table: iface.GetRoutingTable(), + TableStr: iface.RoutingTable, }) } else { - m.bus.Publish(app.TopicRouteUpdate, "interface updated: "+string(iface.Identifier)) + m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{ + Interface: *iface, + AllowedIps: iface.GetAllowedIPs(peers), + FwMark: iface.FirewallMark, + Table: iface.GetRoutingTable(), + TableStr: iface.RoutingTable, + }) + // if the route table changed, ensure that the old entries are remove + if routeTableChanged { + m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{ + Interface: *oldInterface, + AllowedIps: oldInterface.GetAllowedIPs(peers), + FwMark: oldInterface.FirewallMark, + Table: oldInterface.GetRoutingTable(), + TableStr: oldInterface.RoutingTable, + IsDeleted: true, // mark the old entries as deleted + }) + } } - if err := m.handleInterfacePostSaveHooks(iface, oldEnabled, newEnabled); err != nil { + if err := m.handleInterfacePostSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil { return nil, fmt.Errorf("post-save hooks failed: %w", err) } @@ -618,60 +641,90 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) ( return iface, nil } -func (m Manager) getInterfaceStateHistory(ctx context.Context, iface *domain.Interface) (oldEnabled, newEnabled bool) { - oldInterface, err := m.db.GetInterface(ctx, iface.Identifier) - if err != nil { - return false, !iface.IsDisabled() // if the interface did not exist, we assume it was not enabled - } - - return !oldInterface.IsDisabled(), !iface.IsDisabled() +func (m Manager) getInterfaceStateHistory( + oldInterface *domain.Interface, + iface *domain.Interface, +) (oldEnabled, newEnabled, routeTableChanged bool) { + return !oldInterface.IsDisabled(), !iface.IsDisabled(), oldInterface.RoutingTable != iface.RoutingTable } -func (m Manager) handleInterfacePreSaveActions(iface *domain.Interface) error { - if !iface.IsDisabled() { - if err := m.quick.SetDNS(iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil { - return fmt.Errorf("failed to update dns settings: %w", err) - } - } else { - if err := m.quick.UnsetDNS(iface.Identifier); err != nil { - return fmt.Errorf("failed to clear dns settings: %w", err) +func (m Manager) handleInterfacePreSaveActions(ctx context.Context, iface *domain.Interface) error { + wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController) + if !ok { + slog.Warn("failed to perform pre-save actions", "interface", iface.Identifier, + "error", "no capable controller found") + return nil + } + + // update DNS settings only for client interfaces + if iface.Type == domain.InterfaceTypeClient || iface.Type == domain.InterfaceTypeAny { + if !iface.IsDisabled() { + if err := wgQuickController.SetDNS(ctx, iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil { + return fmt.Errorf("failed to update dns settings: %w", err) + } + } else { + if err := wgQuickController.UnsetDNS(ctx, iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil { + return fmt.Errorf("failed to clear dns settings: %w", err) + } } } return nil } -func (m Manager) handleInterfacePreSaveHooks(iface *domain.Interface, oldEnabled, newEnabled bool) error { +func (m Manager) handleInterfacePreSaveHooks( + ctx context.Context, + iface *domain.Interface, + oldEnabled, newEnabled bool, +) error { if oldEnabled == newEnabled { return nil // do nothing if state did not change } slog.Debug("executing pre-save hooks", "interface", iface.Identifier, "up", newEnabled) + wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController) + if !ok { + slog.Warn("failed to execute pre-save hooks", "interface", iface.Identifier, "up", newEnabled, + "error", "no capable controller found") + return nil + } + if newEnabled { - if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreUp); err != nil { + if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PreUp); err != nil { return fmt.Errorf("failed to execute pre-up hook: %w", err) } } else { - if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreDown); err != nil { + if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PreDown); err != nil { return fmt.Errorf("failed to execute pre-down hook: %w", err) } } return nil } -func (m Manager) handleInterfacePostSaveHooks(iface *domain.Interface, oldEnabled, newEnabled bool) error { +func (m Manager) handleInterfacePostSaveHooks( + ctx context.Context, + iface *domain.Interface, + oldEnabled, newEnabled bool, +) error { if oldEnabled == newEnabled { return nil // do nothing if state did not change } slog.Debug("executing post-save hooks", "interface", iface.Identifier, "up", newEnabled) + wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController) + if !ok { + slog.Warn("failed to execute post-save hooks", "interface", iface.Identifier, "up", newEnabled, + "error", "no capable controller found") + return nil + } + if newEnabled { - if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostUp); err != nil { + if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PostUp); err != nil { return fmt.Errorf("failed to execute post-up hook: %w", err) } } else { - if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostDown); err != nil { + if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PostDown); err != nil { return fmt.Errorf("failed to execute post-down hook: %w", err) } } @@ -799,7 +852,7 @@ func (m Manager) getFreshListenPort(ctx context.Context) (port int, err error) { func (m Manager) importInterface( ctx context.Context, - backend InterfaceController, + backend domain.InterfaceController, in *domain.PhysicalInterface, peers []domain.PhysicalPeer, ) error { @@ -901,13 +954,9 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain return nil } -func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) error { - iface, allPeers, err := m.db.GetInterfaceAndPeers(ctx, id) - if err != nil { - return err - } +func (m Manager) deleteInterfacePeers(ctx context.Context, iface *domain.Interface, allPeers []domain.Peer) error { for _, peer := range allPeers { - err = m.wg.GetController(*iface).DeletePeer(ctx, id, peer.Identifier) + err := m.wg.GetController(*iface).DeletePeer(ctx, iface.Identifier, peer.Identifier) if err != nil && !errors.Is(err, os.ErrNotExist) { return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err) } diff --git a/internal/app/wireguard/wireguard_peers.go b/internal/app/wireguard/wireguard_peers.go index f3c5364..e42af28 100644 --- a/internal/app/wireguard/wireguard_peers.go +++ b/internal/app/wireguard/wireguard_peers.go @@ -388,9 +388,20 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error return fmt.Errorf("failed to delete peer %s: %w", id, err) } + peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier) + if err != nil { + return fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err) + } + m.bus.Publish(app.TopicPeerDeleted, *peer) // Update routes after peers have changed - m.bus.Publish(app.TopicRouteUpdate, "peers updated") + m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{ + Interface: *iface, + AllowedIps: iface.GetAllowedIPs(peers), + FwMark: iface.FirewallMark, + Table: iface.GetRoutingTable(), + TableStr: iface.RoutingTable, + }) // Update interface after peers have changed m.bus.Publish(app.TopicPeerInterfaceUpdated, peer.InterfaceIdentifier) @@ -438,20 +449,26 @@ func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) // region helper-functions func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error { - interfaces := make(map[domain.InterfaceIdentifier]struct{}) + interfaces := make(map[domain.InterfaceIdentifier]domain.Interface) for _, peer := range peers { - iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier) - if err != nil { - return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err) + // get interface from db if it is not yet in the map + if _, ok := interfaces[peer.InterfaceIdentifier]; !ok { + iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier) + if err != nil { + return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err) + } + interfaces[peer.InterfaceIdentifier] = *iface } + iface := interfaces[peer.InterfaceIdentifier] + // Always save the peer to the backend, regardless of disabled/expired state // The backend will handle the disabled state appropriately - err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) { + err := m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) { peer.CopyCalculatedAttributes(p) - err := m.wg.GetController(*iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier, + err := m.wg.GetController(iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier, func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) { domain.MergeToPhysicalPeer(pp, peer) return pp, nil @@ -475,13 +492,22 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error { Peer: *peer, }, }) - - interfaces[peer.InterfaceIdentifier] = struct{}{} } // Update routes after peers have changed - if len(interfaces) != 0 { - m.bus.Publish(app.TopicRouteUpdate, "peers updated") + for id, iface := range interfaces { + interfacePeers, err := m.db.GetInterfacePeers(ctx, id) + if err != nil { + return fmt.Errorf("failed to re-load peers for interface %s: %w", id, err) + } + + m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{ + Interface: iface, + AllowedIps: iface.GetAllowedIPs(interfacePeers), + FwMark: iface.FirewallMark, + Table: iface.GetRoutingTable(), + TableStr: iface.RoutingTable, + }) } for iface := range interfaces { diff --git a/internal/config/backend.go b/internal/config/backend.go index fa8ff2e..cee7c8b 100644 --- a/internal/config/backend.go +++ b/internal/config/backend.go @@ -13,6 +13,7 @@ type Backend struct { // Local Backend-specific configuration IgnoredLocalInterfaces []string `yaml:"ignored_local_interfaces"` // A list of interface names that should be ignored by this backend (e.g., "wg0") + LocalResolvconfPrefix string `yaml:"local_resolvconf_prefix"` // The prefix to use for interface names when passing them to resolvconf. // External Backend-specific configuration diff --git a/internal/config/config.go b/internal/config/config.go index 4203099..338dbf6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -134,6 +134,9 @@ func defaultConfig() *Config { cfg.Backend = Backend{ Default: LocalBackendName, // local backend is the default (using wgcrtl) + // Most resolconf implementations use "tun." as a prefix for interface names. + // But systemd's implementation uses no prefix, for example. + LocalResolvconfPrefix: "tun.", } cfg.Web = WebConfig{ diff --git a/internal/domain/interface.go b/internal/domain/interface.go index 32fc1c0..01c720c 100644 --- a/internal/domain/interface.go +++ b/internal/domain/interface.go @@ -13,6 +13,7 @@ import ( "golang.org/x/sys/unix" "github.com/h44z/wg-portal/internal" + "github.com/h44z/wg-portal/internal/config" ) const ( @@ -132,17 +133,30 @@ func (i *Interface) GetConfigFileName() string { return filename } +// GetAllowedIPs returns the allowed IPs for the interface depending on the interface type and peers. +// For example, if the interface type is Server, the allowed IPs are the IPs of the peers. +// If the interface type is Client, the allowed IPs correspond to the AllowedIPsStr of the peers. func (i *Interface) GetAllowedIPs(peers []Peer) []Cidr { var allowedCidrs []Cidr - for _, peer := range peers { - for _, ip := range peer.Interface.Addresses { - allowedCidrs = append(allowedCidrs, ip.HostAddr()) + switch i.Type { + case InterfaceTypeServer, InterfaceTypeAny: + for _, peer := range peers { + for _, ip := range peer.Interface.Addresses { + allowedCidrs = append(allowedCidrs, ip.HostAddr()) + } + if peer.ExtraAllowedIPsStr != "" { + extraIPs, err := CidrsFromString(peer.ExtraAllowedIPsStr) + if err == nil { + allowedCidrs = append(allowedCidrs, extraIPs...) + } + } } - if peer.ExtraAllowedIPsStr != "" { - extraIPs, err := CidrsFromString(peer.ExtraAllowedIPsStr) + case InterfaceTypeClient: + for _, peer := range peers { + allowedIPs, err := CidrsFromString(peer.AllowedIPsStr.GetValue()) if err == nil { - allowedCidrs = append(allowedCidrs, extraIPs...) + allowedCidrs = append(allowedCidrs, allowedIPs...) } } } @@ -159,6 +173,7 @@ func (i *Interface) ManageRoutingTable() bool { // // -1 if RoutingTable was set to "off" or an error occurred func (i *Interface) GetRoutingTable() int { + routingTableStr := strings.ToLower(i.RoutingTable) switch { case routingTableStr == "": @@ -166,6 +181,9 @@ func (i *Interface) GetRoutingTable() int { case routingTableStr == "off": return -1 case strings.HasPrefix(routingTableStr, "0x"): + if i.Backend != config.LocalBackendName { + return 0 // ignore numeric routing table numbers for non-local controllers + } numberStr := strings.ReplaceAll(routingTableStr, "0x", "") routingTable, err := strconv.ParseUint(numberStr, 16, 64) if err != nil { @@ -178,6 +196,9 @@ func (i *Interface) GetRoutingTable() int { } return int(routingTable) default: + if i.Backend != config.LocalBackendName { + return 0 // ignore numeric routing table numbers for non-local controllers + } routingTable, err := strconv.Atoi(routingTableStr) if err != nil { slog.Error("failed to parse routing table number", "table", routingTableStr, "error", err) @@ -308,12 +329,18 @@ func MergeToPhysicalInterface(pi *PhysicalInterface, i *Interface) { } type RoutingTableInfo struct { - FwMark uint32 - Table int + Interface Interface + AllowedIps []Cidr + FwMark uint32 + Table int + TableStr string // the routing table number as string (used by mikrotik, linux uses the numeric value) + IsDeleted bool // true if the interface was deleted, false otherwise } func (r RoutingTableInfo) String() string { - return fmt.Sprintf("%d -> %d", r.FwMark, r.Table) + v4, v6 := CidrsPerFamily(r.AllowedIps) + return fmt.Sprintf("%s: fwmark=%d; table=%d; routes_4=%d; routes_6=%d", r.Interface.Identifier, r.FwMark, r.Table, + len(v4), len(v6)) } func (r RoutingTableInfo) ManagementEnabled() bool { diff --git a/internal/domain/interface_controller.go b/internal/domain/interface_controller.go new file mode 100644 index 0000000..efd3207 --- /dev/null +++ b/internal/domain/interface_controller.go @@ -0,0 +1,27 @@ +package domain + +import "context" + +type InterfaceController interface { + GetId() InterfaceBackend + GetInterfaces(_ context.Context) ([]PhysicalInterface, error) + GetInterface(_ context.Context, id InterfaceIdentifier) (*PhysicalInterface, error) + GetPeers(_ context.Context, deviceId InterfaceIdentifier) ([]PhysicalPeer, error) + SaveInterface( + _ context.Context, + id InterfaceIdentifier, + updateFunc func(pi *PhysicalInterface) (*PhysicalInterface, error), + ) error + DeleteInterface(_ context.Context, id InterfaceIdentifier) error + SavePeer( + _ context.Context, + deviceId InterfaceIdentifier, + id PeerIdentifier, + updateFunc func(pp *PhysicalPeer) (*PhysicalPeer, error), + ) error + DeletePeer(_ context.Context, deviceId InterfaceIdentifier, id PeerIdentifier) error + PingAddresses( + ctx context.Context, + addr string, + ) (*PingerResult, error) +} diff --git a/internal/domain/interface_test.go b/internal/domain/interface_test.go index 54aa74d..9f0ee50 100644 --- a/internal/domain/interface_test.go +++ b/internal/domain/interface_test.go @@ -5,6 +5,8 @@ import ( "time" "github.com/stretchr/testify/assert" + + "github.com/h44z/wg-portal/internal/config" ) func TestInterface_IsDisabledReturnsTrueWhenDisabled(t *testing.T) { @@ -37,8 +39,9 @@ func TestInterface_GetConfigFileNameReturnsCorrectFileName(t *testing.T) { assert.Equal(t, expected, iface.GetConfigFileName()) } -func TestInterface_GetAllowedIPsReturnsCorrectCidrs(t *testing.T) { +func TestInterface_GetAllowedIPsReturnsCorrectCidrsServerMode(t *testing.T) { peer1 := Peer{ + AllowedIPsStr: ConfigOption[string]{Value: "192.168.2.2/32"}, Interface: PeerInterfaceConfig{ Addresses: []Cidr{ {Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32}, @@ -46,16 +49,45 @@ func TestInterface_GetAllowedIPsReturnsCorrectCidrs(t *testing.T) { }, } peer2 := Peer{ + AllowedIPsStr: ConfigOption[string]{Value: "10.0.2.2/32"}, + ExtraAllowedIPsStr: "10.20.2.2/32", Interface: PeerInterfaceConfig{ Addresses: []Cidr{ {Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32}, }, }, } - iface := &Interface{} + iface := &Interface{Type: InterfaceTypeServer} expected := []Cidr{ {Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32}, {Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32}, + {Cidr: "10.20.2.2/32", Addr: "10.20.2.2", NetLength: 32}, + } + assert.Equal(t, expected, iface.GetAllowedIPs([]Peer{peer1, peer2})) +} + +func TestInterface_GetAllowedIPsReturnsCorrectCidrsClientMode(t *testing.T) { + peer1 := Peer{ + AllowedIPsStr: ConfigOption[string]{Value: "192.168.2.2/32"}, + Interface: PeerInterfaceConfig{ + Addresses: []Cidr{ + {Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32}, + }, + }, + } + peer2 := Peer{ + AllowedIPsStr: ConfigOption[string]{Value: "10.0.2.2/32"}, + ExtraAllowedIPsStr: "10.20.2.2/32", + Interface: PeerInterfaceConfig{ + Addresses: []Cidr{ + {Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32}, + }, + }, + } + iface := &Interface{Type: InterfaceTypeClient} + expected := []Cidr{ + {Cidr: "192.168.2.2/32", Addr: "192.168.2.2", NetLength: 32}, + {Cidr: "10.0.2.2/32", Addr: "10.0.2.2", NetLength: 32}, } assert.Equal(t, expected, iface.GetAllowedIPs([]Peer{peer1, peer2})) } @@ -66,10 +98,22 @@ func TestInterface_ManageRoutingTableReturnsCorrectValue(t *testing.T) { iface.RoutingTable = "100" assert.True(t, iface.ManageRoutingTable()) + + iface = &Interface{RoutingTable: "off", Backend: config.LocalBackendName} + assert.False(t, iface.ManageRoutingTable()) + + iface.RoutingTable = "100" + assert.True(t, iface.ManageRoutingTable()) + + iface = &Interface{RoutingTable: "off", Backend: "mikrotik-xxx"} + assert.False(t, iface.ManageRoutingTable()) + + iface.RoutingTable = "100" + assert.True(t, iface.ManageRoutingTable()) } func TestInterface_GetRoutingTableReturnsCorrectValue(t *testing.T) { - iface := &Interface{RoutingTable: ""} + iface := &Interface{RoutingTable: "", Backend: config.LocalBackendName} assert.Equal(t, 0, iface.GetRoutingTable()) iface.RoutingTable = "off" @@ -81,3 +125,17 @@ func TestInterface_GetRoutingTableReturnsCorrectValue(t *testing.T) { iface.RoutingTable = "200" assert.Equal(t, 200, iface.GetRoutingTable()) } + +func TestInterface_GetRoutingTableNonLocal(t *testing.T) { + iface := &Interface{RoutingTable: "off", Backend: "something different"} + assert.Equal(t, -1, iface.GetRoutingTable()) + + iface.RoutingTable = "0" + assert.Equal(t, 0, iface.GetRoutingTable()) + + iface.RoutingTable = "100" + assert.Equal(t, 0, iface.GetRoutingTable()) + + iface.RoutingTable = "abc" + assert.Equal(t, 0, iface.GetRoutingTable()) +} diff --git a/internal/domain/ip.go b/internal/domain/ip.go index ee67413..6081b70 100644 --- a/internal/domain/ip.go +++ b/internal/domain/ip.go @@ -26,6 +26,10 @@ func (c Cidr) IsValid() bool { return c.Prefix().IsValid() } +func (c Cidr) EqualPrefix(other Cidr) bool { + return c.Addr == other.Addr && c.NetLength == other.NetLength +} + func CidrFromString(str string) (Cidr, error) { prefix, err := netip.ParsePrefix(strings.TrimSpace(str)) if err != nil { @@ -199,3 +203,26 @@ func (c Cidr) Contains(other Cidr) bool { return subnet.Contains(otherIP) } + +// ContainsDefaultRoute returns true if the given CIDRs contain a default route. +func ContainsDefaultRoute(cidrs []Cidr) bool { + for _, allowedIP := range cidrs { + if allowedIP.Prefix().Bits() == 0 { + return true + } + } + + return false +} + +// CidrsPerFamily returns a slice of CIDRs, one for each family (IPv4 and IPv6). +func CidrsPerFamily(cidrs []Cidr) (ipv4, ipv6 []Cidr) { + for _, cidr := range cidrs { + if cidr.IsV4() { + ipv4 = append(ipv4, cidr) + } else { + ipv6 = append(ipv6, cidr) + } + } + return +} diff --git a/internal/lowlevel/mikrotik.go b/internal/lowlevel/mikrotik.go index 49ef1d7..0c86a13 100644 --- a/internal/lowlevel/mikrotik.go +++ b/internal/lowlevel/mikrotik.go @@ -267,6 +267,7 @@ func parseHttpResponse[T any](resp *http.Response, err error) MikrotikApiRespons } defer func(Body io.ReadCloser) { + _, _ = io.Copy(io.Discard, Body) // ensure to empty the body err := Body.Close() if err != nil { slog.Error("failed to close response body", "error", err)