mirror of
https://github.com/h44z/wg-portal.git
synced 2025-11-21 08:16:18 +00:00
fix race condition during ldap initialization (#571)
This commit is contained in:
@@ -20,18 +20,7 @@ type LdapAuthenticator struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newLdapAuthenticator(_ context.Context, cfg *config.LdapProvider) (*LdapAuthenticator, error) {
|
func newLdapAuthenticator(_ context.Context, cfg *config.LdapProvider) (*LdapAuthenticator, error) {
|
||||||
var provider = &LdapAuthenticator{}
|
return &LdapAuthenticator{cfg: cfg}, nil
|
||||||
|
|
||||||
provider.cfg = cfg
|
|
||||||
|
|
||||||
dn, err := ldap.ParseDN(cfg.AdminGroupDN)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse admin group DN: %w", err)
|
|
||||||
}
|
|
||||||
provider.cfg.FieldMap = provider.getLdapFieldMapping(cfg.FieldMap)
|
|
||||||
provider.cfg.ParsedAdminGroupDN = dn
|
|
||||||
|
|
||||||
return provider, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetName returns the name of the LDAP authenticator.
|
// GetName returns the name of the LDAP authenticator.
|
||||||
@@ -154,40 +143,3 @@ func (l LdapAuthenticator) ParseUserInfo(raw map[string]any) (*domain.Authentica
|
|||||||
|
|
||||||
return userInfo, nil
|
return userInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l LdapAuthenticator) getLdapFieldMapping(f config.LdapFields) config.LdapFields {
|
|
||||||
defaultMap := config.LdapFields{
|
|
||||||
BaseFields: config.BaseFields{
|
|
||||||
UserIdentifier: "mail",
|
|
||||||
Email: "mail",
|
|
||||||
Firstname: "givenName",
|
|
||||||
Lastname: "sn",
|
|
||||||
Phone: "telephoneNumber",
|
|
||||||
Department: "department",
|
|
||||||
},
|
|
||||||
GroupMembership: "memberOf",
|
|
||||||
}
|
|
||||||
if f.UserIdentifier != "" {
|
|
||||||
defaultMap.UserIdentifier = f.UserIdentifier
|
|
||||||
}
|
|
||||||
if f.Email != "" {
|
|
||||||
defaultMap.Email = f.Email
|
|
||||||
}
|
|
||||||
if f.Firstname != "" {
|
|
||||||
defaultMap.Firstname = f.Firstname
|
|
||||||
}
|
|
||||||
if f.Lastname != "" {
|
|
||||||
defaultMap.Lastname = f.Lastname
|
|
||||||
}
|
|
||||||
if f.Phone != "" {
|
|
||||||
defaultMap.Phone = f.Phone
|
|
||||||
}
|
|
||||||
if f.Department != "" {
|
|
||||||
defaultMap.Department = f.Department
|
|
||||||
}
|
|
||||||
if f.GroupMembership != "" {
|
|
||||||
defaultMap.GroupMembership = f.GroupMembership
|
|
||||||
}
|
|
||||||
|
|
||||||
return defaultMap
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"regexp"
|
"regexp"
|
||||||
"time"
|
"time"
|
||||||
@@ -125,6 +126,45 @@ type LdapFields struct {
|
|||||||
GroupMembership string `yaml:"memberof"`
|
GroupMembership string `yaml:"memberof"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getMappingWithDefaults returns a full field mapping for the LDAP provider.
|
||||||
|
// If specific fields are not set, the default values are used.
|
||||||
|
func (f LdapFields) getMappingWithDefaults() LdapFields {
|
||||||
|
defaultMap := LdapFields{
|
||||||
|
BaseFields: BaseFields{
|
||||||
|
UserIdentifier: "mail",
|
||||||
|
Email: "mail",
|
||||||
|
Firstname: "givenName",
|
||||||
|
Lastname: "sn",
|
||||||
|
Phone: "telephoneNumber",
|
||||||
|
Department: "department",
|
||||||
|
},
|
||||||
|
GroupMembership: "memberOf",
|
||||||
|
}
|
||||||
|
if f.UserIdentifier != "" {
|
||||||
|
defaultMap.UserIdentifier = f.UserIdentifier
|
||||||
|
}
|
||||||
|
if f.Email != "" {
|
||||||
|
defaultMap.Email = f.Email
|
||||||
|
}
|
||||||
|
if f.Firstname != "" {
|
||||||
|
defaultMap.Firstname = f.Firstname
|
||||||
|
}
|
||||||
|
if f.Lastname != "" {
|
||||||
|
defaultMap.Lastname = f.Lastname
|
||||||
|
}
|
||||||
|
if f.Phone != "" {
|
||||||
|
defaultMap.Phone = f.Phone
|
||||||
|
}
|
||||||
|
if f.Department != "" {
|
||||||
|
defaultMap.Department = f.Department
|
||||||
|
}
|
||||||
|
if f.GroupMembership != "" {
|
||||||
|
defaultMap.GroupMembership = f.GroupMembership
|
||||||
|
}
|
||||||
|
|
||||||
|
return defaultMap
|
||||||
|
}
|
||||||
|
|
||||||
// LdapProvider contains the configuration for the LDAP connection.
|
// LdapProvider contains the configuration for the LDAP connection.
|
||||||
type LdapProvider struct {
|
type LdapProvider struct {
|
||||||
// ProviderName is an internal name that is used to distinguish LDAP servers. It must not contain spaces or special characters.
|
// ProviderName is an internal name that is used to distinguish LDAP servers. It must not contain spaces or special characters.
|
||||||
@@ -178,6 +218,19 @@ type LdapProvider struct {
|
|||||||
LogUserInfo bool `yaml:"log_user_info"`
|
LogUserInfo bool `yaml:"log_user_info"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sanitize checks the LDAP configuration and sets default values for missing fields.
|
||||||
|
func (l *LdapProvider) Sanitize() error {
|
||||||
|
l.FieldMap = l.FieldMap.getMappingWithDefaults()
|
||||||
|
|
||||||
|
dn, err := ldap.ParseDN(l.AdminGroupDN)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse admin group DN: %w", err)
|
||||||
|
}
|
||||||
|
l.ParsedAdminGroupDN = dn
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// OpenIDConnectProvider contains the configuration for the OpenID Connect provider.
|
// OpenIDConnectProvider contains the configuration for the OpenID Connect provider.
|
||||||
type OpenIDConnectProvider struct {
|
type OpenIDConnectProvider struct {
|
||||||
// ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters.
|
// ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters.
|
||||||
|
|||||||
@@ -177,7 +177,8 @@ func defaultConfig() *Config {
|
|||||||
cfg.Statistics.PingCheckWorkers = getEnvInt("WG_PORTAL_STATISTICS_PING_CHECK_WORKERS", 10)
|
cfg.Statistics.PingCheckWorkers = getEnvInt("WG_PORTAL_STATISTICS_PING_CHECK_WORKERS", 10)
|
||||||
cfg.Statistics.PingUnprivileged = getEnvBool("WG_PORTAL_STATISTICS_PING_UNPRIVILEGED", false)
|
cfg.Statistics.PingUnprivileged = getEnvBool("WG_PORTAL_STATISTICS_PING_UNPRIVILEGED", false)
|
||||||
cfg.Statistics.PingCheckInterval = getEnvDuration("WG_PORTAL_STATISTICS_PING_CHECK_INTERVAL", 1*time.Minute)
|
cfg.Statistics.PingCheckInterval = getEnvDuration("WG_PORTAL_STATISTICS_PING_CHECK_INTERVAL", 1*time.Minute)
|
||||||
cfg.Statistics.DataCollectionInterval = getEnvDuration("WG_PORTAL_STATISTICS_DATA_COLLECTION_INTERVAL", 1*time.Minute)
|
cfg.Statistics.DataCollectionInterval = getEnvDuration("WG_PORTAL_STATISTICS_DATA_COLLECTION_INTERVAL",
|
||||||
|
1*time.Minute)
|
||||||
cfg.Statistics.CollectInterfaceData = getEnvBool("WG_PORTAL_STATISTICS_COLLECT_INTERFACE_DATA", true)
|
cfg.Statistics.CollectInterfaceData = getEnvBool("WG_PORTAL_STATISTICS_COLLECT_INTERFACE_DATA", true)
|
||||||
cfg.Statistics.CollectPeerData = getEnvBool("WG_PORTAL_STATISTICS_COLLECT_PEER_DATA", true)
|
cfg.Statistics.CollectPeerData = getEnvBool("WG_PORTAL_STATISTICS_COLLECT_PEER_DATA", true)
|
||||||
cfg.Statistics.CollectAuditData = getEnvBool("WG_PORTAL_STATISTICS_COLLECT_AUDIT_DATA", true)
|
cfg.Statistics.CollectAuditData = getEnvBool("WG_PORTAL_STATISTICS_COLLECT_AUDIT_DATA", true)
|
||||||
@@ -235,6 +236,11 @@ func GetConfig() (*Config, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
for i := range cfg.Auth.Ldap {
|
||||||
|
if err := cfg.Auth.Ldap[i].Sanitize(); err != nil {
|
||||||
|
return nil, fmt.Errorf("sanitizing of ldap config for %s failed: %w", cfg.Auth.Ldap[i].ProviderName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return cfg, nil
|
return cfg, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user