fix change of peer identifier (public key) (#265)

This commit is contained in:
Christoph Haas
2025-01-05 11:30:34 +01:00
parent 6d86f15ff8
commit 3020fbca4e
7 changed files with 239 additions and 65 deletions

View File

@@ -13,13 +13,21 @@ type InterfaceAndPeerDatabaseRepo interface {
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
FindInterfaces(ctx context.Context, search string) ([]domain.Interface, error)
GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error)
SaveInterface(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.Interface) (*domain.Interface, error)) error
SaveInterface(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(in *domain.Interface) (*domain.Interface, error),
) error
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ([]domain.Peer, error)
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
FindUserPeers(ctx context.Context, id domain.UserIdentifier, search string) ([]domain.Peer, error)
SavePeer(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.Peer) (*domain.Peer, error)) error
SavePeer(
ctx context.Context,
id domain.PeerIdentifier,
updateFunc func(in *domain.Peer) (*domain.Peer, error),
) error
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
@@ -30,18 +38,40 @@ type StatisticsDatabaseRepo interface {
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
UpdatePeerStatus(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error)) error
UpdateInterfaceStatus(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error)) error
UpdatePeerStatus(
ctx context.Context,
id domain.PeerIdentifier,
updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error),
) error
UpdateInterfaceStatus(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error),
) error
DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error
}
type InterfaceController interface {
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error)
SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error
GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (
*domain.PhysicalPeer,
error,
)
SaveInterface(
_ context.Context,
id domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error
SavePeer(
_ context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
}

View File

@@ -5,14 +5,17 @@ import (
"sync"
"time"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
probing "github.com/prometheus-community/pro-bing"
"github.com/sirupsen/logrus"
evbus "github.com/vardius/message-bus"
)
type StatisticsCollector struct {
cfg *config.Config
bus evbus.MessageBus
pingWaitGroup sync.WaitGroup
pingJobs chan domain.Peer
@@ -22,14 +25,25 @@ type StatisticsCollector struct {
ms MetricsServer
}
func NewStatisticsCollector(cfg *config.Config, db StatisticsDatabaseRepo, wg InterfaceController, ms MetricsServer) (*StatisticsCollector, error) {
return &StatisticsCollector{
func NewStatisticsCollector(
cfg *config.Config,
bus evbus.MessageBus,
db StatisticsDatabaseRepo,
wg InterfaceController,
ms MetricsServer,
) (*StatisticsCollector, error) {
c := &StatisticsCollector{
cfg: cfg,
bus: bus,
db: db,
wg: wg,
ms: ms,
}, nil
}
c.connectToMessageBus()
return c, nil
}
func (c *StatisticsCollector) StartBackgroundJobs(ctx context.Context) {
@@ -69,16 +83,17 @@ func (c *StatisticsCollector) collectInterfaceData(ctx context.Context) {
logrus.Warnf("failed to load physical interface %s for data collection: %v", in.Identifier, err)
continue
}
err = c.db.UpdateInterfaceStatus(ctx, in.Identifier, func(i *domain.InterfaceStatus) (*domain.InterfaceStatus, error) {
i.UpdatedAt = time.Now()
i.BytesReceived = physicalInterface.BytesDownload
i.BytesTransmitted = physicalInterface.BytesUpload
err = c.db.UpdateInterfaceStatus(ctx, in.Identifier,
func(i *domain.InterfaceStatus) (*domain.InterfaceStatus, error) {
i.UpdatedAt = time.Now()
i.BytesReceived = physicalInterface.BytesDownload
i.BytesTransmitted = physicalInterface.BytesUpload
// Update prometheus metrics
go c.updateInterfaceMetrics(*i)
// Update prometheus metrics
go c.updateInterfaceMetrics(*i)
return i, nil
})
return i, nil
})
if err != nil {
logrus.Warnf("failed to update interface status for %s: %v", in.Identifier, err)
}
@@ -120,36 +135,43 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
continue
}
for _, peer := range peers {
err = c.db.UpdatePeerStatus(ctx, peer.Identifier, func(p *domain.PeerStatus) (*domain.PeerStatus, error) {
var lastHandshake *time.Time
if !peer.LastHandshake.IsZero() {
lastHandshake = &peer.LastHandshake
}
err = c.db.UpdatePeerStatus(ctx, peer.Identifier,
func(p *domain.PeerStatus) (*domain.PeerStatus, error) {
var lastHandshake *time.Time
if !peer.LastHandshake.IsZero() {
lastHandshake = &peer.LastHandshake
}
// calculate if session was restarted
p.UpdatedAt = time.Now()
p.LastSessionStart = getSessionStartTime(*p, peer.BytesUpload, peer.BytesDownload, lastHandshake)
p.BytesReceived = peer.BytesUpload // store bytes that where uploaded from the peer and received by the server
p.BytesTransmitted = peer.BytesDownload // store bytes that where received from the peer and sent by the server
p.Endpoint = peer.Endpoint
p.LastHandshake = lastHandshake
// calculate if session was restarted
p.UpdatedAt = time.Now()
p.LastSessionStart = getSessionStartTime(*p, peer.BytesUpload, peer.BytesDownload,
lastHandshake)
p.BytesReceived = peer.BytesUpload // store bytes that where uploaded from the peer and received by the server
p.BytesTransmitted = peer.BytesDownload // store bytes that where received from the peer and sent by the server
p.Endpoint = peer.Endpoint
p.LastHandshake = lastHandshake
// Update prometheus metrics
go c.updatePeerMetrics(ctx, *p)
// Update prometheus metrics
go c.updatePeerMetrics(ctx, *p)
return p, nil
})
return p, nil
})
if err != nil {
logrus.Warnf("failed to update interface status for %s: %v", in.Identifier, err)
logrus.Warnf("failed to update peer status for %s: %v", peer.Identifier, err)
} else {
logrus.Tracef("updated peer status for %s", peer.Identifier)
}
logrus.Tracef("updated peer status for %s", peer.Identifier)
}
}
}
}
}
func getSessionStartTime(oldStats domain.PeerStatus, newReceived, newTransmitted uint64, latestHandshake *time.Time) *time.Time {
func getSessionStartTime(
oldStats domain.PeerStatus,
newReceived, newTransmitted uint64,
latestHandshake *time.Time,
) *time.Time {
if latestHandshake == nil {
return nil // currently not connected
}
@@ -242,6 +264,28 @@ func (c *StatisticsCollector) pingWorker(ctx context.Context) {
for peer := range c.pingJobs {
peerPingable := c.isPeerPingable(ctx, peer)
logrus.Tracef("peer %s pingable: %t", peer.Identifier, peerPingable)
now := time.Now()
err := c.db.UpdatePeerStatus(ctx, peer.Identifier,
func(p *domain.PeerStatus) (*domain.PeerStatus, error) {
if peerPingable {
p.IsPingable = true
p.LastPing = &now
} else {
p.IsPingable = false
p.LastPing = nil
}
// Update prometheus metrics
go c.updatePeerMetrics(ctx, *p)
return p, nil
})
if err != nil {
logrus.Warnf("failed to update peer ping status for %s: %v", peer.Identifier, err)
} else {
logrus.Tracef("updated peer ping status for %s", peer.Identifier)
}
}
}
@@ -257,7 +301,7 @@ func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Pe
pinger, err := probing.NewPinger(checkAddr)
if err != nil {
logrus.Tracef("failed to instatiate pinger for %s: %v", checkAddr, err)
logrus.Tracef("failed to instatiate pinger for %s (%s): %v", peer.Identifier, checkAddr, err)
return false
}
@@ -267,7 +311,7 @@ func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Pe
pinger.Timeout = 2 * time.Second
err = pinger.RunWithContext(ctx) // Blocks until finished.
if err != nil {
logrus.Tracef("pinger for %s exited unexpectedly: %v", checkAddr, err)
logrus.Tracef("pinger for peer %s (%s) exited unexpectedly: %v", peer.Identifier, checkAddr, err)
return false
}
stats := pinger.Statistics()
@@ -287,3 +331,18 @@ func (c *StatisticsCollector) updatePeerMetrics(ctx context.Context, status doma
}
c.ms.UpdatePeerMetrics(peer, status)
}
func (c *StatisticsCollector) connectToMessageBus() {
_ = c.bus.Subscribe(app.TopicPeerIdentifierUpdated, c.handlePeerIdentifierChangeEvent)
}
func (c *StatisticsCollector) handlePeerIdentifierChangeEvent(oldIdentifier, newIdentifier domain.PeerIdentifier) {
ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
// remove potential left-over status data
err := c.db.DeletePeerStatus(ctx, oldIdentifier)
if err != nil {
logrus.Errorf("failed to delete old peer status for migrated peer, %s -> %s: %v",
oldIdentifier, newIdentifier, err)
}
}

View File

@@ -1,10 +1,11 @@
package wireguard
import (
"github.com/h44z/wg-portal/internal/domain"
"reflect"
"testing"
"time"
"github.com/h44z/wg-portal/internal/domain"
)
func Test_getSessionStartTime(t *testing.T) {
@@ -66,7 +67,9 @@ func Test_getSessionStartTime(t *testing.T) {
{
name: "still connected",
args: args{
oldStats: domain.PeerStatus{LastSessionStart: &nowMinus1, BytesReceived: 10, BytesTransmitted: 10},
oldStats: domain.PeerStatus{
LastSessionStart: &nowMinus1, BytesReceived: 10, BytesTransmitted: 10,
},
newReceived: 100,
newTransmitted: 100,
lastHandshake: &now,
@@ -76,7 +79,9 @@ func Test_getSessionStartTime(t *testing.T) {
{
name: "no longer connected",
args: args{
oldStats: domain.PeerStatus{LastSessionStart: &nowMinus5, BytesReceived: 100, BytesTransmitted: 100},
oldStats: domain.PeerStatus{
LastSessionStart: &nowMinus5, BytesReceived: 100, BytesTransmitted: 100,
},
newReceived: 100,
newTransmitted: 100,
lastHandshake: &nowMinus3,
@@ -116,7 +121,9 @@ func Test_getSessionStartTime(t *testing.T) {
{
name: "reconnect (sent)",
args: args{
oldStats: domain.PeerStatus{LastSessionStart: &nowMinus1, BytesReceived: 100, BytesTransmitted: 100},
oldStats: domain.PeerStatus{
LastSessionStart: &nowMinus1, BytesReceived: 100, BytesTransmitted: 100,
},
newReceived: 100,
newTransmitted: 10,
lastHandshake: &now,
@@ -126,7 +133,8 @@ func Test_getSessionStartTime(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getSessionStartTime(tt.args.oldStats, tt.args.newReceived, tt.args.newTransmitted, tt.args.lastHandshake); !reflect.DeepEqual(got, tt.want) {
if got := getSessionStartTime(tt.args.oldStats, tt.args.newReceived, tt.args.newTransmitted,
tt.args.lastHandshake); !reflect.DeepEqual(got, tt.want) {
t.Errorf("getSessionStartTime() = %v, want %v", got, tt.want)
}
})

View File

@@ -230,9 +230,31 @@ func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
return nil, fmt.Errorf("update not allowed: %w", err)
}
err = m.savePeers(ctx, peer)
if err != nil {
return nil, fmt.Errorf("update failure: %w", err)
// handle peer identifier change (new public key)
if existingPeer.Identifier != domain.PeerIdentifier(peer.Interface.PublicKey) {
peer.Identifier = domain.PeerIdentifier(peer.Interface.PublicKey) // set new identifier
// delete old peer
err = m.DeletePeer(ctx, existingPeer.Identifier)
if err != nil {
return nil, fmt.Errorf("failed to delete old peer %s for %s: %w",
existingPeer.Identifier, peer.Identifier, err)
}
// save new peer
err = m.savePeers(ctx, peer)
if err != nil {
return nil, fmt.Errorf("update failure for re-identified peer %s (was %s): %w",
peer.Identifier, existingPeer.Identifier, err)
}
// publish event
m.bus.Publish(app.TopicPeerIdentifierUpdated, existingPeer.Identifier, peer.Identifier)
} else { // normal update
err = m.savePeers(ctx, peer)
if err != nil {
return nil, fmt.Errorf("update failure: %w", err)
}
}
return peer, nil