From c00b34ac31df6f3f4621fbd9a8c8f670b89913db Mon Sep 17 00:00:00 2001 From: Christoph Haas Date: Tue, 7 Oct 2025 22:42:13 +0200 Subject: [PATCH] replace old route handling for local controller --- internal/adapters/wgcontroller/local.go | 304 +++++++++++++++++++++++- internal/app/route/routes.go | 12 + internal/domain/ip.go | 24 +- 3 files changed, 328 insertions(+), 12 deletions(-) diff --git a/internal/adapters/wgcontroller/local.go b/internal/adapters/wgcontroller/local.go index c977a57..38e22ca 100644 --- a/internal/adapters/wgcontroller/local.go +++ b/internal/adapters/wgcontroller/local.go @@ -9,11 +9,13 @@ import ( "log/slog" "os" "os/exec" + "slices" "strings" "time" probing "github.com/prometheus-community/pro-bing" "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -638,23 +640,321 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti // SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op. func (c LocalController) SetRoutes( - ctx context.Context, + _ context.Context, interfaceId domain.InterfaceIdentifier, table int, fwMark uint32, cidrs []domain.Cidr, ) error { + slog.Debug("setting linux routes", "interface", interfaceId, "table", table, "fwMark", fwMark, "cidrs", cidrs) + + link, err := c.nl.LinkByName(string(interfaceId)) + if err != nil { + return fmt.Errorf("failed to find physical link for %s: %w", interfaceId, err) + } + + cidrsV4, cidrsV6 := domain.CidrsPerFamily(cidrs) + realTable, realFwMark, err := c.getOrCreateRoutingTableAndFwMark(link, table, fwMark) + if err != nil { + return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err) + } + if realFwMark != fwMark { + slog.Debug("updating fwmark for interface", "interface", interfaceId, "oldFwMark", fwMark, + "newFwMark", realFwMark, "oldTable", table, "newTable", realTable) + if err := c.updateFwMarkOnInterface(interfaceId, int(realFwMark)); err != nil { + return fmt.Errorf("failed to update fwmark for interface %s to %d: %w", interfaceId, realFwMark, err) + } + } + + if err := c.setRoutesForFamily(interfaceId, link, netlink.FAMILY_V4, realTable, realFwMark, cidrsV4); err != nil { + return fmt.Errorf("failed to set v4 routes: %w", err) + } + if err := c.setRoutesForFamily(interfaceId, link, netlink.FAMILY_V6, realTable, realFwMark, cidrsV6); err != nil { + return fmt.Errorf("failed to set v6 routes: %w", err) + } + return nil } +func (c LocalController) setRoutesForFamily( + interfaceId domain.InterfaceIdentifier, + link netlink.Link, + family int, + table int, + fwMark uint32, + cidrs []domain.Cidr, +) error { + // first create or update the routes + for _, cidr := range cidrs { + err := c.nl.RouteReplace(&netlink.Route{ + LinkIndex: link.Attrs().Index, + Dst: cidr.IpNet(), + Table: table, + Scope: unix.RT_SCOPE_LINK, + Type: unix.RTN_UNICAST, + }) + if err != nil { + return fmt.Errorf("failed to add/update route %s on table %d for interface %s: %w", + cidr.String(), table, interfaceId, err) + } + } + + // next remove old routes + rawRoutes, err := c.nl.RouteListFiltered(family, &netlink.Route{ + LinkIndex: link.Attrs().Index, + Table: unix.RT_TABLE_UNSPEC, // all tables + Scope: unix.RT_SCOPE_LINK, + Type: unix.RTN_UNICAST, + }, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF) + if err != nil { + return fmt.Errorf("failed to fetch raw routes for interface %s and family-id %d: %w", + interfaceId, family, err) + } + for _, rawRoute := range rawRoutes { + if rawRoute.Dst == nil { // handle default route + var netlinkAddr domain.Cidr + if family == netlink.FAMILY_V4 { + netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0") + } else { + netlinkAddr, _ = domain.CidrFromString("::/0") + } + rawRoute.Dst = netlinkAddr.IpNet() + } + + route := domain.CidrFromIpNet(*rawRoute.Dst) + if slices.Contains(cidrs, route) { + continue + } + + if err := c.nl.RouteDel(&rawRoute); err != nil { + return fmt.Errorf("failed to remove deprecated route %s from interface %s: %w", route, interfaceId, err) + } + } + + // next, update route rules for normal routes + if table == 0 { + return nil // no need to update route rules as we are using the default table + } + existingRules, err := c.nl.RuleList(family) + if err != nil { + return fmt.Errorf("failed to get existing rules for family-id %d: %w", family, err) + } + ruleExists := slices.ContainsFunc(existingRules, func(rule netlink.Rule) bool { + return rule.Mark == fwMark && rule.Table == table + }) + if !ruleExists { + if err := c.nl.RuleAdd(&netlink.Rule{ + Family: family, + Table: table, + Mark: fwMark, + Invert: true, + SuppressIfgroup: -1, + SuppressPrefixlen: -1, + Priority: c.getRulePriority(existingRules), + Mask: nil, + Goto: -1, + Flow: -1, + }); err != nil { + return fmt.Errorf("failed to setup rule for fwmark %d and table %d for family-id %d: %w", + fwMark, table, family, err) + } + } + mainRuleExists := slices.ContainsFunc(existingRules, func(rule netlink.Rule) bool { + return rule.SuppressPrefixlen == 0 && rule.Table == unix.RT_TABLE_MAIN + }) + if !mainRuleExists && domain.ContainsDefaultRoute(cidrs) { + err = c.nl.RuleAdd(&netlink.Rule{ + Family: family, + Table: unix.RT_TABLE_MAIN, + SuppressIfgroup: -1, + SuppressPrefixlen: 0, + Priority: c.getMainRulePriority(existingRules), + Mark: 0, + Mask: nil, + Goto: -1, + Flow: -1, + }) + } + + // finally, clean up extra main rules - only one rule is allowed + existingRules, err = c.nl.RuleList(family) + if err != nil { + return fmt.Errorf("failed to get existing main rules for family-id %d: %w", family, err) + } + mainRuleCount := 0 + for _, rule := range existingRules { + if rule.SuppressPrefixlen == 0 && rule.Table == unix.RT_TABLE_MAIN { + mainRuleCount++ + } + if mainRuleCount > 1 { + if err := c.nl.RuleDel(&rule); err != nil { + return fmt.Errorf("failed to remove extra main rule for family-id %d: %w", family, err) + } + } + } + + return nil +} + +func (c LocalController) getOrCreateRoutingTableAndFwMark( + link netlink.Link, + tableIn int, + fwMarkIn uint32, +) ( + table int, + fwmark uint32, + err error, +) { + table = tableIn + fwmark = fwMarkIn + + if fwmark == 0 { + // generate a new (temporary) firewall mark based on the interface index + fwmark = uint32(c.cfg.Advanced.RouteTableOffset + link.Attrs().Index) + } + if table == 0 { + table = int(fwmark) // generate a new routing table base on interface index + } + return +} + +func (c LocalController) updateFwMarkOnInterface(interfaceId domain.InterfaceIdentifier, fwMark int) error { + // apply the new fwmark to the wireguard interface + err := c.wg.ConfigureDevice(string(interfaceId), wgtypes.Config{ + FirewallMark: &fwMark, + }) + if err != nil { + return fmt.Errorf("failed to update fwmark of interface %s to: %d: %w", interfaceId, fwMark, err) + } + + return nil +} + +func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int { + prio := c.cfg.Advanced.RulePrioOffset + for { + isFresh := true + for _, existingRule := range existingRules { + if existingRule.Priority == prio { + isFresh = false + break + } + } + if isFresh { + break + } else { + prio++ + } + } + return prio +} + +func (c LocalController) getRulePriority(existingRules []netlink.Rule) int { + prio := 32700 // linux main rule has a prio of 32766 + for { + isFresh := true + for _, existingRule := range existingRules { + if existingRule.Priority == prio { + isFresh = false + break + } + } + if isFresh { + break + } else { + prio-- + } + } + return prio +} + // RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op. func (c LocalController) RemoveRoutes( - ctx context.Context, + _ context.Context, interfaceId domain.InterfaceIdentifier, table int, fwMark uint32, oldCidrs []domain.Cidr, ) error { + slog.Debug("removing linux routes", "interface", interfaceId, "table", table, "fwMark", fwMark, "cidrs", oldCidrs) + + link, err := c.nl.LinkByName(string(interfaceId)) + if err != nil { + return fmt.Errorf("failed to find physical link for %s: %w", interfaceId, err) + } + + cidrsV4, cidrsV6 := domain.CidrsPerFamily(oldCidrs) + realTable, realFwMark, err := c.getOrCreateRoutingTableAndFwMark(link, table, fwMark) + if err != nil { + return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err) + } + + err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V4, realTable, realFwMark, cidrsV4) + if err != nil { + return fmt.Errorf("failed to remove v4 routes: %w", err) + } + err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V6, realTable, realFwMark, cidrsV6) + if err != nil { + return fmt.Errorf("failed to remove v6 routes: %w", err) + } + + return nil +} + +func (c LocalController) removeRoutesForFamily( + interfaceId domain.InterfaceIdentifier, + link netlink.Link, + family int, + table int, + fwMark uint32, + cidrs []domain.Cidr, +) error { + // first remove all rules + existingRules, err := c.nl.RuleList(family) + if err != nil { + return fmt.Errorf("failed to get existing rules for family %d: %w", family, err) + } + for _, existingRule := range existingRules { + if fwMark == existingRule.Mark && table == existingRule.Table { + existingRule.Family = family // set family, somehow the RuleList method does not populate the family field + if err := c.nl.RuleDel(&existingRule); err != nil { + return fmt.Errorf("failed to delete old fwmark rule: %w", err) + } + } + } + + // next remove all routes + rawRoutes, err := c.nl.RouteListFiltered(family, &netlink.Route{ + LinkIndex: link.Attrs().Index, + Table: unix.RT_TABLE_UNSPEC, // all tables + Scope: unix.RT_SCOPE_LINK, + Type: unix.RTN_UNICAST, + }, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF) + if err != nil { + return fmt.Errorf("failed to fetch raw routes for interface %s and family-id %d: %w", + interfaceId, family, err) + } + for _, rawRoute := range rawRoutes { + if rawRoute.Dst == nil { // handle default route + var netlinkAddr domain.Cidr + if family == netlink.FAMILY_V4 { + netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0") + } else { + netlinkAddr, _ = domain.CidrFromString("::/0") + } + rawRoute.Dst = netlinkAddr.IpNet() + } + + route := domain.CidrFromIpNet(*rawRoute.Dst) + if !slices.Contains(cidrs, route) { + continue // only remove routes that were previously added + } + + if err := c.nl.RouteDel(&rawRoute); err != nil { + return fmt.Errorf("failed to remove old route %s from interface %s: %w", route, interfaceId, err) + } + } + return nil } diff --git a/internal/app/route/routes.go b/internal/app/route/routes.go index 9e52e23..6a202a9 100644 --- a/internal/app/route/routes.go +++ b/internal/app/route/routes.go @@ -138,6 +138,12 @@ func (m Manager) syncRoutes(ctx context.Context, info domain.RoutingTableInfo) e return nil } + if !info.Interface.ManageRoutingTable() { + slog.Debug("interface does not manage routing table, skipping route update", + "interface", info.Interface.Identifier) + return nil + } + err := rc.SetRoutes(ctx, info.Interface.Identifier, info.Table, info.FwMark, info.AllowedIps) if err != nil { return fmt.Errorf("failed to set routes for interface %s: %w", info.Interface.Identifier, err) @@ -152,6 +158,12 @@ func (m Manager) removeRoutes(ctx context.Context, info domain.RoutingTableInfo) return nil } + if !info.Interface.ManageRoutingTable() { + slog.Debug("interface does not manage routing table, skipping route removal", + "interface", info.Interface.Identifier) + return nil + } + err := rc.RemoveRoutes(ctx, info.Interface.Identifier, info.Table, info.FwMark, info.AllowedIps) if err != nil { return fmt.Errorf("failed to remove routes for interface %s: %w", info.Interface.Identifier, err) diff --git a/internal/domain/ip.go b/internal/domain/ip.go index acdaf02..a62fcb4 100644 --- a/internal/domain/ip.go +++ b/internal/domain/ip.go @@ -201,20 +201,24 @@ func (c Cidr) Contains(other Cidr) bool { } // ContainsDefaultRoute returns true if the given CIDRs contain a default route. -func ContainsDefaultRoute(cidrs []Cidr) (ipV4, ipV6 bool) { +func ContainsDefaultRoute(cidrs []Cidr) bool { for _, allowedIP := range cidrs { - if ipV4 && ipV6 { - break // speed up - } - if allowedIP.Prefix().Bits() == 0 { - if allowedIP.IsV4() { - ipV4 = true - } else { - ipV6 = true - } + return true } } + return false +} + +// CidrsPerFamily returns a slice of CIDRs, one for each family (IPv4 and IPv6). +func CidrsPerFamily(cidrs []Cidr) (ipv4, ipv6 []Cidr) { + for _, cidr := range cidrs { + if cidr.IsV4() { + ipv4 = append(ipv4, cidr) + } else { + ipv6 = append(ipv6, cidr) + } + } return }