mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 08:27:32 +00:00
feat: ldap provider login (#5448)
Add the logic to configure and use LDAP provider as an external IDP with a dedicated login GUI.
This commit is contained in:
@@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
"golang.org/x/text/language"
|
||||
@@ -15,49 +17,154 @@ import (
|
||||
)
|
||||
|
||||
var ErrNoSingleUser = errors.New("user does not exist or too many entries returned")
|
||||
var ErrFailedLogin = errors.New("user failed to login")
|
||||
|
||||
var _ idp.Session = (*Session)(nil)
|
||||
|
||||
type Session struct {
|
||||
Provider *Provider
|
||||
loginUrl string
|
||||
user string
|
||||
password string
|
||||
User string
|
||||
Password string
|
||||
}
|
||||
|
||||
func (s *Session) GetAuthURL() string {
|
||||
return s.loginUrl
|
||||
}
|
||||
func (s *Session) FetchUser(_ context.Context) (idp.User, error) {
|
||||
l, err := ldap.DialURL("ldap://" + s.Provider.host + ":" + s.Provider.port)
|
||||
|
||||
func (s *Session) FetchUser(_ context.Context) (_ idp.User, err error) {
|
||||
var user *ldap.Entry
|
||||
for _, server := range s.Provider.servers {
|
||||
user, err = tryBind(server,
|
||||
s.Provider.startTLS,
|
||||
s.Provider.bindDN,
|
||||
s.Provider.bindPassword,
|
||||
s.Provider.baseDN,
|
||||
s.Provider.getNecessaryAttributes(),
|
||||
s.Provider.userObjectClasses,
|
||||
s.Provider.userFilters,
|
||||
s.User,
|
||||
s.Password, s.Provider.timeout)
|
||||
// If there were invalid credentials or multiple users with the credentials cancel process
|
||||
if err != nil && (errors.Is(err, ErrFailedLogin) || errors.Is(err, ErrNoSingleUser)) {
|
||||
return nil, err
|
||||
}
|
||||
// If a user bind was successful and user is filled continue with login, otherwise try next server
|
||||
if err == nil && user != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer l.Close()
|
||||
|
||||
if s.Provider.tls {
|
||||
err = l.StartTLS(&tls.Config{ServerName: s.Provider.host})
|
||||
return mapLDAPEntryToUser(
|
||||
user,
|
||||
s.Provider.idAttribute,
|
||||
s.Provider.firstNameAttribute,
|
||||
s.Provider.lastNameAttribute,
|
||||
s.Provider.displayNameAttribute,
|
||||
s.Provider.nickNameAttribute,
|
||||
s.Provider.preferredUsernameAttribute,
|
||||
s.Provider.emailAttribute,
|
||||
s.Provider.emailVerifiedAttribute,
|
||||
s.Provider.phoneAttribute,
|
||||
s.Provider.phoneVerifiedAttribute,
|
||||
s.Provider.preferredLanguageAttribute,
|
||||
s.Provider.avatarURLAttribute,
|
||||
s.Provider.profileAttribute,
|
||||
)
|
||||
}
|
||||
|
||||
func tryBind(
|
||||
server string,
|
||||
startTLS bool,
|
||||
bindDN string,
|
||||
bindPassword string,
|
||||
baseDN string,
|
||||
attributes []string,
|
||||
objectClasses []string,
|
||||
userFilters []string,
|
||||
username string,
|
||||
password string,
|
||||
timeout time.Duration,
|
||||
) (*ldap.Entry, error) {
|
||||
conn, err := getConnection(server, startTLS, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.Bind(bindDN, bindPassword); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return trySearchAndUserBind(
|
||||
conn,
|
||||
baseDN,
|
||||
attributes,
|
||||
objectClasses,
|
||||
userFilters,
|
||||
username,
|
||||
password,
|
||||
timeout,
|
||||
)
|
||||
}
|
||||
|
||||
func getConnection(
|
||||
server string,
|
||||
startTLS bool,
|
||||
timeout time.Duration,
|
||||
) (*ldap.Conn, error) {
|
||||
if timeout == 0 {
|
||||
timeout = ldap.DefaultTimeout
|
||||
}
|
||||
|
||||
conn, err := ldap.DialURL(server, ldap.DialWithDialer(&net.Dialer{Timeout: timeout}))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u, err := url.Parse(server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if u.Scheme == "ldaps" && startTLS {
|
||||
err = conn.StartTLS(&tls.Config{ServerName: u.Host})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Bind as the admin to search for user
|
||||
err = l.Bind("cn="+s.Provider.admin+","+s.Provider.baseDN, s.Provider.password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func trySearchAndUserBind(
|
||||
conn *ldap.Conn,
|
||||
baseDN string,
|
||||
attributes []string,
|
||||
objectClasses []string,
|
||||
userFilters []string,
|
||||
username string,
|
||||
password string,
|
||||
timeout time.Duration,
|
||||
) (*ldap.Entry, error) {
|
||||
searchQuery := queriesAndToSearchQuery(
|
||||
objectClassesToSearchQuery(objectClasses),
|
||||
queriesOrToSearchQuery(
|
||||
userFiltersToSearchQuery(userFilters, username),
|
||||
),
|
||||
)
|
||||
|
||||
// Search for user with the unique attribute for the userDN
|
||||
searchRequest := ldap.NewSearchRequest(
|
||||
s.Provider.baseDN,
|
||||
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
|
||||
fmt.Sprintf("(&(objectClass="+s.Provider.userObjectClass+")("+s.Provider.userUniqueAttribute+"=%s))", ldap.EscapeFilter(s.user)),
|
||||
[]string{"dn"},
|
||||
baseDN,
|
||||
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, int(timeout.Seconds()), false,
|
||||
searchQuery,
|
||||
attributes,
|
||||
nil,
|
||||
)
|
||||
|
||||
sr, err := l.Search(searchRequest)
|
||||
sr, err := conn.Search(searchRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -67,33 +174,100 @@ func (s *Session) FetchUser(_ context.Context) (idp.User, error) {
|
||||
|
||||
user := sr.Entries[0]
|
||||
// Bind as the user to verify their password
|
||||
err = l.Bind(user.DN, s.password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if err = conn.Bind(user.DN, password); err != nil {
|
||||
return nil, ErrFailedLogin
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
emailVerified, err := strconv.ParseBool(user.GetAttributeValue(s.Provider.emailVerifiedAttribute))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func queriesAndToSearchQuery(queries ...string) string {
|
||||
if len(queries) == 0 {
|
||||
return ""
|
||||
}
|
||||
phoneVerified, err := strconv.ParseBool(user.GetAttributeValue(s.Provider.phoneVerifiedAttribute))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if len(queries) == 1 {
|
||||
return queries[0]
|
||||
}
|
||||
joinQueries := "(&"
|
||||
for _, s := range queries {
|
||||
joinQueries += s
|
||||
}
|
||||
return joinQueries + ")"
|
||||
}
|
||||
|
||||
func queriesOrToSearchQuery(queries ...string) string {
|
||||
if len(queries) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(queries) == 1 {
|
||||
return queries[0]
|
||||
}
|
||||
joinQueries := "(|"
|
||||
for _, s := range queries {
|
||||
joinQueries += s
|
||||
}
|
||||
return joinQueries + ")"
|
||||
}
|
||||
|
||||
func objectClassesToSearchQuery(classes []string) string {
|
||||
searchQuery := ""
|
||||
for _, class := range classes {
|
||||
searchQuery += "(objectClass=" + class + ")"
|
||||
}
|
||||
return searchQuery
|
||||
}
|
||||
|
||||
func userFiltersToSearchQuery(filters []string, username string) string {
|
||||
searchQuery := ""
|
||||
for _, filter := range filters {
|
||||
searchQuery += "(" + filter + "=" + ldap.EscapeFilter(username) + ")"
|
||||
}
|
||||
return searchQuery
|
||||
}
|
||||
|
||||
func mapLDAPEntryToUser(
|
||||
user *ldap.Entry,
|
||||
idAttribute,
|
||||
firstNameAttribute,
|
||||
lastNameAttribute,
|
||||
displayNameAttribute,
|
||||
nickNameAttribute,
|
||||
preferredUsernameAttribute,
|
||||
emailAttribute,
|
||||
emailVerifiedAttribute,
|
||||
phoneAttribute,
|
||||
phoneVerifiedAttribute,
|
||||
preferredLanguageAttribute,
|
||||
avatarURLAttribute,
|
||||
profileAttribute string,
|
||||
) (_ *User, err error) {
|
||||
var emailVerified bool
|
||||
if v := user.GetAttributeValue(emailVerifiedAttribute); v != "" {
|
||||
emailVerified, err = strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
var phoneVerified bool
|
||||
if v := user.GetAttributeValue(phoneVerifiedAttribute); v != "" {
|
||||
phoneVerified, err = strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return NewUser(
|
||||
user.GetAttributeValue(s.Provider.idAttribute),
|
||||
user.GetAttributeValue(s.Provider.firstNameAttribute),
|
||||
user.GetAttributeValue(s.Provider.lastNameAttribute),
|
||||
user.GetAttributeValue(s.Provider.displayNameAttribute),
|
||||
user.GetAttributeValue(s.Provider.nickNameAttribute),
|
||||
user.GetAttributeValue(s.Provider.preferredUsernameAttribute),
|
||||
domain.EmailAddress(user.GetAttributeValue(s.Provider.emailAttribute)),
|
||||
user.GetAttributeValue(idAttribute),
|
||||
user.GetAttributeValue(firstNameAttribute),
|
||||
user.GetAttributeValue(lastNameAttribute),
|
||||
user.GetAttributeValue(displayNameAttribute),
|
||||
user.GetAttributeValue(nickNameAttribute),
|
||||
user.GetAttributeValue(preferredUsernameAttribute),
|
||||
domain.EmailAddress(user.GetAttributeValue(emailAttribute)),
|
||||
emailVerified,
|
||||
domain.PhoneNumber(user.GetAttributeValue(s.Provider.phoneAttribute)),
|
||||
domain.PhoneNumber(user.GetAttributeValue(phoneAttribute)),
|
||||
phoneVerified,
|
||||
language.Make(user.GetAttributeValue(s.Provider.preferredLanguageAttribute)),
|
||||
user.GetAttributeValue(s.Provider.avatarURLAttribute),
|
||||
user.GetAttributeValue(s.Provider.profileAttribute),
|
||||
language.Make(user.GetAttributeValue(preferredLanguageAttribute)),
|
||||
user.GetAttributeValue(avatarURLAttribute),
|
||||
user.GetAttributeValue(profileAttribute),
|
||||
), nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user