mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-20 23:07:33 +00:00
15fd3045e0
* feat: first implementation for saml sp * fix: add command side instance and org for saml provider * fix: add query side instance and org for saml provider * fix: request handling in event and retrieval of finished intent * fix: add review changes and integration tests * fix: add integration tests for saml idp * fix: correct unit tests with review changes * fix: add saml session unit test * fix: add saml session unit test * fix: add saml session unit test * fix: changes from review * fix: changes from review * fix: proto build error * fix: proto build error * fix: proto build error * fix: proto require metadata oneof * fix: login with saml provider * fix: integration test for saml assertion * lint client.go * fix json tag * fix: linting * fix import * fix: linting * fix saml idp query * fix: linting * lint: try all issues * revert linting config * fix: add regenerate endpoints * fix: translations * fix mk.yaml * ignore acs path for user agent cookie * fix: add AuthFromProvider test for saml * fix: integration test for saml retrieve information --------- Co-authored-by: Livio Spring <livio.a@gmail.com>
277 lines
6.1 KiB
Go
277 lines
6.1 KiB
Go
package ldap
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"net"
|
|
"net/url"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/go-ldap/ldap/v3"
|
|
"golang.org/x/text/language"
|
|
|
|
"github.com/zitadel/zitadel/internal/domain"
|
|
"github.com/zitadel/zitadel/internal/idp"
|
|
)
|
|
|
|
var ErrNoSingleUser = errors.New("user does not exist or too many entries returned")
|
|
var ErrFailedLogin = errors.New("user failed to login")
|
|
|
|
var _ idp.Session = (*Session)(nil)
|
|
|
|
type Session struct {
|
|
Provider *Provider
|
|
loginUrl string
|
|
User string
|
|
Password string
|
|
Entry *ldap.Entry
|
|
}
|
|
|
|
// GetAuth implements the [idp.Session] interface.
|
|
func (s *Session) GetAuth(ctx context.Context) (string, bool) {
|
|
return idp.Redirect(s.loginUrl)
|
|
}
|
|
|
|
func (s *Session) FetchUser(_ context.Context) (_ idp.User, err error) {
|
|
var user *ldap.Entry
|
|
for _, server := range s.Provider.servers {
|
|
user, err = tryBind(server,
|
|
s.Provider.startTLS,
|
|
s.Provider.bindDN,
|
|
s.Provider.bindPassword,
|
|
s.Provider.baseDN,
|
|
s.Provider.getNecessaryAttributes(),
|
|
s.Provider.userObjectClasses,
|
|
s.Provider.userFilters,
|
|
s.User,
|
|
s.Password, s.Provider.timeout)
|
|
// If there were invalid credentials or multiple users with the credentials cancel process
|
|
if err != nil && (errors.Is(err, ErrFailedLogin) || errors.Is(err, ErrNoSingleUser)) {
|
|
return nil, err
|
|
}
|
|
// If a user bind was successful and user is filled continue with login, otherwise try next server
|
|
if err == nil && user != nil {
|
|
break
|
|
}
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
s.Entry = user
|
|
|
|
return mapLDAPEntryToUser(
|
|
user,
|
|
s.Provider.idAttribute,
|
|
s.Provider.firstNameAttribute,
|
|
s.Provider.lastNameAttribute,
|
|
s.Provider.displayNameAttribute,
|
|
s.Provider.nickNameAttribute,
|
|
s.Provider.preferredUsernameAttribute,
|
|
s.Provider.emailAttribute,
|
|
s.Provider.emailVerifiedAttribute,
|
|
s.Provider.phoneAttribute,
|
|
s.Provider.phoneVerifiedAttribute,
|
|
s.Provider.preferredLanguageAttribute,
|
|
s.Provider.avatarURLAttribute,
|
|
s.Provider.profileAttribute,
|
|
)
|
|
}
|
|
|
|
func tryBind(
|
|
server string,
|
|
startTLS bool,
|
|
bindDN string,
|
|
bindPassword string,
|
|
baseDN string,
|
|
attributes []string,
|
|
objectClasses []string,
|
|
userFilters []string,
|
|
username string,
|
|
password string,
|
|
timeout time.Duration,
|
|
) (*ldap.Entry, error) {
|
|
conn, err := getConnection(server, startTLS, timeout)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer conn.Close()
|
|
|
|
if err := conn.Bind(bindDN, bindPassword); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return trySearchAndUserBind(
|
|
conn,
|
|
baseDN,
|
|
attributes,
|
|
objectClasses,
|
|
userFilters,
|
|
username,
|
|
password,
|
|
timeout,
|
|
)
|
|
}
|
|
|
|
func getConnection(
|
|
server string,
|
|
startTLS bool,
|
|
timeout time.Duration,
|
|
) (*ldap.Conn, error) {
|
|
if timeout == 0 {
|
|
timeout = ldap.DefaultTimeout
|
|
}
|
|
|
|
conn, err := ldap.DialURL(server, ldap.DialWithDialer(&net.Dialer{Timeout: timeout}))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
u, err := url.Parse(server)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if u.Scheme == "ldaps" && startTLS {
|
|
err = conn.StartTLS(&tls.Config{ServerName: u.Host})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return conn, nil
|
|
}
|
|
|
|
func trySearchAndUserBind(
|
|
conn *ldap.Conn,
|
|
baseDN string,
|
|
attributes []string,
|
|
objectClasses []string,
|
|
userFilters []string,
|
|
username string,
|
|
password string,
|
|
timeout time.Duration,
|
|
) (*ldap.Entry, error) {
|
|
searchQuery := queriesAndToSearchQuery(
|
|
objectClassesToSearchQuery(objectClasses),
|
|
queriesOrToSearchQuery(
|
|
userFiltersToSearchQuery(userFilters, username),
|
|
),
|
|
)
|
|
|
|
// Search for user with the unique attribute for the userDN
|
|
searchRequest := ldap.NewSearchRequest(
|
|
baseDN,
|
|
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, int(timeout.Seconds()), false,
|
|
searchQuery,
|
|
attributes,
|
|
nil,
|
|
)
|
|
|
|
sr, err := conn.Search(searchRequest)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(sr.Entries) != 1 {
|
|
return nil, ErrNoSingleUser
|
|
}
|
|
|
|
user := sr.Entries[0]
|
|
// Bind as the user to verify their password
|
|
if err = conn.Bind(user.DN, password); err != nil {
|
|
return nil, ErrFailedLogin
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
func queriesAndToSearchQuery(queries ...string) string {
|
|
if len(queries) == 0 {
|
|
return ""
|
|
}
|
|
if len(queries) == 1 {
|
|
return queries[0]
|
|
}
|
|
joinQueries := "(&"
|
|
for _, s := range queries {
|
|
joinQueries += s
|
|
}
|
|
return joinQueries + ")"
|
|
}
|
|
|
|
func queriesOrToSearchQuery(queries ...string) string {
|
|
if len(queries) == 0 {
|
|
return ""
|
|
}
|
|
if len(queries) == 1 {
|
|
return queries[0]
|
|
}
|
|
joinQueries := "(|"
|
|
for _, s := range queries {
|
|
joinQueries += s
|
|
}
|
|
return joinQueries + ")"
|
|
}
|
|
|
|
func objectClassesToSearchQuery(classes []string) string {
|
|
searchQuery := ""
|
|
for _, class := range classes {
|
|
searchQuery += "(objectClass=" + class + ")"
|
|
}
|
|
return searchQuery
|
|
}
|
|
|
|
func userFiltersToSearchQuery(filters []string, username string) string {
|
|
searchQuery := ""
|
|
for _, filter := range filters {
|
|
searchQuery += "(" + filter + "=" + ldap.EscapeFilter(username) + ")"
|
|
}
|
|
return searchQuery
|
|
}
|
|
|
|
func mapLDAPEntryToUser(
|
|
user *ldap.Entry,
|
|
idAttribute,
|
|
firstNameAttribute,
|
|
lastNameAttribute,
|
|
displayNameAttribute,
|
|
nickNameAttribute,
|
|
preferredUsernameAttribute,
|
|
emailAttribute,
|
|
emailVerifiedAttribute,
|
|
phoneAttribute,
|
|
phoneVerifiedAttribute,
|
|
preferredLanguageAttribute,
|
|
avatarURLAttribute,
|
|
profileAttribute string,
|
|
) (_ *User, err error) {
|
|
var emailVerified bool
|
|
if v := user.GetAttributeValue(emailVerifiedAttribute); v != "" {
|
|
emailVerified, err = strconv.ParseBool(v)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
var phoneVerified bool
|
|
if v := user.GetAttributeValue(phoneVerifiedAttribute); v != "" {
|
|
phoneVerified, err = strconv.ParseBool(v)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return NewUser(
|
|
user.GetAttributeValue(idAttribute),
|
|
user.GetAttributeValue(firstNameAttribute),
|
|
user.GetAttributeValue(lastNameAttribute),
|
|
user.GetAttributeValue(displayNameAttribute),
|
|
user.GetAttributeValue(nickNameAttribute),
|
|
user.GetAttributeValue(preferredUsernameAttribute),
|
|
domain.EmailAddress(user.GetAttributeValue(emailAttribute)),
|
|
emailVerified,
|
|
domain.PhoneNumber(user.GetAttributeValue(phoneAttribute)),
|
|
phoneVerified,
|
|
language.Make(user.GetAttributeValue(preferredLanguageAttribute)),
|
|
user.GetAttributeValue(avatarURLAttribute),
|
|
user.GetAttributeValue(profileAttribute),
|
|
), nil
|
|
}
|