chore: use interfaces for all other services

This commit is contained in:
Christoph Haas 2025-03-23 23:09:47 +01:00
parent 02ed7b19df
commit 7d0da4e7ad
40 changed files with 1337 additions and 406 deletions

View File

@ -14,6 +14,7 @@ import (
"github.com/h44z/wg-portal/internal/adapters"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/app/api/core"
backendV0 "github.com/h44z/wg-portal/internal/app/api/v0/backend"
handlersV0 "github.com/h44z/wg-portal/internal/app/api/v0/handlers"
backendV1 "github.com/h44z/wg-portal/internal/app/api/v1/backend"
handlersV1 "github.com/h44z/wg-portal/internal/app/api/v1/handlers"
@ -70,17 +71,24 @@ func main() {
queueSize := 100
eventBus := evbus.New(queueSize)
auditRecorder, err := audit.NewAuditRecorder(cfg, eventBus, database)
internal.AssertNoError(err)
auditRecorder.StartBackgroundJobs(ctx)
userManager, err := users.NewUserManager(cfg, eventBus, database, database)
internal.AssertNoError(err)
userManager.StartBackgroundJobs(ctx)
authenticator, err := auth.NewAuthenticator(&cfg.Auth, cfg.Web.ExternalUrl, eventBus, userManager)
internal.AssertNoError(err)
wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, wgQuick, database)
internal.AssertNoError(err)
wireGuardManager.StartBackgroundJobs(ctx)
statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, eventBus, database, wireGuard, metricsServer)
internal.AssertNoError(err)
statisticsCollector.StartBackgroundJobs(ctx)
cfgFileManager, err := configfile.NewConfigFileManager(cfg, eventBus, database, database, cfgFileSystem)
internal.AssertNoError(err)
@ -88,18 +96,11 @@ func main() {
mailManager, err := mail.NewMailManager(cfg, mailer, cfgFileManager, database, database)
internal.AssertNoError(err)
auditRecorder, err := audit.NewAuditRecorder(cfg, eventBus, database)
internal.AssertNoError(err)
auditRecorder.StartBackgroundJobs(ctx)
routeManager, err := route.NewRouteManager(cfg, eventBus, database)
internal.AssertNoError(err)
routeManager.StartBackgroundJobs(ctx)
backend, err := app.New(cfg, eventBus, authenticator, userManager, wireGuardManager,
statisticsCollector, cfgFileManager, mailManager)
internal.AssertNoError(err)
err = backend.Startup(ctx)
err = app.Initialize(cfg, wireGuardManager, userManager)
internal.AssertNoError(err)
validatorManager := validator.New()
@ -109,10 +110,14 @@ func main() {
apiV0Session := handlersV0.NewSessionWrapper(cfg)
apiV0Auth := handlersV0.NewAuthenticationHandler(authenticator, apiV0Session)
apiV0EndpointAuth := handlersV0.NewAuthEndpoint(cfg, apiV0Auth, apiV0Session, validatorManager, backend)
apiV0EndpointUsers := handlersV0.NewUserEndpoint(cfg, apiV0Auth, validatorManager, backend)
apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(cfg, apiV0Auth, validatorManager, backend)
apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, backend)
apiV0BackendUsers := backendV0.NewUserService(cfg, userManager, wireGuardManager)
apiV0BackendInterfaces := backendV0.NewInterfaceService(cfg, wireGuardManager, cfgFileManager)
apiV0BackendPeers := backendV0.NewPeerService(cfg, wireGuardManager, cfgFileManager, mailManager)
apiV0EndpointAuth := handlersV0.NewAuthEndpoint(cfg, apiV0Auth, apiV0Session, validatorManager, authenticator)
apiV0EndpointUsers := handlersV0.NewUserEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendUsers)
apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendInterfaces)
apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendPeers)
apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth)
apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth)

View File

@ -0,0 +1,91 @@
package backend
import (
"context"
"io"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
// region dependencies
type InterfaceServiceInterfaceManager interface {
GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error)
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error)
UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error)
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
PrepareInterface(ctx context.Context) (*domain.Interface, error)
ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error
}
type InterfaceServiceConfigFileManager interface {
PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error
GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error)
}
// endregion dependencies
type InterfaceService struct {
cfg *config.Config
interfaces InterfaceServiceInterfaceManager
configFile InterfaceServiceConfigFileManager
}
func NewInterfaceService(
cfg *config.Config,
interfaces InterfaceServiceInterfaceManager,
configFile InterfaceServiceConfigFileManager,
) *InterfaceService {
return &InterfaceService{
cfg: cfg,
interfaces: interfaces,
configFile: configFile,
}
}
func (i InterfaceService) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
*domain.Interface,
[]domain.Peer,
error,
) {
return i.interfaces.GetInterfaceAndPeers(ctx, id)
}
func (i InterfaceService) PrepareInterface(ctx context.Context) (*domain.Interface, error) {
return i.interfaces.PrepareInterface(ctx)
}
func (i InterfaceService) CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) {
return i.interfaces.CreateInterface(ctx, in)
}
func (i InterfaceService) UpdateInterface(ctx context.Context, in *domain.Interface) (
*domain.Interface,
[]domain.Peer,
error,
) {
return i.interfaces.UpdateInterface(ctx, in)
}
func (i InterfaceService) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
return i.interfaces.DeleteInterface(ctx, id)
}
func (i InterfaceService) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) {
return i.interfaces.GetAllInterfacesAndPeers(ctx)
}
func (i InterfaceService) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) {
return i.configFile.GetInterfaceConfig(ctx, id)
}
func (i InterfaceService) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error {
return i.configFile.PersistInterfaceConfig(ctx, id)
}
func (i InterfaceService) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error {
return i.interfaces.ApplyPeerDefaults(ctx, in)
}

View File

@ -0,0 +1,112 @@
package backend
import (
"context"
"io"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
// region dependencies
type PeerServicePeerManager interface {
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error)
CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error)
UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error)
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
CreateMultiplePeers(
ctx context.Context,
interfaceId domain.InterfaceIdentifier,
r *domain.PeerCreationRequest,
) ([]domain.Peer, error)
GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error)
}
type PeerServiceConfigFileManager interface {
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
}
type PeerServiceMailManager interface {
SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error
}
// endregion dependencies
type PeerService struct {
cfg *config.Config
peers PeerServicePeerManager
configFile PeerServiceConfigFileManager
mailer PeerServiceMailManager
}
func NewPeerService(
cfg *config.Config,
peers PeerServicePeerManager,
configFile PeerServiceConfigFileManager,
mailer PeerServiceMailManager,
) *PeerService {
return &PeerService{
cfg: cfg,
peers: peers,
configFile: configFile,
mailer: mailer,
}
}
func (p PeerService) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
*domain.Interface,
[]domain.Peer,
error,
) {
return p.peers.GetInterfaceAndPeers(ctx, id)
}
func (p PeerService) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) {
return p.peers.PreparePeer(ctx, id)
}
func (p PeerService) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) {
return p.peers.GetPeer(ctx, id)
}
func (p PeerService) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
return p.peers.CreatePeer(ctx, peer)
}
func (p PeerService) CreateMultiplePeers(
ctx context.Context,
interfaceId domain.InterfaceIdentifier,
r *domain.PeerCreationRequest,
) ([]domain.Peer, error) {
return p.peers.CreateMultiplePeers(ctx, interfaceId, r)
}
func (p PeerService) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
return p.peers.UpdatePeer(ctx, peer)
}
func (p PeerService) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error {
return p.peers.DeletePeer(ctx, id)
}
func (p PeerService) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
return p.configFile.GetPeerConfig(ctx, id)
}
func (p PeerService) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
return p.configFile.GetPeerConfigQrCode(ctx, id)
}
func (p PeerService) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error {
return p.mailer.SendPeerEmail(ctx, linkOnly, peers...)
}
func (p PeerService) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error) {
return p.peers.GetPeerStats(ctx, id)
}

View File

@ -0,0 +1,83 @@
package backend
import (
"context"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
// region dependencies
type UserServiceUserManager interface {
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
GetAllUsers(ctx context.Context) ([]domain.User, error)
CreateUser(ctx context.Context, user *domain.User) (*domain.User, error)
UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error)
DeleteUser(ctx context.Context, id domain.UserIdentifier) error
ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
}
type UserServiceWireGuardManager interface {
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
GetUserInterfaces(ctx context.Context, _ domain.UserIdentifier) ([]domain.Interface, error)
GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error)
}
// endregion dependencies
type UserService struct {
cfg *config.Config
users UserServiceUserManager
wg UserServiceWireGuardManager
}
func NewUserService(cfg *config.Config, users UserServiceUserManager, wg UserServiceWireGuardManager) *UserService {
return &UserService{
cfg: cfg,
users: users,
wg: wg,
}
}
func (u UserService) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
return u.users.GetUser(ctx, id)
}
func (u UserService) GetAllUsers(ctx context.Context) ([]domain.User, error) {
return u.users.GetAllUsers(ctx)
}
func (u UserService) UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
return u.users.UpdateUser(ctx, user)
}
func (u UserService) CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
return u.users.CreateUser(ctx, user)
}
func (u UserService) DeleteUser(ctx context.Context, id domain.UserIdentifier) error {
return u.users.DeleteUser(ctx, id)
}
func (u UserService) ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
return u.users.ActivateApi(ctx, id)
}
func (u UserService) DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
return u.users.DeactivateApi(ctx, id)
}
func (u UserService) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
return u.wg.GetUserPeers(ctx, id)
}
func (u UserService) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) {
return u.wg.GetUserPeerStats(ctx, id)
}
func (u UserService) GetUserInterfaces(ctx context.Context, id domain.UserIdentifier) ([]domain.Interface, error) {
return u.wg.GetUserInterfaces(ctx, id)
}

View File

@ -7,46 +7,43 @@ import (
"log/slog"
"time"
evbus "github.com/vardius/message-bus"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
type App struct {
Config *config.Config
bus evbus.MessageBus
// region dependencies
Authenticator
UserManager
WireGuardManager
StatisticsCollector
ConfigFileManager
MailManager
ApiV1Manager
type WireGuardManager interface {
ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error)
RestoreInterfaceState(ctx context.Context, updateDbOnError bool, filter ...domain.InterfaceIdentifier) error
}
func New(
type UserManager interface {
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
CreateUser(ctx context.Context, user *domain.User) (*domain.User, error)
}
// endregion dependencies
// App is the main application struct.
type App struct {
cfg *config.Config
wg WireGuardManager
users UserManager
}
// Initialize creates a new App instance and initializes it.
func Initialize(
cfg *config.Config,
bus evbus.MessageBus,
authenticator Authenticator,
wg WireGuardManager,
users UserManager,
wireGuard WireGuardManager,
stats StatisticsCollector,
cfgFiles ConfigFileManager,
mailer MailManager,
) (*App, error) {
) error {
a := &App{
Config: cfg,
bus: bus,
cfg: cfg,
Authenticator: authenticator,
UserManager: users,
WireGuardManager: wireGuard,
StatisticsCollector: stats,
ConfigFileManager: cfgFiles,
MailManager: mailer,
wg: wg,
users: users,
}
startupContext, cancel := context.WithTimeout(context.Background(), 30*time.Second)
@ -56,36 +53,27 @@ func New(
startupContext = domain.SetUserInfo(startupContext, domain.SystemAdminContextUserInfo())
if err := a.createDefaultUser(startupContext); err != nil {
return nil, fmt.Errorf("failed to create default user: %w", err)
return fmt.Errorf("failed to create default user: %w", err)
}
if err := a.importNewInterfaces(startupContext); err != nil {
return nil, fmt.Errorf("failed to import new interfaces: %w", err)
return fmt.Errorf("failed to import new interfaces: %w", err)
}
if err := a.restoreInterfaceState(startupContext); err != nil {
return nil, fmt.Errorf("failed to restore interface state: %w", err)
return fmt.Errorf("failed to restore interface state: %w", err)
}
return a, nil
}
func (a *App) Startup(ctx context.Context) error {
a.UserManager.StartBackgroundJobs(ctx)
a.StatisticsCollector.StartBackgroundJobs(ctx)
a.WireGuardManager.StartBackgroundJobs(ctx)
return nil
}
func (a *App) importNewInterfaces(ctx context.Context) error {
if !a.Config.Core.ImportExisting {
if !a.cfg.Core.ImportExisting {
slog.Debug("skipping interface import - feature disabled")
return nil // feature disabled
}
importedCount, err := a.ImportNewInterfaces(ctx)
importedCount, err := a.wg.ImportNewInterfaces(ctx)
if err != nil {
return err
}
@ -97,12 +85,12 @@ func (a *App) importNewInterfaces(ctx context.Context) error {
}
func (a *App) restoreInterfaceState(ctx context.Context) error {
if !a.Config.Core.RestoreState {
if !a.cfg.Core.RestoreState {
slog.Debug("skipping interface state restore - feature disabled")
return nil // feature disabled
}
err := a.RestoreInterfaceState(ctx, true)
err := a.wg.RestoreInterfaceState(ctx, true)
if err != nil {
return err
}
@ -112,13 +100,13 @@ func (a *App) restoreInterfaceState(ctx context.Context) error {
}
func (a *App) createDefaultUser(ctx context.Context) error {
adminUserId := domain.UserIdentifier(a.Config.Core.AdminUser)
adminUserId := domain.UserIdentifier(a.cfg.Core.AdminUser)
if adminUserId == "" {
slog.Debug("skipping default user creation - admin user is blank")
return nil // empty admin user - do not create
}
_, err := a.GetUser(ctx, adminUserId)
_, err := a.users.GetUser(ctx, adminUserId)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return err
}
@ -145,22 +133,22 @@ func (a *App) createDefaultUser(ctx context.Context) error {
Phone: "",
Department: "",
Notes: "default administrator user",
Password: domain.PrivateString(a.Config.Core.AdminPassword),
Password: domain.PrivateString(a.cfg.Core.AdminPassword),
Disabled: nil,
DisabledReason: "",
Locked: nil,
LockedReason: "",
LinkedPeerCount: 0,
}
if a.Config.Core.AdminApiToken != "" {
if len(a.Config.Core.AdminApiToken) < 18 {
if a.cfg.Core.AdminApiToken != "" {
if len(a.cfg.Core.AdminApiToken) < 18 {
slog.Warn("admin API token is too short, should be at least 18 characters long")
}
defaultAdmin.ApiToken = a.Config.Core.AdminApiToken
defaultAdmin.ApiToken = a.cfg.Core.AdminApiToken
defaultAdmin.ApiTokenCreated = &now
}
admin, err := a.CreateUser(ctx, defaultAdmin)
admin, err := a.users.CreateUser(ctx, defaultAdmin)
if err != nil {
return err
}

View File

@ -6,21 +6,35 @@ import (
"log/slog"
"time"
evbus "github.com/vardius/message-bus"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
// region dependencies
type DatabaseRepo interface {
// SaveAuditEntry saves an audit entry to the database
SaveAuditEntry(ctx context.Context, entry *domain.AuditEntry) error
}
type EventBus interface {
// Subscribe subscribes to a topic
Subscribe(topic string, fn interface{}) error
}
// endregion dependencies
// Recorder is responsible for recording audit events to the database.
type Recorder struct {
cfg *config.Config
bus evbus.MessageBus
bus EventBus
db DatabaseRepo
}
func NewAuditRecorder(cfg *config.Config, bus evbus.MessageBus, db DatabaseRepo) (*Recorder, error) {
// NewAuditRecorder creates a new audit recorder instance.
func NewAuditRecorder(cfg *config.Config, bus EventBus, db DatabaseRepo) (*Recorder, error) {
r := &Recorder{
cfg: cfg,
bus: bus,
@ -36,6 +50,8 @@ func NewAuditRecorder(cfg *config.Config, bus evbus.MessageBus, db DatabaseRepo)
return r, nil
}
// StartBackgroundJobs starts background jobs for the audit recorder.
// This method is non-blocking and returns immediately.
func (r *Recorder) StartBackgroundJobs(ctx context.Context) {
if !r.cfg.Statistics.CollectAuditData {
return // noting to do

View File

@ -1,11 +0,0 @@
package audit
import (
"context"
"github.com/h44z/wg-portal/internal/domain"
)
type DatabaseRepo interface {
SaveAuditEntry(ctx context.Context, entry *domain.AuditEntry) error
}

View File

@ -14,25 +14,78 @@ import (
"time"
"github.com/coreos/go-oidc/v3/oidc"
evbus "github.com/vardius/message-bus"
"golang.org/x/oauth2"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
// region dependencies
type UserManager interface {
// GetUser returns a user by its identifier.
GetUser(context.Context, domain.UserIdentifier) (*domain.User, error)
// RegisterUser creates a new user in the database.
RegisterUser(ctx context.Context, user *domain.User) error
// UpdateUser updates an existing user in the database.
UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error)
}
type EventBus interface {
// Publish sends a message to the message bus.
Publish(topic string, args ...any)
}
// endregion dependencies
type AuthenticatorType string
const (
AuthenticatorTypeOAuth AuthenticatorType = "oauth"
AuthenticatorTypeOidc AuthenticatorType = "oidc"
)
// AuthenticatorOauth is the interface for all OAuth authenticators.
type AuthenticatorOauth interface {
// GetName returns the name of the authenticator.
GetName() string
// GetType returns the type of the authenticator. It can be either AuthenticatorTypeOAuth or AuthenticatorTypeOidc.
GetType() AuthenticatorType
// AuthCodeURL returns the URL for the authentication flow.
AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
// Exchange exchanges the OAuth code for an access token.
Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
// GetUserInfo fetches the user information from the OAuth or OIDC provider.
GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (map[string]any, error)
// ParseUserInfo parses the raw user information into a domain.AuthenticatorUserInfo struct.
ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error)
// RegistrationEnabled returns whether registration is enabled for the OAuth authenticator.
RegistrationEnabled() bool
}
// AuthenticatorLdap is the interface for all LDAP authenticators.
type AuthenticatorLdap interface {
// GetName returns the name of the authenticator.
GetName() string
// PlaintextAuthentication performs a plaintext authentication against the LDAP server.
PlaintextAuthentication(userId domain.UserIdentifier, plainPassword string) error
// GetUserInfo fetches the user information from the LDAP server.
GetUserInfo(ctx context.Context, username domain.UserIdentifier) (map[string]any, error)
// ParseUserInfo parses the raw user information into a domain.AuthenticatorUserInfo struct.
ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error)
// RegistrationEnabled returns whether registration is enabled for the LDAP authenticator.
RegistrationEnabled() bool
}
// Authenticator is the main entry point for all authentication related tasks.
// This includes password authentication and external authentication providers (OIDC, OAuth, LDAP).
type Authenticator struct {
cfg *config.Auth
bus evbus.MessageBus
bus EventBus
oauthAuthenticators map[string]domain.OauthAuthenticator
ldapAuthenticators map[string]domain.LdapAuthenticator
oauthAuthenticators map[string]AuthenticatorOauth
ldapAuthenticators map[string]AuthenticatorLdap
// URL prefix for the callback endpoints, this is a combination of the external URL and the API prefix
callbackUrlPrefix string
@ -40,7 +93,8 @@ type Authenticator struct {
users UserManager
}
func NewAuthenticator(cfg *config.Auth, extUrl string, bus evbus.MessageBus, users UserManager) (
// NewAuthenticator creates a new Authenticator instance.
func NewAuthenticator(cfg *config.Auth, extUrl string, bus EventBus, users UserManager) (
*Authenticator,
error,
) {
@ -68,8 +122,8 @@ func (a *Authenticator) setupExternalAuthProviders(ctx context.Context) error {
return fmt.Errorf("failed to parse external url: %w", err)
}
a.oauthAuthenticators = make(map[string]domain.OauthAuthenticator, len(a.cfg.OpenIDConnect)+len(a.cfg.OAuth))
a.ldapAuthenticators = make(map[string]domain.LdapAuthenticator, len(a.cfg.Ldap))
a.oauthAuthenticators = make(map[string]AuthenticatorOauth, len(a.cfg.OpenIDConnect)+len(a.cfg.OAuth))
a.ldapAuthenticators = make(map[string]AuthenticatorLdap, len(a.cfg.Ldap))
for i := range a.cfg.OpenIDConnect { // OIDC
providerCfg := &a.cfg.OpenIDConnect[i]
@ -123,6 +177,7 @@ func (a *Authenticator) setupExternalAuthProviders(ctx context.Context) error {
return nil
}
// GetExternalLoginProviders returns a list of all available external login providers.
func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.LoginProviderInfo {
authProviders := make([]domain.LoginProviderInfo, 0, len(a.cfg.OAuth)+len(a.cfg.OpenIDConnect))
@ -157,6 +212,7 @@ func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.Lo
return authProviders
}
// IsUserValid checks if a user is valid and not locked or disabled.
func (a *Authenticator) IsUserValid(ctx context.Context, id domain.UserIdentifier) bool {
ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context
user, err := a.users.GetUser(ctx, id)
@ -177,6 +233,8 @@ func (a *Authenticator) IsUserValid(ctx context.Context, id domain.UserIdentifie
// region password authentication
// PlainLogin performs a password authentication for a user. The username and password are trimmed before usage.
// If the login is successful, the user is returned, otherwise an error.
func (a *Authenticator) PlainLogin(ctx context.Context, username, password string) (*domain.User, error) {
// Validate form input
username = strings.TrimSpace(username)
@ -204,7 +262,7 @@ func (a *Authenticator) passwordAuthentication(
domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
var ldapUserInfo *domain.AuthenticatorUserInfo
var ldapProvider domain.LdapAuthenticator
var ldapProvider AuthenticatorLdap
var userInDatabase = false
var userSource domain.UserSource
@ -280,6 +338,7 @@ func (a *Authenticator) passwordAuthentication(
// region oauth authentication
// OauthLoginStep1 starts the oauth authentication flow by returning the authentication URL, state and nonce.
func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) (
authCodeUrl, state, nonce string,
err error,
@ -296,9 +355,9 @@ func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) (
}
switch oauthProvider.GetType() {
case domain.AuthenticatorTypeOAuth:
case AuthenticatorTypeOAuth:
authCodeUrl = oauthProvider.AuthCodeURL(state)
case domain.AuthenticatorTypeOidc:
case AuthenticatorTypeOidc:
nonce, err = a.randString(16)
if err != nil {
return "", "", "", fmt.Errorf("failed to generate nonce: %w", err)
@ -318,6 +377,8 @@ func (a *Authenticator) randString(nByte int) (string, error) {
return base64.RawURLEncoding.EncodeToString(b), nil
}
// OauthLoginStep2 finishes the oauth authentication flow by exchanging the code for an access token and
// fetching the user information.
func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, error) {
oauthProvider, ok := a.oauthAuthenticators[providerId]
if !ok {

View File

@ -14,6 +14,7 @@ import (
"github.com/h44z/wg-portal/internal/domain"
)
// LdapAuthenticator is an authenticator that uses LDAP for authentication.
type LdapAuthenticator struct {
cfg *config.LdapProvider
}
@ -33,14 +34,17 @@ func newLdapAuthenticator(_ context.Context, cfg *config.LdapProvider) (*LdapAut
return provider, nil
}
// GetName returns the name of the LDAP authenticator.
func (l LdapAuthenticator) GetName() string {
return l.cfg.ProviderName
}
// RegistrationEnabled returns whether registration is enabled for the LDAP authenticator.
func (l LdapAuthenticator) RegistrationEnabled() bool {
return l.cfg.RegistrationEnabled
}
// PlaintextAuthentication performs a plaintext authentication against the LDAP server.
func (l LdapAuthenticator) PlaintextAuthentication(userId domain.UserIdentifier, plainPassword string) error {
conn, err := internal.LdapConnect(l.cfg)
if err != nil {
@ -81,6 +85,9 @@ func (l LdapAuthenticator) PlaintextAuthentication(userId domain.UserIdentifier,
return nil
}
// GetUserInfo retrieves user information from the LDAP server.
// If the user is not found, domain.ErrNotFound is returned.
// If multiple users are found, domain.ErrNotUnique is returned.
func (l LdapAuthenticator) GetUserInfo(_ context.Context, userId domain.UserIdentifier) (
map[string]any,
error,
@ -126,6 +133,7 @@ func (l LdapAuthenticator) GetUserInfo(_ context.Context, userId domain.UserIden
return users[0], nil
}
// ParseUserInfo parses the user information from the LDAP server into a domain.AuthenticatorUserInfo struct.
func (l LdapAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
isAdmin, err := internal.LdapIsMemberOf(raw[l.cfg.FieldMap.GroupMembership].([][]byte), l.cfg.ParsedAdminGroupDN)
if err != nil {

View File

@ -16,6 +16,8 @@ import (
"github.com/h44z/wg-portal/internal/domain"
)
// PlainOauthAuthenticator is an authenticator that uses OAuth for authentication.
// User information is retrieved from the specified user info endpoint.
type PlainOauthAuthenticator struct {
name string
cfg *oauth2.Config
@ -58,22 +60,27 @@ func newPlainOauthAuthenticator(
return provider, nil
}
// GetName returns the name of the OAuth authenticator.
func (p PlainOauthAuthenticator) GetName() string {
return p.name
}
// RegistrationEnabled returns whether registration is enabled for the OAuth authenticator.
func (p PlainOauthAuthenticator) RegistrationEnabled() bool {
return p.registrationEnabled
}
func (p PlainOauthAuthenticator) GetType() domain.AuthenticatorType {
return domain.AuthenticatorTypeOAuth
// GetType returns the type of the authenticator.
func (p PlainOauthAuthenticator) GetType() AuthenticatorType {
return AuthenticatorTypeOAuth
}
// AuthCodeURL returns the URL to redirect the user to for authentication.
func (p PlainOauthAuthenticator) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
return p.cfg.AuthCodeURL(state, opts...)
}
// Exchange exchanges the OAuth code for a token.
func (p PlainOauthAuthenticator) Exchange(
ctx context.Context,
code string,
@ -82,6 +89,7 @@ func (p PlainOauthAuthenticator) Exchange(
return p.cfg.Exchange(ctx, code, opts...)
}
// GetUserInfo retrieves the user information from the user info endpoint.
func (p PlainOauthAuthenticator) GetUserInfo(
ctx context.Context,
token *oauth2.Token,
@ -119,6 +127,7 @@ func (p PlainOauthAuthenticator) GetUserInfo(
return userFields, nil
}
// ParseUserInfo parses the user information from the raw data.
func (p PlainOauthAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
return parseOauthUserInfo(p.userInfoMapping, p.userAdminMapping, raw)
}

View File

@ -14,6 +14,7 @@ import (
"github.com/h44z/wg-portal/internal/domain"
)
// OidcAuthenticator is an authenticator for OpenID Connect providers.
type OidcAuthenticator struct {
name string
provider *oidc.Provider
@ -60,22 +61,27 @@ func newOidcAuthenticator(
return provider, nil
}
// GetName returns the name of the authenticator.
func (o OidcAuthenticator) GetName() string {
return o.name
}
// RegistrationEnabled returns whether registration is enabled for this authenticator.
func (o OidcAuthenticator) RegistrationEnabled() bool {
return o.registrationEnabled
}
func (o OidcAuthenticator) GetType() domain.AuthenticatorType {
return domain.AuthenticatorTypeOidc
// GetType returns the type of the authenticator.
func (o OidcAuthenticator) GetType() AuthenticatorType {
return AuthenticatorTypeOidc
}
// AuthCodeURL returns the URL for the OAuth2 flow.
func (o OidcAuthenticator) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
return o.cfg.AuthCodeURL(state, opts...)
}
// Exchange exchanges the code for a token.
func (o OidcAuthenticator) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (
*oauth2.Token,
error,
@ -83,6 +89,7 @@ func (o OidcAuthenticator) Exchange(ctx context.Context, code string, opts ...oa
return o.cfg.Exchange(ctx, code, opts...)
}
// GetUserInfo retrieves the user info from the token.
func (o OidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (
map[string]any,
error,
@ -114,6 +121,7 @@ func (o OidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token,
return tokenFields, nil
}
// ParseUserInfo parses the user info.
func (o OidcAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
return parseOauthUserInfo(o.userInfoMapping, o.userAdminMapping, raw)
}

View File

@ -10,7 +10,6 @@ import (
"os"
"strings"
evbus "github.com/vardius/message-bus"
"github.com/yeqown/go-qrcode/v2"
"github.com/yeqown/go-qrcode/writer/compressed"
@ -19,19 +18,56 @@ import (
"github.com/h44z/wg-portal/internal/domain"
)
type Manager struct {
cfg *config.Config
bus evbus.MessageBus
tplHandler *TemplateHandler
// region dependencies
fsRepo FileSystemRepo
users UserDatabaseRepo
wg WireguardDatabaseRepo
type UserDatabaseRepo interface {
// GetUser returns the user with the given identifier from the SQL database.
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
}
type WireguardDatabaseRepo interface {
// GetInterfaceAndPeers returns the interface and all peers associated with it.
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
// GetPeer returns the peer with the given identifier.
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
// GetInterface returns the interface with the given identifier.
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
}
type FileSystemRepo interface {
// WriteFile writes the contents to the file at the given path.
WriteFile(path string, contents io.Reader) error
}
type TemplateRenderer interface {
// GetInterfaceConfig returns the configuration file for the given interface.
GetInterfaceConfig(iface *domain.Interface, peers []domain.Peer) (io.Reader, error)
// GetPeerConfig returns the configuration file for the given peer.
GetPeerConfig(peer *domain.Peer) (io.Reader, error)
}
type EventBus interface {
// Subscribe subscribes to the given topic.
Subscribe(topic string, fn any) error
}
// endregion dependencies
// Manager is responsible for managing the configuration files of the WireGuard interfaces and peers.
type Manager struct {
cfg *config.Config
bus EventBus
tplHandler TemplateRenderer
fsRepo FileSystemRepo
users UserDatabaseRepo
wg WireguardDatabaseRepo
}
// NewConfigFileManager creates a new Manager instance.
func NewConfigFileManager(
cfg *config.Config,
bus evbus.MessageBus,
bus EventBus,
users UserDatabaseRepo,
wg WireguardDatabaseRepo,
fsRepo FileSystemRepo,
@ -115,6 +151,8 @@ func (m Manager) handlePeerInterfaceUpdatedEvent(id domain.InterfaceIdentifier)
}
}
// GetInterfaceConfig returns the configuration file for the given interface.
// The file is structured in wg-quick format.
func (m Manager) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err
@ -128,6 +166,8 @@ func (m Manager) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIden
return m.tplHandler.GetInterfaceConfig(iface, peers)
}
// GetPeerConfig returns the configuration file for the given peer.
// The file is structured in wg-quick format.
func (m Manager) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
peer, err := m.wg.GetPeer(ctx, id)
if err != nil {
@ -141,6 +181,7 @@ func (m Manager) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (i
return m.tplHandler.GetPeerConfig(peer)
}
// GetPeerConfigQrCode returns a QR code image containing the configuration for the given peer.
func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
peer, err := m.wg.GetPeer(ctx, id)
if err != nil {
@ -191,6 +232,7 @@ func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifi
return buf, nil
}
// PersistInterfaceConfig writes the configuration file for the given interface to the file system.
func (m Manager) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error {
iface, peers, err := m.wg.GetInterfaceAndPeers(ctx, id)
if err != nil {
@ -213,4 +255,5 @@ type nopCloser struct {
io.Writer
}
// Close is a no-op for the nopCloser.
func (nopCloser) Close() error { return nil }

View File

@ -1,22 +0,0 @@
package configfile
import (
"context"
"io"
"github.com/h44z/wg-portal/internal/domain"
)
type UserDatabaseRepo interface {
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
}
type WireguardDatabaseRepo interface {
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
}
type FileSystemRepo interface {
WriteFile(path string, contents io.Reader) error
}

View File

@ -13,6 +13,8 @@ import (
//go:embed tpl_files/*
var TemplateFiles embed.FS
// TemplateHandler is responsible for rendering the WireGuard configuration files
// based on the provided templates.
type TemplateHandler struct {
templates *template.Template
}
@ -34,6 +36,7 @@ func newTemplateHandler() (*TemplateHandler, error) {
return handler, nil
}
// GetInterfaceConfig returns the rendered configuration file for a WireGuard interface.
func (c TemplateHandler) GetInterfaceConfig(cfg *domain.Interface, peers []domain.Peer) (io.Reader, error) {
var tplBuff bytes.Buffer
@ -51,6 +54,7 @@ func (c TemplateHandler) GetInterfaceConfig(cfg *domain.Interface, peers []domai
return &tplBuff, nil
}
// GetPeerConfig returns the rendered configuration file for a WireGuard peer.
func (c TemplateHandler) GetPeerConfig(peer *domain.Peer) (io.Reader, error) {
var tplBuff bytes.Buffer

View File

@ -10,16 +10,60 @@ import (
"github.com/h44z/wg-portal/internal/domain"
)
type Manager struct {
cfg *config.Config
tplHandler *TemplateHandler
// region dependencies
type Mailer interface {
// Send sends an email with the given subject and body to the given recipients.
Send(ctx context.Context, subject, body string, to []string, options *domain.MailOptions) error
}
type ConfigFileManager interface {
// GetInterfaceConfig returns the configuration for the given interface.
GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error)
// GetPeerConfig returns the configuration for the given peer.
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
// GetPeerConfigQrCode returns the QR code for the given peer.
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
}
type UserDatabaseRepo interface {
// GetUser returns the user with the given identifier.
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
}
type WireguardDatabaseRepo interface {
// GetInterfaceAndPeers returns the interface and all peers for the given interface identifier.
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
// GetPeer returns the peer with the given identifier.
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
// GetInterface returns the interface with the given identifier.
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
}
type TemplateRenderer interface {
// GetConfigMail returns the text and html template for the mail with a link.
GetConfigMail(user *domain.User, link string) (io.Reader, io.Reader, error)
// GetConfigMailWithAttachment returns the text and html template for the mail with an attachment.
GetConfigMailWithAttachment(user *domain.User, cfgName, qrName string) (
io.Reader,
io.Reader,
error,
)
}
// endregion dependencies
type Manager struct {
cfg *config.Config
tplHandler TemplateRenderer
mailer Mailer
configFiles ConfigFileManager
users UserDatabaseRepo
wg WireguardDatabaseRepo
}
// NewMailManager creates a new mail manager.
func NewMailManager(
cfg *config.Config,
mailer Mailer,
@ -44,6 +88,7 @@ func NewMailManager(
return m, nil
}
// SendPeerEmail sends an email to the user linked to the given peers.
func (m Manager) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error {
for _, peerId := range peers {
peer, err := m.wg.GetPeer(ctx, peerId)

View File

@ -1,28 +0,0 @@
package mail
import (
"context"
"io"
"github.com/h44z/wg-portal/internal/domain"
)
type Mailer interface {
Send(ctx context.Context, subject, body string, to []string, options *domain.MailOptions) error
}
type ConfigFileManager interface {
GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error)
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
}
type UserDatabaseRepo interface {
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
}
type WireguardDatabaseRepo interface {
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
}

View File

@ -14,6 +14,7 @@ import (
//go:embed tpl_files/*
var TemplateFiles embed.FS
// TemplateHandler is a struct that holds the html and text templates.
type TemplateHandler struct {
portalUrl string
htmlTemplates *htmlTemplate.Template
@ -40,6 +41,7 @@ func newTemplateHandler(portalUrl string) (*TemplateHandler, error) {
return handler, nil
}
// GetConfigMail returns the text and html template for the mail with a link.
func (c TemplateHandler) GetConfigMail(user *domain.User, link string) (io.Reader, io.Reader, error) {
var tplBuff bytes.Buffer
var htmlTplBuff bytes.Buffer
@ -65,6 +67,7 @@ func (c TemplateHandler) GetConfigMail(user *domain.User, link string) (io.Reade
return &tplBuff, &htmlTplBuff, nil
}
// GetConfigMailWithAttachment returns the text and html template for the mail with an attachment.
func (c TemplateHandler) GetConfigMailWithAttachment(user *domain.User, cfgName, qrName string) (
io.Reader,
io.Reader,

View File

@ -1,77 +0,0 @@
package app
import (
"context"
"io"
"github.com/h44z/wg-portal/internal/domain"
)
type Authenticator interface {
GetExternalLoginProviders(_ context.Context) []domain.LoginProviderInfo
IsUserValid(ctx context.Context, id domain.UserIdentifier) bool
PlainLogin(ctx context.Context, username, password string) (*domain.User, error)
OauthLoginStep1(_ context.Context, providerId string) (authCodeUrl, state, nonce string, err error)
OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, error)
}
type UserManager interface {
RegisterUser(ctx context.Context, user *domain.User) error
NewUser(ctx context.Context, user *domain.User) error
StartBackgroundJobs(ctx context.Context)
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
GetAllUsers(ctx context.Context) ([]domain.User, error)
UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error)
CreateUser(ctx context.Context, user *domain.User) (*domain.User, error)
DeleteUser(ctx context.Context, id domain.UserIdentifier) error
ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
}
type WireGuardManager interface {
StartBackgroundJobs(ctx context.Context)
GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error)
ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error)
RestoreInterfaceState(ctx context.Context, updateDbOnError bool, filter ...domain.InterfaceIdentifier) error
CreateDefaultPeer(ctx context.Context, userId domain.UserIdentifier) error
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error)
GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error)
GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error)
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
GetUserInterfaces(ctx context.Context, id domain.UserIdentifier) ([]domain.Interface, error)
PrepareInterface(ctx context.Context) (*domain.Interface, error)
CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error)
UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error)
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error)
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
CreatePeer(ctx context.Context, p *domain.Peer) (*domain.Peer, error)
CreateMultiplePeers(
ctx context.Context,
id domain.InterfaceIdentifier,
r *domain.PeerCreationRequest,
) ([]domain.Peer, error)
UpdatePeer(ctx context.Context, p *domain.Peer) (*domain.Peer, error)
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error
}
type StatisticsCollector interface {
StartBackgroundJobs(ctx context.Context)
}
type ConfigFileManager interface {
GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error)
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error
}
type MailManager interface {
SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error
}
type ApiV1Manager interface {
ApiV1GetUsers(ctx context.Context) ([]domain.User, error)
}

View File

@ -1,12 +0,0 @@
package route
import (
"context"
"github.com/h44z/wg-portal/internal/domain"
)
type InterfaceAndPeerDatabaseRepo interface {
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
}

View File

@ -5,7 +5,6 @@ import (
"fmt"
"log/slog"
evbus "github.com/vardius/message-bus"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/wgctrl"
@ -17,6 +16,22 @@ import (
"github.com/h44z/wg-portal/internal/lowlevel"
)
// region dependencies
type InterfaceAndPeerDatabaseRepo interface {
// GetAllInterfaces returns all interfaces
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
// GetInterfacePeers returns all peers for a given interface
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
}
type EventBus interface {
// Subscribe subscribes to a topic
Subscribe(topic string, fn interface{}) error
}
// endregion dependencies
type routeRuleInfo struct {
ifaceId domain.InterfaceIdentifier
fwMark uint32
@ -29,14 +44,15 @@ type routeRuleInfo struct {
// for default routes.
type Manager struct {
cfg *config.Config
bus evbus.MessageBus
wg lowlevel.WireGuardClient
nl lowlevel.NetlinkClient
db InterfaceAndPeerDatabaseRepo
bus EventBus
wg lowlevel.WireGuardClient
nl lowlevel.NetlinkClient
db InterfaceAndPeerDatabaseRepo
}
func NewRouteManager(cfg *config.Config, bus evbus.MessageBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) {
// NewRouteManager creates a new route manager instance.
func NewRouteManager(cfg *config.Config, bus EventBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) {
wg, err := wgctrl.New()
if err != nil {
panic("failed to init wgctrl: " + err.Error())
@ -63,7 +79,10 @@ func (m Manager) connectToMessageBus() {
_ = m.bus.Subscribe(app.TopicRouteRemove, m.handleRouteRemoveEvent)
}
// StartBackgroundJobs starts background jobs for the route manager.
// This method is non-blocking and returns immediately.
func (m Manager) StartBackgroundJobs(_ context.Context) {
// this is a no-op for now
}
func (m Manager) handleRouteUpdateEvent(srcDescription string) {

View File

@ -1,20 +0,0 @@
package users
import (
"context"
"github.com/h44z/wg-portal/internal/domain"
)
type UserDatabaseRepo interface {
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
GetUserByEmail(ctx context.Context, email string) (*domain.User, error)
GetAllUsers(ctx context.Context) ([]domain.User, error)
FindUsers(ctx context.Context, search string) ([]domain.User, error)
SaveUser(ctx context.Context, id domain.UserIdentifier, updateFunc func(u *domain.User) (*domain.User, error)) error
DeleteUser(ctx context.Context, id domain.UserIdentifier) error
}
type PeerDatabaseRepo interface {
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
}

View File

@ -11,7 +11,6 @@ import (
"github.com/go-ldap/ldap/v3"
"github.com/google/uuid"
evbus "github.com/vardius/message-bus"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/app"
@ -19,15 +18,46 @@ import (
"github.com/h44z/wg-portal/internal/domain"
)
// region dependencies
type UserDatabaseRepo interface {
// GetUser returns the user with the given identifier.
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
// GetUserByEmail returns the user with the given email address.
GetUserByEmail(ctx context.Context, email string) (*domain.User, error)
// GetAllUsers returns all users.
GetAllUsers(ctx context.Context) ([]domain.User, error)
// FindUsers returns all users matching the search string.
FindUsers(ctx context.Context, search string) ([]domain.User, error)
// SaveUser saves the user with the given identifier.
SaveUser(ctx context.Context, id domain.UserIdentifier, updateFunc func(u *domain.User) (*domain.User, error)) error
// DeleteUser deletes the user with the given identifier.
DeleteUser(ctx context.Context, id domain.UserIdentifier) error
}
type PeerDatabaseRepo interface {
// GetUserPeers returns all peers linked to the given user.
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
}
type EventBus interface {
// Publish sends a message to the message bus.
Publish(topic string, args ...any)
}
// endregion dependencies
// Manager is the user manager.
type Manager struct {
cfg *config.Config
bus evbus.MessageBus
bus EventBus
users UserDatabaseRepo
peers PeerDatabaseRepo
}
func NewUserManager(cfg *config.Config, bus evbus.MessageBus, users UserDatabaseRepo, peers PeerDatabaseRepo) (
// NewUserManager creates a new user manager instance.
func NewUserManager(cfg *config.Config, bus EventBus, users UserDatabaseRepo, peers PeerDatabaseRepo) (
*Manager,
error,
) {
@ -41,6 +71,7 @@ func NewUserManager(cfg *config.Config, bus evbus.MessageBus, users UserDatabase
return m, nil
}
// RegisterUser registers a new user.
func (m Manager) RegisterUser(ctx context.Context, user *domain.User) error {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return err
@ -56,6 +87,7 @@ func (m Manager) RegisterUser(ctx context.Context, user *domain.User) error {
return nil
}
// NewUser creates a new user.
func (m Manager) NewUser(ctx context.Context, user *domain.User) error {
if user.Identifier == "" {
return errors.New("missing user identifier")
@ -90,12 +122,13 @@ func (m Manager) NewUser(ctx context.Context, user *domain.User) error {
return nil
}
// StartBackgroundJobs starts the background jobs.
// This method is non-blocking and returns immediately.
func (m Manager) StartBackgroundJobs(ctx context.Context) {
go m.runLdapSynchronizationService(ctx)
}
// GetUser returns the user with the given identifier.
func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
return nil, err
@ -112,6 +145,7 @@ func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain
return user, nil
}
// GetUserByEmail returns the user with the given email address.
func (m Manager) GetUserByEmail(ctx context.Context, email string) (*domain.User, error) {
user, err := m.users.GetUserByEmail(ctx, email)
@ -130,6 +164,7 @@ func (m Manager) GetUserByEmail(ctx context.Context, email string) (*domain.User
return user, nil
}
// GetAllUsers returns all users.
func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err
@ -162,6 +197,7 @@ func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) {
return users, nil
}
// UpdateUser updates the user with the given identifier.
func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
if err := domain.ValidateUserAccessRights(ctx, user.Identifier); err != nil {
return nil, err
@ -203,6 +239,7 @@ func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.Use
return user, nil
}
// CreateUser creates a new user.
func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err
@ -236,6 +273,7 @@ func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.Use
return user, nil
}
// DeleteUser deletes the user with the given identifier.
func (m Manager) DeleteUser(ctx context.Context, id domain.UserIdentifier) error {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return err
@ -260,6 +298,7 @@ func (m Manager) DeleteUser(ctx context.Context, id domain.UserIdentifier) error
return nil
}
// ActivateApi activates the API access for the user with the given identifier.
func (m Manager) ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
user, err := m.users.GetUser(ctx, id)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
@ -287,6 +326,7 @@ func (m Manager) ActivateApi(ctx context.Context, id domain.UserIdentifier) (*do
return user, nil
}
// DeactivateApi deactivates the API access for the user with the given identifier.
func (m Manager) DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
user, err := m.users.GetUser(ctx, id)
if err != nil && !errors.Is(err, domain.ErrNotFound) {

View File

@ -1,87 +0,0 @@
package wireguard
import (
"context"
"github.com/h44z/wg-portal/internal/domain"
)
type InterfaceAndPeerDatabaseRepo interface {
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error)
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
FindInterfaces(ctx context.Context, search string) ([]domain.Interface, error)
GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error)
SaveInterface(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(in *domain.Interface) (*domain.Interface, error),
) error
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ([]domain.Peer, error)
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
FindUserPeers(ctx context.Context, id domain.UserIdentifier, search string) ([]domain.Peer, error)
SavePeer(
ctx context.Context,
id domain.PeerIdentifier,
updateFunc func(in *domain.Peer) (*domain.Peer, error),
) error
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
}
type StatisticsDatabaseRepo interface {
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
UpdatePeerStatus(
ctx context.Context,
id domain.PeerIdentifier,
updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error),
) error
UpdateInterfaceStatus(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error),
) error
DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error
}
type InterfaceController interface {
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (
*domain.PhysicalPeer,
error,
)
SaveInterface(
_ context.Context,
id domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
SavePeer(
_ context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
}
type WgQuickController interface {
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
UnsetDNS(id domain.InterfaceIdentifier) error
}
type MetricsServer interface {
UpdateInterfaceMetrics(status domain.InterfaceStatus)
UpdatePeerMetrics(peer *domain.Peer, status domain.PeerStatus)
}

View File

@ -7,31 +7,63 @@ import (
"time"
probing "github.com/prometheus-community/pro-bing"
evbus "github.com/vardius/message-bus"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
type StatisticsDatabaseRepo interface {
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
UpdatePeerStatus(
ctx context.Context,
id domain.PeerIdentifier,
updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error),
) error
UpdateInterfaceStatus(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error),
) error
DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error
}
type StatisticsInterfaceController interface {
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
}
type StatisticsMetricsServer interface {
UpdateInterfaceMetrics(status domain.InterfaceStatus)
UpdatePeerMetrics(peer *domain.Peer, status domain.PeerStatus)
}
type StatisticsEventBus interface {
// Subscribe subscribes to a topic
Subscribe(topic string, fn interface{}) error
}
type StatisticsCollector struct {
cfg *config.Config
bus evbus.MessageBus
bus StatisticsEventBus
pingWaitGroup sync.WaitGroup
pingJobs chan domain.Peer
db StatisticsDatabaseRepo
wg InterfaceController
ms MetricsServer
wg StatisticsInterfaceController
ms StatisticsMetricsServer
}
// NewStatisticsCollector creates a new statistics collector.
func NewStatisticsCollector(
cfg *config.Config,
bus evbus.MessageBus,
bus StatisticsEventBus,
db StatisticsDatabaseRepo,
wg InterfaceController,
ms MetricsServer,
wg StatisticsInterfaceController,
ms StatisticsMetricsServer,
) (*StatisticsCollector, error) {
c := &StatisticsCollector{
cfg: cfg,
@ -47,6 +79,8 @@ func NewStatisticsCollector(
return c, nil
}
// StartBackgroundJobs starts the background jobs for the statistics collector.
// This method is non-blocking and returns immediately.
func (c *StatisticsCollector) StartBackgroundJobs(ctx context.Context) {
c.startPingWorkers(ctx)
c.startInterfaceDataFetcher(ctx)

View File

@ -5,17 +5,74 @@ import (
"log/slog"
"time"
evbus "github.com/vardius/message-bus"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
type Manager struct {
cfg *config.Config
bus evbus.MessageBus
// region dependencies
type InterfaceAndPeerDatabaseRepo interface {
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error)
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error)
SaveInterface(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(in *domain.Interface) (*domain.Interface, error),
) error
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
SavePeer(
ctx context.Context,
id domain.PeerIdentifier,
updateFunc func(in *domain.Peer) (*domain.Peer, error),
) error
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
}
type InterfaceController interface {
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
SaveInterface(
_ context.Context,
id domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
SavePeer(
_ context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
}
type WgQuickController interface {
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
UnsetDNS(id domain.InterfaceIdentifier) error
}
type EventBus interface {
// Publish sends a message to the message bus.
Publish(topic string, args ...any)
// Subscribe subscribes to a topic
Subscribe(topic string, fn interface{}) error
}
// endregion dependencies
type Manager struct {
cfg *config.Config
bus EventBus
db InterfaceAndPeerDatabaseRepo
wg InterfaceController
quick WgQuickController
@ -23,7 +80,7 @@ type Manager struct {
func NewWireGuardManager(
cfg *config.Config,
bus evbus.MessageBus,
bus EventBus,
wg InterfaceController,
quick WgQuickController,
db InterfaceAndPeerDatabaseRepo,
@ -41,6 +98,8 @@ func NewWireGuardManager(
return m, nil
}
// StartBackgroundJobs starts background jobs like the expired peers check.
// This method is non-blocking.
func (m Manager) StartBackgroundJobs(ctx context.Context) {
go m.runExpiredPeersCheck(ctx)
}

View File

@ -13,6 +13,8 @@ import (
"github.com/h44z/wg-portal/internal/domain"
)
// GetImportableInterfaces returns all physical interfaces that are available on the system.
// This function also returns interfaces that are already available in the database.
func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err
@ -26,6 +28,7 @@ func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.Physical
return physicalInterfaces, nil
}
// GetInterfaceAndPeers returns the interface and all peers for the given interface identifier.
func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
*domain.Interface,
[]domain.Peer,
@ -38,6 +41,7 @@ func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceId
return m.db.GetInterfaceAndPeers(ctx, id)
}
// GetAllInterfaces returns all interfaces that are available in the database.
func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err
@ -46,6 +50,7 @@ func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, erro
return m.db.GetAllInterfaces(ctx)
}
// GetAllInterfacesAndPeers returns all interfaces and their peers.
func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, nil, err
@ -97,6 +102,7 @@ func (m Manager) GetUserInterfaces(ctx context.Context, _ domain.UserIdentifier)
return userInterfaces, nil
}
// ImportNewInterfaces imports all new physical interfaces that are available on the system.
func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return 0, err
@ -148,6 +154,7 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter
return imported, nil
}
// ApplyPeerDefaults applies the interface defaults to all peers of the given interface.
func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return err
@ -179,6 +186,8 @@ func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) er
return nil
}
// RestoreInterfaceState restores the state of all physical interfaces and their peers.
// The final state of the interfaces and peers will be the same as stored in the database.
func (m Manager) RestoreInterfaceState(
ctx context.Context,
updateDbOnError bool,
@ -296,6 +305,7 @@ func (m Manager) RestoreInterfaceState(
return nil
}
// PrepareInterface generates a new interface with fresh keys, ip addresses and a listen port.
func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err
@ -376,6 +386,7 @@ func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error
return freshInterface, nil
}
// CreateInterface creates a new interface with the given configuration.
func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err
@ -401,6 +412,7 @@ func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*do
return in, nil
}
// UpdateInterface updates the given interface with the new configuration.
func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, nil, err
@ -423,6 +435,7 @@ func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*do
return in, existingPeers, nil
}
// DeleteInterface deletes the given interface.
func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return err

View File

@ -11,6 +11,7 @@ import (
"github.com/h44z/wg-portal/internal/domain"
)
// CreateDefaultPeer creates a default peer for the given user on all server interfaces.
func (m Manager) CreateDefaultPeer(ctx context.Context, userId domain.UserIdentifier) error {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return err
@ -55,6 +56,7 @@ func (m Manager) CreateDefaultPeer(ctx context.Context, userId domain.UserIdenti
return nil
}
// GetUserPeers returns all peers for the given user.
func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
return nil, err
@ -63,6 +65,7 @@ func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]
return m.db.GetUserPeers(ctx, id)
}
// PreparePeer prepares a new peer for the given interface with fresh keys and ip addresses.
func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) {
if !m.cfg.Core.SelfProvisioningAllowed {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
@ -143,6 +146,7 @@ func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier)
return freshPeer, nil
}
// GetPeer returns the peer with the given identifier.
func (m Manager) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) {
peer, err := m.db.GetPeer(ctx, id)
if err != nil {
@ -156,6 +160,7 @@ func (m Manager) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain
return peer, nil
}
// CreatePeer creates a new peer.
func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
if !m.cfg.Core.SelfProvisioningAllowed {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
@ -201,6 +206,8 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
return peer, nil
}
// CreateMultiplePeers creates multiple new peers for the given user identifiers.
// It calls PreparePeer for each user identifier in the request.
func (m Manager) CreateMultiplePeers(
ctx context.Context,
interfaceId domain.InterfaceIdentifier,
@ -243,6 +250,7 @@ func (m Manager) CreateMultiplePeers(
return createdPeers, nil
}
// UpdatePeer updates the given peer.
func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
existingPeer, err := m.db.GetPeer(ctx, peer.Identifier)
if err != nil {
@ -309,6 +317,7 @@ func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
return peer, nil
}
// DeletePeer deletes the peer with the given identifier.
func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error {
peer, err := m.db.GetPeer(ctx, id)
if err != nil {
@ -341,6 +350,7 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
return nil
}
// GetPeerStats returns the status of the peer with the given identifier.
func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error) {
_, peers, err := m.db.GetInterfaceAndPeers(ctx, id)
if err != nil {
@ -359,6 +369,7 @@ func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier
return m.db.GetPeersStats(ctx, peerIds...)
}
// GetUserPeerStats returns the status of all peers for the given user.
func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) {
if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
return nil, err

View File

@ -8,6 +8,7 @@ import (
"github.com/go-ldap/ldap/v3"
)
// Auth contains all authentication providers.
type Auth struct {
// OpenIDConnect contains a list of OpenID Connect providers.
OpenIDConnect []OpenIDConnectProvider `yaml:"oidc"`
@ -17,6 +18,7 @@ type Auth struct {
Ldap []LdapProvider `yaml:"ldap"`
}
// BaseFields contains the basic fields that are used to map user information from the authentication providers.
type BaseFields struct {
// UserIdentifier is the name of the field that contains the user identifier.
UserIdentifier string `yaml:"user_identifier"`
@ -32,6 +34,7 @@ type BaseFields struct {
Department string `yaml:"department"`
}
// OauthFields contains extra fields that are used to map user information from OAuth providers.
type OauthFields struct {
BaseFields `yaml:",inline"`
// IsAdmin is the name of the field that contains the admin flag.
@ -107,12 +110,14 @@ func (o *OauthAdminMapping) GetAdminGroupRegex() *regexp.Regexp {
return o.adminGroupRegex
}
// LdapFields contains extra fields that are used to map user information from LDAP providers.
type LdapFields struct {
BaseFields `yaml:",inline"`
// GroupMembership is the name of the LDAP field that contains the groups to which the user belongs.
GroupMembership string `yaml:"memberof"`
}
// LdapProvider contains the configuration for the LDAP connection.
type LdapProvider struct {
// ProviderName is an internal name that is used to distinguish LDAP servers. It must not contain spaces or special characters.
ProviderName string `yaml:"provider_name"`
@ -163,6 +168,7 @@ type LdapProvider struct {
LogUserInfo bool `yaml:"log_user_info"`
}
// OpenIDConnectProvider contains the configuration for the OpenID Connect provider.
type OpenIDConnectProvider struct {
// ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters.
ProviderName string `yaml:"provider_name"`
@ -196,6 +202,7 @@ type OpenIDConnectProvider struct {
LogUserInfo bool `yaml:"log_user_info"`
}
// OAuthProvider contains the configuration for the OAuth provider.
type OAuthProvider struct {
// ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters.
ProviderName string `yaml:"provider_name"`

View File

@ -10,6 +10,7 @@ import (
"gopkg.in/yaml.v3"
)
// Config is the main configuration struct.
type Config struct {
Core struct {
// AdminUser defines the default administrator account that will be created
@ -179,6 +180,7 @@ func GetConfig() (*Config, error) {
return cfg, nil
}
// loadConfigFile loads the configuration from a YAML file into the given cfg struct.
func loadConfigFile(cfg any, filename string) error {
data, err := envsubst.ReadFile(filename)
if err != nil {

View File

@ -2,6 +2,8 @@ package config
import "time"
// SupportedDatabase is a type for the supported database types.
// Supported: mysql, mssql, postgres, sqlite
type SupportedDatabase string
const (
@ -11,6 +13,7 @@ const (
DatabaseSQLite SupportedDatabase = "sqlite"
)
// DatabaseConfig contains the configuration for the database connection.
type DatabaseConfig struct {
// Debug enables logging of all database statements
Debug bool `yaml:"debug"`

View File

@ -1,5 +1,7 @@
package config
// MailEncryption is the type of the SMTP encryption.
// Supported: none, tls, starttls
type MailEncryption string
const (
@ -8,6 +10,8 @@ const (
MailEncryptionStartTLS MailEncryption = "starttls"
)
// MailAuthType is the type of the SMTP authentication.
// Supported: plain, login, crammd5
type MailAuthType string
const (
@ -16,6 +20,7 @@ const (
MailAuthCramMD5 MailAuthType = "crammd5"
)
// MailConfig contains the configuration for the mail server which is used to send emails.
type MailConfig struct {
// Host is the hostname or IP of the SMTP server
Host string `yaml:"host"`

View File

@ -1,5 +1,6 @@
package config
// WebConfig contains the configuration for the web server.
type WebConfig struct {
// RequestLogging enables logging of all HTTP requests.
RequestLogging bool `yaml:"request_logging"`

View File

@ -1,11 +1,5 @@
package domain
import (
"context"
"golang.org/x/oauth2"
)
type LoginProvider string
type LoginProviderInfo struct {
@ -24,28 +18,3 @@ type AuthenticatorUserInfo struct {
Department string
IsAdmin bool
}
type AuthenticatorType string
const (
AuthenticatorTypeOAuth AuthenticatorType = "oauth"
AuthenticatorTypeOidc AuthenticatorType = "oidc"
)
type OauthAuthenticator interface {
GetName() string
GetType() AuthenticatorType
AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (map[string]any, error)
ParseUserInfo(raw map[string]any) (*AuthenticatorUserInfo, error)
RegistrationEnabled() bool
}
type LdapAuthenticator interface {
GetName() string
PlaintextAuthentication(userId UserIdentifier, plainPassword string) error
GetUserInfo(ctx context.Context, username UserIdentifier) (map[string]any, error)
ParseUserInfo(raw map[string]any) (*AuthenticatorUserInfo, error)
RegistrationEnabled() bool
}

View File

@ -33,6 +33,7 @@ func (p KeyPair) GetPublicKey() wgtypes.Key {
type PreSharedKey string
// NewFreshKeypair generates a new key pair.
func NewFreshKeypair() (KeyPair, error) {
privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
@ -45,6 +46,7 @@ func NewFreshKeypair() (KeyPair, error) {
}, nil
}
// NewPreSharedKey generates a new pre-shared key.
func NewPreSharedKey() (PreSharedKey, error) {
preSharedKey, err := wgtypes.GenerateKey()
if err != nil {
@ -54,6 +56,8 @@ func NewPreSharedKey() (PreSharedKey, error) {
return PreSharedKey(preSharedKey.String()), nil
}
// PublicKeyFromPrivateKey returns the public key for a given private key.
// If the private key is invalid, an empty string is returned.
func PublicKeyFromPrivateKey(key string) string {
privKey, err := wgtypes.ParseKey(key)
if err != nil {

View File

@ -0,0 +1,56 @@
package domain
import (
"encoding/base64"
"testing"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
func TestKeyPair_GetPrivateKeyBytesReturnsCorrectBytes(t *testing.T) {
keyPair := KeyPair{PrivateKey: base64.StdEncoding.EncodeToString([]byte("privateKey"))}
expected := []byte("privateKey")
assert.Equal(t, expected, keyPair.GetPrivateKeyBytes())
}
func TestKeyPair_GetPublicKeyBytesReturnsCorrectBytes(t *testing.T) {
keyPair := KeyPair{PublicKey: base64.StdEncoding.EncodeToString([]byte("publicKey"))}
expected := []byte("publicKey")
assert.Equal(t, expected, keyPair.GetPublicKeyBytes())
}
func TestKeyPair_GetPrivateKeyReturnsCorrectKey(t *testing.T) {
privateKey, _ := wgtypes.GeneratePrivateKey()
keyPair := KeyPair{PrivateKey: privateKey.String()}
assert.Equal(t, privateKey, keyPair.GetPrivateKey())
}
func TestKeyPair_GetPublicKeyReturnsCorrectKey(t *testing.T) {
privateKey, _ := wgtypes.GeneratePrivateKey()
keyPair := KeyPair{PublicKey: privateKey.PublicKey().String()}
assert.Equal(t, privateKey.PublicKey(), keyPair.GetPublicKey())
}
func TestNewFreshKeypairGeneratesValidKeypair(t *testing.T) {
keyPair, err := NewFreshKeypair()
assert.NoError(t, err)
assert.NotEmpty(t, keyPair.PrivateKey)
assert.NotEmpty(t, keyPair.PublicKey)
}
func TestNewPreSharedKeyGeneratesValidKey(t *testing.T) {
preSharedKey, err := NewPreSharedKey()
assert.NoError(t, err)
assert.NotEmpty(t, preSharedKey)
}
func TestPublicKeyFromPrivateKeyReturnsCorrectPublicKey(t *testing.T) {
privateKey, _ := wgtypes.GeneratePrivateKey()
expected := privateKey.PublicKey().String()
assert.Equal(t, expected, PublicKeyFromPrivateKey(privateKey.String()))
}
func TestPublicKeyFromPrivateKeyReturnsEmptyStringOnInvalidKey(t *testing.T) {
assert.Equal(t, "", PublicKeyFromPrivateKey("invalidKey"))
}

View File

@ -0,0 +1,83 @@
package domain
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestInterface_IsDisabledReturnsTrueWhenDisabled(t *testing.T) {
iface := &Interface{}
assert.False(t, iface.IsDisabled())
now := time.Now()
iface.Disabled = &now
assert.True(t, iface.IsDisabled())
}
func TestInterface_AddressStrReturnsCorrectString(t *testing.T) {
iface := &Interface{
Addresses: []Cidr{
{Cidr: "192.168.1.1/24", Addr: "192.168.1.1", NetLength: 24},
{Cidr: "10.0.0.1/24", Addr: "10.0.0.1", NetLength: 24},
},
}
expected := "192.168.1.1/24,10.0.0.1/24"
assert.Equal(t, expected, iface.AddressStr())
}
func TestInterface_GetConfigFileNameReturnsCorrectFileName(t *testing.T) {
iface := &Interface{Identifier: "wg0"}
expected := "wg0.conf"
assert.Equal(t, expected, iface.GetConfigFileName())
iface.Identifier = "wg0@123"
expected = "wg0123.conf"
assert.Equal(t, expected, iface.GetConfigFileName())
}
func TestInterface_GetAllowedIPsReturnsCorrectCidrs(t *testing.T) {
peer1 := Peer{
Interface: PeerInterfaceConfig{
Addresses: []Cidr{
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
},
},
}
peer2 := Peer{
Interface: PeerInterfaceConfig{
Addresses: []Cidr{
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
},
},
}
iface := &Interface{}
expected := []Cidr{
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
}
assert.Equal(t, expected, iface.GetAllowedIPs([]Peer{peer1, peer2}))
}
func TestInterface_ManageRoutingTableReturnsCorrectValue(t *testing.T) {
iface := &Interface{RoutingTable: "off"}
assert.False(t, iface.ManageRoutingTable())
iface.RoutingTable = "100"
assert.True(t, iface.ManageRoutingTable())
}
func TestInterface_GetRoutingTableReturnsCorrectValue(t *testing.T) {
iface := &Interface{RoutingTable: ""}
assert.Equal(t, 0, iface.GetRoutingTable())
iface.RoutingTable = "off"
assert.Equal(t, -1, iface.GetRoutingTable())
iface.RoutingTable = "0x64"
assert.Equal(t, 100, iface.GetRoutingTable())
iface.RoutingTable = "200"
assert.Equal(t, 200, iface.GetRoutingTable())
}

View File

@ -0,0 +1,42 @@
package domain
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestConfigOption_GetValueReturnsCorrectValue(t *testing.T) {
option := ConfigOption[int]{Value: 42}
assert.Equal(t, 42, option.GetValue())
}
func TestConfigOption_SetValueUpdatesValue(t *testing.T) {
option := ConfigOption[int]{Value: 42}
option.SetValue(100)
assert.Equal(t, 100, option.GetValue())
}
func TestConfigOption_TrySetValueUpdatesValueWhenOverridable(t *testing.T) {
option := ConfigOption[int]{Value: 42, Overridable: true}
result := option.TrySetValue(100)
assert.True(t, result)
assert.Equal(t, 100, option.GetValue())
}
func TestConfigOption_TrySetValueDoesNotUpdateValueWhenNotOverridable(t *testing.T) {
option := ConfigOption[int]{Value: 42, Overridable: false}
result := option.TrySetValue(100)
assert.False(t, result)
assert.Equal(t, 42, option.GetValue())
}
func TestNewConfigOptionCreatesCorrectOption(t *testing.T) {
option := NewConfigOption(42, true)
assert.Equal(t, 42, option.GetValue())
assert.True(t, option.Overridable)
option2 := NewConfigOption("str", false)
assert.Equal(t, "str", option2.GetValue())
assert.False(t, option2.Overridable)
}

View File

@ -0,0 +1,165 @@
package domain
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestPeer_IsDisabled(t *testing.T) {
peer := &Peer{}
assert.False(t, peer.IsDisabled())
now := time.Now()
peer.Disabled = &now
assert.True(t, peer.IsDisabled())
}
func TestPeer_IsExpired(t *testing.T) {
peer := &Peer{}
assert.False(t, peer.IsExpired())
expiredTime := time.Now().Add(-time.Hour)
peer.ExpiresAt = &expiredTime
assert.True(t, peer.IsExpired())
futureTime := time.Now().Add(time.Hour)
peer.ExpiresAt = &futureTime
assert.False(t, peer.IsExpired())
}
func TestPeer_CheckAliveAddress(t *testing.T) {
peer := &Peer{}
assert.Equal(t, "", peer.CheckAliveAddress())
peer.Interface.CheckAliveAddress = "192.168.1.1"
assert.Equal(t, "192.168.1.1", peer.CheckAliveAddress())
peer.Interface.CheckAliveAddress = ""
peer.Interface.Addresses = []Cidr{{Addr: "10.0.0.1"}}
assert.Equal(t, "10.0.0.1", peer.CheckAliveAddress())
}
func TestPeer_GetConfigFileName(t *testing.T) {
peer := &Peer{DisplayName: "Test Peer"}
expected := "Test_Peer.conf"
assert.Equal(t, expected, peer.GetConfigFileName())
peer.DisplayName = ""
peer.Identifier = "12345678"
expected = "wg_12345678.conf"
assert.Equal(t, expected, peer.GetConfigFileName())
}
func TestPeer_ApplyInterfaceDefaults(t *testing.T) {
peer := &Peer{
Endpoint: ConfigOption[string]{
Value: "",
Overridable: true,
},
EndpointPublicKey: ConfigOption[string]{
Value: "",
Overridable: true,
},
AllowedIPsStr: ConfigOption[string]{
Value: "1.1.1.1/32",
Overridable: false,
},
}
iface := &Interface{
PeerDefEndpoint: "192.168.1.1",
KeyPair: KeyPair{
PublicKey: "publicKey",
},
PeerDefAllowedIPsStr: "8.8.8.8/32",
}
peer.ApplyInterfaceDefaults(iface)
assert.Equal(t, "192.168.1.1", peer.Endpoint.GetValue())
assert.Equal(t, "publicKey", peer.EndpointPublicKey.GetValue())
assert.Equal(t, "1.1.1.1/32", peer.AllowedIPsStr.GetValue())
}
func TestPeer_GenerateDisplayName(t *testing.T) {
peer := &Peer{Identifier: "12345678"}
peer.GenerateDisplayName("Prefix")
expected := "Prefix Peer 12345678"
assert.Equal(t, expected, peer.DisplayName)
peer.GenerateDisplayName("")
expected = "Peer 12345678"
assert.Equal(t, expected, peer.DisplayName)
}
func TestPeer_OverwriteUserEditableFields(t *testing.T) {
peer := &Peer{}
userPeer := &Peer{
DisplayName: "New DisplayName",
}
peer.OverwriteUserEditableFields(userPeer)
assert.Equal(t, "New DisplayName", peer.DisplayName)
}
func TestPeer_GetPresharedKey(t *testing.T) {
physicalPeer := PhysicalPeer{}
assert.Nil(t, physicalPeer.GetPresharedKey())
physicalPeer.PresharedKey = "Q0evIJTOjhyy2o5J7whvrsvQC+FRL8A74vrw44YHUAk="
key := physicalPeer.GetPresharedKey()
assert.NotNil(t, key)
}
func TestPeer_GetEndpointAddress(t *testing.T) {
physicalPeer := PhysicalPeer{}
assert.Nil(t, physicalPeer.GetEndpointAddress())
physicalPeer.Endpoint = "192.168.1.1:51820"
addr := physicalPeer.GetEndpointAddress()
assert.NotNil(t, addr)
assert.Equal(t, "192.168.1.1:51820", addr.String())
}
func TestPeer_GetPersistentKeepaliveTime(t *testing.T) {
physicalPeer := PhysicalPeer{}
assert.Nil(t, physicalPeer.GetPersistentKeepaliveTime())
physicalPeer.PersistentKeepalive = 25
duration := physicalPeer.GetPersistentKeepaliveTime()
assert.NotNil(t, duration)
assert.Equal(t, 25*time.Second, *duration)
}
func TestPeer_GetAllowedIPs(t *testing.T) {
physicalPeer := PhysicalPeer{}
assert.Empty(t, physicalPeer.GetAllowedIPs())
physicalPeer.AllowedIPs = []Cidr{
{
Cidr: "192.168.1.0/24",
Addr: "192.168.1.0",
NetLength: 24,
},
}
ips := physicalPeer.GetAllowedIPs()
assert.Len(t, ips, 1)
assert.Equal(t, "192.168.1.0/24", ips[0].String())
physicalPeer.AllowedIPs = []Cidr{
{
Cidr: "192.168.1.0/24",
Addr: "192.168.1.0",
NetLength: 24,
},
{
Cidr: "fe80::/64",
Addr: "fe80::",
NetLength: 64,
},
}
ips2 := physicalPeer.GetAllowedIPs()
assert.Len(t, ips2, 2)
assert.Equal(t, "192.168.1.0/24", ips2[0].String())
assert.Equal(t, "fe80::/64", ips2[1].String())
}

View File

@ -0,0 +1,74 @@
package domain
import (
"testing"
"time"
)
func TestPeerStatus_IsConnected(t *testing.T) {
now := time.Now()
past := now.Add(-3 * time.Minute)
recent := now.Add(-1 * time.Minute)
tests := []struct {
name string
status PeerStatus
want bool
}{
{
name: "Pingable and recent handshake",
status: PeerStatus{
IsPingable: true,
LastHandshake: &recent,
},
want: true,
},
{
name: "Not pingable but recent handshake",
status: PeerStatus{
IsPingable: false,
LastHandshake: &recent,
},
want: true,
},
{
name: "Pingable but old handshake",
status: PeerStatus{
IsPingable: true,
LastHandshake: &past,
},
want: true,
},
{
name: "Not pingable and old handshake",
status: PeerStatus{
IsPingable: false,
LastHandshake: &past,
},
want: false,
},
{
name: "Pingable and no handshake",
status: PeerStatus{
IsPingable: true,
LastHandshake: nil,
},
want: true,
},
{
name: "Not pingable and no handshake",
status: PeerStatus{
IsPingable: false,
LastHandshake: nil,
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.status.IsConnected(); got != tt.want {
t.Errorf("IsConnected() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -0,0 +1,125 @@
package domain
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
)
func TestUser_IsDisabled(t *testing.T) {
user := &User{}
assert.False(t, user.IsDisabled())
now := time.Now()
user.Disabled = &now
assert.True(t, user.IsDisabled())
}
func TestUser_IsLocked(t *testing.T) {
user := &User{}
assert.False(t, user.IsLocked())
now := time.Now()
user.Locked = &now
assert.True(t, user.IsLocked())
}
func TestUser_IsApiEnabled(t *testing.T) {
user := &User{}
assert.False(t, user.IsApiEnabled())
user.ApiToken = "token"
assert.True(t, user.IsApiEnabled())
}
func TestUser_CanChangePassword(t *testing.T) {
user := &User{Source: UserSourceDatabase}
assert.NoError(t, user.CanChangePassword())
user.Source = UserSourceLdap
assert.Error(t, user.CanChangePassword())
user.Source = UserSourceOauth
assert.Error(t, user.CanChangePassword())
}
func TestUser_EditAllowed(t *testing.T) {
user := &User{Source: UserSourceDatabase}
newUser := &User{Source: UserSourceDatabase}
assert.NoError(t, user.EditAllowed(newUser))
newUser.Notes = "notes can be changed"
assert.NoError(t, user.EditAllowed(newUser))
newUser.Disabled = &time.Time{}
assert.NoError(t, user.EditAllowed(newUser))
newUser.Lastname = "lastname or other fields can be changed"
assert.NoError(t, user.EditAllowed(newUser))
user.Source = UserSourceLdap
newUser.Source = UserSourceLdap
newUser.Disabled = nil
newUser.Lastname = ""
newUser.Notes = "notes can be changed"
assert.NoError(t, user.EditAllowed(newUser))
newUser.Disabled = &time.Time{}
assert.NoError(t, user.EditAllowed(newUser))
newUser.Lastname = "lastname or other fields can not be changed"
assert.Error(t, user.EditAllowed(newUser))
user.Source = UserSourceOauth
newUser.Source = UserSourceOauth
newUser.Disabled = nil
newUser.Lastname = ""
newUser.Notes = "notes can be changed"
assert.NoError(t, user.EditAllowed(newUser))
newUser.Disabled = &time.Time{}
assert.NoError(t, user.EditAllowed(newUser))
newUser.Lastname = "lastname or other fields can not be changed"
assert.Error(t, user.EditAllowed(newUser))
}
func TestUser_DeleteAllowed(t *testing.T) {
user := &User{}
assert.NoError(t, user.DeleteAllowed())
}
func TestUser_CheckPassword(t *testing.T) {
password := "password"
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
user := &User{Source: UserSourceDatabase, Password: PrivateString(hashedPassword)}
assert.NoError(t, user.CheckPassword(password))
user.Password = ""
assert.Error(t, user.CheckPassword(password))
user.Source = UserSourceLdap
assert.Error(t, user.CheckPassword(password))
}
func TestUser_CheckApiToken(t *testing.T) {
user := &User{}
assert.Error(t, user.CheckApiToken("token"))
user.ApiToken = "token"
assert.NoError(t, user.CheckApiToken("token"))
assert.Error(t, user.CheckApiToken("wrong_token"))
}
func TestUser_HashPassword(t *testing.T) {
user := &User{Password: "password"}
assert.NoError(t, user.HashPassword())
assert.NotEmpty(t, user.Password)
user.Password = ""
assert.NoError(t, user.HashPassword())
}