diff --git a/frontend/src/components/InterfaceEditModal.vue b/frontend/src/components/InterfaceEditModal.vue index 9707f95..586290d 100644 --- a/frontend/src/components/InterfaceEditModal.vue +++ b/frontend/src/components/InterfaceEditModal.vue @@ -50,6 +50,9 @@ const currentTags = ref({ PeerDefDnsSearch: "" }) const formData = ref(freshInterface()) +const isSaving = ref(false) +const isDeleting = ref(false) +const isApplyingDefaults = ref(false) const isBackendValid = computed(() => { if (!props.visible || !selectedInterface.value) { @@ -258,6 +261,8 @@ function handleChangePeerDefDnsSearch(tags) { } async function save() { + if (isSaving.value) return + isSaving.value = true try { if (props.interfaceId!=='#NEW#') { await interfaces.UpdateInterface(selectedInterface.value.Identifier, formData.value) @@ -272,6 +277,8 @@ async function save() { text: e.toString(), type: 'error', }) + } finally { + isSaving.value = false } } @@ -280,6 +287,8 @@ async function applyPeerDefaults() { return; // do nothing for new interfaces } + if (isApplyingDefaults.value) return + isApplyingDefaults.value = true try { await interfaces.ApplyPeerDefaults(selectedInterface.value.Identifier, formData.value) @@ -297,10 +306,14 @@ async function applyPeerDefaults() { text: e.toString(), type: 'error', }) + } finally { + isApplyingDefaults.value = false } } async function del() { + if (isDeleting.value) return + isDeleting.value = true try { await interfaces.DeleteInterface(selectedInterface.value.Identifier) close() @@ -311,6 +324,8 @@ async function del() { text: e.toString(), type: 'error', }) + } finally { + isDeleting.value = false } } @@ -562,16 +577,25 @@ async function del() {

- +
diff --git a/frontend/src/components/PeerEditModal.vue b/frontend/src/components/PeerEditModal.vue index 554de5a..7c50edf 100644 --- a/frontend/src/components/PeerEditModal.vue +++ b/frontend/src/components/PeerEditModal.vue @@ -73,6 +73,8 @@ const currentTags = ref({ DnsSearch: "" }) const formData = ref(freshPeer()) +const isSaving = ref(false) +const isDeleting = ref(false) // functions @@ -270,6 +272,8 @@ function handleChangeDnsSearch(tags) { } async function save() { + if (isSaving.value) return + isSaving.value = true try { if (props.peerId !== '#NEW#') { await peers.UpdatePeer(selectedPeer.value.Identifier, formData.value) @@ -278,26 +282,30 @@ async function save() { } close() } catch (e) { - // console.log(e) notify({ title: "Failed to save peer!", text: e.toString(), type: 'error', }) + } finally { + isSaving.value = false } } async function del() { + if (isDeleting.value) return + isDeleting.value = true try { await peers.DeletePeer(selectedPeer.value.Identifier) close() } catch (e) { - // console.log(e) notify({ title: "Failed to delete peer!", text: e.toString(), type: 'error', }) + } finally { + isDeleting.value = false } } @@ -470,10 +478,15 @@ async function del() { diff --git a/frontend/src/components/PeerMultiCreateModal.vue b/frontend/src/components/PeerMultiCreateModal.vue index 5201f03..97ec25d 100644 --- a/frontend/src/components/PeerMultiCreateModal.vue +++ b/frontend/src/components/PeerMultiCreateModal.vue @@ -38,6 +38,7 @@ function freshForm() { const currentTag = ref("") const formData = ref(freshForm()) +const isSaving = ref(false) const title = computed(() => { if (!props.visible) { @@ -60,12 +61,15 @@ function handleChangeUserIdentifiers(tags) { } async function save() { + if (isSaving.value) return + isSaving.value = true if (formData.value.Identifiers.length === 0) { notify({ title: "Missing Identifiers", text: "At least one identifier is required to create a new peer.", type: 'error', }) + isSaving.value = false return } @@ -79,6 +83,8 @@ async function save() { text: e.toString(), type: 'error', }) + } finally { + isSaving.value = false } } @@ -108,7 +114,10 @@ async function save() { diff --git a/frontend/src/components/UserEditModal.vue b/frontend/src/components/UserEditModal.vue index 6a4a7bc..340dfe2 100644 --- a/frontend/src/components/UserEditModal.vue +++ b/frontend/src/components/UserEditModal.vue @@ -34,6 +34,8 @@ const title = computed(() => { }) const formData = ref(freshUser()) +const isSaving = ref(false) +const isDeleting = ref(false) const passwordWeak = computed(() => { return formData.value.Password && formData.value.Password.length > 0 && formData.value.Password.length < settings.Setting('MinPasswordLength') @@ -89,6 +91,8 @@ function close() { } async function save() { + if (isSaving.value) return + isSaving.value = true try { if (props.userId!=='#NEW#') { await users.UpdateUser(selectedUser.value.Identifier, formData.value) @@ -102,10 +106,14 @@ async function save() { text: e.toString(), type: 'error', }) + } finally { + isSaving.value = false } } async function del() { + if (isDeleting.value) return + isDeleting.value = true try { await users.DeleteUser(selectedUser.value.Identifier) close() @@ -115,6 +123,8 @@ async function del() { text: e.toString(), type: 'error', }) + } finally { + isDeleting.value = false } } @@ -193,9 +203,15 @@ async function del() { diff --git a/frontend/src/components/UserPeerEditModal.vue b/frontend/src/components/UserPeerEditModal.vue index 7594d7b..15f2f83 100644 --- a/frontend/src/components/UserPeerEditModal.vue +++ b/frontend/src/components/UserPeerEditModal.vue @@ -55,6 +55,8 @@ const title = computed(() => { }) const formData = ref(freshPeer()) +const isSaving = ref(false) +const isDeleting = ref(false) // functions @@ -163,6 +165,8 @@ function close() { } async function save() { + if (isSaving.value) return + isSaving.value = true try { if (props.peerId !== '#NEW#') { await peers.UpdatePeer(selectedPeer.value.Identifier, formData.value) @@ -171,26 +175,30 @@ async function save() { } close() } catch (e) { - // console.log(e) notify({ title: "Failed to save peer!", text: e.toString(), type: 'error', }) + } finally { + isSaving.value = false } } async function del() { + if (isDeleting.value) return + isDeleting.value = true try { await peers.DeletePeer(selectedPeer.value.Identifier) close() } catch (e) { - // console.log(e) notify({ title: "Failed to delete peer!", text: e.toString(), type: 'error', }) + } finally { + isDeleting.value = false } } @@ -283,10 +291,15 @@ async function del() { diff --git a/internal/adapters/wgcontroller/local.go b/internal/adapters/wgcontroller/local.go index 9e73d21..7f2e7fa 100644 --- a/internal/adapters/wgcontroller/local.go +++ b/internal/adapters/wgcontroller/local.go @@ -204,6 +204,11 @@ func (c LocalController) convertWireGuardPeer(peer *wgtypes.Peer) (domain.Physic ImportSource: domain.ControllerTypeLocal, } + // Set local extras - local peers are never disabled in the kernel + peerModel.SetExtras(domain.LocalPeerExtras{ + Disabled: false, + }) + for _, addr := range peer.AllowedIPs { peerModel.AllowedIPs = append(peerModel.AllowedIPs, domain.CidrFromIpNet(addr)) } @@ -410,6 +415,18 @@ func (c LocalController) SavePeer( return err } + // Check if the peer is disabled by looking at the backend extras + // For local controller, disabled peers should be deleted + if physicalPeer.GetExtras() != nil { + switch extras := physicalPeer.GetExtras().(type) { + case domain.LocalPeerExtras: + if extras.Disabled { + // Delete the peer instead of updating it + return c.deletePeer(deviceId, id) + } + } + } + if err := c.updatePeer(deviceId, physicalPeer); err != nil { return err } diff --git a/internal/adapters/wgcontroller/mikrotik.go b/internal/adapters/wgcontroller/mikrotik.go index 647f6ad..085281f 100644 --- a/internal/adapters/wgcontroller/mikrotik.go +++ b/internal/adapters/wgcontroller/mikrotik.go @@ -6,8 +6,11 @@ import ( "slices" "strconv" "strings" + "sync" "time" + "log/slog" + "github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/lowlevel" @@ -18,6 +21,10 @@ type MikrotikController struct { cfg *config.BackendMikrotik client *lowlevel.MikrotikApiClient + + // Add mutexes to prevent race conditions + interfaceMutexes sync.Map // map[domain.InterfaceIdentifier]*sync.Mutex + peerMutexes sync.Map // map[domain.PeerIdentifier]*sync.Mutex } func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik) (*MikrotikController, error) { @@ -31,16 +38,31 @@ func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik) cfg: cfg, client: client, + + interfaceMutexes: sync.Map{}, + peerMutexes: sync.Map{}, }, nil } -func (c MikrotikController) GetId() domain.InterfaceBackend { +func (c *MikrotikController) GetId() domain.InterfaceBackend { return domain.InterfaceBackend(c.cfg.Id) } +// getInterfaceMutex returns a mutex for the given interface to prevent concurrent modifications +func (c *MikrotikController) getInterfaceMutex(id domain.InterfaceIdentifier) *sync.Mutex { + mutex, _ := c.interfaceMutexes.LoadOrStore(id, &sync.Mutex{}) + return mutex.(*sync.Mutex) +} + +// getPeerMutex returns a mutex for the given peer to prevent concurrent modifications +func (c *MikrotikController) getPeerMutex(id domain.PeerIdentifier) *sync.Mutex { + mutex, _ := c.peerMutexes.LoadOrStore(id, &sync.Mutex{}) + return mutex.(*sync.Mutex) +} + // region wireguard-related -func (c MikrotikController) GetInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) { +func (c *MikrotikController) GetInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) { wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{ PropList: []string{ ".id", "name", "public-key", "private-key", "listen-port", "mtu", "disabled", "running", "comment", @@ -62,7 +84,7 @@ func (c MikrotikController) GetInterfaces(ctx context.Context) ([]domain.Physica return interfaces, nil } -func (c MikrotikController) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) ( +func (c *MikrotikController) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) ( *domain.PhysicalInterface, error, ) { @@ -85,7 +107,7 @@ func (c MikrotikController) GetInterface(ctx context.Context, id domain.Interfac return c.loadInterfaceData(ctx, wgReply.Data[0]) } -func (c MikrotikController) loadInterfaceData( +func (c *MikrotikController) loadInterfaceData( ctx context.Context, wireGuardObj lowlevel.GenericJsonObject, ) (*domain.PhysicalInterface, error) { @@ -113,7 +135,7 @@ func (c MikrotikController) loadInterfaceData( return &interfaceModel, nil } -func (c MikrotikController) loadIpAddresses( +func (c *MikrotikController) loadIpAddresses( ctx context.Context, deviceName string, ) (ipv4 []lowlevel.GenericJsonObject, ipv6 []lowlevel.GenericJsonObject, err error) { @@ -150,7 +172,7 @@ func (c MikrotikController) loadIpAddresses( return addrV4Reply.Data, addrV6Reply.Data, nil } -func (c MikrotikController) convertIpAddresses( +func (c *MikrotikController) convertIpAddresses( ipv4, ipv6 []lowlevel.GenericJsonObject, ) []domain.Cidr { addresses := make([]domain.Cidr, 0, len(ipv4)+len(ipv6)) @@ -170,7 +192,7 @@ func (c MikrotikController) convertIpAddresses( return addresses } -func (c MikrotikController) convertWireGuardInterface( +func (c *MikrotikController) convertWireGuardInterface( wg, iface lowlevel.GenericJsonObject, addresses []domain.Cidr, ) ( @@ -203,7 +225,7 @@ func (c MikrotikController) convertWireGuardInterface( return pi, nil } -func (c MikrotikController) GetPeers(ctx context.Context, deviceId domain.InterfaceIdentifier) ( +func (c *MikrotikController) GetPeers(ctx context.Context, deviceId domain.InterfaceIdentifier) ( []domain.PhysicalPeer, error, ) { @@ -237,7 +259,7 @@ func (c MikrotikController) GetPeers(ctx context.Context, deviceId domain.Interf return peers, nil } -func (c MikrotikController) convertWireGuardPeer(peer lowlevel.GenericJsonObject) ( +func (c *MikrotikController) convertWireGuardPeer(peer lowlevel.GenericJsonObject) ( domain.PhysicalPeer, error, ) { @@ -300,11 +322,16 @@ func (c MikrotikController) convertWireGuardPeer(peer lowlevel.GenericJsonObject return peerModel, nil } -func (c MikrotikController) SaveInterface( +func (c *MikrotikController) SaveInterface( ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error), ) error { + // Lock the interface to prevent concurrent modifications + mutex := c.getInterfaceMutex(id) + mutex.Lock() + defer mutex.Unlock() + physicalInterface, err := c.getOrCreateInterface(ctx, id) if err != nil { return err @@ -328,7 +355,7 @@ func (c MikrotikController) SaveInterface( return nil } -func (c MikrotikController) getOrCreateInterface( +func (c *MikrotikController) getOrCreateInterface( ctx context.Context, id domain.InterfaceIdentifier, ) (*domain.PhysicalInterface, error) { @@ -355,7 +382,7 @@ func (c MikrotikController) getOrCreateInterface( return nil, fmt.Errorf("failed to create interface %s: %v", id, createReply.Error) } -func (c MikrotikController) updateInterface(ctx context.Context, pi *domain.PhysicalInterface) error { +func (c *MikrotikController) updateInterface(ctx context.Context, pi *domain.PhysicalInterface) error { extras := pi.GetExtras().(domain.MikrotikInterfaceExtras) interfaceId := extras.Id wgReply := c.client.Update(ctx, "/interface/wireguard/"+interfaceId, lowlevel.GenericJsonObject{ @@ -403,7 +430,7 @@ func (c MikrotikController) updateInterface(ctx context.Context, pi *domain.Phys return nil } -func (c MikrotikController) updateIpAddresses( +func (c *MikrotikController) updateIpAddresses( ctx context.Context, deviceName string, currentV4, currentV6 []lowlevel.GenericJsonObject, @@ -459,7 +486,12 @@ func (c MikrotikController) updateIpAddresses( return nil } -func (c MikrotikController) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error { +func (c *MikrotikController) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error { + // Lock the interface to prevent concurrent modifications + mutex := c.getInterfaceMutex(id) + mutex.Lock() + defer mutex.Unlock() + // delete the interface's addresses currentV4, currentV6, err := c.loadIpAddresses(ctx, string(id)) if err != nil { @@ -494,8 +526,8 @@ func (c MikrotikController) DeleteInterface(ctx context.Context, id domain.Inter return nil // interface does not exist, nothing to delete } - deviceId := wgReply.Data[0].GetString(".id") - deleteReply := c.client.Delete(ctx, "/interface/wireguard/"+deviceId) + interfaceId := wgReply.Data[0].GetString(".id") + deleteReply := c.client.Delete(ctx, "/interface/wireguard/"+interfaceId) if deleteReply.Status != lowlevel.MikrotikApiStatusOk { return fmt.Errorf("failed to delete WireGuard interface %s: %v", id, deleteReply.Error) } @@ -503,12 +535,17 @@ func (c MikrotikController) DeleteInterface(ctx context.Context, id domain.Inter return nil } -func (c MikrotikController) SavePeer( +func (c *MikrotikController) SavePeer( ctx context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error), ) error { + // Lock the peer to prevent concurrent modifications + mutex := c.getPeerMutex(id) + mutex.Lock() + defer mutex.Unlock() + physicalPeer, err := c.getOrCreatePeer(ctx, deviceId, id) if err != nil { return err @@ -530,7 +567,7 @@ func (c MikrotikController) SavePeer( return nil } -func (c MikrotikController) getOrCreatePeer( +func (c *MikrotikController) getOrCreatePeer( ctx context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, @@ -546,6 +583,7 @@ func (c MikrotikController) getOrCreatePeer( }, }) if wgReply.Status == lowlevel.MikrotikApiStatusOk && len(wgReply.Data) > 0 { + slog.Debug("found existing Mikrotik peer", "peer", id, "interface", deviceId) existingPeer, err := c.convertWireGuardPeer(wgReply.Data[0]) if err != nil { return nil, err @@ -554,24 +592,26 @@ func (c MikrotikController) getOrCreatePeer( } // create a new peer if it does not exist + slog.Debug("creating new Mikrotik peer", "peer", id, "interface", deviceId) createReply := c.client.Create(ctx, "/interface/wireguard/peers", lowlevel.GenericJsonObject{ "name": fmt.Sprintf("tmp-wg-%s", id[0:8]), "interface": string(deviceId), - "public-key": string(id), // public key will be set later - "allowed-address": "169.254.254.254/32", // allowed addresses will be set later + "public-key": string(id), + "allowed-address": "0.0.0.0/0", // Use 0.0.0.0/0 as default, will be updated by updatePeer }) if createReply.Status == lowlevel.MikrotikApiStatusOk { newPeer, err := c.convertWireGuardPeer(createReply.Data) if err != nil { return nil, err } + slog.Debug("successfully created Mikrotik peer", "peer", id, "interface", deviceId) return &newPeer, nil } return nil, fmt.Errorf("failed to create peer %s for interface %s: %v", id, deviceId, createReply.Error) } -func (c MikrotikController) updatePeer( +func (c *MikrotikController) updatePeer( ctx context.Context, deviceId domain.InterfaceIdentifier, pp *domain.PhysicalPeer, @@ -586,6 +626,14 @@ func (c MikrotikController) updatePeer( endpointPort = s[1] } + allowedAddressStr := domain.CidrsToString(pp.AllowedIPs) + slog.Debug("updating Mikrotik peer", + "peer", pp.Identifier, + "interface", deviceId, + "allowed-address", allowedAddressStr, + "allowed-ips-count", len(pp.AllowedIPs), + "disabled", extras.Disabled) + wgReply := c.client.Update(ctx, "/interface/wireguard/peers/"+peerId, lowlevel.GenericJsonObject{ "name": extras.Name, "comment": extras.Comment, @@ -601,19 +649,31 @@ func (c MikrotikController) updatePeer( "client-dns": extras.ClientDns, "endpoint-address": endpoint, "endpoint-port": endpointPort, + "allowed-address": allowedAddressStr, // Add the missing allowed-address field }) if wgReply.Status != lowlevel.MikrotikApiStatusOk { return fmt.Errorf("failed to update peer %s on interface %s: %v", pp.Identifier, deviceId, wgReply.Error) } + if extras.Disabled { + slog.Debug("successfully disabled Mikrotik peer", "peer", pp.Identifier, "interface", deviceId) + } else { + slog.Debug("successfully updated Mikrotik peer", "peer", pp.Identifier, "interface", deviceId) + } + return nil } -func (c MikrotikController) DeletePeer( +func (c *MikrotikController) DeletePeer( ctx context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, ) error { + // Lock the peer to prevent concurrent modifications + mutex := c.getPeerMutex(id) + mutex.Lock() + defer mutex.Unlock() + wgReply := c.client.Query(ctx, "/interface/wireguard/peers", &lowlevel.MikrotikRequestOptions{ PropList: []string{".id"}, Filters: map[string]string{ @@ -641,17 +701,17 @@ func (c MikrotikController) DeletePeer( // region wg-quick-related -func (c MikrotikController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error { +func (c *MikrotikController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error { // TODO implement me panic("implement me") } -func (c MikrotikController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error { +func (c *MikrotikController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error { // TODO implement me panic("implement me") } -func (c MikrotikController) UnsetDNS(id domain.InterfaceIdentifier) error { +func (c *MikrotikController) UnsetDNS(id domain.InterfaceIdentifier) error { // TODO implement me panic("implement me") } @@ -660,12 +720,12 @@ func (c MikrotikController) UnsetDNS(id domain.InterfaceIdentifier) error { // region routing-related -func (c MikrotikController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error { +func (c *MikrotikController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error { // TODO implement me panic("implement me") } -func (c MikrotikController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error { +func (c *MikrotikController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error { // TODO implement me panic("implement me") } @@ -674,7 +734,7 @@ func (c MikrotikController) DeleteRouteRules(_ context.Context, rules []domain.R // region statistics-related -func (c MikrotikController) PingAddresses( +func (c *MikrotikController) PingAddresses( ctx context.Context, addr string, ) (*domain.PingerResult, error) { diff --git a/internal/app/wireguard/wireguard_peers.go b/internal/app/wireguard/wireguard_peers.go index ed35cf4..9daa3c7 100644 --- a/internal/app/wireguard/wireguard_peers.go +++ b/internal/app/wireguard/wireguard_peers.go @@ -188,29 +188,28 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee sessionUser := domain.GetUserInfo(ctx) - // Enforce peer limit for non-admin users if LimitAdditionalUserPeers is set - if m.cfg.Core.SelfProvisioningAllowed && !sessionUser.IsAdmin && m.cfg.Advanced.LimitAdditionalUserPeers > 0 { - peers, err := m.db.GetUserPeers(ctx, peer.UserIdentifier) - if err != nil { - return nil, fmt.Errorf("failed to fetch peers for user %s: %w", peer.UserIdentifier, err) - } - // Count enabled peers (disabled IS NULL) - peerCount := 0 - for _, p := range peers { - if !p.IsDisabled() { - peerCount++ - } - } - totalAllowedPeers := 1 + m.cfg.Advanced.LimitAdditionalUserPeers // 1 default peer + x additional peers - if peerCount >= totalAllowedPeers { - slog.WarnContext(ctx, "peer creation blocked due to limit", - "user", peer.UserIdentifier, - "current_count", peerCount, - "allowed_count", totalAllowedPeers) - return nil, fmt.Errorf("peer limit reached (%d peers allowed): %w", totalAllowedPeers, domain.ErrNoPermission) - } - } - + // Enforce peer limit for non-admin users if LimitAdditionalUserPeers is set + if m.cfg.Core.SelfProvisioningAllowed && !sessionUser.IsAdmin && m.cfg.Advanced.LimitAdditionalUserPeers > 0 { + peers, err := m.db.GetUserPeers(ctx, peer.UserIdentifier) + if err != nil { + return nil, fmt.Errorf("failed to fetch peers for user %s: %w", peer.UserIdentifier, err) + } + // Count enabled peers (disabled IS NULL) + peerCount := 0 + for _, p := range peers { + if !p.IsDisabled() { + peerCount++ + } + } + totalAllowedPeers := 1 + m.cfg.Advanced.LimitAdditionalUserPeers // 1 default peer + x additional peers + if peerCount >= totalAllowedPeers { + slog.WarnContext(ctx, "peer creation blocked due to limit", + "user", peer.UserIdentifier, + "current_count", peerCount, + "allowed_count", totalAllowedPeers) + return nil, fmt.Errorf("peer limit reached (%d peers allowed): %w", totalAllowedPeers, domain.ErrNoPermission) + } + } existingPeer, err := m.db.GetPeer(ctx, peer.Identifier) if err != nil && !errors.Is(err, domain.ErrNotFound) { @@ -449,33 +448,22 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error { return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err) } - if peer.IsDisabled() || peer.IsExpired() { - err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) { - peer.CopyCalculatedAttributes(p) + // Always save the peer to the backend, regardless of disabled/expired state + // The backend will handle the disabled state appropriately + err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) { + peer.CopyCalculatedAttributes(p) - if err := m.wg.GetController(*iface).DeletePeer(ctx, peer.InterfaceIdentifier, - peer.Identifier); err != nil { - return nil, fmt.Errorf("failed to delete wireguard peer %s: %w", peer.Identifier, err) - } + err := m.wg.GetController(*iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier, + func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) { + domain.MergeToPhysicalPeer(pp, peer) + return pp, nil + }) + if err != nil { + return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err) + } - return peer, nil - }) - } else { - err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) { - peer.CopyCalculatedAttributes(p) - - err := m.wg.GetController(*iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier, - func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) { - domain.MergeToPhysicalPeer(pp, peer) - return pp, nil - }) - if err != nil { - return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err) - } - - return peer, nil - }) - } + return peer, nil + }) if err != nil { return fmt.Errorf("save failure for peer %s: %w", peer.Identifier, err) } diff --git a/internal/domain/controller.go b/internal/domain/controller.go index a94e116..eaefe32 100644 --- a/internal/domain/controller.go +++ b/internal/domain/controller.go @@ -26,3 +26,7 @@ type MikrotikPeerExtras struct { ClientDns string ClientKeepalive int } + +type LocalPeerExtras struct { + Disabled bool +} diff --git a/internal/domain/peer.go b/internal/domain/peer.go index 1da5997..3f65c96 100644 --- a/internal/domain/peer.go +++ b/internal/domain/peer.go @@ -236,7 +236,8 @@ func (p *PhysicalPeer) GetExtras() any { func (p *PhysicalPeer) SetExtras(extras any) { switch extras.(type) { case MikrotikPeerExtras: // OK - default: // we only support MikrotikPeerExtras for now + case LocalPeerExtras: // OK + default: // we only support MikrotikPeerExtras and LocalPeerExtras for now panic(fmt.Sprintf("unsupported peer backend extras type %T", extras)) } @@ -288,6 +289,15 @@ func ConvertPhysicalPeer(pp *PhysicalPeer) *Peer { peer.Disabled = nil peer.DisabledReason = "" } + case ControllerTypeLocal: + extras := pp.GetExtras().(LocalPeerExtras) + if extras.Disabled { + peer.Disabled = &now + peer.DisabledReason = "Disabled by Local controller" + } else { + peer.Disabled = nil + peer.DisabledReason = "" + } } return peer @@ -326,6 +336,11 @@ func MergeToPhysicalPeer(pp *PhysicalPeer, p *Peer) { ClientKeepalive: p.PersistentKeepalive.GetValue(), } pp.SetExtras(extras) + case ControllerTypeLocal: + extras := LocalPeerExtras{ + Disabled: p.IsDisabled(), + } + pp.SetExtras(extras) } }