mirror of
https://github.com/h44z/wg-portal.git
synced 2025-09-14 06:51:15 +00:00
@@ -2,9 +2,10 @@ package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/h44z/wg-portal/internal/app"
|
||||
"github.com/sirupsen/logrus"
|
||||
"time"
|
||||
|
||||
evbus "github.com/vardius/message-bus"
|
||||
|
||||
@@ -21,7 +22,13 @@ type Manager struct {
|
||||
quick WgQuickController
|
||||
}
|
||||
|
||||
func NewWireGuardManager(cfg *config.Config, bus evbus.MessageBus, wg InterfaceController, quick WgQuickController, db InterfaceAndPeerDatabaseRepo) (*Manager, error) {
|
||||
func NewWireGuardManager(
|
||||
cfg *config.Config,
|
||||
bus evbus.MessageBus,
|
||||
wg InterfaceController,
|
||||
quick WgQuickController,
|
||||
db InterfaceAndPeerDatabaseRepo,
|
||||
) (*Manager, error) {
|
||||
m := &Manager{
|
||||
cfg: cfg,
|
||||
bus: bus,
|
||||
@@ -42,6 +49,9 @@ func (m Manager) StartBackgroundJobs(ctx context.Context) {
|
||||
func (m Manager) connectToMessageBus() {
|
||||
_ = m.bus.Subscribe(app.TopicUserCreated, m.handleUserCreationEvent)
|
||||
_ = m.bus.Subscribe(app.TopicAuthLogin, m.handleUserLoginEvent)
|
||||
_ = m.bus.Subscribe(app.TopicUserDisabled, m.handleUserDisabledEvent)
|
||||
_ = m.bus.Subscribe(app.TopicUserEnabled, m.handleUserEnabledEvent)
|
||||
_ = m.bus.Subscribe(app.TopicUserDeleted, m.handleUserDeletedEvent)
|
||||
}
|
||||
|
||||
func (m Manager) handleUserCreationEvent(user *domain.User) {
|
||||
@@ -84,6 +94,104 @@ func (m Manager) handleUserLoginEvent(userId domain.UserIdentifier) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m Manager) handleUserDisabledEvent(user domain.User) {
|
||||
ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
|
||||
userPeers, err := m.db.GetUserPeers(ctx, user.Identifier)
|
||||
if err != nil {
|
||||
logrus.Errorf("failed to retrieve peers for disabled user %s: %v", user.Identifier, err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, peer := range userPeers {
|
||||
if peer.IsDisabled() {
|
||||
continue // peer is already disabled
|
||||
}
|
||||
|
||||
logrus.Debugf("disabling peer %s due to user %s being disabled", peer.Identifier, user.Identifier)
|
||||
|
||||
peer.Disabled = user.Disabled // set to user disabled timestamp
|
||||
peer.DisabledReason = domain.DisabledReasonUserDisabled
|
||||
|
||||
_, err := m.UpdatePeer(ctx, &peer)
|
||||
if err != nil {
|
||||
logrus.Errorf("failed to disable peer %s for disabled user %s: %v",
|
||||
peer.Identifier, user.Identifier, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m Manager) handleUserEnabledEvent(user domain.User) {
|
||||
if !m.cfg.Core.ReEnablePeerAfterUserEnable {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
|
||||
userPeers, err := m.db.GetUserPeers(ctx, user.Identifier)
|
||||
if err != nil {
|
||||
logrus.Errorf("failed to retrieve peers for re-enabled user %s: %v", user.Identifier, err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, peer := range userPeers {
|
||||
if !peer.IsDisabled() {
|
||||
continue // peer is already active
|
||||
}
|
||||
|
||||
if peer.DisabledReason != domain.DisabledReasonUserDisabled {
|
||||
continue // peer was disabled for another reason
|
||||
}
|
||||
|
||||
logrus.Debugf("enabling peer %s due to user %s being enabled", peer.Identifier, user.Identifier)
|
||||
|
||||
peer.Disabled = nil
|
||||
peer.DisabledReason = ""
|
||||
|
||||
_, err := m.UpdatePeer(ctx, &peer)
|
||||
if err != nil {
|
||||
logrus.Errorf("failed to enable peer %s for enabled user %s: %v",
|
||||
peer.Identifier, user.Identifier, err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m Manager) handleUserDeletedEvent(user domain.User) {
|
||||
ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
|
||||
userPeers, err := m.db.GetUserPeers(ctx, user.Identifier)
|
||||
if err != nil {
|
||||
logrus.Errorf("failed to retrieve peers for deleted user %s: %v", user.Identifier, err)
|
||||
return
|
||||
}
|
||||
|
||||
deletionTime := time.Now()
|
||||
for _, peer := range userPeers {
|
||||
if peer.IsDisabled() {
|
||||
continue // peer is already disabled
|
||||
}
|
||||
|
||||
if m.cfg.Core.DeletePeerAfterUserDeleted {
|
||||
logrus.Debugf("deleting peer %s due to user %s being deleted", peer.Identifier, user.Identifier)
|
||||
|
||||
if err := m.DeletePeer(ctx, peer.Identifier); err != nil {
|
||||
logrus.Errorf("failed to delete peer %s for deleted user %s: %v",
|
||||
peer.Identifier, user.Identifier, err)
|
||||
}
|
||||
} else {
|
||||
logrus.Debugf("disabling peer %s due to user %s being deleted", peer.Identifier, user.Identifier)
|
||||
|
||||
peer.UserIdentifier = "" // remove user reference
|
||||
peer.Disabled = &deletionTime
|
||||
peer.DisabledReason = domain.DisabledReasonUserDeleted
|
||||
|
||||
_, err := m.UpdatePeer(ctx, &peer)
|
||||
if err != nil {
|
||||
logrus.Errorf("failed to disable peer %s for deleted user %s: %v",
|
||||
peer.Identifier, user.Identifier, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m Manager) runExpiredPeersCheck(ctx context.Context) {
|
||||
ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo())
|
||||
|
||||
|
@@ -175,14 +175,11 @@ func (m Manager) RestoreInterfaceState(
|
||||
}
|
||||
|
||||
_, err = m.wg.GetInterface(ctx, iface.Identifier)
|
||||
if err != nil {
|
||||
if err != nil && !iface.IsDisabled() {
|
||||
logrus.Debugf("creating missing interface %s...", iface.Identifier)
|
||||
|
||||
// try to create a new interface
|
||||
_, err = m.saveInterface(ctx, &iface, peers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = m.saveInterface(ctx, &iface)
|
||||
if err != nil {
|
||||
if updateDbOnError {
|
||||
// disable interface in database as no physical interface exists
|
||||
@@ -196,23 +193,11 @@ func (m Manager) RestoreInterfaceState(
|
||||
}
|
||||
return fmt.Errorf("failed to create physical interface %s: %w", iface.Identifier, err)
|
||||
}
|
||||
|
||||
// restore peers
|
||||
for _, peer := range peers {
|
||||
err := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier,
|
||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
||||
domain.MergeToPhysicalPeer(pp, &peer)
|
||||
return pp, nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create physical peer %s: %w", peer.Identifier, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logrus.Debugf("restoring interface state for %s to disabled=%t", iface.Identifier, iface.IsDisabled())
|
||||
|
||||
// try to move interface to stored state
|
||||
_, err = m.saveInterface(ctx, &iface, peers)
|
||||
_, err = m.saveInterface(ctx, &iface)
|
||||
if err != nil {
|
||||
if updateDbOnError {
|
||||
// disable interface in database as no physical interface is available
|
||||
@@ -232,6 +217,51 @@ func (m Manager) RestoreInterfaceState(
|
||||
return fmt.Errorf("failed to change physical interface state for %s: %w", iface.Identifier, err)
|
||||
}
|
||||
}
|
||||
|
||||
// restore peers
|
||||
for _, peer := range peers {
|
||||
switch {
|
||||
case iface.IsDisabled(): // if interface is disabled, delete all peers
|
||||
if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil {
|
||||
return fmt.Errorf("failed to remove peer %s for disabled interface %s: %w",
|
||||
peer.Identifier, iface.Identifier, err)
|
||||
}
|
||||
case peer.IsDisabled(): // if peer is disabled, delete it
|
||||
if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil {
|
||||
return fmt.Errorf("failed to remove disbaled peer %s from interface %s: %w",
|
||||
peer.Identifier, iface.Identifier, err)
|
||||
}
|
||||
default: // update peer
|
||||
err := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier,
|
||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
||||
domain.MergeToPhysicalPeer(pp, &peer)
|
||||
return pp, nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create/update physical peer %s for interface %s: %w",
|
||||
peer.Identifier, iface.Identifier, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// remove non-wgportal peers
|
||||
physicalPeers, _ := m.wg.GetPeers(ctx, iface.Identifier)
|
||||
for _, physicalPeer := range physicalPeers {
|
||||
isWgPortalPeer := false
|
||||
for _, peer := range peers {
|
||||
if peer.Identifier == domain.PeerIdentifier(physicalPeer.PublicKey) {
|
||||
isWgPortalPeer = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isWgPortalPeer {
|
||||
err := m.wg.DeletePeer(ctx, iface.Identifier, domain.PeerIdentifier(physicalPeer.PublicKey))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove non-wgportal peer %s from interface %s: %w",
|
||||
physicalPeer.PublicKey, iface.Identifier, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -334,7 +364,7 @@ func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*do
|
||||
return nil, fmt.Errorf("creation not allowed: %w", err)
|
||||
}
|
||||
|
||||
in, err = m.saveInterface(ctx, in, nil)
|
||||
in, err = m.saveInterface(ctx, in)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creation failure: %w", err)
|
||||
}
|
||||
@@ -356,7 +386,7 @@ func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*do
|
||||
return nil, nil, fmt.Errorf("update not allowed: %w", err)
|
||||
}
|
||||
|
||||
in, err = m.saveInterface(ctx, in, existingPeers)
|
||||
in, err = m.saveInterface(ctx, in)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("update failure: %w", err)
|
||||
}
|
||||
@@ -422,7 +452,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
|
||||
|
||||
// region helper-functions
|
||||
|
||||
func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface, peers []domain.Peer) (
|
||||
func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
|
||||
*domain.Interface,
|
||||
error,
|
||||
) {
|
||||
@@ -454,7 +484,6 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface, pee
|
||||
return nil, fmt.Errorf("failed to save interface: %w", err)
|
||||
}
|
||||
|
||||
m.bus.Publish(app.TopicRouteUpdate, "interface updated: "+string(iface.Identifier))
|
||||
if iface.IsDisabled() {
|
||||
physicalInterface, _ := m.wg.GetInterface(ctx, iface.Identifier)
|
||||
fwMark := iface.FirewallMark
|
||||
@@ -465,6 +494,8 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface, pee
|
||||
FwMark: fwMark,
|
||||
Table: iface.GetRoutingTable(),
|
||||
})
|
||||
} else {
|
||||
m.bus.Publish(app.TopicRouteUpdate, "interface updated: "+string(iface.Identifier))
|
||||
}
|
||||
|
||||
if err := m.handleInterfacePostSaveHooks(stateChanged, iface); err != nil {
|
||||
|
@@ -248,6 +248,10 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.validatePeerDeletion(ctx, peer); err != nil {
|
||||
return fmt.Errorf("delete not allowed: %w", err)
|
||||
}
|
||||
|
||||
err = m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wireguard failed to delete peer %s: %w", id, err)
|
||||
@@ -309,20 +313,33 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
|
||||
|
||||
for i := range peers {
|
||||
peer := peers[i]
|
||||
err := m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
||||
peer.CopyCalculatedAttributes(p)
|
||||
var err error
|
||||
if peer.IsDisabled() || peer.IsExpired() {
|
||||
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
||||
peer.CopyCalculatedAttributes(p)
|
||||
|
||||
err := m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
|
||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
||||
domain.MergeToPhysicalPeer(pp, peer)
|
||||
return pp, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
|
||||
}
|
||||
if err := m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, peer.Identifier); err != nil {
|
||||
return nil, fmt.Errorf("failed to delete wireguard peer %s: %w", peer.Identifier, err)
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
})
|
||||
return peer, nil
|
||||
})
|
||||
} else {
|
||||
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
|
||||
peer.CopyCalculatedAttributes(p)
|
||||
|
||||
err := m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
|
||||
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
|
||||
domain.MergeToPhysicalPeer(pp, peer)
|
||||
return pp, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("save failure for peer %s: %w", peer.Identifier, err)
|
||||
}
|
||||
|
Reference in New Issue
Block a user