mirror of
https://github.com/h44z/wg-portal.git
synced 2025-08-12 16:22:23 +00:00
routes and address fixes
This commit is contained in:
parent
5aa94999ab
commit
f4e5072f97
@ -43,6 +43,8 @@ func main() {
|
|||||||
|
|
||||||
wireGuard := adapters.NewWireGuardRepository()
|
wireGuard := adapters.NewWireGuardRepository()
|
||||||
|
|
||||||
|
wgQuick := adapters.NewWgQuickRepo()
|
||||||
|
|
||||||
mailer := adapters.NewSmtpMailRepo(cfg.Mail)
|
mailer := adapters.NewSmtpMailRepo(cfg.Mail)
|
||||||
|
|
||||||
cfgFileSystem, err := adapters.NewFileSystemRepository(cfg.Advanced.ConfigStoragePath)
|
cfgFileSystem, err := adapters.NewFileSystemRepository(cfg.Advanced.ConfigStoragePath)
|
||||||
@ -68,7 +70,7 @@ func main() {
|
|||||||
authenticator, err := auth.NewAuthenticator(&cfg.Auth, eventBus, userManager)
|
authenticator, err := auth.NewAuthenticator(&cfg.Auth, eventBus, userManager)
|
||||||
internal.AssertNoError(err)
|
internal.AssertNoError(err)
|
||||||
|
|
||||||
wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, database)
|
wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, wgQuick, database)
|
||||||
internal.AssertNoError(err)
|
internal.AssertNoError(err)
|
||||||
|
|
||||||
statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, database, wireGuard)
|
statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, database, wireGuard)
|
||||||
|
@ -321,6 +321,11 @@ func (r *SqlRepo) upsertInterface(ui *domain.ContextUserInfo, tx *gorm.DB, in *d
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = tx.Model(in).Association("Addresses").Replace(in.Addresses)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update interface addresses: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -503,6 +508,11 @@ func (r *SqlRepo) upsertPeer(ui *domain.ContextUserInfo, tx *gorm.DB, peer *doma
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = tx.Model(peer).Association("Addresses").Replace(peer.Interface.Addresses)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update peer addresses: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,9 +16,8 @@ import (
|
|||||||
|
|
||||||
// WgRepo implements all low-level WireGuard interactions.
|
// WgRepo implements all low-level WireGuard interactions.
|
||||||
type WgRepo struct {
|
type WgRepo struct {
|
||||||
wg lowlevel.WireGuardClient
|
wg lowlevel.WireGuardClient
|
||||||
nl lowlevel.NetlinkClient
|
nl lowlevel.NetlinkClient
|
||||||
quick *WgQuickRepo
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWireGuardRepository() *WgRepo {
|
func NewWireGuardRepository() *WgRepo {
|
||||||
@ -30,9 +29,8 @@ func NewWireGuardRepository() *WgRepo {
|
|||||||
nl := &lowlevel.NetlinkManager{}
|
nl := &lowlevel.NetlinkManager{}
|
||||||
|
|
||||||
repo := &WgRepo{
|
repo := &WgRepo{
|
||||||
wg: wg,
|
wg: wg,
|
||||||
nl: nl,
|
nl: nl,
|
||||||
quick: NewWgQuickRepo(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return repo
|
return repo
|
||||||
@ -155,40 +153,18 @@ func (r *WgRepo) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer,
|
|||||||
return peerModel, nil
|
return peerModel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *WgRepo) SaveInterface(_ context.Context, iface *domain.Interface, peers []domain.Peer, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error {
|
func (r *WgRepo) SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error {
|
||||||
physicalInterface, err := r.getOrCreateInterface(iface.Identifier)
|
physicalInterface, err := r.getOrCreateInterface(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
wasUp := physicalInterface.DeviceUp
|
|
||||||
if updateFunc != nil {
|
if updateFunc != nil {
|
||||||
physicalInterface, err = updateFunc(physicalInterface)
|
physicalInterface, err = updateFunc(physicalInterface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
stateChanged := wasUp != physicalInterface.DeviceUp
|
|
||||||
|
|
||||||
if stateChanged {
|
|
||||||
if physicalInterface.DeviceUp {
|
|
||||||
if err := r.quick.SetDNS(iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
|
|
||||||
return fmt.Errorf("failed to update dns settings: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreUp); err != nil {
|
|
||||||
return fmt.Errorf("failed to execute pre-up hook: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := r.quick.UnsetDNS(iface.Identifier); err != nil {
|
|
||||||
return fmt.Errorf("failed to clear dns settings: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreDown); err != nil {
|
|
||||||
return fmt.Errorf("failed to execute pre-down hook: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.updateLowLevelInterface(physicalInterface); err != nil {
|
if err := r.updateLowLevelInterface(physicalInterface); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -196,21 +172,6 @@ func (r *WgRepo) SaveInterface(_ context.Context, iface *domain.Interface, peers
|
|||||||
if err := r.updateWireGuardInterface(physicalInterface); err != nil {
|
if err := r.updateWireGuardInterface(physicalInterface); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := r.updateRoutes(iface.Identifier, iface.GetRoutingTable(), iface.GetAllowedIPs(peers)); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if stateChanged {
|
|
||||||
if physicalInterface.DeviceUp {
|
|
||||||
if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostUp); err != nil {
|
|
||||||
return fmt.Errorf("failed to execute post-up hook: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostDown); err != nil {
|
|
||||||
return fmt.Errorf("failed to execute post-down hook: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -338,7 +299,7 @@ func (r *WgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *WgRepo) updateRoutes(interfaceId domain.InterfaceIdentifier, table int, allowedIPs []domain.Cidr) error {
|
func (r *WgRepo) SaveRoutes(_ context.Context, interfaceId domain.InterfaceIdentifier, table int, allowedIPs []domain.Cidr) error {
|
||||||
if table == -1 {
|
if table == -1 {
|
||||||
logrus.Trace("ignoring route update")
|
logrus.Trace("ignoring route update")
|
||||||
return nil
|
return nil
|
||||||
@ -355,19 +316,16 @@ func (r *WgRepo) updateRoutes(interfaceId domain.InterfaceIdentifier, table int,
|
|||||||
|
|
||||||
// try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash)
|
// try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash)
|
||||||
for _, allowedIP := range allowedIPs {
|
for _, allowedIP := range allowedIPs {
|
||||||
if allowedIP.Prefix().Bits() == 0 { // default route
|
// if allowedIP.Prefix().Bits() == 0 { // default route handling - TODO
|
||||||
// TODO
|
err := r.nl.RouteReplace(&netlink.Route{
|
||||||
} else {
|
LinkIndex: link.Attrs().Index,
|
||||||
err := r.nl.RouteReplace(&netlink.Route{
|
Dst: allowedIP.IpNet(),
|
||||||
LinkIndex: link.Attrs().Index,
|
Table: table,
|
||||||
Dst: allowedIP.IpNet(),
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
Table: table,
|
Type: unix.RTN_UNICAST,
|
||||||
Scope: unix.RT_SCOPE_LINK,
|
})
|
||||||
Type: unix.RTN_UNICAST,
|
if err != nil {
|
||||||
})
|
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -37,8 +37,15 @@ type InterfaceController interface {
|
|||||||
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
||||||
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
||||||
GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error)
|
GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error)
|
||||||
SaveInterface(_ context.Context, iface *domain.Interface, peers []domain.Peer, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error
|
SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error
|
||||||
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
||||||
SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error
|
SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error
|
||||||
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
||||||
|
SaveRoutes(_ context.Context, deviceId domain.InterfaceIdentifier, table int, allowedIPs []domain.Cidr) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type WgQuickController interface {
|
||||||
|
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
|
||||||
|
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
|
||||||
|
UnsetDNS(id domain.InterfaceIdentifier) error
|
||||||
}
|
}
|
||||||
|
@ -16,16 +16,18 @@ type Manager struct {
|
|||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
bus evbus.MessageBus
|
bus evbus.MessageBus
|
||||||
|
|
||||||
db InterfaceAndPeerDatabaseRepo
|
db InterfaceAndPeerDatabaseRepo
|
||||||
wg InterfaceController
|
wg InterfaceController
|
||||||
|
quick WgQuickController
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWireGuardManager(cfg *config.Config, bus evbus.MessageBus, wg InterfaceController, db InterfaceAndPeerDatabaseRepo) (*Manager, error) {
|
func NewWireGuardManager(cfg *config.Config, bus evbus.MessageBus, wg InterfaceController, quick WgQuickController, db InterfaceAndPeerDatabaseRepo) (*Manager, error) {
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
bus: bus,
|
bus: bus,
|
||||||
wg: wg,
|
wg: wg,
|
||||||
db: db,
|
db: db,
|
||||||
|
quick: quick,
|
||||||
}
|
}
|
||||||
|
|
||||||
m.connectToMessageBus()
|
m.connectToMessageBus()
|
||||||
|
@ -127,7 +127,7 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool
|
|||||||
|
|
||||||
for _, iface := range interfaces {
|
for _, iface := range interfaces {
|
||||||
if len(filter) != 0 && !internal.SliceContains(filter, iface.Identifier) {
|
if len(filter) != 0 && !internal.SliceContains(filter, iface.Identifier) {
|
||||||
continue
|
continue // ignore filtered interface
|
||||||
}
|
}
|
||||||
|
|
||||||
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
|
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
|
||||||
@ -140,18 +140,17 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool
|
|||||||
logrus.Debugf("creating missing interface %s...", iface.Identifier)
|
logrus.Debugf("creating missing interface %s...", iface.Identifier)
|
||||||
|
|
||||||
// try to create a new interface
|
// try to create a new interface
|
||||||
err := m.wg.SaveInterface(ctx, &iface, peers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
_, err = m.saveInterface(ctx, &iface, peers)
|
||||||
domain.MergeToPhysicalInterface(pi, &iface)
|
if err != nil {
|
||||||
|
return err
|
||||||
return pi, nil
|
}
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if updateDbOnError {
|
if updateDbOnError {
|
||||||
// disable interface in database as no physical interface exists
|
// disable interface in database as no physical interface exists
|
||||||
_ = m.db.SaveInterface(ctx, iface.Identifier, func(in *domain.Interface) (*domain.Interface, error) {
|
_ = m.db.SaveInterface(ctx, iface.Identifier, func(in *domain.Interface) (*domain.Interface, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
in.Disabled = &now // set
|
in.Disabled = &now // set
|
||||||
in.DisabledReason = "no physical interface available"
|
in.DisabledReason = domain.DisabledReasonInterfaceMissing
|
||||||
return in, nil
|
return in, nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -172,11 +171,10 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool
|
|||||||
logrus.Debugf("restoring interface state for %s to disabled=%t", iface.Identifier, iface.IsDisabled())
|
logrus.Debugf("restoring interface state for %s to disabled=%t", iface.Identifier, iface.IsDisabled())
|
||||||
|
|
||||||
// try to move interface to stored state
|
// try to move interface to stored state
|
||||||
err := m.wg.SaveInterface(ctx, &iface, peers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
_, err = m.saveInterface(ctx, &iface, peers)
|
||||||
pi.DeviceUp = !iface.IsDisabled()
|
if err != nil {
|
||||||
|
return err
|
||||||
return pi, nil
|
}
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if updateDbOnError {
|
if updateDbOnError {
|
||||||
// disable interface in database as no physical interface is available
|
// disable interface in database as no physical interface is available
|
||||||
@ -184,7 +182,7 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool
|
|||||||
if iface.IsDisabled() {
|
if iface.IsDisabled() {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
in.Disabled = &now // set
|
in.Disabled = &now // set
|
||||||
in.DisabledReason = "no physical interface active"
|
in.DisabledReason = domain.DisabledReasonInterfaceMissing
|
||||||
} else {
|
} else {
|
||||||
in.Disabled = nil
|
in.Disabled = nil
|
||||||
in.DisabledReason = ""
|
in.DisabledReason = ""
|
||||||
@ -289,19 +287,7 @@ func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*do
|
|||||||
return nil, fmt.Errorf("creation not allowed: %w", err)
|
return nil, fmt.Errorf("creation not allowed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.db.SaveInterface(ctx, in.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
|
in, err = m.saveInterface(ctx, in, nil)
|
||||||
in.CopyCalculatedAttributes(i)
|
|
||||||
|
|
||||||
err = m.wg.SaveInterface(ctx, in, nil, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
|
||||||
domain.MergeToPhysicalInterface(pi, in)
|
|
||||||
return pi, nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create physical interface %s: %w", in.Identifier, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return in, nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("creation failure: %w", err)
|
return nil, fmt.Errorf("creation failure: %w", err)
|
||||||
}
|
}
|
||||||
@ -319,19 +305,7 @@ func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*do
|
|||||||
return nil, nil, fmt.Errorf("update not allowed: %w", err)
|
return nil, nil, fmt.Errorf("update not allowed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.db.SaveInterface(ctx, in.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
|
in, err = m.saveInterface(ctx, in, existingPeers)
|
||||||
in.CopyCalculatedAttributes(i)
|
|
||||||
|
|
||||||
err = m.wg.SaveInterface(ctx, in, existingPeers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
|
||||||
domain.MergeToPhysicalInterface(pi, in)
|
|
||||||
return pi, nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to update physical interface %s: %w", in.Identifier, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return in, nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("update failure: %w", err)
|
return nil, nil, fmt.Errorf("update failure: %w", err)
|
||||||
}
|
}
|
||||||
@ -349,26 +323,135 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
|
|||||||
return fmt.Errorf("deletion not allowed: %w", err)
|
return fmt.Errorf("deletion not allowed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.deleteInterfacePeers(ctx, id)
|
now := time.Now()
|
||||||
if err != nil {
|
existingInterface.Disabled = &now // simulate a disabled interface
|
||||||
|
existingInterface.DisabledReason = domain.DisabledReasonDeleted
|
||||||
|
|
||||||
|
if err := m.handleInterfacePreSaveHooks(true, existingInterface); err != nil {
|
||||||
|
return fmt.Errorf("pre-delete hooks failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.handleInterfacePreSaveActions(existingInterface); err != nil {
|
||||||
|
return fmt.Errorf("pre-delete actions failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.deleteInterfacePeers(ctx, id); err != nil {
|
||||||
return fmt.Errorf("peer deletion failure: %w", err)
|
return fmt.Errorf("peer deletion failure: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.wg.DeleteInterface(ctx, id)
|
if err := m.wg.DeleteInterface(ctx, id); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("wireguard deletion failure: %w", err)
|
return fmt.Errorf("wireguard deletion failure: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.db.DeleteInterface(ctx, id)
|
if err := m.db.DeleteInterface(ctx, id); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("deletion failure: %w", err)
|
return fmt.Errorf("deletion failure: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := m.handleInterfacePostSaveHooks(true, existingInterface); err != nil {
|
||||||
|
return fmt.Errorf("post-delete hooks failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// region helper-functions
|
// region helper-functions
|
||||||
|
|
||||||
|
func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface, peers []domain.Peer) (*domain.Interface, error) {
|
||||||
|
stateChanged := m.hasInterfaceStateChanged(ctx, iface)
|
||||||
|
|
||||||
|
if err := m.handleInterfacePreSaveHooks(stateChanged, iface); err != nil {
|
||||||
|
return nil, fmt.Errorf("pre-save hooks failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.handleInterfacePreSaveActions(iface); err != nil {
|
||||||
|
return nil, fmt.Errorf("pre-save actions failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
|
||||||
|
iface.CopyCalculatedAttributes(i)
|
||||||
|
|
||||||
|
err := m.wg.SaveInterface(ctx, iface.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||||
|
domain.MergeToPhysicalInterface(pi, iface)
|
||||||
|
return pi, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to save physical interface %s: %w", iface.Identifier, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return iface, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to save interface: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.wg.SaveRoutes(ctx, iface.Identifier, iface.GetRoutingTable(), iface.GetAllowedIPs(peers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to save routes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.handleInterfacePostSaveHooks(stateChanged, iface); err != nil {
|
||||||
|
return nil, fmt.Errorf("post-save hooks failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return iface, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) hasInterfaceStateChanged(ctx context.Context, iface *domain.Interface) bool {
|
||||||
|
oldInterface, err := m.db.GetInterface(ctx, iface.Identifier)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return oldInterface.IsDisabled() != iface.IsDisabled()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) handleInterfacePreSaveActions(iface *domain.Interface) error {
|
||||||
|
if !iface.IsDisabled() {
|
||||||
|
if err := m.quick.SetDNS(iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
|
||||||
|
return fmt.Errorf("failed to update dns settings: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := m.quick.UnsetDNS(iface.Identifier); err != nil {
|
||||||
|
return fmt.Errorf("failed to clear dns settings: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) handleInterfacePreSaveHooks(stateChanged bool, iface *domain.Interface) error {
|
||||||
|
if !stateChanged {
|
||||||
|
return nil // do nothing if state did not change
|
||||||
|
}
|
||||||
|
|
||||||
|
if !iface.IsDisabled() {
|
||||||
|
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreUp); err != nil {
|
||||||
|
return fmt.Errorf("failed to execute pre-up hook: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreDown); err != nil {
|
||||||
|
return fmt.Errorf("failed to execute pre-down hook: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Manager) handleInterfacePostSaveHooks(stateChanged bool, iface *domain.Interface) error {
|
||||||
|
if !stateChanged {
|
||||||
|
return nil // do nothing if state did not change
|
||||||
|
}
|
||||||
|
|
||||||
|
if !iface.IsDisabled() {
|
||||||
|
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostUp); err != nil {
|
||||||
|
return fmt.Errorf("failed to execute post-up hook: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostDown); err != nil {
|
||||||
|
return fmt.Errorf("failed to execute post-down hook: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m Manager) getNewInterfaceName(ctx context.Context) (domain.InterfaceIdentifier, error) {
|
func (m Manager) getNewInterfaceName(ctx context.Context) (domain.InterfaceIdentifier, error) {
|
||||||
namePrefix := "wg"
|
namePrefix := "wg"
|
||||||
nameSuffix := 0
|
nameSuffix := 0
|
||||||
|
@ -143,20 +143,7 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
|
|||||||
return nil, fmt.Errorf("creation not allowed: %w", err)
|
return nil, fmt.Errorf("creation not allowed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
err = m.savePeers(ctx, peer)
|
||||||
peer.CopyCalculatedAttributes(p)
|
|
||||||
|
|
||||||
err = m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
|
|
||||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
|
||||||
domain.MergeToPhysicalPeer(pp, peer)
|
|
||||||
return pp, nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create wireguard peer %s: %w", peer.Identifier, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return peer, nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("creation failure: %w", err)
|
return nil, fmt.Errorf("creation failure: %w", err)
|
||||||
}
|
}
|
||||||
@ -165,7 +152,7 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m Manager) CreateMultiplePeers(ctx context.Context, interfaceId domain.InterfaceIdentifier, r *domain.PeerCreationRequest) ([]domain.Peer, error) {
|
func (m Manager) CreateMultiplePeers(ctx context.Context, interfaceId domain.InterfaceIdentifier, r *domain.PeerCreationRequest) ([]domain.Peer, error) {
|
||||||
var newPeers []domain.Peer
|
var newPeers []*domain.Peer
|
||||||
|
|
||||||
for _, id := range r.UserIdentifiers {
|
for _, id := range r.UserIdentifiers {
|
||||||
freshPeer, err := m.PreparePeer(ctx, interfaceId)
|
freshPeer, err := m.PreparePeer(ctx, interfaceId)
|
||||||
@ -178,17 +165,24 @@ func (m Manager) CreateMultiplePeers(ctx context.Context, interfaceId domain.Int
|
|||||||
freshPeer.DisplayName += " " + r.Suffix
|
freshPeer.DisplayName += " " + r.Suffix
|
||||||
}
|
}
|
||||||
|
|
||||||
newPeers = append(newPeers, *freshPeer)
|
if err := m.validatePeerCreation(ctx, nil, freshPeer); err != nil {
|
||||||
}
|
return nil, fmt.Errorf("creation not allowed: %w", err)
|
||||||
|
|
||||||
for i, peer := range newPeers {
|
|
||||||
_, err := m.CreatePeer(ctx, &newPeers[i])
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create peer %s (uid: %s) for interface %s: %w", peer.Identifier, peer.UserIdentifier, interfaceId, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
newPeers = append(newPeers, freshPeer)
|
||||||
}
|
}
|
||||||
|
|
||||||
return newPeers, nil
|
err := m.savePeers(ctx, newPeers...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create new peers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
createdPeers := make([]domain.Peer, len(newPeers))
|
||||||
|
for i := range newPeers {
|
||||||
|
createdPeers[i] = *newPeers[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return createdPeers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
|
func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
|
||||||
@ -201,20 +195,7 @@ func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
|
|||||||
return nil, fmt.Errorf("update not allowed: %w", err)
|
return nil, fmt.Errorf("update not allowed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
err = m.savePeers(ctx, peer)
|
||||||
peer.CopyCalculatedAttributes(p)
|
|
||||||
|
|
||||||
err = m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
|
|
||||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
|
||||||
domain.MergeToPhysicalPeer(pp, peer)
|
|
||||||
return pp, nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to update wireguard peer %s: %w", peer.Identifier, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return peer, nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("update failure: %w", err)
|
return nil, fmt.Errorf("update failure: %w", err)
|
||||||
}
|
}
|
||||||
@ -271,6 +252,47 @@ func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier)
|
|||||||
|
|
||||||
// region helper-functions
|
// region helper-functions
|
||||||
|
|
||||||
|
func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
|
||||||
|
interfaces := make(map[domain.InterfaceIdentifier]struct{})
|
||||||
|
|
||||||
|
for i := range peers {
|
||||||
|
peer := peers[i]
|
||||||
|
err := m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
||||||
|
peer.CopyCalculatedAttributes(p)
|
||||||
|
|
||||||
|
err := m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
|
||||||
|
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
||||||
|
domain.MergeToPhysicalPeer(pp, peer)
|
||||||
|
return pp, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return peer, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("save failure for peer %s: %w", peer.Identifier, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
interfaces[peer.InterfaceIdentifier] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update routes after peers have changed
|
||||||
|
for ifaceId := range interfaces {
|
||||||
|
iface, ifacePeers, err := m.db.GetInterfaceAndPeers(ctx, ifaceId)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load peer interface %s: %w", ifaceId, err)
|
||||||
|
}
|
||||||
|
err = m.wg.SaveRoutes(ctx, iface.Identifier, iface.GetRoutingTable(), iface.GetAllowedIPs(ifacePeers))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update peer routes on interface %s: %w", ifaceId, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m Manager) getFreshPeerIpConfig(ctx context.Context, iface *domain.Interface) (ips []domain.Cidr, err error) {
|
func (m Manager) getFreshPeerIpConfig(ctx context.Context, iface *domain.Interface) (ips []domain.Cidr, err error) {
|
||||||
networks, err := domain.CidrsFromString(iface.PeerDefNetworkStr)
|
networks, err := domain.CidrsFromString(iface.PeerDefNetworkStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -22,14 +22,16 @@ func (PrivateString) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
DisabledReasonExpired = "expired"
|
DisabledReasonExpired = "expired"
|
||||||
DisabledReasonUserEdit = "user edit action"
|
DisabledReasonDeleted = "deleted"
|
||||||
DisabledReasonUserCreate = "user create action"
|
DisabledReasonUserEdit = "user edit action"
|
||||||
DisabledReasonAdminEdit = "admin edit action"
|
DisabledReasonUserCreate = "user create action"
|
||||||
DisabledReasonAdminCreate = "admin create action"
|
DisabledReasonAdminEdit = "admin edit action"
|
||||||
DisabledReasonApiEdit = "api edit action"
|
DisabledReasonAdminCreate = "admin create action"
|
||||||
DisabledReasonApiCreate = "api create action"
|
DisabledReasonApiEdit = "api edit action"
|
||||||
DisabledReasonLdapMissing = "missing in ldap"
|
DisabledReasonApiCreate = "api create action"
|
||||||
DisabledReasonUserMissing = "missing user"
|
DisabledReasonLdapMissing = "missing in ldap"
|
||||||
DisabledReasonMigrationDummy = "migration dummy user"
|
DisabledReasonUserMissing = "missing user"
|
||||||
|
DisabledReasonMigrationDummy = "migration dummy user"
|
||||||
|
DisabledReasonInterfaceMissing = "missing WireGuard interface"
|
||||||
)
|
)
|
||||||
|
@ -84,12 +84,13 @@ func CidrFromIpNet(ipNet net.IPNet) Cidr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CidrFromNetlinkAddr(addr netlink.Addr) Cidr {
|
func CidrFromNetlinkAddr(addr netlink.Addr) Cidr {
|
||||||
prefix, _ := CidrFromString(addr.String())
|
prefix, _ := CidrFromString(addr.IPNet.String())
|
||||||
return prefix
|
return prefix
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Cidr) IpNet() *net.IPNet {
|
func (c Cidr) IpNet() *net.IPNet {
|
||||||
_, cidr, _ := net.ParseCIDR(c.String())
|
ip, cidr, _ := net.ParseCIDR(c.String())
|
||||||
|
cidr.IP = ip
|
||||||
return cidr
|
return cidr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user