mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-15 04:18:01 +00:00
247 lines
8.3 KiB
Go
247 lines
8.3 KiB
Go
|
package idp
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"net/http"
|
||
|
|
||
|
"github.com/gorilla/mux"
|
||
|
"github.com/zitadel/logging"
|
||
|
|
||
|
"github.com/zitadel/zitadel/internal/api/authz"
|
||
|
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||
|
"github.com/zitadel/zitadel/internal/command"
|
||
|
"github.com/zitadel/zitadel/internal/crypto"
|
||
|
"github.com/zitadel/zitadel/internal/domain"
|
||
|
z_errs "github.com/zitadel/zitadel/internal/errors"
|
||
|
"github.com/zitadel/zitadel/internal/form"
|
||
|
"github.com/zitadel/zitadel/internal/idp"
|
||
|
"github.com/zitadel/zitadel/internal/idp/providers/azuread"
|
||
|
"github.com/zitadel/zitadel/internal/idp/providers/github"
|
||
|
"github.com/zitadel/zitadel/internal/idp/providers/gitlab"
|
||
|
"github.com/zitadel/zitadel/internal/idp/providers/google"
|
||
|
"github.com/zitadel/zitadel/internal/idp/providers/jwt"
|
||
|
"github.com/zitadel/zitadel/internal/idp/providers/ldap"
|
||
|
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
||
|
openid "github.com/zitadel/zitadel/internal/idp/providers/oidc"
|
||
|
"github.com/zitadel/zitadel/internal/query"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
HandlerPrefix = "/idps"
|
||
|
callbackPath = "/callback"
|
||
|
|
||
|
paramIntentID = "id"
|
||
|
paramToken = "token"
|
||
|
paramUserID = "user"
|
||
|
paramError = "error"
|
||
|
paramErrorDescription = "error_description"
|
||
|
)
|
||
|
|
||
|
type Handler struct {
|
||
|
commands *command.Commands
|
||
|
queries *query.Queries
|
||
|
parser *form.Parser
|
||
|
encryptionAlgorithm crypto.EncryptionAlgorithm
|
||
|
callbackURL func(ctx context.Context) string
|
||
|
}
|
||
|
|
||
|
type externalIDPCallbackData struct {
|
||
|
State string `schema:"state"`
|
||
|
Code string `schema:"code"`
|
||
|
Error string `schema:"error"`
|
||
|
ErrorDescription string `schema:"error_description"`
|
||
|
}
|
||
|
|
||
|
// CallbackURL generates the instance specific URL to the IDP callback handler
|
||
|
func CallbackURL(externalSecure bool) func(ctx context.Context) string {
|
||
|
return func(ctx context.Context) string {
|
||
|
return http_utils.BuildOrigin(authz.GetInstance(ctx).RequestedHost(), externalSecure) + HandlerPrefix + callbackPath
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func NewHandler(
|
||
|
commands *command.Commands,
|
||
|
queries *query.Queries,
|
||
|
encryptionAlgorithm crypto.EncryptionAlgorithm,
|
||
|
externalSecure bool,
|
||
|
instanceInterceptor func(next http.Handler) http.Handler,
|
||
|
) http.Handler {
|
||
|
h := &Handler{
|
||
|
commands: commands,
|
||
|
queries: queries,
|
||
|
parser: form.NewParser(),
|
||
|
encryptionAlgorithm: encryptionAlgorithm,
|
||
|
callbackURL: CallbackURL(externalSecure),
|
||
|
}
|
||
|
|
||
|
router := mux.NewRouter()
|
||
|
router.Use(instanceInterceptor)
|
||
|
router.HandleFunc(callbackPath, h.handleCallback)
|
||
|
return router
|
||
|
}
|
||
|
|
||
|
func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||
|
data, err := h.parseCallbackRequest(r)
|
||
|
if err != nil {
|
||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||
|
return
|
||
|
}
|
||
|
intent := h.getActiveIntent(w, r, data.State)
|
||
|
if intent == nil {
|
||
|
// if we didn't get an active intent the error was already handled (either redirected or display directly)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
ctx := r.Context()
|
||
|
// the provider might have returned an error
|
||
|
if data.Error != "" {
|
||
|
cmdErr := h.commands.FailIDPIntent(ctx, intent, reason(data.Error, data.ErrorDescription))
|
||
|
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
|
||
|
redirectToFailureURL(w, r, intent, data.Error, data.ErrorDescription)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
provider, err := h.commands.GetProvider(ctx, intent.IDPID, h.callbackURL(ctx))
|
||
|
if err != nil {
|
||
|
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
|
||
|
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
|
||
|
redirectToFailureURLErr(w, r, intent, err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
idpUser, idpSession, err := h.fetchIDPUser(ctx, provider, data.Code)
|
||
|
if err != nil {
|
||
|
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
|
||
|
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
|
||
|
redirectToFailureURLErr(w, r, intent, err)
|
||
|
return
|
||
|
}
|
||
|
userID, err := h.checkExternalUser(ctx, intent.IDPID, idpUser.GetID())
|
||
|
logging.WithFields("intent", intent.AggregateID).OnError(err).Error("could not check if idp user already exists")
|
||
|
|
||
|
token, err := h.commands.SucceedIDPIntent(ctx, intent, idpUser, idpSession, userID)
|
||
|
if err != nil {
|
||
|
redirectToFailureURLErr(w, r, intent, z_errs.ThrowInternal(err, "IDP-JdD3g", "Errors.Intent.TokenCreationFailed"))
|
||
|
return
|
||
|
}
|
||
|
redirectToSuccessURL(w, r, intent, token, userID)
|
||
|
}
|
||
|
|
||
|
func (h *Handler) parseCallbackRequest(r *http.Request) (*externalIDPCallbackData, error) {
|
||
|
data := new(externalIDPCallbackData)
|
||
|
err := h.parser.Parse(r, data)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if data.State == "" {
|
||
|
return nil, z_errs.ThrowInvalidArgument(nil, "IDP-Hk38e", "Errors.Intent.StateMissing")
|
||
|
}
|
||
|
return data, nil
|
||
|
}
|
||
|
|
||
|
func (h *Handler) getActiveIntent(w http.ResponseWriter, r *http.Request, state string) *command.IDPIntentWriteModel {
|
||
|
intent, err := h.commands.GetIntentWriteModel(r.Context(), state, "")
|
||
|
if err != nil {
|
||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||
|
return nil
|
||
|
}
|
||
|
if intent.State == domain.IDPIntentStateUnspecified {
|
||
|
http.Error(w, reason("IDP-Hk38e", "Errors.Intent.NotStarted"), http.StatusBadRequest)
|
||
|
return nil
|
||
|
}
|
||
|
if intent.State != domain.IDPIntentStateStarted {
|
||
|
redirectToFailureURL(w, r, intent, "IDP-Sfrgs", "Errors.Intent.NotStarted")
|
||
|
return nil
|
||
|
}
|
||
|
return intent
|
||
|
}
|
||
|
|
||
|
func redirectToSuccessURL(w http.ResponseWriter, r *http.Request, intent *command.IDPIntentWriteModel, token, userID string) {
|
||
|
queries := intent.SuccessURL.Query()
|
||
|
queries.Set(paramIntentID, intent.AggregateID)
|
||
|
queries.Set(paramToken, token)
|
||
|
if userID != "" {
|
||
|
queries.Set(paramUserID, userID)
|
||
|
}
|
||
|
intent.SuccessURL.RawQuery = queries.Encode()
|
||
|
http.Redirect(w, r, intent.SuccessURL.String(), http.StatusFound)
|
||
|
}
|
||
|
|
||
|
func redirectToFailureURLErr(w http.ResponseWriter, r *http.Request, i *command.IDPIntentWriteModel, err error) {
|
||
|
msg := err.Error()
|
||
|
var description string
|
||
|
zErr := new(z_errs.CaosError)
|
||
|
if errors.As(err, &zErr) {
|
||
|
msg = zErr.GetID()
|
||
|
description = zErr.GetMessage() // TODO: i18n?
|
||
|
}
|
||
|
redirectToFailureURL(w, r, i, msg, description)
|
||
|
}
|
||
|
|
||
|
func redirectToFailureURL(w http.ResponseWriter, r *http.Request, i *command.IDPIntentWriteModel, err, description string) {
|
||
|
queries := i.FailureURL.Query()
|
||
|
queries.Set(paramIntentID, i.AggregateID)
|
||
|
queries.Set(paramError, err)
|
||
|
queries.Set(paramErrorDescription, description)
|
||
|
i.FailureURL.RawQuery = queries.Encode()
|
||
|
http.Redirect(w, r, i.FailureURL.String(), http.StatusFound)
|
||
|
}
|
||
|
|
||
|
func (h *Handler) fetchIDPUser(ctx context.Context, identityProvider idp.Provider, code string) (user idp.User, idpTokens idp.Session, err error) {
|
||
|
var session idp.Session
|
||
|
switch provider := identityProvider.(type) {
|
||
|
case *oauth.Provider:
|
||
|
session = &oauth.Session{Provider: provider, Code: code}
|
||
|
case *openid.Provider:
|
||
|
session = &openid.Session{Provider: provider, Code: code}
|
||
|
case *azuread.Provider:
|
||
|
session = &oauth.Session{Provider: provider.Provider, Code: code}
|
||
|
case *github.Provider:
|
||
|
session = &oauth.Session{Provider: provider.Provider, Code: code}
|
||
|
case *gitlab.Provider:
|
||
|
session = &openid.Session{Provider: provider.Provider, Code: code}
|
||
|
case *google.Provider:
|
||
|
session = &openid.Session{Provider: provider.Provider, Code: code}
|
||
|
case *jwt.Provider, *ldap.Provider:
|
||
|
return nil, nil, z_errs.ThrowInvalidArgument(nil, "IDP-52jmn", "Errors.ExternalIDP.IDPTypeNotImplemented")
|
||
|
default:
|
||
|
return nil, nil, z_errs.ThrowUnimplemented(nil, "IDP-SSDg", "Errors.ExternalIDP.IDPTypeNotImplemented")
|
||
|
}
|
||
|
|
||
|
user, err = session.FetchUser(ctx)
|
||
|
if err != nil {
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
return user, session, nil
|
||
|
}
|
||
|
|
||
|
func (h *Handler) checkExternalUser(ctx context.Context, idpID, externalUserID string) (userID string, err error) {
|
||
|
idQuery, err := query.NewIDPUserLinkIDPIDSearchQuery(idpID)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
externalIDQuery, err := query.NewIDPUserLinksExternalIDSearchQuery(externalUserID)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
queries := []query.SearchQuery{
|
||
|
idQuery, externalIDQuery,
|
||
|
}
|
||
|
links, err := h.queries.IDPUserLinks(ctx, &query.IDPUserLinksSearchQuery{Queries: queries}, false)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
if len(links.Links) != 1 {
|
||
|
return "", nil
|
||
|
}
|
||
|
return links.Links[0].UserID, nil
|
||
|
}
|
||
|
|
||
|
func reason(err, description string) string {
|
||
|
if description == "" {
|
||
|
return err
|
||
|
}
|
||
|
return err + ": " + description
|
||
|
}
|