mirror of
				https://github.com/h44z/wg-portal.git
				synced 2025-11-03 23:56:18 +00:00 
			
		
		
		
	fix plain oauth login (#317)
This commit is contained in:
		@@ -6,8 +6,6 @@ import (
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/app"
 | 
			
		||||
	"github.com/sirupsen/logrus"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"path"
 | 
			
		||||
@@ -15,10 +13,11 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/coreos/go-oidc/v3/oidc"
 | 
			
		||||
	evbus "github.com/vardius/message-bus"
 | 
			
		||||
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/app"
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/config"
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/domain"
 | 
			
		||||
	"github.com/sirupsen/logrus"
 | 
			
		||||
	evbus "github.com/vardius/message-bus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type UserManager interface {
 | 
			
		||||
@@ -33,14 +32,21 @@ type Authenticator struct {
 | 
			
		||||
	oauthAuthenticators map[string]domain.OauthAuthenticator
 | 
			
		||||
	ldapAuthenticators  map[string]domain.LdapAuthenticator
 | 
			
		||||
 | 
			
		||||
	// URL prefix for the callback endpoints, this is a combination of the external URL and the API prefix
 | 
			
		||||
	callbackUrlPrefix string
 | 
			
		||||
 | 
			
		||||
	users UserManager
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewAuthenticator(cfg *config.Auth, bus evbus.MessageBus, users UserManager) (*Authenticator, error) {
 | 
			
		||||
func NewAuthenticator(cfg *config.Auth, extUrl string, bus evbus.MessageBus, users UserManager) (
 | 
			
		||||
	*Authenticator,
 | 
			
		||||
	error,
 | 
			
		||||
) {
 | 
			
		||||
	a := &Authenticator{
 | 
			
		||||
		cfg:   cfg,
 | 
			
		||||
		bus:   bus,
 | 
			
		||||
		users: users,
 | 
			
		||||
		cfg:               cfg,
 | 
			
		||||
		bus:               bus,
 | 
			
		||||
		users:             users,
 | 
			
		||||
		callbackUrlPrefix: fmt.Sprintf("%s/api/v0", extUrl),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
 | 
			
		||||
@@ -55,7 +61,7 @@ func NewAuthenticator(cfg *config.Auth, bus evbus.MessageBus, users UserManager)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Authenticator) setupExternalAuthProviders(ctx context.Context) error {
 | 
			
		||||
	extUrl, err := url.Parse(a.cfg.CallbackUrlPrefix)
 | 
			
		||||
	extUrl, err := url.Parse(a.callbackUrlPrefix)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to parse external url: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -141,8 +147,8 @@ func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.Lo
 | 
			
		||||
		authProviders = append(authProviders, domain.LoginProviderInfo{
 | 
			
		||||
			Identifier:  providerId,
 | 
			
		||||
			Name:        providerName,
 | 
			
		||||
			ProviderUrl: fmt.Sprintf("%s/%s/init", a.cfg.CallbackUrlPrefix, providerId),
 | 
			
		||||
			CallbackUrl: fmt.Sprintf("%s/%s/callback", a.cfg.CallbackUrlPrefix, providerId),
 | 
			
		||||
			ProviderUrl: fmt.Sprintf("/auth/login/%s/init", providerId),
 | 
			
		||||
			CallbackUrl: fmt.Sprintf("/auth/login/%s/callback", providerId),
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -187,8 +193,13 @@ func (a *Authenticator) PlainLogin(ctx context.Context, username, password strin
 | 
			
		||||
	return user, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Authenticator) passwordAuthentication(ctx context.Context, identifier domain.UserIdentifier, password string) (*domain.User, error) {
 | 
			
		||||
	ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
 | 
			
		||||
func (a *Authenticator) passwordAuthentication(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	identifier domain.UserIdentifier,
 | 
			
		||||
	password string,
 | 
			
		||||
) (*domain.User, error) {
 | 
			
		||||
	ctx = domain.SetUserInfo(ctx,
 | 
			
		||||
		domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
 | 
			
		||||
 | 
			
		||||
	var ldapUserInfo *domain.AuthenticatorUserInfo
 | 
			
		||||
	var ldapProvider domain.LdapAuthenticator
 | 
			
		||||
@@ -248,7 +259,8 @@ func (a *Authenticator) passwordAuthentication(ctx context.Context, identifier d
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !userInDatabase {
 | 
			
		||||
		user, err := a.processUserInfo(ctx, ldapUserInfo, domain.UserSourceLdap, ldapProvider.GetName(), ldapProvider.RegistrationEnabled())
 | 
			
		||||
		user, err := a.processUserInfo(ctx, ldapUserInfo, domain.UserSourceLdap, ldapProvider.GetName(),
 | 
			
		||||
			ldapProvider.RegistrationEnabled())
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("unable to process user information: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
@@ -262,7 +274,10 @@ func (a *Authenticator) passwordAuthentication(ctx context.Context, identifier d
 | 
			
		||||
 | 
			
		||||
// region oauth authentication
 | 
			
		||||
 | 
			
		||||
func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) (authCodeUrl, state, nonce string, err error) {
 | 
			
		||||
func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) (
 | 
			
		||||
	authCodeUrl, state, nonce string,
 | 
			
		||||
	err error,
 | 
			
		||||
) {
 | 
			
		||||
	oauthProvider, ok := a.oauthAuthenticators[providerId]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return "", "", "", fmt.Errorf("missing oauth provider %s", providerId)
 | 
			
		||||
@@ -318,8 +333,10 @@ func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce,
 | 
			
		||||
		return nil, fmt.Errorf("unable to parse user information: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
 | 
			
		||||
	user, err := a.processUserInfo(ctx, userInfo, domain.UserSourceOauth, oauthProvider.GetName(), oauthProvider.RegistrationEnabled())
 | 
			
		||||
	ctx = domain.SetUserInfo(ctx,
 | 
			
		||||
		domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
 | 
			
		||||
	user, err := a.processUserInfo(ctx, userInfo, domain.UserSourceOauth, oauthProvider.GetName(),
 | 
			
		||||
		oauthProvider.RegistrationEnabled())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("unable to process user information: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -333,7 +350,13 @@ func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce,
 | 
			
		||||
	return user, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Authenticator) processUserInfo(ctx context.Context, userInfo *domain.AuthenticatorUserInfo, source domain.UserSource, provider string, withReg bool) (*domain.User, error) {
 | 
			
		||||
func (a *Authenticator) processUserInfo(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	userInfo *domain.AuthenticatorUserInfo,
 | 
			
		||||
	source domain.UserSource,
 | 
			
		||||
	provider string,
 | 
			
		||||
	withReg bool,
 | 
			
		||||
) (*domain.User, error) {
 | 
			
		||||
	// Search user in backend
 | 
			
		||||
	user, err := a.users.GetUser(ctx, userInfo.Identifier)
 | 
			
		||||
	switch {
 | 
			
		||||
@@ -349,7 +372,12 @@ func (a *Authenticator) processUserInfo(ctx context.Context, userInfo *domain.Au
 | 
			
		||||
	return user, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Authenticator) registerNewUser(ctx context.Context, userInfo *domain.AuthenticatorUserInfo, source domain.UserSource, provider string) (*domain.User, error) {
 | 
			
		||||
func (a *Authenticator) registerNewUser(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	userInfo *domain.AuthenticatorUserInfo,
 | 
			
		||||
	source domain.UserSource,
 | 
			
		||||
	provider string,
 | 
			
		||||
) (*domain.User, error) {
 | 
			
		||||
	// convert user info to domain.User
 | 
			
		||||
	user := &domain.User{
 | 
			
		||||
		Identifier:   userInfo.Identifier,
 | 
			
		||||
 
 | 
			
		||||
@@ -7,10 +7,9 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Auth struct {
 | 
			
		||||
	OpenIDConnect     []OpenIDConnectProvider `yaml:"oidc"`
 | 
			
		||||
	OAuth             []OAuthProvider         `yaml:"oauth"`
 | 
			
		||||
	Ldap              []LdapProvider          `yaml:"ldap"`
 | 
			
		||||
	CallbackUrlPrefix string                  `yaml:"callback_url_prefix"`
 | 
			
		||||
	OpenIDConnect []OpenIDConnectProvider `yaml:"oidc"`
 | 
			
		||||
	OAuth         []OAuthProvider         `yaml:"oauth"`
 | 
			
		||||
	Ldap          []LdapProvider          `yaml:"ldap"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type BaseFields struct {
 | 
			
		||||
@@ -24,7 +23,7 @@ type BaseFields struct {
 | 
			
		||||
 | 
			
		||||
type OauthFields struct {
 | 
			
		||||
	BaseFields `yaml:",inline"`
 | 
			
		||||
	IsAdmin    string `yaml:"is_admin"`
 | 
			
		||||
	IsAdmin    string `yaml:"is_admin"` // If the value is "true", the user is an admin.
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type LdapFields struct {
 | 
			
		||||
@@ -93,8 +92,6 @@ type OAuthProvider struct {
 | 
			
		||||
	// DisplayName is shown to the user on the login page. If it is empty, ProviderName will be displayed.
 | 
			
		||||
	DisplayName string `yaml:"display_name"`
 | 
			
		||||
 | 
			
		||||
	BaseUrl string `yaml:"base_url"`
 | 
			
		||||
 | 
			
		||||
	// ClientID is the application's ID.
 | 
			
		||||
	ClientID string `yaml:"client_id"`
 | 
			
		||||
 | 
			
		||||
@@ -105,10 +102,6 @@ type OAuthProvider struct {
 | 
			
		||||
	TokenURL    string `yaml:"token_url"`
 | 
			
		||||
	UserInfoURL string `yaml:"user_info_url"`
 | 
			
		||||
 | 
			
		||||
	// RedirectURL is the URL to redirect users going through
 | 
			
		||||
	// the OAuth flow, after the resource owner's URLs.
 | 
			
		||||
	RedirectURL string `yaml:"redirect_url"`
 | 
			
		||||
 | 
			
		||||
	// Scope specifies optional requested permissions.
 | 
			
		||||
	Scopes []string `yaml:"scopes"`
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -104,8 +104,6 @@ func defaultConfig() *Config {
 | 
			
		||||
		SiteCompanyName:   "WireGuard Portal",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	cfg.Auth.CallbackUrlPrefix = "/api/v0"
 | 
			
		||||
 | 
			
		||||
	cfg.Advanced.StartListenPort = 51820
 | 
			
		||||
	cfg.Advanced.StartCidrV4 = "10.11.12.0/24"
 | 
			
		||||
	cfg.Advanced.StartCidrV6 = "fdfd:d3ad:c0de:1234::0/64"
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user