fix race condition during ldap initialization (#571)
Some checks are pending
Docker / Build and Push (push) Waiting to run
Docker / release (push) Blocked by required conditions
github-pages / deploy (push) Waiting to run

This commit is contained in:
Christoph Haas
2025-11-20 18:28:20 +01:00
parent d759fc7dc7
commit 907bb0599a
3 changed files with 61 additions and 50 deletions

View File

@@ -20,18 +20,7 @@ type LdapAuthenticator struct {
} }
func newLdapAuthenticator(_ context.Context, cfg *config.LdapProvider) (*LdapAuthenticator, error) { func newLdapAuthenticator(_ context.Context, cfg *config.LdapProvider) (*LdapAuthenticator, error) {
var provider = &LdapAuthenticator{} return &LdapAuthenticator{cfg: cfg}, nil
provider.cfg = cfg
dn, err := ldap.ParseDN(cfg.AdminGroupDN)
if err != nil {
return nil, fmt.Errorf("failed to parse admin group DN: %w", err)
}
provider.cfg.FieldMap = provider.getLdapFieldMapping(cfg.FieldMap)
provider.cfg.ParsedAdminGroupDN = dn
return provider, nil
} }
// GetName returns the name of the LDAP authenticator. // GetName returns the name of the LDAP authenticator.
@@ -154,40 +143,3 @@ func (l LdapAuthenticator) ParseUserInfo(raw map[string]any) (*domain.Authentica
return userInfo, nil return userInfo, nil
} }
func (l LdapAuthenticator) getLdapFieldMapping(f config.LdapFields) config.LdapFields {
defaultMap := config.LdapFields{
BaseFields: config.BaseFields{
UserIdentifier: "mail",
Email: "mail",
Firstname: "givenName",
Lastname: "sn",
Phone: "telephoneNumber",
Department: "department",
},
GroupMembership: "memberOf",
}
if f.UserIdentifier != "" {
defaultMap.UserIdentifier = f.UserIdentifier
}
if f.Email != "" {
defaultMap.Email = f.Email
}
if f.Firstname != "" {
defaultMap.Firstname = f.Firstname
}
if f.Lastname != "" {
defaultMap.Lastname = f.Lastname
}
if f.Phone != "" {
defaultMap.Phone = f.Phone
}
if f.Department != "" {
defaultMap.Department = f.Department
}
if f.GroupMembership != "" {
defaultMap.GroupMembership = f.GroupMembership
}
return defaultMap
}

View File

@@ -1,6 +1,7 @@
package config package config
import ( import (
"fmt"
"log/slog" "log/slog"
"regexp" "regexp"
"time" "time"
@@ -125,6 +126,45 @@ type LdapFields struct {
GroupMembership string `yaml:"memberof"` GroupMembership string `yaml:"memberof"`
} }
// getMappingWithDefaults returns a full field mapping for the LDAP provider.
// If specific fields are not set, the default values are used.
func (f LdapFields) getMappingWithDefaults() LdapFields {
defaultMap := LdapFields{
BaseFields: BaseFields{
UserIdentifier: "mail",
Email: "mail",
Firstname: "givenName",
Lastname: "sn",
Phone: "telephoneNumber",
Department: "department",
},
GroupMembership: "memberOf",
}
if f.UserIdentifier != "" {
defaultMap.UserIdentifier = f.UserIdentifier
}
if f.Email != "" {
defaultMap.Email = f.Email
}
if f.Firstname != "" {
defaultMap.Firstname = f.Firstname
}
if f.Lastname != "" {
defaultMap.Lastname = f.Lastname
}
if f.Phone != "" {
defaultMap.Phone = f.Phone
}
if f.Department != "" {
defaultMap.Department = f.Department
}
if f.GroupMembership != "" {
defaultMap.GroupMembership = f.GroupMembership
}
return defaultMap
}
// LdapProvider contains the configuration for the LDAP connection. // LdapProvider contains the configuration for the LDAP connection.
type LdapProvider struct { type LdapProvider struct {
// ProviderName is an internal name that is used to distinguish LDAP servers. It must not contain spaces or special characters. // ProviderName is an internal name that is used to distinguish LDAP servers. It must not contain spaces or special characters.
@@ -178,6 +218,19 @@ type LdapProvider struct {
LogUserInfo bool `yaml:"log_user_info"` LogUserInfo bool `yaml:"log_user_info"`
} }
// Sanitize checks the LDAP configuration and sets default values for missing fields.
func (l *LdapProvider) Sanitize() error {
l.FieldMap = l.FieldMap.getMappingWithDefaults()
dn, err := ldap.ParseDN(l.AdminGroupDN)
if err != nil {
return fmt.Errorf("failed to parse admin group DN: %w", err)
}
l.ParsedAdminGroupDN = dn
return nil
}
// OpenIDConnectProvider contains the configuration for the OpenID Connect provider. // OpenIDConnectProvider contains the configuration for the OpenID Connect provider.
type OpenIDConnectProvider struct { type OpenIDConnectProvider struct {
// ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters. // ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters.

View File

@@ -177,7 +177,8 @@ func defaultConfig() *Config {
cfg.Statistics.PingCheckWorkers = getEnvInt("WG_PORTAL_STATISTICS_PING_CHECK_WORKERS", 10) cfg.Statistics.PingCheckWorkers = getEnvInt("WG_PORTAL_STATISTICS_PING_CHECK_WORKERS", 10)
cfg.Statistics.PingUnprivileged = getEnvBool("WG_PORTAL_STATISTICS_PING_UNPRIVILEGED", false) cfg.Statistics.PingUnprivileged = getEnvBool("WG_PORTAL_STATISTICS_PING_UNPRIVILEGED", false)
cfg.Statistics.PingCheckInterval = getEnvDuration("WG_PORTAL_STATISTICS_PING_CHECK_INTERVAL", 1*time.Minute) cfg.Statistics.PingCheckInterval = getEnvDuration("WG_PORTAL_STATISTICS_PING_CHECK_INTERVAL", 1*time.Minute)
cfg.Statistics.DataCollectionInterval = getEnvDuration("WG_PORTAL_STATISTICS_DATA_COLLECTION_INTERVAL", 1*time.Minute) cfg.Statistics.DataCollectionInterval = getEnvDuration("WG_PORTAL_STATISTICS_DATA_COLLECTION_INTERVAL",
1*time.Minute)
cfg.Statistics.CollectInterfaceData = getEnvBool("WG_PORTAL_STATISTICS_COLLECT_INTERFACE_DATA", true) cfg.Statistics.CollectInterfaceData = getEnvBool("WG_PORTAL_STATISTICS_COLLECT_INTERFACE_DATA", true)
cfg.Statistics.CollectPeerData = getEnvBool("WG_PORTAL_STATISTICS_COLLECT_PEER_DATA", true) cfg.Statistics.CollectPeerData = getEnvBool("WG_PORTAL_STATISTICS_COLLECT_PEER_DATA", true)
cfg.Statistics.CollectAuditData = getEnvBool("WG_PORTAL_STATISTICS_COLLECT_AUDIT_DATA", true) cfg.Statistics.CollectAuditData = getEnvBool("WG_PORTAL_STATISTICS_COLLECT_AUDIT_DATA", true)
@@ -235,6 +236,11 @@ func GetConfig() (*Config, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
for i := range cfg.Auth.Ldap {
if err := cfg.Auth.Ldap[i].Sanitize(); err != nil {
return nil, fmt.Errorf("sanitizing of ldap config for %s failed: %w", cfg.Auth.Ldap[i].ProviderName, err)
}
}
return cfg, nil return cfg, nil
} }