diff --git a/internal/api/grpc/session/v2/integration_test/query_test.go b/internal/api/grpc/session/v2/integration_test/query_test.go index 66f8c9b304..6b5ee78f2a 100644 --- a/internal/api/grpc/session/v2/integration_test/query_test.go +++ b/internal/api/grpc/session/v2/integration_test/query_test.go @@ -15,11 +15,13 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/integration" + objpb "github.com/zitadel/zitadel/pkg/grpc/object" "github.com/zitadel/zitadel/pkg/grpc/object/v2" "github.com/zitadel/zitadel/pkg/grpc/session/v2" ) func TestServer_GetSession(t *testing.T) { + t.Parallel() type args struct { ctx context.Context req *session.GetSessionRequest @@ -211,6 +213,7 @@ func TestServer_GetSession(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() var sequence uint64 if tt.args.dep != nil { sequence = tt.args.dep(LoginCTX, t, tt.args.req) @@ -223,9 +226,7 @@ func TestServer_GetSession(t *testing.T) { assert.Error(ttt, err) return } - if !assert.NoError(ttt, err) { - return - } + require.NoError(ttt, err) tt.want.Session.Id = tt.args.req.SessionId tt.want.Session.Sequence = sequence @@ -302,6 +303,7 @@ func createSession(ctx context.Context, t *testing.T, userID string, userAgent s } func TestServer_ListSessions(t *testing.T) { + t.Parallel() type args struct { ctx context.Context req *session.ListSessionsRequest @@ -679,9 +681,48 @@ func TestServer_ListSessions(t *testing.T) { wantFactors: []wantFactor{wantUserFactor}, wantErr: true, }, + { + name: "list sessions, expiration date query, ok", + args: args{ + IAMOwnerCTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, User.GetUserId(), "useragent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, + &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}, + &session.SearchQuery{Query: &session.SearchQuery_ExpirationDateQuery{ + ExpirationDateQuery: &session.ExpirationDateQuery{ExpirationDate: timestamppb.Now(), + Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER_OR_EQUALS, + }}}) + return []*sessionAttr{info} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 1, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{ + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("useragent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() infos := tt.args.dep(LoginCTX, t, tt.args.req) retryDuration, tick := integration.WaitForAndTickWithMaxDuration(tt.args.ctx, time.Minute) @@ -691,19 +732,15 @@ func TestServer_ListSessions(t *testing.T) { assert.Error(ttt, err) return } - if !assert.NoError(ttt, err) { - return - } + require.NoError(ttt, err) // expected count of sessions is not equal to created dependencies - if !assert.Len(ttt, tt.want.Sessions, len(infos)) { - return - } + require.Len(ttt, tt.want.Sessions, len(infos)) + // expected count of sessions is not equal to received sessions - if !assert.Equal(ttt, got.Details.TotalResult, tt.want.Details.TotalResult) || !assert.Len(ttt, got.Sessions, len(tt.want.Sessions)) { - return - } + require.Equal(ttt, tt.want.Details.TotalResult, got.Details.TotalResult) + require.Len(ttt, got.Sessions, len(tt.want.Sessions)) for i := range infos { tt.want.Sessions[i].Id = infos[i].ID @@ -727,3 +764,61 @@ func TestServer_ListSessions(t *testing.T) { }) } } + +func TestServer_ListSessions_with_expiration_date_filter(t *testing.T) { + t.Parallel() + // session with no expiration + session1, err := Client.CreateSession(IAMOwnerCTX, &session.CreateSessionRequest{}) + require.NoError(t, err) + + // session with expiration + session2, err := Client.CreateSession(IAMOwnerCTX, &session.CreateSessionRequest{ + Lifetime: durationpb.New(1 * time.Second), + }) + require.NoError(t, err) + + // wait until the second session expires + time.Sleep(2 * time.Second) + + // with comparison method GREATER_OR_EQUALS, only the active session should be returned + listSessionsResponse1, err := Client.ListSessions(IAMOwnerCTX, + &session.ListSessionsRequest{ + Queries: []*session.SearchQuery{ + { + Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{session1.SessionId}}}, + }, + { + Query: &session.SearchQuery_ExpirationDateQuery{ + ExpirationDateQuery: &session.ExpirationDateQuery{ + ExpirationDate: timestamppb.Now(), + Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER_OR_EQUALS, + }, + }, + }, + }, + }) + require.NoError(t, err) + require.Len(t, listSessionsResponse1.Sessions, 1) + assert.Equal(t, session1.SessionId, listSessionsResponse1.Sessions[0].Id) + + // with comparison method LESS_OR_EQUALS, only the expired session should be returned + listSessionsResponse2, err := Client.ListSessions(IAMOwnerCTX, + &session.ListSessionsRequest{ + Queries: []*session.SearchQuery{ + { + Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{session2.SessionId}}}, + }, + { + Query: &session.SearchQuery_ExpirationDateQuery{ + ExpirationDateQuery: &session.ExpirationDateQuery{ + ExpirationDate: timestamppb.Now(), + Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS_OR_EQUALS, + }, + }, + }, + }, + }) + require.NoError(t, err) + require.Len(t, listSessionsResponse2.Sessions, 1) + assert.Equal(t, session2.SessionId, listSessionsResponse2.Sessions[0].Id) +} diff --git a/internal/api/grpc/session/v2/integration_test/session_test.go b/internal/api/grpc/session/v2/integration_test/session_test.go index 6c0c079e48..9533c1ac82 100644 --- a/internal/api/grpc/session/v2/integration_test/session_test.go +++ b/internal/api/grpc/session/v2/integration_test/session_test.go @@ -930,6 +930,27 @@ func TestServer_DeleteSession_with_permission(t *testing.T) { require.NoError(t, err) } +func TestServer_DeleteSession_expired(t *testing.T) { + createResp, err := Client.CreateSession(LoginCTX, &session.CreateSessionRequest{ + Lifetime: durationpb.New(5 * time.Second), + }) + require.NoError(t, err) + + // wait until the token expires + time.Sleep(10 * time.Second) + _, err = Client.DeleteSession(Instance.WithAuthorizationToken(context.Background(), integration.UserTypeOrgOwner), &session.DeleteSessionRequest{ + SessionId: createResp.GetSessionId(), + SessionToken: gu.Ptr(createResp.GetSessionToken()), + }) + require.NoError(t, err) + + // get session should return an error + sessionResp, err := Client.GetSession(Instance.WithAuthorizationToken(context.Background(), integration.UserTypeOrgOwner), + &session.GetSessionRequest{SessionId: createResp.GetSessionId()}) + require.Error(t, err) + require.Nil(t, sessionResp) +} + func Test_ZITADEL_API_missing_authentication(t *testing.T) { // create new, empty session createResp, err := Client.CreateSession(LoginCTX, &session.CreateSessionRequest{}) diff --git a/internal/api/grpc/session/v2/query.go b/internal/api/grpc/session/v2/query.go index 73303dd9e8..9945555c3c 100644 --- a/internal/api/grpc/session/v2/query.go +++ b/internal/api/grpc/session/v2/query.go @@ -109,6 +109,8 @@ func sessionQueryToQuery(ctx context.Context, sq *session.SearchQuery) (query.Se } } return nil, zerrors.ThrowInvalidArgument(nil, "GRPC-x8n23uh", "List.Query.Invalid") + case *session.SearchQuery_ExpirationDateQuery: + return expirationDateQueryToQuery(q.ExpirationDateQuery) default: return nil, zerrors.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid") } @@ -123,6 +125,30 @@ func creationDateQueryToQuery(q *session.CreationDateQuery) (query.SearchQuery, return query.NewCreationDateQuery(q.GetCreationDate().AsTime(), comparison) } +func expirationDateQueryToQuery(q *session.ExpirationDateQuery) (query.SearchQuery, error) { + comparison := timestampComparisons[q.GetMethod()] + + // to obtain sessions with a set expiration date + expirationDateQuery, err := query.NewExpirationDateQuery(q.GetExpirationDate().AsTime(), comparison) + if err != nil { + return nil, err + } + + switch comparison { + case query.TimestampEquals, query.TimestampLess, query.TimestampLessOrEquals: + return expirationDateQuery, nil + case query.TimestampGreater, query.TimestampGreaterOrEquals: + // to obtain sessions without an expiration date + expirationDateIsNullQuery, err := query.NewIsNullQuery(query.SessionColumnExpiration) + if err != nil { + return nil, err + } + return query.NewOrQuery(expirationDateQuery, expirationDateIsNullQuery) + default: + return nil, zerrors.ThrowInvalidArgument(nil, "GRPC-Dwigt", "List.Query.InvalidComparisonMethod") + } +} + func fieldNameToSessionColumn(field session.SessionFieldName) query.Column { switch field { case session.SessionFieldName_SESSION_FIELD_NAME_CREATION_DATE: diff --git a/internal/api/grpc/session/v2/session_test.go b/internal/api/grpc/session/v2/session_test.go index ce4f5115f2..3015b9655d 100644 --- a/internal/api/grpc/session/v2/session_test.go +++ b/internal/api/grpc/session/v2/session_test.go @@ -24,6 +24,7 @@ import ( var ( creationDate = time.Date(2023, 10, 10, 14, 15, 0, 0, time.UTC) + expiration = creationDate.Add(90 * time.Second) ) func Test_sessionsToPb(t *testing.T) { @@ -315,6 +316,18 @@ func mustNewTimestampQuery(t testing.TB, column query.Column, ts time.Time, comp return q } +func mustNewIsNullQuery(t testing.TB, column query.Column) query.SearchQuery { + q, err := query.NewIsNullQuery(column) + require.NoError(t, err) + return q +} + +func mustNewOrQuery(t testing.TB, queries ...query.SearchQuery) query.SearchQuery { + q, err := query.NewOrQuery(queries...) + require.NoError(t, err) + return q +} + func Test_listSessionsRequestToQuery(t *testing.T) { type args struct { ctx context.Context @@ -398,6 +411,12 @@ func Test_listSessionsRequestToQuery(t *testing.T) { {Query: &session.SearchQuery_UserAgentQuery{ UserAgentQuery: &session.UserAgentQuery{}, }}, + {Query: &session.SearchQuery_ExpirationDateQuery{ + ExpirationDateQuery: &session.ExpirationDateQuery{ + ExpirationDate: timestamppb.New(expiration), + Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS_OR_EQUALS, + }, + }}, }, }, }, @@ -414,6 +433,7 @@ func Test_listSessionsRequestToQuery(t *testing.T) { mustNewTimestampQuery(t, query.SessionColumnCreationDate, creationDate, query.TimestampGreater), mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals), mustNewTextQuery(t, query.SessionColumnUserAgentFingerprintID, "agent", query.TextEquals), + mustNewTimestampQuery(t, query.SessionColumnExpiration, expiration, query.TimestampLessOrEquals), }, }, }, @@ -674,6 +694,91 @@ func Test_sessionQueryToQuery(t *testing.T) { }}, want: mustNewTextQuery(t, query.SessionColumnUserAgentFingerprintID, "agent2", query.TextEquals), }, + { + name: "expiration date query with default method", + args: args{ + context.Background(), + &session.SearchQuery{ + Query: &session.SearchQuery_ExpirationDateQuery{ + ExpirationDateQuery: &session.ExpirationDateQuery{ + ExpirationDate: timestamppb.New(expiration), + }, + }, + }}, + want: mustNewTimestampQuery(t, query.SessionColumnExpiration, expiration, query.TimestampEquals), + }, + { + name: "expiration date query with comparison method equals", + args: args{ + context.Background(), + &session.SearchQuery{ + Query: &session.SearchQuery_ExpirationDateQuery{ + ExpirationDateQuery: &session.ExpirationDateQuery{ + ExpirationDate: timestamppb.New(expiration), + Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_EQUALS, + }, + }, + }}, + want: mustNewTimestampQuery(t, query.SessionColumnExpiration, expiration, query.TimestampEquals), + }, + { + name: "expiration date query with comparison method less", + args: args{ + context.Background(), + &session.SearchQuery{ + Query: &session.SearchQuery_ExpirationDateQuery{ + ExpirationDateQuery: &session.ExpirationDateQuery{ + ExpirationDate: timestamppb.New(expiration), + Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS, + }, + }, + }}, + want: mustNewTimestampQuery(t, query.SessionColumnExpiration, expiration, query.TimestampLess), + }, + { + name: "expiration date query with comparison method less or equals", + args: args{ + context.Background(), + &session.SearchQuery{ + Query: &session.SearchQuery_ExpirationDateQuery{ + ExpirationDateQuery: &session.ExpirationDateQuery{ + ExpirationDate: timestamppb.New(expiration), + Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS_OR_EQUALS, + }, + }, + }}, + want: mustNewTimestampQuery(t, query.SessionColumnExpiration, expiration, query.TimestampLessOrEquals), + }, + { + name: "expiration date query with with comparison method greater", + args: args{ + context.Background(), + &session.SearchQuery{ + Query: &session.SearchQuery_ExpirationDateQuery{ + ExpirationDateQuery: &session.ExpirationDateQuery{ + ExpirationDate: timestamppb.New(expiration), + Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER, + }, + }, + }}, + want: mustNewOrQuery(t, mustNewTimestampQuery(t, query.SessionColumnExpiration, expiration, query.TimestampGreater), + mustNewIsNullQuery(t, query.SessionColumnExpiration)), + }, + { + name: "expiration date query with with comparison method greater or equals", + args: args{ + context.Background(), + &session.SearchQuery{ + Query: &session.SearchQuery_ExpirationDateQuery{ + ExpirationDateQuery: &session.ExpirationDateQuery{ + ExpirationDate: timestamppb.New(expiration), + Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER_OR_EQUALS, + }, + }, + }}, + want: mustNewOrQuery(t, mustNewTimestampQuery(t, query.SessionColumnExpiration, expiration, query.TimestampGreaterOrEquals), + mustNewIsNullQuery(t, query.SessionColumnExpiration)), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/command/session.go b/internal/command/session.go index 87eb56139b..8bbca8f776 100644 --- a/internal/command/session.go +++ b/internal/command/session.go @@ -346,7 +346,9 @@ func (c *Commands) terminateSession(ctx context.Context, sessionID, sessionToken return nil, err } } - if sessionWriteModel.CheckIsActive() != nil { + + // exclude expiration check as expired tokens can be deleted + if sessionWriteModel.State == domain.SessionStateUnspecified || sessionWriteModel.State == domain.SessionStateTerminated { return writeModelToObjectDetails(&sessionWriteModel.WriteModel), nil } terminate := session.NewTerminateEvent(ctx, &session.NewAggregate(sessionWriteModel.AggregateID, sessionWriteModel.ResourceOwner).Aggregate) diff --git a/internal/query/search_query.go b/internal/query/search_query.go index d6dd710d1e..4e0e65c489 100644 --- a/internal/query/search_query.go +++ b/internal/query/search_query.go @@ -675,6 +675,9 @@ type TimestampQuery struct { } func NewTimestampQuery(c Column, value time.Time, compare TimestampComparison) (*TimestampQuery, error) { + if c.isZero() { + return nil, ErrMissingColumn + } return &TimestampQuery{ Column: c, Compare: compare, diff --git a/internal/query/session.go b/internal/query/session.go index 111eb462a0..0c9cffb56a 100644 --- a/internal/query/session.go +++ b/internal/query/session.go @@ -342,6 +342,10 @@ func NewCreationDateQuery(datetime time.Time, compare TimestampComparison) (Sear return NewTimestampQuery(SessionColumnCreationDate, datetime, compare) } +func NewExpirationDateQuery(datetime time.Time, compare TimestampComparison) (SearchQuery, error) { + return NewTimestampQuery(SessionColumnExpiration, datetime, compare) +} + func prepareSessionQuery() (sq.SelectBuilder, func(*sql.Row) (*Session, string, error)) { return sq.Select( SessionColumnID.identifier(), diff --git a/proto/zitadel/session/v2/session.proto b/proto/zitadel/session/v2/session.proto index 7ab6b77610..d05bd89c1a 100644 --- a/proto/zitadel/session/v2/session.proto +++ b/proto/zitadel/session/v2/session.proto @@ -138,6 +138,7 @@ message SearchQuery { CreationDateQuery creation_date_query = 3; CreatorQuery creator_query = 4; UserAgentQuery user_agent_query = 5; + ExpirationDateQuery expiration_date_query = 6; } } @@ -183,6 +184,16 @@ message UserAgentQuery { ]; } +message ExpirationDateQuery { + google.protobuf.Timestamp expiration_date = 1; + zitadel.v1.TimestampQueryMethod method = 2 [ + (validate.rules).enum.defined_only = true, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "defines which timestamp comparison method is used"; + } + ]; +} + message UserAgent { optional string fingerprint_id = 1; optional string ip = 2;