replace old route handling for local controller

This commit is contained in:
Christoph Haas
2025-10-07 22:42:13 +02:00
parent 1fc7e352ab
commit c00b34ac31
3 changed files with 328 additions and 12 deletions

View File

@@ -9,11 +9,13 @@ import (
"log/slog"
"os"
"os/exec"
"slices"
"strings"
"time"
probing "github.com/prometheus-community/pro-bing"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -638,23 +640,321 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti
// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op.
func (c LocalController) SetRoutes(
ctx context.Context,
_ context.Context,
interfaceId domain.InterfaceIdentifier,
table int,
fwMark uint32,
cidrs []domain.Cidr,
) error {
slog.Debug("setting linux routes", "interface", interfaceId, "table", table, "fwMark", fwMark, "cidrs", cidrs)
link, err := c.nl.LinkByName(string(interfaceId))
if err != nil {
return fmt.Errorf("failed to find physical link for %s: %w", interfaceId, err)
}
cidrsV4, cidrsV6 := domain.CidrsPerFamily(cidrs)
realTable, realFwMark, err := c.getOrCreateRoutingTableAndFwMark(link, table, fwMark)
if err != nil {
return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err)
}
if realFwMark != fwMark {
slog.Debug("updating fwmark for interface", "interface", interfaceId, "oldFwMark", fwMark,
"newFwMark", realFwMark, "oldTable", table, "newTable", realTable)
if err := c.updateFwMarkOnInterface(interfaceId, int(realFwMark)); err != nil {
return fmt.Errorf("failed to update fwmark for interface %s to %d: %w", interfaceId, realFwMark, err)
}
}
if err := c.setRoutesForFamily(interfaceId, link, netlink.FAMILY_V4, realTable, realFwMark, cidrsV4); err != nil {
return fmt.Errorf("failed to set v4 routes: %w", err)
}
if err := c.setRoutesForFamily(interfaceId, link, netlink.FAMILY_V6, realTable, realFwMark, cidrsV6); err != nil {
return fmt.Errorf("failed to set v6 routes: %w", err)
}
return nil
}
func (c LocalController) setRoutesForFamily(
interfaceId domain.InterfaceIdentifier,
link netlink.Link,
family int,
table int,
fwMark uint32,
cidrs []domain.Cidr,
) error {
// first create or update the routes
for _, cidr := range cidrs {
err := c.nl.RouteReplace(&netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: cidr.IpNet(),
Table: table,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
})
if err != nil {
return fmt.Errorf("failed to add/update route %s on table %d for interface %s: %w",
cidr.String(), table, interfaceId, err)
}
}
// next remove old routes
rawRoutes, err := c.nl.RouteListFiltered(family, &netlink.Route{
LinkIndex: link.Attrs().Index,
Table: unix.RT_TABLE_UNSPEC, // all tables
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
if err != nil {
return fmt.Errorf("failed to fetch raw routes for interface %s and family-id %d: %w",
interfaceId, family, err)
}
for _, rawRoute := range rawRoutes {
if rawRoute.Dst == nil { // handle default route
var netlinkAddr domain.Cidr
if family == netlink.FAMILY_V4 {
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
} else {
netlinkAddr, _ = domain.CidrFromString("::/0")
}
rawRoute.Dst = netlinkAddr.IpNet()
}
route := domain.CidrFromIpNet(*rawRoute.Dst)
if slices.Contains(cidrs, route) {
continue
}
if err := c.nl.RouteDel(&rawRoute); err != nil {
return fmt.Errorf("failed to remove deprecated route %s from interface %s: %w", route, interfaceId, err)
}
}
// next, update route rules for normal routes
if table == 0 {
return nil // no need to update route rules as we are using the default table
}
existingRules, err := c.nl.RuleList(family)
if err != nil {
return fmt.Errorf("failed to get existing rules for family-id %d: %w", family, err)
}
ruleExists := slices.ContainsFunc(existingRules, func(rule netlink.Rule) bool {
return rule.Mark == fwMark && rule.Table == table
})
if !ruleExists {
if err := c.nl.RuleAdd(&netlink.Rule{
Family: family,
Table: table,
Mark: fwMark,
Invert: true,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Priority: c.getRulePriority(existingRules),
Mask: nil,
Goto: -1,
Flow: -1,
}); err != nil {
return fmt.Errorf("failed to setup rule for fwmark %d and table %d for family-id %d: %w",
fwMark, table, family, err)
}
}
mainRuleExists := slices.ContainsFunc(existingRules, func(rule netlink.Rule) bool {
return rule.SuppressPrefixlen == 0 && rule.Table == unix.RT_TABLE_MAIN
})
if !mainRuleExists && domain.ContainsDefaultRoute(cidrs) {
err = c.nl.RuleAdd(&netlink.Rule{
Family: family,
Table: unix.RT_TABLE_MAIN,
SuppressIfgroup: -1,
SuppressPrefixlen: 0,
Priority: c.getMainRulePriority(existingRules),
Mark: 0,
Mask: nil,
Goto: -1,
Flow: -1,
})
}
// finally, clean up extra main rules - only one rule is allowed
existingRules, err = c.nl.RuleList(family)
if err != nil {
return fmt.Errorf("failed to get existing main rules for family-id %d: %w", family, err)
}
mainRuleCount := 0
for _, rule := range existingRules {
if rule.SuppressPrefixlen == 0 && rule.Table == unix.RT_TABLE_MAIN {
mainRuleCount++
}
if mainRuleCount > 1 {
if err := c.nl.RuleDel(&rule); err != nil {
return fmt.Errorf("failed to remove extra main rule for family-id %d: %w", family, err)
}
}
}
return nil
}
func (c LocalController) getOrCreateRoutingTableAndFwMark(
link netlink.Link,
tableIn int,
fwMarkIn uint32,
) (
table int,
fwmark uint32,
err error,
) {
table = tableIn
fwmark = fwMarkIn
if fwmark == 0 {
// generate a new (temporary) firewall mark based on the interface index
fwmark = uint32(c.cfg.Advanced.RouteTableOffset + link.Attrs().Index)
}
if table == 0 {
table = int(fwmark) // generate a new routing table base on interface index
}
return
}
func (c LocalController) updateFwMarkOnInterface(interfaceId domain.InterfaceIdentifier, fwMark int) error {
// apply the new fwmark to the wireguard interface
err := c.wg.ConfigureDevice(string(interfaceId), wgtypes.Config{
FirewallMark: &fwMark,
})
if err != nil {
return fmt.Errorf("failed to update fwmark of interface %s to: %d: %w", interfaceId, fwMark, err)
}
return nil
}
func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int {
prio := c.cfg.Advanced.RulePrioOffset
for {
isFresh := true
for _, existingRule := range existingRules {
if existingRule.Priority == prio {
isFresh = false
break
}
}
if isFresh {
break
} else {
prio++
}
}
return prio
}
func (c LocalController) getRulePriority(existingRules []netlink.Rule) int {
prio := 32700 // linux main rule has a prio of 32766
for {
isFresh := true
for _, existingRule := range existingRules {
if existingRule.Priority == prio {
isFresh = false
break
}
}
if isFresh {
break
} else {
prio--
}
}
return prio
}
// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op.
func (c LocalController) RemoveRoutes(
ctx context.Context,
_ context.Context,
interfaceId domain.InterfaceIdentifier,
table int,
fwMark uint32,
oldCidrs []domain.Cidr,
) error {
slog.Debug("removing linux routes", "interface", interfaceId, "table", table, "fwMark", fwMark, "cidrs", oldCidrs)
link, err := c.nl.LinkByName(string(interfaceId))
if err != nil {
return fmt.Errorf("failed to find physical link for %s: %w", interfaceId, err)
}
cidrsV4, cidrsV6 := domain.CidrsPerFamily(oldCidrs)
realTable, realFwMark, err := c.getOrCreateRoutingTableAndFwMark(link, table, fwMark)
if err != nil {
return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err)
}
err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V4, realTable, realFwMark, cidrsV4)
if err != nil {
return fmt.Errorf("failed to remove v4 routes: %w", err)
}
err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V6, realTable, realFwMark, cidrsV6)
if err != nil {
return fmt.Errorf("failed to remove v6 routes: %w", err)
}
return nil
}
func (c LocalController) removeRoutesForFamily(
interfaceId domain.InterfaceIdentifier,
link netlink.Link,
family int,
table int,
fwMark uint32,
cidrs []domain.Cidr,
) error {
// first remove all rules
existingRules, err := c.nl.RuleList(family)
if err != nil {
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
}
for _, existingRule := range existingRules {
if fwMark == existingRule.Mark && table == existingRule.Table {
existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
if err := c.nl.RuleDel(&existingRule); err != nil {
return fmt.Errorf("failed to delete old fwmark rule: %w", err)
}
}
}
// next remove all routes
rawRoutes, err := c.nl.RouteListFiltered(family, &netlink.Route{
LinkIndex: link.Attrs().Index,
Table: unix.RT_TABLE_UNSPEC, // all tables
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
if err != nil {
return fmt.Errorf("failed to fetch raw routes for interface %s and family-id %d: %w",
interfaceId, family, err)
}
for _, rawRoute := range rawRoutes {
if rawRoute.Dst == nil { // handle default route
var netlinkAddr domain.Cidr
if family == netlink.FAMILY_V4 {
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
} else {
netlinkAddr, _ = domain.CidrFromString("::/0")
}
rawRoute.Dst = netlinkAddr.IpNet()
}
route := domain.CidrFromIpNet(*rawRoute.Dst)
if !slices.Contains(cidrs, route) {
continue // only remove routes that were previously added
}
if err := c.nl.RouteDel(&rawRoute); err != nil {
return fmt.Errorf("failed to remove old route %s from interface %s: %w", route, interfaceId, err)
}
}
return nil
}