diff --git a/frontend/src/components/InterfaceEditModal.vue b/frontend/src/components/InterfaceEditModal.vue
index 9707f95..586290d 100644
--- a/frontend/src/components/InterfaceEditModal.vue
+++ b/frontend/src/components/InterfaceEditModal.vue
@@ -50,6 +50,9 @@ const currentTags = ref({
PeerDefDnsSearch: ""
})
const formData = ref(freshInterface())
+const isSaving = ref(false)
+const isDeleting = ref(false)
+const isApplyingDefaults = ref(false)
const isBackendValid = computed(() => {
if (!props.visible || !selectedInterface.value) {
@@ -258,6 +261,8 @@ function handleChangePeerDefDnsSearch(tags) {
}
async function save() {
+ if (isSaving.value) return
+ isSaving.value = true
try {
if (props.interfaceId!=='#NEW#') {
await interfaces.UpdateInterface(selectedInterface.value.Identifier, formData.value)
@@ -272,6 +277,8 @@ async function save() {
text: e.toString(),
type: 'error',
})
+ } finally {
+ isSaving.value = false
}
}
@@ -280,6 +287,8 @@ async function applyPeerDefaults() {
return; // do nothing for new interfaces
}
+ if (isApplyingDefaults.value) return
+ isApplyingDefaults.value = true
try {
await interfaces.ApplyPeerDefaults(selectedInterface.value.Identifier, formData.value)
@@ -297,10 +306,14 @@ async function applyPeerDefaults() {
text: e.toString(),
type: 'error',
})
+ } finally {
+ isApplyingDefaults.value = false
}
}
async function del() {
+ if (isDeleting.value) return
+ isDeleting.value = true
try {
await interfaces.DeleteInterface(selectedInterface.value.Identifier)
close()
@@ -311,6 +324,8 @@ async function del() {
text: e.toString(),
type: 'error',
})
+ } finally {
+ isDeleting.value = false
}
}
@@ -562,16 +577,25 @@ async function del() {
-
+
-
+
diff --git a/frontend/src/components/PeerEditModal.vue b/frontend/src/components/PeerEditModal.vue
index 554de5a..7c50edf 100644
--- a/frontend/src/components/PeerEditModal.vue
+++ b/frontend/src/components/PeerEditModal.vue
@@ -73,6 +73,8 @@ const currentTags = ref({
DnsSearch: ""
})
const formData = ref(freshPeer())
+const isSaving = ref(false)
+const isDeleting = ref(false)
// functions
@@ -270,6 +272,8 @@ function handleChangeDnsSearch(tags) {
}
async function save() {
+ if (isSaving.value) return
+ isSaving.value = true
try {
if (props.peerId !== '#NEW#') {
await peers.UpdatePeer(selectedPeer.value.Identifier, formData.value)
@@ -278,26 +282,30 @@ async function save() {
}
close()
} catch (e) {
- // console.log(e)
notify({
title: "Failed to save peer!",
text: e.toString(),
type: 'error',
})
+ } finally {
+ isSaving.value = false
}
}
async function del() {
+ if (isDeleting.value) return
+ isDeleting.value = true
try {
await peers.DeletePeer(selectedPeer.value.Identifier)
close()
} catch (e) {
- // console.log(e)
notify({
title: "Failed to delete peer!",
text: e.toString(),
type: 'error',
})
+ } finally {
+ isDeleting.value = false
}
}
@@ -470,10 +478,15 @@ async function del() {
-
+
-
+
diff --git a/frontend/src/components/PeerMultiCreateModal.vue b/frontend/src/components/PeerMultiCreateModal.vue
index 5201f03..97ec25d 100644
--- a/frontend/src/components/PeerMultiCreateModal.vue
+++ b/frontend/src/components/PeerMultiCreateModal.vue
@@ -38,6 +38,7 @@ function freshForm() {
const currentTag = ref("")
const formData = ref(freshForm())
+const isSaving = ref(false)
const title = computed(() => {
if (!props.visible) {
@@ -60,12 +61,15 @@ function handleChangeUserIdentifiers(tags) {
}
async function save() {
+ if (isSaving.value) return
+ isSaving.value = true
if (formData.value.Identifiers.length === 0) {
notify({
title: "Missing Identifiers",
text: "At least one identifier is required to create a new peer.",
type: 'error',
})
+ isSaving.value = false
return
}
@@ -79,6 +83,8 @@ async function save() {
text: e.toString(),
type: 'error',
})
+ } finally {
+ isSaving.value = false
}
}
@@ -108,7 +114,10 @@ async function save() {
-
+
diff --git a/frontend/src/components/UserEditModal.vue b/frontend/src/components/UserEditModal.vue
index 6a4a7bc..340dfe2 100644
--- a/frontend/src/components/UserEditModal.vue
+++ b/frontend/src/components/UserEditModal.vue
@@ -34,6 +34,8 @@ const title = computed(() => {
})
const formData = ref(freshUser())
+const isSaving = ref(false)
+const isDeleting = ref(false)
const passwordWeak = computed(() => {
return formData.value.Password && formData.value.Password.length > 0 && formData.value.Password.length < settings.Setting('MinPasswordLength')
@@ -89,6 +91,8 @@ function close() {
}
async function save() {
+ if (isSaving.value) return
+ isSaving.value = true
try {
if (props.userId!=='#NEW#') {
await users.UpdateUser(selectedUser.value.Identifier, formData.value)
@@ -102,10 +106,14 @@ async function save() {
text: e.toString(),
type: 'error',
})
+ } finally {
+ isSaving.value = false
}
}
async function del() {
+ if (isDeleting.value) return
+ isDeleting.value = true
try {
await users.DeleteUser(selectedUser.value.Identifier)
close()
@@ -115,6 +123,8 @@ async function del() {
text: e.toString(),
type: 'error',
})
+ } finally {
+ isDeleting.value = false
}
}
@@ -193,9 +203,15 @@ async function del() {
-
+
-
+
diff --git a/frontend/src/components/UserPeerEditModal.vue b/frontend/src/components/UserPeerEditModal.vue
index 7594d7b..15f2f83 100644
--- a/frontend/src/components/UserPeerEditModal.vue
+++ b/frontend/src/components/UserPeerEditModal.vue
@@ -55,6 +55,8 @@ const title = computed(() => {
})
const formData = ref(freshPeer())
+const isSaving = ref(false)
+const isDeleting = ref(false)
// functions
@@ -163,6 +165,8 @@ function close() {
}
async function save() {
+ if (isSaving.value) return
+ isSaving.value = true
try {
if (props.peerId !== '#NEW#') {
await peers.UpdatePeer(selectedPeer.value.Identifier, formData.value)
@@ -171,26 +175,30 @@ async function save() {
}
close()
} catch (e) {
- // console.log(e)
notify({
title: "Failed to save peer!",
text: e.toString(),
type: 'error',
})
+ } finally {
+ isSaving.value = false
}
}
async function del() {
+ if (isDeleting.value) return
+ isDeleting.value = true
try {
await peers.DeletePeer(selectedPeer.value.Identifier)
close()
} catch (e) {
- // console.log(e)
notify({
title: "Failed to delete peer!",
text: e.toString(),
type: 'error',
})
+ } finally {
+ isDeleting.value = false
}
}
@@ -283,10 +291,15 @@ async function del() {
-
+
-
+
diff --git a/internal/adapters/wgcontroller/local.go b/internal/adapters/wgcontroller/local.go
index 9e73d21..7f2e7fa 100644
--- a/internal/adapters/wgcontroller/local.go
+++ b/internal/adapters/wgcontroller/local.go
@@ -204,6 +204,11 @@ func (c LocalController) convertWireGuardPeer(peer *wgtypes.Peer) (domain.Physic
ImportSource: domain.ControllerTypeLocal,
}
+ // Set local extras - local peers are never disabled in the kernel
+ peerModel.SetExtras(domain.LocalPeerExtras{
+ Disabled: false,
+ })
+
for _, addr := range peer.AllowedIPs {
peerModel.AllowedIPs = append(peerModel.AllowedIPs, domain.CidrFromIpNet(addr))
}
@@ -410,6 +415,18 @@ func (c LocalController) SavePeer(
return err
}
+ // Check if the peer is disabled by looking at the backend extras
+ // For local controller, disabled peers should be deleted
+ if physicalPeer.GetExtras() != nil {
+ switch extras := physicalPeer.GetExtras().(type) {
+ case domain.LocalPeerExtras:
+ if extras.Disabled {
+ // Delete the peer instead of updating it
+ return c.deletePeer(deviceId, id)
+ }
+ }
+ }
+
if err := c.updatePeer(deviceId, physicalPeer); err != nil {
return err
}
diff --git a/internal/adapters/wgcontroller/mikrotik.go b/internal/adapters/wgcontroller/mikrotik.go
index 647f6ad..085281f 100644
--- a/internal/adapters/wgcontroller/mikrotik.go
+++ b/internal/adapters/wgcontroller/mikrotik.go
@@ -6,8 +6,11 @@ import (
"slices"
"strconv"
"strings"
+ "sync"
"time"
+ "log/slog"
+
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
"github.com/h44z/wg-portal/internal/lowlevel"
@@ -18,6 +21,10 @@ type MikrotikController struct {
cfg *config.BackendMikrotik
client *lowlevel.MikrotikApiClient
+
+ // Add mutexes to prevent race conditions
+ interfaceMutexes sync.Map // map[domain.InterfaceIdentifier]*sync.Mutex
+ peerMutexes sync.Map // map[domain.PeerIdentifier]*sync.Mutex
}
func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik) (*MikrotikController, error) {
@@ -31,16 +38,31 @@ func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik)
cfg: cfg,
client: client,
+
+ interfaceMutexes: sync.Map{},
+ peerMutexes: sync.Map{},
}, nil
}
-func (c MikrotikController) GetId() domain.InterfaceBackend {
+func (c *MikrotikController) GetId() domain.InterfaceBackend {
return domain.InterfaceBackend(c.cfg.Id)
}
+// getInterfaceMutex returns a mutex for the given interface to prevent concurrent modifications
+func (c *MikrotikController) getInterfaceMutex(id domain.InterfaceIdentifier) *sync.Mutex {
+ mutex, _ := c.interfaceMutexes.LoadOrStore(id, &sync.Mutex{})
+ return mutex.(*sync.Mutex)
+}
+
+// getPeerMutex returns a mutex for the given peer to prevent concurrent modifications
+func (c *MikrotikController) getPeerMutex(id domain.PeerIdentifier) *sync.Mutex {
+ mutex, _ := c.peerMutexes.LoadOrStore(id, &sync.Mutex{})
+ return mutex.(*sync.Mutex)
+}
+
// region wireguard-related
-func (c MikrotikController) GetInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) {
+func (c *MikrotikController) GetInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) {
wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{
PropList: []string{
".id", "name", "public-key", "private-key", "listen-port", "mtu", "disabled", "running", "comment",
@@ -62,7 +84,7 @@ func (c MikrotikController) GetInterfaces(ctx context.Context) ([]domain.Physica
return interfaces, nil
}
-func (c MikrotikController) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (
+func (c *MikrotikController) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (
*domain.PhysicalInterface,
error,
) {
@@ -85,7 +107,7 @@ func (c MikrotikController) GetInterface(ctx context.Context, id domain.Interfac
return c.loadInterfaceData(ctx, wgReply.Data[0])
}
-func (c MikrotikController) loadInterfaceData(
+func (c *MikrotikController) loadInterfaceData(
ctx context.Context,
wireGuardObj lowlevel.GenericJsonObject,
) (*domain.PhysicalInterface, error) {
@@ -113,7 +135,7 @@ func (c MikrotikController) loadInterfaceData(
return &interfaceModel, nil
}
-func (c MikrotikController) loadIpAddresses(
+func (c *MikrotikController) loadIpAddresses(
ctx context.Context,
deviceName string,
) (ipv4 []lowlevel.GenericJsonObject, ipv6 []lowlevel.GenericJsonObject, err error) {
@@ -150,7 +172,7 @@ func (c MikrotikController) loadIpAddresses(
return addrV4Reply.Data, addrV6Reply.Data, nil
}
-func (c MikrotikController) convertIpAddresses(
+func (c *MikrotikController) convertIpAddresses(
ipv4, ipv6 []lowlevel.GenericJsonObject,
) []domain.Cidr {
addresses := make([]domain.Cidr, 0, len(ipv4)+len(ipv6))
@@ -170,7 +192,7 @@ func (c MikrotikController) convertIpAddresses(
return addresses
}
-func (c MikrotikController) convertWireGuardInterface(
+func (c *MikrotikController) convertWireGuardInterface(
wg, iface lowlevel.GenericJsonObject,
addresses []domain.Cidr,
) (
@@ -203,7 +225,7 @@ func (c MikrotikController) convertWireGuardInterface(
return pi, nil
}
-func (c MikrotikController) GetPeers(ctx context.Context, deviceId domain.InterfaceIdentifier) (
+func (c *MikrotikController) GetPeers(ctx context.Context, deviceId domain.InterfaceIdentifier) (
[]domain.PhysicalPeer,
error,
) {
@@ -237,7 +259,7 @@ func (c MikrotikController) GetPeers(ctx context.Context, deviceId domain.Interf
return peers, nil
}
-func (c MikrotikController) convertWireGuardPeer(peer lowlevel.GenericJsonObject) (
+func (c *MikrotikController) convertWireGuardPeer(peer lowlevel.GenericJsonObject) (
domain.PhysicalPeer,
error,
) {
@@ -300,11 +322,16 @@ func (c MikrotikController) convertWireGuardPeer(peer lowlevel.GenericJsonObject
return peerModel, nil
}
-func (c MikrotikController) SaveInterface(
+func (c *MikrotikController) SaveInterface(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error {
+ // Lock the interface to prevent concurrent modifications
+ mutex := c.getInterfaceMutex(id)
+ mutex.Lock()
+ defer mutex.Unlock()
+
physicalInterface, err := c.getOrCreateInterface(ctx, id)
if err != nil {
return err
@@ -328,7 +355,7 @@ func (c MikrotikController) SaveInterface(
return nil
}
-func (c MikrotikController) getOrCreateInterface(
+func (c *MikrotikController) getOrCreateInterface(
ctx context.Context,
id domain.InterfaceIdentifier,
) (*domain.PhysicalInterface, error) {
@@ -355,7 +382,7 @@ func (c MikrotikController) getOrCreateInterface(
return nil, fmt.Errorf("failed to create interface %s: %v", id, createReply.Error)
}
-func (c MikrotikController) updateInterface(ctx context.Context, pi *domain.PhysicalInterface) error {
+func (c *MikrotikController) updateInterface(ctx context.Context, pi *domain.PhysicalInterface) error {
extras := pi.GetExtras().(domain.MikrotikInterfaceExtras)
interfaceId := extras.Id
wgReply := c.client.Update(ctx, "/interface/wireguard/"+interfaceId, lowlevel.GenericJsonObject{
@@ -403,7 +430,7 @@ func (c MikrotikController) updateInterface(ctx context.Context, pi *domain.Phys
return nil
}
-func (c MikrotikController) updateIpAddresses(
+func (c *MikrotikController) updateIpAddresses(
ctx context.Context,
deviceName string,
currentV4, currentV6 []lowlevel.GenericJsonObject,
@@ -459,7 +486,12 @@ func (c MikrotikController) updateIpAddresses(
return nil
}
-func (c MikrotikController) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
+func (c *MikrotikController) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
+ // Lock the interface to prevent concurrent modifications
+ mutex := c.getInterfaceMutex(id)
+ mutex.Lock()
+ defer mutex.Unlock()
+
// delete the interface's addresses
currentV4, currentV6, err := c.loadIpAddresses(ctx, string(id))
if err != nil {
@@ -494,8 +526,8 @@ func (c MikrotikController) DeleteInterface(ctx context.Context, id domain.Inter
return nil // interface does not exist, nothing to delete
}
- deviceId := wgReply.Data[0].GetString(".id")
- deleteReply := c.client.Delete(ctx, "/interface/wireguard/"+deviceId)
+ interfaceId := wgReply.Data[0].GetString(".id")
+ deleteReply := c.client.Delete(ctx, "/interface/wireguard/"+interfaceId)
if deleteReply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to delete WireGuard interface %s: %v", id, deleteReply.Error)
}
@@ -503,12 +535,17 @@ func (c MikrotikController) DeleteInterface(ctx context.Context, id domain.Inter
return nil
}
-func (c MikrotikController) SavePeer(
+func (c *MikrotikController) SavePeer(
ctx context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error {
+ // Lock the peer to prevent concurrent modifications
+ mutex := c.getPeerMutex(id)
+ mutex.Lock()
+ defer mutex.Unlock()
+
physicalPeer, err := c.getOrCreatePeer(ctx, deviceId, id)
if err != nil {
return err
@@ -530,7 +567,7 @@ func (c MikrotikController) SavePeer(
return nil
}
-func (c MikrotikController) getOrCreatePeer(
+func (c *MikrotikController) getOrCreatePeer(
ctx context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
@@ -546,6 +583,7 @@ func (c MikrotikController) getOrCreatePeer(
},
})
if wgReply.Status == lowlevel.MikrotikApiStatusOk && len(wgReply.Data) > 0 {
+ slog.Debug("found existing Mikrotik peer", "peer", id, "interface", deviceId)
existingPeer, err := c.convertWireGuardPeer(wgReply.Data[0])
if err != nil {
return nil, err
@@ -554,24 +592,26 @@ func (c MikrotikController) getOrCreatePeer(
}
// create a new peer if it does not exist
+ slog.Debug("creating new Mikrotik peer", "peer", id, "interface", deviceId)
createReply := c.client.Create(ctx, "/interface/wireguard/peers", lowlevel.GenericJsonObject{
"name": fmt.Sprintf("tmp-wg-%s", id[0:8]),
"interface": string(deviceId),
- "public-key": string(id), // public key will be set later
- "allowed-address": "169.254.254.254/32", // allowed addresses will be set later
+ "public-key": string(id),
+ "allowed-address": "0.0.0.0/0", // Use 0.0.0.0/0 as default, will be updated by updatePeer
})
if createReply.Status == lowlevel.MikrotikApiStatusOk {
newPeer, err := c.convertWireGuardPeer(createReply.Data)
if err != nil {
return nil, err
}
+ slog.Debug("successfully created Mikrotik peer", "peer", id, "interface", deviceId)
return &newPeer, nil
}
return nil, fmt.Errorf("failed to create peer %s for interface %s: %v", id, deviceId, createReply.Error)
}
-func (c MikrotikController) updatePeer(
+func (c *MikrotikController) updatePeer(
ctx context.Context,
deviceId domain.InterfaceIdentifier,
pp *domain.PhysicalPeer,
@@ -586,6 +626,14 @@ func (c MikrotikController) updatePeer(
endpointPort = s[1]
}
+ allowedAddressStr := domain.CidrsToString(pp.AllowedIPs)
+ slog.Debug("updating Mikrotik peer",
+ "peer", pp.Identifier,
+ "interface", deviceId,
+ "allowed-address", allowedAddressStr,
+ "allowed-ips-count", len(pp.AllowedIPs),
+ "disabled", extras.Disabled)
+
wgReply := c.client.Update(ctx, "/interface/wireguard/peers/"+peerId, lowlevel.GenericJsonObject{
"name": extras.Name,
"comment": extras.Comment,
@@ -601,19 +649,31 @@ func (c MikrotikController) updatePeer(
"client-dns": extras.ClientDns,
"endpoint-address": endpoint,
"endpoint-port": endpointPort,
+ "allowed-address": allowedAddressStr, // Add the missing allowed-address field
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to update peer %s on interface %s: %v", pp.Identifier, deviceId, wgReply.Error)
}
+ if extras.Disabled {
+ slog.Debug("successfully disabled Mikrotik peer", "peer", pp.Identifier, "interface", deviceId)
+ } else {
+ slog.Debug("successfully updated Mikrotik peer", "peer", pp.Identifier, "interface", deviceId)
+ }
+
return nil
}
-func (c MikrotikController) DeletePeer(
+func (c *MikrotikController) DeletePeer(
ctx context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
) error {
+ // Lock the peer to prevent concurrent modifications
+ mutex := c.getPeerMutex(id)
+ mutex.Lock()
+ defer mutex.Unlock()
+
wgReply := c.client.Query(ctx, "/interface/wireguard/peers", &lowlevel.MikrotikRequestOptions{
PropList: []string{".id"},
Filters: map[string]string{
@@ -641,17 +701,17 @@ func (c MikrotikController) DeletePeer(
// region wg-quick-related
-func (c MikrotikController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
+func (c *MikrotikController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
// TODO implement me
panic("implement me")
}
-func (c MikrotikController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
+func (c *MikrotikController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
// TODO implement me
panic("implement me")
}
-func (c MikrotikController) UnsetDNS(id domain.InterfaceIdentifier) error {
+func (c *MikrotikController) UnsetDNS(id domain.InterfaceIdentifier) error {
// TODO implement me
panic("implement me")
}
@@ -660,12 +720,12 @@ func (c MikrotikController) UnsetDNS(id domain.InterfaceIdentifier) error {
// region routing-related
-func (c MikrotikController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error {
+func (c *MikrotikController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error {
// TODO implement me
panic("implement me")
}
-func (c MikrotikController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error {
+func (c *MikrotikController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error {
// TODO implement me
panic("implement me")
}
@@ -674,7 +734,7 @@ func (c MikrotikController) DeleteRouteRules(_ context.Context, rules []domain.R
// region statistics-related
-func (c MikrotikController) PingAddresses(
+func (c *MikrotikController) PingAddresses(
ctx context.Context,
addr string,
) (*domain.PingerResult, error) {
diff --git a/internal/app/wireguard/wireguard_peers.go b/internal/app/wireguard/wireguard_peers.go
index ed35cf4..9daa3c7 100644
--- a/internal/app/wireguard/wireguard_peers.go
+++ b/internal/app/wireguard/wireguard_peers.go
@@ -188,29 +188,28 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
sessionUser := domain.GetUserInfo(ctx)
- // Enforce peer limit for non-admin users if LimitAdditionalUserPeers is set
- if m.cfg.Core.SelfProvisioningAllowed && !sessionUser.IsAdmin && m.cfg.Advanced.LimitAdditionalUserPeers > 0 {
- peers, err := m.db.GetUserPeers(ctx, peer.UserIdentifier)
- if err != nil {
- return nil, fmt.Errorf("failed to fetch peers for user %s: %w", peer.UserIdentifier, err)
- }
- // Count enabled peers (disabled IS NULL)
- peerCount := 0
- for _, p := range peers {
- if !p.IsDisabled() {
- peerCount++
- }
- }
- totalAllowedPeers := 1 + m.cfg.Advanced.LimitAdditionalUserPeers // 1 default peer + x additional peers
- if peerCount >= totalAllowedPeers {
- slog.WarnContext(ctx, "peer creation blocked due to limit",
- "user", peer.UserIdentifier,
- "current_count", peerCount,
- "allowed_count", totalAllowedPeers)
- return nil, fmt.Errorf("peer limit reached (%d peers allowed): %w", totalAllowedPeers, domain.ErrNoPermission)
- }
- }
-
+ // Enforce peer limit for non-admin users if LimitAdditionalUserPeers is set
+ if m.cfg.Core.SelfProvisioningAllowed && !sessionUser.IsAdmin && m.cfg.Advanced.LimitAdditionalUserPeers > 0 {
+ peers, err := m.db.GetUserPeers(ctx, peer.UserIdentifier)
+ if err != nil {
+ return nil, fmt.Errorf("failed to fetch peers for user %s: %w", peer.UserIdentifier, err)
+ }
+ // Count enabled peers (disabled IS NULL)
+ peerCount := 0
+ for _, p := range peers {
+ if !p.IsDisabled() {
+ peerCount++
+ }
+ }
+ totalAllowedPeers := 1 + m.cfg.Advanced.LimitAdditionalUserPeers // 1 default peer + x additional peers
+ if peerCount >= totalAllowedPeers {
+ slog.WarnContext(ctx, "peer creation blocked due to limit",
+ "user", peer.UserIdentifier,
+ "current_count", peerCount,
+ "allowed_count", totalAllowedPeers)
+ return nil, fmt.Errorf("peer limit reached (%d peers allowed): %w", totalAllowedPeers, domain.ErrNoPermission)
+ }
+ }
existingPeer, err := m.db.GetPeer(ctx, peer.Identifier)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
@@ -449,33 +448,22 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
}
- if peer.IsDisabled() || peer.IsExpired() {
- err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
- peer.CopyCalculatedAttributes(p)
+ // Always save the peer to the backend, regardless of disabled/expired state
+ // The backend will handle the disabled state appropriately
+ err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
+ peer.CopyCalculatedAttributes(p)
- if err := m.wg.GetController(*iface).DeletePeer(ctx, peer.InterfaceIdentifier,
- peer.Identifier); err != nil {
- return nil, fmt.Errorf("failed to delete wireguard peer %s: %w", peer.Identifier, err)
- }
+ err := m.wg.GetController(*iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
+ func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
+ domain.MergeToPhysicalPeer(pp, peer)
+ return pp, nil
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
+ }
- return peer, nil
- })
- } else {
- err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
- peer.CopyCalculatedAttributes(p)
-
- err := m.wg.GetController(*iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
- func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
- domain.MergeToPhysicalPeer(pp, peer)
- return pp, nil
- })
- if err != nil {
- return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
- }
-
- return peer, nil
- })
- }
+ return peer, nil
+ })
if err != nil {
return fmt.Errorf("save failure for peer %s: %w", peer.Identifier, err)
}
diff --git a/internal/domain/controller.go b/internal/domain/controller.go
index a94e116..eaefe32 100644
--- a/internal/domain/controller.go
+++ b/internal/domain/controller.go
@@ -26,3 +26,7 @@ type MikrotikPeerExtras struct {
ClientDns string
ClientKeepalive int
}
+
+type LocalPeerExtras struct {
+ Disabled bool
+}
diff --git a/internal/domain/peer.go b/internal/domain/peer.go
index 1da5997..3f65c96 100644
--- a/internal/domain/peer.go
+++ b/internal/domain/peer.go
@@ -236,7 +236,8 @@ func (p *PhysicalPeer) GetExtras() any {
func (p *PhysicalPeer) SetExtras(extras any) {
switch extras.(type) {
case MikrotikPeerExtras: // OK
- default: // we only support MikrotikPeerExtras for now
+ case LocalPeerExtras: // OK
+ default: // we only support MikrotikPeerExtras and LocalPeerExtras for now
panic(fmt.Sprintf("unsupported peer backend extras type %T", extras))
}
@@ -288,6 +289,15 @@ func ConvertPhysicalPeer(pp *PhysicalPeer) *Peer {
peer.Disabled = nil
peer.DisabledReason = ""
}
+ case ControllerTypeLocal:
+ extras := pp.GetExtras().(LocalPeerExtras)
+ if extras.Disabled {
+ peer.Disabled = &now
+ peer.DisabledReason = "Disabled by Local controller"
+ } else {
+ peer.Disabled = nil
+ peer.DisabledReason = ""
+ }
}
return peer
@@ -326,6 +336,11 @@ func MergeToPhysicalPeer(pp *PhysicalPeer, p *Peer) {
ClientKeepalive: p.PersistentKeepalive.GetValue(),
}
pp.SetExtras(extras)
+ case ControllerTypeLocal:
+ extras := LocalPeerExtras{
+ Disabled: p.IsDisabled(),
+ }
+ pp.SetExtras(extras)
}
}