mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 04:07:31 +00:00
feat(saml): allow setting nameid-format and alternative mapping for transient format (#7979)
# Which Problems Are Solved ZITADEL currently always uses `urn:oasis:names:tc:SAML:2.0:nameid-format:persistent` in SAML requests, relying on the IdP to respect that flag and always return a peristent nameid in order to be able to map the external user with an existing user (idp link) in ZITADEL. In case the IdP however returns a `urn:oasis:names:tc:SAML:2.0:nameid-format:transient` (transient) nameid, the attribute will differ between each request and it will not be possible to match existing users. # How the Problems Are Solved This PR adds the following two options on SAML IdP: - **nameIDFormat**: allows to set the nameid-format used in the SAML Request - **transientMappingAttributeName**: allows to set an attribute name, which will be used instead of the nameid itself in case the returned nameid-format is transient # Additional Changes To reduce impact on current installations, the `idp_templates6_saml` table is altered with the two added columns by a setup job. New installations will automatically get the table with the two columns directly. All idp unit tests are updated to use `expectEventstore` instead of the deprecated `eventstoreExpect`. # Additional Context Closes #7483 Closes #7743 --------- Co-authored-by: peintnermax <max@caos.ch> Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"github.com/crewjam/saml"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
@@ -20,8 +19,8 @@ func NewUser() *UserMapper {
|
||||
return &UserMapper{Attributes: map[string][]string{}}
|
||||
}
|
||||
|
||||
func (u *UserMapper) SetID(id *saml.NameID) {
|
||||
u.ID = id.Value
|
||||
func (u *UserMapper) SetID(id string) {
|
||||
u.ID = id
|
||||
}
|
||||
|
||||
// GetID is an implementation of the [idp.User] interface.
|
||||
|
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/crewjam/saml"
|
||||
"github.com/crewjam/saml/samlsp"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
@@ -26,7 +27,9 @@ type Provider struct {
|
||||
|
||||
spOptions *samlsp.Options
|
||||
|
||||
binding string
|
||||
binding string
|
||||
nameIDFormat saml.NameIDFormat
|
||||
transientMappingAttributeName string
|
||||
|
||||
isLinkingAllowed bool
|
||||
isCreationAllowed bool
|
||||
@@ -77,6 +80,18 @@ func WithBinding(binding string) ProviderOpts {
|
||||
}
|
||||
}
|
||||
|
||||
func WithNameIDFormat(format domain.SAMLNameIDFormat) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.nameIDFormat = nameIDFormatFromDomain(format)
|
||||
}
|
||||
}
|
||||
|
||||
func WithTransientMappingAttributeName(attribute string) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.transientMappingAttributeName = attribute
|
||||
}
|
||||
}
|
||||
|
||||
func WithCustomRequestTracker(tracker samlsp.RequestTracker) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.requestTracker = tracker
|
||||
@@ -124,6 +139,8 @@ func New(
|
||||
name: name,
|
||||
spOptions: &opts,
|
||||
Certificate: certificate,
|
||||
// the library uses transient as default, which does not make sense for federating accounts
|
||||
nameIDFormat: saml.PersistentNameIDFormat,
|
||||
}
|
||||
for _, option := range options {
|
||||
option(provider)
|
||||
@@ -156,10 +173,7 @@ func (p *Provider) GetSP() (*samlsp.Middleware, error) {
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "SAML-qee09ffuq5", "Errors.Intent.IDPInvalid")
|
||||
}
|
||||
// the library uses transient as default, which we currently can't handle (https://github.com/zitadel/zitadel/discussions/7421)
|
||||
// for the moment we'll use persistent (for those who actually use it from the saml request) and add an option
|
||||
// later on to specify on the provider: https://github.com/zitadel/zitadel/issues/7743
|
||||
sp.ServiceProvider.AuthnNameIDFormat = saml.PersistentNameIDFormat
|
||||
sp.ServiceProvider.AuthnNameIDFormat = p.nameIDFormat
|
||||
if p.requestTracker != nil {
|
||||
sp.RequestTracker = p.requestTracker
|
||||
}
|
||||
@@ -180,3 +194,22 @@ func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...idp.Paramet
|
||||
state: state,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *Provider) TransientMappingAttributeName() string {
|
||||
return p.transientMappingAttributeName
|
||||
}
|
||||
|
||||
func nameIDFormatFromDomain(format domain.SAMLNameIDFormat) saml.NameIDFormat {
|
||||
switch format {
|
||||
case domain.SAMLNameIDFormatUnspecified:
|
||||
return saml.UnspecifiedNameIDFormat
|
||||
case domain.SAMLNameIDFormatEmailAddress:
|
||||
return saml.EmailAddressNameIDFormat
|
||||
case domain.SAMLNameIDFormatPersistent:
|
||||
return saml.PersistentNameIDFormat
|
||||
case domain.SAMLNameIDFormatTransient:
|
||||
return saml.TransientNameIDFormat
|
||||
default:
|
||||
return saml.UnspecifiedNameIDFormat
|
||||
}
|
||||
}
|
||||
|
@@ -3,10 +3,12 @@ package saml
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/crewjam/saml"
|
||||
"github.com/crewjam/saml/samlsp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/saml/requesttracker"
|
||||
)
|
||||
|
||||
@@ -20,16 +22,18 @@ func TestProvider_Options(t *testing.T) {
|
||||
options []ProviderOpts
|
||||
}
|
||||
type want struct {
|
||||
err bool
|
||||
name string
|
||||
linkingAllowed bool
|
||||
creationAllowed bool
|
||||
autoCreation bool
|
||||
autoUpdate bool
|
||||
binding string
|
||||
withSignedRequest bool
|
||||
requesttracker samlsp.RequestTracker
|
||||
entityID string
|
||||
err bool
|
||||
name string
|
||||
linkingAllowed bool
|
||||
creationAllowed bool
|
||||
autoCreation bool
|
||||
autoUpdate bool
|
||||
binding string
|
||||
nameIDFormat saml.NameIDFormat
|
||||
transientMappingAttributeName string
|
||||
withSignedRequest bool
|
||||
requesttracker samlsp.RequestTracker
|
||||
entityID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -103,10 +107,11 @@ func TestProvider_Options(t *testing.T) {
|
||||
creationAllowed: false,
|
||||
autoCreation: false,
|
||||
autoUpdate: false,
|
||||
nameIDFormat: saml.PersistentNameIDFormat,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all true",
|
||||
name: "all set / true",
|
||||
fields: fields{
|
||||
name: "saml",
|
||||
key: []byte("-----BEGIN RSA PRIVATE KEY-----\nMIIEogIBAAKCAQEAxHd087RoEm9ywVWZ/H+tDWxQsmVvhfRz4jAq/RfU+OWXNH4J\njMMSHdFs0Q+WP98nNXRyc7fgbMb8NdmlB2yD4qLYapN5SDaBc5dh/3EnyFt53oSs\njTlKnQUPAeJr2qh/NY046CfyUyQMM4JR5OiQFo4TssfWnqdcgamGt0AEnk2lvbMZ\nKQdAqNS9lDzYbjMGavEQPTZE35mFXFQXjaooZXq+TIa7hbaq7/idH7cHNbLcPLgj\nfPQA8q+DYvnvhXlmq0LPQZH3Oiixf+SF2vRwrBzT2mqGD2OiOkUmhuPwyqEiiBHt\nfxklRtRU6WfLa1Gcb1PsV0uoBGpV3KybIl/GlwIDAQABAoIBAEQjDduLgOCL6Gem\n0X3hpdnW6/HC/jed/Sa//9jBECq2LYeWAqff64ON40hqOHi0YvvGA/+gEOSI6mWe\nsv5tIxxRz+6+cLybsq+tG96kluCE4TJMHy/nY7orS/YiWbd+4odnEApr+D3fbZ/b\nnZ1fDsHTyn8hkYx6jLmnWsJpIHDp7zxD76y7k2Bbg6DZrCGiVxngiLJk23dvz79W\np03lHLM7XE92aFwXQmhfxHGxrbuoB/9eY4ai5IHp36H4fw0vL6NXdNQAo/bhe0p9\nAYB7y0ZumF8Hg0Z/BmMeEzLy6HrYB+VE8cO93pNjhSyH+p2yDB/BlUyTiRLQAoM0\nVTmOZXECgYEA7NGlzpKNhyQEJihVqt0MW0LhKIO/xbBn+XgYfX6GpqPa/ucnMx5/\nVezpl3gK8IU4wPUhAyXXAHJiqNBcEeyxrw0MXLujDVMJgYaLysCLJdvMVgoY08mS\nK5IQivpbozpf4+0y3mOnA+Sy1kbfxv2X8xiWLODRQW3f3q/xoklwOR8CgYEA1GEe\nfaibOFTQAYcIVj77KXtBfYZsX3EGAyfAN9O7cKHq5oaxVstwnF47WxpuVtoKZxCZ\nbNm9D5WvQ9b+Ztpioe42tzwE7Bff/Osj868GcDdRPK7nFlh9N2yVn/D514dOYVwR\n4MBr1KrJzgRWt4QqS4H+to1GzudDTSNlG7gnK4kCgYBUi6AbOHzoYzZL/RhgcJwp\ntJ23nhmH1Su5h2OO4e3mbhcP66w19sxU+8iFN+kH5zfUw26utgKk+TE5vXExQQRK\nT2k7bg2PAzcgk80ybD0BHhA8I0yrx4m0nmfjhe/TPVLgh10iwgbtP+eM0i6v1vc5\nZWyvxu9N4ZEL6lpkqr0y1wKBgG/NAIQd8jhhTW7Aav8cAJQBsqQl038avJOEpYe+\nCnpsgoAAf/K0/f8TDCQVceh+t+MxtdK7fO9rWOxZjWsPo8Si5mLnUaAHoX4/OpnZ\nlYYVWMqdOEFnK+O1Yb7k2GFBdV2DXlX2dc1qavntBsls5ecB89id3pyk2aUN8Pf6\npYQhAoGAMGtrHFely9wyaxI0RTCyfmJbWZHGVGkv6ELK8wneJjdjl82XOBUGCg5q\naRCrTZ3dPitKwrUa6ibJCIFCIziiriBmjDvTHzkMvoJEap2TVxYNDR6IfINVsQ57\nlOsiC4A2uGq4Lbfld+gjoplJ5GX6qXtTgZ6m7eo0y7U6zm2tkN0=\n-----END RSA PRIVATE KEY-----\n"),
|
||||
@@ -121,18 +126,22 @@ func TestProvider_Options(t *testing.T) {
|
||||
WithSignedRequest(),
|
||||
WithCustomRequestTracker(&requesttracker.RequestTracker{}),
|
||||
WithEntityID("entityID"),
|
||||
WithNameIDFormat(domain.SAMLNameIDFormatTransient),
|
||||
WithTransientMappingAttributeName("attribute"),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
name: "saml",
|
||||
linkingAllowed: true,
|
||||
creationAllowed: true,
|
||||
autoCreation: true,
|
||||
autoUpdate: true,
|
||||
binding: "binding",
|
||||
withSignedRequest: true,
|
||||
requesttracker: &requesttracker.RequestTracker{},
|
||||
entityID: "entityID",
|
||||
name: "saml",
|
||||
linkingAllowed: true,
|
||||
creationAllowed: true,
|
||||
autoCreation: true,
|
||||
autoUpdate: true,
|
||||
binding: "binding",
|
||||
entityID: "entityID",
|
||||
nameIDFormat: saml.TransientNameIDFormat,
|
||||
transientMappingAttributeName: "attribute",
|
||||
withSignedRequest: true,
|
||||
requesttracker: &requesttracker.RequestTracker{},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -152,6 +161,8 @@ func TestProvider_Options(t *testing.T) {
|
||||
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
|
||||
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
|
||||
a.Equal(tt.want.binding, provider.binding)
|
||||
a.Equal(tt.want.nameIDFormat, provider.nameIDFormat)
|
||||
a.Equal(tt.want.transientMappingAttributeName, provider.transientMappingAttributeName)
|
||||
a.Equal(tt.want.withSignedRequest, provider.spOptions.SignRequest)
|
||||
a.Equal(tt.want.requesttracker, provider.requestTracker)
|
||||
a.Equal(tt.want.entityID, provider.spOptions.EntityID)
|
||||
|
@@ -17,8 +17,9 @@ var _ idp.Session = (*Session)(nil)
|
||||
|
||||
// Session is the [idp.Session] implementation for the SAML provider.
|
||||
type Session struct {
|
||||
ServiceProvider *samlsp.Middleware
|
||||
state string
|
||||
ServiceProvider *samlsp.Middleware
|
||||
state string
|
||||
TransientMappingAttributeName string
|
||||
|
||||
RequestID string
|
||||
Request *http.Request
|
||||
@@ -26,6 +27,19 @@ type Session struct {
|
||||
Assertion *saml.Assertion
|
||||
}
|
||||
|
||||
func NewSession(provider *Provider, requestID string, request *http.Request) (*Session, error) {
|
||||
sp, err := provider.GetSP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Session{
|
||||
ServiceProvider: sp,
|
||||
TransientMappingAttributeName: provider.TransientMappingAttributeName(),
|
||||
RequestID: requestID,
|
||||
Request: request,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetAuth implements the [idp.Session] interface.
|
||||
func (s *Session) GetAuth(ctx context.Context) (string, bool) {
|
||||
url, _ := url.Parse(s.state)
|
||||
@@ -56,8 +70,17 @@ func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
|
||||
return nil, zerrors.ThrowInvalidArgument(err, "SAML-nuo0vphhh9", "Errors.Intent.ResponseInvalid")
|
||||
}
|
||||
|
||||
nameID := s.Assertion.Subject.NameID
|
||||
userMapper := NewUser()
|
||||
userMapper.SetID(s.Assertion.Subject.NameID)
|
||||
// use the nameID as default mapping id
|
||||
userMapper.SetID(nameID.Value)
|
||||
if nameID.Format == string(saml.TransientNameIDFormat) {
|
||||
mappingID, err := s.transientMappingID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userMapper.SetID(mappingID)
|
||||
}
|
||||
for _, statement := range s.Assertion.AttributeStatements {
|
||||
for _, attribute := range statement.Attributes {
|
||||
values := make([]string, len(attribute.Values))
|
||||
@@ -70,6 +93,21 @@ func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
|
||||
return userMapper, nil
|
||||
}
|
||||
|
||||
func (s *Session) transientMappingID() (string, error) {
|
||||
for _, statement := range s.Assertion.AttributeStatements {
|
||||
for _, attribute := range statement.Attributes {
|
||||
if attribute.Name != s.TransientMappingAttributeName {
|
||||
continue
|
||||
}
|
||||
if len(attribute.Values) != 1 {
|
||||
return "", zerrors.ThrowInvalidArgument(nil, "SAML-Soij4", "Errors.Intent.MissingSingleMappingAttribute")
|
||||
}
|
||||
return attribute.Values[0].Value, nil
|
||||
}
|
||||
}
|
||||
return "", zerrors.ThrowInvalidArgument(nil, "SAML-swwg2", "Errors.Intent.MissingSingleMappingAttribute")
|
||||
}
|
||||
|
||||
type TempResponseWriter struct {
|
||||
header http.Header
|
||||
content *bytes.Buffer
|
||||
|
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user