diff --git a/internal/app/wireguard/wireguard_interfaces.go b/internal/app/wireguard/wireguard_interfaces.go index c4d1923..8f8d193 100644 --- a/internal/app/wireguard/wireguard_interfaces.go +++ b/internal/app/wireguard/wireguard_interfaces.go @@ -893,16 +893,7 @@ func (m Manager) importInterface( } } - // try to predict the interface type based on the number of peers - switch len(peers) { - case 0: - iface.Type = domain.InterfaceTypeAny // no peers means this is an unknown interface - case 1: - iface.Type = domain.InterfaceTypeClient // one peer means this is a client interface - default: // multiple peers means this is a server interface - - iface.Type = domain.InterfaceTypeServer - } + iface.Type = inferImportedInterfaceType(iface, peers) existingInterface, err := m.db.GetInterface(ctx, iface.Identifier) if err != nil && !errors.Is(err, domain.ErrNotFound) { @@ -930,6 +921,20 @@ func (m Manager) importInterface( return nil } +func inferImportedInterfaceType(iface *domain.Interface, peers []domain.PhysicalPeer) domain.InterfaceType { + switch len(peers) { + case 0: + return domain.InterfaceTypeAny // no peers means this is an unknown interface + case 1: + if iface.ListenPort > 0 { + return domain.InterfaceTypeServer // a listening interface with one peer is commonly a site-to-site server + } + return domain.InterfaceTypeClient + default: // multiple peers means this is a server interface + return domain.InterfaceTypeServer + } +} + // extractPfsenseDefaultsFromPeers extracts common endpoint and DNS information from peers // For server interfaces, peers typically have endpoints pointing to the server, so we use the most common one func extractPfsenseDefaultsFromPeers(peers []domain.PhysicalPeer, listenPort int) (endpoint, dns string) { diff --git a/internal/app/wireguard/wireguard_interfaces_test.go b/internal/app/wireguard/wireguard_interfaces_test.go index 4d15fd0..3709556 100644 --- a/internal/app/wireguard/wireguard_interfaces_test.go +++ b/internal/app/wireguard/wireguard_interfaces_test.go @@ -10,6 +10,49 @@ import ( "github.com/h44z/wg-portal/internal/domain" ) +func TestInferImportedInterfaceType(t *testing.T) { + tests := []struct { + name string + listenPort int + peerCount int + expected domain.InterfaceType + }{ + { + name: "no peers stays unknown", + listenPort: 51820, + peerCount: 0, + expected: domain.InterfaceTypeAny, + }, + { + name: "single peer with listen port is server", + listenPort: 51820, + peerCount: 1, + expected: domain.InterfaceTypeServer, + }, + { + name: "single peer without listen port stays client", + listenPort: 0, + peerCount: 1, + expected: domain.InterfaceTypeClient, + }, + { + name: "multiple peers is server", + listenPort: 0, + peerCount: 2, + expected: domain.InterfaceTypeServer, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + iface := &domain.Interface{ListenPort: tt.listenPort} + peers := make([]domain.PhysicalPeer, tt.peerCount) + + assert.Equal(t, tt.expected, inferImportedInterfaceType(iface, peers)) + }) + } +} + func TestImportPeer_AddressMapping(t *testing.T) { tests := []struct { name string