mirror of
https://github.com/h44z/wg-portal.git
synced 2025-10-08 01:16:16 +00:00
replace old route handling for local controller
This commit is contained in:
@@ -9,11 +9,13 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
probing "github.com/prometheus-community/pro-bing"
|
probing "github.com/prometheus-community/pro-bing"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
@@ -638,23 +640,321 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti
|
|||||||
|
|
||||||
// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op.
|
// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op.
|
||||||
func (c LocalController) SetRoutes(
|
func (c LocalController) SetRoutes(
|
||||||
ctx context.Context,
|
_ context.Context,
|
||||||
interfaceId domain.InterfaceIdentifier,
|
interfaceId domain.InterfaceIdentifier,
|
||||||
table int,
|
table int,
|
||||||
fwMark uint32,
|
fwMark uint32,
|
||||||
cidrs []domain.Cidr,
|
cidrs []domain.Cidr,
|
||||||
) error {
|
) error {
|
||||||
|
slog.Debug("setting linux routes", "interface", interfaceId, "table", table, "fwMark", fwMark, "cidrs", cidrs)
|
||||||
|
|
||||||
|
link, err := c.nl.LinkByName(string(interfaceId))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to find physical link for %s: %w", interfaceId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cidrsV4, cidrsV6 := domain.CidrsPerFamily(cidrs)
|
||||||
|
realTable, realFwMark, err := c.getOrCreateRoutingTableAndFwMark(link, table, fwMark)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err)
|
||||||
|
}
|
||||||
|
if realFwMark != fwMark {
|
||||||
|
slog.Debug("updating fwmark for interface", "interface", interfaceId, "oldFwMark", fwMark,
|
||||||
|
"newFwMark", realFwMark, "oldTable", table, "newTable", realTable)
|
||||||
|
if err := c.updateFwMarkOnInterface(interfaceId, int(realFwMark)); err != nil {
|
||||||
|
return fmt.Errorf("failed to update fwmark for interface %s to %d: %w", interfaceId, realFwMark, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.setRoutesForFamily(interfaceId, link, netlink.FAMILY_V4, realTable, realFwMark, cidrsV4); err != nil {
|
||||||
|
return fmt.Errorf("failed to set v4 routes: %w", err)
|
||||||
|
}
|
||||||
|
if err := c.setRoutesForFamily(interfaceId, link, netlink.FAMILY_V6, realTable, realFwMark, cidrsV6); err != nil {
|
||||||
|
return fmt.Errorf("failed to set v6 routes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c LocalController) setRoutesForFamily(
|
||||||
|
interfaceId domain.InterfaceIdentifier,
|
||||||
|
link netlink.Link,
|
||||||
|
family int,
|
||||||
|
table int,
|
||||||
|
fwMark uint32,
|
||||||
|
cidrs []domain.Cidr,
|
||||||
|
) error {
|
||||||
|
// first create or update the routes
|
||||||
|
for _, cidr := range cidrs {
|
||||||
|
err := c.nl.RouteReplace(&netlink.Route{
|
||||||
|
LinkIndex: link.Attrs().Index,
|
||||||
|
Dst: cidr.IpNet(),
|
||||||
|
Table: table,
|
||||||
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
|
Type: unix.RTN_UNICAST,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add/update route %s on table %d for interface %s: %w",
|
||||||
|
cidr.String(), table, interfaceId, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// next remove old routes
|
||||||
|
rawRoutes, err := c.nl.RouteListFiltered(family, &netlink.Route{
|
||||||
|
LinkIndex: link.Attrs().Index,
|
||||||
|
Table: unix.RT_TABLE_UNSPEC, // all tables
|
||||||
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
|
Type: unix.RTN_UNICAST,
|
||||||
|
}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to fetch raw routes for interface %s and family-id %d: %w",
|
||||||
|
interfaceId, family, err)
|
||||||
|
}
|
||||||
|
for _, rawRoute := range rawRoutes {
|
||||||
|
if rawRoute.Dst == nil { // handle default route
|
||||||
|
var netlinkAddr domain.Cidr
|
||||||
|
if family == netlink.FAMILY_V4 {
|
||||||
|
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
|
||||||
|
} else {
|
||||||
|
netlinkAddr, _ = domain.CidrFromString("::/0")
|
||||||
|
}
|
||||||
|
rawRoute.Dst = netlinkAddr.IpNet()
|
||||||
|
}
|
||||||
|
|
||||||
|
route := domain.CidrFromIpNet(*rawRoute.Dst)
|
||||||
|
if slices.Contains(cidrs, route) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.nl.RouteDel(&rawRoute); err != nil {
|
||||||
|
return fmt.Errorf("failed to remove deprecated route %s from interface %s: %w", route, interfaceId, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// next, update route rules for normal routes
|
||||||
|
if table == 0 {
|
||||||
|
return nil // no need to update route rules as we are using the default table
|
||||||
|
}
|
||||||
|
existingRules, err := c.nl.RuleList(family)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get existing rules for family-id %d: %w", family, err)
|
||||||
|
}
|
||||||
|
ruleExists := slices.ContainsFunc(existingRules, func(rule netlink.Rule) bool {
|
||||||
|
return rule.Mark == fwMark && rule.Table == table
|
||||||
|
})
|
||||||
|
if !ruleExists {
|
||||||
|
if err := c.nl.RuleAdd(&netlink.Rule{
|
||||||
|
Family: family,
|
||||||
|
Table: table,
|
||||||
|
Mark: fwMark,
|
||||||
|
Invert: true,
|
||||||
|
SuppressIfgroup: -1,
|
||||||
|
SuppressPrefixlen: -1,
|
||||||
|
Priority: c.getRulePriority(existingRules),
|
||||||
|
Mask: nil,
|
||||||
|
Goto: -1,
|
||||||
|
Flow: -1,
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("failed to setup rule for fwmark %d and table %d for family-id %d: %w",
|
||||||
|
fwMark, table, family, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mainRuleExists := slices.ContainsFunc(existingRules, func(rule netlink.Rule) bool {
|
||||||
|
return rule.SuppressPrefixlen == 0 && rule.Table == unix.RT_TABLE_MAIN
|
||||||
|
})
|
||||||
|
if !mainRuleExists && domain.ContainsDefaultRoute(cidrs) {
|
||||||
|
err = c.nl.RuleAdd(&netlink.Rule{
|
||||||
|
Family: family,
|
||||||
|
Table: unix.RT_TABLE_MAIN,
|
||||||
|
SuppressIfgroup: -1,
|
||||||
|
SuppressPrefixlen: 0,
|
||||||
|
Priority: c.getMainRulePriority(existingRules),
|
||||||
|
Mark: 0,
|
||||||
|
Mask: nil,
|
||||||
|
Goto: -1,
|
||||||
|
Flow: -1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// finally, clean up extra main rules - only one rule is allowed
|
||||||
|
existingRules, err = c.nl.RuleList(family)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get existing main rules for family-id %d: %w", family, err)
|
||||||
|
}
|
||||||
|
mainRuleCount := 0
|
||||||
|
for _, rule := range existingRules {
|
||||||
|
if rule.SuppressPrefixlen == 0 && rule.Table == unix.RT_TABLE_MAIN {
|
||||||
|
mainRuleCount++
|
||||||
|
}
|
||||||
|
if mainRuleCount > 1 {
|
||||||
|
if err := c.nl.RuleDel(&rule); err != nil {
|
||||||
|
return fmt.Errorf("failed to remove extra main rule for family-id %d: %w", family, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) getOrCreateRoutingTableAndFwMark(
|
||||||
|
link netlink.Link,
|
||||||
|
tableIn int,
|
||||||
|
fwMarkIn uint32,
|
||||||
|
) (
|
||||||
|
table int,
|
||||||
|
fwmark uint32,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
table = tableIn
|
||||||
|
fwmark = fwMarkIn
|
||||||
|
|
||||||
|
if fwmark == 0 {
|
||||||
|
// generate a new (temporary) firewall mark based on the interface index
|
||||||
|
fwmark = uint32(c.cfg.Advanced.RouteTableOffset + link.Attrs().Index)
|
||||||
|
}
|
||||||
|
if table == 0 {
|
||||||
|
table = int(fwmark) // generate a new routing table base on interface index
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) updateFwMarkOnInterface(interfaceId domain.InterfaceIdentifier, fwMark int) error {
|
||||||
|
// apply the new fwmark to the wireguard interface
|
||||||
|
err := c.wg.ConfigureDevice(string(interfaceId), wgtypes.Config{
|
||||||
|
FirewallMark: &fwMark,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update fwmark of interface %s to: %d: %w", interfaceId, fwMark, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int {
|
||||||
|
prio := c.cfg.Advanced.RulePrioOffset
|
||||||
|
for {
|
||||||
|
isFresh := true
|
||||||
|
for _, existingRule := range existingRules {
|
||||||
|
if existingRule.Priority == prio {
|
||||||
|
isFresh = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isFresh {
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
prio++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return prio
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) getRulePriority(existingRules []netlink.Rule) int {
|
||||||
|
prio := 32700 // linux main rule has a prio of 32766
|
||||||
|
for {
|
||||||
|
isFresh := true
|
||||||
|
for _, existingRule := range existingRules {
|
||||||
|
if existingRule.Priority == prio {
|
||||||
|
isFresh = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isFresh {
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
prio--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return prio
|
||||||
|
}
|
||||||
|
|
||||||
// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op.
|
// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op.
|
||||||
func (c LocalController) RemoveRoutes(
|
func (c LocalController) RemoveRoutes(
|
||||||
ctx context.Context,
|
_ context.Context,
|
||||||
interfaceId domain.InterfaceIdentifier,
|
interfaceId domain.InterfaceIdentifier,
|
||||||
table int,
|
table int,
|
||||||
fwMark uint32,
|
fwMark uint32,
|
||||||
oldCidrs []domain.Cidr,
|
oldCidrs []domain.Cidr,
|
||||||
) error {
|
) error {
|
||||||
|
slog.Debug("removing linux routes", "interface", interfaceId, "table", table, "fwMark", fwMark, "cidrs", oldCidrs)
|
||||||
|
|
||||||
|
link, err := c.nl.LinkByName(string(interfaceId))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to find physical link for %s: %w", interfaceId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cidrsV4, cidrsV6 := domain.CidrsPerFamily(oldCidrs)
|
||||||
|
realTable, realFwMark, err := c.getOrCreateRoutingTableAndFwMark(link, table, fwMark)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V4, realTable, realFwMark, cidrsV4)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to remove v4 routes: %w", err)
|
||||||
|
}
|
||||||
|
err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V6, realTable, realFwMark, cidrsV6)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to remove v6 routes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) removeRoutesForFamily(
|
||||||
|
interfaceId domain.InterfaceIdentifier,
|
||||||
|
link netlink.Link,
|
||||||
|
family int,
|
||||||
|
table int,
|
||||||
|
fwMark uint32,
|
||||||
|
cidrs []domain.Cidr,
|
||||||
|
) error {
|
||||||
|
// first remove all rules
|
||||||
|
existingRules, err := c.nl.RuleList(family)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
|
||||||
|
}
|
||||||
|
for _, existingRule := range existingRules {
|
||||||
|
if fwMark == existingRule.Mark && table == existingRule.Table {
|
||||||
|
existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
|
||||||
|
if err := c.nl.RuleDel(&existingRule); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete old fwmark rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// next remove all routes
|
||||||
|
rawRoutes, err := c.nl.RouteListFiltered(family, &netlink.Route{
|
||||||
|
LinkIndex: link.Attrs().Index,
|
||||||
|
Table: unix.RT_TABLE_UNSPEC, // all tables
|
||||||
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
|
Type: unix.RTN_UNICAST,
|
||||||
|
}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to fetch raw routes for interface %s and family-id %d: %w",
|
||||||
|
interfaceId, family, err)
|
||||||
|
}
|
||||||
|
for _, rawRoute := range rawRoutes {
|
||||||
|
if rawRoute.Dst == nil { // handle default route
|
||||||
|
var netlinkAddr domain.Cidr
|
||||||
|
if family == netlink.FAMILY_V4 {
|
||||||
|
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
|
||||||
|
} else {
|
||||||
|
netlinkAddr, _ = domain.CidrFromString("::/0")
|
||||||
|
}
|
||||||
|
rawRoute.Dst = netlinkAddr.IpNet()
|
||||||
|
}
|
||||||
|
|
||||||
|
route := domain.CidrFromIpNet(*rawRoute.Dst)
|
||||||
|
if !slices.Contains(cidrs, route) {
|
||||||
|
continue // only remove routes that were previously added
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.nl.RouteDel(&rawRoute); err != nil {
|
||||||
|
return fmt.Errorf("failed to remove old route %s from interface %s: %w", route, interfaceId, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -138,6 +138,12 @@ func (m Manager) syncRoutes(ctx context.Context, info domain.RoutingTableInfo) e
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !info.Interface.ManageRoutingTable() {
|
||||||
|
slog.Debug("interface does not manage routing table, skipping route update",
|
||||||
|
"interface", info.Interface.Identifier)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
err := rc.SetRoutes(ctx, info.Interface.Identifier, info.Table, info.FwMark, info.AllowedIps)
|
err := rc.SetRoutes(ctx, info.Interface.Identifier, info.Table, info.FwMark, info.AllowedIps)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to set routes for interface %s: %w", info.Interface.Identifier, err)
|
return fmt.Errorf("failed to set routes for interface %s: %w", info.Interface.Identifier, err)
|
||||||
@@ -152,6 +158,12 @@ func (m Manager) removeRoutes(ctx context.Context, info domain.RoutingTableInfo)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !info.Interface.ManageRoutingTable() {
|
||||||
|
slog.Debug("interface does not manage routing table, skipping route removal",
|
||||||
|
"interface", info.Interface.Identifier)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
err := rc.RemoveRoutes(ctx, info.Interface.Identifier, info.Table, info.FwMark, info.AllowedIps)
|
err := rc.RemoveRoutes(ctx, info.Interface.Identifier, info.Table, info.FwMark, info.AllowedIps)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to remove routes for interface %s: %w", info.Interface.Identifier, err)
|
return fmt.Errorf("failed to remove routes for interface %s: %w", info.Interface.Identifier, err)
|
||||||
|
@@ -201,20 +201,24 @@ func (c Cidr) Contains(other Cidr) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ContainsDefaultRoute returns true if the given CIDRs contain a default route.
|
// ContainsDefaultRoute returns true if the given CIDRs contain a default route.
|
||||||
func ContainsDefaultRoute(cidrs []Cidr) (ipV4, ipV6 bool) {
|
func ContainsDefaultRoute(cidrs []Cidr) bool {
|
||||||
for _, allowedIP := range cidrs {
|
for _, allowedIP := range cidrs {
|
||||||
if ipV4 && ipV6 {
|
|
||||||
break // speed up
|
|
||||||
}
|
|
||||||
|
|
||||||
if allowedIP.Prefix().Bits() == 0 {
|
if allowedIP.Prefix().Bits() == 0 {
|
||||||
if allowedIP.IsV4() {
|
return true
|
||||||
ipV4 = true
|
|
||||||
} else {
|
|
||||||
ipV6 = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// CidrsPerFamily returns a slice of CIDRs, one for each family (IPv4 and IPv6).
|
||||||
|
func CidrsPerFamily(cidrs []Cidr) (ipv4, ipv6 []Cidr) {
|
||||||
|
for _, cidr := range cidrs {
|
||||||
|
if cidr.IsV4() {
|
||||||
|
ipv4 = append(ipv4, cidr)
|
||||||
|
} else {
|
||||||
|
ipv6 = append(ipv6, cidr)
|
||||||
|
}
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user