From 3eb84f0ee9a1f90b8e88f75dcb0a77306e953801 Mon Sep 17 00:00:00 2001 From: Vladimir Dombrovski Date: Mon, 5 May 2025 18:26:19 +0200 Subject: [PATCH] Enable allowed_domains in oauth and oidc providers (#416) * Enable allowed_domains in oauth and oidc providers Signed-off-by: Vladimir DOMBROVSKI * Domain check code cleanup * Run gofmt on domain validation code --------- Signed-off-by: Vladimir DOMBROVSKI --- internal/app/auth/auth.go | 23 +++++++++++++++++++++++ internal/app/auth/auth_oauth.go | 6 ++++++ internal/app/auth/auth_oidc.go | 6 ++++++ internal/config/auth.go | 6 ++++++ 4 files changed, 41 insertions(+) diff --git a/internal/app/auth/auth.go b/internal/app/auth/auth.go index bbda138..380fe81 100644 --- a/internal/app/auth/auth.go +++ b/internal/app/auth/auth.go @@ -63,6 +63,8 @@ type AuthenticatorOauth interface { ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) // RegistrationEnabled returns whether registration is enabled for the OAuth authenticator. RegistrationEnabled() bool + // GetAllowedDomains returns the list of whitelisted domains + GetAllowedDomains() []string } // AuthenticatorLdap is the interface for all LDAP authenticators. @@ -392,6 +394,23 @@ func (a *Authenticator) randString(nByte int) (string, error) { return base64.RawURLEncoding.EncodeToString(b), nil } +func isDomainAllowed(email string, allowedDomains []string) bool { + if len(allowedDomains) == 0 { + return true + } + parts := strings.Split(email, "@") + if len(parts) != 2 { + return false + } + domain := strings.ToLower(parts[1]) + for _, allowed := range allowedDomains { + if domain == strings.ToLower(allowed) { + return true + } + } + return false +} + // OauthLoginStep2 finishes the oauth authentication flow by exchanging the code for an access token and // fetching the user information. func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, error) { @@ -431,6 +450,10 @@ func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce, return nil, fmt.Errorf("unable to process user information: %w", err) } + if !isDomainAllowed(userInfo.Email, oauthProvider.GetAllowedDomains()) { + return nil, fmt.Errorf("user is not in allowed domains: %w", err) + } + if user.IsLocked() || user.IsDisabled() { a.bus.Publish(app.TopicAuditLoginFailed, domain.AuditEventWrapper[audit.AuthEvent]{ Ctx: ctx, diff --git a/internal/app/auth/auth_oauth.go b/internal/app/auth/auth_oauth.go index 7e730bc..56d53c5 100644 --- a/internal/app/auth/auth_oauth.go +++ b/internal/app/auth/auth_oauth.go @@ -27,6 +27,7 @@ type PlainOauthAuthenticator struct { userAdminMapping *config.OauthAdminMapping registrationEnabled bool userInfoLogging bool + allowedDomains []string } func newPlainOauthAuthenticator( @@ -56,6 +57,7 @@ func newPlainOauthAuthenticator( provider.userAdminMapping = &cfg.AdminMapping provider.registrationEnabled = cfg.RegistrationEnabled provider.userInfoLogging = cfg.LogUserInfo + provider.allowedDomains = cfg.AllowedDomains return provider, nil } @@ -65,6 +67,10 @@ func (p PlainOauthAuthenticator) GetName() string { return p.name } +func (p PlainOauthAuthenticator) GetAllowedDomains() []string { + return p.allowedDomains +} + // RegistrationEnabled returns whether registration is enabled for the OAuth authenticator. func (p PlainOauthAuthenticator) RegistrationEnabled() bool { return p.registrationEnabled diff --git a/internal/app/auth/auth_oidc.go b/internal/app/auth/auth_oidc.go index d832768..0a4ecb0 100644 --- a/internal/app/auth/auth_oidc.go +++ b/internal/app/auth/auth_oidc.go @@ -24,6 +24,7 @@ type OidcAuthenticator struct { userAdminMapping *config.OauthAdminMapping registrationEnabled bool userInfoLogging bool + allowedDomains []string } func newOidcAuthenticator( @@ -57,6 +58,7 @@ func newOidcAuthenticator( provider.userAdminMapping = &cfg.AdminMapping provider.registrationEnabled = cfg.RegistrationEnabled provider.userInfoLogging = cfg.LogUserInfo + provider.allowedDomains = cfg.AllowedDomains return provider, nil } @@ -66,6 +68,10 @@ func (o OidcAuthenticator) GetName() string { return o.name } +func (o OidcAuthenticator) GetAllowedDomains() []string { + return o.allowedDomains +} + // RegistrationEnabled returns whether registration is enabled for this authenticator. func (o OidcAuthenticator) RegistrationEnabled() bool { return o.registrationEnabled diff --git a/internal/config/auth.go b/internal/config/auth.go index 7e70dae..3132fb7 100644 --- a/internal/config/auth.go +++ b/internal/config/auth.go @@ -188,6 +188,9 @@ type OpenIDConnectProvider struct { // ExtraScopes specifies optional requested permissions. ExtraScopes []string `yaml:"extra_scopes"` + // AllowedDomains defines the list of allowed domains + AllowedDomains []string `yaml:"allowed_domains"` + // FieldMap is used to map the names of the user-info endpoint fields to wg-portal fields FieldMap OauthFields `yaml:"field_map"` @@ -226,6 +229,9 @@ type OAuthProvider struct { // Scope specifies optional requested permissions. Scopes []string `yaml:"scopes"` + // AllowedDomains defines the list of allowed domains + AllowedDomains []string `yaml:"allowed_domains"` + // FieldMap is used to map the names of the user-info endpoint fields to wg-portal fields FieldMap OauthFields `yaml:"field_map"`