From 2a5b4fe31dd44e5a734abacd499714db73522d03 Mon Sep 17 00:00:00 2001 From: Christoph Haas Date: Mon, 24 Jul 2023 21:00:45 +0200 Subject: [PATCH] many more improvements and cleanup --- README.md | 2 +- cmd/wg-portal/main.go | 6 ++- internal/adapters/database.go | 4 ++ internal/adapters/filesystem.go | 33 ++++++++----- internal/adapters/migrations/1_init.down.sql | 1 - internal/adapters/migrations/1_init.up.sql | 3 -- internal/adapters/wireguard.go | 47 +++++++++--------- .../adapters/wireguard_integration_test.go | 2 +- .../api/v0/handlers/endpoint_interfaces.go | 16 +++---- internal/app/api/v0/model/models_interface.go | 26 +++++++--- internal/app/api/v0/model/models_peer.go | 4 +- internal/app/app.go | 10 ++-- internal/app/configfile/manager.go | 30 ++++-------- internal/app/configfile/repos.go | 5 ++ internal/app/repos.go | 6 +-- .../app/wireguard/wireguard_interfaces.go | 48 ++++++++++++++----- internal/app/wireguard/wireguard_peers.go | 42 ++++++++++++++-- internal/config/config.go | 6 ++- internal/domain/context.go | 9 +++- internal/domain/peer.go | 4 +- internal/ldap_utils.go | 5 +- 21 files changed, 201 insertions(+), 108 deletions(-) delete mode 100644 internal/adapters/migrations/1_init.down.sql delete mode 100644 internal/adapters/migrations/1_init.up.sql diff --git a/README.md b/README.md index f4a9f15..5747fd1 100644 --- a/README.md +++ b/README.md @@ -140,7 +140,7 @@ The following configuration options are available: * [Gin, HTTP web framework written in Go](https://github.com/gin-gonic/gin) * [Bootstrap, for the HTML templates](https://getbootstrap.com/) - * [Vue.JS, for the frontend](hhttps://vuejs.org/) + * [Vue.JS, for the frontend](https://vuejs.org/) ## License diff --git a/cmd/wg-portal/main.go b/cmd/wg-portal/main.go index 97aaef0..cfc3e10 100644 --- a/cmd/wg-portal/main.go +++ b/cmd/wg-portal/main.go @@ -23,6 +23,7 @@ import ( evbus "github.com/vardius/message-bus" ) +// main entry point for WireGuard Portal func main() { ctx := internal.SignalAwareContext(context.Background(), syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM) @@ -44,6 +45,9 @@ func main() { mailer := adapters.NewSmtpMailRepo(cfg.Mail) + cfgFileSystem, err := adapters.NewFileSystemRepository(cfg.Advanced.ConfigStoragePath) + internal.AssertNoError(err) + shouldExit, err := app.HandleProgramArgs(cfg, rawDb) switch { case shouldExit && err == nil: @@ -70,7 +74,7 @@ func main() { statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, database, wireGuard) internal.AssertNoError(err) - cfgFileManager, err := configfile.NewConfigFileManager(cfg, database, database) + cfgFileManager, err := configfile.NewConfigFileManager(cfg, database, database, cfgFileSystem) internal.AssertNoError(err) mailManager, err := mail.NewMailManager(cfg, mailer, cfgFileManager, database, database) diff --git a/internal/adapters/database.go b/internal/adapters/database.go index fb6e4b7..be67108 100644 --- a/internal/adapters/database.go +++ b/internal/adapters/database.go @@ -22,8 +22,10 @@ import ( "gorm.io/gorm" ) +// SchemaVersion describes the current database schema version. It must be incremented if a manual migration is needed. var SchemaVersion uint64 = 1 +// SysStat stores the current database schema version and the timestamp when it was applied. type SysStat struct { MigratedAt time.Time `gorm:"column:migrated_at"` SchemaVersion uint64 `gorm:"primaryKey,column:schema_version"` @@ -143,6 +145,8 @@ func NewDatabase(cfg config.DatabaseConfig) (*gorm.DB, error) { return gormDb, nil } +// SqlRepo is a SQL database repository implementation. +// Currently, it supports MySQL, SQLite, Microsoft SQL and Postgresql database systems. type SqlRepo struct { db *gorm.DB } diff --git a/internal/adapters/filesystem.go b/internal/adapters/filesystem.go index f3f8a83..22bccc4 100644 --- a/internal/adapters/filesystem.go +++ b/internal/adapters/filesystem.go @@ -1,43 +1,52 @@ package adapters import ( - "context" + "fmt" + "github.com/sirupsen/logrus" "io" "os" "path/filepath" ) -type filesystemRepo struct { +type FilesystemRepo struct { basePath string } -func NewFileSystemRepository(basePath string) (*filesystemRepo, error) { - r := &filesystemRepo{basePath: basePath} +func NewFileSystemRepository(basePath string) (*FilesystemRepo, error) { + if basePath == "" { + return nil, nil // no path, return empty repository + } + + r := &FilesystemRepo{basePath: basePath} if err := os.MkdirAll(r.basePath, os.ModePerm); err != nil { - return nil, err + return nil, fmt.Errorf("failed to create base directory %s: %w", basePath, err) } return r, nil } -func (r *filesystemRepo) WriteFile(_ context.Context, path string, contents io.Reader) error { +func (r *FilesystemRepo) WriteFile(path string, contents io.Reader) error { filePath := filepath.Join(r.basePath, path) + parentDirectory := filepath.Dir(filePath) - err := os.MkdirAll(filepath.Dir(filePath), os.ModePerm) - if err != nil { - return err + if err := os.MkdirAll(parentDirectory, os.ModePerm); err != nil { + return fmt.Errorf("failed to create parent directory %s: %w", parentDirectory, err) } file, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.ModePerm) if err != nil { - return err + return fmt.Errorf("failed to open file %s: %w", file.Name(), err) } - defer file.Close() + defer func(file *os.File) { + if err := file.Close(); err != nil { + logrus.Errorf("failed to close file %s: %v", file.Name(), err) + } + }(file) _, err = io.Copy(file, contents) if err != nil { - return err + return fmt.Errorf("failed to write file contents: %w", err) } return nil diff --git a/internal/adapters/migrations/1_init.down.sql b/internal/adapters/migrations/1_init.down.sql deleted file mode 100644 index 1b10e6f..0000000 --- a/internal/adapters/migrations/1_init.down.sql +++ /dev/null @@ -1 +0,0 @@ -DROP TABLE IF EXISTS test; \ No newline at end of file diff --git a/internal/adapters/migrations/1_init.up.sql b/internal/adapters/migrations/1_init.up.sql deleted file mode 100644 index acae08b..0000000 --- a/internal/adapters/migrations/1_init.up.sql +++ /dev/null @@ -1,3 +0,0 @@ -CREATE TABLE IF NOT EXISTS test ( - firstname VARCHAR(16) -); \ No newline at end of file diff --git a/internal/adapters/wireguard.go b/internal/adapters/wireguard.go index b5002e3..7c43ae6 100644 --- a/internal/adapters/wireguard.go +++ b/internal/adapters/wireguard.go @@ -13,12 +13,13 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -type wgRepo struct { +// WgRepo implements all low-level WireGuard interactions. +type WgRepo struct { wg lowlevel.WireGuardClient nl lowlevel.NetlinkClient } -func NewWireGuardRepository() *wgRepo { +func NewWireGuardRepository() *WgRepo { wg, err := wgctrl.New() if err != nil { panic("failed to init wgctrl: " + err.Error()) @@ -26,7 +27,7 @@ func NewWireGuardRepository() *wgRepo { nl := &lowlevel.NetlinkManager{} - repo := &wgRepo{ + repo := &WgRepo{ wg: wg, nl: nl, } @@ -34,7 +35,7 @@ func NewWireGuardRepository() *wgRepo { return repo } -func (r *wgRepo) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) { +func (r *WgRepo) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) { devices, err := r.wg.Devices() if err != nil { return nil, fmt.Errorf("device list error: %w", err) @@ -52,11 +53,11 @@ func (r *wgRepo) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, e return interfaces, nil } -func (r *wgRepo) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) { +func (r *WgRepo) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) { return r.getInterface(id) } -func (r *wgRepo) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) { +func (r *WgRepo) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) { device, err := r.wg.Device(string(deviceId)) if err != nil { return nil, fmt.Errorf("device error: %w", err) @@ -74,11 +75,11 @@ func (r *wgRepo) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier return peers, nil } -func (r *wgRepo) GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) { +func (r *WgRepo) GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) { return r.getPeer(deviceId, id) } -func (r *wgRepo) convertWireGuardInterface(device *wgtypes.Device) (domain.PhysicalInterface, error) { +func (r *WgRepo) convertWireGuardInterface(device *wgtypes.Device) (domain.PhysicalInterface, error) { // read data from wgctrl interface iface := domain.PhysicalInterface{ @@ -122,7 +123,7 @@ func (r *wgRepo) convertWireGuardInterface(device *wgtypes.Device) (domain.Physi return iface, nil } -func (r *wgRepo) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer, error) { +func (r *WgRepo) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer, error) { peerModel := domain.PhysicalPeer{ Identifier: domain.PeerIdentifier(peer.PublicKey.String()), Endpoint: "", @@ -151,7 +152,7 @@ func (r *wgRepo) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer, return peerModel, nil } -func (r *wgRepo) SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error { +func (r *WgRepo) SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error { physicalInterface, err := r.getOrCreateInterface(id) if err != nil { return err @@ -174,7 +175,7 @@ func (r *wgRepo) SaveInterface(_ context.Context, id domain.InterfaceIdentifier, return nil } -func (r *wgRepo) getOrCreateInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) { +func (r *WgRepo) getOrCreateInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) { device, err := r.getInterface(id) if err == nil { return device, nil @@ -192,7 +193,7 @@ func (r *wgRepo) getOrCreateInterface(id domain.InterfaceIdentifier) (*domain.Ph return device, err } -func (r *wgRepo) getInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) { +func (r *WgRepo) getInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) { device, err := r.wg.Device(string(id)) if err != nil { return nil, err @@ -202,7 +203,7 @@ func (r *wgRepo) getInterface(id domain.InterfaceIdentifier) (*domain.PhysicalIn return &pi, err } -func (r *wgRepo) createLowLevelInterface(id domain.InterfaceIdentifier) error { +func (r *WgRepo) createLowLevelInterface(id domain.InterfaceIdentifier) error { link := &netlink.GenericLink{ LinkAttrs: netlink.LinkAttrs{ Name: string(id), @@ -217,7 +218,7 @@ func (r *wgRepo) createLowLevelInterface(id domain.InterfaceIdentifier) error { return nil } -func (r *wgRepo) updateLowLevelInterface(pi *domain.PhysicalInterface) error { +func (r *WgRepo) updateLowLevelInterface(pi *domain.PhysicalInterface) error { link, err := r.nl.LinkByName(string(pi.Identifier)) if err != nil { return err @@ -274,7 +275,7 @@ func (r *wgRepo) updateLowLevelInterface(pi *domain.PhysicalInterface) error { return nil } -func (r *wgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error { +func (r *WgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error { pKey, err := wgtypes.NewKey(pi.KeyPair.GetPrivateKeyBytes()) if err != nil { return err @@ -297,7 +298,7 @@ func (r *wgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error { return nil } -func (r *wgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error { +func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error { if err := r.deleteLowLevelInterface(id); err != nil { return err } @@ -305,7 +306,7 @@ func (r *wgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifie return nil } -func (r *wgRepo) deleteLowLevelInterface(id domain.InterfaceIdentifier) error { +func (r *WgRepo) deleteLowLevelInterface(id domain.InterfaceIdentifier) error { link, err := r.nl.LinkByName(string(id)) if err != nil { return fmt.Errorf("unable to find low level interface: %w", err) @@ -319,7 +320,7 @@ func (r *wgRepo) deleteLowLevelInterface(id domain.InterfaceIdentifier) error { return nil } -func (r *wgRepo) SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error { +func (r *WgRepo) SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error { physicalPeer, err := r.getOrCreatePeer(deviceId, id) if err != nil { return err @@ -337,7 +338,7 @@ func (r *wgRepo) SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier return nil } -func (r *wgRepo) getOrCreatePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) { +func (r *WgRepo) getOrCreatePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) { peer, err := r.getPeer(deviceId, id) if err == nil { return peer, nil @@ -355,7 +356,7 @@ func (r *wgRepo) getOrCreatePeer(deviceId domain.InterfaceIdentifier, id domain. return peer, nil } -func (r *wgRepo) getPeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) { +func (r *WgRepo) getPeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) { if !id.IsPublicKey() { return nil, errors.New("invalid public key") } @@ -378,7 +379,7 @@ func (r *wgRepo) getPeer(deviceId domain.InterfaceIdentifier, id domain.PeerIden return nil, os.ErrNotExist } -func (r *wgRepo) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.PhysicalPeer) error { +func (r *WgRepo) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.PhysicalPeer) error { cfg := wgtypes.PeerConfig{ PublicKey: pp.GetPublicKey(), Remove: false, @@ -404,7 +405,7 @@ func (r *wgRepo) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.Phys return nil } -func (r *wgRepo) DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error { +func (r *WgRepo) DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error { if !id.IsPublicKey() { return errors.New("invalid public key") } @@ -417,7 +418,7 @@ func (r *wgRepo) DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifi return nil } -func (r *wgRepo) deletePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error { +func (r *WgRepo) deletePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error { cfg := wgtypes.PeerConfig{ PublicKey: id.ToPublicKey(), Remove: true, diff --git a/internal/adapters/wireguard_integration_test.go b/internal/adapters/wireguard_integration_test.go index 07aaac7..2c953cd 100644 --- a/internal/adapters/wireguard_integration_test.go +++ b/internal/adapters/wireguard_integration_test.go @@ -20,7 +20,7 @@ import ( ) // setup WireGuard manager with no linked store -func setup(t *testing.T) *wgRepo { +func setup(t *testing.T) *WgRepo { if getProcessOwner() != "root" { t.Fatalf("this tests need to be executed as root user") } diff --git a/internal/app/api/v0/handlers/endpoint_interfaces.go b/internal/app/api/v0/handlers/endpoint_interfaces.go index 1bef09f..707f8de 100644 --- a/internal/app/api/v0/handlers/endpoint_interfaces.go +++ b/internal/app/api/v0/handlers/endpoint_interfaces.go @@ -53,7 +53,7 @@ func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc { return } - c.JSON(http.StatusOK, model.NewInterface(in)) + c.JSON(http.StatusOK, model.NewInterface(in, nil)) } } @@ -68,7 +68,7 @@ func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc { // @Router /interface/all [get] func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc { return func(c *gin.Context) { - interfaces, err := e.app.GetAllInterfaces(c.Request.Context()) + interfaces, peers, err := e.app.GetAllInterfacesAndPeers(c.Request.Context()) if err != nil { c.JSON(http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), @@ -76,7 +76,7 @@ func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc { return } - c.JSON(http.StatusOK, model.NewInterfaces(interfaces)) + c.JSON(http.StatusOK, model.NewInterfaces(interfaces, peers)) } } @@ -100,7 +100,7 @@ func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc { return } - iface, _, err := e.app.GetInterfaceAndPeers(c.Request.Context(), domain.InterfaceIdentifier(id)) + iface, peers, err := e.app.GetInterfaceAndPeers(c.Request.Context(), domain.InterfaceIdentifier(id)) if err != nil { c.JSON(http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), @@ -108,7 +108,7 @@ func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc { return } - c.JSON(http.StatusOK, model.NewInterface(iface)) + c.JSON(http.StatusOK, model.NewInterface(iface, peers)) } } @@ -186,7 +186,7 @@ func (e interfaceEndpoint) handleUpdatePut() gin.HandlerFunc { return } - updatedInterface, err := e.app.UpdateInterface(ctx, model.NewDomainInterface(&in)) + updatedInterface, peers, err := e.app.UpdateInterface(ctx, model.NewDomainInterface(&in)) if err != nil { c.JSON(http.StatusInternalServerError, model.Error{ Code: http.StatusInternalServerError, Message: err.Error(), @@ -194,7 +194,7 @@ func (e interfaceEndpoint) handleUpdatePut() gin.HandlerFunc { return } - c.JSON(http.StatusOK, model.NewInterface(updatedInterface)) + c.JSON(http.StatusOK, model.NewInterface(updatedInterface, peers)) } } @@ -228,7 +228,7 @@ func (e interfaceEndpoint) handleCreatePost() gin.HandlerFunc { return } - c.JSON(http.StatusOK, model.NewInterface(newInterface)) + c.JSON(http.StatusOK, model.NewInterface(newInterface, nil)) } } diff --git a/internal/app/api/v0/model/models_interface.go b/internal/app/api/v0/model/models_interface.go index 3305f43..5755304 100644 --- a/internal/app/api/v0/model/models_interface.go +++ b/internal/app/api/v0/model/models_interface.go @@ -51,8 +51,8 @@ type Interface struct { TotalPeers int `json:"TotalPeers"` } -func NewInterface(src *domain.Interface) *Interface { - return &Interface{ +func NewInterface(src *domain.Interface, peers []domain.Peer) *Interface { + iface := &Interface{ Identifier: string(src.Identifier), DisplayName: src.DisplayName, Mode: string(src.Type), @@ -86,15 +86,29 @@ func NewInterface(src *domain.Interface) *Interface { PeerDefPreDown: src.PeerDefPreDown, PeerDefPostDown: src.PeerDefPostDown, - EnabledPeers: 0, // TODO - TotalPeers: 0, // TODO + EnabledPeers: 0, + TotalPeers: 0, } + + if len(peers) > 0 { + iface.TotalPeers = len(peers) + + activePeers := 0 + for _, peer := range peers { + if !peer.IsDisabled() { + activePeers++ + } + } + iface.EnabledPeers = activePeers + } + + return iface } -func NewInterfaces(src []domain.Interface) []Interface { +func NewInterfaces(src []domain.Interface, srcPeers [][]domain.Peer) []Interface { results := make([]Interface, len(src)) for i := range src { - results[i] = *NewInterface(&src[i]) + results[i] = *NewInterface(&src[i], srcPeers[i]) } return results diff --git a/internal/app/api/v0/model/models_peer.go b/internal/app/api/v0/model/models_peer.go index fb8fe10..50f2640 100644 --- a/internal/app/api/v0/model/models_peer.go +++ b/internal/app/api/v0/model/models_peer.go @@ -171,8 +171,8 @@ type MultiPeerRequest struct { func NewDomainPeerCreationRequest(src *MultiPeerRequest) *domain.PeerCreationRequest { return &domain.PeerCreationRequest{ - Identifiers: src.Identifiers, - Suffix: src.Suffix, + UserIdentifiers: src.Identifiers, + Suffix: src.Suffix, } } diff --git a/internal/app/app.go b/internal/app/app.go index 2950246..756e1bd 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -72,12 +72,14 @@ func (a *App) importNewInterfaces(ctx context.Context) error { return nil // feature disabled } - err := a.ImportNewInterfaces(ctx) + importedCount, err := a.ImportNewInterfaces(ctx) if err != nil { return err } - logrus.Trace("potential new interfaces imported") + if importedCount > 0 { + logrus.Infof("%d new interfaces imported", importedCount) + } return nil } @@ -92,7 +94,7 @@ func (a *App) restoreInterfaceState(ctx context.Context) error { return err } - logrus.Trace("interface state restored") + logrus.Info("interface state restored") return nil } @@ -141,7 +143,7 @@ func (a *App) createDefaultUser(ctx context.Context) error { return err } - logrus.Tracef("admin user %s created", admin.Identifier) + logrus.Infof("admin user %s created", admin.Identifier) return nil } diff --git a/internal/app/configfile/manager.go b/internal/app/configfile/manager.go index 5048358..892501a 100644 --- a/internal/app/configfile/manager.go +++ b/internal/app/configfile/manager.go @@ -7,11 +7,9 @@ import ( "fmt" "github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/domain" - "github.com/sirupsen/logrus" "github.com/yeqown/go-qrcode/v2" "io" "os" - "path/filepath" "strings" ) @@ -19,11 +17,12 @@ type Manager struct { cfg *config.Config tplHandler *TemplateHandler - users UserDatabaseRepo - wg WireguardDatabaseRepo + fsRepo FileSystemRepo // can be nil if storing the configuration is disabled + users UserDatabaseRepo + wg WireguardDatabaseRepo } -func NewConfigFileManager(cfg *config.Config, users UserDatabaseRepo, wg WireguardDatabaseRepo) (*Manager, error) { +func NewConfigFileManager(cfg *config.Config, users UserDatabaseRepo, wg WireguardDatabaseRepo, fsRepo FileSystemRepo) (*Manager, error) { tplHandler, err := newTemplateHandler() if err != nil { return nil, fmt.Errorf("failed to initialize template handler: %w", err) @@ -33,8 +32,9 @@ func NewConfigFileManager(cfg *config.Config, users UserDatabaseRepo, wg Wiregua cfg: cfg, tplHandler: tplHandler, - users: users, - wg: wg, + fsRepo: fsRepo, + users: users, + wg: wg, } if err := m.createStorageDirectory(); err != nil { @@ -122,7 +122,7 @@ func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifi } func (m Manager) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error { - if m.cfg.Advanced.ConfigStoragePath == "" { + if m.fsRepo == nil { return fmt.Errorf("peristing configuration is not supported") } @@ -136,19 +136,7 @@ func (m Manager) PersistInterfaceConfig(ctx context.Context, id domain.Interface return fmt.Errorf("failed to get interface config: %w", err) } - file, err := os.Create(filepath.Join(m.cfg.Advanced.ConfigStoragePath, iface.GetConfigFileName())) - if err != nil { - return fmt.Errorf("failed to create interface config file: %w", err) - } - defer func(file *os.File) { - err := file.Close() - if err != nil { - logrus.Warn("failed to close interface config file: %v", err) - } - }(file) - - _, err = io.Copy(file, cfg) - if err != nil { + if err := m.fsRepo.WriteFile(iface.GetConfigFileName(), cfg); err != nil { return fmt.Errorf("failed to write interface config: %w", err) } diff --git a/internal/app/configfile/repos.go b/internal/app/configfile/repos.go index 19e109c..e43687b 100644 --- a/internal/app/configfile/repos.go +++ b/internal/app/configfile/repos.go @@ -3,6 +3,7 @@ package configfile import ( "context" "github.com/h44z/wg-portal/internal/domain" + "io" ) type UserDatabaseRepo interface { @@ -14,3 +15,7 @@ type WireguardDatabaseRepo interface { GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) } + +type FileSystemRepo interface { + WriteFile(path string, contents io.Reader) error +} diff --git a/internal/app/repos.go b/internal/app/repos.go index 2f5b34a..4f2ce2b 100644 --- a/internal/app/repos.go +++ b/internal/app/repos.go @@ -28,17 +28,17 @@ type UserManager interface { type WireGuardManager interface { StartBackgroundJobs(ctx context.Context) GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) - ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) error + ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) RestoreInterfaceState(ctx context.Context, updateDbOnError bool, filter ...domain.InterfaceIdentifier) error CreateDefaultPeer(ctx context.Context, user *domain.User) error GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) - GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) + GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) PrepareInterface(ctx context.Context) (*domain.Interface, error) CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) - UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) + UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) diff --git a/internal/app/wireguard/wireguard_interfaces.go b/internal/app/wireguard/wireguard_interfaces.go index 1b2127b..4205e55 100644 --- a/internal/app/wireguard/wireguard_interfaces.go +++ b/internal/app/wireguard/wireguard_interfaces.go @@ -27,10 +27,28 @@ func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, erro return m.db.GetAllInterfaces(ctx) } -func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) error { +func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) { + interfaces, err := m.db.GetAllInterfaces(ctx) + if err != nil { + return nil, nil, fmt.Errorf("unable to load all interfaces: %w", err) + } + + allPeers := make([][]domain.Peer, len(interfaces)) + for i, iface := range interfaces { + peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier) + if err != nil { + return nil, nil, fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err) + } + allPeers[i] = peers + } + + return interfaces, allPeers, nil +} + +func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) { physicalInterfaces, err := m.wg.GetInterfaces(ctx) if err != nil { - return err + return 0, err } // if no filter is given, exclude already existing interfaces @@ -38,13 +56,14 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter if len(filter) == 0 { existingInterfaces, err := m.db.GetAllInterfaces(ctx) if err != nil { - return err + return 0, err } for _, existingInterface := range existingInterfaces { excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier) } } + imported := 0 for _, physicalInterface := range physicalInterfaces { if internal.SliceContains(excludedInterfaces, physicalInterface.Identifier) { continue @@ -58,18 +77,19 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter physicalPeers, err := m.wg.GetPeers(ctx, physicalInterface.Identifier) if err != nil { - return err + return 0, err } err = m.importInterface(ctx, &physicalInterface, physicalPeers) if err != nil { - return fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err) + return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err) } logrus.Infof("imported new interface %s and %d peers", physicalInterface.Identifier, len(physicalPeers)) + imported++ } - return nil + return imported, nil } func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error { @@ -117,6 +137,8 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool physicalInterface, err := m.wg.GetInterface(ctx, iface.Identifier) if err != nil { + logrus.Debugf("creating missing interface %s...", iface.Identifier) + // try to create a new interface err := m.wg.SaveInterface(ctx, iface.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { domain.MergeToPhysicalInterface(pi, &iface) @@ -148,6 +170,8 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool } } else { if physicalInterface.DeviceUp != !iface.IsDisabled() { + logrus.Debugf("restoring interface state for %s to disabled=%t", iface.Identifier, iface.IsDisabled()) + // try to move interface to stored state err := m.wg.SaveInterface(ctx, iface.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { pi.DeviceUp = !iface.IsDisabled() @@ -287,14 +311,14 @@ func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*do return in, nil } -func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) { - existingInterface, err := m.db.GetInterface(ctx, in.Identifier) +func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error) { + existingInterface, existingPeers, err := m.db.GetInterfaceAndPeers(ctx, in.Identifier) if err != nil { - return nil, fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err) + return nil, nil, fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err) } if err := m.validateInterfaceModifications(ctx, existingInterface, in); err != nil { - return nil, fmt.Errorf("update not allowed: %w", err) + return nil, nil, fmt.Errorf("update not allowed: %w", err) } err = m.db.SaveInterface(ctx, in.Identifier, func(i *domain.Interface) (*domain.Interface, error) { @@ -311,10 +335,10 @@ func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*do return in, nil }) if err != nil { - return nil, fmt.Errorf("update failure: %w", err) + return nil, nil, fmt.Errorf("update failure: %w", err) } - return in, nil + return in, existingPeers, nil } func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error { diff --git a/internal/app/wireguard/wireguard_peers.go b/internal/app/wireguard/wireguard_peers.go index 7a49ffa..28029ae 100644 --- a/internal/app/wireguard/wireguard_peers.go +++ b/internal/app/wireguard/wireguard_peers.go @@ -4,13 +4,47 @@ import ( "context" "errors" "fmt" + "github.com/h44z/wg-portal/internal" "github.com/h44z/wg-portal/internal/domain" + "github.com/sirupsen/logrus" "time" ) func (m Manager) CreateDefaultPeer(ctx context.Context, user *domain.User) error { - // TODO: implement - return fmt.Errorf("IMPLEMENT ME") + existingInterfaces, err := m.db.GetAllInterfaces(ctx) + if err != nil { + return fmt.Errorf("failed to fetch all interfaces: %w", err) + } + + var newPeers []domain.Peer + for _, iface := range existingInterfaces { + if iface.Type != domain.InterfaceTypeServer { + continue // only create default peers for server interfaces + } + + peer, err := m.PreparePeer(ctx, iface.Identifier) + if err != nil { + return fmt.Errorf("failed to create default peer for interface %s: %w", iface.Identifier, err) + } + + peer.UserIdentifier = user.Identifier + peer.DisplayName = fmt.Sprintf("Default Peer %s", internal.TruncateString(string(peer.Identifier), 8)) + peer.Notes = fmt.Sprintf("Default peer created for user %s", user.Identifier) + + newPeers = append(newPeers, *peer) + } + + for i, peer := range newPeers { + _, err := m.CreatePeer(ctx, &newPeers[i]) + if err != nil { + return fmt.Errorf("failed to create default peer %s on interface %s: %w", + peer.Identifier, peer.InterfaceIdentifier, err) + } + } + + logrus.Infof("created %d default peers for user %s", len(newPeers), user.Identifier) + + return nil } func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) { @@ -59,7 +93,7 @@ func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) ExtraAllowedIPsStr: "", PresharedKey: pk, PersistentKeepalive: domain.NewIntConfigOption(iface.PeerDefPersistentKeepalive, true), - DisplayName: fmt.Sprintf("Peer %s", peerId[0:8]), + DisplayName: fmt.Sprintf("Peer %s", internal.TruncateString(string(peerId), 8)), Identifier: peerId, UserIdentifier: currentUser.Id, InterfaceIdentifier: iface.Identifier, @@ -133,7 +167,7 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee func (m Manager) CreateMultiplePeers(ctx context.Context, interfaceId domain.InterfaceIdentifier, r *domain.PeerCreationRequest) ([]domain.Peer, error) { var newPeers []domain.Peer - for _, id := range r.Identifiers { + for _, id := range r.UserIdentifiers { freshPeer, err := m.PreparePeer(ctx, interfaceId) if err != nil { return nil, fmt.Errorf("failed to prepare peer for interface %s: %w", interfaceId, err) diff --git a/internal/config/config.go b/internal/config/config.go index e0ae608..63735e7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -153,7 +153,11 @@ func loadConfigFile(cfg any, filename string) error { if err != nil { return err } - defer f.Close() + defer func(f *os.File) { + if err := f.Close(); err != nil { + logrus.Errorf("failed to close configuration file %s: %v", filename, err) + } + }(f) decoder := yaml.NewDecoder(f) err = decoder.Decode(cfg) diff --git a/internal/domain/context.go b/internal/domain/context.go index a5c4e48..6fe004b 100644 --- a/internal/domain/context.go +++ b/internal/domain/context.go @@ -8,6 +8,11 @@ import ( const CtxUserInfo = "userInfo" +const ( + CtxSystemAdminId = "_WG_SYS_ADMIN_" + CtxUnknownUserId = "_WG_SYS_UNKNOWN_" +) + type ContextUserInfo struct { Id UserIdentifier IsAdmin bool @@ -15,14 +20,14 @@ type ContextUserInfo struct { func DefaultContextUserInfo() *ContextUserInfo { return &ContextUserInfo{ - Id: "_WG_SYS_UNKNOWN_", + Id: CtxUnknownUserId, IsAdmin: false, } } func SystemAdminContextUserInfo() *ContextUserInfo { return &ContextUserInfo{ - Id: "_WG_SYS_ADMIN_", + Id: CtxSystemAdminId, IsAdmin: true, } } diff --git a/internal/domain/peer.go b/internal/domain/peer.go index 3520117..9141a1a 100644 --- a/internal/domain/peer.go +++ b/internal/domain/peer.go @@ -232,6 +232,6 @@ func MergeToPhysicalPeer(pp *PhysicalPeer, p *Peer) { } type PeerCreationRequest struct { - Identifiers []string - Suffix string + UserIdentifiers []string + Suffix string } diff --git a/internal/ldap_utils.go b/internal/ldap_utils.go index 0abad23..242ec0b 100644 --- a/internal/ldap_utils.go +++ b/internal/ldap_utils.go @@ -3,6 +3,7 @@ package internal import ( "crypto/tls" "fmt" + "github.com/sirupsen/logrus" "os" "github.com/go-ldap/ldap/v3" @@ -72,7 +73,9 @@ func LdapConnect(cfg *config.LdapProvider) (*ldap.Conn, error) { func LdapDisconnect(conn *ldap.Conn) { if conn != nil { - conn.Close() + if err := conn.Close(); err != nil { + logrus.Errorf("failed to close ldap connection: %v", err) + } } }