mirror of
				https://github.com/h44z/wg-portal.git
				synced 2025-11-03 23:56:18 +00:00 
			
		
		
		
	@@ -3,11 +3,12 @@ package adapters
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/h44z/wg-portal/internal"
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/domain"
 | 
			
		||||
	"github.com/sirupsen/logrus"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// WgQuickRepo implements higher level wg-quick like interactions like setting DNS, routing tables or interface hooks.
 | 
			
		||||
@@ -57,7 +58,10 @@ func (r *WgQuickRepo) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr
 | 
			
		||||
 | 
			
		||||
	err := r.exec(dnsCommand, id, dnsCommandInput...)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to set dns settings: %w", err)
 | 
			
		||||
		return fmt.Errorf(
 | 
			
		||||
			"failed to set dns settings (is resolvconf available?, for systemd create this symlink: ln -s /usr/bin/resolvectl /usr/local/bin/resolvconf): %w",
 | 
			
		||||
			err,
 | 
			
		||||
		)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
 
 | 
			
		||||
@@ -130,7 +130,7 @@ func (e authEndpoint) handleOauthInitiateGet() gin.HandlerFunc {
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if currentSession.LoggedIn {
 | 
			
		||||
			if autoRedirect {
 | 
			
		||||
			if autoRedirect && e.isValidReturnUrl(returnTo) {
 | 
			
		||||
				queryParams := returnUrl.Query()
 | 
			
		||||
				queryParams.Set("wgLoginState", "success")
 | 
			
		||||
				returnParams = queryParams.Encode()
 | 
			
		||||
@@ -237,7 +237,7 @@ func (e authEndpoint) handleOauthCallbackGet() gin.HandlerFunc {
 | 
			
		||||
		user, err := e.app.Authenticator.OauthLoginStep2(loginCtx, provider, currentSession.OauthNonce, oauthCode)
 | 
			
		||||
		cancel()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if returnUrl != nil {
 | 
			
		||||
			if returnUrl != nil && e.isValidReturnUrl(returnUrl.String()) {
 | 
			
		||||
				redirectToReturn()
 | 
			
		||||
			} else {
 | 
			
		||||
				c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: err.Error()})
 | 
			
		||||
@@ -247,7 +247,7 @@ func (e authEndpoint) handleOauthCallbackGet() gin.HandlerFunc {
 | 
			
		||||
 | 
			
		||||
		e.setAuthenticatedUser(c, user)
 | 
			
		||||
 | 
			
		||||
		if returnUrl != nil {
 | 
			
		||||
		if returnUrl != nil && e.isValidReturnUrl(returnUrl.String()) {
 | 
			
		||||
			queryParams := returnUrl.Query()
 | 
			
		||||
			queryParams.Set("wgLoginState", "success")
 | 
			
		||||
			returnParams = queryParams.Encode()
 | 
			
		||||
 
 | 
			
		||||
@@ -246,6 +246,10 @@ func (a *Authenticator) passwordAuthentication(
 | 
			
		||||
		return nil, errors.New("user not found")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if userSource == domain.UserSourceLdap && ldapProvider == nil {
 | 
			
		||||
		return nil, errors.New("ldap provider not found")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch userSource {
 | 
			
		||||
	case domain.UserSourceDatabase:
 | 
			
		||||
		err = existingUser.CheckPassword(password)
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ package app
 | 
			
		||||
const TopicUserCreated = "user:created"
 | 
			
		||||
const TopicUserRegistered = "user:registered"
 | 
			
		||||
const TopicUserDisabled = "user:disabled"
 | 
			
		||||
const TopicUserEnabled = "user:enabled"
 | 
			
		||||
const TopicUserDeleted = "user:deleted"
 | 
			
		||||
const TopicAuthLogin = "auth:login"
 | 
			
		||||
const TopicRouteUpdate = "route:update"
 | 
			
		||||
 
 | 
			
		||||
@@ -30,7 +30,10 @@ type Manager struct {
 | 
			
		||||
	peers PeerDatabaseRepo
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewUserManager(cfg *config.Config, bus evbus.MessageBus, users UserDatabaseRepo, peers PeerDatabaseRepo) (*Manager, error) {
 | 
			
		||||
func NewUserManager(cfg *config.Config, bus evbus.MessageBus, users UserDatabaseRepo, peers PeerDatabaseRepo) (
 | 
			
		||||
	*Manager,
 | 
			
		||||
	error,
 | 
			
		||||
) {
 | 
			
		||||
	m := &Manager{
 | 
			
		||||
		cfg: cfg,
 | 
			
		||||
		bus: bus,
 | 
			
		||||
@@ -170,6 +173,13 @@ func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.Use
 | 
			
		||||
		return nil, fmt.Errorf("update failure: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch {
 | 
			
		||||
	case !existingUser.IsDisabled() && user.IsDisabled():
 | 
			
		||||
		m.bus.Publish(app.TopicUserDisabled, *user)
 | 
			
		||||
	case existingUser.IsDisabled() && !user.IsDisabled():
 | 
			
		||||
		m.bus.Publish(app.TopicUserEnabled, *user)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return user, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -225,7 +235,7 @@ func (m Manager) DeleteUser(ctx context.Context, id domain.UserIdentifier) error
 | 
			
		||||
		return fmt.Errorf("deletion failure: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	m.bus.Publish(app.TopicUserDeleted, existingUser)
 | 
			
		||||
	m.bus.Publish(app.TopicUserDeleted, *existingUser)
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -374,7 +384,13 @@ func (m Manager) synchronizeLdapUsers(ctx context.Context, provider *config.Ldap
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) updateLdapUsers(ctx context.Context, providerName string, rawUsers []internal.RawLdapUser, fields *config.LdapFields, adminGroupDN *ldap.DN) error {
 | 
			
		||||
func (m Manager) updateLdapUsers(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	providerName string,
 | 
			
		||||
	rawUsers []internal.RawLdapUser,
 | 
			
		||||
	fields *config.LdapFields,
 | 
			
		||||
	adminGroupDN *ldap.DN,
 | 
			
		||||
) error {
 | 
			
		||||
	for _, rawUser := range rawUsers {
 | 
			
		||||
		user, err := convertRawLdapUser(providerName, rawUser, fields, adminGroupDN)
 | 
			
		||||
		if err != nil && !errors.Is(err, domain.ErrNotFound) {
 | 
			
		||||
@@ -397,7 +413,8 @@ func (m Manager) updateLdapUsers(ctx context.Context, providerName string, rawUs
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if existingUser != nil && existingUser.Source == domain.UserSourceLdap && userChangedInLdap(existingUser, user) {
 | 
			
		||||
		if existingUser != nil && existingUser.Source == domain.UserSourceLdap && userChangedInLdap(existingUser,
 | 
			
		||||
			user) {
 | 
			
		||||
 | 
			
		||||
			err := m.users.SaveUser(tctx, user.Identifier, func(u *domain.User) (*domain.User, error) {
 | 
			
		||||
				u.UpdatedAt = time.Now()
 | 
			
		||||
@@ -421,7 +438,12 @@ func (m Manager) updateLdapUsers(ctx context.Context, providerName string, rawUs
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) disableMissingLdapUsers(ctx context.Context, providerName string, rawUsers []internal.RawLdapUser, fields *config.LdapFields) error {
 | 
			
		||||
func (m Manager) disableMissingLdapUsers(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	providerName string,
 | 
			
		||||
	rawUsers []internal.RawLdapUser,
 | 
			
		||||
	fields *config.LdapFields,
 | 
			
		||||
) error {
 | 
			
		||||
	allUsers, err := m.users.GetAllUsers(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
 
 | 
			
		||||
@@ -2,9 +2,10 @@ package wireguard
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/h44z/wg-portal/internal/app"
 | 
			
		||||
	"github.com/sirupsen/logrus"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	evbus "github.com/vardius/message-bus"
 | 
			
		||||
 | 
			
		||||
@@ -21,7 +22,13 @@ type Manager struct {
 | 
			
		||||
	quick WgQuickController
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewWireGuardManager(cfg *config.Config, bus evbus.MessageBus, wg InterfaceController, quick WgQuickController, db InterfaceAndPeerDatabaseRepo) (*Manager, error) {
 | 
			
		||||
func NewWireGuardManager(
 | 
			
		||||
	cfg *config.Config,
 | 
			
		||||
	bus evbus.MessageBus,
 | 
			
		||||
	wg InterfaceController,
 | 
			
		||||
	quick WgQuickController,
 | 
			
		||||
	db InterfaceAndPeerDatabaseRepo,
 | 
			
		||||
) (*Manager, error) {
 | 
			
		||||
	m := &Manager{
 | 
			
		||||
		cfg:   cfg,
 | 
			
		||||
		bus:   bus,
 | 
			
		||||
@@ -42,6 +49,9 @@ func (m Manager) StartBackgroundJobs(ctx context.Context) {
 | 
			
		||||
func (m Manager) connectToMessageBus() {
 | 
			
		||||
	_ = m.bus.Subscribe(app.TopicUserCreated, m.handleUserCreationEvent)
 | 
			
		||||
	_ = m.bus.Subscribe(app.TopicAuthLogin, m.handleUserLoginEvent)
 | 
			
		||||
	_ = m.bus.Subscribe(app.TopicUserDisabled, m.handleUserDisabledEvent)
 | 
			
		||||
	_ = m.bus.Subscribe(app.TopicUserEnabled, m.handleUserEnabledEvent)
 | 
			
		||||
	_ = m.bus.Subscribe(app.TopicUserDeleted, m.handleUserDeletedEvent)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) handleUserCreationEvent(user *domain.User) {
 | 
			
		||||
@@ -84,6 +94,104 @@ func (m Manager) handleUserLoginEvent(userId domain.UserIdentifier) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) handleUserDisabledEvent(user domain.User) {
 | 
			
		||||
	ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
 | 
			
		||||
	userPeers, err := m.db.GetUserPeers(ctx, user.Identifier)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logrus.Errorf("failed to retrieve peers for disabled user %s: %v", user.Identifier, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, peer := range userPeers {
 | 
			
		||||
		if peer.IsDisabled() {
 | 
			
		||||
			continue // peer is already disabled
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		logrus.Debugf("disabling peer %s due to user %s being disabled", peer.Identifier, user.Identifier)
 | 
			
		||||
 | 
			
		||||
		peer.Disabled = user.Disabled // set to user disabled timestamp
 | 
			
		||||
		peer.DisabledReason = domain.DisabledReasonUserDisabled
 | 
			
		||||
 | 
			
		||||
		_, err := m.UpdatePeer(ctx, &peer)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logrus.Errorf("failed to disable peer %s for disabled user %s: %v",
 | 
			
		||||
				peer.Identifier, user.Identifier, err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) handleUserEnabledEvent(user domain.User) {
 | 
			
		||||
	if !m.cfg.Core.ReEnablePeerAfterUserEnable {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
 | 
			
		||||
	userPeers, err := m.db.GetUserPeers(ctx, user.Identifier)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logrus.Errorf("failed to retrieve peers for re-enabled user %s: %v", user.Identifier, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, peer := range userPeers {
 | 
			
		||||
		if !peer.IsDisabled() {
 | 
			
		||||
			continue // peer is already active
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if peer.DisabledReason != domain.DisabledReasonUserDisabled {
 | 
			
		||||
			continue // peer was disabled for another reason
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		logrus.Debugf("enabling peer %s due to user %s being enabled", peer.Identifier, user.Identifier)
 | 
			
		||||
 | 
			
		||||
		peer.Disabled = nil
 | 
			
		||||
		peer.DisabledReason = ""
 | 
			
		||||
 | 
			
		||||
		_, err := m.UpdatePeer(ctx, &peer)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logrus.Errorf("failed to enable peer %s for enabled user %s: %v",
 | 
			
		||||
				peer.Identifier, user.Identifier, err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) handleUserDeletedEvent(user domain.User) {
 | 
			
		||||
	ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
 | 
			
		||||
	userPeers, err := m.db.GetUserPeers(ctx, user.Identifier)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logrus.Errorf("failed to retrieve peers for deleted user %s: %v", user.Identifier, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	deletionTime := time.Now()
 | 
			
		||||
	for _, peer := range userPeers {
 | 
			
		||||
		if peer.IsDisabled() {
 | 
			
		||||
			continue // peer is already disabled
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if m.cfg.Core.DeletePeerAfterUserDeleted {
 | 
			
		||||
			logrus.Debugf("deleting peer %s due to user %s being deleted", peer.Identifier, user.Identifier)
 | 
			
		||||
 | 
			
		||||
			if err := m.DeletePeer(ctx, peer.Identifier); err != nil {
 | 
			
		||||
				logrus.Errorf("failed to delete peer %s for deleted user %s: %v",
 | 
			
		||||
					peer.Identifier, user.Identifier, err)
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			logrus.Debugf("disabling peer %s due to user %s being deleted", peer.Identifier, user.Identifier)
 | 
			
		||||
 | 
			
		||||
			peer.UserIdentifier = "" // remove user reference
 | 
			
		||||
			peer.Disabled = &deletionTime
 | 
			
		||||
			peer.DisabledReason = domain.DisabledReasonUserDeleted
 | 
			
		||||
 | 
			
		||||
			_, err := m.UpdatePeer(ctx, &peer)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logrus.Errorf("failed to disable peer %s for deleted user %s: %v",
 | 
			
		||||
					peer.Identifier, user.Identifier, err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) runExpiredPeersCheck(ctx context.Context) {
 | 
			
		||||
	ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo())
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -175,14 +175,11 @@ func (m Manager) RestoreInterfaceState(
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		_, err = m.wg.GetInterface(ctx, iface.Identifier)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
		if err != nil && !iface.IsDisabled() {
 | 
			
		||||
			logrus.Debugf("creating missing interface %s...", iface.Identifier)
 | 
			
		||||
 | 
			
		||||
			// try to create a new interface
 | 
			
		||||
			_, err = m.saveInterface(ctx, &iface, peers)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			_, err = m.saveInterface(ctx, &iface)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				if updateDbOnError {
 | 
			
		||||
					// disable interface in database as no physical interface exists
 | 
			
		||||
@@ -196,23 +193,11 @@ func (m Manager) RestoreInterfaceState(
 | 
			
		||||
				}
 | 
			
		||||
				return fmt.Errorf("failed to create physical interface %s: %w", iface.Identifier, err)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// restore peers
 | 
			
		||||
			for _, peer := range peers {
 | 
			
		||||
				err := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier,
 | 
			
		||||
					func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
 | 
			
		||||
						domain.MergeToPhysicalPeer(pp, &peer)
 | 
			
		||||
						return pp, nil
 | 
			
		||||
					})
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return fmt.Errorf("failed to create physical peer %s: %w", peer.Identifier, err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			logrus.Debugf("restoring interface state for %s to disabled=%t", iface.Identifier, iface.IsDisabled())
 | 
			
		||||
 | 
			
		||||
			// try to move interface to stored state
 | 
			
		||||
			_, err = m.saveInterface(ctx, &iface, peers)
 | 
			
		||||
			_, err = m.saveInterface(ctx, &iface)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				if updateDbOnError {
 | 
			
		||||
					// disable interface in database as no physical interface is available
 | 
			
		||||
@@ -232,6 +217,51 @@ func (m Manager) RestoreInterfaceState(
 | 
			
		||||
				return fmt.Errorf("failed to change physical interface state for %s: %w", iface.Identifier, err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// restore peers
 | 
			
		||||
		for _, peer := range peers {
 | 
			
		||||
			switch {
 | 
			
		||||
			case iface.IsDisabled(): // if interface is disabled, delete all peers
 | 
			
		||||
				if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil {
 | 
			
		||||
					return fmt.Errorf("failed to remove peer %s for disabled interface %s: %w",
 | 
			
		||||
						peer.Identifier, iface.Identifier, err)
 | 
			
		||||
				}
 | 
			
		||||
			case peer.IsDisabled(): // if peer is disabled, delete it
 | 
			
		||||
				if err := m.wg.DeletePeer(ctx, iface.Identifier, peer.Identifier); err != nil {
 | 
			
		||||
					return fmt.Errorf("failed to remove disbaled peer %s from interface %s: %w",
 | 
			
		||||
						peer.Identifier, iface.Identifier, err)
 | 
			
		||||
				}
 | 
			
		||||
			default: // update peer
 | 
			
		||||
				err := m.wg.SavePeer(ctx, iface.Identifier, peer.Identifier,
 | 
			
		||||
					func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
 | 
			
		||||
						domain.MergeToPhysicalPeer(pp, &peer)
 | 
			
		||||
						return pp, nil
 | 
			
		||||
					})
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return fmt.Errorf("failed to create/update physical peer %s for interface %s: %w",
 | 
			
		||||
						peer.Identifier, iface.Identifier, err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// remove non-wgportal peers
 | 
			
		||||
		physicalPeers, _ := m.wg.GetPeers(ctx, iface.Identifier)
 | 
			
		||||
		for _, physicalPeer := range physicalPeers {
 | 
			
		||||
			isWgPortalPeer := false
 | 
			
		||||
			for _, peer := range peers {
 | 
			
		||||
				if peer.Identifier == domain.PeerIdentifier(physicalPeer.PublicKey) {
 | 
			
		||||
					isWgPortalPeer = true
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if !isWgPortalPeer {
 | 
			
		||||
				err := m.wg.DeletePeer(ctx, iface.Identifier, domain.PeerIdentifier(physicalPeer.PublicKey))
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return fmt.Errorf("failed to remove non-wgportal peer %s from interface %s: %w",
 | 
			
		||||
						physicalPeer.PublicKey, iface.Identifier, err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
@@ -334,7 +364,7 @@ func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*do
 | 
			
		||||
		return nil, fmt.Errorf("creation not allowed: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	in, err = m.saveInterface(ctx, in, nil)
 | 
			
		||||
	in, err = m.saveInterface(ctx, in)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("creation failure: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -356,7 +386,7 @@ func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*do
 | 
			
		||||
		return nil, nil, fmt.Errorf("update not allowed: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	in, err = m.saveInterface(ctx, in, existingPeers)
 | 
			
		||||
	in, err = m.saveInterface(ctx, in)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, fmt.Errorf("update failure: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -422,7 +452,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
 | 
			
		||||
 | 
			
		||||
// region helper-functions
 | 
			
		||||
 | 
			
		||||
func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface, peers []domain.Peer) (
 | 
			
		||||
func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
 | 
			
		||||
	*domain.Interface,
 | 
			
		||||
	error,
 | 
			
		||||
) {
 | 
			
		||||
@@ -454,7 +484,6 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface, pee
 | 
			
		||||
		return nil, fmt.Errorf("failed to save interface: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	m.bus.Publish(app.TopicRouteUpdate, "interface updated: "+string(iface.Identifier))
 | 
			
		||||
	if iface.IsDisabled() {
 | 
			
		||||
		physicalInterface, _ := m.wg.GetInterface(ctx, iface.Identifier)
 | 
			
		||||
		fwMark := iface.FirewallMark
 | 
			
		||||
@@ -465,6 +494,8 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface, pee
 | 
			
		||||
			FwMark: fwMark,
 | 
			
		||||
			Table:  iface.GetRoutingTable(),
 | 
			
		||||
		})
 | 
			
		||||
	} else {
 | 
			
		||||
		m.bus.Publish(app.TopicRouteUpdate, "interface updated: "+string(iface.Identifier))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := m.handleInterfacePostSaveHooks(stateChanged, iface); err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -248,6 +248,10 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := m.validatePeerDeletion(ctx, peer); err != nil {
 | 
			
		||||
		return fmt.Errorf("delete not allowed: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("wireguard failed to delete peer %s: %w", id, err)
 | 
			
		||||
@@ -309,20 +313,33 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
 | 
			
		||||
 | 
			
		||||
	for i := range peers {
 | 
			
		||||
		peer := peers[i]
 | 
			
		||||
		err := m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
 | 
			
		||||
			peer.CopyCalculatedAttributes(p)
 | 
			
		||||
		var err error
 | 
			
		||||
		if peer.IsDisabled() || peer.IsExpired() {
 | 
			
		||||
			err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
 | 
			
		||||
				peer.CopyCalculatedAttributes(p)
 | 
			
		||||
 | 
			
		||||
			err := m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
 | 
			
		||||
				func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
 | 
			
		||||
					domain.MergeToPhysicalPeer(pp, peer)
 | 
			
		||||
					return pp, nil
 | 
			
		||||
				})
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
 | 
			
		||||
			}
 | 
			
		||||
				if err := m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, peer.Identifier); err != nil {
 | 
			
		||||
					return nil, fmt.Errorf("failed to delete wireguard peer %s: %w", peer.Identifier, err)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
			return peer, nil
 | 
			
		||||
		})
 | 
			
		||||
				return peer, nil
 | 
			
		||||
			})
 | 
			
		||||
		} else {
 | 
			
		||||
			err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
 | 
			
		||||
				peer.CopyCalculatedAttributes(p)
 | 
			
		||||
 | 
			
		||||
				err := m.wg.SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
 | 
			
		||||
					func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
 | 
			
		||||
						domain.MergeToPhysicalPeer(pp, peer)
 | 
			
		||||
						return pp, nil
 | 
			
		||||
					})
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return nil, fmt.Errorf("failed to save wireguard peer %s: %w", peer.Identifier, err)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				return peer, nil
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("save failure for peer %s: %w", peer.Identifier, err)
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -20,6 +20,8 @@ type Config struct {
 | 
			
		||||
		EditableKeys                bool `yaml:"editable_keys"`
 | 
			
		||||
		CreateDefaultPeer           bool `yaml:"create_default_peer"`
 | 
			
		||||
		CreateDefaultPeerOnCreation bool `yaml:"create_default_peer_on_creation"`
 | 
			
		||||
		ReEnablePeerAfterUserEnable bool `yaml:"re_enable_peer_after_user_enable"`
 | 
			
		||||
		DeletePeerAfterUserDeleted  bool `yaml:"delete_peer_after_user_deleted"`
 | 
			
		||||
		SelfProvisioningAllowed     bool `yaml:"self_provisioning_allowed"`
 | 
			
		||||
		ImportExisting              bool `yaml:"import_existing"`
 | 
			
		||||
		RestoreState                bool `yaml:"restore_state"`
 | 
			
		||||
@@ -61,9 +63,13 @@ type Config struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Config) LogStartupValues() {
 | 
			
		||||
	logrus.Infof("Log Level: %s", c.Advanced.LogLevel)
 | 
			
		||||
 | 
			
		||||
	logrus.Debug("WireGuard Portal Features:")
 | 
			
		||||
	logrus.Debugf("  - EditableKeys: %t", c.Core.EditableKeys)
 | 
			
		||||
	logrus.Debugf("  - CreateDefaultPeerOnCreation: %t", c.Core.CreateDefaultPeerOnCreation)
 | 
			
		||||
	logrus.Debugf("  - ReEnablePeerAfterUserEnable: %t", c.Core.ReEnablePeerAfterUserEnable)
 | 
			
		||||
	logrus.Debugf("  - DeletePeerAfterUserDeleted: %t", c.Core.DeletePeerAfterUserDeleted)
 | 
			
		||||
	logrus.Debugf("  - SelfProvisioningAllowed: %t", c.Core.SelfProvisioningAllowed)
 | 
			
		||||
	logrus.Debugf("  - ImportExisting: %t", c.Core.ImportExisting)
 | 
			
		||||
	logrus.Debugf("  - RestoreState: %t", c.Core.RestoreState)
 | 
			
		||||
@@ -85,8 +91,16 @@ func (c *Config) LogStartupValues() {
 | 
			
		||||
func defaultConfig() *Config {
 | 
			
		||||
	cfg := &Config{}
 | 
			
		||||
 | 
			
		||||
	cfg.Core.AdminUser = "admin@wgportal.local"
 | 
			
		||||
	cfg.Core.AdminPassword = "wgportal"
 | 
			
		||||
	cfg.Core.ImportExisting = true
 | 
			
		||||
	cfg.Core.RestoreState = true
 | 
			
		||||
	cfg.Core.CreateDefaultPeer = false
 | 
			
		||||
	cfg.Core.CreateDefaultPeerOnCreation = false
 | 
			
		||||
	cfg.Core.EditableKeys = true
 | 
			
		||||
	cfg.Core.SelfProvisioningAllowed = false
 | 
			
		||||
	cfg.Core.ReEnablePeerAfterUserEnable = true
 | 
			
		||||
	cfg.Core.DeletePeerAfterUserDeleted = false
 | 
			
		||||
 | 
			
		||||
	cfg.Database = DatabaseConfig{
 | 
			
		||||
		Type: "sqlite",
 | 
			
		||||
@@ -104,6 +118,7 @@ func defaultConfig() *Config {
 | 
			
		||||
		SiteCompanyName:   "WireGuard Portal",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	cfg.Advanced.LogLevel = "info"
 | 
			
		||||
	cfg.Advanced.StartListenPort = 51820
 | 
			
		||||
	cfg.Advanced.StartCidrV4 = "10.11.12.0/24"
 | 
			
		||||
	cfg.Advanced.StartCidrV6 = "fdfd:d3ad:c0de:1234::0/64"
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,9 @@
 | 
			
		||||
package domain
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type BaseModel struct {
 | 
			
		||||
@@ -26,30 +25,32 @@ func (PrivateString) String() string {
 | 
			
		||||
 | 
			
		||||
func (ps PrivateString) Value() (driver.Value, error) {
 | 
			
		||||
	if len(ps) == 0 {
 | 
			
		||||
        return nil, nil
 | 
			
		||||
    }
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	return string(ps), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ps *PrivateString) Scan(value interface{}) error {
 | 
			
		||||
    if value == nil {
 | 
			
		||||
        *ps = ""
 | 
			
		||||
        return nil
 | 
			
		||||
    }
 | 
			
		||||
    switch v := value.(type) {
 | 
			
		||||
    case string:
 | 
			
		||||
        *ps = PrivateString(v)
 | 
			
		||||
    case []byte:
 | 
			
		||||
        *ps = PrivateString(string(v))
 | 
			
		||||
    default:
 | 
			
		||||
        return errors.New("invalid type for PrivateString")
 | 
			
		||||
    }
 | 
			
		||||
    return nil
 | 
			
		||||
	if value == nil {
 | 
			
		||||
		*ps = ""
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	switch v := value.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		*ps = PrivateString(v)
 | 
			
		||||
	case []byte:
 | 
			
		||||
		*ps = PrivateString(string(v))
 | 
			
		||||
	default:
 | 
			
		||||
		return errors.New("invalid type for PrivateString")
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	DisabledReasonExpired          = "expired"
 | 
			
		||||
	DisabledReasonDeleted          = "deleted"
 | 
			
		||||
	DisabledReasonUserDisabled     = "user disabled"
 | 
			
		||||
	DisabledReasonUserDeleted      = "user deleted"
 | 
			
		||||
	DisabledReasonUserEdit         = "user edit action"
 | 
			
		||||
	DisabledReasonUserCreate       = "user create action"
 | 
			
		||||
	DisabledReasonAdminEdit        = "admin edit action"
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user