feat(api): add OIDC session service (#6157)

This PR starts the OIDC implementation for the API V2 including the Implicit and Code Flow.


Co-authored-by: Livio Spring <livio.a@gmail.com>
Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>
Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
This commit is contained in:
Livio Spring
2023-07-10 15:27:00 +02:00
committed by GitHub
parent be1fe36776
commit 14b8cf4894
69 changed files with 5948 additions and 106 deletions

View File

@@ -0,0 +1,215 @@
package command
import (
"context"
"time"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type AuthRequest struct {
ID string
LoginClient string
ClientID string
RedirectURI string
State string
Nonce string
Scope []string
Audience []string
ResponseType domain.OIDCResponseType
CodeChallenge *domain.OIDCCodeChallenge
Prompt []domain.Prompt
UILocales []string
MaxAge *time.Duration
LoginHint *string
HintUserID *string
}
type CurrentAuthRequest struct {
*AuthRequest
SessionID string
UserID string
AMR []string
AuthTime time.Time
}
const IDPrefixV2 = "V2_"
func (c *Commands) AddAuthRequest(ctx context.Context, authRequest *AuthRequest) (_ *CurrentAuthRequest, err error) {
authRequestID, err := c.idGenerator.Next()
if err != nil {
return nil, err
}
authRequest.ID = IDPrefixV2 + authRequestID
writeModel, err := c.getAuthRequestWriteModel(ctx, authRequest.ID)
if err != nil {
return nil, err
}
if writeModel.AuthRequestState != domain.AuthRequestStateUnspecified {
return nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sf3gt", "Errors.AuthRequest.AlreadyExisting")
}
err = c.pushAppendAndReduce(ctx, writeModel, authrequest.NewAddedEvent(
ctx,
&authrequest.NewAggregate(authRequest.ID, authz.GetInstance(ctx).InstanceID()).Aggregate,
authRequest.LoginClient,
authRequest.ClientID,
authRequest.RedirectURI,
authRequest.State,
authRequest.Nonce,
authRequest.Scope,
authRequest.Audience,
authRequest.ResponseType,
authRequest.CodeChallenge,
authRequest.Prompt,
authRequest.UILocales,
authRequest.MaxAge,
authRequest.LoginHint,
authRequest.HintUserID,
))
if err != nil {
return nil, err
}
return authRequestWriteModelToCurrentAuthRequest(writeModel), nil
}
func (c *Commands) LinkSessionToAuthRequest(ctx context.Context, id, sessionID, sessionToken string, checkLoginClient bool) (*domain.ObjectDetails, *CurrentAuthRequest, error) {
writeModel, err := c.getAuthRequestWriteModel(ctx, id)
if err != nil {
return nil, nil, err
}
if writeModel.AuthRequestState == domain.AuthRequestStateUnspecified {
return nil, nil, errors.ThrowNotFound(nil, "COMMAND-jae5P", "Errors.AuthRequest.NotExisting")
}
if writeModel.AuthRequestState != domain.AuthRequestStateAdded {
return nil, nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sx208nt", "Errors.AuthRequest.AlreadyHandled")
}
if checkLoginClient && authz.GetCtxData(ctx).UserID != writeModel.LoginClient {
return nil, nil, errors.ThrowPermissionDenied(nil, "COMMAND-rai9Y", "Errors.AuthRequest.WrongLoginClient")
}
sessionWriteModel := NewSessionWriteModel(sessionID, authz.GetCtxData(ctx).OrgID)
err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
if err != nil {
return nil, nil, err
}
if sessionWriteModel.State == domain.SessionStateUnspecified {
return nil, nil, errors.ThrowNotFound(nil, "COMMAND-x0099887", "Errors.Session.NotExisting")
}
if err := c.sessionPermission(ctx, sessionWriteModel, sessionToken, domain.PermissionSessionWrite); err != nil {
return nil, nil, err
}
if err := c.pushAppendAndReduce(ctx, writeModel, authrequest.NewSessionLinkedEvent(
ctx, &authrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
sessionID,
sessionWriteModel.UserID,
sessionWriteModel.AuthenticationTime(),
amr.List(sessionWriteModel),
)); err != nil {
return nil, nil, err
}
return writeModelToObjectDetails(&writeModel.WriteModel), authRequestWriteModelToCurrentAuthRequest(writeModel), nil
}
func (c *Commands) FailAuthRequest(ctx context.Context, id string, reason domain.OIDCErrorReason) (*domain.ObjectDetails, *CurrentAuthRequest, error) {
writeModel, err := c.getAuthRequestWriteModel(ctx, id)
if err != nil {
return nil, nil, err
}
if writeModel.AuthRequestState != domain.AuthRequestStateAdded {
return nil, nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sx202nt", "Errors.AuthRequest.AlreadyHandled")
}
err = c.pushAppendAndReduce(ctx, writeModel, authrequest.NewFailedEvent(
ctx,
&authrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
reason,
))
if err != nil {
return nil, nil, err
}
return writeModelToObjectDetails(&writeModel.WriteModel), authRequestWriteModelToCurrentAuthRequest(writeModel), nil
}
func (c *Commands) AddAuthRequestCode(ctx context.Context, authRequestID, code string) (err error) {
if code == "" {
return errors.ThrowPreconditionFailed(nil, "COMMAND-Ht52d", "Errors.AuthRequest.InvalidCode")
}
writeModel, err := c.getAuthRequestWriteModel(ctx, authRequestID)
if err != nil {
return err
}
if writeModel.AuthRequestState != domain.AuthRequestStateAdded || writeModel.SessionID == "" {
return errors.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.AlreadyHandled")
}
return c.pushAppendAndReduce(ctx, writeModel, authrequest.NewCodeAddedEvent(ctx,
&authrequest.NewAggregate(writeModel.AggregateID, authz.GetInstance(ctx).InstanceID()).Aggregate))
}
func (c *Commands) ExchangeAuthCode(ctx context.Context, code string) (authRequest *CurrentAuthRequest, err error) {
if code == "" {
return nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sf3g2", "Errors.AuthRequest.InvalidCode")
}
writeModel, err := c.getAuthRequestWriteModel(ctx, code)
if err != nil {
return nil, err
}
if writeModel.AuthRequestState != domain.AuthRequestStateCodeAdded {
return nil, errors.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.NoCode")
}
err = c.pushAppendAndReduce(ctx, writeModel, authrequest.NewCodeExchangedEvent(ctx,
&authrequest.NewAggregate(writeModel.AggregateID, authz.GetInstance(ctx).InstanceID()).Aggregate))
if err != nil {
return nil, err
}
return authRequestWriteModelToCurrentAuthRequest(writeModel), nil
}
func authRequestWriteModelToCurrentAuthRequest(writeModel *AuthRequestWriteModel) (_ *CurrentAuthRequest) {
return &CurrentAuthRequest{
AuthRequest: &AuthRequest{
ID: writeModel.AggregateID,
LoginClient: writeModel.LoginClient,
ClientID: writeModel.ClientID,
RedirectURI: writeModel.RedirectURI,
State: writeModel.State,
Nonce: writeModel.Nonce,
Scope: writeModel.Scope,
Audience: writeModel.Audience,
ResponseType: writeModel.ResponseType,
CodeChallenge: writeModel.CodeChallenge,
Prompt: writeModel.Prompt,
UILocales: writeModel.UILocales,
MaxAge: writeModel.MaxAge,
LoginHint: writeModel.LoginHint,
HintUserID: writeModel.HintUserID,
},
SessionID: writeModel.SessionID,
UserID: writeModel.UserID,
AMR: writeModel.AMR,
AuthTime: writeModel.AuthTime,
}
}
func (c *Commands) GetCurrentAuthRequest(ctx context.Context, id string) (_ *CurrentAuthRequest, err error) {
wm, err := c.getAuthRequestWriteModel(ctx, id)
if err != nil {
return nil, err
}
return authRequestWriteModelToCurrentAuthRequest(wm), nil
}
func (c *Commands) getAuthRequestWriteModel(ctx context.Context, id string) (writeModel *AuthRequestWriteModel, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
writeModel = NewAuthRequestWriteModel(ctx, id)
err = c.eventstore.FilterToQueryReducer(ctx, writeModel)
if err != nil {
return nil, err
}
return writeModel, nil
}

View File

@@ -0,0 +1,110 @@
package command
import (
"context"
"time"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/authrequest"
)
type AuthRequestWriteModel struct {
eventstore.WriteModel
aggregate *eventstore.Aggregate
LoginClient string
ClientID string
RedirectURI string
State string
Nonce string
Scope []string
Audience []string
ResponseType domain.OIDCResponseType
CodeChallenge *domain.OIDCCodeChallenge
Prompt []domain.Prompt
UILocales []string
MaxAge *time.Duration
LoginHint *string
HintUserID *string
SessionID string
UserID string
AuthTime time.Time
AMR []string
AuthRequestState domain.AuthRequestState
}
func NewAuthRequestWriteModel(ctx context.Context, id string) *AuthRequestWriteModel {
return &AuthRequestWriteModel{
WriteModel: eventstore.WriteModel{
AggregateID: id,
},
aggregate: &authrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
}
}
func (m *AuthRequestWriteModel) Reduce() error {
for _, event := range m.Events {
switch e := event.(type) {
case *authrequest.AddedEvent:
m.LoginClient = e.LoginClient
m.ClientID = e.ClientID
m.RedirectURI = e.RedirectURI
m.State = e.State
m.Nonce = e.Nonce
m.Scope = e.Scope
m.Audience = e.Audience
m.ResponseType = e.ResponseType
m.CodeChallenge = e.CodeChallenge
m.Prompt = e.Prompt
m.UILocales = e.UILocales
m.MaxAge = e.MaxAge
m.LoginHint = e.LoginHint
m.HintUserID = e.HintUserID
m.AuthRequestState = domain.AuthRequestStateAdded
case *authrequest.SessionLinkedEvent:
m.SessionID = e.SessionID
m.UserID = e.UserID
m.AuthTime = e.AuthTime
m.AMR = e.AMR
case *authrequest.CodeAddedEvent:
m.AuthRequestState = domain.AuthRequestStateCodeAdded
case *authrequest.FailedEvent:
m.AuthRequestState = domain.AuthRequestStateFailed
case *authrequest.CodeExchangedEvent:
m.AuthRequestState = domain.AuthRequestStateCodeExchanged
case *authrequest.SucceededEvent:
m.AuthRequestState = domain.AuthRequestStateSucceeded
}
}
return m.WriteModel.Reduce()
}
func (m *AuthRequestWriteModel) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(authrequest.AggregateType).
AggregateIDs(m.AggregateID).
Builder()
}
// CheckAuthenticated checks that the auth request exists, a session must have been linked
// and in case of a Code Flow the code must have been exchanged
func (m *AuthRequestWriteModel) CheckAuthenticated() error {
if m.SessionID == "" {
return caos_errs.ThrowPreconditionFailed(nil, "AUTHR-SF2r2", "Errors.AuthRequest.NotAuthenticated")
}
// in case of OIDC Code Flow, the code must have been exchanged
if m.ResponseType == domain.OIDCResponseTypeCode && m.AuthRequestState == domain.AuthRequestStateCodeExchanged {
return nil
}
// in case of OIDC Implicit Flow, check that the requests exists, but has not succeeded yet
if (m.ResponseType == domain.OIDCResponseTypeIDToken || m.ResponseType == domain.OIDCResponseTypeIDTokenToken) &&
m.AuthRequestState == domain.AuthRequestStateAdded {
return nil
}
return caos_errs.ThrowPreconditionFailed(nil, "AUTHR-sajk3", "Errors.AuthRequest.NotAuthenticated")
}

View File

@@ -0,0 +1,998 @@
package command
import (
"context"
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/id/mock"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/session"
)
func TestCommands_AddAuthRequest(t *testing.T) {
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
type fields struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
}
type args struct {
ctx context.Context
request *AuthRequest
}
tests := []struct {
name string
fields fields
args args
want *CurrentAuthRequest
wantErr error
}{
{
"already exists error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
nil,
nil,
nil,
nil,
nil,
nil,
),
),
),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"),
},
args{
ctx: mockCtx,
request: &AuthRequest{},
},
nil,
caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sf3gt", "Errors.AuthRequest.AlreadyExisting"),
},
{
"added",
fields{
eventstore: eventstoreExpect(t,
expectFilter(),
expectPush(
[]*repository.Event{
eventFromEventPusherWithInstanceID("instanceID",
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
}),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"),
},
args{
ctx: mockCtx,
request: &AuthRequest{
LoginClient: "loginClient",
ClientID: "clientID",
RedirectURI: "redirectURI",
State: "state",
Nonce: "nonce",
Scope: []string{"openid"},
Audience: []string{"audience"},
ResponseType: domain.OIDCResponseTypeCode,
CodeChallenge: &domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
Prompt: []domain.Prompt{domain.PromptNone},
UILocales: []string{"en", "de"},
MaxAge: gu.Ptr(time.Duration(0)),
LoginHint: gu.Ptr("loginHint"),
HintUserID: gu.Ptr("hintUserID"),
},
},
&CurrentAuthRequest{
AuthRequest: &AuthRequest{
ID: "V2_id",
LoginClient: "loginClient",
ClientID: "clientID",
RedirectURI: "redirectURI",
State: "state",
Nonce: "nonce",
Scope: []string{"openid"},
Audience: []string{"audience"},
ResponseType: domain.OIDCResponseTypeCode,
CodeChallenge: &domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
Prompt: []domain.Prompt{domain.PromptNone},
UILocales: []string{"en", "de"},
MaxAge: gu.Ptr(time.Duration(0)),
LoginHint: gu.Ptr("loginHint"),
HintUserID: gu.Ptr("hintUserID"),
},
},
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
}
got, err := c.AddAuthRequest(tt.args.ctx, tt.args.request)
require.ErrorIs(t, tt.wantErr, err)
assert.Equal(t, tt.want, got)
})
}
}
func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
type fields struct {
eventstore *eventstore.Eventstore
tokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
checkPermission domain.PermissionCheck
}
type args struct {
ctx context.Context
id string
sessionID string
sessionToken string
checkLoginClient bool
}
type res struct {
details *domain.ObjectDetails
authReq *CurrentAuthRequest
wantErr error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"authRequest not found",
fields{
eventstore: eventstoreExpect(t,
expectFilter(),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return nil
},
checkPermission: newMockPermissionCheckNotAllowed(),
},
args{
ctx: mockCtx,
id: "id",
sessionID: "sessionID",
},
res{
wantErr: caos_errs.ThrowNotFound(nil, "COMMAND-jae5P", "Errors.AuthRequest.NotExisting"),
},
},
{
"authRequest not existing",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
nil,
nil,
nil,
nil,
nil,
nil,
),
),
eventFromEventPusher(
authrequest.NewFailedEvent(mockCtx, &authrequest.NewAggregate("id", "instanceID").Aggregate,
domain.OIDCErrorReasonUnspecified),
),
),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return nil
},
checkPermission: newMockPermissionCheckAllowed(),
},
args{
ctx: mockCtx,
id: "id",
sessionID: "sessionID",
},
res{
wantErr: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sx208nt", "Errors.AuthRequest.AlreadyHandled"),
},
},
{
"wrong login client",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
nil,
nil,
nil,
nil,
nil,
nil,
),
),
),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return nil
},
checkPermission: newMockPermissionCheckAllowed(),
},
args{
ctx: authz.NewMockContext("instanceID", "orgID", "wrongLoginClient"),
id: "id",
sessionID: "sessionID",
sessionToken: "token",
checkLoginClient: true,
},
res{
wantErr: caos_errs.ThrowPermissionDenied(nil, "COMMAND-rai9Y", "Errors.AuthRequest.WrongLoginClient"),
},
},
{
"session not existing",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
nil,
nil,
nil,
nil,
nil,
nil,
),
),
),
expectFilter(),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return nil
},
checkPermission: newMockPermissionCheckNotAllowed(),
},
args{
ctx: mockCtx,
id: "V2_id",
sessionID: "sessionID",
},
res{
wantErr: caos_errs.ThrowNotFound(nil, "COMMAND-x0099887", "Errors.Session.NotExisting"),
},
},
{
"missing permission",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
nil,
nil,
nil,
nil,
nil,
nil,
),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return nil
},
checkPermission: newMockPermissionCheckNotAllowed(),
},
args{
ctx: mockCtx,
id: "V2_id",
sessionID: "sessionID",
},
res{
wantErr: caos_errs.ThrowPermissionDenied(nil, "AUTHZ-HKJD33", "Errors.PermissionDenied"),
},
},
{
"invalid session token",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
nil,
nil,
nil,
nil,
nil,
nil,
),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return caos_errs.ThrowPermissionDenied(nil, "COMMAND-sGr42", "Errors.Session.Token.Invalid")
},
},
args{
ctx: mockCtx,
id: "V2_id",
sessionID: "sessionID",
sessionToken: "invalid",
},
res{
wantErr: caos_errs.ThrowPermissionDenied(nil, "COMMAND-sGr42", "Errors.Session.Token.Invalid"),
},
},
{
"linked",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
nil,
nil,
nil,
nil,
nil,
nil,
),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld"),
),
eventFromEventPusher(
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
"userID", testNow),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
testNow),
),
),
expectPush(
[]*repository.Event{eventFromEventPusherWithInstanceID(
"instanceID",
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
),
)}),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return nil
},
checkPermission: newMockPermissionCheckAllowed(),
},
args{
ctx: mockCtx,
id: "V2_id",
sessionID: "sessionID",
sessionToken: "token",
},
res{
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
authReq: &CurrentAuthRequest{
AuthRequest: &AuthRequest{
ID: "V2_id",
LoginClient: "loginClient",
ClientID: "clientID",
RedirectURI: "redirectURI",
State: "state",
Nonce: "nonce",
Scope: []string{"openid"},
Audience: []string{"audience"},
ResponseType: domain.OIDCResponseTypeCode,
},
SessionID: "sessionID",
UserID: "userID",
AMR: []string{amr.PWD},
},
},
},
{
"linked with login client check",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
nil,
nil,
nil,
nil,
nil,
nil,
),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld"),
),
eventFromEventPusher(
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
"userID", testNow),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
testNow),
),
),
expectPush(
[]*repository.Event{eventFromEventPusherWithInstanceID(
"instanceID",
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
),
)}),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return nil
},
checkPermission: newMockPermissionCheckAllowed(),
},
args{
ctx: authz.NewMockContext("instanceID", "orgID", "loginClient"),
id: "V2_id",
sessionID: "sessionID",
sessionToken: "token",
checkLoginClient: true,
},
res{
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
authReq: &CurrentAuthRequest{
AuthRequest: &AuthRequest{
ID: "V2_id",
LoginClient: "loginClient",
ClientID: "clientID",
RedirectURI: "redirectURI",
State: "state",
Nonce: "nonce",
Scope: []string{"openid"},
Audience: []string{"audience"},
ResponseType: domain.OIDCResponseTypeCode,
},
SessionID: "sessionID",
UserID: "userID",
AMR: []string{amr.PWD},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
sessionTokenVerifier: tt.fields.tokenVerifier,
checkPermission: tt.fields.checkPermission,
}
details, got, err := c.LinkSessionToAuthRequest(tt.args.ctx, tt.args.id, tt.args.sessionID, tt.args.sessionToken, tt.args.checkLoginClient)
require.ErrorIs(t, err, tt.res.wantErr)
assert.Equal(t, tt.res.details, details)
if err == nil {
assert.WithinRange(t, got.AuthTime, testNow, testNow)
got.AuthTime = time.Time{}
}
assert.Equal(t, tt.res.authReq, got)
})
}
}
func TestCommands_FailAuthRequest(t *testing.T) {
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
type fields struct {
eventstore *eventstore.Eventstore
}
type args struct {
ctx context.Context
id string
reason domain.OIDCErrorReason
}
type res struct {
details *domain.ObjectDetails
authReq *CurrentAuthRequest
wantErr error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"authRequest not existing",
fields{
eventstore: eventstoreExpect(t,
expectFilter(),
),
},
args{
ctx: mockCtx,
id: "foo",
reason: domain.OIDCErrorReasonLoginRequired,
},
res{
wantErr: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sx202nt", "Errors.AuthRequest.AlreadyHandled"),
},
},
{
"failed",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
nil,
nil,
nil,
nil,
nil,
nil,
),
),
),
expectPush(
[]*repository.Event{eventFromEventPusherWithInstanceID(
"instanceID",
authrequest.NewFailedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
domain.OIDCErrorReasonLoginRequired),
)}),
),
},
args{
ctx: mockCtx,
id: "V2_id",
reason: domain.OIDCErrorReasonLoginRequired,
},
res{
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
authReq: &CurrentAuthRequest{
AuthRequest: &AuthRequest{
ID: "V2_id",
LoginClient: "loginClient",
ClientID: "clientID",
RedirectURI: "redirectURI",
State: "state",
Nonce: "nonce",
Scope: []string{"openid"},
Audience: []string{"audience"},
ResponseType: domain.OIDCResponseTypeCode,
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
}
details, got, err := c.FailAuthRequest(tt.args.ctx, tt.args.id, tt.args.reason)
require.ErrorIs(t, err, tt.res.wantErr)
assert.Equal(t, tt.res.details, details)
assert.Equal(t, tt.res.authReq, got)
})
}
}
func TestCommands_AddAuthRequestCode(t *testing.T) {
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
type fields struct {
eventstore *eventstore.Eventstore
}
type args struct {
ctx context.Context
id string
code string
}
tests := []struct {
name string
fields fields
args args
wantErr error
}{
{
"empty code error",
fields{
eventstore: eventstoreExpect(t),
},
args{
ctx: mockCtx,
id: "V2_authRequestID",
code: "",
},
caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Ht52d", "Errors.AuthRequest.InvalidCode"),
},
{
"no session linked error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
),
),
},
args{
ctx: mockCtx,
id: "V2_authRequestID",
code: "V2_authRequestID",
},
caos_errs.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.AlreadyHandled"),
},
{
"success",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
eventFromEventPusher(
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
),
),
),
expectPush(
[]*repository.Event{
eventFromEventPusherWithInstanceID("instanceID",
authrequest.NewCodeAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
},
),
),
},
args{
ctx: mockCtx,
id: "V2_authRequestID",
code: "V2_authRequestID",
},
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
}
err := c.AddAuthRequestCode(tt.args.ctx, tt.args.id, tt.args.code)
assert.ErrorIs(t, tt.wantErr, err)
})
}
}
func TestCommands_ExchangeAuthCode(t *testing.T) {
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
type fields struct {
eventstore *eventstore.Eventstore
}
type args struct {
ctx context.Context
code string
}
type res struct {
authRequest *CurrentAuthRequest
err error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"empty code error",
fields{
eventstore: eventstoreExpect(t),
},
args{
ctx: mockCtx,
code: "",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sf3g2", "Errors.AuthRequest.InvalidCode"),
},
},
{
"no code added error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
),
),
},
args{
ctx: mockCtx,
code: "V2_authRequestID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.NoCode"),
},
},
{
"code exchanged",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
eventFromEventPusher(
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
),
),
eventFromEventPusher(
authrequest.NewCodeAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
),
expectPush(
[]*repository.Event{
eventFromEventPusherWithInstanceID("instanceID",
authrequest.NewCodeExchangedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
},
),
),
},
args{
ctx: mockCtx,
code: "V2_authRequestID",
},
res{
authRequest: &CurrentAuthRequest{
AuthRequest: &AuthRequest{
ID: "V2_authRequestID",
LoginClient: "loginClient",
ClientID: "clientID",
RedirectURI: "redirectURI",
State: "state",
Nonce: "nonce",
Scope: []string{"openid"},
Audience: []string{"audience"},
ResponseType: domain.OIDCResponseTypeCode,
CodeChallenge: &domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
Prompt: []domain.Prompt{domain.PromptNone},
UILocales: []string{"en", "de"},
MaxAge: gu.Ptr(time.Duration(0)),
LoginHint: gu.Ptr("loginHint"),
HintUserID: gu.Ptr("hintUserID"),
},
SessionID: "sessionID",
UserID: "userID",
AMR: []string{"pwd"},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
}
got, err := c.ExchangeAuthCode(tt.args.ctx, tt.args.code)
assert.ErrorIs(t, tt.res.err, err)
if err == nil {
// equal on time won't work -> test separately and clear it before comparing the rest
assert.WithinRange(t, got.AuthTime, testNow, testNow)
got.AuthTime = time.Time{}
}
assert.Equal(t, tt.res.authRequest, got)
})
}
}

View File

@@ -15,10 +15,12 @@ import (
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/repository/action"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/idpintent"
instance_repo "github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/keypair"
"github.com/zitadel/zitadel/internal/repository/milestone"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/repository/org"
proj_repo "github.com/zitadel/zitadel/internal/repository/project"
"github.com/zitadel/zitadel/internal/repository/quota"
@@ -43,18 +45,21 @@ type Commands struct {
externalSecure bool
externalPort uint16
idpConfigEncryption crypto.EncryptionAlgorithm
smtpEncryption crypto.EncryptionAlgorithm
smsEncryption crypto.EncryptionAlgorithm
userEncryption crypto.EncryptionAlgorithm
userPasswordAlg crypto.HashAlgorithm
machineKeySize int
applicationKeySize int
domainVerificationAlg crypto.EncryptionAlgorithm
domainVerificationGenerator crypto.Generator
domainVerificationValidator func(domain, token, verifier string, checkType api_http.CheckType) error
sessionTokenCreator func(sessionID string) (id string, token string, err error)
sessionTokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
idpConfigEncryption crypto.EncryptionAlgorithm
smtpEncryption crypto.EncryptionAlgorithm
smsEncryption crypto.EncryptionAlgorithm
userEncryption crypto.EncryptionAlgorithm
userPasswordAlg crypto.HashAlgorithm
machineKeySize int
applicationKeySize int
domainVerificationAlg crypto.EncryptionAlgorithm
domainVerificationGenerator crypto.Generator
domainVerificationValidator func(domain, token, verifier string, checkType api_http.CheckType) error
sessionTokenCreator func(sessionID string) (id string, token string, err error)
sessionTokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
defaultAccessTokenLifetime time.Duration
defaultRefreshTokenLifetime time.Duration
defaultRefreshTokenIdleLifetime time.Duration
multifactors domain.MultifactorConfigs
webauthnConfig *webauthn_helper.Config
@@ -80,6 +85,9 @@ func StartCommands(
httpClient *http.Client,
permissionCheck domain.PermissionCheck,
sessionTokenVerifier func(ctx context.Context, sessionToken string, sessionID string, tokenID string) (err error),
defaultAccessTokenLifetime,
defaultRefreshTokenLifetime,
defaultRefreshTokenIdleLifetime time.Duration,
) (repo *Commands, err error) {
if externalDomain == "" {
return nil, errors.ThrowInvalidArgument(nil, "COMMAND-Df21s", "no external domain specified")
@@ -88,31 +96,34 @@ func StartCommands(
// reuse the oidcEncryption to be able to handle both tokens in the interceptor later on
sessionAlg := oidcEncryption
repo = &Commands{
eventstore: es,
static: staticStore,
idGenerator: idGenerator,
zitadelRoles: zitadelRoles,
externalDomain: externalDomain,
externalSecure: externalSecure,
externalPort: externalPort,
keySize: defaults.KeyConfig.Size,
certKeySize: defaults.KeyConfig.CertificateSize,
privateKeyLifetime: defaults.KeyConfig.PrivateKeyLifetime,
publicKeyLifetime: defaults.KeyConfig.PublicKeyLifetime,
certificateLifetime: defaults.KeyConfig.CertificateLifetime,
idpConfigEncryption: idpConfigEncryption,
smtpEncryption: smtpEncryption,
smsEncryption: smsEncryption,
userEncryption: userEncryption,
domainVerificationAlg: domainVerificationEncryption,
keyAlgorithm: oidcEncryption,
certificateAlgorithm: samlEncryption,
webauthnConfig: webAuthN,
httpClient: httpClient,
checkPermission: permissionCheck,
newCode: newCryptoCode,
sessionTokenCreator: sessionTokenCreator(idGenerator, sessionAlg),
sessionTokenVerifier: sessionTokenVerifier,
eventstore: es,
static: staticStore,
idGenerator: idGenerator,
zitadelRoles: zitadelRoles,
externalDomain: externalDomain,
externalSecure: externalSecure,
externalPort: externalPort,
keySize: defaults.KeyConfig.Size,
certKeySize: defaults.KeyConfig.CertificateSize,
privateKeyLifetime: defaults.KeyConfig.PrivateKeyLifetime,
publicKeyLifetime: defaults.KeyConfig.PublicKeyLifetime,
certificateLifetime: defaults.KeyConfig.CertificateLifetime,
idpConfigEncryption: idpConfigEncryption,
smtpEncryption: smtpEncryption,
smsEncryption: smsEncryption,
userEncryption: userEncryption,
domainVerificationAlg: domainVerificationEncryption,
keyAlgorithm: oidcEncryption,
certificateAlgorithm: samlEncryption,
webauthnConfig: webAuthN,
httpClient: httpClient,
checkPermission: permissionCheck,
newCode: newCryptoCode,
sessionTokenCreator: sessionTokenCreator(idGenerator, sessionAlg),
sessionTokenVerifier: sessionTokenVerifier,
defaultAccessTokenLifetime: defaultAccessTokenLifetime,
defaultRefreshTokenLifetime: defaultRefreshTokenLifetime,
defaultRefreshTokenIdleLifetime: defaultRefreshTokenIdleLifetime,
}
instance_repo.RegisterEventMappers(repo.eventstore)
@@ -125,6 +136,8 @@ func StartCommands(
quota.RegisterEventMappers(repo.eventstore)
session.RegisterEventMappers(repo.eventstore)
idpintent.RegisterEventMappers(repo.eventstore)
authrequest.RegisterEventMappers(repo.eventstore)
oidcsession.RegisterEventMappers(repo.eventstore)
milestone.RegisterEventMappers(repo.eventstore)
repo.userPasswordAlg = crypto.NewBCrypt(defaults.SecretGenerators.PasswordSaltCost)

View File

@@ -16,9 +16,11 @@ import (
"github.com/zitadel/zitadel/internal/eventstore/repository"
"github.com/zitadel/zitadel/internal/eventstore/repository/mock"
action_repo "github.com/zitadel/zitadel/internal/repository/action"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/idpintent"
iam_repo "github.com/zitadel/zitadel/internal/repository/instance"
key_repo "github.com/zitadel/zitadel/internal/repository/keypair"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/repository/org"
proj_repo "github.com/zitadel/zitadel/internal/repository/project"
"github.com/zitadel/zitadel/internal/repository/session"
@@ -43,6 +45,8 @@ func eventstoreExpect(t *testing.T, expects ...expect) *eventstore.Eventstore {
action_repo.RegisterEventMappers(es)
session.RegisterEventMappers(es)
idpintent.RegisterEventMappers(es)
authrequest.RegisterEventMappers(es)
oidcsession.RegisterEventMappers(es)
return es
}

View File

@@ -0,0 +1,281 @@
package command
import (
"context"
"encoding/base64"
"strings"
"time"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
)
// AddOIDCSessionAccessToken creates a new OIDC Session, creates an access token and returns its id and expiration.
// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged.
func (c *Commands) AddOIDCSessionAccessToken(ctx context.Context, authRequestID string) (string, time.Time, error) {
cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID)
if err != nil {
return "", time.Time{}, err
}
cmd.AddSession(ctx)
if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope); err != nil {
return "", time.Time{}, err
}
cmd.SetAuthRequestSuccessful(ctx)
accessTokenID, _, accessTokenExpiration, err := cmd.PushEvents(ctx)
return accessTokenID, accessTokenExpiration, err
}
// AddOIDCSessionRefreshAndAccessToken creates a new OIDC Session, creates an access token and refresh token.
// It returns the access token id, expiration and the refresh token.
// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged.
func (c *Commands) AddOIDCSessionRefreshAndAccessToken(ctx context.Context, authRequestID string) (tokenID, refreshToken string, tokenExpiration time.Time, err error) {
cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID)
if err != nil {
return "", "", time.Time{}, err
}
cmd.AddSession(ctx)
if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope); err != nil {
return "", "", time.Time{}, err
}
if err = cmd.AddRefreshToken(ctx); err != nil {
return "", "", time.Time{}, err
}
cmd.SetAuthRequestSuccessful(ctx)
return cmd.PushEvents(ctx)
}
// ExchangeOIDCSessionRefreshAndAccessToken updates an existing OIDC Session, creates a new access and refresh token.
// It returns the access token id and expiration and the new refresh token.
func (c *Commands) ExchangeOIDCSessionRefreshAndAccessToken(ctx context.Context, oidcSessionID, refreshToken string, scope []string) (tokenID, newRefreshToken string, tokenExpiration time.Time, err error) {
cmd, err := c.newOIDCSessionUpdateEvents(ctx, oidcSessionID, refreshToken)
if err != nil {
return "", "", time.Time{}, err
}
if err = cmd.AddAccessToken(ctx, scope); err != nil {
return "", "", time.Time{}, err
}
if err = cmd.RenewRefreshToken(ctx); err != nil {
return "", "", time.Time{}, err
}
return cmd.PushEvents(ctx)
}
// OIDCSessionByRefreshToken computes the current state of an existing OIDCSession by a refresh_token (to start a Refresh Token Grant).
// If either the session is not active, the token is invalid or expired (incl. idle expiration) an invalid refresh token error will be returned.
func (c *Commands) OIDCSessionByRefreshToken(ctx context.Context, refreshToken string) (*OIDCSessionWriteModel, error) {
split := strings.Split(refreshToken, ":")
if len(split) != 2 {
return nil, caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid")
}
writeModel := NewOIDCSessionWriteModel(split[0], "")
err := c.eventstore.FilterToQueryReducer(ctx, writeModel)
if err != nil {
return nil, caos_errs.ThrowPreconditionFailed(err, "OIDCS-SAF31", "Errors.OIDCSession.RefreshTokenInvalid")
}
if err = writeModel.CheckRefreshToken(split[1]); err != nil {
return nil, err
}
return writeModel, nil
}
func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, authRequestID string) (*OIDCSessionEvents, error) {
authRequestWriteModel, err := c.getAuthRequestWriteModel(ctx, authRequestID)
if err != nil {
return nil, err
}
if err = authRequestWriteModel.CheckAuthenticated(); err != nil {
return nil, err
}
sessionWriteModel := NewSessionWriteModel(authRequestWriteModel.SessionID, authz.GetCtxData(ctx).OrgID)
err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
if err != nil {
return nil, err
}
if sessionWriteModel.State != domain.SessionStateActive {
return nil, caos_errs.ThrowPreconditionFailed(nil, "OIDCS-sjkl3", "Errors.Session.Terminated")
}
accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx)
if err != nil {
return nil, err
}
sessionID, err := c.idGenerator.Next()
if err != nil {
return nil, err
}
sessionID = IDPrefixV2 + sessionID
return &OIDCSessionEvents{
eventstore: c.eventstore,
idGenerator: c.idGenerator,
encryptionAlg: c.keyAlgorithm,
oidcSessionWriteModel: NewOIDCSessionWriteModel(sessionID, authz.GetInstance(ctx).InstanceID()),
sessionWriteModel: sessionWriteModel,
authRequestWriteModel: authRequestWriteModel,
accessTokenLifetime: accessTokenLifetime,
refreshTokenLifeTime: refreshTokenLifeTime,
refreshTokenIdleLifetime: refreshTokenIdleLifetime,
}, nil
}
func (c *Commands) decryptRefreshToken(refreshToken string) (refreshTokenID string, err error) {
decoded, err := base64.RawURLEncoding.DecodeString(refreshToken)
if err != nil {
return "", err
}
decrypted, err := c.keyAlgorithm.DecryptString(decoded, c.keyAlgorithm.EncryptionKeyID())
if err != nil {
return "", err
}
split := strings.Split(decrypted, ":")
if len(split) != 2 {
return "", caos_errs.ThrowPreconditionFailed(nil, "OIDCS-Sj3lk", "Errors.OIDCSession.RefreshTokenInvalid")
}
return split[1], nil
}
func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, oidcSessionID, refreshToken string) (*OIDCSessionEvents, error) {
refreshTokenID, err := c.decryptRefreshToken(refreshToken)
if err != nil {
return nil, err
}
sessionWriteModel := NewOIDCSessionWriteModel(oidcSessionID, authz.GetInstance(ctx).InstanceID())
if err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel); err != nil {
return nil, err
}
if err = sessionWriteModel.CheckRefreshToken(refreshTokenID); err != nil {
return nil, err
}
accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx)
if err != nil {
return nil, err
}
return &OIDCSessionEvents{
eventstore: c.eventstore,
idGenerator: c.idGenerator,
encryptionAlg: c.keyAlgorithm,
oidcSessionWriteModel: sessionWriteModel,
accessTokenLifetime: accessTokenLifetime,
refreshTokenLifeTime: refreshTokenLifeTime,
refreshTokenIdleLifetime: refreshTokenIdleLifetime,
}, nil
}
type OIDCSessionEvents struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
encryptionAlg crypto.EncryptionAlgorithm
events []eventstore.Command
oidcSessionWriteModel *OIDCSessionWriteModel
sessionWriteModel *SessionWriteModel
authRequestWriteModel *AuthRequestWriteModel
accessTokenLifetime time.Duration
refreshTokenLifeTime time.Duration
refreshTokenIdleLifetime time.Duration
// accessTokenID is set by the command
accessTokenID string
// refreshToken is set by the command
refreshToken string
}
func (c *OIDCSessionEvents) AddSession(ctx context.Context) {
c.events = append(c.events, oidcsession.NewAddedEvent(
ctx,
c.oidcSessionWriteModel.aggregate,
c.sessionWriteModel.UserID,
c.sessionWriteModel.AggregateID,
c.authRequestWriteModel.ClientID,
c.authRequestWriteModel.Audience,
c.authRequestWriteModel.Scope,
amr.List(c.sessionWriteModel),
c.sessionWriteModel.AuthenticationTime(),
))
}
func (c *OIDCSessionEvents) SetAuthRequestSuccessful(ctx context.Context) {
c.events = append(c.events, authrequest.NewSucceededEvent(ctx, c.authRequestWriteModel.aggregate))
}
func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string) (err error) {
c.accessTokenID, err = c.idGenerator.Next()
if err != nil {
return err
}
c.events = append(c.events, oidcsession.NewAccessTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.accessTokenID, scope, c.accessTokenLifetime))
return nil
}
func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context) (err error) {
var refreshTokenID string
refreshTokenID, c.refreshToken, err = c.generateRefreshToken()
if err != nil {
return err
}
c.events = append(c.events, oidcsession.NewRefreshTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, refreshTokenID, c.refreshTokenLifeTime, c.refreshTokenIdleLifetime))
return nil
}
func (c *OIDCSessionEvents) RenewRefreshToken(ctx context.Context) (err error) {
var refreshTokenID string
refreshTokenID, c.refreshToken, err = c.generateRefreshToken()
if err != nil {
return err
}
c.events = append(c.events, oidcsession.NewRefreshTokenRenewedEvent(ctx, c.oidcSessionWriteModel.aggregate, refreshTokenID, c.refreshTokenIdleLifetime))
return nil
}
func (c *OIDCSessionEvents) generateRefreshToken() (refreshTokenID, refreshToken string, err error) {
refreshTokenID, err = c.idGenerator.Next()
if err != nil {
return "", "", err
}
token, err := c.encryptionAlg.Encrypt([]byte(c.oidcSessionWriteModel.AggregateID + ":" + refreshTokenID))
if err != nil {
return "", "", err
}
return refreshTokenID, base64.RawURLEncoding.EncodeToString(token), nil
}
func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (accessTokenID string, refreshToken string, accessTokenExpiration time.Time, err error) {
pushedEvents, err := c.eventstore.Push(ctx, c.events...)
if err != nil {
return "", "", time.Time{}, err
}
err = AppendAndReduce(c.oidcSessionWriteModel, pushedEvents...)
if err != nil {
return "", "", time.Time{}, err
}
// prefix the returned id with the oidcSessionID so that we can retrieve it later on
// we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts
return c.oidcSessionWriteModel.AggregateID + "-" + c.accessTokenID, c.refreshToken, c.oidcSessionWriteModel.AccessTokenExpiration, nil
}
func (c *Commands) tokenTokenLifetimes(ctx context.Context) (accessTokenLifetime time.Duration, refreshTokenLifetime time.Duration, refreshTokenIdleLifetime time.Duration, err error) {
oidcSettings := NewInstanceOIDCSettingsWriteModel(ctx)
err = c.eventstore.FilterToQueryReducer(ctx, oidcSettings)
if err != nil {
return 0, 0, 0, err
}
accessTokenLifetime = c.defaultAccessTokenLifetime
refreshTokenLifetime = c.defaultRefreshTokenLifetime
refreshTokenIdleLifetime = c.defaultRefreshTokenIdleLifetime
if oidcSettings.AccessTokenLifetime > 0 {
accessTokenLifetime = oidcSettings.AccessTokenLifetime
}
if oidcSettings.RefreshTokenExpiration > 0 {
refreshTokenLifetime = oidcSettings.RefreshTokenExpiration
}
if oidcSettings.RefreshTokenIdleExpiration > 0 {
refreshTokenIdleLifetime = oidcSettings.RefreshTokenIdleExpiration
}
return accessTokenLifetime, refreshTokenLifetime, refreshTokenIdleLifetime, nil
}

View File

@@ -0,0 +1,114 @@
package command
import (
"time"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
)
type OIDCSessionWriteModel struct {
eventstore.WriteModel
UserID string
SessionID string
ClientID string
Audience []string
Scope []string
AuthMethodsReferences []string
AuthTime time.Time
State domain.OIDCSessionState
AccessTokenExpiration time.Time
RefreshTokenID string
RefreshTokenExpiration time.Time
RefreshTokenIdleExpiration time.Time
aggregate *eventstore.Aggregate
}
func NewOIDCSessionWriteModel(id string, resourceOwner string) *OIDCSessionWriteModel {
return &OIDCSessionWriteModel{
WriteModel: eventstore.WriteModel{
AggregateID: id,
ResourceOwner: resourceOwner,
},
aggregate: &oidcsession.NewAggregate(id, resourceOwner).Aggregate,
}
}
func (wm *OIDCSessionWriteModel) Reduce() error {
for _, event := range wm.Events {
switch e := event.(type) {
case *oidcsession.AddedEvent:
wm.reduceAdded(e)
case *oidcsession.AccessTokenAddedEvent:
wm.reduceAccessTokenAdded(e)
case *oidcsession.RefreshTokenAddedEvent:
wm.reduceRefreshTokenAdded(e)
case *oidcsession.RefreshTokenRenewedEvent:
wm.reduceRefreshTokenRenewed(e)
}
}
return wm.WriteModel.Reduce()
}
func (wm *OIDCSessionWriteModel) Query() *eventstore.SearchQueryBuilder {
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(oidcsession.AggregateType).
AggregateIDs(wm.AggregateID).
EventTypes(
oidcsession.AddedType,
oidcsession.AccessTokenAddedType,
oidcsession.RefreshTokenAddedType,
oidcsession.RefreshTokenRenewedType,
).
Builder()
if wm.ResourceOwner != "" {
query.ResourceOwner(wm.ResourceOwner)
}
return query
}
func (wm *OIDCSessionWriteModel) reduceAdded(e *oidcsession.AddedEvent) {
wm.UserID = e.UserID
wm.SessionID = e.SessionID
wm.ClientID = e.ClientID
wm.Audience = e.Audience
wm.Scope = e.Scope
wm.AuthMethodsReferences = e.AuthMethodsReferences
wm.AuthTime = e.AuthTime
wm.State = domain.OIDCSessionStateActive
}
func (wm *OIDCSessionWriteModel) reduceAccessTokenAdded(e *oidcsession.AccessTokenAddedEvent) {
wm.AccessTokenExpiration = e.CreationDate().Add(e.Lifetime)
}
func (wm *OIDCSessionWriteModel) reduceRefreshTokenAdded(e *oidcsession.RefreshTokenAddedEvent) {
wm.RefreshTokenID = e.ID
wm.RefreshTokenExpiration = e.CreationDate().Add(e.Lifetime)
wm.RefreshTokenIdleExpiration = e.CreationDate().Add(e.IdleLifetime)
}
func (wm *OIDCSessionWriteModel) reduceRefreshTokenRenewed(e *oidcsession.RefreshTokenRenewedEvent) {
wm.RefreshTokenID = e.ID
wm.RefreshTokenIdleExpiration = e.CreationDate().Add(e.IdleLifetime)
}
func (wm *OIDCSessionWriteModel) CheckRefreshToken(refreshTokenID string) error {
if wm.State != domain.OIDCSessionStateActive {
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid")
}
if wm.RefreshTokenID != refreshTokenID {
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid")
}
now := time.Now()
if wm.RefreshTokenExpiration.Before(now) || wm.RefreshTokenIdleExpiration.Before(now) {
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid")
}
return nil
}

View File

@@ -0,0 +1,795 @@
package command
import (
"context"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/id/mock"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/repository/session"
)
var (
testNow = time.Now()
tokenCreationNow = time.Time{}
)
func TestCommands_AddOIDCSessionAccessToken(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
defaultAccessTokenLifetime time.Duration
defaultRefreshTokenLifetime time.Duration
defaultRefreshTokenIdleLifetime time.Duration
keyAlgorithm crypto.EncryptionAlgorithm
}
type args struct {
ctx context.Context
authRequestID string
}
type res struct {
id string
expiration time.Time
err error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"unauthenticated error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(),
),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "V2_authRequestID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "AUTHR-SF2r2", "Errors.AuthRequest.NotAuthenticated"),
},
},
{
"inactive session error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
eventFromEventPusher(
authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
),
),
eventFromEventPusher(
authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
eventFromEventPusher(
authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
),
expectFilter(),
),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "V2_authRequestID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-sjkl3", "Errors.Session.Terminated"),
},
},
{
"add successful",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
eventFromEventPusher(
authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
),
),
eventFromEventPusher(
authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
eventFromEventPusher(
authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, "domain.tld"),
),
eventFromEventPusher(
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
"userID", testNow),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
testNow),
),
),
expectFilter(), // token lifetime
expectPush(
[]*repository.Event{
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid"}, []string{amr.PWD}, testNow),
),
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"accessTokenID", []string{"openid"}, time.Hour),
),
eventFromEventPusherWithInstanceID("instanceID",
authrequest.NewSucceededEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
},
),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID"),
defaultAccessTokenLifetime: time.Hour,
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "V2_authRequestID",
},
res{
id: "V2_oidcSessionID-accessTokenID",
expiration: tokenCreationNow.Add(time.Hour),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
keyAlgorithm: tt.fields.keyAlgorithm,
}
gotID, gotExpiration, err := c.AddOIDCSessionAccessToken(tt.args.ctx, tt.args.authRequestID)
assert.Equal(t, tt.res.id, gotID)
assert.Equal(t, tt.res.expiration, gotExpiration)
assert.ErrorIs(t, err, tt.res.err)
})
}
}
func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
defaultAccessTokenLifetime time.Duration
defaultRefreshTokenLifetime time.Duration
defaultRefreshTokenIdleLifetime time.Duration
keyAlgorithm crypto.EncryptionAlgorithm
}
type args struct {
ctx context.Context
authRequestID string
}
type res struct {
id string
refreshToken string
expiration time.Time
err error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"unauthenticated error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(),
),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "V2_authRequestID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "AUTHR-SF2r2", "Errors.AuthRequest.NotAuthenticated"),
},
},
{
"inactive session error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid", "offline_access"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
eventFromEventPusher(
authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
),
),
eventFromEventPusher(
authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
eventFromEventPusher(
authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
),
expectFilter(),
),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "V2_authRequestID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-sjkl3", "Errors.Session.Terminated"),
},
},
{
"add successful",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid", "offline_access"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
eventFromEventPusher(
authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
),
),
eventFromEventPusher(
authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
eventFromEventPusher(
authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, "domain.tld"),
),
eventFromEventPusher(
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
"userID", testNow),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
testNow),
),
),
expectFilter(), // token lifetime
expectPush(
[]*repository.Event{
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "offline_access"}, []string{amr.PWD}, testNow),
),
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"accessTokenID", []string{"openid", "offline_access"}, time.Hour),
),
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
),
eventFromEventPusherWithInstanceID("instanceID",
authrequest.NewSucceededEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
},
),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID", "refreshTokenID"),
defaultAccessTokenLifetime: time.Hour,
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "V2_authRequestID",
},
res{
id: "V2_oidcSessionID-accessTokenID",
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA", //V2_oidcSessionID:refreshTokenID
expiration: tokenCreationNow.Add(time.Hour),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
keyAlgorithm: tt.fields.keyAlgorithm,
}
gotID, gotRefreshToken, gotExpiration, err := c.AddOIDCSessionRefreshAndAccessToken(tt.args.ctx, tt.args.authRequestID)
assert.Equal(t, tt.res.id, gotID)
assert.Equal(t, tt.res.refreshToken, gotRefreshToken)
assert.Equal(t, tt.res.expiration, gotExpiration)
assert.ErrorIs(t, err, tt.res.err)
})
}
}
func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
defaultAccessTokenLifetime time.Duration
defaultRefreshTokenLifetime time.Duration
defaultRefreshTokenIdleLifetime time.Duration
keyAlgorithm crypto.EncryptionAlgorithm
}
type args struct {
ctx context.Context
oidcSessionID string
refreshToken string
scope []string
}
type res struct {
id string
refreshToken string
expiration time.Time
err error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"invalid refresh token format error",
fields{
eventstore: eventstoreExpect(t),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcSessionID: "V2_oidcSessionID",
refreshToken: "aW52YWxpZA",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-Sj3lk", "Errors.OIDCSession.RefreshTokenInvalid"),
},
},
{
"inactive session error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcSessionID: "V2_oidcSessionID",
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid"),
},
},
{
"invalid refresh token error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
),
eventFromEventPusher(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
),
),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcSessionID: "V2_oidcSessionID",
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid"),
},
},
{
"expired refresh token error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
),
eventFromEventPusher(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
),
eventFromEventPusher(
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
),
),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcSessionID: "V2_oidcSessionID",
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid"),
},
},
{
"refresh successful",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusherWithCreationDateNow(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
),
eventFromEventPusherWithCreationDateNow(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
),
eventFromEventPusherWithCreationDateNow(
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
),
),
expectFilter(), // token lifetime
expectPush(
[]*repository.Event{
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"accessTokenID", []string{"openid", "offline_access"}, time.Hour),
),
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewRefreshTokenRenewedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"refreshTokenID2", 24*time.Hour),
),
},
),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t, "accessTokenID", "refreshTokenID2"),
defaultAccessTokenLifetime: time.Hour,
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcSessionID: "V2_oidcSessionID",
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA",
scope: []string{"openid", "offline_access"},
},
res{
id: "V2_oidcSessionID-accessTokenID",
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRDI",
expiration: time.Time{}.Add(time.Hour),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
keyAlgorithm: tt.fields.keyAlgorithm,
}
gotID, gotRefreshToken, gotExpiration, err := c.ExchangeOIDCSessionRefreshAndAccessToken(tt.args.ctx, tt.args.oidcSessionID, tt.args.refreshToken, tt.args.scope)
assert.Equal(t, tt.res.id, gotID)
assert.Equal(t, tt.res.refreshToken, gotRefreshToken)
assert.Equal(t, tt.res.expiration, gotExpiration)
assert.ErrorIs(t, err, tt.res.err)
})
}
}
func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
defaultAccessTokenLifetime time.Duration
defaultRefreshTokenLifetime time.Duration
defaultRefreshTokenIdleLifetime time.Duration
keyAlgorithm crypto.EncryptionAlgorithm
}
type args struct {
ctx context.Context
refreshToken string
}
type res struct {
model *OIDCSessionWriteModel
err error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"invalid refresh token format error",
fields{
eventstore: eventstoreExpect(t),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
refreshToken: "invalid",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid"),
},
},
{
"inactive session error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
refreshToken: "V2_oidcSessionID:refreshTokenID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid"),
},
},
{
"invalid refresh token error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
),
eventFromEventPusher(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
),
),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
refreshToken: "V2_oidcSessionID:refreshTokenID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid"),
},
},
{
"expired refresh token error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
),
eventFromEventPusher(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
),
eventFromEventPusher(
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
),
),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
refreshToken: "V2_oidcSessionID:refreshTokenID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid"),
},
},
{
"get successful",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusherWithCreationDateNow(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
),
eventFromEventPusherWithCreationDateNow(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
),
eventFromEventPusherWithCreationDateNow(
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
),
),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
refreshToken: "V2_oidcSessionID:refreshTokenID",
},
res{
model: &OIDCSessionWriteModel{
WriteModel: eventstore.WriteModel{
AggregateID: "V2_oidcSessionID",
ChangeDate: testNow,
},
UserID: "userID",
SessionID: "sessionID",
ClientID: "clientID",
Audience: []string{"audience"},
Scope: []string{"openid", "profile", "offline_access"},
AuthMethodsReferences: []string{amr.PWD},
AuthTime: testNow,
State: domain.OIDCSessionStateActive,
RefreshTokenID: "refreshTokenID",
RefreshTokenExpiration: testNow.Add(7 * 24 * time.Hour),
RefreshTokenIdleExpiration: testNow.Add(24 * time.Hour),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
keyAlgorithm: tt.fields.keyAlgorithm,
}
got, err := c.OIDCSessionByRefreshToken(tt.args.ctx, tt.args.refreshToken)
require.ErrorIs(t, err, tt.res.err)
if tt.res.err == nil {
assert.WithinRange(t, got.ChangeDate, tt.res.model.ChangeDate.Add(-2*time.Second), tt.res.model.ChangeDate.Add(2*time.Second))
assert.Equal(t, tt.res.model.AggregateID, got.AggregateID)
assert.Equal(t, tt.res.model.UserID, got.UserID)
assert.Equal(t, tt.res.model.SessionID, got.SessionID)
assert.Equal(t, tt.res.model.ClientID, got.ClientID)
assert.Equal(t, tt.res.model.Audience, got.Audience)
assert.Equal(t, tt.res.model.Scope, got.Scope)
assert.Equal(t, tt.res.model.AuthMethodsReferences, got.AuthMethodsReferences)
assert.WithinRange(t, got.AuthTime, tt.res.model.AuthTime.Add(-2*time.Second), tt.res.model.AuthTime.Add(2*time.Second))
assert.Equal(t, tt.res.model.State, got.State)
assert.Equal(t, tt.res.model.RefreshTokenID, got.RefreshTokenID)
assert.WithinRange(t, got.RefreshTokenExpiration, tt.res.model.RefreshTokenExpiration.Add(-2*time.Second), tt.res.model.RefreshTokenExpiration.Add(2*time.Second))
assert.WithinRange(t, got.RefreshTokenIdleExpiration, tt.res.model.RefreshTokenIdleExpiration.Add(-2*time.Second), tt.res.model.RefreshTokenIdleExpiration.Add(2*time.Second))
}
})
}
}

View File

@@ -52,6 +52,24 @@ type SessionWriteModel struct {
aggregate *eventstore.Aggregate
}
func (wm *SessionWriteModel) IsPasswordChecked() bool {
return !wm.PasswordCheckedAt.IsZero()
}
func (wm *SessionWriteModel) IsPasskeyChecked() bool {
return !wm.PasskeyCheckedAt.IsZero()
}
func (wm *SessionWriteModel) IsU2FChecked() bool {
// TODO: implement with https://github.com/zitadel/zitadel/issues/5477
return false
}
func (wm *SessionWriteModel) IsOTPChecked() bool {
// TODO: implement with https://github.com/zitadel/zitadel/issues/5477
return false
}
func NewSessionWriteModel(sessionID string, resourceOwner string) *SessionWriteModel {
return &SessionWriteModel{
WriteModel: eventstore.WriteModel{
@@ -210,3 +228,19 @@ func (wm *SessionWriteModel) ChangeMetadata(ctx context.Context, metadata map[st
wm.commands = append(wm.commands, session.NewMetadataSetEvent(ctx, wm.aggregate, wm.Metadata))
}
}
// AuthenticationTime returns the time the user authenticated using the latest time of all checks
func (wm *SessionWriteModel) AuthenticationTime() time.Time {
var authTime time.Time
for _, check := range []time.Time{
wm.PasswordCheckedAt,
wm.PasskeyCheckedAt,
wm.IntentCheckedAt,
// TODO: add U2F and OTP check https://github.com/zitadel/zitadel/issues/5477
} {
if check.After(authTime) {
authTime = check
}
}
return authTime
}

View File

@@ -528,7 +528,7 @@ func TestCommands_createHumanOTP(t *testing.T) {
}
func TestCommands_HumanCheckMFAOTPSetup(t *testing.T) {
ctx := authz.NewMockContext("inst1", "org1", "user1")
ctx := authz.NewMockContext("", "org1", "user1")
cryptoAlg := crypto.CreateMockEncryptionAlg(gomock.NewController(t))
key, secret, err := domain.NewOTPKey("example.com", "user1", cryptoAlg)

View File

@@ -188,7 +188,7 @@ func TestCommands_AddUserTOTP(t *testing.T) {
}
func TestCommands_CheckUserTOTP(t *testing.T) {
ctx := authz.NewMockContext("inst1", "org1", "user1")
ctx := authz.NewMockContext("", "org1", "user1")
cryptoAlg := crypto.CreateMockEncryptionAlg(gomock.NewController(t))
key, secret, err := domain.NewOTPKey("example.com", "user1", cryptoAlg)