mirror of
https://github.com/h44z/wg-portal.git
synced 2025-08-12 08:12:23 +00:00
wip: create different backend handlers (#426)
This commit is contained in:
parent
33dcc80078
commit
15d035ec10
@ -50,7 +50,8 @@ func main() {
|
|||||||
database, err := adapters.NewSqlRepository(rawDb)
|
database, err := adapters.NewSqlRepository(rawDb)
|
||||||
internal.AssertNoError(err)
|
internal.AssertNoError(err)
|
||||||
|
|
||||||
wireGuard := adapters.NewWireGuardRepository()
|
wireGuard, err := wireguard.NewControllerManager(cfg)
|
||||||
|
internal.AssertNoError(err)
|
||||||
|
|
||||||
wgQuick := adapters.NewWgQuickRepo()
|
wgQuick := adapters.NewWgQuickRepo()
|
||||||
|
|
||||||
@ -133,7 +134,7 @@ func main() {
|
|||||||
apiV0EndpointUsers := handlersV0.NewUserEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendUsers)
|
apiV0EndpointUsers := handlersV0.NewUserEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendUsers)
|
||||||
apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendInterfaces)
|
apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendInterfaces)
|
||||||
apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendPeers)
|
apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendPeers)
|
||||||
apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth)
|
apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth, wireGuard)
|
||||||
apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth)
|
apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth)
|
||||||
|
|
||||||
apiFrontend := handlersV0.NewRestApi(apiV0Session,
|
apiFrontend := handlersV0.NewRestApi(apiV0Session,
|
||||||
|
809
internal/adapters/wgcontroller/local.go
Normal file
809
internal/adapters/wgcontroller/local.go
Normal file
@ -0,0 +1,809 @@
|
|||||||
|
package wgcontroller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/vishvananda/netlink"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/h44z/wg-portal/internal"
|
||||||
|
"github.com/h44z/wg-portal/internal/config"
|
||||||
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
|
"github.com/h44z/wg-portal/internal/lowlevel"
|
||||||
|
)
|
||||||
|
|
||||||
|
// region dependencies
|
||||||
|
|
||||||
|
// WgCtrlRepo is used to control local WireGuard devices via the wgctrl-go library.
|
||||||
|
type WgCtrlRepo interface {
|
||||||
|
io.Closer
|
||||||
|
Devices() ([]*wgtypes.Device, error)
|
||||||
|
Device(name string) (*wgtypes.Device, error)
|
||||||
|
ConfigureDevice(name string, cfg wgtypes.Config) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// A NetlinkClient is a type which can control a netlink device.
|
||||||
|
type NetlinkClient interface {
|
||||||
|
LinkAdd(link netlink.Link) error
|
||||||
|
LinkDel(link netlink.Link) error
|
||||||
|
LinkByName(name string) (netlink.Link, error)
|
||||||
|
LinkSetUp(link netlink.Link) error
|
||||||
|
LinkSetDown(link netlink.Link) error
|
||||||
|
LinkSetMTU(link netlink.Link, mtu int) error
|
||||||
|
AddrReplace(link netlink.Link, addr *netlink.Addr) error
|
||||||
|
AddrAdd(link netlink.Link, addr *netlink.Addr) error
|
||||||
|
AddrList(link netlink.Link) ([]netlink.Addr, error)
|
||||||
|
AddrDel(link netlink.Link, addr *netlink.Addr) error
|
||||||
|
RouteAdd(route *netlink.Route) error
|
||||||
|
RouteDel(route *netlink.Route) error
|
||||||
|
RouteReplace(route *netlink.Route) error
|
||||||
|
RouteList(link netlink.Link, family int) ([]netlink.Route, error)
|
||||||
|
RouteListFiltered(family int, filter *netlink.Route, filterMask uint64) ([]netlink.Route, error)
|
||||||
|
RuleAdd(rule *netlink.Rule) error
|
||||||
|
RuleDel(rule *netlink.Rule) error
|
||||||
|
RuleList(family int) ([]netlink.Rule, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion dependencies
|
||||||
|
|
||||||
|
type LocalController struct {
|
||||||
|
cfg *config.Config
|
||||||
|
|
||||||
|
wg WgCtrlRepo
|
||||||
|
nl NetlinkClient
|
||||||
|
|
||||||
|
shellCmd string
|
||||||
|
resolvConfIfacePrefix string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLocalController creates a new local controller instance.
|
||||||
|
// This repository is used to interact with the WireGuard kernel or userspace module.
|
||||||
|
func NewLocalController(cfg *config.Config) (*LocalController, error) {
|
||||||
|
wg, err := wgctrl.New()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create wgctrl client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nl := &lowlevel.NetlinkManager{}
|
||||||
|
|
||||||
|
repo := &LocalController{
|
||||||
|
cfg: cfg,
|
||||||
|
|
||||||
|
wg: wg,
|
||||||
|
nl: nl,
|
||||||
|
|
||||||
|
shellCmd: "bash", // we only support bash at the moment
|
||||||
|
resolvConfIfacePrefix: "tun.", // WireGuard interfaces have a tun. prefix in resolvconf
|
||||||
|
}
|
||||||
|
|
||||||
|
return repo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// region wireguard-related
|
||||||
|
|
||||||
|
func (c LocalController) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) {
|
||||||
|
devices, err := c.wg.Devices()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("device list error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
interfaces := make([]domain.PhysicalInterface, 0, len(devices))
|
||||||
|
for _, device := range devices {
|
||||||
|
interfaceModel, err := c.convertWireGuardInterface(device)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("interface convert failed for %s: %w", device.Name, err)
|
||||||
|
}
|
||||||
|
interfaces = append(interfaces, interfaceModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
return interfaces, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (
|
||||||
|
*domain.PhysicalInterface,
|
||||||
|
error,
|
||||||
|
) {
|
||||||
|
return c.getInterface(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) convertWireGuardInterface(device *wgtypes.Device) (domain.PhysicalInterface, error) {
|
||||||
|
// read data from wgctrl interface
|
||||||
|
|
||||||
|
iface := domain.PhysicalInterface{
|
||||||
|
Identifier: domain.InterfaceIdentifier(device.Name),
|
||||||
|
KeyPair: domain.KeyPair{
|
||||||
|
PrivateKey: device.PrivateKey.String(),
|
||||||
|
PublicKey: device.PublicKey.String(),
|
||||||
|
},
|
||||||
|
ListenPort: device.ListenPort,
|
||||||
|
Addresses: nil,
|
||||||
|
Mtu: 0,
|
||||||
|
FirewallMark: uint32(device.FirewallMark),
|
||||||
|
DeviceUp: false,
|
||||||
|
ImportSource: "wgctrl",
|
||||||
|
DeviceType: device.Type.String(),
|
||||||
|
BytesUpload: 0,
|
||||||
|
BytesDownload: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
// read data from netlink interface
|
||||||
|
|
||||||
|
lowLevelInterface, err := c.nl.LinkByName(device.Name)
|
||||||
|
if err != nil {
|
||||||
|
return domain.PhysicalInterface{}, fmt.Errorf("netlink error for %s: %w", device.Name, err)
|
||||||
|
}
|
||||||
|
ipAddresses, err := c.nl.AddrList(lowLevelInterface)
|
||||||
|
if err != nil {
|
||||||
|
return domain.PhysicalInterface{}, fmt.Errorf("ip read error for %s: %w", device.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range ipAddresses {
|
||||||
|
iface.Addresses = append(iface.Addresses, domain.CidrFromNetlinkAddr(addr))
|
||||||
|
}
|
||||||
|
iface.Mtu = lowLevelInterface.Attrs().MTU
|
||||||
|
iface.DeviceUp = lowLevelInterface.Attrs().OperState == netlink.OperUnknown // wg only supports unknown
|
||||||
|
if stats := lowLevelInterface.Attrs().Statistics; stats != nil {
|
||||||
|
iface.BytesUpload = stats.TxBytes
|
||||||
|
iface.BytesDownload = stats.RxBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
return iface, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) (
|
||||||
|
[]domain.PhysicalPeer,
|
||||||
|
error,
|
||||||
|
) {
|
||||||
|
device, err := c.wg.Device(string(deviceId))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("device error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
peers := make([]domain.PhysicalPeer, 0, len(device.Peers))
|
||||||
|
for _, peer := range device.Peers {
|
||||||
|
peerModel, err := c.convertWireGuardPeer(&peer)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("peer convert failed for %v: %w", peer.PublicKey, err)
|
||||||
|
}
|
||||||
|
peers = append(peers, peerModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
return peers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer, error) {
|
||||||
|
peerModel := domain.PhysicalPeer{
|
||||||
|
Identifier: domain.PeerIdentifier(peer.PublicKey.String()),
|
||||||
|
Endpoint: "",
|
||||||
|
AllowedIPs: nil,
|
||||||
|
KeyPair: domain.KeyPair{
|
||||||
|
PublicKey: peer.PublicKey.String(),
|
||||||
|
},
|
||||||
|
PresharedKey: "",
|
||||||
|
PersistentKeepalive: int(peer.PersistentKeepaliveInterval.Seconds()),
|
||||||
|
LastHandshake: peer.LastHandshakeTime,
|
||||||
|
ProtocolVersion: peer.ProtocolVersion,
|
||||||
|
BytesUpload: uint64(peer.ReceiveBytes),
|
||||||
|
BytesDownload: uint64(peer.TransmitBytes),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range peer.AllowedIPs {
|
||||||
|
peerModel.AllowedIPs = append(peerModel.AllowedIPs, domain.CidrFromIpNet(addr))
|
||||||
|
}
|
||||||
|
if peer.Endpoint != nil {
|
||||||
|
peerModel.Endpoint = peer.Endpoint.String()
|
||||||
|
}
|
||||||
|
if peer.PresharedKey != (wgtypes.Key{}) {
|
||||||
|
peerModel.PresharedKey = domain.PreSharedKey(peer.PresharedKey.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return peerModel, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) SaveInterface(
|
||||||
|
_ context.Context,
|
||||||
|
id domain.InterfaceIdentifier,
|
||||||
|
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
||||||
|
) error {
|
||||||
|
physicalInterface, err := c.getOrCreateInterface(id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if updateFunc != nil {
|
||||||
|
physicalInterface, err = updateFunc(physicalInterface)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.updateLowLevelInterface(physicalInterface); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := c.updateWireGuardInterface(physicalInterface); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) getOrCreateInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
|
||||||
|
device, err := c.getInterface(id)
|
||||||
|
if err == nil {
|
||||||
|
return device, nil // interface exists
|
||||||
|
}
|
||||||
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
return nil, fmt.Errorf("device error: %w", err) // unknown error
|
||||||
|
}
|
||||||
|
|
||||||
|
// create new device
|
||||||
|
if err := c.createLowLevelInterface(id); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
device, err = c.getInterface(id)
|
||||||
|
return device, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) getInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
|
||||||
|
device, err := c.wg.Device(string(id))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pi, err := c.convertWireGuardInterface(device)
|
||||||
|
return &pi, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) createLowLevelInterface(id domain.InterfaceIdentifier) error {
|
||||||
|
link := &netlink.GenericLink{
|
||||||
|
LinkAttrs: netlink.LinkAttrs{
|
||||||
|
Name: string(id),
|
||||||
|
},
|
||||||
|
LinkType: "wireguard",
|
||||||
|
}
|
||||||
|
err := c.nl.LinkAdd(link)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("link add failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) updateLowLevelInterface(pi *domain.PhysicalInterface) error {
|
||||||
|
link, err := c.nl.LinkByName(string(pi.Identifier))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if pi.Mtu != 0 {
|
||||||
|
if err := c.nl.LinkSetMTU(link, pi.Mtu); err != nil {
|
||||||
|
return fmt.Errorf("mtu error: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range pi.Addresses {
|
||||||
|
err := c.nl.AddrReplace(link, addr.NetlinkAddr())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set ip %s: %w", addr.String(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove unwanted IP addresses
|
||||||
|
rawAddresses, err := c.nl.AddrList(link)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to fetch interface ips: %w", err)
|
||||||
|
}
|
||||||
|
for _, rawAddr := range rawAddresses {
|
||||||
|
netlinkAddr := domain.CidrFromNetlinkAddr(rawAddr)
|
||||||
|
remove := true
|
||||||
|
for _, addr := range pi.Addresses {
|
||||||
|
if addr == netlinkAddr {
|
||||||
|
remove = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !remove {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.nl.AddrDel(link, &rawAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to remove deprecated ip %s: %w", netlinkAddr.String(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update link state
|
||||||
|
if pi.DeviceUp {
|
||||||
|
if err := c.nl.LinkSetUp(link); err != nil {
|
||||||
|
return fmt.Errorf("failed to bring up device: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := c.nl.LinkSetDown(link); err != nil {
|
||||||
|
return fmt.Errorf("failed to bring down device: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
|
||||||
|
pKey, err := wgtypes.NewKey(pi.KeyPair.GetPrivateKeyBytes())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var fwMark *int
|
||||||
|
if pi.FirewallMark != 0 {
|
||||||
|
intFwMark := int(pi.FirewallMark)
|
||||||
|
fwMark = &intFwMark
|
||||||
|
}
|
||||||
|
err = c.wg.ConfigureDevice(string(pi.Identifier), wgtypes.Config{
|
||||||
|
PrivateKey: &pKey,
|
||||||
|
ListenPort: &pi.ListenPort,
|
||||||
|
FirewallMark: fwMark,
|
||||||
|
ReplacePeers: false,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
|
||||||
|
if err := c.deleteLowLevelInterface(id); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) deleteLowLevelInterface(id domain.InterfaceIdentifier) error {
|
||||||
|
link, err := c.nl.LinkByName(string(id))
|
||||||
|
if err != nil {
|
||||||
|
var linkNotFoundError netlink.LinkNotFoundError
|
||||||
|
if errors.As(err, &linkNotFoundError) {
|
||||||
|
return nil // ignore not found error
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unable to find low level interface: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.nl.LinkDel(link)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete low level interface: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) SavePeer(
|
||||||
|
_ context.Context,
|
||||||
|
deviceId domain.InterfaceIdentifier,
|
||||||
|
id domain.PeerIdentifier,
|
||||||
|
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
||||||
|
) error {
|
||||||
|
physicalPeer, err := c.getOrCreatePeer(deviceId, id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
physicalPeer, err = updateFunc(physicalPeer)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.updatePeer(deviceId, physicalPeer); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) getOrCreatePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (
|
||||||
|
*domain.PhysicalPeer,
|
||||||
|
error,
|
||||||
|
) {
|
||||||
|
peer, err := c.getPeer(deviceId, id)
|
||||||
|
if err == nil {
|
||||||
|
return peer, nil // peer exists
|
||||||
|
}
|
||||||
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
return nil, fmt.Errorf("peer error: %w", err) // unknown error
|
||||||
|
}
|
||||||
|
|
||||||
|
// create new peer
|
||||||
|
err = c.wg.ConfigureDevice(string(deviceId), wgtypes.Config{
|
||||||
|
Peers: []wgtypes.PeerConfig{
|
||||||
|
{
|
||||||
|
PublicKey: id.ToPublicKey(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("peer create error for %s: %w", id.ToPublicKey(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
peer, err = c.getPeer(deviceId, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("peer error after create: %w", err)
|
||||||
|
}
|
||||||
|
return peer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) getPeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (
|
||||||
|
*domain.PhysicalPeer,
|
||||||
|
error,
|
||||||
|
) {
|
||||||
|
if !id.IsPublicKey() {
|
||||||
|
return nil, errors.New("invalid public key")
|
||||||
|
}
|
||||||
|
|
||||||
|
device, err := c.wg.Device(string(deviceId))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKey := id.ToPublicKey()
|
||||||
|
for _, peer := range device.Peers {
|
||||||
|
if peer.PublicKey != publicKey {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
peerModel, err := c.convertWireGuardPeer(&peer)
|
||||||
|
return &peerModel, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, os.ErrNotExist
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.PhysicalPeer) error {
|
||||||
|
cfg := wgtypes.PeerConfig{
|
||||||
|
PublicKey: pp.GetPublicKey(),
|
||||||
|
Remove: false,
|
||||||
|
UpdateOnly: true,
|
||||||
|
PresharedKey: pp.GetPresharedKey(),
|
||||||
|
Endpoint: pp.GetEndpointAddress(),
|
||||||
|
PersistentKeepaliveInterval: pp.GetPersistentKeepaliveTime(),
|
||||||
|
ReplaceAllowedIPs: true,
|
||||||
|
AllowedIPs: pp.GetAllowedIPs(),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) DeletePeer(
|
||||||
|
_ context.Context,
|
||||||
|
deviceId domain.InterfaceIdentifier,
|
||||||
|
id domain.PeerIdentifier,
|
||||||
|
) error {
|
||||||
|
if !id.IsPublicKey() {
|
||||||
|
return errors.New("invalid public key")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.deletePeer(deviceId, id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) deletePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error {
|
||||||
|
cfg := wgtypes.PeerConfig{
|
||||||
|
PublicKey: id.ToPublicKey(),
|
||||||
|
Remove: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion wireguard-related
|
||||||
|
|
||||||
|
// region wg-quick-related
|
||||||
|
|
||||||
|
func (c LocalController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
|
||||||
|
if hookCmd == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("executing interface hook", "interface", id, "hook", hookCmd)
|
||||||
|
err := c.exec(hookCmd, id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to exec hook: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
|
||||||
|
if dnsStr == "" && dnsSearchStr == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsServers := internal.SliceString(dnsStr)
|
||||||
|
dnsSearchDomains := internal.SliceString(dnsSearchStr)
|
||||||
|
|
||||||
|
dnsCommand := "resolvconf -a %resPref%i -m 0 -x"
|
||||||
|
dnsCommandInput := make([]string, 0, len(dnsServers)+len(dnsSearchDomains))
|
||||||
|
|
||||||
|
for _, dnsServer := range dnsServers {
|
||||||
|
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("nameserver %s", dnsServer))
|
||||||
|
}
|
||||||
|
for _, searchDomain := range dnsSearchDomains {
|
||||||
|
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("search %s", searchDomain))
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.exec(dnsCommand, id, dnsCommandInput...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"failed to set dns settings (is resolvconf available?, for systemd create this symlink: ln -s /usr/bin/resolvectl /usr/local/bin/resolvconf): %w",
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) UnsetDNS(id domain.InterfaceIdentifier) error {
|
||||||
|
dnsCommand := "resolvconf -d %resPref%i -f"
|
||||||
|
|
||||||
|
err := c.exec(dnsCommand, id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to unset dns settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) replaceCommandPlaceHolders(command string, interfaceId domain.InterfaceIdentifier) string {
|
||||||
|
command = strings.ReplaceAll(command, "%resPref", c.resolvConfIfacePrefix)
|
||||||
|
return strings.ReplaceAll(command, "%i", string(interfaceId))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) exec(command string, interfaceId domain.InterfaceIdentifier, stdin ...string) error {
|
||||||
|
commandWithInterfaceName := c.replaceCommandPlaceHolders(command, interfaceId)
|
||||||
|
cmd := exec.Command(c.shellCmd, "-ce", commandWithInterfaceName)
|
||||||
|
if len(stdin) > 0 {
|
||||||
|
b := &bytes.Buffer{}
|
||||||
|
for _, ln := range stdin {
|
||||||
|
if _, err := fmt.Fprint(b, ln); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cmd.Stdin = b
|
||||||
|
}
|
||||||
|
out, err := cmd.CombinedOutput() // execute and wait for output
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err)
|
||||||
|
}
|
||||||
|
slog.Debug("executed shell command",
|
||||||
|
"command", commandWithInterfaceName,
|
||||||
|
"output", string(out))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion wg-quick-related
|
||||||
|
|
||||||
|
// region routing-related
|
||||||
|
|
||||||
|
func (c LocalController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error {
|
||||||
|
// update fwmark rules
|
||||||
|
if err := c.setFwMarkRules(rules); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// update main rule
|
||||||
|
if err := c.setMainRule(rules); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup old main rules
|
||||||
|
if err := c.cleanupMainRule(rules); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) setFwMarkRules(rules []domain.RouteRule) error {
|
||||||
|
for _, rule := range rules {
|
||||||
|
existingRules, err := c.nl.RuleList(int(rule.IpFamily))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get existing rules for family %s: %w", rule.IpFamily, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleExists := false
|
||||||
|
for _, existingRule := range existingRules {
|
||||||
|
if rule.FwMark == existingRule.Mark && rule.Table == existingRule.Table {
|
||||||
|
ruleExists = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ruleExists {
|
||||||
|
continue // rule already exists, no need to recreate it
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a missing rule
|
||||||
|
if err := c.nl.RuleAdd(&netlink.Rule{
|
||||||
|
Family: int(rule.IpFamily),
|
||||||
|
Table: rule.Table,
|
||||||
|
Mark: rule.FwMark,
|
||||||
|
Invert: true,
|
||||||
|
SuppressIfgroup: -1,
|
||||||
|
SuppressPrefixlen: -1,
|
||||||
|
Priority: c.getRulePriority(existingRules),
|
||||||
|
Mask: nil,
|
||||||
|
Goto: -1,
|
||||||
|
Flow: -1,
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("failed to setup %s rule for fwmark %d and table %d: %w",
|
||||||
|
rule.IpFamily, rule.FwMark, rule.Table, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) getRulePriority(existingRules []netlink.Rule) int {
|
||||||
|
prio := 32700 // linux main rule has a priority of 32766
|
||||||
|
for {
|
||||||
|
isFresh := true
|
||||||
|
for _, existingRule := range existingRules {
|
||||||
|
if existingRule.Priority == prio {
|
||||||
|
isFresh = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isFresh {
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
prio--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return prio
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) setMainRule(rules []domain.RouteRule) error {
|
||||||
|
var family domain.IpFamily
|
||||||
|
shouldHaveMainRule := false
|
||||||
|
for _, rule := range rules {
|
||||||
|
family = rule.IpFamily
|
||||||
|
if rule.HasDefault == true {
|
||||||
|
shouldHaveMainRule = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !shouldHaveMainRule {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
existingRules, err := c.nl.RuleList(int(family))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get existing rules for family %s: %w", family, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleExists := false
|
||||||
|
for _, existingRule := range existingRules {
|
||||||
|
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
|
||||||
|
ruleExists = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ruleExists {
|
||||||
|
return nil // rule already exists, skip re-creation
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.nl.RuleAdd(&netlink.Rule{
|
||||||
|
Family: int(family),
|
||||||
|
Table: unix.RT_TABLE_MAIN,
|
||||||
|
SuppressIfgroup: -1,
|
||||||
|
SuppressPrefixlen: 0,
|
||||||
|
Priority: c.getMainRulePriority(existingRules),
|
||||||
|
Mark: 0,
|
||||||
|
Mask: nil,
|
||||||
|
Goto: -1,
|
||||||
|
Flow: -1,
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("failed to setup rule for main table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int {
|
||||||
|
priority := c.cfg.Advanced.RulePrioOffset
|
||||||
|
for {
|
||||||
|
isFresh := true
|
||||||
|
for _, existingRule := range existingRules {
|
||||||
|
if existingRule.Priority == priority {
|
||||||
|
isFresh = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isFresh {
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
priority++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return priority
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) cleanupMainRule(rules []domain.RouteRule) error {
|
||||||
|
var family domain.IpFamily
|
||||||
|
for _, rule := range rules {
|
||||||
|
family = rule.IpFamily
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
existingRules, err := c.nl.RuleList(int(family))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get existing rules for family %s: %w", family, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldHaveMainRule := false
|
||||||
|
for _, rule := range rules {
|
||||||
|
if rule.HasDefault == true {
|
||||||
|
shouldHaveMainRule = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mainRules := 0
|
||||||
|
for _, existingRule := range existingRules {
|
||||||
|
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
|
||||||
|
mainRules++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
removalCount := 0
|
||||||
|
if mainRules > 1 {
|
||||||
|
removalCount = mainRules - 1 // we only want one single rule
|
||||||
|
}
|
||||||
|
if !shouldHaveMainRule {
|
||||||
|
removalCount = mainRules
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, existingRule := range existingRules {
|
||||||
|
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
|
||||||
|
if removalCount > 0 {
|
||||||
|
existingRule.Family = int(family) // set family, somehow the RuleList method does not populate the family field
|
||||||
|
if err := c.nl.RuleDel(&existingRule); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete main rule: %w", err)
|
||||||
|
}
|
||||||
|
removalCount--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c LocalController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error {
|
||||||
|
// TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion routing-related
|
105
internal/adapters/wgcontroller/mikrotik.go
Normal file
105
internal/adapters/wgcontroller/mikrotik.go
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
package wgcontroller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MikrotikController struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMikrotikController() (*MikrotikController, error) {
|
||||||
|
return &MikrotikController{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// region wireguard-related
|
||||||
|
|
||||||
|
func (c MikrotikController) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) {
|
||||||
|
// TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MikrotikController) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (
|
||||||
|
*domain.PhysicalInterface,
|
||||||
|
error,
|
||||||
|
) {
|
||||||
|
// TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MikrotikController) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) (
|
||||||
|
[]domain.PhysicalPeer,
|
||||||
|
error,
|
||||||
|
) {
|
||||||
|
// TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MikrotikController) SaveInterface(
|
||||||
|
_ context.Context,
|
||||||
|
id domain.InterfaceIdentifier,
|
||||||
|
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
||||||
|
) error {
|
||||||
|
// TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MikrotikController) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
|
||||||
|
// TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MikrotikController) SavePeer(
|
||||||
|
_ context.Context,
|
||||||
|
deviceId domain.InterfaceIdentifier,
|
||||||
|
id domain.PeerIdentifier,
|
||||||
|
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
||||||
|
) error {
|
||||||
|
// TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MikrotikController) DeletePeer(
|
||||||
|
_ context.Context,
|
||||||
|
deviceId domain.InterfaceIdentifier,
|
||||||
|
id domain.PeerIdentifier,
|
||||||
|
) error {
|
||||||
|
// TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion wireguard-related
|
||||||
|
|
||||||
|
// region wg-quick-related
|
||||||
|
|
||||||
|
func (c MikrotikController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
|
||||||
|
// TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MikrotikController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
|
||||||
|
// TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MikrotikController) UnsetDNS(id domain.InterfaceIdentifier) error {
|
||||||
|
// TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion wg-quick-related
|
||||||
|
|
||||||
|
// region routing-related
|
||||||
|
|
||||||
|
func (c MikrotikController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error {
|
||||||
|
// TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MikrotikController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error {
|
||||||
|
// TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion routing-related
|
@ -21,17 +21,23 @@ import (
|
|||||||
//go:embed frontend_config.js.gotpl
|
//go:embed frontend_config.js.gotpl
|
||||||
var frontendJs embed.FS
|
var frontendJs embed.FS
|
||||||
|
|
||||||
|
type ControllerManager interface {
|
||||||
|
GetControllerNames() []config.BackendBase
|
||||||
|
}
|
||||||
|
|
||||||
type ConfigEndpoint struct {
|
type ConfigEndpoint struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
authenticator Authenticator
|
authenticator Authenticator
|
||||||
|
controllerMgr ControllerManager
|
||||||
|
|
||||||
tpl *respond.TemplateRenderer
|
tpl *respond.TemplateRenderer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConfigEndpoint(cfg *config.Config, authenticator Authenticator) ConfigEndpoint {
|
func NewConfigEndpoint(cfg *config.Config, authenticator Authenticator, ctrlMgr ControllerManager) ConfigEndpoint {
|
||||||
ep := ConfigEndpoint{
|
ep := ConfigEndpoint{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
authenticator: authenticator,
|
authenticator: authenticator,
|
||||||
|
controllerMgr: ctrlMgr,
|
||||||
tpl: respond.NewTemplateRenderer(template.Must(template.ParseFS(frontendJs,
|
tpl: respond.NewTemplateRenderer(template.Must(template.ParseFS(frontendJs,
|
||||||
"frontend_config.js.gotpl"))),
|
"frontend_config.js.gotpl"))),
|
||||||
}
|
}
|
||||||
@ -96,17 +102,21 @@ func (e ConfigEndpoint) handleSettingsGet() http.HandlerFunc {
|
|||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
sessionUser := domain.GetUserInfo(r.Context())
|
sessionUser := domain.GetUserInfo(r.Context())
|
||||||
|
|
||||||
nameFn := func(backend config.Backend) []model.SettingsBackendNames {
|
controllerFn := func() []model.SettingsBackendNames {
|
||||||
names := make([]model.SettingsBackendNames, 0, len(backend.Mikrotik)+1)
|
controllers := e.controllerMgr.GetControllerNames()
|
||||||
|
names := make([]model.SettingsBackendNames, 0, len(controllers))
|
||||||
|
|
||||||
names = append(names, model.SettingsBackendNames{
|
for _, controller := range controllers {
|
||||||
Id: backend.Default,
|
displayName := controller.DisplayName
|
||||||
Name: "modals.interface-edit.backend.local",
|
if displayName == "" {
|
||||||
})
|
displayName = controller.Id // fallback to ID if no display name is set
|
||||||
for _, b := range backend.Mikrotik {
|
}
|
||||||
|
if controller.Id == config.LocalBackendName {
|
||||||
|
displayName = "modals.interface-edit.backend.local" // use a localized string for the local backend
|
||||||
|
}
|
||||||
names = append(names, model.SettingsBackendNames{
|
names = append(names, model.SettingsBackendNames{
|
||||||
Id: b.Id,
|
Id: controller.Id,
|
||||||
Name: b.DisplayName,
|
Name: displayName,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -128,7 +138,7 @@ func (e ConfigEndpoint) handleSettingsGet() http.HandlerFunc {
|
|||||||
ApiAdminOnly: e.cfg.Advanced.ApiAdminOnly,
|
ApiAdminOnly: e.cfg.Advanced.ApiAdminOnly,
|
||||||
WebAuthnEnabled: e.cfg.Auth.WebAuthn.Enabled,
|
WebAuthnEnabled: e.cfg.Auth.WebAuthn.Enabled,
|
||||||
MinPasswordLength: e.cfg.Auth.MinPasswordLength,
|
MinPasswordLength: e.cfg.Auth.MinPasswordLength,
|
||||||
AvailableBackends: nameFn(e.cfg.Backend),
|
AvailableBackends: controllerFn(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -59,7 +59,7 @@ func NewInterface(src *domain.Interface, peers []domain.Peer) *Interface {
|
|||||||
Identifier: string(src.Identifier),
|
Identifier: string(src.Identifier),
|
||||||
DisplayName: src.DisplayName,
|
DisplayName: src.DisplayName,
|
||||||
Mode: string(src.Type),
|
Mode: string(src.Type),
|
||||||
Backend: config.LocalBackendName, // TODO: add backend support
|
Backend: string(src.Backend),
|
||||||
PrivateKey: src.PrivateKey,
|
PrivateKey: src.PrivateKey,
|
||||||
PublicKey: src.PublicKey,
|
PublicKey: src.PublicKey,
|
||||||
Disabled: src.IsDisabled(),
|
Disabled: src.IsDisabled(),
|
||||||
@ -95,6 +95,10 @@ func NewInterface(src *domain.Interface, peers []domain.Peer) *Interface {
|
|||||||
Filename: src.GetConfigFileName(),
|
Filename: src.GetConfigFileName(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if iface.Backend == "" {
|
||||||
|
iface.Backend = config.LocalBackendName // default to local backend
|
||||||
|
}
|
||||||
|
|
||||||
if len(peers) > 0 {
|
if len(peers) > 0 {
|
||||||
iface.TotalPeers = len(peers)
|
iface.TotalPeers = len(peers)
|
||||||
|
|
||||||
@ -149,6 +153,7 @@ func NewDomainInterface(src *Interface) *domain.Interface {
|
|||||||
SaveConfig: src.SaveConfig,
|
SaveConfig: src.SaveConfig,
|
||||||
DisplayName: src.DisplayName,
|
DisplayName: src.DisplayName,
|
||||||
Type: domain.InterfaceType(src.Mode),
|
Type: domain.InterfaceType(src.Mode),
|
||||||
|
Backend: domain.InterfaceBackend(src.Backend),
|
||||||
DriverType: "", // currently unused
|
DriverType: "", // currently unused
|
||||||
Disabled: nil, // set below
|
Disabled: nil, // set below
|
||||||
DisabledReason: src.DisabledReason,
|
DisabledReason: src.DisabledReason,
|
||||||
|
161
internal/app/wireguard/controller_manager.go
Normal file
161
internal/app/wireguard/controller_manager.go
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
package wireguard
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"maps"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/h44z/wg-portal/internal/adapters/wgcontroller"
|
||||||
|
"github.com/h44z/wg-portal/internal/config"
|
||||||
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
type InterfaceController interface {
|
||||||
|
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
|
||||||
|
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
||||||
|
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
||||||
|
SaveInterface(
|
||||||
|
_ context.Context,
|
||||||
|
id domain.InterfaceIdentifier,
|
||||||
|
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
||||||
|
) error
|
||||||
|
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
||||||
|
SavePeer(
|
||||||
|
_ context.Context,
|
||||||
|
deviceId domain.InterfaceIdentifier,
|
||||||
|
id domain.PeerIdentifier,
|
||||||
|
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
||||||
|
) error
|
||||||
|
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type backendInstance struct {
|
||||||
|
Config config.BackendBase // Config is the configuration for the backend instance.
|
||||||
|
Implementation InterfaceController
|
||||||
|
}
|
||||||
|
|
||||||
|
type ControllerManager struct {
|
||||||
|
cfg *config.Config
|
||||||
|
controllers map[domain.InterfaceBackend]backendInstance
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewControllerManager(cfg *config.Config) (*ControllerManager, error) {
|
||||||
|
c := &ControllerManager{
|
||||||
|
cfg: cfg,
|
||||||
|
controllers: make(map[domain.InterfaceBackend]backendInstance),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := c.init()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ControllerManager) init() error {
|
||||||
|
if err := c.registerLocalController(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.registerMikrotikControllers(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logRegisteredControllers()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ControllerManager) registerLocalController() error {
|
||||||
|
localController, err := wgcontroller.NewLocalController(c.cfg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create local WireGuard controller: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.controllers[config.LocalBackendName] = backendInstance{
|
||||||
|
Config: config.BackendBase{
|
||||||
|
Id: config.LocalBackendName,
|
||||||
|
DisplayName: "Local WireGuard Controller",
|
||||||
|
},
|
||||||
|
Implementation: localController,
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ControllerManager) registerMikrotikControllers() error {
|
||||||
|
for _, backendConfig := range c.cfg.Backend.Mikrotik {
|
||||||
|
if backendConfig.Id == config.LocalBackendName {
|
||||||
|
slog.Warn("skipping registration of Mikrotik controller with reserved ID", "id", config.LocalBackendName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
controller, err := wgcontroller.NewMikrotikController() // TODO: Pass backendConfig to the controller constructor
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create Mikrotik controller for backend %s: %w", backendConfig.Id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.controllers[domain.InterfaceBackend(backendConfig.Id)] = backendInstance{
|
||||||
|
Config: backendConfig.BackendBase,
|
||||||
|
Implementation: controller,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ControllerManager) logRegisteredControllers() {
|
||||||
|
for backend, controller := range c.controllers {
|
||||||
|
slog.Debug("backend controller registered",
|
||||||
|
"backend", backend, "type", fmt.Sprintf("%T", controller.Implementation))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) InterfaceController {
|
||||||
|
return c.getController(backend, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ControllerManager) GetController(iface domain.Interface) InterfaceController {
|
||||||
|
return c.getController(iface.Backend, iface.Identifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ControllerManager) getController(
|
||||||
|
backend domain.InterfaceBackend,
|
||||||
|
ifaceId domain.InterfaceIdentifier,
|
||||||
|
) InterfaceController {
|
||||||
|
if backend == "" {
|
||||||
|
// If no backend is specified, use the local controller.
|
||||||
|
// This might be the case for interfaces created in previous WireGuard Portal versions.
|
||||||
|
backend = config.LocalBackendName
|
||||||
|
}
|
||||||
|
|
||||||
|
controller, exists := c.controllers[backend]
|
||||||
|
if !exists {
|
||||||
|
controller, exists = c.controllers[config.LocalBackendName] // Fallback to local controller
|
||||||
|
if !exists {
|
||||||
|
// If the local controller is also not found, panic
|
||||||
|
panic(fmt.Sprintf("%s interface controller for backend %s not found", ifaceId, backend))
|
||||||
|
}
|
||||||
|
slog.Warn("controller for backend not found, using local controller",
|
||||||
|
"backend", backend, "interface", ifaceId)
|
||||||
|
}
|
||||||
|
return controller.Implementation
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ControllerManager) GetAllControllers() []InterfaceController {
|
||||||
|
var backendInstances = make([]InterfaceController, 0, len(c.controllers))
|
||||||
|
for instance := range maps.Values(c.controllers) {
|
||||||
|
backendInstances = append(backendInstances, instance.Implementation)
|
||||||
|
}
|
||||||
|
return backendInstances
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ControllerManager) GetControllerNames() []config.BackendBase {
|
||||||
|
var names []config.BackendBase
|
||||||
|
for _, id := range slices.Sorted(maps.Keys(c.controllers)) {
|
||||||
|
names = append(names, c.controllers[id].Config)
|
||||||
|
}
|
||||||
|
|
||||||
|
return names
|
||||||
|
}
|
@ -53,7 +53,7 @@ type StatisticsCollector struct {
|
|||||||
pingJobs chan domain.Peer
|
pingJobs chan domain.Peer
|
||||||
|
|
||||||
db StatisticsDatabaseRepo
|
db StatisticsDatabaseRepo
|
||||||
wg StatisticsInterfaceController
|
wg *ControllerManager
|
||||||
ms StatisticsMetricsServer
|
ms StatisticsMetricsServer
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,7 +62,7 @@ func NewStatisticsCollector(
|
|||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
bus StatisticsEventBus,
|
bus StatisticsEventBus,
|
||||||
db StatisticsDatabaseRepo,
|
db StatisticsDatabaseRepo,
|
||||||
wg StatisticsInterfaceController,
|
wg *ControllerManager,
|
||||||
ms StatisticsMetricsServer,
|
ms StatisticsMetricsServer,
|
||||||
) (*StatisticsCollector, error) {
|
) (*StatisticsCollector, error) {
|
||||||
c := &StatisticsCollector{
|
c := &StatisticsCollector{
|
||||||
@ -113,7 +113,7 @@ func (c *StatisticsCollector) collectInterfaceData(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, in := range interfaces {
|
for _, in := range interfaces {
|
||||||
physicalInterface, err := c.wg.GetInterface(ctx, in.Identifier)
|
physicalInterface, err := c.wg.GetController(in).GetInterface(ctx, in.Identifier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("failed to load physical interface for data collection", "interface", in.Identifier,
|
slog.Warn("failed to load physical interface for data collection", "interface", in.Identifier,
|
||||||
"error", err)
|
"error", err)
|
||||||
@ -165,7 +165,7 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, in := range interfaces {
|
for _, in := range interfaces {
|
||||||
peers, err := c.wg.GetPeers(ctx, in.Identifier)
|
peers, err := c.wg.GetController(in).GetPeers(ctx, in.Identifier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("failed to fetch peers for data collection", "interface", in.Identifier, "error", err)
|
slog.Warn("failed to fetch peers for data collection", "interface", in.Identifier, "error", err)
|
||||||
continue
|
continue
|
||||||
|
@ -37,25 +37,6 @@ type InterfaceAndPeerDatabaseRepo interface {
|
|||||||
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
|
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type InterfaceController interface {
|
|
||||||
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
|
|
||||||
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
|
||||||
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
|
||||||
SaveInterface(
|
|
||||||
_ context.Context,
|
|
||||||
id domain.InterfaceIdentifier,
|
|
||||||
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
|
||||||
) error
|
|
||||||
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
|
||||||
SavePeer(
|
|
||||||
_ context.Context,
|
|
||||||
deviceId domain.InterfaceIdentifier,
|
|
||||||
id domain.PeerIdentifier,
|
|
||||||
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
|
||||||
) error
|
|
||||||
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type WgQuickController interface {
|
type WgQuickController interface {
|
||||||
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
|
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
|
||||||
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
|
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
|
||||||
@ -75,7 +56,7 @@ type Manager struct {
|
|||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
bus EventBus
|
bus EventBus
|
||||||
db InterfaceAndPeerDatabaseRepo
|
db InterfaceAndPeerDatabaseRepo
|
||||||
wg InterfaceController
|
wg *ControllerManager
|
||||||
quick WgQuickController
|
quick WgQuickController
|
||||||
|
|
||||||
userLockMap *sync.Map
|
userLockMap *sync.Map
|
||||||
@ -84,7 +65,7 @@ type Manager struct {
|
|||||||
func NewWireGuardManager(
|
func NewWireGuardManager(
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
bus EventBus,
|
bus EventBus,
|
||||||
wg InterfaceController,
|
wg *ControllerManager,
|
||||||
quick WgQuickController,
|
quick WgQuickController,
|
||||||
db InterfaceAndPeerDatabaseRepo,
|
db InterfaceAndPeerDatabaseRepo,
|
||||||
) (*Manager, error) {
|
) (*Manager, error) {
|
||||||
|
@ -21,12 +21,17 @@ func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.Physical
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
|
var allPhysicalInterfaces []domain.PhysicalInterface
|
||||||
if err != nil {
|
for _, wgBackend := range m.wg.GetAllControllers() {
|
||||||
return nil, err
|
physicalInterfaces, err := wgBackend.GetInterfaces(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
allPhysicalInterfaces = append(allPhysicalInterfaces, physicalInterfaces...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return physicalInterfaces, nil
|
return allPhysicalInterfaces, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetInterfaceAndPeers returns the interface and all peers for the given interface identifier.
|
// GetInterfaceAndPeers returns the interface and all peers for the given interface identifier.
|
||||||
@ -109,47 +114,49 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
physicalInterfaces, err := m.wg.GetInterfaces(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// if no filter is given, exclude already existing interfaces
|
|
||||||
var excludedInterfaces []domain.InterfaceIdentifier
|
|
||||||
if len(filter) == 0 {
|
|
||||||
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
for _, existingInterface := range existingInterfaces {
|
|
||||||
excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
imported := 0
|
imported := 0
|
||||||
for _, physicalInterface := range physicalInterfaces {
|
for _, wgBackend := range m.wg.GetAllControllers() {
|
||||||
if slices.Contains(excludedInterfaces, physicalInterface.Identifier) {
|
physicalInterfaces, err := wgBackend.GetInterfaces(ctx)
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(filter) != 0 && !slices.Contains(filter, physicalInterface.Identifier) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("importing new interface", "interface", physicalInterface.Identifier)
|
|
||||||
|
|
||||||
physicalPeers, err := m.wg.GetPeers(ctx, physicalInterface.Identifier)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.importInterface(ctx, &physicalInterface, physicalPeers)
|
// if no filter is given, exclude already existing interfaces
|
||||||
if err != nil {
|
var excludedInterfaces []domain.InterfaceIdentifier
|
||||||
return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err)
|
if len(filter) == 0 {
|
||||||
|
existingInterfaces, err := m.db.GetAllInterfaces(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
for _, existingInterface := range existingInterfaces {
|
||||||
|
excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("imported new interface", "interface", physicalInterface.Identifier, "peers", len(physicalPeers))
|
for _, physicalInterface := range physicalInterfaces {
|
||||||
imported++
|
if slices.Contains(excludedInterfaces, physicalInterface.Identifier) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(filter) != 0 && !slices.Contains(filter, physicalInterface.Identifier) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("importing new interface", "interface", physicalInterface.Identifier)
|
||||||
|
|
||||||
|
physicalPeers, err := wgBackend.GetPeers(ctx, physicalInterface.Identifier)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.importInterface(ctx, &physicalInterface, physicalPeers)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("imported new interface", "interface", physicalInterface.Identifier, "peers", len(physicalPeers))
|
||||||
|
imported++
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return imported, nil
|
return imported, nil
|
||||||
@ -213,7 +220,7 @@ func (m Manager) RestoreInterfaceState(
|
|||||||
return fmt.Errorf("failed to load peers for %s: %w", iface.Identifier, err)
|
return fmt.Errorf("failed to load peers for %s: %w", iface.Identifier, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = m.wg.GetInterface(ctx, iface.Identifier)
|
_, err = m.wg.GetController(iface).GetInterface(ctx, iface.Identifier)
|
||||||
if err != nil && !iface.IsDisabled() {
|
if err != nil && !iface.IsDisabled() {
|
||||||
slog.Debug("creating missing interface", "interface", iface.Identifier)
|
slog.Debug("creating missing interface", "interface", iface.Identifier)
|
||||||
|
|
||||||
@ -261,17 +268,19 @@ func (m Manager) RestoreInterfaceState(
|
|||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
switch {
|
switch {
|
||||||
case iface.IsDisabled(): // if interface is disabled, delete all peers
|
case iface.IsDisabled(): // if interface is disabled, delete all peers
|
||||||
if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil {
|
if err := m.wg.GetController(iface).DeletePeer(ctx, iface.Identifier,
|
||||||
|
peer.Identifier); err != nil {
|
||||||
return fmt.Errorf("failed to remove peer %s for disabled interface %s: %w",
|
return fmt.Errorf("failed to remove peer %s for disabled interface %s: %w",
|
||||||
peer.Identifier, iface.Identifier, err)
|
peer.Identifier, iface.Identifier, err)
|
||||||
}
|
}
|
||||||
case peer.IsDisabled(): // if peer is disabled, delete it
|
case peer.IsDisabled(): // if peer is disabled, delete it
|
||||||
if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil {
|
if err := m.wg.GetController(iface).DeletePeer(ctx, iface.Identifier,
|
||||||
|
peer.Identifier); err != nil {
|
||||||
return fmt.Errorf("failed to remove disbaled peer %s from interface %s: %w",
|
return fmt.Errorf("failed to remove disbaled peer %s from interface %s: %w",
|
||||||
peer.Identifier, iface.Identifier, err)
|
peer.Identifier, iface.Identifier, err)
|
||||||
}
|
}
|
||||||
default: // update peer
|
default: // update peer
|
||||||
err := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier,
|
err := m.wg.GetController(iface).SavePeer(ctx, iface.Identifier, peer.Identifier,
|
||||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
||||||
domain.MergeToPhysicalPeer(pp, &peer)
|
domain.MergeToPhysicalPeer(pp, &peer)
|
||||||
return pp, nil
|
return pp, nil
|
||||||
@ -284,7 +293,7 @@ func (m Manager) RestoreInterfaceState(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// remove non-wgportal peers
|
// remove non-wgportal peers
|
||||||
physicalPeers, _ := m.wg.GetPeers(ctx, iface.Identifier)
|
physicalPeers, _ := m.wg.GetController(iface).GetPeers(ctx, iface.Identifier)
|
||||||
for _, physicalPeer := range physicalPeers {
|
for _, physicalPeer := range physicalPeers {
|
||||||
isWgPortalPeer := false
|
isWgPortalPeer := false
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
@ -294,7 +303,8 @@ func (m Manager) RestoreInterfaceState(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !isWgPortalPeer {
|
if !isWgPortalPeer {
|
||||||
err := m.wg.DeletePeer(ctx, iface.Identifier, domain.PeerIdentifier(physicalPeer.PublicKey))
|
err := m.wg.GetController(iface).DeletePeer(ctx, iface.Identifier,
|
||||||
|
domain.PeerIdentifier(physicalPeer.PublicKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to remove non-wgportal peer %s from interface %s: %w",
|
return fmt.Errorf("failed to remove non-wgportal peer %s from interface %s: %w",
|
||||||
physicalPeer.PublicKey, iface.Identifier, err)
|
physicalPeer.PublicKey, iface.Identifier, err)
|
||||||
@ -459,7 +469,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
|
|||||||
existingInterface.Disabled = &now // simulate a disabled interface
|
existingInterface.Disabled = &now // simulate a disabled interface
|
||||||
existingInterface.DisabledReason = domain.DisabledReasonDeleted
|
existingInterface.DisabledReason = domain.DisabledReasonDeleted
|
||||||
|
|
||||||
physicalInterface, _ := m.wg.GetInterface(ctx, id)
|
physicalInterface, _ := m.wg.GetController(*existingInterface).GetInterface(ctx, id)
|
||||||
|
|
||||||
if err := m.handleInterfacePreSaveHooks(true, existingInterface); err != nil {
|
if err := m.handleInterfacePreSaveHooks(true, existingInterface); err != nil {
|
||||||
return fmt.Errorf("pre-delete hooks failed: %w", err)
|
return fmt.Errorf("pre-delete hooks failed: %w", err)
|
||||||
@ -473,7 +483,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
|
|||||||
return fmt.Errorf("peer deletion failure: %w", err)
|
return fmt.Errorf("peer deletion failure: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.wg.DeleteInterface(ctx, id); err != nil {
|
if err := m.wg.GetController(*existingInterface).DeleteInterface(ctx, id); err != nil {
|
||||||
return fmt.Errorf("wireguard deletion failure: %w", err)
|
return fmt.Errorf("wireguard deletion failure: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -522,7 +532,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
|
|||||||
err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
|
err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
|
||||||
iface.CopyCalculatedAttributes(i)
|
iface.CopyCalculatedAttributes(i)
|
||||||
|
|
||||||
err := m.wg.SaveInterface(ctx, iface.Identifier,
|
err := m.wg.GetController(*iface).SaveInterface(ctx, iface.Identifier,
|
||||||
func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||||
domain.MergeToPhysicalInterface(pi, iface)
|
domain.MergeToPhysicalInterface(pi, iface)
|
||||||
return pi, nil
|
return pi, nil
|
||||||
@ -538,7 +548,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
if iface.IsDisabled() {
|
if iface.IsDisabled() {
|
||||||
physicalInterface, _ := m.wg.GetInterface(ctx, iface.Identifier)
|
physicalInterface, _ := m.wg.GetController(*iface).GetInterface(ctx, iface.Identifier)
|
||||||
fwMark := iface.FirewallMark
|
fwMark := iface.FirewallMark
|
||||||
if physicalInterface != nil && fwMark == 0 {
|
if physicalInterface != nil && fwMark == 0 {
|
||||||
fwMark = physicalInterface.FirewallMark
|
fwMark = physicalInterface.FirewallMark
|
||||||
@ -576,7 +586,7 @@ func (m Manager) hasInterfaceStateChanged(ctx context.Context, iface *domain.Int
|
|||||||
return true // interface in db has changed
|
return true // interface in db has changed
|
||||||
}
|
}
|
||||||
|
|
||||||
wgInterface, err := m.wg.GetInterface(ctx, iface.Identifier)
|
wgInterface, err := m.wg.GetController(*iface).GetInterface(ctx, iface.Identifier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true // interface might not exist - so we assume that there must be a change
|
return true // interface might not exist - so we assume that there must be a change
|
||||||
}
|
}
|
||||||
@ -844,12 +854,12 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) error {
|
func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) error {
|
||||||
allPeers, err := m.db.GetInterfacePeers(ctx, id)
|
iface, allPeers, err := m.db.GetInterfaceAndPeers(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, peer := range allPeers {
|
for _, peer := range allPeers {
|
||||||
err = m.wg.DeletePeer(ctx, id, peer.Identifier)
|
err = m.wg.GetController(*iface).DeletePeer(ctx, id, peer.Identifier)
|
||||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||||
return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err)
|
return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err)
|
||||||
}
|
}
|
||||||
|
@ -352,7 +352,12 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
|
|||||||
return fmt.Errorf("delete not allowed: %w", err)
|
return fmt.Errorf("delete not allowed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, id)
|
iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.wg.GetController(*iface).DeletePeer(ctx, peer.InterfaceIdentifier, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("wireguard failed to delete peer %s: %w", id, err)
|
return fmt.Errorf("wireguard failed to delete peer %s: %w", id, err)
|
||||||
}
|
}
|
||||||
@ -414,14 +419,18 @@ func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier)
|
|||||||
func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
|
func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
|
||||||
interfaces := make(map[domain.InterfaceIdentifier]struct{})
|
interfaces := make(map[domain.InterfaceIdentifier]struct{})
|
||||||
|
|
||||||
for i := range peers {
|
for _, peer := range peers {
|
||||||
peer := peers[i]
|
iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
|
||||||
var err error
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
|
||||||
|
}
|
||||||
|
|
||||||
if peer.IsDisabled() || peer.IsExpired() {
|
if peer.IsDisabled() || peer.IsExpired() {
|
||||||
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
||||||
peer.CopyCalculatedAttributes(p)
|
peer.CopyCalculatedAttributes(p)
|
||||||
|
|
||||||
if err := m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, peer.Identifier); err != nil {
|
if err := m.wg.GetController(*iface).DeletePeer(ctx, peer.InterfaceIdentifier,
|
||||||
|
peer.Identifier); err != nil {
|
||||||
return nil, fmt.Errorf("failed to delete wireguard peer %s: %w", peer.Identifier, err)
|
return nil, fmt.Errorf("failed to delete wireguard peer %s: %w", peer.Identifier, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -431,7 +440,7 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
|
|||||||
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
||||||
peer.CopyCalculatedAttributes(p)
|
peer.CopyCalculatedAttributes(p)
|
||||||
|
|
||||||
err := m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
|
err := m.wg.GetController(*iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
|
||||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
||||||
domain.MergeToPhysicalPeer(pp, peer)
|
domain.MergeToPhysicalPeer(pp, peer)
|
||||||
return pp, nil
|
return pp, nil
|
||||||
|
@ -38,9 +38,13 @@ func (b *Backend) Validate() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type BackendBase struct {
|
||||||
|
Id string `yaml:"id"` // A unique id for the backend
|
||||||
|
DisplayName string `yaml:"display_name"` // A display name for the backend
|
||||||
|
}
|
||||||
|
|
||||||
type BackendMikrotik struct {
|
type BackendMikrotik struct {
|
||||||
Id string `yaml:"id"` // A unique id for the Mikrotik backend
|
BackendBase `yaml:",inline"` // Embed the base fields
|
||||||
DisplayName string `yaml:"display_name"` // A display name for the Mikrotik backend
|
|
||||||
|
|
||||||
ApiUrl string `yaml:"api_url"` // The base URL of the Mikrotik API (e.g., "https://10.10.10.10:8729/rest")
|
ApiUrl string `yaml:"api_url"` // The base URL of the Mikrotik API (e.g., "https://10.10.10.10:8729/rest")
|
||||||
ApiUser string `yaml:"api_user"`
|
ApiUser string `yaml:"api_user"`
|
||||||
|
@ -10,6 +10,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
"github.com/h44z/wg-portal/internal"
|
"github.com/h44z/wg-portal/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -23,6 +25,7 @@ var allowedFileNameRegex = regexp.MustCompile("[^a-zA-Z0-9-_]+")
|
|||||||
|
|
||||||
type InterfaceIdentifier string
|
type InterfaceIdentifier string
|
||||||
type InterfaceType string
|
type InterfaceType string
|
||||||
|
type InterfaceBackend string
|
||||||
|
|
||||||
type Interface struct {
|
type Interface struct {
|
||||||
BaseModel
|
BaseModel
|
||||||
@ -49,11 +52,12 @@ type Interface struct {
|
|||||||
SaveConfig bool // automatically persist config changes to the wgX.conf file
|
SaveConfig bool // automatically persist config changes to the wgX.conf file
|
||||||
|
|
||||||
// WG Portal specific
|
// WG Portal specific
|
||||||
DisplayName string // a nice display name/ description for the interface
|
DisplayName string // a nice display name/ description for the interface
|
||||||
Type InterfaceType // the interface type, either InterfaceTypeServer or InterfaceTypeClient
|
Type InterfaceType // the interface type, either InterfaceTypeServer or InterfaceTypeClient
|
||||||
DriverType string // the interface driver type (linux, software, ...)
|
Backend InterfaceBackend // the backend that is used to manage the interface (wgctrl, mikrotik, ...)
|
||||||
Disabled *time.Time `gorm:"index"` // flag that specifies if the interface is enabled (up) or not (down)
|
DriverType string // the interface driver type (linux, software, ...)
|
||||||
DisabledReason string // the reason why the interface has been disabled
|
Disabled *time.Time `gorm:"index"` // flag that specifies if the interface is enabled (up) or not (down)
|
||||||
|
DisabledReason string // the reason why the interface has been disabled
|
||||||
|
|
||||||
// Default settings for the peer, used for new peers, those settings will be published to ConfigOption options of
|
// Default settings for the peer, used for new peers, those settings will be published to ConfigOption options of
|
||||||
// the peer config
|
// the peer config
|
||||||
@ -279,3 +283,30 @@ func (r RoutingTableInfo) GetRoutingTable() int {
|
|||||||
|
|
||||||
return r.Table
|
return r.Table
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type IpFamily int
|
||||||
|
|
||||||
|
const (
|
||||||
|
IpFamilyIPv4 IpFamily = unix.AF_INET
|
||||||
|
IpFamilyIPv6 IpFamily = unix.AF_INET6
|
||||||
|
)
|
||||||
|
|
||||||
|
func (f IpFamily) String() string {
|
||||||
|
switch f {
|
||||||
|
case IpFamilyIPv4:
|
||||||
|
return "IPv4"
|
||||||
|
case IpFamilyIPv6:
|
||||||
|
return "IPv6"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouteRule represents a routing table rule.
|
||||||
|
type RouteRule struct {
|
||||||
|
InterfaceId InterfaceIdentifier
|
||||||
|
IpFamily IpFamily
|
||||||
|
FwMark uint32
|
||||||
|
Table int
|
||||||
|
HasDefault bool
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user