zitadel/internal/api/oidc/token.go
Tim Möhlmann 8e0c8393e9
perf(oidc): optimize token creation (#7822)
* implement code exchange

* port tokenexchange to v2 tokens

* implement refresh token

* implement client credentials

* implement jwt profile

* implement device token

* cleanup unused code

* fix current unit tests

* add user agent unit test

* unit test domain package

* need refresh token as argument

* test commands create oidc session

* test commands device auth

* fix device auth build error

* implicit for oidc session API

* implement authorize callback handler for legacy implicit mode

* upgrade oidc module to working draft

* add missing auth methods and time

* handle all errors in defer

* do not fail auth request on error

the oauth2 Go client automagically retries on any error. If we fail the auth request on the first error, the next attempt will always fail with the Errors.AuthRequest.NoCode, because the auth request state is already set to failed.
The original error is then already lost and the oauth2 library does not return the original error.

Therefore we should not fail the auth request.

Might be worth discussing and perhaps send a bug report to Oauth2?

* fix code flow tests by explicitly setting code exchanged

* fix unit tests in command package

* return allowed scope from client credential client

* add device auth done reducer

* carry nonce thru session into ID token

* fix token exchange integration tests

* allow project role scope prefix in client credentials client

* gci formatting

* do not return refresh token in client credentials and jwt profile

* check org scope

* solve linting issue on authorize callback error

* end session based on v2 session ID

* use preferred language and user agent ID for v2 access tokens

* pin oidc v3.23.2

* add integration test for jwt profile and client credentials with org scopes

* refresh token v1 to v2

* add user token v2 audit event

* add activity trigger

* cleanup and set panics for unused methods

* use the encrypted code for v1 auth request get by code

* add missing event translation

* fix pipeline errors (hopefully)

* fix another test

* revert pointer usage of preferred language

* solve browser info panic in device auth

* remove duplicate entries in AMRToAuthMethodTypes to prevent future `mfa` claim

* revoke v1 refresh token to prevent reuse

* fix terminate oidc session

* always return a new refresh toke in refresh token grant

---------

Co-authored-by: Livio Spring <livio.a@gmail.com>
2024-05-16 07:07:56 +02:00

199 lines
6.2 KiB
Go

package oidc
import (
"context"
"encoding/base64"
"slices"
"sync"
"time"
"github.com/go-jose/go-jose/v4"
"github.com/zitadel/oidc/v3/pkg/crypto"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
/*
For each grant-type, tokens creation follows the same rough logical steps:
1. Information gathering: who is requesting the token, what do we put in the claims?
2. Decision making: is the request authorized? (valid exchange code, auth request completed, valid token etc...)
3. Build an OIDC session in storage: inform the eventstore we are creating tokens.
4. Use the OIDC session to encrypt and / or sign the requested tokens
In some cases step 1 till 3 are completely implemented in the command package,
for example the v2 code exchange and refresh token.
*/
func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.Client, session *command.OIDCSession, state, projectID string, projectRoleAssertion bool) (_ *oidc.AccessTokenResponse, err error) {
getUserInfo := s.getUserInfoOnce(session.UserID, projectID, projectRoleAssertion, session.Scope)
getSigner := s.getSignerOnce()
resp := &oidc.AccessTokenResponse{
TokenType: oidc.BearerToken,
RefreshToken: session.RefreshToken,
ExpiresIn: timeToOIDCExpiresIn(session.Expiration),
State: state,
}
// If the session does not have a token ID, it is an implicit ID-Token only response.
if session.TokenID != "" {
if client.AccessTokenType() == op.AccessTokenTypeJWT {
resp.AccessToken, err = s.createJWT(ctx, client, session, getUserInfo, getSigner)
} else {
resp.AccessToken, err = op.CreateBearerToken(session.TokenID, session.UserID, s.opCrypto)
}
if err != nil {
return nil, err
}
}
if slices.Contains(session.Scope, oidc.ScopeOpenID) {
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, session.SessionID, resp.AccessToken, session.Audience, session.AuthMethods, session.AuthTime, session.Nonce, session.Actor)
}
return resp, err
}
// signerFunc is a getter function that allows add-hoc retrieval of the instance's signer.
type signerFunc func(ctx context.Context) (jose.Signer, jose.SignatureAlgorithm, error)
// getSignerOnce returns a function which retrieves the instance's signer from the database once.
// Repeated calls of the returned function return the same results.
func (s *Server) getSignerOnce() signerFunc {
var (
once sync.Once
signer jose.Signer
signAlg jose.SignatureAlgorithm
err error
)
return func(ctx context.Context) (jose.Signer, jose.SignatureAlgorithm, error) {
once.Do(func() {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
var signingKey op.SigningKey
signingKey, err = s.Provider().Storage().SigningKey(ctx)
if err != nil {
return
}
signAlg = signingKey.SignatureAlgorithm()
signer, err = op.SignerFromKey(signingKey)
if err != nil {
return
}
})
return signer, signAlg, err
}
}
// userInfoFunc is a getter function that allows add-hoc retrieval of a user.
type userInfoFunc func(ctx context.Context) (*oidc.UserInfo, error)
// getUserInfoOnce returns a function which retrieves userinfo from the database once.
// Repeated calls of the returned function return the same results.
func (s *Server) getUserInfoOnce(userID, projectID string, projectRoleAssertion bool, scope []string) userInfoFunc {
var (
once sync.Once
userInfo *oidc.UserInfo
err error
)
return func(ctx context.Context) (*oidc.UserInfo, error) {
once.Do(func() {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
userInfo, err = s.userInfo(ctx, userID, scope, projectID, projectRoleAssertion, false)
})
return userInfo, err
}
}
func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo userInfoFunc, getSigningKey signerFunc, sessionID, accessToken string, audience []string, authMethods []domain.UserAuthMethodType, authTime time.Time, nonce string, actor *domain.TokenActor) (idToken string, exp uint64, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
userInfo, err := getUserInfo(ctx)
if err != nil {
return "", 0, err
}
signer, signAlg, err := getSigningKey(ctx)
if err != nil {
return "", 0, err
}
expTime := time.Now().Add(client.IDTokenLifetime()).Add(client.ClockSkew())
claims := oidc.NewIDTokenClaims(
op.IssuerFromContext(ctx),
"",
audience,
expTime,
authTime,
nonce,
"",
AuthMethodTypesToAMR(authMethods),
client.GetID(),
client.ClockSkew(),
)
claims.SessionID = sessionID
claims.Actor = actorDomainToClaims(actor)
claims.SetUserInfo(userInfo)
if accessToken != "" {
claims.AccessTokenHash, err = oidc.ClaimHash(accessToken, signAlg)
if err != nil {
return "", 0, err
}
}
idToken, err = crypto.Sign(claims, signer)
return idToken, timeToOIDCExpiresIn(expTime), err
}
func timeToOIDCExpiresIn(exp time.Time) uint64 {
return uint64(time.Until(exp) / time.Second)
}
func (*Server) createJWT(ctx context.Context, client op.Client, session *command.OIDCSession, getUserInfo userInfoFunc, getSigner signerFunc) (_ string, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
userInfo, err := getUserInfo(ctx)
if err != nil {
return "", err
}
signer, _, err := getSigner(ctx)
if err != nil {
return "", err
}
expTime := session.Expiration.Add(client.ClockSkew())
claims := oidc.NewAccessTokenClaims(
op.IssuerFromContext(ctx),
userInfo.Subject,
session.Audience,
expTime,
session.TokenID,
client.GetID(),
client.ClockSkew(),
)
claims.Actor = actorDomainToClaims(session.Actor)
claims.Claims = userInfo.Claims
return crypto.Sign(claims, signer)
}
// decryptCode decrypts a code or refresh_token
func (s *Server) decryptCode(ctx context.Context, code string) (_ string, err error) {
_, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
decoded, err := base64.RawURLEncoding.DecodeString(code)
if err != nil {
return "", err
}
return s.encAlg.DecryptString(decoded, s.encAlg.EncryptionKeyID())
}