From 9a708b1b7886f8048b6f1316dd752b880b284f37 Mon Sep 17 00:00:00 2001 From: sp132 <125546043+sp132@users.noreply.github.com> Date: Wed, 8 Nov 2023 12:32:13 +0200 Subject: [PATCH] feat: extend session search service (#6746) * feat: extend session search service (#6029) add two more searching criteria - human user id and session creation date optional sorting by the session creation date * fix: use correct column identifier * fix: implement Col() * chore: fix unit tests * chore: fix linter warnings --------- Co-authored-by: Fabi --- internal/api/grpc/session/v2/session.go | 44 +++++- internal/api/grpc/session/v2/session_test.go | 81 ++++++++++- internal/query/search_query.go | 129 +++++++++++++----- internal/query/session.go | 11 +- proto/zitadel/object.proto | 9 +- proto/zitadel/session/v2beta/session.proto | 22 +++ .../session/v2beta/session_service.proto | 1 + 7 files changed, 251 insertions(+), 46 deletions(-) diff --git a/internal/api/grpc/session/v2/session.go b/internal/api/grpc/session/v2/session.go index ec6c06a4bd..a3aa481010 100644 --- a/internal/api/grpc/session/v2/session.go +++ b/internal/api/grpc/session/v2/session.go @@ -18,9 +18,20 @@ import ( "github.com/zitadel/zitadel/internal/domain" caos_errs "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/query" + objpb "github.com/zitadel/zitadel/pkg/grpc/object" session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta" ) +var ( + timestampComparisons = map[objpb.TimestampQueryMethod]query.TimestampComparison{ + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_EQUALS: query.TimestampEquals, + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER: query.TimestampGreater, + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER_OR_EQUALS: query.TimestampGreaterOrEquals, + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS: query.TimestampLess, + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS_OR_EQUALS: query.TimestampLessOrEquals, + } +) + func (s *Server) GetSession(ctx context.Context, req *session.GetSessionRequest) (*session.GetSessionResponse, error) { res, err := s.query.SessionByID(ctx, true, req.GetSessionId(), req.GetSessionToken()) if err != nil { @@ -240,9 +251,10 @@ func listSessionsRequestToQuery(ctx context.Context, req *session.ListSessionsRe } return &query.SessionsSearchQueries{ SearchRequest: query.SearchRequest{ - Offset: offset, - Limit: limit, - Asc: asc, + Offset: offset, + Limit: limit, + Asc: asc, + SortingColumn: fieldNameToSessionColumn(req.GetSortingColumn()), }, Queries: queries, }, nil @@ -250,8 +262,8 @@ func listSessionsRequestToQuery(ctx context.Context, req *session.ListSessionsRe func sessionQueriesToQuery(ctx context.Context, queries []*session.SearchQuery) (_ []query.SearchQuery, err error) { q := make([]query.SearchQuery, len(queries)+1) - for i, query := range queries { - q[i], err = sessionQueryToQuery(query) + for i, v := range queries { + q[i], err = sessionQueryToQuery(v) if err != nil { return nil, err } @@ -264,10 +276,14 @@ func sessionQueriesToQuery(ctx context.Context, queries []*session.SearchQuery) return q, nil } -func sessionQueryToQuery(query *session.SearchQuery) (query.SearchQuery, error) { - switch q := query.Query.(type) { +func sessionQueryToQuery(sq *session.SearchQuery) (query.SearchQuery, error) { + switch q := sq.Query.(type) { case *session.SearchQuery_IdsQuery: return idsQueryToQuery(q.IdsQuery) + case *session.SearchQuery_UserIdQuery: + return query.NewUserIDSearchQuery(q.UserIdQuery.GetId()) + case *session.SearchQuery_CreationDateQuery: + return creationDateQueryToQuery(q.CreationDateQuery) default: return nil, caos_errs.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid") } @@ -277,6 +293,20 @@ func idsQueryToQuery(q *session.IDsQuery) (query.SearchQuery, error) { return query.NewSessionIDsSearchQuery(q.Ids) } +func creationDateQueryToQuery(q *session.CreationDateQuery) (query.SearchQuery, error) { + comparison := timestampComparisons[q.GetMethod()] + return query.NewCreationDateQuery(q.GetCreationDate().AsTime(), comparison) +} + +func fieldNameToSessionColumn(field session.SessionFieldName) query.Column { + switch field { + case session.SessionFieldName_SESSION_FIELD_NAME_CREATION_DATE: + return query.SessionColumnCreationDate + default: + return query.Column{} + } +} + func (s *Server) createSessionRequestToCommand(ctx context.Context, req *session.CreateSessionRequest) ([]command.SessionCommand, map[string][]byte, *domain.UserAgent, time.Duration, error) { checks, err := s.checksToCommand(ctx, req.Checks) if err != nil { diff --git a/internal/api/grpc/session/v2/session_test.go b/internal/api/grpc/session/v2/session_test.go index 0f243bab2f..18f2cc4e67 100644 --- a/internal/api/grpc/session/v2/session_test.go +++ b/internal/api/grpc/session/v2/session_test.go @@ -14,6 +14,7 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/api/authz" + objpb "github.com/zitadel/zitadel/pkg/grpc/object" "github.com/zitadel/zitadel/internal/domain" caos_errs "github.com/zitadel/zitadel/internal/errors" @@ -22,6 +23,10 @@ import ( session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta" ) +var ( + creationDate = time.Date(2023, 10, 10, 14, 15, 0, 0, time.UTC) +) + func Test_sessionsToPb(t *testing.T) { now := time.Now() past := now.Add(-time.Hour) @@ -309,11 +314,18 @@ func mustNewListQuery(t testing.TB, column query.Column, list []any, compare que return q } +func mustNewTimestampQuery(t testing.TB, column query.Column, ts time.Time, compare query.TimestampComparison) query.SearchQuery { + q, err := query.NewTimestampQuery(column, ts, compare) + require.NoError(t, err) + return q +} + func Test_listSessionsRequestToQuery(t *testing.T) { type args struct { ctx context.Context req *session.ListSessionsRequest } + tests := []struct { name string args args @@ -337,6 +349,26 @@ func Test_listSessionsRequestToQuery(t *testing.T) { }, }, }, + { + name: "default request with sorting column", + args: args{ + ctx: authz.NewMockContext("123", "456", "789"), + req: &session.ListSessionsRequest{ + SortingColumn: session.SessionFieldName_SESSION_FIELD_NAME_CREATION_DATE, + }, + }, + want: &query.SessionsSearchQueries{ + SearchRequest: query.SearchRequest{ + Offset: 0, + Limit: 0, + SortingColumn: query.SessionColumnCreationDate, + Asc: false, + }, + Queries: []query.SearchQuery{ + mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals), + }, + }, + }, { name: "with list query and sessions", args: args{ @@ -358,6 +390,17 @@ func Test_listSessionsRequestToQuery(t *testing.T) { Ids: []string{"4", "5", "6"}, }, }}, + {Query: &session.SearchQuery_UserIdQuery{ + UserIdQuery: &session.UserIDQuery{ + Id: "10", + }, + }}, + {Query: &session.SearchQuery_CreationDateQuery{ + CreationDateQuery: &session.CreationDateQuery{ + CreationDate: timestamppb.New(creationDate), + Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER, + }, + }}, }, }, }, @@ -370,6 +413,8 @@ func Test_listSessionsRequestToQuery(t *testing.T) { Queries: []query.SearchQuery{ mustNewListQuery(t, query.SessionColumnID, []interface{}{"1", "2", "3"}, query.ListIn), mustNewListQuery(t, query.SessionColumnID, []interface{}{"4", "5", "6"}, query.ListIn), + mustNewTextQuery(t, query.SessionColumnUserID, "10", query.TextEquals), + mustNewTimestampQuery(t, query.SessionColumnCreationDate, creationDate, query.TimestampGreater), mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals), }, }, @@ -487,7 +532,7 @@ func Test_sessionQueryToQuery(t *testing.T) { wantErr: caos_errs.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid"), }, { - name: "query", + name: "ids query", args: args{&session.SearchQuery{ Query: &session.SearchQuery_IdsQuery{ IdsQuery: &session.IDsQuery{ @@ -497,6 +542,40 @@ func Test_sessionQueryToQuery(t *testing.T) { }}, want: mustNewListQuery(t, query.SessionColumnID, []interface{}{"1", "2", "3"}, query.ListIn), }, + { + name: "user id query", + args: args{&session.SearchQuery{ + Query: &session.SearchQuery_UserIdQuery{ + UserIdQuery: &session.UserIDQuery{ + Id: "10", + }, + }, + }}, + want: mustNewTextQuery(t, query.SessionColumnUserID, "10", query.TextEquals), + }, + { + name: "creation date query", + args: args{&session.SearchQuery{ + Query: &session.SearchQuery_CreationDateQuery{ + CreationDateQuery: &session.CreationDateQuery{ + CreationDate: timestamppb.New(creationDate), + Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS, + }, + }, + }}, + want: mustNewTimestampQuery(t, query.SessionColumnCreationDate, creationDate, query.TimestampLess), + }, + { + name: "creation date query with default method", + args: args{&session.SearchQuery{ + Query: &session.SearchQuery_CreationDateQuery{ + CreationDateQuery: &session.CreationDateQuery{ + CreationDate: timestamppb.New(creationDate), + }, + }, + }}, + want: mustNewTimestampQuery(t, query.SessionColumnCreationDate, creationDate, query.TimestampEquals), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/query/search_query.go b/internal/query/search_query.go index ce9d31238f..73713823f7 100644 --- a/internal/query/search_query.go +++ b/internal/query/search_query.go @@ -3,6 +3,7 @@ package query import ( "errors" "reflect" + "time" sq "github.com/Masterminds/squirrel" @@ -231,36 +232,41 @@ func (q *InTextQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query.Where(q.comp()) } -func (s *InTextQuery) comp() sq.Sqlizer { +func (q *InTextQuery) comp() sq.Sqlizer { // This translates to an IN query - return sq.Eq{s.Column.identifier(): s.Values} + return sq.Eq{q.Column.identifier(): q.Values} } func (q *TextQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query.Where(q.comp()) } -func (s *TextQuery) comp() sq.Sqlizer { - switch s.Compare { +func (q *TextQuery) comp() sq.Sqlizer { + switch q.Compare { case TextEquals: - return sq.Eq{s.Column.identifier(): s.Text} + return sq.Eq{q.Column.identifier(): q.Text} + case TextNotEquals: + return sq.NotEq{q.Column.identifier(): q.Text} case TextEqualsIgnoreCase: - return sq.ILike{s.Column.identifier(): s.Text} + return sq.ILike{q.Column.identifier(): q.Text} case TextStartsWith: - return sq.Like{s.Column.identifier(): s.Text + "%"} + return sq.Like{q.Column.identifier(): q.Text + "%"} case TextStartsWithIgnoreCase: - return sq.ILike{s.Column.identifier(): s.Text + "%"} + return sq.ILike{q.Column.identifier(): q.Text + "%"} case TextEndsWith: - return sq.Like{s.Column.identifier(): "%" + s.Text} + return sq.Like{q.Column.identifier(): "%" + q.Text} case TextEndsWithIgnoreCase: - return sq.ILike{s.Column.identifier(): "%" + s.Text} + return sq.ILike{q.Column.identifier(): "%" + q.Text} case TextContains: - return sq.Like{s.Column.identifier(): "%" + s.Text + "%"} + return sq.Like{q.Column.identifier(): "%" + q.Text + "%"} case TextContainsIgnoreCase: - return sq.ILike{s.Column.identifier(): "%" + s.Text + "%"} + return sq.ILike{q.Column.identifier(): "%" + q.Text + "%"} case TextListContains: - return &listContains{col: s.Column, args: []interface{}{s.Text}} + return &listContains{col: q.Column, args: []interface{}{q.Text}} + case textCompareMax: + return nil } + return nil } @@ -341,19 +347,22 @@ func (q *NumberQuery) Col() Column { return q.Column } -func (s *NumberQuery) comp() sq.Sqlizer { - switch s.Compare { +func (q *NumberQuery) comp() sq.Sqlizer { + switch q.Compare { case NumberEquals: - return sq.Eq{s.Column.identifier(): s.Number} + return sq.Eq{q.Column.identifier(): q.Number} case NumberNotEquals: - return sq.NotEq{s.Column.identifier(): s.Number} + return sq.NotEq{q.Column.identifier(): q.Number} case NumberLess: - return sq.Lt{s.Column.identifier(): s.Number} + return sq.Lt{q.Column.identifier(): q.Number} case NumberGreater: - return sq.Gt{s.Column.identifier(): s.Number} + return sq.Gt{q.Column.identifier(): q.Number} case NumberListContains: - return &listContains{col: s.Column, args: []interface{}{s.Number}} + return &listContains{col: q.Column, args: []interface{}{q.Number}} + case numberCompareMax: + return nil } + return nil } @@ -442,19 +451,19 @@ func (q *ListQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query.Where(q.comp()) } -func (s *ListQuery) comp() sq.Sqlizer { - switch s.Compare { - case ListIn: - if subSelect, ok := s.Data.(*SubSelect); ok { - subSelect, args, err := subSelect.comp().ToSql() - if err != nil { - return nil - } - return sq.Expr(s.Column.identifier()+" IN ( "+subSelect+" )", args...) - } - return sq.Eq{s.Column.identifier(): s.Data} +func (q *ListQuery) comp() sq.Sqlizer { + if q.Compare != ListIn { + return nil } - return nil + + if subSelect, ok := q.Data.(*SubSelect); ok { + subSelect, args, err := subSelect.comp().ToSql() + if err != nil { + return nil + } + return sq.Expr(q.Column.identifier()+" IN ( "+subSelect+" )", args...) + } + return sq.Eq{q.Column.identifier(): q.Data} } func (q *ListQuery) Col() Column { @@ -524,16 +533,64 @@ func (q *BoolQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query.Where(q.comp()) } -func (s *BoolQuery) comp() sq.Sqlizer { - return sq.Eq{s.Column.identifier(): s.Value} +func (q *BoolQuery) comp() sq.Sqlizer { + return sq.Eq{q.Column.identifier(): q.Value} +} + +type TimestampComparison int + +const ( + TimestampEquals TimestampComparison = iota + TimestampGreater + TimestampGreaterOrEquals + TimestampLess + TimestampLessOrEquals +) + +type TimestampQuery struct { + Column Column + Compare TimestampComparison + Value time.Time +} + +func NewTimestampQuery(c Column, value time.Time, compare TimestampComparison) (*TimestampQuery, error) { + return &TimestampQuery{ + Column: c, + Compare: compare, + Value: value, + }, nil +} + +func (q *TimestampQuery) Col() Column { + return q.Column +} + +func (q *TimestampQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { + return query.Where(q.comp()) +} + +func (q *TimestampQuery) comp() sq.Sqlizer { + switch q.Compare { + case TimestampEquals: + return sq.Eq{q.Column.identifier(): q.Value} + case TimestampGreater: + return sq.Gt{q.Column.identifier(): q.Value} + case TimestampGreaterOrEquals: + return sq.GtOrEq{q.Column.identifier(): q.Value} + case TimestampLess: + return sq.Lt{q.Column.identifier(): q.Value} + case TimestampLessOrEquals: + return sq.LtOrEq{q.Column.identifier(): q.Value} + } + return nil } var ( - //countColumn represents the default counter for search responses + // countColumn represents the default counter for search responses countColumn = Column{ name: "COUNT(*) OVER ()", } - //uniqueColumn shows if there are any results + // uniqueColumn shows if there are any results uniqueColumn = Column{ name: "COUNT(*) = 0", } diff --git a/internal/query/session.go b/internal/query/session.go index 75a31ab27f..acf5ac3bd5 100644 --- a/internal/query/session.go +++ b/internal/query/session.go @@ -239,7 +239,8 @@ func (q *Queries) SearchSessions(ctx context.Context, queries *SessionsSearchQue stmt, args, err := queries.toQuery(query). Where(sq.Eq{ SessionColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), - }).ToSql() + }). + ToSql() if err != nil { return nil, errors.ThrowInvalidArgument(err, "QUERY-sn9Jf", "Errors.Query.InvalidRequest") } @@ -268,6 +269,14 @@ func NewSessionCreatorSearchQuery(creator string) (SearchQuery, error) { return NewTextQuery(SessionColumnCreator, creator, TextEquals) } +func NewUserIDSearchQuery(id string) (SearchQuery, error) { + return NewTextQuery(SessionColumnUserID, id, TextEquals) +} + +func NewCreationDateQuery(datetime time.Time, compare TimestampComparison) (SearchQuery, error) { + return NewTimestampQuery(SessionColumnCreationDate, datetime, compare) +} + func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Session, string, error)) { return sq.Select( SessionColumnID.identifier(), diff --git a/proto/zitadel/object.proto b/proto/zitadel/object.proto index 95a3ba45eb..7d4f189e52 100644 --- a/proto/zitadel/object.proto +++ b/proto/zitadel/object.proto @@ -92,7 +92,14 @@ enum TextQueryMethod { TEXT_QUERY_METHOD_ENDS_WITH_IGNORE_CASE = 7; } - enum ListQueryMethod { LIST_QUERY_METHOD_IN = 0; } + +enum TimestampQueryMethod { + TIMESTAMP_QUERY_METHOD_EQUALS = 0; + TIMESTAMP_QUERY_METHOD_GREATER = 1; + TIMESTAMP_QUERY_METHOD_GREATER_OR_EQUALS = 2; + TIMESTAMP_QUERY_METHOD_LESS = 3; + TIMESTAMP_QUERY_METHOD_LESS_OR_EQUALS = 4; +} diff --git a/proto/zitadel/session/v2beta/session.proto b/proto/zitadel/session/v2beta/session.proto index 5de49c1855..09f3c6852c 100644 --- a/proto/zitadel/session/v2beta/session.proto +++ b/proto/zitadel/session/v2beta/session.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package zitadel.session.v2beta; +import "zitadel/object.proto"; import "google/protobuf/timestamp.proto"; import "protoc-gen-openapiv2/options/annotations.proto"; import "validate/validate.proto"; @@ -137,6 +138,8 @@ message SearchQuery { option (validate.required) = true; IDsQuery ids_query = 1; + UserIDQuery user_id_query = 2; + CreationDateQuery creation_date_query = 3; } } @@ -144,6 +147,20 @@ message IDsQuery { repeated string ids = 1; } +message UserIDQuery { + string id = 1; +} + +message CreationDateQuery { + google.protobuf.Timestamp creation_date = 1; + zitadel.v1.TimestampQueryMethod method = 2 [ + (validate.rules).enum.defined_only = true, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "defines which timestamp comparison method is used"; + } + ]; +} + message UserAgent { optional string fingerprint_id = 1; optional string ip = 2; @@ -157,4 +174,9 @@ message UserAgent { repeated string values = 1; } map header = 4; +} + +enum SessionFieldName { + SESSION_FIELD_NAME_UNSPECIFIED = 0; + SESSION_FIELD_NAME_CREATION_DATE = 1; } \ No newline at end of file diff --git a/proto/zitadel/session/v2beta/session_service.proto b/proto/zitadel/session/v2beta/session_service.proto index 714e7779d0..7a3a42a400 100644 --- a/proto/zitadel/session/v2beta/session_service.proto +++ b/proto/zitadel/session/v2beta/session_service.proto @@ -248,6 +248,7 @@ service SessionService { message ListSessionsRequest{ zitadel.object.v2beta.ListQuery query = 1; repeated SearchQuery queries = 2; + zitadel.session.v2beta.SessionFieldName sorting_column = 3; } message ListSessionsResponse{