wg-portal/internal/adapters/wireguard.go
Christoph Haas 2113999b22 wip routes
2023-07-30 18:55:44 +02:00

625 lines
17 KiB
Go

package adapters
import (
"context"
"errors"
"fmt"
"github.com/h44z/wg-portal/internal/domain"
"github.com/h44z/wg-portal/internal/lowlevel"
"github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"os"
)
// WgRepo implements all low-level WireGuard interactions.
type WgRepo struct {
wg lowlevel.WireGuardClient
nl lowlevel.NetlinkClient
}
func NewWireGuardRepository() *WgRepo {
wg, err := wgctrl.New()
if err != nil {
panic("failed to init wgctrl: " + err.Error())
}
nl := &lowlevel.NetlinkManager{}
repo := &WgRepo{
wg: wg,
nl: nl,
}
return repo
}
func (r *WgRepo) GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) {
devices, err := r.wg.Devices()
if err != nil {
return nil, fmt.Errorf("device list error: %w", err)
}
interfaces := make([]domain.PhysicalInterface, 0, len(devices))
for _, device := range devices {
interfaceModel, err := r.convertWireGuardInterface(device)
if err != nil {
return nil, fmt.Errorf("interface convert failed for %s: %w", device.Name, err)
}
interfaces = append(interfaces, interfaceModel)
}
return interfaces, nil
}
func (r *WgRepo) GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
return r.getInterface(id)
}
func (r *WgRepo) GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) {
device, err := r.wg.Device(string(deviceId))
if err != nil {
return nil, fmt.Errorf("device error: %w", err)
}
peers := make([]domain.PhysicalPeer, 0, len(device.Peers))
for _, peer := range device.Peers {
peerModel, err := r.convertWireGuardPeer(&peer)
if err != nil {
return nil, fmt.Errorf("peer convert failed for %v: %w", peer.PublicKey, err)
}
peers = append(peers, peerModel)
}
return peers, nil
}
func (r *WgRepo) GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) {
return r.getPeer(deviceId, id)
}
func (r *WgRepo) convertWireGuardInterface(device *wgtypes.Device) (domain.PhysicalInterface, error) {
// read data from wgctrl interface
iface := domain.PhysicalInterface{
Identifier: domain.InterfaceIdentifier(device.Name),
KeyPair: domain.KeyPair{
PrivateKey: device.PrivateKey.String(),
PublicKey: device.PublicKey.String(),
},
ListenPort: device.ListenPort,
Addresses: nil,
Mtu: 0,
FirewallMark: int32(device.FirewallMark),
DeviceUp: false,
ImportSource: "wgctrl",
DeviceType: device.Type.String(),
BytesUpload: 0,
BytesDownload: 0,
}
// read data from netlink interface
lowLevelInterface, err := r.nl.LinkByName(device.Name)
if err != nil {
return domain.PhysicalInterface{}, fmt.Errorf("netlink error for %s: %w", device.Name, err)
}
ipAddresses, err := r.nl.AddrList(lowLevelInterface)
if err != nil {
return domain.PhysicalInterface{}, fmt.Errorf("ip read error for %s: %w", device.Name, err)
}
for _, addr := range ipAddresses {
iface.Addresses = append(iface.Addresses, domain.CidrFromNetlinkAddr(addr))
}
iface.Mtu = lowLevelInterface.Attrs().MTU
iface.DeviceUp = lowLevelInterface.Attrs().OperState == netlink.OperUnknown // wg only supports unknown
if stats := lowLevelInterface.Attrs().Statistics; stats != nil {
iface.BytesUpload = stats.TxBytes
iface.BytesDownload = stats.RxBytes
}
return iface, nil
}
func (r *WgRepo) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer, error) {
peerModel := domain.PhysicalPeer{
Identifier: domain.PeerIdentifier(peer.PublicKey.String()),
Endpoint: "",
AllowedIPs: nil,
KeyPair: domain.KeyPair{
PublicKey: peer.PublicKey.String(),
},
PresharedKey: "",
PersistentKeepalive: int(peer.PersistentKeepaliveInterval.Seconds()),
LastHandshake: peer.LastHandshakeTime,
ProtocolVersion: peer.ProtocolVersion,
BytesUpload: uint64(peer.ReceiveBytes),
BytesDownload: uint64(peer.TransmitBytes),
}
for _, addr := range peer.AllowedIPs {
peerModel.AllowedIPs = append(peerModel.AllowedIPs, domain.CidrFromIpNet(addr))
}
if peer.Endpoint != nil {
peerModel.Endpoint = peer.Endpoint.String()
}
if peer.PresharedKey != (wgtypes.Key{}) {
peerModel.PresharedKey = domain.PreSharedKey(peer.PresharedKey.String())
}
return peerModel, nil
}
func (r *WgRepo) SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error {
physicalInterface, err := r.getOrCreateInterface(id)
if err != nil {
return err
}
if updateFunc != nil {
physicalInterface, err = updateFunc(physicalInterface)
if err != nil {
return err
}
}
if err := r.updateLowLevelInterface(physicalInterface); err != nil {
return err
}
if err := r.updateWireGuardInterface(physicalInterface); err != nil {
return err
}
return nil
}
func (r *WgRepo) getOrCreateInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
device, err := r.getInterface(id)
if err == nil {
return device, nil
}
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("device error: %w", err)
}
// create new device
if err := r.createLowLevelInterface(id); err != nil {
return nil, err
}
device, err = r.getInterface(id)
return device, err
}
func (r *WgRepo) getInterface(id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) {
device, err := r.wg.Device(string(id))
if err != nil {
return nil, err
}
pi, err := r.convertWireGuardInterface(device)
return &pi, err
}
func (r *WgRepo) createLowLevelInterface(id domain.InterfaceIdentifier) error {
link := &netlink.GenericLink{
LinkAttrs: netlink.LinkAttrs{
Name: string(id),
},
LinkType: "wireguard",
}
err := r.nl.LinkAdd(link)
if err != nil {
return fmt.Errorf("link add failed: %w", err)
}
return nil
}
func (r *WgRepo) updateLowLevelInterface(pi *domain.PhysicalInterface) error {
link, err := r.nl.LinkByName(string(pi.Identifier))
if err != nil {
return err
}
if pi.Mtu != 0 {
if err := r.nl.LinkSetMTU(link, pi.Mtu); err != nil {
return fmt.Errorf("mtu error: %w", err)
}
}
for _, addr := range pi.Addresses {
err := r.nl.AddrReplace(link, addr.NetlinkAddr())
if err != nil {
return fmt.Errorf("failed to set ip %s: %w", addr.String(), err)
}
}
// Remove unwanted IP addresses
rawAddresses, err := r.nl.AddrList(link)
if err != nil {
return fmt.Errorf("failed to fetch interface ips: %w", err)
}
for _, rawAddr := range rawAddresses {
netlinkAddr := domain.CidrFromNetlinkAddr(rawAddr)
remove := true
for _, addr := range pi.Addresses {
if addr == netlinkAddr {
remove = false
break
}
}
if !remove {
continue
}
err := r.nl.AddrDel(link, &rawAddr)
if err != nil {
return fmt.Errorf("failed to remove deprecated ip %s: %w", netlinkAddr.String(), err)
}
}
// Update link state
if pi.DeviceUp {
if err := r.nl.LinkSetUp(link); err != nil {
return fmt.Errorf("failed to bring up device: %w", err)
}
} else {
if err := r.nl.LinkSetDown(link); err != nil {
return fmt.Errorf("failed to bring down device: %w", err)
}
}
return nil
}
func (r *WgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
pKey, err := wgtypes.NewKey(pi.KeyPair.GetPrivateKeyBytes())
if err != nil {
return err
}
var fwMark *int
if pi.FirewallMark != 0 {
*fwMark = int(pi.FirewallMark)
}
err = r.wg.ConfigureDevice(string(pi.Identifier), wgtypes.Config{
PrivateKey: &pKey,
ListenPort: &pi.ListenPort,
FirewallMark: fwMark,
ReplacePeers: false,
})
if err != nil {
return err
}
return nil
}
func (r *WgRepo) SaveRoutes(_ context.Context, iface *domain.Interface, peers []domain.Peer) error {
table := iface.GetRoutingTable()
if table == -2 {
logrus.Trace("ignoring route update, feature disabled")
return nil
}
if !iface.IsDisabled() {
return r.setupRoutes(iface, peers)
} else {
return r.cleanupRoutes(iface, peers)
}
}
func (r *WgRepo) setupRoutes(iface *domain.Interface, peers []domain.Peer) error {
link, err := r.nl.LinkByName(string(iface.Identifier))
if err != nil {
return fmt.Errorf("unable to find physical interface %s: %w", iface.Identifier, err)
}
if link.Attrs().OperState == netlink.OperDown {
return nil // cannot set route for interface that is down
}
table, fwmark, err := r.getRoutingTableAndFwMark(iface, peers, link)
if err != nil {
return fmt.Errorf("failed to get table and fwmark: %w", err)
}
// try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash)
allowedIPs := iface.GetAllowedIPs(peers)
for _, allowedIP := range allowedIPs {
if allowedIP.Prefix().Bits() == 0 { // default route handling
if err := r.nl.RouteReplace(&netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: allowedIP.IpNet(),
Table: table,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
}); err != nil {
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
}
family := netlink.FAMILY_V4
if !allowedIP.IsV4() {
family = netlink.FAMILY_V6
}
if err := r.nl.RuleAdd(&netlink.Rule{
Family: family,
Table: table,
Mark: fwmark,
Invert: true,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Priority: -1,
Mask: -1,
Goto: -1,
Flow: -1,
}); err != nil {
return fmt.Errorf("failed to setup rule for fwmark %d: %w", fwmark, err)
}
if err := r.nl.RuleAdd(&netlink.Rule{
Family: family,
Table: unix.RT_TABLE_MAIN,
SuppressIfgroup: -1,
SuppressPrefixlen: 0,
Priority: -1,
Mark: -1,
Mask: -1,
Goto: -1,
Flow: -1,
}); err != nil {
return fmt.Errorf("failed to setup rule for main table: %w", err)
}
} else {
err := r.nl.RouteReplace(&netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: allowedIP.IpNet(),
Table: table,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
})
if err != nil {
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
}
}
}
if err := r.removeDeprecatedRoutes(link, allowedIPs, netlink.FAMILY_V4); err != nil {
return fmt.Errorf("failed to remove deprecated v4 routes: %w", err)
}
if err := r.removeDeprecatedRoutes(link, allowedIPs, netlink.FAMILY_V6); err != nil {
return fmt.Errorf("failed to remove deprecated v6 routes: %w", err)
}
return nil
}
func (r *WgRepo) removeDeprecatedRoutes(link netlink.Link, allowedIPs []domain.Cidr, family int) error {
rawRoutes, err := r.nl.RouteListFiltered(family, &netlink.Route{
LinkIndex: link.Attrs().Index,
Table: unix.RT_TABLE_UNSPEC, // all tables
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
if err != nil {
return fmt.Errorf("failed to fetch raw routes: %w", err)
}
for _, rawRoute := range rawRoutes {
var netlinkAddr domain.Cidr
if rawRoute.Dst == nil {
if family == netlink.FAMILY_V4 {
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
} else {
netlinkAddr, _ = domain.CidrFromString("::/0")
}
} else {
netlinkAddr = domain.CidrFromIpNet(*rawRoute.Dst)
}
remove := true
for _, allowedIP := range allowedIPs {
if netlinkAddr == allowedIP {
remove = false
break
}
}
if !remove {
continue
}
err := r.nl.RouteDel(&rawRoute)
if err != nil {
return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err)
}
}
return nil
}
func (r *WgRepo) cleanupRoutes(iface *domain.Interface, peers []domain.Peer) error {
link, err := r.nl.LinkByName(string(iface.Identifier))
if err != nil {
return fmt.Errorf("unable to find physical interface %s: %w", iface.Identifier, err)
}
table, _, err := r.getRoutingTableAndFwMark(iface, peers, link)
if err != nil {
return fmt.Errorf("failed to get table and fwmark: %w", err)
}
if table == 0 {
return nil // noting to remove
}
delRule := netlink.NewRule()
delRule.Family = netlink.FAMILY_ALL
delRule.Table = table
if err := r.nl.RuleDel(delRule); err != nil && !errors.Is(err, unix.ENOENT) {
return fmt.Errorf("failed to delete rule for table %d: %w", table, err)
}
return nil
}
func (r *WgRepo) getRoutingTableAndFwMark(iface *domain.Interface, peers []domain.Peer, link netlink.Link) (table, fwmark int, err error) {
allowedIPs := iface.GetAllowedIPs(peers)
containsDefaultRoute := false
for _, allowedIP := range allowedIPs {
if allowedIP.Prefix().Bits() == 0 {
containsDefaultRoute = true
break
}
}
table = iface.GetRoutingTable()
fwmark = int(iface.FirewallMark)
if containsDefaultRoute && table <= 0 {
table = 20000 + link.Attrs().Index // generate a new routing table base on interface index
logrus.Debugf("using routing table %d to handle default routes", table)
}
if containsDefaultRoute && fwmark == 0 {
fwmark = 20000 + link.Attrs().Index // generate a new (temporary) firewall mark based on the interface index
logrus.Debugf("using fwmark %d to handle default routes", table)
// apply the fwmark
err = r.wg.ConfigureDevice(string(iface.Identifier), wgtypes.Config{
FirewallMark: &fwmark,
})
if err != nil {
return 0, 0, fmt.Errorf("failed to update temporary fwmark to: %d: %w", fwmark, err)
}
}
return
}
func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
if err := r.deleteLowLevelInterface(id); err != nil {
return err
}
return nil
}
func (r *WgRepo) deleteLowLevelInterface(id domain.InterfaceIdentifier) error {
link, err := r.nl.LinkByName(string(id))
if err != nil {
return fmt.Errorf("unable to find low level interface: %w", err)
}
err = r.nl.LinkDel(link)
if err != nil {
return fmt.Errorf("failed to delete low level interface: %w", err)
}
return nil
}
func (r *WgRepo) SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error {
physicalPeer, err := r.getOrCreatePeer(deviceId, id)
if err != nil {
return err
}
physicalPeer, err = updateFunc(physicalPeer)
if err != nil {
return err
}
if err := r.updatePeer(deviceId, physicalPeer); err != nil {
return err
}
return nil
}
func (r *WgRepo) getOrCreatePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) {
peer, err := r.getPeer(deviceId, id)
if err == nil {
return peer, nil
}
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("peer error: %w", err)
}
// create new peer
err = r.wg.ConfigureDevice(string(deviceId), wgtypes.Config{Peers: []wgtypes.PeerConfig{{
PublicKey: id.ToPublicKey(),
}}})
peer, err = r.getPeer(deviceId, id)
return peer, nil
}
func (r *WgRepo) getPeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error) {
if !id.IsPublicKey() {
return nil, errors.New("invalid public key")
}
device, err := r.wg.Device(string(deviceId))
if err != nil {
return nil, err
}
publicKey := id.ToPublicKey()
for _, peer := range device.Peers {
if peer.PublicKey != publicKey {
continue
}
peerModel, err := r.convertWireGuardPeer(&peer)
return &peerModel, err
}
return nil, os.ErrNotExist
}
func (r *WgRepo) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.PhysicalPeer) error {
cfg := wgtypes.PeerConfig{
PublicKey: pp.GetPublicKey(),
Remove: false,
UpdateOnly: true,
PresharedKey: pp.GetPresharedKey(),
Endpoint: pp.GetEndpointAddress(),
PersistentKeepaliveInterval: pp.GetPersistentKeepaliveTime(),
ReplaceAllowedIPs: true,
AllowedIPs: pp.GetAllowedIPs(),
}
err := r.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}})
if err != nil {
return err
}
return nil
}
func (r *WgRepo) DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error {
if !id.IsPublicKey() {
return errors.New("invalid public key")
}
err := r.deletePeer(deviceId, id)
if err != nil {
return err
}
return nil
}
func (r *WgRepo) deletePeer(deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error {
cfg := wgtypes.PeerConfig{
PublicKey: id.ToPublicKey(),
Remove: true,
}
err := r.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}})
if err != nil {
return err
}
return nil
}