feat: add support for PKCE (#686)

This commit is contained in:
Christoph Haas
2026-05-26 22:47:38 +02:00
parent 0cf04d07e0
commit 4c986cc74c
10 changed files with 295 additions and 28 deletions

View File

@@ -47,6 +47,11 @@ const (
AuthenticatorTypeOidc AuthenticatorType = "oidc"
)
const (
pkceMethodS256 = "S256" // SHA-256 hashing
pkceMethodPlain = "plain" // plain text
)
// AuthenticatorOauth is the interface for all OAuth authenticators.
type AuthenticatorOauth interface {
// GetName returns the name of the authenticator.
@@ -70,6 +75,10 @@ type AuthenticatorOauth interface {
GetAllowedUserGroups() []string
// GetLogoutUrl returns an IdP logout URL if supported by the provider.
GetLogoutUrl(idTokenHint, postLogoutRedirectUri string) (string, bool)
// PKCEAuthCodeOptions returns PKCE options for the authorization request and the verifier for the token exchange.
PKCEAuthCodeOptions() ([]oauth2.AuthCodeOption, string)
// PKCETokenOptions returns PKCE options for the token exchange.
PKCETokenOptions(verifier string) []oauth2.AuthCodeOption
}
// AuthenticatorLdap is the interface for all LDAP authenticators.
@@ -448,30 +457,34 @@ func (a *Authenticator) passwordAuthentication(
// OauthLoginStep1 starts the oauth authentication flow by returning the authentication URL, state and nonce.
func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) (
authCodeUrl, state, nonce string,
authCodeUrl, state, nonce, codeVerifier string,
err error,
) {
oauthProvider, ok := a.oauthAuthenticators[providerId]
if !ok {
return "", "", "", fmt.Errorf("missing oauth provider %s", providerId)
return "", "", "", "", fmt.Errorf("missing oauth provider %s", providerId)
}
// Prepare authentication flow, set state cookies
state, err = a.randString(16)
if err != nil {
return "", "", "", fmt.Errorf("failed to generate state: %w", err)
return "", "", "", "", fmt.Errorf("failed to generate state: %w", err)
}
// Generate PKCE code verifier and challenge if enabled. Otherwise, options will be empty.
authCodeOptions, codeVerifier := oauthProvider.PKCEAuthCodeOptions()
switch oauthProvider.GetType() {
case AuthenticatorTypeOAuth:
authCodeUrl = oauthProvider.AuthCodeURL(state)
authCodeUrl = oauthProvider.AuthCodeURL(state, authCodeOptions...)
case AuthenticatorTypeOidc:
nonce, err = a.randString(16)
if err != nil {
return "", "", "", fmt.Errorf("failed to generate nonce: %w", err)
return "", "", "", "", fmt.Errorf("failed to generate nonce: %w", err)
}
authCodeUrl = oauthProvider.AuthCodeURL(state, oidc.Nonce(nonce))
authCodeOptions = append(authCodeOptions, oidc.Nonce(nonce))
authCodeUrl = oauthProvider.AuthCodeURL(state, authCodeOptions...)
}
return
@@ -531,13 +544,16 @@ func isAnyAllowedUserGroup(userGroups, allowedUserGroups []string) bool {
// OauthLoginStep2 finishes the oauth authentication flow by exchanging the code for an access token and
// fetching the user information.
func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, string, error) {
func (a *Authenticator) OauthLoginStep2(
ctx context.Context,
providerId, nonce, code, codeVerifier string,
) (*domain.User, string, error) {
oauthProvider, ok := a.oauthAuthenticators[providerId]
if !ok {
return nil, "", fmt.Errorf("missing oauth provider %s", providerId)
}
oauth2Token, err := oauthProvider.Exchange(ctx, code)
oauth2Token, err := oauthProvider.Exchange(ctx, code, oauthProvider.PKCETokenOptions(codeVerifier)...)
if err != nil {
return nil, "", fmt.Errorf("unable to exchange code: %w", err)
}

View File

@@ -30,6 +30,8 @@ type PlainOauthAuthenticator struct {
sensitiveInfoLogging bool
allowedDomains []string
allowedUserGroups []string
usePKCE bool
pkceMethod string
}
func newPlainOauthAuthenticator(
@@ -62,6 +64,14 @@ func newPlainOauthAuthenticator(
provider.sensitiveInfoLogging = cfg.LogSensitiveInfo
provider.allowedDomains = cfg.AllowedDomains
provider.allowedUserGroups = cfg.AllowedUserGroups
provider.usePKCE = cfg.UsePKCE == nil || *cfg.UsePKCE
provider.pkceMethod = cfg.PKCEMethod
if provider.pkceMethod == "" {
provider.pkceMethod = pkceMethodS256
}
if provider.usePKCE && provider.pkceMethod != pkceMethodS256 && provider.pkceMethod != pkceMethodPlain {
return nil, fmt.Errorf("unsupported PKCE method %q, allowed: S256, plain", provider.pkceMethod)
}
return provider, nil
}
@@ -83,6 +93,32 @@ func (p PlainOauthAuthenticator) GetLogoutUrl(_, _ string) (string, bool) {
return "", false
}
// PKCEAuthCodeOptions returns PKCE options for the authorization request and the verifier for the token exchange.
func (p PlainOauthAuthenticator) PKCEAuthCodeOptions() ([]oauth2.AuthCodeOption, string) {
if !p.usePKCE {
return nil, ""
}
verifier := oauth2.GenerateVerifier()
if p.pkceMethod == pkceMethodPlain {
return []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge", verifier),
oauth2.SetAuthURLParam("code_challenge_method", pkceMethodPlain),
}, verifier
}
return []oauth2.AuthCodeOption{oauth2.S256ChallengeOption(verifier)}, verifier
}
// PKCETokenOptions returns PKCE options for the token exchange.
func (p PlainOauthAuthenticator) PKCETokenOptions(verifier string) []oauth2.AuthCodeOption {
if !p.usePKCE || verifier == "" {
return nil
}
return []oauth2.AuthCodeOption{oauth2.VerifierOption(verifier)}
}
// RegistrationEnabled returns whether registration is enabled for the OAuth authenticator.
func (p PlainOauthAuthenticator) RegistrationEnabled() bool {
return p.registrationEnabled

View File

@@ -0,0 +1,61 @@
package auth
import (
"testing"
"golang.org/x/oauth2"
)
func TestPlainOauthAuthenticatorPKCES256Options(t *testing.T) {
authenticator := PlainOauthAuthenticator{usePKCE: true, pkceMethod: "S256"}
options, verifier := authenticator.PKCEAuthCodeOptions()
if verifier == "" {
t.Fatal("expected verifier")
}
values := authCodeValues(t, options)
if values.Get("code_challenge") == "" {
t.Fatal("expected code_challenge")
}
if values.Get("code_challenge_method") != "S256" {
t.Fatalf("expected S256 challenge method, got %q", values.Get("code_challenge_method"))
}
tokenOptions := authenticator.PKCETokenOptions(verifier)
if len(tokenOptions) != 1 {
t.Fatalf("expected one token option, got %d", len(tokenOptions))
}
}
func TestPlainOauthAuthenticatorPKCEPlainOptions(t *testing.T) {
authenticator := PlainOauthAuthenticator{usePKCE: true, pkceMethod: "plain"}
options, verifier := authenticator.PKCEAuthCodeOptions()
values := authCodeValues(t, options)
if values.Get("code_challenge") != verifier {
t.Fatalf("expected plain challenge %q, got %q", verifier, values.Get("code_challenge"))
}
if values.Get("code_challenge_method") != "plain" {
t.Fatalf("expected plain challenge method, got %q", values.Get("code_challenge_method"))
}
}
func TestPlainOauthAuthenticatorPKCEDisabled(t *testing.T) {
authenticator := PlainOauthAuthenticator{usePKCE: false, pkceMethod: "S256"}
options, verifier := authenticator.PKCEAuthCodeOptions()
if len(options) != 0 {
t.Fatalf("expected no auth code options, got %d", len(options))
}
if verifier != "" {
t.Fatalf("expected empty verifier, got %q", verifier)
}
tokenOptions := authenticator.PKCETokenOptions(oauth2.GenerateVerifier())
if len(tokenOptions) != 0 {
t.Fatalf("expected no token options, got %d", len(tokenOptions))
}
}

View File

@@ -30,6 +30,8 @@ type OidcAuthenticator struct {
allowedUserGroups []string
endSessionEndpoint string
logoutIdpSession bool
usePKCE bool
pkceMethod string
}
func newOidcAuthenticator(
@@ -67,6 +69,14 @@ func newOidcAuthenticator(
provider.allowedDomains = cfg.AllowedDomains
provider.allowedUserGroups = cfg.AllowedUserGroups
provider.logoutIdpSession = cfg.LogoutIdpSession == nil || *cfg.LogoutIdpSession
provider.usePKCE = cfg.UsePKCE == nil || *cfg.UsePKCE
provider.pkceMethod = cfg.PKCEMethod
if provider.pkceMethod == "" {
provider.pkceMethod = pkceMethodS256
}
if provider.usePKCE && provider.pkceMethod != pkceMethodS256 && provider.pkceMethod != pkceMethodPlain {
return nil, fmt.Errorf("unsupported PKCE method %q, allowed: S256, plain", provider.pkceMethod)
}
var providerMetadata struct {
EndSessionEndpoint string `json:"end_session_endpoint"`
@@ -121,6 +131,32 @@ func (o OidcAuthenticator) GetLogoutUrl(idTokenHint, postLogoutRedirectUri strin
return logoutUrl.String(), true
}
// PKCEAuthCodeOptions returns PKCE options for the authorization request and the verifier for the token exchange.
func (o OidcAuthenticator) PKCEAuthCodeOptions() ([]oauth2.AuthCodeOption, string) {
if !o.usePKCE {
return nil, ""
}
verifier := oauth2.GenerateVerifier()
if o.pkceMethod == pkceMethodPlain {
return []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge", verifier),
oauth2.SetAuthURLParam("code_challenge_method", pkceMethodPlain),
}, verifier
}
return []oauth2.AuthCodeOption{oauth2.S256ChallengeOption(verifier)}, verifier
}
// PKCETokenOptions returns PKCE options for the token exchange.
func (o OidcAuthenticator) PKCETokenOptions(verifier string) []oauth2.AuthCodeOption {
if !o.usePKCE || verifier == "" {
return nil
}
return []oauth2.AuthCodeOption{oauth2.VerifierOption(verifier)}
}
// RegistrationEnabled returns whether registration is enabled for this authenticator.
func (o OidcAuthenticator) RegistrationEnabled() bool {
return o.registrationEnabled

View File

@@ -0,0 +1,79 @@
package auth
import (
"net/url"
"testing"
"golang.org/x/oauth2"
)
func authCodeValues(t *testing.T, options []oauth2.AuthCodeOption) url.Values {
t.Helper()
config := oauth2.Config{
ClientID: "client-id",
Endpoint: oauth2.Endpoint{AuthURL: "https://example.com/auth"},
RedirectURL: "https://wg.example.com/callback",
}
authCodeURL, err := url.Parse(config.AuthCodeURL("state", options...))
if err != nil {
t.Fatalf("failed to parse auth code URL: %v", err)
}
return authCodeURL.Query()
}
func TestOidcAuthenticatorPKCES256Options(t *testing.T) {
authenticator := OidcAuthenticator{usePKCE: true, pkceMethod: "S256"}
options, verifier := authenticator.PKCEAuthCodeOptions()
if verifier == "" {
t.Fatal("expected verifier")
}
values := authCodeValues(t, options)
if values.Get("code_challenge") == "" {
t.Fatal("expected code_challenge")
}
if values.Get("code_challenge_method") != "S256" {
t.Fatalf("expected S256 challenge method, got %q", values.Get("code_challenge_method"))
}
tokenOptions := authenticator.PKCETokenOptions(verifier)
if len(tokenOptions) != 1 {
t.Fatalf("expected one token option, got %d", len(tokenOptions))
}
}
func TestOidcAuthenticatorPKCEPlainOptions(t *testing.T) {
authenticator := OidcAuthenticator{usePKCE: true, pkceMethod: "plain"}
options, verifier := authenticator.PKCEAuthCodeOptions()
values := authCodeValues(t, options)
if values.Get("code_challenge") != verifier {
t.Fatalf("expected plain challenge %q, got %q", verifier, values.Get("code_challenge"))
}
if values.Get("code_challenge_method") != "plain" {
t.Fatalf("expected plain challenge method, got %q", values.Get("code_challenge_method"))
}
}
func TestOidcAuthenticatorPKCEDisabled(t *testing.T) {
authenticator := OidcAuthenticator{usePKCE: false, pkceMethod: "S256"}
options, verifier := authenticator.PKCEAuthCodeOptions()
if len(options) != 0 {
t.Fatalf("expected no auth code options, got %d", len(options))
}
if verifier != "" {
t.Fatalf("expected empty verifier, got %q", verifier)
}
tokenOptions := authenticator.PKCETokenOptions(oauth2.GenerateVerifier())
if len(tokenOptions) != 0 {
t.Fatalf("expected no token options, got %d", len(tokenOptions))
}
}