chore: use interfaces for all other services

This commit is contained in:
Christoph Haas 2025-03-23 23:09:47 +01:00
parent 02ed7b19df
commit 7d0da4e7ad
40 changed files with 1337 additions and 406 deletions

View File

@ -14,6 +14,7 @@ import (
"github.com/h44z/wg-portal/internal/adapters" "github.com/h44z/wg-portal/internal/adapters"
"github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/app/api/core" "github.com/h44z/wg-portal/internal/app/api/core"
backendV0 "github.com/h44z/wg-portal/internal/app/api/v0/backend"
handlersV0 "github.com/h44z/wg-portal/internal/app/api/v0/handlers" handlersV0 "github.com/h44z/wg-portal/internal/app/api/v0/handlers"
backendV1 "github.com/h44z/wg-portal/internal/app/api/v1/backend" backendV1 "github.com/h44z/wg-portal/internal/app/api/v1/backend"
handlersV1 "github.com/h44z/wg-portal/internal/app/api/v1/handlers" handlersV1 "github.com/h44z/wg-portal/internal/app/api/v1/handlers"
@ -70,17 +71,24 @@ func main() {
queueSize := 100 queueSize := 100
eventBus := evbus.New(queueSize) eventBus := evbus.New(queueSize)
auditRecorder, err := audit.NewAuditRecorder(cfg, eventBus, database)
internal.AssertNoError(err)
auditRecorder.StartBackgroundJobs(ctx)
userManager, err := users.NewUserManager(cfg, eventBus, database, database) userManager, err := users.NewUserManager(cfg, eventBus, database, database)
internal.AssertNoError(err) internal.AssertNoError(err)
userManager.StartBackgroundJobs(ctx)
authenticator, err := auth.NewAuthenticator(&cfg.Auth, cfg.Web.ExternalUrl, eventBus, userManager) authenticator, err := auth.NewAuthenticator(&cfg.Auth, cfg.Web.ExternalUrl, eventBus, userManager)
internal.AssertNoError(err) internal.AssertNoError(err)
wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, wgQuick, database) wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, wgQuick, database)
internal.AssertNoError(err) internal.AssertNoError(err)
wireGuardManager.StartBackgroundJobs(ctx)
statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, eventBus, database, wireGuard, metricsServer) statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, eventBus, database, wireGuard, metricsServer)
internal.AssertNoError(err) internal.AssertNoError(err)
statisticsCollector.StartBackgroundJobs(ctx)
cfgFileManager, err := configfile.NewConfigFileManager(cfg, eventBus, database, database, cfgFileSystem) cfgFileManager, err := configfile.NewConfigFileManager(cfg, eventBus, database, database, cfgFileSystem)
internal.AssertNoError(err) internal.AssertNoError(err)
@ -88,18 +96,11 @@ func main() {
mailManager, err := mail.NewMailManager(cfg, mailer, cfgFileManager, database, database) mailManager, err := mail.NewMailManager(cfg, mailer, cfgFileManager, database, database)
internal.AssertNoError(err) internal.AssertNoError(err)
auditRecorder, err := audit.NewAuditRecorder(cfg, eventBus, database)
internal.AssertNoError(err)
auditRecorder.StartBackgroundJobs(ctx)
routeManager, err := route.NewRouteManager(cfg, eventBus, database) routeManager, err := route.NewRouteManager(cfg, eventBus, database)
internal.AssertNoError(err) internal.AssertNoError(err)
routeManager.StartBackgroundJobs(ctx) routeManager.StartBackgroundJobs(ctx)
backend, err := app.New(cfg, eventBus, authenticator, userManager, wireGuardManager, err = app.Initialize(cfg, wireGuardManager, userManager)
statisticsCollector, cfgFileManager, mailManager)
internal.AssertNoError(err)
err = backend.Startup(ctx)
internal.AssertNoError(err) internal.AssertNoError(err)
validatorManager := validator.New() validatorManager := validator.New()
@ -109,10 +110,14 @@ func main() {
apiV0Session := handlersV0.NewSessionWrapper(cfg) apiV0Session := handlersV0.NewSessionWrapper(cfg)
apiV0Auth := handlersV0.NewAuthenticationHandler(authenticator, apiV0Session) apiV0Auth := handlersV0.NewAuthenticationHandler(authenticator, apiV0Session)
apiV0EndpointAuth := handlersV0.NewAuthEndpoint(cfg, apiV0Auth, apiV0Session, validatorManager, backend) apiV0BackendUsers := backendV0.NewUserService(cfg, userManager, wireGuardManager)
apiV0EndpointUsers := handlersV0.NewUserEndpoint(cfg, apiV0Auth, validatorManager, backend) apiV0BackendInterfaces := backendV0.NewInterfaceService(cfg, wireGuardManager, cfgFileManager)
apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(cfg, apiV0Auth, validatorManager, backend) apiV0BackendPeers := backendV0.NewPeerService(cfg, wireGuardManager, cfgFileManager, mailManager)
apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, backend)
apiV0EndpointAuth := handlersV0.NewAuthEndpoint(cfg, apiV0Auth, apiV0Session, validatorManager, authenticator)
apiV0EndpointUsers := handlersV0.NewUserEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendUsers)
apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendInterfaces)
apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendPeers)
apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth) apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth)
apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth) apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth)

View File

@ -0,0 +1,91 @@
package backend
import (
"context"
"io"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
// region dependencies
type InterfaceServiceInterfaceManager interface {
GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error)
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error)
UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error)
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
PrepareInterface(ctx context.Context) (*domain.Interface, error)
ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error
}
type InterfaceServiceConfigFileManager interface {
PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error
GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error)
}
// endregion dependencies
type InterfaceService struct {
cfg *config.Config
interfaces InterfaceServiceInterfaceManager
configFile InterfaceServiceConfigFileManager
}
func NewInterfaceService(
cfg *config.Config,
interfaces InterfaceServiceInterfaceManager,
configFile InterfaceServiceConfigFileManager,
) *InterfaceService {
return &InterfaceService{
cfg: cfg,
interfaces: interfaces,
configFile: configFile,
}
}
func (i InterfaceService) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
*domain.Interface,
[]domain.Peer,
error,
) {
return i.interfaces.GetInterfaceAndPeers(ctx, id)
}
func (i InterfaceService) PrepareInterface(ctx context.Context) (*domain.Interface, error) {
return i.interfaces.PrepareInterface(ctx)
}
func (i InterfaceService) CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) {
return i.interfaces.CreateInterface(ctx, in)
}
func (i InterfaceService) UpdateInterface(ctx context.Context, in *domain.Interface) (
*domain.Interface,
[]domain.Peer,
error,
) {
return i.interfaces.UpdateInterface(ctx, in)
}
func (i InterfaceService) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
return i.interfaces.DeleteInterface(ctx, id)
}
func (i InterfaceService) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) {
return i.interfaces.GetAllInterfacesAndPeers(ctx)
}
func (i InterfaceService) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) {
return i.configFile.GetInterfaceConfig(ctx, id)
}
func (i InterfaceService) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error {
return i.configFile.PersistInterfaceConfig(ctx, id)
}
func (i InterfaceService) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error {
return i.interfaces.ApplyPeerDefaults(ctx, in)
}

View File

@ -0,0 +1,112 @@
package backend
import (
"context"
"io"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
// region dependencies
type PeerServicePeerManager interface {
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error)
CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error)
UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error)
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
CreateMultiplePeers(
ctx context.Context,
interfaceId domain.InterfaceIdentifier,
r *domain.PeerCreationRequest,
) ([]domain.Peer, error)
GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error)
}
type PeerServiceConfigFileManager interface {
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
}
type PeerServiceMailManager interface {
SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error
}
// endregion dependencies
type PeerService struct {
cfg *config.Config
peers PeerServicePeerManager
configFile PeerServiceConfigFileManager
mailer PeerServiceMailManager
}
func NewPeerService(
cfg *config.Config,
peers PeerServicePeerManager,
configFile PeerServiceConfigFileManager,
mailer PeerServiceMailManager,
) *PeerService {
return &PeerService{
cfg: cfg,
peers: peers,
configFile: configFile,
mailer: mailer,
}
}
func (p PeerService) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
*domain.Interface,
[]domain.Peer,
error,
) {
return p.peers.GetInterfaceAndPeers(ctx, id)
}
func (p PeerService) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) {
return p.peers.PreparePeer(ctx, id)
}
func (p PeerService) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) {
return p.peers.GetPeer(ctx, id)
}
func (p PeerService) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
return p.peers.CreatePeer(ctx, peer)
}
func (p PeerService) CreateMultiplePeers(
ctx context.Context,
interfaceId domain.InterfaceIdentifier,
r *domain.PeerCreationRequest,
) ([]domain.Peer, error) {
return p.peers.CreateMultiplePeers(ctx, interfaceId, r)
}
func (p PeerService) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
return p.peers.UpdatePeer(ctx, peer)
}
func (p PeerService) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error {
return p.peers.DeletePeer(ctx, id)
}
func (p PeerService) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
return p.configFile.GetPeerConfig(ctx, id)
}
func (p PeerService) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
return p.configFile.GetPeerConfigQrCode(ctx, id)
}
func (p PeerService) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error {
return p.mailer.SendPeerEmail(ctx, linkOnly, peers...)
}
func (p PeerService) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error) {
return p.peers.GetPeerStats(ctx, id)
}

View File

@ -0,0 +1,83 @@
package backend
import (
"context"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
// region dependencies
type UserServiceUserManager interface {
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
GetAllUsers(ctx context.Context) ([]domain.User, error)
CreateUser(ctx context.Context, user *domain.User) (*domain.User, error)
UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error)
DeleteUser(ctx context.Context, id domain.UserIdentifier) error
ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
}
type UserServiceWireGuardManager interface {
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
GetUserInterfaces(ctx context.Context, _ domain.UserIdentifier) ([]domain.Interface, error)
GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error)
}
// endregion dependencies
type UserService struct {
cfg *config.Config
users UserServiceUserManager
wg UserServiceWireGuardManager
}
func NewUserService(cfg *config.Config, users UserServiceUserManager, wg UserServiceWireGuardManager) *UserService {
return &UserService{
cfg: cfg,
users: users,
wg: wg,
}
}
func (u UserService) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
return u.users.GetUser(ctx, id)
}
func (u UserService) GetAllUsers(ctx context.Context) ([]domain.User, error) {
return u.users.GetAllUsers(ctx)
}
func (u UserService) UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
return u.users.UpdateUser(ctx, user)
}
func (u UserService) CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
return u.users.CreateUser(ctx, user)
}
func (u UserService) DeleteUser(ctx context.Context, id domain.UserIdentifier) error {
return u.users.DeleteUser(ctx, id)
}
func (u UserService) ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
return u.users.ActivateApi(ctx, id)
}
func (u UserService) DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
return u.users.DeactivateApi(ctx, id)
}
func (u UserService) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
return u.wg.GetUserPeers(ctx, id)
}
func (u UserService) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) {
return u.wg.GetUserPeerStats(ctx, id)
}
func (u UserService) GetUserInterfaces(ctx context.Context, id domain.UserIdentifier) ([]domain.Interface, error) {
return u.wg.GetUserInterfaces(ctx, id)
}

View File

@ -7,46 +7,43 @@ import (
"log/slog" "log/slog"
"time" "time"
evbus "github.com/vardius/message-bus"
"github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
) )
type App struct { // region dependencies
Config *config.Config
bus evbus.MessageBus
Authenticator type WireGuardManager interface {
UserManager ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error)
WireGuardManager RestoreInterfaceState(ctx context.Context, updateDbOnError bool, filter ...domain.InterfaceIdentifier) error
StatisticsCollector
ConfigFileManager
MailManager
ApiV1Manager
} }
func New( type UserManager interface {
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
CreateUser(ctx context.Context, user *domain.User) (*domain.User, error)
}
// endregion dependencies
// App is the main application struct.
type App struct {
cfg *config.Config
wg WireGuardManager
users UserManager
}
// Initialize creates a new App instance and initializes it.
func Initialize(
cfg *config.Config, cfg *config.Config,
bus evbus.MessageBus, wg WireGuardManager,
authenticator Authenticator,
users UserManager, users UserManager,
wireGuard WireGuardManager, ) error {
stats StatisticsCollector,
cfgFiles ConfigFileManager,
mailer MailManager,
) (*App, error) {
a := &App{ a := &App{
Config: cfg, cfg: cfg,
bus: bus,
Authenticator: authenticator, wg: wg,
UserManager: users, users: users,
WireGuardManager: wireGuard,
StatisticsCollector: stats,
ConfigFileManager: cfgFiles,
MailManager: mailer,
} }
startupContext, cancel := context.WithTimeout(context.Background(), 30*time.Second) startupContext, cancel := context.WithTimeout(context.Background(), 30*time.Second)
@ -56,36 +53,27 @@ func New(
startupContext = domain.SetUserInfo(startupContext, domain.SystemAdminContextUserInfo()) startupContext = domain.SetUserInfo(startupContext, domain.SystemAdminContextUserInfo())
if err := a.createDefaultUser(startupContext); err != nil { if err := a.createDefaultUser(startupContext); err != nil {
return nil, fmt.Errorf("failed to create default user: %w", err) return fmt.Errorf("failed to create default user: %w", err)
} }
if err := a.importNewInterfaces(startupContext); err != nil { if err := a.importNewInterfaces(startupContext); err != nil {
return nil, fmt.Errorf("failed to import new interfaces: %w", err) return fmt.Errorf("failed to import new interfaces: %w", err)
} }
if err := a.restoreInterfaceState(startupContext); err != nil { if err := a.restoreInterfaceState(startupContext); err != nil {
return nil, fmt.Errorf("failed to restore interface state: %w", err) return fmt.Errorf("failed to restore interface state: %w", err)
} }
return a, nil
}
func (a *App) Startup(ctx context.Context) error {
a.UserManager.StartBackgroundJobs(ctx)
a.StatisticsCollector.StartBackgroundJobs(ctx)
a.WireGuardManager.StartBackgroundJobs(ctx)
return nil return nil
} }
func (a *App) importNewInterfaces(ctx context.Context) error { func (a *App) importNewInterfaces(ctx context.Context) error {
if !a.Config.Core.ImportExisting { if !a.cfg.Core.ImportExisting {
slog.Debug("skipping interface import - feature disabled") slog.Debug("skipping interface import - feature disabled")
return nil // feature disabled return nil // feature disabled
} }
importedCount, err := a.ImportNewInterfaces(ctx) importedCount, err := a.wg.ImportNewInterfaces(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -97,12 +85,12 @@ func (a *App) importNewInterfaces(ctx context.Context) error {
} }
func (a *App) restoreInterfaceState(ctx context.Context) error { func (a *App) restoreInterfaceState(ctx context.Context) error {
if !a.Config.Core.RestoreState { if !a.cfg.Core.RestoreState {
slog.Debug("skipping interface state restore - feature disabled") slog.Debug("skipping interface state restore - feature disabled")
return nil // feature disabled return nil // feature disabled
} }
err := a.RestoreInterfaceState(ctx, true) err := a.wg.RestoreInterfaceState(ctx, true)
if err != nil { if err != nil {
return err return err
} }
@ -112,13 +100,13 @@ func (a *App) restoreInterfaceState(ctx context.Context) error {
} }
func (a *App) createDefaultUser(ctx context.Context) error { func (a *App) createDefaultUser(ctx context.Context) error {
adminUserId := domain.UserIdentifier(a.Config.Core.AdminUser) adminUserId := domain.UserIdentifier(a.cfg.Core.AdminUser)
if adminUserId == "" { if adminUserId == "" {
slog.Debug("skipping default user creation - admin user is blank") slog.Debug("skipping default user creation - admin user is blank")
return nil // empty admin user - do not create return nil // empty admin user - do not create
} }
_, err := a.GetUser(ctx, adminUserId) _, err := a.users.GetUser(ctx, adminUserId)
if err != nil && !errors.Is(err, domain.ErrNotFound) { if err != nil && !errors.Is(err, domain.ErrNotFound) {
return err return err
} }
@ -145,22 +133,22 @@ func (a *App) createDefaultUser(ctx context.Context) error {
Phone: "", Phone: "",
Department: "", Department: "",
Notes: "default administrator user", Notes: "default administrator user",
Password: domain.PrivateString(a.Config.Core.AdminPassword), Password: domain.PrivateString(a.cfg.Core.AdminPassword),
Disabled: nil, Disabled: nil,
DisabledReason: "", DisabledReason: "",
Locked: nil, Locked: nil,
LockedReason: "", LockedReason: "",
LinkedPeerCount: 0, LinkedPeerCount: 0,
} }
if a.Config.Core.AdminApiToken != "" { if a.cfg.Core.AdminApiToken != "" {
if len(a.Config.Core.AdminApiToken) < 18 { if len(a.cfg.Core.AdminApiToken) < 18 {
slog.Warn("admin API token is too short, should be at least 18 characters long") slog.Warn("admin API token is too short, should be at least 18 characters long")
} }
defaultAdmin.ApiToken = a.Config.Core.AdminApiToken defaultAdmin.ApiToken = a.cfg.Core.AdminApiToken
defaultAdmin.ApiTokenCreated = &now defaultAdmin.ApiTokenCreated = &now
} }
admin, err := a.CreateUser(ctx, defaultAdmin) admin, err := a.users.CreateUser(ctx, defaultAdmin)
if err != nil { if err != nil {
return err return err
} }

View File

@ -6,21 +6,35 @@ import (
"log/slog" "log/slog"
"time" "time"
evbus "github.com/vardius/message-bus"
"github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
) )
// region dependencies
type DatabaseRepo interface {
// SaveAuditEntry saves an audit entry to the database
SaveAuditEntry(ctx context.Context, entry *domain.AuditEntry) error
}
type EventBus interface {
// Subscribe subscribes to a topic
Subscribe(topic string, fn interface{}) error
}
// endregion dependencies
// Recorder is responsible for recording audit events to the database.
type Recorder struct { type Recorder struct {
cfg *config.Config cfg *config.Config
bus evbus.MessageBus bus EventBus
db DatabaseRepo db DatabaseRepo
} }
func NewAuditRecorder(cfg *config.Config, bus evbus.MessageBus, db DatabaseRepo) (*Recorder, error) { // NewAuditRecorder creates a new audit recorder instance.
func NewAuditRecorder(cfg *config.Config, bus EventBus, db DatabaseRepo) (*Recorder, error) {
r := &Recorder{ r := &Recorder{
cfg: cfg, cfg: cfg,
bus: bus, bus: bus,
@ -36,6 +50,8 @@ func NewAuditRecorder(cfg *config.Config, bus evbus.MessageBus, db DatabaseRepo)
return r, nil return r, nil
} }
// StartBackgroundJobs starts background jobs for the audit recorder.
// This method is non-blocking and returns immediately.
func (r *Recorder) StartBackgroundJobs(ctx context.Context) { func (r *Recorder) StartBackgroundJobs(ctx context.Context) {
if !r.cfg.Statistics.CollectAuditData { if !r.cfg.Statistics.CollectAuditData {
return // noting to do return // noting to do

View File

@ -1,11 +0,0 @@
package audit
import (
"context"
"github.com/h44z/wg-portal/internal/domain"
)
type DatabaseRepo interface {
SaveAuditEntry(ctx context.Context, entry *domain.AuditEntry) error
}

View File

@ -14,25 +14,78 @@ import (
"time" "time"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
evbus "github.com/vardius/message-bus" "golang.org/x/oauth2"
"github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
) )
// region dependencies
type UserManager interface { type UserManager interface {
// GetUser returns a user by its identifier.
GetUser(context.Context, domain.UserIdentifier) (*domain.User, error) GetUser(context.Context, domain.UserIdentifier) (*domain.User, error)
// RegisterUser creates a new user in the database.
RegisterUser(ctx context.Context, user *domain.User) error RegisterUser(ctx context.Context, user *domain.User) error
// UpdateUser updates an existing user in the database.
UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error)
} }
type EventBus interface {
// Publish sends a message to the message bus.
Publish(topic string, args ...any)
}
// endregion dependencies
type AuthenticatorType string
const (
AuthenticatorTypeOAuth AuthenticatorType = "oauth"
AuthenticatorTypeOidc AuthenticatorType = "oidc"
)
// AuthenticatorOauth is the interface for all OAuth authenticators.
type AuthenticatorOauth interface {
// GetName returns the name of the authenticator.
GetName() string
// GetType returns the type of the authenticator. It can be either AuthenticatorTypeOAuth or AuthenticatorTypeOidc.
GetType() AuthenticatorType
// AuthCodeURL returns the URL for the authentication flow.
AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
// Exchange exchanges the OAuth code for an access token.
Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
// GetUserInfo fetches the user information from the OAuth or OIDC provider.
GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (map[string]any, error)
// ParseUserInfo parses the raw user information into a domain.AuthenticatorUserInfo struct.
ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error)
// RegistrationEnabled returns whether registration is enabled for the OAuth authenticator.
RegistrationEnabled() bool
}
// AuthenticatorLdap is the interface for all LDAP authenticators.
type AuthenticatorLdap interface {
// GetName returns the name of the authenticator.
GetName() string
// PlaintextAuthentication performs a plaintext authentication against the LDAP server.
PlaintextAuthentication(userId domain.UserIdentifier, plainPassword string) error
// GetUserInfo fetches the user information from the LDAP server.
GetUserInfo(ctx context.Context, username domain.UserIdentifier) (map[string]any, error)
// ParseUserInfo parses the raw user information into a domain.AuthenticatorUserInfo struct.
ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error)
// RegistrationEnabled returns whether registration is enabled for the LDAP authenticator.
RegistrationEnabled() bool
}
// Authenticator is the main entry point for all authentication related tasks.
// This includes password authentication and external authentication providers (OIDC, OAuth, LDAP).
type Authenticator struct { type Authenticator struct {
cfg *config.Auth cfg *config.Auth
bus evbus.MessageBus bus EventBus
oauthAuthenticators map[string]domain.OauthAuthenticator oauthAuthenticators map[string]AuthenticatorOauth
ldapAuthenticators map[string]domain.LdapAuthenticator ldapAuthenticators map[string]AuthenticatorLdap
// URL prefix for the callback endpoints, this is a combination of the external URL and the API prefix // URL prefix for the callback endpoints, this is a combination of the external URL and the API prefix
callbackUrlPrefix string callbackUrlPrefix string
@ -40,7 +93,8 @@ type Authenticator struct {
users UserManager users UserManager
} }
func NewAuthenticator(cfg *config.Auth, extUrl string, bus evbus.MessageBus, users UserManager) ( // NewAuthenticator creates a new Authenticator instance.
func NewAuthenticator(cfg *config.Auth, extUrl string, bus EventBus, users UserManager) (
*Authenticator, *Authenticator,
error, error,
) { ) {
@ -68,8 +122,8 @@ func (a *Authenticator) setupExternalAuthProviders(ctx context.Context) error {
return fmt.Errorf("failed to parse external url: %w", err) return fmt.Errorf("failed to parse external url: %w", err)
} }
a.oauthAuthenticators = make(map[string]domain.OauthAuthenticator, len(a.cfg.OpenIDConnect)+len(a.cfg.OAuth)) a.oauthAuthenticators = make(map[string]AuthenticatorOauth, len(a.cfg.OpenIDConnect)+len(a.cfg.OAuth))
a.ldapAuthenticators = make(map[string]domain.LdapAuthenticator, len(a.cfg.Ldap)) a.ldapAuthenticators = make(map[string]AuthenticatorLdap, len(a.cfg.Ldap))
for i := range a.cfg.OpenIDConnect { // OIDC for i := range a.cfg.OpenIDConnect { // OIDC
providerCfg := &a.cfg.OpenIDConnect[i] providerCfg := &a.cfg.OpenIDConnect[i]
@ -123,6 +177,7 @@ func (a *Authenticator) setupExternalAuthProviders(ctx context.Context) error {
return nil return nil
} }
// GetExternalLoginProviders returns a list of all available external login providers.
func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.LoginProviderInfo { func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.LoginProviderInfo {
authProviders := make([]domain.LoginProviderInfo, 0, len(a.cfg.OAuth)+len(a.cfg.OpenIDConnect)) authProviders := make([]domain.LoginProviderInfo, 0, len(a.cfg.OAuth)+len(a.cfg.OpenIDConnect))
@ -157,6 +212,7 @@ func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.Lo
return authProviders return authProviders
} }
// IsUserValid checks if a user is valid and not locked or disabled.
func (a *Authenticator) IsUserValid(ctx context.Context, id domain.UserIdentifier) bool { func (a *Authenticator) IsUserValid(ctx context.Context, id domain.UserIdentifier) bool {
ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context
user, err := a.users.GetUser(ctx, id) user, err := a.users.GetUser(ctx, id)
@ -177,6 +233,8 @@ func (a *Authenticator) IsUserValid(ctx context.Context, id domain.UserIdentifie
// region password authentication // region password authentication
// PlainLogin performs a password authentication for a user. The username and password are trimmed before usage.
// If the login is successful, the user is returned, otherwise an error.
func (a *Authenticator) PlainLogin(ctx context.Context, username, password string) (*domain.User, error) { func (a *Authenticator) PlainLogin(ctx context.Context, username, password string) (*domain.User, error) {
// Validate form input // Validate form input
username = strings.TrimSpace(username) username = strings.TrimSpace(username)
@ -204,7 +262,7 @@ func (a *Authenticator) passwordAuthentication(
domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
var ldapUserInfo *domain.AuthenticatorUserInfo var ldapUserInfo *domain.AuthenticatorUserInfo
var ldapProvider domain.LdapAuthenticator var ldapProvider AuthenticatorLdap
var userInDatabase = false var userInDatabase = false
var userSource domain.UserSource var userSource domain.UserSource
@ -280,6 +338,7 @@ func (a *Authenticator) passwordAuthentication(
// region oauth authentication // region oauth authentication
// OauthLoginStep1 starts the oauth authentication flow by returning the authentication URL, state and nonce.
func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) ( func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) (
authCodeUrl, state, nonce string, authCodeUrl, state, nonce string,
err error, err error,
@ -296,9 +355,9 @@ func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) (
} }
switch oauthProvider.GetType() { switch oauthProvider.GetType() {
case domain.AuthenticatorTypeOAuth: case AuthenticatorTypeOAuth:
authCodeUrl = oauthProvider.AuthCodeURL(state) authCodeUrl = oauthProvider.AuthCodeURL(state)
case domain.AuthenticatorTypeOidc: case AuthenticatorTypeOidc:
nonce, err = a.randString(16) nonce, err = a.randString(16)
if err != nil { if err != nil {
return "", "", "", fmt.Errorf("failed to generate nonce: %w", err) return "", "", "", fmt.Errorf("failed to generate nonce: %w", err)
@ -318,6 +377,8 @@ func (a *Authenticator) randString(nByte int) (string, error) {
return base64.RawURLEncoding.EncodeToString(b), nil return base64.RawURLEncoding.EncodeToString(b), nil
} }
// OauthLoginStep2 finishes the oauth authentication flow by exchanging the code for an access token and
// fetching the user information.
func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, error) { func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, error) {
oauthProvider, ok := a.oauthAuthenticators[providerId] oauthProvider, ok := a.oauthAuthenticators[providerId]
if !ok { if !ok {

View File

@ -14,6 +14,7 @@ import (
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
) )
// LdapAuthenticator is an authenticator that uses LDAP for authentication.
type LdapAuthenticator struct { type LdapAuthenticator struct {
cfg *config.LdapProvider cfg *config.LdapProvider
} }
@ -33,14 +34,17 @@ func newLdapAuthenticator(_ context.Context, cfg *config.LdapProvider) (*LdapAut
return provider, nil return provider, nil
} }
// GetName returns the name of the LDAP authenticator.
func (l LdapAuthenticator) GetName() string { func (l LdapAuthenticator) GetName() string {
return l.cfg.ProviderName return l.cfg.ProviderName
} }
// RegistrationEnabled returns whether registration is enabled for the LDAP authenticator.
func (l LdapAuthenticator) RegistrationEnabled() bool { func (l LdapAuthenticator) RegistrationEnabled() bool {
return l.cfg.RegistrationEnabled return l.cfg.RegistrationEnabled
} }
// PlaintextAuthentication performs a plaintext authentication against the LDAP server.
func (l LdapAuthenticator) PlaintextAuthentication(userId domain.UserIdentifier, plainPassword string) error { func (l LdapAuthenticator) PlaintextAuthentication(userId domain.UserIdentifier, plainPassword string) error {
conn, err := internal.LdapConnect(l.cfg) conn, err := internal.LdapConnect(l.cfg)
if err != nil { if err != nil {
@ -81,6 +85,9 @@ func (l LdapAuthenticator) PlaintextAuthentication(userId domain.UserIdentifier,
return nil return nil
} }
// GetUserInfo retrieves user information from the LDAP server.
// If the user is not found, domain.ErrNotFound is returned.
// If multiple users are found, domain.ErrNotUnique is returned.
func (l LdapAuthenticator) GetUserInfo(_ context.Context, userId domain.UserIdentifier) ( func (l LdapAuthenticator) GetUserInfo(_ context.Context, userId domain.UserIdentifier) (
map[string]any, map[string]any,
error, error,
@ -126,6 +133,7 @@ func (l LdapAuthenticator) GetUserInfo(_ context.Context, userId domain.UserIden
return users[0], nil return users[0], nil
} }
// ParseUserInfo parses the user information from the LDAP server into a domain.AuthenticatorUserInfo struct.
func (l LdapAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) { func (l LdapAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
isAdmin, err := internal.LdapIsMemberOf(raw[l.cfg.FieldMap.GroupMembership].([][]byte), l.cfg.ParsedAdminGroupDN) isAdmin, err := internal.LdapIsMemberOf(raw[l.cfg.FieldMap.GroupMembership].([][]byte), l.cfg.ParsedAdminGroupDN)
if err != nil { if err != nil {

View File

@ -16,6 +16,8 @@ import (
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
) )
// PlainOauthAuthenticator is an authenticator that uses OAuth for authentication.
// User information is retrieved from the specified user info endpoint.
type PlainOauthAuthenticator struct { type PlainOauthAuthenticator struct {
name string name string
cfg *oauth2.Config cfg *oauth2.Config
@ -58,22 +60,27 @@ func newPlainOauthAuthenticator(
return provider, nil return provider, nil
} }
// GetName returns the name of the OAuth authenticator.
func (p PlainOauthAuthenticator) GetName() string { func (p PlainOauthAuthenticator) GetName() string {
return p.name return p.name
} }
// RegistrationEnabled returns whether registration is enabled for the OAuth authenticator.
func (p PlainOauthAuthenticator) RegistrationEnabled() bool { func (p PlainOauthAuthenticator) RegistrationEnabled() bool {
return p.registrationEnabled return p.registrationEnabled
} }
func (p PlainOauthAuthenticator) GetType() domain.AuthenticatorType { // GetType returns the type of the authenticator.
return domain.AuthenticatorTypeOAuth func (p PlainOauthAuthenticator) GetType() AuthenticatorType {
return AuthenticatorTypeOAuth
} }
// AuthCodeURL returns the URL to redirect the user to for authentication.
func (p PlainOauthAuthenticator) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { func (p PlainOauthAuthenticator) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
return p.cfg.AuthCodeURL(state, opts...) return p.cfg.AuthCodeURL(state, opts...)
} }
// Exchange exchanges the OAuth code for a token.
func (p PlainOauthAuthenticator) Exchange( func (p PlainOauthAuthenticator) Exchange(
ctx context.Context, ctx context.Context,
code string, code string,
@ -82,6 +89,7 @@ func (p PlainOauthAuthenticator) Exchange(
return p.cfg.Exchange(ctx, code, opts...) return p.cfg.Exchange(ctx, code, opts...)
} }
// GetUserInfo retrieves the user information from the user info endpoint.
func (p PlainOauthAuthenticator) GetUserInfo( func (p PlainOauthAuthenticator) GetUserInfo(
ctx context.Context, ctx context.Context,
token *oauth2.Token, token *oauth2.Token,
@ -119,6 +127,7 @@ func (p PlainOauthAuthenticator) GetUserInfo(
return userFields, nil return userFields, nil
} }
// ParseUserInfo parses the user information from the raw data.
func (p PlainOauthAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) { func (p PlainOauthAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
return parseOauthUserInfo(p.userInfoMapping, p.userAdminMapping, raw) return parseOauthUserInfo(p.userInfoMapping, p.userAdminMapping, raw)
} }

View File

@ -14,6 +14,7 @@ import (
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
) )
// OidcAuthenticator is an authenticator for OpenID Connect providers.
type OidcAuthenticator struct { type OidcAuthenticator struct {
name string name string
provider *oidc.Provider provider *oidc.Provider
@ -60,22 +61,27 @@ func newOidcAuthenticator(
return provider, nil return provider, nil
} }
// GetName returns the name of the authenticator.
func (o OidcAuthenticator) GetName() string { func (o OidcAuthenticator) GetName() string {
return o.name return o.name
} }
// RegistrationEnabled returns whether registration is enabled for this authenticator.
func (o OidcAuthenticator) RegistrationEnabled() bool { func (o OidcAuthenticator) RegistrationEnabled() bool {
return o.registrationEnabled return o.registrationEnabled
} }
func (o OidcAuthenticator) GetType() domain.AuthenticatorType { // GetType returns the type of the authenticator.
return domain.AuthenticatorTypeOidc func (o OidcAuthenticator) GetType() AuthenticatorType {
return AuthenticatorTypeOidc
} }
// AuthCodeURL returns the URL for the OAuth2 flow.
func (o OidcAuthenticator) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { func (o OidcAuthenticator) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
return o.cfg.AuthCodeURL(state, opts...) return o.cfg.AuthCodeURL(state, opts...)
} }
// Exchange exchanges the code for a token.
func (o OidcAuthenticator) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) ( func (o OidcAuthenticator) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (
*oauth2.Token, *oauth2.Token,
error, error,
@ -83,6 +89,7 @@ func (o OidcAuthenticator) Exchange(ctx context.Context, code string, opts ...oa
return o.cfg.Exchange(ctx, code, opts...) return o.cfg.Exchange(ctx, code, opts...)
} }
// GetUserInfo retrieves the user info from the token.
func (o OidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) ( func (o OidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (
map[string]any, map[string]any,
error, error,
@ -114,6 +121,7 @@ func (o OidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token,
return tokenFields, nil return tokenFields, nil
} }
// ParseUserInfo parses the user info.
func (o OidcAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) { func (o OidcAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
return parseOauthUserInfo(o.userInfoMapping, o.userAdminMapping, raw) return parseOauthUserInfo(o.userInfoMapping, o.userAdminMapping, raw)
} }

View File

@ -10,7 +10,6 @@ import (
"os" "os"
"strings" "strings"
evbus "github.com/vardius/message-bus"
"github.com/yeqown/go-qrcode/v2" "github.com/yeqown/go-qrcode/v2"
"github.com/yeqown/go-qrcode/writer/compressed" "github.com/yeqown/go-qrcode/writer/compressed"
@ -19,19 +18,56 @@ import (
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
) )
type Manager struct { // region dependencies
cfg *config.Config
bus evbus.MessageBus
tplHandler *TemplateHandler
fsRepo FileSystemRepo type UserDatabaseRepo interface {
users UserDatabaseRepo // GetUser returns the user with the given identifier from the SQL database.
wg WireguardDatabaseRepo GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
} }
type WireguardDatabaseRepo interface {
// GetInterfaceAndPeers returns the interface and all peers associated with it.
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
// GetPeer returns the peer with the given identifier.
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
// GetInterface returns the interface with the given identifier.
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
}
type FileSystemRepo interface {
// WriteFile writes the contents to the file at the given path.
WriteFile(path string, contents io.Reader) error
}
type TemplateRenderer interface {
// GetInterfaceConfig returns the configuration file for the given interface.
GetInterfaceConfig(iface *domain.Interface, peers []domain.Peer) (io.Reader, error)
// GetPeerConfig returns the configuration file for the given peer.
GetPeerConfig(peer *domain.Peer) (io.Reader, error)
}
type EventBus interface {
// Subscribe subscribes to the given topic.
Subscribe(topic string, fn any) error
}
// endregion dependencies
// Manager is responsible for managing the configuration files of the WireGuard interfaces and peers.
type Manager struct {
cfg *config.Config
bus EventBus
tplHandler TemplateRenderer
fsRepo FileSystemRepo
users UserDatabaseRepo
wg WireguardDatabaseRepo
}
// NewConfigFileManager creates a new Manager instance.
func NewConfigFileManager( func NewConfigFileManager(
cfg *config.Config, cfg *config.Config,
bus evbus.MessageBus, bus EventBus,
users UserDatabaseRepo, users UserDatabaseRepo,
wg WireguardDatabaseRepo, wg WireguardDatabaseRepo,
fsRepo FileSystemRepo, fsRepo FileSystemRepo,
@ -115,6 +151,8 @@ func (m Manager) handlePeerInterfaceUpdatedEvent(id domain.InterfaceIdentifier)
} }
} }
// GetInterfaceConfig returns the configuration file for the given interface.
// The file is structured in wg-quick format.
func (m Manager) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) { func (m Manager) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil { if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err return nil, err
@ -128,6 +166,8 @@ func (m Manager) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIden
return m.tplHandler.GetInterfaceConfig(iface, peers) return m.tplHandler.GetInterfaceConfig(iface, peers)
} }
// GetPeerConfig returns the configuration file for the given peer.
// The file is structured in wg-quick format.
func (m Manager) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) { func (m Manager) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
peer, err := m.wg.GetPeer(ctx, id) peer, err := m.wg.GetPeer(ctx, id)
if err != nil { if err != nil {
@ -141,6 +181,7 @@ func (m Manager) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (i
return m.tplHandler.GetPeerConfig(peer) return m.tplHandler.GetPeerConfig(peer)
} }
// GetPeerConfigQrCode returns a QR code image containing the configuration for the given peer.
func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) { func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
peer, err := m.wg.GetPeer(ctx, id) peer, err := m.wg.GetPeer(ctx, id)
if err != nil { if err != nil {
@ -191,6 +232,7 @@ func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifi
return buf, nil return buf, nil
} }
// PersistInterfaceConfig writes the configuration file for the given interface to the file system.
func (m Manager) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error { func (m Manager) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error {
iface, peers, err := m.wg.GetInterfaceAndPeers(ctx, id) iface, peers, err := m.wg.GetInterfaceAndPeers(ctx, id)
if err != nil { if err != nil {
@ -213,4 +255,5 @@ type nopCloser struct {
io.Writer io.Writer
} }
// Close is a no-op for the nopCloser.
func (nopCloser) Close() error { return nil } func (nopCloser) Close() error { return nil }

View File

@ -1,22 +0,0 @@
package configfile
import (
"context"
"io"
"github.com/h44z/wg-portal/internal/domain"
)
type UserDatabaseRepo interface {
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
}
type WireguardDatabaseRepo interface {
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
}
type FileSystemRepo interface {
WriteFile(path string, contents io.Reader) error
}

View File

@ -13,6 +13,8 @@ import (
//go:embed tpl_files/* //go:embed tpl_files/*
var TemplateFiles embed.FS var TemplateFiles embed.FS
// TemplateHandler is responsible for rendering the WireGuard configuration files
// based on the provided templates.
type TemplateHandler struct { type TemplateHandler struct {
templates *template.Template templates *template.Template
} }
@ -34,6 +36,7 @@ func newTemplateHandler() (*TemplateHandler, error) {
return handler, nil return handler, nil
} }
// GetInterfaceConfig returns the rendered configuration file for a WireGuard interface.
func (c TemplateHandler) GetInterfaceConfig(cfg *domain.Interface, peers []domain.Peer) (io.Reader, error) { func (c TemplateHandler) GetInterfaceConfig(cfg *domain.Interface, peers []domain.Peer) (io.Reader, error) {
var tplBuff bytes.Buffer var tplBuff bytes.Buffer
@ -51,6 +54,7 @@ func (c TemplateHandler) GetInterfaceConfig(cfg *domain.Interface, peers []domai
return &tplBuff, nil return &tplBuff, nil
} }
// GetPeerConfig returns the rendered configuration file for a WireGuard peer.
func (c TemplateHandler) GetPeerConfig(peer *domain.Peer) (io.Reader, error) { func (c TemplateHandler) GetPeerConfig(peer *domain.Peer) (io.Reader, error) {
var tplBuff bytes.Buffer var tplBuff bytes.Buffer

View File

@ -10,16 +10,60 @@ import (
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
) )
type Manager struct { // region dependencies
cfg *config.Config
tplHandler *TemplateHandler
type Mailer interface {
// Send sends an email with the given subject and body to the given recipients.
Send(ctx context.Context, subject, body string, to []string, options *domain.MailOptions) error
}
type ConfigFileManager interface {
// GetInterfaceConfig returns the configuration for the given interface.
GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error)
// GetPeerConfig returns the configuration for the given peer.
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
// GetPeerConfigQrCode returns the QR code for the given peer.
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
}
type UserDatabaseRepo interface {
// GetUser returns the user with the given identifier.
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
}
type WireguardDatabaseRepo interface {
// GetInterfaceAndPeers returns the interface and all peers for the given interface identifier.
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
// GetPeer returns the peer with the given identifier.
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
// GetInterface returns the interface with the given identifier.
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
}
type TemplateRenderer interface {
// GetConfigMail returns the text and html template for the mail with a link.
GetConfigMail(user *domain.User, link string) (io.Reader, io.Reader, error)
// GetConfigMailWithAttachment returns the text and html template for the mail with an attachment.
GetConfigMailWithAttachment(user *domain.User, cfgName, qrName string) (
io.Reader,
io.Reader,
error,
)
}
// endregion dependencies
type Manager struct {
cfg *config.Config
tplHandler TemplateRenderer
mailer Mailer mailer Mailer
configFiles ConfigFileManager configFiles ConfigFileManager
users UserDatabaseRepo users UserDatabaseRepo
wg WireguardDatabaseRepo wg WireguardDatabaseRepo
} }
// NewMailManager creates a new mail manager.
func NewMailManager( func NewMailManager(
cfg *config.Config, cfg *config.Config,
mailer Mailer, mailer Mailer,
@ -44,6 +88,7 @@ func NewMailManager(
return m, nil return m, nil
} }
// SendPeerEmail sends an email to the user linked to the given peers.
func (m Manager) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error { func (m Manager) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error {
for _, peerId := range peers { for _, peerId := range peers {
peer, err := m.wg.GetPeer(ctx, peerId) peer, err := m.wg.GetPeer(ctx, peerId)

View File

@ -1,28 +0,0 @@
package mail
import (
"context"
"io"
"github.com/h44z/wg-portal/internal/domain"
)
type Mailer interface {
Send(ctx context.Context, subject, body string, to []string, options *domain.MailOptions) error
}
type ConfigFileManager interface {
GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error)
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
}
type UserDatabaseRepo interface {
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
}
type WireguardDatabaseRepo interface {
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
}

View File

@ -14,6 +14,7 @@ import (
//go:embed tpl_files/* //go:embed tpl_files/*
var TemplateFiles embed.FS var TemplateFiles embed.FS
// TemplateHandler is a struct that holds the html and text templates.
type TemplateHandler struct { type TemplateHandler struct {
portalUrl string portalUrl string
htmlTemplates *htmlTemplate.Template htmlTemplates *htmlTemplate.Template
@ -40,6 +41,7 @@ func newTemplateHandler(portalUrl string) (*TemplateHandler, error) {
return handler, nil return handler, nil
} }
// GetConfigMail returns the text and html template for the mail with a link.
func (c TemplateHandler) GetConfigMail(user *domain.User, link string) (io.Reader, io.Reader, error) { func (c TemplateHandler) GetConfigMail(user *domain.User, link string) (io.Reader, io.Reader, error) {
var tplBuff bytes.Buffer var tplBuff bytes.Buffer
var htmlTplBuff bytes.Buffer var htmlTplBuff bytes.Buffer
@ -65,6 +67,7 @@ func (c TemplateHandler) GetConfigMail(user *domain.User, link string) (io.Reade
return &tplBuff, &htmlTplBuff, nil return &tplBuff, &htmlTplBuff, nil
} }
// GetConfigMailWithAttachment returns the text and html template for the mail with an attachment.
func (c TemplateHandler) GetConfigMailWithAttachment(user *domain.User, cfgName, qrName string) ( func (c TemplateHandler) GetConfigMailWithAttachment(user *domain.User, cfgName, qrName string) (
io.Reader, io.Reader,
io.Reader, io.Reader,

View File

@ -1,77 +0,0 @@
package app
import (
"context"
"io"
"github.com/h44z/wg-portal/internal/domain"
)
type Authenticator interface {
GetExternalLoginProviders(_ context.Context) []domain.LoginProviderInfo
IsUserValid(ctx context.Context, id domain.UserIdentifier) bool
PlainLogin(ctx context.Context, username, password string) (*domain.User, error)
OauthLoginStep1(_ context.Context, providerId string) (authCodeUrl, state, nonce string, err error)
OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, error)
}
type UserManager interface {
RegisterUser(ctx context.Context, user *domain.User) error
NewUser(ctx context.Context, user *domain.User) error
StartBackgroundJobs(ctx context.Context)
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
GetAllUsers(ctx context.Context) ([]domain.User, error)
UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error)
CreateUser(ctx context.Context, user *domain.User) (*domain.User, error)
DeleteUser(ctx context.Context, id domain.UserIdentifier) error
ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
}
type WireGuardManager interface {
StartBackgroundJobs(ctx context.Context)
GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error)
ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error)
RestoreInterfaceState(ctx context.Context, updateDbOnError bool, filter ...domain.InterfaceIdentifier) error
CreateDefaultPeer(ctx context.Context, userId domain.UserIdentifier) error
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error)
GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error)
GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error)
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
GetUserInterfaces(ctx context.Context, id domain.UserIdentifier) ([]domain.Interface, error)
PrepareInterface(ctx context.Context) (*domain.Interface, error)
CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error)
UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error)
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error)
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
CreatePeer(ctx context.Context, p *domain.Peer) (*domain.Peer, error)
CreateMultiplePeers(
ctx context.Context,
id domain.InterfaceIdentifier,
r *domain.PeerCreationRequest,
) ([]domain.Peer, error)
UpdatePeer(ctx context.Context, p *domain.Peer) (*domain.Peer, error)
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error
}
type StatisticsCollector interface {
StartBackgroundJobs(ctx context.Context)
}
type ConfigFileManager interface {
GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error)
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error
}
type MailManager interface {
SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error
}
type ApiV1Manager interface {
ApiV1GetUsers(ctx context.Context) ([]domain.User, error)
}

View File

@ -1,12 +0,0 @@
package route
import (
"context"
"github.com/h44z/wg-portal/internal/domain"
)
type InterfaceAndPeerDatabaseRepo interface {
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
}

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
evbus "github.com/vardius/message-bus"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
@ -17,6 +16,22 @@ import (
"github.com/h44z/wg-portal/internal/lowlevel" "github.com/h44z/wg-portal/internal/lowlevel"
) )
// region dependencies
type InterfaceAndPeerDatabaseRepo interface {
// GetAllInterfaces returns all interfaces
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
// GetInterfacePeers returns all peers for a given interface
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
}
type EventBus interface {
// Subscribe subscribes to a topic
Subscribe(topic string, fn interface{}) error
}
// endregion dependencies
type routeRuleInfo struct { type routeRuleInfo struct {
ifaceId domain.InterfaceIdentifier ifaceId domain.InterfaceIdentifier
fwMark uint32 fwMark uint32
@ -29,14 +44,15 @@ type routeRuleInfo struct {
// for default routes. // for default routes.
type Manager struct { type Manager struct {
cfg *config.Config cfg *config.Config
bus evbus.MessageBus
wg lowlevel.WireGuardClient bus EventBus
nl lowlevel.NetlinkClient wg lowlevel.WireGuardClient
db InterfaceAndPeerDatabaseRepo nl lowlevel.NetlinkClient
db InterfaceAndPeerDatabaseRepo
} }
func NewRouteManager(cfg *config.Config, bus evbus.MessageBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) { // NewRouteManager creates a new route manager instance.
func NewRouteManager(cfg *config.Config, bus EventBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) {
wg, err := wgctrl.New() wg, err := wgctrl.New()
if err != nil { if err != nil {
panic("failed to init wgctrl: " + err.Error()) panic("failed to init wgctrl: " + err.Error())
@ -63,7 +79,10 @@ func (m Manager) connectToMessageBus() {
_ = m.bus.Subscribe(app.TopicRouteRemove, m.handleRouteRemoveEvent) _ = m.bus.Subscribe(app.TopicRouteRemove, m.handleRouteRemoveEvent)
} }
// StartBackgroundJobs starts background jobs for the route manager.
// This method is non-blocking and returns immediately.
func (m Manager) StartBackgroundJobs(_ context.Context) { func (m Manager) StartBackgroundJobs(_ context.Context) {
// this is a no-op for now
} }
func (m Manager) handleRouteUpdateEvent(srcDescription string) { func (m Manager) handleRouteUpdateEvent(srcDescription string) {

View File

@ -1,20 +0,0 @@
package users
import (
"context"
"github.com/h44z/wg-portal/internal/domain"
)
type UserDatabaseRepo interface {
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
GetUserByEmail(ctx context.Context, email string) (*domain.User, error)
GetAllUsers(ctx context.Context) ([]domain.User, error)
FindUsers(ctx context.Context, search string) ([]domain.User, error)
SaveUser(ctx context.Context, id domain.UserIdentifier, updateFunc func(u *domain.User) (*domain.User, error)) error
DeleteUser(ctx context.Context, id domain.UserIdentifier) error
}
type PeerDatabaseRepo interface {
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
}

View File

@ -11,7 +11,6 @@ import (
"github.com/go-ldap/ldap/v3" "github.com/go-ldap/ldap/v3"
"github.com/google/uuid" "github.com/google/uuid"
evbus "github.com/vardius/message-bus"
"github.com/h44z/wg-portal/internal" "github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/app"
@ -19,15 +18,46 @@ import (
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
) )
// region dependencies
type UserDatabaseRepo interface {
// GetUser returns the user with the given identifier.
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
// GetUserByEmail returns the user with the given email address.
GetUserByEmail(ctx context.Context, email string) (*domain.User, error)
// GetAllUsers returns all users.
GetAllUsers(ctx context.Context) ([]domain.User, error)
// FindUsers returns all users matching the search string.
FindUsers(ctx context.Context, search string) ([]domain.User, error)
// SaveUser saves the user with the given identifier.
SaveUser(ctx context.Context, id domain.UserIdentifier, updateFunc func(u *domain.User) (*domain.User, error)) error
// DeleteUser deletes the user with the given identifier.
DeleteUser(ctx context.Context, id domain.UserIdentifier) error
}
type PeerDatabaseRepo interface {
// GetUserPeers returns all peers linked to the given user.
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
}
type EventBus interface {
// Publish sends a message to the message bus.
Publish(topic string, args ...any)
}
// endregion dependencies
// Manager is the user manager.
type Manager struct { type Manager struct {
cfg *config.Config cfg *config.Config
bus evbus.MessageBus
bus EventBus
users UserDatabaseRepo users UserDatabaseRepo
peers PeerDatabaseRepo peers PeerDatabaseRepo
} }
func NewUserManager(cfg *config.Config, bus evbus.MessageBus, users UserDatabaseRepo, peers PeerDatabaseRepo) ( // NewUserManager creates a new user manager instance.
func NewUserManager(cfg *config.Config, bus EventBus, users UserDatabaseRepo, peers PeerDatabaseRepo) (
*Manager, *Manager,
error, error,
) { ) {
@ -41,6 +71,7 @@ func NewUserManager(cfg *config.Config, bus evbus.MessageBus, users UserDatabase
return m, nil return m, nil
} }
// RegisterUser registers a new user.
func (m Manager) RegisterUser(ctx context.Context, user *domain.User) error { func (m Manager) RegisterUser(ctx context.Context, user *domain.User) error {
if err := domain.ValidateAdminAccessRights(ctx); err != nil { if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return err return err
@ -56,6 +87,7 @@ func (m Manager) RegisterUser(ctx context.Context, user *domain.User) error {
return nil return nil
} }
// NewUser creates a new user.
func (m Manager) NewUser(ctx context.Context, user *domain.User) error { func (m Manager) NewUser(ctx context.Context, user *domain.User) error {
if user.Identifier == "" { if user.Identifier == "" {
return errors.New("missing user identifier") return errors.New("missing user identifier")
@ -90,12 +122,13 @@ func (m Manager) NewUser(ctx context.Context, user *domain.User) error {
return nil return nil
} }
// StartBackgroundJobs starts the background jobs.
// This method is non-blocking and returns immediately.
func (m Manager) StartBackgroundJobs(ctx context.Context) { func (m Manager) StartBackgroundJobs(ctx context.Context) {
go m.runLdapSynchronizationService(ctx) go m.runLdapSynchronizationService(ctx)
} }
// GetUser returns the user with the given identifier.
func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) { func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
if err := domain.ValidateUserAccessRights(ctx, id); err != nil { if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
return nil, err return nil, err
@ -112,6 +145,7 @@ func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain
return user, nil return user, nil
} }
// GetUserByEmail returns the user with the given email address.
func (m Manager) GetUserByEmail(ctx context.Context, email string) (*domain.User, error) { func (m Manager) GetUserByEmail(ctx context.Context, email string) (*domain.User, error) {
user, err := m.users.GetUserByEmail(ctx, email) user, err := m.users.GetUserByEmail(ctx, email)
@ -130,6 +164,7 @@ func (m Manager) GetUserByEmail(ctx context.Context, email string) (*domain.User
return user, nil return user, nil
} }
// GetAllUsers returns all users.
func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) { func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil { if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err return nil, err
@ -162,6 +197,7 @@ func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) {
return users, nil return users, nil
} }
// UpdateUser updates the user with the given identifier.
func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) { func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
if err := domain.ValidateUserAccessRights(ctx, user.Identifier); err != nil { if err := domain.ValidateUserAccessRights(ctx, user.Identifier); err != nil {
return nil, err return nil, err
@ -203,6 +239,7 @@ func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.Use
return user, nil return user, nil
} }
// CreateUser creates a new user.
func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) { func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil { if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err return nil, err
@ -236,6 +273,7 @@ func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.Use
return user, nil return user, nil
} }
// DeleteUser deletes the user with the given identifier.
func (m Manager) DeleteUser(ctx context.Context, id domain.UserIdentifier) error { func (m Manager) DeleteUser(ctx context.Context, id domain.UserIdentifier) error {
if err := domain.ValidateAdminAccessRights(ctx); err != nil { if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return err return err
@ -260,6 +298,7 @@ func (m Manager) DeleteUser(ctx context.Context, id domain.UserIdentifier) error
return nil return nil
} }
// ActivateApi activates the API access for the user with the given identifier.
func (m Manager) ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) { func (m Manager) ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
user, err := m.users.GetUser(ctx, id) user, err := m.users.GetUser(ctx, id)
if err != nil && !errors.Is(err, domain.ErrNotFound) { if err != nil && !errors.Is(err, domain.ErrNotFound) {
@ -287,6 +326,7 @@ func (m Manager) ActivateApi(ctx context.Context, id domain.UserIdentifier) (*do
return user, nil return user, nil
} }
// DeactivateApi deactivates the API access for the user with the given identifier.
func (m Manager) DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) { func (m Manager) DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
user, err := m.users.GetUser(ctx, id) user, err := m.users.GetUser(ctx, id)
if err != nil && !errors.Is(err, domain.ErrNotFound) { if err != nil && !errors.Is(err, domain.ErrNotFound) {

View File

@ -1,87 +0,0 @@
package wireguard
import (
"context"
"github.com/h44z/wg-portal/internal/domain"
)
type InterfaceAndPeerDatabaseRepo interface {
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error)
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
FindInterfaces(ctx context.Context, search string) ([]domain.Interface, error)
GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error)
SaveInterface(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(in *domain.Interface) (*domain.Interface, error),
) error
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ([]domain.Peer, error)
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
FindUserPeers(ctx context.Context, id domain.UserIdentifier, search string) ([]domain.Peer, error)
SavePeer(
ctx context.Context,
id domain.PeerIdentifier,
updateFunc func(in *domain.Peer) (*domain.Peer, error),
) error
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
}
type StatisticsDatabaseRepo interface {
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
UpdatePeerStatus(
ctx context.Context,
id domain.PeerIdentifier,
updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error),
) error
UpdateInterfaceStatus(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error),
) error
DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error
}
type InterfaceController interface {
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (
*domain.PhysicalPeer,
error,
)
SaveInterface(
_ context.Context,
id domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
SavePeer(
_ context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
}
type WgQuickController interface {
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
UnsetDNS(id domain.InterfaceIdentifier) error
}
type MetricsServer interface {
UpdateInterfaceMetrics(status domain.InterfaceStatus)
UpdatePeerMetrics(peer *domain.Peer, status domain.PeerStatus)
}

View File

@ -7,31 +7,63 @@ import (
"time" "time"
probing "github.com/prometheus-community/pro-bing" probing "github.com/prometheus-community/pro-bing"
evbus "github.com/vardius/message-bus"
"github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
) )
type StatisticsDatabaseRepo interface {
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
UpdatePeerStatus(
ctx context.Context,
id domain.PeerIdentifier,
updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error),
) error
UpdateInterfaceStatus(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error),
) error
DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error
}
type StatisticsInterfaceController interface {
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
}
type StatisticsMetricsServer interface {
UpdateInterfaceMetrics(status domain.InterfaceStatus)
UpdatePeerMetrics(peer *domain.Peer, status domain.PeerStatus)
}
type StatisticsEventBus interface {
// Subscribe subscribes to a topic
Subscribe(topic string, fn interface{}) error
}
type StatisticsCollector struct { type StatisticsCollector struct {
cfg *config.Config cfg *config.Config
bus evbus.MessageBus bus StatisticsEventBus
pingWaitGroup sync.WaitGroup pingWaitGroup sync.WaitGroup
pingJobs chan domain.Peer pingJobs chan domain.Peer
db StatisticsDatabaseRepo db StatisticsDatabaseRepo
wg InterfaceController wg StatisticsInterfaceController
ms MetricsServer ms StatisticsMetricsServer
} }
// NewStatisticsCollector creates a new statistics collector.
func NewStatisticsCollector( func NewStatisticsCollector(
cfg *config.Config, cfg *config.Config,
bus evbus.MessageBus, bus StatisticsEventBus,
db StatisticsDatabaseRepo, db StatisticsDatabaseRepo,
wg InterfaceController, wg StatisticsInterfaceController,
ms MetricsServer, ms StatisticsMetricsServer,
) (*StatisticsCollector, error) { ) (*StatisticsCollector, error) {
c := &StatisticsCollector{ c := &StatisticsCollector{
cfg: cfg, cfg: cfg,
@ -47,6 +79,8 @@ func NewStatisticsCollector(
return c, nil return c, nil
} }
// StartBackgroundJobs starts the background jobs for the statistics collector.
// This method is non-blocking and returns immediately.
func (c *StatisticsCollector) StartBackgroundJobs(ctx context.Context) { func (c *StatisticsCollector) StartBackgroundJobs(ctx context.Context) {
c.startPingWorkers(ctx) c.startPingWorkers(ctx)
c.startInterfaceDataFetcher(ctx) c.startInterfaceDataFetcher(ctx)

View File

@ -5,17 +5,74 @@ import (
"log/slog" "log/slog"
"time" "time"
evbus "github.com/vardius/message-bus"
"github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
) )
type Manager struct { // region dependencies
cfg *config.Config
bus evbus.MessageBus
type InterfaceAndPeerDatabaseRepo interface {
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error)
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error)
SaveInterface(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(in *domain.Interface) (*domain.Interface, error),
) error
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
SavePeer(
ctx context.Context,
id domain.PeerIdentifier,
updateFunc func(in *domain.Peer) (*domain.Peer, error),
) error
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
}
type InterfaceController interface {
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
SaveInterface(
_ context.Context,
id domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
SavePeer(
_ context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
}
type WgQuickController interface {
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
UnsetDNS(id domain.InterfaceIdentifier) error
}
type EventBus interface {
// Publish sends a message to the message bus.
Publish(topic string, args ...any)
// Subscribe subscribes to a topic
Subscribe(topic string, fn interface{}) error
}
// endregion dependencies
type Manager struct {
cfg *config.Config
bus EventBus
db InterfaceAndPeerDatabaseRepo db InterfaceAndPeerDatabaseRepo
wg InterfaceController wg InterfaceController
quick WgQuickController quick WgQuickController
@ -23,7 +80,7 @@ type Manager struct {
func NewWireGuardManager( func NewWireGuardManager(
cfg *config.Config, cfg *config.Config,
bus evbus.MessageBus, bus EventBus,
wg InterfaceController, wg InterfaceController,
quick WgQuickController, quick WgQuickController,
db InterfaceAndPeerDatabaseRepo, db InterfaceAndPeerDatabaseRepo,
@ -41,6 +98,8 @@ func NewWireGuardManager(
return m, nil return m, nil
} }
// StartBackgroundJobs starts background jobs like the expired peers check.
// This method is non-blocking.
func (m Manager) StartBackgroundJobs(ctx context.Context) { func (m Manager) StartBackgroundJobs(ctx context.Context) {
go m.runExpiredPeersCheck(ctx) go m.runExpiredPeersCheck(ctx)
} }

View File

@ -13,6 +13,8 @@ import (
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
) )
// GetImportableInterfaces returns all physical interfaces that are available on the system.
// This function also returns interfaces that are already available in the database.
func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) { func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil { if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err return nil, err
@ -26,6 +28,7 @@ func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.Physical
return physicalInterfaces, nil return physicalInterfaces, nil
} }
// GetInterfaceAndPeers returns the interface and all peers for the given interface identifier.
func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) ( func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
*domain.Interface, *domain.Interface,
[]domain.Peer, []domain.Peer,
@ -38,6 +41,7 @@ func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceId
return m.db.GetInterfaceAndPeers(ctx, id) return m.db.GetInterfaceAndPeers(ctx, id)
} }
// GetAllInterfaces returns all interfaces that are available in the database.
func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) { func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil { if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err return nil, err
@ -46,6 +50,7 @@ func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, erro
return m.db.GetAllInterfaces(ctx) return m.db.GetAllInterfaces(ctx)
} }
// GetAllInterfacesAndPeers returns all interfaces and their peers.
func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) { func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil { if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, nil, err return nil, nil, err
@ -97,6 +102,7 @@ func (m Manager) GetUserInterfaces(ctx context.Context, _ domain.UserIdentifier)
return userInterfaces, nil return userInterfaces, nil
} }
// ImportNewInterfaces imports all new physical interfaces that are available on the system.
func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) { func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil { if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return 0, err return 0, err
@ -148,6 +154,7 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter
return imported, nil return imported, nil
} }
// ApplyPeerDefaults applies the interface defaults to all peers of the given interface.
func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error { func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error {
if err := domain.ValidateAdminAccessRights(ctx); err != nil { if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return err return err
@ -179,6 +186,8 @@ func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) er
return nil return nil
} }
// RestoreInterfaceState restores the state of all physical interfaces and their peers.
// The final state of the interfaces and peers will be the same as stored in the database.
func (m Manager) RestoreInterfaceState( func (m Manager) RestoreInterfaceState(
ctx context.Context, ctx context.Context,
updateDbOnError bool, updateDbOnError bool,
@ -296,6 +305,7 @@ func (m Manager) RestoreInterfaceState(
return nil return nil
} }
// PrepareInterface generates a new interface with fresh keys, ip addresses and a listen port.
func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error) { func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil { if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err return nil, err
@ -376,6 +386,7 @@ func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error
return freshInterface, nil return freshInterface, nil
} }
// CreateInterface creates a new interface with the given configuration.
func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) { func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil { if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err return nil, err
@ -401,6 +412,7 @@ func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*do
return in, nil return in, nil
} }
// UpdateInterface updates the given interface with the new configuration.
func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error) { func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil { if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, nil, err return nil, nil, err
@ -423,6 +435,7 @@ func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*do
return in, existingPeers, nil return in, existingPeers, nil
} }
// DeleteInterface deletes the given interface.
func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error { func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
if err := domain.ValidateAdminAccessRights(ctx); err != nil { if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return err return err

View File

@ -11,6 +11,7 @@ import (
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
) )
// CreateDefaultPeer creates a default peer for the given user on all server interfaces.
func (m Manager) CreateDefaultPeer(ctx context.Context, userId domain.UserIdentifier) error { func (m Manager) CreateDefaultPeer(ctx context.Context, userId domain.UserIdentifier) error {
if err := domain.ValidateAdminAccessRights(ctx); err != nil { if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return err return err
@ -55,6 +56,7 @@ func (m Manager) CreateDefaultPeer(ctx context.Context, userId domain.UserIdenti
return nil return nil
} }
// GetUserPeers returns all peers for the given user.
func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) { func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
if err := domain.ValidateUserAccessRights(ctx, id); err != nil { if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
return nil, err return nil, err
@ -63,6 +65,7 @@ func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]
return m.db.GetUserPeers(ctx, id) return m.db.GetUserPeers(ctx, id)
} }
// PreparePeer prepares a new peer for the given interface with fresh keys and ip addresses.
func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) { func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) {
if !m.cfg.Core.SelfProvisioningAllowed { if !m.cfg.Core.SelfProvisioningAllowed {
if err := domain.ValidateAdminAccessRights(ctx); err != nil { if err := domain.ValidateAdminAccessRights(ctx); err != nil {
@ -143,6 +146,7 @@ func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier)
return freshPeer, nil return freshPeer, nil
} }
// GetPeer returns the peer with the given identifier.
func (m Manager) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) { func (m Manager) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) {
peer, err := m.db.GetPeer(ctx, id) peer, err := m.db.GetPeer(ctx, id)
if err != nil { if err != nil {
@ -156,6 +160,7 @@ func (m Manager) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain
return peer, nil return peer, nil
} }
// CreatePeer creates a new peer.
func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) { func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
if !m.cfg.Core.SelfProvisioningAllowed { if !m.cfg.Core.SelfProvisioningAllowed {
if err := domain.ValidateAdminAccessRights(ctx); err != nil { if err := domain.ValidateAdminAccessRights(ctx); err != nil {
@ -201,6 +206,8 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
return peer, nil return peer, nil
} }
// CreateMultiplePeers creates multiple new peers for the given user identifiers.
// It calls PreparePeer for each user identifier in the request.
func (m Manager) CreateMultiplePeers( func (m Manager) CreateMultiplePeers(
ctx context.Context, ctx context.Context,
interfaceId domain.InterfaceIdentifier, interfaceId domain.InterfaceIdentifier,
@ -243,6 +250,7 @@ func (m Manager) CreateMultiplePeers(
return createdPeers, nil return createdPeers, nil
} }
// UpdatePeer updates the given peer.
func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) { func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
existingPeer, err := m.db.GetPeer(ctx, peer.Identifier) existingPeer, err := m.db.GetPeer(ctx, peer.Identifier)
if err != nil { if err != nil {
@ -309,6 +317,7 @@ func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
return peer, nil return peer, nil
} }
// DeletePeer deletes the peer with the given identifier.
func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error { func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error {
peer, err := m.db.GetPeer(ctx, id) peer, err := m.db.GetPeer(ctx, id)
if err != nil { if err != nil {
@ -341,6 +350,7 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
return nil return nil
} }
// GetPeerStats returns the status of the peer with the given identifier.
func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error) { func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error) {
_, peers, err := m.db.GetInterfaceAndPeers(ctx, id) _, peers, err := m.db.GetInterfaceAndPeers(ctx, id)
if err != nil { if err != nil {
@ -359,6 +369,7 @@ func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier
return m.db.GetPeersStats(ctx, peerIds...) return m.db.GetPeersStats(ctx, peerIds...)
} }
// GetUserPeerStats returns the status of all peers for the given user.
func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) { func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) {
if err := domain.ValidateUserAccessRights(ctx, id); err != nil { if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
return nil, err return nil, err

View File

@ -8,6 +8,7 @@ import (
"github.com/go-ldap/ldap/v3" "github.com/go-ldap/ldap/v3"
) )
// Auth contains all authentication providers.
type Auth struct { type Auth struct {
// OpenIDConnect contains a list of OpenID Connect providers. // OpenIDConnect contains a list of OpenID Connect providers.
OpenIDConnect []OpenIDConnectProvider `yaml:"oidc"` OpenIDConnect []OpenIDConnectProvider `yaml:"oidc"`
@ -17,6 +18,7 @@ type Auth struct {
Ldap []LdapProvider `yaml:"ldap"` Ldap []LdapProvider `yaml:"ldap"`
} }
// BaseFields contains the basic fields that are used to map user information from the authentication providers.
type BaseFields struct { type BaseFields struct {
// UserIdentifier is the name of the field that contains the user identifier. // UserIdentifier is the name of the field that contains the user identifier.
UserIdentifier string `yaml:"user_identifier"` UserIdentifier string `yaml:"user_identifier"`
@ -32,6 +34,7 @@ type BaseFields struct {
Department string `yaml:"department"` Department string `yaml:"department"`
} }
// OauthFields contains extra fields that are used to map user information from OAuth providers.
type OauthFields struct { type OauthFields struct {
BaseFields `yaml:",inline"` BaseFields `yaml:",inline"`
// IsAdmin is the name of the field that contains the admin flag. // IsAdmin is the name of the field that contains the admin flag.
@ -107,12 +110,14 @@ func (o *OauthAdminMapping) GetAdminGroupRegex() *regexp.Regexp {
return o.adminGroupRegex return o.adminGroupRegex
} }
// LdapFields contains extra fields that are used to map user information from LDAP providers.
type LdapFields struct { type LdapFields struct {
BaseFields `yaml:",inline"` BaseFields `yaml:",inline"`
// GroupMembership is the name of the LDAP field that contains the groups to which the user belongs. // GroupMembership is the name of the LDAP field that contains the groups to which the user belongs.
GroupMembership string `yaml:"memberof"` GroupMembership string `yaml:"memberof"`
} }
// LdapProvider contains the configuration for the LDAP connection.
type LdapProvider struct { type LdapProvider struct {
// ProviderName is an internal name that is used to distinguish LDAP servers. It must not contain spaces or special characters. // ProviderName is an internal name that is used to distinguish LDAP servers. It must not contain spaces or special characters.
ProviderName string `yaml:"provider_name"` ProviderName string `yaml:"provider_name"`
@ -163,6 +168,7 @@ type LdapProvider struct {
LogUserInfo bool `yaml:"log_user_info"` LogUserInfo bool `yaml:"log_user_info"`
} }
// OpenIDConnectProvider contains the configuration for the OpenID Connect provider.
type OpenIDConnectProvider struct { type OpenIDConnectProvider struct {
// ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters. // ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters.
ProviderName string `yaml:"provider_name"` ProviderName string `yaml:"provider_name"`
@ -196,6 +202,7 @@ type OpenIDConnectProvider struct {
LogUserInfo bool `yaml:"log_user_info"` LogUserInfo bool `yaml:"log_user_info"`
} }
// OAuthProvider contains the configuration for the OAuth provider.
type OAuthProvider struct { type OAuthProvider struct {
// ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters. // ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters.
ProviderName string `yaml:"provider_name"` ProviderName string `yaml:"provider_name"`

View File

@ -10,6 +10,7 @@ import (
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
// Config is the main configuration struct.
type Config struct { type Config struct {
Core struct { Core struct {
// AdminUser defines the default administrator account that will be created // AdminUser defines the default administrator account that will be created
@ -179,6 +180,7 @@ func GetConfig() (*Config, error) {
return cfg, nil return cfg, nil
} }
// loadConfigFile loads the configuration from a YAML file into the given cfg struct.
func loadConfigFile(cfg any, filename string) error { func loadConfigFile(cfg any, filename string) error {
data, err := envsubst.ReadFile(filename) data, err := envsubst.ReadFile(filename)
if err != nil { if err != nil {

View File

@ -2,6 +2,8 @@ package config
import "time" import "time"
// SupportedDatabase is a type for the supported database types.
// Supported: mysql, mssql, postgres, sqlite
type SupportedDatabase string type SupportedDatabase string
const ( const (
@ -11,6 +13,7 @@ const (
DatabaseSQLite SupportedDatabase = "sqlite" DatabaseSQLite SupportedDatabase = "sqlite"
) )
// DatabaseConfig contains the configuration for the database connection.
type DatabaseConfig struct { type DatabaseConfig struct {
// Debug enables logging of all database statements // Debug enables logging of all database statements
Debug bool `yaml:"debug"` Debug bool `yaml:"debug"`

View File

@ -1,5 +1,7 @@
package config package config
// MailEncryption is the type of the SMTP encryption.
// Supported: none, tls, starttls
type MailEncryption string type MailEncryption string
const ( const (
@ -8,6 +10,8 @@ const (
MailEncryptionStartTLS MailEncryption = "starttls" MailEncryptionStartTLS MailEncryption = "starttls"
) )
// MailAuthType is the type of the SMTP authentication.
// Supported: plain, login, crammd5
type MailAuthType string type MailAuthType string
const ( const (
@ -16,6 +20,7 @@ const (
MailAuthCramMD5 MailAuthType = "crammd5" MailAuthCramMD5 MailAuthType = "crammd5"
) )
// MailConfig contains the configuration for the mail server which is used to send emails.
type MailConfig struct { type MailConfig struct {
// Host is the hostname or IP of the SMTP server // Host is the hostname or IP of the SMTP server
Host string `yaml:"host"` Host string `yaml:"host"`

View File

@ -1,5 +1,6 @@
package config package config
// WebConfig contains the configuration for the web server.
type WebConfig struct { type WebConfig struct {
// RequestLogging enables logging of all HTTP requests. // RequestLogging enables logging of all HTTP requests.
RequestLogging bool `yaml:"request_logging"` RequestLogging bool `yaml:"request_logging"`

View File

@ -1,11 +1,5 @@
package domain package domain
import (
"context"
"golang.org/x/oauth2"
)
type LoginProvider string type LoginProvider string
type LoginProviderInfo struct { type LoginProviderInfo struct {
@ -24,28 +18,3 @@ type AuthenticatorUserInfo struct {
Department string Department string
IsAdmin bool IsAdmin bool
} }
type AuthenticatorType string
const (
AuthenticatorTypeOAuth AuthenticatorType = "oauth"
AuthenticatorTypeOidc AuthenticatorType = "oidc"
)
type OauthAuthenticator interface {
GetName() string
GetType() AuthenticatorType
AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (map[string]any, error)
ParseUserInfo(raw map[string]any) (*AuthenticatorUserInfo, error)
RegistrationEnabled() bool
}
type LdapAuthenticator interface {
GetName() string
PlaintextAuthentication(userId UserIdentifier, plainPassword string) error
GetUserInfo(ctx context.Context, username UserIdentifier) (map[string]any, error)
ParseUserInfo(raw map[string]any) (*AuthenticatorUserInfo, error)
RegistrationEnabled() bool
}

View File

@ -33,6 +33,7 @@ func (p KeyPair) GetPublicKey() wgtypes.Key {
type PreSharedKey string type PreSharedKey string
// NewFreshKeypair generates a new key pair.
func NewFreshKeypair() (KeyPair, error) { func NewFreshKeypair() (KeyPair, error) {
privateKey, err := wgtypes.GeneratePrivateKey() privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
@ -45,6 +46,7 @@ func NewFreshKeypair() (KeyPair, error) {
}, nil }, nil
} }
// NewPreSharedKey generates a new pre-shared key.
func NewPreSharedKey() (PreSharedKey, error) { func NewPreSharedKey() (PreSharedKey, error) {
preSharedKey, err := wgtypes.GenerateKey() preSharedKey, err := wgtypes.GenerateKey()
if err != nil { if err != nil {
@ -54,6 +56,8 @@ func NewPreSharedKey() (PreSharedKey, error) {
return PreSharedKey(preSharedKey.String()), nil return PreSharedKey(preSharedKey.String()), nil
} }
// PublicKeyFromPrivateKey returns the public key for a given private key.
// If the private key is invalid, an empty string is returned.
func PublicKeyFromPrivateKey(key string) string { func PublicKeyFromPrivateKey(key string) string {
privKey, err := wgtypes.ParseKey(key) privKey, err := wgtypes.ParseKey(key)
if err != nil { if err != nil {

View File

@ -0,0 +1,56 @@
package domain
import (
"encoding/base64"
"testing"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
func TestKeyPair_GetPrivateKeyBytesReturnsCorrectBytes(t *testing.T) {
keyPair := KeyPair{PrivateKey: base64.StdEncoding.EncodeToString([]byte("privateKey"))}
expected := []byte("privateKey")
assert.Equal(t, expected, keyPair.GetPrivateKeyBytes())
}
func TestKeyPair_GetPublicKeyBytesReturnsCorrectBytes(t *testing.T) {
keyPair := KeyPair{PublicKey: base64.StdEncoding.EncodeToString([]byte("publicKey"))}
expected := []byte("publicKey")
assert.Equal(t, expected, keyPair.GetPublicKeyBytes())
}
func TestKeyPair_GetPrivateKeyReturnsCorrectKey(t *testing.T) {
privateKey, _ := wgtypes.GeneratePrivateKey()
keyPair := KeyPair{PrivateKey: privateKey.String()}
assert.Equal(t, privateKey, keyPair.GetPrivateKey())
}
func TestKeyPair_GetPublicKeyReturnsCorrectKey(t *testing.T) {
privateKey, _ := wgtypes.GeneratePrivateKey()
keyPair := KeyPair{PublicKey: privateKey.PublicKey().String()}
assert.Equal(t, privateKey.PublicKey(), keyPair.GetPublicKey())
}
func TestNewFreshKeypairGeneratesValidKeypair(t *testing.T) {
keyPair, err := NewFreshKeypair()
assert.NoError(t, err)
assert.NotEmpty(t, keyPair.PrivateKey)
assert.NotEmpty(t, keyPair.PublicKey)
}
func TestNewPreSharedKeyGeneratesValidKey(t *testing.T) {
preSharedKey, err := NewPreSharedKey()
assert.NoError(t, err)
assert.NotEmpty(t, preSharedKey)
}
func TestPublicKeyFromPrivateKeyReturnsCorrectPublicKey(t *testing.T) {
privateKey, _ := wgtypes.GeneratePrivateKey()
expected := privateKey.PublicKey().String()
assert.Equal(t, expected, PublicKeyFromPrivateKey(privateKey.String()))
}
func TestPublicKeyFromPrivateKeyReturnsEmptyStringOnInvalidKey(t *testing.T) {
assert.Equal(t, "", PublicKeyFromPrivateKey("invalidKey"))
}

View File

@ -0,0 +1,83 @@
package domain
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestInterface_IsDisabledReturnsTrueWhenDisabled(t *testing.T) {
iface := &Interface{}
assert.False(t, iface.IsDisabled())
now := time.Now()
iface.Disabled = &now
assert.True(t, iface.IsDisabled())
}
func TestInterface_AddressStrReturnsCorrectString(t *testing.T) {
iface := &Interface{
Addresses: []Cidr{
{Cidr: "192.168.1.1/24", Addr: "192.168.1.1", NetLength: 24},
{Cidr: "10.0.0.1/24", Addr: "10.0.0.1", NetLength: 24},
},
}
expected := "192.168.1.1/24,10.0.0.1/24"
assert.Equal(t, expected, iface.AddressStr())
}
func TestInterface_GetConfigFileNameReturnsCorrectFileName(t *testing.T) {
iface := &Interface{Identifier: "wg0"}
expected := "wg0.conf"
assert.Equal(t, expected, iface.GetConfigFileName())
iface.Identifier = "wg0@123"
expected = "wg0123.conf"
assert.Equal(t, expected, iface.GetConfigFileName())
}
func TestInterface_GetAllowedIPsReturnsCorrectCidrs(t *testing.T) {
peer1 := Peer{
Interface: PeerInterfaceConfig{
Addresses: []Cidr{
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
},
},
}
peer2 := Peer{
Interface: PeerInterfaceConfig{
Addresses: []Cidr{
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
},
},
}
iface := &Interface{}
expected := []Cidr{
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
}
assert.Equal(t, expected, iface.GetAllowedIPs([]Peer{peer1, peer2}))
}
func TestInterface_ManageRoutingTableReturnsCorrectValue(t *testing.T) {
iface := &Interface{RoutingTable: "off"}
assert.False(t, iface.ManageRoutingTable())
iface.RoutingTable = "100"
assert.True(t, iface.ManageRoutingTable())
}
func TestInterface_GetRoutingTableReturnsCorrectValue(t *testing.T) {
iface := &Interface{RoutingTable: ""}
assert.Equal(t, 0, iface.GetRoutingTable())
iface.RoutingTable = "off"
assert.Equal(t, -1, iface.GetRoutingTable())
iface.RoutingTable = "0x64"
assert.Equal(t, 100, iface.GetRoutingTable())
iface.RoutingTable = "200"
assert.Equal(t, 200, iface.GetRoutingTable())
}

View File

@ -0,0 +1,42 @@
package domain
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestConfigOption_GetValueReturnsCorrectValue(t *testing.T) {
option := ConfigOption[int]{Value: 42}
assert.Equal(t, 42, option.GetValue())
}
func TestConfigOption_SetValueUpdatesValue(t *testing.T) {
option := ConfigOption[int]{Value: 42}
option.SetValue(100)
assert.Equal(t, 100, option.GetValue())
}
func TestConfigOption_TrySetValueUpdatesValueWhenOverridable(t *testing.T) {
option := ConfigOption[int]{Value: 42, Overridable: true}
result := option.TrySetValue(100)
assert.True(t, result)
assert.Equal(t, 100, option.GetValue())
}
func TestConfigOption_TrySetValueDoesNotUpdateValueWhenNotOverridable(t *testing.T) {
option := ConfigOption[int]{Value: 42, Overridable: false}
result := option.TrySetValue(100)
assert.False(t, result)
assert.Equal(t, 42, option.GetValue())
}
func TestNewConfigOptionCreatesCorrectOption(t *testing.T) {
option := NewConfigOption(42, true)
assert.Equal(t, 42, option.GetValue())
assert.True(t, option.Overridable)
option2 := NewConfigOption("str", false)
assert.Equal(t, "str", option2.GetValue())
assert.False(t, option2.Overridable)
}

View File

@ -0,0 +1,165 @@
package domain
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestPeer_IsDisabled(t *testing.T) {
peer := &Peer{}
assert.False(t, peer.IsDisabled())
now := time.Now()
peer.Disabled = &now
assert.True(t, peer.IsDisabled())
}
func TestPeer_IsExpired(t *testing.T) {
peer := &Peer{}
assert.False(t, peer.IsExpired())
expiredTime := time.Now().Add(-time.Hour)
peer.ExpiresAt = &expiredTime
assert.True(t, peer.IsExpired())
futureTime := time.Now().Add(time.Hour)
peer.ExpiresAt = &futureTime
assert.False(t, peer.IsExpired())
}
func TestPeer_CheckAliveAddress(t *testing.T) {
peer := &Peer{}
assert.Equal(t, "", peer.CheckAliveAddress())
peer.Interface.CheckAliveAddress = "192.168.1.1"
assert.Equal(t, "192.168.1.1", peer.CheckAliveAddress())
peer.Interface.CheckAliveAddress = ""
peer.Interface.Addresses = []Cidr{{Addr: "10.0.0.1"}}
assert.Equal(t, "10.0.0.1", peer.CheckAliveAddress())
}
func TestPeer_GetConfigFileName(t *testing.T) {
peer := &Peer{DisplayName: "Test Peer"}
expected := "Test_Peer.conf"
assert.Equal(t, expected, peer.GetConfigFileName())
peer.DisplayName = ""
peer.Identifier = "12345678"
expected = "wg_12345678.conf"
assert.Equal(t, expected, peer.GetConfigFileName())
}
func TestPeer_ApplyInterfaceDefaults(t *testing.T) {
peer := &Peer{
Endpoint: ConfigOption[string]{
Value: "",
Overridable: true,
},
EndpointPublicKey: ConfigOption[string]{
Value: "",
Overridable: true,
},
AllowedIPsStr: ConfigOption[string]{
Value: "1.1.1.1/32",
Overridable: false,
},
}
iface := &Interface{
PeerDefEndpoint: "192.168.1.1",
KeyPair: KeyPair{
PublicKey: "publicKey",
},
PeerDefAllowedIPsStr: "8.8.8.8/32",
}
peer.ApplyInterfaceDefaults(iface)
assert.Equal(t, "192.168.1.1", peer.Endpoint.GetValue())
assert.Equal(t, "publicKey", peer.EndpointPublicKey.GetValue())
assert.Equal(t, "1.1.1.1/32", peer.AllowedIPsStr.GetValue())
}
func TestPeer_GenerateDisplayName(t *testing.T) {
peer := &Peer{Identifier: "12345678"}
peer.GenerateDisplayName("Prefix")
expected := "Prefix Peer 12345678"
assert.Equal(t, expected, peer.DisplayName)
peer.GenerateDisplayName("")
expected = "Peer 12345678"
assert.Equal(t, expected, peer.DisplayName)
}
func TestPeer_OverwriteUserEditableFields(t *testing.T) {
peer := &Peer{}
userPeer := &Peer{
DisplayName: "New DisplayName",
}
peer.OverwriteUserEditableFields(userPeer)
assert.Equal(t, "New DisplayName", peer.DisplayName)
}
func TestPeer_GetPresharedKey(t *testing.T) {
physicalPeer := PhysicalPeer{}
assert.Nil(t, physicalPeer.GetPresharedKey())
physicalPeer.PresharedKey = "Q0evIJTOjhyy2o5J7whvrsvQC+FRL8A74vrw44YHUAk="
key := physicalPeer.GetPresharedKey()
assert.NotNil(t, key)
}
func TestPeer_GetEndpointAddress(t *testing.T) {
physicalPeer := PhysicalPeer{}
assert.Nil(t, physicalPeer.GetEndpointAddress())
physicalPeer.Endpoint = "192.168.1.1:51820"
addr := physicalPeer.GetEndpointAddress()
assert.NotNil(t, addr)
assert.Equal(t, "192.168.1.1:51820", addr.String())
}
func TestPeer_GetPersistentKeepaliveTime(t *testing.T) {
physicalPeer := PhysicalPeer{}
assert.Nil(t, physicalPeer.GetPersistentKeepaliveTime())
physicalPeer.PersistentKeepalive = 25
duration := physicalPeer.GetPersistentKeepaliveTime()
assert.NotNil(t, duration)
assert.Equal(t, 25*time.Second, *duration)
}
func TestPeer_GetAllowedIPs(t *testing.T) {
physicalPeer := PhysicalPeer{}
assert.Empty(t, physicalPeer.GetAllowedIPs())
physicalPeer.AllowedIPs = []Cidr{
{
Cidr: "192.168.1.0/24",
Addr: "192.168.1.0",
NetLength: 24,
},
}
ips := physicalPeer.GetAllowedIPs()
assert.Len(t, ips, 1)
assert.Equal(t, "192.168.1.0/24", ips[0].String())
physicalPeer.AllowedIPs = []Cidr{
{
Cidr: "192.168.1.0/24",
Addr: "192.168.1.0",
NetLength: 24,
},
{
Cidr: "fe80::/64",
Addr: "fe80::",
NetLength: 64,
},
}
ips2 := physicalPeer.GetAllowedIPs()
assert.Len(t, ips2, 2)
assert.Equal(t, "192.168.1.0/24", ips2[0].String())
assert.Equal(t, "fe80::/64", ips2[1].String())
}

View File

@ -0,0 +1,74 @@
package domain
import (
"testing"
"time"
)
func TestPeerStatus_IsConnected(t *testing.T) {
now := time.Now()
past := now.Add(-3 * time.Minute)
recent := now.Add(-1 * time.Minute)
tests := []struct {
name string
status PeerStatus
want bool
}{
{
name: "Pingable and recent handshake",
status: PeerStatus{
IsPingable: true,
LastHandshake: &recent,
},
want: true,
},
{
name: "Not pingable but recent handshake",
status: PeerStatus{
IsPingable: false,
LastHandshake: &recent,
},
want: true,
},
{
name: "Pingable but old handshake",
status: PeerStatus{
IsPingable: true,
LastHandshake: &past,
},
want: true,
},
{
name: "Not pingable and old handshake",
status: PeerStatus{
IsPingable: false,
LastHandshake: &past,
},
want: false,
},
{
name: "Pingable and no handshake",
status: PeerStatus{
IsPingable: true,
LastHandshake: nil,
},
want: true,
},
{
name: "Not pingable and no handshake",
status: PeerStatus{
IsPingable: false,
LastHandshake: nil,
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.status.IsConnected(); got != tt.want {
t.Errorf("IsConnected() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -0,0 +1,125 @@
package domain
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
)
func TestUser_IsDisabled(t *testing.T) {
user := &User{}
assert.False(t, user.IsDisabled())
now := time.Now()
user.Disabled = &now
assert.True(t, user.IsDisabled())
}
func TestUser_IsLocked(t *testing.T) {
user := &User{}
assert.False(t, user.IsLocked())
now := time.Now()
user.Locked = &now
assert.True(t, user.IsLocked())
}
func TestUser_IsApiEnabled(t *testing.T) {
user := &User{}
assert.False(t, user.IsApiEnabled())
user.ApiToken = "token"
assert.True(t, user.IsApiEnabled())
}
func TestUser_CanChangePassword(t *testing.T) {
user := &User{Source: UserSourceDatabase}
assert.NoError(t, user.CanChangePassword())
user.Source = UserSourceLdap
assert.Error(t, user.CanChangePassword())
user.Source = UserSourceOauth
assert.Error(t, user.CanChangePassword())
}
func TestUser_EditAllowed(t *testing.T) {
user := &User{Source: UserSourceDatabase}
newUser := &User{Source: UserSourceDatabase}
assert.NoError(t, user.EditAllowed(newUser))
newUser.Notes = "notes can be changed"
assert.NoError(t, user.EditAllowed(newUser))
newUser.Disabled = &time.Time{}
assert.NoError(t, user.EditAllowed(newUser))
newUser.Lastname = "lastname or other fields can be changed"
assert.NoError(t, user.EditAllowed(newUser))
user.Source = UserSourceLdap
newUser.Source = UserSourceLdap
newUser.Disabled = nil
newUser.Lastname = ""
newUser.Notes = "notes can be changed"
assert.NoError(t, user.EditAllowed(newUser))
newUser.Disabled = &time.Time{}
assert.NoError(t, user.EditAllowed(newUser))
newUser.Lastname = "lastname or other fields can not be changed"
assert.Error(t, user.EditAllowed(newUser))
user.Source = UserSourceOauth
newUser.Source = UserSourceOauth
newUser.Disabled = nil
newUser.Lastname = ""
newUser.Notes = "notes can be changed"
assert.NoError(t, user.EditAllowed(newUser))
newUser.Disabled = &time.Time{}
assert.NoError(t, user.EditAllowed(newUser))
newUser.Lastname = "lastname or other fields can not be changed"
assert.Error(t, user.EditAllowed(newUser))
}
func TestUser_DeleteAllowed(t *testing.T) {
user := &User{}
assert.NoError(t, user.DeleteAllowed())
}
func TestUser_CheckPassword(t *testing.T) {
password := "password"
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
user := &User{Source: UserSourceDatabase, Password: PrivateString(hashedPassword)}
assert.NoError(t, user.CheckPassword(password))
user.Password = ""
assert.Error(t, user.CheckPassword(password))
user.Source = UserSourceLdap
assert.Error(t, user.CheckPassword(password))
}
func TestUser_CheckApiToken(t *testing.T) {
user := &User{}
assert.Error(t, user.CheckApiToken("token"))
user.ApiToken = "token"
assert.NoError(t, user.CheckApiToken("token"))
assert.Error(t, user.CheckApiToken("wrong_token"))
}
func TestUser_HashPassword(t *testing.T) {
user := &User{Password: "password"}
assert.NoError(t, user.HashPassword())
assert.NotEmpty(t, user.Password)
user.Password = ""
assert.NoError(t, user.HashPassword())
}