diff --git a/README.md b/README.md index 018b748..035c69f 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ The configuration portal supports using a database (SQLite, MySQL, MsSQL, or Pos * Docker ready * Can be used with existing WireGuard setups * Support for multiple WireGuard interfaces +* Supports multiple WireGuard backends (wgctrl or MikroTik [BETA]) * Peer Expiry Feature * Handles route and DNS settings like wg-quick does * Exposes Prometheus metrics for monitoring and alerting diff --git a/cmd/wg-portal/main.go b/cmd/wg-portal/main.go index 0c1dd20..97f0b67 100644 --- a/cmd/wg-portal/main.go +++ b/cmd/wg-portal/main.go @@ -50,7 +50,8 @@ func main() { database, err := adapters.NewSqlRepository(rawDb) internal.AssertNoError(err) - wireGuard := adapters.NewWireGuardRepository() + wireGuard, err := wireguard.NewControllerManager(cfg) + internal.AssertNoError(err) wgQuick := adapters.NewWgQuickRepo() @@ -134,7 +135,7 @@ func main() { apiV0EndpointUsers := handlersV0.NewUserEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendUsers) apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendInterfaces) apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendPeers) - apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth) + apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth, wireGuard) apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth) apiFrontend := handlersV0.NewRestApi(apiV0Session, diff --git a/docs/documentation/configuration/overview.md b/docs/documentation/configuration/overview.md index dd20d79..1268582 100644 --- a/docs/documentation/configuration/overview.md +++ b/docs/documentation/configuration/overview.md @@ -24,6 +24,9 @@ core: self_provisioning_allowed: false import_existing: true restore_state: true + +backend: + default: local advanced: log_level: info @@ -102,6 +105,7 @@ webhook: Below you will find sections like [`core`](#core), +[`backend`](#backend), [`advanced`](#advanced), [`database`](#database), [`statistics`](#statistics), @@ -165,6 +169,65 @@ More advanced options are found in the subsequent `Advanced` section. --- +## Backend + +Configuration options for the WireGuard backend, which manages the WireGuard interfaces and peers. +The current MikroTik backend is in **BETA** and may not support all features. + +### `default` +- **Default:** `local` +- **Description:** The default backend to use for managing WireGuard interfaces. + Valid options are: `local`, or other backend id's configured in the `mikrotik` section. + +### Mikrotik + +The `mikrotik` array contains a list of MikroTik backend definitions. Each entry describes how to connect to a MikroTik RouterOS instance that hosts WireGuard interfaces. + +Below are the properties for each entry inside `backend.mikrotik`: + +#### `id` +- **Default:** *(empty)* +- **Description:** A unique identifier for this backend. + This value can be referenced by `backend.default` to use this backend as default. + The identifier must be unique across all backends and must not use the reserved keyword `local`. + +#### `display_name` +- **Default:** *(empty)* +- **Description:** A human-friendly display name for this backend. If omitted, the `id` will be used as the display name. + +#### `api_url` +- **Default:** *(empty)* +- **Description:** Base URL of the MikroTik REST API, including scheme and path, e.g., `https://10.10.10.10:8729/rest`. + +#### `api_user` +- **Default:** *(empty)* +- **Description:** Username for authenticating against the MikroTik API. + Ensure that the user has sufficient permissions to manage WireGuard interfaces and peers. + +#### `api_password` +- **Default:** *(empty)* +- **Description:** Password for the specified API user. + +#### `api_verify_tls` +- **Default:** `false` +- **Description:** Whether to verify the TLS certificate of the MikroTik API endpoint. Set to `false` to allow self-signed certificates (not recommended for production). + +#### `api_timeout` +- **Default:** `30s` +- **Description:** Timeout for API requests to the MikroTik device. Uses Go duration format (e.g., `10s`, `1m`). If omitted, a default of 30 seconds is used. + +#### `concurrency` +- **Default:** `5` +- **Description:** Maximum number of concurrent API requests the backend will issue when enumerating interfaces and their details. If `0` or negative, a sane default of `5` is used. + +#### `debug` +- **Default:** `false` +- **Description:** Enable verbose debug logging for the MikroTik backend. + +For more details on configuring the MikroTik backend, see the [Backends](../usage/backends.md) documentation. + +--- + ## Advanced Additional or more specialized configuration options for logging and interface creation details. diff --git a/docs/documentation/usage/backends.md b/docs/documentation/usage/backends.md new file mode 100644 index 0000000..e891d95 --- /dev/null +++ b/docs/documentation/usage/backends.md @@ -0,0 +1,57 @@ +# Backends + +WireGuard Portal can manage WireGuard interfaces and peers on different backends. +Each backend represents a system where interfaces actually live. +You can register multiple backends and choose which one to use per interface. +A global default backend determines where newly created interfaces go (unless you explicitly choose another in the UI). + +**Supported backends:** +- **Local** (default): Manages interfaces on the host running WireGuard Portal (Linux WireGuard via wgctrl). Use this when the portal should directly configure wg devices on the same server. +- **MikroTik** RouterOS (_beta_): Manages interfaces and peers on MikroTik devices via the RouterOS REST API. Use this to control WG interfaces on RouterOS v7+. + +How backend selection works: +- The default backend is configured at `backend.default` (_local_ or the id of a defined MikroTik backend). + New interfaces created in the UI will use this backend by default. +- Each interface stores its backend. You can select a different backend when creating a new interface. + +## Configuring MikroTik backends (RouterOS v7+) + +> :warning: The MikroTik backend is currently marked beta. While basic functionality is implemented, some advanced features are not yet implemented or contain bugs. Please test carefully before using in production. + +The MikroTik backend uses the [REST API](https://help.mikrotik.com/docs/spaces/ROS/pages/47579162/REST+API) under a base URL ending with /rest. +You can register one or more MikroTik devices as backends for a single WireGuard Portal instance. + +### Prerequisites on MikroTik: +- RouterOS v7 with WireGuard support. +- REST API enabled and reachable over HTTP(S). A typical base URL is https://:8729/rest or https:///rest depending on your service setup. +- A dedicated RouterOS user with the following group permissions: + - **api** (for logging in via REST API) + - **rest-api** (for logging in via REST API) + - **read** (to read interface and peer data) + - **write** (to create/update interfaces and peers) + - **test** (to perform ping checks) + - **sensitive** (to read private keys) +- TLS certificate on the device is recommended. If you use a self-signed certificate during testing, set `api_verify_tls`: _false_ in wg-portal (not recommended for production). + +Example WireGuard Portal configuration (config/config.yaml): + +```yaml +backend: + # default backend decides where new interfaces are created + default: mikrotik-prod + + mikrotik: + - id: mikrotik-prod # unique id, not "local" + display_name: RouterOS RB5009 # optional nice name + api_url: https://10.10.10.10/rest + api_user: wgportal + api_password: a-super-secret-password + api_verify_tls: true # set to false only if using self-signed during testing + api_timeout: 30s # maximum request duration + concurrency: 5 # limit parallel REST calls to device + debug: false # verbose logging for this backend +``` + +### Known limitations: +- The MikroTik backend is still in beta. Some features may not work as expected. +- Not all WireGuard Portal features are supported yet (e.g., no support for interface hooks) \ No newline at end of file diff --git a/frontend/src/components/InterfaceEditModal.vue b/frontend/src/components/InterfaceEditModal.vue index d23b490..586290d 100644 --- a/frontend/src/components/InterfaceEditModal.vue +++ b/frontend/src/components/InterfaceEditModal.vue @@ -10,11 +10,13 @@ import isCidr from "is-cidr"; import {isIP} from 'is-ip'; import { freshInterface } from '@/helpers/models'; import {peerStore} from "@/stores/peers"; +import {settingsStore} from "@/stores/settings"; const { t } = useI18n() const interfaces = interfaceStore() const peers = peerStore() +const settings = settingsStore() const props = defineProps({ interfaceId: String, @@ -48,6 +50,26 @@ const currentTags = ref({ PeerDefDnsSearch: "" }) const formData = ref(freshInterface()) +const isSaving = ref(false) +const isDeleting = ref(false) +const isApplyingDefaults = ref(false) + +const isBackendValid = computed(() => { + if (!props.visible || !selectedInterface.value) { + return true // if modal is not visible or no interface is selected, we don't care about backend validity + } + + let backendId = selectedInterface.value.Backend + + let valid = false + let availableBackends = settings.Setting('AvailableBackends') || [] + availableBackends.forEach(backend => { + if (backend.Id === backendId) { + valid = true + } + }) + return valid +}) // functions @@ -61,6 +83,7 @@ watch(() => props.visible, async (newValue, oldValue) => { formData.value.Identifier = interfaces.Prepared.Identifier formData.value.DisplayName = interfaces.Prepared.DisplayName formData.value.Mode = interfaces.Prepared.Mode + formData.value.Backend = interfaces.Prepared.Backend formData.value.PublicKey = interfaces.Prepared.PublicKey formData.value.PrivateKey = interfaces.Prepared.PrivateKey @@ -99,6 +122,7 @@ watch(() => props.visible, async (newValue, oldValue) => { formData.value.Identifier = selectedInterface.value.Identifier formData.value.DisplayName = selectedInterface.value.DisplayName formData.value.Mode = selectedInterface.value.Mode + formData.value.Backend = selectedInterface.value.Backend formData.value.PublicKey = selectedInterface.value.PublicKey formData.value.PrivateKey = selectedInterface.value.PrivateKey @@ -237,6 +261,8 @@ function handleChangePeerDefDnsSearch(tags) { } async function save() { + if (isSaving.value) return + isSaving.value = true try { if (props.interfaceId!=='#NEW#') { await interfaces.UpdateInterface(selectedInterface.value.Identifier, formData.value) @@ -251,6 +277,8 @@ async function save() { text: e.toString(), type: 'error', }) + } finally { + isSaving.value = false } } @@ -259,6 +287,8 @@ async function applyPeerDefaults() { return; // do nothing for new interfaces } + if (isApplyingDefaults.value) return + isApplyingDefaults.value = true try { await interfaces.ApplyPeerDefaults(selectedInterface.value.Identifier, formData.value) @@ -276,10 +306,14 @@ async function applyPeerDefaults() { text: e.toString(), type: 'error', }) + } finally { + isApplyingDefaults.value = false } } async function del() { + if (isDeleting.value) return + isDeleting.value = true try { await interfaces.DeleteInterface(selectedInterface.value.Identifier) close() @@ -290,6 +324,8 @@ async function del() { text: e.toString(), type: 'error', }) + } finally { + isDeleting.value = false } } @@ -314,13 +350,22 @@ async function del() { -
- - +
+
+ + +
+
+ + + {{ $t('modals.interface-edit.backend.invalid-label') }} +
@@ -385,12 +430,14 @@ async function del() {
-
+
+
+
-
+
@@ -530,16 +577,25 @@ async function del() {

- +
diff --git a/frontend/src/components/PeerEditModal.vue b/frontend/src/components/PeerEditModal.vue index 554de5a..7c50edf 100644 --- a/frontend/src/components/PeerEditModal.vue +++ b/frontend/src/components/PeerEditModal.vue @@ -73,6 +73,8 @@ const currentTags = ref({ DnsSearch: "" }) const formData = ref(freshPeer()) +const isSaving = ref(false) +const isDeleting = ref(false) // functions @@ -270,6 +272,8 @@ function handleChangeDnsSearch(tags) { } async function save() { + if (isSaving.value) return + isSaving.value = true try { if (props.peerId !== '#NEW#') { await peers.UpdatePeer(selectedPeer.value.Identifier, formData.value) @@ -278,26 +282,30 @@ async function save() { } close() } catch (e) { - // console.log(e) notify({ title: "Failed to save peer!", text: e.toString(), type: 'error', }) + } finally { + isSaving.value = false } } async function del() { + if (isDeleting.value) return + isDeleting.value = true try { await peers.DeletePeer(selectedPeer.value.Identifier) close() } catch (e) { - // console.log(e) notify({ title: "Failed to delete peer!", text: e.toString(), type: 'error', }) + } finally { + isDeleting.value = false } } @@ -470,10 +478,15 @@ async function del() { diff --git a/frontend/src/components/PeerMultiCreateModal.vue b/frontend/src/components/PeerMultiCreateModal.vue index f5a2c87..bc99432 100644 --- a/frontend/src/components/PeerMultiCreateModal.vue +++ b/frontend/src/components/PeerMultiCreateModal.vue @@ -38,6 +38,7 @@ function freshForm() { const currentTag = ref("") const formData = ref(freshForm()) +const isSaving = ref(false) const title = computed(() => { if (!props.visible) { @@ -60,12 +61,15 @@ function handleChangeUserIdentifiers(tags) { } async function save() { + if (isSaving.value) return + isSaving.value = true if (formData.value.Identifiers.length === 0) { notify({ title: "Missing Identifiers", text: "At least one identifier is required to create a new peer.", type: 'error', }) + isSaving.value = false return } @@ -79,6 +83,8 @@ async function save() { text: e.toString(), type: 'error', }) + } finally { + isSaving.value = false } } @@ -108,7 +114,10 @@ async function save() { diff --git a/frontend/src/components/UserEditModal.vue b/frontend/src/components/UserEditModal.vue index 6a4a7bc..340dfe2 100644 --- a/frontend/src/components/UserEditModal.vue +++ b/frontend/src/components/UserEditModal.vue @@ -34,6 +34,8 @@ const title = computed(() => { }) const formData = ref(freshUser()) +const isSaving = ref(false) +const isDeleting = ref(false) const passwordWeak = computed(() => { return formData.value.Password && formData.value.Password.length > 0 && formData.value.Password.length < settings.Setting('MinPasswordLength') @@ -89,6 +91,8 @@ function close() { } async function save() { + if (isSaving.value) return + isSaving.value = true try { if (props.userId!=='#NEW#') { await users.UpdateUser(selectedUser.value.Identifier, formData.value) @@ -102,10 +106,14 @@ async function save() { text: e.toString(), type: 'error', }) + } finally { + isSaving.value = false } } async function del() { + if (isDeleting.value) return + isDeleting.value = true try { await users.DeleteUser(selectedUser.value.Identifier) close() @@ -115,6 +123,8 @@ async function del() { text: e.toString(), type: 'error', }) + } finally { + isDeleting.value = false } } @@ -193,9 +203,15 @@ async function del() { diff --git a/frontend/src/components/UserPeerEditModal.vue b/frontend/src/components/UserPeerEditModal.vue index 7594d7b..15f2f83 100644 --- a/frontend/src/components/UserPeerEditModal.vue +++ b/frontend/src/components/UserPeerEditModal.vue @@ -55,6 +55,8 @@ const title = computed(() => { }) const formData = ref(freshPeer()) +const isSaving = ref(false) +const isDeleting = ref(false) // functions @@ -163,6 +165,8 @@ function close() { } async function save() { + if (isSaving.value) return + isSaving.value = true try { if (props.peerId !== '#NEW#') { await peers.UpdatePeer(selectedPeer.value.Identifier, formData.value) @@ -171,26 +175,30 @@ async function save() { } close() } catch (e) { - // console.log(e) notify({ title: "Failed to save peer!", text: e.toString(), type: 'error', }) + } finally { + isSaving.value = false } } async function del() { + if (isDeleting.value) return + isDeleting.value = true try { await peers.DeletePeer(selectedPeer.value.Identifier) close() } catch (e) { - // console.log(e) notify({ title: "Failed to delete peer!", text: e.toString(), type: 'error', }) + } finally { + isDeleting.value = false } } @@ -283,10 +291,15 @@ async function del() { diff --git a/frontend/src/helpers/models.js b/frontend/src/helpers/models.js index 8f8683e..6e1e52b 100644 --- a/frontend/src/helpers/models.js +++ b/frontend/src/helpers/models.js @@ -5,6 +5,7 @@ export function freshInterface() { DisplayName: "", Identifier: "", Mode: "server", + Backend: "local", PublicKey: "", PrivateKey: "", diff --git a/frontend/src/lang/translations/de.json b/frontend/src/lang/translations/de.json index 859c70e..a0b200c 100644 --- a/frontend/src/lang/translations/de.json +++ b/frontend/src/lang/translations/de.json @@ -102,7 +102,9 @@ }, "interface": { "headline": "Schnittstellenstatus für", - "mode": "Modus", + "backend": "Backend", + "unknown-backend": "Unbekannt", + "wrong-backend": "Ungültiges Backend, das lokale WireGuard Backend wird stattdessen verwendet!", "key": "Öffentlicher Schlüssel", "endpoint": "Öffentlicher Endpunkt", "port": "Port", @@ -357,6 +359,11 @@ "client": "Client-Modus", "any": "Unbekannter Modus" }, + "backend": { + "label": "Schnittstellenbackend", + "invalid-label": "Ursprüngliches Backend ist ungültig, das lokale WireGuard Backend wird stattdessen verwendet!", + "local": "Lokales WireGuard Backend" + }, "display-name": { "label": "Anzeigename", "placeholder": "Der beschreibende Name für die Schnittstelle" diff --git a/frontend/src/lang/translations/en.json b/frontend/src/lang/translations/en.json index 57a129a..af795fc 100644 --- a/frontend/src/lang/translations/en.json +++ b/frontend/src/lang/translations/en.json @@ -102,7 +102,9 @@ }, "interface": { "headline": "Interface status for", - "mode": "mode", + "backend": "Backend", + "unknown-backend": "Unknown", + "wrong-backend": "Invalid backend, using local WireGuard backend instead!", "key": "Public Key", "endpoint": "Public Endpoint", "port": "Listening Port", @@ -357,6 +359,11 @@ "client": "Client Mode", "any": "Unknown Mode" }, + "backend": { + "label": "Interface Backend", + "invalid-label": "Original backend is no longer available, using local WireGuard backend instead!", + "local": "Local WireGuard Backend" + }, "display-name": { "label": "Display Name", "placeholder": "The descriptive name for the interface" diff --git a/frontend/src/lang/translations/fr.json b/frontend/src/lang/translations/fr.json index f5b165c..951ab22 100644 --- a/frontend/src/lang/translations/fr.json +++ b/frontend/src/lang/translations/fr.json @@ -99,7 +99,7 @@ }, "interface": { "headline": "État de l'interface pour", - "mode": "mode", + "backend": "backend", "key": "Clé publique", "endpoint": "Point de terminaison public", "port": "Port d'écoute", diff --git a/frontend/src/lang/translations/ko.json b/frontend/src/lang/translations/ko.json index 8b87d1b..6e65e06 100644 --- a/frontend/src/lang/translations/ko.json +++ b/frontend/src/lang/translations/ko.json @@ -100,7 +100,7 @@ }, "interface": { "headline": "인터페이스 상태:", - "mode": "모드", + "backend": "백엔드", "key": "공개 키", "endpoint": "공개 엔드포인트", "port": "수신 포트", diff --git a/frontend/src/lang/translations/pt.json b/frontend/src/lang/translations/pt.json index a895400..126037e 100644 --- a/frontend/src/lang/translations/pt.json +++ b/frontend/src/lang/translations/pt.json @@ -101,7 +101,7 @@ }, "interface": { "headline": "Status da interface para", - "mode": "modo", + "mode": "backend", "key": "Chave Pública", "endpoint": "Endpoint Público", "port": "Porta de Escuta", diff --git a/frontend/src/lang/translations/ru.json b/frontend/src/lang/translations/ru.json index 6df8383..a88158a 100644 --- a/frontend/src/lang/translations/ru.json +++ b/frontend/src/lang/translations/ru.json @@ -99,7 +99,7 @@ }, "interface": { "headline": "Статус интерфейса для", - "mode": "режим", + "backend": "бэкэнд", "key": "Публичный ключ", "endpoint": "Публичная конечная точка", "port": "Порт прослушивания", diff --git a/frontend/src/lang/translations/uk.json b/frontend/src/lang/translations/uk.json index 7647528..4574ad4 100644 --- a/frontend/src/lang/translations/uk.json +++ b/frontend/src/lang/translations/uk.json @@ -99,7 +99,7 @@ }, "interface": { "headline": "Статус інтерфейсу для", - "mode": "режим", + "backend": "бекенд", "key": "Публічний ключ", "endpoint": "Публічна кінцева точка", "port": "Порт прослуховування", diff --git a/frontend/src/lang/translations/vi.json b/frontend/src/lang/translations/vi.json index 7e90dd9..722918f 100644 --- a/frontend/src/lang/translations/vi.json +++ b/frontend/src/lang/translations/vi.json @@ -98,7 +98,7 @@ }, "interface": { "headline": "Trạng thái giao diện cho", - "mode": "chế độ", + "backend": "phần sau", "key": "Khóa Công khai", "endpoint": "Điểm cuối Công khai", "port": "Cổng Nghe", diff --git a/frontend/src/lang/translations/zh.json b/frontend/src/lang/translations/zh.json index cf1d715..3b5b64e 100644 --- a/frontend/src/lang/translations/zh.json +++ b/frontend/src/lang/translations/zh.json @@ -98,7 +98,7 @@ }, "interface": { "headline": "接口状态", - "mode": "模式", + "backend": "后端", "key": "公钥", "endpoint": "公开节点", "port": "监听端口", diff --git a/frontend/src/views/InterfaceView.vue b/frontend/src/views/InterfaceView.vue index 1cdb9fb..b28cf5f 100644 --- a/frontend/src/views/InterfaceView.vue +++ b/frontend/src/views/InterfaceView.vue @@ -5,17 +5,20 @@ import PeerMultiCreateModal from "../components/PeerMultiCreateModal.vue"; import InterfaceEditModal from "../components/InterfaceEditModal.vue"; import InterfaceViewModal from "../components/InterfaceViewModal.vue"; -import {onMounted, ref} from "vue"; +import {computed, onMounted, ref} from "vue"; import {peerStore} from "@/stores/peers"; import {interfaceStore} from "@/stores/interfaces"; import {notify} from "@kyvg/vue3-notification"; import {settingsStore} from "@/stores/settings"; import {humanFileSize} from '@/helpers/utils'; +import {useI18n} from "vue-i18n"; const settings = settingsStore() const interfaces = interfaceStore() const peers = peerStore() +const { t } = useI18n() + const viewedPeerId = ref("") const editPeerId = ref("") const multiCreatePeerId = ref("") @@ -45,6 +48,33 @@ function calculateInterfaceName(id, name) { return result } +const calculateBackendName = computed(() => { + let backendId = interfaces.GetSelected.Backend + + let backendName = t('interfaces.interface.unknown-backend') + let availableBackends = settings.Setting('AvailableBackends') || [] + availableBackends.forEach(backend => { + if (backend.Id === backendId) { + backendName = backend.Id === 'local' ? t(backend.Name) : backend.Name + } + }) + return backendName +}) + +const isBackendValid = computed(() => { + let backendId = interfaces.GetSelected.Backend + + let valid = false + let availableBackends = settings.Setting('AvailableBackends') || [] + availableBackends.forEach(backend => { + if (backend.Id === backendId) { + valid = true + } + }) + return valid +}) + + async function download() { await interfaces.LoadInterfaceConfig(interfaces.GetSelected.Identifier) @@ -141,7 +171,7 @@ onMounted(async () => {
- {{ $t('interfaces.interface.headline') }} {{interfaces.GetSelected.Identifier}} ({{interfaces.GetSelected.Mode}} {{ $t('interfaces.interface.mode') }}) + {{ $t('interfaces.interface.headline') }} {{interfaces.GetSelected.Identifier}} ({{ $t('modals.interface-edit.mode.' + interfaces.GetSelected.Mode )}} | {{ $t('interfaces.interface.backend') + ": " + calculateBackendName }})
diff --git a/internal/adapters/wgcontroller/local.go b/internal/adapters/wgcontroller/local.go new file mode 100644 index 0000000..7f2e7fa --- /dev/null +++ b/internal/adapters/wgcontroller/local.go @@ -0,0 +1,864 @@ +package wgcontroller + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "log/slog" + "os" + "os/exec" + "strings" + "time" + + probing "github.com/prometheus-community/pro-bing" + "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/h44z/wg-portal/internal" + "github.com/h44z/wg-portal/internal/config" + "github.com/h44z/wg-portal/internal/domain" + "github.com/h44z/wg-portal/internal/lowlevel" +) + +// region dependencies + +// WgCtrlRepo is used to control local WireGuard devices via the wgctrl-go library. +type WgCtrlRepo interface { + io.Closer + Devices() ([]*wgtypes.Device, error) + Device(name string) (*wgtypes.Device, error) + ConfigureDevice(name string, cfg wgtypes.Config) error +} + +// A NetlinkClient is a type which can control a netlink device. +type NetlinkClient interface { + LinkAdd(link netlink.Link) error + LinkDel(link netlink.Link) error + LinkByName(name string) (netlink.Link, error) + LinkSetUp(link netlink.Link) error + LinkSetDown(link netlink.Link) error + LinkSetMTU(link netlink.Link, mtu int) error + AddrReplace(link netlink.Link, addr *netlink.Addr) error + AddrAdd(link netlink.Link, addr *netlink.Addr) error + AddrList(link netlink.Link) ([]netlink.Addr, error) + AddrDel(link netlink.Link, addr *netlink.Addr) error + RouteAdd(route *netlink.Route) error + RouteDel(route *netlink.Route) error + RouteReplace(route *netlink.Route) error + RouteList(link netlink.Link, family int) ([]netlink.Route, error) + RouteListFiltered(family int, filter *netlink.Route, filterMask uint64) ([]netlink.Route, error) + RuleAdd(rule *netlink.Rule) error + RuleDel(rule *netlink.Rule) error + RuleList(family int) ([]netlink.Rule, error) +} + +// endregion dependencies + +type LocalController struct { + cfg *config.Config + + wg WgCtrlRepo + nl NetlinkClient + + shellCmd string + resolvConfIfacePrefix string +} + +// NewLocalController creates a new local controller instance. +// This repository is used to interact with the WireGuard kernel or userspace module. +func NewLocalController(cfg *config.Config) (*LocalController, error) { + wg, err := wgctrl.New() + if err != nil { + return nil, fmt.Errorf("failed to create wgctrl client: %w", err) + } + + nl := &lowlevel.NetlinkManager{} + + repo := &LocalController{ + cfg: cfg, + + wg: wg, + nl: nl, + + shellCmd: "bash", // we only support bash at the moment + resolvConfIfacePrefix: "tun.", // WireGuard interfaces have a tun. prefix in resolvconf + } + + return repo, nil +} + +func (c LocalController) GetId() domain.InterfaceBackend { + return config.LocalBackendName +} + +// region wireguard-related + +func (c LocalController) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) { + devices, err := c.wg.Devices() + if err != nil { + return nil, fmt.Errorf("device list error: %w", err) + } + + interfaces := make([]domain.PhysicalInterface, 0, len(devices)) + for _, device := range devices { + interfaceModel, err := c.convertWireGuardInterface(device) + if err != nil { + return nil, fmt.Errorf("interface convert failed for %s: %w", device.Name, err) + } + interfaces = append(interfaces, interfaceModel) + } + + return interfaces, nil +} + +func (c LocalController) GetInterface(_ context.Context, id domain.InterfaceIdentifier) ( + *domain.PhysicalInterface, + error, +) { + return c.getInterface(id) +} + +func (c LocalController) convertWireGuardInterface(device *wgtypes.Device) (domain.PhysicalInterface, error) { + // read data from wgctrl interface + + iface := domain.PhysicalInterface{ + Identifier: domain.InterfaceIdentifier(device.Name), + KeyPair: domain.KeyPair{ + PrivateKey: device.PrivateKey.String(), + PublicKey: device.PublicKey.String(), + }, + ListenPort: device.ListenPort, + Addresses: nil, + Mtu: 0, + FirewallMark: uint32(device.FirewallMark), + DeviceUp: false, + ImportSource: domain.ControllerTypeLocal, + DeviceType: device.Type.String(), + BytesUpload: 0, + BytesDownload: 0, + } + + // read data from netlink interface + + lowLevelInterface, err := c.nl.LinkByName(device.Name) + if err != nil { + return domain.PhysicalInterface{}, fmt.Errorf("netlink error for %s: %w", device.Name, err) + } + ipAddresses, err := c.nl.AddrList(lowLevelInterface) + if err != nil { + return domain.PhysicalInterface{}, fmt.Errorf("ip read error for %s: %w", device.Name, err) + } + + for _, addr := range ipAddresses { + iface.Addresses = append(iface.Addresses, domain.CidrFromNetlinkAddr(addr)) + } + iface.Mtu = lowLevelInterface.Attrs().MTU + iface.DeviceUp = lowLevelInterface.Attrs().OperState == netlink.OperUnknown // wg only supports unknown + if stats := lowLevelInterface.Attrs().Statistics; stats != nil { + iface.BytesUpload = stats.TxBytes + iface.BytesDownload = stats.RxBytes + } + + return iface, nil +} + +func (c LocalController) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ( + []domain.PhysicalPeer, + error, +) { + device, err := c.wg.Device(string(deviceId)) + if err != nil { + return nil, fmt.Errorf("device error: %w", err) + } + + peers := make([]domain.PhysicalPeer, 0, len(device.Peers)) + for _, peer := range device.Peers { + peerModel, err := c.convertWireGuardPeer(&peer) + if err != nil { + return nil, fmt.Errorf("peer convert failed for %v: %w", peer.PublicKey, err) + } + peers = append(peers, peerModel) + } + + return peers, nil +} + +func (c LocalController) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer, error) { + peerModel := domain.PhysicalPeer{ + Identifier: domain.PeerIdentifier(peer.PublicKey.String()), + Endpoint: "", + AllowedIPs: nil, + KeyPair: domain.KeyPair{ + PublicKey: peer.PublicKey.String(), + }, + PresharedKey: "", + PersistentKeepalive: int(peer.PersistentKeepaliveInterval.Seconds()), + LastHandshake: peer.LastHandshakeTime, + ProtocolVersion: peer.ProtocolVersion, + BytesUpload: uint64(peer.ReceiveBytes), + BytesDownload: uint64(peer.TransmitBytes), + ImportSource: domain.ControllerTypeLocal, + } + + // Set local extras - local peers are never disabled in the kernel + peerModel.SetExtras(domain.LocalPeerExtras{ + Disabled: false, + }) + + for _, addr := range peer.AllowedIPs { + peerModel.AllowedIPs = append(peerModel.AllowedIPs, domain.CidrFromIpNet(addr)) + } + if peer.Endpoint != nil { + peerModel.Endpoint = peer.Endpoint.String() + } + if peer.PresharedKey != (wgtypes.Key{}) { + peerModel.PresharedKey = domain.PreSharedKey(peer.PresharedKey.String()) + } + + return peerModel, nil +} + +func (c LocalController) SaveInterface( + _ context.Context, + id domain.InterfaceIdentifier, + updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error), +) error { + physicalInterface, err := c.getOrCreateInterface(id) + if err != nil { + return err + } + + if updateFunc != nil { + physicalInterface, err = updateFunc(physicalInterface) + if err != nil { + return err + } + } + + if err := c.updateLowLevelInterface(physicalInterface); err != nil { + return err + } + if err := c.updateWireGuardInterface(physicalInterface); err != nil { + return err + } + + return nil +} + +func (c LocalController) getOrCreateInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) { + device, err := c.getInterface(id) + if err == nil { + return device, nil // interface exists + } + if !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("device error: %w", err) // unknown error + } + + // create new device + if err := c.createLowLevelInterface(id); err != nil { + return nil, err + } + + device, err = c.getInterface(id) + return device, err +} + +func (c LocalController) getInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) { + device, err := c.wg.Device(string(id)) + if err != nil { + return nil, err + } + + pi, err := c.convertWireGuardInterface(device) + return &pi, err +} + +func (c LocalController) createLowLevelInterface(id domain.InterfaceIdentifier) error { + link := &netlink.GenericLink{ + LinkAttrs: netlink.LinkAttrs{ + Name: string(id), + }, + LinkType: "wireguard", + } + err := c.nl.LinkAdd(link) + if err != nil { + return fmt.Errorf("link add failed: %w", err) + } + + return nil +} + +func (c LocalController) updateLowLevelInterface(pi *domain.PhysicalInterface) error { + link, err := c.nl.LinkByName(string(pi.Identifier)) + if err != nil { + return err + } + if pi.Mtu != 0 { + if err := c.nl.LinkSetMTU(link, pi.Mtu); err != nil { + return fmt.Errorf("mtu error: %w", err) + } + } + + for _, addr := range pi.Addresses { + err := c.nl.AddrReplace(link, addr.NetlinkAddr()) + if err != nil { + return fmt.Errorf("failed to set ip %s: %w", addr.String(), err) + } + } + + // Remove unwanted IP addresses + rawAddresses, err := c.nl.AddrList(link) + if err != nil { + return fmt.Errorf("failed to fetch interface ips: %w", err) + } + for _, rawAddr := range rawAddresses { + netlinkAddr := domain.CidrFromNetlinkAddr(rawAddr) + remove := true + for _, addr := range pi.Addresses { + if addr == netlinkAddr { + remove = false + break + } + } + + if !remove { + continue + } + + err := c.nl.AddrDel(link, &rawAddr) + if err != nil { + return fmt.Errorf("failed to remove deprecated ip %s: %w", netlinkAddr.String(), err) + } + } + + // Update link state + if pi.DeviceUp { + if err := c.nl.LinkSetUp(link); err != nil { + return fmt.Errorf("failed to bring up device: %w", err) + } + } else { + if err := c.nl.LinkSetDown(link); err != nil { + return fmt.Errorf("failed to bring down device: %w", err) + } + } + + return nil +} + +func (c LocalController) updateWireGuardInterface(pi *domain.PhysicalInterface) error { + pKey, err := wgtypes.NewKey(pi.KeyPair.GetPrivateKeyBytes()) + if err != nil { + return err + } + + var fwMark *int + if pi.FirewallMark != 0 { + intFwMark := int(pi.FirewallMark) + fwMark = &intFwMark + } + err = c.wg.ConfigureDevice(string(pi.Identifier), wgtypes.Config{ + PrivateKey: &pKey, + ListenPort: &pi.ListenPort, + FirewallMark: fwMark, + ReplacePeers: false, + }) + if err != nil { + return err + } + + return nil +} + +func (c LocalController) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error { + if err := c.deleteLowLevelInterface(id); err != nil { + return err + } + + return nil +} + +func (c LocalController) deleteLowLevelInterface(id domain.InterfaceIdentifier) error { + link, err := c.nl.LinkByName(string(id)) + if err != nil { + var linkNotFoundError netlink.LinkNotFoundError + if errors.As(err, &linkNotFoundError) { + return nil // ignore not found error + } + return fmt.Errorf("unable to find low level interface: %w", err) + } + + err = c.nl.LinkDel(link) + if err != nil { + return fmt.Errorf("failed to delete low level interface: %w", err) + } + + return nil +} + +func (c LocalController) SavePeer( + _ context.Context, + deviceId domain.InterfaceIdentifier, + id domain.PeerIdentifier, + updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error), +) error { + physicalPeer, err := c.getOrCreatePeer(deviceId, id) + if err != nil { + return err + } + + physicalPeer, err = updateFunc(physicalPeer) + if err != nil { + return err + } + + // Check if the peer is disabled by looking at the backend extras + // For local controller, disabled peers should be deleted + if physicalPeer.GetExtras() != nil { + switch extras := physicalPeer.GetExtras().(type) { + case domain.LocalPeerExtras: + if extras.Disabled { + // Delete the peer instead of updating it + return c.deletePeer(deviceId, id) + } + } + } + + if err := c.updatePeer(deviceId, physicalPeer); err != nil { + return err + } + + return nil +} + +func (c LocalController) getOrCreatePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) ( + *domain.PhysicalPeer, + error, +) { + peer, err := c.getPeer(deviceId, id) + if err == nil { + return peer, nil // peer exists + } + if !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("peer error: %w", err) // unknown error + } + + // create new peer + err = c.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: id.ToPublicKey(), + }, + }, + }) + if err != nil { + return nil, fmt.Errorf("peer create error for %s: %w", id.ToPublicKey(), err) + } + + peer, err = c.getPeer(deviceId, id) + if err != nil { + return nil, fmt.Errorf("peer error after create: %w", err) + } + return peer, nil +} + +func (c LocalController) getPeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) ( + *domain.PhysicalPeer, + error, +) { + if !id.IsPublicKey() { + return nil, errors.New("invalid public key") + } + + device, err := c.wg.Device(string(deviceId)) + if err != nil { + return nil, err + } + + publicKey := id.ToPublicKey() + for _, peer := range device.Peers { + if peer.PublicKey != publicKey { + continue + } + + peerModel, err := c.convertWireGuardPeer(&peer) + return &peerModel, err + } + + return nil, os.ErrNotExist +} + +func (c LocalController) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.PhysicalPeer) error { + cfg := wgtypes.PeerConfig{ + PublicKey: pp.GetPublicKey(), + Remove: false, + UpdateOnly: true, + PresharedKey: pp.GetPresharedKey(), + Endpoint: pp.GetEndpointAddress(), + PersistentKeepaliveInterval: pp.GetPersistentKeepaliveTime(), + ReplaceAllowedIPs: true, + AllowedIPs: pp.GetAllowedIPs(), + } + + err := c.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}}) + if err != nil { + return err + } + + return nil +} + +func (c LocalController) DeletePeer( + _ context.Context, + deviceId domain.InterfaceIdentifier, + id domain.PeerIdentifier, +) error { + if !id.IsPublicKey() { + return errors.New("invalid public key") + } + + err := c.deletePeer(deviceId, id) + if err != nil { + return err + } + + return nil +} + +func (c LocalController) deletePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error { + cfg := wgtypes.PeerConfig{ + PublicKey: id.ToPublicKey(), + Remove: true, + } + + err := c.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}}) + if err != nil { + return err + } + + return nil +} + +// endregion wireguard-related + +// region wg-quick-related + +func (c LocalController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error { + if hookCmd == "" { + return nil + } + + slog.Debug("executing interface hook", "interface", id, "hook", hookCmd) + err := c.exec(hookCmd, id) + if err != nil { + return fmt.Errorf("failed to exec hook: %w", err) + } + + return nil +} + +func (c LocalController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error { + if dnsStr == "" && dnsSearchStr == "" { + return nil + } + + dnsServers := internal.SliceString(dnsStr) + dnsSearchDomains := internal.SliceString(dnsSearchStr) + + dnsCommand := "resolvconf -a %resPref%i -m 0 -x" + dnsCommandInput := make([]string, 0, len(dnsServers)+len(dnsSearchDomains)) + + for _, dnsServer := range dnsServers { + dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("nameserver %s", dnsServer)) + } + for _, searchDomain := range dnsSearchDomains { + dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("search %s", searchDomain)) + } + + err := c.exec(dnsCommand, id, dnsCommandInput...) + if err != nil { + return fmt.Errorf( + "failed to set dns settings (is resolvconf available?, for systemd create this symlink: ln -s /usr/bin/resolvectl /usr/local/bin/resolvconf): %w", + err, + ) + } + + return nil +} + +func (c LocalController) UnsetDNS(id domain.InterfaceIdentifier) error { + dnsCommand := "resolvconf -d %resPref%i -f" + + err := c.exec(dnsCommand, id) + if err != nil { + return fmt.Errorf("failed to unset dns settings: %w", err) + } + + return nil +} + +func (c LocalController) replaceCommandPlaceHolders(command string, interfaceId domain.InterfaceIdentifier) string { + command = strings.ReplaceAll(command, "%resPref", c.resolvConfIfacePrefix) + return strings.ReplaceAll(command, "%i", string(interfaceId)) +} + +func (c LocalController) exec(command string, interfaceId domain.InterfaceIdentifier, stdin ...string) error { + commandWithInterfaceName := c.replaceCommandPlaceHolders(command, interfaceId) + cmd := exec.Command(c.shellCmd, "-ce", commandWithInterfaceName) + if len(stdin) > 0 { + b := &bytes.Buffer{} + for _, ln := range stdin { + if _, err := fmt.Fprint(b, ln); err != nil { + return err + } + } + cmd.Stdin = b + } + out, err := cmd.CombinedOutput() // execute and wait for output + if err != nil { + return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err) + } + slog.Debug("executed shell command", + "command", commandWithInterfaceName, + "output", string(out)) + return nil +} + +// endregion wg-quick-related + +// region routing-related + +func (c LocalController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error { + // update fwmark rules + if err := c.setFwMarkRules(rules); err != nil { + return err + } + + // update main rule + if err := c.setMainRule(rules); err != nil { + return err + } + + // cleanup old main rules + if err := c.cleanupMainRule(rules); err != nil { + return err + } + + return nil +} + +func (c LocalController) setFwMarkRules(rules []domain.RouteRule) error { + for _, rule := range rules { + existingRules, err := c.nl.RuleList(int(rule.IpFamily)) + if err != nil { + return fmt.Errorf("failed to get existing rules for family %s: %w", rule.IpFamily, err) + } + + ruleExists := false + for _, existingRule := range existingRules { + if rule.FwMark == existingRule.Mark && rule.Table == existingRule.Table { + ruleExists = true + break + } + } + + if ruleExists { + continue // rule already exists, no need to recreate it + } + + // create a missing rule + if err := c.nl.RuleAdd(&netlink.Rule{ + Family: int(rule.IpFamily), + Table: rule.Table, + Mark: rule.FwMark, + Invert: true, + SuppressIfgroup: -1, + SuppressPrefixlen: -1, + Priority: c.getRulePriority(existingRules), + Mask: nil, + Goto: -1, + Flow: -1, + }); err != nil { + return fmt.Errorf("failed to setup %s rule for fwmark %d and table %d: %w", + rule.IpFamily, rule.FwMark, rule.Table, err) + } + } + return nil +} + +func (c LocalController) getRulePriority(existingRules []netlink.Rule) int { + prio := 32700 // linux main rule has a priority of 32766 + for { + isFresh := true + for _, existingRule := range existingRules { + if existingRule.Priority == prio { + isFresh = false + break + } + } + if isFresh { + break + } else { + prio-- + } + } + return prio +} + +func (c LocalController) setMainRule(rules []domain.RouteRule) error { + var family domain.IpFamily + shouldHaveMainRule := false + for _, rule := range rules { + family = rule.IpFamily + if rule.HasDefault == true { + shouldHaveMainRule = true + break + } + } + if !shouldHaveMainRule { + return nil + } + + existingRules, err := c.nl.RuleList(int(family)) + if err != nil { + return fmt.Errorf("failed to get existing rules for family %s: %w", family, err) + } + + ruleExists := false + for _, existingRule := range existingRules { + if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { + ruleExists = true + break + } + } + + if ruleExists { + return nil // rule already exists, skip re-creation + } + + if err := c.nl.RuleAdd(&netlink.Rule{ + Family: int(family), + Table: unix.RT_TABLE_MAIN, + SuppressIfgroup: -1, + SuppressPrefixlen: 0, + Priority: c.getMainRulePriority(existingRules), + Mark: 0, + Mask: nil, + Goto: -1, + Flow: -1, + }); err != nil { + return fmt.Errorf("failed to setup rule for main table: %w", err) + } + + return nil +} + +func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int { + priority := c.cfg.Advanced.RulePrioOffset + for { + isFresh := true + for _, existingRule := range existingRules { + if existingRule.Priority == priority { + isFresh = false + break + } + } + if isFresh { + break + } else { + priority++ + } + } + return priority +} + +func (c LocalController) cleanupMainRule(rules []domain.RouteRule) error { + var family domain.IpFamily + for _, rule := range rules { + family = rule.IpFamily + break + } + + existingRules, err := c.nl.RuleList(int(family)) + if err != nil { + return fmt.Errorf("failed to get existing rules for family %s: %w", family, err) + } + + shouldHaveMainRule := false + for _, rule := range rules { + if rule.HasDefault == true { + shouldHaveMainRule = true + break + } + } + + mainRules := 0 + for _, existingRule := range existingRules { + if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { + mainRules++ + } + } + + removalCount := 0 + if mainRules > 1 { + removalCount = mainRules - 1 // we only want one single rule + } + if !shouldHaveMainRule { + removalCount = mainRules + } + + for _, existingRule := range existingRules { + if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { + if removalCount > 0 { + existingRule.Family = int(family) // set family, somehow the RuleList method does not populate the family field + if err := c.nl.RuleDel(&existingRule); err != nil { + return fmt.Errorf("failed to delete main rule: %w", err) + } + removalCount-- + } + } + } + + return nil +} + +func (c LocalController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error { + // TODO implement me + panic("implement me") +} + +// endregion routing-related + +// region statistics-related + +func (c LocalController) PingAddresses( + ctx context.Context, + addr string, +) (*domain.PingerResult, error) { + pinger, err := probing.NewPinger(addr) + if err != nil { + return nil, fmt.Errorf("failed to instantiate pinger for %s: %w", addr, err) + } + + checkCount := 1 + pinger.SetPrivileged(!c.cfg.Statistics.PingUnprivileged) + pinger.Count = checkCount + pinger.Timeout = 2 * time.Second + err = pinger.RunWithContext(ctx) // Blocks until finished. + if err != nil { + return nil, fmt.Errorf("failed to ping %s: %w", addr, err) + } + + stats := pinger.Statistics() + + return &domain.PingerResult{ + PacketsRecv: stats.PacketsRecv, + PacketsSent: stats.PacketsSent, + Rtts: stats.Rtts, + }, nil +} + +// endregion statistics-related diff --git a/internal/adapters/wgcontroller/mikrotik.go b/internal/adapters/wgcontroller/mikrotik.go new file mode 100644 index 0000000..8498d34 --- /dev/null +++ b/internal/adapters/wgcontroller/mikrotik.go @@ -0,0 +1,829 @@ +package wgcontroller + +import ( + "context" + "fmt" + "slices" + "strconv" + "strings" + "sync" + "time" + + "log/slog" + + "github.com/h44z/wg-portal/internal/config" + "github.com/h44z/wg-portal/internal/domain" + "github.com/h44z/wg-portal/internal/lowlevel" +) + +type MikrotikController struct { + coreCfg *config.Config + cfg *config.BackendMikrotik + + client *lowlevel.MikrotikApiClient + + // Add mutexes to prevent race conditions + interfaceMutexes sync.Map // map[domain.InterfaceIdentifier]*sync.Mutex + peerMutexes sync.Map // map[domain.PeerIdentifier]*sync.Mutex +} + +func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik) (*MikrotikController, error) { + client, err := lowlevel.NewMikrotikApiClient(coreCfg, cfg) + if err != nil { + return nil, fmt.Errorf("failed to create Mikrotik API client: %w", err) + } + + return &MikrotikController{ + coreCfg: coreCfg, + cfg: cfg, + + client: client, + + interfaceMutexes: sync.Map{}, + peerMutexes: sync.Map{}, + }, nil +} + +func (c *MikrotikController) GetId() domain.InterfaceBackend { + return domain.InterfaceBackend(c.cfg.Id) +} + +// getInterfaceMutex returns a mutex for the given interface to prevent concurrent modifications +func (c *MikrotikController) getInterfaceMutex(id domain.InterfaceIdentifier) *sync.Mutex { + mutex, _ := c.interfaceMutexes.LoadOrStore(id, &sync.Mutex{}) + return mutex.(*sync.Mutex) +} + +// getPeerMutex returns a mutex for the given peer to prevent concurrent modifications +func (c *MikrotikController) getPeerMutex(id domain.PeerIdentifier) *sync.Mutex { + mutex, _ := c.peerMutexes.LoadOrStore(id, &sync.Mutex{}) + return mutex.(*sync.Mutex) +} + +// region wireguard-related + +func (c *MikrotikController) GetInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) { + wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{ + PropList: []string{ + ".id", "name", "public-key", "private-key", "listen-port", "mtu", "disabled", "running", "comment", + }, + }) + if wgReply.Status != lowlevel.MikrotikApiStatusOk { + return nil, fmt.Errorf("failed to query interfaces: %v", wgReply.Error) + } + + // Parallelize loading of interface details to speed up overall latency. + // Use a bounded semaphore to avoid overloading the MikroTik device. + maxConcurrent := c.cfg.GetConcurrency() + sem := make(chan struct{}, maxConcurrent) + + interfaces := make([]domain.PhysicalInterface, 0, len(wgReply.Data)) + var mu sync.Mutex + var wgWait sync.WaitGroup + var firstErr error + ctx2, cancel := context.WithCancel(ctx) + defer cancel() + + for _, wgObj := range wgReply.Data { + wgWait.Add(1) + sem <- struct{}{} // block if more than maxConcurrent requests are processing + go func(wg lowlevel.GenericJsonObject) { + defer wgWait.Done() + defer func() { <-sem }() // read from the semaphore and make space for the next entry + if firstErr != nil { + return + } + pi, err := c.loadInterfaceData(ctx2, wg) + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + cancel() + } + mu.Unlock() + return + } + mu.Lock() + interfaces = append(interfaces, *pi) + mu.Unlock() + }(wgObj) + } + + wgWait.Wait() + if firstErr != nil { + return nil, firstErr + } + + return interfaces, nil +} + +func (c *MikrotikController) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) ( + *domain.PhysicalInterface, + error, +) { + wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{ + PropList: []string{ + ".id", "name", "public-key", "private-key", "listen-port", "mtu", "disabled", "running", + }, + Filters: map[string]string{ + "name": string(id), + }, + }) + if wgReply.Status != lowlevel.MikrotikApiStatusOk { + return nil, fmt.Errorf("failed to query interface %s: %v", id, wgReply.Error) + } + + if len(wgReply.Data) == 0 { + return nil, fmt.Errorf("interface %s not found", id) + } + + return c.loadInterfaceData(ctx, wgReply.Data[0]) +} + +func (c *MikrotikController) loadInterfaceData( + ctx context.Context, + wireGuardObj lowlevel.GenericJsonObject, +) (*domain.PhysicalInterface, error) { + deviceId := wireGuardObj.GetString(".id") + deviceName := wireGuardObj.GetString("name") + ifaceReply := c.client.Get(ctx, "/interface/"+deviceId, &lowlevel.MikrotikRequestOptions{ + PropList: []string{ + "name", "rx-byte", "tx-byte", + }, + }) + if ifaceReply.Status != lowlevel.MikrotikApiStatusOk { + return nil, fmt.Errorf("failed to query interface %s: %v", deviceId, ifaceReply.Error) + } + + ipv4, ipv6, err := c.loadIpAddresses(ctx, deviceName) + if err != nil { + return nil, fmt.Errorf("failed to query IP addresses for interface %s: %v", deviceId, err) + } + addresses := c.convertIpAddresses(ipv4, ipv6) + + interfaceModel, err := c.convertWireGuardInterface(wireGuardObj, ifaceReply.Data, addresses) + if err != nil { + return nil, fmt.Errorf("interface convert failed for %s: %w", deviceName, err) + } + return &interfaceModel, nil +} + +func (c *MikrotikController) loadIpAddresses( + ctx context.Context, + deviceName string, +) (ipv4 []lowlevel.GenericJsonObject, ipv6 []lowlevel.GenericJsonObject, err error) { + // Query IPv4 and IPv6 addresses in parallel to reduce latency. + var ( + v4 []lowlevel.GenericJsonObject + v6 []lowlevel.GenericJsonObject + v4Err error + v6Err error + wg sync.WaitGroup + ) + wg.Add(2) + + go func() { + defer wg.Done() + addrV4Reply := c.client.Query(ctx, "/ip/address", &lowlevel.MikrotikRequestOptions{ + PropList: []string{ + ".id", "address", "network", + }, + Filters: map[string]string{ + "interface": deviceName, + "dynamic": "false", // we only want static addresses + "disabled": "false", // we only want addresses that are not disabled + }, + }) + if addrV4Reply.Status != lowlevel.MikrotikApiStatusOk { + v4Err = fmt.Errorf("failed to query IPv4 addresses for interface %s: %v", deviceName, addrV4Reply.Error) + return + } + v4 = addrV4Reply.Data + }() + + go func() { + defer wg.Done() + addrV6Reply := c.client.Query(ctx, "/ipv6/address", &lowlevel.MikrotikRequestOptions{ + PropList: []string{ + ".id", "address", "network", + }, + Filters: map[string]string{ + "interface": deviceName, + "dynamic": "false", // we only want static addresses + "disabled": "false", // we only want addresses that are not disabled + }, + }) + if addrV6Reply.Status != lowlevel.MikrotikApiStatusOk { + v6Err = fmt.Errorf("failed to query IPv6 addresses for interface %s: %v", deviceName, addrV6Reply.Error) + return + } + v6 = addrV6Reply.Data + }() + + wg.Wait() + if v4Err != nil { + return nil, nil, v4Err + } + if v6Err != nil { + return nil, nil, v6Err + } + + return v4, v6, nil +} + +func (c *MikrotikController) convertIpAddresses( + ipv4, ipv6 []lowlevel.GenericJsonObject, +) []domain.Cidr { + addresses := make([]domain.Cidr, 0, len(ipv4)+len(ipv6)) + for _, addr := range append(ipv4, ipv6...) { + addrStr := addr.GetString("address") + if addrStr == "" { + continue + } + cidr, err := domain.CidrFromString(addrStr) + if err != nil { + continue + } + + addresses = append(addresses, cidr) + } + + return addresses +} + +func (c *MikrotikController) convertWireGuardInterface( + wg, iface lowlevel.GenericJsonObject, + addresses []domain.Cidr, +) ( + domain.PhysicalInterface, + error, +) { + pi := domain.PhysicalInterface{ + Identifier: domain.InterfaceIdentifier(wg.GetString("name")), + KeyPair: domain.KeyPair{ + PrivateKey: wg.GetString("private-key"), + PublicKey: wg.GetString("public-key"), + }, + ListenPort: wg.GetInt("listen-port"), + Addresses: addresses, + Mtu: wg.GetInt("mtu"), + FirewallMark: 0, + DeviceUp: wg.GetBool("running"), + ImportSource: domain.ControllerTypeMikrotik, + DeviceType: domain.ControllerTypeMikrotik, + BytesUpload: uint64(iface.GetInt("tx-byte")), + BytesDownload: uint64(iface.GetInt("rx-byte")), + } + + pi.SetExtras(domain.MikrotikInterfaceExtras{ + Id: wg.GetString(".id"), + Comment: wg.GetString("comment"), + Disabled: wg.GetBool("disabled"), + }) + + return pi, nil +} + +func (c *MikrotikController) GetPeers(ctx context.Context, deviceId domain.InterfaceIdentifier) ( + []domain.PhysicalPeer, + error, +) { + wgReply := c.client.Query(ctx, "/interface/wireguard/peers", &lowlevel.MikrotikRequestOptions{ + PropList: []string{ + ".id", "name", "allowed-address", "client-address", "client-endpoint", "client-keepalive", "comment", + "current-endpoint-address", "current-endpoint-port", "last-handshake", "persistent-keepalive", + "public-key", "private-key", "preshared-key", "mtu", "disabled", "rx", "tx", "responder", "client-dns", + }, + Filters: map[string]string{ + "interface": string(deviceId), + }, + }) + if wgReply.Status != lowlevel.MikrotikApiStatusOk { + return nil, fmt.Errorf("failed to query peers for %s: %v", deviceId, wgReply.Error) + } + + if len(wgReply.Data) == 0 { + return nil, nil + } + + peers := make([]domain.PhysicalPeer, 0, len(wgReply.Data)) + for _, peer := range wgReply.Data { + peerModel, err := c.convertWireGuardPeer(peer) + if err != nil { + return nil, fmt.Errorf("peer convert failed for %v: %w", peer.GetString("name"), err) + } + peers = append(peers, peerModel) + } + + return peers, nil +} + +func (c *MikrotikController) convertWireGuardPeer(peer lowlevel.GenericJsonObject) ( + domain.PhysicalPeer, + error, +) { + keepAliveSeconds := 0 + duration, err := time.ParseDuration(peer.GetString("persistent-keepalive")) + if err == nil { + keepAliveSeconds = int(duration.Seconds()) + } + + currentEndpoint := "" + if peer.GetString("current-endpoint-address") != "" && peer.GetString("current-endpoint-port") != "" { + currentEndpoint = peer.GetString("current-endpoint-address") + ":" + peer.GetString("current-endpoint-port") + } + + lastHandshakeTime := time.Time{} + if peer.GetString("last-handshake") != "" { + relDuration, err := time.ParseDuration(peer.GetString("last-handshake")) + if err == nil { + lastHandshakeTime = time.Now().Add(-relDuration) + } + } + + allowedAddresses, _ := domain.CidrsFromString(peer.GetString("allowed-address")) + + clientKeepAliveSeconds := 0 + duration, err = time.ParseDuration(peer.GetString("client-keepalive")) + if err == nil { + clientKeepAliveSeconds = int(duration.Seconds()) + } + + peerModel := domain.PhysicalPeer{ + Identifier: domain.PeerIdentifier(peer.GetString("public-key")), + Endpoint: currentEndpoint, + AllowedIPs: allowedAddresses, + KeyPair: domain.KeyPair{ + PublicKey: peer.GetString("public-key"), + PrivateKey: peer.GetString("private-key"), + }, + PresharedKey: domain.PreSharedKey(peer.GetString("preshared-key")), + PersistentKeepalive: keepAliveSeconds, + LastHandshake: lastHandshakeTime, + ProtocolVersion: 0, // Mikrotik does not support protocol versioning, so we set it to 0 + BytesUpload: uint64(peer.GetInt("rx")), + BytesDownload: uint64(peer.GetInt("tx")), + ImportSource: domain.ControllerTypeMikrotik, + } + + peerModel.SetExtras(domain.MikrotikPeerExtras{ + Id: peer.GetString(".id"), + Name: peer.GetString("name"), + Comment: peer.GetString("comment"), + IsResponder: peer.GetBool("responder"), + Disabled: peer.GetBool("disabled"), + ClientEndpoint: peer.GetString("client-endpoint"), + ClientAddress: peer.GetString("client-address"), + ClientDns: peer.GetString("client-dns"), + ClientKeepalive: clientKeepAliveSeconds, + }) + + return peerModel, nil +} + +func (c *MikrotikController) SaveInterface( + ctx context.Context, + id domain.InterfaceIdentifier, + updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error), +) error { + // Lock the interface to prevent concurrent modifications + mutex := c.getInterfaceMutex(id) + mutex.Lock() + defer mutex.Unlock() + + physicalInterface, err := c.getOrCreateInterface(ctx, id) + if err != nil { + return err + } + + deviceId := physicalInterface.GetExtras().(domain.MikrotikInterfaceExtras).Id + if updateFunc != nil { + physicalInterface, err = updateFunc(physicalInterface) + if err != nil { + return err + } + newExtras := physicalInterface.GetExtras().(domain.MikrotikInterfaceExtras) + newExtras.Id = deviceId // ensure the ID is not changed + physicalInterface.SetExtras(newExtras) + } + + if err := c.updateInterface(ctx, physicalInterface); err != nil { + return err + } + + return nil +} + +func (c *MikrotikController) getOrCreateInterface( + ctx context.Context, + id domain.InterfaceIdentifier, +) (*domain.PhysicalInterface, error) { + wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{ + PropList: []string{ + ".id", "name", "public-key", "private-key", "listen-port", "mtu", "disabled", "running", + }, + Filters: map[string]string{ + "name": string(id), + }, + }) + if wgReply.Status == lowlevel.MikrotikApiStatusOk && len(wgReply.Data) > 0 { + return c.loadInterfaceData(ctx, wgReply.Data[0]) + } + + // create a new interface if it does not exist + createReply := c.client.Create(ctx, "/interface/wireguard", lowlevel.GenericJsonObject{ + "name": string(id), + }) + if wgReply.Status == lowlevel.MikrotikApiStatusOk { + return c.loadInterfaceData(ctx, createReply.Data) + } + + return nil, fmt.Errorf("failed to create interface %s: %v", id, createReply.Error) +} + +func (c *MikrotikController) updateInterface(ctx context.Context, pi *domain.PhysicalInterface) error { + extras := pi.GetExtras().(domain.MikrotikInterfaceExtras) + interfaceId := extras.Id + wgReply := c.client.Update(ctx, "/interface/wireguard/"+interfaceId, lowlevel.GenericJsonObject{ + "name": pi.Identifier, + "comment": extras.Comment, + "mtu": strconv.Itoa(pi.Mtu), + "listen-port": strconv.Itoa(pi.ListenPort), + "private-key": pi.KeyPair.PrivateKey, + "disabled": strconv.FormatBool(!pi.DeviceUp), + }) + if wgReply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to update interface %s: %v", pi.Identifier, wgReply.Error) + } + + // update the interface's addresses + currentV4, currentV6, err := c.loadIpAddresses(ctx, string(pi.Identifier)) + if err != nil { + return fmt.Errorf("failed to load current addresses for interface %s: %v", pi.Identifier, err) + } + currentAddresses := c.convertIpAddresses(currentV4, currentV6) + + // get all addresses that are currently not in the interface, only in pi + newAddresses := make([]domain.Cidr, 0, len(pi.Addresses)) + for _, addr := range pi.Addresses { + if slices.Contains(currentAddresses, addr) { + continue + } + newAddresses = append(newAddresses, addr) + } + // get obsolete addresses that are in the interface, but not in pi + obsoleteAddresses := make([]domain.Cidr, 0, len(currentAddresses)) + for _, addr := range currentAddresses { + if slices.Contains(pi.Addresses, addr) { + continue + } + obsoleteAddresses = append(obsoleteAddresses, addr) + } + + // update the IP addresses for the interface + if err := c.updateIpAddresses(ctx, string(pi.Identifier), currentV4, currentV6, + newAddresses, obsoleteAddresses); err != nil { + return fmt.Errorf("failed to update IP addresses for interface %s: %v", pi.Identifier, err) + } + + return nil +} + +func (c *MikrotikController) updateIpAddresses( + ctx context.Context, + deviceName string, + currentV4, currentV6 []lowlevel.GenericJsonObject, + new, obsolete []domain.Cidr, +) error { + // first, delete all obsolete addresses + for _, addr := range obsolete { + // find ID of the address to delete + if addr.IsV4() { + for _, a := range currentV4 { + if a.GetString("address") == addr.String() { + // delete the address + reply := c.client.Delete(ctx, "/ip/address/"+a.GetString(".id")) + if reply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to delete obsolete IPv4 address %s: %v", addr, reply.Error) + } + break + } + } + } else { + for _, a := range currentV6 { + if a.GetString("address") == addr.String() { + // delete the address + reply := c.client.Delete(ctx, "/ipv6/address/"+a.GetString(".id")) + if reply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to delete obsolete IPv6 address %s: %v", addr, reply.Error) + } + break + } + } + } + } + + // then, add all new addresses + for _, addr := range new { + var createPath string + if addr.IsV4() { + createPath = "/ip/address" + } else { + createPath = "/ipv6/address" + } + + // create the address + reply := c.client.Create(ctx, createPath, lowlevel.GenericJsonObject{ + "address": addr.String(), + "interface": deviceName, + }) + if reply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to create new address %s: %v", addr, reply.Error) + } + } + + return nil +} + +func (c *MikrotikController) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error { + // Lock the interface to prevent concurrent modifications + mutex := c.getInterfaceMutex(id) + mutex.Lock() + defer mutex.Unlock() + + // delete the interface's addresses + currentV4, currentV6, err := c.loadIpAddresses(ctx, string(id)) + if err != nil { + return fmt.Errorf("failed to load current addresses for interface %s: %v", id, err) + } + for _, a := range currentV4 { + // delete the address + reply := c.client.Delete(ctx, "/ip/address/"+a.GetString(".id")) + if reply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to delete IPv4 address %s: %v", a.GetString("address"), reply.Error) + } + } + for _, a := range currentV6 { + // delete the address + reply := c.client.Delete(ctx, "/ipv6/address/"+a.GetString(".id")) + if reply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to delete IPv6 address %s: %v", a.GetString("address"), reply.Error) + } + } + + // delete the WireGuard interface + wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{ + PropList: []string{".id"}, + Filters: map[string]string{ + "name": string(id), + }, + }) + if wgReply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("unable to find WireGuard interface %s: %v", id, wgReply.Error) + } + if len(wgReply.Data) == 0 { + return nil // interface does not exist, nothing to delete + } + + interfaceId := wgReply.Data[0].GetString(".id") + deleteReply := c.client.Delete(ctx, "/interface/wireguard/"+interfaceId) + if deleteReply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to delete WireGuard interface %s: %v", id, deleteReply.Error) + } + + return nil +} + +func (c *MikrotikController) SavePeer( + ctx context.Context, + deviceId domain.InterfaceIdentifier, + id domain.PeerIdentifier, + updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error), +) error { + // Lock the peer to prevent concurrent modifications + mutex := c.getPeerMutex(id) + mutex.Lock() + defer mutex.Unlock() + + physicalPeer, err := c.getOrCreatePeer(ctx, deviceId, id) + if err != nil { + return err + } + + peerId := physicalPeer.GetExtras().(domain.MikrotikPeerExtras).Id + physicalPeer, err = updateFunc(physicalPeer) + if err != nil { + return err + } + newExtras := physicalPeer.GetExtras().(domain.MikrotikPeerExtras) + newExtras.Id = peerId // ensure the ID is not changed + physicalPeer.SetExtras(newExtras) + + if err := c.updatePeer(ctx, deviceId, physicalPeer); err != nil { + return err + } + + return nil +} + +func (c *MikrotikController) getOrCreatePeer( + ctx context.Context, + deviceId domain.InterfaceIdentifier, + id domain.PeerIdentifier, +) (*domain.PhysicalPeer, error) { + wgReply := c.client.Query(ctx, "/interface/wireguard/peers", &lowlevel.MikrotikRequestOptions{ + PropList: []string{ + ".id", "name", "public-key", "private-key", "preshared-key", "persistent-keepalive", "client-address", + "client-endpoint", "client-keepalive", "allowed-address", "client-dns", "comment", "disabled", "responder", + }, + Filters: map[string]string{ + "public-key": string(id), + "interface": string(deviceId), + }, + }) + if wgReply.Status == lowlevel.MikrotikApiStatusOk && len(wgReply.Data) > 0 { + slog.Debug("found existing Mikrotik peer", "peer", id, "interface", deviceId) + existingPeer, err := c.convertWireGuardPeer(wgReply.Data[0]) + if err != nil { + return nil, err + } + return &existingPeer, nil + } + + // create a new peer if it does not exist + slog.Debug("creating new Mikrotik peer", "peer", id, "interface", deviceId) + createReply := c.client.Create(ctx, "/interface/wireguard/peers", lowlevel.GenericJsonObject{ + "name": fmt.Sprintf("tmp-wg-%s", id[0:8]), + "interface": string(deviceId), + "public-key": string(id), + "allowed-address": "0.0.0.0/0", // Use 0.0.0.0/0 as default, will be updated by updatePeer + }) + if createReply.Status == lowlevel.MikrotikApiStatusOk { + newPeer, err := c.convertWireGuardPeer(createReply.Data) + if err != nil { + return nil, err + } + slog.Debug("successfully created Mikrotik peer", "peer", id, "interface", deviceId) + return &newPeer, nil + } + + return nil, fmt.Errorf("failed to create peer %s for interface %s: %v", id, deviceId, createReply.Error) +} + +func (c *MikrotikController) updatePeer( + ctx context.Context, + deviceId domain.InterfaceIdentifier, + pp *domain.PhysicalPeer, +) error { + extras := pp.GetExtras().(domain.MikrotikPeerExtras) + peerId := extras.Id + + endpoint := pp.Endpoint + endpointPort := "51820" // default port if not set + if s := strings.Split(endpoint, ":"); len(s) == 2 { + endpoint = s[0] + endpointPort = s[1] + } + + allowedAddressStr := domain.CidrsToString(pp.AllowedIPs) + slog.Debug("updating Mikrotik peer", + "peer", pp.Identifier, + "interface", deviceId, + "allowed-address", allowedAddressStr, + "allowed-ips-count", len(pp.AllowedIPs), + "disabled", extras.Disabled) + + wgReply := c.client.Update(ctx, "/interface/wireguard/peers/"+peerId, lowlevel.GenericJsonObject{ + "name": extras.Name, + "comment": extras.Comment, + "preshared-key": pp.PresharedKey, + "public-key": pp.KeyPair.PublicKey, + "private-key": pp.KeyPair.PrivateKey, + "persistent-keepalive": (time.Duration(pp.PersistentKeepalive) * time.Second).String(), + "disabled": strconv.FormatBool(extras.Disabled), + "responder": strconv.FormatBool(extras.IsResponder), + "client-endpoint": extras.ClientEndpoint, + "client-address": extras.ClientAddress, + "client-keepalive": (time.Duration(extras.ClientKeepalive) * time.Second).String(), + "client-dns": extras.ClientDns, + "endpoint-address": endpoint, + "endpoint-port": endpointPort, + "allowed-address": allowedAddressStr, // Add the missing allowed-address field + }) + if wgReply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to update peer %s on interface %s: %v", pp.Identifier, deviceId, wgReply.Error) + } + + if extras.Disabled { + slog.Debug("successfully disabled Mikrotik peer", "peer", pp.Identifier, "interface", deviceId) + } else { + slog.Debug("successfully updated Mikrotik peer", "peer", pp.Identifier, "interface", deviceId) + } + + return nil +} + +func (c *MikrotikController) DeletePeer( + ctx context.Context, + deviceId domain.InterfaceIdentifier, + id domain.PeerIdentifier, +) error { + // Lock the peer to prevent concurrent modifications + mutex := c.getPeerMutex(id) + mutex.Lock() + defer mutex.Unlock() + + wgReply := c.client.Query(ctx, "/interface/wireguard/peers", &lowlevel.MikrotikRequestOptions{ + PropList: []string{".id"}, + Filters: map[string]string{ + "public-key": string(id), + "interface": string(deviceId), + }, + }) + if wgReply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("unable to find WireGuard peer %s for interface %s: %v", id, deviceId, wgReply.Error) + } + if len(wgReply.Data) == 0 { + return nil // peer does not exist, nothing to delete + } + + peerId := wgReply.Data[0].GetString(".id") + deleteReply := c.client.Delete(ctx, "/interface/wireguard/peers/"+peerId) + if deleteReply.Status != lowlevel.MikrotikApiStatusOk { + return fmt.Errorf("failed to delete WireGuard peer %s for interface %s: %v", id, deviceId, deleteReply.Error) + } + + return nil +} + +// endregion wireguard-related + +// region wg-quick-related + +func (c *MikrotikController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error { + // TODO implement me + panic("implement me") +} + +func (c *MikrotikController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error { + // TODO implement me + panic("implement me") +} + +func (c *MikrotikController) UnsetDNS(id domain.InterfaceIdentifier) error { + // TODO implement me + panic("implement me") +} + +// endregion wg-quick-related + +// region routing-related + +func (c *MikrotikController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error { + // TODO implement me + panic("implement me") +} + +func (c *MikrotikController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error { + // TODO implement me + panic("implement me") +} + +// endregion routing-related + +// region statistics-related + +func (c *MikrotikController) PingAddresses( + ctx context.Context, + addr string, +) (*domain.PingerResult, error) { + wgReply := c.client.ExecList(ctx, "/tool/ping", + // limit to 1 packet with a max running time of 2 seconds + lowlevel.GenericJsonObject{"address": addr, "count": 1, "interval": "00:00:02"}, + ) + + if wgReply.Status != lowlevel.MikrotikApiStatusOk { + return nil, fmt.Errorf("failed to ping %s: %v", addr, wgReply.Error) + } + + var result domain.PingerResult + for _, item := range wgReply.Data { + result.PacketsRecv += item.GetInt("received") + result.PacketsSent += item.GetInt("sent") + + rttStr := item.GetString("avg-rtt") + if rttStr != "" { + rtt, err := time.ParseDuration(rttStr) + if err == nil { + result.Rtts = append(result.Rtts, rtt) + } else { + // use a high value to indicate failure or timeout + result.Rtts = append(result.Rtts, 999999*time.Millisecond) + } + } + } + + return &result, nil +} + +// endregion statistics-related diff --git a/internal/app/api/v0/handlers/endpoint_config.go b/internal/app/api/v0/handlers/endpoint_config.go index 21b342a..9936644 100644 --- a/internal/app/api/v0/handlers/endpoint_config.go +++ b/internal/app/api/v0/handlers/endpoint_config.go @@ -21,17 +21,23 @@ import ( //go:embed frontend_config.js.gotpl var frontendJs embed.FS +type ControllerManager interface { + GetControllerNames() []config.BackendBase +} + type ConfigEndpoint struct { cfg *config.Config authenticator Authenticator + controllerMgr ControllerManager tpl *respond.TemplateRenderer } -func NewConfigEndpoint(cfg *config.Config, authenticator Authenticator) ConfigEndpoint { +func NewConfigEndpoint(cfg *config.Config, authenticator Authenticator, ctrlMgr ControllerManager) ConfigEndpoint { ep := ConfigEndpoint{ cfg: cfg, authenticator: authenticator, + controllerMgr: ctrlMgr, tpl: respond.NewTemplateRenderer(template.Must(template.ParseFS(frontendJs, "frontend_config.js.gotpl"))), } @@ -96,13 +102,36 @@ func (e ConfigEndpoint) handleSettingsGet() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { sessionUser := domain.GetUserInfo(r.Context()) + controllerFn := func() []model.SettingsBackendNames { + controllers := e.controllerMgr.GetControllerNames() + names := make([]model.SettingsBackendNames, 0, len(controllers)) + + for _, controller := range controllers { + displayName := controller.GetDisplayName() + if displayName == "" { + displayName = controller.Id // fallback to ID if no display name is set + } + if controller.Id == config.LocalBackendName { + displayName = "modals.interface-edit.backend.local" // use a localized string for the local backend + } + names = append(names, model.SettingsBackendNames{ + Id: controller.Id, + Name: displayName, + }) + } + + return names + + } + hasSocialLogin := len(e.cfg.Auth.OAuth) > 0 || len(e.cfg.Auth.OpenIDConnect) > 0 || e.cfg.Auth.WebAuthn.Enabled // For anonymous users, we return the settings object with minimal information if sessionUser.Id == domain.CtxUnknownUserId || sessionUser.Id == "" { respond.JSON(w, http.StatusOK, model.Settings{ - WebAuthnEnabled: e.cfg.Auth.WebAuthn.Enabled, - LoginFormVisible: !e.cfg.Auth.HideLoginForm || !hasSocialLogin, + WebAuthnEnabled: e.cfg.Auth.WebAuthn.Enabled, + AvailableBackends: []model.SettingsBackendNames{}, // return an empty list instead of null + LoginFormVisible: !e.cfg.Auth.HideLoginForm || !hasSocialLogin, }) } else { respond.JSON(w, http.StatusOK, model.Settings{ @@ -112,6 +141,7 @@ func (e ConfigEndpoint) handleSettingsGet() http.HandlerFunc { ApiAdminOnly: e.cfg.Advanced.ApiAdminOnly, WebAuthnEnabled: e.cfg.Auth.WebAuthn.Enabled, MinPasswordLength: e.cfg.Auth.MinPasswordLength, + AvailableBackends: controllerFn(), LoginFormVisible: !e.cfg.Auth.HideLoginForm || !hasSocialLogin, }) } diff --git a/internal/app/api/v0/model/models.go b/internal/app/api/v0/model/models.go index 5c3ec73..07e2eba 100644 --- a/internal/app/api/v0/model/models.go +++ b/internal/app/api/v0/model/models.go @@ -6,11 +6,17 @@ type Error struct { } type Settings struct { - MailLinkOnly bool `json:"MailLinkOnly"` - PersistentConfigSupported bool `json:"PersistentConfigSupported"` - SelfProvisioning bool `json:"SelfProvisioning"` - ApiAdminOnly bool `json:"ApiAdminOnly"` - WebAuthnEnabled bool `json:"WebAuthnEnabled"` - MinPasswordLength int `json:"MinPasswordLength"` - LoginFormVisible bool `json:"LoginFormVisible"` + MailLinkOnly bool `json:"MailLinkOnly"` + PersistentConfigSupported bool `json:"PersistentConfigSupported"` + SelfProvisioning bool `json:"SelfProvisioning"` + ApiAdminOnly bool `json:"ApiAdminOnly"` + WebAuthnEnabled bool `json:"WebAuthnEnabled"` + MinPasswordLength int `json:"MinPasswordLength"` + AvailableBackends []SettingsBackendNames `json:"AvailableBackends"` + LoginFormVisible bool `json:"LoginFormVisible"` +} + +type SettingsBackendNames struct { + Id string `json:"Id"` + Name string `json:"Name"` } diff --git a/internal/app/api/v0/model/models_interface.go b/internal/app/api/v0/model/models_interface.go index 5684178..1b22d02 100644 --- a/internal/app/api/v0/model/models_interface.go +++ b/internal/app/api/v0/model/models_interface.go @@ -4,6 +4,7 @@ import ( "time" "github.com/h44z/wg-portal/internal" + "github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/domain" ) @@ -11,6 +12,7 @@ type Interface struct { Identifier string `json:"Identifier" example:"wg0"` // device name, for example: wg0 DisplayName string `json:"DisplayName"` // a nice display name/ description for the interface Mode string `json:"Mode" example:"server"` // the interface type, either 'server', 'client' or 'any' + Backend string `json:"Backend" example:"local"` // the backend used for this interface e.g., local, mikrotik, ... PrivateKey string `json:"PrivateKey" example:"abcdef=="` // private Key of the server interface PublicKey string `json:"PublicKey" example:"abcdef=="` // public Key of the server interface Disabled bool `json:"Disabled"` // flag that specifies if the interface is enabled (up) or not (down) @@ -57,6 +59,7 @@ func NewInterface(src *domain.Interface, peers []domain.Peer) *Interface { Identifier: string(src.Identifier), DisplayName: src.DisplayName, Mode: string(src.Type), + Backend: string(src.Backend), PrivateKey: src.PrivateKey, PublicKey: src.PublicKey, Disabled: src.IsDisabled(), @@ -92,6 +95,10 @@ func NewInterface(src *domain.Interface, peers []domain.Peer) *Interface { Filename: src.GetConfigFileName(), } + if iface.Backend == "" { + iface.Backend = config.LocalBackendName // default to local backend + } + if len(peers) > 0 { iface.TotalPeers = len(peers) @@ -146,6 +153,7 @@ func NewDomainInterface(src *Interface) *domain.Interface { SaveConfig: src.SaveConfig, DisplayName: src.DisplayName, Type: domain.InterfaceType(src.Mode), + Backend: domain.InterfaceBackend(src.Backend), DriverType: "", // currently unused Disabled: nil, // set below DisabledReason: src.DisabledReason, diff --git a/internal/app/app.go b/internal/app/app.go index 1eb24cb..e33cfc0 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -46,7 +46,7 @@ func Initialize( users: users, } - startupContext, cancel := context.WithTimeout(context.Background(), 30*time.Second) + startupContext, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() // Switch to admin user context diff --git a/internal/app/wireguard/controller_manager.go b/internal/app/wireguard/controller_manager.go new file mode 100644 index 0000000..ab1eaa9 --- /dev/null +++ b/internal/app/wireguard/controller_manager.go @@ -0,0 +1,166 @@ +package wireguard + +import ( + "context" + "fmt" + "log/slog" + "maps" + "slices" + + "github.com/h44z/wg-portal/internal/adapters/wgcontroller" + "github.com/h44z/wg-portal/internal/config" + "github.com/h44z/wg-portal/internal/domain" +) + +type InterfaceController interface { + GetId() domain.InterfaceBackend + GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) + GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) + GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) + SaveInterface( + _ context.Context, + id domain.InterfaceIdentifier, + updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error), + ) error + DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error + SavePeer( + _ context.Context, + deviceId domain.InterfaceIdentifier, + id domain.PeerIdentifier, + updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error), + ) error + DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error + PingAddresses( + ctx context.Context, + addr string, + ) (*domain.PingerResult, error) +} + +type backendInstance struct { + Config config.BackendBase // Config is the configuration for the backend instance. + Implementation InterfaceController +} + +type ControllerManager struct { + cfg *config.Config + controllers map[domain.InterfaceBackend]backendInstance +} + +func NewControllerManager(cfg *config.Config) (*ControllerManager, error) { + c := &ControllerManager{ + cfg: cfg, + controllers: make(map[domain.InterfaceBackend]backendInstance), + } + + err := c.init() + if err != nil { + return nil, err + } + + return c, nil +} + +func (c *ControllerManager) init() error { + if err := c.registerLocalController(); err != nil { + return err + } + + if err := c.registerMikrotikControllers(); err != nil { + return err + } + + c.logRegisteredControllers() + + return nil +} + +func (c *ControllerManager) registerLocalController() error { + localController, err := wgcontroller.NewLocalController(c.cfg) + if err != nil { + return fmt.Errorf("failed to create local WireGuard controller: %w", err) + } + + c.controllers[config.LocalBackendName] = backendInstance{ + Config: config.BackendBase{ + Id: config.LocalBackendName, + DisplayName: "Local WireGuard Controller", + }, + Implementation: localController, + } + return nil +} + +func (c *ControllerManager) registerMikrotikControllers() error { + for _, backendConfig := range c.cfg.Backend.Mikrotik { + if backendConfig.Id == config.LocalBackendName { + slog.Warn("skipping registration of Mikrotik controller with reserved ID", "id", config.LocalBackendName) + continue + } + + controller, err := wgcontroller.NewMikrotikController(c.cfg, &backendConfig) + if err != nil { + return fmt.Errorf("failed to create Mikrotik controller for backend %s: %w", backendConfig.Id, err) + } + + c.controllers[domain.InterfaceBackend(backendConfig.Id)] = backendInstance{ + Config: backendConfig.BackendBase, + Implementation: controller, + } + } + return nil +} + +func (c *ControllerManager) logRegisteredControllers() { + for backend, controller := range c.controllers { + slog.Debug("backend controller registered", + "backend", backend, "type", fmt.Sprintf("%T", controller.Implementation)) + } +} + +func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) InterfaceController { + return c.getController(backend, "") +} + +func (c *ControllerManager) GetController(iface domain.Interface) InterfaceController { + return c.getController(iface.Backend, iface.Identifier) +} + +func (c *ControllerManager) getController( + backend domain.InterfaceBackend, + ifaceId domain.InterfaceIdentifier, +) InterfaceController { + if backend == "" { + // If no backend is specified, use the local controller. + // This might be the case for interfaces created in previous WireGuard Portal versions. + backend = config.LocalBackendName + } + + controller, exists := c.controllers[backend] + if !exists { + controller, exists = c.controllers[config.LocalBackendName] // Fallback to local controller + if !exists { + // If the local controller is also not found, panic + panic(fmt.Sprintf("%s interface controller for backend %s not found", ifaceId, backend)) + } + slog.Warn("controller for backend not found, using local controller", + "backend", backend, "interface", ifaceId) + } + return controller.Implementation +} + +func (c *ControllerManager) GetAllControllers() []InterfaceController { + var backendInstances = make([]InterfaceController, 0, len(c.controllers)) + for instance := range maps.Values(c.controllers) { + backendInstances = append(backendInstances, instance.Implementation) + } + return backendInstances +} + +func (c *ControllerManager) GetControllerNames() []config.BackendBase { + var names []config.BackendBase + for _, id := range slices.Sorted(maps.Keys(c.controllers)) { + names = append(names, c.controllers[id].Config) + } + + return names +} diff --git a/internal/app/wireguard/statistics.go b/internal/app/wireguard/statistics.go index 9ab8946..78ca6eb 100644 --- a/internal/app/wireguard/statistics.go +++ b/internal/app/wireguard/statistics.go @@ -6,8 +6,6 @@ import ( "sync" "time" - probing "github.com/prometheus-community/pro-bing" - "github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/domain" @@ -30,11 +28,6 @@ type StatisticsDatabaseRepo interface { DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error } -type StatisticsInterfaceController interface { - GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) - GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) -} - type StatisticsMetricsServer interface { UpdateInterfaceMetrics(status domain.InterfaceStatus) UpdatePeerMetrics(peer *domain.Peer, status domain.PeerStatus) @@ -47,15 +40,20 @@ type StatisticsEventBus interface { Publish(topic string, args ...any) } +type pingJob struct { + Peer domain.Peer + Backend domain.InterfaceBackend +} + type StatisticsCollector struct { cfg *config.Config bus StatisticsEventBus pingWaitGroup sync.WaitGroup - pingJobs chan domain.Peer + pingJobs chan pingJob db StatisticsDatabaseRepo - wg StatisticsInterfaceController + wg *ControllerManager ms StatisticsMetricsServer peerChangeEvent chan domain.PeerIdentifier @@ -66,7 +64,7 @@ func NewStatisticsCollector( cfg *config.Config, bus StatisticsEventBus, db StatisticsDatabaseRepo, - wg StatisticsInterfaceController, + wg *ControllerManager, ms StatisticsMetricsServer, ) (*StatisticsCollector, error) { c := &StatisticsCollector{ @@ -117,7 +115,7 @@ func (c *StatisticsCollector) collectInterfaceData(ctx context.Context) { } for _, in := range interfaces { - physicalInterface, err := c.wg.GetInterface(ctx, in.Identifier) + physicalInterface, err := c.wg.GetController(in).GetInterface(ctx, in.Identifier) if err != nil { slog.Warn("failed to load physical interface for data collection", "interface", in.Identifier, "error", err) @@ -169,7 +167,7 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) { } for _, in := range interfaces { - peers, err := c.wg.GetPeers(ctx, in.Identifier) + peers, err := c.wg.GetController(in).GetPeers(ctx, in.Identifier) if err != nil { slog.Warn("failed to fetch peers for data collection", "interface", in.Identifier, "error", err) continue @@ -271,7 +269,7 @@ func (c *StatisticsCollector) startPingWorkers(ctx context.Context) { c.pingWaitGroup = sync.WaitGroup{} c.pingWaitGroup.Add(c.cfg.Statistics.PingCheckWorkers) - c.pingJobs = make(chan domain.Peer, c.cfg.Statistics.PingCheckWorkers) + c.pingJobs = make(chan pingJob, c.cfg.Statistics.PingCheckWorkers) // start workers for i := 0; i < c.cfg.Statistics.PingCheckWorkers; i++ { @@ -314,7 +312,10 @@ func (c *StatisticsCollector) enqueuePingChecks(ctx context.Context) { continue } for _, peer := range peers { - c.pingJobs <- peer + c.pingJobs <- pingJob{ + Peer: peer, + Backend: in.Backend, + } } } } @@ -323,11 +324,14 @@ func (c *StatisticsCollector) enqueuePingChecks(ctx context.Context) { func (c *StatisticsCollector) pingWorker(ctx context.Context) { defer c.pingWaitGroup.Done() - for peer := range c.pingJobs { + for job := range c.pingJobs { + peer := job.Peer + backend := job.Backend + var connectionStateChanged bool var newPeerStatus domain.PeerStatus - peerPingable := c.isPeerPingable(ctx, peer) + peerPingable := c.isPeerPingable(ctx, backend, peer) slog.Debug("peer ping check completed", "peer", peer.Identifier, "pingable", peerPingable) now := time.Now() @@ -368,7 +372,11 @@ func (c *StatisticsCollector) pingWorker(ctx context.Context) { } } -func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Peer) bool { +func (c *StatisticsCollector) isPeerPingable( + ctx context.Context, + backend domain.InterfaceBackend, + peer domain.Peer, +) bool { if !c.cfg.Statistics.UsePingChecks { return false } @@ -378,23 +386,13 @@ func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Pe return false } - pinger, err := probing.NewPinger(checkAddr) + stats, err := c.wg.GetControllerByName(backend).PingAddresses(ctx, checkAddr) if err != nil { - slog.Debug("failed to instantiate pinger", "peer", peer.Identifier, "address", checkAddr, "error", err) + slog.Debug("failed to ping peer", "peer", peer.Identifier, "error", err) return false } - checkCount := 1 - pinger.SetPrivileged(!c.cfg.Statistics.PingUnprivileged) - pinger.Count = checkCount - pinger.Timeout = 2 * time.Second - err = pinger.RunWithContext(ctx) // Blocks until finished. - if err != nil { - slog.Debug("pinger for peer exited unexpectedly", "peer", peer.Identifier, "address", checkAddr, "error", err) - return false - } - stats := pinger.Statistics() - return stats.PacketsRecv == checkCount + return stats.IsPingable() } func (c *StatisticsCollector) updateInterfaceMetrics(status domain.InterfaceStatus) { diff --git a/internal/app/wireguard/wireguard.go b/internal/app/wireguard/wireguard.go index eb76bc9..b28f70e 100644 --- a/internal/app/wireguard/wireguard.go +++ b/internal/app/wireguard/wireguard.go @@ -37,25 +37,6 @@ type InterfaceAndPeerDatabaseRepo interface { GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error) } -type InterfaceController interface { - GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) - GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) - GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) - SaveInterface( - _ context.Context, - id domain.InterfaceIdentifier, - updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error), - ) error - DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error - SavePeer( - _ context.Context, - deviceId domain.InterfaceIdentifier, - id domain.PeerIdentifier, - updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error), - ) error - DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error -} - type WgQuickController interface { ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error @@ -75,7 +56,7 @@ type Manager struct { cfg *config.Config bus EventBus db InterfaceAndPeerDatabaseRepo - wg InterfaceController + wg *ControllerManager quick WgQuickController userLockMap *sync.Map @@ -84,7 +65,7 @@ type Manager struct { func NewWireGuardManager( cfg *config.Config, bus EventBus, - wg InterfaceController, + wg *ControllerManager, quick WgQuickController, db InterfaceAndPeerDatabaseRepo, ) (*Manager, error) { diff --git a/internal/app/wireguard/wireguard_interfaces.go b/internal/app/wireguard/wireguard_interfaces.go index 17a28bc..22b6658 100644 --- a/internal/app/wireguard/wireguard_interfaces.go +++ b/internal/app/wireguard/wireguard_interfaces.go @@ -11,6 +11,7 @@ import ( "github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/app/audit" + "github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/domain" ) @@ -21,12 +22,17 @@ func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.Physical return nil, err } - physicalInterfaces, err := m.wg.GetInterfaces(ctx) - if err != nil { - return nil, err + var allPhysicalInterfaces []domain.PhysicalInterface + for _, wgBackend := range m.wg.GetAllControllers() { + physicalInterfaces, err := wgBackend.GetInterfaces(ctx) + if err != nil { + return nil, err + } + + allPhysicalInterfaces = append(allPhysicalInterfaces, physicalInterfaces...) } - return physicalInterfaces, nil + return allPhysicalInterfaces, nil } // GetInterfaceAndPeers returns the interface and all peers for the given interface identifier. @@ -109,47 +115,49 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter return 0, err } - physicalInterfaces, err := m.wg.GetInterfaces(ctx) - if err != nil { - return 0, err - } - - // if no filter is given, exclude already existing interfaces - var excludedInterfaces []domain.InterfaceIdentifier - if len(filter) == 0 { - existingInterfaces, err := m.db.GetAllInterfaces(ctx) - if err != nil { - return 0, err - } - for _, existingInterface := range existingInterfaces { - excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier) - } - } - imported := 0 - for _, physicalInterface := range physicalInterfaces { - if slices.Contains(excludedInterfaces, physicalInterface.Identifier) { - continue - } - - if len(filter) != 0 && !slices.Contains(filter, physicalInterface.Identifier) { - continue - } - - slog.Info("importing new interface", "interface", physicalInterface.Identifier) - - physicalPeers, err := m.wg.GetPeers(ctx, physicalInterface.Identifier) + for _, wgBackend := range m.wg.GetAllControllers() { + physicalInterfaces, err := wgBackend.GetInterfaces(ctx) if err != nil { return 0, err } - err = m.importInterface(ctx, &physicalInterface, physicalPeers) - if err != nil { - return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err) + // if no filter is given, exclude already existing interfaces + var excludedInterfaces []domain.InterfaceIdentifier + if len(filter) == 0 { + existingInterfaces, err := m.db.GetAllInterfaces(ctx) + if err != nil { + return 0, err + } + for _, existingInterface := range existingInterfaces { + excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier) + } } - slog.Info("imported new interface", "interface", physicalInterface.Identifier, "peers", len(physicalPeers)) - imported++ + for _, physicalInterface := range physicalInterfaces { + if slices.Contains(excludedInterfaces, physicalInterface.Identifier) { + continue + } + + if len(filter) != 0 && !slices.Contains(filter, physicalInterface.Identifier) { + continue + } + + slog.Info("importing new interface", "interface", physicalInterface.Identifier) + + physicalPeers, err := wgBackend.GetPeers(ctx, physicalInterface.Identifier) + if err != nil { + return 0, err + } + + err = m.importInterface(ctx, wgBackend, &physicalInterface, physicalPeers) + if err != nil { + return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err) + } + + slog.Info("imported new interface", "interface", physicalInterface.Identifier, "peers", len(physicalPeers)) + imported++ + } } return imported, nil @@ -213,7 +221,7 @@ func (m Manager) RestoreInterfaceState( return fmt.Errorf("failed to load peers for %s: %w", iface.Identifier, err) } - _, err = m.wg.GetInterface(ctx, iface.Identifier) + _, err = m.wg.GetController(iface).GetInterface(ctx, iface.Identifier) if err != nil && !iface.IsDisabled() { slog.Debug("creating missing interface", "interface", iface.Identifier) @@ -260,18 +268,14 @@ func (m Manager) RestoreInterfaceState( // restore peers for _, peer := range peers { switch { - case iface.IsDisabled(): // if interface is disabled, delete all peers - if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil { + case iface.IsDisabled() && iface.Backend == config.LocalBackendName: // if interface is disabled, delete all peers + if err := m.wg.GetController(iface).DeletePeer(ctx, iface.Identifier, + peer.Identifier); err != nil { return fmt.Errorf("failed to remove peer %s for disabled interface %s: %w", peer.Identifier, iface.Identifier, err) } - case peer.IsDisabled(): // if peer is disabled, delete it - if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil { - return fmt.Errorf("failed to remove disbaled peer %s from interface %s: %w", - peer.Identifier, iface.Identifier, err) - } default: // update peer - err := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier, + err := m.wg.GetController(iface).SavePeer(ctx, iface.Identifier, peer.Identifier, func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) { domain.MergeToPhysicalPeer(pp, &peer) return pp, nil @@ -284,7 +288,7 @@ func (m Manager) RestoreInterfaceState( } // remove non-wgportal peers - physicalPeers, _ := m.wg.GetPeers(ctx, iface.Identifier) + physicalPeers, _ := m.wg.GetController(iface).GetPeers(ctx, iface.Identifier) for _, physicalPeer := range physicalPeers { isWgPortalPeer := false for _, peer := range peers { @@ -294,7 +298,8 @@ func (m Manager) RestoreInterfaceState( } } if !isWgPortalPeer { - err := m.wg.DeletePeer(ctx, iface.Identifier, domain.PeerIdentifier(physicalPeer.PublicKey)) + err := m.wg.GetController(iface).DeletePeer(ctx, iface.Identifier, + domain.PeerIdentifier(physicalPeer.PublicKey)) if err != nil { return fmt.Errorf("failed to remove non-wgportal peer %s from interface %s: %w", physicalPeer.PublicKey, iface.Identifier, err) @@ -459,7 +464,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif existingInterface.Disabled = &now // simulate a disabled interface existingInterface.DisabledReason = domain.DisabledReasonDeleted - physicalInterface, _ := m.wg.GetInterface(ctx, id) + physicalInterface, _ := m.wg.GetController(*existingInterface).GetInterface(ctx, id) if err := m.handleInterfacePreSaveHooks(existingInterface, !existingInterface.IsDisabled(), false); err != nil { return fmt.Errorf("pre-delete hooks failed: %w", err) @@ -473,7 +478,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif return fmt.Errorf("peer deletion failure: %w", err) } - if err := m.wg.DeleteInterface(ctx, id); err != nil { + if err := m.wg.GetController(*existingInterface).DeleteInterface(ctx, id); err != nil { return fmt.Errorf("wireguard deletion failure: %w", err) } @@ -522,7 +527,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) ( err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) { iface.CopyCalculatedAttributes(i) - err := m.wg.SaveInterface(ctx, iface.Identifier, + err := m.wg.GetController(*iface).SaveInterface(ctx, iface.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { domain.MergeToPhysicalInterface(pi, iface) return pi, nil @@ -538,7 +543,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) ( } if iface.IsDisabled() { - physicalInterface, _ := m.wg.GetInterface(ctx, iface.Identifier) + physicalInterface, _ := m.wg.GetController(*iface).GetInterface(ctx, iface.Identifier) fwMark := iface.FirewallMark if physicalInterface != nil && fwMark == 0 { fwMark = physicalInterface.FirewallMark @@ -556,13 +561,13 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) ( } // If the interface has just been enabled, restore its peers on the physical controller - if !oldEnabled && newEnabled { + if !oldEnabled && newEnabled && iface.Backend == config.LocalBackendName { peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier) if err != nil { return nil, fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err) } for _, peer := range peers { - saveErr := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier, + saveErr := m.wg.GetController(*iface).SavePeer(ctx, iface.Identifier, peer.Identifier, func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) { domain.MergeToPhysicalPeer(pp, &peer) return pp, nil @@ -766,7 +771,12 @@ func (m Manager) getFreshListenPort(ctx context.Context) (port int, err error) { return } -func (m Manager) importInterface(ctx context.Context, in *domain.PhysicalInterface, peers []domain.PhysicalPeer) error { +func (m Manager) importInterface( + ctx context.Context, + backend InterfaceController, + in *domain.PhysicalInterface, + peers []domain.PhysicalPeer, +) error { now := time.Now() iface := domain.ConvertPhysicalInterface(in) iface.BaseModel = domain.BaseModel{ @@ -775,8 +785,20 @@ func (m Manager) importInterface(ctx context.Context, in *domain.PhysicalInterfa CreatedAt: now, UpdatedAt: now, } + iface.Backend = backend.GetId() iface.PeerDefAllowedIPsStr = iface.AddressStr() + // try to predict the interface type based on the number of peers + switch len(peers) { + case 0: + iface.Type = domain.InterfaceTypeAny // no peers means this is an unknown interface + case 1: + iface.Type = domain.InterfaceTypeClient // one peer means this is a client interface + default: // multiple peers means this is a server interface + + iface.Type = domain.InterfaceTypeServer + } + existingInterface, err := m.db.GetInterface(ctx, iface.Identifier) if err != nil && !errors.Is(err, domain.ErrNotFound) { return err @@ -827,16 +849,20 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain peer.Interface.PreDown = domain.NewConfigOption(in.PeerDefPreDown, true) peer.Interface.PostDown = domain.NewConfigOption(in.PeerDefPostDown, true) + var displayName string switch in.Type { case domain.InterfaceTypeAny: peer.Interface.Type = domain.InterfaceTypeAny - peer.DisplayName = "Autodetected Peer (" + peer.Interface.PublicKey[0:8] + ")" + displayName = "Autodetected Peer (" + peer.Interface.PublicKey[0:8] + ")" case domain.InterfaceTypeClient: peer.Interface.Type = domain.InterfaceTypeServer - peer.DisplayName = "Autodetected Endpoint (" + peer.Interface.PublicKey[0:8] + ")" + displayName = "Autodetected Endpoint (" + peer.Interface.PublicKey[0:8] + ")" case domain.InterfaceTypeServer: peer.Interface.Type = domain.InterfaceTypeClient - peer.DisplayName = "Autodetected Client (" + peer.Interface.PublicKey[0:8] + ")" + displayName = "Autodetected Client (" + peer.Interface.PublicKey[0:8] + ")" + } + if peer.DisplayName == "" { + peer.DisplayName = displayName // use auto-generated display name if not set } err := m.db.SavePeer(ctx, peer.Identifier, func(_ *domain.Peer) (*domain.Peer, error) { @@ -850,12 +876,12 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain } func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) error { - allPeers, err := m.db.GetInterfacePeers(ctx, id) + iface, allPeers, err := m.db.GetInterfaceAndPeers(ctx, id) if err != nil { return err } for _, peer := range allPeers { - err = m.wg.DeletePeer(ctx, id, peer.Identifier) + err = m.wg.GetController(*iface).DeletePeer(ctx, id, peer.Identifier) if err != nil && !errors.Is(err, os.ErrNotExist) { return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err) } diff --git a/internal/app/wireguard/wireguard_peers.go b/internal/app/wireguard/wireguard_peers.go index 500d5bb..99ccdcf 100644 --- a/internal/app/wireguard/wireguard_peers.go +++ b/internal/app/wireguard/wireguard_peers.go @@ -371,7 +371,12 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error return fmt.Errorf("delete not allowed: %w", err) } - err = m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, id) + iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier) + if err != nil { + return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err) + } + + err = m.wg.GetController(*iface).DeletePeer(ctx, peer.InterfaceIdentifier, id) if err != nil { return fmt.Errorf("wireguard failed to delete peer %s: %w", id, err) } @@ -433,35 +438,28 @@ func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error { interfaces := make(map[domain.InterfaceIdentifier]struct{}) - for i := range peers { - peer := peers[i] - var err error - if peer.IsDisabled() || peer.IsExpired() { - err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) { - peer.CopyCalculatedAttributes(p) - - if err := m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, peer.Identifier); err != nil { - return nil, fmt.Errorf("failed to delete wireguard peer %s: %w", peer.Identifier, err) - } - - return peer, nil - }) - } else { - err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) { - peer.CopyCalculatedAttributes(p) - - err := m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier, - func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) { - domain.MergeToPhysicalPeer(pp, peer) - return pp, nil - }) - if err != nil { - return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err) - } - - return peer, nil - }) + for _, peer := range peers { + iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier) + if err != nil { + return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err) } + + // Always save the peer to the backend, regardless of disabled/expired state + // The backend will handle the disabled state appropriately + err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) { + peer.CopyCalculatedAttributes(p) + + err := m.wg.GetController(*iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier, + func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) { + domain.MergeToPhysicalPeer(pp, peer) + return pp, nil + }) + if err != nil { + return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err) + } + + return peer, nil + }) if err != nil { return fmt.Errorf("save failure for peer %s: %w", peer.Identifier, err) } diff --git a/internal/config/backend.go b/internal/config/backend.go new file mode 100644 index 0000000..f81adad --- /dev/null +++ b/internal/config/backend.go @@ -0,0 +1,94 @@ +package config + +import ( + "fmt" + "time" +) + +const LocalBackendName = "local" + +type Backend struct { + Default string `yaml:"default"` // The default backend to use (defaults to the internal backend) + + Mikrotik []BackendMikrotik `yaml:"mikrotik"` +} + +// Validate checks the backend configuration for errors. +func (b *Backend) Validate() error { + if b.Default == "" { + b.Default = LocalBackendName + } + + uniqueMap := make(map[string]struct{}) + for _, backend := range b.Mikrotik { + if backend.Id == LocalBackendName { + return fmt.Errorf("backend ID %q is a reserved keyword", LocalBackendName) + } + if _, exists := uniqueMap[backend.Id]; exists { + return fmt.Errorf("backend ID %q is not unique", backend.Id) + } + uniqueMap[backend.Id] = struct{}{} + } + + if b.Default != LocalBackendName { + if _, ok := uniqueMap[b.Default]; !ok { + return fmt.Errorf("default backend %q is not defined in the configuration", b.Default) + } + } + + return nil +} + +type BackendBase struct { + Id string `yaml:"id"` // A unique id for the backend + DisplayName string `yaml:"display_name"` // A display name for the backend +} + +// GetDisplayName returns the display name of the backend. +// If no display name is set, it falls back to the ID. +func (b BackendBase) GetDisplayName() string { + if b.DisplayName == "" { + return b.Id // Fallback to ID if no display name is set + } + return b.DisplayName +} + +type BackendMikrotik struct { + BackendBase `yaml:",inline"` // Embed the base fields + + ApiUrl string `yaml:"api_url"` // The base URL of the Mikrotik API (e.g., "https://10.10.10.10:8729/rest") + ApiUser string `yaml:"api_user"` + ApiPassword string `yaml:"api_password"` + ApiVerifyTls bool `yaml:"api_verify_tls"` // Whether to verify the TLS certificate of the Mikrotik API + ApiTimeout time.Duration `yaml:"api_timeout"` // Timeout for API requests (default: 30 seconds) + + // Concurrency controls the maximum number of concurrent API requests that this backend will issue + // when enumerating interfaces and their details. If 0 or negative, a default of 5 is used. + Concurrency int `yaml:"concurrency"` + + Debug bool `yaml:"debug"` // Enable debug logging for the Mikrotik backend +} + +// GetConcurrency returns the configured concurrency for this backend or a sane default (5) +// when the configured value is zero or negative. +func (b *BackendMikrotik) GetConcurrency() int { + if b == nil { + return 5 + } + if b.Concurrency <= 0 { + return 5 + } + return b.Concurrency +} + +// GetApiTimeout returns the configured API timeout or a sane default (30 seconds) +// when the configured value is zero or negative. +func (b *BackendMikrotik) GetApiTimeout() time.Duration { + if b == nil { + return 30 * time.Second + } + if b.ApiTimeout <= 0 { + return 30 * time.Second + } + return b.ApiTimeout +} diff --git a/internal/config/config.go b/internal/config/config.go index e64a703..0574133 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -44,6 +44,8 @@ type Config struct { LimitAdditionalUserPeers int `yaml:"limit_additional_user_peers"` } `yaml:"advanced"` + Backend Backend `yaml:"backend"` + Statistics struct { UsePingChecks bool `yaml:"use_ping_checks"` PingCheckWorkers int `yaml:"ping_check_workers"` @@ -99,6 +101,12 @@ func (c *Config) LogStartupValues() { "minPasswordLength", c.Auth.MinPasswordLength, "hideLoginForm", c.Auth.HideLoginForm, ) + + slog.Debug("Config Backend", + "defaultBackend", c.Backend.Default, + "extraBackends", len(c.Backend.Mikrotik), + ) + } // defaultConfig returns the default configuration @@ -122,6 +130,10 @@ func defaultConfig() *Config { DSN: "data/sqlite.db", } + cfg.Backend = Backend{ + Default: LocalBackendName, // local backend is the default (using wgcrtl) + } + cfg.Web = WebConfig{ RequestLogging: false, ExternalUrl: "http://localhost:8888", @@ -201,6 +213,10 @@ func GetConfig() (*Config, error) { } cfg.Web.Sanitize() + err := cfg.Backend.Validate() + if err != nil { + return nil, err + } return cfg, nil } diff --git a/internal/domain/controller.go b/internal/domain/controller.go new file mode 100644 index 0000000..eaefe32 --- /dev/null +++ b/internal/domain/controller.go @@ -0,0 +1,32 @@ +package domain + +// ControllerType defines the type of controller used to manage interfaces. + +const ( + ControllerTypeMikrotik = "mikrotik" + ControllerTypeLocal = "wgctrl" +) + +// Controller extras can be used to store additional information available for specific controllers only. + +type MikrotikInterfaceExtras struct { + Id string // internal mikrotik ID + Comment string + Disabled bool +} + +type MikrotikPeerExtras struct { + Id string // internal mikrotik ID + Name string + Comment string + IsResponder bool + Disabled bool + ClientEndpoint string + ClientAddress string + ClientDns string + ClientKeepalive int +} + +type LocalPeerExtras struct { + Disabled bool +} diff --git a/internal/domain/interface.go b/internal/domain/interface.go index 977f7d3..32fc1c0 100644 --- a/internal/domain/interface.go +++ b/internal/domain/interface.go @@ -10,6 +10,8 @@ import ( "strings" "time" + "golang.org/x/sys/unix" + "github.com/h44z/wg-portal/internal" ) @@ -23,6 +25,7 @@ var allowedFileNameRegex = regexp.MustCompile("[^a-zA-Z0-9-_]+") type InterfaceIdentifier string type InterfaceType string +type InterfaceBackend string type Interface struct { BaseModel @@ -49,11 +52,12 @@ type Interface struct { SaveConfig bool // automatically persist config changes to the wgX.conf file // WG Portal specific - DisplayName string // a nice display name/ description for the interface - Type InterfaceType // the interface type, either InterfaceTypeServer or InterfaceTypeClient - DriverType string // the interface driver type (linux, software, ...) - Disabled *time.Time `gorm:"index"` // flag that specifies if the interface is enabled (up) or not (down) - DisabledReason string // the reason why the interface has been disabled + DisplayName string // a nice display name/ description for the interface + Type InterfaceType // the interface type, either InterfaceTypeServer or InterfaceTypeClient + Backend InterfaceBackend // the backend that is used to manage the interface (wgctrl, mikrotik, ...) + DriverType string // the interface driver type (linux, software, ...) + Disabled *time.Time `gorm:"index"` // flag that specifies if the interface is enabled (up) or not (down) + DisabledReason string // the reason why the interface has been disabled // Default settings for the peer, used for new peers, those settings will be published to ConfigOption options of // the peer config @@ -204,9 +208,31 @@ type PhysicalInterface struct { BytesUpload uint64 BytesDownload uint64 + + backendExtras any // additional backend-specific extras, e.g., domain.MikrotikInterfaceExtras +} + +func (p *PhysicalInterface) GetExtras() any { + return p.backendExtras +} + +func (p *PhysicalInterface) SetExtras(extras any) { + switch extras.(type) { + case MikrotikInterfaceExtras: // OK + default: // we only support MikrotikInterfaceExtras for now + panic(fmt.Sprintf("unsupported interface backend extras type %T", extras)) + } + + p.backendExtras = extras } func ConvertPhysicalInterface(pi *PhysicalInterface) *Interface { + networks := make([]Cidr, 0, len(pi.Addresses)) + for _, addr := range pi.Addresses { + networks = append(networks, addr.NetworkAddr()) + } + + // create a new basic interface with the data from the physical interface iface := &Interface{ Identifier: pi.Identifier, KeyPair: pi.KeyPair, @@ -226,11 +252,11 @@ func ConvertPhysicalInterface(pi *PhysicalInterface) *Interface { Type: InterfaceTypeAny, DriverType: pi.DeviceType, Disabled: nil, - PeerDefNetworkStr: "", + PeerDefNetworkStr: CidrsToString(networks), PeerDefDnsStr: "", PeerDefDnsSearchStr: "", PeerDefEndpoint: "", - PeerDefAllowedIPsStr: "", + PeerDefAllowedIPsStr: CidrsToString(networks), PeerDefMtu: pi.Mtu, PeerDefPersistentKeepalive: 0, PeerDefFirewallMark: 0, @@ -241,6 +267,23 @@ func ConvertPhysicalInterface(pi *PhysicalInterface) *Interface { PeerDefPostDown: "", } + if pi.GetExtras() == nil { + return iface + } + + // enrich the data with controller-specific extras + now := time.Now() + switch pi.ImportSource { + case ControllerTypeMikrotik: + extras := pi.GetExtras().(MikrotikInterfaceExtras) + iface.DisplayName = extras.Comment + if extras.Disabled { + iface.Disabled = &now + } else { + iface.Disabled = nil + } + } + return iface } @@ -253,6 +296,15 @@ func MergeToPhysicalInterface(pi *PhysicalInterface, i *Interface) { pi.FirewallMark = i.FirewallMark pi.DeviceUp = !i.IsDisabled() pi.Addresses = i.Addresses + + switch pi.ImportSource { + case ControllerTypeMikrotik: + extras := MikrotikInterfaceExtras{ + Comment: i.DisplayName, + Disabled: i.IsDisabled(), + } + pi.SetExtras(extras) + } } type RoutingTableInfo struct { @@ -279,3 +331,30 @@ func (r RoutingTableInfo) GetRoutingTable() int { return r.Table } + +type IpFamily int + +const ( + IpFamilyIPv4 IpFamily = unix.AF_INET + IpFamilyIPv6 IpFamily = unix.AF_INET6 +) + +func (f IpFamily) String() string { + switch f { + case IpFamilyIPv4: + return "IPv4" + case IpFamilyIPv6: + return "IPv6" + default: + return "unknown" + } +} + +// RouteRule represents a routing table rule. +type RouteRule struct { + InterfaceId InterfaceIdentifier + IpFamily IpFamily + FwMark uint32 + Table int + HasDefault bool +} diff --git a/internal/domain/peer.go b/internal/domain/peer.go index 93404eb..519d551 100644 --- a/internal/domain/peer.go +++ b/internal/domain/peer.go @@ -129,7 +129,7 @@ func (p *Peer) GenerateDisplayName(prefix string) { p.DisplayName = fmt.Sprintf("%sPeer %s", prefix, internal.TruncateString(string(p.Identifier), 8)) } -// OverwriteUserEditableFields overwrites the user editable fields of the peer with the values from the userPeer +// OverwriteUserEditableFields overwrites the user-editable fields of the peer with the values from the userPeer func (p *Peer) OverwriteUserEditableFields(userPeer *Peer, cfg *config.Config) { p.DisplayName = userPeer.DisplayName if cfg.Core.EditableKeys { @@ -182,9 +182,12 @@ type PhysicalPeer struct { BytesUpload uint64 // upload bytes are the number of bytes that the remote peer has sent to the server BytesDownload uint64 // upload bytes are the number of bytes that the remote peer has received from the server + + ImportSource string // import source (wgctrl, file, ...) + backendExtras any // additional backend-specific extras, e.g., domain.MikrotikPeerExtras } -func (p PhysicalPeer) GetPresharedKey() *wgtypes.Key { +func (p *PhysicalPeer) GetPresharedKey() *wgtypes.Key { if p.PresharedKey == "" { return nil } @@ -196,7 +199,7 @@ func (p PhysicalPeer) GetPresharedKey() *wgtypes.Key { return &key } -func (p PhysicalPeer) GetEndpointAddress() *net.UDPAddr { +func (p *PhysicalPeer) GetEndpointAddress() *net.UDPAddr { if p.Endpoint == "" { return nil } @@ -208,7 +211,7 @@ func (p PhysicalPeer) GetEndpointAddress() *net.UDPAddr { return addr } -func (p PhysicalPeer) GetPersistentKeepaliveTime() *time.Duration { +func (p *PhysicalPeer) GetPersistentKeepaliveTime() *time.Duration { if p.PersistentKeepalive == 0 { return nil } @@ -217,7 +220,7 @@ func (p PhysicalPeer) GetPersistentKeepaliveTime() *time.Duration { return &keepAliveDuration } -func (p PhysicalPeer) GetAllowedIPs() []net.IPNet { +func (p *PhysicalPeer) GetAllowedIPs() []net.IPNet { allowedIPs := make([]net.IPNet, len(p.AllowedIPs)) for i, ip := range p.AllowedIPs { allowedIPs[i] = *ip.IpNet() @@ -226,6 +229,21 @@ func (p PhysicalPeer) GetAllowedIPs() []net.IPNet { return allowedIPs } +func (p *PhysicalPeer) GetExtras() any { + return p.backendExtras +} + +func (p *PhysicalPeer) SetExtras(extras any) { + switch extras.(type) { + case MikrotikPeerExtras: // OK + case LocalPeerExtras: // OK + default: // we only support MikrotikPeerExtras and LocalPeerExtras for now + panic(fmt.Sprintf("unsupported peer backend extras type %T", extras)) + } + + p.backendExtras = extras +} + func ConvertPhysicalPeer(pp *PhysicalPeer) *Peer { peer := &Peer{ Endpoint: NewConfigOption(pp.Endpoint, true), @@ -244,6 +262,44 @@ func ConvertPhysicalPeer(pp *PhysicalPeer) *Peer { }, } + if pp.GetExtras() == nil { + return peer + } + + // enrich the data with controller-specific extras + now := time.Now() + switch pp.ImportSource { + case ControllerTypeMikrotik: + extras := pp.GetExtras().(MikrotikPeerExtras) + peer.Notes = extras.Comment + peer.DisplayName = extras.Name + if extras.ClientEndpoint != "" { // if the client endpoint is set, we assume that this is a client peer + peer.Endpoint = NewConfigOption(extras.ClientEndpoint, true) + peer.Interface.Type = InterfaceTypeClient + peer.Interface.Addresses, _ = CidrsFromString(extras.ClientAddress) + peer.Interface.DnsStr = NewConfigOption(extras.ClientDns, true) + peer.PersistentKeepalive = NewConfigOption(extras.ClientKeepalive, true) + } else { + peer.Interface.Type = InterfaceTypeServer + } + if extras.Disabled { + peer.Disabled = &now + peer.DisabledReason = "Disabled by Mikrotik controller" + } else { + peer.Disabled = nil + peer.DisabledReason = "" + } + case ControllerTypeLocal: + extras := pp.GetExtras().(LocalPeerExtras) + if extras.Disabled { + peer.Disabled = &now + peer.DisabledReason = "Disabled by Local controller" + } else { + peer.Disabled = nil + peer.DisabledReason = "" + } + } + return peer } @@ -265,6 +321,27 @@ func MergeToPhysicalPeer(pp *PhysicalPeer, p *Peer) { pp.PresharedKey = p.PresharedKey pp.PublicKey = p.Interface.PublicKey pp.PersistentKeepalive = p.PersistentKeepalive.GetValue() + + switch pp.ImportSource { + case ControllerTypeMikrotik: + extras := MikrotikPeerExtras{ + Id: "", + Name: p.DisplayName, + Comment: p.Notes, + IsResponder: false, + Disabled: p.IsDisabled(), + ClientEndpoint: p.Endpoint.GetValue(), + ClientAddress: CidrsToString(p.Interface.Addresses), + ClientDns: p.Interface.DnsStr.GetValue(), + ClientKeepalive: p.PersistentKeepalive.GetValue(), + } + pp.SetExtras(extras) + case ControllerTypeLocal: + extras := LocalPeerExtras{ + Disabled: p.IsDisabled(), + } + pp.SetExtras(extras) + } } type PeerCreationRequest struct { diff --git a/internal/domain/statistics.go b/internal/domain/statistics.go index b4a3dca..cbc987d 100644 --- a/internal/domain/statistics.go +++ b/internal/domain/statistics.go @@ -1,6 +1,8 @@ package domain -import "time" +import ( + "time" +) type PeerStatus struct { PeerId PeerIdentifier `gorm:"primaryKey;column:identifier" json:"PeerId"` @@ -37,3 +39,25 @@ type InterfaceStatus struct { BytesReceived uint64 `gorm:"column:received"` BytesTransmitted uint64 `gorm:"column:transmitted"` } + +type PingerResult struct { + PacketsRecv int + PacketsSent int + Rtts []time.Duration +} + +func (r PingerResult) IsPingable() bool { + return r.PacketsRecv > 0 && r.PacketsSent > 0 && len(r.Rtts) > 0 +} + +func (r PingerResult) AverageRtt() time.Duration { + if len(r.Rtts) == 0 { + return 0 + } + + var total time.Duration + for _, rtt := range r.Rtts { + total += rtt + } + return total / time.Duration(len(r.Rtts)) +} diff --git a/internal/logger.go b/internal/logger.go index 994c236..9bca8c8 100644 --- a/internal/logger.go +++ b/internal/logger.go @@ -12,8 +12,8 @@ import ( "sync" ) -// SetupLogging initializes the global logger with the given level and format -func SetupLogging(level string, pretty, json bool) { +// GetLoggingHandler initializes a slog.Handler based on the provided logging level and format options. +func GetLoggingHandler(level string, pretty, json bool) slog.Handler { var logLevel = new(slog.LevelVar) switch strings.ToLower(level) { @@ -46,6 +46,13 @@ func SetupLogging(level string, pretty, json bool) { handler = slog.NewTextHandler(output, opts) } + return handler +} + +// SetupLogging initializes the global logger with the given level and format +func SetupLogging(level string, pretty, json bool) { + handler := GetLoggingHandler(level, pretty, json) + logger := slog.New(handler) slog.SetDefault(logger) diff --git a/internal/lowlevel/mikrotik.go b/internal/lowlevel/mikrotik.go new file mode 100644 index 0000000..49ef1d7 --- /dev/null +++ b/internal/lowlevel/mikrotik.go @@ -0,0 +1,435 @@ +package lowlevel + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/h44z/wg-portal/internal" + "github.com/h44z/wg-portal/internal/config" +) + +// region models + +const ( + MikrotikApiStatusOk = "success" + MikrotikApiStatusError = "error" +) + +const ( + MikrotikApiErrorCodeUnknown = iota + 600 + MikrotikApiErrorCodeRequestPreparationFailed + MikrotikApiErrorCodeRequestFailed + MikrotikApiErrorCodeResponseDecodeFailed +) + +type MikrotikApiResponse[T any] struct { + Status string + Code int + Data T `json:"data,omitempty"` + Error *MikrotikApiError `json:"error,omitempty"` +} + +type MikrotikApiError struct { + Code int `json:"error,omitempty"` + Message string `json:"message,omitempty"` + Details string `json:"detail,omitempty"` +} + +func (e *MikrotikApiError) String() string { + if e == nil { + return "no error" + } + return fmt.Sprintf("API error %d: %s - %s", e.Code, e.Message, e.Details) +} + +type GenericJsonObject map[string]any +type EmptyResponse struct{} + +func (JsonObject GenericJsonObject) GetString(key string) string { + if value, ok := JsonObject[key]; ok { + if strValue, ok := value.(string); ok { + return strValue + } else { + return fmt.Sprintf("%v", value) // Convert to string if not already + } + } + return "" +} + +func (JsonObject GenericJsonObject) GetInt(key string) int { + if value, ok := JsonObject[key]; ok { + if intValue, ok := value.(int); ok { + return intValue + } else { + if floatValue, ok := value.(float64); ok { + return int(floatValue) // Convert float64 to int + } + if strValue, ok := value.(string); ok { + if intValue, err := strconv.Atoi(strValue); err == nil { + return intValue // Convert string to int if possible + } + } + } + } + return 0 +} + +func (JsonObject GenericJsonObject) GetBool(key string) bool { + if value, ok := JsonObject[key]; ok { + if boolValue, ok := value.(bool); ok { + return boolValue + } else { + if intValue, ok := value.(int); ok { + return intValue == 1 // Convert int to bool (1 is true, 0 is false) + } + if floatValue, ok := value.(float64); ok { + return int(floatValue) == 1 // Convert float64 to bool (1.0 is true, 0.0 is false) + } + if strValue, ok := value.(string); ok { + boolValue, err := strconv.ParseBool(strValue) + if err == nil { + return boolValue + } + } + } + } + return false +} + +type MikrotikRequestOptions struct { + Filters map[string]string `json:"filters,omitempty"` + PropList []string `json:"proplist,omitempty"` +} + +func (o *MikrotikRequestOptions) GetPath(base string) string { + if o == nil { + return base + } + + path, err := url.Parse(base) + if err != nil { + return base + } + + query := path.Query() + for k, v := range o.Filters { + query.Set(k, v) + } + if len(o.PropList) > 0 { + query.Set(".proplist", strings.Join(o.PropList, ",")) + } + path.RawQuery = query.Encode() + return path.String() +} + +// region models + +// region API-client + +type MikrotikApiClient struct { + coreCfg *config.Config + cfg *config.BackendMikrotik + + client *http.Client + log *slog.Logger +} + +func NewMikrotikApiClient(coreCfg *config.Config, cfg *config.BackendMikrotik) (*MikrotikApiClient, error) { + c := &MikrotikApiClient{ + coreCfg: coreCfg, + cfg: cfg, + } + + err := c.setup() + if err != nil { + return nil, err + } + + c.debugLog("Mikrotik api client created", "api_url", cfg.ApiUrl) + + return c, nil +} + +func (m *MikrotikApiClient) setup() error { + m.client = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: !m.cfg.ApiVerifyTls, + }, + }, + Timeout: m.cfg.GetApiTimeout(), + } + + if m.cfg.Debug { + m.log = slog.New(internal.GetLoggingHandler("debug", + m.coreCfg.Advanced.LogPretty, + m.coreCfg.Advanced.LogJson). + WithAttrs([]slog.Attr{ + { + Key: "mikrotik-bid", Value: slog.StringValue(m.cfg.Id), + }, + })) + } + + return nil +} + +func (m *MikrotikApiClient) debugLog(msg string, args ...any) { + if m.log != nil { + m.log.Debug("[MT-API] "+msg, args...) + } +} + +func (m *MikrotikApiClient) getFullPath(command string) string { + path, err := url.JoinPath(m.cfg.ApiUrl, command) + if err != nil { + return "" + } + return path +} + +func (m *MikrotikApiClient) prepareGetRequest(ctx context.Context, fullUrl string) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullUrl, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Accept", "application/json") + if m.cfg.ApiUser != "" && m.cfg.ApiPassword != "" { + req.SetBasicAuth(m.cfg.ApiUser, m.cfg.ApiPassword) + } + + return req, nil +} + +func (m *MikrotikApiClient) prepareDeleteRequest(ctx context.Context, fullUrl string) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, fullUrl, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Accept", "application/json") + if m.cfg.ApiUser != "" && m.cfg.ApiPassword != "" { + req.SetBasicAuth(m.cfg.ApiUser, m.cfg.ApiPassword) + } + + return req, nil +} + +func (m *MikrotikApiClient) preparePayloadRequest( + ctx context.Context, + method string, + fullUrl string, + payload GenericJsonObject, +) (*http.Request, error) { + // marshal the payload to JSON + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, method, fullUrl, bytes.NewReader(payloadBytes)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/json") + if m.cfg.ApiUser != "" && m.cfg.ApiPassword != "" { + req.SetBasicAuth(m.cfg.ApiUser, m.cfg.ApiPassword) + } + + return req, nil +} + +func errToApiResponse[T any](code int, message string, err error) MikrotikApiResponse[T] { + return MikrotikApiResponse[T]{ + Status: MikrotikApiStatusError, + Code: code, + Error: &MikrotikApiError{ + Code: code, + Message: message, + Details: err.Error(), + }, + } +} + +func parseHttpResponse[T any](resp *http.Response, err error) MikrotikApiResponse[T] { + if err != nil { + return errToApiResponse[T](MikrotikApiErrorCodeRequestFailed, "failed to execute request", err) + } + + defer func(Body io.ReadCloser) { + err := Body.Close() + if err != nil { + slog.Error("failed to close response body", "error", err) + } + }(resp.Body) + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + var data T + + // if the type of T is EmptyResponse, we can return an empty response with just the status + if _, ok := any(data).(EmptyResponse); ok { + return MikrotikApiResponse[T]{Status: MikrotikApiStatusOk, Code: resp.StatusCode} + } + + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { + return errToApiResponse[T](MikrotikApiErrorCodeResponseDecodeFailed, "failed to decode response", err) + } + return MikrotikApiResponse[T]{Status: MikrotikApiStatusOk, Code: resp.StatusCode, Data: data} + } + + var apiErr MikrotikApiError + if err := json.NewDecoder(resp.Body).Decode(&apiErr); err != nil { + return errToApiResponse[T](resp.StatusCode, "unknown error, unparsable response", err) + } else { + return MikrotikApiResponse[T]{Status: MikrotikApiStatusError, Code: resp.StatusCode, Error: &apiErr} + } +} + +func (m *MikrotikApiClient) Query( + ctx context.Context, + command string, + opts *MikrotikRequestOptions, +) MikrotikApiResponse[[]GenericJsonObject] { + apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout()) + defer cancel() + + fullUrl := opts.GetPath(m.getFullPath(command)) + + req, err := m.prepareGetRequest(apiCtx, fullUrl) + if err != nil { + return errToApiResponse[[]GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed, + "failed to create request", err) + } + + start := time.Now() + m.debugLog("executing API query", "url", fullUrl) + response := parseHttpResponse[[]GenericJsonObject](m.client.Do(req)) + m.debugLog("retrieved API query result", "url", fullUrl, "duration", time.Since(start).String()) + return response +} + +func (m *MikrotikApiClient) Get( + ctx context.Context, + command string, + opts *MikrotikRequestOptions, +) MikrotikApiResponse[GenericJsonObject] { + apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout()) + defer cancel() + + fullUrl := opts.GetPath(m.getFullPath(command)) + + req, err := m.prepareGetRequest(apiCtx, fullUrl) + if err != nil { + return errToApiResponse[GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed, + "failed to create request", err) + } + + start := time.Now() + m.debugLog("executing API get", "url", fullUrl) + response := parseHttpResponse[GenericJsonObject](m.client.Do(req)) + m.debugLog("retrieved API get result", "url", fullUrl, "duration", time.Since(start).String()) + return response +} + +func (m *MikrotikApiClient) Create( + ctx context.Context, + command string, + payload GenericJsonObject, +) MikrotikApiResponse[GenericJsonObject] { + apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout()) + defer cancel() + + fullUrl := m.getFullPath(command) + + req, err := m.preparePayloadRequest(apiCtx, http.MethodPut, fullUrl, payload) + if err != nil { + return errToApiResponse[GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed, + "failed to create request", err) + } + + start := time.Now() + m.debugLog("executing API put", "url", fullUrl) + response := parseHttpResponse[GenericJsonObject](m.client.Do(req)) + m.debugLog("retrieved API put result", "url", fullUrl, "duration", time.Since(start).String()) + return response +} + +func (m *MikrotikApiClient) Update( + ctx context.Context, + command string, + payload GenericJsonObject, +) MikrotikApiResponse[GenericJsonObject] { + apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout()) + defer cancel() + + fullUrl := m.getFullPath(command) + + req, err := m.preparePayloadRequest(apiCtx, http.MethodPatch, fullUrl, payload) + if err != nil { + return errToApiResponse[GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed, + "failed to create request", err) + } + + start := time.Now() + m.debugLog("executing API patch", "url", fullUrl) + response := parseHttpResponse[GenericJsonObject](m.client.Do(req)) + m.debugLog("retrieved API patch result", "url", fullUrl, "duration", time.Since(start).String()) + return response +} + +func (m *MikrotikApiClient) Delete( + ctx context.Context, + command string, +) MikrotikApiResponse[EmptyResponse] { + apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout()) + defer cancel() + + fullUrl := m.getFullPath(command) + + req, err := m.prepareDeleteRequest(apiCtx, fullUrl) + if err != nil { + return errToApiResponse[EmptyResponse](MikrotikApiErrorCodeRequestPreparationFailed, + "failed to create request", err) + } + + start := time.Now() + m.debugLog("executing API delete", "url", fullUrl) + response := parseHttpResponse[EmptyResponse](m.client.Do(req)) + m.debugLog("retrieved API delete result", "url", fullUrl, "duration", time.Since(start).String()) + return response +} + +func (m *MikrotikApiClient) ExecList( + ctx context.Context, + command string, + payload GenericJsonObject, +) MikrotikApiResponse[[]GenericJsonObject] { + apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout()) + defer cancel() + + fullUrl := m.getFullPath(command) + + req, err := m.preparePayloadRequest(apiCtx, http.MethodPost, fullUrl, payload) + if err != nil { + return errToApiResponse[[]GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed, + "failed to create request", err) + } + + start := time.Now() + m.debugLog("executing API post", "url", fullUrl) + response := parseHttpResponse[[]GenericJsonObject](m.client.Do(req)) + m.debugLog("retrieved API post result", "url", fullUrl, "duration", time.Since(start).String()) + return response +} + +// endregion API-client diff --git a/mkdocs.yml b/mkdocs.yml index f7b5169..7184e3c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -80,6 +80,7 @@ nav: - Examples: documentation/configuration/examples.md - Usage: - General: documentation/usage/general.md + - Backends: documentation/usage/backends.md - LDAP: documentation/usage/ldap.md - Security: documentation/usage/security.md - Webhooks: documentation/usage/webhooks.md