many more improvements and cleanup

This commit is contained in:
Christoph Haas 2023-07-24 21:00:45 +02:00
parent 5153f602ab
commit 2a5b4fe31d
21 changed files with 201 additions and 108 deletions

View File

@ -140,7 +140,7 @@ The following configuration options are available:
* [Gin, HTTP web framework written in Go](https://github.com/gin-gonic/gin) * [Gin, HTTP web framework written in Go](https://github.com/gin-gonic/gin)
* [Bootstrap, for the HTML templates](https://getbootstrap.com/) * [Bootstrap, for the HTML templates](https://getbootstrap.com/)
* [Vue.JS, for the frontend](hhttps://vuejs.org/) * [Vue.JS, for the frontend](https://vuejs.org/)
## License ## License

View File

@ -23,6 +23,7 @@ import (
evbus "github.com/vardius/message-bus" evbus "github.com/vardius/message-bus"
) )
// main entry point for WireGuard Portal
func main() { func main() {
ctx := internal.SignalAwareContext(context.Background(), syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM) ctx := internal.SignalAwareContext(context.Background(), syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM)
@ -44,6 +45,9 @@ func main() {
mailer := adapters.NewSmtpMailRepo(cfg.Mail) mailer := adapters.NewSmtpMailRepo(cfg.Mail)
cfgFileSystem, err := adapters.NewFileSystemRepository(cfg.Advanced.ConfigStoragePath)
internal.AssertNoError(err)
shouldExit, err := app.HandleProgramArgs(cfg, rawDb) shouldExit, err := app.HandleProgramArgs(cfg, rawDb)
switch { switch {
case shouldExit && err == nil: case shouldExit && err == nil:
@ -70,7 +74,7 @@ func main() {
statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, database, wireGuard) statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, database, wireGuard)
internal.AssertNoError(err) internal.AssertNoError(err)
cfgFileManager, err := configfile.NewConfigFileManager(cfg, database, database) cfgFileManager, err := configfile.NewConfigFileManager(cfg, database, database, cfgFileSystem)
internal.AssertNoError(err) internal.AssertNoError(err)
mailManager, err := mail.NewMailManager(cfg, mailer, cfgFileManager, database, database) mailManager, err := mail.NewMailManager(cfg, mailer, cfgFileManager, database, database)

View File

@ -22,8 +22,10 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
// SchemaVersion describes the current database schema version. It must be incremented if a manual migration is needed.
var SchemaVersion uint64 = 1 var SchemaVersion uint64 = 1
// SysStat stores the current database schema version and the timestamp when it was applied.
type SysStat struct { type SysStat struct {
MigratedAt time.Time `gorm:"column:migrated_at"` MigratedAt time.Time `gorm:"column:migrated_at"`
SchemaVersion uint64 `gorm:"primaryKey,column:schema_version"` SchemaVersion uint64 `gorm:"primaryKey,column:schema_version"`
@ -143,6 +145,8 @@ func NewDatabase(cfg config.DatabaseConfig) (*gorm.DB, error) {
return gormDb, nil return gormDb, nil
} }
// SqlRepo is a SQL database repository implementation.
// Currently, it supports MySQL, SQLite, Microsoft SQL and Postgresql database systems.
type SqlRepo struct { type SqlRepo struct {
db *gorm.DB db *gorm.DB
} }

View File

@ -1,43 +1,52 @@
package adapters package adapters
import ( import (
"context" "fmt"
"github.com/sirupsen/logrus"
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
) )
type filesystemRepo struct { type FilesystemRepo struct {
basePath string basePath string
} }
func NewFileSystemRepository(basePath string) (*filesystemRepo, error) { func NewFileSystemRepository(basePath string) (*FilesystemRepo, error) {
r := &filesystemRepo{basePath: basePath} if basePath == "" {
return nil, nil // no path, return empty repository
}
r := &FilesystemRepo{basePath: basePath}
if err := os.MkdirAll(r.basePath, os.ModePerm); err != nil { if err := os.MkdirAll(r.basePath, os.ModePerm); err != nil {
return nil, err return nil, fmt.Errorf("failed to create base directory %s: %w", basePath, err)
} }
return r, nil return r, nil
} }
func (r *filesystemRepo) WriteFile(_ context.Context, path string, contents io.Reader) error { func (r *FilesystemRepo) WriteFile(path string, contents io.Reader) error {
filePath := filepath.Join(r.basePath, path) filePath := filepath.Join(r.basePath, path)
parentDirectory := filepath.Dir(filePath)
err := os.MkdirAll(filepath.Dir(filePath), os.ModePerm) if err := os.MkdirAll(parentDirectory, os.ModePerm); err != nil {
if err != nil { return fmt.Errorf("failed to create parent directory %s: %w", parentDirectory, err)
return err
} }
file, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.ModePerm) file, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.ModePerm)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to open file %s: %w", file.Name(), err)
} }
defer file.Close() defer func(file *os.File) {
if err := file.Close(); err != nil {
logrus.Errorf("failed to close file %s: %v", file.Name(), err)
}
}(file)
_, err = io.Copy(file, contents) _, err = io.Copy(file, contents)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to write file contents: %w", err)
} }
return nil return nil

View File

@ -1 +0,0 @@
DROP TABLE IF EXISTS test;

View File

@ -1,3 +0,0 @@
CREATE TABLE IF NOT EXISTS test (
firstname VARCHAR(16)
);

View File

@ -13,12 +13,13 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
type wgRepo struct { // WgRepo implements all low-level WireGuard interactions.
type WgRepo struct {
wg lowlevel.WireGuardClient wg lowlevel.WireGuardClient
nl lowlevel.NetlinkClient nl lowlevel.NetlinkClient
} }
func NewWireGuardRepository() *wgRepo { func NewWireGuardRepository() *WgRepo {
wg, err := wgctrl.New() wg, err := wgctrl.New()
if err != nil { if err != nil {
panic("failed to init wgctrl: " + err.Error()) panic("failed to init wgctrl: " + err.Error())
@ -26,7 +27,7 @@ func NewWireGuardRepository() *wgRepo {
nl := &lowlevel.NetlinkManager{} nl := &lowlevel.NetlinkManager{}
repo := &wgRepo{ repo := &WgRepo{
wg: wg, wg: wg,
nl: nl, nl: nl,
} }
@ -34,7 +35,7 @@ func NewWireGuardRepository() *wgRepo {
return repo return repo
} }
func (r *wgRepo) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) { func (r *WgRepo) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) {
devices, err := r.wg.Devices() devices, err := r.wg.Devices()
if err != nil { if err != nil {
return nil, fmt.Errorf("device list error: %w", err) return nil, fmt.Errorf("device list error: %w", err)
@ -52,11 +53,11 @@ func (r *wgRepo) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, e
return interfaces, nil return interfaces, nil
} }
func (r *wgRepo) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) { func (r *WgRepo) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
return r.getInterface(id) return r.getInterface(id)
} }
func (r *wgRepo) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) { func (r *WgRepo) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) {
device, err := r.wg.Device(string(deviceId)) device, err := r.wg.Device(string(deviceId))
if err != nil { if err != nil {
return nil, fmt.Errorf("device error: %w", err) return nil, fmt.Errorf("device error: %w", err)
@ -74,11 +75,11 @@ func (r *wgRepo) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier
return peers, nil return peers, nil
} }
func (r *wgRepo) GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) { func (r *WgRepo) GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) {
return r.getPeer(deviceId, id) return r.getPeer(deviceId, id)
} }
func (r *wgRepo) convertWireGuardInterface(device *wgtypes.Device) (domain.PhysicalInterface, error) { func (r *WgRepo) convertWireGuardInterface(device *wgtypes.Device) (domain.PhysicalInterface, error) {
// read data from wgctrl interface // read data from wgctrl interface
iface := domain.PhysicalInterface{ iface := domain.PhysicalInterface{
@ -122,7 +123,7 @@ func (r *wgRepo) convertWireGuardInterface(device *wgtypes.Device) (domain.Physi
return iface, nil return iface, nil
} }
func (r *wgRepo) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer, error) { func (r *WgRepo) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer, error) {
peerModel := domain.PhysicalPeer{ peerModel := domain.PhysicalPeer{
Identifier: domain.PeerIdentifier(peer.PublicKey.String()), Identifier: domain.PeerIdentifier(peer.PublicKey.String()),
Endpoint: "", Endpoint: "",
@ -151,7 +152,7 @@ func (r *wgRepo) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer,
return peerModel, nil return peerModel, nil
} }
func (r *wgRepo) SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error { func (r *WgRepo) SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error {
physicalInterface, err := r.getOrCreateInterface(id) physicalInterface, err := r.getOrCreateInterface(id)
if err != nil { if err != nil {
return err return err
@ -174,7 +175,7 @@ func (r *wgRepo) SaveInterface(_ context.Context, id domain.InterfaceIdentifier,
return nil return nil
} }
func (r *wgRepo) getOrCreateInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) { func (r *WgRepo) getOrCreateInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
device, err := r.getInterface(id) device, err := r.getInterface(id)
if err == nil { if err == nil {
return device, nil return device, nil
@ -192,7 +193,7 @@ func (r *wgRepo) getOrCreateInterface(id domain.InterfaceIdentifier) (*domain.Ph
return device, err return device, err
} }
func (r *wgRepo) getInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) { func (r *WgRepo) getInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
device, err := r.wg.Device(string(id)) device, err := r.wg.Device(string(id))
if err != nil { if err != nil {
return nil, err return nil, err
@ -202,7 +203,7 @@ func (r *wgRepo) getInterface(id domain.InterfaceIdentifier) (*domain.PhysicalIn
return &pi, err return &pi, err
} }
func (r *wgRepo) createLowLevelInterface(id domain.InterfaceIdentifier) error { func (r *WgRepo) createLowLevelInterface(id domain.InterfaceIdentifier) error {
link := &netlink.GenericLink{ link := &netlink.GenericLink{
LinkAttrs: netlink.LinkAttrs{ LinkAttrs: netlink.LinkAttrs{
Name: string(id), Name: string(id),
@ -217,7 +218,7 @@ func (r *wgRepo) createLowLevelInterface(id domain.InterfaceIdentifier) error {
return nil return nil
} }
func (r *wgRepo) updateLowLevelInterface(pi *domain.PhysicalInterface) error { func (r *WgRepo) updateLowLevelInterface(pi *domain.PhysicalInterface) error {
link, err := r.nl.LinkByName(string(pi.Identifier)) link, err := r.nl.LinkByName(string(pi.Identifier))
if err != nil { if err != nil {
return err return err
@ -274,7 +275,7 @@ func (r *wgRepo) updateLowLevelInterface(pi *domain.PhysicalInterface) error {
return nil return nil
} }
func (r *wgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error { func (r *WgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
pKey, err := wgtypes.NewKey(pi.KeyPair.GetPrivateKeyBytes()) pKey, err := wgtypes.NewKey(pi.KeyPair.GetPrivateKeyBytes())
if err != nil { if err != nil {
return err return err
@ -297,7 +298,7 @@ func (r *wgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
return nil return nil
} }
func (r *wgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error { func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
if err := r.deleteLowLevelInterface(id); err != nil { if err := r.deleteLowLevelInterface(id); err != nil {
return err return err
} }
@ -305,7 +306,7 @@ func (r *wgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifie
return nil return nil
} }
func (r *wgRepo) deleteLowLevelInterface(id domain.InterfaceIdentifier) error { func (r *WgRepo) deleteLowLevelInterface(id domain.InterfaceIdentifier) error {
link, err := r.nl.LinkByName(string(id)) link, err := r.nl.LinkByName(string(id))
if err != nil { if err != nil {
return fmt.Errorf("unable to find low level interface: %w", err) return fmt.Errorf("unable to find low level interface: %w", err)
@ -319,7 +320,7 @@ func (r *wgRepo) deleteLowLevelInterface(id domain.InterfaceIdentifier) error {
return nil return nil
} }
func (r *wgRepo) SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error { func (r *WgRepo) SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error {
physicalPeer, err := r.getOrCreatePeer(deviceId, id) physicalPeer, err := r.getOrCreatePeer(deviceId, id)
if err != nil { if err != nil {
return err return err
@ -337,7 +338,7 @@ func (r *wgRepo) SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier
return nil return nil
} }
func (r *wgRepo) getOrCreatePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) { func (r *WgRepo) getOrCreatePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) {
peer, err := r.getPeer(deviceId, id) peer, err := r.getPeer(deviceId, id)
if err == nil { if err == nil {
return peer, nil return peer, nil
@ -355,7 +356,7 @@ func (r *wgRepo) getOrCreatePeer(deviceId domain.InterfaceIdentifier, id domain.
return peer, nil return peer, nil
} }
func (r *wgRepo) getPeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) { func (r *WgRepo) getPeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) {
if !id.IsPublicKey() { if !id.IsPublicKey() {
return nil, errors.New("invalid public key") return nil, errors.New("invalid public key")
} }
@ -378,7 +379,7 @@ func (r *wgRepo) getPeer(deviceId domain.InterfaceIdentifier, id domain.PeerIden
return nil, os.ErrNotExist return nil, os.ErrNotExist
} }
func (r *wgRepo) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.PhysicalPeer) error { func (r *WgRepo) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.PhysicalPeer) error {
cfg := wgtypes.PeerConfig{ cfg := wgtypes.PeerConfig{
PublicKey: pp.GetPublicKey(), PublicKey: pp.GetPublicKey(),
Remove: false, Remove: false,
@ -404,7 +405,7 @@ func (r *wgRepo) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.Phys
return nil return nil
} }
func (r *wgRepo) DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error { func (r *WgRepo) DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error {
if !id.IsPublicKey() { if !id.IsPublicKey() {
return errors.New("invalid public key") return errors.New("invalid public key")
} }
@ -417,7 +418,7 @@ func (r *wgRepo) DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifi
return nil return nil
} }
func (r *wgRepo) deletePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error { func (r *WgRepo) deletePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error {
cfg := wgtypes.PeerConfig{ cfg := wgtypes.PeerConfig{
PublicKey: id.ToPublicKey(), PublicKey: id.ToPublicKey(),
Remove: true, Remove: true,

View File

@ -20,7 +20,7 @@ import (
) )
// setup WireGuard manager with no linked store // setup WireGuard manager with no linked store
func setup(t *testing.T) *wgRepo { func setup(t *testing.T) *WgRepo {
if getProcessOwner() != "root" { if getProcessOwner() != "root" {
t.Fatalf("this tests need to be executed as root user") t.Fatalf("this tests need to be executed as root user")
} }

View File

@ -53,7 +53,7 @@ func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc {
return return
} }
c.JSON(http.StatusOK, model.NewInterface(in)) c.JSON(http.StatusOK, model.NewInterface(in, nil))
} }
} }
@ -68,7 +68,7 @@ func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc {
// @Router /interface/all [get] // @Router /interface/all [get]
func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc { func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
interfaces, err := e.app.GetAllInterfaces(c.Request.Context()) interfaces, peers, err := e.app.GetAllInterfacesAndPeers(c.Request.Context())
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{ c.JSON(http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(), Code: http.StatusInternalServerError, Message: err.Error(),
@ -76,7 +76,7 @@ func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc {
return return
} }
c.JSON(http.StatusOK, model.NewInterfaces(interfaces)) c.JSON(http.StatusOK, model.NewInterfaces(interfaces, peers))
} }
} }
@ -100,7 +100,7 @@ func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc {
return return
} }
iface, _, err := e.app.GetInterfaceAndPeers(c.Request.Context(), domain.InterfaceIdentifier(id)) iface, peers, err := e.app.GetInterfaceAndPeers(c.Request.Context(), domain.InterfaceIdentifier(id))
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{ c.JSON(http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(), Code: http.StatusInternalServerError, Message: err.Error(),
@ -108,7 +108,7 @@ func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc {
return return
} }
c.JSON(http.StatusOK, model.NewInterface(iface)) c.JSON(http.StatusOK, model.NewInterface(iface, peers))
} }
} }
@ -186,7 +186,7 @@ func (e interfaceEndpoint) handleUpdatePut() gin.HandlerFunc {
return return
} }
updatedInterface, err := e.app.UpdateInterface(ctx, model.NewDomainInterface(&in)) updatedInterface, peers, err := e.app.UpdateInterface(ctx, model.NewDomainInterface(&in))
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{ c.JSON(http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(), Code: http.StatusInternalServerError, Message: err.Error(),
@ -194,7 +194,7 @@ func (e interfaceEndpoint) handleUpdatePut() gin.HandlerFunc {
return return
} }
c.JSON(http.StatusOK, model.NewInterface(updatedInterface)) c.JSON(http.StatusOK, model.NewInterface(updatedInterface, peers))
} }
} }
@ -228,7 +228,7 @@ func (e interfaceEndpoint) handleCreatePost() gin.HandlerFunc {
return return
} }
c.JSON(http.StatusOK, model.NewInterface(newInterface)) c.JSON(http.StatusOK, model.NewInterface(newInterface, nil))
} }
} }

View File

@ -51,8 +51,8 @@ type Interface struct {
TotalPeers int `json:"TotalPeers"` TotalPeers int `json:"TotalPeers"`
} }
func NewInterface(src *domain.Interface) *Interface { func NewInterface(src *domain.Interface, peers []domain.Peer) *Interface {
return &Interface{ iface := &Interface{
Identifier: string(src.Identifier), Identifier: string(src.Identifier),
DisplayName: src.DisplayName, DisplayName: src.DisplayName,
Mode: string(src.Type), Mode: string(src.Type),
@ -86,15 +86,29 @@ func NewInterface(src *domain.Interface) *Interface {
PeerDefPreDown: src.PeerDefPreDown, PeerDefPreDown: src.PeerDefPreDown,
PeerDefPostDown: src.PeerDefPostDown, PeerDefPostDown: src.PeerDefPostDown,
EnabledPeers: 0, // TODO EnabledPeers: 0,
TotalPeers: 0, // TODO TotalPeers: 0,
} }
if len(peers) > 0 {
iface.TotalPeers = len(peers)
activePeers := 0
for _, peer := range peers {
if !peer.IsDisabled() {
activePeers++
}
}
iface.EnabledPeers = activePeers
}
return iface
} }
func NewInterfaces(src []domain.Interface) []Interface { func NewInterfaces(src []domain.Interface, srcPeers [][]domain.Peer) []Interface {
results := make([]Interface, len(src)) results := make([]Interface, len(src))
for i := range src { for i := range src {
results[i] = *NewInterface(&src[i]) results[i] = *NewInterface(&src[i], srcPeers[i])
} }
return results return results

View File

@ -171,7 +171,7 @@ type MultiPeerRequest struct {
func NewDomainPeerCreationRequest(src *MultiPeerRequest) *domain.PeerCreationRequest { func NewDomainPeerCreationRequest(src *MultiPeerRequest) *domain.PeerCreationRequest {
return &domain.PeerCreationRequest{ return &domain.PeerCreationRequest{
Identifiers: src.Identifiers, UserIdentifiers: src.Identifiers,
Suffix: src.Suffix, Suffix: src.Suffix,
} }
} }

View File

@ -72,12 +72,14 @@ func (a *App) importNewInterfaces(ctx context.Context) error {
return nil // feature disabled return nil // feature disabled
} }
err := a.ImportNewInterfaces(ctx) importedCount, err := a.ImportNewInterfaces(ctx)
if err != nil { if err != nil {
return err return err
} }
logrus.Trace("potential new interfaces imported") if importedCount > 0 {
logrus.Infof("%d new interfaces imported", importedCount)
}
return nil return nil
} }
@ -92,7 +94,7 @@ func (a *App) restoreInterfaceState(ctx context.Context) error {
return err return err
} }
logrus.Trace("interface state restored") logrus.Info("interface state restored")
return nil return nil
} }
@ -141,7 +143,7 @@ func (a *App) createDefaultUser(ctx context.Context) error {
return err return err
} }
logrus.Tracef("admin user %s created", admin.Identifier) logrus.Infof("admin user %s created", admin.Identifier)
return nil return nil
} }

View File

@ -7,11 +7,9 @@ import (
"fmt" "fmt"
"github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
"github.com/sirupsen/logrus"
"github.com/yeqown/go-qrcode/v2" "github.com/yeqown/go-qrcode/v2"
"io" "io"
"os" "os"
"path/filepath"
"strings" "strings"
) )
@ -19,11 +17,12 @@ type Manager struct {
cfg *config.Config cfg *config.Config
tplHandler *TemplateHandler tplHandler *TemplateHandler
fsRepo FileSystemRepo // can be nil if storing the configuration is disabled
users UserDatabaseRepo users UserDatabaseRepo
wg WireguardDatabaseRepo wg WireguardDatabaseRepo
} }
func NewConfigFileManager(cfg *config.Config, users UserDatabaseRepo, wg WireguardDatabaseRepo) (*Manager, error) { func NewConfigFileManager(cfg *config.Config, users UserDatabaseRepo, wg WireguardDatabaseRepo, fsRepo FileSystemRepo) (*Manager, error) {
tplHandler, err := newTemplateHandler() tplHandler, err := newTemplateHandler()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize template handler: %w", err) return nil, fmt.Errorf("failed to initialize template handler: %w", err)
@ -33,6 +32,7 @@ func NewConfigFileManager(cfg *config.Config, users UserDatabaseRepo, wg Wiregua
cfg: cfg, cfg: cfg,
tplHandler: tplHandler, tplHandler: tplHandler,
fsRepo: fsRepo,
users: users, users: users,
wg: wg, wg: wg,
} }
@ -122,7 +122,7 @@ func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifi
} }
func (m Manager) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error { func (m Manager) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error {
if m.cfg.Advanced.ConfigStoragePath == "" { if m.fsRepo == nil {
return fmt.Errorf("peristing configuration is not supported") return fmt.Errorf("peristing configuration is not supported")
} }
@ -136,19 +136,7 @@ func (m Manager) PersistInterfaceConfig(ctx context.Context, id domain.Interface
return fmt.Errorf("failed to get interface config: %w", err) return fmt.Errorf("failed to get interface config: %w", err)
} }
file, err := os.Create(filepath.Join(m.cfg.Advanced.ConfigStoragePath, iface.GetConfigFileName())) if err := m.fsRepo.WriteFile(iface.GetConfigFileName(), cfg); err != nil {
if err != nil {
return fmt.Errorf("failed to create interface config file: %w", err)
}
defer func(file *os.File) {
err := file.Close()
if err != nil {
logrus.Warn("failed to close interface config file: %v", err)
}
}(file)
_, err = io.Copy(file, cfg)
if err != nil {
return fmt.Errorf("failed to write interface config: %w", err) return fmt.Errorf("failed to write interface config: %w", err)
} }

View File

@ -3,6 +3,7 @@ package configfile
import ( import (
"context" "context"
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
"io"
) )
type UserDatabaseRepo interface { type UserDatabaseRepo interface {
@ -14,3 +15,7 @@ type WireguardDatabaseRepo interface {
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
} }
type FileSystemRepo interface {
WriteFile(path string, contents io.Reader) error
}

View File

@ -28,17 +28,17 @@ type UserManager interface {
type WireGuardManager interface { type WireGuardManager interface {
StartBackgroundJobs(ctx context.Context) StartBackgroundJobs(ctx context.Context)
GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error)
ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) error ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error)
RestoreInterfaceState(ctx context.Context, updateDbOnError bool, filter ...domain.InterfaceIdentifier) error RestoreInterfaceState(ctx context.Context, updateDbOnError bool, filter ...domain.InterfaceIdentifier) error
CreateDefaultPeer(ctx context.Context, user *domain.User) error CreateDefaultPeer(ctx context.Context, user *domain.User) error
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error)
GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error)
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error)
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
PrepareInterface(ctx context.Context) (*domain.Interface, error) PrepareInterface(ctx context.Context) (*domain.Interface, error)
CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error)
UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error)
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error)
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)

View File

@ -27,10 +27,28 @@ func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, erro
return m.db.GetAllInterfaces(ctx) return m.db.GetAllInterfaces(ctx)
} }
func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) error { func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) {
interfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return nil, nil, fmt.Errorf("unable to load all interfaces: %w", err)
}
allPeers := make([][]domain.Peer, len(interfaces))
for i, iface := range interfaces {
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil {
return nil, nil, fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
}
allPeers[i] = peers
}
return interfaces, allPeers, nil
}
func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) {
physicalInterfaces, err := m.wg.GetInterfaces(ctx) physicalInterfaces, err := m.wg.GetInterfaces(ctx)
if err != nil { if err != nil {
return err return 0, err
} }
// if no filter is given, exclude already existing interfaces // if no filter is given, exclude already existing interfaces
@ -38,13 +56,14 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter
if len(filter) == 0 { if len(filter) == 0 {
existingInterfaces, err := m.db.GetAllInterfaces(ctx) existingInterfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil { if err != nil {
return err return 0, err
} }
for _, existingInterface := range existingInterfaces { for _, existingInterface := range existingInterfaces {
excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier) excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier)
} }
} }
imported := 0
for _, physicalInterface := range physicalInterfaces { for _, physicalInterface := range physicalInterfaces {
if internal.SliceContains(excludedInterfaces, physicalInterface.Identifier) { if internal.SliceContains(excludedInterfaces, physicalInterface.Identifier) {
continue continue
@ -58,18 +77,19 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter
physicalPeers, err := m.wg.GetPeers(ctx, physicalInterface.Identifier) physicalPeers, err := m.wg.GetPeers(ctx, physicalInterface.Identifier)
if err != nil { if err != nil {
return err return 0, err
} }
err = m.importInterface(ctx, &physicalInterface, physicalPeers) err = m.importInterface(ctx, &physicalInterface, physicalPeers)
if err != nil { if err != nil {
return fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err) return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err)
} }
logrus.Infof("imported new interface %s and %d peers", physicalInterface.Identifier, len(physicalPeers)) logrus.Infof("imported new interface %s and %d peers", physicalInterface.Identifier, len(physicalPeers))
imported++
} }
return nil return imported, nil
} }
func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error { func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error {
@ -117,6 +137,8 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool
physicalInterface, err := m.wg.GetInterface(ctx, iface.Identifier) physicalInterface, err := m.wg.GetInterface(ctx, iface.Identifier)
if err != nil { if err != nil {
logrus.Debugf("creating missing interface %s...", iface.Identifier)
// try to create a new interface // try to create a new interface
err := m.wg.SaveInterface(ctx, iface.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { err := m.wg.SaveInterface(ctx, iface.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
domain.MergeToPhysicalInterface(pi, &iface) domain.MergeToPhysicalInterface(pi, &iface)
@ -148,6 +170,8 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool
} }
} else { } else {
if physicalInterface.DeviceUp != !iface.IsDisabled() { if physicalInterface.DeviceUp != !iface.IsDisabled() {
logrus.Debugf("restoring interface state for %s to disabled=%t", iface.Identifier, iface.IsDisabled())
// try to move interface to stored state // try to move interface to stored state
err := m.wg.SaveInterface(ctx, iface.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { err := m.wg.SaveInterface(ctx, iface.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
pi.DeviceUp = !iface.IsDisabled() pi.DeviceUp = !iface.IsDisabled()
@ -287,14 +311,14 @@ func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*do
return in, nil return in, nil
} }
func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) { func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error) {
existingInterface, err := m.db.GetInterface(ctx, in.Identifier) existingInterface, existingPeers, err := m.db.GetInterfaceAndPeers(ctx, in.Identifier)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err) return nil, nil, fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err)
} }
if err := m.validateInterfaceModifications(ctx, existingInterface, in); err != nil { if err := m.validateInterfaceModifications(ctx, existingInterface, in); err != nil {
return nil, fmt.Errorf("update not allowed: %w", err) return nil, nil, fmt.Errorf("update not allowed: %w", err)
} }
err = m.db.SaveInterface(ctx, in.Identifier, func(i *domain.Interface) (*domain.Interface, error) { err = m.db.SaveInterface(ctx, in.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
@ -311,10 +335,10 @@ func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*do
return in, nil return in, nil
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("update failure: %w", err) return nil, nil, fmt.Errorf("update failure: %w", err)
} }
return in, nil return in, existingPeers, nil
} }
func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error { func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {

View File

@ -4,13 +4,47 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
"github.com/sirupsen/logrus"
"time" "time"
) )
func (m Manager) CreateDefaultPeer(ctx context.Context, user *domain.User) error { func (m Manager) CreateDefaultPeer(ctx context.Context, user *domain.User) error {
// TODO: implement existingInterfaces, err := m.db.GetAllInterfaces(ctx)
return fmt.Errorf("IMPLEMENT ME") if err != nil {
return fmt.Errorf("failed to fetch all interfaces: %w", err)
}
var newPeers []domain.Peer
for _, iface := range existingInterfaces {
if iface.Type != domain.InterfaceTypeServer {
continue // only create default peers for server interfaces
}
peer, err := m.PreparePeer(ctx, iface.Identifier)
if err != nil {
return fmt.Errorf("failed to create default peer for interface %s: %w", iface.Identifier, err)
}
peer.UserIdentifier = user.Identifier
peer.DisplayName = fmt.Sprintf("Default Peer %s", internal.TruncateString(string(peer.Identifier), 8))
peer.Notes = fmt.Sprintf("Default peer created for user %s", user.Identifier)
newPeers = append(newPeers, *peer)
}
for i, peer := range newPeers {
_, err := m.CreatePeer(ctx, &newPeers[i])
if err != nil {
return fmt.Errorf("failed to create default peer %s on interface %s: %w",
peer.Identifier, peer.InterfaceIdentifier, err)
}
}
logrus.Infof("created %d default peers for user %s", len(newPeers), user.Identifier)
return nil
} }
func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) { func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
@ -59,7 +93,7 @@ func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier)
ExtraAllowedIPsStr: "", ExtraAllowedIPsStr: "",
PresharedKey: pk, PresharedKey: pk,
PersistentKeepalive: domain.NewIntConfigOption(iface.PeerDefPersistentKeepalive, true), PersistentKeepalive: domain.NewIntConfigOption(iface.PeerDefPersistentKeepalive, true),
DisplayName: fmt.Sprintf("Peer %s", peerId[0:8]), DisplayName: fmt.Sprintf("Peer %s", internal.TruncateString(string(peerId), 8)),
Identifier: peerId, Identifier: peerId,
UserIdentifier: currentUser.Id, UserIdentifier: currentUser.Id,
InterfaceIdentifier: iface.Identifier, InterfaceIdentifier: iface.Identifier,
@ -133,7 +167,7 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
func (m Manager) CreateMultiplePeers(ctx context.Context, interfaceId domain.InterfaceIdentifier, r *domain.PeerCreationRequest) ([]domain.Peer, error) { func (m Manager) CreateMultiplePeers(ctx context.Context, interfaceId domain.InterfaceIdentifier, r *domain.PeerCreationRequest) ([]domain.Peer, error) {
var newPeers []domain.Peer var newPeers []domain.Peer
for _, id := range r.Identifiers { for _, id := range r.UserIdentifiers {
freshPeer, err := m.PreparePeer(ctx, interfaceId) freshPeer, err := m.PreparePeer(ctx, interfaceId)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to prepare peer for interface %s: %w", interfaceId, err) return nil, fmt.Errorf("failed to prepare peer for interface %s: %w", interfaceId, err)

View File

@ -153,7 +153,11 @@ func loadConfigFile(cfg any, filename string) error {
if err != nil { if err != nil {
return err return err
} }
defer f.Close() defer func(f *os.File) {
if err := f.Close(); err != nil {
logrus.Errorf("failed to close configuration file %s: %v", filename, err)
}
}(f)
decoder := yaml.NewDecoder(f) decoder := yaml.NewDecoder(f)
err = decoder.Decode(cfg) err = decoder.Decode(cfg)

View File

@ -8,6 +8,11 @@ import (
const CtxUserInfo = "userInfo" const CtxUserInfo = "userInfo"
const (
CtxSystemAdminId = "_WG_SYS_ADMIN_"
CtxUnknownUserId = "_WG_SYS_UNKNOWN_"
)
type ContextUserInfo struct { type ContextUserInfo struct {
Id UserIdentifier Id UserIdentifier
IsAdmin bool IsAdmin bool
@ -15,14 +20,14 @@ type ContextUserInfo struct {
func DefaultContextUserInfo() *ContextUserInfo { func DefaultContextUserInfo() *ContextUserInfo {
return &ContextUserInfo{ return &ContextUserInfo{
Id: "_WG_SYS_UNKNOWN_", Id: CtxUnknownUserId,
IsAdmin: false, IsAdmin: false,
} }
} }
func SystemAdminContextUserInfo() *ContextUserInfo { func SystemAdminContextUserInfo() *ContextUserInfo {
return &ContextUserInfo{ return &ContextUserInfo{
Id: "_WG_SYS_ADMIN_", Id: CtxSystemAdminId,
IsAdmin: true, IsAdmin: true,
} }
} }

View File

@ -232,6 +232,6 @@ func MergeToPhysicalPeer(pp *PhysicalPeer, p *Peer) {
} }
type PeerCreationRequest struct { type PeerCreationRequest struct {
Identifiers []string UserIdentifiers []string
Suffix string Suffix string
} }

View File

@ -3,6 +3,7 @@ package internal
import ( import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"github.com/sirupsen/logrus"
"os" "os"
"github.com/go-ldap/ldap/v3" "github.com/go-ldap/ldap/v3"
@ -72,7 +73,9 @@ func LdapConnect(cfg *config.LdapProvider) (*ldap.Conn, error) {
func LdapDisconnect(conn *ldap.Conn) { func LdapDisconnect(conn *ldap.Conn) {
if conn != nil { if conn != nil {
conn.Close() if err := conn.Close(); err != nil {
logrus.Errorf("failed to close ldap connection: %v", err)
}
} }
} }