diff --git a/internal/adapters/database.go b/internal/adapters/database.go index b4676e2..ff74207 100644 --- a/internal/adapters/database.go +++ b/internal/adapters/database.go @@ -331,17 +331,17 @@ func (r *SqlRepo) upsertInterface(ui *domain.ContextUserInfo, tx *gorm.DB, in *d func (r *SqlRepo) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error { err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - err := r.db.WithContext(ctx).Where("interface_identifier = ?", id).Delete(&domain.Peer{}).Error + err := tx.Where("interface_identifier = ?", id).Delete(&domain.Peer{}).Error if err != nil { return err } - err = r.db.WithContext(ctx).Delete(&domain.InterfaceStatus{InterfaceId: id}).Error + err = tx.Delete(&domain.InterfaceStatus{InterfaceId: id}).Error if err != nil { return err } - err = r.db.WithContext(ctx).Debug().Select(clause.Associations).Delete(&domain.Interface{Identifier: id}).Error + err = tx.Select(clause.Associations).Delete(&domain.Interface{Identifier: id}).Error if err != nil { return err } @@ -518,12 +518,12 @@ func (r *SqlRepo) upsertPeer(ui *domain.ContextUserInfo, tx *gorm.DB, peer *doma func (r *SqlRepo) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error { err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - err := r.db.WithContext(ctx).Delete(&domain.PeerStatus{PeerId: id}).Error + err := tx.Delete(&domain.PeerStatus{PeerId: id}).Error if err != nil { return err } - err = r.db.WithContext(ctx).Select(clause.Associations).Delete(&domain.Peer{Identifier: id}).Error + err = tx.Select(clause.Associations).Delete(&domain.Peer{Identifier: id}).Error if err != nil { return err } diff --git a/internal/adapters/wireguard.go b/internal/adapters/wireguard.go index 76a8c1d..434c246 100644 --- a/internal/adapters/wireguard.go +++ b/internal/adapters/wireguard.go @@ -308,6 +308,10 @@ func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifie func (r *WgRepo) deleteLowLevelInterface(id domain.InterfaceIdentifier) error { link, err := r.nl.LinkByName(string(id)) if err != nil { + var linkNotFoundError netlink.LinkNotFoundError + if errors.As(err, &linkNotFoundError) { + return nil // ignore not found error + } return fmt.Errorf("unable to find low level interface: %w", err) } diff --git a/internal/adapters/wireguard_integration_test.go b/internal/adapters/wireguard_integration_test.go index d9a20b9..07b35f2 100644 --- a/internal/adapters/wireguard_integration_test.go +++ b/internal/adapters/wireguard_integration_test.go @@ -119,42 +119,3 @@ func TestWireGuardUpdateInterface(t *testing.T) { assert.Contains(t, string(out), ipAddress) assert.Contains(t, string(out), ipV6Address) } - -func TestWireGuardCreateInterfaceWithRoutes(t *testing.T) { - mgr := setup(t) - - interfaceName := domain.InterfaceIdentifier("wg_test_001") - ipAddress := "10.11.12.13" - ipV6Address := "1337:d34d:b33f::2" - defer mgr.DeleteInterface(context.Background(), interfaceName) - - iface := &domain.Interface{ - Identifier: interfaceName, - //RoutingTable: "1234", - } - peers := []domain.Peer{ - { - Interface: domain.PeerInterfaceConfig{ - Addresses: domain.CidrsMust(domain.CidrsFromString("10.11.12.14/32,10.22.33.44/32")), - }, - }, - } - - err := mgr.SaveInterface2(context.Background(), iface, peers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { - pi.Addresses = []domain.Cidr{ - domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipAddress), Mask: net.CIDRMask(24, 32)}), - domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipV6Address), Mask: net.CIDRMask(64, 128)}), - } - pi.DeviceUp = true - return pi, nil - }) - assert.NoError(t, err) - - // Validate that the interface has been created - cmd := exec.Command("ip", "addr") - out, err := cmd.CombinedOutput() - assert.NoError(t, err) - assert.Contains(t, string(out), interfaceName) - assert.Contains(t, string(out), ipAddress) - assert.Contains(t, string(out), ipV6Address) -} diff --git a/internal/app/route/routes.go b/internal/app/route/routes.go index 66942a0..ee5b0ae 100644 --- a/internal/app/route/routes.go +++ b/internal/app/route/routes.go @@ -15,13 +15,16 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -type defaultRouteRule struct { - ifaceId domain.InterfaceIdentifier - fwMark int - table int - family int +type routeRuleInfo struct { + ifaceId domain.InterfaceIdentifier + fwMark int + table int + family int + hasDefault bool } +// Manager is try to mimic wg-quick behaviour (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash) +// for default routes. type Manager struct { cfg *config.Config bus evbus.MessageBus @@ -55,7 +58,7 @@ func NewRouteManager(cfg *config.Config, bus evbus.MessageBus, db InterfaceAndPe func (m Manager) connectToMessageBus() { _ = m.bus.Subscribe(app.TopicRouteUpdate, m.handleRouteUpdateEvent) - _ = m.bus.Subscribe(app.TopicRouteRemove, m.handleRouteUpdateEvent) + _ = m.bus.Subscribe(app.TopicRouteRemove, m.handleRouteRemoveEvent) } func (m Manager) StartBackgroundJobs(ctx context.Context) { @@ -75,14 +78,14 @@ func (m Manager) handleRouteUpdateEvent(srcDescription string) { func (m Manager) handleRouteRemoveEvent(info domain.RoutingTableInfo) { logrus.Debugf("handling route remove event for: %s", info.String()) - if info.Table == -2 { + if !info.ManagementEnabled() { return // route management disabled } - if err := m.removeFwMarkRules(info.FwMark, info.FwMark, netlink.FAMILY_V4); err != nil { + if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V4); err != nil { logrus.Errorf("failed to remove v4 fwmark rules: %v", err) } - if err := m.removeFwMarkRules(info.FwMark, info.FwMark, netlink.FAMILY_V6); err != nil { + if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V6); err != nil { logrus.Errorf("failed to remove v6 fwmark rules: %v", err) } @@ -95,7 +98,7 @@ func (m Manager) syncRoutes(ctx context.Context) error { return fmt.Errorf("failed to find all interfaces: %w", err) } - rules := map[int][]defaultRouteRule{ + rules := map[int][]routeRuleInfo{ netlink.FAMILY_V4: nil, netlink.FAMILY_V6: nil, } @@ -103,6 +106,9 @@ func (m Manager) syncRoutes(ctx context.Context) error { if iface.IsDisabled() { continue // disabled interface does not need route entries } + if !iface.ManageRoutingTable() { + continue + } peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier) if err != nil { @@ -132,20 +138,22 @@ func (m Manager) syncRoutes(ctx context.Context) error { return fmt.Errorf("failed to remove deprecated v6 routes for %s: %w", iface.Identifier, err) } - if defRouteV4 { - rules[netlink.FAMILY_V4] = append(rules[netlink.FAMILY_V4], defaultRouteRule{ - ifaceId: iface.Identifier, - fwMark: fwmark, - table: table, - family: netlink.FAMILY_V4, + if table != 0 { + rules[netlink.FAMILY_V4] = append(rules[netlink.FAMILY_V4], routeRuleInfo{ + ifaceId: iface.Identifier, + fwMark: fwmark, + table: table, + family: netlink.FAMILY_V4, + hasDefault: defRouteV4, }) } - if defRouteV6 { - rules[netlink.FAMILY_V6] = append(rules[netlink.FAMILY_V6], defaultRouteRule{ - ifaceId: iface.Identifier, - fwMark: fwmark, - table: table, - family: netlink.FAMILY_V6, + if table != 0 { + rules[netlink.FAMILY_V6] = append(rules[netlink.FAMILY_V6], routeRuleInfo{ + ifaceId: iface.Identifier, + fwMark: fwmark, + table: table, + family: netlink.FAMILY_V6, + hasDefault: defRouteV6, }) } } @@ -153,7 +161,7 @@ func (m Manager) syncRoutes(ctx context.Context) error { return m.syncRouteRules(rules) } -func (m Manager) syncRouteRules(allRules map[int][]defaultRouteRule) error { +func (m Manager) syncRouteRules(allRules map[int][]routeRuleInfo) error { for family, rules := range allRules { // update fwmark rules if err := m.setFwMarkRules(rules, family); err != nil { @@ -174,7 +182,7 @@ func (m Manager) syncRouteRules(allRules map[int][]defaultRouteRule) error { return nil } -func (m Manager) setFwMarkRules(rules []defaultRouteRule, family int) error { +func (m Manager) setFwMarkRules(rules []routeRuleInfo, family int) error { for _, rule := range rules { existingRules, err := m.nl.RuleList(family) if err != nil { @@ -229,8 +237,14 @@ func (m Manager) removeFwMarkRules(fwmark, table int, family int) error { return nil } -func (m Manager) setMainRule(rules []defaultRouteRule, family int) error { - shouldHaveMainRule := len(rules) != 0 +func (m Manager) setMainRule(rules []routeRuleInfo, family int) error { + shouldHaveMainRule := false + for _, rule := range rules { + if rule.hasDefault == true { + shouldHaveMainRule = true + break + } + } if !shouldHaveMainRule { return nil } @@ -269,13 +283,19 @@ func (m Manager) setMainRule(rules []defaultRouteRule, family int) error { return nil } -func (m Manager) cleanupMainRule(rules []defaultRouteRule, family int) error { +func (m Manager) cleanupMainRule(rules []routeRuleInfo, family int) error { existingRules, err := m.nl.RuleList(family) if err != nil { return fmt.Errorf("failed to get existing rules for family %d: %w", family, err) } - shouldHaveMainRule := len(rules) != 0 + shouldHaveMainRule := false + for _, rule := range rules { + if rule.hasDefault == true { + shouldHaveMainRule = true + break + } + } mainRules := 0 for _, existingRule := range existingRules { @@ -307,7 +327,7 @@ func (m Manager) cleanupMainRule(rules []defaultRouteRule, family int) error { return nil } -func (m Manager) getRulePriority(existingRules []netlink.Rule) int { +func (m Manager) getMainRulePriority(existingRules []netlink.Rule) int { prio := m.cfg.Advanced.RulePrioOffset for { isFresh := true @@ -326,7 +346,7 @@ func (m Manager) getRulePriority(existingRules []netlink.Rule) int { return prio } -func (m Manager) getMainRulePriority(existingRules []netlink.Rule) int { +func (m Manager) getRulePriority(existingRules []netlink.Rule) int { prio := 32700 // linux main rule has a prio of 32766 for { isFresh := true @@ -346,7 +366,6 @@ func (m Manager) getMainRulePriority(existingRules []netlink.Rule) int { } func (m Manager) setInterfaceRoutes(link netlink.Link, table int, allowedIPs []domain.Cidr) error { - // try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash) for _, allowedIP := range allowedIPs { err := m.nl.RouteReplace(&netlink.Route{ LinkIndex: link.Attrs().Index, @@ -374,16 +393,17 @@ func (m Manager) removeDeprecatedRoutes(link netlink.Link, family int, allowedIP return fmt.Errorf("failed to fetch raw routes: %w", err) } for _, rawRoute := range rawRoutes { - var netlinkAddr domain.Cidr - if rawRoute.Dst == nil { + if rawRoute.Dst == nil { // handle default route + var netlinkAddr domain.Cidr if family == netlink.FAMILY_V4 { netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0") } else { netlinkAddr, _ = domain.CidrFromString("::/0") } - } else { - netlinkAddr = domain.CidrFromIpNet(*rawRoute.Dst) + rawRoute.Dst = netlinkAddr.IpNet() } + + netlinkAddr := domain.CidrFromIpNet(*rawRoute.Dst) remove := true for _, allowedIP := range allowedIPs { if netlinkAddr == allowedIP { @@ -405,22 +425,20 @@ func (m Manager) removeDeprecatedRoutes(link netlink.Link, family int, allowedIP } func (m Manager) getRoutingTableAndFwMark(iface *domain.Interface, allowedIPs []domain.Cidr, link netlink.Link) (table, fwmark int, err error) { - defRouteV4, defRouteV6 := m.containsDefaultRoute(allowedIPs) - table = iface.GetRoutingTable() fwmark = int(iface.FirewallMark) - if (defRouteV4 || defRouteV6) && table <= 0 { - table = m.cfg.Advanced.RouteTableOffset + link.Attrs().Index // generate a new routing table base on interface index - logrus.Debugf("using routing table %d to handle default routes", table) - } - if (defRouteV4 || defRouteV6) && fwmark == 0 { + if fwmark == 0 { fwmark = m.cfg.Advanced.RouteTableOffset + link.Attrs().Index // generate a new (temporary) firewall mark based on the interface index - logrus.Debugf("using fwmark %d to handle default routes", table) + logrus.Debugf("using fwmark %d to handle routes", table) - // apply the fwmark + // apply the temporary fwmark to the wireguard interface err = m.setFwMark(iface.Identifier, fwmark) } + if table == 0 { + table = fwmark // generate a new routing table base on interface index + logrus.Debugf("using routing table %d to handle default routes", table) + } return } diff --git a/internal/app/wireguard/wireguard_interfaces.go b/internal/app/wireguard/wireguard_interfaces.go index 5dcc0fc..e8f93de 100644 --- a/internal/app/wireguard/wireguard_interfaces.go +++ b/internal/app/wireguard/wireguard_interfaces.go @@ -8,6 +8,7 @@ import ( "github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/domain" "github.com/sirupsen/logrus" + "os" "time" ) @@ -350,12 +351,14 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif return fmt.Errorf("deletion failure: %w", err) } - if physicalInterface != nil { - m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{ - FwMark: int(physicalInterface.FirewallMark), - Table: existingInterface.GetRoutingTable(), - }) + fwMark := int(existingInterface.FirewallMark) + if physicalInterface != nil && fwMark == 0 { + fwMark = int(physicalInterface.FirewallMark) } + m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{ + FwMark: fwMark, + Table: existingInterface.GetRoutingTable(), + }) if err := m.handleInterfacePostSaveHooks(true, existingInterface); err != nil { return fmt.Errorf("post-delete hooks failed: %w", err) @@ -395,6 +398,17 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface, pee } m.bus.Publish(app.TopicRouteUpdate, "interface updated: "+string(iface.Identifier)) + if iface.IsDisabled() { + physicalInterface, _ := m.wg.GetInterface(ctx, iface.Identifier) + fwMark := int(iface.FirewallMark) + if physicalInterface != nil && fwMark == 0 { + fwMark = int(physicalInterface.FirewallMark) + } + m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{ + FwMark: fwMark, + Table: iface.GetRoutingTable(), + }) + } if err := m.handleInterfacePostSaveHooks(stateChanged, iface); err != nil { return nil, fmt.Errorf("post-save hooks failed: %w", err) @@ -668,7 +682,7 @@ func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceId } for _, peer := range allPeers { err = m.wg.DeletePeer(ctx, id, peer.Identifier) - if err != nil { + if err != nil && !errors.Is(err, os.ErrNotExist) { return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err) } diff --git a/internal/domain/interface.go b/internal/domain/interface.go index 03eabb6..ff3ce65 100644 --- a/internal/domain/interface.go +++ b/internal/domain/interface.go @@ -3,6 +3,7 @@ package domain import ( "fmt" "github.com/h44z/wg-portal/internal" + "github.com/sirupsen/logrus" "math" "regexp" "strconv" @@ -34,7 +35,7 @@ type Interface struct { Mtu int // the device MTU FirewallMark int32 // a firewall mark - RoutingTable string // the routing table + RoutingTable string // the routing table number or "off" if the routing table should not be managed PreUp string // action that is executed before the device is up PostUp string // action that is executed after the device is up @@ -114,30 +115,37 @@ func (i *Interface) GetAllowedIPs(peers []Peer) []Cidr { return allowedCidrs } +func (i *Interface) ManageRoutingTable() bool { + routingTableStr := strings.ToLower(i.RoutingTable) + return routingTableStr != "off" +} + // GetRoutingTable returns the routing table number or // -// -1 if an error occurred -// -2 if RoutingTable was set to "off" +// -1 if RoutingTable was set to "off" or an error occurred func (i *Interface) GetRoutingTable() int { routingTableStr := strings.ToLower(i.RoutingTable) switch { case routingTableStr == "": return 0 case routingTableStr == "off": - return -2 + return -1 case strings.HasPrefix(routingTableStr, "0x"): numberStr := strings.ReplaceAll(routingTableStr, "0x", "") routingTable, err := strconv.ParseUint(numberStr, 16, 64) if err != nil { + logrus.Errorf("invalid hex routing table %s: %w", routingTableStr, err) return -1 } if routingTable > math.MaxInt32 { + logrus.Errorf("invalid routing table %s, too big", routingTableStr) return -1 } return int(routingTable) default: routingTable, err := strconv.Atoi(routingTableStr) if err != nil { + logrus.Errorf("invalid routing table %s: %w", routingTableStr, err) return -1 } return routingTable @@ -220,3 +228,19 @@ type RoutingTableInfo struct { func (r RoutingTableInfo) String() string { return fmt.Sprintf("%d -> %d", r.FwMark, r.Table) } + +func (r RoutingTableInfo) ManagementEnabled() bool { + if r.Table == -1 { + return false + } + + return true +} + +func (r RoutingTableInfo) GetRoutingTable() int { + if r.Table <= 0 { + return r.FwMark // use the dynamic routing table which has the same number as the firewall mark + } + + return r.Table +}