perf(oidc): optimize token creation (#7822)

* implement code exchange

* port tokenexchange to v2 tokens

* implement refresh token

* implement client credentials

* implement jwt profile

* implement device token

* cleanup unused code

* fix current unit tests

* add user agent unit test

* unit test domain package

* need refresh token as argument

* test commands create oidc session

* test commands device auth

* fix device auth build error

* implicit for oidc session API

* implement authorize callback handler for legacy implicit mode

* upgrade oidc module to working draft

* add missing auth methods and time

* handle all errors in defer

* do not fail auth request on error

the oauth2 Go client automagically retries on any error. If we fail the auth request on the first error, the next attempt will always fail with the Errors.AuthRequest.NoCode, because the auth request state is already set to failed.
The original error is then already lost and the oauth2 library does not return the original error.

Therefore we should not fail the auth request.

Might be worth discussing and perhaps send a bug report to Oauth2?

* fix code flow tests by explicitly setting code exchanged

* fix unit tests in command package

* return allowed scope from client credential client

* add device auth done reducer

* carry nonce thru session into ID token

* fix token exchange integration tests

* allow project role scope prefix in client credentials client

* gci formatting

* do not return refresh token in client credentials and jwt profile

* check org scope

* solve linting issue on authorize callback error

* end session based on v2 session ID

* use preferred language and user agent ID for v2 access tokens

* pin oidc v3.23.2

* add integration test for jwt profile and client credentials with org scopes

* refresh token v1 to v2

* add user token v2 audit event

* add activity trigger

* cleanup and set panics for unused methods

* use the encrypted code for v1 auth request get by code

* add missing event translation

* fix pipeline errors (hopefully)

* fix another test

* revert pointer usage of preferred language

* solve browser info panic in device auth

* remove duplicate entries in AMRToAuthMethodTypes to prevent future `mfa` claim

* revoke v1 refresh token to prevent reuse

* fix terminate oidc session

* always return a new refresh toke in refresh token grant

---------

Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
Tim Möhlmann
2024-05-16 08:07:56 +03:00
committed by GitHub
parent 6cf9ca9f7e
commit 8e0c8393e9
84 changed files with 3429 additions and 2635 deletions

View File

@@ -6,8 +6,10 @@ import (
"strings"
"time"
"github.com/muhlemmer/gu"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/command"
@@ -19,19 +21,20 @@ import (
)
type accessToken struct {
tokenID string
userID string
resourceOwner string
subject string
clientID string
audience []string
scope []string
authMethods []domain.UserAuthMethodType
authTime time.Time
tokenCreation time.Time
tokenExpiration time.Time
isPAT bool
actor *domain.TokenActor
tokenID string
userID string
resourceOwner string
subject string
preferredLanguage *language.Tag
clientID string
audience []string
scope []string
authMethods []domain.UserAuthMethodType
authTime time.Time
tokenCreation time.Time
tokenExpiration time.Time
isPAT bool
actor *domain.TokenActor
}
var ErrInvalidTokenFormat = errors.New("invalid token format")
@@ -73,35 +76,41 @@ func (s *Server) verifyAccessToken(ctx context.Context, tkn string) (_ *accessTo
}
func accessTokenV1(tokenID, subject string, token *model.TokenView) *accessToken {
var preferredLanguage *language.Tag
if token.PreferredLanguage != "" {
preferredLanguage = gu.Ptr(language.Make(token.PreferredLanguage))
}
return &accessToken{
tokenID: tokenID,
userID: token.UserID,
resourceOwner: token.ResourceOwner,
subject: subject,
clientID: token.ApplicationID,
audience: token.Audience,
scope: token.Scopes,
tokenCreation: token.CreationDate,
tokenExpiration: token.Expiration,
isPAT: token.IsPAT,
actor: token.Actor,
tokenID: tokenID,
userID: token.UserID,
resourceOwner: token.ResourceOwner,
subject: subject,
preferredLanguage: preferredLanguage,
clientID: token.ApplicationID,
audience: token.Audience,
scope: token.Scopes,
tokenCreation: token.CreationDate,
tokenExpiration: token.Expiration,
isPAT: token.IsPAT,
actor: token.Actor,
}
}
func accessTokenV2(tokenID, subject string, token *query.OIDCSessionAccessTokenReadModel) *accessToken {
return &accessToken{
tokenID: tokenID,
userID: token.UserID,
resourceOwner: token.ResourceOwner,
subject: subject,
clientID: token.ClientID,
audience: token.Audience,
scope: token.Scope,
authMethods: token.AuthMethods,
authTime: token.AuthTime,
tokenCreation: token.AccessTokenCreation,
tokenExpiration: token.AccessTokenExpiration,
actor: token.Actor,
tokenID: tokenID,
userID: token.UserID,
resourceOwner: token.ResourceOwner,
subject: subject,
preferredLanguage: token.PreferredLanguage,
clientID: token.ClientID,
audience: token.Audience,
scope: token.Scope,
authMethods: token.AuthMethods,
authTime: token.AuthTime,
tokenCreation: token.AccessTokenCreation,
tokenExpiration: token.AccessTokenExpiration,
actor: token.Actor,
}
}

View File

@@ -1,6 +1,10 @@
package oidc
import "github.com/zitadel/zitadel/internal/domain"
import (
"slices"
"github.com/zitadel/zitadel/internal/domain"
)
const (
// Password states that the users password has been verified
@@ -87,5 +91,5 @@ func AMRToAuthMethodTypes(amr []string) []domain.UserAuthMethodType {
authMethods = append(authMethods, domain.UserAuthMethodTypeU2F)
}
}
return authMethods
return slices.Compact(authMethods) // remove duplicate entries
}

View File

@@ -3,14 +3,17 @@ package oidc
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"slices"
"strings"
"time"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/activity"
"github.com/zitadel/zitadel/internal/api/authz"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/http/middleware"
@@ -64,18 +67,19 @@ func (o *OPStorage) createAuthRequestLoginClient(ctx context.Context, req *oidc.
return nil, err
}
authRequest := &command.AuthRequest{
LoginClient: loginClient,
ClientID: req.ClientID,
RedirectURI: req.RedirectURI,
State: req.State,
Nonce: req.Nonce,
Scope: scope,
Audience: audience,
ResponseType: ResponseTypeToBusiness(req.ResponseType),
CodeChallenge: CodeChallengeToBusiness(req.CodeChallenge, req.CodeChallengeMethod),
Prompt: PromptToBusiness(req.Prompt),
UILocales: UILocalesToBusiness(req.UILocales),
MaxAge: MaxAgeToBusiness(req.MaxAge),
LoginClient: loginClient,
ClientID: req.ClientID,
RedirectURI: req.RedirectURI,
State: req.State,
Nonce: req.Nonce,
Scope: scope,
Audience: audience,
NeedRefreshToken: slices.Contains(scope, oidc.ScopeOfflineAccess),
ResponseType: ResponseTypeToBusiness(req.ResponseType),
CodeChallenge: CodeChallengeToBusiness(req.CodeChallenge, req.CodeChallengeMethod),
Prompt: PromptToBusiness(req.Prompt),
UILocales: UILocalesToBusiness(req.UILocales),
MaxAge: MaxAgeToBusiness(req.MaxAge),
}
if req.LoginHint != "" {
authRequest.LoginHint = &req.LoginHint
@@ -149,28 +153,7 @@ func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRe
}
func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
plainCode, err := o.decryptGrant(code)
if err != nil {
return nil, zerrors.ThrowInvalidArgument(err, "OIDC-ahLi2", "Errors.User.Code.Invalid")
}
if strings.HasPrefix(plainCode, command.IDPrefixV2) {
authReq, err := o.command.ExchangeAuthCode(ctx, plainCode)
if err != nil {
return nil, err
}
return &AuthRequestV2{authReq}, nil
}
resp, err := o.repo.AuthRequestByCode(ctx, code)
if err != nil {
return nil, err
}
return AuthRequestFromBusiness(resp)
panic(o.panicErr("AuthRequestByCode"))
}
// decryptGrant decrypts a code or refresh_token
@@ -201,136 +184,40 @@ func (o *OPStorage) SaveAuthCode(ctx context.Context, id, code string) (err erro
}
func (o *OPStorage) DeleteAuthRequest(ctx context.Context, id string) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
return o.repo.DeleteAuthRequest(ctx, id)
panic(o.panicErr("DeleteAuthRequest"))
}
func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (_ string, _ time.Time, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
if authReq, ok := req.(*AuthRequestV2); ok {
activity.Trigger(ctx, "", authReq.CurrentAuthRequest.UserID, activity.OIDCAccessToken, o.eventstore.FilterToQueryReducer)
return o.command.AddOIDCSessionAccessToken(setContextUserSystem(ctx), authReq.GetID())
}
userAgentID, applicationID, userOrgID, authTime, amr, reason, actor := getInfoFromRequest(req)
accessTokenLifetime, _, _, _, err := o.getOIDCSettings(ctx)
if err != nil {
return "", time.Time{}, err
}
resp, err := o.command.AddUserToken(setContextUserSystem(ctx), userOrgID, userAgentID, applicationID, req.GetSubject(), req.GetAudience(), req.GetScopes(), amr, accessTokenLifetime, authTime, reason, actor)
if err != nil {
return "", time.Time{}, err
}
// trigger activity log for authentication for user
activity.Trigger(ctx, userOrgID, req.GetSubject(), activity.OIDCAccessToken, o.eventstore.FilterToQueryReducer)
return resp.TokenID, resp.Expiration, nil
func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (string, time.Time, error) {
panic(o.panicErr("CreateAccessToken"))
}
func (o *OPStorage) CreateAccessAndRefreshTokens(ctx context.Context, req op.TokenRequest, refreshToken string) (_, _ string, _ time.Time, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
// handle V2 request directly
switch tokenReq := req.(type) {
case *AuthRequestV2:
// trigger activity log for authentication for user
activity.Trigger(ctx, "", tokenReq.GetSubject(), activity.OIDCRefreshToken, o.eventstore.FilterToQueryReducer)
return o.command.AddOIDCSessionRefreshAndAccessToken(setContextUserSystem(ctx), tokenReq.GetID())
case *RefreshTokenRequestV2:
// trigger activity log for authentication for user
activity.Trigger(ctx, "", tokenReq.GetSubject(), activity.OIDCRefreshToken, o.eventstore.FilterToQueryReducer)
return o.command.ExchangeOIDCSessionRefreshAndAccessToken(setContextUserSystem(ctx), tokenReq.OIDCSessionWriteModel.AggregateID, refreshToken, tokenReq.RequestedScopes)
}
userAgentID, applicationID, userOrgID, authTime, authMethodsReferences, reason, actor := getInfoFromRequest(req)
scopes, err := o.assertProjectRoleScopes(ctx, applicationID, req.GetScopes())
if err != nil {
return "", "", time.Time{}, zerrors.ThrowPreconditionFailed(err, "OIDC-Df2fq", "Errors.Internal")
}
if request, ok := req.(op.RefreshTokenRequest); ok {
request.SetCurrentScopes(scopes)
}
accessTokenLifetime, _, refreshTokenIdleExpiration, refreshTokenExpiration, err := o.getOIDCSettings(ctx)
if err != nil {
return "", "", time.Time{}, err
}
resp, token, err := o.command.AddAccessAndRefreshToken(setContextUserSystem(ctx), userOrgID, userAgentID, applicationID, req.GetSubject(),
refreshToken, req.GetAudience(), scopes, authMethodsReferences, accessTokenLifetime,
refreshTokenIdleExpiration, refreshTokenExpiration, authTime, reason, actor) //PLANNED: lifetime from client
if err != nil {
if zerrors.IsErrorInvalidArgument(err) {
err = oidc.ErrInvalidGrant().WithParent(err)
}
return "", "", time.Time{}, err
}
// trigger activity log for authentication for user
activity.Trigger(ctx, userOrgID, req.GetSubject(), activity.OIDCRefreshToken, o.eventstore.FilterToQueryReducer)
return resp.TokenID, token, resp.Expiration, nil
func (o *OPStorage) CreateAccessAndRefreshTokens(context.Context, op.TokenRequest, string) (string, string, time.Time, error) {
panic(o.panicErr("CreateAccessAndRefreshTokens"))
}
func getInfoFromRequest(req op.TokenRequest) (agentID string, clientID string, userOrgID string, authTime time.Time, amr []string, reason domain.TokenReason, actor *domain.TokenActor) {
func (*OPStorage) panicErr(method string) error {
return fmt.Errorf("OPStorage.%s should not be called anymore. This is a bug. Please report https://github.com/zitadel/zitadel/issues", method)
}
func getInfoFromRequest(req op.TokenRequest) (agentID, clientID, userOrgID string, authTime time.Time, amr []string, preferredLanguage *language.Tag, reason domain.TokenReason, actor *domain.TokenActor) {
switch r := req.(type) {
case *AuthRequest:
return r.AgentID, r.ApplicationID, r.UserOrgID, r.AuthTime, r.GetAMR(), domain.TokenReasonAuthRequest, nil
return r.AgentID, r.ApplicationID, r.UserOrgID, r.AuthTime, r.GetAMR(), r.PreferredLanguage, domain.TokenReasonAuthRequest, nil
case *RefreshTokenRequest:
return r.UserAgentID, r.ClientID, "", r.AuthTime, r.AuthMethodsReferences, domain.TokenReasonRefresh, r.Actor
return r.UserAgentID, r.ClientID, "", r.AuthTime, r.AuthMethodsReferences, nil, domain.TokenReasonRefresh, r.Actor
case op.IDTokenRequest:
return "", r.GetClientID(), "", r.GetAuthTime(), r.GetAMR(), domain.TokenReasonAuthRequest, nil
return "", r.GetClientID(), "", r.GetAuthTime(), r.GetAMR(), nil, domain.TokenReasonAuthRequest, nil
case *oidc.JWTTokenRequest:
return "", "", "", r.GetAuthTime(), nil, domain.TokenReasonJWTProfile, nil
return "", "", "", r.GetAuthTime(), nil, nil, domain.TokenReasonJWTProfile, nil
case *clientCredentialsRequest:
return "", "", "", time.Time{}, nil, domain.TokenReasonClientCredentials, nil
return "", "", "", time.Time{}, nil, nil, domain.TokenReasonClientCredentials, nil
default:
return "", "", "", time.Time{}, nil, domain.TokenReasonAuthRequest, nil
return "", "", "", time.Time{}, nil, nil, domain.TokenReasonAuthRequest, nil
}
}
func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (_ op.RefreshTokenRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
plainToken, err := o.decryptGrant(refreshToken)
if err != nil {
return nil, op.ErrInvalidRefreshToken
}
if strings.HasPrefix(plainToken, command.IDPrefixV2) {
oidcSession, err := o.command.OIDCSessionByRefreshToken(ctx, plainToken)
if err != nil {
return nil, err
}
// trigger activity log for authentication for user
activity.Trigger(ctx, "", oidcSession.UserID, activity.OIDCRefreshToken, o.eventstore.FilterToQueryReducer)
return &RefreshTokenRequestV2{OIDCSessionWriteModel: oidcSession}, nil
}
tokenView, err := o.repo.RefreshTokenByToken(ctx, refreshToken)
if err != nil {
return nil, err
}
// trigger activity log for use of refresh token for user
activity.Trigger(ctx, tokenView.ResourceOwner, tokenView.UserID, activity.OIDCRefreshToken, o.eventstore.FilterToQueryReducer)
return RefreshTokenRequestFromBusiness(tokenView), nil
panic("TokenRequestByRefreshToken should not be called anymore. This is a bug. Please report https://github.com/zitadel/zitadel/issues")
}
func (o *OPStorage) TerminateSession(ctx context.Context, userID, clientID string) (err error) {
@@ -368,18 +255,19 @@ func (o *OPStorage) TerminateSessionFromRequest(ctx context.Context, endSessionR
}()
// check for the login client header
// and if not provided, terminate the session using the V1 method
headers, _ := http_utils.HeadersFromCtx(ctx)
if loginClient := headers.Get(LoginClientHeader); loginClient == "" {
return endSessionRequest.RedirectURI, o.TerminateSession(ctx, endSessionRequest.UserID, endSessionRequest.ClientID)
}
// in case there are not id_token_hint, redirect to the UI and let it decide which session to terminate
if endSessionRequest.IDTokenHintClaims == nil {
// in case there is no id_token_hint, redirect to the UI and let it decide which session to terminate
if headers.Get(LoginClientHeader) != "" && endSessionRequest.IDTokenHintClaims == nil {
return o.defaultLogoutURLV2 + endSessionRequest.RedirectURI, nil
}
// terminate the session of the id_token_hint
// If there is no login client header and no id_token_hint or the id_token_hint does not have a session ID,
// do a v1 Terminate session.
if endSessionRequest.IDTokenHintClaims == nil || endSessionRequest.IDTokenHintClaims.SessionID == "" {
return endSessionRequest.RedirectURI, o.TerminateSession(ctx, endSessionRequest.UserID, endSessionRequest.ClientID)
}
// terminate the v2 session of the id_token_hint
_, err = o.command.TerminateSessionWithoutTokenCheck(ctx, endSessionRequest.IDTokenHintClaims.SessionID)
if err != nil {
return "", err
@@ -543,18 +431,6 @@ func setContextUserSystem(ctx context.Context) context.Context {
return authz.SetCtxData(ctx, data)
}
func (o *OPStorage) getOIDCSettings(ctx context.Context) (accessTokenLifetime, idTokenLifetime, refreshTokenIdleExpiration, refreshTokenExpiration time.Duration, _ error) {
oidcSettings, err := o.query.OIDCSettingsByAggID(ctx, authz.GetInstance(ctx).InstanceID())
if err != nil && !zerrors.IsNotFound(err) {
return time.Duration(0), time.Duration(0), time.Duration(0), time.Duration(0), err
}
if oidcSettings != nil {
return oidcSettings.AccessTokenLifetime, oidcSettings.IdTokenLifetime, oidcSettings.RefreshTokenIdleExpiration, oidcSettings.RefreshTokenExpiration, nil
}
return o.defaultAccessTokenLifetime, o.defaultIdTokenLifetime, o.defaultRefreshTokenIdleExpiration, o.defaultRefreshTokenExpiration, nil
}
func CreateErrorCallbackURL(authReq op.AuthRequest, reason, description, uri string, authorizer op.Authorizer) (string, error) {
e := struct {
Error string `schema:"error"`
@@ -593,19 +469,140 @@ func CreateCodeCallbackURL(ctx context.Context, authReq op.AuthRequest, authoriz
return callback, err
}
func CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest, authorizer op.Authorizer) (string, error) {
client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID())
func (s *Server) CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest) (string, error) {
provider := s.Provider()
opClient, err := provider.Storage().GetClientByClientID(ctx, req.GetClientID())
if err != nil {
return "", err
}
createAccessToken := req.GetResponseType() != oidc.ResponseTypeIDTokenOnly
resp, err := op.CreateTokenResponse(ctx, req, client, authorizer, createAccessToken, "", "")
client, ok := opClient.(*Client)
if !ok {
return "", zerrors.ThrowInternal(nil, "OIDC-waeN6", "Error.Internal")
}
session, state, err := s.command.CreateOIDCSessionFromAuthRequest(
setContextUserSystem(ctx),
req.GetID(),
implicitFlowComplianceChecker(),
slices.Contains(client.GrantTypes(), oidc.GrantTypeRefreshToken),
)
if err != nil {
return "", err
}
callback, err := op.AuthResponseURL(req.GetRedirectURI(), req.GetResponseType(), req.GetResponseMode(), resp, authorizer.Encoder())
resp, err := s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion)
if err != nil {
return "", err
}
callback, err := op.AuthResponseURL(req.GetRedirectURI(), req.GetResponseType(), req.GetResponseMode(), resp, provider.Encoder())
if err != nil {
return "", err
}
return callback, err
}
func implicitFlowComplianceChecker() command.AuthRequestComplianceChecker {
return func(_ context.Context, authReq *command.AuthRequestWriteModel) error {
if err := authReq.CheckAuthenticated(); err != nil {
return err
}
return nil
}
}
func (s *Server) authorizeCallbackHandler(w http.ResponseWriter, r *http.Request) {
authorizer := s.Provider()
authReq, err := func() (authReq op.AuthRequest, err error) {
ctx, span := tracing.NewSpan(r.Context())
r = r.WithContext(ctx)
defer func() { span.EndWithError(err) }()
id, err := op.ParseAuthorizeCallbackRequest(r)
if err != nil {
return nil, err
}
authReq, err = authorizer.Storage().AuthRequestByID(r.Context(), id)
if err != nil {
return nil, err
}
if !authReq.Done() {
return authReq, oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required.")
}
return authReq, s.authResponse(authReq, authorizer, w, r)
}()
if err != nil {
op.AuthRequestError(w, r, authReq, err, authorizer)
}
}
func (s *Server) authResponse(authReq op.AuthRequest, authorizer op.Authorizer, w http.ResponseWriter, r *http.Request) (err error) {
ctx, span := tracing.NewSpan(r.Context())
r = r.WithContext(ctx)
defer func() { span.EndWithError(err) }()
client, err := authorizer.Storage().GetClientByClientID(ctx, authReq.GetClientID())
if err != nil {
op.AuthRequestError(w, r, authReq, err, authorizer)
return err
}
if authReq.GetResponseType() == oidc.ResponseTypeCode {
op.AuthResponseCode(w, r, authReq, authorizer)
return nil
}
return s.authResponseToken(authReq, authorizer, client, w, r)
}
func (s *Server) authResponseToken(authReq op.AuthRequest, authorizer op.Authorizer, opClient op.Client, w http.ResponseWriter, r *http.Request) (err error) {
ctx, span := tracing.NewSpan(r.Context())
r = r.WithContext(ctx)
defer func() { span.EndWithError(err) }()
client, ok := opClient.(*Client)
if !ok {
return zerrors.ThrowInternal(nil, "OIDC-waeN6", "Error.Internal")
}
userAgentID, _, userOrgID, authTime, authMethodsReferences, preferredLanguage, reason, actor := getInfoFromRequest(authReq)
scope := authReq.GetScopes()
session, err := s.command.CreateOIDCSession(ctx,
authReq.GetSubject(),
userOrgID,
client.client.ClientID,
scope,
authReq.GetAudience(),
AMRToAuthMethodTypes(authMethodsReferences),
authTime,
authReq.GetNonce(),
preferredLanguage,
&domain.UserAgent{
FingerprintID: &userAgentID,
},
reason,
actor,
slices.Contains(scope, oidc.ScopeOfflineAccess),
)
if err != nil {
op.AuthRequestError(w, r, authReq, err, authorizer)
return err
}
resp, err := s.accessTokenResponseFromSession(ctx, client, session, authReq.GetState(), client.client.ProjectID, client.client.ProjectRoleAssertion)
if err != nil {
op.AuthRequestError(w, r, authReq, err, authorizer)
return err
}
if authReq.GetResponseMode() == oidc.ResponseModeFormPost {
if err = op.AuthResponseFormPost(w, authReq.GetRedirectURI(), resp, authorizer.Encoder()); err != nil {
op.AuthRequestError(w, r, authReq, err, authorizer)
return err
}
return nil
}
callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), resp, authorizer.Encoder())
if err != nil {
op.AuthRequestError(w, r, authReq, err, authorizer)
return err
}
http.Redirect(w, r, callback, http.StatusFound)
return nil
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"slices"
"strings"
"time"
@@ -69,7 +70,7 @@ func (o *OPStorage) GetKeyByIDAndIssuer(ctx context.Context, keyID, issuer strin
err = oidcError(err)
span.EndWithError(err)
}()
publicKeyData, err := o.query.GetAuthNKeyPublicKeyByIDAndIdentifier(ctx, keyID, issuer, false)
publicKeyData, err := o.query.GetAuthNKeyPublicKeyByIDAndIdentifier(ctx, keyID, issuer)
if err != nil {
return nil, err
}
@@ -1040,3 +1041,26 @@ func (s *Server) verifyClientSecret(ctx context.Context, client *query.OIDCClien
s.command.OIDCSecretCheckSucceeded(ctx, client.AppID, client.ProjectID, client.Settings.ResourceOwner, updated)
return nil
}
func (s *Server) checkOrgScopes(ctx context.Context, user *query.User, scopes []string) ([]string, error) {
if slices.ContainsFunc(scopes, func(scope string) bool {
return strings.HasPrefix(scope, domain.OrgDomainPrimaryScope)
}) {
org, err := s.query.OrgByID(ctx, false, user.ResourceOwner)
if err != nil {
return nil, err
}
scopes = slices.DeleteFunc(scopes, func(scope string) bool {
if domain, ok := strings.CutPrefix(scope, domain.OrgDomainPrimaryScope); ok {
return domain != org.Domain
}
return false
})
}
return slices.DeleteFunc(scopes, func(scope string) bool {
if orgID, ok := strings.CutPrefix(scope, domain.OrgIDScope); ok {
return orgID != user.ResourceOwner
}
return false
}), nil
}

View File

@@ -104,28 +104,7 @@ func (c *Client) AccessTokenType() op.AccessTokenType {
}
func (c *Client) IsScopeAllowed(scope string) bool {
if strings.HasPrefix(scope, domain.OrgDomainPrimaryScope) {
return true
}
if strings.HasPrefix(scope, domain.OrgIDScope) {
return true
}
if strings.HasPrefix(scope, domain.ProjectIDScope) {
return true
}
if strings.HasPrefix(scope, domain.SelectIDPScope) {
return true
}
if scope == ScopeUserMetaData {
return true
}
if scope == ScopeResourceOwner {
return true
}
if scope == ScopeProjectsRoles {
return true
}
return slices.Contains(c.allowedScopes, scope)
return isScopeAllowed(scope, c.allowedScopes...)
}
func (c *Client) ClockSkew() time.Duration {
@@ -249,3 +228,28 @@ func clientIDFromCredentials(cc *op.ClientCredentials) (clientID string, asserti
}
return cc.ClientID, false, nil
}
func isScopeAllowed(scope string, allowedScopes ...string) bool {
if strings.HasPrefix(scope, domain.OrgDomainPrimaryScope) {
return true
}
if strings.HasPrefix(scope, domain.OrgIDScope) {
return true
}
if strings.HasPrefix(scope, domain.ProjectIDScope) {
return true
}
if strings.HasPrefix(scope, domain.SelectIDPScope) {
return true
}
if scope == ScopeUserMetaData {
return true
}
if scope == ScopeResourceOwner {
return true
}
if scope == ScopeProjectsRoles {
return true
}
return slices.Contains(allowedScopes, scope)
}

View File

@@ -2,6 +2,7 @@ package oidc
import (
"context"
"strings"
"time"
"github.com/zitadel/oidc/v3/pkg/oidc"
@@ -136,9 +137,8 @@ func (c *clientCredentialsClient) RestrictAdditionalAccessTokenScopes() func(sco
}
}
// IsScopeAllowed returns null false as the check is executed during the auth request validation
func (c *clientCredentialsClient) IsScopeAllowed(scope string) bool {
return false
return isScopeAllowed(scope) || strings.HasPrefix(scope, ScopeProjectRolePrefix)
}
// IDTokenUserinfoClaimsAssertion returns null false as no id_token is issued

View File

@@ -7,7 +7,6 @@ import (
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/client"
@@ -22,7 +21,6 @@ import (
"github.com/zitadel/zitadel/pkg/grpc/authn"
"github.com/zitadel/zitadel/pkg/grpc/management"
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2beta"
"github.com/zitadel/zitadel/pkg/grpc/user"
)
func TestServer_Introspect(t *testing.T) {
@@ -346,73 +344,3 @@ func createInvalidKeyData(t testing.TB, client *management.AddOIDCAppResponse) [
require.NoError(t, err)
return data
}
func TestServer_CreateAccessToken_ClientCredentials(t *testing.T) {
_, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX)
require.NoError(t, err)
type clientDetails struct {
clientID string
clientSecret string
keyData []byte
}
tests := []struct {
name string
clientID string
clientSecret string
wantErr bool
}{
{
name: "missing client ID error",
clientID: "",
clientSecret: clientSecret,
wantErr: true,
},
{
name: "client not found error",
clientID: "foo",
clientSecret: clientSecret,
wantErr: true,
},
{
name: "machine user without secret error",
clientID: func() string {
name := gofakeit.Username()
_, err := Tester.Client.Mgmt.AddMachineUser(CTX, &management.AddMachineUserRequest{
Name: name,
UserName: name,
AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT,
})
require.NoError(t, err)
return name
}(),
clientSecret: clientSecret,
wantErr: true,
},
{
name: "wrong secret error",
clientID: clientID,
clientSecret: "bar",
wantErr: true,
},
{
name: "success",
clientID: clientID,
clientSecret: clientSecret,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider, err := rp.NewRelyingPartyOIDC(CTX, Tester.OIDCIssuer(), tt.clientID, tt.clientSecret, redirectURI, []string{oidc.ScopeOpenID})
require.NoError(t, err)
tokens, err := rp.ClientCredentials(CTX, provider, nil)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.NotNil(t, tokens)
assert.NotEmpty(t, tokens.AccessToken)
})
}
}

View File

@@ -2,14 +2,14 @@ package oidc
import (
"context"
"slices"
"time"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/api/ui/login"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
@@ -80,7 +80,7 @@ func (o *OPStorage) StoreDeviceAuthorization(ctx context.Context, clientID, devi
if err != nil {
return err
}
details, err := o.command.AddDeviceAuth(ctx, clientID, deviceCode, userCode, expires, scope, audience)
details, err := o.command.AddDeviceAuth(ctx, clientID, deviceCode, userCode, expires, scope, audience, slices.Contains(scope, oidc.ScopeOfflineAccess))
if err == nil {
logger.SetFields("details", details).Debug(logMsg)
}
@@ -88,50 +88,6 @@ func (o *OPStorage) StoreDeviceAuthorization(ctx context.Context, clientID, devi
return err
}
func newDeviceAuthorizationState(d *query.DeviceAuth) *op.DeviceAuthorizationState {
return &op.DeviceAuthorizationState{
ClientID: d.ClientID,
Scopes: d.Scopes,
Audience: d.Audience,
Expires: d.Expires,
Done: d.State.Done(),
Denied: d.State.Denied(),
Subject: d.Subject,
AMR: AuthMethodTypesToAMR(d.UserAuthMethods),
AuthTime: d.AuthTime,
}
}
// GetDeviceAuthorizatonState retrieves the current state of the Device Authorization process.
// It implements the [op.DeviceAuthorizationStorage] interface and is used by devices that
// are polling until they successfully receive a token or we indicate a denied or expired state.
// As generated user codes are of low entropy, this implementation also takes care or
// device authorization request cleanup, when it has been Approved, Denied or Expired.
func (o *OPStorage) GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (state *op.DeviceAuthorizationState, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
deviceAuth, err := o.query.DeviceAuthByDeviceCode(ctx, deviceCode)
if err != nil {
return nil, err
}
logging.WithFields(
"device_code", deviceCode,
"expires", deviceAuth.Expires, "scopes", deviceAuth.Scopes,
"subject", deviceAuth.Subject, "state", deviceAuth.State,
).Debug("device authorization state")
// Cancel the request if it is expired, only if it wasn't Done meanwhile
if !deviceAuth.State.Done() && deviceAuth.Expires.Before(time.Now()) {
_, err = o.command.CancelDeviceAuth(ctx, deviceAuth.DeviceCode, domain.DeviceAuthCanceledExpired)
if err != nil {
return nil, err
}
deviceAuth.State = domain.DeviceAuthStateExpired
}
return newDeviceAuthorizationState(deviceAuth), nil
func (o *OPStorage) GetDeviceAuthorizatonState(ctx context.Context, _, deviceCode string) (state *op.DeviceAuthorizationState, err error) {
return nil, nil
}

View File

@@ -140,10 +140,13 @@ func NewServer(
fallbackLogger: fallbackLogger,
hasher: hasher,
signingKeyAlgorithm: config.SigningKeyAlgorithm,
encAlg: encryptionAlg,
opCrypto: op.NewAESCrypto(opConfig.CryptoKey),
assetAPIPrefix: assets.AssetAPI(externalSecure),
}
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
server.Handler = op.RegisterLegacyServer(server,
server.authorizeCallbackHandler,
op.WithFallbackLogger(fallbackLogger),
op.WithHTTPMiddleware(
middleware.MetricsHandler(metricTypes),

View File

@@ -37,7 +37,10 @@ type Server struct {
fallbackLogger *slog.Logger
hasher *crypto.Hasher
signingKeyAlgorithm string
assetAPIPrefix func(ctx context.Context) string
encAlg crypto.EncryptionAlgorithm
opCrypto op.Crypto
assetAPIPrefix func(ctx context.Context) string
}
func endpoints(endpointConfig *EndpointConfig) op.Endpoints {
@@ -153,41 +156,6 @@ func (s *Server) DeviceAuthorization(ctx context.Context, r *op.ClientRequest[oi
return s.LegacyServer.DeviceAuthorization(ctx, r)
}
func (s *Server) CodeExchange(ctx context.Context, r *op.ClientRequest[oidc.AccessTokenRequest]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
return s.LegacyServer.CodeExchange(ctx, r)
}
func (s *Server) RefreshToken(ctx context.Context, r *op.ClientRequest[oidc.RefreshTokenRequest]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
return s.LegacyServer.RefreshToken(ctx, r)
}
func (s *Server) JWTProfile(ctx context.Context, r *op.Request[oidc.JWTProfileGrantRequest]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
return s.LegacyServer.JWTProfile(ctx, r)
}
func (s *Server) ClientCredentialsExchange(ctx context.Context, r *op.ClientRequest[oidc.ClientCredentialsRequest]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
return s.LegacyServer.ClientCredentialsExchange(ctx, r)
}
func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.DeviceAccessTokenRequest]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
return s.LegacyServer.DeviceToken(ctx, r)
}
func (s *Server) Revocation(ctx context.Context, r *op.ClientRequest[oidc.RevocationRequest]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
@@ -232,3 +200,10 @@ func (s *Server) createDiscoveryConfig(ctx context.Context, supportedUILocales o
RequestParameterSupported: s.Provider().RequestObjectSupported(),
}
}
func response(resp any, err error) (*op.Response, error) {
if err != nil {
return nil, err
}
return op.NewResponse(resp), nil
}

198
internal/api/oidc/token.go Normal file
View File

@@ -0,0 +1,198 @@
package oidc
import (
"context"
"encoding/base64"
"slices"
"sync"
"time"
"github.com/go-jose/go-jose/v4"
"github.com/zitadel/oidc/v3/pkg/crypto"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
/*
For each grant-type, tokens creation follows the same rough logical steps:
1. Information gathering: who is requesting the token, what do we put in the claims?
2. Decision making: is the request authorized? (valid exchange code, auth request completed, valid token etc...)
3. Build an OIDC session in storage: inform the eventstore we are creating tokens.
4. Use the OIDC session to encrypt and / or sign the requested tokens
In some cases step 1 till 3 are completely implemented in the command package,
for example the v2 code exchange and refresh token.
*/
func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.Client, session *command.OIDCSession, state, projectID string, projectRoleAssertion bool) (_ *oidc.AccessTokenResponse, err error) {
getUserInfo := s.getUserInfoOnce(session.UserID, projectID, projectRoleAssertion, session.Scope)
getSigner := s.getSignerOnce()
resp := &oidc.AccessTokenResponse{
TokenType: oidc.BearerToken,
RefreshToken: session.RefreshToken,
ExpiresIn: timeToOIDCExpiresIn(session.Expiration),
State: state,
}
// If the session does not have a token ID, it is an implicit ID-Token only response.
if session.TokenID != "" {
if client.AccessTokenType() == op.AccessTokenTypeJWT {
resp.AccessToken, err = s.createJWT(ctx, client, session, getUserInfo, getSigner)
} else {
resp.AccessToken, err = op.CreateBearerToken(session.TokenID, session.UserID, s.opCrypto)
}
if err != nil {
return nil, err
}
}
if slices.Contains(session.Scope, oidc.ScopeOpenID) {
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, session.SessionID, resp.AccessToken, session.Audience, session.AuthMethods, session.AuthTime, session.Nonce, session.Actor)
}
return resp, err
}
// signerFunc is a getter function that allows add-hoc retrieval of the instance's signer.
type signerFunc func(ctx context.Context) (jose.Signer, jose.SignatureAlgorithm, error)
// getSignerOnce returns a function which retrieves the instance's signer from the database once.
// Repeated calls of the returned function return the same results.
func (s *Server) getSignerOnce() signerFunc {
var (
once sync.Once
signer jose.Signer
signAlg jose.SignatureAlgorithm
err error
)
return func(ctx context.Context) (jose.Signer, jose.SignatureAlgorithm, error) {
once.Do(func() {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
var signingKey op.SigningKey
signingKey, err = s.Provider().Storage().SigningKey(ctx)
if err != nil {
return
}
signAlg = signingKey.SignatureAlgorithm()
signer, err = op.SignerFromKey(signingKey)
if err != nil {
return
}
})
return signer, signAlg, err
}
}
// userInfoFunc is a getter function that allows add-hoc retrieval of a user.
type userInfoFunc func(ctx context.Context) (*oidc.UserInfo, error)
// getUserInfoOnce returns a function which retrieves userinfo from the database once.
// Repeated calls of the returned function return the same results.
func (s *Server) getUserInfoOnce(userID, projectID string, projectRoleAssertion bool, scope []string) userInfoFunc {
var (
once sync.Once
userInfo *oidc.UserInfo
err error
)
return func(ctx context.Context) (*oidc.UserInfo, error) {
once.Do(func() {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
userInfo, err = s.userInfo(ctx, userID, scope, projectID, projectRoleAssertion, false)
})
return userInfo, err
}
}
func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo userInfoFunc, getSigningKey signerFunc, sessionID, accessToken string, audience []string, authMethods []domain.UserAuthMethodType, authTime time.Time, nonce string, actor *domain.TokenActor) (idToken string, exp uint64, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
userInfo, err := getUserInfo(ctx)
if err != nil {
return "", 0, err
}
signer, signAlg, err := getSigningKey(ctx)
if err != nil {
return "", 0, err
}
expTime := time.Now().Add(client.IDTokenLifetime()).Add(client.ClockSkew())
claims := oidc.NewIDTokenClaims(
op.IssuerFromContext(ctx),
"",
audience,
expTime,
authTime,
nonce,
"",
AuthMethodTypesToAMR(authMethods),
client.GetID(),
client.ClockSkew(),
)
claims.SessionID = sessionID
claims.Actor = actorDomainToClaims(actor)
claims.SetUserInfo(userInfo)
if accessToken != "" {
claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signAlg)
if err != nil {
return "", 0, err
}
}
idToken, err = crypto.Sign(claims, signer)
return idToken, timeToOIDCExpiresIn(expTime), err
}
func timeToOIDCExpiresIn(exp time.Time) uint64 {
return uint64(time.Until(exp) / time.Second)
}
func (*Server) createJWT(ctx context.Context, client op.Client, session *command.OIDCSession, getUserInfo userInfoFunc, getSigner signerFunc) (_ string, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
userInfo, err := getUserInfo(ctx)
if err != nil {
return "", err
}
signer, _, err := getSigner(ctx)
if err != nil {
return "", err
}
expTime := session.Expiration.Add(client.ClockSkew())
claims := oidc.NewAccessTokenClaims(
op.IssuerFromContext(ctx),
userInfo.Subject,
session.Audience,
expTime,
session.TokenID,
client.GetID(),
client.ClockSkew(),
)
claims.Actor = actorDomainToClaims(session.Actor)
claims.Claims = userInfo.Claims
return crypto.Sign(claims, signer)
}
// decryptCode decrypts a code or refresh_token
func (s *Server) decryptCode(ctx context.Context, code string) (_ string, err error) {
_, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
decoded, err := base64.RawURLEncoding.DecodeString(code)
if err != nil {
return "", err
}
return s.encAlg.DecryptString(decoded, s.encAlg.EncryptionKeyID())
}

View File

@@ -0,0 +1,51 @@
package oidc
import (
"context"
"time"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
func (s *Server) ClientCredentialsExchange(ctx context.Context, r *op.ClientRequest[oidc.ClientCredentialsRequest]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
span.EndWithError(err)
err = oidcError(err)
}()
client, ok := r.Client.(*clientCredentialsClient)
if !ok {
return nil, zerrors.ThrowInternal(nil, "OIDC-ga0EP", "Error.Internal")
}
scope, err := op.ValidateAuthReqScopes(client, r.Data.Scope)
if err != nil {
return nil, err
}
scope, err = s.checkOrgScopes(ctx, client.user, scope)
if err != nil {
return nil, err
}
session, err := s.command.CreateOIDCSession(ctx,
client.user.ID,
client.user.ResourceOwner,
r.Data.ClientID,
scope,
domain.AddAudScopeToAudience(ctx, nil, r.Data.Scope),
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
time.Now(),
"",
nil,
nil,
domain.TokenReasonClientCredentials,
nil,
false,
)
return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false))
}

View File

@@ -0,0 +1,143 @@
//go:build integration
package oidc_test
import (
"testing"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v3/pkg/oidc"
oidc_api "github.com/zitadel/zitadel/internal/api/oidc"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/pkg/grpc/management"
"github.com/zitadel/zitadel/pkg/grpc/user"
)
func TestServer_ClientCredentialsExchange(t *testing.T) {
userID, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX)
require.NoError(t, err)
type claims struct {
resourceOwnerID any
resourceOwnerName any
resourceOwnerPrimaryDomain any
orgDomain any
}
tests := []struct {
name string
clientID string
clientSecret string
scope []string
wantClaims claims
wantErr bool
}{
{
name: "missing client ID error",
clientID: "",
clientSecret: clientSecret,
scope: []string{oidc.ScopeOpenID},
wantErr: true,
},
{
name: "client not found error",
clientID: "foo",
clientSecret: clientSecret,
scope: []string{oidc.ScopeOpenID},
wantErr: true,
},
{
name: "machine user without secret error",
clientID: func() string {
name := gofakeit.Username()
_, err := Tester.Client.Mgmt.AddMachineUser(CTX, &management.AddMachineUserRequest{
Name: name,
UserName: name,
AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT,
})
require.NoError(t, err)
return name
}(),
clientSecret: clientSecret,
scope: []string{oidc.ScopeOpenID},
wantErr: true,
},
{
name: "wrong secret error",
clientID: clientID,
clientSecret: "bar",
scope: []string{oidc.ScopeOpenID},
wantErr: true,
},
{
name: "success",
clientID: clientID,
clientSecret: clientSecret,
scope: []string{oidc.ScopeOpenID},
},
{
name: "org id and domain scope",
clientID: clientID,
clientSecret: clientSecret,
scope: []string{
oidc.ScopeOpenID,
domain.OrgIDScope + Tester.Organisation.ID,
domain.OrgDomainPrimaryScope + Tester.Organisation.Domain,
},
wantClaims: claims{
resourceOwnerID: Tester.Organisation.ID,
resourceOwnerName: Tester.Organisation.Name,
resourceOwnerPrimaryDomain: Tester.Organisation.Domain,
orgDomain: Tester.Organisation.Domain,
},
},
{
name: "invalid org domain filtered",
clientID: clientID,
clientSecret: clientSecret,
scope: []string{
oidc.ScopeOpenID,
domain.OrgDomainPrimaryScope + Tester.Organisation.Domain,
domain.OrgDomainPrimaryScope + "foo"},
wantClaims: claims{
orgDomain: Tester.Organisation.Domain,
},
},
{
name: "invalid org id filtered",
clientID: clientID,
clientSecret: clientSecret,
scope: []string{oidc.ScopeOpenID,
domain.OrgIDScope + Tester.Organisation.ID,
domain.OrgIDScope + "foo",
},
wantClaims: claims{
resourceOwnerID: Tester.Organisation.ID,
resourceOwnerName: Tester.Organisation.Name,
resourceOwnerPrimaryDomain: Tester.Organisation.Domain,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider, err := rp.NewRelyingPartyOIDC(CTX, Tester.OIDCIssuer(), tt.clientID, tt.clientSecret, redirectURI, tt.scope)
require.NoError(t, err)
tokens, err := rp.ClientCredentials(CTX, provider, nil)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.NotNil(t, tokens)
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, userID, provider)
require.NoError(t, err)
assert.Equal(t, tt.wantClaims.resourceOwnerID, userinfo.Claims[oidc_api.ClaimResourceOwnerID])
assert.Equal(t, tt.wantClaims.resourceOwnerName, userinfo.Claims[oidc_api.ClaimResourceOwnerName])
assert.Equal(t, tt.wantClaims.resourceOwnerPrimaryDomain, userinfo.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain])
assert.Equal(t, tt.wantClaims.orgDomain, userinfo.Claims[domain.OrgDomainPrimaryClaim])
})
}
}

View File

@@ -0,0 +1,125 @@
package oidc
import (
"context"
"slices"
"strings"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
func (s *Server) CodeExchange(ctx context.Context, r *op.ClientRequest[oidc.AccessTokenRequest]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
span.EndWithError(err)
err = oidcError(err)
}()
client, ok := r.Client.(*Client)
if !ok {
return nil, zerrors.ThrowInternal(nil, "OIDC-Ae2ph", "Error.Internal")
}
plainCode, err := s.decryptCode(ctx, r.Data.Code)
if err != nil {
return nil, zerrors.ThrowInvalidArgument(err, "OIDC-ahLi2", "Errors.User.Code.Invalid")
}
var (
session *command.OIDCSession
state string
)
if strings.HasPrefix(plainCode, command.IDPrefixV2) {
session, state, err = s.command.CreateOIDCSessionFromAuthRequest(
setContextUserSystem(ctx),
plainCode,
codeExchangeComplianceChecker(client, r.Data),
slices.Contains(client.GrantTypes(), oidc.GrantTypeRefreshToken),
)
} else {
session, state, err = s.codeExchangeV1(ctx, client, r.Data, r.Data.Code)
}
if err != nil {
return nil, err
}
return response(s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion))
}
// codeExchangeV1 creates a v2 token from a v1 auth request.
func (s *Server) codeExchangeV1(ctx context.Context, client *Client, req *oidc.AccessTokenRequest, code string) (session *command.OIDCSession, state string, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
authReq, err := s.getAuthRequestV1ByCode(ctx, code)
if err != nil {
return nil, "", err
}
if challenge := authReq.GetCodeChallenge(); challenge != nil || client.AuthMethod() == oidc.AuthMethodNone {
if err = op.AuthorizeCodeChallenge(req.CodeVerifier, challenge); err != nil {
return nil, "", err
}
}
if req.RedirectURI != authReq.GetRedirectURI() {
return nil, "", oidc.ErrInvalidGrant().WithDescription("redirect_uri does not correspond")
}
userAgentID, _, userOrgID, authTime, authMethodsReferences, preferredLanguage, reason, actor := getInfoFromRequest(authReq)
scope := authReq.GetScopes()
session, err = s.command.CreateOIDCSession(ctx,
authReq.GetSubject(),
userOrgID,
client.client.ClientID,
scope,
authReq.GetAudience(),
AMRToAuthMethodTypes(authMethodsReferences),
authTime,
authReq.GetNonce(),
preferredLanguage,
&domain.UserAgent{
FingerprintID: &userAgentID,
},
reason,
actor,
slices.Contains(scope, oidc.ScopeOfflineAccess),
)
if err != nil {
return nil, "", err
}
return session, authReq.GetState(), s.repo.DeleteAuthRequest(ctx, authReq.GetID())
}
// getAuthRequestV1ByCode finds the v1 auth request by code.
// code needs to be the encrypted version of the ID,
// this is required by the underlying repo.
func (s *Server) getAuthRequestV1ByCode(ctx context.Context, code string) (op.AuthRequest, error) {
authReq, err := s.repo.AuthRequestByCode(ctx, code)
if err != nil {
return nil, err
}
return AuthRequestFromBusiness(authReq)
}
func codeExchangeComplianceChecker(client *Client, req *oidc.AccessTokenRequest) command.AuthRequestComplianceChecker {
return func(ctx context.Context, authReq *command.AuthRequestWriteModel) error {
if authReq.CodeChallenge != nil || client.AuthMethod() == oidc.AuthMethodNone {
err := op.AuthorizeCodeChallenge(req.CodeVerifier, CodeChallengeToOIDC(authReq.CodeChallenge))
if err != nil {
return err
}
}
if req.RedirectURI != authReq.RedirectURI {
return oidc.ErrInvalidGrant().WithDescription("redirect_uri does not correspond")
}
if err := authReq.CheckAuthenticated(); err != nil {
return err
}
return nil
}
}

View File

@@ -0,0 +1,46 @@
package oidc
import (
"context"
"errors"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.DeviceAccessTokenRequest]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
span.EndWithError(err)
err = oidcError(err)
}()
client, ok := r.Client.(*Client)
if !ok {
return nil, zerrors.ThrowInternal(nil, "OIDC-Ae2ph", "Error.Internal")
}
session, err := s.command.CreateOIDCSessionFromDeviceAuth(ctx, r.Data.DeviceCode)
if err == nil {
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion))
}
if errors.Is(err, context.DeadlineExceeded) {
return nil, oidc.ErrSlowDown().WithParent(err)
}
var target command.DeviceAuthStateError
if errors.As(err, &target) {
state := domain.DeviceAuthState(target)
if state == domain.DeviceAuthStateInitiated {
return nil, oidc.ErrAuthorizationPending()
}
if state == domain.DeviceAuthStateExpired {
return nil, oidc.ErrExpiredDeviceCode()
}
}
return nil, oidc.ErrAccessDenied().WithParent(err)
}

View File

@@ -5,9 +5,9 @@ import (
"slices"
"time"
"github.com/zitadel/oidc/v3/pkg/crypto"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
@@ -134,13 +134,16 @@ func (s *Server) verifyExchangeToken(ctx context.Context, client *Client, token
return idTokenClaimsToExchangeToken(claims, resourceOwner), nil
case oidc.JWTTokenType:
resourceOwner := new(string)
verifier := op.NewJWTProfileVerifierKeySet(keySetMap(client.client.PublicKeys), op.IssuerFromContext(ctx), time.Hour, client.client.ClockSkew, s.jwtProfileUserCheck(ctx, resourceOwner))
var (
resourceOwner string
preferredLanguage *language.Tag
)
verifier := op.NewJWTProfileVerifierKeySet(keySetMap(client.client.PublicKeys), op.IssuerFromContext(ctx), time.Hour, client.client.ClockSkew, s.jwtProfileUserCheck(ctx, &resourceOwner, &preferredLanguage))
jwt, err := op.VerifyJWTAssertion(ctx, token, verifier)
if err != nil {
return nil, zerrors.ThrowPermissionDenied(err, "OIDC-eiS6o", "Errors.TokenExchange.Token.Invalid")
}
return jwtToExchangeToken(jwt, *resourceOwner), nil
return jwtToExchangeToken(jwt, resourceOwner, preferredLanguage), nil
case UserIDTokenType:
user, err := s.query.GetUserByID(ctx, false, token)
@@ -156,13 +159,18 @@ func (s *Server) verifyExchangeToken(ctx context.Context, client *Client, token
}
}
func (s *Server) jwtProfileUserCheck(ctx context.Context, resourceOwner *string) op.JWTProfileVerifierOption {
// jwtProfileUserCheck finds the user by subject (user ID) and sets the resourceOwner through the pointer.
// preferred Language is set only if it was defined for a Human user, else the pointed pointer remains nil.
func (s *Server) jwtProfileUserCheck(ctx context.Context, resourceOwner *string, preferredLanguage **language.Tag) op.JWTProfileVerifierOption {
return op.SubjectCheck(func(request *oidc.JWTTokenRequest) error {
user, err := s.query.GetUserByID(ctx, false, request.Subject)
if err != nil {
return zerrors.ThrowPermissionDenied(err, "OIDC-Nee6r", "Errors.TokenExchange.Token.Invalid")
}
*resourceOwner = user.ResourceOwner
if user.Human != nil && !user.Human.PreferredLanguage.IsRoot() {
*preferredLanguage = &user.Human.PreferredLanguage
}
return nil
})
}
@@ -210,21 +218,8 @@ func validateTokenExchangeAudience(requestedAudience, subjectAudience, actorAudi
// Both tokens may point to the same object (subjectToken) in case of a regular Token Exchange.
// When the subject and actor Tokens point to different objects, the new tokens will be for impersonation / delegation.
func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenType, client *Client, subjectToken, actorToken *exchangeToken, audience, scopes []string) (_ *oidc.TokenExchangeResponse, err error) {
var (
userInfo *oidc.UserInfo
signingKey op.SigningKey
)
if slices.Contains(scopes, oidc.ScopeOpenID) || tokenType == oidc.JWTTokenType || tokenType == oidc.IDTokenType {
projectID := client.client.ProjectID
userInfo, err = s.userInfo(ctx, subjectToken.userID, scopes, projectID, client.client.ProjectRoleAssertion, false)
if err != nil {
return nil, err
}
signingKey, err = s.Provider().Storage().SigningKey(ctx)
if err != nil {
return nil, err
}
}
getUserInfo := s.getUserInfoOnce(subjectToken.userID, client.client.ProjectID, client.client.ProjectRoleAssertion, scopes)
getSigner := s.getSignerOnce()
resp := &oidc.TokenExchangeResponse{
Scopes: scopes,
@@ -237,21 +232,23 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT
actor = actorToken.nestedActor()
}
var sessionID string
switch tokenType {
case oidc.AccessTokenType, "":
resp.AccessToken, resp.RefreshToken, resp.ExpiresIn, err = s.createExchangeAccessToken(ctx, client, subjectToken.resourceOwner, subjectToken.userID, audience, scopes, actorToken.authMethods, actorToken.authTime, reason, actor)
resp.AccessToken, resp.RefreshToken, sessionID, resp.ExpiresIn, err = s.createExchangeAccessToken(ctx, client, subjectToken.userID, subjectToken.resourceOwner, audience, scopes, actorToken.authMethods, actorToken.authTime, subjectToken.preferredLanguage, reason, actor)
resp.TokenType = oidc.BearerToken
resp.IssuedTokenType = oidc.AccessTokenType
case oidc.JWTTokenType:
resp.AccessToken, resp.RefreshToken, resp.ExpiresIn, err = s.createExchangeJWT(ctx, signingKey, client, subjectToken.resourceOwner, subjectToken.userID, audience, scopes, actorToken.authMethods, actorToken.authTime, reason, actor, userInfo.Claims)
resp.AccessToken, resp.RefreshToken, resp.ExpiresIn, err = s.createExchangeJWT(ctx, client, getUserInfo, getSigner, subjectToken.userID, subjectToken.resourceOwner, audience, scopes, actorToken.authMethods, actorToken.authTime, subjectToken.preferredLanguage, reason, actor)
resp.TokenType = oidc.BearerToken
resp.IssuedTokenType = oidc.JWTTokenType
case oidc.IDTokenType:
resp.AccessToken, resp.ExpiresIn, err = s.createExchangeIDToken(ctx, signingKey, client, subjectToken.userID, "", audience, userInfo, actorToken.authMethods, actorToken.authTime, reason, actor)
resp.AccessToken, resp.ExpiresIn, err = s.createIDToken(ctx, client, getUserInfo, getSigner, "", resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
resp.TokenType = TokenTypeNA
resp.IssuedTokenType = oidc.IDTokenType
case oidc.RefreshTokenType, UserIDTokenType:
fallthrough
default:
@@ -262,7 +259,7 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT
}
if slices.Contains(scopes, oidc.ScopeOpenID) && tokenType != oidc.IDTokenType {
resp.IDToken, _, err = s.createExchangeIDToken(ctx, signingKey, client, subjectToken.userID, resp.AccessToken, audience, userInfo, actorToken.authMethods, actorToken.authTime, reason, actor)
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, sessionID, resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
if err != nil {
return nil, err
}
@@ -271,77 +268,83 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT
return resp, nil
}
func (s *Server) createExchangeAccessToken(ctx context.Context, client *Client, resourceOwner, userID string, audience, scopes []string, authMethods []domain.UserAuthMethodType, authTime time.Time, reason domain.TokenReason, actor *domain.TokenActor) (accessToken string, refreshToken string, exp uint64, err error) {
tokenInfo, refreshToken, err := s.createAccessTokenCommands(ctx, client, resourceOwner, userID, audience, scopes, authMethods, authTime, reason, actor)
if err != nil {
return "", "", 0, err
}
accessToken, err = op.CreateBearerToken(tokenInfo.TokenID, userID, s.Provider().Crypto())
if err != nil {
return "", "", 0, err
}
return accessToken, refreshToken, timeToOIDCExpiresIn(tokenInfo.Expiration), nil
}
func (s *Server) createExchangeAccessToken(
ctx context.Context,
client *Client,
userID,
resourceOwner string,
audience,
scope []string,
authMethods []domain.UserAuthMethodType,
authTime time.Time,
preferredLanguage *language.Tag,
reason domain.TokenReason,
actor *domain.TokenActor,
) (accessToken, refreshToken, sessionID string, exp uint64, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
func (s *Server) createExchangeJWT(ctx context.Context, signingKey op.SigningKey, client *Client, resourceOwner, userID string, audience, scopes []string, authMethods []domain.UserAuthMethodType, authTime time.Time, reason domain.TokenReason, actor *domain.TokenActor, privateClaims map[string]any) (accessToken string, refreshToken string, exp uint64, err error) {
tokenInfo, refreshToken, err := s.createAccessTokenCommands(ctx, client, resourceOwner, userID, audience, scopes, authMethods, authTime, reason, actor)
if err != nil {
return "", "", 0, err
}
expTime := tokenInfo.Expiration.Add(client.ClockSkew())
claims := oidc.NewAccessTokenClaims(op.IssuerFromContext(ctx), userID, tokenInfo.Audience, expTime, tokenInfo.TokenID, client.GetID(), client.ClockSkew())
claims.Actor = actorDomainToClaims(tokenInfo.Actor)
claims.Claims = privateClaims
signer, err := op.SignerFromKey(signingKey)
if err != nil {
return "", "", 0, err
}
accessToken, err = crypto.Sign(claims, signer)
if err != nil {
return "", "", 0, nil
}
return accessToken, refreshToken, timeToOIDCExpiresIn(expTime), nil
}
func (s *Server) createExchangeIDToken(ctx context.Context, signingKey op.SigningKey, client *Client, userID, accessToken string, audience []string, userInfo *oidc.UserInfo, authMethods []domain.UserAuthMethodType, authTime time.Time, reason domain.TokenReason, actor *domain.TokenActor) (idToken string, exp uint64, err error) {
expTime := time.Now().Add(client.IDTokenLifetime()).Add(client.ClockSkew())
claims := oidc.NewIDTokenClaims(op.IssuerFromContext(ctx), userID, audience, expTime, authTime, "", "", AuthMethodTypesToAMR(authMethods), client.GetID(), client.ClockSkew())
claims.Actor = actorDomainToClaims(actor)
claims.SetUserInfo(userInfo)
if accessToken != "" {
claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signingKey.SignatureAlgorithm())
if err != nil {
return "", 0, err
}
}
signer, err := op.SignerFromKey(signingKey)
if err != nil {
return "", 0, err
}
idToken, err = crypto.Sign(claims, signer)
return idToken, timeToOIDCExpiresIn(expTime), err
}
func timeToOIDCExpiresIn(exp time.Time) uint64 {
return uint64(time.Until(exp) / time.Second)
}
func (s *Server) createAccessTokenCommands(ctx context.Context, client *Client, resourceOwner, userID string, audience, scopes []string, authMethods []domain.UserAuthMethodType, authTime time.Time, reason domain.TokenReason, actor *domain.TokenActor) (tokenInfo *domain.Token, refreshToken string, err error) {
settings := client.client.Settings
if slices.Contains(scopes, oidc.ScopeOfflineAccess) {
return s.command.AddAccessAndRefreshToken(
ctx, resourceOwner, "", client.GetID(), userID, "", audience, scopes, AuthMethodTypesToAMR(authMethods),
settings.AccessTokenLifetime, settings.RefreshTokenIdleExpiration, settings.RefreshTokenExpiration,
authTime, reason, actor,
)
}
tokenInfo, err = s.command.AddUserToken(
ctx, resourceOwner, "", client.GetID(), userID, audience, scopes, AuthMethodTypesToAMR(authMethods),
settings.AccessTokenLifetime,
authTime, reason, actor,
session, err := s.command.CreateOIDCSession(ctx,
userID,
resourceOwner,
client.client.ClientID,
scope,
audience,
authMethods,
authTime,
"",
preferredLanguage,
nil,
reason,
actor,
slices.Contains(scope, oidc.ScopeOfflineAccess),
)
return tokenInfo, "", err
if err != nil {
return "", "", "", 0, err
}
accessToken, err = op.CreateBearerToken(session.TokenID, userID, s.opCrypto)
if err != nil {
return "", "", "", 0, err
}
return accessToken, session.RefreshToken, session.SessionID, timeToOIDCExpiresIn(session.Expiration), nil
}
func (s *Server) createExchangeJWT(
ctx context.Context,
client *Client,
getUserInfo userInfoFunc,
getSigner signerFunc,
userID,
resourceOwner string,
audience,
scope []string,
authMethods []domain.UserAuthMethodType,
authTime time.Time,
preferredLanguage *language.Tag,
reason domain.TokenReason,
actor *domain.TokenActor,
) (accessToken string, refreshToken string, exp uint64, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
session, err := s.command.CreateOIDCSession(ctx,
userID,
resourceOwner,
client.client.ClientID,
scope,
audience,
authMethods,
authTime,
"",
preferredLanguage,
nil,
reason,
actor,
slices.Contains(scope, oidc.ScopeOfflineAccess),
)
accessToken, err = s.createJWT(ctx, client, session, getUserInfo, getSigner)
if err != nil {
return "", "", 0, err
}
return accessToken, session.RefreshToken, timeToOIDCExpiresIn(session.Expiration), nil
}

View File

@@ -4,21 +4,23 @@ import (
"time"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
)
type exchangeToken struct {
tokenType oidc.TokenType
userID string
issuer string
resourceOwner string
authTime time.Time
authMethods []domain.UserAuthMethodType
actor *domain.TokenActor
audience []string
scopes []string
tokenType oidc.TokenType
userID string
issuer string
resourceOwner string
authTime time.Time
authMethods []domain.UserAuthMethodType
actor *domain.TokenActor
audience []string
scopes []string
preferredLanguage *language.Tag
}
func (et *exchangeToken) nestedActor() *domain.TokenActor {
@@ -31,27 +33,33 @@ func (et *exchangeToken) nestedActor() *domain.TokenActor {
func accessToExchangeToken(token *accessToken, issuer string) *exchangeToken {
return &exchangeToken{
tokenType: oidc.AccessTokenType,
userID: token.userID,
issuer: issuer,
resourceOwner: token.resourceOwner,
authMethods: token.authMethods,
actor: token.actor,
audience: token.audience,
scopes: token.scope,
tokenType: oidc.AccessTokenType,
userID: token.userID,
issuer: issuer,
resourceOwner: token.resourceOwner,
authMethods: token.authMethods,
actor: token.actor,
audience: token.audience,
scopes: token.scope,
preferredLanguage: token.preferredLanguage,
}
}
func idTokenClaimsToExchangeToken(claims *oidc.IDTokenClaims, resourceOwner string) *exchangeToken {
var preferredLanguage *language.Tag
if tag := claims.Locale.Tag(); !tag.IsRoot() {
preferredLanguage = &tag
}
return &exchangeToken{
tokenType: oidc.IDTokenType,
userID: claims.Subject,
issuer: claims.Issuer,
resourceOwner: resourceOwner,
authTime: claims.GetAuthTime(),
authMethods: AMRToAuthMethodTypes(claims.AuthenticationMethodsReferences),
actor: actorClaimsToDomain(claims.Actor),
audience: claims.Audience,
tokenType: oidc.IDTokenType,
userID: claims.Subject,
issuer: claims.Issuer,
resourceOwner: resourceOwner,
authTime: claims.GetAuthTime(),
authMethods: AMRToAuthMethodTypes(claims.AuthenticationMethodsReferences),
actor: actorClaimsToDomain(claims.Actor),
audience: claims.Audience,
preferredLanguage: preferredLanguage,
}
}
@@ -77,7 +85,7 @@ func actorDomainToClaims(actor *domain.TokenActor) *oidc.ActorClaims {
}
}
func jwtToExchangeToken(jwt *oidc.JWTTokenRequest, resourceOwner string) *exchangeToken {
func jwtToExchangeToken(jwt *oidc.JWTTokenRequest, resourceOwner string, preferredLanguage *language.Tag) *exchangeToken {
return &exchangeToken{
tokenType: oidc.JWTTokenType,
userID: jwt.Subject,
@@ -86,6 +94,7 @@ func jwtToExchangeToken(jwt *oidc.JWTTokenRequest, resourceOwner string) *exchan
scopes: jwt.Scopes,
authTime: jwt.IssuedAt.AsTime(),
// audience omitted as we don't thrust audiences not signed by us
preferredLanguage: preferredLanguage,
}
}

View File

@@ -587,5 +587,5 @@ func TestImpersonation_API_Call(t *testing.T) {
_, err = Tester.Client.Admin.GetAllowedLanguages(impersonatedCTX, &admin.GetAllowedLanguagesRequest{})
status := status.Convert(err)
assert.Equal(t, codes.PermissionDenied, status.Code())
assert.Equal(t, "Errors.TokenExchange.Token.NotForAPI (APP-wai8O)", status.Message())
assert.Equal(t, "Errors.TokenExchange.Token.NotForAPI (APP-Shi0J)", status.Message())
}

View File

@@ -0,0 +1,99 @@
package oidc
import (
"context"
"time"
"github.com/go-jose/go-jose/v4"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
func (s *Server) JWTProfile(ctx context.Context, r *op.Request[oidc.JWTProfileGrantRequest]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
span.EndWithError(err)
err = oidcError(err)
}()
user, jwtReq, err := s.verifyJWTProfile(ctx, r.Data)
if err != nil {
return nil, err
}
client := &clientCredentialsClient{
id: jwtReq.Subject,
user: user,
}
scope, err := op.ValidateAuthReqScopes(client, r.Data.Scope)
if err != nil {
return nil, err
}
scope, err = s.checkOrgScopes(ctx, client.user, scope)
if err != nil {
return nil, err
}
session, err := s.command.CreateOIDCSession(ctx,
user.ID,
user.ResourceOwner,
jwtReq.Subject,
scope,
domain.AddAudScopeToAudience(ctx, nil, r.Data.Scope),
nil,
time.Now(),
"",
nil,
nil,
domain.TokenReasonClientCredentials,
nil,
false,
)
return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false))
}
func (s *Server) verifyJWTProfile(ctx context.Context, req *oidc.JWTProfileGrantRequest) (user *query.User, tokenRequest *oidc.JWTTokenRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
checkSubject := func(jwt *oidc.JWTTokenRequest) (err error) {
user, err = s.query.GetUserByID(ctx, true, jwt.Subject)
return err
}
verifier := op.NewJWTProfileVerifier(
&jwtProfileKeyStorage{query: s.query},
op.IssuerFromContext(ctx),
time.Hour, time.Second,
op.SubjectCheck(checkSubject),
)
tokenRequest, err = op.VerifyJWTAssertion(ctx, req.Assertion, verifier)
if err != nil {
return nil, nil, err
}
return user, tokenRequest, nil
}
type jwtProfileKeyStorage struct {
query *query.Queries
}
func (s *jwtProfileKeyStorage) GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (*jose.JSONWebKey, error) {
publicKeyData, err := s.query.GetAuthNKeyPublicKeyByIDAndIdentifier(ctx, keyID, userID)
if err != nil {
return nil, err
}
publicKey, err := crypto.BytesToPublicKey(publicKeyData)
if err != nil {
return nil, err
}
return &jose.JSONWebKey{
KeyID: keyID,
Use: "sig",
Key: publicKey,
}, nil
}

View File

@@ -0,0 +1,103 @@
//go:build integration
package oidc_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/client/profile"
"github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v3/pkg/oidc"
oidc_api "github.com/zitadel/zitadel/internal/api/oidc"
"github.com/zitadel/zitadel/internal/domain"
)
func TestServer_JWTProfile(t *testing.T) {
userID, keyData, err := Tester.CreateOIDCJWTProfileClient(CTX)
require.NoError(t, err)
type claims struct {
resourceOwnerID any
resourceOwnerName any
resourceOwnerPrimaryDomain any
orgDomain any
}
tests := []struct {
name string
keyData []byte
scope []string
wantClaims claims
wantErr bool
}{
{
name: "success",
keyData: keyData,
scope: []string{oidc.ScopeOpenID},
},
{
name: "org id and domain scope",
keyData: keyData,
scope: []string{
oidc.ScopeOpenID,
domain.OrgIDScope + Tester.Organisation.ID,
domain.OrgDomainPrimaryScope + Tester.Organisation.Domain,
},
wantClaims: claims{
resourceOwnerID: Tester.Organisation.ID,
resourceOwnerName: Tester.Organisation.Name,
resourceOwnerPrimaryDomain: Tester.Organisation.Domain,
orgDomain: Tester.Organisation.Domain,
},
},
{
name: "invalid org domain filtered",
keyData: keyData,
scope: []string{
oidc.ScopeOpenID,
domain.OrgDomainPrimaryScope + Tester.Organisation.Domain,
domain.OrgDomainPrimaryScope + "foo"},
wantClaims: claims{
orgDomain: Tester.Organisation.Domain,
},
},
{
name: "invalid org id filtered",
keyData: keyData,
scope: []string{oidc.ScopeOpenID,
domain.OrgIDScope + Tester.Organisation.ID,
domain.OrgIDScope + "foo",
},
wantClaims: claims{
resourceOwnerID: Tester.Organisation.ID,
resourceOwnerName: Tester.Organisation.Name,
resourceOwnerPrimaryDomain: Tester.Organisation.Domain,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokenSource, err := profile.NewJWTProfileTokenSourceFromKeyFileData(CTX, Tester.OIDCIssuer(), tt.keyData, tt.scope)
require.NoError(t, err)
tokens, err := tokenSource.TokenCtx(CTX)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.NotNil(t, tokens)
provider, err := rp.NewRelyingPartyOIDC(CTX, Tester.OIDCIssuer(), "", "", redirectURI, tt.scope)
require.NoError(t, err)
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, userID, provider)
require.NoError(t, err)
assert.Equal(t, tt.wantClaims.resourceOwnerID, userinfo.Claims[oidc_api.ClaimResourceOwnerID])
assert.Equal(t, tt.wantClaims.resourceOwnerName, userinfo.Claims[oidc_api.ClaimResourceOwnerName])
assert.Equal(t, tt.wantClaims.resourceOwnerPrimaryDomain, userinfo.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain])
assert.Equal(t, tt.wantClaims.orgDomain, userinfo.Claims[domain.OrgDomainPrimaryClaim])
})
}
}

View File

@@ -0,0 +1,101 @@
package oidc
import (
"context"
"errors"
"slices"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
func (s *Server) RefreshToken(ctx context.Context, r *op.ClientRequest[oidc.RefreshTokenRequest]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
span.EndWithError(err)
err = oidcError(err)
}()
client, ok := r.Client.(*Client)
if !ok {
return nil, zerrors.ThrowInternal(nil, "OIDC-ga0EP", "Error.Internal")
}
session, err := s.command.ExchangeOIDCSessionRefreshAndAccessToken(ctx, r.Data.RefreshToken, r.Data.Scopes, refreshTokenComplianceChecker())
if err == nil {
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion))
} else if errors.Is(err, zerrors.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid")) {
// We try again for v1 tokens when we encountered specific parsing error
return s.refreshTokenV1(ctx, client, r)
}
return nil, err
}
// refreshTokenV1 verifies a v1 refresh token.
// When valid a v2 OIDC session is created and v2 tokens are returned.
// This "upgrades" existing v1 sessions to v2 session without requiring users to re-login.
//
// This function can be removed when we retire the v1 token repo.
func (s *Server) refreshTokenV1(ctx context.Context, client *Client, r *op.ClientRequest[oidc.RefreshTokenRequest]) (_ *op.Response, err error) {
refreshToken, err := s.repo.RefreshTokenByToken(ctx, r.Data.RefreshToken)
if err != nil {
return nil, err
}
scope, err := validateRefreshTokenScopes(refreshToken.Scopes, r.Data.Scopes)
if err != nil {
return nil, err
}
session, err := s.command.CreateOIDCSession(ctx,
refreshToken.UserID,
refreshToken.ResourceOwner,
refreshToken.ClientID,
scope,
refreshToken.Audience,
AMRToAuthMethodTypes(refreshToken.AuthMethodsReferences),
refreshToken.AuthTime,
"",
nil, // Preferred language not in refresh token view
&domain.UserAgent{
FingerprintID: &refreshToken.UserAgentID,
Description: &refreshToken.UserAgentID,
},
domain.TokenReasonRefresh,
refreshToken.Actor,
true,
)
if err != nil {
return nil, err
}
// make sure the v1 refresh token can't be reused.
_, err = s.command.RevokeRefreshToken(ctx, refreshToken.UserID, refreshToken.ResourceOwner, refreshToken.ID)
if err != nil {
return nil, err
}
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion))
}
// refreshTokenComplianceChecker validates that the requested scope is a subset of the original auth request scope.
func refreshTokenComplianceChecker() command.RefreshTokenComplianceChecker {
return func(_ context.Context, model *command.OIDCSessionWriteModel, requestedScope []string) ([]string, error) {
return validateRefreshTokenScopes(model.Scope, requestedScope)
}
}
func validateRefreshTokenScopes(currentScope, requestedScope []string) ([]string, error) {
if len(requestedScope) == 0 {
return currentScope, nil
}
for _, s := range requestedScope {
if !slices.Contains(currentScope, s) {
return nil, oidc.ErrInvalidScope()
}
}
return requestedScope, nil
}

View File

@@ -20,6 +20,7 @@ import (
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoRequest]) (_ *op.Response, err error) {
@@ -48,7 +49,8 @@ func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoReques
)
if token.clientID != "" {
projectID, assertion, err = s.query.GetOIDCUserinfoClientByID(ctx, token.clientID)
if err != nil {
// token.clientID might contain a username (e.g. client credentials) -> ignore the not found
if err != nil && !zerrors.IsNotFound(err) {
return nil, err
}
}