Livio Spring 374b9a7f66
fix(saml): provide option to get internal as default ACS (#8888)
# Which Problems Are Solved

Some SAML IdPs including Google only allow to configure a single
AssertionConsumerService URL.
Since the current metadata provides multiple and the hosted login UI is
not published as neither the first nor with `isDefault=true`, those IdPs
take another and then return an error on sign in.

# How the Problems Are Solved

Allow to reorder the ACS URLs using a query parameter
(`internalUI=true`) when retrieving the metadata endpoint.
This will list the `ui/login/login/externalidp/saml/acs` first and also
set the `isDefault=true`.

# Additional Changes

None

# Additional Context

Reported by a customer
2024-11-15 07:19:43 +01:00

471 lines
15 KiB
Go

package idp
import (
"bytes"
"context"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"github.com/crewjam/saml"
"github.com/gorilla/mux"
"github.com/muhlemmer/gu"
"github.com/zitadel/logging"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/ui/login"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/form"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/apple"
"github.com/zitadel/zitadel/internal/idp/providers/azuread"
"github.com/zitadel/zitadel/internal/idp/providers/github"
"github.com/zitadel/zitadel/internal/idp/providers/gitlab"
"github.com/zitadel/zitadel/internal/idp/providers/google"
"github.com/zitadel/zitadel/internal/idp/providers/jwt"
"github.com/zitadel/zitadel/internal/idp/providers/ldap"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
openid "github.com/zitadel/zitadel/internal/idp/providers/oidc"
saml2 "github.com/zitadel/zitadel/internal/idp/providers/saml"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/zerrors"
)
const (
HandlerPrefix = "/idps"
idpPrefix = "/{" + varIDPID + ":[0-9]+}"
callbackPath = "/callback"
metadataPath = idpPrefix + "/saml/metadata"
acsPath = idpPrefix + "/saml/acs"
certificatePath = idpPrefix + "/saml/certificate"
paramIntentID = "id"
paramToken = "token"
paramUserID = "user"
paramError = "error"
paramErrorDescription = "error_description"
varIDPID = "idpid"
paramInternalUI = "internalUI"
)
type Handler struct {
commands *command.Commands
queries *query.Queries
parser *form.Parser
encryptionAlgorithm crypto.EncryptionAlgorithm
callbackURL func(ctx context.Context) string
samlRootURL func(ctx context.Context, idpID string) string
loginSAMLRootURL func(ctx context.Context) string
}
type externalIDPCallbackData struct {
State string `schema:"state"`
Code string `schema:"code"`
Error string `schema:"error"`
ErrorDescription string `schema:"error_description"`
// Apple returns a user on first registration
User string `schema:"user"`
}
type externalSAMLIDPCallbackData struct {
IDPID string
Response string
RelayState string
}
// CallbackURL generates the instance specific URL to the IDP callback handler
func CallbackURL() func(ctx context.Context) string {
return func(ctx context.Context) string {
return http_utils.DomainContext(ctx).Origin() + HandlerPrefix + callbackPath
}
}
func SAMLRootURL() func(ctx context.Context, idpID string) string {
return func(ctx context.Context, idpID string) string {
return http_utils.DomainContext(ctx).Origin() + HandlerPrefix + "/" + idpID + "/"
}
}
func LoginSAMLRootURL() func(ctx context.Context) string {
return func(ctx context.Context) string {
return http_utils.DomainContext(ctx).Origin() + login.HandlerPrefix + login.EndpointSAMLACS
}
}
func NewHandler(
commands *command.Commands,
queries *query.Queries,
encryptionAlgorithm crypto.EncryptionAlgorithm,
instanceInterceptor func(next http.Handler) http.Handler,
) http.Handler {
h := &Handler{
commands: commands,
queries: queries,
parser: form.NewParser(),
encryptionAlgorithm: encryptionAlgorithm,
callbackURL: CallbackURL(),
samlRootURL: SAMLRootURL(),
loginSAMLRootURL: LoginSAMLRootURL(),
}
router := mux.NewRouter()
router.Use(instanceInterceptor)
router.HandleFunc(callbackPath, h.handleCallback)
router.HandleFunc(metadataPath, h.handleMetadata)
router.HandleFunc(certificatePath, h.handleCertificate)
router.HandleFunc(acsPath, h.handleACS)
return router
}
func parseSAMLRequest(r *http.Request) *externalSAMLIDPCallbackData {
vars := mux.Vars(r)
return &externalSAMLIDPCallbackData{
IDPID: vars[varIDPID],
Response: r.FormValue("SAMLResponse"),
RelayState: r.FormValue("RelayState"),
}
}
func (h *Handler) getProvider(ctx context.Context, idpID string) (idp.Provider, error) {
return h.commands.GetProvider(ctx, idpID, h.callbackURL(ctx), h.samlRootURL(ctx, idpID))
}
func (h *Handler) handleCertificate(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data := parseSAMLRequest(r)
provider, err := h.getProvider(ctx, data.IDPID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
samlProvider, ok := provider.(*saml2.Provider)
if !ok {
http.Error(w, zerrors.ThrowInvalidArgument(nil, "SAML-lrud8s9coi", "Errors.Intent.IDPInvalid").Error(), http.StatusBadRequest)
return
}
certPem := new(bytes.Buffer)
if _, err := certPem.Write(samlProvider.Certificate); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.Header().Set("Content-Disposition", "attachment; filename=idp.crt")
w.Header().Set("Content-Type", r.Header.Get("Content-Type"))
_, err = io.Copy(w, certPem)
if err != nil {
http.Error(w, fmt.Errorf("failed to response with certificate: %w", err).Error(), http.StatusInternalServerError)
return
}
}
func (h *Handler) handleMetadata(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data := parseSAMLRequest(r)
provider, err := h.getProvider(ctx, data.IDPID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
samlProvider, ok := provider.(*saml2.Provider)
if !ok {
http.Error(w, zerrors.ThrowInvalidArgument(nil, "SAML-lrud8s9coi", "Errors.Intent.IDPInvalid").Error(), http.StatusBadRequest)
return
}
sp, err := samlProvider.GetSP()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
metadata := sp.ServiceProvider.Metadata()
internalUI, _ := strconv.ParseBool(r.URL.Query().Get(paramInternalUI))
h.assertionConsumerServices(ctx, metadata, internalUI)
buf, _ := xml.MarshalIndent(metadata, "", " ")
w.Header().Set("Content-Type", "application/samlmetadata+xml")
_, err = w.Write(buf)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
}
func (h *Handler) assertionConsumerServices(ctx context.Context, metadata *saml.EntityDescriptor, internalUI bool) {
if !internalUI {
for i, spDesc := range metadata.SPSSODescriptors {
spDesc.AssertionConsumerServices = append(
spDesc.AssertionConsumerServices,
saml.IndexedEndpoint{
Binding: saml.HTTPPostBinding,
Location: h.loginSAMLRootURL(ctx),
Index: len(spDesc.AssertionConsumerServices) + 1,
}, saml.IndexedEndpoint{
Binding: saml.HTTPArtifactBinding,
Location: h.loginSAMLRootURL(ctx),
Index: len(spDesc.AssertionConsumerServices) + 2,
},
)
metadata.SPSSODescriptors[i] = spDesc
}
return
}
for i, spDesc := range metadata.SPSSODescriptors {
acs := make([]saml.IndexedEndpoint, 0, len(spDesc.AssertionConsumerServices)+2)
acs = append(acs,
saml.IndexedEndpoint{
Binding: saml.HTTPPostBinding,
Location: h.loginSAMLRootURL(ctx),
Index: 0,
IsDefault: gu.Ptr(true),
},
saml.IndexedEndpoint{
Binding: saml.HTTPArtifactBinding,
Location: h.loginSAMLRootURL(ctx),
Index: 1,
})
for i := 0; i < len(spDesc.AssertionConsumerServices); i++ {
spDesc.AssertionConsumerServices[i].Index = 2 + i
acs = append(acs, spDesc.AssertionConsumerServices[i])
}
spDesc.AssertionConsumerServices = acs
metadata.SPSSODescriptors[i] = spDesc
}
}
func (h *Handler) handleACS(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data := parseSAMLRequest(r)
provider, err := h.getProvider(ctx, data.IDPID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
samlProvider, ok := provider.(*saml2.Provider)
if !ok {
err := zerrors.ThrowInvalidArgument(nil, "SAML-ui9wyux0hp", "Errors.Intent.IDPInvalid")
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
intent, err := h.commands.GetActiveIntent(ctx, data.RelayState)
if err != nil {
if zerrors.IsNotFound(err) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
redirectToFailureURLErr(w, r, intent, err)
return
}
session, err := saml2.NewSession(samlProvider, intent.RequestID, r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
idpUser, err := session.FetchUser(r.Context())
if err != nil {
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
redirectToFailureURLErr(w, r, intent, err)
return
}
userID, err := h.checkExternalUser(ctx, intent.IDPID, idpUser.GetID())
logging.WithFields("intent", intent.AggregateID).OnError(err).Error("could not check if idp user already exists")
token, err := h.commands.SucceedSAMLIDPIntent(ctx, intent, idpUser, userID, session.Assertion)
if err != nil {
redirectToFailureURLErr(w, r, intent, zerrors.ThrowInternal(err, "IDP-JdD3g", "Errors.Intent.TokenCreationFailed"))
return
}
redirectToSuccessURL(w, r, intent, token, userID)
}
func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data, err := h.parseCallbackRequest(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
intent, err := h.commands.GetActiveIntent(ctx, data.State)
if err != nil {
if zerrors.IsNotFound(err) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
redirectToFailureURLErr(w, r, intent, err)
return
}
// the provider might have returned an error
if data.Error != "" {
cmdErr := h.commands.FailIDPIntent(ctx, intent, reason(data.Error, data.ErrorDescription))
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
redirectToFailureURL(w, r, intent, data.Error, data.ErrorDescription)
return
}
provider, err := h.getProvider(ctx, intent.IDPID)
if err != nil {
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
redirectToFailureURLErr(w, r, intent, err)
return
}
idpUser, idpSession, err := h.fetchIDPUserFromCode(ctx, provider, data.Code, data.User)
if err != nil {
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
redirectToFailureURLErr(w, r, intent, err)
return
}
userID, err := h.checkExternalUser(ctx, intent.IDPID, idpUser.GetID())
logging.WithFields("intent", intent.AggregateID).OnError(err).Error("could not check if idp user already exists")
if userID == "" {
userID, err = h.tryMigrateExternalUser(ctx, intent.IDPID, idpUser, idpSession)
logging.WithFields("intent", intent.AggregateID).OnError(err).Error("migration check failed")
}
token, err := h.commands.SucceedIDPIntent(ctx, intent, idpUser, idpSession, userID)
if err != nil {
redirectToFailureURLErr(w, r, intent, zerrors.ThrowInternal(err, "IDP-JdD3g", "Errors.Intent.TokenCreationFailed"))
return
}
redirectToSuccessURL(w, r, intent, token, userID)
}
func (h *Handler) tryMigrateExternalUser(ctx context.Context, idpID string, idpUser idp.User, idpSession idp.Session) (userID string, err error) {
migration, ok := idpSession.(idp.SessionSupportsMigration)
if !ok {
return "", nil
}
previousID, err := migration.RetrievePreviousID()
if err != nil || previousID == "" {
return "", err
}
userID, err = h.checkExternalUser(ctx, idpID, previousID)
if err != nil {
return "", err
}
return userID, h.commands.MigrateUserIDP(ctx, userID, "", idpID, previousID, idpUser.GetID())
}
func (h *Handler) parseCallbackRequest(r *http.Request) (*externalIDPCallbackData, error) {
data := new(externalIDPCallbackData)
err := h.parser.Parse(r, data)
if err != nil {
return nil, err
}
if data.State == "" {
return nil, zerrors.ThrowInvalidArgument(nil, "IDP-Hk38e", "Errors.Intent.StateMissing")
}
return data, nil
}
func redirectToSuccessURL(w http.ResponseWriter, r *http.Request, intent *command.IDPIntentWriteModel, token, userID string) {
queries := intent.SuccessURL.Query()
queries.Set(paramIntentID, intent.AggregateID)
queries.Set(paramToken, token)
if userID != "" {
queries.Set(paramUserID, userID)
}
intent.SuccessURL.RawQuery = queries.Encode()
http.Redirect(w, r, intent.SuccessURL.String(), http.StatusFound)
}
func redirectToFailureURLErr(w http.ResponseWriter, r *http.Request, i *command.IDPIntentWriteModel, err error) {
msg := err.Error()
var description string
zErr := new(zerrors.ZitadelError)
if errors.As(err, &zErr) {
msg = zErr.GetID()
description = zErr.GetMessage() // TODO: i18n?
}
redirectToFailureURL(w, r, i, msg, description)
}
func redirectToFailureURL(w http.ResponseWriter, r *http.Request, i *command.IDPIntentWriteModel, err, description string) {
queries := i.FailureURL.Query()
queries.Set(paramIntentID, i.AggregateID)
queries.Set(paramError, err)
queries.Set(paramErrorDescription, description)
i.FailureURL.RawQuery = queries.Encode()
http.Redirect(w, r, i.FailureURL.String(), http.StatusFound)
}
func (h *Handler) fetchIDPUserFromCode(ctx context.Context, identityProvider idp.Provider, code string, appleUser string) (user idp.User, idpTokens idp.Session, err error) {
var session idp.Session
switch provider := identityProvider.(type) {
case *oauth.Provider:
session = &oauth.Session{Provider: provider, Code: code}
case *openid.Provider:
session = &openid.Session{Provider: provider, Code: code}
case *azuread.Provider:
session = &azuread.Session{Provider: provider, Code: code}
case *github.Provider:
session = &oauth.Session{Provider: provider.Provider, Code: code}
case *gitlab.Provider:
session = &openid.Session{Provider: provider.Provider, Code: code}
case *google.Provider:
session = &openid.Session{Provider: provider.Provider, Code: code}
case *apple.Provider:
session = &apple.Session{Session: &openid.Session{Provider: provider.Provider, Code: code}, UserFormValue: appleUser}
case *jwt.Provider, *ldap.Provider, *saml2.Provider:
return nil, nil, zerrors.ThrowInvalidArgument(nil, "IDP-52jmn", "Errors.ExternalIDP.IDPTypeNotImplemented")
default:
return nil, nil, zerrors.ThrowUnimplemented(nil, "IDP-SSDg", "Errors.ExternalIDP.IDPTypeNotImplemented")
}
user, err = session.FetchUser(ctx)
if err != nil {
return nil, nil, err
}
return user, session, nil
}
func (h *Handler) checkExternalUser(ctx context.Context, idpID, externalUserID string) (userID string, err error) {
idQuery, err := query.NewIDPUserLinkIDPIDSearchQuery(idpID)
if err != nil {
return "", err
}
externalIDQuery, err := query.NewIDPUserLinksExternalIDSearchQuery(externalUserID)
if err != nil {
return "", err
}
queries := []query.SearchQuery{
idQuery, externalIDQuery,
}
links, err := h.queries.IDPUserLinks(ctx, &query.IDPUserLinksSearchQuery{Queries: queries}, nil)
if err != nil {
return "", err
}
if len(links.Links) != 1 {
return "", nil
}
return links.Links[0].UserID, nil
}
func reason(err, description string) string {
if description == "" {
return err
}
return err + ": " + description
}