From 15d035ec10d4a069fc714cfdc3e5eccb2b21e039 Mon Sep 17 00:00:00 2001 From: Christoph Haas Date: Fri, 30 May 2025 23:19:48 +0200 Subject: [PATCH] wip: create different backend handlers (#426) --- cmd/wg-portal/main.go | 5 +- internal/adapters/wgcontroller/local.go | 809 ++++++++++++++++++ internal/adapters/wgcontroller/mikrotik.go | 105 +++ .../app/api/v0/handlers/endpoint_config.go | 32 +- internal/app/api/v0/model/models_interface.go | 7 +- internal/app/wireguard/controller_manager.go | 161 ++++ internal/app/wireguard/statistics.go | 8 +- internal/app/wireguard/wireguard.go | 23 +- .../app/wireguard/wireguard_interfaces.go | 112 +-- internal/app/wireguard/wireguard_peers.go | 21 +- internal/config/backend.go | 8 +- internal/domain/interface.go | 41 +- 12 files changed, 1229 insertions(+), 103 deletions(-) create mode 100644 internal/adapters/wgcontroller/local.go create mode 100644 internal/adapters/wgcontroller/mikrotik.go create mode 100644 internal/app/wireguard/controller_manager.go diff --git a/cmd/wg-portal/main.go b/cmd/wg-portal/main.go index dbd2020..753cb81 100644 --- a/cmd/wg-portal/main.go +++ b/cmd/wg-portal/main.go @@ -50,7 +50,8 @@ func main() { database, err := adapters.NewSqlRepository(rawDb) internal.AssertNoError(err) - wireGuard := adapters.NewWireGuardRepository() + wireGuard, err := wireguard.NewControllerManager(cfg) + internal.AssertNoError(err) wgQuick := adapters.NewWgQuickRepo() @@ -133,7 +134,7 @@ func main() { apiV0EndpointUsers := handlersV0.NewUserEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendUsers) apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendInterfaces) apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendPeers) - apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth) + apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth, wireGuard) apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth) apiFrontend := handlersV0.NewRestApi(apiV0Session, diff --git a/internal/adapters/wgcontroller/local.go b/internal/adapters/wgcontroller/local.go new file mode 100644 index 0000000..b1541aa --- /dev/null +++ b/internal/adapters/wgcontroller/local.go @@ -0,0 +1,809 @@ +package wgcontroller + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "log/slog" + "os" + "os/exec" + "strings" + + "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/h44z/wg-portal/internal" + "github.com/h44z/wg-portal/internal/config" + "github.com/h44z/wg-portal/internal/domain" + "github.com/h44z/wg-portal/internal/lowlevel" +) + +// region dependencies + +// WgCtrlRepo is used to control local WireGuard devices via the wgctrl-go library. +type WgCtrlRepo interface { + io.Closer + Devices() ([]*wgtypes.Device, error) + Device(name string) (*wgtypes.Device, error) + ConfigureDevice(name string, cfg wgtypes.Config) error +} + +// A NetlinkClient is a type which can control a netlink device. +type NetlinkClient interface { + LinkAdd(link netlink.Link) error + LinkDel(link netlink.Link) error + LinkByName(name string) (netlink.Link, error) + LinkSetUp(link netlink.Link) error + LinkSetDown(link netlink.Link) error + LinkSetMTU(link netlink.Link, mtu int) error + AddrReplace(link netlink.Link, addr *netlink.Addr) error + AddrAdd(link netlink.Link, addr *netlink.Addr) error + AddrList(link netlink.Link) ([]netlink.Addr, error) + AddrDel(link netlink.Link, addr *netlink.Addr) error + RouteAdd(route *netlink.Route) error + RouteDel(route *netlink.Route) error + RouteReplace(route *netlink.Route) error + RouteList(link netlink.Link, family int) ([]netlink.Route, error) + RouteListFiltered(family int, filter *netlink.Route, filterMask uint64) ([]netlink.Route, error) + RuleAdd(rule *netlink.Rule) error + RuleDel(rule *netlink.Rule) error + RuleList(family int) ([]netlink.Rule, error) +} + +// endregion dependencies + +type LocalController struct { + cfg *config.Config + + wg WgCtrlRepo + nl NetlinkClient + + shellCmd string + resolvConfIfacePrefix string +} + +// NewLocalController creates a new local controller instance. +// This repository is used to interact with the WireGuard kernel or userspace module. +func NewLocalController(cfg *config.Config) (*LocalController, error) { + wg, err := wgctrl.New() + if err != nil { + return nil, fmt.Errorf("failed to create wgctrl client: %w", err) + } + + nl := &lowlevel.NetlinkManager{} + + repo := &LocalController{ + cfg: cfg, + + wg: wg, + nl: nl, + + shellCmd: "bash", // we only support bash at the moment + resolvConfIfacePrefix: "tun.", // WireGuard interfaces have a tun. prefix in resolvconf + } + + return repo, nil +} + +// region wireguard-related + +func (c LocalController) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) { + devices, err := c.wg.Devices() + if err != nil { + return nil, fmt.Errorf("device list error: %w", err) + } + + interfaces := make([]domain.PhysicalInterface, 0, len(devices)) + for _, device := range devices { + interfaceModel, err := c.convertWireGuardInterface(device) + if err != nil { + return nil, fmt.Errorf("interface convert failed for %s: %w", device.Name, err) + } + interfaces = append(interfaces, interfaceModel) + } + + return interfaces, nil +} + +func (c LocalController) GetInterface(_ context.Context, id domain.InterfaceIdentifier) ( + *domain.PhysicalInterface, + error, +) { + return c.getInterface(id) +} + +func (c LocalController) convertWireGuardInterface(device *wgtypes.Device) (domain.PhysicalInterface, error) { + // read data from wgctrl interface + + iface := domain.PhysicalInterface{ + Identifier: domain.InterfaceIdentifier(device.Name), + KeyPair: domain.KeyPair{ + PrivateKey: device.PrivateKey.String(), + PublicKey: device.PublicKey.String(), + }, + ListenPort: device.ListenPort, + Addresses: nil, + Mtu: 0, + FirewallMark: uint32(device.FirewallMark), + DeviceUp: false, + ImportSource: "wgctrl", + DeviceType: device.Type.String(), + BytesUpload: 0, + BytesDownload: 0, + } + + // read data from netlink interface + + lowLevelInterface, err := c.nl.LinkByName(device.Name) + if err != nil { + return domain.PhysicalInterface{}, fmt.Errorf("netlink error for %s: %w", device.Name, err) + } + ipAddresses, err := c.nl.AddrList(lowLevelInterface) + if err != nil { + return domain.PhysicalInterface{}, fmt.Errorf("ip read error for %s: %w", device.Name, err) + } + + for _, addr := range ipAddresses { + iface.Addresses = append(iface.Addresses, domain.CidrFromNetlinkAddr(addr)) + } + iface.Mtu = lowLevelInterface.Attrs().MTU + iface.DeviceUp = lowLevelInterface.Attrs().OperState == netlink.OperUnknown // wg only supports unknown + if stats := lowLevelInterface.Attrs().Statistics; stats != nil { + iface.BytesUpload = stats.TxBytes + iface.BytesDownload = stats.RxBytes + } + + return iface, nil +} + +func (c LocalController) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ( + []domain.PhysicalPeer, + error, +) { + device, err := c.wg.Device(string(deviceId)) + if err != nil { + return nil, fmt.Errorf("device error: %w", err) + } + + peers := make([]domain.PhysicalPeer, 0, len(device.Peers)) + for _, peer := range device.Peers { + peerModel, err := c.convertWireGuardPeer(&peer) + if err != nil { + return nil, fmt.Errorf("peer convert failed for %v: %w", peer.PublicKey, err) + } + peers = append(peers, peerModel) + } + + return peers, nil +} + +func (c LocalController) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer, error) { + peerModel := domain.PhysicalPeer{ + Identifier: domain.PeerIdentifier(peer.PublicKey.String()), + Endpoint: "", + AllowedIPs: nil, + KeyPair: domain.KeyPair{ + PublicKey: peer.PublicKey.String(), + }, + PresharedKey: "", + PersistentKeepalive: int(peer.PersistentKeepaliveInterval.Seconds()), + LastHandshake: peer.LastHandshakeTime, + ProtocolVersion: peer.ProtocolVersion, + BytesUpload: uint64(peer.ReceiveBytes), + BytesDownload: uint64(peer.TransmitBytes), + } + + for _, addr := range peer.AllowedIPs { + peerModel.AllowedIPs = append(peerModel.AllowedIPs, domain.CidrFromIpNet(addr)) + } + if peer.Endpoint != nil { + peerModel.Endpoint = peer.Endpoint.String() + } + if peer.PresharedKey != (wgtypes.Key{}) { + peerModel.PresharedKey = domain.PreSharedKey(peer.PresharedKey.String()) + } + + return peerModel, nil +} + +func (c LocalController) SaveInterface( + _ context.Context, + id domain.InterfaceIdentifier, + updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error), +) error { + physicalInterface, err := c.getOrCreateInterface(id) + if err != nil { + return err + } + + if updateFunc != nil { + physicalInterface, err = updateFunc(physicalInterface) + if err != nil { + return err + } + } + + if err := c.updateLowLevelInterface(physicalInterface); err != nil { + return err + } + if err := c.updateWireGuardInterface(physicalInterface); err != nil { + return err + } + + return nil +} + +func (c LocalController) getOrCreateInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) { + device, err := c.getInterface(id) + if err == nil { + return device, nil // interface exists + } + if !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("device error: %w", err) // unknown error + } + + // create new device + if err := c.createLowLevelInterface(id); err != nil { + return nil, err + } + + device, err = c.getInterface(id) + return device, err +} + +func (c LocalController) getInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) { + device, err := c.wg.Device(string(id)) + if err != nil { + return nil, err + } + + pi, err := c.convertWireGuardInterface(device) + return &pi, err +} + +func (c LocalController) createLowLevelInterface(id domain.InterfaceIdentifier) error { + link := &netlink.GenericLink{ + LinkAttrs: netlink.LinkAttrs{ + Name: string(id), + }, + LinkType: "wireguard", + } + err := c.nl.LinkAdd(link) + if err != nil { + return fmt.Errorf("link add failed: %w", err) + } + + return nil +} + +func (c LocalController) updateLowLevelInterface(pi *domain.PhysicalInterface) error { + link, err := c.nl.LinkByName(string(pi.Identifier)) + if err != nil { + return err + } + if pi.Mtu != 0 { + if err := c.nl.LinkSetMTU(link, pi.Mtu); err != nil { + return fmt.Errorf("mtu error: %w", err) + } + } + + for _, addr := range pi.Addresses { + err := c.nl.AddrReplace(link, addr.NetlinkAddr()) + if err != nil { + return fmt.Errorf("failed to set ip %s: %w", addr.String(), err) + } + } + + // Remove unwanted IP addresses + rawAddresses, err := c.nl.AddrList(link) + if err != nil { + return fmt.Errorf("failed to fetch interface ips: %w", err) + } + for _, rawAddr := range rawAddresses { + netlinkAddr := domain.CidrFromNetlinkAddr(rawAddr) + remove := true + for _, addr := range pi.Addresses { + if addr == netlinkAddr { + remove = false + break + } + } + + if !remove { + continue + } + + err := c.nl.AddrDel(link, &rawAddr) + if err != nil { + return fmt.Errorf("failed to remove deprecated ip %s: %w", netlinkAddr.String(), err) + } + } + + // Update link state + if pi.DeviceUp { + if err := c.nl.LinkSetUp(link); err != nil { + return fmt.Errorf("failed to bring up device: %w", err) + } + } else { + if err := c.nl.LinkSetDown(link); err != nil { + return fmt.Errorf("failed to bring down device: %w", err) + } + } + + return nil +} + +func (c LocalController) updateWireGuardInterface(pi *domain.PhysicalInterface) error { + pKey, err := wgtypes.NewKey(pi.KeyPair.GetPrivateKeyBytes()) + if err != nil { + return err + } + + var fwMark *int + if pi.FirewallMark != 0 { + intFwMark := int(pi.FirewallMark) + fwMark = &intFwMark + } + err = c.wg.ConfigureDevice(string(pi.Identifier), wgtypes.Config{ + PrivateKey: &pKey, + ListenPort: &pi.ListenPort, + FirewallMark: fwMark, + ReplacePeers: false, + }) + if err != nil { + return err + } + + return nil +} + +func (c LocalController) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error { + if err := c.deleteLowLevelInterface(id); err != nil { + return err + } + + return nil +} + +func (c LocalController) deleteLowLevelInterface(id domain.InterfaceIdentifier) error { + link, err := c.nl.LinkByName(string(id)) + if err != nil { + var linkNotFoundError netlink.LinkNotFoundError + if errors.As(err, &linkNotFoundError) { + return nil // ignore not found error + } + return fmt.Errorf("unable to find low level interface: %w", err) + } + + err = c.nl.LinkDel(link) + if err != nil { + return fmt.Errorf("failed to delete low level interface: %w", err) + } + + return nil +} + +func (c LocalController) SavePeer( + _ context.Context, + deviceId domain.InterfaceIdentifier, + id domain.PeerIdentifier, + updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error), +) error { + physicalPeer, err := c.getOrCreatePeer(deviceId, id) + if err != nil { + return err + } + + physicalPeer, err = updateFunc(physicalPeer) + if err != nil { + return err + } + + if err := c.updatePeer(deviceId, physicalPeer); err != nil { + return err + } + + return nil +} + +func (c LocalController) getOrCreatePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) ( + *domain.PhysicalPeer, + error, +) { + peer, err := c.getPeer(deviceId, id) + if err == nil { + return peer, nil // peer exists + } + if !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("peer error: %w", err) // unknown error + } + + // create new peer + err = c.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: id.ToPublicKey(), + }, + }, + }) + if err != nil { + return nil, fmt.Errorf("peer create error for %s: %w", id.ToPublicKey(), err) + } + + peer, err = c.getPeer(deviceId, id) + if err != nil { + return nil, fmt.Errorf("peer error after create: %w", err) + } + return peer, nil +} + +func (c LocalController) getPeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) ( + *domain.PhysicalPeer, + error, +) { + if !id.IsPublicKey() { + return nil, errors.New("invalid public key") + } + + device, err := c.wg.Device(string(deviceId)) + if err != nil { + return nil, err + } + + publicKey := id.ToPublicKey() + for _, peer := range device.Peers { + if peer.PublicKey != publicKey { + continue + } + + peerModel, err := c.convertWireGuardPeer(&peer) + return &peerModel, err + } + + return nil, os.ErrNotExist +} + +func (c LocalController) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.PhysicalPeer) error { + cfg := wgtypes.PeerConfig{ + PublicKey: pp.GetPublicKey(), + Remove: false, + UpdateOnly: true, + PresharedKey: pp.GetPresharedKey(), + Endpoint: pp.GetEndpointAddress(), + PersistentKeepaliveInterval: pp.GetPersistentKeepaliveTime(), + ReplaceAllowedIPs: true, + AllowedIPs: pp.GetAllowedIPs(), + } + + err := c.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}}) + if err != nil { + return err + } + + return nil +} + +func (c LocalController) DeletePeer( + _ context.Context, + deviceId domain.InterfaceIdentifier, + id domain.PeerIdentifier, +) error { + if !id.IsPublicKey() { + return errors.New("invalid public key") + } + + err := c.deletePeer(deviceId, id) + if err != nil { + return err + } + + return nil +} + +func (c LocalController) deletePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error { + cfg := wgtypes.PeerConfig{ + PublicKey: id.ToPublicKey(), + Remove: true, + } + + err := c.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}}) + if err != nil { + return err + } + + return nil +} + +// endregion wireguard-related + +// region wg-quick-related + +func (c LocalController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error { + if hookCmd == "" { + return nil + } + + slog.Debug("executing interface hook", "interface", id, "hook", hookCmd) + err := c.exec(hookCmd, id) + if err != nil { + return fmt.Errorf("failed to exec hook: %w", err) + } + + return nil +} + +func (c LocalController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error { + if dnsStr == "" && dnsSearchStr == "" { + return nil + } + + dnsServers := internal.SliceString(dnsStr) + dnsSearchDomains := internal.SliceString(dnsSearchStr) + + dnsCommand := "resolvconf -a %resPref%i -m 0 -x" + dnsCommandInput := make([]string, 0, len(dnsServers)+len(dnsSearchDomains)) + + for _, dnsServer := range dnsServers { + dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("nameserver %s", dnsServer)) + } + for _, searchDomain := range dnsSearchDomains { + dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("search %s", searchDomain)) + } + + err := c.exec(dnsCommand, id, dnsCommandInput...) + if err != nil { + return fmt.Errorf( + "failed to set dns settings (is resolvconf available?, for systemd create this symlink: ln -s /usr/bin/resolvectl /usr/local/bin/resolvconf): %w", + err, + ) + } + + return nil +} + +func (c LocalController) UnsetDNS(id domain.InterfaceIdentifier) error { + dnsCommand := "resolvconf -d %resPref%i -f" + + err := c.exec(dnsCommand, id) + if err != nil { + return fmt.Errorf("failed to unset dns settings: %w", err) + } + + return nil +} + +func (c LocalController) replaceCommandPlaceHolders(command string, interfaceId domain.InterfaceIdentifier) string { + command = strings.ReplaceAll(command, "%resPref", c.resolvConfIfacePrefix) + return strings.ReplaceAll(command, "%i", string(interfaceId)) +} + +func (c LocalController) exec(command string, interfaceId domain.InterfaceIdentifier, stdin ...string) error { + commandWithInterfaceName := c.replaceCommandPlaceHolders(command, interfaceId) + cmd := exec.Command(c.shellCmd, "-ce", commandWithInterfaceName) + if len(stdin) > 0 { + b := &bytes.Buffer{} + for _, ln := range stdin { + if _, err := fmt.Fprint(b, ln); err != nil { + return err + } + } + cmd.Stdin = b + } + out, err := cmd.CombinedOutput() // execute and wait for output + if err != nil { + return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err) + } + slog.Debug("executed shell command", + "command", commandWithInterfaceName, + "output", string(out)) + return nil +} + +// endregion wg-quick-related + +// region routing-related + +func (c LocalController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error { + // update fwmark rules + if err := c.setFwMarkRules(rules); err != nil { + return err + } + + // update main rule + if err := c.setMainRule(rules); err != nil { + return err + } + + // cleanup old main rules + if err := c.cleanupMainRule(rules); err != nil { + return err + } + + return nil +} + +func (c LocalController) setFwMarkRules(rules []domain.RouteRule) error { + for _, rule := range rules { + existingRules, err := c.nl.RuleList(int(rule.IpFamily)) + if err != nil { + return fmt.Errorf("failed to get existing rules for family %s: %w", rule.IpFamily, err) + } + + ruleExists := false + for _, existingRule := range existingRules { + if rule.FwMark == existingRule.Mark && rule.Table == existingRule.Table { + ruleExists = true + break + } + } + + if ruleExists { + continue // rule already exists, no need to recreate it + } + + // create a missing rule + if err := c.nl.RuleAdd(&netlink.Rule{ + Family: int(rule.IpFamily), + Table: rule.Table, + Mark: rule.FwMark, + Invert: true, + SuppressIfgroup: -1, + SuppressPrefixlen: -1, + Priority: c.getRulePriority(existingRules), + Mask: nil, + Goto: -1, + Flow: -1, + }); err != nil { + return fmt.Errorf("failed to setup %s rule for fwmark %d and table %d: %w", + rule.IpFamily, rule.FwMark, rule.Table, err) + } + } + return nil +} + +func (c LocalController) getRulePriority(existingRules []netlink.Rule) int { + prio := 32700 // linux main rule has a priority of 32766 + for { + isFresh := true + for _, existingRule := range existingRules { + if existingRule.Priority == prio { + isFresh = false + break + } + } + if isFresh { + break + } else { + prio-- + } + } + return prio +} + +func (c LocalController) setMainRule(rules []domain.RouteRule) error { + var family domain.IpFamily + shouldHaveMainRule := false + for _, rule := range rules { + family = rule.IpFamily + if rule.HasDefault == true { + shouldHaveMainRule = true + break + } + } + if !shouldHaveMainRule { + return nil + } + + existingRules, err := c.nl.RuleList(int(family)) + if err != nil { + return fmt.Errorf("failed to get existing rules for family %s: %w", family, err) + } + + ruleExists := false + for _, existingRule := range existingRules { + if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { + ruleExists = true + break + } + } + + if ruleExists { + return nil // rule already exists, skip re-creation + } + + if err := c.nl.RuleAdd(&netlink.Rule{ + Family: int(family), + Table: unix.RT_TABLE_MAIN, + SuppressIfgroup: -1, + SuppressPrefixlen: 0, + Priority: c.getMainRulePriority(existingRules), + Mark: 0, + Mask: nil, + Goto: -1, + Flow: -1, + }); err != nil { + return fmt.Errorf("failed to setup rule for main table: %w", err) + } + + return nil +} + +func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int { + priority := c.cfg.Advanced.RulePrioOffset + for { + isFresh := true + for _, existingRule := range existingRules { + if existingRule.Priority == priority { + isFresh = false + break + } + } + if isFresh { + break + } else { + priority++ + } + } + return priority +} + +func (c LocalController) cleanupMainRule(rules []domain.RouteRule) error { + var family domain.IpFamily + for _, rule := range rules { + family = rule.IpFamily + break + } + + existingRules, err := c.nl.RuleList(int(family)) + if err != nil { + return fmt.Errorf("failed to get existing rules for family %s: %w", family, err) + } + + shouldHaveMainRule := false + for _, rule := range rules { + if rule.HasDefault == true { + shouldHaveMainRule = true + break + } + } + + mainRules := 0 + for _, existingRule := range existingRules { + if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { + mainRules++ + } + } + + removalCount := 0 + if mainRules > 1 { + removalCount = mainRules - 1 // we only want one single rule + } + if !shouldHaveMainRule { + removalCount = mainRules + } + + for _, existingRule := range existingRules { + if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { + if removalCount > 0 { + existingRule.Family = int(family) // set family, somehow the RuleList method does not populate the family field + if err := c.nl.RuleDel(&existingRule); err != nil { + return fmt.Errorf("failed to delete main rule: %w", err) + } + removalCount-- + } + } + } + + return nil +} + +func (c LocalController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error { + // TODO implement me + panic("implement me") +} + +// endregion routing-related diff --git a/internal/adapters/wgcontroller/mikrotik.go b/internal/adapters/wgcontroller/mikrotik.go new file mode 100644 index 0000000..c01fae7 --- /dev/null +++ b/internal/adapters/wgcontroller/mikrotik.go @@ -0,0 +1,105 @@ +package wgcontroller + +import ( + "context" + + "github.com/h44z/wg-portal/internal/domain" +) + +type MikrotikController struct { +} + +func NewMikrotikController() (*MikrotikController, error) { + return &MikrotikController{}, nil +} + +// region wireguard-related + +func (c MikrotikController) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) { + // TODO implement me + panic("implement me") +} + +func (c MikrotikController) GetInterface(_ context.Context, id domain.InterfaceIdentifier) ( + *domain.PhysicalInterface, + error, +) { + // TODO implement me + panic("implement me") +} + +func (c MikrotikController) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ( + []domain.PhysicalPeer, + error, +) { + // TODO implement me + panic("implement me") +} + +func (c MikrotikController) SaveInterface( + _ context.Context, + id domain.InterfaceIdentifier, + updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error), +) error { + // TODO implement me + panic("implement me") +} + +func (c MikrotikController) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error { + // TODO implement me + panic("implement me") +} + +func (c MikrotikController) SavePeer( + _ context.Context, + deviceId domain.InterfaceIdentifier, + id domain.PeerIdentifier, + updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error), +) error { + // TODO implement me + panic("implement me") +} + +func (c MikrotikController) DeletePeer( + _ context.Context, + deviceId domain.InterfaceIdentifier, + id domain.PeerIdentifier, +) error { + // TODO implement me + panic("implement me") +} + +// endregion wireguard-related + +// region wg-quick-related + +func (c MikrotikController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error { + // TODO implement me + panic("implement me") +} + +func (c MikrotikController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error { + // TODO implement me + panic("implement me") +} + +func (c MikrotikController) UnsetDNS(id domain.InterfaceIdentifier) error { + // TODO implement me + panic("implement me") +} + +// endregion wg-quick-related + +// region routing-related + +func (c MikrotikController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error { + // TODO implement me + panic("implement me") +} + +func (c MikrotikController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error { + // TODO implement me + panic("implement me") +} + +// endregion routing-related diff --git a/internal/app/api/v0/handlers/endpoint_config.go b/internal/app/api/v0/handlers/endpoint_config.go index 30bdf51..1bc8a25 100644 --- a/internal/app/api/v0/handlers/endpoint_config.go +++ b/internal/app/api/v0/handlers/endpoint_config.go @@ -21,17 +21,23 @@ import ( //go:embed frontend_config.js.gotpl var frontendJs embed.FS +type ControllerManager interface { + GetControllerNames() []config.BackendBase +} + type ConfigEndpoint struct { cfg *config.Config authenticator Authenticator + controllerMgr ControllerManager tpl *respond.TemplateRenderer } -func NewConfigEndpoint(cfg *config.Config, authenticator Authenticator) ConfigEndpoint { +func NewConfigEndpoint(cfg *config.Config, authenticator Authenticator, ctrlMgr ControllerManager) ConfigEndpoint { ep := ConfigEndpoint{ cfg: cfg, authenticator: authenticator, + controllerMgr: ctrlMgr, tpl: respond.NewTemplateRenderer(template.Must(template.ParseFS(frontendJs, "frontend_config.js.gotpl"))), } @@ -96,17 +102,21 @@ func (e ConfigEndpoint) handleSettingsGet() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { sessionUser := domain.GetUserInfo(r.Context()) - nameFn := func(backend config.Backend) []model.SettingsBackendNames { - names := make([]model.SettingsBackendNames, 0, len(backend.Mikrotik)+1) + controllerFn := func() []model.SettingsBackendNames { + controllers := e.controllerMgr.GetControllerNames() + names := make([]model.SettingsBackendNames, 0, len(controllers)) - names = append(names, model.SettingsBackendNames{ - Id: backend.Default, - Name: "modals.interface-edit.backend.local", - }) - for _, b := range backend.Mikrotik { + for _, controller := range controllers { + displayName := controller.DisplayName + if displayName == "" { + displayName = controller.Id // fallback to ID if no display name is set + } + if controller.Id == config.LocalBackendName { + displayName = "modals.interface-edit.backend.local" // use a localized string for the local backend + } names = append(names, model.SettingsBackendNames{ - Id: b.Id, - Name: b.DisplayName, + Id: controller.Id, + Name: displayName, }) } @@ -128,7 +138,7 @@ func (e ConfigEndpoint) handleSettingsGet() http.HandlerFunc { ApiAdminOnly: e.cfg.Advanced.ApiAdminOnly, WebAuthnEnabled: e.cfg.Auth.WebAuthn.Enabled, MinPasswordLength: e.cfg.Auth.MinPasswordLength, - AvailableBackends: nameFn(e.cfg.Backend), + AvailableBackends: controllerFn(), }) } } diff --git a/internal/app/api/v0/model/models_interface.go b/internal/app/api/v0/model/models_interface.go index a5c3600..1b22d02 100644 --- a/internal/app/api/v0/model/models_interface.go +++ b/internal/app/api/v0/model/models_interface.go @@ -59,7 +59,7 @@ func NewInterface(src *domain.Interface, peers []domain.Peer) *Interface { Identifier: string(src.Identifier), DisplayName: src.DisplayName, Mode: string(src.Type), - Backend: config.LocalBackendName, // TODO: add backend support + Backend: string(src.Backend), PrivateKey: src.PrivateKey, PublicKey: src.PublicKey, Disabled: src.IsDisabled(), @@ -95,6 +95,10 @@ func NewInterface(src *domain.Interface, peers []domain.Peer) *Interface { Filename: src.GetConfigFileName(), } + if iface.Backend == "" { + iface.Backend = config.LocalBackendName // default to local backend + } + if len(peers) > 0 { iface.TotalPeers = len(peers) @@ -149,6 +153,7 @@ func NewDomainInterface(src *Interface) *domain.Interface { SaveConfig: src.SaveConfig, DisplayName: src.DisplayName, Type: domain.InterfaceType(src.Mode), + Backend: domain.InterfaceBackend(src.Backend), DriverType: "", // currently unused Disabled: nil, // set below DisabledReason: src.DisabledReason, diff --git a/internal/app/wireguard/controller_manager.go b/internal/app/wireguard/controller_manager.go new file mode 100644 index 0000000..7a5ccd5 --- /dev/null +++ b/internal/app/wireguard/controller_manager.go @@ -0,0 +1,161 @@ +package wireguard + +import ( + "context" + "fmt" + "log/slog" + "maps" + "slices" + + "github.com/h44z/wg-portal/internal/adapters/wgcontroller" + "github.com/h44z/wg-portal/internal/config" + "github.com/h44z/wg-portal/internal/domain" +) + +type InterfaceController interface { + GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) + GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) + GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) + SaveInterface( + _ context.Context, + id domain.InterfaceIdentifier, + updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error), + ) error + DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error + SavePeer( + _ context.Context, + deviceId domain.InterfaceIdentifier, + id domain.PeerIdentifier, + updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error), + ) error + DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error +} + +type backendInstance struct { + Config config.BackendBase // Config is the configuration for the backend instance. + Implementation InterfaceController +} + +type ControllerManager struct { + cfg *config.Config + controllers map[domain.InterfaceBackend]backendInstance +} + +func NewControllerManager(cfg *config.Config) (*ControllerManager, error) { + c := &ControllerManager{ + cfg: cfg, + controllers: make(map[domain.InterfaceBackend]backendInstance), + } + + err := c.init() + if err != nil { + return nil, err + } + + return c, nil +} + +func (c *ControllerManager) init() error { + if err := c.registerLocalController(); err != nil { + return err + } + + if err := c.registerMikrotikControllers(); err != nil { + return err + } + + c.logRegisteredControllers() + + return nil +} + +func (c *ControllerManager) registerLocalController() error { + localController, err := wgcontroller.NewLocalController(c.cfg) + if err != nil { + return fmt.Errorf("failed to create local WireGuard controller: %w", err) + } + + c.controllers[config.LocalBackendName] = backendInstance{ + Config: config.BackendBase{ + Id: config.LocalBackendName, + DisplayName: "Local WireGuard Controller", + }, + Implementation: localController, + } + return nil +} + +func (c *ControllerManager) registerMikrotikControllers() error { + for _, backendConfig := range c.cfg.Backend.Mikrotik { + if backendConfig.Id == config.LocalBackendName { + slog.Warn("skipping registration of Mikrotik controller with reserved ID", "id", config.LocalBackendName) + continue + } + + controller, err := wgcontroller.NewMikrotikController() // TODO: Pass backendConfig to the controller constructor + if err != nil { + return fmt.Errorf("failed to create Mikrotik controller for backend %s: %w", backendConfig.Id, err) + } + + c.controllers[domain.InterfaceBackend(backendConfig.Id)] = backendInstance{ + Config: backendConfig.BackendBase, + Implementation: controller, + } + } + return nil +} + +func (c *ControllerManager) logRegisteredControllers() { + for backend, controller := range c.controllers { + slog.Debug("backend controller registered", + "backend", backend, "type", fmt.Sprintf("%T", controller.Implementation)) + } +} + +func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) InterfaceController { + return c.getController(backend, "") +} + +func (c *ControllerManager) GetController(iface domain.Interface) InterfaceController { + return c.getController(iface.Backend, iface.Identifier) +} + +func (c *ControllerManager) getController( + backend domain.InterfaceBackend, + ifaceId domain.InterfaceIdentifier, +) InterfaceController { + if backend == "" { + // If no backend is specified, use the local controller. + // This might be the case for interfaces created in previous WireGuard Portal versions. + backend = config.LocalBackendName + } + + controller, exists := c.controllers[backend] + if !exists { + controller, exists = c.controllers[config.LocalBackendName] // Fallback to local controller + if !exists { + // If the local controller is also not found, panic + panic(fmt.Sprintf("%s interface controller for backend %s not found", ifaceId, backend)) + } + slog.Warn("controller for backend not found, using local controller", + "backend", backend, "interface", ifaceId) + } + return controller.Implementation +} + +func (c *ControllerManager) GetAllControllers() []InterfaceController { + var backendInstances = make([]InterfaceController, 0, len(c.controllers)) + for instance := range maps.Values(c.controllers) { + backendInstances = append(backendInstances, instance.Implementation) + } + return backendInstances +} + +func (c *ControllerManager) GetControllerNames() []config.BackendBase { + var names []config.BackendBase + for _, id := range slices.Sorted(maps.Keys(c.controllers)) { + names = append(names, c.controllers[id].Config) + } + + return names +} diff --git a/internal/app/wireguard/statistics.go b/internal/app/wireguard/statistics.go index ffa7571..c9098a9 100644 --- a/internal/app/wireguard/statistics.go +++ b/internal/app/wireguard/statistics.go @@ -53,7 +53,7 @@ type StatisticsCollector struct { pingJobs chan domain.Peer db StatisticsDatabaseRepo - wg StatisticsInterfaceController + wg *ControllerManager ms StatisticsMetricsServer } @@ -62,7 +62,7 @@ func NewStatisticsCollector( cfg *config.Config, bus StatisticsEventBus, db StatisticsDatabaseRepo, - wg StatisticsInterfaceController, + wg *ControllerManager, ms StatisticsMetricsServer, ) (*StatisticsCollector, error) { c := &StatisticsCollector{ @@ -113,7 +113,7 @@ func (c *StatisticsCollector) collectInterfaceData(ctx context.Context) { } for _, in := range interfaces { - physicalInterface, err := c.wg.GetInterface(ctx, in.Identifier) + physicalInterface, err := c.wg.GetController(in).GetInterface(ctx, in.Identifier) if err != nil { slog.Warn("failed to load physical interface for data collection", "interface", in.Identifier, "error", err) @@ -165,7 +165,7 @@ func (c *StatisticsCollector) collectPeerData(ctx context.Context) { } for _, in := range interfaces { - peers, err := c.wg.GetPeers(ctx, in.Identifier) + peers, err := c.wg.GetController(in).GetPeers(ctx, in.Identifier) if err != nil { slog.Warn("failed to fetch peers for data collection", "interface", in.Identifier, "error", err) continue diff --git a/internal/app/wireguard/wireguard.go b/internal/app/wireguard/wireguard.go index eb76bc9..b28f70e 100644 --- a/internal/app/wireguard/wireguard.go +++ b/internal/app/wireguard/wireguard.go @@ -37,25 +37,6 @@ type InterfaceAndPeerDatabaseRepo interface { GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error) } -type InterfaceController interface { - GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) - GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) - GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) - SaveInterface( - _ context.Context, - id domain.InterfaceIdentifier, - updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error), - ) error - DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error - SavePeer( - _ context.Context, - deviceId domain.InterfaceIdentifier, - id domain.PeerIdentifier, - updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error), - ) error - DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error -} - type WgQuickController interface { ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error @@ -75,7 +56,7 @@ type Manager struct { cfg *config.Config bus EventBus db InterfaceAndPeerDatabaseRepo - wg InterfaceController + wg *ControllerManager quick WgQuickController userLockMap *sync.Map @@ -84,7 +65,7 @@ type Manager struct { func NewWireGuardManager( cfg *config.Config, bus EventBus, - wg InterfaceController, + wg *ControllerManager, quick WgQuickController, db InterfaceAndPeerDatabaseRepo, ) (*Manager, error) { diff --git a/internal/app/wireguard/wireguard_interfaces.go b/internal/app/wireguard/wireguard_interfaces.go index f232eec..a2a727f 100644 --- a/internal/app/wireguard/wireguard_interfaces.go +++ b/internal/app/wireguard/wireguard_interfaces.go @@ -21,12 +21,17 @@ func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.Physical return nil, err } - physicalInterfaces, err := m.wg.GetInterfaces(ctx) - if err != nil { - return nil, err + var allPhysicalInterfaces []domain.PhysicalInterface + for _, wgBackend := range m.wg.GetAllControllers() { + physicalInterfaces, err := wgBackend.GetInterfaces(ctx) + if err != nil { + return nil, err + } + + allPhysicalInterfaces = append(allPhysicalInterfaces, physicalInterfaces...) } - return physicalInterfaces, nil + return allPhysicalInterfaces, nil } // GetInterfaceAndPeers returns the interface and all peers for the given interface identifier. @@ -109,47 +114,49 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter return 0, err } - physicalInterfaces, err := m.wg.GetInterfaces(ctx) - if err != nil { - return 0, err - } - - // if no filter is given, exclude already existing interfaces - var excludedInterfaces []domain.InterfaceIdentifier - if len(filter) == 0 { - existingInterfaces, err := m.db.GetAllInterfaces(ctx) - if err != nil { - return 0, err - } - for _, existingInterface := range existingInterfaces { - excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier) - } - } - imported := 0 - for _, physicalInterface := range physicalInterfaces { - if slices.Contains(excludedInterfaces, physicalInterface.Identifier) { - continue - } - - if len(filter) != 0 && !slices.Contains(filter, physicalInterface.Identifier) { - continue - } - - slog.Info("importing new interface", "interface", physicalInterface.Identifier) - - physicalPeers, err := m.wg.GetPeers(ctx, physicalInterface.Identifier) + for _, wgBackend := range m.wg.GetAllControllers() { + physicalInterfaces, err := wgBackend.GetInterfaces(ctx) if err != nil { return 0, err } - err = m.importInterface(ctx, &physicalInterface, physicalPeers) - if err != nil { - return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err) + // if no filter is given, exclude already existing interfaces + var excludedInterfaces []domain.InterfaceIdentifier + if len(filter) == 0 { + existingInterfaces, err := m.db.GetAllInterfaces(ctx) + if err != nil { + return 0, err + } + for _, existingInterface := range existingInterfaces { + excludedInterfaces = append(excludedInterfaces, existingInterface.Identifier) + } } - slog.Info("imported new interface", "interface", physicalInterface.Identifier, "peers", len(physicalPeers)) - imported++ + for _, physicalInterface := range physicalInterfaces { + if slices.Contains(excludedInterfaces, physicalInterface.Identifier) { + continue + } + + if len(filter) != 0 && !slices.Contains(filter, physicalInterface.Identifier) { + continue + } + + slog.Info("importing new interface", "interface", physicalInterface.Identifier) + + physicalPeers, err := wgBackend.GetPeers(ctx, physicalInterface.Identifier) + if err != nil { + return 0, err + } + + err = m.importInterface(ctx, &physicalInterface, physicalPeers) + if err != nil { + return 0, fmt.Errorf("import of %s failed: %w", physicalInterface.Identifier, err) + } + + slog.Info("imported new interface", "interface", physicalInterface.Identifier, "peers", len(physicalPeers)) + imported++ + } } return imported, nil @@ -213,7 +220,7 @@ func (m Manager) RestoreInterfaceState( return fmt.Errorf("failed to load peers for %s: %w", iface.Identifier, err) } - _, err = m.wg.GetInterface(ctx, iface.Identifier) + _, err = m.wg.GetController(iface).GetInterface(ctx, iface.Identifier) if err != nil && !iface.IsDisabled() { slog.Debug("creating missing interface", "interface", iface.Identifier) @@ -261,17 +268,19 @@ func (m Manager) RestoreInterfaceState( for _, peer := range peers { switch { case iface.IsDisabled(): // if interface is disabled, delete all peers - if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil { + if err := m.wg.GetController(iface).DeletePeer(ctx, iface.Identifier, + peer.Identifier); err != nil { return fmt.Errorf("failed to remove peer %s for disabled interface %s: %w", peer.Identifier, iface.Identifier, err) } case peer.IsDisabled(): // if peer is disabled, delete it - if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil { + if err := m.wg.GetController(iface).DeletePeer(ctx, iface.Identifier, + peer.Identifier); err != nil { return fmt.Errorf("failed to remove disbaled peer %s from interface %s: %w", peer.Identifier, iface.Identifier, err) } default: // update peer - err := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier, + err := m.wg.GetController(iface).SavePeer(ctx, iface.Identifier, peer.Identifier, func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) { domain.MergeToPhysicalPeer(pp, &peer) return pp, nil @@ -284,7 +293,7 @@ func (m Manager) RestoreInterfaceState( } // remove non-wgportal peers - physicalPeers, _ := m.wg.GetPeers(ctx, iface.Identifier) + physicalPeers, _ := m.wg.GetController(iface).GetPeers(ctx, iface.Identifier) for _, physicalPeer := range physicalPeers { isWgPortalPeer := false for _, peer := range peers { @@ -294,7 +303,8 @@ func (m Manager) RestoreInterfaceState( } } if !isWgPortalPeer { - err := m.wg.DeletePeer(ctx, iface.Identifier, domain.PeerIdentifier(physicalPeer.PublicKey)) + err := m.wg.GetController(iface).DeletePeer(ctx, iface.Identifier, + domain.PeerIdentifier(physicalPeer.PublicKey)) if err != nil { return fmt.Errorf("failed to remove non-wgportal peer %s from interface %s: %w", physicalPeer.PublicKey, iface.Identifier, err) @@ -459,7 +469,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif existingInterface.Disabled = &now // simulate a disabled interface existingInterface.DisabledReason = domain.DisabledReasonDeleted - physicalInterface, _ := m.wg.GetInterface(ctx, id) + physicalInterface, _ := m.wg.GetController(*existingInterface).GetInterface(ctx, id) if err := m.handleInterfacePreSaveHooks(true, existingInterface); err != nil { return fmt.Errorf("pre-delete hooks failed: %w", err) @@ -473,7 +483,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif return fmt.Errorf("peer deletion failure: %w", err) } - if err := m.wg.DeleteInterface(ctx, id); err != nil { + if err := m.wg.GetController(*existingInterface).DeleteInterface(ctx, id); err != nil { return fmt.Errorf("wireguard deletion failure: %w", err) } @@ -522,7 +532,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) ( err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) { iface.CopyCalculatedAttributes(i) - err := m.wg.SaveInterface(ctx, iface.Identifier, + err := m.wg.GetController(*iface).SaveInterface(ctx, iface.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) { domain.MergeToPhysicalInterface(pi, iface) return pi, nil @@ -538,7 +548,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) ( } if iface.IsDisabled() { - physicalInterface, _ := m.wg.GetInterface(ctx, iface.Identifier) + physicalInterface, _ := m.wg.GetController(*iface).GetInterface(ctx, iface.Identifier) fwMark := iface.FirewallMark if physicalInterface != nil && fwMark == 0 { fwMark = physicalInterface.FirewallMark @@ -576,7 +586,7 @@ func (m Manager) hasInterfaceStateChanged(ctx context.Context, iface *domain.Int return true // interface in db has changed } - wgInterface, err := m.wg.GetInterface(ctx, iface.Identifier) + wgInterface, err := m.wg.GetController(*iface).GetInterface(ctx, iface.Identifier) if err != nil { return true // interface might not exist - so we assume that there must be a change } @@ -844,12 +854,12 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain } func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) error { - allPeers, err := m.db.GetInterfacePeers(ctx, id) + iface, allPeers, err := m.db.GetInterfaceAndPeers(ctx, id) if err != nil { return err } for _, peer := range allPeers { - err = m.wg.DeletePeer(ctx, id, peer.Identifier) + err = m.wg.GetController(*iface).DeletePeer(ctx, id, peer.Identifier) if err != nil && !errors.Is(err, os.ErrNotExist) { return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err) } diff --git a/internal/app/wireguard/wireguard_peers.go b/internal/app/wireguard/wireguard_peers.go index 2131323..e3c78bf 100644 --- a/internal/app/wireguard/wireguard_peers.go +++ b/internal/app/wireguard/wireguard_peers.go @@ -352,7 +352,12 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error return fmt.Errorf("delete not allowed: %w", err) } - err = m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, id) + iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier) + if err != nil { + return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err) + } + + err = m.wg.GetController(*iface).DeletePeer(ctx, peer.InterfaceIdentifier, id) if err != nil { return fmt.Errorf("wireguard failed to delete peer %s: %w", id, err) } @@ -414,14 +419,18 @@ func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error { interfaces := make(map[domain.InterfaceIdentifier]struct{}) - for i := range peers { - peer := peers[i] - var err error + for _, peer := range peers { + iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier) + if err != nil { + return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err) + } + if peer.IsDisabled() || peer.IsExpired() { err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) { peer.CopyCalculatedAttributes(p) - if err := m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, peer.Identifier); err != nil { + if err := m.wg.GetController(*iface).DeletePeer(ctx, peer.InterfaceIdentifier, + peer.Identifier); err != nil { return nil, fmt.Errorf("failed to delete wireguard peer %s: %w", peer.Identifier, err) } @@ -431,7 +440,7 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error { err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) { peer.CopyCalculatedAttributes(p) - err := m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier, + err := m.wg.GetController(*iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier, func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) { domain.MergeToPhysicalPeer(pp, peer) return pp, nil diff --git a/internal/config/backend.go b/internal/config/backend.go index dc55b0f..714568f 100644 --- a/internal/config/backend.go +++ b/internal/config/backend.go @@ -38,9 +38,13 @@ func (b *Backend) Validate() error { return nil } +type BackendBase struct { + Id string `yaml:"id"` // A unique id for the backend + DisplayName string `yaml:"display_name"` // A display name for the backend +} + type BackendMikrotik struct { - Id string `yaml:"id"` // A unique id for the Mikrotik backend - DisplayName string `yaml:"display_name"` // A display name for the Mikrotik backend + BackendBase `yaml:",inline"` // Embed the base fields ApiUrl string `yaml:"api_url"` // The base URL of the Mikrotik API (e.g., "https://10.10.10.10:8729/rest") ApiUser string `yaml:"api_user"` diff --git a/internal/domain/interface.go b/internal/domain/interface.go index 977f7d3..05699ec 100644 --- a/internal/domain/interface.go +++ b/internal/domain/interface.go @@ -10,6 +10,8 @@ import ( "strings" "time" + "golang.org/x/sys/unix" + "github.com/h44z/wg-portal/internal" ) @@ -23,6 +25,7 @@ var allowedFileNameRegex = regexp.MustCompile("[^a-zA-Z0-9-_]+") type InterfaceIdentifier string type InterfaceType string +type InterfaceBackend string type Interface struct { BaseModel @@ -49,11 +52,12 @@ type Interface struct { SaveConfig bool // automatically persist config changes to the wgX.conf file // WG Portal specific - DisplayName string // a nice display name/ description for the interface - Type InterfaceType // the interface type, either InterfaceTypeServer or InterfaceTypeClient - DriverType string // the interface driver type (linux, software, ...) - Disabled *time.Time `gorm:"index"` // flag that specifies if the interface is enabled (up) or not (down) - DisabledReason string // the reason why the interface has been disabled + DisplayName string // a nice display name/ description for the interface + Type InterfaceType // the interface type, either InterfaceTypeServer or InterfaceTypeClient + Backend InterfaceBackend // the backend that is used to manage the interface (wgctrl, mikrotik, ...) + DriverType string // the interface driver type (linux, software, ...) + Disabled *time.Time `gorm:"index"` // flag that specifies if the interface is enabled (up) or not (down) + DisabledReason string // the reason why the interface has been disabled // Default settings for the peer, used for new peers, those settings will be published to ConfigOption options of // the peer config @@ -279,3 +283,30 @@ func (r RoutingTableInfo) GetRoutingTable() int { return r.Table } + +type IpFamily int + +const ( + IpFamilyIPv4 IpFamily = unix.AF_INET + IpFamilyIPv6 IpFamily = unix.AF_INET6 +) + +func (f IpFamily) String() string { + switch f { + case IpFamilyIPv4: + return "IPv4" + case IpFamilyIPv6: + return "IPv6" + default: + return "unknown" + } +} + +// RouteRule represents a routing table rule. +type RouteRule struct { + InterfaceId InterfaceIdentifier + IpFamily IpFamily + FwMark uint32 + Table int + HasDefault bool +}