routes and address fixes

This commit is contained in:
Christoph Haas 2023-07-30 12:11:00 +02:00
parent 5aa94999ab
commit f4e5072f97
9 changed files with 249 additions and 162 deletions

View File

@ -43,6 +43,8 @@ func main() {
wireGuard := adapters.NewWireGuardRepository() wireGuard := adapters.NewWireGuardRepository()
wgQuick := adapters.NewWgQuickRepo()
mailer := adapters.NewSmtpMailRepo(cfg.Mail) mailer := adapters.NewSmtpMailRepo(cfg.Mail)
cfgFileSystem, err := adapters.NewFileSystemRepository(cfg.Advanced.ConfigStoragePath) cfgFileSystem, err := adapters.NewFileSystemRepository(cfg.Advanced.ConfigStoragePath)
@ -68,7 +70,7 @@ func main() {
authenticator, err := auth.NewAuthenticator(&cfg.Auth, eventBus, userManager) authenticator, err := auth.NewAuthenticator(&cfg.Auth, eventBus, userManager)
internal.AssertNoError(err) internal.AssertNoError(err)
wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, database) wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, wgQuick, database)
internal.AssertNoError(err) internal.AssertNoError(err)
statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, database, wireGuard) statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, database, wireGuard)

View File

@ -321,6 +321,11 @@ func (r *SqlRepo) upsertInterface(ui *domain.ContextUserInfo, tx *gorm.DB, in *d
return err return err
} }
err = tx.Model(in).Association("Addresses").Replace(in.Addresses)
if err != nil {
return fmt.Errorf("failed to update interface addresses: %w", err)
}
return nil return nil
} }
@ -503,6 +508,11 @@ func (r *SqlRepo) upsertPeer(ui *domain.ContextUserInfo, tx *gorm.DB, peer *doma
return err return err
} }
err = tx.Model(peer).Association("Addresses").Replace(peer.Interface.Addresses)
if err != nil {
return fmt.Errorf("failed to update peer addresses: %w", err)
}
return nil return nil
} }

View File

@ -16,9 +16,8 @@ import (
// WgRepo implements all low-level WireGuard interactions. // WgRepo implements all low-level WireGuard interactions.
type WgRepo struct { type WgRepo struct {
wg lowlevel.WireGuardClient wg lowlevel.WireGuardClient
nl lowlevel.NetlinkClient nl lowlevel.NetlinkClient
quick *WgQuickRepo
} }
func NewWireGuardRepository() *WgRepo { func NewWireGuardRepository() *WgRepo {
@ -30,9 +29,8 @@ func NewWireGuardRepository() *WgRepo {
nl := &lowlevel.NetlinkManager{} nl := &lowlevel.NetlinkManager{}
repo := &WgRepo{ repo := &WgRepo{
wg: wg, wg: wg,
nl: nl, nl: nl,
quick: NewWgQuickRepo(),
} }
return repo return repo
@ -155,40 +153,18 @@ func (r *WgRepo) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer,
return peerModel, nil return peerModel, nil
} }
func (r *WgRepo) SaveInterface(_ context.Context, iface *domain.Interface, peers []domain.Peer, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error { func (r *WgRepo) SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error {
physicalInterface, err := r.getOrCreateInterface(iface.Identifier) physicalInterface, err := r.getOrCreateInterface(id)
if err != nil { if err != nil {
return err return err
} }
wasUp := physicalInterface.DeviceUp
if updateFunc != nil { if updateFunc != nil {
physicalInterface, err = updateFunc(physicalInterface) physicalInterface, err = updateFunc(physicalInterface)
if err != nil { if err != nil {
return err return err
} }
} }
stateChanged := wasUp != physicalInterface.DeviceUp
if stateChanged {
if physicalInterface.DeviceUp {
if err := r.quick.SetDNS(iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
return fmt.Errorf("failed to update dns settings: %w", err)
}
if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreUp); err != nil {
return fmt.Errorf("failed to execute pre-up hook: %w", err)
}
} else {
if err := r.quick.UnsetDNS(iface.Identifier); err != nil {
return fmt.Errorf("failed to clear dns settings: %w", err)
}
if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreDown); err != nil {
return fmt.Errorf("failed to execute pre-down hook: %w", err)
}
}
}
if err := r.updateLowLevelInterface(physicalInterface); err != nil { if err := r.updateLowLevelInterface(physicalInterface); err != nil {
return err return err
@ -196,21 +172,6 @@ func (r *WgRepo) SaveInterface(_ context.Context, iface *domain.Interface, peers
if err := r.updateWireGuardInterface(physicalInterface); err != nil { if err := r.updateWireGuardInterface(physicalInterface); err != nil {
return err return err
} }
if err := r.updateRoutes(iface.Identifier, iface.GetRoutingTable(), iface.GetAllowedIPs(peers)); err != nil {
return err
}
if stateChanged {
if physicalInterface.DeviceUp {
if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostUp); err != nil {
return fmt.Errorf("failed to execute post-up hook: %w", err)
}
} else {
if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostDown); err != nil {
return fmt.Errorf("failed to execute post-down hook: %w", err)
}
}
}
return nil return nil
} }
@ -338,7 +299,7 @@ func (r *WgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
return nil return nil
} }
func (r *WgRepo) updateRoutes(interfaceId domain.InterfaceIdentifier, table int, allowedIPs []domain.Cidr) error { func (r *WgRepo) SaveRoutes(_ context.Context, interfaceId domain.InterfaceIdentifier, table int, allowedIPs []domain.Cidr) error {
if table == -1 { if table == -1 {
logrus.Trace("ignoring route update") logrus.Trace("ignoring route update")
return nil return nil
@ -355,19 +316,16 @@ func (r *WgRepo) updateRoutes(interfaceId domain.InterfaceIdentifier, table int,
// try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash) // try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash)
for _, allowedIP := range allowedIPs { for _, allowedIP := range allowedIPs {
if allowedIP.Prefix().Bits() == 0 { // default route // if allowedIP.Prefix().Bits() == 0 { // default route handling - TODO
// TODO err := r.nl.RouteReplace(&netlink.Route{
} else { LinkIndex: link.Attrs().Index,
err := r.nl.RouteReplace(&netlink.Route{ Dst: allowedIP.IpNet(),
LinkIndex: link.Attrs().Index, Table: table,
Dst: allowedIP.IpNet(), Scope: unix.RT_SCOPE_LINK,
Table: table, Type: unix.RTN_UNICAST,
Scope: unix.RT_SCOPE_LINK, })
Type: unix.RTN_UNICAST, if err != nil {
}) return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
if err != nil {
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
}
} }
} }

View File

@ -37,8 +37,15 @@ type InterfaceController interface {
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error)
SaveInterface(_ context.Context, iface *domain.Interface, peers []domain.Peer, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
SaveRoutes(_ context.Context, deviceId domain.InterfaceIdentifier, table int, allowedIPs []domain.Cidr) error
}
type WgQuickController interface {
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
UnsetDNS(id domain.InterfaceIdentifier) error
} }

View File

@ -16,16 +16,18 @@ type Manager struct {
cfg *config.Config cfg *config.Config
bus evbus.MessageBus bus evbus.MessageBus
db InterfaceAndPeerDatabaseRepo db InterfaceAndPeerDatabaseRepo
wg InterfaceController wg InterfaceController
quick WgQuickController
} }
func NewWireGuardManager(cfg *config.Config, bus evbus.MessageBus, wg InterfaceController, db InterfaceAndPeerDatabaseRepo) (*Manager, error) { func NewWireGuardManager(cfg *config.Config, bus evbus.MessageBus, wg InterfaceController, quick WgQuickController, db InterfaceAndPeerDatabaseRepo) (*Manager, error) {
m := &Manager{ m := &Manager{
cfg: cfg, cfg: cfg,
bus: bus, bus: bus,
wg: wg, wg: wg,
db: db, db: db,
quick: quick,
} }
m.connectToMessageBus() m.connectToMessageBus()

View File

@ -127,7 +127,7 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool
for _, iface := range interfaces { for _, iface := range interfaces {
if len(filter) != 0 && !internal.SliceContains(filter, iface.Identifier) { if len(filter) != 0 && !internal.SliceContains(filter, iface.Identifier) {
continue continue // ignore filtered interface
} }
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier) peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
@ -140,18 +140,17 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool
logrus.Debugf("creating missing interface %s...", iface.Identifier) logrus.Debugf("creating missing interface %s...", iface.Identifier)
// try to create a new interface // try to create a new interface
err := m.wg.SaveInterface(ctx, &iface, peers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { _, err = m.saveInterface(ctx, &iface, peers)
domain.MergeToPhysicalInterface(pi, &iface) if err != nil {
return err
return pi, nil }
})
if err != nil { if err != nil {
if updateDbOnError { if updateDbOnError {
// disable interface in database as no physical interface exists // disable interface in database as no physical interface exists
_ = m.db.SaveInterface(ctx, iface.Identifier, func(in *domain.Interface) (*domain.Interface, error) { _ = m.db.SaveInterface(ctx, iface.Identifier, func(in *domain.Interface) (*domain.Interface, error) {
now := time.Now() now := time.Now()
in.Disabled = &now // set in.Disabled = &now // set
in.DisabledReason = "no physical interface available" in.DisabledReason = domain.DisabledReasonInterfaceMissing
return in, nil return in, nil
}) })
} }
@ -172,11 +171,10 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool
logrus.Debugf("restoring interface state for %s to disabled=%t", iface.Identifier, iface.IsDisabled()) logrus.Debugf("restoring interface state for %s to disabled=%t", iface.Identifier, iface.IsDisabled())
// try to move interface to stored state // try to move interface to stored state
err := m.wg.SaveInterface(ctx, &iface, peers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { _, err = m.saveInterface(ctx, &iface, peers)
pi.DeviceUp = !iface.IsDisabled() if err != nil {
return err
return pi, nil }
})
if err != nil { if err != nil {
if updateDbOnError { if updateDbOnError {
// disable interface in database as no physical interface is available // disable interface in database as no physical interface is available
@ -184,7 +182,7 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool
if iface.IsDisabled() { if iface.IsDisabled() {
now := time.Now() now := time.Now()
in.Disabled = &now // set in.Disabled = &now // set
in.DisabledReason = "no physical interface active" in.DisabledReason = domain.DisabledReasonInterfaceMissing
} else { } else {
in.Disabled = nil in.Disabled = nil
in.DisabledReason = "" in.DisabledReason = ""
@ -289,19 +287,7 @@ func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*do
return nil, fmt.Errorf("creation not allowed: %w", err) return nil, fmt.Errorf("creation not allowed: %w", err)
} }
err = m.db.SaveInterface(ctx, in.Identifier, func(i *domain.Interface) (*domain.Interface, error) { in, err = m.saveInterface(ctx, in, nil)
in.CopyCalculatedAttributes(i)
err = m.wg.SaveInterface(ctx, in, nil, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
domain.MergeToPhysicalInterface(pi, in)
return pi, nil
})
if err != nil {
return nil, fmt.Errorf("failed to create physical interface %s: %w", in.Identifier, err)
}
return in, nil
})
if err != nil { if err != nil {
return nil, fmt.Errorf("creation failure: %w", err) return nil, fmt.Errorf("creation failure: %w", err)
} }
@ -319,19 +305,7 @@ func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*do
return nil, nil, fmt.Errorf("update not allowed: %w", err) return nil, nil, fmt.Errorf("update not allowed: %w", err)
} }
err = m.db.SaveInterface(ctx, in.Identifier, func(i *domain.Interface) (*domain.Interface, error) { in, err = m.saveInterface(ctx, in, existingPeers)
in.CopyCalculatedAttributes(i)
err = m.wg.SaveInterface(ctx, in, existingPeers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
domain.MergeToPhysicalInterface(pi, in)
return pi, nil
})
if err != nil {
return nil, fmt.Errorf("failed to update physical interface %s: %w", in.Identifier, err)
}
return in, nil
})
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("update failure: %w", err) return nil, nil, fmt.Errorf("update failure: %w", err)
} }
@ -349,26 +323,135 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
return fmt.Errorf("deletion not allowed: %w", err) return fmt.Errorf("deletion not allowed: %w", err)
} }
err = m.deleteInterfacePeers(ctx, id) now := time.Now()
if err != nil { existingInterface.Disabled = &now // simulate a disabled interface
existingInterface.DisabledReason = domain.DisabledReasonDeleted
if err := m.handleInterfacePreSaveHooks(true, existingInterface); err != nil {
return fmt.Errorf("pre-delete hooks failed: %w", err)
}
if err := m.handleInterfacePreSaveActions(existingInterface); err != nil {
return fmt.Errorf("pre-delete actions failed: %w", err)
}
if err := m.deleteInterfacePeers(ctx, id); err != nil {
return fmt.Errorf("peer deletion failure: %w", err) return fmt.Errorf("peer deletion failure: %w", err)
} }
err = m.wg.DeleteInterface(ctx, id) if err := m.wg.DeleteInterface(ctx, id); err != nil {
if err != nil {
return fmt.Errorf("wireguard deletion failure: %w", err) return fmt.Errorf("wireguard deletion failure: %w", err)
} }
err = m.db.DeleteInterface(ctx, id) if err := m.db.DeleteInterface(ctx, id); err != nil {
if err != nil {
return fmt.Errorf("deletion failure: %w", err) return fmt.Errorf("deletion failure: %w", err)
} }
if err := m.handleInterfacePostSaveHooks(true, existingInterface); err != nil {
return fmt.Errorf("post-delete hooks failed: %w", err)
}
return nil return nil
} }
// region helper-functions // region helper-functions
func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface, peers []domain.Peer) (*domain.Interface, error) {
stateChanged := m.hasInterfaceStateChanged(ctx, iface)
if err := m.handleInterfacePreSaveHooks(stateChanged, iface); err != nil {
return nil, fmt.Errorf("pre-save hooks failed: %w", err)
}
if err := m.handleInterfacePreSaveActions(iface); err != nil {
return nil, fmt.Errorf("pre-save actions failed: %w", err)
}
err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
iface.CopyCalculatedAttributes(i)
err := m.wg.SaveInterface(ctx, iface.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
domain.MergeToPhysicalInterface(pi, iface)
return pi, nil
})
if err != nil {
return nil, fmt.Errorf("failed to save physical interface %s: %w", iface.Identifier, err)
}
return iface, nil
})
if err != nil {
return nil, fmt.Errorf("failed to save interface: %w", err)
}
err = m.wg.SaveRoutes(ctx, iface.Identifier, iface.GetRoutingTable(), iface.GetAllowedIPs(peers))
if err != nil {
return nil, fmt.Errorf("failed to save routes: %w", err)
}
if err := m.handleInterfacePostSaveHooks(stateChanged, iface); err != nil {
return nil, fmt.Errorf("post-save hooks failed: %w", err)
}
return iface, nil
}
func (m Manager) hasInterfaceStateChanged(ctx context.Context, iface *domain.Interface) bool {
oldInterface, err := m.db.GetInterface(ctx, iface.Identifier)
if err != nil {
return false
}
return oldInterface.IsDisabled() != iface.IsDisabled()
}
func (m Manager) handleInterfacePreSaveActions(iface *domain.Interface) error {
if !iface.IsDisabled() {
if err := m.quick.SetDNS(iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
return fmt.Errorf("failed to update dns settings: %w", err)
}
} else {
if err := m.quick.UnsetDNS(iface.Identifier); err != nil {
return fmt.Errorf("failed to clear dns settings: %w", err)
}
}
return nil
}
func (m Manager) handleInterfacePreSaveHooks(stateChanged bool, iface *domain.Interface) error {
if !stateChanged {
return nil // do nothing if state did not change
}
if !iface.IsDisabled() {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreUp); err != nil {
return fmt.Errorf("failed to execute pre-up hook: %w", err)
}
} else {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreDown); err != nil {
return fmt.Errorf("failed to execute pre-down hook: %w", err)
}
}
return nil
}
func (m Manager) handleInterfacePostSaveHooks(stateChanged bool, iface *domain.Interface) error {
if !stateChanged {
return nil // do nothing if state did not change
}
if !iface.IsDisabled() {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostUp); err != nil {
return fmt.Errorf("failed to execute post-up hook: %w", err)
}
} else {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostDown); err != nil {
return fmt.Errorf("failed to execute post-down hook: %w", err)
}
}
return nil
}
func (m Manager) getNewInterfaceName(ctx context.Context) (domain.InterfaceIdentifier, error) { func (m Manager) getNewInterfaceName(ctx context.Context) (domain.InterfaceIdentifier, error) {
namePrefix := "wg" namePrefix := "wg"
nameSuffix := 0 nameSuffix := 0

View File

@ -143,20 +143,7 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
return nil, fmt.Errorf("creation not allowed: %w", err) return nil, fmt.Errorf("creation not allowed: %w", err)
} }
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) { err = m.savePeers(ctx, peer)
peer.CopyCalculatedAttributes(p)
err = m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, peer)
return pp, nil
})
if err != nil {
return nil, fmt.Errorf("failed to create wireguard peer %s: %w", peer.Identifier, err)
}
return peer, nil
})
if err != nil { if err != nil {
return nil, fmt.Errorf("creation failure: %w", err) return nil, fmt.Errorf("creation failure: %w", err)
} }
@ -165,7 +152,7 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
} }
func (m Manager) CreateMultiplePeers(ctx context.Context, interfaceId domain.InterfaceIdentifier, r *domain.PeerCreationRequest) ([]domain.Peer, error) { func (m Manager) CreateMultiplePeers(ctx context.Context, interfaceId domain.InterfaceIdentifier, r *domain.PeerCreationRequest) ([]domain.Peer, error) {
var newPeers []domain.Peer var newPeers []*domain.Peer
for _, id := range r.UserIdentifiers { for _, id := range r.UserIdentifiers {
freshPeer, err := m.PreparePeer(ctx, interfaceId) freshPeer, err := m.PreparePeer(ctx, interfaceId)
@ -178,17 +165,24 @@ func (m Manager) CreateMultiplePeers(ctx context.Context, interfaceId domain.Int
freshPeer.DisplayName += " " + r.Suffix freshPeer.DisplayName += " " + r.Suffix
} }
newPeers = append(newPeers, *freshPeer) if err := m.validatePeerCreation(ctx, nil, freshPeer); err != nil {
} return nil, fmt.Errorf("creation not allowed: %w", err)
for i, peer := range newPeers {
_, err := m.CreatePeer(ctx, &newPeers[i])
if err != nil {
return nil, fmt.Errorf("failed to create peer %s (uid: %s) for interface %s: %w", peer.Identifier, peer.UserIdentifier, interfaceId, err)
} }
newPeers = append(newPeers, freshPeer)
} }
return newPeers, nil err := m.savePeers(ctx, newPeers...)
if err != nil {
return nil, fmt.Errorf("failed to create new peers: %w", err)
}
createdPeers := make([]domain.Peer, len(newPeers))
for i := range newPeers {
createdPeers[i] = *newPeers[i]
}
return createdPeers, nil
} }
func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) { func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
@ -201,20 +195,7 @@ func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
return nil, fmt.Errorf("update not allowed: %w", err) return nil, fmt.Errorf("update not allowed: %w", err)
} }
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) { err = m.savePeers(ctx, peer)
peer.CopyCalculatedAttributes(p)
err = m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, peer)
return pp, nil
})
if err != nil {
return nil, fmt.Errorf("failed to update wireguard peer %s: %w", peer.Identifier, err)
}
return peer, nil
})
if err != nil { if err != nil {
return nil, fmt.Errorf("update failure: %w", err) return nil, fmt.Errorf("update failure: %w", err)
} }
@ -271,6 +252,47 @@ func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier)
// region helper-functions // region helper-functions
func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
interfaces := make(map[domain.InterfaceIdentifier]struct{})
for i := range peers {
peer := peers[i]
err := m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
peer.CopyCalculatedAttributes(p)
err := m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, peer)
return pp, nil
})
if err != nil {
return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
}
return peer, nil
})
if err != nil {
return fmt.Errorf("save failure for peer %s: %w", peer.Identifier, err)
}
interfaces[peer.InterfaceIdentifier] = struct{}{}
}
// Update routes after peers have changed
for ifaceId := range interfaces {
iface, ifacePeers, err := m.db.GetInterfaceAndPeers(ctx, ifaceId)
if err != nil {
return fmt.Errorf("failed to load peer interface %s: %w", ifaceId, err)
}
err = m.wg.SaveRoutes(ctx, iface.Identifier, iface.GetRoutingTable(), iface.GetAllowedIPs(ifacePeers))
if err != nil {
return fmt.Errorf("failed to update peer routes on interface %s: %w", ifaceId, err)
}
}
return nil
}
func (m Manager) getFreshPeerIpConfig(ctx context.Context, iface *domain.Interface) (ips []domain.Cidr, err error) { func (m Manager) getFreshPeerIpConfig(ctx context.Context, iface *domain.Interface) (ips []domain.Cidr, err error) {
networks, err := domain.CidrsFromString(iface.PeerDefNetworkStr) networks, err := domain.CidrsFromString(iface.PeerDefNetworkStr)
if err != nil { if err != nil {

View File

@ -22,14 +22,16 @@ func (PrivateString) String() string {
} }
const ( const (
DisabledReasonExpired = "expired" DisabledReasonExpired = "expired"
DisabledReasonUserEdit = "user edit action" DisabledReasonDeleted = "deleted"
DisabledReasonUserCreate = "user create action" DisabledReasonUserEdit = "user edit action"
DisabledReasonAdminEdit = "admin edit action" DisabledReasonUserCreate = "user create action"
DisabledReasonAdminCreate = "admin create action" DisabledReasonAdminEdit = "admin edit action"
DisabledReasonApiEdit = "api edit action" DisabledReasonAdminCreate = "admin create action"
DisabledReasonApiCreate = "api create action" DisabledReasonApiEdit = "api edit action"
DisabledReasonLdapMissing = "missing in ldap" DisabledReasonApiCreate = "api create action"
DisabledReasonUserMissing = "missing user" DisabledReasonLdapMissing = "missing in ldap"
DisabledReasonMigrationDummy = "migration dummy user" DisabledReasonUserMissing = "missing user"
DisabledReasonMigrationDummy = "migration dummy user"
DisabledReasonInterfaceMissing = "missing WireGuard interface"
) )

View File

@ -84,12 +84,13 @@ func CidrFromIpNet(ipNet net.IPNet) Cidr {
} }
func CidrFromNetlinkAddr(addr netlink.Addr) Cidr { func CidrFromNetlinkAddr(addr netlink.Addr) Cidr {
prefix, _ := CidrFromString(addr.String()) prefix, _ := CidrFromString(addr.IPNet.String())
return prefix return prefix
} }
func (c Cidr) IpNet() *net.IPNet { func (c Cidr) IpNet() *net.IPNet {
_, cidr, _ := net.ParseCIDR(c.String()) ip, cidr, _ := net.ParseCIDR(c.String())
cidr.IP = ip
return cidr return cidr
} }