many more improvements and cleanup

This commit is contained in:
Christoph Haas 2023-07-24 21:00:45 +02:00
parent 5153f602ab
commit 2a5b4fe31d
21 changed files with 201 additions and 108 deletions

View File

@ -140,7 +140,7 @@ The following configuration options are available:
* [Gin, HTTP web framework written in Go](https://github.com/gin-gonic/gin)
* [Bootstrap, for the HTML templates](https://getbootstrap.com/)
* [Vue.JS, for the frontend](hhttps://vuejs.org/)
* [Vue.JS, for the frontend](https://vuejs.org/)
## License

View File

@ -23,6 +23,7 @@ import (
evbus "github.com/vardius/message-bus"
)
// main entry point for WireGuard Portal
func main() {
ctx := internal.SignalAwareContext(context.Background(), syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM)
@ -44,6 +45,9 @@ func main() {
mailer := adapters.NewSmtpMailRepo(cfg.Mail)
cfgFileSystem, err := adapters.NewFileSystemRepository(cfg.Advanced.ConfigStoragePath)
internal.AssertNoError(err)
shouldExit, err := app.HandleProgramArgs(cfg, rawDb)
switch {
case shouldExit && err == nil:
@ -70,7 +74,7 @@ func main() {
statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, database, wireGuard)
internal.AssertNoError(err)
cfgFileManager, err := configfile.NewConfigFileManager(cfg, database, database)
cfgFileManager, err := configfile.NewConfigFileManager(cfg, database, database, cfgFileSystem)
internal.AssertNoError(err)
mailManager, err := mail.NewMailManager(cfg, mailer, cfgFileManager, database, database)

View File

@ -22,8 +22,10 @@ import (
"gorm.io/gorm"
)
// SchemaVersion describes the current database schema version. It must be incremented if a manual migration is needed.
var SchemaVersion uint64 = 1
// SysStat stores the current database schema version and the timestamp when it was applied.
type SysStat struct {
MigratedAt time.Time `gorm:"column:migrated_at"`
SchemaVersion uint64 `gorm:"primaryKey,column:schema_version"`
@ -143,6 +145,8 @@ func NewDatabase(cfg config.DatabaseConfig) (*gorm.DB, error) {
return gormDb, nil
}
// SqlRepo is a SQL database repository implementation.
// Currently, it supports MySQL, SQLite, Microsoft SQL and Postgresql database systems.
type SqlRepo struct {
db *gorm.DB
}

View File

@ -1,43 +1,52 @@
package adapters
import (
"context"
"fmt"
"github.com/sirupsen/logrus"
"io"
"os"
"path/filepath"
)
type filesystemRepo struct {
type FilesystemRepo struct {
basePath string
}
func NewFileSystemRepository(basePath string) (*filesystemRepo, error) {
r := &filesystemRepo{basePath: basePath}
func NewFileSystemRepository(basePath string) (*FilesystemRepo, error) {
if basePath == "" {
return nil, nil // no path, return empty repository
}
r := &FilesystemRepo{basePath: basePath}
if err := os.MkdirAll(r.basePath, os.ModePerm); err != nil {
return nil, err
return nil, fmt.Errorf("failed to create base directory %s: %w", basePath, err)
}
return r, nil
}
func (r *filesystemRepo) WriteFile(_ context.Context, path string, contents io.Reader) error {
func (r *FilesystemRepo) WriteFile(path string, contents io.Reader) error {
filePath := filepath.Join(r.basePath, path)
parentDirectory := filepath.Dir(filePath)
err := os.MkdirAll(filepath.Dir(filePath), os.ModePerm)
if err != nil {
return err
if err := os.MkdirAll(parentDirectory, os.ModePerm); err != nil {
return fmt.Errorf("failed to create parent directory %s: %w", parentDirectory, err)
}
file, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.ModePerm)
if err != nil {
return err
return fmt.Errorf("failed to open file %s: %w", file.Name(), err)
}
defer file.Close()
defer func(file *os.File) {
if err := file.Close(); err != nil {
logrus.Errorf("failed to close file %s: %v", file.Name(), err)
}
}(file)
_, err = io.Copy(file, contents)
if err != nil {
return err
return fmt.Errorf("failed to write file contents: %w", err)
}
return nil

View File

@ -1 +0,0 @@
DROP TABLE IF EXISTS test;

View File

@ -1,3 +0,0 @@
CREATE TABLE IF NOT EXISTS test (
firstname VARCHAR(16)
);

View File

@ -13,12 +13,13 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type wgRepo struct {
// WgRepo implements all low-level WireGuard interactions.
type WgRepo struct {
wg lowlevel.WireGuardClient
nl lowlevel.NetlinkClient
}
func NewWireGuardRepository() *wgRepo {
func NewWireGuardRepository() *WgRepo {
wg, err := wgctrl.New()
if err != nil {
panic("failed to init wgctrl: " + err.Error())
@ -26,7 +27,7 @@ func NewWireGuardRepository() *wgRepo {
nl := &lowlevel.NetlinkManager{}
repo := &wgRepo{
repo := &WgRepo{
wg: wg,
nl: nl,
}
@ -34,7 +35,7 @@ func NewWireGuardRepository() *wgRepo {
return repo
}
func (r *wgRepo) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) {
func (r *WgRepo) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) {
devices, err := r.wg.Devices()
if err != nil {
return nil, fmt.Errorf("device list error: %w", err)
@ -52,11 +53,11 @@ func (r *wgRepo) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, e
return interfaces, nil
}
func (r *wgRepo) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
func (r *WgRepo) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
return r.getInterface(id)
}
func (r *wgRepo) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) {
func (r *WgRepo) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) {
device, err := r.wg.Device(string(deviceId))
if err != nil {
return nil, fmt.Errorf("device error: %w", err)
@ -74,11 +75,11 @@ func (r *wgRepo) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier
return peers, nil
}
func (r *wgRepo) GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) {
func (r *WgRepo) GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) {
return r.getPeer(deviceId, id)
}
func (r *wgRepo) convertWireGuardInterface(device *wgtypes.Device) (domain.PhysicalInterface, error) {
func (r *WgRepo) convertWireGuardInterface(device *wgtypes.Device) (domain.PhysicalInterface, error) {
// read data from wgctrl interface
iface := domain.PhysicalInterface{
@ -122,7 +123,7 @@ func (r *wgRepo) convertWireGuardInterface(device *wgtypes.Device) (domain.Physi
return iface, nil
}
func (r *wgRepo) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer, error) {
func (r *WgRepo) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer, error) {
peerModel := domain.PhysicalPeer{
Identifier: domain.PeerIdentifier(peer.PublicKey.String()),
Endpoint: "",
@ -151,7 +152,7 @@ func (r *wgRepo) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer,
return peerModel, nil
}
func (r *wgRepo) SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error {
func (r *WgRepo) SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error {
physicalInterface, err := r.getOrCreateInterface(id)
if err != nil {
return err
@ -174,7 +175,7 @@ func (r *wgRepo) SaveInterface(_ context.Context, id domain.InterfaceIdentifier,
return nil
}
func (r *wgRepo) getOrCreateInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
func (r *WgRepo) getOrCreateInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
device, err := r.getInterface(id)
if err == nil {
return device, nil
@ -192,7 +193,7 @@ func (r *wgRepo) getOrCreateInterface(id domain.InterfaceIdentifier) (*domain.Ph
return device, err
}
func (r *wgRepo) getInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
func (r *WgRepo) getInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
device, err := r.wg.Device(string(id))
if err != nil {
return nil, err
@ -202,7 +203,7 @@ func (r *wgRepo) getInterface(id domain.InterfaceIdentifier) (*domain.PhysicalIn
return &pi, err
}
func (r *wgRepo) createLowLevelInterface(id domain.InterfaceIdentifier) error {
func (r *WgRepo) createLowLevelInterface(id domain.InterfaceIdentifier) error {
link := &netlink.GenericLink{
LinkAttrs: netlink.LinkAttrs{
Name: string(id),
@ -217,7 +218,7 @@ func (r *wgRepo) createLowLevelInterface(id domain.InterfaceIdentifier) error {
return nil
}
func (r *wgRepo) updateLowLevelInterface(pi *domain.PhysicalInterface) error {
func (r *WgRepo) updateLowLevelInterface(pi *domain.PhysicalInterface) error {
link, err := r.nl.LinkByName(string(pi.Identifier))
if err != nil {
return err
@ -274,7 +275,7 @@ func (r *wgRepo) updateLowLevelInterface(pi *domain.PhysicalInterface) error {
return nil
}
func (r *wgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
func (r *WgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
pKey, err := wgtypes.NewKey(pi.KeyPair.GetPrivateKeyBytes())
if err != nil {
return err
@ -297,7 +298,7 @@ func (r *wgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
return nil
}
func (r *wgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
if err := r.deleteLowLevelInterface(id); err != nil {
return err
}
@ -305,7 +306,7 @@ func (r *wgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifie
return nil
}
func (r *wgRepo) deleteLowLevelInterface(id domain.InterfaceIdentifier) error {
func (r *WgRepo) deleteLowLevelInterface(id domain.InterfaceIdentifier) error {
link, err := r.nl.LinkByName(string(id))
if err != nil {
return fmt.Errorf("unable to find low level interface: %w", err)
@ -319,7 +320,7 @@ func (r *wgRepo) deleteLowLevelInterface(id domain.InterfaceIdentifier) error {
return nil
}
func (r *wgRepo) SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error {
func (r *WgRepo) SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error {
physicalPeer, err := r.getOrCreatePeer(deviceId, id)
if err != nil {
return err
@ -337,7 +338,7 @@ func (r *wgRepo) SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier
return nil
}
func (r *wgRepo) getOrCreatePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) {
func (r *WgRepo) getOrCreatePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) {
peer, err := r.getPeer(deviceId, id)
if err == nil {
return peer, nil
@ -355,7 +356,7 @@ func (r *wgRepo) getOrCreatePeer(deviceId domain.InterfaceIdentifier, id domain.
return peer, nil
}
func (r *wgRepo) getPeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) {
func (r *WgRepo) getPeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) {
if !id.IsPublicKey() {
return nil, errors.New("invalid public key")
}
@ -378,7 +379,7 @@ func (r *wgRepo) getPeer(deviceId domain.InterfaceIdentifier, id domain.PeerIden
return nil, os.ErrNotExist
}
func (r *wgRepo) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.PhysicalPeer) error {
func (r *WgRepo) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.PhysicalPeer) error {
cfg := wgtypes.PeerConfig{
PublicKey: pp.GetPublicKey(),
Remove: false,
@ -404,7 +405,7 @@ func (r *wgRepo) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.Phys
return nil
}
func (r *wgRepo) DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error {
func (r *WgRepo) DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error {
if !id.IsPublicKey() {
return errors.New("invalid public key")
}
@ -417,7 +418,7 @@ func (r *wgRepo) DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifi
return nil
}
func (r *wgRepo) deletePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error {
func (r *WgRepo) deletePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error {
cfg := wgtypes.PeerConfig{
PublicKey: id.ToPublicKey(),
Remove: true,

View File

@ -20,7 +20,7 @@ import (
)
// setup WireGuard manager with no linked store
func setup(t *testing.T) *wgRepo {
func setup(t *testing.T) *WgRepo {
if getProcessOwner() != "root" {
t.Fatalf("this tests need to be executed as root user")
}

View File

@ -53,7 +53,7 @@ func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc {
return
}
c.JSON(http.StatusOK, model.NewInterface(in))
c.JSON(http.StatusOK, model.NewInterface(in, nil))
}
}
@ -68,7 +68,7 @@ func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc {
// @Router /interface/all [get]
func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc {
return func(c *gin.Context) {
interfaces, err := e.app.GetAllInterfaces(c.Request.Context())
interfaces, peers, err := e.app.GetAllInterfacesAndPeers(c.Request.Context())
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
@ -76,7 +76,7 @@ func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc {
return
}
c.JSON(http.StatusOK, model.NewInterfaces(interfaces))
c.JSON(http.StatusOK, model.NewInterfaces(interfaces, peers))
}
}
@ -100,7 +100,7 @@ func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc {
return
}
iface, _, err := e.app.GetInterfaceAndPeers(c.Request.Context(), domain.InterfaceIdentifier(id))
iface, peers, err := e.app.GetInterfaceAndPeers(c.Request.Context(), domain.InterfaceIdentifier(id))
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
@ -108,7 +108,7 @@ func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc {
return
}
c.JSON(http.StatusOK, model.NewInterface(iface))
c.JSON(http.StatusOK, model.NewInterface(iface, peers))
}
}
@ -186,7 +186,7 @@ func (e interfaceEndpoint) handleUpdatePut() gin.HandlerFunc {
return
}
updatedInterface, err := e.app.UpdateInterface(ctx, model.NewDomainInterface(&in))
updatedInterface, peers, err := e.app.UpdateInterface(ctx, model.NewDomainInterface(&in))
if err != nil {
c.JSON(http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
@ -194,7 +194,7 @@ func (e interfaceEndpoint) handleUpdatePut() gin.HandlerFunc {
return
}
c.JSON(http.StatusOK, model.NewInterface(updatedInterface))
c.JSON(http.StatusOK, model.NewInterface(updatedInterface, peers))
}
}
@ -228,7 +228,7 @@ func (e interfaceEndpoint) handleCreatePost() gin.HandlerFunc {
return
}
c.JSON(http.StatusOK, model.NewInterface(newInterface))
c.JSON(http.StatusOK, model.NewInterface(newInterface, nil))
}
}

View File

@ -51,8 +51,8 @@ type Interface struct {
TotalPeers int `json:"TotalPeers"`
}
func NewInterface(src *domain.Interface) *Interface {
return &Interface{
func NewInterface(src *domain.Interface, peers []domain.Peer) *Interface {
iface := &Interface{
Identifier: string(src.Identifier),
DisplayName: src.DisplayName,
Mode: string(src.Type),
@ -86,15 +86,29 @@ func NewInterface(src *domain.Interface) *Interface {
PeerDefPreDown: src.PeerDefPreDown,
PeerDefPostDown: src.PeerDefPostDown,
EnabledPeers: 0, // TODO
TotalPeers: 0, // TODO
EnabledPeers: 0,
TotalPeers: 0,
}
if len(peers) > 0 {
iface.TotalPeers = len(peers)
activePeers := 0
for _, peer := range peers {
if !peer.IsDisabled() {
activePeers++
}
}
iface.EnabledPeers = activePeers
}
return iface
}
func NewInterfaces(src []domain.Interface) []Interface {
func NewInterfaces(src []domain.Interface, srcPeers [][]domain.Peer) []Interface {
results := make([]Interface, len(src))
for i := range src {
results[i] = *NewInterface(&src[i])
results[i] = *NewInterface(&src[i], srcPeers[i])
}
return results

View File

@ -171,8 +171,8 @@ type MultiPeerRequest struct {
func NewDomainPeerCreationRequest(src *MultiPeerRequest) *domain.PeerCreationRequest {
return &domain.PeerCreationRequest{
Identifiers: src.Identifiers,
Suffix: src.Suffix,
UserIdentifiers: src.Identifiers,
Suffix: src.Suffix,
}
}

View File

@ -72,12 +72,14 @@ func (a *App) importNewInterfaces(ctx context.Context) error {
return nil // feature disabled
}
err := a.ImportNewInterfaces(ctx)
importedCount, err := a.ImportNewInterfaces(ctx)
if err != nil {
return err
}
logrus.Trace("potential new interfaces imported")
if importedCount > 0 {
logrus.Infof("%d new interfaces imported", importedCount)
}
return nil
}
@ -92,7 +94,7 @@ func (a *App) restoreInterfaceState(ctx context.Context) error {
return err
}
logrus.Trace("interface state restored")
logrus.Info("interface state restored")
return nil
}
@ -141,7 +143,7 @@ func (a *App) createDefaultUser(ctx context.Context) error {
return err
}
logrus.Tracef("admin user %s created", admin.Identifier)
logrus.Infof("admin user %s created", admin.Identifier)
return nil
}

View File

@ -7,11 +7,9 @@ import (
"fmt"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
"github.com/sirupsen/logrus"
"github.com/yeqown/go-qrcode/v2"
"io"
"os"
"path/filepath"
"strings"
)
@ -19,11 +17,12 @@ type Manager struct {
cfg *config.Config
tplHandler *TemplateHandler
users UserDatabaseRepo
wg WireguardDatabaseRepo
fsRepo FileSystemRepo // can be nil if storing the configuration is disabled
users UserDatabaseRepo
wg WireguardDatabaseRepo
}
func NewConfigFileManager(cfg *config.Config, users UserDatabaseRepo, wg WireguardDatabaseRepo) (*Manager, error) {
func NewConfigFileManager(cfg *config.Config, users UserDatabaseRepo, wg WireguardDatabaseRepo, fsRepo FileSystemRepo) (*Manager, error) {
tplHandler, err := newTemplateHandler()
if err != nil {
return nil, fmt.Errorf("failed to initialize template handler: %w", err)
@ -33,8 +32,9 @@ func NewConfigFileManager(cfg *config.Config, users UserDatabaseRepo, wg Wiregua
cfg: cfg,
tplHandler: tplHandler,
users: users,
wg: wg,
fsRepo: fsRepo,
users: users,
wg: wg,
}
if err := m.createStorageDirectory(); err != nil {
@ -122,7 +122,7 @@ func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifi
}
func (m Manager) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error {
if m.cfg.Advanced.ConfigStoragePath == "" {
if m.fsRepo == nil {
return fmt.Errorf("peristing configuration is not supported")
}
@ -136,19 +136,7 @@ func (m Manager) PersistInterfaceConfig(ctx context.Context, id domain.Interface
return fmt.Errorf("failed to get interface config: %w", err)
}
file, err := os.Create(filepath.Join(m.cfg.Advanced.ConfigStoragePath, iface.GetConfigFileName()))
if err != nil {
return fmt.Errorf("failed to create interface config file: %w", err)
}
defer func(file *os.File) {
err := file.Close()
if err != nil {
logrus.Warn("failed to close interface config file: %v", err)
}
}(file)
_, err = io.Copy(file, cfg)
if err != nil {
if err := m.fsRepo.WriteFile(iface.GetConfigFileName(), cfg); err != nil {
return fmt.Errorf("failed to write interface config: %w", err)
}

View File

@ -3,6 +3,7 @@ package configfile
import (
"context"
"github.com/h44z/wg-portal/internal/domain"
"io"
)
type UserDatabaseRepo interface {
@ -14,3 +15,7 @@ type WireguardDatabaseRepo interface {
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
}
type FileSystemRepo interface {
WriteFile(path string, contents io.Reader) error
}

View File

@ -28,17 +28,17 @@ type UserManager interface {
type WireGuardManager interface {
StartBackgroundJobs(ctx context.Context)
GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error)
ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) error
ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error)
RestoreInterfaceState(ctx context.Context, updateDbOnError bool, filter ...domain.InterfaceIdentifier) error
CreateDefaultPeer(ctx context.Context, user *domain.User) error
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error)
GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error)
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error)
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
PrepareInterface(ctx context.Context) (*domain.Interface, error)
CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error)
UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error)
UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error)
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error)
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)

View File

@ -27,10 +27,28 @@ func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, erro
return m.db.GetAllInterfaces(ctx)
}
func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) error {
func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) {
interfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return nil, nil, fmt.Errorf("unable to load all interfaces: %w", err)
}
allPeers := make([][]domain.Peer, len(interfaces))
for i, iface := range interfaces {
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil {
return nil, nil, fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
}
allPeers[i] = peers
}
return interfaces, allPeers, nil
}
func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) {
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
if err != nil {
return err
return 0, err
}
// if no filter is given, exclude already existing interfaces
@ -38,13 +56,14 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter
if len(filter) == 0 {
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return err
return 0, err
}
for _, existingInterface := range existingInterfaces {
excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier)
}
}
imported := 0
for _, physicalInterface := range physicalInterfaces {
if internal.SliceContains(excludedInterfaces, physicalInterface.Identifier) {
continue
@ -58,18 +77,19 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter
physicalPeers, err := m.wg.GetPeers(ctx, physicalInterface.Identifier)
if err != nil {
return err
return 0, err
}
err = m.importInterface(ctx, &physicalInterface, physicalPeers)
if err != nil {
return fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err)
return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err)
}
logrus.Infof("imported new interface %s and %d peers", physicalInterface.Identifier, len(physicalPeers))
imported++
}
return nil
return imported, nil
}
func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error {
@ -117,6 +137,8 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool
physicalInterface, err := m.wg.GetInterface(ctx, iface.Identifier)
if err != nil {
logrus.Debugf("creating missing interface %s...", iface.Identifier)
// try to create a new interface
err := m.wg.SaveInterface(ctx, iface.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
domain.MergeToPhysicalInterface(pi, &iface)
@ -148,6 +170,8 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool
}
} else {
if physicalInterface.DeviceUp != !iface.IsDisabled() {
logrus.Debugf("restoring interface state for %s to disabled=%t", iface.Identifier, iface.IsDisabled())
// try to move interface to stored state
err := m.wg.SaveInterface(ctx, iface.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
pi.DeviceUp = !iface.IsDisabled()
@ -287,14 +311,14 @@ func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*do
return in, nil
}
func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) {
existingInterface, err := m.db.GetInterface(ctx, in.Identifier)
func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error) {
existingInterface, existingPeers, err := m.db.GetInterfaceAndPeers(ctx, in.Identifier)
if err != nil {
return nil, fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err)
return nil, nil, fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err)
}
if err := m.validateInterfaceModifications(ctx, existingInterface, in); err != nil {
return nil, fmt.Errorf("update not allowed: %w", err)
return nil, nil, fmt.Errorf("update not allowed: %w", err)
}
err = m.db.SaveInterface(ctx, in.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
@ -311,10 +335,10 @@ func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*do
return in, nil
})
if err != nil {
return nil, fmt.Errorf("update failure: %w", err)
return nil, nil, fmt.Errorf("update failure: %w", err)
}
return in, nil
return in, existingPeers, nil
}
func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {

View File

@ -4,13 +4,47 @@ import (
"context"
"errors"
"fmt"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/domain"
"github.com/sirupsen/logrus"
"time"
)
func (m Manager) CreateDefaultPeer(ctx context.Context, user *domain.User) error {
// TODO: implement
return fmt.Errorf("IMPLEMENT ME")
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return fmt.Errorf("failed to fetch all interfaces: %w", err)
}
var newPeers []domain.Peer
for _, iface := range existingInterfaces {
if iface.Type != domain.InterfaceTypeServer {
continue // only create default peers for server interfaces
}
peer, err := m.PreparePeer(ctx, iface.Identifier)
if err != nil {
return fmt.Errorf("failed to create default peer for interface %s: %w", iface.Identifier, err)
}
peer.UserIdentifier = user.Identifier
peer.DisplayName = fmt.Sprintf("Default Peer %s", internal.TruncateString(string(peer.Identifier), 8))
peer.Notes = fmt.Sprintf("Default peer created for user %s", user.Identifier)
newPeers = append(newPeers, *peer)
}
for i, peer := range newPeers {
_, err := m.CreatePeer(ctx, &newPeers[i])
if err != nil {
return fmt.Errorf("failed to create default peer %s on interface %s: %w",
peer.Identifier, peer.InterfaceIdentifier, err)
}
}
logrus.Infof("created %d default peers for user %s", len(newPeers), user.Identifier)
return nil
}
func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
@ -59,7 +93,7 @@ func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier)
ExtraAllowedIPsStr: "",
PresharedKey: pk,
PersistentKeepalive: domain.NewIntConfigOption(iface.PeerDefPersistentKeepalive, true),
DisplayName: fmt.Sprintf("Peer %s", peerId[0:8]),
DisplayName: fmt.Sprintf("Peer %s", internal.TruncateString(string(peerId), 8)),
Identifier: peerId,
UserIdentifier: currentUser.Id,
InterfaceIdentifier: iface.Identifier,
@ -133,7 +167,7 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
func (m Manager) CreateMultiplePeers(ctx context.Context, interfaceId domain.InterfaceIdentifier, r *domain.PeerCreationRequest) ([]domain.Peer, error) {
var newPeers []domain.Peer
for _, id := range r.Identifiers {
for _, id := range r.UserIdentifiers {
freshPeer, err := m.PreparePeer(ctx, interfaceId)
if err != nil {
return nil, fmt.Errorf("failed to prepare peer for interface %s: %w", interfaceId, err)

View File

@ -153,7 +153,11 @@ func loadConfigFile(cfg any, filename string) error {
if err != nil {
return err
}
defer f.Close()
defer func(f *os.File) {
if err := f.Close(); err != nil {
logrus.Errorf("failed to close configuration file %s: %v", filename, err)
}
}(f)
decoder := yaml.NewDecoder(f)
err = decoder.Decode(cfg)

View File

@ -8,6 +8,11 @@ import (
const CtxUserInfo = "userInfo"
const (
CtxSystemAdminId = "_WG_SYS_ADMIN_"
CtxUnknownUserId = "_WG_SYS_UNKNOWN_"
)
type ContextUserInfo struct {
Id UserIdentifier
IsAdmin bool
@ -15,14 +20,14 @@ type ContextUserInfo struct {
func DefaultContextUserInfo() *ContextUserInfo {
return &ContextUserInfo{
Id: "_WG_SYS_UNKNOWN_",
Id: CtxUnknownUserId,
IsAdmin: false,
}
}
func SystemAdminContextUserInfo() *ContextUserInfo {
return &ContextUserInfo{
Id: "_WG_SYS_ADMIN_",
Id: CtxSystemAdminId,
IsAdmin: true,
}
}

View File

@ -232,6 +232,6 @@ func MergeToPhysicalPeer(pp *PhysicalPeer, p *Peer) {
}
type PeerCreationRequest struct {
Identifiers []string
Suffix string
UserIdentifiers []string
Suffix string
}

View File

@ -3,6 +3,7 @@ package internal
import (
"crypto/tls"
"fmt"
"github.com/sirupsen/logrus"
"os"
"github.com/go-ldap/ldap/v3"
@ -72,7 +73,9 @@ func LdapConnect(cfg *config.LdapProvider) (*ldap.Conn, error) {
func LdapDisconnect(conn *ldap.Conn) {
if conn != nil {
conn.Close()
if err := conn.Close(); err != nil {
logrus.Errorf("failed to close ldap connection: %v", err)
}
}
}