mirror of
https://github.com/h44z/wg-portal.git
synced 2025-09-14 15:01:14 +00:00
fix plain oauth login (#317)
This commit is contained in:
@@ -6,8 +6,6 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/sirupsen/logrus"
|
||||
"io"
|
||||
"net/url"
|
||||
"path"
|
||||
@@ -15,10 +13,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
evbus "github.com/vardius/message-bus"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
"github.com/sirupsen/logrus"
|
||||
evbus "github.com/vardius/message-bus"
|
||||
)
|
||||
|
||||
type UserManager interface {
|
||||
@@ -33,14 +32,21 @@ type Authenticator struct {
|
||||
oauthAuthenticators map[string]domain.OauthAuthenticator
|
||||
ldapAuthenticators map[string]domain.LdapAuthenticator
|
||||
|
||||
// URL prefix for the callback endpoints, this is a combination of the external URL and the API prefix
|
||||
callbackUrlPrefix string
|
||||
|
||||
users UserManager
|
||||
}
|
||||
|
||||
func NewAuthenticator(cfg *config.Auth, bus evbus.MessageBus, users UserManager) (*Authenticator, error) {
|
||||
func NewAuthenticator(cfg *config.Auth, extUrl string, bus evbus.MessageBus, users UserManager) (
|
||||
*Authenticator,
|
||||
error,
|
||||
) {
|
||||
a := &Authenticator{
|
||||
cfg: cfg,
|
||||
bus: bus,
|
||||
users: users,
|
||||
cfg: cfg,
|
||||
bus: bus,
|
||||
users: users,
|
||||
callbackUrlPrefix: fmt.Sprintf("%s/api/v0", extUrl),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
@@ -55,7 +61,7 @@ func NewAuthenticator(cfg *config.Auth, bus evbus.MessageBus, users UserManager)
|
||||
}
|
||||
|
||||
func (a *Authenticator) setupExternalAuthProviders(ctx context.Context) error {
|
||||
extUrl, err := url.Parse(a.cfg.CallbackUrlPrefix)
|
||||
extUrl, err := url.Parse(a.callbackUrlPrefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse external url: %w", err)
|
||||
}
|
||||
@@ -141,8 +147,8 @@ func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.Lo
|
||||
authProviders = append(authProviders, domain.LoginProviderInfo{
|
||||
Identifier: providerId,
|
||||
Name: providerName,
|
||||
ProviderUrl: fmt.Sprintf("%s/%s/init", a.cfg.CallbackUrlPrefix, providerId),
|
||||
CallbackUrl: fmt.Sprintf("%s/%s/callback", a.cfg.CallbackUrlPrefix, providerId),
|
||||
ProviderUrl: fmt.Sprintf("/auth/login/%s/init", providerId),
|
||||
CallbackUrl: fmt.Sprintf("/auth/login/%s/callback", providerId),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -187,8 +193,13 @@ func (a *Authenticator) PlainLogin(ctx context.Context, username, password strin
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (a *Authenticator) passwordAuthentication(ctx context.Context, identifier domain.UserIdentifier, password string) (*domain.User, error) {
|
||||
ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
|
||||
func (a *Authenticator) passwordAuthentication(
|
||||
ctx context.Context,
|
||||
identifier domain.UserIdentifier,
|
||||
password string,
|
||||
) (*domain.User, error) {
|
||||
ctx = domain.SetUserInfo(ctx,
|
||||
domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
|
||||
|
||||
var ldapUserInfo *domain.AuthenticatorUserInfo
|
||||
var ldapProvider domain.LdapAuthenticator
|
||||
@@ -248,7 +259,8 @@ func (a *Authenticator) passwordAuthentication(ctx context.Context, identifier d
|
||||
}
|
||||
|
||||
if !userInDatabase {
|
||||
user, err := a.processUserInfo(ctx, ldapUserInfo, domain.UserSourceLdap, ldapProvider.GetName(), ldapProvider.RegistrationEnabled())
|
||||
user, err := a.processUserInfo(ctx, ldapUserInfo, domain.UserSourceLdap, ldapProvider.GetName(),
|
||||
ldapProvider.RegistrationEnabled())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to process user information: %w", err)
|
||||
}
|
||||
@@ -262,7 +274,10 @@ func (a *Authenticator) passwordAuthentication(ctx context.Context, identifier d
|
||||
|
||||
// region oauth authentication
|
||||
|
||||
func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) (authCodeUrl, state, nonce string, err error) {
|
||||
func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) (
|
||||
authCodeUrl, state, nonce string,
|
||||
err error,
|
||||
) {
|
||||
oauthProvider, ok := a.oauthAuthenticators[providerId]
|
||||
if !ok {
|
||||
return "", "", "", fmt.Errorf("missing oauth provider %s", providerId)
|
||||
@@ -318,8 +333,10 @@ func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce,
|
||||
return nil, fmt.Errorf("unable to parse user information: %w", err)
|
||||
}
|
||||
|
||||
ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
|
||||
user, err := a.processUserInfo(ctx, userInfo, domain.UserSourceOauth, oauthProvider.GetName(), oauthProvider.RegistrationEnabled())
|
||||
ctx = domain.SetUserInfo(ctx,
|
||||
domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
|
||||
user, err := a.processUserInfo(ctx, userInfo, domain.UserSourceOauth, oauthProvider.GetName(),
|
||||
oauthProvider.RegistrationEnabled())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to process user information: %w", err)
|
||||
}
|
||||
@@ -333,7 +350,13 @@ func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce,
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (a *Authenticator) processUserInfo(ctx context.Context, userInfo *domain.AuthenticatorUserInfo, source domain.UserSource, provider string, withReg bool) (*domain.User, error) {
|
||||
func (a *Authenticator) processUserInfo(
|
||||
ctx context.Context,
|
||||
userInfo *domain.AuthenticatorUserInfo,
|
||||
source domain.UserSource,
|
||||
provider string,
|
||||
withReg bool,
|
||||
) (*domain.User, error) {
|
||||
// Search user in backend
|
||||
user, err := a.users.GetUser(ctx, userInfo.Identifier)
|
||||
switch {
|
||||
@@ -349,7 +372,12 @@ func (a *Authenticator) processUserInfo(ctx context.Context, userInfo *domain.Au
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (a *Authenticator) registerNewUser(ctx context.Context, userInfo *domain.AuthenticatorUserInfo, source domain.UserSource, provider string) (*domain.User, error) {
|
||||
func (a *Authenticator) registerNewUser(
|
||||
ctx context.Context,
|
||||
userInfo *domain.AuthenticatorUserInfo,
|
||||
source domain.UserSource,
|
||||
provider string,
|
||||
) (*domain.User, error) {
|
||||
// convert user info to domain.User
|
||||
user := &domain.User{
|
||||
Identifier: userInfo.Identifier,
|
||||
|
@@ -7,10 +7,9 @@ import (
|
||||
)
|
||||
|
||||
type Auth struct {
|
||||
OpenIDConnect []OpenIDConnectProvider `yaml:"oidc"`
|
||||
OAuth []OAuthProvider `yaml:"oauth"`
|
||||
Ldap []LdapProvider `yaml:"ldap"`
|
||||
CallbackUrlPrefix string `yaml:"callback_url_prefix"`
|
||||
OpenIDConnect []OpenIDConnectProvider `yaml:"oidc"`
|
||||
OAuth []OAuthProvider `yaml:"oauth"`
|
||||
Ldap []LdapProvider `yaml:"ldap"`
|
||||
}
|
||||
|
||||
type BaseFields struct {
|
||||
@@ -24,7 +23,7 @@ type BaseFields struct {
|
||||
|
||||
type OauthFields struct {
|
||||
BaseFields `yaml:",inline"`
|
||||
IsAdmin string `yaml:"is_admin"`
|
||||
IsAdmin string `yaml:"is_admin"` // If the value is "true", the user is an admin.
|
||||
}
|
||||
|
||||
type LdapFields struct {
|
||||
@@ -93,8 +92,6 @@ type OAuthProvider struct {
|
||||
// DisplayName is shown to the user on the login page. If it is empty, ProviderName will be displayed.
|
||||
DisplayName string `yaml:"display_name"`
|
||||
|
||||
BaseUrl string `yaml:"base_url"`
|
||||
|
||||
// ClientID is the application's ID.
|
||||
ClientID string `yaml:"client_id"`
|
||||
|
||||
@@ -105,10 +102,6 @@ type OAuthProvider struct {
|
||||
TokenURL string `yaml:"token_url"`
|
||||
UserInfoURL string `yaml:"user_info_url"`
|
||||
|
||||
// RedirectURL is the URL to redirect users going through
|
||||
// the OAuth flow, after the resource owner's URLs.
|
||||
RedirectURL string `yaml:"redirect_url"`
|
||||
|
||||
// Scope specifies optional requested permissions.
|
||||
Scopes []string `yaml:"scopes"`
|
||||
|
||||
|
@@ -104,8 +104,6 @@ func defaultConfig() *Config {
|
||||
SiteCompanyName: "WireGuard Portal",
|
||||
}
|
||||
|
||||
cfg.Auth.CallbackUrlPrefix = "/api/v0"
|
||||
|
||||
cfg.Advanced.StartListenPort = 51820
|
||||
cfg.Advanced.StartCidrV4 = "10.11.12.0/24"
|
||||
cfg.Advanced.StartCidrV6 = "fdfd:d3ad:c0de:1234::0/64"
|
||||
|
Reference in New Issue
Block a user