+
{{ $t('modals.interface-edit.save-config.label') }}
diff --git a/internal/adapters/wgcontroller/local.go b/internal/adapters/wgcontroller/local.go
index 7f2e7fa..d91a5f6 100644
--- a/internal/adapters/wgcontroller/local.go
+++ b/internal/adapters/wgcontroller/local.go
@@ -9,6 +9,7 @@ import (
"log/slog"
"os"
"os/exec"
+ "slices"
"strings"
"time"
@@ -84,8 +85,8 @@ func NewLocalController(cfg *config.Config) (*LocalController, error) {
wg: wg,
nl: nl,
- shellCmd: "bash", // we only support bash at the moment
- resolvConfIfacePrefix: "tun.", // WireGuard interfaces have a tun. prefix in resolvconf
+ shellCmd: "bash", // we only support bash at the moment
+ resolvConfIfacePrefix: cfg.Backend.LocalResolvconfPrefix, // WireGuard interfaces have a tun. prefix in resolvconf
}
return repo, nil
@@ -546,7 +547,11 @@ func (c LocalController) deletePeer(deviceId domain.InterfaceIdentifier, id doma
// region wg-quick-related
-func (c LocalController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
+func (c LocalController) ExecuteInterfaceHook(
+ _ context.Context,
+ id domain.InterfaceIdentifier,
+ hookCmd string,
+) error {
if hookCmd == "" {
return nil
}
@@ -560,7 +565,7 @@ func (c LocalController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hoo
return nil
}
-func (c LocalController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
+func (c LocalController) SetDNS(_ context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
if dnsStr == "" && dnsSearchStr == "" {
return nil
}
@@ -589,7 +594,7 @@ func (c LocalController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearch
return nil
}
-func (c LocalController) UnsetDNS(id domain.InterfaceIdentifier) error {
+func (c LocalController) UnsetDNS(_ context.Context, id domain.InterfaceIdentifier, _, _ string) error {
dnsCommand := "resolvconf -d %resPref%i -f"
err := c.exec(dnsCommand, id)
@@ -611,7 +616,7 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti
if len(stdin) > 0 {
b := &bytes.Buffer{}
for _, ln := range stdin {
- if _, err := fmt.Fprint(b, ln); err != nil {
+ if _, err := fmt.Fprint(b, ln+"\n"); err != nil {
return err
}
}
@@ -619,6 +624,8 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti
}
out, err := cmd.CombinedOutput() // execute and wait for output
if err != nil {
+ slog.Warn("failed to executed shell command",
+ "command", commandWithInterfaceName, "stdin", stdin, "output", string(out), "error", err)
return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err)
}
slog.Debug("executed shell command",
@@ -631,49 +638,116 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti
// region routing-related
-func (c LocalController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error {
- // update fwmark rules
- if err := c.setFwMarkRules(rules); err != nil {
- return err
+// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op.
+func (c LocalController) SetRoutes(_ context.Context, info domain.RoutingTableInfo) error {
+ interfaceId := info.Interface.Identifier
+ slog.Debug("setting linux routes", "interface", interfaceId, "table", info.Table, "fwMark", info.FwMark,
+ "cidrs", info.AllowedIps)
+
+ link, err := c.nl.LinkByName(string(interfaceId))
+ if err != nil {
+ return fmt.Errorf("failed to find physical link for %s: %w", interfaceId, err)
}
- // update main rule
- if err := c.setMainRule(rules); err != nil {
- return err
+ cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
+ realTable, realFwMark, err := c.getOrCreateRoutingTableAndFwMark(link, info.Table, info.FwMark)
+ if err != nil {
+ return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err)
+ }
+ wgDev, err := c.wg.Device(string(interfaceId))
+ if err != nil {
+ return fmt.Errorf("failed to get wg device for %s: %w", interfaceId, err)
+ }
+ currentFwMark := wgDev.FirewallMark
+ if int(realFwMark) != currentFwMark {
+ slog.Debug("updating fwmark for interface", "interface", interfaceId, "oldFwMark", currentFwMark,
+ "newFwMark", realFwMark, "oldTable", info.Table, "newTable", realTable)
+ if err := c.updateFwMarkOnInterface(interfaceId, int(realFwMark)); err != nil {
+ return fmt.Errorf("failed to update fwmark for interface %s to %d: %w", interfaceId, realFwMark, err)
+ }
}
- // cleanup old main rules
- if err := c.cleanupMainRule(rules); err != nil {
- return err
+ if err := c.setRoutesForFamily(interfaceId, link, netlink.FAMILY_V4, realTable, realFwMark, cidrsV4); err != nil {
+ return fmt.Errorf("failed to set v4 routes: %w", err)
+ }
+ if err := c.setRoutesForFamily(interfaceId, link, netlink.FAMILY_V6, realTable, realFwMark, cidrsV6); err != nil {
+ return fmt.Errorf("failed to set v6 routes: %w", err)
}
return nil
}
-func (c LocalController) setFwMarkRules(rules []domain.RouteRule) error {
- for _, rule := range rules {
- existingRules, err := c.nl.RuleList(int(rule.IpFamily))
+func (c LocalController) setRoutesForFamily(
+ interfaceId domain.InterfaceIdentifier,
+ link netlink.Link,
+ family int,
+ table int,
+ fwMark uint32,
+ cidrs []domain.Cidr,
+) error {
+ // first create or update the routes
+ for _, cidr := range cidrs {
+ err := c.nl.RouteReplace(&netlink.Route{
+ LinkIndex: link.Attrs().Index,
+ Dst: cidr.IpNet(),
+ Table: table,
+ Scope: unix.RT_SCOPE_LINK,
+ Type: unix.RTN_UNICAST,
+ })
if err != nil {
- return fmt.Errorf("failed to get existing rules for family %s: %w", rule.IpFamily, err)
+ return fmt.Errorf("failed to add/update route %s on table %d for interface %s: %w",
+ cidr.String(), table, interfaceId, err)
}
+ }
- ruleExists := false
- for _, existingRule := range existingRules {
- if rule.FwMark == existingRule.Mark && rule.Table == existingRule.Table {
- ruleExists = true
- break
+ // next remove old routes
+ rawRoutes, err := c.nl.RouteListFiltered(family, &netlink.Route{
+ LinkIndex: link.Attrs().Index,
+ Table: unix.RT_TABLE_UNSPEC, // all tables
+ Scope: unix.RT_SCOPE_LINK,
+ Type: unix.RTN_UNICAST,
+ }, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
+ if err != nil {
+ return fmt.Errorf("failed to fetch raw routes for interface %s and family-id %d: %w",
+ interfaceId, family, err)
+ }
+ for _, rawRoute := range rawRoutes {
+ if rawRoute.Dst == nil { // handle default route
+ var netlinkAddr domain.Cidr
+ if family == netlink.FAMILY_V4 {
+ netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
+ } else {
+ netlinkAddr, _ = domain.CidrFromString("::/0")
}
+ rawRoute.Dst = netlinkAddr.IpNet()
}
- if ruleExists {
- continue // rule already exists, no need to recreate it
+ route := domain.CidrFromIpNet(*rawRoute.Dst)
+ if slices.Contains(cidrs, route) {
+ continue
}
- // create a missing rule
+ if err := c.nl.RouteDel(&rawRoute); err != nil {
+ return fmt.Errorf("failed to remove deprecated route %s from interface %s: %w", route, interfaceId, err)
+ }
+ }
+
+ // next, update route rules for normal routes
+ if table == 0 {
+ return nil // no need to update route rules as we are using the default table
+ }
+ existingRules, err := c.nl.RuleList(family)
+ if err != nil {
+ return fmt.Errorf("failed to get existing rules for family-id %d: %w", family, err)
+ }
+ ruleExists := slices.ContainsFunc(existingRules, func(rule netlink.Rule) bool {
+ return rule.Mark == fwMark && rule.Table == table
+ })
+ if !ruleExists {
if err := c.nl.RuleAdd(&netlink.Rule{
- Family: int(rule.IpFamily),
- Table: rule.Table,
- Mark: rule.FwMark,
+ Family: family,
+ Table: table,
+ Mark: fwMark,
Invert: true,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
@@ -682,15 +756,102 @@ func (c LocalController) setFwMarkRules(rules []domain.RouteRule) error {
Goto: -1,
Flow: -1,
}); err != nil {
- return fmt.Errorf("failed to setup %s rule for fwmark %d and table %d: %w",
- rule.IpFamily, rule.FwMark, rule.Table, err)
+ return fmt.Errorf("failed to setup rule for fwmark %d and table %d for family-id %d: %w",
+ fwMark, table, family, err)
}
}
+ mainRuleExists := slices.ContainsFunc(existingRules, func(rule netlink.Rule) bool {
+ return rule.SuppressPrefixlen == 0 && rule.Table == unix.RT_TABLE_MAIN
+ })
+ if !mainRuleExists && domain.ContainsDefaultRoute(cidrs) {
+ err = c.nl.RuleAdd(&netlink.Rule{
+ Family: family,
+ Table: unix.RT_TABLE_MAIN,
+ SuppressIfgroup: -1,
+ SuppressPrefixlen: 0,
+ Priority: c.getMainRulePriority(existingRules),
+ Mark: 0,
+ Mask: nil,
+ Goto: -1,
+ Flow: -1,
+ })
+ }
+
+ // finally, clean up extra main rules - only one rule is allowed
+ existingRules, err = c.nl.RuleList(family)
+ if err != nil {
+ return fmt.Errorf("failed to get existing main rules for family-id %d: %w", family, err)
+ }
+ mainRuleCount := 0
+ for _, rule := range existingRules {
+ if rule.SuppressPrefixlen == 0 && rule.Table == unix.RT_TABLE_MAIN {
+ mainRuleCount++
+ }
+ if mainRuleCount > 1 {
+ if err := c.nl.RuleDel(&rule); err != nil {
+ return fmt.Errorf("failed to remove extra main rule for family-id %d: %w", family, err)
+ }
+ }
+ }
+
return nil
}
+func (c LocalController) getOrCreateRoutingTableAndFwMark(
+ link netlink.Link,
+ tableIn int,
+ fwMarkIn uint32,
+) (
+ table int,
+ fwmark uint32,
+ err error,
+) {
+ table = tableIn
+ fwmark = fwMarkIn
+
+ if fwmark == 0 {
+ // generate a new (temporary) firewall mark based on the interface index
+ fwmark = uint32(c.cfg.Advanced.RouteTableOffset + link.Attrs().Index)
+ }
+ if table == 0 {
+ table = int(fwmark) // generate a new routing table base on interface index
+ }
+ return
+}
+
+func (c LocalController) updateFwMarkOnInterface(interfaceId domain.InterfaceIdentifier, fwMark int) error {
+ // apply the new fwmark to the wireguard interface
+ err := c.wg.ConfigureDevice(string(interfaceId), wgtypes.Config{
+ FirewallMark: &fwMark,
+ })
+ if err != nil {
+ return fmt.Errorf("failed to update fwmark of interface %s to: %d: %w", interfaceId, fwMark, err)
+ }
+
+ return nil
+}
+
+func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int {
+ prio := c.cfg.Advanced.RulePrioOffset
+ for {
+ isFresh := true
+ for _, existingRule := range existingRules {
+ if existingRule.Priority == prio {
+ isFresh = false
+ break
+ }
+ }
+ if isFresh {
+ break
+ } else {
+ prio++
+ }
+ }
+ return prio
+}
+
func (c LocalController) getRulePriority(existingRules []netlink.Rule) int {
- prio := 32700 // linux main rule has a priority of 32766
+ prio := 32700 // linux main rule has a prio of 32766
for {
isFresh := true
for _, existingRule := range existingRules {
@@ -708,126 +869,145 @@ func (c LocalController) getRulePriority(existingRules []netlink.Rule) int {
return prio
}
-func (c LocalController) setMainRule(rules []domain.RouteRule) error {
- var family domain.IpFamily
- shouldHaveMainRule := false
- for _, rule := range rules {
- family = rule.IpFamily
- if rule.HasDefault == true {
- shouldHaveMainRule = true
- break
- }
- }
- if !shouldHaveMainRule {
- return nil
- }
+// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op.
+func (c LocalController) RemoveRoutes(_ context.Context, info domain.RoutingTableInfo) error {
+ interfaceId := info.Interface.Identifier
+ slog.Debug("removing linux routes", "interface", interfaceId, "table", info.Table, "fwMark", info.FwMark,
+ "cidrs", info.AllowedIps)
- existingRules, err := c.nl.RuleList(int(family))
+ wgDev, err := c.wg.Device(string(interfaceId))
if err != nil {
- return fmt.Errorf("failed to get existing rules for family %s: %w", family, err)
+ slog.Debug("wg device already removed, route cleanup might be incomplete", "interface", interfaceId)
+ wgDev = nil
}
-
- ruleExists := false
- for _, existingRule := range existingRules {
- if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
- ruleExists = true
- break
- }
- }
-
- if ruleExists {
- return nil // rule already exists, skip re-creation
- }
-
- if err := c.nl.RuleAdd(&netlink.Rule{
- Family: int(family),
- Table: unix.RT_TABLE_MAIN,
- SuppressIfgroup: -1,
- SuppressPrefixlen: 0,
- Priority: c.getMainRulePriority(existingRules),
- Mark: 0,
- Mask: nil,
- Goto: -1,
- Flow: -1,
- }); err != nil {
- return fmt.Errorf("failed to setup rule for main table: %w", err)
- }
-
- return nil
-}
-
-func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int {
- priority := c.cfg.Advanced.RulePrioOffset
- for {
- isFresh := true
- for _, existingRule := range existingRules {
- if existingRule.Priority == priority {
- isFresh = false
- break
- }
- }
- if isFresh {
- break
- } else {
- priority++
- }
- }
- return priority
-}
-
-func (c LocalController) cleanupMainRule(rules []domain.RouteRule) error {
- var family domain.IpFamily
- for _, rule := range rules {
- family = rule.IpFamily
- break
- }
-
- existingRules, err := c.nl.RuleList(int(family))
+ link, err := c.nl.LinkByName(string(interfaceId))
if err != nil {
- return fmt.Errorf("failed to get existing rules for family %s: %w", family, err)
+ slog.Debug("physical link already removed, route cleanup might be incomplete", "interface", interfaceId)
+ link = nil
}
- shouldHaveMainRule := false
- for _, rule := range rules {
- if rule.HasDefault == true {
- shouldHaveMainRule = true
- break
+ fwMark := info.FwMark
+ if wgDev != nil && info.FwMark == 0 {
+ fwMark = uint32(wgDev.FirewallMark)
+ }
+ table := info.Table
+ if wgDev != nil && info.Table == 0 {
+ table = wgDev.FirewallMark // use the fwMark as table, this is the default behavior
+ }
+ linkIndex := -1
+ if link != nil {
+ linkIndex = link.Attrs().Index
+ }
+
+ cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
+ realTable, realFwMark, err := c.getOrCreateRoutingTableAndFwMark(link, table, fwMark)
+ if err != nil {
+ return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err)
+ }
+
+ if linkIndex > 0 {
+ err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V4, realTable, realFwMark, cidrsV4)
+ if err != nil {
+ return fmt.Errorf("failed to remove v4 routes: %w", err)
+ }
+ err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V6, realTable, realFwMark, cidrsV6)
+ if err != nil {
+ return fmt.Errorf("failed to remove v6 routes: %w", err)
}
}
- mainRules := 0
- for _, existingRule := range existingRules {
- if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
- mainRules++
+ if table > 0 {
+ err = c.removeRouteRulesForTable(netlink.FAMILY_V4, realTable)
+ if err != nil {
+ return fmt.Errorf("failed to remove v4 route rules for %s: %w", interfaceId, err)
}
- }
-
- removalCount := 0
- if mainRules > 1 {
- removalCount = mainRules - 1 // we only want one single rule
- }
- if !shouldHaveMainRule {
- removalCount = mainRules
- }
-
- for _, existingRule := range existingRules {
- if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
- if removalCount > 0 {
- existingRule.Family = int(family) // set family, somehow the RuleList method does not populate the family field
- if err := c.nl.RuleDel(&existingRule); err != nil {
- return fmt.Errorf("failed to delete main rule: %w", err)
- }
- removalCount--
- }
+ err = c.removeRouteRulesForTable(netlink.FAMILY_V6, realTable)
+ if err != nil {
+ return fmt.Errorf("failed to remove v6 route rules for %s: %w", interfaceId, err)
}
}
return nil
}
-func (c LocalController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error {
- // TODO implement me
- panic("implement me")
+func (c LocalController) removeRoutesForFamily(
+ interfaceId domain.InterfaceIdentifier,
+ link netlink.Link,
+ family int,
+ table int,
+ fwMark uint32,
+ cidrs []domain.Cidr,
+) error {
+ // first remove all rules
+ existingRules, err := c.nl.RuleList(family)
+ if err != nil {
+ return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
+ }
+ for _, existingRule := range existingRules {
+ if fwMark == existingRule.Mark && table == existingRule.Table {
+ existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
+ if err := c.nl.RuleDel(&existingRule); err != nil {
+ return fmt.Errorf("failed to delete old fwmark rule: %w", err)
+ }
+ }
+ }
+
+ // next remove all routes
+ rawRoutes, err := c.nl.RouteListFiltered(family, &netlink.Route{
+ LinkIndex: link.Attrs().Index,
+ Table: unix.RT_TABLE_UNSPEC, // all tables
+ Scope: unix.RT_SCOPE_LINK,
+ Type: unix.RTN_UNICAST,
+ }, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
+ if err != nil {
+ return fmt.Errorf("failed to fetch raw routes for interface %s and family-id %d: %w",
+ interfaceId, family, err)
+ }
+ for _, rawRoute := range rawRoutes {
+ if rawRoute.Dst == nil { // handle default route
+ var netlinkAddr domain.Cidr
+ if family == netlink.FAMILY_V4 {
+ netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
+ } else {
+ netlinkAddr, _ = domain.CidrFromString("::/0")
+ }
+ rawRoute.Dst = netlinkAddr.IpNet()
+ }
+
+ if rawRoute.Table != table {
+ continue // ignore routes from other tables
+ }
+
+ route := domain.CidrFromIpNet(*rawRoute.Dst)
+ if !slices.Contains(cidrs, route) {
+ continue // only remove routes that were previously added
+ }
+
+ if err := c.nl.RouteDel(&rawRoute); err != nil {
+ return fmt.Errorf("failed to remove old route %s from interface %s: %w", route, interfaceId, err)
+ }
+ }
+
+ return nil
+}
+
+func (c LocalController) removeRouteRulesForTable(
+ family int,
+ table int,
+) error {
+ existingRules, err := c.nl.RuleList(family)
+ if err != nil {
+ return fmt.Errorf("failed to get existing route rules for family-id %d: %w", family, err)
+ }
+ for _, existingRule := range existingRules {
+ if existingRule.Table == table {
+ err := c.nl.RuleDel(&existingRule)
+ if err != nil {
+ return fmt.Errorf("failed to delete old rule for table %d and family-id %d: %w", table, family, err)
+ }
+ }
+ }
+ return nil
}
// endregion routing-related
diff --git a/internal/adapters/wgcontroller/mikrotik.go b/internal/adapters/wgcontroller/mikrotik.go
index ac98094..e730bd7 100644
--- a/internal/adapters/wgcontroller/mikrotik.go
+++ b/internal/adapters/wgcontroller/mikrotik.go
@@ -15,6 +15,9 @@ import (
"github.com/h44z/wg-portal/internal/lowlevel"
)
+const MikrotikRouteDistance = 5
+const MikrotikDefaultRoutingTable = "main"
+
type MikrotikController struct {
coreCfg *config.Config
cfg *config.BackendMikrotik
@@ -22,8 +25,9 @@ type MikrotikController struct {
client *lowlevel.MikrotikApiClient
// Add mutexes to prevent race conditions
- interfaceMutexes sync.Map // map[domain.InterfaceIdentifier]*sync.Mutex
- peerMutexes sync.Map // map[domain.PeerIdentifier]*sync.Mutex
+ interfaceMutexes sync.Map // map[domain.InterfaceIdentifier]*sync.Mutex
+ peerMutexes sync.Map // map[domain.PeerIdentifier]*sync.Mutex
+ coreMutex sync.Mutex // for updating the core configuration such as routing table or DNS settings
}
func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik) (*MikrotikController, error) {
@@ -40,6 +44,7 @@ func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik)
interfaceMutexes: sync.Map{},
peerMutexes: sync.Map{},
+ coreMutex: sync.Mutex{},
}, nil
}
@@ -763,33 +768,404 @@ func (c *MikrotikController) DeletePeer(
// region wg-quick-related
-func (c *MikrotikController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
+func (c *MikrotikController) ExecuteInterfaceHook(
+ _ context.Context,
+ _ domain.InterfaceIdentifier,
+ _ string,
+) error {
// TODO implement me
- panic("implement me")
+ slog.Error("interface hooks are not yet supported for Mikrotik backends, please open an issue on GitHub")
+ return nil
}
-func (c *MikrotikController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
- // TODO implement me
- panic("implement me")
+func (c *MikrotikController) SetDNS(
+ ctx context.Context,
+ _ domain.InterfaceIdentifier,
+ dnsStr, _ string,
+) error {
+ // Lock the interface to prevent concurrent modifications
+ c.coreMutex.Lock()
+ defer c.coreMutex.Unlock()
+
+ // check if the server is already configured
+ wgReply := c.client.Get(ctx, "/ip/dns", &lowlevel.MikrotikRequestOptions{
+ PropList: []string{"servers"},
+ })
+ if wgReply.Status != lowlevel.MikrotikApiStatusOk {
+ return fmt.Errorf("unable to find WireGuard dns settings: %v", wgReply.Error)
+ }
+
+ var existingServers []string
+ existingServers = append(existingServers, strings.Split(wgReply.Data.GetString("servers"), ",")...)
+
+ newServers := strings.Split(dnsStr, ",")
+
+ mergedServers := slices.Clone(existingServers)
+ for _, s := range newServers {
+ if s == "" {
+ continue
+ }
+ if !slices.Contains(mergedServers, s) {
+ mergedServers = append(mergedServers, s)
+ }
+ }
+ mergedServersStr := strings.Join(mergedServers, ",")
+
+ reply := c.client.ExecList(ctx, "/ip/dns/set", lowlevel.GenericJsonObject{
+ "servers": mergedServersStr,
+ })
+ if reply.Status != lowlevel.MikrotikApiStatusOk {
+ return fmt.Errorf("failed to set DNS servers: %s: %v", mergedServersStr, reply.Error)
+ }
+
+ return nil
}
-func (c *MikrotikController) UnsetDNS(id domain.InterfaceIdentifier) error {
- // TODO implement me
- panic("implement me")
+func (c *MikrotikController) UnsetDNS(
+ ctx context.Context,
+ _ domain.InterfaceIdentifier,
+ dnsStr, _ string,
+) error {
+ // Lock the interface to prevent concurrent modifications
+ c.coreMutex.Lock()
+ defer c.coreMutex.Unlock()
+
+ // retrieve current DNS settings
+ wgReply := c.client.Get(ctx, "/ip/dns", &lowlevel.MikrotikRequestOptions{
+ PropList: []string{"servers"},
+ })
+ if wgReply.Status != lowlevel.MikrotikApiStatusOk {
+ return fmt.Errorf("unable to find WireGuard dns settings: %v", wgReply.Error)
+ }
+
+ var existingServers []string
+ existingServers = append(existingServers, strings.Split(wgReply.Data.GetString("servers"), ",")...)
+
+ oldServers := strings.Split(dnsStr, ",")
+
+ mergedServers := make([]string, 0, len(existingServers))
+ for _, s := range existingServers {
+ if s == "" {
+ continue
+ }
+ if !slices.Contains(oldServers, s) {
+ mergedServers = append(mergedServers, s) // only keep the servers that are not in the old list
+ }
+ }
+ mergedServersStr := strings.Join(mergedServers, ",")
+
+ reply := c.client.ExecList(ctx, "/ip/dns/set", lowlevel.GenericJsonObject{
+ "servers": mergedServersStr,
+ })
+ if reply.Status != lowlevel.MikrotikApiStatusOk {
+ return fmt.Errorf("failed to set DNS servers: %s: %v", mergedServersStr, reply.Error)
+ }
+
+ return nil
}
// endregion wg-quick-related
// region routing-related
-func (c *MikrotikController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error {
- // TODO implement me
- panic("implement me")
+// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op.
+func (c *MikrotikController) SetRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
+ interfaceId := info.Interface.Identifier
+ slog.Debug("setting mikrotik routes", "interface", interfaceId, "table", info.TableStr, "cidrs", info.AllowedIps)
+
+ // Mikrotik needs some time to apply the changes.
+ // If we don't wait, the routes might get created multiple times as the dynamic routes are not yet available.
+ time.Sleep(2 * time.Second)
+
+ tableName, err := c.getOrCreateRoutingTables(ctx, info.Interface.Identifier, info.TableStr)
+ if err != nil {
+ return fmt.Errorf("failed to get or create routing table for %s: %v", interfaceId, err)
+ }
+
+ cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
+
+ err = c.setRoutesForFamily(ctx, interfaceId, false, tableName, cidrsV4)
+ if err != nil {
+ return fmt.Errorf("failed to set IPv4 routes for %s: %v", interfaceId, err)
+ }
+
+ err = c.setRoutesForFamily(ctx, interfaceId, true, tableName, cidrsV6)
+ if err != nil {
+ return fmt.Errorf("failed to set IPv6 routes for %s: %v", interfaceId, err)
+ }
+
+ return nil
}
-func (c *MikrotikController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error {
- // TODO implement me
- panic("implement me")
+func (c *MikrotikController) resolveRouteTableName(name string) string {
+ name = strings.TrimSpace(name)
+
+ var mikrotikTableName string
+ switch strings.ToLower(name) {
+ case "", "0":
+ mikrotikTableName = MikrotikDefaultRoutingTable
+ case MikrotikDefaultRoutingTable:
+ return fmt.Sprintf("wgportal-%s",
+ MikrotikDefaultRoutingTable) // if the Mikrotik Main table should be used, the table-name should be left empty or set to "0".
+ default:
+ mikrotikTableName = name
+ }
+
+ return mikrotikTableName
+}
+
+func (c *MikrotikController) getOrCreateRoutingTables(
+ ctx context.Context,
+ interfaceId domain.InterfaceIdentifier,
+ table string,
+) (string, error) {
+ // retrieve current routing tables
+ wgReply := c.client.Query(ctx, "/routing/table", &lowlevel.MikrotikRequestOptions{
+ PropList: []string{
+ ".id", "dynamic", "fib", "name",
+ },
+ })
+ if wgReply.Status != lowlevel.MikrotikApiStatusOk {
+ return "", fmt.Errorf("unable to query routing tables: %v", wgReply.Error)
+ }
+
+ wantedTableName := c.resolveRouteTableName(table)
+
+ // check if the table already exists
+ for _, table := range wgReply.Data {
+ if table.GetString("name") == wantedTableName {
+ return wantedTableName, nil // already exists, nothing to do
+ }
+ }
+
+ // create the table if it does not exist
+ createReply := c.client.Create(ctx, "/routing/table", lowlevel.GenericJsonObject{
+ "name": wantedTableName,
+ "comment": fmt.Sprintf("Routing Table for %s", interfaceId),
+ "fib": strconv.FormatBool(true),
+ })
+ if createReply.Status != lowlevel.MikrotikApiStatusOk {
+ return "", fmt.Errorf("failed to create routing table %s: %v", wantedTableName, createReply.Error)
+ }
+
+ return wantedTableName, nil
+}
+
+func (c *MikrotikController) setRoutesForFamily(
+ ctx context.Context,
+ interfaceId domain.InterfaceIdentifier,
+ ipV6 bool,
+ table string,
+ cidrs []domain.Cidr,
+) error {
+ apiPath := "/ip/route"
+ if ipV6 {
+ apiPath = "/ipv6/route"
+ }
+
+ // retrieve current routes
+ wgReply := c.client.Query(ctx, apiPath, &lowlevel.MikrotikRequestOptions{
+ PropList: []string{
+ ".id", "disabled", "inactive", "distance", "dst-address", "dynamic", "gateway", "immediate-gw",
+ "routing-table", "scope", "target-scope", "client-dns", "comment", "disabled", "responder",
+ },
+ Filters: map[string]string{
+ "gateway": string(interfaceId),
+ },
+ })
+ if wgReply.Status != lowlevel.MikrotikApiStatusOk {
+ return fmt.Errorf("unable to find WireGuard IP route settings (v6=%t): %v", ipV6, wgReply.Error)
+ }
+
+ // first create or update the routes
+ for _, cidr := range cidrs {
+ // check if the route already exists
+ exists := false
+ for _, route := range wgReply.Data {
+ existingRoute, err := domain.CidrFromString(route.GetString("dst-address"))
+ if err != nil {
+ slog.Warn("failed to parse route destination address",
+ "cidr", route.GetString("dst-address"), "error", err)
+ continue
+ }
+ if existingRoute.EqualPrefix(cidr) && route.GetString("routing-table") == table {
+ exists = true
+ break
+ }
+ }
+ if exists {
+ continue // route already exists, nothing to do
+ }
+
+ // create the route
+ reply := c.client.Create(ctx, apiPath, lowlevel.GenericJsonObject{
+ "gateway": string(interfaceId),
+ "dst-address": cidr.String(),
+ "distance": strconv.Itoa(MikrotikRouteDistance),
+ "disabled": strconv.FormatBool(false),
+ "routing-table": table,
+ })
+ if reply.Status != lowlevel.MikrotikApiStatusOk {
+ return fmt.Errorf("failed to create new route %s via %s: %v", cidr.String(), interfaceId, reply.Error)
+ }
+ }
+
+ // finally, remove the routes that are not in the new list
+ for _, route := range wgReply.Data {
+ if route.GetBool("dynamic") {
+ continue // dynamic routes are not managed by the controller, nothing to do
+ }
+
+ existingRoute, err := domain.CidrFromString(route.GetString("dst-address"))
+ if err != nil {
+ slog.Warn("failed to parse route destination address",
+ "cidr", route.GetString("dst-address"), "error", err)
+ continue
+ }
+
+ valid := false
+ for _, cidr := range cidrs {
+ if existingRoute.EqualPrefix(cidr) {
+ valid = true
+ break
+ }
+ }
+ if valid {
+ continue // route is still valid, nothing to do
+ }
+
+ // remove the route
+ reply := c.client.Delete(ctx, apiPath+"/"+route.GetString(".id"))
+ if reply.Status != lowlevel.MikrotikApiStatusOk {
+ return fmt.Errorf("failed to remove outdated route %s: %v", existingRoute.String(), reply.Error)
+ }
+ }
+
+ return nil
+}
+
+// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op.
+func (c *MikrotikController) RemoveRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
+ interfaceId := info.Interface.Identifier
+ slog.Debug("removing mikrotik routes", "interface", interfaceId, "table", info.TableStr, "cidrs", info.AllowedIps)
+
+ tableName := c.resolveRouteTableName(info.TableStr)
+
+ cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
+
+ err := c.removeRoutesForFamily(ctx, interfaceId, false, tableName, cidrsV4)
+ if err != nil {
+ return fmt.Errorf("failed to remove IPv4 routes for %s: %v", interfaceId, err)
+ }
+
+ err = c.removeRoutesForFamily(ctx, interfaceId, true, tableName, cidrsV6)
+ if err != nil {
+ return fmt.Errorf("failed to remove IPv6 routes for %s: %v", interfaceId, err)
+ }
+
+ err = c.removeRoutingTable(ctx, tableName)
+ if err != nil {
+ return fmt.Errorf("failed to remove routing table for %s: %v", interfaceId, err)
+ }
+
+ return nil
+}
+
+func (c *MikrotikController) removeRoutesForFamily(
+ ctx context.Context,
+ interfaceId domain.InterfaceIdentifier,
+ ipV6 bool,
+ table string,
+ cidrs []domain.Cidr,
+) error {
+ apiPath := "/ip/route"
+ if ipV6 {
+ apiPath = "/ipv6/route"
+ }
+
+ // retrieve current routes
+ wgReply := c.client.Query(ctx, apiPath, &lowlevel.MikrotikRequestOptions{
+ PropList: []string{
+ ".id", "disabled", "inactive", "distance", "dst-address", "dynamic", "gateway", "immediate-gw",
+ "routing-table", "scope", "target-scope", "client-dns", "comment", "disabled", "responder",
+ },
+ Filters: map[string]string{
+ "gateway": string(interfaceId),
+ },
+ })
+ if wgReply.Status != lowlevel.MikrotikApiStatusOk {
+ return fmt.Errorf("unable to find WireGuard IP route settings (v6=%t): %v", ipV6, wgReply.Error)
+ }
+
+ // remove the routes from the list
+ for _, route := range wgReply.Data {
+ if route.GetBool("dynamic") {
+ continue // dynamic routes are not managed by the controller, nothing to do
+ }
+
+ existingRoute, err := domain.CidrFromString(route.GetString("dst-address"))
+ if err != nil {
+ slog.Warn("failed to parse route destination address",
+ "cidr", route.GetString("dst-address"), "error", err)
+ continue
+ }
+
+ remove := false
+ for _, cidr := range cidrs {
+ if existingRoute.EqualPrefix(cidr) && route.GetString("routing-table") == table {
+ remove = true
+ break
+ }
+ }
+ if !remove {
+ continue // route is still valid, nothing to do
+ }
+
+ // remove the route
+ reply := c.client.Delete(ctx, apiPath+"/"+route.GetString(".id"))
+ if reply.Status != lowlevel.MikrotikApiStatusOk {
+ return fmt.Errorf("failed to remove old route %s: %v", existingRoute.String(), reply.Error)
+ }
+ }
+
+ return nil
+}
+
+func (c *MikrotikController) removeRoutingTable(
+ ctx context.Context,
+ table string,
+) error {
+ if table == MikrotikDefaultRoutingTable {
+ return nil // we cannot remove the default table
+ }
+
+ // retrieve current routing tables
+ wgReply := c.client.Query(ctx, "/routing/table", &lowlevel.MikrotikRequestOptions{
+ PropList: []string{
+ ".id", "dynamic", "fib", "name",
+ },
+ })
+ if wgReply.Status != lowlevel.MikrotikApiStatusOk {
+ return fmt.Errorf("unable to query routing tables: %v", wgReply.Error)
+ }
+
+ for _, existingTable := range wgReply.Data {
+ if existingTable.GetBool("dynamic") {
+ continue // dynamic tables are not managed by the controller, nothing to do
+ }
+ if existingTable.GetString("name") != table {
+ continue // not the table we want to remove
+ }
+
+ // remove the table
+ reply := c.client.Delete(ctx, "/routing/table/"+existingTable.GetString(".id"))
+ if reply.Status != lowlevel.MikrotikApiStatusOk {
+ return fmt.Errorf("failed to remove routing table %s: %v", table, reply.Error)
+ }
+ return nil
+ }
+
+ return nil
}
// endregion routing-related
diff --git a/internal/adapters/wgquick.go b/internal/adapters/wgquick.go
deleted file mode 100644
index 992e69a..0000000
--- a/internal/adapters/wgquick.go
+++ /dev/null
@@ -1,113 +0,0 @@
-package adapters
-
-import (
- "bytes"
- "fmt"
- "log/slog"
- "os/exec"
- "strings"
-
- "github.com/h44z/wg-portal/internal"
- "github.com/h44z/wg-portal/internal/domain"
-)
-
-// WgQuickRepo implements higher level wg-quick like interactions like setting DNS, routing tables or interface hooks.
-type WgQuickRepo struct {
- shellCmd string
- resolvConfIfacePrefix string
-}
-
-// NewWgQuickRepo creates a new WgQuickRepo instance.
-func NewWgQuickRepo() *WgQuickRepo {
- return &WgQuickRepo{
- shellCmd: "bash",
- resolvConfIfacePrefix: "tun.",
- }
-}
-
-// ExecuteInterfaceHook executes the given hook command.
-// The hook command can contain the following placeholders:
-//
-// %i: the interface identifier.
-func (r *WgQuickRepo) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
- if hookCmd == "" {
- return nil
- }
-
- slog.Debug("executing interface hook", "interface", id, "hook", hookCmd)
- err := r.exec(hookCmd, id)
- if err != nil {
- return fmt.Errorf("failed to exec hook: %w", err)
- }
-
- return nil
-}
-
-// SetDNS sets the DNS settings for the given interface. It uses resolvconf to set the DNS settings.
-func (r *WgQuickRepo) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
- if dnsStr == "" && dnsSearchStr == "" {
- return nil
- }
-
- dnsServers := internal.SliceString(dnsStr)
- dnsSearchDomains := internal.SliceString(dnsSearchStr)
-
- dnsCommand := "resolvconf -a %resPref%i -m 0 -x"
- dnsCommandInput := make([]string, 0, len(dnsServers)+len(dnsSearchDomains))
-
- for _, dnsServer := range dnsServers {
- dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("nameserver %s", dnsServer))
- }
- for _, searchDomain := range dnsSearchDomains {
- dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("search %s", searchDomain))
- }
-
- err := r.exec(dnsCommand, id, dnsCommandInput...)
- if err != nil {
- return fmt.Errorf(
- "failed to set dns settings (is resolvconf available?, for systemd create this symlink: ln -s /usr/bin/resolvectl /usr/local/bin/resolvconf): %w",
- err,
- )
- }
-
- return nil
-}
-
-// UnsetDNS unsets the DNS settings for the given interface. It uses resolvconf to unset the DNS settings.
-func (r *WgQuickRepo) UnsetDNS(id domain.InterfaceIdentifier) error {
- dnsCommand := "resolvconf -d %resPref%i -f"
-
- err := r.exec(dnsCommand, id)
- if err != nil {
- return fmt.Errorf("failed to unset dns settings: %w", err)
- }
-
- return nil
-}
-
-func (r *WgQuickRepo) replaceCommandPlaceHolders(command string, interfaceId domain.InterfaceIdentifier) string {
- command = strings.ReplaceAll(command, "%resPref", r.resolvConfIfacePrefix)
- return strings.ReplaceAll(command, "%i", string(interfaceId))
-}
-
-func (r *WgQuickRepo) exec(command string, interfaceId domain.InterfaceIdentifier, stdin ...string) error {
- commandWithInterfaceName := r.replaceCommandPlaceHolders(command, interfaceId)
- cmd := exec.Command(r.shellCmd, "-ce", commandWithInterfaceName)
- if len(stdin) > 0 {
- b := &bytes.Buffer{}
- for _, ln := range stdin {
- if _, err := fmt.Fprint(b, ln); err != nil {
- return err
- }
- }
- cmd.Stdin = b
- }
- out, err := cmd.CombinedOutput() // execute and wait for output
- if err != nil {
- return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err)
- }
- slog.Debug("executed shell command",
- "command", commandWithInterfaceName,
- "output", string(out))
- return nil
-}
diff --git a/internal/app/route/routes.go b/internal/app/route/routes.go
index c87bcaf..62cd67e 100644
--- a/internal/app/route/routes.go
+++ b/internal/app/route/routes.go
@@ -4,25 +4,23 @@ import (
"context"
"fmt"
"log/slog"
-
- "github.com/vishvananda/netlink"
- "golang.org/x/sys/unix"
- "golang.zx2c4.com/wireguard/wgctrl"
- "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+ "sync"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
- "github.com/h44z/wg-portal/internal/lowlevel"
)
// region dependencies
+type ControllerManager interface {
+ // GetController returns the controller for the given interface.
+ GetController(iface domain.Interface) domain.InterfaceController
+}
+
type InterfaceAndPeerDatabaseRepo interface {
- // GetAllInterfaces returns all interfaces
- GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
- // GetInterfacePeers returns all peers for a given interface
- GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
+ // GetInterface returns the interface with the given identifier.
+ GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
}
type EventBus interface {
@@ -30,6 +28,13 @@ type EventBus interface {
Subscribe(topic string, fn interface{}) error
}
+type RoutesController interface {
+ // SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op.
+ SetRoutes(ctx context.Context, info domain.RoutingTableInfo) error
+ // RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op.
+ RemoveRoutes(ctx context.Context, info domain.RoutingTableInfo) error
+}
+
// endregion dependencies
type routeRuleInfo struct {
@@ -45,28 +50,27 @@ type routeRuleInfo struct {
type Manager struct {
cfg *config.Config
- bus EventBus
- wg lowlevel.WireGuardClient
- nl lowlevel.NetlinkClient
- db InterfaceAndPeerDatabaseRepo
+ bus EventBus
+ db InterfaceAndPeerDatabaseRepo
+ wgController ControllerManager
+
+ mux *sync.Mutex
}
// NewRouteManager creates a new route manager instance.
-func NewRouteManager(cfg *config.Config, bus EventBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) {
- wg, err := wgctrl.New()
- if err != nil {
- panic("failed to init wgctrl: " + err.Error())
- }
-
- nl := &lowlevel.NetlinkManager{}
-
+func NewRouteManager(
+ cfg *config.Config,
+ bus EventBus,
+ db InterfaceAndPeerDatabaseRepo,
+ wgController ControllerManager,
+) (*Manager, error) {
m := &Manager{
cfg: cfg,
bus: bus,
- db: db,
- wg: wg,
- nl: nl,
+ db: db,
+ wgController: wgController,
+ mux: &sync.Mutex{},
}
m.connectToMessageBus()
@@ -85,419 +89,82 @@ func (m Manager) StartBackgroundJobs(_ context.Context) {
// this is a no-op for now
}
-func (m Manager) handleRouteUpdateEvent(srcDescription string) {
- slog.Debug("handling route update event", "source", srcDescription)
+func (m Manager) handleRouteUpdateEvent(info domain.RoutingTableInfo) {
+ m.mux.Lock() // ensure that only one route update is processed at a time
+ defer m.mux.Unlock()
- err := m.syncRoutes(context.Background())
- if err != nil {
- slog.Error("failed to synchronize routes",
- "source", srcDescription,
- "error", err)
+ slog.Debug("handling route update event", "info", info.String())
+
+ if !info.ManagementEnabled() {
+ return // route management disabled
}
- slog.Debug("routes synchronized", "source", srcDescription)
+ err := m.syncRoutes(context.Background(), info)
+ if err != nil {
+ slog.Error("failed to synchronize routes",
+ "info", info.String(), "error", err)
+ return
+ }
+
+ slog.Debug("routes synchronized", "info", info.String())
}
func (m Manager) handleRouteRemoveEvent(info domain.RoutingTableInfo) {
+ m.mux.Lock() // ensure that only one route update is processed at a time
+ defer m.mux.Unlock()
+
slog.Debug("handling route remove event", "info", info.String())
if !info.ManagementEnabled() {
return // route management disabled
}
- if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V4); err != nil {
- slog.Error("failed to remove v4 fwmark rules", "error", err)
- }
- if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V6); err != nil {
- slog.Error("failed to remove v6 fwmark rules", "error", err)
- }
-
- slog.Debug("routes removed", "table", info.String())
-}
-
-func (m Manager) syncRoutes(ctx context.Context) error {
- interfaces, err := m.db.GetAllInterfaces(ctx)
+ err := m.removeRoutes(context.Background(), info)
if err != nil {
- return fmt.Errorf("failed to find all interfaces: %w", err)
+ slog.Error("failed to synchronize routes",
+ "info", info.String(), "error", err)
+ return
}
- rules := map[int][]routeRuleInfo{
- netlink.FAMILY_V4: nil,
- netlink.FAMILY_V6: nil,
- }
- for _, iface := range interfaces {
- if iface.IsDisabled() {
- continue // disabled interface does not need route entries
- }
- if !iface.ManageRoutingTable() {
- continue
- }
-
- peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
- if err != nil {
- return fmt.Errorf("failed to find peers for %s: %w", iface.Identifier, err)
- }
- allowedIPs := iface.GetAllowedIPs(peers)
- defRouteV4, defRouteV6 := m.containsDefaultRoute(allowedIPs)
-
- link, err := m.nl.LinkByName(string(iface.Identifier))
- if err != nil {
- return fmt.Errorf("failed to find physical link for %s: %w", iface.Identifier, err)
- }
-
- table, fwmark, err := m.getRoutingTableAndFwMark(&iface, link)
- if err != nil {
- return fmt.Errorf("failed to get table and fwmark for %s: %w", iface.Identifier, err)
- }
-
- if err := m.setInterfaceRoutes(link, table, allowedIPs); err != nil {
- return fmt.Errorf("failed to set routes for %s: %w", iface.Identifier, err)
- }
-
- if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V4, allowedIPs); err != nil {
- return fmt.Errorf("failed to remove deprecated v4 routes for %s: %w", iface.Identifier, err)
- }
- if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V6, allowedIPs); err != nil {
- return fmt.Errorf("failed to remove deprecated v6 routes for %s: %w", iface.Identifier, err)
- }
-
- if table != 0 {
- rules[netlink.FAMILY_V4] = append(rules[netlink.FAMILY_V4], routeRuleInfo{
- ifaceId: iface.Identifier,
- fwMark: fwmark,
- table: table,
- family: netlink.FAMILY_V4,
- hasDefault: defRouteV4,
- })
- }
- if table != 0 {
- rules[netlink.FAMILY_V6] = append(rules[netlink.FAMILY_V6], routeRuleInfo{
- ifaceId: iface.Identifier,
- fwMark: fwmark,
- table: table,
- family: netlink.FAMILY_V6,
- hasDefault: defRouteV6,
- })
- }
- }
-
- return m.syncRouteRules(rules)
+ slog.Debug("routes removed", "info", info.String())
}
-func (m Manager) syncRouteRules(allRules map[int][]routeRuleInfo) error {
- for family, rules := range allRules {
- // update fwmark rules
- if err := m.setFwMarkRules(rules, family); err != nil {
- return err
- }
-
- // update main rule
- if err := m.setMainRule(rules, family); err != nil {
- return err
- }
-
- // cleanup old main rules
- if err := m.cleanupMainRule(rules, family); err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func (m Manager) setFwMarkRules(rules []routeRuleInfo, family int) error {
- for _, rule := range rules {
- existingRules, err := m.nl.RuleList(family)
- if err != nil {
- return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
- }
-
- ruleExists := false
- for _, existingRule := range existingRules {
- if rule.fwMark == existingRule.Mark && rule.table == existingRule.Table {
- ruleExists = true
- break
- }
- }
-
- if ruleExists {
- continue // rule already exists, no need to recreate it
- }
-
- // create missing rule
- if err := m.nl.RuleAdd(&netlink.Rule{
- Family: family,
- Table: rule.table,
- Mark: rule.fwMark,
- Invert: true,
- SuppressIfgroup: -1,
- SuppressPrefixlen: -1,
- Priority: m.getRulePriority(existingRules),
- Mask: nil,
- Goto: -1,
- Flow: -1,
- }); err != nil {
- return fmt.Errorf("failed to setup rule for fwmark %d and table %d: %w", rule.fwMark, rule.table, err)
- }
- }
- return nil
-}
-
-func (m Manager) removeFwMarkRules(fwmark uint32, table int, family int) error {
- existingRules, err := m.nl.RuleList(family)
- if err != nil {
- return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
- }
-
- for _, existingRule := range existingRules {
- if fwmark == existingRule.Mark && table == existingRule.Table {
- existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
- if err := m.nl.RuleDel(&existingRule); err != nil {
- return fmt.Errorf("failed to delete fwmark rule: %w", err)
- }
- }
- }
- return nil
-}
-
-func (m Manager) setMainRule(rules []routeRuleInfo, family int) error {
- shouldHaveMainRule := false
- for _, rule := range rules {
- if rule.hasDefault == true {
- shouldHaveMainRule = true
- break
- }
- }
- if !shouldHaveMainRule {
+func (m Manager) syncRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
+ rc, ok := m.wgController.GetController(info.Interface).(RoutesController)
+ if !ok {
+ slog.Warn("no capable routes-controller found for interface", "interface", info.Interface.Identifier)
return nil
}
- existingRules, err := m.nl.RuleList(family)
+ if !info.Interface.ManageRoutingTable() {
+ slog.Debug("interface does not manage routing table, skipping route update",
+ "interface", info.Interface.Identifier)
+ return nil
+ }
+
+ err := rc.SetRoutes(ctx, info)
if err != nil {
- return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
+ return fmt.Errorf("failed to set routes for interface %s: %w", info.Interface.Identifier, err)
}
-
- ruleExists := false
- for _, existingRule := range existingRules {
- if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
- ruleExists = true
- break
- }
- }
-
- if ruleExists {
- return nil // rule already exists, skip re-creation
- }
-
- if err := m.nl.RuleAdd(&netlink.Rule{
- Family: family,
- Table: unix.RT_TABLE_MAIN,
- SuppressIfgroup: -1,
- SuppressPrefixlen: 0,
- Priority: m.getMainRulePriority(existingRules),
- Mark: 0,
- Mask: nil,
- Goto: -1,
- Flow: -1,
- }); err != nil {
- return fmt.Errorf("failed to setup rule for main table: %w", err)
- }
-
return nil
}
-func (m Manager) cleanupMainRule(rules []routeRuleInfo, family int) error {
- existingRules, err := m.nl.RuleList(family)
+func (m Manager) removeRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
+ rc, ok := m.wgController.GetController(info.Interface).(RoutesController)
+ if !ok {
+ slog.Warn("no capable routes-controller found for interface", "interface", info.Interface.Identifier)
+ return nil
+ }
+
+ if !info.Interface.ManageRoutingTable() {
+ slog.Debug("interface does not manage routing table, skipping route removal",
+ "interface", info.Interface.Identifier)
+ return nil
+ }
+
+ err := rc.RemoveRoutes(ctx, info)
if err != nil {
- return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
- }
-
- shouldHaveMainRule := false
- for _, rule := range rules {
- if rule.hasDefault == true {
- shouldHaveMainRule = true
- break
- }
- }
-
- mainRules := 0
- for _, existingRule := range existingRules {
- if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
- mainRules++
- }
- }
-
- removalCount := 0
- if mainRules > 1 {
- removalCount = mainRules - 1 // we only want one single rule
- }
- if !shouldHaveMainRule {
- removalCount = mainRules
- }
-
- for _, existingRule := range existingRules {
- if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
- if removalCount > 0 {
- existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
- if err := m.nl.RuleDel(&existingRule); err != nil {
- return fmt.Errorf("failed to delete main rule: %w", err)
- }
- removalCount--
- }
- }
- }
-
- return nil
-}
-
-func (m Manager) getMainRulePriority(existingRules []netlink.Rule) int {
- prio := m.cfg.Advanced.RulePrioOffset
- for {
- isFresh := true
- for _, existingRule := range existingRules {
- if existingRule.Priority == prio {
- isFresh = false
- break
- }
- }
- if isFresh {
- break
- } else {
- prio++
- }
- }
- return prio
-}
-
-func (m Manager) getRulePriority(existingRules []netlink.Rule) int {
- prio := 32700 // linux main rule has a prio of 32766
- for {
- isFresh := true
- for _, existingRule := range existingRules {
- if existingRule.Priority == prio {
- isFresh = false
- break
- }
- }
- if isFresh {
- break
- } else {
- prio--
- }
- }
- return prio
-}
-
-func (m Manager) setInterfaceRoutes(link netlink.Link, table int, allowedIPs []domain.Cidr) error {
- for _, allowedIP := range allowedIPs {
- err := m.nl.RouteReplace(&netlink.Route{
- LinkIndex: link.Attrs().Index,
- Dst: allowedIP.IpNet(),
- Table: table,
- Scope: unix.RT_SCOPE_LINK,
- Type: unix.RTN_UNICAST,
- })
- if err != nil {
- return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
- }
- }
-
- return nil
-}
-
-func (m Manager) removeDeprecatedRoutes(link netlink.Link, family int, allowedIPs []domain.Cidr) error {
- rawRoutes, err := m.nl.RouteListFiltered(family, &netlink.Route{
- LinkIndex: link.Attrs().Index,
- Table: unix.RT_TABLE_UNSPEC, // all tables
- Scope: unix.RT_SCOPE_LINK,
- Type: unix.RTN_UNICAST,
- }, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
- if err != nil {
- return fmt.Errorf("failed to fetch raw routes: %w", err)
- }
- for _, rawRoute := range rawRoutes {
- if rawRoute.Dst == nil { // handle default route
- var netlinkAddr domain.Cidr
- if family == netlink.FAMILY_V4 {
- netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
- } else {
- netlinkAddr, _ = domain.CidrFromString("::/0")
- }
- rawRoute.Dst = netlinkAddr.IpNet()
- }
-
- netlinkAddr := domain.CidrFromIpNet(*rawRoute.Dst)
- remove := true
- for _, allowedIP := range allowedIPs {
- if netlinkAddr == allowedIP {
- remove = false
- break
- }
- }
-
- if !remove {
- continue
- }
-
- err := m.nl.RouteDel(&rawRoute)
- if err != nil {
- return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err)
- }
+ return fmt.Errorf("failed to remove routes for interface %s: %w", info.Interface.Identifier, err)
}
return nil
}
-
-func (m Manager) getRoutingTableAndFwMark(iface *domain.Interface, link netlink.Link) (
- table int,
- fwmark uint32,
- err error,
-) {
- table = iface.GetRoutingTable()
- fwmark = iface.FirewallMark
-
- if fwmark == 0 {
- // generate a new (temporary) firewall mark based on the interface index
- fwmark = uint32(m.cfg.Advanced.RouteTableOffset + link.Attrs().Index)
- slog.Debug("using fwmark to handle routes",
- "interface", iface.Identifier,
- "fwmark", fwmark)
-
- // apply the temporary fwmark to the wireguard interface
- err = m.setFwMark(iface.Identifier, int(fwmark))
- }
- if table == 0 {
- table = int(fwmark) // generate a new routing table base on interface index
- slog.Debug("using routing table to handle default routes",
- "interface", iface.Identifier,
- "table", table)
- }
- return
-}
-
-func (m Manager) setFwMark(id domain.InterfaceIdentifier, fwmark int) error {
- err := m.wg.ConfigureDevice(string(id), wgtypes.Config{
- FirewallMark: &fwmark,
- })
- if err != nil {
- return fmt.Errorf("failed to update fwmark to: %d: %w", fwmark, err)
- }
- return nil
-}
-
-func (m Manager) containsDefaultRoute(allowedIPs []domain.Cidr) (ipV4, ipV6 bool) {
- for _, allowedIP := range allowedIPs {
- if ipV4 && ipV6 {
- break // speed up
- }
-
- if allowedIP.Prefix().Bits() == 0 {
- if allowedIP.IsV4() {
- ipV4 = true
- } else {
- ipV6 = true
- }
- }
- }
-
- return
-}
diff --git a/internal/app/wireguard/controller_manager.go b/internal/app/wireguard/controller_manager.go
index 2eea6af..0f6bd23 100644
--- a/internal/app/wireguard/controller_manager.go
+++ b/internal/app/wireguard/controller_manager.go
@@ -1,7 +1,6 @@
package wireguard
import (
- "context"
"fmt"
"log/slog"
"maps"
@@ -12,33 +11,9 @@ import (
"github.com/h44z/wg-portal/internal/domain"
)
-type InterfaceController interface {
- GetId() domain.InterfaceBackend
- GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
- GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
- GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
- SaveInterface(
- _ context.Context,
- id domain.InterfaceIdentifier,
- updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
- ) error
- DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
- SavePeer(
- _ context.Context,
- deviceId domain.InterfaceIdentifier,
- id domain.PeerIdentifier,
- updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
- ) error
- DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
- PingAddresses(
- ctx context.Context,
- addr string,
- ) (*domain.PingerResult, error)
-}
-
type backendInstance struct {
Config config.BackendBase // Config is the configuration for the backend instance.
- Implementation InterfaceController
+ Implementation domain.InterfaceController
}
type ControllerManager struct {
@@ -118,11 +93,11 @@ func (c *ControllerManager) logRegisteredControllers() {
}
}
-func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) InterfaceController {
+func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) domain.InterfaceController {
return c.getController(backend, "").Implementation
}
-func (c *ControllerManager) GetController(iface domain.Interface) InterfaceController {
+func (c *ControllerManager) GetController(iface domain.Interface) domain.InterfaceController {
return c.getController(iface.Backend, iface.Identifier).Implementation
}
diff --git a/internal/app/wireguard/wireguard.go b/internal/app/wireguard/wireguard.go
index b28f70e..e1e9dfa 100644
--- a/internal/app/wireguard/wireguard.go
+++ b/internal/app/wireguard/wireguard.go
@@ -38,9 +38,9 @@ type InterfaceAndPeerDatabaseRepo interface {
}
type WgQuickController interface {
- ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
- SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
- UnsetDNS(id domain.InterfaceIdentifier) error
+ ExecuteInterfaceHook(ctx context.Context, id domain.InterfaceIdentifier, hookCmd string) error
+ SetDNS(ctx context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
+ UnsetDNS(ctx context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
}
type EventBus interface {
@@ -53,11 +53,10 @@ type EventBus interface {
// endregion dependencies
type Manager struct {
- cfg *config.Config
- bus EventBus
- db InterfaceAndPeerDatabaseRepo
- wg *ControllerManager
- quick WgQuickController
+ cfg *config.Config
+ bus EventBus
+ db InterfaceAndPeerDatabaseRepo
+ wg *ControllerManager
userLockMap *sync.Map
}
@@ -66,7 +65,6 @@ func NewWireGuardManager(
cfg *config.Config,
bus EventBus,
wg *ControllerManager,
- quick WgQuickController,
db InterfaceAndPeerDatabaseRepo,
) (*Manager, error) {
m := &Manager{
@@ -74,7 +72,6 @@ func NewWireGuardManager(
bus: bus,
wg: wg,
db: db,
- quick: quick,
userLockMap: &sync.Map{},
}
diff --git a/internal/app/wireguard/wireguard_interfaces.go b/internal/app/wireguard/wireguard_interfaces.go
index 368d1eb..3dbe53f 100644
--- a/internal/app/wireguard/wireguard_interfaces.go
+++ b/internal/app/wireguard/wireguard_interfaces.go
@@ -453,7 +453,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
return err
}
- existingInterface, err := m.db.GetInterface(ctx, id)
+ existingInterface, existingPeers, err := m.db.GetInterfaceAndPeers(ctx, id)
if err != nil {
return fmt.Errorf("unable to find interface %s: %w", id, err)
}
@@ -462,21 +462,29 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
return fmt.Errorf("deletion not allowed: %w", err)
}
+ m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
+ Interface: *existingInterface,
+ AllowedIps: existingInterface.GetAllowedIPs(existingPeers),
+ FwMark: existingInterface.FirewallMark,
+ Table: existingInterface.GetRoutingTable(),
+ TableStr: existingInterface.RoutingTable,
+ IsDeleted: true,
+ })
+
now := time.Now()
existingInterface.Disabled = &now // simulate a disabled interface
existingInterface.DisabledReason = domain.DisabledReasonDeleted
- physicalInterface, _ := m.wg.GetController(*existingInterface).GetInterface(ctx, id)
-
- if err := m.handleInterfacePreSaveHooks(existingInterface, !existingInterface.IsDisabled(), false); err != nil {
+ if err := m.handleInterfacePreSaveHooks(ctx, existingInterface, !existingInterface.IsDisabled(),
+ false); err != nil {
return fmt.Errorf("pre-delete hooks failed: %w", err)
}
- if err := m.handleInterfacePreSaveActions(existingInterface); err != nil {
+ if err := m.handleInterfacePreSaveActions(ctx, existingInterface); err != nil {
return fmt.Errorf("pre-delete actions failed: %w", err)
}
- if err := m.deleteInterfacePeers(ctx, id); err != nil {
+ if err := m.deleteInterfacePeers(ctx, existingInterface, existingPeers); err != nil {
return fmt.Errorf("peer deletion failure: %w", err)
}
@@ -488,16 +496,12 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
return fmt.Errorf("deletion failure: %w", err)
}
- fwMark := existingInterface.FirewallMark
- if physicalInterface != nil && fwMark == 0 {
- fwMark = physicalInterface.FirewallMark
- }
- m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
- FwMark: fwMark,
- Table: existingInterface.GetRoutingTable(),
- })
-
- if err := m.handleInterfacePostSaveHooks(existingInterface, !existingInterface.IsDisabled(), false); err != nil {
+ if err := m.handleInterfacePostSaveHooks(
+ ctx,
+ existingInterface,
+ !existingInterface.IsDisabled(),
+ false,
+ ); err != nil {
return fmt.Errorf("post-delete hooks failed: %w", err)
}
@@ -516,17 +520,21 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
return nil, fmt.Errorf("interface validation failed: %w", err)
}
- oldEnabled, newEnabled := m.getInterfaceStateHistory(ctx, iface)
+ oldEnabled, newEnabled, routeTableChanged := false, !iface.IsDisabled(), false // if the interface did not exist, we assume it was not enabled
+ oldInterface, err := m.db.GetInterface(ctx, iface.Identifier)
+ if err == nil {
+ oldEnabled, newEnabled, routeTableChanged = m.getInterfaceStateHistory(oldInterface, iface)
+ }
- if err := m.handleInterfacePreSaveHooks(iface, oldEnabled, newEnabled); err != nil {
+ if err := m.handleInterfacePreSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil {
return nil, fmt.Errorf("pre-save hooks failed: %w", err)
}
- if err := m.handleInterfacePreSaveActions(iface); err != nil {
+ if err := m.handleInterfacePreSaveActions(ctx, iface); err != nil {
return nil, fmt.Errorf("pre-save actions failed: %w", err)
}
- err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
+ err = m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
iface.CopyCalculatedAttributes(i)
err := m.wg.GetController(*iface).SaveInterface(ctx, iface.Identifier,
@@ -569,20 +577,35 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
}
if iface.IsDisabled() {
- physicalInterface, _ := m.wg.GetController(*iface).GetInterface(ctx, iface.Identifier)
- fwMark := iface.FirewallMark
- if physicalInterface != nil && fwMark == 0 {
- fwMark = physicalInterface.FirewallMark
- }
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
- FwMark: fwMark,
- Table: iface.GetRoutingTable(),
+ Interface: *iface,
+ AllowedIps: iface.GetAllowedIPs(peers),
+ FwMark: iface.FirewallMark,
+ Table: iface.GetRoutingTable(),
+ TableStr: iface.RoutingTable,
})
} else {
- m.bus.Publish(app.TopicRouteUpdate, "interface updated: "+string(iface.Identifier))
+ m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
+ Interface: *iface,
+ AllowedIps: iface.GetAllowedIPs(peers),
+ FwMark: iface.FirewallMark,
+ Table: iface.GetRoutingTable(),
+ TableStr: iface.RoutingTable,
+ })
+ // if the route table changed, ensure that the old entries are remove
+ if routeTableChanged {
+ m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
+ Interface: *oldInterface,
+ AllowedIps: oldInterface.GetAllowedIPs(peers),
+ FwMark: oldInterface.FirewallMark,
+ Table: oldInterface.GetRoutingTable(),
+ TableStr: oldInterface.RoutingTable,
+ IsDeleted: true, // mark the old entries as deleted
+ })
+ }
}
- if err := m.handleInterfacePostSaveHooks(iface, oldEnabled, newEnabled); err != nil {
+ if err := m.handleInterfacePostSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil {
return nil, fmt.Errorf("post-save hooks failed: %w", err)
}
@@ -618,60 +641,90 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
return iface, nil
}
-func (m Manager) getInterfaceStateHistory(ctx context.Context, iface *domain.Interface) (oldEnabled, newEnabled bool) {
- oldInterface, err := m.db.GetInterface(ctx, iface.Identifier)
- if err != nil {
- return false, !iface.IsDisabled() // if the interface did not exist, we assume it was not enabled
- }
-
- return !oldInterface.IsDisabled(), !iface.IsDisabled()
+func (m Manager) getInterfaceStateHistory(
+ oldInterface *domain.Interface,
+ iface *domain.Interface,
+) (oldEnabled, newEnabled, routeTableChanged bool) {
+ return !oldInterface.IsDisabled(), !iface.IsDisabled(), oldInterface.RoutingTable != iface.RoutingTable
}
-func (m Manager) handleInterfacePreSaveActions(iface *domain.Interface) error {
- if !iface.IsDisabled() {
- if err := m.quick.SetDNS(iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
- return fmt.Errorf("failed to update dns settings: %w", err)
- }
- } else {
- if err := m.quick.UnsetDNS(iface.Identifier); err != nil {
- return fmt.Errorf("failed to clear dns settings: %w", err)
+func (m Manager) handleInterfacePreSaveActions(ctx context.Context, iface *domain.Interface) error {
+ wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
+ if !ok {
+ slog.Warn("failed to perform pre-save actions", "interface", iface.Identifier,
+ "error", "no capable controller found")
+ return nil
+ }
+
+ // update DNS settings only for client interfaces
+ if iface.Type == domain.InterfaceTypeClient || iface.Type == domain.InterfaceTypeAny {
+ if !iface.IsDisabled() {
+ if err := wgQuickController.SetDNS(ctx, iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
+ return fmt.Errorf("failed to update dns settings: %w", err)
+ }
+ } else {
+ if err := wgQuickController.UnsetDNS(ctx, iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
+ return fmt.Errorf("failed to clear dns settings: %w", err)
+ }
}
}
return nil
}
-func (m Manager) handleInterfacePreSaveHooks(iface *domain.Interface, oldEnabled, newEnabled bool) error {
+func (m Manager) handleInterfacePreSaveHooks(
+ ctx context.Context,
+ iface *domain.Interface,
+ oldEnabled, newEnabled bool,
+) error {
if oldEnabled == newEnabled {
return nil // do nothing if state did not change
}
slog.Debug("executing pre-save hooks", "interface", iface.Identifier, "up", newEnabled)
+ wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
+ if !ok {
+ slog.Warn("failed to execute pre-save hooks", "interface", iface.Identifier, "up", newEnabled,
+ "error", "no capable controller found")
+ return nil
+ }
+
if newEnabled {
- if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreUp); err != nil {
+ if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PreUp); err != nil {
return fmt.Errorf("failed to execute pre-up hook: %w", err)
}
} else {
- if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreDown); err != nil {
+ if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PreDown); err != nil {
return fmt.Errorf("failed to execute pre-down hook: %w", err)
}
}
return nil
}
-func (m Manager) handleInterfacePostSaveHooks(iface *domain.Interface, oldEnabled, newEnabled bool) error {
+func (m Manager) handleInterfacePostSaveHooks(
+ ctx context.Context,
+ iface *domain.Interface,
+ oldEnabled, newEnabled bool,
+) error {
if oldEnabled == newEnabled {
return nil // do nothing if state did not change
}
slog.Debug("executing post-save hooks", "interface", iface.Identifier, "up", newEnabled)
+ wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
+ if !ok {
+ slog.Warn("failed to execute post-save hooks", "interface", iface.Identifier, "up", newEnabled,
+ "error", "no capable controller found")
+ return nil
+ }
+
if newEnabled {
- if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostUp); err != nil {
+ if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PostUp); err != nil {
return fmt.Errorf("failed to execute post-up hook: %w", err)
}
} else {
- if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostDown); err != nil {
+ if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PostDown); err != nil {
return fmt.Errorf("failed to execute post-down hook: %w", err)
}
}
@@ -799,7 +852,7 @@ func (m Manager) getFreshListenPort(ctx context.Context) (port int, err error) {
func (m Manager) importInterface(
ctx context.Context,
- backend InterfaceController,
+ backend domain.InterfaceController,
in *domain.PhysicalInterface,
peers []domain.PhysicalPeer,
) error {
@@ -901,13 +954,9 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain
return nil
}
-func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) error {
- iface, allPeers, err := m.db.GetInterfaceAndPeers(ctx, id)
- if err != nil {
- return err
- }
+func (m Manager) deleteInterfacePeers(ctx context.Context, iface *domain.Interface, allPeers []domain.Peer) error {
for _, peer := range allPeers {
- err = m.wg.GetController(*iface).DeletePeer(ctx, id, peer.Identifier)
+ err := m.wg.GetController(*iface).DeletePeer(ctx, iface.Identifier, peer.Identifier)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err)
}
diff --git a/internal/app/wireguard/wireguard_peers.go b/internal/app/wireguard/wireguard_peers.go
index f3c5364..e42af28 100644
--- a/internal/app/wireguard/wireguard_peers.go
+++ b/internal/app/wireguard/wireguard_peers.go
@@ -388,9 +388,20 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
return fmt.Errorf("failed to delete peer %s: %w", id, err)
}
+ peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
+ if err != nil {
+ return fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
+ }
+
m.bus.Publish(app.TopicPeerDeleted, *peer)
// Update routes after peers have changed
- m.bus.Publish(app.TopicRouteUpdate, "peers updated")
+ m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
+ Interface: *iface,
+ AllowedIps: iface.GetAllowedIPs(peers),
+ FwMark: iface.FirewallMark,
+ Table: iface.GetRoutingTable(),
+ TableStr: iface.RoutingTable,
+ })
// Update interface after peers have changed
m.bus.Publish(app.TopicPeerInterfaceUpdated, peer.InterfaceIdentifier)
@@ -438,20 +449,26 @@ func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier)
// region helper-functions
func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
- interfaces := make(map[domain.InterfaceIdentifier]struct{})
+ interfaces := make(map[domain.InterfaceIdentifier]domain.Interface)
for _, peer := range peers {
- iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
- if err != nil {
- return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
+ // get interface from db if it is not yet in the map
+ if _, ok := interfaces[peer.InterfaceIdentifier]; !ok {
+ iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
+ if err != nil {
+ return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
+ }
+ interfaces[peer.InterfaceIdentifier] = *iface
}
+ iface := interfaces[peer.InterfaceIdentifier]
+
// Always save the peer to the backend, regardless of disabled/expired state
// The backend will handle the disabled state appropriately
- err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
+ err := m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
peer.CopyCalculatedAttributes(p)
- err := m.wg.GetController(*iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
+ err := m.wg.GetController(iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, peer)
return pp, nil
@@ -475,13 +492,22 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
Peer: *peer,
},
})
-
- interfaces[peer.InterfaceIdentifier] = struct{}{}
}
// Update routes after peers have changed
- if len(interfaces) != 0 {
- m.bus.Publish(app.TopicRouteUpdate, "peers updated")
+ for id, iface := range interfaces {
+ interfacePeers, err := m.db.GetInterfacePeers(ctx, id)
+ if err != nil {
+ return fmt.Errorf("failed to re-load peers for interface %s: %w", id, err)
+ }
+
+ m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
+ Interface: iface,
+ AllowedIps: iface.GetAllowedIPs(interfacePeers),
+ FwMark: iface.FirewallMark,
+ Table: iface.GetRoutingTable(),
+ TableStr: iface.RoutingTable,
+ })
}
for iface := range interfaces {
diff --git a/internal/config/backend.go b/internal/config/backend.go
index fa8ff2e..cee7c8b 100644
--- a/internal/config/backend.go
+++ b/internal/config/backend.go
@@ -13,6 +13,7 @@ type Backend struct {
// Local Backend-specific configuration
IgnoredLocalInterfaces []string `yaml:"ignored_local_interfaces"` // A list of interface names that should be ignored by this backend (e.g., "wg0")
+ LocalResolvconfPrefix string `yaml:"local_resolvconf_prefix"` // The prefix to use for interface names when passing them to resolvconf.
// External Backend-specific configuration
diff --git a/internal/config/config.go b/internal/config/config.go
index 4203099..338dbf6 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -134,6 +134,9 @@ func defaultConfig() *Config {
cfg.Backend = Backend{
Default: LocalBackendName, // local backend is the default (using wgcrtl)
+ // Most resolconf implementations use "tun." as a prefix for interface names.
+ // But systemd's implementation uses no prefix, for example.
+ LocalResolvconfPrefix: "tun.",
}
cfg.Web = WebConfig{
diff --git a/internal/domain/interface.go b/internal/domain/interface.go
index 32fc1c0..01c720c 100644
--- a/internal/domain/interface.go
+++ b/internal/domain/interface.go
@@ -13,6 +13,7 @@ import (
"golang.org/x/sys/unix"
"github.com/h44z/wg-portal/internal"
+ "github.com/h44z/wg-portal/internal/config"
)
const (
@@ -132,17 +133,30 @@ func (i *Interface) GetConfigFileName() string {
return filename
}
+// GetAllowedIPs returns the allowed IPs for the interface depending on the interface type and peers.
+// For example, if the interface type is Server, the allowed IPs are the IPs of the peers.
+// If the interface type is Client, the allowed IPs correspond to the AllowedIPsStr of the peers.
func (i *Interface) GetAllowedIPs(peers []Peer) []Cidr {
var allowedCidrs []Cidr
- for _, peer := range peers {
- for _, ip := range peer.Interface.Addresses {
- allowedCidrs = append(allowedCidrs, ip.HostAddr())
+ switch i.Type {
+ case InterfaceTypeServer, InterfaceTypeAny:
+ for _, peer := range peers {
+ for _, ip := range peer.Interface.Addresses {
+ allowedCidrs = append(allowedCidrs, ip.HostAddr())
+ }
+ if peer.ExtraAllowedIPsStr != "" {
+ extraIPs, err := CidrsFromString(peer.ExtraAllowedIPsStr)
+ if err == nil {
+ allowedCidrs = append(allowedCidrs, extraIPs...)
+ }
+ }
}
- if peer.ExtraAllowedIPsStr != "" {
- extraIPs, err := CidrsFromString(peer.ExtraAllowedIPsStr)
+ case InterfaceTypeClient:
+ for _, peer := range peers {
+ allowedIPs, err := CidrsFromString(peer.AllowedIPsStr.GetValue())
if err == nil {
- allowedCidrs = append(allowedCidrs, extraIPs...)
+ allowedCidrs = append(allowedCidrs, allowedIPs...)
}
}
}
@@ -159,6 +173,7 @@ func (i *Interface) ManageRoutingTable() bool {
//
// -1 if RoutingTable was set to "off" or an error occurred
func (i *Interface) GetRoutingTable() int {
+
routingTableStr := strings.ToLower(i.RoutingTable)
switch {
case routingTableStr == "":
@@ -166,6 +181,9 @@ func (i *Interface) GetRoutingTable() int {
case routingTableStr == "off":
return -1
case strings.HasPrefix(routingTableStr, "0x"):
+ if i.Backend != config.LocalBackendName {
+ return 0 // ignore numeric routing table numbers for non-local controllers
+ }
numberStr := strings.ReplaceAll(routingTableStr, "0x", "")
routingTable, err := strconv.ParseUint(numberStr, 16, 64)
if err != nil {
@@ -178,6 +196,9 @@ func (i *Interface) GetRoutingTable() int {
}
return int(routingTable)
default:
+ if i.Backend != config.LocalBackendName {
+ return 0 // ignore numeric routing table numbers for non-local controllers
+ }
routingTable, err := strconv.Atoi(routingTableStr)
if err != nil {
slog.Error("failed to parse routing table number", "table", routingTableStr, "error", err)
@@ -308,12 +329,18 @@ func MergeToPhysicalInterface(pi *PhysicalInterface, i *Interface) {
}
type RoutingTableInfo struct {
- FwMark uint32
- Table int
+ Interface Interface
+ AllowedIps []Cidr
+ FwMark uint32
+ Table int
+ TableStr string // the routing table number as string (used by mikrotik, linux uses the numeric value)
+ IsDeleted bool // true if the interface was deleted, false otherwise
}
func (r RoutingTableInfo) String() string {
- return fmt.Sprintf("%d -> %d", r.FwMark, r.Table)
+ v4, v6 := CidrsPerFamily(r.AllowedIps)
+ return fmt.Sprintf("%s: fwmark=%d; table=%d; routes_4=%d; routes_6=%d", r.Interface.Identifier, r.FwMark, r.Table,
+ len(v4), len(v6))
}
func (r RoutingTableInfo) ManagementEnabled() bool {
diff --git a/internal/domain/interface_controller.go b/internal/domain/interface_controller.go
new file mode 100644
index 0000000..efd3207
--- /dev/null
+++ b/internal/domain/interface_controller.go
@@ -0,0 +1,27 @@
+package domain
+
+import "context"
+
+type InterfaceController interface {
+ GetId() InterfaceBackend
+ GetInterfaces(_ context.Context) ([]PhysicalInterface, error)
+ GetInterface(_ context.Context, id InterfaceIdentifier) (*PhysicalInterface, error)
+ GetPeers(_ context.Context, deviceId InterfaceIdentifier) ([]PhysicalPeer, error)
+ SaveInterface(
+ _ context.Context,
+ id InterfaceIdentifier,
+ updateFunc func(pi *PhysicalInterface) (*PhysicalInterface, error),
+ ) error
+ DeleteInterface(_ context.Context, id InterfaceIdentifier) error
+ SavePeer(
+ _ context.Context,
+ deviceId InterfaceIdentifier,
+ id PeerIdentifier,
+ updateFunc func(pp *PhysicalPeer) (*PhysicalPeer, error),
+ ) error
+ DeletePeer(_ context.Context, deviceId InterfaceIdentifier, id PeerIdentifier) error
+ PingAddresses(
+ ctx context.Context,
+ addr string,
+ ) (*PingerResult, error)
+}
diff --git a/internal/domain/interface_test.go b/internal/domain/interface_test.go
index 54aa74d..9f0ee50 100644
--- a/internal/domain/interface_test.go
+++ b/internal/domain/interface_test.go
@@ -5,6 +5,8 @@ import (
"time"
"github.com/stretchr/testify/assert"
+
+ "github.com/h44z/wg-portal/internal/config"
)
func TestInterface_IsDisabledReturnsTrueWhenDisabled(t *testing.T) {
@@ -37,8 +39,9 @@ func TestInterface_GetConfigFileNameReturnsCorrectFileName(t *testing.T) {
assert.Equal(t, expected, iface.GetConfigFileName())
}
-func TestInterface_GetAllowedIPsReturnsCorrectCidrs(t *testing.T) {
+func TestInterface_GetAllowedIPsReturnsCorrectCidrsServerMode(t *testing.T) {
peer1 := Peer{
+ AllowedIPsStr: ConfigOption[string]{Value: "192.168.2.2/32"},
Interface: PeerInterfaceConfig{
Addresses: []Cidr{
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
@@ -46,16 +49,45 @@ func TestInterface_GetAllowedIPsReturnsCorrectCidrs(t *testing.T) {
},
}
peer2 := Peer{
+ AllowedIPsStr: ConfigOption[string]{Value: "10.0.2.2/32"},
+ ExtraAllowedIPsStr: "10.20.2.2/32",
Interface: PeerInterfaceConfig{
Addresses: []Cidr{
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
},
},
}
- iface := &Interface{}
+ iface := &Interface{Type: InterfaceTypeServer}
expected := []Cidr{
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
+ {Cidr: "10.20.2.2/32", Addr: "10.20.2.2", NetLength: 32},
+ }
+ assert.Equal(t, expected, iface.GetAllowedIPs([]Peer{peer1, peer2}))
+}
+
+func TestInterface_GetAllowedIPsReturnsCorrectCidrsClientMode(t *testing.T) {
+ peer1 := Peer{
+ AllowedIPsStr: ConfigOption[string]{Value: "192.168.2.2/32"},
+ Interface: PeerInterfaceConfig{
+ Addresses: []Cidr{
+ {Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
+ },
+ },
+ }
+ peer2 := Peer{
+ AllowedIPsStr: ConfigOption[string]{Value: "10.0.2.2/32"},
+ ExtraAllowedIPsStr: "10.20.2.2/32",
+ Interface: PeerInterfaceConfig{
+ Addresses: []Cidr{
+ {Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
+ },
+ },
+ }
+ iface := &Interface{Type: InterfaceTypeClient}
+ expected := []Cidr{
+ {Cidr: "192.168.2.2/32", Addr: "192.168.2.2", NetLength: 32},
+ {Cidr: "10.0.2.2/32", Addr: "10.0.2.2", NetLength: 32},
}
assert.Equal(t, expected, iface.GetAllowedIPs([]Peer{peer1, peer2}))
}
@@ -66,10 +98,22 @@ func TestInterface_ManageRoutingTableReturnsCorrectValue(t *testing.T) {
iface.RoutingTable = "100"
assert.True(t, iface.ManageRoutingTable())
+
+ iface = &Interface{RoutingTable: "off", Backend: config.LocalBackendName}
+ assert.False(t, iface.ManageRoutingTable())
+
+ iface.RoutingTable = "100"
+ assert.True(t, iface.ManageRoutingTable())
+
+ iface = &Interface{RoutingTable: "off", Backend: "mikrotik-xxx"}
+ assert.False(t, iface.ManageRoutingTable())
+
+ iface.RoutingTable = "100"
+ assert.True(t, iface.ManageRoutingTable())
}
func TestInterface_GetRoutingTableReturnsCorrectValue(t *testing.T) {
- iface := &Interface{RoutingTable: ""}
+ iface := &Interface{RoutingTable: "", Backend: config.LocalBackendName}
assert.Equal(t, 0, iface.GetRoutingTable())
iface.RoutingTable = "off"
@@ -81,3 +125,17 @@ func TestInterface_GetRoutingTableReturnsCorrectValue(t *testing.T) {
iface.RoutingTable = "200"
assert.Equal(t, 200, iface.GetRoutingTable())
}
+
+func TestInterface_GetRoutingTableNonLocal(t *testing.T) {
+ iface := &Interface{RoutingTable: "off", Backend: "something different"}
+ assert.Equal(t, -1, iface.GetRoutingTable())
+
+ iface.RoutingTable = "0"
+ assert.Equal(t, 0, iface.GetRoutingTable())
+
+ iface.RoutingTable = "100"
+ assert.Equal(t, 0, iface.GetRoutingTable())
+
+ iface.RoutingTable = "abc"
+ assert.Equal(t, 0, iface.GetRoutingTable())
+}
diff --git a/internal/domain/ip.go b/internal/domain/ip.go
index ee67413..6081b70 100644
--- a/internal/domain/ip.go
+++ b/internal/domain/ip.go
@@ -26,6 +26,10 @@ func (c Cidr) IsValid() bool {
return c.Prefix().IsValid()
}
+func (c Cidr) EqualPrefix(other Cidr) bool {
+ return c.Addr == other.Addr && c.NetLength == other.NetLength
+}
+
func CidrFromString(str string) (Cidr, error) {
prefix, err := netip.ParsePrefix(strings.TrimSpace(str))
if err != nil {
@@ -199,3 +203,26 @@ func (c Cidr) Contains(other Cidr) bool {
return subnet.Contains(otherIP)
}
+
+// ContainsDefaultRoute returns true if the given CIDRs contain a default route.
+func ContainsDefaultRoute(cidrs []Cidr) bool {
+ for _, allowedIP := range cidrs {
+ if allowedIP.Prefix().Bits() == 0 {
+ return true
+ }
+ }
+
+ return false
+}
+
+// CidrsPerFamily returns a slice of CIDRs, one for each family (IPv4 and IPv6).
+func CidrsPerFamily(cidrs []Cidr) (ipv4, ipv6 []Cidr) {
+ for _, cidr := range cidrs {
+ if cidr.IsV4() {
+ ipv4 = append(ipv4, cidr)
+ } else {
+ ipv6 = append(ipv6, cidr)
+ }
+ }
+ return
+}
diff --git a/internal/lowlevel/mikrotik.go b/internal/lowlevel/mikrotik.go
index 49ef1d7..0c86a13 100644
--- a/internal/lowlevel/mikrotik.go
+++ b/internal/lowlevel/mikrotik.go
@@ -267,6 +267,7 @@ func parseHttpResponse[T any](resp *http.Response, err error) MikrotikApiRespons
}
defer func(Body io.ReadCloser) {
+ _, _ = io.Copy(io.Discard, Body) // ensure to empty the body
err := Body.Close()
if err != nil {
slog.Error("failed to close response body", "error", err)