wip routes

This commit is contained in:
Christoph Haas 2023-07-30 18:55:44 +02:00
parent f4e5072f97
commit 2113999b22
5 changed files with 160 additions and 22 deletions

View File

@ -299,38 +299,106 @@ func (r *WgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
return nil
}
func (r *WgRepo) SaveRoutes(_ context.Context, interfaceId domain.InterfaceIdentifier, table int, allowedIPs []domain.Cidr) error {
if table == -1 {
logrus.Trace("ignoring route update")
func (r *WgRepo) SaveRoutes(_ context.Context, iface *domain.Interface, peers []domain.Peer) error {
table := iface.GetRoutingTable()
if table == -2 {
logrus.Trace("ignoring route update, feature disabled")
return nil
}
link, err := r.nl.LinkByName(string(interfaceId))
if !iface.IsDisabled() {
return r.setupRoutes(iface, peers)
} else {
return r.cleanupRoutes(iface, peers)
}
}
func (r *WgRepo) setupRoutes(iface *domain.Interface, peers []domain.Peer) error {
link, err := r.nl.LinkByName(string(iface.Identifier))
if err != nil {
return err
return fmt.Errorf("unable to find physical interface %s: %w", iface.Identifier, err)
}
if link.Attrs().OperState == netlink.OperDown {
return nil // cannot set route for interface that is down
}
table, fwmark, err := r.getRoutingTableAndFwMark(iface, peers, link)
if err != nil {
return fmt.Errorf("failed to get table and fwmark: %w", err)
}
// try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash)
allowedIPs := iface.GetAllowedIPs(peers)
for _, allowedIP := range allowedIPs {
// if allowedIP.Prefix().Bits() == 0 { // default route handling - TODO
err := r.nl.RouteReplace(&netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: allowedIP.IpNet(),
Table: table,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
})
if err != nil {
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
if allowedIP.Prefix().Bits() == 0 { // default route handling
if err := r.nl.RouteReplace(&netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: allowedIP.IpNet(),
Table: table,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
}); err != nil {
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
}
family := netlink.FAMILY_V4
if !allowedIP.IsV4() {
family = netlink.FAMILY_V6
}
if err := r.nl.RuleAdd(&netlink.Rule{
Family: family,
Table: table,
Mark: fwmark,
Invert: true,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Priority: -1,
Mask: -1,
Goto: -1,
Flow: -1,
}); err != nil {
return fmt.Errorf("failed to setup rule for fwmark %d: %w", fwmark, err)
}
if err := r.nl.RuleAdd(&netlink.Rule{
Family: family,
Table: unix.RT_TABLE_MAIN,
SuppressIfgroup: -1,
SuppressPrefixlen: 0,
Priority: -1,
Mark: -1,
Mask: -1,
Goto: -1,
Flow: -1,
}); err != nil {
return fmt.Errorf("failed to setup rule for main table: %w", err)
}
} else {
err := r.nl.RouteReplace(&netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: allowedIP.IpNet(),
Table: table,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
})
if err != nil {
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
}
}
}
// Remove unwanted routes
rawRoutes, err := r.nl.RouteListFiltered(netlink.FAMILY_ALL, &netlink.Route{
if err := r.removeDeprecatedRoutes(link, allowedIPs, netlink.FAMILY_V4); err != nil {
return fmt.Errorf("failed to remove deprecated v4 routes: %w", err)
}
if err := r.removeDeprecatedRoutes(link, allowedIPs, netlink.FAMILY_V6); err != nil {
return fmt.Errorf("failed to remove deprecated v6 routes: %w", err)
}
return nil
}
func (r *WgRepo) removeDeprecatedRoutes(link netlink.Link, allowedIPs []domain.Cidr, family int) error {
rawRoutes, err := r.nl.RouteListFiltered(family, &netlink.Route{
LinkIndex: link.Attrs().Index,
Table: unix.RT_TABLE_UNSPEC, // all tables
Scope: unix.RT_SCOPE_LINK,
@ -340,7 +408,16 @@ func (r *WgRepo) SaveRoutes(_ context.Context, interfaceId domain.InterfaceIdent
return fmt.Errorf("failed to fetch raw routes: %w", err)
}
for _, rawRoute := range rawRoutes {
netlinkAddr := domain.CidrFromIpNet(*rawRoute.Dst)
var netlinkAddr domain.Cidr
if rawRoute.Dst == nil {
if family == netlink.FAMILY_V4 {
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
} else {
netlinkAddr, _ = domain.CidrFromString("::/0")
}
} else {
netlinkAddr = domain.CidrFromIpNet(*rawRoute.Dst)
}
remove := true
for _, allowedIP := range allowedIPs {
if netlinkAddr == allowedIP {
@ -358,10 +435,66 @@ func (r *WgRepo) SaveRoutes(_ context.Context, interfaceId domain.InterfaceIdent
return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err)
}
}
return nil
}
func (r *WgRepo) cleanupRoutes(iface *domain.Interface, peers []domain.Peer) error {
link, err := r.nl.LinkByName(string(iface.Identifier))
if err != nil {
return fmt.Errorf("unable to find physical interface %s: %w", iface.Identifier, err)
}
table, _, err := r.getRoutingTableAndFwMark(iface, peers, link)
if err != nil {
return fmt.Errorf("failed to get table and fwmark: %w", err)
}
if table == 0 {
return nil // noting to remove
}
delRule := netlink.NewRule()
delRule.Family = netlink.FAMILY_ALL
delRule.Table = table
if err := r.nl.RuleDel(delRule); err != nil && !errors.Is(err, unix.ENOENT) {
return fmt.Errorf("failed to delete rule for table %d: %w", table, err)
}
return nil
}
func (r *WgRepo) getRoutingTableAndFwMark(iface *domain.Interface, peers []domain.Peer, link netlink.Link) (table, fwmark int, err error) {
allowedIPs := iface.GetAllowedIPs(peers)
containsDefaultRoute := false
for _, allowedIP := range allowedIPs {
if allowedIP.Prefix().Bits() == 0 {
containsDefaultRoute = true
break
}
}
table = iface.GetRoutingTable()
fwmark = int(iface.FirewallMark)
if containsDefaultRoute && table <= 0 {
table = 20000 + link.Attrs().Index // generate a new routing table base on interface index
logrus.Debugf("using routing table %d to handle default routes", table)
}
if containsDefaultRoute && fwmark == 0 {
fwmark = 20000 + link.Attrs().Index // generate a new (temporary) firewall mark based on the interface index
logrus.Debugf("using fwmark %d to handle default routes", table)
// apply the fwmark
err = r.wg.ConfigureDevice(string(iface.Identifier), wgtypes.Config{
FirewallMark: &fwmark,
})
if err != nil {
return 0, 0, fmt.Errorf("failed to update temporary fwmark to: %d: %w", fwmark, err)
}
}
return
}
func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
if err := r.deleteLowLevelInterface(id); err != nil {
return err

View File

@ -41,7 +41,7 @@ type InterfaceController interface {
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
SaveRoutes(_ context.Context, deviceId domain.InterfaceIdentifier, table int, allowedIPs []domain.Cidr) error
SaveRoutes(_ context.Context, iface *domain.Interface, peers []domain.Peer) error
}
type WgQuickController interface {

View File

@ -384,7 +384,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface, pee
return nil, fmt.Errorf("failed to save interface: %w", err)
}
err = m.wg.SaveRoutes(ctx, iface.Identifier, iface.GetRoutingTable(), iface.GetAllowedIPs(peers))
err = m.wg.SaveRoutes(ctx, iface, peers)
if err != nil {
return nil, fmt.Errorf("failed to save routes: %w", err)
}

View File

@ -284,7 +284,7 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
if err != nil {
return fmt.Errorf("failed to load peer interface %s: %w", ifaceId, err)
}
err = m.wg.SaveRoutes(ctx, iface.Identifier, iface.GetRoutingTable(), iface.GetAllowedIPs(ifacePeers))
err = m.wg.SaveRoutes(ctx, iface, ifacePeers)
if err != nil {
return fmt.Errorf("failed to update peer routes on interface %s: %w", ifaceId, err)
}

View File

@ -114,12 +114,17 @@ func (i *Interface) GetAllowedIPs(peers []Peer) []Cidr {
return allowedCidrs
}
// GetRoutingTable returns the routing table number or -1 if an error occurred or RoutingTable was set to "off"
// GetRoutingTable returns the routing table number or
//
// -1 if an error occurred
// -2 if RoutingTable was set to "off"
func (i *Interface) GetRoutingTable() int {
routingTableStr := strings.ToLower(i.RoutingTable)
switch {
case routingTableStr == "":
return 0
case routingTableStr == "off":
return -2
case strings.HasPrefix(routingTableStr, "0x"):
numberStr := strings.ReplaceAll(routingTableStr, "0x", "")
routingTable, err := strconv.ParseUint(numberStr, 16, 64)