diff --git a/internal/app/auth/auth_ldap.go b/internal/app/auth/auth_ldap.go index 84bdcd1..faca76a 100644 --- a/internal/app/auth/auth_ldap.go +++ b/internal/app/auth/auth_ldap.go @@ -20,18 +20,7 @@ type LdapAuthenticator struct { } func newLdapAuthenticator(_ context.Context, cfg *config.LdapProvider) (*LdapAuthenticator, error) { - var provider = &LdapAuthenticator{} - - provider.cfg = cfg - - dn, err := ldap.ParseDN(cfg.AdminGroupDN) - if err != nil { - return nil, fmt.Errorf("failed to parse admin group DN: %w", err) - } - provider.cfg.FieldMap = provider.getLdapFieldMapping(cfg.FieldMap) - provider.cfg.ParsedAdminGroupDN = dn - - return provider, nil + return &LdapAuthenticator{cfg: cfg}, nil } // GetName returns the name of the LDAP authenticator. @@ -154,40 +143,3 @@ func (l LdapAuthenticator) ParseUserInfo(raw map[string]any) (*domain.Authentica return userInfo, nil } - -func (l LdapAuthenticator) getLdapFieldMapping(f config.LdapFields) config.LdapFields { - defaultMap := config.LdapFields{ - BaseFields: config.BaseFields{ - UserIdentifier: "mail", - Email: "mail", - Firstname: "givenName", - Lastname: "sn", - Phone: "telephoneNumber", - Department: "department", - }, - GroupMembership: "memberOf", - } - if f.UserIdentifier != "" { - defaultMap.UserIdentifier = f.UserIdentifier - } - if f.Email != "" { - defaultMap.Email = f.Email - } - if f.Firstname != "" { - defaultMap.Firstname = f.Firstname - } - if f.Lastname != "" { - defaultMap.Lastname = f.Lastname - } - if f.Phone != "" { - defaultMap.Phone = f.Phone - } - if f.Department != "" { - defaultMap.Department = f.Department - } - if f.GroupMembership != "" { - defaultMap.GroupMembership = f.GroupMembership - } - - return defaultMap -} diff --git a/internal/config/auth.go b/internal/config/auth.go index 34dfec6..4314b63 100644 --- a/internal/config/auth.go +++ b/internal/config/auth.go @@ -1,6 +1,7 @@ package config import ( + "fmt" "log/slog" "regexp" "time" @@ -125,6 +126,45 @@ type LdapFields struct { GroupMembership string `yaml:"memberof"` } +// getMappingWithDefaults returns a full field mapping for the LDAP provider. +// If specific fields are not set, the default values are used. +func (f LdapFields) getMappingWithDefaults() LdapFields { + defaultMap := LdapFields{ + BaseFields: BaseFields{ + UserIdentifier: "mail", + Email: "mail", + Firstname: "givenName", + Lastname: "sn", + Phone: "telephoneNumber", + Department: "department", + }, + GroupMembership: "memberOf", + } + if f.UserIdentifier != "" { + defaultMap.UserIdentifier = f.UserIdentifier + } + if f.Email != "" { + defaultMap.Email = f.Email + } + if f.Firstname != "" { + defaultMap.Firstname = f.Firstname + } + if f.Lastname != "" { + defaultMap.Lastname = f.Lastname + } + if f.Phone != "" { + defaultMap.Phone = f.Phone + } + if f.Department != "" { + defaultMap.Department = f.Department + } + if f.GroupMembership != "" { + defaultMap.GroupMembership = f.GroupMembership + } + + return defaultMap +} + // LdapProvider contains the configuration for the LDAP connection. type LdapProvider struct { // ProviderName is an internal name that is used to distinguish LDAP servers. It must not contain spaces or special characters. @@ -178,6 +218,19 @@ type LdapProvider struct { LogUserInfo bool `yaml:"log_user_info"` } +// Sanitize checks the LDAP configuration and sets default values for missing fields. +func (l *LdapProvider) Sanitize() error { + l.FieldMap = l.FieldMap.getMappingWithDefaults() + + dn, err := ldap.ParseDN(l.AdminGroupDN) + if err != nil { + return fmt.Errorf("failed to parse admin group DN: %w", err) + } + l.ParsedAdminGroupDN = dn + + return nil +} + // OpenIDConnectProvider contains the configuration for the OpenID Connect provider. type OpenIDConnectProvider struct { // ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters. diff --git a/internal/config/config.go b/internal/config/config.go index b2d5b8c..5cd4bff 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -177,7 +177,8 @@ func defaultConfig() *Config { cfg.Statistics.PingCheckWorkers = getEnvInt("WG_PORTAL_STATISTICS_PING_CHECK_WORKERS", 10) cfg.Statistics.PingUnprivileged = getEnvBool("WG_PORTAL_STATISTICS_PING_UNPRIVILEGED", false) cfg.Statistics.PingCheckInterval = getEnvDuration("WG_PORTAL_STATISTICS_PING_CHECK_INTERVAL", 1*time.Minute) - cfg.Statistics.DataCollectionInterval = getEnvDuration("WG_PORTAL_STATISTICS_DATA_COLLECTION_INTERVAL", 1*time.Minute) + cfg.Statistics.DataCollectionInterval = getEnvDuration("WG_PORTAL_STATISTICS_DATA_COLLECTION_INTERVAL", + 1*time.Minute) cfg.Statistics.CollectInterfaceData = getEnvBool("WG_PORTAL_STATISTICS_COLLECT_INTERFACE_DATA", true) cfg.Statistics.CollectPeerData = getEnvBool("WG_PORTAL_STATISTICS_COLLECT_PEER_DATA", true) cfg.Statistics.CollectAuditData = getEnvBool("WG_PORTAL_STATISTICS_COLLECT_AUDIT_DATA", true) @@ -235,6 +236,11 @@ func GetConfig() (*Config, error) { if err != nil { return nil, err } + for i := range cfg.Auth.Ldap { + if err := cfg.Auth.Ldap[i].Sanitize(); err != nil { + return nil, fmt.Errorf("sanitizing of ldap config for %s failed: %w", cfg.Auth.Ldap[i].ProviderName, err) + } + } return cfg, nil }