mirror of
				https://github.com/h44z/wg-portal.git
				synced 2025-11-03 23:56:18 +00:00 
			
		
		
		
	fix change of peer identifier (public key) (#265)
This commit is contained in:
		@@ -4,15 +4,16 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/sirupsen/logrus"
 | 
			
		||||
	"gorm.io/gorm/clause"
 | 
			
		||||
	"gorm.io/gorm/logger"
 | 
			
		||||
	"gorm.io/gorm/utils"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/sirupsen/logrus"
 | 
			
		||||
	"gorm.io/gorm/clause"
 | 
			
		||||
	"gorm.io/gorm/logger"
 | 
			
		||||
	"gorm.io/gorm/utils"
 | 
			
		||||
 | 
			
		||||
	"github.com/glebarez/sqlite"
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/config"
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/domain"
 | 
			
		||||
@@ -204,7 +205,8 @@ func (r *SqlRepo) preCheck() error {
 | 
			
		||||
		return nil // we probably don't have a V1 database =)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fmt.Errorf("detected a WireGuard Portal V1 database (version: %s) - please migrate first", lastVersion.Version)
 | 
			
		||||
	return fmt.Errorf("detected a WireGuard Portal V1 database (version: %s) - please migrate first",
 | 
			
		||||
		lastVersion.Version)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) migrate() error {
 | 
			
		||||
@@ -249,7 +251,11 @@ func (r *SqlRepo) GetInterface(ctx context.Context, id domain.InterfaceIdentifie
 | 
			
		||||
	return &in, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) {
 | 
			
		||||
func (r *SqlRepo) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
 | 
			
		||||
	*domain.Interface,
 | 
			
		||||
	[]domain.Peer,
 | 
			
		||||
	error,
 | 
			
		||||
) {
 | 
			
		||||
	in, err := r.GetInterface(ctx, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, fmt.Errorf("failed to load interface: %w", err)
 | 
			
		||||
@@ -305,7 +311,11 @@ func (r *SqlRepo) FindInterfaces(ctx context.Context, search string) ([]domain.I
 | 
			
		||||
	return users, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) SaveInterface(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.Interface) (*domain.Interface, error)) error {
 | 
			
		||||
func (r *SqlRepo) SaveInterface(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	id domain.InterfaceIdentifier,
 | 
			
		||||
	updateFunc func(in *domain.Interface) (*domain.Interface, error),
 | 
			
		||||
) error {
 | 
			
		||||
	userInfo := domain.GetUserInfo(ctx)
 | 
			
		||||
	err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		in, err := r.getOrCreateInterface(userInfo, tx, id)
 | 
			
		||||
@@ -333,7 +343,11 @@ func (r *SqlRepo) SaveInterface(ctx context.Context, id domain.InterfaceIdentifi
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) getOrCreateInterface(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.InterfaceIdentifier) (*domain.Interface, error) {
 | 
			
		||||
func (r *SqlRepo) getOrCreateInterface(
 | 
			
		||||
	ui *domain.ContextUserInfo,
 | 
			
		||||
	tx *gorm.DB,
 | 
			
		||||
	id domain.InterfaceIdentifier,
 | 
			
		||||
) (*domain.Interface, error) {
 | 
			
		||||
	var in domain.Interface
 | 
			
		||||
 | 
			
		||||
	// interfaceDefaults will be applied to newly created interface records
 | 
			
		||||
@@ -449,7 +463,10 @@ func (r *SqlRepo) GetInterfacePeers(ctx context.Context, id domain.InterfaceIden
 | 
			
		||||
	return peers, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ([]domain.Peer, error) {
 | 
			
		||||
func (r *SqlRepo) FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) (
 | 
			
		||||
	[]domain.Peer,
 | 
			
		||||
	error,
 | 
			
		||||
) {
 | 
			
		||||
	var peers []domain.Peer
 | 
			
		||||
 | 
			
		||||
	searchValue := "%" + strings.ToLower(search) + "%"
 | 
			
		||||
@@ -492,7 +509,11 @@ func (r *SqlRepo) FindUserPeers(ctx context.Context, id domain.UserIdentifier, s
 | 
			
		||||
	return peers, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) SavePeer(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.Peer) (*domain.Peer, error)) error {
 | 
			
		||||
func (r *SqlRepo) SavePeer(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	id domain.PeerIdentifier,
 | 
			
		||||
	updateFunc func(in *domain.Peer) (*domain.Peer, error),
 | 
			
		||||
) error {
 | 
			
		||||
	userInfo := domain.GetUserInfo(ctx)
 | 
			
		||||
	err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		peer, err := r.getOrCreatePeer(userInfo, tx, id)
 | 
			
		||||
@@ -520,7 +541,10 @@ func (r *SqlRepo) SavePeer(ctx context.Context, id domain.PeerIdentifier, update
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) getOrCreatePeer(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.PeerIdentifier) (*domain.Peer, error) {
 | 
			
		||||
func (r *SqlRepo) getOrCreatePeer(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.PeerIdentifier) (
 | 
			
		||||
	*domain.Peer,
 | 
			
		||||
	error,
 | 
			
		||||
) {
 | 
			
		||||
	var peer domain.Peer
 | 
			
		||||
 | 
			
		||||
	// interfaceDefaults will be applied to newly created interface records
 | 
			
		||||
@@ -601,7 +625,10 @@ func (r *SqlRepo) GetPeerIps(ctx context.Context) (map[domain.PeerIdentifier][]d
 | 
			
		||||
	return result, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error) {
 | 
			
		||||
func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (
 | 
			
		||||
	map[domain.Cidr][]domain.Cidr,
 | 
			
		||||
	error,
 | 
			
		||||
) {
 | 
			
		||||
	var peerIps []struct {
 | 
			
		||||
		domain.Cidr
 | 
			
		||||
		PeerId domain.PeerIdentifier `gorm:"column:peer_identifier"`
 | 
			
		||||
@@ -699,7 +726,11 @@ func (r *SqlRepo) FindUsers(ctx context.Context, search string) ([]domain.User,
 | 
			
		||||
	return users, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) SaveUser(ctx context.Context, id domain.UserIdentifier, updateFunc func(u *domain.User) (*domain.User, error)) error {
 | 
			
		||||
func (r *SqlRepo) SaveUser(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	id domain.UserIdentifier,
 | 
			
		||||
	updateFunc func(u *domain.User) (*domain.User, error),
 | 
			
		||||
) error {
 | 
			
		||||
	userInfo := domain.GetUserInfo(ctx)
 | 
			
		||||
 | 
			
		||||
	err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
@@ -737,7 +768,10 @@ func (r *SqlRepo) DeleteUser(ctx context.Context, id domain.UserIdentifier) erro
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) getOrCreateUser(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.UserIdentifier) (*domain.User, error) {
 | 
			
		||||
func (r *SqlRepo) getOrCreateUser(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.UserIdentifier) (
 | 
			
		||||
	*domain.User,
 | 
			
		||||
	error,
 | 
			
		||||
) {
 | 
			
		||||
	var user domain.User
 | 
			
		||||
 | 
			
		||||
	// userDefaults will be applied to newly created user records
 | 
			
		||||
@@ -777,7 +811,11 @@ func (r *SqlRepo) upsertUser(ui *domain.ContextUserInfo, tx *gorm.DB, user *doma
 | 
			
		||||
 | 
			
		||||
// region statistics
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) UpdateInterfaceStatus(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error)) error {
 | 
			
		||||
func (r *SqlRepo) UpdateInterfaceStatus(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	id domain.InterfaceIdentifier,
 | 
			
		||||
	updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error),
 | 
			
		||||
) error {
 | 
			
		||||
	err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		in, err := r.getOrCreateInterfaceStatus(tx, id)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
@@ -804,7 +842,10 @@ func (r *SqlRepo) UpdateInterfaceStatus(ctx context.Context, id domain.Interface
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) getOrCreateInterfaceStatus(tx *gorm.DB, id domain.InterfaceIdentifier) (*domain.InterfaceStatus, error) {
 | 
			
		||||
func (r *SqlRepo) getOrCreateInterfaceStatus(tx *gorm.DB, id domain.InterfaceIdentifier) (
 | 
			
		||||
	*domain.InterfaceStatus,
 | 
			
		||||
	error,
 | 
			
		||||
) {
 | 
			
		||||
	var in domain.InterfaceStatus
 | 
			
		||||
 | 
			
		||||
	// defaults will be applied to newly created record
 | 
			
		||||
@@ -830,7 +871,11 @@ func (r *SqlRepo) upsertInterfaceStatus(tx *gorm.DB, in *domain.InterfaceStatus)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) UpdatePeerStatus(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error)) error {
 | 
			
		||||
func (r *SqlRepo) UpdatePeerStatus(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	id domain.PeerIdentifier,
 | 
			
		||||
	updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error),
 | 
			
		||||
) error {
 | 
			
		||||
	err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		in, err := r.getOrCreatePeerStatus(tx, id)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
@@ -883,6 +928,15 @@ func (r *SqlRepo) upsertPeerStatus(tx *gorm.DB, in *domain.PeerStatus) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error {
 | 
			
		||||
	err := r.db.WithContext(ctx).Delete(&domain.PeerStatus{}, id).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// endregion statistics
 | 
			
		||||
 | 
			
		||||
// region audit
 | 
			
		||||
 
 | 
			
		||||
@@ -10,3 +10,4 @@ const TopicRouteUpdate = "route:update"
 | 
			
		||||
const TopicRouteRemove = "route:remove"
 | 
			
		||||
const TopicInterfaceUpdated = "interface:updated"
 | 
			
		||||
const TopicPeerInterfaceUpdated = "peer:interface:updated"
 | 
			
		||||
const TopicPeerIdentifierUpdated = "peer:identifier:updated"
 | 
			
		||||
 
 | 
			
		||||
@@ -13,13 +13,21 @@ type InterfaceAndPeerDatabaseRepo interface {
 | 
			
		||||
	GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
 | 
			
		||||
	FindInterfaces(ctx context.Context, search string) ([]domain.Interface, error)
 | 
			
		||||
	GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error)
 | 
			
		||||
	SaveInterface(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.Interface) (*domain.Interface, error)) error
 | 
			
		||||
	SaveInterface(
 | 
			
		||||
		ctx context.Context,
 | 
			
		||||
		id domain.InterfaceIdentifier,
 | 
			
		||||
		updateFunc func(in *domain.Interface) (*domain.Interface, error),
 | 
			
		||||
	) error
 | 
			
		||||
	DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
 | 
			
		||||
	GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
 | 
			
		||||
	FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ([]domain.Peer, error)
 | 
			
		||||
	GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
 | 
			
		||||
	FindUserPeers(ctx context.Context, id domain.UserIdentifier, search string) ([]domain.Peer, error)
 | 
			
		||||
	SavePeer(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.Peer) (*domain.Peer, error)) error
 | 
			
		||||
	SavePeer(
 | 
			
		||||
		ctx context.Context,
 | 
			
		||||
		id domain.PeerIdentifier,
 | 
			
		||||
		updateFunc func(in *domain.Peer) (*domain.Peer, error),
 | 
			
		||||
	) error
 | 
			
		||||
	DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
 | 
			
		||||
	GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
 | 
			
		||||
	GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
 | 
			
		||||
@@ -30,18 +38,40 @@ type StatisticsDatabaseRepo interface {
 | 
			
		||||
	GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
 | 
			
		||||
	GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
 | 
			
		||||
 | 
			
		||||
	UpdatePeerStatus(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error)) error
 | 
			
		||||
	UpdateInterfaceStatus(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error)) error
 | 
			
		||||
	UpdatePeerStatus(
 | 
			
		||||
		ctx context.Context,
 | 
			
		||||
		id domain.PeerIdentifier,
 | 
			
		||||
		updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error),
 | 
			
		||||
	) error
 | 
			
		||||
	UpdateInterfaceStatus(
 | 
			
		||||
		ctx context.Context,
 | 
			
		||||
		id domain.InterfaceIdentifier,
 | 
			
		||||
		updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error),
 | 
			
		||||
	) error
 | 
			
		||||
 | 
			
		||||
	DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type InterfaceController interface {
 | 
			
		||||
	GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
 | 
			
		||||
	GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
 | 
			
		||||
	GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
 | 
			
		||||
	GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error)
 | 
			
		||||
	SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error
 | 
			
		||||
	GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (
 | 
			
		||||
		*domain.PhysicalPeer,
 | 
			
		||||
		error,
 | 
			
		||||
	)
 | 
			
		||||
	SaveInterface(
 | 
			
		||||
		_ context.Context,
 | 
			
		||||
		id domain.InterfaceIdentifier,
 | 
			
		||||
		updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
 | 
			
		||||
	) error
 | 
			
		||||
	DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
 | 
			
		||||
	SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error
 | 
			
		||||
	SavePeer(
 | 
			
		||||
		_ context.Context,
 | 
			
		||||
		deviceId domain.InterfaceIdentifier,
 | 
			
		||||
		id domain.PeerIdentifier,
 | 
			
		||||
		updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
 | 
			
		||||
	) error
 | 
			
		||||
	DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -5,14 +5,17 @@ import (
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/app"
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/config"
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/domain"
 | 
			
		||||
	probing "github.com/prometheus-community/pro-bing"
 | 
			
		||||
	"github.com/sirupsen/logrus"
 | 
			
		||||
	evbus "github.com/vardius/message-bus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type StatisticsCollector struct {
 | 
			
		||||
	cfg *config.Config
 | 
			
		||||
	bus evbus.MessageBus
 | 
			
		||||
 | 
			
		||||
	pingWaitGroup sync.WaitGroup
 | 
			
		||||
	pingJobs      chan domain.Peer
 | 
			
		||||
@@ -22,14 +25,25 @@ type StatisticsCollector struct {
 | 
			
		||||
	ms MetricsServer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewStatisticsCollector(cfg *config.Config, db StatisticsDatabaseRepo, wg InterfaceController, ms MetricsServer) (*StatisticsCollector, error) {
 | 
			
		||||
	return &StatisticsCollector{
 | 
			
		||||
func NewStatisticsCollector(
 | 
			
		||||
	cfg *config.Config,
 | 
			
		||||
	bus evbus.MessageBus,
 | 
			
		||||
	db StatisticsDatabaseRepo,
 | 
			
		||||
	wg InterfaceController,
 | 
			
		||||
	ms MetricsServer,
 | 
			
		||||
) (*StatisticsCollector, error) {
 | 
			
		||||
	c := &StatisticsCollector{
 | 
			
		||||
		cfg: cfg,
 | 
			
		||||
		bus: bus,
 | 
			
		||||
 | 
			
		||||
		db: db,
 | 
			
		||||
		wg: wg,
 | 
			
		||||
		ms: ms,
 | 
			
		||||
	}, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.connectToMessageBus()
 | 
			
		||||
 | 
			
		||||
	return c, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *StatisticsCollector) StartBackgroundJobs(ctx context.Context) {
 | 
			
		||||
@@ -69,16 +83,17 @@ func (c *StatisticsCollector) collectInterfaceData(ctx context.Context) {
 | 
			
		||||
					logrus.Warnf("failed to load physical interface %s for data collection: %v", in.Identifier, err)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				err = c.db.UpdateInterfaceStatus(ctx, in.Identifier, func(i *domain.InterfaceStatus) (*domain.InterfaceStatus, error) {
 | 
			
		||||
					i.UpdatedAt = time.Now()
 | 
			
		||||
					i.BytesReceived = physicalInterface.BytesDownload
 | 
			
		||||
					i.BytesTransmitted = physicalInterface.BytesUpload
 | 
			
		||||
				err = c.db.UpdateInterfaceStatus(ctx, in.Identifier,
 | 
			
		||||
					func(i *domain.InterfaceStatus) (*domain.InterfaceStatus, error) {
 | 
			
		||||
						i.UpdatedAt = time.Now()
 | 
			
		||||
						i.BytesReceived = physicalInterface.BytesDownload
 | 
			
		||||
						i.BytesTransmitted = physicalInterface.BytesUpload
 | 
			
		||||
 | 
			
		||||
					// Update prometheus metrics
 | 
			
		||||
					go c.updateInterfaceMetrics(*i)
 | 
			
		||||
						// Update prometheus metrics
 | 
			
		||||
						go c.updateInterfaceMetrics(*i)
 | 
			
		||||
 | 
			
		||||
					return i, nil
 | 
			
		||||
				})
 | 
			
		||||
						return i, nil
 | 
			
		||||
					})
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logrus.Warnf("failed to update interface status for %s: %v", in.Identifier, err)
 | 
			
		||||
				}
 | 
			
		||||
@@ -120,36 +135,43 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				for _, peer := range peers {
 | 
			
		||||
					err = c.db.UpdatePeerStatus(ctx, peer.Identifier, func(p *domain.PeerStatus) (*domain.PeerStatus, error) {
 | 
			
		||||
						var lastHandshake *time.Time
 | 
			
		||||
						if !peer.LastHandshake.IsZero() {
 | 
			
		||||
							lastHandshake = &peer.LastHandshake
 | 
			
		||||
						}
 | 
			
		||||
					err = c.db.UpdatePeerStatus(ctx, peer.Identifier,
 | 
			
		||||
						func(p *domain.PeerStatus) (*domain.PeerStatus, error) {
 | 
			
		||||
							var lastHandshake *time.Time
 | 
			
		||||
							if !peer.LastHandshake.IsZero() {
 | 
			
		||||
								lastHandshake = &peer.LastHandshake
 | 
			
		||||
							}
 | 
			
		||||
 | 
			
		||||
						// calculate if session was restarted
 | 
			
		||||
						p.UpdatedAt = time.Now()
 | 
			
		||||
						p.LastSessionStart = getSessionStartTime(*p, peer.BytesUpload, peer.BytesDownload, lastHandshake)
 | 
			
		||||
						p.BytesReceived = peer.BytesUpload      // store bytes that where uploaded from the peer and received by the server
 | 
			
		||||
						p.BytesTransmitted = peer.BytesDownload // store bytes that where received from the peer and sent by the server
 | 
			
		||||
						p.Endpoint = peer.Endpoint
 | 
			
		||||
						p.LastHandshake = lastHandshake
 | 
			
		||||
							// calculate if session was restarted
 | 
			
		||||
							p.UpdatedAt = time.Now()
 | 
			
		||||
							p.LastSessionStart = getSessionStartTime(*p, peer.BytesUpload, peer.BytesDownload,
 | 
			
		||||
								lastHandshake)
 | 
			
		||||
							p.BytesReceived = peer.BytesUpload      // store bytes that where uploaded from the peer and received by the server
 | 
			
		||||
							p.BytesTransmitted = peer.BytesDownload // store bytes that where received from the peer and sent by the server
 | 
			
		||||
							p.Endpoint = peer.Endpoint
 | 
			
		||||
							p.LastHandshake = lastHandshake
 | 
			
		||||
 | 
			
		||||
						// Update prometheus metrics
 | 
			
		||||
						go c.updatePeerMetrics(ctx, *p)
 | 
			
		||||
							// Update prometheus metrics
 | 
			
		||||
							go c.updatePeerMetrics(ctx, *p)
 | 
			
		||||
 | 
			
		||||
						return p, nil
 | 
			
		||||
					})
 | 
			
		||||
							return p, nil
 | 
			
		||||
						})
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						logrus.Warnf("failed to update interface status for %s: %v", in.Identifier, err)
 | 
			
		||||
						logrus.Warnf("failed to update peer status for %s: %v", peer.Identifier, err)
 | 
			
		||||
					} else {
 | 
			
		||||
						logrus.Tracef("updated peer status for %s", peer.Identifier)
 | 
			
		||||
					}
 | 
			
		||||
					logrus.Tracef("updated peer status for %s", peer.Identifier)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getSessionStartTime(oldStats domain.PeerStatus, newReceived, newTransmitted uint64, latestHandshake *time.Time) *time.Time {
 | 
			
		||||
func getSessionStartTime(
 | 
			
		||||
	oldStats domain.PeerStatus,
 | 
			
		||||
	newReceived, newTransmitted uint64,
 | 
			
		||||
	latestHandshake *time.Time,
 | 
			
		||||
) *time.Time {
 | 
			
		||||
	if latestHandshake == nil {
 | 
			
		||||
		return nil // currently not connected
 | 
			
		||||
	}
 | 
			
		||||
@@ -242,6 +264,28 @@ func (c *StatisticsCollector) pingWorker(ctx context.Context) {
 | 
			
		||||
	for peer := range c.pingJobs {
 | 
			
		||||
		peerPingable := c.isPeerPingable(ctx, peer)
 | 
			
		||||
		logrus.Tracef("peer %s pingable: %t", peer.Identifier, peerPingable)
 | 
			
		||||
 | 
			
		||||
		now := time.Now()
 | 
			
		||||
		err := c.db.UpdatePeerStatus(ctx, peer.Identifier,
 | 
			
		||||
			func(p *domain.PeerStatus) (*domain.PeerStatus, error) {
 | 
			
		||||
				if peerPingable {
 | 
			
		||||
					p.IsPingable = true
 | 
			
		||||
					p.LastPing = &now
 | 
			
		||||
				} else {
 | 
			
		||||
					p.IsPingable = false
 | 
			
		||||
					p.LastPing = nil
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// Update prometheus metrics
 | 
			
		||||
				go c.updatePeerMetrics(ctx, *p)
 | 
			
		||||
 | 
			
		||||
				return p, nil
 | 
			
		||||
			})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logrus.Warnf("failed to update peer ping status for %s: %v", peer.Identifier, err)
 | 
			
		||||
		} else {
 | 
			
		||||
			logrus.Tracef("updated peer ping status for %s", peer.Identifier)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -257,7 +301,7 @@ func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Pe
 | 
			
		||||
 | 
			
		||||
	pinger, err := probing.NewPinger(checkAddr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logrus.Tracef("failed to instatiate pinger for %s: %v", checkAddr, err)
 | 
			
		||||
		logrus.Tracef("failed to instatiate pinger for %s (%s): %v", peer.Identifier, checkAddr, err)
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -267,7 +311,7 @@ func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Pe
 | 
			
		||||
	pinger.Timeout = 2 * time.Second
 | 
			
		||||
	err = pinger.RunWithContext(ctx) // Blocks until finished.
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logrus.Tracef("pinger for %s exited unexpectedly: %v", checkAddr, err)
 | 
			
		||||
		logrus.Tracef("pinger for peer %s (%s) exited unexpectedly: %v", peer.Identifier, checkAddr, err)
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	stats := pinger.Statistics()
 | 
			
		||||
@@ -287,3 +331,18 @@ func (c *StatisticsCollector) updatePeerMetrics(ctx context.Context, status doma
 | 
			
		||||
	}
 | 
			
		||||
	c.ms.UpdatePeerMetrics(peer, status)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *StatisticsCollector) connectToMessageBus() {
 | 
			
		||||
	_ = c.bus.Subscribe(app.TopicPeerIdentifierUpdated, c.handlePeerIdentifierChangeEvent)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *StatisticsCollector) handlePeerIdentifierChangeEvent(oldIdentifier, newIdentifier domain.PeerIdentifier) {
 | 
			
		||||
	ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
 | 
			
		||||
 | 
			
		||||
	// remove potential left-over status data
 | 
			
		||||
	err := c.db.DeletePeerStatus(ctx, oldIdentifier)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logrus.Errorf("failed to delete old peer status for migrated peer, %s -> %s: %v",
 | 
			
		||||
			oldIdentifier, newIdentifier, err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,11 @@
 | 
			
		||||
package wireguard
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/domain"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/domain"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func Test_getSessionStartTime(t *testing.T) {
 | 
			
		||||
@@ -66,7 +67,9 @@ func Test_getSessionStartTime(t *testing.T) {
 | 
			
		||||
		{
 | 
			
		||||
			name: "still connected",
 | 
			
		||||
			args: args{
 | 
			
		||||
				oldStats:       domain.PeerStatus{LastSessionStart: &nowMinus1, BytesReceived: 10, BytesTransmitted: 10},
 | 
			
		||||
				oldStats: domain.PeerStatus{
 | 
			
		||||
					LastSessionStart: &nowMinus1, BytesReceived: 10, BytesTransmitted: 10,
 | 
			
		||||
				},
 | 
			
		||||
				newReceived:    100,
 | 
			
		||||
				newTransmitted: 100,
 | 
			
		||||
				lastHandshake:  &now,
 | 
			
		||||
@@ -76,7 +79,9 @@ func Test_getSessionStartTime(t *testing.T) {
 | 
			
		||||
		{
 | 
			
		||||
			name: "no longer connected",
 | 
			
		||||
			args: args{
 | 
			
		||||
				oldStats:       domain.PeerStatus{LastSessionStart: &nowMinus5, BytesReceived: 100, BytesTransmitted: 100},
 | 
			
		||||
				oldStats: domain.PeerStatus{
 | 
			
		||||
					LastSessionStart: &nowMinus5, BytesReceived: 100, BytesTransmitted: 100,
 | 
			
		||||
				},
 | 
			
		||||
				newReceived:    100,
 | 
			
		||||
				newTransmitted: 100,
 | 
			
		||||
				lastHandshake:  &nowMinus3,
 | 
			
		||||
@@ -116,7 +121,9 @@ func Test_getSessionStartTime(t *testing.T) {
 | 
			
		||||
		{
 | 
			
		||||
			name: "reconnect (sent)",
 | 
			
		||||
			args: args{
 | 
			
		||||
				oldStats:       domain.PeerStatus{LastSessionStart: &nowMinus1, BytesReceived: 100, BytesTransmitted: 100},
 | 
			
		||||
				oldStats: domain.PeerStatus{
 | 
			
		||||
					LastSessionStart: &nowMinus1, BytesReceived: 100, BytesTransmitted: 100,
 | 
			
		||||
				},
 | 
			
		||||
				newReceived:    100,
 | 
			
		||||
				newTransmitted: 10,
 | 
			
		||||
				lastHandshake:  &now,
 | 
			
		||||
@@ -126,7 +133,8 @@ func Test_getSessionStartTime(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			if got := getSessionStartTime(tt.args.oldStats, tt.args.newReceived, tt.args.newTransmitted, tt.args.lastHandshake); !reflect.DeepEqual(got, tt.want) {
 | 
			
		||||
			if got := getSessionStartTime(tt.args.oldStats, tt.args.newReceived, tt.args.newTransmitted,
 | 
			
		||||
				tt.args.lastHandshake); !reflect.DeepEqual(got, tt.want) {
 | 
			
		||||
				t.Errorf("getSessionStartTime() = %v, want %v", got, tt.want)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
 
 | 
			
		||||
@@ -230,9 +230,31 @@ func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
 | 
			
		||||
		return nil, fmt.Errorf("update not allowed: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = m.savePeers(ctx, peer)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("update failure: %w", err)
 | 
			
		||||
	// handle peer identifier change (new public key)
 | 
			
		||||
	if existingPeer.Identifier != domain.PeerIdentifier(peer.Interface.PublicKey) {
 | 
			
		||||
		peer.Identifier = domain.PeerIdentifier(peer.Interface.PublicKey) // set new identifier
 | 
			
		||||
 | 
			
		||||
		// delete old peer
 | 
			
		||||
		err = m.DeletePeer(ctx, existingPeer.Identifier)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("failed to delete old peer %s for %s: %w",
 | 
			
		||||
				existingPeer.Identifier, peer.Identifier, err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// save new peer
 | 
			
		||||
		err = m.savePeers(ctx, peer)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("update failure for re-identified peer %s (was %s): %w",
 | 
			
		||||
				peer.Identifier, existingPeer.Identifier, err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// publish event
 | 
			
		||||
		m.bus.Publish(app.TopicPeerIdentifierUpdated, existingPeer.Identifier, peer.Identifier)
 | 
			
		||||
	} else { // normal update
 | 
			
		||||
		err = m.savePeers(ctx, peer)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("update failure: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return peer, nil
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user