fix(oidc): do not return access token for response type id_token (#8777)

# Which Problems Are Solved

Do not return an access token for implicit flow from v1 login, if the
`response_type` is `id_token`

# How the Problems Are Solved

Do not create the access token event if if the `response_type` is
`id_token`.

# Additional Changes

Token endpoint calls without auth request, such as machine users, token
exchange and refresh token, do not have a `response_type`. For such
calls the `OIDCResponseTypeUnspecified` enum is added at a `-1` offset,
in order not to break existing client configs.

# Additional Context

- https://discord.com/channels/927474939156643850/1294001717725237298
- Fixes https://github.com/zitadel/zitadel/issues/8776
This commit is contained in:
Tim Möhlmann 2024-11-12 17:20:48 +02:00 committed by GitHub
parent 69e9926bcc
commit 778b4041ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 105 additions and 3 deletions

View File

@ -600,6 +600,7 @@ func (s *Server) authResponseToken(authReq *AuthRequest, authorizer op.Authorize
nil, nil,
slices.Contains(scope, oidc.ScopeOfflineAccess), slices.Contains(scope, oidc.ScopeOfflineAccess),
authReq.SessionID, authReq.SessionID,
authReq.oidc().ResponseType,
) )
if err != nil { if err != nil {
op.AuthRequestError(w, r, authReq, err, authorizer) op.AuthRequestError(w, r, authReq, err, authorizer)

View File

@ -47,6 +47,7 @@ func (s *Server) ClientCredentialsExchange(ctx context.Context, r *op.ClientRequ
nil, nil,
false, false,
"", "",
domain.OIDCResponseTypeUnspecified,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -87,6 +87,7 @@ func (s *Server) codeExchangeV1(ctx context.Context, client *Client, req *oidc.A
nil, nil,
slices.Contains(scope, oidc.ScopeOfflineAccess), slices.Contains(scope, oidc.ScopeOfflineAccess),
authReq.SessionID, authReq.SessionID,
authReq.oidc().ResponseType,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -300,6 +300,7 @@ func (s *Server) createExchangeAccessToken(
actor, actor,
slices.Contains(scope, oidc.ScopeOfflineAccess), slices.Contains(scope, oidc.ScopeOfflineAccess),
"", "",
domain.OIDCResponseTypeUnspecified,
) )
if err != nil { if err != nil {
return "", "", "", 0, err return "", "", "", 0, err
@ -346,6 +347,7 @@ func (s *Server) createExchangeJWT(
actor, actor,
slices.Contains(scope, oidc.ScopeOfflineAccess), slices.Contains(scope, oidc.ScopeOfflineAccess),
"", "",
domain.OIDCResponseTypeUnspecified,
) )
accessToken, err = s.createJWT(ctx, client, session, getUserInfo, roleAssertion, getSigner) accessToken, err = s.createJWT(ctx, client, session, getUserInfo, roleAssertion, getSigner)
if err != nil { if err != nil {

View File

@ -57,6 +57,7 @@ func (s *Server) JWTProfile(ctx context.Context, r *op.Request[oidc.JWTProfileGr
nil, nil,
false, false,
"", "",
domain.OIDCResponseTypeUnspecified,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -69,6 +69,7 @@ func (s *Server) refreshTokenV1(ctx context.Context, client *Client, r *op.Clien
refreshToken.Actor, refreshToken.Actor,
true, true,
"", "",
domain.OIDCResponseTypeUnspecified,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -147,6 +147,7 @@ func (c *Commands) CreateOIDCSession(ctx context.Context,
actor *domain.TokenActor, actor *domain.TokenActor,
needRefreshToken bool, needRefreshToken bool,
sessionID string, sessionID string,
responseType domain.OIDCResponseType,
) (session *OIDCSession, err error) { ) (session *OIDCSession, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -164,8 +165,10 @@ func (c *Commands) CreateOIDCSession(ctx context.Context,
cmd.AddSession(ctx, userID, resourceOwner, sessionID, clientID, audience, scope, authMethods, authTime, nonce, preferredLanguage, userAgent) cmd.AddSession(ctx, userID, resourceOwner, sessionID, clientID, audience, scope, authMethods, authTime, nonce, preferredLanguage, userAgent)
cmd.RegisterLogout(ctx, sessionID, userID, clientID, backChannelLogoutURI) cmd.RegisterLogout(ctx, sessionID, userID, clientID, backChannelLogoutURI)
if err = cmd.AddAccessToken(ctx, scope, userID, resourceOwner, reason, actor); err != nil { if responseType != domain.OIDCResponseTypeIDToken {
return nil, err if err = cmd.AddAccessToken(ctx, scope, userID, resourceOwner, reason, actor); err != nil {
return nil, err
}
} }
if needRefreshToken { if needRefreshToken {
if err = cmd.AddRefreshToken(ctx, userID); err != nil { if err = cmd.AddRefreshToken(ctx, userID); err != nil {

View File

@ -749,6 +749,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
actor *domain.TokenActor actor *domain.TokenActor
needRefreshToken bool needRefreshToken bool
sessionID string sessionID string
responseType domain.OIDCResponseType
} }
tests := []struct { tests := []struct {
name string name string
@ -788,6 +789,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
Issuer: "foo.com", Issuer: "foo.com",
}, },
needRefreshToken: false, needRefreshToken: false,
responseType: domain.OIDCResponseTypeUnspecified,
}, },
wantErr: io.ErrClosedPipe, wantErr: io.ErrClosedPipe,
}, },
@ -844,6 +846,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
Issuer: "foo.com", Issuer: "foo.com",
}, },
needRefreshToken: false, needRefreshToken: false,
responseType: domain.OIDCResponseTypeUnspecified,
}, },
wantErr: zerrors.ThrowPreconditionFailed(nil, "OIDCS-kj3g2", "Errors.User.NotActive"), wantErr: zerrors.ThrowPreconditionFailed(nil, "OIDCS-kj3g2", "Errors.User.NotActive"),
}, },
@ -918,6 +921,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
Issuer: "foo.com", Issuer: "foo.com",
}, },
needRefreshToken: false, needRefreshToken: false,
responseType: domain.OIDCResponseTypeUnspecified,
}, },
want: &OIDCSession{ want: &OIDCSession{
TokenID: "V2_oidcSessionID-at_accessTokenID", TokenID: "V2_oidcSessionID-at_accessTokenID",
@ -943,6 +947,87 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
}, },
}, },
}, },
{
name: "ID token only",
fields: fields{
eventstore: expectEventstore(
expectFilter(
user.NewHumanAddedEvent(
context.Background(),
&user.NewAggregate("userID", "org1").Aggregate,
"username",
"firstname",
"lastname",
"nickname",
"displayname",
language.Afrikaans,
domain.GenderUnspecified,
"email",
false,
),
),
expectFilter(), // token lifetime
expectPush(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"userID", "org1", "", "clientID", []string{"audience"}, []string{"openid", "offline_access"},
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow, "nonce", &language.Afrikaans,
&domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
),
),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID"),
defaultAccessTokenLifetime: time.Hour,
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID",
resourceOwner: "org1",
clientID: "clientID",
audience: []string{"audience"},
scope: []string{"openid", "offline_access"},
authMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
authTime: testNow,
nonce: "nonce",
preferredLanguage: &language.Afrikaans,
userAgent: &domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
reason: domain.TokenReasonAuthRequest,
actor: &domain.TokenActor{
UserID: "user2",
Issuer: "foo.com",
},
needRefreshToken: false,
responseType: domain.OIDCResponseTypeIDToken,
},
want: &OIDCSession{
ClientID: "clientID",
UserID: "userID",
Audience: []string{"audience"},
Scope: []string{"openid", "offline_access"},
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
AuthTime: testNow,
Nonce: "nonce",
PreferredLanguage: &language.Afrikaans,
UserAgent: &domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
},
},
{ {
name: "disable user token event", name: "disable user token event",
fields: fields{ fields: fields{
@ -1018,6 +1103,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
Issuer: "foo.com", Issuer: "foo.com",
}, },
needRefreshToken: false, needRefreshToken: false,
responseType: domain.OIDCResponseTypeUnspecified,
}, },
want: &OIDCSession{ want: &OIDCSession{
TokenID: "V2_oidcSessionID-at_accessTokenID", TokenID: "V2_oidcSessionID-at_accessTokenID",
@ -1115,6 +1201,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
Issuer: "foo.com", Issuer: "foo.com",
}, },
needRefreshToken: true, needRefreshToken: true,
responseType: domain.OIDCResponseTypeUnspecified,
}, },
want: &OIDCSession{ want: &OIDCSession{
TokenID: "V2_oidcSessionID-at_accessTokenID", TokenID: "V2_oidcSessionID-at_accessTokenID",
@ -1213,6 +1300,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
}, },
needRefreshToken: false, needRefreshToken: false,
sessionID: "sessionID", sessionID: "sessionID",
responseType: domain.OIDCResponseTypeUnspecified,
}, },
want: &OIDCSession{ want: &OIDCSession{
TokenID: "V2_oidcSessionID-at_accessTokenID", TokenID: "V2_oidcSessionID-at_accessTokenID",
@ -1594,6 +1682,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
Issuer: "foo.com", Issuer: "foo.com",
}, },
needRefreshToken: false, needRefreshToken: false,
responseType: domain.OIDCResponseTypeUnspecified,
}, },
wantErr: zerrors.ThrowPermissionDenied(nil, "test", "test"), wantErr: zerrors.ThrowPermissionDenied(nil, "test", "test"),
}, },
@ -1675,6 +1764,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
Issuer: "foo.com", Issuer: "foo.com",
}, },
needRefreshToken: false, needRefreshToken: false,
responseType: domain.OIDCResponseTypeUnspecified,
}, },
want: &OIDCSession{ want: &OIDCSession{
TokenID: "V2_oidcSessionID-at_accessTokenID", TokenID: "V2_oidcSessionID-at_accessTokenID",
@ -1729,6 +1819,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
tt.args.actor, tt.args.actor,
tt.args.needRefreshToken, tt.args.needRefreshToken,
tt.args.sessionID, tt.args.sessionID,
tt.args.responseType,
) )
require.ErrorIs(t, err, tt.wantErr) require.ErrorIs(t, err, tt.wantErr)
if got != nil { if got != nil {

View File

@ -79,7 +79,8 @@ const (
type OIDCResponseType int32 type OIDCResponseType int32
const ( const (
OIDCResponseTypeCode OIDCResponseType = iota OIDCResponseTypeUnspecified OIDCResponseType = iota - 1 // Negative offset not to break existing configs.
OIDCResponseTypeCode
OIDCResponseTypeIDToken OIDCResponseTypeIDToken
OIDCResponseTypeIDTokenToken OIDCResponseTypeIDTokenToken
) )