V2 alpha - initial version (#172)

Initial alpha codebase for version 2 of WireGuard Portal.
This version is considered unstable and incomplete (for example, no public REST API)! 
Use with care!


Fixes/Implements the following issues:
 - OAuth support #154, #1 
 - New Web UI with internationalisation support #98, #107, #89, #62
 - Postgres Support #49 
 - Improved Email handling #47, #119 
 - DNS Search Domain support #46 
 - Bugfixes #94, #48 

---------

Co-authored-by: Fabian Wechselberger <wechselbergerf@hotmail.com>
This commit is contained in:
h44z
2023-08-04 13:34:18 +02:00
committed by GitHub
parent b3a5f2ac60
commit 8b820a5adf
788 changed files with 46139 additions and 11281 deletions

View File

@@ -0,0 +1,50 @@
package wireguard
import (
"context"
"github.com/h44z/wg-portal/internal/domain"
)
type InterfaceAndPeerDatabaseRepo interface {
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error)
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
FindInterfaces(ctx context.Context, search string) ([]domain.Interface, error)
GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error)
SaveInterface(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.Interface) (*domain.Interface, error)) error
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ([]domain.Peer, error)
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
FindUserPeers(ctx context.Context, id domain.UserIdentifier, search string) ([]domain.Peer, error)
SavePeer(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.Peer) (*domain.Peer, error)) error
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
GetUsedIpsPerSubnet(ctx context.Context) (map[domain.Cidr][]domain.Cidr, error)
}
type StatisticsDatabaseRepo interface {
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
UpdatePeerStatus(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error)) error
UpdateInterfaceStatus(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error)) error
}
type InterfaceController interface {
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error)
SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
}
type WgQuickController interface {
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
UnsetDNS(id domain.InterfaceIdentifier) error
}

View File

@@ -0,0 +1,263 @@
package wireguard
import (
"context"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
"github.com/prometheus-community/pro-bing"
"github.com/sirupsen/logrus"
"sync"
"time"
)
type StatisticsCollector struct {
cfg *config.Config
pingWaitGroup sync.WaitGroup
pingJobs chan domain.Peer
db StatisticsDatabaseRepo
wg InterfaceController
}
func NewStatisticsCollector(cfg *config.Config, db StatisticsDatabaseRepo, wg InterfaceController) (*StatisticsCollector, error) {
return &StatisticsCollector{
cfg: cfg,
db: db,
wg: wg,
}, nil
}
func (c *StatisticsCollector) StartBackgroundJobs(ctx context.Context) {
c.startPingWorkers(ctx)
c.startInterfaceDataFetcher(ctx)
c.startPeerDataFetcher(ctx)
}
func (c *StatisticsCollector) startInterfaceDataFetcher(ctx context.Context) {
if !c.cfg.Statistics.CollectInterfaceData {
return
}
go c.collectInterfaceData(ctx)
logrus.Tracef("started interface data fetcher")
}
func (c *StatisticsCollector) collectInterfaceData(ctx context.Context) {
// Start ticker
ticker := time.NewTicker(c.cfg.Statistics.DataCollectionInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return // program stopped
case <-ticker.C:
interfaces, err := c.db.GetAllInterfaces(ctx)
if err != nil {
logrus.Warnf("failed to fetch all interfaces for data collection: %v", err)
continue
}
for _, in := range interfaces {
physicalInterface, err := c.wg.GetInterface(ctx, in.Identifier)
if err != nil {
logrus.Warnf("failed to load physical interface %s for data collection: %v", in.Identifier, err)
continue
}
err = c.db.UpdateInterfaceStatus(ctx, in.Identifier, func(i *domain.InterfaceStatus) (*domain.InterfaceStatus, error) {
i.UpdatedAt = time.Now()
i.BytesReceived = physicalInterface.BytesDownload
i.BytesTransmitted = physicalInterface.BytesUpload
return i, nil
})
if err != nil {
logrus.Warnf("failed to update interface status for %s: %v", in.Identifier, err)
}
}
}
}
}
func (c *StatisticsCollector) startPeerDataFetcher(ctx context.Context) {
if !c.cfg.Statistics.CollectPeerData {
return
}
go c.collectPeerData(ctx)
logrus.Tracef("started peer data fetcher")
}
func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
// Start ticker
ticker := time.NewTicker(c.cfg.Statistics.DataCollectionInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return // program stopped
case <-ticker.C:
interfaces, err := c.db.GetAllInterfaces(ctx)
if err != nil {
logrus.Warnf("failed to fetch all interfaces for peer data collection: %v", err)
continue
}
for _, in := range interfaces {
peers, err := c.wg.GetPeers(ctx, in.Identifier)
if err != nil {
logrus.Warnf("failed to fetch peers for data collection (interface %s): %v", in.Identifier, err)
continue
}
for _, peer := range peers {
err = c.db.UpdatePeerStatus(ctx, peer.Identifier, func(p *domain.PeerStatus) (*domain.PeerStatus, error) {
var lastHandshake *time.Time
if !peer.LastHandshake.IsZero() {
lastHandshake = &peer.LastHandshake
}
// calculate if session was restarted
p.UpdatedAt = time.Now()
p.LastSessionStart = getSessionStartTime(*p, peer.BytesUpload, peer.BytesDownload, lastHandshake)
p.BytesReceived = peer.BytesUpload // store bytes that where uploaded from the peer and received by the server
p.BytesTransmitted = peer.BytesDownload // store bytes that where received from the peer and sent by the server
p.Endpoint = peer.Endpoint
p.LastHandshake = lastHandshake
return p, nil
})
if err != nil {
logrus.Warnf("failed to update interface status for %s: %v", in.Identifier, err)
}
}
}
}
}
}
func getSessionStartTime(oldStats domain.PeerStatus, newReceived, newTransmitted uint64, latestHandshake *time.Time) *time.Time {
if latestHandshake == nil {
return nil // currently not connected
}
oldestHandshakeTime := time.Now().Add(-2 * time.Minute) // if a handshake is older than 2 minutes, the peer is no longer connected
switch {
// old session was never initiated
case oldStats.BytesReceived == 0 && oldStats.BytesTransmitted == 0 && (newReceived > 0 || newTransmitted > 0):
return latestHandshake
// session never received bytes -> first receive
case oldStats.BytesReceived == 0 && newReceived > 0 && (oldStats.LastHandshake == nil || oldStats.LastHandshake.Before(oldestHandshakeTime)):
return latestHandshake
// session never transmitted bytes -> first transmit
case oldStats.BytesTransmitted == 0 && newTransmitted > 0 && (oldStats.LastSessionStart == nil || oldStats.LastHandshake.Before(oldestHandshakeTime)):
return latestHandshake
// session restarted as newer send or transmit counts are lower
case (newReceived != 0 && newReceived < oldStats.BytesReceived) || (newTransmitted != 0 && newTransmitted < oldStats.BytesTransmitted):
return latestHandshake
// session initiated (but some bytes were already transmitted
case oldStats.LastSessionStart == nil && (newReceived > oldStats.BytesReceived || newTransmitted > oldStats.BytesTransmitted):
return latestHandshake
default:
return oldStats.LastSessionStart
}
}
func (c *StatisticsCollector) startPingWorkers(ctx context.Context) {
if !c.cfg.Statistics.UsePingChecks {
return
}
if c.pingJobs != nil {
return // already started
}
c.pingWaitGroup = sync.WaitGroup{}
c.pingWaitGroup.Add(c.cfg.Statistics.PingCheckWorkers)
c.pingJobs = make(chan domain.Peer, c.cfg.Statistics.PingCheckWorkers)
// start workers
for i := 0; i < c.cfg.Statistics.PingCheckWorkers; i++ {
go c.pingWorker(ctx)
}
// start cleanup goroutine
go func() {
c.pingWaitGroup.Wait()
logrus.Tracef("stopped ping checks")
}()
go c.enqueuePingChecks(ctx)
logrus.Tracef("started ping checks")
}
func (c *StatisticsCollector) enqueuePingChecks(ctx context.Context) {
// Start ticker
ticker := time.NewTicker(c.cfg.Statistics.PingCheckInterval)
defer ticker.Stop()
defer close(c.pingJobs)
for {
select {
case <-ctx.Done():
return // program stopped
case <-ticker.C:
interfaces, err := c.db.GetAllInterfaces(ctx)
if err != nil {
logrus.Warnf("failed to fetch all interfaces for ping checks: %v", err)
continue
}
for _, in := range interfaces {
peers, err := c.db.GetInterfacePeers(ctx, in.Identifier)
if err != nil {
logrus.Warnf("failed to fetch peers for ping checks (interface %s): %v", in.Identifier, err)
continue
}
for _, peer := range peers {
c.pingJobs <- peer
}
}
}
}
}
func (c *StatisticsCollector) pingWorker(ctx context.Context) {
defer c.pingWaitGroup.Done()
for peer := range c.pingJobs {
peerPingable := c.isPeerPingable(ctx, peer)
logrus.Tracef("peer %s pingable: %t", peer.Identifier, peerPingable)
}
}
func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Peer) bool {
if c.cfg.Statistics.UsePingChecks == false {
return false
}
checkAddr := peer.CheckAliveAddress()
if checkAddr == "" {
return false
}
pinger, err := probing.NewPinger(checkAddr)
if err != nil {
logrus.Tracef("failed to instatiate pinger for %s: %v", checkAddr, err)
return false
}
checkCount := 1
pinger.SetPrivileged(!c.cfg.Statistics.PingUnprivileged)
pinger.Count = checkCount
pinger.Timeout = 2 * time.Second
err = pinger.RunWithContext(ctx) // Blocks until finished.
if err != nil {
logrus.Tracef("pinger for %s exited unexpectedly: %v", checkAddr, err)
return false
}
stats := pinger.Statistics()
return stats.PacketsRecv == checkCount
}

View File

@@ -0,0 +1,134 @@
package wireguard
import (
"github.com/h44z/wg-portal/internal/domain"
"reflect"
"testing"
"time"
)
func Test_getSessionStartTime(t *testing.T) {
now := time.Now()
nowMinus1 := now.Add(-1 * time.Minute)
nowMinus3 := now.Add(-3 * time.Minute)
nowMinus5 := now.Add(-5 * time.Minute)
type args struct {
oldStats domain.PeerStatus
newReceived uint64
newTransmitted uint64
lastHandshake *time.Time
}
tests := []struct {
name string
args args
want *time.Time
}{
{
name: "not connected",
args: args{
newReceived: 0,
newTransmitted: 0,
lastHandshake: nil,
},
want: nil,
},
{
name: "freshly connected",
args: args{
oldStats: domain.PeerStatus{LastSessionStart: &nowMinus1},
newReceived: 100,
newTransmitted: 100,
lastHandshake: &now,
},
want: &now,
},
{
name: "freshly connected (no prev session)",
args: args{
oldStats: domain.PeerStatus{LastSessionStart: nil},
newReceived: 100,
newTransmitted: 100,
lastHandshake: &now,
},
want: &now,
},
{
name: "freshly connected (no prev session but bytes)",
args: args{
oldStats: domain.PeerStatus{LastSessionStart: nil, BytesReceived: 10, BytesTransmitted: 20},
newReceived: 100,
newTransmitted: 100,
lastHandshake: &now,
},
want: &now,
},
{
name: "still connected",
args: args{
oldStats: domain.PeerStatus{LastSessionStart: &nowMinus1, BytesReceived: 10, BytesTransmitted: 10},
newReceived: 100,
newTransmitted: 100,
lastHandshake: &now,
},
want: &nowMinus1,
},
{
name: "no longer connected",
args: args{
oldStats: domain.PeerStatus{LastSessionStart: &nowMinus5, BytesReceived: 100, BytesTransmitted: 100},
newReceived: 100,
newTransmitted: 100,
lastHandshake: &nowMinus3,
},
want: &nowMinus5,
},
{
name: "reconnect (recv, hs outdated)",
args: args{
oldStats: domain.PeerStatus{LastHandshake: &nowMinus5, BytesReceived: 100, BytesTransmitted: 100},
newReceived: 10,
newTransmitted: 100,
lastHandshake: &nowMinus1,
},
want: &nowMinus1,
},
{
name: "reconnect (recv)",
args: args{
oldStats: domain.PeerStatus{LastHandshake: &nowMinus1, BytesReceived: 100, BytesTransmitted: 100},
newReceived: 10,
newTransmitted: 100,
lastHandshake: &now,
},
want: &now,
},
{
name: "reconnect (sent, hs outdated)",
args: args{
oldStats: domain.PeerStatus{LastHandshake: &nowMinus5, BytesReceived: 100, BytesTransmitted: 100},
newReceived: 100,
newTransmitted: 10,
lastHandshake: &nowMinus1,
},
want: &nowMinus1,
},
{
name: "reconnect (sent)",
args: args{
oldStats: domain.PeerStatus{LastSessionStart: &nowMinus1, BytesReceived: 100, BytesTransmitted: 100},
newReceived: 100,
newTransmitted: 10,
lastHandshake: &now,
},
want: &now,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getSessionStartTime(tt.args.oldStats, tt.args.newReceived, tt.args.newTransmitted, tt.args.lastHandshake); !reflect.DeepEqual(got, tt.want) {
t.Errorf("getSessionStartTime() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -0,0 +1,103 @@
package wireguard
import (
"context"
"github.com/h44z/wg-portal/internal/app"
"github.com/sirupsen/logrus"
"time"
evbus "github.com/vardius/message-bus"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
type Manager struct {
cfg *config.Config
bus evbus.MessageBus
db InterfaceAndPeerDatabaseRepo
wg InterfaceController
quick WgQuickController
}
func NewWireGuardManager(cfg *config.Config, bus evbus.MessageBus, wg InterfaceController, quick WgQuickController, db InterfaceAndPeerDatabaseRepo) (*Manager, error) {
m := &Manager{
cfg: cfg,
bus: bus,
wg: wg,
db: db,
quick: quick,
}
m.connectToMessageBus()
return m, nil
}
func (m Manager) StartBackgroundJobs(ctx context.Context) {
go m.runExpiredPeersCheck(ctx)
}
func (m Manager) connectToMessageBus() {
_ = m.bus.Subscribe(app.TopicUserCreated, m.handleUserCreationEvent)
}
func (m Manager) handleUserCreationEvent(user *domain.User) {
logrus.Errorf("handling new user event for %s", user.Identifier)
err := m.CreateDefaultPeer(context.Background(), user)
if err != nil {
logrus.Errorf("failed to create default peer for %s: %v", user.Identifier, err)
return
}
}
func (m Manager) runExpiredPeersCheck(ctx context.Context) {
ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo())
running := true
for running {
select {
case <-ctx.Done():
running = false
continue
case <-time.After(m.cfg.Advanced.ExpiryCheckInterval):
// select blocks until one of the cases evaluate to true
}
interfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
logrus.Errorf("failed to fetch all interfaces for expiry check: %v", err)
continue
}
for _, iface := range interfaces {
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil {
logrus.Errorf("failed to fetch all peers from interface %s for expiry check: %v", iface.Identifier, err)
continue
}
m.checkExpiredPeers(ctx, peers)
}
}
}
func (m Manager) checkExpiredPeers(ctx context.Context, peers []domain.Peer) {
now := time.Now()
for _, peer := range peers {
if peer.IsExpired() && !peer.IsDisabled() {
logrus.Infof("peer %s has expired, disabling...", peer.Identifier)
peer.Disabled = &now
peer.DisabledReason = domain.DisabledReasonExpired
_, err := m.UpdatePeer(ctx, &peer)
if err != nil {
logrus.Errorf("failed to update expired peer %s: %v", peer.Identifier, err)
}
}
}
}

View File

@@ -0,0 +1,734 @@
package wireguard
import (
"context"
"errors"
"fmt"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/domain"
"github.com/sirupsen/logrus"
"os"
"time"
)
func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) {
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
if err != nil {
return nil, err
}
return physicalInterfaces, nil
}
func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) {
return m.db.GetInterfaceAndPeers(ctx, id)
}
func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) {
return m.db.GetAllInterfaces(ctx)
}
func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) {
interfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return nil, nil, fmt.Errorf("unable to load all interfaces: %w", err)
}
allPeers := make([][]domain.Peer, len(interfaces))
for i, iface := range interfaces {
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil {
return nil, nil, fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
}
allPeers[i] = peers
}
return interfaces, allPeers, nil
}
func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) {
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
if err != nil {
return 0, err
}
// if no filter is given, exclude already existing interfaces
var excludedInterfaces []domain.InterfaceIdentifier
if len(filter) == 0 {
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return 0, err
}
for _, existingInterface := range existingInterfaces {
excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier)
}
}
imported := 0
for _, physicalInterface := range physicalInterfaces {
if internal.SliceContains(excludedInterfaces, physicalInterface.Identifier) {
continue
}
if len(filter) != 0 && !internal.SliceContains(filter, physicalInterface.Identifier) {
continue
}
logrus.Infof("importing new interface %s...", physicalInterface.Identifier)
physicalPeers, err := m.wg.GetPeers(ctx, physicalInterface.Identifier)
if err != nil {
return 0, err
}
err = m.importInterface(ctx, &physicalInterface, physicalPeers)
if err != nil {
return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err)
}
logrus.Infof("imported new interface %s and %d peers", physicalInterface.Identifier, len(physicalPeers))
imported++
}
return imported, nil
}
func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error {
existingInterface, err := m.db.GetInterface(ctx, in.Identifier)
if err != nil {
return fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err)
}
if err := m.validateInterfaceModifications(ctx, existingInterface, in); err != nil {
return fmt.Errorf("update not allowed: %w", err)
}
peers, err := m.db.GetInterfacePeers(ctx, in.Identifier)
if err != nil {
return fmt.Errorf("failed to find peers for interface %s: %w", in.Identifier, err)
}
for i := range peers {
(&peers[i]).ApplyInterfaceDefaults(in)
_, err := m.UpdatePeer(ctx, &peers[i])
if err != nil {
return fmt.Errorf("failed to apply interface defaults to peer %s: %w", peers[i].Identifier, err)
}
}
return nil
}
func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool, filter ...domain.InterfaceIdentifier) error {
interfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return err
}
for _, iface := range interfaces {
if len(filter) != 0 && !internal.SliceContains(filter, iface.Identifier) {
continue // ignore filtered interface
}
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil {
return fmt.Errorf("failed to load peers for %s: %w", iface.Identifier, err)
}
_, err = m.wg.GetInterface(ctx, iface.Identifier)
if err != nil {
logrus.Debugf("creating missing interface %s...", iface.Identifier)
// try to create a new interface
_, err = m.saveInterface(ctx, &iface, peers)
if err != nil {
return err
}
if err != nil {
if updateDbOnError {
// disable interface in database as no physical interface exists
_ = m.db.SaveInterface(ctx, iface.Identifier, func(in *domain.Interface) (*domain.Interface, error) {
now := time.Now()
in.Disabled = &now // set
in.DisabledReason = domain.DisabledReasonInterfaceMissing
return in, nil
})
}
return fmt.Errorf("failed to create physical interface %s: %w", iface.Identifier, err)
}
// restore peers
for _, peer := range peers {
err := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier, func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, &peer)
return pp, nil
})
if err != nil {
return fmt.Errorf("failed to create physical peer %s: %w", peer.Identifier, err)
}
}
} else {
logrus.Debugf("restoring interface state for %s to disabled=%t", iface.Identifier, iface.IsDisabled())
// try to move interface to stored state
_, err = m.saveInterface(ctx, &iface, peers)
if err != nil {
return err
}
if err != nil {
if updateDbOnError {
// disable interface in database as no physical interface is available
_ = m.db.SaveInterface(ctx, iface.Identifier, func(in *domain.Interface) (*domain.Interface, error) {
if iface.IsDisabled() {
now := time.Now()
in.Disabled = &now // set
in.DisabledReason = domain.DisabledReasonInterfaceMissing
} else {
in.Disabled = nil
in.DisabledReason = ""
}
return in, nil
})
}
return fmt.Errorf("failed to change physical interface state for %s: %w", iface.Identifier, err)
}
}
}
return nil
}
func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error) {
currentUser := domain.GetUserInfo(ctx)
kp, err := domain.NewFreshKeypair()
if err != nil {
return nil, fmt.Errorf("failed to generate keys: %w", err)
}
id, err := m.getNewInterfaceName(ctx)
if err != nil {
return nil, fmt.Errorf("failed to generate new identifier: %w", err)
}
ipv4, ipv6, err := m.getFreshInterfaceIpConfig(ctx)
if err != nil {
return nil, fmt.Errorf("failed to generate new ip config: %w", err)
}
port, err := m.getFreshListenPort(ctx)
if err != nil {
return nil, fmt.Errorf("failed to generate new listen port: %w", err)
}
ips := []domain.Cidr{ipv4}
if m.cfg.Advanced.UseIpV6 {
ips = append(ips, ipv6)
}
networks := []domain.Cidr{ipv4.NetworkAddr()}
if m.cfg.Advanced.UseIpV6 {
networks = append(networks, ipv6.NetworkAddr())
}
freshInterface := &domain.Interface{
BaseModel: domain.BaseModel{
CreatedBy: string(currentUser.Id),
UpdatedBy: string(currentUser.Id),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
},
Identifier: id,
KeyPair: kp,
ListenPort: port,
Addresses: ips,
DnsStr: "",
DnsSearchStr: "",
Mtu: 1420,
FirewallMark: 0,
RoutingTable: "",
PreUp: "",
PostUp: "",
PreDown: "",
PostDown: "",
SaveConfig: m.cfg.Advanced.ConfigStoragePath != "",
DisplayName: string(id),
Type: domain.InterfaceTypeServer,
DriverType: "",
Disabled: nil,
DisabledReason: "",
PeerDefNetworkStr: domain.CidrsToString(networks),
PeerDefDnsStr: "",
PeerDefDnsSearchStr: "",
PeerDefEndpoint: "",
PeerDefAllowedIPsStr: domain.CidrsToString(networks),
PeerDefMtu: 1420,
PeerDefPersistentKeepalive: 16,
PeerDefFirewallMark: 0,
PeerDefRoutingTable: "",
PeerDefPreUp: "",
PeerDefPostUp: "",
PeerDefPreDown: "",
PeerDefPostDown: "",
}
return freshInterface, nil
}
func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) {
existingInterface, err := m.db.GetInterface(ctx, in.Identifier)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return nil, fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err)
}
if existingInterface != nil {
return nil, fmt.Errorf("interface %s already exists", in.Identifier)
}
if err := m.validateInterfaceCreation(ctx, existingInterface, in); err != nil {
return nil, fmt.Errorf("creation not allowed: %w", err)
}
in, err = m.saveInterface(ctx, in, nil)
if err != nil {
return nil, fmt.Errorf("creation failure: %w", err)
}
return in, nil
}
func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error) {
existingInterface, existingPeers, err := m.db.GetInterfaceAndPeers(ctx, in.Identifier)
if err != nil {
return nil, nil, fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err)
}
if err := m.validateInterfaceModifications(ctx, existingInterface, in); err != nil {
return nil, nil, fmt.Errorf("update not allowed: %w", err)
}
in, err = m.saveInterface(ctx, in, existingPeers)
if err != nil {
return nil, nil, fmt.Errorf("update failure: %w", err)
}
return in, existingPeers, nil
}
func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
existingInterface, err := m.db.GetInterface(ctx, id)
if err != nil {
return fmt.Errorf("unable to find interface %s: %w", id, err)
}
if err := m.validateInterfaceDeletion(ctx, existingInterface); err != nil {
return fmt.Errorf("deletion not allowed: %w", err)
}
now := time.Now()
existingInterface.Disabled = &now // simulate a disabled interface
existingInterface.DisabledReason = domain.DisabledReasonDeleted
physicalInterface, _ := m.wg.GetInterface(ctx, id)
if err := m.handleInterfacePreSaveHooks(true, existingInterface); err != nil {
return fmt.Errorf("pre-delete hooks failed: %w", err)
}
if err := m.handleInterfacePreSaveActions(existingInterface); err != nil {
return fmt.Errorf("pre-delete actions failed: %w", err)
}
if err := m.deleteInterfacePeers(ctx, id); err != nil {
return fmt.Errorf("peer deletion failure: %w", err)
}
if err := m.wg.DeleteInterface(ctx, id); err != nil {
return fmt.Errorf("wireguard deletion failure: %w", err)
}
if err := m.db.DeleteInterface(ctx, id); err != nil {
return fmt.Errorf("deletion failure: %w", err)
}
fwMark := int(existingInterface.FirewallMark)
if physicalInterface != nil && fwMark == 0 {
fwMark = int(physicalInterface.FirewallMark)
}
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
FwMark: fwMark,
Table: existingInterface.GetRoutingTable(),
})
if err := m.handleInterfacePostSaveHooks(true, existingInterface); err != nil {
return fmt.Errorf("post-delete hooks failed: %w", err)
}
return nil
}
// region helper-functions
func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface, peers []domain.Peer) (*domain.Interface, error) {
stateChanged := m.hasInterfaceStateChanged(ctx, iface)
if err := m.handleInterfacePreSaveHooks(stateChanged, iface); err != nil {
return nil, fmt.Errorf("pre-save hooks failed: %w", err)
}
if err := m.handleInterfacePreSaveActions(iface); err != nil {
return nil, fmt.Errorf("pre-save actions failed: %w", err)
}
err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
iface.CopyCalculatedAttributes(i)
err := m.wg.SaveInterface(ctx, iface.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
domain.MergeToPhysicalInterface(pi, iface)
return pi, nil
})
if err != nil {
return nil, fmt.Errorf("failed to save physical interface %s: %w", iface.Identifier, err)
}
return iface, nil
})
if err != nil {
return nil, fmt.Errorf("failed to save interface: %w", err)
}
m.bus.Publish(app.TopicRouteUpdate, "interface updated: "+string(iface.Identifier))
if iface.IsDisabled() {
physicalInterface, _ := m.wg.GetInterface(ctx, iface.Identifier)
fwMark := int(iface.FirewallMark)
if physicalInterface != nil && fwMark == 0 {
fwMark = int(physicalInterface.FirewallMark)
}
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
FwMark: fwMark,
Table: iface.GetRoutingTable(),
})
}
if err := m.handleInterfacePostSaveHooks(stateChanged, iface); err != nil {
return nil, fmt.Errorf("post-save hooks failed: %w", err)
}
m.bus.Publish(app.TopicInterfaceUpdated, iface)
return iface, nil
}
func (m Manager) hasInterfaceStateChanged(ctx context.Context, iface *domain.Interface) bool {
oldInterface, err := m.db.GetInterface(ctx, iface.Identifier)
if err != nil {
return false
}
return oldInterface.IsDisabled() != iface.IsDisabled()
}
func (m Manager) handleInterfacePreSaveActions(iface *domain.Interface) error {
if !iface.IsDisabled() {
if err := m.quick.SetDNS(iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
return fmt.Errorf("failed to update dns settings: %w", err)
}
} else {
if err := m.quick.UnsetDNS(iface.Identifier); err != nil {
return fmt.Errorf("failed to clear dns settings: %w", err)
}
}
return nil
}
func (m Manager) handleInterfacePreSaveHooks(stateChanged bool, iface *domain.Interface) error {
if !stateChanged {
return nil // do nothing if state did not change
}
if !iface.IsDisabled() {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreUp); err != nil {
return fmt.Errorf("failed to execute pre-up hook: %w", err)
}
} else {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreDown); err != nil {
return fmt.Errorf("failed to execute pre-down hook: %w", err)
}
}
return nil
}
func (m Manager) handleInterfacePostSaveHooks(stateChanged bool, iface *domain.Interface) error {
if !stateChanged {
return nil // do nothing if state did not change
}
if !iface.IsDisabled() {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostUp); err != nil {
return fmt.Errorf("failed to execute post-up hook: %w", err)
}
} else {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostDown); err != nil {
return fmt.Errorf("failed to execute post-down hook: %w", err)
}
}
return nil
}
func (m Manager) getNewInterfaceName(ctx context.Context) (domain.InterfaceIdentifier, error) {
namePrefix := "wg"
nameSuffix := 0
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return "", err
}
var name domain.InterfaceIdentifier
for {
name = domain.InterfaceIdentifier(fmt.Sprintf("%s%d", namePrefix, nameSuffix))
conflict := false
for _, in := range existingInterfaces {
if in.Identifier == name {
conflict = true
break
}
}
if !conflict {
break
}
nameSuffix++
}
return name, nil
}
func (m Manager) getFreshInterfaceIpConfig(ctx context.Context) (ipV4, ipV6 domain.Cidr, err error) {
ips, err := m.db.GetInterfaceIps(ctx)
if err != nil {
err = fmt.Errorf("failed to get existing IP addresses: %w", err)
return
}
useV6 := m.cfg.Advanced.UseIpV6
ipV4, _ = domain.CidrFromString(m.cfg.Advanced.StartCidrV4)
ipV6, _ = domain.CidrFromString(m.cfg.Advanced.StartCidrV6)
ipV4 = ipV4.FirstAddr()
ipV6 = ipV6.FirstAddr()
netV4 := ipV4.NetworkAddr()
netV6 := ipV6.NetworkAddr()
for {
v4Conflict := false
v6Conflict := false
for _, usedIps := range ips {
for _, usedIp := range usedIps {
usedNetwork := usedIp.NetworkAddr()
if netV4 == usedNetwork {
v4Conflict = true
}
if netV6 == usedNetwork {
v6Conflict = true
}
}
}
if !v4Conflict && (!useV6 || !v6Conflict) {
break
}
if v4Conflict {
netV4 = netV4.NextSubnet()
}
if v6Conflict && useV6 {
netV6 = netV6.NextSubnet()
}
if !netV4.IsValid() {
return domain.Cidr{}, domain.Cidr{}, fmt.Errorf("IPv4 space exhausted")
}
if useV6 && !netV6.IsValid() {
return domain.Cidr{}, domain.Cidr{}, fmt.Errorf("IPv6 space exhausted")
}
}
// use first address in network for interface
ipV4 = netV4.NextAddr()
ipV6 = netV6.NextAddr()
return
}
func (m Manager) getFreshListenPort(ctx context.Context) (port int, err error) {
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return -1, err
}
port = m.cfg.Advanced.StartListenPort
for {
conflict := false
for _, in := range existingInterfaces {
if in.ListenPort == port {
conflict = true
break
}
}
if !conflict {
break
}
port++
}
if port > 65535 { // maximum allowed port number (16 bit uint)
return -1, fmt.Errorf("port space exhausted")
}
return
}
func (m Manager) importInterface(ctx context.Context, in *domain.PhysicalInterface, peers []domain.PhysicalPeer) error {
now := time.Now()
iface := domain.ConvertPhysicalInterface(in)
iface.BaseModel = domain.BaseModel{
CreatedBy: "importer",
UpdatedBy: "importer",
CreatedAt: now,
UpdatedAt: now,
}
iface.PeerDefAllowedIPsStr = iface.AddressStr()
existingInterface, err := m.db.GetInterface(ctx, iface.Identifier)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return err
}
if existingInterface != nil {
return errors.New("interface already exists")
}
err = m.db.SaveInterface(ctx, iface.Identifier, func(_ *domain.Interface) (*domain.Interface, error) {
return iface, nil
})
if err != nil {
return fmt.Errorf("database save failed: %w", err)
}
// import peers
for _, peer := range peers {
err = m.importPeer(ctx, iface, &peer)
if err != nil {
return fmt.Errorf("import of peer %s failed: %w", peer.Identifier, err)
}
}
return nil
}
func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain.PhysicalPeer) error {
now := time.Now()
peer := domain.ConvertPhysicalPeer(p)
peer.BaseModel = domain.BaseModel{
CreatedBy: "importer",
UpdatedBy: "importer",
CreatedAt: now,
UpdatedAt: now,
}
peer.InterfaceIdentifier = in.Identifier
peer.EndpointPublicKey = domain.StringConfigOption{Value: in.PublicKey, Overridable: true}
peer.AllowedIPsStr = domain.StringConfigOption{Value: in.PeerDefAllowedIPsStr, Overridable: true}
peer.Interface.Addresses = p.AllowedIPs // use allowed IP's as the peer IP's
peer.Interface.DnsStr = domain.StringConfigOption{Value: in.PeerDefDnsStr, Overridable: true}
peer.Interface.DnsSearchStr = domain.StringConfigOption{Value: in.PeerDefDnsSearchStr, Overridable: true}
peer.Interface.Mtu = domain.IntConfigOption{Value: in.PeerDefMtu, Overridable: true}
peer.Interface.FirewallMark = domain.Int32ConfigOption{Value: in.PeerDefFirewallMark, Overridable: true}
peer.Interface.RoutingTable = domain.StringConfigOption{Value: in.PeerDefRoutingTable, Overridable: true}
peer.Interface.PreUp = domain.StringConfigOption{Value: in.PeerDefPreUp, Overridable: true}
peer.Interface.PostUp = domain.StringConfigOption{Value: in.PeerDefPostUp, Overridable: true}
peer.Interface.PreDown = domain.StringConfigOption{Value: in.PeerDefPreDown, Overridable: true}
peer.Interface.PostDown = domain.StringConfigOption{Value: in.PeerDefPostDown, Overridable: true}
switch in.Type {
case domain.InterfaceTypeAny:
peer.Interface.Type = domain.InterfaceTypeAny
peer.DisplayName = "Autodetected Peer (" + peer.Interface.PublicKey[0:8] + ")"
case domain.InterfaceTypeClient:
peer.Interface.Type = domain.InterfaceTypeServer
peer.DisplayName = "Autodetected Endpoint (" + peer.Interface.PublicKey[0:8] + ")"
case domain.InterfaceTypeServer:
peer.Interface.Type = domain.InterfaceTypeClient
peer.DisplayName = "Autodetected Client (" + peer.Interface.PublicKey[0:8] + ")"
}
err := m.db.SavePeer(ctx, peer.Identifier, func(_ *domain.Peer) (*domain.Peer, error) {
return peer, nil
})
if err != nil {
return fmt.Errorf("database save failed: %w", err)
}
return nil
}
func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) error {
allPeers, err := m.db.GetInterfacePeers(ctx, id)
if err != nil {
return err
}
for _, peer := range allPeers {
err = m.wg.DeletePeer(ctx, id, peer.Identifier)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err)
}
err = m.db.DeletePeer(ctx, peer.Identifier)
if err != nil {
return fmt.Errorf("peer deletion failure for %s: %w", peer.Identifier, err)
}
}
return nil
}
func (m Manager) validateInterfaceModifications(ctx context.Context, old, new *domain.Interface) error {
currentUser := domain.GetUserInfo(ctx)
if !currentUser.IsAdmin {
return fmt.Errorf("insufficient permissions")
}
return nil
}
func (m Manager) validateInterfaceCreation(ctx context.Context, old, new *domain.Interface) error {
currentUser := domain.GetUserInfo(ctx)
if new.Identifier == "" {
return fmt.Errorf("invalid interface identifier")
}
if !currentUser.IsAdmin {
return fmt.Errorf("insufficient permissions")
}
return nil
}
func (m Manager) validateInterfaceDeletion(ctx context.Context, del *domain.Interface) error {
currentUser := domain.GetUserInfo(ctx)
if !currentUser.IsAdmin {
return fmt.Errorf("insufficient permissions")
}
return nil
}
// endregion helper-functions

View File

@@ -0,0 +1,369 @@
package wireguard
import (
"context"
"errors"
"fmt"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/domain"
"github.com/sirupsen/logrus"
"time"
)
func (m Manager) CreateDefaultPeer(ctx context.Context, user *domain.User) error {
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return fmt.Errorf("failed to fetch all interfaces: %w", err)
}
var newPeers []domain.Peer
for _, iface := range existingInterfaces {
if iface.Type != domain.InterfaceTypeServer {
continue // only create default peers for server interfaces
}
peer, err := m.PreparePeer(ctx, iface.Identifier)
if err != nil {
return fmt.Errorf("failed to create default peer for interface %s: %w", iface.Identifier, err)
}
peer.UserIdentifier = user.Identifier
peer.DisplayName = fmt.Sprintf("Default Peer %s", internal.TruncateString(string(peer.Identifier), 8))
peer.Notes = fmt.Sprintf("Default peer created for user %s", user.Identifier)
newPeers = append(newPeers, *peer)
}
for i, peer := range newPeers {
_, err := m.CreatePeer(ctx, &newPeers[i])
if err != nil {
return fmt.Errorf("failed to create default peer %s on interface %s: %w",
peer.Identifier, peer.InterfaceIdentifier, err)
}
}
logrus.Infof("created %d default peers for user %s", len(newPeers), user.Identifier)
return nil
}
func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
return m.db.GetUserPeers(ctx, id)
}
func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) {
currentUser := domain.GetUserInfo(ctx)
iface, err := m.db.GetInterface(ctx, id)
if err != nil {
return nil, fmt.Errorf("unable to find interface %s: %w", id, err)
}
ips, err := m.getFreshPeerIpConfig(ctx, iface)
if err != nil {
return nil, fmt.Errorf("unable to get fresh ip addresses: %w", err)
}
kp, err := domain.NewFreshKeypair()
if err != nil {
return nil, fmt.Errorf("failed to generate keys: %w", err)
}
pk, err := domain.NewPreSharedKey()
if err != nil {
return nil, fmt.Errorf("failed to generate preshared key: %w", err)
}
peerMode := domain.InterfaceTypeClient
if iface.Type == domain.InterfaceTypeClient {
peerMode = domain.InterfaceTypeServer
}
peerId := domain.PeerIdentifier(kp.PublicKey)
freshPeer := &domain.Peer{
BaseModel: domain.BaseModel{
CreatedBy: string(currentUser.Id),
UpdatedBy: string(currentUser.Id),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
},
Endpoint: domain.NewStringConfigOption(iface.PeerDefEndpoint, true),
EndpointPublicKey: domain.NewStringConfigOption(iface.PublicKey, true),
AllowedIPsStr: domain.NewStringConfigOption(iface.PeerDefAllowedIPsStr, true),
ExtraAllowedIPsStr: "",
PresharedKey: pk,
PersistentKeepalive: domain.NewIntConfigOption(iface.PeerDefPersistentKeepalive, true),
DisplayName: fmt.Sprintf("Peer %s", internal.TruncateString(string(peerId), 8)),
Identifier: peerId,
UserIdentifier: currentUser.Id,
InterfaceIdentifier: iface.Identifier,
Disabled: nil,
DisabledReason: "",
ExpiresAt: nil,
Notes: "",
Interface: domain.PeerInterfaceConfig{
KeyPair: kp,
Type: peerMode,
Addresses: ips,
CheckAliveAddress: "",
DnsStr: domain.NewStringConfigOption(iface.PeerDefDnsStr, true),
DnsSearchStr: domain.NewStringConfigOption(iface.PeerDefDnsSearchStr, true),
Mtu: domain.NewIntConfigOption(iface.PeerDefMtu, true),
FirewallMark: domain.NewInt32ConfigOption(iface.PeerDefFirewallMark, true),
RoutingTable: domain.NewStringConfigOption(iface.PeerDefRoutingTable, true),
PreUp: domain.NewStringConfigOption(iface.PeerDefPreUp, true),
PostUp: domain.NewStringConfigOption(iface.PeerDefPostUp, true),
PreDown: domain.NewStringConfigOption(iface.PeerDefPreUp, true),
PostDown: domain.NewStringConfigOption(iface.PeerDefPostUp, true),
},
}
return freshPeer, nil
}
func (m Manager) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) {
peer, err := m.db.GetPeer(ctx, id)
if err != nil {
return nil, fmt.Errorf("unable to find peer %s: %w", id, err)
}
return peer, nil
}
func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
existingPeer, err := m.db.GetPeer(ctx, peer.Identifier)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return nil, fmt.Errorf("unable to load existing peer %s: %w", peer.Identifier, err)
}
if existingPeer != nil {
return nil, fmt.Errorf("peer %s already exists", peer.Identifier)
}
if err := m.validatePeerCreation(ctx, existingPeer, peer); err != nil {
return nil, fmt.Errorf("creation not allowed: %w", err)
}
err = m.savePeers(ctx, peer)
if err != nil {
return nil, fmt.Errorf("creation failure: %w", err)
}
return peer, nil
}
func (m Manager) CreateMultiplePeers(ctx context.Context, interfaceId domain.InterfaceIdentifier, r *domain.PeerCreationRequest) ([]domain.Peer, error) {
var newPeers []*domain.Peer
for _, id := range r.UserIdentifiers {
freshPeer, err := m.PreparePeer(ctx, interfaceId)
if err != nil {
return nil, fmt.Errorf("failed to prepare peer for interface %s: %w", interfaceId, err)
}
freshPeer.UserIdentifier = domain.UserIdentifier(id) // use id as user identifier. peers are allowed to have invalid user identifiers
if r.Suffix != "" {
freshPeer.DisplayName += " " + r.Suffix
}
if err := m.validatePeerCreation(ctx, nil, freshPeer); err != nil {
return nil, fmt.Errorf("creation not allowed: %w", err)
}
newPeers = append(newPeers, freshPeer)
}
err := m.savePeers(ctx, newPeers...)
if err != nil {
return nil, fmt.Errorf("failed to create new peers: %w", err)
}
createdPeers := make([]domain.Peer, len(newPeers))
for i := range newPeers {
createdPeers[i] = *newPeers[i]
}
return createdPeers, nil
}
func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
existingPeer, err := m.db.GetPeer(ctx, peer.Identifier)
if err != nil {
return nil, fmt.Errorf("unable to load existing peer %s: %w", peer.Identifier, err)
}
if err := m.validatePeerModifications(ctx, existingPeer, peer); err != nil {
return nil, fmt.Errorf("update not allowed: %w", err)
}
err = m.savePeers(ctx, peer)
if err != nil {
return nil, fmt.Errorf("update failure: %w", err)
}
return peer, nil
}
func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error {
peer, err := m.db.GetPeer(ctx, id)
if err != nil {
return fmt.Errorf("unable to find peer %s: %w", id, err)
}
err = m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, id)
if err != nil {
return fmt.Errorf("wireguard failed to delete peer %s: %w", id, err)
}
err = m.db.DeletePeer(ctx, id)
if err != nil {
return fmt.Errorf("failed to delete peer %s: %w", id, err)
}
return nil
}
func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error) {
_, peers, err := m.db.GetInterfaceAndPeers(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to fetch peers for interface %s: %w", id, err)
}
peerIds := make([]domain.PeerIdentifier, len(peers))
for i, peer := range peers {
peerIds[i] = peer.Identifier
}
return m.db.GetPeersStats(ctx, peerIds...)
}
func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) {
peers, err := m.db.GetUserPeers(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to fetch peers for user %s: %w", id, err)
}
peerIds := make([]domain.PeerIdentifier, len(peers))
for i, peer := range peers {
peerIds[i] = peer.Identifier
}
return m.db.GetPeersStats(ctx, peerIds...)
}
// region helper-functions
func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
interfaces := make(map[domain.InterfaceIdentifier]struct{})
for i := range peers {
peer := peers[i]
err := m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
peer.CopyCalculatedAttributes(p)
err := m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, peer)
return pp, nil
})
if err != nil {
return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
}
return peer, nil
})
if err != nil {
return fmt.Errorf("save failure for peer %s: %w", peer.Identifier, err)
}
interfaces[peer.InterfaceIdentifier] = struct{}{}
}
// Update routes after peers have changed
if len(interfaces) != 0 {
m.bus.Publish(app.TopicRouteUpdate, "peers updated")
}
for iface := range interfaces {
m.bus.Publish(app.TopicPeerInterfaceUpdated, iface)
}
return nil
}
func (m Manager) getFreshPeerIpConfig(ctx context.Context, iface *domain.Interface) (ips []domain.Cidr, err error) {
networks, err := domain.CidrsFromString(iface.PeerDefNetworkStr)
if err != nil {
err = fmt.Errorf("failed to parse default network address: %w", err)
return
}
existingIps, err := m.db.GetUsedIpsPerSubnet(ctx)
if err != nil {
err = fmt.Errorf("failed to get existing IP addresses: %w", err)
return
}
for _, network := range networks {
ip := network.NextAddr()
for {
ipConflict := false
for _, usedIp := range existingIps[network] {
if usedIp == ip {
ipConflict = true
}
}
if !ipConflict {
break
}
ip = ip.NextAddr()
if !ip.IsValid() {
return nil, fmt.Errorf("ip space on subnet %s is exhausted", network.String())
}
}
ips = append(ips, ip.HostAddr())
}
return
}
func (m Manager) validatePeerModifications(ctx context.Context, old, new *domain.Peer) error {
currentUser := domain.GetUserInfo(ctx)
if !currentUser.IsAdmin {
return fmt.Errorf("insufficient permissions")
}
return nil
}
func (m Manager) validatePeerCreation(ctx context.Context, old, new *domain.Peer) error {
currentUser := domain.GetUserInfo(ctx)
if new.Identifier == "" {
return fmt.Errorf("invalid peer identifier")
}
if !currentUser.IsAdmin {
return fmt.Errorf("insufficient permissions")
}
return nil
}
func (m Manager) validatePeerDeletion(ctx context.Context, del *domain.Peer) error {
currentUser := domain.GetUserInfo(ctx)
if !currentUser.IsAdmin {
return fmt.Errorf("insufficient permissions")
}
return nil
}
// endregion helper-functions