feat(api): add OIDC session service (#6157)

This PR starts the OIDC implementation for the API V2 including the Implicit and Code Flow.


Co-authored-by: Livio Spring <livio.a@gmail.com>
Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>
Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
This commit is contained in:
Livio Spring
2023-07-10 15:27:00 +02:00
committed by GitHub
parent be1fe36776
commit 14b8cf4894
69 changed files with 5948 additions and 106 deletions

View File

@@ -0,0 +1,43 @@
// Package amr maps zitadel session factors to Authentication Method Reference Values
// as defined in [RFC 8176, section 2].
//
// [RFC 8176, section 2]: https://datatracker.ietf.org/doc/html/rfc8176#section-2
package amr
const (
// Password states that the users password has been verified
// Deprecated: use `PWD` instead
Password = "password"
// PWD states that the users password has been verified
PWD = "pwd"
// MFA states that multiple factors have been verified (e.g. pwd and otp or passkey)
MFA = "mfa"
// OTP states that a one time password has been verified (e.g. TOTP)
OTP = "otp"
// UserPresence states that the end users presence has been verified (e.g. passkey and u2f)
UserPresence = "user"
)
type AuthenticationMethodReference interface {
IsPasswordChecked() bool
IsPasskeyChecked() bool
IsU2FChecked() bool
IsOTPChecked() bool
}
func List(model AuthenticationMethodReference) []string {
amr := make([]string, 0)
if model.IsPasswordChecked() {
amr = append(amr, PWD)
}
if model.IsPasskeyChecked() || model.IsU2FChecked() {
amr = append(amr, UserPresence)
}
if model.IsOTPChecked() {
amr = append(amr, OTP)
}
if model.IsPasskeyChecked() || len(amr) >= 2 {
amr = append(amr, MFA)
}
return amr
}

View File

@@ -0,0 +1,93 @@
package amr
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAMR(t *testing.T) {
type args struct {
model AuthenticationMethodReference
}
tests := []struct {
name string
args args
want []string
}{
{
"no checks, empty",
args{
new(test),
},
[]string{},
},
{
"pw checked",
args{
&test{pwChecked: true},
},
[]string{PWD},
},
{
"passkey checked",
args{
&test{passkeyChecked: true},
},
[]string{UserPresence, MFA},
},
{
"u2f checked",
args{
&test{u2fChecked: true},
},
[]string{UserPresence},
},
{
"otp checked",
args{
&test{otpChecked: true},
},
[]string{OTP},
},
{
"multiple checked",
args{
&test{
pwChecked: true,
u2fChecked: true,
},
},
[]string{PWD, UserPresence, MFA},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := List(tt.args.model)
assert.Equal(t, tt.want, got)
})
}
}
type test struct {
pwChecked bool
passkeyChecked bool
u2fChecked bool
otpChecked bool
}
func (t test) IsPasswordChecked() bool {
return t.pwChecked
}
func (t test) IsPasskeyChecked() bool {
return t.passkeyChecked
}
func (t test) IsU2FChecked() bool {
return t.u2fChecked
}
func (t test) IsOTPChecked() bool {
return t.otpChecked
}

View File

@@ -2,6 +2,7 @@ package oidc
import (
"context"
"encoding/base64"
"strings"
"time"
@@ -10,16 +11,75 @@ import (
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/zitadel/internal/api/authz"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/user/model"
)
const (
LoginClientHeader = "x-zitadel-login-client"
)
func (o *OPStorage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (_ op.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
headers, _ := http_utils.HeadersFromCtx(ctx)
if loginClient := headers.Get(LoginClientHeader); loginClient != "" {
return o.createAuthRequestLoginClient(ctx, req, userID, loginClient)
}
return o.createAuthRequest(ctx, req, userID)
}
func (o *OPStorage) createAuthRequestLoginClient(ctx context.Context, req *oidc.AuthRequest, hintUserID, loginClient string) (op.AuthRequest, error) {
project, err := o.query.ProjectByClientID(ctx, req.ClientID, false)
if err != nil {
return nil, err
}
scope, err := o.assertProjectRoleScopesByProject(ctx, project, req.Scopes)
if err != nil {
return nil, err
}
audience, err := o.audienceFromProjectID(ctx, project.ID)
if err != nil {
return nil, err
}
audience = domain.AddAudScopeToAudience(ctx, audience, scope)
authRequest := &command.AuthRequest{
LoginClient: loginClient,
ClientID: req.ClientID,
RedirectURI: req.RedirectURI,
State: req.State,
Nonce: req.Nonce,
Scope: scope,
Audience: audience,
ResponseType: ResponseTypeToBusiness(req.ResponseType),
CodeChallenge: CodeChallengeToBusiness(req.CodeChallenge, req.CodeChallengeMethod),
Prompt: PromptToBusiness(req.Prompt),
UILocales: UILocalesToBusiness(req.UILocales),
MaxAge: MaxAgeToBusiness(req.MaxAge),
}
if req.LoginHint != "" {
authRequest.LoginHint = &req.LoginHint
}
if hintUserID != "" {
authRequest.HintUserID = &hintUserID
}
aar, err := o.command.AddAuthRequest(ctx, authRequest)
if err != nil {
return nil, err
}
return &AuthRequestV2{aar}, nil
}
func (o *OPStorage) createAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (_ op.AuthRequest, err error) {
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return nil, errors.ThrowPreconditionFailed(nil, "OIDC-sd436", "no user agent id")
@@ -36,9 +96,31 @@ func (o *OPStorage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest
return AuthRequestFromBusiness(resp)
}
func (o *OPStorage) audienceFromProjectID(ctx context.Context, projectID string) ([]string, error) {
projectIDQuery, err := query.NewAppProjectIDSearchQuery(projectID)
if err != nil {
return nil, err
}
appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false)
if err != nil {
return nil, err
}
return append(appIDs, projectID), nil
}
func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if strings.HasPrefix(id, command.IDPrefixV2) {
req, err := o.command.GetCurrentAuthRequest(ctx, id)
if err != nil {
return nil, err
}
return &AuthRequestV2{req}, nil
}
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return nil, errors.ThrowPreconditionFailed(nil, "OIDC-D3g21", "no user agent id")
@@ -54,6 +136,17 @@ func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.Au
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
plainCode, err := o.decryptGrant(code)
if err != nil {
return nil, err
}
if strings.HasPrefix(plainCode, command.IDPrefixV2) {
authReq, err := o.command.ExchangeAuthCode(ctx, plainCode)
if err != nil {
return nil, err
}
return &AuthRequestV2{authReq}, nil
}
resp, err := o.repo.AuthRequestByCode(ctx, code)
if err != nil {
return nil, err
@@ -61,9 +154,23 @@ func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.Au
return AuthRequestFromBusiness(resp)
}
// decryptGrant decrypts a code or refresh_token
func (o *OPStorage) decryptGrant(grant string) (string, error) {
decodedGrant, err := base64.RawURLEncoding.DecodeString(grant)
if err != nil {
return "", err
}
return o.encAlg.DecryptString(decodedGrant, o.encAlg.EncryptionKeyID())
}
func (o *OPStorage) SaveAuthCode(ctx context.Context, id, code string) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if strings.HasPrefix(id, command.IDPrefixV2) {
return o.command.AddAuthRequestCode(ctx, id, code)
}
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return errors.ThrowPreconditionFailed(nil, "OIDC-Dgus2", "no user agent id")
@@ -81,12 +188,15 @@ func (o *OPStorage) DeleteAuthRequest(ctx context.Context, id string) (err error
func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (_ string, _ time.Time, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
var userAgentID, applicationID, userOrgID string
authReq, ok := req.(*AuthRequest)
if ok {
switch authReq := req.(type) {
case *AuthRequest:
userAgentID = authReq.AgentID
applicationID = authReq.ApplicationID
userOrgID = authReq.UserOrgID
case *AuthRequestV2:
return o.command.AddOIDCSessionAccessToken(setContextUserSystem(ctx), authReq.GetID())
}
accessTokenLifetime, _, _, _, err := o.getOIDCSettings(ctx)
@@ -104,6 +214,15 @@ func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest)
func (o *OPStorage) CreateAccessAndRefreshTokens(ctx context.Context, req op.TokenRequest, refreshToken string) (_, _ string, _ time.Time, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
// handle V2 request directly
switch tokenReq := req.(type) {
case *AuthRequestV2:
return o.command.AddOIDCSessionRefreshAndAccessToken(setContextUserSystem(ctx), tokenReq.GetID())
case *RefreshTokenRequestV2:
return o.command.ExchangeOIDCSessionRefreshAndAccessToken(setContextUserSystem(ctx), tokenReq.OIDCSessionWriteModel.AggregateID, refreshToken, tokenReq.RequestedScopes)
}
userAgentID, applicationID, userOrgID, authTime, authMethodsReferences := getInfoFromRequest(req)
scopes, err := o.assertProjectRoleScopes(ctx, applicationID, req.GetScopes())
if err != nil {
@@ -142,7 +261,22 @@ func getInfoFromRequest(req op.TokenRequest) (string, string, string, time.Time,
return "", "", "", time.Time{}, nil
}
func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) {
func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (_ op.RefreshTokenRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
plainCode, err := o.decryptGrant(refreshToken)
if err != nil {
return nil, err
}
if strings.HasPrefix(plainCode, command.IDPrefixV2) {
oidcSession, err := o.command.OIDCSessionByRefreshToken(ctx, plainCode)
if err != nil {
return nil, err
}
return &RefreshTokenRequestV2{OIDCSessionWriteModel: oidcSession}, nil
}
tokenView, err := o.repo.RefreshTokenByToken(ctx, refreshToken)
if err != nil {
return nil, err
@@ -245,6 +379,29 @@ func (o *OPStorage) assertProjectRoleScopes(ctx context.Context, clientID string
return scopes, nil
}
func (o *OPStorage) assertProjectRoleScopesByProject(ctx context.Context, project *query.Project, scopes []string) ([]string, error) {
for _, scope := range scopes {
if strings.HasPrefix(scope, ScopeProjectRolePrefix) {
return scopes, nil
}
}
if !project.ProjectRoleAssertion {
return scopes, nil
}
projectIDQuery, err := query.NewProjectRoleProjectIDSearchQuery(project.ID)
if err != nil {
return nil, errors.ThrowInternal(err, "OIDC-Cyc78", "Errors.Internal")
}
roles, err := o.query.SearchProjectRoles(ctx, true, &query.ProjectRoleSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false)
if err != nil {
return nil, err
}
for _, role := range roles.ProjectRoles {
scopes = append(scopes, ScopeProjectRolePrefix+role.Key)
}
return scopes, nil
}
func (o *OPStorage) assertClientScopesForPAT(ctx context.Context, token *model.TokenView, clientID, projectID string) error {
token.Audience = append(token.Audience, clientID)
projectIDQuery, err := query.NewProjectRoleProjectIDSearchQuery(projectID)
@@ -279,3 +436,58 @@ func (o *OPStorage) getOIDCSettings(ctx context.Context) (accessTokenLifetime, i
}
return o.defaultAccessTokenLifetime, o.defaultIdTokenLifetime, o.defaultRefreshTokenIdleExpiration, o.defaultRefreshTokenExpiration, nil
}
func CreateErrorCallbackURL(authReq op.AuthRequest, reason, description, uri string, authorizer op.Authorizer) (string, error) {
e := struct {
Error string `schema:"error"`
Description string `schema:"error_description,omitempty"`
URI string `schema:"error_uri,omitempty"`
State string `schema:"state,omitempty"`
}{
Error: reason,
Description: description,
URI: uri,
State: authReq.GetState(),
}
callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), e, authorizer.Encoder())
if err != nil {
return "", err
}
return callback, nil
}
func CreateCodeCallbackURL(ctx context.Context, authReq op.AuthRequest, authorizer op.Authorizer) (string, error) {
code, err := op.CreateAuthRequestCode(ctx, authReq, authorizer.Storage(), authorizer.Crypto())
if err != nil {
return "", err
}
codeResponse := struct {
code string
state string
}{
code: code,
state: authReq.GetState(),
}
callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
if err != nil {
return "", err
}
return callback, err
}
func CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest, authorizer op.Authorizer) (string, error) {
client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID())
if err != nil {
return "", err
}
createAccessToken := req.GetResponseType() != oidc.ResponseTypeIDTokenOnly
resp, err := op.CreateTokenResponse(ctx, req, client, authorizer, createAccessToken, "", "")
if err != nil {
return "", err
}
callback, err := op.AuthResponseURL(req.GetRedirectURI(), req.GetResponseType(), req.GetResponseMode(), resp, authorizer.Encoder())
if err != nil {
return "", err
}
return callback, err
}

View File

@@ -12,20 +12,12 @@ import (
"github.com/zitadel/zitadel/internal/api/authz"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/user/model"
)
const (
// DEPRECATED: use `amrPWD` instead
amrPassword = "password"
amrPWD = "pwd"
amrMFA = "mfa"
amrOTP = "otp"
amrUserPresence = "user"
)
type AuthRequest struct {
*domain.AuthRequest
}
@@ -40,19 +32,19 @@ func (a *AuthRequest) GetACR() string {
}
func (a *AuthRequest) GetAMR() []string {
amr := make([]string, 0)
list := make([]string, 0)
if a.PasswordVerified {
amr = append(amr, amrPassword, amrPWD)
list = append(list, amr.Password, amr.PWD)
}
if len(a.MFAsVerified) > 0 {
amr = append(amr, amrMFA)
list = append(list, amr.MFA)
for _, mfa := range a.MFAsVerified {
if amrMFA := AMRFromMFAType(mfa); amrMFA != "" {
amr = append(amr, amrMFA)
list = append(list, amrMFA)
}
}
}
return amr
return list
}
func (a *AuthRequest) GetAudience() []string {
@@ -271,10 +263,10 @@ func CodeChallengeToOIDC(challenge *domain.OIDCCodeChallenge) *oidc.CodeChalleng
func AMRFromMFAType(mfaType domain.MFAType) string {
switch mfaType {
case domain.MFATypeOTP:
return amrOTP
return amr.OTP
case domain.MFATypeU2F,
domain.MFATypeU2FUserVerification:
return amrUserPresence
return amr.UserPresence
default:
return ""
}

View File

@@ -0,0 +1,106 @@
package oidc
import (
"time"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/zitadel/internal/command"
)
type AuthRequestV2 struct {
*command.CurrentAuthRequest
}
func (a *AuthRequestV2) GetID() string {
return a.ID
}
func (a *AuthRequestV2) GetACR() string {
return "" //PLANNED: impl
}
func (a *AuthRequestV2) GetAMR() []string {
return a.AMR
}
func (a *AuthRequestV2) GetAudience() []string {
return a.Audience
}
func (a *AuthRequestV2) GetAuthTime() time.Time {
return a.AuthTime
}
func (a *AuthRequestV2) GetClientID() string {
return a.ClientID
}
func (a *AuthRequestV2) GetCodeChallenge() *oidc.CodeChallenge {
return CodeChallengeToOIDC(a.CodeChallenge)
}
func (a *AuthRequestV2) GetNonce() string {
return a.Nonce
}
func (a *AuthRequestV2) GetRedirectURI() string {
return a.RedirectURI
}
func (a *AuthRequestV2) GetResponseType() oidc.ResponseType {
return ResponseTypeToOIDC(a.ResponseType)
}
func (a *AuthRequestV2) GetResponseMode() oidc.ResponseMode {
return ""
}
func (a *AuthRequestV2) GetScopes() []string {
return a.Scope
}
func (a *AuthRequestV2) GetState() string {
return a.State
}
func (a *AuthRequestV2) GetSubject() string {
return a.UserID
}
func (a *AuthRequestV2) Done() bool {
return a.UserID != "" && a.SessionID != ""
}
type RefreshTokenRequestV2 struct {
*command.OIDCSessionWriteModel
RequestedScopes []string
}
func (r *RefreshTokenRequestV2) GetAMR() []string {
return r.AuthMethodsReferences
}
func (r *RefreshTokenRequestV2) GetAudience() []string {
return r.Audience
}
func (r *RefreshTokenRequestV2) GetAuthTime() time.Time {
return r.AuthTime
}
func (r *RefreshTokenRequestV2) GetClientID() string {
return r.ClientID
}
func (r *RefreshTokenRequestV2) GetScopes() []string {
return r.Scope
}
func (r *RefreshTokenRequestV2) GetSubject() string {
return r.UserID
}
func (r *RefreshTokenRequestV2) SetCurrentScopes(scopes []string) {
r.RequestedScopes = scopes
}

View File

@@ -0,0 +1,275 @@
//go:build integration
package oidc_test
import (
"context"
"net/url"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/oauth2"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/integration"
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
)
var (
CTX context.Context
CTXLOGIN context.Context
Tester *integration.Tester
User *user.AddHumanUserResponse
)
const (
redirectURI = "oidcIntegrationTest://callback"
redirectURIImplicit = "http://localhost:9999/callback"
)
func TestMain(m *testing.M) {
os.Exit(func() int {
ctx, errCtx, cancel := integration.Contexts(5 * time.Minute)
defer cancel()
Tester = integration.NewTester(ctx)
defer Tester.Done()
CTX, _ = Tester.WithAuthorization(ctx, integration.OrgOwner), errCtx
User = Tester.CreateHumanUser(CTX)
Tester.RegisterUserPasskey(CTX, User.GetUserId())
CTXLOGIN, _ = Tester.WithAuthorization(ctx, integration.Login), errCtx
return m.Run()
}())
}
func createClient(t testing.TB) string {
app, err := Tester.CreateOIDCNativeClient(CTX, redirectURI)
require.NoError(t, err)
return app.GetClientId()
}
func createImplicitClient(t testing.TB) string {
app, err := Tester.CreateOIDCImplicitFlowClient(CTX, redirectURIImplicit)
require.NoError(t, err)
return app.GetClientId()
}
func createAuthRequest(t testing.TB, clientID, redirectURI string, scope ...string) string {
redURL, err := Tester.CreateOIDCAuthRequest(clientID, Tester.Users[integration.FirstInstanceUsersKey][integration.Login].ID, redirectURI, scope...)
require.NoError(t, err)
return redURL
}
func createAuthRequestImplicit(t testing.TB, clientID, redirectURI string, scope ...string) string {
redURL, err := Tester.CreateOIDCAuthRequestImplicit(clientID, Tester.Users[integration.FirstInstanceUsersKey][integration.Login].ID, redirectURI, scope...)
require.NoError(t, err)
return redURL
}
func TestOPStorage_CreateAuthRequest(t *testing.T) {
clientID := createClient(t)
id := createAuthRequest(t, clientID, redirectURI)
require.Contains(t, id, command.IDPrefixV2)
}
func TestOPStorage_CreateAccessToken_code(t *testing.T) {
clientID := createClient(t)
authRequestID := createAuthRequest(t, clientID, redirectURI)
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.NoError(t, err)
// test code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code)
require.NoError(t, err)
assertTokens(t, tokens, false)
assertTokenClaims(t, tokens.IDTokenClaims, startTime, changeTime)
// callback on a succeeded request must fail
linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.Error(t, err)
// exchange with a used code must fail
_, err = exchangeTokens(t, clientID, code)
require.Error(t, err)
}
func TestOPStorage_CreateAccessToken_implicit(t *testing.T) {
clientID := createImplicitClient(t)
authRequestID := createAuthRequestImplicit(t, clientID, redirectURIImplicit)
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.NoError(t, err)
// test implicit callback
callback, err := url.Parse(linkResp.GetCallbackUrl())
require.NoError(t, err)
values, err := url.ParseQuery(callback.Fragment)
require.NoError(t, err)
accessToken := values.Get("access_token")
idToken := values.Get("id_token")
refreshToken := values.Get("refresh_token")
assert.NotEmpty(t, accessToken)
assert.NotEmpty(t, idToken)
assert.Empty(t, refreshToken)
assert.NotEmpty(t, values.Get("expires_in"))
assert.Equal(t, oidc.BearerToken, values.Get("token_type"))
assert.Equal(t, "state", values.Get("state"))
// check id_token / claims
provider, err := Tester.CreateRelyingParty(clientID, redirectURIImplicit)
require.NoError(t, err)
claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), accessToken, idToken, provider.IDTokenVerifier())
require.NoError(t, err)
assertTokenClaims(t, claims, startTime, changeTime)
// callback on a succeeded request must fail
linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.Error(t, err)
}
func TestOPStorage_CreateAccessAndRefreshTokens_code(t *testing.T) {
clientID := createClient(t)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.NoError(t, err)
// test code exchange (expect refresh token to be returned)
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertTokenClaims(t, tokens.IDTokenClaims, startTime, changeTime)
}
func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) {
clientID := createClient(t)
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
require.NoError(t, err)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.NoError(t, err)
// code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertTokenClaims(t, tokens.IDTokenClaims, startTime, changeTime)
// test actual refresh grant
newTokens, err := refreshTokens(t, clientID, tokens.RefreshToken)
require.NoError(t, err)
idToken, _ := newTokens.Extra("id_token").(string)
assert.NotEmpty(t, idToken)
assert.NotEmpty(t, newTokens.AccessToken)
assert.NotEmpty(t, newTokens.RefreshToken)
claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), newTokens.AccessToken, idToken, provider.IDTokenVerifier())
require.NoError(t, err)
// auth time must still be the initial
assertTokenClaims(t, claims, startTime, changeTime)
// refresh with an old refresh_token must fail
_, err = rp.RefreshAccessToken(provider, tokens.RefreshToken, "", "")
require.Error(t, err)
}
func exchangeTokens(t testing.TB, clientID, code string) (*oidc.Tokens[*oidc.IDTokenClaims], error) {
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
require.NoError(t, err)
codeVerifier := "codeVerifier"
return rp.CodeExchange[*oidc.IDTokenClaims](context.Background(), code, provider, rp.WithCodeVerifier(codeVerifier))
}
func refreshTokens(t testing.TB, clientID, refreshToken string) (*oauth2.Token, error) {
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
require.NoError(t, err)
return rp.RefreshAccessToken(provider, refreshToken, "", "")
}
func assertCodeResponse(t *testing.T, callback string) string {
callbackURL, err := url.Parse(callback)
require.NoError(t, err)
code := callbackURL.Query().Get("code")
require.NotEmpty(t, code)
assert.Equal(t, "state", callbackURL.Query().Get("state"))
return code
}
func assertTokens(t *testing.T, tokens *oidc.Tokens[*oidc.IDTokenClaims], requireRefreshToken bool) {
assert.NotEmpty(t, tokens.AccessToken)
assert.NotEmpty(t, tokens.IDToken)
if requireRefreshToken {
assert.NotEmpty(t, tokens.RefreshToken)
} else {
assert.Empty(t, tokens.RefreshToken)
}
}
func assertTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, sessionStart, sessionChange time.Time) {
assert.Equal(t, User.GetUserId(), claims.Subject)
assert.Equal(t, []string{amr.UserPresence, amr.MFA}, claims.AuthenticationMethodsReferences)
assert.WithinRange(t, claims.AuthTime.AsTime().UTC(), sessionStart.Add(-1*time.Second), sessionChange.Add(1*time.Second))
}

View File

@@ -66,7 +66,7 @@ func (o *OPStorage) GetClientByClientID(ctx context.Context, id string) (_ op.Cl
return nil, err
}
return ClientFromBusiness(client, o.defaultLoginURL, accessTokenLifetime, idTokenLifetime, allowedScopes)
return ClientFromBusiness(client, o.defaultLoginURL, o.defaultLoginURLV2, accessTokenLifetime, idTokenLifetime, allowedScopes)
}
func (o *OPStorage) GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (_ *jose.JSONWebKey, err error) {
@@ -153,7 +153,7 @@ func (o *OPStorage) SetUserinfoFromScopes(ctx context.Context, userInfo *oidc.Us
return o.setUserinfo(ctx, userInfo, userID, applicationID, scopes, nil)
}
func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) error {
func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) (err error) {
token, err := o.repo.TokenByIDs(ctx, subject, tokenID)
if err != nil {
return errors.ThrowPermissionDenied(nil, "OIDC-Dsfb2", "token is not valid or has expired")

View File

@@ -7,6 +7,7 @@ import (
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query"
@@ -15,18 +16,20 @@ import (
type Client struct {
app *query.App
defaultLoginURL string
defaultLoginURLV2 string
defaultAccessTokenLifetime time.Duration
defaultIdTokenLifetime time.Duration
allowedScopes []string
}
func ClientFromBusiness(app *query.App, defaultLoginURL string, defaultAccessTokenLifetime, defaultIdTokenLifetime time.Duration, allowedScopes []string) (op.Client, error) {
func ClientFromBusiness(app *query.App, defaultLoginURL, defaultLoginURLV2 string, defaultAccessTokenLifetime, defaultIdTokenLifetime time.Duration, allowedScopes []string) (op.Client, error) {
if app.OIDCConfig == nil {
return nil, errors.ThrowInvalidArgument(nil, "OIDC-d5bhD", "client is not a proper oidc application")
}
return &Client{
app: app,
defaultLoginURL: defaultLoginURL,
defaultLoginURLV2: defaultLoginURLV2,
defaultAccessTokenLifetime: defaultAccessTokenLifetime,
defaultIdTokenLifetime: defaultIdTokenLifetime,
allowedScopes: allowedScopes},
@@ -46,6 +49,9 @@ func (c *Client) GetID() string {
}
func (c *Client) LoginURL(id string) string {
if strings.HasPrefix(id, command.IDPrefixV2) {
return c.defaultLoginURLV2 + id
}
return c.defaultLoginURL + id
}

View File

@@ -41,6 +41,7 @@ type Config struct {
Cache *middleware.CacheConfig
CustomEndpoints *EndpointConfig
DeviceAuth *DeviceAuthorizationConfig
DefaultLoginURLV2 string
}
type EndpointConfig struct {
@@ -65,6 +66,7 @@ type OPStorage struct {
query *query.Queries
eventstore *eventstore.Eventstore
defaultLoginURL string
defaultLoginURLV2 string
defaultAccessTokenLifetime time.Duration
defaultIdTokenLifetime time.Duration
signingKeyAlgorithm string
@@ -181,6 +183,7 @@ func newStorage(config Config, command *command.Commands, query *query.Queries,
query: query,
eventstore: es,
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
defaultLoginURLV2: config.DefaultLoginURLV2,
signingKeyAlgorithm: config.SigningKeyAlgorithm,
defaultAccessTokenLifetime: config.DefaultAccessTokenLifetime,
defaultIdTokenLifetime: config.DefaultIdTokenLifetime,