diff --git a/internal/app/wireguard/wireguard_peers.go b/internal/app/wireguard/wireguard_peers.go index 99ccdcf..f3c5364 100644 --- a/internal/app/wireguard/wireguard_peers.go +++ b/internal/app/wireguard/wireguard_peers.go @@ -188,6 +188,8 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee sessionUser := domain.GetUserInfo(ctx) + peer.Identifier = domain.PeerIdentifier(peer.Interface.PublicKey) // ensure that identifier corresponds to the public key + // Enforce peer limit for non-admin users if LimitAdditionalUserPeers is set if m.cfg.Core.SelfProvisioningAllowed && !sessionUser.IsAdmin && m.cfg.Advanced.LimitAdditionalUserPeers > 0 { peers, err := m.db.GetUserPeers(ctx, peer.UserIdentifier) diff --git a/internal/app/wireguard/wireguard_peers_test.go b/internal/app/wireguard/wireguard_peers_test.go new file mode 100644 index 0000000..707d015 --- /dev/null +++ b/internal/app/wireguard/wireguard_peers_test.go @@ -0,0 +1,194 @@ +package wireguard + +import ( + "context" + "testing" + + "github.com/h44z/wg-portal/internal/config" + "github.com/h44z/wg-portal/internal/domain" +) + +// --- Test mocks --- + +type mockBus struct{} + +func (f *mockBus) Publish(topic string, args ...any) {} +func (f *mockBus) Subscribe(topic string, fn interface{}) error { return nil } + +type mockController struct{} + +func (f *mockController) GetId() domain.InterfaceBackend { return "local" } +func (f *mockController) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) { + return nil, nil +} +func (f *mockController) GetInterface(_ context.Context, id domain.InterfaceIdentifier) ( + *domain.PhysicalInterface, + error, +) { + return &domain.PhysicalInterface{Identifier: id}, nil +} +func (f *mockController) GetPeers(_ context.Context, _ domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) { + return nil, nil +} +func (f *mockController) SaveInterface( + _ context.Context, + _ domain.InterfaceIdentifier, + updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error), +) error { + _, _ = updateFunc(&domain.PhysicalInterface{}) + return nil +} +func (f *mockController) DeleteInterface(_ context.Context, _ domain.InterfaceIdentifier) error { + return nil +} +func (f *mockController) SavePeer( + _ context.Context, + _ domain.InterfaceIdentifier, + _ domain.PeerIdentifier, + updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error), +) error { + _, _ = updateFunc(&domain.PhysicalPeer{}) + return nil +} +func (f *mockController) DeletePeer(_ context.Context, _ domain.InterfaceIdentifier, _ domain.PeerIdentifier) error { + return nil +} +func (f *mockController) PingAddresses(_ context.Context, _ string) (*domain.PingerResult, error) { + return nil, nil +} + +type mockDB struct { + savedPeers map[domain.PeerIdentifier]*domain.Peer + iface *domain.Interface +} + +func (f *mockDB) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) { + if f.iface != nil && f.iface.Identifier == id { + return f.iface, nil + } + return &domain.Interface{Identifier: id}, nil +} +func (f *mockDB) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) ( + *domain.Interface, + []domain.Peer, + error, +) { + return f.iface, nil, nil +} +func (f *mockDB) GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error) { + return nil, nil +} +func (f *mockDB) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) { return nil, nil } +func (f *mockDB) GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error) { + return nil, nil +} +func (f *mockDB) SaveInterface( + ctx context.Context, + id domain.InterfaceIdentifier, + updateFunc func(in *domain.Interface) (*domain.Interface, error), +) error { + if f.iface == nil { + f.iface = &domain.Interface{Identifier: id} + } + var err error + f.iface, err = updateFunc(f.iface) + return err +} +func (f *mockDB) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error { + return nil +} +func (f *mockDB) GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error) { + return nil, nil +} +func (f *mockDB) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) { + return nil, nil +} +func (f *mockDB) SavePeer( + ctx context.Context, + id domain.PeerIdentifier, + updateFunc func(in *domain.Peer) (*domain.Peer, error), +) error { + if f.savedPeers == nil { + f.savedPeers = make(map[domain.PeerIdentifier]*domain.Peer) + } + existing := f.savedPeers[id] + if existing == nil { + existing = &domain.Peer{Identifier: id} + } + updated, err := updateFunc(existing) + if err != nil { + return err + } + f.savedPeers[updated.Identifier] = updated + return nil +} +func (f *mockDB) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error { return nil } +func (f *mockDB) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) { + return nil, domain.ErrNotFound +} +func (f *mockDB) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) ( + map[domain.Cidr][]domain.Cidr, + error, +) { + return map[domain.Cidr][]domain.Cidr{}, nil +} + +// --- Test --- + +func TestCreatePeer_SetsIdentifier_FromPublicKey(t *testing.T) { + // Arrange + cfg := &config.Config{} + cfg.Core.SelfProvisioningAllowed = true + cfg.Core.EditableKeys = true + cfg.Advanced.LimitAdditionalUserPeers = 0 + + bus := &mockBus{} + + // Prepare a controller manager with our mock controller + ctrlMgr := &ControllerManager{ + controllers: map[domain.InterfaceBackend]backendInstance{ + config.LocalBackendName: {Implementation: &mockController{}}, + }, + } + + db := &mockDB{iface: &domain.Interface{Identifier: "wg0", Type: domain.InterfaceTypeServer}} + + m := Manager{ + cfg: cfg, + bus: bus, + db: db, + wg: ctrlMgr, + } + + userId := domain.UserIdentifier("user@example.com") + ctx := domain.SetUserInfo(context.Background(), &domain.ContextUserInfo{Id: userId, IsAdmin: false}) + + pubKey := "TEST_PUBLIC_KEY_ABC123" + + input := &domain.Peer{ + Identifier: "should_be_overwritten", + UserIdentifier: userId, + InterfaceIdentifier: domain.InterfaceIdentifier("wg0"), + Interface: domain.PeerInterfaceConfig{ + KeyPair: domain.KeyPair{PublicKey: pubKey}, + }, + } + + // Act + out, err := m.CreatePeer(ctx, input) + + // Assert + if err != nil { + t.Fatalf("CreatePeer returned error: %v", err) + } + + expectedId := domain.PeerIdentifier(pubKey) + if out.Identifier != expectedId { + t.Fatalf("expected Identifier to be set from public key %q, got %q", expectedId, out.Identifier) + } + + // Ensure the saved peer in DB also has the expected identifier + if db.savedPeers[expectedId] == nil { + t.Fatalf("expected peer with identifier %q to be saved in DB", expectedId) + } +}