mirror of
https://github.com/h44z/wg-portal.git
synced 2025-09-15 07:11:15 +00:00
Mikrotik integration (#467)
Allow MikroTik routes as WireGuard backends
This commit is contained in:
864
internal/adapters/wgcontroller/local.go
Normal file
864
internal/adapters/wgcontroller/local.go
Normal file
@@ -0,0 +1,864 @@
|
||||
package wgcontroller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
probing "github.com/prometheus-community/pro-bing"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/h44z/wg-portal/internal"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
"github.com/h44z/wg-portal/internal/lowlevel"
|
||||
)
|
||||
|
||||
// region dependencies
|
||||
|
||||
// WgCtrlRepo is used to control local WireGuard devices via the wgctrl-go library.
|
||||
type WgCtrlRepo interface {
|
||||
io.Closer
|
||||
Devices() ([]*wgtypes.Device, error)
|
||||
Device(name string) (*wgtypes.Device, error)
|
||||
ConfigureDevice(name string, cfg wgtypes.Config) error
|
||||
}
|
||||
|
||||
// A NetlinkClient is a type which can control a netlink device.
|
||||
type NetlinkClient interface {
|
||||
LinkAdd(link netlink.Link) error
|
||||
LinkDel(link netlink.Link) error
|
||||
LinkByName(name string) (netlink.Link, error)
|
||||
LinkSetUp(link netlink.Link) error
|
||||
LinkSetDown(link netlink.Link) error
|
||||
LinkSetMTU(link netlink.Link, mtu int) error
|
||||
AddrReplace(link netlink.Link, addr *netlink.Addr) error
|
||||
AddrAdd(link netlink.Link, addr *netlink.Addr) error
|
||||
AddrList(link netlink.Link) ([]netlink.Addr, error)
|
||||
AddrDel(link netlink.Link, addr *netlink.Addr) error
|
||||
RouteAdd(route *netlink.Route) error
|
||||
RouteDel(route *netlink.Route) error
|
||||
RouteReplace(route *netlink.Route) error
|
||||
RouteList(link netlink.Link, family int) ([]netlink.Route, error)
|
||||
RouteListFiltered(family int, filter *netlink.Route, filterMask uint64) ([]netlink.Route, error)
|
||||
RuleAdd(rule *netlink.Rule) error
|
||||
RuleDel(rule *netlink.Rule) error
|
||||
RuleList(family int) ([]netlink.Rule, error)
|
||||
}
|
||||
|
||||
// endregion dependencies
|
||||
|
||||
type LocalController struct {
|
||||
cfg *config.Config
|
||||
|
||||
wg WgCtrlRepo
|
||||
nl NetlinkClient
|
||||
|
||||
shellCmd string
|
||||
resolvConfIfacePrefix string
|
||||
}
|
||||
|
||||
// NewLocalController creates a new local controller instance.
|
||||
// This repository is used to interact with the WireGuard kernel or userspace module.
|
||||
func NewLocalController(cfg *config.Config) (*LocalController, error) {
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create wgctrl client: %w", err)
|
||||
}
|
||||
|
||||
nl := &lowlevel.NetlinkManager{}
|
||||
|
||||
repo := &LocalController{
|
||||
cfg: cfg,
|
||||
|
||||
wg: wg,
|
||||
nl: nl,
|
||||
|
||||
shellCmd: "bash", // we only support bash at the moment
|
||||
resolvConfIfacePrefix: "tun.", // WireGuard interfaces have a tun. prefix in resolvconf
|
||||
}
|
||||
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
func (c LocalController) GetId() domain.InterfaceBackend {
|
||||
return config.LocalBackendName
|
||||
}
|
||||
|
||||
// region wireguard-related
|
||||
|
||||
func (c LocalController) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) {
|
||||
devices, err := c.wg.Devices()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("device list error: %w", err)
|
||||
}
|
||||
|
||||
interfaces := make([]domain.PhysicalInterface, 0, len(devices))
|
||||
for _, device := range devices {
|
||||
interfaceModel, err := c.convertWireGuardInterface(device)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("interface convert failed for %s: %w", device.Name, err)
|
||||
}
|
||||
interfaces = append(interfaces, interfaceModel)
|
||||
}
|
||||
|
||||
return interfaces, nil
|
||||
}
|
||||
|
||||
func (c LocalController) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (
|
||||
*domain.PhysicalInterface,
|
||||
error,
|
||||
) {
|
||||
return c.getInterface(id)
|
||||
}
|
||||
|
||||
func (c LocalController) convertWireGuardInterface(device *wgtypes.Device) (domain.PhysicalInterface, error) {
|
||||
// read data from wgctrl interface
|
||||
|
||||
iface := domain.PhysicalInterface{
|
||||
Identifier: domain.InterfaceIdentifier(device.Name),
|
||||
KeyPair: domain.KeyPair{
|
||||
PrivateKey: device.PrivateKey.String(),
|
||||
PublicKey: device.PublicKey.String(),
|
||||
},
|
||||
ListenPort: device.ListenPort,
|
||||
Addresses: nil,
|
||||
Mtu: 0,
|
||||
FirewallMark: uint32(device.FirewallMark),
|
||||
DeviceUp: false,
|
||||
ImportSource: domain.ControllerTypeLocal,
|
||||
DeviceType: device.Type.String(),
|
||||
BytesUpload: 0,
|
||||
BytesDownload: 0,
|
||||
}
|
||||
|
||||
// read data from netlink interface
|
||||
|
||||
lowLevelInterface, err := c.nl.LinkByName(device.Name)
|
||||
if err != nil {
|
||||
return domain.PhysicalInterface{}, fmt.Errorf("netlink error for %s: %w", device.Name, err)
|
||||
}
|
||||
ipAddresses, err := c.nl.AddrList(lowLevelInterface)
|
||||
if err != nil {
|
||||
return domain.PhysicalInterface{}, fmt.Errorf("ip read error for %s: %w", device.Name, err)
|
||||
}
|
||||
|
||||
for _, addr := range ipAddresses {
|
||||
iface.Addresses = append(iface.Addresses, domain.CidrFromNetlinkAddr(addr))
|
||||
}
|
||||
iface.Mtu = lowLevelInterface.Attrs().MTU
|
||||
iface.DeviceUp = lowLevelInterface.Attrs().OperState == netlink.OperUnknown // wg only supports unknown
|
||||
if stats := lowLevelInterface.Attrs().Statistics; stats != nil {
|
||||
iface.BytesUpload = stats.TxBytes
|
||||
iface.BytesDownload = stats.RxBytes
|
||||
}
|
||||
|
||||
return iface, nil
|
||||
}
|
||||
|
||||
func (c LocalController) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) (
|
||||
[]domain.PhysicalPeer,
|
||||
error,
|
||||
) {
|
||||
device, err := c.wg.Device(string(deviceId))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("device error: %w", err)
|
||||
}
|
||||
|
||||
peers := make([]domain.PhysicalPeer, 0, len(device.Peers))
|
||||
for _, peer := range device.Peers {
|
||||
peerModel, err := c.convertWireGuardPeer(&peer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("peer convert failed for %v: %w", peer.PublicKey, err)
|
||||
}
|
||||
peers = append(peers, peerModel)
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func (c LocalController) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer, error) {
|
||||
peerModel := domain.PhysicalPeer{
|
||||
Identifier: domain.PeerIdentifier(peer.PublicKey.String()),
|
||||
Endpoint: "",
|
||||
AllowedIPs: nil,
|
||||
KeyPair: domain.KeyPair{
|
||||
PublicKey: peer.PublicKey.String(),
|
||||
},
|
||||
PresharedKey: "",
|
||||
PersistentKeepalive: int(peer.PersistentKeepaliveInterval.Seconds()),
|
||||
LastHandshake: peer.LastHandshakeTime,
|
||||
ProtocolVersion: peer.ProtocolVersion,
|
||||
BytesUpload: uint64(peer.ReceiveBytes),
|
||||
BytesDownload: uint64(peer.TransmitBytes),
|
||||
ImportSource: domain.ControllerTypeLocal,
|
||||
}
|
||||
|
||||
// Set local extras - local peers are never disabled in the kernel
|
||||
peerModel.SetExtras(domain.LocalPeerExtras{
|
||||
Disabled: false,
|
||||
})
|
||||
|
||||
for _, addr := range peer.AllowedIPs {
|
||||
peerModel.AllowedIPs = append(peerModel.AllowedIPs, domain.CidrFromIpNet(addr))
|
||||
}
|
||||
if peer.Endpoint != nil {
|
||||
peerModel.Endpoint = peer.Endpoint.String()
|
||||
}
|
||||
if peer.PresharedKey != (wgtypes.Key{}) {
|
||||
peerModel.PresharedKey = domain.PreSharedKey(peer.PresharedKey.String())
|
||||
}
|
||||
|
||||
return peerModel, nil
|
||||
}
|
||||
|
||||
func (c LocalController) SaveInterface(
|
||||
_ context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
||||
) error {
|
||||
physicalInterface, err := c.getOrCreateInterface(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updateFunc != nil {
|
||||
physicalInterface, err = updateFunc(physicalInterface)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.updateLowLevelInterface(physicalInterface); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.updateWireGuardInterface(physicalInterface); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) getOrCreateInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
|
||||
device, err := c.getInterface(id)
|
||||
if err == nil {
|
||||
return device, nil // interface exists
|
||||
}
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("device error: %w", err) // unknown error
|
||||
}
|
||||
|
||||
// create new device
|
||||
if err := c.createLowLevelInterface(id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
device, err = c.getInterface(id)
|
||||
return device, err
|
||||
}
|
||||
|
||||
func (c LocalController) getInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
|
||||
device, err := c.wg.Device(string(id))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pi, err := c.convertWireGuardInterface(device)
|
||||
return &pi, err
|
||||
}
|
||||
|
||||
func (c LocalController) createLowLevelInterface(id domain.InterfaceIdentifier) error {
|
||||
link := &netlink.GenericLink{
|
||||
LinkAttrs: netlink.LinkAttrs{
|
||||
Name: string(id),
|
||||
},
|
||||
LinkType: "wireguard",
|
||||
}
|
||||
err := c.nl.LinkAdd(link)
|
||||
if err != nil {
|
||||
return fmt.Errorf("link add failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) updateLowLevelInterface(pi *domain.PhysicalInterface) error {
|
||||
link, err := c.nl.LinkByName(string(pi.Identifier))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if pi.Mtu != 0 {
|
||||
if err := c.nl.LinkSetMTU(link, pi.Mtu); err != nil {
|
||||
return fmt.Errorf("mtu error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, addr := range pi.Addresses {
|
||||
err := c.nl.AddrReplace(link, addr.NetlinkAddr())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set ip %s: %w", addr.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove unwanted IP addresses
|
||||
rawAddresses, err := c.nl.AddrList(link)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch interface ips: %w", err)
|
||||
}
|
||||
for _, rawAddr := range rawAddresses {
|
||||
netlinkAddr := domain.CidrFromNetlinkAddr(rawAddr)
|
||||
remove := true
|
||||
for _, addr := range pi.Addresses {
|
||||
if addr == netlinkAddr {
|
||||
remove = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !remove {
|
||||
continue
|
||||
}
|
||||
|
||||
err := c.nl.AddrDel(link, &rawAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove deprecated ip %s: %w", netlinkAddr.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update link state
|
||||
if pi.DeviceUp {
|
||||
if err := c.nl.LinkSetUp(link); err != nil {
|
||||
return fmt.Errorf("failed to bring up device: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := c.nl.LinkSetDown(link); err != nil {
|
||||
return fmt.Errorf("failed to bring down device: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
|
||||
pKey, err := wgtypes.NewKey(pi.KeyPair.GetPrivateKeyBytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var fwMark *int
|
||||
if pi.FirewallMark != 0 {
|
||||
intFwMark := int(pi.FirewallMark)
|
||||
fwMark = &intFwMark
|
||||
}
|
||||
err = c.wg.ConfigureDevice(string(pi.Identifier), wgtypes.Config{
|
||||
PrivateKey: &pKey,
|
||||
ListenPort: &pi.ListenPort,
|
||||
FirewallMark: fwMark,
|
||||
ReplacePeers: false,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
|
||||
if err := c.deleteLowLevelInterface(id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) deleteLowLevelInterface(id domain.InterfaceIdentifier) error {
|
||||
link, err := c.nl.LinkByName(string(id))
|
||||
if err != nil {
|
||||
var linkNotFoundError netlink.LinkNotFoundError
|
||||
if errors.As(err, &linkNotFoundError) {
|
||||
return nil // ignore not found error
|
||||
}
|
||||
return fmt.Errorf("unable to find low level interface: %w", err)
|
||||
}
|
||||
|
||||
err = c.nl.LinkDel(link)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete low level interface: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) SavePeer(
|
||||
_ context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
||||
) error {
|
||||
physicalPeer, err := c.getOrCreatePeer(deviceId, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
physicalPeer, err = updateFunc(physicalPeer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if the peer is disabled by looking at the backend extras
|
||||
// For local controller, disabled peers should be deleted
|
||||
if physicalPeer.GetExtras() != nil {
|
||||
switch extras := physicalPeer.GetExtras().(type) {
|
||||
case domain.LocalPeerExtras:
|
||||
if extras.Disabled {
|
||||
// Delete the peer instead of updating it
|
||||
return c.deletePeer(deviceId, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.updatePeer(deviceId, physicalPeer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) getOrCreatePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (
|
||||
*domain.PhysicalPeer,
|
||||
error,
|
||||
) {
|
||||
peer, err := c.getPeer(deviceId, id)
|
||||
if err == nil {
|
||||
return peer, nil // peer exists
|
||||
}
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("peer error: %w", err) // unknown error
|
||||
}
|
||||
|
||||
// create new peer
|
||||
err = c.wg.ConfigureDevice(string(deviceId), wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{
|
||||
{
|
||||
PublicKey: id.ToPublicKey(),
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("peer create error for %s: %w", id.ToPublicKey(), err)
|
||||
}
|
||||
|
||||
peer, err = c.getPeer(deviceId, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("peer error after create: %w", err)
|
||||
}
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
func (c LocalController) getPeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (
|
||||
*domain.PhysicalPeer,
|
||||
error,
|
||||
) {
|
||||
if !id.IsPublicKey() {
|
||||
return nil, errors.New("invalid public key")
|
||||
}
|
||||
|
||||
device, err := c.wg.Device(string(deviceId))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
publicKey := id.ToPublicKey()
|
||||
for _, peer := range device.Peers {
|
||||
if peer.PublicKey != publicKey {
|
||||
continue
|
||||
}
|
||||
|
||||
peerModel, err := c.convertWireGuardPeer(&peer)
|
||||
return &peerModel, err
|
||||
}
|
||||
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
func (c LocalController) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.PhysicalPeer) error {
|
||||
cfg := wgtypes.PeerConfig{
|
||||
PublicKey: pp.GetPublicKey(),
|
||||
Remove: false,
|
||||
UpdateOnly: true,
|
||||
PresharedKey: pp.GetPresharedKey(),
|
||||
Endpoint: pp.GetEndpointAddress(),
|
||||
PersistentKeepaliveInterval: pp.GetPersistentKeepaliveTime(),
|
||||
ReplaceAllowedIPs: true,
|
||||
AllowedIPs: pp.GetAllowedIPs(),
|
||||
}
|
||||
|
||||
err := c.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) DeletePeer(
|
||||
_ context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
id domain.PeerIdentifier,
|
||||
) error {
|
||||
if !id.IsPublicKey() {
|
||||
return errors.New("invalid public key")
|
||||
}
|
||||
|
||||
err := c.deletePeer(deviceId, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) deletePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error {
|
||||
cfg := wgtypes.PeerConfig{
|
||||
PublicKey: id.ToPublicKey(),
|
||||
Remove: true,
|
||||
}
|
||||
|
||||
err := c.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// endregion wireguard-related
|
||||
|
||||
// region wg-quick-related
|
||||
|
||||
func (c LocalController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
|
||||
if hookCmd == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
slog.Debug("executing interface hook", "interface", id, "hook", hookCmd)
|
||||
err := c.exec(hookCmd, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to exec hook: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
|
||||
if dnsStr == "" && dnsSearchStr == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnsServers := internal.SliceString(dnsStr)
|
||||
dnsSearchDomains := internal.SliceString(dnsSearchStr)
|
||||
|
||||
dnsCommand := "resolvconf -a %resPref%i -m 0 -x"
|
||||
dnsCommandInput := make([]string, 0, len(dnsServers)+len(dnsSearchDomains))
|
||||
|
||||
for _, dnsServer := range dnsServers {
|
||||
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("nameserver %s", dnsServer))
|
||||
}
|
||||
for _, searchDomain := range dnsSearchDomains {
|
||||
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("search %s", searchDomain))
|
||||
}
|
||||
|
||||
err := c.exec(dnsCommand, id, dnsCommandInput...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to set dns settings (is resolvconf available?, for systemd create this symlink: ln -s /usr/bin/resolvectl /usr/local/bin/resolvconf): %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) UnsetDNS(id domain.InterfaceIdentifier) error {
|
||||
dnsCommand := "resolvconf -d %resPref%i -f"
|
||||
|
||||
err := c.exec(dnsCommand, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unset dns settings: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) replaceCommandPlaceHolders(command string, interfaceId domain.InterfaceIdentifier) string {
|
||||
command = strings.ReplaceAll(command, "%resPref", c.resolvConfIfacePrefix)
|
||||
return strings.ReplaceAll(command, "%i", string(interfaceId))
|
||||
}
|
||||
|
||||
func (c LocalController) exec(command string, interfaceId domain.InterfaceIdentifier, stdin ...string) error {
|
||||
commandWithInterfaceName := c.replaceCommandPlaceHolders(command, interfaceId)
|
||||
cmd := exec.Command(c.shellCmd, "-ce", commandWithInterfaceName)
|
||||
if len(stdin) > 0 {
|
||||
b := &bytes.Buffer{}
|
||||
for _, ln := range stdin {
|
||||
if _, err := fmt.Fprint(b, ln); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
cmd.Stdin = b
|
||||
}
|
||||
out, err := cmd.CombinedOutput() // execute and wait for output
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err)
|
||||
}
|
||||
slog.Debug("executed shell command",
|
||||
"command", commandWithInterfaceName,
|
||||
"output", string(out))
|
||||
return nil
|
||||
}
|
||||
|
||||
// endregion wg-quick-related
|
||||
|
||||
// region routing-related
|
||||
|
||||
func (c LocalController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error {
|
||||
// update fwmark rules
|
||||
if err := c.setFwMarkRules(rules); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update main rule
|
||||
if err := c.setMainRule(rules); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// cleanup old main rules
|
||||
if err := c.cleanupMainRule(rules); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) setFwMarkRules(rules []domain.RouteRule) error {
|
||||
for _, rule := range rules {
|
||||
existingRules, err := c.nl.RuleList(int(rule.IpFamily))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family %s: %w", rule.IpFamily, err)
|
||||
}
|
||||
|
||||
ruleExists := false
|
||||
for _, existingRule := range existingRules {
|
||||
if rule.FwMark == existingRule.Mark && rule.Table == existingRule.Table {
|
||||
ruleExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if ruleExists {
|
||||
continue // rule already exists, no need to recreate it
|
||||
}
|
||||
|
||||
// create a missing rule
|
||||
if err := c.nl.RuleAdd(&netlink.Rule{
|
||||
Family: int(rule.IpFamily),
|
||||
Table: rule.Table,
|
||||
Mark: rule.FwMark,
|
||||
Invert: true,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: -1,
|
||||
Priority: c.getRulePriority(existingRules),
|
||||
Mask: nil,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to setup %s rule for fwmark %d and table %d: %w",
|
||||
rule.IpFamily, rule.FwMark, rule.Table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) getRulePriority(existingRules []netlink.Rule) int {
|
||||
prio := 32700 // linux main rule has a priority of 32766
|
||||
for {
|
||||
isFresh := true
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Priority == prio {
|
||||
isFresh = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if isFresh {
|
||||
break
|
||||
} else {
|
||||
prio--
|
||||
}
|
||||
}
|
||||
return prio
|
||||
}
|
||||
|
||||
func (c LocalController) setMainRule(rules []domain.RouteRule) error {
|
||||
var family domain.IpFamily
|
||||
shouldHaveMainRule := false
|
||||
for _, rule := range rules {
|
||||
family = rule.IpFamily
|
||||
if rule.HasDefault == true {
|
||||
shouldHaveMainRule = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !shouldHaveMainRule {
|
||||
return nil
|
||||
}
|
||||
|
||||
existingRules, err := c.nl.RuleList(int(family))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family %s: %w", family, err)
|
||||
}
|
||||
|
||||
ruleExists := false
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
|
||||
ruleExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if ruleExists {
|
||||
return nil // rule already exists, skip re-creation
|
||||
}
|
||||
|
||||
if err := c.nl.RuleAdd(&netlink.Rule{
|
||||
Family: int(family),
|
||||
Table: unix.RT_TABLE_MAIN,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: 0,
|
||||
Priority: c.getMainRulePriority(existingRules),
|
||||
Mark: 0,
|
||||
Mask: nil,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to setup rule for main table: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int {
|
||||
priority := c.cfg.Advanced.RulePrioOffset
|
||||
for {
|
||||
isFresh := true
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Priority == priority {
|
||||
isFresh = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if isFresh {
|
||||
break
|
||||
} else {
|
||||
priority++
|
||||
}
|
||||
}
|
||||
return priority
|
||||
}
|
||||
|
||||
func (c LocalController) cleanupMainRule(rules []domain.RouteRule) error {
|
||||
var family domain.IpFamily
|
||||
for _, rule := range rules {
|
||||
family = rule.IpFamily
|
||||
break
|
||||
}
|
||||
|
||||
existingRules, err := c.nl.RuleList(int(family))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family %s: %w", family, err)
|
||||
}
|
||||
|
||||
shouldHaveMainRule := false
|
||||
for _, rule := range rules {
|
||||
if rule.HasDefault == true {
|
||||
shouldHaveMainRule = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
mainRules := 0
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
|
||||
mainRules++
|
||||
}
|
||||
}
|
||||
|
||||
removalCount := 0
|
||||
if mainRules > 1 {
|
||||
removalCount = mainRules - 1 // we only want one single rule
|
||||
}
|
||||
if !shouldHaveMainRule {
|
||||
removalCount = mainRules
|
||||
}
|
||||
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
|
||||
if removalCount > 0 {
|
||||
existingRule.Family = int(family) // set family, somehow the RuleList method does not populate the family field
|
||||
if err := c.nl.RuleDel(&existingRule); err != nil {
|
||||
return fmt.Errorf("failed to delete main rule: %w", err)
|
||||
}
|
||||
removalCount--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// endregion routing-related
|
||||
|
||||
// region statistics-related
|
||||
|
||||
func (c LocalController) PingAddresses(
|
||||
ctx context.Context,
|
||||
addr string,
|
||||
) (*domain.PingerResult, error) {
|
||||
pinger, err := probing.NewPinger(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to instantiate pinger for %s: %w", addr, err)
|
||||
}
|
||||
|
||||
checkCount := 1
|
||||
pinger.SetPrivileged(!c.cfg.Statistics.PingUnprivileged)
|
||||
pinger.Count = checkCount
|
||||
pinger.Timeout = 2 * time.Second
|
||||
err = pinger.RunWithContext(ctx) // Blocks until finished.
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to ping %s: %w", addr, err)
|
||||
}
|
||||
|
||||
stats := pinger.Statistics()
|
||||
|
||||
return &domain.PingerResult{
|
||||
PacketsRecv: stats.PacketsRecv,
|
||||
PacketsSent: stats.PacketsSent,
|
||||
Rtts: stats.Rtts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// endregion statistics-related
|
829
internal/adapters/wgcontroller/mikrotik.go
Normal file
829
internal/adapters/wgcontroller/mikrotik.go
Normal file
@@ -0,0 +1,829 @@
|
||||
package wgcontroller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
"github.com/h44z/wg-portal/internal/lowlevel"
|
||||
)
|
||||
|
||||
type MikrotikController struct {
|
||||
coreCfg *config.Config
|
||||
cfg *config.BackendMikrotik
|
||||
|
||||
client *lowlevel.MikrotikApiClient
|
||||
|
||||
// Add mutexes to prevent race conditions
|
||||
interfaceMutexes sync.Map // map[domain.InterfaceIdentifier]*sync.Mutex
|
||||
peerMutexes sync.Map // map[domain.PeerIdentifier]*sync.Mutex
|
||||
}
|
||||
|
||||
func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik) (*MikrotikController, error) {
|
||||
client, err := lowlevel.NewMikrotikApiClient(coreCfg, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Mikrotik API client: %w", err)
|
||||
}
|
||||
|
||||
return &MikrotikController{
|
||||
coreCfg: coreCfg,
|
||||
cfg: cfg,
|
||||
|
||||
client: client,
|
||||
|
||||
interfaceMutexes: sync.Map{},
|
||||
peerMutexes: sync.Map{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) GetId() domain.InterfaceBackend {
|
||||
return domain.InterfaceBackend(c.cfg.Id)
|
||||
}
|
||||
|
||||
// getInterfaceMutex returns a mutex for the given interface to prevent concurrent modifications
|
||||
func (c *MikrotikController) getInterfaceMutex(id domain.InterfaceIdentifier) *sync.Mutex {
|
||||
mutex, _ := c.interfaceMutexes.LoadOrStore(id, &sync.Mutex{})
|
||||
return mutex.(*sync.Mutex)
|
||||
}
|
||||
|
||||
// getPeerMutex returns a mutex for the given peer to prevent concurrent modifications
|
||||
func (c *MikrotikController) getPeerMutex(id domain.PeerIdentifier) *sync.Mutex {
|
||||
mutex, _ := c.peerMutexes.LoadOrStore(id, &sync.Mutex{})
|
||||
return mutex.(*sync.Mutex)
|
||||
}
|
||||
|
||||
// region wireguard-related
|
||||
|
||||
func (c *MikrotikController) GetInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) {
|
||||
wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
".id", "name", "public-key", "private-key", "listen-port", "mtu", "disabled", "running", "comment",
|
||||
},
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return nil, fmt.Errorf("failed to query interfaces: %v", wgReply.Error)
|
||||
}
|
||||
|
||||
// Parallelize loading of interface details to speed up overall latency.
|
||||
// Use a bounded semaphore to avoid overloading the MikroTik device.
|
||||
maxConcurrent := c.cfg.GetConcurrency()
|
||||
sem := make(chan struct{}, maxConcurrent)
|
||||
|
||||
interfaces := make([]domain.PhysicalInterface, 0, len(wgReply.Data))
|
||||
var mu sync.Mutex
|
||||
var wgWait sync.WaitGroup
|
||||
var firstErr error
|
||||
ctx2, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
for _, wgObj := range wgReply.Data {
|
||||
wgWait.Add(1)
|
||||
sem <- struct{}{} // block if more than maxConcurrent requests are processing
|
||||
go func(wg lowlevel.GenericJsonObject) {
|
||||
defer wgWait.Done()
|
||||
defer func() { <-sem }() // read from the semaphore and make space for the next entry
|
||||
if firstErr != nil {
|
||||
return
|
||||
}
|
||||
pi, err := c.loadInterfaceData(ctx2, wg)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
cancel()
|
||||
}
|
||||
mu.Unlock()
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
interfaces = append(interfaces, *pi)
|
||||
mu.Unlock()
|
||||
}(wgObj)
|
||||
}
|
||||
|
||||
wgWait.Wait()
|
||||
if firstErr != nil {
|
||||
return nil, firstErr
|
||||
}
|
||||
|
||||
return interfaces, nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (
|
||||
*domain.PhysicalInterface,
|
||||
error,
|
||||
) {
|
||||
wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
".id", "name", "public-key", "private-key", "listen-port", "mtu", "disabled", "running",
|
||||
},
|
||||
Filters: map[string]string{
|
||||
"name": string(id),
|
||||
},
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return nil, fmt.Errorf("failed to query interface %s: %v", id, wgReply.Error)
|
||||
}
|
||||
|
||||
if len(wgReply.Data) == 0 {
|
||||
return nil, fmt.Errorf("interface %s not found", id)
|
||||
}
|
||||
|
||||
return c.loadInterfaceData(ctx, wgReply.Data[0])
|
||||
}
|
||||
|
||||
func (c *MikrotikController) loadInterfaceData(
|
||||
ctx context.Context,
|
||||
wireGuardObj lowlevel.GenericJsonObject,
|
||||
) (*domain.PhysicalInterface, error) {
|
||||
deviceId := wireGuardObj.GetString(".id")
|
||||
deviceName := wireGuardObj.GetString("name")
|
||||
ifaceReply := c.client.Get(ctx, "/interface/"+deviceId, &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
"name", "rx-byte", "tx-byte",
|
||||
},
|
||||
})
|
||||
if ifaceReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return nil, fmt.Errorf("failed to query interface %s: %v", deviceId, ifaceReply.Error)
|
||||
}
|
||||
|
||||
ipv4, ipv6, err := c.loadIpAddresses(ctx, deviceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query IP addresses for interface %s: %v", deviceId, err)
|
||||
}
|
||||
addresses := c.convertIpAddresses(ipv4, ipv6)
|
||||
|
||||
interfaceModel, err := c.convertWireGuardInterface(wireGuardObj, ifaceReply.Data, addresses)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("interface convert failed for %s: %w", deviceName, err)
|
||||
}
|
||||
return &interfaceModel, nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) loadIpAddresses(
|
||||
ctx context.Context,
|
||||
deviceName string,
|
||||
) (ipv4 []lowlevel.GenericJsonObject, ipv6 []lowlevel.GenericJsonObject, err error) {
|
||||
// Query IPv4 and IPv6 addresses in parallel to reduce latency.
|
||||
var (
|
||||
v4 []lowlevel.GenericJsonObject
|
||||
v6 []lowlevel.GenericJsonObject
|
||||
v4Err error
|
||||
v6Err error
|
||||
wg sync.WaitGroup
|
||||
)
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
addrV4Reply := c.client.Query(ctx, "/ip/address", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
".id", "address", "network",
|
||||
},
|
||||
Filters: map[string]string{
|
||||
"interface": deviceName,
|
||||
"dynamic": "false", // we only want static addresses
|
||||
"disabled": "false", // we only want addresses that are not disabled
|
||||
},
|
||||
})
|
||||
if addrV4Reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
v4Err = fmt.Errorf("failed to query IPv4 addresses for interface %s: %v", deviceName, addrV4Reply.Error)
|
||||
return
|
||||
}
|
||||
v4 = addrV4Reply.Data
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
addrV6Reply := c.client.Query(ctx, "/ipv6/address", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
".id", "address", "network",
|
||||
},
|
||||
Filters: map[string]string{
|
||||
"interface": deviceName,
|
||||
"dynamic": "false", // we only want static addresses
|
||||
"disabled": "false", // we only want addresses that are not disabled
|
||||
},
|
||||
})
|
||||
if addrV6Reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
v6Err = fmt.Errorf("failed to query IPv6 addresses for interface %s: %v", deviceName, addrV6Reply.Error)
|
||||
return
|
||||
}
|
||||
v6 = addrV6Reply.Data
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
if v4Err != nil {
|
||||
return nil, nil, v4Err
|
||||
}
|
||||
if v6Err != nil {
|
||||
return nil, nil, v6Err
|
||||
}
|
||||
|
||||
return v4, v6, nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) convertIpAddresses(
|
||||
ipv4, ipv6 []lowlevel.GenericJsonObject,
|
||||
) []domain.Cidr {
|
||||
addresses := make([]domain.Cidr, 0, len(ipv4)+len(ipv6))
|
||||
for _, addr := range append(ipv4, ipv6...) {
|
||||
addrStr := addr.GetString("address")
|
||||
if addrStr == "" {
|
||||
continue
|
||||
}
|
||||
cidr, err := domain.CidrFromString(addrStr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
addresses = append(addresses, cidr)
|
||||
}
|
||||
|
||||
return addresses
|
||||
}
|
||||
|
||||
func (c *MikrotikController) convertWireGuardInterface(
|
||||
wg, iface lowlevel.GenericJsonObject,
|
||||
addresses []domain.Cidr,
|
||||
) (
|
||||
domain.PhysicalInterface,
|
||||
error,
|
||||
) {
|
||||
pi := domain.PhysicalInterface{
|
||||
Identifier: domain.InterfaceIdentifier(wg.GetString("name")),
|
||||
KeyPair: domain.KeyPair{
|
||||
PrivateKey: wg.GetString("private-key"),
|
||||
PublicKey: wg.GetString("public-key"),
|
||||
},
|
||||
ListenPort: wg.GetInt("listen-port"),
|
||||
Addresses: addresses,
|
||||
Mtu: wg.GetInt("mtu"),
|
||||
FirewallMark: 0,
|
||||
DeviceUp: wg.GetBool("running"),
|
||||
ImportSource: domain.ControllerTypeMikrotik,
|
||||
DeviceType: domain.ControllerTypeMikrotik,
|
||||
BytesUpload: uint64(iface.GetInt("tx-byte")),
|
||||
BytesDownload: uint64(iface.GetInt("rx-byte")),
|
||||
}
|
||||
|
||||
pi.SetExtras(domain.MikrotikInterfaceExtras{
|
||||
Id: wg.GetString(".id"),
|
||||
Comment: wg.GetString("comment"),
|
||||
Disabled: wg.GetBool("disabled"),
|
||||
})
|
||||
|
||||
return pi, nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) GetPeers(ctx context.Context, deviceId domain.InterfaceIdentifier) (
|
||||
[]domain.PhysicalPeer,
|
||||
error,
|
||||
) {
|
||||
wgReply := c.client.Query(ctx, "/interface/wireguard/peers", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
".id", "name", "allowed-address", "client-address", "client-endpoint", "client-keepalive", "comment",
|
||||
"current-endpoint-address", "current-endpoint-port", "last-handshake", "persistent-keepalive",
|
||||
"public-key", "private-key", "preshared-key", "mtu", "disabled", "rx", "tx", "responder", "client-dns",
|
||||
},
|
||||
Filters: map[string]string{
|
||||
"interface": string(deviceId),
|
||||
},
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return nil, fmt.Errorf("failed to query peers for %s: %v", deviceId, wgReply.Error)
|
||||
}
|
||||
|
||||
if len(wgReply.Data) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
peers := make([]domain.PhysicalPeer, 0, len(wgReply.Data))
|
||||
for _, peer := range wgReply.Data {
|
||||
peerModel, err := c.convertWireGuardPeer(peer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("peer convert failed for %v: %w", peer.GetString("name"), err)
|
||||
}
|
||||
peers = append(peers, peerModel)
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) convertWireGuardPeer(peer lowlevel.GenericJsonObject) (
|
||||
domain.PhysicalPeer,
|
||||
error,
|
||||
) {
|
||||
keepAliveSeconds := 0
|
||||
duration, err := time.ParseDuration(peer.GetString("persistent-keepalive"))
|
||||
if err == nil {
|
||||
keepAliveSeconds = int(duration.Seconds())
|
||||
}
|
||||
|
||||
currentEndpoint := ""
|
||||
if peer.GetString("current-endpoint-address") != "" && peer.GetString("current-endpoint-port") != "" {
|
||||
currentEndpoint = peer.GetString("current-endpoint-address") + ":" + peer.GetString("current-endpoint-port")
|
||||
}
|
||||
|
||||
lastHandshakeTime := time.Time{}
|
||||
if peer.GetString("last-handshake") != "" {
|
||||
relDuration, err := time.ParseDuration(peer.GetString("last-handshake"))
|
||||
if err == nil {
|
||||
lastHandshakeTime = time.Now().Add(-relDuration)
|
||||
}
|
||||
}
|
||||
|
||||
allowedAddresses, _ := domain.CidrsFromString(peer.GetString("allowed-address"))
|
||||
|
||||
clientKeepAliveSeconds := 0
|
||||
duration, err = time.ParseDuration(peer.GetString("client-keepalive"))
|
||||
if err == nil {
|
||||
clientKeepAliveSeconds = int(duration.Seconds())
|
||||
}
|
||||
|
||||
peerModel := domain.PhysicalPeer{
|
||||
Identifier: domain.PeerIdentifier(peer.GetString("public-key")),
|
||||
Endpoint: currentEndpoint,
|
||||
AllowedIPs: allowedAddresses,
|
||||
KeyPair: domain.KeyPair{
|
||||
PublicKey: peer.GetString("public-key"),
|
||||
PrivateKey: peer.GetString("private-key"),
|
||||
},
|
||||
PresharedKey: domain.PreSharedKey(peer.GetString("preshared-key")),
|
||||
PersistentKeepalive: keepAliveSeconds,
|
||||
LastHandshake: lastHandshakeTime,
|
||||
ProtocolVersion: 0, // Mikrotik does not support protocol versioning, so we set it to 0
|
||||
BytesUpload: uint64(peer.GetInt("rx")),
|
||||
BytesDownload: uint64(peer.GetInt("tx")),
|
||||
ImportSource: domain.ControllerTypeMikrotik,
|
||||
}
|
||||
|
||||
peerModel.SetExtras(domain.MikrotikPeerExtras{
|
||||
Id: peer.GetString(".id"),
|
||||
Name: peer.GetString("name"),
|
||||
Comment: peer.GetString("comment"),
|
||||
IsResponder: peer.GetBool("responder"),
|
||||
Disabled: peer.GetBool("disabled"),
|
||||
ClientEndpoint: peer.GetString("client-endpoint"),
|
||||
ClientAddress: peer.GetString("client-address"),
|
||||
ClientDns: peer.GetString("client-dns"),
|
||||
ClientKeepalive: clientKeepAliveSeconds,
|
||||
})
|
||||
|
||||
return peerModel, nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) SaveInterface(
|
||||
ctx context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
||||
) error {
|
||||
// Lock the interface to prevent concurrent modifications
|
||||
mutex := c.getInterfaceMutex(id)
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
physicalInterface, err := c.getOrCreateInterface(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
deviceId := physicalInterface.GetExtras().(domain.MikrotikInterfaceExtras).Id
|
||||
if updateFunc != nil {
|
||||
physicalInterface, err = updateFunc(physicalInterface)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newExtras := physicalInterface.GetExtras().(domain.MikrotikInterfaceExtras)
|
||||
newExtras.Id = deviceId // ensure the ID is not changed
|
||||
physicalInterface.SetExtras(newExtras)
|
||||
}
|
||||
|
||||
if err := c.updateInterface(ctx, physicalInterface); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) getOrCreateInterface(
|
||||
ctx context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
) (*domain.PhysicalInterface, error) {
|
||||
wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
".id", "name", "public-key", "private-key", "listen-port", "mtu", "disabled", "running",
|
||||
},
|
||||
Filters: map[string]string{
|
||||
"name": string(id),
|
||||
},
|
||||
})
|
||||
if wgReply.Status == lowlevel.MikrotikApiStatusOk && len(wgReply.Data) > 0 {
|
||||
return c.loadInterfaceData(ctx, wgReply.Data[0])
|
||||
}
|
||||
|
||||
// create a new interface if it does not exist
|
||||
createReply := c.client.Create(ctx, "/interface/wireguard", lowlevel.GenericJsonObject{
|
||||
"name": string(id),
|
||||
})
|
||||
if wgReply.Status == lowlevel.MikrotikApiStatusOk {
|
||||
return c.loadInterfaceData(ctx, createReply.Data)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to create interface %s: %v", id, createReply.Error)
|
||||
}
|
||||
|
||||
func (c *MikrotikController) updateInterface(ctx context.Context, pi *domain.PhysicalInterface) error {
|
||||
extras := pi.GetExtras().(domain.MikrotikInterfaceExtras)
|
||||
interfaceId := extras.Id
|
||||
wgReply := c.client.Update(ctx, "/interface/wireguard/"+interfaceId, lowlevel.GenericJsonObject{
|
||||
"name": pi.Identifier,
|
||||
"comment": extras.Comment,
|
||||
"mtu": strconv.Itoa(pi.Mtu),
|
||||
"listen-port": strconv.Itoa(pi.ListenPort),
|
||||
"private-key": pi.KeyPair.PrivateKey,
|
||||
"disabled": strconv.FormatBool(!pi.DeviceUp),
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to update interface %s: %v", pi.Identifier, wgReply.Error)
|
||||
}
|
||||
|
||||
// update the interface's addresses
|
||||
currentV4, currentV6, err := c.loadIpAddresses(ctx, string(pi.Identifier))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load current addresses for interface %s: %v", pi.Identifier, err)
|
||||
}
|
||||
currentAddresses := c.convertIpAddresses(currentV4, currentV6)
|
||||
|
||||
// get all addresses that are currently not in the interface, only in pi
|
||||
newAddresses := make([]domain.Cidr, 0, len(pi.Addresses))
|
||||
for _, addr := range pi.Addresses {
|
||||
if slices.Contains(currentAddresses, addr) {
|
||||
continue
|
||||
}
|
||||
newAddresses = append(newAddresses, addr)
|
||||
}
|
||||
// get obsolete addresses that are in the interface, but not in pi
|
||||
obsoleteAddresses := make([]domain.Cidr, 0, len(currentAddresses))
|
||||
for _, addr := range currentAddresses {
|
||||
if slices.Contains(pi.Addresses, addr) {
|
||||
continue
|
||||
}
|
||||
obsoleteAddresses = append(obsoleteAddresses, addr)
|
||||
}
|
||||
|
||||
// update the IP addresses for the interface
|
||||
if err := c.updateIpAddresses(ctx, string(pi.Identifier), currentV4, currentV6,
|
||||
newAddresses, obsoleteAddresses); err != nil {
|
||||
return fmt.Errorf("failed to update IP addresses for interface %s: %v", pi.Identifier, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) updateIpAddresses(
|
||||
ctx context.Context,
|
||||
deviceName string,
|
||||
currentV4, currentV6 []lowlevel.GenericJsonObject,
|
||||
new, obsolete []domain.Cidr,
|
||||
) error {
|
||||
// first, delete all obsolete addresses
|
||||
for _, addr := range obsolete {
|
||||
// find ID of the address to delete
|
||||
if addr.IsV4() {
|
||||
for _, a := range currentV4 {
|
||||
if a.GetString("address") == addr.String() {
|
||||
// delete the address
|
||||
reply := c.client.Delete(ctx, "/ip/address/"+a.GetString(".id"))
|
||||
if reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to delete obsolete IPv4 address %s: %v", addr, reply.Error)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, a := range currentV6 {
|
||||
if a.GetString("address") == addr.String() {
|
||||
// delete the address
|
||||
reply := c.client.Delete(ctx, "/ipv6/address/"+a.GetString(".id"))
|
||||
if reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to delete obsolete IPv6 address %s: %v", addr, reply.Error)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// then, add all new addresses
|
||||
for _, addr := range new {
|
||||
var createPath string
|
||||
if addr.IsV4() {
|
||||
createPath = "/ip/address"
|
||||
} else {
|
||||
createPath = "/ipv6/address"
|
||||
}
|
||||
|
||||
// create the address
|
||||
reply := c.client.Create(ctx, createPath, lowlevel.GenericJsonObject{
|
||||
"address": addr.String(),
|
||||
"interface": deviceName,
|
||||
})
|
||||
if reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to create new address %s: %v", addr, reply.Error)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
|
||||
// Lock the interface to prevent concurrent modifications
|
||||
mutex := c.getInterfaceMutex(id)
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
// delete the interface's addresses
|
||||
currentV4, currentV6, err := c.loadIpAddresses(ctx, string(id))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load current addresses for interface %s: %v", id, err)
|
||||
}
|
||||
for _, a := range currentV4 {
|
||||
// delete the address
|
||||
reply := c.client.Delete(ctx, "/ip/address/"+a.GetString(".id"))
|
||||
if reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to delete IPv4 address %s: %v", a.GetString("address"), reply.Error)
|
||||
}
|
||||
}
|
||||
for _, a := range currentV6 {
|
||||
// delete the address
|
||||
reply := c.client.Delete(ctx, "/ipv6/address/"+a.GetString(".id"))
|
||||
if reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to delete IPv6 address %s: %v", a.GetString("address"), reply.Error)
|
||||
}
|
||||
}
|
||||
|
||||
// delete the WireGuard interface
|
||||
wgReply := c.client.Query(ctx, "/interface/wireguard", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{".id"},
|
||||
Filters: map[string]string{
|
||||
"name": string(id),
|
||||
},
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("unable to find WireGuard interface %s: %v", id, wgReply.Error)
|
||||
}
|
||||
if len(wgReply.Data) == 0 {
|
||||
return nil // interface does not exist, nothing to delete
|
||||
}
|
||||
|
||||
interfaceId := wgReply.Data[0].GetString(".id")
|
||||
deleteReply := c.client.Delete(ctx, "/interface/wireguard/"+interfaceId)
|
||||
if deleteReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to delete WireGuard interface %s: %v", id, deleteReply.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) SavePeer(
|
||||
ctx context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
||||
) error {
|
||||
// Lock the peer to prevent concurrent modifications
|
||||
mutex := c.getPeerMutex(id)
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
physicalPeer, err := c.getOrCreatePeer(ctx, deviceId, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peerId := physicalPeer.GetExtras().(domain.MikrotikPeerExtras).Id
|
||||
physicalPeer, err = updateFunc(physicalPeer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newExtras := physicalPeer.GetExtras().(domain.MikrotikPeerExtras)
|
||||
newExtras.Id = peerId // ensure the ID is not changed
|
||||
physicalPeer.SetExtras(newExtras)
|
||||
|
||||
if err := c.updatePeer(ctx, deviceId, physicalPeer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) getOrCreatePeer(
|
||||
ctx context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
id domain.PeerIdentifier,
|
||||
) (*domain.PhysicalPeer, error) {
|
||||
wgReply := c.client.Query(ctx, "/interface/wireguard/peers", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
".id", "name", "public-key", "private-key", "preshared-key", "persistent-keepalive", "client-address",
|
||||
"client-endpoint", "client-keepalive", "allowed-address", "client-dns", "comment", "disabled", "responder",
|
||||
},
|
||||
Filters: map[string]string{
|
||||
"public-key": string(id),
|
||||
"interface": string(deviceId),
|
||||
},
|
||||
})
|
||||
if wgReply.Status == lowlevel.MikrotikApiStatusOk && len(wgReply.Data) > 0 {
|
||||
slog.Debug("found existing Mikrotik peer", "peer", id, "interface", deviceId)
|
||||
existingPeer, err := c.convertWireGuardPeer(wgReply.Data[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &existingPeer, nil
|
||||
}
|
||||
|
||||
// create a new peer if it does not exist
|
||||
slog.Debug("creating new Mikrotik peer", "peer", id, "interface", deviceId)
|
||||
createReply := c.client.Create(ctx, "/interface/wireguard/peers", lowlevel.GenericJsonObject{
|
||||
"name": fmt.Sprintf("tmp-wg-%s", id[0:8]),
|
||||
"interface": string(deviceId),
|
||||
"public-key": string(id),
|
||||
"allowed-address": "0.0.0.0/0", // Use 0.0.0.0/0 as default, will be updated by updatePeer
|
||||
})
|
||||
if createReply.Status == lowlevel.MikrotikApiStatusOk {
|
||||
newPeer, err := c.convertWireGuardPeer(createReply.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
slog.Debug("successfully created Mikrotik peer", "peer", id, "interface", deviceId)
|
||||
return &newPeer, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to create peer %s for interface %s: %v", id, deviceId, createReply.Error)
|
||||
}
|
||||
|
||||
func (c *MikrotikController) updatePeer(
|
||||
ctx context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
pp *domain.PhysicalPeer,
|
||||
) error {
|
||||
extras := pp.GetExtras().(domain.MikrotikPeerExtras)
|
||||
peerId := extras.Id
|
||||
|
||||
endpoint := pp.Endpoint
|
||||
endpointPort := "51820" // default port if not set
|
||||
if s := strings.Split(endpoint, ":"); len(s) == 2 {
|
||||
endpoint = s[0]
|
||||
endpointPort = s[1]
|
||||
}
|
||||
|
||||
allowedAddressStr := domain.CidrsToString(pp.AllowedIPs)
|
||||
slog.Debug("updating Mikrotik peer",
|
||||
"peer", pp.Identifier,
|
||||
"interface", deviceId,
|
||||
"allowed-address", allowedAddressStr,
|
||||
"allowed-ips-count", len(pp.AllowedIPs),
|
||||
"disabled", extras.Disabled)
|
||||
|
||||
wgReply := c.client.Update(ctx, "/interface/wireguard/peers/"+peerId, lowlevel.GenericJsonObject{
|
||||
"name": extras.Name,
|
||||
"comment": extras.Comment,
|
||||
"preshared-key": pp.PresharedKey,
|
||||
"public-key": pp.KeyPair.PublicKey,
|
||||
"private-key": pp.KeyPair.PrivateKey,
|
||||
"persistent-keepalive": (time.Duration(pp.PersistentKeepalive) * time.Second).String(),
|
||||
"disabled": strconv.FormatBool(extras.Disabled),
|
||||
"responder": strconv.FormatBool(extras.IsResponder),
|
||||
"client-endpoint": extras.ClientEndpoint,
|
||||
"client-address": extras.ClientAddress,
|
||||
"client-keepalive": (time.Duration(extras.ClientKeepalive) * time.Second).String(),
|
||||
"client-dns": extras.ClientDns,
|
||||
"endpoint-address": endpoint,
|
||||
"endpoint-port": endpointPort,
|
||||
"allowed-address": allowedAddressStr, // Add the missing allowed-address field
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to update peer %s on interface %s: %v", pp.Identifier, deviceId, wgReply.Error)
|
||||
}
|
||||
|
||||
if extras.Disabled {
|
||||
slog.Debug("successfully disabled Mikrotik peer", "peer", pp.Identifier, "interface", deviceId)
|
||||
} else {
|
||||
slog.Debug("successfully updated Mikrotik peer", "peer", pp.Identifier, "interface", deviceId)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) DeletePeer(
|
||||
ctx context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
id domain.PeerIdentifier,
|
||||
) error {
|
||||
// Lock the peer to prevent concurrent modifications
|
||||
mutex := c.getPeerMutex(id)
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
wgReply := c.client.Query(ctx, "/interface/wireguard/peers", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{".id"},
|
||||
Filters: map[string]string{
|
||||
"public-key": string(id),
|
||||
"interface": string(deviceId),
|
||||
},
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("unable to find WireGuard peer %s for interface %s: %v", id, deviceId, wgReply.Error)
|
||||
}
|
||||
if len(wgReply.Data) == 0 {
|
||||
return nil // peer does not exist, nothing to delete
|
||||
}
|
||||
|
||||
peerId := wgReply.Data[0].GetString(".id")
|
||||
deleteReply := c.client.Delete(ctx, "/interface/wireguard/peers/"+peerId)
|
||||
if deleteReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to delete WireGuard peer %s for interface %s: %v", id, deviceId, deleteReply.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// endregion wireguard-related
|
||||
|
||||
// region wg-quick-related
|
||||
|
||||
func (c *MikrotikController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c *MikrotikController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c *MikrotikController) UnsetDNS(id domain.InterfaceIdentifier) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// endregion wg-quick-related
|
||||
|
||||
// region routing-related
|
||||
|
||||
func (c *MikrotikController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c *MikrotikController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// endregion routing-related
|
||||
|
||||
// region statistics-related
|
||||
|
||||
func (c *MikrotikController) PingAddresses(
|
||||
ctx context.Context,
|
||||
addr string,
|
||||
) (*domain.PingerResult, error) {
|
||||
wgReply := c.client.ExecList(ctx, "/tool/ping",
|
||||
// limit to 1 packet with a max running time of 2 seconds
|
||||
lowlevel.GenericJsonObject{"address": addr, "count": 1, "interval": "00:00:02"},
|
||||
)
|
||||
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return nil, fmt.Errorf("failed to ping %s: %v", addr, wgReply.Error)
|
||||
}
|
||||
|
||||
var result domain.PingerResult
|
||||
for _, item := range wgReply.Data {
|
||||
result.PacketsRecv += item.GetInt("received")
|
||||
result.PacketsSent += item.GetInt("sent")
|
||||
|
||||
rttStr := item.GetString("avg-rtt")
|
||||
if rttStr != "" {
|
||||
rtt, err := time.ParseDuration(rttStr)
|
||||
if err == nil {
|
||||
result.Rtts = append(result.Rtts, rtt)
|
||||
} else {
|
||||
// use a high value to indicate failure or timeout
|
||||
result.Rtts = append(result.Rtts, 999999*time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// endregion statistics-related
|
Reference in New Issue
Block a user