mirror of
https://github.com/h44z/wg-portal.git
synced 2025-04-19 08:55:12 +00:00
chore: use interfaces for all other services
This commit is contained in:
parent
02ed7b19df
commit
7d0da4e7ad
@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/h44z/wg-portal/internal/adapters"
|
"github.com/h44z/wg-portal/internal/adapters"
|
||||||
"github.com/h44z/wg-portal/internal/app"
|
"github.com/h44z/wg-portal/internal/app"
|
||||||
"github.com/h44z/wg-portal/internal/app/api/core"
|
"github.com/h44z/wg-portal/internal/app/api/core"
|
||||||
|
backendV0 "github.com/h44z/wg-portal/internal/app/api/v0/backend"
|
||||||
handlersV0 "github.com/h44z/wg-portal/internal/app/api/v0/handlers"
|
handlersV0 "github.com/h44z/wg-portal/internal/app/api/v0/handlers"
|
||||||
backendV1 "github.com/h44z/wg-portal/internal/app/api/v1/backend"
|
backendV1 "github.com/h44z/wg-portal/internal/app/api/v1/backend"
|
||||||
handlersV1 "github.com/h44z/wg-portal/internal/app/api/v1/handlers"
|
handlersV1 "github.com/h44z/wg-portal/internal/app/api/v1/handlers"
|
||||||
@ -70,17 +71,24 @@ func main() {
|
|||||||
queueSize := 100
|
queueSize := 100
|
||||||
eventBus := evbus.New(queueSize)
|
eventBus := evbus.New(queueSize)
|
||||||
|
|
||||||
|
auditRecorder, err := audit.NewAuditRecorder(cfg, eventBus, database)
|
||||||
|
internal.AssertNoError(err)
|
||||||
|
auditRecorder.StartBackgroundJobs(ctx)
|
||||||
|
|
||||||
userManager, err := users.NewUserManager(cfg, eventBus, database, database)
|
userManager, err := users.NewUserManager(cfg, eventBus, database, database)
|
||||||
internal.AssertNoError(err)
|
internal.AssertNoError(err)
|
||||||
|
userManager.StartBackgroundJobs(ctx)
|
||||||
|
|
||||||
authenticator, err := auth.NewAuthenticator(&cfg.Auth, cfg.Web.ExternalUrl, eventBus, userManager)
|
authenticator, err := auth.NewAuthenticator(&cfg.Auth, cfg.Web.ExternalUrl, eventBus, userManager)
|
||||||
internal.AssertNoError(err)
|
internal.AssertNoError(err)
|
||||||
|
|
||||||
wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, wgQuick, database)
|
wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, wgQuick, database)
|
||||||
internal.AssertNoError(err)
|
internal.AssertNoError(err)
|
||||||
|
wireGuardManager.StartBackgroundJobs(ctx)
|
||||||
|
|
||||||
statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, eventBus, database, wireGuard, metricsServer)
|
statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, eventBus, database, wireGuard, metricsServer)
|
||||||
internal.AssertNoError(err)
|
internal.AssertNoError(err)
|
||||||
|
statisticsCollector.StartBackgroundJobs(ctx)
|
||||||
|
|
||||||
cfgFileManager, err := configfile.NewConfigFileManager(cfg, eventBus, database, database, cfgFileSystem)
|
cfgFileManager, err := configfile.NewConfigFileManager(cfg, eventBus, database, database, cfgFileSystem)
|
||||||
internal.AssertNoError(err)
|
internal.AssertNoError(err)
|
||||||
@ -88,18 +96,11 @@ func main() {
|
|||||||
mailManager, err := mail.NewMailManager(cfg, mailer, cfgFileManager, database, database)
|
mailManager, err := mail.NewMailManager(cfg, mailer, cfgFileManager, database, database)
|
||||||
internal.AssertNoError(err)
|
internal.AssertNoError(err)
|
||||||
|
|
||||||
auditRecorder, err := audit.NewAuditRecorder(cfg, eventBus, database)
|
|
||||||
internal.AssertNoError(err)
|
|
||||||
auditRecorder.StartBackgroundJobs(ctx)
|
|
||||||
|
|
||||||
routeManager, err := route.NewRouteManager(cfg, eventBus, database)
|
routeManager, err := route.NewRouteManager(cfg, eventBus, database)
|
||||||
internal.AssertNoError(err)
|
internal.AssertNoError(err)
|
||||||
routeManager.StartBackgroundJobs(ctx)
|
routeManager.StartBackgroundJobs(ctx)
|
||||||
|
|
||||||
backend, err := app.New(cfg, eventBus, authenticator, userManager, wireGuardManager,
|
err = app.Initialize(cfg, wireGuardManager, userManager)
|
||||||
statisticsCollector, cfgFileManager, mailManager)
|
|
||||||
internal.AssertNoError(err)
|
|
||||||
err = backend.Startup(ctx)
|
|
||||||
internal.AssertNoError(err)
|
internal.AssertNoError(err)
|
||||||
|
|
||||||
validatorManager := validator.New()
|
validatorManager := validator.New()
|
||||||
@ -109,10 +110,14 @@ func main() {
|
|||||||
apiV0Session := handlersV0.NewSessionWrapper(cfg)
|
apiV0Session := handlersV0.NewSessionWrapper(cfg)
|
||||||
apiV0Auth := handlersV0.NewAuthenticationHandler(authenticator, apiV0Session)
|
apiV0Auth := handlersV0.NewAuthenticationHandler(authenticator, apiV0Session)
|
||||||
|
|
||||||
apiV0EndpointAuth := handlersV0.NewAuthEndpoint(cfg, apiV0Auth, apiV0Session, validatorManager, backend)
|
apiV0BackendUsers := backendV0.NewUserService(cfg, userManager, wireGuardManager)
|
||||||
apiV0EndpointUsers := handlersV0.NewUserEndpoint(cfg, apiV0Auth, validatorManager, backend)
|
apiV0BackendInterfaces := backendV0.NewInterfaceService(cfg, wireGuardManager, cfgFileManager)
|
||||||
apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(cfg, apiV0Auth, validatorManager, backend)
|
apiV0BackendPeers := backendV0.NewPeerService(cfg, wireGuardManager, cfgFileManager, mailManager)
|
||||||
apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, backend)
|
|
||||||
|
apiV0EndpointAuth := handlersV0.NewAuthEndpoint(cfg, apiV0Auth, apiV0Session, validatorManager, authenticator)
|
||||||
|
apiV0EndpointUsers := handlersV0.NewUserEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendUsers)
|
||||||
|
apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendInterfaces)
|
||||||
|
apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendPeers)
|
||||||
apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth)
|
apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth)
|
||||||
apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth)
|
apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth)
|
||||||
|
|
||||||
|
91
internal/app/api/v0/backend/interface_service.go
Normal file
91
internal/app/api/v0/backend/interface_service.go
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/h44z/wg-portal/internal/config"
|
||||||
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
// region dependencies
|
||||||
|
|
||||||
|
type InterfaceServiceInterfaceManager interface {
|
||||||
|
GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error)
|
||||||
|
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
|
||||||
|
CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error)
|
||||||
|
UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error)
|
||||||
|
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
|
||||||
|
PrepareInterface(ctx context.Context) (*domain.Interface, error)
|
||||||
|
ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type InterfaceServiceConfigFileManager interface {
|
||||||
|
PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error
|
||||||
|
GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion dependencies
|
||||||
|
|
||||||
|
type InterfaceService struct {
|
||||||
|
cfg *config.Config
|
||||||
|
|
||||||
|
interfaces InterfaceServiceInterfaceManager
|
||||||
|
configFile InterfaceServiceConfigFileManager
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewInterfaceService(
|
||||||
|
cfg *config.Config,
|
||||||
|
interfaces InterfaceServiceInterfaceManager,
|
||||||
|
configFile InterfaceServiceConfigFileManager,
|
||||||
|
) *InterfaceService {
|
||||||
|
return &InterfaceService{
|
||||||
|
cfg: cfg,
|
||||||
|
interfaces: interfaces,
|
||||||
|
configFile: configFile,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i InterfaceService) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
|
||||||
|
*domain.Interface,
|
||||||
|
[]domain.Peer,
|
||||||
|
error,
|
||||||
|
) {
|
||||||
|
return i.interfaces.GetInterfaceAndPeers(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i InterfaceService) PrepareInterface(ctx context.Context) (*domain.Interface, error) {
|
||||||
|
return i.interfaces.PrepareInterface(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i InterfaceService) CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) {
|
||||||
|
return i.interfaces.CreateInterface(ctx, in)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i InterfaceService) UpdateInterface(ctx context.Context, in *domain.Interface) (
|
||||||
|
*domain.Interface,
|
||||||
|
[]domain.Peer,
|
||||||
|
error,
|
||||||
|
) {
|
||||||
|
return i.interfaces.UpdateInterface(ctx, in)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i InterfaceService) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
|
||||||
|
return i.interfaces.DeleteInterface(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i InterfaceService) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) {
|
||||||
|
return i.interfaces.GetAllInterfacesAndPeers(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i InterfaceService) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) {
|
||||||
|
return i.configFile.GetInterfaceConfig(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i InterfaceService) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error {
|
||||||
|
return i.configFile.PersistInterfaceConfig(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i InterfaceService) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error {
|
||||||
|
return i.interfaces.ApplyPeerDefaults(ctx, in)
|
||||||
|
}
|
112
internal/app/api/v0/backend/peer_service.go
Normal file
112
internal/app/api/v0/backend/peer_service.go
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/h44z/wg-portal/internal/config"
|
||||||
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
// region dependencies
|
||||||
|
|
||||||
|
type PeerServicePeerManager interface {
|
||||||
|
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
||||||
|
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
|
||||||
|
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
|
||||||
|
PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error)
|
||||||
|
CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error)
|
||||||
|
UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error)
|
||||||
|
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
|
||||||
|
CreateMultiplePeers(
|
||||||
|
ctx context.Context,
|
||||||
|
interfaceId domain.InterfaceIdentifier,
|
||||||
|
r *domain.PeerCreationRequest,
|
||||||
|
) ([]domain.Peer, error)
|
||||||
|
GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type PeerServiceConfigFileManager interface {
|
||||||
|
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
|
||||||
|
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type PeerServiceMailManager interface {
|
||||||
|
SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion dependencies
|
||||||
|
|
||||||
|
type PeerService struct {
|
||||||
|
cfg *config.Config
|
||||||
|
|
||||||
|
peers PeerServicePeerManager
|
||||||
|
configFile PeerServiceConfigFileManager
|
||||||
|
mailer PeerServiceMailManager
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPeerService(
|
||||||
|
cfg *config.Config,
|
||||||
|
peers PeerServicePeerManager,
|
||||||
|
configFile PeerServiceConfigFileManager,
|
||||||
|
mailer PeerServiceMailManager,
|
||||||
|
) *PeerService {
|
||||||
|
return &PeerService{
|
||||||
|
cfg: cfg,
|
||||||
|
peers: peers,
|
||||||
|
configFile: configFile,
|
||||||
|
mailer: mailer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PeerService) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
|
||||||
|
*domain.Interface,
|
||||||
|
[]domain.Peer,
|
||||||
|
error,
|
||||||
|
) {
|
||||||
|
return p.peers.GetInterfaceAndPeers(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PeerService) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) {
|
||||||
|
return p.peers.PreparePeer(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PeerService) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) {
|
||||||
|
return p.peers.GetPeer(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PeerService) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
|
||||||
|
return p.peers.CreatePeer(ctx, peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PeerService) CreateMultiplePeers(
|
||||||
|
ctx context.Context,
|
||||||
|
interfaceId domain.InterfaceIdentifier,
|
||||||
|
r *domain.PeerCreationRequest,
|
||||||
|
) ([]domain.Peer, error) {
|
||||||
|
return p.peers.CreateMultiplePeers(ctx, interfaceId, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PeerService) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
|
||||||
|
return p.peers.UpdatePeer(ctx, peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PeerService) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error {
|
||||||
|
return p.peers.DeletePeer(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PeerService) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
|
||||||
|
return p.configFile.GetPeerConfig(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PeerService) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
|
||||||
|
return p.configFile.GetPeerConfigQrCode(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PeerService) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error {
|
||||||
|
return p.mailer.SendPeerEmail(ctx, linkOnly, peers...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p PeerService) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error) {
|
||||||
|
return p.peers.GetPeerStats(ctx, id)
|
||||||
|
}
|
83
internal/app/api/v0/backend/user_service.go
Normal file
83
internal/app/api/v0/backend/user_service.go
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
package backend
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/h44z/wg-portal/internal/config"
|
||||||
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
// region dependencies
|
||||||
|
|
||||||
|
type UserServiceUserManager interface {
|
||||||
|
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||||
|
GetAllUsers(ctx context.Context) ([]domain.User, error)
|
||||||
|
CreateUser(ctx context.Context, user *domain.User) (*domain.User, error)
|
||||||
|
UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error)
|
||||||
|
DeleteUser(ctx context.Context, id domain.UserIdentifier) error
|
||||||
|
ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||||
|
DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserServiceWireGuardManager interface {
|
||||||
|
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
|
||||||
|
GetUserInterfaces(ctx context.Context, _ domain.UserIdentifier) ([]domain.Interface, error)
|
||||||
|
GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion dependencies
|
||||||
|
|
||||||
|
type UserService struct {
|
||||||
|
cfg *config.Config
|
||||||
|
|
||||||
|
users UserServiceUserManager
|
||||||
|
wg UserServiceWireGuardManager
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUserService(cfg *config.Config, users UserServiceUserManager, wg UserServiceWireGuardManager) *UserService {
|
||||||
|
return &UserService{
|
||||||
|
cfg: cfg,
|
||||||
|
users: users,
|
||||||
|
wg: wg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u UserService) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
|
||||||
|
return u.users.GetUser(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u UserService) GetAllUsers(ctx context.Context) ([]domain.User, error) {
|
||||||
|
return u.users.GetAllUsers(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u UserService) UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
|
||||||
|
return u.users.UpdateUser(ctx, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u UserService) CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
|
||||||
|
return u.users.CreateUser(ctx, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u UserService) DeleteUser(ctx context.Context, id domain.UserIdentifier) error {
|
||||||
|
return u.users.DeleteUser(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u UserService) ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
|
||||||
|
return u.users.ActivateApi(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u UserService) DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
|
||||||
|
return u.users.DeactivateApi(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u UserService) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
|
||||||
|
return u.wg.GetUserPeers(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u UserService) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) {
|
||||||
|
return u.wg.GetUserPeerStats(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u UserService) GetUserInterfaces(ctx context.Context, id domain.UserIdentifier) ([]domain.Interface, error) {
|
||||||
|
return u.wg.GetUserInterfaces(ctx, id)
|
||||||
|
}
|
@ -7,46 +7,43 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
evbus "github.com/vardius/message-bus"
|
|
||||||
|
|
||||||
"github.com/h44z/wg-portal/internal/config"
|
"github.com/h44z/wg-portal/internal/config"
|
||||||
"github.com/h44z/wg-portal/internal/domain"
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type App struct {
|
// region dependencies
|
||||||
Config *config.Config
|
|
||||||
bus evbus.MessageBus
|
|
||||||
|
|
||||||
Authenticator
|
type WireGuardManager interface {
|
||||||
UserManager
|
ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error)
|
||||||
WireGuardManager
|
RestoreInterfaceState(ctx context.Context, updateDbOnError bool, filter ...domain.InterfaceIdentifier) error
|
||||||
StatisticsCollector
|
|
||||||
ConfigFileManager
|
|
||||||
MailManager
|
|
||||||
ApiV1Manager
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(
|
type UserManager interface {
|
||||||
|
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||||
|
CreateUser(ctx context.Context, user *domain.User) (*domain.User, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion dependencies
|
||||||
|
|
||||||
|
// App is the main application struct.
|
||||||
|
type App struct {
|
||||||
|
cfg *config.Config
|
||||||
|
|
||||||
|
wg WireGuardManager
|
||||||
|
users UserManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize creates a new App instance and initializes it.
|
||||||
|
func Initialize(
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
bus evbus.MessageBus,
|
wg WireGuardManager,
|
||||||
authenticator Authenticator,
|
|
||||||
users UserManager,
|
users UserManager,
|
||||||
wireGuard WireGuardManager,
|
) error {
|
||||||
stats StatisticsCollector,
|
|
||||||
cfgFiles ConfigFileManager,
|
|
||||||
mailer MailManager,
|
|
||||||
) (*App, error) {
|
|
||||||
|
|
||||||
a := &App{
|
a := &App{
|
||||||
Config: cfg,
|
cfg: cfg,
|
||||||
bus: bus,
|
|
||||||
|
|
||||||
Authenticator: authenticator,
|
wg: wg,
|
||||||
UserManager: users,
|
users: users,
|
||||||
WireGuardManager: wireGuard,
|
|
||||||
StatisticsCollector: stats,
|
|
||||||
ConfigFileManager: cfgFiles,
|
|
||||||
MailManager: mailer,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
startupContext, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
startupContext, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
@ -56,36 +53,27 @@ func New(
|
|||||||
startupContext = domain.SetUserInfo(startupContext, domain.SystemAdminContextUserInfo())
|
startupContext = domain.SetUserInfo(startupContext, domain.SystemAdminContextUserInfo())
|
||||||
|
|
||||||
if err := a.createDefaultUser(startupContext); err != nil {
|
if err := a.createDefaultUser(startupContext); err != nil {
|
||||||
return nil, fmt.Errorf("failed to create default user: %w", err)
|
return fmt.Errorf("failed to create default user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := a.importNewInterfaces(startupContext); err != nil {
|
if err := a.importNewInterfaces(startupContext); err != nil {
|
||||||
return nil, fmt.Errorf("failed to import new interfaces: %w", err)
|
return fmt.Errorf("failed to import new interfaces: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := a.restoreInterfaceState(startupContext); err != nil {
|
if err := a.restoreInterfaceState(startupContext); err != nil {
|
||||||
return nil, fmt.Errorf("failed to restore interface state: %w", err)
|
return fmt.Errorf("failed to restore interface state: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return a, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *App) Startup(ctx context.Context) error {
|
|
||||||
|
|
||||||
a.UserManager.StartBackgroundJobs(ctx)
|
|
||||||
a.StatisticsCollector.StartBackgroundJobs(ctx)
|
|
||||||
a.WireGuardManager.StartBackgroundJobs(ctx)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) importNewInterfaces(ctx context.Context) error {
|
func (a *App) importNewInterfaces(ctx context.Context) error {
|
||||||
if !a.Config.Core.ImportExisting {
|
if !a.cfg.Core.ImportExisting {
|
||||||
slog.Debug("skipping interface import - feature disabled")
|
slog.Debug("skipping interface import - feature disabled")
|
||||||
return nil // feature disabled
|
return nil // feature disabled
|
||||||
}
|
}
|
||||||
|
|
||||||
importedCount, err := a.ImportNewInterfaces(ctx)
|
importedCount, err := a.wg.ImportNewInterfaces(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -97,12 +85,12 @@ func (a *App) importNewInterfaces(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) restoreInterfaceState(ctx context.Context) error {
|
func (a *App) restoreInterfaceState(ctx context.Context) error {
|
||||||
if !a.Config.Core.RestoreState {
|
if !a.cfg.Core.RestoreState {
|
||||||
slog.Debug("skipping interface state restore - feature disabled")
|
slog.Debug("skipping interface state restore - feature disabled")
|
||||||
return nil // feature disabled
|
return nil // feature disabled
|
||||||
}
|
}
|
||||||
|
|
||||||
err := a.RestoreInterfaceState(ctx, true)
|
err := a.wg.RestoreInterfaceState(ctx, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -112,13 +100,13 @@ func (a *App) restoreInterfaceState(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) createDefaultUser(ctx context.Context) error {
|
func (a *App) createDefaultUser(ctx context.Context) error {
|
||||||
adminUserId := domain.UserIdentifier(a.Config.Core.AdminUser)
|
adminUserId := domain.UserIdentifier(a.cfg.Core.AdminUser)
|
||||||
if adminUserId == "" {
|
if adminUserId == "" {
|
||||||
slog.Debug("skipping default user creation - admin user is blank")
|
slog.Debug("skipping default user creation - admin user is blank")
|
||||||
return nil // empty admin user - do not create
|
return nil // empty admin user - do not create
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := a.GetUser(ctx, adminUserId)
|
_, err := a.users.GetUser(ctx, adminUserId)
|
||||||
if err != nil && !errors.Is(err, domain.ErrNotFound) {
|
if err != nil && !errors.Is(err, domain.ErrNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -145,22 +133,22 @@ func (a *App) createDefaultUser(ctx context.Context) error {
|
|||||||
Phone: "",
|
Phone: "",
|
||||||
Department: "",
|
Department: "",
|
||||||
Notes: "default administrator user",
|
Notes: "default administrator user",
|
||||||
Password: domain.PrivateString(a.Config.Core.AdminPassword),
|
Password: domain.PrivateString(a.cfg.Core.AdminPassword),
|
||||||
Disabled: nil,
|
Disabled: nil,
|
||||||
DisabledReason: "",
|
DisabledReason: "",
|
||||||
Locked: nil,
|
Locked: nil,
|
||||||
LockedReason: "",
|
LockedReason: "",
|
||||||
LinkedPeerCount: 0,
|
LinkedPeerCount: 0,
|
||||||
}
|
}
|
||||||
if a.Config.Core.AdminApiToken != "" {
|
if a.cfg.Core.AdminApiToken != "" {
|
||||||
if len(a.Config.Core.AdminApiToken) < 18 {
|
if len(a.cfg.Core.AdminApiToken) < 18 {
|
||||||
slog.Warn("admin API token is too short, should be at least 18 characters long")
|
slog.Warn("admin API token is too short, should be at least 18 characters long")
|
||||||
}
|
}
|
||||||
defaultAdmin.ApiToken = a.Config.Core.AdminApiToken
|
defaultAdmin.ApiToken = a.cfg.Core.AdminApiToken
|
||||||
defaultAdmin.ApiTokenCreated = &now
|
defaultAdmin.ApiTokenCreated = &now
|
||||||
}
|
}
|
||||||
|
|
||||||
admin, err := a.CreateUser(ctx, defaultAdmin)
|
admin, err := a.users.CreateUser(ctx, defaultAdmin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -6,21 +6,35 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
evbus "github.com/vardius/message-bus"
|
|
||||||
|
|
||||||
"github.com/h44z/wg-portal/internal/app"
|
"github.com/h44z/wg-portal/internal/app"
|
||||||
"github.com/h44z/wg-portal/internal/config"
|
"github.com/h44z/wg-portal/internal/config"
|
||||||
"github.com/h44z/wg-portal/internal/domain"
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// region dependencies
|
||||||
|
|
||||||
|
type DatabaseRepo interface {
|
||||||
|
// SaveAuditEntry saves an audit entry to the database
|
||||||
|
SaveAuditEntry(ctx context.Context, entry *domain.AuditEntry) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type EventBus interface {
|
||||||
|
// Subscribe subscribes to a topic
|
||||||
|
Subscribe(topic string, fn interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion dependencies
|
||||||
|
|
||||||
|
// Recorder is responsible for recording audit events to the database.
|
||||||
type Recorder struct {
|
type Recorder struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
bus evbus.MessageBus
|
bus EventBus
|
||||||
|
|
||||||
db DatabaseRepo
|
db DatabaseRepo
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuditRecorder(cfg *config.Config, bus evbus.MessageBus, db DatabaseRepo) (*Recorder, error) {
|
// NewAuditRecorder creates a new audit recorder instance.
|
||||||
|
func NewAuditRecorder(cfg *config.Config, bus EventBus, db DatabaseRepo) (*Recorder, error) {
|
||||||
r := &Recorder{
|
r := &Recorder{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
bus: bus,
|
bus: bus,
|
||||||
@ -36,6 +50,8 @@ func NewAuditRecorder(cfg *config.Config, bus evbus.MessageBus, db DatabaseRepo)
|
|||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StartBackgroundJobs starts background jobs for the audit recorder.
|
||||||
|
// This method is non-blocking and returns immediately.
|
||||||
func (r *Recorder) StartBackgroundJobs(ctx context.Context) {
|
func (r *Recorder) StartBackgroundJobs(ctx context.Context) {
|
||||||
if !r.cfg.Statistics.CollectAuditData {
|
if !r.cfg.Statistics.CollectAuditData {
|
||||||
return // noting to do
|
return // noting to do
|
||||||
|
@ -1,11 +0,0 @@
|
|||||||
package audit
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/h44z/wg-portal/internal/domain"
|
|
||||||
)
|
|
||||||
|
|
||||||
type DatabaseRepo interface {
|
|
||||||
SaveAuditEntry(ctx context.Context, entry *domain.AuditEntry) error
|
|
||||||
}
|
|
@ -14,25 +14,78 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
evbus "github.com/vardius/message-bus"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
"github.com/h44z/wg-portal/internal/app"
|
"github.com/h44z/wg-portal/internal/app"
|
||||||
"github.com/h44z/wg-portal/internal/config"
|
"github.com/h44z/wg-portal/internal/config"
|
||||||
"github.com/h44z/wg-portal/internal/domain"
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// region dependencies
|
||||||
|
|
||||||
type UserManager interface {
|
type UserManager interface {
|
||||||
|
// GetUser returns a user by its identifier.
|
||||||
GetUser(context.Context, domain.UserIdentifier) (*domain.User, error)
|
GetUser(context.Context, domain.UserIdentifier) (*domain.User, error)
|
||||||
|
// RegisterUser creates a new user in the database.
|
||||||
RegisterUser(ctx context.Context, user *domain.User) error
|
RegisterUser(ctx context.Context, user *domain.User) error
|
||||||
|
// UpdateUser updates an existing user in the database.
|
||||||
UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error)
|
UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EventBus interface {
|
||||||
|
// Publish sends a message to the message bus.
|
||||||
|
Publish(topic string, args ...any)
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion dependencies
|
||||||
|
|
||||||
|
type AuthenticatorType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
AuthenticatorTypeOAuth AuthenticatorType = "oauth"
|
||||||
|
AuthenticatorTypeOidc AuthenticatorType = "oidc"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticatorOauth is the interface for all OAuth authenticators.
|
||||||
|
type AuthenticatorOauth interface {
|
||||||
|
// GetName returns the name of the authenticator.
|
||||||
|
GetName() string
|
||||||
|
// GetType returns the type of the authenticator. It can be either AuthenticatorTypeOAuth or AuthenticatorTypeOidc.
|
||||||
|
GetType() AuthenticatorType
|
||||||
|
// AuthCodeURL returns the URL for the authentication flow.
|
||||||
|
AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
|
||||||
|
// Exchange exchanges the OAuth code for an access token.
|
||||||
|
Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
|
||||||
|
// GetUserInfo fetches the user information from the OAuth or OIDC provider.
|
||||||
|
GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (map[string]any, error)
|
||||||
|
// ParseUserInfo parses the raw user information into a domain.AuthenticatorUserInfo struct.
|
||||||
|
ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error)
|
||||||
|
// RegistrationEnabled returns whether registration is enabled for the OAuth authenticator.
|
||||||
|
RegistrationEnabled() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthenticatorLdap is the interface for all LDAP authenticators.
|
||||||
|
type AuthenticatorLdap interface {
|
||||||
|
// GetName returns the name of the authenticator.
|
||||||
|
GetName() string
|
||||||
|
// PlaintextAuthentication performs a plaintext authentication against the LDAP server.
|
||||||
|
PlaintextAuthentication(userId domain.UserIdentifier, plainPassword string) error
|
||||||
|
// GetUserInfo fetches the user information from the LDAP server.
|
||||||
|
GetUserInfo(ctx context.Context, username domain.UserIdentifier) (map[string]any, error)
|
||||||
|
// ParseUserInfo parses the raw user information into a domain.AuthenticatorUserInfo struct.
|
||||||
|
ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error)
|
||||||
|
// RegistrationEnabled returns whether registration is enabled for the LDAP authenticator.
|
||||||
|
RegistrationEnabled() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticator is the main entry point for all authentication related tasks.
|
||||||
|
// This includes password authentication and external authentication providers (OIDC, OAuth, LDAP).
|
||||||
type Authenticator struct {
|
type Authenticator struct {
|
||||||
cfg *config.Auth
|
cfg *config.Auth
|
||||||
bus evbus.MessageBus
|
bus EventBus
|
||||||
|
|
||||||
oauthAuthenticators map[string]domain.OauthAuthenticator
|
oauthAuthenticators map[string]AuthenticatorOauth
|
||||||
ldapAuthenticators map[string]domain.LdapAuthenticator
|
ldapAuthenticators map[string]AuthenticatorLdap
|
||||||
|
|
||||||
// URL prefix for the callback endpoints, this is a combination of the external URL and the API prefix
|
// URL prefix for the callback endpoints, this is a combination of the external URL and the API prefix
|
||||||
callbackUrlPrefix string
|
callbackUrlPrefix string
|
||||||
@ -40,7 +93,8 @@ type Authenticator struct {
|
|||||||
users UserManager
|
users UserManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuthenticator(cfg *config.Auth, extUrl string, bus evbus.MessageBus, users UserManager) (
|
// NewAuthenticator creates a new Authenticator instance.
|
||||||
|
func NewAuthenticator(cfg *config.Auth, extUrl string, bus EventBus, users UserManager) (
|
||||||
*Authenticator,
|
*Authenticator,
|
||||||
error,
|
error,
|
||||||
) {
|
) {
|
||||||
@ -68,8 +122,8 @@ func (a *Authenticator) setupExternalAuthProviders(ctx context.Context) error {
|
|||||||
return fmt.Errorf("failed to parse external url: %w", err)
|
return fmt.Errorf("failed to parse external url: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
a.oauthAuthenticators = make(map[string]domain.OauthAuthenticator, len(a.cfg.OpenIDConnect)+len(a.cfg.OAuth))
|
a.oauthAuthenticators = make(map[string]AuthenticatorOauth, len(a.cfg.OpenIDConnect)+len(a.cfg.OAuth))
|
||||||
a.ldapAuthenticators = make(map[string]domain.LdapAuthenticator, len(a.cfg.Ldap))
|
a.ldapAuthenticators = make(map[string]AuthenticatorLdap, len(a.cfg.Ldap))
|
||||||
|
|
||||||
for i := range a.cfg.OpenIDConnect { // OIDC
|
for i := range a.cfg.OpenIDConnect { // OIDC
|
||||||
providerCfg := &a.cfg.OpenIDConnect[i]
|
providerCfg := &a.cfg.OpenIDConnect[i]
|
||||||
@ -123,6 +177,7 @@ func (a *Authenticator) setupExternalAuthProviders(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetExternalLoginProviders returns a list of all available external login providers.
|
||||||
func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.LoginProviderInfo {
|
func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.LoginProviderInfo {
|
||||||
authProviders := make([]domain.LoginProviderInfo, 0, len(a.cfg.OAuth)+len(a.cfg.OpenIDConnect))
|
authProviders := make([]domain.LoginProviderInfo, 0, len(a.cfg.OAuth)+len(a.cfg.OpenIDConnect))
|
||||||
|
|
||||||
@ -157,6 +212,7 @@ func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.Lo
|
|||||||
return authProviders
|
return authProviders
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsUserValid checks if a user is valid and not locked or disabled.
|
||||||
func (a *Authenticator) IsUserValid(ctx context.Context, id domain.UserIdentifier) bool {
|
func (a *Authenticator) IsUserValid(ctx context.Context, id domain.UserIdentifier) bool {
|
||||||
ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context
|
ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context
|
||||||
user, err := a.users.GetUser(ctx, id)
|
user, err := a.users.GetUser(ctx, id)
|
||||||
@ -177,6 +233,8 @@ func (a *Authenticator) IsUserValid(ctx context.Context, id domain.UserIdentifie
|
|||||||
|
|
||||||
// region password authentication
|
// region password authentication
|
||||||
|
|
||||||
|
// PlainLogin performs a password authentication for a user. The username and password are trimmed before usage.
|
||||||
|
// If the login is successful, the user is returned, otherwise an error.
|
||||||
func (a *Authenticator) PlainLogin(ctx context.Context, username, password string) (*domain.User, error) {
|
func (a *Authenticator) PlainLogin(ctx context.Context, username, password string) (*domain.User, error) {
|
||||||
// Validate form input
|
// Validate form input
|
||||||
username = strings.TrimSpace(username)
|
username = strings.TrimSpace(username)
|
||||||
@ -204,7 +262,7 @@ func (a *Authenticator) passwordAuthentication(
|
|||||||
domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
|
domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
|
||||||
|
|
||||||
var ldapUserInfo *domain.AuthenticatorUserInfo
|
var ldapUserInfo *domain.AuthenticatorUserInfo
|
||||||
var ldapProvider domain.LdapAuthenticator
|
var ldapProvider AuthenticatorLdap
|
||||||
|
|
||||||
var userInDatabase = false
|
var userInDatabase = false
|
||||||
var userSource domain.UserSource
|
var userSource domain.UserSource
|
||||||
@ -280,6 +338,7 @@ func (a *Authenticator) passwordAuthentication(
|
|||||||
|
|
||||||
// region oauth authentication
|
// region oauth authentication
|
||||||
|
|
||||||
|
// OauthLoginStep1 starts the oauth authentication flow by returning the authentication URL, state and nonce.
|
||||||
func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) (
|
func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) (
|
||||||
authCodeUrl, state, nonce string,
|
authCodeUrl, state, nonce string,
|
||||||
err error,
|
err error,
|
||||||
@ -296,9 +355,9 @@ func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch oauthProvider.GetType() {
|
switch oauthProvider.GetType() {
|
||||||
case domain.AuthenticatorTypeOAuth:
|
case AuthenticatorTypeOAuth:
|
||||||
authCodeUrl = oauthProvider.AuthCodeURL(state)
|
authCodeUrl = oauthProvider.AuthCodeURL(state)
|
||||||
case domain.AuthenticatorTypeOidc:
|
case AuthenticatorTypeOidc:
|
||||||
nonce, err = a.randString(16)
|
nonce, err = a.randString(16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", fmt.Errorf("failed to generate nonce: %w", err)
|
return "", "", "", fmt.Errorf("failed to generate nonce: %w", err)
|
||||||
@ -318,6 +377,8 @@ func (a *Authenticator) randString(nByte int) (string, error) {
|
|||||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OauthLoginStep2 finishes the oauth authentication flow by exchanging the code for an access token and
|
||||||
|
// fetching the user information.
|
||||||
func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, error) {
|
func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, error) {
|
||||||
oauthProvider, ok := a.oauthAuthenticators[providerId]
|
oauthProvider, ok := a.oauthAuthenticators[providerId]
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/h44z/wg-portal/internal/domain"
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// LdapAuthenticator is an authenticator that uses LDAP for authentication.
|
||||||
type LdapAuthenticator struct {
|
type LdapAuthenticator struct {
|
||||||
cfg *config.LdapProvider
|
cfg *config.LdapProvider
|
||||||
}
|
}
|
||||||
@ -33,14 +34,17 @@ func newLdapAuthenticator(_ context.Context, cfg *config.LdapProvider) (*LdapAut
|
|||||||
return provider, nil
|
return provider, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetName returns the name of the LDAP authenticator.
|
||||||
func (l LdapAuthenticator) GetName() string {
|
func (l LdapAuthenticator) GetName() string {
|
||||||
return l.cfg.ProviderName
|
return l.cfg.ProviderName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegistrationEnabled returns whether registration is enabled for the LDAP authenticator.
|
||||||
func (l LdapAuthenticator) RegistrationEnabled() bool {
|
func (l LdapAuthenticator) RegistrationEnabled() bool {
|
||||||
return l.cfg.RegistrationEnabled
|
return l.cfg.RegistrationEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PlaintextAuthentication performs a plaintext authentication against the LDAP server.
|
||||||
func (l LdapAuthenticator) PlaintextAuthentication(userId domain.UserIdentifier, plainPassword string) error {
|
func (l LdapAuthenticator) PlaintextAuthentication(userId domain.UserIdentifier, plainPassword string) error {
|
||||||
conn, err := internal.LdapConnect(l.cfg)
|
conn, err := internal.LdapConnect(l.cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -81,6 +85,9 @@ func (l LdapAuthenticator) PlaintextAuthentication(userId domain.UserIdentifier,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserInfo retrieves user information from the LDAP server.
|
||||||
|
// If the user is not found, domain.ErrNotFound is returned.
|
||||||
|
// If multiple users are found, domain.ErrNotUnique is returned.
|
||||||
func (l LdapAuthenticator) GetUserInfo(_ context.Context, userId domain.UserIdentifier) (
|
func (l LdapAuthenticator) GetUserInfo(_ context.Context, userId domain.UserIdentifier) (
|
||||||
map[string]any,
|
map[string]any,
|
||||||
error,
|
error,
|
||||||
@ -126,6 +133,7 @@ func (l LdapAuthenticator) GetUserInfo(_ context.Context, userId domain.UserIden
|
|||||||
return users[0], nil
|
return users[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ParseUserInfo parses the user information from the LDAP server into a domain.AuthenticatorUserInfo struct.
|
||||||
func (l LdapAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
|
func (l LdapAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
|
||||||
isAdmin, err := internal.LdapIsMemberOf(raw[l.cfg.FieldMap.GroupMembership].([][]byte), l.cfg.ParsedAdminGroupDN)
|
isAdmin, err := internal.LdapIsMemberOf(raw[l.cfg.FieldMap.GroupMembership].([][]byte), l.cfg.ParsedAdminGroupDN)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -16,6 +16,8 @@ import (
|
|||||||
"github.com/h44z/wg-portal/internal/domain"
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PlainOauthAuthenticator is an authenticator that uses OAuth for authentication.
|
||||||
|
// User information is retrieved from the specified user info endpoint.
|
||||||
type PlainOauthAuthenticator struct {
|
type PlainOauthAuthenticator struct {
|
||||||
name string
|
name string
|
||||||
cfg *oauth2.Config
|
cfg *oauth2.Config
|
||||||
@ -58,22 +60,27 @@ func newPlainOauthAuthenticator(
|
|||||||
return provider, nil
|
return provider, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetName returns the name of the OAuth authenticator.
|
||||||
func (p PlainOauthAuthenticator) GetName() string {
|
func (p PlainOauthAuthenticator) GetName() string {
|
||||||
return p.name
|
return p.name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegistrationEnabled returns whether registration is enabled for the OAuth authenticator.
|
||||||
func (p PlainOauthAuthenticator) RegistrationEnabled() bool {
|
func (p PlainOauthAuthenticator) RegistrationEnabled() bool {
|
||||||
return p.registrationEnabled
|
return p.registrationEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p PlainOauthAuthenticator) GetType() domain.AuthenticatorType {
|
// GetType returns the type of the authenticator.
|
||||||
return domain.AuthenticatorTypeOAuth
|
func (p PlainOauthAuthenticator) GetType() AuthenticatorType {
|
||||||
|
return AuthenticatorTypeOAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AuthCodeURL returns the URL to redirect the user to for authentication.
|
||||||
func (p PlainOauthAuthenticator) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
|
func (p PlainOauthAuthenticator) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
|
||||||
return p.cfg.AuthCodeURL(state, opts...)
|
return p.cfg.AuthCodeURL(state, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Exchange exchanges the OAuth code for a token.
|
||||||
func (p PlainOauthAuthenticator) Exchange(
|
func (p PlainOauthAuthenticator) Exchange(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
code string,
|
code string,
|
||||||
@ -82,6 +89,7 @@ func (p PlainOauthAuthenticator) Exchange(
|
|||||||
return p.cfg.Exchange(ctx, code, opts...)
|
return p.cfg.Exchange(ctx, code, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserInfo retrieves the user information from the user info endpoint.
|
||||||
func (p PlainOauthAuthenticator) GetUserInfo(
|
func (p PlainOauthAuthenticator) GetUserInfo(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
token *oauth2.Token,
|
token *oauth2.Token,
|
||||||
@ -119,6 +127,7 @@ func (p PlainOauthAuthenticator) GetUserInfo(
|
|||||||
return userFields, nil
|
return userFields, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ParseUserInfo parses the user information from the raw data.
|
||||||
func (p PlainOauthAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
|
func (p PlainOauthAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
|
||||||
return parseOauthUserInfo(p.userInfoMapping, p.userAdminMapping, raw)
|
return parseOauthUserInfo(p.userInfoMapping, p.userAdminMapping, raw)
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/h44z/wg-portal/internal/domain"
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// OidcAuthenticator is an authenticator for OpenID Connect providers.
|
||||||
type OidcAuthenticator struct {
|
type OidcAuthenticator struct {
|
||||||
name string
|
name string
|
||||||
provider *oidc.Provider
|
provider *oidc.Provider
|
||||||
@ -60,22 +61,27 @@ func newOidcAuthenticator(
|
|||||||
return provider, nil
|
return provider, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetName returns the name of the authenticator.
|
||||||
func (o OidcAuthenticator) GetName() string {
|
func (o OidcAuthenticator) GetName() string {
|
||||||
return o.name
|
return o.name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegistrationEnabled returns whether registration is enabled for this authenticator.
|
||||||
func (o OidcAuthenticator) RegistrationEnabled() bool {
|
func (o OidcAuthenticator) RegistrationEnabled() bool {
|
||||||
return o.registrationEnabled
|
return o.registrationEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o OidcAuthenticator) GetType() domain.AuthenticatorType {
|
// GetType returns the type of the authenticator.
|
||||||
return domain.AuthenticatorTypeOidc
|
func (o OidcAuthenticator) GetType() AuthenticatorType {
|
||||||
|
return AuthenticatorTypeOidc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AuthCodeURL returns the URL for the OAuth2 flow.
|
||||||
func (o OidcAuthenticator) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
|
func (o OidcAuthenticator) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
|
||||||
return o.cfg.AuthCodeURL(state, opts...)
|
return o.cfg.AuthCodeURL(state, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Exchange exchanges the code for a token.
|
||||||
func (o OidcAuthenticator) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (
|
func (o OidcAuthenticator) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (
|
||||||
*oauth2.Token,
|
*oauth2.Token,
|
||||||
error,
|
error,
|
||||||
@ -83,6 +89,7 @@ func (o OidcAuthenticator) Exchange(ctx context.Context, code string, opts ...oa
|
|||||||
return o.cfg.Exchange(ctx, code, opts...)
|
return o.cfg.Exchange(ctx, code, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserInfo retrieves the user info from the token.
|
||||||
func (o OidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (
|
func (o OidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (
|
||||||
map[string]any,
|
map[string]any,
|
||||||
error,
|
error,
|
||||||
@ -114,6 +121,7 @@ func (o OidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token,
|
|||||||
return tokenFields, nil
|
return tokenFields, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ParseUserInfo parses the user info.
|
||||||
func (o OidcAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
|
func (o OidcAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) {
|
||||||
return parseOauthUserInfo(o.userInfoMapping, o.userAdminMapping, raw)
|
return parseOauthUserInfo(o.userInfoMapping, o.userAdminMapping, raw)
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
evbus "github.com/vardius/message-bus"
|
|
||||||
"github.com/yeqown/go-qrcode/v2"
|
"github.com/yeqown/go-qrcode/v2"
|
||||||
"github.com/yeqown/go-qrcode/writer/compressed"
|
"github.com/yeqown/go-qrcode/writer/compressed"
|
||||||
|
|
||||||
@ -19,19 +18,56 @@ import (
|
|||||||
"github.com/h44z/wg-portal/internal/domain"
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Manager struct {
|
// region dependencies
|
||||||
cfg *config.Config
|
|
||||||
bus evbus.MessageBus
|
|
||||||
tplHandler *TemplateHandler
|
|
||||||
|
|
||||||
fsRepo FileSystemRepo
|
type UserDatabaseRepo interface {
|
||||||
users UserDatabaseRepo
|
// GetUser returns the user with the given identifier from the SQL database.
|
||||||
wg WireguardDatabaseRepo
|
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type WireguardDatabaseRepo interface {
|
||||||
|
// GetInterfaceAndPeers returns the interface and all peers associated with it.
|
||||||
|
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
|
||||||
|
// GetPeer returns the peer with the given identifier.
|
||||||
|
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
||||||
|
// GetInterface returns the interface with the given identifier.
|
||||||
|
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type FileSystemRepo interface {
|
||||||
|
// WriteFile writes the contents to the file at the given path.
|
||||||
|
WriteFile(path string, contents io.Reader) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TemplateRenderer interface {
|
||||||
|
// GetInterfaceConfig returns the configuration file for the given interface.
|
||||||
|
GetInterfaceConfig(iface *domain.Interface, peers []domain.Peer) (io.Reader, error)
|
||||||
|
// GetPeerConfig returns the configuration file for the given peer.
|
||||||
|
GetPeerConfig(peer *domain.Peer) (io.Reader, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type EventBus interface {
|
||||||
|
// Subscribe subscribes to the given topic.
|
||||||
|
Subscribe(topic string, fn any) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion dependencies
|
||||||
|
|
||||||
|
// Manager is responsible for managing the configuration files of the WireGuard interfaces and peers.
|
||||||
|
type Manager struct {
|
||||||
|
cfg *config.Config
|
||||||
|
bus EventBus
|
||||||
|
|
||||||
|
tplHandler TemplateRenderer
|
||||||
|
fsRepo FileSystemRepo
|
||||||
|
users UserDatabaseRepo
|
||||||
|
wg WireguardDatabaseRepo
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConfigFileManager creates a new Manager instance.
|
||||||
func NewConfigFileManager(
|
func NewConfigFileManager(
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
bus evbus.MessageBus,
|
bus EventBus,
|
||||||
users UserDatabaseRepo,
|
users UserDatabaseRepo,
|
||||||
wg WireguardDatabaseRepo,
|
wg WireguardDatabaseRepo,
|
||||||
fsRepo FileSystemRepo,
|
fsRepo FileSystemRepo,
|
||||||
@ -115,6 +151,8 @@ func (m Manager) handlePeerInterfaceUpdatedEvent(id domain.InterfaceIdentifier)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetInterfaceConfig returns the configuration file for the given interface.
|
||||||
|
// The file is structured in wg-quick format.
|
||||||
func (m Manager) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) {
|
func (m Manager) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) {
|
||||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -128,6 +166,8 @@ func (m Manager) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIden
|
|||||||
return m.tplHandler.GetInterfaceConfig(iface, peers)
|
return m.tplHandler.GetInterfaceConfig(iface, peers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPeerConfig returns the configuration file for the given peer.
|
||||||
|
// The file is structured in wg-quick format.
|
||||||
func (m Manager) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
|
func (m Manager) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
|
||||||
peer, err := m.wg.GetPeer(ctx, id)
|
peer, err := m.wg.GetPeer(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -141,6 +181,7 @@ func (m Manager) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (i
|
|||||||
return m.tplHandler.GetPeerConfig(peer)
|
return m.tplHandler.GetPeerConfig(peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPeerConfigQrCode returns a QR code image containing the configuration for the given peer.
|
||||||
func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
|
func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
|
||||||
peer, err := m.wg.GetPeer(ctx, id)
|
peer, err := m.wg.GetPeer(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -191,6 +232,7 @@ func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifi
|
|||||||
return buf, nil
|
return buf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PersistInterfaceConfig writes the configuration file for the given interface to the file system.
|
||||||
func (m Manager) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error {
|
func (m Manager) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error {
|
||||||
iface, peers, err := m.wg.GetInterfaceAndPeers(ctx, id)
|
iface, peers, err := m.wg.GetInterfaceAndPeers(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -213,4 +255,5 @@ type nopCloser struct {
|
|||||||
io.Writer
|
io.Writer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close is a no-op for the nopCloser.
|
||||||
func (nopCloser) Close() error { return nil }
|
func (nopCloser) Close() error { return nil }
|
||||||
|
@ -1,22 +0,0 @@
|
|||||||
package configfile
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/h44z/wg-portal/internal/domain"
|
|
||||||
)
|
|
||||||
|
|
||||||
type UserDatabaseRepo interface {
|
|
||||||
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type WireguardDatabaseRepo interface {
|
|
||||||
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
|
|
||||||
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
|
||||||
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type FileSystemRepo interface {
|
|
||||||
WriteFile(path string, contents io.Reader) error
|
|
||||||
}
|
|
@ -13,6 +13,8 @@ import (
|
|||||||
//go:embed tpl_files/*
|
//go:embed tpl_files/*
|
||||||
var TemplateFiles embed.FS
|
var TemplateFiles embed.FS
|
||||||
|
|
||||||
|
// TemplateHandler is responsible for rendering the WireGuard configuration files
|
||||||
|
// based on the provided templates.
|
||||||
type TemplateHandler struct {
|
type TemplateHandler struct {
|
||||||
templates *template.Template
|
templates *template.Template
|
||||||
}
|
}
|
||||||
@ -34,6 +36,7 @@ func newTemplateHandler() (*TemplateHandler, error) {
|
|||||||
return handler, nil
|
return handler, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetInterfaceConfig returns the rendered configuration file for a WireGuard interface.
|
||||||
func (c TemplateHandler) GetInterfaceConfig(cfg *domain.Interface, peers []domain.Peer) (io.Reader, error) {
|
func (c TemplateHandler) GetInterfaceConfig(cfg *domain.Interface, peers []domain.Peer) (io.Reader, error) {
|
||||||
var tplBuff bytes.Buffer
|
var tplBuff bytes.Buffer
|
||||||
|
|
||||||
@ -51,6 +54,7 @@ func (c TemplateHandler) GetInterfaceConfig(cfg *domain.Interface, peers []domai
|
|||||||
return &tplBuff, nil
|
return &tplBuff, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPeerConfig returns the rendered configuration file for a WireGuard peer.
|
||||||
func (c TemplateHandler) GetPeerConfig(peer *domain.Peer) (io.Reader, error) {
|
func (c TemplateHandler) GetPeerConfig(peer *domain.Peer) (io.Reader, error) {
|
||||||
var tplBuff bytes.Buffer
|
var tplBuff bytes.Buffer
|
||||||
|
|
||||||
|
@ -10,16 +10,60 @@ import (
|
|||||||
"github.com/h44z/wg-portal/internal/domain"
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Manager struct {
|
// region dependencies
|
||||||
cfg *config.Config
|
|
||||||
tplHandler *TemplateHandler
|
|
||||||
|
|
||||||
|
type Mailer interface {
|
||||||
|
// Send sends an email with the given subject and body to the given recipients.
|
||||||
|
Send(ctx context.Context, subject, body string, to []string, options *domain.MailOptions) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConfigFileManager interface {
|
||||||
|
// GetInterfaceConfig returns the configuration for the given interface.
|
||||||
|
GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error)
|
||||||
|
// GetPeerConfig returns the configuration for the given peer.
|
||||||
|
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
|
||||||
|
// GetPeerConfigQrCode returns the QR code for the given peer.
|
||||||
|
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserDatabaseRepo interface {
|
||||||
|
// GetUser returns the user with the given identifier.
|
||||||
|
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type WireguardDatabaseRepo interface {
|
||||||
|
// GetInterfaceAndPeers returns the interface and all peers for the given interface identifier.
|
||||||
|
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
|
||||||
|
// GetPeer returns the peer with the given identifier.
|
||||||
|
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
||||||
|
// GetInterface returns the interface with the given identifier.
|
||||||
|
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TemplateRenderer interface {
|
||||||
|
// GetConfigMail returns the text and html template for the mail with a link.
|
||||||
|
GetConfigMail(user *domain.User, link string) (io.Reader, io.Reader, error)
|
||||||
|
// GetConfigMailWithAttachment returns the text and html template for the mail with an attachment.
|
||||||
|
GetConfigMailWithAttachment(user *domain.User, cfgName, qrName string) (
|
||||||
|
io.Reader,
|
||||||
|
io.Reader,
|
||||||
|
error,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion dependencies
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
cfg *config.Config
|
||||||
|
|
||||||
|
tplHandler TemplateRenderer
|
||||||
mailer Mailer
|
mailer Mailer
|
||||||
configFiles ConfigFileManager
|
configFiles ConfigFileManager
|
||||||
users UserDatabaseRepo
|
users UserDatabaseRepo
|
||||||
wg WireguardDatabaseRepo
|
wg WireguardDatabaseRepo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewMailManager creates a new mail manager.
|
||||||
func NewMailManager(
|
func NewMailManager(
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
mailer Mailer,
|
mailer Mailer,
|
||||||
@ -44,6 +88,7 @@ func NewMailManager(
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SendPeerEmail sends an email to the user linked to the given peers.
|
||||||
func (m Manager) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error {
|
func (m Manager) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error {
|
||||||
for _, peerId := range peers {
|
for _, peerId := range peers {
|
||||||
peer, err := m.wg.GetPeer(ctx, peerId)
|
peer, err := m.wg.GetPeer(ctx, peerId)
|
||||||
|
@ -1,28 +0,0 @@
|
|||||||
package mail
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/h44z/wg-portal/internal/domain"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Mailer interface {
|
|
||||||
Send(ctx context.Context, subject, body string, to []string, options *domain.MailOptions) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type ConfigFileManager interface {
|
|
||||||
GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error)
|
|
||||||
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
|
|
||||||
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserDatabaseRepo interface {
|
|
||||||
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type WireguardDatabaseRepo interface {
|
|
||||||
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
|
|
||||||
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
|
||||||
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
|
|
||||||
}
|
|
@ -14,6 +14,7 @@ import (
|
|||||||
//go:embed tpl_files/*
|
//go:embed tpl_files/*
|
||||||
var TemplateFiles embed.FS
|
var TemplateFiles embed.FS
|
||||||
|
|
||||||
|
// TemplateHandler is a struct that holds the html and text templates.
|
||||||
type TemplateHandler struct {
|
type TemplateHandler struct {
|
||||||
portalUrl string
|
portalUrl string
|
||||||
htmlTemplates *htmlTemplate.Template
|
htmlTemplates *htmlTemplate.Template
|
||||||
@ -40,6 +41,7 @@ func newTemplateHandler(portalUrl string) (*TemplateHandler, error) {
|
|||||||
return handler, nil
|
return handler, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetConfigMail returns the text and html template for the mail with a link.
|
||||||
func (c TemplateHandler) GetConfigMail(user *domain.User, link string) (io.Reader, io.Reader, error) {
|
func (c TemplateHandler) GetConfigMail(user *domain.User, link string) (io.Reader, io.Reader, error) {
|
||||||
var tplBuff bytes.Buffer
|
var tplBuff bytes.Buffer
|
||||||
var htmlTplBuff bytes.Buffer
|
var htmlTplBuff bytes.Buffer
|
||||||
@ -65,6 +67,7 @@ func (c TemplateHandler) GetConfigMail(user *domain.User, link string) (io.Reade
|
|||||||
return &tplBuff, &htmlTplBuff, nil
|
return &tplBuff, &htmlTplBuff, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetConfigMailWithAttachment returns the text and html template for the mail with an attachment.
|
||||||
func (c TemplateHandler) GetConfigMailWithAttachment(user *domain.User, cfgName, qrName string) (
|
func (c TemplateHandler) GetConfigMailWithAttachment(user *domain.User, cfgName, qrName string) (
|
||||||
io.Reader,
|
io.Reader,
|
||||||
io.Reader,
|
io.Reader,
|
||||||
|
@ -1,77 +0,0 @@
|
|||||||
package app
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/h44z/wg-portal/internal/domain"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Authenticator interface {
|
|
||||||
GetExternalLoginProviders(_ context.Context) []domain.LoginProviderInfo
|
|
||||||
IsUserValid(ctx context.Context, id domain.UserIdentifier) bool
|
|
||||||
PlainLogin(ctx context.Context, username, password string) (*domain.User, error)
|
|
||||||
OauthLoginStep1(_ context.Context, providerId string) (authCodeUrl, state, nonce string, err error)
|
|
||||||
OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserManager interface {
|
|
||||||
RegisterUser(ctx context.Context, user *domain.User) error
|
|
||||||
NewUser(ctx context.Context, user *domain.User) error
|
|
||||||
StartBackgroundJobs(ctx context.Context)
|
|
||||||
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
|
||||||
GetAllUsers(ctx context.Context) ([]domain.User, error)
|
|
||||||
UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error)
|
|
||||||
CreateUser(ctx context.Context, user *domain.User) (*domain.User, error)
|
|
||||||
DeleteUser(ctx context.Context, id domain.UserIdentifier) error
|
|
||||||
ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
|
||||||
DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type WireGuardManager interface {
|
|
||||||
StartBackgroundJobs(ctx context.Context)
|
|
||||||
GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error)
|
|
||||||
ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error)
|
|
||||||
RestoreInterfaceState(ctx context.Context, updateDbOnError bool, filter ...domain.InterfaceIdentifier) error
|
|
||||||
CreateDefaultPeer(ctx context.Context, userId domain.UserIdentifier) error
|
|
||||||
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
|
|
||||||
GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error)
|
|
||||||
GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error)
|
|
||||||
GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error)
|
|
||||||
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
|
|
||||||
GetUserInterfaces(ctx context.Context, id domain.UserIdentifier) ([]domain.Interface, error)
|
|
||||||
PrepareInterface(ctx context.Context) (*domain.Interface, error)
|
|
||||||
CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error)
|
|
||||||
UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error)
|
|
||||||
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
|
|
||||||
PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error)
|
|
||||||
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
|
||||||
CreatePeer(ctx context.Context, p *domain.Peer) (*domain.Peer, error)
|
|
||||||
CreateMultiplePeers(
|
|
||||||
ctx context.Context,
|
|
||||||
id domain.InterfaceIdentifier,
|
|
||||||
r *domain.PeerCreationRequest,
|
|
||||||
) ([]domain.Peer, error)
|
|
||||||
UpdatePeer(ctx context.Context, p *domain.Peer) (*domain.Peer, error)
|
|
||||||
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
|
|
||||||
ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type StatisticsCollector interface {
|
|
||||||
StartBackgroundJobs(ctx context.Context)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ConfigFileManager interface {
|
|
||||||
GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error)
|
|
||||||
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
|
|
||||||
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
|
|
||||||
PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type MailManager interface {
|
|
||||||
SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type ApiV1Manager interface {
|
|
||||||
ApiV1GetUsers(ctx context.Context) ([]domain.User, error)
|
|
||||||
}
|
|
@ -1,12 +0,0 @@
|
|||||||
package route
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/h44z/wg-portal/internal/domain"
|
|
||||||
)
|
|
||||||
|
|
||||||
type InterfaceAndPeerDatabaseRepo interface {
|
|
||||||
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
|
|
||||||
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
|
|
||||||
}
|
|
@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
|
||||||
evbus "github.com/vardius/message-bus"
|
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
@ -17,6 +16,22 @@ import (
|
|||||||
"github.com/h44z/wg-portal/internal/lowlevel"
|
"github.com/h44z/wg-portal/internal/lowlevel"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// region dependencies
|
||||||
|
|
||||||
|
type InterfaceAndPeerDatabaseRepo interface {
|
||||||
|
// GetAllInterfaces returns all interfaces
|
||||||
|
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
|
||||||
|
// GetInterfacePeers returns all peers for a given interface
|
||||||
|
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type EventBus interface {
|
||||||
|
// Subscribe subscribes to a topic
|
||||||
|
Subscribe(topic string, fn interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion dependencies
|
||||||
|
|
||||||
type routeRuleInfo struct {
|
type routeRuleInfo struct {
|
||||||
ifaceId domain.InterfaceIdentifier
|
ifaceId domain.InterfaceIdentifier
|
||||||
fwMark uint32
|
fwMark uint32
|
||||||
@ -29,14 +44,15 @@ type routeRuleInfo struct {
|
|||||||
// for default routes.
|
// for default routes.
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
bus evbus.MessageBus
|
|
||||||
|
|
||||||
wg lowlevel.WireGuardClient
|
bus EventBus
|
||||||
nl lowlevel.NetlinkClient
|
wg lowlevel.WireGuardClient
|
||||||
db InterfaceAndPeerDatabaseRepo
|
nl lowlevel.NetlinkClient
|
||||||
|
db InterfaceAndPeerDatabaseRepo
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRouteManager(cfg *config.Config, bus evbus.MessageBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) {
|
// NewRouteManager creates a new route manager instance.
|
||||||
|
func NewRouteManager(cfg *config.Config, bus EventBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) {
|
||||||
wg, err := wgctrl.New()
|
wg, err := wgctrl.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic("failed to init wgctrl: " + err.Error())
|
panic("failed to init wgctrl: " + err.Error())
|
||||||
@ -63,7 +79,10 @@ func (m Manager) connectToMessageBus() {
|
|||||||
_ = m.bus.Subscribe(app.TopicRouteRemove, m.handleRouteRemoveEvent)
|
_ = m.bus.Subscribe(app.TopicRouteRemove, m.handleRouteRemoveEvent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StartBackgroundJobs starts background jobs for the route manager.
|
||||||
|
// This method is non-blocking and returns immediately.
|
||||||
func (m Manager) StartBackgroundJobs(_ context.Context) {
|
func (m Manager) StartBackgroundJobs(_ context.Context) {
|
||||||
|
// this is a no-op for now
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Manager) handleRouteUpdateEvent(srcDescription string) {
|
func (m Manager) handleRouteUpdateEvent(srcDescription string) {
|
||||||
|
@ -1,20 +0,0 @@
|
|||||||
package users
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/h44z/wg-portal/internal/domain"
|
|
||||||
)
|
|
||||||
|
|
||||||
type UserDatabaseRepo interface {
|
|
||||||
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
|
||||||
GetUserByEmail(ctx context.Context, email string) (*domain.User, error)
|
|
||||||
GetAllUsers(ctx context.Context) ([]domain.User, error)
|
|
||||||
FindUsers(ctx context.Context, search string) ([]domain.User, error)
|
|
||||||
SaveUser(ctx context.Context, id domain.UserIdentifier, updateFunc func(u *domain.User) (*domain.User, error)) error
|
|
||||||
DeleteUser(ctx context.Context, id domain.UserIdentifier) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type PeerDatabaseRepo interface {
|
|
||||||
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
|
|
||||||
}
|
|
@ -11,7 +11,6 @@ import (
|
|||||||
|
|
||||||
"github.com/go-ldap/ldap/v3"
|
"github.com/go-ldap/ldap/v3"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
evbus "github.com/vardius/message-bus"
|
|
||||||
|
|
||||||
"github.com/h44z/wg-portal/internal"
|
"github.com/h44z/wg-portal/internal"
|
||||||
"github.com/h44z/wg-portal/internal/app"
|
"github.com/h44z/wg-portal/internal/app"
|
||||||
@ -19,15 +18,46 @@ import (
|
|||||||
"github.com/h44z/wg-portal/internal/domain"
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// region dependencies
|
||||||
|
|
||||||
|
type UserDatabaseRepo interface {
|
||||||
|
// GetUser returns the user with the given identifier.
|
||||||
|
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
|
||||||
|
// GetUserByEmail returns the user with the given email address.
|
||||||
|
GetUserByEmail(ctx context.Context, email string) (*domain.User, error)
|
||||||
|
// GetAllUsers returns all users.
|
||||||
|
GetAllUsers(ctx context.Context) ([]domain.User, error)
|
||||||
|
// FindUsers returns all users matching the search string.
|
||||||
|
FindUsers(ctx context.Context, search string) ([]domain.User, error)
|
||||||
|
// SaveUser saves the user with the given identifier.
|
||||||
|
SaveUser(ctx context.Context, id domain.UserIdentifier, updateFunc func(u *domain.User) (*domain.User, error)) error
|
||||||
|
// DeleteUser deletes the user with the given identifier.
|
||||||
|
DeleteUser(ctx context.Context, id domain.UserIdentifier) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type PeerDatabaseRepo interface {
|
||||||
|
// GetUserPeers returns all peers linked to the given user.
|
||||||
|
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type EventBus interface {
|
||||||
|
// Publish sends a message to the message bus.
|
||||||
|
Publish(topic string, args ...any)
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion dependencies
|
||||||
|
|
||||||
|
// Manager is the user manager.
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
bus evbus.MessageBus
|
|
||||||
|
|
||||||
|
bus EventBus
|
||||||
users UserDatabaseRepo
|
users UserDatabaseRepo
|
||||||
peers PeerDatabaseRepo
|
peers PeerDatabaseRepo
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserManager(cfg *config.Config, bus evbus.MessageBus, users UserDatabaseRepo, peers PeerDatabaseRepo) (
|
// NewUserManager creates a new user manager instance.
|
||||||
|
func NewUserManager(cfg *config.Config, bus EventBus, users UserDatabaseRepo, peers PeerDatabaseRepo) (
|
||||||
*Manager,
|
*Manager,
|
||||||
error,
|
error,
|
||||||
) {
|
) {
|
||||||
@ -41,6 +71,7 @@ func NewUserManager(cfg *config.Config, bus evbus.MessageBus, users UserDatabase
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterUser registers a new user.
|
||||||
func (m Manager) RegisterUser(ctx context.Context, user *domain.User) error {
|
func (m Manager) RegisterUser(ctx context.Context, user *domain.User) error {
|
||||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -56,6 +87,7 @@ func (m Manager) RegisterUser(ctx context.Context, user *domain.User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewUser creates a new user.
|
||||||
func (m Manager) NewUser(ctx context.Context, user *domain.User) error {
|
func (m Manager) NewUser(ctx context.Context, user *domain.User) error {
|
||||||
if user.Identifier == "" {
|
if user.Identifier == "" {
|
||||||
return errors.New("missing user identifier")
|
return errors.New("missing user identifier")
|
||||||
@ -90,12 +122,13 @@ func (m Manager) NewUser(ctx context.Context, user *domain.User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StartBackgroundJobs starts the background jobs.
|
||||||
|
// This method is non-blocking and returns immediately.
|
||||||
func (m Manager) StartBackgroundJobs(ctx context.Context) {
|
func (m Manager) StartBackgroundJobs(ctx context.Context) {
|
||||||
|
|
||||||
go m.runLdapSynchronizationService(ctx)
|
go m.runLdapSynchronizationService(ctx)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUser returns the user with the given identifier.
|
||||||
func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
|
func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
|
||||||
if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
|
if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -112,6 +145,7 @@ func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserByEmail returns the user with the given email address.
|
||||||
func (m Manager) GetUserByEmail(ctx context.Context, email string) (*domain.User, error) {
|
func (m Manager) GetUserByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||||||
|
|
||||||
user, err := m.users.GetUserByEmail(ctx, email)
|
user, err := m.users.GetUserByEmail(ctx, email)
|
||||||
@ -130,6 +164,7 @@ func (m Manager) GetUserByEmail(ctx context.Context, email string) (*domain.User
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAllUsers returns all users.
|
||||||
func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) {
|
func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) {
|
||||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -162,6 +197,7 @@ func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) {
|
|||||||
return users, nil
|
return users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateUser updates the user with the given identifier.
|
||||||
func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
|
func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
|
||||||
if err := domain.ValidateUserAccessRights(ctx, user.Identifier); err != nil {
|
if err := domain.ValidateUserAccessRights(ctx, user.Identifier); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -203,6 +239,7 @@ func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.Use
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateUser creates a new user.
|
||||||
func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
|
func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
|
||||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -236,6 +273,7 @@ func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.Use
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteUser deletes the user with the given identifier.
|
||||||
func (m Manager) DeleteUser(ctx context.Context, id domain.UserIdentifier) error {
|
func (m Manager) DeleteUser(ctx context.Context, id domain.UserIdentifier) error {
|
||||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -260,6 +298,7 @@ func (m Manager) DeleteUser(ctx context.Context, id domain.UserIdentifier) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ActivateApi activates the API access for the user with the given identifier.
|
||||||
func (m Manager) ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
|
func (m Manager) ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
|
||||||
user, err := m.users.GetUser(ctx, id)
|
user, err := m.users.GetUser(ctx, id)
|
||||||
if err != nil && !errors.Is(err, domain.ErrNotFound) {
|
if err != nil && !errors.Is(err, domain.ErrNotFound) {
|
||||||
@ -287,6 +326,7 @@ func (m Manager) ActivateApi(ctx context.Context, id domain.UserIdentifier) (*do
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeactivateApi deactivates the API access for the user with the given identifier.
|
||||||
func (m Manager) DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
|
func (m Manager) DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
|
||||||
user, err := m.users.GetUser(ctx, id)
|
user, err := m.users.GetUser(ctx, id)
|
||||||
if err != nil && !errors.Is(err, domain.ErrNotFound) {
|
if err != nil && !errors.Is(err, domain.ErrNotFound) {
|
||||||
|
@ -1,87 +0,0 @@
|
|||||||
package wireguard
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/h44z/wg-portal/internal/domain"
|
|
||||||
)
|
|
||||||
|
|
||||||
type InterfaceAndPeerDatabaseRepo interface {
|
|
||||||
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
|
|
||||||
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
|
|
||||||
GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error)
|
|
||||||
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
|
|
||||||
FindInterfaces(ctx context.Context, search string) ([]domain.Interface, error)
|
|
||||||
GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error)
|
|
||||||
SaveInterface(
|
|
||||||
ctx context.Context,
|
|
||||||
id domain.InterfaceIdentifier,
|
|
||||||
updateFunc func(in *domain.Interface) (*domain.Interface, error),
|
|
||||||
) error
|
|
||||||
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
|
|
||||||
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
|
|
||||||
FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ([]domain.Peer, error)
|
|
||||||
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
|
|
||||||
FindUserPeers(ctx context.Context, id domain.UserIdentifier, search string) ([]domain.Peer, error)
|
|
||||||
SavePeer(
|
|
||||||
ctx context.Context,
|
|
||||||
id domain.PeerIdentifier,
|
|
||||||
updateFunc func(in *domain.Peer) (*domain.Peer, error),
|
|
||||||
) error
|
|
||||||
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
|
|
||||||
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
|
||||||
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type StatisticsDatabaseRepo interface {
|
|
||||||
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
|
|
||||||
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
|
|
||||||
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
|
||||||
|
|
||||||
UpdatePeerStatus(
|
|
||||||
ctx context.Context,
|
|
||||||
id domain.PeerIdentifier,
|
|
||||||
updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error),
|
|
||||||
) error
|
|
||||||
UpdateInterfaceStatus(
|
|
||||||
ctx context.Context,
|
|
||||||
id domain.InterfaceIdentifier,
|
|
||||||
updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error),
|
|
||||||
) error
|
|
||||||
|
|
||||||
DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type InterfaceController interface {
|
|
||||||
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
|
|
||||||
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
|
||||||
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
|
||||||
GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (
|
|
||||||
*domain.PhysicalPeer,
|
|
||||||
error,
|
|
||||||
)
|
|
||||||
SaveInterface(
|
|
||||||
_ context.Context,
|
|
||||||
id domain.InterfaceIdentifier,
|
|
||||||
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
|
||||||
) error
|
|
||||||
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
|
||||||
SavePeer(
|
|
||||||
_ context.Context,
|
|
||||||
deviceId domain.InterfaceIdentifier,
|
|
||||||
id domain.PeerIdentifier,
|
|
||||||
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
|
||||||
) error
|
|
||||||
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type WgQuickController interface {
|
|
||||||
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
|
|
||||||
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
|
|
||||||
UnsetDNS(id domain.InterfaceIdentifier) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type MetricsServer interface {
|
|
||||||
UpdateInterfaceMetrics(status domain.InterfaceStatus)
|
|
||||||
UpdatePeerMetrics(peer *domain.Peer, status domain.PeerStatus)
|
|
||||||
}
|
|
@ -7,31 +7,63 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
probing "github.com/prometheus-community/pro-bing"
|
probing "github.com/prometheus-community/pro-bing"
|
||||||
evbus "github.com/vardius/message-bus"
|
|
||||||
|
|
||||||
"github.com/h44z/wg-portal/internal/app"
|
"github.com/h44z/wg-portal/internal/app"
|
||||||
"github.com/h44z/wg-portal/internal/config"
|
"github.com/h44z/wg-portal/internal/config"
|
||||||
"github.com/h44z/wg-portal/internal/domain"
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type StatisticsDatabaseRepo interface {
|
||||||
|
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
|
||||||
|
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
|
||||||
|
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
||||||
|
UpdatePeerStatus(
|
||||||
|
ctx context.Context,
|
||||||
|
id domain.PeerIdentifier,
|
||||||
|
updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error),
|
||||||
|
) error
|
||||||
|
UpdateInterfaceStatus(
|
||||||
|
ctx context.Context,
|
||||||
|
id domain.InterfaceIdentifier,
|
||||||
|
updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error),
|
||||||
|
) error
|
||||||
|
DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type StatisticsInterfaceController interface {
|
||||||
|
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
||||||
|
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type StatisticsMetricsServer interface {
|
||||||
|
UpdateInterfaceMetrics(status domain.InterfaceStatus)
|
||||||
|
UpdatePeerMetrics(peer *domain.Peer, status domain.PeerStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
type StatisticsEventBus interface {
|
||||||
|
// Subscribe subscribes to a topic
|
||||||
|
Subscribe(topic string, fn interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
type StatisticsCollector struct {
|
type StatisticsCollector struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
bus evbus.MessageBus
|
bus StatisticsEventBus
|
||||||
|
|
||||||
pingWaitGroup sync.WaitGroup
|
pingWaitGroup sync.WaitGroup
|
||||||
pingJobs chan domain.Peer
|
pingJobs chan domain.Peer
|
||||||
|
|
||||||
db StatisticsDatabaseRepo
|
db StatisticsDatabaseRepo
|
||||||
wg InterfaceController
|
wg StatisticsInterfaceController
|
||||||
ms MetricsServer
|
ms StatisticsMetricsServer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewStatisticsCollector creates a new statistics collector.
|
||||||
func NewStatisticsCollector(
|
func NewStatisticsCollector(
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
bus evbus.MessageBus,
|
bus StatisticsEventBus,
|
||||||
db StatisticsDatabaseRepo,
|
db StatisticsDatabaseRepo,
|
||||||
wg InterfaceController,
|
wg StatisticsInterfaceController,
|
||||||
ms MetricsServer,
|
ms StatisticsMetricsServer,
|
||||||
) (*StatisticsCollector, error) {
|
) (*StatisticsCollector, error) {
|
||||||
c := &StatisticsCollector{
|
c := &StatisticsCollector{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
@ -47,6 +79,8 @@ func NewStatisticsCollector(
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StartBackgroundJobs starts the background jobs for the statistics collector.
|
||||||
|
// This method is non-blocking and returns immediately.
|
||||||
func (c *StatisticsCollector) StartBackgroundJobs(ctx context.Context) {
|
func (c *StatisticsCollector) StartBackgroundJobs(ctx context.Context) {
|
||||||
c.startPingWorkers(ctx)
|
c.startPingWorkers(ctx)
|
||||||
c.startInterfaceDataFetcher(ctx)
|
c.startInterfaceDataFetcher(ctx)
|
||||||
|
@ -5,17 +5,74 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
evbus "github.com/vardius/message-bus"
|
|
||||||
|
|
||||||
"github.com/h44z/wg-portal/internal/app"
|
"github.com/h44z/wg-portal/internal/app"
|
||||||
"github.com/h44z/wg-portal/internal/config"
|
"github.com/h44z/wg-portal/internal/config"
|
||||||
"github.com/h44z/wg-portal/internal/domain"
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Manager struct {
|
// region dependencies
|
||||||
cfg *config.Config
|
|
||||||
bus evbus.MessageBus
|
|
||||||
|
|
||||||
|
type InterfaceAndPeerDatabaseRepo interface {
|
||||||
|
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
|
||||||
|
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
|
||||||
|
GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error)
|
||||||
|
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
|
||||||
|
GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error)
|
||||||
|
SaveInterface(
|
||||||
|
ctx context.Context,
|
||||||
|
id domain.InterfaceIdentifier,
|
||||||
|
updateFunc func(in *domain.Interface) (*domain.Interface, error),
|
||||||
|
) error
|
||||||
|
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
|
||||||
|
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
|
||||||
|
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
|
||||||
|
SavePeer(
|
||||||
|
ctx context.Context,
|
||||||
|
id domain.PeerIdentifier,
|
||||||
|
updateFunc func(in *domain.Peer) (*domain.Peer, error),
|
||||||
|
) error
|
||||||
|
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
|
||||||
|
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
|
||||||
|
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type InterfaceController interface {
|
||||||
|
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
|
||||||
|
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
||||||
|
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
||||||
|
SaveInterface(
|
||||||
|
_ context.Context,
|
||||||
|
id domain.InterfaceIdentifier,
|
||||||
|
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
||||||
|
) error
|
||||||
|
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
||||||
|
SavePeer(
|
||||||
|
_ context.Context,
|
||||||
|
deviceId domain.InterfaceIdentifier,
|
||||||
|
id domain.PeerIdentifier,
|
||||||
|
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
||||||
|
) error
|
||||||
|
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type WgQuickController interface {
|
||||||
|
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
|
||||||
|
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
|
||||||
|
UnsetDNS(id domain.InterfaceIdentifier) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type EventBus interface {
|
||||||
|
// Publish sends a message to the message bus.
|
||||||
|
Publish(topic string, args ...any)
|
||||||
|
// Subscribe subscribes to a topic
|
||||||
|
Subscribe(topic string, fn interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion dependencies
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
cfg *config.Config
|
||||||
|
bus EventBus
|
||||||
db InterfaceAndPeerDatabaseRepo
|
db InterfaceAndPeerDatabaseRepo
|
||||||
wg InterfaceController
|
wg InterfaceController
|
||||||
quick WgQuickController
|
quick WgQuickController
|
||||||
@ -23,7 +80,7 @@ type Manager struct {
|
|||||||
|
|
||||||
func NewWireGuardManager(
|
func NewWireGuardManager(
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
bus evbus.MessageBus,
|
bus EventBus,
|
||||||
wg InterfaceController,
|
wg InterfaceController,
|
||||||
quick WgQuickController,
|
quick WgQuickController,
|
||||||
db InterfaceAndPeerDatabaseRepo,
|
db InterfaceAndPeerDatabaseRepo,
|
||||||
@ -41,6 +98,8 @@ func NewWireGuardManager(
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StartBackgroundJobs starts background jobs like the expired peers check.
|
||||||
|
// This method is non-blocking.
|
||||||
func (m Manager) StartBackgroundJobs(ctx context.Context) {
|
func (m Manager) StartBackgroundJobs(ctx context.Context) {
|
||||||
go m.runExpiredPeersCheck(ctx)
|
go m.runExpiredPeersCheck(ctx)
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,8 @@ import (
|
|||||||
"github.com/h44z/wg-portal/internal/domain"
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// GetImportableInterfaces returns all physical interfaces that are available on the system.
|
||||||
|
// This function also returns interfaces that are already available in the database.
|
||||||
func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) {
|
func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) {
|
||||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -26,6 +28,7 @@ func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.Physical
|
|||||||
return physicalInterfaces, nil
|
return physicalInterfaces, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetInterfaceAndPeers returns the interface and all peers for the given interface identifier.
|
||||||
func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
|
func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
|
||||||
*domain.Interface,
|
*domain.Interface,
|
||||||
[]domain.Peer,
|
[]domain.Peer,
|
||||||
@ -38,6 +41,7 @@ func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceId
|
|||||||
return m.db.GetInterfaceAndPeers(ctx, id)
|
return m.db.GetInterfaceAndPeers(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAllInterfaces returns all interfaces that are available in the database.
|
||||||
func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) {
|
func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) {
|
||||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -46,6 +50,7 @@ func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, erro
|
|||||||
return m.db.GetAllInterfaces(ctx)
|
return m.db.GetAllInterfaces(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAllInterfacesAndPeers returns all interfaces and their peers.
|
||||||
func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) {
|
func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) {
|
||||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
@ -97,6 +102,7 @@ func (m Manager) GetUserInterfaces(ctx context.Context, _ domain.UserIdentifier)
|
|||||||
return userInterfaces, nil
|
return userInterfaces, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ImportNewInterfaces imports all new physical interfaces that are available on the system.
|
||||||
func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) {
|
func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) {
|
||||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@ -148,6 +154,7 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter
|
|||||||
return imported, nil
|
return imported, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ApplyPeerDefaults applies the interface defaults to all peers of the given interface.
|
||||||
func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error {
|
func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error {
|
||||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -179,6 +186,8 @@ func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) er
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RestoreInterfaceState restores the state of all physical interfaces and their peers.
|
||||||
|
// The final state of the interfaces and peers will be the same as stored in the database.
|
||||||
func (m Manager) RestoreInterfaceState(
|
func (m Manager) RestoreInterfaceState(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
updateDbOnError bool,
|
updateDbOnError bool,
|
||||||
@ -296,6 +305,7 @@ func (m Manager) RestoreInterfaceState(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PrepareInterface generates a new interface with fresh keys, ip addresses and a listen port.
|
||||||
func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error) {
|
func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error) {
|
||||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -376,6 +386,7 @@ func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error
|
|||||||
return freshInterface, nil
|
return freshInterface, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateInterface creates a new interface with the given configuration.
|
||||||
func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) {
|
func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) {
|
||||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -401,6 +412,7 @@ func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*do
|
|||||||
return in, nil
|
return in, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateInterface updates the given interface with the new configuration.
|
||||||
func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error) {
|
func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error) {
|
||||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
@ -423,6 +435,7 @@ func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*do
|
|||||||
return in, existingPeers, nil
|
return in, existingPeers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteInterface deletes the given interface.
|
||||||
func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
|
func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
|
||||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/h44z/wg-portal/internal/domain"
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// CreateDefaultPeer creates a default peer for the given user on all server interfaces.
|
||||||
func (m Manager) CreateDefaultPeer(ctx context.Context, userId domain.UserIdentifier) error {
|
func (m Manager) CreateDefaultPeer(ctx context.Context, userId domain.UserIdentifier) error {
|
||||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -55,6 +56,7 @@ func (m Manager) CreateDefaultPeer(ctx context.Context, userId domain.UserIdenti
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserPeers returns all peers for the given user.
|
||||||
func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
|
func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
|
||||||
if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
|
if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -63,6 +65,7 @@ func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]
|
|||||||
return m.db.GetUserPeers(ctx, id)
|
return m.db.GetUserPeers(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PreparePeer prepares a new peer for the given interface with fresh keys and ip addresses.
|
||||||
func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) {
|
func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) {
|
||||||
if !m.cfg.Core.SelfProvisioningAllowed {
|
if !m.cfg.Core.SelfProvisioningAllowed {
|
||||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||||
@ -143,6 +146,7 @@ func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier)
|
|||||||
return freshPeer, nil
|
return freshPeer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPeer returns the peer with the given identifier.
|
||||||
func (m Manager) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) {
|
func (m Manager) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) {
|
||||||
peer, err := m.db.GetPeer(ctx, id)
|
peer, err := m.db.GetPeer(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -156,6 +160,7 @@ func (m Manager) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain
|
|||||||
return peer, nil
|
return peer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreatePeer creates a new peer.
|
||||||
func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
|
func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
|
||||||
if !m.cfg.Core.SelfProvisioningAllowed {
|
if !m.cfg.Core.SelfProvisioningAllowed {
|
||||||
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
|
||||||
@ -201,6 +206,8 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
|
|||||||
return peer, nil
|
return peer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateMultiplePeers creates multiple new peers for the given user identifiers.
|
||||||
|
// It calls PreparePeer for each user identifier in the request.
|
||||||
func (m Manager) CreateMultiplePeers(
|
func (m Manager) CreateMultiplePeers(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
interfaceId domain.InterfaceIdentifier,
|
interfaceId domain.InterfaceIdentifier,
|
||||||
@ -243,6 +250,7 @@ func (m Manager) CreateMultiplePeers(
|
|||||||
return createdPeers, nil
|
return createdPeers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdatePeer updates the given peer.
|
||||||
func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
|
func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
|
||||||
existingPeer, err := m.db.GetPeer(ctx, peer.Identifier)
|
existingPeer, err := m.db.GetPeer(ctx, peer.Identifier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -309,6 +317,7 @@ func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
|
|||||||
return peer, nil
|
return peer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeletePeer deletes the peer with the given identifier.
|
||||||
func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error {
|
func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error {
|
||||||
peer, err := m.db.GetPeer(ctx, id)
|
peer, err := m.db.GetPeer(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -341,6 +350,7 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPeerStats returns the status of the peer with the given identifier.
|
||||||
func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error) {
|
func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error) {
|
||||||
_, peers, err := m.db.GetInterfaceAndPeers(ctx, id)
|
_, peers, err := m.db.GetInterfaceAndPeers(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -359,6 +369,7 @@ func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier
|
|||||||
return m.db.GetPeersStats(ctx, peerIds...)
|
return m.db.GetPeersStats(ctx, peerIds...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserPeerStats returns the status of all peers for the given user.
|
||||||
func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) {
|
func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) {
|
||||||
if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
|
if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/go-ldap/ldap/v3"
|
"github.com/go-ldap/ldap/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Auth contains all authentication providers.
|
||||||
type Auth struct {
|
type Auth struct {
|
||||||
// OpenIDConnect contains a list of OpenID Connect providers.
|
// OpenIDConnect contains a list of OpenID Connect providers.
|
||||||
OpenIDConnect []OpenIDConnectProvider `yaml:"oidc"`
|
OpenIDConnect []OpenIDConnectProvider `yaml:"oidc"`
|
||||||
@ -17,6 +18,7 @@ type Auth struct {
|
|||||||
Ldap []LdapProvider `yaml:"ldap"`
|
Ldap []LdapProvider `yaml:"ldap"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BaseFields contains the basic fields that are used to map user information from the authentication providers.
|
||||||
type BaseFields struct {
|
type BaseFields struct {
|
||||||
// UserIdentifier is the name of the field that contains the user identifier.
|
// UserIdentifier is the name of the field that contains the user identifier.
|
||||||
UserIdentifier string `yaml:"user_identifier"`
|
UserIdentifier string `yaml:"user_identifier"`
|
||||||
@ -32,6 +34,7 @@ type BaseFields struct {
|
|||||||
Department string `yaml:"department"`
|
Department string `yaml:"department"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OauthFields contains extra fields that are used to map user information from OAuth providers.
|
||||||
type OauthFields struct {
|
type OauthFields struct {
|
||||||
BaseFields `yaml:",inline"`
|
BaseFields `yaml:",inline"`
|
||||||
// IsAdmin is the name of the field that contains the admin flag.
|
// IsAdmin is the name of the field that contains the admin flag.
|
||||||
@ -107,12 +110,14 @@ func (o *OauthAdminMapping) GetAdminGroupRegex() *regexp.Regexp {
|
|||||||
return o.adminGroupRegex
|
return o.adminGroupRegex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LdapFields contains extra fields that are used to map user information from LDAP providers.
|
||||||
type LdapFields struct {
|
type LdapFields struct {
|
||||||
BaseFields `yaml:",inline"`
|
BaseFields `yaml:",inline"`
|
||||||
// GroupMembership is the name of the LDAP field that contains the groups to which the user belongs.
|
// GroupMembership is the name of the LDAP field that contains the groups to which the user belongs.
|
||||||
GroupMembership string `yaml:"memberof"`
|
GroupMembership string `yaml:"memberof"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LdapProvider contains the configuration for the LDAP connection.
|
||||||
type LdapProvider struct {
|
type LdapProvider struct {
|
||||||
// ProviderName is an internal name that is used to distinguish LDAP servers. It must not contain spaces or special characters.
|
// ProviderName is an internal name that is used to distinguish LDAP servers. It must not contain spaces or special characters.
|
||||||
ProviderName string `yaml:"provider_name"`
|
ProviderName string `yaml:"provider_name"`
|
||||||
@ -163,6 +168,7 @@ type LdapProvider struct {
|
|||||||
LogUserInfo bool `yaml:"log_user_info"`
|
LogUserInfo bool `yaml:"log_user_info"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenIDConnectProvider contains the configuration for the OpenID Connect provider.
|
||||||
type OpenIDConnectProvider struct {
|
type OpenIDConnectProvider struct {
|
||||||
// ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters.
|
// ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters.
|
||||||
ProviderName string `yaml:"provider_name"`
|
ProviderName string `yaml:"provider_name"`
|
||||||
@ -196,6 +202,7 @@ type OpenIDConnectProvider struct {
|
|||||||
LogUserInfo bool `yaml:"log_user_info"`
|
LogUserInfo bool `yaml:"log_user_info"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OAuthProvider contains the configuration for the OAuth provider.
|
||||||
type OAuthProvider struct {
|
type OAuthProvider struct {
|
||||||
// ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters.
|
// ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters.
|
||||||
ProviderName string `yaml:"provider_name"`
|
ProviderName string `yaml:"provider_name"`
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Config is the main configuration struct.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Core struct {
|
Core struct {
|
||||||
// AdminUser defines the default administrator account that will be created
|
// AdminUser defines the default administrator account that will be created
|
||||||
@ -179,6 +180,7 @@ func GetConfig() (*Config, error) {
|
|||||||
return cfg, nil
|
return cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// loadConfigFile loads the configuration from a YAML file into the given cfg struct.
|
||||||
func loadConfigFile(cfg any, filename string) error {
|
func loadConfigFile(cfg any, filename string) error {
|
||||||
data, err := envsubst.ReadFile(filename)
|
data, err := envsubst.ReadFile(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -2,6 +2,8 @@ package config
|
|||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
||||||
|
// SupportedDatabase is a type for the supported database types.
|
||||||
|
// Supported: mysql, mssql, postgres, sqlite
|
||||||
type SupportedDatabase string
|
type SupportedDatabase string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -11,6 +13,7 @@ const (
|
|||||||
DatabaseSQLite SupportedDatabase = "sqlite"
|
DatabaseSQLite SupportedDatabase = "sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// DatabaseConfig contains the configuration for the database connection.
|
||||||
type DatabaseConfig struct {
|
type DatabaseConfig struct {
|
||||||
// Debug enables logging of all database statements
|
// Debug enables logging of all database statements
|
||||||
Debug bool `yaml:"debug"`
|
Debug bool `yaml:"debug"`
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
|
// MailEncryption is the type of the SMTP encryption.
|
||||||
|
// Supported: none, tls, starttls
|
||||||
type MailEncryption string
|
type MailEncryption string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -8,6 +10,8 @@ const (
|
|||||||
MailEncryptionStartTLS MailEncryption = "starttls"
|
MailEncryptionStartTLS MailEncryption = "starttls"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MailAuthType is the type of the SMTP authentication.
|
||||||
|
// Supported: plain, login, crammd5
|
||||||
type MailAuthType string
|
type MailAuthType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -16,6 +20,7 @@ const (
|
|||||||
MailAuthCramMD5 MailAuthType = "crammd5"
|
MailAuthCramMD5 MailAuthType = "crammd5"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MailConfig contains the configuration for the mail server which is used to send emails.
|
||||||
type MailConfig struct {
|
type MailConfig struct {
|
||||||
// Host is the hostname or IP of the SMTP server
|
// Host is the hostname or IP of the SMTP server
|
||||||
Host string `yaml:"host"`
|
Host string `yaml:"host"`
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
|
// WebConfig contains the configuration for the web server.
|
||||||
type WebConfig struct {
|
type WebConfig struct {
|
||||||
// RequestLogging enables logging of all HTTP requests.
|
// RequestLogging enables logging of all HTTP requests.
|
||||||
RequestLogging bool `yaml:"request_logging"`
|
RequestLogging bool `yaml:"request_logging"`
|
||||||
|
@ -1,11 +1,5 @@
|
|||||||
package domain
|
package domain
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
)
|
|
||||||
|
|
||||||
type LoginProvider string
|
type LoginProvider string
|
||||||
|
|
||||||
type LoginProviderInfo struct {
|
type LoginProviderInfo struct {
|
||||||
@ -24,28 +18,3 @@ type AuthenticatorUserInfo struct {
|
|||||||
Department string
|
Department string
|
||||||
IsAdmin bool
|
IsAdmin bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthenticatorType string
|
|
||||||
|
|
||||||
const (
|
|
||||||
AuthenticatorTypeOAuth AuthenticatorType = "oauth"
|
|
||||||
AuthenticatorTypeOidc AuthenticatorType = "oidc"
|
|
||||||
)
|
|
||||||
|
|
||||||
type OauthAuthenticator interface {
|
|
||||||
GetName() string
|
|
||||||
GetType() AuthenticatorType
|
|
||||||
AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
|
|
||||||
Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error)
|
|
||||||
GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (map[string]any, error)
|
|
||||||
ParseUserInfo(raw map[string]any) (*AuthenticatorUserInfo, error)
|
|
||||||
RegistrationEnabled() bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type LdapAuthenticator interface {
|
|
||||||
GetName() string
|
|
||||||
PlaintextAuthentication(userId UserIdentifier, plainPassword string) error
|
|
||||||
GetUserInfo(ctx context.Context, username UserIdentifier) (map[string]any, error)
|
|
||||||
ParseUserInfo(raw map[string]any) (*AuthenticatorUserInfo, error)
|
|
||||||
RegistrationEnabled() bool
|
|
||||||
}
|
|
||||||
|
@ -33,6 +33,7 @@ func (p KeyPair) GetPublicKey() wgtypes.Key {
|
|||||||
|
|
||||||
type PreSharedKey string
|
type PreSharedKey string
|
||||||
|
|
||||||
|
// NewFreshKeypair generates a new key pair.
|
||||||
func NewFreshKeypair() (KeyPair, error) {
|
func NewFreshKeypair() (KeyPair, error) {
|
||||||
privateKey, err := wgtypes.GeneratePrivateKey()
|
privateKey, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -45,6 +46,7 @@ func NewFreshKeypair() (KeyPair, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewPreSharedKey generates a new pre-shared key.
|
||||||
func NewPreSharedKey() (PreSharedKey, error) {
|
func NewPreSharedKey() (PreSharedKey, error) {
|
||||||
preSharedKey, err := wgtypes.GenerateKey()
|
preSharedKey, err := wgtypes.GenerateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -54,6 +56,8 @@ func NewPreSharedKey() (PreSharedKey, error) {
|
|||||||
return PreSharedKey(preSharedKey.String()), nil
|
return PreSharedKey(preSharedKey.String()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PublicKeyFromPrivateKey returns the public key for a given private key.
|
||||||
|
// If the private key is invalid, an empty string is returned.
|
||||||
func PublicKeyFromPrivateKey(key string) string {
|
func PublicKeyFromPrivateKey(key string) string {
|
||||||
privKey, err := wgtypes.ParseKey(key)
|
privKey, err := wgtypes.ParseKey(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
56
internal/domain/crypto_test.go
Normal file
56
internal/domain/crypto_test.go
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestKeyPair_GetPrivateKeyBytesReturnsCorrectBytes(t *testing.T) {
|
||||||
|
keyPair := KeyPair{PrivateKey: base64.StdEncoding.EncodeToString([]byte("privateKey"))}
|
||||||
|
expected := []byte("privateKey")
|
||||||
|
assert.Equal(t, expected, keyPair.GetPrivateKeyBytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyPair_GetPublicKeyBytesReturnsCorrectBytes(t *testing.T) {
|
||||||
|
keyPair := KeyPair{PublicKey: base64.StdEncoding.EncodeToString([]byte("publicKey"))}
|
||||||
|
expected := []byte("publicKey")
|
||||||
|
assert.Equal(t, expected, keyPair.GetPublicKeyBytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyPair_GetPrivateKeyReturnsCorrectKey(t *testing.T) {
|
||||||
|
privateKey, _ := wgtypes.GeneratePrivateKey()
|
||||||
|
keyPair := KeyPair{PrivateKey: privateKey.String()}
|
||||||
|
assert.Equal(t, privateKey, keyPair.GetPrivateKey())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyPair_GetPublicKeyReturnsCorrectKey(t *testing.T) {
|
||||||
|
privateKey, _ := wgtypes.GeneratePrivateKey()
|
||||||
|
keyPair := KeyPair{PublicKey: privateKey.PublicKey().String()}
|
||||||
|
assert.Equal(t, privateKey.PublicKey(), keyPair.GetPublicKey())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewFreshKeypairGeneratesValidKeypair(t *testing.T) {
|
||||||
|
keyPair, err := NewFreshKeypair()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, keyPair.PrivateKey)
|
||||||
|
assert.NotEmpty(t, keyPair.PublicKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewPreSharedKeyGeneratesValidKey(t *testing.T) {
|
||||||
|
preSharedKey, err := NewPreSharedKey()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, preSharedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublicKeyFromPrivateKeyReturnsCorrectPublicKey(t *testing.T) {
|
||||||
|
privateKey, _ := wgtypes.GeneratePrivateKey()
|
||||||
|
expected := privateKey.PublicKey().String()
|
||||||
|
assert.Equal(t, expected, PublicKeyFromPrivateKey(privateKey.String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublicKeyFromPrivateKeyReturnsEmptyStringOnInvalidKey(t *testing.T) {
|
||||||
|
assert.Equal(t, "", PublicKeyFromPrivateKey("invalidKey"))
|
||||||
|
}
|
83
internal/domain/interface_test.go
Normal file
83
internal/domain/interface_test.go
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInterface_IsDisabledReturnsTrueWhenDisabled(t *testing.T) {
|
||||||
|
iface := &Interface{}
|
||||||
|
assert.False(t, iface.IsDisabled())
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
iface.Disabled = &now
|
||||||
|
assert.True(t, iface.IsDisabled())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterface_AddressStrReturnsCorrectString(t *testing.T) {
|
||||||
|
iface := &Interface{
|
||||||
|
Addresses: []Cidr{
|
||||||
|
{Cidr: "192.168.1.1/24", Addr: "192.168.1.1", NetLength: 24},
|
||||||
|
{Cidr: "10.0.0.1/24", Addr: "10.0.0.1", NetLength: 24},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
expected := "192.168.1.1/24,10.0.0.1/24"
|
||||||
|
assert.Equal(t, expected, iface.AddressStr())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterface_GetConfigFileNameReturnsCorrectFileName(t *testing.T) {
|
||||||
|
iface := &Interface{Identifier: "wg0"}
|
||||||
|
expected := "wg0.conf"
|
||||||
|
assert.Equal(t, expected, iface.GetConfigFileName())
|
||||||
|
|
||||||
|
iface.Identifier = "wg0@123"
|
||||||
|
expected = "wg0123.conf"
|
||||||
|
assert.Equal(t, expected, iface.GetConfigFileName())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterface_GetAllowedIPsReturnsCorrectCidrs(t *testing.T) {
|
||||||
|
peer1 := Peer{
|
||||||
|
Interface: PeerInterfaceConfig{
|
||||||
|
Addresses: []Cidr{
|
||||||
|
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
peer2 := Peer{
|
||||||
|
Interface: PeerInterfaceConfig{
|
||||||
|
Addresses: []Cidr{
|
||||||
|
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
iface := &Interface{}
|
||||||
|
expected := []Cidr{
|
||||||
|
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
|
||||||
|
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
|
||||||
|
}
|
||||||
|
assert.Equal(t, expected, iface.GetAllowedIPs([]Peer{peer1, peer2}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterface_ManageRoutingTableReturnsCorrectValue(t *testing.T) {
|
||||||
|
iface := &Interface{RoutingTable: "off"}
|
||||||
|
assert.False(t, iface.ManageRoutingTable())
|
||||||
|
|
||||||
|
iface.RoutingTable = "100"
|
||||||
|
assert.True(t, iface.ManageRoutingTable())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterface_GetRoutingTableReturnsCorrectValue(t *testing.T) {
|
||||||
|
iface := &Interface{RoutingTable: ""}
|
||||||
|
assert.Equal(t, 0, iface.GetRoutingTable())
|
||||||
|
|
||||||
|
iface.RoutingTable = "off"
|
||||||
|
assert.Equal(t, -1, iface.GetRoutingTable())
|
||||||
|
|
||||||
|
iface.RoutingTable = "0x64"
|
||||||
|
assert.Equal(t, 100, iface.GetRoutingTable())
|
||||||
|
|
||||||
|
iface.RoutingTable = "200"
|
||||||
|
assert.Equal(t, 200, iface.GetRoutingTable())
|
||||||
|
}
|
42
internal/domain/options_test.go
Normal file
42
internal/domain/options_test.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfigOption_GetValueReturnsCorrectValue(t *testing.T) {
|
||||||
|
option := ConfigOption[int]{Value: 42}
|
||||||
|
assert.Equal(t, 42, option.GetValue())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigOption_SetValueUpdatesValue(t *testing.T) {
|
||||||
|
option := ConfigOption[int]{Value: 42}
|
||||||
|
option.SetValue(100)
|
||||||
|
assert.Equal(t, 100, option.GetValue())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigOption_TrySetValueUpdatesValueWhenOverridable(t *testing.T) {
|
||||||
|
option := ConfigOption[int]{Value: 42, Overridable: true}
|
||||||
|
result := option.TrySetValue(100)
|
||||||
|
assert.True(t, result)
|
||||||
|
assert.Equal(t, 100, option.GetValue())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigOption_TrySetValueDoesNotUpdateValueWhenNotOverridable(t *testing.T) {
|
||||||
|
option := ConfigOption[int]{Value: 42, Overridable: false}
|
||||||
|
result := option.TrySetValue(100)
|
||||||
|
assert.False(t, result)
|
||||||
|
assert.Equal(t, 42, option.GetValue())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewConfigOptionCreatesCorrectOption(t *testing.T) {
|
||||||
|
option := NewConfigOption(42, true)
|
||||||
|
assert.Equal(t, 42, option.GetValue())
|
||||||
|
assert.True(t, option.Overridable)
|
||||||
|
|
||||||
|
option2 := NewConfigOption("str", false)
|
||||||
|
assert.Equal(t, "str", option2.GetValue())
|
||||||
|
assert.False(t, option2.Overridable)
|
||||||
|
}
|
165
internal/domain/peer_test.go
Normal file
165
internal/domain/peer_test.go
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPeer_IsDisabled(t *testing.T) {
|
||||||
|
peer := &Peer{}
|
||||||
|
assert.False(t, peer.IsDisabled())
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
peer.Disabled = &now
|
||||||
|
assert.True(t, peer.IsDisabled())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_IsExpired(t *testing.T) {
|
||||||
|
peer := &Peer{}
|
||||||
|
assert.False(t, peer.IsExpired())
|
||||||
|
|
||||||
|
expiredTime := time.Now().Add(-time.Hour)
|
||||||
|
peer.ExpiresAt = &expiredTime
|
||||||
|
assert.True(t, peer.IsExpired())
|
||||||
|
|
||||||
|
futureTime := time.Now().Add(time.Hour)
|
||||||
|
peer.ExpiresAt = &futureTime
|
||||||
|
assert.False(t, peer.IsExpired())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_CheckAliveAddress(t *testing.T) {
|
||||||
|
peer := &Peer{}
|
||||||
|
assert.Equal(t, "", peer.CheckAliveAddress())
|
||||||
|
|
||||||
|
peer.Interface.CheckAliveAddress = "192.168.1.1"
|
||||||
|
assert.Equal(t, "192.168.1.1", peer.CheckAliveAddress())
|
||||||
|
|
||||||
|
peer.Interface.CheckAliveAddress = ""
|
||||||
|
peer.Interface.Addresses = []Cidr{{Addr: "10.0.0.1"}}
|
||||||
|
assert.Equal(t, "10.0.0.1", peer.CheckAliveAddress())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_GetConfigFileName(t *testing.T) {
|
||||||
|
peer := &Peer{DisplayName: "Test Peer"}
|
||||||
|
expected := "Test_Peer.conf"
|
||||||
|
assert.Equal(t, expected, peer.GetConfigFileName())
|
||||||
|
|
||||||
|
peer.DisplayName = ""
|
||||||
|
peer.Identifier = "12345678"
|
||||||
|
expected = "wg_12345678.conf"
|
||||||
|
assert.Equal(t, expected, peer.GetConfigFileName())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_ApplyInterfaceDefaults(t *testing.T) {
|
||||||
|
peer := &Peer{
|
||||||
|
Endpoint: ConfigOption[string]{
|
||||||
|
Value: "",
|
||||||
|
Overridable: true,
|
||||||
|
},
|
||||||
|
EndpointPublicKey: ConfigOption[string]{
|
||||||
|
Value: "",
|
||||||
|
Overridable: true,
|
||||||
|
},
|
||||||
|
AllowedIPsStr: ConfigOption[string]{
|
||||||
|
Value: "1.1.1.1/32",
|
||||||
|
Overridable: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
iface := &Interface{
|
||||||
|
PeerDefEndpoint: "192.168.1.1",
|
||||||
|
KeyPair: KeyPair{
|
||||||
|
PublicKey: "publicKey",
|
||||||
|
},
|
||||||
|
PeerDefAllowedIPsStr: "8.8.8.8/32",
|
||||||
|
}
|
||||||
|
|
||||||
|
peer.ApplyInterfaceDefaults(iface)
|
||||||
|
assert.Equal(t, "192.168.1.1", peer.Endpoint.GetValue())
|
||||||
|
assert.Equal(t, "publicKey", peer.EndpointPublicKey.GetValue())
|
||||||
|
assert.Equal(t, "1.1.1.1/32", peer.AllowedIPsStr.GetValue())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_GenerateDisplayName(t *testing.T) {
|
||||||
|
peer := &Peer{Identifier: "12345678"}
|
||||||
|
peer.GenerateDisplayName("Prefix")
|
||||||
|
expected := "Prefix Peer 12345678"
|
||||||
|
assert.Equal(t, expected, peer.DisplayName)
|
||||||
|
|
||||||
|
peer.GenerateDisplayName("")
|
||||||
|
expected = "Peer 12345678"
|
||||||
|
assert.Equal(t, expected, peer.DisplayName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_OverwriteUserEditableFields(t *testing.T) {
|
||||||
|
peer := &Peer{}
|
||||||
|
userPeer := &Peer{
|
||||||
|
DisplayName: "New DisplayName",
|
||||||
|
}
|
||||||
|
|
||||||
|
peer.OverwriteUserEditableFields(userPeer)
|
||||||
|
assert.Equal(t, "New DisplayName", peer.DisplayName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_GetPresharedKey(t *testing.T) {
|
||||||
|
physicalPeer := PhysicalPeer{}
|
||||||
|
assert.Nil(t, physicalPeer.GetPresharedKey())
|
||||||
|
|
||||||
|
physicalPeer.PresharedKey = "Q0evIJTOjhyy2o5J7whvrsvQC+FRL8A74vrw44YHUAk="
|
||||||
|
key := physicalPeer.GetPresharedKey()
|
||||||
|
assert.NotNil(t, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_GetEndpointAddress(t *testing.T) {
|
||||||
|
physicalPeer := PhysicalPeer{}
|
||||||
|
assert.Nil(t, physicalPeer.GetEndpointAddress())
|
||||||
|
|
||||||
|
physicalPeer.Endpoint = "192.168.1.1:51820"
|
||||||
|
addr := physicalPeer.GetEndpointAddress()
|
||||||
|
assert.NotNil(t, addr)
|
||||||
|
assert.Equal(t, "192.168.1.1:51820", addr.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_GetPersistentKeepaliveTime(t *testing.T) {
|
||||||
|
physicalPeer := PhysicalPeer{}
|
||||||
|
assert.Nil(t, physicalPeer.GetPersistentKeepaliveTime())
|
||||||
|
|
||||||
|
physicalPeer.PersistentKeepalive = 25
|
||||||
|
duration := physicalPeer.GetPersistentKeepaliveTime()
|
||||||
|
assert.NotNil(t, duration)
|
||||||
|
assert.Equal(t, 25*time.Second, *duration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeer_GetAllowedIPs(t *testing.T) {
|
||||||
|
physicalPeer := PhysicalPeer{}
|
||||||
|
assert.Empty(t, physicalPeer.GetAllowedIPs())
|
||||||
|
|
||||||
|
physicalPeer.AllowedIPs = []Cidr{
|
||||||
|
{
|
||||||
|
Cidr: "192.168.1.0/24",
|
||||||
|
Addr: "192.168.1.0",
|
||||||
|
NetLength: 24,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ips := physicalPeer.GetAllowedIPs()
|
||||||
|
assert.Len(t, ips, 1)
|
||||||
|
assert.Equal(t, "192.168.1.0/24", ips[0].String())
|
||||||
|
|
||||||
|
physicalPeer.AllowedIPs = []Cidr{
|
||||||
|
{
|
||||||
|
Cidr: "192.168.1.0/24",
|
||||||
|
Addr: "192.168.1.0",
|
||||||
|
NetLength: 24,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Cidr: "fe80::/64",
|
||||||
|
Addr: "fe80::",
|
||||||
|
NetLength: 64,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ips2 := physicalPeer.GetAllowedIPs()
|
||||||
|
assert.Len(t, ips2, 2)
|
||||||
|
assert.Equal(t, "192.168.1.0/24", ips2[0].String())
|
||||||
|
assert.Equal(t, "fe80::/64", ips2[1].String())
|
||||||
|
}
|
74
internal/domain/statistics_test.go
Normal file
74
internal/domain/statistics_test.go
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPeerStatus_IsConnected(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
past := now.Add(-3 * time.Minute)
|
||||||
|
recent := now.Add(-1 * time.Minute)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
status PeerStatus
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Pingable and recent handshake",
|
||||||
|
status: PeerStatus{
|
||||||
|
IsPingable: true,
|
||||||
|
LastHandshake: &recent,
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Not pingable but recent handshake",
|
||||||
|
status: PeerStatus{
|
||||||
|
IsPingable: false,
|
||||||
|
LastHandshake: &recent,
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Pingable but old handshake",
|
||||||
|
status: PeerStatus{
|
||||||
|
IsPingable: true,
|
||||||
|
LastHandshake: &past,
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Not pingable and old handshake",
|
||||||
|
status: PeerStatus{
|
||||||
|
IsPingable: false,
|
||||||
|
LastHandshake: &past,
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Pingable and no handshake",
|
||||||
|
status: PeerStatus{
|
||||||
|
IsPingable: true,
|
||||||
|
LastHandshake: nil,
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Not pingable and no handshake",
|
||||||
|
status: PeerStatus{
|
||||||
|
IsPingable: false,
|
||||||
|
LastHandshake: nil,
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := tt.status.IsConnected(); got != tt.want {
|
||||||
|
t.Errorf("IsConnected() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
125
internal/domain/user_test.go
Normal file
125
internal/domain/user_test.go
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUser_IsDisabled(t *testing.T) {
|
||||||
|
user := &User{}
|
||||||
|
assert.False(t, user.IsDisabled())
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
user.Disabled = &now
|
||||||
|
assert.True(t, user.IsDisabled())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUser_IsLocked(t *testing.T) {
|
||||||
|
user := &User{}
|
||||||
|
assert.False(t, user.IsLocked())
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
user.Locked = &now
|
||||||
|
assert.True(t, user.IsLocked())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUser_IsApiEnabled(t *testing.T) {
|
||||||
|
user := &User{}
|
||||||
|
assert.False(t, user.IsApiEnabled())
|
||||||
|
|
||||||
|
user.ApiToken = "token"
|
||||||
|
assert.True(t, user.IsApiEnabled())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUser_CanChangePassword(t *testing.T) {
|
||||||
|
user := &User{Source: UserSourceDatabase}
|
||||||
|
assert.NoError(t, user.CanChangePassword())
|
||||||
|
|
||||||
|
user.Source = UserSourceLdap
|
||||||
|
assert.Error(t, user.CanChangePassword())
|
||||||
|
|
||||||
|
user.Source = UserSourceOauth
|
||||||
|
assert.Error(t, user.CanChangePassword())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUser_EditAllowed(t *testing.T) {
|
||||||
|
user := &User{Source: UserSourceDatabase}
|
||||||
|
newUser := &User{Source: UserSourceDatabase}
|
||||||
|
assert.NoError(t, user.EditAllowed(newUser))
|
||||||
|
|
||||||
|
newUser.Notes = "notes can be changed"
|
||||||
|
assert.NoError(t, user.EditAllowed(newUser))
|
||||||
|
|
||||||
|
newUser.Disabled = &time.Time{}
|
||||||
|
assert.NoError(t, user.EditAllowed(newUser))
|
||||||
|
|
||||||
|
newUser.Lastname = "lastname or other fields can be changed"
|
||||||
|
assert.NoError(t, user.EditAllowed(newUser))
|
||||||
|
|
||||||
|
user.Source = UserSourceLdap
|
||||||
|
newUser.Source = UserSourceLdap
|
||||||
|
newUser.Disabled = nil
|
||||||
|
newUser.Lastname = ""
|
||||||
|
newUser.Notes = "notes can be changed"
|
||||||
|
assert.NoError(t, user.EditAllowed(newUser))
|
||||||
|
|
||||||
|
newUser.Disabled = &time.Time{}
|
||||||
|
assert.NoError(t, user.EditAllowed(newUser))
|
||||||
|
|
||||||
|
newUser.Lastname = "lastname or other fields can not be changed"
|
||||||
|
assert.Error(t, user.EditAllowed(newUser))
|
||||||
|
|
||||||
|
user.Source = UserSourceOauth
|
||||||
|
newUser.Source = UserSourceOauth
|
||||||
|
newUser.Disabled = nil
|
||||||
|
newUser.Lastname = ""
|
||||||
|
newUser.Notes = "notes can be changed"
|
||||||
|
assert.NoError(t, user.EditAllowed(newUser))
|
||||||
|
|
||||||
|
newUser.Disabled = &time.Time{}
|
||||||
|
assert.NoError(t, user.EditAllowed(newUser))
|
||||||
|
|
||||||
|
newUser.Lastname = "lastname or other fields can not be changed"
|
||||||
|
assert.Error(t, user.EditAllowed(newUser))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUser_DeleteAllowed(t *testing.T) {
|
||||||
|
user := &User{}
|
||||||
|
assert.NoError(t, user.DeleteAllowed())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUser_CheckPassword(t *testing.T) {
|
||||||
|
password := "password"
|
||||||
|
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
|
||||||
|
user := &User{Source: UserSourceDatabase, Password: PrivateString(hashedPassword)}
|
||||||
|
assert.NoError(t, user.CheckPassword(password))
|
||||||
|
|
||||||
|
user.Password = ""
|
||||||
|
assert.Error(t, user.CheckPassword(password))
|
||||||
|
|
||||||
|
user.Source = UserSourceLdap
|
||||||
|
assert.Error(t, user.CheckPassword(password))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUser_CheckApiToken(t *testing.T) {
|
||||||
|
user := &User{}
|
||||||
|
assert.Error(t, user.CheckApiToken("token"))
|
||||||
|
|
||||||
|
user.ApiToken = "token"
|
||||||
|
assert.NoError(t, user.CheckApiToken("token"))
|
||||||
|
|
||||||
|
assert.Error(t, user.CheckApiToken("wrong_token"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUser_HashPassword(t *testing.T) {
|
||||||
|
user := &User{Password: "password"}
|
||||||
|
assert.NoError(t, user.HashPassword())
|
||||||
|
assert.NotEmpty(t, user.Password)
|
||||||
|
|
||||||
|
user.Password = ""
|
||||||
|
assert.NoError(t, user.HashPassword())
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user