mirror of
https://github.com/h44z/wg-portal.git
synced 2025-08-12 08:12:23 +00:00
Mikrotik integration (#467)
Allow MikroTik routes as WireGuard backends
This commit is contained in:
parent
a86f83a219
commit
112f6bfb77
@ -32,6 +32,7 @@ The configuration portal supports using a database (SQLite, MySQL, MsSQL, or Pos
|
||||
* Docker ready
|
||||
* Can be used with existing WireGuard setups
|
||||
* Support for multiple WireGuard interfaces
|
||||
* Supports multiple WireGuard backends (wgctrl or MikroTik [BETA])
|
||||
* Peer Expiry Feature
|
||||
* Handles route and DNS settings like wg-quick does
|
||||
* Exposes Prometheus metrics for monitoring and alerting
|
||||
|
@ -50,7 +50,8 @@ func main() {
|
||||
database, err := adapters.NewSqlRepository(rawDb)
|
||||
internal.AssertNoError(err)
|
||||
|
||||
wireGuard := adapters.NewWireGuardRepository()
|
||||
wireGuard, err := wireguard.NewControllerManager(cfg)
|
||||
internal.AssertNoError(err)
|
||||
|
||||
wgQuick := adapters.NewWgQuickRepo()
|
||||
|
||||
@ -134,7 +135,7 @@ func main() {
|
||||
apiV0EndpointUsers := handlersV0.NewUserEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendUsers)
|
||||
apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendInterfaces)
|
||||
apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendPeers)
|
||||
apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth)
|
||||
apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth, wireGuard)
|
||||
apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth)
|
||||
|
||||
apiFrontend := handlersV0.NewRestApi(apiV0Session,
|
||||
|
@ -25,6 +25,9 @@ core:
|
||||
import_existing: true
|
||||
restore_state: true
|
||||
|
||||
backend:
|
||||
default: local
|
||||
|
||||
advanced:
|
||||
log_level: info
|
||||
log_pretty: false
|
||||
@ -102,6 +105,7 @@ webhook:
|
||||
|
||||
Below you will find sections like
|
||||
[`core`](#core),
|
||||
[`backend`](#backend),
|
||||
[`advanced`](#advanced),
|
||||
[`database`](#database),
|
||||
[`statistics`](#statistics),
|
||||
@ -165,6 +169,65 @@ More advanced options are found in the subsequent `Advanced` section.
|
||||
|
||||
---
|
||||
|
||||
## Backend
|
||||
|
||||
Configuration options for the WireGuard backend, which manages the WireGuard interfaces and peers.
|
||||
The current MikroTik backend is in **BETA** and may not support all features.
|
||||
|
||||
### `default`
|
||||
- **Default:** `local`
|
||||
- **Description:** The default backend to use for managing WireGuard interfaces.
|
||||
Valid options are: `local`, or other backend id's configured in the `mikrotik` section.
|
||||
|
||||
### Mikrotik
|
||||
|
||||
The `mikrotik` array contains a list of MikroTik backend definitions. Each entry describes how to connect to a MikroTik RouterOS instance that hosts WireGuard interfaces.
|
||||
|
||||
Below are the properties for each entry inside `backend.mikrotik`:
|
||||
|
||||
#### `id`
|
||||
- **Default:** *(empty)*
|
||||
- **Description:** A unique identifier for this backend.
|
||||
This value can be referenced by `backend.default` to use this backend as default.
|
||||
The identifier must be unique across all backends and must not use the reserved keyword `local`.
|
||||
|
||||
#### `display_name`
|
||||
- **Default:** *(empty)*
|
||||
- **Description:** A human-friendly display name for this backend. If omitted, the `id` will be used as the display name.
|
||||
|
||||
#### `api_url`
|
||||
- **Default:** *(empty)*
|
||||
- **Description:** Base URL of the MikroTik REST API, including scheme and path, e.g., `https://10.10.10.10:8729/rest`.
|
||||
|
||||
#### `api_user`
|
||||
- **Default:** *(empty)*
|
||||
- **Description:** Username for authenticating against the MikroTik API.
|
||||
Ensure that the user has sufficient permissions to manage WireGuard interfaces and peers.
|
||||
|
||||
#### `api_password`
|
||||
- **Default:** *(empty)*
|
||||
- **Description:** Password for the specified API user.
|
||||
|
||||
#### `api_verify_tls`
|
||||
- **Default:** `false`
|
||||
- **Description:** Whether to verify the TLS certificate of the MikroTik API endpoint. Set to `false` to allow self-signed certificates (not recommended for production).
|
||||
|
||||
#### `api_timeout`
|
||||
- **Default:** `30s`
|
||||
- **Description:** Timeout for API requests to the MikroTik device. Uses Go duration format (e.g., `10s`, `1m`). If omitted, a default of 30 seconds is used.
|
||||
|
||||
#### `concurrency`
|
||||
- **Default:** `5`
|
||||
- **Description:** Maximum number of concurrent API requests the backend will issue when enumerating interfaces and their details. If `0` or negative, a sane default of `5` is used.
|
||||
|
||||
#### `debug`
|
||||
- **Default:** `false`
|
||||
- **Description:** Enable verbose debug logging for the MikroTik backend.
|
||||
|
||||
For more details on configuring the MikroTik backend, see the [Backends](../usage/backends.md) documentation.
|
||||
|
||||
---
|
||||
|
||||
## Advanced
|
||||
|
||||
Additional or more specialized configuration options for logging and interface creation details.
|
||||
|
57
docs/documentation/usage/backends.md
Normal file
57
docs/documentation/usage/backends.md
Normal file
@ -0,0 +1,57 @@
|
||||
# Backends
|
||||
|
||||
WireGuard Portal can manage WireGuard interfaces and peers on different backends.
|
||||
Each backend represents a system where interfaces actually live.
|
||||
You can register multiple backends and choose which one to use per interface.
|
||||
A global default backend determines where newly created interfaces go (unless you explicitly choose another in the UI).
|
||||
|
||||
**Supported backends:**
|
||||
- **Local** (default): Manages interfaces on the host running WireGuard Portal (Linux WireGuard via wgctrl). Use this when the portal should directly configure wg devices on the same server.
|
||||
- **MikroTik** RouterOS (_beta_): Manages interfaces and peers on MikroTik devices via the RouterOS REST API. Use this to control WG interfaces on RouterOS v7+.
|
||||
|
||||
How backend selection works:
|
||||
- The default backend is configured at `backend.default` (_local_ or the id of a defined MikroTik backend).
|
||||
New interfaces created in the UI will use this backend by default.
|
||||
- Each interface stores its backend. You can select a different backend when creating a new interface.
|
||||
|
||||
## Configuring MikroTik backends (RouterOS v7+)
|
||||
|
||||
> :warning: The MikroTik backend is currently marked beta. While basic functionality is implemented, some advanced features are not yet implemented or contain bugs. Please test carefully before using in production.
|
||||
|
||||
The MikroTik backend uses the [REST API](https://help.mikrotik.com/docs/spaces/ROS/pages/47579162/REST+API) under a base URL ending with /rest.
|
||||
You can register one or more MikroTik devices as backends for a single WireGuard Portal instance.
|
||||
|
||||
### Prerequisites on MikroTik:
|
||||
- RouterOS v7 with WireGuard support.
|
||||
- REST API enabled and reachable over HTTP(S). A typical base URL is https://<router-address>:8729/rest or https://<router-address>/rest depending on your service setup.
|
||||
- A dedicated RouterOS user with the following group permissions:
|
||||
- **api** (for logging in via REST API)
|
||||
- **rest-api** (for logging in via REST API)
|
||||
- **read** (to read interface and peer data)
|
||||
- **write** (to create/update interfaces and peers)
|
||||
- **test** (to perform ping checks)
|
||||
- **sensitive** (to read private keys)
|
||||
- TLS certificate on the device is recommended. If you use a self-signed certificate during testing, set `api_verify_tls`: _false_ in wg-portal (not recommended for production).
|
||||
|
||||
Example WireGuard Portal configuration (config/config.yaml):
|
||||
|
||||
```yaml
|
||||
backend:
|
||||
# default backend decides where new interfaces are created
|
||||
default: mikrotik-prod
|
||||
|
||||
mikrotik:
|
||||
- id: mikrotik-prod # unique id, not "local"
|
||||
display_name: RouterOS RB5009 # optional nice name
|
||||
api_url: https://10.10.10.10/rest
|
||||
api_user: wgportal
|
||||
api_password: a-super-secret-password
|
||||
api_verify_tls: true # set to false only if using self-signed during testing
|
||||
api_timeout: 30s # maximum request duration
|
||||
concurrency: 5 # limit parallel REST calls to device
|
||||
debug: false # verbose logging for this backend
|
||||
```
|
||||
|
||||
### Known limitations:
|
||||
- The MikroTik backend is still in beta. Some features may not work as expected.
|
||||
- Not all WireGuard Portal features are supported yet (e.g., no support for interface hooks)
|
@ -10,11 +10,13 @@ import isCidr from "is-cidr";
|
||||
import {isIP} from 'is-ip';
|
||||
import { freshInterface } from '@/helpers/models';
|
||||
import {peerStore} from "@/stores/peers";
|
||||
import {settingsStore} from "@/stores/settings";
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const interfaces = interfaceStore()
|
||||
const peers = peerStore()
|
||||
const settings = settingsStore()
|
||||
|
||||
const props = defineProps({
|
||||
interfaceId: String,
|
||||
@ -48,6 +50,26 @@ const currentTags = ref({
|
||||
PeerDefDnsSearch: ""
|
||||
})
|
||||
const formData = ref(freshInterface())
|
||||
const isSaving = ref(false)
|
||||
const isDeleting = ref(false)
|
||||
const isApplyingDefaults = ref(false)
|
||||
|
||||
const isBackendValid = computed(() => {
|
||||
if (!props.visible || !selectedInterface.value) {
|
||||
return true // if modal is not visible or no interface is selected, we don't care about backend validity
|
||||
}
|
||||
|
||||
let backendId = selectedInterface.value.Backend
|
||||
|
||||
let valid = false
|
||||
let availableBackends = settings.Setting('AvailableBackends') || []
|
||||
availableBackends.forEach(backend => {
|
||||
if (backend.Id === backendId) {
|
||||
valid = true
|
||||
}
|
||||
})
|
||||
return valid
|
||||
})
|
||||
|
||||
// functions
|
||||
|
||||
@ -61,6 +83,7 @@ watch(() => props.visible, async (newValue, oldValue) => {
|
||||
formData.value.Identifier = interfaces.Prepared.Identifier
|
||||
formData.value.DisplayName = interfaces.Prepared.DisplayName
|
||||
formData.value.Mode = interfaces.Prepared.Mode
|
||||
formData.value.Backend = interfaces.Prepared.Backend
|
||||
|
||||
formData.value.PublicKey = interfaces.Prepared.PublicKey
|
||||
formData.value.PrivateKey = interfaces.Prepared.PrivateKey
|
||||
@ -99,6 +122,7 @@ watch(() => props.visible, async (newValue, oldValue) => {
|
||||
formData.value.Identifier = selectedInterface.value.Identifier
|
||||
formData.value.DisplayName = selectedInterface.value.DisplayName
|
||||
formData.value.Mode = selectedInterface.value.Mode
|
||||
formData.value.Backend = selectedInterface.value.Backend
|
||||
|
||||
formData.value.PublicKey = selectedInterface.value.PublicKey
|
||||
formData.value.PrivateKey = selectedInterface.value.PrivateKey
|
||||
@ -237,6 +261,8 @@ function handleChangePeerDefDnsSearch(tags) {
|
||||
}
|
||||
|
||||
async function save() {
|
||||
if (isSaving.value) return
|
||||
isSaving.value = true
|
||||
try {
|
||||
if (props.interfaceId!=='#NEW#') {
|
||||
await interfaces.UpdateInterface(selectedInterface.value.Identifier, formData.value)
|
||||
@ -251,6 +277,8 @@ async function save() {
|
||||
text: e.toString(),
|
||||
type: 'error',
|
||||
})
|
||||
} finally {
|
||||
isSaving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
@ -259,6 +287,8 @@ async function applyPeerDefaults() {
|
||||
return; // do nothing for new interfaces
|
||||
}
|
||||
|
||||
if (isApplyingDefaults.value) return
|
||||
isApplyingDefaults.value = true
|
||||
try {
|
||||
await interfaces.ApplyPeerDefaults(selectedInterface.value.Identifier, formData.value)
|
||||
|
||||
@ -276,10 +306,14 @@ async function applyPeerDefaults() {
|
||||
text: e.toString(),
|
||||
type: 'error',
|
||||
})
|
||||
} finally {
|
||||
isApplyingDefaults.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function del() {
|
||||
if (isDeleting.value) return
|
||||
isDeleting.value = true
|
||||
try {
|
||||
await interfaces.DeleteInterface(selectedInterface.value.Identifier)
|
||||
close()
|
||||
@ -290,6 +324,8 @@ async function del() {
|
||||
text: e.toString(),
|
||||
type: 'error',
|
||||
})
|
||||
} finally {
|
||||
isDeleting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
@ -314,7 +350,8 @@ async function del() {
|
||||
<label class="form-label mt-4">{{ $t('modals.interface-edit.identifier.label') }}</label>
|
||||
<input v-model="formData.Identifier" class="form-control" :placeholder="$t('modals.interface-edit.identifier.placeholder')" type="text">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<div class="row">
|
||||
<div class="form-group col-md-6">
|
||||
<label class="form-label mt-4">{{ $t('modals.interface-edit.mode.label') }}</label>
|
||||
<select v-model="formData.Mode" class="form-select">
|
||||
<option value="server">{{ $t('modals.interface-edit.mode.server') }}</option>
|
||||
@ -322,6 +359,14 @@ async function del() {
|
||||
<option value="any">{{ $t('modals.interface-edit.mode.any') }}</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="form-group col-md-6">
|
||||
<label class="form-label mt-4" for="ifaceBackendSelector">{{ $t('modals.interface-edit.backend.label') }}</label>
|
||||
<select id="ifaceBackendSelector" v-model="formData.Backend" class="form-select" aria-describedby="backendHelp">
|
||||
<option v-for="backend in settings.Setting('AvailableBackends')" :value="backend.Id">{{ backend.Id === 'local' ? $t(backend.Name) : backend.Name }}</option>
|
||||
</select>
|
||||
<small v-if="!isBackendValid" id="backendHelp" class="form-text text-warning">{{ $t('modals.interface-edit.backend.invalid-label') }}</small>
|
||||
</div>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label class="form-label mt-4">{{ $t('modals.interface-edit.display-name.label') }}</label>
|
||||
<input v-model="formData.DisplayName" class="form-control" :placeholder="$t('modals.interface-edit.display-name.placeholder')" type="text">
|
||||
@ -385,12 +430,14 @@ async function del() {
|
||||
<label class="form-label mt-4">{{ $t('modals.interface-edit.mtu.label') }}</label>
|
||||
<input v-model="formData.Mtu" class="form-control" :placeholder="$t('modals.interface-edit.mtu.placeholder')" type="number">
|
||||
</div>
|
||||
<div class="form-group col-md-6">
|
||||
<div class="form-group col-md-6" v-if="formData.Backend==='local'">
|
||||
<label class="form-label mt-4">{{ $t('modals.interface-edit.firewall-mark.label') }}</label>
|
||||
<input v-model="formData.FirewallMark" class="form-control" :placeholder="$t('modals.interface-edit.firewall-mark.placeholder')" type="number">
|
||||
</div>
|
||||
<div class="form-group col-md-6" v-else>
|
||||
</div>
|
||||
<div class="row">
|
||||
</div>
|
||||
<div class="row" v-if="formData.Backend==='local'">
|
||||
<div class="form-group col-md-6">
|
||||
<label class="form-label mt-4">{{ $t('modals.interface-edit.routing-table.label') }}</label>
|
||||
<input v-model="formData.RoutingTable" aria-describedby="routingTableHelp" class="form-control" :placeholder="$t('modals.interface-edit.routing-table.placeholder')" type="text">
|
||||
@ -530,16 +577,25 @@ async function del() {
|
||||
</fieldset>
|
||||
<fieldset v-if="props.interfaceId!=='#NEW#'" class="text-end">
|
||||
<hr class="mt-4">
|
||||
<button class="btn btn-primary me-1" type="button" @click.prevent="applyPeerDefaults">{{ $t('modals.interface-edit.button-apply-defaults') }}</button>
|
||||
<button class="btn btn-primary me-1" type="button" @click.prevent="applyPeerDefaults" :disabled="isApplyingDefaults">
|
||||
<span v-if="isApplyingDefaults" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
|
||||
{{ $t('modals.interface-edit.button-apply-defaults') }}
|
||||
</button>
|
||||
</fieldset>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
<template #footer>
|
||||
<div class="flex-fill text-start">
|
||||
<button v-if="props.interfaceId!=='#NEW#'" class="btn btn-danger me-1" type="button" @click.prevent="del">{{ $t('general.delete') }}</button>
|
||||
<button v-if="props.interfaceId!=='#NEW#'" class="btn btn-danger me-1" type="button" @click.prevent="del" :disabled="isDeleting">
|
||||
<span v-if="isDeleting" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
|
||||
{{ $t('general.delete') }}
|
||||
</button>
|
||||
</div>
|
||||
<button class="btn btn-primary me-1" type="button" @click.prevent="save">{{ $t('general.save') }}</button>
|
||||
<button class="btn btn-primary me-1" type="button" @click.prevent="save" :disabled="isSaving">
|
||||
<span v-if="isSaving" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
|
||||
{{ $t('general.save') }}
|
||||
</button>
|
||||
<button class="btn btn-secondary" type="button" @click.prevent="close">{{ $t('general.close') }}</button>
|
||||
</template>
|
||||
</Modal>
|
||||
|
@ -73,6 +73,8 @@ const currentTags = ref({
|
||||
DnsSearch: ""
|
||||
})
|
||||
const formData = ref(freshPeer())
|
||||
const isSaving = ref(false)
|
||||
const isDeleting = ref(false)
|
||||
|
||||
// functions
|
||||
|
||||
@ -270,6 +272,8 @@ function handleChangeDnsSearch(tags) {
|
||||
}
|
||||
|
||||
async function save() {
|
||||
if (isSaving.value) return
|
||||
isSaving.value = true
|
||||
try {
|
||||
if (props.peerId !== '#NEW#') {
|
||||
await peers.UpdatePeer(selectedPeer.value.Identifier, formData.value)
|
||||
@ -278,26 +282,30 @@ async function save() {
|
||||
}
|
||||
close()
|
||||
} catch (e) {
|
||||
// console.log(e)
|
||||
notify({
|
||||
title: "Failed to save peer!",
|
||||
text: e.toString(),
|
||||
type: 'error',
|
||||
})
|
||||
} finally {
|
||||
isSaving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function del() {
|
||||
if (isDeleting.value) return
|
||||
isDeleting.value = true
|
||||
try {
|
||||
await peers.DeletePeer(selectedPeer.value.Identifier)
|
||||
close()
|
||||
} catch (e) {
|
||||
// console.log(e)
|
||||
notify({
|
||||
title: "Failed to delete peer!",
|
||||
text: e.toString(),
|
||||
type: 'error',
|
||||
})
|
||||
} finally {
|
||||
isDeleting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
@ -470,10 +478,15 @@ async function del() {
|
||||
</template>
|
||||
<template #footer>
|
||||
<div class="flex-fill text-start">
|
||||
<button v-if="props.peerId !== '#NEW#'" class="btn btn-danger me-1" type="button" @click.prevent="del">{{
|
||||
$t('general.delete') }}</button>
|
||||
<button v-if="props.peerId !== '#NEW#'" class="btn btn-danger me-1" type="button" @click.prevent="del" :disabled="isDeleting">
|
||||
<span v-if="isDeleting" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
|
||||
{{ $t('general.delete') }}
|
||||
</button>
|
||||
</div>
|
||||
<button class="btn btn-primary me-1" type="button" @click.prevent="save">{{ $t('general.save') }}</button>
|
||||
<button class="btn btn-primary me-1" type="button" @click.prevent="save" :disabled="isSaving">
|
||||
<span v-if="isSaving" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
|
||||
{{ $t('general.save') }}
|
||||
</button>
|
||||
<button class="btn btn-secondary" type="button" @click.prevent="close">{{ $t('general.close') }}</button>
|
||||
</template>
|
||||
</Modal>
|
||||
|
@ -38,6 +38,7 @@ function freshForm() {
|
||||
|
||||
const currentTag = ref("")
|
||||
const formData = ref(freshForm())
|
||||
const isSaving = ref(false)
|
||||
|
||||
const title = computed(() => {
|
||||
if (!props.visible) {
|
||||
@ -60,12 +61,15 @@ function handleChangeUserIdentifiers(tags) {
|
||||
}
|
||||
|
||||
async function save() {
|
||||
if (isSaving.value) return
|
||||
isSaving.value = true
|
||||
if (formData.value.Identifiers.length === 0) {
|
||||
notify({
|
||||
title: "Missing Identifiers",
|
||||
text: "At least one identifier is required to create a new peer.",
|
||||
type: 'error',
|
||||
})
|
||||
isSaving.value = false
|
||||
return
|
||||
}
|
||||
|
||||
@ -79,6 +83,8 @@ async function save() {
|
||||
text: e.toString(),
|
||||
type: 'error',
|
||||
})
|
||||
} finally {
|
||||
isSaving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
@ -108,7 +114,10 @@ async function save() {
|
||||
</fieldset>
|
||||
</template>
|
||||
<template #footer>
|
||||
<button class="btn btn-primary me-1" type="button" @click.prevent="save">{{ $t('general.save') }}</button>
|
||||
<button class="btn btn-primary me-1" type="button" @click.prevent="save" :disabled="isSaving">
|
||||
<span v-if="isSaving" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
|
||||
{{ $t('general.save') }}
|
||||
</button>
|
||||
<button class="btn btn-secondary" type="button" @click.prevent="close">{{ $t('general.close') }}</button>
|
||||
</template>
|
||||
</Modal>
|
||||
|
@ -34,6 +34,8 @@ const title = computed(() => {
|
||||
})
|
||||
|
||||
const formData = ref(freshUser())
|
||||
const isSaving = ref(false)
|
||||
const isDeleting = ref(false)
|
||||
|
||||
const passwordWeak = computed(() => {
|
||||
return formData.value.Password && formData.value.Password.length > 0 && formData.value.Password.length < settings.Setting('MinPasswordLength')
|
||||
@ -89,6 +91,8 @@ function close() {
|
||||
}
|
||||
|
||||
async function save() {
|
||||
if (isSaving.value) return
|
||||
isSaving.value = true
|
||||
try {
|
||||
if (props.userId!=='#NEW#') {
|
||||
await users.UpdateUser(selectedUser.value.Identifier, formData.value)
|
||||
@ -102,10 +106,14 @@ async function save() {
|
||||
text: e.toString(),
|
||||
type: 'error',
|
||||
})
|
||||
} finally {
|
||||
isSaving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function del() {
|
||||
if (isDeleting.value) return
|
||||
isDeleting.value = true
|
||||
try {
|
||||
await users.DeleteUser(selectedUser.value.Identifier)
|
||||
close()
|
||||
@ -115,6 +123,8 @@ async function del() {
|
||||
text: e.toString(),
|
||||
type: 'error',
|
||||
})
|
||||
} finally {
|
||||
isDeleting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
@ -193,9 +203,15 @@ async function del() {
|
||||
</template>
|
||||
<template #footer>
|
||||
<div class="flex-fill text-start">
|
||||
<button v-if="props.userId!=='#NEW#'" class="btn btn-danger me-1" type="button" @click.prevent="del">{{ $t('general.delete') }}</button>
|
||||
<button v-if="props.userId!=='#NEW#'" class="btn btn-danger me-1" type="button" @click.prevent="del" :disabled="isDeleting">
|
||||
<span v-if="isDeleting" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
|
||||
{{ $t('general.delete') }}
|
||||
</button>
|
||||
</div>
|
||||
<button class="btn btn-primary me-1" type="button" @click.prevent="save" :disabled="!formValid">{{ $t('general.save') }}</button>
|
||||
<button class="btn btn-primary me-1" type="button" @click.prevent="save" :disabled="!formValid || isSaving">
|
||||
<span v-if="isSaving" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
|
||||
{{ $t('general.save') }}
|
||||
</button>
|
||||
<button class="btn btn-secondary" type="button" @click.prevent="close">{{ $t('general.close') }}</button>
|
||||
</template>
|
||||
</Modal>
|
||||
|
@ -55,6 +55,8 @@ const title = computed(() => {
|
||||
})
|
||||
|
||||
const formData = ref(freshPeer())
|
||||
const isSaving = ref(false)
|
||||
const isDeleting = ref(false)
|
||||
|
||||
// functions
|
||||
|
||||
@ -163,6 +165,8 @@ function close() {
|
||||
}
|
||||
|
||||
async function save() {
|
||||
if (isSaving.value) return
|
||||
isSaving.value = true
|
||||
try {
|
||||
if (props.peerId !== '#NEW#') {
|
||||
await peers.UpdatePeer(selectedPeer.value.Identifier, formData.value)
|
||||
@ -171,26 +175,30 @@ async function save() {
|
||||
}
|
||||
close()
|
||||
} catch (e) {
|
||||
// console.log(e)
|
||||
notify({
|
||||
title: "Failed to save peer!",
|
||||
text: e.toString(),
|
||||
type: 'error',
|
||||
})
|
||||
} finally {
|
||||
isSaving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function del() {
|
||||
if (isDeleting.value) return
|
||||
isDeleting.value = true
|
||||
try {
|
||||
await peers.DeletePeer(selectedPeer.value.Identifier)
|
||||
close()
|
||||
} catch (e) {
|
||||
// console.log(e)
|
||||
notify({
|
||||
title: "Failed to delete peer!",
|
||||
text: e.toString(),
|
||||
type: 'error',
|
||||
})
|
||||
} finally {
|
||||
isDeleting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
@ -283,10 +291,15 @@ async function del() {
|
||||
</template>
|
||||
<template #footer>
|
||||
<div class="flex-fill text-start">
|
||||
<button v-if="props.peerId !== '#NEW#'" class="btn btn-danger me-1" type="button" @click.prevent="del">{{
|
||||
$t('general.delete') }}</button>
|
||||
<button v-if="props.peerId !== '#NEW#'" class="btn btn-danger me-1" type="button" @click.prevent="del" :disabled="isDeleting">
|
||||
<span v-if="isDeleting" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
|
||||
{{ $t('general.delete') }}
|
||||
</button>
|
||||
</div>
|
||||
<button class="btn btn-primary me-1" type="button" @click.prevent="save">{{ $t('general.save') }}</button>
|
||||
<button class="btn btn-primary me-1" type="button" @click.prevent="save" :disabled="isSaving">
|
||||
<span v-if="isSaving" class="spinner-border spinner-border-sm me-1" role="status" aria-hidden="true"></span>
|
||||
{{ $t('general.save') }}
|
||||
</button>
|
||||
<button class="btn btn-secondary" type="button" @click.prevent="close">{{ $t('general.close') }}</button>
|
||||
</template>
|
||||
</Modal>
|
||||
|
@ -5,6 +5,7 @@ export function freshInterface() {
|
||||
DisplayName: "",
|
||||
Identifier: "",
|
||||
Mode: "server",
|
||||
Backend: "local",
|
||||
|
||||
PublicKey: "",
|
||||
PrivateKey: "",
|
||||
|
@ -102,7 +102,9 @@
|
||||
},
|
||||
"interface": {
|
||||
"headline": "Schnittstellenstatus für",
|
||||
"mode": "Modus",
|
||||
"backend": "Backend",
|
||||
"unknown-backend": "Unbekannt",
|
||||
"wrong-backend": "Ungültiges Backend, das lokale WireGuard Backend wird stattdessen verwendet!",
|
||||
"key": "Öffentlicher Schlüssel",
|
||||
"endpoint": "Öffentlicher Endpunkt",
|
||||
"port": "Port",
|
||||
@ -357,6 +359,11 @@
|
||||
"client": "Client-Modus",
|
||||
"any": "Unbekannter Modus"
|
||||
},
|
||||
"backend": {
|
||||
"label": "Schnittstellenbackend",
|
||||
"invalid-label": "Ursprüngliches Backend ist ungültig, das lokale WireGuard Backend wird stattdessen verwendet!",
|
||||
"local": "Lokales WireGuard Backend"
|
||||
},
|
||||
"display-name": {
|
||||
"label": "Anzeigename",
|
||||
"placeholder": "Der beschreibende Name für die Schnittstelle"
|
||||
|
@ -102,7 +102,9 @@
|
||||
},
|
||||
"interface": {
|
||||
"headline": "Interface status for",
|
||||
"mode": "mode",
|
||||
"backend": "Backend",
|
||||
"unknown-backend": "Unknown",
|
||||
"wrong-backend": "Invalid backend, using local WireGuard backend instead!",
|
||||
"key": "Public Key",
|
||||
"endpoint": "Public Endpoint",
|
||||
"port": "Listening Port",
|
||||
@ -357,6 +359,11 @@
|
||||
"client": "Client Mode",
|
||||
"any": "Unknown Mode"
|
||||
},
|
||||
"backend": {
|
||||
"label": "Interface Backend",
|
||||
"invalid-label": "Original backend is no longer available, using local WireGuard backend instead!",
|
||||
"local": "Local WireGuard Backend"
|
||||
},
|
||||
"display-name": {
|
||||
"label": "Display Name",
|
||||
"placeholder": "The descriptive name for the interface"
|
||||
|
@ -99,7 +99,7 @@
|
||||
},
|
||||
"interface": {
|
||||
"headline": "État de l'interface pour",
|
||||
"mode": "mode",
|
||||
"backend": "backend",
|
||||
"key": "Clé publique",
|
||||
"endpoint": "Point de terminaison public",
|
||||
"port": "Port d'écoute",
|
||||
|
@ -100,7 +100,7 @@
|
||||
},
|
||||
"interface": {
|
||||
"headline": "인터페이스 상태:",
|
||||
"mode": "모드",
|
||||
"backend": "백엔드",
|
||||
"key": "공개 키",
|
||||
"endpoint": "공개 엔드포인트",
|
||||
"port": "수신 포트",
|
||||
|
@ -101,7 +101,7 @@
|
||||
},
|
||||
"interface": {
|
||||
"headline": "Status da interface para",
|
||||
"mode": "modo",
|
||||
"mode": "backend",
|
||||
"key": "Chave Pública",
|
||||
"endpoint": "Endpoint Público",
|
||||
"port": "Porta de Escuta",
|
||||
|
@ -99,7 +99,7 @@
|
||||
},
|
||||
"interface": {
|
||||
"headline": "Статус интерфейса для",
|
||||
"mode": "режим",
|
||||
"backend": "бэкэнд",
|
||||
"key": "Публичный ключ",
|
||||
"endpoint": "Публичная конечная точка",
|
||||
"port": "Порт прослушивания",
|
||||
|
@ -99,7 +99,7 @@
|
||||
},
|
||||
"interface": {
|
||||
"headline": "Статус інтерфейсу для",
|
||||
"mode": "режим",
|
||||
"backend": "бекенд",
|
||||
"key": "Публічний ключ",
|
||||
"endpoint": "Публічна кінцева точка",
|
||||
"port": "Порт прослуховування",
|
||||
|
@ -98,7 +98,7 @@
|
||||
},
|
||||
"interface": {
|
||||
"headline": "Trạng thái giao diện cho",
|
||||
"mode": "chế độ",
|
||||
"backend": "phần sau",
|
||||
"key": "Khóa Công khai",
|
||||
"endpoint": "Điểm cuối Công khai",
|
||||
"port": "Cổng Nghe",
|
||||
|
@ -98,7 +98,7 @@
|
||||
},
|
||||
"interface": {
|
||||
"headline": "接口状态",
|
||||
"mode": "模式",
|
||||
"backend": "后端",
|
||||
"key": "公钥",
|
||||
"endpoint": "公开节点",
|
||||
"port": "监听端口",
|
||||
|
@ -5,17 +5,20 @@ import PeerMultiCreateModal from "../components/PeerMultiCreateModal.vue";
|
||||
import InterfaceEditModal from "../components/InterfaceEditModal.vue";
|
||||
import InterfaceViewModal from "../components/InterfaceViewModal.vue";
|
||||
|
||||
import {onMounted, ref} from "vue";
|
||||
import {computed, onMounted, ref} from "vue";
|
||||
import {peerStore} from "@/stores/peers";
|
||||
import {interfaceStore} from "@/stores/interfaces";
|
||||
import {notify} from "@kyvg/vue3-notification";
|
||||
import {settingsStore} from "@/stores/settings";
|
||||
import {humanFileSize} from '@/helpers/utils';
|
||||
import {useI18n} from "vue-i18n";
|
||||
|
||||
const settings = settingsStore()
|
||||
const interfaces = interfaceStore()
|
||||
const peers = peerStore()
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const viewedPeerId = ref("")
|
||||
const editPeerId = ref("")
|
||||
const multiCreatePeerId = ref("")
|
||||
@ -45,6 +48,33 @@ function calculateInterfaceName(id, name) {
|
||||
return result
|
||||
}
|
||||
|
||||
const calculateBackendName = computed(() => {
|
||||
let backendId = interfaces.GetSelected.Backend
|
||||
|
||||
let backendName = t('interfaces.interface.unknown-backend')
|
||||
let availableBackends = settings.Setting('AvailableBackends') || []
|
||||
availableBackends.forEach(backend => {
|
||||
if (backend.Id === backendId) {
|
||||
backendName = backend.Id === 'local' ? t(backend.Name) : backend.Name
|
||||
}
|
||||
})
|
||||
return backendName
|
||||
})
|
||||
|
||||
const isBackendValid = computed(() => {
|
||||
let backendId = interfaces.GetSelected.Backend
|
||||
|
||||
let valid = false
|
||||
let availableBackends = settings.Setting('AvailableBackends') || []
|
||||
availableBackends.forEach(backend => {
|
||||
if (backend.Id === backendId) {
|
||||
valid = true
|
||||
}
|
||||
})
|
||||
return valid
|
||||
})
|
||||
|
||||
|
||||
async function download() {
|
||||
await interfaces.LoadInterfaceConfig(interfaces.GetSelected.Identifier)
|
||||
|
||||
@ -141,7 +171,7 @@ onMounted(async () => {
|
||||
<div class="card-header">
|
||||
<div class="row">
|
||||
<div class="col-12 col-lg-8">
|
||||
{{ $t('interfaces.interface.headline') }} <strong>{{interfaces.GetSelected.Identifier}}</strong> ({{interfaces.GetSelected.Mode}} {{ $t('interfaces.interface.mode') }})
|
||||
{{ $t('interfaces.interface.headline') }} <strong>{{interfaces.GetSelected.Identifier}}</strong> ({{ $t('modals.interface-edit.mode.' + interfaces.GetSelected.Mode )}} | {{ $t('interfaces.interface.backend') + ": " + calculateBackendName }}<span v-if="!isBackendValid" :title="t('interfaces.interface.wrong-backend')" class="ms-1 me-1"><i class="fa-solid fa-triangle-exclamation"></i></span>)
|
||||
<span v-if="interfaces.GetSelected.Disabled" class="text-danger"><i class="fa fa-circle-xmark" :title="interfaces.GetSelected.DisabledReason"></i></span>
|
||||
</div>
|
||||
<div class="col-12 col-lg-4 text-lg-end">
|
||||
|
864
internal/adapters/wgcontroller/local.go
Normal file
864
internal/adapters/wgcontroller/local.go
Normal file
@ -0,0 +1,864 @@
|
||||
package wgcontroller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
probing "github.com/prometheus-community/pro-bing"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/h44z/wg-portal/internal"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
"github.com/h44z/wg-portal/internal/lowlevel"
|
||||
)
|
||||
|
||||
// region dependencies
|
||||
|
||||
// WgCtrlRepo is used to control local WireGuard devices via the wgctrl-go library.
|
||||
type WgCtrlRepo interface {
|
||||
io.Closer
|
||||
Devices() ([]*wgtypes.Device, error)
|
||||
Device(name string) (*wgtypes.Device, error)
|
||||
ConfigureDevice(name string, cfg wgtypes.Config) error
|
||||
}
|
||||
|
||||
// A NetlinkClient is a type which can control a netlink device.
|
||||
type NetlinkClient interface {
|
||||
LinkAdd(link netlink.Link) error
|
||||
LinkDel(link netlink.Link) error
|
||||
LinkByName(name string) (netlink.Link, error)
|
||||
LinkSetUp(link netlink.Link) error
|
||||
LinkSetDown(link netlink.Link) error
|
||||
LinkSetMTU(link netlink.Link, mtu int) error
|
||||
AddrReplace(link netlink.Link, addr *netlink.Addr) error
|
||||
AddrAdd(link netlink.Link, addr *netlink.Addr) error
|
||||
AddrList(link netlink.Link) ([]netlink.Addr, error)
|
||||
AddrDel(link netlink.Link, addr *netlink.Addr) error
|
||||
RouteAdd(route *netlink.Route) error
|
||||
RouteDel(route *netlink.Route) error
|
||||
RouteReplace(route *netlink.Route) error
|
||||
RouteList(link netlink.Link, family int) ([]netlink.Route, error)
|
||||
RouteListFiltered(family int, filter *netlink.Route, filterMask uint64) ([]netlink.Route, error)
|
||||
RuleAdd(rule *netlink.Rule) error
|
||||
RuleDel(rule *netlink.Rule) error
|
||||
RuleList(family int) ([]netlink.Rule, error)
|
||||
}
|
||||
|
||||
// endregion dependencies
|
||||
|
||||
type LocalController struct {
|
||||
cfg *config.Config
|
||||
|
||||
wg WgCtrlRepo
|
||||
nl NetlinkClient
|
||||
|
||||
shellCmd string
|
||||
resolvConfIfacePrefix string
|
||||
}
|
||||
|
||||
// NewLocalController creates a new local controller instance.
|
||||
// This repository is used to interact with the WireGuard kernel or userspace module.
|
||||
func NewLocalController(cfg *config.Config) (*LocalController, error) {
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create wgctrl client: %w", err)
|
||||
}
|
||||
|
||||
nl := &lowlevel.NetlinkManager{}
|
||||
|
||||
repo := &LocalController{
|
||||
cfg: cfg,
|
||||
|
||||
wg: wg,
|
||||
nl: nl,
|
||||
|
||||
shellCmd: "bash", // we only support bash at the moment
|
||||
resolvConfIfacePrefix: "tun.", // WireGuard interfaces have a tun. prefix in resolvconf
|
||||
}
|
||||
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
func (c LocalController) GetId() domain.InterfaceBackend {
|
||||
return config.LocalBackendName
|
||||
}
|
||||
|
||||
// region wireguard-related
|
||||
|
||||
func (c LocalController) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) {
|
||||
devices, err := c.wg.Devices()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("device list error: %w", err)
|
||||
}
|
||||
|
||||
interfaces := make([]domain.PhysicalInterface, 0, len(devices))
|
||||
for _, device := range devices {
|
||||
interfaceModel, err := c.convertWireGuardInterface(device)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("interface convert failed for %s: %w", device.Name, err)
|
||||
}
|
||||
interfaces = append(interfaces, interfaceModel)
|
||||
}
|
||||
|
||||
return interfaces, nil
|
||||
}
|
||||
|
||||
func (c LocalController) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (
|
||||
*domain.PhysicalInterface,
|
||||
error,
|
||||
) {
|
||||
return c.getInterface(id)
|
||||
}
|
||||
|
||||
func (c LocalController) convertWireGuardInterface(device *wgtypes.Device) (domain.PhysicalInterface, error) {
|
||||
// read data from wgctrl interface
|
||||
|
||||
iface := domain.PhysicalInterface{
|
||||
Identifier: domain.InterfaceIdentifier(device.Name),
|
||||
KeyPair: domain.KeyPair{
|
||||
PrivateKey: device.PrivateKey.String(),
|
||||
PublicKey: device.PublicKey.String(),
|
||||
},
|
||||
ListenPort: device.ListenPort,
|
||||
Addresses: nil,
|
||||
Mtu: 0,
|
||||
FirewallMark: uint32(device.FirewallMark),
|
||||
DeviceUp: false,
|
||||
ImportSource: domain.ControllerTypeLocal,
|
||||
DeviceType: device.Type.String(),
|
||||
BytesUpload: 0,
|
||||
BytesDownload: 0,
|
||||
}
|
||||
|
||||
// read data from netlink interface
|
||||
|
||||
lowLevelInterface, err := c.nl.LinkByName(device.Name)
|
||||
if err != nil {
|
||||
return domain.PhysicalInterface{}, fmt.Errorf("netlink error for %s: %w", device.Name, err)
|
||||
}
|
||||
ipAddresses, err := c.nl.AddrList(lowLevelInterface)
|
||||
if err != nil {
|
||||
return domain.PhysicalInterface{}, fmt.Errorf("ip read error for %s: %w", device.Name, err)
|
||||
}
|
||||
|
||||
for _, addr := range ipAddresses {
|
||||
iface.Addresses = append(iface.Addresses, domain.CidrFromNetlinkAddr(addr))
|
||||
}
|
||||
iface.Mtu = lowLevelInterface.Attrs().MTU
|
||||
iface.DeviceUp = lowLevelInterface.Attrs().OperState == netlink.OperUnknown // wg only supports unknown
|
||||
if stats := lowLevelInterface.Attrs().Statistics; stats != nil {
|
||||
iface.BytesUpload = stats.TxBytes
|
||||
iface.BytesDownload = stats.RxBytes
|
||||
}
|
||||
|
||||
return iface, nil
|
||||
}
|
||||
|
||||
func (c LocalController) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) (
|
||||
[]domain.PhysicalPeer,
|
||||
error,
|
||||
) {
|
||||
device, err := c.wg.Device(string(deviceId))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("device error: %w", err)
|
||||
}
|
||||
|
||||
peers := make([]domain.PhysicalPeer, 0, len(device.Peers))
|
||||
for _, peer := range device.Peers {
|
||||
peerModel, err := c.convertWireGuardPeer(&peer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("peer convert failed for %v: %w", peer.PublicKey, err)
|
||||
}
|
||||
peers = append(peers, peerModel)
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func (c LocalController) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer, error) {
|
||||
peerModel := domain.PhysicalPeer{
|
||||
Identifier: domain.PeerIdentifier(peer.PublicKey.String()),
|
||||
Endpoint: "",
|
||||
AllowedIPs: nil,
|
||||
KeyPair: domain.KeyPair{
|
||||
PublicKey: peer.PublicKey.String(),
|
||||
},
|
||||
PresharedKey: "",
|
||||
PersistentKeepalive: int(peer.PersistentKeepaliveInterval.Seconds()),
|
||||
LastHandshake: peer.LastHandshakeTime,
|
||||
ProtocolVersion: peer.ProtocolVersion,
|
||||
BytesUpload: uint64(peer.ReceiveBytes),
|
||||
BytesDownload: uint64(peer.TransmitBytes),
|
||||
ImportSource: domain.ControllerTypeLocal,
|
||||
}
|
||||
|
||||
// Set local extras - local peers are never disabled in the kernel
|
||||
peerModel.SetExtras(domain.LocalPeerExtras{
|
||||
Disabled: false,
|
||||
})
|
||||
|
||||
for _, addr := range peer.AllowedIPs {
|
||||
peerModel.AllowedIPs = append(peerModel.AllowedIPs, domain.CidrFromIpNet(addr))
|
||||
}
|
||||
if peer.Endpoint != nil {
|
||||
peerModel.Endpoint = peer.Endpoint.String()
|
||||
}
|
||||
if peer.PresharedKey != (wgtypes.Key{}) {
|
||||
peerModel.PresharedKey = domain.PreSharedKey(peer.PresharedKey.String())
|
||||
}
|
||||
|
||||
return peerModel, nil
|
||||
}
|
||||
|
||||
func (c LocalController) SaveInterface(
|
||||
_ context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
||||
) error {
|
||||
physicalInterface, err := c.getOrCreateInterface(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updateFunc != nil {
|
||||
physicalInterface, err = updateFunc(physicalInterface)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.updateLowLevelInterface(physicalInterface); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.updateWireGuardInterface(physicalInterface); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) getOrCreateInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
|
||||
device, err := c.getInterface(id)
|
||||
if err == nil {
|
||||
return device, nil // interface exists
|
||||
}
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("device error: %w", err) // unknown error
|
||||
}
|
||||
|
||||
// create new device
|
||||
if err := c.createLowLevelInterface(id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
device, err = c.getInterface(id)
|
||||
return device, err
|
||||
}
|
||||
|
||||
func (c LocalController) getInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
|
||||
device, err := c.wg.Device(string(id))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pi, err := c.convertWireGuardInterface(device)
|
||||
return &pi, err
|
||||
}
|
||||
|
||||
func (c LocalController) createLowLevelInterface(id domain.InterfaceIdentifier) error {
|
||||
link := &netlink.GenericLink{
|
||||
LinkAttrs: netlink.LinkAttrs{
|
||||
Name: string(id),
|
||||
},
|
||||
LinkType: "wireguard",
|
||||
}
|
||||
err := c.nl.LinkAdd(link)
|
||||
if err != nil {
|
||||
return fmt.Errorf("link add failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) updateLowLevelInterface(pi *domain.PhysicalInterface) error {
|
||||
link, err := c.nl.LinkByName(string(pi.Identifier))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if pi.Mtu != 0 {
|
||||
if err := c.nl.LinkSetMTU(link, pi.Mtu); err != nil {
|
||||
return fmt.Errorf("mtu error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, addr := range pi.Addresses {
|
||||
err := c.nl.AddrReplace(link, addr.NetlinkAddr())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set ip %s: %w", addr.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove unwanted IP addresses
|
||||
rawAddresses, err := c.nl.AddrList(link)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch interface ips: %w", err)
|
||||
}
|
||||
for _, rawAddr := range rawAddresses {
|
||||
netlinkAddr := domain.CidrFromNetlinkAddr(rawAddr)
|
||||
remove := true
|
||||
for _, addr := range pi.Addresses {
|
||||
if addr == netlinkAddr {
|
||||
remove = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !remove {
|
||||
continue
|
||||
}
|
||||
|
||||
err := c.nl.AddrDel(link, &rawAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove deprecated ip %s: %w", netlinkAddr.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update link state
|
||||
if pi.DeviceUp {
|
||||
if err := c.nl.LinkSetUp(link); err != nil {
|
||||
return fmt.Errorf("failed to bring up device: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := c.nl.LinkSetDown(link); err != nil {
|
||||
return fmt.Errorf("failed to bring down device: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
|
||||
pKey, err := wgtypes.NewKey(pi.KeyPair.GetPrivateKeyBytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var fwMark *int
|
||||
if pi.FirewallMark != 0 {
|
||||
intFwMark := int(pi.FirewallMark)
|
||||
fwMark = &intFwMark
|
||||
}
|
||||
err = c.wg.ConfigureDevice(string(pi.Identifier), wgtypes.Config{
|
||||
PrivateKey: &pKey,
|
||||
ListenPort: &pi.ListenPort,
|
||||
FirewallMark: fwMark,
|
||||
ReplacePeers: false,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
|
||||
if err := c.deleteLowLevelInterface(id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) deleteLowLevelInterface(id domain.InterfaceIdentifier) error {
|
||||
link, err := c.nl.LinkByName(string(id))
|
||||
if err != nil {
|
||||
var linkNotFoundError netlink.LinkNotFoundError
|
||||
if errors.As(err, &linkNotFoundError) {
|
||||
return nil // ignore not found error
|
||||
}
|
||||
return fmt.Errorf("unable to find low level interface: %w", err)
|
||||
}
|
||||
|
||||
err = c.nl.LinkDel(link)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete low level interface: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) SavePeer(
|
||||
_ context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
||||
) error {
|
||||
physicalPeer, err := c.getOrCreatePeer(deviceId, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
physicalPeer, err = updateFunc(physicalPeer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if the peer is disabled by looking at the backend extras
|
||||
// For local controller, disabled peers should be deleted
|
||||
if physicalPeer.GetExtras() != nil {
|
||||
switch extras := physicalPeer.GetExtras().(type) {
|
||||
case domain.LocalPeerExtras:
|
||||
if extras.Disabled {
|
||||
// Delete the peer instead of updating it
|
||||
return c.deletePeer(deviceId, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.updatePeer(deviceId, physicalPeer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) getOrCreatePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (
|
||||
*domain.PhysicalPeer,
|
||||
error,
|
||||
) {
|
||||
peer, err := c.getPeer(deviceId, id)
|
||||
if err == nil {
|
||||
return peer, nil // peer exists
|
||||
}
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("peer error: %w", err) // unknown error
|
||||
}
|
||||
|
||||
// create new peer
|
||||
err = c.wg.ConfigureDevice(string(deviceId), wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{
|
||||
{
|
||||
PublicKey: id.ToPublicKey(),
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("peer create error for %s: %w", id.ToPublicKey(), err)
|
||||
}
|
||||
|
||||
peer, err = c.getPeer(deviceId, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("peer error after create: %w", err)
|
||||
}
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
func (c LocalController) getPeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (
|
||||
*domain.PhysicalPeer,
|
||||
error,
|
||||
) {
|
||||
if !id.IsPublicKey() {
|
||||
return nil, errors.New("invalid public key")
|
||||
}
|
||||
|
||||
device, err := c.wg.Device(string(deviceId))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
publicKey := id.ToPublicKey()
|
||||
for _, peer := range device.Peers {
|
||||
if peer.PublicKey != publicKey {
|
||||
continue
|
||||
}
|
||||
|
||||
peerModel, err := c.convertWireGuardPeer(&peer)
|
||||
return &peerModel, err
|
||||
}
|
||||
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
func (c LocalController) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.PhysicalPeer) error {
|
||||
cfg := wgtypes.PeerConfig{
|
||||
PublicKey: pp.GetPublicKey(),
|
||||
Remove: false,
|
||||
UpdateOnly: true,
|
||||
PresharedKey: pp.GetPresharedKey(),
|
||||
Endpoint: pp.GetEndpointAddress(),
|
||||
PersistentKeepaliveInterval: pp.GetPersistentKeepaliveTime(),
|
||||
ReplaceAllowedIPs: true,
|
||||
AllowedIPs: pp.GetAllowedIPs(),
|
||||
}
|
||||
|
||||
err := c.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) DeletePeer(
|
||||
_ context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
id domain.PeerIdentifier,
|
||||
) error {
|
||||
if !id.IsPublicKey() {
|
||||
return errors.New("invalid public key")
|
||||
}
|
||||
|
||||
err := c.deletePeer(deviceId, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) deletePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error {
|
||||
cfg := wgtypes.PeerConfig{
|
||||
PublicKey: id.ToPublicKey(),
|
||||
Remove: true,
|
||||
}
|
||||
|
||||
err := c.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// endregion wireguard-related
|
||||
|
||||
// region wg-quick-related
|
||||
|
||||
func (c LocalController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
|
||||
if hookCmd == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
slog.Debug("executing interface hook", "interface", id, "hook", hookCmd)
|
||||
err := c.exec(hookCmd, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to exec hook: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
|
||||
if dnsStr == "" && dnsSearchStr == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnsServers := internal.SliceString(dnsStr)
|
||||
dnsSearchDomains := internal.SliceString(dnsSearchStr)
|
||||
|
||||
dnsCommand := "resolvconf -a %resPref%i -m 0 -x"
|
||||
dnsCommandInput := make([]string, 0, len(dnsServers)+len(dnsSearchDomains))
|
||||
|
||||
for _, dnsServer := range dnsServers {
|
||||
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("nameserver %s", dnsServer))
|
||||
}
|
||||
for _, searchDomain := range dnsSearchDomains {
|
||||
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("search %s", searchDomain))
|
||||
}
|
||||
|
||||
err := c.exec(dnsCommand, id, dnsCommandInput...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to set dns settings (is resolvconf available?, for systemd create this symlink: ln -s /usr/bin/resolvectl /usr/local/bin/resolvconf): %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) UnsetDNS(id domain.InterfaceIdentifier) error {
|
||||
dnsCommand := "resolvconf -d %resPref%i -f"
|
||||
|
||||
err := c.exec(dnsCommand, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unset dns settings: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) replaceCommandPlaceHolders(command string, interfaceId domain.InterfaceIdentifier) string {
|
||||
command = strings.ReplaceAll(command, "%resPref", c.resolvConfIfacePrefix)
|
||||
return strings.ReplaceAll(command, "%i", string(interfaceId))
|
||||
}
|
||||
|
||||
func (c LocalController) exec(command string, interfaceId domain.InterfaceIdentifier, stdin ...string) error {
|
||||
commandWithInterfaceName := c.replaceCommandPlaceHolders(command, interfaceId)
|
||||
cmd := exec.Command(c.shellCmd, "-ce", commandWithInterfaceName)
|
||||
if len(stdin) > 0 {
|
||||
b := &bytes.Buffer{}
|
||||
for _, ln := range stdin {
|
||||
if _, err := fmt.Fprint(b, ln); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
cmd.Stdin = b
|
||||
}
|
||||
out, err := cmd.CombinedOutput() // execute and wait for output
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err)
|
||||
}
|
||||
slog.Debug("executed shell command",
|
||||
"command", commandWithInterfaceName,
|
||||
"output", string(out))
|
||||
return nil
|
||||
}
|
||||
|
||||
// endregion wg-quick-related
|
||||
|
||||
// region routing-related
|
||||
|
||||
func (c LocalController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error {
|
||||
// update fwmark rules
|
||||
if err := c.setFwMarkRules(rules); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update main rule
|
||||
if err := c.setMainRule(rules); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// cleanup old main rules
|
||||
if err := c.cleanupMainRule(rules); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) setFwMarkRules(rules []domain.RouteRule) error {
|
||||
for _, rule := range rules {
|
||||
existingRules, err := c.nl.RuleList(int(rule.IpFamily))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family %s: %w", rule.IpFamily, err)
|
||||
}
|
||||
|
||||
ruleExists := false
|
||||
for _, existingRule := range existingRules {
|
||||
if rule.FwMark == existingRule.Mark && rule.Table == existingRule.Table {
|
||||
ruleExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if ruleExists {
|
||||
continue // rule already exists, no need to recreate it
|
||||
}
|
||||
|
||||
// create a missing rule
|
||||
if err := c.nl.RuleAdd(&netlink.Rule{
|
||||
Family: int(rule.IpFamily),
|
||||
Table: rule.Table,
|
||||
Mark: rule.FwMark,
|
||||
Invert: true,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: -1,
|
||||
Priority: c.getRulePriority(existingRules),
|
||||
Mask: nil,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to setup %s rule for fwmark %d and table %d: %w",
|
||||
rule.IpFamily, rule.FwMark, rule.Table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) getRulePriority(existingRules []netlink.Rule) int {
|
||||
prio := 32700 // linux main rule has a priority of 32766
|
||||
for {
|
||||
isFresh := true
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Priority == prio {
|
||||
isFresh = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if isFresh {
|
||||
break
|
||||
} else {
|
||||
prio--
|
||||
}
|
||||
}
|
||||
return prio
|
||||
}
|
||||
|
||||
func (c LocalController) setMainRule(rules []domain.RouteRule) error {
|
||||
var family domain.IpFamily
|
||||
shouldHaveMainRule := false
|
||||
for _, rule := range rules {
|
||||
family = rule.IpFamily
|
||||
if rule.HasDefault == true {
|
||||
shouldHaveMainRule = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !shouldHaveMainRule {
|
||||
return nil
|
||||
}
|
||||
|
||||
existingRules, err := c.nl.RuleList(int(family))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family %s: %w", family, err)
|
||||
}
|
||||
|
||||
ruleExists := false
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
|
||||
ruleExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if ruleExists {
|
||||
return nil // rule already exists, skip re-creation
|
||||
}
|
||||
|
||||
if err := c.nl.RuleAdd(&netlink.Rule{
|
||||
Family: int(family),
|
||||
Table: unix.RT_TABLE_MAIN,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: 0,
|
||||
Priority: c.getMainRulePriority(existingRules),
|
||||
Mark: 0,
|
||||
Mask: nil,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to setup rule for main table: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int {
|
||||
priority := c.cfg.Advanced.RulePrioOffset
|
||||
for {
|
||||
isFresh := true
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Priority == priority {
|
||||
isFresh = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if isFresh {
|
||||
break
|
||||
} else {
|
||||
priority++
|
||||
}
|
||||
}
|
||||
return priority
|
||||
}
|
||||
|
||||
func (c LocalController) cleanupMainRule(rules []domain.RouteRule) error {
|
||||
var family domain.IpFamily
|
||||
for _, rule := range rules {
|
||||
family = rule.IpFamily
|
||||
break
|
||||
}
|
||||
|
||||
existingRules, err := c.nl.RuleList(int(family))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family %s: %w", family, err)
|
||||
}
|
||||
|
||||
shouldHaveMainRule := false
|
||||
for _, rule := range rules {
|
||||
if rule.HasDefault == true {
|
||||
shouldHaveMainRule = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
mainRules := 0
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
|
||||
mainRules++
|
||||
}
|
||||
}
|
||||
|
||||
removalCount := 0
|
||||
if mainRules > 1 {
|
||||
removalCount = mainRules - 1 // we only want one single rule
|
||||
}
|
||||
if !shouldHaveMainRule {
|
||||
removalCount = mainRules
|
||||
}
|
||||
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
|
||||
if removalCount > 0 {
|
||||
existingRule.Family = int(family) // set family, somehow the RuleList method does not populate the family field
|
||||
if err := c.nl.RuleDel(&existingRule); err != nil {
|
||||
return fmt.Errorf("failed to delete main rule: %w", err)
|
||||
}
|
||||
removalCount--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// endregion routing-related
|
||||
|
||||
// region statistics-related
|
||||
|
||||
func (c LocalController) PingAddresses(
|
||||
ctx context.Context,
|
||||
addr string,
|
||||
) (*domain.PingerResult, error) {
|
||||
pinger, err := probing.NewPinger(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to instantiate pinger for %s: %w", addr, err)
|
||||
}
|
||||
|
||||
checkCount := 1
|
||||
pinger.SetPrivileged(!c.cfg.Statistics.PingUnprivileged)
|
||||
pinger.Count = checkCount
|
||||
pinger.Timeout = 2 * time.Second
|
||||
err = pinger.RunWithContext(ctx) // Blocks until finished.
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to ping %s: %w", addr, err)
|
||||
}
|
||||
|
||||
stats := pinger.Statistics()
|
||||
|
||||
return &domain.PingerResult{
|
||||
PacketsRecv: stats.PacketsRecv,
|
||||
PacketsSent: stats.PacketsSent,
|
||||
Rtts: stats.Rtts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// endregion statistics-related
|
829
internal/adapters/wgcontroller/mikrotik.go
Normal file
829
internal/adapters/wgcontroller/mikrotik.go
Normal file
@ -0,0 +1,829 @@
|
||||
package wgcontroller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
"github.com/h44z/wg-portal/internal/lowlevel"
|
||||
)
|
||||
|
||||
type MikrotikController struct {
|
||||
coreCfg *config.Config
|
||||
cfg *config.BackendMikrotik
|
||||
|
||||
client *lowlevel.MikrotikApiClient
|
||||
|
||||
// Add mutexes to prevent race conditions
|
||||
interfaceMutexes sync.Map // map[domain.InterfaceIdentifier]*sync.Mutex
|
||||
peerMutexes sync.Map // map[domain.PeerIdentifier]*sync.Mutex
|
||||
}
|
||||
|
||||
func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik) (*MikrotikController, error) {
|
||||
client, err := lowlevel.NewMikrotikApiClient(coreCfg, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Mikrotik API client: %w", err)
|
||||
}
|
||||
|
||||
return &MikrotikController{
|
||||
coreCfg: coreCfg,
|
||||
cfg: cfg,
|
||||
|
||||
client: client,
|
||||
|
||||
interfaceMutexes: sync.Map{},
|
||||
peerMutexes: sync.Map{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) GetId() domain.InterfaceBackend {
|
||||
return domain.InterfaceBackend(c.cfg.Id)
|
||||
}
|
||||
|
||||
// getInterfaceMutex returns a mutex for the given interface to prevent concurrent modifications
|
||||
func (c *MikrotikController) getInterfaceMutex(id domain.InterfaceIdentifier) *sync.Mutex {
|
||||
mutex, _ := c.interfaceMutexes.LoadOrStore(id, &sync.Mutex{})
|
||||
return mutex.(*sync.Mutex)
|
||||
}
|
||||
|
||||
// getPeerMutex returns a mutex for the given peer to prevent concurrent modifications
|
||||
func (c *MikrotikController) getPeerMutex(id domain.PeerIdentifier) *sync.Mutex {
|
||||
mutex, _ := c.peerMutexes.LoadOrStore(id, &sync.Mutex{})
|
||||
return mutex.(*sync.Mutex)
|
||||
}
|
||||
|
||||
// region wireguard-related
|
||||
|
||||
func (c *MikrotikController) GetInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) {
|
||||
wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
".id", "name", "public-key", "private-key", "listen-port", "mtu", "disabled", "running", "comment",
|
||||
},
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return nil, fmt.Errorf("failed to query interfaces: %v", wgReply.Error)
|
||||
}
|
||||
|
||||
// Parallelize loading of interface details to speed up overall latency.
|
||||
// Use a bounded semaphore to avoid overloading the MikroTik device.
|
||||
maxConcurrent := c.cfg.GetConcurrency()
|
||||
sem := make(chan struct{}, maxConcurrent)
|
||||
|
||||
interfaces := make([]domain.PhysicalInterface, 0, len(wgReply.Data))
|
||||
var mu sync.Mutex
|
||||
var wgWait sync.WaitGroup
|
||||
var firstErr error
|
||||
ctx2, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
for _, wgObj := range wgReply.Data {
|
||||
wgWait.Add(1)
|
||||
sem <- struct{}{} // block if more than maxConcurrent requests are processing
|
||||
go func(wg lowlevel.GenericJsonObject) {
|
||||
defer wgWait.Done()
|
||||
defer func() { <-sem }() // read from the semaphore and make space for the next entry
|
||||
if firstErr != nil {
|
||||
return
|
||||
}
|
||||
pi, err := c.loadInterfaceData(ctx2, wg)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
cancel()
|
||||
}
|
||||
mu.Unlock()
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
interfaces = append(interfaces, *pi)
|
||||
mu.Unlock()
|
||||
}(wgObj)
|
||||
}
|
||||
|
||||
wgWait.Wait()
|
||||
if firstErr != nil {
|
||||
return nil, firstErr
|
||||
}
|
||||
|
||||
return interfaces, nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (
|
||||
*domain.PhysicalInterface,
|
||||
error,
|
||||
) {
|
||||
wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
".id", "name", "public-key", "private-key", "listen-port", "mtu", "disabled", "running",
|
||||
},
|
||||
Filters: map[string]string{
|
||||
"name": string(id),
|
||||
},
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return nil, fmt.Errorf("failed to query interface %s: %v", id, wgReply.Error)
|
||||
}
|
||||
|
||||
if len(wgReply.Data) == 0 {
|
||||
return nil, fmt.Errorf("interface %s not found", id)
|
||||
}
|
||||
|
||||
return c.loadInterfaceData(ctx, wgReply.Data[0])
|
||||
}
|
||||
|
||||
func (c *MikrotikController) loadInterfaceData(
|
||||
ctx context.Context,
|
||||
wireGuardObj lowlevel.GenericJsonObject,
|
||||
) (*domain.PhysicalInterface, error) {
|
||||
deviceId := wireGuardObj.GetString(".id")
|
||||
deviceName := wireGuardObj.GetString("name")
|
||||
ifaceReply := c.client.Get(ctx, "/interface/"+deviceId, &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
"name", "rx-byte", "tx-byte",
|
||||
},
|
||||
})
|
||||
if ifaceReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return nil, fmt.Errorf("failed to query interface %s: %v", deviceId, ifaceReply.Error)
|
||||
}
|
||||
|
||||
ipv4, ipv6, err := c.loadIpAddresses(ctx, deviceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query IP addresses for interface %s: %v", deviceId, err)
|
||||
}
|
||||
addresses := c.convertIpAddresses(ipv4, ipv6)
|
||||
|
||||
interfaceModel, err := c.convertWireGuardInterface(wireGuardObj, ifaceReply.Data, addresses)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("interface convert failed for %s: %w", deviceName, err)
|
||||
}
|
||||
return &interfaceModel, nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) loadIpAddresses(
|
||||
ctx context.Context,
|
||||
deviceName string,
|
||||
) (ipv4 []lowlevel.GenericJsonObject, ipv6 []lowlevel.GenericJsonObject, err error) {
|
||||
// Query IPv4 and IPv6 addresses in parallel to reduce latency.
|
||||
var (
|
||||
v4 []lowlevel.GenericJsonObject
|
||||
v6 []lowlevel.GenericJsonObject
|
||||
v4Err error
|
||||
v6Err error
|
||||
wg sync.WaitGroup
|
||||
)
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
addrV4Reply := c.client.Query(ctx, "/ip/address", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
".id", "address", "network",
|
||||
},
|
||||
Filters: map[string]string{
|
||||
"interface": deviceName,
|
||||
"dynamic": "false", // we only want static addresses
|
||||
"disabled": "false", // we only want addresses that are not disabled
|
||||
},
|
||||
})
|
||||
if addrV4Reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
v4Err = fmt.Errorf("failed to query IPv4 addresses for interface %s: %v", deviceName, addrV4Reply.Error)
|
||||
return
|
||||
}
|
||||
v4 = addrV4Reply.Data
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
addrV6Reply := c.client.Query(ctx, "/ipv6/address", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
".id", "address", "network",
|
||||
},
|
||||
Filters: map[string]string{
|
||||
"interface": deviceName,
|
||||
"dynamic": "false", // we only want static addresses
|
||||
"disabled": "false", // we only want addresses that are not disabled
|
||||
},
|
||||
})
|
||||
if addrV6Reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
v6Err = fmt.Errorf("failed to query IPv6 addresses for interface %s: %v", deviceName, addrV6Reply.Error)
|
||||
return
|
||||
}
|
||||
v6 = addrV6Reply.Data
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
if v4Err != nil {
|
||||
return nil, nil, v4Err
|
||||
}
|
||||
if v6Err != nil {
|
||||
return nil, nil, v6Err
|
||||
}
|
||||
|
||||
return v4, v6, nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) convertIpAddresses(
|
||||
ipv4, ipv6 []lowlevel.GenericJsonObject,
|
||||
) []domain.Cidr {
|
||||
addresses := make([]domain.Cidr, 0, len(ipv4)+len(ipv6))
|
||||
for _, addr := range append(ipv4, ipv6...) {
|
||||
addrStr := addr.GetString("address")
|
||||
if addrStr == "" {
|
||||
continue
|
||||
}
|
||||
cidr, err := domain.CidrFromString(addrStr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
addresses = append(addresses, cidr)
|
||||
}
|
||||
|
||||
return addresses
|
||||
}
|
||||
|
||||
func (c *MikrotikController) convertWireGuardInterface(
|
||||
wg, iface lowlevel.GenericJsonObject,
|
||||
addresses []domain.Cidr,
|
||||
) (
|
||||
domain.PhysicalInterface,
|
||||
error,
|
||||
) {
|
||||
pi := domain.PhysicalInterface{
|
||||
Identifier: domain.InterfaceIdentifier(wg.GetString("name")),
|
||||
KeyPair: domain.KeyPair{
|
||||
PrivateKey: wg.GetString("private-key"),
|
||||
PublicKey: wg.GetString("public-key"),
|
||||
},
|
||||
ListenPort: wg.GetInt("listen-port"),
|
||||
Addresses: addresses,
|
||||
Mtu: wg.GetInt("mtu"),
|
||||
FirewallMark: 0,
|
||||
DeviceUp: wg.GetBool("running"),
|
||||
ImportSource: domain.ControllerTypeMikrotik,
|
||||
DeviceType: domain.ControllerTypeMikrotik,
|
||||
BytesUpload: uint64(iface.GetInt("tx-byte")),
|
||||
BytesDownload: uint64(iface.GetInt("rx-byte")),
|
||||
}
|
||||
|
||||
pi.SetExtras(domain.MikrotikInterfaceExtras{
|
||||
Id: wg.GetString(".id"),
|
||||
Comment: wg.GetString("comment"),
|
||||
Disabled: wg.GetBool("disabled"),
|
||||
})
|
||||
|
||||
return pi, nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) GetPeers(ctx context.Context, deviceId domain.InterfaceIdentifier) (
|
||||
[]domain.PhysicalPeer,
|
||||
error,
|
||||
) {
|
||||
wgReply := c.client.Query(ctx, "/interface/wireguard/peers", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
".id", "name", "allowed-address", "client-address", "client-endpoint", "client-keepalive", "comment",
|
||||
"current-endpoint-address", "current-endpoint-port", "last-handshake", "persistent-keepalive",
|
||||
"public-key", "private-key", "preshared-key", "mtu", "disabled", "rx", "tx", "responder", "client-dns",
|
||||
},
|
||||
Filters: map[string]string{
|
||||
"interface": string(deviceId),
|
||||
},
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return nil, fmt.Errorf("failed to query peers for %s: %v", deviceId, wgReply.Error)
|
||||
}
|
||||
|
||||
if len(wgReply.Data) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
peers := make([]domain.PhysicalPeer, 0, len(wgReply.Data))
|
||||
for _, peer := range wgReply.Data {
|
||||
peerModel, err := c.convertWireGuardPeer(peer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("peer convert failed for %v: %w", peer.GetString("name"), err)
|
||||
}
|
||||
peers = append(peers, peerModel)
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) convertWireGuardPeer(peer lowlevel.GenericJsonObject) (
|
||||
domain.PhysicalPeer,
|
||||
error,
|
||||
) {
|
||||
keepAliveSeconds := 0
|
||||
duration, err := time.ParseDuration(peer.GetString("persistent-keepalive"))
|
||||
if err == nil {
|
||||
keepAliveSeconds = int(duration.Seconds())
|
||||
}
|
||||
|
||||
currentEndpoint := ""
|
||||
if peer.GetString("current-endpoint-address") != "" && peer.GetString("current-endpoint-port") != "" {
|
||||
currentEndpoint = peer.GetString("current-endpoint-address") + ":" + peer.GetString("current-endpoint-port")
|
||||
}
|
||||
|
||||
lastHandshakeTime := time.Time{}
|
||||
if peer.GetString("last-handshake") != "" {
|
||||
relDuration, err := time.ParseDuration(peer.GetString("last-handshake"))
|
||||
if err == nil {
|
||||
lastHandshakeTime = time.Now().Add(-relDuration)
|
||||
}
|
||||
}
|
||||
|
||||
allowedAddresses, _ := domain.CidrsFromString(peer.GetString("allowed-address"))
|
||||
|
||||
clientKeepAliveSeconds := 0
|
||||
duration, err = time.ParseDuration(peer.GetString("client-keepalive"))
|
||||
if err == nil {
|
||||
clientKeepAliveSeconds = int(duration.Seconds())
|
||||
}
|
||||
|
||||
peerModel := domain.PhysicalPeer{
|
||||
Identifier: domain.PeerIdentifier(peer.GetString("public-key")),
|
||||
Endpoint: currentEndpoint,
|
||||
AllowedIPs: allowedAddresses,
|
||||
KeyPair: domain.KeyPair{
|
||||
PublicKey: peer.GetString("public-key"),
|
||||
PrivateKey: peer.GetString("private-key"),
|
||||
},
|
||||
PresharedKey: domain.PreSharedKey(peer.GetString("preshared-key")),
|
||||
PersistentKeepalive: keepAliveSeconds,
|
||||
LastHandshake: lastHandshakeTime,
|
||||
ProtocolVersion: 0, // Mikrotik does not support protocol versioning, so we set it to 0
|
||||
BytesUpload: uint64(peer.GetInt("rx")),
|
||||
BytesDownload: uint64(peer.GetInt("tx")),
|
||||
ImportSource: domain.ControllerTypeMikrotik,
|
||||
}
|
||||
|
||||
peerModel.SetExtras(domain.MikrotikPeerExtras{
|
||||
Id: peer.GetString(".id"),
|
||||
Name: peer.GetString("name"),
|
||||
Comment: peer.GetString("comment"),
|
||||
IsResponder: peer.GetBool("responder"),
|
||||
Disabled: peer.GetBool("disabled"),
|
||||
ClientEndpoint: peer.GetString("client-endpoint"),
|
||||
ClientAddress: peer.GetString("client-address"),
|
||||
ClientDns: peer.GetString("client-dns"),
|
||||
ClientKeepalive: clientKeepAliveSeconds,
|
||||
})
|
||||
|
||||
return peerModel, nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) SaveInterface(
|
||||
ctx context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
||||
) error {
|
||||
// Lock the interface to prevent concurrent modifications
|
||||
mutex := c.getInterfaceMutex(id)
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
physicalInterface, err := c.getOrCreateInterface(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
deviceId := physicalInterface.GetExtras().(domain.MikrotikInterfaceExtras).Id
|
||||
if updateFunc != nil {
|
||||
physicalInterface, err = updateFunc(physicalInterface)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newExtras := physicalInterface.GetExtras().(domain.MikrotikInterfaceExtras)
|
||||
newExtras.Id = deviceId // ensure the ID is not changed
|
||||
physicalInterface.SetExtras(newExtras)
|
||||
}
|
||||
|
||||
if err := c.updateInterface(ctx, physicalInterface); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) getOrCreateInterface(
|
||||
ctx context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
) (*domain.PhysicalInterface, error) {
|
||||
wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
".id", "name", "public-key", "private-key", "listen-port", "mtu", "disabled", "running",
|
||||
},
|
||||
Filters: map[string]string{
|
||||
"name": string(id),
|
||||
},
|
||||
})
|
||||
if wgReply.Status == lowlevel.MikrotikApiStatusOk && len(wgReply.Data) > 0 {
|
||||
return c.loadInterfaceData(ctx, wgReply.Data[0])
|
||||
}
|
||||
|
||||
// create a new interface if it does not exist
|
||||
createReply := c.client.Create(ctx, "/interface/wireguard", lowlevel.GenericJsonObject{
|
||||
"name": string(id),
|
||||
})
|
||||
if wgReply.Status == lowlevel.MikrotikApiStatusOk {
|
||||
return c.loadInterfaceData(ctx, createReply.Data)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to create interface %s: %v", id, createReply.Error)
|
||||
}
|
||||
|
||||
func (c *MikrotikController) updateInterface(ctx context.Context, pi *domain.PhysicalInterface) error {
|
||||
extras := pi.GetExtras().(domain.MikrotikInterfaceExtras)
|
||||
interfaceId := extras.Id
|
||||
wgReply := c.client.Update(ctx, "/interface/wireguard/"+interfaceId, lowlevel.GenericJsonObject{
|
||||
"name": pi.Identifier,
|
||||
"comment": extras.Comment,
|
||||
"mtu": strconv.Itoa(pi.Mtu),
|
||||
"listen-port": strconv.Itoa(pi.ListenPort),
|
||||
"private-key": pi.KeyPair.PrivateKey,
|
||||
"disabled": strconv.FormatBool(!pi.DeviceUp),
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to update interface %s: %v", pi.Identifier, wgReply.Error)
|
||||
}
|
||||
|
||||
// update the interface's addresses
|
||||
currentV4, currentV6, err := c.loadIpAddresses(ctx, string(pi.Identifier))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load current addresses for interface %s: %v", pi.Identifier, err)
|
||||
}
|
||||
currentAddresses := c.convertIpAddresses(currentV4, currentV6)
|
||||
|
||||
// get all addresses that are currently not in the interface, only in pi
|
||||
newAddresses := make([]domain.Cidr, 0, len(pi.Addresses))
|
||||
for _, addr := range pi.Addresses {
|
||||
if slices.Contains(currentAddresses, addr) {
|
||||
continue
|
||||
}
|
||||
newAddresses = append(newAddresses, addr)
|
||||
}
|
||||
// get obsolete addresses that are in the interface, but not in pi
|
||||
obsoleteAddresses := make([]domain.Cidr, 0, len(currentAddresses))
|
||||
for _, addr := range currentAddresses {
|
||||
if slices.Contains(pi.Addresses, addr) {
|
||||
continue
|
||||
}
|
||||
obsoleteAddresses = append(obsoleteAddresses, addr)
|
||||
}
|
||||
|
||||
// update the IP addresses for the interface
|
||||
if err := c.updateIpAddresses(ctx, string(pi.Identifier), currentV4, currentV6,
|
||||
newAddresses, obsoleteAddresses); err != nil {
|
||||
return fmt.Errorf("failed to update IP addresses for interface %s: %v", pi.Identifier, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) updateIpAddresses(
|
||||
ctx context.Context,
|
||||
deviceName string,
|
||||
currentV4, currentV6 []lowlevel.GenericJsonObject,
|
||||
new, obsolete []domain.Cidr,
|
||||
) error {
|
||||
// first, delete all obsolete addresses
|
||||
for _, addr := range obsolete {
|
||||
// find ID of the address to delete
|
||||
if addr.IsV4() {
|
||||
for _, a := range currentV4 {
|
||||
if a.GetString("address") == addr.String() {
|
||||
// delete the address
|
||||
reply := c.client.Delete(ctx, "/ip/address/"+a.GetString(".id"))
|
||||
if reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to delete obsolete IPv4 address %s: %v", addr, reply.Error)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, a := range currentV6 {
|
||||
if a.GetString("address") == addr.String() {
|
||||
// delete the address
|
||||
reply := c.client.Delete(ctx, "/ipv6/address/"+a.GetString(".id"))
|
||||
if reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to delete obsolete IPv6 address %s: %v", addr, reply.Error)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// then, add all new addresses
|
||||
for _, addr := range new {
|
||||
var createPath string
|
||||
if addr.IsV4() {
|
||||
createPath = "/ip/address"
|
||||
} else {
|
||||
createPath = "/ipv6/address"
|
||||
}
|
||||
|
||||
// create the address
|
||||
reply := c.client.Create(ctx, createPath, lowlevel.GenericJsonObject{
|
||||
"address": addr.String(),
|
||||
"interface": deviceName,
|
||||
})
|
||||
if reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to create new address %s: %v", addr, reply.Error)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
|
||||
// Lock the interface to prevent concurrent modifications
|
||||
mutex := c.getInterfaceMutex(id)
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
// delete the interface's addresses
|
||||
currentV4, currentV6, err := c.loadIpAddresses(ctx, string(id))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load current addresses for interface %s: %v", id, err)
|
||||
}
|
||||
for _, a := range currentV4 {
|
||||
// delete the address
|
||||
reply := c.client.Delete(ctx, "/ip/address/"+a.GetString(".id"))
|
||||
if reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to delete IPv4 address %s: %v", a.GetString("address"), reply.Error)
|
||||
}
|
||||
}
|
||||
for _, a := range currentV6 {
|
||||
// delete the address
|
||||
reply := c.client.Delete(ctx, "/ipv6/address/"+a.GetString(".id"))
|
||||
if reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to delete IPv6 address %s: %v", a.GetString("address"), reply.Error)
|
||||
}
|
||||
}
|
||||
|
||||
// delete the WireGuard interface
|
||||
wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{".id"},
|
||||
Filters: map[string]string{
|
||||
"name": string(id),
|
||||
},
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("unable to find WireGuard interface %s: %v", id, wgReply.Error)
|
||||
}
|
||||
if len(wgReply.Data) == 0 {
|
||||
return nil // interface does not exist, nothing to delete
|
||||
}
|
||||
|
||||
interfaceId := wgReply.Data[0].GetString(".id")
|
||||
deleteReply := c.client.Delete(ctx, "/interface/wireguard/"+interfaceId)
|
||||
if deleteReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to delete WireGuard interface %s: %v", id, deleteReply.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) SavePeer(
|
||||
ctx context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
||||
) error {
|
||||
// Lock the peer to prevent concurrent modifications
|
||||
mutex := c.getPeerMutex(id)
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
physicalPeer, err := c.getOrCreatePeer(ctx, deviceId, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peerId := physicalPeer.GetExtras().(domain.MikrotikPeerExtras).Id
|
||||
physicalPeer, err = updateFunc(physicalPeer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newExtras := physicalPeer.GetExtras().(domain.MikrotikPeerExtras)
|
||||
newExtras.Id = peerId // ensure the ID is not changed
|
||||
physicalPeer.SetExtras(newExtras)
|
||||
|
||||
if err := c.updatePeer(ctx, deviceId, physicalPeer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) getOrCreatePeer(
|
||||
ctx context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
id domain.PeerIdentifier,
|
||||
) (*domain.PhysicalPeer, error) {
|
||||
wgReply := c.client.Query(ctx, "/interface/wireguard/peers", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
".id", "name", "public-key", "private-key", "preshared-key", "persistent-keepalive", "client-address",
|
||||
"client-endpoint", "client-keepalive", "allowed-address", "client-dns", "comment", "disabled", "responder",
|
||||
},
|
||||
Filters: map[string]string{
|
||||
"public-key": string(id),
|
||||
"interface": string(deviceId),
|
||||
},
|
||||
})
|
||||
if wgReply.Status == lowlevel.MikrotikApiStatusOk && len(wgReply.Data) > 0 {
|
||||
slog.Debug("found existing Mikrotik peer", "peer", id, "interface", deviceId)
|
||||
existingPeer, err := c.convertWireGuardPeer(wgReply.Data[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &existingPeer, nil
|
||||
}
|
||||
|
||||
// create a new peer if it does not exist
|
||||
slog.Debug("creating new Mikrotik peer", "peer", id, "interface", deviceId)
|
||||
createReply := c.client.Create(ctx, "/interface/wireguard/peers", lowlevel.GenericJsonObject{
|
||||
"name": fmt.Sprintf("tmp-wg-%s", id[0:8]),
|
||||
"interface": string(deviceId),
|
||||
"public-key": string(id),
|
||||
"allowed-address": "0.0.0.0/0", // Use 0.0.0.0/0 as default, will be updated by updatePeer
|
||||
})
|
||||
if createReply.Status == lowlevel.MikrotikApiStatusOk {
|
||||
newPeer, err := c.convertWireGuardPeer(createReply.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
slog.Debug("successfully created Mikrotik peer", "peer", id, "interface", deviceId)
|
||||
return &newPeer, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to create peer %s for interface %s: %v", id, deviceId, createReply.Error)
|
||||
}
|
||||
|
||||
func (c *MikrotikController) updatePeer(
|
||||
ctx context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
pp *domain.PhysicalPeer,
|
||||
) error {
|
||||
extras := pp.GetExtras().(domain.MikrotikPeerExtras)
|
||||
peerId := extras.Id
|
||||
|
||||
endpoint := pp.Endpoint
|
||||
endpointPort := "51820" // default port if not set
|
||||
if s := strings.Split(endpoint, ":"); len(s) == 2 {
|
||||
endpoint = s[0]
|
||||
endpointPort = s[1]
|
||||
}
|
||||
|
||||
allowedAddressStr := domain.CidrsToString(pp.AllowedIPs)
|
||||
slog.Debug("updating Mikrotik peer",
|
||||
"peer", pp.Identifier,
|
||||
"interface", deviceId,
|
||||
"allowed-address", allowedAddressStr,
|
||||
"allowed-ips-count", len(pp.AllowedIPs),
|
||||
"disabled", extras.Disabled)
|
||||
|
||||
wgReply := c.client.Update(ctx, "/interface/wireguard/peers/"+peerId, lowlevel.GenericJsonObject{
|
||||
"name": extras.Name,
|
||||
"comment": extras.Comment,
|
||||
"preshared-key": pp.PresharedKey,
|
||||
"public-key": pp.KeyPair.PublicKey,
|
||||
"private-key": pp.KeyPair.PrivateKey,
|
||||
"persistent-keepalive": (time.Duration(pp.PersistentKeepalive) * time.Second).String(),
|
||||
"disabled": strconv.FormatBool(extras.Disabled),
|
||||
"responder": strconv.FormatBool(extras.IsResponder),
|
||||
"client-endpoint": extras.ClientEndpoint,
|
||||
"client-address": extras.ClientAddress,
|
||||
"client-keepalive": (time.Duration(extras.ClientKeepalive) * time.Second).String(),
|
||||
"client-dns": extras.ClientDns,
|
||||
"endpoint-address": endpoint,
|
||||
"endpoint-port": endpointPort,
|
||||
"allowed-address": allowedAddressStr, // Add the missing allowed-address field
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to update peer %s on interface %s: %v", pp.Identifier, deviceId, wgReply.Error)
|
||||
}
|
||||
|
||||
if extras.Disabled {
|
||||
slog.Debug("successfully disabled Mikrotik peer", "peer", pp.Identifier, "interface", deviceId)
|
||||
} else {
|
||||
slog.Debug("successfully updated Mikrotik peer", "peer", pp.Identifier, "interface", deviceId)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) DeletePeer(
|
||||
ctx context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
id domain.PeerIdentifier,
|
||||
) error {
|
||||
// Lock the peer to prevent concurrent modifications
|
||||
mutex := c.getPeerMutex(id)
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
wgReply := c.client.Query(ctx, "/interface/wireguard/peers", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{".id"},
|
||||
Filters: map[string]string{
|
||||
"public-key": string(id),
|
||||
"interface": string(deviceId),
|
||||
},
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("unable to find WireGuard peer %s for interface %s: %v", id, deviceId, wgReply.Error)
|
||||
}
|
||||
if len(wgReply.Data) == 0 {
|
||||
return nil // peer does not exist, nothing to delete
|
||||
}
|
||||
|
||||
peerId := wgReply.Data[0].GetString(".id")
|
||||
deleteReply := c.client.Delete(ctx, "/interface/wireguard/peers/"+peerId)
|
||||
if deleteReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to delete WireGuard peer %s for interface %s: %v", id, deviceId, deleteReply.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// endregion wireguard-related
|
||||
|
||||
// region wg-quick-related
|
||||
|
||||
func (c *MikrotikController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c *MikrotikController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c *MikrotikController) UnsetDNS(id domain.InterfaceIdentifier) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// endregion wg-quick-related
|
||||
|
||||
// region routing-related
|
||||
|
||||
func (c *MikrotikController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c *MikrotikController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// endregion routing-related
|
||||
|
||||
// region statistics-related
|
||||
|
||||
func (c *MikrotikController) PingAddresses(
|
||||
ctx context.Context,
|
||||
addr string,
|
||||
) (*domain.PingerResult, error) {
|
||||
wgReply := c.client.ExecList(ctx, "/tool/ping",
|
||||
// limit to 1 packet with a max running time of 2 seconds
|
||||
lowlevel.GenericJsonObject{"address": addr, "count": 1, "interval": "00:00:02"},
|
||||
)
|
||||
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return nil, fmt.Errorf("failed to ping %s: %v", addr, wgReply.Error)
|
||||
}
|
||||
|
||||
var result domain.PingerResult
|
||||
for _, item := range wgReply.Data {
|
||||
result.PacketsRecv += item.GetInt("received")
|
||||
result.PacketsSent += item.GetInt("sent")
|
||||
|
||||
rttStr := item.GetString("avg-rtt")
|
||||
if rttStr != "" {
|
||||
rtt, err := time.ParseDuration(rttStr)
|
||||
if err == nil {
|
||||
result.Rtts = append(result.Rtts, rtt)
|
||||
} else {
|
||||
// use a high value to indicate failure or timeout
|
||||
result.Rtts = append(result.Rtts, 999999*time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// endregion statistics-related
|
@ -21,17 +21,23 @@ import (
|
||||
//go:embed frontend_config.js.gotpl
|
||||
var frontendJs embed.FS
|
||||
|
||||
type ControllerManager interface {
|
||||
GetControllerNames() []config.BackendBase
|
||||
}
|
||||
|
||||
type ConfigEndpoint struct {
|
||||
cfg *config.Config
|
||||
authenticator Authenticator
|
||||
controllerMgr ControllerManager
|
||||
|
||||
tpl *respond.TemplateRenderer
|
||||
}
|
||||
|
||||
func NewConfigEndpoint(cfg *config.Config, authenticator Authenticator) ConfigEndpoint {
|
||||
func NewConfigEndpoint(cfg *config.Config, authenticator Authenticator, ctrlMgr ControllerManager) ConfigEndpoint {
|
||||
ep := ConfigEndpoint{
|
||||
cfg: cfg,
|
||||
authenticator: authenticator,
|
||||
controllerMgr: ctrlMgr,
|
||||
tpl: respond.NewTemplateRenderer(template.Must(template.ParseFS(frontendJs,
|
||||
"frontend_config.js.gotpl"))),
|
||||
}
|
||||
@ -96,12 +102,35 @@ func (e ConfigEndpoint) handleSettingsGet() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
sessionUser := domain.GetUserInfo(r.Context())
|
||||
|
||||
controllerFn := func() []model.SettingsBackendNames {
|
||||
controllers := e.controllerMgr.GetControllerNames()
|
||||
names := make([]model.SettingsBackendNames, 0, len(controllers))
|
||||
|
||||
for _, controller := range controllers {
|
||||
displayName := controller.GetDisplayName()
|
||||
if displayName == "" {
|
||||
displayName = controller.Id // fallback to ID if no display name is set
|
||||
}
|
||||
if controller.Id == config.LocalBackendName {
|
||||
displayName = "modals.interface-edit.backend.local" // use a localized string for the local backend
|
||||
}
|
||||
names = append(names, model.SettingsBackendNames{
|
||||
Id: controller.Id,
|
||||
Name: displayName,
|
||||
})
|
||||
}
|
||||
|
||||
return names
|
||||
|
||||
}
|
||||
|
||||
hasSocialLogin := len(e.cfg.Auth.OAuth) > 0 || len(e.cfg.Auth.OpenIDConnect) > 0 || e.cfg.Auth.WebAuthn.Enabled
|
||||
|
||||
// For anonymous users, we return the settings object with minimal information
|
||||
if sessionUser.Id == domain.CtxUnknownUserId || sessionUser.Id == "" {
|
||||
respond.JSON(w, http.StatusOK, model.Settings{
|
||||
WebAuthnEnabled: e.cfg.Auth.WebAuthn.Enabled,
|
||||
AvailableBackends: []model.SettingsBackendNames{}, // return an empty list instead of null
|
||||
LoginFormVisible: !e.cfg.Auth.HideLoginForm || !hasSocialLogin,
|
||||
})
|
||||
} else {
|
||||
@ -112,6 +141,7 @@ func (e ConfigEndpoint) handleSettingsGet() http.HandlerFunc {
|
||||
ApiAdminOnly: e.cfg.Advanced.ApiAdminOnly,
|
||||
WebAuthnEnabled: e.cfg.Auth.WebAuthn.Enabled,
|
||||
MinPasswordLength: e.cfg.Auth.MinPasswordLength,
|
||||
AvailableBackends: controllerFn(),
|
||||
LoginFormVisible: !e.cfg.Auth.HideLoginForm || !hasSocialLogin,
|
||||
})
|
||||
}
|
||||
|
@ -12,5 +12,11 @@ type Settings struct {
|
||||
ApiAdminOnly bool `json:"ApiAdminOnly"`
|
||||
WebAuthnEnabled bool `json:"WebAuthnEnabled"`
|
||||
MinPasswordLength int `json:"MinPasswordLength"`
|
||||
AvailableBackends []SettingsBackendNames `json:"AvailableBackends"`
|
||||
LoginFormVisible bool `json:"LoginFormVisible"`
|
||||
}
|
||||
|
||||
type SettingsBackendNames struct {
|
||||
Id string `json:"Id"`
|
||||
Name string `json:"Name"`
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/h44z/wg-portal/internal"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
@ -11,6 +12,7 @@ type Interface struct {
|
||||
Identifier string `json:"Identifier" example:"wg0"` // device name, for example: wg0
|
||||
DisplayName string `json:"DisplayName"` // a nice display name/ description for the interface
|
||||
Mode string `json:"Mode" example:"server"` // the interface type, either 'server', 'client' or 'any'
|
||||
Backend string `json:"Backend" example:"local"` // the backend used for this interface e.g., local, mikrotik, ...
|
||||
PrivateKey string `json:"PrivateKey" example:"abcdef=="` // private Key of the server interface
|
||||
PublicKey string `json:"PublicKey" example:"abcdef=="` // public Key of the server interface
|
||||
Disabled bool `json:"Disabled"` // flag that specifies if the interface is enabled (up) or not (down)
|
||||
@ -57,6 +59,7 @@ func NewInterface(src *domain.Interface, peers []domain.Peer) *Interface {
|
||||
Identifier: string(src.Identifier),
|
||||
DisplayName: src.DisplayName,
|
||||
Mode: string(src.Type),
|
||||
Backend: string(src.Backend),
|
||||
PrivateKey: src.PrivateKey,
|
||||
PublicKey: src.PublicKey,
|
||||
Disabled: src.IsDisabled(),
|
||||
@ -92,6 +95,10 @@ func NewInterface(src *domain.Interface, peers []domain.Peer) *Interface {
|
||||
Filename: src.GetConfigFileName(),
|
||||
}
|
||||
|
||||
if iface.Backend == "" {
|
||||
iface.Backend = config.LocalBackendName // default to local backend
|
||||
}
|
||||
|
||||
if len(peers) > 0 {
|
||||
iface.TotalPeers = len(peers)
|
||||
|
||||
@ -146,6 +153,7 @@ func NewDomainInterface(src *Interface) *domain.Interface {
|
||||
SaveConfig: src.SaveConfig,
|
||||
DisplayName: src.DisplayName,
|
||||
Type: domain.InterfaceType(src.Mode),
|
||||
Backend: domain.InterfaceBackend(src.Backend),
|
||||
DriverType: "", // currently unused
|
||||
Disabled: nil, // set below
|
||||
DisabledReason: src.DisabledReason,
|
||||
|
@ -46,7 +46,7 @@ func Initialize(
|
||||
users: users,
|
||||
}
|
||||
|
||||
startupContext, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
startupContext, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Switch to admin user context
|
||||
|
166
internal/app/wireguard/controller_manager.go
Normal file
166
internal/app/wireguard/controller_manager.go
Normal file
@ -0,0 +1,166 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"slices"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/adapters/wgcontroller"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type InterfaceController interface {
|
||||
GetId() domain.InterfaceBackend
|
||||
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
|
||||
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
||||
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
||||
SaveInterface(
|
||||
_ context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
||||
) error
|
||||
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
||||
SavePeer(
|
||||
_ context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
||||
) error
|
||||
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
||||
PingAddresses(
|
||||
ctx context.Context,
|
||||
addr string,
|
||||
) (*domain.PingerResult, error)
|
||||
}
|
||||
|
||||
type backendInstance struct {
|
||||
Config config.BackendBase // Config is the configuration for the backend instance.
|
||||
Implementation InterfaceController
|
||||
}
|
||||
|
||||
type ControllerManager struct {
|
||||
cfg *config.Config
|
||||
controllers map[domain.InterfaceBackend]backendInstance
|
||||
}
|
||||
|
||||
func NewControllerManager(cfg *config.Config) (*ControllerManager, error) {
|
||||
c := &ControllerManager{
|
||||
cfg: cfg,
|
||||
controllers: make(map[domain.InterfaceBackend]backendInstance),
|
||||
}
|
||||
|
||||
err := c.init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *ControllerManager) init() error {
|
||||
if err := c.registerLocalController(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.registerMikrotikControllers(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.logRegisteredControllers()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ControllerManager) registerLocalController() error {
|
||||
localController, err := wgcontroller.NewLocalController(c.cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create local WireGuard controller: %w", err)
|
||||
}
|
||||
|
||||
c.controllers[config.LocalBackendName] = backendInstance{
|
||||
Config: config.BackendBase{
|
||||
Id: config.LocalBackendName,
|
||||
DisplayName: "Local WireGuard Controller",
|
||||
},
|
||||
Implementation: localController,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ControllerManager) registerMikrotikControllers() error {
|
||||
for _, backendConfig := range c.cfg.Backend.Mikrotik {
|
||||
if backendConfig.Id == config.LocalBackendName {
|
||||
slog.Warn("skipping registration of Mikrotik controller with reserved ID", "id", config.LocalBackendName)
|
||||
continue
|
||||
}
|
||||
|
||||
controller, err := wgcontroller.NewMikrotikController(c.cfg, &backendConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Mikrotik controller for backend %s: %w", backendConfig.Id, err)
|
||||
}
|
||||
|
||||
c.controllers[domain.InterfaceBackend(backendConfig.Id)] = backendInstance{
|
||||
Config: backendConfig.BackendBase,
|
||||
Implementation: controller,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ControllerManager) logRegisteredControllers() {
|
||||
for backend, controller := range c.controllers {
|
||||
slog.Debug("backend controller registered",
|
||||
"backend", backend, "type", fmt.Sprintf("%T", controller.Implementation))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) InterfaceController {
|
||||
return c.getController(backend, "")
|
||||
}
|
||||
|
||||
func (c *ControllerManager) GetController(iface domain.Interface) InterfaceController {
|
||||
return c.getController(iface.Backend, iface.Identifier)
|
||||
}
|
||||
|
||||
func (c *ControllerManager) getController(
|
||||
backend domain.InterfaceBackend,
|
||||
ifaceId domain.InterfaceIdentifier,
|
||||
) InterfaceController {
|
||||
if backend == "" {
|
||||
// If no backend is specified, use the local controller.
|
||||
// This might be the case for interfaces created in previous WireGuard Portal versions.
|
||||
backend = config.LocalBackendName
|
||||
}
|
||||
|
||||
controller, exists := c.controllers[backend]
|
||||
if !exists {
|
||||
controller, exists = c.controllers[config.LocalBackendName] // Fallback to local controller
|
||||
if !exists {
|
||||
// If the local controller is also not found, panic
|
||||
panic(fmt.Sprintf("%s interface controller for backend %s not found", ifaceId, backend))
|
||||
}
|
||||
slog.Warn("controller for backend not found, using local controller",
|
||||
"backend", backend, "interface", ifaceId)
|
||||
}
|
||||
return controller.Implementation
|
||||
}
|
||||
|
||||
func (c *ControllerManager) GetAllControllers() []InterfaceController {
|
||||
var backendInstances = make([]InterfaceController, 0, len(c.controllers))
|
||||
for instance := range maps.Values(c.controllers) {
|
||||
backendInstances = append(backendInstances, instance.Implementation)
|
||||
}
|
||||
return backendInstances
|
||||
}
|
||||
|
||||
func (c *ControllerManager) GetControllerNames() []config.BackendBase {
|
||||
var names []config.BackendBase
|
||||
for _, id := range slices.Sorted(maps.Keys(c.controllers)) {
|
||||
names = append(names, c.controllers[id].Config)
|
||||
}
|
||||
|
||||
return names
|
||||
}
|
@ -6,8 +6,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
probing "github.com/prometheus-community/pro-bing"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
@ -30,11 +28,6 @@ type StatisticsDatabaseRepo interface {
|
||||
DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error
|
||||
}
|
||||
|
||||
type StatisticsInterfaceController interface {
|
||||
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
||||
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
||||
}
|
||||
|
||||
type StatisticsMetricsServer interface {
|
||||
UpdateInterfaceMetrics(status domain.InterfaceStatus)
|
||||
UpdatePeerMetrics(peer *domain.Peer, status domain.PeerStatus)
|
||||
@ -47,15 +40,20 @@ type StatisticsEventBus interface {
|
||||
Publish(topic string, args ...any)
|
||||
}
|
||||
|
||||
type pingJob struct {
|
||||
Peer domain.Peer
|
||||
Backend domain.InterfaceBackend
|
||||
}
|
||||
|
||||
type StatisticsCollector struct {
|
||||
cfg *config.Config
|
||||
bus StatisticsEventBus
|
||||
|
||||
pingWaitGroup sync.WaitGroup
|
||||
pingJobs chan domain.Peer
|
||||
pingJobs chan pingJob
|
||||
|
||||
db StatisticsDatabaseRepo
|
||||
wg StatisticsInterfaceController
|
||||
wg *ControllerManager
|
||||
ms StatisticsMetricsServer
|
||||
|
||||
peerChangeEvent chan domain.PeerIdentifier
|
||||
@ -66,7 +64,7 @@ func NewStatisticsCollector(
|
||||
cfg *config.Config,
|
||||
bus StatisticsEventBus,
|
||||
db StatisticsDatabaseRepo,
|
||||
wg StatisticsInterfaceController,
|
||||
wg *ControllerManager,
|
||||
ms StatisticsMetricsServer,
|
||||
) (*StatisticsCollector, error) {
|
||||
c := &StatisticsCollector{
|
||||
@ -117,7 +115,7 @@ func (c *StatisticsCollector) collectInterfaceData(ctx context.Context) {
|
||||
}
|
||||
|
||||
for _, in := range interfaces {
|
||||
physicalInterface, err := c.wg.GetInterface(ctx, in.Identifier)
|
||||
physicalInterface, err := c.wg.GetController(in).GetInterface(ctx, in.Identifier)
|
||||
if err != nil {
|
||||
slog.Warn("failed to load physical interface for data collection", "interface", in.Identifier,
|
||||
"error", err)
|
||||
@ -169,7 +167,7 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
|
||||
}
|
||||
|
||||
for _, in := range interfaces {
|
||||
peers, err := c.wg.GetPeers(ctx, in.Identifier)
|
||||
peers, err := c.wg.GetController(in).GetPeers(ctx, in.Identifier)
|
||||
if err != nil {
|
||||
slog.Warn("failed to fetch peers for data collection", "interface", in.Identifier, "error", err)
|
||||
continue
|
||||
@ -271,7 +269,7 @@ func (c *StatisticsCollector) startPingWorkers(ctx context.Context) {
|
||||
|
||||
c.pingWaitGroup = sync.WaitGroup{}
|
||||
c.pingWaitGroup.Add(c.cfg.Statistics.PingCheckWorkers)
|
||||
c.pingJobs = make(chan domain.Peer, c.cfg.Statistics.PingCheckWorkers)
|
||||
c.pingJobs = make(chan pingJob, c.cfg.Statistics.PingCheckWorkers)
|
||||
|
||||
// start workers
|
||||
for i := 0; i < c.cfg.Statistics.PingCheckWorkers; i++ {
|
||||
@ -314,7 +312,10 @@ func (c *StatisticsCollector) enqueuePingChecks(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
for _, peer := range peers {
|
||||
c.pingJobs <- peer
|
||||
c.pingJobs <- pingJob{
|
||||
Peer: peer,
|
||||
Backend: in.Backend,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -323,11 +324,14 @@ func (c *StatisticsCollector) enqueuePingChecks(ctx context.Context) {
|
||||
|
||||
func (c *StatisticsCollector) pingWorker(ctx context.Context) {
|
||||
defer c.pingWaitGroup.Done()
|
||||
for peer := range c.pingJobs {
|
||||
for job := range c.pingJobs {
|
||||
peer := job.Peer
|
||||
backend := job.Backend
|
||||
|
||||
var connectionStateChanged bool
|
||||
var newPeerStatus domain.PeerStatus
|
||||
|
||||
peerPingable := c.isPeerPingable(ctx, peer)
|
||||
peerPingable := c.isPeerPingable(ctx, backend, peer)
|
||||
slog.Debug("peer ping check completed", "peer", peer.Identifier, "pingable", peerPingable)
|
||||
|
||||
now := time.Now()
|
||||
@ -368,7 +372,11 @@ func (c *StatisticsCollector) pingWorker(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Peer) bool {
|
||||
func (c *StatisticsCollector) isPeerPingable(
|
||||
ctx context.Context,
|
||||
backend domain.InterfaceBackend,
|
||||
peer domain.Peer,
|
||||
) bool {
|
||||
if !c.cfg.Statistics.UsePingChecks {
|
||||
return false
|
||||
}
|
||||
@ -378,23 +386,13 @@ func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Pe
|
||||
return false
|
||||
}
|
||||
|
||||
pinger, err := probing.NewPinger(checkAddr)
|
||||
stats, err := c.wg.GetControllerByName(backend).PingAddresses(ctx, checkAddr)
|
||||
if err != nil {
|
||||
slog.Debug("failed to instantiate pinger", "peer", peer.Identifier, "address", checkAddr, "error", err)
|
||||
slog.Debug("failed to ping peer", "peer", peer.Identifier, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
checkCount := 1
|
||||
pinger.SetPrivileged(!c.cfg.Statistics.PingUnprivileged)
|
||||
pinger.Count = checkCount
|
||||
pinger.Timeout = 2 * time.Second
|
||||
err = pinger.RunWithContext(ctx) // Blocks until finished.
|
||||
if err != nil {
|
||||
slog.Debug("pinger for peer exited unexpectedly", "peer", peer.Identifier, "address", checkAddr, "error", err)
|
||||
return false
|
||||
}
|
||||
stats := pinger.Statistics()
|
||||
return stats.PacketsRecv == checkCount
|
||||
return stats.IsPingable()
|
||||
}
|
||||
|
||||
func (c *StatisticsCollector) updateInterfaceMetrics(status domain.InterfaceStatus) {
|
||||
|
@ -37,25 +37,6 @@ type InterfaceAndPeerDatabaseRepo interface {
|
||||
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
|
||||
}
|
||||
|
||||
type InterfaceController interface {
|
||||
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
|
||||
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
||||
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
||||
SaveInterface(
|
||||
_ context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
||||
) error
|
||||
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
||||
SavePeer(
|
||||
_ context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
||||
) error
|
||||
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
||||
}
|
||||
|
||||
type WgQuickController interface {
|
||||
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
|
||||
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
|
||||
@ -75,7 +56,7 @@ type Manager struct {
|
||||
cfg *config.Config
|
||||
bus EventBus
|
||||
db InterfaceAndPeerDatabaseRepo
|
||||
wg InterfaceController
|
||||
wg *ControllerManager
|
||||
quick WgQuickController
|
||||
|
||||
userLockMap *sync.Map
|
||||
@ -84,7 +65,7 @@ type Manager struct {
|
||||
func NewWireGuardManager(
|
||||
cfg *config.Config,
|
||||
bus EventBus,
|
||||
wg InterfaceController,
|
||||
wg *ControllerManager,
|
||||
quick WgQuickController,
|
||||
db InterfaceAndPeerDatabaseRepo,
|
||||
) (*Manager, error) {
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/app/audit"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
@ -21,12 +22,17 @@ func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.Physical
|
||||
return nil, err
|
||||
}
|
||||
|
||||
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
|
||||
var allPhysicalInterfaces []domain.PhysicalInterface
|
||||
for _, wgBackend := range m.wg.GetAllControllers() {
|
||||
physicalInterfaces, err := wgBackend.GetInterfaces(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return physicalInterfaces, nil
|
||||
allPhysicalInterfaces = append(allPhysicalInterfaces, physicalInterfaces...)
|
||||
}
|
||||
|
||||
return allPhysicalInterfaces, nil
|
||||
}
|
||||
|
||||
// GetInterfaceAndPeers returns the interface and all peers for the given interface identifier.
|
||||
@ -109,7 +115,9 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter
|
||||
return 0, err
|
||||
}
|
||||
|
||||
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
|
||||
imported := 0
|
||||
for _, wgBackend := range m.wg.GetAllControllers() {
|
||||
physicalInterfaces, err := wgBackend.GetInterfaces(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -126,7 +134,6 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter
|
||||
}
|
||||
}
|
||||
|
||||
imported := 0
|
||||
for _, physicalInterface := range physicalInterfaces {
|
||||
if slices.Contains(excludedInterfaces, physicalInterface.Identifier) {
|
||||
continue
|
||||
@ -138,12 +145,12 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter
|
||||
|
||||
slog.Info("importing new interface", "interface", physicalInterface.Identifier)
|
||||
|
||||
physicalPeers, err := m.wg.GetPeers(ctx, physicalInterface.Identifier)
|
||||
physicalPeers, err := wgBackend.GetPeers(ctx, physicalInterface.Identifier)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = m.importInterface(ctx, &physicalInterface, physicalPeers)
|
||||
err = m.importInterface(ctx, wgBackend, &physicalInterface, physicalPeers)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err)
|
||||
}
|
||||
@ -151,6 +158,7 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter
|
||||
slog.Info("imported new interface", "interface", physicalInterface.Identifier, "peers", len(physicalPeers))
|
||||
imported++
|
||||
}
|
||||
}
|
||||
|
||||
return imported, nil
|
||||
}
|
||||
@ -213,7 +221,7 @@ func (m Manager) RestoreInterfaceState(
|
||||
return fmt.Errorf("failed to load peers for %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
_, err = m.wg.GetInterface(ctx, iface.Identifier)
|
||||
_, err = m.wg.GetController(iface).GetInterface(ctx, iface.Identifier)
|
||||
if err != nil && !iface.IsDisabled() {
|
||||
slog.Debug("creating missing interface", "interface", iface.Identifier)
|
||||
|
||||
@ -260,18 +268,14 @@ func (m Manager) RestoreInterfaceState(
|
||||
// restore peers
|
||||
for _, peer := range peers {
|
||||
switch {
|
||||
case iface.IsDisabled(): // if interface is disabled, delete all peers
|
||||
if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil {
|
||||
case iface.IsDisabled() && iface.Backend == config.LocalBackendName: // if interface is disabled, delete all peers
|
||||
if err := m.wg.GetController(iface).DeletePeer(ctx, iface.Identifier,
|
||||
peer.Identifier); err != nil {
|
||||
return fmt.Errorf("failed to remove peer %s for disabled interface %s: %w",
|
||||
peer.Identifier, iface.Identifier, err)
|
||||
}
|
||||
case peer.IsDisabled(): // if peer is disabled, delete it
|
||||
if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil {
|
||||
return fmt.Errorf("failed to remove disbaled peer %s from interface %s: %w",
|
||||
peer.Identifier, iface.Identifier, err)
|
||||
}
|
||||
default: // update peer
|
||||
err := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier,
|
||||
err := m.wg.GetController(iface).SavePeer(ctx, iface.Identifier, peer.Identifier,
|
||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
||||
domain.MergeToPhysicalPeer(pp, &peer)
|
||||
return pp, nil
|
||||
@ -284,7 +288,7 @@ func (m Manager) RestoreInterfaceState(
|
||||
}
|
||||
|
||||
// remove non-wgportal peers
|
||||
physicalPeers, _ := m.wg.GetPeers(ctx, iface.Identifier)
|
||||
physicalPeers, _ := m.wg.GetController(iface).GetPeers(ctx, iface.Identifier)
|
||||
for _, physicalPeer := range physicalPeers {
|
||||
isWgPortalPeer := false
|
||||
for _, peer := range peers {
|
||||
@ -294,7 +298,8 @@ func (m Manager) RestoreInterfaceState(
|
||||
}
|
||||
}
|
||||
if !isWgPortalPeer {
|
||||
err := m.wg.DeletePeer(ctx, iface.Identifier, domain.PeerIdentifier(physicalPeer.PublicKey))
|
||||
err := m.wg.GetController(iface).DeletePeer(ctx, iface.Identifier,
|
||||
domain.PeerIdentifier(physicalPeer.PublicKey))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove non-wgportal peer %s from interface %s: %w",
|
||||
physicalPeer.PublicKey, iface.Identifier, err)
|
||||
@ -459,7 +464,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
|
||||
existingInterface.Disabled = &now // simulate a disabled interface
|
||||
existingInterface.DisabledReason = domain.DisabledReasonDeleted
|
||||
|
||||
physicalInterface, _ := m.wg.GetInterface(ctx, id)
|
||||
physicalInterface, _ := m.wg.GetController(*existingInterface).GetInterface(ctx, id)
|
||||
|
||||
if err := m.handleInterfacePreSaveHooks(existingInterface, !existingInterface.IsDisabled(), false); err != nil {
|
||||
return fmt.Errorf("pre-delete hooks failed: %w", err)
|
||||
@ -473,7 +478,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
|
||||
return fmt.Errorf("peer deletion failure: %w", err)
|
||||
}
|
||||
|
||||
if err := m.wg.DeleteInterface(ctx, id); err != nil {
|
||||
if err := m.wg.GetController(*existingInterface).DeleteInterface(ctx, id); err != nil {
|
||||
return fmt.Errorf("wireguard deletion failure: %w", err)
|
||||
}
|
||||
|
||||
@ -522,7 +527,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
|
||||
err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
|
||||
iface.CopyCalculatedAttributes(i)
|
||||
|
||||
err := m.wg.SaveInterface(ctx, iface.Identifier,
|
||||
err := m.wg.GetController(*iface).SaveInterface(ctx, iface.Identifier,
|
||||
func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||
domain.MergeToPhysicalInterface(pi, iface)
|
||||
return pi, nil
|
||||
@ -538,7 +543,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
|
||||
}
|
||||
|
||||
if iface.IsDisabled() {
|
||||
physicalInterface, _ := m.wg.GetInterface(ctx, iface.Identifier)
|
||||
physicalInterface, _ := m.wg.GetController(*iface).GetInterface(ctx, iface.Identifier)
|
||||
fwMark := iface.FirewallMark
|
||||
if physicalInterface != nil && fwMark == 0 {
|
||||
fwMark = physicalInterface.FirewallMark
|
||||
@ -556,13 +561,13 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
|
||||
}
|
||||
|
||||
// If the interface has just been enabled, restore its peers on the physical controller
|
||||
if !oldEnabled && newEnabled {
|
||||
if !oldEnabled && newEnabled && iface.Backend == config.LocalBackendName {
|
||||
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
|
||||
}
|
||||
for _, peer := range peers {
|
||||
saveErr := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier,
|
||||
saveErr := m.wg.GetController(*iface).SavePeer(ctx, iface.Identifier, peer.Identifier,
|
||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
||||
domain.MergeToPhysicalPeer(pp, &peer)
|
||||
return pp, nil
|
||||
@ -766,7 +771,12 @@ func (m Manager) getFreshListenPort(ctx context.Context) (port int, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m Manager) importInterface(ctx context.Context, in *domain.PhysicalInterface, peers []domain.PhysicalPeer) error {
|
||||
func (m Manager) importInterface(
|
||||
ctx context.Context,
|
||||
backend InterfaceController,
|
||||
in *domain.PhysicalInterface,
|
||||
peers []domain.PhysicalPeer,
|
||||
) error {
|
||||
now := time.Now()
|
||||
iface := domain.ConvertPhysicalInterface(in)
|
||||
iface.BaseModel = domain.BaseModel{
|
||||
@ -775,8 +785,20 @@ func (m Manager) importInterface(ctx context.Context, in *domain.PhysicalInterfa
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
iface.Backend = backend.GetId()
|
||||
iface.PeerDefAllowedIPsStr = iface.AddressStr()
|
||||
|
||||
// try to predict the interface type based on the number of peers
|
||||
switch len(peers) {
|
||||
case 0:
|
||||
iface.Type = domain.InterfaceTypeAny // no peers means this is an unknown interface
|
||||
case 1:
|
||||
iface.Type = domain.InterfaceTypeClient // one peer means this is a client interface
|
||||
default: // multiple peers means this is a server interface
|
||||
|
||||
iface.Type = domain.InterfaceTypeServer
|
||||
}
|
||||
|
||||
existingInterface, err := m.db.GetInterface(ctx, iface.Identifier)
|
||||
if err != nil && !errors.Is(err, domain.ErrNotFound) {
|
||||
return err
|
||||
@ -827,16 +849,20 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain
|
||||
peer.Interface.PreDown = domain.NewConfigOption(in.PeerDefPreDown, true)
|
||||
peer.Interface.PostDown = domain.NewConfigOption(in.PeerDefPostDown, true)
|
||||
|
||||
var displayName string
|
||||
switch in.Type {
|
||||
case domain.InterfaceTypeAny:
|
||||
peer.Interface.Type = domain.InterfaceTypeAny
|
||||
peer.DisplayName = "Autodetected Peer (" + peer.Interface.PublicKey[0:8] + ")"
|
||||
displayName = "Autodetected Peer (" + peer.Interface.PublicKey[0:8] + ")"
|
||||
case domain.InterfaceTypeClient:
|
||||
peer.Interface.Type = domain.InterfaceTypeServer
|
||||
peer.DisplayName = "Autodetected Endpoint (" + peer.Interface.PublicKey[0:8] + ")"
|
||||
displayName = "Autodetected Endpoint (" + peer.Interface.PublicKey[0:8] + ")"
|
||||
case domain.InterfaceTypeServer:
|
||||
peer.Interface.Type = domain.InterfaceTypeClient
|
||||
peer.DisplayName = "Autodetected Client (" + peer.Interface.PublicKey[0:8] + ")"
|
||||
displayName = "Autodetected Client (" + peer.Interface.PublicKey[0:8] + ")"
|
||||
}
|
||||
if peer.DisplayName == "" {
|
||||
peer.DisplayName = displayName // use auto-generated display name if not set
|
||||
}
|
||||
|
||||
err := m.db.SavePeer(ctx, peer.Identifier, func(_ *domain.Peer) (*domain.Peer, error) {
|
||||
@ -850,12 +876,12 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain
|
||||
}
|
||||
|
||||
func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) error {
|
||||
allPeers, err := m.db.GetInterfacePeers(ctx, id)
|
||||
iface, allPeers, err := m.db.GetInterfaceAndPeers(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, peer := range allPeers {
|
||||
err = m.wg.DeletePeer(ctx, id, peer.Identifier)
|
||||
err = m.wg.GetController(*iface).DeletePeer(ctx, id, peer.Identifier)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err)
|
||||
}
|
||||
|
@ -371,7 +371,12 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
|
||||
return fmt.Errorf("delete not allowed: %w", err)
|
||||
}
|
||||
|
||||
err = m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, id)
|
||||
iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
|
||||
}
|
||||
|
||||
err = m.wg.GetController(*iface).DeletePeer(ctx, peer.InterfaceIdentifier, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wireguard failed to delete peer %s: %w", id, err)
|
||||
}
|
||||
@ -433,24 +438,18 @@ func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier)
|
||||
func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
|
||||
interfaces := make(map[domain.InterfaceIdentifier]struct{})
|
||||
|
||||
for i := range peers {
|
||||
peer := peers[i]
|
||||
var err error
|
||||
if peer.IsDisabled() || peer.IsExpired() {
|
||||
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
||||
peer.CopyCalculatedAttributes(p)
|
||||
|
||||
if err := m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, peer.Identifier); err != nil {
|
||||
return nil, fmt.Errorf("failed to delete wireguard peer %s: %w", peer.Identifier, err)
|
||||
for _, peer := range peers {
|
||||
iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
})
|
||||
} else {
|
||||
// Always save the peer to the backend, regardless of disabled/expired state
|
||||
// The backend will handle the disabled state appropriately
|
||||
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
||||
peer.CopyCalculatedAttributes(p)
|
||||
|
||||
err := m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
|
||||
err := m.wg.GetController(*iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
|
||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
||||
domain.MergeToPhysicalPeer(pp, peer)
|
||||
return pp, nil
|
||||
@ -461,7 +460,6 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
|
||||
|
||||
return peer, nil
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("save failure for peer %s: %w", peer.Identifier, err)
|
||||
}
|
||||
|
94
internal/config/backend.go
Normal file
94
internal/config/backend.go
Normal file
@ -0,0 +1,94 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
const LocalBackendName = "local"
|
||||
|
||||
type Backend struct {
|
||||
Default string `yaml:"default"` // The default backend to use (defaults to the internal backend)
|
||||
|
||||
Mikrotik []BackendMikrotik `yaml:"mikrotik"`
|
||||
}
|
||||
|
||||
// Validate checks the backend configuration for errors.
|
||||
func (b *Backend) Validate() error {
|
||||
if b.Default == "" {
|
||||
b.Default = LocalBackendName
|
||||
}
|
||||
|
||||
uniqueMap := make(map[string]struct{})
|
||||
for _, backend := range b.Mikrotik {
|
||||
if backend.Id == LocalBackendName {
|
||||
return fmt.Errorf("backend ID %q is a reserved keyword", LocalBackendName)
|
||||
}
|
||||
if _, exists := uniqueMap[backend.Id]; exists {
|
||||
return fmt.Errorf("backend ID %q is not unique", backend.Id)
|
||||
}
|
||||
uniqueMap[backend.Id] = struct{}{}
|
||||
}
|
||||
|
||||
if b.Default != LocalBackendName {
|
||||
if _, ok := uniqueMap[b.Default]; !ok {
|
||||
return fmt.Errorf("default backend %q is not defined in the configuration", b.Default)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type BackendBase struct {
|
||||
Id string `yaml:"id"` // A unique id for the backend
|
||||
DisplayName string `yaml:"display_name"` // A display name for the backend
|
||||
}
|
||||
|
||||
// GetDisplayName returns the display name of the backend.
|
||||
// If no display name is set, it falls back to the ID.
|
||||
func (b BackendBase) GetDisplayName() string {
|
||||
if b.DisplayName == "" {
|
||||
return b.Id // Fallback to ID if no display name is set
|
||||
}
|
||||
return b.DisplayName
|
||||
}
|
||||
|
||||
type BackendMikrotik struct {
|
||||
BackendBase `yaml:",inline"` // Embed the base fields
|
||||
|
||||
ApiUrl string `yaml:"api_url"` // The base URL of the Mikrotik API (e.g., "https://10.10.10.10:8729/rest")
|
||||
ApiUser string `yaml:"api_user"`
|
||||
ApiPassword string `yaml:"api_password"`
|
||||
ApiVerifyTls bool `yaml:"api_verify_tls"` // Whether to verify the TLS certificate of the Mikrotik API
|
||||
ApiTimeout time.Duration `yaml:"api_timeout"` // Timeout for API requests (default: 30 seconds)
|
||||
|
||||
// Concurrency controls the maximum number of concurrent API requests that this backend will issue
|
||||
// when enumerating interfaces and their details. If 0 or negative, a default of 5 is used.
|
||||
Concurrency int `yaml:"concurrency"`
|
||||
|
||||
Debug bool `yaml:"debug"` // Enable debug logging for the Mikrotik backend
|
||||
}
|
||||
|
||||
// GetConcurrency returns the configured concurrency for this backend or a sane default (5)
|
||||
// when the configured value is zero or negative.
|
||||
func (b *BackendMikrotik) GetConcurrency() int {
|
||||
if b == nil {
|
||||
return 5
|
||||
}
|
||||
if b.Concurrency <= 0 {
|
||||
return 5
|
||||
}
|
||||
return b.Concurrency
|
||||
}
|
||||
|
||||
// GetApiTimeout returns the configured API timeout or a sane default (30 seconds)
|
||||
// when the configured value is zero or negative.
|
||||
func (b *BackendMikrotik) GetApiTimeout() time.Duration {
|
||||
if b == nil {
|
||||
return 30 * time.Second
|
||||
}
|
||||
if b.ApiTimeout <= 0 {
|
||||
return 30 * time.Second
|
||||
}
|
||||
return b.ApiTimeout
|
||||
}
|
@ -44,6 +44,8 @@ type Config struct {
|
||||
LimitAdditionalUserPeers int `yaml:"limit_additional_user_peers"`
|
||||
} `yaml:"advanced"`
|
||||
|
||||
Backend Backend `yaml:"backend"`
|
||||
|
||||
Statistics struct {
|
||||
UsePingChecks bool `yaml:"use_ping_checks"`
|
||||
PingCheckWorkers int `yaml:"ping_check_workers"`
|
||||
@ -99,6 +101,12 @@ func (c *Config) LogStartupValues() {
|
||||
"minPasswordLength", c.Auth.MinPasswordLength,
|
||||
"hideLoginForm", c.Auth.HideLoginForm,
|
||||
)
|
||||
|
||||
slog.Debug("Config Backend",
|
||||
"defaultBackend", c.Backend.Default,
|
||||
"extraBackends", len(c.Backend.Mikrotik),
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
// defaultConfig returns the default configuration
|
||||
@ -122,6 +130,10 @@ func defaultConfig() *Config {
|
||||
DSN: "data/sqlite.db",
|
||||
}
|
||||
|
||||
cfg.Backend = Backend{
|
||||
Default: LocalBackendName, // local backend is the default (using wgcrtl)
|
||||
}
|
||||
|
||||
cfg.Web = WebConfig{
|
||||
RequestLogging: false,
|
||||
ExternalUrl: "http://localhost:8888",
|
||||
@ -201,6 +213,10 @@ func GetConfig() (*Config, error) {
|
||||
}
|
||||
|
||||
cfg.Web.Sanitize()
|
||||
err := cfg.Backend.Validate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
32
internal/domain/controller.go
Normal file
32
internal/domain/controller.go
Normal file
@ -0,0 +1,32 @@
|
||||
package domain
|
||||
|
||||
// ControllerType defines the type of controller used to manage interfaces.
|
||||
|
||||
const (
|
||||
ControllerTypeMikrotik = "mikrotik"
|
||||
ControllerTypeLocal = "wgctrl"
|
||||
)
|
||||
|
||||
// Controller extras can be used to store additional information available for specific controllers only.
|
||||
|
||||
type MikrotikInterfaceExtras struct {
|
||||
Id string // internal mikrotik ID
|
||||
Comment string
|
||||
Disabled bool
|
||||
}
|
||||
|
||||
type MikrotikPeerExtras struct {
|
||||
Id string // internal mikrotik ID
|
||||
Name string
|
||||
Comment string
|
||||
IsResponder bool
|
||||
Disabled bool
|
||||
ClientEndpoint string
|
||||
ClientAddress string
|
||||
ClientDns string
|
||||
ClientKeepalive int
|
||||
}
|
||||
|
||||
type LocalPeerExtras struct {
|
||||
Disabled bool
|
||||
}
|
@ -10,6 +10,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/h44z/wg-portal/internal"
|
||||
)
|
||||
|
||||
@ -23,6 +25,7 @@ var allowedFileNameRegex = regexp.MustCompile("[^a-zA-Z0-9-_]+")
|
||||
|
||||
type InterfaceIdentifier string
|
||||
type InterfaceType string
|
||||
type InterfaceBackend string
|
||||
|
||||
type Interface struct {
|
||||
BaseModel
|
||||
@ -51,6 +54,7 @@ type Interface struct {
|
||||
// WG Portal specific
|
||||
DisplayName string // a nice display name/ description for the interface
|
||||
Type InterfaceType // the interface type, either InterfaceTypeServer or InterfaceTypeClient
|
||||
Backend InterfaceBackend // the backend that is used to manage the interface (wgctrl, mikrotik, ...)
|
||||
DriverType string // the interface driver type (linux, software, ...)
|
||||
Disabled *time.Time `gorm:"index"` // flag that specifies if the interface is enabled (up) or not (down)
|
||||
DisabledReason string // the reason why the interface has been disabled
|
||||
@ -204,9 +208,31 @@ type PhysicalInterface struct {
|
||||
|
||||
BytesUpload uint64
|
||||
BytesDownload uint64
|
||||
|
||||
backendExtras any // additional backend-specific extras, e.g., domain.MikrotikInterfaceExtras
|
||||
}
|
||||
|
||||
func (p *PhysicalInterface) GetExtras() any {
|
||||
return p.backendExtras
|
||||
}
|
||||
|
||||
func (p *PhysicalInterface) SetExtras(extras any) {
|
||||
switch extras.(type) {
|
||||
case MikrotikInterfaceExtras: // OK
|
||||
default: // we only support MikrotikInterfaceExtras for now
|
||||
panic(fmt.Sprintf("unsupported interface backend extras type %T", extras))
|
||||
}
|
||||
|
||||
p.backendExtras = extras
|
||||
}
|
||||
|
||||
func ConvertPhysicalInterface(pi *PhysicalInterface) *Interface {
|
||||
networks := make([]Cidr, 0, len(pi.Addresses))
|
||||
for _, addr := range pi.Addresses {
|
||||
networks = append(networks, addr.NetworkAddr())
|
||||
}
|
||||
|
||||
// create a new basic interface with the data from the physical interface
|
||||
iface := &Interface{
|
||||
Identifier: pi.Identifier,
|
||||
KeyPair: pi.KeyPair,
|
||||
@ -226,11 +252,11 @@ func ConvertPhysicalInterface(pi *PhysicalInterface) *Interface {
|
||||
Type: InterfaceTypeAny,
|
||||
DriverType: pi.DeviceType,
|
||||
Disabled: nil,
|
||||
PeerDefNetworkStr: "",
|
||||
PeerDefNetworkStr: CidrsToString(networks),
|
||||
PeerDefDnsStr: "",
|
||||
PeerDefDnsSearchStr: "",
|
||||
PeerDefEndpoint: "",
|
||||
PeerDefAllowedIPsStr: "",
|
||||
PeerDefAllowedIPsStr: CidrsToString(networks),
|
||||
PeerDefMtu: pi.Mtu,
|
||||
PeerDefPersistentKeepalive: 0,
|
||||
PeerDefFirewallMark: 0,
|
||||
@ -241,6 +267,23 @@ func ConvertPhysicalInterface(pi *PhysicalInterface) *Interface {
|
||||
PeerDefPostDown: "",
|
||||
}
|
||||
|
||||
if pi.GetExtras() == nil {
|
||||
return iface
|
||||
}
|
||||
|
||||
// enrich the data with controller-specific extras
|
||||
now := time.Now()
|
||||
switch pi.ImportSource {
|
||||
case ControllerTypeMikrotik:
|
||||
extras := pi.GetExtras().(MikrotikInterfaceExtras)
|
||||
iface.DisplayName = extras.Comment
|
||||
if extras.Disabled {
|
||||
iface.Disabled = &now
|
||||
} else {
|
||||
iface.Disabled = nil
|
||||
}
|
||||
}
|
||||
|
||||
return iface
|
||||
}
|
||||
|
||||
@ -253,6 +296,15 @@ func MergeToPhysicalInterface(pi *PhysicalInterface, i *Interface) {
|
||||
pi.FirewallMark = i.FirewallMark
|
||||
pi.DeviceUp = !i.IsDisabled()
|
||||
pi.Addresses = i.Addresses
|
||||
|
||||
switch pi.ImportSource {
|
||||
case ControllerTypeMikrotik:
|
||||
extras := MikrotikInterfaceExtras{
|
||||
Comment: i.DisplayName,
|
||||
Disabled: i.IsDisabled(),
|
||||
}
|
||||
pi.SetExtras(extras)
|
||||
}
|
||||
}
|
||||
|
||||
type RoutingTableInfo struct {
|
||||
@ -279,3 +331,30 @@ func (r RoutingTableInfo) GetRoutingTable() int {
|
||||
|
||||
return r.Table
|
||||
}
|
||||
|
||||
type IpFamily int
|
||||
|
||||
const (
|
||||
IpFamilyIPv4 IpFamily = unix.AF_INET
|
||||
IpFamilyIPv6 IpFamily = unix.AF_INET6
|
||||
)
|
||||
|
||||
func (f IpFamily) String() string {
|
||||
switch f {
|
||||
case IpFamilyIPv4:
|
||||
return "IPv4"
|
||||
case IpFamilyIPv6:
|
||||
return "IPv6"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// RouteRule represents a routing table rule.
|
||||
type RouteRule struct {
|
||||
InterfaceId InterfaceIdentifier
|
||||
IpFamily IpFamily
|
||||
FwMark uint32
|
||||
Table int
|
||||
HasDefault bool
|
||||
}
|
||||
|
@ -129,7 +129,7 @@ func (p *Peer) GenerateDisplayName(prefix string) {
|
||||
p.DisplayName = fmt.Sprintf("%sPeer %s", prefix, internal.TruncateString(string(p.Identifier), 8))
|
||||
}
|
||||
|
||||
// OverwriteUserEditableFields overwrites the user editable fields of the peer with the values from the userPeer
|
||||
// OverwriteUserEditableFields overwrites the user-editable fields of the peer with the values from the userPeer
|
||||
func (p *Peer) OverwriteUserEditableFields(userPeer *Peer, cfg *config.Config) {
|
||||
p.DisplayName = userPeer.DisplayName
|
||||
if cfg.Core.EditableKeys {
|
||||
@ -182,9 +182,12 @@ type PhysicalPeer struct {
|
||||
|
||||
BytesUpload uint64 // upload bytes are the number of bytes that the remote peer has sent to the server
|
||||
BytesDownload uint64 // upload bytes are the number of bytes that the remote peer has received from the server
|
||||
|
||||
ImportSource string // import source (wgctrl, file, ...)
|
||||
backendExtras any // additional backend-specific extras, e.g., domain.MikrotikPeerExtras
|
||||
}
|
||||
|
||||
func (p PhysicalPeer) GetPresharedKey() *wgtypes.Key {
|
||||
func (p *PhysicalPeer) GetPresharedKey() *wgtypes.Key {
|
||||
if p.PresharedKey == "" {
|
||||
return nil
|
||||
}
|
||||
@ -196,7 +199,7 @@ func (p PhysicalPeer) GetPresharedKey() *wgtypes.Key {
|
||||
return &key
|
||||
}
|
||||
|
||||
func (p PhysicalPeer) GetEndpointAddress() *net.UDPAddr {
|
||||
func (p *PhysicalPeer) GetEndpointAddress() *net.UDPAddr {
|
||||
if p.Endpoint == "" {
|
||||
return nil
|
||||
}
|
||||
@ -208,7 +211,7 @@ func (p PhysicalPeer) GetEndpointAddress() *net.UDPAddr {
|
||||
return addr
|
||||
}
|
||||
|
||||
func (p PhysicalPeer) GetPersistentKeepaliveTime() *time.Duration {
|
||||
func (p *PhysicalPeer) GetPersistentKeepaliveTime() *time.Duration {
|
||||
if p.PersistentKeepalive == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -217,7 +220,7 @@ func (p PhysicalPeer) GetPersistentKeepaliveTime() *time.Duration {
|
||||
return &keepAliveDuration
|
||||
}
|
||||
|
||||
func (p PhysicalPeer) GetAllowedIPs() []net.IPNet {
|
||||
func (p *PhysicalPeer) GetAllowedIPs() []net.IPNet {
|
||||
allowedIPs := make([]net.IPNet, len(p.AllowedIPs))
|
||||
for i, ip := range p.AllowedIPs {
|
||||
allowedIPs[i] = *ip.IpNet()
|
||||
@ -226,6 +229,21 @@ func (p PhysicalPeer) GetAllowedIPs() []net.IPNet {
|
||||
return allowedIPs
|
||||
}
|
||||
|
||||
func (p *PhysicalPeer) GetExtras() any {
|
||||
return p.backendExtras
|
||||
}
|
||||
|
||||
func (p *PhysicalPeer) SetExtras(extras any) {
|
||||
switch extras.(type) {
|
||||
case MikrotikPeerExtras: // OK
|
||||
case LocalPeerExtras: // OK
|
||||
default: // we only support MikrotikPeerExtras and LocalPeerExtras for now
|
||||
panic(fmt.Sprintf("unsupported peer backend extras type %T", extras))
|
||||
}
|
||||
|
||||
p.backendExtras = extras
|
||||
}
|
||||
|
||||
func ConvertPhysicalPeer(pp *PhysicalPeer) *Peer {
|
||||
peer := &Peer{
|
||||
Endpoint: NewConfigOption(pp.Endpoint, true),
|
||||
@ -244,6 +262,44 @@ func ConvertPhysicalPeer(pp *PhysicalPeer) *Peer {
|
||||
},
|
||||
}
|
||||
|
||||
if pp.GetExtras() == nil {
|
||||
return peer
|
||||
}
|
||||
|
||||
// enrich the data with controller-specific extras
|
||||
now := time.Now()
|
||||
switch pp.ImportSource {
|
||||
case ControllerTypeMikrotik:
|
||||
extras := pp.GetExtras().(MikrotikPeerExtras)
|
||||
peer.Notes = extras.Comment
|
||||
peer.DisplayName = extras.Name
|
||||
if extras.ClientEndpoint != "" { // if the client endpoint is set, we assume that this is a client peer
|
||||
peer.Endpoint = NewConfigOption(extras.ClientEndpoint, true)
|
||||
peer.Interface.Type = InterfaceTypeClient
|
||||
peer.Interface.Addresses, _ = CidrsFromString(extras.ClientAddress)
|
||||
peer.Interface.DnsStr = NewConfigOption(extras.ClientDns, true)
|
||||
peer.PersistentKeepalive = NewConfigOption(extras.ClientKeepalive, true)
|
||||
} else {
|
||||
peer.Interface.Type = InterfaceTypeServer
|
||||
}
|
||||
if extras.Disabled {
|
||||
peer.Disabled = &now
|
||||
peer.DisabledReason = "Disabled by Mikrotik controller"
|
||||
} else {
|
||||
peer.Disabled = nil
|
||||
peer.DisabledReason = ""
|
||||
}
|
||||
case ControllerTypeLocal:
|
||||
extras := pp.GetExtras().(LocalPeerExtras)
|
||||
if extras.Disabled {
|
||||
peer.Disabled = &now
|
||||
peer.DisabledReason = "Disabled by Local controller"
|
||||
} else {
|
||||
peer.Disabled = nil
|
||||
peer.DisabledReason = ""
|
||||
}
|
||||
}
|
||||
|
||||
return peer
|
||||
}
|
||||
|
||||
@ -265,6 +321,27 @@ func MergeToPhysicalPeer(pp *PhysicalPeer, p *Peer) {
|
||||
pp.PresharedKey = p.PresharedKey
|
||||
pp.PublicKey = p.Interface.PublicKey
|
||||
pp.PersistentKeepalive = p.PersistentKeepalive.GetValue()
|
||||
|
||||
switch pp.ImportSource {
|
||||
case ControllerTypeMikrotik:
|
||||
extras := MikrotikPeerExtras{
|
||||
Id: "",
|
||||
Name: p.DisplayName,
|
||||
Comment: p.Notes,
|
||||
IsResponder: false,
|
||||
Disabled: p.IsDisabled(),
|
||||
ClientEndpoint: p.Endpoint.GetValue(),
|
||||
ClientAddress: CidrsToString(p.Interface.Addresses),
|
||||
ClientDns: p.Interface.DnsStr.GetValue(),
|
||||
ClientKeepalive: p.PersistentKeepalive.GetValue(),
|
||||
}
|
||||
pp.SetExtras(extras)
|
||||
case ControllerTypeLocal:
|
||||
extras := LocalPeerExtras{
|
||||
Disabled: p.IsDisabled(),
|
||||
}
|
||||
pp.SetExtras(extras)
|
||||
}
|
||||
}
|
||||
|
||||
type PeerCreationRequest struct {
|
||||
|
@ -1,6 +1,8 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type PeerStatus struct {
|
||||
PeerId PeerIdentifier `gorm:"primaryKey;column:identifier" json:"PeerId"`
|
||||
@ -37,3 +39,25 @@ type InterfaceStatus struct {
|
||||
BytesReceived uint64 `gorm:"column:received"`
|
||||
BytesTransmitted uint64 `gorm:"column:transmitted"`
|
||||
}
|
||||
|
||||
type PingerResult struct {
|
||||
PacketsRecv int
|
||||
PacketsSent int
|
||||
Rtts []time.Duration
|
||||
}
|
||||
|
||||
func (r PingerResult) IsPingable() bool {
|
||||
return r.PacketsRecv > 0 && r.PacketsSent > 0 && len(r.Rtts) > 0
|
||||
}
|
||||
|
||||
func (r PingerResult) AverageRtt() time.Duration {
|
||||
if len(r.Rtts) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var total time.Duration
|
||||
for _, rtt := range r.Rtts {
|
||||
total += rtt
|
||||
}
|
||||
return total / time.Duration(len(r.Rtts))
|
||||
}
|
||||
|
@ -12,8 +12,8 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// SetupLogging initializes the global logger with the given level and format
|
||||
func SetupLogging(level string, pretty, json bool) {
|
||||
// GetLoggingHandler initializes a slog.Handler based on the provided logging level and format options.
|
||||
func GetLoggingHandler(level string, pretty, json bool) slog.Handler {
|
||||
var logLevel = new(slog.LevelVar)
|
||||
|
||||
switch strings.ToLower(level) {
|
||||
@ -46,6 +46,13 @@ func SetupLogging(level string, pretty, json bool) {
|
||||
handler = slog.NewTextHandler(output, opts)
|
||||
}
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
// SetupLogging initializes the global logger with the given level and format
|
||||
func SetupLogging(level string, pretty, json bool) {
|
||||
handler := GetLoggingHandler(level, pretty, json)
|
||||
|
||||
logger := slog.New(handler)
|
||||
|
||||
slog.SetDefault(logger)
|
||||
|
435
internal/lowlevel/mikrotik.go
Normal file
435
internal/lowlevel/mikrotik.go
Normal file
@ -0,0 +1,435 @@
|
||||
package lowlevel
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/h44z/wg-portal/internal"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
)
|
||||
|
||||
// region models
|
||||
|
||||
const (
|
||||
MikrotikApiStatusOk = "success"
|
||||
MikrotikApiStatusError = "error"
|
||||
)
|
||||
|
||||
const (
|
||||
MikrotikApiErrorCodeUnknown = iota + 600
|
||||
MikrotikApiErrorCodeRequestPreparationFailed
|
||||
MikrotikApiErrorCodeRequestFailed
|
||||
MikrotikApiErrorCodeResponseDecodeFailed
|
||||
)
|
||||
|
||||
type MikrotikApiResponse[T any] struct {
|
||||
Status string
|
||||
Code int
|
||||
Data T `json:"data,omitempty"`
|
||||
Error *MikrotikApiError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type MikrotikApiError struct {
|
||||
Code int `json:"error,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Details string `json:"detail,omitempty"`
|
||||
}
|
||||
|
||||
func (e *MikrotikApiError) String() string {
|
||||
if e == nil {
|
||||
return "no error"
|
||||
}
|
||||
return fmt.Sprintf("API error %d: %s - %s", e.Code, e.Message, e.Details)
|
||||
}
|
||||
|
||||
type GenericJsonObject map[string]any
|
||||
type EmptyResponse struct{}
|
||||
|
||||
func (JsonObject GenericJsonObject) GetString(key string) string {
|
||||
if value, ok := JsonObject[key]; ok {
|
||||
if strValue, ok := value.(string); ok {
|
||||
return strValue
|
||||
} else {
|
||||
return fmt.Sprintf("%v", value) // Convert to string if not already
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (JsonObject GenericJsonObject) GetInt(key string) int {
|
||||
if value, ok := JsonObject[key]; ok {
|
||||
if intValue, ok := value.(int); ok {
|
||||
return intValue
|
||||
} else {
|
||||
if floatValue, ok := value.(float64); ok {
|
||||
return int(floatValue) // Convert float64 to int
|
||||
}
|
||||
if strValue, ok := value.(string); ok {
|
||||
if intValue, err := strconv.Atoi(strValue); err == nil {
|
||||
return intValue // Convert string to int if possible
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (JsonObject GenericJsonObject) GetBool(key string) bool {
|
||||
if value, ok := JsonObject[key]; ok {
|
||||
if boolValue, ok := value.(bool); ok {
|
||||
return boolValue
|
||||
} else {
|
||||
if intValue, ok := value.(int); ok {
|
||||
return intValue == 1 // Convert int to bool (1 is true, 0 is false)
|
||||
}
|
||||
if floatValue, ok := value.(float64); ok {
|
||||
return int(floatValue) == 1 // Convert float64 to bool (1.0 is true, 0.0 is false)
|
||||
}
|
||||
if strValue, ok := value.(string); ok {
|
||||
boolValue, err := strconv.ParseBool(strValue)
|
||||
if err == nil {
|
||||
return boolValue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type MikrotikRequestOptions struct {
|
||||
Filters map[string]string `json:"filters,omitempty"`
|
||||
PropList []string `json:"proplist,omitempty"`
|
||||
}
|
||||
|
||||
func (o *MikrotikRequestOptions) GetPath(base string) string {
|
||||
if o == nil {
|
||||
return base
|
||||
}
|
||||
|
||||
path, err := url.Parse(base)
|
||||
if err != nil {
|
||||
return base
|
||||
}
|
||||
|
||||
query := path.Query()
|
||||
for k, v := range o.Filters {
|
||||
query.Set(k, v)
|
||||
}
|
||||
if len(o.PropList) > 0 {
|
||||
query.Set(".proplist", strings.Join(o.PropList, ","))
|
||||
}
|
||||
path.RawQuery = query.Encode()
|
||||
return path.String()
|
||||
}
|
||||
|
||||
// region models
|
||||
|
||||
// region API-client
|
||||
|
||||
type MikrotikApiClient struct {
|
||||
coreCfg *config.Config
|
||||
cfg *config.BackendMikrotik
|
||||
|
||||
client *http.Client
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
func NewMikrotikApiClient(coreCfg *config.Config, cfg *config.BackendMikrotik) (*MikrotikApiClient, error) {
|
||||
c := &MikrotikApiClient{
|
||||
coreCfg: coreCfg,
|
||||
cfg: cfg,
|
||||
}
|
||||
|
||||
err := c.setup()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.debugLog("Mikrotik api client created", "api_url", cfg.ApiUrl)
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (m *MikrotikApiClient) setup() error {
|
||||
m.client = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: !m.cfg.ApiVerifyTls,
|
||||
},
|
||||
},
|
||||
Timeout: m.cfg.GetApiTimeout(),
|
||||
}
|
||||
|
||||
if m.cfg.Debug {
|
||||
m.log = slog.New(internal.GetLoggingHandler("debug",
|
||||
m.coreCfg.Advanced.LogPretty,
|
||||
m.coreCfg.Advanced.LogJson).
|
||||
WithAttrs([]slog.Attr{
|
||||
{
|
||||
Key: "mikrotik-bid", Value: slog.StringValue(m.cfg.Id),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MikrotikApiClient) debugLog(msg string, args ...any) {
|
||||
if m.log != nil {
|
||||
m.log.Debug("[MT-API] "+msg, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MikrotikApiClient) getFullPath(command string) string {
|
||||
path, err := url.JoinPath(m.cfg.ApiUrl, command)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func (m *MikrotikApiClient) prepareGetRequest(ctx context.Context, fullUrl string) (*http.Request, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullUrl, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if m.cfg.ApiUser != "" && m.cfg.ApiPassword != "" {
|
||||
req.SetBasicAuth(m.cfg.ApiUser, m.cfg.ApiPassword)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (m *MikrotikApiClient) prepareDeleteRequest(ctx context.Context, fullUrl string) (*http.Request, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, fullUrl, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if m.cfg.ApiUser != "" && m.cfg.ApiPassword != "" {
|
||||
req.SetBasicAuth(m.cfg.ApiUser, m.cfg.ApiPassword)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (m *MikrotikApiClient) preparePayloadRequest(
|
||||
ctx context.Context,
|
||||
method string,
|
||||
fullUrl string,
|
||||
payload GenericJsonObject,
|
||||
) (*http.Request, error) {
|
||||
// marshal the payload to JSON
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, fullUrl, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if m.cfg.ApiUser != "" && m.cfg.ApiPassword != "" {
|
||||
req.SetBasicAuth(m.cfg.ApiUser, m.cfg.ApiPassword)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func errToApiResponse[T any](code int, message string, err error) MikrotikApiResponse[T] {
|
||||
return MikrotikApiResponse[T]{
|
||||
Status: MikrotikApiStatusError,
|
||||
Code: code,
|
||||
Error: &MikrotikApiError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Details: err.Error(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func parseHttpResponse[T any](resp *http.Response, err error) MikrotikApiResponse[T] {
|
||||
if err != nil {
|
||||
return errToApiResponse[T](MikrotikApiErrorCodeRequestFailed, "failed to execute request", err)
|
||||
}
|
||||
|
||||
defer func(Body io.ReadCloser) {
|
||||
err := Body.Close()
|
||||
if err != nil {
|
||||
slog.Error("failed to close response body", "error", err)
|
||||
}
|
||||
}(resp.Body)
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
var data T
|
||||
|
||||
// if the type of T is EmptyResponse, we can return an empty response with just the status
|
||||
if _, ok := any(data).(EmptyResponse); ok {
|
||||
return MikrotikApiResponse[T]{Status: MikrotikApiStatusOk, Code: resp.StatusCode}
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return errToApiResponse[T](MikrotikApiErrorCodeResponseDecodeFailed, "failed to decode response", err)
|
||||
}
|
||||
return MikrotikApiResponse[T]{Status: MikrotikApiStatusOk, Code: resp.StatusCode, Data: data}
|
||||
}
|
||||
|
||||
var apiErr MikrotikApiError
|
||||
if err := json.NewDecoder(resp.Body).Decode(&apiErr); err != nil {
|
||||
return errToApiResponse[T](resp.StatusCode, "unknown error, unparsable response", err)
|
||||
} else {
|
||||
return MikrotikApiResponse[T]{Status: MikrotikApiStatusError, Code: resp.StatusCode, Error: &apiErr}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MikrotikApiClient) Query(
|
||||
ctx context.Context,
|
||||
command string,
|
||||
opts *MikrotikRequestOptions,
|
||||
) MikrotikApiResponse[[]GenericJsonObject] {
|
||||
apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout())
|
||||
defer cancel()
|
||||
|
||||
fullUrl := opts.GetPath(m.getFullPath(command))
|
||||
|
||||
req, err := m.prepareGetRequest(apiCtx, fullUrl)
|
||||
if err != nil {
|
||||
return errToApiResponse[[]GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed,
|
||||
"failed to create request", err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
m.debugLog("executing API query", "url", fullUrl)
|
||||
response := parseHttpResponse[[]GenericJsonObject](m.client.Do(req))
|
||||
m.debugLog("retrieved API query result", "url", fullUrl, "duration", time.Since(start).String())
|
||||
return response
|
||||
}
|
||||
|
||||
func (m *MikrotikApiClient) Get(
|
||||
ctx context.Context,
|
||||
command string,
|
||||
opts *MikrotikRequestOptions,
|
||||
) MikrotikApiResponse[GenericJsonObject] {
|
||||
apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout())
|
||||
defer cancel()
|
||||
|
||||
fullUrl := opts.GetPath(m.getFullPath(command))
|
||||
|
||||
req, err := m.prepareGetRequest(apiCtx, fullUrl)
|
||||
if err != nil {
|
||||
return errToApiResponse[GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed,
|
||||
"failed to create request", err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
m.debugLog("executing API get", "url", fullUrl)
|
||||
response := parseHttpResponse[GenericJsonObject](m.client.Do(req))
|
||||
m.debugLog("retrieved API get result", "url", fullUrl, "duration", time.Since(start).String())
|
||||
return response
|
||||
}
|
||||
|
||||
func (m *MikrotikApiClient) Create(
|
||||
ctx context.Context,
|
||||
command string,
|
||||
payload GenericJsonObject,
|
||||
) MikrotikApiResponse[GenericJsonObject] {
|
||||
apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout())
|
||||
defer cancel()
|
||||
|
||||
fullUrl := m.getFullPath(command)
|
||||
|
||||
req, err := m.preparePayloadRequest(apiCtx, http.MethodPut, fullUrl, payload)
|
||||
if err != nil {
|
||||
return errToApiResponse[GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed,
|
||||
"failed to create request", err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
m.debugLog("executing API put", "url", fullUrl)
|
||||
response := parseHttpResponse[GenericJsonObject](m.client.Do(req))
|
||||
m.debugLog("retrieved API put result", "url", fullUrl, "duration", time.Since(start).String())
|
||||
return response
|
||||
}
|
||||
|
||||
func (m *MikrotikApiClient) Update(
|
||||
ctx context.Context,
|
||||
command string,
|
||||
payload GenericJsonObject,
|
||||
) MikrotikApiResponse[GenericJsonObject] {
|
||||
apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout())
|
||||
defer cancel()
|
||||
|
||||
fullUrl := m.getFullPath(command)
|
||||
|
||||
req, err := m.preparePayloadRequest(apiCtx, http.MethodPatch, fullUrl, payload)
|
||||
if err != nil {
|
||||
return errToApiResponse[GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed,
|
||||
"failed to create request", err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
m.debugLog("executing API patch", "url", fullUrl)
|
||||
response := parseHttpResponse[GenericJsonObject](m.client.Do(req))
|
||||
m.debugLog("retrieved API patch result", "url", fullUrl, "duration", time.Since(start).String())
|
||||
return response
|
||||
}
|
||||
|
||||
func (m *MikrotikApiClient) Delete(
|
||||
ctx context.Context,
|
||||
command string,
|
||||
) MikrotikApiResponse[EmptyResponse] {
|
||||
apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout())
|
||||
defer cancel()
|
||||
|
||||
fullUrl := m.getFullPath(command)
|
||||
|
||||
req, err := m.prepareDeleteRequest(apiCtx, fullUrl)
|
||||
if err != nil {
|
||||
return errToApiResponse[EmptyResponse](MikrotikApiErrorCodeRequestPreparationFailed,
|
||||
"failed to create request", err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
m.debugLog("executing API delete", "url", fullUrl)
|
||||
response := parseHttpResponse[EmptyResponse](m.client.Do(req))
|
||||
m.debugLog("retrieved API delete result", "url", fullUrl, "duration", time.Since(start).String())
|
||||
return response
|
||||
}
|
||||
|
||||
func (m *MikrotikApiClient) ExecList(
|
||||
ctx context.Context,
|
||||
command string,
|
||||
payload GenericJsonObject,
|
||||
) MikrotikApiResponse[[]GenericJsonObject] {
|
||||
apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout())
|
||||
defer cancel()
|
||||
|
||||
fullUrl := m.getFullPath(command)
|
||||
|
||||
req, err := m.preparePayloadRequest(apiCtx, http.MethodPost, fullUrl, payload)
|
||||
if err != nil {
|
||||
return errToApiResponse[[]GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed,
|
||||
"failed to create request", err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
m.debugLog("executing API post", "url", fullUrl)
|
||||
response := parseHttpResponse[[]GenericJsonObject](m.client.Do(req))
|
||||
m.debugLog("retrieved API post result", "url", fullUrl, "duration", time.Since(start).String())
|
||||
return response
|
||||
}
|
||||
|
||||
// endregion API-client
|
@ -80,6 +80,7 @@ nav:
|
||||
- Examples: documentation/configuration/examples.md
|
||||
- Usage:
|
||||
- General: documentation/usage/general.md
|
||||
- Backends: documentation/usage/backends.md
|
||||
- LDAP: documentation/usage/ldap.md
|
||||
- Security: documentation/usage/security.md
|
||||
- Webhooks: documentation/usage/webhooks.md
|
||||
|
Loading…
x
Reference in New Issue
Block a user