feat(saml): allow setting nameid-format and alternative mapping for transient format (#7979)

# Which Problems Are Solved

ZITADEL currently always uses
`urn:oasis:names:tc:SAML:2.0:nameid-format:persistent` in SAML requests,
relying on the IdP to respect that flag and always return a peristent
nameid in order to be able to map the external user with an existing
user (idp link) in ZITADEL.
In case the IdP however returns a
`urn:oasis:names:tc:SAML:2.0:nameid-format:transient` (transient)
nameid, the attribute will differ between each request and it will not
be possible to match existing users.

# How the Problems Are Solved

This PR adds the following two options on SAML IdP:
- **nameIDFormat**: allows to set the nameid-format used in the SAML
Request
- **transientMappingAttributeName**: allows to set an attribute name,
which will be used instead of the nameid itself in case the returned
nameid-format is transient

# Additional Changes

To reduce impact on current installations, the `idp_templates6_saml`
table is altered with the two added columns by a setup job. New
installations will automatically get the table with the two columns
directly.
All idp unit tests are updated to use `expectEventstore` instead of the
deprecated `eventstoreExpect`.

# Additional Context

Closes #7483
Closes #7743

---------

Co-authored-by: peintnermax <max@caos.ch>
Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
This commit is contained in:
Livio Spring
2024-05-23 07:04:07 +02:00
committed by GitHub
parent 12be21a3ff
commit e57a9b57c8
58 changed files with 1306 additions and 720 deletions

View File

@@ -1,7 +1,6 @@
package saml
import (
"github.com/crewjam/saml"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
@@ -20,8 +19,8 @@ func NewUser() *UserMapper {
return &UserMapper{Attributes: map[string][]string{}}
}
func (u *UserMapper) SetID(id *saml.NameID) {
u.ID = id.Value
func (u *UserMapper) SetID(id string) {
u.ID = id
}
// GetID is an implementation of the [idp.User] interface.

View File

@@ -11,6 +11,7 @@ import (
"github.com/crewjam/saml"
"github.com/crewjam/saml/samlsp"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/zerrors"
)
@@ -26,7 +27,9 @@ type Provider struct {
spOptions *samlsp.Options
binding string
binding string
nameIDFormat saml.NameIDFormat
transientMappingAttributeName string
isLinkingAllowed bool
isCreationAllowed bool
@@ -77,6 +80,18 @@ func WithBinding(binding string) ProviderOpts {
}
}
func WithNameIDFormat(format domain.SAMLNameIDFormat) ProviderOpts {
return func(p *Provider) {
p.nameIDFormat = nameIDFormatFromDomain(format)
}
}
func WithTransientMappingAttributeName(attribute string) ProviderOpts {
return func(p *Provider) {
p.transientMappingAttributeName = attribute
}
}
func WithCustomRequestTracker(tracker samlsp.RequestTracker) ProviderOpts {
return func(p *Provider) {
p.requestTracker = tracker
@@ -124,6 +139,8 @@ func New(
name: name,
spOptions: &opts,
Certificate: certificate,
// the library uses transient as default, which does not make sense for federating accounts
nameIDFormat: saml.PersistentNameIDFormat,
}
for _, option := range options {
option(provider)
@@ -156,10 +173,7 @@ func (p *Provider) GetSP() (*samlsp.Middleware, error) {
if err != nil {
return nil, zerrors.ThrowInternal(err, "SAML-qee09ffuq5", "Errors.Intent.IDPInvalid")
}
// the library uses transient as default, which we currently can't handle (https://github.com/zitadel/zitadel/discussions/7421)
// for the moment we'll use persistent (for those who actually use it from the saml request) and add an option
// later on to specify on the provider: https://github.com/zitadel/zitadel/issues/7743
sp.ServiceProvider.AuthnNameIDFormat = saml.PersistentNameIDFormat
sp.ServiceProvider.AuthnNameIDFormat = p.nameIDFormat
if p.requestTracker != nil {
sp.RequestTracker = p.requestTracker
}
@@ -180,3 +194,22 @@ func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...idp.Paramet
state: state,
}, nil
}
func (p *Provider) TransientMappingAttributeName() string {
return p.transientMappingAttributeName
}
func nameIDFormatFromDomain(format domain.SAMLNameIDFormat) saml.NameIDFormat {
switch format {
case domain.SAMLNameIDFormatUnspecified:
return saml.UnspecifiedNameIDFormat
case domain.SAMLNameIDFormatEmailAddress:
return saml.EmailAddressNameIDFormat
case domain.SAMLNameIDFormatPersistent:
return saml.PersistentNameIDFormat
case domain.SAMLNameIDFormatTransient:
return saml.TransientNameIDFormat
default:
return saml.UnspecifiedNameIDFormat
}
}

View File

@@ -3,10 +3,12 @@ package saml
import (
"testing"
"github.com/crewjam/saml"
"github.com/crewjam/saml/samlsp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp/providers/saml/requesttracker"
)
@@ -20,16 +22,18 @@ func TestProvider_Options(t *testing.T) {
options []ProviderOpts
}
type want struct {
err bool
name string
linkingAllowed bool
creationAllowed bool
autoCreation bool
autoUpdate bool
binding string
withSignedRequest bool
requesttracker samlsp.RequestTracker
entityID string
err bool
name string
linkingAllowed bool
creationAllowed bool
autoCreation bool
autoUpdate bool
binding string
nameIDFormat saml.NameIDFormat
transientMappingAttributeName string
withSignedRequest bool
requesttracker samlsp.RequestTracker
entityID string
}
tests := []struct {
name string
@@ -103,10 +107,11 @@ func TestProvider_Options(t *testing.T) {
creationAllowed: false,
autoCreation: false,
autoUpdate: false,
nameIDFormat: saml.PersistentNameIDFormat,
},
},
{
name: "all true",
name: "all set / true",
fields: fields{
name: "saml",
key: []byte("-----BEGIN RSA PRIVATE KEY-----\nMIIEogIBAAKCAQEAxHd087RoEm9ywVWZ/H+tDWxQsmVvhfRz4jAq/RfU+OWXNH4J\njMMSHdFs0Q+WP98nNXRyc7fgbMb8NdmlB2yD4qLYapN5SDaBc5dh/3EnyFt53oSs\njTlKnQUPAeJr2qh/NY046CfyUyQMM4JR5OiQFo4TssfWnqdcgamGt0AEnk2lvbMZ\nKQdAqNS9lDzYbjMGavEQPTZE35mFXFQXjaooZXq+TIa7hbaq7/idH7cHNbLcPLgj\nfPQA8q+DYvnvhXlmq0LPQZH3Oiixf+SF2vRwrBzT2mqGD2OiOkUmhuPwyqEiiBHt\nfxklRtRU6WfLa1Gcb1PsV0uoBGpV3KybIl/GlwIDAQABAoIBAEQjDduLgOCL6Gem\n0X3hpdnW6/HC/jed/Sa//9jBECq2LYeWAqff64ON40hqOHi0YvvGA/+gEOSI6mWe\nsv5tIxxRz+6+cLybsq+tG96kluCE4TJMHy/nY7orS/YiWbd+4odnEApr+D3fbZ/b\nnZ1fDsHTyn8hkYx6jLmnWsJpIHDp7zxD76y7k2Bbg6DZrCGiVxngiLJk23dvz79W\np03lHLM7XE92aFwXQmhfxHGxrbuoB/9eY4ai5IHp36H4fw0vL6NXdNQAo/bhe0p9\nAYB7y0ZumF8Hg0Z/BmMeEzLy6HrYB+VE8cO93pNjhSyH+p2yDB/BlUyTiRLQAoM0\nVTmOZXECgYEA7NGlzpKNhyQEJihVqt0MW0LhKIO/xbBn+XgYfX6GpqPa/ucnMx5/\nVezpl3gK8IU4wPUhAyXXAHJiqNBcEeyxrw0MXLujDVMJgYaLysCLJdvMVgoY08mS\nK5IQivpbozpf4+0y3mOnA+Sy1kbfxv2X8xiWLODRQW3f3q/xoklwOR8CgYEA1GEe\nfaibOFTQAYcIVj77KXtBfYZsX3EGAyfAN9O7cKHq5oaxVstwnF47WxpuVtoKZxCZ\nbNm9D5WvQ9b+Ztpioe42tzwE7Bff/Osj868GcDdRPK7nFlh9N2yVn/D514dOYVwR\n4MBr1KrJzgRWt4QqS4H+to1GzudDTSNlG7gnK4kCgYBUi6AbOHzoYzZL/RhgcJwp\ntJ23nhmH1Su5h2OO4e3mbhcP66w19sxU+8iFN+kH5zfUw26utgKk+TE5vXExQQRK\nT2k7bg2PAzcgk80ybD0BHhA8I0yrx4m0nmfjhe/TPVLgh10iwgbtP+eM0i6v1vc5\nZWyvxu9N4ZEL6lpkqr0y1wKBgG/NAIQd8jhhTW7Aav8cAJQBsqQl038avJOEpYe+\nCnpsgoAAf/K0/f8TDCQVceh+t+MxtdK7fO9rWOxZjWsPo8Si5mLnUaAHoX4/OpnZ\nlYYVWMqdOEFnK+O1Yb7k2GFBdV2DXlX2dc1qavntBsls5ecB89id3pyk2aUN8Pf6\npYQhAoGAMGtrHFely9wyaxI0RTCyfmJbWZHGVGkv6ELK8wneJjdjl82XOBUGCg5q\naRCrTZ3dPitKwrUa6ibJCIFCIziiriBmjDvTHzkMvoJEap2TVxYNDR6IfINVsQ57\nlOsiC4A2uGq4Lbfld+gjoplJ5GX6qXtTgZ6m7eo0y7U6zm2tkN0=\n-----END RSA PRIVATE KEY-----\n"),
@@ -121,18 +126,22 @@ func TestProvider_Options(t *testing.T) {
WithSignedRequest(),
WithCustomRequestTracker(&requesttracker.RequestTracker{}),
WithEntityID("entityID"),
WithNameIDFormat(domain.SAMLNameIDFormatTransient),
WithTransientMappingAttributeName("attribute"),
},
},
want: want{
name: "saml",
linkingAllowed: true,
creationAllowed: true,
autoCreation: true,
autoUpdate: true,
binding: "binding",
withSignedRequest: true,
requesttracker: &requesttracker.RequestTracker{},
entityID: "entityID",
name: "saml",
linkingAllowed: true,
creationAllowed: true,
autoCreation: true,
autoUpdate: true,
binding: "binding",
entityID: "entityID",
nameIDFormat: saml.TransientNameIDFormat,
transientMappingAttributeName: "attribute",
withSignedRequest: true,
requesttracker: &requesttracker.RequestTracker{},
},
},
}
@@ -152,6 +161,8 @@ func TestProvider_Options(t *testing.T) {
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
a.Equal(tt.want.binding, provider.binding)
a.Equal(tt.want.nameIDFormat, provider.nameIDFormat)
a.Equal(tt.want.transientMappingAttributeName, provider.transientMappingAttributeName)
a.Equal(tt.want.withSignedRequest, provider.spOptions.SignRequest)
a.Equal(tt.want.requesttracker, provider.requestTracker)
a.Equal(tt.want.entityID, provider.spOptions.EntityID)

View File

@@ -17,8 +17,9 @@ var _ idp.Session = (*Session)(nil)
// Session is the [idp.Session] implementation for the SAML provider.
type Session struct {
ServiceProvider *samlsp.Middleware
state string
ServiceProvider *samlsp.Middleware
state string
TransientMappingAttributeName string
RequestID string
Request *http.Request
@@ -26,6 +27,19 @@ type Session struct {
Assertion *saml.Assertion
}
func NewSession(provider *Provider, requestID string, request *http.Request) (*Session, error) {
sp, err := provider.GetSP()
if err != nil {
return nil, err
}
return &Session{
ServiceProvider: sp,
TransientMappingAttributeName: provider.TransientMappingAttributeName(),
RequestID: requestID,
Request: request,
}, nil
}
// GetAuth implements the [idp.Session] interface.
func (s *Session) GetAuth(ctx context.Context) (string, bool) {
url, _ := url.Parse(s.state)
@@ -56,8 +70,17 @@ func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
return nil, zerrors.ThrowInvalidArgument(err, "SAML-nuo0vphhh9", "Errors.Intent.ResponseInvalid")
}
nameID := s.Assertion.Subject.NameID
userMapper := NewUser()
userMapper.SetID(s.Assertion.Subject.NameID)
// use the nameID as default mapping id
userMapper.SetID(nameID.Value)
if nameID.Format == string(saml.TransientNameIDFormat) {
mappingID, err := s.transientMappingID()
if err != nil {
return nil, err
}
userMapper.SetID(mappingID)
}
for _, statement := range s.Assertion.AttributeStatements {
for _, attribute := range statement.Attributes {
values := make([]string, len(attribute.Values))
@@ -70,6 +93,21 @@ func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
return userMapper, nil
}
func (s *Session) transientMappingID() (string, error) {
for _, statement := range s.Assertion.AttributeStatements {
for _, attribute := range statement.Attributes {
if attribute.Name != s.TransientMappingAttributeName {
continue
}
if len(attribute.Values) != 1 {
return "", zerrors.ThrowInvalidArgument(nil, "SAML-Soij4", "Errors.Intent.MissingSingleMappingAttribute")
}
return attribute.Values[0].Value, nil
}
}
return "", zerrors.ThrowInvalidArgument(nil, "SAML-swwg2", "Errors.Intent.MissingSingleMappingAttribute")
}
type TempResponseWriter struct {
header http.Header
content *bytes.Buffer

File diff suppressed because one or more lines are too long