mirror of
https://github.com/h44z/wg-portal.git
synced 2025-10-07 08:56:18 +00:00
mikrotik: allow to set DNS, wip: handle routes in wg-controller
This commit is contained in:
@@ -5,24 +5,21 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
"github.com/h44z/wg-portal/internal/lowlevel"
|
||||
)
|
||||
|
||||
// region dependencies
|
||||
|
||||
type ControllerManager interface {
|
||||
// GetController returns the controller for the given interface.
|
||||
GetController(iface domain.Interface) domain.InterfaceController
|
||||
}
|
||||
|
||||
type InterfaceAndPeerDatabaseRepo interface {
|
||||
// GetAllInterfaces returns all interfaces
|
||||
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
|
||||
// GetInterfacePeers returns all peers for a given interface
|
||||
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
|
||||
// GetInterface returns the interface with the given identifier.
|
||||
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
|
||||
}
|
||||
|
||||
type EventBus interface {
|
||||
@@ -30,6 +27,25 @@ type EventBus interface {
|
||||
Subscribe(topic string, fn interface{}) error
|
||||
}
|
||||
|
||||
type RoutesController interface {
|
||||
// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op.
|
||||
SetRoutes(
|
||||
ctx context.Context,
|
||||
interfaceId domain.InterfaceIdentifier,
|
||||
table int,
|
||||
fwMark uint32,
|
||||
cidrs []domain.Cidr,
|
||||
) error
|
||||
// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op.
|
||||
RemoveRoutes(
|
||||
ctx context.Context,
|
||||
interfaceId domain.InterfaceIdentifier,
|
||||
table int,
|
||||
fwMark uint32,
|
||||
oldCidrs []domain.Cidr,
|
||||
) error
|
||||
}
|
||||
|
||||
// endregion dependencies
|
||||
|
||||
type routeRuleInfo struct {
|
||||
@@ -45,28 +61,24 @@ type routeRuleInfo struct {
|
||||
type Manager struct {
|
||||
cfg *config.Config
|
||||
|
||||
bus EventBus
|
||||
wg lowlevel.WireGuardClient
|
||||
nl lowlevel.NetlinkClient
|
||||
db InterfaceAndPeerDatabaseRepo
|
||||
bus EventBus
|
||||
db InterfaceAndPeerDatabaseRepo
|
||||
wgController ControllerManager
|
||||
}
|
||||
|
||||
// NewRouteManager creates a new route manager instance.
|
||||
func NewRouteManager(cfg *config.Config, bus EventBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) {
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
panic("failed to init wgctrl: " + err.Error())
|
||||
}
|
||||
|
||||
nl := &lowlevel.NetlinkManager{}
|
||||
|
||||
func NewRouteManager(
|
||||
cfg *config.Config,
|
||||
bus EventBus,
|
||||
db InterfaceAndPeerDatabaseRepo,
|
||||
wgController ControllerManager,
|
||||
) (*Manager, error) {
|
||||
m := &Manager{
|
||||
cfg: cfg,
|
||||
bus: bus,
|
||||
|
||||
db: db,
|
||||
wg: wg,
|
||||
nl: nl,
|
||||
db: db,
|
||||
wgController: wgController,
|
||||
}
|
||||
|
||||
m.connectToMessageBus()
|
||||
@@ -85,17 +97,21 @@ func (m Manager) StartBackgroundJobs(_ context.Context) {
|
||||
// this is a no-op for now
|
||||
}
|
||||
|
||||
func (m Manager) handleRouteUpdateEvent(srcDescription string) {
|
||||
slog.Debug("handling route update event", "source", srcDescription)
|
||||
func (m Manager) handleRouteUpdateEvent(info domain.RoutingTableInfo) {
|
||||
slog.Debug("handling route update event", "info", info.String())
|
||||
|
||||
err := m.syncRoutes(context.Background())
|
||||
if err != nil {
|
||||
slog.Error("failed to synchronize routes",
|
||||
"source", srcDescription,
|
||||
"error", err)
|
||||
if !info.ManagementEnabled() {
|
||||
return // route management disabled
|
||||
}
|
||||
|
||||
slog.Debug("routes synchronized", "source", srcDescription)
|
||||
err := m.syncRoutes(context.Background(), info)
|
||||
if err != nil {
|
||||
slog.Error("failed to synchronize routes",
|
||||
"info", info.String(), "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("routes synchronized", "info", info.String())
|
||||
}
|
||||
|
||||
func (m Manager) handleRouteRemoveEvent(info domain.RoutingTableInfo) {
|
||||
@@ -105,399 +121,40 @@ func (m Manager) handleRouteRemoveEvent(info domain.RoutingTableInfo) {
|
||||
return // route management disabled
|
||||
}
|
||||
|
||||
if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V4); err != nil {
|
||||
slog.Error("failed to remove v4 fwmark rules", "error", err)
|
||||
}
|
||||
if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V6); err != nil {
|
||||
slog.Error("failed to remove v6 fwmark rules", "error", err)
|
||||
}
|
||||
|
||||
slog.Debug("routes removed", "table", info.String())
|
||||
}
|
||||
|
||||
func (m Manager) syncRoutes(ctx context.Context) error {
|
||||
interfaces, err := m.db.GetAllInterfaces(ctx)
|
||||
err := m.removeRoutes(context.Background(), info)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find all interfaces: %w", err)
|
||||
slog.Error("failed to synchronize routes",
|
||||
"info", info.String(), "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
rules := map[int][]routeRuleInfo{
|
||||
netlink.FAMILY_V4: nil,
|
||||
netlink.FAMILY_V6: nil,
|
||||
}
|
||||
for _, iface := range interfaces {
|
||||
if iface.IsDisabled() {
|
||||
continue // disabled interface does not need route entries
|
||||
}
|
||||
if !iface.ManageRoutingTable() {
|
||||
continue
|
||||
}
|
||||
|
||||
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find peers for %s: %w", iface.Identifier, err)
|
||||
}
|
||||
allowedIPs := iface.GetAllowedIPs(peers)
|
||||
defRouteV4, defRouteV6 := m.containsDefaultRoute(allowedIPs)
|
||||
|
||||
link, err := m.nl.LinkByName(string(iface.Identifier))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find physical link for %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
table, fwmark, err := m.getRoutingTableAndFwMark(&iface, link)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get table and fwmark for %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
if err := m.setInterfaceRoutes(link, table, allowedIPs); err != nil {
|
||||
return fmt.Errorf("failed to set routes for %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V4, allowedIPs); err != nil {
|
||||
return fmt.Errorf("failed to remove deprecated v4 routes for %s: %w", iface.Identifier, err)
|
||||
}
|
||||
if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V6, allowedIPs); err != nil {
|
||||
return fmt.Errorf("failed to remove deprecated v6 routes for %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
if table != 0 {
|
||||
rules[netlink.FAMILY_V4] = append(rules[netlink.FAMILY_V4], routeRuleInfo{
|
||||
ifaceId: iface.Identifier,
|
||||
fwMark: fwmark,
|
||||
table: table,
|
||||
family: netlink.FAMILY_V4,
|
||||
hasDefault: defRouteV4,
|
||||
})
|
||||
}
|
||||
if table != 0 {
|
||||
rules[netlink.FAMILY_V6] = append(rules[netlink.FAMILY_V6], routeRuleInfo{
|
||||
ifaceId: iface.Identifier,
|
||||
fwMark: fwmark,
|
||||
table: table,
|
||||
family: netlink.FAMILY_V6,
|
||||
hasDefault: defRouteV6,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return m.syncRouteRules(rules)
|
||||
slog.Debug("routes removed", "info", info.String())
|
||||
}
|
||||
|
||||
func (m Manager) syncRouteRules(allRules map[int][]routeRuleInfo) error {
|
||||
for family, rules := range allRules {
|
||||
// update fwmark rules
|
||||
if err := m.setFwMarkRules(rules, family); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update main rule
|
||||
if err := m.setMainRule(rules, family); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// cleanup old main rules
|
||||
if err := m.cleanupMainRule(rules, family); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) setFwMarkRules(rules []routeRuleInfo, family int) error {
|
||||
for _, rule := range rules {
|
||||
existingRules, err := m.nl.RuleList(family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
|
||||
}
|
||||
|
||||
ruleExists := false
|
||||
for _, existingRule := range existingRules {
|
||||
if rule.fwMark == existingRule.Mark && rule.table == existingRule.Table {
|
||||
ruleExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if ruleExists {
|
||||
continue // rule already exists, no need to recreate it
|
||||
}
|
||||
|
||||
// create missing rule
|
||||
if err := m.nl.RuleAdd(&netlink.Rule{
|
||||
Family: family,
|
||||
Table: rule.table,
|
||||
Mark: rule.fwMark,
|
||||
Invert: true,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: -1,
|
||||
Priority: m.getRulePriority(existingRules),
|
||||
Mask: nil,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to setup rule for fwmark %d and table %d: %w", rule.fwMark, rule.table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) removeFwMarkRules(fwmark uint32, table int, family int) error {
|
||||
existingRules, err := m.nl.RuleList(family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
|
||||
}
|
||||
|
||||
for _, existingRule := range existingRules {
|
||||
if fwmark == existingRule.Mark && table == existingRule.Table {
|
||||
existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
|
||||
if err := m.nl.RuleDel(&existingRule); err != nil {
|
||||
return fmt.Errorf("failed to delete fwmark rule: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) setMainRule(rules []routeRuleInfo, family int) error {
|
||||
shouldHaveMainRule := false
|
||||
for _, rule := range rules {
|
||||
if rule.hasDefault == true {
|
||||
shouldHaveMainRule = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !shouldHaveMainRule {
|
||||
func (m Manager) syncRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
|
||||
rc, ok := m.wgController.GetController(info.Interface).(RoutesController)
|
||||
if !ok {
|
||||
slog.Warn("no capable routes-controller found for interface", "interface", info.Interface.Identifier)
|
||||
return nil
|
||||
}
|
||||
|
||||
existingRules, err := m.nl.RuleList(family)
|
||||
err := rc.SetRoutes(ctx, info.Interface.Identifier, info.Table, info.FwMark, info.AllowedIps)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
|
||||
return fmt.Errorf("failed to set routes for interface %s: %w", info.Interface.Identifier, err)
|
||||
}
|
||||
|
||||
ruleExists := false
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
|
||||
ruleExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if ruleExists {
|
||||
return nil // rule already exists, skip re-creation
|
||||
}
|
||||
|
||||
if err := m.nl.RuleAdd(&netlink.Rule{
|
||||
Family: family,
|
||||
Table: unix.RT_TABLE_MAIN,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: 0,
|
||||
Priority: m.getMainRulePriority(existingRules),
|
||||
Mark: 0,
|
||||
Mask: nil,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to setup rule for main table: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) cleanupMainRule(rules []routeRuleInfo, family int) error {
|
||||
existingRules, err := m.nl.RuleList(family)
|
||||
func (m Manager) removeRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
|
||||
rc, ok := m.wgController.GetController(info.Interface).(RoutesController)
|
||||
if !ok {
|
||||
slog.Warn("no capable routes-controller found for interface", "interface", info.Interface.Identifier)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := rc.RemoveRoutes(ctx, info.Interface.Identifier, info.Table, info.FwMark, info.AllowedIps)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
|
||||
}
|
||||
|
||||
shouldHaveMainRule := false
|
||||
for _, rule := range rules {
|
||||
if rule.hasDefault == true {
|
||||
shouldHaveMainRule = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
mainRules := 0
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
|
||||
mainRules++
|
||||
}
|
||||
}
|
||||
|
||||
removalCount := 0
|
||||
if mainRules > 1 {
|
||||
removalCount = mainRules - 1 // we only want one single rule
|
||||
}
|
||||
if !shouldHaveMainRule {
|
||||
removalCount = mainRules
|
||||
}
|
||||
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
|
||||
if removalCount > 0 {
|
||||
existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
|
||||
if err := m.nl.RuleDel(&existingRule); err != nil {
|
||||
return fmt.Errorf("failed to delete main rule: %w", err)
|
||||
}
|
||||
removalCount--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) getMainRulePriority(existingRules []netlink.Rule) int {
|
||||
prio := m.cfg.Advanced.RulePrioOffset
|
||||
for {
|
||||
isFresh := true
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Priority == prio {
|
||||
isFresh = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if isFresh {
|
||||
break
|
||||
} else {
|
||||
prio++
|
||||
}
|
||||
}
|
||||
return prio
|
||||
}
|
||||
|
||||
func (m Manager) getRulePriority(existingRules []netlink.Rule) int {
|
||||
prio := 32700 // linux main rule has a prio of 32766
|
||||
for {
|
||||
isFresh := true
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Priority == prio {
|
||||
isFresh = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if isFresh {
|
||||
break
|
||||
} else {
|
||||
prio--
|
||||
}
|
||||
}
|
||||
return prio
|
||||
}
|
||||
|
||||
func (m Manager) setInterfaceRoutes(link netlink.Link, table int, allowedIPs []domain.Cidr) error {
|
||||
for _, allowedIP := range allowedIPs {
|
||||
err := m.nl.RouteReplace(&netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Dst: allowedIP.IpNet(),
|
||||
Table: table,
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
Type: unix.RTN_UNICAST,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) removeDeprecatedRoutes(link netlink.Link, family int, allowedIPs []domain.Cidr) error {
|
||||
rawRoutes, err := m.nl.RouteListFiltered(family, &netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Table: unix.RT_TABLE_UNSPEC, // all tables
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
Type: unix.RTN_UNICAST,
|
||||
}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch raw routes: %w", err)
|
||||
}
|
||||
for _, rawRoute := range rawRoutes {
|
||||
if rawRoute.Dst == nil { // handle default route
|
||||
var netlinkAddr domain.Cidr
|
||||
if family == netlink.FAMILY_V4 {
|
||||
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
|
||||
} else {
|
||||
netlinkAddr, _ = domain.CidrFromString("::/0")
|
||||
}
|
||||
rawRoute.Dst = netlinkAddr.IpNet()
|
||||
}
|
||||
|
||||
netlinkAddr := domain.CidrFromIpNet(*rawRoute.Dst)
|
||||
remove := true
|
||||
for _, allowedIP := range allowedIPs {
|
||||
if netlinkAddr == allowedIP {
|
||||
remove = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !remove {
|
||||
continue
|
||||
}
|
||||
|
||||
err := m.nl.RouteDel(&rawRoute)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err)
|
||||
}
|
||||
return fmt.Errorf("failed to remove routes for interface %s: %w", info.Interface.Identifier, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) getRoutingTableAndFwMark(iface *domain.Interface, link netlink.Link) (
|
||||
table int,
|
||||
fwmark uint32,
|
||||
err error,
|
||||
) {
|
||||
table = iface.GetRoutingTable()
|
||||
fwmark = iface.FirewallMark
|
||||
|
||||
if fwmark == 0 {
|
||||
// generate a new (temporary) firewall mark based on the interface index
|
||||
fwmark = uint32(m.cfg.Advanced.RouteTableOffset + link.Attrs().Index)
|
||||
slog.Debug("using fwmark to handle routes",
|
||||
"interface", iface.Identifier,
|
||||
"fwmark", fwmark)
|
||||
|
||||
// apply the temporary fwmark to the wireguard interface
|
||||
err = m.setFwMark(iface.Identifier, int(fwmark))
|
||||
}
|
||||
if table == 0 {
|
||||
table = int(fwmark) // generate a new routing table base on interface index
|
||||
slog.Debug("using routing table to handle default routes",
|
||||
"interface", iface.Identifier,
|
||||
"table", table)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m Manager) setFwMark(id domain.InterfaceIdentifier, fwmark int) error {
|
||||
err := m.wg.ConfigureDevice(string(id), wgtypes.Config{
|
||||
FirewallMark: &fwmark,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update fwmark to: %d: %w", fwmark, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) containsDefaultRoute(allowedIPs []domain.Cidr) (ipV4, ipV6 bool) {
|
||||
for _, allowedIP := range allowedIPs {
|
||||
if ipV4 && ipV6 {
|
||||
break // speed up
|
||||
}
|
||||
|
||||
if allowedIP.Prefix().Bits() == 0 {
|
||||
if allowedIP.IsV4() {
|
||||
ipV4 = true
|
||||
} else {
|
||||
ipV6 = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
@@ -1,7 +1,6 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"maps"
|
||||
@@ -12,33 +11,9 @@ import (
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type InterfaceController interface {
|
||||
GetId() domain.InterfaceBackend
|
||||
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
|
||||
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
||||
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
||||
SaveInterface(
|
||||
_ context.Context,
|
||||
id domain.InterfaceIdentifier,
|
||||
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
|
||||
) error
|
||||
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
||||
SavePeer(
|
||||
_ context.Context,
|
||||
deviceId domain.InterfaceIdentifier,
|
||||
id domain.PeerIdentifier,
|
||||
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
|
||||
) error
|
||||
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
||||
PingAddresses(
|
||||
ctx context.Context,
|
||||
addr string,
|
||||
) (*domain.PingerResult, error)
|
||||
}
|
||||
|
||||
type backendInstance struct {
|
||||
Config config.BackendBase // Config is the configuration for the backend instance.
|
||||
Implementation InterfaceController
|
||||
Implementation domain.InterfaceController
|
||||
}
|
||||
|
||||
type ControllerManager struct {
|
||||
@@ -118,11 +93,11 @@ func (c *ControllerManager) logRegisteredControllers() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) InterfaceController {
|
||||
func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) domain.InterfaceController {
|
||||
return c.getController(backend, "").Implementation
|
||||
}
|
||||
|
||||
func (c *ControllerManager) GetController(iface domain.Interface) InterfaceController {
|
||||
func (c *ControllerManager) GetController(iface domain.Interface) domain.InterfaceController {
|
||||
return c.getController(iface.Backend, iface.Identifier).Implementation
|
||||
}
|
||||
|
||||
|
@@ -38,9 +38,9 @@ type InterfaceAndPeerDatabaseRepo interface {
|
||||
}
|
||||
|
||||
type WgQuickController interface {
|
||||
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
|
||||
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
|
||||
UnsetDNS(id domain.InterfaceIdentifier) error
|
||||
ExecuteInterfaceHook(ctx context.Context, id domain.InterfaceIdentifier, hookCmd string) error
|
||||
SetDNS(ctx context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
|
||||
UnsetDNS(ctx context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
|
||||
}
|
||||
|
||||
type EventBus interface {
|
||||
@@ -53,11 +53,10 @@ type EventBus interface {
|
||||
// endregion dependencies
|
||||
|
||||
type Manager struct {
|
||||
cfg *config.Config
|
||||
bus EventBus
|
||||
db InterfaceAndPeerDatabaseRepo
|
||||
wg *ControllerManager
|
||||
quick WgQuickController
|
||||
cfg *config.Config
|
||||
bus EventBus
|
||||
db InterfaceAndPeerDatabaseRepo
|
||||
wg *ControllerManager
|
||||
|
||||
userLockMap *sync.Map
|
||||
}
|
||||
@@ -66,7 +65,6 @@ func NewWireGuardManager(
|
||||
cfg *config.Config,
|
||||
bus EventBus,
|
||||
wg *ControllerManager,
|
||||
quick WgQuickController,
|
||||
db InterfaceAndPeerDatabaseRepo,
|
||||
) (*Manager, error) {
|
||||
m := &Manager{
|
||||
@@ -74,7 +72,6 @@ func NewWireGuardManager(
|
||||
bus: bus,
|
||||
wg: wg,
|
||||
db: db,
|
||||
quick: quick,
|
||||
userLockMap: &sync.Map{},
|
||||
}
|
||||
|
||||
|
@@ -453,7 +453,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
|
||||
return err
|
||||
}
|
||||
|
||||
existingInterface, err := m.db.GetInterface(ctx, id)
|
||||
existingInterface, existingPeers, err := m.db.GetInterfaceAndPeers(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to find interface %s: %w", id, err)
|
||||
}
|
||||
@@ -468,15 +468,16 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
|
||||
|
||||
physicalInterface, _ := m.wg.GetController(*existingInterface).GetInterface(ctx, id)
|
||||
|
||||
if err := m.handleInterfacePreSaveHooks(existingInterface, !existingInterface.IsDisabled(), false); err != nil {
|
||||
if err := m.handleInterfacePreSaveHooks(ctx, existingInterface, !existingInterface.IsDisabled(),
|
||||
false); err != nil {
|
||||
return fmt.Errorf("pre-delete hooks failed: %w", err)
|
||||
}
|
||||
|
||||
if err := m.handleInterfacePreSaveActions(existingInterface); err != nil {
|
||||
if err := m.handleInterfacePreSaveActions(ctx, existingInterface); err != nil {
|
||||
return fmt.Errorf("pre-delete actions failed: %w", err)
|
||||
}
|
||||
|
||||
if err := m.deleteInterfacePeers(ctx, id); err != nil {
|
||||
if err := m.deleteInterfacePeers(ctx, existingInterface, existingPeers); err != nil {
|
||||
return fmt.Errorf("peer deletion failure: %w", err)
|
||||
}
|
||||
|
||||
@@ -493,11 +494,18 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
|
||||
fwMark = physicalInterface.FirewallMark
|
||||
}
|
||||
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
|
||||
FwMark: fwMark,
|
||||
Table: existingInterface.GetRoutingTable(),
|
||||
Interface: *existingInterface,
|
||||
AllowedIps: existingInterface.GetAllowedIPs(existingPeers),
|
||||
FwMark: fwMark,
|
||||
Table: existingInterface.GetRoutingTable(),
|
||||
})
|
||||
|
||||
if err := m.handleInterfacePostSaveHooks(existingInterface, !existingInterface.IsDisabled(), false); err != nil {
|
||||
if err := m.handleInterfacePostSaveHooks(
|
||||
ctx,
|
||||
existingInterface,
|
||||
!existingInterface.IsDisabled(),
|
||||
false,
|
||||
); err != nil {
|
||||
return fmt.Errorf("post-delete hooks failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -518,11 +526,11 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
|
||||
|
||||
oldEnabled, newEnabled := m.getInterfaceStateHistory(ctx, iface)
|
||||
|
||||
if err := m.handleInterfacePreSaveHooks(iface, oldEnabled, newEnabled); err != nil {
|
||||
if err := m.handleInterfacePreSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil {
|
||||
return nil, fmt.Errorf("pre-save hooks failed: %w", err)
|
||||
}
|
||||
|
||||
if err := m.handleInterfacePreSaveActions(iface); err != nil {
|
||||
if err := m.handleInterfacePreSaveActions(ctx, iface); err != nil {
|
||||
return nil, fmt.Errorf("pre-save actions failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -575,14 +583,21 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
|
||||
fwMark = physicalInterface.FirewallMark
|
||||
}
|
||||
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
|
||||
FwMark: fwMark,
|
||||
Table: iface.GetRoutingTable(),
|
||||
Interface: *iface,
|
||||
AllowedIps: iface.GetAllowedIPs(peers),
|
||||
FwMark: fwMark,
|
||||
Table: iface.GetRoutingTable(),
|
||||
})
|
||||
} else {
|
||||
m.bus.Publish(app.TopicRouteUpdate, "interface updated: "+string(iface.Identifier))
|
||||
m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
|
||||
Interface: *iface,
|
||||
AllowedIps: iface.GetAllowedIPs(peers),
|
||||
FwMark: iface.FirewallMark,
|
||||
Table: iface.GetRoutingTable(),
|
||||
})
|
||||
}
|
||||
|
||||
if err := m.handleInterfacePostSaveHooks(iface, oldEnabled, newEnabled); err != nil {
|
||||
if err := m.handleInterfacePostSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil {
|
||||
return nil, fmt.Errorf("post-save hooks failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -627,51 +642,83 @@ func (m Manager) getInterfaceStateHistory(ctx context.Context, iface *domain.Int
|
||||
return !oldInterface.IsDisabled(), !iface.IsDisabled()
|
||||
}
|
||||
|
||||
func (m Manager) handleInterfacePreSaveActions(iface *domain.Interface) error {
|
||||
if !iface.IsDisabled() {
|
||||
if err := m.quick.SetDNS(iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
|
||||
return fmt.Errorf("failed to update dns settings: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := m.quick.UnsetDNS(iface.Identifier); err != nil {
|
||||
return fmt.Errorf("failed to clear dns settings: %w", err)
|
||||
func (m Manager) handleInterfacePreSaveActions(ctx context.Context, iface *domain.Interface) error {
|
||||
wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
|
||||
if !ok {
|
||||
slog.Warn("failed to perform pre-save actions", "interface", iface.Identifier,
|
||||
"error", "no capable controller found")
|
||||
return nil
|
||||
}
|
||||
|
||||
// update DNS settings only for client interfaces
|
||||
if iface.Type == domain.InterfaceTypeClient || iface.Type == domain.InterfaceTypeAny {
|
||||
if !iface.IsDisabled() {
|
||||
if err := wgQuickController.SetDNS(ctx, iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
|
||||
return fmt.Errorf("failed to update dns settings: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := wgQuickController.UnsetDNS(ctx, iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
|
||||
return fmt.Errorf("failed to clear dns settings: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) handleInterfacePreSaveHooks(iface *domain.Interface, oldEnabled, newEnabled bool) error {
|
||||
func (m Manager) handleInterfacePreSaveHooks(
|
||||
ctx context.Context,
|
||||
iface *domain.Interface,
|
||||
oldEnabled, newEnabled bool,
|
||||
) error {
|
||||
if oldEnabled == newEnabled {
|
||||
return nil // do nothing if state did not change
|
||||
}
|
||||
|
||||
slog.Debug("executing pre-save hooks", "interface", iface.Identifier, "up", newEnabled)
|
||||
|
||||
wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
|
||||
if !ok {
|
||||
slog.Warn("failed to execute pre-save hooks", "interface", iface.Identifier, "up", newEnabled,
|
||||
"error", "no capable controller found")
|
||||
return nil
|
||||
}
|
||||
|
||||
if newEnabled {
|
||||
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreUp); err != nil {
|
||||
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PreUp); err != nil {
|
||||
return fmt.Errorf("failed to execute pre-up hook: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreDown); err != nil {
|
||||
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PreDown); err != nil {
|
||||
return fmt.Errorf("failed to execute pre-down hook: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) handleInterfacePostSaveHooks(iface *domain.Interface, oldEnabled, newEnabled bool) error {
|
||||
func (m Manager) handleInterfacePostSaveHooks(
|
||||
ctx context.Context,
|
||||
iface *domain.Interface,
|
||||
oldEnabled, newEnabled bool,
|
||||
) error {
|
||||
if oldEnabled == newEnabled {
|
||||
return nil // do nothing if state did not change
|
||||
}
|
||||
|
||||
slog.Debug("executing post-save hooks", "interface", iface.Identifier, "up", newEnabled)
|
||||
|
||||
wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
|
||||
if !ok {
|
||||
slog.Warn("failed to execute post-save hooks", "interface", iface.Identifier, "up", newEnabled,
|
||||
"error", "no capable controller found")
|
||||
return nil
|
||||
}
|
||||
|
||||
if newEnabled {
|
||||
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostUp); err != nil {
|
||||
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PostUp); err != nil {
|
||||
return fmt.Errorf("failed to execute post-up hook: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostDown); err != nil {
|
||||
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PostDown); err != nil {
|
||||
return fmt.Errorf("failed to execute post-down hook: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -799,7 +846,7 @@ func (m Manager) getFreshListenPort(ctx context.Context) (port int, err error) {
|
||||
|
||||
func (m Manager) importInterface(
|
||||
ctx context.Context,
|
||||
backend InterfaceController,
|
||||
backend domain.InterfaceController,
|
||||
in *domain.PhysicalInterface,
|
||||
peers []domain.PhysicalPeer,
|
||||
) error {
|
||||
@@ -901,13 +948,9 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) error {
|
||||
iface, allPeers, err := m.db.GetInterfaceAndPeers(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
func (m Manager) deleteInterfacePeers(ctx context.Context, iface *domain.Interface, allPeers []domain.Peer) error {
|
||||
for _, peer := range allPeers {
|
||||
err = m.wg.GetController(*iface).DeletePeer(ctx, id, peer.Identifier)
|
||||
err := m.wg.GetController(*iface).DeletePeer(ctx, iface.Identifier, peer.Identifier)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err)
|
||||
}
|
||||
|
@@ -388,9 +388,19 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
|
||||
return fmt.Errorf("failed to delete peer %s: %w", id, err)
|
||||
}
|
||||
|
||||
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
m.bus.Publish(app.TopicPeerDeleted, *peer)
|
||||
// Update routes after peers have changed
|
||||
m.bus.Publish(app.TopicRouteUpdate, "peers updated")
|
||||
m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
|
||||
Interface: *iface,
|
||||
AllowedIps: iface.GetAllowedIPs(peers),
|
||||
FwMark: iface.FirewallMark,
|
||||
Table: iface.GetRoutingTable(),
|
||||
})
|
||||
// Update interface after peers have changed
|
||||
m.bus.Publish(app.TopicPeerInterfaceUpdated, peer.InterfaceIdentifier)
|
||||
|
||||
@@ -438,20 +448,28 @@ func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier)
|
||||
// region helper-functions
|
||||
|
||||
func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
|
||||
interfaces := make(map[domain.InterfaceIdentifier]struct{})
|
||||
interfaces := make(map[domain.InterfaceIdentifier]domain.Interface)
|
||||
interfacePeers := make(map[domain.InterfaceIdentifier][]domain.Peer)
|
||||
|
||||
for _, peer := range peers {
|
||||
iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
|
||||
// get interface from db if it is not yet in the map
|
||||
if _, ok := interfaces[peer.InterfaceIdentifier]; !ok {
|
||||
iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
|
||||
}
|
||||
interfaces[peer.InterfaceIdentifier] = *iface
|
||||
}
|
||||
|
||||
iface := interfaces[peer.InterfaceIdentifier]
|
||||
interfacePeers[iface.Identifier] = append(interfacePeers[iface.Identifier], *peer)
|
||||
|
||||
// Always save the peer to the backend, regardless of disabled/expired state
|
||||
// The backend will handle the disabled state appropriately
|
||||
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
||||
err := m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
||||
peer.CopyCalculatedAttributes(p)
|
||||
|
||||
err := m.wg.GetController(*iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
|
||||
err := m.wg.GetController(iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
|
||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
||||
domain.MergeToPhysicalPeer(pp, peer)
|
||||
return pp, nil
|
||||
@@ -475,13 +493,16 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
|
||||
Peer: *peer,
|
||||
},
|
||||
})
|
||||
|
||||
interfaces[peer.InterfaceIdentifier] = struct{}{}
|
||||
}
|
||||
|
||||
// Update routes after peers have changed
|
||||
if len(interfaces) != 0 {
|
||||
m.bus.Publish(app.TopicRouteUpdate, "peers updated")
|
||||
for id, iface := range interfaces {
|
||||
m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
|
||||
Interface: iface,
|
||||
AllowedIps: iface.GetAllowedIPs(interfacePeers[id]),
|
||||
FwMark: iface.FirewallMark,
|
||||
Table: iface.GetRoutingTable(),
|
||||
})
|
||||
}
|
||||
|
||||
for iface := range interfaces {
|
||||
|
Reference in New Issue
Block a user