mirror of
https://github.com/h44z/wg-portal.git
synced 2025-04-19 00:45:17 +00:00
chore: use interfaces for all other services
This commit is contained in:
parent
02ed7b19df
commit
7d0da4e7ad
@ -14,6 +14,7 @@ import (
|
||||
"github.com/h44z/wg-portal/internal/adapters"
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/app/api/core"
|
||||
backendV0 "github.com/h44z/wg-portal/internal/app/api/v0/backend"
|
||||
handlersV0 "github.com/h44z/wg-portal/internal/app/api/v0/handlers"
|
||||
backendV1 "github.com/h44z/wg-portal/internal/app/api/v1/backend"
|
||||
handlersV1 "github.com/h44z/wg-portal/internal/app/api/v1/handlers"
|
||||
@ -70,17 +71,24 @@ func main() {
|
||||
queueSize := 100
|
||||
eventBus := evbus.New(queueSize)
|
||||
|
||||
auditRecorder, err := audit.NewAuditRecorder(cfg, eventBus, database)
|
||||
internal.AssertNoError(err)
|
||||
auditRecorder.StartBackgroundJobs(ctx)
|
||||
|
||||
userManager, err := users.NewUserManager(cfg, eventBus, database, database)
|
||||
internal.AssertNoError(err)
|
||||
userManager.StartBackgroundJobs(ctx)
|
||||
|
||||
authenticator, err := auth.NewAuthenticator(&cfg.Auth, cfg.Web.ExternalUrl, eventBus, userManager)
|
||||
internal.AssertNoError(err)
|
||||
|
||||
wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, wgQuick, database)
|
||||
internal.AssertNoError(err)
|
||||
wireGuardManager.StartBackgroundJobs(ctx)
|
||||
|
||||
statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, eventBus, database, wireGuard, metricsServer)
|
||||
internal.AssertNoError(err)
|
||||
statisticsCollector.StartBackgroundJobs(ctx)
|
||||
|
||||
cfgFileManager, err := configfile.NewConfigFileManager(cfg, eventBus, database, database, cfgFileSystem)
|
||||
internal.AssertNoError(err)
|
||||
@ -88,18 +96,11 @@ func main() {
|
||||
mailManager, err := mail.NewMailManager(cfg, mailer, cfgFileManager, database, database)
|
||||
internal.AssertNoError(err)
|
||||
|
||||
auditRecorder, err := audit.NewAuditRecorder(cfg, eventBus, database)
|
||||
internal.AssertNoError(err)
|
||||
auditRecorder.StartBackgroundJobs(ctx)
|
||||
|
||||
routeManager, err := route.NewRouteManager(cfg, eventBus, database)
|
||||
internal.AssertNoError(err)
|
||||
routeManager.StartBackgroundJobs(ctx)
|
||||
|
||||
backend, err := app.New(cfg, eventBus, authenticator, userManager, wireGuardManager,
|
||||
statisticsCollector, cfgFileManager, mailManager)
|
||||
internal.AssertNoError(err)
|
||||
err = backend.Startup(ctx)
|
||||
err = app.Initialize(cfg, wireGuardManager, userManager)
|
||||
internal.AssertNoError(err)
|
||||
|
||||
validatorManager := validator.New()
|
||||
@ -109,10 +110,14 @@ func main() {
|
||||
apiV0Session := handlersV0.NewSessionWrapper(cfg)
|
||||
apiV0Auth := handlersV0.NewAuthenticationHandler(authenticator, apiV0Session)
|
||||
|
||||
apiV0EndpointAuth := handlersV0.NewAuthEndpoint(cfg, apiV0Auth, apiV0Session, validatorManager, backend)
|
||||
apiV0EndpointUsers := handlersV0.NewUserEndpoint(cfg, apiV0Auth, validatorManager, backend)
|
||||
apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(cfg, apiV0Auth, validatorManager, backend)
|
||||
apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, backend)
|
||||
apiV0BackendUsers := backendV0.NewUserService(cfg, userManager, wireGuardManager)
|
||||
apiV0BackendInterfaces := backendV0.NewInterfaceService(cfg, wireGuardManager, cfgFileManager)
|
||||
apiV0BackendPeers := backendV0.NewPeerService(cfg, wireGuardManager, cfgFileManager, mailManager)
|
||||
|
||||
apiV0EndpointAuth := handlersV0.NewAuthEndpoint(cfg, apiV0Auth, apiV0Session, validatorManager, authenticator)
|
||||
apiV0EndpointUsers := handlersV0.NewUserEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendUsers)
|
||||
apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendInterfaces)
|
||||
apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendPeers)
|
||||
apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth)
|
||||
apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth)
|
||||
|
||||
|
91
internal/app/api/v0/backend/interface_service.go
Normal file
91
internal/app/api/v0/backend/interface_service.go
Normal file
@ -0,0 +1,91 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
// region dependencies
|
||||
|
||||
type InterfaceServiceInterfaceManager interface {
|
||||
GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error)
|
||||
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
|
||||
CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error)
|
||||
UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error)
|
||||
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
|
||||
PrepareInterface(ctx context.Context) (*domain.Interface, error)
|
||||
ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error
|
||||
}
|
||||
|
||||
type InterfaceServiceConfigFileManager interface {
|
||||
PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error
|
||||
GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error)
|
||||
}
|
||||
|
||||
// endregion dependencies
|
||||
|
||||
type InterfaceService struct {
|
||||
cfg *config.Config
|
||||
|
||||
interfaces InterfaceServiceInterfaceManager
|
||||
configFile InterfaceServiceConfigFileManager
|
||||
}
|
||||
|
||||
func NewInterfaceService(
|
||||
cfg *config.Config,
|
||||
interfaces InterfaceServiceInterfaceManager,
|
||||
configFile InterfaceServiceConfigFileManager,
|
||||
) *InterfaceService {
|
||||
return &InterfaceService{
|
||||
cfg: cfg,
|
||||
interfaces: interfaces,
|
||||
configFile: configFile,
|
||||
}
|
||||
}
|
||||
|
||||
func (i InterfaceService) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
|
||||
*domain.Interface,
|
||||
[]domain.Peer,
|
||||
error,
|
||||
) {
|
||||
return i.interfaces.GetInterfaceAndPeers(ctx, id)
|
||||
}
|
||||
|
||||
func (i InterfaceService) PrepareInterface(ctx context.Context) (*domain.Interface, error) {
|
||||
return i.interfaces.PrepareInterface(ctx)
|
||||
}
|
||||
|
||||
func (i InterfaceService) CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) {
|
||||
return i.interfaces.CreateInterface(ctx, in)
|
||||
}
|
||||
|
||||
func (i InterfaceService) UpdateInterface(ctx context.Context, in *domain.Interface) (
|
||||
*domain.Interface,
|
||||
[]domain.Peer,
|
||||
error,
|
||||
) {
|
||||
return i.interfaces.UpdateInterface(ctx, in)
|
||||
}
|
||||
|
||||
func (i InterfaceService) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
|
||||
return i.interfaces.DeleteInterface(ctx, id)
|
||||
}
|
||||
|
||||
func (i InterfaceService) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) {
|
||||
return i.interfaces.GetAllInterfacesAndPeers(ctx)
|
||||
}
|
||||
|
||||
func (i InterfaceService) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) {
|
||||
return i.configFile.GetInterfaceConfig(ctx, id)
|
||||
}
|
||||
|
||||
func (i InterfaceService) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error {
|
||||
return i.configFile.PersistInterfaceConfig(ctx, id)
|
||||
}
|
||||
|
||||
func (i InterfaceService) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error {
|
||||
return i.interfaces.ApplyPeerDefaults(ctx, in)
|
||||
}
|
112
internal/app/api/v0/backend/peer_service.go
Normal file
112
internal/app/api/v0/backend/peer_service.go
Normal file
@ -0,0 +1,112 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
// region dependencies
|
||||
|
||||
type PeerServicePeerManager interface {
|
||||
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
||||
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
|
||||
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
|
||||
PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error)
|
||||
CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error)
|
||||
UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error)
|
||||
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
|
||||
CreateMultiplePeers(
|
||||
ctx context.Context,
|
||||
interfaceId domain.InterfaceIdentifier,
|
||||
r *domain.PeerCreationRequest,
|
||||
) ([]domain.Peer, error)
|
||||
GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error)
|
||||
}
|
||||
|
||||
type PeerServiceConfigFileManager interface {
|
||||
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
|
||||
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
|
||||
}
|
||||
|
||||
type PeerServiceMailManager interface {
|
||||
SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error
|
||||
}
|
||||
|
||||
// endregion dependencies
|
||||
|
||||
type PeerService struct {
|
||||
cfg *config.Config
|
||||
|
||||
peers PeerServicePeerManager
|
||||
configFile PeerServiceConfigFileManager
|
||||
mailer PeerServiceMailManager
|
||||
}
|
||||
|
||||
func NewPeerService(
|
||||
cfg *config.Config,
|
||||
peers PeerServicePeerManager,
|
||||
configFile PeerServiceConfigFileManager,
|
||||
mailer PeerServiceMailManager,
|
||||
) *PeerService {
|
||||
return &PeerService{
|
||||
cfg: cfg,
|
||||
peers: peers,
|
||||
configFile: configFile,
|
||||
mailer: mailer,
|
||||
}
|
||||
}
|
||||
|
||||
func (p PeerService) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
|
||||
*domain.Interface,
|
||||
[]domain.Peer,
|
||||
error,
|
||||
) {
|
||||
return p.peers.GetInterfaceAndPeers(ctx, id)
|
||||
}
|
||||
|
||||
func (p PeerService) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) {
|
||||
return p.peers.PreparePeer(ctx, id)
|
||||
}
|
||||
|
||||
func (p PeerService) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) {
|
||||
return p.peers.GetPeer(ctx, id)
|
||||
}
|
||||
|
||||
func (p PeerService) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
|
||||
return p.peers.CreatePeer(ctx, peer)
|
||||
}
|
||||
|
||||
func (p PeerService) CreateMultiplePeers(
|
||||
ctx context.Context,
|
||||
interfaceId domain.InterfaceIdentifier,
|
||||
r *domain.PeerCreationRequest,
|
||||
) ([]domain.Peer, error) {
|
||||
return p.peers.CreateMultiplePeers(ctx, interfaceId, r)
|
||||
}
|
||||
|
||||
func (p PeerService) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
|
||||
return p.peers.UpdatePeer(ctx, peer)
|
||||
}
|
||||
|
||||
func (p PeerService) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error {
|
||||
return p.peers.DeletePeer(ctx, id)
|
||||
}
|
||||
|
||||
func (p PeerService) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
|
||||
return p.configFile.GetPeerConfig(ctx, id)
|
||||
}
|
||||
|
||||
func (p PeerService) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
|
||||
return p.configFile.GetPeerConfigQrCode(ctx, id)
|
||||
}
|
||||
|
||||
func (p PeerService) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error {
|
||||
return p.mailer.SendPeerEmail(ctx, linkOnly, peers...)
|
||||
}
|
||||
|
||||
func (p PeerService) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error) {
|
||||
return p.peers.GetPeerStats(ctx, id)
|
||||
}
|
83
internal/app/api/v0/backend/user_service.go
Normal file
83
internal/app/api/v0/backend/user_service.go
Normal file
@ -0,0 +1,83 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
// region dependencies
|
||||
|
||||
type UserServiceUserManager interface {
|
||||
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||
GetAllUsers(ctx context.Context) ([]domain.User, error)
|
||||
CreateUser(ctx context.Context, user *domain.User) (*domain.User, error)
|
||||
UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error)
|
||||
DeleteUser(ctx context.Context, id domain.UserIdentifier) error
|
||||
ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||
DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||
}
|
||||
|
||||
type UserServiceWireGuardManager interface {
|
||||
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
|
||||
GetUserInterfaces(ctx context.Context, _ domain.UserIdentifier) ([]domain.Interface, error)
|
||||
GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error)
|
||||
}
|
||||
|
||||
// endregion dependencies
|
||||
|
||||
type UserService struct {
|
||||
cfg *config.Config
|
||||
|
||||
users UserServiceUserManager
|
||||
wg UserServiceWireGuardManager
|
||||
}
|
||||
|
||||
func NewUserService(cfg *config.Config, users UserServiceUserManager, wg UserServiceWireGuardManager) *UserService {
|
||||
return &UserService{
|
||||
cfg: cfg,
|
||||
users: users,
|
||||
wg: wg,
|
||||
}
|
||||
}
|
||||
|
||||
func (u UserService) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
|
||||
return u.users.GetUser(ctx, id)
|
||||
}
|
||||
|
||||
func (u UserService) GetAllUsers(ctx context.Context) ([]domain.User, error) {
|
||||
return u.users.GetAllUsers(ctx)
|
||||
}
|
||||
|
||||
func (u UserService) UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
|
||||
return u.users.UpdateUser(ctx, user)
|
||||
}
|
||||
|
||||
func (u UserService) CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
|
||||
return u.users.CreateUser(ctx, user)
|
||||
}
|
||||
|
||||
func (u UserService) DeleteUser(ctx context.Context, id domain.UserIdentifier) error {
|
||||
return u.users.DeleteUser(ctx, id)
|
||||
}
|
||||
|
||||
func (u UserService) ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
|
||||
return u.users.ActivateApi(ctx, id)
|
||||
}
|
||||
|
||||
func (u UserService) DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
|
||||
return u.users.DeactivateApi(ctx, id)
|
||||
}
|
||||
|
||||
func (u UserService) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
|
||||
return u.wg.GetUserPeers(ctx, id)
|
||||
}
|
||||
|
||||
func (u UserService) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) {
|
||||
return u.wg.GetUserPeerStats(ctx, id)
|
||||
}
|
||||
|
||||
func (u UserService) GetUserInterfaces(ctx context.Context, id domain.UserIdentifier) ([]domain.Interface, error) {
|
||||
return u.wg.GetUserInterfaces(ctx, id)
|
||||
}
|
@ -7,46 +7,43 @@ import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
evbus "github.com/vardius/message-bus"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type App struct {
|
||||
Config *config.Config
|
||||
bus evbus.MessageBus
|
||||
// region dependencies
|
||||
|
||||
Authenticator
|
||||
UserManager
|
||||
WireGuardManager
|
||||
StatisticsCollector
|
||||
ConfigFileManager
|
||||
MailManager
|
||||
ApiV1Manager
|
||||
type WireGuardManager interface {
|
||||
ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error)
|
||||
RestoreInterfaceState(ctx context.Context, updateDbOnError bool, filter ...domain.InterfaceIdentifier) error
|
||||
}
|
||||
|
||||
func New(
|
||||
type UserManager interface {
|
||||
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||
CreateUser(ctx context.Context, user *domain.User) (*domain.User, error)
|
||||
}
|
||||
|
||||
// endregion dependencies
|
||||
|
||||
// App is the main application struct.
|
||||
type App struct {
|
||||
cfg *config.Config
|
||||
|
||||
wg WireGuardManager
|
||||
users UserManager
|
||||
}
|
||||
|
||||
// Initialize creates a new App instance and initializes it.
|
||||
func Initialize(
|
||||
cfg *config.Config,
|
||||
bus evbus.MessageBus,
|
||||
authenticator Authenticator,
|
||||
wg WireGuardManager,
|
||||
users UserManager,
|
||||
wireGuard WireGuardManager,
|
||||
stats StatisticsCollector,
|
||||
cfgFiles ConfigFileManager,
|
||||
mailer MailManager,
|
||||
) (*App, error) {
|
||||
|
||||
) error {
|
||||
a := &App{
|
||||
Config: cfg,
|
||||
bus: bus,
|
||||
cfg: cfg,
|
||||
|
||||
Authenticator: authenticator,
|
||||
UserManager: users,
|
||||
WireGuardManager: wireGuard,
|
||||
StatisticsCollector: stats,
|
||||
ConfigFileManager: cfgFiles,
|
||||
MailManager: mailer,
|
||||
wg: wg,
|
||||
users: users,
|
||||
}
|
||||
|
||||
startupContext, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
@ -56,36 +53,27 @@ func New(
|
||||
startupContext = domain.SetUserInfo(startupContext, domain.SystemAdminContextUserInfo())
|
||||
|
||||
if err := a.createDefaultUser(startupContext); err != nil {
|
||||
return nil, fmt.Errorf("failed to create default user: %w", err)
|
||||
return fmt.Errorf("failed to create default user: %w", err)
|
||||
}
|
||||
|
||||
if err := a.importNewInterfaces(startupContext); err != nil {
|
||||
return nil, fmt.Errorf("failed to import new interfaces: %w", err)
|
||||
return fmt.Errorf("failed to import new interfaces: %w", err)
|
||||
}
|
||||
|
||||
if err := a.restoreInterfaceState(startupContext); err != nil {
|
||||
return nil, fmt.Errorf("failed to restore interface state: %w", err)
|
||||
return fmt.Errorf("failed to restore interface state: %w", err)
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (a *App) Startup(ctx context.Context) error {
|
||||
|
||||
a.UserManager.StartBackgroundJobs(ctx)
|
||||
a.StatisticsCollector.StartBackgroundJobs(ctx)
|
||||
a.WireGuardManager.StartBackgroundJobs(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *App) importNewInterfaces(ctx context.Context) error {
|
||||
if !a.Config.Core.ImportExisting {
|
||||
if !a.cfg.Core.ImportExisting {
|
||||
slog.Debug("skipping interface import - feature disabled")
|
||||
return nil // feature disabled
|
||||
}
|
||||
|
||||
importedCount, err := a.ImportNewInterfaces(ctx)
|
||||
importedCount, err := a.wg.ImportNewInterfaces(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -97,12 +85,12 @@ func (a *App) importNewInterfaces(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (a *App) restoreInterfaceState(ctx context.Context) error {
|
||||
if !a.Config.Core.RestoreState {
|
||||
if !a.cfg.Core.RestoreState {
|
||||
slog.Debug("skipping interface state restore - feature disabled")
|
||||
return nil // feature disabled
|
||||
}
|
||||
|
||||
err := a.RestoreInterfaceState(ctx, true)
|
||||
err := a.wg.RestoreInterfaceState(ctx, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -112,13 +100,13 @@ func (a *App) restoreInterfaceState(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (a *App) createDefaultUser(ctx context.Context) error {
|
||||
adminUserId := domain.UserIdentifier(a.Config.Core.AdminUser)
|
||||
adminUserId := domain.UserIdentifier(a.cfg.Core.AdminUser)
|
||||
if adminUserId == "" {
|
||||
slog.Debug("skipping default user creation - admin user is blank")
|
||||
return nil // empty admin user - do not create
|
||||
}
|
||||
|
||||
_, err := a.GetUser(ctx, adminUserId)
|
||||
_, err := a.users.GetUser(ctx, adminUserId)
|
||||
if err != nil && !errors.Is(err, domain.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
@ -145,22 +133,22 @@ func (a *App) createDefaultUser(ctx context.Context) error {
|
||||
Phone: "",
|
||||
Department: "",
|
||||
Notes: "default administrator user",
|
||||
Password: domain.PrivateString(a.Config.Core.AdminPassword),
|
||||
Password: domain.PrivateString(a.cfg.Core.AdminPassword),
|
||||
Disabled: nil,
|
||||
DisabledReason: "",
|
||||
Locked: nil,
|
||||
LockedReason: "",
|
||||
LinkedPeerCount: 0,
|
||||
}
|
||||
if a.Config.Core.AdminApiToken != "" {
|
||||
if len(a.Config.Core.AdminApiToken) < 18 {
|
||||
if a.cfg.Core.AdminApiToken != "" {
|
||||
if len(a.cfg.Core.AdminApiToken) < 18 {
|
||||
slog.Warn("admin API token is too short, should be at least 18 characters long")
|
||||
}
|
||||
defaultAdmin.ApiToken = a.Config.Core.AdminApiToken
|
||||
defaultAdmin.ApiToken = a.cfg.Core.AdminApiToken
|
||||
defaultAdmin.ApiTokenCreated = &now
|
||||
}
|
||||
|
||||
admin, err := a.CreateUser(ctx, defaultAdmin)
|
||||
admin, err := a.users.CreateUser(ctx, defaultAdmin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -6,21 +6,35 @@ import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
evbus "github.com/vardius/message-bus"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
// region dependencies
|
||||
|
||||
type DatabaseRepo interface {
|
||||
// SaveAuditEntry saves an audit entry to the database
|
||||
SaveAuditEntry(ctx context.Context, entry *domain.AuditEntry) error
|
||||
}
|
||||
|
||||
type EventBus interface {
|
||||
// Subscribe subscribes to a topic
|
||||
Subscribe(topic string, fn interface{}) error
|
||||
}
|
||||
|
||||
// endregion dependencies
|
||||
|
||||
// Recorder is responsible for recording audit events to the database.
|
||||
type Recorder struct {
|
||||
cfg *config.Config
|
||||
bus evbus.MessageBus
|
||||
bus EventBus
|
||||
|
||||
db DatabaseRepo
|
||||
}
|
||||
|
||||
func NewAuditRecorder(cfg *config.Config, bus evbus.MessageBus, db DatabaseRepo) (*Recorder, error) {
|
||||
// NewAuditRecorder creates a new audit recorder instance.
|
||||
func NewAuditRecorder(cfg *config.Config, bus EventBus, db DatabaseRepo) (*Recorder, error) {
|
||||
r := &Recorder{
|
||||
cfg: cfg,
|
||||
bus: bus,
|
||||
@ -36,6 +50,8 @@ func NewAuditRecorder(cfg *config.Config, bus evbus.MessageBus, db DatabaseRepo)
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// StartBackgroundJobs starts background jobs for the audit recorder.
|
||||
// This method is non-blocking and returns immediately.
|
||||
func (r *Recorder) StartBackgroundJobs(ctx context.Context) {
|
||||
if !r.cfg.Statistics.CollectAuditData {
|
||||
return // noting to do
|
||||
|
@ -1,11 +0,0 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type DatabaseRepo interface {
|
||||
SaveAuditEntry(ctx context.Context, entry *domain.AuditEntry) error
|
||||
}
|
@ -14,25 +14,78 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
evbus "github.com/vardius/message-bus"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
// region dependencies
|
||||
|
||||
type UserManager interface {
|
||||
// GetUser returns a user by its identifier.
|
||||
GetUser(context.Context, domain.UserIdentifier) (*domain.User, error)
|
||||
// RegisterUser creates a new user in the database.
|
||||
RegisterUser(ctx context.Context, user *domain.User) error
|
||||
// UpdateUser updates an existing user in the database.
|
||||
UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error)
|
||||
}
|
||||
|
||||
type EventBus interface {
|
||||
// Publish sends a message to the message bus.
|
||||
Publish(topic string, args ...any)
|
||||
}
|
||||
|
||||
// endregion dependencies
|
||||
|
||||
type AuthenticatorType string
|
||||
|
||||
const (
|
||||
AuthenticatorTypeOAuth AuthenticatorType = "oauth"
|
||||
AuthenticatorTypeOidc AuthenticatorType = "oidc"
|
||||
)
|
||||
|
||||
// AuthenticatorOauth is the interface for all OAuth authenticators.
|
||||
type AuthenticatorOauth interface {
|
||||
// GetName returns the name of the authenticator.
|
||||
GetName() string
|
||||
// GetType returns the type of the authenticator. It can be either AuthenticatorTypeOAuth or AuthenticatorTypeOidc.
|
||||
GetType() AuthenticatorType
|
||||
// AuthCodeURL returns the URL for the authentication flow.
|
||||
AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
|
||||
// Exchange exchanges the OAuth code for an access token.
|
||||
Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
|
||||
// GetUserInfo fetches the user information from the OAuth or OIDC provider.
|
||||
GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (map[string]any, error)
|
||||
// ParseUserInfo parses the raw user information into a domain.AuthenticatorUserInfo struct.
|
||||
ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error)
|
||||
// RegistrationEnabled returns whether registration is enabled for the OAuth authenticator.
|
||||
RegistrationEnabled() bool
|
||||
}
|
||||
|
||||
// AuthenticatorLdap is the interface for all LDAP authenticators.
|
||||
type AuthenticatorLdap interface {
|
||||
// GetName returns the name of the authenticator.
|
||||
GetName() string
|
||||
// PlaintextAuthentication performs a plaintext authentication against the LDAP server.
|
||||
PlaintextAuthentication(userId domain.UserIdentifier, plainPassword string) error
|
||||
// GetUserInfo fetches the user information from the LDAP server.
|
||||
GetUserInfo(ctx context.Context, username domain.UserIdentifier) (map[string]any, error)
|
||||
// ParseUserInfo parses the raw user information into a domain.AuthenticatorUserInfo struct.
|
||||
ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error)
|
||||
// RegistrationEnabled returns whether registration is enabled for the LDAP authenticator.
|
||||
RegistrationEnabled() bool
|
||||
}
|
||||
|
||||
// Authenticator is the main entry point for all authentication related tasks.
|
||||
// This includes password authentication and external authentication providers (OIDC, OAuth, LDAP).
|
||||
type Authenticator struct {
|
||||
cfg *config.Auth
|
||||
bus evbus.MessageBus
|
||||
bus EventBus
|
||||
|
||||
oauthAuthenticators map[string]domain.OauthAuthenticator
|
||||
ldapAuthenticators map[string]domain.LdapAuthenticator
|
||||
oauthAuthenticators map[string]AuthenticatorOauth
|
||||
ldapAuthenticators map[string]AuthenticatorLdap
|
||||
|
||||
// URL prefix for the callback endpoints, this is a combination of the external URL and the API prefix
|
||||
callbackUrlPrefix string
|
||||
@ -40,7 +93,8 @@ type Authenticator struct {
|
||||
users UserManager
|
||||
}
|
||||
|
||||
func NewAuthenticator(cfg *config.Auth, extUrl string, bus evbus.MessageBus, users UserManager) (
|
||||
// NewAuthenticator creates a new Authenticator instance.
|
||||
func NewAuthenticator(cfg *config.Auth, extUrl string, bus EventBus, users UserManager) (
|
||||
*Authenticator,
|
||||
error,
|
||||
) {
|
||||
@ -68,8 +122,8 @@ func (a *Authenticator) setupExternalAuthProviders(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to parse external url: %w", err)
|
||||
}
|
||||
|
||||
a.oauthAuthenticators = make(map[string]domain.OauthAuthenticator, len(a.cfg.OpenIDConnect)+len(a.cfg.OAuth))
|
||||
a.ldapAuthenticators = make(map[string]domain.LdapAuthenticator, len(a.cfg.Ldap))
|
||||
a.oauthAuthenticators = make(map[string]AuthenticatorOauth, len(a.cfg.OpenIDConnect)+len(a.cfg.OAuth))
|
||||
a.ldapAuthenticators = make(map[string]AuthenticatorLdap, len(a.cfg.Ldap))
|
||||
|
||||
for i := range a.cfg.OpenIDConnect { // OIDC
|
||||
providerCfg := &a.cfg.OpenIDConnect[i]
|
||||
@ -123,6 +177,7 @@ func (a *Authenticator) setupExternalAuthProviders(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetExternalLoginProviders returns a list of all available external login providers.
|
||||
func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.LoginProviderInfo {
|
||||
authProviders := make([]domain.LoginProviderInfo, 0, len(a.cfg.OAuth)+len(a.cfg.OpenIDConnect))
|
||||
|
||||
@ -157,6 +212,7 @@ func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.Lo
|
||||
return authProviders
|
||||
}
|
||||
|
||||
// IsUserValid checks if a user is valid and not locked or disabled.
|
||||
func (a *Authenticator) IsUserValid(ctx context.Context, id domain.UserIdentifier) bool {
|
||||
ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context
|
||||
user, err := a.users.GetUser(ctx, id)
|
||||
@ -177,6 +233,8 @@ func (a *Authenticator) IsUserValid(ctx context.Context, id domain.UserIdentifie
|
||||
|
||||
// region password authentication
|
||||
|
||||
// PlainLogin performs a password authentication for a user. The username and password are trimmed before usage.
|
||||
// If the login is successful, the user is returned, otherwise an error.
|
||||
func (a *Authenticator) PlainLogin(ctx context.Context, username, password string) (*domain.User, error) {
|
||||
// Validate form input
|
||||
username = strings.TrimSpace(username)
|
||||
@ -204,7 +262,7 @@ func (a *Authenticator) passwordAuthentication(
|
||||
domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
|
||||
|
||||
var ldapUserInfo *domain.AuthenticatorUserInfo
|
||||
var ldapProvider domain.LdapAuthenticator
|
||||
var ldapProvider AuthenticatorLdap
|
||||
|
||||
var userInDatabase = false
|
||||
var userSource domain.UserSource
|
||||
@ -280,6 +338,7 @@ func (a *Authenticator) passwordAuthentication(
|
||||
|
||||
// region oauth authentication
|
||||
|
||||
// OauthLoginStep1 starts the oauth authentication flow by returning the authentication URL, state and nonce.
|
||||
func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) (
|
||||
authCodeUrl, state, nonce string,
|
||||
err error,
|
||||
@ -296,9 +355,9 @@ func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) (
|
||||
}
|
||||
|
||||
switch oauthProvider.GetType() {
|
||||
case domain.AuthenticatorTypeOAuth:
|
||||
case AuthenticatorTypeOAuth:
|
||||
authCodeUrl = oauthProvider.AuthCodeURL(state)
|
||||
case domain.AuthenticatorTypeOidc:
|
||||
case AuthenticatorTypeOidc:
|
||||
nonce, err = a.randString(16)
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("failed to generate nonce: %w", err)
|
||||
@ -318,6 +377,8 @@ func (a *Authenticator) randString(nByte int) (string, error) {
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// OauthLoginStep2 finishes the oauth authentication flow by exchanging the code for an access token and
|
||||
// fetching the user information.
|
||||
func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, error) {
|
||||
oauthProvider, ok := a.oauthAuthenticators[providerId]
|
||||
if !ok {
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
// LdapAuthenticator is an authenticator that uses LDAP for authentication.
|
||||
type LdapAuthenticator struct {
|
||||
cfg *config.LdapProvider
|
||||
}
|
||||
@ -33,14 +34,17 @@ func newLdapAuthenticator(_ context.Context, cfg *config.LdapProvider) (*LdapAut
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// GetName returns the name of the LDAP authenticator.
|
||||
func (l LdapAuthenticator) GetName() string {
|
||||
return l.cfg.ProviderName
|
||||
}
|
||||
|
||||
// RegistrationEnabled returns whether registration is enabled for the LDAP authenticator.
|
||||
func (l LdapAuthenticator) RegistrationEnabled() bool {
|
||||
return l.cfg.RegistrationEnabled
|
||||
}
|
||||
|
||||
// PlaintextAuthentication performs a plaintext authentication against the LDAP server.
|
||||
func (l LdapAuthenticator) PlaintextAuthentication(userId domain.UserIdentifier, plainPassword string) error {
|
||||
conn, err := internal.LdapConnect(l.cfg)
|
||||
if err != nil {
|
||||
@ -81,6 +85,9 @@ func (l LdapAuthenticator) PlaintextAuthentication(userId domain.UserIdentifier,
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserInfo retrieves user information from the LDAP server.
|
||||
// If the user is not found, domain.ErrNotFound is returned.
|
||||
// If multiple users are found, domain.ErrNotUnique is returned.
|
||||
func (l LdapAuthenticator) GetUserInfo(_ context.Context, userId domain.UserIdentifier) (
|
||||
map[string]any,
|
||||
error,
|
||||
@ -126,6 +133,7 @@ func (l LdapAuthenticator) GetUserInfo(_ context.Context, userId domain.UserIden
|
||||
return users[0], nil
|
||||
}
|
||||
|
||||
// ParseUserInfo parses the user information from the LDAP server into a domain.AuthenticatorUserInfo struct.
|
||||
func (l LdapAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
|
||||
isAdmin, err := internal.LdapIsMemberOf(raw[l.cfg.FieldMap.GroupMembership].([][]byte), l.cfg.ParsedAdminGroupDN)
|
||||
if err != nil {
|
||||
|
@ -16,6 +16,8 @@ import (
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
// PlainOauthAuthenticator is an authenticator that uses OAuth for authentication.
|
||||
// User information is retrieved from the specified user info endpoint.
|
||||
type PlainOauthAuthenticator struct {
|
||||
name string
|
||||
cfg *oauth2.Config
|
||||
@ -58,22 +60,27 @@ func newPlainOauthAuthenticator(
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// GetName returns the name of the OAuth authenticator.
|
||||
func (p PlainOauthAuthenticator) GetName() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
// RegistrationEnabled returns whether registration is enabled for the OAuth authenticator.
|
||||
func (p PlainOauthAuthenticator) RegistrationEnabled() bool {
|
||||
return p.registrationEnabled
|
||||
}
|
||||
|
||||
func (p PlainOauthAuthenticator) GetType() domain.AuthenticatorType {
|
||||
return domain.AuthenticatorTypeOAuth
|
||||
// GetType returns the type of the authenticator.
|
||||
func (p PlainOauthAuthenticator) GetType() AuthenticatorType {
|
||||
return AuthenticatorTypeOAuth
|
||||
}
|
||||
|
||||
// AuthCodeURL returns the URL to redirect the user to for authentication.
|
||||
func (p PlainOauthAuthenticator) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
|
||||
return p.cfg.AuthCodeURL(state, opts...)
|
||||
}
|
||||
|
||||
// Exchange exchanges the OAuth code for a token.
|
||||
func (p PlainOauthAuthenticator) Exchange(
|
||||
ctx context.Context,
|
||||
code string,
|
||||
@ -82,6 +89,7 @@ func (p PlainOauthAuthenticator) Exchange(
|
||||
return p.cfg.Exchange(ctx, code, opts...)
|
||||
}
|
||||
|
||||
// GetUserInfo retrieves the user information from the user info endpoint.
|
||||
func (p PlainOauthAuthenticator) GetUserInfo(
|
||||
ctx context.Context,
|
||||
token *oauth2.Token,
|
||||
@ -119,6 +127,7 @@ func (p PlainOauthAuthenticator) GetUserInfo(
|
||||
return userFields, nil
|
||||
}
|
||||
|
||||
// ParseUserInfo parses the user information from the raw data.
|
||||
func (p PlainOauthAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
|
||||
return parseOauthUserInfo(p.userInfoMapping, p.userAdminMapping, raw)
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
// OidcAuthenticator is an authenticator for OpenID Connect providers.
|
||||
type OidcAuthenticator struct {
|
||||
name string
|
||||
provider *oidc.Provider
|
||||
@ -60,22 +61,27 @@ func newOidcAuthenticator(
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// GetName returns the name of the authenticator.
|
||||
func (o OidcAuthenticator) GetName() string {
|
||||
return o.name
|
||||
}
|
||||
|
||||
// RegistrationEnabled returns whether registration is enabled for this authenticator.
|
||||
func (o OidcAuthenticator) RegistrationEnabled() bool {
|
||||
return o.registrationEnabled
|
||||
}
|
||||
|
||||
func (o OidcAuthenticator) GetType() domain.AuthenticatorType {
|
||||
return domain.AuthenticatorTypeOidc
|
||||
// GetType returns the type of the authenticator.
|
||||
func (o OidcAuthenticator) GetType() AuthenticatorType {
|
||||
return AuthenticatorTypeOidc
|
||||
}
|
||||
|
||||
// AuthCodeURL returns the URL for the OAuth2 flow.
|
||||
func (o OidcAuthenticator) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
|
||||
return o.cfg.AuthCodeURL(state, opts...)
|
||||
}
|
||||
|
||||
// Exchange exchanges the code for a token.
|
||||
func (o OidcAuthenticator) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (
|
||||
*oauth2.Token,
|
||||
error,
|
||||
@ -83,6 +89,7 @@ func (o OidcAuthenticator) Exchange(ctx context.Context, code string, opts ...oa
|
||||
return o.cfg.Exchange(ctx, code, opts...)
|
||||
}
|
||||
|
||||
// GetUserInfo retrieves the user info from the token.
|
||||
func (o OidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (
|
||||
map[string]any,
|
||||
error,
|
||||
@ -114,6 +121,7 @@ func (o OidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token,
|
||||
return tokenFields, nil
|
||||
}
|
||||
|
||||
// ParseUserInfo parses the user info.
|
||||
func (o OidcAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
|
||||
return parseOauthUserInfo(o.userInfoMapping, o.userAdminMapping, raw)
|
||||
}
|
||||
|
@ -10,7 +10,6 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
evbus "github.com/vardius/message-bus"
|
||||
"github.com/yeqown/go-qrcode/v2"
|
||||
"github.com/yeqown/go-qrcode/writer/compressed"
|
||||
|
||||
@ -19,19 +18,56 @@ import (
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
cfg *config.Config
|
||||
bus evbus.MessageBus
|
||||
tplHandler *TemplateHandler
|
||||
// region dependencies
|
||||
|
||||
fsRepo FileSystemRepo
|
||||
users UserDatabaseRepo
|
||||
wg WireguardDatabaseRepo
|
||||
type UserDatabaseRepo interface {
|
||||
// GetUser returns the user with the given identifier from the SQL database.
|
||||
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||
}
|
||||
|
||||
type WireguardDatabaseRepo interface {
|
||||
// GetInterfaceAndPeers returns the interface and all peers associated with it.
|
||||
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
|
||||
// GetPeer returns the peer with the given identifier.
|
||||
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
||||
// GetInterface returns the interface with the given identifier.
|
||||
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
|
||||
}
|
||||
|
||||
type FileSystemRepo interface {
|
||||
// WriteFile writes the contents to the file at the given path.
|
||||
WriteFile(path string, contents io.Reader) error
|
||||
}
|
||||
|
||||
type TemplateRenderer interface {
|
||||
// GetInterfaceConfig returns the configuration file for the given interface.
|
||||
GetInterfaceConfig(iface *domain.Interface, peers []domain.Peer) (io.Reader, error)
|
||||
// GetPeerConfig returns the configuration file for the given peer.
|
||||
GetPeerConfig(peer *domain.Peer) (io.Reader, error)
|
||||
}
|
||||
|
||||
type EventBus interface {
|
||||
// Subscribe subscribes to the given topic.
|
||||
Subscribe(topic string, fn any) error
|
||||
}
|
||||
|
||||
// endregion dependencies
|
||||
|
||||
// Manager is responsible for managing the configuration files of the WireGuard interfaces and peers.
|
||||
type Manager struct {
|
||||
cfg *config.Config
|
||||
bus EventBus
|
||||
|
||||
tplHandler TemplateRenderer
|
||||
fsRepo FileSystemRepo
|
||||
users UserDatabaseRepo
|
||||
wg WireguardDatabaseRepo
|
||||
}
|
||||
|
||||
// NewConfigFileManager creates a new Manager instance.
|
||||
func NewConfigFileManager(
|
||||
cfg *config.Config,
|
||||
bus evbus.MessageBus,
|
||||
bus EventBus,
|
||||
users UserDatabaseRepo,
|
||||
wg WireguardDatabaseRepo,
|
||||
fsRepo FileSystemRepo,
|
||||
@ -115,6 +151,8 @@ func (m Manager) handlePeerInterfaceUpdatedEvent(id domain.InterfaceIdentifier)
|
||||
}
|
||||
}
|
||||
|
||||
// GetInterfaceConfig returns the configuration file for the given interface.
|
||||
// The file is structured in wg-quick format.
|
||||
func (m Manager) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, err
|
||||
@ -128,6 +166,8 @@ func (m Manager) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIden
|
||||
return m.tplHandler.GetInterfaceConfig(iface, peers)
|
||||
}
|
||||
|
||||
// GetPeerConfig returns the configuration file for the given peer.
|
||||
// The file is structured in wg-quick format.
|
||||
func (m Manager) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
|
||||
peer, err := m.wg.GetPeer(ctx, id)
|
||||
if err != nil {
|
||||
@ -141,6 +181,7 @@ func (m Manager) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (i
|
||||
return m.tplHandler.GetPeerConfig(peer)
|
||||
}
|
||||
|
||||
// GetPeerConfigQrCode returns a QR code image containing the configuration for the given peer.
|
||||
func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
|
||||
peer, err := m.wg.GetPeer(ctx, id)
|
||||
if err != nil {
|
||||
@ -191,6 +232,7 @@ func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifi
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// PersistInterfaceConfig writes the configuration file for the given interface to the file system.
|
||||
func (m Manager) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error {
|
||||
iface, peers, err := m.wg.GetInterfaceAndPeers(ctx, id)
|
||||
if err != nil {
|
||||
@ -213,4 +255,5 @@ type nopCloser struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
// Close is a no-op for the nopCloser.
|
||||
func (nopCloser) Close() error { return nil }
|
||||
|
@ -1,22 +0,0 @@
|
||||
package configfile
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type UserDatabaseRepo interface {
|
||||
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||
}
|
||||
|
||||
type WireguardDatabaseRepo interface {
|
||||
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
|
||||
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
||||
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
|
||||
}
|
||||
|
||||
type FileSystemRepo interface {
|
||||
WriteFile(path string, contents io.Reader) error
|
||||
}
|
@ -13,6 +13,8 @@ import (
|
||||
//go:embed tpl_files/*
|
||||
var TemplateFiles embed.FS
|
||||
|
||||
// TemplateHandler is responsible for rendering the WireGuard configuration files
|
||||
// based on the provided templates.
|
||||
type TemplateHandler struct {
|
||||
templates *template.Template
|
||||
}
|
||||
@ -34,6 +36,7 @@ func newTemplateHandler() (*TemplateHandler, error) {
|
||||
return handler, nil
|
||||
}
|
||||
|
||||
// GetInterfaceConfig returns the rendered configuration file for a WireGuard interface.
|
||||
func (c TemplateHandler) GetInterfaceConfig(cfg *domain.Interface, peers []domain.Peer) (io.Reader, error) {
|
||||
var tplBuff bytes.Buffer
|
||||
|
||||
@ -51,6 +54,7 @@ func (c TemplateHandler) GetInterfaceConfig(cfg *domain.Interface, peers []domai
|
||||
return &tplBuff, nil
|
||||
}
|
||||
|
||||
// GetPeerConfig returns the rendered configuration file for a WireGuard peer.
|
||||
func (c TemplateHandler) GetPeerConfig(peer *domain.Peer) (io.Reader, error) {
|
||||
var tplBuff bytes.Buffer
|
||||
|
||||
|
@ -10,16 +10,60 @@ import (
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
cfg *config.Config
|
||||
tplHandler *TemplateHandler
|
||||
// region dependencies
|
||||
|
||||
type Mailer interface {
|
||||
// Send sends an email with the given subject and body to the given recipients.
|
||||
Send(ctx context.Context, subject, body string, to []string, options *domain.MailOptions) error
|
||||
}
|
||||
|
||||
type ConfigFileManager interface {
|
||||
// GetInterfaceConfig returns the configuration for the given interface.
|
||||
GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error)
|
||||
// GetPeerConfig returns the configuration for the given peer.
|
||||
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
|
||||
// GetPeerConfigQrCode returns the QR code for the given peer.
|
||||
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
|
||||
}
|
||||
|
||||
type UserDatabaseRepo interface {
|
||||
// GetUser returns the user with the given identifier.
|
||||
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||
}
|
||||
|
||||
type WireguardDatabaseRepo interface {
|
||||
// GetInterfaceAndPeers returns the interface and all peers for the given interface identifier.
|
||||
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
|
||||
// GetPeer returns the peer with the given identifier.
|
||||
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
||||
// GetInterface returns the interface with the given identifier.
|
||||
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
|
||||
}
|
||||
|
||||
type TemplateRenderer interface {
|
||||
// GetConfigMail returns the text and html template for the mail with a link.
|
||||
GetConfigMail(user *domain.User, link string) (io.Reader, io.Reader, error)
|
||||
// GetConfigMailWithAttachment returns the text and html template for the mail with an attachment.
|
||||
GetConfigMailWithAttachment(user *domain.User, cfgName, qrName string) (
|
||||
io.Reader,
|
||||
io.Reader,
|
||||
error,
|
||||
)
|
||||
}
|
||||
|
||||
// endregion dependencies
|
||||
|
||||
type Manager struct {
|
||||
cfg *config.Config
|
||||
|
||||
tplHandler TemplateRenderer
|
||||
mailer Mailer
|
||||
configFiles ConfigFileManager
|
||||
users UserDatabaseRepo
|
||||
wg WireguardDatabaseRepo
|
||||
}
|
||||
|
||||
// NewMailManager creates a new mail manager.
|
||||
func NewMailManager(
|
||||
cfg *config.Config,
|
||||
mailer Mailer,
|
||||
@ -44,6 +88,7 @@ func NewMailManager(
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// SendPeerEmail sends an email to the user linked to the given peers.
|
||||
func (m Manager) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error {
|
||||
for _, peerId := range peers {
|
||||
peer, err := m.wg.GetPeer(ctx, peerId)
|
||||
|
@ -1,28 +0,0 @@
|
||||
package mail
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type Mailer interface {
|
||||
Send(ctx context.Context, subject, body string, to []string, options *domain.MailOptions) error
|
||||
}
|
||||
|
||||
type ConfigFileManager interface {
|
||||
GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error)
|
||||
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
|
||||
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
|
||||
}
|
||||
|
||||
type UserDatabaseRepo interface {
|
||||
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||
}
|
||||
|
||||
type WireguardDatabaseRepo interface {
|
||||
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
|
||||
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
||||
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
|
||||
}
|
@ -14,6 +14,7 @@ import (
|
||||
//go:embed tpl_files/*
|
||||
var TemplateFiles embed.FS
|
||||
|
||||
// TemplateHandler is a struct that holds the html and text templates.
|
||||
type TemplateHandler struct {
|
||||
portalUrl string
|
||||
htmlTemplates *htmlTemplate.Template
|
||||
@ -40,6 +41,7 @@ func newTemplateHandler(portalUrl string) (*TemplateHandler, error) {
|
||||
return handler, nil
|
||||
}
|
||||
|
||||
// GetConfigMail returns the text and html template for the mail with a link.
|
||||
func (c TemplateHandler) GetConfigMail(user *domain.User, link string) (io.Reader, io.Reader, error) {
|
||||
var tplBuff bytes.Buffer
|
||||
var htmlTplBuff bytes.Buffer
|
||||
@ -65,6 +67,7 @@ func (c TemplateHandler) GetConfigMail(user *domain.User, link string) (io.Reade
|
||||
return &tplBuff, &htmlTplBuff, nil
|
||||
}
|
||||
|
||||
// GetConfigMailWithAttachment returns the text and html template for the mail with an attachment.
|
||||
func (c TemplateHandler) GetConfigMailWithAttachment(user *domain.User, cfgName, qrName string) (
|
||||
io.Reader,
|
||||
io.Reader,
|
||||
|
@ -1,77 +0,0 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type Authenticator interface {
|
||||
GetExternalLoginProviders(_ context.Context) []domain.LoginProviderInfo
|
||||
IsUserValid(ctx context.Context, id domain.UserIdentifier) bool
|
||||
PlainLogin(ctx context.Context, username, password string) (*domain.User, error)
|
||||
OauthLoginStep1(_ context.Context, providerId string) (authCodeUrl, state, nonce string, err error)
|
||||
OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, error)
|
||||
}
|
||||
|
||||
type UserManager interface {
|
||||
RegisterUser(ctx context.Context, user *domain.User) error
|
||||
NewUser(ctx context.Context, user *domain.User) error
|
||||
StartBackgroundJobs(ctx context.Context)
|
||||
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||
GetAllUsers(ctx context.Context) ([]domain.User, error)
|
||||
UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error)
|
||||
CreateUser(ctx context.Context, user *domain.User) (*domain.User, error)
|
||||
DeleteUser(ctx context.Context, id domain.UserIdentifier) error
|
||||
ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||
DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||
}
|
||||
|
||||
type WireGuardManager interface {
|
||||
StartBackgroundJobs(ctx context.Context)
|
||||
GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error)
|
||||
ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error)
|
||||
RestoreInterfaceState(ctx context.Context, updateDbOnError bool, filter ...domain.InterfaceIdentifier) error
|
||||
CreateDefaultPeer(ctx context.Context, userId domain.UserIdentifier) error
|
||||
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
|
||||
GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error)
|
||||
GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error)
|
||||
GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error)
|
||||
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
|
||||
GetUserInterfaces(ctx context.Context, id domain.UserIdentifier) ([]domain.Interface, error)
|
||||
PrepareInterface(ctx context.Context) (*domain.Interface, error)
|
||||
CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error)
|
||||
UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error)
|
||||
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
|
||||
PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error)
|
||||
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
||||
CreatePeer(ctx context.Context, p *domain.Peer) (*domain.Peer, error)
|
||||
CreateMultiplePeers(
|
||||
ctx context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
r *domain.PeerCreationRequest,
|
||||
) ([]domain.Peer, error)
|
||||
UpdatePeer(ctx context.Context, p *domain.Peer) (*domain.Peer, error)
|
||||
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
|
||||
ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error
|
||||
}
|
||||
|
||||
type StatisticsCollector interface {
|
||||
StartBackgroundJobs(ctx context.Context)
|
||||
}
|
||||
|
||||
type ConfigFileManager interface {
|
||||
GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error)
|
||||
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
|
||||
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
|
||||
PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error
|
||||
}
|
||||
|
||||
type MailManager interface {
|
||||
SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error
|
||||
}
|
||||
|
||||
type ApiV1Manager interface {
|
||||
ApiV1GetUsers(ctx context.Context) ([]domain.User, error)
|
||||
}
|
@ -1,12 +0,0 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type InterfaceAndPeerDatabaseRepo interface {
|
||||
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
|
||||
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
|
||||
}
|
@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
evbus "github.com/vardius/message-bus"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
@ -17,6 +16,22 @@ import (
|
||||
"github.com/h44z/wg-portal/internal/lowlevel"
|
||||
)
|
||||
|
||||
// region dependencies
|
||||
|
||||
type InterfaceAndPeerDatabaseRepo interface {
|
||||
// GetAllInterfaces returns all interfaces
|
||||
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
|
||||
// GetInterfacePeers returns all peers for a given interface
|
||||
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
|
||||
}
|
||||
|
||||
type EventBus interface {
|
||||
// Subscribe subscribes to a topic
|
||||
Subscribe(topic string, fn interface{}) error
|
||||
}
|
||||
|
||||
// endregion dependencies
|
||||
|
||||
type routeRuleInfo struct {
|
||||
ifaceId domain.InterfaceIdentifier
|
||||
fwMark uint32
|
||||
@ -29,14 +44,15 @@ type routeRuleInfo struct {
|
||||
// for default routes.
|
||||
type Manager struct {
|
||||
cfg *config.Config
|
||||
bus evbus.MessageBus
|
||||
|
||||
wg lowlevel.WireGuardClient
|
||||
nl lowlevel.NetlinkClient
|
||||
db InterfaceAndPeerDatabaseRepo
|
||||
bus EventBus
|
||||
wg lowlevel.WireGuardClient
|
||||
nl lowlevel.NetlinkClient
|
||||
db InterfaceAndPeerDatabaseRepo
|
||||
}
|
||||
|
||||
func NewRouteManager(cfg *config.Config, bus evbus.MessageBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) {
|
||||
// NewRouteManager creates a new route manager instance.
|
||||
func NewRouteManager(cfg *config.Config, bus EventBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) {
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
panic("failed to init wgctrl: " + err.Error())
|
||||
@ -63,7 +79,10 @@ func (m Manager) connectToMessageBus() {
|
||||
_ = m.bus.Subscribe(app.TopicRouteRemove, m.handleRouteRemoveEvent)
|
||||
}
|
||||
|
||||
// StartBackgroundJobs starts background jobs for the route manager.
|
||||
// This method is non-blocking and returns immediately.
|
||||
func (m Manager) StartBackgroundJobs(_ context.Context) {
|
||||
// this is a no-op for now
|
||||
}
|
||||
|
||||
func (m Manager) handleRouteUpdateEvent(srcDescription string) {
|
||||
|
@ -1,20 +0,0 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type UserDatabaseRepo interface {
|
||||
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||
GetUserByEmail(ctx context.Context, email string) (*domain.User, error)
|
||||
GetAllUsers(ctx context.Context) ([]domain.User, error)
|
||||
FindUsers(ctx context.Context, search string) ([]domain.User, error)
|
||||
SaveUser(ctx context.Context, id domain.UserIdentifier, updateFunc func(u *domain.User) (*domain.User, error)) error
|
||||
DeleteUser(ctx context.Context, id domain.UserIdentifier) error
|
||||
}
|
||||
|
||||
type PeerDatabaseRepo interface {
|
||||
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
|
||||
}
|
@ -11,7 +11,6 @@ import (
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
"github.com/google/uuid"
|
||||
evbus "github.com/vardius/message-bus"
|
||||
|
||||
"github.com/h44z/wg-portal/internal"
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
@ -19,15 +18,46 @@ import (
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
// region dependencies
|
||||
|
||||
type UserDatabaseRepo interface {
|
||||
// GetUser returns the user with the given identifier.
|
||||
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||
// GetUserByEmail returns the user with the given email address.
|
||||
GetUserByEmail(ctx context.Context, email string) (*domain.User, error)
|
||||
// GetAllUsers returns all users.
|
||||
GetAllUsers(ctx context.Context) ([]domain.User, error)
|
||||
// FindUsers returns all users matching the search string.
|
||||
FindUsers(ctx context.Context, search string) ([]domain.User, error)
|
||||
// SaveUser saves the user with the given identifier.
|
||||
SaveUser(ctx context.Context, id domain.UserIdentifier, updateFunc func(u *domain.User) (*domain.User, error)) error
|
||||
// DeleteUser deletes the user with the given identifier.
|
||||
DeleteUser(ctx context.Context, id domain.UserIdentifier) error
|
||||
}
|
||||
|
||||
type PeerDatabaseRepo interface {
|
||||
// GetUserPeers returns all peers linked to the given user.
|
||||
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
|
||||
}
|
||||
|
||||
type EventBus interface {
|
||||
// Publish sends a message to the message bus.
|
||||
Publish(topic string, args ...any)
|
||||
}
|
||||
|
||||
// endregion dependencies
|
||||
|
||||
// Manager is the user manager.
|
||||
type Manager struct {
|
||||
cfg *config.Config
|
||||
bus evbus.MessageBus
|
||||
|
||||
bus EventBus
|
||||
users UserDatabaseRepo
|
||||
peers PeerDatabaseRepo
|
||||
}
|
||||
|
||||
func NewUserManager(cfg *config.Config, bus evbus.MessageBus, users UserDatabaseRepo, peers PeerDatabaseRepo) (
|
||||
// NewUserManager creates a new user manager instance.
|
||||
func NewUserManager(cfg *config.Config, bus EventBus, users UserDatabaseRepo, peers PeerDatabaseRepo) (
|
||||
*Manager,
|
||||
error,
|
||||
) {
|
||||
@ -41,6 +71,7 @@ func NewUserManager(cfg *config.Config, bus evbus.MessageBus, users UserDatabase
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// RegisterUser registers a new user.
|
||||
func (m Manager) RegisterUser(ctx context.Context, user *domain.User) error {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return err
|
||||
@ -56,6 +87,7 @@ func (m Manager) RegisterUser(ctx context.Context, user *domain.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewUser creates a new user.
|
||||
func (m Manager) NewUser(ctx context.Context, user *domain.User) error {
|
||||
if user.Identifier == "" {
|
||||
return errors.New("missing user identifier")
|
||||
@ -90,12 +122,13 @@ func (m Manager) NewUser(ctx context.Context, user *domain.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartBackgroundJobs starts the background jobs.
|
||||
// This method is non-blocking and returns immediately.
|
||||
func (m Manager) StartBackgroundJobs(ctx context.Context) {
|
||||
|
||||
go m.runLdapSynchronizationService(ctx)
|
||||
|
||||
}
|
||||
|
||||
// GetUser returns the user with the given identifier.
|
||||
func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
|
||||
if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
|
||||
return nil, err
|
||||
@ -112,6 +145,7 @@ func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail returns the user with the given email address.
|
||||
func (m Manager) GetUserByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||||
|
||||
user, err := m.users.GetUserByEmail(ctx, email)
|
||||
@ -130,6 +164,7 @@ func (m Manager) GetUserByEmail(ctx context.Context, email string) (*domain.User
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetAllUsers returns all users.
|
||||
func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, err
|
||||
@ -162,6 +197,7 @@ func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) {
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// UpdateUser updates the user with the given identifier.
|
||||
func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
|
||||
if err := domain.ValidateUserAccessRights(ctx, user.Identifier); err != nil {
|
||||
return nil, err
|
||||
@ -203,6 +239,7 @@ func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.Use
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user.
|
||||
func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, err
|
||||
@ -236,6 +273,7 @@ func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.Use
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// DeleteUser deletes the user with the given identifier.
|
||||
func (m Manager) DeleteUser(ctx context.Context, id domain.UserIdentifier) error {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return err
|
||||
@ -260,6 +298,7 @@ func (m Manager) DeleteUser(ctx context.Context, id domain.UserIdentifier) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// ActivateApi activates the API access for the user with the given identifier.
|
||||
func (m Manager) ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
|
||||
user, err := m.users.GetUser(ctx, id)
|
||||
if err != nil && !errors.Is(err, domain.ErrNotFound) {
|
||||
@ -287,6 +326,7 @@ func (m Manager) ActivateApi(ctx context.Context, id domain.UserIdentifier) (*do
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// DeactivateApi deactivates the API access for the user with the given identifier.
|
||||
func (m Manager) DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
|
||||
user, err := m.users.GetUser(ctx, id)
|
||||
if err != nil && !errors.Is(err, domain.ErrNotFound) {
|
||||
|
@ -1,87 +0,0 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type InterfaceAndPeerDatabaseRepo interface {
|
||||
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
|
||||
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
|
||||
GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error)
|
||||
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
|
||||
FindInterfaces(ctx context.Context, search string) ([]domain.Interface, error)
|
||||
GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error)
|
||||
SaveInterface(
|
||||
ctx context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(in *domain.Interface) (*domain.Interface, error),
|
||||
) error
|
||||
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
|
||||
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
|
||||
FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ([]domain.Peer, error)
|
||||
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
|
||||
FindUserPeers(ctx context.Context, id domain.UserIdentifier, search string) ([]domain.Peer, error)
|
||||
SavePeer(
|
||||
ctx context.Context,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(in *domain.Peer) (*domain.Peer, error),
|
||||
) error
|
||||
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
|
||||
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
||||
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
|
||||
}
|
||||
|
||||
type StatisticsDatabaseRepo interface {
|
||||
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
|
||||
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
|
||||
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
||||
|
||||
UpdatePeerStatus(
|
||||
ctx context.Context,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error),
|
||||
) error
|
||||
UpdateInterfaceStatus(
|
||||
ctx context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error),
|
||||
) error
|
||||
|
||||
DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error
|
||||
}
|
||||
|
||||
type InterfaceController interface {
|
||||
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
|
||||
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
||||
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
||||
GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (
|
||||
*domain.PhysicalPeer,
|
||||
error,
|
||||
)
|
||||
SaveInterface(
|
||||
_ context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
||||
) error
|
||||
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
||||
SavePeer(
|
||||
_ context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
||||
) error
|
||||
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
||||
}
|
||||
|
||||
type WgQuickController interface {
|
||||
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
|
||||
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
|
||||
UnsetDNS(id domain.InterfaceIdentifier) error
|
||||
}
|
||||
|
||||
type MetricsServer interface {
|
||||
UpdateInterfaceMetrics(status domain.InterfaceStatus)
|
||||
UpdatePeerMetrics(peer *domain.Peer, status domain.PeerStatus)
|
||||
}
|
@ -7,31 +7,63 @@ import (
|
||||
"time"
|
||||
|
||||
probing "github.com/prometheus-community/pro-bing"
|
||||
evbus "github.com/vardius/message-bus"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type StatisticsDatabaseRepo interface {
|
||||
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
|
||||
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
|
||||
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
||||
UpdatePeerStatus(
|
||||
ctx context.Context,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error),
|
||||
) error
|
||||
UpdateInterfaceStatus(
|
||||
ctx context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error),
|
||||
) error
|
||||
DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error
|
||||
}
|
||||
|
||||
type StatisticsInterfaceController interface {
|
||||
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
||||
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
||||
}
|
||||
|
||||
type StatisticsMetricsServer interface {
|
||||
UpdateInterfaceMetrics(status domain.InterfaceStatus)
|
||||
UpdatePeerMetrics(peer *domain.Peer, status domain.PeerStatus)
|
||||
}
|
||||
|
||||
type StatisticsEventBus interface {
|
||||
// Subscribe subscribes to a topic
|
||||
Subscribe(topic string, fn interface{}) error
|
||||
}
|
||||
|
||||
type StatisticsCollector struct {
|
||||
cfg *config.Config
|
||||
bus evbus.MessageBus
|
||||
bus StatisticsEventBus
|
||||
|
||||
pingWaitGroup sync.WaitGroup
|
||||
pingJobs chan domain.Peer
|
||||
|
||||
db StatisticsDatabaseRepo
|
||||
wg InterfaceController
|
||||
ms MetricsServer
|
||||
wg StatisticsInterfaceController
|
||||
ms StatisticsMetricsServer
|
||||
}
|
||||
|
||||
// NewStatisticsCollector creates a new statistics collector.
|
||||
func NewStatisticsCollector(
|
||||
cfg *config.Config,
|
||||
bus evbus.MessageBus,
|
||||
bus StatisticsEventBus,
|
||||
db StatisticsDatabaseRepo,
|
||||
wg InterfaceController,
|
||||
ms MetricsServer,
|
||||
wg StatisticsInterfaceController,
|
||||
ms StatisticsMetricsServer,
|
||||
) (*StatisticsCollector, error) {
|
||||
c := &StatisticsCollector{
|
||||
cfg: cfg,
|
||||
@ -47,6 +79,8 @@ func NewStatisticsCollector(
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// StartBackgroundJobs starts the background jobs for the statistics collector.
|
||||
// This method is non-blocking and returns immediately.
|
||||
func (c *StatisticsCollector) StartBackgroundJobs(ctx context.Context) {
|
||||
c.startPingWorkers(ctx)
|
||||
c.startInterfaceDataFetcher(ctx)
|
||||
|
@ -5,17 +5,74 @@ import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
evbus "github.com/vardius/message-bus"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
cfg *config.Config
|
||||
bus evbus.MessageBus
|
||||
// region dependencies
|
||||
|
||||
type InterfaceAndPeerDatabaseRepo interface {
|
||||
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
|
||||
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
|
||||
GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error)
|
||||
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
|
||||
GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error)
|
||||
SaveInterface(
|
||||
ctx context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(in *domain.Interface) (*domain.Interface, error),
|
||||
) error
|
||||
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
|
||||
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
|
||||
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
|
||||
SavePeer(
|
||||
ctx context.Context,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(in *domain.Peer) (*domain.Peer, error),
|
||||
) error
|
||||
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
|
||||
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
||||
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
|
||||
}
|
||||
|
||||
type InterfaceController interface {
|
||||
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
|
||||
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
||||
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
||||
SaveInterface(
|
||||
_ context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
||||
) error
|
||||
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
||||
SavePeer(
|
||||
_ context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
||||
) error
|
||||
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
||||
}
|
||||
|
||||
type WgQuickController interface {
|
||||
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
|
||||
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
|
||||
UnsetDNS(id domain.InterfaceIdentifier) error
|
||||
}
|
||||
|
||||
type EventBus interface {
|
||||
// Publish sends a message to the message bus.
|
||||
Publish(topic string, args ...any)
|
||||
// Subscribe subscribes to a topic
|
||||
Subscribe(topic string, fn interface{}) error
|
||||
}
|
||||
|
||||
// endregion dependencies
|
||||
|
||||
type Manager struct {
|
||||
cfg *config.Config
|
||||
bus EventBus
|
||||
db InterfaceAndPeerDatabaseRepo
|
||||
wg InterfaceController
|
||||
quick WgQuickController
|
||||
@ -23,7 +80,7 @@ type Manager struct {
|
||||
|
||||
func NewWireGuardManager(
|
||||
cfg *config.Config,
|
||||
bus evbus.MessageBus,
|
||||
bus EventBus,
|
||||
wg InterfaceController,
|
||||
quick WgQuickController,
|
||||
db InterfaceAndPeerDatabaseRepo,
|
||||
@ -41,6 +98,8 @@ func NewWireGuardManager(
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// StartBackgroundJobs starts background jobs like the expired peers check.
|
||||
// This method is non-blocking.
|
||||
func (m Manager) StartBackgroundJobs(ctx context.Context) {
|
||||
go m.runExpiredPeersCheck(ctx)
|
||||
}
|
||||
|
@ -13,6 +13,8 @@ import (
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
// GetImportableInterfaces returns all physical interfaces that are available on the system.
|
||||
// This function also returns interfaces that are already available in the database.
|
||||
func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, err
|
||||
@ -26,6 +28,7 @@ func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.Physical
|
||||
return physicalInterfaces, nil
|
||||
}
|
||||
|
||||
// GetInterfaceAndPeers returns the interface and all peers for the given interface identifier.
|
||||
func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
|
||||
*domain.Interface,
|
||||
[]domain.Peer,
|
||||
@ -38,6 +41,7 @@ func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceId
|
||||
return m.db.GetInterfaceAndPeers(ctx, id)
|
||||
}
|
||||
|
||||
// GetAllInterfaces returns all interfaces that are available in the database.
|
||||
func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, err
|
||||
@ -46,6 +50,7 @@ func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, erro
|
||||
return m.db.GetAllInterfaces(ctx)
|
||||
}
|
||||
|
||||
// GetAllInterfacesAndPeers returns all interfaces and their peers.
|
||||
func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, nil, err
|
||||
@ -97,6 +102,7 @@ func (m Manager) GetUserInterfaces(ctx context.Context, _ domain.UserIdentifier)
|
||||
return userInterfaces, nil
|
||||
}
|
||||
|
||||
// ImportNewInterfaces imports all new physical interfaces that are available on the system.
|
||||
func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return 0, err
|
||||
@ -148,6 +154,7 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter
|
||||
return imported, nil
|
||||
}
|
||||
|
||||
// ApplyPeerDefaults applies the interface defaults to all peers of the given interface.
|
||||
func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return err
|
||||
@ -179,6 +186,8 @@ func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) er
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreInterfaceState restores the state of all physical interfaces and their peers.
|
||||
// The final state of the interfaces and peers will be the same as stored in the database.
|
||||
func (m Manager) RestoreInterfaceState(
|
||||
ctx context.Context,
|
||||
updateDbOnError bool,
|
||||
@ -296,6 +305,7 @@ func (m Manager) RestoreInterfaceState(
|
||||
return nil
|
||||
}
|
||||
|
||||
// PrepareInterface generates a new interface with fresh keys, ip addresses and a listen port.
|
||||
func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, err
|
||||
@ -376,6 +386,7 @@ func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error
|
||||
return freshInterface, nil
|
||||
}
|
||||
|
||||
// CreateInterface creates a new interface with the given configuration.
|
||||
func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, err
|
||||
@ -401,6 +412,7 @@ func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*do
|
||||
return in, nil
|
||||
}
|
||||
|
||||
// UpdateInterface updates the given interface with the new configuration.
|
||||
func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error) {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return nil, nil, err
|
||||
@ -423,6 +435,7 @@ func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*do
|
||||
return in, existingPeers, nil
|
||||
}
|
||||
|
||||
// DeleteInterface deletes the given interface.
|
||||
func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return err
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
// CreateDefaultPeer creates a default peer for the given user on all server interfaces.
|
||||
func (m Manager) CreateDefaultPeer(ctx context.Context, userId domain.UserIdentifier) error {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
return err
|
||||
@ -55,6 +56,7 @@ func (m Manager) CreateDefaultPeer(ctx context.Context, userId domain.UserIdenti
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserPeers returns all peers for the given user.
|
||||
func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
|
||||
if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
|
||||
return nil, err
|
||||
@ -63,6 +65,7 @@ func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]
|
||||
return m.db.GetUserPeers(ctx, id)
|
||||
}
|
||||
|
||||
// PreparePeer prepares a new peer for the given interface with fresh keys and ip addresses.
|
||||
func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) {
|
||||
if !m.cfg.Core.SelfProvisioningAllowed {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
@ -143,6 +146,7 @@ func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier)
|
||||
return freshPeer, nil
|
||||
}
|
||||
|
||||
// GetPeer returns the peer with the given identifier.
|
||||
func (m Manager) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) {
|
||||
peer, err := m.db.GetPeer(ctx, id)
|
||||
if err != nil {
|
||||
@ -156,6 +160,7 @@ func (m Manager) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
// CreatePeer creates a new peer.
|
||||
func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
|
||||
if !m.cfg.Core.SelfProvisioningAllowed {
|
||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||
@ -201,6 +206,8 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
// CreateMultiplePeers creates multiple new peers for the given user identifiers.
|
||||
// It calls PreparePeer for each user identifier in the request.
|
||||
func (m Manager) CreateMultiplePeers(
|
||||
ctx context.Context,
|
||||
interfaceId domain.InterfaceIdentifier,
|
||||
@ -243,6 +250,7 @@ func (m Manager) CreateMultiplePeers(
|
||||
return createdPeers, nil
|
||||
}
|
||||
|
||||
// UpdatePeer updates the given peer.
|
||||
func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
|
||||
existingPeer, err := m.db.GetPeer(ctx, peer.Identifier)
|
||||
if err != nil {
|
||||
@ -309,6 +317,7 @@ func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
// DeletePeer deletes the peer with the given identifier.
|
||||
func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error {
|
||||
peer, err := m.db.GetPeer(ctx, id)
|
||||
if err != nil {
|
||||
@ -341,6 +350,7 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPeerStats returns the status of the peer with the given identifier.
|
||||
func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error) {
|
||||
_, peers, err := m.db.GetInterfaceAndPeers(ctx, id)
|
||||
if err != nil {
|
||||
@ -359,6 +369,7 @@ func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier
|
||||
return m.db.GetPeersStats(ctx, peerIds...)
|
||||
}
|
||||
|
||||
// GetUserPeerStats returns the status of all peers for the given user.
|
||||
func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) {
|
||||
if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
|
||||
return nil, err
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
)
|
||||
|
||||
// Auth contains all authentication providers.
|
||||
type Auth struct {
|
||||
// OpenIDConnect contains a list of OpenID Connect providers.
|
||||
OpenIDConnect []OpenIDConnectProvider `yaml:"oidc"`
|
||||
@ -17,6 +18,7 @@ type Auth struct {
|
||||
Ldap []LdapProvider `yaml:"ldap"`
|
||||
}
|
||||
|
||||
// BaseFields contains the basic fields that are used to map user information from the authentication providers.
|
||||
type BaseFields struct {
|
||||
// UserIdentifier is the name of the field that contains the user identifier.
|
||||
UserIdentifier string `yaml:"user_identifier"`
|
||||
@ -32,6 +34,7 @@ type BaseFields struct {
|
||||
Department string `yaml:"department"`
|
||||
}
|
||||
|
||||
// OauthFields contains extra fields that are used to map user information from OAuth providers.
|
||||
type OauthFields struct {
|
||||
BaseFields `yaml:",inline"`
|
||||
// IsAdmin is the name of the field that contains the admin flag.
|
||||
@ -107,12 +110,14 @@ func (o *OauthAdminMapping) GetAdminGroupRegex() *regexp.Regexp {
|
||||
return o.adminGroupRegex
|
||||
}
|
||||
|
||||
// LdapFields contains extra fields that are used to map user information from LDAP providers.
|
||||
type LdapFields struct {
|
||||
BaseFields `yaml:",inline"`
|
||||
// GroupMembership is the name of the LDAP field that contains the groups to which the user belongs.
|
||||
GroupMembership string `yaml:"memberof"`
|
||||
}
|
||||
|
||||
// LdapProvider contains the configuration for the LDAP connection.
|
||||
type LdapProvider struct {
|
||||
// ProviderName is an internal name that is used to distinguish LDAP servers. It must not contain spaces or special characters.
|
||||
ProviderName string `yaml:"provider_name"`
|
||||
@ -163,6 +168,7 @@ type LdapProvider struct {
|
||||
LogUserInfo bool `yaml:"log_user_info"`
|
||||
}
|
||||
|
||||
// OpenIDConnectProvider contains the configuration for the OpenID Connect provider.
|
||||
type OpenIDConnectProvider struct {
|
||||
// ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters.
|
||||
ProviderName string `yaml:"provider_name"`
|
||||
@ -196,6 +202,7 @@ type OpenIDConnectProvider struct {
|
||||
LogUserInfo bool `yaml:"log_user_info"`
|
||||
}
|
||||
|
||||
// OAuthProvider contains the configuration for the OAuth provider.
|
||||
type OAuthProvider struct {
|
||||
// ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters.
|
||||
ProviderName string `yaml:"provider_name"`
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config is the main configuration struct.
|
||||
type Config struct {
|
||||
Core struct {
|
||||
// AdminUser defines the default administrator account that will be created
|
||||
@ -179,6 +180,7 @@ func GetConfig() (*Config, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// loadConfigFile loads the configuration from a YAML file into the given cfg struct.
|
||||
func loadConfigFile(cfg any, filename string) error {
|
||||
data, err := envsubst.ReadFile(filename)
|
||||
if err != nil {
|
||||
|
@ -2,6 +2,8 @@ package config
|
||||
|
||||
import "time"
|
||||
|
||||
// SupportedDatabase is a type for the supported database types.
|
||||
// Supported: mysql, mssql, postgres, sqlite
|
||||
type SupportedDatabase string
|
||||
|
||||
const (
|
||||
@ -11,6 +13,7 @@ const (
|
||||
DatabaseSQLite SupportedDatabase = "sqlite"
|
||||
)
|
||||
|
||||
// DatabaseConfig contains the configuration for the database connection.
|
||||
type DatabaseConfig struct {
|
||||
// Debug enables logging of all database statements
|
||||
Debug bool `yaml:"debug"`
|
||||
|
@ -1,5 +1,7 @@
|
||||
package config
|
||||
|
||||
// MailEncryption is the type of the SMTP encryption.
|
||||
// Supported: none, tls, starttls
|
||||
type MailEncryption string
|
||||
|
||||
const (
|
||||
@ -8,6 +10,8 @@ const (
|
||||
MailEncryptionStartTLS MailEncryption = "starttls"
|
||||
)
|
||||
|
||||
// MailAuthType is the type of the SMTP authentication.
|
||||
// Supported: plain, login, crammd5
|
||||
type MailAuthType string
|
||||
|
||||
const (
|
||||
@ -16,6 +20,7 @@ const (
|
||||
MailAuthCramMD5 MailAuthType = "crammd5"
|
||||
)
|
||||
|
||||
// MailConfig contains the configuration for the mail server which is used to send emails.
|
||||
type MailConfig struct {
|
||||
// Host is the hostname or IP of the SMTP server
|
||||
Host string `yaml:"host"`
|
||||
|
@ -1,5 +1,6 @@
|
||||
package config
|
||||
|
||||
// WebConfig contains the configuration for the web server.
|
||||
type WebConfig struct {
|
||||
// RequestLogging enables logging of all HTTP requests.
|
||||
RequestLogging bool `yaml:"request_logging"`
|
||||
|
@ -1,11 +1,5 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type LoginProvider string
|
||||
|
||||
type LoginProviderInfo struct {
|
||||
@ -24,28 +18,3 @@ type AuthenticatorUserInfo struct {
|
||||
Department string
|
||||
IsAdmin bool
|
||||
}
|
||||
|
||||
type AuthenticatorType string
|
||||
|
||||
const (
|
||||
AuthenticatorTypeOAuth AuthenticatorType = "oauth"
|
||||
AuthenticatorTypeOidc AuthenticatorType = "oidc"
|
||||
)
|
||||
|
||||
type OauthAuthenticator interface {
|
||||
GetName() string
|
||||
GetType() AuthenticatorType
|
||||
AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
|
||||
Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
|
||||
GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (map[string]any, error)
|
||||
ParseUserInfo(raw map[string]any) (*AuthenticatorUserInfo, error)
|
||||
RegistrationEnabled() bool
|
||||
}
|
||||
|
||||
type LdapAuthenticator interface {
|
||||
GetName() string
|
||||
PlaintextAuthentication(userId UserIdentifier, plainPassword string) error
|
||||
GetUserInfo(ctx context.Context, username UserIdentifier) (map[string]any, error)
|
||||
ParseUserInfo(raw map[string]any) (*AuthenticatorUserInfo, error)
|
||||
RegistrationEnabled() bool
|
||||
}
|
||||
|
@ -33,6 +33,7 @@ func (p KeyPair) GetPublicKey() wgtypes.Key {
|
||||
|
||||
type PreSharedKey string
|
||||
|
||||
// NewFreshKeypair generates a new key pair.
|
||||
func NewFreshKeypair() (KeyPair, error) {
|
||||
privateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
@ -45,6 +46,7 @@ func NewFreshKeypair() (KeyPair, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewPreSharedKey generates a new pre-shared key.
|
||||
func NewPreSharedKey() (PreSharedKey, error) {
|
||||
preSharedKey, err := wgtypes.GenerateKey()
|
||||
if err != nil {
|
||||
@ -54,6 +56,8 @@ func NewPreSharedKey() (PreSharedKey, error) {
|
||||
return PreSharedKey(preSharedKey.String()), nil
|
||||
}
|
||||
|
||||
// PublicKeyFromPrivateKey returns the public key for a given private key.
|
||||
// If the private key is invalid, an empty string is returned.
|
||||
func PublicKeyFromPrivateKey(key string) string {
|
||||
privKey, err := wgtypes.ParseKey(key)
|
||||
if err != nil {
|
||||
|
56
internal/domain/crypto_test.go
Normal file
56
internal/domain/crypto_test.go
Normal file
@ -0,0 +1,56 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
func TestKeyPair_GetPrivateKeyBytesReturnsCorrectBytes(t *testing.T) {
|
||||
keyPair := KeyPair{PrivateKey: base64.StdEncoding.EncodeToString([]byte("privateKey"))}
|
||||
expected := []byte("privateKey")
|
||||
assert.Equal(t, expected, keyPair.GetPrivateKeyBytes())
|
||||
}
|
||||
|
||||
func TestKeyPair_GetPublicKeyBytesReturnsCorrectBytes(t *testing.T) {
|
||||
keyPair := KeyPair{PublicKey: base64.StdEncoding.EncodeToString([]byte("publicKey"))}
|
||||
expected := []byte("publicKey")
|
||||
assert.Equal(t, expected, keyPair.GetPublicKeyBytes())
|
||||
}
|
||||
|
||||
func TestKeyPair_GetPrivateKeyReturnsCorrectKey(t *testing.T) {
|
||||
privateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
keyPair := KeyPair{PrivateKey: privateKey.String()}
|
||||
assert.Equal(t, privateKey, keyPair.GetPrivateKey())
|
||||
}
|
||||
|
||||
func TestKeyPair_GetPublicKeyReturnsCorrectKey(t *testing.T) {
|
||||
privateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
keyPair := KeyPair{PublicKey: privateKey.PublicKey().String()}
|
||||
assert.Equal(t, privateKey.PublicKey(), keyPair.GetPublicKey())
|
||||
}
|
||||
|
||||
func TestNewFreshKeypairGeneratesValidKeypair(t *testing.T) {
|
||||
keyPair, err := NewFreshKeypair()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, keyPair.PrivateKey)
|
||||
assert.NotEmpty(t, keyPair.PublicKey)
|
||||
}
|
||||
|
||||
func TestNewPreSharedKeyGeneratesValidKey(t *testing.T) {
|
||||
preSharedKey, err := NewPreSharedKey()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, preSharedKey)
|
||||
}
|
||||
|
||||
func TestPublicKeyFromPrivateKeyReturnsCorrectPublicKey(t *testing.T) {
|
||||
privateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
expected := privateKey.PublicKey().String()
|
||||
assert.Equal(t, expected, PublicKeyFromPrivateKey(privateKey.String()))
|
||||
}
|
||||
|
||||
func TestPublicKeyFromPrivateKeyReturnsEmptyStringOnInvalidKey(t *testing.T) {
|
||||
assert.Equal(t, "", PublicKeyFromPrivateKey("invalidKey"))
|
||||
}
|
83
internal/domain/interface_test.go
Normal file
83
internal/domain/interface_test.go
Normal file
@ -0,0 +1,83 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestInterface_IsDisabledReturnsTrueWhenDisabled(t *testing.T) {
|
||||
iface := &Interface{}
|
||||
assert.False(t, iface.IsDisabled())
|
||||
|
||||
now := time.Now()
|
||||
iface.Disabled = &now
|
||||
assert.True(t, iface.IsDisabled())
|
||||
}
|
||||
|
||||
func TestInterface_AddressStrReturnsCorrectString(t *testing.T) {
|
||||
iface := &Interface{
|
||||
Addresses: []Cidr{
|
||||
{Cidr: "192.168.1.1/24", Addr: "192.168.1.1", NetLength: 24},
|
||||
{Cidr: "10.0.0.1/24", Addr: "10.0.0.1", NetLength: 24},
|
||||
},
|
||||
}
|
||||
expected := "192.168.1.1/24,10.0.0.1/24"
|
||||
assert.Equal(t, expected, iface.AddressStr())
|
||||
}
|
||||
|
||||
func TestInterface_GetConfigFileNameReturnsCorrectFileName(t *testing.T) {
|
||||
iface := &Interface{Identifier: "wg0"}
|
||||
expected := "wg0.conf"
|
||||
assert.Equal(t, expected, iface.GetConfigFileName())
|
||||
|
||||
iface.Identifier = "wg0@123"
|
||||
expected = "wg0123.conf"
|
||||
assert.Equal(t, expected, iface.GetConfigFileName())
|
||||
}
|
||||
|
||||
func TestInterface_GetAllowedIPsReturnsCorrectCidrs(t *testing.T) {
|
||||
peer1 := Peer{
|
||||
Interface: PeerInterfaceConfig{
|
||||
Addresses: []Cidr{
|
||||
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
|
||||
},
|
||||
},
|
||||
}
|
||||
peer2 := Peer{
|
||||
Interface: PeerInterfaceConfig{
|
||||
Addresses: []Cidr{
|
||||
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
|
||||
},
|
||||
},
|
||||
}
|
||||
iface := &Interface{}
|
||||
expected := []Cidr{
|
||||
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
|
||||
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
|
||||
}
|
||||
assert.Equal(t, expected, iface.GetAllowedIPs([]Peer{peer1, peer2}))
|
||||
}
|
||||
|
||||
func TestInterface_ManageRoutingTableReturnsCorrectValue(t *testing.T) {
|
||||
iface := &Interface{RoutingTable: "off"}
|
||||
assert.False(t, iface.ManageRoutingTable())
|
||||
|
||||
iface.RoutingTable = "100"
|
||||
assert.True(t, iface.ManageRoutingTable())
|
||||
}
|
||||
|
||||
func TestInterface_GetRoutingTableReturnsCorrectValue(t *testing.T) {
|
||||
iface := &Interface{RoutingTable: ""}
|
||||
assert.Equal(t, 0, iface.GetRoutingTable())
|
||||
|
||||
iface.RoutingTable = "off"
|
||||
assert.Equal(t, -1, iface.GetRoutingTable())
|
||||
|
||||
iface.RoutingTable = "0x64"
|
||||
assert.Equal(t, 100, iface.GetRoutingTable())
|
||||
|
||||
iface.RoutingTable = "200"
|
||||
assert.Equal(t, 200, iface.GetRoutingTable())
|
||||
}
|
42
internal/domain/options_test.go
Normal file
42
internal/domain/options_test.go
Normal file
@ -0,0 +1,42 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfigOption_GetValueReturnsCorrectValue(t *testing.T) {
|
||||
option := ConfigOption[int]{Value: 42}
|
||||
assert.Equal(t, 42, option.GetValue())
|
||||
}
|
||||
|
||||
func TestConfigOption_SetValueUpdatesValue(t *testing.T) {
|
||||
option := ConfigOption[int]{Value: 42}
|
||||
option.SetValue(100)
|
||||
assert.Equal(t, 100, option.GetValue())
|
||||
}
|
||||
|
||||
func TestConfigOption_TrySetValueUpdatesValueWhenOverridable(t *testing.T) {
|
||||
option := ConfigOption[int]{Value: 42, Overridable: true}
|
||||
result := option.TrySetValue(100)
|
||||
assert.True(t, result)
|
||||
assert.Equal(t, 100, option.GetValue())
|
||||
}
|
||||
|
||||
func TestConfigOption_TrySetValueDoesNotUpdateValueWhenNotOverridable(t *testing.T) {
|
||||
option := ConfigOption[int]{Value: 42, Overridable: false}
|
||||
result := option.TrySetValue(100)
|
||||
assert.False(t, result)
|
||||
assert.Equal(t, 42, option.GetValue())
|
||||
}
|
||||
|
||||
func TestNewConfigOptionCreatesCorrectOption(t *testing.T) {
|
||||
option := NewConfigOption(42, true)
|
||||
assert.Equal(t, 42, option.GetValue())
|
||||
assert.True(t, option.Overridable)
|
||||
|
||||
option2 := NewConfigOption("str", false)
|
||||
assert.Equal(t, "str", option2.GetValue())
|
||||
assert.False(t, option2.Overridable)
|
||||
}
|
165
internal/domain/peer_test.go
Normal file
165
internal/domain/peer_test.go
Normal file
@ -0,0 +1,165 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPeer_IsDisabled(t *testing.T) {
|
||||
peer := &Peer{}
|
||||
assert.False(t, peer.IsDisabled())
|
||||
|
||||
now := time.Now()
|
||||
peer.Disabled = &now
|
||||
assert.True(t, peer.IsDisabled())
|
||||
}
|
||||
|
||||
func TestPeer_IsExpired(t *testing.T) {
|
||||
peer := &Peer{}
|
||||
assert.False(t, peer.IsExpired())
|
||||
|
||||
expiredTime := time.Now().Add(-time.Hour)
|
||||
peer.ExpiresAt = &expiredTime
|
||||
assert.True(t, peer.IsExpired())
|
||||
|
||||
futureTime := time.Now().Add(time.Hour)
|
||||
peer.ExpiresAt = &futureTime
|
||||
assert.False(t, peer.IsExpired())
|
||||
}
|
||||
|
||||
func TestPeer_CheckAliveAddress(t *testing.T) {
|
||||
peer := &Peer{}
|
||||
assert.Equal(t, "", peer.CheckAliveAddress())
|
||||
|
||||
peer.Interface.CheckAliveAddress = "192.168.1.1"
|
||||
assert.Equal(t, "192.168.1.1", peer.CheckAliveAddress())
|
||||
|
||||
peer.Interface.CheckAliveAddress = ""
|
||||
peer.Interface.Addresses = []Cidr{{Addr: "10.0.0.1"}}
|
||||
assert.Equal(t, "10.0.0.1", peer.CheckAliveAddress())
|
||||
}
|
||||
|
||||
func TestPeer_GetConfigFileName(t *testing.T) {
|
||||
peer := &Peer{DisplayName: "Test Peer"}
|
||||
expected := "Test_Peer.conf"
|
||||
assert.Equal(t, expected, peer.GetConfigFileName())
|
||||
|
||||
peer.DisplayName = ""
|
||||
peer.Identifier = "12345678"
|
||||
expected = "wg_12345678.conf"
|
||||
assert.Equal(t, expected, peer.GetConfigFileName())
|
||||
}
|
||||
|
||||
func TestPeer_ApplyInterfaceDefaults(t *testing.T) {
|
||||
peer := &Peer{
|
||||
Endpoint: ConfigOption[string]{
|
||||
Value: "",
|
||||
Overridable: true,
|
||||
},
|
||||
EndpointPublicKey: ConfigOption[string]{
|
||||
Value: "",
|
||||
Overridable: true,
|
||||
},
|
||||
AllowedIPsStr: ConfigOption[string]{
|
||||
Value: "1.1.1.1/32",
|
||||
Overridable: false,
|
||||
},
|
||||
}
|
||||
iface := &Interface{
|
||||
PeerDefEndpoint: "192.168.1.1",
|
||||
KeyPair: KeyPair{
|
||||
PublicKey: "publicKey",
|
||||
},
|
||||
PeerDefAllowedIPsStr: "8.8.8.8/32",
|
||||
}
|
||||
|
||||
peer.ApplyInterfaceDefaults(iface)
|
||||
assert.Equal(t, "192.168.1.1", peer.Endpoint.GetValue())
|
||||
assert.Equal(t, "publicKey", peer.EndpointPublicKey.GetValue())
|
||||
assert.Equal(t, "1.1.1.1/32", peer.AllowedIPsStr.GetValue())
|
||||
}
|
||||
|
||||
func TestPeer_GenerateDisplayName(t *testing.T) {
|
||||
peer := &Peer{Identifier: "12345678"}
|
||||
peer.GenerateDisplayName("Prefix")
|
||||
expected := "Prefix Peer 12345678"
|
||||
assert.Equal(t, expected, peer.DisplayName)
|
||||
|
||||
peer.GenerateDisplayName("")
|
||||
expected = "Peer 12345678"
|
||||
assert.Equal(t, expected, peer.DisplayName)
|
||||
}
|
||||
|
||||
func TestPeer_OverwriteUserEditableFields(t *testing.T) {
|
||||
peer := &Peer{}
|
||||
userPeer := &Peer{
|
||||
DisplayName: "New DisplayName",
|
||||
}
|
||||
|
||||
peer.OverwriteUserEditableFields(userPeer)
|
||||
assert.Equal(t, "New DisplayName", peer.DisplayName)
|
||||
}
|
||||
|
||||
func TestPeer_GetPresharedKey(t *testing.T) {
|
||||
physicalPeer := PhysicalPeer{}
|
||||
assert.Nil(t, physicalPeer.GetPresharedKey())
|
||||
|
||||
physicalPeer.PresharedKey = "Q0evIJTOjhyy2o5J7whvrsvQC+FRL8A74vrw44YHUAk="
|
||||
key := physicalPeer.GetPresharedKey()
|
||||
assert.NotNil(t, key)
|
||||
}
|
||||
|
||||
func TestPeer_GetEndpointAddress(t *testing.T) {
|
||||
physicalPeer := PhysicalPeer{}
|
||||
assert.Nil(t, physicalPeer.GetEndpointAddress())
|
||||
|
||||
physicalPeer.Endpoint = "192.168.1.1:51820"
|
||||
addr := physicalPeer.GetEndpointAddress()
|
||||
assert.NotNil(t, addr)
|
||||
assert.Equal(t, "192.168.1.1:51820", addr.String())
|
||||
}
|
||||
|
||||
func TestPeer_GetPersistentKeepaliveTime(t *testing.T) {
|
||||
physicalPeer := PhysicalPeer{}
|
||||
assert.Nil(t, physicalPeer.GetPersistentKeepaliveTime())
|
||||
|
||||
physicalPeer.PersistentKeepalive = 25
|
||||
duration := physicalPeer.GetPersistentKeepaliveTime()
|
||||
assert.NotNil(t, duration)
|
||||
assert.Equal(t, 25*time.Second, *duration)
|
||||
}
|
||||
|
||||
func TestPeer_GetAllowedIPs(t *testing.T) {
|
||||
physicalPeer := PhysicalPeer{}
|
||||
assert.Empty(t, physicalPeer.GetAllowedIPs())
|
||||
|
||||
physicalPeer.AllowedIPs = []Cidr{
|
||||
{
|
||||
Cidr: "192.168.1.0/24",
|
||||
Addr: "192.168.1.0",
|
||||
NetLength: 24,
|
||||
},
|
||||
}
|
||||
ips := physicalPeer.GetAllowedIPs()
|
||||
assert.Len(t, ips, 1)
|
||||
assert.Equal(t, "192.168.1.0/24", ips[0].String())
|
||||
|
||||
physicalPeer.AllowedIPs = []Cidr{
|
||||
{
|
||||
Cidr: "192.168.1.0/24",
|
||||
Addr: "192.168.1.0",
|
||||
NetLength: 24,
|
||||
},
|
||||
{
|
||||
Cidr: "fe80::/64",
|
||||
Addr: "fe80::",
|
||||
NetLength: 64,
|
||||
},
|
||||
}
|
||||
ips2 := physicalPeer.GetAllowedIPs()
|
||||
assert.Len(t, ips2, 2)
|
||||
assert.Equal(t, "192.168.1.0/24", ips2[0].String())
|
||||
assert.Equal(t, "fe80::/64", ips2[1].String())
|
||||
}
|
74
internal/domain/statistics_test.go
Normal file
74
internal/domain/statistics_test.go
Normal file
@ -0,0 +1,74 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPeerStatus_IsConnected(t *testing.T) {
|
||||
now := time.Now()
|
||||
past := now.Add(-3 * time.Minute)
|
||||
recent := now.Add(-1 * time.Minute)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
status PeerStatus
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Pingable and recent handshake",
|
||||
status: PeerStatus{
|
||||
IsPingable: true,
|
||||
LastHandshake: &recent,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Not pingable but recent handshake",
|
||||
status: PeerStatus{
|
||||
IsPingable: false,
|
||||
LastHandshake: &recent,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Pingable but old handshake",
|
||||
status: PeerStatus{
|
||||
IsPingable: true,
|
||||
LastHandshake: &past,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Not pingable and old handshake",
|
||||
status: PeerStatus{
|
||||
IsPingable: false,
|
||||
LastHandshake: &past,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Pingable and no handshake",
|
||||
status: PeerStatus{
|
||||
IsPingable: true,
|
||||
LastHandshake: nil,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Not pingable and no handshake",
|
||||
status: PeerStatus{
|
||||
IsPingable: false,
|
||||
LastHandshake: nil,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.status.IsConnected(); got != tt.want {
|
||||
t.Errorf("IsConnected() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
125
internal/domain/user_test.go
Normal file
125
internal/domain/user_test.go
Normal file
@ -0,0 +1,125 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func TestUser_IsDisabled(t *testing.T) {
|
||||
user := &User{}
|
||||
assert.False(t, user.IsDisabled())
|
||||
|
||||
now := time.Now()
|
||||
user.Disabled = &now
|
||||
assert.True(t, user.IsDisabled())
|
||||
}
|
||||
|
||||
func TestUser_IsLocked(t *testing.T) {
|
||||
user := &User{}
|
||||
assert.False(t, user.IsLocked())
|
||||
|
||||
now := time.Now()
|
||||
user.Locked = &now
|
||||
assert.True(t, user.IsLocked())
|
||||
}
|
||||
|
||||
func TestUser_IsApiEnabled(t *testing.T) {
|
||||
user := &User{}
|
||||
assert.False(t, user.IsApiEnabled())
|
||||
|
||||
user.ApiToken = "token"
|
||||
assert.True(t, user.IsApiEnabled())
|
||||
}
|
||||
|
||||
func TestUser_CanChangePassword(t *testing.T) {
|
||||
user := &User{Source: UserSourceDatabase}
|
||||
assert.NoError(t, user.CanChangePassword())
|
||||
|
||||
user.Source = UserSourceLdap
|
||||
assert.Error(t, user.CanChangePassword())
|
||||
|
||||
user.Source = UserSourceOauth
|
||||
assert.Error(t, user.CanChangePassword())
|
||||
}
|
||||
|
||||
func TestUser_EditAllowed(t *testing.T) {
|
||||
user := &User{Source: UserSourceDatabase}
|
||||
newUser := &User{Source: UserSourceDatabase}
|
||||
assert.NoError(t, user.EditAllowed(newUser))
|
||||
|
||||
newUser.Notes = "notes can be changed"
|
||||
assert.NoError(t, user.EditAllowed(newUser))
|
||||
|
||||
newUser.Disabled = &time.Time{}
|
||||
assert.NoError(t, user.EditAllowed(newUser))
|
||||
|
||||
newUser.Lastname = "lastname or other fields can be changed"
|
||||
assert.NoError(t, user.EditAllowed(newUser))
|
||||
|
||||
user.Source = UserSourceLdap
|
||||
newUser.Source = UserSourceLdap
|
||||
newUser.Disabled = nil
|
||||
newUser.Lastname = ""
|
||||
newUser.Notes = "notes can be changed"
|
||||
assert.NoError(t, user.EditAllowed(newUser))
|
||||
|
||||
newUser.Disabled = &time.Time{}
|
||||
assert.NoError(t, user.EditAllowed(newUser))
|
||||
|
||||
newUser.Lastname = "lastname or other fields can not be changed"
|
||||
assert.Error(t, user.EditAllowed(newUser))
|
||||
|
||||
user.Source = UserSourceOauth
|
||||
newUser.Source = UserSourceOauth
|
||||
newUser.Disabled = nil
|
||||
newUser.Lastname = ""
|
||||
newUser.Notes = "notes can be changed"
|
||||
assert.NoError(t, user.EditAllowed(newUser))
|
||||
|
||||
newUser.Disabled = &time.Time{}
|
||||
assert.NoError(t, user.EditAllowed(newUser))
|
||||
|
||||
newUser.Lastname = "lastname or other fields can not be changed"
|
||||
assert.Error(t, user.EditAllowed(newUser))
|
||||
}
|
||||
|
||||
func TestUser_DeleteAllowed(t *testing.T) {
|
||||
user := &User{}
|
||||
assert.NoError(t, user.DeleteAllowed())
|
||||
}
|
||||
|
||||
func TestUser_CheckPassword(t *testing.T) {
|
||||
password := "password"
|
||||
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
|
||||
user := &User{Source: UserSourceDatabase, Password: PrivateString(hashedPassword)}
|
||||
assert.NoError(t, user.CheckPassword(password))
|
||||
|
||||
user.Password = ""
|
||||
assert.Error(t, user.CheckPassword(password))
|
||||
|
||||
user.Source = UserSourceLdap
|
||||
assert.Error(t, user.CheckPassword(password))
|
||||
}
|
||||
|
||||
func TestUser_CheckApiToken(t *testing.T) {
|
||||
user := &User{}
|
||||
assert.Error(t, user.CheckApiToken("token"))
|
||||
|
||||
user.ApiToken = "token"
|
||||
assert.NoError(t, user.CheckApiToken("token"))
|
||||
|
||||
assert.Error(t, user.CheckApiToken("wrong_token"))
|
||||
}
|
||||
|
||||
func TestUser_HashPassword(t *testing.T) {
|
||||
user := &User{Password: "password"}
|
||||
assert.NoError(t, user.HashPassword())
|
||||
assert.NotEmpty(t, user.Password)
|
||||
|
||||
user.Password = ""
|
||||
assert.NoError(t, user.HashPassword())
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user