feat(api): new session service (#5801)

* backup new protoc plugin

* backup

* session

* backup

* initial implementation

* change to specific events

* implement tests

* cleanup

* refactor: use new protoc plugin for api v2

* change package

* simplify code

* cleanup

* cleanup

* fix merge

* start queries

* fix tests

* improve returned values

* add token to projection

* tests

* test db map

* update query

* permission checks

* fix tests and linting

* rework token creation

* i18n

* refactor token check and fix tests

* session to PB test

* request to query tests

* cleanup proto

* test user check

* add comment

* simplify database map type

* Update docs/docs/guides/integrate/access-zitadel-system-api.md

Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>

* fix test

* cleanup

* docs

---------

Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>
This commit is contained in:
Livio Spring
2023-05-05 17:34:53 +02:00
committed by GitHub
parent 74377c2c37
commit c2cb84cd24
55 changed files with 3911 additions and 106 deletions

View File

@@ -180,12 +180,20 @@ func (q *Queries) IsOrgUnique(ctx context.Context, name, domain string) (isUniqu
return scan(row)
}
func (q *Queries) ExistsOrg(ctx context.Context, id string) (err error) {
func (q *Queries) ExistsOrg(ctx context.Context, id, domain string) (verifiedID string, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
_, err = q.OrgByID(ctx, true, id)
return err
var org *Org
if id != "" {
org, err = q.OrgByID(ctx, true, id)
} else {
org, err = q.OrgByVerifiedDomain(ctx, domain)
}
if err != nil {
return "", err
}
return org.ID, nil
}
func (q *Queries) SearchOrgs(ctx context.Context, queries *OrgSearchQueries) (orgs *Orgs, err error) {

View File

@@ -65,6 +65,7 @@ var (
NotificationsProjection interface{}
NotificationsQuotaProjection interface{}
DeviceAuthProjection *deviceAuthProjection
SessionProjection *sessionProjection
)
type projection interface {
@@ -141,6 +142,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es *eventstore.Eventsto
SecurityPolicyProjection = newSecurityPolicyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["security_policies"]))
NotificationPolicyProjection = newNotificationPolicyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["notification_policies"]))
DeviceAuthProjection = newDeviceAuthProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["device_auth"]))
SessionProjection = newSessionProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["sessions"]))
newProjectionsList()
return nil
}
@@ -237,5 +239,6 @@ func newProjectionsList() {
SecurityPolicyProjection,
NotificationPolicyProjection,
DeviceAuthProjection,
SessionProjection,
}
}

View File

@@ -0,0 +1,221 @@
package projection
import (
"context"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/eventstore/handler/crdb"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/session"
)
const (
SessionsProjectionTable = "projections.sessions"
SessionColumnID = "id"
SessionColumnCreationDate = "creation_date"
SessionColumnChangeDate = "change_date"
SessionColumnSequence = "sequence"
SessionColumnState = "state"
SessionColumnResourceOwner = "resource_owner"
SessionColumnInstanceID = "instance_id"
SessionColumnCreator = "creator"
SessionColumnUserID = "user_id"
SessionColumnUserCheckedAt = "user_checked_at"
SessionColumnPasswordCheckedAt = "password_checked_at"
SessionColumnMetadata = "metadata"
SessionColumnTokenID = "token_id"
)
type sessionProjection struct {
crdb.StatementHandler
}
func newSessionProjection(ctx context.Context, config crdb.StatementHandlerConfig) *sessionProjection {
p := new(sessionProjection)
config.ProjectionName = SessionsProjectionTable
config.Reducers = p.reducers()
config.InitCheck = crdb.NewMultiTableCheck(
crdb.NewTable([]*crdb.Column{
crdb.NewColumn(SessionColumnID, crdb.ColumnTypeText),
crdb.NewColumn(SessionColumnCreationDate, crdb.ColumnTypeTimestamp),
crdb.NewColumn(SessionColumnChangeDate, crdb.ColumnTypeTimestamp),
crdb.NewColumn(SessionColumnSequence, crdb.ColumnTypeInt64),
crdb.NewColumn(SessionColumnState, crdb.ColumnTypeEnum),
crdb.NewColumn(SessionColumnResourceOwner, crdb.ColumnTypeText),
crdb.NewColumn(SessionColumnInstanceID, crdb.ColumnTypeText),
crdb.NewColumn(SessionColumnCreator, crdb.ColumnTypeText),
crdb.NewColumn(SessionColumnUserID, crdb.ColumnTypeText, crdb.Nullable()),
crdb.NewColumn(SessionColumnUserCheckedAt, crdb.ColumnTypeTimestamp, crdb.Nullable()),
crdb.NewColumn(SessionColumnPasswordCheckedAt, crdb.ColumnTypeTimestamp, crdb.Nullable()),
crdb.NewColumn(SessionColumnMetadata, crdb.ColumnTypeJSONB, crdb.Nullable()),
crdb.NewColumn(SessionColumnTokenID, crdb.ColumnTypeText, crdb.Nullable()),
},
crdb.NewPrimaryKey(SessionColumnInstanceID, SessionColumnID),
),
)
p.StatementHandler = crdb.NewStatementHandler(ctx, config)
return p
}
func (p *sessionProjection) reducers() []handler.AggregateReducer {
return []handler.AggregateReducer{
{
Aggregate: session.AggregateType,
EventRedusers: []handler.EventReducer{
{
Event: session.AddedType,
Reduce: p.reduceSessionAdded,
},
{
Event: session.UserCheckedType,
Reduce: p.reduceUserChecked,
},
{
Event: session.PasswordCheckedType,
Reduce: p.reducePasswordChecked,
},
{
Event: session.TokenSetType,
Reduce: p.reduceTokenSet,
},
{
Event: session.MetadataSetType,
Reduce: p.reduceMetadataSet,
},
{
Event: session.TerminateType,
Reduce: p.reduceSessionTerminated,
},
},
},
{
Aggregate: instance.AggregateType,
EventRedusers: []handler.EventReducer{
{
Event: instance.InstanceRemovedEventType,
Reduce: reduceInstanceRemovedHelper(SMSColumnInstanceID),
},
},
},
}
}
func (p *sessionProjection) reduceSessionAdded(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*session.AddedEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-Sfrgf", "reduce.wrong.event.type %s", session.AddedType)
}
return crdb.NewCreateStatement(
e,
[]handler.Column{
handler.NewCol(SessionColumnID, e.Aggregate().ID),
handler.NewCol(SessionColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCol(SessionColumnCreationDate, e.CreationDate()),
handler.NewCol(SessionColumnChangeDate, e.CreationDate()),
handler.NewCol(SessionColumnResourceOwner, e.Aggregate().ResourceOwner),
handler.NewCol(SessionColumnState, domain.SessionStateActive),
handler.NewCol(SessionColumnSequence, e.Sequence()),
handler.NewCol(SessionColumnCreator, e.User),
},
), nil
}
func (p *sessionProjection) reduceUserChecked(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*session.UserCheckedEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-saDg5", "reduce.wrong.event.type %s", session.UserCheckedType)
}
return crdb.NewUpdateStatement(
e,
[]handler.Column{
handler.NewCol(SessionColumnChangeDate, e.CreationDate()),
handler.NewCol(SessionColumnSequence, e.Sequence()),
handler.NewCol(SessionColumnUserID, e.UserID),
handler.NewCol(SessionColumnUserCheckedAt, e.CheckedAt),
},
[]handler.Condition{
handler.NewCond(SessionColumnID, e.Aggregate().ID),
handler.NewCond(SessionColumnInstanceID, e.Aggregate().InstanceID),
},
), nil
}
func (p *sessionProjection) reducePasswordChecked(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*session.PasswordCheckedEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-SDgrb", "reduce.wrong.event.type %s", session.PasswordCheckedType)
}
return crdb.NewUpdateStatement(
e,
[]handler.Column{
handler.NewCol(SessionColumnChangeDate, e.CreationDate()),
handler.NewCol(SessionColumnSequence, e.Sequence()),
handler.NewCol(SessionColumnPasswordCheckedAt, e.CheckedAt),
},
[]handler.Condition{
handler.NewCond(SessionColumnID, e.Aggregate().ID),
handler.NewCond(SessionColumnInstanceID, e.Aggregate().InstanceID),
},
), nil
}
func (p *sessionProjection) reduceTokenSet(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*session.TokenSetEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-SAfd3", "reduce.wrong.event.type %s", session.TokenSetType)
}
return crdb.NewUpdateStatement(
e,
[]handler.Column{
handler.NewCol(SessionColumnChangeDate, e.CreationDate()),
handler.NewCol(SessionColumnSequence, e.Sequence()),
handler.NewCol(SessionColumnTokenID, e.TokenID),
},
[]handler.Condition{
handler.NewCond(SessionColumnID, e.Aggregate().ID),
handler.NewCond(SessionColumnInstanceID, e.Aggregate().InstanceID),
},
), nil
}
func (p *sessionProjection) reduceMetadataSet(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*session.MetadataSetEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-SAfd3", "reduce.wrong.event.type %s", session.MetadataSetType)
}
return crdb.NewUpdateStatement(
e,
[]handler.Column{
handler.NewCol(SessionColumnChangeDate, e.CreationDate()),
handler.NewCol(SessionColumnSequence, e.Sequence()),
handler.NewCol(SessionColumnMetadata, e.Metadata),
},
[]handler.Condition{
handler.NewCond(SessionColumnID, e.Aggregate().ID),
handler.NewCond(SessionColumnInstanceID, e.Aggregate().InstanceID),
},
), nil
}
func (p *sessionProjection) reduceSessionTerminated(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*session.TerminateEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-SAftn", "reduce.wrong.event.type %s", session.TerminateType)
}
return crdb.NewDeleteStatement(
e,
[]handler.Condition{
handler.NewCond(SessionColumnID, e.Aggregate().ID),
handler.NewCond(SessionColumnInstanceID, e.Aggregate().InstanceID),
},
), nil
}

View File

@@ -0,0 +1,260 @@
package projection
import (
"testing"
"time"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/eventstore/repository"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/session"
)
func TestSessionProjection_reduces(t *testing.T) {
type args struct {
event func(t *testing.T) eventstore.Event
}
tests := []struct {
name string
args args
reduce func(event eventstore.Event) (*handler.Statement, error)
want wantReduce
}{
{
name: "instance reduceSessionAdded",
args: args{
event: getEvent(testEvent(
session.AddedType,
session.AggregateType,
[]byte(`{}`),
), session.AddedEventMapper),
},
reduce: (&sessionProjection{}).reduceSessionAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("session"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "INSERT INTO projections.sessions (id, instance_id, creation_date, change_date, resource_owner, state, sequence, creator) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)",
expectedArgs: []interface{}{
"agg-id",
"instance-id",
anyArg{},
anyArg{},
"ro-id",
domain.SessionStateActive,
uint64(15),
"editor-user",
},
},
},
},
},
},
{
name: "instance reduceUserChecked",
args: args{
event: getEvent(testEvent(
session.AddedType,
session.AggregateType,
[]byte(`{
"userId": "user-id",
"checkedAt": "2023-05-04T00:00:00Z"
}`),
), session.UserCheckedEventMapper),
},
reduce: (&sessionProjection{}).reduceUserChecked,
want: wantReduce{
aggregateType: eventstore.AggregateType("session"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.sessions SET (change_date, sequence, user_id, user_checked_at) = ($1, $2, $3, $4) WHERE (id = $5) AND (instance_id = $6)",
expectedArgs: []interface{}{
anyArg{},
anyArg{},
"user-id",
time.Date(2023, time.May, 4, 0, 0, 0, 0, time.UTC),
"agg-id",
"instance-id",
},
},
},
},
},
},
{
name: "instance reducePasswordChecked",
args: args{
event: getEvent(testEvent(
session.AddedType,
session.AggregateType,
[]byte(`{
"checkedAt": "2023-05-04T00:00:00Z"
}`),
), session.PasswordCheckedEventMapper),
},
reduce: (&sessionProjection{}).reducePasswordChecked,
want: wantReduce{
aggregateType: eventstore.AggregateType("session"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.sessions SET (change_date, sequence, password_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
expectedArgs: []interface{}{
anyArg{},
anyArg{},
time.Date(2023, time.May, 4, 0, 0, 0, 0, time.UTC),
"agg-id",
"instance-id",
},
},
},
},
},
},
{
name: "instance reduceTokenSet",
args: args{
event: getEvent(testEvent(
session.TokenSetType,
session.AggregateType,
[]byte(`{
"tokenID": "tokenID"
}`),
), session.TokenSetEventMapper),
},
reduce: (&sessionProjection{}).reduceTokenSet,
want: wantReduce{
aggregateType: eventstore.AggregateType("session"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.sessions SET (change_date, sequence, token_id) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
expectedArgs: []interface{}{
anyArg{},
anyArg{},
"tokenID",
"agg-id",
"instance-id",
},
},
},
},
},
},
{
name: "instance reduceMetadataSet",
args: args{
event: getEvent(testEvent(
session.MetadataSetType,
session.AggregateType,
[]byte(`{
"metadata": {
"key": "dmFsdWU="
}
}`),
), session.MetadataSetEventMapper),
},
reduce: (&sessionProjection{}).reduceMetadataSet,
want: wantReduce{
aggregateType: eventstore.AggregateType("session"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.sessions SET (change_date, sequence, metadata) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
expectedArgs: []interface{}{
anyArg{},
anyArg{},
map[string][]byte{
"key": []byte("value"),
},
"agg-id",
"instance-id",
},
},
},
},
},
},
{
name: "instance reduceSessionTerminated",
args: args{
event: getEvent(testEvent(
session.TerminateType,
session.AggregateType,
[]byte(`{}`),
), session.TerminateEventMapper),
},
reduce: (&sessionProjection{}).reduceSessionTerminated,
want: wantReduce{
aggregateType: eventstore.AggregateType("session"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "DELETE FROM projections.sessions WHERE (id = $1) AND (instance_id = $2)",
expectedArgs: []interface{}{
"agg-id",
"instance-id",
},
},
},
},
},
},
{
name: "instance reduceInstanceRemoved",
args: args{
event: getEvent(testEvent(
repository.EventType(instance.InstanceRemovedEventType),
instance.AggregateType,
nil,
), instance.InstanceRemovedEventMapper),
},
reduce: reduceInstanceRemovedHelper(SessionColumnInstanceID),
want: wantReduce{
aggregateType: eventstore.AggregateType("instance"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "DELETE FROM projections.sessions WHERE (instance_id = $1)",
expectedArgs: []interface{}{
"agg-id",
},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
event := baseEvent(t)
got, err := tt.reduce(event)
if !errors.IsErrorInvalidArgument(err) {
t.Errorf("no wrong event mapping: %v, got: %v", err, got)
}
event = tt.args.event(t)
got, err = tt.reduce(event)
assertReduce(t, got, err, SessionsProjectionTable, tt.want)
})
}
}

View File

@@ -22,6 +22,7 @@ import (
"github.com/zitadel/zitadel/internal/repository/keypair"
"github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/repository/project"
"github.com/zitadel/zitadel/internal/repository/session"
usr_repo "github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/repository/usergrant"
)
@@ -30,7 +31,8 @@ type Queries struct {
eventstore *eventstore.Eventstore
client *database.DB
idpConfigEncryption crypto.EncryptionAlgorithm
idpConfigEncryption crypto.EncryptionAlgorithm
sessionTokenVerifier func(ctx context.Context, sessionToken string, sessionID string, tokenID string) (err error)
DefaultLanguage language.Tag
LoginDir http.FileSystem
@@ -43,7 +45,16 @@ type Queries struct {
multifactors domain.MultifactorConfigs
}
func StartQueries(ctx context.Context, es *eventstore.Eventstore, sqlClient *database.DB, projections projection.Config, defaults sd.SystemDefaults, idpConfigEncryption, otpEncryption, keyEncryptionAlgorithm crypto.EncryptionAlgorithm, certEncryptionAlgorithm crypto.EncryptionAlgorithm, zitadelRoles []authz.RoleMapping) (repo *Queries, err error) {
func StartQueries(
ctx context.Context,
es *eventstore.Eventstore,
sqlClient *database.DB,
projections projection.Config,
defaults sd.SystemDefaults,
idpConfigEncryption, otpEncryption, keyEncryptionAlgorithm, certEncryptionAlgorithm crypto.EncryptionAlgorithm,
zitadelRoles []authz.RoleMapping,
sessionTokenVerifier func(ctx context.Context, sessionToken string, sessionID string, tokenID string) (err error),
) (repo *Queries, err error) {
statikLoginFS, err := fs.NewWithNamespace("login")
if err != nil {
return nil, fmt.Errorf("unable to start login statik dir")
@@ -63,6 +74,7 @@ func StartQueries(ctx context.Context, es *eventstore.Eventstore, sqlClient *dat
LoginTranslationFileContents: make(map[string][]byte),
NotificationTranslationFileContents: make(map[string][]byte),
zitadelRoles: zitadelRoles,
sessionTokenVerifier: sessionTokenVerifier,
}
iam_repo.RegisterEventMappers(repo.eventstore)
usr_repo.RegisterEventMappers(repo.eventstore)
@@ -71,6 +83,7 @@ func StartQueries(ctx context.Context, es *eventstore.Eventstore, sqlClient *dat
action.RegisterEventMappers(repo.eventstore)
keypair.RegisterEventMappers(repo.eventstore)
usergrant.RegisterEventMappers(repo.eventstore)
session.RegisterEventMappers(repo.eventstore)
repo.idpConfigEncryption = idpConfigEncryption
repo.multifactors = domain.MultifactorConfigs{

320
internal/query/session.go Normal file
View File

@@ -0,0 +1,320 @@
package query
import (
"context"
"database/sql"
errs "errors"
"time"
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type Sessions struct {
SearchResponse
Sessions []*Session
}
type Session struct {
ID string
CreationDate time.Time
ChangeDate time.Time
Sequence uint64
State domain.SessionState
ResourceOwner string
Creator string
UserFactor SessionUserFactor
PasswordFactor SessionPasswordFactor
Metadata map[string][]byte
}
type SessionUserFactor struct {
UserID string
UserCheckedAt time.Time
LoginName string
DisplayName string
}
type SessionPasswordFactor struct {
PasswordCheckedAt time.Time
}
type SessionsSearchQueries struct {
SearchRequest
Queries []SearchQuery
}
func (q *SessionsSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
query = q.SearchRequest.toQuery(query)
for _, q := range q.Queries {
query = q.toQuery(query)
}
return query
}
var (
sessionsTable = table{
name: projection.SessionsProjectionTable,
instanceIDCol: projection.SessionColumnInstanceID,
}
SessionColumnID = Column{
name: projection.SessionColumnID,
table: sessionsTable,
}
SessionColumnCreationDate = Column{
name: projection.SessionColumnCreationDate,
table: sessionsTable,
}
SessionColumnChangeDate = Column{
name: projection.SessionColumnChangeDate,
table: sessionsTable,
}
SessionColumnSequence = Column{
name: projection.SessionColumnSequence,
table: sessionsTable,
}
SessionColumnState = Column{
name: projection.SessionColumnState,
table: sessionsTable,
}
SessionColumnResourceOwner = Column{
name: projection.SessionColumnResourceOwner,
table: sessionsTable,
}
SessionColumnInstanceID = Column{
name: projection.SessionColumnInstanceID,
table: sessionsTable,
}
SessionColumnCreator = Column{
name: projection.SessionColumnCreator,
table: sessionsTable,
}
SessionColumnUserID = Column{
name: projection.SessionColumnUserID,
table: sessionsTable,
}
SessionColumnUserCheckedAt = Column{
name: projection.SessionColumnUserCheckedAt,
table: sessionsTable,
}
SessionColumnPasswordCheckedAt = Column{
name: projection.SessionColumnPasswordCheckedAt,
table: sessionsTable,
}
SessionColumnMetadata = Column{
name: projection.SessionColumnMetadata,
table: sessionsTable,
}
SessionColumnToken = Column{
name: projection.SessionColumnTokenID,
table: sessionsTable,
}
)
func (q *Queries) SessionByID(ctx context.Context, id, sessionToken string) (_ *Session, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareSessionQuery(ctx, q.client)
stmt, args, err := query.Where(
sq.Eq{
SessionColumnID.identifier(): id,
SessionColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
},
).ToSql()
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-dn9JW", "Errors.Query.SQLStatement")
}
row := q.client.QueryRowContext(ctx, stmt, args...)
session, tokenID, err := scan(row)
if err != nil {
return nil, err
}
if sessionToken == "" {
return session, nil
}
if err := q.sessionTokenVerifier(ctx, sessionToken, session.ID, tokenID); err != nil {
return nil, errors.ThrowPermissionDenied(nil, "QUERY-dsfr3", "Errors.PermissionDenied")
}
return session, nil
}
func (q *Queries) SearchSessions(ctx context.Context, queries *SessionsSearchQueries) (_ *Sessions, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareSessionsQuery(ctx, q.client)
stmt, args, err := queries.toQuery(query).
Where(sq.Eq{
SessionColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
}).ToSql()
if err != nil {
return nil, errors.ThrowInvalidArgument(err, "QUERY-sn9Jf", "Errors.Query.InvalidRequest")
}
rows, err := q.client.QueryContext(ctx, stmt, args...)
if err != nil || rows.Err() != nil {
return nil, errors.ThrowInternal(err, "QUERY-Sfg42", "Errors.Internal")
}
sessions, err := scan(rows)
if err != nil {
return nil, err
}
sessions.LatestSequence, err = q.latestSequence(ctx, sessionsTable)
return sessions, err
}
func NewSessionIDsSearchQuery(ids []string) (SearchQuery, error) {
list := make([]interface{}, len(ids))
for i, value := range ids {
list[i] = value
}
return NewListQuery(SessionColumnID, list, ListIn)
}
func NewSessionCreatorSearchQuery(creator string) (SearchQuery, error) {
return NewTextQuery(SessionColumnCreator, creator, TextEquals)
}
func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Session, string, error)) {
return sq.Select(
SessionColumnID.identifier(),
SessionColumnCreationDate.identifier(),
SessionColumnChangeDate.identifier(),
SessionColumnSequence.identifier(),
SessionColumnState.identifier(),
SessionColumnResourceOwner.identifier(),
SessionColumnCreator.identifier(),
SessionColumnUserID.identifier(),
SessionColumnUserCheckedAt.identifier(),
LoginNameNameCol.identifier(),
HumanDisplayNameCol.identifier(),
SessionColumnPasswordCheckedAt.identifier(),
SessionColumnMetadata.identifier(),
SessionColumnToken.identifier(),
).From(sessionsTable.identifier()).
LeftJoin(join(LoginNameUserIDCol, SessionColumnUserID)).
LeftJoin(join(HumanUserIDCol, SessionColumnUserID) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*Session, string, error) {
session := new(Session)
var (
userID sql.NullString
userCheckedAt sql.NullTime
loginName sql.NullString
displayName sql.NullString
passwordCheckedAt sql.NullTime
metadata database.Map[[]byte]
token sql.NullString
)
err := row.Scan(
&session.ID,
&session.CreationDate,
&session.ChangeDate,
&session.Sequence,
&session.State,
&session.ResourceOwner,
&session.Creator,
&userID,
&userCheckedAt,
&loginName,
&displayName,
&passwordCheckedAt,
&metadata,
&token,
)
if err != nil {
if errs.Is(err, sql.ErrNoRows) {
return nil, "", errors.ThrowNotFound(err, "QUERY-SFeaa", "Errors.Session.NotExisting")
}
return nil, "", errors.ThrowInternal(err, "QUERY-SAder", "Errors.Internal")
}
session.UserFactor.UserID = userID.String
session.UserFactor.UserCheckedAt = userCheckedAt.Time
session.UserFactor.LoginName = loginName.String
session.UserFactor.DisplayName = displayName.String
session.PasswordFactor.PasswordCheckedAt = passwordCheckedAt.Time
session.Metadata = metadata
return session, token.String, nil
}
}
func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Sessions, error)) {
return sq.Select(
SessionColumnID.identifier(),
SessionColumnCreationDate.identifier(),
SessionColumnChangeDate.identifier(),
SessionColumnSequence.identifier(),
SessionColumnState.identifier(),
SessionColumnResourceOwner.identifier(),
SessionColumnCreator.identifier(),
SessionColumnUserID.identifier(),
SessionColumnUserCheckedAt.identifier(),
LoginNameNameCol.identifier(),
HumanDisplayNameCol.identifier(),
SessionColumnPasswordCheckedAt.identifier(),
SessionColumnMetadata.identifier(),
countColumn.identifier(),
).From(sessionsTable.identifier()).
LeftJoin(join(LoginNameUserIDCol, SessionColumnUserID)).
LeftJoin(join(HumanUserIDCol, SessionColumnUserID) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*Sessions, error) {
sessions := &Sessions{Sessions: []*Session{}}
for rows.Next() {
session := new(Session)
var (
userID sql.NullString
userCheckedAt sql.NullTime
loginName sql.NullString
displayName sql.NullString
passwordCheckedAt sql.NullTime
metadata database.Map[[]byte]
)
err := rows.Scan(
&session.ID,
&session.CreationDate,
&session.ChangeDate,
&session.Sequence,
&session.State,
&session.ResourceOwner,
&session.Creator,
&userID,
&userCheckedAt,
&loginName,
&displayName,
&passwordCheckedAt,
&metadata,
&sessions.Count,
)
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-SAfeg", "Errors.Internal")
}
session.UserFactor.UserID = userID.String
session.UserFactor.UserCheckedAt = userCheckedAt.Time
session.UserFactor.LoginName = loginName.String
session.UserFactor.DisplayName = displayName.String
session.PasswordFactor.PasswordCheckedAt = passwordCheckedAt.Time
session.Metadata = metadata
sessions.Sessions = append(sessions.Sessions, session)
}
return sessions, nil
}
}

View File

@@ -0,0 +1,396 @@
package query
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"regexp"
"testing"
sq "github.com/Masterminds/squirrel"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/domain"
errs "github.com/zitadel/zitadel/internal/errors"
)
var (
expectedSessionQuery = regexp.QuoteMeta(`SELECT projections.sessions.id,` +
` projections.sessions.creation_date,` +
` projections.sessions.change_date,` +
` projections.sessions.sequence,` +
` projections.sessions.state,` +
` projections.sessions.resource_owner,` +
` projections.sessions.creator,` +
` projections.sessions.user_id,` +
` projections.sessions.user_checked_at,` +
` projections.login_names2.login_name,` +
` projections.users8_humans.display_name,` +
` projections.sessions.password_checked_at,` +
` projections.sessions.metadata,` +
` projections.sessions.token_id` +
` FROM projections.sessions` +
` LEFT JOIN projections.login_names2 ON projections.sessions.user_id = projections.login_names2.user_id AND projections.sessions.instance_id = projections.login_names2.instance_id` +
` LEFT JOIN projections.users8_humans ON projections.sessions.user_id = projections.users8_humans.user_id AND projections.sessions.instance_id = projections.users8_humans.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`)
expectedSessionsQuery = regexp.QuoteMeta(`SELECT projections.sessions.id,` +
` projections.sessions.creation_date,` +
` projections.sessions.change_date,` +
` projections.sessions.sequence,` +
` projections.sessions.state,` +
` projections.sessions.resource_owner,` +
` projections.sessions.creator,` +
` projections.sessions.user_id,` +
` projections.sessions.user_checked_at,` +
` projections.login_names2.login_name,` +
` projections.users8_humans.display_name,` +
` projections.sessions.password_checked_at,` +
` projections.sessions.metadata,` +
` COUNT(*) OVER ()` +
` FROM projections.sessions` +
` LEFT JOIN projections.login_names2 ON projections.sessions.user_id = projections.login_names2.user_id AND projections.sessions.instance_id = projections.login_names2.instance_id` +
` LEFT JOIN projections.users8_humans ON projections.sessions.user_id = projections.users8_humans.user_id AND projections.sessions.instance_id = projections.users8_humans.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`)
sessionCols = []string{
"id",
"creation_date",
"change_date",
"sequence",
"state",
"resource_owner",
"creator",
"user_id",
"user_checked_at",
"login_name",
"display_name",
"password_checked_at",
"metadata",
"token",
}
sessionsCols = []string{
"id",
"creation_date",
"change_date",
"sequence",
"state",
"resource_owner",
"creator",
"user_id",
"user_checked_at",
"login_name",
"display_name",
"password_checked_at",
"metadata",
"count",
}
)
func Test_SessionsPrepare(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
err checkErr
}
tests := []struct {
name string
prepare interface{}
want want
object interface{}
}{
{
name: "prepareSessionsQuery no result",
prepare: prepareSessionsQuery,
want: want{
sqlExpectations: mockQueries(
expectedSessionsQuery,
nil,
nil,
),
},
object: &Sessions{Sessions: []*Session{}},
},
{
name: "prepareSessionQuery",
prepare: prepareSessionsQuery,
want: want{
sqlExpectations: mockQueries(
expectedSessionsQuery,
sessionsCols,
[][]driver.Value{
{
"session-id",
testNow,
testNow,
uint64(20211109),
domain.SessionStateActive,
"ro",
"creator",
"user-id",
testNow,
"login-name",
"display-name",
testNow,
[]byte(`{"key": "dmFsdWU="}`),
},
},
),
},
object: &Sessions{
SearchResponse: SearchResponse{
Count: 1,
},
Sessions: []*Session{
{
ID: "session-id",
CreationDate: testNow,
ChangeDate: testNow,
Sequence: 20211109,
State: domain.SessionStateActive,
ResourceOwner: "ro",
Creator: "creator",
UserFactor: SessionUserFactor{
UserID: "user-id",
UserCheckedAt: testNow,
LoginName: "login-name",
DisplayName: "display-name",
},
PasswordFactor: SessionPasswordFactor{
PasswordCheckedAt: testNow,
},
Metadata: map[string][]byte{
"key": []byte("value"),
},
},
},
},
},
{
name: "prepareSessionsQuery multiple result",
prepare: prepareSessionsQuery,
want: want{
sqlExpectations: mockQueries(
expectedSessionsQuery,
sessionsCols,
[][]driver.Value{
{
"session-id",
testNow,
testNow,
uint64(20211109),
domain.SessionStateActive,
"ro",
"creator",
"user-id",
testNow,
"login-name",
"display-name",
testNow,
[]byte(`{"key": "dmFsdWU="}`),
},
{
"session-id2",
testNow,
testNow,
uint64(20211109),
domain.SessionStateActive,
"ro",
"creator2",
"user-id2",
testNow,
"login-name2",
"display-name2",
testNow,
[]byte(`{"key": "dmFsdWU="}`),
},
},
),
},
object: &Sessions{
SearchResponse: SearchResponse{
Count: 2,
},
Sessions: []*Session{
{
ID: "session-id",
CreationDate: testNow,
ChangeDate: testNow,
Sequence: 20211109,
State: domain.SessionStateActive,
ResourceOwner: "ro",
Creator: "creator",
UserFactor: SessionUserFactor{
UserID: "user-id",
UserCheckedAt: testNow,
LoginName: "login-name",
DisplayName: "display-name",
},
PasswordFactor: SessionPasswordFactor{
PasswordCheckedAt: testNow,
},
Metadata: map[string][]byte{
"key": []byte("value"),
},
},
{
ID: "session-id2",
CreationDate: testNow,
ChangeDate: testNow,
Sequence: 20211109,
State: domain.SessionStateActive,
ResourceOwner: "ro",
Creator: "creator2",
UserFactor: SessionUserFactor{
UserID: "user-id2",
UserCheckedAt: testNow,
LoginName: "login-name2",
DisplayName: "display-name2",
},
PasswordFactor: SessionPasswordFactor{
PasswordCheckedAt: testNow,
},
Metadata: map[string][]byte{
"key": []byte("value"),
},
},
},
},
},
{
name: "prepareSessionsQuery sql err",
prepare: prepareSessionsQuery,
want: want{
sqlExpectations: mockQueryErr(
expectedSessionsQuery,
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
},
},
object: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}
func Test_SessionPrepare(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
err checkErr
}
tests := []struct {
name string
prepare interface{}
want want
object interface{}
}{
{
name: "prepareSessionQuery no result",
prepare: prepareSessionQueryTesting(t, ""),
want: want{
sqlExpectations: mockQueries(
expectedSessionQuery,
nil,
nil,
),
err: func(err error) (error, bool) {
if !errs.IsNotFound(err) {
return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false
}
return nil, true
},
},
object: (*Session)(nil),
},
{
name: "prepareSessionQuery found",
prepare: prepareSessionQueryTesting(t, "tokenID"),
want: want{
sqlExpectations: mockQuery(
expectedSessionQuery,
sessionCols,
[]driver.Value{
"session-id",
testNow,
testNow,
uint64(20211109),
domain.SessionStateActive,
"ro",
"creator",
"user-id",
testNow,
"login-name",
"display-name",
testNow,
[]byte(`{"key": "dmFsdWU="}`),
"tokenID",
},
),
},
object: &Session{
ID: "session-id",
CreationDate: testNow,
ChangeDate: testNow,
Sequence: 20211109,
State: domain.SessionStateActive,
ResourceOwner: "ro",
Creator: "creator",
UserFactor: SessionUserFactor{
UserID: "user-id",
UserCheckedAt: testNow,
LoginName: "login-name",
DisplayName: "display-name",
},
PasswordFactor: SessionPasswordFactor{
PasswordCheckedAt: testNow,
},
Metadata: map[string][]byte{
"key": []byte("value"),
},
},
},
{
name: "prepareSessionQuery sql err",
prepare: prepareSessionQueryTesting(t, ""),
want: want{
sqlExpectations: mockQueryErr(
expectedSessionQuery,
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
},
},
object: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}
func prepareSessionQueryTesting(t *testing.T, token string) func(context.Context, prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Session, error)) {
return func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Session, error)) {
builder, scan := prepareSessionQuery(ctx, db)
return builder, func(row *sql.Row) (*Session, error) {
session, tokenID, err := scan(row)
require.Equal(t, tokenID, token)
return session, err
}
}
}