From 374b9a7f66046253da14dcc936fe8a3842385352 Mon Sep 17 00:00:00 2001 From: Livio Spring Date: Fri, 15 Nov 2024 07:19:43 +0100 Subject: [PATCH] fix(saml): provide option to get internal as default ACS (#8888) # Which Problems Are Solved Some SAML IdPs including Google only allow to configure a single AssertionConsumerService URL. Since the current metadata provides multiple and the hosted login UI is not published as neither the first nor with `isDefault=true`, those IdPs take another and then return an error on sign in. # How the Problems Are Solved Allow to reorder the ACS URLs using a query parameter (`internalUI=true`) when retrieving the metadata endpoint. This will list the `ui/login/login/externalidp/saml/acs` first and also set the `isDefault=true`. # Additional Changes None # Additional Context Reported by a customer --- internal/api/idp/idp.go | 62 ++++++--- internal/api/idp/integration_test/idp_test.go | 122 +++++++++++++++++- 2 files changed, 162 insertions(+), 22 deletions(-) diff --git a/internal/api/idp/idp.go b/internal/api/idp/idp.go index 3d46f029da..01594c43ba 100644 --- a/internal/api/idp/idp.go +++ b/internal/api/idp/idp.go @@ -8,9 +8,11 @@ import ( "fmt" "io" "net/http" + "strconv" "github.com/crewjam/saml" "github.com/gorilla/mux" + "github.com/muhlemmer/gu" "github.com/zitadel/logging" http_utils "github.com/zitadel/zitadel/internal/api/http" @@ -49,6 +51,7 @@ const ( paramError = "error" paramErrorDescription = "error_description" varIDPID = "idpid" + paramInternalUI = "internalUI" ) type Handler struct { @@ -187,21 +190,8 @@ func (h *Handler) handleMetadata(w http.ResponseWriter, r *http.Request) { } metadata := sp.ServiceProvider.Metadata() - for i, spDesc := range metadata.SPSSODescriptors { - spDesc.AssertionConsumerServices = append( - spDesc.AssertionConsumerServices, - saml.IndexedEndpoint{ - Binding: saml.HTTPPostBinding, - Location: h.loginSAMLRootURL(ctx), - Index: len(spDesc.AssertionConsumerServices) + 1, - }, saml.IndexedEndpoint{ - Binding: saml.HTTPArtifactBinding, - Location: h.loginSAMLRootURL(ctx), - Index: len(spDesc.AssertionConsumerServices) + 2, - }, - ) - metadata.SPSSODescriptors[i] = spDesc - } + internalUI, _ := strconv.ParseBool(r.URL.Query().Get(paramInternalUI)) + h.assertionConsumerServices(ctx, metadata, internalUI) buf, _ := xml.MarshalIndent(metadata, "", " ") w.Header().Set("Content-Type", "application/samlmetadata+xml") @@ -212,6 +202,48 @@ func (h *Handler) handleMetadata(w http.ResponseWriter, r *http.Request) { } } +func (h *Handler) assertionConsumerServices(ctx context.Context, metadata *saml.EntityDescriptor, internalUI bool) { + if !internalUI { + for i, spDesc := range metadata.SPSSODescriptors { + spDesc.AssertionConsumerServices = append( + spDesc.AssertionConsumerServices, + saml.IndexedEndpoint{ + Binding: saml.HTTPPostBinding, + Location: h.loginSAMLRootURL(ctx), + Index: len(spDesc.AssertionConsumerServices) + 1, + }, saml.IndexedEndpoint{ + Binding: saml.HTTPArtifactBinding, + Location: h.loginSAMLRootURL(ctx), + Index: len(spDesc.AssertionConsumerServices) + 2, + }, + ) + metadata.SPSSODescriptors[i] = spDesc + } + return + } + for i, spDesc := range metadata.SPSSODescriptors { + acs := make([]saml.IndexedEndpoint, 0, len(spDesc.AssertionConsumerServices)+2) + acs = append(acs, + saml.IndexedEndpoint{ + Binding: saml.HTTPPostBinding, + Location: h.loginSAMLRootURL(ctx), + Index: 0, + IsDefault: gu.Ptr(true), + }, + saml.IndexedEndpoint{ + Binding: saml.HTTPArtifactBinding, + Location: h.loginSAMLRootURL(ctx), + Index: 1, + }) + for i := 0; i < len(spDesc.AssertionConsumerServices); i++ { + spDesc.AssertionConsumerServices[i].Index = 2 + i + acs = append(acs, spDesc.AssertionConsumerServices[i]) + } + spDesc.AssertionConsumerServices = acs + metadata.SPSSODescriptors[i] = spDesc + } +} + func (h *Handler) handleACS(w http.ResponseWriter, r *http.Request) { ctx := r.Context() data := parseSAMLRequest(r) diff --git a/internal/api/idp/integration_test/idp_test.go b/internal/api/idp/integration_test/idp_test.go index 8e7141271a..d7616f8f53 100644 --- a/internal/api/idp/integration_test/idp_test.go +++ b/internal/api/idp/integration_test/idp_test.go @@ -23,6 +23,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" saml_xml "github.com/zitadel/saml/pkg/provider/xml" + "github.com/zitadel/saml/pkg/provider/xml/md" "golang.org/x/crypto/bcrypt" http_util "github.com/zitadel/zitadel/internal/api/http" @@ -111,13 +112,15 @@ func TestServer_SAMLMetadata(t *testing.T) { oauthIdpResp := Instance.AddGenericOAuthProvider(CTX, Instance.DefaultOrg.Id) type args struct { - ctx context.Context - idpID string + ctx context.Context + idpID string + internalUI bool } tests := []struct { - name string - args args - want int + name string + args args + want int + wantACS []md.IndexedEndpointType }{ { name: "saml metadata, invalid idp", @@ -142,11 +145,115 @@ func TestServer_SAMLMetadata(t *testing.T) { idpID: samlRedirectIdpID, }, want: http.StatusOK, + wantACS: []md.IndexedEndpointType{ + { + XMLName: xml.Name{ + Space: "urn:oasis:names:tc:SAML:2.0:metadata", + Local: "AssertionConsumerService", + }, + Index: "1", + IsDefault: "", + Binding: saml.HTTPPostBinding, + Location: http_util.BuildOrigin(Instance.Host(), Instance.Config.Secure) + "/idps/" + samlRedirectIdpID + "/saml/acs", + ResponseLocation: "", + }, + { + XMLName: xml.Name{ + Space: "urn:oasis:names:tc:SAML:2.0:metadata", + Local: "AssertionConsumerService", + }, + Index: "2", + IsDefault: "", + Binding: saml.HTTPArtifactBinding, + Location: http_util.BuildOrigin(Instance.Host(), Instance.Config.Secure) + "/idps/" + samlRedirectIdpID + "/saml/acs", + ResponseLocation: "", + }, + { + XMLName: xml.Name{ + Space: "urn:oasis:names:tc:SAML:2.0:metadata", + Local: "AssertionConsumerService", + }, + Index: "3", + IsDefault: "", + Binding: saml.HTTPPostBinding, + Location: http_util.BuildOrigin(Instance.Host(), Instance.Config.Secure) + "/ui/login/login/externalidp/saml/acs", + ResponseLocation: "", + }, + { + XMLName: xml.Name{ + Space: "urn:oasis:names:tc:SAML:2.0:metadata", + Local: "AssertionConsumerService", + }, + Index: "4", + IsDefault: "", + Binding: saml.HTTPArtifactBinding, + Location: http_util.BuildOrigin(Instance.Host(), Instance.Config.Secure) + "/ui/login/login/externalidp/saml/acs", + ResponseLocation: "", + }, + }, + }, + { + name: "saml metadata, ok (internalUI)", + args: args{ + ctx: CTX, + idpID: samlRedirectIdpID, + internalUI: true, + }, + want: http.StatusOK, + wantACS: []md.IndexedEndpointType{ + { + XMLName: xml.Name{ + Space: "urn:oasis:names:tc:SAML:2.0:metadata", + Local: "AssertionConsumerService", + }, + Index: "0", + IsDefault: "true", + Binding: saml.HTTPPostBinding, + Location: http_util.BuildOrigin(Instance.Host(), Instance.Config.Secure) + "/ui/login/login/externalidp/saml/acs", + ResponseLocation: "", + }, + { + XMLName: xml.Name{ + Space: "urn:oasis:names:tc:SAML:2.0:metadata", + Local: "AssertionConsumerService", + }, + Index: "1", + IsDefault: "", + Binding: saml.HTTPArtifactBinding, + Location: http_util.BuildOrigin(Instance.Host(), Instance.Config.Secure) + "/ui/login/login/externalidp/saml/acs", + ResponseLocation: "", + }, + { + XMLName: xml.Name{ + Space: "urn:oasis:names:tc:SAML:2.0:metadata", + Local: "AssertionConsumerService", + }, + Index: "2", + IsDefault: "", + Binding: saml.HTTPPostBinding, + Location: http_util.BuildOrigin(Instance.Host(), Instance.Config.Secure) + "/idps/" + samlRedirectIdpID + "/saml/acs", + ResponseLocation: "", + }, + { + XMLName: xml.Name{ + Space: "urn:oasis:names:tc:SAML:2.0:metadata", + Local: "AssertionConsumerService", + }, + Index: "3", + IsDefault: "", + Binding: saml.HTTPArtifactBinding, + Location: http_util.BuildOrigin(Instance.Host(), Instance.Config.Secure) + "/idps/" + samlRedirectIdpID + "/saml/acs", + ResponseLocation: "", + }, + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { metadataURL := http_util.BuildOrigin(Instance.Host(), Instance.Config.Secure) + "/idps/" + tt.args.idpID + "/saml/metadata" + if tt.args.internalUI { + metadataURL = metadataURL + "?internalUI=true" + } resp, err := http.Get(metadataURL) assert.NoError(t, err) assert.Equal(t, tt.want, resp.StatusCode) @@ -155,10 +262,11 @@ func TestServer_SAMLMetadata(t *testing.T) { defer resp.Body.Close() assert.NoError(t, err) - _, err = saml_xml.ParseMetadataXmlIntoStruct(b) + metadata, err := saml_xml.ParseMetadataXmlIntoStruct(b) assert.NoError(t, err) - } + assert.Equal(t, metadata.SPSSODescriptor.AssertionConsumerService, tt.wantACS) + } }) } }