From 5aa94999ab2e045827dcba21082c4560d580df2d Mon Sep 17 00:00:00 2001 From: Christoph Haas Date: Sat, 29 Jul 2023 23:56:49 +0200 Subject: [PATCH] wip: wgquick --- internal/adapters/database.go | 6 +- internal/adapters/wgquick.go | 99 +++++++++++++ internal/adapters/wireguard.go | 132 +++++++++++++++--- .../adapters/wireguard_integration_test.go | 47 ++++++- internal/app/wireguard/repos.go | 2 +- .../app/wireguard/wireguard_interfaces.go | 52 ++++--- internal/app/wireguard/wireguard_peers.go | 2 +- internal/domain/interface.go | 44 ++++++ internal/domain/ip.go | 8 ++ internal/domain/peer.go | 16 ++- internal/lowlevel/netlink.go | 40 ++++++ 11 files changed, 391 insertions(+), 57 deletions(-) create mode 100644 internal/adapters/wgquick.go diff --git a/internal/adapters/database.go b/internal/adapters/database.go index c166ca4..77d8139 100644 --- a/internal/adapters/database.go +++ b/internal/adapters/database.go @@ -53,15 +53,15 @@ func (l *GormLogger) LogMode(logger.LogLevel) logger.Interface { } func (l *GormLogger) Info(ctx context.Context, s string, args ...interface{}) { - logrus.WithContext(ctx).Infof(s, args) + logrus.WithContext(ctx).Infof(s, args...) } func (l *GormLogger) Warn(ctx context.Context, s string, args ...interface{}) { - logrus.WithContext(ctx).Warnf(s, args) + logrus.WithContext(ctx).Warnf(s, args...) } func (l *GormLogger) Error(ctx context.Context, s string, args ...interface{}) { - logrus.WithContext(ctx).Errorf(s, args) + logrus.WithContext(ctx).Errorf(s, args...) } func (l *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { diff --git a/internal/adapters/wgquick.go b/internal/adapters/wgquick.go new file mode 100644 index 0000000..58de124 --- /dev/null +++ b/internal/adapters/wgquick.go @@ -0,0 +1,99 @@ +package adapters + +import ( + "bytes" + "fmt" + "github.com/h44z/wg-portal/internal" + "github.com/h44z/wg-portal/internal/domain" + "github.com/sirupsen/logrus" + "os/exec" + "strings" +) + +// WgQuickRepo implements higher level wg-quick like interactions like setting DNS, routing tables or interface hooks. +type WgQuickRepo struct { + shellCmd string + resolvConfIfacePrefix string +} + +func NewWgQuickRepo() *WgQuickRepo { + return &WgQuickRepo{ + shellCmd: "bash", + resolvConfIfacePrefix: "tun.", + } +} + +func (r *WgQuickRepo) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error { + if hookCmd == "" { + return nil + } + + err := r.exec(hookCmd, id) + if err != nil { + return fmt.Errorf("failed to exec hook: %w", err) + } + + return nil +} + +func (r *WgQuickRepo) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error { + if dnsStr == "" && dnsSearchStr == "" { + return nil + } + + dnsServers := internal.SliceString(dnsStr) + dnsSearchDomains := internal.SliceString(dnsSearchStr) + + dnsCommand := "resolvconf -a %resPref%i -m 0 -x" + dnsCommandInput := make([]string, 0, len(dnsServers)+len(dnsSearchDomains)) + + for _, dnsServer := range dnsServers { + dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("nameserver %s", dnsServer)) + } + for _, searchDomain := range dnsSearchDomains { + dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("search %s", searchDomain)) + } + + err := r.exec(dnsCommand, id, dnsCommandInput...) + if err != nil { + return fmt.Errorf("failed to set dns settings: %w", err) + } + + return nil +} + +func (r *WgQuickRepo) UnsetDNS(id domain.InterfaceIdentifier) error { + dnsCommand := "resolvconf -d %resPref%i -f" + + err := r.exec(dnsCommand, id) + if err != nil { + return fmt.Errorf("failed to unset dns settings: %w", err) + } + + return nil +} + +func (r *WgQuickRepo) replaceCommandPlaceHolders(command string, interfaceId domain.InterfaceIdentifier) string { + command = strings.ReplaceAll(command, "%resPref", r.resolvConfIfacePrefix) + return strings.ReplaceAll(command, "%i", string(interfaceId)) +} + +func (r *WgQuickRepo) exec(command string, interfaceId domain.InterfaceIdentifier, stdin ...string) error { + commandWithInterfaceName := r.replaceCommandPlaceHolders(command, interfaceId) + cmd := exec.Command(r.shellCmd, "-ce", commandWithInterfaceName) + if len(stdin) > 0 { + b := &bytes.Buffer{} + for _, ln := range stdin { + if _, err := fmt.Fprint(b, ln); err != nil { + return err + } + } + cmd.Stdin = b + } + out, err := cmd.CombinedOutput() // execute and wait for output + if err != nil { + return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err) + } + logrus.Tracef("executed shell command %s, with output: %s", commandWithInterfaceName, string(out)) + return nil +} diff --git a/internal/adapters/wireguard.go b/internal/adapters/wireguard.go index 7c43ae6..f17dd95 100644 --- a/internal/adapters/wireguard.go +++ b/internal/adapters/wireguard.go @@ -4,19 +4,21 @@ import ( "context" "errors" "fmt" - "os" - "github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/lowlevel" + "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "os" ) // WgRepo implements all low-level WireGuard interactions. type WgRepo struct { - wg lowlevel.WireGuardClient - nl lowlevel.NetlinkClient + wg lowlevel.WireGuardClient + nl lowlevel.NetlinkClient + quick *WgQuickRepo } func NewWireGuardRepository() *WgRepo { @@ -28,8 +30,9 @@ func NewWireGuardRepository() *WgRepo { nl := &lowlevel.NetlinkManager{} repo := &WgRepo{ - wg: wg, - nl: nl, + wg: wg, + nl: nl, + quick: NewWgQuickRepo(), } return repo @@ -152,18 +155,40 @@ func (r *WgRepo) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer, return peerModel, nil } -func (r *WgRepo) SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error { - physicalInterface, err := r.getOrCreateInterface(id) +func (r *WgRepo) SaveInterface(_ context.Context, iface *domain.Interface, peers []domain.Peer, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error { + physicalInterface, err := r.getOrCreateInterface(iface.Identifier) if err != nil { return err } + wasUp := physicalInterface.DeviceUp if updateFunc != nil { physicalInterface, err = updateFunc(physicalInterface) if err != nil { return err } } + stateChanged := wasUp != physicalInterface.DeviceUp + + if stateChanged { + if physicalInterface.DeviceUp { + if err := r.quick.SetDNS(iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil { + return fmt.Errorf("failed to update dns settings: %w", err) + } + + if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreUp); err != nil { + return fmt.Errorf("failed to execute pre-up hook: %w", err) + } + } else { + if err := r.quick.UnsetDNS(iface.Identifier); err != nil { + return fmt.Errorf("failed to clear dns settings: %w", err) + } + + if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreDown); err != nil { + return fmt.Errorf("failed to execute pre-down hook: %w", err) + } + } + } if err := r.updateLowLevelInterface(physicalInterface); err != nil { return err @@ -171,6 +196,21 @@ func (r *WgRepo) SaveInterface(_ context.Context, id domain.InterfaceIdentifier, if err := r.updateWireGuardInterface(physicalInterface); err != nil { return err } + if err := r.updateRoutes(iface.Identifier, iface.GetRoutingTable(), iface.GetAllowedIPs(peers)); err != nil { + return err + } + + if stateChanged { + if physicalInterface.DeviceUp { + if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostUp); err != nil { + return fmt.Errorf("failed to execute post-up hook: %w", err) + } + } else { + if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostDown); err != nil { + return fmt.Errorf("failed to execute post-down hook: %w", err) + } + } + } return nil } @@ -298,6 +338,72 @@ func (r *WgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error { return nil } +func (r *WgRepo) updateRoutes(interfaceId domain.InterfaceIdentifier, table int, allowedIPs []domain.Cidr) error { + if table == -1 { + logrus.Trace("ignoring route update") + return nil + } + + link, err := r.nl.LinkByName(string(interfaceId)) + if err != nil { + return err + } + + if link.Attrs().OperState == netlink.OperDown { + return nil // cannot set route for interface that is down + } + + // try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash) + for _, allowedIP := range allowedIPs { + if allowedIP.Prefix().Bits() == 0 { // default route + // TODO + } else { + err := r.nl.RouteReplace(&netlink.Route{ + LinkIndex: link.Attrs().Index, + Dst: allowedIP.IpNet(), + Table: table, + Scope: unix.RT_SCOPE_LINK, + Type: unix.RTN_UNICAST, + }) + if err != nil { + return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err) + } + } + } + + // Remove unwanted routes + rawRoutes, err := r.nl.RouteListFiltered(netlink.FAMILY_ALL, &netlink.Route{ + LinkIndex: link.Attrs().Index, + Table: unix.RT_TABLE_UNSPEC, // all tables + Scope: unix.RT_SCOPE_LINK, + Type: unix.RTN_UNICAST, + }, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF) + if err != nil { + return fmt.Errorf("failed to fetch raw routes: %w", err) + } + for _, rawRoute := range rawRoutes { + netlinkAddr := domain.CidrFromIpNet(*rawRoute.Dst) + remove := true + for _, allowedIP := range allowedIPs { + if netlinkAddr == allowedIP { + remove = false + break + } + } + + if !remove { + continue + } + + err := r.nl.RouteDel(&rawRoute) + if err != nil { + return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err) + } + } + + return nil +} + func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error { if err := r.deleteLowLevelInterface(id); err != nil { return err @@ -388,16 +494,10 @@ func (r *WgRepo) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.Phys Endpoint: pp.GetEndpointAddress(), PersistentKeepaliveInterval: pp.GetPersistentKeepaliveTime(), ReplaceAllowedIPs: true, - AllowedIPs: nil, + AllowedIPs: pp.GetAllowedIPs(), } - ips, err := pp.GetAllowedIPs() - if err != nil { - return err - } - cfg.AllowedIPs = ips - - err = r.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}}) + err := r.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}}) if err != nil { return err } diff --git a/internal/adapters/wireguard_integration_test.go b/internal/adapters/wireguard_integration_test.go index 2c953cd..d9a20b9 100644 --- a/internal/adapters/wireguard_integration_test.go +++ b/internal/adapters/wireguard_integration_test.go @@ -70,8 +70,8 @@ func TestWireGuardCreateInterface(t *testing.T) { err := mgr.SaveInterface(context.Background(), interfaceName, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { pi.Addresses = []domain.Cidr{ - {Ip: domain.IpAddress(net.ParseIP(ipAddress)), NetLength: 24, Bits: 32}, - {Ip: domain.IpAddress(net.ParseIP(ipV6Address)), NetLength: 64, Bits: 128}, + domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipAddress), Mask: net.CIDRMask(24, 32)}), + domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipV6Address), Mask: net.CIDRMask(64, 128)}), } return pi, nil }) @@ -104,8 +104,8 @@ func TestWireGuardUpdateInterface(t *testing.T) { ipV6Address := "1337:d34d:b33f::2" err = mgr.SaveInterface(context.Background(), interfaceName, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { pi.Addresses = []domain.Cidr{ - {Ip: domain.IpAddress(net.ParseIP(ipAddress)), NetLength: 24, Bits: 32}, - {Ip: domain.IpAddress(net.ParseIP(ipV6Address)), NetLength: 64, Bits: 128}, + domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipAddress), Mask: net.CIDRMask(24, 32)}), + domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipV6Address), Mask: net.CIDRMask(64, 128)}), } return pi, nil }) @@ -119,3 +119,42 @@ func TestWireGuardUpdateInterface(t *testing.T) { assert.Contains(t, string(out), ipAddress) assert.Contains(t, string(out), ipV6Address) } + +func TestWireGuardCreateInterfaceWithRoutes(t *testing.T) { + mgr := setup(t) + + interfaceName := domain.InterfaceIdentifier("wg_test_001") + ipAddress := "10.11.12.13" + ipV6Address := "1337:d34d:b33f::2" + defer mgr.DeleteInterface(context.Background(), interfaceName) + + iface := &domain.Interface{ + Identifier: interfaceName, + //RoutingTable: "1234", + } + peers := []domain.Peer{ + { + Interface: domain.PeerInterfaceConfig{ + Addresses: domain.CidrsMust(domain.CidrsFromString("10.11.12.14/32,10.22.33.44/32")), + }, + }, + } + + err := mgr.SaveInterface2(context.Background(), iface, peers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { + pi.Addresses = []domain.Cidr{ + domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipAddress), Mask: net.CIDRMask(24, 32)}), + domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipV6Address), Mask: net.CIDRMask(64, 128)}), + } + pi.DeviceUp = true + return pi, nil + }) + assert.NoError(t, err) + + // Validate that the interface has been created + cmd := exec.Command("ip", "addr") + out, err := cmd.CombinedOutput() + assert.NoError(t, err) + assert.Contains(t, string(out), interfaceName) + assert.Contains(t, string(out), ipAddress) + assert.Contains(t, string(out), ipV6Address) +} diff --git a/internal/app/wireguard/repos.go b/internal/app/wireguard/repos.go index 23b0a96..cf63cd3 100644 --- a/internal/app/wireguard/repos.go +++ b/internal/app/wireguard/repos.go @@ -37,7 +37,7 @@ type InterfaceController interface { GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) - SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error + SaveInterface(_ context.Context, iface *domain.Interface, peers []domain.Peer, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error diff --git a/internal/app/wireguard/wireguard_interfaces.go b/internal/app/wireguard/wireguard_interfaces.go index 4205e55..1d79e2a 100644 --- a/internal/app/wireguard/wireguard_interfaces.go +++ b/internal/app/wireguard/wireguard_interfaces.go @@ -135,12 +135,12 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool return fmt.Errorf("failed to load peers for %s: %w", iface.Identifier, err) } - physicalInterface, err := m.wg.GetInterface(ctx, iface.Identifier) + _, err = m.wg.GetInterface(ctx, iface.Identifier) if err != nil { logrus.Debugf("creating missing interface %s...", iface.Identifier) // try to create a new interface - err := m.wg.SaveInterface(ctx, iface.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { + err := m.wg.SaveInterface(ctx, &iface, peers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { domain.MergeToPhysicalInterface(pi, &iface) return pi, nil @@ -169,32 +169,30 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool } } } else { - if physicalInterface.DeviceUp != !iface.IsDisabled() { - logrus.Debugf("restoring interface state for %s to disabled=%t", iface.Identifier, iface.IsDisabled()) + logrus.Debugf("restoring interface state for %s to disabled=%t", iface.Identifier, iface.IsDisabled()) - // try to move interface to stored state - err := m.wg.SaveInterface(ctx, iface.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { - pi.DeviceUp = !iface.IsDisabled() + // try to move interface to stored state + err := m.wg.SaveInterface(ctx, &iface, peers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { + pi.DeviceUp = !iface.IsDisabled() - return pi, nil - }) - if err != nil { - if updateDbOnError { - // disable interface in database as no physical interface is available - _ = m.db.SaveInterface(ctx, iface.Identifier, func(in *domain.Interface) (*domain.Interface, error) { - if iface.IsDisabled() { - now := time.Now() - in.Disabled = &now // set - in.DisabledReason = "no physical interface active" - } else { - in.Disabled = nil - in.DisabledReason = "" - } - return in, nil - }) - } - return fmt.Errorf("failed to change physical interface state for %s: %w", iface.Identifier, err) + return pi, nil + }) + if err != nil { + if updateDbOnError { + // disable interface in database as no physical interface is available + _ = m.db.SaveInterface(ctx, iface.Identifier, func(in *domain.Interface) (*domain.Interface, error) { + if iface.IsDisabled() { + now := time.Now() + in.Disabled = &now // set + in.DisabledReason = "no physical interface active" + } else { + in.Disabled = nil + in.DisabledReason = "" + } + return in, nil + }) } + return fmt.Errorf("failed to change physical interface state for %s: %w", iface.Identifier, err) } } } @@ -294,7 +292,7 @@ func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*do err = m.db.SaveInterface(ctx, in.Identifier, func(i *domain.Interface) (*domain.Interface, error) { in.CopyCalculatedAttributes(i) - err = m.wg.SaveInterface(ctx, in.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { + err = m.wg.SaveInterface(ctx, in, nil, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { domain.MergeToPhysicalInterface(pi, in) return pi, nil }) @@ -324,7 +322,7 @@ func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*do err = m.db.SaveInterface(ctx, in.Identifier, func(i *domain.Interface) (*domain.Interface, error) { in.CopyCalculatedAttributes(i) - err = m.wg.SaveInterface(ctx, in.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { + err = m.wg.SaveInterface(ctx, in, existingPeers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { domain.MergeToPhysicalInterface(pi, in) return pi, nil }) diff --git a/internal/app/wireguard/wireguard_peers.go b/internal/app/wireguard/wireguard_peers.go index 28029ae..2a35b74 100644 --- a/internal/app/wireguard/wireguard_peers.go +++ b/internal/app/wireguard/wireguard_peers.go @@ -306,7 +306,7 @@ func (m Manager) getFreshPeerIpConfig(ctx context.Context, iface *domain.Interfa } } - ips = append(ips, ip) + ips = append(ips, ip.HostAddr()) } return diff --git a/internal/domain/interface.go b/internal/domain/interface.go index a4fd241..4386f41 100644 --- a/internal/domain/interface.go +++ b/internal/domain/interface.go @@ -3,7 +3,10 @@ package domain import ( "fmt" "github.com/h44z/wg-portal/internal" + "math" "regexp" + "strconv" + "strings" "time" ) @@ -95,6 +98,47 @@ func (i *Interface) GetConfigFileName() string { return filename } +func (i *Interface) GetAllowedIPs(peers []Peer) []Cidr { + var allowedCidrs []Cidr + + for _, peer := range peers { + allowedCidrs = append(allowedCidrs, peer.Interface.Addresses...) + if peer.ExtraAllowedIPsStr != "" { + extraIPs, err := CidrsFromString(peer.ExtraAllowedIPsStr) + if err == nil { + allowedCidrs = append(allowedCidrs, extraIPs...) + } + } + } + + return allowedCidrs +} + +// GetRoutingTable returns the routing table number or -1 if an error occurred or RoutingTable was set to "off" +func (i *Interface) GetRoutingTable() int { + routingTableStr := strings.ToLower(i.RoutingTable) + switch { + case routingTableStr == "": + return 0 + case strings.HasPrefix(routingTableStr, "0x"): + numberStr := strings.ReplaceAll(routingTableStr, "0x", "") + routingTable, err := strconv.ParseUint(numberStr, 16, 64) + if err != nil { + return -1 + } + if routingTable > math.MaxInt32 { + return -1 + } + return int(routingTable) + default: + routingTable, err := strconv.Atoi(routingTableStr) + if err != nil { + return -1 + } + return routingTable + } +} + type PhysicalInterface struct { Identifier InterfaceIdentifier // device name, for example: wg0 KeyPair // private/public Key of the server interface diff --git a/internal/domain/ip.go b/internal/domain/ip.go index 294d526..529c7c9 100644 --- a/internal/domain/ip.go +++ b/internal/domain/ip.go @@ -163,6 +163,14 @@ func (c Cidr) NextAddr() Cidr { } } +func (c Cidr) HostAddr() Cidr { + return Cidr{ + Cidr: netip.PrefixFrom(c.Prefix().Addr(), c.Prefix().Addr().BitLen()).String(), + Addr: c.Addr, + NetLength: c.Prefix().Addr().BitLen(), + } +} + func (c Cidr) NextSubnet() Cidr { prefix := c.Prefix() nextAddr := c.BroadcastAddr().Prefix().Addr().Next() diff --git a/internal/domain/peer.go b/internal/domain/peer.go index 9141a1a..64250bc 100644 --- a/internal/domain/peer.go +++ b/internal/domain/peer.go @@ -190,13 +190,13 @@ func (p PhysicalPeer) GetPersistentKeepaliveTime() *time.Duration { return &keepAliveDuration } -func (p PhysicalPeer) GetAllowedIPs() ([]net.IPNet, error) { +func (p PhysicalPeer) GetAllowedIPs() []net.IPNet { allowedIPs := make([]net.IPNet, len(p.AllowedIPs)) for i, ip := range p.AllowedIPs { allowedIPs[i] = *ip.IpNet() } - return allowedIPs, nil + return allowedIPs } func ConvertPhysicalPeer(pp *PhysicalPeer) *Peer { @@ -223,9 +223,15 @@ func ConvertPhysicalPeer(pp *PhysicalPeer) *Peer { func MergeToPhysicalPeer(pp *PhysicalPeer, p *Peer) { pp.Identifier = p.Identifier pp.Endpoint = p.Endpoint.GetValue() - allowedIPs, _ := CidrsFromString(p.AllowedIPsStr.GetValue()) - extraAllowedIPs, _ := CidrsFromString(p.ExtraAllowedIPsStr) - pp.AllowedIPs = append(allowedIPs, extraAllowedIPs...) + if p.Interface.Type == InterfaceTypeServer { + allowedIPs, _ := CidrsFromString(p.AllowedIPsStr.GetValue()) + extraAllowedIPs, _ := CidrsFromString(p.ExtraAllowedIPsStr) + pp.AllowedIPs = append(allowedIPs, extraAllowedIPs...) + } else { + allowedIPs := p.Interface.Addresses + extraAllowedIPs, _ := CidrsFromString(p.ExtraAllowedIPsStr) + pp.AllowedIPs = append(allowedIPs, extraAllowedIPs...) + } pp.PresharedKey = p.PresharedKey pp.PublicKey = p.Interface.PublicKey pp.PersistentKeepalive = p.PersistentKeepalive.GetValue() diff --git a/internal/lowlevel/netlink.go b/internal/lowlevel/netlink.go index 409ebdd..4e9d312 100644 --- a/internal/lowlevel/netlink.go +++ b/internal/lowlevel/netlink.go @@ -16,6 +16,14 @@ type NetlinkClient interface { AddrAdd(link netlink.Link, addr *netlink.Addr) error AddrList(link netlink.Link) ([]netlink.Addr, error) AddrDel(link netlink.Link, addr *netlink.Addr) error + RouteAdd(route *netlink.Route) error + RouteDel(route *netlink.Route) error + RouteReplace(route *netlink.Route) error + RouteList(link netlink.Link, family int) ([]netlink.Route, error) + RouteListFiltered(family int, filter *netlink.Route, filterMask uint64) ([]netlink.Route, error) + RuleAdd(rule *netlink.Rule) error + RuleDel(rule *netlink.Rule) error + RuleList(family int) ([]netlink.Rule, error) } type NetlinkManager struct { @@ -66,3 +74,35 @@ func (n NetlinkManager) AddrList(link netlink.Link) ([]netlink.Addr, error) { func (n NetlinkManager) AddrDel(link netlink.Link, addr *netlink.Addr) error { return netlink.AddrDel(link, addr) } + +func (n NetlinkManager) RouteAdd(route *netlink.Route) error { + return netlink.RouteAdd(route) +} + +func (n NetlinkManager) RouteDel(route *netlink.Route) error { + return netlink.RouteDel(route) +} + +func (n NetlinkManager) RouteReplace(route *netlink.Route) error { + return netlink.RouteReplace(route) +} + +func (n NetlinkManager) RouteList(link netlink.Link, family int) ([]netlink.Route, error) { + return netlink.RouteList(link, family) +} + +func (n NetlinkManager) RouteListFiltered(family int, filter *netlink.Route, filterMask uint64) ([]netlink.Route, error) { + return netlink.RouteListFiltered(family, filter, filterMask) +} + +func (n NetlinkManager) RuleAdd(rule *netlink.Rule) error { + return netlink.RuleAdd(rule) +} + +func (n NetlinkManager) RuleDel(rule *netlink.Rule) error { + return netlink.RuleDel(rule) +} + +func (n NetlinkManager) RuleList(family int) ([]netlink.Rule, error) { + return netlink.RuleList(family) +}