diff --git a/internal/adapters/wireguard.go b/internal/adapters/wireguard.go index b57a5de..07e5240 100644 --- a/internal/adapters/wireguard.go +++ b/internal/adapters/wireguard.go @@ -299,38 +299,106 @@ func (r *WgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error { return nil } -func (r *WgRepo) SaveRoutes(_ context.Context, interfaceId domain.InterfaceIdentifier, table int, allowedIPs []domain.Cidr) error { - if table == -1 { - logrus.Trace("ignoring route update") +func (r *WgRepo) SaveRoutes(_ context.Context, iface *domain.Interface, peers []domain.Peer) error { + table := iface.GetRoutingTable() + if table == -2 { + logrus.Trace("ignoring route update, feature disabled") return nil } - link, err := r.nl.LinkByName(string(interfaceId)) + if !iface.IsDisabled() { + return r.setupRoutes(iface, peers) + } else { + return r.cleanupRoutes(iface, peers) + } +} + +func (r *WgRepo) setupRoutes(iface *domain.Interface, peers []domain.Peer) error { + link, err := r.nl.LinkByName(string(iface.Identifier)) if err != nil { - return err + return fmt.Errorf("unable to find physical interface %s: %w", iface.Identifier, err) } if link.Attrs().OperState == netlink.OperDown { return nil // cannot set route for interface that is down } + table, fwmark, err := r.getRoutingTableAndFwMark(iface, peers, link) + if err != nil { + return fmt.Errorf("failed to get table and fwmark: %w", err) + } + // try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash) + allowedIPs := iface.GetAllowedIPs(peers) for _, allowedIP := range allowedIPs { - // if allowedIP.Prefix().Bits() == 0 { // default route handling - TODO - err := r.nl.RouteReplace(&netlink.Route{ - LinkIndex: link.Attrs().Index, - Dst: allowedIP.IpNet(), - Table: table, - Scope: unix.RT_SCOPE_LINK, - Type: unix.RTN_UNICAST, - }) - if err != nil { - return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err) + if allowedIP.Prefix().Bits() == 0 { // default route handling + if err := r.nl.RouteReplace(&netlink.Route{ + LinkIndex: link.Attrs().Index, + Dst: allowedIP.IpNet(), + Table: table, + Scope: unix.RT_SCOPE_LINK, + Type: unix.RTN_UNICAST, + }); err != nil { + return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err) + } + + family := netlink.FAMILY_V4 + if !allowedIP.IsV4() { + family = netlink.FAMILY_V6 + } + if err := r.nl.RuleAdd(&netlink.Rule{ + Family: family, + Table: table, + Mark: fwmark, + Invert: true, + SuppressIfgroup: -1, + SuppressPrefixlen: -1, + Priority: -1, + Mask: -1, + Goto: -1, + Flow: -1, + }); err != nil { + return fmt.Errorf("failed to setup rule for fwmark %d: %w", fwmark, err) + } + if err := r.nl.RuleAdd(&netlink.Rule{ + Family: family, + Table: unix.RT_TABLE_MAIN, + SuppressIfgroup: -1, + SuppressPrefixlen: 0, + Priority: -1, + Mark: -1, + Mask: -1, + Goto: -1, + Flow: -1, + }); err != nil { + return fmt.Errorf("failed to setup rule for main table: %w", err) + } + } else { + err := r.nl.RouteReplace(&netlink.Route{ + LinkIndex: link.Attrs().Index, + Dst: allowedIP.IpNet(), + Table: table, + Scope: unix.RT_SCOPE_LINK, + Type: unix.RTN_UNICAST, + }) + if err != nil { + return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err) + } } } - // Remove unwanted routes - rawRoutes, err := r.nl.RouteListFiltered(netlink.FAMILY_ALL, &netlink.Route{ + if err := r.removeDeprecatedRoutes(link, allowedIPs, netlink.FAMILY_V4); err != nil { + return fmt.Errorf("failed to remove deprecated v4 routes: %w", err) + } + if err := r.removeDeprecatedRoutes(link, allowedIPs, netlink.FAMILY_V6); err != nil { + return fmt.Errorf("failed to remove deprecated v6 routes: %w", err) + } + + return nil +} + +func (r *WgRepo) removeDeprecatedRoutes(link netlink.Link, allowedIPs []domain.Cidr, family int) error { + rawRoutes, err := r.nl.RouteListFiltered(family, &netlink.Route{ LinkIndex: link.Attrs().Index, Table: unix.RT_TABLE_UNSPEC, // all tables Scope: unix.RT_SCOPE_LINK, @@ -340,7 +408,16 @@ func (r *WgRepo) SaveRoutes(_ context.Context, interfaceId domain.InterfaceIdent return fmt.Errorf("failed to fetch raw routes: %w", err) } for _, rawRoute := range rawRoutes { - netlinkAddr := domain.CidrFromIpNet(*rawRoute.Dst) + var netlinkAddr domain.Cidr + if rawRoute.Dst == nil { + if family == netlink.FAMILY_V4 { + netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0") + } else { + netlinkAddr, _ = domain.CidrFromString("::/0") + } + } else { + netlinkAddr = domain.CidrFromIpNet(*rawRoute.Dst) + } remove := true for _, allowedIP := range allowedIPs { if netlinkAddr == allowedIP { @@ -358,10 +435,66 @@ func (r *WgRepo) SaveRoutes(_ context.Context, interfaceId domain.InterfaceIdent return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err) } } + return nil +} + +func (r *WgRepo) cleanupRoutes(iface *domain.Interface, peers []domain.Peer) error { + link, err := r.nl.LinkByName(string(iface.Identifier)) + if err != nil { + return fmt.Errorf("unable to find physical interface %s: %w", iface.Identifier, err) + } + + table, _, err := r.getRoutingTableAndFwMark(iface, peers, link) + if err != nil { + return fmt.Errorf("failed to get table and fwmark: %w", err) + } + + if table == 0 { + return nil // noting to remove + } + + delRule := netlink.NewRule() + delRule.Family = netlink.FAMILY_ALL + delRule.Table = table + if err := r.nl.RuleDel(delRule); err != nil && !errors.Is(err, unix.ENOENT) { + return fmt.Errorf("failed to delete rule for table %d: %w", table, err) + } return nil } +func (r *WgRepo) getRoutingTableAndFwMark(iface *domain.Interface, peers []domain.Peer, link netlink.Link) (table, fwmark int, err error) { + allowedIPs := iface.GetAllowedIPs(peers) + containsDefaultRoute := false + for _, allowedIP := range allowedIPs { + if allowedIP.Prefix().Bits() == 0 { + containsDefaultRoute = true + break + } + } + + table = iface.GetRoutingTable() + fwmark = int(iface.FirewallMark) + + if containsDefaultRoute && table <= 0 { + table = 20000 + link.Attrs().Index // generate a new routing table base on interface index + logrus.Debugf("using routing table %d to handle default routes", table) + } + if containsDefaultRoute && fwmark == 0 { + fwmark = 20000 + link.Attrs().Index // generate a new (temporary) firewall mark based on the interface index + logrus.Debugf("using fwmark %d to handle default routes", table) + + // apply the fwmark + err = r.wg.ConfigureDevice(string(iface.Identifier), wgtypes.Config{ + FirewallMark: &fwmark, + }) + if err != nil { + return 0, 0, fmt.Errorf("failed to update temporary fwmark to: %d: %w", fwmark, err) + } + } + return +} + func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error { if err := r.deleteLowLevelInterface(id); err != nil { return err diff --git a/internal/app/wireguard/repos.go b/internal/app/wireguard/repos.go index 14693b9..46222d0 100644 --- a/internal/app/wireguard/repos.go +++ b/internal/app/wireguard/repos.go @@ -41,7 +41,7 @@ type InterfaceController interface { DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error - SaveRoutes(_ context.Context, deviceId domain.InterfaceIdentifier, table int, allowedIPs []domain.Cidr) error + SaveRoutes(_ context.Context, iface *domain.Interface, peers []domain.Peer) error } type WgQuickController interface { diff --git a/internal/app/wireguard/wireguard_interfaces.go b/internal/app/wireguard/wireguard_interfaces.go index b3b4058..83fadb7 100644 --- a/internal/app/wireguard/wireguard_interfaces.go +++ b/internal/app/wireguard/wireguard_interfaces.go @@ -384,7 +384,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface, pee return nil, fmt.Errorf("failed to save interface: %w", err) } - err = m.wg.SaveRoutes(ctx, iface.Identifier, iface.GetRoutingTable(), iface.GetAllowedIPs(peers)) + err = m.wg.SaveRoutes(ctx, iface, peers) if err != nil { return nil, fmt.Errorf("failed to save routes: %w", err) } diff --git a/internal/app/wireguard/wireguard_peers.go b/internal/app/wireguard/wireguard_peers.go index aca99eb..bec6db8 100644 --- a/internal/app/wireguard/wireguard_peers.go +++ b/internal/app/wireguard/wireguard_peers.go @@ -284,7 +284,7 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error { if err != nil { return fmt.Errorf("failed to load peer interface %s: %w", ifaceId, err) } - err = m.wg.SaveRoutes(ctx, iface.Identifier, iface.GetRoutingTable(), iface.GetAllowedIPs(ifacePeers)) + err = m.wg.SaveRoutes(ctx, iface, ifacePeers) if err != nil { return fmt.Errorf("failed to update peer routes on interface %s: %w", ifaceId, err) } diff --git a/internal/domain/interface.go b/internal/domain/interface.go index 4386f41..b09c51a 100644 --- a/internal/domain/interface.go +++ b/internal/domain/interface.go @@ -114,12 +114,17 @@ func (i *Interface) GetAllowedIPs(peers []Peer) []Cidr { return allowedCidrs } -// GetRoutingTable returns the routing table number or -1 if an error occurred or RoutingTable was set to "off" +// GetRoutingTable returns the routing table number or +// +// -1 if an error occurred +// -2 if RoutingTable was set to "off" func (i *Interface) GetRoutingTable() int { routingTableStr := strings.ToLower(i.RoutingTable) switch { case routingTableStr == "": return 0 + case routingTableStr == "off": + return -2 case strings.HasPrefix(routingTableStr, "0x"): numberStr := strings.ReplaceAll(routingTableStr, "0x", "") routingTable, err := strconv.ParseUint(numberStr, 16, 64)