mirror of
https://github.com/h44z/wg-portal.git
synced 2025-08-09 15:02:24 +00:00
route management
This commit is contained in:
parent
2113999b22
commit
22949963cf
@ -8,6 +8,7 @@ import (
|
||||
"github.com/h44z/wg-portal/internal/app/auth"
|
||||
"github.com/h44z/wg-portal/internal/app/configfile"
|
||||
"github.com/h44z/wg-portal/internal/app/mail"
|
||||
"github.com/h44z/wg-portal/internal/app/route"
|
||||
"github.com/h44z/wg-portal/internal/app/users"
|
||||
"github.com/h44z/wg-portal/internal/app/wireguard"
|
||||
"os"
|
||||
@ -86,6 +87,10 @@ func main() {
|
||||
internal.AssertNoError(err)
|
||||
auditRecorder.StartBackgroundJobs(ctx)
|
||||
|
||||
routeManager, err := route.NewRouteManager(cfg, eventBus, database)
|
||||
internal.AssertNoError(err)
|
||||
routeManager.StartBackgroundJobs(ctx)
|
||||
|
||||
backend, err := app.New(cfg, eventBus, authenticator, userManager, wireGuardManager,
|
||||
statisticsCollector, cfgFileManager, mailManager)
|
||||
internal.AssertNoError(err)
|
||||
|
@ -6,9 +6,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
"github.com/h44z/wg-portal/internal/lowlevel"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"os"
|
||||
@ -299,202 +297,6 @@ func (r *WgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *WgRepo) SaveRoutes(_ context.Context, iface *domain.Interface, peers []domain.Peer) error {
|
||||
table := iface.GetRoutingTable()
|
||||
if table == -2 {
|
||||
logrus.Trace("ignoring route update, feature disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
if !iface.IsDisabled() {
|
||||
return r.setupRoutes(iface, peers)
|
||||
} else {
|
||||
return r.cleanupRoutes(iface, peers)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *WgRepo) setupRoutes(iface *domain.Interface, peers []domain.Peer) error {
|
||||
link, err := r.nl.LinkByName(string(iface.Identifier))
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to find physical interface %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
if link.Attrs().OperState == netlink.OperDown {
|
||||
return nil // cannot set route for interface that is down
|
||||
}
|
||||
|
||||
table, fwmark, err := r.getRoutingTableAndFwMark(iface, peers, link)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get table and fwmark: %w", err)
|
||||
}
|
||||
|
||||
// try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash)
|
||||
allowedIPs := iface.GetAllowedIPs(peers)
|
||||
for _, allowedIP := range allowedIPs {
|
||||
if allowedIP.Prefix().Bits() == 0 { // default route handling
|
||||
if err := r.nl.RouteReplace(&netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Dst: allowedIP.IpNet(),
|
||||
Table: table,
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
Type: unix.RTN_UNICAST,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
|
||||
}
|
||||
|
||||
family := netlink.FAMILY_V4
|
||||
if !allowedIP.IsV4() {
|
||||
family = netlink.FAMILY_V6
|
||||
}
|
||||
if err := r.nl.RuleAdd(&netlink.Rule{
|
||||
Family: family,
|
||||
Table: table,
|
||||
Mark: fwmark,
|
||||
Invert: true,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: -1,
|
||||
Priority: -1,
|
||||
Mask: -1,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to setup rule for fwmark %d: %w", fwmark, err)
|
||||
}
|
||||
if err := r.nl.RuleAdd(&netlink.Rule{
|
||||
Family: family,
|
||||
Table: unix.RT_TABLE_MAIN,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: 0,
|
||||
Priority: -1,
|
||||
Mark: -1,
|
||||
Mask: -1,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to setup rule for main table: %w", err)
|
||||
}
|
||||
} else {
|
||||
err := r.nl.RouteReplace(&netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Dst: allowedIP.IpNet(),
|
||||
Table: table,
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
Type: unix.RTN_UNICAST,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.removeDeprecatedRoutes(link, allowedIPs, netlink.FAMILY_V4); err != nil {
|
||||
return fmt.Errorf("failed to remove deprecated v4 routes: %w", err)
|
||||
}
|
||||
if err := r.removeDeprecatedRoutes(link, allowedIPs, netlink.FAMILY_V6); err != nil {
|
||||
return fmt.Errorf("failed to remove deprecated v6 routes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *WgRepo) removeDeprecatedRoutes(link netlink.Link, allowedIPs []domain.Cidr, family int) error {
|
||||
rawRoutes, err := r.nl.RouteListFiltered(family, &netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Table: unix.RT_TABLE_UNSPEC, // all tables
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
Type: unix.RTN_UNICAST,
|
||||
}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch raw routes: %w", err)
|
||||
}
|
||||
for _, rawRoute := range rawRoutes {
|
||||
var netlinkAddr domain.Cidr
|
||||
if rawRoute.Dst == nil {
|
||||
if family == netlink.FAMILY_V4 {
|
||||
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
|
||||
} else {
|
||||
netlinkAddr, _ = domain.CidrFromString("::/0")
|
||||
}
|
||||
} else {
|
||||
netlinkAddr = domain.CidrFromIpNet(*rawRoute.Dst)
|
||||
}
|
||||
remove := true
|
||||
for _, allowedIP := range allowedIPs {
|
||||
if netlinkAddr == allowedIP {
|
||||
remove = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !remove {
|
||||
continue
|
||||
}
|
||||
|
||||
err := r.nl.RouteDel(&rawRoute)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *WgRepo) cleanupRoutes(iface *domain.Interface, peers []domain.Peer) error {
|
||||
link, err := r.nl.LinkByName(string(iface.Identifier))
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to find physical interface %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
table, _, err := r.getRoutingTableAndFwMark(iface, peers, link)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get table and fwmark: %w", err)
|
||||
}
|
||||
|
||||
if table == 0 {
|
||||
return nil // noting to remove
|
||||
}
|
||||
|
||||
delRule := netlink.NewRule()
|
||||
delRule.Family = netlink.FAMILY_ALL
|
||||
delRule.Table = table
|
||||
if err := r.nl.RuleDel(delRule); err != nil && !errors.Is(err, unix.ENOENT) {
|
||||
return fmt.Errorf("failed to delete rule for table %d: %w", table, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *WgRepo) getRoutingTableAndFwMark(iface *domain.Interface, peers []domain.Peer, link netlink.Link) (table, fwmark int, err error) {
|
||||
allowedIPs := iface.GetAllowedIPs(peers)
|
||||
containsDefaultRoute := false
|
||||
for _, allowedIP := range allowedIPs {
|
||||
if allowedIP.Prefix().Bits() == 0 {
|
||||
containsDefaultRoute = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
table = iface.GetRoutingTable()
|
||||
fwmark = int(iface.FirewallMark)
|
||||
|
||||
if containsDefaultRoute && table <= 0 {
|
||||
table = 20000 + link.Attrs().Index // generate a new routing table base on interface index
|
||||
logrus.Debugf("using routing table %d to handle default routes", table)
|
||||
}
|
||||
if containsDefaultRoute && fwmark == 0 {
|
||||
fwmark = 20000 + link.Attrs().Index // generate a new (temporary) firewall mark based on the interface index
|
||||
logrus.Debugf("using fwmark %d to handle default routes", table)
|
||||
|
||||
// apply the fwmark
|
||||
err = r.wg.ConfigureDevice(string(iface.Identifier), wgtypes.Config{
|
||||
FirewallMark: &fwmark,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to update temporary fwmark to: %d: %w", fwmark, err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
|
||||
if err := r.deleteLowLevelInterface(id); err != nil {
|
||||
return err
|
||||
|
@ -5,3 +5,5 @@ const TopicUserRegistered = "user:registered"
|
||||
const TopicUserDisabled = "user:disabled"
|
||||
const TopicUserDeleted = "user:deleted"
|
||||
const TopicAuthLogin = "auth:login"
|
||||
const TopicRouteUpdate = "route:update"
|
||||
const TopicRouteRemove = "route:remove"
|
||||
|
11
internal/app/route/repos.go
Normal file
11
internal/app/route/repos.go
Normal file
@ -0,0 +1,11 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
)
|
||||
|
||||
type InterfaceAndPeerDatabaseRepo interface {
|
||||
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
|
||||
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
|
||||
}
|
453
internal/app/route/routes.go
Normal file
453
internal/app/route/routes.go
Normal file
@ -0,0 +1,453 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/config"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
"github.com/h44z/wg-portal/internal/lowlevel"
|
||||
"github.com/sirupsen/logrus"
|
||||
evbus "github.com/vardius/message-bus"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type defaultRouteRule struct {
|
||||
ifaceId domain.InterfaceIdentifier
|
||||
fwMark int
|
||||
table int
|
||||
family int
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
cfg *config.Config
|
||||
bus evbus.MessageBus
|
||||
|
||||
wg lowlevel.WireGuardClient
|
||||
nl lowlevel.NetlinkClient
|
||||
db InterfaceAndPeerDatabaseRepo
|
||||
}
|
||||
|
||||
func NewRouteManager(cfg *config.Config, bus evbus.MessageBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) {
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
panic("failed to init wgctrl: " + err.Error())
|
||||
}
|
||||
|
||||
nl := &lowlevel.NetlinkManager{}
|
||||
|
||||
m := &Manager{
|
||||
cfg: cfg,
|
||||
bus: bus,
|
||||
|
||||
db: db,
|
||||
wg: wg,
|
||||
nl: nl,
|
||||
}
|
||||
|
||||
m.connectToMessageBus()
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m Manager) connectToMessageBus() {
|
||||
_ = m.bus.Subscribe(app.TopicRouteUpdate, m.handleRouteUpdateEvent)
|
||||
_ = m.bus.Subscribe(app.TopicRouteRemove, m.handleRouteUpdateEvent)
|
||||
}
|
||||
|
||||
func (m Manager) StartBackgroundJobs(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (m Manager) handleRouteUpdateEvent(srcDescription string) {
|
||||
logrus.Debugf("handling route update event: %s", srcDescription)
|
||||
|
||||
err := m.syncRoutes(context.Background())
|
||||
if err != nil {
|
||||
logrus.Errorf("failed to synchronize routes for event %s: %v", srcDescription, err)
|
||||
}
|
||||
|
||||
logrus.Debugf("routes synchronized, event: %s", srcDescription)
|
||||
}
|
||||
|
||||
func (m Manager) handleRouteRemoveEvent(info domain.RoutingTableInfo) {
|
||||
logrus.Debugf("handling route remove event for: %s", info.String())
|
||||
|
||||
if info.Table == -2 {
|
||||
return // route management disabled
|
||||
}
|
||||
|
||||
if err := m.removeFwMarkRules(info.FwMark, info.FwMark, netlink.FAMILY_V4); err != nil {
|
||||
logrus.Errorf("failed to remove v4 fwmark rules: %v", err)
|
||||
}
|
||||
if err := m.removeFwMarkRules(info.FwMark, info.FwMark, netlink.FAMILY_V6); err != nil {
|
||||
logrus.Errorf("failed to remove v6 fwmark rules: %v", err)
|
||||
}
|
||||
|
||||
logrus.Debugf("routes removed, table: %s", info.String())
|
||||
}
|
||||
|
||||
func (m Manager) syncRoutes(ctx context.Context) error {
|
||||
interfaces, err := m.db.GetAllInterfaces(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find all interfaces: %w", err)
|
||||
}
|
||||
|
||||
rules := map[int][]defaultRouteRule{
|
||||
netlink.FAMILY_V4: nil,
|
||||
netlink.FAMILY_V6: nil,
|
||||
}
|
||||
for _, iface := range interfaces {
|
||||
if iface.IsDisabled() {
|
||||
continue // disabled interface does not need route entries
|
||||
}
|
||||
|
||||
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find peers for %s: %w", iface.Identifier, err)
|
||||
}
|
||||
allowedIPs := iface.GetAllowedIPs(peers)
|
||||
defRouteV4, defRouteV6 := m.containsDefaultRoute(allowedIPs)
|
||||
|
||||
link, err := m.nl.LinkByName(string(iface.Identifier))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find physical link for %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
table, fwmark, err := m.getRoutingTableAndFwMark(&iface, allowedIPs, link)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get table and fwmark for %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
if err := m.setInterfaceRoutes(link, table, allowedIPs); err != nil {
|
||||
return fmt.Errorf("failed to set routes for %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V4, allowedIPs); err != nil {
|
||||
return fmt.Errorf("failed to remove deprecated v4 routes for %s: %w", iface.Identifier, err)
|
||||
}
|
||||
if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V6, allowedIPs); err != nil {
|
||||
return fmt.Errorf("failed to remove deprecated v6 routes for %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
if defRouteV4 {
|
||||
rules[netlink.FAMILY_V4] = append(rules[netlink.FAMILY_V4], defaultRouteRule{
|
||||
ifaceId: iface.Identifier,
|
||||
fwMark: fwmark,
|
||||
table: table,
|
||||
family: netlink.FAMILY_V4,
|
||||
})
|
||||
}
|
||||
if defRouteV6 {
|
||||
rules[netlink.FAMILY_V6] = append(rules[netlink.FAMILY_V6], defaultRouteRule{
|
||||
ifaceId: iface.Identifier,
|
||||
fwMark: fwmark,
|
||||
table: table,
|
||||
family: netlink.FAMILY_V6,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return m.syncRouteRules(rules)
|
||||
}
|
||||
|
||||
func (m Manager) syncRouteRules(allRules map[int][]defaultRouteRule) error {
|
||||
for family, rules := range allRules {
|
||||
// update fwmark rules
|
||||
if err := m.setFwMarkRules(rules, family); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update main rule
|
||||
if err := m.setMainRule(rules, family); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// cleanup old main rules
|
||||
if err := m.cleanupMainRule(rules, family); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) setFwMarkRules(rules []defaultRouteRule, family int) error {
|
||||
for _, rule := range rules {
|
||||
existingRules, err := m.nl.RuleList(family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
|
||||
}
|
||||
|
||||
ruleExists := false
|
||||
for _, existingRule := range existingRules {
|
||||
if rule.fwMark == existingRule.Mark && rule.table == existingRule.Table {
|
||||
ruleExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if ruleExists {
|
||||
continue // rule already exists, no need to recreate it
|
||||
}
|
||||
|
||||
// create missing rule
|
||||
if err := m.nl.RuleAdd(&netlink.Rule{
|
||||
Family: family,
|
||||
Table: rule.table,
|
||||
Mark: rule.fwMark,
|
||||
Invert: true,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: -1,
|
||||
Priority: m.getRulePriority(existingRules),
|
||||
Mask: -1,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to setup rule for fwmark %d and table %d: %w", rule.fwMark, rule.table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) removeFwMarkRules(fwmark, table int, family int) error {
|
||||
existingRules, err := m.nl.RuleList(family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
|
||||
}
|
||||
|
||||
for _, existingRule := range existingRules {
|
||||
if fwmark == existingRule.Mark && table == existingRule.Table {
|
||||
existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
|
||||
if err := m.nl.RuleDel(&existingRule); err != nil {
|
||||
return fmt.Errorf("failed to delete fwmark rule: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) setMainRule(rules []defaultRouteRule, family int) error {
|
||||
shouldHaveMainRule := len(rules) != 0
|
||||
if !shouldHaveMainRule {
|
||||
return nil
|
||||
}
|
||||
|
||||
existingRules, err := m.nl.RuleList(family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
|
||||
}
|
||||
|
||||
ruleExists := false
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
|
||||
ruleExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if ruleExists {
|
||||
return nil // rule already exists, skip re-creation
|
||||
}
|
||||
|
||||
if err := m.nl.RuleAdd(&netlink.Rule{
|
||||
Family: family,
|
||||
Table: unix.RT_TABLE_MAIN,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: 0,
|
||||
Priority: m.getMainRulePriority(existingRules),
|
||||
Mark: -1,
|
||||
Mask: -1,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to setup rule for main table: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) cleanupMainRule(rules []defaultRouteRule, family int) error {
|
||||
existingRules, err := m.nl.RuleList(family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
|
||||
}
|
||||
|
||||
shouldHaveMainRule := len(rules) != 0
|
||||
|
||||
mainRules := 0
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
|
||||
mainRules++
|
||||
}
|
||||
}
|
||||
|
||||
removalCount := 0
|
||||
if mainRules > 1 {
|
||||
removalCount = mainRules - 1 // we only want one single rule
|
||||
}
|
||||
if !shouldHaveMainRule {
|
||||
removalCount = mainRules
|
||||
}
|
||||
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
|
||||
if removalCount > 0 {
|
||||
existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
|
||||
if err := m.nl.RuleDel(&existingRule); err != nil {
|
||||
return fmt.Errorf("failed to delete main rule: %w", err)
|
||||
}
|
||||
removalCount--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) getRulePriority(existingRules []netlink.Rule) int {
|
||||
prio := m.cfg.Advanced.RulePrioOffset
|
||||
for {
|
||||
isFresh := true
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Priority == prio {
|
||||
isFresh = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if isFresh {
|
||||
break
|
||||
} else {
|
||||
prio++
|
||||
}
|
||||
}
|
||||
return prio
|
||||
}
|
||||
|
||||
func (m Manager) getMainRulePriority(existingRules []netlink.Rule) int {
|
||||
prio := 32700 // linux main rule has a prio of 32766
|
||||
for {
|
||||
isFresh := true
|
||||
for _, existingRule := range existingRules {
|
||||
if existingRule.Priority == prio {
|
||||
isFresh = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if isFresh {
|
||||
break
|
||||
} else {
|
||||
prio--
|
||||
}
|
||||
}
|
||||
return prio
|
||||
}
|
||||
|
||||
func (m Manager) setInterfaceRoutes(link netlink.Link, table int, allowedIPs []domain.Cidr) error {
|
||||
// try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash)
|
||||
for _, allowedIP := range allowedIPs {
|
||||
err := m.nl.RouteReplace(&netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Dst: allowedIP.IpNet(),
|
||||
Table: table,
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
Type: unix.RTN_UNICAST,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) removeDeprecatedRoutes(link netlink.Link, family int, allowedIPs []domain.Cidr) error {
|
||||
rawRoutes, err := m.nl.RouteListFiltered(family, &netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Table: unix.RT_TABLE_UNSPEC, // all tables
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
Type: unix.RTN_UNICAST,
|
||||
}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch raw routes: %w", err)
|
||||
}
|
||||
for _, rawRoute := range rawRoutes {
|
||||
var netlinkAddr domain.Cidr
|
||||
if rawRoute.Dst == nil {
|
||||
if family == netlink.FAMILY_V4 {
|
||||
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
|
||||
} else {
|
||||
netlinkAddr, _ = domain.CidrFromString("::/0")
|
||||
}
|
||||
} else {
|
||||
netlinkAddr = domain.CidrFromIpNet(*rawRoute.Dst)
|
||||
}
|
||||
remove := true
|
||||
for _, allowedIP := range allowedIPs {
|
||||
if netlinkAddr == allowedIP {
|
||||
remove = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !remove {
|
||||
continue
|
||||
}
|
||||
|
||||
err := m.nl.RouteDel(&rawRoute)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) getRoutingTableAndFwMark(iface *domain.Interface, allowedIPs []domain.Cidr, link netlink.Link) (table, fwmark int, err error) {
|
||||
defRouteV4, defRouteV6 := m.containsDefaultRoute(allowedIPs)
|
||||
|
||||
table = iface.GetRoutingTable()
|
||||
fwmark = int(iface.FirewallMark)
|
||||
|
||||
if (defRouteV4 || defRouteV6) && table <= 0 {
|
||||
table = m.cfg.Advanced.RouteTableOffset + link.Attrs().Index // generate a new routing table base on interface index
|
||||
logrus.Debugf("using routing table %d to handle default routes", table)
|
||||
}
|
||||
if (defRouteV4 || defRouteV6) && fwmark == 0 {
|
||||
fwmark = m.cfg.Advanced.RouteTableOffset + link.Attrs().Index // generate a new (temporary) firewall mark based on the interface index
|
||||
logrus.Debugf("using fwmark %d to handle default routes", table)
|
||||
|
||||
// apply the fwmark
|
||||
err = m.setFwMark(iface.Identifier, fwmark)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m Manager) setFwMark(id domain.InterfaceIdentifier, fwmark int) error {
|
||||
err := m.wg.ConfigureDevice(string(id), wgtypes.Config{
|
||||
FirewallMark: &fwmark,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update fwmark to: %d: %w", fwmark, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) containsDefaultRoute(allowedIPs []domain.Cidr) (ipV4, ipV6 bool) {
|
||||
for _, allowedIP := range allowedIPs {
|
||||
if ipV4 && ipV6 {
|
||||
break // speed up
|
||||
}
|
||||
|
||||
if allowedIP.Prefix().Bits() == 0 {
|
||||
if allowedIP.IsV4() {
|
||||
ipV4 = true
|
||||
} else {
|
||||
ipV6 = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
@ -41,7 +41,6 @@ type InterfaceController interface {
|
||||
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
||||
SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error
|
||||
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
||||
SaveRoutes(_ context.Context, iface *domain.Interface, peers []domain.Peer) error
|
||||
}
|
||||
|
||||
type WgQuickController interface {
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/h44z/wg-portal/internal"
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
"github.com/sirupsen/logrus"
|
||||
"time"
|
||||
@ -327,6 +328,8 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
|
||||
existingInterface.Disabled = &now // simulate a disabled interface
|
||||
existingInterface.DisabledReason = domain.DisabledReasonDeleted
|
||||
|
||||
physicalInterface, _ := m.wg.GetInterface(ctx, id)
|
||||
|
||||
if err := m.handleInterfacePreSaveHooks(true, existingInterface); err != nil {
|
||||
return fmt.Errorf("pre-delete hooks failed: %w", err)
|
||||
}
|
||||
@ -347,6 +350,13 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
|
||||
return fmt.Errorf("deletion failure: %w", err)
|
||||
}
|
||||
|
||||
if physicalInterface != nil {
|
||||
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
|
||||
FwMark: int(physicalInterface.FirewallMark),
|
||||
Table: existingInterface.GetRoutingTable(),
|
||||
})
|
||||
}
|
||||
|
||||
if err := m.handleInterfacePostSaveHooks(true, existingInterface); err != nil {
|
||||
return fmt.Errorf("post-delete hooks failed: %w", err)
|
||||
}
|
||||
@ -384,10 +394,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface, pee
|
||||
return nil, fmt.Errorf("failed to save interface: %w", err)
|
||||
}
|
||||
|
||||
err = m.wg.SaveRoutes(ctx, iface, peers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save routes: %w", err)
|
||||
}
|
||||
m.bus.Publish(app.TopicRouteUpdate, "interface updated: "+string(iface.Identifier))
|
||||
|
||||
if err := m.handleInterfacePostSaveHooks(stateChanged, iface); err != nil {
|
||||
return nil, fmt.Errorf("post-save hooks failed: %w", err)
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/h44z/wg-portal/internal"
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
"github.com/sirupsen/logrus"
|
||||
"time"
|
||||
@ -279,15 +280,8 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
|
||||
}
|
||||
|
||||
// Update routes after peers have changed
|
||||
for ifaceId := range interfaces {
|
||||
iface, ifacePeers, err := m.db.GetInterfaceAndPeers(ctx, ifaceId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load peer interface %s: %w", ifaceId, err)
|
||||
}
|
||||
err = m.wg.SaveRoutes(ctx, iface, ifacePeers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update peer routes on interface %s: %w", ifaceId, err)
|
||||
}
|
||||
if len(interfaces) != 0 {
|
||||
m.bus.Publish(app.TopicRouteUpdate, "peers updated")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -33,6 +33,8 @@ type Config struct {
|
||||
UseIpV6 bool `yaml:"use_ip_v6"`
|
||||
ConfigStoragePath string `yaml:"config_storage_path"` // keep empty to disable config export to file
|
||||
ExpiryCheckInterval time.Duration `yaml:"expiry_check_interval"`
|
||||
RulePrioOffset int `yaml:"rule_prio_offset"`
|
||||
RouteTableOffset int `yaml:"route_table_offset"`
|
||||
} `yaml:"advanced"`
|
||||
|
||||
Statistics struct {
|
||||
@ -106,6 +108,8 @@ func defaultConfig() *Config {
|
||||
cfg.Advanced.StartCidrV6 = "fdfd:d3ad:c0de:1234::0/64"
|
||||
cfg.Advanced.UseIpV6 = true
|
||||
cfg.Advanced.ExpiryCheckInterval = 15 * time.Minute
|
||||
cfg.Advanced.RulePrioOffset = 20000
|
||||
cfg.Advanced.RouteTableOffset = 20000
|
||||
|
||||
cfg.Statistics.UsePingChecks = true
|
||||
cfg.Statistics.PingCheckWorkers = 10
|
||||
|
@ -211,3 +211,12 @@ func MergeToPhysicalInterface(pi *PhysicalInterface, i *Interface) {
|
||||
pi.DeviceUp = !i.IsDisabled()
|
||||
pi.Addresses = i.Addresses
|
||||
}
|
||||
|
||||
type RoutingTableInfo struct {
|
||||
FwMark int
|
||||
Table int
|
||||
}
|
||||
|
||||
func (r RoutingTableInfo) String() string {
|
||||
return fmt.Sprintf("%d -> %d", r.FwMark, r.Table)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user