diff --git a/internal/adapters/database.go b/internal/adapters/database.go index 5fe5e23..c428c13 100644 --- a/internal/adapters/database.go +++ b/internal/adapters/database.go @@ -232,21 +232,19 @@ func (r *SqlRepo) migrate() error { slog.Debug("running migration: interface status", "result", r.db.AutoMigrate(&domain.InterfaceStatus{})) slog.Debug("running migration: audit data", "result", r.db.AutoMigrate(&domain.AuditEntry{})) - existingSysStat := SysStat{} + var existingSysStat SysStat + var err error + r.db.Order("schema_version desc").First(&existingSysStat) // get latest version // Migration: 0 --> 1 if existingSysStat.SchemaVersion == 0 { const schemaVersion = 1 - sysStat := SysStat{ - MigratedAt: time.Now(), - SchemaVersion: schemaVersion, - } - if err := r.db.Create(&sysStat).Error; err != nil { - return fmt.Errorf("failed to write sysstat entry for schema version %d: %w", schemaVersion, err) + existingSysStat, err = r.addMigration(schemaVersion) // ensure that follow-up checks test against the latest version + if err != nil { + return err } slog.Debug("sys-stat entry written", "schema_version", schemaVersion) - existingSysStat = sysStat // ensure that follow-up checks test against the latest version } // Migration: 1 --> 2 @@ -262,14 +260,10 @@ func (r *SqlRepo) migrate() error { } slog.Debug("migrated interface create_default_peer flags", "schema_version", schemaVersion) } - sysStat := SysStat{ - MigratedAt: time.Now(), - SchemaVersion: schemaVersion, + existingSysStat, err = r.addMigration(schemaVersion) // ensure that follow-up checks test against the latest version + if err != nil { + return err } - if err := r.db.Create(&sysStat).Error; err != nil { - return fmt.Errorf("failed to write sysstat entry for schema version %d: %w", schemaVersion, err) - } - existingSysStat = sysStat // ensure that follow-up checks test against the latest version } // Migration: 2 --> 3 @@ -307,19 +301,45 @@ func (r *SqlRepo) migrate() error { if err != nil { return fmt.Errorf("failed to migrate to multi-auth: %w", err) } - sysStat := SysStat{ - MigratedAt: time.Now(), - SchemaVersion: schemaVersion, + existingSysStat, err = r.addMigration(schemaVersion) // ensure that follow-up checks test against the latest version + if err != nil { + return err } - if err := r.db.Create(&sysStat).Error; err != nil { - return fmt.Errorf("failed to write sysstat entry for schema version %d: %w", schemaVersion, err) + } + + // Migration: 3 --> 4 + if existingSysStat.SchemaVersion == 3 { + const schemaVersion = 4 + cutoff := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + + // Fix zero created_at timestamps for users. Set the to the last known update timestamp. + err := r.db.Model(&domain.User{}).Where("created_at < ?", cutoff). + Update("created_at", gorm.Expr("updated_at")).Error + if err != nil { + slog.Warn("failed to fix zero created_at for users", "error", err) + } + slog.Debug("fixed zero created_at timestamps for users", "schema_version", schemaVersion) + + existingSysStat, err = r.addMigration(schemaVersion) // ensure that follow-up checks test against the latest version + if err != nil { + return err } - existingSysStat = sysStat // ensure that follow-up checks test against the latest version } return nil } +func (r *SqlRepo) addMigration(schemaVersion uint64) (SysStat, error) { + sysStat := SysStat{ + MigratedAt: time.Now(), + SchemaVersion: schemaVersion, + } + if err := r.db.Create(&sysStat).Error; err != nil { + return SysStat{}, fmt.Errorf("failed to write sysstat entry for schema version %d: %w", schemaVersion, err) + } + return sysStat, nil +} + // region interfaces // GetInterface returns the interface with the given id. diff --git a/internal/adapters/database_created_at_test.go b/internal/adapters/database_created_at_test.go new file mode 100644 index 0000000..2a2761b --- /dev/null +++ b/internal/adapters/database_created_at_test.go @@ -0,0 +1,168 @@ +package adapters + +import ( + "context" + "testing" + "time" + + "github.com/glebarez/sqlite" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + + "github.com/h44z/wg-portal/internal/config" + "github.com/h44z/wg-portal/internal/domain" +) + +func newTestDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open("file::memory:"), &gorm.Config{}) + require.NoError(t, err) + return db +} + +func TestUpsertUser_SetsCreatedAtWhenZero(t *testing.T) { + db := newTestDB(t) + require.NoError(t, db.AutoMigrate(&domain.User{}, &domain.UserAuthentication{}, &domain.UserWebauthnCredential{})) + + repo := &SqlRepo{db: db, cfg: &config.Config{}} + ui := domain.SystemAdminContextUserInfo() + + user := &domain.User{ + Identifier: "test-user", + Email: "test@example.com", + // CreatedAt is zero + } + + err := repo.upsertUser(ui, db, user) + require.NoError(t, err) + + assert.False(t, user.CreatedAt.IsZero(), "CreatedAt should be set when it was zero") + assert.Equal(t, ui.UserId(), user.UpdatedBy, "UpdatedBy should be set when it was empty") + assert.WithinDuration(t, user.UpdatedAt, user.CreatedAt, time.Second, + "CreatedAt should be close to UpdatedAt for new user") +} + +func TestUpsertUser_PreservesExistingCreatedAt(t *testing.T) { + db := newTestDB(t) + require.NoError(t, db.AutoMigrate(&domain.User{}, &domain.UserAuthentication{}, &domain.UserWebauthnCredential{})) + + repo := &SqlRepo{db: db, cfg: &config.Config{}} + ui := domain.SystemAdminContextUserInfo() + + originalTime := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + user := &domain.User{ + Identifier: "test-user", + Email: "test@example.com", + BaseModel: domain.BaseModel{ + CreatedAt: originalTime, + CreatedBy: "original-creator", + }, + } + + err := repo.upsertUser(ui, db, user) + require.NoError(t, err) + + assert.Equal(t, originalTime, user.CreatedAt, "CreatedAt should not be overwritten") + assert.Equal(t, "original-creator", user.CreatedBy, "CreatedBy should not be overwritten") +} + +func TestSaveUser_NewUserGetsCreatedAt(t *testing.T) { + db := newTestDB(t) + require.NoError(t, db.AutoMigrate(&domain.User{}, &domain.UserAuthentication{}, &domain.UserWebauthnCredential{})) + + repo := &SqlRepo{db: db, cfg: &config.Config{}} + ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo()) + + before := time.Now().Add(-time.Second) + + err := repo.SaveUser(ctx, "new-user", func(u *domain.User) (*domain.User, error) { + u.Email = "new@example.com" + return u, nil + }) + require.NoError(t, err) + + var saved domain.User + require.NoError(t, db.First(&saved, "identifier = ?", "new-user").Error) + + assert.False(t, saved.CreatedAt.IsZero(), "CreatedAt should not be zero") + assert.True(t, saved.CreatedAt.After(before), "CreatedAt should be recent") + assert.NotEmpty(t, saved.CreatedBy, "CreatedBy should be set") +} + +func TestMigration_FixesZeroCreatedAt(t *testing.T) { + db := newTestDB(t) + + // Manually create tables and seed schema version 3 + require.NoError(t, db.AutoMigrate( + &SysStat{}, + &domain.User{}, + &domain.UserAuthentication{}, + &domain.Interface{}, + &domain.Cidr{}, + &domain.Peer{}, + &domain.AuditEntry{}, + &domain.UserWebauthnCredential{}, + )) + + // Insert schema versions 1, 2, 3 so migration starts at 3 + for v := uint64(1); v <= 3; v++ { + require.NoError(t, db.Create(&SysStat{SchemaVersion: v, MigratedAt: time.Now()}).Error) + } + + updatedAt := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC) + + // Insert a user with zero created_at but valid updated_at + require.NoError(t, db.Exec( + "INSERT INTO users (identifier, email, created_at, updated_at) VALUES (?, ?, ?, ?)", + "zero-user", "zero@example.com", time.Time{}, updatedAt, + ).Error) + + // Run migration + repo := &SqlRepo{db: db, cfg: &config.Config{}} + require.NoError(t, repo.migrate()) + + // Verify created_at was backfilled from updated_at + var user domain.User + require.NoError(t, db.First(&user, "identifier = ?", "zero-user").Error) + assert.Equal(t, updatedAt, user.CreatedAt, "created_at should be backfilled from updated_at") + + // Verify schema version advanced to 4 + var latest SysStat + require.NoError(t, db.Order("schema_version DESC").First(&latest).Error) + assert.Equal(t, uint64(4), latest.SchemaVersion) +} + +func TestMigration_DoesNotTouchValidCreatedAt(t *testing.T) { + db := newTestDB(t) + + require.NoError(t, db.AutoMigrate( + &SysStat{}, + &domain.User{}, + &domain.UserAuthentication{}, + &domain.Interface{}, + &domain.Cidr{}, + &domain.Peer{}, + &domain.AuditEntry{}, + &domain.UserWebauthnCredential{}, + )) + + for v := uint64(1); v <= 3; v++ { + require.NoError(t, db.Create(&SysStat{SchemaVersion: v, MigratedAt: time.Now()}).Error) + } + + createdAt := time.Date(2024, 3, 1, 8, 0, 0, 0, time.UTC) + updatedAt := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC) + + require.NoError(t, db.Exec( + "INSERT INTO users (identifier, email, created_at, updated_at) VALUES (?, ?, ?, ?)", + "valid-user", "valid@example.com", createdAt, updatedAt, + ).Error) + + repo := &SqlRepo{db: db, cfg: &config.Config{}} + require.NoError(t, repo.migrate()) + + var user domain.User + require.NoError(t, db.First(&user, "identifier = ?", "valid-user").Error) + assert.Equal(t, createdAt, user.CreatedAt, "valid created_at should not be modified") +} diff --git a/internal/app/users/user_manager.go b/internal/app/users/user_manager.go index 447fc9d..bb2d3a6 100644 --- a/internal/app/users/user_manager.go +++ b/internal/app/users/user_manager.go @@ -533,6 +533,7 @@ func (m Manager) create(ctx context.Context, user *domain.User) (*domain.User, e } err = m.users.SaveUser(ctx, user.Identifier, func(u *domain.User) (*domain.User, error) { + user.CopyCalculatedAttributes(u, false) return user, nil }) if err != nil {