mirror of
https://github.com/h44z/wg-portal.git
synced 2025-08-12 16:22:23 +00:00
wip routes
This commit is contained in:
parent
f4e5072f97
commit
2113999b22
@ -299,38 +299,106 @@ func (r *WgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *WgRepo) SaveRoutes(_ context.Context, interfaceId domain.InterfaceIdentifier, table int, allowedIPs []domain.Cidr) error {
|
func (r *WgRepo) SaveRoutes(_ context.Context, iface *domain.Interface, peers []domain.Peer) error {
|
||||||
if table == -1 {
|
table := iface.GetRoutingTable()
|
||||||
logrus.Trace("ignoring route update")
|
if table == -2 {
|
||||||
|
logrus.Trace("ignoring route update, feature disabled")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
link, err := r.nl.LinkByName(string(interfaceId))
|
if !iface.IsDisabled() {
|
||||||
|
return r.setupRoutes(iface, peers)
|
||||||
|
} else {
|
||||||
|
return r.cleanupRoutes(iface, peers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *WgRepo) setupRoutes(iface *domain.Interface, peers []domain.Peer) error {
|
||||||
|
link, err := r.nl.LinkByName(string(iface.Identifier))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("unable to find physical interface %s: %w", iface.Identifier, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if link.Attrs().OperState == netlink.OperDown {
|
if link.Attrs().OperState == netlink.OperDown {
|
||||||
return nil // cannot set route for interface that is down
|
return nil // cannot set route for interface that is down
|
||||||
}
|
}
|
||||||
|
|
||||||
|
table, fwmark, err := r.getRoutingTableAndFwMark(iface, peers, link)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get table and fwmark: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash)
|
// try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash)
|
||||||
|
allowedIPs := iface.GetAllowedIPs(peers)
|
||||||
for _, allowedIP := range allowedIPs {
|
for _, allowedIP := range allowedIPs {
|
||||||
// if allowedIP.Prefix().Bits() == 0 { // default route handling - TODO
|
if allowedIP.Prefix().Bits() == 0 { // default route handling
|
||||||
err := r.nl.RouteReplace(&netlink.Route{
|
if err := r.nl.RouteReplace(&netlink.Route{
|
||||||
LinkIndex: link.Attrs().Index,
|
LinkIndex: link.Attrs().Index,
|
||||||
Dst: allowedIP.IpNet(),
|
Dst: allowedIP.IpNet(),
|
||||||
Table: table,
|
Table: table,
|
||||||
Scope: unix.RT_SCOPE_LINK,
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
Type: unix.RTN_UNICAST,
|
Type: unix.RTN_UNICAST,
|
||||||
})
|
}); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
|
||||||
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
|
}
|
||||||
|
|
||||||
|
family := netlink.FAMILY_V4
|
||||||
|
if !allowedIP.IsV4() {
|
||||||
|
family = netlink.FAMILY_V6
|
||||||
|
}
|
||||||
|
if err := r.nl.RuleAdd(&netlink.Rule{
|
||||||
|
Family: family,
|
||||||
|
Table: table,
|
||||||
|
Mark: fwmark,
|
||||||
|
Invert: true,
|
||||||
|
SuppressIfgroup: -1,
|
||||||
|
SuppressPrefixlen: -1,
|
||||||
|
Priority: -1,
|
||||||
|
Mask: -1,
|
||||||
|
Goto: -1,
|
||||||
|
Flow: -1,
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("failed to setup rule for fwmark %d: %w", fwmark, err)
|
||||||
|
}
|
||||||
|
if err := r.nl.RuleAdd(&netlink.Rule{
|
||||||
|
Family: family,
|
||||||
|
Table: unix.RT_TABLE_MAIN,
|
||||||
|
SuppressIfgroup: -1,
|
||||||
|
SuppressPrefixlen: 0,
|
||||||
|
Priority: -1,
|
||||||
|
Mark: -1,
|
||||||
|
Mask: -1,
|
||||||
|
Goto: -1,
|
||||||
|
Flow: -1,
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("failed to setup rule for main table: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err := r.nl.RouteReplace(&netlink.Route{
|
||||||
|
LinkIndex: link.Attrs().Index,
|
||||||
|
Dst: allowedIP.IpNet(),
|
||||||
|
Table: table,
|
||||||
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
|
Type: unix.RTN_UNICAST,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove unwanted routes
|
if err := r.removeDeprecatedRoutes(link, allowedIPs, netlink.FAMILY_V4); err != nil {
|
||||||
rawRoutes, err := r.nl.RouteListFiltered(netlink.FAMILY_ALL, &netlink.Route{
|
return fmt.Errorf("failed to remove deprecated v4 routes: %w", err)
|
||||||
|
}
|
||||||
|
if err := r.removeDeprecatedRoutes(link, allowedIPs, netlink.FAMILY_V6); err != nil {
|
||||||
|
return fmt.Errorf("failed to remove deprecated v6 routes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *WgRepo) removeDeprecatedRoutes(link netlink.Link, allowedIPs []domain.Cidr, family int) error {
|
||||||
|
rawRoutes, err := r.nl.RouteListFiltered(family, &netlink.Route{
|
||||||
LinkIndex: link.Attrs().Index,
|
LinkIndex: link.Attrs().Index,
|
||||||
Table: unix.RT_TABLE_UNSPEC, // all tables
|
Table: unix.RT_TABLE_UNSPEC, // all tables
|
||||||
Scope: unix.RT_SCOPE_LINK,
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
@ -340,7 +408,16 @@ func (r *WgRepo) SaveRoutes(_ context.Context, interfaceId domain.InterfaceIdent
|
|||||||
return fmt.Errorf("failed to fetch raw routes: %w", err)
|
return fmt.Errorf("failed to fetch raw routes: %w", err)
|
||||||
}
|
}
|
||||||
for _, rawRoute := range rawRoutes {
|
for _, rawRoute := range rawRoutes {
|
||||||
netlinkAddr := domain.CidrFromIpNet(*rawRoute.Dst)
|
var netlinkAddr domain.Cidr
|
||||||
|
if rawRoute.Dst == nil {
|
||||||
|
if family == netlink.FAMILY_V4 {
|
||||||
|
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
|
||||||
|
} else {
|
||||||
|
netlinkAddr, _ = domain.CidrFromString("::/0")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
netlinkAddr = domain.CidrFromIpNet(*rawRoute.Dst)
|
||||||
|
}
|
||||||
remove := true
|
remove := true
|
||||||
for _, allowedIP := range allowedIPs {
|
for _, allowedIP := range allowedIPs {
|
||||||
if netlinkAddr == allowedIP {
|
if netlinkAddr == allowedIP {
|
||||||
@ -358,10 +435,66 @@ func (r *WgRepo) SaveRoutes(_ context.Context, interfaceId domain.InterfaceIdent
|
|||||||
return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err)
|
return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *WgRepo) cleanupRoutes(iface *domain.Interface, peers []domain.Peer) error {
|
||||||
|
link, err := r.nl.LinkByName(string(iface.Identifier))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to find physical interface %s: %w", iface.Identifier, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
table, _, err := r.getRoutingTableAndFwMark(iface, peers, link)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get table and fwmark: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if table == 0 {
|
||||||
|
return nil // noting to remove
|
||||||
|
}
|
||||||
|
|
||||||
|
delRule := netlink.NewRule()
|
||||||
|
delRule.Family = netlink.FAMILY_ALL
|
||||||
|
delRule.Table = table
|
||||||
|
if err := r.nl.RuleDel(delRule); err != nil && !errors.Is(err, unix.ENOENT) {
|
||||||
|
return fmt.Errorf("failed to delete rule for table %d: %w", table, err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *WgRepo) getRoutingTableAndFwMark(iface *domain.Interface, peers []domain.Peer, link netlink.Link) (table, fwmark int, err error) {
|
||||||
|
allowedIPs := iface.GetAllowedIPs(peers)
|
||||||
|
containsDefaultRoute := false
|
||||||
|
for _, allowedIP := range allowedIPs {
|
||||||
|
if allowedIP.Prefix().Bits() == 0 {
|
||||||
|
containsDefaultRoute = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
table = iface.GetRoutingTable()
|
||||||
|
fwmark = int(iface.FirewallMark)
|
||||||
|
|
||||||
|
if containsDefaultRoute && table <= 0 {
|
||||||
|
table = 20000 + link.Attrs().Index // generate a new routing table base on interface index
|
||||||
|
logrus.Debugf("using routing table %d to handle default routes", table)
|
||||||
|
}
|
||||||
|
if containsDefaultRoute && fwmark == 0 {
|
||||||
|
fwmark = 20000 + link.Attrs().Index // generate a new (temporary) firewall mark based on the interface index
|
||||||
|
logrus.Debugf("using fwmark %d to handle default routes", table)
|
||||||
|
|
||||||
|
// apply the fwmark
|
||||||
|
err = r.wg.ConfigureDevice(string(iface.Identifier), wgtypes.Config{
|
||||||
|
FirewallMark: &fwmark,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("failed to update temporary fwmark to: %d: %w", fwmark, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
|
func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
|
||||||
if err := r.deleteLowLevelInterface(id); err != nil {
|
if err := r.deleteLowLevelInterface(id); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -41,7 +41,7 @@ type InterfaceController interface {
|
|||||||
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
||||||
SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error
|
SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error
|
||||||
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
||||||
SaveRoutes(_ context.Context, deviceId domain.InterfaceIdentifier, table int, allowedIPs []domain.Cidr) error
|
SaveRoutes(_ context.Context, iface *domain.Interface, peers []domain.Peer) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type WgQuickController interface {
|
type WgQuickController interface {
|
||||||
|
@ -384,7 +384,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface, pee
|
|||||||
return nil, fmt.Errorf("failed to save interface: %w", err)
|
return nil, fmt.Errorf("failed to save interface: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.wg.SaveRoutes(ctx, iface.Identifier, iface.GetRoutingTable(), iface.GetAllowedIPs(peers))
|
err = m.wg.SaveRoutes(ctx, iface, peers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to save routes: %w", err)
|
return nil, fmt.Errorf("failed to save routes: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -284,7 +284,7 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to load peer interface %s: %w", ifaceId, err)
|
return fmt.Errorf("failed to load peer interface %s: %w", ifaceId, err)
|
||||||
}
|
}
|
||||||
err = m.wg.SaveRoutes(ctx, iface.Identifier, iface.GetRoutingTable(), iface.GetAllowedIPs(ifacePeers))
|
err = m.wg.SaveRoutes(ctx, iface, ifacePeers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to update peer routes on interface %s: %w", ifaceId, err)
|
return fmt.Errorf("failed to update peer routes on interface %s: %w", ifaceId, err)
|
||||||
}
|
}
|
||||||
|
@ -114,12 +114,17 @@ func (i *Interface) GetAllowedIPs(peers []Peer) []Cidr {
|
|||||||
return allowedCidrs
|
return allowedCidrs
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRoutingTable returns the routing table number or -1 if an error occurred or RoutingTable was set to "off"
|
// GetRoutingTable returns the routing table number or
|
||||||
|
//
|
||||||
|
// -1 if an error occurred
|
||||||
|
// -2 if RoutingTable was set to "off"
|
||||||
func (i *Interface) GetRoutingTable() int {
|
func (i *Interface) GetRoutingTable() int {
|
||||||
routingTableStr := strings.ToLower(i.RoutingTable)
|
routingTableStr := strings.ToLower(i.RoutingTable)
|
||||||
switch {
|
switch {
|
||||||
case routingTableStr == "":
|
case routingTableStr == "":
|
||||||
return 0
|
return 0
|
||||||
|
case routingTableStr == "off":
|
||||||
|
return -2
|
||||||
case strings.HasPrefix(routingTableStr, "0x"):
|
case strings.HasPrefix(routingTableStr, "0x"):
|
||||||
numberStr := strings.ReplaceAll(routingTableStr, "0x", "")
|
numberStr := strings.ReplaceAll(routingTableStr, "0x", "")
|
||||||
routingTable, err := strconv.ParseUint(numberStr, 16, 64)
|
routingTable, err := strconv.ParseUint(numberStr, 16, 64)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user