fix: uniform oidc errors (#7237)

* fix: uniform oidc errors

sanitize oidc error reporting when passing package boundary towards oidc.

* add should TriggerBulk in get audiences for auth request

* upgrade to oidc 3.10.1

* provisional oidc upgrade to error branch

* pin oidc 3.10.2
This commit is contained in:
Tim Möhlmann 2024-01-18 08:10:49 +02:00 committed by GitHub
parent cdfcdec101
commit af4e0484d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 267 additions and 61 deletions

2
go.mod
View File

@ -61,7 +61,7 @@ require (
github.com/superseriousbusiness/exifremove v0.0.0-20210330092427-6acd27eac203 github.com/superseriousbusiness/exifremove v0.0.0-20210330092427-6acd27eac203
github.com/ttacon/libphonenumber v1.2.1 github.com/ttacon/libphonenumber v1.2.1
github.com/zitadel/logging v0.5.0 github.com/zitadel/logging v0.5.0
github.com/zitadel/oidc/v3 v3.10.0 github.com/zitadel/oidc/v3 v3.10.2
github.com/zitadel/passwap v0.5.0 github.com/zitadel/passwap v0.5.0
github.com/zitadel/saml v0.1.3 github.com/zitadel/saml v0.1.3
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.46.1 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.46.1

4
go.sum
View File

@ -782,8 +782,8 @@ github.com/zenazn/goji v1.0.1 h1:4lbD8Mx2h7IvloP7r2C0D6ltZP6Ufip8Hn0wmSK5LR8=
github.com/zenazn/goji v1.0.1/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= github.com/zenazn/goji v1.0.1/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
github.com/zitadel/logging v0.5.0 h1:Kunouvqse/efXy4UDvFw5s3vP+Z4AlHo3y8wF7stXHA= github.com/zitadel/logging v0.5.0 h1:Kunouvqse/efXy4UDvFw5s3vP+Z4AlHo3y8wF7stXHA=
github.com/zitadel/logging v0.5.0/go.mod h1:IzP5fzwFhzzyxHkSmfF8dsyqFsQRJLLcQmwhIBzlGsE= github.com/zitadel/logging v0.5.0/go.mod h1:IzP5fzwFhzzyxHkSmfF8dsyqFsQRJLLcQmwhIBzlGsE=
github.com/zitadel/oidc/v3 v3.10.0 h1:qAGlw6FGQEpkWya8tT03P6pU4AHNrZ0Kfyxmwsd4am0= github.com/zitadel/oidc/v3 v3.10.2 h1:nowZrpOBR4tdIlYXE8/l5Nl84QDYwyHpccIE1l2OAd4=
github.com/zitadel/oidc/v3 v3.10.0/go.mod h1:nfjWH8ps4B7T0JGJyLLOIUlhr0Z4becyGKui/sXYpA8= github.com/zitadel/oidc/v3 v3.10.2/go.mod h1:nfjWH8ps4B7T0JGJyLLOIUlhr0Z4becyGKui/sXYpA8=
github.com/zitadel/passwap v0.5.0 h1:kFMoRyo0GnxtOz7j9+r/CsRwSCjHGRaAKoUe69NwPvs= github.com/zitadel/passwap v0.5.0 h1:kFMoRyo0GnxtOz7j9+r/CsRwSCjHGRaAKoUe69NwPvs=
github.com/zitadel/passwap v0.5.0/go.mod h1:uqY7D3jqdTFcKsW0Q3Pcv5qDMmSHpVTzUZewUKC1KZA= github.com/zitadel/passwap v0.5.0/go.mod h1:uqY7D3jqdTFcKsW0Q3Pcv5qDMmSHpVTzUZewUKC1KZA=
github.com/zitadel/saml v0.1.3 h1:LI4DOCVyyU1qKPkzs3vrGcA5J3H4pH3+CL9zr9ShkpM= github.com/zitadel/saml v0.1.3 h1:LI4DOCVyyU1qKPkzs3vrGcA5J3H4pH3+CL9zr9ShkpM=

View File

@ -27,20 +27,22 @@ type accessToken struct {
isPAT bool isPAT bool
} }
var ErrInvalidTokenFormat = errors.New("invalid token format")
func (s *Server) verifyAccessToken(ctx context.Context, tkn string) (*accessToken, error) { func (s *Server) verifyAccessToken(ctx context.Context, tkn string) (*accessToken, error) {
var tokenID, subject string var tokenID, subject string
if tokenIDSubject, err := s.Provider().Crypto().Decrypt(tkn); err == nil { if tokenIDSubject, err := s.Provider().Crypto().Decrypt(tkn); err == nil {
split := strings.Split(tokenIDSubject, ":") split := strings.Split(tokenIDSubject, ":")
if len(split) != 2 { if len(split) != 2 {
return nil, errors.New("invalid token format") return nil, zerrors.ThrowPermissionDenied(ErrInvalidTokenFormat, "OIDC-rei1O", "token is not valid or has expired")
} }
tokenID, subject = split[0], split[1] tokenID, subject = split[0], split[1]
} else { } else {
verifier := op.NewAccessTokenVerifier(op.IssuerFromContext(ctx), s.keySet) verifier := op.NewAccessTokenVerifier(op.IssuerFromContext(ctx), s.keySet)
claims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](ctx, tkn, verifier) claims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](ctx, tkn, verifier)
if err != nil { if err != nil {
return nil, err return nil, zerrors.ThrowPermissionDenied(err, "OIDC-Eib8e", "token is not valid or has expired")
} }
tokenID, subject = claims.JWTID, claims.Subject tokenID, subject = claims.JWTID, claims.Subject
} }

View File

@ -28,7 +28,10 @@ const (
func (o *OPStorage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (_ op.AuthRequest, err error) { func (o *OPStorage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (_ op.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
headers, _ := http_utils.HeadersFromCtx(ctx) headers, _ := http_utils.HeadersFromCtx(ctx)
if loginClient := headers.Get(LoginClientHeader); loginClient != "" { if loginClient := headers.Get(LoginClientHeader); loginClient != "" {
@ -102,7 +105,7 @@ func (o *OPStorage) audienceFromProjectID(ctx context.Context, projectID string)
if err != nil { if err != nil {
return nil, err return nil, err
} }
appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}) appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -112,7 +115,10 @@ func (o *OPStorage) audienceFromProjectID(ctx context.Context, projectID string)
func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRequest, err error) { func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
if strings.HasPrefix(id, command.IDPrefixV2) { if strings.HasPrefix(id, command.IDPrefixV2) {
req, err := o.command.GetCurrentAuthRequest(ctx, id) req, err := o.command.GetCurrentAuthRequest(ctx, id)
@ -135,7 +141,10 @@ func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRe
func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.AuthRequest, err error) { func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
plainCode, err := o.decryptGrant(code) plainCode, err := o.decryptGrant(code)
if err != nil { if err != nil {
@ -166,7 +175,10 @@ func (o *OPStorage) decryptGrant(grant string) (string, error) {
func (o *OPStorage) SaveAuthCode(ctx context.Context, id, code string) (err error) { func (o *OPStorage) SaveAuthCode(ctx context.Context, id, code string) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
if strings.HasPrefix(id, command.IDPrefixV2) { if strings.HasPrefix(id, command.IDPrefixV2) {
return o.command.AddAuthRequestCode(ctx, id, code) return o.command.AddAuthRequestCode(ctx, id, code)
@ -181,14 +193,20 @@ func (o *OPStorage) SaveAuthCode(ctx context.Context, id, code string) (err erro
func (o *OPStorage) DeleteAuthRequest(ctx context.Context, id string) (err error) { func (o *OPStorage) DeleteAuthRequest(ctx context.Context, id string) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
return o.repo.DeleteAuthRequest(ctx, id) return o.repo.DeleteAuthRequest(ctx, id)
} }
func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (_ string, _ time.Time, err error) { func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (_ string, _ time.Time, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
var userAgentID, applicationID, userOrgID string var userAgentID, applicationID, userOrgID string
switch authReq := req.(type) { switch authReq := req.(type) {
@ -221,7 +239,10 @@ func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest)
func (o *OPStorage) CreateAccessAndRefreshTokens(ctx context.Context, req op.TokenRequest, refreshToken string) (_, _ string, _ time.Time, err error) { func (o *OPStorage) CreateAccessAndRefreshTokens(ctx context.Context, req op.TokenRequest, refreshToken string) (_, _ string, _ time.Time, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
// handle V2 request directly // handle V2 request directly
switch tokenReq := req.(type) { switch tokenReq := req.(type) {
@ -279,7 +300,10 @@ func getInfoFromRequest(req op.TokenRequest) (string, string, string, time.Time,
func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (_ op.RefreshTokenRequest, err error) { func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (_ op.RefreshTokenRequest, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
plainToken, err := o.decryptGrant(refreshToken) plainToken, err := o.decryptGrant(refreshToken)
if err != nil { if err != nil {
@ -307,7 +331,10 @@ func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken
func (o *OPStorage) TerminateSession(ctx context.Context, userID, clientID string) (err error) { func (o *OPStorage) TerminateSession(ctx context.Context, userID, clientID string) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx) userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok { if !ok {
logging.Error("no user agent id") logging.Error("no user agent id")
@ -331,7 +358,10 @@ func (o *OPStorage) TerminateSession(ctx context.Context, userID, clientID strin
func (o *OPStorage) TerminateSessionFromRequest(ctx context.Context, endSessionRequest *op.EndSessionRequest) (redirectURI string, err error) { func (o *OPStorage) TerminateSessionFromRequest(ctx context.Context, endSessionRequest *op.EndSessionRequest) (redirectURI string, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
// check for the login client header // check for the login client header
// and if not provided, terminate the session using the V1 method // and if not provided, terminate the session using the V1 method
@ -408,6 +438,12 @@ func (o *OPStorage) revokeTokenV1(ctx context.Context, token, userID, clientID s
} }
func (o *OPStorage) GetRefreshTokenInfo(ctx context.Context, clientID string, token string) (userID string, tokenID string, err error) { func (o *OPStorage) GetRefreshTokenInfo(ctx context.Context, clientID string, token string) (userID string, tokenID string, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
plainToken, err := o.decryptGrant(token) plainToken, err := o.decryptGrant(token)
if err != nil { if err != nil {
return "", "", op.ErrInvalidRefreshToken return "", "", op.ErrInvalidRefreshToken

View File

@ -51,7 +51,7 @@ func TestOPStorage_CreateAccessToken_code(t *testing.T) {
// test code exchange // test code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, false) assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
@ -69,7 +69,7 @@ func TestOPStorage_CreateAccessToken_code(t *testing.T) {
require.Error(t, err) require.Error(t, err)
// exchange with a used code must fail // exchange with a used code must fail
_, err = exchangeTokens(t, clientID, code) _, err = exchangeTokens(t, clientID, code, redirectURI)
require.Error(t, err) require.Error(t, err)
} }
@ -140,7 +140,7 @@ func TestOPStorage_CreateAccessAndRefreshTokens_code(t *testing.T) {
// test code exchange (expect refresh token to be returned) // test code exchange (expect refresh token to be returned)
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
@ -165,7 +165,7 @@ func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) {
// code exchange // code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
@ -201,7 +201,7 @@ func TestOPStorage_RevokeToken_access_token(t *testing.T) {
// code exchange // code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
@ -244,7 +244,7 @@ func TestOPStorage_RevokeToken_access_token_invalid_token_hint_type(t *testing.T
// code exchange // code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
@ -281,7 +281,7 @@ func TestOPStorage_RevokeToken_refresh_token(t *testing.T) {
// code exchange // code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
@ -324,7 +324,7 @@ func TestOPStorage_RevokeToken_refresh_token_invalid_token_type_hint(t *testing.
// code exchange // code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
@ -359,7 +359,7 @@ func TestOPStorage_RevokeToken_invalid_client(t *testing.T) {
// code exchange // code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
@ -391,7 +391,7 @@ func TestOPStorage_TerminateSession(t *testing.T) {
// test code exchange // test code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, false) assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
@ -428,7 +428,7 @@ func TestOPStorage_TerminateSession_refresh_grant(t *testing.T) {
// test code exchange // test code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
@ -472,7 +472,7 @@ func TestOPStorage_TerminateSession_empty_id_token_hint(t *testing.T) {
// test code exchange // test code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, false) assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
@ -497,7 +497,7 @@ func TestOPStorage_TerminateSession_empty_id_token_hint(t *testing.T) {
require.Error(t, err) require.Error(t, err)
} }
func exchangeTokens(t testing.TB, clientID, code string) (*oidc.Tokens[*oidc.IDTokenClaims], error) { func exchangeTokens(t testing.TB, clientID, code, redirectURI string) (*oidc.Tokens[*oidc.IDTokenClaims], error) {
provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI) provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI)
require.NoError(t, err) require.NoError(t, err)

View File

@ -42,7 +42,10 @@ const (
func (o *OPStorage) GetClientByClientID(ctx context.Context, id string) (_ op.Client, err error) { func (o *OPStorage) GetClientByClientID(ctx context.Context, id string) (_ op.Client, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
client, err := o.query.GetOIDCClientByID(ctx, id, false) client, err := o.query.GetOIDCClientByID(ctx, id, false)
if err != nil { if err != nil {
return nil, err return nil, err
@ -59,7 +62,10 @@ func (o *OPStorage) GetKeyByIDAndClientID(ctx context.Context, keyID, userID str
func (o *OPStorage) GetKeyByIDAndIssuer(ctx context.Context, keyID, issuer string) (_ *jose.JSONWebKey, err error) { func (o *OPStorage) GetKeyByIDAndIssuer(ctx context.Context, keyID, issuer string) (_ *jose.JSONWebKey, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
publicKeyData, err := o.query.GetAuthNKeyPublicKeyByIDAndIdentifier(ctx, keyID, issuer, false) publicKeyData, err := o.query.GetAuthNKeyPublicKeyByIDAndIdentifier(ctx, keyID, issuer, false)
if err != nil { if err != nil {
return nil, err return nil, err
@ -75,7 +81,12 @@ func (o *OPStorage) GetKeyByIDAndIssuer(ctx context.Context, keyID, issuer strin
}, nil }, nil
} }
func (o *OPStorage) ValidateJWTProfileScopes(ctx context.Context, subject string, scopes []string) ([]string, error) { func (o *OPStorage) ValidateJWTProfileScopes(ctx context.Context, subject string, scopes []string) (_ []string, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
user, err := o.query.GetUserByID(ctx, true, subject) user, err := o.query.GetUserByID(ctx, true, subject)
if err != nil { if err != nil {
return nil, err return nil, err
@ -85,7 +96,10 @@ func (o *OPStorage) ValidateJWTProfileScopes(ctx context.Context, subject string
func (o *OPStorage) AuthorizeClientIDSecret(ctx context.Context, id string, secret string) (err error) { func (o *OPStorage) AuthorizeClientIDSecret(ctx context.Context, id string, secret string) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
ctx = authz.SetCtxData(ctx, authz.CtxData{ ctx = authz.SetCtxData(ctx, authz.CtxData{
UserID: oidcCtx, UserID: oidcCtx,
OrgID: oidcCtx, OrgID: oidcCtx,
@ -102,7 +116,10 @@ func (o *OPStorage) AuthorizeClientIDSecret(ctx context.Context, id string, secr
func (o *OPStorage) SetUserinfoFromToken(ctx context.Context, userInfo *oidc.UserInfo, tokenID, subject, origin string) (err error) { func (o *OPStorage) SetUserinfoFromToken(ctx context.Context, userInfo *oidc.UserInfo, tokenID, subject, origin string) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
if strings.HasPrefix(tokenID, command.IDPrefixV2) { if strings.HasPrefix(tokenID, command.IDPrefixV2) {
token, err := o.query.ActiveAccessTokenByToken(ctx, tokenID) token, err := o.query.ActiveAccessTokenByToken(ctx, tokenID)
@ -129,7 +146,10 @@ func (o *OPStorage) SetUserinfoFromToken(ctx context.Context, userInfo *oidc.Use
func (o *OPStorage) SetUserinfoFromScopes(ctx context.Context, userInfo *oidc.UserInfo, userID, applicationID string, scopes []string) (err error) { func (o *OPStorage) SetUserinfoFromScopes(ctx context.Context, userInfo *oidc.UserInfo, userID, applicationID string, scopes []string) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
if applicationID != "" { if applicationID != "" {
app, err := o.query.AppByOIDCClientID(ctx, applicationID) app, err := o.query.AppByOIDCClientID(ctx, applicationID)
if err != nil { if err != nil {
@ -159,7 +179,10 @@ func (o *OPStorage) SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.U
func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) (err error) { func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
if strings.HasPrefix(tokenID, command.IDPrefixV2) { if strings.HasPrefix(tokenID, command.IDPrefixV2) {
token, err := o.query.ActiveAccessTokenByToken(ctx, tokenID) token, err := o.query.ActiveAccessTokenByToken(ctx, tokenID)
@ -196,7 +219,12 @@ func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection
token.CreationDate, token.Expiration) token.CreationDate, token.Expiration)
} }
func (o *OPStorage) ClientCredentialsTokenRequest(ctx context.Context, clientID string, scope []string) (op.TokenRequest, error) { func (o *OPStorage) ClientCredentialsTokenRequest(ctx context.Context, clientID string, scope []string) (_ op.TokenRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
user, err := o.query.GetUserByLoginName(ctx, false, clientID) user, err := o.query.GetUserByLoginName(ctx, false, clientID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -545,6 +573,12 @@ func (o *OPStorage) userinfoFlows(ctx context.Context, user *query.User, userGra
} }
func (o *OPStorage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]interface{}, err error) { func (o *OPStorage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]interface{}, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
roles := make([]string, 0) roles := make([]string, 0)
var allRoles bool var allRoles bool
for _, scope := range scopes { for _, scope := range scopes {
@ -903,7 +937,10 @@ func userinfoClaims(userInfo *oidc.UserInfo) func(c *actions.FieldConfig) interf
func (s *Server) VerifyClient(ctx context.Context, r *op.Request[op.ClientCredentials]) (_ op.Client, err error) { func (s *Server) VerifyClient(ctx context.Context, r *op.Request[op.ClientCredentials]) (_ op.Client, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
if oidc.GrantType(r.Form.Get("grant_type")) == oidc.GrantTypeClientCredentials { if oidc.GrantType(r.Form.Get("grant_type")) == oidc.GrantTypeClientCredentials {
return s.clientCredentialsAuth(ctx, r.Data.ClientID, r.Data.ClientSecret) return s.clientCredentialsAuth(ctx, r.Data.ClientID, r.Data.ClientSecret)

View File

@ -43,7 +43,7 @@ func TestOPStorage_SetUserinfoFromToken(t *testing.T) {
// code exchange // code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
@ -152,7 +152,7 @@ func TestServer_Introspect(t *testing.T) {
// code exchange // code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, app.GetClientId(), code) tokens, err := exchangeTokens(t, app.GetClientId(), code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)

View File

@ -109,14 +109,9 @@ func newDeviceAuthorizationState(d *query.DeviceAuth) *op.DeviceAuthorizationSta
// As generated user codes are of low entropy, this implementation also takes care or // As generated user codes are of low entropy, this implementation also takes care or
// device authorization request cleanup, when it has been Approved, Denied or Expired. // device authorization request cleanup, when it has been Approved, Denied or Expired.
func (o *OPStorage) GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (state *op.DeviceAuthorizationState, err error) { func (o *OPStorage) GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (state *op.DeviceAuthorizationState, err error) {
const logMsg = "get device authorization state"
logger := logging.WithFields("device_code", deviceCode)
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { defer func() {
if err != nil { err = oidcError(err)
logger.WithError(err).Error(logMsg)
}
span.EndWithError(err) span.EndWithError(err)
}() }()
@ -124,7 +119,8 @@ func (o *OPStorage) GetDeviceAuthorizatonState(ctx context.Context, clientID, de
if err != nil { if err != nil {
return nil, err return nil, err
} }
logger.SetFields( logging.WithFields(
"device_code", deviceCode,
"expires", deviceAuth.Expires, "scopes", deviceAuth.Scopes, "expires", deviceAuth.Expires, "scopes", deviceAuth.Scopes,
"subject", deviceAuth.Subject, "state", deviceAuth.State, "subject", deviceAuth.Subject, "state", deviceAuth.State,
).Debug("device authorization state") ).Debug("device authorization state")

View File

@ -0,0 +1,49 @@
package oidc
import (
"errors"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
http_util "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/zerrors"
)
// oidcError ensures [*oidc.Error] and [op.StatusError] types for err.
// It must be used when an error passes the package boundary towards oidc.
// When err is already of the correct type is passed as-is.
// If the err is a Zitadel error, it is transformed with a proper HTTP status code.
// Unknown errors are treated as internal server errors.
func oidcError(err error) error {
if err == nil {
return nil
}
var (
sError op.StatusError
oError *oidc.Error
zError *zerrors.ZitadelError
)
if errors.As(err, &sError) || errors.As(err, &oError) {
return err
}
// here we are encountering an error type that is completely unknown to us.
if !errors.As(err, &zError) {
err = zerrors.ThrowInternal(err, "OIDC-AhX2u", "Errors.Internal")
errors.As(err, &zError)
}
statusCode, _ := http_util.ZitadelErrorToHTTPStatusCode(err)
newOidcErr := oidc.ErrServerError
if statusCode < 500 {
newOidcErr = oidc.ErrInvalidRequest
}
return op.NewStatusError(
newOidcErr().
WithParent(err).
WithDescription(zError.GetMessage()),
statusCode,
)
}

View File

@ -0,0 +1,63 @@
package oidc
import (
"io"
"net/http"
"testing"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/zerrors"
)
func Test_oidcError(t *testing.T) {
tests := []struct {
name string
err error
wantErr error
}{
{
name: "nil",
err: nil,
wantErr: nil,
},
{
name: "status err",
err: op.NewStatusError(io.ErrClosedPipe, http.StatusTeapot),
wantErr: op.NewStatusError(io.ErrClosedPipe, http.StatusTeapot),
},
{
name: "oidc err",
err: oidc.ErrInvalidClient().WithParent(io.ErrClosedPipe),
wantErr: oidc.ErrInvalidClient().WithParent(io.ErrClosedPipe),
},
{
name: "unknown err",
err: io.ErrClosedPipe,
wantErr: op.NewStatusError(
oidc.ErrServerError().
WithParent(io.ErrClosedPipe).
WithDescription("Errors.Internal"),
http.StatusInternalServerError,
),
},
{
name: "zitadel error, invalid request",
err: zerrors.ThrowPreconditionFailed(io.ErrClosedPipe, "TEST-123", "oopsie"),
wantErr: op.NewStatusError(
oidc.ErrInvalidRequest().
WithParent(io.ErrClosedPipe).
WithDescription("oopsie"),
http.StatusBadRequest,
),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := oidcError(tt.err)
require.ErrorIs(t, err, tt.wantErr)
})
}
}

View File

@ -18,7 +18,10 @@ import (
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (resp *op.Response, err error) { func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (resp *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
if s.features.LegacyIntrospection { if s.features.LegacyIntrospection {
return s.LegacyServer.Introspect(ctx, r) return s.LegacyServer.Introspect(ctx, r)

View File

@ -7,10 +7,17 @@ import (
"github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
) )
func (o *OPStorage) JWTProfileTokenType(ctx context.Context, request op.TokenRequest) (op.AccessTokenType, error) { func (o *OPStorage) JWTProfileTokenType(ctx context.Context, request op.TokenRequest) (_ op.AccessTokenType, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
mapJWTProfileScopesToAudience(ctx, request) mapJWTProfileScopesToAudience(ctx, request)
user, err := o.query.GetUserByID(ctx, false, request.GetSubject()) user, err := o.query.GetUserByID(ctx, false, request.GetSubject())
if err != nil { if err != nil {

View File

@ -111,7 +111,10 @@ func (k *keySetCache) getKey(ctx context.Context, keyID string) (_ *jose.JSONWeb
// VerifySignature implements the oidc.KeySet interface. // VerifySignature implements the oidc.KeySet interface.
func (k *keySetCache) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (_ []byte, err error) { func (k *keySetCache) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (_ []byte, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
if len(jws.Signatures) != 1 { if len(jws.Signatures) != 1 {
return nil, zerrors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid") return nil, zerrors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")

View File

@ -71,7 +71,7 @@ func Test_ZITADEL_API_missing_audience_scope(t *testing.T) {
// code exchange // code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, false) assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
@ -107,7 +107,7 @@ func Test_ZITADEL_API_missing_authentication(t *testing.T) {
// code exchange // code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -134,7 +134,7 @@ func Test_ZITADEL_API_missing_mfa(t *testing.T) {
// code exchange // code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPassword, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, armPassword, startTime, changeTime)
@ -162,7 +162,7 @@ func Test_ZITADEL_API_success(t *testing.T) {
// code exchange // code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, false) assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
@ -196,7 +196,7 @@ func Test_ZITADEL_API_glob_redirects(t *testing.T) {
// code exchange // code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, false) assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
@ -225,7 +225,7 @@ func Test_ZITADEL_API_inactive_access_token(t *testing.T) {
// code exchange // code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
@ -267,7 +267,7 @@ func Test_ZITADEL_API_terminated_session(t *testing.T) {
// code exchange // code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)

View File

@ -110,10 +110,13 @@ func (s *Server) Ready(ctx context.Context, r *op.Request[struct{}]) (_ *op.Resp
func (s *Server) Discovery(ctx context.Context, r *op.Request[struct{}]) (_ *op.Response, err error) { func (s *Server) Discovery(ctx context.Context, r *op.Request[struct{}]) (_ *op.Response, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
restrictions, err := s.query.GetInstanceRestrictions(ctx) restrictions, err := s.query.GetInstanceRestrictions(ctx)
if err != nil { if err != nil {
return nil, err return nil, op.NewStatusError(oidc.ErrServerError().WithParent(err).WithDescription("internal server error"), http.StatusInternalServerError)
} }
allowedLanguages := restrictions.AllowedLanguages allowedLanguages := restrictions.AllowedLanguages
if len(allowedLanguages) == 0 { if len(allowedLanguages) == 0 {

View File

@ -140,7 +140,7 @@ func (repo *AuthRequestRepo) CreateAuthRequest(ctx context.Context, request *dom
if err != nil { if err != nil {
return nil, err return nil, err
} }
appIDs, err := repo.Query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}) appIDs, err := repo.Query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -476,10 +476,17 @@ func (q *Queries) SearchApps(ctx context.Context, queries *AppSearchQueries, wit
return apps, err return apps, err
} }
func (q *Queries) SearchClientIDs(ctx context.Context, queries *AppSearchQueries) (ids []string, err error) { func (q *Queries) SearchClientIDs(ctx context.Context, queries *AppSearchQueries, shouldTriggerBulk bool) (ids []string, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
if shouldTriggerBulk {
_, traceSpan := tracing.NewNamedSpan(ctx, "TriggerAppProjection")
ctx, err = projection.AppProjection.Trigger(ctx, handler.WithAwaitRunning())
logging.OnError(err).Debug("trigger failed")
traceSpan.EndWithError(err)
}
query, scan := prepareClientIDsQuery(ctx, q.client) query, scan := prepareClientIDsQuery(ctx, q.client)
eq := sq.Eq{AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()} eq := sq.Eq{AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}
stmt, args, err := queries.toQuery(query).Where(eq).ToSql() stmt, args, err := queries.toQuery(query).Where(eq).ToSql()