mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 07:47:32 +00:00
feat(OIDC): add support for end_session for V2 tokens (#6226)
This PR adds support for the OIDC end_session_endpoint for V2 tokens. Sending an id_token_hint as parameter will directly terminate the underlying (SSO) session and all its tokens. Without this param, the user will be redirected to the Login UI, where he will able to choose if to logout.
This commit is contained in:
@@ -29,8 +29,9 @@ var (
|
||||
)
|
||||
|
||||
const (
|
||||
redirectURI = "oidcIntegrationTest://callback"
|
||||
redirectURI = "oidcintegrationtest://callback"
|
||||
redirectURIImplicit = "http://localhost:9999/callback"
|
||||
logoutRedirectURI = "oidcintegrationtest://logged-out"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -51,7 +52,7 @@ func TestMain(m *testing.M) {
|
||||
func TestServer_GetAuthRequest(t *testing.T) {
|
||||
project, err := Tester.CreateProject(CTX)
|
||||
require.NoError(t, err)
|
||||
client, err := Tester.CreateOIDCNativeClient(CTX, redirectURI, project.GetId())
|
||||
client, err := Tester.CreateOIDCNativeClient(CTX, redirectURI, logoutRedirectURI, project.GetId())
|
||||
require.NoError(t, err)
|
||||
authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
@@ -95,7 +96,7 @@ func TestServer_GetAuthRequest(t *testing.T) {
|
||||
func TestServer_CreateCallback(t *testing.T) {
|
||||
project, err := Tester.CreateProject(CTX)
|
||||
require.NoError(t, err)
|
||||
client, err := Tester.CreateOIDCNativeClient(CTX, redirectURI, project.GetId())
|
||||
client, err := Tester.CreateOIDCNativeClient(CTX, redirectURI, logoutRedirectURI, project.GetId())
|
||||
require.NoError(t, err)
|
||||
sessionResp, err := Tester.Client.SessionV2.CreateSession(CTX, &session.CreateSessionRequest{
|
||||
Checks: &session.Checks{
|
||||
|
@@ -308,10 +308,41 @@ func (o *OPStorage) TerminateSession(ctx context.Context, userID, clientID strin
|
||||
return err
|
||||
}
|
||||
|
||||
func (o *OPStorage) RevokeToken(ctx context.Context, token, userID, clientID string) (err *oidc.Error) {
|
||||
func (o *OPStorage) TerminateSessionFromRequest(ctx context.Context, endSessionRequest *op.EndSessionRequest) (redirectURI string, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
// check for the login client header
|
||||
// and if not provided, terminate the session using the V1 method
|
||||
headers, _ := http_utils.HeadersFromCtx(ctx)
|
||||
if loginClient := headers.Get(LoginClientHeader); loginClient == "" {
|
||||
return endSessionRequest.RedirectURI, o.TerminateSession(ctx, endSessionRequest.UserID, endSessionRequest.ClientID)
|
||||
}
|
||||
|
||||
// in case there are not id_token_hint, redirect to the UI and let it decide which session to terminate
|
||||
if endSessionRequest.IDTokenHintClaims == nil {
|
||||
return o.defaultLogoutURLV2 + endSessionRequest.RedirectURI, nil
|
||||
}
|
||||
|
||||
// terminate the session of the id_token_hint
|
||||
_, err = o.command.TerminateSessionWithoutTokenCheck(ctx, endSessionRequest.IDTokenHintClaims.SessionID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return endSessionRequest.RedirectURI, nil
|
||||
}
|
||||
|
||||
func (o *OPStorage) RevokeToken(ctx context.Context, token, userID, clientID string) (err *oidc.Error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
// check for nil, because `err` is not an error and EndWithError would panic
|
||||
if err == nil {
|
||||
span.End()
|
||||
return
|
||||
}
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
if strings.HasPrefix(token, command.IDPrefixV2) {
|
||||
err := o.command.RevokeOIDCSessionToken(ctx, token, clientID)
|
||||
if err == nil {
|
||||
|
@@ -8,15 +8,17 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||||
oidc_api "github.com/zitadel/zitadel/internal/api/oidc"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
|
||||
session "github.com/zitadel/zitadel/pkg/grpc/session/v2alpha"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -170,14 +172,9 @@ func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) {
|
||||
// test actual refresh grant
|
||||
newTokens, err := refreshTokens(t, clientID, tokens.RefreshToken)
|
||||
require.NoError(t, err)
|
||||
idToken, _ := newTokens.Extra("id_token").(string)
|
||||
assert.NotEmpty(t, idToken)
|
||||
assert.NotEmpty(t, newTokens.AccessToken)
|
||||
assert.NotEmpty(t, newTokens.RefreshToken)
|
||||
claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), newTokens.AccessToken, idToken, provider.IDTokenVerifier())
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, newTokens, true)
|
||||
// auth time must still be the initial
|
||||
assertIDTokenClaims(t, claims, armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, newTokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
|
||||
// refresh with an old refresh_token must fail
|
||||
_, err = rp.RefreshAccessToken(provider, tokens.RefreshToken, "", "")
|
||||
@@ -374,6 +371,131 @@ func TestOPStorage_RevokeToken_invalid_client(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOPStorage_TerminateSession(t *testing.T) {
|
||||
clientID := createClient(t)
|
||||
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
authRequestID := createAuthRequest(t, clientID, redirectURI)
|
||||
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
|
||||
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: authRequestID,
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionID,
|
||||
SessionToken: sessionToken,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// test code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, false)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
|
||||
// userinfo must not fail
|
||||
_, err = rp.Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
postLogoutRedirect, err := rp.EndSession(provider, tokens.IDToken, logoutRedirectURI, "state")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, logoutRedirectURI+"?state=state", postLogoutRedirect.String())
|
||||
|
||||
// userinfo must fail
|
||||
_, err = rp.Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOPStorage_TerminateSession_refresh_grant(t *testing.T) {
|
||||
clientID := createClient(t)
|
||||
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
|
||||
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
|
||||
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: authRequestID,
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionID,
|
||||
SessionToken: sessionToken,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// test code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
|
||||
// userinfo must not fail
|
||||
_, err = rp.Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
postLogoutRedirect, err := rp.EndSession(provider, tokens.IDToken, logoutRedirectURI, "state")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, logoutRedirectURI+"?state=state", postLogoutRedirect.String())
|
||||
|
||||
// userinfo must fail
|
||||
_, err = rp.Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
|
||||
require.Error(t, err)
|
||||
|
||||
refreshedTokens, err := refreshTokens(t, clientID, tokens.RefreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// userinfo must not fail
|
||||
_, err = rp.Userinfo(refreshedTokens.AccessToken, refreshedTokens.TokenType, refreshedTokens.IDTokenClaims.Subject, provider)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestOPStorage_TerminateSession_empty_id_token_hint(t *testing.T) {
|
||||
clientID := createClient(t)
|
||||
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
authRequestID := createAuthRequest(t, clientID, redirectURI)
|
||||
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
|
||||
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: authRequestID,
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionID,
|
||||
SessionToken: sessionToken,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// test code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, false)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
|
||||
postLogoutRedirect, err := rp.EndSession(provider, "", logoutRedirectURI, "state")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http_utils.BuildOrigin(Tester.Host(), Tester.Config.ExternalSecure)+Tester.Config.OIDC.DefaultLogoutURLV2+logoutRedirectURI+"?state=state", postLogoutRedirect.String())
|
||||
|
||||
// userinfo must not fail until login UI terminated session
|
||||
_, err = rp.Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
// simulate termination by login UI
|
||||
_, err = Tester.Client.SessionV2.DeleteSession(CTXLOGIN, &session.DeleteSessionRequest{
|
||||
SessionId: sessionID,
|
||||
SessionToken: gu.Ptr(sessionToken),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// userinfo must fail
|
||||
_, err = rp.Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func exchangeTokens(t testing.TB, clientID, code string) (*oidc.Tokens[*oidc.IDTokenClaims], error) {
|
||||
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
@@ -382,11 +504,24 @@ func exchangeTokens(t testing.TB, clientID, code string) (*oidc.Tokens[*oidc.IDT
|
||||
return rp.CodeExchange[*oidc.IDTokenClaims](context.Background(), code, provider, rp.WithCodeVerifier(codeVerifier))
|
||||
}
|
||||
|
||||
func refreshTokens(t testing.TB, clientID, refreshToken string) (*oauth2.Token, error) {
|
||||
func refreshTokens(t testing.TB, clientID, refreshToken string) (*oidc.Tokens[*oidc.IDTokenClaims], error) {
|
||||
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
|
||||
return rp.RefreshAccessToken(provider, refreshToken, "", "")
|
||||
tokens, err := rp.RefreshAccessToken(provider, refreshToken, "", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
idToken, _ := tokens.Extra("id_token").(string)
|
||||
claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), tokens.AccessToken, idToken, provider.IDTokenVerifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: tokens,
|
||||
IDToken: idToken,
|
||||
IDTokenClaims: claims,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func assertCodeResponse(t *testing.T, callback string) string {
|
||||
|
@@ -163,6 +163,18 @@ func (o *OPStorage) SetUserinfoFromScopes(ctx context.Context, userInfo *oidc.Us
|
||||
return o.setUserinfo(ctx, userInfo, userID, applicationID, scopes, nil)
|
||||
}
|
||||
|
||||
// SetUserinfoFromRequest extends the SetUserinfoFromScopes during the id_token generation.
|
||||
// This is required for V2 tokens to be able to set the sessionID (`sid`) claim.
|
||||
func (o *OPStorage) SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.UserInfo, request op.IDTokenRequest, _ []string) error {
|
||||
switch t := request.(type) {
|
||||
case *AuthRequestV2:
|
||||
userinfo.AppendClaims("sid", t.SessionID)
|
||||
case *RefreshTokenRequestV2:
|
||||
userinfo.AppendClaims("sid", t.SessionID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
@@ -51,7 +51,7 @@ func TestOPStorage_SetUserinfoFromToken(t *testing.T) {
|
||||
func TestOPStorage_SetIntrospectionFromToken(t *testing.T) {
|
||||
project, err := Tester.CreateProject(CTX)
|
||||
require.NoError(t, err)
|
||||
app, err := Tester.CreateOIDCNativeClient(CTX, redirectURI, project.GetId())
|
||||
app, err := Tester.CreateOIDCNativeClient(CTX, redirectURI, logoutRedirectURI, project.GetId())
|
||||
require.NoError(t, err)
|
||||
api, err := Tester.CreateAPIClient(CTX, project.GetId())
|
||||
require.NoError(t, err)
|
||||
|
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
@@ -30,8 +31,9 @@ var (
|
||||
)
|
||||
|
||||
const (
|
||||
redirectURI = "oidcIntegrationTest://callback"
|
||||
redirectURI = "oidcintegrationtest://callback"
|
||||
redirectURIImplicit = "http://localhost:9999/callback"
|
||||
logoutRedirectURI = "oidcintegrationtest://logged-out"
|
||||
zitadelAudienceScope = domain.ProjectIDScope + domain.ProjectIDScopeZITADEL + domain.AudSuffix
|
||||
)
|
||||
|
||||
@@ -213,10 +215,52 @@ func Test_ZITADEL_API_inactive_access_token(t *testing.T) {
|
||||
require.Nil(t, myUserResp)
|
||||
}
|
||||
|
||||
func Test_ZITADEL_API_terminated_session(t *testing.T) {
|
||||
clientID := createClient(t)
|
||||
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess, zitadelAudienceScope)
|
||||
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
|
||||
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: authRequestID,
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionID,
|
||||
SessionToken: sessionToken,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
|
||||
// make sure token works
|
||||
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
|
||||
myUserResp, err := Tester.Client.Auth.GetMyUser(ctx, &auth.GetMyUserRequest{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, User.GetUserId(), myUserResp.GetUser().GetId())
|
||||
|
||||
// refresh token
|
||||
postLogoutRedirect, err := rp.EndSession(provider, tokens.IDToken, logoutRedirectURI, "state")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, logoutRedirectURI+"?state=state", postLogoutRedirect.String())
|
||||
|
||||
// use token from terminated session
|
||||
ctx = metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
|
||||
myUserResp, err = Tester.Client.Auth.GetMyUser(ctx, &auth.GetMyUserRequest{})
|
||||
require.Error(t, err)
|
||||
require.Nil(t, myUserResp)
|
||||
}
|
||||
|
||||
func createClient(t testing.TB) string {
|
||||
project, err := Tester.CreateProject(CTX)
|
||||
require.NoError(t, err)
|
||||
app, err := Tester.CreateOIDCNativeClient(CTX, redirectURI, project.GetId())
|
||||
app, err := Tester.CreateOIDCNativeClient(CTX, redirectURI, logoutRedirectURI, project.GetId())
|
||||
require.NoError(t, err)
|
||||
return app.GetClientId()
|
||||
}
|
||||
|
@@ -42,6 +42,7 @@ type Config struct {
|
||||
CustomEndpoints *EndpointConfig
|
||||
DeviceAuth *DeviceAuthorizationConfig
|
||||
DefaultLoginURLV2 string
|
||||
DefaultLogoutURLV2 string
|
||||
}
|
||||
|
||||
type EndpointConfig struct {
|
||||
@@ -67,6 +68,7 @@ type OPStorage struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
defaultLoginURL string
|
||||
defaultLoginURLV2 string
|
||||
defaultLogoutURLV2 string
|
||||
defaultAccessTokenLifetime time.Duration
|
||||
defaultIdTokenLifetime time.Duration
|
||||
signingKeyAlgorithm string
|
||||
@@ -184,6 +186,7 @@ func newStorage(config Config, command *command.Commands, query *query.Queries,
|
||||
eventstore: es,
|
||||
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
|
||||
defaultLoginURLV2: config.DefaultLoginURLV2,
|
||||
defaultLogoutURLV2: config.DefaultLogoutURLV2,
|
||||
signingKeyAlgorithm: config.SigningKeyAlgorithm,
|
||||
defaultAccessTokenLifetime: config.DefaultAccessTokenLifetime,
|
||||
defaultIdTokenLifetime: config.DefaultIdTokenLifetime,
|
||||
|
Reference in New Issue
Block a user