mirror of
				https://github.com/h44z/wg-portal.git
				synced 2025-11-03 23:56:18 +00:00 
			
		
		
		
	fix REST API permission checks (#209)
This commit is contained in:
		@@ -44,7 +44,8 @@ func (e authEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenti
 | 
			
		||||
// @Router /auth/providers [get]
 | 
			
		||||
func (e authEndpoint) handleExternalLoginProvidersGet() gin.HandlerFunc {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		providers := e.app.Authenticator.GetExternalLoginProviders(c.Request.Context())
 | 
			
		||||
		ctx := domain.SetUserInfoFromGin(c)
 | 
			
		||||
		providers := e.app.Authenticator.GetExternalLoginProviders(ctx)
 | 
			
		||||
 | 
			
		||||
		c.JSON(http.StatusOK, model.NewLoginProviderInfos(providers))
 | 
			
		||||
	}
 | 
			
		||||
@@ -69,7 +70,7 @@ func (e authEndpoint) handleSessionInfoGet() gin.HandlerFunc {
 | 
			
		||||
		var email *string
 | 
			
		||||
 | 
			
		||||
		if currentSession.LoggedIn {
 | 
			
		||||
			uid := string(currentSession.UserIdentifier)
 | 
			
		||||
			uid := currentSession.UserIdentifier
 | 
			
		||||
			f := currentSession.Firstname
 | 
			
		||||
			l := currentSession.Lastname
 | 
			
		||||
			e := currentSession.Email
 | 
			
		||||
@@ -134,7 +135,8 @@ func (e authEndpoint) handleOauthInitiateGet() gin.HandlerFunc {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		authCodeUrl, state, nonce, err := e.app.Authenticator.OauthLoginStep1(c.Request.Context(), provider)
 | 
			
		||||
		ctx := domain.SetUserInfoFromGin(c)
 | 
			
		||||
		authCodeUrl, state, nonce, err := e.app.Authenticator.OauthLoginStep1(ctx, provider)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if autoRedirect {
 | 
			
		||||
				redirectToReturn()
 | 
			
		||||
@@ -292,7 +294,8 @@ func (e authEndpoint) handleLoginPost() gin.HandlerFunc {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		user, err := e.app.Authenticator.PlainLogin(c.Request.Context(), loginData.Username, loginData.Password)
 | 
			
		||||
		ctx := domain.SetUserInfoFromGin(c)
 | 
			
		||||
		user, err := e.app.Authenticator.PlainLogin(ctx, loginData.Username, loginData.Password)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusUnauthorized, model.Error{Code: http.StatusUnauthorized, Message: "login failed"})
 | 
			
		||||
			return
 | 
			
		||||
 
 | 
			
		||||
@@ -19,7 +19,7 @@ func (e interfaceEndpoint) GetName() string {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e interfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
 | 
			
		||||
	apiGroup := g.Group("/interface", e.authenticator.LoggedIn())
 | 
			
		||||
	apiGroup := g.Group("/interface", e.authenticator.LoggedIn(ScopeAdmin))
 | 
			
		||||
 | 
			
		||||
	apiGroup.GET("/prepare", e.handlePrepareGet())
 | 
			
		||||
	apiGroup.GET("/all", e.handleAllGet())
 | 
			
		||||
@@ -45,7 +45,8 @@ func (e interfaceEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *aut
 | 
			
		||||
// @Router /interface/prepare [get]
 | 
			
		||||
func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		in, err := e.app.PrepareInterface(c.Request.Context())
 | 
			
		||||
		ctx := domain.SetUserInfoFromGin(c)
 | 
			
		||||
		in, err := e.app.PrepareInterface(ctx)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusInternalServerError, model.Error{
 | 
			
		||||
				Code: http.StatusInternalServerError, Message: err.Error(),
 | 
			
		||||
@@ -68,7 +69,8 @@ func (e interfaceEndpoint) handlePrepareGet() gin.HandlerFunc {
 | 
			
		||||
// @Router /interface/all [get]
 | 
			
		||||
func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		interfaces, peers, err := e.app.GetAllInterfacesAndPeers(c.Request.Context())
 | 
			
		||||
		ctx := domain.SetUserInfoFromGin(c)
 | 
			
		||||
		interfaces, peers, err := e.app.GetAllInterfacesAndPeers(ctx)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusInternalServerError, model.Error{
 | 
			
		||||
				Code: http.StatusInternalServerError, Message: err.Error(),
 | 
			
		||||
@@ -92,6 +94,7 @@ func (e interfaceEndpoint) handleAllGet() gin.HandlerFunc {
 | 
			
		||||
// @Router /interface/get/{id} [get]
 | 
			
		||||
func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		ctx := domain.SetUserInfoFromGin(c)
 | 
			
		||||
		id := Base64UrlDecode(c.Param("id"))
 | 
			
		||||
		if id == "" {
 | 
			
		||||
			c.JSON(http.StatusBadRequest, model.Error{
 | 
			
		||||
@@ -100,7 +103,7 @@ func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		iface, peers, err := e.app.GetInterfaceAndPeers(c.Request.Context(), domain.InterfaceIdentifier(id))
 | 
			
		||||
		iface, peers, err := e.app.GetInterfaceAndPeers(ctx, domain.InterfaceIdentifier(id))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusInternalServerError, model.Error{
 | 
			
		||||
				Code: http.StatusInternalServerError, Message: err.Error(),
 | 
			
		||||
@@ -124,6 +127,7 @@ func (e interfaceEndpoint) handleSingleGet() gin.HandlerFunc {
 | 
			
		||||
// @Router /interface/config/{id} [get]
 | 
			
		||||
func (e interfaceEndpoint) handleConfigGet() gin.HandlerFunc {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		ctx := domain.SetUserInfoFromGin(c)
 | 
			
		||||
		id := Base64UrlDecode(c.Param("id"))
 | 
			
		||||
		if id == "" {
 | 
			
		||||
			c.JSON(http.StatusBadRequest, model.Error{
 | 
			
		||||
@@ -132,7 +136,7 @@ func (e interfaceEndpoint) handleConfigGet() gin.HandlerFunc {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		config, err := e.app.GetInterfaceConfig(c.Request.Context(), domain.InterfaceIdentifier(id))
 | 
			
		||||
		config, err := e.app.GetInterfaceConfig(ctx, domain.InterfaceIdentifier(id))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusInternalServerError, model.Error{
 | 
			
		||||
				Code: http.StatusInternalServerError, Message: err.Error(),
 | 
			
		||||
 
 | 
			
		||||
@@ -21,11 +21,11 @@ func (e peerEndpoint) GetName() string {
 | 
			
		||||
func (e peerEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
 | 
			
		||||
	apiGroup := g.Group("/peer", e.authenticator.LoggedIn())
 | 
			
		||||
 | 
			
		||||
	apiGroup.GET("/iface/:iface/all", e.handleAllGet())
 | 
			
		||||
	apiGroup.GET("/iface/:iface/stats", e.handleStatsGet())
 | 
			
		||||
	apiGroup.GET("/iface/:iface/prepare", e.handlePrepareGet())
 | 
			
		||||
	apiGroup.POST("/iface/:iface/new", e.handleCreatePost())
 | 
			
		||||
	apiGroup.POST("/iface/:iface/multiplenew", e.handleCreateMultiplePost())
 | 
			
		||||
	apiGroup.GET("/iface/:iface/all", e.authenticator.LoggedIn(ScopeAdmin), e.handleAllGet())
 | 
			
		||||
	apiGroup.GET("/iface/:iface/stats", e.authenticator.LoggedIn(ScopeAdmin), e.handleStatsGet())
 | 
			
		||||
	apiGroup.GET("/iface/:iface/prepare", e.authenticator.LoggedIn(ScopeAdmin), e.handlePrepareGet())
 | 
			
		||||
	apiGroup.POST("/iface/:iface/new", e.authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost())
 | 
			
		||||
	apiGroup.POST("/iface/:iface/multiplenew", e.authenticator.LoggedIn(ScopeAdmin), e.handleCreateMultiplePost())
 | 
			
		||||
	apiGroup.GET("/config-qr/:id", e.handleQrCodeGet())
 | 
			
		||||
	apiGroup.POST("/config-mail", e.handleEmailPost())
 | 
			
		||||
	apiGroup.GET("/config/:id", e.handleConfigGet())
 | 
			
		||||
@@ -298,6 +298,8 @@ func (e peerEndpoint) handleDelete() gin.HandlerFunc {
 | 
			
		||||
// @Router /peer/config/{id} [get]
 | 
			
		||||
func (e peerEndpoint) handleConfigGet() gin.HandlerFunc {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		ctx := domain.SetUserInfoFromGin(c)
 | 
			
		||||
 | 
			
		||||
		id := Base64UrlDecode(c.Param("id"))
 | 
			
		||||
		if id == "" {
 | 
			
		||||
			c.JSON(http.StatusBadRequest, model.Error{
 | 
			
		||||
@@ -306,7 +308,7 @@ func (e peerEndpoint) handleConfigGet() gin.HandlerFunc {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		config, err := e.app.GetPeerConfig(c.Request.Context(), domain.PeerIdentifier(id))
 | 
			
		||||
		config, err := e.app.GetPeerConfig(ctx, domain.PeerIdentifier(id))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusInternalServerError, model.Error{
 | 
			
		||||
				Code: http.StatusInternalServerError, Message: err.Error(),
 | 
			
		||||
@@ -339,6 +341,7 @@ func (e peerEndpoint) handleConfigGet() gin.HandlerFunc {
 | 
			
		||||
// @Router /peer/config-qr/{id} [get]
 | 
			
		||||
func (e peerEndpoint) handleQrCodeGet() gin.HandlerFunc {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		ctx := domain.SetUserInfoFromGin(c)
 | 
			
		||||
		id := Base64UrlDecode(c.Param("id"))
 | 
			
		||||
		if id == "" {
 | 
			
		||||
			c.JSON(http.StatusBadRequest, model.Error{
 | 
			
		||||
@@ -347,7 +350,7 @@ func (e peerEndpoint) handleQrCodeGet() gin.HandlerFunc {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		config, err := e.app.GetPeerConfigQrCode(c.Request.Context(), domain.PeerIdentifier(id))
 | 
			
		||||
		config, err := e.app.GetPeerConfigQrCode(ctx, domain.PeerIdentifier(id))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusInternalServerError, model.Error{
 | 
			
		||||
				Code: http.StatusInternalServerError, Message: err.Error(),
 | 
			
		||||
@@ -392,11 +395,13 @@ func (e peerEndpoint) handleEmailPost() gin.HandlerFunc {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		ctx := domain.SetUserInfoFromGin(c)
 | 
			
		||||
 | 
			
		||||
		peerIds := make([]domain.PeerIdentifier, len(req.Identifiers))
 | 
			
		||||
		for i := range req.Identifiers {
 | 
			
		||||
			peerIds[i] = domain.PeerIdentifier(req.Identifiers[i])
 | 
			
		||||
		}
 | 
			
		||||
		err = e.app.SendPeerEmail(c.Request.Context(), req.LinkOnly, peerIds...)
 | 
			
		||||
		err = e.app.SendPeerEmail(ctx, req.LinkOnly, peerIds...)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusInternalServerError, model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
 | 
			
		||||
			return
 | 
			
		||||
 
 | 
			
		||||
@@ -20,13 +20,13 @@ func (e userEndpoint) GetName() string {
 | 
			
		||||
func (e userEndpoint) RegisterRoutes(g *gin.RouterGroup, authenticator *authenticationHandler) {
 | 
			
		||||
	apiGroup := g.Group("/user", e.authenticator.LoggedIn())
 | 
			
		||||
 | 
			
		||||
	apiGroup.GET("/all", e.handleAllGet())
 | 
			
		||||
	apiGroup.GET("/:id", e.handleSingleGet())
 | 
			
		||||
	apiGroup.PUT("/:id", e.handleUpdatePut())
 | 
			
		||||
	apiGroup.DELETE("/:id", e.handleDelete())
 | 
			
		||||
	apiGroup.POST("/new", e.handleCreatePost())
 | 
			
		||||
	apiGroup.GET("/:id/peers", e.handlePeersGet())
 | 
			
		||||
	apiGroup.GET("/:id/stats", e.handleStatsGet())
 | 
			
		||||
	apiGroup.GET("/all", e.authenticator.LoggedIn(ScopeAdmin), e.handleAllGet())
 | 
			
		||||
	apiGroup.GET("/:id", e.authenticator.UserIdMatch("id"), e.handleSingleGet())
 | 
			
		||||
	apiGroup.PUT("/:id", e.authenticator.UserIdMatch("id"), e.handleUpdatePut())
 | 
			
		||||
	apiGroup.DELETE("/:id", e.authenticator.UserIdMatch("id"), e.handleDelete())
 | 
			
		||||
	apiGroup.POST("/new", e.authenticator.LoggedIn(ScopeAdmin), e.handleCreatePost())
 | 
			
		||||
	apiGroup.GET("/:id/peers", e.authenticator.UserIdMatch("id"), e.handlePeersGet())
 | 
			
		||||
	apiGroup.GET("/:id/stats", e.authenticator.UserIdMatch("id"), e.handleStatsGet())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// handleAllGet returns a gorm handler function.
 | 
			
		||||
 
 | 
			
		||||
@@ -58,6 +58,31 @@ func (h authenticationHandler) LoggedIn(scopes ...Scope) gin.HandlerFunc {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UserIdMatch checks if the user id in the session matches the user id in the request. If not, the request is aborted.
 | 
			
		||||
func (h authenticationHandler) UserIdMatch(idParameter string) gin.HandlerFunc {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		session := h.Session.GetData(c)
 | 
			
		||||
 | 
			
		||||
		if session.IsAdmin {
 | 
			
		||||
			c.Next() // Admins can do everything
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		sessionUserId := domain.UserIdentifier(session.UserIdentifier)
 | 
			
		||||
		requestUserId := domain.UserIdentifier(Base64UrlDecode(c.Param(idParameter)))
 | 
			
		||||
 | 
			
		||||
		if sessionUserId != requestUserId {
 | 
			
		||||
			// Abort the request with the appropriate error code
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			c.JSON(http.StatusForbidden, model.Error{Code: http.StatusForbidden, Message: "not enough permissions"})
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Continue down the chain to handler etc
 | 
			
		||||
		c.Next()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UserHasScopes(session SessionData, scopes ...Scope) bool {
 | 
			
		||||
	// No scopes give, so the check should succeed
 | 
			
		||||
	if len(scopes) == 0 {
 | 
			
		||||
 
 | 
			
		||||
@@ -150,6 +150,7 @@ func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.Lo
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Authenticator) IsUserValid(ctx context.Context, id domain.UserIdentifier) bool {
 | 
			
		||||
	ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context
 | 
			
		||||
	user, err := a.users.GetUser(ctx, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false
 | 
			
		||||
@@ -187,6 +188,8 @@ func (a *Authenticator) PlainLogin(ctx context.Context, username, password strin
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Authenticator) passwordAuthentication(ctx context.Context, identifier domain.UserIdentifier, password string) (*domain.User, error) {
 | 
			
		||||
	ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
 | 
			
		||||
 | 
			
		||||
	var ldapUserInfo *domain.AuthenticatorUserInfo
 | 
			
		||||
	var ldapProvider domain.LdapAuthenticator
 | 
			
		||||
 | 
			
		||||
@@ -315,6 +318,7 @@ func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce,
 | 
			
		||||
		return nil, fmt.Errorf("unable to parse user information: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists
 | 
			
		||||
	user, err := a.processUserInfo(ctx, userInfo, domain.UserSourceOauth, oauthProvider.GetName(), oauthProvider.RegistrationEnabled())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("unable to process user information: %w", err)
 | 
			
		||||
 
 | 
			
		||||
@@ -109,6 +109,10 @@ func (m Manager) handlePeerInterfaceUpdatedEvent(id domain.InterfaceIdentifier)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	iface, peers, err := m.wg.GetInterfaceAndPeers(ctx, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("failed to fetch interface %s: %w", id, err)
 | 
			
		||||
@@ -123,6 +127,10 @@ func (m Manager) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (i
 | 
			
		||||
		return nil, fmt.Errorf("failed to fetch peer %s: %w", id, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return m.tplHandler.GetPeerConfig(peer)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -132,6 +140,10 @@ func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifi
 | 
			
		||||
		return nil, fmt.Errorf("failed to fetch peer %s: %w", id, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	cfgData, err := m.tplHandler.GetPeerConfig(peer)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("failed to get peer config for %s: %w", id, err)
 | 
			
		||||
@@ -172,6 +184,10 @@ func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifi
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if m.fsRepo == nil {
 | 
			
		||||
		return fmt.Errorf("peristing configuration is not supported")
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -44,6 +44,10 @@ func (m Manager) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...doma
 | 
			
		||||
			return fmt.Errorf("failed to fetch peer %s: %w", peerId, err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if peer.UserIdentifier == "" {
 | 
			
		||||
			logrus.Debugf("skipping peer email for %s, no user linked", peerId)
 | 
			
		||||
			continue
 | 
			
		||||
 
 | 
			
		||||
@@ -43,6 +43,10 @@ func NewUserManager(cfg *config.Config, bus evbus.MessageBus, users UserDatabase
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) RegisterUser(ctx context.Context, user *domain.User) error {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := m.NewUser(ctx, user)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
@@ -58,6 +62,10 @@ func (m Manager) NewUser(ctx context.Context, user *domain.User) error {
 | 
			
		||||
		return errors.New("missing user identifier")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := m.users.SaveUser(ctx, user.Identifier, func(u *domain.User) (*domain.User, error) {
 | 
			
		||||
		u.Identifier = user.Identifier
 | 
			
		||||
		u.Email = user.Email
 | 
			
		||||
@@ -83,6 +91,10 @@ func (m Manager) StartBackgroundJobs(ctx context.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
 | 
			
		||||
	if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	user, err := m.users.GetUser(ctx, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("unable to load peer %s: %w", id, err)
 | 
			
		||||
@@ -95,6 +107,10 @@ func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	users, err := m.users.GetAllUsers(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("unable to load users: %w", err)
 | 
			
		||||
@@ -123,6 +139,10 @@ func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
 | 
			
		||||
	if err := domain.ValidateUserAccessRights(ctx, user.Identifier); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	existingUser, err := m.users.GetUser(ctx, user.Identifier)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("unable to load existing user %s: %w", user.Identifier, err)
 | 
			
		||||
@@ -153,6 +173,10 @@ func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.Use
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	existingUser, err := m.users.GetUser(ctx, user.Identifier)
 | 
			
		||||
	if err != nil && !errors.Is(err, domain.ErrNotFound) {
 | 
			
		||||
		return nil, fmt.Errorf("unable to load existing user %s: %w", user.Identifier, err)
 | 
			
		||||
@@ -182,6 +206,10 @@ func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.Use
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) DeleteUser(ctx context.Context, id domain.UserIdentifier) error {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	existingUser, err := m.users.GetUser(ctx, id)
 | 
			
		||||
	if err != nil && !errors.Is(err, domain.ErrNotFound) {
 | 
			
		||||
		return fmt.Errorf("unable to find user %s: %w", id, err)
 | 
			
		||||
 
 | 
			
		||||
@@ -47,7 +47,8 @@ func (m Manager) handleUserCreationEvent(user *domain.User) {
 | 
			
		||||
	logrus.Errorf("handling new user event for %s", user.Identifier)
 | 
			
		||||
 | 
			
		||||
	if m.cfg.Core.CreateDefaultPeer {
 | 
			
		||||
		err := m.CreateDefaultPeer(context.Background(), user)
 | 
			
		||||
		ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
 | 
			
		||||
		err := m.CreateDefaultPeer(ctx, user)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logrus.Errorf("failed to create default peer for %s: %v", user.Identifier, err)
 | 
			
		||||
			return
 | 
			
		||||
 
 | 
			
		||||
@@ -13,6 +13,10 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	physicalInterfaces, err := m.wg.GetInterfaces(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@@ -22,14 +26,26 @@ func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.Physical
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return m.db.GetInterfaceAndPeers(ctx, id)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return m.db.GetAllInterfaces(ctx)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	interfaces, err := m.db.GetAllInterfaces(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, fmt.Errorf("unable to load all interfaces: %w", err)
 | 
			
		||||
@@ -48,6 +64,10 @@ func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interfa
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	physicalInterfaces, err := m.wg.GetInterfaces(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
@@ -95,6 +115,10 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	existingInterface, err := m.db.GetInterface(ctx, in.Identifier)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err)
 | 
			
		||||
@@ -122,6 +146,10 @@ func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) er
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool, filter ...domain.InterfaceIdentifier) error {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	interfaces, err := m.db.GetAllInterfaces(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
@@ -201,6 +229,10 @@ func (m Manager) RestoreInterfaceState(ctx context.Context, updateDbOnError bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error) {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	currentUser := domain.GetUserInfo(ctx)
 | 
			
		||||
 | 
			
		||||
	kp, err := domain.NewFreshKeypair()
 | 
			
		||||
@@ -277,6 +309,10 @@ func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	existingInterface, err := m.db.GetInterface(ctx, in.Identifier)
 | 
			
		||||
	if err != nil && !errors.Is(err, domain.ErrNotFound) {
 | 
			
		||||
		return nil, fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err)
 | 
			
		||||
@@ -298,6 +334,10 @@ func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*do
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error) {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	existingInterface, existingPeers, err := m.db.GetInterfaceAndPeers(ctx, in.Identifier)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, fmt.Errorf("unable to load existing interface %s: %w", in.Identifier, err)
 | 
			
		||||
@@ -316,6 +356,10 @@ func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*do
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	existingInterface, err := m.db.GetInterface(ctx, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("unable to find interface %s: %w", id, err)
 | 
			
		||||
 
 | 
			
		||||
@@ -12,6 +12,10 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (m Manager) CreateDefaultPeer(ctx context.Context, user *domain.User) error {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	existingInterfaces, err := m.db.GetAllInterfaces(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to fetch all interfaces: %w", err)
 | 
			
		||||
@@ -49,10 +53,18 @@ func (m Manager) CreateDefaultPeer(ctx context.Context, user *domain.User) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
 | 
			
		||||
	if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return m.db.GetUserPeers(ctx, id)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return nil, err // TODO: self provisioning?
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	currentUser := domain.GetUserInfo(ctx)
 | 
			
		||||
 | 
			
		||||
	iface, err := m.db.GetInterface(ctx, id)
 | 
			
		||||
@@ -128,10 +140,18 @@ func (m Manager) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain
 | 
			
		||||
		return nil, fmt.Errorf("unable to find peer %s: %w", id, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return peer, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) {
 | 
			
		||||
	if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	existingPeer, err := m.db.GetPeer(ctx, peer.Identifier)
 | 
			
		||||
	if err != nil && !errors.Is(err, domain.ErrNotFound) {
 | 
			
		||||
		return nil, fmt.Errorf("unable to load existing peer %s: %w", peer.Identifier, err)
 | 
			
		||||
@@ -153,6 +173,10 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) CreateMultiplePeers(ctx context.Context, interfaceId domain.InterfaceIdentifier, r *domain.PeerCreationRequest) ([]domain.Peer, error) {
 | 
			
		||||
	if err := domain.ValidateAdminAccessRights(ctx); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var newPeers []*domain.Peer
 | 
			
		||||
 | 
			
		||||
	for _, id := range r.UserIdentifiers {
 | 
			
		||||
@@ -192,6 +216,10 @@ func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee
 | 
			
		||||
		return nil, fmt.Errorf("unable to load existing peer %s: %w", peer.Identifier, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := domain.ValidateUserAccessRights(ctx, existingPeer.UserIdentifier); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := m.validatePeerModifications(ctx, existingPeer, peer); err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("update not allowed: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -210,6 +238,10 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
 | 
			
		||||
		return fmt.Errorf("unable to find peer %s: %w", id, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = m.wg.DeletePeer(ctx, peer.InterfaceIdentifier, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("wireguard failed to delete peer %s: %w", id, err)
 | 
			
		||||
@@ -231,6 +263,10 @@ func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier
 | 
			
		||||
 | 
			
		||||
	peerIds := make([]domain.PeerIdentifier, len(peers))
 | 
			
		||||
	for i, peer := range peers {
 | 
			
		||||
		if err := domain.ValidateUserAccessRights(ctx, peer.UserIdentifier); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		peerIds[i] = peer.Identifier
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -238,6 +274,10 @@ func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) {
 | 
			
		||||
	if err := domain.ValidateUserAccessRights(ctx, id); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	peers, err := m.db.GetUserPeers(ctx, id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("failed to fetch peers for user %s: %w", id, err)
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ package domain
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/sirupsen/logrus"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
@@ -72,3 +73,29 @@ func GetUserInfo(ctx context.Context) *ContextUserInfo {
 | 
			
		||||
 | 
			
		||||
	return DefaultContextUserInfo()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ValidateUserAccessRights(ctx context.Context, requiredUser UserIdentifier) error {
 | 
			
		||||
	sessionUser := GetUserInfo(ctx)
 | 
			
		||||
 | 
			
		||||
	if sessionUser.IsAdmin {
 | 
			
		||||
		return nil // Admins can do everything
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if sessionUser.Id == requiredUser {
 | 
			
		||||
		return nil // User can access own data
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logrus.Warnf("insufficient permissions for %s (want %s), stack: %s", sessionUser.Id, requiredUser, GetStackTrace())
 | 
			
		||||
	return fmt.Errorf("insufficient permissions")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ValidateAdminAccessRights(ctx context.Context) error {
 | 
			
		||||
	sessionUser := GetUserInfo(ctx)
 | 
			
		||||
 | 
			
		||||
	if sessionUser.IsAdmin {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logrus.Warnf("insufficient admin permissions for %s, stack: %s", sessionUser.Id, GetStackTrace())
 | 
			
		||||
	return fmt.Errorf("insufficient permissions")
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,18 @@
 | 
			
		||||
package domain
 | 
			
		||||
 | 
			
		||||
import "errors"
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"runtime"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ErrNotFound = errors.New("record not found")
 | 
			
		||||
var ErrNotUnique = errors.New("record not unique")
 | 
			
		||||
 | 
			
		||||
// GetStackTrace returns a stack trace of the current goroutine. The stack trace has at most 1024 bytes.
 | 
			
		||||
func GetStackTrace() string {
 | 
			
		||||
	b := make([]byte, 1024)
 | 
			
		||||
	n := runtime.Stack(b, false)
 | 
			
		||||
	s := string(b[:n])
 | 
			
		||||
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user