mirror of
https://github.com/h44z/wg-portal.git
synced 2025-04-19 08:55:12 +00:00
596 lines
18 KiB
Go
596 lines
18 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/url"
|
|
"path"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
"golang.org/x/oauth2"
|
|
|
|
"github.com/h44z/wg-portal/internal/app"
|
|
"github.com/h44z/wg-portal/internal/app/audit"
|
|
"github.com/h44z/wg-portal/internal/config"
|
|
"github.com/h44z/wg-portal/internal/domain"
|
|
)
|
|
|
|
// region dependencies
|
|
|
|
type UserManager interface {
|
|
// GetUser returns a user by its identifier.
|
|
GetUser(context.Context, domain.UserIdentifier) (*domain.User, error)
|
|
// RegisterUser creates a new user in the database.
|
|
RegisterUser(ctx context.Context, user *domain.User) error
|
|
// UpdateUser updates an existing user in the database.
|
|
UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error)
|
|
}
|
|
|
|
type EventBus interface {
|
|
// Publish sends a message to the message bus.
|
|
Publish(topic string, args ...any)
|
|
}
|
|
|
|
// endregion dependencies
|
|
|
|
type AuthenticatorType string
|
|
|
|
const (
|
|
AuthenticatorTypeOAuth AuthenticatorType = "oauth"
|
|
AuthenticatorTypeOidc AuthenticatorType = "oidc"
|
|
)
|
|
|
|
// AuthenticatorOauth is the interface for all OAuth authenticators.
|
|
type AuthenticatorOauth interface {
|
|
// GetName returns the name of the authenticator.
|
|
GetName() string
|
|
// GetType returns the type of the authenticator. It can be either AuthenticatorTypeOAuth or AuthenticatorTypeOidc.
|
|
GetType() AuthenticatorType
|
|
// AuthCodeURL returns the URL for the authentication flow.
|
|
AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
|
|
// Exchange exchanges the OAuth code for an access token.
|
|
Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
|
|
// GetUserInfo fetches the user information from the OAuth or OIDC provider.
|
|
GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (map[string]any, error)
|
|
// ParseUserInfo parses the raw user information into a domain.AuthenticatorUserInfo struct.
|
|
ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error)
|
|
// RegistrationEnabled returns whether registration is enabled for the OAuth authenticator.
|
|
RegistrationEnabled() bool
|
|
}
|
|
|
|
// AuthenticatorLdap is the interface for all LDAP authenticators.
|
|
type AuthenticatorLdap interface {
|
|
// GetName returns the name of the authenticator.
|
|
GetName() string
|
|
// PlaintextAuthentication performs a plaintext authentication against the LDAP server.
|
|
PlaintextAuthentication(userId domain.UserIdentifier, plainPassword string) error
|
|
// GetUserInfo fetches the user information from the LDAP server.
|
|
GetUserInfo(ctx context.Context, username domain.UserIdentifier) (map[string]any, error)
|
|
// ParseUserInfo parses the raw user information into a domain.AuthenticatorUserInfo struct.
|
|
ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error)
|
|
// RegistrationEnabled returns whether registration is enabled for the LDAP authenticator.
|
|
RegistrationEnabled() bool
|
|
}
|
|
|
|
// Authenticator is the main entry point for all authentication related tasks.
|
|
// This includes password authentication and external authentication providers (OIDC, OAuth, LDAP).
|
|
type Authenticator struct {
|
|
cfg *config.Auth
|
|
bus EventBus
|
|
|
|
oauthAuthenticators map[string]AuthenticatorOauth
|
|
ldapAuthenticators map[string]AuthenticatorLdap
|
|
|
|
// URL prefix for the callback endpoints, this is a combination of the external URL and the API prefix
|
|
callbackUrlPrefix string
|
|
|
|
users UserManager
|
|
}
|
|
|
|
// NewAuthenticator creates a new Authenticator instance.
|
|
func NewAuthenticator(cfg *config.Auth, extUrl string, bus EventBus, users UserManager) (
|
|
*Authenticator,
|
|
error,
|
|
) {
|
|
a := &Authenticator{
|
|
cfg: cfg,
|
|
bus: bus,
|
|
users: users,
|
|
callbackUrlPrefix: fmt.Sprintf("%s/api/v0", extUrl),
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
err := a.setupExternalAuthProviders(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return a, nil
|
|
}
|
|
|
|
func (a *Authenticator) setupExternalAuthProviders(ctx context.Context) error {
|
|
extUrl, err := url.Parse(a.callbackUrlPrefix)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse external url: %w", err)
|
|
}
|
|
|
|
a.oauthAuthenticators = make(map[string]AuthenticatorOauth, len(a.cfg.OpenIDConnect)+len(a.cfg.OAuth))
|
|
a.ldapAuthenticators = make(map[string]AuthenticatorLdap, len(a.cfg.Ldap))
|
|
|
|
for i := range a.cfg.OpenIDConnect { // OIDC
|
|
providerCfg := &a.cfg.OpenIDConnect[i]
|
|
providerId := strings.ToLower(providerCfg.ProviderName)
|
|
|
|
if _, exists := a.oauthAuthenticators[providerId]; exists {
|
|
return fmt.Errorf("auth provider with name %s is already registerd", providerId)
|
|
}
|
|
|
|
redirectUrl := *extUrl
|
|
redirectUrl.Path = path.Join(redirectUrl.Path, "/auth/login/", providerId, "/callback")
|
|
|
|
provider, err := newOidcAuthenticator(ctx, redirectUrl.String(), providerCfg)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to setup oidc authentication provider %s: %w", providerCfg.ProviderName, err)
|
|
}
|
|
a.oauthAuthenticators[providerId] = provider
|
|
}
|
|
for i := range a.cfg.OAuth { // PLAIN OAUTH
|
|
providerCfg := &a.cfg.OAuth[i]
|
|
providerId := strings.ToLower(providerCfg.ProviderName)
|
|
|
|
if _, exists := a.oauthAuthenticators[providerId]; exists {
|
|
return fmt.Errorf("auth provider with name %s is already registerd", providerId)
|
|
}
|
|
|
|
redirectUrl := *extUrl
|
|
redirectUrl.Path = path.Join(redirectUrl.Path, "/auth/login/", providerId, "/callback")
|
|
|
|
provider, err := newPlainOauthAuthenticator(ctx, redirectUrl.String(), providerCfg)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to setup oauth authentication provider %s: %w", providerId, err)
|
|
}
|
|
a.oauthAuthenticators[providerId] = provider
|
|
}
|
|
for i := range a.cfg.Ldap { // LDAP
|
|
providerCfg := &a.cfg.Ldap[i]
|
|
providerId := strings.ToLower(providerCfg.URL)
|
|
|
|
if _, exists := a.ldapAuthenticators[providerId]; exists {
|
|
return fmt.Errorf("auth provider with name %s is already registerd", providerId)
|
|
}
|
|
|
|
provider, err := newLdapAuthenticator(ctx, providerCfg)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to setup ldap authentication provider %s: %w", providerId, err)
|
|
}
|
|
a.ldapAuthenticators[providerId] = provider
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetExternalLoginProviders returns a list of all available external login providers.
|
|
func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.LoginProviderInfo {
|
|
authProviders := make([]domain.LoginProviderInfo, 0, len(a.cfg.OAuth)+len(a.cfg.OpenIDConnect))
|
|
|
|
for _, provider := range a.cfg.OpenIDConnect {
|
|
providerId := strings.ToLower(provider.ProviderName)
|
|
providerName := provider.DisplayName
|
|
if providerName == "" {
|
|
providerName = provider.ProviderName
|
|
}
|
|
authProviders = append(authProviders, domain.LoginProviderInfo{
|
|
Identifier: providerId,
|
|
Name: providerName,
|
|
ProviderUrl: fmt.Sprintf("/auth/login/%s/init", providerId),
|
|
CallbackUrl: fmt.Sprintf("/auth/login/%s/callback", providerId),
|
|
})
|
|
}
|
|
|
|
for _, provider := range a.cfg.OAuth {
|
|
providerId := strings.ToLower(provider.ProviderName)
|
|
providerName := provider.DisplayName
|
|
if providerName == "" {
|
|
providerName = provider.ProviderName
|
|
}
|
|
authProviders = append(authProviders, domain.LoginProviderInfo{
|
|
Identifier: providerId,
|
|
Name: providerName,
|
|
ProviderUrl: fmt.Sprintf("/auth/login/%s/init", providerId),
|
|
CallbackUrl: fmt.Sprintf("/auth/login/%s/callback", providerId),
|
|
})
|
|
}
|
|
|
|
return authProviders
|
|
}
|
|
|
|
// IsUserValid checks if a user is valid and not locked or disabled.
|
|
func (a *Authenticator) IsUserValid(ctx context.Context, id domain.UserIdentifier) bool {
|
|
ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context
|
|
user, err := a.users.GetUser(ctx, id)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
if user.IsDisabled() {
|
|
return false
|
|
}
|
|
|
|
if user.IsLocked() {
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// region password authentication
|
|
|
|
// PlainLogin performs a password authentication for a user. The username and password are trimmed before usage.
|
|
// If the login is successful, the user is returned, otherwise an error.
|
|
func (a *Authenticator) PlainLogin(ctx context.Context, username, password string) (*domain.User, error) {
|
|
// Validate form input
|
|
username = strings.TrimSpace(username)
|
|
password = strings.TrimSpace(password)
|
|
if username == "" || password == "" {
|
|
return nil, fmt.Errorf("missing username or password")
|
|
}
|
|
|
|
user, err := a.passwordAuthentication(ctx, domain.UserIdentifier(username), password)
|
|
if err != nil {
|
|
a.bus.Publish(app.TopicAuditLoginFailed, domain.AuditEventWrapper[audit.AuthEvent]{
|
|
Ctx: ctx,
|
|
Source: "plain",
|
|
Event: audit.AuthEvent{
|
|
Username: username, Error: err.Error(),
|
|
},
|
|
})
|
|
return nil, fmt.Errorf("login failed: %w", err)
|
|
}
|
|
|
|
a.bus.Publish(app.TopicAuthLogin, user.Identifier)
|
|
a.bus.Publish(app.TopicAuditLoginSuccess, domain.AuditEventWrapper[audit.AuthEvent]{
|
|
Ctx: ctx,
|
|
Source: "plain",
|
|
Event: audit.AuthEvent{
|
|
Username: string(user.Identifier),
|
|
},
|
|
})
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (a *Authenticator) passwordAuthentication(
|
|
ctx context.Context,
|
|
identifier domain.UserIdentifier,
|
|
password string,
|
|
) (*domain.User, error) {
|
|
ctx = domain.SetUserInfo(ctx,
|
|
domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
|
|
|
|
var ldapUserInfo *domain.AuthenticatorUserInfo
|
|
var ldapProvider AuthenticatorLdap
|
|
|
|
var userInDatabase = false
|
|
var userSource domain.UserSource
|
|
existingUser, err := a.users.GetUser(ctx, identifier)
|
|
if err == nil {
|
|
userInDatabase = true
|
|
userSource = existingUser.Source
|
|
}
|
|
if userInDatabase && (existingUser.IsLocked() || existingUser.IsDisabled()) {
|
|
return nil, errors.New("user is locked")
|
|
}
|
|
|
|
if !userInDatabase || userSource == domain.UserSourceLdap {
|
|
// search user in ldap if registration is enabled
|
|
for _, ldapAuth := range a.ldapAuthenticators {
|
|
if !userInDatabase && !ldapAuth.RegistrationEnabled() {
|
|
continue
|
|
}
|
|
|
|
rawUserInfo, err := ldapAuth.GetUserInfo(context.Background(), identifier)
|
|
if err != nil {
|
|
if !errors.Is(err, domain.ErrNotFound) {
|
|
slog.Warn("failed to fetch ldap user info", "identifier", identifier, "error", err)
|
|
}
|
|
continue // user not found / other ldap error
|
|
}
|
|
ldapUserInfo, err = ldapAuth.ParseUserInfo(rawUserInfo)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
// ldap user found
|
|
userSource = domain.UserSourceLdap
|
|
ldapProvider = ldapAuth
|
|
|
|
break
|
|
}
|
|
}
|
|
|
|
if userSource == "" {
|
|
return nil, errors.New("user not found")
|
|
}
|
|
|
|
if userSource == domain.UserSourceLdap && ldapProvider == nil {
|
|
return nil, errors.New("ldap provider not found")
|
|
}
|
|
|
|
switch userSource {
|
|
case domain.UserSourceDatabase:
|
|
err = existingUser.CheckPassword(password)
|
|
case domain.UserSourceLdap:
|
|
err = ldapProvider.PlaintextAuthentication(identifier, password)
|
|
default:
|
|
err = errors.New("no authentication backend available")
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to authenticate: %w", err)
|
|
}
|
|
|
|
if !userInDatabase {
|
|
user, err := a.processUserInfo(ctx, ldapUserInfo, domain.UserSourceLdap, ldapProvider.GetName(),
|
|
ldapProvider.RegistrationEnabled())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to process user information: %w", err)
|
|
}
|
|
return user, nil
|
|
} else {
|
|
return existingUser, nil
|
|
}
|
|
}
|
|
|
|
// endregion password authentication
|
|
|
|
// region oauth authentication
|
|
|
|
// OauthLoginStep1 starts the oauth authentication flow by returning the authentication URL, state and nonce.
|
|
func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) (
|
|
authCodeUrl, state, nonce string,
|
|
err error,
|
|
) {
|
|
oauthProvider, ok := a.oauthAuthenticators[providerId]
|
|
if !ok {
|
|
return "", "", "", fmt.Errorf("missing oauth provider %s", providerId)
|
|
}
|
|
|
|
// Prepare authentication flow, set state cookies
|
|
state, err = a.randString(16)
|
|
if err != nil {
|
|
return "", "", "", fmt.Errorf("failed to generate state: %w", err)
|
|
}
|
|
|
|
switch oauthProvider.GetType() {
|
|
case AuthenticatorTypeOAuth:
|
|
authCodeUrl = oauthProvider.AuthCodeURL(state)
|
|
case AuthenticatorTypeOidc:
|
|
nonce, err = a.randString(16)
|
|
if err != nil {
|
|
return "", "", "", fmt.Errorf("failed to generate nonce: %w", err)
|
|
}
|
|
|
|
authCodeUrl = oauthProvider.AuthCodeURL(state, oidc.Nonce(nonce))
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (a *Authenticator) randString(nByte int) (string, error) {
|
|
b := make([]byte, nByte)
|
|
if _, err := io.ReadFull(rand.Reader, b); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
}
|
|
|
|
// OauthLoginStep2 finishes the oauth authentication flow by exchanging the code for an access token and
|
|
// fetching the user information.
|
|
func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, error) {
|
|
oauthProvider, ok := a.oauthAuthenticators[providerId]
|
|
if !ok {
|
|
return nil, fmt.Errorf("missing oauth provider %s", providerId)
|
|
}
|
|
|
|
oauth2Token, err := oauthProvider.Exchange(ctx, code)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to exchange code: %w", err)
|
|
}
|
|
|
|
rawUserInfo, err := oauthProvider.GetUserInfo(ctx, oauth2Token, nonce)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to fetch user information: %w", err)
|
|
}
|
|
|
|
userInfo, err := oauthProvider.ParseUserInfo(rawUserInfo)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to parse user information: %w", err)
|
|
}
|
|
|
|
ctx = domain.SetUserInfo(ctx,
|
|
domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
|
|
user, err := a.processUserInfo(ctx, userInfo, domain.UserSourceOauth, oauthProvider.GetName(),
|
|
oauthProvider.RegistrationEnabled())
|
|
if err != nil {
|
|
a.bus.Publish(app.TopicAuditLoginFailed, domain.AuditEventWrapper[audit.AuthEvent]{
|
|
Ctx: ctx,
|
|
Source: "oauth " + providerId,
|
|
Event: audit.AuthEvent{
|
|
Username: string(userInfo.Identifier),
|
|
Error: err.Error(),
|
|
},
|
|
})
|
|
return nil, fmt.Errorf("unable to process user information: %w", err)
|
|
}
|
|
|
|
if user.IsLocked() || user.IsDisabled() {
|
|
a.bus.Publish(app.TopicAuditLoginFailed, domain.AuditEventWrapper[audit.AuthEvent]{
|
|
Ctx: ctx,
|
|
Source: "oauth " + providerId,
|
|
Event: audit.AuthEvent{
|
|
Username: string(user.Identifier),
|
|
Error: "user is locked",
|
|
},
|
|
})
|
|
return nil, errors.New("user is locked")
|
|
}
|
|
|
|
a.bus.Publish(app.TopicAuthLogin, user.Identifier)
|
|
a.bus.Publish(app.TopicAuditLoginSuccess, domain.AuditEventWrapper[audit.AuthEvent]{
|
|
Ctx: ctx,
|
|
Source: "oauth " + providerId,
|
|
Event: audit.AuthEvent{
|
|
Username: string(user.Identifier),
|
|
},
|
|
})
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (a *Authenticator) processUserInfo(
|
|
ctx context.Context,
|
|
userInfo *domain.AuthenticatorUserInfo,
|
|
source domain.UserSource,
|
|
provider string,
|
|
withReg bool,
|
|
) (*domain.User, error) {
|
|
// Search user in backend
|
|
user, err := a.users.GetUser(ctx, userInfo.Identifier)
|
|
switch {
|
|
case err != nil && withReg:
|
|
user, err = a.registerNewUser(ctx, userInfo, source, provider)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to register user: %w", err)
|
|
}
|
|
case err != nil:
|
|
return nil, fmt.Errorf("registration disabled, cannot create missing user: %w", err)
|
|
default:
|
|
err = a.updateExternalUser(ctx, user, userInfo, source, provider)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to update user: %w", err)
|
|
}
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (a *Authenticator) registerNewUser(
|
|
ctx context.Context,
|
|
userInfo *domain.AuthenticatorUserInfo,
|
|
source domain.UserSource,
|
|
provider string,
|
|
) (*domain.User, error) {
|
|
// convert user info to domain.User
|
|
user := &domain.User{
|
|
Identifier: userInfo.Identifier,
|
|
Email: userInfo.Email,
|
|
Source: source,
|
|
ProviderName: provider,
|
|
IsAdmin: userInfo.IsAdmin,
|
|
Firstname: userInfo.Firstname,
|
|
Lastname: userInfo.Lastname,
|
|
Phone: userInfo.Phone,
|
|
Department: userInfo.Department,
|
|
}
|
|
|
|
err := a.users.RegisterUser(ctx, user)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to register new user: %w", err)
|
|
}
|
|
|
|
slog.Debug("registered user from external authentication provider",
|
|
"user", user.Identifier,
|
|
"isAdmin", user.IsAdmin,
|
|
"provider", source)
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (a *Authenticator) getAuthenticatorConfig(id string) (any, error) {
|
|
for i := range a.cfg.OpenIDConnect {
|
|
if a.cfg.OpenIDConnect[i].ProviderName == id {
|
|
return a.cfg.OpenIDConnect[i], nil
|
|
}
|
|
}
|
|
|
|
for i := range a.cfg.OAuth {
|
|
if a.cfg.OAuth[i].ProviderName == id {
|
|
return a.cfg.OAuth[i], nil
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("no configuration for Authenticator id %s", id)
|
|
}
|
|
|
|
func (a *Authenticator) updateExternalUser(
|
|
ctx context.Context,
|
|
existingUser *domain.User,
|
|
userInfo *domain.AuthenticatorUserInfo,
|
|
source domain.UserSource,
|
|
provider string,
|
|
) error {
|
|
if existingUser.IsLocked() || existingUser.IsDisabled() {
|
|
return nil // user is locked or disabled, do not update
|
|
}
|
|
|
|
isChanged := false
|
|
if existingUser.Email != userInfo.Email {
|
|
existingUser.Email = userInfo.Email
|
|
isChanged = true
|
|
}
|
|
if existingUser.Firstname != userInfo.Firstname {
|
|
existingUser.Firstname = userInfo.Firstname
|
|
isChanged = true
|
|
}
|
|
if existingUser.Lastname != userInfo.Lastname {
|
|
existingUser.Lastname = userInfo.Lastname
|
|
isChanged = true
|
|
}
|
|
if existingUser.Phone != userInfo.Phone {
|
|
existingUser.Phone = userInfo.Phone
|
|
isChanged = true
|
|
}
|
|
if existingUser.Department != userInfo.Department {
|
|
existingUser.Department = userInfo.Department
|
|
isChanged = true
|
|
}
|
|
if existingUser.IsAdmin != userInfo.IsAdmin {
|
|
existingUser.IsAdmin = userInfo.IsAdmin
|
|
isChanged = true
|
|
}
|
|
if existingUser.Source != source {
|
|
existingUser.Source = source
|
|
isChanged = true
|
|
}
|
|
if existingUser.ProviderName != provider {
|
|
existingUser.ProviderName = provider
|
|
isChanged = true
|
|
}
|
|
|
|
if !isChanged {
|
|
return nil // nothing to update
|
|
}
|
|
|
|
_, err := a.users.UpdateUser(ctx, existingUser)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update user: %w", err)
|
|
}
|
|
|
|
slog.Debug("updated user with data from external authentication provider",
|
|
"user", existingUser.Identifier,
|
|
"isAdmin", existingUser.IsAdmin,
|
|
"provider", source)
|
|
|
|
return nil
|
|
}
|
|
|
|
// endregion oauth authentication
|