mirror of
https://github.com/h44z/wg-portal.git
synced 2026-01-29 06:36:24 +00:00
feat: allow multiple auth sources per user (#500,#477) (#612)
* feat: allow multiple auth sources per user (#500,#477) * only override isAdmin flag if it is provided by the authentication source
This commit is contained in:
@@ -23,9 +23,6 @@ import (
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
// SchemaVersion describes the current database schema version. It must be incremented if a manual migration is needed.
|
||||
var SchemaVersion uint64 = 2
|
||||
|
||||
// SysStat stores the current database schema version and the timestamp when it was applied.
|
||||
type SysStat struct {
|
||||
MigratedAt time.Time `gorm:"column:migrated_at"`
|
||||
@@ -225,6 +222,8 @@ func (r *SqlRepo) preCheck() error {
|
||||
func (r *SqlRepo) migrate() error {
|
||||
slog.Debug("running migration: sys-stat", "result", r.db.AutoMigrate(&SysStat{}))
|
||||
slog.Debug("running migration: user", "result", r.db.AutoMigrate(&domain.User{}))
|
||||
slog.Debug("running migration: user authentications", "result",
|
||||
r.db.AutoMigrate(&domain.UserAuthentication{}))
|
||||
slog.Debug("running migration: user webauthn credentials", "result",
|
||||
r.db.AutoMigrate(&domain.UserWebauthnCredential{}))
|
||||
slog.Debug("running migration: interface", "result", r.db.AutoMigrate(&domain.Interface{}))
|
||||
@@ -238,35 +237,84 @@ func (r *SqlRepo) migrate() error {
|
||||
|
||||
// Migration: 0 --> 1
|
||||
if existingSysStat.SchemaVersion == 0 {
|
||||
const schemaVersion = 1
|
||||
sysStat := SysStat{
|
||||
MigratedAt: time.Now(),
|
||||
SchemaVersion: SchemaVersion,
|
||||
SchemaVersion: schemaVersion,
|
||||
}
|
||||
if err := r.db.Create(&sysStat).Error; err != nil {
|
||||
return fmt.Errorf("failed to write sysstat entry for schema version %d: %w", SchemaVersion, err)
|
||||
return fmt.Errorf("failed to write sysstat entry for schema version %d: %w", schemaVersion, err)
|
||||
}
|
||||
slog.Debug("sys-stat entry written", "schema_version", SchemaVersion)
|
||||
slog.Debug("sys-stat entry written", "schema_version", schemaVersion)
|
||||
existingSysStat = sysStat // ensure that follow-up checks test against the latest version
|
||||
}
|
||||
|
||||
// Migration: 1 --> 2
|
||||
if existingSysStat.SchemaVersion == 1 {
|
||||
const schemaVersion = 2
|
||||
// Preserve existing behavior for installations that had default-peer-creation enabled.
|
||||
if r.cfg.Core.CreateDefaultPeer {
|
||||
err := r.db.Model(&domain.Interface{}).
|
||||
Where("type = ?", domain.InterfaceTypeServer).
|
||||
Update("create_default_peer", true).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to migrate interface flags for schema version %d: %w", SchemaVersion, err)
|
||||
return fmt.Errorf("failed to migrate interface flags for schema version %d: %w", schemaVersion, err)
|
||||
}
|
||||
slog.Debug("migrated interface create_default_peer flags", "schema_version", SchemaVersion)
|
||||
slog.Debug("migrated interface create_default_peer flags", "schema_version", schemaVersion)
|
||||
}
|
||||
sysStat := SysStat{
|
||||
MigratedAt: time.Now(),
|
||||
SchemaVersion: SchemaVersion,
|
||||
SchemaVersion: schemaVersion,
|
||||
}
|
||||
if err := r.db.Create(&sysStat).Error; err != nil {
|
||||
return fmt.Errorf("failed to write sysstat entry for schema version %d: %w", SchemaVersion, err)
|
||||
return fmt.Errorf("failed to write sysstat entry for schema version %d: %w", schemaVersion, err)
|
||||
}
|
||||
existingSysStat = sysStat // ensure that follow-up checks test against the latest version
|
||||
}
|
||||
|
||||
// Migration: 2 --> 3
|
||||
if existingSysStat.SchemaVersion == 2 {
|
||||
const schemaVersion = 3
|
||||
// Migration to multi-auth
|
||||
err := r.db.Transaction(func(tx *gorm.DB) error {
|
||||
var users []domain.User
|
||||
if err := tx.Find(&users).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for _, user := range users {
|
||||
auth := domain.UserAuthentication{
|
||||
BaseModel: domain.BaseModel{
|
||||
CreatedBy: domain.CtxSystemDBMigrator,
|
||||
UpdatedBy: domain.CtxSystemDBMigrator,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
},
|
||||
UserIdentifier: user.Identifier,
|
||||
Source: user.Source,
|
||||
ProviderName: user.ProviderName,
|
||||
}
|
||||
if err := tx.Create(&auth).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("migrated users to multi-auth model", "schema_version", schemaVersion)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to migrate to multi-auth: %w", err)
|
||||
}
|
||||
sysStat := SysStat{
|
||||
MigratedAt: time.Now(),
|
||||
SchemaVersion: schemaVersion,
|
||||
}
|
||||
if err := r.db.Create(&sysStat).Error; err != nil {
|
||||
return fmt.Errorf("failed to write sysstat entry for schema version %d: %w", schemaVersion, err)
|
||||
}
|
||||
existingSysStat = sysStat // ensure that follow-up checks test against the latest version
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -776,7 +824,7 @@ func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr
|
||||
func (r *SqlRepo) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
|
||||
var user domain.User
|
||||
|
||||
err := r.db.WithContext(ctx).Preload("WebAuthnCredentialList").First(&user, id).Error
|
||||
err := r.db.WithContext(ctx).Preload("WebAuthnCredentialList").Preload("Authentications").First(&user, id).Error
|
||||
|
||||
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, domain.ErrNotFound
|
||||
@@ -794,7 +842,8 @@ func (r *SqlRepo) GetUser(ctx context.Context, id domain.UserIdentifier) (*domai
|
||||
func (r *SqlRepo) GetUserByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||||
var users []domain.User
|
||||
|
||||
err := r.db.WithContext(ctx).Where("email = ?", email).Preload("WebAuthnCredentialList").Find(&users).Error
|
||||
err := r.db.WithContext(ctx).Where("email = ?",
|
||||
email).Preload("WebAuthnCredentialList").Preload("Authentications").Find(&users).Error
|
||||
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, domain.ErrNotFound
|
||||
}
|
||||
@@ -834,7 +883,7 @@ func (r *SqlRepo) GetUserByWebAuthnCredential(ctx context.Context, credentialIdB
|
||||
func (r *SqlRepo) GetAllUsers(ctx context.Context) ([]domain.User, error) {
|
||||
var users []domain.User
|
||||
|
||||
err := r.db.WithContext(ctx).Preload("WebAuthnCredentialList").Find(&users).Error
|
||||
err := r.db.WithContext(ctx).Preload("WebAuthnCredentialList").Preload("Authentications").Find(&users).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -854,6 +903,7 @@ func (r *SqlRepo) FindUsers(ctx context.Context, search string) ([]domain.User,
|
||||
Or("lastname LIKE ?", searchValue).
|
||||
Or("email LIKE ?", searchValue).
|
||||
Preload("WebAuthnCredentialList").
|
||||
Preload("Authentications").
|
||||
Find(&users).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -913,7 +963,17 @@ func (r *SqlRepo) getOrCreateUser(ui *domain.ContextUserInfo, tx *gorm.DB, id do
|
||||
) {
|
||||
var user domain.User
|
||||
|
||||
// userDefaults will be applied to newly created user records
|
||||
result := tx.Model(&user).Preload("WebAuthnCredentialList").Preload("Authentications").Find(&user, id)
|
||||
if result.Error != nil {
|
||||
if !errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, result.Error
|
||||
}
|
||||
}
|
||||
if result.Error == nil && result.RowsAffected > 0 {
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// create a new user record if no user record exists yet
|
||||
userDefaults := domain.User{
|
||||
BaseModel: domain.BaseModel{
|
||||
CreatedBy: ui.UserId(),
|
||||
@@ -922,16 +982,15 @@ func (r *SqlRepo) getOrCreateUser(ui *domain.ContextUserInfo, tx *gorm.DB, id do
|
||||
UpdatedAt: time.Now(),
|
||||
},
|
||||
Identifier: id,
|
||||
Source: domain.UserSourceDatabase,
|
||||
IsAdmin: false,
|
||||
}
|
||||
|
||||
err := tx.Attrs(userDefaults).FirstOrCreate(&user, id).Error
|
||||
err := tx.Create(&userDefaults).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
return &userDefaults, nil
|
||||
}
|
||||
|
||||
func (r *SqlRepo) upsertUser(ui *domain.ContextUserInfo, tx *gorm.DB, user *domain.User) error {
|
||||
@@ -948,6 +1007,11 @@ func (r *SqlRepo) upsertUser(ui *domain.ContextUserInfo, tx *gorm.DB, user *doma
|
||||
return fmt.Errorf("failed to update users webauthn credentials: %w", err)
|
||||
}
|
||||
|
||||
err = tx.Session(&gorm.Session{FullSaveAssociations: true}).Unscoped().Model(user).Association("Authentications").Unscoped().Replace(user.Authentications)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update users authentications: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user