mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 12:07:37 +00:00
feat: ldap provider login (#5448)
Add the logic to configure and use LDAP provider as an external IDP with a dedicated login GUI.
This commit is contained in:
@@ -2,6 +2,7 @@ package ldap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
)
|
||||
@@ -12,16 +13,18 @@ var _ idp.Provider = (*Provider)(nil)
|
||||
|
||||
// Provider is the [idp.Provider] implementation for a generic LDAP provider
|
||||
type Provider struct {
|
||||
name string
|
||||
host string
|
||||
port string
|
||||
tls bool
|
||||
baseDN string
|
||||
userObjectClass string
|
||||
userUniqueAttribute string
|
||||
admin string
|
||||
password string
|
||||
loginUrl string
|
||||
name string
|
||||
servers []string
|
||||
startTLS bool
|
||||
baseDN string
|
||||
bindDN string
|
||||
bindPassword string
|
||||
userBase string
|
||||
userObjectClasses []string
|
||||
userFilters []string
|
||||
timeout time.Duration
|
||||
|
||||
loginUrl string
|
||||
|
||||
isLinkingAllowed bool
|
||||
isCreationAllowed bool
|
||||
@@ -74,17 +77,10 @@ func WithAutoUpdate() ProviderOpts {
|
||||
}
|
||||
}
|
||||
|
||||
// WithCustomPort configures a custom port used for the communication instead of :389 as per default
|
||||
func WithCustomPort(port string) ProviderOpts {
|
||||
// WithoutStartTLS configures to communication insecure with the LDAP server without startTLS
|
||||
func WithoutStartTLS() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.port = port
|
||||
}
|
||||
}
|
||||
|
||||
// Insecure configures to communication insecure with the LDAP server without TLS
|
||||
func Insecure() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.tls = false
|
||||
p.startTLS = false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,27 +177,29 @@ func WithProfileAttribute(name string) ProviderOpts {
|
||||
|
||||
func New(
|
||||
name string,
|
||||
host string,
|
||||
servers []string,
|
||||
baseDN string,
|
||||
userObjectClass string,
|
||||
userUniqueAttribute string,
|
||||
admin string,
|
||||
password string,
|
||||
bindDN string,
|
||||
bindPassword string,
|
||||
userBase string,
|
||||
userObjectClasses []string,
|
||||
userFilters []string,
|
||||
timeout time.Duration,
|
||||
loginUrl string,
|
||||
options ...ProviderOpts,
|
||||
) *Provider {
|
||||
provider := &Provider{
|
||||
name: name,
|
||||
host: host,
|
||||
port: DefaultPort,
|
||||
tls: true,
|
||||
baseDN: baseDN,
|
||||
userObjectClass: userObjectClass,
|
||||
userUniqueAttribute: userUniqueAttribute,
|
||||
admin: admin,
|
||||
password: password,
|
||||
loginUrl: loginUrl,
|
||||
idAttribute: userUniqueAttribute,
|
||||
name: name,
|
||||
servers: servers,
|
||||
startTLS: true,
|
||||
baseDN: baseDN,
|
||||
bindDN: bindDN,
|
||||
bindPassword: bindPassword,
|
||||
userBase: userBase,
|
||||
userObjectClasses: userObjectClasses,
|
||||
userFilters: userFilters,
|
||||
timeout: timeout,
|
||||
loginUrl: loginUrl,
|
||||
}
|
||||
for _, option := range options {
|
||||
option(provider)
|
||||
@@ -216,7 +214,7 @@ func (p *Provider) Name() string {
|
||||
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) {
|
||||
return &Session{
|
||||
Provider: p,
|
||||
loginUrl: p.loginUrl + "?state=" + state,
|
||||
loginUrl: p.loginUrl + state,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -235,3 +233,47 @@ func (p *Provider) IsAutoCreation() bool {
|
||||
func (p *Provider) IsAutoUpdate() bool {
|
||||
return p.isAutoUpdate
|
||||
}
|
||||
|
||||
func (p *Provider) getNecessaryAttributes() []string {
|
||||
attributes := []string{p.userBase}
|
||||
if p.idAttribute != "" {
|
||||
attributes = append(attributes, p.idAttribute)
|
||||
}
|
||||
if p.firstNameAttribute != "" {
|
||||
attributes = append(attributes, p.firstNameAttribute)
|
||||
}
|
||||
if p.lastNameAttribute != "" {
|
||||
attributes = append(attributes, p.lastNameAttribute)
|
||||
}
|
||||
if p.displayNameAttribute != "" {
|
||||
attributes = append(attributes, p.displayNameAttribute)
|
||||
}
|
||||
if p.nickNameAttribute != "" {
|
||||
attributes = append(attributes, p.nickNameAttribute)
|
||||
}
|
||||
if p.preferredUsernameAttribute != "" {
|
||||
attributes = append(attributes, p.preferredUsernameAttribute)
|
||||
}
|
||||
if p.emailAttribute != "" {
|
||||
attributes = append(attributes, p.emailAttribute)
|
||||
}
|
||||
if p.emailVerifiedAttribute != "" {
|
||||
attributes = append(attributes, p.emailVerifiedAttribute)
|
||||
}
|
||||
if p.phoneAttribute != "" {
|
||||
attributes = append(attributes, p.phoneAttribute)
|
||||
}
|
||||
if p.phoneVerifiedAttribute != "" {
|
||||
attributes = append(attributes, p.phoneVerifiedAttribute)
|
||||
}
|
||||
if p.preferredLanguageAttribute != "" {
|
||||
attributes = append(attributes, p.preferredLanguageAttribute)
|
||||
}
|
||||
if p.avatarURLAttribute != "" {
|
||||
attributes = append(attributes, p.avatarURLAttribute)
|
||||
}
|
||||
if p.profileAttribute != "" {
|
||||
attributes = append(attributes, p.profileAttribute)
|
||||
}
|
||||
return attributes
|
||||
}
|
||||
|
@@ -2,26 +2,28 @@ package ldap
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestProvider_Options(t *testing.T) {
|
||||
type fields struct {
|
||||
name string
|
||||
host string
|
||||
baseDN string
|
||||
userObjectClass string
|
||||
userUniqueAttribute string
|
||||
admin string
|
||||
password string
|
||||
loginUrl string
|
||||
opts []ProviderOpts
|
||||
name string
|
||||
servers []string
|
||||
baseDN string
|
||||
bindDN string
|
||||
bindPassword string
|
||||
userBase string
|
||||
userObjectClasses []string
|
||||
userFilters []string
|
||||
timeout time.Duration
|
||||
loginUrl string
|
||||
opts []ProviderOpts
|
||||
}
|
||||
type want struct {
|
||||
name string
|
||||
port string
|
||||
tls bool
|
||||
startTls bool
|
||||
linkingAllowed bool
|
||||
creationAllowed bool
|
||||
autoCreation bool
|
||||
@@ -48,39 +50,43 @@ func TestProvider_Options(t *testing.T) {
|
||||
{
|
||||
name: "default",
|
||||
fields: fields{
|
||||
name: "ldap",
|
||||
host: "host",
|
||||
baseDN: "base",
|
||||
userObjectClass: "class",
|
||||
userUniqueAttribute: "attr",
|
||||
admin: "admin",
|
||||
password: "password",
|
||||
loginUrl: "url",
|
||||
opts: nil,
|
||||
name: "ldap",
|
||||
servers: []string{"server"},
|
||||
baseDN: "base",
|
||||
bindDN: "binddn",
|
||||
bindPassword: "password",
|
||||
userBase: "user",
|
||||
userObjectClasses: []string{"object"},
|
||||
userFilters: []string{"filter"},
|
||||
timeout: 30 * time.Second,
|
||||
loginUrl: "url",
|
||||
opts: nil,
|
||||
},
|
||||
want: want{
|
||||
name: "ldap",
|
||||
port: DefaultPort,
|
||||
tls: true,
|
||||
startTls: true,
|
||||
linkingAllowed: false,
|
||||
creationAllowed: false,
|
||||
autoCreation: false,
|
||||
autoUpdate: false,
|
||||
idAttribute: "attr",
|
||||
idAttribute: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all true",
|
||||
fields: fields{
|
||||
name: "ldap",
|
||||
host: "host",
|
||||
baseDN: "base",
|
||||
userObjectClass: "class",
|
||||
userUniqueAttribute: "attr",
|
||||
admin: "admin",
|
||||
password: "password",
|
||||
loginUrl: "url",
|
||||
name: "ldap",
|
||||
servers: []string{"server"},
|
||||
baseDN: "base",
|
||||
bindDN: "binddn",
|
||||
bindPassword: "password",
|
||||
userBase: "user",
|
||||
userObjectClasses: []string{"object"},
|
||||
userFilters: []string{"filter"},
|
||||
timeout: 30 * time.Second,
|
||||
loginUrl: "url",
|
||||
opts: []ProviderOpts{
|
||||
WithoutStartTLS(),
|
||||
WithLinkingAllowed(),
|
||||
WithCreationAllowed(),
|
||||
WithAutoCreation(),
|
||||
@@ -89,28 +95,28 @@ func TestProvider_Options(t *testing.T) {
|
||||
},
|
||||
want: want{
|
||||
name: "ldap",
|
||||
port: DefaultPort,
|
||||
tls: true,
|
||||
startTls: false,
|
||||
linkingAllowed: true,
|
||||
creationAllowed: true,
|
||||
autoCreation: true,
|
||||
autoUpdate: true,
|
||||
idAttribute: "attr",
|
||||
idAttribute: "",
|
||||
},
|
||||
}, {
|
||||
name: "all true, attributes set",
|
||||
fields: fields{
|
||||
name: "ldap",
|
||||
host: "host",
|
||||
baseDN: "base",
|
||||
userObjectClass: "class",
|
||||
userUniqueAttribute: "attr",
|
||||
admin: "admin",
|
||||
password: "password",
|
||||
loginUrl: "url",
|
||||
name: "ldap",
|
||||
servers: []string{"server"},
|
||||
baseDN: "base",
|
||||
bindDN: "binddn",
|
||||
bindPassword: "password",
|
||||
userBase: "user",
|
||||
userObjectClasses: []string{"object"},
|
||||
userFilters: []string{"filter"},
|
||||
timeout: 30 * time.Second,
|
||||
loginUrl: "url",
|
||||
opts: []ProviderOpts{
|
||||
Insecure(),
|
||||
WithCustomPort("port"),
|
||||
WithoutStartTLS(),
|
||||
WithLinkingAllowed(),
|
||||
WithCreationAllowed(),
|
||||
WithAutoCreation(),
|
||||
@@ -132,8 +138,7 @@ func TestProvider_Options(t *testing.T) {
|
||||
},
|
||||
want: want{
|
||||
name: "ldap",
|
||||
port: "port",
|
||||
tls: false,
|
||||
startTls: false,
|
||||
linkingAllowed: true,
|
||||
creationAllowed: true,
|
||||
autoCreation: true,
|
||||
@@ -157,11 +162,22 @@ func TestProvider_Options(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
provider := New(tt.fields.name, tt.fields.host, tt.fields.baseDN, tt.fields.userObjectClass, tt.fields.userUniqueAttribute, tt.fields.admin, tt.fields.password, tt.fields.loginUrl, tt.fields.opts...)
|
||||
provider := New(
|
||||
tt.fields.name,
|
||||
tt.fields.servers,
|
||||
tt.fields.baseDN,
|
||||
tt.fields.bindDN,
|
||||
tt.fields.bindPassword,
|
||||
tt.fields.userBase,
|
||||
tt.fields.userObjectClasses,
|
||||
tt.fields.userFilters,
|
||||
tt.fields.timeout,
|
||||
tt.fields.loginUrl,
|
||||
tt.fields.opts...,
|
||||
)
|
||||
|
||||
a.Equal(tt.want.name, provider.Name())
|
||||
a.Equal(tt.want.port, provider.port)
|
||||
a.Equal(tt.want.tls, provider.tls)
|
||||
a.Equal(tt.want.startTls, provider.startTLS)
|
||||
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
|
||||
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
|
||||
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
|
||||
|
@@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
"golang.org/x/text/language"
|
||||
@@ -15,49 +17,154 @@ import (
|
||||
)
|
||||
|
||||
var ErrNoSingleUser = errors.New("user does not exist or too many entries returned")
|
||||
var ErrFailedLogin = errors.New("user failed to login")
|
||||
|
||||
var _ idp.Session = (*Session)(nil)
|
||||
|
||||
type Session struct {
|
||||
Provider *Provider
|
||||
loginUrl string
|
||||
user string
|
||||
password string
|
||||
User string
|
||||
Password string
|
||||
}
|
||||
|
||||
func (s *Session) GetAuthURL() string {
|
||||
return s.loginUrl
|
||||
}
|
||||
func (s *Session) FetchUser(_ context.Context) (idp.User, error) {
|
||||
l, err := ldap.DialURL("ldap://" + s.Provider.host + ":" + s.Provider.port)
|
||||
|
||||
func (s *Session) FetchUser(_ context.Context) (_ idp.User, err error) {
|
||||
var user *ldap.Entry
|
||||
for _, server := range s.Provider.servers {
|
||||
user, err = tryBind(server,
|
||||
s.Provider.startTLS,
|
||||
s.Provider.bindDN,
|
||||
s.Provider.bindPassword,
|
||||
s.Provider.baseDN,
|
||||
s.Provider.getNecessaryAttributes(),
|
||||
s.Provider.userObjectClasses,
|
||||
s.Provider.userFilters,
|
||||
s.User,
|
||||
s.Password, s.Provider.timeout)
|
||||
// If there were invalid credentials or multiple users with the credentials cancel process
|
||||
if err != nil && (errors.Is(err, ErrFailedLogin) || errors.Is(err, ErrNoSingleUser)) {
|
||||
return nil, err
|
||||
}
|
||||
// If a user bind was successful and user is filled continue with login, otherwise try next server
|
||||
if err == nil && user != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer l.Close()
|
||||
|
||||
if s.Provider.tls {
|
||||
err = l.StartTLS(&tls.Config{ServerName: s.Provider.host})
|
||||
return mapLDAPEntryToUser(
|
||||
user,
|
||||
s.Provider.idAttribute,
|
||||
s.Provider.firstNameAttribute,
|
||||
s.Provider.lastNameAttribute,
|
||||
s.Provider.displayNameAttribute,
|
||||
s.Provider.nickNameAttribute,
|
||||
s.Provider.preferredUsernameAttribute,
|
||||
s.Provider.emailAttribute,
|
||||
s.Provider.emailVerifiedAttribute,
|
||||
s.Provider.phoneAttribute,
|
||||
s.Provider.phoneVerifiedAttribute,
|
||||
s.Provider.preferredLanguageAttribute,
|
||||
s.Provider.avatarURLAttribute,
|
||||
s.Provider.profileAttribute,
|
||||
)
|
||||
}
|
||||
|
||||
func tryBind(
|
||||
server string,
|
||||
startTLS bool,
|
||||
bindDN string,
|
||||
bindPassword string,
|
||||
baseDN string,
|
||||
attributes []string,
|
||||
objectClasses []string,
|
||||
userFilters []string,
|
||||
username string,
|
||||
password string,
|
||||
timeout time.Duration,
|
||||
) (*ldap.Entry, error) {
|
||||
conn, err := getConnection(server, startTLS, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.Bind(bindDN, bindPassword); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return trySearchAndUserBind(
|
||||
conn,
|
||||
baseDN,
|
||||
attributes,
|
||||
objectClasses,
|
||||
userFilters,
|
||||
username,
|
||||
password,
|
||||
timeout,
|
||||
)
|
||||
}
|
||||
|
||||
func getConnection(
|
||||
server string,
|
||||
startTLS bool,
|
||||
timeout time.Duration,
|
||||
) (*ldap.Conn, error) {
|
||||
if timeout == 0 {
|
||||
timeout = ldap.DefaultTimeout
|
||||
}
|
||||
|
||||
conn, err := ldap.DialURL(server, ldap.DialWithDialer(&net.Dialer{Timeout: timeout}))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u, err := url.Parse(server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if u.Scheme == "ldaps" && startTLS {
|
||||
err = conn.StartTLS(&tls.Config{ServerName: u.Host})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Bind as the admin to search for user
|
||||
err = l.Bind("cn="+s.Provider.admin+","+s.Provider.baseDN, s.Provider.password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func trySearchAndUserBind(
|
||||
conn *ldap.Conn,
|
||||
baseDN string,
|
||||
attributes []string,
|
||||
objectClasses []string,
|
||||
userFilters []string,
|
||||
username string,
|
||||
password string,
|
||||
timeout time.Duration,
|
||||
) (*ldap.Entry, error) {
|
||||
searchQuery := queriesAndToSearchQuery(
|
||||
objectClassesToSearchQuery(objectClasses),
|
||||
queriesOrToSearchQuery(
|
||||
userFiltersToSearchQuery(userFilters, username),
|
||||
),
|
||||
)
|
||||
|
||||
// Search for user with the unique attribute for the userDN
|
||||
searchRequest := ldap.NewSearchRequest(
|
||||
s.Provider.baseDN,
|
||||
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
|
||||
fmt.Sprintf("(&(objectClass="+s.Provider.userObjectClass+")("+s.Provider.userUniqueAttribute+"=%s))", ldap.EscapeFilter(s.user)),
|
||||
[]string{"dn"},
|
||||
baseDN,
|
||||
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, int(timeout.Seconds()), false,
|
||||
searchQuery,
|
||||
attributes,
|
||||
nil,
|
||||
)
|
||||
|
||||
sr, err := l.Search(searchRequest)
|
||||
sr, err := conn.Search(searchRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -67,33 +174,100 @@ func (s *Session) FetchUser(_ context.Context) (idp.User, error) {
|
||||
|
||||
user := sr.Entries[0]
|
||||
// Bind as the user to verify their password
|
||||
err = l.Bind(user.DN, s.password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if err = conn.Bind(user.DN, password); err != nil {
|
||||
return nil, ErrFailedLogin
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
emailVerified, err := strconv.ParseBool(user.GetAttributeValue(s.Provider.emailVerifiedAttribute))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func queriesAndToSearchQuery(queries ...string) string {
|
||||
if len(queries) == 0 {
|
||||
return ""
|
||||
}
|
||||
phoneVerified, err := strconv.ParseBool(user.GetAttributeValue(s.Provider.phoneVerifiedAttribute))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if len(queries) == 1 {
|
||||
return queries[0]
|
||||
}
|
||||
joinQueries := "(&"
|
||||
for _, s := range queries {
|
||||
joinQueries += s
|
||||
}
|
||||
return joinQueries + ")"
|
||||
}
|
||||
|
||||
func queriesOrToSearchQuery(queries ...string) string {
|
||||
if len(queries) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(queries) == 1 {
|
||||
return queries[0]
|
||||
}
|
||||
joinQueries := "(|"
|
||||
for _, s := range queries {
|
||||
joinQueries += s
|
||||
}
|
||||
return joinQueries + ")"
|
||||
}
|
||||
|
||||
func objectClassesToSearchQuery(classes []string) string {
|
||||
searchQuery := ""
|
||||
for _, class := range classes {
|
||||
searchQuery += "(objectClass=" + class + ")"
|
||||
}
|
||||
return searchQuery
|
||||
}
|
||||
|
||||
func userFiltersToSearchQuery(filters []string, username string) string {
|
||||
searchQuery := ""
|
||||
for _, filter := range filters {
|
||||
searchQuery += "(" + filter + "=" + ldap.EscapeFilter(username) + ")"
|
||||
}
|
||||
return searchQuery
|
||||
}
|
||||
|
||||
func mapLDAPEntryToUser(
|
||||
user *ldap.Entry,
|
||||
idAttribute,
|
||||
firstNameAttribute,
|
||||
lastNameAttribute,
|
||||
displayNameAttribute,
|
||||
nickNameAttribute,
|
||||
preferredUsernameAttribute,
|
||||
emailAttribute,
|
||||
emailVerifiedAttribute,
|
||||
phoneAttribute,
|
||||
phoneVerifiedAttribute,
|
||||
preferredLanguageAttribute,
|
||||
avatarURLAttribute,
|
||||
profileAttribute string,
|
||||
) (_ *User, err error) {
|
||||
var emailVerified bool
|
||||
if v := user.GetAttributeValue(emailVerifiedAttribute); v != "" {
|
||||
emailVerified, err = strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
var phoneVerified bool
|
||||
if v := user.GetAttributeValue(phoneVerifiedAttribute); v != "" {
|
||||
phoneVerified, err = strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return NewUser(
|
||||
user.GetAttributeValue(s.Provider.idAttribute),
|
||||
user.GetAttributeValue(s.Provider.firstNameAttribute),
|
||||
user.GetAttributeValue(s.Provider.lastNameAttribute),
|
||||
user.GetAttributeValue(s.Provider.displayNameAttribute),
|
||||
user.GetAttributeValue(s.Provider.nickNameAttribute),
|
||||
user.GetAttributeValue(s.Provider.preferredUsernameAttribute),
|
||||
domain.EmailAddress(user.GetAttributeValue(s.Provider.emailAttribute)),
|
||||
user.GetAttributeValue(idAttribute),
|
||||
user.GetAttributeValue(firstNameAttribute),
|
||||
user.GetAttributeValue(lastNameAttribute),
|
||||
user.GetAttributeValue(displayNameAttribute),
|
||||
user.GetAttributeValue(nickNameAttribute),
|
||||
user.GetAttributeValue(preferredUsernameAttribute),
|
||||
domain.EmailAddress(user.GetAttributeValue(emailAttribute)),
|
||||
emailVerified,
|
||||
domain.PhoneNumber(user.GetAttributeValue(s.Provider.phoneAttribute)),
|
||||
domain.PhoneNumber(user.GetAttributeValue(phoneAttribute)),
|
||||
phoneVerified,
|
||||
language.Make(user.GetAttributeValue(s.Provider.preferredLanguageAttribute)),
|
||||
user.GetAttributeValue(s.Provider.avatarURLAttribute),
|
||||
user.GetAttributeValue(s.Provider.profileAttribute),
|
||||
language.Make(user.GetAttributeValue(preferredLanguageAttribute)),
|
||||
user.GetAttributeValue(avatarURLAttribute),
|
||||
user.GetAttributeValue(profileAttribute),
|
||||
), nil
|
||||
}
|
||||
|
400
internal/idp/providers/ldap/session_test.go
Normal file
400
internal/idp/providers/ldap/session_test.go
Normal file
@@ -0,0 +1,400 @@
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
func TestProvider_objectClassesToSearchQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fields []string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "zero",
|
||||
fields: []string{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "one",
|
||||
fields: []string{"test"},
|
||||
want: "(objectClass=test)",
|
||||
},
|
||||
{
|
||||
name: "three",
|
||||
fields: []string{"test1", "test2", "test3"},
|
||||
want: "(objectClass=test1)(objectClass=test2)(objectClass=test3)",
|
||||
},
|
||||
{
|
||||
name: "five",
|
||||
fields: []string{"test1", "test2", "test3", "test4", "test5"},
|
||||
want: "(objectClass=test1)(objectClass=test2)(objectClass=test3)(objectClass=test4)(objectClass=test5)",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
a.Equal(tt.want, objectClassesToSearchQuery(tt.fields))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_userFiltersToSearchQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fields []string
|
||||
username string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "zero",
|
||||
fields: []string{},
|
||||
username: "user",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "one",
|
||||
fields: []string{"test"},
|
||||
username: "user",
|
||||
want: "(test=user)",
|
||||
},
|
||||
{
|
||||
name: "three",
|
||||
fields: []string{"test1", "test2", "test3"},
|
||||
username: "user",
|
||||
want: "(test1=user)(test2=user)(test3=user)",
|
||||
},
|
||||
{
|
||||
name: "five",
|
||||
fields: []string{"test1", "test2", "test3", "test4", "test5"},
|
||||
username: "user",
|
||||
want: "(test1=user)(test2=user)(test3=user)(test4=user)(test5=user)",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
a.Equal(tt.want, userFiltersToSearchQuery(tt.fields, tt.username))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_queriesAndToSearchQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fields []string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "zero",
|
||||
fields: []string{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "one",
|
||||
fields: []string{"(test)"},
|
||||
want: "(test)",
|
||||
},
|
||||
{
|
||||
name: "three",
|
||||
fields: []string{"(test1)", "(test2)", "(test3)"},
|
||||
want: "(&(test1)(test2)(test3))",
|
||||
},
|
||||
{
|
||||
name: "five",
|
||||
fields: []string{"(test1)", "(test2)", "(test3)", "(test4)", "(test5)"},
|
||||
want: "(&(test1)(test2)(test3)(test4)(test5))",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
a.Equal(tt.want, queriesAndToSearchQuery(tt.fields...))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_queriesOrToSearchQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fields []string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "zero",
|
||||
fields: []string{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "one",
|
||||
fields: []string{"(test)"},
|
||||
want: "(test)",
|
||||
},
|
||||
{
|
||||
name: "three",
|
||||
fields: []string{"(test1)", "(test2)", "(test3)"},
|
||||
want: "(|(test1)(test2)(test3))",
|
||||
},
|
||||
{
|
||||
name: "five",
|
||||
fields: []string{"(test1)", "(test2)", "(test3)", "(test4)", "(test5)"},
|
||||
want: "(|(test1)(test2)(test3)(test4)(test5))",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
a.Equal(tt.want, queriesOrToSearchQuery(tt.fields...))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_mapLDAPEntryToUser(t *testing.T) {
|
||||
type fields struct {
|
||||
user *ldap.Entry
|
||||
idAttribute string
|
||||
firstNameAttribute string
|
||||
lastNameAttribute string
|
||||
displayNameAttribute string
|
||||
nickNameAttribute string
|
||||
preferredUsernameAttribute string
|
||||
emailAttribute string
|
||||
emailVerifiedAttribute string
|
||||
phoneAttribute string
|
||||
phoneVerifiedAttribute string
|
||||
preferredLanguageAttribute string
|
||||
avatarURLAttribute string
|
||||
profileAttribute string
|
||||
}
|
||||
type want struct {
|
||||
user *User
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
fields: fields{
|
||||
user: &ldap.Entry{
|
||||
Attributes: []*ldap.EntryAttribute{
|
||||
{Name: "id", Values: []string{"id"}},
|
||||
{Name: "first", Values: []string{"first"}},
|
||||
{Name: "last", Values: []string{"last"}},
|
||||
{Name: "display", Values: []string{"display"}},
|
||||
{Name: "nick", Values: []string{"nick"}},
|
||||
{Name: "preferred", Values: []string{"preferred"}},
|
||||
{Name: "email", Values: []string{"email"}},
|
||||
{Name: "emailVerified", Values: []string{"false"}},
|
||||
{Name: "phone", Values: []string{"phone"}},
|
||||
{Name: "phoneVerified", Values: []string{"false"}},
|
||||
{Name: "lang", Values: []string{"und"}},
|
||||
{Name: "avatar", Values: []string{"avatar"}},
|
||||
{Name: "profile", Values: []string{"profile"}},
|
||||
},
|
||||
},
|
||||
idAttribute: "",
|
||||
firstNameAttribute: "",
|
||||
lastNameAttribute: "",
|
||||
displayNameAttribute: "",
|
||||
nickNameAttribute: "",
|
||||
preferredUsernameAttribute: "",
|
||||
emailAttribute: "",
|
||||
emailVerifiedAttribute: "",
|
||||
phoneAttribute: "",
|
||||
phoneVerifiedAttribute: "",
|
||||
preferredLanguageAttribute: "",
|
||||
avatarURLAttribute: "",
|
||||
profileAttribute: "",
|
||||
},
|
||||
want: want{
|
||||
user: &User{
|
||||
id: "",
|
||||
firstName: "",
|
||||
lastName: "",
|
||||
displayName: "",
|
||||
nickName: "",
|
||||
preferredUsername: "",
|
||||
email: "",
|
||||
emailVerified: false,
|
||||
phone: "",
|
||||
phoneVerified: false,
|
||||
preferredLanguage: language.Tag{},
|
||||
avatarURL: "",
|
||||
profile: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "failed parse emailVerified",
|
||||
fields: fields{
|
||||
user: &ldap.Entry{
|
||||
Attributes: []*ldap.EntryAttribute{
|
||||
{Name: "id", Values: []string{"id"}},
|
||||
{Name: "first", Values: []string{"first"}},
|
||||
{Name: "last", Values: []string{"last"}},
|
||||
{Name: "display", Values: []string{"display"}},
|
||||
{Name: "nick", Values: []string{"nick"}},
|
||||
{Name: "preferred", Values: []string{"preferred"}},
|
||||
{Name: "email", Values: []string{"email"}},
|
||||
{Name: "emailVerified", Values: []string{"failure"}},
|
||||
{Name: "phone", Values: []string{"phone"}},
|
||||
{Name: "phoneVerified", Values: []string{"false"}},
|
||||
{Name: "lang", Values: []string{"und"}},
|
||||
{Name: "avatar", Values: []string{"avatar"}},
|
||||
{Name: "profile", Values: []string{"profile"}},
|
||||
},
|
||||
},
|
||||
idAttribute: "id",
|
||||
firstNameAttribute: "first",
|
||||
lastNameAttribute: "last",
|
||||
displayNameAttribute: "display",
|
||||
nickNameAttribute: "nick",
|
||||
preferredUsernameAttribute: "preferred",
|
||||
emailAttribute: "email",
|
||||
emailVerifiedAttribute: "emailVerified",
|
||||
phoneAttribute: "phone",
|
||||
phoneVerifiedAttribute: "phoneVerified",
|
||||
preferredLanguageAttribute: "lang",
|
||||
avatarURLAttribute: "avatar",
|
||||
profileAttribute: "profile",
|
||||
},
|
||||
want: want{
|
||||
err: func(err error) bool {
|
||||
return err != nil
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "failed parse phoneVerified",
|
||||
fields: fields{
|
||||
user: &ldap.Entry{
|
||||
Attributes: []*ldap.EntryAttribute{
|
||||
{Name: "id", Values: []string{"id"}},
|
||||
{Name: "first", Values: []string{"first"}},
|
||||
{Name: "last", Values: []string{"last"}},
|
||||
{Name: "display", Values: []string{"display"}},
|
||||
{Name: "nick", Values: []string{"nick"}},
|
||||
{Name: "preferred", Values: []string{"preferred"}},
|
||||
{Name: "email", Values: []string{"email"}},
|
||||
{Name: "emailVerified", Values: []string{"false"}},
|
||||
{Name: "phone", Values: []string{"phone"}},
|
||||
{Name: "phoneVerified", Values: []string{"failure"}},
|
||||
{Name: "lang", Values: []string{"und"}},
|
||||
{Name: "avatar", Values: []string{"avatar"}},
|
||||
{Name: "profile", Values: []string{"profile"}},
|
||||
},
|
||||
},
|
||||
idAttribute: "id",
|
||||
firstNameAttribute: "first",
|
||||
lastNameAttribute: "last",
|
||||
displayNameAttribute: "display",
|
||||
nickNameAttribute: "nick",
|
||||
preferredUsernameAttribute: "preferred",
|
||||
emailAttribute: "email",
|
||||
emailVerifiedAttribute: "emailVerified",
|
||||
phoneAttribute: "phone",
|
||||
phoneVerifiedAttribute: "phoneVerified",
|
||||
preferredLanguageAttribute: "lang",
|
||||
avatarURLAttribute: "avatar",
|
||||
profileAttribute: "profile",
|
||||
},
|
||||
want: want{
|
||||
err: func(err error) bool {
|
||||
return err != nil
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "full user",
|
||||
fields: fields{
|
||||
user: &ldap.Entry{
|
||||
Attributes: []*ldap.EntryAttribute{
|
||||
{Name: "id", Values: []string{"id"}},
|
||||
{Name: "first", Values: []string{"first"}},
|
||||
{Name: "last", Values: []string{"last"}},
|
||||
{Name: "display", Values: []string{"display"}},
|
||||
{Name: "nick", Values: []string{"nick"}},
|
||||
{Name: "preferred", Values: []string{"preferred"}},
|
||||
{Name: "email", Values: []string{"email"}},
|
||||
{Name: "emailVerified", Values: []string{"false"}},
|
||||
{Name: "phone", Values: []string{"phone"}},
|
||||
{Name: "phoneVerified", Values: []string{"false"}},
|
||||
{Name: "lang", Values: []string{"und"}},
|
||||
{Name: "avatar", Values: []string{"avatar"}},
|
||||
{Name: "profile", Values: []string{"profile"}},
|
||||
},
|
||||
},
|
||||
idAttribute: "id",
|
||||
firstNameAttribute: "first",
|
||||
lastNameAttribute: "last",
|
||||
displayNameAttribute: "display",
|
||||
nickNameAttribute: "nick",
|
||||
preferredUsernameAttribute: "preferred",
|
||||
emailAttribute: "email",
|
||||
emailVerifiedAttribute: "emailVerified",
|
||||
phoneAttribute: "phone",
|
||||
phoneVerifiedAttribute: "phoneVerified",
|
||||
preferredLanguageAttribute: "lang",
|
||||
avatarURLAttribute: "avatar",
|
||||
profileAttribute: "profile",
|
||||
},
|
||||
want: want{
|
||||
user: &User{
|
||||
id: "id",
|
||||
firstName: "first",
|
||||
lastName: "last",
|
||||
displayName: "display",
|
||||
nickName: "nick",
|
||||
preferredUsername: "preferred",
|
||||
email: "email",
|
||||
emailVerified: false,
|
||||
phone: "phone",
|
||||
phoneVerified: false,
|
||||
preferredLanguage: language.Make("und"),
|
||||
avatarURL: "avatar",
|
||||
profile: "profile",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := mapLDAPEntryToUser(
|
||||
tt.fields.user,
|
||||
tt.fields.idAttribute,
|
||||
tt.fields.firstNameAttribute,
|
||||
tt.fields.lastNameAttribute,
|
||||
tt.fields.displayNameAttribute,
|
||||
tt.fields.nickNameAttribute,
|
||||
tt.fields.preferredUsernameAttribute,
|
||||
tt.fields.emailAttribute,
|
||||
tt.fields.emailVerifiedAttribute,
|
||||
tt.fields.phoneAttribute,
|
||||
tt.fields.phoneVerifiedAttribute,
|
||||
tt.fields.preferredLanguageAttribute,
|
||||
tt.fields.avatarURLAttribute,
|
||||
tt.fields.profileAttribute,
|
||||
)
|
||||
if tt.want.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if tt.want.err != nil && !tt.want.err(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
if tt.want.err == nil {
|
||||
assert.Equal(t, tt.want.user, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user