fix change of peer identifier (public key) (#265)

This commit is contained in:
Christoph Haas
2025-01-05 11:30:34 +01:00
parent 6d86f15ff8
commit 3020fbca4e
7 changed files with 239 additions and 65 deletions

View File

@@ -4,15 +4,16 @@ import (
"context"
"errors"
"fmt"
"github.com/sirupsen/logrus"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/utils"
"os"
"path/filepath"
"strings"
"time"
"github.com/sirupsen/logrus"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/utils"
"github.com/glebarez/sqlite"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
@@ -204,7 +205,8 @@ func (r *SqlRepo) preCheck() error {
return nil // we probably don't have a V1 database =)
}
return fmt.Errorf("detected a WireGuard Portal V1 database (version: %s) - please migrate first", lastVersion.Version)
return fmt.Errorf("detected a WireGuard Portal V1 database (version: %s) - please migrate first",
lastVersion.Version)
}
func (r *SqlRepo) migrate() error {
@@ -249,7 +251,11 @@ func (r *SqlRepo) GetInterface(ctx context.Context, id domain.InterfaceIdentifie
return &in, nil
}
func (r *SqlRepo) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) {
func (r *SqlRepo) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
*domain.Interface,
[]domain.Peer,
error,
) {
in, err := r.GetInterface(ctx, id)
if err != nil {
return nil, nil, fmt.Errorf("failed to load interface: %w", err)
@@ -305,7 +311,11 @@ func (r *SqlRepo) FindInterfaces(ctx context.Context, search string) ([]domain.I
return users, nil
}
func (r *SqlRepo) SaveInterface(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.Interface) (*domain.Interface, error)) error {
func (r *SqlRepo) SaveInterface(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(in *domain.Interface) (*domain.Interface, error),
) error {
userInfo := domain.GetUserInfo(ctx)
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
in, err := r.getOrCreateInterface(userInfo, tx, id)
@@ -333,7 +343,11 @@ func (r *SqlRepo) SaveInterface(ctx context.Context, id domain.InterfaceIdentifi
return nil
}
func (r *SqlRepo) getOrCreateInterface(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.InterfaceIdentifier) (*domain.Interface, error) {
func (r *SqlRepo) getOrCreateInterface(
ui *domain.ContextUserInfo,
tx *gorm.DB,
id domain.InterfaceIdentifier,
) (*domain.Interface, error) {
var in domain.Interface
// interfaceDefaults will be applied to newly created interface records
@@ -449,7 +463,10 @@ func (r *SqlRepo) GetInterfacePeers(ctx context.Context, id domain.InterfaceIden
return peers, nil
}
func (r *SqlRepo) FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ([]domain.Peer, error) {
func (r *SqlRepo) FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) (
[]domain.Peer,
error,
) {
var peers []domain.Peer
searchValue := "%" + strings.ToLower(search) + "%"
@@ -492,7 +509,11 @@ func (r *SqlRepo) FindUserPeers(ctx context.Context, id domain.UserIdentifier, s
return peers, nil
}
func (r *SqlRepo) SavePeer(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.Peer) (*domain.Peer, error)) error {
func (r *SqlRepo) SavePeer(
ctx context.Context,
id domain.PeerIdentifier,
updateFunc func(in *domain.Peer) (*domain.Peer, error),
) error {
userInfo := domain.GetUserInfo(ctx)
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
peer, err := r.getOrCreatePeer(userInfo, tx, id)
@@ -520,7 +541,10 @@ func (r *SqlRepo) SavePeer(ctx context.Context, id domain.PeerIdentifier, update
return nil
}
func (r *SqlRepo) getOrCreatePeer(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.PeerIdentifier) (*domain.Peer, error) {
func (r *SqlRepo) getOrCreatePeer(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.PeerIdentifier) (
*domain.Peer,
error,
) {
var peer domain.Peer
// interfaceDefaults will be applied to newly created interface records
@@ -601,7 +625,10 @@ func (r *SqlRepo) GetPeerIps(ctx context.Context) (map[domain.PeerIdentifier][]d
return result, nil
}
func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error) {
func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (
map[domain.Cidr][]domain.Cidr,
error,
) {
var peerIps []struct {
domain.Cidr
PeerId domain.PeerIdentifier `gorm:"column:peer_identifier"`
@@ -699,7 +726,11 @@ func (r *SqlRepo) FindUsers(ctx context.Context, search string) ([]domain.User,
return users, nil
}
func (r *SqlRepo) SaveUser(ctx context.Context, id domain.UserIdentifier, updateFunc func(u *domain.User) (*domain.User, error)) error {
func (r *SqlRepo) SaveUser(
ctx context.Context,
id domain.UserIdentifier,
updateFunc func(u *domain.User) (*domain.User, error),
) error {
userInfo := domain.GetUserInfo(ctx)
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
@@ -737,7 +768,10 @@ func (r *SqlRepo) DeleteUser(ctx context.Context, id domain.UserIdentifier) erro
return nil
}
func (r *SqlRepo) getOrCreateUser(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.UserIdentifier) (*domain.User, error) {
func (r *SqlRepo) getOrCreateUser(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.UserIdentifier) (
*domain.User,
error,
) {
var user domain.User
// userDefaults will be applied to newly created user records
@@ -777,7 +811,11 @@ func (r *SqlRepo) upsertUser(ui *domain.ContextUserInfo, tx *gorm.DB, user *doma
// region statistics
func (r *SqlRepo) UpdateInterfaceStatus(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error)) error {
func (r *SqlRepo) UpdateInterfaceStatus(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error),
) error {
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
in, err := r.getOrCreateInterfaceStatus(tx, id)
if err != nil {
@@ -804,7 +842,10 @@ func (r *SqlRepo) UpdateInterfaceStatus(ctx context.Context, id domain.Interface
return nil
}
func (r *SqlRepo) getOrCreateInterfaceStatus(tx *gorm.DB, id domain.InterfaceIdentifier) (*domain.InterfaceStatus, error) {
func (r *SqlRepo) getOrCreateInterfaceStatus(tx *gorm.DB, id domain.InterfaceIdentifier) (
*domain.InterfaceStatus,
error,
) {
var in domain.InterfaceStatus
// defaults will be applied to newly created record
@@ -830,7 +871,11 @@ func (r *SqlRepo) upsertInterfaceStatus(tx *gorm.DB, in *domain.InterfaceStatus)
return nil
}
func (r *SqlRepo) UpdatePeerStatus(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error)) error {
func (r *SqlRepo) UpdatePeerStatus(
ctx context.Context,
id domain.PeerIdentifier,
updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error),
) error {
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
in, err := r.getOrCreatePeerStatus(tx, id)
if err != nil {
@@ -883,6 +928,15 @@ func (r *SqlRepo) upsertPeerStatus(tx *gorm.DB, in *domain.PeerStatus) error {
return nil
}
func (r *SqlRepo) DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error {
err := r.db.WithContext(ctx).Delete(&domain.PeerStatus{}, id).Error
if err != nil {
return err
}
return nil
}
// endregion statistics
// region audit