fix(saml): provide option to get internal as default ACS (#8888)

# Which Problems Are Solved

Some SAML IdPs including Google only allow to configure a single
AssertionConsumerService URL.
Since the current metadata provides multiple and the hosted login UI is
not published as neither the first nor with `isDefault=true`, those IdPs
take another and then return an error on sign in.

# How the Problems Are Solved

Allow to reorder the ACS URLs using a query parameter
(`internalUI=true`) when retrieving the metadata endpoint.
This will list the `ui/login/login/externalidp/saml/acs` first and also
set the `isDefault=true`.

# Additional Changes

None

# Additional Context

Reported by a customer
This commit is contained in:
Livio Spring 2024-11-15 07:19:43 +01:00 committed by GitHub
parent 85bdf01505
commit 374b9a7f66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 162 additions and 22 deletions

View File

@ -8,9 +8,11 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"strconv"
"github.com/crewjam/saml" "github.com/crewjam/saml"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/muhlemmer/gu"
"github.com/zitadel/logging" "github.com/zitadel/logging"
http_utils "github.com/zitadel/zitadel/internal/api/http" http_utils "github.com/zitadel/zitadel/internal/api/http"
@ -49,6 +51,7 @@ const (
paramError = "error" paramError = "error"
paramErrorDescription = "error_description" paramErrorDescription = "error_description"
varIDPID = "idpid" varIDPID = "idpid"
paramInternalUI = "internalUI"
) )
type Handler struct { type Handler struct {
@ -187,21 +190,8 @@ func (h *Handler) handleMetadata(w http.ResponseWriter, r *http.Request) {
} }
metadata := sp.ServiceProvider.Metadata() metadata := sp.ServiceProvider.Metadata()
for i, spDesc := range metadata.SPSSODescriptors { internalUI, _ := strconv.ParseBool(r.URL.Query().Get(paramInternalUI))
spDesc.AssertionConsumerServices = append( h.assertionConsumerServices(ctx, metadata, internalUI)
spDesc.AssertionConsumerServices,
saml.IndexedEndpoint{
Binding: saml.HTTPPostBinding,
Location: h.loginSAMLRootURL(ctx),
Index: len(spDesc.AssertionConsumerServices) + 1,
}, saml.IndexedEndpoint{
Binding: saml.HTTPArtifactBinding,
Location: h.loginSAMLRootURL(ctx),
Index: len(spDesc.AssertionConsumerServices) + 2,
},
)
metadata.SPSSODescriptors[i] = spDesc
}
buf, _ := xml.MarshalIndent(metadata, "", " ") buf, _ := xml.MarshalIndent(metadata, "", " ")
w.Header().Set("Content-Type", "application/samlmetadata+xml") w.Header().Set("Content-Type", "application/samlmetadata+xml")
@ -212,6 +202,48 @@ func (h *Handler) handleMetadata(w http.ResponseWriter, r *http.Request) {
} }
} }
func (h *Handler) assertionConsumerServices(ctx context.Context, metadata *saml.EntityDescriptor, internalUI bool) {
if !internalUI {
for i, spDesc := range metadata.SPSSODescriptors {
spDesc.AssertionConsumerServices = append(
spDesc.AssertionConsumerServices,
saml.IndexedEndpoint{
Binding: saml.HTTPPostBinding,
Location: h.loginSAMLRootURL(ctx),
Index: len(spDesc.AssertionConsumerServices) + 1,
}, saml.IndexedEndpoint{
Binding: saml.HTTPArtifactBinding,
Location: h.loginSAMLRootURL(ctx),
Index: len(spDesc.AssertionConsumerServices) + 2,
},
)
metadata.SPSSODescriptors[i] = spDesc
}
return
}
for i, spDesc := range metadata.SPSSODescriptors {
acs := make([]saml.IndexedEndpoint, 0, len(spDesc.AssertionConsumerServices)+2)
acs = append(acs,
saml.IndexedEndpoint{
Binding: saml.HTTPPostBinding,
Location: h.loginSAMLRootURL(ctx),
Index: 0,
IsDefault: gu.Ptr(true),
},
saml.IndexedEndpoint{
Binding: saml.HTTPArtifactBinding,
Location: h.loginSAMLRootURL(ctx),
Index: 1,
})
for i := 0; i < len(spDesc.AssertionConsumerServices); i++ {
spDesc.AssertionConsumerServices[i].Index = 2 + i
acs = append(acs, spDesc.AssertionConsumerServices[i])
}
spDesc.AssertionConsumerServices = acs
metadata.SPSSODescriptors[i] = spDesc
}
}
func (h *Handler) handleACS(w http.ResponseWriter, r *http.Request) { func (h *Handler) handleACS(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
data := parseSAMLRequest(r) data := parseSAMLRequest(r)

View File

@ -23,6 +23,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
saml_xml "github.com/zitadel/saml/pkg/provider/xml" saml_xml "github.com/zitadel/saml/pkg/provider/xml"
"github.com/zitadel/saml/pkg/provider/xml/md"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
http_util "github.com/zitadel/zitadel/internal/api/http" http_util "github.com/zitadel/zitadel/internal/api/http"
@ -111,13 +112,15 @@ func TestServer_SAMLMetadata(t *testing.T) {
oauthIdpResp := Instance.AddGenericOAuthProvider(CTX, Instance.DefaultOrg.Id) oauthIdpResp := Instance.AddGenericOAuthProvider(CTX, Instance.DefaultOrg.Id)
type args struct { type args struct {
ctx context.Context ctx context.Context
idpID string idpID string
internalUI bool
} }
tests := []struct { tests := []struct {
name string name string
args args args args
want int want int
wantACS []md.IndexedEndpointType
}{ }{
{ {
name: "saml metadata, invalid idp", name: "saml metadata, invalid idp",
@ -142,11 +145,115 @@ func TestServer_SAMLMetadata(t *testing.T) {
idpID: samlRedirectIdpID, idpID: samlRedirectIdpID,
}, },
want: http.StatusOK, want: http.StatusOK,
wantACS: []md.IndexedEndpointType{
{
XMLName: xml.Name{
Space: "urn:oasis:names:tc:SAML:2.0:metadata",
Local: "AssertionConsumerService",
},
Index: "1",
IsDefault: "",
Binding: saml.HTTPPostBinding,
Location: http_util.BuildOrigin(Instance.Host(), Instance.Config.Secure) + "/idps/" + samlRedirectIdpID + "/saml/acs",
ResponseLocation: "",
},
{
XMLName: xml.Name{
Space: "urn:oasis:names:tc:SAML:2.0:metadata",
Local: "AssertionConsumerService",
},
Index: "2",
IsDefault: "",
Binding: saml.HTTPArtifactBinding,
Location: http_util.BuildOrigin(Instance.Host(), Instance.Config.Secure) + "/idps/" + samlRedirectIdpID + "/saml/acs",
ResponseLocation: "",
},
{
XMLName: xml.Name{
Space: "urn:oasis:names:tc:SAML:2.0:metadata",
Local: "AssertionConsumerService",
},
Index: "3",
IsDefault: "",
Binding: saml.HTTPPostBinding,
Location: http_util.BuildOrigin(Instance.Host(), Instance.Config.Secure) + "/ui/login/login/externalidp/saml/acs",
ResponseLocation: "",
},
{
XMLName: xml.Name{
Space: "urn:oasis:names:tc:SAML:2.0:metadata",
Local: "AssertionConsumerService",
},
Index: "4",
IsDefault: "",
Binding: saml.HTTPArtifactBinding,
Location: http_util.BuildOrigin(Instance.Host(), Instance.Config.Secure) + "/ui/login/login/externalidp/saml/acs",
ResponseLocation: "",
},
},
},
{
name: "saml metadata, ok (internalUI)",
args: args{
ctx: CTX,
idpID: samlRedirectIdpID,
internalUI: true,
},
want: http.StatusOK,
wantACS: []md.IndexedEndpointType{
{
XMLName: xml.Name{
Space: "urn:oasis:names:tc:SAML:2.0:metadata",
Local: "AssertionConsumerService",
},
Index: "0",
IsDefault: "true",
Binding: saml.HTTPPostBinding,
Location: http_util.BuildOrigin(Instance.Host(), Instance.Config.Secure) + "/ui/login/login/externalidp/saml/acs",
ResponseLocation: "",
},
{
XMLName: xml.Name{
Space: "urn:oasis:names:tc:SAML:2.0:metadata",
Local: "AssertionConsumerService",
},
Index: "1",
IsDefault: "",
Binding: saml.HTTPArtifactBinding,
Location: http_util.BuildOrigin(Instance.Host(), Instance.Config.Secure) + "/ui/login/login/externalidp/saml/acs",
ResponseLocation: "",
},
{
XMLName: xml.Name{
Space: "urn:oasis:names:tc:SAML:2.0:metadata",
Local: "AssertionConsumerService",
},
Index: "2",
IsDefault: "",
Binding: saml.HTTPPostBinding,
Location: http_util.BuildOrigin(Instance.Host(), Instance.Config.Secure) + "/idps/" + samlRedirectIdpID + "/saml/acs",
ResponseLocation: "",
},
{
XMLName: xml.Name{
Space: "urn:oasis:names:tc:SAML:2.0:metadata",
Local: "AssertionConsumerService",
},
Index: "3",
IsDefault: "",
Binding: saml.HTTPArtifactBinding,
Location: http_util.BuildOrigin(Instance.Host(), Instance.Config.Secure) + "/idps/" + samlRedirectIdpID + "/saml/acs",
ResponseLocation: "",
},
},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
metadataURL := http_util.BuildOrigin(Instance.Host(), Instance.Config.Secure) + "/idps/" + tt.args.idpID + "/saml/metadata" metadataURL := http_util.BuildOrigin(Instance.Host(), Instance.Config.Secure) + "/idps/" + tt.args.idpID + "/saml/metadata"
if tt.args.internalUI {
metadataURL = metadataURL + "?internalUI=true"
}
resp, err := http.Get(metadataURL) resp, err := http.Get(metadataURL)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, tt.want, resp.StatusCode) assert.Equal(t, tt.want, resp.StatusCode)
@ -155,10 +262,11 @@ func TestServer_SAMLMetadata(t *testing.T) {
defer resp.Body.Close() defer resp.Body.Close()
assert.NoError(t, err) assert.NoError(t, err)
_, err = saml_xml.ParseMetadataXmlIntoStruct(b) metadata, err := saml_xml.ParseMetadataXmlIntoStruct(b)
assert.NoError(t, err) assert.NoError(t, err)
}
assert.Equal(t, metadata.SPSSODescriptor.AssertionConsumerService, tt.wantACS)
}
}) })
} }
} }