Mikrotik integration (#467)
Some checks failed
Docker / Build and Push (push) Has been cancelled
github-pages / deploy (push) Has been cancelled
Docker / release (push) Has been cancelled

Allow MikroTik routes as WireGuard backends
This commit is contained in:
h44z 2025-08-10 14:42:02 +02:00 committed by GitHub
parent a86f83a219
commit 112f6bfb77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 3150 additions and 205 deletions

View File

@ -32,6 +32,7 @@ The configuration portal supports using a database (SQLite, MySQL, MsSQL, or Pos
* Docker ready * Docker ready
* Can be used with existing WireGuard setups * Can be used with existing WireGuard setups
* Support for multiple WireGuard interfaces * Support for multiple WireGuard interfaces
* Supports multiple WireGuard backends (wgctrl or MikroTik [BETA])
* Peer Expiry Feature * Peer Expiry Feature
* Handles route and DNS settings like wg-quick does * Handles route and DNS settings like wg-quick does
* Exposes Prometheus metrics for monitoring and alerting * Exposes Prometheus metrics for monitoring and alerting

View File

@ -50,7 +50,8 @@ func main() {
database, err := adapters.NewSqlRepository(rawDb) database, err := adapters.NewSqlRepository(rawDb)
internal.AssertNoError(err) internal.AssertNoError(err)
wireGuard := adapters.NewWireGuardRepository() wireGuard, err := wireguard.NewControllerManager(cfg)
internal.AssertNoError(err)
wgQuick := adapters.NewWgQuickRepo() wgQuick := adapters.NewWgQuickRepo()
@ -134,7 +135,7 @@ func main() {
apiV0EndpointUsers := handlersV0.NewUserEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendUsers) apiV0EndpointUsers := handlersV0.NewUserEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendUsers)
apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendInterfaces) apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendInterfaces)
apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendPeers) apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendPeers)
apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth) apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth, wireGuard)
apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth) apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth)
apiFrontend := handlersV0.NewRestApi(apiV0Session, apiFrontend := handlersV0.NewRestApi(apiV0Session,

View File

@ -24,6 +24,9 @@ core:
self_provisioning_allowed: false self_provisioning_allowed: false
import_existing: true import_existing: true
restore_state: true restore_state: true
backend:
default: local
advanced: advanced:
log_level: info log_level: info
@ -102,6 +105,7 @@ webhook:
Below you will find sections like Below you will find sections like
[`core`](#core), [`core`](#core),
[`backend`](#backend),
[`advanced`](#advanced), [`advanced`](#advanced),
[`database`](#database), [`database`](#database),
[`statistics`](#statistics), [`statistics`](#statistics),
@ -165,6 +169,65 @@ More advanced options are found in the subsequent `Advanced` section.
--- ---
## Backend
Configuration options for the WireGuard backend, which manages the WireGuard interfaces and peers.
The current MikroTik backend is in **BETA** and may not support all features.
### `default`
- **Default:** `local`
- **Description:** The default backend to use for managing WireGuard interfaces.
Valid options are: `local`, or other backend id's configured in the `mikrotik` section.
### Mikrotik
The `mikrotik` array contains a list of MikroTik backend definitions. Each entry describes how to connect to a MikroTik RouterOS instance that hosts WireGuard interfaces.
Below are the properties for each entry inside `backend.mikrotik`:
#### `id`
- **Default:** *(empty)*
- **Description:** A unique identifier for this backend.
This value can be referenced by `backend.default` to use this backend as default.
The identifier must be unique across all backends and must not use the reserved keyword `local`.
#### `display_name`
- **Default:** *(empty)*
- **Description:** A human-friendly display name for this backend. If omitted, the `id` will be used as the display name.
#### `api_url`
- **Default:** *(empty)*
- **Description:** Base URL of the MikroTik REST API, including scheme and path, e.g., `https://10.10.10.10:8729/rest`.
#### `api_user`
- **Default:** *(empty)*
- **Description:** Username for authenticating against the MikroTik API.
Ensure that the user has sufficient permissions to manage WireGuard interfaces and peers.
#### `api_password`
- **Default:** *(empty)*
- **Description:** Password for the specified API user.
#### `api_verify_tls`
- **Default:** `false`
- **Description:** Whether to verify the TLS certificate of the MikroTik API endpoint. Set to `false` to allow self-signed certificates (not recommended for production).
#### `api_timeout`
- **Default:** `30s`
- **Description:** Timeout for API requests to the MikroTik device. Uses Go duration format (e.g., `10s`, `1m`). If omitted, a default of 30 seconds is used.
#### `concurrency`
- **Default:** `5`
- **Description:** Maximum number of concurrent API requests the backend will issue when enumerating interfaces and their details. If `0` or negative, a sane default of `5` is used.
#### `debug`
- **Default:** `false`
- **Description:** Enable verbose debug logging for the MikroTik backend.
For more details on configuring the MikroTik backend, see the [Backends](../usage/backends.md) documentation.
---
## Advanced ## Advanced
Additional or more specialized configuration options for logging and interface creation details. Additional or more specialized configuration options for logging and interface creation details.

View File

@ -0,0 +1,57 @@
# Backends
WireGuard Portal can manage WireGuard interfaces and peers on different backends.
Each backend represents a system where interfaces actually live.
You can register multiple backends and choose which one to use per interface.
A global default backend determines where newly created interfaces go (unless you explicitly choose another in the UI).
**Supported backends:**
- **Local** (default): Manages interfaces on the host running WireGuard Portal (Linux WireGuard via wgctrl). Use this when the portal should directly configure wg devices on the same server.
- **MikroTik** RouterOS (_beta_): Manages interfaces and peers on MikroTik devices via the RouterOS REST API. Use this to control WG interfaces on RouterOS v7+.
How backend selection works:
- The default backend is configured at `backend.default` (_local_ or the id of a defined MikroTik backend).
New interfaces created in the UI will use this backend by default.
- Each interface stores its backend. You can select a different backend when creating a new interface.
## Configuring MikroTik backends (RouterOS v7+)
> :warning: The MikroTik backend is currently marked beta. While basic functionality is implemented, some advanced features are not yet implemented or contain bugs. Please test carefully before using in production.
The MikroTik backend uses the [REST API](https://help.mikrotik.com/docs/spaces/ROS/pages/47579162/REST+API) under a base URL ending with /rest.
You can register one or more MikroTik devices as backends for a single WireGuard Portal instance.
### Prerequisites on MikroTik:
- RouterOS v7 with WireGuard support.
- REST API enabled and reachable over HTTP(S). A typical base URL is https://<router-address>:8729/rest or https://<router-address>/rest depending on your service setup.
- A dedicated RouterOS user with the following group permissions:
- **api** (for logging in via REST API)
- **rest-api** (for logging in via REST API)
- **read** (to read interface and peer data)
- **write** (to create/update interfaces and peers)
- **test** (to perform ping checks)
- **sensitive** (to read private keys)
- TLS certificate on the device is recommended. If you use a self-signed certificate during testing, set `api_verify_tls`: _false_ in wg-portal (not recommended for production).
Example WireGuard Portal configuration (config/config.yaml):
```yaml
backend:
# default backend decides where new interfaces are created
default: mikrotik-prod
mikrotik:
- id: mikrotik-prod # unique id, not "local"
display_name: RouterOS RB5009 # optional nice name
api_url: https://10.10.10.10/rest
api_user: wgportal
api_password: a-super-secret-password
api_verify_tls: true # set to false only if using self-signed during testing
api_timeout: 30s # maximum request duration
concurrency: 5 # limit parallel REST calls to device
debug: false # verbose logging for this backend
```
### Known limitations:
- The MikroTik backend is still in beta. Some features may not work as expected.
- Not all WireGuard Portal features are supported yet (e.g., no support for interface hooks)

View File

@ -10,11 +10,13 @@ import isCidr from "is-cidr";
import {isIP} from 'is-ip'; import {isIP} from 'is-ip';
import { freshInterface } from '@/helpers/models'; import { freshInterface } from '@/helpers/models';
import {peerStore} from "@/stores/peers"; import {peerStore} from "@/stores/peers";
import {settingsStore} from "@/stores/settings";
const { t } = useI18n() const { t } = useI18n()
const interfaces = interfaceStore() const interfaces = interfaceStore()
const peers = peerStore() const peers = peerStore()
const settings = settingsStore()
const props = defineProps({ const props = defineProps({
interfaceId: String, interfaceId: String,
@ -48,6 +50,26 @@ const currentTags = ref({
PeerDefDnsSearch: "" PeerDefDnsSearch: ""
}) })
const formData = ref(freshInterface()) const formData = ref(freshInterface())
const isSaving = ref(false)
const isDeleting = ref(false)
const isApplyingDefaults = ref(false)
const isBackendValid = computed(() => {
if (!props.visible || !selectedInterface.value) {
return true // if modal is not visible or no interface is selected, we don't care about backend validity
}
let backendId = selectedInterface.value.Backend
let valid = false
let availableBackends = settings.Setting('AvailableBackends') || []
availableBackends.forEach(backend => {
if (backend.Id === backendId) {
valid = true
}
})
return valid
})
// functions // functions
@ -61,6 +83,7 @@ watch(() => props.visible, async (newValue, oldValue) => {
formData.value.Identifier = interfaces.Prepared.Identifier formData.value.Identifier = interfaces.Prepared.Identifier
formData.value.DisplayName = interfaces.Prepared.DisplayName formData.value.DisplayName = interfaces.Prepared.DisplayName
formData.value.Mode = interfaces.Prepared.Mode formData.value.Mode = interfaces.Prepared.Mode
formData.value.Backend = interfaces.Prepared.Backend
formData.value.PublicKey = interfaces.Prepared.PublicKey formData.value.PublicKey = interfaces.Prepared.PublicKey
formData.value.PrivateKey = interfaces.Prepared.PrivateKey formData.value.PrivateKey = interfaces.Prepared.PrivateKey
@ -99,6 +122,7 @@ watch(() => props.visible, async (newValue, oldValue) => {
formData.value.Identifier = selectedInterface.value.Identifier formData.value.Identifier = selectedInterface.value.Identifier
formData.value.DisplayName = selectedInterface.value.DisplayName formData.value.DisplayName = selectedInterface.value.DisplayName
formData.value.Mode = selectedInterface.value.Mode formData.value.Mode = selectedInterface.value.Mode
formData.value.Backend = selectedInterface.value.Backend
formData.value.PublicKey = selectedInterface.value.PublicKey formData.value.PublicKey = selectedInterface.value.PublicKey
formData.value.PrivateKey = selectedInterface.value.PrivateKey formData.value.PrivateKey = selectedInterface.value.PrivateKey
@ -237,6 +261,8 @@ function handleChangePeerDefDnsSearch(tags) {
} }
async function save() { async function save() {
if (isSaving.value) return
isSaving.value = true
try { try {
if (props.interfaceId!=='#NEW#') { if (props.interfaceId!=='#NEW#') {
await interfaces.UpdateInterface(selectedInterface.value.Identifier, formData.value) await interfaces.UpdateInterface(selectedInterface.value.Identifier, formData.value)
@ -251,6 +277,8 @@ async function save() {
text: e.toString(), text: e.toString(),
type: 'error', type: 'error',
}) })
} finally {
isSaving.value = false
} }
} }
@ -259,6 +287,8 @@ async function applyPeerDefaults() {
return; // do nothing for new interfaces return; // do nothing for new interfaces
} }
if (isApplyingDefaults.value) return
isApplyingDefaults.value = true
try { try {
await interfaces.ApplyPeerDefaults(selectedInterface.value.Identifier, formData.value) await interfaces.ApplyPeerDefaults(selectedInterface.value.Identifier, formData.value)
@ -276,10 +306,14 @@ async function applyPeerDefaults() {
text: e.toString(), text: e.toString(),
type: 'error', type: 'error',
}) })
} finally {
isApplyingDefaults.value = false
} }
} }
async function del() { async function del() {
if (isDeleting.value) return
isDeleting.value = true
try { try {
await interfaces.DeleteInterface(selectedInterface.value.Identifier) await interfaces.DeleteInterface(selectedInterface.value.Identifier)
close() close()
@ -290,6 +324,8 @@ async function del() {
text: e.toString(), text: e.toString(),
type: 'error', type: 'error',
}) })
} finally {
isDeleting.value = false
} }
} }
@ -314,13 +350,22 @@ async function del() {
<label class="form-label mt-4">{{ $t('modals.interface-edit.identifier.label') }}</label> <label class="form-label mt-4">{{ $t('modals.interface-edit.identifier.label') }}</label>
<input v-model="formData.Identifier" class="form-control" :placeholder="$t('modals.interface-edit.identifier.placeholder')" type="text"> <input v-model="formData.Identifier" class="form-control" :placeholder="$t('modals.interface-edit.identifier.placeholder')" type="text">
</div> </div>
<div class="form-group"> <div class="row">
<label class="form-label mt-4">{{ $t('modals.interface-edit.mode.label') }}</label> <div class="form-group col-md-6">
<select v-model="formData.Mode" class="form-select"> <label class="form-label mt-4">{{ $t('modals.interface-edit.mode.label') }}</label>
<option value="server">{{ $t('modals.interface-edit.mode.server') }}</option> <select v-model="formData.Mode" class="form-select">
<option value="client">{{ $t('modals.interface-edit.mode.client') }}</option> <option value="server">{{ $t('modals.interface-edit.mode.server') }}</option>
<option value="any">{{ $t('modals.interface-edit.mode.any') }}</option> <option value="client">{{ $t('modals.interface-edit.mode.client') }}</option>
</select> <option value="any">{{ $t('modals.interface-edit.mode.any') }}</option>
</select>
</div>
<div class="form-group col-md-6">
<label class="form-label mt-4" for="ifaceBackendSelector">{{ $t('modals.interface-edit.backend.label') }}</label>
<select id="ifaceBackendSelector" v-model="formData.Backend" class="form-select" aria-describedby="backendHelp">
<option v-for="backend in settings.Setting('AvailableBackends')" :value="backend.Id">{{ backend.Id === 'local' ? $t(backend.Name) : backend.Name }}</option>
</select>
<small v-if="!isBackendValid" id="backendHelp" class="form-text text-warning">{{ $t('modals.interface-edit.backend.invalid-label') }}</small>
</div>
</div> </div>
<div class="form-group"> <div class="form-group">
<label class="form-label mt-4">{{ $t('modals.interface-edit.display-name.label') }}</label> <label class="form-label mt-4">{{ $t('modals.interface-edit.display-name.label') }}</label>
@ -385,12 +430,14 @@ async function del() {
<label class="form-label mt-4">{{ $t('modals.interface-edit.mtu.label') }}</label> <label class="form-label mt-4">{{ $t('modals.interface-edit.mtu.label') }}</label>
<input v-model="formData.Mtu" class="form-control" :placeholder="$t('modals.interface-edit.mtu.placeholder')" type="number"> <input v-model="formData.Mtu" class="form-control" :placeholder="$t('modals.interface-edit.mtu.placeholder')" type="number">
</div> </div>
<div class="form-group col-md-6"> <div class="form-group col-md-6" v-if="formData.Backend==='local'">
<label class="form-label mt-4">{{ $t('modals.interface-edit.firewall-mark.label') }}</label> <label class="form-label mt-4">{{ $t('modals.interface-edit.firewall-mark.label') }}</label>
<input v-model="formData.FirewallMark" class="form-control" :placeholder="$t('modals.interface-edit.firewall-mark.placeholder')" type="number"> <input v-model="formData.FirewallMark" class="form-control" :placeholder="$t('modals.interface-edit.firewall-mark.placeholder')" type="number">
</div> </div>
<div class="form-group col-md-6" v-else>
</div>
</div> </div>
<div class="row"> <div class="row" v-if="formData.Backend==='local'">
<div class="form-group col-md-6"> <div class="form-group col-md-6">
<label class="form-label mt-4">{{ $t('modals.interface-edit.routing-table.label') }}</label> <label class="form-label mt-4">{{ $t('modals.interface-edit.routing-table.label') }}</label>
<input v-model="formData.RoutingTable" aria-describedby="routingTableHelp" class="form-control" :placeholder="$t('modals.interface-edit.routing-table.placeholder')" type="text"> <input v-model="formData.RoutingTable" aria-describedby="routingTableHelp" class="form-control" :placeholder="$t('modals.interface-edit.routing-table.placeholder')" type="text">
@ -530,16 +577,25 @@ async function del() {
</fieldset> </fieldset>
<fieldset v-if="props.interfaceId!=='#NEW#'" class="text-end"> <fieldset v-if="props.interfaceId!=='#NEW#'" class="text-end">
<hr class="mt-4"> <hr class="mt-4">
<button class="btn btn-primary me-1" type="button" @click.prevent="applyPeerDefaults">{{ $t('modals.interface-edit.button-apply-defaults') }}</button> <button class="btn btn-primary me-1" type="button" @click.prevent="applyPeerDefaults" :disabled="isApplyingDefaults">
<span v-if="isApplyingDefaults" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
{{ $t('modals.interface-edit.button-apply-defaults') }}
</button>
</fieldset> </fieldset>
</div> </div>
</div> </div>
</template> </template>
<template #footer> <template #footer>
<div class="flex-fill text-start"> <div class="flex-fill text-start">
<button v-if="props.interfaceId!=='#NEW#'" class="btn btn-danger me-1" type="button" @click.prevent="del">{{ $t('general.delete') }}</button> <button v-if="props.interfaceId!=='#NEW#'" class="btn btn-danger me-1" type="button" @click.prevent="del" :disabled="isDeleting">
<span v-if="isDeleting" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
{{ $t('general.delete') }}
</button>
</div> </div>
<button class="btn btn-primary me-1" type="button" @click.prevent="save">{{ $t('general.save') }}</button> <button class="btn btn-primary me-1" type="button" @click.prevent="save" :disabled="isSaving">
<span v-if="isSaving" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
{{ $t('general.save') }}
</button>
<button class="btn btn-secondary" type="button" @click.prevent="close">{{ $t('general.close') }}</button> <button class="btn btn-secondary" type="button" @click.prevent="close">{{ $t('general.close') }}</button>
</template> </template>
</Modal> </Modal>

View File

@ -73,6 +73,8 @@ const currentTags = ref({
DnsSearch: "" DnsSearch: ""
}) })
const formData = ref(freshPeer()) const formData = ref(freshPeer())
const isSaving = ref(false)
const isDeleting = ref(false)
// functions // functions
@ -270,6 +272,8 @@ function handleChangeDnsSearch(tags) {
} }
async function save() { async function save() {
if (isSaving.value) return
isSaving.value = true
try { try {
if (props.peerId !== '#NEW#') { if (props.peerId !== '#NEW#') {
await peers.UpdatePeer(selectedPeer.value.Identifier, formData.value) await peers.UpdatePeer(selectedPeer.value.Identifier, formData.value)
@ -278,26 +282,30 @@ async function save() {
} }
close() close()
} catch (e) { } catch (e) {
// console.log(e)
notify({ notify({
title: "Failed to save peer!", title: "Failed to save peer!",
text: e.toString(), text: e.toString(),
type: 'error', type: 'error',
}) })
} finally {
isSaving.value = false
} }
} }
async function del() { async function del() {
if (isDeleting.value) return
isDeleting.value = true
try { try {
await peers.DeletePeer(selectedPeer.value.Identifier) await peers.DeletePeer(selectedPeer.value.Identifier)
close() close()
} catch (e) { } catch (e) {
// console.log(e)
notify({ notify({
title: "Failed to delete peer!", title: "Failed to delete peer!",
text: e.toString(), text: e.toString(),
type: 'error', type: 'error',
}) })
} finally {
isDeleting.value = false
} }
} }
@ -470,10 +478,15 @@ async function del() {
</template> </template>
<template #footer> <template #footer>
<div class="flex-fill text-start"> <div class="flex-fill text-start">
<button v-if="props.peerId !== '#NEW#'" class="btn btn-danger me-1" type="button" @click.prevent="del">{{ <button v-if="props.peerId !== '#NEW#'" class="btn btn-danger me-1" type="button" @click.prevent="del" :disabled="isDeleting">
$t('general.delete') }}</button> <span v-if="isDeleting" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
{{ $t('general.delete') }}
</button>
</div> </div>
<button class="btn btn-primary me-1" type="button" @click.prevent="save">{{ $t('general.save') }}</button> <button class="btn btn-primary me-1" type="button" @click.prevent="save" :disabled="isSaving">
<span v-if="isSaving" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
{{ $t('general.save') }}
</button>
<button class="btn btn-secondary" type="button" @click.prevent="close">{{ $t('general.close') }}</button> <button class="btn btn-secondary" type="button" @click.prevent="close">{{ $t('general.close') }}</button>
</template> </template>
</Modal> </Modal>

View File

@ -38,6 +38,7 @@ function freshForm() {
const currentTag = ref("") const currentTag = ref("")
const formData = ref(freshForm()) const formData = ref(freshForm())
const isSaving = ref(false)
const title = computed(() => { const title = computed(() => {
if (!props.visible) { if (!props.visible) {
@ -60,12 +61,15 @@ function handleChangeUserIdentifiers(tags) {
} }
async function save() { async function save() {
if (isSaving.value) return
isSaving.value = true
if (formData.value.Identifiers.length === 0) { if (formData.value.Identifiers.length === 0) {
notify({ notify({
title: "Missing Identifiers", title: "Missing Identifiers",
text: "At least one identifier is required to create a new peer.", text: "At least one identifier is required to create a new peer.",
type: 'error', type: 'error',
}) })
isSaving.value = false
return return
} }
@ -79,6 +83,8 @@ async function save() {
text: e.toString(), text: e.toString(),
type: 'error', type: 'error',
}) })
} finally {
isSaving.value = false
} }
} }
@ -108,7 +114,10 @@ async function save() {
</fieldset> </fieldset>
</template> </template>
<template #footer> <template #footer>
<button class="btn btn-primary me-1" type="button" @click.prevent="save">{{ $t('general.save') }}</button> <button class="btn btn-primary me-1" type="button" @click.prevent="save" :disabled="isSaving">
<span v-if="isSaving" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
{{ $t('general.save') }}
</button>
<button class="btn btn-secondary" type="button" @click.prevent="close">{{ $t('general.close') }}</button> <button class="btn btn-secondary" type="button" @click.prevent="close">{{ $t('general.close') }}</button>
</template> </template>
</Modal> </Modal>

View File

@ -34,6 +34,8 @@ const title = computed(() => {
}) })
const formData = ref(freshUser()) const formData = ref(freshUser())
const isSaving = ref(false)
const isDeleting = ref(false)
const passwordWeak = computed(() => { const passwordWeak = computed(() => {
return formData.value.Password && formData.value.Password.length > 0 && formData.value.Password.length < settings.Setting('MinPasswordLength') return formData.value.Password && formData.value.Password.length > 0 && formData.value.Password.length < settings.Setting('MinPasswordLength')
@ -89,6 +91,8 @@ function close() {
} }
async function save() { async function save() {
if (isSaving.value) return
isSaving.value = true
try { try {
if (props.userId!=='#NEW#') { if (props.userId!=='#NEW#') {
await users.UpdateUser(selectedUser.value.Identifier, formData.value) await users.UpdateUser(selectedUser.value.Identifier, formData.value)
@ -102,10 +106,14 @@ async function save() {
text: e.toString(), text: e.toString(),
type: 'error', type: 'error',
}) })
} finally {
isSaving.value = false
} }
} }
async function del() { async function del() {
if (isDeleting.value) return
isDeleting.value = true
try { try {
await users.DeleteUser(selectedUser.value.Identifier) await users.DeleteUser(selectedUser.value.Identifier)
close() close()
@ -115,6 +123,8 @@ async function del() {
text: e.toString(), text: e.toString(),
type: 'error', type: 'error',
}) })
} finally {
isDeleting.value = false
} }
} }
@ -193,9 +203,15 @@ async function del() {
</template> </template>
<template #footer> <template #footer>
<div class="flex-fill text-start"> <div class="flex-fill text-start">
<button v-if="props.userId!=='#NEW#'" class="btn btn-danger me-1" type="button" @click.prevent="del">{{ $t('general.delete') }}</button> <button v-if="props.userId!=='#NEW#'" class="btn btn-danger me-1" type="button" @click.prevent="del" :disabled="isDeleting">
<span v-if="isDeleting" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
{{ $t('general.delete') }}
</button>
</div> </div>
<button class="btn btn-primary me-1" type="button" @click.prevent="save" :disabled="!formValid">{{ $t('general.save') }}</button> <button class="btn btn-primary me-1" type="button" @click.prevent="save" :disabled="!formValid || isSaving">
<span v-if="isSaving" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
{{ $t('general.save') }}
</button>
<button class="btn btn-secondary" type="button" @click.prevent="close">{{ $t('general.close') }}</button> <button class="btn btn-secondary" type="button" @click.prevent="close">{{ $t('general.close') }}</button>
</template> </template>
</Modal> </Modal>

View File

@ -55,6 +55,8 @@ const title = computed(() => {
}) })
const formData = ref(freshPeer()) const formData = ref(freshPeer())
const isSaving = ref(false)
const isDeleting = ref(false)
// functions // functions
@ -163,6 +165,8 @@ function close() {
} }
async function save() { async function save() {
if (isSaving.value) return
isSaving.value = true
try { try {
if (props.peerId !== '#NEW#') { if (props.peerId !== '#NEW#') {
await peers.UpdatePeer(selectedPeer.value.Identifier, formData.value) await peers.UpdatePeer(selectedPeer.value.Identifier, formData.value)
@ -171,26 +175,30 @@ async function save() {
} }
close() close()
} catch (e) { } catch (e) {
// console.log(e)
notify({ notify({
title: "Failed to save peer!", title: "Failed to save peer!",
text: e.toString(), text: e.toString(),
type: 'error', type: 'error',
}) })
} finally {
isSaving.value = false
} }
} }
async function del() { async function del() {
if (isDeleting.value) return
isDeleting.value = true
try { try {
await peers.DeletePeer(selectedPeer.value.Identifier) await peers.DeletePeer(selectedPeer.value.Identifier)
close() close()
} catch (e) { } catch (e) {
// console.log(e)
notify({ notify({
title: "Failed to delete peer!", title: "Failed to delete peer!",
text: e.toString(), text: e.toString(),
type: 'error', type: 'error',
}) })
} finally {
isDeleting.value = false
} }
} }
@ -283,10 +291,15 @@ async function del() {
</template> </template>
<template #footer> <template #footer>
<div class="flex-fill text-start"> <div class="flex-fill text-start">
<button v-if="props.peerId !== '#NEW#'" class="btn btn-danger me-1" type="button" @click.prevent="del">{{ <button v-if="props.peerId !== '#NEW#'" class="btn btn-danger me-1" type="button" @click.prevent="del" :disabled="isDeleting">
$t('general.delete') }}</button> <span v-if="isDeleting" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
{{ $t('general.delete') }}
</button>
</div> </div>
<button class="btn btn-primary me-1" type="button" @click.prevent="save">{{ $t('general.save') }}</button> <button class="btn btn-primary me-1" type="button" @click.prevent="save" :disabled="isSaving">
<span v-if="isSaving" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
{{ $t('general.save') }}
</button>
<button class="btn btn-secondary" type="button" @click.prevent="close">{{ $t('general.close') }}</button> <button class="btn btn-secondary" type="button" @click.prevent="close">{{ $t('general.close') }}</button>
</template> </template>
</Modal> </Modal>

View File

@ -5,6 +5,7 @@ export function freshInterface() {
DisplayName: "", DisplayName: "",
Identifier: "", Identifier: "",
Mode: "server", Mode: "server",
Backend: "local",
PublicKey: "", PublicKey: "",
PrivateKey: "", PrivateKey: "",

View File

@ -102,7 +102,9 @@
}, },
"interface": { "interface": {
"headline": "Schnittstellenstatus für", "headline": "Schnittstellenstatus für",
"mode": "Modus", "backend": "Backend",
"unknown-backend": "Unbekannt",
"wrong-backend": "Ungültiges Backend, das lokale WireGuard Backend wird stattdessen verwendet!",
"key": "Öffentlicher Schlüssel", "key": "Öffentlicher Schlüssel",
"endpoint": "Öffentlicher Endpunkt", "endpoint": "Öffentlicher Endpunkt",
"port": "Port", "port": "Port",
@ -357,6 +359,11 @@
"client": "Client-Modus", "client": "Client-Modus",
"any": "Unbekannter Modus" "any": "Unbekannter Modus"
}, },
"backend": {
"label": "Schnittstellenbackend",
"invalid-label": "Ursprüngliches Backend ist ungültig, das lokale WireGuard Backend wird stattdessen verwendet!",
"local": "Lokales WireGuard Backend"
},
"display-name": { "display-name": {
"label": "Anzeigename", "label": "Anzeigename",
"placeholder": "Der beschreibende Name für die Schnittstelle" "placeholder": "Der beschreibende Name für die Schnittstelle"

View File

@ -102,7 +102,9 @@
}, },
"interface": { "interface": {
"headline": "Interface status for", "headline": "Interface status for",
"mode": "mode", "backend": "Backend",
"unknown-backend": "Unknown",
"wrong-backend": "Invalid backend, using local WireGuard backend instead!",
"key": "Public Key", "key": "Public Key",
"endpoint": "Public Endpoint", "endpoint": "Public Endpoint",
"port": "Listening Port", "port": "Listening Port",
@ -357,6 +359,11 @@
"client": "Client Mode", "client": "Client Mode",
"any": "Unknown Mode" "any": "Unknown Mode"
}, },
"backend": {
"label": "Interface Backend",
"invalid-label": "Original backend is no longer available, using local WireGuard backend instead!",
"local": "Local WireGuard Backend"
},
"display-name": { "display-name": {
"label": "Display Name", "label": "Display Name",
"placeholder": "The descriptive name for the interface" "placeholder": "The descriptive name for the interface"

View File

@ -99,7 +99,7 @@
}, },
"interface": { "interface": {
"headline": "État de l'interface pour", "headline": "État de l'interface pour",
"mode": "mode", "backend": "backend",
"key": "Clé publique", "key": "Clé publique",
"endpoint": "Point de terminaison public", "endpoint": "Point de terminaison public",
"port": "Port d'écoute", "port": "Port d'écoute",

View File

@ -100,7 +100,7 @@
}, },
"interface": { "interface": {
"headline": "인터페이스 상태:", "headline": "인터페이스 상태:",
"mode": "모드", "backend": "백엔드",
"key": "공개 키", "key": "공개 키",
"endpoint": "공개 엔드포인트", "endpoint": "공개 엔드포인트",
"port": "수신 포트", "port": "수신 포트",

View File

@ -101,7 +101,7 @@
}, },
"interface": { "interface": {
"headline": "Status da interface para", "headline": "Status da interface para",
"mode": "modo", "mode": "backend",
"key": "Chave Pública", "key": "Chave Pública",
"endpoint": "Endpoint Público", "endpoint": "Endpoint Público",
"port": "Porta de Escuta", "port": "Porta de Escuta",

View File

@ -99,7 +99,7 @@
}, },
"interface": { "interface": {
"headline": "Статус интерфейса для", "headline": "Статус интерфейса для",
"mode": "режим", "backend": "бэкэнд",
"key": "Публичный ключ", "key": "Публичный ключ",
"endpoint": "Публичная конечная точка", "endpoint": "Публичная конечная точка",
"port": "Порт прослушивания", "port": "Порт прослушивания",

View File

@ -99,7 +99,7 @@
}, },
"interface": { "interface": {
"headline": "Статус інтерфейсу для", "headline": "Статус інтерфейсу для",
"mode": "режим", "backend": "бекенд",
"key": "Публічний ключ", "key": "Публічний ключ",
"endpoint": "Публічна кінцева точка", "endpoint": "Публічна кінцева точка",
"port": "Порт прослуховування", "port": "Порт прослуховування",

View File

@ -98,7 +98,7 @@
}, },
"interface": { "interface": {
"headline": "Trạng thái giao diện cho", "headline": "Trạng thái giao diện cho",
"mode": "chế độ", "backend": "phần sau",
"key": "Khóa Công khai", "key": "Khóa Công khai",
"endpoint": "Điểm cuối Công khai", "endpoint": "Điểm cuối Công khai",
"port": "Cổng Nghe", "port": "Cổng Nghe",

View File

@ -98,7 +98,7 @@
}, },
"interface": { "interface": {
"headline": "接口状态", "headline": "接口状态",
"mode": "模式", "backend": "后端",
"key": "公钥", "key": "公钥",
"endpoint": "公开节点", "endpoint": "公开节点",
"port": "监听端口", "port": "监听端口",

View File

@ -5,17 +5,20 @@ import PeerMultiCreateModal from "../components/PeerMultiCreateModal.vue";
import InterfaceEditModal from "../components/InterfaceEditModal.vue"; import InterfaceEditModal from "../components/InterfaceEditModal.vue";
import InterfaceViewModal from "../components/InterfaceViewModal.vue"; import InterfaceViewModal from "../components/InterfaceViewModal.vue";
import {onMounted, ref} from "vue"; import {computed, onMounted, ref} from "vue";
import {peerStore} from "@/stores/peers"; import {peerStore} from "@/stores/peers";
import {interfaceStore} from "@/stores/interfaces"; import {interfaceStore} from "@/stores/interfaces";
import {notify} from "@kyvg/vue3-notification"; import {notify} from "@kyvg/vue3-notification";
import {settingsStore} from "@/stores/settings"; import {settingsStore} from "@/stores/settings";
import {humanFileSize} from '@/helpers/utils'; import {humanFileSize} from '@/helpers/utils';
import {useI18n} from "vue-i18n";
const settings = settingsStore() const settings = settingsStore()
const interfaces = interfaceStore() const interfaces = interfaceStore()
const peers = peerStore() const peers = peerStore()
const { t } = useI18n()
const viewedPeerId = ref("") const viewedPeerId = ref("")
const editPeerId = ref("") const editPeerId = ref("")
const multiCreatePeerId = ref("") const multiCreatePeerId = ref("")
@ -45,6 +48,33 @@ function calculateInterfaceName(id, name) {
return result return result
} }
const calculateBackendName = computed(() => {
let backendId = interfaces.GetSelected.Backend
let backendName = t('interfaces.interface.unknown-backend')
let availableBackends = settings.Setting('AvailableBackends') || []
availableBackends.forEach(backend => {
if (backend.Id === backendId) {
backendName = backend.Id === 'local' ? t(backend.Name) : backend.Name
}
})
return backendName
})
const isBackendValid = computed(() => {
let backendId = interfaces.GetSelected.Backend
let valid = false
let availableBackends = settings.Setting('AvailableBackends') || []
availableBackends.forEach(backend => {
if (backend.Id === backendId) {
valid = true
}
})
return valid
})
async function download() { async function download() {
await interfaces.LoadInterfaceConfig(interfaces.GetSelected.Identifier) await interfaces.LoadInterfaceConfig(interfaces.GetSelected.Identifier)
@ -141,7 +171,7 @@ onMounted(async () => {
<div class="card-header"> <div class="card-header">
<div class="row"> <div class="row">
<div class="col-12 col-lg-8"> <div class="col-12 col-lg-8">
{{ $t('interfaces.interface.headline') }} <strong>{{interfaces.GetSelected.Identifier}}</strong> ({{interfaces.GetSelected.Mode}} {{ $t('interfaces.interface.mode') }}) {{ $t('interfaces.interface.headline') }} <strong>{{interfaces.GetSelected.Identifier}}</strong> ({{ $t('modals.interface-edit.mode.' + interfaces.GetSelected.Mode )}} | {{ $t('interfaces.interface.backend') + ": " + calculateBackendName }}<span v-if="!isBackendValid" :title="t('interfaces.interface.wrong-backend')" class="ms-1 me-1"><i class="fa-solid fa-triangle-exclamation"></i></span>)
<span v-if="interfaces.GetSelected.Disabled" class="text-danger"><i class="fa fa-circle-xmark" :title="interfaces.GetSelected.DisabledReason"></i></span> <span v-if="interfaces.GetSelected.Disabled" class="text-danger"><i class="fa fa-circle-xmark" :title="interfaces.GetSelected.DisabledReason"></i></span>
</div> </div>
<div class="col-12 col-lg-4 text-lg-end"> <div class="col-12 col-lg-4 text-lg-end">

View File

@ -0,0 +1,864 @@
package wgcontroller
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"log/slog"
"os"
"os/exec"
"strings"
"time"
probing "github.com/prometheus-community/pro-bing"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
"github.com/h44z/wg-portal/internal/lowlevel"
)
// region dependencies
// WgCtrlRepo is used to control local WireGuard devices via the wgctrl-go library.
type WgCtrlRepo interface {
io.Closer
Devices() ([]*wgtypes.Device, error)
Device(name string) (*wgtypes.Device, error)
ConfigureDevice(name string, cfg wgtypes.Config) error
}
// A NetlinkClient is a type which can control a netlink device.
type NetlinkClient interface {
LinkAdd(link netlink.Link) error
LinkDel(link netlink.Link) error
LinkByName(name string) (netlink.Link, error)
LinkSetUp(link netlink.Link) error
LinkSetDown(link netlink.Link) error
LinkSetMTU(link netlink.Link, mtu int) error
AddrReplace(link netlink.Link, addr *netlink.Addr) error
AddrAdd(link netlink.Link, addr *netlink.Addr) error
AddrList(link netlink.Link) ([]netlink.Addr, error)
AddrDel(link netlink.Link, addr *netlink.Addr) error
RouteAdd(route *netlink.Route) error
RouteDel(route *netlink.Route) error
RouteReplace(route *netlink.Route) error
RouteList(link netlink.Link, family int) ([]netlink.Route, error)
RouteListFiltered(family int, filter *netlink.Route, filterMask uint64) ([]netlink.Route, error)
RuleAdd(rule *netlink.Rule) error
RuleDel(rule *netlink.Rule) error
RuleList(family int) ([]netlink.Rule, error)
}
// endregion dependencies
type LocalController struct {
cfg *config.Config
wg WgCtrlRepo
nl NetlinkClient
shellCmd string
resolvConfIfacePrefix string
}
// NewLocalController creates a new local controller instance.
// This repository is used to interact with the WireGuard kernel or userspace module.
func NewLocalController(cfg *config.Config) (*LocalController, error) {
wg, err := wgctrl.New()
if err != nil {
return nil, fmt.Errorf("failed to create wgctrl client: %w", err)
}
nl := &lowlevel.NetlinkManager{}
repo := &LocalController{
cfg: cfg,
wg: wg,
nl: nl,
shellCmd: "bash", // we only support bash at the moment
resolvConfIfacePrefix: "tun.", // WireGuard interfaces have a tun. prefix in resolvconf
}
return repo, nil
}
func (c LocalController) GetId() domain.InterfaceBackend {
return config.LocalBackendName
}
// region wireguard-related
func (c LocalController) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) {
devices, err := c.wg.Devices()
if err != nil {
return nil, fmt.Errorf("device list error: %w", err)
}
interfaces := make([]domain.PhysicalInterface, 0, len(devices))
for _, device := range devices {
interfaceModel, err := c.convertWireGuardInterface(device)
if err != nil {
return nil, fmt.Errorf("interface convert failed for %s: %w", device.Name, err)
}
interfaces = append(interfaces, interfaceModel)
}
return interfaces, nil
}
func (c LocalController) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (
*domain.PhysicalInterface,
error,
) {
return c.getInterface(id)
}
func (c LocalController) convertWireGuardInterface(device *wgtypes.Device) (domain.PhysicalInterface, error) {
// read data from wgctrl interface
iface := domain.PhysicalInterface{
Identifier: domain.InterfaceIdentifier(device.Name),
KeyPair: domain.KeyPair{
PrivateKey: device.PrivateKey.String(),
PublicKey: device.PublicKey.String(),
},
ListenPort: device.ListenPort,
Addresses: nil,
Mtu: 0,
FirewallMark: uint32(device.FirewallMark),
DeviceUp: false,
ImportSource: domain.ControllerTypeLocal,
DeviceType: device.Type.String(),
BytesUpload: 0,
BytesDownload: 0,
}
// read data from netlink interface
lowLevelInterface, err := c.nl.LinkByName(device.Name)
if err != nil {
return domain.PhysicalInterface{}, fmt.Errorf("netlink error for %s: %w", device.Name, err)
}
ipAddresses, err := c.nl.AddrList(lowLevelInterface)
if err != nil {
return domain.PhysicalInterface{}, fmt.Errorf("ip read error for %s: %w", device.Name, err)
}
for _, addr := range ipAddresses {
iface.Addresses = append(iface.Addresses, domain.CidrFromNetlinkAddr(addr))
}
iface.Mtu = lowLevelInterface.Attrs().MTU
iface.DeviceUp = lowLevelInterface.Attrs().OperState == netlink.OperUnknown // wg only supports unknown
if stats := lowLevelInterface.Attrs().Statistics; stats != nil {
iface.BytesUpload = stats.TxBytes
iface.BytesDownload = stats.RxBytes
}
return iface, nil
}
func (c LocalController) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) (
[]domain.PhysicalPeer,
error,
) {
device, err := c.wg.Device(string(deviceId))
if err != nil {
return nil, fmt.Errorf("device error: %w", err)
}
peers := make([]domain.PhysicalPeer, 0, len(device.Peers))
for _, peer := range device.Peers {
peerModel, err := c.convertWireGuardPeer(&peer)
if err != nil {
return nil, fmt.Errorf("peer convert failed for %v: %w", peer.PublicKey, err)
}
peers = append(peers, peerModel)
}
return peers, nil
}
func (c LocalController) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer, error) {
peerModel := domain.PhysicalPeer{
Identifier: domain.PeerIdentifier(peer.PublicKey.String()),
Endpoint: "",
AllowedIPs: nil,
KeyPair: domain.KeyPair{
PublicKey: peer.PublicKey.String(),
},
PresharedKey: "",
PersistentKeepalive: int(peer.PersistentKeepaliveInterval.Seconds()),
LastHandshake: peer.LastHandshakeTime,
ProtocolVersion: peer.ProtocolVersion,
BytesUpload: uint64(peer.ReceiveBytes),
BytesDownload: uint64(peer.TransmitBytes),
ImportSource: domain.ControllerTypeLocal,
}
// Set local extras - local peers are never disabled in the kernel
peerModel.SetExtras(domain.LocalPeerExtras{
Disabled: false,
})
for _, addr := range peer.AllowedIPs {
peerModel.AllowedIPs = append(peerModel.AllowedIPs, domain.CidrFromIpNet(addr))
}
if peer.Endpoint != nil {
peerModel.Endpoint = peer.Endpoint.String()
}
if peer.PresharedKey != (wgtypes.Key{}) {
peerModel.PresharedKey = domain.PreSharedKey(peer.PresharedKey.String())
}
return peerModel, nil
}
func (c LocalController) SaveInterface(
_ context.Context,
id domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error {
physicalInterface, err := c.getOrCreateInterface(id)
if err != nil {
return err
}
if updateFunc != nil {
physicalInterface, err = updateFunc(physicalInterface)
if err != nil {
return err
}
}
if err := c.updateLowLevelInterface(physicalInterface); err != nil {
return err
}
if err := c.updateWireGuardInterface(physicalInterface); err != nil {
return err
}
return nil
}
func (c LocalController) getOrCreateInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
device, err := c.getInterface(id)
if err == nil {
return device, nil // interface exists
}
if !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("device error: %w", err) // unknown error
}
// create new device
if err := c.createLowLevelInterface(id); err != nil {
return nil, err
}
device, err = c.getInterface(id)
return device, err
}
func (c LocalController) getInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
device, err := c.wg.Device(string(id))
if err != nil {
return nil, err
}
pi, err := c.convertWireGuardInterface(device)
return &pi, err
}
func (c LocalController) createLowLevelInterface(id domain.InterfaceIdentifier) error {
link := &netlink.GenericLink{
LinkAttrs: netlink.LinkAttrs{
Name: string(id),
},
LinkType: "wireguard",
}
err := c.nl.LinkAdd(link)
if err != nil {
return fmt.Errorf("link add failed: %w", err)
}
return nil
}
func (c LocalController) updateLowLevelInterface(pi *domain.PhysicalInterface) error {
link, err := c.nl.LinkByName(string(pi.Identifier))
if err != nil {
return err
}
if pi.Mtu != 0 {
if err := c.nl.LinkSetMTU(link, pi.Mtu); err != nil {
return fmt.Errorf("mtu error: %w", err)
}
}
for _, addr := range pi.Addresses {
err := c.nl.AddrReplace(link, addr.NetlinkAddr())
if err != nil {
return fmt.Errorf("failed to set ip %s: %w", addr.String(), err)
}
}
// Remove unwanted IP addresses
rawAddresses, err := c.nl.AddrList(link)
if err != nil {
return fmt.Errorf("failed to fetch interface ips: %w", err)
}
for _, rawAddr := range rawAddresses {
netlinkAddr := domain.CidrFromNetlinkAddr(rawAddr)
remove := true
for _, addr := range pi.Addresses {
if addr == netlinkAddr {
remove = false
break
}
}
if !remove {
continue
}
err := c.nl.AddrDel(link, &rawAddr)
if err != nil {
return fmt.Errorf("failed to remove deprecated ip %s: %w", netlinkAddr.String(), err)
}
}
// Update link state
if pi.DeviceUp {
if err := c.nl.LinkSetUp(link); err != nil {
return fmt.Errorf("failed to bring up device: %w", err)
}
} else {
if err := c.nl.LinkSetDown(link); err != nil {
return fmt.Errorf("failed to bring down device: %w", err)
}
}
return nil
}
func (c LocalController) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
pKey, err := wgtypes.NewKey(pi.KeyPair.GetPrivateKeyBytes())
if err != nil {
return err
}
var fwMark *int
if pi.FirewallMark != 0 {
intFwMark := int(pi.FirewallMark)
fwMark = &intFwMark
}
err = c.wg.ConfigureDevice(string(pi.Identifier), wgtypes.Config{
PrivateKey: &pKey,
ListenPort: &pi.ListenPort,
FirewallMark: fwMark,
ReplacePeers: false,
})
if err != nil {
return err
}
return nil
}
func (c LocalController) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
if err := c.deleteLowLevelInterface(id); err != nil {
return err
}
return nil
}
func (c LocalController) deleteLowLevelInterface(id domain.InterfaceIdentifier) error {
link, err := c.nl.LinkByName(string(id))
if err != nil {
var linkNotFoundError netlink.LinkNotFoundError
if errors.As(err, &linkNotFoundError) {
return nil // ignore not found error
}
return fmt.Errorf("unable to find low level interface: %w", err)
}
err = c.nl.LinkDel(link)
if err != nil {
return fmt.Errorf("failed to delete low level interface: %w", err)
}
return nil
}
func (c LocalController) SavePeer(
_ context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error {
physicalPeer, err := c.getOrCreatePeer(deviceId, id)
if err != nil {
return err
}
physicalPeer, err = updateFunc(physicalPeer)
if err != nil {
return err
}
// Check if the peer is disabled by looking at the backend extras
// For local controller, disabled peers should be deleted
if physicalPeer.GetExtras() != nil {
switch extras := physicalPeer.GetExtras().(type) {
case domain.LocalPeerExtras:
if extras.Disabled {
// Delete the peer instead of updating it
return c.deletePeer(deviceId, id)
}
}
}
if err := c.updatePeer(deviceId, physicalPeer); err != nil {
return err
}
return nil
}
func (c LocalController) getOrCreatePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (
*domain.PhysicalPeer,
error,
) {
peer, err := c.getPeer(deviceId, id)
if err == nil {
return peer, nil // peer exists
}
if !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("peer error: %w", err) // unknown error
}
// create new peer
err = c.wg.ConfigureDevice(string(deviceId), wgtypes.Config{
Peers: []wgtypes.PeerConfig{
{
PublicKey: id.ToPublicKey(),
},
},
})
if err != nil {
return nil, fmt.Errorf("peer create error for %s: %w", id.ToPublicKey(), err)
}
peer, err = c.getPeer(deviceId, id)
if err != nil {
return nil, fmt.Errorf("peer error after create: %w", err)
}
return peer, nil
}
func (c LocalController) getPeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (
*domain.PhysicalPeer,
error,
) {
if !id.IsPublicKey() {
return nil, errors.New("invalid public key")
}
device, err := c.wg.Device(string(deviceId))
if err != nil {
return nil, err
}
publicKey := id.ToPublicKey()
for _, peer := range device.Peers {
if peer.PublicKey != publicKey {
continue
}
peerModel, err := c.convertWireGuardPeer(&peer)
return &peerModel, err
}
return nil, os.ErrNotExist
}
func (c LocalController) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.PhysicalPeer) error {
cfg := wgtypes.PeerConfig{
PublicKey: pp.GetPublicKey(),
Remove: false,
UpdateOnly: true,
PresharedKey: pp.GetPresharedKey(),
Endpoint: pp.GetEndpointAddress(),
PersistentKeepaliveInterval: pp.GetPersistentKeepaliveTime(),
ReplaceAllowedIPs: true,
AllowedIPs: pp.GetAllowedIPs(),
}
err := c.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}})
if err != nil {
return err
}
return nil
}
func (c LocalController) DeletePeer(
_ context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
) error {
if !id.IsPublicKey() {
return errors.New("invalid public key")
}
err := c.deletePeer(deviceId, id)
if err != nil {
return err
}
return nil
}
func (c LocalController) deletePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error {
cfg := wgtypes.PeerConfig{
PublicKey: id.ToPublicKey(),
Remove: true,
}
err := c.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}})
if err != nil {
return err
}
return nil
}
// endregion wireguard-related
// region wg-quick-related
func (c LocalController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
if hookCmd == "" {
return nil
}
slog.Debug("executing interface hook", "interface", id, "hook", hookCmd)
err := c.exec(hookCmd, id)
if err != nil {
return fmt.Errorf("failed to exec hook: %w", err)
}
return nil
}
func (c LocalController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
if dnsStr == "" && dnsSearchStr == "" {
return nil
}
dnsServers := internal.SliceString(dnsStr)
dnsSearchDomains := internal.SliceString(dnsSearchStr)
dnsCommand := "resolvconf -a %resPref%i -m 0 -x"
dnsCommandInput := make([]string, 0, len(dnsServers)+len(dnsSearchDomains))
for _, dnsServer := range dnsServers {
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("nameserver %s", dnsServer))
}
for _, searchDomain := range dnsSearchDomains {
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("search %s", searchDomain))
}
err := c.exec(dnsCommand, id, dnsCommandInput...)
if err != nil {
return fmt.Errorf(
"failed to set dns settings (is resolvconf available?, for systemd create this symlink: ln -s /usr/bin/resolvectl /usr/local/bin/resolvconf): %w",
err,
)
}
return nil
}
func (c LocalController) UnsetDNS(id domain.InterfaceIdentifier) error {
dnsCommand := "resolvconf -d %resPref%i -f"
err := c.exec(dnsCommand, id)
if err != nil {
return fmt.Errorf("failed to unset dns settings: %w", err)
}
return nil
}
func (c LocalController) replaceCommandPlaceHolders(command string, interfaceId domain.InterfaceIdentifier) string {
command = strings.ReplaceAll(command, "%resPref", c.resolvConfIfacePrefix)
return strings.ReplaceAll(command, "%i", string(interfaceId))
}
func (c LocalController) exec(command string, interfaceId domain.InterfaceIdentifier, stdin ...string) error {
commandWithInterfaceName := c.replaceCommandPlaceHolders(command, interfaceId)
cmd := exec.Command(c.shellCmd, "-ce", commandWithInterfaceName)
if len(stdin) > 0 {
b := &bytes.Buffer{}
for _, ln := range stdin {
if _, err := fmt.Fprint(b, ln); err != nil {
return err
}
}
cmd.Stdin = b
}
out, err := cmd.CombinedOutput() // execute and wait for output
if err != nil {
return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err)
}
slog.Debug("executed shell command",
"command", commandWithInterfaceName,
"output", string(out))
return nil
}
// endregion wg-quick-related
// region routing-related
func (c LocalController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error {
// update fwmark rules
if err := c.setFwMarkRules(rules); err != nil {
return err
}
// update main rule
if err := c.setMainRule(rules); err != nil {
return err
}
// cleanup old main rules
if err := c.cleanupMainRule(rules); err != nil {
return err
}
return nil
}
func (c LocalController) setFwMarkRules(rules []domain.RouteRule) error {
for _, rule := range rules {
existingRules, err := c.nl.RuleList(int(rule.IpFamily))
if err != nil {
return fmt.Errorf("failed to get existing rules for family %s: %w", rule.IpFamily, err)
}
ruleExists := false
for _, existingRule := range existingRules {
if rule.FwMark == existingRule.Mark && rule.Table == existingRule.Table {
ruleExists = true
break
}
}
if ruleExists {
continue // rule already exists, no need to recreate it
}
// create a missing rule
if err := c.nl.RuleAdd(&netlink.Rule{
Family: int(rule.IpFamily),
Table: rule.Table,
Mark: rule.FwMark,
Invert: true,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Priority: c.getRulePriority(existingRules),
Mask: nil,
Goto: -1,
Flow: -1,
}); err != nil {
return fmt.Errorf("failed to setup %s rule for fwmark %d and table %d: %w",
rule.IpFamily, rule.FwMark, rule.Table, err)
}
}
return nil
}
func (c LocalController) getRulePriority(existingRules []netlink.Rule) int {
prio := 32700 // linux main rule has a priority of 32766
for {
isFresh := true
for _, existingRule := range existingRules {
if existingRule.Priority == prio {
isFresh = false
break
}
}
if isFresh {
break
} else {
prio--
}
}
return prio
}
func (c LocalController) setMainRule(rules []domain.RouteRule) error {
var family domain.IpFamily
shouldHaveMainRule := false
for _, rule := range rules {
family = rule.IpFamily
if rule.HasDefault == true {
shouldHaveMainRule = true
break
}
}
if !shouldHaveMainRule {
return nil
}
existingRules, err := c.nl.RuleList(int(family))
if err != nil {
return fmt.Errorf("failed to get existing rules for family %s: %w", family, err)
}
ruleExists := false
for _, existingRule := range existingRules {
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
ruleExists = true
break
}
}
if ruleExists {
return nil // rule already exists, skip re-creation
}
if err := c.nl.RuleAdd(&netlink.Rule{
Family: int(family),
Table: unix.RT_TABLE_MAIN,
SuppressIfgroup: -1,
SuppressPrefixlen: 0,
Priority: c.getMainRulePriority(existingRules),
Mark: 0,
Mask: nil,
Goto: -1,
Flow: -1,
}); err != nil {
return fmt.Errorf("failed to setup rule for main table: %w", err)
}
return nil
}
func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int {
priority := c.cfg.Advanced.RulePrioOffset
for {
isFresh := true
for _, existingRule := range existingRules {
if existingRule.Priority == priority {
isFresh = false
break
}
}
if isFresh {
break
} else {
priority++
}
}
return priority
}
func (c LocalController) cleanupMainRule(rules []domain.RouteRule) error {
var family domain.IpFamily
for _, rule := range rules {
family = rule.IpFamily
break
}
existingRules, err := c.nl.RuleList(int(family))
if err != nil {
return fmt.Errorf("failed to get existing rules for family %s: %w", family, err)
}
shouldHaveMainRule := false
for _, rule := range rules {
if rule.HasDefault == true {
shouldHaveMainRule = true
break
}
}
mainRules := 0
for _, existingRule := range existingRules {
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
mainRules++
}
}
removalCount := 0
if mainRules > 1 {
removalCount = mainRules - 1 // we only want one single rule
}
if !shouldHaveMainRule {
removalCount = mainRules
}
for _, existingRule := range existingRules {
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
if removalCount > 0 {
existingRule.Family = int(family) // set family, somehow the RuleList method does not populate the family field
if err := c.nl.RuleDel(&existingRule); err != nil {
return fmt.Errorf("failed to delete main rule: %w", err)
}
removalCount--
}
}
}
return nil
}
func (c LocalController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error {
// TODO implement me
panic("implement me")
}
// endregion routing-related
// region statistics-related
func (c LocalController) PingAddresses(
ctx context.Context,
addr string,
) (*domain.PingerResult, error) {
pinger, err := probing.NewPinger(addr)
if err != nil {
return nil, fmt.Errorf("failed to instantiate pinger for %s: %w", addr, err)
}
checkCount := 1
pinger.SetPrivileged(!c.cfg.Statistics.PingUnprivileged)
pinger.Count = checkCount
pinger.Timeout = 2 * time.Second
err = pinger.RunWithContext(ctx) // Blocks until finished.
if err != nil {
return nil, fmt.Errorf("failed to ping %s: %w", addr, err)
}
stats := pinger.Statistics()
return &domain.PingerResult{
PacketsRecv: stats.PacketsRecv,
PacketsSent: stats.PacketsSent,
Rtts: stats.Rtts,
}, nil
}
// endregion statistics-related

View File

@ -0,0 +1,829 @@
package wgcontroller
import (
"context"
"fmt"
"slices"
"strconv"
"strings"
"sync"
"time"
"log/slog"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
"github.com/h44z/wg-portal/internal/lowlevel"
)
type MikrotikController struct {
coreCfg *config.Config
cfg *config.BackendMikrotik
client *lowlevel.MikrotikApiClient
// Add mutexes to prevent race conditions
interfaceMutexes sync.Map // map[domain.InterfaceIdentifier]*sync.Mutex
peerMutexes sync.Map // map[domain.PeerIdentifier]*sync.Mutex
}
func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik) (*MikrotikController, error) {
client, err := lowlevel.NewMikrotikApiClient(coreCfg, cfg)
if err != nil {
return nil, fmt.Errorf("failed to create Mikrotik API client: %w", err)
}
return &MikrotikController{
coreCfg: coreCfg,
cfg: cfg,
client: client,
interfaceMutexes: sync.Map{},
peerMutexes: sync.Map{},
}, nil
}
func (c *MikrotikController) GetId() domain.InterfaceBackend {
return domain.InterfaceBackend(c.cfg.Id)
}
// getInterfaceMutex returns a mutex for the given interface to prevent concurrent modifications
func (c *MikrotikController) getInterfaceMutex(id domain.InterfaceIdentifier) *sync.Mutex {
mutex, _ := c.interfaceMutexes.LoadOrStore(id, &sync.Mutex{})
return mutex.(*sync.Mutex)
}
// getPeerMutex returns a mutex for the given peer to prevent concurrent modifications
func (c *MikrotikController) getPeerMutex(id domain.PeerIdentifier) *sync.Mutex {
mutex, _ := c.peerMutexes.LoadOrStore(id, &sync.Mutex{})
return mutex.(*sync.Mutex)
}
// region wireguard-related
func (c *MikrotikController) GetInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) {
wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{
PropList: []string{
".id", "name", "public-key", "private-key", "listen-port", "mtu", "disabled", "running", "comment",
},
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return nil, fmt.Errorf("failed to query interfaces: %v", wgReply.Error)
}
// Parallelize loading of interface details to speed up overall latency.
// Use a bounded semaphore to avoid overloading the MikroTik device.
maxConcurrent := c.cfg.GetConcurrency()
sem := make(chan struct{}, maxConcurrent)
interfaces := make([]domain.PhysicalInterface, 0, len(wgReply.Data))
var mu sync.Mutex
var wgWait sync.WaitGroup
var firstErr error
ctx2, cancel := context.WithCancel(ctx)
defer cancel()
for _, wgObj := range wgReply.Data {
wgWait.Add(1)
sem <- struct{}{} // block if more than maxConcurrent requests are processing
go func(wg lowlevel.GenericJsonObject) {
defer wgWait.Done()
defer func() { <-sem }() // read from the semaphore and make space for the next entry
if firstErr != nil {
return
}
pi, err := c.loadInterfaceData(ctx2, wg)
if err != nil {
mu.Lock()
if firstErr == nil {
firstErr = err
cancel()
}
mu.Unlock()
return
}
mu.Lock()
interfaces = append(interfaces, *pi)
mu.Unlock()
}(wgObj)
}
wgWait.Wait()
if firstErr != nil {
return nil, firstErr
}
return interfaces, nil
}
func (c *MikrotikController) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (
*domain.PhysicalInterface,
error,
) {
wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{
PropList: []string{
".id", "name", "public-key", "private-key", "listen-port", "mtu", "disabled", "running",
},
Filters: map[string]string{
"name": string(id),
},
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return nil, fmt.Errorf("failed to query interface %s: %v", id, wgReply.Error)
}
if len(wgReply.Data) == 0 {
return nil, fmt.Errorf("interface %s not found", id)
}
return c.loadInterfaceData(ctx, wgReply.Data[0])
}
func (c *MikrotikController) loadInterfaceData(
ctx context.Context,
wireGuardObj lowlevel.GenericJsonObject,
) (*domain.PhysicalInterface, error) {
deviceId := wireGuardObj.GetString(".id")
deviceName := wireGuardObj.GetString("name")
ifaceReply := c.client.Get(ctx, "/interface/"+deviceId, &lowlevel.MikrotikRequestOptions{
PropList: []string{
"name", "rx-byte", "tx-byte",
},
})
if ifaceReply.Status != lowlevel.MikrotikApiStatusOk {
return nil, fmt.Errorf("failed to query interface %s: %v", deviceId, ifaceReply.Error)
}
ipv4, ipv6, err := c.loadIpAddresses(ctx, deviceName)
if err != nil {
return nil, fmt.Errorf("failed to query IP addresses for interface %s: %v", deviceId, err)
}
addresses := c.convertIpAddresses(ipv4, ipv6)
interfaceModel, err := c.convertWireGuardInterface(wireGuardObj, ifaceReply.Data, addresses)
if err != nil {
return nil, fmt.Errorf("interface convert failed for %s: %w", deviceName, err)
}
return &interfaceModel, nil
}
func (c *MikrotikController) loadIpAddresses(
ctx context.Context,
deviceName string,
) (ipv4 []lowlevel.GenericJsonObject, ipv6 []lowlevel.GenericJsonObject, err error) {
// Query IPv4 and IPv6 addresses in parallel to reduce latency.
var (
v4 []lowlevel.GenericJsonObject
v6 []lowlevel.GenericJsonObject
v4Err error
v6Err error
wg sync.WaitGroup
)
wg.Add(2)
go func() {
defer wg.Done()
addrV4Reply := c.client.Query(ctx, "/ip/address", &lowlevel.MikrotikRequestOptions{
PropList: []string{
".id", "address", "network",
},
Filters: map[string]string{
"interface": deviceName,
"dynamic": "false", // we only want static addresses
"disabled": "false", // we only want addresses that are not disabled
},
})
if addrV4Reply.Status != lowlevel.MikrotikApiStatusOk {
v4Err = fmt.Errorf("failed to query IPv4 addresses for interface %s: %v", deviceName, addrV4Reply.Error)
return
}
v4 = addrV4Reply.Data
}()
go func() {
defer wg.Done()
addrV6Reply := c.client.Query(ctx, "/ipv6/address", &lowlevel.MikrotikRequestOptions{
PropList: []string{
".id", "address", "network",
},
Filters: map[string]string{
"interface": deviceName,
"dynamic": "false", // we only want static addresses
"disabled": "false", // we only want addresses that are not disabled
},
})
if addrV6Reply.Status != lowlevel.MikrotikApiStatusOk {
v6Err = fmt.Errorf("failed to query IPv6 addresses for interface %s: %v", deviceName, addrV6Reply.Error)
return
}
v6 = addrV6Reply.Data
}()
wg.Wait()
if v4Err != nil {
return nil, nil, v4Err
}
if v6Err != nil {
return nil, nil, v6Err
}
return v4, v6, nil
}
func (c *MikrotikController) convertIpAddresses(
ipv4, ipv6 []lowlevel.GenericJsonObject,
) []domain.Cidr {
addresses := make([]domain.Cidr, 0, len(ipv4)+len(ipv6))
for _, addr := range append(ipv4, ipv6...) {
addrStr := addr.GetString("address")
if addrStr == "" {
continue
}
cidr, err := domain.CidrFromString(addrStr)
if err != nil {
continue
}
addresses = append(addresses, cidr)
}
return addresses
}
func (c *MikrotikController) convertWireGuardInterface(
wg, iface lowlevel.GenericJsonObject,
addresses []domain.Cidr,
) (
domain.PhysicalInterface,
error,
) {
pi := domain.PhysicalInterface{
Identifier: domain.InterfaceIdentifier(wg.GetString("name")),
KeyPair: domain.KeyPair{
PrivateKey: wg.GetString("private-key"),
PublicKey: wg.GetString("public-key"),
},
ListenPort: wg.GetInt("listen-port"),
Addresses: addresses,
Mtu: wg.GetInt("mtu"),
FirewallMark: 0,
DeviceUp: wg.GetBool("running"),
ImportSource: domain.ControllerTypeMikrotik,
DeviceType: domain.ControllerTypeMikrotik,
BytesUpload: uint64(iface.GetInt("tx-byte")),
BytesDownload: uint64(iface.GetInt("rx-byte")),
}
pi.SetExtras(domain.MikrotikInterfaceExtras{
Id: wg.GetString(".id"),
Comment: wg.GetString("comment"),
Disabled: wg.GetBool("disabled"),
})
return pi, nil
}
func (c *MikrotikController) GetPeers(ctx context.Context, deviceId domain.InterfaceIdentifier) (
[]domain.PhysicalPeer,
error,
) {
wgReply := c.client.Query(ctx, "/interface/wireguard/peers", &lowlevel.MikrotikRequestOptions{
PropList: []string{
".id", "name", "allowed-address", "client-address", "client-endpoint", "client-keepalive", "comment",
"current-endpoint-address", "current-endpoint-port", "last-handshake", "persistent-keepalive",
"public-key", "private-key", "preshared-key", "mtu", "disabled", "rx", "tx", "responder", "client-dns",
},
Filters: map[string]string{
"interface": string(deviceId),
},
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return nil, fmt.Errorf("failed to query peers for %s: %v", deviceId, wgReply.Error)
}
if len(wgReply.Data) == 0 {
return nil, nil
}
peers := make([]domain.PhysicalPeer, 0, len(wgReply.Data))
for _, peer := range wgReply.Data {
peerModel, err := c.convertWireGuardPeer(peer)
if err != nil {
return nil, fmt.Errorf("peer convert failed for %v: %w", peer.GetString("name"), err)
}
peers = append(peers, peerModel)
}
return peers, nil
}
func (c *MikrotikController) convertWireGuardPeer(peer lowlevel.GenericJsonObject) (
domain.PhysicalPeer,
error,
) {
keepAliveSeconds := 0
duration, err := time.ParseDuration(peer.GetString("persistent-keepalive"))
if err == nil {
keepAliveSeconds = int(duration.Seconds())
}
currentEndpoint := ""
if peer.GetString("current-endpoint-address") != "" && peer.GetString("current-endpoint-port") != "" {
currentEndpoint = peer.GetString("current-endpoint-address") + ":" + peer.GetString("current-endpoint-port")
}
lastHandshakeTime := time.Time{}
if peer.GetString("last-handshake") != "" {
relDuration, err := time.ParseDuration(peer.GetString("last-handshake"))
if err == nil {
lastHandshakeTime = time.Now().Add(-relDuration)
}
}
allowedAddresses, _ := domain.CidrsFromString(peer.GetString("allowed-address"))
clientKeepAliveSeconds := 0
duration, err = time.ParseDuration(peer.GetString("client-keepalive"))
if err == nil {
clientKeepAliveSeconds = int(duration.Seconds())
}
peerModel := domain.PhysicalPeer{
Identifier: domain.PeerIdentifier(peer.GetString("public-key")),
Endpoint: currentEndpoint,
AllowedIPs: allowedAddresses,
KeyPair: domain.KeyPair{
PublicKey: peer.GetString("public-key"),
PrivateKey: peer.GetString("private-key"),
},
PresharedKey: domain.PreSharedKey(peer.GetString("preshared-key")),
PersistentKeepalive: keepAliveSeconds,
LastHandshake: lastHandshakeTime,
ProtocolVersion: 0, // Mikrotik does not support protocol versioning, so we set it to 0
BytesUpload: uint64(peer.GetInt("rx")),
BytesDownload: uint64(peer.GetInt("tx")),
ImportSource: domain.ControllerTypeMikrotik,
}
peerModel.SetExtras(domain.MikrotikPeerExtras{
Id: peer.GetString(".id"),
Name: peer.GetString("name"),
Comment: peer.GetString("comment"),
IsResponder: peer.GetBool("responder"),
Disabled: peer.GetBool("disabled"),
ClientEndpoint: peer.GetString("client-endpoint"),
ClientAddress: peer.GetString("client-address"),
ClientDns: peer.GetString("client-dns"),
ClientKeepalive: clientKeepAliveSeconds,
})
return peerModel, nil
}
func (c *MikrotikController) SaveInterface(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error {
// Lock the interface to prevent concurrent modifications
mutex := c.getInterfaceMutex(id)
mutex.Lock()
defer mutex.Unlock()
physicalInterface, err := c.getOrCreateInterface(ctx, id)
if err != nil {
return err
}
deviceId := physicalInterface.GetExtras().(domain.MikrotikInterfaceExtras).Id
if updateFunc != nil {
physicalInterface, err = updateFunc(physicalInterface)
if err != nil {
return err
}
newExtras := physicalInterface.GetExtras().(domain.MikrotikInterfaceExtras)
newExtras.Id = deviceId // ensure the ID is not changed
physicalInterface.SetExtras(newExtras)
}
if err := c.updateInterface(ctx, physicalInterface); err != nil {
return err
}
return nil
}
func (c *MikrotikController) getOrCreateInterface(
ctx context.Context,
id domain.InterfaceIdentifier,
) (*domain.PhysicalInterface, error) {
wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{
PropList: []string{
".id", "name", "public-key", "private-key", "listen-port", "mtu", "disabled", "running",
},
Filters: map[string]string{
"name": string(id),
},
})
if wgReply.Status == lowlevel.MikrotikApiStatusOk && len(wgReply.Data) > 0 {
return c.loadInterfaceData(ctx, wgReply.Data[0])
}
// create a new interface if it does not exist
createReply := c.client.Create(ctx, "/interface/wireguard", lowlevel.GenericJsonObject{
"name": string(id),
})
if wgReply.Status == lowlevel.MikrotikApiStatusOk {
return c.loadInterfaceData(ctx, createReply.Data)
}
return nil, fmt.Errorf("failed to create interface %s: %v", id, createReply.Error)
}
func (c *MikrotikController) updateInterface(ctx context.Context, pi *domain.PhysicalInterface) error {
extras := pi.GetExtras().(domain.MikrotikInterfaceExtras)
interfaceId := extras.Id
wgReply := c.client.Update(ctx, "/interface/wireguard/"+interfaceId, lowlevel.GenericJsonObject{
"name": pi.Identifier,
"comment": extras.Comment,
"mtu": strconv.Itoa(pi.Mtu),
"listen-port": strconv.Itoa(pi.ListenPort),
"private-key": pi.KeyPair.PrivateKey,
"disabled": strconv.FormatBool(!pi.DeviceUp),
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to update interface %s: %v", pi.Identifier, wgReply.Error)
}
// update the interface's addresses
currentV4, currentV6, err := c.loadIpAddresses(ctx, string(pi.Identifier))
if err != nil {
return fmt.Errorf("failed to load current addresses for interface %s: %v", pi.Identifier, err)
}
currentAddresses := c.convertIpAddresses(currentV4, currentV6)
// get all addresses that are currently not in the interface, only in pi
newAddresses := make([]domain.Cidr, 0, len(pi.Addresses))
for _, addr := range pi.Addresses {
if slices.Contains(currentAddresses, addr) {
continue
}
newAddresses = append(newAddresses, addr)
}
// get obsolete addresses that are in the interface, but not in pi
obsoleteAddresses := make([]domain.Cidr, 0, len(currentAddresses))
for _, addr := range currentAddresses {
if slices.Contains(pi.Addresses, addr) {
continue
}
obsoleteAddresses = append(obsoleteAddresses, addr)
}
// update the IP addresses for the interface
if err := c.updateIpAddresses(ctx, string(pi.Identifier), currentV4, currentV6,
newAddresses, obsoleteAddresses); err != nil {
return fmt.Errorf("failed to update IP addresses for interface %s: %v", pi.Identifier, err)
}
return nil
}
func (c *MikrotikController) updateIpAddresses(
ctx context.Context,
deviceName string,
currentV4, currentV6 []lowlevel.GenericJsonObject,
new, obsolete []domain.Cidr,
) error {
// first, delete all obsolete addresses
for _, addr := range obsolete {
// find ID of the address to delete
if addr.IsV4() {
for _, a := range currentV4 {
if a.GetString("address") == addr.String() {
// delete the address
reply := c.client.Delete(ctx, "/ip/address/"+a.GetString(".id"))
if reply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to delete obsolete IPv4 address %s: %v", addr, reply.Error)
}
break
}
}
} else {
for _, a := range currentV6 {
if a.GetString("address") == addr.String() {
// delete the address
reply := c.client.Delete(ctx, "/ipv6/address/"+a.GetString(".id"))
if reply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to delete obsolete IPv6 address %s: %v", addr, reply.Error)
}
break
}
}
}
}
// then, add all new addresses
for _, addr := range new {
var createPath string
if addr.IsV4() {
createPath = "/ip/address"
} else {
createPath = "/ipv6/address"
}
// create the address
reply := c.client.Create(ctx, createPath, lowlevel.GenericJsonObject{
"address": addr.String(),
"interface": deviceName,
})
if reply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to create new address %s: %v", addr, reply.Error)
}
}
return nil
}
func (c *MikrotikController) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
// Lock the interface to prevent concurrent modifications
mutex := c.getInterfaceMutex(id)
mutex.Lock()
defer mutex.Unlock()
// delete the interface's addresses
currentV4, currentV6, err := c.loadIpAddresses(ctx, string(id))
if err != nil {
return fmt.Errorf("failed to load current addresses for interface %s: %v", id, err)
}
for _, a := range currentV4 {
// delete the address
reply := c.client.Delete(ctx, "/ip/address/"+a.GetString(".id"))
if reply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to delete IPv4 address %s: %v", a.GetString("address"), reply.Error)
}
}
for _, a := range currentV6 {
// delete the address
reply := c.client.Delete(ctx, "/ipv6/address/"+a.GetString(".id"))
if reply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to delete IPv6 address %s: %v", a.GetString("address"), reply.Error)
}
}
// delete the WireGuard interface
wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{
PropList: []string{".id"},
Filters: map[string]string{
"name": string(id),
},
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("unable to find WireGuard interface %s: %v", id, wgReply.Error)
}
if len(wgReply.Data) == 0 {
return nil // interface does not exist, nothing to delete
}
interfaceId := wgReply.Data[0].GetString(".id")
deleteReply := c.client.Delete(ctx, "/interface/wireguard/"+interfaceId)
if deleteReply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to delete WireGuard interface %s: %v", id, deleteReply.Error)
}
return nil
}
func (c *MikrotikController) SavePeer(
ctx context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error {
// Lock the peer to prevent concurrent modifications
mutex := c.getPeerMutex(id)
mutex.Lock()
defer mutex.Unlock()
physicalPeer, err := c.getOrCreatePeer(ctx, deviceId, id)
if err != nil {
return err
}
peerId := physicalPeer.GetExtras().(domain.MikrotikPeerExtras).Id
physicalPeer, err = updateFunc(physicalPeer)
if err != nil {
return err
}
newExtras := physicalPeer.GetExtras().(domain.MikrotikPeerExtras)
newExtras.Id = peerId // ensure the ID is not changed
physicalPeer.SetExtras(newExtras)
if err := c.updatePeer(ctx, deviceId, physicalPeer); err != nil {
return err
}
return nil
}
func (c *MikrotikController) getOrCreatePeer(
ctx context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
) (*domain.PhysicalPeer, error) {
wgReply := c.client.Query(ctx, "/interface/wireguard/peers", &lowlevel.MikrotikRequestOptions{
PropList: []string{
".id", "name", "public-key", "private-key", "preshared-key", "persistent-keepalive", "client-address",
"client-endpoint", "client-keepalive", "allowed-address", "client-dns", "comment", "disabled", "responder",
},
Filters: map[string]string{
"public-key": string(id),
"interface": string(deviceId),
},
})
if wgReply.Status == lowlevel.MikrotikApiStatusOk && len(wgReply.Data) > 0 {
slog.Debug("found existing Mikrotik peer", "peer", id, "interface", deviceId)
existingPeer, err := c.convertWireGuardPeer(wgReply.Data[0])
if err != nil {
return nil, err
}
return &existingPeer, nil
}
// create a new peer if it does not exist
slog.Debug("creating new Mikrotik peer", "peer", id, "interface", deviceId)
createReply := c.client.Create(ctx, "/interface/wireguard/peers", lowlevel.GenericJsonObject{
"name": fmt.Sprintf("tmp-wg-%s", id[0:8]),
"interface": string(deviceId),
"public-key": string(id),
"allowed-address": "0.0.0.0/0", // Use 0.0.0.0/0 as default, will be updated by updatePeer
})
if createReply.Status == lowlevel.MikrotikApiStatusOk {
newPeer, err := c.convertWireGuardPeer(createReply.Data)
if err != nil {
return nil, err
}
slog.Debug("successfully created Mikrotik peer", "peer", id, "interface", deviceId)
return &newPeer, nil
}
return nil, fmt.Errorf("failed to create peer %s for interface %s: %v", id, deviceId, createReply.Error)
}
func (c *MikrotikController) updatePeer(
ctx context.Context,
deviceId domain.InterfaceIdentifier,
pp *domain.PhysicalPeer,
) error {
extras := pp.GetExtras().(domain.MikrotikPeerExtras)
peerId := extras.Id
endpoint := pp.Endpoint
endpointPort := "51820" // default port if not set
if s := strings.Split(endpoint, ":"); len(s) == 2 {
endpoint = s[0]
endpointPort = s[1]
}
allowedAddressStr := domain.CidrsToString(pp.AllowedIPs)
slog.Debug("updating Mikrotik peer",
"peer", pp.Identifier,
"interface", deviceId,
"allowed-address", allowedAddressStr,
"allowed-ips-count", len(pp.AllowedIPs),
"disabled", extras.Disabled)
wgReply := c.client.Update(ctx, "/interface/wireguard/peers/"+peerId, lowlevel.GenericJsonObject{
"name": extras.Name,
"comment": extras.Comment,
"preshared-key": pp.PresharedKey,
"public-key": pp.KeyPair.PublicKey,
"private-key": pp.KeyPair.PrivateKey,
"persistent-keepalive": (time.Duration(pp.PersistentKeepalive) * time.Second).String(),
"disabled": strconv.FormatBool(extras.Disabled),
"responder": strconv.FormatBool(extras.IsResponder),
"client-endpoint": extras.ClientEndpoint,
"client-address": extras.ClientAddress,
"client-keepalive": (time.Duration(extras.ClientKeepalive) * time.Second).String(),
"client-dns": extras.ClientDns,
"endpoint-address": endpoint,
"endpoint-port": endpointPort,
"allowed-address": allowedAddressStr, // Add the missing allowed-address field
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to update peer %s on interface %s: %v", pp.Identifier, deviceId, wgReply.Error)
}
if extras.Disabled {
slog.Debug("successfully disabled Mikrotik peer", "peer", pp.Identifier, "interface", deviceId)
} else {
slog.Debug("successfully updated Mikrotik peer", "peer", pp.Identifier, "interface", deviceId)
}
return nil
}
func (c *MikrotikController) DeletePeer(
ctx context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
) error {
// Lock the peer to prevent concurrent modifications
mutex := c.getPeerMutex(id)
mutex.Lock()
defer mutex.Unlock()
wgReply := c.client.Query(ctx, "/interface/wireguard/peers", &lowlevel.MikrotikRequestOptions{
PropList: []string{".id"},
Filters: map[string]string{
"public-key": string(id),
"interface": string(deviceId),
},
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("unable to find WireGuard peer %s for interface %s: %v", id, deviceId, wgReply.Error)
}
if len(wgReply.Data) == 0 {
return nil // peer does not exist, nothing to delete
}
peerId := wgReply.Data[0].GetString(".id")
deleteReply := c.client.Delete(ctx, "/interface/wireguard/peers/"+peerId)
if deleteReply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to delete WireGuard peer %s for interface %s: %v", id, deviceId, deleteReply.Error)
}
return nil
}
// endregion wireguard-related
// region wg-quick-related
func (c *MikrotikController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
// TODO implement me
panic("implement me")
}
func (c *MikrotikController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
// TODO implement me
panic("implement me")
}
func (c *MikrotikController) UnsetDNS(id domain.InterfaceIdentifier) error {
// TODO implement me
panic("implement me")
}
// endregion wg-quick-related
// region routing-related
func (c *MikrotikController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error {
// TODO implement me
panic("implement me")
}
func (c *MikrotikController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error {
// TODO implement me
panic("implement me")
}
// endregion routing-related
// region statistics-related
func (c *MikrotikController) PingAddresses(
ctx context.Context,
addr string,
) (*domain.PingerResult, error) {
wgReply := c.client.ExecList(ctx, "/tool/ping",
// limit to 1 packet with a max running time of 2 seconds
lowlevel.GenericJsonObject{"address": addr, "count": 1, "interval": "00:00:02"},
)
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return nil, fmt.Errorf("failed to ping %s: %v", addr, wgReply.Error)
}
var result domain.PingerResult
for _, item := range wgReply.Data {
result.PacketsRecv += item.GetInt("received")
result.PacketsSent += item.GetInt("sent")
rttStr := item.GetString("avg-rtt")
if rttStr != "" {
rtt, err := time.ParseDuration(rttStr)
if err == nil {
result.Rtts = append(result.Rtts, rtt)
} else {
// use a high value to indicate failure or timeout
result.Rtts = append(result.Rtts, 999999*time.Millisecond)
}
}
}
return &result, nil
}
// endregion statistics-related

View File

@ -21,17 +21,23 @@ import (
//go:embed frontend_config.js.gotpl //go:embed frontend_config.js.gotpl
var frontendJs embed.FS var frontendJs embed.FS
type ControllerManager interface {
GetControllerNames() []config.BackendBase
}
type ConfigEndpoint struct { type ConfigEndpoint struct {
cfg *config.Config cfg *config.Config
authenticator Authenticator authenticator Authenticator
controllerMgr ControllerManager
tpl *respond.TemplateRenderer tpl *respond.TemplateRenderer
} }
func NewConfigEndpoint(cfg *config.Config, authenticator Authenticator) ConfigEndpoint { func NewConfigEndpoint(cfg *config.Config, authenticator Authenticator, ctrlMgr ControllerManager) ConfigEndpoint {
ep := ConfigEndpoint{ ep := ConfigEndpoint{
cfg: cfg, cfg: cfg,
authenticator: authenticator, authenticator: authenticator,
controllerMgr: ctrlMgr,
tpl: respond.NewTemplateRenderer(template.Must(template.ParseFS(frontendJs, tpl: respond.NewTemplateRenderer(template.Must(template.ParseFS(frontendJs,
"frontend_config.js.gotpl"))), "frontend_config.js.gotpl"))),
} }
@ -96,13 +102,36 @@ func (e ConfigEndpoint) handleSettingsGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
sessionUser := domain.GetUserInfo(r.Context()) sessionUser := domain.GetUserInfo(r.Context())
controllerFn := func() []model.SettingsBackendNames {
controllers := e.controllerMgr.GetControllerNames()
names := make([]model.SettingsBackendNames, 0, len(controllers))
for _, controller := range controllers {
displayName := controller.GetDisplayName()
if displayName == "" {
displayName = controller.Id // fallback to ID if no display name is set
}
if controller.Id == config.LocalBackendName {
displayName = "modals.interface-edit.backend.local" // use a localized string for the local backend
}
names = append(names, model.SettingsBackendNames{
Id: controller.Id,
Name: displayName,
})
}
return names
}
hasSocialLogin := len(e.cfg.Auth.OAuth) > 0 || len(e.cfg.Auth.OpenIDConnect) > 0 || e.cfg.Auth.WebAuthn.Enabled hasSocialLogin := len(e.cfg.Auth.OAuth) > 0 || len(e.cfg.Auth.OpenIDConnect) > 0 || e.cfg.Auth.WebAuthn.Enabled
// For anonymous users, we return the settings object with minimal information // For anonymous users, we return the settings object with minimal information
if sessionUser.Id == domain.CtxUnknownUserId || sessionUser.Id == "" { if sessionUser.Id == domain.CtxUnknownUserId || sessionUser.Id == "" {
respond.JSON(w, http.StatusOK, model.Settings{ respond.JSON(w, http.StatusOK, model.Settings{
WebAuthnEnabled: e.cfg.Auth.WebAuthn.Enabled, WebAuthnEnabled: e.cfg.Auth.WebAuthn.Enabled,
LoginFormVisible: !e.cfg.Auth.HideLoginForm || !hasSocialLogin, AvailableBackends: []model.SettingsBackendNames{}, // return an empty list instead of null
LoginFormVisible: !e.cfg.Auth.HideLoginForm || !hasSocialLogin,
}) })
} else { } else {
respond.JSON(w, http.StatusOK, model.Settings{ respond.JSON(w, http.StatusOK, model.Settings{
@ -112,6 +141,7 @@ func (e ConfigEndpoint) handleSettingsGet() http.HandlerFunc {
ApiAdminOnly: e.cfg.Advanced.ApiAdminOnly, ApiAdminOnly: e.cfg.Advanced.ApiAdminOnly,
WebAuthnEnabled: e.cfg.Auth.WebAuthn.Enabled, WebAuthnEnabled: e.cfg.Auth.WebAuthn.Enabled,
MinPasswordLength: e.cfg.Auth.MinPasswordLength, MinPasswordLength: e.cfg.Auth.MinPasswordLength,
AvailableBackends: controllerFn(),
LoginFormVisible: !e.cfg.Auth.HideLoginForm || !hasSocialLogin, LoginFormVisible: !e.cfg.Auth.HideLoginForm || !hasSocialLogin,
}) })
} }

View File

@ -6,11 +6,17 @@ type Error struct {
} }
type Settings struct { type Settings struct {
MailLinkOnly bool `json:"MailLinkOnly"` MailLinkOnly bool `json:"MailLinkOnly"`
PersistentConfigSupported bool `json:"PersistentConfigSupported"` PersistentConfigSupported bool `json:"PersistentConfigSupported"`
SelfProvisioning bool `json:"SelfProvisioning"` SelfProvisioning bool `json:"SelfProvisioning"`
ApiAdminOnly bool `json:"ApiAdminOnly"` ApiAdminOnly bool `json:"ApiAdminOnly"`
WebAuthnEnabled bool `json:"WebAuthnEnabled"` WebAuthnEnabled bool `json:"WebAuthnEnabled"`
MinPasswordLength int `json:"MinPasswordLength"` MinPasswordLength int `json:"MinPasswordLength"`
LoginFormVisible bool `json:"LoginFormVisible"` AvailableBackends []SettingsBackendNames `json:"AvailableBackends"`
LoginFormVisible bool `json:"LoginFormVisible"`
}
type SettingsBackendNames struct {
Id string `json:"Id"`
Name string `json:"Name"`
} }

View File

@ -4,6 +4,7 @@ import (
"time" "time"
"github.com/h44z/wg-portal/internal" "github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
) )
@ -11,6 +12,7 @@ type Interface struct {
Identifier string `json:"Identifier" example:"wg0"` // device name, for example: wg0 Identifier string `json:"Identifier" example:"wg0"` // device name, for example: wg0
DisplayName string `json:"DisplayName"` // a nice display name/ description for the interface DisplayName string `json:"DisplayName"` // a nice display name/ description for the interface
Mode string `json:"Mode" example:"server"` // the interface type, either 'server', 'client' or 'any' Mode string `json:"Mode" example:"server"` // the interface type, either 'server', 'client' or 'any'
Backend string `json:"Backend" example:"local"` // the backend used for this interface e.g., local, mikrotik, ...
PrivateKey string `json:"PrivateKey" example:"abcdef=="` // private Key of the server interface PrivateKey string `json:"PrivateKey" example:"abcdef=="` // private Key of the server interface
PublicKey string `json:"PublicKey" example:"abcdef=="` // public Key of the server interface PublicKey string `json:"PublicKey" example:"abcdef=="` // public Key of the server interface
Disabled bool `json:"Disabled"` // flag that specifies if the interface is enabled (up) or not (down) Disabled bool `json:"Disabled"` // flag that specifies if the interface is enabled (up) or not (down)
@ -57,6 +59,7 @@ func NewInterface(src *domain.Interface, peers []domain.Peer) *Interface {
Identifier: string(src.Identifier), Identifier: string(src.Identifier),
DisplayName: src.DisplayName, DisplayName: src.DisplayName,
Mode: string(src.Type), Mode: string(src.Type),
Backend: string(src.Backend),
PrivateKey: src.PrivateKey, PrivateKey: src.PrivateKey,
PublicKey: src.PublicKey, PublicKey: src.PublicKey,
Disabled: src.IsDisabled(), Disabled: src.IsDisabled(),
@ -92,6 +95,10 @@ func NewInterface(src *domain.Interface, peers []domain.Peer) *Interface {
Filename: src.GetConfigFileName(), Filename: src.GetConfigFileName(),
} }
if iface.Backend == "" {
iface.Backend = config.LocalBackendName // default to local backend
}
if len(peers) > 0 { if len(peers) > 0 {
iface.TotalPeers = len(peers) iface.TotalPeers = len(peers)
@ -146,6 +153,7 @@ func NewDomainInterface(src *Interface) *domain.Interface {
SaveConfig: src.SaveConfig, SaveConfig: src.SaveConfig,
DisplayName: src.DisplayName, DisplayName: src.DisplayName,
Type: domain.InterfaceType(src.Mode), Type: domain.InterfaceType(src.Mode),
Backend: domain.InterfaceBackend(src.Backend),
DriverType: "", // currently unused DriverType: "", // currently unused
Disabled: nil, // set below Disabled: nil, // set below
DisabledReason: src.DisabledReason, DisabledReason: src.DisabledReason,

View File

@ -46,7 +46,7 @@ func Initialize(
users: users, users: users,
} }
startupContext, cancel := context.WithTimeout(context.Background(), 30*time.Second) startupContext, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel() defer cancel()
// Switch to admin user context // Switch to admin user context

View File

@ -0,0 +1,166 @@
package wireguard
import (
"context"
"fmt"
"log/slog"
"maps"
"slices"
"github.com/h44z/wg-portal/internal/adapters/wgcontroller"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
type InterfaceController interface {
GetId() domain.InterfaceBackend
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
SaveInterface(
_ context.Context,
id domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
SavePeer(
_ context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
PingAddresses(
ctx context.Context,
addr string,
) (*domain.PingerResult, error)
}
type backendInstance struct {
Config config.BackendBase // Config is the configuration for the backend instance.
Implementation InterfaceController
}
type ControllerManager struct {
cfg *config.Config
controllers map[domain.InterfaceBackend]backendInstance
}
func NewControllerManager(cfg *config.Config) (*ControllerManager, error) {
c := &ControllerManager{
cfg: cfg,
controllers: make(map[domain.InterfaceBackend]backendInstance),
}
err := c.init()
if err != nil {
return nil, err
}
return c, nil
}
func (c *ControllerManager) init() error {
if err := c.registerLocalController(); err != nil {
return err
}
if err := c.registerMikrotikControllers(); err != nil {
return err
}
c.logRegisteredControllers()
return nil
}
func (c *ControllerManager) registerLocalController() error {
localController, err := wgcontroller.NewLocalController(c.cfg)
if err != nil {
return fmt.Errorf("failed to create local WireGuard controller: %w", err)
}
c.controllers[config.LocalBackendName] = backendInstance{
Config: config.BackendBase{
Id: config.LocalBackendName,
DisplayName: "Local WireGuard Controller",
},
Implementation: localController,
}
return nil
}
func (c *ControllerManager) registerMikrotikControllers() error {
for _, backendConfig := range c.cfg.Backend.Mikrotik {
if backendConfig.Id == config.LocalBackendName {
slog.Warn("skipping registration of Mikrotik controller with reserved ID", "id", config.LocalBackendName)
continue
}
controller, err := wgcontroller.NewMikrotikController(c.cfg, &backendConfig)
if err != nil {
return fmt.Errorf("failed to create Mikrotik controller for backend %s: %w", backendConfig.Id, err)
}
c.controllers[domain.InterfaceBackend(backendConfig.Id)] = backendInstance{
Config: backendConfig.BackendBase,
Implementation: controller,
}
}
return nil
}
func (c *ControllerManager) logRegisteredControllers() {
for backend, controller := range c.controllers {
slog.Debug("backend controller registered",
"backend", backend, "type", fmt.Sprintf("%T", controller.Implementation))
}
}
func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) InterfaceController {
return c.getController(backend, "")
}
func (c *ControllerManager) GetController(iface domain.Interface) InterfaceController {
return c.getController(iface.Backend, iface.Identifier)
}
func (c *ControllerManager) getController(
backend domain.InterfaceBackend,
ifaceId domain.InterfaceIdentifier,
) InterfaceController {
if backend == "" {
// If no backend is specified, use the local controller.
// This might be the case for interfaces created in previous WireGuard Portal versions.
backend = config.LocalBackendName
}
controller, exists := c.controllers[backend]
if !exists {
controller, exists = c.controllers[config.LocalBackendName] // Fallback to local controller
if !exists {
// If the local controller is also not found, panic
panic(fmt.Sprintf("%s interface controller for backend %s not found", ifaceId, backend))
}
slog.Warn("controller for backend not found, using local controller",
"backend", backend, "interface", ifaceId)
}
return controller.Implementation
}
func (c *ControllerManager) GetAllControllers() []InterfaceController {
var backendInstances = make([]InterfaceController, 0, len(c.controllers))
for instance := range maps.Values(c.controllers) {
backendInstances = append(backendInstances, instance.Implementation)
}
return backendInstances
}
func (c *ControllerManager) GetControllerNames() []config.BackendBase {
var names []config.BackendBase
for _, id := range slices.Sorted(maps.Keys(c.controllers)) {
names = append(names, c.controllers[id].Config)
}
return names
}

View File

@ -6,8 +6,6 @@ import (
"sync" "sync"
"time" "time"
probing "github.com/prometheus-community/pro-bing"
"github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
@ -30,11 +28,6 @@ type StatisticsDatabaseRepo interface {
DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error
} }
type StatisticsInterfaceController interface {
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
}
type StatisticsMetricsServer interface { type StatisticsMetricsServer interface {
UpdateInterfaceMetrics(status domain.InterfaceStatus) UpdateInterfaceMetrics(status domain.InterfaceStatus)
UpdatePeerMetrics(peer *domain.Peer, status domain.PeerStatus) UpdatePeerMetrics(peer *domain.Peer, status domain.PeerStatus)
@ -47,15 +40,20 @@ type StatisticsEventBus interface {
Publish(topic string, args ...any) Publish(topic string, args ...any)
} }
type pingJob struct {
Peer domain.Peer
Backend domain.InterfaceBackend
}
type StatisticsCollector struct { type StatisticsCollector struct {
cfg *config.Config cfg *config.Config
bus StatisticsEventBus bus StatisticsEventBus
pingWaitGroup sync.WaitGroup pingWaitGroup sync.WaitGroup
pingJobs chan domain.Peer pingJobs chan pingJob
db StatisticsDatabaseRepo db StatisticsDatabaseRepo
wg StatisticsInterfaceController wg *ControllerManager
ms StatisticsMetricsServer ms StatisticsMetricsServer
peerChangeEvent chan domain.PeerIdentifier peerChangeEvent chan domain.PeerIdentifier
@ -66,7 +64,7 @@ func NewStatisticsCollector(
cfg *config.Config, cfg *config.Config,
bus StatisticsEventBus, bus StatisticsEventBus,
db StatisticsDatabaseRepo, db StatisticsDatabaseRepo,
wg StatisticsInterfaceController, wg *ControllerManager,
ms StatisticsMetricsServer, ms StatisticsMetricsServer,
) (*StatisticsCollector, error) { ) (*StatisticsCollector, error) {
c := &StatisticsCollector{ c := &StatisticsCollector{
@ -117,7 +115,7 @@ func (c *StatisticsCollector) collectInterfaceData(ctx context.Context) {
} }
for _, in := range interfaces { for _, in := range interfaces {
physicalInterface, err := c.wg.GetInterface(ctx, in.Identifier) physicalInterface, err := c.wg.GetController(in).GetInterface(ctx, in.Identifier)
if err != nil { if err != nil {
slog.Warn("failed to load physical interface for data collection", "interface", in.Identifier, slog.Warn("failed to load physical interface for data collection", "interface", in.Identifier,
"error", err) "error", err)
@ -169,7 +167,7 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
} }
for _, in := range interfaces { for _, in := range interfaces {
peers, err := c.wg.GetPeers(ctx, in.Identifier) peers, err := c.wg.GetController(in).GetPeers(ctx, in.Identifier)
if err != nil { if err != nil {
slog.Warn("failed to fetch peers for data collection", "interface", in.Identifier, "error", err) slog.Warn("failed to fetch peers for data collection", "interface", in.Identifier, "error", err)
continue continue
@ -271,7 +269,7 @@ func (c *StatisticsCollector) startPingWorkers(ctx context.Context) {
c.pingWaitGroup = sync.WaitGroup{} c.pingWaitGroup = sync.WaitGroup{}
c.pingWaitGroup.Add(c.cfg.Statistics.PingCheckWorkers) c.pingWaitGroup.Add(c.cfg.Statistics.PingCheckWorkers)
c.pingJobs = make(chan domain.Peer, c.cfg.Statistics.PingCheckWorkers) c.pingJobs = make(chan pingJob, c.cfg.Statistics.PingCheckWorkers)
// start workers // start workers
for i := 0; i < c.cfg.Statistics.PingCheckWorkers; i++ { for i := 0; i < c.cfg.Statistics.PingCheckWorkers; i++ {
@ -314,7 +312,10 @@ func (c *StatisticsCollector) enqueuePingChecks(ctx context.Context) {
continue continue
} }
for _, peer := range peers { for _, peer := range peers {
c.pingJobs <- peer c.pingJobs <- pingJob{
Peer: peer,
Backend: in.Backend,
}
} }
} }
} }
@ -323,11 +324,14 @@ func (c *StatisticsCollector) enqueuePingChecks(ctx context.Context) {
func (c *StatisticsCollector) pingWorker(ctx context.Context) { func (c *StatisticsCollector) pingWorker(ctx context.Context) {
defer c.pingWaitGroup.Done() defer c.pingWaitGroup.Done()
for peer := range c.pingJobs { for job := range c.pingJobs {
peer := job.Peer
backend := job.Backend
var connectionStateChanged bool var connectionStateChanged bool
var newPeerStatus domain.PeerStatus var newPeerStatus domain.PeerStatus
peerPingable := c.isPeerPingable(ctx, peer) peerPingable := c.isPeerPingable(ctx, backend, peer)
slog.Debug("peer ping check completed", "peer", peer.Identifier, "pingable", peerPingable) slog.Debug("peer ping check completed", "peer", peer.Identifier, "pingable", peerPingable)
now := time.Now() now := time.Now()
@ -368,7 +372,11 @@ func (c *StatisticsCollector) pingWorker(ctx context.Context) {
} }
} }
func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Peer) bool { func (c *StatisticsCollector) isPeerPingable(
ctx context.Context,
backend domain.InterfaceBackend,
peer domain.Peer,
) bool {
if !c.cfg.Statistics.UsePingChecks { if !c.cfg.Statistics.UsePingChecks {
return false return false
} }
@ -378,23 +386,13 @@ func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Pe
return false return false
} }
pinger, err := probing.NewPinger(checkAddr) stats, err := c.wg.GetControllerByName(backend).PingAddresses(ctx, checkAddr)
if err != nil { if err != nil {
slog.Debug("failed to instantiate pinger", "peer", peer.Identifier, "address", checkAddr, "error", err) slog.Debug("failed to ping peer", "peer", peer.Identifier, "error", err)
return false return false
} }
checkCount := 1 return stats.IsPingable()
pinger.SetPrivileged(!c.cfg.Statistics.PingUnprivileged)
pinger.Count = checkCount
pinger.Timeout = 2 * time.Second
err = pinger.RunWithContext(ctx) // Blocks until finished.
if err != nil {
slog.Debug("pinger for peer exited unexpectedly", "peer", peer.Identifier, "address", checkAddr, "error", err)
return false
}
stats := pinger.Statistics()
return stats.PacketsRecv == checkCount
} }
func (c *StatisticsCollector) updateInterfaceMetrics(status domain.InterfaceStatus) { func (c *StatisticsCollector) updateInterfaceMetrics(status domain.InterfaceStatus) {

View File

@ -37,25 +37,6 @@ type InterfaceAndPeerDatabaseRepo interface {
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
} }
type InterfaceController interface {
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
SaveInterface(
_ context.Context,
id domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
SavePeer(
_ context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
}
type WgQuickController interface { type WgQuickController interface {
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
@ -75,7 +56,7 @@ type Manager struct {
cfg *config.Config cfg *config.Config
bus EventBus bus EventBus
db InterfaceAndPeerDatabaseRepo db InterfaceAndPeerDatabaseRepo
wg InterfaceController wg *ControllerManager
quick WgQuickController quick WgQuickController
userLockMap *sync.Map userLockMap *sync.Map
@ -84,7 +65,7 @@ type Manager struct {
func NewWireGuardManager( func NewWireGuardManager(
cfg *config.Config, cfg *config.Config,
bus EventBus, bus EventBus,
wg InterfaceController, wg *ControllerManager,
quick WgQuickController, quick WgQuickController,
db InterfaceAndPeerDatabaseRepo, db InterfaceAndPeerDatabaseRepo,
) (*Manager, error) { ) (*Manager, error) {

View File

@ -11,6 +11,7 @@ import (
"github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/app/audit" "github.com/h44z/wg-portal/internal/app/audit"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
) )
@ -21,12 +22,17 @@ func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.Physical
return nil, err return nil, err
} }
physicalInterfaces, err := m.wg.GetInterfaces(ctx) var allPhysicalInterfaces []domain.PhysicalInterface
if err != nil { for _, wgBackend := range m.wg.GetAllControllers() {
return nil, err physicalInterfaces, err := wgBackend.GetInterfaces(ctx)
if err != nil {
return nil, err
}
allPhysicalInterfaces = append(allPhysicalInterfaces, physicalInterfaces...)
} }
return physicalInterfaces, nil return allPhysicalInterfaces, nil
} }
// GetInterfaceAndPeers returns the interface and all peers for the given interface identifier. // GetInterfaceAndPeers returns the interface and all peers for the given interface identifier.
@ -109,47 +115,49 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter
return 0, err return 0, err
} }
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
if err != nil {
return 0, err
}
// if no filter is given, exclude already existing interfaces
var excludedInterfaces []domain.InterfaceIdentifier
if len(filter) == 0 {
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return 0, err
}
for _, existingInterface := range existingInterfaces {
excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier)
}
}
imported := 0 imported := 0
for _, physicalInterface := range physicalInterfaces { for _, wgBackend := range m.wg.GetAllControllers() {
if slices.Contains(excludedInterfaces, physicalInterface.Identifier) { physicalInterfaces, err := wgBackend.GetInterfaces(ctx)
continue
}
if len(filter) != 0 && !slices.Contains(filter, physicalInterface.Identifier) {
continue
}
slog.Info("importing new interface", "interface", physicalInterface.Identifier)
physicalPeers, err := m.wg.GetPeers(ctx, physicalInterface.Identifier)
if err != nil { if err != nil {
return 0, err return 0, err
} }
err = m.importInterface(ctx, &physicalInterface, physicalPeers) // if no filter is given, exclude already existing interfaces
if err != nil { var excludedInterfaces []domain.InterfaceIdentifier
return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err) if len(filter) == 0 {
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return 0, err
}
for _, existingInterface := range existingInterfaces {
excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier)
}
} }
slog.Info("imported new interface", "interface", physicalInterface.Identifier, "peers", len(physicalPeers)) for _, physicalInterface := range physicalInterfaces {
imported++ if slices.Contains(excludedInterfaces, physicalInterface.Identifier) {
continue
}
if len(filter) != 0 && !slices.Contains(filter, physicalInterface.Identifier) {
continue
}
slog.Info("importing new interface", "interface", physicalInterface.Identifier)
physicalPeers, err := wgBackend.GetPeers(ctx, physicalInterface.Identifier)
if err != nil {
return 0, err
}
err = m.importInterface(ctx, wgBackend, &physicalInterface, physicalPeers)
if err != nil {
return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err)
}
slog.Info("imported new interface", "interface", physicalInterface.Identifier, "peers", len(physicalPeers))
imported++
}
} }
return imported, nil return imported, nil
@ -213,7 +221,7 @@ func (m Manager) RestoreInterfaceState(
return fmt.Errorf("failed to load peers for %s: %w", iface.Identifier, err) return fmt.Errorf("failed to load peers for %s: %w", iface.Identifier, err)
} }
_, err = m.wg.GetInterface(ctx, iface.Identifier) _, err = m.wg.GetController(iface).GetInterface(ctx, iface.Identifier)
if err != nil && !iface.IsDisabled() { if err != nil && !iface.IsDisabled() {
slog.Debug("creating missing interface", "interface", iface.Identifier) slog.Debug("creating missing interface", "interface", iface.Identifier)
@ -260,18 +268,14 @@ func (m Manager) RestoreInterfaceState(
// restore peers // restore peers
for _, peer := range peers { for _, peer := range peers {
switch { switch {
case iface.IsDisabled(): // if interface is disabled, delete all peers case iface.IsDisabled() && iface.Backend == config.LocalBackendName: // if interface is disabled, delete all peers
if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil { if err := m.wg.GetController(iface).DeletePeer(ctx, iface.Identifier,
peer.Identifier); err != nil {
return fmt.Errorf("failed to remove peer %s for disabled interface %s: %w", return fmt.Errorf("failed to remove peer %s for disabled interface %s: %w",
peer.Identifier, iface.Identifier, err) peer.Identifier, iface.Identifier, err)
} }
case peer.IsDisabled(): // if peer is disabled, delete it
if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil {
return fmt.Errorf("failed to remove disbaled peer %s from interface %s: %w",
peer.Identifier, iface.Identifier, err)
}
default: // update peer default: // update peer
err := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier, err := m.wg.GetController(iface).SavePeer(ctx, iface.Identifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) { func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, &peer) domain.MergeToPhysicalPeer(pp, &peer)
return pp, nil return pp, nil
@ -284,7 +288,7 @@ func (m Manager) RestoreInterfaceState(
} }
// remove non-wgportal peers // remove non-wgportal peers
physicalPeers, _ := m.wg.GetPeers(ctx, iface.Identifier) physicalPeers, _ := m.wg.GetController(iface).GetPeers(ctx, iface.Identifier)
for _, physicalPeer := range physicalPeers { for _, physicalPeer := range physicalPeers {
isWgPortalPeer := false isWgPortalPeer := false
for _, peer := range peers { for _, peer := range peers {
@ -294,7 +298,8 @@ func (m Manager) RestoreInterfaceState(
} }
} }
if !isWgPortalPeer { if !isWgPortalPeer {
err := m.wg.DeletePeer(ctx, iface.Identifier, domain.PeerIdentifier(physicalPeer.PublicKey)) err := m.wg.GetController(iface).DeletePeer(ctx, iface.Identifier,
domain.PeerIdentifier(physicalPeer.PublicKey))
if err != nil { if err != nil {
return fmt.Errorf("failed to remove non-wgportal peer %s from interface %s: %w", return fmt.Errorf("failed to remove non-wgportal peer %s from interface %s: %w",
physicalPeer.PublicKey, iface.Identifier, err) physicalPeer.PublicKey, iface.Identifier, err)
@ -459,7 +464,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
existingInterface.Disabled = &now // simulate a disabled interface existingInterface.Disabled = &now // simulate a disabled interface
existingInterface.DisabledReason = domain.DisabledReasonDeleted existingInterface.DisabledReason = domain.DisabledReasonDeleted
physicalInterface, _ := m.wg.GetInterface(ctx, id) physicalInterface, _ := m.wg.GetController(*existingInterface).GetInterface(ctx, id)
if err := m.handleInterfacePreSaveHooks(existingInterface, !existingInterface.IsDisabled(), false); err != nil { if err := m.handleInterfacePreSaveHooks(existingInterface, !existingInterface.IsDisabled(), false); err != nil {
return fmt.Errorf("pre-delete hooks failed: %w", err) return fmt.Errorf("pre-delete hooks failed: %w", err)
@ -473,7 +478,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
return fmt.Errorf("peer deletion failure: %w", err) return fmt.Errorf("peer deletion failure: %w", err)
} }
if err := m.wg.DeleteInterface(ctx, id); err != nil { if err := m.wg.GetController(*existingInterface).DeleteInterface(ctx, id); err != nil {
return fmt.Errorf("wireguard deletion failure: %w", err) return fmt.Errorf("wireguard deletion failure: %w", err)
} }
@ -522,7 +527,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) { err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
iface.CopyCalculatedAttributes(i) iface.CopyCalculatedAttributes(i)
err := m.wg.SaveInterface(ctx, iface.Identifier, err := m.wg.GetController(*iface).SaveInterface(ctx, iface.Identifier,
func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
domain.MergeToPhysicalInterface(pi, iface) domain.MergeToPhysicalInterface(pi, iface)
return pi, nil return pi, nil
@ -538,7 +543,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
} }
if iface.IsDisabled() { if iface.IsDisabled() {
physicalInterface, _ := m.wg.GetInterface(ctx, iface.Identifier) physicalInterface, _ := m.wg.GetController(*iface).GetInterface(ctx, iface.Identifier)
fwMark := iface.FirewallMark fwMark := iface.FirewallMark
if physicalInterface != nil && fwMark == 0 { if physicalInterface != nil && fwMark == 0 {
fwMark = physicalInterface.FirewallMark fwMark = physicalInterface.FirewallMark
@ -556,13 +561,13 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
} }
// If the interface has just been enabled, restore its peers on the physical controller // If the interface has just been enabled, restore its peers on the physical controller
if !oldEnabled && newEnabled { if !oldEnabled && newEnabled && iface.Backend == config.LocalBackendName {
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier) peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err) return nil, fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
} }
for _, peer := range peers { for _, peer := range peers {
saveErr := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier, saveErr := m.wg.GetController(*iface).SavePeer(ctx, iface.Identifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) { func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, &peer) domain.MergeToPhysicalPeer(pp, &peer)
return pp, nil return pp, nil
@ -766,7 +771,12 @@ func (m Manager) getFreshListenPort(ctx context.Context) (port int, err error) {
return return
} }
func (m Manager) importInterface(ctx context.Context, in *domain.PhysicalInterface, peers []domain.PhysicalPeer) error { func (m Manager) importInterface(
ctx context.Context,
backend InterfaceController,
in *domain.PhysicalInterface,
peers []domain.PhysicalPeer,
) error {
now := time.Now() now := time.Now()
iface := domain.ConvertPhysicalInterface(in) iface := domain.ConvertPhysicalInterface(in)
iface.BaseModel = domain.BaseModel{ iface.BaseModel = domain.BaseModel{
@ -775,8 +785,20 @@ func (m Manager) importInterface(ctx context.Context, in *domain.PhysicalInterfa
CreatedAt: now, CreatedAt: now,
UpdatedAt: now, UpdatedAt: now,
} }
iface.Backend = backend.GetId()
iface.PeerDefAllowedIPsStr = iface.AddressStr() iface.PeerDefAllowedIPsStr = iface.AddressStr()
// try to predict the interface type based on the number of peers
switch len(peers) {
case 0:
iface.Type = domain.InterfaceTypeAny // no peers means this is an unknown interface
case 1:
iface.Type = domain.InterfaceTypeClient // one peer means this is a client interface
default: // multiple peers means this is a server interface
iface.Type = domain.InterfaceTypeServer
}
existingInterface, err := m.db.GetInterface(ctx, iface.Identifier) existingInterface, err := m.db.GetInterface(ctx, iface.Identifier)
if err != nil && !errors.Is(err, domain.ErrNotFound) { if err != nil && !errors.Is(err, domain.ErrNotFound) {
return err return err
@ -827,16 +849,20 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain
peer.Interface.PreDown = domain.NewConfigOption(in.PeerDefPreDown, true) peer.Interface.PreDown = domain.NewConfigOption(in.PeerDefPreDown, true)
peer.Interface.PostDown = domain.NewConfigOption(in.PeerDefPostDown, true) peer.Interface.PostDown = domain.NewConfigOption(in.PeerDefPostDown, true)
var displayName string
switch in.Type { switch in.Type {
case domain.InterfaceTypeAny: case domain.InterfaceTypeAny:
peer.Interface.Type = domain.InterfaceTypeAny peer.Interface.Type = domain.InterfaceTypeAny
peer.DisplayName = "Autodetected Peer (" + peer.Interface.PublicKey[0:8] + ")" displayName = "Autodetected Peer (" + peer.Interface.PublicKey[0:8] + ")"
case domain.InterfaceTypeClient: case domain.InterfaceTypeClient:
peer.Interface.Type = domain.InterfaceTypeServer peer.Interface.Type = domain.InterfaceTypeServer
peer.DisplayName = "Autodetected Endpoint (" + peer.Interface.PublicKey[0:8] + ")" displayName = "Autodetected Endpoint (" + peer.Interface.PublicKey[0:8] + ")"
case domain.InterfaceTypeServer: case domain.InterfaceTypeServer:
peer.Interface.Type = domain.InterfaceTypeClient peer.Interface.Type = domain.InterfaceTypeClient
peer.DisplayName = "Autodetected Client (" + peer.Interface.PublicKey[0:8] + ")" displayName = "Autodetected Client (" + peer.Interface.PublicKey[0:8] + ")"
}
if peer.DisplayName == "" {
peer.DisplayName = displayName // use auto-generated display name if not set
} }
err := m.db.SavePeer(ctx, peer.Identifier, func(_ *domain.Peer) (*domain.Peer, error) { err := m.db.SavePeer(ctx, peer.Identifier, func(_ *domain.Peer) (*domain.Peer, error) {
@ -850,12 +876,12 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain
} }
func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) error { func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) error {
allPeers, err := m.db.GetInterfacePeers(ctx, id) iface, allPeers, err := m.db.GetInterfaceAndPeers(ctx, id)
if err != nil { if err != nil {
return err return err
} }
for _, peer := range allPeers { for _, peer := range allPeers {
err = m.wg.DeletePeer(ctx, id, peer.Identifier) err = m.wg.GetController(*iface).DeletePeer(ctx, id, peer.Identifier)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err) return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err)
} }

View File

@ -371,7 +371,12 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
return fmt.Errorf("delete not allowed: %w", err) return fmt.Errorf("delete not allowed: %w", err)
} }
err = m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, id) iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
if err != nil {
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
}
err = m.wg.GetController(*iface).DeletePeer(ctx, peer.InterfaceIdentifier, id)
if err != nil { if err != nil {
return fmt.Errorf("wireguard failed to delete peer %s: %w", id, err) return fmt.Errorf("wireguard failed to delete peer %s: %w", id, err)
} }
@ -433,35 +438,28 @@ func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier)
func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error { func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
interfaces := make(map[domain.InterfaceIdentifier]struct{}) interfaces := make(map[domain.InterfaceIdentifier]struct{})
for i := range peers { for _, peer := range peers {
peer := peers[i] iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
var err error if err != nil {
if peer.IsDisabled() || peer.IsExpired() { return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
peer.CopyCalculatedAttributes(p)
if err := m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, peer.Identifier); err != nil {
return nil, fmt.Errorf("failed to delete wireguard peer %s: %w", peer.Identifier, err)
}
return peer, nil
})
} else {
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
peer.CopyCalculatedAttributes(p)
err := m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, peer)
return pp, nil
})
if err != nil {
return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
}
return peer, nil
})
} }
// Always save the peer to the backend, regardless of disabled/expired state
// The backend will handle the disabled state appropriately
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
peer.CopyCalculatedAttributes(p)
err := m.wg.GetController(*iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, peer)
return pp, nil
})
if err != nil {
return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
}
return peer, nil
})
if err != nil { if err != nil {
return fmt.Errorf("save failure for peer %s: %w", peer.Identifier, err) return fmt.Errorf("save failure for peer %s: %w", peer.Identifier, err)
} }

View File

@ -0,0 +1,94 @@
package config
import (
"fmt"
"time"
)
const LocalBackendName = "local"
type Backend struct {
Default string `yaml:"default"` // The default backend to use (defaults to the internal backend)
Mikrotik []BackendMikrotik `yaml:"mikrotik"`
}
// Validate checks the backend configuration for errors.
func (b *Backend) Validate() error {
if b.Default == "" {
b.Default = LocalBackendName
}
uniqueMap := make(map[string]struct{})
for _, backend := range b.Mikrotik {
if backend.Id == LocalBackendName {
return fmt.Errorf("backend ID %q is a reserved keyword", LocalBackendName)
}
if _, exists := uniqueMap[backend.Id]; exists {
return fmt.Errorf("backend ID %q is not unique", backend.Id)
}
uniqueMap[backend.Id] = struct{}{}
}
if b.Default != LocalBackendName {
if _, ok := uniqueMap[b.Default]; !ok {
return fmt.Errorf("default backend %q is not defined in the configuration", b.Default)
}
}
return nil
}
type BackendBase struct {
Id string `yaml:"id"` // A unique id for the backend
DisplayName string `yaml:"display_name"` // A display name for the backend
}
// GetDisplayName returns the display name of the backend.
// If no display name is set, it falls back to the ID.
func (b BackendBase) GetDisplayName() string {
if b.DisplayName == "" {
return b.Id // Fallback to ID if no display name is set
}
return b.DisplayName
}
type BackendMikrotik struct {
BackendBase `yaml:",inline"` // Embed the base fields
ApiUrl string `yaml:"api_url"` // The base URL of the Mikrotik API (e.g., "https://10.10.10.10:8729/rest")
ApiUser string `yaml:"api_user"`
ApiPassword string `yaml:"api_password"`
ApiVerifyTls bool `yaml:"api_verify_tls"` // Whether to verify the TLS certificate of the Mikrotik API
ApiTimeout time.Duration `yaml:"api_timeout"` // Timeout for API requests (default: 30 seconds)
// Concurrency controls the maximum number of concurrent API requests that this backend will issue
// when enumerating interfaces and their details. If 0 or negative, a default of 5 is used.
Concurrency int `yaml:"concurrency"`
Debug bool `yaml:"debug"` // Enable debug logging for the Mikrotik backend
}
// GetConcurrency returns the configured concurrency for this backend or a sane default (5)
// when the configured value is zero or negative.
func (b *BackendMikrotik) GetConcurrency() int {
if b == nil {
return 5
}
if b.Concurrency <= 0 {
return 5
}
return b.Concurrency
}
// GetApiTimeout returns the configured API timeout or a sane default (30 seconds)
// when the configured value is zero or negative.
func (b *BackendMikrotik) GetApiTimeout() time.Duration {
if b == nil {
return 30 * time.Second
}
if b.ApiTimeout <= 0 {
return 30 * time.Second
}
return b.ApiTimeout
}

View File

@ -44,6 +44,8 @@ type Config struct {
LimitAdditionalUserPeers int `yaml:"limit_additional_user_peers"` LimitAdditionalUserPeers int `yaml:"limit_additional_user_peers"`
} `yaml:"advanced"` } `yaml:"advanced"`
Backend Backend `yaml:"backend"`
Statistics struct { Statistics struct {
UsePingChecks bool `yaml:"use_ping_checks"` UsePingChecks bool `yaml:"use_ping_checks"`
PingCheckWorkers int `yaml:"ping_check_workers"` PingCheckWorkers int `yaml:"ping_check_workers"`
@ -99,6 +101,12 @@ func (c *Config) LogStartupValues() {
"minPasswordLength", c.Auth.MinPasswordLength, "minPasswordLength", c.Auth.MinPasswordLength,
"hideLoginForm", c.Auth.HideLoginForm, "hideLoginForm", c.Auth.HideLoginForm,
) )
slog.Debug("Config Backend",
"defaultBackend", c.Backend.Default,
"extraBackends", len(c.Backend.Mikrotik),
)
} }
// defaultConfig returns the default configuration // defaultConfig returns the default configuration
@ -122,6 +130,10 @@ func defaultConfig() *Config {
DSN: "data/sqlite.db", DSN: "data/sqlite.db",
} }
cfg.Backend = Backend{
Default: LocalBackendName, // local backend is the default (using wgcrtl)
}
cfg.Web = WebConfig{ cfg.Web = WebConfig{
RequestLogging: false, RequestLogging: false,
ExternalUrl: "http://localhost:8888", ExternalUrl: "http://localhost:8888",
@ -201,6 +213,10 @@ func GetConfig() (*Config, error) {
} }
cfg.Web.Sanitize() cfg.Web.Sanitize()
err := cfg.Backend.Validate()
if err != nil {
return nil, err
}
return cfg, nil return cfg, nil
} }

View File

@ -0,0 +1,32 @@
package domain
// ControllerType defines the type of controller used to manage interfaces.
const (
ControllerTypeMikrotik = "mikrotik"
ControllerTypeLocal = "wgctrl"
)
// Controller extras can be used to store additional information available for specific controllers only.
type MikrotikInterfaceExtras struct {
Id string // internal mikrotik ID
Comment string
Disabled bool
}
type MikrotikPeerExtras struct {
Id string // internal mikrotik ID
Name string
Comment string
IsResponder bool
Disabled bool
ClientEndpoint string
ClientAddress string
ClientDns string
ClientKeepalive int
}
type LocalPeerExtras struct {
Disabled bool
}

View File

@ -10,6 +10,8 @@ import (
"strings" "strings"
"time" "time"
"golang.org/x/sys/unix"
"github.com/h44z/wg-portal/internal" "github.com/h44z/wg-portal/internal"
) )
@ -23,6 +25,7 @@ var allowedFileNameRegex = regexp.MustCompile("[^a-zA-Z0-9-_]+")
type InterfaceIdentifier string type InterfaceIdentifier string
type InterfaceType string type InterfaceType string
type InterfaceBackend string
type Interface struct { type Interface struct {
BaseModel BaseModel
@ -49,11 +52,12 @@ type Interface struct {
SaveConfig bool // automatically persist config changes to the wgX.conf file SaveConfig bool // automatically persist config changes to the wgX.conf file
// WG Portal specific // WG Portal specific
DisplayName string // a nice display name/ description for the interface DisplayName string // a nice display name/ description for the interface
Type InterfaceType // the interface type, either InterfaceTypeServer or InterfaceTypeClient Type InterfaceType // the interface type, either InterfaceTypeServer or InterfaceTypeClient
DriverType string // the interface driver type (linux, software, ...) Backend InterfaceBackend // the backend that is used to manage the interface (wgctrl, mikrotik, ...)
Disabled *time.Time `gorm:"index"` // flag that specifies if the interface is enabled (up) or not (down) DriverType string // the interface driver type (linux, software, ...)
DisabledReason string // the reason why the interface has been disabled Disabled *time.Time `gorm:"index"` // flag that specifies if the interface is enabled (up) or not (down)
DisabledReason string // the reason why the interface has been disabled
// Default settings for the peer, used for new peers, those settings will be published to ConfigOption options of // Default settings for the peer, used for new peers, those settings will be published to ConfigOption options of
// the peer config // the peer config
@ -204,9 +208,31 @@ type PhysicalInterface struct {
BytesUpload uint64 BytesUpload uint64
BytesDownload uint64 BytesDownload uint64
backendExtras any // additional backend-specific extras, e.g., domain.MikrotikInterfaceExtras
}
func (p *PhysicalInterface) GetExtras() any {
return p.backendExtras
}
func (p *PhysicalInterface) SetExtras(extras any) {
switch extras.(type) {
case MikrotikInterfaceExtras: // OK
default: // we only support MikrotikInterfaceExtras for now
panic(fmt.Sprintf("unsupported interface backend extras type %T", extras))
}
p.backendExtras = extras
} }
func ConvertPhysicalInterface(pi *PhysicalInterface) *Interface { func ConvertPhysicalInterface(pi *PhysicalInterface) *Interface {
networks := make([]Cidr, 0, len(pi.Addresses))
for _, addr := range pi.Addresses {
networks = append(networks, addr.NetworkAddr())
}
// create a new basic interface with the data from the physical interface
iface := &Interface{ iface := &Interface{
Identifier: pi.Identifier, Identifier: pi.Identifier,
KeyPair: pi.KeyPair, KeyPair: pi.KeyPair,
@ -226,11 +252,11 @@ func ConvertPhysicalInterface(pi *PhysicalInterface) *Interface {
Type: InterfaceTypeAny, Type: InterfaceTypeAny,
DriverType: pi.DeviceType, DriverType: pi.DeviceType,
Disabled: nil, Disabled: nil,
PeerDefNetworkStr: "", PeerDefNetworkStr: CidrsToString(networks),
PeerDefDnsStr: "", PeerDefDnsStr: "",
PeerDefDnsSearchStr: "", PeerDefDnsSearchStr: "",
PeerDefEndpoint: "", PeerDefEndpoint: "",
PeerDefAllowedIPsStr: "", PeerDefAllowedIPsStr: CidrsToString(networks),
PeerDefMtu: pi.Mtu, PeerDefMtu: pi.Mtu,
PeerDefPersistentKeepalive: 0, PeerDefPersistentKeepalive: 0,
PeerDefFirewallMark: 0, PeerDefFirewallMark: 0,
@ -241,6 +267,23 @@ func ConvertPhysicalInterface(pi *PhysicalInterface) *Interface {
PeerDefPostDown: "", PeerDefPostDown: "",
} }
if pi.GetExtras() == nil {
return iface
}
// enrich the data with controller-specific extras
now := time.Now()
switch pi.ImportSource {
case ControllerTypeMikrotik:
extras := pi.GetExtras().(MikrotikInterfaceExtras)
iface.DisplayName = extras.Comment
if extras.Disabled {
iface.Disabled = &now
} else {
iface.Disabled = nil
}
}
return iface return iface
} }
@ -253,6 +296,15 @@ func MergeToPhysicalInterface(pi *PhysicalInterface, i *Interface) {
pi.FirewallMark = i.FirewallMark pi.FirewallMark = i.FirewallMark
pi.DeviceUp = !i.IsDisabled() pi.DeviceUp = !i.IsDisabled()
pi.Addresses = i.Addresses pi.Addresses = i.Addresses
switch pi.ImportSource {
case ControllerTypeMikrotik:
extras := MikrotikInterfaceExtras{
Comment: i.DisplayName,
Disabled: i.IsDisabled(),
}
pi.SetExtras(extras)
}
} }
type RoutingTableInfo struct { type RoutingTableInfo struct {
@ -279,3 +331,30 @@ func (r RoutingTableInfo) GetRoutingTable() int {
return r.Table return r.Table
} }
type IpFamily int
const (
IpFamilyIPv4 IpFamily = unix.AF_INET
IpFamilyIPv6 IpFamily = unix.AF_INET6
)
func (f IpFamily) String() string {
switch f {
case IpFamilyIPv4:
return "IPv4"
case IpFamilyIPv6:
return "IPv6"
default:
return "unknown"
}
}
// RouteRule represents a routing table rule.
type RouteRule struct {
InterfaceId InterfaceIdentifier
IpFamily IpFamily
FwMark uint32
Table int
HasDefault bool
}

View File

@ -129,7 +129,7 @@ func (p *Peer) GenerateDisplayName(prefix string) {
p.DisplayName = fmt.Sprintf("%sPeer %s", prefix, internal.TruncateString(string(p.Identifier), 8)) p.DisplayName = fmt.Sprintf("%sPeer %s", prefix, internal.TruncateString(string(p.Identifier), 8))
} }
// OverwriteUserEditableFields overwrites the user editable fields of the peer with the values from the userPeer // OverwriteUserEditableFields overwrites the user-editable fields of the peer with the values from the userPeer
func (p *Peer) OverwriteUserEditableFields(userPeer *Peer, cfg *config.Config) { func (p *Peer) OverwriteUserEditableFields(userPeer *Peer, cfg *config.Config) {
p.DisplayName = userPeer.DisplayName p.DisplayName = userPeer.DisplayName
if cfg.Core.EditableKeys { if cfg.Core.EditableKeys {
@ -182,9 +182,12 @@ type PhysicalPeer struct {
BytesUpload uint64 // upload bytes are the number of bytes that the remote peer has sent to the server BytesUpload uint64 // upload bytes are the number of bytes that the remote peer has sent to the server
BytesDownload uint64 // upload bytes are the number of bytes that the remote peer has received from the server BytesDownload uint64 // upload bytes are the number of bytes that the remote peer has received from the server
ImportSource string // import source (wgctrl, file, ...)
backendExtras any // additional backend-specific extras, e.g., domain.MikrotikPeerExtras
} }
func (p PhysicalPeer) GetPresharedKey() *wgtypes.Key { func (p *PhysicalPeer) GetPresharedKey() *wgtypes.Key {
if p.PresharedKey == "" { if p.PresharedKey == "" {
return nil return nil
} }
@ -196,7 +199,7 @@ func (p PhysicalPeer) GetPresharedKey() *wgtypes.Key {
return &key return &key
} }
func (p PhysicalPeer) GetEndpointAddress() *net.UDPAddr { func (p *PhysicalPeer) GetEndpointAddress() *net.UDPAddr {
if p.Endpoint == "" { if p.Endpoint == "" {
return nil return nil
} }
@ -208,7 +211,7 @@ func (p PhysicalPeer) GetEndpointAddress() *net.UDPAddr {
return addr return addr
} }
func (p PhysicalPeer) GetPersistentKeepaliveTime() *time.Duration { func (p *PhysicalPeer) GetPersistentKeepaliveTime() *time.Duration {
if p.PersistentKeepalive == 0 { if p.PersistentKeepalive == 0 {
return nil return nil
} }
@ -217,7 +220,7 @@ func (p PhysicalPeer) GetPersistentKeepaliveTime() *time.Duration {
return &keepAliveDuration return &keepAliveDuration
} }
func (p PhysicalPeer) GetAllowedIPs() []net.IPNet { func (p *PhysicalPeer) GetAllowedIPs() []net.IPNet {
allowedIPs := make([]net.IPNet, len(p.AllowedIPs)) allowedIPs := make([]net.IPNet, len(p.AllowedIPs))
for i, ip := range p.AllowedIPs { for i, ip := range p.AllowedIPs {
allowedIPs[i] = *ip.IpNet() allowedIPs[i] = *ip.IpNet()
@ -226,6 +229,21 @@ func (p PhysicalPeer) GetAllowedIPs() []net.IPNet {
return allowedIPs return allowedIPs
} }
func (p *PhysicalPeer) GetExtras() any {
return p.backendExtras
}
func (p *PhysicalPeer) SetExtras(extras any) {
switch extras.(type) {
case MikrotikPeerExtras: // OK
case LocalPeerExtras: // OK
default: // we only support MikrotikPeerExtras and LocalPeerExtras for now
panic(fmt.Sprintf("unsupported peer backend extras type %T", extras))
}
p.backendExtras = extras
}
func ConvertPhysicalPeer(pp *PhysicalPeer) *Peer { func ConvertPhysicalPeer(pp *PhysicalPeer) *Peer {
peer := &Peer{ peer := &Peer{
Endpoint: NewConfigOption(pp.Endpoint, true), Endpoint: NewConfigOption(pp.Endpoint, true),
@ -244,6 +262,44 @@ func ConvertPhysicalPeer(pp *PhysicalPeer) *Peer {
}, },
} }
if pp.GetExtras() == nil {
return peer
}
// enrich the data with controller-specific extras
now := time.Now()
switch pp.ImportSource {
case ControllerTypeMikrotik:
extras := pp.GetExtras().(MikrotikPeerExtras)
peer.Notes = extras.Comment
peer.DisplayName = extras.Name
if extras.ClientEndpoint != "" { // if the client endpoint is set, we assume that this is a client peer
peer.Endpoint = NewConfigOption(extras.ClientEndpoint, true)
peer.Interface.Type = InterfaceTypeClient
peer.Interface.Addresses, _ = CidrsFromString(extras.ClientAddress)
peer.Interface.DnsStr = NewConfigOption(extras.ClientDns, true)
peer.PersistentKeepalive = NewConfigOption(extras.ClientKeepalive, true)
} else {
peer.Interface.Type = InterfaceTypeServer
}
if extras.Disabled {
peer.Disabled = &now
peer.DisabledReason = "Disabled by Mikrotik controller"
} else {
peer.Disabled = nil
peer.DisabledReason = ""
}
case ControllerTypeLocal:
extras := pp.GetExtras().(LocalPeerExtras)
if extras.Disabled {
peer.Disabled = &now
peer.DisabledReason = "Disabled by Local controller"
} else {
peer.Disabled = nil
peer.DisabledReason = ""
}
}
return peer return peer
} }
@ -265,6 +321,27 @@ func MergeToPhysicalPeer(pp *PhysicalPeer, p *Peer) {
pp.PresharedKey = p.PresharedKey pp.PresharedKey = p.PresharedKey
pp.PublicKey = p.Interface.PublicKey pp.PublicKey = p.Interface.PublicKey
pp.PersistentKeepalive = p.PersistentKeepalive.GetValue() pp.PersistentKeepalive = p.PersistentKeepalive.GetValue()
switch pp.ImportSource {
case ControllerTypeMikrotik:
extras := MikrotikPeerExtras{
Id: "",
Name: p.DisplayName,
Comment: p.Notes,
IsResponder: false,
Disabled: p.IsDisabled(),
ClientEndpoint: p.Endpoint.GetValue(),
ClientAddress: CidrsToString(p.Interface.Addresses),
ClientDns: p.Interface.DnsStr.GetValue(),
ClientKeepalive: p.PersistentKeepalive.GetValue(),
}
pp.SetExtras(extras)
case ControllerTypeLocal:
extras := LocalPeerExtras{
Disabled: p.IsDisabled(),
}
pp.SetExtras(extras)
}
} }
type PeerCreationRequest struct { type PeerCreationRequest struct {

View File

@ -1,6 +1,8 @@
package domain package domain
import "time" import (
"time"
)
type PeerStatus struct { type PeerStatus struct {
PeerId PeerIdentifier `gorm:"primaryKey;column:identifier" json:"PeerId"` PeerId PeerIdentifier `gorm:"primaryKey;column:identifier" json:"PeerId"`
@ -37,3 +39,25 @@ type InterfaceStatus struct {
BytesReceived uint64 `gorm:"column:received"` BytesReceived uint64 `gorm:"column:received"`
BytesTransmitted uint64 `gorm:"column:transmitted"` BytesTransmitted uint64 `gorm:"column:transmitted"`
} }
type PingerResult struct {
PacketsRecv int
PacketsSent int
Rtts []time.Duration
}
func (r PingerResult) IsPingable() bool {
return r.PacketsRecv > 0 && r.PacketsSent > 0 && len(r.Rtts) > 0
}
func (r PingerResult) AverageRtt() time.Duration {
if len(r.Rtts) == 0 {
return 0
}
var total time.Duration
for _, rtt := range r.Rtts {
total += rtt
}
return total / time.Duration(len(r.Rtts))
}

View File

@ -12,8 +12,8 @@ import (
"sync" "sync"
) )
// SetupLogging initializes the global logger with the given level and format // GetLoggingHandler initializes a slog.Handler based on the provided logging level and format options.
func SetupLogging(level string, pretty, json bool) { func GetLoggingHandler(level string, pretty, json bool) slog.Handler {
var logLevel = new(slog.LevelVar) var logLevel = new(slog.LevelVar)
switch strings.ToLower(level) { switch strings.ToLower(level) {
@ -46,6 +46,13 @@ func SetupLogging(level string, pretty, json bool) {
handler = slog.NewTextHandler(output, opts) handler = slog.NewTextHandler(output, opts)
} }
return handler
}
// SetupLogging initializes the global logger with the given level and format
func SetupLogging(level string, pretty, json bool) {
handler := GetLoggingHandler(level, pretty, json)
logger := slog.New(handler) logger := slog.New(handler)
slog.SetDefault(logger) slog.SetDefault(logger)

View File

@ -0,0 +1,435 @@
package lowlevel
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/config"
)
// region models
const (
MikrotikApiStatusOk = "success"
MikrotikApiStatusError = "error"
)
const (
MikrotikApiErrorCodeUnknown = iota + 600
MikrotikApiErrorCodeRequestPreparationFailed
MikrotikApiErrorCodeRequestFailed
MikrotikApiErrorCodeResponseDecodeFailed
)
type MikrotikApiResponse[T any] struct {
Status string
Code int
Data T `json:"data,omitempty"`
Error *MikrotikApiError `json:"error,omitempty"`
}
type MikrotikApiError struct {
Code int `json:"error,omitempty"`
Message string `json:"message,omitempty"`
Details string `json:"detail,omitempty"`
}
func (e *MikrotikApiError) String() string {
if e == nil {
return "no error"
}
return fmt.Sprintf("API error %d: %s - %s", e.Code, e.Message, e.Details)
}
type GenericJsonObject map[string]any
type EmptyResponse struct{}
func (JsonObject GenericJsonObject) GetString(key string) string {
if value, ok := JsonObject[key]; ok {
if strValue, ok := value.(string); ok {
return strValue
} else {
return fmt.Sprintf("%v", value) // Convert to string if not already
}
}
return ""
}
func (JsonObject GenericJsonObject) GetInt(key string) int {
if value, ok := JsonObject[key]; ok {
if intValue, ok := value.(int); ok {
return intValue
} else {
if floatValue, ok := value.(float64); ok {
return int(floatValue) // Convert float64 to int
}
if strValue, ok := value.(string); ok {
if intValue, err := strconv.Atoi(strValue); err == nil {
return intValue // Convert string to int if possible
}
}
}
}
return 0
}
func (JsonObject GenericJsonObject) GetBool(key string) bool {
if value, ok := JsonObject[key]; ok {
if boolValue, ok := value.(bool); ok {
return boolValue
} else {
if intValue, ok := value.(int); ok {
return intValue == 1 // Convert int to bool (1 is true, 0 is false)
}
if floatValue, ok := value.(float64); ok {
return int(floatValue) == 1 // Convert float64 to bool (1.0 is true, 0.0 is false)
}
if strValue, ok := value.(string); ok {
boolValue, err := strconv.ParseBool(strValue)
if err == nil {
return boolValue
}
}
}
}
return false
}
type MikrotikRequestOptions struct {
Filters map[string]string `json:"filters,omitempty"`
PropList []string `json:"proplist,omitempty"`
}
func (o *MikrotikRequestOptions) GetPath(base string) string {
if o == nil {
return base
}
path, err := url.Parse(base)
if err != nil {
return base
}
query := path.Query()
for k, v := range o.Filters {
query.Set(k, v)
}
if len(o.PropList) > 0 {
query.Set(".proplist", strings.Join(o.PropList, ","))
}
path.RawQuery = query.Encode()
return path.String()
}
// region models
// region API-client
type MikrotikApiClient struct {
coreCfg *config.Config
cfg *config.BackendMikrotik
client *http.Client
log *slog.Logger
}
func NewMikrotikApiClient(coreCfg *config.Config, cfg *config.BackendMikrotik) (*MikrotikApiClient, error) {
c := &MikrotikApiClient{
coreCfg: coreCfg,
cfg: cfg,
}
err := c.setup()
if err != nil {
return nil, err
}
c.debugLog("Mikrotik api client created", "api_url", cfg.ApiUrl)
return c, nil
}
func (m *MikrotikApiClient) setup() error {
m.client = &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: !m.cfg.ApiVerifyTls,
},
},
Timeout: m.cfg.GetApiTimeout(),
}
if m.cfg.Debug {
m.log = slog.New(internal.GetLoggingHandler("debug",
m.coreCfg.Advanced.LogPretty,
m.coreCfg.Advanced.LogJson).
WithAttrs([]slog.Attr{
{
Key: "mikrotik-bid", Value: slog.StringValue(m.cfg.Id),
},
}))
}
return nil
}
func (m *MikrotikApiClient) debugLog(msg string, args ...any) {
if m.log != nil {
m.log.Debug("[MT-API] "+msg, args...)
}
}
func (m *MikrotikApiClient) getFullPath(command string) string {
path, err := url.JoinPath(m.cfg.ApiUrl, command)
if err != nil {
return ""
}
return path
}
func (m *MikrotikApiClient) prepareGetRequest(ctx context.Context, fullUrl string) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullUrl, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
if m.cfg.ApiUser != "" && m.cfg.ApiPassword != "" {
req.SetBasicAuth(m.cfg.ApiUser, m.cfg.ApiPassword)
}
return req, nil
}
func (m *MikrotikApiClient) prepareDeleteRequest(ctx context.Context, fullUrl string) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, fullUrl, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
if m.cfg.ApiUser != "" && m.cfg.ApiPassword != "" {
req.SetBasicAuth(m.cfg.ApiUser, m.cfg.ApiPassword)
}
return req, nil
}
func (m *MikrotikApiClient) preparePayloadRequest(
ctx context.Context,
method string,
fullUrl string,
payload GenericJsonObject,
) (*http.Request, error) {
// marshal the payload to JSON
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, method, fullUrl, bytes.NewReader(payloadBytes))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
if m.cfg.ApiUser != "" && m.cfg.ApiPassword != "" {
req.SetBasicAuth(m.cfg.ApiUser, m.cfg.ApiPassword)
}
return req, nil
}
func errToApiResponse[T any](code int, message string, err error) MikrotikApiResponse[T] {
return MikrotikApiResponse[T]{
Status: MikrotikApiStatusError,
Code: code,
Error: &MikrotikApiError{
Code: code,
Message: message,
Details: err.Error(),
},
}
}
func parseHttpResponse[T any](resp *http.Response, err error) MikrotikApiResponse[T] {
if err != nil {
return errToApiResponse[T](MikrotikApiErrorCodeRequestFailed, "failed to execute request", err)
}
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
slog.Error("failed to close response body", "error", err)
}
}(resp.Body)
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
var data T
// if the type of T is EmptyResponse, we can return an empty response with just the status
if _, ok := any(data).(EmptyResponse); ok {
return MikrotikApiResponse[T]{Status: MikrotikApiStatusOk, Code: resp.StatusCode}
}
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return errToApiResponse[T](MikrotikApiErrorCodeResponseDecodeFailed, "failed to decode response", err)
}
return MikrotikApiResponse[T]{Status: MikrotikApiStatusOk, Code: resp.StatusCode, Data: data}
}
var apiErr MikrotikApiError
if err := json.NewDecoder(resp.Body).Decode(&apiErr); err != nil {
return errToApiResponse[T](resp.StatusCode, "unknown error, unparsable response", err)
} else {
return MikrotikApiResponse[T]{Status: MikrotikApiStatusError, Code: resp.StatusCode, Error: &apiErr}
}
}
func (m *MikrotikApiClient) Query(
ctx context.Context,
command string,
opts *MikrotikRequestOptions,
) MikrotikApiResponse[[]GenericJsonObject] {
apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout())
defer cancel()
fullUrl := opts.GetPath(m.getFullPath(command))
req, err := m.prepareGetRequest(apiCtx, fullUrl)
if err != nil {
return errToApiResponse[[]GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed,
"failed to create request", err)
}
start := time.Now()
m.debugLog("executing API query", "url", fullUrl)
response := parseHttpResponse[[]GenericJsonObject](m.client.Do(req))
m.debugLog("retrieved API query result", "url", fullUrl, "duration", time.Since(start).String())
return response
}
func (m *MikrotikApiClient) Get(
ctx context.Context,
command string,
opts *MikrotikRequestOptions,
) MikrotikApiResponse[GenericJsonObject] {
apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout())
defer cancel()
fullUrl := opts.GetPath(m.getFullPath(command))
req, err := m.prepareGetRequest(apiCtx, fullUrl)
if err != nil {
return errToApiResponse[GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed,
"failed to create request", err)
}
start := time.Now()
m.debugLog("executing API get", "url", fullUrl)
response := parseHttpResponse[GenericJsonObject](m.client.Do(req))
m.debugLog("retrieved API get result", "url", fullUrl, "duration", time.Since(start).String())
return response
}
func (m *MikrotikApiClient) Create(
ctx context.Context,
command string,
payload GenericJsonObject,
) MikrotikApiResponse[GenericJsonObject] {
apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout())
defer cancel()
fullUrl := m.getFullPath(command)
req, err := m.preparePayloadRequest(apiCtx, http.MethodPut, fullUrl, payload)
if err != nil {
return errToApiResponse[GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed,
"failed to create request", err)
}
start := time.Now()
m.debugLog("executing API put", "url", fullUrl)
response := parseHttpResponse[GenericJsonObject](m.client.Do(req))
m.debugLog("retrieved API put result", "url", fullUrl, "duration", time.Since(start).String())
return response
}
func (m *MikrotikApiClient) Update(
ctx context.Context,
command string,
payload GenericJsonObject,
) MikrotikApiResponse[GenericJsonObject] {
apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout())
defer cancel()
fullUrl := m.getFullPath(command)
req, err := m.preparePayloadRequest(apiCtx, http.MethodPatch, fullUrl, payload)
if err != nil {
return errToApiResponse[GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed,
"failed to create request", err)
}
start := time.Now()
m.debugLog("executing API patch", "url", fullUrl)
response := parseHttpResponse[GenericJsonObject](m.client.Do(req))
m.debugLog("retrieved API patch result", "url", fullUrl, "duration", time.Since(start).String())
return response
}
func (m *MikrotikApiClient) Delete(
ctx context.Context,
command string,
) MikrotikApiResponse[EmptyResponse] {
apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout())
defer cancel()
fullUrl := m.getFullPath(command)
req, err := m.prepareDeleteRequest(apiCtx, fullUrl)
if err != nil {
return errToApiResponse[EmptyResponse](MikrotikApiErrorCodeRequestPreparationFailed,
"failed to create request", err)
}
start := time.Now()
m.debugLog("executing API delete", "url", fullUrl)
response := parseHttpResponse[EmptyResponse](m.client.Do(req))
m.debugLog("retrieved API delete result", "url", fullUrl, "duration", time.Since(start).String())
return response
}
func (m *MikrotikApiClient) ExecList(
ctx context.Context,
command string,
payload GenericJsonObject,
) MikrotikApiResponse[[]GenericJsonObject] {
apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout())
defer cancel()
fullUrl := m.getFullPath(command)
req, err := m.preparePayloadRequest(apiCtx, http.MethodPost, fullUrl, payload)
if err != nil {
return errToApiResponse[[]GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed,
"failed to create request", err)
}
start := time.Now()
m.debugLog("executing API post", "url", fullUrl)
response := parseHttpResponse[[]GenericJsonObject](m.client.Do(req))
m.debugLog("retrieved API post result", "url", fullUrl, "duration", time.Since(start).String())
return response
}
// endregion API-client

View File

@ -80,6 +80,7 @@ nav:
- Examples: documentation/configuration/examples.md - Examples: documentation/configuration/examples.md
- Usage: - Usage:
- General: documentation/usage/general.md - General: documentation/usage/general.md
- Backends: documentation/usage/backends.md
- LDAP: documentation/usage/ldap.md - LDAP: documentation/usage/ldap.md
- Security: documentation/usage/security.md - Security: documentation/usage/security.md
- Webhooks: documentation/usage/webhooks.md - Webhooks: documentation/usage/webhooks.md