feat(ldap): adding root ca option to ldap config (#9292)

# Which Problems Are Solved

Adding ability to add a root CA to LDAP configs

# Additional Context

- Closes https://github.com/zitadel/zitadel/issues/7888

---------

Co-authored-by: Iraq Jaber <IraqJaber@gmail.com>
This commit is contained in:
Iraq
2025-02-18 10:06:50 +00:00
committed by GitHub
parent d7332d1ac4
commit 5bbb953ffb
27 changed files with 418 additions and 243 deletions

View File

@@ -23,6 +23,7 @@ type Provider struct {
userObjectClasses []string
userFilters []string
timeout time.Duration
rootCA []byte
loginUrl string
@@ -185,6 +186,7 @@ func New(
userObjectClasses []string,
userFilters []string,
timeout time.Duration,
rootCA []byte,
loginUrl string,
options ...ProviderOpts,
) *Provider {
@@ -199,6 +201,7 @@ func New(
userObjectClasses: userObjectClasses,
userFilters: userFilters,
timeout: timeout,
rootCA: rootCA,
loginUrl: loginUrl,
}
for _, option := range options {

View File

@@ -18,11 +18,13 @@ func TestProvider_Options(t *testing.T) {
userObjectClasses []string
userFilters []string
timeout time.Duration
rootCA []byte
loginUrl string
opts []ProviderOpts
}
type want struct {
name string
rootCA []byte
startTls bool
linkingAllowed bool
creationAllowed bool
@@ -114,6 +116,7 @@ func TestProvider_Options(t *testing.T) {
userObjectClasses: []string{"object"},
userFilters: []string{"filter"},
timeout: 30 * time.Second,
rootCA: []byte("certificate"),
loginUrl: "url",
opts: []ProviderOpts{
WithoutStartTLS(),
@@ -138,6 +141,7 @@ func TestProvider_Options(t *testing.T) {
},
want: want{
name: "ldap",
rootCA: []byte("certificate"),
startTls: false,
linkingAllowed: true,
creationAllowed: true,
@@ -172,11 +176,13 @@ func TestProvider_Options(t *testing.T) {
tt.fields.userObjectClasses,
tt.fields.userFilters,
tt.fields.timeout,
tt.fields.rootCA,
tt.fields.loginUrl,
tt.fields.opts...,
)
a.Equal(tt.want.name, provider.Name())
a.Equal(tt.want.rootCA, provider.rootCA)
a.Equal(tt.want.startTls, provider.startTLS)
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())

View File

@@ -3,6 +3,7 @@ package ldap
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"errors"
"net"
@@ -21,6 +22,7 @@ import (
var ErrNoSingleUser = errors.New("user does not exist or too many entries returned")
var ErrFailedLogin = errors.New("user failed to login")
var ErrUnableToAppendRootCA = errors.New("unable to append rootCA")
var _ idp.Session = (*Session)(nil)
@@ -49,7 +51,9 @@ func (s *Session) FetchUser(_ context.Context) (_ idp.User, err error) {
s.Provider.userObjectClasses,
s.Provider.userFilters,
s.User,
s.Password, s.Provider.timeout)
s.Password,
s.Provider.timeout,
s.Provider.rootCA)
// If there were invalid credentials or multiple users with the credentials cancel process
if err != nil && (errors.Is(err, ErrFailedLogin) || errors.Is(err, ErrNoSingleUser)) {
return nil, err
@@ -94,8 +98,9 @@ func tryBind(
username string,
password string,
timeout time.Duration,
rootCA []byte,
) (*ldap.Entry, error) {
conn, err := getConnection(server, startTLS, timeout)
conn, err := getConnection(server, startTLS, timeout, rootCA)
if err != nil {
return nil, err
}
@@ -114,6 +119,7 @@ func tryBind(
username,
password,
timeout,
rootCA,
)
}
@@ -121,21 +127,37 @@ func getConnection(
server string,
startTLS bool,
timeout time.Duration,
rootCA []byte,
) (*ldap.Conn, error) {
if timeout == 0 {
timeout = ldap.DefaultTimeout
}
conn, err := ldap.DialURL(server, ldap.DialWithDialer(&net.Dialer{Timeout: timeout}))
if err != nil {
return nil, err
}
dialer := make([]ldap.DialOpt, 1, 2)
dialer[0] = ldap.DialWithDialer(&net.Dialer{Timeout: timeout})
u, err := url.Parse(server)
if err != nil {
return nil, err
}
if u.Scheme == "ldaps" && startTLS {
if u.Scheme == "ldaps" && len(rootCA) > 0 {
rootCAs := x509.NewCertPool()
if ok := rootCAs.AppendCertsFromPEM(rootCA); !ok {
return nil, ErrUnableToAppendRootCA
}
dialer = append(dialer, ldap.DialWithTLSConfig(&tls.Config{
RootCAs: rootCAs,
}))
}
conn, err := ldap.DialURL(server, dialer...)
if err != nil {
return nil, err
}
if u.Scheme == "ldap" && startTLS {
err = conn.StartTLS(&tls.Config{ServerName: u.Host})
if err != nil {
return nil, err
@@ -153,6 +175,7 @@ func trySearchAndUserBind(
username string,
password string,
timeout time.Duration,
rootCA []byte,
) (*ldap.Entry, error) {
searchQuery := queriesAndToSearchQuery(
objectClassesToSearchQuery(objectClasses),