zitadel/internal/idp/providers/ldap/session.go
Stefan Benz d92717a1c6
fix: encode ldap values to make valid UTF8 (#8210)
# Which Problems Are Solved

UUIDs stored in LDAP are Octet Strings and have to be parsed, so that
they can be stored as IDs as they are not valid UTF8.

# How the Problems Are Solved

Try to parse the RawValue from LDAP as UUID, otherwise try to base64
decode and then parse as UUID, else use the data as string as before.

# Additional Changes

None

# Additional Context

Closes #7601
2024-06-28 13:46:54 +00:00

294 lines
6.7 KiB
Go

package ldap
import (
"context"
"crypto/tls"
"encoding/base64"
"errors"
"net"
"net/url"
"strconv"
"time"
"unicode/utf8"
"github.com/go-ldap/ldap/v3"
"github.com/zitadel/logging"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp"
)
var ErrNoSingleUser = errors.New("user does not exist or too many entries returned")
var ErrFailedLogin = errors.New("user failed to login")
var _ idp.Session = (*Session)(nil)
type Session struct {
Provider *Provider
loginUrl string
User string
Password string
Entry *ldap.Entry
}
// GetAuth implements the [idp.Session] interface.
func (s *Session) GetAuth(ctx context.Context) (string, bool) {
return idp.Redirect(s.loginUrl)
}
func (s *Session) FetchUser(_ context.Context) (_ idp.User, err error) {
var user *ldap.Entry
for _, server := range s.Provider.servers {
user, err = tryBind(server,
s.Provider.startTLS,
s.Provider.bindDN,
s.Provider.bindPassword,
s.Provider.baseDN,
s.Provider.getNecessaryAttributes(),
s.Provider.userObjectClasses,
s.Provider.userFilters,
s.User,
s.Password, s.Provider.timeout)
// If there were invalid credentials or multiple users with the credentials cancel process
if err != nil && (errors.Is(err, ErrFailedLogin) || errors.Is(err, ErrNoSingleUser)) {
return nil, err
}
// If a user bind was successful and user is filled continue with login, otherwise try next server
if err == nil && user != nil {
break
}
}
if err != nil {
return nil, err
}
s.Entry = user
return mapLDAPEntryToUser(
user,
s.Provider.idAttribute,
s.Provider.firstNameAttribute,
s.Provider.lastNameAttribute,
s.Provider.displayNameAttribute,
s.Provider.nickNameAttribute,
s.Provider.preferredUsernameAttribute,
s.Provider.emailAttribute,
s.Provider.emailVerifiedAttribute,
s.Provider.phoneAttribute,
s.Provider.phoneVerifiedAttribute,
s.Provider.preferredLanguageAttribute,
s.Provider.avatarURLAttribute,
s.Provider.profileAttribute,
)
}
func tryBind(
server string,
startTLS bool,
bindDN string,
bindPassword string,
baseDN string,
attributes []string,
objectClasses []string,
userFilters []string,
username string,
password string,
timeout time.Duration,
) (*ldap.Entry, error) {
conn, err := getConnection(server, startTLS, timeout)
if err != nil {
return nil, err
}
defer conn.Close()
if err := conn.Bind(bindDN, bindPassword); err != nil {
return nil, err
}
return trySearchAndUserBind(
conn,
baseDN,
attributes,
objectClasses,
userFilters,
username,
password,
timeout,
)
}
func getConnection(
server string,
startTLS bool,
timeout time.Duration,
) (*ldap.Conn, error) {
if timeout == 0 {
timeout = ldap.DefaultTimeout
}
conn, err := ldap.DialURL(server, ldap.DialWithDialer(&net.Dialer{Timeout: timeout}))
if err != nil {
return nil, err
}
u, err := url.Parse(server)
if err != nil {
return nil, err
}
if u.Scheme == "ldaps" && startTLS {
err = conn.StartTLS(&tls.Config{ServerName: u.Host})
if err != nil {
return nil, err
}
}
return conn, nil
}
func trySearchAndUserBind(
conn *ldap.Conn,
baseDN string,
attributes []string,
objectClasses []string,
userFilters []string,
username string,
password string,
timeout time.Duration,
) (*ldap.Entry, error) {
searchQuery := queriesAndToSearchQuery(
objectClassesToSearchQuery(objectClasses),
queriesOrToSearchQuery(
userFiltersToSearchQuery(userFilters, username),
),
)
// Search for user with the unique attribute for the userDN
searchRequest := ldap.NewSearchRequest(
baseDN,
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, int(timeout.Seconds()), false,
searchQuery,
attributes,
nil,
)
sr, err := conn.Search(searchRequest)
if err != nil {
return nil, err
}
if len(sr.Entries) != 1 {
logging.WithFields("entries", len(sr.Entries)).Info("ldap: no single user found")
return nil, ErrNoSingleUser
}
user := sr.Entries[0]
// Bind as the user to verify their password
if err = conn.Bind(user.DN, password); err != nil {
logging.WithFields("userDN", user.DN).WithError(err).Info("ldap user bind failed")
return nil, ErrFailedLogin
}
return user, nil
}
func queriesAndToSearchQuery(queries ...string) string {
if len(queries) == 0 {
return ""
}
if len(queries) == 1 {
return queries[0]
}
joinQueries := "(&"
for _, s := range queries {
joinQueries += s
}
return joinQueries + ")"
}
func queriesOrToSearchQuery(queries ...string) string {
if len(queries) == 0 {
return ""
}
if len(queries) == 1 {
return queries[0]
}
joinQueries := "(|"
for _, s := range queries {
joinQueries += s
}
return joinQueries + ")"
}
func objectClassesToSearchQuery(classes []string) string {
searchQuery := ""
for _, class := range classes {
searchQuery += "(objectClass=" + class + ")"
}
return searchQuery
}
func userFiltersToSearchQuery(filters []string, username string) string {
searchQuery := ""
for _, filter := range filters {
searchQuery += "(" + filter + "=" + ldap.EscapeFilter(username) + ")"
}
return searchQuery
}
func mapLDAPEntryToUser(
user *ldap.Entry,
idAttribute,
firstNameAttribute,
lastNameAttribute,
displayNameAttribute,
nickNameAttribute,
preferredUsernameAttribute,
emailAttribute,
emailVerifiedAttribute,
phoneAttribute,
phoneVerifiedAttribute,
preferredLanguageAttribute,
avatarURLAttribute,
profileAttribute string,
) (_ *User, err error) {
var emailVerified bool
if v := user.GetAttributeValue(emailVerifiedAttribute); v != "" {
emailVerified, err = strconv.ParseBool(v)
if err != nil {
return nil, err
}
}
var phoneVerified bool
if v := user.GetAttributeValue(phoneVerifiedAttribute); v != "" {
phoneVerified, err = strconv.ParseBool(v)
if err != nil {
return nil, err
}
}
return NewUser(
getAttributeValue(user, idAttribute),
getAttributeValue(user, firstNameAttribute),
getAttributeValue(user, lastNameAttribute),
getAttributeValue(user, displayNameAttribute),
getAttributeValue(user, nickNameAttribute),
getAttributeValue(user, preferredUsernameAttribute),
domain.EmailAddress(user.GetAttributeValue(emailAttribute)),
emailVerified,
domain.PhoneNumber(user.GetAttributeValue(phoneAttribute)),
phoneVerified,
language.Make(user.GetAttributeValue(preferredLanguageAttribute)),
user.GetAttributeValue(avatarURLAttribute),
user.GetAttributeValue(profileAttribute),
), nil
}
func getAttributeValue(user *ldap.Entry, attribute string) string {
// return an empty string if no attribute is needed
if attribute == "" {
return ""
}
value := user.GetAttributeValue(attribute)
if utf8.ValidString(value) {
return value
}
return base64.StdEncoding.EncodeToString(user.GetRawAttributeValue(attribute))
}