diff --git a/docs/docs/guides/integrate/identity-providers/jwt_idp.md b/docs/docs/guides/integrate/identity-providers/jwt_idp.md index 2a0e8b8e7a..52324ad5c2 100644 --- a/docs/docs/guides/integrate/identity-providers/jwt_idp.md +++ b/docs/docs/guides/integrate/identity-providers/jwt_idp.md @@ -49,7 +49,7 @@ The **JWT IdP Configuration** might then be: Therefore, if the user is redirected from ZITADEL to the JWT Endpoint on the WAF (`https://apps.test.com/existing/auth-new`), the session cookies previously issued by the WAF, will be sent along by the browser due to the path being on the same domain as the exiting application. -The WAF will reuse the session and send the JWT in the HTTP header `x-custom-tkn` to its upstream, the ZITADEL JWT Endpoint (`https://accounts.test.com/ui/login/login/jwt/authorize`). +The WAF will reuse the session and send the JWT in the HTTP header `x-custom-tkn` to its upstream, the ZITADEL JWT Endpoint (`https://accounts.test.com/ipds/jwt`). For the signature validation, ZITADEL must be able to connect to Keys Endpoint (`https://issuer.test.internal/keys`) and it will check if the token was signed (claim `iss`) by the defined Issuer (`https://issuer.test.internal`). diff --git a/docs/static/img/guides/jwt_idp.png b/docs/static/img/guides/jwt_idp.png index 73d5353521..218996aef7 100644 Binary files a/docs/static/img/guides/jwt_idp.png and b/docs/static/img/guides/jwt_idp.png differ diff --git a/internal/api/grpc/user/v2/integration_test/user_test.go b/internal/api/grpc/user/v2/integration_test/user_test.go index 7f211afd6f..4cf4ab21f8 100644 --- a/internal/api/grpc/user/v2/integration_test/user_test.go +++ b/internal/api/grpc/user/v2/integration_test/user_test.go @@ -1875,6 +1875,7 @@ func TestServer_StartIdentityProviderIntent(t *testing.T) { samlIdpID := Instance.AddSAMLProvider(IamCTX) samlRedirectIdpID := Instance.AddSAMLRedirectProvider(IamCTX, "") samlPostIdpID := Instance.AddSAMLPostProvider(IamCTX) + jwtIdPID := Instance.AddJWTProvider(IamCTX) type args struct { ctx context.Context req *user.StartIdentityProviderIntentRequest @@ -2097,6 +2098,30 @@ func TestServer_StartIdentityProviderIntent(t *testing.T) { }, wantErr: false, }, + { + name: "next step jwt idp", + args: args{ + CTX, + &user.StartIdentityProviderIntentRequest{ + IdpId: jwtIdPID, + Content: &user.StartIdentityProviderIntentRequest_Urls{ + Urls: &user.RedirectURLs{ + SuccessUrl: "https://example.com/success", + FailureUrl: "https://example.com/failure", + }, + }, + }, + }, + want: want{ + details: &object.Details{ + ChangeDate: timestamppb.Now(), + ResourceOwner: Instance.ID(), + }, + url: "https://example.com/jwt", + parametersExisting: []string{"authRequestID", "userAgentID"}, + }, + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -2134,6 +2159,7 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { oidcIdpID := Instance.AddGenericOIDCProvider(IamCTX, gofakeit.AppName()).GetId() samlIdpID := Instance.AddSAMLPostProvider(IamCTX) ldapIdpID := Instance.AddLDAPProvider(IamCTX) + jwtIdPID := Instance.AddJWTProvider(IamCTX) authURL, err := url.Parse(Instance.CreateIntent(CTX, oauthIdpID).GetAuthUrl()) require.NoError(t, err) intentID := authURL.Query().Get("state") @@ -2168,6 +2194,10 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { require.NoError(t, err) samlSuccessfulWithUserID, samlWithUserToken, samlWithUserChangeDate, samlWithUserSequence, err := sink.SuccessfulSAMLIntent(Instance.ID(), samlIdpID, "id", "user", expiry) require.NoError(t, err) + jwtSuccessfulID, jwtToken, jwtChangeDate, jwtSequence, err := sink.SuccessfulJWTIntent(Instance.ID(), jwtIdPID, "id", "", expiry) + require.NoError(t, err) + jwtSuccessfulWithUserID, jwtWithUserToken, jwtWithUserChangeDate, jwtWithUserSequence, err := sink.SuccessfulJWTIntent(Instance.ID(), jwtIdPID, "id", "user", expiry) + require.NoError(t, err) type args struct { ctx context.Context req *user.RetrieveIdentityProviderIntentRequest @@ -2591,6 +2621,88 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { }, wantErr: false, }, + { + name: "retrieve successful jwt intent", + args: args{ + CTX, + &user.RetrieveIdentityProviderIntentRequest{ + IdpIntentId: jwtSuccessfulID, + IdpIntentToken: jwtToken, + }, + }, + want: &user.RetrieveIdentityProviderIntentResponse{ + Details: &object.Details{ + ChangeDate: timestamppb.New(jwtChangeDate), + ResourceOwner: Instance.ID(), + Sequence: jwtSequence, + }, + IdpInformation: &user.IDPInformation{ + Access: &user.IDPInformation_Oauth{ + Oauth: &user.IDPOAuthAccessInformation{ + IdToken: gu.Ptr("idToken"), + }, + }, + IdpId: jwtIdPID, + UserId: "id", + UserName: "", + RawInformation: func() *structpb.Struct { + s, err := structpb.NewStruct(map[string]interface{}{ + "sub": "id", + }) + require.NoError(t, err) + return s + }(), + }, + AddHumanUser: &user.AddHumanUserRequest{ + Profile: &user.SetHumanProfile{ + PreferredLanguage: gu.Ptr("und"), + }, + IdpLinks: []*user.IDPLink{ + {IdpId: jwtIdPID, UserId: "id"}, + }, + Email: &user.SetHumanEmail{ + Verification: &user.SetHumanEmail_SendCode{SendCode: &user.SendEmailVerificationCode{}}, + }, + }, + }, + wantErr: false, + }, + { + name: "retrieve successful jwt intent with linked user", + args: args{ + CTX, + &user.RetrieveIdentityProviderIntentRequest{ + IdpIntentId: jwtSuccessfulWithUserID, + IdpIntentToken: jwtWithUserToken, + }, + }, + want: &user.RetrieveIdentityProviderIntentResponse{ + Details: &object.Details{ + ChangeDate: timestamppb.New(jwtWithUserChangeDate), + ResourceOwner: Instance.ID(), + Sequence: jwtWithUserSequence, + }, + IdpInformation: &user.IDPInformation{ + Access: &user.IDPInformation_Oauth{ + Oauth: &user.IDPOAuthAccessInformation{ + IdToken: gu.Ptr("idToken"), + }, + }, + IdpId: jwtIdPID, + UserId: "id", + UserName: "", + RawInformation: func() *structpb.Struct { + s, err := structpb.NewStruct(map[string]interface{}{ + "sub": "id", + }) + require.NoError(t, err) + return s + }(), + }, + UserId: "user", + }, + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/api/grpc/user/v2/intent.go b/internal/api/grpc/user/v2/intent.go index afb34deb83..5514b6ef03 100644 --- a/internal/api/grpc/user/v2/intent.go +++ b/internal/api/grpc/user/v2/intent.go @@ -173,7 +173,7 @@ func (s *Server) RetrieveIdentityProviderIntent(ctx context.Context, req *user.R case *oidc.Provider: idpUser, err = unmarshalIdpUser(intent.IDPUser, oidc.InitUser()) case *jwt.Provider: - idpUser, err = unmarshalIdpUser(intent.IDPUser, &jwt.User{}) + idpUser, err = unmarshalIdpUser(intent.IDPUser, jwt.InitUser()) case *azuread.Provider: idpUser, err = unmarshalRawIdpUser(intent.IDPUser, p.User()) case *github.Provider: diff --git a/internal/api/idp/idp.go b/internal/api/idp/idp.go index f688ba2352..8b1c24134a 100644 --- a/internal/api/idp/idp.go +++ b/internal/api/idp/idp.go @@ -3,6 +3,7 @@ package idp import ( "bytes" "context" + "encoding/base64" "encoding/xml" "errors" "fmt" @@ -48,6 +49,7 @@ const ( acsPath = idpPrefix + "/saml/acs" certificatePath = idpPrefix + "/saml/certificate" sloPath = idpPrefix + "/saml/slo" + jwtPath = "/jwt" paramIntentID = "id" paramToken = "token" @@ -129,6 +131,7 @@ func NewHandler( router.HandleFunc(certificatePath, h.handleCertificate) router.HandleFunc(acsPath, h.handleACS) router.HandleFunc(sloPath, h.handleSLO) + router.HandleFunc(jwtPath, h.handleJWT) return router } @@ -307,6 +310,89 @@ func (h *Handler) handleACS(w http.ResponseWriter, r *http.Request) { redirectToSuccessURL(w, r, intent, token, userID) } +func (h *Handler) handleJWT(w http.ResponseWriter, r *http.Request) { + intentID, err := h.intentIDFromJWTRequest(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + intent, err := h.commands.GetActiveIntent(r.Context(), intentID) + if err != nil { + if zerrors.IsNotFound(err) { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + redirectToFailureURLErr(w, r, intent, err) + return + } + idpConfig, err := h.getProvider(r.Context(), intent.IDPID) + if err != nil { + cmdErr := h.commands.FailIDPIntent(r.Context(), intent, err.Error()) + logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent") + redirectToFailureURLErr(w, r, intent, err) + return + } + jwtIDP, ok := idpConfig.(*jwt.Provider) + if !ok { + err := zerrors.ThrowInvalidArgument(nil, "IDP-JK23ed", "Errors.ExternalIDP.IDPTypeNotImplemented") + cmdErr := h.commands.FailIDPIntent(r.Context(), intent, err.Error()) + logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent") + redirectToFailureURLErr(w, r, intent, err) + return + } + h.handleJWTExtraction(w, r, intent, jwtIDP) +} + +func (h *Handler) intentIDFromJWTRequest(r *http.Request) (string, error) { + // for compatibility of the old JWT provider we use the auth request id parameter to pass the intent id + intentID := r.FormValue(jwt.QueryAuthRequestID) + // for compatibility of the old JWT provider we use the user agent id parameter to pass the encrypted intent id + encryptedIntentID := r.FormValue(jwt.QueryUserAgentID) + if err := h.checkIntentID(intentID, encryptedIntentID); err != nil { + return "", err + } + return intentID, nil +} + +func (h *Handler) checkIntentID(intentID, encryptedIntentID string) error { + if intentID == "" || encryptedIntentID == "" { + return zerrors.ThrowInvalidArgument(nil, "LOGIN-adfzz", "Errors.AuthRequest.MissingParameters") + } + id, err := base64.RawURLEncoding.DecodeString(encryptedIntentID) + if err != nil { + return err + } + decryptedIntentID, err := h.encryptionAlgorithm.DecryptString(id, h.encryptionAlgorithm.EncryptionKeyID()) + if err != nil { + return err + } + if intentID != decryptedIntentID { + return zerrors.ThrowInvalidArgument(nil, "LOGIN-adfzz", "Errors.AuthRequest.MissingParameters") + } + return nil +} + +func (h *Handler) handleJWTExtraction(w http.ResponseWriter, r *http.Request, intent *command.IDPIntentWriteModel, identityProvider *jwt.Provider) { + session := jwt.NewSessionFromRequest(identityProvider, r) + user, err := session.FetchUser(r.Context()) + if err != nil { + cmdErr := h.commands.FailIDPIntent(r.Context(), intent, err.Error()) + logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent") + redirectToFailureURLErr(w, r, intent, err) + return + } + + userID, err := h.checkExternalUser(r.Context(), intent.IDPID, user.GetID()) + logging.WithFields("intent", intent.AggregateID).OnError(err).Error("could not check if idp user already exists") + + token, err := h.commands.SucceedIDPIntent(r.Context(), intent, user, session, userID) + if err != nil { + redirectToFailureURLErr(w, r, intent, err) + return + } + redirectToSuccessURL(w, r, intent, token, userID) +} + func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) { ctx := r.Context() data, err := h.parseCallbackRequest(r) diff --git a/internal/idp/providers/jwt/jwt.go b/internal/idp/providers/jwt/jwt.go index 99347f31a3..d972102b01 100644 --- a/internal/idp/providers/jwt/jwt.go +++ b/internal/idp/providers/jwt/jwt.go @@ -11,14 +11,14 @@ import ( ) const ( - queryAuthRequestID = "authRequestID" - queryUserAgentID = "userAgentID" + QueryAuthRequestID = "authRequestID" + QueryUserAgentID = "userAgentID" ) var _ idp.Provider = (*Provider)(nil) var ( - ErrMissingUserAgentID = errors.New("userAgentID missing") + ErrMissingState = errors.New("state missing") ) // Provider is the [idp.Provider] implementation for a JWT provider @@ -92,32 +92,32 @@ func (p *Provider) Name() string { // It will create a [Session] with an AuthURL, pointing to the jwtEndpoint // with the authRequest and encrypted userAgent ids. func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Parameter) (idp.Session, error) { - userAgentID, err := userAgentIDFromParams(params...) - if err != nil { - return nil, err + if state == "" { + return nil, ErrMissingState } + userAgentID := userAgentIDFromParams(state, params...) redirect, err := url.Parse(p.jwtEndpoint) if err != nil { return nil, err } q := redirect.Query() - q.Set(queryAuthRequestID, state) + q.Set(QueryAuthRequestID, state) nonce, err := p.encryptionAlg.Encrypt([]byte(userAgentID)) if err != nil { return nil, err } - q.Set(queryUserAgentID, base64.RawURLEncoding.EncodeToString(nonce)) + q.Set(QueryUserAgentID, base64.RawURLEncoding.EncodeToString(nonce)) redirect.RawQuery = q.Encode() return &Session{AuthURL: redirect.String()}, nil } -func userAgentIDFromParams(params ...idp.Parameter) (string, error) { +func userAgentIDFromParams(state string, params ...idp.Parameter) string { for _, param := range params { if id, ok := param.(idp.UserAgentID); ok { - return string(id), nil + return string(id) } } - return "", ErrMissingUserAgentID + return state } // IsLinkingAllowed implements the [idp.Provider] interface. diff --git a/internal/idp/providers/jwt/jwt_test.go b/internal/idp/providers/jwt/jwt_test.go index 59e32b4690..5756c58e07 100644 --- a/internal/idp/providers/jwt/jwt_test.go +++ b/internal/idp/providers/jwt/jwt_test.go @@ -23,6 +23,7 @@ func TestProvider_BeginAuth(t *testing.T) { encryptionAlg func(t *testing.T) crypto.EncryptionAlgorithm } type args struct { + state string params []idp.Parameter } type want struct { @@ -36,7 +37,7 @@ func TestProvider_BeginAuth(t *testing.T) { want want }{ { - name: "missing userAgentID error", + name: "missing state, error", fields: fields{ issuer: "https://jwt.com", jwtEndpoint: "https://auth.com/jwt", @@ -47,14 +48,34 @@ func TestProvider_BeginAuth(t *testing.T) { }, }, args: args{ + state: "", params: nil, }, want: want{ err: func(err error) bool { - return errors.Is(err, ErrMissingUserAgentID) + return errors.Is(err, ErrMissingState) }, }, }, + { + name: "missing userAgentID, fallback to state", + fields: fields{ + issuer: "https://jwt.com", + jwtEndpoint: "https://auth.com/jwt", + keysEndpoint: "https://jwt.com/keys", + headerName: "jwt-header", + encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm { + return crypto.CreateMockEncryptionAlg(gomock.NewController(t)) + }, + }, + args: args{ + state: "testState", + params: nil, + }, + want: want{ + session: &Session{AuthURL: "https://auth.com/jwt?authRequestID=testState&userAgentID=dGVzdFN0YXRl"}, + }, + }, { name: "successful auth", fields: fields{ @@ -67,6 +88,7 @@ func TestProvider_BeginAuth(t *testing.T) { }, }, args: args{ + state: "testState", params: []idp.Parameter{ idp.UserAgentID("agent"), }, @@ -91,7 +113,7 @@ func TestProvider_BeginAuth(t *testing.T) { require.NoError(t, err) ctx := context.Background() - session, err := provider.BeginAuth(ctx, "testState", tt.args.params...) + session, err := provider.BeginAuth(ctx, tt.args.state, tt.args.params...) if tt.want.err != nil && !tt.want.err(err) { a.Fail("invalid error", err) } diff --git a/internal/idp/providers/jwt/session.go b/internal/idp/providers/jwt/session.go index 5138812f3c..85b164a9c5 100644 --- a/internal/idp/providers/jwt/session.go +++ b/internal/idp/providers/jwt/session.go @@ -5,11 +5,13 @@ import ( "errors" "fmt" "net/http" + "strings" "time" "github.com/zitadel/logging" "github.com/zitadel/oidc/v3/pkg/client/rp" "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/oauth2" "golang.org/x/text/language" "github.com/zitadel/zitadel/internal/domain" @@ -34,6 +36,11 @@ func NewSession(provider *Provider, tokens *oidc.Tokens[*oidc.IDTokenClaims]) *S return &Session{Provider: provider, Tokens: tokens} } +func NewSessionFromRequest(provider *Provider, r *http.Request) *Session { + token := strings.TrimPrefix(r.Header.Get(provider.headerName), oidc.PrefixBearer) + return NewSession(provider, &oidc.Tokens[*oidc.IDTokenClaims]{IDToken: token, Token: &oauth2.Token{}}) +} + // GetAuth implements the [idp.Session] interface. func (s *Session) GetAuth(ctx context.Context) (string, bool) { return idp.Redirect(s.AuthURL) @@ -99,6 +106,12 @@ func (s *Session) validateToken(ctx context.Context, token string) (*oidc.IDToke return claims, nil } +func InitUser() *User { + return &User{ + IDTokenClaims: &oidc.IDTokenClaims{}, + } +} + type User struct { *oidc.IDTokenClaims } diff --git a/internal/integration/client.go b/internal/integration/client.go index 3efd682ee1..320809a7e8 100644 --- a/internal/integration/client.go +++ b/internal/integration/client.go @@ -684,6 +684,24 @@ func (i *Instance) AddLDAPProvider(ctx context.Context) string { return resp.GetId() } +func (i *Instance) AddJWTProvider(ctx context.Context) string { + resp, err := i.Client.Admin.AddJWTProvider(ctx, &admin.AddJWTProviderRequest{ + Name: "jwt-idp", + Issuer: "https://example.com", + JwtEndpoint: "https://example.com/jwt", + KeysEndpoint: "https://example.com/keys", + HeaderName: "Authorization", + ProviderOptions: &idp.Options{ + IsLinkingAllowed: true, + IsCreationAllowed: true, + IsAutoCreation: true, + IsAutoUpdate: true, + }, + }) + logging.OnError(err).Panic("create jwt idp") + return resp.GetId() +} + func (i *Instance) CreateIntent(ctx context.Context, idpID string) *user_v2.StartIdentityProviderIntentResponse { resp, err := i.Client.UserV2.StartIdentityProviderIntent(ctx, &user_v2.StartIdentityProviderIntentRequest{ IdpId: idpID, diff --git a/internal/integration/sink/server.go b/internal/integration/sink/server.go index 8abb31a63e..653c5236d6 100644 --- a/internal/integration/sink/server.go +++ b/internal/integration/sink/server.go @@ -27,6 +27,7 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/idp/providers/jwt" "github.com/zitadel/zitadel/internal/idp/providers/ldap" "github.com/zitadel/zitadel/internal/idp/providers/oauth" openid "github.com/zitadel/zitadel/internal/idp/providers/oidc" @@ -124,6 +125,25 @@ func SuccessfulLDAPIntent(instanceID, idpID, idpUserID, userID string) (string, return resp.IntentID, resp.Token, resp.ChangeDate, resp.Sequence, nil } +func SuccessfulJWTIntent(instanceID, idpID, idpUserID, userID string, expiry time.Time) (string, string, time.Time, uint64, error) { + u := url.URL{ + Scheme: "http", + Host: host, + Path: successfulIntentJWTPath(), + } + resp, err := callIntent(u.String(), &SuccessfulIntentRequest{ + InstanceID: instanceID, + IDPID: idpID, + IDPUserID: idpUserID, + UserID: userID, + Expiry: expiry, + }) + if err != nil { + return "", "", time.Time{}, uint64(0), err + } + return resp.IntentID, resp.Token, resp.ChangeDate, resp.Sequence, nil +} + // StartServer starts a simple HTTP server on localhost:8081 // ZITADEL can use the server to send HTTP requests which can be // used to validate tests through [Subscribe]rs. @@ -145,6 +165,7 @@ func StartServer(commands *command.Commands) (close func()) { router.HandleFunc(successfulIntentOIDCPath(), successfulIntentHandler(commands, createSuccessfulOIDCIntent)) router.HandleFunc(successfulIntentSAMLPath(), successfulIntentHandler(commands, createSuccessfulSAMLIntent)) router.HandleFunc(successfulIntentLDAPPath(), successfulIntentHandler(commands, createSuccessfulLDAPIntent)) + router.HandleFunc(successfulIntentJWTPath(), successfulIntentHandler(commands, createSuccessfulJWTIntent)) } s := &http.Server{ Addr: listenAddr, @@ -195,6 +216,10 @@ func successfulIntentLDAPPath() string { return path.Join(successfulIntentPath(), "/", "ldap") } +func successfulIntentJWTPath() string { + return path.Join(successfulIntentPath(), "/", "jwt") +} + // forwarder handles incoming HTTP requests from ZITADEL and // forwards them to all subscribed web sockets. type forwarder struct { @@ -497,3 +522,30 @@ func createSuccessfulLDAPIntent(ctx context.Context, cmd *command.Commands, req writeModel.ProcessedSequence, }, nil } + +func createSuccessfulJWTIntent(ctx context.Context, cmd *command.Commands, req *SuccessfulIntentRequest) (*SuccessfulIntentResponse, error) { + intentID, err := createIntent(ctx, cmd, req.InstanceID, req.IDPID) + writeModel, err := cmd.GetIntentWriteModel(ctx, intentID, req.InstanceID) + idpUser := &jwt.User{ + IDTokenClaims: &oidc.IDTokenClaims{ + TokenClaims: oidc.TokenClaims{ + Subject: req.IDPUserID, + }, + }, + } + session := &jwt.Session{ + Tokens: &oidc.Tokens[*oidc.IDTokenClaims]{ + IDToken: "idToken", + }, + } + token, err := cmd.SucceedIDPIntent(ctx, writeModel, idpUser, session, req.UserID) + if err != nil { + return nil, err + } + return &SuccessfulIntentResponse{ + intentID, + token, + writeModel.ChangeDate, + writeModel.ProcessedSequence, + }, nil +}