feat: allow multiple auth sources per user (#500,#477) (#612)

* feat: allow multiple auth sources per user (#500,#477)

* only override isAdmin flag if it is provided by the authentication source
This commit is contained in:
h44z
2026-01-21 22:22:22 +01:00
committed by GitHub
parent d2fe267be7
commit e0f6c1d04b
44 changed files with 1158 additions and 798 deletions

View File

@@ -4,15 +4,12 @@ import (
"context"
"errors"
"fmt"
"log/slog"
"math"
"sync"
"time"
"github.com/go-ldap/ldap/v3"
"github.com/google/uuid"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
@@ -79,7 +76,7 @@ func (m Manager) RegisterUser(ctx context.Context, user *domain.User) error {
return err
}
createdUser, err := m.CreateUser(ctx, user)
createdUser, err := m.create(ctx, user)
if err != nil {
return err
}
@@ -101,20 +98,11 @@ func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain
return nil, err
}
user, err := m.users.GetUser(ctx, id)
if err != nil {
return nil, fmt.Errorf("unable to load user %s: %w", id, err)
}
peers, _ := m.peers.GetUserPeers(ctx, id) // ignore error, list will be empty in error case
user.LinkedPeerCount = len(peers)
return user, nil
return m.getUser(ctx, id)
}
// GetUserByEmail returns the user with the given email address.
func (m Manager) GetUserByEmail(ctx context.Context, email string) (*domain.User, error) {
user, err := m.users.GetUserByEmail(ctx, email)
if err != nil {
return nil, fmt.Errorf("unable to load user for email %s: %w", email, err)
@@ -124,16 +112,11 @@ func (m Manager) GetUserByEmail(ctx context.Context, email string) (*domain.User
return nil, err
}
peers, _ := m.peers.GetUserPeers(ctx, user.Identifier) // ignore error, list will be empty in error case
user.LinkedPeerCount = len(peers)
return user, nil
return m.enrichUser(ctx, user), nil
}
// GetUserByWebAuthnCredential returns the user for the given WebAuthn credential.
func (m Manager) GetUserByWebAuthnCredential(ctx context.Context, credentialIdBase64 string) (*domain.User, error) {
user, err := m.users.GetUserByWebAuthnCredential(ctx, credentialIdBase64)
if err != nil {
return nil, fmt.Errorf("unable to load user for webauthn credential %s: %w", credentialIdBase64, err)
@@ -143,11 +126,7 @@ func (m Manager) GetUserByWebAuthnCredential(ctx context.Context, credentialIdBa
return nil, err
}
peers, _ := m.peers.GetUserPeers(ctx, user.Identifier) // ignore error, list will be empty in error case
user.LinkedPeerCount = len(peers)
return user, nil
return m.enrichUser(ctx, user), nil
}
// GetAllUsers returns all users.
@@ -169,8 +148,7 @@ func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) {
go func() {
defer wg.Done()
for user := range ch {
peers, _ := m.peers.GetUserPeers(ctx, user.Identifier) // ignore error, list will be empty in error case
user.LinkedPeerCount = len(peers)
m.enrichUser(ctx, user)
}
}()
}
@@ -194,77 +172,29 @@ func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.Use
return nil, fmt.Errorf("unable to load existing user %s: %w", user.Identifier, err)
}
if err := m.validateModifications(ctx, existingUser, user); err != nil {
return nil, fmt.Errorf("update not allowed: %w", err)
}
user.CopyCalculatedAttributes(existingUser, true) // ensure that crucial attributes stay the same
user.CopyCalculatedAttributes(existingUser)
err = user.HashPassword()
return m.update(ctx, existingUser, user, true)
}
// UpdateUserInternal updates the user with the given identifier. This function must never be called from external.
// This function allows to override authentications and webauthn credentials.
func (m Manager) UpdateUserInternal(ctx context.Context, user *domain.User) (*domain.User, error) {
existingUser, err := m.users.GetUser(ctx, user.Identifier)
if err != nil {
return nil, err
}
if user.Password == "" { // keep old password
user.Password = existingUser.Password
return nil, fmt.Errorf("unable to load existing user %s: %w", user.Identifier, err)
}
err = m.users.SaveUser(ctx, existingUser.Identifier, func(u *domain.User) (*domain.User, error) {
user.CopyCalculatedAttributes(u)
return user, nil
})
if err != nil {
return nil, fmt.Errorf("update failure: %w", err)
}
m.bus.Publish(app.TopicUserUpdated, *user)
switch {
case !existingUser.IsDisabled() && user.IsDisabled():
m.bus.Publish(app.TopicUserDisabled, *user)
case existingUser.IsDisabled() && !user.IsDisabled():
m.bus.Publish(app.TopicUserEnabled, *user)
}
return user, nil
return m.update(ctx, existingUser, user, false)
}
// CreateUser creates a new user.
func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
if user.Identifier == "" {
return nil, errors.New("missing user identifier")
}
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err
}
existingUser, err := m.users.GetUser(ctx, user.Identifier)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return nil, fmt.Errorf("unable to load existing user %s: %w", user.Identifier, err)
}
if existingUser != nil {
return nil, errors.Join(fmt.Errorf("user %s already exists", user.Identifier), domain.ErrDuplicateEntry)
}
if err := m.validateCreation(ctx, user); err != nil {
return nil, fmt.Errorf("creation not allowed: %w", err)
}
err = user.HashPassword()
if err != nil {
return nil, err
}
err = m.users.SaveUser(ctx, user.Identifier, func(u *domain.User) (*domain.User, error) {
user.CopyCalculatedAttributes(u)
return user, nil
})
if err != nil {
return nil, fmt.Errorf("creation failure: %w", err)
}
m.bus.Publish(app.TopicUserCreated, *user)
return user, nil
return m.create(ctx, user)
}
// DeleteUser deletes the user with the given identifier.
@@ -307,15 +237,10 @@ func (m Manager) ActivateApi(ctx context.Context, id domain.UserIdentifier) (*do
user.ApiToken = uuid.New().String()
user.ApiTokenCreated = &now
err = m.users.SaveUser(ctx, user.Identifier, func(u *domain.User) (*domain.User, error) {
user.CopyCalculatedAttributes(u)
return user, nil
})
user, err = m.update(ctx, user, user, true) // self-update
if err != nil {
return nil, fmt.Errorf("update failure: %w", err)
return nil, err
}
m.bus.Publish(app.TopicUserUpdated, *user)
m.bus.Publish(app.TopicUserApiEnabled, *user)
return user, nil
@@ -335,15 +260,10 @@ func (m Manager) DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*
user.ApiToken = ""
user.ApiTokenCreated = nil
err = m.users.SaveUser(ctx, user.Identifier, func(u *domain.User) (*domain.User, error) {
user.CopyCalculatedAttributes(u)
return user, nil
})
user, err = m.update(ctx, user, user, true) // self-update
if err != nil {
return nil, fmt.Errorf("update failure: %w", err)
return nil, err
}
m.bus.Publish(app.TopicUserUpdated, *user)
m.bus.Publish(app.TopicUserApiDisabled, *user)
return user, nil
@@ -380,10 +300,6 @@ func (m Manager) validateModifications(ctx context.Context, old, new *domain.Use
return fmt.Errorf("cannot lock own user: %w", domain.ErrInvalidData)
}
if old.Source != new.Source {
return fmt.Errorf("cannot change user source: %w", domain.ErrInvalidData)
}
return nil
}
@@ -414,14 +330,19 @@ func (m Manager) validateCreation(ctx context.Context, new *domain.User) error {
return fmt.Errorf("reserved user identifier: %w", domain.ErrInvalidData)
}
if len(new.Authentications) != 1 {
return fmt.Errorf("invalid number of authentications: %d, expected 1: %w",
len(new.Authentications), domain.ErrInvalidData)
}
// Admins are allowed to create users for arbitrary sources.
if new.Source != domain.UserSourceDatabase && !currentUser.IsAdmin {
if new.Authentications[0].Source != domain.UserSourceDatabase && !currentUser.IsAdmin {
return fmt.Errorf("invalid user source: %s, only %s is allowed: %w",
new.Source, domain.UserSourceDatabase, domain.ErrInvalidData)
new.Authentications[0].Source, domain.UserSourceDatabase, domain.ErrInvalidData)
}
// database users must have a password
if new.Source == domain.UserSourceDatabase && string(new.Password) == "" {
if new.Authentications[0].Source == domain.UserSourceDatabase && string(new.Password) == "" {
return fmt.Errorf("missing password: %w", domain.ErrInvalidData)
}
@@ -460,214 +381,112 @@ func (m Manager) validateApiChange(ctx context.Context, user *domain.User) error
return nil
}
func (m Manager) runLdapSynchronizationService(ctx context.Context) {
ctx = domain.SetUserInfo(ctx, domain.LdapSyncContextUserInfo()) // switch to service context for LDAP sync
// region internal-modifiers
for _, ldapCfg := range m.cfg.Auth.Ldap { // LDAP Auth providers
go func(cfg config.LdapProvider) {
syncInterval := cfg.SyncInterval
if syncInterval == 0 {
slog.Debug("sync disabled for LDAP server", "provider", cfg.ProviderName)
return
}
// perform initial sync
err := m.synchronizeLdapUsers(ctx, &cfg)
if err != nil {
slog.Error("failed to synchronize LDAP users", "provider", cfg.ProviderName, "error", err)
} else {
slog.Debug("initial LDAP user sync completed", "provider", cfg.ProviderName)
}
// start periodic sync
running := true
for running {
select {
case <-ctx.Done():
running = false
continue
case <-time.After(syncInterval):
// select blocks until one of the cases evaluate to true
}
err := m.synchronizeLdapUsers(ctx, &cfg)
if err != nil {
slog.Error("failed to synchronize LDAP users", "provider", cfg.ProviderName, "error", err)
}
}
}(ldapCfg)
func (m Manager) enrichUser(ctx context.Context, user *domain.User) *domain.User {
if user == nil {
return nil
}
peers, _ := m.peers.GetUserPeers(ctx, user.Identifier) // ignore error, list will be empty in error case
user.LinkedPeerCount = len(peers)
return user
}
func (m Manager) synchronizeLdapUsers(ctx context.Context, provider *config.LdapProvider) error {
slog.Debug("starting to synchronize users", "provider", provider.ProviderName)
dn, err := ldap.ParseDN(provider.AdminGroupDN)
func (m Manager) getUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
user, err := m.users.GetUser(ctx, id)
if err != nil {
return fmt.Errorf("failed to parse admin group DN: %w", err)
return nil, fmt.Errorf("unable to load user %s: %w", id, err)
}
provider.ParsedAdminGroupDN = dn
conn, err := internal.LdapConnect(provider)
if err != nil {
return fmt.Errorf("failed to setup LDAP connection: %w", err)
}
defer internal.LdapDisconnect(conn)
rawUsers, err := internal.LdapFindAllUsers(conn, provider.BaseDN, provider.SyncFilter, &provider.FieldMap)
if err != nil {
return err
}
slog.Debug("fetched raw ldap users", "count", len(rawUsers), "provider", provider.ProviderName)
// Update existing LDAP users
err = m.updateLdapUsers(ctx, provider, rawUsers, &provider.FieldMap, provider.ParsedAdminGroupDN)
if err != nil {
return err
}
// Disable missing LDAP users
if provider.DisableMissing {
err = m.disableMissingLdapUsers(ctx, provider.ProviderName, rawUsers, &provider.FieldMap)
if err != nil {
return err
}
}
return nil
return m.enrichUser(ctx, user), nil
}
func (m Manager) updateLdapUsers(
ctx context.Context,
provider *config.LdapProvider,
rawUsers []internal.RawLdapUser,
fields *config.LdapFields,
adminGroupDN *ldap.DN,
) error {
for _, rawUser := range rawUsers {
user, err := convertRawLdapUser(provider.ProviderName, rawUser, fields, adminGroupDN)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return fmt.Errorf("failed to convert LDAP data for %v: %w", rawUser["dn"], err)
}
if provider.SyncLogUserInfo {
slog.Debug("ldap user data",
"raw-user", rawUser, "user", user.Identifier,
"is-admin", user.IsAdmin, "provider", provider.ProviderName)
}
existingUser, err := m.users.GetUser(ctx, user.Identifier)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return fmt.Errorf("find error for user id %s: %w", user.Identifier, err)
}
tctx, cancel := context.WithTimeout(ctx, 30*time.Second)
tctx = domain.SetUserInfo(tctx, domain.SystemAdminContextUserInfo())
if existingUser == nil {
// create new user
slog.Debug("creating new user from provider", "user", user.Identifier, "provider", provider.ProviderName)
_, err := m.CreateUser(tctx, user)
if err != nil {
cancel()
return fmt.Errorf("create error for user id %s: %w", user.Identifier, err)
}
} else {
// update existing user
if provider.AutoReEnable && existingUser.DisabledReason == domain.DisabledReasonLdapMissing {
user.Disabled = nil
user.DisabledReason = ""
} else {
user.Disabled = existingUser.Disabled
user.DisabledReason = existingUser.DisabledReason
}
if existingUser.Source == domain.UserSourceLdap && userChangedInLdap(existingUser, user) {
err := m.users.SaveUser(tctx, user.Identifier, func(u *domain.User) (*domain.User, error) {
u.UpdatedAt = time.Now()
u.UpdatedBy = domain.CtxSystemLdapSyncer
u.Source = user.Source
u.ProviderName = user.ProviderName
u.Email = user.Email
u.Firstname = user.Firstname
u.Lastname = user.Lastname
u.Phone = user.Phone
u.Department = user.Department
u.IsAdmin = user.IsAdmin
u.Disabled = nil
u.DisabledReason = ""
return u, nil
})
if err != nil {
cancel()
return fmt.Errorf("update error for user id %s: %w", user.Identifier, err)
}
if existingUser.IsDisabled() && !user.IsDisabled() {
m.bus.Publish(app.TopicUserEnabled, *user)
}
}
}
cancel()
func (m Manager) update(ctx context.Context, existingUser, user *domain.User, keepAuthentications bool) (
*domain.User,
error,
) {
if err := m.validateModifications(ctx, existingUser, user); err != nil {
return nil, fmt.Errorf("update not allowed: %w", err)
}
return nil
err := user.HashPassword()
if err != nil {
return nil, err
}
if user.Password == "" { // keep old password
user.Password = existingUser.Password
}
err = m.users.SaveUser(ctx, existingUser.Identifier, func(u *domain.User) (*domain.User, error) {
user.CopyCalculatedAttributes(u, keepAuthentications)
return user, nil
})
if err != nil {
return nil, fmt.Errorf("update failure: %w", err)
}
m.bus.Publish(app.TopicUserUpdated, *user)
switch {
case !existingUser.IsDisabled() && user.IsDisabled():
m.bus.Publish(app.TopicUserDisabled, *user)
case existingUser.IsDisabled() && !user.IsDisabled():
m.bus.Publish(app.TopicUserEnabled, *user)
}
return user, nil
}
func (m Manager) disableMissingLdapUsers(
ctx context.Context,
providerName string,
rawUsers []internal.RawLdapUser,
fields *config.LdapFields,
) error {
allUsers, err := m.users.GetAllUsers(ctx)
if err != nil {
return err
func (m Manager) create(ctx context.Context, user *domain.User) (*domain.User, error) {
if user.Identifier == "" {
return nil, errors.New("missing user identifier")
}
for _, user := range allUsers {
if user.Source != domain.UserSourceLdap {
continue // ignore non ldap users
}
if user.ProviderName != providerName {
continue // user was synchronized through different provider
}
if user.IsDisabled() {
continue // ignore deactivated
}
existsInLDAP := false
for _, rawUser := range rawUsers {
userId := domain.UserIdentifier(internal.MapDefaultString(rawUser, fields.UserIdentifier, ""))
if user.Identifier == userId {
existsInLDAP = true
break
}
}
if existsInLDAP {
continue
}
slog.Debug("user is missing in ldap provider, disabling", "user", user.Identifier, "provider", providerName)
existingUser, err := m.users.GetUser(ctx, user.Identifier)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return nil, fmt.Errorf("unable to load existing user %s: %w", user.Identifier, err)
}
if existingUser != nil {
return nil, errors.Join(fmt.Errorf("user %s already exists", user.Identifier), domain.ErrDuplicateEntry)
}
// Add default authentication if missing
if len(user.Authentications) == 0 {
ctxUserInfo := domain.GetUserInfo(ctx)
now := time.Now()
user.Disabled = &now
user.DisabledReason = domain.DisabledReasonLdapMissing
err := m.users.SaveUser(ctx, user.Identifier, func(u *domain.User) (*domain.User, error) {
u.Disabled = user.Disabled
u.DisabledReason = user.DisabledReason
return u, nil
})
if err != nil {
return fmt.Errorf("disable error for user id %s: %w", user.Identifier, err)
user.Authentications = []domain.UserAuthentication{
{
BaseModel: domain.BaseModel{
CreatedBy: ctxUserInfo.UserId(),
UpdatedBy: ctxUserInfo.UserId(),
CreatedAt: now,
UpdatedAt: now,
},
UserIdentifier: user.Identifier,
Source: domain.UserSourceDatabase,
ProviderName: "",
},
}
m.bus.Publish(app.TopicUserDisabled, user)
}
return nil
if err := m.validateCreation(ctx, user); err != nil {
return nil, fmt.Errorf("creation not allowed: %w", err)
}
err = user.HashPassword()
if err != nil {
return nil, err
}
err = m.users.SaveUser(ctx, user.Identifier, func(u *domain.User) (*domain.User, error) {
return user, nil
})
if err != nil {
return nil, fmt.Errorf("creation failure: %w", err)
}
m.bus.Publish(app.TopicUserCreated, *user)
return user, nil
}
// endregion internal-modifiers