Cleanup route handling (#542)

* mikrotik: allow to set DNS, wip: handle routes in wg-controller

* replace old route handling for local controller

* cleanup route handling for local backend

* implement route handling for mikrotik controller
This commit is contained in:
h44z
2025-10-12 14:31:19 +02:00
committed by GitHub
parent 298c9405f6
commit cdf3a49801
19 changed files with 1116 additions and 803 deletions

View File

@@ -53,8 +53,6 @@ func main() {
wireGuard, err := wireguard.NewControllerManager(cfg) wireGuard, err := wireguard.NewControllerManager(cfg)
internal.AssertNoError(err) internal.AssertNoError(err)
wgQuick := adapters.NewWgQuickRepo()
mailer := adapters.NewSmtpMailRepo(cfg.Mail) mailer := adapters.NewSmtpMailRepo(cfg.Mail)
metricsServer := adapters.NewMetricsServer(cfg) metricsServer := adapters.NewMetricsServer(cfg)
@@ -93,7 +91,7 @@ func main() {
webAuthn, err := auth.NewWebAuthnAuthenticator(cfg, eventBus, userManager) webAuthn, err := auth.NewWebAuthnAuthenticator(cfg, eventBus, userManager)
internal.AssertNoError(err) internal.AssertNoError(err)
wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, wgQuick, database) wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, database)
internal.AssertNoError(err) internal.AssertNoError(err)
wireGuardManager.StartBackgroundJobs(ctx) wireGuardManager.StartBackgroundJobs(ctx)
@@ -107,7 +105,7 @@ func main() {
mailManager, err := mail.NewMailManager(cfg, mailer, cfgFileManager, database, database) mailManager, err := mail.NewMailManager(cfg, mailer, cfgFileManager, database, database)
internal.AssertNoError(err) internal.AssertNoError(err)
routeManager, err := route.NewRouteManager(cfg, eventBus, database) routeManager, err := route.NewRouteManager(cfg, eventBus, database, wireGuard)
internal.AssertNoError(err) internal.AssertNoError(err)
routeManager.StartBackgroundJobs(ctx) routeManager.StartBackgroundJobs(ctx)

View File

@@ -15,6 +15,9 @@ backend:
# default backend decides where new interfaces are created # default backend decides where new interfaces are created
default: mikrotik default: mikrotik
# A prefix for resolvconf. Usually it is "tun.". If you are using systemd, the prefix should be empty.
local_resolvconf_prefix: "tun."
mikrotik: mikrotik:
- id: mikrotik # unique id, not "local" - id: mikrotik # unique id, not "local"
display_name: RouterOS RB5009 # optional nice name display_name: RouterOS RB5009 # optional nice name

View File

@@ -28,6 +28,7 @@ core:
backend: backend:
default: local default: local
local_resolvconf_prefix: tun.
advanced: advanced:
log_level: info log_level: info
@@ -185,6 +186,11 @@ The current MikroTik backend is in **BETA** and may not support all features.
- **Description:** The default backend to use for managing WireGuard interfaces. - **Description:** The default backend to use for managing WireGuard interfaces.
Valid options are: `local`, or other backend id's configured in the `mikrotik` section. Valid options are: `local`, or other backend id's configured in the `mikrotik` section.
### `local_resolvconf_prefix`
- **Default:** `tun.`
- **Description:** Interface name prefix for WireGuard interfaces on the local system which is used to configure DNS servers with *resolvconf*.
It depends on the *resolvconf* implementation you are using, most use a prefix of `tun.`, but some have an empty prefix (e.g., systemd).
### `ignored_local_interfaces` ### `ignored_local_interfaces`
- **Default:** *(empty)* - **Default:** *(empty)*
- **Description:** A list of interface names to exclude when enumerating local interfaces. - **Description:** A list of interface names to exclude when enumerating local interfaces.

View File

@@ -444,6 +444,11 @@ async function del() {
<label class="form-label mt-4">{{ $t('modals.interface-edit.firewall-mark.label') }}</label> <label class="form-label mt-4">{{ $t('modals.interface-edit.firewall-mark.label') }}</label>
<input v-model="formData.FirewallMark" class="form-control" :placeholder="$t('modals.interface-edit.firewall-mark.placeholder')" type="number"> <input v-model="formData.FirewallMark" class="form-control" :placeholder="$t('modals.interface-edit.firewall-mark.placeholder')" type="number">
</div> </div>
<div class="form-group col-md-6" v-if="formData.Backend!=='local'">
<label class="form-label mt-4">{{ $t('modals.interface-edit.routing-table.label') }}</label>
<input v-model="formData.RoutingTable" aria-describedby="routingTableHelp" class="form-control" :placeholder="$t('modals.interface-edit.routing-table.placeholder')" type="text">
<small id="routingTableHelp" class="form-text text-muted">{{ $t('modals.interface-edit.routing-table.description') }}</small>
</div>
<div class="form-group col-md-6" v-else> <div class="form-group col-md-6" v-else>
</div> </div>
</div> </div>
@@ -457,7 +462,7 @@ async function del() {
</div> </div>
</div> </div>
</fieldset> </fieldset>
<fieldset> <fieldset v-if="formData.Backend==='local'">
<legend class="mt-4">{{ $t('modals.interface-edit.header-hooks') }}</legend> <legend class="mt-4">{{ $t('modals.interface-edit.header-hooks') }}</legend>
<div class="form-group"> <div class="form-group">
<label class="form-label mt-4">{{ $t('modals.interface-edit.pre-up.label') }}</label> <label class="form-label mt-4">{{ $t('modals.interface-edit.pre-up.label') }}</label>
@@ -482,7 +487,7 @@ async function del() {
<input v-model="formData.Disabled" class="form-check-input" type="checkbox"> <input v-model="formData.Disabled" class="form-check-input" type="checkbox">
<label class="form-check-label">{{ $t('modals.interface-edit.disabled.label') }}</label> <label class="form-check-label">{{ $t('modals.interface-edit.disabled.label') }}</label>
</div> </div>
<div class="form-check form-switch"> <div class="form-check form-switch" v-if="formData.Backend==='local'">
<input v-model="formData.SaveConfig" checked="" class="form-check-input" type="checkbox"> <input v-model="formData.SaveConfig" checked="" class="form-check-input" type="checkbox">
<label class="form-check-label">{{ $t('modals.interface-edit.save-config.label') }}</label> <label class="form-check-label">{{ $t('modals.interface-edit.save-config.label') }}</label>
</div> </div>

View File

@@ -9,6 +9,7 @@ import (
"log/slog" "log/slog"
"os" "os"
"os/exec" "os/exec"
"slices"
"strings" "strings"
"time" "time"
@@ -85,7 +86,7 @@ func NewLocalController(cfg *config.Config) (*LocalController, error) {
nl: nl, nl: nl,
shellCmd: "bash", // we only support bash at the moment shellCmd: "bash", // we only support bash at the moment
resolvConfIfacePrefix: "tun.", // WireGuard interfaces have a tun. prefix in resolvconf resolvConfIfacePrefix: cfg.Backend.LocalResolvconfPrefix, // WireGuard interfaces have a tun. prefix in resolvconf
} }
return repo, nil return repo, nil
@@ -546,7 +547,11 @@ func (c LocalController) deletePeer(deviceId domain.InterfaceIdentifier, id doma
// region wg-quick-related // region wg-quick-related
func (c LocalController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error { func (c LocalController) ExecuteInterfaceHook(
_ context.Context,
id domain.InterfaceIdentifier,
hookCmd string,
) error {
if hookCmd == "" { if hookCmd == "" {
return nil return nil
} }
@@ -560,7 +565,7 @@ func (c LocalController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hoo
return nil return nil
} }
func (c LocalController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error { func (c LocalController) SetDNS(_ context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
if dnsStr == "" && dnsSearchStr == "" { if dnsStr == "" && dnsSearchStr == "" {
return nil return nil
} }
@@ -589,7 +594,7 @@ func (c LocalController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearch
return nil return nil
} }
func (c LocalController) UnsetDNS(id domain.InterfaceIdentifier) error { func (c LocalController) UnsetDNS(_ context.Context, id domain.InterfaceIdentifier, _, _ string) error {
dnsCommand := "resolvconf -d %resPref%i -f" dnsCommand := "resolvconf -d %resPref%i -f"
err := c.exec(dnsCommand, id) err := c.exec(dnsCommand, id)
@@ -611,7 +616,7 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti
if len(stdin) > 0 { if len(stdin) > 0 {
b := &bytes.Buffer{} b := &bytes.Buffer{}
for _, ln := range stdin { for _, ln := range stdin {
if _, err := fmt.Fprint(b, ln); err != nil { if _, err := fmt.Fprint(b, ln+"\n"); err != nil {
return err return err
} }
} }
@@ -619,6 +624,8 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti
} }
out, err := cmd.CombinedOutput() // execute and wait for output out, err := cmd.CombinedOutput() // execute and wait for output
if err != nil { if err != nil {
slog.Warn("failed to executed shell command",
"command", commandWithInterfaceName, "stdin", stdin, "output", string(out), "error", err)
return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err) return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err)
} }
slog.Debug("executed shell command", slog.Debug("executed shell command",
@@ -631,49 +638,116 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti
// region routing-related // region routing-related
func (c LocalController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error { // SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op.
// update fwmark rules func (c LocalController) SetRoutes(_ context.Context, info domain.RoutingTableInfo) error {
if err := c.setFwMarkRules(rules); err != nil { interfaceId := info.Interface.Identifier
return err slog.Debug("setting linux routes", "interface", interfaceId, "table", info.Table, "fwMark", info.FwMark,
"cidrs", info.AllowedIps)
link, err := c.nl.LinkByName(string(interfaceId))
if err != nil {
return fmt.Errorf("failed to find physical link for %s: %w", interfaceId, err)
} }
// update main rule cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
if err := c.setMainRule(rules); err != nil { realTable, realFwMark, err := c.getOrCreateRoutingTableAndFwMark(link, info.Table, info.FwMark)
return err if err != nil {
return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err)
}
wgDev, err := c.wg.Device(string(interfaceId))
if err != nil {
return fmt.Errorf("failed to get wg device for %s: %w", interfaceId, err)
}
currentFwMark := wgDev.FirewallMark
if int(realFwMark) != currentFwMark {
slog.Debug("updating fwmark for interface", "interface", interfaceId, "oldFwMark", currentFwMark,
"newFwMark", realFwMark, "oldTable", info.Table, "newTable", realTable)
if err := c.updateFwMarkOnInterface(interfaceId, int(realFwMark)); err != nil {
return fmt.Errorf("failed to update fwmark for interface %s to %d: %w", interfaceId, realFwMark, err)
}
} }
// cleanup old main rules if err := c.setRoutesForFamily(interfaceId, link, netlink.FAMILY_V4, realTable, realFwMark, cidrsV4); err != nil {
if err := c.cleanupMainRule(rules); err != nil { return fmt.Errorf("failed to set v4 routes: %w", err)
return err }
if err := c.setRoutesForFamily(interfaceId, link, netlink.FAMILY_V6, realTable, realFwMark, cidrsV6); err != nil {
return fmt.Errorf("failed to set v6 routes: %w", err)
} }
return nil return nil
} }
func (c LocalController) setFwMarkRules(rules []domain.RouteRule) error { func (c LocalController) setRoutesForFamily(
for _, rule := range rules { interfaceId domain.InterfaceIdentifier,
existingRules, err := c.nl.RuleList(int(rule.IpFamily)) link netlink.Link,
family int,
table int,
fwMark uint32,
cidrs []domain.Cidr,
) error {
// first create or update the routes
for _, cidr := range cidrs {
err := c.nl.RouteReplace(&netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: cidr.IpNet(),
Table: table,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
})
if err != nil { if err != nil {
return fmt.Errorf("failed to get existing rules for family %s: %w", rule.IpFamily, err) return fmt.Errorf("failed to add/update route %s on table %d for interface %s: %w",
} cidr.String(), table, interfaceId, err)
ruleExists := false
for _, existingRule := range existingRules {
if rule.FwMark == existingRule.Mark && rule.Table == existingRule.Table {
ruleExists = true
break
} }
} }
if ruleExists { // next remove old routes
continue // rule already exists, no need to recreate it rawRoutes, err := c.nl.RouteListFiltered(family, &netlink.Route{
LinkIndex: link.Attrs().Index,
Table: unix.RT_TABLE_UNSPEC, // all tables
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
if err != nil {
return fmt.Errorf("failed to fetch raw routes for interface %s and family-id %d: %w",
interfaceId, family, err)
}
for _, rawRoute := range rawRoutes {
if rawRoute.Dst == nil { // handle default route
var netlinkAddr domain.Cidr
if family == netlink.FAMILY_V4 {
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
} else {
netlinkAddr, _ = domain.CidrFromString("::/0")
}
rawRoute.Dst = netlinkAddr.IpNet()
} }
// create a missing rule route := domain.CidrFromIpNet(*rawRoute.Dst)
if slices.Contains(cidrs, route) {
continue
}
if err := c.nl.RouteDel(&rawRoute); err != nil {
return fmt.Errorf("failed to remove deprecated route %s from interface %s: %w", route, interfaceId, err)
}
}
// next, update route rules for normal routes
if table == 0 {
return nil // no need to update route rules as we are using the default table
}
existingRules, err := c.nl.RuleList(family)
if err != nil {
return fmt.Errorf("failed to get existing rules for family-id %d: %w", family, err)
}
ruleExists := slices.ContainsFunc(existingRules, func(rule netlink.Rule) bool {
return rule.Mark == fwMark && rule.Table == table
})
if !ruleExists {
if err := c.nl.RuleAdd(&netlink.Rule{ if err := c.nl.RuleAdd(&netlink.Rule{
Family: int(rule.IpFamily), Family: family,
Table: rule.Table, Table: table,
Mark: rule.FwMark, Mark: fwMark,
Invert: true, Invert: true,
SuppressIfgroup: -1, SuppressIfgroup: -1,
SuppressPrefixlen: -1, SuppressPrefixlen: -1,
@@ -682,15 +756,102 @@ func (c LocalController) setFwMarkRules(rules []domain.RouteRule) error {
Goto: -1, Goto: -1,
Flow: -1, Flow: -1,
}); err != nil { }); err != nil {
return fmt.Errorf("failed to setup %s rule for fwmark %d and table %d: %w", return fmt.Errorf("failed to setup rule for fwmark %d and table %d for family-id %d: %w",
rule.IpFamily, rule.FwMark, rule.Table, err) fwMark, table, family, err)
} }
} }
mainRuleExists := slices.ContainsFunc(existingRules, func(rule netlink.Rule) bool {
return rule.SuppressPrefixlen == 0 && rule.Table == unix.RT_TABLE_MAIN
})
if !mainRuleExists && domain.ContainsDefaultRoute(cidrs) {
err = c.nl.RuleAdd(&netlink.Rule{
Family: family,
Table: unix.RT_TABLE_MAIN,
SuppressIfgroup: -1,
SuppressPrefixlen: 0,
Priority: c.getMainRulePriority(existingRules),
Mark: 0,
Mask: nil,
Goto: -1,
Flow: -1,
})
}
// finally, clean up extra main rules - only one rule is allowed
existingRules, err = c.nl.RuleList(family)
if err != nil {
return fmt.Errorf("failed to get existing main rules for family-id %d: %w", family, err)
}
mainRuleCount := 0
for _, rule := range existingRules {
if rule.SuppressPrefixlen == 0 && rule.Table == unix.RT_TABLE_MAIN {
mainRuleCount++
}
if mainRuleCount > 1 {
if err := c.nl.RuleDel(&rule); err != nil {
return fmt.Errorf("failed to remove extra main rule for family-id %d: %w", family, err)
}
}
}
return nil return nil
} }
func (c LocalController) getOrCreateRoutingTableAndFwMark(
link netlink.Link,
tableIn int,
fwMarkIn uint32,
) (
table int,
fwmark uint32,
err error,
) {
table = tableIn
fwmark = fwMarkIn
if fwmark == 0 {
// generate a new (temporary) firewall mark based on the interface index
fwmark = uint32(c.cfg.Advanced.RouteTableOffset + link.Attrs().Index)
}
if table == 0 {
table = int(fwmark) // generate a new routing table base on interface index
}
return
}
func (c LocalController) updateFwMarkOnInterface(interfaceId domain.InterfaceIdentifier, fwMark int) error {
// apply the new fwmark to the wireguard interface
err := c.wg.ConfigureDevice(string(interfaceId), wgtypes.Config{
FirewallMark: &fwMark,
})
if err != nil {
return fmt.Errorf("failed to update fwmark of interface %s to: %d: %w", interfaceId, fwMark, err)
}
return nil
}
func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int {
prio := c.cfg.Advanced.RulePrioOffset
for {
isFresh := true
for _, existingRule := range existingRules {
if existingRule.Priority == prio {
isFresh = false
break
}
}
if isFresh {
break
} else {
prio++
}
}
return prio
}
func (c LocalController) getRulePriority(existingRules []netlink.Rule) int { func (c LocalController) getRulePriority(existingRules []netlink.Rule) int {
prio := 32700 // linux main rule has a priority of 32766 prio := 32700 // linux main rule has a prio of 32766
for { for {
isFresh := true isFresh := true
for _, existingRule := range existingRules { for _, existingRule := range existingRules {
@@ -708,126 +869,145 @@ func (c LocalController) getRulePriority(existingRules []netlink.Rule) int {
return prio return prio
} }
func (c LocalController) setMainRule(rules []domain.RouteRule) error { // RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op.
var family domain.IpFamily func (c LocalController) RemoveRoutes(_ context.Context, info domain.RoutingTableInfo) error {
shouldHaveMainRule := false interfaceId := info.Interface.Identifier
for _, rule := range rules { slog.Debug("removing linux routes", "interface", interfaceId, "table", info.Table, "fwMark", info.FwMark,
family = rule.IpFamily "cidrs", info.AllowedIps)
if rule.HasDefault == true {
shouldHaveMainRule = true
break
}
}
if !shouldHaveMainRule {
return nil
}
existingRules, err := c.nl.RuleList(int(family)) wgDev, err := c.wg.Device(string(interfaceId))
if err != nil { if err != nil {
return fmt.Errorf("failed to get existing rules for family %s: %w", family, err) slog.Debug("wg device already removed, route cleanup might be incomplete", "interface", interfaceId)
wgDev = nil
}
link, err := c.nl.LinkByName(string(interfaceId))
if err != nil {
slog.Debug("physical link already removed, route cleanup might be incomplete", "interface", interfaceId)
link = nil
} }
ruleExists := false fwMark := info.FwMark
for _, existingRule := range existingRules { if wgDev != nil && info.FwMark == 0 {
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { fwMark = uint32(wgDev.FirewallMark)
ruleExists = true }
break table := info.Table
if wgDev != nil && info.Table == 0 {
table = wgDev.FirewallMark // use the fwMark as table, this is the default behavior
}
linkIndex := -1
if link != nil {
linkIndex = link.Attrs().Index
}
cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
realTable, realFwMark, err := c.getOrCreateRoutingTableAndFwMark(link, table, fwMark)
if err != nil {
return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err)
}
if linkIndex > 0 {
err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V4, realTable, realFwMark, cidrsV4)
if err != nil {
return fmt.Errorf("failed to remove v4 routes: %w", err)
}
err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V6, realTable, realFwMark, cidrsV6)
if err != nil {
return fmt.Errorf("failed to remove v6 routes: %w", err)
} }
} }
if ruleExists { if table > 0 {
return nil // rule already exists, skip re-creation err = c.removeRouteRulesForTable(netlink.FAMILY_V4, realTable)
if err != nil {
return fmt.Errorf("failed to remove v4 route rules for %s: %w", interfaceId, err)
}
err = c.removeRouteRulesForTable(netlink.FAMILY_V6, realTable)
if err != nil {
return fmt.Errorf("failed to remove v6 route rules for %s: %w", interfaceId, err)
} }
if err := c.nl.RuleAdd(&netlink.Rule{
Family: int(family),
Table: unix.RT_TABLE_MAIN,
SuppressIfgroup: -1,
SuppressPrefixlen: 0,
Priority: c.getMainRulePriority(existingRules),
Mark: 0,
Mask: nil,
Goto: -1,
Flow: -1,
}); err != nil {
return fmt.Errorf("failed to setup rule for main table: %w", err)
} }
return nil return nil
} }
func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int { func (c LocalController) removeRoutesForFamily(
priority := c.cfg.Advanced.RulePrioOffset interfaceId domain.InterfaceIdentifier,
for { link netlink.Link,
isFresh := true family int,
for _, existingRule := range existingRules { table int,
if existingRule.Priority == priority { fwMark uint32,
isFresh = false cidrs []domain.Cidr,
break ) error {
} // first remove all rules
} existingRules, err := c.nl.RuleList(family)
if isFresh {
break
} else {
priority++
}
}
return priority
}
func (c LocalController) cleanupMainRule(rules []domain.RouteRule) error {
var family domain.IpFamily
for _, rule := range rules {
family = rule.IpFamily
break
}
existingRules, err := c.nl.RuleList(int(family))
if err != nil { if err != nil {
return fmt.Errorf("failed to get existing rules for family %s: %w", family, err) return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
} }
shouldHaveMainRule := false
for _, rule := range rules {
if rule.HasDefault == true {
shouldHaveMainRule = true
break
}
}
mainRules := 0
for _, existingRule := range existingRules { for _, existingRule := range existingRules {
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 { if fwMark == existingRule.Mark && table == existingRule.Table {
mainRules++ existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
}
}
removalCount := 0
if mainRules > 1 {
removalCount = mainRules - 1 // we only want one single rule
}
if !shouldHaveMainRule {
removalCount = mainRules
}
for _, existingRule := range existingRules {
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
if removalCount > 0 {
existingRule.Family = int(family) // set family, somehow the RuleList method does not populate the family field
if err := c.nl.RuleDel(&existingRule); err != nil { if err := c.nl.RuleDel(&existingRule); err != nil {
return fmt.Errorf("failed to delete main rule: %w", err) return fmt.Errorf("failed to delete old fwmark rule: %w", err)
} }
removalCount--
} }
} }
// next remove all routes
rawRoutes, err := c.nl.RouteListFiltered(family, &netlink.Route{
LinkIndex: link.Attrs().Index,
Table: unix.RT_TABLE_UNSPEC, // all tables
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
if err != nil {
return fmt.Errorf("failed to fetch raw routes for interface %s and family-id %d: %w",
interfaceId, family, err)
}
for _, rawRoute := range rawRoutes {
if rawRoute.Dst == nil { // handle default route
var netlinkAddr domain.Cidr
if family == netlink.FAMILY_V4 {
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
} else {
netlinkAddr, _ = domain.CidrFromString("::/0")
}
rawRoute.Dst = netlinkAddr.IpNet()
}
if rawRoute.Table != table {
continue // ignore routes from other tables
}
route := domain.CidrFromIpNet(*rawRoute.Dst)
if !slices.Contains(cidrs, route) {
continue // only remove routes that were previously added
}
if err := c.nl.RouteDel(&rawRoute); err != nil {
return fmt.Errorf("failed to remove old route %s from interface %s: %w", route, interfaceId, err)
}
} }
return nil return nil
} }
func (c LocalController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error { func (c LocalController) removeRouteRulesForTable(
// TODO implement me family int,
panic("implement me") table int,
) error {
existingRules, err := c.nl.RuleList(family)
if err != nil {
return fmt.Errorf("failed to get existing route rules for family-id %d: %w", family, err)
}
for _, existingRule := range existingRules {
if existingRule.Table == table {
err := c.nl.RuleDel(&existingRule)
if err != nil {
return fmt.Errorf("failed to delete old rule for table %d and family-id %d: %w", table, family, err)
}
}
}
return nil
} }
// endregion routing-related // endregion routing-related

View File

@@ -15,6 +15,9 @@ import (
"github.com/h44z/wg-portal/internal/lowlevel" "github.com/h44z/wg-portal/internal/lowlevel"
) )
const MikrotikRouteDistance = 5
const MikrotikDefaultRoutingTable = "main"
type MikrotikController struct { type MikrotikController struct {
coreCfg *config.Config coreCfg *config.Config
cfg *config.BackendMikrotik cfg *config.BackendMikrotik
@@ -24,6 +27,7 @@ type MikrotikController struct {
// Add mutexes to prevent race conditions // Add mutexes to prevent race conditions
interfaceMutexes sync.Map // map[domain.InterfaceIdentifier]*sync.Mutex interfaceMutexes sync.Map // map[domain.InterfaceIdentifier]*sync.Mutex
peerMutexes sync.Map // map[domain.PeerIdentifier]*sync.Mutex peerMutexes sync.Map // map[domain.PeerIdentifier]*sync.Mutex
coreMutex sync.Mutex // for updating the core configuration such as routing table or DNS settings
} }
func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik) (*MikrotikController, error) { func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik) (*MikrotikController, error) {
@@ -40,6 +44,7 @@ func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik)
interfaceMutexes: sync.Map{}, interfaceMutexes: sync.Map{},
peerMutexes: sync.Map{}, peerMutexes: sync.Map{},
coreMutex: sync.Mutex{},
}, nil }, nil
} }
@@ -763,33 +768,404 @@ func (c *MikrotikController) DeletePeer(
// region wg-quick-related // region wg-quick-related
func (c *MikrotikController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error { func (c *MikrotikController) ExecuteInterfaceHook(
_ context.Context,
_ domain.InterfaceIdentifier,
_ string,
) error {
// TODO implement me // TODO implement me
panic("implement me") slog.Error("interface hooks are not yet supported for Mikrotik backends, please open an issue on GitHub")
return nil
} }
func (c *MikrotikController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error { func (c *MikrotikController) SetDNS(
// TODO implement me ctx context.Context,
panic("implement me") _ domain.InterfaceIdentifier,
dnsStr, _ string,
) error {
// Lock the interface to prevent concurrent modifications
c.coreMutex.Lock()
defer c.coreMutex.Unlock()
// check if the server is already configured
wgReply := c.client.Get(ctx, "/ip/dns", &lowlevel.MikrotikRequestOptions{
PropList: []string{"servers"},
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("unable to find WireGuard dns settings: %v", wgReply.Error)
} }
func (c *MikrotikController) UnsetDNS(id domain.InterfaceIdentifier) error { var existingServers []string
// TODO implement me existingServers = append(existingServers, strings.Split(wgReply.Data.GetString("servers"), ",")...)
panic("implement me")
newServers := strings.Split(dnsStr, ",")
mergedServers := slices.Clone(existingServers)
for _, s := range newServers {
if s == "" {
continue
}
if !slices.Contains(mergedServers, s) {
mergedServers = append(mergedServers, s)
}
}
mergedServersStr := strings.Join(mergedServers, ",")
reply := c.client.ExecList(ctx, "/ip/dns/set", lowlevel.GenericJsonObject{
"servers": mergedServersStr,
})
if reply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to set DNS servers: %s: %v", mergedServersStr, reply.Error)
}
return nil
}
func (c *MikrotikController) UnsetDNS(
ctx context.Context,
_ domain.InterfaceIdentifier,
dnsStr, _ string,
) error {
// Lock the interface to prevent concurrent modifications
c.coreMutex.Lock()
defer c.coreMutex.Unlock()
// retrieve current DNS settings
wgReply := c.client.Get(ctx, "/ip/dns", &lowlevel.MikrotikRequestOptions{
PropList: []string{"servers"},
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("unable to find WireGuard dns settings: %v", wgReply.Error)
}
var existingServers []string
existingServers = append(existingServers, strings.Split(wgReply.Data.GetString("servers"), ",")...)
oldServers := strings.Split(dnsStr, ",")
mergedServers := make([]string, 0, len(existingServers))
for _, s := range existingServers {
if s == "" {
continue
}
if !slices.Contains(oldServers, s) {
mergedServers = append(mergedServers, s) // only keep the servers that are not in the old list
}
}
mergedServersStr := strings.Join(mergedServers, ",")
reply := c.client.ExecList(ctx, "/ip/dns/set", lowlevel.GenericJsonObject{
"servers": mergedServersStr,
})
if reply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to set DNS servers: %s: %v", mergedServersStr, reply.Error)
}
return nil
} }
// endregion wg-quick-related // endregion wg-quick-related
// region routing-related // region routing-related
func (c *MikrotikController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error { // SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op.
// TODO implement me func (c *MikrotikController) SetRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
panic("implement me") interfaceId := info.Interface.Identifier
slog.Debug("setting mikrotik routes", "interface", interfaceId, "table", info.TableStr, "cidrs", info.AllowedIps)
// Mikrotik needs some time to apply the changes.
// If we don't wait, the routes might get created multiple times as the dynamic routes are not yet available.
time.Sleep(2 * time.Second)
tableName, err := c.getOrCreateRoutingTables(ctx, info.Interface.Identifier, info.TableStr)
if err != nil {
return fmt.Errorf("failed to get or create routing table for %s: %v", interfaceId, err)
} }
func (c *MikrotikController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error { cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
// TODO implement me
panic("implement me") err = c.setRoutesForFamily(ctx, interfaceId, false, tableName, cidrsV4)
if err != nil {
return fmt.Errorf("failed to set IPv4 routes for %s: %v", interfaceId, err)
}
err = c.setRoutesForFamily(ctx, interfaceId, true, tableName, cidrsV6)
if err != nil {
return fmt.Errorf("failed to set IPv6 routes for %s: %v", interfaceId, err)
}
return nil
}
func (c *MikrotikController) resolveRouteTableName(name string) string {
name = strings.TrimSpace(name)
var mikrotikTableName string
switch strings.ToLower(name) {
case "", "0":
mikrotikTableName = MikrotikDefaultRoutingTable
case MikrotikDefaultRoutingTable:
return fmt.Sprintf("wgportal-%s",
MikrotikDefaultRoutingTable) // if the Mikrotik Main table should be used, the table-name should be left empty or set to "0".
default:
mikrotikTableName = name
}
return mikrotikTableName
}
func (c *MikrotikController) getOrCreateRoutingTables(
ctx context.Context,
interfaceId domain.InterfaceIdentifier,
table string,
) (string, error) {
// retrieve current routing tables
wgReply := c.client.Query(ctx, "/routing/table", &lowlevel.MikrotikRequestOptions{
PropList: []string{
".id", "dynamic", "fib", "name",
},
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return "", fmt.Errorf("unable to query routing tables: %v", wgReply.Error)
}
wantedTableName := c.resolveRouteTableName(table)
// check if the table already exists
for _, table := range wgReply.Data {
if table.GetString("name") == wantedTableName {
return wantedTableName, nil // already exists, nothing to do
}
}
// create the table if it does not exist
createReply := c.client.Create(ctx, "/routing/table", lowlevel.GenericJsonObject{
"name": wantedTableName,
"comment": fmt.Sprintf("Routing Table for %s", interfaceId),
"fib": strconv.FormatBool(true),
})
if createReply.Status != lowlevel.MikrotikApiStatusOk {
return "", fmt.Errorf("failed to create routing table %s: %v", wantedTableName, createReply.Error)
}
return wantedTableName, nil
}
func (c *MikrotikController) setRoutesForFamily(
ctx context.Context,
interfaceId domain.InterfaceIdentifier,
ipV6 bool,
table string,
cidrs []domain.Cidr,
) error {
apiPath := "/ip/route"
if ipV6 {
apiPath = "/ipv6/route"
}
// retrieve current routes
wgReply := c.client.Query(ctx, apiPath, &lowlevel.MikrotikRequestOptions{
PropList: []string{
".id", "disabled", "inactive", "distance", "dst-address", "dynamic", "gateway", "immediate-gw",
"routing-table", "scope", "target-scope", "client-dns", "comment", "disabled", "responder",
},
Filters: map[string]string{
"gateway": string(interfaceId),
},
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("unable to find WireGuard IP route settings (v6=%t): %v", ipV6, wgReply.Error)
}
// first create or update the routes
for _, cidr := range cidrs {
// check if the route already exists
exists := false
for _, route := range wgReply.Data {
existingRoute, err := domain.CidrFromString(route.GetString("dst-address"))
if err != nil {
slog.Warn("failed to parse route destination address",
"cidr", route.GetString("dst-address"), "error", err)
continue
}
if existingRoute.EqualPrefix(cidr) && route.GetString("routing-table") == table {
exists = true
break
}
}
if exists {
continue // route already exists, nothing to do
}
// create the route
reply := c.client.Create(ctx, apiPath, lowlevel.GenericJsonObject{
"gateway": string(interfaceId),
"dst-address": cidr.String(),
"distance": strconv.Itoa(MikrotikRouteDistance),
"disabled": strconv.FormatBool(false),
"routing-table": table,
})
if reply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to create new route %s via %s: %v", cidr.String(), interfaceId, reply.Error)
}
}
// finally, remove the routes that are not in the new list
for _, route := range wgReply.Data {
if route.GetBool("dynamic") {
continue // dynamic routes are not managed by the controller, nothing to do
}
existingRoute, err := domain.CidrFromString(route.GetString("dst-address"))
if err != nil {
slog.Warn("failed to parse route destination address",
"cidr", route.GetString("dst-address"), "error", err)
continue
}
valid := false
for _, cidr := range cidrs {
if existingRoute.EqualPrefix(cidr) {
valid = true
break
}
}
if valid {
continue // route is still valid, nothing to do
}
// remove the route
reply := c.client.Delete(ctx, apiPath+"/"+route.GetString(".id"))
if reply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to remove outdated route %s: %v", existingRoute.String(), reply.Error)
}
}
return nil
}
// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op.
func (c *MikrotikController) RemoveRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
interfaceId := info.Interface.Identifier
slog.Debug("removing mikrotik routes", "interface", interfaceId, "table", info.TableStr, "cidrs", info.AllowedIps)
tableName := c.resolveRouteTableName(info.TableStr)
cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
err := c.removeRoutesForFamily(ctx, interfaceId, false, tableName, cidrsV4)
if err != nil {
return fmt.Errorf("failed to remove IPv4 routes for %s: %v", interfaceId, err)
}
err = c.removeRoutesForFamily(ctx, interfaceId, true, tableName, cidrsV6)
if err != nil {
return fmt.Errorf("failed to remove IPv6 routes for %s: %v", interfaceId, err)
}
err = c.removeRoutingTable(ctx, tableName)
if err != nil {
return fmt.Errorf("failed to remove routing table for %s: %v", interfaceId, err)
}
return nil
}
func (c *MikrotikController) removeRoutesForFamily(
ctx context.Context,
interfaceId domain.InterfaceIdentifier,
ipV6 bool,
table string,
cidrs []domain.Cidr,
) error {
apiPath := "/ip/route"
if ipV6 {
apiPath = "/ipv6/route"
}
// retrieve current routes
wgReply := c.client.Query(ctx, apiPath, &lowlevel.MikrotikRequestOptions{
PropList: []string{
".id", "disabled", "inactive", "distance", "dst-address", "dynamic", "gateway", "immediate-gw",
"routing-table", "scope", "target-scope", "client-dns", "comment", "disabled", "responder",
},
Filters: map[string]string{
"gateway": string(interfaceId),
},
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("unable to find WireGuard IP route settings (v6=%t): %v", ipV6, wgReply.Error)
}
// remove the routes from the list
for _, route := range wgReply.Data {
if route.GetBool("dynamic") {
continue // dynamic routes are not managed by the controller, nothing to do
}
existingRoute, err := domain.CidrFromString(route.GetString("dst-address"))
if err != nil {
slog.Warn("failed to parse route destination address",
"cidr", route.GetString("dst-address"), "error", err)
continue
}
remove := false
for _, cidr := range cidrs {
if existingRoute.EqualPrefix(cidr) && route.GetString("routing-table") == table {
remove = true
break
}
}
if !remove {
continue // route is still valid, nothing to do
}
// remove the route
reply := c.client.Delete(ctx, apiPath+"/"+route.GetString(".id"))
if reply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to remove old route %s: %v", existingRoute.String(), reply.Error)
}
}
return nil
}
func (c *MikrotikController) removeRoutingTable(
ctx context.Context,
table string,
) error {
if table == MikrotikDefaultRoutingTable {
return nil // we cannot remove the default table
}
// retrieve current routing tables
wgReply := c.client.Query(ctx, "/routing/table", &lowlevel.MikrotikRequestOptions{
PropList: []string{
".id", "dynamic", "fib", "name",
},
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("unable to query routing tables: %v", wgReply.Error)
}
for _, existingTable := range wgReply.Data {
if existingTable.GetBool("dynamic") {
continue // dynamic tables are not managed by the controller, nothing to do
}
if existingTable.GetString("name") != table {
continue // not the table we want to remove
}
// remove the table
reply := c.client.Delete(ctx, "/routing/table/"+existingTable.GetString(".id"))
if reply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to remove routing table %s: %v", table, reply.Error)
}
return nil
}
return nil
} }
// endregion routing-related // endregion routing-related

View File

@@ -1,113 +0,0 @@
package adapters
import (
"bytes"
"fmt"
"log/slog"
"os/exec"
"strings"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/domain"
)
// WgQuickRepo implements higher level wg-quick like interactions like setting DNS, routing tables or interface hooks.
type WgQuickRepo struct {
shellCmd string
resolvConfIfacePrefix string
}
// NewWgQuickRepo creates a new WgQuickRepo instance.
func NewWgQuickRepo() *WgQuickRepo {
return &WgQuickRepo{
shellCmd: "bash",
resolvConfIfacePrefix: "tun.",
}
}
// ExecuteInterfaceHook executes the given hook command.
// The hook command can contain the following placeholders:
//
// %i: the interface identifier.
func (r *WgQuickRepo) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
if hookCmd == "" {
return nil
}
slog.Debug("executing interface hook", "interface", id, "hook", hookCmd)
err := r.exec(hookCmd, id)
if err != nil {
return fmt.Errorf("failed to exec hook: %w", err)
}
return nil
}
// SetDNS sets the DNS settings for the given interface. It uses resolvconf to set the DNS settings.
func (r *WgQuickRepo) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
if dnsStr == "" && dnsSearchStr == "" {
return nil
}
dnsServers := internal.SliceString(dnsStr)
dnsSearchDomains := internal.SliceString(dnsSearchStr)
dnsCommand := "resolvconf -a %resPref%i -m 0 -x"
dnsCommandInput := make([]string, 0, len(dnsServers)+len(dnsSearchDomains))
for _, dnsServer := range dnsServers {
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("nameserver %s", dnsServer))
}
for _, searchDomain := range dnsSearchDomains {
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("search %s", searchDomain))
}
err := r.exec(dnsCommand, id, dnsCommandInput...)
if err != nil {
return fmt.Errorf(
"failed to set dns settings (is resolvconf available?, for systemd create this symlink: ln -s /usr/bin/resolvectl /usr/local/bin/resolvconf): %w",
err,
)
}
return nil
}
// UnsetDNS unsets the DNS settings for the given interface. It uses resolvconf to unset the DNS settings.
func (r *WgQuickRepo) UnsetDNS(id domain.InterfaceIdentifier) error {
dnsCommand := "resolvconf -d %resPref%i -f"
err := r.exec(dnsCommand, id)
if err != nil {
return fmt.Errorf("failed to unset dns settings: %w", err)
}
return nil
}
func (r *WgQuickRepo) replaceCommandPlaceHolders(command string, interfaceId domain.InterfaceIdentifier) string {
command = strings.ReplaceAll(command, "%resPref", r.resolvConfIfacePrefix)
return strings.ReplaceAll(command, "%i", string(interfaceId))
}
func (r *WgQuickRepo) exec(command string, interfaceId domain.InterfaceIdentifier, stdin ...string) error {
commandWithInterfaceName := r.replaceCommandPlaceHolders(command, interfaceId)
cmd := exec.Command(r.shellCmd, "-ce", commandWithInterfaceName)
if len(stdin) > 0 {
b := &bytes.Buffer{}
for _, ln := range stdin {
if _, err := fmt.Fprint(b, ln); err != nil {
return err
}
}
cmd.Stdin = b
}
out, err := cmd.CombinedOutput() // execute and wait for output
if err != nil {
return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err)
}
slog.Debug("executed shell command",
"command", commandWithInterfaceName,
"output", string(out))
return nil
}

View File

@@ -4,25 +4,23 @@ import (
"context" "context"
"fmt" "fmt"
"log/slog" "log/slog"
"sync"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
"github.com/h44z/wg-portal/internal/lowlevel"
) )
// region dependencies // region dependencies
type ControllerManager interface {
// GetController returns the controller for the given interface.
GetController(iface domain.Interface) domain.InterfaceController
}
type InterfaceAndPeerDatabaseRepo interface { type InterfaceAndPeerDatabaseRepo interface {
// GetAllInterfaces returns all interfaces // GetInterface returns the interface with the given identifier.
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
// GetInterfacePeers returns all peers for a given interface
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
} }
type EventBus interface { type EventBus interface {
@@ -30,6 +28,13 @@ type EventBus interface {
Subscribe(topic string, fn interface{}) error Subscribe(topic string, fn interface{}) error
} }
type RoutesController interface {
// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op.
SetRoutes(ctx context.Context, info domain.RoutingTableInfo) error
// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op.
RemoveRoutes(ctx context.Context, info domain.RoutingTableInfo) error
}
// endregion dependencies // endregion dependencies
type routeRuleInfo struct { type routeRuleInfo struct {
@@ -46,27 +51,26 @@ type Manager struct {
cfg *config.Config cfg *config.Config
bus EventBus bus EventBus
wg lowlevel.WireGuardClient
nl lowlevel.NetlinkClient
db InterfaceAndPeerDatabaseRepo db InterfaceAndPeerDatabaseRepo
wgController ControllerManager
mux *sync.Mutex
} }
// NewRouteManager creates a new route manager instance. // NewRouteManager creates a new route manager instance.
func NewRouteManager(cfg *config.Config, bus EventBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) { func NewRouteManager(
wg, err := wgctrl.New() cfg *config.Config,
if err != nil { bus EventBus,
panic("failed to init wgctrl: " + err.Error()) db InterfaceAndPeerDatabaseRepo,
} wgController ControllerManager,
) (*Manager, error) {
nl := &lowlevel.NetlinkManager{}
m := &Manager{ m := &Manager{
cfg: cfg, cfg: cfg,
bus: bus, bus: bus,
db: db, db: db,
wg: wg, wgController: wgController,
nl: nl, mux: &sync.Mutex{},
} }
m.connectToMessageBus() m.connectToMessageBus()
@@ -85,419 +89,82 @@ func (m Manager) StartBackgroundJobs(_ context.Context) {
// this is a no-op for now // this is a no-op for now
} }
func (m Manager) handleRouteUpdateEvent(srcDescription string) { func (m Manager) handleRouteUpdateEvent(info domain.RoutingTableInfo) {
slog.Debug("handling route update event", "source", srcDescription) m.mux.Lock() // ensure that only one route update is processed at a time
defer m.mux.Unlock()
err := m.syncRoutes(context.Background()) slog.Debug("handling route update event", "info", info.String())
if err != nil {
slog.Error("failed to synchronize routes", if !info.ManagementEnabled() {
"source", srcDescription, return // route management disabled
"error", err)
} }
slog.Debug("routes synchronized", "source", srcDescription) err := m.syncRoutes(context.Background(), info)
if err != nil {
slog.Error("failed to synchronize routes",
"info", info.String(), "error", err)
return
}
slog.Debug("routes synchronized", "info", info.String())
} }
func (m Manager) handleRouteRemoveEvent(info domain.RoutingTableInfo) { func (m Manager) handleRouteRemoveEvent(info domain.RoutingTableInfo) {
m.mux.Lock() // ensure that only one route update is processed at a time
defer m.mux.Unlock()
slog.Debug("handling route remove event", "info", info.String()) slog.Debug("handling route remove event", "info", info.String())
if !info.ManagementEnabled() { if !info.ManagementEnabled() {
return // route management disabled return // route management disabled
} }
if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V4); err != nil { err := m.removeRoutes(context.Background(), info)
slog.Error("failed to remove v4 fwmark rules", "error", err)
}
if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V6); err != nil {
slog.Error("failed to remove v6 fwmark rules", "error", err)
}
slog.Debug("routes removed", "table", info.String())
}
func (m Manager) syncRoutes(ctx context.Context) error {
interfaces, err := m.db.GetAllInterfaces(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to find all interfaces: %w", err) slog.Error("failed to synchronize routes",
} "info", info.String(), "error", err)
rules := map[int][]routeRuleInfo{
netlink.FAMILY_V4: nil,
netlink.FAMILY_V6: nil,
}
for _, iface := range interfaces {
if iface.IsDisabled() {
continue // disabled interface does not need route entries
}
if !iface.ManageRoutingTable() {
continue
}
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil {
return fmt.Errorf("failed to find peers for %s: %w", iface.Identifier, err)
}
allowedIPs := iface.GetAllowedIPs(peers)
defRouteV4, defRouteV6 := m.containsDefaultRoute(allowedIPs)
link, err := m.nl.LinkByName(string(iface.Identifier))
if err != nil {
return fmt.Errorf("failed to find physical link for %s: %w", iface.Identifier, err)
}
table, fwmark, err := m.getRoutingTableAndFwMark(&iface, link)
if err != nil {
return fmt.Errorf("failed to get table and fwmark for %s: %w", iface.Identifier, err)
}
if err := m.setInterfaceRoutes(link, table, allowedIPs); err != nil {
return fmt.Errorf("failed to set routes for %s: %w", iface.Identifier, err)
}
if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V4, allowedIPs); err != nil {
return fmt.Errorf("failed to remove deprecated v4 routes for %s: %w", iface.Identifier, err)
}
if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V6, allowedIPs); err != nil {
return fmt.Errorf("failed to remove deprecated v6 routes for %s: %w", iface.Identifier, err)
}
if table != 0 {
rules[netlink.FAMILY_V4] = append(rules[netlink.FAMILY_V4], routeRuleInfo{
ifaceId: iface.Identifier,
fwMark: fwmark,
table: table,
family: netlink.FAMILY_V4,
hasDefault: defRouteV4,
})
}
if table != 0 {
rules[netlink.FAMILY_V6] = append(rules[netlink.FAMILY_V6], routeRuleInfo{
ifaceId: iface.Identifier,
fwMark: fwmark,
table: table,
family: netlink.FAMILY_V6,
hasDefault: defRouteV6,
})
}
}
return m.syncRouteRules(rules)
}
func (m Manager) syncRouteRules(allRules map[int][]routeRuleInfo) error {
for family, rules := range allRules {
// update fwmark rules
if err := m.setFwMarkRules(rules, family); err != nil {
return err
}
// update main rule
if err := m.setMainRule(rules, family); err != nil {
return err
}
// cleanup old main rules
if err := m.cleanupMainRule(rules, family); err != nil {
return err
}
}
return nil
}
func (m Manager) setFwMarkRules(rules []routeRuleInfo, family int) error {
for _, rule := range rules {
existingRules, err := m.nl.RuleList(family)
if err != nil {
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
}
ruleExists := false
for _, existingRule := range existingRules {
if rule.fwMark == existingRule.Mark && rule.table == existingRule.Table {
ruleExists = true
break
}
}
if ruleExists {
continue // rule already exists, no need to recreate it
}
// create missing rule
if err := m.nl.RuleAdd(&netlink.Rule{
Family: family,
Table: rule.table,
Mark: rule.fwMark,
Invert: true,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Priority: m.getRulePriority(existingRules),
Mask: nil,
Goto: -1,
Flow: -1,
}); err != nil {
return fmt.Errorf("failed to setup rule for fwmark %d and table %d: %w", rule.fwMark, rule.table, err)
}
}
return nil
}
func (m Manager) removeFwMarkRules(fwmark uint32, table int, family int) error {
existingRules, err := m.nl.RuleList(family)
if err != nil {
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
}
for _, existingRule := range existingRules {
if fwmark == existingRule.Mark && table == existingRule.Table {
existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
if err := m.nl.RuleDel(&existingRule); err != nil {
return fmt.Errorf("failed to delete fwmark rule: %w", err)
}
}
}
return nil
}
func (m Manager) setMainRule(rules []routeRuleInfo, family int) error {
shouldHaveMainRule := false
for _, rule := range rules {
if rule.hasDefault == true {
shouldHaveMainRule = true
break
}
}
if !shouldHaveMainRule {
return nil
}
existingRules, err := m.nl.RuleList(family)
if err != nil {
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
}
ruleExists := false
for _, existingRule := range existingRules {
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
ruleExists = true
break
}
}
if ruleExists {
return nil // rule already exists, skip re-creation
}
if err := m.nl.RuleAdd(&netlink.Rule{
Family: family,
Table: unix.RT_TABLE_MAIN,
SuppressIfgroup: -1,
SuppressPrefixlen: 0,
Priority: m.getMainRulePriority(existingRules),
Mark: 0,
Mask: nil,
Goto: -1,
Flow: -1,
}); err != nil {
return fmt.Errorf("failed to setup rule for main table: %w", err)
}
return nil
}
func (m Manager) cleanupMainRule(rules []routeRuleInfo, family int) error {
existingRules, err := m.nl.RuleList(family)
if err != nil {
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
}
shouldHaveMainRule := false
for _, rule := range rules {
if rule.hasDefault == true {
shouldHaveMainRule = true
break
}
}
mainRules := 0
for _, existingRule := range existingRules {
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
mainRules++
}
}
removalCount := 0
if mainRules > 1 {
removalCount = mainRules - 1 // we only want one single rule
}
if !shouldHaveMainRule {
removalCount = mainRules
}
for _, existingRule := range existingRules {
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
if removalCount > 0 {
existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
if err := m.nl.RuleDel(&existingRule); err != nil {
return fmt.Errorf("failed to delete main rule: %w", err)
}
removalCount--
}
}
}
return nil
}
func (m Manager) getMainRulePriority(existingRules []netlink.Rule) int {
prio := m.cfg.Advanced.RulePrioOffset
for {
isFresh := true
for _, existingRule := range existingRules {
if existingRule.Priority == prio {
isFresh = false
break
}
}
if isFresh {
break
} else {
prio++
}
}
return prio
}
func (m Manager) getRulePriority(existingRules []netlink.Rule) int {
prio := 32700 // linux main rule has a prio of 32766
for {
isFresh := true
for _, existingRule := range existingRules {
if existingRule.Priority == prio {
isFresh = false
break
}
}
if isFresh {
break
} else {
prio--
}
}
return prio
}
func (m Manager) setInterfaceRoutes(link netlink.Link, table int, allowedIPs []domain.Cidr) error {
for _, allowedIP := range allowedIPs {
err := m.nl.RouteReplace(&netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: allowedIP.IpNet(),
Table: table,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
})
if err != nil {
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
}
}
return nil
}
func (m Manager) removeDeprecatedRoutes(link netlink.Link, family int, allowedIPs []domain.Cidr) error {
rawRoutes, err := m.nl.RouteListFiltered(family, &netlink.Route{
LinkIndex: link.Attrs().Index,
Table: unix.RT_TABLE_UNSPEC, // all tables
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
if err != nil {
return fmt.Errorf("failed to fetch raw routes: %w", err)
}
for _, rawRoute := range rawRoutes {
if rawRoute.Dst == nil { // handle default route
var netlinkAddr domain.Cidr
if family == netlink.FAMILY_V4 {
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
} else {
netlinkAddr, _ = domain.CidrFromString("::/0")
}
rawRoute.Dst = netlinkAddr.IpNet()
}
netlinkAddr := domain.CidrFromIpNet(*rawRoute.Dst)
remove := true
for _, allowedIP := range allowedIPs {
if netlinkAddr == allowedIP {
remove = false
break
}
}
if !remove {
continue
}
err := m.nl.RouteDel(&rawRoute)
if err != nil {
return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err)
}
}
return nil
}
func (m Manager) getRoutingTableAndFwMark(iface *domain.Interface, link netlink.Link) (
table int,
fwmark uint32,
err error,
) {
table = iface.GetRoutingTable()
fwmark = iface.FirewallMark
if fwmark == 0 {
// generate a new (temporary) firewall mark based on the interface index
fwmark = uint32(m.cfg.Advanced.RouteTableOffset + link.Attrs().Index)
slog.Debug("using fwmark to handle routes",
"interface", iface.Identifier,
"fwmark", fwmark)
// apply the temporary fwmark to the wireguard interface
err = m.setFwMark(iface.Identifier, int(fwmark))
}
if table == 0 {
table = int(fwmark) // generate a new routing table base on interface index
slog.Debug("using routing table to handle default routes",
"interface", iface.Identifier,
"table", table)
}
return return
} }
func (m Manager) setFwMark(id domain.InterfaceIdentifier, fwmark int) error { slog.Debug("routes removed", "info", info.String())
err := m.wg.ConfigureDevice(string(id), wgtypes.Config{ }
FirewallMark: &fwmark,
}) func (m Manager) syncRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
rc, ok := m.wgController.GetController(info.Interface).(RoutesController)
if !ok {
slog.Warn("no capable routes-controller found for interface", "interface", info.Interface.Identifier)
return nil
}
if !info.Interface.ManageRoutingTable() {
slog.Debug("interface does not manage routing table, skipping route update",
"interface", info.Interface.Identifier)
return nil
}
err := rc.SetRoutes(ctx, info)
if err != nil { if err != nil {
return fmt.Errorf("failed to update fwmark to: %d: %w", fwmark, err) return fmt.Errorf("failed to set routes for interface %s: %w", info.Interface.Identifier, err)
} }
return nil return nil
} }
func (m Manager) containsDefaultRoute(allowedIPs []domain.Cidr) (ipV4, ipV6 bool) { func (m Manager) removeRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
for _, allowedIP := range allowedIPs { rc, ok := m.wgController.GetController(info.Interface).(RoutesController)
if ipV4 && ipV6 { if !ok {
break // speed up slog.Warn("no capable routes-controller found for interface", "interface", info.Interface.Identifier)
return nil
} }
if allowedIP.Prefix().Bits() == 0 { if !info.Interface.ManageRoutingTable() {
if allowedIP.IsV4() { slog.Debug("interface does not manage routing table, skipping route removal",
ipV4 = true "interface", info.Interface.Identifier)
} else { return nil
ipV6 = true
}
}
} }
return err := rc.RemoveRoutes(ctx, info)
if err != nil {
return fmt.Errorf("failed to remove routes for interface %s: %w", info.Interface.Identifier, err)
}
return nil
} }

View File

@@ -1,7 +1,6 @@
package wireguard package wireguard
import ( import (
"context"
"fmt" "fmt"
"log/slog" "log/slog"
"maps" "maps"
@@ -12,33 +11,9 @@ import (
"github.com/h44z/wg-portal/internal/domain" "github.com/h44z/wg-portal/internal/domain"
) )
type InterfaceController interface {
GetId() domain.InterfaceBackend
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
SaveInterface(
_ context.Context,
id domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
SavePeer(
_ context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
PingAddresses(
ctx context.Context,
addr string,
) (*domain.PingerResult, error)
}
type backendInstance struct { type backendInstance struct {
Config config.BackendBase // Config is the configuration for the backend instance. Config config.BackendBase // Config is the configuration for the backend instance.
Implementation InterfaceController Implementation domain.InterfaceController
} }
type ControllerManager struct { type ControllerManager struct {
@@ -118,11 +93,11 @@ func (c *ControllerManager) logRegisteredControllers() {
} }
} }
func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) InterfaceController { func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) domain.InterfaceController {
return c.getController(backend, "").Implementation return c.getController(backend, "").Implementation
} }
func (c *ControllerManager) GetController(iface domain.Interface) InterfaceController { func (c *ControllerManager) GetController(iface domain.Interface) domain.InterfaceController {
return c.getController(iface.Backend, iface.Identifier).Implementation return c.getController(iface.Backend, iface.Identifier).Implementation
} }

View File

@@ -38,9 +38,9 @@ type InterfaceAndPeerDatabaseRepo interface {
} }
type WgQuickController interface { type WgQuickController interface {
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error ExecuteInterfaceHook(ctx context.Context, id domain.InterfaceIdentifier, hookCmd string) error
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error SetDNS(ctx context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
UnsetDNS(id domain.InterfaceIdentifier) error UnsetDNS(ctx context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
} }
type EventBus interface { type EventBus interface {
@@ -57,7 +57,6 @@ type Manager struct {
bus EventBus bus EventBus
db InterfaceAndPeerDatabaseRepo db InterfaceAndPeerDatabaseRepo
wg *ControllerManager wg *ControllerManager
quick WgQuickController
userLockMap *sync.Map userLockMap *sync.Map
} }
@@ -66,7 +65,6 @@ func NewWireGuardManager(
cfg *config.Config, cfg *config.Config,
bus EventBus, bus EventBus,
wg *ControllerManager, wg *ControllerManager,
quick WgQuickController,
db InterfaceAndPeerDatabaseRepo, db InterfaceAndPeerDatabaseRepo,
) (*Manager, error) { ) (*Manager, error) {
m := &Manager{ m := &Manager{
@@ -74,7 +72,6 @@ func NewWireGuardManager(
bus: bus, bus: bus,
wg: wg, wg: wg,
db: db, db: db,
quick: quick,
userLockMap: &sync.Map{}, userLockMap: &sync.Map{},
} }

View File

@@ -453,7 +453,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
return err return err
} }
existingInterface, err := m.db.GetInterface(ctx, id) existingInterface, existingPeers, err := m.db.GetInterfaceAndPeers(ctx, id)
if err != nil { if err != nil {
return fmt.Errorf("unable to find interface %s: %w", id, err) return fmt.Errorf("unable to find interface %s: %w", id, err)
} }
@@ -462,21 +462,29 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
return fmt.Errorf("deletion not allowed: %w", err) return fmt.Errorf("deletion not allowed: %w", err)
} }
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
Interface: *existingInterface,
AllowedIps: existingInterface.GetAllowedIPs(existingPeers),
FwMark: existingInterface.FirewallMark,
Table: existingInterface.GetRoutingTable(),
TableStr: existingInterface.RoutingTable,
IsDeleted: true,
})
now := time.Now() now := time.Now()
existingInterface.Disabled = &now // simulate a disabled interface existingInterface.Disabled = &now // simulate a disabled interface
existingInterface.DisabledReason = domain.DisabledReasonDeleted existingInterface.DisabledReason = domain.DisabledReasonDeleted
physicalInterface, _ := m.wg.GetController(*existingInterface).GetInterface(ctx, id) if err := m.handleInterfacePreSaveHooks(ctx, existingInterface, !existingInterface.IsDisabled(),
false); err != nil {
if err := m.handleInterfacePreSaveHooks(existingInterface, !existingInterface.IsDisabled(), false); err != nil {
return fmt.Errorf("pre-delete hooks failed: %w", err) return fmt.Errorf("pre-delete hooks failed: %w", err)
} }
if err := m.handleInterfacePreSaveActions(existingInterface); err != nil { if err := m.handleInterfacePreSaveActions(ctx, existingInterface); err != nil {
return fmt.Errorf("pre-delete actions failed: %w", err) return fmt.Errorf("pre-delete actions failed: %w", err)
} }
if err := m.deleteInterfacePeers(ctx, id); err != nil { if err := m.deleteInterfacePeers(ctx, existingInterface, existingPeers); err != nil {
return fmt.Errorf("peer deletion failure: %w", err) return fmt.Errorf("peer deletion failure: %w", err)
} }
@@ -488,16 +496,12 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
return fmt.Errorf("deletion failure: %w", err) return fmt.Errorf("deletion failure: %w", err)
} }
fwMark := existingInterface.FirewallMark if err := m.handleInterfacePostSaveHooks(
if physicalInterface != nil && fwMark == 0 { ctx,
fwMark = physicalInterface.FirewallMark existingInterface,
} !existingInterface.IsDisabled(),
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{ false,
FwMark: fwMark, ); err != nil {
Table: existingInterface.GetRoutingTable(),
})
if err := m.handleInterfacePostSaveHooks(existingInterface, !existingInterface.IsDisabled(), false); err != nil {
return fmt.Errorf("post-delete hooks failed: %w", err) return fmt.Errorf("post-delete hooks failed: %w", err)
} }
@@ -516,17 +520,21 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
return nil, fmt.Errorf("interface validation failed: %w", err) return nil, fmt.Errorf("interface validation failed: %w", err)
} }
oldEnabled, newEnabled := m.getInterfaceStateHistory(ctx, iface) oldEnabled, newEnabled, routeTableChanged := false, !iface.IsDisabled(), false // if the interface did not exist, we assume it was not enabled
oldInterface, err := m.db.GetInterface(ctx, iface.Identifier)
if err == nil {
oldEnabled, newEnabled, routeTableChanged = m.getInterfaceStateHistory(oldInterface, iface)
}
if err := m.handleInterfacePreSaveHooks(iface, oldEnabled, newEnabled); err != nil { if err := m.handleInterfacePreSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil {
return nil, fmt.Errorf("pre-save hooks failed: %w", err) return nil, fmt.Errorf("pre-save hooks failed: %w", err)
} }
if err := m.handleInterfacePreSaveActions(iface); err != nil { if err := m.handleInterfacePreSaveActions(ctx, iface); err != nil {
return nil, fmt.Errorf("pre-save actions failed: %w", err) return nil, fmt.Errorf("pre-save actions failed: %w", err)
} }
err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) { err = m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
iface.CopyCalculatedAttributes(i) iface.CopyCalculatedAttributes(i)
err := m.wg.GetController(*iface).SaveInterface(ctx, iface.Identifier, err := m.wg.GetController(*iface).SaveInterface(ctx, iface.Identifier,
@@ -569,20 +577,35 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
} }
if iface.IsDisabled() { if iface.IsDisabled() {
physicalInterface, _ := m.wg.GetController(*iface).GetInterface(ctx, iface.Identifier)
fwMark := iface.FirewallMark
if physicalInterface != nil && fwMark == 0 {
fwMark = physicalInterface.FirewallMark
}
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{ m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
FwMark: fwMark, Interface: *iface,
AllowedIps: iface.GetAllowedIPs(peers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(), Table: iface.GetRoutingTable(),
TableStr: iface.RoutingTable,
}) })
} else { } else {
m.bus.Publish(app.TopicRouteUpdate, "interface updated: "+string(iface.Identifier)) m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
Interface: *iface,
AllowedIps: iface.GetAllowedIPs(peers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(),
TableStr: iface.RoutingTable,
})
// if the route table changed, ensure that the old entries are remove
if routeTableChanged {
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
Interface: *oldInterface,
AllowedIps: oldInterface.GetAllowedIPs(peers),
FwMark: oldInterface.FirewallMark,
Table: oldInterface.GetRoutingTable(),
TableStr: oldInterface.RoutingTable,
IsDeleted: true, // mark the old entries as deleted
})
}
} }
if err := m.handleInterfacePostSaveHooks(iface, oldEnabled, newEnabled); err != nil { if err := m.handleInterfacePostSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil {
return nil, fmt.Errorf("post-save hooks failed: %w", err) return nil, fmt.Errorf("post-save hooks failed: %w", err)
} }
@@ -618,60 +641,90 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
return iface, nil return iface, nil
} }
func (m Manager) getInterfaceStateHistory(ctx context.Context, iface *domain.Interface) (oldEnabled, newEnabled bool) { func (m Manager) getInterfaceStateHistory(
oldInterface, err := m.db.GetInterface(ctx, iface.Identifier) oldInterface *domain.Interface,
if err != nil { iface *domain.Interface,
return false, !iface.IsDisabled() // if the interface did not exist, we assume it was not enabled ) (oldEnabled, newEnabled, routeTableChanged bool) {
return !oldInterface.IsDisabled(), !iface.IsDisabled(), oldInterface.RoutingTable != iface.RoutingTable
} }
return !oldInterface.IsDisabled(), !iface.IsDisabled() func (m Manager) handleInterfacePreSaveActions(ctx context.Context, iface *domain.Interface) error {
wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
if !ok {
slog.Warn("failed to perform pre-save actions", "interface", iface.Identifier,
"error", "no capable controller found")
return nil
} }
func (m Manager) handleInterfacePreSaveActions(iface *domain.Interface) error { // update DNS settings only for client interfaces
if iface.Type == domain.InterfaceTypeClient || iface.Type == domain.InterfaceTypeAny {
if !iface.IsDisabled() { if !iface.IsDisabled() {
if err := m.quick.SetDNS(iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil { if err := wgQuickController.SetDNS(ctx, iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
return fmt.Errorf("failed to update dns settings: %w", err) return fmt.Errorf("failed to update dns settings: %w", err)
} }
} else { } else {
if err := m.quick.UnsetDNS(iface.Identifier); err != nil { if err := wgQuickController.UnsetDNS(ctx, iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
return fmt.Errorf("failed to clear dns settings: %w", err) return fmt.Errorf("failed to clear dns settings: %w", err)
} }
} }
}
return nil return nil
} }
func (m Manager) handleInterfacePreSaveHooks(iface *domain.Interface, oldEnabled, newEnabled bool) error { func (m Manager) handleInterfacePreSaveHooks(
ctx context.Context,
iface *domain.Interface,
oldEnabled, newEnabled bool,
) error {
if oldEnabled == newEnabled { if oldEnabled == newEnabled {
return nil // do nothing if state did not change return nil // do nothing if state did not change
} }
slog.Debug("executing pre-save hooks", "interface", iface.Identifier, "up", newEnabled) slog.Debug("executing pre-save hooks", "interface", iface.Identifier, "up", newEnabled)
wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
if !ok {
slog.Warn("failed to execute pre-save hooks", "interface", iface.Identifier, "up", newEnabled,
"error", "no capable controller found")
return nil
}
if newEnabled { if newEnabled {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreUp); err != nil { if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PreUp); err != nil {
return fmt.Errorf("failed to execute pre-up hook: %w", err) return fmt.Errorf("failed to execute pre-up hook: %w", err)
} }
} else { } else {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreDown); err != nil { if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PreDown); err != nil {
return fmt.Errorf("failed to execute pre-down hook: %w", err) return fmt.Errorf("failed to execute pre-down hook: %w", err)
} }
} }
return nil return nil
} }
func (m Manager) handleInterfacePostSaveHooks(iface *domain.Interface, oldEnabled, newEnabled bool) error { func (m Manager) handleInterfacePostSaveHooks(
ctx context.Context,
iface *domain.Interface,
oldEnabled, newEnabled bool,
) error {
if oldEnabled == newEnabled { if oldEnabled == newEnabled {
return nil // do nothing if state did not change return nil // do nothing if state did not change
} }
slog.Debug("executing post-save hooks", "interface", iface.Identifier, "up", newEnabled) slog.Debug("executing post-save hooks", "interface", iface.Identifier, "up", newEnabled)
wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
if !ok {
slog.Warn("failed to execute post-save hooks", "interface", iface.Identifier, "up", newEnabled,
"error", "no capable controller found")
return nil
}
if newEnabled { if newEnabled {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostUp); err != nil { if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PostUp); err != nil {
return fmt.Errorf("failed to execute post-up hook: %w", err) return fmt.Errorf("failed to execute post-up hook: %w", err)
} }
} else { } else {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostDown); err != nil { if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PostDown); err != nil {
return fmt.Errorf("failed to execute post-down hook: %w", err) return fmt.Errorf("failed to execute post-down hook: %w", err)
} }
} }
@@ -799,7 +852,7 @@ func (m Manager) getFreshListenPort(ctx context.Context) (port int, err error) {
func (m Manager) importInterface( func (m Manager) importInterface(
ctx context.Context, ctx context.Context,
backend InterfaceController, backend domain.InterfaceController,
in *domain.PhysicalInterface, in *domain.PhysicalInterface,
peers []domain.PhysicalPeer, peers []domain.PhysicalPeer,
) error { ) error {
@@ -901,13 +954,9 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain
return nil return nil
} }
func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) error { func (m Manager) deleteInterfacePeers(ctx context.Context, iface *domain.Interface, allPeers []domain.Peer) error {
iface, allPeers, err := m.db.GetInterfaceAndPeers(ctx, id)
if err != nil {
return err
}
for _, peer := range allPeers { for _, peer := range allPeers {
err = m.wg.GetController(*iface).DeletePeer(ctx, id, peer.Identifier) err := m.wg.GetController(*iface).DeletePeer(ctx, iface.Identifier, peer.Identifier)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err) return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err)
} }

View File

@@ -388,9 +388,20 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
return fmt.Errorf("failed to delete peer %s: %w", id, err) return fmt.Errorf("failed to delete peer %s: %w", id, err)
} }
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil {
return fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
}
m.bus.Publish(app.TopicPeerDeleted, *peer) m.bus.Publish(app.TopicPeerDeleted, *peer)
// Update routes after peers have changed // Update routes after peers have changed
m.bus.Publish(app.TopicRouteUpdate, "peers updated") m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
Interface: *iface,
AllowedIps: iface.GetAllowedIPs(peers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(),
TableStr: iface.RoutingTable,
})
// Update interface after peers have changed // Update interface after peers have changed
m.bus.Publish(app.TopicPeerInterfaceUpdated, peer.InterfaceIdentifier) m.bus.Publish(app.TopicPeerInterfaceUpdated, peer.InterfaceIdentifier)
@@ -438,20 +449,26 @@ func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier)
// region helper-functions // region helper-functions
func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error { func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
interfaces := make(map[domain.InterfaceIdentifier]struct{}) interfaces := make(map[domain.InterfaceIdentifier]domain.Interface)
for _, peer := range peers { for _, peer := range peers {
// get interface from db if it is not yet in the map
if _, ok := interfaces[peer.InterfaceIdentifier]; !ok {
iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier) iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
if err != nil { if err != nil {
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err) return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
} }
interfaces[peer.InterfaceIdentifier] = *iface
}
iface := interfaces[peer.InterfaceIdentifier]
// Always save the peer to the backend, regardless of disabled/expired state // Always save the peer to the backend, regardless of disabled/expired state
// The backend will handle the disabled state appropriately // The backend will handle the disabled state appropriately
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) { err := m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
peer.CopyCalculatedAttributes(p) peer.CopyCalculatedAttributes(p)
err := m.wg.GetController(*iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier, err := m.wg.GetController(iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) { func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, peer) domain.MergeToPhysicalPeer(pp, peer)
return pp, nil return pp, nil
@@ -475,13 +492,22 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
Peer: *peer, Peer: *peer,
}, },
}) })
interfaces[peer.InterfaceIdentifier] = struct{}{}
} }
// Update routes after peers have changed // Update routes after peers have changed
if len(interfaces) != 0 { for id, iface := range interfaces {
m.bus.Publish(app.TopicRouteUpdate, "peers updated") interfacePeers, err := m.db.GetInterfacePeers(ctx, id)
if err != nil {
return fmt.Errorf("failed to re-load peers for interface %s: %w", id, err)
}
m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
Interface: iface,
AllowedIps: iface.GetAllowedIPs(interfacePeers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(),
TableStr: iface.RoutingTable,
})
} }
for iface := range interfaces { for iface := range interfaces {

View File

@@ -13,6 +13,7 @@ type Backend struct {
// Local Backend-specific configuration // Local Backend-specific configuration
IgnoredLocalInterfaces []string `yaml:"ignored_local_interfaces"` // A list of interface names that should be ignored by this backend (e.g., "wg0") IgnoredLocalInterfaces []string `yaml:"ignored_local_interfaces"` // A list of interface names that should be ignored by this backend (e.g., "wg0")
LocalResolvconfPrefix string `yaml:"local_resolvconf_prefix"` // The prefix to use for interface names when passing them to resolvconf.
// External Backend-specific configuration // External Backend-specific configuration

View File

@@ -134,6 +134,9 @@ func defaultConfig() *Config {
cfg.Backend = Backend{ cfg.Backend = Backend{
Default: LocalBackendName, // local backend is the default (using wgcrtl) Default: LocalBackendName, // local backend is the default (using wgcrtl)
// Most resolconf implementations use "tun." as a prefix for interface names.
// But systemd's implementation uses no prefix, for example.
LocalResolvconfPrefix: "tun.",
} }
cfg.Web = WebConfig{ cfg.Web = WebConfig{

View File

@@ -13,6 +13,7 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/h44z/wg-portal/internal" "github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/config"
) )
const ( const (
@@ -132,9 +133,14 @@ func (i *Interface) GetConfigFileName() string {
return filename return filename
} }
// GetAllowedIPs returns the allowed IPs for the interface depending on the interface type and peers.
// For example, if the interface type is Server, the allowed IPs are the IPs of the peers.
// If the interface type is Client, the allowed IPs correspond to the AllowedIPsStr of the peers.
func (i *Interface) GetAllowedIPs(peers []Peer) []Cidr { func (i *Interface) GetAllowedIPs(peers []Peer) []Cidr {
var allowedCidrs []Cidr var allowedCidrs []Cidr
switch i.Type {
case InterfaceTypeServer, InterfaceTypeAny:
for _, peer := range peers { for _, peer := range peers {
for _, ip := range peer.Interface.Addresses { for _, ip := range peer.Interface.Addresses {
allowedCidrs = append(allowedCidrs, ip.HostAddr()) allowedCidrs = append(allowedCidrs, ip.HostAddr())
@@ -146,6 +152,14 @@ func (i *Interface) GetAllowedIPs(peers []Peer) []Cidr {
} }
} }
} }
case InterfaceTypeClient:
for _, peer := range peers {
allowedIPs, err := CidrsFromString(peer.AllowedIPsStr.GetValue())
if err == nil {
allowedCidrs = append(allowedCidrs, allowedIPs...)
}
}
}
return allowedCidrs return allowedCidrs
} }
@@ -159,6 +173,7 @@ func (i *Interface) ManageRoutingTable() bool {
// //
// -1 if RoutingTable was set to "off" or an error occurred // -1 if RoutingTable was set to "off" or an error occurred
func (i *Interface) GetRoutingTable() int { func (i *Interface) GetRoutingTable() int {
routingTableStr := strings.ToLower(i.RoutingTable) routingTableStr := strings.ToLower(i.RoutingTable)
switch { switch {
case routingTableStr == "": case routingTableStr == "":
@@ -166,6 +181,9 @@ func (i *Interface) GetRoutingTable() int {
case routingTableStr == "off": case routingTableStr == "off":
return -1 return -1
case strings.HasPrefix(routingTableStr, "0x"): case strings.HasPrefix(routingTableStr, "0x"):
if i.Backend != config.LocalBackendName {
return 0 // ignore numeric routing table numbers for non-local controllers
}
numberStr := strings.ReplaceAll(routingTableStr, "0x", "") numberStr := strings.ReplaceAll(routingTableStr, "0x", "")
routingTable, err := strconv.ParseUint(numberStr, 16, 64) routingTable, err := strconv.ParseUint(numberStr, 16, 64)
if err != nil { if err != nil {
@@ -178,6 +196,9 @@ func (i *Interface) GetRoutingTable() int {
} }
return int(routingTable) return int(routingTable)
default: default:
if i.Backend != config.LocalBackendName {
return 0 // ignore numeric routing table numbers for non-local controllers
}
routingTable, err := strconv.Atoi(routingTableStr) routingTable, err := strconv.Atoi(routingTableStr)
if err != nil { if err != nil {
slog.Error("failed to parse routing table number", "table", routingTableStr, "error", err) slog.Error("failed to parse routing table number", "table", routingTableStr, "error", err)
@@ -308,12 +329,18 @@ func MergeToPhysicalInterface(pi *PhysicalInterface, i *Interface) {
} }
type RoutingTableInfo struct { type RoutingTableInfo struct {
Interface Interface
AllowedIps []Cidr
FwMark uint32 FwMark uint32
Table int Table int
TableStr string // the routing table number as string (used by mikrotik, linux uses the numeric value)
IsDeleted bool // true if the interface was deleted, false otherwise
} }
func (r RoutingTableInfo) String() string { func (r RoutingTableInfo) String() string {
return fmt.Sprintf("%d -> %d", r.FwMark, r.Table) v4, v6 := CidrsPerFamily(r.AllowedIps)
return fmt.Sprintf("%s: fwmark=%d; table=%d; routes_4=%d; routes_6=%d", r.Interface.Identifier, r.FwMark, r.Table,
len(v4), len(v6))
} }
func (r RoutingTableInfo) ManagementEnabled() bool { func (r RoutingTableInfo) ManagementEnabled() bool {

View File

@@ -0,0 +1,27 @@
package domain
import "context"
type InterfaceController interface {
GetId() InterfaceBackend
GetInterfaces(_ context.Context) ([]PhysicalInterface, error)
GetInterface(_ context.Context, id InterfaceIdentifier) (*PhysicalInterface, error)
GetPeers(_ context.Context, deviceId InterfaceIdentifier) ([]PhysicalPeer, error)
SaveInterface(
_ context.Context,
id InterfaceIdentifier,
updateFunc func(pi *PhysicalInterface) (*PhysicalInterface, error),
) error
DeleteInterface(_ context.Context, id InterfaceIdentifier) error
SavePeer(
_ context.Context,
deviceId InterfaceIdentifier,
id PeerIdentifier,
updateFunc func(pp *PhysicalPeer) (*PhysicalPeer, error),
) error
DeletePeer(_ context.Context, deviceId InterfaceIdentifier, id PeerIdentifier) error
PingAddresses(
ctx context.Context,
addr string,
) (*PingerResult, error)
}

View File

@@ -5,6 +5,8 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/h44z/wg-portal/internal/config"
) )
func TestInterface_IsDisabledReturnsTrueWhenDisabled(t *testing.T) { func TestInterface_IsDisabledReturnsTrueWhenDisabled(t *testing.T) {
@@ -37,8 +39,9 @@ func TestInterface_GetConfigFileNameReturnsCorrectFileName(t *testing.T) {
assert.Equal(t, expected, iface.GetConfigFileName()) assert.Equal(t, expected, iface.GetConfigFileName())
} }
func TestInterface_GetAllowedIPsReturnsCorrectCidrs(t *testing.T) { func TestInterface_GetAllowedIPsReturnsCorrectCidrsServerMode(t *testing.T) {
peer1 := Peer{ peer1 := Peer{
AllowedIPsStr: ConfigOption[string]{Value: "192.168.2.2/32"},
Interface: PeerInterfaceConfig{ Interface: PeerInterfaceConfig{
Addresses: []Cidr{ Addresses: []Cidr{
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32}, {Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
@@ -46,16 +49,45 @@ func TestInterface_GetAllowedIPsReturnsCorrectCidrs(t *testing.T) {
}, },
} }
peer2 := Peer{ peer2 := Peer{
AllowedIPsStr: ConfigOption[string]{Value: "10.0.2.2/32"},
ExtraAllowedIPsStr: "10.20.2.2/32",
Interface: PeerInterfaceConfig{ Interface: PeerInterfaceConfig{
Addresses: []Cidr{ Addresses: []Cidr{
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32}, {Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
}, },
}, },
} }
iface := &Interface{} iface := &Interface{Type: InterfaceTypeServer}
expected := []Cidr{ expected := []Cidr{
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32}, {Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32}, {Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
{Cidr: "10.20.2.2/32", Addr: "10.20.2.2", NetLength: 32},
}
assert.Equal(t, expected, iface.GetAllowedIPs([]Peer{peer1, peer2}))
}
func TestInterface_GetAllowedIPsReturnsCorrectCidrsClientMode(t *testing.T) {
peer1 := Peer{
AllowedIPsStr: ConfigOption[string]{Value: "192.168.2.2/32"},
Interface: PeerInterfaceConfig{
Addresses: []Cidr{
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
},
},
}
peer2 := Peer{
AllowedIPsStr: ConfigOption[string]{Value: "10.0.2.2/32"},
ExtraAllowedIPsStr: "10.20.2.2/32",
Interface: PeerInterfaceConfig{
Addresses: []Cidr{
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
},
},
}
iface := &Interface{Type: InterfaceTypeClient}
expected := []Cidr{
{Cidr: "192.168.2.2/32", Addr: "192.168.2.2", NetLength: 32},
{Cidr: "10.0.2.2/32", Addr: "10.0.2.2", NetLength: 32},
} }
assert.Equal(t, expected, iface.GetAllowedIPs([]Peer{peer1, peer2})) assert.Equal(t, expected, iface.GetAllowedIPs([]Peer{peer1, peer2}))
} }
@@ -66,10 +98,22 @@ func TestInterface_ManageRoutingTableReturnsCorrectValue(t *testing.T) {
iface.RoutingTable = "100" iface.RoutingTable = "100"
assert.True(t, iface.ManageRoutingTable()) assert.True(t, iface.ManageRoutingTable())
iface = &Interface{RoutingTable: "off", Backend: config.LocalBackendName}
assert.False(t, iface.ManageRoutingTable())
iface.RoutingTable = "100"
assert.True(t, iface.ManageRoutingTable())
iface = &Interface{RoutingTable: "off", Backend: "mikrotik-xxx"}
assert.False(t, iface.ManageRoutingTable())
iface.RoutingTable = "100"
assert.True(t, iface.ManageRoutingTable())
} }
func TestInterface_GetRoutingTableReturnsCorrectValue(t *testing.T) { func TestInterface_GetRoutingTableReturnsCorrectValue(t *testing.T) {
iface := &Interface{RoutingTable: ""} iface := &Interface{RoutingTable: "", Backend: config.LocalBackendName}
assert.Equal(t, 0, iface.GetRoutingTable()) assert.Equal(t, 0, iface.GetRoutingTable())
iface.RoutingTable = "off" iface.RoutingTable = "off"
@@ -81,3 +125,17 @@ func TestInterface_GetRoutingTableReturnsCorrectValue(t *testing.T) {
iface.RoutingTable = "200" iface.RoutingTable = "200"
assert.Equal(t, 200, iface.GetRoutingTable()) assert.Equal(t, 200, iface.GetRoutingTable())
} }
func TestInterface_GetRoutingTableNonLocal(t *testing.T) {
iface := &Interface{RoutingTable: "off", Backend: "something different"}
assert.Equal(t, -1, iface.GetRoutingTable())
iface.RoutingTable = "0"
assert.Equal(t, 0, iface.GetRoutingTable())
iface.RoutingTable = "100"
assert.Equal(t, 0, iface.GetRoutingTable())
iface.RoutingTable = "abc"
assert.Equal(t, 0, iface.GetRoutingTable())
}

View File

@@ -26,6 +26,10 @@ func (c Cidr) IsValid() bool {
return c.Prefix().IsValid() return c.Prefix().IsValid()
} }
func (c Cidr) EqualPrefix(other Cidr) bool {
return c.Addr == other.Addr && c.NetLength == other.NetLength
}
func CidrFromString(str string) (Cidr, error) { func CidrFromString(str string) (Cidr, error) {
prefix, err := netip.ParsePrefix(strings.TrimSpace(str)) prefix, err := netip.ParsePrefix(strings.TrimSpace(str))
if err != nil { if err != nil {
@@ -199,3 +203,26 @@ func (c Cidr) Contains(other Cidr) bool {
return subnet.Contains(otherIP) return subnet.Contains(otherIP)
} }
// ContainsDefaultRoute returns true if the given CIDRs contain a default route.
func ContainsDefaultRoute(cidrs []Cidr) bool {
for _, allowedIP := range cidrs {
if allowedIP.Prefix().Bits() == 0 {
return true
}
}
return false
}
// CidrsPerFamily returns a slice of CIDRs, one for each family (IPv4 and IPv6).
func CidrsPerFamily(cidrs []Cidr) (ipv4, ipv6 []Cidr) {
for _, cidr := range cidrs {
if cidr.IsV4() {
ipv4 = append(ipv4, cidr)
} else {
ipv6 = append(ipv6, cidr)
}
}
return
}

View File

@@ -267,6 +267,7 @@ func parseHttpResponse[T any](resp *http.Response, err error) MikrotikApiRespons
} }
defer func(Body io.ReadCloser) { defer func(Body io.ReadCloser) {
_, _ = io.Copy(io.Discard, Body) // ensure to empty the body
err := Body.Close() err := Body.Close()
if err != nil { if err != nil {
slog.Error("failed to close response body", "error", err) slog.Error("failed to close response body", "error", err)