mirror of
				https://github.com/h44z/wg-portal.git
				synced 2025-11-03 23:56:18 +00:00 
			
		
		
		
	fix change of peer identifier (public key) (#265)
This commit is contained in:
		@@ -4,15 +4,16 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/sirupsen/logrus"
 | 
			
		||||
	"gorm.io/gorm/clause"
 | 
			
		||||
	"gorm.io/gorm/logger"
 | 
			
		||||
	"gorm.io/gorm/utils"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/sirupsen/logrus"
 | 
			
		||||
	"gorm.io/gorm/clause"
 | 
			
		||||
	"gorm.io/gorm/logger"
 | 
			
		||||
	"gorm.io/gorm/utils"
 | 
			
		||||
 | 
			
		||||
	"github.com/glebarez/sqlite"
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/config"
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/domain"
 | 
			
		||||
@@ -204,7 +205,8 @@ func (r *SqlRepo) preCheck() error {
 | 
			
		||||
		return nil // we probably don't have a V1 database =)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fmt.Errorf("detected a WireGuard Portal V1 database (version: %s) - please migrate first", lastVersion.Version)
 | 
			
		||||
	return fmt.Errorf("detected a WireGuard Portal V1 database (version: %s) - please migrate first",
 | 
			
		||||
		lastVersion.Version)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) migrate() error {
 | 
			
		||||
@@ -249,7 +251,11 @@ func (r *SqlRepo) GetInterface(ctx context.Context, id domain.InterfaceIdentifie
 | 
			
		||||
	return &in, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) {
 | 
			
		||||
func (r *SqlRepo) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
 | 
			
		||||
	*domain.Interface,
 | 
			
		||||
	[]domain.Peer,
 | 
			
		||||
	error,
 | 
			
		||||
) {
 | 
			
		||||
	in, err := r.GetInterface(ctx, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, fmt.Errorf("failed to load interface: %w", err)
 | 
			
		||||
@@ -305,7 +311,11 @@ func (r *SqlRepo) FindInterfaces(ctx context.Context, search string) ([]domain.I
 | 
			
		||||
	return users, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) SaveInterface(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.Interface) (*domain.Interface, error)) error {
 | 
			
		||||
func (r *SqlRepo) SaveInterface(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	id domain.InterfaceIdentifier,
 | 
			
		||||
	updateFunc func(in *domain.Interface) (*domain.Interface, error),
 | 
			
		||||
) error {
 | 
			
		||||
	userInfo := domain.GetUserInfo(ctx)
 | 
			
		||||
	err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		in, err := r.getOrCreateInterface(userInfo, tx, id)
 | 
			
		||||
@@ -333,7 +343,11 @@ func (r *SqlRepo) SaveInterface(ctx context.Context, id domain.InterfaceIdentifi
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) getOrCreateInterface(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.InterfaceIdentifier) (*domain.Interface, error) {
 | 
			
		||||
func (r *SqlRepo) getOrCreateInterface(
 | 
			
		||||
	ui *domain.ContextUserInfo,
 | 
			
		||||
	tx *gorm.DB,
 | 
			
		||||
	id domain.InterfaceIdentifier,
 | 
			
		||||
) (*domain.Interface, error) {
 | 
			
		||||
	var in domain.Interface
 | 
			
		||||
 | 
			
		||||
	// interfaceDefaults will be applied to newly created interface records
 | 
			
		||||
@@ -449,7 +463,10 @@ func (r *SqlRepo) GetInterfacePeers(ctx context.Context, id domain.InterfaceIden
 | 
			
		||||
	return peers, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ([]domain.Peer, error) {
 | 
			
		||||
func (r *SqlRepo) FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) (
 | 
			
		||||
	[]domain.Peer,
 | 
			
		||||
	error,
 | 
			
		||||
) {
 | 
			
		||||
	var peers []domain.Peer
 | 
			
		||||
 | 
			
		||||
	searchValue := "%" + strings.ToLower(search) + "%"
 | 
			
		||||
@@ -492,7 +509,11 @@ func (r *SqlRepo) FindUserPeers(ctx context.Context, id domain.UserIdentifier, s
 | 
			
		||||
	return peers, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) SavePeer(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.Peer) (*domain.Peer, error)) error {
 | 
			
		||||
func (r *SqlRepo) SavePeer(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	id domain.PeerIdentifier,
 | 
			
		||||
	updateFunc func(in *domain.Peer) (*domain.Peer, error),
 | 
			
		||||
) error {
 | 
			
		||||
	userInfo := domain.GetUserInfo(ctx)
 | 
			
		||||
	err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		peer, err := r.getOrCreatePeer(userInfo, tx, id)
 | 
			
		||||
@@ -520,7 +541,10 @@ func (r *SqlRepo) SavePeer(ctx context.Context, id domain.PeerIdentifier, update
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) getOrCreatePeer(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.PeerIdentifier) (*domain.Peer, error) {
 | 
			
		||||
func (r *SqlRepo) getOrCreatePeer(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.PeerIdentifier) (
 | 
			
		||||
	*domain.Peer,
 | 
			
		||||
	error,
 | 
			
		||||
) {
 | 
			
		||||
	var peer domain.Peer
 | 
			
		||||
 | 
			
		||||
	// interfaceDefaults will be applied to newly created interface records
 | 
			
		||||
@@ -601,7 +625,10 @@ func (r *SqlRepo) GetPeerIps(ctx context.Context) (map[domain.PeerIdentifier][]d
 | 
			
		||||
	return result, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error) {
 | 
			
		||||
func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (
 | 
			
		||||
	map[domain.Cidr][]domain.Cidr,
 | 
			
		||||
	error,
 | 
			
		||||
) {
 | 
			
		||||
	var peerIps []struct {
 | 
			
		||||
		domain.Cidr
 | 
			
		||||
		PeerId domain.PeerIdentifier `gorm:"column:peer_identifier"`
 | 
			
		||||
@@ -699,7 +726,11 @@ func (r *SqlRepo) FindUsers(ctx context.Context, search string) ([]domain.User,
 | 
			
		||||
	return users, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) SaveUser(ctx context.Context, id domain.UserIdentifier, updateFunc func(u *domain.User) (*domain.User, error)) error {
 | 
			
		||||
func (r *SqlRepo) SaveUser(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	id domain.UserIdentifier,
 | 
			
		||||
	updateFunc func(u *domain.User) (*domain.User, error),
 | 
			
		||||
) error {
 | 
			
		||||
	userInfo := domain.GetUserInfo(ctx)
 | 
			
		||||
 | 
			
		||||
	err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
@@ -737,7 +768,10 @@ func (r *SqlRepo) DeleteUser(ctx context.Context, id domain.UserIdentifier) erro
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) getOrCreateUser(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.UserIdentifier) (*domain.User, error) {
 | 
			
		||||
func (r *SqlRepo) getOrCreateUser(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.UserIdentifier) (
 | 
			
		||||
	*domain.User,
 | 
			
		||||
	error,
 | 
			
		||||
) {
 | 
			
		||||
	var user domain.User
 | 
			
		||||
 | 
			
		||||
	// userDefaults will be applied to newly created user records
 | 
			
		||||
@@ -777,7 +811,11 @@ func (r *SqlRepo) upsertUser(ui *domain.ContextUserInfo, tx *gorm.DB, user *doma
 | 
			
		||||
 | 
			
		||||
// region statistics
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) UpdateInterfaceStatus(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error)) error {
 | 
			
		||||
func (r *SqlRepo) UpdateInterfaceStatus(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	id domain.InterfaceIdentifier,
 | 
			
		||||
	updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error),
 | 
			
		||||
) error {
 | 
			
		||||
	err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		in, err := r.getOrCreateInterfaceStatus(tx, id)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
@@ -804,7 +842,10 @@ func (r *SqlRepo) UpdateInterfaceStatus(ctx context.Context, id domain.Interface
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) getOrCreateInterfaceStatus(tx *gorm.DB, id domain.InterfaceIdentifier) (*domain.InterfaceStatus, error) {
 | 
			
		||||
func (r *SqlRepo) getOrCreateInterfaceStatus(tx *gorm.DB, id domain.InterfaceIdentifier) (
 | 
			
		||||
	*domain.InterfaceStatus,
 | 
			
		||||
	error,
 | 
			
		||||
) {
 | 
			
		||||
	var in domain.InterfaceStatus
 | 
			
		||||
 | 
			
		||||
	// defaults will be applied to newly created record
 | 
			
		||||
@@ -830,7 +871,11 @@ func (r *SqlRepo) upsertInterfaceStatus(tx *gorm.DB, in *domain.InterfaceStatus)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) UpdatePeerStatus(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error)) error {
 | 
			
		||||
func (r *SqlRepo) UpdatePeerStatus(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	id domain.PeerIdentifier,
 | 
			
		||||
	updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error),
 | 
			
		||||
) error {
 | 
			
		||||
	err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		in, err := r.getOrCreatePeerStatus(tx, id)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
@@ -883,6 +928,15 @@ func (r *SqlRepo) upsertPeerStatus(tx *gorm.DB, in *domain.PeerStatus) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *SqlRepo) DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error {
 | 
			
		||||
	err := r.db.WithContext(ctx).Delete(&domain.PeerStatus{}, id).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// endregion statistics
 | 
			
		||||
 | 
			
		||||
// region audit
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user