Livio Spring 8f88c4cf5b
feat: add PKCE option to generic OAuth2 / OIDC identity providers (#9373)
# Which Problems Are Solved

Some OAuth2 and OIDC providers require the use of PKCE for all their
clients. While ZITADEL already recommended the same for its clients, it
did not yet support the option on the IdP configuration.

# How the Problems Are Solved

- A new boolean `use_pkce` is added to the add/update generic OAuth/OIDC
endpoints.
- A new checkbox is added to the generic OAuth and OIDC provider
templates.
- The `rp.WithPKCE` option is added to the provider if the use of PKCE
has been set.
- The `rp.WithCodeChallenge` and `rp.WithCodeVerifier` options are added
to the OIDC/Auth BeginAuth and CodeExchange function.
- Store verifier or any other persistent argument in the intent or auth
request.
- Create corresponding session object before creating the intent, to be
able to store the information.
- (refactored session structs to use a constructor for unified creation
and better overview of actual usage)

Here's a screenshot showing the URI including the PKCE params:


![use_pkce_in_url](https://github.com/zitadel/zitadel/assets/30386061/eaeab123-a5da-4826-b001-2ae9efa35169)

# Additional Changes

None.

# Additional Context

- Closes #6449
- This PR replaces the existing PR (#8228) of @doncicuto. The base he
did was cherry picked. Thank you very much for that!

---------

Co-authored-by: Miguel Cabrerizo <doncicuto@gmail.com>
Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
2025-02-26 12:20:47 +00:00

471 lines
15 KiB
Go

package idp
import (
"bytes"
"context"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"github.com/crewjam/saml"
"github.com/gorilla/mux"
"github.com/muhlemmer/gu"
"github.com/zitadel/logging"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/ui/login"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/form"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/apple"
"github.com/zitadel/zitadel/internal/idp/providers/azuread"
"github.com/zitadel/zitadel/internal/idp/providers/github"
"github.com/zitadel/zitadel/internal/idp/providers/gitlab"
"github.com/zitadel/zitadel/internal/idp/providers/google"
"github.com/zitadel/zitadel/internal/idp/providers/jwt"
"github.com/zitadel/zitadel/internal/idp/providers/ldap"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
openid "github.com/zitadel/zitadel/internal/idp/providers/oidc"
saml2 "github.com/zitadel/zitadel/internal/idp/providers/saml"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/zerrors"
)
const (
HandlerPrefix = "/idps"
idpPrefix = "/{" + varIDPID + ":[0-9]+}"
callbackPath = "/callback"
metadataPath = idpPrefix + "/saml/metadata"
acsPath = idpPrefix + "/saml/acs"
certificatePath = idpPrefix + "/saml/certificate"
paramIntentID = "id"
paramToken = "token"
paramUserID = "user"
paramError = "error"
paramErrorDescription = "error_description"
varIDPID = "idpid"
paramInternalUI = "internalUI"
)
type Handler struct {
commands *command.Commands
queries *query.Queries
parser *form.Parser
encryptionAlgorithm crypto.EncryptionAlgorithm
callbackURL func(ctx context.Context) string
samlRootURL func(ctx context.Context, idpID string) string
loginSAMLRootURL func(ctx context.Context) string
}
type externalIDPCallbackData struct {
State string `schema:"state"`
Code string `schema:"code"`
Error string `schema:"error"`
ErrorDescription string `schema:"error_description"`
// Apple returns a user on first registration
User string `schema:"user"`
}
type externalSAMLIDPCallbackData struct {
IDPID string
Response string
RelayState string
}
// CallbackURL generates the instance specific URL to the IDP callback handler
func CallbackURL() func(ctx context.Context) string {
return func(ctx context.Context) string {
return http_utils.DomainContext(ctx).Origin() + HandlerPrefix + callbackPath
}
}
func SAMLRootURL() func(ctx context.Context, idpID string) string {
return func(ctx context.Context, idpID string) string {
return http_utils.DomainContext(ctx).Origin() + HandlerPrefix + "/" + idpID + "/"
}
}
func LoginSAMLRootURL() func(ctx context.Context) string {
return func(ctx context.Context) string {
return http_utils.DomainContext(ctx).Origin() + login.HandlerPrefix + login.EndpointSAMLACS
}
}
func NewHandler(
commands *command.Commands,
queries *query.Queries,
encryptionAlgorithm crypto.EncryptionAlgorithm,
instanceInterceptor func(next http.Handler) http.Handler,
) http.Handler {
h := &Handler{
commands: commands,
queries: queries,
parser: form.NewParser(),
encryptionAlgorithm: encryptionAlgorithm,
callbackURL: CallbackURL(),
samlRootURL: SAMLRootURL(),
loginSAMLRootURL: LoginSAMLRootURL(),
}
router := mux.NewRouter()
router.Use(instanceInterceptor)
router.HandleFunc(callbackPath, h.handleCallback)
router.HandleFunc(metadataPath, h.handleMetadata)
router.HandleFunc(certificatePath, h.handleCertificate)
router.HandleFunc(acsPath, h.handleACS)
return router
}
func parseSAMLRequest(r *http.Request) *externalSAMLIDPCallbackData {
vars := mux.Vars(r)
return &externalSAMLIDPCallbackData{
IDPID: vars[varIDPID],
Response: r.FormValue("SAMLResponse"),
RelayState: r.FormValue("RelayState"),
}
}
func (h *Handler) getProvider(ctx context.Context, idpID string) (idp.Provider, error) {
return h.commands.GetProvider(ctx, idpID, h.callbackURL(ctx), h.samlRootURL(ctx, idpID))
}
func (h *Handler) handleCertificate(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data := parseSAMLRequest(r)
provider, err := h.getProvider(ctx, data.IDPID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
samlProvider, ok := provider.(*saml2.Provider)
if !ok {
http.Error(w, zerrors.ThrowInvalidArgument(nil, "SAML-lrud8s9coi", "Errors.Intent.IDPInvalid").Error(), http.StatusBadRequest)
return
}
certPem := new(bytes.Buffer)
if _, err := certPem.Write(samlProvider.Certificate); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.Header().Set("Content-Disposition", "attachment; filename=idp.crt")
w.Header().Set("Content-Type", r.Header.Get("Content-Type"))
_, err = io.Copy(w, certPem)
if err != nil {
http.Error(w, fmt.Errorf("failed to response with certificate: %w", err).Error(), http.StatusInternalServerError)
return
}
}
func (h *Handler) handleMetadata(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data := parseSAMLRequest(r)
provider, err := h.getProvider(ctx, data.IDPID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
samlProvider, ok := provider.(*saml2.Provider)
if !ok {
http.Error(w, zerrors.ThrowInvalidArgument(nil, "SAML-lrud8s9coi", "Errors.Intent.IDPInvalid").Error(), http.StatusBadRequest)
return
}
sp, err := samlProvider.GetSP()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
metadata := sp.ServiceProvider.Metadata()
internalUI, _ := strconv.ParseBool(r.URL.Query().Get(paramInternalUI))
h.assertionConsumerServices(ctx, metadata, internalUI)
buf, _ := xml.MarshalIndent(metadata, "", " ")
w.Header().Set("Content-Type", "application/samlmetadata+xml")
_, err = w.Write(buf)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
}
func (h *Handler) assertionConsumerServices(ctx context.Context, metadata *saml.EntityDescriptor, internalUI bool) {
if !internalUI {
for i, spDesc := range metadata.SPSSODescriptors {
spDesc.AssertionConsumerServices = append(
spDesc.AssertionConsumerServices,
saml.IndexedEndpoint{
Binding: saml.HTTPPostBinding,
Location: h.loginSAMLRootURL(ctx),
Index: len(spDesc.AssertionConsumerServices) + 1,
}, saml.IndexedEndpoint{
Binding: saml.HTTPArtifactBinding,
Location: h.loginSAMLRootURL(ctx),
Index: len(spDesc.AssertionConsumerServices) + 2,
},
)
metadata.SPSSODescriptors[i] = spDesc
}
return
}
for i, spDesc := range metadata.SPSSODescriptors {
acs := make([]saml.IndexedEndpoint, 0, len(spDesc.AssertionConsumerServices)+2)
acs = append(acs,
saml.IndexedEndpoint{
Binding: saml.HTTPPostBinding,
Location: h.loginSAMLRootURL(ctx),
Index: 0,
IsDefault: gu.Ptr(true),
},
saml.IndexedEndpoint{
Binding: saml.HTTPArtifactBinding,
Location: h.loginSAMLRootURL(ctx),
Index: 1,
})
for i := 0; i < len(spDesc.AssertionConsumerServices); i++ {
spDesc.AssertionConsumerServices[i].Index = 2 + i
acs = append(acs, spDesc.AssertionConsumerServices[i])
}
spDesc.AssertionConsumerServices = acs
metadata.SPSSODescriptors[i] = spDesc
}
}
func (h *Handler) handleACS(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data := parseSAMLRequest(r)
provider, err := h.getProvider(ctx, data.IDPID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
samlProvider, ok := provider.(*saml2.Provider)
if !ok {
err := zerrors.ThrowInvalidArgument(nil, "SAML-ui9wyux0hp", "Errors.Intent.IDPInvalid")
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
intent, err := h.commands.GetActiveIntent(ctx, data.RelayState)
if err != nil {
if zerrors.IsNotFound(err) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
redirectToFailureURLErr(w, r, intent, err)
return
}
session, err := saml2.NewSession(samlProvider, intent.RequestID, r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
idpUser, err := session.FetchUser(r.Context())
if err != nil {
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
redirectToFailureURLErr(w, r, intent, err)
return
}
userID, err := h.checkExternalUser(ctx, intent.IDPID, idpUser.GetID())
logging.WithFields("intent", intent.AggregateID).OnError(err).Error("could not check if idp user already exists")
token, err := h.commands.SucceedSAMLIDPIntent(ctx, intent, idpUser, userID, session.Assertion)
if err != nil {
redirectToFailureURLErr(w, r, intent, zerrors.ThrowInternal(err, "IDP-JdD3g", "Errors.Intent.TokenCreationFailed"))
return
}
redirectToSuccessURL(w, r, intent, token, userID)
}
func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data, err := h.parseCallbackRequest(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
intent, err := h.commands.GetActiveIntent(ctx, data.State)
if err != nil {
if zerrors.IsNotFound(err) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
redirectToFailureURLErr(w, r, intent, err)
return
}
// the provider might have returned an error
if data.Error != "" {
cmdErr := h.commands.FailIDPIntent(ctx, intent, reason(data.Error, data.ErrorDescription))
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
redirectToFailureURL(w, r, intent, data.Error, data.ErrorDescription)
return
}
provider, err := h.getProvider(ctx, intent.IDPID)
if err != nil {
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
redirectToFailureURLErr(w, r, intent, err)
return
}
idpUser, idpSession, err := h.fetchIDPUserFromCode(ctx, provider, data.Code, data.User, intent.IDPArguments)
if err != nil {
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
redirectToFailureURLErr(w, r, intent, err)
return
}
userID, err := h.checkExternalUser(ctx, intent.IDPID, idpUser.GetID())
logging.WithFields("intent", intent.AggregateID).OnError(err).Error("could not check if idp user already exists")
if userID == "" {
userID, err = h.tryMigrateExternalUser(ctx, intent.IDPID, idpUser, idpSession)
logging.WithFields("intent", intent.AggregateID).OnError(err).Error("migration check failed")
}
token, err := h.commands.SucceedIDPIntent(ctx, intent, idpUser, idpSession, userID)
if err != nil {
redirectToFailureURLErr(w, r, intent, zerrors.ThrowInternal(err, "IDP-JdD3g", "Errors.Intent.TokenCreationFailed"))
return
}
redirectToSuccessURL(w, r, intent, token, userID)
}
func (h *Handler) tryMigrateExternalUser(ctx context.Context, idpID string, idpUser idp.User, idpSession idp.Session) (userID string, err error) {
migration, ok := idpSession.(idp.SessionSupportsMigration)
if !ok {
return "", nil
}
previousID, err := migration.RetrievePreviousID()
if err != nil || previousID == "" {
return "", err
}
userID, err = h.checkExternalUser(ctx, idpID, previousID)
if err != nil {
return "", err
}
return userID, h.commands.MigrateUserIDP(ctx, userID, "", idpID, previousID, idpUser.GetID())
}
func (h *Handler) parseCallbackRequest(r *http.Request) (*externalIDPCallbackData, error) {
data := new(externalIDPCallbackData)
err := h.parser.Parse(r, data)
if err != nil {
return nil, err
}
if data.State == "" {
return nil, zerrors.ThrowInvalidArgument(nil, "IDP-Hk38e", "Errors.Intent.StateMissing")
}
return data, nil
}
func redirectToSuccessURL(w http.ResponseWriter, r *http.Request, intent *command.IDPIntentWriteModel, token, userID string) {
queries := intent.SuccessURL.Query()
queries.Set(paramIntentID, intent.AggregateID)
queries.Set(paramToken, token)
if userID != "" {
queries.Set(paramUserID, userID)
}
intent.SuccessURL.RawQuery = queries.Encode()
http.Redirect(w, r, intent.SuccessURL.String(), http.StatusFound)
}
func redirectToFailureURLErr(w http.ResponseWriter, r *http.Request, i *command.IDPIntentWriteModel, err error) {
msg := err.Error()
var description string
zErr := new(zerrors.ZitadelError)
if errors.As(err, &zErr) {
msg = zErr.GetID()
description = zErr.GetMessage() // TODO: i18n?
}
redirectToFailureURL(w, r, i, msg, description)
}
func redirectToFailureURL(w http.ResponseWriter, r *http.Request, i *command.IDPIntentWriteModel, err, description string) {
queries := i.FailureURL.Query()
queries.Set(paramIntentID, i.AggregateID)
queries.Set(paramError, err)
queries.Set(paramErrorDescription, description)
i.FailureURL.RawQuery = queries.Encode()
http.Redirect(w, r, i.FailureURL.String(), http.StatusFound)
}
func (h *Handler) fetchIDPUserFromCode(ctx context.Context, identityProvider idp.Provider, code string, appleUser string, idpArguments map[string]any) (user idp.User, idpTokens idp.Session, err error) {
var session idp.Session
switch provider := identityProvider.(type) {
case *oauth.Provider:
session = oauth.NewSession(provider, code, idpArguments)
case *openid.Provider:
session = openid.NewSession(provider, code, idpArguments)
case *azuread.Provider:
session = azuread.NewSession(provider, code)
case *github.Provider:
session = oauth.NewSession(provider.Provider, code, idpArguments)
case *gitlab.Provider:
session = openid.NewSession(provider.Provider, code, idpArguments)
case *google.Provider:
session = openid.NewSession(provider.Provider, code, idpArguments)
case *apple.Provider:
session = apple.NewSession(provider, code, appleUser)
case *jwt.Provider, *ldap.Provider, *saml2.Provider:
return nil, nil, zerrors.ThrowInvalidArgument(nil, "IDP-52jmn", "Errors.ExternalIDP.IDPTypeNotImplemented")
default:
return nil, nil, zerrors.ThrowUnimplemented(nil, "IDP-SSDg", "Errors.ExternalIDP.IDPTypeNotImplemented")
}
user, err = session.FetchUser(ctx)
if err != nil {
return nil, nil, err
}
return user, session, nil
}
func (h *Handler) checkExternalUser(ctx context.Context, idpID, externalUserID string) (userID string, err error) {
idQuery, err := query.NewIDPUserLinkIDPIDSearchQuery(idpID)
if err != nil {
return "", err
}
externalIDQuery, err := query.NewIDPUserLinksExternalIDSearchQuery(externalUserID)
if err != nil {
return "", err
}
queries := []query.SearchQuery{
idQuery, externalIDQuery,
}
links, err := h.queries.IDPUserLinks(ctx, &query.IDPUserLinksSearchQuery{Queries: queries}, nil)
if err != nil {
return "", err
}
if len(links.Links) != 1 {
return "", nil
}
return links.Links[0].UserID, nil
}
func reason(err, description string) string {
if description == "" {
return err
}
return err + ": " + description
}