route management

This commit is contained in:
Christoph Haas 2023-07-30 23:18:38 +02:00
parent 22949963cf
commit d8624748b7
6 changed files with 119 additions and 98 deletions

View File

@ -331,17 +331,17 @@ func (r *SqlRepo) upsertInterface(ui *domain.ContextUserInfo, tx *gorm.DB, in *d
func (r *SqlRepo) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err := r.db.WithContext(ctx).Where("interface_identifier = ?", id).Delete(&domain.Peer{}).Error
err := tx.Where("interface_identifier = ?", id).Delete(&domain.Peer{}).Error
if err != nil {
return err
}
err = r.db.WithContext(ctx).Delete(&domain.InterfaceStatus{InterfaceId: id}).Error
err = tx.Delete(&domain.InterfaceStatus{InterfaceId: id}).Error
if err != nil {
return err
}
err = r.db.WithContext(ctx).Debug().Select(clause.Associations).Delete(&domain.Interface{Identifier: id}).Error
err = tx.Select(clause.Associations).Delete(&domain.Interface{Identifier: id}).Error
if err != nil {
return err
}
@ -518,12 +518,12 @@ func (r *SqlRepo) upsertPeer(ui *domain.ContextUserInfo, tx *gorm.DB, peer *doma
func (r *SqlRepo) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error {
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err := r.db.WithContext(ctx).Delete(&domain.PeerStatus{PeerId: id}).Error
err := tx.Delete(&domain.PeerStatus{PeerId: id}).Error
if err != nil {
return err
}
err = r.db.WithContext(ctx).Select(clause.Associations).Delete(&domain.Peer{Identifier: id}).Error
err = tx.Select(clause.Associations).Delete(&domain.Peer{Identifier: id}).Error
if err != nil {
return err
}

View File

@ -308,6 +308,10 @@ func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifie
func (r *WgRepo) deleteLowLevelInterface(id domain.InterfaceIdentifier) error {
link, err := r.nl.LinkByName(string(id))
if err != nil {
var linkNotFoundError netlink.LinkNotFoundError
if errors.As(err, &linkNotFoundError) {
return nil // ignore not found error
}
return fmt.Errorf("unable to find low level interface: %w", err)
}

View File

@ -119,42 +119,3 @@ func TestWireGuardUpdateInterface(t *testing.T) {
assert.Contains(t, string(out), ipAddress)
assert.Contains(t, string(out), ipV6Address)
}
func TestWireGuardCreateInterfaceWithRoutes(t *testing.T) {
mgr := setup(t)
interfaceName := domain.InterfaceIdentifier("wg_test_001")
ipAddress := "10.11.12.13"
ipV6Address := "1337:d34d:b33f::2"
defer mgr.DeleteInterface(context.Background(), interfaceName)
iface := &domain.Interface{
Identifier: interfaceName,
//RoutingTable: "1234",
}
peers := []domain.Peer{
{
Interface: domain.PeerInterfaceConfig{
Addresses: domain.CidrsMust(domain.CidrsFromString("10.11.12.14/32,10.22.33.44/32")),
},
},
}
err := mgr.SaveInterface2(context.Background(), iface, peers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
pi.Addresses = []domain.Cidr{
domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipAddress), Mask: net.CIDRMask(24, 32)}),
domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipV6Address), Mask: net.CIDRMask(64, 128)}),
}
pi.DeviceUp = true
return pi, nil
})
assert.NoError(t, err)
// Validate that the interface has been created
cmd := exec.Command("ip", "addr")
out, err := cmd.CombinedOutput()
assert.NoError(t, err)
assert.Contains(t, string(out), interfaceName)
assert.Contains(t, string(out), ipAddress)
assert.Contains(t, string(out), ipV6Address)
}

View File

@ -15,13 +15,16 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type defaultRouteRule struct {
ifaceId domain.InterfaceIdentifier
fwMark int
table int
family int
type routeRuleInfo struct {
ifaceId domain.InterfaceIdentifier
fwMark int
table int
family int
hasDefault bool
}
// Manager is try to mimic wg-quick behaviour (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash)
// for default routes.
type Manager struct {
cfg *config.Config
bus evbus.MessageBus
@ -55,7 +58,7 @@ func NewRouteManager(cfg *config.Config, bus evbus.MessageBus, db InterfaceAndPe
func (m Manager) connectToMessageBus() {
_ = m.bus.Subscribe(app.TopicRouteUpdate, m.handleRouteUpdateEvent)
_ = m.bus.Subscribe(app.TopicRouteRemove, m.handleRouteUpdateEvent)
_ = m.bus.Subscribe(app.TopicRouteRemove, m.handleRouteRemoveEvent)
}
func (m Manager) StartBackgroundJobs(ctx context.Context) {
@ -75,14 +78,14 @@ func (m Manager) handleRouteUpdateEvent(srcDescription string) {
func (m Manager) handleRouteRemoveEvent(info domain.RoutingTableInfo) {
logrus.Debugf("handling route remove event for: %s", info.String())
if info.Table == -2 {
if !info.ManagementEnabled() {
return // route management disabled
}
if err := m.removeFwMarkRules(info.FwMark, info.FwMark, netlink.FAMILY_V4); err != nil {
if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V4); err != nil {
logrus.Errorf("failed to remove v4 fwmark rules: %v", err)
}
if err := m.removeFwMarkRules(info.FwMark, info.FwMark, netlink.FAMILY_V6); err != nil {
if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V6); err != nil {
logrus.Errorf("failed to remove v6 fwmark rules: %v", err)
}
@ -95,7 +98,7 @@ func (m Manager) syncRoutes(ctx context.Context) error {
return fmt.Errorf("failed to find all interfaces: %w", err)
}
rules := map[int][]defaultRouteRule{
rules := map[int][]routeRuleInfo{
netlink.FAMILY_V4: nil,
netlink.FAMILY_V6: nil,
}
@ -103,6 +106,9 @@ func (m Manager) syncRoutes(ctx context.Context) error {
if iface.IsDisabled() {
continue // disabled interface does not need route entries
}
if !iface.ManageRoutingTable() {
continue
}
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil {
@ -132,20 +138,22 @@ func (m Manager) syncRoutes(ctx context.Context) error {
return fmt.Errorf("failed to remove deprecated v6 routes for %s: %w", iface.Identifier, err)
}
if defRouteV4 {
rules[netlink.FAMILY_V4] = append(rules[netlink.FAMILY_V4], defaultRouteRule{
ifaceId: iface.Identifier,
fwMark: fwmark,
table: table,
family: netlink.FAMILY_V4,
if table != 0 {
rules[netlink.FAMILY_V4] = append(rules[netlink.FAMILY_V4], routeRuleInfo{
ifaceId: iface.Identifier,
fwMark: fwmark,
table: table,
family: netlink.FAMILY_V4,
hasDefault: defRouteV4,
})
}
if defRouteV6 {
rules[netlink.FAMILY_V6] = append(rules[netlink.FAMILY_V6], defaultRouteRule{
ifaceId: iface.Identifier,
fwMark: fwmark,
table: table,
family: netlink.FAMILY_V6,
if table != 0 {
rules[netlink.FAMILY_V6] = append(rules[netlink.FAMILY_V6], routeRuleInfo{
ifaceId: iface.Identifier,
fwMark: fwmark,
table: table,
family: netlink.FAMILY_V6,
hasDefault: defRouteV6,
})
}
}
@ -153,7 +161,7 @@ func (m Manager) syncRoutes(ctx context.Context) error {
return m.syncRouteRules(rules)
}
func (m Manager) syncRouteRules(allRules map[int][]defaultRouteRule) error {
func (m Manager) syncRouteRules(allRules map[int][]routeRuleInfo) error {
for family, rules := range allRules {
// update fwmark rules
if err := m.setFwMarkRules(rules, family); err != nil {
@ -174,7 +182,7 @@ func (m Manager) syncRouteRules(allRules map[int][]defaultRouteRule) error {
return nil
}
func (m Manager) setFwMarkRules(rules []defaultRouteRule, family int) error {
func (m Manager) setFwMarkRules(rules []routeRuleInfo, family int) error {
for _, rule := range rules {
existingRules, err := m.nl.RuleList(family)
if err != nil {
@ -229,8 +237,14 @@ func (m Manager) removeFwMarkRules(fwmark, table int, family int) error {
return nil
}
func (m Manager) setMainRule(rules []defaultRouteRule, family int) error {
shouldHaveMainRule := len(rules) != 0
func (m Manager) setMainRule(rules []routeRuleInfo, family int) error {
shouldHaveMainRule := false
for _, rule := range rules {
if rule.hasDefault == true {
shouldHaveMainRule = true
break
}
}
if !shouldHaveMainRule {
return nil
}
@ -269,13 +283,19 @@ func (m Manager) setMainRule(rules []defaultRouteRule, family int) error {
return nil
}
func (m Manager) cleanupMainRule(rules []defaultRouteRule, family int) error {
func (m Manager) cleanupMainRule(rules []routeRuleInfo, family int) error {
existingRules, err := m.nl.RuleList(family)
if err != nil {
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
}
shouldHaveMainRule := len(rules) != 0
shouldHaveMainRule := false
for _, rule := range rules {
if rule.hasDefault == true {
shouldHaveMainRule = true
break
}
}
mainRules := 0
for _, existingRule := range existingRules {
@ -307,7 +327,7 @@ func (m Manager) cleanupMainRule(rules []defaultRouteRule, family int) error {
return nil
}
func (m Manager) getRulePriority(existingRules []netlink.Rule) int {
func (m Manager) getMainRulePriority(existingRules []netlink.Rule) int {
prio := m.cfg.Advanced.RulePrioOffset
for {
isFresh := true
@ -326,7 +346,7 @@ func (m Manager) getRulePriority(existingRules []netlink.Rule) int {
return prio
}
func (m Manager) getMainRulePriority(existingRules []netlink.Rule) int {
func (m Manager) getRulePriority(existingRules []netlink.Rule) int {
prio := 32700 // linux main rule has a prio of 32766
for {
isFresh := true
@ -346,7 +366,6 @@ func (m Manager) getMainRulePriority(existingRules []netlink.Rule) int {
}
func (m Manager) setInterfaceRoutes(link netlink.Link, table int, allowedIPs []domain.Cidr) error {
// try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash)
for _, allowedIP := range allowedIPs {
err := m.nl.RouteReplace(&netlink.Route{
LinkIndex: link.Attrs().Index,
@ -374,16 +393,17 @@ func (m Manager) removeDeprecatedRoutes(link netlink.Link, family int, allowedIP
return fmt.Errorf("failed to fetch raw routes: %w", err)
}
for _, rawRoute := range rawRoutes {
var netlinkAddr domain.Cidr
if rawRoute.Dst == nil {
if rawRoute.Dst == nil { // handle default route
var netlinkAddr domain.Cidr
if family == netlink.FAMILY_V4 {
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
} else {
netlinkAddr, _ = domain.CidrFromString("::/0")
}
} else {
netlinkAddr = domain.CidrFromIpNet(*rawRoute.Dst)
rawRoute.Dst = netlinkAddr.IpNet()
}
netlinkAddr := domain.CidrFromIpNet(*rawRoute.Dst)
remove := true
for _, allowedIP := range allowedIPs {
if netlinkAddr == allowedIP {
@ -405,22 +425,20 @@ func (m Manager) removeDeprecatedRoutes(link netlink.Link, family int, allowedIP
}
func (m Manager) getRoutingTableAndFwMark(iface *domain.Interface, allowedIPs []domain.Cidr, link netlink.Link) (table, fwmark int, err error) {
defRouteV4, defRouteV6 := m.containsDefaultRoute(allowedIPs)
table = iface.GetRoutingTable()
fwmark = int(iface.FirewallMark)
if (defRouteV4 || defRouteV6) && table <= 0 {
table = m.cfg.Advanced.RouteTableOffset + link.Attrs().Index // generate a new routing table base on interface index
logrus.Debugf("using routing table %d to handle default routes", table)
}
if (defRouteV4 || defRouteV6) && fwmark == 0 {
if fwmark == 0 {
fwmark = m.cfg.Advanced.RouteTableOffset + link.Attrs().Index // generate a new (temporary) firewall mark based on the interface index
logrus.Debugf("using fwmark %d to handle default routes", table)
logrus.Debugf("using fwmark %d to handle routes", table)
// apply the fwmark
// apply the temporary fwmark to the wireguard interface
err = m.setFwMark(iface.Identifier, fwmark)
}
if table == 0 {
table = fwmark // generate a new routing table base on interface index
logrus.Debugf("using routing table %d to handle default routes", table)
}
return
}

View File

@ -8,6 +8,7 @@ import (
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/domain"
"github.com/sirupsen/logrus"
"os"
"time"
)
@ -350,12 +351,14 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
return fmt.Errorf("deletion failure: %w", err)
}
if physicalInterface != nil {
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
FwMark: int(physicalInterface.FirewallMark),
Table: existingInterface.GetRoutingTable(),
})
fwMark := int(existingInterface.FirewallMark)
if physicalInterface != nil && fwMark == 0 {
fwMark = int(physicalInterface.FirewallMark)
}
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
FwMark: fwMark,
Table: existingInterface.GetRoutingTable(),
})
if err := m.handleInterfacePostSaveHooks(true, existingInterface); err != nil {
return fmt.Errorf("post-delete hooks failed: %w", err)
@ -395,6 +398,17 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface, pee
}
m.bus.Publish(app.TopicRouteUpdate, "interface updated: "+string(iface.Identifier))
if iface.IsDisabled() {
physicalInterface, _ := m.wg.GetInterface(ctx, iface.Identifier)
fwMark := int(iface.FirewallMark)
if physicalInterface != nil && fwMark == 0 {
fwMark = int(physicalInterface.FirewallMark)
}
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
FwMark: fwMark,
Table: iface.GetRoutingTable(),
})
}
if err := m.handleInterfacePostSaveHooks(stateChanged, iface); err != nil {
return nil, fmt.Errorf("post-save hooks failed: %w", err)
@ -668,7 +682,7 @@ func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceId
}
for _, peer := range allPeers {
err = m.wg.DeletePeer(ctx, id, peer.Identifier)
if err != nil {
if err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err)
}

View File

@ -3,6 +3,7 @@ package domain
import (
"fmt"
"github.com/h44z/wg-portal/internal"
"github.com/sirupsen/logrus"
"math"
"regexp"
"strconv"
@ -34,7 +35,7 @@ type Interface struct {
Mtu int // the device MTU
FirewallMark int32 // a firewall mark
RoutingTable string // the routing table
RoutingTable string // the routing table number or "off" if the routing table should not be managed
PreUp string // action that is executed before the device is up
PostUp string // action that is executed after the device is up
@ -114,30 +115,37 @@ func (i *Interface) GetAllowedIPs(peers []Peer) []Cidr {
return allowedCidrs
}
func (i *Interface) ManageRoutingTable() bool {
routingTableStr := strings.ToLower(i.RoutingTable)
return routingTableStr != "off"
}
// GetRoutingTable returns the routing table number or
//
// -1 if an error occurred
// -2 if RoutingTable was set to "off"
// -1 if RoutingTable was set to "off" or an error occurred
func (i *Interface) GetRoutingTable() int {
routingTableStr := strings.ToLower(i.RoutingTable)
switch {
case routingTableStr == "":
return 0
case routingTableStr == "off":
return -2
return -1
case strings.HasPrefix(routingTableStr, "0x"):
numberStr := strings.ReplaceAll(routingTableStr, "0x", "")
routingTable, err := strconv.ParseUint(numberStr, 16, 64)
if err != nil {
logrus.Errorf("invalid hex routing table %s: %w", routingTableStr, err)
return -1
}
if routingTable > math.MaxInt32 {
logrus.Errorf("invalid routing table %s, too big", routingTableStr)
return -1
}
return int(routingTable)
default:
routingTable, err := strconv.Atoi(routingTableStr)
if err != nil {
logrus.Errorf("invalid routing table %s: %w", routingTableStr, err)
return -1
}
return routingTable
@ -220,3 +228,19 @@ type RoutingTableInfo struct {
func (r RoutingTableInfo) String() string {
return fmt.Sprintf("%d -> %d", r.FwMark, r.Table)
}
func (r RoutingTableInfo) ManagementEnabled() bool {
if r.Table == -1 {
return false
}
return true
}
func (r RoutingTableInfo) GetRoutingTable() int {
if r.Table <= 0 {
return r.FwMark // use the dynamic routing table which has the same number as the firewall mark
}
return r.Table
}