Merge branch 'master' into stable

# Conflicts:
#	internal/domain/peer.go
This commit is contained in:
Christoph Haas
2025-10-19 13:25:07 +02:00
137 changed files with 10275 additions and 1996 deletions

View File

@@ -220,6 +220,8 @@ func (r *SqlRepo) preCheck() error {
func (r *SqlRepo) migrate() error {
slog.Debug("running migration: sys-stat", "result", r.db.AutoMigrate(&SysStat{}))
slog.Debug("running migration: user", "result", r.db.AutoMigrate(&domain.User{}))
slog.Debug("running migration: user webauthn credentials", "result",
r.db.AutoMigrate(&domain.UserWebauthnCredential{}))
slog.Debug("running migration: interface", "result", r.db.AutoMigrate(&domain.Interface{}))
slog.Debug("running migration: peer", "result", r.db.AutoMigrate(&domain.Peer{}))
slog.Debug("running migration: peer status", "result", r.db.AutoMigrate(&domain.PeerStatus{}))
@@ -746,7 +748,7 @@ func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr
func (r *SqlRepo) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
var user domain.User
err := r.db.WithContext(ctx).First(&user, id).Error
err := r.db.WithContext(ctx).Preload("WebAuthnCredentialList").First(&user, id).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return nil, domain.ErrNotFound
@@ -764,7 +766,7 @@ func (r *SqlRepo) GetUser(ctx context.Context, id domain.UserIdentifier) (*domai
func (r *SqlRepo) GetUserByEmail(ctx context.Context, email string) (*domain.User, error) {
var users []domain.User
err := r.db.WithContext(ctx).Where("email = ?", email).Find(&users).Error
err := r.db.WithContext(ctx).Where("email = ?", email).Preload("WebAuthnCredentialList").Find(&users).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return nil, domain.ErrNotFound
}
@@ -785,11 +787,26 @@ func (r *SqlRepo) GetUserByEmail(ctx context.Context, email string) (*domain.Use
return &user, nil
}
// GetUserByWebAuthnCredential returns the user with the given webauthn credential id.
func (r *SqlRepo) GetUserByWebAuthnCredential(ctx context.Context, credentialIdBase64 string) (*domain.User, error) {
var credential domain.UserWebauthnCredential
err := r.db.WithContext(ctx).Where("credential_identifier = ?", credentialIdBase64).First(&credential).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return nil, domain.ErrNotFound
}
if err != nil {
return nil, err
}
return r.GetUser(ctx, domain.UserIdentifier(credential.UserIdentifier))
}
// GetAllUsers returns all users.
func (r *SqlRepo) GetAllUsers(ctx context.Context) ([]domain.User, error) {
var users []domain.User
err := r.db.WithContext(ctx).Find(&users).Error
err := r.db.WithContext(ctx).Preload("WebAuthnCredentialList").Find(&users).Error
if err != nil {
return nil, err
}
@@ -808,6 +825,7 @@ func (r *SqlRepo) FindUsers(ctx context.Context, search string) ([]domain.User,
Or("firstname LIKE ?", searchValue).
Or("lastname LIKE ?", searchValue).
Or("email LIKE ?", searchValue).
Preload("WebAuthnCredentialList").
Find(&users).Error
if err != nil {
return nil, err
@@ -853,7 +871,7 @@ func (r *SqlRepo) SaveUser(
// DeleteUser deletes the user with the given id.
func (r *SqlRepo) DeleteUser(ctx context.Context, id domain.UserIdentifier) error {
err := r.db.WithContext(ctx).Delete(&domain.User{}, id).Error
err := r.db.WithContext(ctx).Unscoped().Select(clause.Associations).Delete(&domain.User{Identifier: id}).Error
if err != nil {
return err
}
@@ -897,6 +915,11 @@ func (r *SqlRepo) upsertUser(ui *domain.ContextUserInfo, tx *gorm.DB, user *doma
return err
}
err = tx.Session(&gorm.Session{FullSaveAssociations: true}).Unscoped().Model(user).Association("WebAuthnCredentialList").Unscoped().Replace(user.WebAuthnCredentialList)
if err != nil {
return fmt.Errorf("failed to update users webauthn credentials: %w", err)
}
return nil
}

View File

@@ -133,5 +133,5 @@ func (m *MetricsServer) UpdatePeerMetrics(peer *domain.Peer, status domain.PeerS
}
m.peerReceivedBytesTotal.WithLabelValues(labels...).Set(float64(status.BytesReceived))
m.peerSendBytesTotal.WithLabelValues(labels...).Set(float64(status.BytesTransmitted))
m.peerIsConnected.WithLabelValues(labels...).Set(internal.BoolToFloat64(status.IsConnected()))
m.peerIsConnected.WithLabelValues(labels...).Set(internal.BoolToFloat64(status.IsConnected))
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,113 +0,0 @@
package adapters
import (
"bytes"
"fmt"
"log/slog"
"os/exec"
"strings"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/domain"
)
// WgQuickRepo implements higher level wg-quick like interactions like setting DNS, routing tables or interface hooks.
type WgQuickRepo struct {
shellCmd string
resolvConfIfacePrefix string
}
// NewWgQuickRepo creates a new WgQuickRepo instance.
func NewWgQuickRepo() *WgQuickRepo {
return &WgQuickRepo{
shellCmd: "bash",
resolvConfIfacePrefix: "tun.",
}
}
// ExecuteInterfaceHook executes the given hook command.
// The hook command can contain the following placeholders:
//
// %i: the interface identifier.
func (r *WgQuickRepo) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
if hookCmd == "" {
return nil
}
slog.Debug("executing interface hook", "interface", id, "hook", hookCmd)
err := r.exec(hookCmd, id)
if err != nil {
return fmt.Errorf("failed to exec hook: %w", err)
}
return nil
}
// SetDNS sets the DNS settings for the given interface. It uses resolvconf to set the DNS settings.
func (r *WgQuickRepo) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
if dnsStr == "" && dnsSearchStr == "" {
return nil
}
dnsServers := internal.SliceString(dnsStr)
dnsSearchDomains := internal.SliceString(dnsSearchStr)
dnsCommand := "resolvconf -a %resPref%i -m 0 -x"
dnsCommandInput := make([]string, 0, len(dnsServers)+len(dnsSearchDomains))
for _, dnsServer := range dnsServers {
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("nameserver %s", dnsServer))
}
for _, searchDomain := range dnsSearchDomains {
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("search %s", searchDomain))
}
err := r.exec(dnsCommand, id, dnsCommandInput...)
if err != nil {
return fmt.Errorf(
"failed to set dns settings (is resolvconf available?, for systemd create this symlink: ln -s /usr/bin/resolvectl /usr/local/bin/resolvconf): %w",
err,
)
}
return nil
}
// UnsetDNS unsets the DNS settings for the given interface. It uses resolvconf to unset the DNS settings.
func (r *WgQuickRepo) UnsetDNS(id domain.InterfaceIdentifier) error {
dnsCommand := "resolvconf -d %resPref%i -f"
err := r.exec(dnsCommand, id)
if err != nil {
return fmt.Errorf("failed to unset dns settings: %w", err)
}
return nil
}
func (r *WgQuickRepo) replaceCommandPlaceHolders(command string, interfaceId domain.InterfaceIdentifier) string {
command = strings.ReplaceAll(command, "%resPref", r.resolvConfIfacePrefix)
return strings.ReplaceAll(command, "%i", string(interfaceId))
}
func (r *WgQuickRepo) exec(command string, interfaceId domain.InterfaceIdentifier, stdin ...string) error {
commandWithInterfaceName := r.replaceCommandPlaceHolders(command, interfaceId)
cmd := exec.Command(r.shellCmd, "-ce", commandWithInterfaceName)
if len(stdin) > 0 {
b := &bytes.Buffer{}
for _, ln := range stdin {
if _, err := fmt.Fprint(b, ln); err != nil {
return err
}
}
cmd.Stdin = b
}
out, err := cmd.CombinedOutput() // execute and wait for output
if err != nil {
return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err)
}
slog.Debug("executed shell command",
"command", commandWithInterfaceName,
"output", string(out))
return nil
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"github.com/vishvananda/netlink"
@@ -16,8 +17,9 @@ import (
// WgRepo implements all low-level WireGuard interactions.
type WgRepo struct {
wg lowlevel.WireGuardClient
nl lowlevel.NetlinkClient
wg lowlevel.WireGuardClient
nl lowlevel.NetlinkClient
log *slog.Logger
}
// NewWireGuardRepository creates a new WgRepo instance.
@@ -31,8 +33,9 @@ func NewWireGuardRepository() *WgRepo {
nl := &lowlevel.NetlinkManager{}
repo := &WgRepo{
wg: wg,
nl: nl,
wg: wg,
nl: nl,
log: slog.Default().With(slog.String("adapter", "wireguard")),
}
return repo
@@ -40,8 +43,10 @@ func NewWireGuardRepository() *WgRepo {
// GetInterfaces returns all existing WireGuard interfaces.
func (r *WgRepo) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) {
r.log.Debug("getting all interfaces")
devices, err := r.wg.Devices()
if err != nil {
r.log.Error("failed to get devices", "error", err)
return nil, fmt.Errorf("device list error: %w", err)
}
@@ -60,14 +65,17 @@ func (r *WgRepo) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, e
// GetInterface returns the interface with the given id.
// If no interface is found, an error os.ErrNotExist is returned.
func (r *WgRepo) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
r.log.Debug("getting interface", "id", id)
return r.getInterface(id)
}
// GetPeers returns all peers associated with the given interface id.
// If the requested interface is found, an error os.ErrNotExist is returned.
func (r *WgRepo) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) {
r.log.Debug("getting peers for interface", "deviceId", deviceId)
device, err := r.wg.Device(string(deviceId))
if err != nil {
r.log.Error("failed to get device", "deviceId", deviceId, "error", err)
return nil, fmt.Errorf("device error: %w", err)
}
@@ -90,6 +98,7 @@ func (r *WgRepo) GetPeer(
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
) (*domain.PhysicalPeer, error) {
r.log.Debug("getting peer", "deviceId", deviceId, "peerId", id)
return r.getPeer(deviceId, id)
}
@@ -174,25 +183,31 @@ func (r *WgRepo) SaveInterface(
id domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error {
r.log.Debug("saving interface", "id", id)
physicalInterface, err := r.getOrCreateInterface(id)
if err != nil {
r.log.Error("failed to get or create interface", "id", id, "error", err)
return err
}
if updateFunc != nil {
physicalInterface, err = updateFunc(physicalInterface)
if err != nil {
r.log.Error("interface update function failed", "id", id, "error", err)
return err
}
}
if err := r.updateLowLevelInterface(physicalInterface); err != nil {
r.log.Error("failed to update low level interface", "id", id, "error", err)
return err
}
if err := r.updateWireGuardInterface(physicalInterface); err != nil {
r.log.Error("failed to update wireguard interface", "id", id, "error", err)
return err
}
r.log.Debug("successfully saved interface", "id", id)
return nil
}
@@ -323,10 +338,13 @@ func (r *WgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
// DeleteInterface deletes the interface with the given id.
// If the requested interface is found, no error is returned.
func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
r.log.Debug("deleting interface", "id", id)
if err := r.deleteLowLevelInterface(id); err != nil {
r.log.Error("failed to delete low level interface", "id", id, "error", err)
return err
}
r.log.Debug("successfully deleted interface", "id", id)
return nil
}
@@ -356,20 +374,25 @@ func (r *WgRepo) SavePeer(
id domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error {
r.log.Debug("saving peer", "deviceId", deviceId, "peerId", id)
physicalPeer, err := r.getOrCreatePeer(deviceId, id)
if err != nil {
r.log.Error("failed to get or create peer", "deviceId", deviceId, "peerId", id, "error", err)
return err
}
physicalPeer, err = updateFunc(physicalPeer)
if err != nil {
r.log.Error("peer update function failed", "deviceId", deviceId, "peerId", id, "error", err)
return err
}
if err := r.updatePeer(deviceId, physicalPeer); err != nil {
r.log.Error("failed to update peer", "deviceId", deviceId, "peerId", id, "error", err)
return err
}
r.log.Debug("successfully saved peer", "deviceId", deviceId, "peerId", id)
return nil
}
@@ -441,6 +464,7 @@ func (r *WgRepo) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.Phys
err := r.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}})
if err != nil {
r.log.Error("failed to configure device for peer update", "deviceId", deviceId, "peerId", pp.Identifier, "error", err)
return err
}
@@ -450,15 +474,20 @@ func (r *WgRepo) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.Phys
// DeletePeer deletes the peer with the given id.
// If the requested interface or peer is found, no error is returned.
func (r *WgRepo) DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error {
r.log.Debug("deleting peer", "deviceId", deviceId, "peerId", id)
if !id.IsPublicKey() {
return errors.New("invalid public key")
err := errors.New("invalid public key")
r.log.Error("invalid peer id", "peerId", id, "error", err)
return err
}
err := r.deletePeer(deviceId, id)
if err != nil {
r.log.Error("failed to delete peer", "deviceId", deviceId, "peerId", id, "error", err)
return err
}
r.log.Debug("successfully deleted peer", "deviceId", deviceId, "peerId", id)
return nil
}
@@ -470,6 +499,7 @@ func (r *WgRepo) deletePeer(deviceId domain.InterfaceIdentifier, id domain.PeerI
err := r.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}})
if err != nil {
r.log.Error("failed to configure device for peer deletion", "deviceId", deviceId, "peerId", id, "error", err)
return err
}

View File

@@ -175,30 +175,78 @@
}
}
},
"/auth/{provider}/callback": {
"get": {
"/auth/webauthn/credential/{id}": {
"put": {
"produces": [
"application/json"
],
"tags": [
"Authentication"
],
"summary": "Handle the OAuth callback.",
"operationId": "auth_handleOauthCallbackGet",
"summary": "Update a WebAuthn credential.",
"operationId": "auth_handleWebAuthnCredentialsPut",
"parameters": [
{
"type": "string",
"description": "Base64 encoded Credential ID",
"name": "id",
"in": "path",
"required": true
},
{
"description": "Credential name",
"name": "request",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/model.WebAuthnCredentialRequest"
}
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"$ref": "#/definitions/model.LoginProviderInfo"
"$ref": "#/definitions/model.WebAuthnCredentialResponse"
}
}
}
}
},
"delete": {
"produces": [
"application/json"
],
"tags": [
"Authentication"
],
"summary": "Delete a WebAuthn credential.",
"operationId": "auth_handleWebAuthnCredentialsDelete",
"parameters": [
{
"type": "string",
"description": "Base64 encoded Credential ID",
"name": "id",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"$ref": "#/definitions/model.WebAuthnCredentialResponse"
}
}
}
}
}
},
"/auth/{provider}/init": {
"/auth/webauthn/credentials": {
"get": {
"produces": [
"application/json"
@@ -206,15 +254,67 @@
"tags": [
"Authentication"
],
"summary": "Initiate the OAuth login flow.",
"operationId": "auth_handleOauthInitiateGet",
"summary": "Get all available external login providers.",
"operationId": "auth_handleWebAuthnCredentialsGet",
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"$ref": "#/definitions/model.LoginProviderInfo"
"$ref": "#/definitions/model.WebAuthnCredentialResponse"
}
}
}
}
}
},
"/auth/webauthn/login/finish": {
"post": {
"produces": [
"application/json"
],
"tags": [
"Authentication"
],
"summary": "Finish the WebAuthn login process.",
"operationId": "auth_handleWebAuthnLoginFinish",
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/model.User"
}
}
}
}
},
"/auth/webauthn/register/finish": {
"post": {
"produces": [
"application/json"
],
"tags": [
"Authentication"
],
"summary": "Finish the WebAuthn registration process.",
"operationId": "auth_handleWebAuthnRegisterFinish",
"parameters": [
{
"type": "string",
"default": "\"\"",
"description": "Credential name",
"name": "credential_name",
"in": "query"
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"$ref": "#/definitions/model.WebAuthnCredentialResponse"
}
}
}
@@ -719,6 +819,12 @@
"schema": {
"$ref": "#/definitions/model.PeerMailRequest"
}
},
{
"type": "string",
"description": "The configuration style",
"name": "style",
"in": "query"
}
],
"responses": {
@@ -758,6 +864,12 @@
"name": "id",
"in": "path",
"required": true
},
{
"type": "string",
"description": "The configuration style",
"name": "style",
"in": "query"
}
],
"responses": {
@@ -799,6 +911,12 @@
"name": "id",
"in": "path",
"required": true
},
{
"type": "string",
"description": "The configuration style",
"name": "style",
"in": "query"
}
],
"responses": {
@@ -1432,6 +1550,38 @@
}
}
},
"/user/{id}/change-password": {
"post": {
"produces": [
"application/json"
],
"tags": [
"Users"
],
"summary": "Change the password for the given user.",
"operationId": "users_handleChangePasswordPost",
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/model.User"
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/model.Error"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/model.Error"
}
}
}
}
},
"/user/{id}/interfaces": {
"get": {
"produces": [
@@ -1663,6 +1813,11 @@
"type": "string"
}
},
"Backend": {
"description": "the backend used for this interface e.g., local, mikrotik, ...",
"type": "string",
"example": "local"
},
"Disabled": {
"description": "flag that specifies if the interface is enabled (up) or not (down)",
"type": "boolean"
@@ -2036,6 +2191,10 @@
}
]
},
"UserDisplayName": {
"description": "the owner display name",
"type": "string"
},
"UserIdentifier": {
"description": "the owner",
"type": "string"
@@ -2131,14 +2290,40 @@
"ApiAdminOnly": {
"type": "boolean"
},
"AvailableBackends": {
"type": "array",
"items": {
"$ref": "#/definitions/model.SettingsBackendNames"
}
},
"LoginFormVisible": {
"type": "boolean"
},
"MailLinkOnly": {
"type": "boolean"
},
"MinPasswordLength": {
"type": "integer"
},
"PersistentConfigSupported": {
"type": "boolean"
},
"SelfProvisioning": {
"type": "boolean"
},
"WebAuthnEnabled": {
"type": "boolean"
}
}
},
"model.SettingsBackendNames": {
"type": "object",
"properties": {
"Id": {
"type": "string"
},
"Name": {
"type": "string"
}
}
},
@@ -2207,6 +2392,28 @@
"type": "string"
}
}
},
"model.WebAuthnCredentialRequest": {
"type": "object",
"properties": {
"Name": {
"type": "string"
}
}
},
"model.WebAuthnCredentialResponse": {
"type": "object",
"properties": {
"CreatedAt": {
"type": "string"
},
"ID": {
"type": "string"
},
"Name": {
"type": "string"
}
}
}
}
}

View File

@@ -65,6 +65,10 @@ definitions:
items:
type: string
type: array
Backend:
description: the backend used for this interface e.g., local, mikrotik, ...
example: local
type: string
Disabled:
description: flag that specifies if the interface is enabled (up) or not (down)
type: boolean
@@ -318,6 +322,9 @@ definitions:
allOf:
- $ref: '#/definitions/model.ConfigOption-string'
description: the routing table
UserDisplayName:
description: the owner display name
type: string
UserIdentifier:
description: the owner
type: string
@@ -381,6 +388,12 @@ definitions:
properties:
ApiAdminOnly:
type: boolean
AvailableBackends:
items:
$ref: '#/definitions/model.SettingsBackendNames'
type: array
LoginFormVisible:
type: boolean
MailLinkOnly:
type: boolean
MinPasswordLength:
@@ -389,6 +402,15 @@ definitions:
type: boolean
SelfProvisioning:
type: boolean
WebAuthnEnabled:
type: boolean
type: object
model.SettingsBackendNames:
properties:
Id:
type: string
Name:
type: string
type: object
model.User:
properties:
@@ -435,6 +457,20 @@ definitions:
Source:
type: string
type: object
model.WebAuthnCredentialRequest:
properties:
Name:
type: string
type: object
model.WebAuthnCredentialResponse:
properties:
CreatedAt:
type: string
ID:
type: string
Name:
type: string
type: object
info:
contact:
name: WireGuard Portal Developers
@@ -550,6 +586,102 @@ paths:
summary: Get information about the currently logged-in user.
tags:
- Authentication
/auth/webauthn/credential/{id}:
delete:
operationId: auth_handleWebAuthnCredentialsDelete
parameters:
- description: Base64 encoded Credential ID
in: path
name: id
required: true
type: string
produces:
- application/json
responses:
"200":
description: OK
schema:
items:
$ref: '#/definitions/model.WebAuthnCredentialResponse'
type: array
summary: Delete a WebAuthn credential.
tags:
- Authentication
put:
operationId: auth_handleWebAuthnCredentialsPut
parameters:
- description: Base64 encoded Credential ID
in: path
name: id
required: true
type: string
- description: Credential name
in: body
name: request
required: true
schema:
$ref: '#/definitions/model.WebAuthnCredentialRequest'
produces:
- application/json
responses:
"200":
description: OK
schema:
items:
$ref: '#/definitions/model.WebAuthnCredentialResponse'
type: array
summary: Update a WebAuthn credential.
tags:
- Authentication
/auth/webauthn/credentials:
get:
operationId: auth_handleWebAuthnCredentialsGet
produces:
- application/json
responses:
"200":
description: OK
schema:
items:
$ref: '#/definitions/model.WebAuthnCredentialResponse'
type: array
summary: Get all available external login providers.
tags:
- Authentication
/auth/webauthn/login/finish:
post:
operationId: auth_handleWebAuthnLoginFinish
produces:
- application/json
responses:
"200":
description: OK
schema:
$ref: '#/definitions/model.User'
summary: Finish the WebAuthn login process.
tags:
- Authentication
/auth/webauthn/register/finish:
post:
operationId: auth_handleWebAuthnRegisterFinish
parameters:
- default: '""'
description: Credential name
in: query
name: credential_name
type: string
produces:
- application/json
responses:
"200":
description: OK
schema:
items:
$ref: '#/definitions/model.WebAuthnCredentialResponse'
type: array
summary: Finish the WebAuthn registration process.
tags:
- Authentication
/config/frontend.js:
get:
operationId: config_handleConfigJsGet
@@ -958,6 +1090,10 @@ paths:
required: true
schema:
$ref: '#/definitions/model.PeerMailRequest'
- description: The configuration style
in: query
name: style
type: string
produces:
- application/json
responses:
@@ -983,6 +1119,10 @@ paths:
name: id
required: true
type: string
- description: The configuration style
in: query
name: style
type: string
produces:
- image/png
- application/json
@@ -1011,6 +1151,10 @@ paths:
name: id
required: true
type: string
- description: The configuration style
in: query
name: style
type: string
produces:
- application/json
responses:
@@ -1301,6 +1445,27 @@ paths:
summary: Enable the REST API for the given user.
tags:
- Users
/user/{id}/change-password:
post:
operationId: users_handleChangePasswordPost
produces:
- application/json
responses:
"200":
description: OK
schema:
$ref: '#/definitions/model.User'
"400":
description: Bad Request
schema:
$ref: '#/definitions/model.Error'
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/model.Error'
summary: Change the password for the given user.
tags:
- Users
/user/{id}/interfaces:
get:
operationId: users_handleInterfacesGet

View File

@@ -17,11 +17,6 @@
"paths": {
"/interface/all": {
"get": {
"security": [
{
"BasicAuth": []
}
],
"produces": [
"application/json"
],
@@ -52,16 +47,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/interface/by-id/{id}": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/interface/by-id/{id}": {
"get": {
"produces": [
"application/json"
],
@@ -110,14 +105,14 @@
"$ref": "#/definitions/models.Error"
}
}
}
},
"put": {
},
"security": [
{
"BasicAuth": []
}
],
]
},
"put": {
"description": "This endpoint updates an existing interface with the provided data. All required fields must be filled (e.g. name, private key, public key, ...).",
"produces": [
"application/json"
@@ -182,14 +177,14 @@
"$ref": "#/definitions/models.Error"
}
}
}
},
"delete": {
},
"security": [
{
"BasicAuth": []
}
],
]
},
"delete": {
"produces": [
"application/json"
],
@@ -241,16 +236,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/interface/new": {
"post": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/interface/new": {
"post": {
"description": "This endpoint creates a new interface with the provided data. All required fields must be filled (e.g. name, private key, public key, ...).",
"produces": [
"application/json"
@@ -308,16 +303,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/interface/prepare": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/interface/prepare": {
"get": {
"description": "This endpoint returns a new interface with default values (fresh key pair, valid name, new IP address pool, ...).",
"produces": [
"application/json"
@@ -352,16 +347,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/metrics/by-interface/{id}": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/metrics/by-interface/{id}": {
"get": {
"produces": [
"application/json"
],
@@ -410,16 +405,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/metrics/by-peer/{id}": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/metrics/by-peer/{id}": {
"get": {
"produces": [
"application/json"
],
@@ -468,16 +463,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/metrics/by-user/{id}": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/metrics/by-user/{id}": {
"get": {
"produces": [
"application/json"
],
@@ -526,16 +521,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/peer/by-id/{id}": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/peer/by-id/{id}": {
"get": {
"description": "Normal users can only access their own records. Admins can access all records.",
"produces": [
"application/json"
@@ -585,14 +580,14 @@
"$ref": "#/definitions/models.Error"
}
}
}
},
"put": {
},
"security": [
{
"BasicAuth": []
}
],
]
},
"put": {
"description": "Only admins can update existing records. The peer record must contain all required fields (e.g., public key, allowed IPs).",
"produces": [
"application/json"
@@ -657,14 +652,14 @@
"$ref": "#/definitions/models.Error"
}
}
}
},
"delete": {
},
"security": [
{
"BasicAuth": []
}
],
]
},
"delete": {
"produces": [
"application/json"
],
@@ -716,16 +711,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/peer/by-interface/{id}": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/peer/by-interface/{id}": {
"get": {
"produces": [
"application/json"
],
@@ -765,16 +760,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/peer/by-user/{id}": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/peer/by-user/{id}": {
"get": {
"description": "Normal users can only access their own records. Admins can access all records.",
"produces": [
"application/json"
@@ -815,16 +810,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/peer/new": {
"post": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/peer/new": {
"post": {
"description": "Only admins can create new records. The peer record must contain all required fields (e.g., public key, allowed IPs).",
"produces": [
"application/json"
@@ -882,16 +877,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/peer/prepare/{id}": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/peer/prepare/{id}": {
"get": {
"description": "This endpoint is used to prepare a new peer record. The returned data contains a fresh key pair and valid ip address.",
"produces": [
"application/json"
@@ -947,16 +942,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/provisioning/data/peer-config": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/provisioning/data/peer-config": {
"get": {
"description": "Normal users can only access their own record. Admins can access all records.",
"produces": [
"text/plain",
@@ -1013,16 +1008,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/provisioning/data/peer-qr": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/provisioning/data/peer-qr": {
"get": {
"description": "Normal users can only access their own record. Admins can access all records.",
"produces": [
"image/png",
@@ -1079,16 +1074,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/provisioning/data/user-info": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/provisioning/data/user-info": {
"get": {
"description": "Normal users can only access their own record. Admins can access all records.",
"produces": [
"application/json"
@@ -1149,16 +1144,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/provisioning/new-peer": {
"post": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/provisioning/new-peer": {
"post": {
"description": "Normal users can only create new peers if self provisioning is allowed. Admins can always add new peers.",
"produces": [
"application/json"
@@ -1216,16 +1211,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/user/all": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/user/all": {
"get": {
"produces": [
"application/json"
],
@@ -1256,16 +1251,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/user/by-id/{id}": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/user/by-id/{id}": {
"get": {
"description": "Normal users can only access their own record. Admins can access all records.",
"produces": [
"application/json"
@@ -1315,14 +1310,14 @@
"$ref": "#/definitions/models.Error"
}
}
}
},
"put": {
},
"security": [
{
"BasicAuth": []
}
],
]
},
"put": {
"description": "Only admins can update existing records.",
"produces": [
"application/json"
@@ -1387,14 +1382,14 @@
"$ref": "#/definitions/models.Error"
}
}
}
},
"delete": {
},
"security": [
{
"BasicAuth": []
}
],
]
},
"delete": {
"produces": [
"application/json"
],
@@ -1446,16 +1441,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/user/new": {
"post": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/user/new": {
"post": {
"description": "Only admins can create new records.",
"produces": [
"application/json"
@@ -1513,7 +1508,12 @@
"$ref": "#/definitions/models.Error"
}
}
}
},
"security": [
{
"BasicAuth": []
}
]
}
}
},
@@ -2086,6 +2086,11 @@
"InterfaceIdentifier"
],
"properties": {
"DisplayName": {
"description": "DisplayName is an optional name for the new peer.\nIf unset, a default template value (e.g., \"API Peer ...\") will be assigned.",
"type": "string",
"example": "API Peer xyz"
},
"InterfaceIdentifier": {
"description": "InterfaceIdentifier is the identifier of the WireGuard interface the peer should be linked to.",
"type": "string",

View File

@@ -445,6 +445,12 @@ definitions:
type: object
models.ProvisioningRequest:
properties:
DisplayName:
description: |-
DisplayName is an optional name for the new peer.
If unset, a default template value (e.g., "API Peer ...") will be assigned.
example: API Peer xyz
type: string
InterfaceIdentifier:
description: InterfaceIdentifier is the identifier of the WireGuard interface
the peer should be linked to.

View File

@@ -100,6 +100,7 @@ func (s *Server) Run(ctx context.Context, listenAddress string) {
srvContext, cancelFn := context.WithCancel(ctx)
go func() {
var err error
slog.Debug("starting server", "certFile", s.cfg.Web.CertFile, "keyFile", s.cfg.Web.KeyFile)
if s.cfg.Web.CertFile != "" && s.cfg.Web.KeyFile != "" {
err = srv.ListenAndServeTLS(s.cfg.Web.CertFile, s.cfg.Web.KeyFile)
} else {

View File

@@ -27,12 +27,12 @@ type PeerServicePeerManager interface {
}
type PeerServiceConfigFileManager interface {
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier, style string) (io.Reader, error)
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier, style string) (io.Reader, error)
}
type PeerServiceMailManager interface {
SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error
SendPeerEmail(ctx context.Context, linkOnly bool, style string, peers ...domain.PeerIdentifier) error
}
// endregion dependencies
@@ -95,16 +95,24 @@ func (p PeerService) DeletePeer(ctx context.Context, id domain.PeerIdentifier) e
return p.peers.DeletePeer(ctx, id)
}
func (p PeerService) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
return p.configFile.GetPeerConfig(ctx, id)
func (p PeerService) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier, style string) (io.Reader, error) {
return p.configFile.GetPeerConfig(ctx, id, style)
}
func (p PeerService) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
return p.configFile.GetPeerConfigQrCode(ctx, id)
func (p PeerService) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier, style string) (
io.Reader,
error,
) {
return p.configFile.GetPeerConfigQrCode(ctx, id, style)
}
func (p PeerService) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error {
return p.mailer.SendPeerEmail(ctx, linkOnly, peers...)
func (p PeerService) SendPeerEmail(
ctx context.Context,
linkOnly bool,
style string,
peers ...domain.PeerIdentifier,
) error {
return p.mailer.SendPeerEmail(ctx, linkOnly, style, peers...)
}
func (p PeerService) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error) {

View File

@@ -2,6 +2,8 @@ package backend
import (
"context"
"fmt"
"strings"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
@@ -70,6 +72,44 @@ func (u UserService) DeactivateApi(ctx context.Context, id domain.UserIdentifier
return u.users.DeactivateApi(ctx, id)
}
func (u UserService) ChangePassword(ctx context.Context, id domain.UserIdentifier, oldPassword, newPassword string) (*domain.User, error) {
oldPassword = strings.TrimSpace(oldPassword)
newPassword = strings.TrimSpace(newPassword)
if newPassword == "" {
return nil, fmt.Errorf("new password must not be empty")
}
// ensure that the new password is different from the old one
if oldPassword == newPassword {
return nil, fmt.Errorf("new password must be different from the old one")
}
user, err := u.users.GetUser(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
// ensure that the user uses the database backend; otherwise we can't change the password
if user.Source != domain.UserSourceDatabase {
return nil, fmt.Errorf("user source %s does not support password changes", user.Source)
}
// validate old password
if user.CheckPassword(oldPassword) != nil {
return nil, fmt.Errorf("current password is invalid")
}
user.Password = domain.PrivateString(newPassword)
// ensure that the new password is strong enough
if err := user.HasWeakPassword(u.cfg.Auth.MinPasswordLength); err != nil {
return nil, err
}
return u.users.UpdateUser(ctx, user)
}
func (u UserService) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
return u.wg.GetUserPeers(ctx, id)
}

View File

@@ -99,6 +99,8 @@ type Authenticator interface {
LoggedIn(scopes ...Scope) func(next http.Handler) http.Handler
// UserIdMatch checks if the user id in the session matches the user id in the request. If not, the request is aborted.
UserIdMatch(idParameter string) func(next http.Handler) http.Handler
// InfoOnly only add user info to the request context. No login check is performed.
InfoOnly() func(next http.Handler) http.Handler
}
type Session interface {

View File

@@ -29,12 +29,54 @@ type AuthenticationService interface {
OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, error)
}
type WebAuthnService interface {
Enabled() bool
StartWebAuthnRegistration(ctx context.Context, userId domain.UserIdentifier) (
responseOptions []byte,
sessionData []byte,
err error,
)
FinishWebAuthnRegistration(
ctx context.Context,
userId domain.UserIdentifier,
name string,
sessionDataAsJSON []byte,
r *http.Request,
) ([]domain.UserWebauthnCredential, error)
GetCredentials(
ctx context.Context,
userId domain.UserIdentifier,
) ([]domain.UserWebauthnCredential, error)
RemoveCredential(
ctx context.Context,
userId domain.UserIdentifier,
credentialIdBase64 string,
) ([]domain.UserWebauthnCredential, error)
UpdateCredential(
ctx context.Context,
userId domain.UserIdentifier,
credentialIdBase64 string,
name string,
) ([]domain.UserWebauthnCredential, error)
StartWebAuthnLogin(_ context.Context) (
optionsAsJSON []byte,
sessionDataAsJSON []byte,
err error,
)
FinishWebAuthnLogin(
ctx context.Context,
sessionDataAsJSON []byte,
r *http.Request,
) (*domain.User, error)
}
type AuthEndpoint struct {
cfg *config.Config
authService AuthenticationService
authenticator Authenticator
session Session
validate Validator
webAuthn WebAuthnService
}
func NewAuthEndpoint(
@@ -43,6 +85,7 @@ func NewAuthEndpoint(
session Session,
validator Validator,
authService AuthenticationService,
webAuthn WebAuthnService,
) AuthEndpoint {
return AuthEndpoint{
cfg: cfg,
@@ -50,6 +93,7 @@ func NewAuthEndpoint(
authenticator: authenticator,
session: session,
validate: validator,
webAuthn: webAuthn,
}
}
@@ -66,6 +110,19 @@ func (e AuthEndpoint) RegisterRoutes(g *routegroup.Bundle) {
apiGroup.HandleFunc("GET /login/{provider}/init", e.handleOauthInitiateGet())
apiGroup.HandleFunc("GET /login/{provider}/callback", e.handleOauthCallbackGet())
apiGroup.HandleFunc("POST /webauthn/login/start", e.handleWebAuthnLoginStart())
apiGroup.HandleFunc("POST /webauthn/login/finish", e.handleWebAuthnLoginFinish())
apiGroup.With(e.authenticator.LoggedIn()).HandleFunc("GET /webauthn/credentials",
e.handleWebAuthnCredentialsGet())
apiGroup.With(e.authenticator.LoggedIn()).HandleFunc("POST /webauthn/register/start",
e.handleWebAuthnRegisterStart())
apiGroup.With(e.authenticator.LoggedIn()).HandleFunc("POST /webauthn/register/finish",
e.handleWebAuthnRegisterFinish())
apiGroup.With(e.authenticator.LoggedIn()).HandleFunc("DELETE /webauthn/credential/{id}",
e.handleWebAuthnCredentialsDelete())
apiGroup.With(e.authenticator.LoggedIn()).HandleFunc("PUT /webauthn/credential/{id}",
e.handleWebAuthnCredentialsPut())
apiGroup.HandleFunc("POST /login", e.handleLoginPost())
apiGroup.With(e.authenticator.LoggedIn()).HandleFunc("POST /logout", e.handleLogoutPost())
}
@@ -398,3 +455,237 @@ func (e AuthEndpoint) isValidReturnUrl(returnUrl string) bool {
return true
}
// handleWebAuthnCredentialsGet returns a gorm Handler function.
//
// @ID auth_handleWebAuthnCredentialsGet
// @Tags Authentication
// @Summary Get all available external login providers.
// @Produce json
// @Success 200 {object} []model.WebAuthnCredentialResponse
// @Router /auth/webauthn/credentials [get]
func (e AuthEndpoint) handleWebAuthnCredentialsGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !e.webAuthn.Enabled() {
respond.JSON(w, http.StatusOK, []model.WebAuthnCredentialResponse{})
return
}
currentSession := e.session.GetData(r.Context())
userIdentifier := domain.UserIdentifier(currentSession.UserIdentifier)
credentials, err := e.webAuthn.GetCredentials(r.Context(), userIdentifier)
if err != nil {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
respond.JSON(w, http.StatusOK, model.NewWebAuthnCredentialResponses(credentials))
}
}
// handleWebAuthnCredentialsDelete returns a gorm Handler function.
//
// @ID auth_handleWebAuthnCredentialsDelete
// @Tags Authentication
// @Summary Delete a WebAuthn credential.
// @Param id path string true "Base64 encoded Credential ID"
// @Produce json
// @Success 200 {object} []model.WebAuthnCredentialResponse
// @Router /auth/webauthn/credential/{id} [delete]
func (e AuthEndpoint) handleWebAuthnCredentialsDelete() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !e.webAuthn.Enabled() {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "WebAuthn is not enabled"})
return
}
currentSession := e.session.GetData(r.Context())
userIdentifier := domain.UserIdentifier(currentSession.UserIdentifier)
credentialId := Base64UrlDecode(request.Path(r, "id"))
credentials, err := e.webAuthn.RemoveCredential(r.Context(), userIdentifier, credentialId)
if err != nil {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
respond.JSON(w, http.StatusOK, model.NewWebAuthnCredentialResponses(credentials))
}
}
// handleWebAuthnCredentialsPut returns a gorm Handler function.
//
// @ID auth_handleWebAuthnCredentialsPut
// @Tags Authentication
// @Summary Update a WebAuthn credential.
// @Param id path string true "Base64 encoded Credential ID"
// @Param request body model.WebAuthnCredentialRequest true "Credential name"
// @Produce json
// @Success 200 {object} []model.WebAuthnCredentialResponse
// @Router /auth/webauthn/credential/{id} [put]
func (e AuthEndpoint) handleWebAuthnCredentialsPut() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !e.webAuthn.Enabled() {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "WebAuthn is not enabled"})
return
}
currentSession := e.session.GetData(r.Context())
userIdentifier := domain.UserIdentifier(currentSession.UserIdentifier)
credentialId := Base64UrlDecode(request.Path(r, "id"))
var req model.WebAuthnCredentialRequest
if err := request.BodyJson(r, &req); err != nil {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
credentials, err := e.webAuthn.UpdateCredential(r.Context(), userIdentifier, credentialId, req.Name)
if err != nil {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
respond.JSON(w, http.StatusOK, model.NewWebAuthnCredentialResponses(credentials))
}
}
func (e AuthEndpoint) handleWebAuthnRegisterStart() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !e.webAuthn.Enabled() {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "WebAuthn is not enabled"})
return
}
currentSession := e.session.GetData(r.Context())
userIdentifier := domain.UserIdentifier(currentSession.UserIdentifier)
options, sessionData, err := e.webAuthn.StartWebAuthnRegistration(r.Context(), userIdentifier)
if err != nil {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
currentSession.WebAuthnData = string(sessionData)
e.session.SetData(r.Context(), currentSession)
respond.Data(w, http.StatusOK, "application/json", options)
}
}
// handleWebAuthnRegisterFinish returns a gorm Handler function.
//
// @ID auth_handleWebAuthnRegisterFinish
// @Tags Authentication
// @Summary Finish the WebAuthn registration process.
// @Param credential_name query string false "Credential name" default("")
// @Produce json
// @Success 200 {object} []model.WebAuthnCredentialResponse
// @Router /auth/webauthn/register/finish [post]
func (e AuthEndpoint) handleWebAuthnRegisterFinish() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !e.webAuthn.Enabled() {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "WebAuthn is not enabled"})
return
}
name := request.QueryDefault(r, "credential_name", "")
currentSession := e.session.GetData(r.Context())
webAuthnSessionData := []byte(currentSession.WebAuthnData)
currentSession.WebAuthnData = "" // clear the session data
e.session.SetData(r.Context(), currentSession)
credentials, err := e.webAuthn.FinishWebAuthnRegistration(
r.Context(),
domain.UserIdentifier(currentSession.UserIdentifier),
name,
webAuthnSessionData,
r)
if err != nil {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
respond.JSON(w, http.StatusOK, model.NewWebAuthnCredentialResponses(credentials))
}
}
func (e AuthEndpoint) handleWebAuthnLoginStart() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !e.webAuthn.Enabled() {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "WebAuthn is not enabled"})
return
}
currentSession := e.session.GetData(r.Context())
options, sessionData, err := e.webAuthn.StartWebAuthnLogin(r.Context())
if err != nil {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
currentSession.WebAuthnData = string(sessionData)
e.session.SetData(r.Context(), currentSession)
respond.Data(w, http.StatusOK, "application/json", options)
}
}
// handleWebAuthnLoginFinish returns a gorm Handler function.
//
// @ID auth_handleWebAuthnLoginFinish
// @Tags Authentication
// @Summary Finish the WebAuthn login process.
// @Produce json
// @Success 200 {object} model.User
// @Router /auth/webauthn/login/finish [post]
func (e AuthEndpoint) handleWebAuthnLoginFinish() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !e.webAuthn.Enabled() {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "WebAuthn is not enabled"})
return
}
currentSession := e.session.GetData(r.Context())
webAuthnSessionData := []byte(currentSession.WebAuthnData)
currentSession.WebAuthnData = "" // clear the session data
e.session.SetData(r.Context(), currentSession)
user, err := e.webAuthn.FinishWebAuthnLogin(
r.Context(),
webAuthnSessionData,
r)
if err != nil {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
e.setAuthenticatedUser(r, user)
respond.JSON(w, http.StatusOK, model.NewUser(user, false))
}
}

View File

@@ -15,22 +15,29 @@ import (
"github.com/h44z/wg-portal/internal/app/api/core/respond"
"github.com/h44z/wg-portal/internal/app/api/v0/model"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
//go:embed frontend_config.js.gotpl
var frontendJs embed.FS
type ControllerManager interface {
GetControllerNames() []config.BackendBase
}
type ConfigEndpoint struct {
cfg *config.Config
authenticator Authenticator
controllerMgr ControllerManager
tpl *respond.TemplateRenderer
}
func NewConfigEndpoint(cfg *config.Config, authenticator Authenticator) ConfigEndpoint {
func NewConfigEndpoint(cfg *config.Config, authenticator Authenticator, ctrlMgr ControllerManager) ConfigEndpoint {
ep := ConfigEndpoint{
cfg: cfg,
authenticator: authenticator,
controllerMgr: ctrlMgr,
tpl: respond.NewTemplateRenderer(template.Must(template.ParseFS(frontendJs,
"frontend_config.js.gotpl"))),
}
@@ -46,7 +53,7 @@ func (e ConfigEndpoint) RegisterRoutes(g *routegroup.Bundle) {
apiGroup := g.Mount("/config")
apiGroup.HandleFunc("GET /frontend.js", e.handleConfigJsGet())
apiGroup.With(e.authenticator.LoggedIn()).HandleFunc("GET /settings", e.handleSettingsGet())
apiGroup.With(e.authenticator.InfoOnly()).HandleFunc("GET /settings", e.handleSettingsGet())
}
// handleConfigJsGet returns a gorm Handler function.
@@ -93,11 +100,50 @@ func (e ConfigEndpoint) handleConfigJsGet() http.HandlerFunc {
// @Router /config/settings [get]
func (e ConfigEndpoint) handleSettingsGet() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
respond.JSON(w, http.StatusOK, model.Settings{
MailLinkOnly: e.cfg.Mail.LinkOnly,
PersistentConfigSupported: e.cfg.Advanced.ConfigStoragePath != "",
SelfProvisioning: e.cfg.Core.SelfProvisioningAllowed,
ApiAdminOnly: e.cfg.Advanced.ApiAdminOnly,
})
sessionUser := domain.GetUserInfo(r.Context())
controllerFn := func() []model.SettingsBackendNames {
controllers := e.controllerMgr.GetControllerNames()
names := make([]model.SettingsBackendNames, 0, len(controllers))
for _, controller := range controllers {
displayName := controller.GetDisplayName()
if displayName == "" {
displayName = controller.Id // fallback to ID if no display name is set
}
if controller.Id == config.LocalBackendName {
displayName = "modals.interface-edit.backend.local" // use a localized string for the local backend
}
names = append(names, model.SettingsBackendNames{
Id: controller.Id,
Name: displayName,
})
}
return names
}
hasSocialLogin := len(e.cfg.Auth.OAuth) > 0 || len(e.cfg.Auth.OpenIDConnect) > 0 || e.cfg.Auth.WebAuthn.Enabled
// For anonymous users, we return the settings object with minimal information
if sessionUser.Id == domain.CtxUnknownUserId || sessionUser.Id == "" {
respond.JSON(w, http.StatusOK, model.Settings{
WebAuthnEnabled: e.cfg.Auth.WebAuthn.Enabled,
AvailableBackends: []model.SettingsBackendNames{}, // return an empty list instead of null
LoginFormVisible: !e.cfg.Auth.HideLoginForm || !hasSocialLogin,
})
} else {
respond.JSON(w, http.StatusOK, model.Settings{
MailLinkOnly: e.cfg.Mail.LinkOnly,
PersistentConfigSupported: e.cfg.Advanced.ConfigStoragePath != "",
SelfProvisioning: e.cfg.Core.SelfProvisioningAllowed,
ApiAdminOnly: e.cfg.Advanced.ApiAdminOnly,
WebAuthnEnabled: e.cfg.Auth.WebAuthn.Enabled,
MinPasswordLength: e.cfg.Auth.MinPasswordLength,
AvailableBackends: controllerFn(),
LoginFormVisible: !e.cfg.Auth.HideLoginForm || !hasSocialLogin,
})
}
}
}

View File

@@ -34,11 +34,11 @@ type PeerService interface {
// DeletePeer deletes the peer with the given id.
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
// GetPeerConfig returns the peer configuration for the given id.
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier, style string) (io.Reader, error)
// GetPeerConfigQrCode returns the peer configuration as qr code for the given id.
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier, style string) (io.Reader, error)
// SendPeerEmail sends the peer configuration via email.
SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error
SendPeerEmail(ctx context.Context, linkOnly bool, style string, peers ...domain.PeerIdentifier) error
// GetPeerStats returns the peer stats for the given interface.
GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error)
}
@@ -355,6 +355,7 @@ func (e PeerEndpoint) handleDelete() http.HandlerFunc {
// @Summary Get peer configuration as string.
// @Produce json
// @Param id path string true "The peer identifier"
// @Param style query string false "The configuration style"
// @Success 200 {object} string
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
@@ -369,7 +370,9 @@ func (e PeerEndpoint) handleConfigGet() http.HandlerFunc {
return
}
configTxt, err := e.peerService.GetPeerConfig(r.Context(), domain.PeerIdentifier(id))
configStyle := e.getConfigStyle(r)
configTxt, err := e.peerService.GetPeerConfig(r.Context(), domain.PeerIdentifier(id), configStyle)
if err != nil {
respond.JSON(w, http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
@@ -397,6 +400,7 @@ func (e PeerEndpoint) handleConfigGet() http.HandlerFunc {
// @Produce png
// @Produce json
// @Param id path string true "The peer identifier"
// @Param style query string false "The configuration style"
// @Success 200 {file} binary
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
@@ -411,7 +415,9 @@ func (e PeerEndpoint) handleQrCodeGet() http.HandlerFunc {
return
}
configQr, err := e.peerService.GetPeerConfigQrCode(r.Context(), domain.PeerIdentifier(id))
configStyle := e.getConfigStyle(r)
configQr, err := e.peerService.GetPeerConfigQrCode(r.Context(), domain.PeerIdentifier(id), configStyle)
if err != nil {
respond.JSON(w, http.StatusInternalServerError, model.Error{
Code: http.StatusInternalServerError, Message: err.Error(),
@@ -438,6 +444,7 @@ func (e PeerEndpoint) handleQrCodeGet() http.HandlerFunc {
// @Summary Send peer configuration via email.
// @Produce json
// @Param request body model.PeerMailRequest true "The peer mail request data"
// @Param style query string false "The configuration style"
// @Success 204 "No content if mail sending was successful"
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
@@ -460,11 +467,13 @@ func (e PeerEndpoint) handleEmailPost() http.HandlerFunc {
return
}
configStyle := e.getConfigStyle(r)
peerIds := make([]domain.PeerIdentifier, len(req.Identifiers))
for i := range req.Identifiers {
peerIds[i] = domain.PeerIdentifier(req.Identifiers[i])
}
if err := e.peerService.SendPeerEmail(r.Context(), req.LinkOnly, peerIds...); err != nil {
if err := e.peerService.SendPeerEmail(r.Context(), req.LinkOnly, configStyle, peerIds...); err != nil {
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
@@ -504,3 +513,11 @@ func (e PeerEndpoint) handleStatsGet() http.HandlerFunc {
respond.JSON(w, http.StatusOK, model.NewPeerStats(e.cfg.Statistics.CollectPeerData, stats))
}
}
func (e PeerEndpoint) getConfigStyle(r *http.Request) string {
configStyle := request.QueryDefault(r, "style", domain.ConfigStyleWgQuick)
if configStyle != domain.ConfigStyleWgQuick && configStyle != domain.ConfigStyleRaw {
configStyle = domain.ConfigStyleWgQuick // default to wg-quick style
}
return configStyle
}

View File

@@ -28,6 +28,8 @@ type UserService interface {
ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
// DeactivateApi disables the API for the user with the given id.
DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
// ChangePassword changes the password for the user with the given id.
ChangePassword(ctx context.Context, id domain.UserIdentifier, oldPassword, newPassword string) (*domain.User, error)
// GetUserPeers returns all peers for the given user.
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
// GetUserPeerStats returns all peer stats for the given user.
@@ -75,6 +77,7 @@ func (e UserEndpoint) RegisterRoutes(g *routegroup.Bundle) {
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("GET /{id}/interfaces", e.handleInterfacesGet())
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("POST /{id}/api/enable", e.handleApiEnablePost())
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("POST /{id}/api/disable", e.handleApiDisablePost())
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("POST /{id}/change-password", e.handleChangePasswordPost())
}
// handleAllGet returns a gorm Handler function.
@@ -391,3 +394,68 @@ func (e UserEndpoint) handleApiDisablePost() http.HandlerFunc {
respond.JSON(w, http.StatusOK, model.NewUser(user, false))
}
}
// handleChangePasswordPost returns a gorm Handler function.
//
// @ID users_handleChangePasswordPost
// @Tags Users
// @Summary Change the password for the given user.
// @Produce json
// @Success 200 {object} model.User
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /user/{id}/change-password [post]
func (e UserEndpoint) handleChangePasswordPost() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userId := Base64UrlDecode(request.Path(r, "id"))
if userId == "" {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusInternalServerError, Message: "missing id parameter"})
return
}
var passwordData struct {
OldPassword string `json:"OldPassword"`
Password string `json:"Password"`
PasswordRepeat string `json:"PasswordRepeat"`
}
if err := request.BodyJson(r, &passwordData); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if passwordData.OldPassword == "" {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "old password missing"})
return
}
if passwordData.Password == "" {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "new password missing"})
return
}
if passwordData.OldPassword == passwordData.Password {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "password did not change"})
return
}
if passwordData.Password != passwordData.PasswordRepeat {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "password mismatch"})
return
}
user, err := e.userService.ChangePassword(r.Context(), domain.UserIdentifier(userId),
passwordData.OldPassword, passwordData.Password)
if err != nil {
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
respond.JSON(w, http.StatusOK, model.NewUser(user, false))
}
}

View File

@@ -72,6 +72,32 @@ func (h AuthenticationHandler) LoggedIn(scopes ...Scope) func(next http.Handler)
}
}
// InfoOnly only checks if the user is logged in and adds the user id to the context.
// If the user is not logged in, the context user id is set to domain.CtxUnknownUserId.
func (h AuthenticationHandler) InfoOnly() func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session := h.session.GetData(r.Context())
var newContext context.Context
if !session.LoggedIn {
newContext = domain.SetUserInfo(r.Context(), domain.DefaultContextUserInfo())
} else {
newContext = domain.SetUserInfo(r.Context(), &domain.ContextUserInfo{
Id: domain.UserIdentifier(session.UserIdentifier),
IsAdmin: session.IsAdmin,
})
}
r = r.WithContext(newContext)
// Continue down the chain to Handler etc
next.ServeHTTP(w, r)
})
}
}
// UserIdMatch checks if the user id in the session matches the user id in the request. If not, the request is aborted.
func (h AuthenticationHandler) UserIdMatch(idParameter string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {

View File

@@ -31,6 +31,8 @@ type SessionData struct {
OauthProvider string
OauthReturnTo string
WebAuthnData string
CsrfToken string
}

View File

@@ -6,8 +6,17 @@ type Error struct {
}
type Settings struct {
MailLinkOnly bool `json:"MailLinkOnly"`
PersistentConfigSupported bool `json:"PersistentConfigSupported"`
SelfProvisioning bool `json:"SelfProvisioning"`
ApiAdminOnly bool `json:"ApiAdminOnly"`
MailLinkOnly bool `json:"MailLinkOnly"`
PersistentConfigSupported bool `json:"PersistentConfigSupported"`
SelfProvisioning bool `json:"SelfProvisioning"`
ApiAdminOnly bool `json:"ApiAdminOnly"`
WebAuthnEnabled bool `json:"WebAuthnEnabled"`
MinPasswordLength int `json:"MinPasswordLength"`
AvailableBackends []SettingsBackendNames `json:"AvailableBackends"`
LoginFormVisible bool `json:"LoginFormVisible"`
}
type SettingsBackendNames struct {
Id string `json:"Id"`
Name string `json:"Name"`
}

View File

@@ -1,6 +1,11 @@
package model
import "github.com/h44z/wg-portal/internal/domain"
import (
"slices"
"strings"
"github.com/h44z/wg-portal/internal/domain"
)
type LoginProviderInfo struct {
Identifier string `json:"Identifier" example:"google"`
@@ -39,3 +44,32 @@ type OauthInitiationResponse struct {
RedirectUrl string
State string
}
type WebAuthnCredentialRequest struct {
Name string `json:"Name"`
}
type WebAuthnCredentialResponse struct {
ID string `json:"ID"`
Name string `json:"Name"`
CreatedAt string `json:"CreatedAt"`
}
func NewWebAuthnCredentialResponse(src domain.UserWebauthnCredential) WebAuthnCredentialResponse {
return WebAuthnCredentialResponse{
ID: src.CredentialIdentifier,
Name: src.DisplayName,
CreatedAt: src.CreatedAt.Format("2006-01-02 15:04:05"),
}
}
func NewWebAuthnCredentialResponses(src []domain.UserWebauthnCredential) []WebAuthnCredentialResponse {
credentials := make([]WebAuthnCredentialResponse, len(src))
for i := range src {
credentials[i] = NewWebAuthnCredentialResponse(src[i])
}
// Sort by CreatedAt, newest first
slices.SortFunc(credentials, func(i, j WebAuthnCredentialResponse) int {
return strings.Compare(i.CreatedAt, j.CreatedAt)
})
return credentials
}

View File

@@ -4,6 +4,7 @@ import (
"time"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
@@ -11,6 +12,7 @@ type Interface struct {
Identifier string `json:"Identifier" example:"wg0"` // device name, for example: wg0
DisplayName string `json:"DisplayName"` // a nice display name/ description for the interface
Mode string `json:"Mode" example:"server"` // the interface type, either 'server', 'client' or 'any'
Backend string `json:"Backend" example:"local"` // the backend used for this interface e.g., local, mikrotik, ...
PrivateKey string `json:"PrivateKey" example:"abcdef=="` // private Key of the server interface
PublicKey string `json:"PublicKey" example:"abcdef=="` // public Key of the server interface
Disabled bool `json:"Disabled"` // flag that specifies if the interface is enabled (up) or not (down)
@@ -57,6 +59,7 @@ func NewInterface(src *domain.Interface, peers []domain.Peer) *Interface {
Identifier: string(src.Identifier),
DisplayName: src.DisplayName,
Mode: string(src.Type),
Backend: string(src.Backend),
PrivateKey: src.PrivateKey,
PublicKey: src.PublicKey,
Disabled: src.IsDisabled(),
@@ -92,6 +95,10 @@ func NewInterface(src *domain.Interface, peers []domain.Peer) *Interface {
Filename: src.GetConfigFileName(),
}
if iface.Backend == "" {
iface.Backend = config.LocalBackendName // default to local backend
}
if len(peers) > 0 {
iface.TotalPeers = len(peers)
@@ -146,6 +153,7 @@ func NewDomainInterface(src *Interface) *domain.Interface {
SaveConfig: src.SaveConfig,
DisplayName: src.DisplayName,
Type: domain.InterfaceType(src.Mode),
Backend: domain.InterfaceBackend(src.Backend),
DriverType: "", // currently unused
Disabled: nil, // set below
DisabledReason: src.DisabledReason,

View File

@@ -43,6 +43,7 @@ type Peer struct {
Identifier string `json:"Identifier" example:"super_nice_peer"` // peer unique identifier
DisplayName string `json:"DisplayName"` // a nice display name/ description for the peer
UserIdentifier string `json:"UserIdentifier"` // the owner
UserDisplayName string `json:"UserDisplayName"` // the owner display name
InterfaceIdentifier string `json:"InterfaceIdentifier"` // the interface id
Disabled bool `json:"Disabled"` // flag that specifies if the peer is enabled (up) or not (down)
DisabledReason string `json:"DisabledReason"` // the reason why the peer has been disabled
@@ -80,7 +81,7 @@ type Peer struct {
}
func NewPeer(src *domain.Peer) *Peer {
return &Peer{
p := &Peer{
Identifier: string(src.Identifier),
DisplayName: src.DisplayName,
UserIdentifier: string(src.UserIdentifier),
@@ -111,6 +112,12 @@ func NewPeer(src *domain.Peer) *Peer {
PostDown: ConfigOptionFromDomain(src.Interface.PostDown),
Filename: src.GetConfigFileName(),
}
if src.User != nil {
p.UserDisplayName = src.User.DisplayName()
}
return p
}
func NewPeers(src []domain.Peer) []Peer {
@@ -198,7 +205,7 @@ func NewPeerStats(enabled bool, src []domain.PeerStatus) *PeerStats {
for _, srcStat := range src {
stats[string(srcStat.PeerId)] = PeerStatData{
IsConnected: srcStat.IsConnected(),
IsConnected: srcStat.IsConnected,
IsPingable: srcStat.IsPingable,
LastPing: srcStat.LastPing,
BytesReceived: srcStat.BytesReceived,

View File

@@ -23,8 +23,8 @@ type ProvisioningServicePeerManagerRepo interface {
}
type ProvisioningServiceConfigFileManagerRepo interface {
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier, style string) (io.Reader, error)
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier, style string) (io.Reader, error)
}
type ProvisioningService struct {
@@ -96,7 +96,7 @@ func (p ProvisioningService) GetPeerConfig(ctx context.Context, peerId domain.Pe
return nil, err
}
peerCfgReader, err := p.configFiles.GetPeerConfig(ctx, peer.Identifier)
peerCfgReader, err := p.configFiles.GetPeerConfig(ctx, peer.Identifier, domain.ConfigStyleWgQuick)
if err != nil {
return nil, err
}
@@ -119,7 +119,7 @@ func (p ProvisioningService) GetPeerQrPng(ctx context.Context, peerId domain.Pee
return nil, err
}
peerCfgQrReader, err := p.configFiles.GetPeerConfigQrCode(ctx, peer.Identifier)
peerCfgQrReader, err := p.configFiles.GetPeerConfigQrCode(ctx, peer.Identifier, domain.ConfigStyleWgQuick)
if err != nil {
return nil, err
}
@@ -162,7 +162,11 @@ func (p ProvisioningService) NewPeer(ctx context.Context, req models.Provisionin
if req.PresharedKey != "" {
peer.PresharedKey = domain.PreSharedKey(req.PresharedKey)
}
peer.GenerateDisplayName("API")
if req.DisplayName == "" {
peer.GenerateDisplayName("API")
} else {
peer.DisplayName = req.DisplayName
}
// save new peer
peer, err = p.peers.CreatePeer(ctx, peer)

View File

@@ -68,6 +68,10 @@ type ProvisioningRequest struct {
// If no user identifier is set, the authenticated user is used.
UserIdentifier string `json:"UserIdentifier" example:"uid-1234567"`
// DisplayName is an optional name for the new peer.
// If unset, a default template value (e.g., "API Peer ...") will be assigned.
DisplayName string `json:"DisplayName" example:"API Peer xyz" binding:"omitempty"`
// PublicKey is the optional public key of the peer. If no public key is set, a new key pair is generated.
PublicKey string `json:"PublicKey" example:"xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=" binding:"omitempty,len=44"`
// PresharedKey is the optional pre-shared key of the peer. If no pre-shared key is set, a new key is generated.

View File

@@ -46,14 +46,18 @@ func Initialize(
users: users,
}
startupContext, cancel := context.WithTimeout(context.Background(), 30*time.Second)
startupContext, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
// Switch to admin user context
startupContext = domain.SetUserInfo(startupContext, domain.SystemAdminContextUserInfo())
if err := a.createDefaultUser(startupContext); err != nil {
return fmt.Errorf("failed to create default user: %w", err)
if !cfg.Core.AdminUserDisabled {
if err := a.createDefaultUser(startupContext); err != nil {
return fmt.Errorf("failed to create default user: %w", err)
}
} else {
slog.Info("Local Admin user disabled!")
}
if err := a.importNewInterfaces(startupContext); err != nil {

View File

@@ -93,6 +93,8 @@ type Authenticator struct {
// URL prefix for the callback endpoints, this is a combination of the external URL and the API prefix
callbackUrlPrefix string
callbackUrl *url.URL
users UserManager
}
@@ -102,82 +104,152 @@ func NewAuthenticator(cfg *config.Auth, extUrl string, bus EventBus, users UserM
error,
) {
a := &Authenticator{
cfg: cfg,
bus: bus,
users: users,
callbackUrlPrefix: fmt.Sprintf("%s/api/v0", extUrl),
cfg: cfg,
bus: bus,
users: users,
callbackUrlPrefix: fmt.Sprintf("%s/api/v0", extUrl),
oauthAuthenticators: make(map[string]AuthenticatorOauth, len(cfg.OpenIDConnect)+len(cfg.OAuth)),
ldapAuthenticators: make(map[string]AuthenticatorLdap, len(cfg.Ldap)),
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err := a.setupExternalAuthProviders(ctx)
parsedExtUrl, err := url.Parse(a.callbackUrlPrefix)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to parse external URL: %w", err)
}
a.callbackUrl = parsedExtUrl
return a, nil
}
func (a *Authenticator) setupExternalAuthProviders(ctx context.Context) error {
extUrl, err := url.Parse(a.callbackUrlPrefix)
if err != nil {
return fmt.Errorf("failed to parse external url: %w", err)
}
// StartBackgroundJobs starts the background jobs for the authenticator.
// It sets up the external authentication providers (OIDC, OAuth, LDAP) and retries in case of errors.
func (a *Authenticator) StartBackgroundJobs(ctx context.Context) {
go func() {
slog.Debug("setting up external auth providers...")
a.oauthAuthenticators = make(map[string]AuthenticatorOauth, len(a.cfg.OpenIDConnect)+len(a.cfg.OAuth))
a.ldapAuthenticators = make(map[string]AuthenticatorLdap, len(a.cfg.Ldap))
// Initialize local copies of authentication providers to allow retry in case of errors
oidcQueue := a.cfg.OpenIDConnect
oauthQueue := a.cfg.OAuth
ldapQueue := a.cfg.Ldap
for i := range a.cfg.OpenIDConnect { // OIDC
providerCfg := &a.cfg.OpenIDConnect[i]
// Immediate attempt
failedOidc, failedOauth, failedLdap := a.setupExternalAuthProviders(oidcQueue, oauthQueue, ldapQueue)
if len(failedOidc) == 0 && len(failedOauth) == 0 && len(failedLdap) == 0 {
slog.Info("successfully setup all external auth providers")
return
}
// Prepare for retries with only the failed ones
oidcQueue = failedOidc
oauthQueue = failedOauth
ldapQueue = failedLdap
slog.Warn("failed to setup some external auth providers, retrying in 30 seconds",
"failedOidc", len(failedOidc), "failedOauth", len(failedOauth), "failedLdap", len(failedLdap))
ticker := time.NewTicker(30 * time.Second) // Ticker for delay between retries
defer ticker.Stop()
for {
select {
case <-ticker.C:
failedOidc, failedOauth, failedLdap := a.setupExternalAuthProviders(oidcQueue, oauthQueue, ldapQueue)
if len(failedOidc) > 0 || len(failedOauth) > 0 || len(failedLdap) > 0 {
slog.Warn("failed to setup some external auth providers, retrying in 30 seconds",
"failedOidc", len(failedOidc), "failedOauth", len(failedOauth), "failedLdap", len(failedLdap))
// Retry failed providers
oidcQueue = failedOidc
oauthQueue = failedOauth
ldapQueue = failedLdap
} else {
slog.Info("successfully setup all external auth providers")
return // Exit goroutine if all providers are set up successfully
}
case <-ctx.Done():
slog.Info("context cancelled, stopping setup of external auth providers")
return // Exit goroutine if context is cancelled
}
}
}()
}
func (a *Authenticator) setupExternalAuthProviders(
oidc []config.OpenIDConnectProvider,
oauth []config.OAuthProvider,
ldap []config.LdapProvider,
) (
[]config.OpenIDConnectProvider,
[]config.OAuthProvider,
[]config.LdapProvider,
) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
var failedOidc []config.OpenIDConnectProvider
var failedOauth []config.OAuthProvider
var failedLdap []config.LdapProvider
for i := range oidc { // OIDC
providerCfg := &oidc[i]
providerId := strings.ToLower(providerCfg.ProviderName)
if _, exists := a.oauthAuthenticators[providerId]; exists {
return fmt.Errorf("auth provider with name %s is already registerd", providerId)
// this is an unrecoverable error, we cannot register the same provider twice
slog.Error("OIDC auth provider is already registered", "name", providerId)
continue // skip this provider
}
redirectUrl := *extUrl
redirectUrl := *a.callbackUrl
redirectUrl.Path = path.Join(redirectUrl.Path, "/auth/login/", providerId, "/callback")
provider, err := newOidcAuthenticator(ctx, redirectUrl.String(), providerCfg)
if err != nil {
return fmt.Errorf("failed to setup oidc authentication provider %s: %w", providerCfg.ProviderName, err)
failedOidc = append(failedOidc, oidc[i])
slog.Error("failed to setup oidc authentication provider", "name", providerId, "error", err)
continue
}
a.oauthAuthenticators[providerId] = provider
}
for i := range a.cfg.OAuth { // PLAIN OAUTH
providerCfg := &a.cfg.OAuth[i]
for i := range oauth { // PLAIN OAUTH
providerCfg := &oauth[i]
providerId := strings.ToLower(providerCfg.ProviderName)
if _, exists := a.oauthAuthenticators[providerId]; exists {
return fmt.Errorf("auth provider with name %s is already registerd", providerId)
// this is an unrecoverable error, we cannot register the same provider twice
slog.Error("OAUTH auth provider is already registered", "name", providerId)
continue // skip this provider
}
redirectUrl := *extUrl
redirectUrl := *a.callbackUrl
redirectUrl.Path = path.Join(redirectUrl.Path, "/auth/login/", providerId, "/callback")
provider, err := newPlainOauthAuthenticator(ctx, redirectUrl.String(), providerCfg)
if err != nil {
return fmt.Errorf("failed to setup oauth authentication provider %s: %w", providerId, err)
failedOauth = append(failedOauth, oauth[i])
slog.Error("failed to setup oauth authentication provider", "name", providerId, "error", err)
continue
}
a.oauthAuthenticators[providerId] = provider
}
for i := range a.cfg.Ldap { // LDAP
providerCfg := &a.cfg.Ldap[i]
for i := range ldap { // LDAP
providerCfg := &ldap[i]
providerId := strings.ToLower(providerCfg.URL)
if _, exists := a.ldapAuthenticators[providerId]; exists {
return fmt.Errorf("auth provider with name %s is already registerd", providerId)
// this is an unrecoverable error, we cannot register the same provider twice
slog.Error("LDAP auth provider is already registered", "name", providerId)
continue // skip this provider
}
provider, err := newLdapAuthenticator(ctx, providerCfg)
if err != nil {
return fmt.Errorf("failed to setup ldap authentication provider %s: %w", providerId, err)
failedLdap = append(failedLdap, ldap[i])
slog.Error("failed to setup ldap authentication provider", "name", providerId, "error", err)
continue
}
a.ldapAuthenticators[providerId] = provider
}
return nil
return failedOidc, failedOauth, failedLdap
}
// GetExternalLoginProviders returns a list of all available external login providers.
@@ -302,12 +374,15 @@ func (a *Authenticator) passwordAuthentication(
rawUserInfo, err := ldapAuth.GetUserInfo(context.Background(), identifier)
if err != nil {
if !errors.Is(err, domain.ErrNotFound) {
slog.Warn("failed to fetch ldap user info", "identifier", identifier, "error", err)
slog.Warn("failed to fetch ldap user info",
"source", ldapAuth.GetName(), "identifier", identifier, "error", err)
}
continue // user not found / other ldap error
}
ldapUserInfo, err = ldapAuth.ParseUserInfo(rawUserInfo)
if err != nil {
slog.Error("failed to parse ldap user info",
"source", ldapAuth.GetName(), "identifier", identifier, "error", err)
continue
}
@@ -320,10 +395,14 @@ func (a *Authenticator) passwordAuthentication(
}
if userSource == "" {
slog.Warn("no user source found for user",
"identifier", identifier, "ldapProviderCount", len(a.ldapAuthenticators), "inDb", userInDatabase)
return nil, errors.New("user not found")
}
if userSource == domain.UserSourceLdap && ldapProvider == nil {
slog.Warn("no ldap provider found for user",
"identifier", identifier, "ldapProviderCount", len(a.ldapAuthenticators), "inDb", userInDatabase)
return nil, errors.New("ldap provider not found")
}
@@ -434,6 +513,10 @@ func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce,
return nil, fmt.Errorf("unable to parse user information: %w", err)
}
if !isDomainAllowed(userInfo.Email, oauthProvider.GetAllowedDomains()) {
return nil, fmt.Errorf("user %s is not in allowed domains", userInfo.Email)
}
ctx = domain.SetUserInfo(ctx,
domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
user, err := a.processUserInfo(ctx, userInfo, domain.UserSourceOauth, oauthProvider.GetName(),
@@ -450,10 +533,6 @@ func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce,
return nil, fmt.Errorf("unable to process user information: %w", err)
}
if !isDomainAllowed(userInfo.Email, oauthProvider.GetAllowedDomains()) {
return nil, fmt.Errorf("user is not in allowed domains: %w", err)
}
if user.IsLocked() || user.IsDisabled() {
a.bus.Publish(app.TopicAuditLoginFailed, domain.AuditEventWrapper[audit.AuthEvent]{
Ctx: ctx,

View File

@@ -113,10 +113,13 @@ func (l LdapAuthenticator) GetUserInfo(_ context.Context, userId domain.UserIden
}
if len(sr.Entries) == 0 {
slog.Debug("LDAP user not found", "source", l.GetName(), "userId", userId, "filter", loginFilter)
return nil, domain.ErrNotFound
}
if len(sr.Entries) > 1 {
slog.Debug("LDAP user not unique",
"source", l.GetName(), "userId", userId, "filter", loginFilter, "entries", len(sr.Entries))
return nil, domain.ErrNotUnique
}

View File

@@ -19,15 +19,16 @@ import (
// PlainOauthAuthenticator is an authenticator that uses OAuth for authentication.
// User information is retrieved from the specified user info endpoint.
type PlainOauthAuthenticator struct {
name string
cfg *oauth2.Config
userInfoEndpoint string
client *http.Client
userInfoMapping config.OauthFields
userAdminMapping *config.OauthAdminMapping
registrationEnabled bool
userInfoLogging bool
allowedDomains []string
name string
cfg *oauth2.Config
userInfoEndpoint string
client *http.Client
userInfoMapping config.OauthFields
userAdminMapping *config.OauthAdminMapping
registrationEnabled bool
userInfoLogging bool
sensitiveInfoLogging bool
allowedDomains []string
}
func newPlainOauthAuthenticator(
@@ -57,6 +58,7 @@ func newPlainOauthAuthenticator(
provider.userAdminMapping = &cfg.AdminMapping
provider.registrationEnabled = cfg.RegistrationEnabled
provider.userInfoLogging = cfg.LogUserInfo
provider.sensitiveInfoLogging = cfg.LogSensitiveInfo
provider.allowedDomains = cfg.AllowedDomains
return provider, nil
@@ -110,6 +112,10 @@ func (p PlainOauthAuthenticator) GetUserInfo(
response, err := p.client.Do(req)
if err != nil {
if p.sensitiveInfoLogging {
slog.Debug("OAuth: failed to get user info", "endpoint", p.userInfoEndpoint,
"token", token, "error", err)
}
return nil, fmt.Errorf("failed to get user info: %w", err)
}
defer internal.LogClose(response.Body)
@@ -121,11 +127,15 @@ func (p PlainOauthAuthenticator) GetUserInfo(
var userFields map[string]any
err = json.Unmarshal(contents, &userFields)
if err != nil {
if p.sensitiveInfoLogging {
slog.Debug("OAuth: failed to parse user info", "endpoint", p.userInfoEndpoint,
"token", token, "contents", contents, "error", err)
}
return nil, fmt.Errorf("failed to parse user info: %w", err)
}
if p.userInfoLogging {
slog.Debug("OAuth user info",
slog.Debug("OAuth: user info debug",
"source", p.name,
"info", string(contents))
}

View File

@@ -16,15 +16,16 @@ import (
// OidcAuthenticator is an authenticator for OpenID Connect providers.
type OidcAuthenticator struct {
name string
provider *oidc.Provider
verifier *oidc.IDTokenVerifier
cfg *oauth2.Config
userInfoMapping config.OauthFields
userAdminMapping *config.OauthAdminMapping
registrationEnabled bool
userInfoLogging bool
allowedDomains []string
name string
provider *oidc.Provider
verifier *oidc.IDTokenVerifier
cfg *oauth2.Config
userInfoMapping config.OauthFields
userAdminMapping *config.OauthAdminMapping
registrationEnabled bool
userInfoLogging bool
sensitiveInfoLogging bool
allowedDomains []string
}
func newOidcAuthenticator(
@@ -58,6 +59,7 @@ func newOidcAuthenticator(
provider.userAdminMapping = &cfg.AdminMapping
provider.registrationEnabled = cfg.RegistrationEnabled
provider.userInfoLogging = cfg.LogUserInfo
provider.sensitiveInfoLogging = cfg.LogSensitiveInfo
provider.allowedDomains = cfg.AllowedDomains
return provider, nil
@@ -102,24 +104,40 @@ func (o OidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token,
) {
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
if o.sensitiveInfoLogging {
slog.Debug("OIDC: token does not contain id_token", "token", token, "nonce", nonce)
}
return nil, errors.New("token does not contain id_token")
}
idToken, err := o.verifier.Verify(ctx, rawIDToken)
if err != nil {
if o.sensitiveInfoLogging {
slog.Debug("OIDC: failed to validate id_token", "token", token, "id_token", rawIDToken, "nonce", nonce,
"error",
err)
}
return nil, fmt.Errorf("failed to validate id_token: %w", err)
}
if idToken.Nonce != nonce {
if o.sensitiveInfoLogging {
slog.Debug("OIDC: id_token nonce mismatch", "token", token, "id_token", idToken, "nonce", nonce)
}
return nil, errors.New("nonce mismatch")
}
var tokenFields map[string]any
if err = idToken.Claims(&tokenFields); err != nil {
if o.sensitiveInfoLogging {
slog.Debug("OIDC: failed to parse extra claims", "token", token, "id_token", idToken, "nonce", nonce,
"error",
err)
}
return nil, fmt.Errorf("failed to parse extra claims: %w", err)
}
if o.userInfoLogging {
contents, _ := json.Marshal(tokenFields)
slog.Debug("OIDC user info",
slog.Debug("OIDC: user info debug",
"source", o.name,
"info", string(contents))
}

View File

@@ -0,0 +1,301 @@
package auth
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/app/audit"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
type WebAuthnUserManager interface {
// GetUser returns a user by its identifier.
GetUser(context.Context, domain.UserIdentifier) (*domain.User, error)
// GetUserByWebAuthnCredential returns a user by its WebAuthn ID.
GetUserByWebAuthnCredential(ctx context.Context, credentialIdBase64 string) (*domain.User, error)
// UpdateUser updates an existing user in the database.
UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error)
}
type WebAuthnAuthenticator struct {
webAuthn *webauthn.WebAuthn
users WebAuthnUserManager
bus EventBus
}
func NewWebAuthnAuthenticator(cfg *config.Config, bus EventBus, users WebAuthnUserManager) (
*WebAuthnAuthenticator,
error,
) {
if !cfg.Auth.WebAuthn.Enabled {
return nil, nil
}
extUrl, err := url.Parse(cfg.Web.ExternalUrl)
if err != nil {
return nil, errors.New("failed to parse external URL - required for WebAuthn RP ID")
}
rpId := extUrl.Hostname()
if rpId == "" {
return nil, errors.New("failed to determine Webauthn RPID")
}
// Initialize the WebAuthn authenticator with the provided configuration
awCfg := &webauthn.Config{
RPID: rpId,
RPDisplayName: cfg.Web.SiteTitle,
RPOrigins: []string{cfg.Web.ExternalUrl},
}
webAuthn, err := webauthn.New(awCfg)
if err != nil {
return nil, fmt.Errorf("failed to create Webauthn instance: %w", err)
}
return &WebAuthnAuthenticator{
webAuthn: webAuthn,
users: users,
bus: bus,
}, nil
}
func (a *WebAuthnAuthenticator) Enabled() bool {
return a != nil && a.webAuthn != nil
}
func (a *WebAuthnAuthenticator) StartWebAuthnRegistration(ctx context.Context, userId domain.UserIdentifier) (
optionsAsJSON []byte,
sessionDataAsJSON []byte,
err error,
) {
user, err := a.users.GetUser(ctx, userId)
if err != nil {
return nil, nil, fmt.Errorf("failed to get user: %w", err)
}
if user.IsLocked() || user.IsDisabled() {
return nil, nil, errors.New("user is locked") // adding passkey to locked user is not allowed
}
if user.WebAuthnId == "" {
user.GenerateWebAuthnId()
user, err = a.users.UpdateUser(ctx, user)
if err != nil {
return nil, nil, fmt.Errorf("failed to store webauthn id to user: %w", err)
}
}
options, sessionData, err := a.webAuthn.BeginRegistration(user,
webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired),
)
if err != nil {
return nil, nil, fmt.Errorf("failed to begin WebAuthn registration: %w", err)
}
optionsAsJSON, err = json.Marshal(options)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal webauthn options to JSON: %w", err)
}
sessionDataAsJSON, err = json.Marshal(sessionData)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal webauthn session data to JSON: %w", err)
}
return optionsAsJSON, sessionDataAsJSON, nil
}
func (a *WebAuthnAuthenticator) FinishWebAuthnRegistration(
ctx context.Context,
userId domain.UserIdentifier,
name string,
sessionDataAsJSON []byte,
r *http.Request,
) ([]domain.UserWebauthnCredential, error) {
user, err := a.users.GetUser(ctx, userId)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
if user.IsLocked() || user.IsDisabled() {
return nil, errors.New("user is locked") // adding passkey to locked user is not allowed
}
var webAuthnData webauthn.SessionData
err = json.Unmarshal(sessionDataAsJSON, &webAuthnData)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal webauthn session data: %w", err)
}
credential, err := a.webAuthn.FinishRegistration(user, webAuthnData, r)
if err != nil {
return nil, err
}
if name == "" {
name = fmt.Sprintf("Passkey %d", len(user.WebAuthnCredentialList)+1) // fallback name
}
// Add the credential to the user
err = user.AddCredential(userId, name, *credential)
if err != nil {
return nil, err
}
user, err = a.users.UpdateUser(ctx, user)
if err != nil {
return nil, err
}
return user.WebAuthnCredentialList, nil
}
func (a *WebAuthnAuthenticator) GetCredentials(
ctx context.Context,
userId domain.UserIdentifier,
) ([]domain.UserWebauthnCredential, error) {
user, err := a.users.GetUser(ctx, userId)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
return user.WebAuthnCredentialList, nil
}
func (a *WebAuthnAuthenticator) RemoveCredential(
ctx context.Context,
userId domain.UserIdentifier,
credentialIdBase64 string,
) ([]domain.UserWebauthnCredential, error) {
user, err := a.users.GetUser(ctx, userId)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
user.RemoveCredential(credentialIdBase64)
user, err = a.users.UpdateUser(ctx, user)
if err != nil {
return nil, err
}
return user.WebAuthnCredentialList, nil
}
func (a *WebAuthnAuthenticator) UpdateCredential(
ctx context.Context,
userId domain.UserIdentifier,
credentialIdBase64 string,
name string,
) ([]domain.UserWebauthnCredential, error) {
user, err := a.users.GetUser(ctx, userId)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
err = user.UpdateCredential(credentialIdBase64, name)
if err != nil {
return nil, err
}
user, err = a.users.UpdateUser(ctx, user)
if err != nil {
return nil, err
}
return user.WebAuthnCredentialList, nil
}
func (a *WebAuthnAuthenticator) StartWebAuthnLogin(_ context.Context) (
optionsAsJSON []byte,
sessionDataAsJSON []byte,
err error,
) {
options, sessionData, err := a.webAuthn.BeginDiscoverableLogin()
if err != nil {
return nil, nil, fmt.Errorf("failed to begin WebAuthn login: %w", err)
}
optionsAsJSON, err = json.Marshal(options)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal webauthn options to JSON: %w", err)
}
sessionDataAsJSON, err = json.Marshal(sessionData)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal webauthn session data to JSON: %w", err)
}
return optionsAsJSON, sessionDataAsJSON, nil
}
func (a *WebAuthnAuthenticator) FinishWebAuthnLogin(
ctx context.Context,
sessionDataAsJSON []byte,
r *http.Request,
) (*domain.User, error) {
var webAuthnData webauthn.SessionData
err := json.Unmarshal(sessionDataAsJSON, &webAuthnData)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal webauthn session data: %w", err)
}
// switch to admin context for user lookup
ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo())
credential, err := a.webAuthn.FinishDiscoverableLogin(a.findUserForWebAuthnSecretFn(ctx), webAuthnData, r)
if err != nil {
return nil, err
}
// Find the user by the WebAuthn ID
user, err := a.users.GetUserByWebAuthnCredential(ctx,
base64.StdEncoding.EncodeToString(credential.ID))
if err != nil {
return nil, fmt.Errorf("failed to get user by webauthn credential: %w", err)
}
if user.IsLocked() || user.IsDisabled() {
a.bus.Publish(app.TopicAuditLoginFailed, domain.AuditEventWrapper[audit.AuthEvent]{
Ctx: ctx,
Source: "passkey",
Event: audit.AuthEvent{
Username: string(user.Identifier), Error: "User is locked",
},
})
return nil, errors.New("user is locked") // login with passkey is not allowed
}
a.bus.Publish(app.TopicAuthLogin, user.Identifier)
a.bus.Publish(app.TopicAuditLoginSuccess, domain.AuditEventWrapper[audit.AuthEvent]{
Ctx: ctx,
Source: "passkey",
Event: audit.AuthEvent{
Username: string(user.Identifier),
},
})
return user, nil
}
func (a *WebAuthnAuthenticator) findUserForWebAuthnSecretFn(ctx context.Context) func(rawID, userHandle []byte) (
user webauthn.User,
err error,
) {
return func(rawID, userHandle []byte) (webauthn.User, error) {
// Find the user by the WebAuthn ID
user, err := a.users.GetUserByWebAuthnCredential(ctx, base64.StdEncoding.EncodeToString(rawID))
if err != nil {
return nil, fmt.Errorf("failed to get user by webauthn credential: %w", err)
}
return user, nil
}
}

View File

@@ -46,7 +46,7 @@ type TemplateRenderer interface {
// GetInterfaceConfig returns the configuration file for the given interface.
GetInterfaceConfig(iface *domain.Interface, peers []domain.Peer) (io.Reader, error)
// GetPeerConfig returns the configuration file for the given peer.
GetPeerConfig(peer *domain.Peer) (io.Reader, error)
GetPeerConfig(peer *domain.Peer, style string) (io.Reader, error)
}
type EventBus interface {
@@ -186,7 +186,7 @@ func (m Manager) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIden
// GetPeerConfig returns the configuration file for the given peer.
// The file is structured in wg-quick format.
func (m Manager) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
func (m Manager) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier, style string) (io.Reader, error) {
peer, err := m.wg.GetPeer(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to fetch peer %s: %w", id, err)
@@ -196,11 +196,11 @@ func (m Manager) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (i
return nil, err
}
return m.tplHandler.GetPeerConfig(peer)
return m.tplHandler.GetPeerConfig(peer, style)
}
// GetPeerConfigQrCode returns a QR code image containing the configuration for the given peer.
func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) {
func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier, style string) (io.Reader, error) {
peer, err := m.wg.GetPeer(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to fetch peer %s: %w", id, err)
@@ -210,7 +210,7 @@ func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifi
return nil, err
}
cfgData, err := m.tplHandler.GetPeerConfig(peer)
cfgData, err := m.tplHandler.GetPeerConfig(peer, style)
if err != nil {
return nil, fmt.Errorf("failed to get peer config for %s: %w", id, err)
}

View File

@@ -55,11 +55,12 @@ func (c TemplateHandler) GetInterfaceConfig(cfg *domain.Interface, peers []domai
}
// GetPeerConfig returns the rendered configuration file for a WireGuard peer.
func (c TemplateHandler) GetPeerConfig(peer *domain.Peer) (io.Reader, error) {
func (c TemplateHandler) GetPeerConfig(peer *domain.Peer, style string) (io.Reader, error) {
var tplBuff bytes.Buffer
err := c.templates.ExecuteTemplate(&tplBuff, "wg_peer.tpl", map[string]any{
"Peer": peer,
"Style": style,
"Peer": peer,
"Portal": map[string]any{
"Version": "unknown",
},

View File

@@ -1,6 +1,8 @@
# AUTOGENERATED FILE - DO NOT EDIT
# This file uses wg-quick format.
# This file uses {{ .Style }} format.
{{- if eq .Style "wgquick"}}
# See https://man7.org/linux/man-pages/man8/wg-quick.8.html#CONFIGURATION
{{- end}}
# Lines starting with the -WGP- tag are used by
# the WireGuard Portal configuration parser.
@@ -21,22 +23,27 @@
# Core settings
PrivateKey = {{ .Peer.Interface.KeyPair.PrivateKey }}
{{- if eq .Style "wgquick"}}
Address = {{ CidrsToString .Peer.Interface.Addresses }}
{{- end}}
# Misc. settings (optional)
{{- if eq .Style "wgquick"}}
{{- if .Peer.Interface.DnsStr.GetValue}}
DNS = {{ .Peer.Interface.DnsStr.GetValue }} {{- if .Peer.Interface.DnsSearchStr.GetValue}}, {{ .Peer.Interface.DnsSearchStr.GetValue }} {{- end}}
{{- end}}
{{- if ne .Peer.Interface.Mtu.GetValue 0}}
MTU = {{ .Peer.Interface.Mtu.GetValue }}
{{- end}}
{{- if ne .Peer.Interface.FirewallMark.GetValue 0}}
FwMark = {{ .Peer.Interface.FirewallMark.GetValue }}
{{- end}}
{{- if ne .Peer.Interface.RoutingTable.GetValue ""}}
Table = {{ .Peer.Interface.RoutingTable.GetValue }}
{{- end}}
{{- end}}
{{- if ne .Peer.Interface.FirewallMark.GetValue 0}}
FwMark = {{ .Peer.Interface.FirewallMark.GetValue }}
{{- end}}
{{- if eq .Style "wgquick"}}
# Interface hooks (optional)
{{- if .Peer.Interface.PreUp.GetValue}}
PreUp = {{ .Peer.Interface.PreUp.GetValue }}
@@ -50,6 +57,7 @@ PreDown = {{ .Peer.Interface.PreDown.GetValue }}
{{- if .Peer.Interface.PostDown.GetValue}}
PostDown = {{ .Peer.Interface.PostDown.GetValue }}
{{- end}}
{{- end}}
[Peer]
PublicKey = {{ .Peer.EndpointPublicKey.GetValue }}

View File

@@ -36,6 +36,7 @@ const TopicPeerDeleted = "peer:deleted"
const TopicPeerUpdated = "peer:updated"
const TopicPeerInterfaceUpdated = "peer:interface:updated"
const TopicPeerIdentifierUpdated = "peer:identifier:updated"
const TopicPeerStateChanged = "peer:state:changed"
// endregion peer-events

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"log/slog"
"net/mail"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
@@ -21,9 +22,9 @@ type ConfigFileManager interface {
// GetInterfaceConfig returns the configuration for the given interface.
GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error)
// GetPeerConfig returns the configuration for the given peer.
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
GetPeerConfig(ctx context.Context, id domain.PeerIdentifier, style string) (io.Reader, error)
// GetPeerConfigQrCode returns the QR code for the given peer.
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error)
GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier, style string) (io.Reader, error)
}
type UserDatabaseRepo interface {
@@ -71,7 +72,7 @@ func NewMailManager(
users UserDatabaseRepo,
wg WireguardDatabaseRepo,
) (*Manager, error) {
tplHandler, err := newTemplateHandler(cfg.Web.ExternalUrl)
tplHandler, err := newTemplateHandler(cfg.Web.ExternalUrl, cfg.Web.SiteTitle)
if err != nil {
return nil, fmt.Errorf("failed to initialize template handler: %w", err)
}
@@ -89,7 +90,7 @@ func NewMailManager(
}
// SendPeerEmail sends an email to the user linked to the given peers.
func (m Manager) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error {
func (m Manager) SendPeerEmail(ctx context.Context, linkOnly bool, style string, peers ...domain.PeerIdentifier) error {
for _, peerId := range peers {
peer, err := m.wg.GetPeer(ctx, peerId)
if err != nil {
@@ -101,29 +102,15 @@ func (m Manager) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...doma
}
if peer.UserIdentifier == "" {
slog.Debug("skipping peer email",
"peer", peerId,
"reason", "no user linked")
continue
return fmt.Errorf("peer %s has no user linked, no email is sent", peerId)
}
user, err := m.users.GetUser(ctx, peer.UserIdentifier)
if err != nil {
slog.Debug("skipping peer email",
"peer", peerId,
"reason", "unable to fetch user",
"error", err)
continue
email, user := m.resolveEmail(ctx, peer)
if email == "" {
return fmt.Errorf("peer %s has no valid email address, no email is sent", peerId)
}
if user.Email == "" {
slog.Debug("skipping peer email",
"peer", peerId,
"reason", "user has no mail address")
continue
}
err = m.sendPeerEmail(ctx, linkOnly, user, peer)
err = m.sendPeerEmail(ctx, linkOnly, style, &user, peer)
if err != nil {
return fmt.Errorf("failed to send peer email for %s: %w", peerId, err)
}
@@ -132,7 +119,13 @@ func (m Manager) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...doma
return nil
}
func (m Manager) sendPeerEmail(ctx context.Context, linkOnly bool, user *domain.User, peer *domain.Peer) error {
func (m Manager) sendPeerEmail(
ctx context.Context,
linkOnly bool,
style string,
user *domain.User,
peer *domain.Peer,
) error {
qrName := "WireGuardQRCode.png"
configName := peer.GetConfigFileName()
@@ -148,12 +141,12 @@ func (m Manager) sendPeerEmail(ctx context.Context, linkOnly bool, user *domain.
}
} else {
peerConfig, err := m.configFiles.GetPeerConfig(ctx, peer.Identifier)
peerConfig, err := m.configFiles.GetPeerConfig(ctx, peer.Identifier, style)
if err != nil {
return fmt.Errorf("failed to fetch peer config for %s: %w", peer.Identifier, err)
}
peerConfigQr, err := m.configFiles.GetPeerConfigQrCode(ctx, peer.Identifier)
peerConfigQr, err := m.configFiles.GetPeerConfigQrCode(ctx, peer.Identifier, style)
if err != nil {
return fmt.Errorf("failed to fetch peer config QR code for %s: %w", peer.Identifier, err)
}
@@ -188,3 +181,37 @@ func (m Manager) sendPeerEmail(ctx context.Context, linkOnly bool, user *domain.
return nil
}
func (m Manager) resolveEmail(ctx context.Context, peer *domain.Peer) (string, domain.User) {
user, err := m.users.GetUser(ctx, peer.UserIdentifier)
if err != nil {
if m.cfg.Mail.AllowPeerEmail {
_, err := mail.ParseAddress(string(peer.UserIdentifier)) // test if the user identifier is a valid email address
if err == nil {
slog.Debug("peer email: using user-identifier as email",
"peer", peer.Identifier, "email", peer.UserIdentifier)
return string(peer.UserIdentifier), domain.User{}
} else {
slog.Debug("peer email: skipping peer email",
"peer", peer.Identifier,
"reason", "peer has no user linked and user-identifier is not a valid email address")
return "", domain.User{}
}
} else {
slog.Debug("peer email: skipping peer email",
"peer", peer.Identifier,
"reason", "user has no user linked")
return "", domain.User{}
}
}
if user.Email == "" {
slog.Debug("peer email: skipping peer email",
"peer", peer.Identifier,
"reason", "user has no mail address")
return "", domain.User{}
}
slog.Debug("peer email: using user email", "peer", peer.Identifier, "email", user.Email)
return user.Email, *user
}

View File

@@ -17,11 +17,12 @@ var TemplateFiles embed.FS
// TemplateHandler is a struct that holds the html and text templates.
type TemplateHandler struct {
portalUrl string
portalName string
htmlTemplates *htmlTemplate.Template
textTemplates *template.Template
}
func newTemplateHandler(portalUrl string) (*TemplateHandler, error) {
func newTemplateHandler(portalUrl, portalName string) (*TemplateHandler, error) {
htmlTemplateCache, err := htmlTemplate.New("Html").ParseFS(TemplateFiles, "tpl_files/*.gohtml")
if err != nil {
return nil, fmt.Errorf("failed to parse html template files: %w", err)
@@ -34,6 +35,7 @@ func newTemplateHandler(portalUrl string) (*TemplateHandler, error) {
handler := &TemplateHandler{
portalUrl: portalUrl,
portalName: portalName,
htmlTemplates: htmlTemplateCache,
textTemplates: txtTemplateCache,
}
@@ -81,6 +83,7 @@ func (c TemplateHandler) GetConfigMailWithAttachment(user *domain.User, cfgName,
"ConfigFileName": cfgName,
"QrcodePngName": qrName,
"PortalUrl": c.portalUrl,
"PortalName": c.portalName,
})
if err != nil {
return nil, nil, fmt.Errorf("failed to execute template mail_with_attachment.gotpl: %w", err)
@@ -91,6 +94,7 @@ func (c TemplateHandler) GetConfigMailWithAttachment(user *domain.User, cfgName,
"ConfigFileName": cfgName,
"QrcodePngName": qrName,
"PortalUrl": c.portalUrl,
"PortalName": c.portalName,
})
if err != nil {
return nil, nil, fmt.Errorf("failed to execute template mail_with_attachment.gohtml: %w", err)

View File

@@ -19,7 +19,7 @@
<!--[if !mso]><!-->
<link href="https://fonts.googleapis.com/css?family=Muli:400,400i,700,700i" rel="stylesheet" />
<!--<![endif]-->
<title>Email Template</title>
<title>{{$.PortalName}}</title>
<!--[if gte mso 9]>
<style type="text/css" media="all">
sup { font-size: 100% !important; }
@@ -143,7 +143,7 @@
<td align="left">
<table border="0" cellspacing="0" cellpadding="0">
<tr>
<td class="blue-button text-button" style="background:#000000; color:#c1cddc; font-family:'Muli', Arial,sans-serif; font-size:14px; line-height:18px; padding:12px 30px; text-align:center; border-radius:0px 22px 22px 22px; font-weight:bold;"><a href="https://www.wireguard.com/install" target="_blank" class="link-white" style="color:#ffffff; text-decoration:none;"><span class="link-white" style="color:#ffffff; text-decoration:none;">Download WireGuard VPN Client</span></a></td>
<td class="blue-button text-button" style="background:#000000; color:#ffffff; font-family:'Muli', Arial,sans-serif; font-size:14px; line-height:18px; padding:12px 30px; text-align:center; border-radius:0px 22px 22px 22px; font-weight:bold;"><a href="https://www.wireguard.com/install" target="_blank" class="link-white" style="color:#ffffff; text-decoration:none;"><span class="link-white" style="color:#ffffff; text-decoration:none;">Download WireGuard VPN Client</span></a></td>
</tr>
</table>
</td>
@@ -167,10 +167,10 @@
<td class="p30-15 bbrr" style="padding: 50px 30px; border-radius:0px 0px 26px 26px;" bgcolor="#ffffff">
<table width="100%" border="0" cellspacing="0" cellpadding="0">
<tr>
<td class="text-footer1 pb10" style="color:#000000; font-family:'Muli', Arial,sans-serif; font-size:16px; line-height:20px; text-align:center; padding-bottom:10px;">This mail was generated using WireGuard Portal.</td>
<td class="text-footer1 pb10" style="color:#000000; font-family:'Muli', Arial,sans-serif; font-size:16px; line-height:20px; text-align:center; padding-bottom:10px;">This mail was generated by {{$.PortalName}}.</td>
</tr>
<tr>
<td class="text-footer2" style="color:#000000; font-family:'Muli', Arial,sans-serif; font-size:12px; line-height:26px; text-align:center;"><a href="{{$.PortalUrl}}" target="_blank" rel="noopener noreferrer" class="link" style="color:#000000; text-decoration:none;"><span class="link" style="color:#000000; text-decoration:none;">Visit WireGuard Portal</span></a></td>
<td class="text-footer2" style="color:#000000; font-family:'Muli', Arial,sans-serif; font-size:12px; line-height:26px; text-align:center;"><a href="{{$.PortalUrl}}" target="_blank" rel="noopener noreferrer" class="link" style="color:#000000; text-decoration:none;"><span class="link" style="color:#000000; text-decoration:none;">Visit {{$.PortalName}}</span></a></td>
</tr>
</table>
</td>

View File

@@ -20,5 +20,5 @@ You can download and install the WireGuard VPN client from:
https://www.wireguard.com/install/
This mail was generated using WireGuard Portal.
This mail was generated by {{$.PortalName}}.
{{$.PortalUrl}}

View File

@@ -19,7 +19,7 @@
<!--[if !mso]><!-->
<link href="https://fonts.googleapis.com/css?family=Muli:400,400i,700,700i" rel="stylesheet" />
<!--<![endif]-->
<title>Email Template</title>
<title>{{$.PortalName}}</title>
<!--[if gte mso 9]>
<style type="text/css" media="all">
sup { font-size: 100% !important; }
@@ -143,7 +143,7 @@
<td align="left">
<table border="0" cellspacing="0" cellpadding="0">
<tr>
<td class="blue-button text-button" style="background:#000000; color:#c1cddc; font-family:'Muli', Arial,sans-serif; font-size:14px; line-height:18px; padding:12px 30px; text-align:center; border-radius:0px 22px 22px 22px; font-weight:bold;"><a href="https://www.wireguard.com/install" target="_blank" class="link-white" style="color:#ffffff; text-decoration:none;"><span class="link-white" style="color:#ffffff; text-decoration:none;">Download WireGuard VPN Client</span></a></td>
<td class="blue-button text-button" style="background:#000000; color:#ffffff; font-family:'Muli', Arial,sans-serif; font-size:14px; line-height:18px; padding:12px 30px; text-align:center; border-radius:0px 22px 22px 22px; font-weight:bold;"><a href="https://www.wireguard.com/install" target="_blank" class="link-white" style="color:#ffffff; text-decoration:none;"><span class="link-white" style="color:#ffffff; text-decoration:none;">Download WireGuard VPN Client</span></a></td>
</tr>
</table>
</td>
@@ -167,10 +167,10 @@
<td class="p30-15 bbrr" style="padding: 50px 30px; border-radius:0px 0px 26px 26px;" bgcolor="#ffffff">
<table width="100%" border="0" cellspacing="0" cellpadding="0">
<tr>
<td class="text-footer1 pb10" style="color:#000000; font-family:'Muli', Arial,sans-serif; font-size:16px; line-height:20px; text-align:center; padding-bottom:10px;">This mail was generated using WireGuard Portal.</td>
<td class="text-footer1 pb10" style="color:#000000; font-family:'Muli', Arial,sans-serif; font-size:16px; line-height:20px; text-align:center; padding-bottom:10px;">This mail was generated by {{$.PortalName}}.</td>
</tr>
<tr>
<td class="text-footer2" style="color:#000000; font-family:'Muli', Arial,sans-serif; font-size:12px; line-height:26px; text-align:center;"><a href="{{$.PortalUrl}}" target="_blank" rel="noopener noreferrer" class="link" style="color:#000000; text-decoration:none;"><span class="link" style="color:#000000; text-decoration:none;">Visit WireGuard Portal</span></a></td>
<td class="text-footer2" style="color:#000000; font-family:'Muli', Arial,sans-serif; font-size:12px; line-height:26px; text-align:center;"><a href="{{$.PortalUrl}}" target="_blank" rel="noopener noreferrer" class="link" style="color:#000000; text-decoration:none;"><span class="link" style="color:#000000; text-decoration:none;">Visit {{$.PortalName}}</span></a></td>
</tr>
</table>
</td>

View File

@@ -20,5 +20,5 @@ You can download and install the WireGuard VPN client from:
https://www.wireguard.com/install/
This mail was generated using WireGuard Portal.
This mail was generated by {{$.PortalName}}.
{{$.PortalUrl}}

View File

@@ -4,25 +4,23 @@ import (
"context"
"fmt"
"log/slog"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"sync"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
"github.com/h44z/wg-portal/internal/lowlevel"
)
// region dependencies
type ControllerManager interface {
// GetController returns the controller for the given interface.
GetController(iface domain.Interface) domain.InterfaceController
}
type InterfaceAndPeerDatabaseRepo interface {
// GetAllInterfaces returns all interfaces
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
// GetInterfacePeers returns all peers for a given interface
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
// GetInterface returns the interface with the given identifier.
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
}
type EventBus interface {
@@ -30,6 +28,13 @@ type EventBus interface {
Subscribe(topic string, fn interface{}) error
}
type RoutesController interface {
// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op.
SetRoutes(ctx context.Context, info domain.RoutingTableInfo) error
// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op.
RemoveRoutes(ctx context.Context, info domain.RoutingTableInfo) error
}
// endregion dependencies
type routeRuleInfo struct {
@@ -45,28 +50,27 @@ type routeRuleInfo struct {
type Manager struct {
cfg *config.Config
bus EventBus
wg lowlevel.WireGuardClient
nl lowlevel.NetlinkClient
db InterfaceAndPeerDatabaseRepo
bus EventBus
db InterfaceAndPeerDatabaseRepo
wgController ControllerManager
mux *sync.Mutex
}
// NewRouteManager creates a new route manager instance.
func NewRouteManager(cfg *config.Config, bus EventBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) {
wg, err := wgctrl.New()
if err != nil {
panic("failed to init wgctrl: " + err.Error())
}
nl := &lowlevel.NetlinkManager{}
func NewRouteManager(
cfg *config.Config,
bus EventBus,
db InterfaceAndPeerDatabaseRepo,
wgController ControllerManager,
) (*Manager, error) {
m := &Manager{
cfg: cfg,
bus: bus,
db: db,
wg: wg,
nl: nl,
db: db,
wgController: wgController,
mux: &sync.Mutex{},
}
m.connectToMessageBus()
@@ -85,419 +89,82 @@ func (m Manager) StartBackgroundJobs(_ context.Context) {
// this is a no-op for now
}
func (m Manager) handleRouteUpdateEvent(srcDescription string) {
slog.Debug("handling route update event", "source", srcDescription)
func (m Manager) handleRouteUpdateEvent(info domain.RoutingTableInfo) {
m.mux.Lock() // ensure that only one route update is processed at a time
defer m.mux.Unlock()
err := m.syncRoutes(context.Background())
if err != nil {
slog.Error("failed to synchronize routes",
"source", srcDescription,
"error", err)
slog.Debug("handling route update event", "info", info.String())
if !info.ManagementEnabled() {
return // route management disabled
}
slog.Debug("routes synchronized", "source", srcDescription)
err := m.syncRoutes(context.Background(), info)
if err != nil {
slog.Error("failed to synchronize routes",
"info", info.String(), "error", err)
return
}
slog.Debug("routes synchronized", "info", info.String())
}
func (m Manager) handleRouteRemoveEvent(info domain.RoutingTableInfo) {
m.mux.Lock() // ensure that only one route update is processed at a time
defer m.mux.Unlock()
slog.Debug("handling route remove event", "info", info.String())
if !info.ManagementEnabled() {
return // route management disabled
}
if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V4); err != nil {
slog.Error("failed to remove v4 fwmark rules", "error", err)
}
if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V6); err != nil {
slog.Error("failed to remove v6 fwmark rules", "error", err)
}
slog.Debug("routes removed", "table", info.String())
}
func (m Manager) syncRoutes(ctx context.Context) error {
interfaces, err := m.db.GetAllInterfaces(ctx)
err := m.removeRoutes(context.Background(), info)
if err != nil {
return fmt.Errorf("failed to find all interfaces: %w", err)
slog.Error("failed to synchronize routes",
"info", info.String(), "error", err)
return
}
rules := map[int][]routeRuleInfo{
netlink.FAMILY_V4: nil,
netlink.FAMILY_V6: nil,
}
for _, iface := range interfaces {
if iface.IsDisabled() {
continue // disabled interface does not need route entries
}
if !iface.ManageRoutingTable() {
continue
}
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil {
return fmt.Errorf("failed to find peers for %s: %w", iface.Identifier, err)
}
allowedIPs := iface.GetAllowedIPs(peers)
defRouteV4, defRouteV6 := m.containsDefaultRoute(allowedIPs)
link, err := m.nl.LinkByName(string(iface.Identifier))
if err != nil {
return fmt.Errorf("failed to find physical link for %s: %w", iface.Identifier, err)
}
table, fwmark, err := m.getRoutingTableAndFwMark(&iface, link)
if err != nil {
return fmt.Errorf("failed to get table and fwmark for %s: %w", iface.Identifier, err)
}
if err := m.setInterfaceRoutes(link, table, allowedIPs); err != nil {
return fmt.Errorf("failed to set routes for %s: %w", iface.Identifier, err)
}
if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V4, allowedIPs); err != nil {
return fmt.Errorf("failed to remove deprecated v4 routes for %s: %w", iface.Identifier, err)
}
if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V6, allowedIPs); err != nil {
return fmt.Errorf("failed to remove deprecated v6 routes for %s: %w", iface.Identifier, err)
}
if table != 0 {
rules[netlink.FAMILY_V4] = append(rules[netlink.FAMILY_V4], routeRuleInfo{
ifaceId: iface.Identifier,
fwMark: fwmark,
table: table,
family: netlink.FAMILY_V4,
hasDefault: defRouteV4,
})
}
if table != 0 {
rules[netlink.FAMILY_V6] = append(rules[netlink.FAMILY_V6], routeRuleInfo{
ifaceId: iface.Identifier,
fwMark: fwmark,
table: table,
family: netlink.FAMILY_V6,
hasDefault: defRouteV6,
})
}
}
return m.syncRouteRules(rules)
slog.Debug("routes removed", "info", info.String())
}
func (m Manager) syncRouteRules(allRules map[int][]routeRuleInfo) error {
for family, rules := range allRules {
// update fwmark rules
if err := m.setFwMarkRules(rules, family); err != nil {
return err
}
// update main rule
if err := m.setMainRule(rules, family); err != nil {
return err
}
// cleanup old main rules
if err := m.cleanupMainRule(rules, family); err != nil {
return err
}
}
return nil
}
func (m Manager) setFwMarkRules(rules []routeRuleInfo, family int) error {
for _, rule := range rules {
existingRules, err := m.nl.RuleList(family)
if err != nil {
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
}
ruleExists := false
for _, existingRule := range existingRules {
if rule.fwMark == existingRule.Mark && rule.table == existingRule.Table {
ruleExists = true
break
}
}
if ruleExists {
continue // rule already exists, no need to recreate it
}
// create missing rule
if err := m.nl.RuleAdd(&netlink.Rule{
Family: family,
Table: rule.table,
Mark: rule.fwMark,
Invert: true,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Priority: m.getRulePriority(existingRules),
Mask: nil,
Goto: -1,
Flow: -1,
}); err != nil {
return fmt.Errorf("failed to setup rule for fwmark %d and table %d: %w", rule.fwMark, rule.table, err)
}
}
return nil
}
func (m Manager) removeFwMarkRules(fwmark uint32, table int, family int) error {
existingRules, err := m.nl.RuleList(family)
if err != nil {
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
}
for _, existingRule := range existingRules {
if fwmark == existingRule.Mark && table == existingRule.Table {
existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
if err := m.nl.RuleDel(&existingRule); err != nil {
return fmt.Errorf("failed to delete fwmark rule: %w", err)
}
}
}
return nil
}
func (m Manager) setMainRule(rules []routeRuleInfo, family int) error {
shouldHaveMainRule := false
for _, rule := range rules {
if rule.hasDefault == true {
shouldHaveMainRule = true
break
}
}
if !shouldHaveMainRule {
func (m Manager) syncRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
rc, ok := m.wgController.GetController(info.Interface).(RoutesController)
if !ok {
slog.Warn("no capable routes-controller found for interface", "interface", info.Interface.Identifier)
return nil
}
existingRules, err := m.nl.RuleList(family)
if !info.Interface.ManageRoutingTable() {
slog.Debug("interface does not manage routing table, skipping route update",
"interface", info.Interface.Identifier)
return nil
}
err := rc.SetRoutes(ctx, info)
if err != nil {
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
return fmt.Errorf("failed to set routes for interface %s: %w", info.Interface.Identifier, err)
}
ruleExists := false
for _, existingRule := range existingRules {
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
ruleExists = true
break
}
}
if ruleExists {
return nil // rule already exists, skip re-creation
}
if err := m.nl.RuleAdd(&netlink.Rule{
Family: family,
Table: unix.RT_TABLE_MAIN,
SuppressIfgroup: -1,
SuppressPrefixlen: 0,
Priority: m.getMainRulePriority(existingRules),
Mark: 0,
Mask: nil,
Goto: -1,
Flow: -1,
}); err != nil {
return fmt.Errorf("failed to setup rule for main table: %w", err)
}
return nil
}
func (m Manager) cleanupMainRule(rules []routeRuleInfo, family int) error {
existingRules, err := m.nl.RuleList(family)
func (m Manager) removeRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
rc, ok := m.wgController.GetController(info.Interface).(RoutesController)
if !ok {
slog.Warn("no capable routes-controller found for interface", "interface", info.Interface.Identifier)
return nil
}
if !info.Interface.ManageRoutingTable() {
slog.Debug("interface does not manage routing table, skipping route removal",
"interface", info.Interface.Identifier)
return nil
}
err := rc.RemoveRoutes(ctx, info)
if err != nil {
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
}
shouldHaveMainRule := false
for _, rule := range rules {
if rule.hasDefault == true {
shouldHaveMainRule = true
break
}
}
mainRules := 0
for _, existingRule := range existingRules {
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
mainRules++
}
}
removalCount := 0
if mainRules > 1 {
removalCount = mainRules - 1 // we only want one single rule
}
if !shouldHaveMainRule {
removalCount = mainRules
}
for _, existingRule := range existingRules {
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
if removalCount > 0 {
existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
if err := m.nl.RuleDel(&existingRule); err != nil {
return fmt.Errorf("failed to delete main rule: %w", err)
}
removalCount--
}
}
}
return nil
}
func (m Manager) getMainRulePriority(existingRules []netlink.Rule) int {
prio := m.cfg.Advanced.RulePrioOffset
for {
isFresh := true
for _, existingRule := range existingRules {
if existingRule.Priority == prio {
isFresh = false
break
}
}
if isFresh {
break
} else {
prio++
}
}
return prio
}
func (m Manager) getRulePriority(existingRules []netlink.Rule) int {
prio := 32700 // linux main rule has a prio of 32766
for {
isFresh := true
for _, existingRule := range existingRules {
if existingRule.Priority == prio {
isFresh = false
break
}
}
if isFresh {
break
} else {
prio--
}
}
return prio
}
func (m Manager) setInterfaceRoutes(link netlink.Link, table int, allowedIPs []domain.Cidr) error {
for _, allowedIP := range allowedIPs {
err := m.nl.RouteReplace(&netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: allowedIP.IpNet(),
Table: table,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
})
if err != nil {
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
}
}
return nil
}
func (m Manager) removeDeprecatedRoutes(link netlink.Link, family int, allowedIPs []domain.Cidr) error {
rawRoutes, err := m.nl.RouteListFiltered(family, &netlink.Route{
LinkIndex: link.Attrs().Index,
Table: unix.RT_TABLE_UNSPEC, // all tables
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
if err != nil {
return fmt.Errorf("failed to fetch raw routes: %w", err)
}
for _, rawRoute := range rawRoutes {
if rawRoute.Dst == nil { // handle default route
var netlinkAddr domain.Cidr
if family == netlink.FAMILY_V4 {
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
} else {
netlinkAddr, _ = domain.CidrFromString("::/0")
}
rawRoute.Dst = netlinkAddr.IpNet()
}
netlinkAddr := domain.CidrFromIpNet(*rawRoute.Dst)
remove := true
for _, allowedIP := range allowedIPs {
if netlinkAddr == allowedIP {
remove = false
break
}
}
if !remove {
continue
}
err := m.nl.RouteDel(&rawRoute)
if err != nil {
return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err)
}
return fmt.Errorf("failed to remove routes for interface %s: %w", info.Interface.Identifier, err)
}
return nil
}
func (m Manager) getRoutingTableAndFwMark(iface *domain.Interface, link netlink.Link) (
table int,
fwmark uint32,
err error,
) {
table = iface.GetRoutingTable()
fwmark = iface.FirewallMark
if fwmark == 0 {
// generate a new (temporary) firewall mark based on the interface index
fwmark = uint32(m.cfg.Advanced.RouteTableOffset + link.Attrs().Index)
slog.Debug("using fwmark to handle routes",
"interface", iface.Identifier,
"fwmark", fwmark)
// apply the temporary fwmark to the wireguard interface
err = m.setFwMark(iface.Identifier, int(fwmark))
}
if table == 0 {
table = int(fwmark) // generate a new routing table base on interface index
slog.Debug("using routing table to handle default routes",
"interface", iface.Identifier,
"table", table)
}
return
}
func (m Manager) setFwMark(id domain.InterfaceIdentifier, fwmark int) error {
err := m.wg.ConfigureDevice(string(id), wgtypes.Config{
FirewallMark: &fwmark,
})
if err != nil {
return fmt.Errorf("failed to update fwmark to: %d: %w", fwmark, err)
}
return nil
}
func (m Manager) containsDefaultRoute(allowedIPs []domain.Cidr) (ipV4, ipV6 bool) {
for _, allowedIP := range allowedIPs {
if ipV4 && ipV6 {
break // speed up
}
if allowedIP.Prefix().Bits() == 0 {
if allowedIP.IsV4() {
ipV4 = true
} else {
ipV6 = true
}
}
}
return
}

View File

@@ -25,6 +25,8 @@ type UserDatabaseRepo interface {
GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
// GetUserByEmail returns the user with the given email address.
GetUserByEmail(ctx context.Context, email string) (*domain.User, error)
// GetUserByWebAuthnCredential returns the user for the given WebAuthn credential ID.
GetUserByWebAuthnCredential(ctx context.Context, credentialIdBase64 string) (*domain.User, error)
// GetAllUsers returns all users.
GetAllUsers(ctx context.Context) ([]domain.User, error)
// FindUsers returns all users matching the search string.
@@ -129,6 +131,25 @@ func (m Manager) GetUserByEmail(ctx context.Context, email string) (*domain.User
return user, nil
}
// GetUserByWebAuthnCredential returns the user for the given WebAuthn credential.
func (m Manager) GetUserByWebAuthnCredential(ctx context.Context, credentialIdBase64 string) (*domain.User, error) {
user, err := m.users.GetUserByWebAuthnCredential(ctx, credentialIdBase64)
if err != nil {
return nil, fmt.Errorf("unable to load user for webauthn credential %s: %w", credentialIdBase64, err)
}
if err := domain.ValidateUserAccessRights(ctx, user.Identifier); err != nil {
return nil, err
}
peers, _ := m.peers.GetUserPeers(ctx, user.Identifier) // ignore error, list will be empty in error case
user.LinkedPeerCount = len(peers)
return user, nil
}
// GetAllUsers returns all users.
func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
@@ -343,6 +364,10 @@ func (m Manager) validateModifications(ctx context.Context, old, new *domain.Use
return errors.Join(fmt.Errorf("no access: %w", err), domain.ErrInvalidData)
}
if err := new.HasWeakPassword(m.cfg.Auth.MinPasswordLength); err != nil {
return errors.Join(fmt.Errorf("password too weak: %w", err), domain.ErrInvalidData)
}
if currentUser.Id == old.Identifier && old.IsAdmin && !new.IsAdmin {
return fmt.Errorf("cannot remove own admin rights: %w", domain.ErrInvalidData)
}
@@ -397,7 +422,11 @@ func (m Manager) validateCreation(ctx context.Context, new *domain.User) error {
// database users must have a password
if new.Source == domain.UserSourceDatabase && string(new.Password) == "" {
return fmt.Errorf("invalid password: %w", domain.ErrInvalidData)
return fmt.Errorf("missing password: %w", domain.ErrInvalidData)
}
if err := new.HasWeakPassword(m.cfg.Auth.MinPasswordLength); err != nil {
return errors.Join(fmt.Errorf("password too weak: %w", err), domain.ErrInvalidData)
}
return nil

View File

@@ -8,6 +8,7 @@ import (
"net/http"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/app/webhooks/models"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
@@ -64,6 +65,7 @@ func (m Manager) connectToMessageBus() {
_ = m.bus.Subscribe(app.TopicPeerCreated, m.handlePeerCreateEvent)
_ = m.bus.Subscribe(app.TopicPeerUpdated, m.handlePeerUpdateEvent)
_ = m.bus.Subscribe(app.TopicPeerDeleted, m.handlePeerDeleteEvent)
_ = m.bus.Subscribe(app.TopicPeerStateChanged, m.handlePeerStateChangeEvent)
_ = m.bus.Subscribe(app.TopicInterfaceCreated, m.handleInterfaceCreateEvent)
_ = m.bus.Subscribe(app.TopicInterfaceUpdated, m.handleInterfaceUpdateEvent)
@@ -100,39 +102,47 @@ func (m Manager) sendWebhook(ctx context.Context, data io.Reader) error {
}
func (m Manager) handleUserCreateEvent(user domain.User) {
m.handleGenericEvent(WebhookEventCreate, user)
m.handleGenericEvent(WebhookEventCreate, models.NewUser(user))
}
func (m Manager) handleUserUpdateEvent(user domain.User) {
m.handleGenericEvent(WebhookEventUpdate, user)
m.handleGenericEvent(WebhookEventUpdate, models.NewUser(user))
}
func (m Manager) handleUserDeleteEvent(user domain.User) {
m.handleGenericEvent(WebhookEventDelete, user)
m.handleGenericEvent(WebhookEventDelete, models.NewUser(user))
}
func (m Manager) handlePeerCreateEvent(peer domain.Peer) {
m.handleGenericEvent(WebhookEventCreate, peer)
m.handleGenericEvent(WebhookEventCreate, models.NewPeer(peer))
}
func (m Manager) handlePeerUpdateEvent(peer domain.Peer) {
m.handleGenericEvent(WebhookEventUpdate, peer)
m.handleGenericEvent(WebhookEventUpdate, models.NewPeer(peer))
}
func (m Manager) handlePeerDeleteEvent(peer domain.Peer) {
m.handleGenericEvent(WebhookEventDelete, peer)
m.handleGenericEvent(WebhookEventDelete, models.NewPeer(peer))
}
func (m Manager) handleInterfaceCreateEvent(iface domain.Interface) {
m.handleGenericEvent(WebhookEventCreate, iface)
m.handleGenericEvent(WebhookEventCreate, models.NewInterface(iface))
}
func (m Manager) handleInterfaceUpdateEvent(iface domain.Interface) {
m.handleGenericEvent(WebhookEventUpdate, iface)
m.handleGenericEvent(WebhookEventUpdate, models.NewInterface(iface))
}
func (m Manager) handleInterfaceDeleteEvent(iface domain.Interface) {
m.handleGenericEvent(WebhookEventDelete, iface)
m.handleGenericEvent(WebhookEventDelete, models.NewInterface(iface))
}
func (m Manager) handlePeerStateChangeEvent(peerStatus domain.PeerStatus, peer domain.Peer) {
if peerStatus.IsConnected {
m.handleGenericEvent(WebhookEventConnect, models.NewPeerMetrics(peerStatus, peer))
} else {
m.handleGenericEvent(WebhookEventDisconnect, models.NewPeerMetrics(peerStatus, peer))
}
}
func (m Manager) handleGenericEvent(action WebhookEvent, payload any) {
@@ -168,15 +178,18 @@ func (m Manager) createWebhookData(action WebhookEvent, payload any) (*WebhookDa
}
switch v := payload.(type) {
case domain.User:
case models.User:
d.Entity = WebhookEntityUser
d.Identifier = string(v.Identifier)
case domain.Peer:
d.Identifier = v.Identifier
case models.Peer:
d.Entity = WebhookEntityPeer
d.Identifier = string(v.Identifier)
case domain.Interface:
d.Identifier = v.Identifier
case models.Interface:
d.Entity = WebhookEntityInterface
d.Identifier = string(v.Identifier)
d.Identifier = v.Identifier
case models.PeerMetrics:
d.Entity = WebhookEntityPeerMetric
d.Identifier = v.Peer.Identifier
default:
return nil, fmt.Errorf("unsupported payload type: %T", v)
}

View File

@@ -34,15 +34,18 @@ func (d *WebhookData) Serialize() (io.Reader, error) {
type WebhookEntity = string
const (
WebhookEntityUser WebhookEntity = "user"
WebhookEntityPeer WebhookEntity = "peer"
WebhookEntityInterface WebhookEntity = "interface"
WebhookEntityUser WebhookEntity = "user"
WebhookEntityPeer WebhookEntity = "peer"
WebhookEntityPeerMetric WebhookEntity = "peer_metric"
WebhookEntityInterface WebhookEntity = "interface"
)
type WebhookEvent = string
const (
WebhookEventCreate WebhookEvent = "create"
WebhookEventUpdate WebhookEvent = "update"
WebhookEventDelete WebhookEvent = "delete"
WebhookEventCreate WebhookEvent = "create"
WebhookEventUpdate WebhookEvent = "update"
WebhookEventDelete WebhookEvent = "delete"
WebhookEventConnect WebhookEvent = "connect"
WebhookEventDisconnect WebhookEvent = "disconnect"
)

View File

@@ -0,0 +1,99 @@
package models
import (
"time"
"github.com/h44z/wg-portal/internal/domain"
)
// Interface represents an interface model for webhooks. For details about the fields, see the domain.Interface struct.
type Interface struct {
CreatedBy string `json:"CreatedBy"`
UpdatedBy string `json:"UpdatedBy"`
CreatedAt time.Time `json:"CreatedAt"`
UpdatedAt time.Time `json:"UpdatedAt"`
Identifier string `json:"Identifier"`
PrivateKey string `json:"PrivateKey"`
PublicKey string `json:"PublicKey"`
ListenPort int `json:"ListenPort"`
Addresses []string `json:"Addresses"`
DnsStr string `json:"DnsStr"`
DnsSearchStr string `json:"DnsSearchStr"`
Mtu int `json:"Mtu"`
FirewallMark uint32 `json:"FirewallMark"`
RoutingTable string `json:"RoutingTable"`
PreUp string `json:"PreUp"`
PostUp string `json:"PostUp"`
PreDown string `json:"PreDown"`
PostDown string `json:"PostDown"`
SaveConfig bool `json:"SaveConfig"`
DisplayName string `json:"DisplayName"`
Type string `json:"Type"`
DriverType string `json:"DriverType"`
Disabled *time.Time `json:"Disabled,omitempty"`
DisabledReason string `json:"DisabledReason,omitempty"`
PeerDefNetworkStr string `json:"PeerDefNetworkStr,omitempty"`
PeerDefDnsStr string `json:"PeerDefDnsStr,omitempty"`
PeerDefDnsSearchStr string `json:"PeerDefDnsSearchStr,omitempty"`
PeerDefEndpoint string `json:"PeerDefEndpoint,omitempty"`
PeerDefAllowedIPsStr string `json:"PeerDefAllowedIPsStr,omitempty"`
PeerDefMtu int `json:"PeerDefMtu,omitempty"`
PeerDefPersistentKeepalive int `json:"PeerDefPersistentKeepalive,omitempty"`
PeerDefFirewallMark uint32 `json:"PeerDefFirewallMark,omitempty"`
PeerDefRoutingTable string `json:"PeerDefRoutingTable,omitempty"`
PeerDefPreUp string `json:"PeerDefPreUp,omitempty"`
PeerDefPostUp string `json:"PeerDefPostUp,omitempty"`
PeerDefPreDown string `json:"PeerDefPreDown,omitempty"`
PeerDefPostDown string `json:"PeerDefPostDown,omitempty"`
}
// NewInterface creates a new Interface model from a domain.Interface.
func NewInterface(src domain.Interface) Interface {
return Interface{
CreatedBy: src.CreatedBy,
UpdatedBy: src.UpdatedBy,
CreatedAt: src.CreatedAt,
UpdatedAt: src.UpdatedAt,
Identifier: string(src.Identifier),
PrivateKey: src.KeyPair.PrivateKey,
PublicKey: src.KeyPair.PublicKey,
ListenPort: src.ListenPort,
Addresses: domain.CidrsToStringSlice(src.Addresses),
DnsStr: src.DnsStr,
DnsSearchStr: src.DnsSearchStr,
Mtu: src.Mtu,
FirewallMark: src.FirewallMark,
RoutingTable: src.RoutingTable,
PreUp: src.PreUp,
PostUp: src.PostUp,
PreDown: src.PreDown,
PostDown: src.PostDown,
SaveConfig: src.SaveConfig,
DisplayName: string(src.Identifier),
Type: string(src.Type),
DriverType: src.DriverType,
Disabled: src.Disabled,
DisabledReason: src.DisabledReason,
PeerDefNetworkStr: src.PeerDefNetworkStr,
PeerDefDnsStr: src.PeerDefDnsStr,
PeerDefDnsSearchStr: src.PeerDefDnsSearchStr,
PeerDefEndpoint: src.PeerDefEndpoint,
PeerDefAllowedIPsStr: src.PeerDefAllowedIPsStr,
PeerDefMtu: src.PeerDefMtu,
PeerDefPersistentKeepalive: src.PeerDefPersistentKeepalive,
PeerDefFirewallMark: src.PeerDefFirewallMark,
PeerDefRoutingTable: src.PeerDefRoutingTable,
PeerDefPreUp: src.PeerDefPreUp,
PeerDefPostUp: src.PeerDefPostUp,
PeerDefPreDown: src.PeerDefPreDown,
PeerDefPostDown: src.PeerDefPostDown,
}
}

View File

@@ -0,0 +1,89 @@
package models
import (
"time"
"github.com/h44z/wg-portal/internal/domain"
)
// Peer represents a peer model for webhooks. For details about the fields, see the domain.Peer struct.
type Peer struct {
CreatedBy string `json:"CreatedBy"`
UpdatedBy string `json:"UpdatedBy"`
CreatedAt time.Time `json:"CreatedAt"`
UpdatedAt time.Time `json:"UpdatedAt"`
Endpoint string `json:"Endpoint"`
EndpointPublicKey string `json:"EndpointPublicKey"`
AllowedIPsStr string `json:"AllowedIPsStr"`
ExtraAllowedIPsStr string `json:"ExtraAllowedIPsStr"`
PresharedKey string `json:"PresharedKey"`
PersistentKeepalive int `json:"PersistentKeepalive"`
DisplayName string `json:"DisplayName"`
Identifier string `json:"Identifier"`
UserIdentifier string `json:"UserIdentifier"`
InterfaceIdentifier string `json:"InterfaceIdentifier"`
Disabled *time.Time `json:"Disabled,omitempty"`
DisabledReason string `json:"DisabledReason,omitempty"`
ExpiresAt *time.Time `json:"ExpiresAt,omitempty"`
Notes string `json:"Notes,omitempty"`
AutomaticallyCreated bool `json:"AutomaticallyCreated"`
PrivateKey string `json:"PrivateKey"`
PublicKey string `json:"PublicKey"`
InterfaceType string `json:"InterfaceType"`
Addresses []string `json:"Addresses"`
CheckAliveAddress string `json:"CheckAliveAddress"`
DnsStr string `json:"DnsStr"`
DnsSearchStr string `json:"DnsSearchStr"`
Mtu int `json:"Mtu"`
FirewallMark uint32 `json:"FirewallMark,omitempty"`
RoutingTable string `json:"RoutingTable,omitempty"`
PreUp string `json:"PreUp,omitempty"`
PostUp string `json:"PostUp,omitempty"`
PreDown string `json:"PreDown,omitempty"`
PostDown string `json:"PostDown,omitempty"`
}
// NewPeer creates a new Peer model from a domain.Peer.
func NewPeer(src domain.Peer) Peer {
return Peer{
CreatedBy: src.CreatedBy,
UpdatedBy: src.UpdatedBy,
CreatedAt: src.CreatedAt,
UpdatedAt: src.UpdatedAt,
Endpoint: src.Endpoint.GetValue(),
EndpointPublicKey: src.EndpointPublicKey.GetValue(),
AllowedIPsStr: src.AllowedIPsStr.GetValue(),
ExtraAllowedIPsStr: src.ExtraAllowedIPsStr,
PresharedKey: string(src.PresharedKey),
PersistentKeepalive: src.PersistentKeepalive.GetValue(),
DisplayName: src.DisplayName,
Identifier: string(src.Identifier),
UserIdentifier: string(src.UserIdentifier),
InterfaceIdentifier: string(src.InterfaceIdentifier),
Disabled: src.Disabled,
DisabledReason: src.DisabledReason,
ExpiresAt: src.ExpiresAt,
Notes: src.Notes,
AutomaticallyCreated: src.AutomaticallyCreated,
PrivateKey: src.Interface.KeyPair.PrivateKey,
PublicKey: src.Interface.KeyPair.PublicKey,
InterfaceType: string(src.Interface.Type),
Addresses: domain.CidrsToStringSlice(src.Interface.Addresses),
CheckAliveAddress: src.Interface.CheckAliveAddress,
DnsStr: src.Interface.DnsStr.GetValue(),
DnsSearchStr: src.Interface.DnsSearchStr.GetValue(),
Mtu: src.Interface.Mtu.GetValue(),
FirewallMark: src.Interface.FirewallMark.GetValue(),
RoutingTable: src.Interface.RoutingTable.GetValue(),
PreUp: src.Interface.PreUp.GetValue(),
PostUp: src.Interface.PostUp.GetValue(),
PreDown: src.Interface.PreDown.GetValue(),
PostDown: src.Interface.PostDown.GetValue(),
}
}

View File

@@ -0,0 +1,50 @@
package models
import (
"time"
"github.com/h44z/wg-portal/internal/domain"
)
// PeerMetrics represents a peer metrics model for webhooks.
// For details about the fields, see the domain.PeerStatus and domain.Peer structs.
type PeerMetrics struct {
Status PeerStatus `json:"Status"`
Peer Peer `json:"Peer"`
}
// PeerStatus represents the status of a peer for webhooks.
// For details about the fields, see the domain.PeerStatus struct.
type PeerStatus struct {
UpdatedAt time.Time `json:"UpdatedAt"`
IsConnected bool `json:"IsConnected"`
IsPingable bool `json:"IsPingable"`
LastPing *time.Time `json:"LastPing,omitempty"`
BytesReceived uint64 `json:"BytesReceived"`
BytesTransmitted uint64 `json:"BytesTransmitted"`
Endpoint string `json:"Endpoint"`
LastHandshake *time.Time `json:"LastHandshake,omitempty"`
LastSessionStart *time.Time `json:"LastSessionStart,omitempty"`
}
// NewPeerMetrics creates a new PeerMetrics model from the domain.PeerStatus and domain.Peer models.
func NewPeerMetrics(status domain.PeerStatus, peer domain.Peer) PeerMetrics {
return PeerMetrics{
Status: PeerStatus{
UpdatedAt: status.UpdatedAt,
IsConnected: status.IsConnected,
IsPingable: status.IsPingable,
LastPing: status.LastPing,
BytesReceived: status.BytesReceived,
BytesTransmitted: status.BytesTransmitted,
Endpoint: status.Endpoint,
LastHandshake: status.LastHandshake,
LastSessionStart: status.LastSessionStart,
},
Peer: NewPeer(peer),
}
}

View File

@@ -0,0 +1,56 @@
package models
import (
"time"
"github.com/h44z/wg-portal/internal/domain"
)
// User represents a user model for webhooks. For details about the fields, see the domain.User struct.
type User struct {
CreatedBy string `json:"CreatedBy"`
UpdatedBy string `json:"UpdatedBy"`
CreatedAt time.Time `json:"CreatedAt"`
UpdatedAt time.Time `json:"UpdatedAt"`
Identifier string `json:"Identifier"`
Email string `json:"Email"`
Source string `json:"Source"`
ProviderName string `json:"ProviderName"`
IsAdmin bool `json:"IsAdmin"`
Firstname string `json:"Firstname,omitempty"`
Lastname string `json:"Lastname,omitempty"`
Phone string `json:"Phone,omitempty"`
Department string `json:"Department,omitempty"`
Notes string `json:"Notes,omitempty"`
Disabled *time.Time `json:"Disabled,omitempty"`
DisabledReason string `json:"DisabledReason,omitempty"`
Locked *time.Time `json:"Locked,omitempty"`
LockedReason string `json:"LockedReason,omitempty"`
}
// NewUser creates a new User model from a domain.User
func NewUser(src domain.User) User {
return User{
CreatedBy: src.CreatedBy,
UpdatedBy: src.UpdatedBy,
CreatedAt: src.CreatedAt,
UpdatedAt: src.UpdatedAt,
Identifier: string(src.Identifier),
Email: src.Email,
Source: string(src.Source),
ProviderName: src.ProviderName,
IsAdmin: src.IsAdmin,
Firstname: src.Firstname,
Lastname: src.Lastname,
Phone: src.Phone,
Department: src.Department,
Notes: src.Notes,
Disabled: src.Disabled,
DisabledReason: src.DisabledReason,
Locked: src.Locked,
LockedReason: src.LockedReason,
}
}

View File

@@ -0,0 +1,142 @@
package wireguard
import (
"fmt"
"log/slog"
"maps"
"slices"
"github.com/h44z/wg-portal/internal/adapters/wgcontroller"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
type backendInstance struct {
Config config.BackendBase // Config is the configuration for the backend instance.
Implementation domain.InterfaceController
}
type ControllerManager struct {
cfg *config.Config
controllers map[domain.InterfaceBackend]backendInstance
}
func NewControllerManager(cfg *config.Config) (*ControllerManager, error) {
c := &ControllerManager{
cfg: cfg,
controllers: make(map[domain.InterfaceBackend]backendInstance),
}
err := c.init()
if err != nil {
return nil, err
}
return c, nil
}
func (c *ControllerManager) init() error {
if err := c.registerLocalController(); err != nil {
return err
}
if err := c.registerMikrotikControllers(); err != nil {
return err
}
c.logRegisteredControllers()
return nil
}
func (c *ControllerManager) registerLocalController() error {
localController, err := wgcontroller.NewLocalController(c.cfg)
if err != nil {
return fmt.Errorf("failed to create local WireGuard controller: %w", err)
}
c.controllers[config.LocalBackendName] = backendInstance{
Config: config.BackendBase{
Id: config.LocalBackendName,
DisplayName: "Local WireGuard Controller",
IgnoredInterfaces: c.cfg.Backend.IgnoredLocalInterfaces,
},
Implementation: localController,
}
return nil
}
func (c *ControllerManager) registerMikrotikControllers() error {
for _, backendConfig := range c.cfg.Backend.Mikrotik {
if backendConfig.Id == config.LocalBackendName {
slog.Warn("skipping registration of Mikrotik controller with reserved ID", "id", config.LocalBackendName)
continue
}
controller, err := wgcontroller.NewMikrotikController(c.cfg, &backendConfig)
if err != nil {
return fmt.Errorf("failed to create Mikrotik controller for backend %s: %w", backendConfig.Id, err)
}
c.controllers[domain.InterfaceBackend(backendConfig.Id)] = backendInstance{
Config: backendConfig.BackendBase,
Implementation: controller,
}
}
return nil
}
func (c *ControllerManager) logRegisteredControllers() {
for backend, controller := range c.controllers {
slog.Debug("backend controller registered",
"backend", backend, "type", fmt.Sprintf("%T", controller.Implementation))
}
}
func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) domain.InterfaceController {
return c.getController(backend, "").Implementation
}
func (c *ControllerManager) GetController(iface domain.Interface) domain.InterfaceController {
return c.getController(iface.Backend, iface.Identifier).Implementation
}
func (c *ControllerManager) getController(
backend domain.InterfaceBackend,
ifaceId domain.InterfaceIdentifier,
) backendInstance {
if backend == "" {
// If no backend is specified, use the local controller.
// This might be the case for interfaces created in previous WireGuard Portal versions.
backend = config.LocalBackendName
}
controller, exists := c.controllers[backend]
if !exists {
controller, exists = c.controllers[config.LocalBackendName] // Fallback to local controller
if !exists {
// If the local controller is also not found, panic
panic(fmt.Sprintf("%s interface controller for backend %s not found", ifaceId, backend))
}
slog.Warn("controller for backend not found, using local controller",
"backend", backend, "interface", ifaceId)
}
return controller
}
func (c *ControllerManager) GetAllControllers() []backendInstance {
var backendInstances = make([]backendInstance, 0, len(c.controllers))
for instance := range maps.Values(c.controllers) {
backendInstances = append(backendInstances, instance)
}
return backendInstances
}
func (c *ControllerManager) GetControllerNames() []config.BackendBase {
var names []config.BackendBase
for _, id := range slices.Sorted(maps.Keys(c.controllers)) {
names = append(names, c.controllers[id].Config)
}
return names
}

View File

@@ -6,8 +6,6 @@ import (
"sync"
"time"
probing "github.com/prometheus-community/pro-bing"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
@@ -30,11 +28,6 @@ type StatisticsDatabaseRepo interface {
DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error
}
type StatisticsInterfaceController interface {
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
}
type StatisticsMetricsServer interface {
UpdateInterfaceMetrics(status domain.InterfaceStatus)
UpdatePeerMetrics(peer *domain.Peer, status domain.PeerStatus)
@@ -43,6 +36,13 @@ type StatisticsMetricsServer interface {
type StatisticsEventBus interface {
// Subscribe subscribes to a topic
Subscribe(topic string, fn interface{}) error
// Publish sends a message to the message bus.
Publish(topic string, args ...any)
}
type pingJob struct {
Peer domain.Peer
Backend domain.InterfaceBackend
}
type StatisticsCollector struct {
@@ -50,11 +50,13 @@ type StatisticsCollector struct {
bus StatisticsEventBus
pingWaitGroup sync.WaitGroup
pingJobs chan domain.Peer
pingJobs chan pingJob
db StatisticsDatabaseRepo
wg StatisticsInterfaceController
wg *ControllerManager
ms StatisticsMetricsServer
peerChangeEvent chan domain.PeerIdentifier
}
// NewStatisticsCollector creates a new statistics collector.
@@ -62,7 +64,7 @@ func NewStatisticsCollector(
cfg *config.Config,
bus StatisticsEventBus,
db StatisticsDatabaseRepo,
wg StatisticsInterfaceController,
wg *ControllerManager,
ms StatisticsMetricsServer,
) (*StatisticsCollector, error) {
c := &StatisticsCollector{
@@ -113,7 +115,7 @@ func (c *StatisticsCollector) collectInterfaceData(ctx context.Context) {
}
for _, in := range interfaces {
physicalInterface, err := c.wg.GetInterface(ctx, in.Identifier)
physicalInterface, err := c.wg.GetController(in).GetInterface(ctx, in.Identifier)
if err != nil {
slog.Warn("failed to load physical interface for data collection", "interface", in.Identifier,
"error", err)
@@ -165,14 +167,18 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
}
for _, in := range interfaces {
peers, err := c.wg.GetPeers(ctx, in.Identifier)
peers, err := c.wg.GetController(in).GetPeers(ctx, in.Identifier)
if err != nil {
slog.Warn("failed to fetch peers for data collection", "interface", in.Identifier, "error", err)
continue
}
for _, peer := range peers {
var connectionStateChanged bool
var newPeerStatus domain.PeerStatus
err = c.db.UpdatePeerStatus(ctx, peer.Identifier,
func(p *domain.PeerStatus) (*domain.PeerStatus, error) {
wasConnected := p.IsConnected
var lastHandshake *time.Time
if !peer.LastHandshake.IsZero() {
lastHandshake = &peer.LastHandshake
@@ -186,6 +192,13 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
p.BytesTransmitted = peer.BytesDownload // store bytes that where received from the peer and sent by the server
p.Endpoint = peer.Endpoint
p.LastHandshake = lastHandshake
p.CalcConnected()
if wasConnected != p.IsConnected {
slog.Debug("peer connection state changed", "peer", peer.Identifier, "connected", p.IsConnected)
connectionStateChanged = true
newPeerStatus = *p // store new status for event publishing
}
// Update prometheus metrics
go c.updatePeerMetrics(ctx, *p)
@@ -197,6 +210,17 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
} else {
slog.Debug("updated peer status", "peer", peer.Identifier)
}
if connectionStateChanged {
peerModel, err := c.db.GetPeer(ctx, peer.Identifier)
if err != nil {
slog.Error("failed to fetch peer for data collection", "peer", peer.Identifier, "error",
err)
continue
}
// publish event if connection state changed
c.bus.Publish(app.TopicPeerStateChanged, newPeerStatus, *peerModel)
}
}
}
}
@@ -245,7 +269,7 @@ func (c *StatisticsCollector) startPingWorkers(ctx context.Context) {
c.pingWaitGroup = sync.WaitGroup{}
c.pingWaitGroup.Add(c.cfg.Statistics.PingCheckWorkers)
c.pingJobs = make(chan domain.Peer, c.cfg.Statistics.PingCheckWorkers)
c.pingJobs = make(chan pingJob, c.cfg.Statistics.PingCheckWorkers)
// start workers
for i := 0; i < c.cfg.Statistics.PingCheckWorkers; i++ {
@@ -288,7 +312,10 @@ func (c *StatisticsCollector) enqueuePingChecks(ctx context.Context) {
continue
}
for _, peer := range peers {
c.pingJobs <- peer
c.pingJobs <- pingJob{
Peer: peer,
Backend: in.Backend,
}
}
}
}
@@ -297,13 +324,21 @@ func (c *StatisticsCollector) enqueuePingChecks(ctx context.Context) {
func (c *StatisticsCollector) pingWorker(ctx context.Context) {
defer c.pingWaitGroup.Done()
for peer := range c.pingJobs {
peerPingable := c.isPeerPingable(ctx, peer)
for job := range c.pingJobs {
peer := job.Peer
backend := job.Backend
var connectionStateChanged bool
var newPeerStatus domain.PeerStatus
peerPingable := c.isPeerPingable(ctx, backend, peer)
slog.Debug("peer ping check completed", "peer", peer.Identifier, "pingable", peerPingable)
now := time.Now()
err := c.db.UpdatePeerStatus(ctx, peer.Identifier,
func(p *domain.PeerStatus) (*domain.PeerStatus, error) {
wasConnected := p.IsConnected
if peerPingable {
p.IsPingable = true
p.LastPing = &now
@@ -311,6 +346,13 @@ func (c *StatisticsCollector) pingWorker(ctx context.Context) {
p.IsPingable = false
p.LastPing = nil
}
p.UpdatedAt = time.Now()
p.CalcConnected()
if wasConnected != p.IsConnected {
connectionStateChanged = true
newPeerStatus = *p // store new status for event publishing
}
// Update prometheus metrics
go c.updatePeerMetrics(ctx, *p)
@@ -322,10 +364,19 @@ func (c *StatisticsCollector) pingWorker(ctx context.Context) {
} else {
slog.Debug("updated peer ping status", "peer", peer.Identifier)
}
if connectionStateChanged {
// publish event if connection state changed
c.bus.Publish(app.TopicPeerStateChanged, newPeerStatus, peer)
}
}
}
func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Peer) bool {
func (c *StatisticsCollector) isPeerPingable(
ctx context.Context,
backend domain.InterfaceBackend,
peer domain.Peer,
) bool {
if !c.cfg.Statistics.UsePingChecks {
return false
}
@@ -335,23 +386,13 @@ func (c *StatisticsCollector) isPeerPingable(ctx context.Context, peer domain.Pe
return false
}
pinger, err := probing.NewPinger(checkAddr)
stats, err := c.wg.GetControllerByName(backend).PingAddresses(ctx, checkAddr)
if err != nil {
slog.Debug("failed to instantiate pinger", "peer", peer.Identifier, "address", checkAddr, "error", err)
slog.Debug("failed to ping peer", "peer", peer.Identifier, "error", err)
return false
}
checkCount := 1
pinger.SetPrivileged(!c.cfg.Statistics.PingUnprivileged)
pinger.Count = checkCount
pinger.Timeout = 2 * time.Second
err = pinger.RunWithContext(ctx) // Blocks until finished.
if err != nil {
slog.Debug("pinger for peer exited unexpectedly", "peer", peer.Identifier, "address", checkAddr, "error", err)
return false
}
stats := pinger.Statistics()
return stats.PacketsRecv == checkCount
return stats.IsPingable()
}
func (c *StatisticsCollector) updateInterfaceMetrics(status domain.InterfaceStatus) {

View File

@@ -37,29 +37,10 @@ type InterfaceAndPeerDatabaseRepo interface {
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
}
type InterfaceController interface {
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
SaveInterface(
_ context.Context,
id domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
SavePeer(
_ context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
}
type WgQuickController interface {
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
UnsetDNS(id domain.InterfaceIdentifier) error
ExecuteInterfaceHook(ctx context.Context, id domain.InterfaceIdentifier, hookCmd string) error
SetDNS(ctx context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
UnsetDNS(ctx context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
}
type EventBus interface {
@@ -72,11 +53,10 @@ type EventBus interface {
// endregion dependencies
type Manager struct {
cfg *config.Config
bus EventBus
db InterfaceAndPeerDatabaseRepo
wg InterfaceController
quick WgQuickController
cfg *config.Config
bus EventBus
db InterfaceAndPeerDatabaseRepo
wg *ControllerManager
userLockMap *sync.Map
}
@@ -84,8 +64,7 @@ type Manager struct {
func NewWireGuardManager(
cfg *config.Config,
bus EventBus,
wg InterfaceController,
quick WgQuickController,
wg *ControllerManager,
db InterfaceAndPeerDatabaseRepo,
) (*Manager, error) {
m := &Manager{
@@ -93,7 +72,6 @@ func NewWireGuardManager(
bus: bus,
wg: wg,
db: db,
quick: quick,
userLockMap: &sync.Map{},
}

View File

@@ -11,24 +11,10 @@ import (
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/app/audit"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
// GetImportableInterfaces returns all physical interfaces that are available on the system.
// This function also returns interfaces that are already available in the database.
func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return nil, err
}
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
if err != nil {
return nil, err
}
return physicalInterfaces, nil
}
// GetInterfaceAndPeers returns the interface and all peers for the given interface identifier.
func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
*domain.Interface,
@@ -104,52 +90,64 @@ func (m Manager) GetUserInterfaces(ctx context.Context, _ domain.UserIdentifier)
}
// ImportNewInterfaces imports all new physical interfaces that are available on the system.
// If a filter is set, only interfaces that match the filter will be imported.
func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) {
if err := domain.ValidateAdminAccessRights(ctx); err != nil {
return 0, err
}
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
var existingInterfaceIds []domain.InterfaceIdentifier
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return 0, err
}
// if no filter is given, exclude already existing interfaces
var excludedInterfaces []domain.InterfaceIdentifier
if len(filter) == 0 {
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil {
return 0, err
}
for _, existingInterface := range existingInterfaces {
excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier)
}
for _, existingInterface := range existingInterfaces {
existingInterfaceIds = append(existingInterfaceIds, existingInterface.Identifier)
}
imported := 0
for _, physicalInterface := range physicalInterfaces {
if slices.Contains(excludedInterfaces, physicalInterface.Identifier) {
continue
}
if len(filter) != 0 && !slices.Contains(filter, physicalInterface.Identifier) {
continue
}
slog.Info("importing new interface", "interface", physicalInterface.Identifier)
physicalPeers, err := m.wg.GetPeers(ctx, physicalInterface.Identifier)
for _, wgBackend := range m.wg.GetAllControllers() {
physicalInterfaces, err := wgBackend.Implementation.GetInterfaces(ctx)
if err != nil {
return 0, err
}
err = m.importInterface(ctx, &physicalInterface, physicalPeers)
if err != nil {
return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err)
}
for _, physicalInterface := range physicalInterfaces {
if slices.Contains(wgBackend.Config.IgnoredInterfaces, string(physicalInterface.Identifier)) {
slog.Info("ignoring interface due to backend filter restrictions",
"interface", physicalInterface.Identifier, "filter", wgBackend.Config.IgnoredInterfaces,
"backend", wgBackend.Config.Id)
continue // skip ignored interfaces
}
slog.Info("imported new interface", "interface", physicalInterface.Identifier, "peers", len(physicalPeers))
imported++
if slices.Contains(existingInterfaceIds, physicalInterface.Identifier) {
continue // skip interfaces that already exist
}
if len(filter) > 0 && !slices.Contains(filter, physicalInterface.Identifier) {
slog.Info("ignoring interface due to filter restrictions",
"interface", physicalInterface.Identifier, "filter", wgBackend.Config.IgnoredInterfaces,
"backend", wgBackend.Config.Id)
continue
}
slog.Info("importing new interface",
"interface", physicalInterface.Identifier, "backend", wgBackend.Config.Id)
physicalPeers, err := wgBackend.Implementation.GetPeers(ctx, physicalInterface.Identifier)
if err != nil {
return 0, err
}
err = m.importInterface(ctx, wgBackend.Implementation, &physicalInterface, physicalPeers)
if err != nil {
return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err)
}
slog.Info("imported new interface",
"interface", physicalInterface.Identifier, "peers", len(physicalPeers), "backend", wgBackend.Config.Id)
imported++
}
}
return imported, nil
@@ -213,9 +211,20 @@ func (m Manager) RestoreInterfaceState(
return fmt.Errorf("failed to load peers for %s: %w", iface.Identifier, err)
}
_, err = m.wg.GetInterface(ctx, iface.Identifier)
controller := m.wg.GetController(iface)
_, err = controller.GetInterface(ctx, iface.Identifier)
if err != nil && !iface.IsDisabled() {
slog.Debug("creating missing interface", "interface", iface.Identifier)
slog.Debug("creating missing interface", "interface", iface.Identifier, "backend", controller.GetId())
// temporarily disable interface in database so that the current state is reflected correctly
_ = m.db.SaveInterface(ctx, iface.Identifier,
func(in *domain.Interface) (*domain.Interface, error) {
now := time.Now()
in.Disabled = &now // set
in.DisabledReason = domain.DisabledReasonInterfaceMissing
return in, nil
})
// temporarily disable interface in database so that the current state is reflected correctly
_ = m.db.SaveInterface(ctx, iface.Identifier,
@@ -242,7 +251,8 @@ func (m Manager) RestoreInterfaceState(
return fmt.Errorf("failed to create physical interface %s: %w", iface.Identifier, err)
}
} else {
slog.Debug("restoring interface state", "interface", iface.Identifier, "disabled", iface.IsDisabled())
slog.Debug("restoring interface state",
"interface", iface.Identifier, "disabled", iface.IsDisabled(), "backend", controller.GetId())
// try to move interface to stored state
_, err = m.saveInterface(ctx, &iface)
@@ -269,18 +279,14 @@ func (m Manager) RestoreInterfaceState(
// restore peers
for _, peer := range peers {
switch {
case iface.IsDisabled(): // if interface is disabled, delete all peers
if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil {
case iface.IsDisabled() && iface.Backend == config.LocalBackendName: // if interface is disabled, delete all peers
if err := controller.DeletePeer(ctx, iface.Identifier,
peer.Identifier); err != nil {
return fmt.Errorf("failed to remove peer %s for disabled interface %s: %w",
peer.Identifier, iface.Identifier, err)
}
case peer.IsDisabled(): // if peer is disabled, delete it
if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil {
return fmt.Errorf("failed to remove disbaled peer %s from interface %s: %w",
peer.Identifier, iface.Identifier, err)
}
default: // update peer
err := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier,
err := controller.SavePeer(ctx, iface.Identifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, &peer)
return pp, nil
@@ -293,7 +299,7 @@ func (m Manager) RestoreInterfaceState(
}
// remove non-wgportal peers
physicalPeers, _ := m.wg.GetPeers(ctx, iface.Identifier)
physicalPeers, _ := controller.GetPeers(ctx, iface.Identifier)
for _, physicalPeer := range physicalPeers {
isWgPortalPeer := false
for _, peer := range peers {
@@ -303,7 +309,8 @@ func (m Manager) RestoreInterfaceState(
}
}
if !isWgPortalPeer {
err := m.wg.DeletePeer(ctx, iface.Identifier, domain.PeerIdentifier(physicalPeer.PublicKey))
err := controller.DeletePeer(ctx, iface.Identifier,
domain.PeerIdentifier(physicalPeer.PublicKey))
if err != nil {
return fmt.Errorf("failed to remove non-wgportal peer %s from interface %s: %w",
physicalPeer.PublicKey, iface.Identifier, err)
@@ -455,7 +462,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
return err
}
existingInterface, err := m.db.GetInterface(ctx, id)
existingInterface, existingPeers, err := m.db.GetInterfaceAndPeers(ctx, id)
if err != nil {
return fmt.Errorf("unable to find interface %s: %w", id, err)
}
@@ -464,25 +471,33 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
return fmt.Errorf("deletion not allowed: %w", err)
}
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
Interface: *existingInterface,
AllowedIps: existingInterface.GetAllowedIPs(existingPeers),
FwMark: existingInterface.FirewallMark,
Table: existingInterface.GetRoutingTable(),
TableStr: existingInterface.RoutingTable,
IsDeleted: true,
})
now := time.Now()
existingInterface.Disabled = &now // simulate a disabled interface
existingInterface.DisabledReason = domain.DisabledReasonDeleted
physicalInterface, _ := m.wg.GetInterface(ctx, id)
if err := m.handleInterfacePreSaveHooks(true, existingInterface); err != nil {
if err := m.handleInterfacePreSaveHooks(ctx, existingInterface, !existingInterface.IsDisabled(),
false); err != nil {
return fmt.Errorf("pre-delete hooks failed: %w", err)
}
if err := m.handleInterfacePreSaveActions(existingInterface); err != nil {
if err := m.handleInterfacePreSaveActions(ctx, existingInterface); err != nil {
return fmt.Errorf("pre-delete actions failed: %w", err)
}
if err := m.deleteInterfacePeers(ctx, id); err != nil {
if err := m.deleteInterfacePeers(ctx, existingInterface, existingPeers); err != nil {
return fmt.Errorf("peer deletion failure: %w", err)
}
if err := m.wg.DeleteInterface(ctx, id); err != nil {
if err := m.wg.GetController(*existingInterface).DeleteInterface(ctx, id); err != nil {
return fmt.Errorf("wireguard deletion failure: %w", err)
}
@@ -490,16 +505,12 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
return fmt.Errorf("deletion failure: %w", err)
}
fwMark := existingInterface.FirewallMark
if physicalInterface != nil && fwMark == 0 {
fwMark = physicalInterface.FirewallMark
}
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
FwMark: fwMark,
Table: existingInterface.GetRoutingTable(),
})
if err := m.handleInterfacePostSaveHooks(true, existingInterface); err != nil {
if err := m.handleInterfacePostSaveHooks(
ctx,
existingInterface,
!existingInterface.IsDisabled(),
false,
); err != nil {
return fmt.Errorf("post-delete hooks failed: %w", err)
}
@@ -518,20 +529,24 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
return nil, fmt.Errorf("interface validation failed: %w", err)
}
stateChanged := m.hasInterfaceStateChanged(ctx, iface)
oldEnabled, newEnabled, routeTableChanged := false, !iface.IsDisabled(), false // if the interface did not exist, we assume it was not enabled
oldInterface, err := m.db.GetInterface(ctx, iface.Identifier)
if err == nil {
oldEnabled, newEnabled, routeTableChanged = m.getInterfaceStateHistory(oldInterface, iface)
}
if err := m.handleInterfacePreSaveHooks(stateChanged, iface); err != nil {
if err := m.handleInterfacePreSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil {
return nil, fmt.Errorf("pre-save hooks failed: %w", err)
}
if err := m.handleInterfacePreSaveActions(iface); err != nil {
if err := m.handleInterfacePreSaveActions(ctx, iface); err != nil {
return nil, fmt.Errorf("pre-save actions failed: %w", err)
}
err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
err = m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
iface.CopyCalculatedAttributes(i)
err := m.wg.SaveInterface(ctx, iface.Identifier,
err := m.wg.GetController(*iface).SaveInterface(ctx, iface.Identifier,
func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
domain.MergeToPhysicalInterface(pi, iface)
return pi, nil
@@ -546,24 +561,84 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
return nil, fmt.Errorf("failed to save interface: %w", err)
}
if iface.IsDisabled() {
physicalInterface, _ := m.wg.GetInterface(ctx, iface.Identifier)
fwMark := iface.FirewallMark
if physicalInterface != nil && fwMark == 0 {
fwMark = physicalInterface.FirewallMark
}
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
FwMark: fwMark,
Table: iface.GetRoutingTable(),
// update the interface type of peers in db
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil {
return nil, fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
}
for _, peer := range peers {
err := m.db.SavePeer(ctx, peer.Identifier, func(_ *domain.Peer) (*domain.Peer, error) {
switch iface.Type {
case domain.InterfaceTypeAny:
peer.Interface.Type = domain.InterfaceTypeAny
case domain.InterfaceTypeClient:
peer.Interface.Type = domain.InterfaceTypeServer
case domain.InterfaceTypeServer:
peer.Interface.Type = domain.InterfaceTypeClient
}
return &peer, nil
})
} else {
m.bus.Publish(app.TopicRouteUpdate, "interface updated: "+string(iface.Identifier))
if err != nil {
return nil, fmt.Errorf("failed to update peer %s for interface %s: %w", peer.Identifier,
iface.Identifier, err)
}
}
if err := m.handleInterfacePostSaveHooks(stateChanged, iface); err != nil {
if iface.IsDisabled() {
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
Interface: *iface,
AllowedIps: iface.GetAllowedIPs(peers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(),
TableStr: iface.RoutingTable,
})
} else {
m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
Interface: *iface,
AllowedIps: iface.GetAllowedIPs(peers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(),
TableStr: iface.RoutingTable,
})
// if the route table changed, ensure that the old entries are remove
if routeTableChanged {
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
Interface: *oldInterface,
AllowedIps: oldInterface.GetAllowedIPs(peers),
FwMark: oldInterface.FirewallMark,
Table: oldInterface.GetRoutingTable(),
TableStr: oldInterface.RoutingTable,
IsDeleted: true, // mark the old entries as deleted
})
}
}
if err := m.handleInterfacePostSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil {
return nil, fmt.Errorf("post-save hooks failed: %w", err)
}
// If the interface has just been enabled, restore its peers on the physical controller
if !oldEnabled && newEnabled && iface.Backend == config.LocalBackendName {
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil {
return nil, fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
}
for _, peer := range peers {
saveErr := m.wg.GetController(*iface).SavePeer(ctx, iface.Identifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, &peer)
return pp, nil
})
if saveErr != nil {
return nil, fmt.Errorf("failed to restore peer %s for interface %s: %w", peer.Identifier,
iface.Identifier, saveErr)
}
}
// notify that peers for this interface have changed so config/routes can be updated
m.bus.Publish(app.TopicPeerInterfaceUpdated, iface.Identifier)
}
m.bus.Publish(app.TopicAuditInterfaceChanged, domain.AuditEventWrapper[audit.InterfaceEvent]{
Ctx: ctx,
Event: audit.InterfaceEvent{
@@ -575,75 +650,90 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
return iface, nil
}
func (m Manager) hasInterfaceStateChanged(ctx context.Context, iface *domain.Interface) bool {
oldInterface, err := m.db.GetInterface(ctx, iface.Identifier)
if err != nil {
return false
}
if oldInterface.IsDisabled() != iface.IsDisabled() {
return true // interface in db has changed
}
wgInterface, err := m.wg.GetInterface(ctx, iface.Identifier)
if err != nil {
return true // interface might not exist - so we assume that there must be a change
}
// compare physical interface settings
if len(wgInterface.Addresses) != len(iface.Addresses) ||
wgInterface.Mtu != iface.Mtu ||
wgInterface.FirewallMark != iface.FirewallMark ||
wgInterface.ListenPort != iface.ListenPort ||
wgInterface.PrivateKey != iface.PrivateKey ||
wgInterface.PublicKey != iface.PublicKey {
return true
}
return false
func (m Manager) getInterfaceStateHistory(
oldInterface *domain.Interface,
iface *domain.Interface,
) (oldEnabled, newEnabled, routeTableChanged bool) {
return !oldInterface.IsDisabled(), !iface.IsDisabled(), oldInterface.RoutingTable != iface.RoutingTable
}
func (m Manager) handleInterfacePreSaveActions(iface *domain.Interface) error {
if !iface.IsDisabled() {
if err := m.quick.SetDNS(iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
return fmt.Errorf("failed to update dns settings: %w", err)
}
} else {
if err := m.quick.UnsetDNS(iface.Identifier); err != nil {
return fmt.Errorf("failed to clear dns settings: %w", err)
func (m Manager) handleInterfacePreSaveActions(ctx context.Context, iface *domain.Interface) error {
wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
if !ok {
slog.Warn("failed to perform pre-save actions", "interface", iface.Identifier,
"error", "no capable controller found")
return nil
}
// update DNS settings only for client interfaces
if iface.Type == domain.InterfaceTypeClient || iface.Type == domain.InterfaceTypeAny {
if !iface.IsDisabled() {
if err := wgQuickController.SetDNS(ctx, iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
return fmt.Errorf("failed to update dns settings: %w", err)
}
} else {
if err := wgQuickController.UnsetDNS(ctx, iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
return fmt.Errorf("failed to clear dns settings: %w", err)
}
}
}
return nil
}
func (m Manager) handleInterfacePreSaveHooks(stateChanged bool, iface *domain.Interface) error {
if !stateChanged {
func (m Manager) handleInterfacePreSaveHooks(
ctx context.Context,
iface *domain.Interface,
oldEnabled, newEnabled bool,
) error {
if oldEnabled == newEnabled {
return nil // do nothing if state did not change
}
if !iface.IsDisabled() {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreUp); err != nil {
slog.Debug("executing pre-save hooks", "interface", iface.Identifier, "up", newEnabled)
wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
if !ok {
slog.Warn("failed to execute pre-save hooks", "interface", iface.Identifier, "up", newEnabled,
"error", "no capable controller found")
return nil
}
if newEnabled {
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PreUp); err != nil {
return fmt.Errorf("failed to execute pre-up hook: %w", err)
}
} else {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreDown); err != nil {
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PreDown); err != nil {
return fmt.Errorf("failed to execute pre-down hook: %w", err)
}
}
return nil
}
func (m Manager) handleInterfacePostSaveHooks(stateChanged bool, iface *domain.Interface) error {
if !stateChanged {
func (m Manager) handleInterfacePostSaveHooks(
ctx context.Context,
iface *domain.Interface,
oldEnabled, newEnabled bool,
) error {
if oldEnabled == newEnabled {
return nil // do nothing if state did not change
}
if !iface.IsDisabled() {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostUp); err != nil {
slog.Debug("executing post-save hooks", "interface", iface.Identifier, "up", newEnabled)
wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
if !ok {
slog.Warn("failed to execute post-save hooks", "interface", iface.Identifier, "up", newEnabled,
"error", "no capable controller found")
return nil
}
if newEnabled {
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PostUp); err != nil {
return fmt.Errorf("failed to execute post-up hook: %w", err)
}
} else {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostDown); err != nil {
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PostDown); err != nil {
return fmt.Errorf("failed to execute post-down hook: %w", err)
}
}
@@ -769,7 +859,12 @@ func (m Manager) getFreshListenPort(ctx context.Context) (port int, err error) {
return
}
func (m Manager) importInterface(ctx context.Context, in *domain.PhysicalInterface, peers []domain.PhysicalPeer) error {
func (m Manager) importInterface(
ctx context.Context,
backend domain.InterfaceController,
in *domain.PhysicalInterface,
peers []domain.PhysicalPeer,
) error {
now := time.Now()
iface := domain.ConvertPhysicalInterface(in)
iface.BaseModel = domain.BaseModel{
@@ -778,8 +873,20 @@ func (m Manager) importInterface(ctx context.Context, in *domain.PhysicalInterfa
CreatedAt: now,
UpdatedAt: now,
}
iface.Backend = backend.GetId()
iface.PeerDefAllowedIPsStr = iface.AddressStr()
// try to predict the interface type based on the number of peers
switch len(peers) {
case 0:
iface.Type = domain.InterfaceTypeAny // no peers means this is an unknown interface
case 1:
iface.Type = domain.InterfaceTypeClient // one peer means this is a client interface
default: // multiple peers means this is a server interface
iface.Type = domain.InterfaceTypeServer
}
existingInterface, err := m.db.GetInterface(ctx, iface.Identifier)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return err
@@ -830,16 +937,20 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain
peer.Interface.PreDown = domain.NewConfigOption(in.PeerDefPreDown, true)
peer.Interface.PostDown = domain.NewConfigOption(in.PeerDefPostDown, true)
var displayName string
switch in.Type {
case domain.InterfaceTypeAny:
peer.Interface.Type = domain.InterfaceTypeAny
peer.DisplayName = "Autodetected Peer (" + peer.Interface.PublicKey[0:8] + ")"
displayName = "Autodetected Peer (" + peer.Interface.PublicKey[0:8] + ")"
case domain.InterfaceTypeClient:
peer.Interface.Type = domain.InterfaceTypeServer
peer.DisplayName = "Autodetected Endpoint (" + peer.Interface.PublicKey[0:8] + ")"
displayName = "Autodetected Endpoint (" + peer.Interface.PublicKey[0:8] + ")"
case domain.InterfaceTypeServer:
peer.Interface.Type = domain.InterfaceTypeClient
peer.DisplayName = "Autodetected Client (" + peer.Interface.PublicKey[0:8] + ")"
displayName = "Autodetected Client (" + peer.Interface.PublicKey[0:8] + ")"
}
if peer.DisplayName == "" {
peer.DisplayName = displayName // use auto-generated display name if not set
}
err := m.db.SavePeer(ctx, peer.Identifier, func(_ *domain.Peer) (*domain.Peer, error) {
@@ -852,13 +963,9 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain
return nil
}
func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) error {
allPeers, err := m.db.GetInterfacePeers(ctx, id)
if err != nil {
return err
}
func (m Manager) deleteInterfacePeers(ctx context.Context, iface *domain.Interface, allPeers []domain.Peer) error {
for _, peer := range allPeers {
err = m.wg.DeletePeer(ctx, id, peer.Identifier)
err := m.wg.GetController(*iface).DeletePeer(ctx, iface.Identifier, peer.Identifier)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err)
}

View File

@@ -188,6 +188,32 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
sessionUser := domain.GetUserInfo(ctx)
peer.Identifier = domain.PeerIdentifier(peer.Interface.PublicKey) // ensure that identifier corresponds to the public key
// Enforce peer limit for non-admin users if LimitAdditionalUserPeers is set
if m.cfg.Core.SelfProvisioningAllowed && !sessionUser.IsAdmin && m.cfg.Advanced.LimitAdditionalUserPeers > 0 {
peers, err := m.db.GetUserPeers(ctx, peer.UserIdentifier)
if err != nil {
return nil, fmt.Errorf("failed to fetch peers for user %s: %w", peer.UserIdentifier, err)
}
// Count enabled peers (disabled IS NULL)
peerCount := 0
for _, p := range peers {
if !p.IsDisabled() {
peerCount++
}
}
totalAllowedPeers := 1 + m.cfg.Advanced.LimitAdditionalUserPeers // 1 default peer + x additional peers
if peerCount >= totalAllowedPeers {
slog.WarnContext(ctx, "peer creation blocked due to limit",
"user", peer.UserIdentifier,
"current_count", peerCount,
"allowed_count", totalAllowedPeers)
return nil, fmt.Errorf("peer limit reached (%d peers allowed): %w", totalAllowedPeers,
domain.ErrNoPermission)
}
}
existingPeer, err := m.db.GetPeer(ctx, peer.Identifier)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return nil, fmt.Errorf("unable to load existing peer %s: %w", peer.Identifier, err)
@@ -347,7 +373,12 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
return fmt.Errorf("delete not allowed: %w", err)
}
err = m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, id)
iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
if err != nil {
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
}
err = m.wg.GetController(*iface).DeletePeer(ctx, peer.InterfaceIdentifier, id)
if err != nil {
return fmt.Errorf("wireguard failed to delete peer %s: %w", id, err)
}
@@ -357,9 +388,20 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
return fmt.Errorf("failed to delete peer %s: %w", id, err)
}
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil {
return fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
}
m.bus.Publish(app.TopicPeerDeleted, *peer)
// Update routes after peers have changed
m.bus.Publish(app.TopicRouteUpdate, "peers updated")
m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
Interface: *iface,
AllowedIps: iface.GetAllowedIPs(peers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(),
TableStr: iface.RoutingTable,
})
// Update interface after peers have changed
m.bus.Publish(app.TopicPeerInterfaceUpdated, peer.InterfaceIdentifier)
@@ -407,37 +449,36 @@ func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier)
// region helper-functions
func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
interfaces := make(map[domain.InterfaceIdentifier]struct{})
interfaces := make(map[domain.InterfaceIdentifier]domain.Interface)
for i := range peers {
peer := peers[i]
var err error
if peer.IsDisabled() || peer.IsExpired() {
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
peer.CopyCalculatedAttributes(p)
if err := m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, peer.Identifier); err != nil {
return nil, fmt.Errorf("failed to delete wireguard peer %s: %w", peer.Identifier, err)
}
return peer, nil
})
} else {
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
peer.CopyCalculatedAttributes(p)
err := m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, peer)
return pp, nil
})
if err != nil {
return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
}
return peer, nil
})
for _, peer := range peers {
// get interface from db if it is not yet in the map
if _, ok := interfaces[peer.InterfaceIdentifier]; !ok {
iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
if err != nil {
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
}
interfaces[peer.InterfaceIdentifier] = *iface
}
iface := interfaces[peer.InterfaceIdentifier]
// Always save the peer to the backend, regardless of disabled/expired state
// The backend will handle the disabled state appropriately
err := m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
peer.CopyCalculatedAttributes(p)
err := m.wg.GetController(iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, peer)
return pp, nil
})
if err != nil {
return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
}
return peer, nil
})
if err != nil {
return fmt.Errorf("save failure for peer %s: %w", peer.Identifier, err)
}
@@ -451,13 +492,22 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
Peer: *peer,
},
})
interfaces[peer.InterfaceIdentifier] = struct{}{}
}
// Update routes after peers have changed
if len(interfaces) != 0 {
m.bus.Publish(app.TopicRouteUpdate, "peers updated")
for id, iface := range interfaces {
interfacePeers, err := m.db.GetInterfacePeers(ctx, id)
if err != nil {
return fmt.Errorf("failed to re-load peers for interface %s: %w", id, err)
}
m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
Interface: iface,
AllowedIps: iface.GetAllowedIPs(interfacePeers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(),
TableStr: iface.RoutingTable,
})
}
for iface := range interfaces {

View File

@@ -0,0 +1,194 @@
package wireguard
import (
"context"
"testing"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
)
// --- Test mocks ---
type mockBus struct{}
func (f *mockBus) Publish(topic string, args ...any) {}
func (f *mockBus) Subscribe(topic string, fn interface{}) error { return nil }
type mockController struct{}
func (f *mockController) GetId() domain.InterfaceBackend { return "local" }
func (f *mockController) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) {
return nil, nil
}
func (f *mockController) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (
*domain.PhysicalInterface,
error,
) {
return &domain.PhysicalInterface{Identifier: id}, nil
}
func (f *mockController) GetPeers(_ context.Context, _ domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) {
return nil, nil
}
func (f *mockController) SaveInterface(
_ context.Context,
_ domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error {
_, _ = updateFunc(&domain.PhysicalInterface{})
return nil
}
func (f *mockController) DeleteInterface(_ context.Context, _ domain.InterfaceIdentifier) error {
return nil
}
func (f *mockController) SavePeer(
_ context.Context,
_ domain.InterfaceIdentifier,
_ domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error {
_, _ = updateFunc(&domain.PhysicalPeer{})
return nil
}
func (f *mockController) DeletePeer(_ context.Context, _ domain.InterfaceIdentifier, _ domain.PeerIdentifier) error {
return nil
}
func (f *mockController) PingAddresses(_ context.Context, _ string) (*domain.PingerResult, error) {
return nil, nil
}
type mockDB struct {
savedPeers map[domain.PeerIdentifier]*domain.Peer
iface *domain.Interface
}
func (f *mockDB) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) {
if f.iface != nil && f.iface.Identifier == id {
return f.iface, nil
}
return &domain.Interface{Identifier: id}, nil
}
func (f *mockDB) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
*domain.Interface,
[]domain.Peer,
error,
) {
return f.iface, nil, nil
}
func (f *mockDB) GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error) {
return nil, nil
}
func (f *mockDB) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) { return nil, nil }
func (f *mockDB) GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error) {
return nil, nil
}
func (f *mockDB) SaveInterface(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(in *domain.Interface) (*domain.Interface, error),
) error {
if f.iface == nil {
f.iface = &domain.Interface{Identifier: id}
}
var err error
f.iface, err = updateFunc(f.iface)
return err
}
func (f *mockDB) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
return nil
}
func (f *mockDB) GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error) {
return nil, nil
}
func (f *mockDB) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
return nil, nil
}
func (f *mockDB) SavePeer(
ctx context.Context,
id domain.PeerIdentifier,
updateFunc func(in *domain.Peer) (*domain.Peer, error),
) error {
if f.savedPeers == nil {
f.savedPeers = make(map[domain.PeerIdentifier]*domain.Peer)
}
existing := f.savedPeers[id]
if existing == nil {
existing = &domain.Peer{Identifier: id}
}
updated, err := updateFunc(existing)
if err != nil {
return err
}
f.savedPeers[updated.Identifier] = updated
return nil
}
func (f *mockDB) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error { return nil }
func (f *mockDB) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) {
return nil, domain.ErrNotFound
}
func (f *mockDB) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (
map[domain.Cidr][]domain.Cidr,
error,
) {
return map[domain.Cidr][]domain.Cidr{}, nil
}
// --- Test ---
func TestCreatePeer_SetsIdentifier_FromPublicKey(t *testing.T) {
// Arrange
cfg := &config.Config{}
cfg.Core.SelfProvisioningAllowed = true
cfg.Core.EditableKeys = true
cfg.Advanced.LimitAdditionalUserPeers = 0
bus := &mockBus{}
// Prepare a controller manager with our mock controller
ctrlMgr := &ControllerManager{
controllers: map[domain.InterfaceBackend]backendInstance{
config.LocalBackendName: {Implementation: &mockController{}},
},
}
db := &mockDB{iface: &domain.Interface{Identifier: "wg0", Type: domain.InterfaceTypeServer}}
m := Manager{
cfg: cfg,
bus: bus,
db: db,
wg: ctrlMgr,
}
userId := domain.UserIdentifier("user@example.com")
ctx := domain.SetUserInfo(context.Background(), &domain.ContextUserInfo{Id: userId, IsAdmin: false})
pubKey := "TEST_PUBLIC_KEY_ABC123"
input := &domain.Peer{
Identifier: "should_be_overwritten",
UserIdentifier: userId,
InterfaceIdentifier: domain.InterfaceIdentifier("wg0"),
Interface: domain.PeerInterfaceConfig{
KeyPair: domain.KeyPair{PublicKey: pubKey},
},
}
// Act
out, err := m.CreatePeer(ctx, input)
// Assert
if err != nil {
t.Fatalf("CreatePeer returned error: %v", err)
}
expectedId := domain.PeerIdentifier(pubKey)
if out.Identifier != expectedId {
t.Fatalf("expected Identifier to be set from public key %q, got %q", expectedId, out.Identifier)
}
// Ensure the saved peer in DB also has the expected identifier
if db.savedPeers[expectedId] == nil {
t.Fatalf("expected peer with identifier %q to be saved in DB", expectedId)
}
}

View File

@@ -16,6 +16,14 @@ type Auth struct {
OAuth []OAuthProvider `yaml:"oauth"`
// Ldap contains a list of LDAP providers.
Ldap []LdapProvider `yaml:"ldap"`
// Webauthn contains the configuration for the WebAuthn authenticator.
WebAuthn WebauthnConfig `yaml:"webauthn"`
// MinPasswordLength is the minimum password length for user accounts. This also applies to the admin user.
// It is encouraged to set this value to at least 16 characters.
MinPasswordLength int `yaml:"min_password_length"`
// HideLoginForm specifies whether the login form should be hidden. If no social login providers are configured,
// the login form will be shown regardless of this setting.
HideLoginForm bool `yaml:"hide_login_form"`
}
// BaseFields contains the basic fields that are used to map user information from the authentication providers.
@@ -203,6 +211,10 @@ type OpenIDConnectProvider struct {
// If LogUserInfo is set to true, the user info retrieved from the OIDC provider will be logged in trace level.
LogUserInfo bool `yaml:"log_user_info"`
// If LogSensitiveInfo is set to true, sensitive information retrieved from the OIDC provider will be logged in trace level.
// This also includes OAuth tokens! Keep this disabled in production!
LogSensitiveInfo bool `yaml:"log_sensitive_info"`
}
// OAuthProvider contains the configuration for the OAuth provider.
@@ -244,4 +256,14 @@ type OAuthProvider struct {
// If LogUserInfo is set to true, the user info retrieved from the OAuth provider will be logged in trace level.
LogUserInfo bool `yaml:"log_user_info"`
// If LogSensitiveInfo is set to true, sensitive information retrieved from the OAuth provider will be logged in trace level.
// This also includes OAuth tokens! Keep this disabled in production!
LogSensitiveInfo bool `yaml:"log_sensitive_info"`
}
// WebauthnConfig contains the configuration for the WebAuthn authenticator.
type WebauthnConfig struct {
// Enabled specifies whether WebAuthn is enabled.
Enabled bool `yaml:"enabled"`
}

103
internal/config/backend.go Normal file
View File

@@ -0,0 +1,103 @@
package config
import (
"fmt"
"time"
)
const LocalBackendName = "local"
type Backend struct {
Default string `yaml:"default"` // The default backend to use (defaults to the internal backend)
// Local Backend-specific configuration
IgnoredLocalInterfaces []string `yaml:"ignored_local_interfaces"` // A list of interface names that should be ignored by this backend (e.g., "wg0")
LocalResolvconfPrefix string `yaml:"local_resolvconf_prefix"` // The prefix to use for interface names when passing them to resolvconf.
// External Backend-specific configuration
Mikrotik []BackendMikrotik `yaml:"mikrotik"`
}
// Validate checks the backend configuration for errors.
func (b *Backend) Validate() error {
if b.Default == "" {
b.Default = LocalBackendName
}
uniqueMap := make(map[string]struct{})
for _, backend := range b.Mikrotik {
if backend.Id == LocalBackendName {
return fmt.Errorf("backend ID %q is a reserved keyword", LocalBackendName)
}
if _, exists := uniqueMap[backend.Id]; exists {
return fmt.Errorf("backend ID %q is not unique", backend.Id)
}
uniqueMap[backend.Id] = struct{}{}
}
if b.Default != LocalBackendName {
if _, ok := uniqueMap[b.Default]; !ok {
return fmt.Errorf("default backend %q is not defined in the configuration", b.Default)
}
}
return nil
}
type BackendBase struct {
Id string `yaml:"id"` // A unique id for the backend
DisplayName string `yaml:"display_name"` // A display name for the backend
IgnoredInterfaces []string `yaml:"ignored_interfaces"` // A list of interface names that should be ignored by this backend (e.g., "wg0")
}
// GetDisplayName returns the display name of the backend.
// If no display name is set, it falls back to the ID.
func (b BackendBase) GetDisplayName() string {
if b.DisplayName == "" {
return b.Id // Fallback to ID if no display name is set
}
return b.DisplayName
}
type BackendMikrotik struct {
BackendBase `yaml:",inline"` // Embed the base fields
ApiUrl string `yaml:"api_url"` // The base URL of the Mikrotik API (e.g., "https://10.10.10.10:8729/rest")
ApiUser string `yaml:"api_user"`
ApiPassword string `yaml:"api_password"`
ApiVerifyTls bool `yaml:"api_verify_tls"` // Whether to verify the TLS certificate of the Mikrotik API
ApiTimeout time.Duration `yaml:"api_timeout"` // Timeout for API requests (default: 30 seconds)
// Concurrency controls the maximum number of concurrent API requests that this backend will issue
// when enumerating interfaces and their details. If 0 or negative, a default of 5 is used.
Concurrency int `yaml:"concurrency"`
Debug bool `yaml:"debug"` // Enable debug logging for the Mikrotik backend
}
// GetConcurrency returns the configured concurrency for this backend or a sane default (5)
// when the configured value is zero or negative.
func (b *BackendMikrotik) GetConcurrency() int {
if b == nil {
return 5
}
if b.Concurrency <= 0 {
return 5
}
return b.Concurrency
}
// GetApiTimeout returns the configured API timeout or a sane default (30 seconds)
// when the configured value is zero or negative.
func (b *BackendMikrotik) GetApiTimeout() time.Duration {
if b == nil {
return 30 * time.Second
}
if b.ApiTimeout <= 0 {
return 30 * time.Second
}
return b.ApiTimeout
}

View File

@@ -14,9 +14,10 @@ import (
type Config struct {
Core struct {
// AdminUser defines the default administrator account that will be created
AdminUser string `yaml:"admin_user"`
AdminPassword string `yaml:"admin_password"`
AdminApiToken string `yaml:"admin_api_token"` // if set, the API access is enabled automatically
AdminUserDisabled bool `yaml:"disable_admin_user"`
AdminUser string `yaml:"admin_user"`
AdminPassword string `yaml:"admin_password"`
AdminApiToken string `yaml:"admin_api_token"` // if set, the API access is enabled automatically
EditableKeys bool `yaml:"editable_keys"`
CreateDefaultPeer bool `yaml:"create_default_peer"`
@@ -29,20 +30,23 @@ type Config struct {
} `yaml:"core"`
Advanced struct {
LogLevel string `yaml:"log_level"`
LogPretty bool `yaml:"log_pretty"`
LogJson bool `yaml:"log_json"`
StartListenPort int `yaml:"start_listen_port"`
StartCidrV4 string `yaml:"start_cidr_v4"`
StartCidrV6 string `yaml:"start_cidr_v6"`
UseIpV6 bool `yaml:"use_ip_v6"`
ConfigStoragePath string `yaml:"config_storage_path"` // keep empty to disable config export to file
ExpiryCheckInterval time.Duration `yaml:"expiry_check_interval"`
RulePrioOffset int `yaml:"rule_prio_offset"`
RouteTableOffset int `yaml:"route_table_offset"`
ApiAdminOnly bool `yaml:"api_admin_only"` // if true, only admin users can access the API
LogLevel string `yaml:"log_level"`
LogPretty bool `yaml:"log_pretty"`
LogJson bool `yaml:"log_json"`
StartListenPort int `yaml:"start_listen_port"`
StartCidrV4 string `yaml:"start_cidr_v4"`
StartCidrV6 string `yaml:"start_cidr_v6"`
UseIpV6 bool `yaml:"use_ip_v6"`
ConfigStoragePath string `yaml:"config_storage_path"` // keep empty to disable config export to file
ExpiryCheckInterval time.Duration `yaml:"expiry_check_interval"`
RulePrioOffset int `yaml:"rule_prio_offset"`
RouteTableOffset int `yaml:"route_table_offset"`
ApiAdminOnly bool `yaml:"api_admin_only"` // if true, only admin users can access the API
LimitAdditionalUserPeers int `yaml:"limit_additional_user_peers"`
} `yaml:"advanced"`
Backend Backend `yaml:"backend"`
Statistics struct {
UsePingChecks bool `yaml:"use_ping_checks"`
PingCheckWorkers int `yaml:"ping_check_workers"`
@@ -76,6 +80,7 @@ func (c *Config) LogStartupValues() {
"reEnablePeerAfterUserEnable", c.Core.ReEnablePeerAfterUserEnable,
"deletePeerAfterUserDeleted", c.Core.DeletePeerAfterUserDeleted,
"selfProvisioningAllowed", c.Core.SelfProvisioningAllowed,
"limitAdditionalUserPeers", c.Advanced.LimitAdditionalUserPeers,
"importExisting", c.Core.ImportExisting,
"restoreState", c.Core.RestoreState,
"useIpV6", c.Advanced.UseIpV6,
@@ -93,15 +98,25 @@ func (c *Config) LogStartupValues() {
"oidcProviders", len(c.Auth.OpenIDConnect),
"oauthProviders", len(c.Auth.OAuth),
"ldapProviders", len(c.Auth.Ldap),
"webauthnEnabled", c.Auth.WebAuthn.Enabled,
"minPasswordLength", c.Auth.MinPasswordLength,
"hideLoginForm", c.Auth.HideLoginForm,
)
slog.Debug("Config Backend",
"defaultBackend", c.Backend.Default,
"extraBackends", len(c.Backend.Mikrotik),
)
}
// defaultConfig returns the default configuration
func defaultConfig() *Config {
cfg := &Config{}
cfg.Core.AdminUserDisabled = false
cfg.Core.AdminUser = "admin@wgportal.local"
cfg.Core.AdminPassword = "wgportal"
cfg.Core.AdminPassword = "wgportal-default"
cfg.Core.AdminApiToken = "" // by default, the API access is disabled
cfg.Core.ImportExisting = true
cfg.Core.RestoreState = true
@@ -117,6 +132,13 @@ func defaultConfig() *Config {
DSN: "data/sqlite.db",
}
cfg.Backend = Backend{
Default: LocalBackendName, // local backend is the default (using wgcrtl)
// Most resolconf implementations use "tun." as a prefix for interface names.
// But systemd's implementation uses no prefix, for example.
LocalResolvconfPrefix: "tun.",
}
cfg.Web = WebConfig{
RequestLogging: false,
ExternalUrl: "http://localhost:8888",
@@ -137,6 +159,7 @@ func defaultConfig() *Config {
cfg.Advanced.RulePrioOffset = 20000
cfg.Advanced.RouteTableOffset = 20000
cfg.Advanced.ApiAdminOnly = true
cfg.Advanced.LimitAdditionalUserPeers = 0
cfg.Statistics.UsePingChecks = true
cfg.Statistics.PingCheckWorkers = 10
@@ -164,6 +187,10 @@ func defaultConfig() *Config {
cfg.Webhook.Authentication = ""
cfg.Webhook.Timeout = 10 * time.Second
cfg.Auth.WebAuthn.Enabled = true
cfg.Auth.MinPasswordLength = 16
cfg.Auth.HideLoginForm = false
return cfg
}
@@ -191,6 +218,10 @@ func GetConfig() (*Config, error) {
}
cfg.Web.Sanitize()
err := cfg.Backend.Validate()
if err != nil {
return nil, err
}
return cfg, nil
}

View File

@@ -41,4 +41,6 @@ type MailConfig struct {
From string `yaml:"from"`
// LinkOnly specifies whether emails should only contain a link to WireGuard Portal or attach the full configuration
LinkOnly bool `yaml:"link_only"`
// AllowPeerEmail specifies whether emails should be sent to peers which have no valid user account linked, but an email address is set as "user".
AllowPeerEmail bool `yaml:"allow_peer_email"`
}

View File

@@ -62,4 +62,7 @@ const (
LockedReasonAdmin = "locked by admin"
LockedReasonApi = "locked by admin"
ConfigStyleRaw = "raw"
ConfigStyleWgQuick = "wgquick"
)

View File

@@ -0,0 +1,32 @@
package domain
// ControllerType defines the type of controller used to manage interfaces.
const (
ControllerTypeMikrotik = "mikrotik"
ControllerTypeLocal = "wgctrl"
)
// Controller extras can be used to store additional information available for specific controllers only.
type MikrotikInterfaceExtras struct {
Id string // internal mikrotik ID
Comment string
Disabled bool
}
type MikrotikPeerExtras struct {
Id string // internal mikrotik ID
Name string
Comment string
IsResponder bool
Disabled bool
ClientEndpoint string
ClientAddress string
ClientDns string
ClientKeepalive int
}
type LocalPeerExtras struct {
Disabled bool
}

View File

@@ -10,7 +10,10 @@ import (
"strings"
"time"
"golang.org/x/sys/unix"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/config"
)
const (
@@ -23,6 +26,7 @@ var allowedFileNameRegex = regexp.MustCompile("[^a-zA-Z0-9-_]+")
type InterfaceIdentifier string
type InterfaceType string
type InterfaceBackend string
type Interface struct {
BaseModel
@@ -49,11 +53,12 @@ type Interface struct {
SaveConfig bool // automatically persist config changes to the wgX.conf file
// WG Portal specific
DisplayName string // a nice display name/ description for the interface
Type InterfaceType // the interface type, either InterfaceTypeServer or InterfaceTypeClient
DriverType string // the interface driver type (linux, software, ...)
Disabled *time.Time `gorm:"index"` // flag that specifies if the interface is enabled (up) or not (down)
DisabledReason string // the reason why the interface has been disabled
DisplayName string // a nice display name/ description for the interface
Type InterfaceType // the interface type, either InterfaceTypeServer or InterfaceTypeClient
Backend InterfaceBackend // the backend that is used to manage the interface (wgctrl, mikrotik, ...)
DriverType string // the interface driver type (linux, software, ...)
Disabled *time.Time `gorm:"index"` // flag that specifies if the interface is enabled (up) or not (down)
DisabledReason string // the reason why the interface has been disabled
// Default settings for the peer, used for new peers, those settings will be published to ConfigOption options of
// the peer config
@@ -128,17 +133,30 @@ func (i *Interface) GetConfigFileName() string {
return filename
}
// GetAllowedIPs returns the allowed IPs for the interface depending on the interface type and peers.
// For example, if the interface type is Server, the allowed IPs are the IPs of the peers.
// If the interface type is Client, the allowed IPs correspond to the AllowedIPsStr of the peers.
func (i *Interface) GetAllowedIPs(peers []Peer) []Cidr {
var allowedCidrs []Cidr
for _, peer := range peers {
for _, ip := range peer.Interface.Addresses {
allowedCidrs = append(allowedCidrs, ip.HostAddr())
switch i.Type {
case InterfaceTypeServer, InterfaceTypeAny:
for _, peer := range peers {
for _, ip := range peer.Interface.Addresses {
allowedCidrs = append(allowedCidrs, ip.HostAddr())
}
if peer.ExtraAllowedIPsStr != "" {
extraIPs, err := CidrsFromString(peer.ExtraAllowedIPsStr)
if err == nil {
allowedCidrs = append(allowedCidrs, extraIPs...)
}
}
}
if peer.ExtraAllowedIPsStr != "" {
extraIPs, err := CidrsFromString(peer.ExtraAllowedIPsStr)
case InterfaceTypeClient:
for _, peer := range peers {
allowedIPs, err := CidrsFromString(peer.AllowedIPsStr.GetValue())
if err == nil {
allowedCidrs = append(allowedCidrs, extraIPs...)
allowedCidrs = append(allowedCidrs, allowedIPs...)
}
}
}
@@ -155,6 +173,7 @@ func (i *Interface) ManageRoutingTable() bool {
//
// -1 if RoutingTable was set to "off" or an error occurred
func (i *Interface) GetRoutingTable() int {
routingTableStr := strings.ToLower(i.RoutingTable)
switch {
case routingTableStr == "":
@@ -162,6 +181,9 @@ func (i *Interface) GetRoutingTable() int {
case routingTableStr == "off":
return -1
case strings.HasPrefix(routingTableStr, "0x"):
if i.Backend != config.LocalBackendName {
return 0 // ignore numeric routing table numbers for non-local controllers
}
numberStr := strings.ReplaceAll(routingTableStr, "0x", "")
routingTable, err := strconv.ParseUint(numberStr, 16, 64)
if err != nil {
@@ -174,6 +196,9 @@ func (i *Interface) GetRoutingTable() int {
}
return int(routingTable)
default:
if i.Backend != config.LocalBackendName {
return 0 // ignore numeric routing table numbers for non-local controllers
}
routingTable, err := strconv.Atoi(routingTableStr)
if err != nil {
slog.Error("failed to parse routing table number", "table", routingTableStr, "error", err)
@@ -204,9 +229,31 @@ type PhysicalInterface struct {
BytesUpload uint64
BytesDownload uint64
backendExtras any // additional backend-specific extras, e.g., domain.MikrotikInterfaceExtras
}
func (p *PhysicalInterface) GetExtras() any {
return p.backendExtras
}
func (p *PhysicalInterface) SetExtras(extras any) {
switch extras.(type) {
case MikrotikInterfaceExtras: // OK
default: // we only support MikrotikInterfaceExtras for now
panic(fmt.Sprintf("unsupported interface backend extras type %T", extras))
}
p.backendExtras = extras
}
func ConvertPhysicalInterface(pi *PhysicalInterface) *Interface {
networks := make([]Cidr, 0, len(pi.Addresses))
for _, addr := range pi.Addresses {
networks = append(networks, addr.NetworkAddr())
}
// create a new basic interface with the data from the physical interface
iface := &Interface{
Identifier: pi.Identifier,
KeyPair: pi.KeyPair,
@@ -226,11 +273,11 @@ func ConvertPhysicalInterface(pi *PhysicalInterface) *Interface {
Type: InterfaceTypeAny,
DriverType: pi.DeviceType,
Disabled: nil,
PeerDefNetworkStr: "",
PeerDefNetworkStr: CidrsToString(networks),
PeerDefDnsStr: "",
PeerDefDnsSearchStr: "",
PeerDefEndpoint: "",
PeerDefAllowedIPsStr: "",
PeerDefAllowedIPsStr: CidrsToString(networks),
PeerDefMtu: pi.Mtu,
PeerDefPersistentKeepalive: 0,
PeerDefFirewallMark: 0,
@@ -241,6 +288,23 @@ func ConvertPhysicalInterface(pi *PhysicalInterface) *Interface {
PeerDefPostDown: "",
}
if pi.GetExtras() == nil {
return iface
}
// enrich the data with controller-specific extras
now := time.Now()
switch pi.ImportSource {
case ControllerTypeMikrotik:
extras := pi.GetExtras().(MikrotikInterfaceExtras)
iface.DisplayName = extras.Comment
if extras.Disabled {
iface.Disabled = &now
} else {
iface.Disabled = nil
}
}
return iface
}
@@ -253,15 +317,30 @@ func MergeToPhysicalInterface(pi *PhysicalInterface, i *Interface) {
pi.FirewallMark = i.FirewallMark
pi.DeviceUp = !i.IsDisabled()
pi.Addresses = i.Addresses
switch pi.ImportSource {
case ControllerTypeMikrotik:
extras := MikrotikInterfaceExtras{
Comment: i.DisplayName,
Disabled: i.IsDisabled(),
}
pi.SetExtras(extras)
}
}
type RoutingTableInfo struct {
FwMark uint32
Table int
Interface Interface
AllowedIps []Cidr
FwMark uint32
Table int
TableStr string // the routing table number as string (used by mikrotik, linux uses the numeric value)
IsDeleted bool // true if the interface was deleted, false otherwise
}
func (r RoutingTableInfo) String() string {
return fmt.Sprintf("%d -> %d", r.FwMark, r.Table)
v4, v6 := CidrsPerFamily(r.AllowedIps)
return fmt.Sprintf("%s: fwmark=%d; table=%d; routes_4=%d; routes_6=%d", r.Interface.Identifier, r.FwMark, r.Table,
len(v4), len(v6))
}
func (r RoutingTableInfo) ManagementEnabled() bool {
@@ -279,3 +358,30 @@ func (r RoutingTableInfo) GetRoutingTable() int {
return r.Table
}
type IpFamily int
const (
IpFamilyIPv4 IpFamily = unix.AF_INET
IpFamilyIPv6 IpFamily = unix.AF_INET6
)
func (f IpFamily) String() string {
switch f {
case IpFamilyIPv4:
return "IPv4"
case IpFamilyIPv6:
return "IPv6"
default:
return "unknown"
}
}
// RouteRule represents a routing table rule.
type RouteRule struct {
InterfaceId InterfaceIdentifier
IpFamily IpFamily
FwMark uint32
Table int
HasDefault bool
}

View File

@@ -0,0 +1,27 @@
package domain
import "context"
type InterfaceController interface {
GetId() InterfaceBackend
GetInterfaces(_ context.Context) ([]PhysicalInterface, error)
GetInterface(_ context.Context, id InterfaceIdentifier) (*PhysicalInterface, error)
GetPeers(_ context.Context, deviceId InterfaceIdentifier) ([]PhysicalPeer, error)
SaveInterface(
_ context.Context,
id InterfaceIdentifier,
updateFunc func(pi *PhysicalInterface) (*PhysicalInterface, error),
) error
DeleteInterface(_ context.Context, id InterfaceIdentifier) error
SavePeer(
_ context.Context,
deviceId InterfaceIdentifier,
id PeerIdentifier,
updateFunc func(pp *PhysicalPeer) (*PhysicalPeer, error),
) error
DeletePeer(_ context.Context, deviceId InterfaceIdentifier, id PeerIdentifier) error
PingAddresses(
ctx context.Context,
addr string,
) (*PingerResult, error)
}

View File

@@ -5,6 +5,8 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/h44z/wg-portal/internal/config"
)
func TestInterface_IsDisabledReturnsTrueWhenDisabled(t *testing.T) {
@@ -37,8 +39,9 @@ func TestInterface_GetConfigFileNameReturnsCorrectFileName(t *testing.T) {
assert.Equal(t, expected, iface.GetConfigFileName())
}
func TestInterface_GetAllowedIPsReturnsCorrectCidrs(t *testing.T) {
func TestInterface_GetAllowedIPsReturnsCorrectCidrsServerMode(t *testing.T) {
peer1 := Peer{
AllowedIPsStr: ConfigOption[string]{Value: "192.168.2.2/32"},
Interface: PeerInterfaceConfig{
Addresses: []Cidr{
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
@@ -46,16 +49,45 @@ func TestInterface_GetAllowedIPsReturnsCorrectCidrs(t *testing.T) {
},
}
peer2 := Peer{
AllowedIPsStr: ConfigOption[string]{Value: "10.0.2.2/32"},
ExtraAllowedIPsStr: "10.20.2.2/32",
Interface: PeerInterfaceConfig{
Addresses: []Cidr{
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
},
},
}
iface := &Interface{}
iface := &Interface{Type: InterfaceTypeServer}
expected := []Cidr{
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
{Cidr: "10.20.2.2/32", Addr: "10.20.2.2", NetLength: 32},
}
assert.Equal(t, expected, iface.GetAllowedIPs([]Peer{peer1, peer2}))
}
func TestInterface_GetAllowedIPsReturnsCorrectCidrsClientMode(t *testing.T) {
peer1 := Peer{
AllowedIPsStr: ConfigOption[string]{Value: "192.168.2.2/32"},
Interface: PeerInterfaceConfig{
Addresses: []Cidr{
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
},
},
}
peer2 := Peer{
AllowedIPsStr: ConfigOption[string]{Value: "10.0.2.2/32"},
ExtraAllowedIPsStr: "10.20.2.2/32",
Interface: PeerInterfaceConfig{
Addresses: []Cidr{
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
},
},
}
iface := &Interface{Type: InterfaceTypeClient}
expected := []Cidr{
{Cidr: "192.168.2.2/32", Addr: "192.168.2.2", NetLength: 32},
{Cidr: "10.0.2.2/32", Addr: "10.0.2.2", NetLength: 32},
}
assert.Equal(t, expected, iface.GetAllowedIPs([]Peer{peer1, peer2}))
}
@@ -66,10 +98,22 @@ func TestInterface_ManageRoutingTableReturnsCorrectValue(t *testing.T) {
iface.RoutingTable = "100"
assert.True(t, iface.ManageRoutingTable())
iface = &Interface{RoutingTable: "off", Backend: config.LocalBackendName}
assert.False(t, iface.ManageRoutingTable())
iface.RoutingTable = "100"
assert.True(t, iface.ManageRoutingTable())
iface = &Interface{RoutingTable: "off", Backend: "mikrotik-xxx"}
assert.False(t, iface.ManageRoutingTable())
iface.RoutingTable = "100"
assert.True(t, iface.ManageRoutingTable())
}
func TestInterface_GetRoutingTableReturnsCorrectValue(t *testing.T) {
iface := &Interface{RoutingTable: ""}
iface := &Interface{RoutingTable: "", Backend: config.LocalBackendName}
assert.Equal(t, 0, iface.GetRoutingTable())
iface.RoutingTable = "off"
@@ -81,3 +125,17 @@ func TestInterface_GetRoutingTableReturnsCorrectValue(t *testing.T) {
iface.RoutingTable = "200"
assert.Equal(t, 200, iface.GetRoutingTable())
}
func TestInterface_GetRoutingTableNonLocal(t *testing.T) {
iface := &Interface{RoutingTable: "off", Backend: "something different"}
assert.Equal(t, -1, iface.GetRoutingTable())
iface.RoutingTable = "0"
assert.Equal(t, 0, iface.GetRoutingTable())
iface.RoutingTable = "100"
assert.Equal(t, 0, iface.GetRoutingTable())
iface.RoutingTable = "abc"
assert.Equal(t, 0, iface.GetRoutingTable())
}

View File

@@ -26,6 +26,10 @@ func (c Cidr) IsValid() bool {
return c.Prefix().IsValid()
}
func (c Cidr) EqualPrefix(other Cidr) bool {
return c.Addr == other.Addr && c.NetLength == other.NetLength
}
func CidrFromString(str string) (Cidr, error) {
prefix, err := netip.ParsePrefix(strings.TrimSpace(str))
if err != nil {
@@ -199,3 +203,26 @@ func (c Cidr) Contains(other Cidr) bool {
return subnet.Contains(otherIP)
}
// ContainsDefaultRoute returns true if the given CIDRs contain a default route.
func ContainsDefaultRoute(cidrs []Cidr) bool {
for _, allowedIP := range cidrs {
if allowedIP.Prefix().Bits() == 0 {
return true
}
}
return false
}
// CidrsPerFamily returns a slice of CIDRs, one for each family (IPv4 and IPv6).
func CidrsPerFamily(cidrs []Cidr) (ipv4, ipv6 []Cidr) {
for _, cidr := range cidrs {
if cidr.IsV4() {
ipv4 = append(ipv4, cidr)
} else {
ipv6 = append(ipv6, cidr)
}
}
return
}

View File

@@ -1,12 +1,14 @@
package domain
import (
"errors"
"fmt"
"net"
"strings"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"gorm.io/gorm"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/config"
@@ -44,6 +46,7 @@ type Peer struct {
DisplayName string // a nice display name/ description for the peer
Identifier PeerIdentifier `gorm:"primaryKey;column:identifier"` // peer unique identifier
UserIdentifier UserIdentifier `gorm:"index;column:user_identifier"` // the owner
User *User `gorm:"-"` // the owner user object; loaded automatically after fetch
InterfaceIdentifier InterfaceIdentifier `gorm:"index;column:interface_identifier"` // the interface id
Disabled *time.Time `gorm:"column:disabled"` // if this field is set, the peer is disabled
DisabledReason string // the reason why the peer has been disabled
@@ -129,7 +132,7 @@ func (p *Peer) GenerateDisplayName(prefix string) {
p.DisplayName = fmt.Sprintf("%sPeer %s", prefix, internal.TruncateString(string(p.Identifier), 8))
}
// OverwriteUserEditableFields overwrites the user editable fields of the peer with the values from the userPeer
// OverwriteUserEditableFields overwrites the user-editable fields of the peer with the values from the userPeer
func (p *Peer) OverwriteUserEditableFields(userPeer *Peer, cfg *config.Config) {
p.DisplayName = userPeer.DisplayName
if cfg.Core.EditableKeys {
@@ -182,9 +185,12 @@ type PhysicalPeer struct {
BytesUpload uint64 // upload bytes are the number of bytes that the remote peer has sent to the server
BytesDownload uint64 // upload bytes are the number of bytes that the remote peer has received from the server
ImportSource string // import source (wgctrl, file, ...)
backendExtras any // additional backend-specific extras, e.g., domain.MikrotikPeerExtras
}
func (p PhysicalPeer) GetPresharedKey() *wgtypes.Key {
func (p *PhysicalPeer) GetPresharedKey() *wgtypes.Key {
if p.PresharedKey == "" {
return nil
}
@@ -196,7 +202,7 @@ func (p PhysicalPeer) GetPresharedKey() *wgtypes.Key {
return &key
}
func (p PhysicalPeer) GetEndpointAddress() *net.UDPAddr {
func (p *PhysicalPeer) GetEndpointAddress() *net.UDPAddr {
if p.Endpoint == "" {
return nil
}
@@ -208,7 +214,7 @@ func (p PhysicalPeer) GetEndpointAddress() *net.UDPAddr {
return addr
}
func (p PhysicalPeer) GetPersistentKeepaliveTime() *time.Duration {
func (p *PhysicalPeer) GetPersistentKeepaliveTime() *time.Duration {
if p.PersistentKeepalive == 0 {
return nil
}
@@ -217,7 +223,7 @@ func (p PhysicalPeer) GetPersistentKeepaliveTime() *time.Duration {
return &keepAliveDuration
}
func (p PhysicalPeer) GetAllowedIPs() []net.IPNet {
func (p *PhysicalPeer) GetAllowedIPs() []net.IPNet {
allowedIPs := make([]net.IPNet, len(p.AllowedIPs))
for i, ip := range p.AllowedIPs {
allowedIPs[i] = *ip.IpNet()
@@ -226,6 +232,21 @@ func (p PhysicalPeer) GetAllowedIPs() []net.IPNet {
return allowedIPs
}
func (p *PhysicalPeer) GetExtras() any {
return p.backendExtras
}
func (p *PhysicalPeer) SetExtras(extras any) {
switch extras.(type) {
case MikrotikPeerExtras: // OK
case LocalPeerExtras: // OK
default: // we only support MikrotikPeerExtras and LocalPeerExtras for now
panic(fmt.Sprintf("unsupported peer backend extras type %T", extras))
}
p.backendExtras = extras
}
func ConvertPhysicalPeer(pp *PhysicalPeer) *Peer {
peer := &Peer{
Endpoint: NewConfigOption(pp.Endpoint, true),
@@ -244,30 +265,123 @@ func ConvertPhysicalPeer(pp *PhysicalPeer) *Peer {
},
}
if pp.GetExtras() == nil {
return peer
}
// enrich the data with controller-specific extras
now := time.Now()
switch pp.ImportSource {
case ControllerTypeMikrotik:
extras := pp.GetExtras().(MikrotikPeerExtras)
peer.Notes = extras.Comment
peer.DisplayName = extras.Name
if extras.ClientEndpoint != "" { // if the client endpoint is set, we assume that this is a client peer
peer.Endpoint = NewConfigOption(extras.ClientEndpoint, true)
peer.Interface.Type = InterfaceTypeClient
peer.Interface.Addresses, _ = CidrsFromString(extras.ClientAddress)
peer.Interface.DnsStr = NewConfigOption(extras.ClientDns, true)
peer.PersistentKeepalive = NewConfigOption(extras.ClientKeepalive, true)
} else {
peer.Interface.Type = InterfaceTypeServer
}
if extras.Disabled {
peer.Disabled = &now
peer.DisabledReason = "Disabled by Mikrotik controller"
} else {
peer.Disabled = nil
peer.DisabledReason = ""
}
case ControllerTypeLocal:
extras := pp.GetExtras().(LocalPeerExtras)
if extras.Disabled {
peer.Disabled = &now
peer.DisabledReason = "Disabled by Local controller"
} else {
peer.Disabled = nil
peer.DisabledReason = ""
}
}
return peer
}
func MergeToPhysicalPeer(pp *PhysicalPeer, p *Peer) {
pp.Identifier = p.Identifier
pp.Endpoint = p.Endpoint.GetValue()
if p.Interface.Type == InterfaceTypeServer {
allowedIPs, _ := CidrsFromString(p.AllowedIPsStr.GetValue())
extraAllowedIPs, _ := CidrsFromString(p.ExtraAllowedIPsStr)
pp.AllowedIPs = append(allowedIPs, extraAllowedIPs...)
} else {
pp.PresharedKey = p.PresharedKey
pp.PublicKey = p.Interface.PublicKey
switch p.Interface.Type {
case InterfaceTypeClient: // this means that the corresponding interface in wgportal is a server interface
allowedIPs := make([]Cidr, len(p.Interface.Addresses))
for i, ip := range p.Interface.Addresses {
allowedIPs[i] = ip.HostAddr()
allowedIPs[i] = ip.HostAddr() // add the peer's host address to the allowed IPs
}
extraAllowedIPs, _ := CidrsFromString(p.ExtraAllowedIPsStr)
pp.AllowedIPs = append(allowedIPs, extraAllowedIPs...)
case InterfaceTypeServer: // this means that the corresponding interface in wgportal is a client interface
allowedIPs, _ := CidrsFromString(p.AllowedIPsStr.GetValue())
extraAllowedIPs, _ := CidrsFromString(p.ExtraAllowedIPsStr)
pp.AllowedIPs = append(allowedIPs, extraAllowedIPs...)
pp.Endpoint = p.Endpoint.GetValue()
pp.PersistentKeepalive = p.PersistentKeepalive.GetValue()
case InterfaceTypeAny: // this means that the corresponding interface in wgportal has no specific type
allowedIPs := make([]Cidr, len(p.Interface.Addresses))
for i, ip := range p.Interface.Addresses {
allowedIPs[i] = ip.HostAddr() // add the peer's host address to the allowed IPs
}
extraAllowedIPs, _ := CidrsFromString(p.ExtraAllowedIPsStr)
pp.AllowedIPs = append(allowedIPs, extraAllowedIPs...)
pp.Endpoint = p.Endpoint.GetValue()
pp.PersistentKeepalive = p.PersistentKeepalive.GetValue()
}
switch pp.ImportSource {
case ControllerTypeMikrotik:
extras := MikrotikPeerExtras{
Id: "",
Name: p.DisplayName,
Comment: p.Notes,
IsResponder: p.Interface.Type == InterfaceTypeClient,
Disabled: p.IsDisabled(),
ClientEndpoint: p.Endpoint.GetValue(),
ClientAddress: CidrsToString(p.Interface.Addresses),
ClientDns: p.Interface.DnsStr.GetValue(),
ClientKeepalive: p.PersistentKeepalive.GetValue(),
}
pp.SetExtras(extras)
case ControllerTypeLocal:
extras := LocalPeerExtras{
Disabled: p.IsDisabled(),
}
pp.SetExtras(extras)
}
pp.PresharedKey = p.PresharedKey
pp.PublicKey = p.Interface.PublicKey
pp.PersistentKeepalive = p.PersistentKeepalive.GetValue()
}
type PeerCreationRequest struct {
UserIdentifiers []string
Prefix string
}
// AfterFind is a GORM hook that automatically loads the associated User object
// based on the UserIdentifier field. If the identifier is empty or no user is
// found, the User field is set to nil.
func (p *Peer) AfterFind(tx *gorm.DB) error {
if p == nil {
return nil
}
if p.UserIdentifier == "" {
p.User = nil
return nil
}
var u User
if err := tx.Where("identifier = ?", p.UserIdentifier).First(&u).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
p.User = nil
return nil
}
return err
}
p.User = &u
return nil
}

View File

@@ -1,23 +1,27 @@
package domain
import "time"
import (
"time"
)
type PeerStatus struct {
PeerId PeerIdentifier `gorm:"primaryKey;column:identifier"`
UpdatedAt time.Time `gorm:"column:updated_at"`
PeerId PeerIdentifier `gorm:"primaryKey;column:identifier" json:"PeerId"`
UpdatedAt time.Time `gorm:"column:updated_at" json:"-"`
IsPingable bool `gorm:"column:pingable"`
LastPing *time.Time `gorm:"column:last_ping"`
IsConnected bool `gorm:"column:connected" json:"IsConnected"` // indicates if the peer is connected based on the last handshake or ping
BytesReceived uint64 `gorm:"column:received"`
BytesTransmitted uint64 `gorm:"column:transmitted"`
IsPingable bool `gorm:"column:pingable" json:"IsPingable"`
LastPing *time.Time `gorm:"column:last_ping" json:"LastPing"`
LastHandshake *time.Time `gorm:"column:last_handshake"`
Endpoint string `gorm:"column:endpoint"`
LastSessionStart *time.Time `gorm:"column:last_session_start"`
BytesReceived uint64 `gorm:"column:received" json:"BytesReceived"`
BytesTransmitted uint64 `gorm:"column:transmitted" json:"BytesTransmitted"`
LastHandshake *time.Time `gorm:"column:last_handshake" json:"LastHandshake"`
Endpoint string `gorm:"column:endpoint" json:"Endpoint"`
LastSessionStart *time.Time `gorm:"column:last_session_start" json:"LastSessionStart"`
}
func (s PeerStatus) IsConnected() bool {
func (s *PeerStatus) CalcConnected() {
oldestHandshakeTime := time.Now().Add(-2 * time.Minute) // if a handshake is older than 2 minutes, the peer is no longer connected
handshakeValid := false
@@ -25,7 +29,7 @@ func (s PeerStatus) IsConnected() bool {
handshakeValid = !s.LastHandshake.Before(oldestHandshakeTime)
}
return s.IsPingable || handshakeValid
s.IsConnected = s.IsPingable || handshakeValid
}
type InterfaceStatus struct {
@@ -35,3 +39,25 @@ type InterfaceStatus struct {
BytesReceived uint64 `gorm:"column:received"`
BytesTransmitted uint64 `gorm:"column:transmitted"`
}
type PingerResult struct {
PacketsRecv int
PacketsSent int
Rtts []time.Duration
}
func (r PingerResult) IsPingable() bool {
return r.PacketsRecv > 0 && r.PacketsSent > 0 && len(r.Rtts) > 0
}
func (r PingerResult) AverageRtt() time.Duration {
if len(r.Rtts) == 0 {
return 0
}
var total time.Duration
for _, rtt := range r.Rtts {
total += rtt
}
return total / time.Duration(len(r.Rtts))
}

View File

@@ -66,8 +66,9 @@ func TestPeerStatus_IsConnected(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.status.IsConnected(); got != tt.want {
t.Errorf("IsConnected() = %v, want %v", got, tt.want)
tt.status.CalcConnected()
if got := tt.status.IsConnected; got != tt.want {
t.Errorf("IsConnected = %v, want %v", got, tt.want)
}
})
}

View File

@@ -2,9 +2,16 @@ package domain
import (
"crypto/subtle"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"slices"
"strings"
"time"
"github.com/go-webauthn/webauthn/webauthn"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
@@ -43,6 +50,10 @@ type User struct {
Locked *time.Time `gorm:"index;column:locked"` // if this field is set, the user is locked and can no longer login (WireGuard peers still can connect)
LockedReason string // the reason why the user has been locked
// Passwordless authentication
WebAuthnId string `gorm:"column:webauthn_id"` // the webauthn id of the user, used for webauthn authentication
WebAuthnCredentialList []UserWebauthnCredential `gorm:"foreignKey:user_identifier"` // the webauthn credentials of the user, used for webauthn authentication
// API token for REST API access
ApiToken string `form:"api_token" binding:"omitempty"`
ApiTokenCreated *time.Time
@@ -77,6 +88,22 @@ func (u *User) CanChangePassword() error {
return errors.New("password change only allowed for database source")
}
func (u *User) HasWeakPassword(minLength int) error {
if u.Source != UserSourceDatabase {
return nil // password is not required for non-database users, so no check needed
}
if u.Password == "" {
return nil // password is not set, so no check needed
}
if len(u.Password) < minLength {
return fmt.Errorf("password is too short, minimum length is %d", minLength)
}
return nil // password is strong enough
}
func (u *User) EditAllowed(new *User) error {
if u.Source == UserSourceDatabase {
return nil
@@ -157,3 +184,157 @@ func (u *User) CopyCalculatedAttributes(src *User) {
u.BaseModel = src.BaseModel
u.LinkedPeerCount = src.LinkedPeerCount
}
// DisplayName returns the display name of the user.
// The display name is the first and last name, or the email address of the user.
// If none of these fields are set, the user identifier is returned.
func (u *User) DisplayName() string {
var displayName string
switch {
case u.Firstname != "" && u.Lastname != "":
displayName = fmt.Sprintf("%s %s", u.Firstname, u.Lastname)
case u.Firstname != "":
displayName = u.Firstname
case u.Lastname != "":
displayName = u.Lastname
case u.Email != "":
displayName = u.Email
default:
displayName = string(u.Identifier)
}
return displayName
}
// region webauthn
func (u *User) WebAuthnID() []byte {
decodeString, err := base64.StdEncoding.DecodeString(u.WebAuthnId)
if err != nil {
return nil
}
return decodeString
}
func (u *User) GenerateWebAuthnId() {
randomUid1 := uuid.New().String() // 32 hex digits + 4 dashes
randomUid2 := uuid.New().String() // 32 hex digits + 4 dashes
webAuthnId := []byte(strings.ReplaceAll(fmt.Sprintf("%s%s", randomUid1, randomUid2), "-", "")) // 64 hex digits
u.WebAuthnId = base64.StdEncoding.EncodeToString(webAuthnId)
}
func (u *User) WebAuthnName() string {
return string(u.Identifier)
}
func (u *User) WebAuthnDisplayName() string {
return u.DisplayName()
}
func (u *User) WebAuthnCredentials() []webauthn.Credential {
credentials := make([]webauthn.Credential, len(u.WebAuthnCredentialList))
for i, cred := range u.WebAuthnCredentialList {
credential, err := cred.GetCredential()
if err != nil {
continue
}
credentials[i] = credential
}
return credentials
}
func (u *User) AddCredential(userId UserIdentifier, name string, credential webauthn.Credential) error {
cred, err := NewUserWebauthnCredential(userId, name, credential)
if err != nil {
return err
}
// Check if the credential already exists
for _, c := range u.WebAuthnCredentialList {
if c.GetCredentialId() == string(credential.ID) {
return errors.New("credential already exists")
}
}
u.WebAuthnCredentialList = append(u.WebAuthnCredentialList, cred)
return nil
}
func (u *User) UpdateCredential(credentialIdBase64, name string) error {
for i, c := range u.WebAuthnCredentialList {
if c.CredentialIdentifier == credentialIdBase64 {
u.WebAuthnCredentialList[i].DisplayName = name
return nil
}
}
return errors.New("credential not found")
}
func (u *User) RemoveCredential(credentialIdBase64 string) {
u.WebAuthnCredentialList = slices.DeleteFunc(u.WebAuthnCredentialList, func(e UserWebauthnCredential) bool {
return e.CredentialIdentifier == credentialIdBase64
})
}
type UserWebauthnCredential struct {
UserIdentifier string `gorm:"primaryKey;column:user_identifier"` // the user identifier
CredentialIdentifier string `gorm:"primaryKey;uniqueIndex;column:credential_identifier"` // base64 encoded credential id
CreatedAt time.Time `gorm:"column:created_at"` // the time when the credential was created
DisplayName string `gorm:"column:display_name"` // the display name of the credential
SerializedCredential string `gorm:"column:serialized_credential"` // JSON and base64 encoded credential
}
func NewUserWebauthnCredential(userIdentifier UserIdentifier, name string, credential webauthn.Credential) (
UserWebauthnCredential,
error,
) {
c := UserWebauthnCredential{
UserIdentifier: string(userIdentifier),
CreatedAt: time.Now(),
DisplayName: name,
CredentialIdentifier: base64.StdEncoding.EncodeToString(credential.ID),
}
err := c.SetCredential(credential)
if err != nil {
return c, err
}
return c, nil
}
func (c *UserWebauthnCredential) SetCredential(credential webauthn.Credential) error {
jsonData, err := json.Marshal(credential)
if err != nil {
return fmt.Errorf("failed to marshal credential: %w", err)
}
c.SerializedCredential = base64.StdEncoding.EncodeToString(jsonData)
return nil
}
func (c *UserWebauthnCredential) GetCredential() (webauthn.Credential, error) {
jsonData, err := base64.StdEncoding.DecodeString(c.SerializedCredential)
if err != nil {
return webauthn.Credential{}, fmt.Errorf("failed to decode base64 credential: %w", err)
}
var credential webauthn.Credential
if err := json.Unmarshal(jsonData, &credential); err != nil {
return webauthn.Credential{}, fmt.Errorf("failed to unmarshal credential: %w", err)
}
return credential, nil
}
func (c *UserWebauthnCredential) GetCredentialId() string {
decodeString, _ := base64.StdEncoding.DecodeString(c.CredentialIdentifier)
return string(decodeString)
}
// endregion webauthn

View File

@@ -12,8 +12,8 @@ import (
"sync"
)
// SetupLogging initializes the global logger with the given level and format
func SetupLogging(level string, pretty, json bool) {
// GetLoggingHandler initializes a slog.Handler based on the provided logging level and format options.
func GetLoggingHandler(level string, pretty, json bool) slog.Handler {
var logLevel = new(slog.LevelVar)
switch strings.ToLower(level) {
@@ -46,6 +46,13 @@ func SetupLogging(level string, pretty, json bool) {
handler = slog.NewTextHandler(output, opts)
}
return handler
}
// SetupLogging initializes the global logger with the given level and format
func SetupLogging(level string, pretty, json bool) {
handler := GetLoggingHandler(level, pretty, json)
logger := slog.New(handler)
slog.SetDefault(logger)

View File

@@ -0,0 +1,436 @@
package lowlevel
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/config"
)
// region models
const (
MikrotikApiStatusOk = "success"
MikrotikApiStatusError = "error"
)
const (
MikrotikApiErrorCodeUnknown = iota + 600
MikrotikApiErrorCodeRequestPreparationFailed
MikrotikApiErrorCodeRequestFailed
MikrotikApiErrorCodeResponseDecodeFailed
)
type MikrotikApiResponse[T any] struct {
Status string
Code int
Data T `json:"data,omitempty"`
Error *MikrotikApiError `json:"error,omitempty"`
}
type MikrotikApiError struct {
Code int `json:"error,omitempty"`
Message string `json:"message,omitempty"`
Details string `json:"detail,omitempty"`
}
func (e *MikrotikApiError) String() string {
if e == nil {
return "no error"
}
return fmt.Sprintf("API error %d: %s - %s", e.Code, e.Message, e.Details)
}
type GenericJsonObject map[string]any
type EmptyResponse struct{}
func (JsonObject GenericJsonObject) GetString(key string) string {
if value, ok := JsonObject[key]; ok {
if strValue, ok := value.(string); ok {
return strValue
} else {
return fmt.Sprintf("%v", value) // Convert to string if not already
}
}
return ""
}
func (JsonObject GenericJsonObject) GetInt(key string) int {
if value, ok := JsonObject[key]; ok {
if intValue, ok := value.(int); ok {
return intValue
} else {
if floatValue, ok := value.(float64); ok {
return int(floatValue) // Convert float64 to int
}
if strValue, ok := value.(string); ok {
if intValue, err := strconv.Atoi(strValue); err == nil {
return intValue // Convert string to int if possible
}
}
}
}
return 0
}
func (JsonObject GenericJsonObject) GetBool(key string) bool {
if value, ok := JsonObject[key]; ok {
if boolValue, ok := value.(bool); ok {
return boolValue
} else {
if intValue, ok := value.(int); ok {
return intValue == 1 // Convert int to bool (1 is true, 0 is false)
}
if floatValue, ok := value.(float64); ok {
return int(floatValue) == 1 // Convert float64 to bool (1.0 is true, 0.0 is false)
}
if strValue, ok := value.(string); ok {
boolValue, err := strconv.ParseBool(strValue)
if err == nil {
return boolValue
}
}
}
}
return false
}
type MikrotikRequestOptions struct {
Filters map[string]string `json:"filters,omitempty"`
PropList []string `json:"proplist,omitempty"`
}
func (o *MikrotikRequestOptions) GetPath(base string) string {
if o == nil {
return base
}
path, err := url.Parse(base)
if err != nil {
return base
}
query := path.Query()
for k, v := range o.Filters {
query.Set(k, v)
}
if len(o.PropList) > 0 {
query.Set(".proplist", strings.Join(o.PropList, ","))
}
path.RawQuery = query.Encode()
return path.String()
}
// region models
// region API-client
type MikrotikApiClient struct {
coreCfg *config.Config
cfg *config.BackendMikrotik
client *http.Client
log *slog.Logger
}
func NewMikrotikApiClient(coreCfg *config.Config, cfg *config.BackendMikrotik) (*MikrotikApiClient, error) {
c := &MikrotikApiClient{
coreCfg: coreCfg,
cfg: cfg,
}
err := c.setup()
if err != nil {
return nil, err
}
c.debugLog("Mikrotik api client created", "api_url", cfg.ApiUrl)
return c, nil
}
func (m *MikrotikApiClient) setup() error {
m.client = &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: !m.cfg.ApiVerifyTls,
},
},
Timeout: m.cfg.GetApiTimeout(),
}
if m.cfg.Debug {
m.log = slog.New(internal.GetLoggingHandler("debug",
m.coreCfg.Advanced.LogPretty,
m.coreCfg.Advanced.LogJson).
WithAttrs([]slog.Attr{
{
Key: "mikrotik-bid", Value: slog.StringValue(m.cfg.Id),
},
}))
}
return nil
}
func (m *MikrotikApiClient) debugLog(msg string, args ...any) {
if m.log != nil {
m.log.Debug("[MT-API] "+msg, args...)
}
}
func (m *MikrotikApiClient) getFullPath(command string) string {
path, err := url.JoinPath(m.cfg.ApiUrl, command)
if err != nil {
return ""
}
return path
}
func (m *MikrotikApiClient) prepareGetRequest(ctx context.Context, fullUrl string) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullUrl, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
if m.cfg.ApiUser != "" && m.cfg.ApiPassword != "" {
req.SetBasicAuth(m.cfg.ApiUser, m.cfg.ApiPassword)
}
return req, nil
}
func (m *MikrotikApiClient) prepareDeleteRequest(ctx context.Context, fullUrl string) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, fullUrl, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
if m.cfg.ApiUser != "" && m.cfg.ApiPassword != "" {
req.SetBasicAuth(m.cfg.ApiUser, m.cfg.ApiPassword)
}
return req, nil
}
func (m *MikrotikApiClient) preparePayloadRequest(
ctx context.Context,
method string,
fullUrl string,
payload GenericJsonObject,
) (*http.Request, error) {
// marshal the payload to JSON
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, method, fullUrl, bytes.NewReader(payloadBytes))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
if m.cfg.ApiUser != "" && m.cfg.ApiPassword != "" {
req.SetBasicAuth(m.cfg.ApiUser, m.cfg.ApiPassword)
}
return req, nil
}
func errToApiResponse[T any](code int, message string, err error) MikrotikApiResponse[T] {
return MikrotikApiResponse[T]{
Status: MikrotikApiStatusError,
Code: code,
Error: &MikrotikApiError{
Code: code,
Message: message,
Details: err.Error(),
},
}
}
func parseHttpResponse[T any](resp *http.Response, err error) MikrotikApiResponse[T] {
if err != nil {
return errToApiResponse[T](MikrotikApiErrorCodeRequestFailed, "failed to execute request", err)
}
defer func(Body io.ReadCloser) {
_, _ = io.Copy(io.Discard, Body) // ensure to empty the body
err := Body.Close()
if err != nil {
slog.Error("failed to close response body", "error", err)
}
}(resp.Body)
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
var data T
// if the type of T is EmptyResponse, we can return an empty response with just the status
if _, ok := any(data).(EmptyResponse); ok {
return MikrotikApiResponse[T]{Status: MikrotikApiStatusOk, Code: resp.StatusCode}
}
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return errToApiResponse[T](MikrotikApiErrorCodeResponseDecodeFailed, "failed to decode response", err)
}
return MikrotikApiResponse[T]{Status: MikrotikApiStatusOk, Code: resp.StatusCode, Data: data}
}
var apiErr MikrotikApiError
if err := json.NewDecoder(resp.Body).Decode(&apiErr); err != nil {
return errToApiResponse[T](resp.StatusCode, "unknown error, unparsable response", err)
} else {
return MikrotikApiResponse[T]{Status: MikrotikApiStatusError, Code: resp.StatusCode, Error: &apiErr}
}
}
func (m *MikrotikApiClient) Query(
ctx context.Context,
command string,
opts *MikrotikRequestOptions,
) MikrotikApiResponse[[]GenericJsonObject] {
apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout())
defer cancel()
fullUrl := opts.GetPath(m.getFullPath(command))
req, err := m.prepareGetRequest(apiCtx, fullUrl)
if err != nil {
return errToApiResponse[[]GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed,
"failed to create request", err)
}
start := time.Now()
m.debugLog("executing API query", "url", fullUrl)
response := parseHttpResponse[[]GenericJsonObject](m.client.Do(req))
m.debugLog("retrieved API query result", "url", fullUrl, "duration", time.Since(start).String())
return response
}
func (m *MikrotikApiClient) Get(
ctx context.Context,
command string,
opts *MikrotikRequestOptions,
) MikrotikApiResponse[GenericJsonObject] {
apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout())
defer cancel()
fullUrl := opts.GetPath(m.getFullPath(command))
req, err := m.prepareGetRequest(apiCtx, fullUrl)
if err != nil {
return errToApiResponse[GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed,
"failed to create request", err)
}
start := time.Now()
m.debugLog("executing API get", "url", fullUrl)
response := parseHttpResponse[GenericJsonObject](m.client.Do(req))
m.debugLog("retrieved API get result", "url", fullUrl, "duration", time.Since(start).String())
return response
}
func (m *MikrotikApiClient) Create(
ctx context.Context,
command string,
payload GenericJsonObject,
) MikrotikApiResponse[GenericJsonObject] {
apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout())
defer cancel()
fullUrl := m.getFullPath(command)
req, err := m.preparePayloadRequest(apiCtx, http.MethodPut, fullUrl, payload)
if err != nil {
return errToApiResponse[GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed,
"failed to create request", err)
}
start := time.Now()
m.debugLog("executing API put", "url", fullUrl)
response := parseHttpResponse[GenericJsonObject](m.client.Do(req))
m.debugLog("retrieved API put result", "url", fullUrl, "duration", time.Since(start).String())
return response
}
func (m *MikrotikApiClient) Update(
ctx context.Context,
command string,
payload GenericJsonObject,
) MikrotikApiResponse[GenericJsonObject] {
apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout())
defer cancel()
fullUrl := m.getFullPath(command)
req, err := m.preparePayloadRequest(apiCtx, http.MethodPatch, fullUrl, payload)
if err != nil {
return errToApiResponse[GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed,
"failed to create request", err)
}
start := time.Now()
m.debugLog("executing API patch", "url", fullUrl)
response := parseHttpResponse[GenericJsonObject](m.client.Do(req))
m.debugLog("retrieved API patch result", "url", fullUrl, "duration", time.Since(start).String())
return response
}
func (m *MikrotikApiClient) Delete(
ctx context.Context,
command string,
) MikrotikApiResponse[EmptyResponse] {
apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout())
defer cancel()
fullUrl := m.getFullPath(command)
req, err := m.prepareDeleteRequest(apiCtx, fullUrl)
if err != nil {
return errToApiResponse[EmptyResponse](MikrotikApiErrorCodeRequestPreparationFailed,
"failed to create request", err)
}
start := time.Now()
m.debugLog("executing API delete", "url", fullUrl)
response := parseHttpResponse[EmptyResponse](m.client.Do(req))
m.debugLog("retrieved API delete result", "url", fullUrl, "duration", time.Since(start).String())
return response
}
func (m *MikrotikApiClient) ExecList(
ctx context.Context,
command string,
payload GenericJsonObject,
) MikrotikApiResponse[[]GenericJsonObject] {
apiCtx, cancel := context.WithTimeout(ctx, m.cfg.GetApiTimeout())
defer cancel()
fullUrl := m.getFullPath(command)
req, err := m.preparePayloadRequest(apiCtx, http.MethodPost, fullUrl, payload)
if err != nil {
return errToApiResponse[[]GenericJsonObject](MikrotikApiErrorCodeRequestPreparationFailed,
"failed to create request", err)
}
start := time.Now()
m.debugLog("executing API post", "url", fullUrl)
response := parseHttpResponse[[]GenericJsonObject](m.client.Do(req))
m.debugLog("retrieved API post result", "url", fullUrl, "duration", time.Since(start).String())
return response
}
// endregion API-client