2022-11-11 18:17:38 +01:00

211 lines
5.5 KiB
Go

package ldap
import (
"crypto/tls"
"os"
"strings"
"github.com/gin-gonic/gin"
"github.com/go-ldap/ldap/v3"
"github.com/h44z/wg-portal/internal/authentication"
ldapconfig "github.com/h44z/wg-portal/internal/ldap"
"github.com/h44z/wg-portal/internal/users"
"github.com/pkg/errors"
)
// Provider implements a password login method for an LDAP backend.
type Provider struct {
config *ldapconfig.Config
}
func New(cfg *ldapconfig.Config) (*Provider, error) {
p := &Provider{
config: cfg,
}
// test ldap connectivity
client, err := p.open()
if err != nil {
return nil, errors.Wrap(err, "unable to open ldap connection")
}
defer p.close(client)
return p, nil
}
// GetName return provider name
func (Provider) GetName() string {
return string(users.UserSourceLdap)
}
// GetType return provider type
func (Provider) GetType() authentication.AuthProviderType {
return authentication.AuthProviderTypePassword
}
// GetPriority return provider priority
func (Provider) GetPriority() int {
return 1 // LDAP password provider
}
func (provider Provider) SetupRoutes(_ *gin.RouterGroup) {
// nothing here
}
func (provider Provider) Login(ctx *authentication.AuthContext) (string, error) {
username := strings.ToLower(ctx.Username)
password := ctx.Password
// Validate input
if strings.Trim(username, " ") == "" || strings.Trim(password, " ") == "" {
return "", errors.New("empty username or password")
}
client, err := provider.open()
if err != nil {
return "", errors.Wrap(err, "unable to open ldap connection")
}
defer provider.close(client)
// Search for the given username
attrs := []string{"dn", provider.config.EmailAttribute}
loginFilter := strings.Replace(provider.config.LoginFilter, "{{login_identifier}}", username, -1)
searchRequest := ldap.NewSearchRequest(
provider.config.BaseDN,
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
loginFilter,
attrs,
nil,
)
sr, err := client.Search(searchRequest)
if err != nil {
return "", errors.Wrap(err, "unable to find user in ldap")
}
if len(sr.Entries) != 1 {
return "", errors.Errorf("invalid amount of ldap entries (%d)", len(sr.Entries))
}
// Bind as the user to verify their password
userDN := sr.Entries[0].DN
err = client.Bind(userDN, password)
if err != nil {
return "", errors.Wrapf(err, "invalid credentials")
}
return sr.Entries[0].GetAttributeValue(provider.config.EmailAttribute), nil
}
func (provider Provider) Logout(_ *authentication.AuthContext) error {
return nil // nothing here
}
func (provider Provider) GetUserModel(ctx *authentication.AuthContext) (*authentication.User, error) {
username := strings.ToLower(ctx.Username)
// Validate input
if strings.Trim(username, " ") == "" {
return nil, errors.New("empty username")
}
client, err := provider.open()
if err != nil {
return nil, errors.Wrap(err, "unable to open ldap connection")
}
defer provider.close(client)
// Search for the given username
attrs := []string{"dn", provider.config.EmailAttribute, provider.config.FirstNameAttribute, provider.config.LastNameAttribute,
provider.config.PhoneAttribute, provider.config.GroupMemberAttribute}
loginFilter := strings.Replace(provider.config.LoginFilter, "{{login_identifier}}", username, -1)
searchRequest := ldap.NewSearchRequest(
provider.config.BaseDN,
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
loginFilter,
attrs,
nil,
)
sr, err := client.Search(searchRequest)
if err != nil {
return nil, errors.Wrap(err, "unable to find user in ldap")
}
if len(sr.Entries) != 1 {
return nil, errors.Wrapf(err, "invalid amount of ldap entries (%d)", len(sr.Entries))
}
user := &authentication.User{
Firstname: sr.Entries[0].GetAttributeValue(provider.config.FirstNameAttribute),
Lastname: sr.Entries[0].GetAttributeValue(provider.config.LastNameAttribute),
Email: sr.Entries[0].GetAttributeValue(provider.config.EmailAttribute),
Phone: sr.Entries[0].GetAttributeValue(provider.config.PhoneAttribute),
IsAdmin: false,
}
for _, group := range sr.Entries[0].GetAttributeValues(provider.config.GroupMemberAttribute) {
if group == provider.config.AdminLdapGroup {
user.IsAdmin = true
break
}
}
return user, nil
}
func (provider Provider) open() (*ldap.Conn, error) {
var tlsConfig *tls.Config
if provider.config.LdapCertConn {
certPlain, err := os.ReadFile(provider.config.LdapTlsCert)
if err != nil {
return nil, errors.WithMessage(err, "failed to load the certificate")
}
key, err := os.ReadFile(provider.config.LdapTlsKey)
if err != nil {
return nil, errors.WithMessage(err, "failed to load the key")
}
certX509, err := tls.X509KeyPair(certPlain, key)
if err != nil {
return nil, errors.WithMessage(err, "failed X509")
}
tlsConfig = &tls.Config{Certificates: []tls.Certificate{certX509}}
} else {
tlsConfig = &tls.Config{InsecureSkipVerify: !provider.config.CertValidation}
}
conn, err := ldap.DialURL(provider.config.URL, ldap.DialWithTLSConfig(tlsConfig))
if err != nil {
return nil, errors.WithMessage(err, "failed to connect to LDAP")
}
if provider.config.StartTLS {
// Reconnect with TLS
err = conn.StartTLS(tlsConfig)
if err != nil {
return nil, errors.WithMessage(err, "failed to start TLS session")
}
}
err = conn.Bind(provider.config.BindUser, provider.config.BindPass)
if err != nil {
return nil, errors.WithMessage(err, "failed to bind user")
}
return conn, nil
}
func (provider Provider) close(conn *ldap.Conn) {
if conn != nil {
conn.Close()
}
}