fix(sessions): add an expiration date filter to list sessions api (#10384)

# Which Problems Are Solved

The deletion of expired sessions does not go through even though a
success response is returned to the user. These expired and supposedly
deleted (to the user) sessions are then returned when the `ListSessions`
API is called.

This PR fixes this issue by:
1. Allowing deletion of expired sessions
2. Providing an `expiration_date` filter in `ListSession` API to filter
sessions by expiration date

# How the Problems Are Solved

1. Remove expired session check during deletion
2. Add an `expiration_date` filter to the  `ListSession` API

# Additional Changes
N/A

# Additional Context
- Closes #10045

---------

Co-authored-by: Marco A. <marco@zitadel.com>
This commit is contained in:
Gayathri Vijayan
2025-08-07 14:58:59 +02:00
committed by Stefan Benz
parent 07f22e7e2a
commit 5df28465a4
8 changed files with 280 additions and 13 deletions

View File

@@ -15,11 +15,13 @@ import (
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/integration" "github.com/zitadel/zitadel/internal/integration"
objpb "github.com/zitadel/zitadel/pkg/grpc/object"
"github.com/zitadel/zitadel/pkg/grpc/object/v2" "github.com/zitadel/zitadel/pkg/grpc/object/v2"
"github.com/zitadel/zitadel/pkg/grpc/session/v2" "github.com/zitadel/zitadel/pkg/grpc/session/v2"
) )
func TestServer_GetSession(t *testing.T) { func TestServer_GetSession(t *testing.T) {
t.Parallel()
type args struct { type args struct {
ctx context.Context ctx context.Context
req *session.GetSessionRequest req *session.GetSessionRequest
@@ -211,6 +213,7 @@ func TestServer_GetSession(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var sequence uint64 var sequence uint64
if tt.args.dep != nil { if tt.args.dep != nil {
sequence = tt.args.dep(LoginCTX, t, tt.args.req) sequence = tt.args.dep(LoginCTX, t, tt.args.req)
@@ -223,9 +226,7 @@ func TestServer_GetSession(t *testing.T) {
assert.Error(ttt, err) assert.Error(ttt, err)
return return
} }
if !assert.NoError(ttt, err) { require.NoError(ttt, err)
return
}
tt.want.Session.Id = tt.args.req.SessionId tt.want.Session.Id = tt.args.req.SessionId
tt.want.Session.Sequence = sequence tt.want.Session.Sequence = sequence
@@ -302,6 +303,7 @@ func createSession(ctx context.Context, t *testing.T, userID string, userAgent s
} }
func TestServer_ListSessions(t *testing.T) { func TestServer_ListSessions(t *testing.T) {
t.Parallel()
type args struct { type args struct {
ctx context.Context ctx context.Context
req *session.ListSessionsRequest req *session.ListSessionsRequest
@@ -679,9 +681,48 @@ func TestServer_ListSessions(t *testing.T) {
wantFactors: []wantFactor{wantUserFactor}, wantFactors: []wantFactor{wantUserFactor},
wantErr: true, wantErr: true,
}, },
{
name: "list sessions, expiration date query, ok",
args: args{
IAMOwnerCTX,
&session.ListSessionsRequest{},
func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr {
info := createSession(ctx, t, User.GetUserId(), "useragent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")})
request.Queries = append(request.Queries,
&session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}},
&session.SearchQuery{Query: &session.SearchQuery_ExpirationDateQuery{
ExpirationDateQuery: &session.ExpirationDateQuery{ExpirationDate: timestamppb.Now(),
Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER_OR_EQUALS,
}}})
return []*sessionAttr{info}
},
},
wantExpirationWindow: time.Minute * 5,
wantFactors: []wantFactor{wantUserFactor},
want: &session.ListSessionsResponse{
Details: &object.ListDetails{
TotalResult: 1,
Timestamp: timestamppb.Now(),
},
Sessions: []*session.Session{
{
Metadata: map[string][]byte{"key": []byte("value")},
UserAgent: &session.UserAgent{
FingerprintId: gu.Ptr("useragent"),
Ip: gu.Ptr("1.2.3.4"),
Description: gu.Ptr("Description"),
Header: map[string]*session.UserAgent_HeaderValues{
"foo": {Values: []string{"foo", "bar"}},
},
},
},
},
},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
infos := tt.args.dep(LoginCTX, t, tt.args.req) infos := tt.args.dep(LoginCTX, t, tt.args.req)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(tt.args.ctx, time.Minute) retryDuration, tick := integration.WaitForAndTickWithMaxDuration(tt.args.ctx, time.Minute)
@@ -691,19 +732,15 @@ func TestServer_ListSessions(t *testing.T) {
assert.Error(ttt, err) assert.Error(ttt, err)
return return
} }
if !assert.NoError(ttt, err) { require.NoError(ttt, err)
return
}
// expected count of sessions is not equal to created dependencies // expected count of sessions is not equal to created dependencies
if !assert.Len(ttt, tt.want.Sessions, len(infos)) { require.Len(ttt, tt.want.Sessions, len(infos))
return
}
// expected count of sessions is not equal to received sessions // expected count of sessions is not equal to received sessions
if !assert.Equal(ttt, got.Details.TotalResult, tt.want.Details.TotalResult) || !assert.Len(ttt, got.Sessions, len(tt.want.Sessions)) { require.Equal(ttt, tt.want.Details.TotalResult, got.Details.TotalResult)
return require.Len(ttt, got.Sessions, len(tt.want.Sessions))
}
for i := range infos { for i := range infos {
tt.want.Sessions[i].Id = infos[i].ID tt.want.Sessions[i].Id = infos[i].ID
@@ -727,3 +764,61 @@ func TestServer_ListSessions(t *testing.T) {
}) })
} }
} }
func TestServer_ListSessions_with_expiration_date_filter(t *testing.T) {
t.Parallel()
// session with no expiration
session1, err := Client.CreateSession(IAMOwnerCTX, &session.CreateSessionRequest{})
require.NoError(t, err)
// session with expiration
session2, err := Client.CreateSession(IAMOwnerCTX, &session.CreateSessionRequest{
Lifetime: durationpb.New(1 * time.Second),
})
require.NoError(t, err)
// wait until the second session expires
time.Sleep(2 * time.Second)
// with comparison method GREATER_OR_EQUALS, only the active session should be returned
listSessionsResponse1, err := Client.ListSessions(IAMOwnerCTX,
&session.ListSessionsRequest{
Queries: []*session.SearchQuery{
{
Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{session1.SessionId}}},
},
{
Query: &session.SearchQuery_ExpirationDateQuery{
ExpirationDateQuery: &session.ExpirationDateQuery{
ExpirationDate: timestamppb.Now(),
Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER_OR_EQUALS,
},
},
},
},
})
require.NoError(t, err)
require.Len(t, listSessionsResponse1.Sessions, 1)
assert.Equal(t, session1.SessionId, listSessionsResponse1.Sessions[0].Id)
// with comparison method LESS_OR_EQUALS, only the expired session should be returned
listSessionsResponse2, err := Client.ListSessions(IAMOwnerCTX,
&session.ListSessionsRequest{
Queries: []*session.SearchQuery{
{
Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{session2.SessionId}}},
},
{
Query: &session.SearchQuery_ExpirationDateQuery{
ExpirationDateQuery: &session.ExpirationDateQuery{
ExpirationDate: timestamppb.Now(),
Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS_OR_EQUALS,
},
},
},
},
})
require.NoError(t, err)
require.Len(t, listSessionsResponse2.Sessions, 1)
assert.Equal(t, session2.SessionId, listSessionsResponse2.Sessions[0].Id)
}

View File

@@ -930,6 +930,27 @@ func TestServer_DeleteSession_with_permission(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
func TestServer_DeleteSession_expired(t *testing.T) {
createResp, err := Client.CreateSession(LoginCTX, &session.CreateSessionRequest{
Lifetime: durationpb.New(5 * time.Second),
})
require.NoError(t, err)
// wait until the token expires
time.Sleep(10 * time.Second)
_, err = Client.DeleteSession(Instance.WithAuthorizationToken(context.Background(), integration.UserTypeOrgOwner), &session.DeleteSessionRequest{
SessionId: createResp.GetSessionId(),
SessionToken: gu.Ptr(createResp.GetSessionToken()),
})
require.NoError(t, err)
// get session should return an error
sessionResp, err := Client.GetSession(Instance.WithAuthorizationToken(context.Background(), integration.UserTypeOrgOwner),
&session.GetSessionRequest{SessionId: createResp.GetSessionId()})
require.Error(t, err)
require.Nil(t, sessionResp)
}
func Test_ZITADEL_API_missing_authentication(t *testing.T) { func Test_ZITADEL_API_missing_authentication(t *testing.T) {
// create new, empty session // create new, empty session
createResp, err := Client.CreateSession(LoginCTX, &session.CreateSessionRequest{}) createResp, err := Client.CreateSession(LoginCTX, &session.CreateSessionRequest{})

View File

@@ -110,6 +110,8 @@ func sessionQueryToQuery(ctx context.Context, sq *session.SearchQuery) (query.Se
} }
} }
return nil, zerrors.ThrowInvalidArgument(nil, "GRPC-x8n23uh", "List.Query.Invalid") return nil, zerrors.ThrowInvalidArgument(nil, "GRPC-x8n23uh", "List.Query.Invalid")
case *session.SearchQuery_ExpirationDateQuery:
return expirationDateQueryToQuery(q.ExpirationDateQuery)
default: default:
return nil, zerrors.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid") return nil, zerrors.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid")
} }
@@ -124,6 +126,30 @@ func creationDateQueryToQuery(q *session.CreationDateQuery) (query.SearchQuery,
return query.NewCreationDateQuery(q.GetCreationDate().AsTime(), comparison) return query.NewCreationDateQuery(q.GetCreationDate().AsTime(), comparison)
} }
func expirationDateQueryToQuery(q *session.ExpirationDateQuery) (query.SearchQuery, error) {
comparison := timestampComparisons[q.GetMethod()]
// to obtain sessions with a set expiration date
expirationDateQuery, err := query.NewExpirationDateQuery(q.GetExpirationDate().AsTime(), comparison)
if err != nil {
return nil, err
}
switch comparison {
case query.TimestampEquals, query.TimestampLess, query.TimestampLessOrEquals:
return expirationDateQuery, nil
case query.TimestampGreater, query.TimestampGreaterOrEquals:
// to obtain sessions without an expiration date
expirationDateIsNullQuery, err := query.NewIsNullQuery(query.SessionColumnExpiration)
if err != nil {
return nil, err
}
return query.NewOrQuery(expirationDateQuery, expirationDateIsNullQuery)
default:
return nil, zerrors.ThrowInvalidArgument(nil, "GRPC-Dwigt", "List.Query.InvalidComparisonMethod")
}
}
func fieldNameToSessionColumn(field session.SessionFieldName) query.Column { func fieldNameToSessionColumn(field session.SessionFieldName) query.Column {
switch field { switch field {
case session.SessionFieldName_SESSION_FIELD_NAME_CREATION_DATE: case session.SessionFieldName_SESSION_FIELD_NAME_CREATION_DATE:

View File

@@ -24,6 +24,7 @@ import (
var ( var (
creationDate = time.Date(2023, 10, 10, 14, 15, 0, 0, time.UTC) creationDate = time.Date(2023, 10, 10, 14, 15, 0, 0, time.UTC)
expiration = creationDate.Add(90 * time.Second)
) )
func Test_sessionsToPb(t *testing.T) { func Test_sessionsToPb(t *testing.T) {
@@ -315,6 +316,18 @@ func mustNewTimestampQuery(t testing.TB, column query.Column, ts time.Time, comp
return q return q
} }
func mustNewIsNullQuery(t testing.TB, column query.Column) query.SearchQuery {
q, err := query.NewIsNullQuery(column)
require.NoError(t, err)
return q
}
func mustNewOrQuery(t testing.TB, queries ...query.SearchQuery) query.SearchQuery {
q, err := query.NewOrQuery(queries...)
require.NoError(t, err)
return q
}
func Test_listSessionsRequestToQuery(t *testing.T) { func Test_listSessionsRequestToQuery(t *testing.T) {
type args struct { type args struct {
ctx context.Context ctx context.Context
@@ -398,6 +411,12 @@ func Test_listSessionsRequestToQuery(t *testing.T) {
{Query: &session.SearchQuery_UserAgentQuery{ {Query: &session.SearchQuery_UserAgentQuery{
UserAgentQuery: &session.UserAgentQuery{}, UserAgentQuery: &session.UserAgentQuery{},
}}, }},
{Query: &session.SearchQuery_ExpirationDateQuery{
ExpirationDateQuery: &session.ExpirationDateQuery{
ExpirationDate: timestamppb.New(expiration),
Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS_OR_EQUALS,
},
}},
}, },
}, },
}, },
@@ -414,6 +433,7 @@ func Test_listSessionsRequestToQuery(t *testing.T) {
mustNewTimestampQuery(t, query.SessionColumnCreationDate, creationDate, query.TimestampGreater), mustNewTimestampQuery(t, query.SessionColumnCreationDate, creationDate, query.TimestampGreater),
mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals), mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals),
mustNewTextQuery(t, query.SessionColumnUserAgentFingerprintID, "agent", query.TextEquals), mustNewTextQuery(t, query.SessionColumnUserAgentFingerprintID, "agent", query.TextEquals),
mustNewTimestampQuery(t, query.SessionColumnExpiration, expiration, query.TimestampLessOrEquals),
}, },
}, },
}, },
@@ -674,6 +694,91 @@ func Test_sessionQueryToQuery(t *testing.T) {
}}, }},
want: mustNewTextQuery(t, query.SessionColumnUserAgentFingerprintID, "agent2", query.TextEquals), want: mustNewTextQuery(t, query.SessionColumnUserAgentFingerprintID, "agent2", query.TextEquals),
}, },
{
name: "expiration date query with default method",
args: args{
context.Background(),
&session.SearchQuery{
Query: &session.SearchQuery_ExpirationDateQuery{
ExpirationDateQuery: &session.ExpirationDateQuery{
ExpirationDate: timestamppb.New(expiration),
},
},
}},
want: mustNewTimestampQuery(t, query.SessionColumnExpiration, expiration, query.TimestampEquals),
},
{
name: "expiration date query with comparison method equals",
args: args{
context.Background(),
&session.SearchQuery{
Query: &session.SearchQuery_ExpirationDateQuery{
ExpirationDateQuery: &session.ExpirationDateQuery{
ExpirationDate: timestamppb.New(expiration),
Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_EQUALS,
},
},
}},
want: mustNewTimestampQuery(t, query.SessionColumnExpiration, expiration, query.TimestampEquals),
},
{
name: "expiration date query with comparison method less",
args: args{
context.Background(),
&session.SearchQuery{
Query: &session.SearchQuery_ExpirationDateQuery{
ExpirationDateQuery: &session.ExpirationDateQuery{
ExpirationDate: timestamppb.New(expiration),
Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS,
},
},
}},
want: mustNewTimestampQuery(t, query.SessionColumnExpiration, expiration, query.TimestampLess),
},
{
name: "expiration date query with comparison method less or equals",
args: args{
context.Background(),
&session.SearchQuery{
Query: &session.SearchQuery_ExpirationDateQuery{
ExpirationDateQuery: &session.ExpirationDateQuery{
ExpirationDate: timestamppb.New(expiration),
Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS_OR_EQUALS,
},
},
}},
want: mustNewTimestampQuery(t, query.SessionColumnExpiration, expiration, query.TimestampLessOrEquals),
},
{
name: "expiration date query with with comparison method greater",
args: args{
context.Background(),
&session.SearchQuery{
Query: &session.SearchQuery_ExpirationDateQuery{
ExpirationDateQuery: &session.ExpirationDateQuery{
ExpirationDate: timestamppb.New(expiration),
Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER,
},
},
}},
want: mustNewOrQuery(t, mustNewTimestampQuery(t, query.SessionColumnExpiration, expiration, query.TimestampGreater),
mustNewIsNullQuery(t, query.SessionColumnExpiration)),
},
{
name: "expiration date query with with comparison method greater or equals",
args: args{
context.Background(),
&session.SearchQuery{
Query: &session.SearchQuery_ExpirationDateQuery{
ExpirationDateQuery: &session.ExpirationDateQuery{
ExpirationDate: timestamppb.New(expiration),
Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER_OR_EQUALS,
},
},
}},
want: mustNewOrQuery(t, mustNewTimestampQuery(t, query.SessionColumnExpiration, expiration, query.TimestampGreaterOrEquals),
mustNewIsNullQuery(t, query.SessionColumnExpiration)),
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View File

@@ -346,7 +346,9 @@ func (c *Commands) terminateSession(ctx context.Context, sessionID, sessionToken
return nil, err return nil, err
} }
} }
if sessionWriteModel.CheckIsActive() != nil {
// exclude expiration check as expired tokens can be deleted
if sessionWriteModel.State == domain.SessionStateUnspecified || sessionWriteModel.State == domain.SessionStateTerminated {
return writeModelToObjectDetails(&sessionWriteModel.WriteModel), nil return writeModelToObjectDetails(&sessionWriteModel.WriteModel), nil
} }
terminate := session.NewTerminateEvent(ctx, &session.NewAggregate(sessionWriteModel.AggregateID, sessionWriteModel.ResourceOwner).Aggregate) terminate := session.NewTerminateEvent(ctx, &session.NewAggregate(sessionWriteModel.AggregateID, sessionWriteModel.ResourceOwner).Aggregate)

View File

@@ -675,6 +675,9 @@ type TimestampQuery struct {
} }
func NewTimestampQuery(c Column, value time.Time, compare TimestampComparison) (*TimestampQuery, error) { func NewTimestampQuery(c Column, value time.Time, compare TimestampComparison) (*TimestampQuery, error) {
if c.isZero() {
return nil, ErrMissingColumn
}
return &TimestampQuery{ return &TimestampQuery{
Column: c, Column: c,
Compare: compare, Compare: compare,

View File

@@ -362,6 +362,10 @@ func NewCreationDateQuery(datetime time.Time, compare TimestampComparison) (Sear
return NewTimestampQuery(SessionColumnCreationDate, datetime, compare) return NewTimestampQuery(SessionColumnCreationDate, datetime, compare)
} }
func NewExpirationDateQuery(datetime time.Time, compare TimestampComparison) (SearchQuery, error) {
return NewTimestampQuery(SessionColumnExpiration, datetime, compare)
}
func prepareSessionQuery() (sq.SelectBuilder, func(*sql.Row) (*Session, string, error)) { func prepareSessionQuery() (sq.SelectBuilder, func(*sql.Row) (*Session, string, error)) {
return sq.Select( return sq.Select(
SessionColumnID.identifier(), SessionColumnID.identifier(),

View File

@@ -138,6 +138,7 @@ message SearchQuery {
CreationDateQuery creation_date_query = 3; CreationDateQuery creation_date_query = 3;
CreatorQuery creator_query = 4; CreatorQuery creator_query = 4;
UserAgentQuery user_agent_query = 5; UserAgentQuery user_agent_query = 5;
ExpirationDateQuery expiration_date_query = 6;
} }
} }
@@ -183,6 +184,16 @@ message UserAgentQuery {
]; ];
} }
message ExpirationDateQuery {
google.protobuf.Timestamp expiration_date = 1;
zitadel.v1.TimestampQueryMethod method = 2 [
(validate.rules).enum.defined_only = true,
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "defines which timestamp comparison method is used";
}
];
}
message UserAgent { message UserAgent {
optional string fingerprint_id = 1; optional string fingerprint_id = 1;
optional string ip = 2; optional string ip = 2;