From e53b8c808799270b1a9d66be72d777f200039c58 Mon Sep 17 00:00:00 2001 From: h44z Date: Sun, 25 Jan 2026 00:33:33 +0100 Subject: [PATCH] fix: improve import of existing allowed-IPs (#615) --- .../app/wireguard/wireguard_interfaces.go | 21 ++++- .../wireguard/wireguard_interfaces_test.go | 94 +++++++++++++++++++ 2 files changed, 114 insertions(+), 1 deletion(-) create mode 100644 internal/app/wireguard/wireguard_interfaces_test.go diff --git a/internal/app/wireguard/wireguard_interfaces.go b/internal/app/wireguard/wireguard_interfaces.go index 867fab4..77f6f96 100644 --- a/internal/app/wireguard/wireguard_interfaces.go +++ b/internal/app/wireguard/wireguard_interfaces.go @@ -985,7 +985,26 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain peer.InterfaceIdentifier = in.Identifier peer.EndpointPublicKey = domain.NewConfigOption(in.PublicKey, true) peer.AllowedIPsStr = domain.NewConfigOption(in.PeerDefAllowedIPsStr, true) - peer.Interface.Addresses = p.AllowedIPs // use allowed IP's as the peer IP's TODO: Should this also match server interface address' prefix length? + + // split allowed IP's into interface addresses and extra allowed IP's + var interfaceAddresses []domain.Cidr + var extraAllowedIPs []domain.Cidr + for _, allowedIP := range p.AllowedIPs { + isHost := (allowedIP.IsV4() && allowedIP.NetLength == 32) || (!allowedIP.IsV4() && allowedIP.NetLength == 128) + isNetworkAddr := allowedIP.Addr == allowedIP.NetworkAddr().Addr + + // Network addresses (e.g. 10.0.0.0/24) will always be extra allowed IP's. + // For IP addresses, such as 10.0.0.1/24, it is challenging to tell whether it is an interface address or + // an extra allowed IP, therefore we treat such addresses as interface addresses. + if !isHost && isNetworkAddr { + extraAllowedIPs = append(extraAllowedIPs, allowedIP) + } else { + interfaceAddresses = append(interfaceAddresses, allowedIP) + } + } + peer.Interface.Addresses = interfaceAddresses + peer.ExtraAllowedIPsStr = domain.CidrsToString(extraAllowedIPs) + peer.Interface.DnsStr = domain.NewConfigOption(in.PeerDefDnsStr, true) peer.Interface.DnsSearchStr = domain.NewConfigOption(in.PeerDefDnsSearchStr, true) peer.Interface.Mtu = domain.NewConfigOption(in.PeerDefMtu, true) diff --git a/internal/app/wireguard/wireguard_interfaces_test.go b/internal/app/wireguard/wireguard_interfaces_test.go new file mode 100644 index 0000000..95f202b --- /dev/null +++ b/internal/app/wireguard/wireguard_interfaces_test.go @@ -0,0 +1,94 @@ +package wireguard + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/h44z/wg-portal/internal/domain" +) + +func TestImportPeer_AddressMapping(t *testing.T) { + tests := []struct { + name string + allowedIPs []string + expectedInterface []string + expectedExtraAllowed string + }{ + { + name: "IPv4 host address", + allowedIPs: []string{"10.0.0.1/32"}, + expectedInterface: []string{"10.0.0.1/32"}, + expectedExtraAllowed: "", + }, + { + name: "IPv6 host address", + allowedIPs: []string{"fd00::1/128"}, + expectedInterface: []string{"fd00::1/128"}, + expectedExtraAllowed: "", + }, + { + name: "IPv4 network address", + allowedIPs: []string{"10.0.1.0/24"}, + expectedInterface: []string{}, + expectedExtraAllowed: "10.0.1.0/24", + }, + { + name: "IPv4 normal address with mask", + allowedIPs: []string{"10.0.1.5/24"}, + expectedInterface: []string{"10.0.1.5/24"}, + expectedExtraAllowed: "", + }, + { + name: "Mixed addresses", + allowedIPs: []string{ + "10.0.0.1/32", "192.168.1.0/24", "172.16.0.5/24", "fd00::1/128", "fd00:1::/64", + }, + expectedInterface: []string{"10.0.0.1/32", "172.16.0.5/24", "fd00::1/128"}, + expectedExtraAllowed: "192.168.1.0/24,fd00:1::/64", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := &mockDB{} + m := Manager{ + db: db, + } + + iface := &domain.Interface{ + Identifier: "wg0", + Type: domain.InterfaceTypeServer, + } + + allowedIPs := make([]domain.Cidr, len(tt.allowedIPs)) + for i, s := range tt.allowedIPs { + cidr, _ := domain.CidrFromString(s) + allowedIPs[i] = cidr + } + + p := &domain.PhysicalPeer{ + Identifier: "peer1", + KeyPair: domain.KeyPair{PublicKey: "peer1-public-key-is-long-enough"}, + AllowedIPs: allowedIPs, + } + + err := m.importPeer(context.Background(), iface, p) + assert.NoError(t, err) + + savedPeer := db.savedPeers["peer1"] + assert.NotNil(t, savedPeer) + + // Check interface addresses + actualInterface := make([]string, len(savedPeer.Interface.Addresses)) + for i, addr := range savedPeer.Interface.Addresses { + actualInterface[i] = addr.String() + } + assert.ElementsMatch(t, tt.expectedInterface, actualInterface) + + // Check extra allowed IPs + assert.Equal(t, tt.expectedExtraAllowed, savedPeer.ExtraAllowedIPsStr) + }) + } +}