mirror of
https://github.com/h44z/wg-portal.git
synced 2025-10-14 03:56:17 +00:00
Cleanup route handling (#542)
* mikrotik: allow to set DNS, wip: handle routes in wg-controller * replace old route handling for local controller * cleanup route handling for local backend * implement route handling for mikrotik controller
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -84,8 +85,8 @@ func NewLocalController(cfg *config.Config) (*LocalController, error) {
|
||||
wg: wg,
|
||||
nl: nl,
|
||||
|
||||
shellCmd: "bash", // we only support bash at the moment
|
||||
resolvConfIfacePrefix: "tun.", // WireGuard interfaces have a tun. prefix in resolvconf
|
||||
shellCmd: "bash", // we only support bash at the moment
|
||||
resolvConfIfacePrefix: cfg.Backend.LocalResolvconfPrefix, // WireGuard interfaces have a tun. prefix in resolvconf
|
||||
}
|
||||
|
||||
return repo, nil
|
||||
@@ -546,7 +547,11 @@ func (c LocalController) deletePeer(deviceId domain.InterfaceIdentifier, id doma
|
||||
|
||||
// region wg-quick-related
|
||||
|
||||
func (c LocalController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
|
||||
func (c LocalController) ExecuteInterfaceHook(
|
||||
_ context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
hookCmd string,
|
||||
) error {
|
||||
if hookCmd == "" {
|
||||
return nil
|
||||
}
|
||||
@@ -560,7 +565,7 @@ func (c LocalController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hoo
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
|
||||
func (c LocalController) SetDNS(_ context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
|
||||
if dnsStr == "" && dnsSearchStr == "" {
|
||||
return nil
|
||||
}
|
||||
@@ -589,7 +594,7 @@ func (c LocalController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearch
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) UnsetDNS(id domain.InterfaceIdentifier) error {
|
||||
func (c LocalController) UnsetDNS(_ context.Context, id domain.InterfaceIdentifier, _, _ string) error {
|
||||
dnsCommand := "resolvconf -d %resPref%i -f"
|
||||
|
||||
err := c.exec(dnsCommand, id)
|
||||
@@ -611,7 +616,7 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti
|
||||
if len(stdin) > 0 {
|
||||
b := &bytes.Buffer{}
|
||||
for _, ln := range stdin {
|
||||
if _, err := fmt.Fprint(b, ln); err != nil {
|
||||
if _, err := fmt.Fprint(b, ln+"\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -619,6 +624,8 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti
|
||||
}
|
||||
out, err := cmd.CombinedOutput() // execute and wait for output
|
||||
if err != nil {
|
||||
slog.Warn("failed to executed shell command",
|
||||
"command", commandWithInterfaceName, "stdin", stdin, "output", string(out), "error", err)
|
||||
return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err)
|
||||
}
|
||||
slog.Debug("executed shell command",
|
||||
@@ -631,49 +638,116 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti
|
||||
|
||||
// region routing-related
|
||||
|
||||
func (c LocalController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error {
|
||||
// update fwmark rules
|
||||
if err := c.setFwMarkRules(rules); err != nil {
|
||||
return err
|
||||
// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op.
|
||||
func (c LocalController) SetRoutes(_ context.Context, info domain.RoutingTableInfo) error {
|
||||
interfaceId := info.Interface.Identifier
|
||||
slog.Debug("setting linux routes", "interface", interfaceId, "table", info.Table, "fwMark", info.FwMark,
|
||||
"cidrs", info.AllowedIps)
|
||||
|
||||
link, err := c.nl.LinkByName(string(interfaceId))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find physical link for %s: %w", interfaceId, err)
|
||||
}
|
||||
|
||||
// update main rule
|
||||
if err := c.setMainRule(rules); err != nil {
|
||||
return err
|
||||
cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
|
||||
realTable, realFwMark, err := c.getOrCreateRoutingTableAndFwMark(link, info.Table, info.FwMark)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err)
|
||||
}
|
||||
wgDev, err := c.wg.Device(string(interfaceId))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get wg device for %s: %w", interfaceId, err)
|
||||
}
|
||||
currentFwMark := wgDev.FirewallMark
|
||||
if int(realFwMark) != currentFwMark {
|
||||
slog.Debug("updating fwmark for interface", "interface", interfaceId, "oldFwMark", currentFwMark,
|
||||
"newFwMark", realFwMark, "oldTable", info.Table, "newTable", realTable)
|
||||
if err := c.updateFwMarkOnInterface(interfaceId, int(realFwMark)); err != nil {
|
||||
return fmt.Errorf("failed to update fwmark for interface %s to %d: %w", interfaceId, realFwMark, err)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup old main rules
|
||||
if err := c.cleanupMainRule(rules); err != nil {
|
||||
return err
|
||||
if err := c.setRoutesForFamily(interfaceId, link, netlink.FAMILY_V4, realTable, realFwMark, cidrsV4); err != nil {
|
||||
return fmt.Errorf("failed to set v4 routes: %w", err)
|
||||
}
|
||||
if err := c.setRoutesForFamily(interfaceId, link, netlink.FAMILY_V6, realTable, realFwMark, cidrsV6); err != nil {
|
||||
return fmt.Errorf("failed to set v6 routes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) setFwMarkRules(rules []domain.RouteRule) error {
|
||||
for _, rule := range rules {
|
||||
existingRules, err := c.nl.RuleList(int(rule.IpFamily))
|
||||
func (c LocalController) setRoutesForFamily(
|
||||
interfaceId domain.InterfaceIdentifier,
|
||||
link netlink.Link,
|
||||
family int,
|
||||
table int,
|
||||
fwMark uint32,
|
||||
cidrs []domain.Cidr,
|
||||
) error {
|
||||
// first create or update the routes
|
||||
for _, cidr := range cidrs {
|
||||
err := c.nl.RouteReplace(&netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Dst: cidr.IpNet(),
|
||||
Table: table,
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
Type: unix.RTN_UNICAST,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family %s: %w", rule.IpFamily, err)
|
||||
return fmt.Errorf("failed to add/update route %s on table %d for interface %s: %w",
|
||||
cidr.String(), table, interfaceId, err)
|
||||
}
|
||||
}
|
||||
|
||||
ruleExists := false
|
||||
for _, existingRule := range existingRules {
|
||||
if rule.FwMark == existingRule.Mark && rule.Table == existingRule.Table {
|
||||
ruleExists = true
|
||||
break
|
||||
// next remove old routes
|
||||
rawRoutes, err := c.nl.RouteListFiltered(family, &netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Table: unix.RT_TABLE_UNSPEC, // all tables
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
Type: unix.RTN_UNICAST,
|
||||
}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch raw routes for interface %s and family-id %d: %w",
|
||||
interfaceId, family, err)
|
||||
}
|
||||
for _, rawRoute := range rawRoutes {
|
||||
if rawRoute.Dst == nil { // handle default route
|
||||
var netlinkAddr domain.Cidr
|
||||
if family == netlink.FAMILY_V4 {
|
||||
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
|
||||
} else {
|
||||
netlinkAddr, _ = domain.CidrFromString("::/0")
|
||||
}
|
||||
rawRoute.Dst = netlinkAddr.IpNet()
|
||||
}
|
||||
|
||||
if ruleExists {
|
||||
continue // rule already exists, no need to recreate it
|
||||
route := domain.CidrFromIpNet(*rawRoute.Dst)
|
||||
if slices.Contains(cidrs, route) {
|
||||
continue
|
||||
}
|
||||
|
||||
// create a missing rule
|
||||
if err := c.nl.RouteDel(&rawRoute); err != nil {
|
||||
return fmt.Errorf("failed to remove deprecated route %s from interface %s: %w", route, interfaceId, err)
|
||||
}
|
||||
}
|
||||
|
||||
// next, update route rules for normal routes
|
||||
if table == 0 {
|
||||
return nil // no need to update route rules as we are using the default table
|
||||
}
|
||||
existingRules, err := c.nl.RuleList(family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family-id %d: %w", family, err)
|
||||
}
|
||||
ruleExists := slices.ContainsFunc(existingRules, func(rule netlink.Rule) bool {
|
||||
return rule.Mark == fwMark && rule.Table == table
|
||||
})
|
||||
if !ruleExists {
|
||||
if err := c.nl.RuleAdd(&netlink.Rule{
|
||||
Family: int(rule.IpFamily),
|
||||
Table: rule.Table,
|
||||
Mark: rule.FwMark,
|
||||
Family: family,
|
||||
Table: table,
|
||||
Mark: fwMark,
|
||||
Invert: true,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: -1,
|
||||
@@ -682,15 +756,102 @@ func (c LocalController) setFwMarkRules(rules []domain.RouteRule) error {
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to setup %s rule for fwmark %d and table %d: %w",
|
||||
rule.IpFamily, rule.FwMark, rule.Table, err)
|
||||
return fmt.Errorf("failed to setup rule for fwmark %d and table %d for family-id %d: %w",
|
||||
fwMark, table, family, err)
|
||||
}
|
||||
}
|
||||
mainRuleExists := slices.ContainsFunc(existingRules, func(rule netlink.Rule) bool {
|
||||
return rule.SuppressPrefixlen == 0 && rule.Table == unix.RT_TABLE_MAIN
|
||||
})
|
||||
if !mainRuleExists && domain.ContainsDefaultRoute(cidrs) {
|
||||
err = c.nl.RuleAdd(&netlink.Rule{
|
||||
Family: family,
|
||||
Table: unix.RT_TABLE_MAIN,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: 0,
|
||||
Priority: c.getMainRulePriority(existingRules),
|
||||
Mark: 0,
|
||||
Mask: nil,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
})
|
||||
}
|
||||
|
||||
// finally, clean up extra main rules - only one rule is allowed
|
||||
existingRules, err = c.nl.RuleList(family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing main rules for family-id %d: %w", family, err)
|
||||
}
|
||||
mainRuleCount := 0
|
||||
for _, rule := range existingRules {
|
||||
if rule.SuppressPrefixlen == 0 && rule.Table == unix.RT_TABLE_MAIN {
|
||||
mainRuleCount++
|
||||
}
|
||||
if mainRuleCount > 1 {
|
||||
if err := c.nl.RuleDel(&rule); err != nil {
|
||||
return fmt.Errorf("failed to remove extra main rule for family-id %d: %w", family, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) getOrCreateRoutingTableAndFwMark(
|
||||
link netlink.Link,
|
||||
tableIn int,
|
||||
fwMarkIn uint32,
|
||||
) (
|
||||
table int,
|
||||
fwmark uint32,
|
||||
err error,
|
||||
) {
|
||||
table = tableIn
|
||||
fwmark = fwMarkIn
|
||||
|
||||
if fwmark == 0 {
|
||||
// generate a new (temporary) firewall mark based on the interface index
|
||||
fwmark = uint32(c.cfg.Advanced.RouteTableOffset + link.Attrs().Index)
|
||||
}
|
||||
if table == 0 {
|
||||
table = int(fwmark) // generate a new routing table base on interface index
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c LocalController) updateFwMarkOnInterface(interfaceId domain.InterfaceIdentifier, fwMark int) error {
|
||||
// apply the new fwmark to the wireguard interface
|
||||
err := c.wg.ConfigureDevice(string(interfaceId), wgtypes.Config{
|
||||
FirewallMark: &fwMark,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update fwmark of interface %s to: %d: %w", interfaceId, fwMark, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int {
|
||||
prio := c.cfg.Advanced.RulePrioOffset
|
||||
for {
|
||||
isFresh := true
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Priority == prio {
|
||||
isFresh = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if isFresh {
|
||||
break
|
||||
} else {
|
||||
prio++
|
||||
}
|
||||
}
|
||||
return prio
|
||||
}
|
||||
|
||||
func (c LocalController) getRulePriority(existingRules []netlink.Rule) int {
|
||||
prio := 32700 // linux main rule has a priority of 32766
|
||||
prio := 32700 // linux main rule has a prio of 32766
|
||||
for {
|
||||
isFresh := true
|
||||
for _, existingRule := range existingRules {
|
||||
@@ -708,126 +869,145 @@ func (c LocalController) getRulePriority(existingRules []netlink.Rule) int {
|
||||
return prio
|
||||
}
|
||||
|
||||
func (c LocalController) setMainRule(rules []domain.RouteRule) error {
|
||||
var family domain.IpFamily
|
||||
shouldHaveMainRule := false
|
||||
for _, rule := range rules {
|
||||
family = rule.IpFamily
|
||||
if rule.HasDefault == true {
|
||||
shouldHaveMainRule = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !shouldHaveMainRule {
|
||||
return nil
|
||||
}
|
||||
// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op.
|
||||
func (c LocalController) RemoveRoutes(_ context.Context, info domain.RoutingTableInfo) error {
|
||||
interfaceId := info.Interface.Identifier
|
||||
slog.Debug("removing linux routes", "interface", interfaceId, "table", info.Table, "fwMark", info.FwMark,
|
||||
"cidrs", info.AllowedIps)
|
||||
|
||||
existingRules, err := c.nl.RuleList(int(family))
|
||||
wgDev, err := c.wg.Device(string(interfaceId))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family %s: %w", family, err)
|
||||
slog.Debug("wg device already removed, route cleanup might be incomplete", "interface", interfaceId)
|
||||
wgDev = nil
|
||||
}
|
||||
|
||||
ruleExists := false
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
|
||||
ruleExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if ruleExists {
|
||||
return nil // rule already exists, skip re-creation
|
||||
}
|
||||
|
||||
if err := c.nl.RuleAdd(&netlink.Rule{
|
||||
Family: int(family),
|
||||
Table: unix.RT_TABLE_MAIN,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: 0,
|
||||
Priority: c.getMainRulePriority(existingRules),
|
||||
Mark: 0,
|
||||
Mask: nil,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to setup rule for main table: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int {
|
||||
priority := c.cfg.Advanced.RulePrioOffset
|
||||
for {
|
||||
isFresh := true
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Priority == priority {
|
||||
isFresh = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if isFresh {
|
||||
break
|
||||
} else {
|
||||
priority++
|
||||
}
|
||||
}
|
||||
return priority
|
||||
}
|
||||
|
||||
func (c LocalController) cleanupMainRule(rules []domain.RouteRule) error {
|
||||
var family domain.IpFamily
|
||||
for _, rule := range rules {
|
||||
family = rule.IpFamily
|
||||
break
|
||||
}
|
||||
|
||||
existingRules, err := c.nl.RuleList(int(family))
|
||||
link, err := c.nl.LinkByName(string(interfaceId))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family %s: %w", family, err)
|
||||
slog.Debug("physical link already removed, route cleanup might be incomplete", "interface", interfaceId)
|
||||
link = nil
|
||||
}
|
||||
|
||||
shouldHaveMainRule := false
|
||||
for _, rule := range rules {
|
||||
if rule.HasDefault == true {
|
||||
shouldHaveMainRule = true
|
||||
break
|
||||
fwMark := info.FwMark
|
||||
if wgDev != nil && info.FwMark == 0 {
|
||||
fwMark = uint32(wgDev.FirewallMark)
|
||||
}
|
||||
table := info.Table
|
||||
if wgDev != nil && info.Table == 0 {
|
||||
table = wgDev.FirewallMark // use the fwMark as table, this is the default behavior
|
||||
}
|
||||
linkIndex := -1
|
||||
if link != nil {
|
||||
linkIndex = link.Attrs().Index
|
||||
}
|
||||
|
||||
cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
|
||||
realTable, realFwMark, err := c.getOrCreateRoutingTableAndFwMark(link, table, fwMark)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err)
|
||||
}
|
||||
|
||||
if linkIndex > 0 {
|
||||
err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V4, realTable, realFwMark, cidrsV4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove v4 routes: %w", err)
|
||||
}
|
||||
err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V6, realTable, realFwMark, cidrsV6)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove v6 routes: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
mainRules := 0
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
|
||||
mainRules++
|
||||
if table > 0 {
|
||||
err = c.removeRouteRulesForTable(netlink.FAMILY_V4, realTable)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove v4 route rules for %s: %w", interfaceId, err)
|
||||
}
|
||||
}
|
||||
|
||||
removalCount := 0
|
||||
if mainRules > 1 {
|
||||
removalCount = mainRules - 1 // we only want one single rule
|
||||
}
|
||||
if !shouldHaveMainRule {
|
||||
removalCount = mainRules
|
||||
}
|
||||
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
|
||||
if removalCount > 0 {
|
||||
existingRule.Family = int(family) // set family, somehow the RuleList method does not populate the family field
|
||||
if err := c.nl.RuleDel(&existingRule); err != nil {
|
||||
return fmt.Errorf("failed to delete main rule: %w", err)
|
||||
}
|
||||
removalCount--
|
||||
}
|
||||
err = c.removeRouteRulesForTable(netlink.FAMILY_V6, realTable)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove v6 route rules for %s: %w", interfaceId, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
func (c LocalController) removeRoutesForFamily(
|
||||
interfaceId domain.InterfaceIdentifier,
|
||||
link netlink.Link,
|
||||
family int,
|
||||
table int,
|
||||
fwMark uint32,
|
||||
cidrs []domain.Cidr,
|
||||
) error {
|
||||
// first remove all rules
|
||||
existingRules, err := c.nl.RuleList(family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
|
||||
}
|
||||
for _, existingRule := range existingRules {
|
||||
if fwMark == existingRule.Mark && table == existingRule.Table {
|
||||
existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
|
||||
if err := c.nl.RuleDel(&existingRule); err != nil {
|
||||
return fmt.Errorf("failed to delete old fwmark rule: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// next remove all routes
|
||||
rawRoutes, err := c.nl.RouteListFiltered(family, &netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Table: unix.RT_TABLE_UNSPEC, // all tables
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
Type: unix.RTN_UNICAST,
|
||||
}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch raw routes for interface %s and family-id %d: %w",
|
||||
interfaceId, family, err)
|
||||
}
|
||||
for _, rawRoute := range rawRoutes {
|
||||
if rawRoute.Dst == nil { // handle default route
|
||||
var netlinkAddr domain.Cidr
|
||||
if family == netlink.FAMILY_V4 {
|
||||
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
|
||||
} else {
|
||||
netlinkAddr, _ = domain.CidrFromString("::/0")
|
||||
}
|
||||
rawRoute.Dst = netlinkAddr.IpNet()
|
||||
}
|
||||
|
||||
if rawRoute.Table != table {
|
||||
continue // ignore routes from other tables
|
||||
}
|
||||
|
||||
route := domain.CidrFromIpNet(*rawRoute.Dst)
|
||||
if !slices.Contains(cidrs, route) {
|
||||
continue // only remove routes that were previously added
|
||||
}
|
||||
|
||||
if err := c.nl.RouteDel(&rawRoute); err != nil {
|
||||
return fmt.Errorf("failed to remove old route %s from interface %s: %w", route, interfaceId, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c LocalController) removeRouteRulesForTable(
|
||||
family int,
|
||||
table int,
|
||||
) error {
|
||||
existingRules, err := c.nl.RuleList(family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing route rules for family-id %d: %w", family, err)
|
||||
}
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Table == table {
|
||||
err := c.nl.RuleDel(&existingRule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete old rule for table %d and family-id %d: %w", table, family, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// endregion routing-related
|
||||
|
@@ -15,6 +15,9 @@ import (
|
||||
"github.com/h44z/wg-portal/internal/lowlevel"
|
||||
)
|
||||
|
||||
const MikrotikRouteDistance = 5
|
||||
const MikrotikDefaultRoutingTable = "main"
|
||||
|
||||
type MikrotikController struct {
|
||||
coreCfg *config.Config
|
||||
cfg *config.BackendMikrotik
|
||||
@@ -22,8 +25,9 @@ type MikrotikController struct {
|
||||
client *lowlevel.MikrotikApiClient
|
||||
|
||||
// Add mutexes to prevent race conditions
|
||||
interfaceMutexes sync.Map // map[domain.InterfaceIdentifier]*sync.Mutex
|
||||
peerMutexes sync.Map // map[domain.PeerIdentifier]*sync.Mutex
|
||||
interfaceMutexes sync.Map // map[domain.InterfaceIdentifier]*sync.Mutex
|
||||
peerMutexes sync.Map // map[domain.PeerIdentifier]*sync.Mutex
|
||||
coreMutex sync.Mutex // for updating the core configuration such as routing table or DNS settings
|
||||
}
|
||||
|
||||
func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik) (*MikrotikController, error) {
|
||||
@@ -40,6 +44,7 @@ func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik)
|
||||
|
||||
interfaceMutexes: sync.Map{},
|
||||
peerMutexes: sync.Map{},
|
||||
coreMutex: sync.Mutex{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -763,33 +768,404 @@ func (c *MikrotikController) DeletePeer(
|
||||
|
||||
// region wg-quick-related
|
||||
|
||||
func (c *MikrotikController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
|
||||
func (c *MikrotikController) ExecuteInterfaceHook(
|
||||
_ context.Context,
|
||||
_ domain.InterfaceIdentifier,
|
||||
_ string,
|
||||
) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
slog.Error("interface hooks are not yet supported for Mikrotik backends, please open an issue on GitHub")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
func (c *MikrotikController) SetDNS(
|
||||
ctx context.Context,
|
||||
_ domain.InterfaceIdentifier,
|
||||
dnsStr, _ string,
|
||||
) error {
|
||||
// Lock the interface to prevent concurrent modifications
|
||||
c.coreMutex.Lock()
|
||||
defer c.coreMutex.Unlock()
|
||||
|
||||
// check if the server is already configured
|
||||
wgReply := c.client.Get(ctx, "/ip/dns", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{"servers"},
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("unable to find WireGuard dns settings: %v", wgReply.Error)
|
||||
}
|
||||
|
||||
var existingServers []string
|
||||
existingServers = append(existingServers, strings.Split(wgReply.Data.GetString("servers"), ",")...)
|
||||
|
||||
newServers := strings.Split(dnsStr, ",")
|
||||
|
||||
mergedServers := slices.Clone(existingServers)
|
||||
for _, s := range newServers {
|
||||
if s == "" {
|
||||
continue
|
||||
}
|
||||
if !slices.Contains(mergedServers, s) {
|
||||
mergedServers = append(mergedServers, s)
|
||||
}
|
||||
}
|
||||
mergedServersStr := strings.Join(mergedServers, ",")
|
||||
|
||||
reply := c.client.ExecList(ctx, "/ip/dns/set", lowlevel.GenericJsonObject{
|
||||
"servers": mergedServersStr,
|
||||
})
|
||||
if reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to set DNS servers: %s: %v", mergedServersStr, reply.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) UnsetDNS(id domain.InterfaceIdentifier) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
func (c *MikrotikController) UnsetDNS(
|
||||
ctx context.Context,
|
||||
_ domain.InterfaceIdentifier,
|
||||
dnsStr, _ string,
|
||||
) error {
|
||||
// Lock the interface to prevent concurrent modifications
|
||||
c.coreMutex.Lock()
|
||||
defer c.coreMutex.Unlock()
|
||||
|
||||
// retrieve current DNS settings
|
||||
wgReply := c.client.Get(ctx, "/ip/dns", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{"servers"},
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("unable to find WireGuard dns settings: %v", wgReply.Error)
|
||||
}
|
||||
|
||||
var existingServers []string
|
||||
existingServers = append(existingServers, strings.Split(wgReply.Data.GetString("servers"), ",")...)
|
||||
|
||||
oldServers := strings.Split(dnsStr, ",")
|
||||
|
||||
mergedServers := make([]string, 0, len(existingServers))
|
||||
for _, s := range existingServers {
|
||||
if s == "" {
|
||||
continue
|
||||
}
|
||||
if !slices.Contains(oldServers, s) {
|
||||
mergedServers = append(mergedServers, s) // only keep the servers that are not in the old list
|
||||
}
|
||||
}
|
||||
mergedServersStr := strings.Join(mergedServers, ",")
|
||||
|
||||
reply := c.client.ExecList(ctx, "/ip/dns/set", lowlevel.GenericJsonObject{
|
||||
"servers": mergedServersStr,
|
||||
})
|
||||
if reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to set DNS servers: %s: %v", mergedServersStr, reply.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// endregion wg-quick-related
|
||||
|
||||
// region routing-related
|
||||
|
||||
func (c *MikrotikController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op.
|
||||
func (c *MikrotikController) SetRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
|
||||
interfaceId := info.Interface.Identifier
|
||||
slog.Debug("setting mikrotik routes", "interface", interfaceId, "table", info.TableStr, "cidrs", info.AllowedIps)
|
||||
|
||||
// Mikrotik needs some time to apply the changes.
|
||||
// If we don't wait, the routes might get created multiple times as the dynamic routes are not yet available.
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
tableName, err := c.getOrCreateRoutingTables(ctx, info.Interface.Identifier, info.TableStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get or create routing table for %s: %v", interfaceId, err)
|
||||
}
|
||||
|
||||
cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
|
||||
|
||||
err = c.setRoutesForFamily(ctx, interfaceId, false, tableName, cidrsV4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set IPv4 routes for %s: %v", interfaceId, err)
|
||||
}
|
||||
|
||||
err = c.setRoutesForFamily(ctx, interfaceId, true, tableName, cidrsV6)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set IPv6 routes for %s: %v", interfaceId, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
func (c *MikrotikController) resolveRouteTableName(name string) string {
|
||||
name = strings.TrimSpace(name)
|
||||
|
||||
var mikrotikTableName string
|
||||
switch strings.ToLower(name) {
|
||||
case "", "0":
|
||||
mikrotikTableName = MikrotikDefaultRoutingTable
|
||||
case MikrotikDefaultRoutingTable:
|
||||
return fmt.Sprintf("wgportal-%s",
|
||||
MikrotikDefaultRoutingTable) // if the Mikrotik Main table should be used, the table-name should be left empty or set to "0".
|
||||
default:
|
||||
mikrotikTableName = name
|
||||
}
|
||||
|
||||
return mikrotikTableName
|
||||
}
|
||||
|
||||
func (c *MikrotikController) getOrCreateRoutingTables(
|
||||
ctx context.Context,
|
||||
interfaceId domain.InterfaceIdentifier,
|
||||
table string,
|
||||
) (string, error) {
|
||||
// retrieve current routing tables
|
||||
wgReply := c.client.Query(ctx, "/routing/table", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
".id", "dynamic", "fib", "name",
|
||||
},
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return "", fmt.Errorf("unable to query routing tables: %v", wgReply.Error)
|
||||
}
|
||||
|
||||
wantedTableName := c.resolveRouteTableName(table)
|
||||
|
||||
// check if the table already exists
|
||||
for _, table := range wgReply.Data {
|
||||
if table.GetString("name") == wantedTableName {
|
||||
return wantedTableName, nil // already exists, nothing to do
|
||||
}
|
||||
}
|
||||
|
||||
// create the table if it does not exist
|
||||
createReply := c.client.Create(ctx, "/routing/table", lowlevel.GenericJsonObject{
|
||||
"name": wantedTableName,
|
||||
"comment": fmt.Sprintf("Routing Table for %s", interfaceId),
|
||||
"fib": strconv.FormatBool(true),
|
||||
})
|
||||
if createReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return "", fmt.Errorf("failed to create routing table %s: %v", wantedTableName, createReply.Error)
|
||||
}
|
||||
|
||||
return wantedTableName, nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) setRoutesForFamily(
|
||||
ctx context.Context,
|
||||
interfaceId domain.InterfaceIdentifier,
|
||||
ipV6 bool,
|
||||
table string,
|
||||
cidrs []domain.Cidr,
|
||||
) error {
|
||||
apiPath := "/ip/route"
|
||||
if ipV6 {
|
||||
apiPath = "/ipv6/route"
|
||||
}
|
||||
|
||||
// retrieve current routes
|
||||
wgReply := c.client.Query(ctx, apiPath, &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
".id", "disabled", "inactive", "distance", "dst-address", "dynamic", "gateway", "immediate-gw",
|
||||
"routing-table", "scope", "target-scope", "client-dns", "comment", "disabled", "responder",
|
||||
},
|
||||
Filters: map[string]string{
|
||||
"gateway": string(interfaceId),
|
||||
},
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("unable to find WireGuard IP route settings (v6=%t): %v", ipV6, wgReply.Error)
|
||||
}
|
||||
|
||||
// first create or update the routes
|
||||
for _, cidr := range cidrs {
|
||||
// check if the route already exists
|
||||
exists := false
|
||||
for _, route := range wgReply.Data {
|
||||
existingRoute, err := domain.CidrFromString(route.GetString("dst-address"))
|
||||
if err != nil {
|
||||
slog.Warn("failed to parse route destination address",
|
||||
"cidr", route.GetString("dst-address"), "error", err)
|
||||
continue
|
||||
}
|
||||
if existingRoute.EqualPrefix(cidr) && route.GetString("routing-table") == table {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if exists {
|
||||
continue // route already exists, nothing to do
|
||||
}
|
||||
|
||||
// create the route
|
||||
reply := c.client.Create(ctx, apiPath, lowlevel.GenericJsonObject{
|
||||
"gateway": string(interfaceId),
|
||||
"dst-address": cidr.String(),
|
||||
"distance": strconv.Itoa(MikrotikRouteDistance),
|
||||
"disabled": strconv.FormatBool(false),
|
||||
"routing-table": table,
|
||||
})
|
||||
if reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to create new route %s via %s: %v", cidr.String(), interfaceId, reply.Error)
|
||||
}
|
||||
}
|
||||
|
||||
// finally, remove the routes that are not in the new list
|
||||
for _, route := range wgReply.Data {
|
||||
if route.GetBool("dynamic") {
|
||||
continue // dynamic routes are not managed by the controller, nothing to do
|
||||
}
|
||||
|
||||
existingRoute, err := domain.CidrFromString(route.GetString("dst-address"))
|
||||
if err != nil {
|
||||
slog.Warn("failed to parse route destination address",
|
||||
"cidr", route.GetString("dst-address"), "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
valid := false
|
||||
for _, cidr := range cidrs {
|
||||
if existingRoute.EqualPrefix(cidr) {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if valid {
|
||||
continue // route is still valid, nothing to do
|
||||
}
|
||||
|
||||
// remove the route
|
||||
reply := c.client.Delete(ctx, apiPath+"/"+route.GetString(".id"))
|
||||
if reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to remove outdated route %s: %v", existingRoute.String(), reply.Error)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op.
|
||||
func (c *MikrotikController) RemoveRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
|
||||
interfaceId := info.Interface.Identifier
|
||||
slog.Debug("removing mikrotik routes", "interface", interfaceId, "table", info.TableStr, "cidrs", info.AllowedIps)
|
||||
|
||||
tableName := c.resolveRouteTableName(info.TableStr)
|
||||
|
||||
cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
|
||||
|
||||
err := c.removeRoutesForFamily(ctx, interfaceId, false, tableName, cidrsV4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove IPv4 routes for %s: %v", interfaceId, err)
|
||||
}
|
||||
|
||||
err = c.removeRoutesForFamily(ctx, interfaceId, true, tableName, cidrsV6)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove IPv6 routes for %s: %v", interfaceId, err)
|
||||
}
|
||||
|
||||
err = c.removeRoutingTable(ctx, tableName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove routing table for %s: %v", interfaceId, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) removeRoutesForFamily(
|
||||
ctx context.Context,
|
||||
interfaceId domain.InterfaceIdentifier,
|
||||
ipV6 bool,
|
||||
table string,
|
||||
cidrs []domain.Cidr,
|
||||
) error {
|
||||
apiPath := "/ip/route"
|
||||
if ipV6 {
|
||||
apiPath = "/ipv6/route"
|
||||
}
|
||||
|
||||
// retrieve current routes
|
||||
wgReply := c.client.Query(ctx, apiPath, &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
".id", "disabled", "inactive", "distance", "dst-address", "dynamic", "gateway", "immediate-gw",
|
||||
"routing-table", "scope", "target-scope", "client-dns", "comment", "disabled", "responder",
|
||||
},
|
||||
Filters: map[string]string{
|
||||
"gateway": string(interfaceId),
|
||||
},
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("unable to find WireGuard IP route settings (v6=%t): %v", ipV6, wgReply.Error)
|
||||
}
|
||||
|
||||
// remove the routes from the list
|
||||
for _, route := range wgReply.Data {
|
||||
if route.GetBool("dynamic") {
|
||||
continue // dynamic routes are not managed by the controller, nothing to do
|
||||
}
|
||||
|
||||
existingRoute, err := domain.CidrFromString(route.GetString("dst-address"))
|
||||
if err != nil {
|
||||
slog.Warn("failed to parse route destination address",
|
||||
"cidr", route.GetString("dst-address"), "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
remove := false
|
||||
for _, cidr := range cidrs {
|
||||
if existingRoute.EqualPrefix(cidr) && route.GetString("routing-table") == table {
|
||||
remove = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !remove {
|
||||
continue // route is still valid, nothing to do
|
||||
}
|
||||
|
||||
// remove the route
|
||||
reply := c.client.Delete(ctx, apiPath+"/"+route.GetString(".id"))
|
||||
if reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to remove old route %s: %v", existingRoute.String(), reply.Error)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MikrotikController) removeRoutingTable(
|
||||
ctx context.Context,
|
||||
table string,
|
||||
) error {
|
||||
if table == MikrotikDefaultRoutingTable {
|
||||
return nil // we cannot remove the default table
|
||||
}
|
||||
|
||||
// retrieve current routing tables
|
||||
wgReply := c.client.Query(ctx, "/routing/table", &lowlevel.MikrotikRequestOptions{
|
||||
PropList: []string{
|
||||
".id", "dynamic", "fib", "name",
|
||||
},
|
||||
})
|
||||
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("unable to query routing tables: %v", wgReply.Error)
|
||||
}
|
||||
|
||||
for _, existingTable := range wgReply.Data {
|
||||
if existingTable.GetBool("dynamic") {
|
||||
continue // dynamic tables are not managed by the controller, nothing to do
|
||||
}
|
||||
if existingTable.GetString("name") != table {
|
||||
continue // not the table we want to remove
|
||||
}
|
||||
|
||||
// remove the table
|
||||
reply := c.client.Delete(ctx, "/routing/table/"+existingTable.GetString(".id"))
|
||||
if reply.Status != lowlevel.MikrotikApiStatusOk {
|
||||
return fmt.Errorf("failed to remove routing table %s: %v", table, reply.Error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// endregion routing-related
|
||||
|
@@ -1,113 +0,0 @@
|
||||
package adapters
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/h44z/wg-portal/internal"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
// WgQuickRepo implements higher level wg-quick like interactions like setting DNS, routing tables or interface hooks.
|
||||
type WgQuickRepo struct {
|
||||
shellCmd string
|
||||
resolvConfIfacePrefix string
|
||||
}
|
||||
|
||||
// NewWgQuickRepo creates a new WgQuickRepo instance.
|
||||
func NewWgQuickRepo() *WgQuickRepo {
|
||||
return &WgQuickRepo{
|
||||
shellCmd: "bash",
|
||||
resolvConfIfacePrefix: "tun.",
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteInterfaceHook executes the given hook command.
|
||||
// The hook command can contain the following placeholders:
|
||||
//
|
||||
// %i: the interface identifier.
|
||||
func (r *WgQuickRepo) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
|
||||
if hookCmd == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
slog.Debug("executing interface hook", "interface", id, "hook", hookCmd)
|
||||
err := r.exec(hookCmd, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to exec hook: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDNS sets the DNS settings for the given interface. It uses resolvconf to set the DNS settings.
|
||||
func (r *WgQuickRepo) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
|
||||
if dnsStr == "" && dnsSearchStr == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnsServers := internal.SliceString(dnsStr)
|
||||
dnsSearchDomains := internal.SliceString(dnsSearchStr)
|
||||
|
||||
dnsCommand := "resolvconf -a %resPref%i -m 0 -x"
|
||||
dnsCommandInput := make([]string, 0, len(dnsServers)+len(dnsSearchDomains))
|
||||
|
||||
for _, dnsServer := range dnsServers {
|
||||
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("nameserver %s", dnsServer))
|
||||
}
|
||||
for _, searchDomain := range dnsSearchDomains {
|
||||
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("search %s", searchDomain))
|
||||
}
|
||||
|
||||
err := r.exec(dnsCommand, id, dnsCommandInput...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to set dns settings (is resolvconf available?, for systemd create this symlink: ln -s /usr/bin/resolvectl /usr/local/bin/resolvconf): %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnsetDNS unsets the DNS settings for the given interface. It uses resolvconf to unset the DNS settings.
|
||||
func (r *WgQuickRepo) UnsetDNS(id domain.InterfaceIdentifier) error {
|
||||
dnsCommand := "resolvconf -d %resPref%i -f"
|
||||
|
||||
err := r.exec(dnsCommand, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unset dns settings: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *WgQuickRepo) replaceCommandPlaceHolders(command string, interfaceId domain.InterfaceIdentifier) string {
|
||||
command = strings.ReplaceAll(command, "%resPref", r.resolvConfIfacePrefix)
|
||||
return strings.ReplaceAll(command, "%i", string(interfaceId))
|
||||
}
|
||||
|
||||
func (r *WgQuickRepo) exec(command string, interfaceId domain.InterfaceIdentifier, stdin ...string) error {
|
||||
commandWithInterfaceName := r.replaceCommandPlaceHolders(command, interfaceId)
|
||||
cmd := exec.Command(r.shellCmd, "-ce", commandWithInterfaceName)
|
||||
if len(stdin) > 0 {
|
||||
b := &bytes.Buffer{}
|
||||
for _, ln := range stdin {
|
||||
if _, err := fmt.Fprint(b, ln); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
cmd.Stdin = b
|
||||
}
|
||||
out, err := cmd.CombinedOutput() // execute and wait for output
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err)
|
||||
}
|
||||
slog.Debug("executed shell command",
|
||||
"command", commandWithInterfaceName,
|
||||
"output", string(out))
|
||||
return nil
|
||||
}
|
Reference in New Issue
Block a user