From 4c986cc74c696748b49a46f6ebfd015d4e60510d Mon Sep 17 00:00:00 2001 From: Christoph Haas Date: Tue, 26 May 2026 22:47:38 +0200 Subject: [PATCH] feat: add support for PKCE (#686) --- config.yml.sample | 3 + docs/documentation/configuration/overview.md | 16 ++++ .../v0/handlers/endpoint_authentication.go | 10 ++- internal/app/api/v0/handlers/web_session.go | 34 ++++---- internal/app/auth/auth.go | 32 ++++++-- internal/app/auth/auth_oauth.go | 36 +++++++++ internal/app/auth/auth_oauth_test.go | 61 ++++++++++++++ internal/app/auth/auth_oidc.go | 36 +++++++++ internal/app/auth/auth_oidc_test.go | 79 +++++++++++++++++++ internal/config/auth.go | 16 ++++ 10 files changed, 295 insertions(+), 28 deletions(-) create mode 100644 internal/app/auth/auth_oauth_test.go create mode 100644 internal/app/auth/auth_oidc_test.go diff --git a/config.yml.sample b/config.yml.sample index 3be9ce5..7baed32 100644 --- a/config.yml.sample +++ b/config.yml.sample @@ -47,6 +47,8 @@ auth: extra_scopes: - https://www.googleapis.com/auth/userinfo.email - https://www.googleapis.com/auth/userinfo.profile + use_pkce: true + pkce_method: S256 registration_enabled: true logout_idp_session: true - id: oidc2 @@ -79,6 +81,7 @@ auth: user_identifier: sub is_admin: this-attribute-must-be-true registration_enabled: true + use_pkce: false - id: google_plain_oauth_with_groups provider_name: google4 display_name: Login with
Google4 diff --git a/docs/documentation/configuration/overview.md b/docs/documentation/configuration/overview.md index 5157f9e..10ec936 100644 --- a/docs/documentation/configuration/overview.md +++ b/docs/documentation/configuration/overview.md @@ -617,6 +617,14 @@ Below are the properties for each OIDC provider entry inside `auth.oidc`: - **Description:** If `true`, sensitive OIDC user data, such as tokens and raw responses, will be logged at the trace level upon login (for debugging). - **Important:** Keep this setting disabled in production environments! Remove logs once you finished debugging authentication issues. +#### `use_pkce` +- **Default:** `true` +- **Description:** If `true`, Proof Key for Code Exchange (PKCE) is used for the OIDC authorization code flow. A fresh `code_verifier` is generated per login request, the matching `code_challenge` is sent with the authorization request, and the `code_verifier` is included in the token exchange. Set to `false` only for providers that do not support PKCE. + +#### `pkce_method` +- **Default:** `S256` +- **Description:** PKCE challenge method to use when `use_pkce` is enabled. Supported values are `S256` and `plain`. `S256` is recommended; use `plain` only for providers that explicitly require it. + #### `logout_idp_session` - **Default:** `true` - **Description:** If `true` (default), WireGuard Portal will redirect the user to the OIDC provider's `end_session_endpoint` after local logout, terminating the session at the IdP as well. Set to `false` to only invalidate the local WireGuard Portal session without touching the IdP session. @@ -703,6 +711,14 @@ Below are the properties for each OAuth provider entry inside `auth.oauth`: - **Description:** If `true`, sensitive OIDC user data, such as tokens and raw responses, will be logged at the trace level upon login (for debugging). - **Important:** Keep this setting disabled in production environments! Remove logs once you finished debugging authentication issues. +#### `use_pkce` +- **Default:** `true` +- **Description:** If `true`, Proof Key for Code Exchange (PKCE) is used for the OIDC authorization code flow. A fresh `code_verifier` is generated per login request, the matching `code_challenge` is sent with the authorization request, and the `code_verifier` is included in the token exchange. Set to `false` only for providers that do not support PKCE. + +#### `pkce_method` +- **Default:** `S256` +- **Description:** PKCE challenge method to use when `use_pkce` is enabled. Supported values are `S256` and `plain`. `S256` is recommended; use `plain` only for providers that explicitly require it. + --- ### LDAP diff --git a/internal/app/api/v0/handlers/endpoint_authentication.go b/internal/app/api/v0/handlers/endpoint_authentication.go index 7e9aedd..3b3f872 100644 --- a/internal/app/api/v0/handlers/endpoint_authentication.go +++ b/internal/app/api/v0/handlers/endpoint_authentication.go @@ -24,9 +24,9 @@ type AuthenticationService interface { // PlainLogin authenticates a user with a username and password. PlainLogin(ctx context.Context, username, password string) (*domain.User, error) // OauthLoginStep1 initiates the OAuth login flow. - OauthLoginStep1(_ context.Context, providerId string) (authCodeUrl, state, nonce string, err error) + OauthLoginStep1(_ context.Context, providerId string) (authCodeUrl, state, nonce, codeVerifier string, err error) // OauthLoginStep2 completes the OAuth login flow and logins the user in. - OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, string, error) + OauthLoginStep2(ctx context.Context, providerId, nonce, code, codeVerifier string) (*domain.User, string, error) // OauthProviderLogoutUrl returns an IdP logout URL for the given provider if supported. OauthProviderLogoutUrl(providerId, idTokenHint, postLogoutRedirectUri string) (string, bool) } @@ -231,7 +231,7 @@ func (e AuthEndpoint) handleOauthInitiateGet() http.HandlerFunc { return } - authCodeUrl, state, nonce, err := e.authService.OauthLoginStep1(context.Background(), provider) + authCodeUrl, state, nonce, codeVerifier, err := e.authService.OauthLoginStep1(context.Background(), provider) if err != nil { slog.Debug("failed to create oauth auth code URL", "provider", provider, "error", err) @@ -247,6 +247,7 @@ func (e AuthEndpoint) handleOauthInitiateGet() http.HandlerFunc { authSession := e.session.GetData(r.Context()) authSession.OauthState = state authSession.OauthNonce = nonce + authSession.OauthCodeVerifier = codeVerifier authSession.OauthProvider = provider authSession.OauthReturnTo = returnTo e.session.SetData(r.Context(), authSession) @@ -323,7 +324,7 @@ func (e AuthEndpoint) handleOauthCallbackGet() http.HandlerFunc { loginCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) // avoid long waits user, idTokenHint, err := e.authService.OauthLoginStep2(loginCtx, provider, currentSession.OauthNonce, - oauthCode) + oauthCode, currentSession.OauthCodeVerifier) cancel() if err != nil { slog.Debug("failed to process oauth code", @@ -362,6 +363,7 @@ func (e AuthEndpoint) setAuthenticatedUser(r *http.Request, user *domain.User, o currentSession.OauthState = "" currentSession.OauthNonce = "" + currentSession.OauthCodeVerifier = "" currentSession.OauthProvider = oauthProvider currentSession.OauthReturnTo = "" currentSession.OauthIdToken = idTokenHint diff --git a/internal/app/api/v0/handlers/web_session.go b/internal/app/api/v0/handlers/web_session.go index 5b3ba26..cf893b0 100644 --- a/internal/app/api/v0/handlers/web_session.go +++ b/internal/app/api/v0/handlers/web_session.go @@ -26,11 +26,12 @@ type SessionData struct { Lastname string Email string - OauthState string - OauthNonce string - OauthProvider string - OauthReturnTo string - OauthIdToken string + OauthState string + OauthNonce string + OauthCodeVerifier string + OauthProvider string + OauthReturnTo string + OauthIdToken string WebAuthnData string @@ -80,16 +81,17 @@ func (s *SessionWrapper) DestroyData(ctx context.Context) { func (s *SessionWrapper) defaultSessionData() SessionData { return SessionData{ - LoggedIn: false, - IsAdmin: false, - UserIdentifier: "", - Firstname: "", - Lastname: "", - Email: "", - OauthState: "", - OauthNonce: "", - OauthProvider: "", - OauthReturnTo: "", - OauthIdToken: "", + LoggedIn: false, + IsAdmin: false, + UserIdentifier: "", + Firstname: "", + Lastname: "", + Email: "", + OauthState: "", + OauthNonce: "", + OauthCodeVerifier: "", + OauthProvider: "", + OauthReturnTo: "", + OauthIdToken: "", } } diff --git a/internal/app/auth/auth.go b/internal/app/auth/auth.go index c4ce8b3..1f4f9db 100644 --- a/internal/app/auth/auth.go +++ b/internal/app/auth/auth.go @@ -47,6 +47,11 @@ const ( AuthenticatorTypeOidc AuthenticatorType = "oidc" ) +const ( + pkceMethodS256 = "S256" // SHA-256 hashing + pkceMethodPlain = "plain" // plain text +) + // AuthenticatorOauth is the interface for all OAuth authenticators. type AuthenticatorOauth interface { // GetName returns the name of the authenticator. @@ -70,6 +75,10 @@ type AuthenticatorOauth interface { GetAllowedUserGroups() []string // GetLogoutUrl returns an IdP logout URL if supported by the provider. GetLogoutUrl(idTokenHint, postLogoutRedirectUri string) (string, bool) + // PKCEAuthCodeOptions returns PKCE options for the authorization request and the verifier for the token exchange. + PKCEAuthCodeOptions() ([]oauth2.AuthCodeOption, string) + // PKCETokenOptions returns PKCE options for the token exchange. + PKCETokenOptions(verifier string) []oauth2.AuthCodeOption } // AuthenticatorLdap is the interface for all LDAP authenticators. @@ -448,30 +457,34 @@ func (a *Authenticator) passwordAuthentication( // OauthLoginStep1 starts the oauth authentication flow by returning the authentication URL, state and nonce. func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) ( - authCodeUrl, state, nonce string, + authCodeUrl, state, nonce, codeVerifier string, err error, ) { oauthProvider, ok := a.oauthAuthenticators[providerId] if !ok { - return "", "", "", fmt.Errorf("missing oauth provider %s", providerId) + return "", "", "", "", fmt.Errorf("missing oauth provider %s", providerId) } // Prepare authentication flow, set state cookies state, err = a.randString(16) if err != nil { - return "", "", "", fmt.Errorf("failed to generate state: %w", err) + return "", "", "", "", fmt.Errorf("failed to generate state: %w", err) } + // Generate PKCE code verifier and challenge if enabled. Otherwise, options will be empty. + authCodeOptions, codeVerifier := oauthProvider.PKCEAuthCodeOptions() + switch oauthProvider.GetType() { case AuthenticatorTypeOAuth: - authCodeUrl = oauthProvider.AuthCodeURL(state) + authCodeUrl = oauthProvider.AuthCodeURL(state, authCodeOptions...) case AuthenticatorTypeOidc: nonce, err = a.randString(16) if err != nil { - return "", "", "", fmt.Errorf("failed to generate nonce: %w", err) + return "", "", "", "", fmt.Errorf("failed to generate nonce: %w", err) } - authCodeUrl = oauthProvider.AuthCodeURL(state, oidc.Nonce(nonce)) + authCodeOptions = append(authCodeOptions, oidc.Nonce(nonce)) + authCodeUrl = oauthProvider.AuthCodeURL(state, authCodeOptions...) } return @@ -531,13 +544,16 @@ func isAnyAllowedUserGroup(userGroups, allowedUserGroups []string) bool { // OauthLoginStep2 finishes the oauth authentication flow by exchanging the code for an access token and // fetching the user information. -func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, string, error) { +func (a *Authenticator) OauthLoginStep2( + ctx context.Context, + providerId, nonce, code, codeVerifier string, +) (*domain.User, string, error) { oauthProvider, ok := a.oauthAuthenticators[providerId] if !ok { return nil, "", fmt.Errorf("missing oauth provider %s", providerId) } - oauth2Token, err := oauthProvider.Exchange(ctx, code) + oauth2Token, err := oauthProvider.Exchange(ctx, code, oauthProvider.PKCETokenOptions(codeVerifier)...) if err != nil { return nil, "", fmt.Errorf("unable to exchange code: %w", err) } diff --git a/internal/app/auth/auth_oauth.go b/internal/app/auth/auth_oauth.go index 55eb7e4..d7be206 100644 --- a/internal/app/auth/auth_oauth.go +++ b/internal/app/auth/auth_oauth.go @@ -30,6 +30,8 @@ type PlainOauthAuthenticator struct { sensitiveInfoLogging bool allowedDomains []string allowedUserGroups []string + usePKCE bool + pkceMethod string } func newPlainOauthAuthenticator( @@ -62,6 +64,14 @@ func newPlainOauthAuthenticator( provider.sensitiveInfoLogging = cfg.LogSensitiveInfo provider.allowedDomains = cfg.AllowedDomains provider.allowedUserGroups = cfg.AllowedUserGroups + provider.usePKCE = cfg.UsePKCE == nil || *cfg.UsePKCE + provider.pkceMethod = cfg.PKCEMethod + if provider.pkceMethod == "" { + provider.pkceMethod = pkceMethodS256 + } + if provider.usePKCE && provider.pkceMethod != pkceMethodS256 && provider.pkceMethod != pkceMethodPlain { + return nil, fmt.Errorf("unsupported PKCE method %q, allowed: S256, plain", provider.pkceMethod) + } return provider, nil } @@ -83,6 +93,32 @@ func (p PlainOauthAuthenticator) GetLogoutUrl(_, _ string) (string, bool) { return "", false } +// PKCEAuthCodeOptions returns PKCE options for the authorization request and the verifier for the token exchange. +func (p PlainOauthAuthenticator) PKCEAuthCodeOptions() ([]oauth2.AuthCodeOption, string) { + if !p.usePKCE { + return nil, "" + } + + verifier := oauth2.GenerateVerifier() + if p.pkceMethod == pkceMethodPlain { + return []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("code_challenge", verifier), + oauth2.SetAuthURLParam("code_challenge_method", pkceMethodPlain), + }, verifier + } + + return []oauth2.AuthCodeOption{oauth2.S256ChallengeOption(verifier)}, verifier +} + +// PKCETokenOptions returns PKCE options for the token exchange. +func (p PlainOauthAuthenticator) PKCETokenOptions(verifier string) []oauth2.AuthCodeOption { + if !p.usePKCE || verifier == "" { + return nil + } + + return []oauth2.AuthCodeOption{oauth2.VerifierOption(verifier)} +} + // RegistrationEnabled returns whether registration is enabled for the OAuth authenticator. func (p PlainOauthAuthenticator) RegistrationEnabled() bool { return p.registrationEnabled diff --git a/internal/app/auth/auth_oauth_test.go b/internal/app/auth/auth_oauth_test.go new file mode 100644 index 0000000..431066c --- /dev/null +++ b/internal/app/auth/auth_oauth_test.go @@ -0,0 +1,61 @@ +package auth + +import ( + "testing" + + "golang.org/x/oauth2" +) + +func TestPlainOauthAuthenticatorPKCES256Options(t *testing.T) { + authenticator := PlainOauthAuthenticator{usePKCE: true, pkceMethod: "S256"} + + options, verifier := authenticator.PKCEAuthCodeOptions() + if verifier == "" { + t.Fatal("expected verifier") + } + + values := authCodeValues(t, options) + + if values.Get("code_challenge") == "" { + t.Fatal("expected code_challenge") + } + if values.Get("code_challenge_method") != "S256" { + t.Fatalf("expected S256 challenge method, got %q", values.Get("code_challenge_method")) + } + + tokenOptions := authenticator.PKCETokenOptions(verifier) + if len(tokenOptions) != 1 { + t.Fatalf("expected one token option, got %d", len(tokenOptions)) + } +} + +func TestPlainOauthAuthenticatorPKCEPlainOptions(t *testing.T) { + authenticator := PlainOauthAuthenticator{usePKCE: true, pkceMethod: "plain"} + + options, verifier := authenticator.PKCEAuthCodeOptions() + values := authCodeValues(t, options) + + if values.Get("code_challenge") != verifier { + t.Fatalf("expected plain challenge %q, got %q", verifier, values.Get("code_challenge")) + } + if values.Get("code_challenge_method") != "plain" { + t.Fatalf("expected plain challenge method, got %q", values.Get("code_challenge_method")) + } +} + +func TestPlainOauthAuthenticatorPKCEDisabled(t *testing.T) { + authenticator := PlainOauthAuthenticator{usePKCE: false, pkceMethod: "S256"} + + options, verifier := authenticator.PKCEAuthCodeOptions() + if len(options) != 0 { + t.Fatalf("expected no auth code options, got %d", len(options)) + } + if verifier != "" { + t.Fatalf("expected empty verifier, got %q", verifier) + } + + tokenOptions := authenticator.PKCETokenOptions(oauth2.GenerateVerifier()) + if len(tokenOptions) != 0 { + t.Fatalf("expected no token options, got %d", len(tokenOptions)) + } +} diff --git a/internal/app/auth/auth_oidc.go b/internal/app/auth/auth_oidc.go index 5bcdbc7..571a46d 100644 --- a/internal/app/auth/auth_oidc.go +++ b/internal/app/auth/auth_oidc.go @@ -30,6 +30,8 @@ type OidcAuthenticator struct { allowedUserGroups []string endSessionEndpoint string logoutIdpSession bool + usePKCE bool + pkceMethod string } func newOidcAuthenticator( @@ -67,6 +69,14 @@ func newOidcAuthenticator( provider.allowedDomains = cfg.AllowedDomains provider.allowedUserGroups = cfg.AllowedUserGroups provider.logoutIdpSession = cfg.LogoutIdpSession == nil || *cfg.LogoutIdpSession + provider.usePKCE = cfg.UsePKCE == nil || *cfg.UsePKCE + provider.pkceMethod = cfg.PKCEMethod + if provider.pkceMethod == "" { + provider.pkceMethod = pkceMethodS256 + } + if provider.usePKCE && provider.pkceMethod != pkceMethodS256 && provider.pkceMethod != pkceMethodPlain { + return nil, fmt.Errorf("unsupported PKCE method %q, allowed: S256, plain", provider.pkceMethod) + } var providerMetadata struct { EndSessionEndpoint string `json:"end_session_endpoint"` @@ -121,6 +131,32 @@ func (o OidcAuthenticator) GetLogoutUrl(idTokenHint, postLogoutRedirectUri strin return logoutUrl.String(), true } +// PKCEAuthCodeOptions returns PKCE options for the authorization request and the verifier for the token exchange. +func (o OidcAuthenticator) PKCEAuthCodeOptions() ([]oauth2.AuthCodeOption, string) { + if !o.usePKCE { + return nil, "" + } + + verifier := oauth2.GenerateVerifier() + if o.pkceMethod == pkceMethodPlain { + return []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("code_challenge", verifier), + oauth2.SetAuthURLParam("code_challenge_method", pkceMethodPlain), + }, verifier + } + + return []oauth2.AuthCodeOption{oauth2.S256ChallengeOption(verifier)}, verifier +} + +// PKCETokenOptions returns PKCE options for the token exchange. +func (o OidcAuthenticator) PKCETokenOptions(verifier string) []oauth2.AuthCodeOption { + if !o.usePKCE || verifier == "" { + return nil + } + + return []oauth2.AuthCodeOption{oauth2.VerifierOption(verifier)} +} + // RegistrationEnabled returns whether registration is enabled for this authenticator. func (o OidcAuthenticator) RegistrationEnabled() bool { return o.registrationEnabled diff --git a/internal/app/auth/auth_oidc_test.go b/internal/app/auth/auth_oidc_test.go new file mode 100644 index 0000000..d57de35 --- /dev/null +++ b/internal/app/auth/auth_oidc_test.go @@ -0,0 +1,79 @@ +package auth + +import ( + "net/url" + "testing" + + "golang.org/x/oauth2" +) + +func authCodeValues(t *testing.T, options []oauth2.AuthCodeOption) url.Values { + t.Helper() + + config := oauth2.Config{ + ClientID: "client-id", + Endpoint: oauth2.Endpoint{AuthURL: "https://example.com/auth"}, + RedirectURL: "https://wg.example.com/callback", + } + authCodeURL, err := url.Parse(config.AuthCodeURL("state", options...)) + if err != nil { + t.Fatalf("failed to parse auth code URL: %v", err) + } + + return authCodeURL.Query() +} + +func TestOidcAuthenticatorPKCES256Options(t *testing.T) { + authenticator := OidcAuthenticator{usePKCE: true, pkceMethod: "S256"} + + options, verifier := authenticator.PKCEAuthCodeOptions() + if verifier == "" { + t.Fatal("expected verifier") + } + + values := authCodeValues(t, options) + + if values.Get("code_challenge") == "" { + t.Fatal("expected code_challenge") + } + if values.Get("code_challenge_method") != "S256" { + t.Fatalf("expected S256 challenge method, got %q", values.Get("code_challenge_method")) + } + + tokenOptions := authenticator.PKCETokenOptions(verifier) + if len(tokenOptions) != 1 { + t.Fatalf("expected one token option, got %d", len(tokenOptions)) + } + +} + +func TestOidcAuthenticatorPKCEPlainOptions(t *testing.T) { + authenticator := OidcAuthenticator{usePKCE: true, pkceMethod: "plain"} + + options, verifier := authenticator.PKCEAuthCodeOptions() + values := authCodeValues(t, options) + + if values.Get("code_challenge") != verifier { + t.Fatalf("expected plain challenge %q, got %q", verifier, values.Get("code_challenge")) + } + if values.Get("code_challenge_method") != "plain" { + t.Fatalf("expected plain challenge method, got %q", values.Get("code_challenge_method")) + } +} + +func TestOidcAuthenticatorPKCEDisabled(t *testing.T) { + authenticator := OidcAuthenticator{usePKCE: false, pkceMethod: "S256"} + + options, verifier := authenticator.PKCEAuthCodeOptions() + if len(options) != 0 { + t.Fatalf("expected no auth code options, got %d", len(options)) + } + if verifier != "" { + t.Fatalf("expected empty verifier, got %q", verifier) + } + + tokenOptions := authenticator.PKCETokenOptions(oauth2.GenerateVerifier()) + if len(tokenOptions) != 0 { + t.Fatalf("expected no token options, got %d", len(tokenOptions)) + } +} diff --git a/internal/config/auth.go b/internal/config/auth.go index bd2bc34..fef4422 100644 --- a/internal/config/auth.go +++ b/internal/config/auth.go @@ -279,6 +279,14 @@ type OpenIDConnectProvider struct { // This also includes OAuth tokens! Keep this disabled in production! LogSensitiveInfo bool `yaml:"log_sensitive_info"` + // UsePKCE controls whether Proof Key for Code Exchange is used during the authorization code flow. + // If unset, PKCE is enabled by default. + UsePKCE *bool `yaml:"use_pkce"` + + // PKCEMethod controls which PKCE challenge method is used. Supported values are "S256" and "plain". + // If empty, "S256" is used. + PKCEMethod string `yaml:"pkce_method"` + // LogoutIdpSession controls whether the user's session at the OIDC provider is terminated on logout. // If set to true (default), the user will be redirected to the IdP's end_session_endpoint after local logout. // If set to false, only the local wg-portal session is invalidated. @@ -332,6 +340,14 @@ type OAuthProvider struct { // If LogSensitiveInfo is set to true, sensitive information retrieved from the OAuth provider will be logged in trace level. // This also includes OAuth tokens! Keep this disabled in production! LogSensitiveInfo bool `yaml:"log_sensitive_info"` + + // UsePKCE controls whether Proof Key for Code Exchange is used during the authorization code flow. + // If unset, PKCE is enabled by default. + UsePKCE *bool `yaml:"use_pkce"` + + // PKCEMethod controls which PKCE challenge method is used. Supported values are "S256" and "plain". + // If empty, "S256" is used. + PKCEMethod string `yaml:"pkce_method"` } // WebauthnConfig contains the configuration for the WebAuthn authenticator.