mirror of
https://github.com/h44z/wg-portal.git
synced 2026-05-28 17:06:18 +00:00
feat: add support for PKCE (#686)
This commit is contained in:
@@ -24,9 +24,9 @@ type AuthenticationService interface {
|
||||
// PlainLogin authenticates a user with a username and password.
|
||||
PlainLogin(ctx context.Context, username, password string) (*domain.User, error)
|
||||
// OauthLoginStep1 initiates the OAuth login flow.
|
||||
OauthLoginStep1(_ context.Context, providerId string) (authCodeUrl, state, nonce string, err error)
|
||||
OauthLoginStep1(_ context.Context, providerId string) (authCodeUrl, state, nonce, codeVerifier string, err error)
|
||||
// OauthLoginStep2 completes the OAuth login flow and logins the user in.
|
||||
OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, string, error)
|
||||
OauthLoginStep2(ctx context.Context, providerId, nonce, code, codeVerifier string) (*domain.User, string, error)
|
||||
// OauthProviderLogoutUrl returns an IdP logout URL for the given provider if supported.
|
||||
OauthProviderLogoutUrl(providerId, idTokenHint, postLogoutRedirectUri string) (string, bool)
|
||||
}
|
||||
@@ -231,7 +231,7 @@ func (e AuthEndpoint) handleOauthInitiateGet() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
authCodeUrl, state, nonce, err := e.authService.OauthLoginStep1(context.Background(), provider)
|
||||
authCodeUrl, state, nonce, codeVerifier, err := e.authService.OauthLoginStep1(context.Background(), provider)
|
||||
if err != nil {
|
||||
slog.Debug("failed to create oauth auth code URL",
|
||||
"provider", provider, "error", err)
|
||||
@@ -247,6 +247,7 @@ func (e AuthEndpoint) handleOauthInitiateGet() http.HandlerFunc {
|
||||
authSession := e.session.GetData(r.Context())
|
||||
authSession.OauthState = state
|
||||
authSession.OauthNonce = nonce
|
||||
authSession.OauthCodeVerifier = codeVerifier
|
||||
authSession.OauthProvider = provider
|
||||
authSession.OauthReturnTo = returnTo
|
||||
e.session.SetData(r.Context(), authSession)
|
||||
@@ -323,7 +324,7 @@ func (e AuthEndpoint) handleOauthCallbackGet() http.HandlerFunc {
|
||||
|
||||
loginCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) // avoid long waits
|
||||
user, idTokenHint, err := e.authService.OauthLoginStep2(loginCtx, provider, currentSession.OauthNonce,
|
||||
oauthCode)
|
||||
oauthCode, currentSession.OauthCodeVerifier)
|
||||
cancel()
|
||||
if err != nil {
|
||||
slog.Debug("failed to process oauth code",
|
||||
@@ -362,6 +363,7 @@ func (e AuthEndpoint) setAuthenticatedUser(r *http.Request, user *domain.User, o
|
||||
|
||||
currentSession.OauthState = ""
|
||||
currentSession.OauthNonce = ""
|
||||
currentSession.OauthCodeVerifier = ""
|
||||
currentSession.OauthProvider = oauthProvider
|
||||
currentSession.OauthReturnTo = ""
|
||||
currentSession.OauthIdToken = idTokenHint
|
||||
|
||||
@@ -26,11 +26,12 @@ type SessionData struct {
|
||||
Lastname string
|
||||
Email string
|
||||
|
||||
OauthState string
|
||||
OauthNonce string
|
||||
OauthProvider string
|
||||
OauthReturnTo string
|
||||
OauthIdToken string
|
||||
OauthState string
|
||||
OauthNonce string
|
||||
OauthCodeVerifier string
|
||||
OauthProvider string
|
||||
OauthReturnTo string
|
||||
OauthIdToken string
|
||||
|
||||
WebAuthnData string
|
||||
|
||||
@@ -80,16 +81,17 @@ func (s *SessionWrapper) DestroyData(ctx context.Context) {
|
||||
|
||||
func (s *SessionWrapper) defaultSessionData() SessionData {
|
||||
return SessionData{
|
||||
LoggedIn: false,
|
||||
IsAdmin: false,
|
||||
UserIdentifier: "",
|
||||
Firstname: "",
|
||||
Lastname: "",
|
||||
Email: "",
|
||||
OauthState: "",
|
||||
OauthNonce: "",
|
||||
OauthProvider: "",
|
||||
OauthReturnTo: "",
|
||||
OauthIdToken: "",
|
||||
LoggedIn: false,
|
||||
IsAdmin: false,
|
||||
UserIdentifier: "",
|
||||
Firstname: "",
|
||||
Lastname: "",
|
||||
Email: "",
|
||||
OauthState: "",
|
||||
OauthNonce: "",
|
||||
OauthCodeVerifier: "",
|
||||
OauthProvider: "",
|
||||
OauthReturnTo: "",
|
||||
OauthIdToken: "",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,6 +47,11 @@ const (
|
||||
AuthenticatorTypeOidc AuthenticatorType = "oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
pkceMethodS256 = "S256" // SHA-256 hashing
|
||||
pkceMethodPlain = "plain" // plain text
|
||||
)
|
||||
|
||||
// AuthenticatorOauth is the interface for all OAuth authenticators.
|
||||
type AuthenticatorOauth interface {
|
||||
// GetName returns the name of the authenticator.
|
||||
@@ -70,6 +75,10 @@ type AuthenticatorOauth interface {
|
||||
GetAllowedUserGroups() []string
|
||||
// GetLogoutUrl returns an IdP logout URL if supported by the provider.
|
||||
GetLogoutUrl(idTokenHint, postLogoutRedirectUri string) (string, bool)
|
||||
// PKCEAuthCodeOptions returns PKCE options for the authorization request and the verifier for the token exchange.
|
||||
PKCEAuthCodeOptions() ([]oauth2.AuthCodeOption, string)
|
||||
// PKCETokenOptions returns PKCE options for the token exchange.
|
||||
PKCETokenOptions(verifier string) []oauth2.AuthCodeOption
|
||||
}
|
||||
|
||||
// AuthenticatorLdap is the interface for all LDAP authenticators.
|
||||
@@ -448,30 +457,34 @@ func (a *Authenticator) passwordAuthentication(
|
||||
|
||||
// OauthLoginStep1 starts the oauth authentication flow by returning the authentication URL, state and nonce.
|
||||
func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) (
|
||||
authCodeUrl, state, nonce string,
|
||||
authCodeUrl, state, nonce, codeVerifier string,
|
||||
err error,
|
||||
) {
|
||||
oauthProvider, ok := a.oauthAuthenticators[providerId]
|
||||
if !ok {
|
||||
return "", "", "", fmt.Errorf("missing oauth provider %s", providerId)
|
||||
return "", "", "", "", fmt.Errorf("missing oauth provider %s", providerId)
|
||||
}
|
||||
|
||||
// Prepare authentication flow, set state cookies
|
||||
state, err = a.randString(16)
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("failed to generate state: %w", err)
|
||||
return "", "", "", "", fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
// Generate PKCE code verifier and challenge if enabled. Otherwise, options will be empty.
|
||||
authCodeOptions, codeVerifier := oauthProvider.PKCEAuthCodeOptions()
|
||||
|
||||
switch oauthProvider.GetType() {
|
||||
case AuthenticatorTypeOAuth:
|
||||
authCodeUrl = oauthProvider.AuthCodeURL(state)
|
||||
authCodeUrl = oauthProvider.AuthCodeURL(state, authCodeOptions...)
|
||||
case AuthenticatorTypeOidc:
|
||||
nonce, err = a.randString(16)
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("failed to generate nonce: %w", err)
|
||||
return "", "", "", "", fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
authCodeUrl = oauthProvider.AuthCodeURL(state, oidc.Nonce(nonce))
|
||||
authCodeOptions = append(authCodeOptions, oidc.Nonce(nonce))
|
||||
authCodeUrl = oauthProvider.AuthCodeURL(state, authCodeOptions...)
|
||||
}
|
||||
|
||||
return
|
||||
@@ -531,13 +544,16 @@ func isAnyAllowedUserGroup(userGroups, allowedUserGroups []string) bool {
|
||||
|
||||
// OauthLoginStep2 finishes the oauth authentication flow by exchanging the code for an access token and
|
||||
// fetching the user information.
|
||||
func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, string, error) {
|
||||
func (a *Authenticator) OauthLoginStep2(
|
||||
ctx context.Context,
|
||||
providerId, nonce, code, codeVerifier string,
|
||||
) (*domain.User, string, error) {
|
||||
oauthProvider, ok := a.oauthAuthenticators[providerId]
|
||||
if !ok {
|
||||
return nil, "", fmt.Errorf("missing oauth provider %s", providerId)
|
||||
}
|
||||
|
||||
oauth2Token, err := oauthProvider.Exchange(ctx, code)
|
||||
oauth2Token, err := oauthProvider.Exchange(ctx, code, oauthProvider.PKCETokenOptions(codeVerifier)...)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("unable to exchange code: %w", err)
|
||||
}
|
||||
|
||||
@@ -30,6 +30,8 @@ type PlainOauthAuthenticator struct {
|
||||
sensitiveInfoLogging bool
|
||||
allowedDomains []string
|
||||
allowedUserGroups []string
|
||||
usePKCE bool
|
||||
pkceMethod string
|
||||
}
|
||||
|
||||
func newPlainOauthAuthenticator(
|
||||
@@ -62,6 +64,14 @@ func newPlainOauthAuthenticator(
|
||||
provider.sensitiveInfoLogging = cfg.LogSensitiveInfo
|
||||
provider.allowedDomains = cfg.AllowedDomains
|
||||
provider.allowedUserGroups = cfg.AllowedUserGroups
|
||||
provider.usePKCE = cfg.UsePKCE == nil || *cfg.UsePKCE
|
||||
provider.pkceMethod = cfg.PKCEMethod
|
||||
if provider.pkceMethod == "" {
|
||||
provider.pkceMethod = pkceMethodS256
|
||||
}
|
||||
if provider.usePKCE && provider.pkceMethod != pkceMethodS256 && provider.pkceMethod != pkceMethodPlain {
|
||||
return nil, fmt.Errorf("unsupported PKCE method %q, allowed: S256, plain", provider.pkceMethod)
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
@@ -83,6 +93,32 @@ func (p PlainOauthAuthenticator) GetLogoutUrl(_, _ string) (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// PKCEAuthCodeOptions returns PKCE options for the authorization request and the verifier for the token exchange.
|
||||
func (p PlainOauthAuthenticator) PKCEAuthCodeOptions() ([]oauth2.AuthCodeOption, string) {
|
||||
if !p.usePKCE {
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
verifier := oauth2.GenerateVerifier()
|
||||
if p.pkceMethod == pkceMethodPlain {
|
||||
return []oauth2.AuthCodeOption{
|
||||
oauth2.SetAuthURLParam("code_challenge", verifier),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", pkceMethodPlain),
|
||||
}, verifier
|
||||
}
|
||||
|
||||
return []oauth2.AuthCodeOption{oauth2.S256ChallengeOption(verifier)}, verifier
|
||||
}
|
||||
|
||||
// PKCETokenOptions returns PKCE options for the token exchange.
|
||||
func (p PlainOauthAuthenticator) PKCETokenOptions(verifier string) []oauth2.AuthCodeOption {
|
||||
if !p.usePKCE || verifier == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []oauth2.AuthCodeOption{oauth2.VerifierOption(verifier)}
|
||||
}
|
||||
|
||||
// RegistrationEnabled returns whether registration is enabled for the OAuth authenticator.
|
||||
func (p PlainOauthAuthenticator) RegistrationEnabled() bool {
|
||||
return p.registrationEnabled
|
||||
|
||||
61
internal/app/auth/auth_oauth_test.go
Normal file
61
internal/app/auth/auth_oauth_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func TestPlainOauthAuthenticatorPKCES256Options(t *testing.T) {
|
||||
authenticator := PlainOauthAuthenticator{usePKCE: true, pkceMethod: "S256"}
|
||||
|
||||
options, verifier := authenticator.PKCEAuthCodeOptions()
|
||||
if verifier == "" {
|
||||
t.Fatal("expected verifier")
|
||||
}
|
||||
|
||||
values := authCodeValues(t, options)
|
||||
|
||||
if values.Get("code_challenge") == "" {
|
||||
t.Fatal("expected code_challenge")
|
||||
}
|
||||
if values.Get("code_challenge_method") != "S256" {
|
||||
t.Fatalf("expected S256 challenge method, got %q", values.Get("code_challenge_method"))
|
||||
}
|
||||
|
||||
tokenOptions := authenticator.PKCETokenOptions(verifier)
|
||||
if len(tokenOptions) != 1 {
|
||||
t.Fatalf("expected one token option, got %d", len(tokenOptions))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlainOauthAuthenticatorPKCEPlainOptions(t *testing.T) {
|
||||
authenticator := PlainOauthAuthenticator{usePKCE: true, pkceMethod: "plain"}
|
||||
|
||||
options, verifier := authenticator.PKCEAuthCodeOptions()
|
||||
values := authCodeValues(t, options)
|
||||
|
||||
if values.Get("code_challenge") != verifier {
|
||||
t.Fatalf("expected plain challenge %q, got %q", verifier, values.Get("code_challenge"))
|
||||
}
|
||||
if values.Get("code_challenge_method") != "plain" {
|
||||
t.Fatalf("expected plain challenge method, got %q", values.Get("code_challenge_method"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlainOauthAuthenticatorPKCEDisabled(t *testing.T) {
|
||||
authenticator := PlainOauthAuthenticator{usePKCE: false, pkceMethod: "S256"}
|
||||
|
||||
options, verifier := authenticator.PKCEAuthCodeOptions()
|
||||
if len(options) != 0 {
|
||||
t.Fatalf("expected no auth code options, got %d", len(options))
|
||||
}
|
||||
if verifier != "" {
|
||||
t.Fatalf("expected empty verifier, got %q", verifier)
|
||||
}
|
||||
|
||||
tokenOptions := authenticator.PKCETokenOptions(oauth2.GenerateVerifier())
|
||||
if len(tokenOptions) != 0 {
|
||||
t.Fatalf("expected no token options, got %d", len(tokenOptions))
|
||||
}
|
||||
}
|
||||
@@ -30,6 +30,8 @@ type OidcAuthenticator struct {
|
||||
allowedUserGroups []string
|
||||
endSessionEndpoint string
|
||||
logoutIdpSession bool
|
||||
usePKCE bool
|
||||
pkceMethod string
|
||||
}
|
||||
|
||||
func newOidcAuthenticator(
|
||||
@@ -67,6 +69,14 @@ func newOidcAuthenticator(
|
||||
provider.allowedDomains = cfg.AllowedDomains
|
||||
provider.allowedUserGroups = cfg.AllowedUserGroups
|
||||
provider.logoutIdpSession = cfg.LogoutIdpSession == nil || *cfg.LogoutIdpSession
|
||||
provider.usePKCE = cfg.UsePKCE == nil || *cfg.UsePKCE
|
||||
provider.pkceMethod = cfg.PKCEMethod
|
||||
if provider.pkceMethod == "" {
|
||||
provider.pkceMethod = pkceMethodS256
|
||||
}
|
||||
if provider.usePKCE && provider.pkceMethod != pkceMethodS256 && provider.pkceMethod != pkceMethodPlain {
|
||||
return nil, fmt.Errorf("unsupported PKCE method %q, allowed: S256, plain", provider.pkceMethod)
|
||||
}
|
||||
|
||||
var providerMetadata struct {
|
||||
EndSessionEndpoint string `json:"end_session_endpoint"`
|
||||
@@ -121,6 +131,32 @@ func (o OidcAuthenticator) GetLogoutUrl(idTokenHint, postLogoutRedirectUri strin
|
||||
return logoutUrl.String(), true
|
||||
}
|
||||
|
||||
// PKCEAuthCodeOptions returns PKCE options for the authorization request and the verifier for the token exchange.
|
||||
func (o OidcAuthenticator) PKCEAuthCodeOptions() ([]oauth2.AuthCodeOption, string) {
|
||||
if !o.usePKCE {
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
verifier := oauth2.GenerateVerifier()
|
||||
if o.pkceMethod == pkceMethodPlain {
|
||||
return []oauth2.AuthCodeOption{
|
||||
oauth2.SetAuthURLParam("code_challenge", verifier),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", pkceMethodPlain),
|
||||
}, verifier
|
||||
}
|
||||
|
||||
return []oauth2.AuthCodeOption{oauth2.S256ChallengeOption(verifier)}, verifier
|
||||
}
|
||||
|
||||
// PKCETokenOptions returns PKCE options for the token exchange.
|
||||
func (o OidcAuthenticator) PKCETokenOptions(verifier string) []oauth2.AuthCodeOption {
|
||||
if !o.usePKCE || verifier == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []oauth2.AuthCodeOption{oauth2.VerifierOption(verifier)}
|
||||
}
|
||||
|
||||
// RegistrationEnabled returns whether registration is enabled for this authenticator.
|
||||
func (o OidcAuthenticator) RegistrationEnabled() bool {
|
||||
return o.registrationEnabled
|
||||
|
||||
79
internal/app/auth/auth_oidc_test.go
Normal file
79
internal/app/auth/auth_oidc_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func authCodeValues(t *testing.T, options []oauth2.AuthCodeOption) url.Values {
|
||||
t.Helper()
|
||||
|
||||
config := oauth2.Config{
|
||||
ClientID: "client-id",
|
||||
Endpoint: oauth2.Endpoint{AuthURL: "https://example.com/auth"},
|
||||
RedirectURL: "https://wg.example.com/callback",
|
||||
}
|
||||
authCodeURL, err := url.Parse(config.AuthCodeURL("state", options...))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse auth code URL: %v", err)
|
||||
}
|
||||
|
||||
return authCodeURL.Query()
|
||||
}
|
||||
|
||||
func TestOidcAuthenticatorPKCES256Options(t *testing.T) {
|
||||
authenticator := OidcAuthenticator{usePKCE: true, pkceMethod: "S256"}
|
||||
|
||||
options, verifier := authenticator.PKCEAuthCodeOptions()
|
||||
if verifier == "" {
|
||||
t.Fatal("expected verifier")
|
||||
}
|
||||
|
||||
values := authCodeValues(t, options)
|
||||
|
||||
if values.Get("code_challenge") == "" {
|
||||
t.Fatal("expected code_challenge")
|
||||
}
|
||||
if values.Get("code_challenge_method") != "S256" {
|
||||
t.Fatalf("expected S256 challenge method, got %q", values.Get("code_challenge_method"))
|
||||
}
|
||||
|
||||
tokenOptions := authenticator.PKCETokenOptions(verifier)
|
||||
if len(tokenOptions) != 1 {
|
||||
t.Fatalf("expected one token option, got %d", len(tokenOptions))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestOidcAuthenticatorPKCEPlainOptions(t *testing.T) {
|
||||
authenticator := OidcAuthenticator{usePKCE: true, pkceMethod: "plain"}
|
||||
|
||||
options, verifier := authenticator.PKCEAuthCodeOptions()
|
||||
values := authCodeValues(t, options)
|
||||
|
||||
if values.Get("code_challenge") != verifier {
|
||||
t.Fatalf("expected plain challenge %q, got %q", verifier, values.Get("code_challenge"))
|
||||
}
|
||||
if values.Get("code_challenge_method") != "plain" {
|
||||
t.Fatalf("expected plain challenge method, got %q", values.Get("code_challenge_method"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOidcAuthenticatorPKCEDisabled(t *testing.T) {
|
||||
authenticator := OidcAuthenticator{usePKCE: false, pkceMethod: "S256"}
|
||||
|
||||
options, verifier := authenticator.PKCEAuthCodeOptions()
|
||||
if len(options) != 0 {
|
||||
t.Fatalf("expected no auth code options, got %d", len(options))
|
||||
}
|
||||
if verifier != "" {
|
||||
t.Fatalf("expected empty verifier, got %q", verifier)
|
||||
}
|
||||
|
||||
tokenOptions := authenticator.PKCETokenOptions(oauth2.GenerateVerifier())
|
||||
if len(tokenOptions) != 0 {
|
||||
t.Fatalf("expected no token options, got %d", len(tokenOptions))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user