update base

This commit is contained in:
Elio Bischof
2023-06-29 12:45:45 +02:00
117 changed files with 7254 additions and 3466 deletions

View File

@@ -10,10 +10,11 @@ import (
"github.com/zitadel/zitadel/internal/eventstore/handler/crdb"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/session"
"github.com/zitadel/zitadel/internal/repository/user"
)
const (
SessionsProjectionTable = "projections.sessions1"
SessionsProjectionTable = "projections.sessions3"
SessionColumnID = "id"
SessionColumnCreationDate = "creation_date"
@@ -21,11 +22,13 @@ const (
SessionColumnSequence = "sequence"
SessionColumnState = "state"
SessionColumnResourceOwner = "resource_owner"
SessionColumnDomain = "domain"
SessionColumnInstanceID = "instance_id"
SessionColumnCreator = "creator"
SessionColumnUserID = "user_id"
SessionColumnUserCheckedAt = "user_checked_at"
SessionColumnPasswordCheckedAt = "password_checked_at"
SessionColumnIntentCheckedAt = "intent_checked_at"
SessionColumnPasskeyCheckedAt = "passkey_checked_at"
SessionColumnMetadata = "metadata"
SessionColumnTokenID = "token_id"
@@ -47,11 +50,13 @@ func newSessionProjection(ctx context.Context, config crdb.StatementHandlerConfi
crdb.NewColumn(SessionColumnSequence, crdb.ColumnTypeInt64),
crdb.NewColumn(SessionColumnState, crdb.ColumnTypeEnum),
crdb.NewColumn(SessionColumnResourceOwner, crdb.ColumnTypeText),
crdb.NewColumn(SessionColumnDomain, crdb.ColumnTypeText),
crdb.NewColumn(SessionColumnInstanceID, crdb.ColumnTypeText),
crdb.NewColumn(SessionColumnCreator, crdb.ColumnTypeText),
crdb.NewColumn(SessionColumnUserID, crdb.ColumnTypeText, crdb.Nullable()),
crdb.NewColumn(SessionColumnUserCheckedAt, crdb.ColumnTypeTimestamp, crdb.Nullable()),
crdb.NewColumn(SessionColumnPasswordCheckedAt, crdb.ColumnTypeTimestamp, crdb.Nullable()),
crdb.NewColumn(SessionColumnIntentCheckedAt, crdb.ColumnTypeTimestamp, crdb.Nullable()),
crdb.NewColumn(SessionColumnPasskeyCheckedAt, crdb.ColumnTypeTimestamp, crdb.Nullable()),
crdb.NewColumn(SessionColumnMetadata, crdb.ColumnTypeJSONB, crdb.Nullable()),
crdb.NewColumn(SessionColumnTokenID, crdb.ColumnTypeText, crdb.Nullable()),
@@ -80,6 +85,10 @@ func (p *sessionProjection) reducers() []handler.AggregateReducer {
Event: session.PasswordCheckedType,
Reduce: p.reducePasswordChecked,
},
{
Event: session.IntentCheckedType,
Reduce: p.reduceIntentChecked,
},
{
Event: session.PasskeyCheckedType,
Reduce: p.reducePasskeyChecked,
@@ -107,6 +116,15 @@ func (p *sessionProjection) reducers() []handler.AggregateReducer {
},
},
},
{
Aggregate: user.AggregateType,
EventRedusers: []handler.EventReducer{
{
Event: user.HumanPasswordChangedType,
Reduce: p.reducePasswordChanged,
},
},
},
}
}
@@ -124,6 +142,7 @@ func (p *sessionProjection) reduceSessionAdded(event eventstore.Event) (*handler
handler.NewCol(SessionColumnCreationDate, e.CreationDate()),
handler.NewCol(SessionColumnChangeDate, e.CreationDate()),
handler.NewCol(SessionColumnResourceOwner, e.Aggregate().ResourceOwner),
handler.NewCol(SessionColumnDomain, e.Domain),
handler.NewCol(SessionColumnState, domain.SessionStateActive),
handler.NewCol(SessionColumnSequence, e.Sequence()),
handler.NewCol(SessionColumnCreator, e.User),
@@ -171,6 +190,26 @@ func (p *sessionProjection) reducePasswordChecked(event eventstore.Event) (*hand
), nil
}
func (p *sessionProjection) reduceIntentChecked(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*session.IntentCheckedEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-SDgr2", "reduce.wrong.event.type %s", session.IntentCheckedType)
}
return crdb.NewUpdateStatement(
e,
[]handler.Column{
handler.NewCol(SessionColumnChangeDate, e.CreationDate()),
handler.NewCol(SessionColumnSequence, e.Sequence()),
handler.NewCol(SessionColumnIntentCheckedAt, e.CheckedAt),
},
[]handler.Condition{
handler.NewCond(SessionColumnID, e.Aggregate().ID),
handler.NewCond(SessionColumnInstanceID, e.Aggregate().InstanceID),
},
), nil
}
func (p *sessionProjection) reducePasskeyChecked(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*session.PasskeyCheckedEvent)
if !ok {
@@ -245,3 +284,21 @@ func (p *sessionProjection) reduceSessionTerminated(event eventstore.Event) (*ha
},
), nil
}
func (p *sessionProjection) reducePasswordChanged(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*user.HumanPasswordChangedEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-Deg3d", "reduce.wrong.event.type %s", user.HumanPasswordChangedType)
}
return crdb.NewUpdateStatement(
e,
[]handler.Column{
handler.NewCol(SessionColumnPasswordCheckedAt, nil),
},
[]handler.Condition{
handler.NewCond(SessionColumnUserID, e.Aggregate().ID),
crdb.NewLessThanCond(SessionColumnPasswordCheckedAt, e.CreationDate()),
},
), nil
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/zitadel/zitadel/internal/eventstore/repository"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/session"
"github.com/zitadel/zitadel/internal/repository/user"
)
func TestSessionProjection_reduces(t *testing.T) {
@@ -29,7 +30,9 @@ func TestSessionProjection_reduces(t *testing.T) {
event: getEvent(testEvent(
session.AddedType,
session.AggregateType,
[]byte(`{}`),
[]byte(`{
"domain": "domain"
}`),
), session.AddedEventMapper),
},
reduce: (&sessionProjection{}).reduceSessionAdded,
@@ -40,13 +43,14 @@ func TestSessionProjection_reduces(t *testing.T) {
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "INSERT INTO projections.sessions1 (id, instance_id, creation_date, change_date, resource_owner, state, sequence, creator) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)",
expectedStmt: "INSERT INTO projections.sessions3 (id, instance_id, creation_date, change_date, resource_owner, domain, state, sequence, creator) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)",
expectedArgs: []interface{}{
"agg-id",
"instance-id",
anyArg{},
anyArg{},
"ro-id",
"domain",
domain.SessionStateActive,
uint64(15),
"editor-user",
@@ -76,7 +80,7 @@ func TestSessionProjection_reduces(t *testing.T) {
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.sessions1 SET (change_date, sequence, user_id, user_checked_at) = ($1, $2, $3, $4) WHERE (id = $5) AND (instance_id = $6)",
expectedStmt: "UPDATE projections.sessions3 SET (change_date, sequence, user_id, user_checked_at) = ($1, $2, $3, $4) WHERE (id = $5) AND (instance_id = $6)",
expectedArgs: []interface{}{
anyArg{},
anyArg{},
@@ -109,7 +113,39 @@ func TestSessionProjection_reduces(t *testing.T) {
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.sessions1 SET (change_date, sequence, password_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
expectedStmt: "UPDATE projections.sessions3 SET (change_date, sequence, password_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
expectedArgs: []interface{}{
anyArg{},
anyArg{},
time.Date(2023, time.May, 4, 0, 0, 0, 0, time.UTC),
"agg-id",
"instance-id",
},
},
},
},
},
},
{
name: "instance reduceIntentChecked",
args: args{
event: getEvent(testEvent(
session.AddedType,
session.AggregateType,
[]byte(`{
"checkedAt": "2023-05-04T00:00:00Z"
}`),
), session.IntentCheckedEventMapper),
},
reduce: (&sessionProjection{}).reduceIntentChecked,
want: wantReduce{
aggregateType: eventstore.AggregateType("session"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.sessions3 SET (change_date, sequence, intent_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
expectedArgs: []interface{}{
anyArg{},
anyArg{},
@@ -141,7 +177,7 @@ func TestSessionProjection_reduces(t *testing.T) {
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.sessions1 SET (change_date, sequence, token_id) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
expectedStmt: "UPDATE projections.sessions3 SET (change_date, sequence, token_id) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
expectedArgs: []interface{}{
anyArg{},
anyArg{},
@@ -175,7 +211,7 @@ func TestSessionProjection_reduces(t *testing.T) {
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.sessions1 SET (change_date, sequence, metadata) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
expectedStmt: "UPDATE projections.sessions3 SET (change_date, sequence, metadata) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
expectedArgs: []interface{}{
anyArg{},
anyArg{},
@@ -207,7 +243,7 @@ func TestSessionProjection_reduces(t *testing.T) {
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "DELETE FROM projections.sessions1 WHERE (id = $1) AND (instance_id = $2)",
expectedStmt: "DELETE FROM projections.sessions3 WHERE (id = $1) AND (instance_id = $2)",
expectedArgs: []interface{}{
"agg-id",
"instance-id",
@@ -234,7 +270,7 @@ func TestSessionProjection_reduces(t *testing.T) {
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "DELETE FROM projections.sessions1 WHERE (instance_id = $1)",
expectedStmt: "DELETE FROM projections.sessions3 WHERE (instance_id = $1)",
expectedArgs: []interface{}{
"agg-id",
},
@@ -243,6 +279,39 @@ func TestSessionProjection_reduces(t *testing.T) {
},
},
},
{
name: "reducePasswordChanged",
args: args{
event: getEvent(testEvent(
repository.EventType(user.HumanPasswordChangedType),
user.AggregateType,
[]byte(`{"secret": {
"cryptoType": 0,
"algorithm": "enc",
"keyID": "id",
"crypted": "cGFzc3dvcmQ="
}}`),
), user.HumanPasswordChangedEventMapper),
},
reduce: (&sessionProjection{}).reducePasswordChanged,
want: wantReduce{
aggregateType: eventstore.AggregateType("user"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.sessions3 SET password_checked_at = $1 WHERE (user_id = $2) AND (password_checked_at < $3)",
expectedArgs: []interface{}{
nil,
"agg-id",
anyArg{},
},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@@ -34,6 +34,7 @@ type Queries struct {
idpConfigEncryption crypto.EncryptionAlgorithm
sessionTokenVerifier func(ctx context.Context, sessionToken string, sessionID string, tokenID string) (err error)
checkPermission domain.PermissionCheck
DefaultLanguage language.Tag
LoginDir http.FileSystem
@@ -55,6 +56,7 @@ func StartQueries(
idpConfigEncryption, otpEncryption, keyEncryptionAlgorithm, certEncryptionAlgorithm crypto.EncryptionAlgorithm,
zitadelRoles []authz.RoleMapping,
sessionTokenVerifier func(ctx context.Context, sessionToken string, sessionID string, tokenID string) (err error),
permissionCheck func(q *Queries) domain.PermissionCheck,
) (repo *Queries, err error) {
statikLoginFS, err := fs.NewWithNamespace("login")
if err != nil {
@@ -95,6 +97,8 @@ func StartQueries(
},
}
repo.checkPermission = permissionCheck(repo)
err = projection.Create(ctx, sqlClient, es, projections, keyEncryptionAlgorithm, certEncryptionAlgorithm)
if err != nil {
return nil, err

View File

@@ -29,9 +29,11 @@ type Session struct {
Sequence uint64
State domain.SessionState
ResourceOwner string
Domain string
Creator string
UserFactor SessionUserFactor
PasswordFactor SessionPasswordFactor
IntentFactor SessionIntentFactor
PasskeyFactor SessionPasskeyFactor
Metadata map[string][]byte
}
@@ -47,6 +49,10 @@ type SessionPasswordFactor struct {
PasswordCheckedAt time.Time
}
type SessionIntentFactor struct {
IntentCheckedAt time.Time
}
type SessionPasskeyFactor struct {
PasskeyCheckedAt time.Time
}
@@ -93,6 +99,10 @@ var (
name: projection.SessionColumnResourceOwner,
table: sessionsTable,
}
SessionColumnDomain = Column{
name: projection.SessionColumnDomain,
table: sessionsTable,
}
SessionColumnInstanceID = Column{
name: projection.SessionColumnInstanceID,
table: sessionsTable,
@@ -113,6 +123,10 @@ var (
name: projection.SessionColumnPasswordCheckedAt,
table: sessionsTable,
}
SessionColumnIntentCheckedAt = Column{
name: projection.SessionColumnIntentCheckedAt,
table: sessionsTable,
}
SessionColumnPasskeyCheckedAt = Column{
name: projection.SessionColumnPasskeyCheckedAt,
table: sessionsTable,
@@ -202,11 +216,13 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil
SessionColumnState.identifier(),
SessionColumnResourceOwner.identifier(),
SessionColumnCreator.identifier(),
SessionColumnDomain.identifier(),
SessionColumnUserID.identifier(),
SessionColumnUserCheckedAt.identifier(),
LoginNameNameCol.identifier(),
HumanDisplayNameCol.identifier(),
SessionColumnPasswordCheckedAt.identifier(),
SessionColumnIntentCheckedAt.identifier(),
SessionColumnPasskeyCheckedAt.identifier(),
SessionColumnMetadata.identifier(),
SessionColumnToken.identifier(),
@@ -222,9 +238,11 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil
loginName sql.NullString
displayName sql.NullString
passwordCheckedAt sql.NullTime
intentCheckedAt sql.NullTime
passkeyCheckedAt sql.NullTime
metadata database.Map[[]byte]
token sql.NullString
sessionDomain sql.NullString
)
err := row.Scan(
@@ -235,11 +253,13 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil
&session.State,
&session.ResourceOwner,
&session.Creator,
&sessionDomain,
&userID,
&userCheckedAt,
&loginName,
&displayName,
&passwordCheckedAt,
&intentCheckedAt,
&passkeyCheckedAt,
&metadata,
&token,
@@ -252,11 +272,13 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil
return nil, "", errors.ThrowInternal(err, "QUERY-SAder", "Errors.Internal")
}
session.Domain = sessionDomain.String
session.UserFactor.UserID = userID.String
session.UserFactor.UserCheckedAt = userCheckedAt.Time
session.UserFactor.LoginName = loginName.String
session.UserFactor.DisplayName = displayName.String
session.PasswordFactor.PasswordCheckedAt = passwordCheckedAt.Time
session.IntentFactor.IntentCheckedAt = intentCheckedAt.Time
session.PasskeyFactor.PasskeyCheckedAt = passkeyCheckedAt.Time
session.Metadata = metadata
@@ -273,11 +295,13 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui
SessionColumnState.identifier(),
SessionColumnResourceOwner.identifier(),
SessionColumnCreator.identifier(),
SessionColumnDomain.identifier(),
SessionColumnUserID.identifier(),
SessionColumnUserCheckedAt.identifier(),
LoginNameNameCol.identifier(),
HumanDisplayNameCol.identifier(),
SessionColumnPasswordCheckedAt.identifier(),
SessionColumnIntentCheckedAt.identifier(),
SessionColumnPasskeyCheckedAt.identifier(),
SessionColumnMetadata.identifier(),
countColumn.identifier(),
@@ -296,8 +320,10 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui
loginName sql.NullString
displayName sql.NullString
passwordCheckedAt sql.NullTime
intentCheckedAt sql.NullTime
passkeyCheckedAt sql.NullTime
metadata database.Map[[]byte]
sessionDomain sql.NullString
)
err := rows.Scan(
@@ -308,11 +334,13 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui
&session.State,
&session.ResourceOwner,
&session.Creator,
&sessionDomain,
&userID,
&userCheckedAt,
&loginName,
&displayName,
&passwordCheckedAt,
&intentCheckedAt,
&passkeyCheckedAt,
&metadata,
&sessions.Count,
@@ -321,11 +349,13 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-SAfeg", "Errors.Internal")
}
session.Domain = sessionDomain.String
session.UserFactor.UserID = userID.String
session.UserFactor.UserCheckedAt = userCheckedAt.Time
session.UserFactor.LoginName = loginName.String
session.UserFactor.DisplayName = displayName.String
session.PasswordFactor.PasswordCheckedAt = passwordCheckedAt.Time
session.IntentFactor.IntentCheckedAt = intentCheckedAt.Time
session.PasskeyFactor.PasskeyCheckedAt = passkeyCheckedAt.Time
session.Metadata = metadata

View File

@@ -17,43 +17,47 @@ import (
)
var (
expectedSessionQuery = regexp.QuoteMeta(`SELECT projections.sessions1.id,` +
` projections.sessions1.creation_date,` +
` projections.sessions1.change_date,` +
` projections.sessions1.sequence,` +
` projections.sessions1.state,` +
` projections.sessions1.resource_owner,` +
` projections.sessions1.creator,` +
` projections.sessions1.user_id,` +
` projections.sessions1.user_checked_at,` +
expectedSessionQuery = regexp.QuoteMeta(`SELECT projections.sessions3.id,` +
` projections.sessions3.creation_date,` +
` projections.sessions3.change_date,` +
` projections.sessions3.sequence,` +
` projections.sessions3.state,` +
` projections.sessions3.resource_owner,` +
` projections.sessions3.creator,` +
` projections.sessions3.domain,` +
` projections.sessions3.user_id,` +
` projections.sessions3.user_checked_at,` +
` projections.login_names2.login_name,` +
` projections.users8_humans.display_name,` +
` projections.sessions1.password_checked_at,` +
` projections.sessions1.passkey_checked_at,` +
` projections.sessions1.metadata,` +
` projections.sessions1.token_id` +
` FROM projections.sessions1` +
` LEFT JOIN projections.login_names2 ON projections.sessions1.user_id = projections.login_names2.user_id AND projections.sessions1.instance_id = projections.login_names2.instance_id` +
` LEFT JOIN projections.users8_humans ON projections.sessions1.user_id = projections.users8_humans.user_id AND projections.sessions1.instance_id = projections.users8_humans.instance_id` +
` projections.sessions3.password_checked_at,` +
` projections.sessions3.intent_checked_at,` +
` projections.sessions3.passkey_checked_at,` +
` projections.sessions3.metadata,` +
` projections.sessions3.token_id` +
` FROM projections.sessions3` +
` LEFT JOIN projections.login_names2 ON projections.sessions3.user_id = projections.login_names2.user_id AND projections.sessions3.instance_id = projections.login_names2.instance_id` +
` LEFT JOIN projections.users8_humans ON projections.sessions3.user_id = projections.users8_humans.user_id AND projections.sessions3.instance_id = projections.users8_humans.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`)
expectedSessionsQuery = regexp.QuoteMeta(`SELECT projections.sessions1.id,` +
` projections.sessions1.creation_date,` +
` projections.sessions1.change_date,` +
` projections.sessions1.sequence,` +
` projections.sessions1.state,` +
` projections.sessions1.resource_owner,` +
` projections.sessions1.creator,` +
` projections.sessions1.user_id,` +
` projections.sessions1.user_checked_at,` +
expectedSessionsQuery = regexp.QuoteMeta(`SELECT projections.sessions3.id,` +
` projections.sessions3.creation_date,` +
` projections.sessions3.change_date,` +
` projections.sessions3.sequence,` +
` projections.sessions3.state,` +
` projections.sessions3.resource_owner,` +
` projections.sessions3.creator,` +
` projections.sessions3.domain,` +
` projections.sessions3.user_id,` +
` projections.sessions3.user_checked_at,` +
` projections.login_names2.login_name,` +
` projections.users8_humans.display_name,` +
` projections.sessions1.password_checked_at,` +
` projections.sessions1.passkey_checked_at,` +
` projections.sessions1.metadata,` +
` projections.sessions3.password_checked_at,` +
` projections.sessions3.intent_checked_at,` +
` projections.sessions3.passkey_checked_at,` +
` projections.sessions3.metadata,` +
` COUNT(*) OVER ()` +
` FROM projections.sessions1` +
` LEFT JOIN projections.login_names2 ON projections.sessions1.user_id = projections.login_names2.user_id AND projections.sessions1.instance_id = projections.login_names2.instance_id` +
` LEFT JOIN projections.users8_humans ON projections.sessions1.user_id = projections.users8_humans.user_id AND projections.sessions1.instance_id = projections.users8_humans.instance_id` +
` FROM projections.sessions3` +
` LEFT JOIN projections.login_names2 ON projections.sessions3.user_id = projections.login_names2.user_id AND projections.sessions3.instance_id = projections.login_names2.instance_id` +
` LEFT JOIN projections.users8_humans ON projections.sessions3.user_id = projections.users8_humans.user_id AND projections.sessions3.instance_id = projections.users8_humans.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`)
sessionCols = []string{
@@ -64,11 +68,13 @@ var (
"state",
"resource_owner",
"creator",
"domain",
"user_id",
"user_checked_at",
"login_name",
"display_name",
"password_checked_at",
"intent_checked_at",
"passkey_checked_at",
"metadata",
"token",
@@ -82,11 +88,13 @@ var (
"state",
"resource_owner",
"creator",
"domain",
"user_id",
"user_checked_at",
"login_name",
"display_name",
"password_checked_at",
"intent_checked_at",
"passkey_checked_at",
"metadata",
"count",
@@ -132,12 +140,14 @@ func Test_SessionsPrepare(t *testing.T) {
domain.SessionStateActive,
"ro",
"creator",
"domain",
"user-id",
testNow,
"login-name",
"display-name",
testNow,
testNow,
testNow,
[]byte(`{"key": "dmFsdWU="}`),
},
},
@@ -156,6 +166,7 @@ func Test_SessionsPrepare(t *testing.T) {
State: domain.SessionStateActive,
ResourceOwner: "ro",
Creator: "creator",
Domain: "domain",
UserFactor: SessionUserFactor{
UserID: "user-id",
UserCheckedAt: testNow,
@@ -165,6 +176,9 @@ func Test_SessionsPrepare(t *testing.T) {
PasswordFactor: SessionPasswordFactor{
PasswordCheckedAt: testNow,
},
IntentFactor: SessionIntentFactor{
IntentCheckedAt: testNow,
},
PasskeyFactor: SessionPasskeyFactor{
PasskeyCheckedAt: testNow,
},
@@ -191,12 +205,14 @@ func Test_SessionsPrepare(t *testing.T) {
domain.SessionStateActive,
"ro",
"creator",
"domain",
"user-id",
testNow,
"login-name",
"display-name",
testNow,
testNow,
testNow,
[]byte(`{"key": "dmFsdWU="}`),
},
{
@@ -207,12 +223,14 @@ func Test_SessionsPrepare(t *testing.T) {
domain.SessionStateActive,
"ro",
"creator2",
"domain",
"user-id2",
testNow,
"login-name2",
"display-name2",
testNow,
testNow,
testNow,
[]byte(`{"key": "dmFsdWU="}`),
},
},
@@ -231,6 +249,7 @@ func Test_SessionsPrepare(t *testing.T) {
State: domain.SessionStateActive,
ResourceOwner: "ro",
Creator: "creator",
Domain: "domain",
UserFactor: SessionUserFactor{
UserID: "user-id",
UserCheckedAt: testNow,
@@ -240,6 +259,9 @@ func Test_SessionsPrepare(t *testing.T) {
PasswordFactor: SessionPasswordFactor{
PasswordCheckedAt: testNow,
},
IntentFactor: SessionIntentFactor{
IntentCheckedAt: testNow,
},
PasskeyFactor: SessionPasskeyFactor{
PasskeyCheckedAt: testNow,
},
@@ -255,6 +277,7 @@ func Test_SessionsPrepare(t *testing.T) {
State: domain.SessionStateActive,
ResourceOwner: "ro",
Creator: "creator2",
Domain: "domain",
UserFactor: SessionUserFactor{
UserID: "user-id2",
UserCheckedAt: testNow,
@@ -264,6 +287,9 @@ func Test_SessionsPrepare(t *testing.T) {
PasswordFactor: SessionPasswordFactor{
PasswordCheckedAt: testNow,
},
IntentFactor: SessionIntentFactor{
IntentCheckedAt: testNow,
},
PasskeyFactor: SessionPasskeyFactor{
PasskeyCheckedAt: testNow,
},
@@ -343,12 +369,14 @@ func Test_SessionPrepare(t *testing.T) {
domain.SessionStateActive,
"ro",
"creator",
"domain",
"user-id",
testNow,
"login-name",
"display-name",
testNow,
testNow,
testNow,
[]byte(`{"key": "dmFsdWU="}`),
"tokenID",
},
@@ -362,6 +390,7 @@ func Test_SessionPrepare(t *testing.T) {
State: domain.SessionStateActive,
ResourceOwner: "ro",
Creator: "creator",
Domain: "domain",
UserFactor: SessionUserFactor{
UserID: "user-id",
UserCheckedAt: testNow,
@@ -371,6 +400,9 @@ func Test_SessionPrepare(t *testing.T) {
PasswordFactor: SessionPasswordFactor{
PasswordCheckedAt: testNow,
},
IntentFactor: SessionIntentFactor{
IntentCheckedAt: testNow,
},
PasskeyFactor: SessionPasskeyFactor{
PasskeyCheckedAt: testNow,
},

View File

@@ -6,6 +6,7 @@ import (
"time"
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
@@ -64,12 +65,27 @@ var (
name: projection.UserAuthMethodOwnerRemovedCol,
table: userAuthMethodTable,
}
authMethodTypeTable = userAuthMethodTable.setAlias("auth_method_types")
authMethodTypeUserID = UserAuthMethodColumnUserID.setTable(authMethodTypeTable)
authMethodTypeInstanceID = UserAuthMethodColumnInstanceID.setTable(authMethodTypeTable)
authMethodTypeTypes = UserAuthMethodColumnMethodType.setTable(authMethodTypeTable)
authMethodTypeState = UserAuthMethodColumnState.setTable(authMethodTypeTable)
userIDPsCountTable = idpUserLinkTable.setAlias("user_idps_count")
userIDPsCountUserID = IDPUserLinkUserIDCol.setTable(userIDPsCountTable)
userIDPsCountInstanceID = IDPUserLinkInstanceIDCol.setTable(userIDPsCountTable)
userIDPsCountCount = Column{
name: "count",
table: userIDPsCountTable,
}
)
type AuthMethods struct {
SearchResponse
AuthMethods []*AuthMethod
}
type AuthMethod struct {
UserID string
CreationDate time.Time
@@ -83,6 +99,11 @@ type AuthMethod struct {
Type domain.UserAuthMethodType
}
type AuthMethodTypes struct {
SearchResponse
AuthMethodTypes []domain.UserAuthMethodType
}
type UserAuthMethodSearchQueries struct {
SearchRequest
Queries []SearchQuery
@@ -114,6 +135,41 @@ func (q *Queries) SearchUserAuthMethods(ctx context.Context, queries *UserAuthMe
return userAuthMethods, err
}
func (q *Queries) ListActiveUserAuthMethodTypes(ctx context.Context, userID string, withOwnerRemoved bool) (userAuthMethodTypes *AuthMethodTypes, err error) {
ctxData := authz.GetCtxData(ctx)
if ctxData.UserID != userID {
if err := q.checkPermission(ctx, domain.PermissionUserRead, ctxData.OrgID, userID); err != nil {
return nil, err
}
}
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareActiveUserAuthMethodTypesQuery(ctx, q.client)
eq := sq.Eq{
UserIDCol.identifier(): userID,
UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
}
if !withOwnerRemoved {
eq[UserOwnerRemovedCol.identifier()] = false
}
stmt, args, err := query.Where(eq).ToSql()
if err != nil {
return nil, errors.ThrowInvalidArgument(err, "QUERY-Sfdrg", "Errors.Query.InvalidRequest")
}
rows, err := q.client.QueryContext(ctx, stmt, args...)
if err != nil || rows.Err() != nil {
return nil, errors.ThrowInternal(err, "QUERY-SDgr3", "Errors.Internal")
}
userAuthMethodTypes, err = scan(rows)
if err != nil {
return nil, err
}
userAuthMethodTypes.LatestSequence, err = q.latestSequence(ctx, userTable, notifyTable, userAuthMethodTable, idpUserLinkTable)
return userAuthMethodTypes, err
}
func NewUserAuthMethodUserIDSearchQuery(value string) (SearchQuery, error) {
return NewTextQuery(UserAuthMethodColumnUserID, value, TextEquals)
}
@@ -253,3 +309,80 @@ func prepareUserAuthMethodsQuery(ctx context.Context, db prepareDatabase) (sq.Se
}, nil
}
}
func prepareActiveUserAuthMethodTypesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) {
authMethodsQuery, authMethodsArgs, err := sq.Select(
"DISTINCT("+authMethodTypeTypes.identifier()+")",
authMethodTypeUserID.identifier(),
authMethodTypeInstanceID.identifier()).
From(authMethodTypeTable.identifier()).
Where(sq.Eq{authMethodTypeState.identifier(): domain.MFAStateReady}).
ToSql()
if err != nil {
return sq.SelectBuilder{}, nil
}
idpsQuery, _, err := sq.Select(
userIDPsCountUserID.identifier(),
userIDPsCountInstanceID.identifier(),
"COUNT("+userIDPsCountUserID.identifier()+") AS "+userIDPsCountCount.name).
From(userIDPsCountTable.identifier()).
GroupBy(
userIDPsCountUserID.identifier(),
userIDPsCountInstanceID.identifier(),
).
ToSql()
if err != nil {
return sq.SelectBuilder{}, nil
}
return sq.Select(
NotifyPasswordSetCol.identifier(),
authMethodTypeTypes.identifier(),
userIDPsCountCount.identifier()).
From(userTable.identifier()).
LeftJoin(join(NotifyUserIDCol, UserIDCol)).
LeftJoin("("+authMethodsQuery+") AS "+authMethodTypeTable.alias+" ON "+
authMethodTypeUserID.identifier()+" = "+UserIDCol.identifier()+" AND "+
authMethodTypeInstanceID.identifier()+" = "+UserInstanceIDCol.identifier(),
authMethodsArgs...).
LeftJoin("(" + idpsQuery + ") AS " + userIDPsCountTable.alias + " ON " +
userIDPsCountUserID.identifier() + " = " + UserIDCol.identifier() + " AND " +
userIDPsCountInstanceID.identifier() + " = " + UserInstanceIDCol.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*AuthMethodTypes, error) {
userAuthMethodTypes := make([]domain.UserAuthMethodType, 0)
var passwordSet sql.NullBool
var idp sql.NullInt64
for rows.Next() {
var authMethodType sql.NullInt16
err := rows.Scan(
&passwordSet,
&authMethodType,
&idp,
)
if err != nil {
return nil, err
}
if authMethodType.Valid {
userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodType(authMethodType.Int16))
}
}
if passwordSet.Valid && passwordSet.Bool {
userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodTypePassword)
}
if idp.Valid && idp.Int64 > 0 {
logging.Error("IDP", idp.Int64)
userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodTypeIDP)
}
if err := rows.Close(); err != nil {
return nil, errors.ThrowInternal(err, "QUERY-3n9fl", "Errors.Query.CloseRows")
}
return &AuthMethodTypes{
AuthMethodTypes: userAuthMethodTypes,
SearchResponse: SearchResponse{
Count: uint64(len(userAuthMethodTypes)),
},
}, nil
}
}

View File

@@ -36,6 +36,23 @@ var (
"method_type",
"count",
}
prepareActiveAuthMethodTypesStmt = `SELECT projections.users8_notifications.password_set,` +
` auth_method_types.method_type,` +
` user_idps_count.count` +
` FROM projections.users8` +
` LEFT JOIN projections.users8_notifications ON projections.users8.id = projections.users8_notifications.user_id AND projections.users8.instance_id = projections.users8_notifications.instance_id` +
` LEFT JOIN (SELECT DISTINCT(auth_method_types.method_type), auth_method_types.user_id, auth_method_types.instance_id FROM projections.user_auth_methods4 AS auth_method_types` +
` WHERE auth_method_types.state = $1) AS auth_method_types` +
` ON auth_method_types.user_id = projections.users8.id AND auth_method_types.instance_id = projections.users8.instance_id` +
` LEFT JOIN (SELECT user_idps_count.user_id, user_idps_count.instance_id, COUNT(user_idps_count.user_id) AS count FROM projections.idp_user_links3 AS user_idps_count` +
` GROUP BY user_idps_count.user_id, user_idps_count.instance_id) AS user_idps_count` +
` ON user_idps_count.user_id = projections.users8.id AND user_idps_count.instance_id = projections.users8.instance_id` +
` AS OF SYSTEM TIME '-1 ms`
prepareActiveAuthMethodTypesCols = []string{
"password_set",
"method_type",
"idps_count",
}
)
func Test_UserAuthMethodPrepares(t *testing.T) {
@@ -182,6 +199,95 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
},
object: nil,
},
{
name: "prepareActiveUserAuthMethodTypesQuery no result",
prepare: prepareActiveUserAuthMethodTypesQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(prepareActiveAuthMethodTypesStmt),
nil,
nil,
),
},
object: &AuthMethodTypes{AuthMethodTypes: []domain.UserAuthMethodType{}},
},
{
name: "prepareActiveUserAuthMethodTypesQuery one second factor",
prepare: prepareActiveUserAuthMethodTypesQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(prepareActiveAuthMethodTypesStmt),
prepareActiveAuthMethodTypesCols,
[][]driver.Value{
{
true,
domain.UserAuthMethodTypePasswordless,
1,
},
},
),
},
object: &AuthMethodTypes{
SearchResponse: SearchResponse{
Count: 3,
},
AuthMethodTypes: []domain.UserAuthMethodType{
domain.UserAuthMethodTypePasswordless,
domain.UserAuthMethodTypePassword,
domain.UserAuthMethodTypeIDP,
},
},
},
{
name: "prepareActiveUserAuthMethodTypesQuery multiple second factors",
prepare: prepareActiveUserAuthMethodTypesQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(prepareActiveAuthMethodTypesStmt),
prepareActiveAuthMethodTypesCols,
[][]driver.Value{
{
true,
domain.UserAuthMethodTypePasswordless,
1,
},
{
true,
domain.UserAuthMethodTypeOTP,
1,
},
},
),
},
object: &AuthMethodTypes{
SearchResponse: SearchResponse{
Count: 4,
},
AuthMethodTypes: []domain.UserAuthMethodType{
domain.UserAuthMethodTypePasswordless,
domain.UserAuthMethodTypeOTP,
domain.UserAuthMethodTypePassword,
domain.UserAuthMethodTypeIDP,
},
},
},
{
name: "prepareActiveUserAuthMethodTypesQuery sql err",
prepare: prepareActiveUserAuthMethodTypesQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(prepareActiveAuthMethodTypesStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
},
},
object: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {