wip routes

This commit is contained in:
Christoph Haas 2023-07-30 18:55:44 +02:00
parent f4e5072f97
commit 2113999b22
5 changed files with 160 additions and 22 deletions

View File

@ -299,38 +299,106 @@ func (r *WgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
return nil return nil
} }
func (r *WgRepo) SaveRoutes(_ context.Context, interfaceId domain.InterfaceIdentifier, table int, allowedIPs []domain.Cidr) error { func (r *WgRepo) SaveRoutes(_ context.Context, iface *domain.Interface, peers []domain.Peer) error {
if table == -1 { table := iface.GetRoutingTable()
logrus.Trace("ignoring route update") if table == -2 {
logrus.Trace("ignoring route update, feature disabled")
return nil return nil
} }
link, err := r.nl.LinkByName(string(interfaceId)) if !iface.IsDisabled() {
return r.setupRoutes(iface, peers)
} else {
return r.cleanupRoutes(iface, peers)
}
}
func (r *WgRepo) setupRoutes(iface *domain.Interface, peers []domain.Peer) error {
link, err := r.nl.LinkByName(string(iface.Identifier))
if err != nil { if err != nil {
return err return fmt.Errorf("unable to find physical interface %s: %w", iface.Identifier, err)
} }
if link.Attrs().OperState == netlink.OperDown { if link.Attrs().OperState == netlink.OperDown {
return nil // cannot set route for interface that is down return nil // cannot set route for interface that is down
} }
table, fwmark, err := r.getRoutingTableAndFwMark(iface, peers, link)
if err != nil {
return fmt.Errorf("failed to get table and fwmark: %w", err)
}
// try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash) // try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash)
allowedIPs := iface.GetAllowedIPs(peers)
for _, allowedIP := range allowedIPs { for _, allowedIP := range allowedIPs {
// if allowedIP.Prefix().Bits() == 0 { // default route handling - TODO if allowedIP.Prefix().Bits() == 0 { // default route handling
err := r.nl.RouteReplace(&netlink.Route{ if err := r.nl.RouteReplace(&netlink.Route{
LinkIndex: link.Attrs().Index, LinkIndex: link.Attrs().Index,
Dst: allowedIP.IpNet(), Dst: allowedIP.IpNet(),
Table: table, Table: table,
Scope: unix.RT_SCOPE_LINK, Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST, Type: unix.RTN_UNICAST,
}) }); err != nil {
if err != nil { return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err) }
family := netlink.FAMILY_V4
if !allowedIP.IsV4() {
family = netlink.FAMILY_V6
}
if err := r.nl.RuleAdd(&netlink.Rule{
Family: family,
Table: table,
Mark: fwmark,
Invert: true,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Priority: -1,
Mask: -1,
Goto: -1,
Flow: -1,
}); err != nil {
return fmt.Errorf("failed to setup rule for fwmark %d: %w", fwmark, err)
}
if err := r.nl.RuleAdd(&netlink.Rule{
Family: family,
Table: unix.RT_TABLE_MAIN,
SuppressIfgroup: -1,
SuppressPrefixlen: 0,
Priority: -1,
Mark: -1,
Mask: -1,
Goto: -1,
Flow: -1,
}); err != nil {
return fmt.Errorf("failed to setup rule for main table: %w", err)
}
} else {
err := r.nl.RouteReplace(&netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: allowedIP.IpNet(),
Table: table,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
})
if err != nil {
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
}
} }
} }
// Remove unwanted routes if err := r.removeDeprecatedRoutes(link, allowedIPs, netlink.FAMILY_V4); err != nil {
rawRoutes, err := r.nl.RouteListFiltered(netlink.FAMILY_ALL, &netlink.Route{ return fmt.Errorf("failed to remove deprecated v4 routes: %w", err)
}
if err := r.removeDeprecatedRoutes(link, allowedIPs, netlink.FAMILY_V6); err != nil {
return fmt.Errorf("failed to remove deprecated v6 routes: %w", err)
}
return nil
}
func (r *WgRepo) removeDeprecatedRoutes(link netlink.Link, allowedIPs []domain.Cidr, family int) error {
rawRoutes, err := r.nl.RouteListFiltered(family, &netlink.Route{
LinkIndex: link.Attrs().Index, LinkIndex: link.Attrs().Index,
Table: unix.RT_TABLE_UNSPEC, // all tables Table: unix.RT_TABLE_UNSPEC, // all tables
Scope: unix.RT_SCOPE_LINK, Scope: unix.RT_SCOPE_LINK,
@ -340,7 +408,16 @@ func (r *WgRepo) SaveRoutes(_ context.Context, interfaceId domain.InterfaceIdent
return fmt.Errorf("failed to fetch raw routes: %w", err) return fmt.Errorf("failed to fetch raw routes: %w", err)
} }
for _, rawRoute := range rawRoutes { for _, rawRoute := range rawRoutes {
netlinkAddr := domain.CidrFromIpNet(*rawRoute.Dst) var netlinkAddr domain.Cidr
if rawRoute.Dst == nil {
if family == netlink.FAMILY_V4 {
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
} else {
netlinkAddr, _ = domain.CidrFromString("::/0")
}
} else {
netlinkAddr = domain.CidrFromIpNet(*rawRoute.Dst)
}
remove := true remove := true
for _, allowedIP := range allowedIPs { for _, allowedIP := range allowedIPs {
if netlinkAddr == allowedIP { if netlinkAddr == allowedIP {
@ -358,10 +435,66 @@ func (r *WgRepo) SaveRoutes(_ context.Context, interfaceId domain.InterfaceIdent
return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err) return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err)
} }
} }
return nil
}
func (r *WgRepo) cleanupRoutes(iface *domain.Interface, peers []domain.Peer) error {
link, err := r.nl.LinkByName(string(iface.Identifier))
if err != nil {
return fmt.Errorf("unable to find physical interface %s: %w", iface.Identifier, err)
}
table, _, err := r.getRoutingTableAndFwMark(iface, peers, link)
if err != nil {
return fmt.Errorf("failed to get table and fwmark: %w", err)
}
if table == 0 {
return nil // noting to remove
}
delRule := netlink.NewRule()
delRule.Family = netlink.FAMILY_ALL
delRule.Table = table
if err := r.nl.RuleDel(delRule); err != nil && !errors.Is(err, unix.ENOENT) {
return fmt.Errorf("failed to delete rule for table %d: %w", table, err)
}
return nil return nil
} }
func (r *WgRepo) getRoutingTableAndFwMark(iface *domain.Interface, peers []domain.Peer, link netlink.Link) (table, fwmark int, err error) {
allowedIPs := iface.GetAllowedIPs(peers)
containsDefaultRoute := false
for _, allowedIP := range allowedIPs {
if allowedIP.Prefix().Bits() == 0 {
containsDefaultRoute = true
break
}
}
table = iface.GetRoutingTable()
fwmark = int(iface.FirewallMark)
if containsDefaultRoute && table <= 0 {
table = 20000 + link.Attrs().Index // generate a new routing table base on interface index
logrus.Debugf("using routing table %d to handle default routes", table)
}
if containsDefaultRoute && fwmark == 0 {
fwmark = 20000 + link.Attrs().Index // generate a new (temporary) firewall mark based on the interface index
logrus.Debugf("using fwmark %d to handle default routes", table)
// apply the fwmark
err = r.wg.ConfigureDevice(string(iface.Identifier), wgtypes.Config{
FirewallMark: &fwmark,
})
if err != nil {
return 0, 0, fmt.Errorf("failed to update temporary fwmark to: %d: %w", fwmark, err)
}
}
return
}
func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error { func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
if err := r.deleteLowLevelInterface(id); err != nil { if err := r.deleteLowLevelInterface(id); err != nil {
return err return err

View File

@ -41,7 +41,7 @@ type InterfaceController interface {
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
SaveRoutes(_ context.Context, deviceId domain.InterfaceIdentifier, table int, allowedIPs []domain.Cidr) error SaveRoutes(_ context.Context, iface *domain.Interface, peers []domain.Peer) error
} }
type WgQuickController interface { type WgQuickController interface {

View File

@ -384,7 +384,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface, pee
return nil, fmt.Errorf("failed to save interface: %w", err) return nil, fmt.Errorf("failed to save interface: %w", err)
} }
err = m.wg.SaveRoutes(ctx, iface.Identifier, iface.GetRoutingTable(), iface.GetAllowedIPs(peers)) err = m.wg.SaveRoutes(ctx, iface, peers)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to save routes: %w", err) return nil, fmt.Errorf("failed to save routes: %w", err)
} }

View File

@ -284,7 +284,7 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to load peer interface %s: %w", ifaceId, err) return fmt.Errorf("failed to load peer interface %s: %w", ifaceId, err)
} }
err = m.wg.SaveRoutes(ctx, iface.Identifier, iface.GetRoutingTable(), iface.GetAllowedIPs(ifacePeers)) err = m.wg.SaveRoutes(ctx, iface, ifacePeers)
if err != nil { if err != nil {
return fmt.Errorf("failed to update peer routes on interface %s: %w", ifaceId, err) return fmt.Errorf("failed to update peer routes on interface %s: %w", ifaceId, err)
} }

View File

@ -114,12 +114,17 @@ func (i *Interface) GetAllowedIPs(peers []Peer) []Cidr {
return allowedCidrs return allowedCidrs
} }
// GetRoutingTable returns the routing table number or -1 if an error occurred or RoutingTable was set to "off" // GetRoutingTable returns the routing table number or
//
// -1 if an error occurred
// -2 if RoutingTable was set to "off"
func (i *Interface) GetRoutingTable() int { func (i *Interface) GetRoutingTable() int {
routingTableStr := strings.ToLower(i.RoutingTable) routingTableStr := strings.ToLower(i.RoutingTable)
switch { switch {
case routingTableStr == "": case routingTableStr == "":
return 0 return 0
case routingTableStr == "off":
return -2
case strings.HasPrefix(routingTableStr, "0x"): case strings.HasPrefix(routingTableStr, "0x"):
numberStr := strings.ReplaceAll(routingTableStr, "0x", "") numberStr := strings.ReplaceAll(routingTableStr, "0x", "")
routingTable, err := strconv.ParseUint(numberStr, 16, 64) routingTable, err := strconv.ParseUint(numberStr, 16, 64)