feat(saml): allow setting nameid-format and alternative mapping for transient format (#7979)

# Which Problems Are Solved

ZITADEL currently always uses
`urn:oasis:names:tc:SAML:2.0:nameid-format:persistent` in SAML requests,
relying on the IdP to respect that flag and always return a peristent
nameid in order to be able to map the external user with an existing
user (idp link) in ZITADEL.
In case the IdP however returns a
`urn:oasis:names:tc:SAML:2.0:nameid-format:transient` (transient)
nameid, the attribute will differ between each request and it will not
be possible to match existing users.

# How the Problems Are Solved

This PR adds the following two options on SAML IdP:
- **nameIDFormat**: allows to set the nameid-format used in the SAML
Request
- **transientMappingAttributeName**: allows to set an attribute name,
which will be used instead of the nameid itself in case the returned
nameid-format is transient

# Additional Changes

To reduce impact on current installations, the `idp_templates6_saml`
table is altered with the two added columns by a setup job. New
installations will automatically get the table with the two columns
directly.
All idp unit tests are updated to use `expectEventstore` instead of the
deprecated `eventstoreExpect`.

# Additional Context

Closes #7483
Closes #7743

---------

Co-authored-by: peintnermax <max@caos.ch>
Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
This commit is contained in:
Livio Spring
2024-05-23 07:04:07 +02:00
committed by GitHub
parent 12be21a3ff
commit e57a9b57c8
58 changed files with 1306 additions and 720 deletions

View File

@@ -2,6 +2,7 @@ package admin
import (
"github.com/crewjam/saml"
"github.com/muhlemmer/gu"
idp_grpc "github.com/zitadel/zitadel/internal/api/grpc/idp"
"github.com/zitadel/zitadel/internal/api/grpc/object"
@@ -469,24 +470,36 @@ func updateAppleProviderToCommand(req *admin_pb.UpdateAppleProviderRequest) comm
}
func addSAMLProviderToCommand(req *admin_pb.AddSAMLProviderRequest) command.SAMLProvider {
var nameIDFormat *domain.SAMLNameIDFormat
if req.NameIdFormat != nil {
nameIDFormat = gu.Ptr(idp_grpc.SAMLNameIDFormatToDomain(req.GetNameIdFormat()))
}
return command.SAMLProvider{
Name: req.Name,
Metadata: req.GetMetadataXml(),
MetadataURL: req.GetMetadataUrl(),
Binding: bindingToCommand(req.Binding),
WithSignedRequest: req.WithSignedRequest,
IDPOptions: idp_grpc.OptionsToCommand(req.ProviderOptions),
Name: req.Name,
Metadata: req.GetMetadataXml(),
MetadataURL: req.GetMetadataUrl(),
Binding: bindingToCommand(req.Binding),
WithSignedRequest: req.WithSignedRequest,
NameIDFormat: nameIDFormat,
TransientMappingAttributeName: req.GetTransientMappingAttributeName(),
IDPOptions: idp_grpc.OptionsToCommand(req.ProviderOptions),
}
}
func updateSAMLProviderToCommand(req *admin_pb.UpdateSAMLProviderRequest) command.SAMLProvider {
var nameIDFormat *domain.SAMLNameIDFormat
if req.NameIdFormat != nil {
nameIDFormat = gu.Ptr(idp_grpc.SAMLNameIDFormatToDomain(req.GetNameIdFormat()))
}
return command.SAMLProvider{
Name: req.Name,
Metadata: req.GetMetadataXml(),
MetadataURL: req.GetMetadataUrl(),
Binding: bindingToCommand(req.Binding),
WithSignedRequest: req.WithSignedRequest,
IDPOptions: idp_grpc.OptionsToCommand(req.ProviderOptions),
Name: req.Name,
Metadata: req.GetMetadataXml(),
MetadataURL: req.GetMetadataUrl(),
Binding: bindingToCommand(req.Binding),
WithSignedRequest: req.WithSignedRequest,
NameIDFormat: nameIDFormat,
TransientMappingAttributeName: req.GetTransientMappingAttributeName(),
IDPOptions: idp_grpc.OptionsToCommand(req.ProviderOptions),
}
}

View File

@@ -2,6 +2,7 @@ package idp
import (
"github.com/crewjam/saml"
"github.com/muhlemmer/gu"
"google.golang.org/protobuf/types/known/durationpb"
obj_grpc "github.com/zitadel/zitadel/internal/api/grpc/object"
@@ -339,6 +340,21 @@ func azureADTenantTypeToCommand(tenantType idp_pb.AzureADTenantType) azuread.Ten
}
}
func SAMLNameIDFormatToDomain(format idp_pb.SAMLNameIDFormat) domain.SAMLNameIDFormat {
switch format {
case idp_pb.SAMLNameIDFormat_SAML_NAME_ID_FORMAT_UNSPECIFIED:
return domain.SAMLNameIDFormatUnspecified
case idp_pb.SAMLNameIDFormat_SAML_NAME_ID_FORMAT_EMAIL_ADDRESS:
return domain.SAMLNameIDFormatEmailAddress
case idp_pb.SAMLNameIDFormat_SAML_NAME_ID_FORMAT_PERSISTENT:
return domain.SAMLNameIDFormatPersistent
case idp_pb.SAMLNameIDFormat_SAML_NAME_ID_FORMAT_TRANSIENT:
return domain.SAMLNameIDFormatTransient
default:
return domain.SAMLNameIDFormatUnspecified
}
}
func ProvidersToPb(providers []*query.IDPTemplate) []*idp_pb.Provider {
list := make([]*idp_pb.Provider, len(providers))
for i, provider := range providers {
@@ -639,11 +655,17 @@ func appleConfigToPb(providerConfig *idp_pb.ProviderConfig, template *query.Appl
}
func samlConfigToPb(providerConfig *idp_pb.ProviderConfig, template *query.SAMLIDPTemplate) {
nameIDFormat := idp_pb.SAMLNameIDFormat_SAML_NAME_ID_FORMAT_PERSISTENT
if template.NameIDFormat.Valid {
nameIDFormat = nameIDToPb(template.NameIDFormat.V)
}
providerConfig.Config = &idp_pb.ProviderConfig_Saml{
Saml: &idp_pb.SAMLConfig{
MetadataXml: template.Metadata,
Binding: bindingToPb(template.Binding),
WithSignedRequest: template.WithSignedRequest,
MetadataXml: template.Metadata,
Binding: bindingToPb(template.Binding),
WithSignedRequest: template.WithSignedRequest,
NameIdFormat: nameIDFormat,
TransientMappingAttributeName: gu.Ptr(template.TransientMappingAttributeName),
},
}
}
@@ -662,3 +684,18 @@ func bindingToPb(binding string) idp_pb.SAMLBinding {
return idp_pb.SAMLBinding_SAML_BINDING_UNSPECIFIED
}
}
func nameIDToPb(format domain.SAMLNameIDFormat) idp_pb.SAMLNameIDFormat {
switch format {
case domain.SAMLNameIDFormatUnspecified:
return idp_pb.SAMLNameIDFormat_SAML_NAME_ID_FORMAT_UNSPECIFIED
case domain.SAMLNameIDFormatEmailAddress:
return idp_pb.SAMLNameIDFormat_SAML_NAME_ID_FORMAT_EMAIL_ADDRESS
case domain.SAMLNameIDFormatPersistent:
return idp_pb.SAMLNameIDFormat_SAML_NAME_ID_FORMAT_PERSISTENT
case domain.SAMLNameIDFormatTransient:
return idp_pb.SAMLNameIDFormat_SAML_NAME_ID_FORMAT_TRANSIENT
default:
return idp_pb.SAMLNameIDFormat_SAML_NAME_ID_FORMAT_UNSPECIFIED
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"github.com/crewjam/saml"
"github.com/muhlemmer/gu"
"github.com/zitadel/zitadel/internal/api/authz"
idp_grpc "github.com/zitadel/zitadel/internal/api/grpc/idp"
@@ -462,24 +463,36 @@ func updateAppleProviderToCommand(req *mgmt_pb.UpdateAppleProviderRequest) comma
}
func addSAMLProviderToCommand(req *mgmt_pb.AddSAMLProviderRequest) command.SAMLProvider {
var nameIDFormat *domain.SAMLNameIDFormat
if req.NameIdFormat != nil {
nameIDFormat = gu.Ptr(idp_grpc.SAMLNameIDFormatToDomain(req.GetNameIdFormat()))
}
return command.SAMLProvider{
Name: req.Name,
Metadata: req.GetMetadataXml(),
MetadataURL: req.GetMetadataUrl(),
Binding: bindingToCommand(req.Binding),
WithSignedRequest: req.WithSignedRequest,
IDPOptions: idp_grpc.OptionsToCommand(req.ProviderOptions),
Name: req.Name,
Metadata: req.GetMetadataXml(),
MetadataURL: req.GetMetadataUrl(),
Binding: bindingToCommand(req.Binding),
WithSignedRequest: req.WithSignedRequest,
NameIDFormat: nameIDFormat,
TransientMappingAttributeName: req.GetTransientMappingAttributeName(),
IDPOptions: idp_grpc.OptionsToCommand(req.ProviderOptions),
}
}
func updateSAMLProviderToCommand(req *mgmt_pb.UpdateSAMLProviderRequest) command.SAMLProvider {
var nameIDFormat *domain.SAMLNameIDFormat
if req.NameIdFormat != nil {
nameIDFormat = gu.Ptr(idp_grpc.SAMLNameIDFormatToDomain(req.GetNameIdFormat()))
}
return command.SAMLProvider{
Name: req.Name,
Metadata: req.GetMetadataXml(),
MetadataURL: req.GetMetadataUrl(),
Binding: bindingToCommand(req.Binding),
WithSignedRequest: req.WithSignedRequest,
IDPOptions: idp_grpc.OptionsToCommand(req.ProviderOptions),
Name: req.Name,
Metadata: req.GetMetadataXml(),
MetadataURL: req.GetMetadataUrl(),
Binding: bindingToCommand(req.Binding),
WithSignedRequest: req.WithSignedRequest,
NameIDFormat: nameIDFormat,
TransientMappingAttributeName: req.GetTransientMappingAttributeName(),
IDPOptions: idp_grpc.OptionsToCommand(req.ProviderOptions),
}
}

View File

@@ -1883,7 +1883,7 @@ func TestServer_StartIdentityProviderIntent(t *testing.T) {
orgResp := Tester.CreateOrganization(IamCTX, fmt.Sprintf("NotDefaultOrg%d", time.Now().UnixNano()), fmt.Sprintf("%d@mouse.com", time.Now().UnixNano()))
notDefaultOrgIdpID := Tester.AddOrgGenericOAuthProvider(t, CTX, orgResp.OrganizationId)
samlIdpID := Tester.AddSAMLProvider(t, CTX)
samlRedirectIdpID := Tester.AddSAMLRedirectProvider(t, CTX)
samlRedirectIdpID := Tester.AddSAMLRedirectProvider(t, CTX, "")
samlPostIdpID := Tester.AddSAMLPostProvider(t, CTX)
type args struct {
ctx context.Context

View File

@@ -229,11 +229,6 @@ func (h *Handler) handleACS(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
sp, err := samlProvider.GetSP()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
intent, err := h.commands.GetActiveIntent(ctx, data.RelayState)
if err != nil {
@@ -245,10 +240,10 @@ func (h *Handler) handleACS(w http.ResponseWriter, r *http.Request) {
return
}
session := saml2.Session{
ServiceProvider: sp,
RequestID: intent.RequestID,
Request: r,
session, err := saml2.NewSession(samlProvider, intent.RequestID, r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
idpUser, err := session.FetchUser(r.Context())

View File

@@ -52,7 +52,7 @@ func TestMain(m *testing.M) {
}
func TestServer_SAMLCertificate(t *testing.T) {
samlRedirectIdpID := Tester.AddSAMLRedirectProvider(t, CTX)
samlRedirectIdpID := Tester.AddSAMLRedirectProvider(t, CTX, "")
oauthIdpID := Tester.AddGenericOAuthProvider(t, CTX)
type args struct {
@@ -109,7 +109,7 @@ func TestServer_SAMLCertificate(t *testing.T) {
}
func TestServer_SAMLMetadata(t *testing.T) {
samlRedirectIdpID := Tester.AddSAMLRedirectProvider(t, CTX)
samlRedirectIdpID := Tester.AddSAMLRedirectProvider(t, CTX, "")
oauthIdpID := Tester.AddGenericOAuthProvider(t, CTX)
type args struct {
@@ -167,7 +167,7 @@ func TestServer_SAMLMetadata(t *testing.T) {
func TestServer_SAMLACS(t *testing.T) {
userHuman := Tester.CreateHumanUser(CTX)
samlRedirectIdpID := Tester.AddSAMLRedirectProvider(t, CTX)
samlRedirectIdpID := Tester.AddSAMLRedirectProvider(t, CTX, "urn:oid:0.9.2342.19200300.100.1.1") // the username is set in urn:oid:0.9.2342.19200300.100.1.1
externalUserID := "test1"
linkedExternalUserID := "test2"
Tester.CreateUserIDPlink(CTX, userHuman.UserId, linkedExternalUserID, samlRedirectIdpID, linkedExternalUserID)
@@ -180,13 +180,15 @@ func TestServer_SAMLACS(t *testing.T) {
assert.NoError(t, err)
type args struct {
ctx context.Context
successURL string
failureURL string
idpID string
username string
intentID string
response string
ctx context.Context
successURL string
failureURL string
idpID string
username string
nameID string
nameIDFormat string
intentID string
response string
}
type want struct {
successful bool
@@ -201,12 +203,14 @@ func TestServer_SAMLACS(t *testing.T) {
{
name: "intent invalid",
args: args{
ctx: CTX,
successURL: "https://example.com/success",
failureURL: "https://example.com/failure",
idpID: samlRedirectIdpID,
username: externalUserID,
intentID: "notexisting",
ctx: CTX,
successURL: "https://example.com/success",
failureURL: "https://example.com/failure",
idpID: samlRedirectIdpID,
username: externalUserID,
nameID: externalUserID,
nameIDFormat: string(saml.PersistentNameIDFormat),
intentID: "notexisting",
},
want: want{
successful: false,
@@ -217,12 +221,14 @@ func TestServer_SAMLACS(t *testing.T) {
{
name: "response invalid",
args: args{
ctx: CTX,
successURL: "https://example.com/success",
failureURL: "https://example.com/failure",
idpID: samlRedirectIdpID,
username: externalUserID,
response: "invalid",
ctx: CTX,
successURL: "https://example.com/success",
failureURL: "https://example.com/failure",
idpID: samlRedirectIdpID,
username: externalUserID,
nameID: externalUserID,
nameIDFormat: string(saml.PersistentNameIDFormat),
response: "invalid",
},
want: want{
successful: false,
@@ -232,11 +238,13 @@ func TestServer_SAMLACS(t *testing.T) {
{
name: "saml flow redirect, ok",
args: args{
ctx: CTX,
successURL: "https://example.com/success",
failureURL: "https://example.com/failure",
idpID: samlRedirectIdpID,
username: externalUserID,
ctx: CTX,
successURL: "https://example.com/success",
failureURL: "https://example.com/failure",
idpID: samlRedirectIdpID,
username: externalUserID,
nameID: externalUserID,
nameIDFormat: string(saml.PersistentNameIDFormat),
},
want: want{
successful: true,
@@ -246,11 +254,45 @@ func TestServer_SAMLACS(t *testing.T) {
{
name: "saml flow redirect with link, ok",
args: args{
ctx: CTX,
successURL: "https://example.com/success",
failureURL: "https://example.com/failure",
idpID: samlRedirectIdpID,
username: linkedExternalUserID,
ctx: CTX,
successURL: "https://example.com/success",
failureURL: "https://example.com/failure",
idpID: samlRedirectIdpID,
username: linkedExternalUserID,
nameID: linkedExternalUserID,
nameIDFormat: string(saml.PersistentNameIDFormat),
},
want: want{
successful: true,
user: userHuman.UserId,
},
},
{
name: "saml flow redirect (transient), ok",
args: args{
ctx: CTX,
successURL: "https://example.com/success",
failureURL: "https://example.com/failure",
idpID: samlRedirectIdpID,
username: externalUserID,
nameID: "genericID",
nameIDFormat: string(saml.TransientNameIDFormat),
},
want: want{
successful: true,
user: "",
},
},
{
name: "saml flow redirect with link (transient), ok",
args: args{
ctx: CTX,
successURL: "https://example.com/success",
failureURL: "https://example.com/failure",
idpID: samlRedirectIdpID,
username: linkedExternalUserID,
nameID: "genericID",
nameIDFormat: string(saml.TransientNameIDFormat),
},
want: want{
successful: true,
@@ -287,7 +329,7 @@ func TestServer_SAMLACS(t *testing.T) {
relayState = tt.args.intentID
}
callbackURL := http_util.BuildOrigin(Tester.Host(), Tester.Server.Config.ExternalSecure) + "/idps/" + tt.args.idpID + "/saml/acs"
response := createResponse(t, idp, samlRequest, tt.args.username)
response := createResponse(t, idp, samlRequest, tt.args.nameID, tt.args.nameIDFormat, tt.args.username)
//test purposes, use defined response
if tt.args.response != "" {
response = tt.args.response
@@ -432,14 +474,16 @@ func getIDP(zitadelBaseURL string, idpIDs []string, user1, user2 string) (*saml.
return &idpServer.IDP, nil
}
func createResponse(t *testing.T, idp *saml.IdentityProvider, req *http.Request, username string) string {
func createResponse(t *testing.T, idp *saml.IdentityProvider, req *http.Request, nameID, nameIDFormat, username string) string {
authnReq, err := saml.NewIdpAuthnRequest(idp, req)
assert.NoError(t, authnReq.Validate())
err = idp.AssertionMaker.MakeAssertion(authnReq, &saml.Session{
CreateTime: time.Now().UTC(),
Index: "",
NameID: username,
CreateTime: time.Now().UTC(),
Index: "",
NameID: nameID,
NameIDFormat: nameIDFormat,
UserName: username,
})
assert.NoError(t, err)
err = authnReq.MakeResponse()

View File

@@ -319,12 +319,11 @@ func (l *Login) handleExternalLoginCallback(w http.ResponseWriter, r *http.Reque
l.externalAuthFailed(w, r, authReq, nil, nil, err)
return
}
sp, err := provider.(*saml.Provider).GetSP()
session, err = saml.NewSession(provider.(*saml.Provider), authReq.SAMLRequestID, r)
if err != nil {
l.externalAuthFailed(w, r, authReq, nil, nil, err)
return
}
session = &saml.Session{ServiceProvider: sp, RequestID: authReq.SAMLRequestID, Request: r}
case domain.IDPTypeJWT,
domain.IDPTypeLDAP,
domain.IDPTypeUnspecified:
@@ -1029,13 +1028,19 @@ func (l *Login) samlProvider(ctx context.Context, identityProvider *query.IDPTem
if err != nil {
return nil, err
}
opts := make([]saml.ProviderOpts, 0, 2)
opts := make([]saml.ProviderOpts, 0, 6)
if identityProvider.SAMLIDPTemplate.WithSignedRequest {
opts = append(opts, saml.WithSignedRequest())
}
if identityProvider.SAMLIDPTemplate.Binding != "" {
opts = append(opts, saml.WithBinding(identityProvider.SAMLIDPTemplate.Binding))
}
if identityProvider.SAMLIDPTemplate.NameIDFormat.Valid {
opts = append(opts, saml.WithNameIDFormat(identityProvider.SAMLIDPTemplate.NameIDFormat.V))
}
if identityProvider.SAMLIDPTemplate.TransientMappingAttributeName != "" {
opts = append(opts, saml.WithTransientMappingAttributeName(identityProvider.SAMLIDPTemplate.TransientMappingAttributeName))
}
opts = append(opts,
saml.WithEntityID(http_utils.BuildOrigin(authz.GetInstance(ctx).RequestedHost(), l.externalSecure)+"/idps/"+identityProvider.ID+"/saml/metadata"),
saml.WithCustomRequestTracker(