From 22949963cfd886c7e96a6312900836ec15a0896e Mon Sep 17 00:00:00 2001 From: Christoph Haas Date: Sun, 30 Jul 2023 22:02:52 +0200 Subject: [PATCH] route management --- cmd/wg-portal/main.go | 5 + internal/adapters/wireguard.go | 198 -------- internal/app/eventbus.go | 2 + internal/app/route/repos.go | 11 + internal/app/route/routes.go | 453 ++++++++++++++++++ internal/app/wireguard/repos.go | 1 - .../app/wireguard/wireguard_interfaces.go | 15 +- internal/app/wireguard/wireguard_peers.go | 12 +- internal/config/config.go | 4 + internal/domain/interface.go | 9 + 10 files changed, 498 insertions(+), 212 deletions(-) create mode 100644 internal/app/route/repos.go create mode 100644 internal/app/route/routes.go diff --git a/cmd/wg-portal/main.go b/cmd/wg-portal/main.go index 1e1a61b..eb7c237 100644 --- a/cmd/wg-portal/main.go +++ b/cmd/wg-portal/main.go @@ -8,6 +8,7 @@ import ( "github.com/h44z/wg-portal/internal/app/auth" "github.com/h44z/wg-portal/internal/app/configfile" "github.com/h44z/wg-portal/internal/app/mail" + "github.com/h44z/wg-portal/internal/app/route" "github.com/h44z/wg-portal/internal/app/users" "github.com/h44z/wg-portal/internal/app/wireguard" "os" @@ -86,6 +87,10 @@ func main() { internal.AssertNoError(err) auditRecorder.StartBackgroundJobs(ctx) + routeManager, err := route.NewRouteManager(cfg, eventBus, database) + internal.AssertNoError(err) + routeManager.StartBackgroundJobs(ctx) + backend, err := app.New(cfg, eventBus, authenticator, userManager, wireGuardManager, statisticsCollector, cfgFileManager, mailManager) internal.AssertNoError(err) diff --git a/internal/adapters/wireguard.go b/internal/adapters/wireguard.go index 07e5240..76a8c1d 100644 --- a/internal/adapters/wireguard.go +++ b/internal/adapters/wireguard.go @@ -6,9 +6,7 @@ import ( "fmt" "github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/lowlevel" - "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" - "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "os" @@ -299,202 +297,6 @@ func (r *WgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error { return nil } -func (r *WgRepo) SaveRoutes(_ context.Context, iface *domain.Interface, peers []domain.Peer) error { - table := iface.GetRoutingTable() - if table == -2 { - logrus.Trace("ignoring route update, feature disabled") - return nil - } - - if !iface.IsDisabled() { - return r.setupRoutes(iface, peers) - } else { - return r.cleanupRoutes(iface, peers) - } -} - -func (r *WgRepo) setupRoutes(iface *domain.Interface, peers []domain.Peer) error { - link, err := r.nl.LinkByName(string(iface.Identifier)) - if err != nil { - return fmt.Errorf("unable to find physical interface %s: %w", iface.Identifier, err) - } - - if link.Attrs().OperState == netlink.OperDown { - return nil // cannot set route for interface that is down - } - - table, fwmark, err := r.getRoutingTableAndFwMark(iface, peers, link) - if err != nil { - return fmt.Errorf("failed to get table and fwmark: %w", err) - } - - // try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash) - allowedIPs := iface.GetAllowedIPs(peers) - for _, allowedIP := range allowedIPs { - if allowedIP.Prefix().Bits() == 0 { // default route handling - if err := r.nl.RouteReplace(&netlink.Route{ - LinkIndex: link.Attrs().Index, - Dst: allowedIP.IpNet(), - Table: table, - Scope: unix.RT_SCOPE_LINK, - Type: unix.RTN_UNICAST, - }); err != nil { - return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err) - } - - family := netlink.FAMILY_V4 - if !allowedIP.IsV4() { - family = netlink.FAMILY_V6 - } - if err := r.nl.RuleAdd(&netlink.Rule{ - Family: family, - Table: table, - Mark: fwmark, - Invert: true, - SuppressIfgroup: -1, - SuppressPrefixlen: -1, - Priority: -1, - Mask: -1, - Goto: -1, - Flow: -1, - }); err != nil { - return fmt.Errorf("failed to setup rule for fwmark %d: %w", fwmark, err) - } - if err := r.nl.RuleAdd(&netlink.Rule{ - Family: family, - Table: unix.RT_TABLE_MAIN, - SuppressIfgroup: -1, - SuppressPrefixlen: 0, - Priority: -1, - Mark: -1, - Mask: -1, - Goto: -1, - Flow: -1, - }); err != nil { - return fmt.Errorf("failed to setup rule for main table: %w", err) - } - } else { - err := r.nl.RouteReplace(&netlink.Route{ - LinkIndex: link.Attrs().Index, - Dst: allowedIP.IpNet(), - Table: table, - Scope: unix.RT_SCOPE_LINK, - Type: unix.RTN_UNICAST, - }) - if err != nil { - return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err) - } - } - } - - if err := r.removeDeprecatedRoutes(link, allowedIPs, netlink.FAMILY_V4); err != nil { - return fmt.Errorf("failed to remove deprecated v4 routes: %w", err) - } - if err := r.removeDeprecatedRoutes(link, allowedIPs, netlink.FAMILY_V6); err != nil { - return fmt.Errorf("failed to remove deprecated v6 routes: %w", err) - } - - return nil -} - -func (r *WgRepo) removeDeprecatedRoutes(link netlink.Link, allowedIPs []domain.Cidr, family int) error { - rawRoutes, err := r.nl.RouteListFiltered(family, &netlink.Route{ - LinkIndex: link.Attrs().Index, - Table: unix.RT_TABLE_UNSPEC, // all tables - Scope: unix.RT_SCOPE_LINK, - Type: unix.RTN_UNICAST, - }, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF) - if err != nil { - return fmt.Errorf("failed to fetch raw routes: %w", err) - } - for _, rawRoute := range rawRoutes { - var netlinkAddr domain.Cidr - if rawRoute.Dst == nil { - if family == netlink.FAMILY_V4 { - netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0") - } else { - netlinkAddr, _ = domain.CidrFromString("::/0") - } - } else { - netlinkAddr = domain.CidrFromIpNet(*rawRoute.Dst) - } - remove := true - for _, allowedIP := range allowedIPs { - if netlinkAddr == allowedIP { - remove = false - break - } - } - - if !remove { - continue - } - - err := r.nl.RouteDel(&rawRoute) - if err != nil { - return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err) - } - } - return nil -} - -func (r *WgRepo) cleanupRoutes(iface *domain.Interface, peers []domain.Peer) error { - link, err := r.nl.LinkByName(string(iface.Identifier)) - if err != nil { - return fmt.Errorf("unable to find physical interface %s: %w", iface.Identifier, err) - } - - table, _, err := r.getRoutingTableAndFwMark(iface, peers, link) - if err != nil { - return fmt.Errorf("failed to get table and fwmark: %w", err) - } - - if table == 0 { - return nil // noting to remove - } - - delRule := netlink.NewRule() - delRule.Family = netlink.FAMILY_ALL - delRule.Table = table - if err := r.nl.RuleDel(delRule); err != nil && !errors.Is(err, unix.ENOENT) { - return fmt.Errorf("failed to delete rule for table %d: %w", table, err) - } - - return nil -} - -func (r *WgRepo) getRoutingTableAndFwMark(iface *domain.Interface, peers []domain.Peer, link netlink.Link) (table, fwmark int, err error) { - allowedIPs := iface.GetAllowedIPs(peers) - containsDefaultRoute := false - for _, allowedIP := range allowedIPs { - if allowedIP.Prefix().Bits() == 0 { - containsDefaultRoute = true - break - } - } - - table = iface.GetRoutingTable() - fwmark = int(iface.FirewallMark) - - if containsDefaultRoute && table <= 0 { - table = 20000 + link.Attrs().Index // generate a new routing table base on interface index - logrus.Debugf("using routing table %d to handle default routes", table) - } - if containsDefaultRoute && fwmark == 0 { - fwmark = 20000 + link.Attrs().Index // generate a new (temporary) firewall mark based on the interface index - logrus.Debugf("using fwmark %d to handle default routes", table) - - // apply the fwmark - err = r.wg.ConfigureDevice(string(iface.Identifier), wgtypes.Config{ - FirewallMark: &fwmark, - }) - if err != nil { - return 0, 0, fmt.Errorf("failed to update temporary fwmark to: %d: %w", fwmark, err) - } - } - return -} - func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error { if err := r.deleteLowLevelInterface(id); err != nil { return err diff --git a/internal/app/eventbus.go b/internal/app/eventbus.go index 9a59e6a..da21b9d 100644 --- a/internal/app/eventbus.go +++ b/internal/app/eventbus.go @@ -5,3 +5,5 @@ const TopicUserRegistered = "user:registered" const TopicUserDisabled = "user:disabled" const TopicUserDeleted = "user:deleted" const TopicAuthLogin = "auth:login" +const TopicRouteUpdate = "route:update" +const TopicRouteRemove = "route:remove" diff --git a/internal/app/route/repos.go b/internal/app/route/repos.go new file mode 100644 index 0000000..b6e0a5c --- /dev/null +++ b/internal/app/route/repos.go @@ -0,0 +1,11 @@ +package route + +import ( + "context" + "github.com/h44z/wg-portal/internal/domain" +) + +type InterfaceAndPeerDatabaseRepo interface { + GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) + GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error) +} diff --git a/internal/app/route/routes.go b/internal/app/route/routes.go new file mode 100644 index 0000000..66942a0 --- /dev/null +++ b/internal/app/route/routes.go @@ -0,0 +1,453 @@ +package route + +import ( + "context" + "fmt" + "github.com/h44z/wg-portal/internal/app" + "github.com/h44z/wg-portal/internal/config" + "github.com/h44z/wg-portal/internal/domain" + "github.com/h44z/wg-portal/internal/lowlevel" + "github.com/sirupsen/logrus" + evbus "github.com/vardius/message-bus" + "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +type defaultRouteRule struct { + ifaceId domain.InterfaceIdentifier + fwMark int + table int + family int +} + +type Manager struct { + cfg *config.Config + bus evbus.MessageBus + + wg lowlevel.WireGuardClient + nl lowlevel.NetlinkClient + db InterfaceAndPeerDatabaseRepo +} + +func NewRouteManager(cfg *config.Config, bus evbus.MessageBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) { + wg, err := wgctrl.New() + if err != nil { + panic("failed to init wgctrl: " + err.Error()) + } + + nl := &lowlevel.NetlinkManager{} + + m := &Manager{ + cfg: cfg, + bus: bus, + + db: db, + wg: wg, + nl: nl, + } + + m.connectToMessageBus() + + return m, nil +} + +func (m Manager) connectToMessageBus() { + _ = m.bus.Subscribe(app.TopicRouteUpdate, m.handleRouteUpdateEvent) + _ = m.bus.Subscribe(app.TopicRouteRemove, m.handleRouteUpdateEvent) +} + +func (m Manager) StartBackgroundJobs(ctx context.Context) { +} + +func (m Manager) handleRouteUpdateEvent(srcDescription string) { + logrus.Debugf("handling route update event: %s", srcDescription) + + err := m.syncRoutes(context.Background()) + if err != nil { + logrus.Errorf("failed to synchronize routes for event %s: %v", srcDescription, err) + } + + logrus.Debugf("routes synchronized, event: %s", srcDescription) +} + +func (m Manager) handleRouteRemoveEvent(info domain.RoutingTableInfo) { + logrus.Debugf("handling route remove event for: %s", info.String()) + + if info.Table == -2 { + return // route management disabled + } + + if err := m.removeFwMarkRules(info.FwMark, info.FwMark, netlink.FAMILY_V4); err != nil { + logrus.Errorf("failed to remove v4 fwmark rules: %v", err) + } + if err := m.removeFwMarkRules(info.FwMark, info.FwMark, netlink.FAMILY_V6); err != nil { + logrus.Errorf("failed to remove v6 fwmark rules: %v", err) + } + + logrus.Debugf("routes removed, table: %s", info.String()) +} + +func (m Manager) syncRoutes(ctx context.Context) error { + interfaces, err := m.db.GetAllInterfaces(ctx) + if err != nil { + return fmt.Errorf("failed to find all interfaces: %w", err) + } + + rules := map[int][]defaultRouteRule{ + netlink.FAMILY_V4: nil, + netlink.FAMILY_V6: nil, + } + for _, iface := range interfaces { + if iface.IsDisabled() { + continue // disabled interface does not need route entries + } + + peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier) + if err != nil { + return fmt.Errorf("failed to find peers for %s: %w", iface.Identifier, err) + } + allowedIPs := iface.GetAllowedIPs(peers) + defRouteV4, defRouteV6 := m.containsDefaultRoute(allowedIPs) + + link, err := m.nl.LinkByName(string(iface.Identifier)) + if err != nil { + return fmt.Errorf("failed to find physical link for %s: %w", iface.Identifier, err) + } + + table, fwmark, err := m.getRoutingTableAndFwMark(&iface, allowedIPs, link) + if err != nil { + return fmt.Errorf("failed to get table and fwmark for %s: %w", iface.Identifier, err) + } + + if err := m.setInterfaceRoutes(link, table, allowedIPs); err != nil { + return fmt.Errorf("failed to set routes for %s: %w", iface.Identifier, err) + } + + if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V4, allowedIPs); err != nil { + return fmt.Errorf("failed to remove deprecated v4 routes for %s: %w", iface.Identifier, err) + } + if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V6, allowedIPs); err != nil { + return fmt.Errorf("failed to remove deprecated v6 routes for %s: %w", iface.Identifier, err) + } + + if defRouteV4 { + rules[netlink.FAMILY_V4] = append(rules[netlink.FAMILY_V4], defaultRouteRule{ + ifaceId: iface.Identifier, + fwMark: fwmark, + table: table, + family: netlink.FAMILY_V4, + }) + } + if defRouteV6 { + rules[netlink.FAMILY_V6] = append(rules[netlink.FAMILY_V6], defaultRouteRule{ + ifaceId: iface.Identifier, + fwMark: fwmark, + table: table, + family: netlink.FAMILY_V6, + }) + } + } + + return m.syncRouteRules(rules) +} + +func (m Manager) syncRouteRules(allRules map[int][]defaultRouteRule) error { + for family, rules := range allRules { + // update fwmark rules + if err := m.setFwMarkRules(rules, family); err != nil { + return err + } + + // update main rule + if err := m.setMainRule(rules, family); err != nil { + return err + } + + // cleanup old main rules + if err := m.cleanupMainRule(rules, family); err != nil { + return err + } + } + + return nil +} + +func (m Manager) setFwMarkRules(rules []defaultRouteRule, family int) error { + for _, rule := range rules { + existingRules, err := m.nl.RuleList(family) + if err != nil { + return fmt.Errorf("failed to get existing rules for family %d: %w", family, err) + } + + ruleExists := false + for _, existingRule := range existingRules { + if rule.fwMark == existingRule.Mark && rule.table == existingRule.Table { + ruleExists = true + break + } + } + + if ruleExists { + continue // rule already exists, no need to recreate it + } + + // create missing rule + if err := m.nl.RuleAdd(&netlink.Rule{ + Family: family, + Table: rule.table, + Mark: rule.fwMark, + Invert: true, + SuppressIfgroup: -1, + SuppressPrefixlen: -1, + Priority: m.getRulePriority(existingRules), + Mask: -1, + Goto: -1, + Flow: -1, + }); err != nil { + return fmt.Errorf("failed to setup rule for fwmark %d and table %d: %w", rule.fwMark, rule.table, err) + } + } + return nil +} + +func (m Manager) removeFwMarkRules(fwmark, table int, family int) error { + existingRules, err := m.nl.RuleList(family) + if err != nil { + return fmt.Errorf("failed to get existing rules for family %d: %w", family, err) + } + + for _, existingRule := range existingRules { + if fwmark == existingRule.Mark && table == existingRule.Table { + existingRule.Family = family // set family, somehow the RuleList method does not populate the family field + if err := m.nl.RuleDel(&existingRule); err != nil { + return fmt.Errorf("failed to delete fwmark rule: %w", err) + } + } + } + return nil +} + +func (m Manager) setMainRule(rules []defaultRouteRule, family int) error { + shouldHaveMainRule := len(rules) != 0 + if !shouldHaveMainRule { + return nil + } + + existingRules, err := m.nl.RuleList(family) + if err != nil { + return fmt.Errorf("failed to get existing rules for family %d: %w", family, err) + } + + ruleExists := false + for _, existingRule := range existingRules { + if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { + ruleExists = true + break + } + } + + if ruleExists { + return nil // rule already exists, skip re-creation + } + + if err := m.nl.RuleAdd(&netlink.Rule{ + Family: family, + Table: unix.RT_TABLE_MAIN, + SuppressIfgroup: -1, + SuppressPrefixlen: 0, + Priority: m.getMainRulePriority(existingRules), + Mark: -1, + Mask: -1, + Goto: -1, + Flow: -1, + }); err != nil { + return fmt.Errorf("failed to setup rule for main table: %w", err) + } + + return nil +} + +func (m Manager) cleanupMainRule(rules []defaultRouteRule, family int) error { + existingRules, err := m.nl.RuleList(family) + if err != nil { + return fmt.Errorf("failed to get existing rules for family %d: %w", family, err) + } + + shouldHaveMainRule := len(rules) != 0 + + mainRules := 0 + for _, existingRule := range existingRules { + if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { + mainRules++ + } + } + + removalCount := 0 + if mainRules > 1 { + removalCount = mainRules - 1 // we only want one single rule + } + if !shouldHaveMainRule { + removalCount = mainRules + } + + for _, existingRule := range existingRules { + if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { + if removalCount > 0 { + existingRule.Family = family // set family, somehow the RuleList method does not populate the family field + if err := m.nl.RuleDel(&existingRule); err != nil { + return fmt.Errorf("failed to delete main rule: %w", err) + } + removalCount-- + } + } + } + + return nil +} + +func (m Manager) getRulePriority(existingRules []netlink.Rule) int { + prio := m.cfg.Advanced.RulePrioOffset + for { + isFresh := true + for _, existingRule := range existingRules { + if existingRule.Priority == prio { + isFresh = false + break + } + } + if isFresh { + break + } else { + prio++ + } + } + return prio +} + +func (m Manager) getMainRulePriority(existingRules []netlink.Rule) int { + prio := 32700 // linux main rule has a prio of 32766 + for { + isFresh := true + for _, existingRule := range existingRules { + if existingRule.Priority == prio { + isFresh = false + break + } + } + if isFresh { + break + } else { + prio-- + } + } + return prio +} + +func (m Manager) setInterfaceRoutes(link netlink.Link, table int, allowedIPs []domain.Cidr) error { + // try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash) + for _, allowedIP := range allowedIPs { + err := m.nl.RouteReplace(&netlink.Route{ + LinkIndex: link.Attrs().Index, + Dst: allowedIP.IpNet(), + Table: table, + Scope: unix.RT_SCOPE_LINK, + Type: unix.RTN_UNICAST, + }) + if err != nil { + return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err) + } + } + + return nil +} + +func (m Manager) removeDeprecatedRoutes(link netlink.Link, family int, allowedIPs []domain.Cidr) error { + rawRoutes, err := m.nl.RouteListFiltered(family, &netlink.Route{ + LinkIndex: link.Attrs().Index, + Table: unix.RT_TABLE_UNSPEC, // all tables + Scope: unix.RT_SCOPE_LINK, + Type: unix.RTN_UNICAST, + }, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF) + if err != nil { + return fmt.Errorf("failed to fetch raw routes: %w", err) + } + for _, rawRoute := range rawRoutes { + var netlinkAddr domain.Cidr + if rawRoute.Dst == nil { + if family == netlink.FAMILY_V4 { + netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0") + } else { + netlinkAddr, _ = domain.CidrFromString("::/0") + } + } else { + netlinkAddr = domain.CidrFromIpNet(*rawRoute.Dst) + } + remove := true + for _, allowedIP := range allowedIPs { + if netlinkAddr == allowedIP { + remove = false + break + } + } + + if !remove { + continue + } + + err := m.nl.RouteDel(&rawRoute) + if err != nil { + return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err) + } + } + return nil +} + +func (m Manager) getRoutingTableAndFwMark(iface *domain.Interface, allowedIPs []domain.Cidr, link netlink.Link) (table, fwmark int, err error) { + defRouteV4, defRouteV6 := m.containsDefaultRoute(allowedIPs) + + table = iface.GetRoutingTable() + fwmark = int(iface.FirewallMark) + + if (defRouteV4 || defRouteV6) && table <= 0 { + table = m.cfg.Advanced.RouteTableOffset + link.Attrs().Index // generate a new routing table base on interface index + logrus.Debugf("using routing table %d to handle default routes", table) + } + if (defRouteV4 || defRouteV6) && fwmark == 0 { + fwmark = m.cfg.Advanced.RouteTableOffset + link.Attrs().Index // generate a new (temporary) firewall mark based on the interface index + logrus.Debugf("using fwmark %d to handle default routes", table) + + // apply the fwmark + err = m.setFwMark(iface.Identifier, fwmark) + } + return +} + +func (m Manager) setFwMark(id domain.InterfaceIdentifier, fwmark int) error { + err := m.wg.ConfigureDevice(string(id), wgtypes.Config{ + FirewallMark: &fwmark, + }) + if err != nil { + return fmt.Errorf("failed to update fwmark to: %d: %w", fwmark, err) + } + return nil +} + +func (m Manager) containsDefaultRoute(allowedIPs []domain.Cidr) (ipV4, ipV6 bool) { + for _, allowedIP := range allowedIPs { + if ipV4 && ipV6 { + break // speed up + } + + if allowedIP.Prefix().Bits() == 0 { + if allowedIP.IsV4() { + ipV4 = true + } else { + ipV6 = true + } + } + } + + return +} diff --git a/internal/app/wireguard/repos.go b/internal/app/wireguard/repos.go index 46222d0..00b0b4b 100644 --- a/internal/app/wireguard/repos.go +++ b/internal/app/wireguard/repos.go @@ -41,7 +41,6 @@ type InterfaceController interface { DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error - SaveRoutes(_ context.Context, iface *domain.Interface, peers []domain.Peer) error } type WgQuickController interface { diff --git a/internal/app/wireguard/wireguard_interfaces.go b/internal/app/wireguard/wireguard_interfaces.go index 83fadb7..5dcc0fc 100644 --- a/internal/app/wireguard/wireguard_interfaces.go +++ b/internal/app/wireguard/wireguard_interfaces.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "github.com/h44z/wg-portal/internal" + "github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/domain" "github.com/sirupsen/logrus" "time" @@ -327,6 +328,8 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif existingInterface.Disabled = &now // simulate a disabled interface existingInterface.DisabledReason = domain.DisabledReasonDeleted + physicalInterface, _ := m.wg.GetInterface(ctx, id) + if err := m.handleInterfacePreSaveHooks(true, existingInterface); err != nil { return fmt.Errorf("pre-delete hooks failed: %w", err) } @@ -347,6 +350,13 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif return fmt.Errorf("deletion failure: %w", err) } + if physicalInterface != nil { + m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{ + FwMark: int(physicalInterface.FirewallMark), + Table: existingInterface.GetRoutingTable(), + }) + } + if err := m.handleInterfacePostSaveHooks(true, existingInterface); err != nil { return fmt.Errorf("post-delete hooks failed: %w", err) } @@ -384,10 +394,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface, pee return nil, fmt.Errorf("failed to save interface: %w", err) } - err = m.wg.SaveRoutes(ctx, iface, peers) - if err != nil { - return nil, fmt.Errorf("failed to save routes: %w", err) - } + m.bus.Publish(app.TopicRouteUpdate, "interface updated: "+string(iface.Identifier)) if err := m.handleInterfacePostSaveHooks(stateChanged, iface); err != nil { return nil, fmt.Errorf("post-save hooks failed: %w", err) diff --git a/internal/app/wireguard/wireguard_peers.go b/internal/app/wireguard/wireguard_peers.go index bec6db8..87b8783 100644 --- a/internal/app/wireguard/wireguard_peers.go +++ b/internal/app/wireguard/wireguard_peers.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "github.com/h44z/wg-portal/internal" + "github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/domain" "github.com/sirupsen/logrus" "time" @@ -279,15 +280,8 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error { } // Update routes after peers have changed - for ifaceId := range interfaces { - iface, ifacePeers, err := m.db.GetInterfaceAndPeers(ctx, ifaceId) - if err != nil { - return fmt.Errorf("failed to load peer interface %s: %w", ifaceId, err) - } - err = m.wg.SaveRoutes(ctx, iface, ifacePeers) - if err != nil { - return fmt.Errorf("failed to update peer routes on interface %s: %w", ifaceId, err) - } + if len(interfaces) != 0 { + m.bus.Publish(app.TopicRouteUpdate, "peers updated") } return nil diff --git a/internal/config/config.go b/internal/config/config.go index 63735e7..1beae78 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -33,6 +33,8 @@ type Config struct { UseIpV6 bool `yaml:"use_ip_v6"` ConfigStoragePath string `yaml:"config_storage_path"` // keep empty to disable config export to file ExpiryCheckInterval time.Duration `yaml:"expiry_check_interval"` + RulePrioOffset int `yaml:"rule_prio_offset"` + RouteTableOffset int `yaml:"route_table_offset"` } `yaml:"advanced"` Statistics struct { @@ -106,6 +108,8 @@ func defaultConfig() *Config { cfg.Advanced.StartCidrV6 = "fdfd:d3ad:c0de:1234::0/64" cfg.Advanced.UseIpV6 = true cfg.Advanced.ExpiryCheckInterval = 15 * time.Minute + cfg.Advanced.RulePrioOffset = 20000 + cfg.Advanced.RouteTableOffset = 20000 cfg.Statistics.UsePingChecks = true cfg.Statistics.PingCheckWorkers = 10 diff --git a/internal/domain/interface.go b/internal/domain/interface.go index b09c51a..03eabb6 100644 --- a/internal/domain/interface.go +++ b/internal/domain/interface.go @@ -211,3 +211,12 @@ func MergeToPhysicalInterface(pi *PhysicalInterface, i *Interface) { pi.DeviceUp = !i.IsDisabled() pi.Addresses = i.Addresses } + +type RoutingTableInfo struct { + FwMark int + Table int +} + +func (r RoutingTableInfo) String() string { + return fmt.Sprintf("%d -> %d", r.FwMark, r.Table) +}