feat: Implement LDAP interface-specific provisioning filters (#642)
Some checks failed
Docker / Build and Push (push) Has been cancelled
github-pages / deploy (push) Has been cancelled
Docker / release (push) Has been cancelled

* Implement LDAP filter-based access control for interface provisioning

* test: add unit tests for LDAP interface filtering logic

* smaller improvements / cleanup

---------

Co-authored-by: jc <37738506+theguy147@users.noreply.github.com>
Co-authored-by: Christoph Haas <christoph.h@sprinternet.at>
This commit is contained in:
Jacopo Clark
2026-03-19 23:13:19 +01:00
committed by GitHub
parent f70f60a3f5
commit 402cc1b5f3
16 changed files with 339 additions and 18 deletions

View File

@@ -35,6 +35,7 @@ type InterfaceAndPeerDatabaseRepo interface {
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
}
type WgQuickController interface {

View File

@@ -16,6 +16,11 @@ import (
"github.com/h44z/wg-portal/internal/domain"
)
// GetInterface returns the interface for the given interface identifier.
func (m Manager) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) {
return m.db.GetInterface(ctx, id)
}
// GetInterfaceAndPeers returns the interface and all peers for the given interface identifier.
func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
*domain.Interface,
@@ -63,12 +68,17 @@ func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interfa
// GetUserInterfaces returns all interfaces that are available for users to create new peers.
// If self-provisioning is disabled, this function will return an empty list.
// At the moment, there are no interfaces specific to single users, thus the user id is not used.
func (m Manager) GetUserInterfaces(ctx context.Context, _ domain.UserIdentifier) ([]domain.Interface, error) {
func (m Manager) GetUserInterfaces(ctx context.Context, userId domain.UserIdentifier) ([]domain.Interface, error) {
if !m.cfg.Core.SelfProvisioningAllowed {
return nil, nil // self-provisioning is disabled - no interfaces for users
}
user, err := m.db.GetUser(ctx, userId)
if err != nil {
slog.Error("failed to load user for interface group verification", "user", userId, "error", err)
return nil, nil // fail closed
}
interfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return nil, fmt.Errorf("unable to load all interfaces: %w", err)
@@ -83,6 +93,9 @@ func (m Manager) GetUserInterfaces(ctx context.Context, _ domain.UserIdentifier)
if iface.Type != domain.InterfaceTypeServer {
continue // skip client interfaces
}
if !user.IsAdmin && !iface.IsUserAllowed(userId, m.cfg) {
continue // user not allowed due to LDAP group filter
}
userInterfaces = append(userInterfaces, iface.PublicInfo())
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
@@ -92,3 +93,126 @@ func TestImportPeer_AddressMapping(t *testing.T) {
})
}
}
func (f *mockDB) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
return &domain.User{
Identifier: id,
IsAdmin: false,
}, nil
}
func TestInterface_IsUserAllowed(t *testing.T) {
cfg := &config.Config{
Auth: config.Auth{
Ldap: []config.LdapProvider{
{
ProviderName: "ldap1",
InterfaceFilter: map[string]string{
"wg0": "(memberOf=CN=VPNUsers,...)",
},
},
},
},
}
tests := []struct {
name string
iface domain.Interface
userId domain.UserIdentifier
expect bool
}{
{
name: "Unrestricted interface",
iface: domain.Interface{
Identifier: "wg1",
},
userId: "user1",
expect: true,
},
{
name: "Restricted interface - user allowed",
iface: domain.Interface{
Identifier: "wg0",
LdapAllowedUsers: map[string][]domain.UserIdentifier{
"ldap1": {"user1"},
},
},
userId: "user1",
expect: true,
},
{
name: "Restricted interface - user allowed (at least one match)",
iface: domain.Interface{
Identifier: "wg0",
LdapAllowedUsers: map[string][]domain.UserIdentifier{
"ldap1": {"user2"},
"ldap2": {"user1"},
},
},
userId: "user1",
expect: true,
},
{
name: "Restricted interface - user NOT allowed",
iface: domain.Interface{
Identifier: "wg0",
LdapAllowedUsers: map[string][]domain.UserIdentifier{
"ldap1": {"user2"},
},
},
userId: "user1",
expect: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expect, tt.iface.IsUserAllowed(tt.userId, cfg))
})
}
}
func TestManager_GetUserInterfaces_Filtering(t *testing.T) {
cfg := &config.Config{}
cfg.Core.SelfProvisioningAllowed = true
cfg.Auth.Ldap = []config.LdapProvider{
{
ProviderName: "ldap1",
InterfaceFilter: map[string]string{
"wg_restricted": "(some-filter)",
},
},
}
db := &mockDB{
interfaces: []domain.Interface{
{Identifier: "wg_public", Type: domain.InterfaceTypeServer},
{
Identifier: "wg_restricted",
Type: domain.InterfaceTypeServer,
LdapAllowedUsers: map[string][]domain.UserIdentifier{
"ldap1": {"allowed_user"},
},
},
},
}
m := Manager{
cfg: cfg,
db: db,
}
t.Run("Allowed user sees both", func(t *testing.T) {
ifaces, err := m.GetUserInterfaces(context.Background(), "allowed_user")
assert.NoError(t, err)
assert.Equal(t, 2, len(ifaces))
})
t.Run("Unallowed user sees only public", func(t *testing.T) {
ifaces, err := m.GetUserInterfaces(context.Background(), "other_user")
assert.NoError(t, err)
assert.Equal(t, 1, len(ifaces))
if len(ifaces) > 0 {
assert.Equal(t, domain.InterfaceIdentifier("wg_public"), ifaces[0].Identifier)
}
})
}

View File

@@ -93,6 +93,10 @@ func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier)
currentUser := domain.GetUserInfo(ctx)
if err := m.checkInterfaceAccess(ctx, id); err != nil {
return nil, err
}
iface, err := m.db.GetInterface(ctx, id)
if err != nil {
return nil, fmt.Errorf("unable to find interface %s: %w", id, err)
@@ -188,6 +192,9 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
return nil, err
}
if err := m.checkInterfaceAccess(ctx, peer.InterfaceIdentifier); err != nil {
return nil, err
}
}
sessionUser := domain.GetUserInfo(ctx)
@@ -304,6 +311,10 @@ func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
return nil, err
}
if err := m.checkInterfaceAccess(ctx, existingPeer.InterfaceIdentifier); err != nil {
return nil, err
}
if err := m.validatePeerModifications(ctx, existingPeer, peer); err != nil {
return nil, fmt.Errorf("update not allowed: %w", err)
}
@@ -373,6 +384,10 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
return err
}
if err := m.checkInterfaceAccess(ctx, peer.InterfaceIdentifier); err != nil {
return err
}
if err := m.validatePeerDeletion(ctx, peer); err != nil {
return fmt.Errorf("delete not allowed: %w", err)
}
@@ -606,4 +621,22 @@ func (m Manager) validatePeerDeletion(ctx context.Context, _ *domain.Peer) error
return nil
}
func (m Manager) checkInterfaceAccess(ctx context.Context, id domain.InterfaceIdentifier) error {
user := domain.GetUserInfo(ctx)
if user.IsAdmin {
return nil
}
iface, err := m.db.GetInterface(ctx, id)
if err != nil {
return fmt.Errorf("failed to get interface %s: %w", id, err)
}
if !iface.IsUserAllowed(user.Id, m.cfg) {
return fmt.Errorf("user %s is not allowed to access interface %s: %w", user.Id, id, domain.ErrNoPermission)
}
return nil
}
// endregion helper-functions

View File

@@ -60,6 +60,7 @@ func (f *mockController) PingAddresses(_ context.Context, _ string) (*domain.Pin
type mockDB struct {
savedPeers map[domain.PeerIdentifier]*domain.Peer
iface *domain.Interface
interfaces []domain.Interface
}
func (f *mockDB) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) {
@@ -79,6 +80,9 @@ func (f *mockDB) GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier
return nil, nil
}
func (f *mockDB) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) {
if f.interfaces != nil {
return f.interfaces, nil
}
if f.iface != nil {
return []domain.Interface{*f.iface}, nil
}