mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 20:47:32 +00:00
feat(api): add OIDC session service (#6157)
This PR starts the OIDC implementation for the API V2 including the Implicit and Code Flow. Co-authored-by: Livio Spring <livio.a@gmail.com> Co-authored-by: Tim Möhlmann <tim+github@zitadel.com> Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
This commit is contained in:
215
internal/command/auth_request.go
Normal file
215
internal/command/auth_request.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/oidc/amr"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
type AuthRequest struct {
|
||||
ID string
|
||||
LoginClient string
|
||||
ClientID string
|
||||
RedirectURI string
|
||||
State string
|
||||
Nonce string
|
||||
Scope []string
|
||||
Audience []string
|
||||
ResponseType domain.OIDCResponseType
|
||||
CodeChallenge *domain.OIDCCodeChallenge
|
||||
Prompt []domain.Prompt
|
||||
UILocales []string
|
||||
MaxAge *time.Duration
|
||||
LoginHint *string
|
||||
HintUserID *string
|
||||
}
|
||||
|
||||
type CurrentAuthRequest struct {
|
||||
*AuthRequest
|
||||
SessionID string
|
||||
UserID string
|
||||
AMR []string
|
||||
AuthTime time.Time
|
||||
}
|
||||
|
||||
const IDPrefixV2 = "V2_"
|
||||
|
||||
func (c *Commands) AddAuthRequest(ctx context.Context, authRequest *AuthRequest) (_ *CurrentAuthRequest, err error) {
|
||||
authRequestID, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authRequest.ID = IDPrefixV2 + authRequestID
|
||||
writeModel, err := c.getAuthRequestWriteModel(ctx, authRequest.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if writeModel.AuthRequestState != domain.AuthRequestStateUnspecified {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sf3gt", "Errors.AuthRequest.AlreadyExisting")
|
||||
}
|
||||
err = c.pushAppendAndReduce(ctx, writeModel, authrequest.NewAddedEvent(
|
||||
ctx,
|
||||
&authrequest.NewAggregate(authRequest.ID, authz.GetInstance(ctx).InstanceID()).Aggregate,
|
||||
authRequest.LoginClient,
|
||||
authRequest.ClientID,
|
||||
authRequest.RedirectURI,
|
||||
authRequest.State,
|
||||
authRequest.Nonce,
|
||||
authRequest.Scope,
|
||||
authRequest.Audience,
|
||||
authRequest.ResponseType,
|
||||
authRequest.CodeChallenge,
|
||||
authRequest.Prompt,
|
||||
authRequest.UILocales,
|
||||
authRequest.MaxAge,
|
||||
authRequest.LoginHint,
|
||||
authRequest.HintUserID,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return authRequestWriteModelToCurrentAuthRequest(writeModel), nil
|
||||
}
|
||||
|
||||
func (c *Commands) LinkSessionToAuthRequest(ctx context.Context, id, sessionID, sessionToken string, checkLoginClient bool) (*domain.ObjectDetails, *CurrentAuthRequest, error) {
|
||||
writeModel, err := c.getAuthRequestWriteModel(ctx, id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if writeModel.AuthRequestState == domain.AuthRequestStateUnspecified {
|
||||
return nil, nil, errors.ThrowNotFound(nil, "COMMAND-jae5P", "Errors.AuthRequest.NotExisting")
|
||||
}
|
||||
if writeModel.AuthRequestState != domain.AuthRequestStateAdded {
|
||||
return nil, nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sx208nt", "Errors.AuthRequest.AlreadyHandled")
|
||||
}
|
||||
if checkLoginClient && authz.GetCtxData(ctx).UserID != writeModel.LoginClient {
|
||||
return nil, nil, errors.ThrowPermissionDenied(nil, "COMMAND-rai9Y", "Errors.AuthRequest.WrongLoginClient")
|
||||
}
|
||||
sessionWriteModel := NewSessionWriteModel(sessionID, authz.GetCtxData(ctx).OrgID)
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if sessionWriteModel.State == domain.SessionStateUnspecified {
|
||||
return nil, nil, errors.ThrowNotFound(nil, "COMMAND-x0099887", "Errors.Session.NotExisting")
|
||||
}
|
||||
if err := c.sessionPermission(ctx, sessionWriteModel, sessionToken, domain.PermissionSessionWrite); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := c.pushAppendAndReduce(ctx, writeModel, authrequest.NewSessionLinkedEvent(
|
||||
ctx, &authrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
|
||||
sessionID,
|
||||
sessionWriteModel.UserID,
|
||||
sessionWriteModel.AuthenticationTime(),
|
||||
amr.List(sessionWriteModel),
|
||||
)); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return writeModelToObjectDetails(&writeModel.WriteModel), authRequestWriteModelToCurrentAuthRequest(writeModel), nil
|
||||
}
|
||||
|
||||
func (c *Commands) FailAuthRequest(ctx context.Context, id string, reason domain.OIDCErrorReason) (*domain.ObjectDetails, *CurrentAuthRequest, error) {
|
||||
writeModel, err := c.getAuthRequestWriteModel(ctx, id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if writeModel.AuthRequestState != domain.AuthRequestStateAdded {
|
||||
return nil, nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sx202nt", "Errors.AuthRequest.AlreadyHandled")
|
||||
}
|
||||
err = c.pushAppendAndReduce(ctx, writeModel, authrequest.NewFailedEvent(
|
||||
ctx,
|
||||
&authrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
|
||||
reason,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return writeModelToObjectDetails(&writeModel.WriteModel), authRequestWriteModelToCurrentAuthRequest(writeModel), nil
|
||||
}
|
||||
|
||||
func (c *Commands) AddAuthRequestCode(ctx context.Context, authRequestID, code string) (err error) {
|
||||
if code == "" {
|
||||
return errors.ThrowPreconditionFailed(nil, "COMMAND-Ht52d", "Errors.AuthRequest.InvalidCode")
|
||||
}
|
||||
writeModel, err := c.getAuthRequestWriteModel(ctx, authRequestID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if writeModel.AuthRequestState != domain.AuthRequestStateAdded || writeModel.SessionID == "" {
|
||||
return errors.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.AlreadyHandled")
|
||||
}
|
||||
return c.pushAppendAndReduce(ctx, writeModel, authrequest.NewCodeAddedEvent(ctx,
|
||||
&authrequest.NewAggregate(writeModel.AggregateID, authz.GetInstance(ctx).InstanceID()).Aggregate))
|
||||
}
|
||||
|
||||
func (c *Commands) ExchangeAuthCode(ctx context.Context, code string) (authRequest *CurrentAuthRequest, err error) {
|
||||
if code == "" {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sf3g2", "Errors.AuthRequest.InvalidCode")
|
||||
}
|
||||
writeModel, err := c.getAuthRequestWriteModel(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if writeModel.AuthRequestState != domain.AuthRequestStateCodeAdded {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.NoCode")
|
||||
}
|
||||
err = c.pushAppendAndReduce(ctx, writeModel, authrequest.NewCodeExchangedEvent(ctx,
|
||||
&authrequest.NewAggregate(writeModel.AggregateID, authz.GetInstance(ctx).InstanceID()).Aggregate))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return authRequestWriteModelToCurrentAuthRequest(writeModel), nil
|
||||
}
|
||||
|
||||
func authRequestWriteModelToCurrentAuthRequest(writeModel *AuthRequestWriteModel) (_ *CurrentAuthRequest) {
|
||||
return &CurrentAuthRequest{
|
||||
AuthRequest: &AuthRequest{
|
||||
ID: writeModel.AggregateID,
|
||||
LoginClient: writeModel.LoginClient,
|
||||
ClientID: writeModel.ClientID,
|
||||
RedirectURI: writeModel.RedirectURI,
|
||||
State: writeModel.State,
|
||||
Nonce: writeModel.Nonce,
|
||||
Scope: writeModel.Scope,
|
||||
Audience: writeModel.Audience,
|
||||
ResponseType: writeModel.ResponseType,
|
||||
CodeChallenge: writeModel.CodeChallenge,
|
||||
Prompt: writeModel.Prompt,
|
||||
UILocales: writeModel.UILocales,
|
||||
MaxAge: writeModel.MaxAge,
|
||||
LoginHint: writeModel.LoginHint,
|
||||
HintUserID: writeModel.HintUserID,
|
||||
},
|
||||
SessionID: writeModel.SessionID,
|
||||
UserID: writeModel.UserID,
|
||||
AMR: writeModel.AMR,
|
||||
AuthTime: writeModel.AuthTime,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Commands) GetCurrentAuthRequest(ctx context.Context, id string) (_ *CurrentAuthRequest, err error) {
|
||||
wm, err := c.getAuthRequestWriteModel(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return authRequestWriteModelToCurrentAuthRequest(wm), nil
|
||||
}
|
||||
|
||||
func (c *Commands) getAuthRequestWriteModel(ctx context.Context, id string) (writeModel *AuthRequestWriteModel, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
writeModel = NewAuthRequestWriteModel(ctx, id)
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, writeModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return writeModel, nil
|
||||
}
|
110
internal/command/auth_request_model.go
Normal file
110
internal/command/auth_request_model.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||
)
|
||||
|
||||
type AuthRequestWriteModel struct {
|
||||
eventstore.WriteModel
|
||||
aggregate *eventstore.Aggregate
|
||||
|
||||
LoginClient string
|
||||
ClientID string
|
||||
RedirectURI string
|
||||
State string
|
||||
Nonce string
|
||||
Scope []string
|
||||
Audience []string
|
||||
ResponseType domain.OIDCResponseType
|
||||
CodeChallenge *domain.OIDCCodeChallenge
|
||||
Prompt []domain.Prompt
|
||||
UILocales []string
|
||||
MaxAge *time.Duration
|
||||
LoginHint *string
|
||||
HintUserID *string
|
||||
SessionID string
|
||||
UserID string
|
||||
AuthTime time.Time
|
||||
AMR []string
|
||||
AuthRequestState domain.AuthRequestState
|
||||
}
|
||||
|
||||
func NewAuthRequestWriteModel(ctx context.Context, id string) *AuthRequestWriteModel {
|
||||
return &AuthRequestWriteModel{
|
||||
WriteModel: eventstore.WriteModel{
|
||||
AggregateID: id,
|
||||
},
|
||||
aggregate: &authrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AuthRequestWriteModel) Reduce() error {
|
||||
for _, event := range m.Events {
|
||||
switch e := event.(type) {
|
||||
case *authrequest.AddedEvent:
|
||||
m.LoginClient = e.LoginClient
|
||||
m.ClientID = e.ClientID
|
||||
m.RedirectURI = e.RedirectURI
|
||||
m.State = e.State
|
||||
m.Nonce = e.Nonce
|
||||
m.Scope = e.Scope
|
||||
m.Audience = e.Audience
|
||||
m.ResponseType = e.ResponseType
|
||||
m.CodeChallenge = e.CodeChallenge
|
||||
m.Prompt = e.Prompt
|
||||
m.UILocales = e.UILocales
|
||||
m.MaxAge = e.MaxAge
|
||||
m.LoginHint = e.LoginHint
|
||||
m.HintUserID = e.HintUserID
|
||||
m.AuthRequestState = domain.AuthRequestStateAdded
|
||||
case *authrequest.SessionLinkedEvent:
|
||||
m.SessionID = e.SessionID
|
||||
m.UserID = e.UserID
|
||||
m.AuthTime = e.AuthTime
|
||||
m.AMR = e.AMR
|
||||
case *authrequest.CodeAddedEvent:
|
||||
m.AuthRequestState = domain.AuthRequestStateCodeAdded
|
||||
case *authrequest.FailedEvent:
|
||||
m.AuthRequestState = domain.AuthRequestStateFailed
|
||||
case *authrequest.CodeExchangedEvent:
|
||||
m.AuthRequestState = domain.AuthRequestStateCodeExchanged
|
||||
case *authrequest.SucceededEvent:
|
||||
m.AuthRequestState = domain.AuthRequestStateSucceeded
|
||||
}
|
||||
}
|
||||
|
||||
return m.WriteModel.Reduce()
|
||||
}
|
||||
|
||||
func (m *AuthRequestWriteModel) Query() *eventstore.SearchQueryBuilder {
|
||||
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
AddQuery().
|
||||
AggregateTypes(authrequest.AggregateType).
|
||||
AggregateIDs(m.AggregateID).
|
||||
Builder()
|
||||
}
|
||||
|
||||
// CheckAuthenticated checks that the auth request exists, a session must have been linked
|
||||
// and in case of a Code Flow the code must have been exchanged
|
||||
func (m *AuthRequestWriteModel) CheckAuthenticated() error {
|
||||
if m.SessionID == "" {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "AUTHR-SF2r2", "Errors.AuthRequest.NotAuthenticated")
|
||||
}
|
||||
// in case of OIDC Code Flow, the code must have been exchanged
|
||||
if m.ResponseType == domain.OIDCResponseTypeCode && m.AuthRequestState == domain.AuthRequestStateCodeExchanged {
|
||||
return nil
|
||||
}
|
||||
// in case of OIDC Implicit Flow, check that the requests exists, but has not succeeded yet
|
||||
if (m.ResponseType == domain.OIDCResponseTypeIDToken || m.ResponseType == domain.OIDCResponseTypeIDTokenToken) &&
|
||||
m.AuthRequestState == domain.AuthRequestStateAdded {
|
||||
return nil
|
||||
}
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "AUTHR-sajk3", "Errors.AuthRequest.NotAuthenticated")
|
||||
}
|
998
internal/command/auth_request_test.go
Normal file
998
internal/command/auth_request_test.go
Normal file
@@ -0,0 +1,998 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/oidc/amr"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
"github.com/zitadel/zitadel/internal/id/mock"
|
||||
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||
"github.com/zitadel/zitadel/internal/repository/session"
|
||||
)
|
||||
|
||||
func TestCommands_AddAuthRequest(t *testing.T) {
|
||||
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
request *AuthRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *CurrentAuthRequest
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
"already exists error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
request: &AuthRequest{},
|
||||
},
|
||||
nil,
|
||||
caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sf3gt", "Errors.AuthRequest.AlreadyExisting"),
|
||||
},
|
||||
{
|
||||
"added",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
}),
|
||||
),
|
||||
idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
request: &AuthRequest{
|
||||
LoginClient: "loginClient",
|
||||
ClientID: "clientID",
|
||||
RedirectURI: "redirectURI",
|
||||
State: "state",
|
||||
Nonce: "nonce",
|
||||
Scope: []string{"openid"},
|
||||
Audience: []string{"audience"},
|
||||
ResponseType: domain.OIDCResponseTypeCode,
|
||||
CodeChallenge: &domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
Prompt: []domain.Prompt{domain.PromptNone},
|
||||
UILocales: []string{"en", "de"},
|
||||
MaxAge: gu.Ptr(time.Duration(0)),
|
||||
LoginHint: gu.Ptr("loginHint"),
|
||||
HintUserID: gu.Ptr("hintUserID"),
|
||||
},
|
||||
},
|
||||
&CurrentAuthRequest{
|
||||
AuthRequest: &AuthRequest{
|
||||
ID: "V2_id",
|
||||
LoginClient: "loginClient",
|
||||
ClientID: "clientID",
|
||||
RedirectURI: "redirectURI",
|
||||
State: "state",
|
||||
Nonce: "nonce",
|
||||
Scope: []string{"openid"},
|
||||
Audience: []string{"audience"},
|
||||
ResponseType: domain.OIDCResponseTypeCode,
|
||||
CodeChallenge: &domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
Prompt: []domain.Prompt{domain.PromptNone},
|
||||
UILocales: []string{"en", "de"},
|
||||
MaxAge: gu.Ptr(time.Duration(0)),
|
||||
LoginHint: gu.Ptr("loginHint"),
|
||||
HintUserID: gu.Ptr("hintUserID"),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
}
|
||||
got, err := c.AddAuthRequest(tt.args.ctx, tt.args.request)
|
||||
require.ErrorIs(t, tt.wantErr, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
tokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
|
||||
checkPermission domain.PermissionCheck
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
id string
|
||||
sessionID string
|
||||
sessionToken string
|
||||
checkLoginClient bool
|
||||
}
|
||||
type res struct {
|
||||
details *domain.ObjectDetails
|
||||
authReq *CurrentAuthRequest
|
||||
wantErr error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"authRequest not found",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(),
|
||||
),
|
||||
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
|
||||
return nil
|
||||
},
|
||||
checkPermission: newMockPermissionCheckNotAllowed(),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "id",
|
||||
sessionID: "sessionID",
|
||||
},
|
||||
res{
|
||||
wantErr: caos_errs.ThrowNotFound(nil, "COMMAND-jae5P", "Errors.AuthRequest.NotExisting"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"authRequest not existing",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewFailedEvent(mockCtx, &authrequest.NewAggregate("id", "instanceID").Aggregate,
|
||||
domain.OIDCErrorReasonUnspecified),
|
||||
),
|
||||
),
|
||||
),
|
||||
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
|
||||
return nil
|
||||
},
|
||||
checkPermission: newMockPermissionCheckAllowed(),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "id",
|
||||
sessionID: "sessionID",
|
||||
},
|
||||
res{
|
||||
wantErr: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sx208nt", "Errors.AuthRequest.AlreadyHandled"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"wrong login client",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
|
||||
return nil
|
||||
},
|
||||
checkPermission: newMockPermissionCheckAllowed(),
|
||||
},
|
||||
args{
|
||||
ctx: authz.NewMockContext("instanceID", "orgID", "wrongLoginClient"),
|
||||
id: "id",
|
||||
sessionID: "sessionID",
|
||||
sessionToken: "token",
|
||||
checkLoginClient: true,
|
||||
},
|
||||
res{
|
||||
wantErr: caos_errs.ThrowPermissionDenied(nil, "COMMAND-rai9Y", "Errors.AuthRequest.WrongLoginClient"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"session not existing",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(),
|
||||
),
|
||||
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
|
||||
return nil
|
||||
},
|
||||
checkPermission: newMockPermissionCheckNotAllowed(),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_id",
|
||||
sessionID: "sessionID",
|
||||
},
|
||||
res{
|
||||
wantErr: caos_errs.ThrowNotFound(nil, "COMMAND-x0099887", "Errors.Session.NotExisting"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"missing permission",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
|
||||
),
|
||||
),
|
||||
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
|
||||
return nil
|
||||
},
|
||||
checkPermission: newMockPermissionCheckNotAllowed(),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_id",
|
||||
sessionID: "sessionID",
|
||||
},
|
||||
res{
|
||||
wantErr: caos_errs.ThrowPermissionDenied(nil, "AUTHZ-HKJD33", "Errors.PermissionDenied"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"invalid session token",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
|
||||
),
|
||||
),
|
||||
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
|
||||
return caos_errs.ThrowPermissionDenied(nil, "COMMAND-sGr42", "Errors.Session.Token.Invalid")
|
||||
},
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_id",
|
||||
sessionID: "sessionID",
|
||||
sessionToken: "invalid",
|
||||
},
|
||||
res{
|
||||
wantErr: caos_errs.ThrowPermissionDenied(nil, "COMMAND-sGr42", "Errors.Session.Token.Invalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"linked",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld"),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
|
||||
"userID", testNow),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
|
||||
testNow),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
[]*repository.Event{eventFromEventPusherWithInstanceID(
|
||||
"instanceID",
|
||||
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]string{amr.PWD},
|
||||
),
|
||||
)}),
|
||||
),
|
||||
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
|
||||
return nil
|
||||
},
|
||||
checkPermission: newMockPermissionCheckAllowed(),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_id",
|
||||
sessionID: "sessionID",
|
||||
sessionToken: "token",
|
||||
},
|
||||
res{
|
||||
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
|
||||
authReq: &CurrentAuthRequest{
|
||||
AuthRequest: &AuthRequest{
|
||||
ID: "V2_id",
|
||||
LoginClient: "loginClient",
|
||||
ClientID: "clientID",
|
||||
RedirectURI: "redirectURI",
|
||||
State: "state",
|
||||
Nonce: "nonce",
|
||||
Scope: []string{"openid"},
|
||||
Audience: []string{"audience"},
|
||||
ResponseType: domain.OIDCResponseTypeCode,
|
||||
},
|
||||
SessionID: "sessionID",
|
||||
UserID: "userID",
|
||||
AMR: []string{amr.PWD},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"linked with login client check",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld"),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
|
||||
"userID", testNow),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
|
||||
testNow),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
[]*repository.Event{eventFromEventPusherWithInstanceID(
|
||||
"instanceID",
|
||||
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]string{amr.PWD},
|
||||
),
|
||||
)}),
|
||||
),
|
||||
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
|
||||
return nil
|
||||
},
|
||||
checkPermission: newMockPermissionCheckAllowed(),
|
||||
},
|
||||
args{
|
||||
ctx: authz.NewMockContext("instanceID", "orgID", "loginClient"),
|
||||
id: "V2_id",
|
||||
sessionID: "sessionID",
|
||||
sessionToken: "token",
|
||||
checkLoginClient: true,
|
||||
},
|
||||
res{
|
||||
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
|
||||
authReq: &CurrentAuthRequest{
|
||||
AuthRequest: &AuthRequest{
|
||||
ID: "V2_id",
|
||||
LoginClient: "loginClient",
|
||||
ClientID: "clientID",
|
||||
RedirectURI: "redirectURI",
|
||||
State: "state",
|
||||
Nonce: "nonce",
|
||||
Scope: []string{"openid"},
|
||||
Audience: []string{"audience"},
|
||||
ResponseType: domain.OIDCResponseTypeCode,
|
||||
},
|
||||
SessionID: "sessionID",
|
||||
UserID: "userID",
|
||||
AMR: []string{amr.PWD},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
sessionTokenVerifier: tt.fields.tokenVerifier,
|
||||
checkPermission: tt.fields.checkPermission,
|
||||
}
|
||||
details, got, err := c.LinkSessionToAuthRequest(tt.args.ctx, tt.args.id, tt.args.sessionID, tt.args.sessionToken, tt.args.checkLoginClient)
|
||||
require.ErrorIs(t, err, tt.res.wantErr)
|
||||
assert.Equal(t, tt.res.details, details)
|
||||
if err == nil {
|
||||
assert.WithinRange(t, got.AuthTime, testNow, testNow)
|
||||
got.AuthTime = time.Time{}
|
||||
}
|
||||
assert.Equal(t, tt.res.authReq, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_FailAuthRequest(t *testing.T) {
|
||||
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
id string
|
||||
reason domain.OIDCErrorReason
|
||||
}
|
||||
type res struct {
|
||||
details *domain.ObjectDetails
|
||||
authReq *CurrentAuthRequest
|
||||
wantErr error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"authRequest not existing",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "foo",
|
||||
reason: domain.OIDCErrorReasonLoginRequired,
|
||||
},
|
||||
res{
|
||||
wantErr: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sx202nt", "Errors.AuthRequest.AlreadyHandled"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"failed",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
[]*repository.Event{eventFromEventPusherWithInstanceID(
|
||||
"instanceID",
|
||||
authrequest.NewFailedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
domain.OIDCErrorReasonLoginRequired),
|
||||
)}),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_id",
|
||||
reason: domain.OIDCErrorReasonLoginRequired,
|
||||
},
|
||||
res{
|
||||
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
|
||||
authReq: &CurrentAuthRequest{
|
||||
AuthRequest: &AuthRequest{
|
||||
ID: "V2_id",
|
||||
LoginClient: "loginClient",
|
||||
ClientID: "clientID",
|
||||
RedirectURI: "redirectURI",
|
||||
State: "state",
|
||||
Nonce: "nonce",
|
||||
Scope: []string{"openid"},
|
||||
Audience: []string{"audience"},
|
||||
ResponseType: domain.OIDCResponseTypeCode,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
}
|
||||
details, got, err := c.FailAuthRequest(tt.args.ctx, tt.args.id, tt.args.reason)
|
||||
require.ErrorIs(t, err, tt.res.wantErr)
|
||||
assert.Equal(t, tt.res.details, details)
|
||||
assert.Equal(t, tt.res.authReq, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_AddAuthRequestCode(t *testing.T) {
|
||||
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
id string
|
||||
code string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
"empty code error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_authRequestID",
|
||||
code: "",
|
||||
},
|
||||
caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Ht52d", "Errors.AuthRequest.InvalidCode"),
|
||||
},
|
||||
{
|
||||
"no session linked error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_authRequestID",
|
||||
code: "V2_authRequestID",
|
||||
},
|
||||
caos_errs.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.AlreadyHandled"),
|
||||
},
|
||||
{
|
||||
"success",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]string{amr.PWD},
|
||||
),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
authrequest.NewCodeAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_authRequestID",
|
||||
code: "V2_authRequestID",
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
}
|
||||
err := c.AddAuthRequestCode(tt.args.ctx, tt.args.id, tt.args.code)
|
||||
assert.ErrorIs(t, tt.wantErr, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_ExchangeAuthCode(t *testing.T) {
|
||||
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
code string
|
||||
}
|
||||
type res struct {
|
||||
authRequest *CurrentAuthRequest
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"empty code error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
code: "",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sf3g2", "Errors.AuthRequest.InvalidCode"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"no code added error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
code: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.NoCode"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"code exchanged",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]string{amr.PWD},
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
authrequest.NewCodeExchangedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
code: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
authRequest: &CurrentAuthRequest{
|
||||
AuthRequest: &AuthRequest{
|
||||
ID: "V2_authRequestID",
|
||||
LoginClient: "loginClient",
|
||||
ClientID: "clientID",
|
||||
RedirectURI: "redirectURI",
|
||||
State: "state",
|
||||
Nonce: "nonce",
|
||||
Scope: []string{"openid"},
|
||||
Audience: []string{"audience"},
|
||||
ResponseType: domain.OIDCResponseTypeCode,
|
||||
CodeChallenge: &domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
Prompt: []domain.Prompt{domain.PromptNone},
|
||||
UILocales: []string{"en", "de"},
|
||||
MaxAge: gu.Ptr(time.Duration(0)),
|
||||
LoginHint: gu.Ptr("loginHint"),
|
||||
HintUserID: gu.Ptr("hintUserID"),
|
||||
},
|
||||
SessionID: "sessionID",
|
||||
UserID: "userID",
|
||||
AMR: []string{"pwd"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
}
|
||||
got, err := c.ExchangeAuthCode(tt.args.ctx, tt.args.code)
|
||||
assert.ErrorIs(t, tt.res.err, err)
|
||||
|
||||
if err == nil {
|
||||
// equal on time won't work -> test separately and clear it before comparing the rest
|
||||
assert.WithinRange(t, got.AuthTime, testNow, testNow)
|
||||
got.AuthTime = time.Time{}
|
||||
}
|
||||
assert.Equal(t, tt.res.authRequest, got)
|
||||
})
|
||||
}
|
||||
}
|
@@ -15,10 +15,12 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
"github.com/zitadel/zitadel/internal/repository/action"
|
||||
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||
"github.com/zitadel/zitadel/internal/repository/idpintent"
|
||||
instance_repo "github.com/zitadel/zitadel/internal/repository/instance"
|
||||
"github.com/zitadel/zitadel/internal/repository/keypair"
|
||||
"github.com/zitadel/zitadel/internal/repository/milestone"
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
"github.com/zitadel/zitadel/internal/repository/org"
|
||||
proj_repo "github.com/zitadel/zitadel/internal/repository/project"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
@@ -43,18 +45,21 @@ type Commands struct {
|
||||
externalSecure bool
|
||||
externalPort uint16
|
||||
|
||||
idpConfigEncryption crypto.EncryptionAlgorithm
|
||||
smtpEncryption crypto.EncryptionAlgorithm
|
||||
smsEncryption crypto.EncryptionAlgorithm
|
||||
userEncryption crypto.EncryptionAlgorithm
|
||||
userPasswordAlg crypto.HashAlgorithm
|
||||
machineKeySize int
|
||||
applicationKeySize int
|
||||
domainVerificationAlg crypto.EncryptionAlgorithm
|
||||
domainVerificationGenerator crypto.Generator
|
||||
domainVerificationValidator func(domain, token, verifier string, checkType api_http.CheckType) error
|
||||
sessionTokenCreator func(sessionID string) (id string, token string, err error)
|
||||
sessionTokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
|
||||
idpConfigEncryption crypto.EncryptionAlgorithm
|
||||
smtpEncryption crypto.EncryptionAlgorithm
|
||||
smsEncryption crypto.EncryptionAlgorithm
|
||||
userEncryption crypto.EncryptionAlgorithm
|
||||
userPasswordAlg crypto.HashAlgorithm
|
||||
machineKeySize int
|
||||
applicationKeySize int
|
||||
domainVerificationAlg crypto.EncryptionAlgorithm
|
||||
domainVerificationGenerator crypto.Generator
|
||||
domainVerificationValidator func(domain, token, verifier string, checkType api_http.CheckType) error
|
||||
sessionTokenCreator func(sessionID string) (id string, token string, err error)
|
||||
sessionTokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
|
||||
defaultAccessTokenLifetime time.Duration
|
||||
defaultRefreshTokenLifetime time.Duration
|
||||
defaultRefreshTokenIdleLifetime time.Duration
|
||||
|
||||
multifactors domain.MultifactorConfigs
|
||||
webauthnConfig *webauthn_helper.Config
|
||||
@@ -80,6 +85,9 @@ func StartCommands(
|
||||
httpClient *http.Client,
|
||||
permissionCheck domain.PermissionCheck,
|
||||
sessionTokenVerifier func(ctx context.Context, sessionToken string, sessionID string, tokenID string) (err error),
|
||||
defaultAccessTokenLifetime,
|
||||
defaultRefreshTokenLifetime,
|
||||
defaultRefreshTokenIdleLifetime time.Duration,
|
||||
) (repo *Commands, err error) {
|
||||
if externalDomain == "" {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "COMMAND-Df21s", "no external domain specified")
|
||||
@@ -88,31 +96,34 @@ func StartCommands(
|
||||
// reuse the oidcEncryption to be able to handle both tokens in the interceptor later on
|
||||
sessionAlg := oidcEncryption
|
||||
repo = &Commands{
|
||||
eventstore: es,
|
||||
static: staticStore,
|
||||
idGenerator: idGenerator,
|
||||
zitadelRoles: zitadelRoles,
|
||||
externalDomain: externalDomain,
|
||||
externalSecure: externalSecure,
|
||||
externalPort: externalPort,
|
||||
keySize: defaults.KeyConfig.Size,
|
||||
certKeySize: defaults.KeyConfig.CertificateSize,
|
||||
privateKeyLifetime: defaults.KeyConfig.PrivateKeyLifetime,
|
||||
publicKeyLifetime: defaults.KeyConfig.PublicKeyLifetime,
|
||||
certificateLifetime: defaults.KeyConfig.CertificateLifetime,
|
||||
idpConfigEncryption: idpConfigEncryption,
|
||||
smtpEncryption: smtpEncryption,
|
||||
smsEncryption: smsEncryption,
|
||||
userEncryption: userEncryption,
|
||||
domainVerificationAlg: domainVerificationEncryption,
|
||||
keyAlgorithm: oidcEncryption,
|
||||
certificateAlgorithm: samlEncryption,
|
||||
webauthnConfig: webAuthN,
|
||||
httpClient: httpClient,
|
||||
checkPermission: permissionCheck,
|
||||
newCode: newCryptoCode,
|
||||
sessionTokenCreator: sessionTokenCreator(idGenerator, sessionAlg),
|
||||
sessionTokenVerifier: sessionTokenVerifier,
|
||||
eventstore: es,
|
||||
static: staticStore,
|
||||
idGenerator: idGenerator,
|
||||
zitadelRoles: zitadelRoles,
|
||||
externalDomain: externalDomain,
|
||||
externalSecure: externalSecure,
|
||||
externalPort: externalPort,
|
||||
keySize: defaults.KeyConfig.Size,
|
||||
certKeySize: defaults.KeyConfig.CertificateSize,
|
||||
privateKeyLifetime: defaults.KeyConfig.PrivateKeyLifetime,
|
||||
publicKeyLifetime: defaults.KeyConfig.PublicKeyLifetime,
|
||||
certificateLifetime: defaults.KeyConfig.CertificateLifetime,
|
||||
idpConfigEncryption: idpConfigEncryption,
|
||||
smtpEncryption: smtpEncryption,
|
||||
smsEncryption: smsEncryption,
|
||||
userEncryption: userEncryption,
|
||||
domainVerificationAlg: domainVerificationEncryption,
|
||||
keyAlgorithm: oidcEncryption,
|
||||
certificateAlgorithm: samlEncryption,
|
||||
webauthnConfig: webAuthN,
|
||||
httpClient: httpClient,
|
||||
checkPermission: permissionCheck,
|
||||
newCode: newCryptoCode,
|
||||
sessionTokenCreator: sessionTokenCreator(idGenerator, sessionAlg),
|
||||
sessionTokenVerifier: sessionTokenVerifier,
|
||||
defaultAccessTokenLifetime: defaultAccessTokenLifetime,
|
||||
defaultRefreshTokenLifetime: defaultRefreshTokenLifetime,
|
||||
defaultRefreshTokenIdleLifetime: defaultRefreshTokenIdleLifetime,
|
||||
}
|
||||
|
||||
instance_repo.RegisterEventMappers(repo.eventstore)
|
||||
@@ -125,6 +136,8 @@ func StartCommands(
|
||||
quota.RegisterEventMappers(repo.eventstore)
|
||||
session.RegisterEventMappers(repo.eventstore)
|
||||
idpintent.RegisterEventMappers(repo.eventstore)
|
||||
authrequest.RegisterEventMappers(repo.eventstore)
|
||||
oidcsession.RegisterEventMappers(repo.eventstore)
|
||||
milestone.RegisterEventMappers(repo.eventstore)
|
||||
|
||||
repo.userPasswordAlg = crypto.NewBCrypt(defaults.SecretGenerators.PasswordSaltCost)
|
||||
|
@@ -16,9 +16,11 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository/mock"
|
||||
action_repo "github.com/zitadel/zitadel/internal/repository/action"
|
||||
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||
"github.com/zitadel/zitadel/internal/repository/idpintent"
|
||||
iam_repo "github.com/zitadel/zitadel/internal/repository/instance"
|
||||
key_repo "github.com/zitadel/zitadel/internal/repository/keypair"
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
"github.com/zitadel/zitadel/internal/repository/org"
|
||||
proj_repo "github.com/zitadel/zitadel/internal/repository/project"
|
||||
"github.com/zitadel/zitadel/internal/repository/session"
|
||||
@@ -43,6 +45,8 @@ func eventstoreExpect(t *testing.T, expects ...expect) *eventstore.Eventstore {
|
||||
action_repo.RegisterEventMappers(es)
|
||||
session.RegisterEventMappers(es)
|
||||
idpintent.RegisterEventMappers(es)
|
||||
authrequest.RegisterEventMappers(es)
|
||||
oidcsession.RegisterEventMappers(es)
|
||||
return es
|
||||
}
|
||||
|
||||
|
281
internal/command/oidc_session.go
Normal file
281
internal/command/oidc_session.go
Normal file
@@ -0,0 +1,281 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/oidc/amr"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
)
|
||||
|
||||
// AddOIDCSessionAccessToken creates a new OIDC Session, creates an access token and returns its id and expiration.
|
||||
// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged.
|
||||
func (c *Commands) AddOIDCSessionAccessToken(ctx context.Context, authRequestID string) (string, time.Time, error) {
|
||||
cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID)
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
cmd.AddSession(ctx)
|
||||
if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope); err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
cmd.SetAuthRequestSuccessful(ctx)
|
||||
accessTokenID, _, accessTokenExpiration, err := cmd.PushEvents(ctx)
|
||||
return accessTokenID, accessTokenExpiration, err
|
||||
}
|
||||
|
||||
// AddOIDCSessionRefreshAndAccessToken creates a new OIDC Session, creates an access token and refresh token.
|
||||
// It returns the access token id, expiration and the refresh token.
|
||||
// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged.
|
||||
func (c *Commands) AddOIDCSessionRefreshAndAccessToken(ctx context.Context, authRequestID string) (tokenID, refreshToken string, tokenExpiration time.Time, err error) {
|
||||
cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
cmd.AddSession(ctx)
|
||||
if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope); err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
if err = cmd.AddRefreshToken(ctx); err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
cmd.SetAuthRequestSuccessful(ctx)
|
||||
return cmd.PushEvents(ctx)
|
||||
}
|
||||
|
||||
// ExchangeOIDCSessionRefreshAndAccessToken updates an existing OIDC Session, creates a new access and refresh token.
|
||||
// It returns the access token id and expiration and the new refresh token.
|
||||
func (c *Commands) ExchangeOIDCSessionRefreshAndAccessToken(ctx context.Context, oidcSessionID, refreshToken string, scope []string) (tokenID, newRefreshToken string, tokenExpiration time.Time, err error) {
|
||||
cmd, err := c.newOIDCSessionUpdateEvents(ctx, oidcSessionID, refreshToken)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
if err = cmd.AddAccessToken(ctx, scope); err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
if err = cmd.RenewRefreshToken(ctx); err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
return cmd.PushEvents(ctx)
|
||||
}
|
||||
|
||||
// OIDCSessionByRefreshToken computes the current state of an existing OIDCSession by a refresh_token (to start a Refresh Token Grant).
|
||||
// If either the session is not active, the token is invalid or expired (incl. idle expiration) an invalid refresh token error will be returned.
|
||||
func (c *Commands) OIDCSessionByRefreshToken(ctx context.Context, refreshToken string) (*OIDCSessionWriteModel, error) {
|
||||
split := strings.Split(refreshToken, ":")
|
||||
if len(split) != 2 {
|
||||
return nil, caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid")
|
||||
}
|
||||
writeModel := NewOIDCSessionWriteModel(split[0], "")
|
||||
err := c.eventstore.FilterToQueryReducer(ctx, writeModel)
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowPreconditionFailed(err, "OIDCS-SAF31", "Errors.OIDCSession.RefreshTokenInvalid")
|
||||
}
|
||||
if err = writeModel.CheckRefreshToken(split[1]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return writeModel, nil
|
||||
}
|
||||
|
||||
func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, authRequestID string) (*OIDCSessionEvents, error) {
|
||||
authRequestWriteModel, err := c.getAuthRequestWriteModel(ctx, authRequestID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = authRequestWriteModel.CheckAuthenticated(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionWriteModel := NewSessionWriteModel(authRequestWriteModel.SessionID, authz.GetCtxData(ctx).OrgID)
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sessionWriteModel.State != domain.SessionStateActive {
|
||||
return nil, caos_errs.ThrowPreconditionFailed(nil, "OIDCS-sjkl3", "Errors.Session.Terminated")
|
||||
}
|
||||
accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionID, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionID = IDPrefixV2 + sessionID
|
||||
return &OIDCSessionEvents{
|
||||
eventstore: c.eventstore,
|
||||
idGenerator: c.idGenerator,
|
||||
encryptionAlg: c.keyAlgorithm,
|
||||
oidcSessionWriteModel: NewOIDCSessionWriteModel(sessionID, authz.GetInstance(ctx).InstanceID()),
|
||||
sessionWriteModel: sessionWriteModel,
|
||||
authRequestWriteModel: authRequestWriteModel,
|
||||
accessTokenLifetime: accessTokenLifetime,
|
||||
refreshTokenLifeTime: refreshTokenLifeTime,
|
||||
refreshTokenIdleLifetime: refreshTokenIdleLifetime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Commands) decryptRefreshToken(refreshToken string) (refreshTokenID string, err error) {
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(refreshToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
decrypted, err := c.keyAlgorithm.DecryptString(decoded, c.keyAlgorithm.EncryptionKeyID())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
split := strings.Split(decrypted, ":")
|
||||
if len(split) != 2 {
|
||||
return "", caos_errs.ThrowPreconditionFailed(nil, "OIDCS-Sj3lk", "Errors.OIDCSession.RefreshTokenInvalid")
|
||||
}
|
||||
return split[1], nil
|
||||
}
|
||||
|
||||
func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, oidcSessionID, refreshToken string) (*OIDCSessionEvents, error) {
|
||||
refreshTokenID, err := c.decryptRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionWriteModel := NewOIDCSessionWriteModel(oidcSessionID, authz.GetInstance(ctx).InstanceID())
|
||||
if err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = sessionWriteModel.CheckRefreshToken(refreshTokenID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OIDCSessionEvents{
|
||||
eventstore: c.eventstore,
|
||||
idGenerator: c.idGenerator,
|
||||
encryptionAlg: c.keyAlgorithm,
|
||||
oidcSessionWriteModel: sessionWriteModel,
|
||||
accessTokenLifetime: accessTokenLifetime,
|
||||
refreshTokenLifeTime: refreshTokenLifeTime,
|
||||
refreshTokenIdleLifetime: refreshTokenIdleLifetime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type OIDCSessionEvents struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
encryptionAlg crypto.EncryptionAlgorithm
|
||||
events []eventstore.Command
|
||||
oidcSessionWriteModel *OIDCSessionWriteModel
|
||||
sessionWriteModel *SessionWriteModel
|
||||
authRequestWriteModel *AuthRequestWriteModel
|
||||
accessTokenLifetime time.Duration
|
||||
refreshTokenLifeTime time.Duration
|
||||
refreshTokenIdleLifetime time.Duration
|
||||
|
||||
// accessTokenID is set by the command
|
||||
accessTokenID string
|
||||
|
||||
// refreshToken is set by the command
|
||||
refreshToken string
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) AddSession(ctx context.Context) {
|
||||
c.events = append(c.events, oidcsession.NewAddedEvent(
|
||||
ctx,
|
||||
c.oidcSessionWriteModel.aggregate,
|
||||
c.sessionWriteModel.UserID,
|
||||
c.sessionWriteModel.AggregateID,
|
||||
c.authRequestWriteModel.ClientID,
|
||||
c.authRequestWriteModel.Audience,
|
||||
c.authRequestWriteModel.Scope,
|
||||
amr.List(c.sessionWriteModel),
|
||||
c.sessionWriteModel.AuthenticationTime(),
|
||||
))
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) SetAuthRequestSuccessful(ctx context.Context) {
|
||||
c.events = append(c.events, authrequest.NewSucceededEvent(ctx, c.authRequestWriteModel.aggregate))
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string) (err error) {
|
||||
c.accessTokenID, err = c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.events = append(c.events, oidcsession.NewAccessTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.accessTokenID, scope, c.accessTokenLifetime))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context) (err error) {
|
||||
var refreshTokenID string
|
||||
refreshTokenID, c.refreshToken, err = c.generateRefreshToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.events = append(c.events, oidcsession.NewRefreshTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, refreshTokenID, c.refreshTokenLifeTime, c.refreshTokenIdleLifetime))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) RenewRefreshToken(ctx context.Context) (err error) {
|
||||
var refreshTokenID string
|
||||
refreshTokenID, c.refreshToken, err = c.generateRefreshToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.events = append(c.events, oidcsession.NewRefreshTokenRenewedEvent(ctx, c.oidcSessionWriteModel.aggregate, refreshTokenID, c.refreshTokenIdleLifetime))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) generateRefreshToken() (refreshTokenID, refreshToken string, err error) {
|
||||
refreshTokenID, err = c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
token, err := c.encryptionAlg.Encrypt([]byte(c.oidcSessionWriteModel.AggregateID + ":" + refreshTokenID))
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return refreshTokenID, base64.RawURLEncoding.EncodeToString(token), nil
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (accessTokenID string, refreshToken string, accessTokenExpiration time.Time, err error) {
|
||||
pushedEvents, err := c.eventstore.Push(ctx, c.events...)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
err = AppendAndReduce(c.oidcSessionWriteModel, pushedEvents...)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
// prefix the returned id with the oidcSessionID so that we can retrieve it later on
|
||||
// we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts
|
||||
return c.oidcSessionWriteModel.AggregateID + "-" + c.accessTokenID, c.refreshToken, c.oidcSessionWriteModel.AccessTokenExpiration, nil
|
||||
}
|
||||
|
||||
func (c *Commands) tokenTokenLifetimes(ctx context.Context) (accessTokenLifetime time.Duration, refreshTokenLifetime time.Duration, refreshTokenIdleLifetime time.Duration, err error) {
|
||||
oidcSettings := NewInstanceOIDCSettingsWriteModel(ctx)
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, oidcSettings)
|
||||
if err != nil {
|
||||
return 0, 0, 0, err
|
||||
}
|
||||
accessTokenLifetime = c.defaultAccessTokenLifetime
|
||||
refreshTokenLifetime = c.defaultRefreshTokenLifetime
|
||||
refreshTokenIdleLifetime = c.defaultRefreshTokenIdleLifetime
|
||||
if oidcSettings.AccessTokenLifetime > 0 {
|
||||
accessTokenLifetime = oidcSettings.AccessTokenLifetime
|
||||
}
|
||||
if oidcSettings.RefreshTokenExpiration > 0 {
|
||||
refreshTokenLifetime = oidcSettings.RefreshTokenExpiration
|
||||
}
|
||||
if oidcSettings.RefreshTokenIdleExpiration > 0 {
|
||||
refreshTokenIdleLifetime = oidcSettings.RefreshTokenIdleExpiration
|
||||
}
|
||||
return accessTokenLifetime, refreshTokenLifetime, refreshTokenIdleLifetime, nil
|
||||
}
|
114
internal/command/oidc_session_model.go
Normal file
114
internal/command/oidc_session_model.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
)
|
||||
|
||||
type OIDCSessionWriteModel struct {
|
||||
eventstore.WriteModel
|
||||
|
||||
UserID string
|
||||
SessionID string
|
||||
ClientID string
|
||||
Audience []string
|
||||
Scope []string
|
||||
AuthMethodsReferences []string
|
||||
AuthTime time.Time
|
||||
State domain.OIDCSessionState
|
||||
AccessTokenExpiration time.Time
|
||||
RefreshTokenID string
|
||||
RefreshTokenExpiration time.Time
|
||||
RefreshTokenIdleExpiration time.Time
|
||||
|
||||
aggregate *eventstore.Aggregate
|
||||
}
|
||||
|
||||
func NewOIDCSessionWriteModel(id string, resourceOwner string) *OIDCSessionWriteModel {
|
||||
return &OIDCSessionWriteModel{
|
||||
WriteModel: eventstore.WriteModel{
|
||||
AggregateID: id,
|
||||
ResourceOwner: resourceOwner,
|
||||
},
|
||||
aggregate: &oidcsession.NewAggregate(id, resourceOwner).Aggregate,
|
||||
}
|
||||
}
|
||||
|
||||
func (wm *OIDCSessionWriteModel) Reduce() error {
|
||||
for _, event := range wm.Events {
|
||||
switch e := event.(type) {
|
||||
case *oidcsession.AddedEvent:
|
||||
wm.reduceAdded(e)
|
||||
case *oidcsession.AccessTokenAddedEvent:
|
||||
wm.reduceAccessTokenAdded(e)
|
||||
case *oidcsession.RefreshTokenAddedEvent:
|
||||
wm.reduceRefreshTokenAdded(e)
|
||||
case *oidcsession.RefreshTokenRenewedEvent:
|
||||
wm.reduceRefreshTokenRenewed(e)
|
||||
}
|
||||
}
|
||||
return wm.WriteModel.Reduce()
|
||||
}
|
||||
|
||||
func (wm *OIDCSessionWriteModel) Query() *eventstore.SearchQueryBuilder {
|
||||
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
AddQuery().
|
||||
AggregateTypes(oidcsession.AggregateType).
|
||||
AggregateIDs(wm.AggregateID).
|
||||
EventTypes(
|
||||
oidcsession.AddedType,
|
||||
oidcsession.AccessTokenAddedType,
|
||||
oidcsession.RefreshTokenAddedType,
|
||||
oidcsession.RefreshTokenRenewedType,
|
||||
).
|
||||
Builder()
|
||||
|
||||
if wm.ResourceOwner != "" {
|
||||
query.ResourceOwner(wm.ResourceOwner)
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
func (wm *OIDCSessionWriteModel) reduceAdded(e *oidcsession.AddedEvent) {
|
||||
wm.UserID = e.UserID
|
||||
wm.SessionID = e.SessionID
|
||||
wm.ClientID = e.ClientID
|
||||
wm.Audience = e.Audience
|
||||
wm.Scope = e.Scope
|
||||
wm.AuthMethodsReferences = e.AuthMethodsReferences
|
||||
wm.AuthTime = e.AuthTime
|
||||
wm.State = domain.OIDCSessionStateActive
|
||||
}
|
||||
|
||||
func (wm *OIDCSessionWriteModel) reduceAccessTokenAdded(e *oidcsession.AccessTokenAddedEvent) {
|
||||
wm.AccessTokenExpiration = e.CreationDate().Add(e.Lifetime)
|
||||
}
|
||||
|
||||
func (wm *OIDCSessionWriteModel) reduceRefreshTokenAdded(e *oidcsession.RefreshTokenAddedEvent) {
|
||||
wm.RefreshTokenID = e.ID
|
||||
wm.RefreshTokenExpiration = e.CreationDate().Add(e.Lifetime)
|
||||
wm.RefreshTokenIdleExpiration = e.CreationDate().Add(e.IdleLifetime)
|
||||
}
|
||||
|
||||
func (wm *OIDCSessionWriteModel) reduceRefreshTokenRenewed(e *oidcsession.RefreshTokenRenewedEvent) {
|
||||
wm.RefreshTokenID = e.ID
|
||||
wm.RefreshTokenIdleExpiration = e.CreationDate().Add(e.IdleLifetime)
|
||||
}
|
||||
|
||||
func (wm *OIDCSessionWriteModel) CheckRefreshToken(refreshTokenID string) error {
|
||||
if wm.State != domain.OIDCSessionStateActive {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid")
|
||||
}
|
||||
if wm.RefreshTokenID != refreshTokenID {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid")
|
||||
}
|
||||
now := time.Now()
|
||||
if wm.RefreshTokenExpiration.Before(now) || wm.RefreshTokenIdleExpiration.Before(now) {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid")
|
||||
}
|
||||
return nil
|
||||
}
|
795
internal/command/oidc_session_test.go
Normal file
795
internal/command/oidc_session_test.go
Normal file
@@ -0,0 +1,795 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/oidc/amr"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
"github.com/zitadel/zitadel/internal/id/mock"
|
||||
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
"github.com/zitadel/zitadel/internal/repository/session"
|
||||
)
|
||||
|
||||
var (
|
||||
testNow = time.Now()
|
||||
tokenCreationNow = time.Time{}
|
||||
)
|
||||
|
||||
func TestCommands_AddOIDCSessionAccessToken(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
defaultAccessTokenLifetime time.Duration
|
||||
defaultRefreshTokenLifetime time.Duration
|
||||
defaultRefreshTokenIdleLifetime time.Duration
|
||||
keyAlgorithm crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
authRequestID string
|
||||
}
|
||||
type res struct {
|
||||
id string
|
||||
expiration time.Time
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"unauthenticated error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
authRequestID: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "AUTHR-SF2r2", "Errors.AuthRequest.NotAuthenticated"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"inactive session error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]string{amr.PWD},
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
),
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
authRequestID: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-sjkl3", "Errors.Session.Terminated"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"add successful",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]string{amr.PWD},
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, "domain.tld"),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
"userID", testNow),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
testNow),
|
||||
),
|
||||
),
|
||||
expectFilter(), // token lifetime
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid"}, []string{amr.PWD}, testNow),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"accessTokenID", []string{"openid"}, time.Hour),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
authrequest.NewSucceededEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID"),
|
||||
defaultAccessTokenLifetime: time.Hour,
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
authRequestID: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
id: "V2_oidcSessionID-accessTokenID",
|
||||
expiration: tokenCreationNow.Add(time.Hour),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
|
||||
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
|
||||
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
}
|
||||
gotID, gotExpiration, err := c.AddOIDCSessionAccessToken(tt.args.ctx, tt.args.authRequestID)
|
||||
assert.Equal(t, tt.res.id, gotID)
|
||||
assert.Equal(t, tt.res.expiration, gotExpiration)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
defaultAccessTokenLifetime time.Duration
|
||||
defaultRefreshTokenLifetime time.Duration
|
||||
defaultRefreshTokenIdleLifetime time.Duration
|
||||
keyAlgorithm crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
authRequestID string
|
||||
}
|
||||
type res struct {
|
||||
id string
|
||||
refreshToken string
|
||||
expiration time.Time
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"unauthenticated error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
authRequestID: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "AUTHR-SF2r2", "Errors.AuthRequest.NotAuthenticated"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"inactive session error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid", "offline_access"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]string{amr.PWD},
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
),
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
authRequestID: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-sjkl3", "Errors.Session.Terminated"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"add successful",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid", "offline_access"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]string{amr.PWD},
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, "domain.tld"),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
"userID", testNow),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
testNow),
|
||||
),
|
||||
),
|
||||
expectFilter(), // token lifetime
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "offline_access"}, []string{amr.PWD}, testNow),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"accessTokenID", []string{"openid", "offline_access"}, time.Hour),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
authrequest.NewSucceededEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID", "refreshTokenID"),
|
||||
defaultAccessTokenLifetime: time.Hour,
|
||||
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
|
||||
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
authRequestID: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
id: "V2_oidcSessionID-accessTokenID",
|
||||
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA", //V2_oidcSessionID:refreshTokenID
|
||||
expiration: tokenCreationNow.Add(time.Hour),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
|
||||
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
|
||||
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
}
|
||||
gotID, gotRefreshToken, gotExpiration, err := c.AddOIDCSessionRefreshAndAccessToken(tt.args.ctx, tt.args.authRequestID)
|
||||
assert.Equal(t, tt.res.id, gotID)
|
||||
assert.Equal(t, tt.res.refreshToken, gotRefreshToken)
|
||||
assert.Equal(t, tt.res.expiration, gotExpiration)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
defaultAccessTokenLifetime time.Duration
|
||||
defaultRefreshTokenLifetime time.Duration
|
||||
defaultRefreshTokenIdleLifetime time.Duration
|
||||
keyAlgorithm crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
oidcSessionID string
|
||||
refreshToken string
|
||||
scope []string
|
||||
}
|
||||
type res struct {
|
||||
id string
|
||||
refreshToken string
|
||||
expiration time.Time
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"invalid refresh token format error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
oidcSessionID: "V2_oidcSessionID",
|
||||
refreshToken: "aW52YWxpZA",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-Sj3lk", "Errors.OIDCSession.RefreshTokenInvalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"inactive session error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(),
|
||||
),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
oidcSessionID: "V2_oidcSessionID",
|
||||
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"invalid refresh token error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
|
||||
),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
oidcSessionID: "V2_oidcSessionID",
|
||||
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"expired refresh token error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
|
||||
),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
oidcSessionID: "V2_oidcSessionID",
|
||||
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"refresh successful",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusherWithCreationDateNow(
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
|
||||
),
|
||||
eventFromEventPusherWithCreationDateNow(
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
|
||||
),
|
||||
eventFromEventPusherWithCreationDateNow(
|
||||
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
|
||||
),
|
||||
),
|
||||
expectFilter(), // token lifetime
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"accessTokenID", []string{"openid", "offline_access"}, time.Hour),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
oidcsession.NewRefreshTokenRenewedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"refreshTokenID2", 24*time.Hour),
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
idGenerator: mock.NewIDGeneratorExpectIDs(t, "accessTokenID", "refreshTokenID2"),
|
||||
defaultAccessTokenLifetime: time.Hour,
|
||||
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
|
||||
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
oidcSessionID: "V2_oidcSessionID",
|
||||
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA",
|
||||
scope: []string{"openid", "offline_access"},
|
||||
},
|
||||
res{
|
||||
id: "V2_oidcSessionID-accessTokenID",
|
||||
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRDI",
|
||||
expiration: time.Time{}.Add(time.Hour),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
|
||||
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
|
||||
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
}
|
||||
gotID, gotRefreshToken, gotExpiration, err := c.ExchangeOIDCSessionRefreshAndAccessToken(tt.args.ctx, tt.args.oidcSessionID, tt.args.refreshToken, tt.args.scope)
|
||||
assert.Equal(t, tt.res.id, gotID)
|
||||
assert.Equal(t, tt.res.refreshToken, gotRefreshToken)
|
||||
assert.Equal(t, tt.res.expiration, gotExpiration)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
defaultAccessTokenLifetime time.Duration
|
||||
defaultRefreshTokenLifetime time.Duration
|
||||
defaultRefreshTokenIdleLifetime time.Duration
|
||||
keyAlgorithm crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
refreshToken string
|
||||
}
|
||||
type res struct {
|
||||
model *OIDCSessionWriteModel
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"invalid refresh token format error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
refreshToken: "invalid",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"inactive session error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(),
|
||||
),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
refreshToken: "V2_oidcSessionID:refreshTokenID",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"invalid refresh token error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
|
||||
),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
refreshToken: "V2_oidcSessionID:refreshTokenID",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"expired refresh token error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
|
||||
),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
refreshToken: "V2_oidcSessionID:refreshTokenID",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"get successful",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusherWithCreationDateNow(
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
|
||||
),
|
||||
eventFromEventPusherWithCreationDateNow(
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
|
||||
),
|
||||
eventFromEventPusherWithCreationDateNow(
|
||||
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
|
||||
),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
refreshToken: "V2_oidcSessionID:refreshTokenID",
|
||||
},
|
||||
res{
|
||||
model: &OIDCSessionWriteModel{
|
||||
WriteModel: eventstore.WriteModel{
|
||||
AggregateID: "V2_oidcSessionID",
|
||||
ChangeDate: testNow,
|
||||
},
|
||||
UserID: "userID",
|
||||
SessionID: "sessionID",
|
||||
ClientID: "clientID",
|
||||
Audience: []string{"audience"},
|
||||
Scope: []string{"openid", "profile", "offline_access"},
|
||||
AuthMethodsReferences: []string{amr.PWD},
|
||||
AuthTime: testNow,
|
||||
State: domain.OIDCSessionStateActive,
|
||||
RefreshTokenID: "refreshTokenID",
|
||||
RefreshTokenExpiration: testNow.Add(7 * 24 * time.Hour),
|
||||
RefreshTokenIdleExpiration: testNow.Add(24 * time.Hour),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
|
||||
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
|
||||
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
}
|
||||
got, err := c.OIDCSessionByRefreshToken(tt.args.ctx, tt.args.refreshToken)
|
||||
require.ErrorIs(t, err, tt.res.err)
|
||||
if tt.res.err == nil {
|
||||
assert.WithinRange(t, got.ChangeDate, tt.res.model.ChangeDate.Add(-2*time.Second), tt.res.model.ChangeDate.Add(2*time.Second))
|
||||
assert.Equal(t, tt.res.model.AggregateID, got.AggregateID)
|
||||
assert.Equal(t, tt.res.model.UserID, got.UserID)
|
||||
assert.Equal(t, tt.res.model.SessionID, got.SessionID)
|
||||
assert.Equal(t, tt.res.model.ClientID, got.ClientID)
|
||||
assert.Equal(t, tt.res.model.Audience, got.Audience)
|
||||
assert.Equal(t, tt.res.model.Scope, got.Scope)
|
||||
assert.Equal(t, tt.res.model.AuthMethodsReferences, got.AuthMethodsReferences)
|
||||
assert.WithinRange(t, got.AuthTime, tt.res.model.AuthTime.Add(-2*time.Second), tt.res.model.AuthTime.Add(2*time.Second))
|
||||
assert.Equal(t, tt.res.model.State, got.State)
|
||||
assert.Equal(t, tt.res.model.RefreshTokenID, got.RefreshTokenID)
|
||||
assert.WithinRange(t, got.RefreshTokenExpiration, tt.res.model.RefreshTokenExpiration.Add(-2*time.Second), tt.res.model.RefreshTokenExpiration.Add(2*time.Second))
|
||||
assert.WithinRange(t, got.RefreshTokenIdleExpiration, tt.res.model.RefreshTokenIdleExpiration.Add(-2*time.Second), tt.res.model.RefreshTokenIdleExpiration.Add(2*time.Second))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -52,6 +52,24 @@ type SessionWriteModel struct {
|
||||
aggregate *eventstore.Aggregate
|
||||
}
|
||||
|
||||
func (wm *SessionWriteModel) IsPasswordChecked() bool {
|
||||
return !wm.PasswordCheckedAt.IsZero()
|
||||
}
|
||||
|
||||
func (wm *SessionWriteModel) IsPasskeyChecked() bool {
|
||||
return !wm.PasskeyCheckedAt.IsZero()
|
||||
}
|
||||
|
||||
func (wm *SessionWriteModel) IsU2FChecked() bool {
|
||||
// TODO: implement with https://github.com/zitadel/zitadel/issues/5477
|
||||
return false
|
||||
}
|
||||
|
||||
func (wm *SessionWriteModel) IsOTPChecked() bool {
|
||||
// TODO: implement with https://github.com/zitadel/zitadel/issues/5477
|
||||
return false
|
||||
}
|
||||
|
||||
func NewSessionWriteModel(sessionID string, resourceOwner string) *SessionWriteModel {
|
||||
return &SessionWriteModel{
|
||||
WriteModel: eventstore.WriteModel{
|
||||
@@ -210,3 +228,19 @@ func (wm *SessionWriteModel) ChangeMetadata(ctx context.Context, metadata map[st
|
||||
wm.commands = append(wm.commands, session.NewMetadataSetEvent(ctx, wm.aggregate, wm.Metadata))
|
||||
}
|
||||
}
|
||||
|
||||
// AuthenticationTime returns the time the user authenticated using the latest time of all checks
|
||||
func (wm *SessionWriteModel) AuthenticationTime() time.Time {
|
||||
var authTime time.Time
|
||||
for _, check := range []time.Time{
|
||||
wm.PasswordCheckedAt,
|
||||
wm.PasskeyCheckedAt,
|
||||
wm.IntentCheckedAt,
|
||||
// TODO: add U2F and OTP check https://github.com/zitadel/zitadel/issues/5477
|
||||
} {
|
||||
if check.After(authTime) {
|
||||
authTime = check
|
||||
}
|
||||
}
|
||||
return authTime
|
||||
}
|
||||
|
@@ -528,7 +528,7 @@ func TestCommands_createHumanOTP(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCommands_HumanCheckMFAOTPSetup(t *testing.T) {
|
||||
ctx := authz.NewMockContext("inst1", "org1", "user1")
|
||||
ctx := authz.NewMockContext("", "org1", "user1")
|
||||
|
||||
cryptoAlg := crypto.CreateMockEncryptionAlg(gomock.NewController(t))
|
||||
key, secret, err := domain.NewOTPKey("example.com", "user1", cryptoAlg)
|
||||
|
@@ -188,7 +188,7 @@ func TestCommands_AddUserTOTP(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCommands_CheckUserTOTP(t *testing.T) {
|
||||
ctx := authz.NewMockContext("inst1", "org1", "user1")
|
||||
ctx := authz.NewMockContext("", "org1", "user1")
|
||||
|
||||
cryptoAlg := crypto.CreateMockEncryptionAlg(gomock.NewController(t))
|
||||
key, secret, err := domain.NewOTPKey("example.com", "user1", cryptoAlg)
|
||||
|
Reference in New Issue
Block a user