mirror of
https://github.com/h44z/wg-portal.git
synced 2025-08-09 15:02:24 +00:00
wip: wgquick
This commit is contained in:
parent
c934d7ecd3
commit
5aa94999ab
@ -53,15 +53,15 @@ func (l *GormLogger) LogMode(logger.LogLevel) logger.Interface {
|
||||
}
|
||||
|
||||
func (l *GormLogger) Info(ctx context.Context, s string, args ...interface{}) {
|
||||
logrus.WithContext(ctx).Infof(s, args)
|
||||
logrus.WithContext(ctx).Infof(s, args...)
|
||||
}
|
||||
|
||||
func (l *GormLogger) Warn(ctx context.Context, s string, args ...interface{}) {
|
||||
logrus.WithContext(ctx).Warnf(s, args)
|
||||
logrus.WithContext(ctx).Warnf(s, args...)
|
||||
}
|
||||
|
||||
func (l *GormLogger) Error(ctx context.Context, s string, args ...interface{}) {
|
||||
logrus.WithContext(ctx).Errorf(s, args)
|
||||
logrus.WithContext(ctx).Errorf(s, args...)
|
||||
}
|
||||
|
||||
func (l *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||
|
99
internal/adapters/wgquick.go
Normal file
99
internal/adapters/wgquick.go
Normal file
@ -0,0 +1,99 @@
|
||||
package adapters
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/h44z/wg-portal/internal"
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
"github.com/sirupsen/logrus"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// WgQuickRepo implements higher level wg-quick like interactions like setting DNS, routing tables or interface hooks.
|
||||
type WgQuickRepo struct {
|
||||
shellCmd string
|
||||
resolvConfIfacePrefix string
|
||||
}
|
||||
|
||||
func NewWgQuickRepo() *WgQuickRepo {
|
||||
return &WgQuickRepo{
|
||||
shellCmd: "bash",
|
||||
resolvConfIfacePrefix: "tun.",
|
||||
}
|
||||
}
|
||||
|
||||
func (r *WgQuickRepo) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
|
||||
if hookCmd == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := r.exec(hookCmd, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to exec hook: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *WgQuickRepo) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
|
||||
if dnsStr == "" && dnsSearchStr == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnsServers := internal.SliceString(dnsStr)
|
||||
dnsSearchDomains := internal.SliceString(dnsSearchStr)
|
||||
|
||||
dnsCommand := "resolvconf -a %resPref%i -m 0 -x"
|
||||
dnsCommandInput := make([]string, 0, len(dnsServers)+len(dnsSearchDomains))
|
||||
|
||||
for _, dnsServer := range dnsServers {
|
||||
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("nameserver %s", dnsServer))
|
||||
}
|
||||
for _, searchDomain := range dnsSearchDomains {
|
||||
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("search %s", searchDomain))
|
||||
}
|
||||
|
||||
err := r.exec(dnsCommand, id, dnsCommandInput...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set dns settings: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *WgQuickRepo) UnsetDNS(id domain.InterfaceIdentifier) error {
|
||||
dnsCommand := "resolvconf -d %resPref%i -f"
|
||||
|
||||
err := r.exec(dnsCommand, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unset dns settings: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *WgQuickRepo) replaceCommandPlaceHolders(command string, interfaceId domain.InterfaceIdentifier) string {
|
||||
command = strings.ReplaceAll(command, "%resPref", r.resolvConfIfacePrefix)
|
||||
return strings.ReplaceAll(command, "%i", string(interfaceId))
|
||||
}
|
||||
|
||||
func (r *WgQuickRepo) exec(command string, interfaceId domain.InterfaceIdentifier, stdin ...string) error {
|
||||
commandWithInterfaceName := r.replaceCommandPlaceHolders(command, interfaceId)
|
||||
cmd := exec.Command(r.shellCmd, "-ce", commandWithInterfaceName)
|
||||
if len(stdin) > 0 {
|
||||
b := &bytes.Buffer{}
|
||||
for _, ln := range stdin {
|
||||
if _, err := fmt.Fprint(b, ln); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
cmd.Stdin = b
|
||||
}
|
||||
out, err := cmd.CombinedOutput() // execute and wait for output
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err)
|
||||
}
|
||||
logrus.Tracef("executed shell command %s, with output: %s", commandWithInterfaceName, string(out))
|
||||
return nil
|
||||
}
|
@ -4,19 +4,21 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/domain"
|
||||
"github.com/h44z/wg-portal/internal/lowlevel"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"os"
|
||||
)
|
||||
|
||||
// WgRepo implements all low-level WireGuard interactions.
|
||||
type WgRepo struct {
|
||||
wg lowlevel.WireGuardClient
|
||||
nl lowlevel.NetlinkClient
|
||||
wg lowlevel.WireGuardClient
|
||||
nl lowlevel.NetlinkClient
|
||||
quick *WgQuickRepo
|
||||
}
|
||||
|
||||
func NewWireGuardRepository() *WgRepo {
|
||||
@ -28,8 +30,9 @@ func NewWireGuardRepository() *WgRepo {
|
||||
nl := &lowlevel.NetlinkManager{}
|
||||
|
||||
repo := &WgRepo{
|
||||
wg: wg,
|
||||
nl: nl,
|
||||
wg: wg,
|
||||
nl: nl,
|
||||
quick: NewWgQuickRepo(),
|
||||
}
|
||||
|
||||
return repo
|
||||
@ -152,18 +155,40 @@ func (r *WgRepo) convertWireGuardPeer(peer *wgtypes.Peer) (domain.PhysicalPeer,
|
||||
return peerModel, nil
|
||||
}
|
||||
|
||||
func (r *WgRepo) SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error {
|
||||
physicalInterface, err := r.getOrCreateInterface(id)
|
||||
func (r *WgRepo) SaveInterface(_ context.Context, iface *domain.Interface, peers []domain.Peer, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error {
|
||||
physicalInterface, err := r.getOrCreateInterface(iface.Identifier)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
wasUp := physicalInterface.DeviceUp
|
||||
if updateFunc != nil {
|
||||
physicalInterface, err = updateFunc(physicalInterface)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
stateChanged := wasUp != physicalInterface.DeviceUp
|
||||
|
||||
if stateChanged {
|
||||
if physicalInterface.DeviceUp {
|
||||
if err := r.quick.SetDNS(iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
|
||||
return fmt.Errorf("failed to update dns settings: %w", err)
|
||||
}
|
||||
|
||||
if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreUp); err != nil {
|
||||
return fmt.Errorf("failed to execute pre-up hook: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := r.quick.UnsetDNS(iface.Identifier); err != nil {
|
||||
return fmt.Errorf("failed to clear dns settings: %w", err)
|
||||
}
|
||||
|
||||
if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreDown); err != nil {
|
||||
return fmt.Errorf("failed to execute pre-down hook: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.updateLowLevelInterface(physicalInterface); err != nil {
|
||||
return err
|
||||
@ -171,6 +196,21 @@ func (r *WgRepo) SaveInterface(_ context.Context, id domain.InterfaceIdentifier,
|
||||
if err := r.updateWireGuardInterface(physicalInterface); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.updateRoutes(iface.Identifier, iface.GetRoutingTable(), iface.GetAllowedIPs(peers)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if stateChanged {
|
||||
if physicalInterface.DeviceUp {
|
||||
if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostUp); err != nil {
|
||||
return fmt.Errorf("failed to execute post-up hook: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := r.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostDown); err != nil {
|
||||
return fmt.Errorf("failed to execute post-down hook: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -298,6 +338,72 @@ func (r *WgRepo) updateWireGuardInterface(pi *domain.PhysicalInterface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *WgRepo) updateRoutes(interfaceId domain.InterfaceIdentifier, table int, allowedIPs []domain.Cidr) error {
|
||||
if table == -1 {
|
||||
logrus.Trace("ignoring route update")
|
||||
return nil
|
||||
}
|
||||
|
||||
link, err := r.nl.LinkByName(string(interfaceId))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if link.Attrs().OperState == netlink.OperDown {
|
||||
return nil // cannot set route for interface that is down
|
||||
}
|
||||
|
||||
// try to mimic wg-quick (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash)
|
||||
for _, allowedIP := range allowedIPs {
|
||||
if allowedIP.Prefix().Bits() == 0 { // default route
|
||||
// TODO
|
||||
} else {
|
||||
err := r.nl.RouteReplace(&netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Dst: allowedIP.IpNet(),
|
||||
Table: table,
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
Type: unix.RTN_UNICAST,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove unwanted routes
|
||||
rawRoutes, err := r.nl.RouteListFiltered(netlink.FAMILY_ALL, &netlink.Route{
|
||||
LinkIndex: link.Attrs().Index,
|
||||
Table: unix.RT_TABLE_UNSPEC, // all tables
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
Type: unix.RTN_UNICAST,
|
||||
}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch raw routes: %w", err)
|
||||
}
|
||||
for _, rawRoute := range rawRoutes {
|
||||
netlinkAddr := domain.CidrFromIpNet(*rawRoute.Dst)
|
||||
remove := true
|
||||
for _, allowedIP := range allowedIPs {
|
||||
if netlinkAddr == allowedIP {
|
||||
remove = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !remove {
|
||||
continue
|
||||
}
|
||||
|
||||
err := r.nl.RouteDel(&rawRoute)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *WgRepo) DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error {
|
||||
if err := r.deleteLowLevelInterface(id); err != nil {
|
||||
return err
|
||||
@ -388,16 +494,10 @@ func (r *WgRepo) updatePeer(deviceId domain.InterfaceIdentifier, pp *domain.Phys
|
||||
Endpoint: pp.GetEndpointAddress(),
|
||||
PersistentKeepaliveInterval: pp.GetPersistentKeepaliveTime(),
|
||||
ReplaceAllowedIPs: true,
|
||||
AllowedIPs: nil,
|
||||
AllowedIPs: pp.GetAllowedIPs(),
|
||||
}
|
||||
|
||||
ips, err := pp.GetAllowedIPs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.AllowedIPs = ips
|
||||
|
||||
err = r.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}})
|
||||
err := r.wg.ConfigureDevice(string(deviceId), wgtypes.Config{ReplacePeers: false, Peers: []wgtypes.PeerConfig{cfg}})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -70,8 +70,8 @@ func TestWireGuardCreateInterface(t *testing.T) {
|
||||
|
||||
err := mgr.SaveInterface(context.Background(), interfaceName, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||
pi.Addresses = []domain.Cidr{
|
||||
{Ip: domain.IpAddress(net.ParseIP(ipAddress)), NetLength: 24, Bits: 32},
|
||||
{Ip: domain.IpAddress(net.ParseIP(ipV6Address)), NetLength: 64, Bits: 128},
|
||||
domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipAddress), Mask: net.CIDRMask(24, 32)}),
|
||||
domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipV6Address), Mask: net.CIDRMask(64, 128)}),
|
||||
}
|
||||
return pi, nil
|
||||
})
|
||||
@ -104,8 +104,8 @@ func TestWireGuardUpdateInterface(t *testing.T) {
|
||||
ipV6Address := "1337:d34d:b33f::2"
|
||||
err = mgr.SaveInterface(context.Background(), interfaceName, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||
pi.Addresses = []domain.Cidr{
|
||||
{Ip: domain.IpAddress(net.ParseIP(ipAddress)), NetLength: 24, Bits: 32},
|
||||
{Ip: domain.IpAddress(net.ParseIP(ipV6Address)), NetLength: 64, Bits: 128},
|
||||
domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipAddress), Mask: net.CIDRMask(24, 32)}),
|
||||
domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipV6Address), Mask: net.CIDRMask(64, 128)}),
|
||||
}
|
||||
return pi, nil
|
||||
})
|
||||
@ -119,3 +119,42 @@ func TestWireGuardUpdateInterface(t *testing.T) {
|
||||
assert.Contains(t, string(out), ipAddress)
|
||||
assert.Contains(t, string(out), ipV6Address)
|
||||
}
|
||||
|
||||
func TestWireGuardCreateInterfaceWithRoutes(t *testing.T) {
|
||||
mgr := setup(t)
|
||||
|
||||
interfaceName := domain.InterfaceIdentifier("wg_test_001")
|
||||
ipAddress := "10.11.12.13"
|
||||
ipV6Address := "1337:d34d:b33f::2"
|
||||
defer mgr.DeleteInterface(context.Background(), interfaceName)
|
||||
|
||||
iface := &domain.Interface{
|
||||
Identifier: interfaceName,
|
||||
//RoutingTable: "1234",
|
||||
}
|
||||
peers := []domain.Peer{
|
||||
{
|
||||
Interface: domain.PeerInterfaceConfig{
|
||||
Addresses: domain.CidrsMust(domain.CidrsFromString("10.11.12.14/32,10.22.33.44/32")),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := mgr.SaveInterface2(context.Background(), iface, peers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||
pi.Addresses = []domain.Cidr{
|
||||
domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipAddress), Mask: net.CIDRMask(24, 32)}),
|
||||
domain.CidrFromIpNet(net.IPNet{IP: net.ParseIP(ipV6Address), Mask: net.CIDRMask(64, 128)}),
|
||||
}
|
||||
pi.DeviceUp = true
|
||||
return pi, nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Validate that the interface has been created
|
||||
cmd := exec.Command("ip", "addr")
|
||||
out, err := cmd.CombinedOutput()
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(out), interfaceName)
|
||||
assert.Contains(t, string(out), ipAddress)
|
||||
assert.Contains(t, string(out), ipV6Address)
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ type InterfaceController interface {
|
||||
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
|
||||
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
|
||||
GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error)
|
||||
SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error
|
||||
SaveInterface(_ context.Context, iface *domain.Interface, peers []domain.Peer, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error
|
||||
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
|
||||
SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error
|
||||
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
|
||||
|
@ -135,12 +135,12 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool
|
||||
return fmt.Errorf("failed to load peers for %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
physicalInterface, err := m.wg.GetInterface(ctx, iface.Identifier)
|
||||
_, err = m.wg.GetInterface(ctx, iface.Identifier)
|
||||
if err != nil {
|
||||
logrus.Debugf("creating missing interface %s...", iface.Identifier)
|
||||
|
||||
// try to create a new interface
|
||||
err := m.wg.SaveInterface(ctx, iface.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||
err := m.wg.SaveInterface(ctx, &iface, peers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||
domain.MergeToPhysicalInterface(pi, &iface)
|
||||
|
||||
return pi, nil
|
||||
@ -169,32 +169,30 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if physicalInterface.DeviceUp != !iface.IsDisabled() {
|
||||
logrus.Debugf("restoring interface state for %s to disabled=%t", iface.Identifier, iface.IsDisabled())
|
||||
logrus.Debugf("restoring interface state for %s to disabled=%t", iface.Identifier, iface.IsDisabled())
|
||||
|
||||
// try to move interface to stored state
|
||||
err := m.wg.SaveInterface(ctx, iface.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||
pi.DeviceUp = !iface.IsDisabled()
|
||||
// try to move interface to stored state
|
||||
err := m.wg.SaveInterface(ctx, &iface, peers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||
pi.DeviceUp = !iface.IsDisabled()
|
||||
|
||||
return pi, nil
|
||||
})
|
||||
if err != nil {
|
||||
if updateDbOnError {
|
||||
// disable interface in database as no physical interface is available
|
||||
_ = m.db.SaveInterface(ctx, iface.Identifier, func(in *domain.Interface) (*domain.Interface, error) {
|
||||
if iface.IsDisabled() {
|
||||
now := time.Now()
|
||||
in.Disabled = &now // set
|
||||
in.DisabledReason = "no physical interface active"
|
||||
} else {
|
||||
in.Disabled = nil
|
||||
in.DisabledReason = ""
|
||||
}
|
||||
return in, nil
|
||||
})
|
||||
}
|
||||
return fmt.Errorf("failed to change physical interface state for %s: %w", iface.Identifier, err)
|
||||
return pi, nil
|
||||
})
|
||||
if err != nil {
|
||||
if updateDbOnError {
|
||||
// disable interface in database as no physical interface is available
|
||||
_ = m.db.SaveInterface(ctx, iface.Identifier, func(in *domain.Interface) (*domain.Interface, error) {
|
||||
if iface.IsDisabled() {
|
||||
now := time.Now()
|
||||
in.Disabled = &now // set
|
||||
in.DisabledReason = "no physical interface active"
|
||||
} else {
|
||||
in.Disabled = nil
|
||||
in.DisabledReason = ""
|
||||
}
|
||||
return in, nil
|
||||
})
|
||||
}
|
||||
return fmt.Errorf("failed to change physical interface state for %s: %w", iface.Identifier, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -294,7 +292,7 @@ func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*do
|
||||
err = m.db.SaveInterface(ctx, in.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
|
||||
in.CopyCalculatedAttributes(i)
|
||||
|
||||
err = m.wg.SaveInterface(ctx, in.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||
err = m.wg.SaveInterface(ctx, in, nil, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||
domain.MergeToPhysicalInterface(pi, in)
|
||||
return pi, nil
|
||||
})
|
||||
@ -324,7 +322,7 @@ func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*do
|
||||
err = m.db.SaveInterface(ctx, in.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
|
||||
in.CopyCalculatedAttributes(i)
|
||||
|
||||
err = m.wg.SaveInterface(ctx, in.Identifier, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||
err = m.wg.SaveInterface(ctx, in, existingPeers, func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error) {
|
||||
domain.MergeToPhysicalInterface(pi, in)
|
||||
return pi, nil
|
||||
})
|
||||
|
@ -306,7 +306,7 @@ func (m Manager) getFreshPeerIpConfig(ctx context.Context, iface *domain.Interfa
|
||||
}
|
||||
}
|
||||
|
||||
ips = append(ips, ip)
|
||||
ips = append(ips, ip.HostAddr())
|
||||
}
|
||||
|
||||
return
|
||||
|
@ -3,7 +3,10 @@ package domain
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/h44z/wg-portal/internal"
|
||||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -95,6 +98,47 @@ func (i *Interface) GetConfigFileName() string {
|
||||
return filename
|
||||
}
|
||||
|
||||
func (i *Interface) GetAllowedIPs(peers []Peer) []Cidr {
|
||||
var allowedCidrs []Cidr
|
||||
|
||||
for _, peer := range peers {
|
||||
allowedCidrs = append(allowedCidrs, peer.Interface.Addresses...)
|
||||
if peer.ExtraAllowedIPsStr != "" {
|
||||
extraIPs, err := CidrsFromString(peer.ExtraAllowedIPsStr)
|
||||
if err == nil {
|
||||
allowedCidrs = append(allowedCidrs, extraIPs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return allowedCidrs
|
||||
}
|
||||
|
||||
// GetRoutingTable returns the routing table number or -1 if an error occurred or RoutingTable was set to "off"
|
||||
func (i *Interface) GetRoutingTable() int {
|
||||
routingTableStr := strings.ToLower(i.RoutingTable)
|
||||
switch {
|
||||
case routingTableStr == "":
|
||||
return 0
|
||||
case strings.HasPrefix(routingTableStr, "0x"):
|
||||
numberStr := strings.ReplaceAll(routingTableStr, "0x", "")
|
||||
routingTable, err := strconv.ParseUint(numberStr, 16, 64)
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
if routingTable > math.MaxInt32 {
|
||||
return -1
|
||||
}
|
||||
return int(routingTable)
|
||||
default:
|
||||
routingTable, err := strconv.Atoi(routingTableStr)
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
return routingTable
|
||||
}
|
||||
}
|
||||
|
||||
type PhysicalInterface struct {
|
||||
Identifier InterfaceIdentifier // device name, for example: wg0
|
||||
KeyPair // private/public Key of the server interface
|
||||
|
@ -163,6 +163,14 @@ func (c Cidr) NextAddr() Cidr {
|
||||
}
|
||||
}
|
||||
|
||||
func (c Cidr) HostAddr() Cidr {
|
||||
return Cidr{
|
||||
Cidr: netip.PrefixFrom(c.Prefix().Addr(), c.Prefix().Addr().BitLen()).String(),
|
||||
Addr: c.Addr,
|
||||
NetLength: c.Prefix().Addr().BitLen(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c Cidr) NextSubnet() Cidr {
|
||||
prefix := c.Prefix()
|
||||
nextAddr := c.BroadcastAddr().Prefix().Addr().Next()
|
||||
|
@ -190,13 +190,13 @@ func (p PhysicalPeer) GetPersistentKeepaliveTime() *time.Duration {
|
||||
return &keepAliveDuration
|
||||
}
|
||||
|
||||
func (p PhysicalPeer) GetAllowedIPs() ([]net.IPNet, error) {
|
||||
func (p PhysicalPeer) GetAllowedIPs() []net.IPNet {
|
||||
allowedIPs := make([]net.IPNet, len(p.AllowedIPs))
|
||||
for i, ip := range p.AllowedIPs {
|
||||
allowedIPs[i] = *ip.IpNet()
|
||||
}
|
||||
|
||||
return allowedIPs, nil
|
||||
return allowedIPs
|
||||
}
|
||||
|
||||
func ConvertPhysicalPeer(pp *PhysicalPeer) *Peer {
|
||||
@ -223,9 +223,15 @@ func ConvertPhysicalPeer(pp *PhysicalPeer) *Peer {
|
||||
func MergeToPhysicalPeer(pp *PhysicalPeer, p *Peer) {
|
||||
pp.Identifier = p.Identifier
|
||||
pp.Endpoint = p.Endpoint.GetValue()
|
||||
allowedIPs, _ := CidrsFromString(p.AllowedIPsStr.GetValue())
|
||||
extraAllowedIPs, _ := CidrsFromString(p.ExtraAllowedIPsStr)
|
||||
pp.AllowedIPs = append(allowedIPs, extraAllowedIPs...)
|
||||
if p.Interface.Type == InterfaceTypeServer {
|
||||
allowedIPs, _ := CidrsFromString(p.AllowedIPsStr.GetValue())
|
||||
extraAllowedIPs, _ := CidrsFromString(p.ExtraAllowedIPsStr)
|
||||
pp.AllowedIPs = append(allowedIPs, extraAllowedIPs...)
|
||||
} else {
|
||||
allowedIPs := p.Interface.Addresses
|
||||
extraAllowedIPs, _ := CidrsFromString(p.ExtraAllowedIPsStr)
|
||||
pp.AllowedIPs = append(allowedIPs, extraAllowedIPs...)
|
||||
}
|
||||
pp.PresharedKey = p.PresharedKey
|
||||
pp.PublicKey = p.Interface.PublicKey
|
||||
pp.PersistentKeepalive = p.PersistentKeepalive.GetValue()
|
||||
|
@ -16,6 +16,14 @@ type NetlinkClient interface {
|
||||
AddrAdd(link netlink.Link, addr *netlink.Addr) error
|
||||
AddrList(link netlink.Link) ([]netlink.Addr, error)
|
||||
AddrDel(link netlink.Link, addr *netlink.Addr) error
|
||||
RouteAdd(route *netlink.Route) error
|
||||
RouteDel(route *netlink.Route) error
|
||||
RouteReplace(route *netlink.Route) error
|
||||
RouteList(link netlink.Link, family int) ([]netlink.Route, error)
|
||||
RouteListFiltered(family int, filter *netlink.Route, filterMask uint64) ([]netlink.Route, error)
|
||||
RuleAdd(rule *netlink.Rule) error
|
||||
RuleDel(rule *netlink.Rule) error
|
||||
RuleList(family int) ([]netlink.Rule, error)
|
||||
}
|
||||
|
||||
type NetlinkManager struct {
|
||||
@ -66,3 +74,35 @@ func (n NetlinkManager) AddrList(link netlink.Link) ([]netlink.Addr, error) {
|
||||
func (n NetlinkManager) AddrDel(link netlink.Link, addr *netlink.Addr) error {
|
||||
return netlink.AddrDel(link, addr)
|
||||
}
|
||||
|
||||
func (n NetlinkManager) RouteAdd(route *netlink.Route) error {
|
||||
return netlink.RouteAdd(route)
|
||||
}
|
||||
|
||||
func (n NetlinkManager) RouteDel(route *netlink.Route) error {
|
||||
return netlink.RouteDel(route)
|
||||
}
|
||||
|
||||
func (n NetlinkManager) RouteReplace(route *netlink.Route) error {
|
||||
return netlink.RouteReplace(route)
|
||||
}
|
||||
|
||||
func (n NetlinkManager) RouteList(link netlink.Link, family int) ([]netlink.Route, error) {
|
||||
return netlink.RouteList(link, family)
|
||||
}
|
||||
|
||||
func (n NetlinkManager) RouteListFiltered(family int, filter *netlink.Route, filterMask uint64) ([]netlink.Route, error) {
|
||||
return netlink.RouteListFiltered(family, filter, filterMask)
|
||||
}
|
||||
|
||||
func (n NetlinkManager) RuleAdd(rule *netlink.Rule) error {
|
||||
return netlink.RuleAdd(rule)
|
||||
}
|
||||
|
||||
func (n NetlinkManager) RuleDel(rule *netlink.Rule) error {
|
||||
return netlink.RuleDel(rule)
|
||||
}
|
||||
|
||||
func (n NetlinkManager) RuleList(family int) ([]netlink.Rule, error) {
|
||||
return netlink.RuleList(family)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user