mirror of
				https://github.com/h44z/wg-portal.git
				synced 2025-11-03 23:56:18 +00:00 
			
		
		
		
	V2 alpha - initial version (#172)
Initial alpha codebase for version 2 of WireGuard Portal. This version is considered unstable and incomplete (for example, no public REST API)! Use with care! Fixes/Implements the following issues: - OAuth support #154, #1 - New Web UI with internationalisation support #98, #107, #89, #62 - Postgres Support #49 - Improved Email handling #47, #119 - DNS Search Domain support #46 - Bugfixes #94, #48 --------- Co-authored-by: Fabian Wechselberger <wechselbergerf@hotmail.com>
This commit is contained in:
		
							
								
								
									
										887
									
								
								internal/adapters/database.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										887
									
								
								internal/adapters/database.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,887 @@
 | 
			
		||||
package adapters
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/sirupsen/logrus"
 | 
			
		||||
	"gorm.io/gorm/clause"
 | 
			
		||||
	"gorm.io/gorm/logger"
 | 
			
		||||
	"gorm.io/gorm/utils"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/glebarez/sqlite"
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/config"
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/domain"
 | 
			
		||||
	gormMySQL "gorm.io/driver/mysql"
 | 
			
		||||
	"gorm.io/driver/postgres"
 | 
			
		||||
	"gorm.io/driver/sqlserver"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// SchemaVersion describes the current database schema version. It must be incremented if a manual migration is needed.
 | 
			
		||||
var SchemaVersion uint64 = 1
 | 
			
		||||
 | 
			
		||||
// SysStat stores the current database schema version and the timestamp when it was applied.
 | 
			
		||||
type SysStat struct {
 | 
			
		||||
	MigratedAt    time.Time `gorm:"column:migrated_at"`
 | 
			
		||||
	SchemaVersion uint64    `gorm:"primaryKey,column:schema_version"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GormLogger is a custom logger for Gorm, making it use logrus.
 | 
			
		||||
type GormLogger struct {
 | 
			
		||||
	SlowThreshold           time.Duration
 | 
			
		||||
	SourceField             string
 | 
			
		||||
	IgnoreErrRecordNotFound bool
 | 
			
		||||
	Debug                   bool
 | 
			
		||||
	Silent                  bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewLogger(slowThreshold time.Duration, debug bool) *GormLogger {
 | 
			
		||||
	return &GormLogger{
 | 
			
		||||
		SlowThreshold:           slowThreshold,
 | 
			
		||||
		Debug:                   debug,
 | 
			
		||||
		IgnoreErrRecordNotFound: true,
 | 
			
		||||
		Silent:                  false,
 | 
			
		||||
		SourceField:             "src",
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *GormLogger) LogMode(level logger.LogLevel) logger.Interface {
 | 
			
		||||
	if level == logger.Silent {
 | 
			
		||||
		l.Silent = true
 | 
			
		||||
	} else {
 | 
			
		||||
		l.Silent = false
 | 
			
		||||
	}
 | 
			
		||||
	return l
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *GormLogger) Info(ctx context.Context, s string, args ...interface{}) {
 | 
			
		||||
	if l.Silent {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	logrus.WithContext(ctx).Infof(s, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *GormLogger) Warn(ctx context.Context, s string, args ...interface{}) {
 | 
			
		||||
	if l.Silent {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	logrus.WithContext(ctx).Warnf(s, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *GormLogger) Error(ctx context.Context, s string, args ...interface{}) {
 | 
			
		||||
	if l.Silent {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	logrus.WithContext(ctx).Errorf(s, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
 | 
			
		||||
	if l.Silent {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	elapsed := time.Since(begin)
 | 
			
		||||
	sql, rows := fc()
 | 
			
		||||
	fields := logrus.Fields{
 | 
			
		||||
		"rows":     rows,
 | 
			
		||||
		"duration": elapsed,
 | 
			
		||||
	}
 | 
			
		||||
	if l.SourceField != "" {
 | 
			
		||||
		fields[l.SourceField] = utils.FileWithLineNum()
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil && !(errors.Is(err, gorm.ErrRecordNotFound) && l.IgnoreErrRecordNotFound) {
 | 
			
		||||
		fields[logrus.ErrorKey] = err
 | 
			
		||||
		logrus.WithContext(ctx).WithFields(fields).Errorf("%s", sql)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if l.SlowThreshold != 0 && elapsed > l.SlowThreshold {
 | 
			
		||||
		logrus.WithContext(ctx).WithFields(fields).Warnf("%s", sql)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if l.Debug {
 | 
			
		||||
		logrus.WithContext(ctx).WithFields(fields).Tracef("%s", sql)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewDatabase(cfg config.DatabaseConfig) (*gorm.DB, error) {
 | 
			
		||||
	var gormDb *gorm.DB
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	switch cfg.Type {
 | 
			
		||||
	case config.DatabaseMySQL:
 | 
			
		||||
		gormDb, err = gorm.Open(gormMySQL.Open(cfg.DSN), &gorm.Config{
 | 
			
		||||
			Logger: NewLogger(cfg.SlowQueryThreshold, cfg.Debug),
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("failed to open MySQL database: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		sqlDB, _ := gormDb.DB()
 | 
			
		||||
		sqlDB.SetConnMaxLifetime(time.Minute * 5)
 | 
			
		||||
		sqlDB.SetMaxIdleConns(2)
 | 
			
		||||
		sqlDB.SetMaxOpenConns(10)
 | 
			
		||||
		err = sqlDB.Ping() // This DOES open a connection if necessary. This makes sure the database is accessible
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("failed to ping MySQL database: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
	case config.DatabaseMsSQL:
 | 
			
		||||
		gormDb, err = gorm.Open(sqlserver.Open(cfg.DSN), &gorm.Config{
 | 
			
		||||
			Logger: NewLogger(cfg.SlowQueryThreshold, cfg.Debug),
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("failed to open sqlserver database: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
	case config.DatabasePostgres:
 | 
			
		||||
		gormDb, err = gorm.Open(postgres.Open(cfg.DSN), &gorm.Config{
 | 
			
		||||
			Logger: NewLogger(cfg.SlowQueryThreshold, cfg.Debug),
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("failed to open Postgres database: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
	case config.DatabaseSQLite:
 | 
			
		||||
		if _, err = os.Stat(filepath.Dir(cfg.DSN)); os.IsNotExist(err) {
 | 
			
		||||
			if err = os.MkdirAll(filepath.Dir(cfg.DSN), 0700); err != nil {
 | 
			
		||||
				return nil, fmt.Errorf("failed to create database base directory: %w", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		gormDb, err = gorm.Open(sqlite.Open(cfg.DSN), &gorm.Config{
 | 
			
		||||
			Logger:                                   NewLogger(cfg.SlowQueryThreshold, cfg.Debug),
 | 
			
		||||
			DisableForeignKeyConstraintWhenMigrating: true,
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("failed to open sqlite database: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
		sqlDB, _ := gormDb.DB()
 | 
			
		||||
		sqlDB.SetMaxOpenConns(1)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return gormDb, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SqlRepo is a SQL database repository implementation.
 | 
			
		||||
// Currently, it supports MySQL, SQLite, Microsoft SQL and Postgresql database systems.
 | 
			
		||||
type SqlRepo struct {
 | 
			
		||||
	db *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewSqlRepository(db *gorm.DB) (*SqlRepo, error) {
 | 
			
		||||
	repo := &SqlRepo{
 | 
			
		||||
		db: db,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := repo.preCheck(); err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("failed to initialize database: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := repo.migrate(); err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("failed to initialize database: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return repo, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) preCheck() error {
 | 
			
		||||
	// WireGuard Portal v1 database migration table
 | 
			
		||||
	type DatabaseMigrationInfo struct {
 | 
			
		||||
		Version string `gorm:"primaryKey"`
 | 
			
		||||
		Applied time.Time
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// temporarily disable logger as the next request might fail (intentionally)
 | 
			
		||||
	r.db.Logger.LogMode(logger.Silent)
 | 
			
		||||
	defer func() { r.db.Logger.LogMode(logger.Info) }()
 | 
			
		||||
 | 
			
		||||
	lastVersion := DatabaseMigrationInfo{}
 | 
			
		||||
	err := r.db.Order("applied desc, version desc").FirstOrInit(&lastVersion).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil // we probably don't have a V1 database =)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fmt.Errorf("detected a WireGuard Portal V1 database (version: %s) - please migrate first", lastVersion.Version)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) migrate() error {
 | 
			
		||||
	logrus.Tracef("sysstat migration: %v", r.db.AutoMigrate(&SysStat{}))
 | 
			
		||||
	logrus.Tracef("user migration: %v", r.db.AutoMigrate(&domain.User{}))
 | 
			
		||||
	logrus.Tracef("interface migration: %v", r.db.AutoMigrate(&domain.Interface{}))
 | 
			
		||||
	logrus.Tracef("peer migration: %v", r.db.AutoMigrate(&domain.Peer{}))
 | 
			
		||||
	logrus.Tracef("peer status migration: %v", r.db.AutoMigrate(&domain.PeerStatus{}))
 | 
			
		||||
	logrus.Tracef("interface status migration: %v", r.db.AutoMigrate(&domain.InterfaceStatus{}))
 | 
			
		||||
	logrus.Tracef("audit data migration: %v", r.db.AutoMigrate(&domain.AuditEntry{}))
 | 
			
		||||
 | 
			
		||||
	existingSysStat := SysStat{}
 | 
			
		||||
	r.db.Where("schema_version = ?", SchemaVersion).First(&existingSysStat)
 | 
			
		||||
	if existingSysStat.SchemaVersion == 0 {
 | 
			
		||||
		sysStat := SysStat{
 | 
			
		||||
			MigratedAt:    time.Now(),
 | 
			
		||||
			SchemaVersion: SchemaVersion,
 | 
			
		||||
		}
 | 
			
		||||
		if err := r.db.Create(&sysStat).Error; err != nil {
 | 
			
		||||
			return fmt.Errorf("failed to write sysstat entry for schema version %d: %w", SchemaVersion, err)
 | 
			
		||||
		}
 | 
			
		||||
		logrus.Debugf("sysstat entry for schema version %d written", SchemaVersion)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// region interfaces
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) {
 | 
			
		||||
	var in domain.Interface
 | 
			
		||||
 | 
			
		||||
	err := r.db.WithContext(ctx).Preload("Addresses").First(&in, id).Error
 | 
			
		||||
 | 
			
		||||
	if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
 | 
			
		||||
		return nil, domain.ErrNotFound
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &in, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) {
 | 
			
		||||
	in, err := r.GetInterface(ctx, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, fmt.Errorf("failed to load interface: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	peers, err := r.GetInterfacePeers(ctx, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, fmt.Errorf("failed to load peers: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return in, peers, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error) {
 | 
			
		||||
	if len(ids) == 0 {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var stats []domain.PeerStatus
 | 
			
		||||
 | 
			
		||||
	err := r.db.WithContext(ctx).Where("identifier IN ?", ids).Find(&stats).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return stats, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) {
 | 
			
		||||
	var interfaces []domain.Interface
 | 
			
		||||
 | 
			
		||||
	err := r.db.WithContext(ctx).Preload("Addresses").Find(&interfaces).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return interfaces, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) FindInterfaces(ctx context.Context, search string) ([]domain.Interface, error) {
 | 
			
		||||
	var users []domain.Interface
 | 
			
		||||
 | 
			
		||||
	searchValue := "%" + strings.ToLower(search) + "%"
 | 
			
		||||
	err := r.db.WithContext(ctx).
 | 
			
		||||
		Where("identifier LIKE ?", searchValue).
 | 
			
		||||
		Or("display_name LIKE ?", searchValue).
 | 
			
		||||
		Preload("Addresses").
 | 
			
		||||
		Find(&users).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return users, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) SaveInterface(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.Interface) (*domain.Interface, error)) error {
 | 
			
		||||
	userInfo := domain.GetUserInfo(ctx)
 | 
			
		||||
	err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		in, err := r.getOrCreateInterface(userInfo, tx, id)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err // return any error will roll back
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		in, err = updateFunc(in)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err = r.upsertInterface(userInfo, tx, in)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// return nil will commit the whole transaction
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) getOrCreateInterface(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.InterfaceIdentifier) (*domain.Interface, error) {
 | 
			
		||||
	var in domain.Interface
 | 
			
		||||
 | 
			
		||||
	// interfaceDefaults will be applied to newly created interface records
 | 
			
		||||
	interfaceDefaults := domain.Interface{
 | 
			
		||||
		BaseModel: domain.BaseModel{
 | 
			
		||||
			CreatedBy: ui.UserId(),
 | 
			
		||||
			UpdatedBy: ui.UserId(),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
			UpdatedAt: time.Now(),
 | 
			
		||||
		},
 | 
			
		||||
		Identifier: id,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := tx.Attrs(interfaceDefaults).FirstOrCreate(&in, id).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &in, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) upsertInterface(ui *domain.ContextUserInfo, tx *gorm.DB, in *domain.Interface) error {
 | 
			
		||||
	in.UpdatedBy = ui.UserId()
 | 
			
		||||
	in.UpdatedAt = time.Now()
 | 
			
		||||
 | 
			
		||||
	err := tx.Save(in).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = tx.Model(in).Association("Addresses").Replace(in.Addresses)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to update interface addresses: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
 | 
			
		||||
	err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		err := tx.Where("interface_identifier = ?", id).Delete(&domain.Peer{}).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err = tx.Delete(&domain.InterfaceStatus{InterfaceId: id}).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err = tx.Select(clause.Associations).Delete(&domain.Interface{Identifier: id}).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error) {
 | 
			
		||||
	var ips []struct {
 | 
			
		||||
		domain.Cidr
 | 
			
		||||
		InterfaceId domain.InterfaceIdentifier `gorm:"column:interface_identifier"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := r.db.WithContext(ctx).
 | 
			
		||||
		Table("interface_addresses").
 | 
			
		||||
		Joins("LEFT JOIN cidrs ON interface_addresses.cidr_cidr = cidrs.cidr").
 | 
			
		||||
		Scan(&ips).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	result := make(map[domain.InterfaceIdentifier][]domain.Cidr)
 | 
			
		||||
	for _, ip := range ips {
 | 
			
		||||
		result[ip.InterfaceId] = append(result[ip.InterfaceId], ip.Cidr)
 | 
			
		||||
	}
 | 
			
		||||
	return result, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// endregion interfaces
 | 
			
		||||
 | 
			
		||||
// region peers
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) {
 | 
			
		||||
	var peer domain.Peer
 | 
			
		||||
 | 
			
		||||
	err := r.db.WithContext(ctx).Preload("Addresses").First(&peer, id).Error
 | 
			
		||||
 | 
			
		||||
	if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
 | 
			
		||||
		return nil, domain.ErrNotFound
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &peer, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error) {
 | 
			
		||||
	var peers []domain.Peer
 | 
			
		||||
 | 
			
		||||
	err := r.db.WithContext(ctx).Preload("Addresses").Where("interface_identifier = ?", id).Find(&peers).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return peers, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ([]domain.Peer, error) {
 | 
			
		||||
	var peers []domain.Peer
 | 
			
		||||
 | 
			
		||||
	searchValue := "%" + strings.ToLower(search) + "%"
 | 
			
		||||
	err := r.db.WithContext(ctx).Where("interface_identifier = ?", id).
 | 
			
		||||
		Where("identifier LIKE ?", searchValue).
 | 
			
		||||
		Or("display_name LIKE ?", searchValue).
 | 
			
		||||
		Or("iface_address_str_v LIKE ?", searchValue).
 | 
			
		||||
		Find(&peers).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return peers, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
 | 
			
		||||
	var peers []domain.Peer
 | 
			
		||||
 | 
			
		||||
	err := r.db.WithContext(ctx).Preload("Addresses").Where("user_identifier = ?", id).Find(&peers).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return peers, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) FindUserPeers(ctx context.Context, id domain.UserIdentifier, search string) ([]domain.Peer, error) {
 | 
			
		||||
	var peers []domain.Peer
 | 
			
		||||
 | 
			
		||||
	searchValue := "%" + strings.ToLower(search) + "%"
 | 
			
		||||
	err := r.db.WithContext(ctx).Where("user_identifier = ?", id).
 | 
			
		||||
		Where("identifier LIKE ?", searchValue).
 | 
			
		||||
		Or("display_name LIKE ?", searchValue).
 | 
			
		||||
		Or("iface_address_str_v LIKE ?", searchValue).
 | 
			
		||||
		Find(&peers).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return peers, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) SavePeer(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.Peer) (*domain.Peer, error)) error {
 | 
			
		||||
	userInfo := domain.GetUserInfo(ctx)
 | 
			
		||||
	err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		peer, err := r.getOrCreatePeer(userInfo, tx, id)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err // return any error will roll back
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		peer, err = updateFunc(peer)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err = r.upsertPeer(userInfo, tx, peer)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// return nil will commit the whole transaction
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) getOrCreatePeer(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.PeerIdentifier) (*domain.Peer, error) {
 | 
			
		||||
	var peer domain.Peer
 | 
			
		||||
 | 
			
		||||
	// interfaceDefaults will be applied to newly created interface records
 | 
			
		||||
	interfaceDefaults := domain.Peer{
 | 
			
		||||
		BaseModel: domain.BaseModel{
 | 
			
		||||
			CreatedBy: ui.UserId(),
 | 
			
		||||
			UpdatedBy: ui.UserId(),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
			UpdatedAt: time.Now(),
 | 
			
		||||
		},
 | 
			
		||||
		Identifier: id,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := tx.Attrs(interfaceDefaults).FirstOrCreate(&peer, id).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &peer, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) upsertPeer(ui *domain.ContextUserInfo, tx *gorm.DB, peer *domain.Peer) error {
 | 
			
		||||
	peer.UpdatedBy = ui.UserId()
 | 
			
		||||
	peer.UpdatedAt = time.Now()
 | 
			
		||||
 | 
			
		||||
	err := tx.Save(peer).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = tx.Model(peer).Association("Addresses").Replace(peer.Interface.Addresses)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to update peer addresses: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error {
 | 
			
		||||
	err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		err := tx.Delete(&domain.PeerStatus{PeerId: id}).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err = tx.Select(clause.Associations).Delete(&domain.Peer{Identifier: id}).Error
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) GetPeerIps(ctx context.Context) (map[domain.PeerIdentifier][]domain.Cidr, error) {
 | 
			
		||||
	var ips []struct {
 | 
			
		||||
		domain.Cidr
 | 
			
		||||
		PeerId domain.PeerIdentifier `gorm:"column:peer_identifier"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := r.db.WithContext(ctx).
 | 
			
		||||
		Table("peer_addresses").
 | 
			
		||||
		Joins("LEFT JOIN cidrs ON peer_addresses.cidr_cidr = cidrs.cidr").
 | 
			
		||||
		Scan(&ips).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	result := make(map[domain.PeerIdentifier][]domain.Cidr)
 | 
			
		||||
	for _, ip := range ips {
 | 
			
		||||
		result[ip.PeerId] = append(result[ip.PeerId], ip.Cidr)
 | 
			
		||||
	}
 | 
			
		||||
	return result, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context) (map[domain.Cidr][]domain.Cidr, error) {
 | 
			
		||||
	var peerIps []struct {
 | 
			
		||||
		domain.Cidr
 | 
			
		||||
		PeerId domain.PeerIdentifier `gorm:"column:peer_identifier"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := r.db.WithContext(ctx).
 | 
			
		||||
		Table("peer_addresses").
 | 
			
		||||
		Joins("LEFT JOIN cidrs ON peer_addresses.cidr_cidr = cidrs.cidr").
 | 
			
		||||
		Scan(&peerIps).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("failed to fetch peer IP's: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var interfaceIps []struct {
 | 
			
		||||
		domain.Cidr
 | 
			
		||||
		InterfaceId domain.InterfaceIdentifier `gorm:"column:interface_identifier"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = r.db.WithContext(ctx).
 | 
			
		||||
		Table("interface_addresses").
 | 
			
		||||
		Joins("LEFT JOIN cidrs ON interface_addresses.cidr_cidr = cidrs.cidr").
 | 
			
		||||
		Scan(&interfaceIps).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("failed to fetch interface IP's: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	result := make(map[domain.Cidr][]domain.Cidr)
 | 
			
		||||
	for _, ip := range interfaceIps {
 | 
			
		||||
		networkAddr := ip.Cidr.NetworkAddr()
 | 
			
		||||
		result[networkAddr] = append(result[networkAddr], ip.Cidr)
 | 
			
		||||
	}
 | 
			
		||||
	for _, ip := range peerIps {
 | 
			
		||||
		networkAddr := ip.Cidr.NetworkAddr()
 | 
			
		||||
		result[networkAddr] = append(result[networkAddr], ip.Cidr)
 | 
			
		||||
	}
 | 
			
		||||
	return result, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// endregion peers
 | 
			
		||||
 | 
			
		||||
// region users
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
 | 
			
		||||
	var user domain.User
 | 
			
		||||
 | 
			
		||||
	err := r.db.WithContext(ctx).First(&user, id).Error
 | 
			
		||||
 | 
			
		||||
	if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
 | 
			
		||||
		return nil, domain.ErrNotFound
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &user, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) GetAllUsers(ctx context.Context) ([]domain.User, error) {
 | 
			
		||||
	var users []domain.User
 | 
			
		||||
 | 
			
		||||
	err := r.db.WithContext(ctx).Find(&users).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return users, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) FindUsers(ctx context.Context, search string) ([]domain.User, error) {
 | 
			
		||||
	var users []domain.User
 | 
			
		||||
 | 
			
		||||
	searchValue := "%" + strings.ToLower(search) + "%"
 | 
			
		||||
	err := r.db.WithContext(ctx).
 | 
			
		||||
		Where("identifier LIKE ?", searchValue).
 | 
			
		||||
		Or("firstname LIKE ?", searchValue).
 | 
			
		||||
		Or("lastname LIKE ?", searchValue).
 | 
			
		||||
		Or("email LIKE ?", searchValue).
 | 
			
		||||
		Find(&users).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return users, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) SaveUser(ctx context.Context, id domain.UserIdentifier, updateFunc func(u *domain.User) (*domain.User, error)) error {
 | 
			
		||||
	userInfo := domain.GetUserInfo(ctx)
 | 
			
		||||
 | 
			
		||||
	err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		user, err := r.getOrCreateUser(userInfo, tx, id)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err // return any error will roll back
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		user, err = updateFunc(user)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err = r.upsertUser(userInfo, tx, user)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// return nil will commit the whole transaction
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) DeleteUser(ctx context.Context, id domain.UserIdentifier) error {
 | 
			
		||||
	err := r.db.WithContext(ctx).Delete(&domain.User{}, id).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) getOrCreateUser(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.UserIdentifier) (*domain.User, error) {
 | 
			
		||||
	var user domain.User
 | 
			
		||||
 | 
			
		||||
	// userDefaults will be applied to newly created user records
 | 
			
		||||
	userDefaults := domain.User{
 | 
			
		||||
		BaseModel: domain.BaseModel{
 | 
			
		||||
			CreatedBy: ui.UserId(),
 | 
			
		||||
			UpdatedBy: ui.UserId(),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
			UpdatedAt: time.Now(),
 | 
			
		||||
		},
 | 
			
		||||
		Identifier: id,
 | 
			
		||||
		Source:     domain.UserSourceDatabase,
 | 
			
		||||
		IsAdmin:    false,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := tx.Attrs(userDefaults).FirstOrCreate(&user, id).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &user, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) upsertUser(ui *domain.ContextUserInfo, tx *gorm.DB, user *domain.User) error {
 | 
			
		||||
	user.UpdatedBy = ui.UserId()
 | 
			
		||||
	user.UpdatedAt = time.Now()
 | 
			
		||||
 | 
			
		||||
	err := tx.Save(user).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// endregion users
 | 
			
		||||
 | 
			
		||||
// region statistics
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) UpdateInterfaceStatus(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error)) error {
 | 
			
		||||
	err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		in, err := r.getOrCreateInterfaceStatus(tx, id)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err // return any error will roll back
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		in, err = updateFunc(in)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err = r.upsertInterfaceStatus(tx, in)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// return nil will commit the whole transaction
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) getOrCreateInterfaceStatus(tx *gorm.DB, id domain.InterfaceIdentifier) (*domain.InterfaceStatus, error) {
 | 
			
		||||
	var in domain.InterfaceStatus
 | 
			
		||||
 | 
			
		||||
	// defaults will be applied to newly created record
 | 
			
		||||
	defaults := domain.InterfaceStatus{
 | 
			
		||||
		InterfaceId: id,
 | 
			
		||||
		UpdatedAt:   time.Now(),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := tx.Attrs(defaults).FirstOrCreate(&in, id).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &in, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) upsertInterfaceStatus(tx *gorm.DB, in *domain.InterfaceStatus) error {
 | 
			
		||||
	err := tx.Save(in).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) UpdatePeerStatus(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error)) error {
 | 
			
		||||
	err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		in, err := r.getOrCreatePeerStatus(tx, id)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err // return any error will roll back
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		in, err = updateFunc(in)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err = r.upsertPeerStatus(tx, in)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// return nil will commit the whole transaction
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) getOrCreatePeerStatus(tx *gorm.DB, id domain.PeerIdentifier) (*domain.PeerStatus, error) {
 | 
			
		||||
	var in domain.PeerStatus
 | 
			
		||||
 | 
			
		||||
	// defaults will be applied to newly created record
 | 
			
		||||
	defaults := domain.PeerStatus{
 | 
			
		||||
		PeerId:    id,
 | 
			
		||||
		UpdatedAt: time.Now(),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := tx.Attrs(defaults).FirstOrCreate(&in, id).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &in, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) upsertPeerStatus(tx *gorm.DB, in *domain.PeerStatus) error {
 | 
			
		||||
	err := tx.Save(in).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// endregion statistics
 | 
			
		||||
 | 
			
		||||
// region audit
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) SaveAuditEntry(ctx context.Context, entry *domain.AuditEntry) error {
 | 
			
		||||
	err := r.db.WithContext(ctx).Save(entry).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// endregion audit
 | 
			
		||||
							
								
								
									
										43
									
								
								internal/adapters/database_integration_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								internal/adapters/database_integration_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,43 @@
 | 
			
		||||
//go:build integration
 | 
			
		||||
 | 
			
		||||
package adapters
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"github.com/glebarez/sqlite"
 | 
			
		||||
	"github.com/stretchr/testify/assert"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func tempSqliteDb(t *testing.T) *gorm.DB {
 | 
			
		||||
 | 
			
		||||
	// github.com/mattn/go-sqlite3
 | 
			
		||||
	db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	return db
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Test_sqlRepo_migrate(t *testing.T) {
 | 
			
		||||
	db := tempSqliteDb(t)
 | 
			
		||||
 | 
			
		||||
	r := SqlRepo{db: db}
 | 
			
		||||
 | 
			
		||||
	err := r.migrate()
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	// check result
 | 
			
		||||
	var sqlStatement []sql.NullString
 | 
			
		||||
	db.Raw("SELECT sql FROM sqlite_master").Find(&sqlStatement)
 | 
			
		||||
	fmt.Println("Table Schemas:")
 | 
			
		||||
	for _, stm := range sqlStatement {
 | 
			
		||||
		if stm.Valid {
 | 
			
		||||
			fmt.Println(stm.String)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										54
									
								
								internal/adapters/filesystem.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								internal/adapters/filesystem.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,54 @@
 | 
			
		||||
package adapters
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/sirupsen/logrus"
 | 
			
		||||
	"io"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type FilesystemRepo struct {
 | 
			
		||||
	basePath string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewFileSystemRepository(basePath string) (*FilesystemRepo, error) {
 | 
			
		||||
	if basePath == "" {
 | 
			
		||||
		return nil, nil // no path, return empty repository
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r := &FilesystemRepo{basePath: basePath}
 | 
			
		||||
 | 
			
		||||
	if err := os.MkdirAll(r.basePath, os.ModePerm); err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("failed to create base directory %s: %w", basePath, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return r, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *FilesystemRepo) WriteFile(path string, contents io.Reader) error {
 | 
			
		||||
	filePath := filepath.Join(r.basePath, path)
 | 
			
		||||
	parentDirectory := filepath.Dir(filePath)
 | 
			
		||||
 | 
			
		||||
	if err := os.MkdirAll(parentDirectory, os.ModePerm); err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to create parent directory %s: %w", parentDirectory, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	file, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, os.ModePerm)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to open file %s: %w", file.Name(), err)
 | 
			
		||||
	}
 | 
			
		||||
	defer func(file *os.File) {
 | 
			
		||||
		if err := file.Close(); err != nil {
 | 
			
		||||
			logrus.Errorf("failed to close file %s: %v", file.Name(), err)
 | 
			
		||||
		}
 | 
			
		||||
	}(file)
 | 
			
		||||
 | 
			
		||||
	_, err = io.Copy(file, contents)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to write file contents: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										138
									
								
								internal/adapters/mailer.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								internal/adapters/mailer.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,138 @@
 | 
			
		||||
package adapters
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/h44z/wg-portal/internal"
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/config"
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/domain"
 | 
			
		||||
	mail "github.com/xhit/go-simple-mail/v2"
 | 
			
		||||
	"io"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type MailRepo struct {
 | 
			
		||||
	cfg *config.MailConfig
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewSmtpMailRepo(cfg config.MailConfig) MailRepo {
 | 
			
		||||
	return MailRepo{cfg: &cfg}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Send sends a mail.
 | 
			
		||||
func (r MailRepo) Send(_ context.Context, subject, body string, to []string, options *domain.MailOptions) error {
 | 
			
		||||
	if options == nil {
 | 
			
		||||
		options = &domain.MailOptions{}
 | 
			
		||||
	}
 | 
			
		||||
	r.setDefaultOptions(r.cfg.From, options)
 | 
			
		||||
 | 
			
		||||
	if len(to) == 0 {
 | 
			
		||||
		return errors.New("missing email recipient")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	uniqueTo := internal.UniqueStringSlice(to)
 | 
			
		||||
	email := mail.NewMSG()
 | 
			
		||||
	email.SetFrom(r.cfg.From).
 | 
			
		||||
		AddTo(uniqueTo...).
 | 
			
		||||
		SetReplyTo(options.ReplyTo).
 | 
			
		||||
		SetSubject(subject).
 | 
			
		||||
		SetBody(mail.TextPlain, body)
 | 
			
		||||
 | 
			
		||||
	if len(options.Cc) > 0 {
 | 
			
		||||
		// the underlying mail library does not allow the same address to appear in TO and CC... so filter entries that are already included
 | 
			
		||||
		// in the TO addresses
 | 
			
		||||
		cc := RemoveDuplicates(internal.UniqueStringSlice(options.Cc), uniqueTo)
 | 
			
		||||
		email.AddCc(cc...)
 | 
			
		||||
	}
 | 
			
		||||
	if len(options.Bcc) > 0 {
 | 
			
		||||
		// the underlying mail library does not allow the same address to appear in TO or CC and BCC... so filter entries that are already
 | 
			
		||||
		// included in the TO and CC addresses
 | 
			
		||||
		bcc := RemoveDuplicates(internal.UniqueStringSlice(options.Bcc), uniqueTo)
 | 
			
		||||
		bcc = RemoveDuplicates(bcc, options.Cc)
 | 
			
		||||
 | 
			
		||||
		email.AddCc(internal.UniqueStringSlice(options.Bcc)...)
 | 
			
		||||
	}
 | 
			
		||||
	if options.HtmlBody != "" {
 | 
			
		||||
		email.AddAlternative(mail.TextHTML, options.HtmlBody)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, attachment := range options.Attachments {
 | 
			
		||||
		attachmentData, err := io.ReadAll(attachment.Data)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("failed to read attachment data for %s: %w", attachment.Name, err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if attachment.Embedded {
 | 
			
		||||
			email.AddInlineData(attachmentData, attachment.Name, attachment.ContentType)
 | 
			
		||||
		} else {
 | 
			
		||||
			email.AddAttachmentData(attachmentData, attachment.Name, attachment.ContentType)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Call Send and pass the client
 | 
			
		||||
	srv := r.getMailServer()
 | 
			
		||||
	client, err := srv.Connect()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to connect to SMTP server: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = email.Send(client)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to send email: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r MailRepo) setDefaultOptions(sender string, options *domain.MailOptions) {
 | 
			
		||||
	if options.ReplyTo == "" {
 | 
			
		||||
		options.ReplyTo = sender
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r MailRepo) getMailServer() *mail.SMTPServer {
 | 
			
		||||
	srv := mail.NewSMTPClient()
 | 
			
		||||
 | 
			
		||||
	srv.ConnectTimeout = 30 * time.Second
 | 
			
		||||
	srv.SendTimeout = 30 * time.Second
 | 
			
		||||
	srv.Host = r.cfg.Host
 | 
			
		||||
	srv.Port = r.cfg.Port
 | 
			
		||||
	srv.Username = r.cfg.Username
 | 
			
		||||
	srv.Password = r.cfg.Password
 | 
			
		||||
 | 
			
		||||
	switch r.cfg.Encryption {
 | 
			
		||||
	case config.MailEncryptionTLS:
 | 
			
		||||
		srv.Encryption = mail.EncryptionSSLTLS
 | 
			
		||||
	case config.MailEncryptionStartTLS:
 | 
			
		||||
		srv.Encryption = mail.EncryptionSTARTTLS
 | 
			
		||||
	default: // MailEncryptionNone
 | 
			
		||||
		srv.Encryption = mail.EncryptionNone
 | 
			
		||||
	}
 | 
			
		||||
	srv.TLSConfig = &tls.Config{ServerName: srv.Host, InsecureSkipVerify: !r.cfg.CertValidation}
 | 
			
		||||
	switch r.cfg.AuthType {
 | 
			
		||||
	case config.MailAuthPlain:
 | 
			
		||||
		srv.Authentication = mail.AuthPlain
 | 
			
		||||
	case config.MailAuthLogin:
 | 
			
		||||
		srv.Authentication = mail.AuthLogin
 | 
			
		||||
	case config.MailAuthCramMD5:
 | 
			
		||||
		srv.Authentication = mail.AuthCRAMMD5
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return srv
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveDuplicates removes addresses from the given string slice which are contained in the remove slice.
 | 
			
		||||
func RemoveDuplicates(slice []string, remove []string) []string {
 | 
			
		||||
	uniqueSlice := make([]string, 0, len(slice))
 | 
			
		||||
 | 
			
		||||
	for _, i := range remove {
 | 
			
		||||
		for _, j := range slice {
 | 
			
		||||
			if i != j {
 | 
			
		||||
				uniqueSlice = append(uniqueSlice, j)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return uniqueSlice
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										99
									
								
								internal/adapters/wgquick.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								internal/adapters/wgquick.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,99 @@
 | 
			
		||||
package adapters
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/h44z/wg-portal/internal"
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/domain"
 | 
			
		||||
	"github.com/sirupsen/logrus"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// WgQuickRepo implements higher level wg-quick like interactions like setting DNS, routing tables or interface hooks.
 | 
			
		||||
type WgQuickRepo struct {
 | 
			
		||||
	shellCmd              string
 | 
			
		||||
	resolvConfIfacePrefix string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewWgQuickRepo() *WgQuickRepo {
 | 
			
		||||
	return &WgQuickRepo{
 | 
			
		||||
		shellCmd:              "bash",
 | 
			
		||||
		resolvConfIfacePrefix: "tun.",
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgQuickRepo) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
 | 
			
		||||
	if hookCmd == "" {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := r.exec(hookCmd, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to exec hook: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgQuickRepo) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
 | 
			
		||||
	if dnsStr == "" && dnsSearchStr == "" {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dnsServers := internal.SliceString(dnsStr)
 | 
			
		||||
	dnsSearchDomains := internal.SliceString(dnsSearchStr)
 | 
			
		||||
 | 
			
		||||
	dnsCommand := "resolvconf -a %resPref%i -m 0 -x"
 | 
			
		||||
	dnsCommandInput := make([]string, 0, len(dnsServers)+len(dnsSearchDomains))
 | 
			
		||||
 | 
			
		||||
	for _, dnsServer := range dnsServers {
 | 
			
		||||
		dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("nameserver %s", dnsServer))
 | 
			
		||||
	}
 | 
			
		||||
	for _, searchDomain := range dnsSearchDomains {
 | 
			
		||||
		dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("search %s", searchDomain))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := r.exec(dnsCommand, id, dnsCommandInput...)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to set dns settings: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgQuickRepo) UnsetDNS(id domain.InterfaceIdentifier) error {
 | 
			
		||||
	dnsCommand := "resolvconf -d %resPref%i -f"
 | 
			
		||||
 | 
			
		||||
	err := r.exec(dnsCommand, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to unset dns settings: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgQuickRepo) replaceCommandPlaceHolders(command string, interfaceId domain.InterfaceIdentifier) string {
 | 
			
		||||
	command = strings.ReplaceAll(command, "%resPref", r.resolvConfIfacePrefix)
 | 
			
		||||
	return strings.ReplaceAll(command, "%i", string(interfaceId))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgQuickRepo) exec(command string, interfaceId domain.InterfaceIdentifier, stdin ...string) error {
 | 
			
		||||
	commandWithInterfaceName := r.replaceCommandPlaceHolders(command, interfaceId)
 | 
			
		||||
	cmd := exec.Command(r.shellCmd, "-ce", commandWithInterfaceName)
 | 
			
		||||
	if len(stdin) > 0 {
 | 
			
		||||
		b := &bytes.Buffer{}
 | 
			
		||||
		for _, ln := range stdin {
 | 
			
		||||
			if _, err := fmt.Fprint(b, ln); err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		cmd.Stdin = b
 | 
			
		||||
	}
 | 
			
		||||
	out, err := cmd.CombinedOutput() // execute and wait for output
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err)
 | 
			
		||||
	}
 | 
			
		||||
	logrus.Tracef("executed shell command %s, with output: %s", commandWithInterfaceName, string(out))
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										430
									
								
								internal/adapters/wireguard.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										430
									
								
								internal/adapters/wireguard.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,430 @@
 | 
			
		||||
package adapters
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/domain"
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/lowlevel"
 | 
			
		||||
	"github.com/vishvananda/netlink"
 | 
			
		||||
	"golang.zx2c4.com/wireguard/wgctrl"
 | 
			
		||||
	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 | 
			
		||||
	"os"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// WgRepo implements all low-level WireGuard interactions.
 | 
			
		||||
type WgRepo struct {
 | 
			
		||||
	wg lowlevel.WireGuardClient
 | 
			
		||||
	nl lowlevel.NetlinkClient
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewWireGuardRepository() *WgRepo {
 | 
			
		||||
	wg, err := wgctrl.New()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic("failed to init wgctrl: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	nl := &lowlevel.NetlinkManager{}
 | 
			
		||||
 | 
			
		||||
	repo := &WgRepo{
 | 
			
		||||
		wg: wg,
 | 
			
		||||
		nl: nl,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return repo
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) {
 | 
			
		||||
	devices, err := r.wg.Devices()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("device list error: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	interfaces := make([]domain.PhysicalInterface, 0, len(devices))
 | 
			
		||||
	for _, device := range devices {
 | 
			
		||||
		interfaceModel, err := r.convertWireGuardInterface(device)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("interface convert failed for %s: %w", device.Name, err)
 | 
			
		||||
		}
 | 
			
		||||
		interfaces = append(interfaces, interfaceModel)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return interfaces, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
 | 
			
		||||
	return r.getInterface(id)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) {
 | 
			
		||||
	device, err := r.wg.Device(string(deviceId))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("device error: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	peers := make([]domain.PhysicalPeer, 0, len(device.Peers))
 | 
			
		||||
	for _, peer := range device.Peers {
 | 
			
		||||
		peerModel, err := r.convertWireGuardPeer(&peer)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("peer convert failed for %v: %w", peer.PublicKey, err)
 | 
			
		||||
		}
 | 
			
		||||
		peers = append(peers, peerModel)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return peers, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) {
 | 
			
		||||
	return r.getPeer(deviceId, id)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) convertWireGuardInterface(device *wgtypes.Device) (domain.PhysicalInterface, error) {
 | 
			
		||||
	// read data from wgctrl interface
 | 
			
		||||
 | 
			
		||||
	iface := domain.PhysicalInterface{
 | 
			
		||||
		Identifier: domain.InterfaceIdentifier(device.Name),
 | 
			
		||||
		KeyPair: domain.KeyPair{
 | 
			
		||||
			PrivateKey: device.PrivateKey.String(),
 | 
			
		||||
			PublicKey:  device.PublicKey.String(),
 | 
			
		||||
		},
 | 
			
		||||
		ListenPort:    device.ListenPort,
 | 
			
		||||
		Addresses:     nil,
 | 
			
		||||
		Mtu:           0,
 | 
			
		||||
		FirewallMark:  int32(device.FirewallMark),
 | 
			
		||||
		DeviceUp:      false,
 | 
			
		||||
		ImportSource:  "wgctrl",
 | 
			
		||||
		DeviceType:    device.Type.String(),
 | 
			
		||||
		BytesUpload:   0,
 | 
			
		||||
		BytesDownload: 0,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// read data from netlink interface
 | 
			
		||||
 | 
			
		||||
	lowLevelInterface, err := r.nl.LinkByName(device.Name)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return domain.PhysicalInterface{}, fmt.Errorf("netlink error for %s: %w", device.Name, err)
 | 
			
		||||
	}
 | 
			
		||||
	ipAddresses, err := r.nl.AddrList(lowLevelInterface)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return domain.PhysicalInterface{}, fmt.Errorf("ip read error for %s: %w", device.Name, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, addr := range ipAddresses {
 | 
			
		||||
		iface.Addresses = append(iface.Addresses, domain.CidrFromNetlinkAddr(addr))
 | 
			
		||||
	}
 | 
			
		||||
	iface.Mtu = lowLevelInterface.Attrs().MTU
 | 
			
		||||
	iface.DeviceUp = lowLevelInterface.Attrs().OperState == netlink.OperUnknown // wg only supports unknown
 | 
			
		||||
	if stats := lowLevelInterface.Attrs().Statistics; stats != nil {
 | 
			
		||||
		iface.BytesUpload = stats.TxBytes
 | 
			
		||||
		iface.BytesDownload = stats.RxBytes
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return iface, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer, error) {
 | 
			
		||||
	peerModel := domain.PhysicalPeer{
 | 
			
		||||
		Identifier: domain.PeerIdentifier(peer.PublicKey.String()),
 | 
			
		||||
		Endpoint:   "",
 | 
			
		||||
		AllowedIPs: nil,
 | 
			
		||||
		KeyPair: domain.KeyPair{
 | 
			
		||||
			PublicKey: peer.PublicKey.String(),
 | 
			
		||||
		},
 | 
			
		||||
		PresharedKey:        "",
 | 
			
		||||
		PersistentKeepalive: int(peer.PersistentKeepaliveInterval.Seconds()),
 | 
			
		||||
		LastHandshake:       peer.LastHandshakeTime,
 | 
			
		||||
		ProtocolVersion:     peer.ProtocolVersion,
 | 
			
		||||
		BytesUpload:         uint64(peer.ReceiveBytes),
 | 
			
		||||
		BytesDownload:       uint64(peer.TransmitBytes),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, addr := range peer.AllowedIPs {
 | 
			
		||||
		peerModel.AllowedIPs = append(peerModel.AllowedIPs, domain.CidrFromIpNet(addr))
 | 
			
		||||
	}
 | 
			
		||||
	if peer.Endpoint != nil {
 | 
			
		||||
		peerModel.Endpoint = peer.Endpoint.String()
 | 
			
		||||
	}
 | 
			
		||||
	if peer.PresharedKey != (wgtypes.Key{}) {
 | 
			
		||||
		peerModel.PresharedKey = domain.PreSharedKey(peer.PresharedKey.String())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return peerModel, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error {
 | 
			
		||||
	physicalInterface, err := r.getOrCreateInterface(id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if updateFunc != nil {
 | 
			
		||||
		physicalInterface, err = updateFunc(physicalInterface)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := r.updateLowLevelInterface(physicalInterface); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if err := r.updateWireGuardInterface(physicalInterface); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) getOrCreateInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
 | 
			
		||||
	device, err := r.getInterface(id)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return device, nil
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil && !errors.Is(err, os.ErrNotExist) {
 | 
			
		||||
		return nil, fmt.Errorf("device error: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// create new device
 | 
			
		||||
	if err := r.createLowLevelInterface(id); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	device, err = r.getInterface(id)
 | 
			
		||||
	return device, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) getInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
 | 
			
		||||
	device, err := r.wg.Device(string(id))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	pi, err := r.convertWireGuardInterface(device)
 | 
			
		||||
	return &pi, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) createLowLevelInterface(id domain.InterfaceIdentifier) error {
 | 
			
		||||
	link := &netlink.GenericLink{
 | 
			
		||||
		LinkAttrs: netlink.LinkAttrs{
 | 
			
		||||
			Name: string(id),
 | 
			
		||||
		},
 | 
			
		||||
		LinkType: "wireguard",
 | 
			
		||||
	}
 | 
			
		||||
	err := r.nl.LinkAdd(link)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("link add failed: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) updateLowLevelInterface(pi *domain.PhysicalInterface) error {
 | 
			
		||||
	link, err := r.nl.LinkByName(string(pi.Identifier))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if pi.Mtu != 0 {
 | 
			
		||||
		if err := r.nl.LinkSetMTU(link, pi.Mtu); err != nil {
 | 
			
		||||
			return fmt.Errorf("mtu error: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, addr := range pi.Addresses {
 | 
			
		||||
		err := r.nl.AddrReplace(link, addr.NetlinkAddr())
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("failed to set ip %s: %w", addr.String(), err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Remove unwanted IP addresses
 | 
			
		||||
	rawAddresses, err := r.nl.AddrList(link)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to fetch interface ips: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	for _, rawAddr := range rawAddresses {
 | 
			
		||||
		netlinkAddr := domain.CidrFromNetlinkAddr(rawAddr)
 | 
			
		||||
		remove := true
 | 
			
		||||
		for _, addr := range pi.Addresses {
 | 
			
		||||
			if addr == netlinkAddr {
 | 
			
		||||
				remove = false
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if !remove {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err := r.nl.AddrDel(link, &rawAddr)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("failed to remove deprecated ip %s: %w", netlinkAddr.String(), err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Update link state
 | 
			
		||||
	if pi.DeviceUp {
 | 
			
		||||
		if err := r.nl.LinkSetUp(link); err != nil {
 | 
			
		||||
			return fmt.Errorf("failed to bring up device: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		if err := r.nl.LinkSetDown(link); err != nil {
 | 
			
		||||
			return fmt.Errorf("failed to bring down device: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
 | 
			
		||||
	pKey, err := wgtypes.NewKey(pi.KeyPair.GetPrivateKeyBytes())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var fwMark *int
 | 
			
		||||
	if pi.FirewallMark != 0 {
 | 
			
		||||
		*fwMark = int(pi.FirewallMark)
 | 
			
		||||
	}
 | 
			
		||||
	err = r.wg.ConfigureDevice(string(pi.Identifier), wgtypes.Config{
 | 
			
		||||
		PrivateKey:   &pKey,
 | 
			
		||||
		ListenPort:   &pi.ListenPort,
 | 
			
		||||
		FirewallMark: fwMark,
 | 
			
		||||
		ReplacePeers: false,
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
 | 
			
		||||
	if err := r.deleteLowLevelInterface(id); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) deleteLowLevelInterface(id domain.InterfaceIdentifier) error {
 | 
			
		||||
	link, err := r.nl.LinkByName(string(id))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		var linkNotFoundError netlink.LinkNotFoundError
 | 
			
		||||
		if errors.As(err, &linkNotFoundError) {
 | 
			
		||||
			return nil // ignore not found error
 | 
			
		||||
		}
 | 
			
		||||
		return fmt.Errorf("unable to find low level interface: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = r.nl.LinkDel(link)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to delete low level interface: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error {
 | 
			
		||||
	physicalPeer, err := r.getOrCreatePeer(deviceId, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	physicalPeer, err = updateFunc(physicalPeer)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := r.updatePeer(deviceId, physicalPeer); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) getOrCreatePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) {
 | 
			
		||||
	peer, err := r.getPeer(deviceId, id)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return peer, nil
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil && !errors.Is(err, os.ErrNotExist) {
 | 
			
		||||
		return nil, fmt.Errorf("peer error: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// create new peer
 | 
			
		||||
	err = r.wg.ConfigureDevice(string(deviceId), wgtypes.Config{Peers: []wgtypes.PeerConfig{{
 | 
			
		||||
		PublicKey: id.ToPublicKey(),
 | 
			
		||||
	}}})
 | 
			
		||||
 | 
			
		||||
	peer, err = r.getPeer(deviceId, id)
 | 
			
		||||
	return peer, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) getPeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) {
 | 
			
		||||
	if !id.IsPublicKey() {
 | 
			
		||||
		return nil, errors.New("invalid public key")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	device, err := r.wg.Device(string(deviceId))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	publicKey := id.ToPublicKey()
 | 
			
		||||
	for _, peer := range device.Peers {
 | 
			
		||||
		if peer.PublicKey != publicKey {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		peerModel, err := r.convertWireGuardPeer(&peer)
 | 
			
		||||
		return &peerModel, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, os.ErrNotExist
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.PhysicalPeer) error {
 | 
			
		||||
	cfg := wgtypes.PeerConfig{
 | 
			
		||||
		PublicKey:                   pp.GetPublicKey(),
 | 
			
		||||
		Remove:                      false,
 | 
			
		||||
		UpdateOnly:                  true,
 | 
			
		||||
		PresharedKey:                pp.GetPresharedKey(),
 | 
			
		||||
		Endpoint:                    pp.GetEndpointAddress(),
 | 
			
		||||
		PersistentKeepaliveInterval: pp.GetPersistentKeepaliveTime(),
 | 
			
		||||
		ReplaceAllowedIPs:           true,
 | 
			
		||||
		AllowedIPs:                  pp.GetAllowedIPs(),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := r.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error {
 | 
			
		||||
	if !id.IsPublicKey() {
 | 
			
		||||
		return errors.New("invalid public key")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := r.deletePeer(deviceId, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *WgRepo) deletePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error {
 | 
			
		||||
	cfg := wgtypes.PeerConfig{
 | 
			
		||||
		PublicKey: id.ToPublicKey(),
 | 
			
		||||
		Remove:    true,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := r.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										121
									
								
								internal/adapters/wireguard_integration_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										121
									
								
								internal/adapters/wireguard_integration_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,121 @@
 | 
			
		||||
//go:build integration
 | 
			
		||||
 | 
			
		||||
package adapters
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"os"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/stretchr/testify/require"
 | 
			
		||||
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/domain"
 | 
			
		||||
 | 
			
		||||
	"github.com/stretchr/testify/assert"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// setup WireGuard manager with no linked store
 | 
			
		||||
func setup(t *testing.T) *WgRepo {
 | 
			
		||||
	if getProcessOwner() != "root" {
 | 
			
		||||
		t.Fatalf("this tests need to be executed as root user")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	repo := NewWireGuardRepository()
 | 
			
		||||
 | 
			
		||||
	return repo
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getProcessOwner() string {
 | 
			
		||||
	stdout, err := exec.Command("ps", "-o", "user=", "-p", strconv.Itoa(os.Getpid())).Output()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		fmt.Println(err)
 | 
			
		||||
		os.Exit(1)
 | 
			
		||||
	}
 | 
			
		||||
	return strings.TrimSpace(string(stdout))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Test_wgRepository_GetInterfaces(t *testing.T) {
 | 
			
		||||
	mgr := setup(t)
 | 
			
		||||
 | 
			
		||||
	interfaceName := domain.InterfaceIdentifier("wg_test_001")
 | 
			
		||||
	defer mgr.DeleteInterface(context.Background(), interfaceName)
 | 
			
		||||
	err := mgr.SaveInterface(context.Background(), interfaceName, nil)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	interfaceName2 := domain.InterfaceIdentifier("wg_test_002")
 | 
			
		||||
	defer mgr.DeleteInterface(context.Background(), interfaceName2)
 | 
			
		||||
	err = mgr.SaveInterface(context.Background(), interfaceName2, nil)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	interfaces, err := mgr.GetInterfaces(context.Background())
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.Len(t, interfaces, 2)
 | 
			
		||||
	for _, iface := range interfaces {
 | 
			
		||||
		assert.True(t, iface.Identifier == interfaceName || iface.Identifier == interfaceName2)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestWireGuardCreateInterface(t *testing.T) {
 | 
			
		||||
	mgr := setup(t)
 | 
			
		||||
 | 
			
		||||
	interfaceName := domain.InterfaceIdentifier("wg_test_001")
 | 
			
		||||
	ipAddress := "10.11.12.13"
 | 
			
		||||
	ipV6Address := "1337:d34d:b33f::2"
 | 
			
		||||
	defer mgr.DeleteInterface(context.Background(), interfaceName)
 | 
			
		||||
 | 
			
		||||
	err := mgr.SaveInterface(context.Background(), interfaceName, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
 | 
			
		||||
		pi.Addresses = []domain.Cidr{
 | 
			
		||||
			domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipAddress), Mask: net.CIDRMask(24, 32)}),
 | 
			
		||||
			domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipV6Address), Mask: net.CIDRMask(64, 128)}),
 | 
			
		||||
		}
 | 
			
		||||
		return pi, nil
 | 
			
		||||
	})
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	// Validate that the interface has been created
 | 
			
		||||
	cmd := exec.Command("ip", "addr")
 | 
			
		||||
	out, err := cmd.CombinedOutput()
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.Contains(t, string(out), interfaceName)
 | 
			
		||||
	assert.Contains(t, string(out), ipAddress)
 | 
			
		||||
	assert.Contains(t, string(out), ipV6Address)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestWireGuardUpdateInterface(t *testing.T) {
 | 
			
		||||
	mgr := setup(t)
 | 
			
		||||
 | 
			
		||||
	interfaceName := domain.InterfaceIdentifier("wg_test_001")
 | 
			
		||||
	defer mgr.DeleteInterface(context.Background(), interfaceName)
 | 
			
		||||
 | 
			
		||||
	err := mgr.SaveInterface(context.Background(), interfaceName, nil)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	cmd := exec.Command("ip", "addr")
 | 
			
		||||
	out, err := cmd.CombinedOutput()
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
	require.Contains(t, string(out), interfaceName)
 | 
			
		||||
 | 
			
		||||
	ipAddress := "10.11.12.13"
 | 
			
		||||
	ipV6Address := "1337:d34d:b33f::2"
 | 
			
		||||
	err = mgr.SaveInterface(context.Background(), interfaceName, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
 | 
			
		||||
		pi.Addresses = []domain.Cidr{
 | 
			
		||||
			domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipAddress), Mask: net.CIDRMask(24, 32)}),
 | 
			
		||||
			domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipV6Address), Mask: net.CIDRMask(64, 128)}),
 | 
			
		||||
		}
 | 
			
		||||
		return pi, nil
 | 
			
		||||
	})
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	// Validate that the interface has been updated
 | 
			
		||||
	cmd = exec.Command("ip", "addr")
 | 
			
		||||
	out, err = cmd.CombinedOutput()
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.Contains(t, string(out), interfaceName)
 | 
			
		||||
	assert.Contains(t, string(out), ipAddress)
 | 
			
		||||
	assert.Contains(t, string(out), ipV6Address)
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user