mirror of
https://github.com/h44z/wg-portal.git
synced 2025-08-10 07:22:24 +00:00
wip routes
This commit is contained in:
parent
f4e5072f97
commit
2113999b22
@ -299,38 +299,106 @@ func (r *WgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *WgRepo) SaveRoutes(_ context.Context, interfaceId domain.InterfaceIdentifier, table int, allowedIPs []domain.Cidr) error {
|
||||
if table == -1 {
|
||||
logrus.Trace("ignoring route update")
|
||||
func (r *WgRepo) SaveRoutes(_ context.Context, iface *domain.Interface, peers []domain.Peer) error {
|
||||
table := iface.GetRoutingTable()
|
||||
if table == -2 {
|
||||
logrus.Trace("ignoring route update, feature disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
link, err := r.nl.LinkByName(string(interfaceId))
|
||||
if !iface.IsDisabled() {
|
||||
return r.setupRoutes(iface, peers)
|
||||
} else {
|
||||
return r.cleanupRoutes(iface, peers)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *WgRepo) setupRoutes(iface *domain.Interface, peers []domain.Peer) error {
|
||||
link, err := r.nl.LinkByName(string(iface.Identifier))
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("unable to find physical interface %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
if link.Attrs().OperState == netlink.OperDown {
|
||||
return nil // cannot set route for interface that is down
|
||||
}
|
||||
|
||||
table, fwmark, err := r.getRoutingTableAndFwMark(iface, peers, link)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get table and fwmark: %w", err)
|
||||
}
|
||||
|
||||
// try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash)
|
||||
allowedIPs := iface.GetAllowedIPs(peers)
|
||||
for _, allowedIP := range allowedIPs {
|
||||
// if allowedIP.Prefix().Bits() == 0 { // default route handling - TODO
|
||||
err := r.nl.RouteReplace(&netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Dst: allowedIP.IpNet(),
|
||||
Table: table,
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
Type: unix.RTN_UNICAST,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
|
||||
if allowedIP.Prefix().Bits() == 0 { // default route handling
|
||||
if err := r.nl.RouteReplace(&netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Dst: allowedIP.IpNet(),
|
||||
Table: table,
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
Type: unix.RTN_UNICAST,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
|
||||
}
|
||||
|
||||
family := netlink.FAMILY_V4
|
||||
if !allowedIP.IsV4() {
|
||||
family = netlink.FAMILY_V6
|
||||
}
|
||||
if err := r.nl.RuleAdd(&netlink.Rule{
|
||||
Family: family,
|
||||
Table: table,
|
||||
Mark: fwmark,
|
||||
Invert: true,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: -1,
|
||||
Priority: -1,
|
||||
Mask: -1,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to setup rule for fwmark %d: %w", fwmark, err)
|
||||
}
|
||||
if err := r.nl.RuleAdd(&netlink.Rule{
|
||||
Family: family,
|
||||
Table: unix.RT_TABLE_MAIN,
|
||||
SuppressIfgroup: -1,
|
||||
SuppressPrefixlen: 0,
|
||||
Priority: -1,
|
||||
Mark: -1,
|
||||
Mask: -1,
|
||||
Goto: -1,
|
||||
Flow: -1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to setup rule for main table: %w", err)
|
||||
}
|
||||
} else {
|
||||
err := r.nl.RouteReplace(&netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Dst: allowedIP.IpNet(),
|
||||
Table: table,
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
Type: unix.RTN_UNICAST,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove unwanted routes
|
||||
rawRoutes, err := r.nl.RouteListFiltered(netlink.FAMILY_ALL, &netlink.Route{
|
||||
if err := r.removeDeprecatedRoutes(link, allowedIPs, netlink.FAMILY_V4); err != nil {
|
||||
return fmt.Errorf("failed to remove deprecated v4 routes: %w", err)
|
||||
}
|
||||
if err := r.removeDeprecatedRoutes(link, allowedIPs, netlink.FAMILY_V6); err != nil {
|
||||
return fmt.Errorf("failed to remove deprecated v6 routes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *WgRepo) removeDeprecatedRoutes(link netlink.Link, allowedIPs []domain.Cidr, family int) error {
|
||||
rawRoutes, err := r.nl.RouteListFiltered(family, &netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Table: unix.RT_TABLE_UNSPEC, // all tables
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
@ -340,7 +408,16 @@ func (r *WgRepo) SaveRoutes(_ context.Context, interfaceId domain.InterfaceIdent
|
||||
return fmt.Errorf("failed to fetch raw routes: %w", err)
|
||||
}
|
||||
for _, rawRoute := range rawRoutes {
|
||||
netlinkAddr := domain.CidrFromIpNet(*rawRoute.Dst)
|
||||
var netlinkAddr domain.Cidr
|
||||
if rawRoute.Dst == nil {
|
||||
if family == netlink.FAMILY_V4 {
|
||||
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
|
||||
} else {
|
||||
netlinkAddr, _ = domain.CidrFromString("::/0")
|
||||
}
|
||||
} else {
|
||||
netlinkAddr = domain.CidrFromIpNet(*rawRoute.Dst)
|
||||
}
|
||||
remove := true
|
||||
for _, allowedIP := range allowedIPs {
|
||||
if netlinkAddr == allowedIP {
|
||||
@ -358,10 +435,66 @@ func (r *WgRepo) SaveRoutes(_ context.Context, interfaceId domain.InterfaceIdent
|
||||
return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *WgRepo) cleanupRoutes(iface *domain.Interface, peers []domain.Peer) error {
|
||||
link, err := r.nl.LinkByName(string(iface.Identifier))
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to find physical interface %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
table, _, err := r.getRoutingTableAndFwMark(iface, peers, link)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get table and fwmark: %w", err)
|
||||
}
|
||||
|
||||
if table == 0 {
|
||||
return nil // noting to remove
|
||||
}
|
||||
|
||||
delRule := netlink.NewRule()
|
||||
delRule.Family = netlink.FAMILY_ALL
|
||||
delRule.Table = table
|
||||
if err := r.nl.RuleDel(delRule); err != nil && !errors.Is(err, unix.ENOENT) {
|
||||
return fmt.Errorf("failed to delete rule for table %d: %w", table, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *WgRepo) getRoutingTableAndFwMark(iface *domain.Interface, peers []domain.Peer, link netlink.Link) (table, fwmark int, err error) {
|
||||
allowedIPs := iface.GetAllowedIPs(peers)
|
||||
containsDefaultRoute := false
|
||||
for _, allowedIP := range allowedIPs {
|
||||
if allowedIP.Prefix().Bits() == 0 {
|
||||
containsDefaultRoute = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
table = iface.GetRoutingTable()
|
||||
fwmark = int(iface.FirewallMark)
|
||||
|
||||
if containsDefaultRoute && table <= 0 {
|
||||
table = 20000 + link.Attrs().Index // generate a new routing table base on interface index
|
||||
logrus.Debugf("using routing table %d to handle default routes", table)
|
||||
}
|
||||
if containsDefaultRoute && fwmark == 0 {
|
||||
fwmark = 20000 + link.Attrs().Index // generate a new (temporary) firewall mark based on the interface index
|
||||
logrus.Debugf("using fwmark %d to handle default routes", table)
|
||||
|
||||
// apply the fwmark
|
||||
err = r.wg.ConfigureDevice(string(iface.Identifier), wgtypes.Config{
|
||||
FirewallMark: &fwmark,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to update temporary fwmark to: %d: %w", fwmark, err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
|
||||
if err := r.deleteLowLevelInterface(id); err != nil {
|
||||
return err
|
||||
|
@ -41,7 +41,7 @@ type InterfaceController interface {
|
||||
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
||||
SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error
|
||||
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
||||
SaveRoutes(_ context.Context, deviceId domain.InterfaceIdentifier, table int, allowedIPs []domain.Cidr) error
|
||||
SaveRoutes(_ context.Context, iface *domain.Interface, peers []domain.Peer) error
|
||||
}
|
||||
|
||||
type WgQuickController interface {
|
||||
|
@ -384,7 +384,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface, pee
|
||||
return nil, fmt.Errorf("failed to save interface: %w", err)
|
||||
}
|
||||
|
||||
err = m.wg.SaveRoutes(ctx, iface.Identifier, iface.GetRoutingTable(), iface.GetAllowedIPs(peers))
|
||||
err = m.wg.SaveRoutes(ctx, iface, peers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save routes: %w", err)
|
||||
}
|
||||
|
@ -284,7 +284,7 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load peer interface %s: %w", ifaceId, err)
|
||||
}
|
||||
err = m.wg.SaveRoutes(ctx, iface.Identifier, iface.GetRoutingTable(), iface.GetAllowedIPs(ifacePeers))
|
||||
err = m.wg.SaveRoutes(ctx, iface, ifacePeers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update peer routes on interface %s: %w", ifaceId, err)
|
||||
}
|
||||
|
@ -114,12 +114,17 @@ func (i *Interface) GetAllowedIPs(peers []Peer) []Cidr {
|
||||
return allowedCidrs
|
||||
}
|
||||
|
||||
// GetRoutingTable returns the routing table number or -1 if an error occurred or RoutingTable was set to "off"
|
||||
// GetRoutingTable returns the routing table number or
|
||||
//
|
||||
// -1 if an error occurred
|
||||
// -2 if RoutingTable was set to "off"
|
||||
func (i *Interface) GetRoutingTable() int {
|
||||
routingTableStr := strings.ToLower(i.RoutingTable)
|
||||
switch {
|
||||
case routingTableStr == "":
|
||||
return 0
|
||||
case routingTableStr == "off":
|
||||
return -2
|
||||
case strings.HasPrefix(routingTableStr, "0x"):
|
||||
numberStr := strings.ReplaceAll(routingTableStr, "0x", "")
|
||||
routingTable, err := strconv.ParseUint(numberStr, 16, 64)
|
||||
|
Loading…
x
Reference in New Issue
Block a user