mirror of
https://github.com/h44z/wg-portal.git
synced 2025-08-09 15:02:24 +00:00
route management
This commit is contained in:
parent
22949963cf
commit
d8624748b7
@ -331,17 +331,17 @@ func (r *SqlRepo) upsertInterface(ui *domain.ContextUserInfo, tx *gorm.DB, in *d
|
||||
|
||||
func (r *SqlRepo) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
|
||||
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
err := r.db.WithContext(ctx).Where("interface_identifier = ?", id).Delete(&domain.Peer{}).Error
|
||||
err := tx.Where("interface_identifier = ?", id).Delete(&domain.Peer{}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.db.WithContext(ctx).Delete(&domain.InterfaceStatus{InterfaceId: id}).Error
|
||||
err = tx.Delete(&domain.InterfaceStatus{InterfaceId: id}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.db.WithContext(ctx).Debug().Select(clause.Associations).Delete(&domain.Interface{Identifier: id}).Error
|
||||
err = tx.Select(clause.Associations).Delete(&domain.Interface{Identifier: id}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -518,12 +518,12 @@ func (r *SqlRepo) upsertPeer(ui *domain.ContextUserInfo, tx *gorm.DB, peer *doma
|
||||
|
||||
func (r *SqlRepo) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error {
|
||||
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
err := r.db.WithContext(ctx).Delete(&domain.PeerStatus{PeerId: id}).Error
|
||||
err := tx.Delete(&domain.PeerStatus{PeerId: id}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.db.WithContext(ctx).Select(clause.Associations).Delete(&domain.Peer{Identifier: id}).Error
|
||||
err = tx.Select(clause.Associations).Delete(&domain.Peer{Identifier: id}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -308,6 +308,10 @@ func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifie
|
||||
func (r *WgRepo) deleteLowLevelInterface(id domain.InterfaceIdentifier) error {
|
||||
link, err := r.nl.LinkByName(string(id))
|
||||
if err != nil {
|
||||
var linkNotFoundError netlink.LinkNotFoundError
|
||||
if errors.As(err, &linkNotFoundError) {
|
||||
return nil // ignore not found error
|
||||
}
|
||||
return fmt.Errorf("unable to find low level interface: %w", err)
|
||||
}
|
||||
|
||||
|
@ -119,42 +119,3 @@ func TestWireGuardUpdateInterface(t *testing.T) {
|
||||
assert.Contains(t, string(out), ipAddress)
|
||||
assert.Contains(t, string(out), ipV6Address)
|
||||
}
|
||||
|
||||
func TestWireGuardCreateInterfaceWithRoutes(t *testing.T) {
|
||||
mgr := setup(t)
|
||||
|
||||
interfaceName := domain.InterfaceIdentifier("wg_test_001")
|
||||
ipAddress := "10.11.12.13"
|
||||
ipV6Address := "1337:d34d:b33f::2"
|
||||
defer mgr.DeleteInterface(context.Background(), interfaceName)
|
||||
|
||||
iface := &domain.Interface{
|
||||
Identifier: interfaceName,
|
||||
//RoutingTable: "1234",
|
||||
}
|
||||
peers := []domain.Peer{
|
||||
{
|
||||
Interface: domain.PeerInterfaceConfig{
|
||||
Addresses: domain.CidrsMust(domain.CidrsFromString("10.11.12.14/32,10.22.33.44/32")),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := mgr.SaveInterface2(context.Background(), iface, peers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||
pi.Addresses = []domain.Cidr{
|
||||
domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipAddress), Mask: net.CIDRMask(24, 32)}),
|
||||
domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipV6Address), Mask: net.CIDRMask(64, 128)}),
|
||||
}
|
||||
pi.DeviceUp = true
|
||||
return pi, nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Validate that the interface has been created
|
||||
cmd := exec.Command("ip", "addr")
|
||||
out, err := cmd.CombinedOutput()
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(out), interfaceName)
|
||||
assert.Contains(t, string(out), ipAddress)
|
||||
assert.Contains(t, string(out), ipV6Address)
|
||||
}
|
||||
|
@ -15,13 +15,16 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type defaultRouteRule struct {
|
||||
ifaceId domain.InterfaceIdentifier
|
||||
fwMark int
|
||||
table int
|
||||
family int
|
||||
type routeRuleInfo struct {
|
||||
ifaceId domain.InterfaceIdentifier
|
||||
fwMark int
|
||||
table int
|
||||
family int
|
||||
hasDefault bool
|
||||
}
|
||||
|
||||
// Manager is try to mimic wg-quick behaviour (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash)
|
||||
// for default routes.
|
||||
type Manager struct {
|
||||
cfg *config.Config
|
||||
bus evbus.MessageBus
|
||||
@ -55,7 +58,7 @@ func NewRouteManager(cfg *config.Config, bus evbus.MessageBus, db InterfaceAndPe
|
||||
|
||||
func (m Manager) connectToMessageBus() {
|
||||
_ = m.bus.Subscribe(app.TopicRouteUpdate, m.handleRouteUpdateEvent)
|
||||
_ = m.bus.Subscribe(app.TopicRouteRemove, m.handleRouteUpdateEvent)
|
||||
_ = m.bus.Subscribe(app.TopicRouteRemove, m.handleRouteRemoveEvent)
|
||||
}
|
||||
|
||||
func (m Manager) StartBackgroundJobs(ctx context.Context) {
|
||||
@ -75,14 +78,14 @@ func (m Manager) handleRouteUpdateEvent(srcDescription string) {
|
||||
func (m Manager) handleRouteRemoveEvent(info domain.RoutingTableInfo) {
|
||||
logrus.Debugf("handling route remove event for: %s", info.String())
|
||||
|
||||
if info.Table == -2 {
|
||||
if !info.ManagementEnabled() {
|
||||
return // route management disabled
|
||||
}
|
||||
|
||||
if err := m.removeFwMarkRules(info.FwMark, info.FwMark, netlink.FAMILY_V4); err != nil {
|
||||
if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V4); err != nil {
|
||||
logrus.Errorf("failed to remove v4 fwmark rules: %v", err)
|
||||
}
|
||||
if err := m.removeFwMarkRules(info.FwMark, info.FwMark, netlink.FAMILY_V6); err != nil {
|
||||
if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V6); err != nil {
|
||||
logrus.Errorf("failed to remove v6 fwmark rules: %v", err)
|
||||
}
|
||||
|
||||
@ -95,7 +98,7 @@ func (m Manager) syncRoutes(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to find all interfaces: %w", err)
|
||||
}
|
||||
|
||||
rules := map[int][]defaultRouteRule{
|
||||
rules := map[int][]routeRuleInfo{
|
||||
netlink.FAMILY_V4: nil,
|
||||
netlink.FAMILY_V6: nil,
|
||||
}
|
||||
@ -103,6 +106,9 @@ func (m Manager) syncRoutes(ctx context.Context) error {
|
||||
if iface.IsDisabled() {
|
||||
continue // disabled interface does not need route entries
|
||||
}
|
||||
if !iface.ManageRoutingTable() {
|
||||
continue
|
||||
}
|
||||
|
||||
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
|
||||
if err != nil {
|
||||
@ -132,20 +138,22 @@ func (m Manager) syncRoutes(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to remove deprecated v6 routes for %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
if defRouteV4 {
|
||||
rules[netlink.FAMILY_V4] = append(rules[netlink.FAMILY_V4], defaultRouteRule{
|
||||
ifaceId: iface.Identifier,
|
||||
fwMark: fwmark,
|
||||
table: table,
|
||||
family: netlink.FAMILY_V4,
|
||||
if table != 0 {
|
||||
rules[netlink.FAMILY_V4] = append(rules[netlink.FAMILY_V4], routeRuleInfo{
|
||||
ifaceId: iface.Identifier,
|
||||
fwMark: fwmark,
|
||||
table: table,
|
||||
family: netlink.FAMILY_V4,
|
||||
hasDefault: defRouteV4,
|
||||
})
|
||||
}
|
||||
if defRouteV6 {
|
||||
rules[netlink.FAMILY_V6] = append(rules[netlink.FAMILY_V6], defaultRouteRule{
|
||||
ifaceId: iface.Identifier,
|
||||
fwMark: fwmark,
|
||||
table: table,
|
||||
family: netlink.FAMILY_V6,
|
||||
if table != 0 {
|
||||
rules[netlink.FAMILY_V6] = append(rules[netlink.FAMILY_V6], routeRuleInfo{
|
||||
ifaceId: iface.Identifier,
|
||||
fwMark: fwmark,
|
||||
table: table,
|
||||
family: netlink.FAMILY_V6,
|
||||
hasDefault: defRouteV6,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -153,7 +161,7 @@ func (m Manager) syncRoutes(ctx context.Context) error {
|
||||
return m.syncRouteRules(rules)
|
||||
}
|
||||
|
||||
func (m Manager) syncRouteRules(allRules map[int][]defaultRouteRule) error {
|
||||
func (m Manager) syncRouteRules(allRules map[int][]routeRuleInfo) error {
|
||||
for family, rules := range allRules {
|
||||
// update fwmark rules
|
||||
if err := m.setFwMarkRules(rules, family); err != nil {
|
||||
@ -174,7 +182,7 @@ func (m Manager) syncRouteRules(allRules map[int][]defaultRouteRule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) setFwMarkRules(rules []defaultRouteRule, family int) error {
|
||||
func (m Manager) setFwMarkRules(rules []routeRuleInfo, family int) error {
|
||||
for _, rule := range rules {
|
||||
existingRules, err := m.nl.RuleList(family)
|
||||
if err != nil {
|
||||
@ -229,8 +237,14 @@ func (m Manager) removeFwMarkRules(fwmark, table int, family int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) setMainRule(rules []defaultRouteRule, family int) error {
|
||||
shouldHaveMainRule := len(rules) != 0
|
||||
func (m Manager) setMainRule(rules []routeRuleInfo, family int) error {
|
||||
shouldHaveMainRule := false
|
||||
for _, rule := range rules {
|
||||
if rule.hasDefault == true {
|
||||
shouldHaveMainRule = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !shouldHaveMainRule {
|
||||
return nil
|
||||
}
|
||||
@ -269,13 +283,19 @@ func (m Manager) setMainRule(rules []defaultRouteRule, family int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) cleanupMainRule(rules []defaultRouteRule, family int) error {
|
||||
func (m Manager) cleanupMainRule(rules []routeRuleInfo, family int) error {
|
||||
existingRules, err := m.nl.RuleList(family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
|
||||
}
|
||||
|
||||
shouldHaveMainRule := len(rules) != 0
|
||||
shouldHaveMainRule := false
|
||||
for _, rule := range rules {
|
||||
if rule.hasDefault == true {
|
||||
shouldHaveMainRule = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
mainRules := 0
|
||||
for _, existingRule := range existingRules {
|
||||
@ -307,7 +327,7 @@ func (m Manager) cleanupMainRule(rules []defaultRouteRule, family int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Manager) getRulePriority(existingRules []netlink.Rule) int {
|
||||
func (m Manager) getMainRulePriority(existingRules []netlink.Rule) int {
|
||||
prio := m.cfg.Advanced.RulePrioOffset
|
||||
for {
|
||||
isFresh := true
|
||||
@ -326,7 +346,7 @@ func (m Manager) getRulePriority(existingRules []netlink.Rule) int {
|
||||
return prio
|
||||
}
|
||||
|
||||
func (m Manager) getMainRulePriority(existingRules []netlink.Rule) int {
|
||||
func (m Manager) getRulePriority(existingRules []netlink.Rule) int {
|
||||
prio := 32700 // linux main rule has a prio of 32766
|
||||
for {
|
||||
isFresh := true
|
||||
@ -346,7 +366,6 @@ func (m Manager) getMainRulePriority(existingRules []netlink.Rule) int {
|
||||
}
|
||||
|
||||
func (m Manager) setInterfaceRoutes(link netlink.Link, table int, allowedIPs []domain.Cidr) error {
|
||||
// try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash)
|
||||
for _, allowedIP := range allowedIPs {
|
||||
err := m.nl.RouteReplace(&netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
@ -374,16 +393,17 @@ func (m Manager) removeDeprecatedRoutes(link netlink.Link, family int, allowedIP
|
||||
return fmt.Errorf("failed to fetch raw routes: %w", err)
|
||||
}
|
||||
for _, rawRoute := range rawRoutes {
|
||||
var netlinkAddr domain.Cidr
|
||||
if rawRoute.Dst == nil {
|
||||
if rawRoute.Dst == nil { // handle default route
|
||||
var netlinkAddr domain.Cidr
|
||||
if family == netlink.FAMILY_V4 {
|
||||
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
|
||||
} else {
|
||||
netlinkAddr, _ = domain.CidrFromString("::/0")
|
||||
}
|
||||
} else {
|
||||
netlinkAddr = domain.CidrFromIpNet(*rawRoute.Dst)
|
||||
rawRoute.Dst = netlinkAddr.IpNet()
|
||||
}
|
||||
|
||||
netlinkAddr := domain.CidrFromIpNet(*rawRoute.Dst)
|
||||
remove := true
|
||||
for _, allowedIP := range allowedIPs {
|
||||
if netlinkAddr == allowedIP {
|
||||
@ -405,22 +425,20 @@ func (m Manager) removeDeprecatedRoutes(link netlink.Link, family int, allowedIP
|
||||
}
|
||||
|
||||
func (m Manager) getRoutingTableAndFwMark(iface *domain.Interface, allowedIPs []domain.Cidr, link netlink.Link) (table, fwmark int, err error) {
|
||||
defRouteV4, defRouteV6 := m.containsDefaultRoute(allowedIPs)
|
||||
|
||||
table = iface.GetRoutingTable()
|
||||
fwmark = int(iface.FirewallMark)
|
||||
|
||||
if (defRouteV4 || defRouteV6) && table <= 0 {
|
||||
table = m.cfg.Advanced.RouteTableOffset + link.Attrs().Index // generate a new routing table base on interface index
|
||||
logrus.Debugf("using routing table %d to handle default routes", table)
|
||||
}
|
||||
if (defRouteV4 || defRouteV6) && fwmark == 0 {
|
||||
if fwmark == 0 {
|
||||
fwmark = m.cfg.Advanced.RouteTableOffset + link.Attrs().Index // generate a new (temporary) firewall mark based on the interface index
|
||||
logrus.Debugf("using fwmark %d to handle default routes", table)
|
||||
logrus.Debugf("using fwmark %d to handle routes", table)
|
||||
|
||||
// apply the fwmark
|
||||
// apply the temporary fwmark to the wireguard interface
|
||||
err = m.setFwMark(iface.Identifier, fwmark)
|
||||
}
|
||||
if table == 0 {
|
||||
table = fwmark // generate a new routing table base on interface index
|
||||
logrus.Debugf("using routing table %d to handle default routes", table)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
"github.com/sirupsen/logrus"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -350,12 +351,14 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
|
||||
return fmt.Errorf("deletion failure: %w", err)
|
||||
}
|
||||
|
||||
if physicalInterface != nil {
|
||||
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
|
||||
FwMark: int(physicalInterface.FirewallMark),
|
||||
Table: existingInterface.GetRoutingTable(),
|
||||
})
|
||||
fwMark := int(existingInterface.FirewallMark)
|
||||
if physicalInterface != nil && fwMark == 0 {
|
||||
fwMark = int(physicalInterface.FirewallMark)
|
||||
}
|
||||
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
|
||||
FwMark: fwMark,
|
||||
Table: existingInterface.GetRoutingTable(),
|
||||
})
|
||||
|
||||
if err := m.handleInterfacePostSaveHooks(true, existingInterface); err != nil {
|
||||
return fmt.Errorf("post-delete hooks failed: %w", err)
|
||||
@ -395,6 +398,17 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface, pee
|
||||
}
|
||||
|
||||
m.bus.Publish(app.TopicRouteUpdate, "interface updated: "+string(iface.Identifier))
|
||||
if iface.IsDisabled() {
|
||||
physicalInterface, _ := m.wg.GetInterface(ctx, iface.Identifier)
|
||||
fwMark := int(iface.FirewallMark)
|
||||
if physicalInterface != nil && fwMark == 0 {
|
||||
fwMark = int(physicalInterface.FirewallMark)
|
||||
}
|
||||
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
|
||||
FwMark: fwMark,
|
||||
Table: iface.GetRoutingTable(),
|
||||
})
|
||||
}
|
||||
|
||||
if err := m.handleInterfacePostSaveHooks(stateChanged, iface); err != nil {
|
||||
return nil, fmt.Errorf("post-save hooks failed: %w", err)
|
||||
@ -668,7 +682,7 @@ func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceId
|
||||
}
|
||||
for _, peer := range allPeers {
|
||||
err = m.wg.DeletePeer(ctx, id, peer.Identifier)
|
||||
if err != nil {
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err)
|
||||
}
|
||||
|
||||
|
@ -3,6 +3,7 @@ package domain
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/h44z/wg-portal/internal"
|
||||
"github.com/sirupsen/logrus"
|
||||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
@ -34,7 +35,7 @@ type Interface struct {
|
||||
|
||||
Mtu int // the device MTU
|
||||
FirewallMark int32 // a firewall mark
|
||||
RoutingTable string // the routing table
|
||||
RoutingTable string // the routing table number or "off" if the routing table should not be managed
|
||||
|
||||
PreUp string // action that is executed before the device is up
|
||||
PostUp string // action that is executed after the device is up
|
||||
@ -114,30 +115,37 @@ func (i *Interface) GetAllowedIPs(peers []Peer) []Cidr {
|
||||
return allowedCidrs
|
||||
}
|
||||
|
||||
func (i *Interface) ManageRoutingTable() bool {
|
||||
routingTableStr := strings.ToLower(i.RoutingTable)
|
||||
return routingTableStr != "off"
|
||||
}
|
||||
|
||||
// GetRoutingTable returns the routing table number or
|
||||
//
|
||||
// -1 if an error occurred
|
||||
// -2 if RoutingTable was set to "off"
|
||||
// -1 if RoutingTable was set to "off" or an error occurred
|
||||
func (i *Interface) GetRoutingTable() int {
|
||||
routingTableStr := strings.ToLower(i.RoutingTable)
|
||||
switch {
|
||||
case routingTableStr == "":
|
||||
return 0
|
||||
case routingTableStr == "off":
|
||||
return -2
|
||||
return -1
|
||||
case strings.HasPrefix(routingTableStr, "0x"):
|
||||
numberStr := strings.ReplaceAll(routingTableStr, "0x", "")
|
||||
routingTable, err := strconv.ParseUint(numberStr, 16, 64)
|
||||
if err != nil {
|
||||
logrus.Errorf("invalid hex routing table %s: %w", routingTableStr, err)
|
||||
return -1
|
||||
}
|
||||
if routingTable > math.MaxInt32 {
|
||||
logrus.Errorf("invalid routing table %s, too big", routingTableStr)
|
||||
return -1
|
||||
}
|
||||
return int(routingTable)
|
||||
default:
|
||||
routingTable, err := strconv.Atoi(routingTableStr)
|
||||
if err != nil {
|
||||
logrus.Errorf("invalid routing table %s: %w", routingTableStr, err)
|
||||
return -1
|
||||
}
|
||||
return routingTable
|
||||
@ -220,3 +228,19 @@ type RoutingTableInfo struct {
|
||||
func (r RoutingTableInfo) String() string {
|
||||
return fmt.Sprintf("%d -> %d", r.FwMark, r.Table)
|
||||
}
|
||||
|
||||
func (r RoutingTableInfo) ManagementEnabled() bool {
|
||||
if r.Table == -1 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (r RoutingTableInfo) GetRoutingTable() int {
|
||||
if r.Table <= 0 {
|
||||
return r.FwMark // use the dynamic routing table which has the same number as the firewall mark
|
||||
}
|
||||
|
||||
return r.Table
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user