mirror of
https://github.com/h44z/wg-portal.git
synced 2025-08-10 07:22:24 +00:00
wip: wgquick
This commit is contained in:
parent
c934d7ecd3
commit
5aa94999ab
@ -53,15 +53,15 @@ func (l *GormLogger) LogMode(logger.LogLevel) logger.Interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (l *GormLogger) Info(ctx context.Context, s string, args ...interface{}) {
|
func (l *GormLogger) Info(ctx context.Context, s string, args ...interface{}) {
|
||||||
logrus.WithContext(ctx).Infof(s, args)
|
logrus.WithContext(ctx).Infof(s, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *GormLogger) Warn(ctx context.Context, s string, args ...interface{}) {
|
func (l *GormLogger) Warn(ctx context.Context, s string, args ...interface{}) {
|
||||||
logrus.WithContext(ctx).Warnf(s, args)
|
logrus.WithContext(ctx).Warnf(s, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *GormLogger) Error(ctx context.Context, s string, args ...interface{}) {
|
func (l *GormLogger) Error(ctx context.Context, s string, args ...interface{}) {
|
||||||
logrus.WithContext(ctx).Errorf(s, args)
|
logrus.WithContext(ctx).Errorf(s, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
func (l *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||||
|
99
internal/adapters/wgquick.go
Normal file
99
internal/adapters/wgquick.go
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
package adapters
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"github.com/h44z/wg-portal/internal"
|
||||||
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WgQuickRepo implements higher level wg-quick like interactions like setting DNS, routing tables or interface hooks.
|
||||||
|
type WgQuickRepo struct {
|
||||||
|
shellCmd string
|
||||||
|
resolvConfIfacePrefix string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWgQuickRepo() *WgQuickRepo {
|
||||||
|
return &WgQuickRepo{
|
||||||
|
shellCmd: "bash",
|
||||||
|
resolvConfIfacePrefix: "tun.",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *WgQuickRepo) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
|
||||||
|
if hookCmd == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := r.exec(hookCmd, id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to exec hook: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *WgQuickRepo) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
|
||||||
|
if dnsStr == "" && dnsSearchStr == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsServers := internal.SliceString(dnsStr)
|
||||||
|
dnsSearchDomains := internal.SliceString(dnsSearchStr)
|
||||||
|
|
||||||
|
dnsCommand := "resolvconf -a %resPref%i -m 0 -x"
|
||||||
|
dnsCommandInput := make([]string, 0, len(dnsServers)+len(dnsSearchDomains))
|
||||||
|
|
||||||
|
for _, dnsServer := range dnsServers {
|
||||||
|
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("nameserver %s", dnsServer))
|
||||||
|
}
|
||||||
|
for _, searchDomain := range dnsSearchDomains {
|
||||||
|
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("search %s", searchDomain))
|
||||||
|
}
|
||||||
|
|
||||||
|
err := r.exec(dnsCommand, id, dnsCommandInput...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set dns settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *WgQuickRepo) UnsetDNS(id domain.InterfaceIdentifier) error {
|
||||||
|
dnsCommand := "resolvconf -d %resPref%i -f"
|
||||||
|
|
||||||
|
err := r.exec(dnsCommand, id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to unset dns settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *WgQuickRepo) replaceCommandPlaceHolders(command string, interfaceId domain.InterfaceIdentifier) string {
|
||||||
|
command = strings.ReplaceAll(command, "%resPref", r.resolvConfIfacePrefix)
|
||||||
|
return strings.ReplaceAll(command, "%i", string(interfaceId))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *WgQuickRepo) exec(command string, interfaceId domain.InterfaceIdentifier, stdin ...string) error {
|
||||||
|
commandWithInterfaceName := r.replaceCommandPlaceHolders(command, interfaceId)
|
||||||
|
cmd := exec.Command(r.shellCmd, "-ce", commandWithInterfaceName)
|
||||||
|
if len(stdin) > 0 {
|
||||||
|
b := &bytes.Buffer{}
|
||||||
|
for _, ln := range stdin {
|
||||||
|
if _, err := fmt.Fprint(b, ln); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cmd.Stdin = b
|
||||||
|
}
|
||||||
|
out, err := cmd.CombinedOutput() // execute and wait for output
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err)
|
||||||
|
}
|
||||||
|
logrus.Tracef("executed shell command %s, with output: %s", commandWithInterfaceName, string(out))
|
||||||
|
return nil
|
||||||
|
}
|
@ -4,19 +4,21 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/h44z/wg-portal/internal/domain"
|
"github.com/h44z/wg-portal/internal/domain"
|
||||||
"github.com/h44z/wg-portal/internal/lowlevel"
|
"github.com/h44z/wg-portal/internal/lowlevel"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WgRepo implements all low-level WireGuard interactions.
|
// WgRepo implements all low-level WireGuard interactions.
|
||||||
type WgRepo struct {
|
type WgRepo struct {
|
||||||
wg lowlevel.WireGuardClient
|
wg lowlevel.WireGuardClient
|
||||||
nl lowlevel.NetlinkClient
|
nl lowlevel.NetlinkClient
|
||||||
|
quick *WgQuickRepo
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWireGuardRepository() *WgRepo {
|
func NewWireGuardRepository() *WgRepo {
|
||||||
@ -28,8 +30,9 @@ func NewWireGuardRepository() *WgRepo {
|
|||||||
nl := &lowlevel.NetlinkManager{}
|
nl := &lowlevel.NetlinkManager{}
|
||||||
|
|
||||||
repo := &WgRepo{
|
repo := &WgRepo{
|
||||||
wg: wg,
|
wg: wg,
|
||||||
nl: nl,
|
nl: nl,
|
||||||
|
quick: NewWgQuickRepo(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return repo
|
return repo
|
||||||
@ -152,18 +155,40 @@ func (r *WgRepo) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer,
|
|||||||
return peerModel, nil
|
return peerModel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *WgRepo) SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error {
|
func (r *WgRepo) SaveInterface(_ context.Context, iface *domain.Interface, peers []domain.Peer, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error {
|
||||||
physicalInterface, err := r.getOrCreateInterface(id)
|
physicalInterface, err := r.getOrCreateInterface(iface.Identifier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wasUp := physicalInterface.DeviceUp
|
||||||
if updateFunc != nil {
|
if updateFunc != nil {
|
||||||
physicalInterface, err = updateFunc(physicalInterface)
|
physicalInterface, err = updateFunc(physicalInterface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
stateChanged := wasUp != physicalInterface.DeviceUp
|
||||||
|
|
||||||
|
if stateChanged {
|
||||||
|
if physicalInterface.DeviceUp {
|
||||||
|
if err := r.quick.SetDNS(iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
|
||||||
|
return fmt.Errorf("failed to update dns settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreUp); err != nil {
|
||||||
|
return fmt.Errorf("failed to execute pre-up hook: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := r.quick.UnsetDNS(iface.Identifier); err != nil {
|
||||||
|
return fmt.Errorf("failed to clear dns settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreDown); err != nil {
|
||||||
|
return fmt.Errorf("failed to execute pre-down hook: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.updateLowLevelInterface(physicalInterface); err != nil {
|
if err := r.updateLowLevelInterface(physicalInterface); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -171,6 +196,21 @@ func (r *WgRepo) SaveInterface(_ context.Context, id domain.InterfaceIdentifier,
|
|||||||
if err := r.updateWireGuardInterface(physicalInterface); err != nil {
|
if err := r.updateWireGuardInterface(physicalInterface); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if err := r.updateRoutes(iface.Identifier, iface.GetRoutingTable(), iface.GetAllowedIPs(peers)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if stateChanged {
|
||||||
|
if physicalInterface.DeviceUp {
|
||||||
|
if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostUp); err != nil {
|
||||||
|
return fmt.Errorf("failed to execute post-up hook: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostDown); err != nil {
|
||||||
|
return fmt.Errorf("failed to execute post-down hook: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -298,6 +338,72 @@ func (r *WgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *WgRepo) updateRoutes(interfaceId domain.InterfaceIdentifier, table int, allowedIPs []domain.Cidr) error {
|
||||||
|
if table == -1 {
|
||||||
|
logrus.Trace("ignoring route update")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
link, err := r.nl.LinkByName(string(interfaceId))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if link.Attrs().OperState == netlink.OperDown {
|
||||||
|
return nil // cannot set route for interface that is down
|
||||||
|
}
|
||||||
|
|
||||||
|
// try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash)
|
||||||
|
for _, allowedIP := range allowedIPs {
|
||||||
|
if allowedIP.Prefix().Bits() == 0 { // default route
|
||||||
|
// TODO
|
||||||
|
} else {
|
||||||
|
err := r.nl.RouteReplace(&netlink.Route{
|
||||||
|
LinkIndex: link.Attrs().Index,
|
||||||
|
Dst: allowedIP.IpNet(),
|
||||||
|
Table: table,
|
||||||
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
|
Type: unix.RTN_UNICAST,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove unwanted routes
|
||||||
|
rawRoutes, err := r.nl.RouteListFiltered(netlink.FAMILY_ALL, &netlink.Route{
|
||||||
|
LinkIndex: link.Attrs().Index,
|
||||||
|
Table: unix.RT_TABLE_UNSPEC, // all tables
|
||||||
|
Scope: unix.RT_SCOPE_LINK,
|
||||||
|
Type: unix.RTN_UNICAST,
|
||||||
|
}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to fetch raw routes: %w", err)
|
||||||
|
}
|
||||||
|
for _, rawRoute := range rawRoutes {
|
||||||
|
netlinkAddr := domain.CidrFromIpNet(*rawRoute.Dst)
|
||||||
|
remove := true
|
||||||
|
for _, allowedIP := range allowedIPs {
|
||||||
|
if netlinkAddr == allowedIP {
|
||||||
|
remove = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !remove {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := r.nl.RouteDel(&rawRoute)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
|
func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
|
||||||
if err := r.deleteLowLevelInterface(id); err != nil {
|
if err := r.deleteLowLevelInterface(id); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -388,16 +494,10 @@ func (r *WgRepo) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.Phys
|
|||||||
Endpoint: pp.GetEndpointAddress(),
|
Endpoint: pp.GetEndpointAddress(),
|
||||||
PersistentKeepaliveInterval: pp.GetPersistentKeepaliveTime(),
|
PersistentKeepaliveInterval: pp.GetPersistentKeepaliveTime(),
|
||||||
ReplaceAllowedIPs: true,
|
ReplaceAllowedIPs: true,
|
||||||
AllowedIPs: nil,
|
AllowedIPs: pp.GetAllowedIPs(),
|
||||||
}
|
}
|
||||||
|
|
||||||
ips, err := pp.GetAllowedIPs()
|
err := r.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}})
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
cfg.AllowedIPs = ips
|
|
||||||
|
|
||||||
err = r.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -70,8 +70,8 @@ func TestWireGuardCreateInterface(t *testing.T) {
|
|||||||
|
|
||||||
err := mgr.SaveInterface(context.Background(), interfaceName, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
err := mgr.SaveInterface(context.Background(), interfaceName, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||||
pi.Addresses = []domain.Cidr{
|
pi.Addresses = []domain.Cidr{
|
||||||
{Ip: domain.IpAddress(net.ParseIP(ipAddress)), NetLength: 24, Bits: 32},
|
domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipAddress), Mask: net.CIDRMask(24, 32)}),
|
||||||
{Ip: domain.IpAddress(net.ParseIP(ipV6Address)), NetLength: 64, Bits: 128},
|
domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipV6Address), Mask: net.CIDRMask(64, 128)}),
|
||||||
}
|
}
|
||||||
return pi, nil
|
return pi, nil
|
||||||
})
|
})
|
||||||
@ -104,8 +104,8 @@ func TestWireGuardUpdateInterface(t *testing.T) {
|
|||||||
ipV6Address := "1337:d34d:b33f::2"
|
ipV6Address := "1337:d34d:b33f::2"
|
||||||
err = mgr.SaveInterface(context.Background(), interfaceName, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
err = mgr.SaveInterface(context.Background(), interfaceName, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||||
pi.Addresses = []domain.Cidr{
|
pi.Addresses = []domain.Cidr{
|
||||||
{Ip: domain.IpAddress(net.ParseIP(ipAddress)), NetLength: 24, Bits: 32},
|
domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipAddress), Mask: net.CIDRMask(24, 32)}),
|
||||||
{Ip: domain.IpAddress(net.ParseIP(ipV6Address)), NetLength: 64, Bits: 128},
|
domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipV6Address), Mask: net.CIDRMask(64, 128)}),
|
||||||
}
|
}
|
||||||
return pi, nil
|
return pi, nil
|
||||||
})
|
})
|
||||||
@ -119,3 +119,42 @@ func TestWireGuardUpdateInterface(t *testing.T) {
|
|||||||
assert.Contains(t, string(out), ipAddress)
|
assert.Contains(t, string(out), ipAddress)
|
||||||
assert.Contains(t, string(out), ipV6Address)
|
assert.Contains(t, string(out), ipV6Address)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWireGuardCreateInterfaceWithRoutes(t *testing.T) {
|
||||||
|
mgr := setup(t)
|
||||||
|
|
||||||
|
interfaceName := domain.InterfaceIdentifier("wg_test_001")
|
||||||
|
ipAddress := "10.11.12.13"
|
||||||
|
ipV6Address := "1337:d34d:b33f::2"
|
||||||
|
defer mgr.DeleteInterface(context.Background(), interfaceName)
|
||||||
|
|
||||||
|
iface := &domain.Interface{
|
||||||
|
Identifier: interfaceName,
|
||||||
|
//RoutingTable: "1234",
|
||||||
|
}
|
||||||
|
peers := []domain.Peer{
|
||||||
|
{
|
||||||
|
Interface: domain.PeerInterfaceConfig{
|
||||||
|
Addresses: domain.CidrsMust(domain.CidrsFromString("10.11.12.14/32,10.22.33.44/32")),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := mgr.SaveInterface2(context.Background(), iface, peers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||||
|
pi.Addresses = []domain.Cidr{
|
||||||
|
domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipAddress), Mask: net.CIDRMask(24, 32)}),
|
||||||
|
domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipV6Address), Mask: net.CIDRMask(64, 128)}),
|
||||||
|
}
|
||||||
|
pi.DeviceUp = true
|
||||||
|
return pi, nil
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Validate that the interface has been created
|
||||||
|
cmd := exec.Command("ip", "addr")
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(out), interfaceName)
|
||||||
|
assert.Contains(t, string(out), ipAddress)
|
||||||
|
assert.Contains(t, string(out), ipV6Address)
|
||||||
|
}
|
||||||
|
@ -37,7 +37,7 @@ type InterfaceController interface {
|
|||||||
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
||||||
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
||||||
GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error)
|
GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error)
|
||||||
SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error
|
SaveInterface(_ context.Context, iface *domain.Interface, peers []domain.Peer, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error
|
||||||
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
||||||
SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error
|
SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error
|
||||||
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
||||||
|
@ -135,12 +135,12 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool
|
|||||||
return fmt.Errorf("failed to load peers for %s: %w", iface.Identifier, err)
|
return fmt.Errorf("failed to load peers for %s: %w", iface.Identifier, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
physicalInterface, err := m.wg.GetInterface(ctx, iface.Identifier)
|
_, err = m.wg.GetInterface(ctx, iface.Identifier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Debugf("creating missing interface %s...", iface.Identifier)
|
logrus.Debugf("creating missing interface %s...", iface.Identifier)
|
||||||
|
|
||||||
// try to create a new interface
|
// try to create a new interface
|
||||||
err := m.wg.SaveInterface(ctx, iface.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
err := m.wg.SaveInterface(ctx, &iface, peers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||||
domain.MergeToPhysicalInterface(pi, &iface)
|
domain.MergeToPhysicalInterface(pi, &iface)
|
||||||
|
|
||||||
return pi, nil
|
return pi, nil
|
||||||
@ -169,32 +169,30 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if physicalInterface.DeviceUp != !iface.IsDisabled() {
|
logrus.Debugf("restoring interface state for %s to disabled=%t", iface.Identifier, iface.IsDisabled())
|
||||||
logrus.Debugf("restoring interface state for %s to disabled=%t", iface.Identifier, iface.IsDisabled())
|
|
||||||
|
|
||||||
// try to move interface to stored state
|
// try to move interface to stored state
|
||||||
err := m.wg.SaveInterface(ctx, iface.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
err := m.wg.SaveInterface(ctx, &iface, peers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||||
pi.DeviceUp = !iface.IsDisabled()
|
pi.DeviceUp = !iface.IsDisabled()
|
||||||
|
|
||||||
return pi, nil
|
return pi, nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if updateDbOnError {
|
if updateDbOnError {
|
||||||
// disable interface in database as no physical interface is available
|
// disable interface in database as no physical interface is available
|
||||||
_ = m.db.SaveInterface(ctx, iface.Identifier, func(in *domain.Interface) (*domain.Interface, error) {
|
_ = m.db.SaveInterface(ctx, iface.Identifier, func(in *domain.Interface) (*domain.Interface, error) {
|
||||||
if iface.IsDisabled() {
|
if iface.IsDisabled() {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
in.Disabled = &now // set
|
in.Disabled = &now // set
|
||||||
in.DisabledReason = "no physical interface active"
|
in.DisabledReason = "no physical interface active"
|
||||||
} else {
|
} else {
|
||||||
in.Disabled = nil
|
in.Disabled = nil
|
||||||
in.DisabledReason = ""
|
in.DisabledReason = ""
|
||||||
}
|
}
|
||||||
return in, nil
|
return in, nil
|
||||||
})
|
})
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to change physical interface state for %s: %w", iface.Identifier, err)
|
|
||||||
}
|
}
|
||||||
|
return fmt.Errorf("failed to change physical interface state for %s: %w", iface.Identifier, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -294,7 +292,7 @@ func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*do
|
|||||||
err = m.db.SaveInterface(ctx, in.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
|
err = m.db.SaveInterface(ctx, in.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
|
||||||
in.CopyCalculatedAttributes(i)
|
in.CopyCalculatedAttributes(i)
|
||||||
|
|
||||||
err = m.wg.SaveInterface(ctx, in.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
err = m.wg.SaveInterface(ctx, in, nil, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||||
domain.MergeToPhysicalInterface(pi, in)
|
domain.MergeToPhysicalInterface(pi, in)
|
||||||
return pi, nil
|
return pi, nil
|
||||||
})
|
})
|
||||||
@ -324,7 +322,7 @@ func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*do
|
|||||||
err = m.db.SaveInterface(ctx, in.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
|
err = m.db.SaveInterface(ctx, in.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
|
||||||
in.CopyCalculatedAttributes(i)
|
in.CopyCalculatedAttributes(i)
|
||||||
|
|
||||||
err = m.wg.SaveInterface(ctx, in.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
err = m.wg.SaveInterface(ctx, in, existingPeers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||||
domain.MergeToPhysicalInterface(pi, in)
|
domain.MergeToPhysicalInterface(pi, in)
|
||||||
return pi, nil
|
return pi, nil
|
||||||
})
|
})
|
||||||
|
@ -306,7 +306,7 @@ func (m Manager) getFreshPeerIpConfig(ctx context.Context, iface *domain.Interfa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ips = append(ips, ip)
|
ips = append(ips, ip.HostAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -3,7 +3,10 @@ package domain
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/h44z/wg-portal/internal"
|
"github.com/h44z/wg-portal/internal"
|
||||||
|
"math"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -95,6 +98,47 @@ func (i *Interface) GetConfigFileName() string {
|
|||||||
return filename
|
return filename
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *Interface) GetAllowedIPs(peers []Peer) []Cidr {
|
||||||
|
var allowedCidrs []Cidr
|
||||||
|
|
||||||
|
for _, peer := range peers {
|
||||||
|
allowedCidrs = append(allowedCidrs, peer.Interface.Addresses...)
|
||||||
|
if peer.ExtraAllowedIPsStr != "" {
|
||||||
|
extraIPs, err := CidrsFromString(peer.ExtraAllowedIPsStr)
|
||||||
|
if err == nil {
|
||||||
|
allowedCidrs = append(allowedCidrs, extraIPs...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return allowedCidrs
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoutingTable returns the routing table number or -1 if an error occurred or RoutingTable was set to "off"
|
||||||
|
func (i *Interface) GetRoutingTable() int {
|
||||||
|
routingTableStr := strings.ToLower(i.RoutingTable)
|
||||||
|
switch {
|
||||||
|
case routingTableStr == "":
|
||||||
|
return 0
|
||||||
|
case strings.HasPrefix(routingTableStr, "0x"):
|
||||||
|
numberStr := strings.ReplaceAll(routingTableStr, "0x", "")
|
||||||
|
routingTable, err := strconv.ParseUint(numberStr, 16, 64)
|
||||||
|
if err != nil {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if routingTable > math.MaxInt32 {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return int(routingTable)
|
||||||
|
default:
|
||||||
|
routingTable, err := strconv.Atoi(routingTableStr)
|
||||||
|
if err != nil {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return routingTable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type PhysicalInterface struct {
|
type PhysicalInterface struct {
|
||||||
Identifier InterfaceIdentifier // device name, for example: wg0
|
Identifier InterfaceIdentifier // device name, for example: wg0
|
||||||
KeyPair // private/public Key of the server interface
|
KeyPair // private/public Key of the server interface
|
||||||
|
@ -163,6 +163,14 @@ func (c Cidr) NextAddr() Cidr {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c Cidr) HostAddr() Cidr {
|
||||||
|
return Cidr{
|
||||||
|
Cidr: netip.PrefixFrom(c.Prefix().Addr(), c.Prefix().Addr().BitLen()).String(),
|
||||||
|
Addr: c.Addr,
|
||||||
|
NetLength: c.Prefix().Addr().BitLen(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c Cidr) NextSubnet() Cidr {
|
func (c Cidr) NextSubnet() Cidr {
|
||||||
prefix := c.Prefix()
|
prefix := c.Prefix()
|
||||||
nextAddr := c.BroadcastAddr().Prefix().Addr().Next()
|
nextAddr := c.BroadcastAddr().Prefix().Addr().Next()
|
||||||
|
@ -190,13 +190,13 @@ func (p PhysicalPeer) GetPersistentKeepaliveTime() *time.Duration {
|
|||||||
return &keepAliveDuration
|
return &keepAliveDuration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p PhysicalPeer) GetAllowedIPs() ([]net.IPNet, error) {
|
func (p PhysicalPeer) GetAllowedIPs() []net.IPNet {
|
||||||
allowedIPs := make([]net.IPNet, len(p.AllowedIPs))
|
allowedIPs := make([]net.IPNet, len(p.AllowedIPs))
|
||||||
for i, ip := range p.AllowedIPs {
|
for i, ip := range p.AllowedIPs {
|
||||||
allowedIPs[i] = *ip.IpNet()
|
allowedIPs[i] = *ip.IpNet()
|
||||||
}
|
}
|
||||||
|
|
||||||
return allowedIPs, nil
|
return allowedIPs
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConvertPhysicalPeer(pp *PhysicalPeer) *Peer {
|
func ConvertPhysicalPeer(pp *PhysicalPeer) *Peer {
|
||||||
@ -223,9 +223,15 @@ func ConvertPhysicalPeer(pp *PhysicalPeer) *Peer {
|
|||||||
func MergeToPhysicalPeer(pp *PhysicalPeer, p *Peer) {
|
func MergeToPhysicalPeer(pp *PhysicalPeer, p *Peer) {
|
||||||
pp.Identifier = p.Identifier
|
pp.Identifier = p.Identifier
|
||||||
pp.Endpoint = p.Endpoint.GetValue()
|
pp.Endpoint = p.Endpoint.GetValue()
|
||||||
allowedIPs, _ := CidrsFromString(p.AllowedIPsStr.GetValue())
|
if p.Interface.Type == InterfaceTypeServer {
|
||||||
extraAllowedIPs, _ := CidrsFromString(p.ExtraAllowedIPsStr)
|
allowedIPs, _ := CidrsFromString(p.AllowedIPsStr.GetValue())
|
||||||
pp.AllowedIPs = append(allowedIPs, extraAllowedIPs...)
|
extraAllowedIPs, _ := CidrsFromString(p.ExtraAllowedIPsStr)
|
||||||
|
pp.AllowedIPs = append(allowedIPs, extraAllowedIPs...)
|
||||||
|
} else {
|
||||||
|
allowedIPs := p.Interface.Addresses
|
||||||
|
extraAllowedIPs, _ := CidrsFromString(p.ExtraAllowedIPsStr)
|
||||||
|
pp.AllowedIPs = append(allowedIPs, extraAllowedIPs...)
|
||||||
|
}
|
||||||
pp.PresharedKey = p.PresharedKey
|
pp.PresharedKey = p.PresharedKey
|
||||||
pp.PublicKey = p.Interface.PublicKey
|
pp.PublicKey = p.Interface.PublicKey
|
||||||
pp.PersistentKeepalive = p.PersistentKeepalive.GetValue()
|
pp.PersistentKeepalive = p.PersistentKeepalive.GetValue()
|
||||||
|
@ -16,6 +16,14 @@ type NetlinkClient interface {
|
|||||||
AddrAdd(link netlink.Link, addr *netlink.Addr) error
|
AddrAdd(link netlink.Link, addr *netlink.Addr) error
|
||||||
AddrList(link netlink.Link) ([]netlink.Addr, error)
|
AddrList(link netlink.Link) ([]netlink.Addr, error)
|
||||||
AddrDel(link netlink.Link, addr *netlink.Addr) error
|
AddrDel(link netlink.Link, addr *netlink.Addr) error
|
||||||
|
RouteAdd(route *netlink.Route) error
|
||||||
|
RouteDel(route *netlink.Route) error
|
||||||
|
RouteReplace(route *netlink.Route) error
|
||||||
|
RouteList(link netlink.Link, family int) ([]netlink.Route, error)
|
||||||
|
RouteListFiltered(family int, filter *netlink.Route, filterMask uint64) ([]netlink.Route, error)
|
||||||
|
RuleAdd(rule *netlink.Rule) error
|
||||||
|
RuleDel(rule *netlink.Rule) error
|
||||||
|
RuleList(family int) ([]netlink.Rule, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type NetlinkManager struct {
|
type NetlinkManager struct {
|
||||||
@ -66,3 +74,35 @@ func (n NetlinkManager) AddrList(link netlink.Link) ([]netlink.Addr, error) {
|
|||||||
func (n NetlinkManager) AddrDel(link netlink.Link, addr *netlink.Addr) error {
|
func (n NetlinkManager) AddrDel(link netlink.Link, addr *netlink.Addr) error {
|
||||||
return netlink.AddrDel(link, addr)
|
return netlink.AddrDel(link, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n NetlinkManager) RouteAdd(route *netlink.Route) error {
|
||||||
|
return netlink.RouteAdd(route)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n NetlinkManager) RouteDel(route *netlink.Route) error {
|
||||||
|
return netlink.RouteDel(route)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n NetlinkManager) RouteReplace(route *netlink.Route) error {
|
||||||
|
return netlink.RouteReplace(route)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n NetlinkManager) RouteList(link netlink.Link, family int) ([]netlink.Route, error) {
|
||||||
|
return netlink.RouteList(link, family)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n NetlinkManager) RouteListFiltered(family int, filter *netlink.Route, filterMask uint64) ([]netlink.Route, error) {
|
||||||
|
return netlink.RouteListFiltered(family, filter, filterMask)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n NetlinkManager) RuleAdd(rule *netlink.Rule) error {
|
||||||
|
return netlink.RuleAdd(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n NetlinkManager) RuleDel(rule *netlink.Rule) error {
|
||||||
|
return netlink.RuleDel(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n NetlinkManager) RuleList(family int) ([]netlink.Rule, error) {
|
||||||
|
return netlink.RuleList(family)
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user