mirror of
https://github.com/h44z/wg-portal.git
synced 2025-08-25 14:31:14 +00:00
WIP: new user management and authentication system, use go 1.16 embed
This commit is contained in:
203
internal/authentication/providers/ldap/provider.go
Normal file
203
internal/authentication/providers/ldap/provider.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
"github.com/h44z/wg-portal/internal/authentication"
|
||||
ldapconfig "github.com/h44z/wg-portal/internal/ldap"
|
||||
"github.com/h44z/wg-portal/internal/users"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Provider provide login with password method
|
||||
type Provider struct {
|
||||
config *ldapconfig.Config
|
||||
}
|
||||
|
||||
func New(cfg *ldapconfig.Config) (*Provider, error) {
|
||||
p := &Provider{
|
||||
config: cfg,
|
||||
}
|
||||
|
||||
// test ldap connectivity
|
||||
client, err := p.open()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unable to open ldap connection")
|
||||
}
|
||||
defer p.close(client)
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// GetName return provider name
|
||||
func (Provider) GetName() string {
|
||||
return string(users.UserSourceLdap)
|
||||
}
|
||||
|
||||
// GetType return provider type
|
||||
func (Provider) GetType() authentication.AuthProviderType {
|
||||
return authentication.AuthProviderTypePassword
|
||||
}
|
||||
|
||||
// GetPriority return provider priority
|
||||
func (Provider) GetPriority() int {
|
||||
return 1 // LDAP password provider
|
||||
}
|
||||
|
||||
func (provider Provider) SetupRoutes(routes *gin.RouterGroup) {
|
||||
// nothing todo here
|
||||
}
|
||||
|
||||
func (provider Provider) Login(ctx *authentication.AuthContext) (string, error) {
|
||||
username := strings.ToLower(ctx.Username)
|
||||
password := ctx.Password
|
||||
|
||||
// Validate input
|
||||
if strings.Trim(username, " ") == "" || strings.Trim(password, " ") == "" {
|
||||
return "", errors.New("empty username or password")
|
||||
}
|
||||
|
||||
client, err := provider.open()
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "unable to open ldap connection")
|
||||
}
|
||||
defer provider.close(client)
|
||||
|
||||
// Search for the given username
|
||||
attrs := []string{"dn", provider.config.EmailAttribute}
|
||||
if provider.config.DisabledAttribute != "" {
|
||||
attrs = append(attrs, provider.config.DisabledAttribute)
|
||||
}
|
||||
searchRequest := ldap.NewSearchRequest(
|
||||
provider.config.BaseDN,
|
||||
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
|
||||
fmt.Sprintf("(&(objectClass=%s)(%s=%s))", provider.config.UserClass, provider.config.EmailAttribute, username),
|
||||
attrs,
|
||||
nil,
|
||||
)
|
||||
|
||||
sr, err := client.Search(searchRequest)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "unable to find user in ldap")
|
||||
}
|
||||
|
||||
if len(sr.Entries) != 1 {
|
||||
return "", errors.Wrapf(err, "invalid amount of ldap entries (%d)", len(sr.Entries))
|
||||
}
|
||||
|
||||
userDN := sr.Entries[0].DN
|
||||
|
||||
// Check if user is disabled, if so deny login
|
||||
if provider.config.DisabledAttribute != "" {
|
||||
uac := sr.Entries[0].GetAttributeValue(provider.config.DisabledAttribute)
|
||||
switch provider.config.Type {
|
||||
case ldapconfig.TypeActiveDirectory:
|
||||
if ldapconfig.IsActiveDirectoryUserDisabled(uac) {
|
||||
return "", errors.Wrapf(err, "user is disabled")
|
||||
}
|
||||
case ldapconfig.TypeOpenLDAP:
|
||||
if ldapconfig.IsOpenLdapUserDisabled(uac) {
|
||||
return "", errors.Wrapf(err, "user is disabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Bind as the user to verify their password
|
||||
err = client.Bind(userDN, password)
|
||||
if err != nil {
|
||||
return "", errors.Wrapf(err, "invalid credentials")
|
||||
}
|
||||
|
||||
return sr.Entries[0].GetAttributeValue(provider.config.EmailAttribute), nil
|
||||
}
|
||||
|
||||
func (provider Provider) Logout(context *authentication.AuthContext) error {
|
||||
return nil // nothing todo here
|
||||
}
|
||||
|
||||
func (provider Provider) GetUserModel(ctx *authentication.AuthContext) (*authentication.User, error) {
|
||||
username := strings.ToLower(ctx.Username)
|
||||
|
||||
// Validate input
|
||||
if strings.Trim(username, " ") == "" {
|
||||
return nil, errors.New("empty username")
|
||||
}
|
||||
|
||||
client, err := provider.open()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unable to open ldap connection")
|
||||
}
|
||||
defer provider.close(client)
|
||||
|
||||
// Search for the given username
|
||||
attrs := []string{"dn", provider.config.EmailAttribute, provider.config.FirstNameAttribute, provider.config.LastNameAttribute,
|
||||
provider.config.PhoneAttribute, provider.config.GroupMemberAttribute}
|
||||
if provider.config.DisabledAttribute != "" {
|
||||
attrs = append(attrs, provider.config.DisabledAttribute)
|
||||
}
|
||||
searchRequest := ldap.NewSearchRequest(
|
||||
provider.config.BaseDN,
|
||||
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
|
||||
fmt.Sprintf("(&(objectClass=%s)(%s=%s))", provider.config.UserClass, provider.config.EmailAttribute, username),
|
||||
attrs,
|
||||
nil,
|
||||
)
|
||||
|
||||
sr, err := client.Search(searchRequest)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unable to find user in ldap")
|
||||
}
|
||||
|
||||
if len(sr.Entries) != 1 {
|
||||
return nil, errors.Wrapf(err, "invalid amount of ldap entries (%d)", len(sr.Entries))
|
||||
}
|
||||
|
||||
user := &authentication.User{
|
||||
Firstname: sr.Entries[0].GetAttributeValue(provider.config.FirstNameAttribute),
|
||||
Lastname: sr.Entries[0].GetAttributeValue(provider.config.LastNameAttribute),
|
||||
Email: sr.Entries[0].GetAttributeValue(provider.config.EmailAttribute),
|
||||
Phone: sr.Entries[0].GetAttributeValue(provider.config.PhoneAttribute),
|
||||
IsAdmin: false,
|
||||
}
|
||||
|
||||
for _, group := range sr.Entries[0].GetAttributeValues(provider.config.GroupMemberAttribute) {
|
||||
if group == provider.config.AdminLdapGroup {
|
||||
user.IsAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (provider Provider) open() (*ldap.Conn, error) {
|
||||
conn, err := ldap.DialURL(provider.config.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if provider.config.StartTLS {
|
||||
// Reconnect with TLS
|
||||
err = conn.StartTLS(&tls.Config{InsecureSkipVerify: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = conn.Bind(provider.config.BindUser, provider.config.BindPass)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (provider Provider) close(conn *ldap.Conn) {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
186
internal/authentication/providers/password/provider.go
Normal file
186
internal/authentication/providers/password/provider.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package password
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/h44z/wg-portal/internal/authentication"
|
||||
"github.com/h44z/wg-portal/internal/users"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Provider provide login with password method
|
||||
type Provider struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func New(cfg *users.Config) (*Provider, error) {
|
||||
p := &Provider{}
|
||||
|
||||
var err error
|
||||
p.db, err = users.GetDatabaseForConfig(cfg)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to setup authentication database %s", cfg.Database)
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// GetName return provider name
|
||||
func (Provider) GetName() string {
|
||||
return string(users.UserSourceDatabase)
|
||||
}
|
||||
|
||||
// GetType return provider type
|
||||
func (Provider) GetType() authentication.AuthProviderType {
|
||||
return authentication.AuthProviderTypePassword
|
||||
}
|
||||
|
||||
// GetPriority return provider priority
|
||||
func (Provider) GetPriority() int {
|
||||
return 0 // DB password provider = highest prio
|
||||
}
|
||||
|
||||
func (provider Provider) SetupRoutes(routes *gin.RouterGroup) {
|
||||
// nothing todo here
|
||||
}
|
||||
|
||||
func (provider Provider) Login(ctx *authentication.AuthContext) (string, error) {
|
||||
username := strings.ToLower(ctx.Username)
|
||||
password := ctx.Password
|
||||
|
||||
// Validate input
|
||||
if strings.Trim(username, " ") == "" || strings.Trim(password, " ") == "" {
|
||||
return "", errors.New("empty username or password")
|
||||
}
|
||||
|
||||
// Authenticate agains the users database
|
||||
user := users.User{}
|
||||
provider.db.Where("email = ?", username).First(&user)
|
||||
|
||||
if user.Email == "" {
|
||||
return "", errors.New("invalid username")
|
||||
}
|
||||
|
||||
// Compare the stored hashed password, with the hashed version of the password that was received
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
|
||||
return "", errors.New("invalid password")
|
||||
}
|
||||
|
||||
return user.Email, nil
|
||||
}
|
||||
|
||||
func (provider Provider) Logout(context *authentication.AuthContext) error {
|
||||
return nil // nothing todo here
|
||||
}
|
||||
|
||||
func (provider Provider) GetUserModel(ctx *authentication.AuthContext) (*authentication.User, error) {
|
||||
username := strings.ToLower(ctx.Username)
|
||||
|
||||
// Validate input
|
||||
if strings.Trim(username, " ") == "" {
|
||||
return nil, errors.New("empty username")
|
||||
}
|
||||
|
||||
// Fetch usermodel from users database
|
||||
user := users.User{}
|
||||
provider.db.Where("email = ?", username).First(&user)
|
||||
if user.Email != username {
|
||||
return nil, errors.New("invalid or disabled username")
|
||||
}
|
||||
|
||||
return &authentication.User{
|
||||
Email: user.Email,
|
||||
IsAdmin: user.IsAdmin,
|
||||
Firstname: user.Firstname,
|
||||
Lastname: user.Lastname,
|
||||
Phone: user.Phone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (provider Provider) InitializeAdmin(email, password string) error {
|
||||
admin := users.User{}
|
||||
provider.db.Unscoped().Where("email = ?", email).FirstOrInit(&admin)
|
||||
|
||||
// newly created admin
|
||||
if admin.Email != email {
|
||||
// For security reasons a random admin password will be generated if the default one is still in use!
|
||||
if password == "wgportal" {
|
||||
password = generateRandomPassword()
|
||||
|
||||
fmt.Println("#############################################")
|
||||
fmt.Println("Administrator credentials:")
|
||||
fmt.Println(" Email: ", email)
|
||||
fmt.Println(" Password: ", password)
|
||||
fmt.Println()
|
||||
fmt.Println("This information will only be displayed once!")
|
||||
fmt.Println("#############################################")
|
||||
}
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to hash admin password")
|
||||
}
|
||||
|
||||
admin.Email = email
|
||||
admin.Password = string(hashedPassword)
|
||||
admin.Firstname = "WireGuard"
|
||||
admin.Lastname = "Administrator"
|
||||
admin.CreatedAt = time.Now()
|
||||
admin.UpdatedAt = time.Now()
|
||||
admin.IsAdmin = true
|
||||
admin.Source = users.UserSourceDatabase
|
||||
|
||||
res := provider.db.Create(admin)
|
||||
if res.Error != nil {
|
||||
return errors.Wrapf(res.Error, "failed to create admin %s", admin.Email)
|
||||
}
|
||||
}
|
||||
|
||||
// update/reactivate
|
||||
if !admin.IsAdmin || admin.DeletedAt.Valid {
|
||||
// For security reasons a random admin password will be generated if the default one is still in use!
|
||||
if password == "wgportal" {
|
||||
password = generateRandomPassword()
|
||||
|
||||
fmt.Println("#############################################")
|
||||
fmt.Println("Administrator credentials:")
|
||||
fmt.Println(" Email: ", email)
|
||||
fmt.Println(" Password: ", password)
|
||||
fmt.Println()
|
||||
fmt.Println("This information will only be displayed once!")
|
||||
fmt.Println("#############################################")
|
||||
}
|
||||
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to hash admin password")
|
||||
}
|
||||
|
||||
admin.Password = string(hashedPassword)
|
||||
admin.IsAdmin = true
|
||||
admin.UpdatedAt = time.Now()
|
||||
|
||||
res := provider.db.Save(admin)
|
||||
if res.Error != nil {
|
||||
return errors.Wrapf(res.Error, "failed to update admin %s", admin.Email)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateRandomPassword() string {
|
||||
rand.Seed(time.Now().Unix())
|
||||
var randPassword strings.Builder
|
||||
charSet := "abcdedfghijklmnopqrstABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!#$"
|
||||
for i := 0; i < 12; i++ {
|
||||
random := rand.Intn(len(charSet))
|
||||
randPassword.WriteString(string(charSet[random]))
|
||||
}
|
||||
return randPassword.String()
|
||||
}
|
Reference in New Issue
Block a user