feat: extend session search service (#6746)

* feat: extend session search service (#6029)

add two more searching criteria - human user id and session creation date

optional sorting by the session creation date

* fix: use correct column identifier

* fix: implement Col()

* chore: fix unit tests

* chore: fix linter warnings

---------

Co-authored-by: Fabi <fabienne@zitadel.com>
This commit is contained in:
sp132 2023-11-08 12:32:13 +02:00 committed by GitHub
parent 0d3788b757
commit 9a708b1b78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 251 additions and 46 deletions

View File

@ -18,9 +18,20 @@ import (
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query"
objpb "github.com/zitadel/zitadel/pkg/grpc/object"
session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta"
)
var (
timestampComparisons = map[objpb.TimestampQueryMethod]query.TimestampComparison{
objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_EQUALS: query.TimestampEquals,
objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER: query.TimestampGreater,
objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER_OR_EQUALS: query.TimestampGreaterOrEquals,
objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS: query.TimestampLess,
objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS_OR_EQUALS: query.TimestampLessOrEquals,
}
)
func (s *Server) GetSession(ctx context.Context, req *session.GetSessionRequest) (*session.GetSessionResponse, error) {
res, err := s.query.SessionByID(ctx, true, req.GetSessionId(), req.GetSessionToken())
if err != nil {
@ -240,9 +251,10 @@ func listSessionsRequestToQuery(ctx context.Context, req *session.ListSessionsRe
}
return &query.SessionsSearchQueries{
SearchRequest: query.SearchRequest{
Offset: offset,
Limit: limit,
Asc: asc,
Offset: offset,
Limit: limit,
Asc: asc,
SortingColumn: fieldNameToSessionColumn(req.GetSortingColumn()),
},
Queries: queries,
}, nil
@ -250,8 +262,8 @@ func listSessionsRequestToQuery(ctx context.Context, req *session.ListSessionsRe
func sessionQueriesToQuery(ctx context.Context, queries []*session.SearchQuery) (_ []query.SearchQuery, err error) {
q := make([]query.SearchQuery, len(queries)+1)
for i, query := range queries {
q[i], err = sessionQueryToQuery(query)
for i, v := range queries {
q[i], err = sessionQueryToQuery(v)
if err != nil {
return nil, err
}
@ -264,10 +276,14 @@ func sessionQueriesToQuery(ctx context.Context, queries []*session.SearchQuery)
return q, nil
}
func sessionQueryToQuery(query *session.SearchQuery) (query.SearchQuery, error) {
switch q := query.Query.(type) {
func sessionQueryToQuery(sq *session.SearchQuery) (query.SearchQuery, error) {
switch q := sq.Query.(type) {
case *session.SearchQuery_IdsQuery:
return idsQueryToQuery(q.IdsQuery)
case *session.SearchQuery_UserIdQuery:
return query.NewUserIDSearchQuery(q.UserIdQuery.GetId())
case *session.SearchQuery_CreationDateQuery:
return creationDateQueryToQuery(q.CreationDateQuery)
default:
return nil, caos_errs.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid")
}
@ -277,6 +293,20 @@ func idsQueryToQuery(q *session.IDsQuery) (query.SearchQuery, error) {
return query.NewSessionIDsSearchQuery(q.Ids)
}
func creationDateQueryToQuery(q *session.CreationDateQuery) (query.SearchQuery, error) {
comparison := timestampComparisons[q.GetMethod()]
return query.NewCreationDateQuery(q.GetCreationDate().AsTime(), comparison)
}
func fieldNameToSessionColumn(field session.SessionFieldName) query.Column {
switch field {
case session.SessionFieldName_SESSION_FIELD_NAME_CREATION_DATE:
return query.SessionColumnCreationDate
default:
return query.Column{}
}
}
func (s *Server) createSessionRequestToCommand(ctx context.Context, req *session.CreateSessionRequest) ([]command.SessionCommand, map[string][]byte, *domain.UserAgent, time.Duration, error) {
checks, err := s.checksToCommand(ctx, req.Checks)
if err != nil {

View File

@ -14,6 +14,7 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/api/authz"
objpb "github.com/zitadel/zitadel/pkg/grpc/object"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
@ -22,6 +23,10 @@ import (
session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta"
)
var (
creationDate = time.Date(2023, 10, 10, 14, 15, 0, 0, time.UTC)
)
func Test_sessionsToPb(t *testing.T) {
now := time.Now()
past := now.Add(-time.Hour)
@ -309,11 +314,18 @@ func mustNewListQuery(t testing.TB, column query.Column, list []any, compare que
return q
}
func mustNewTimestampQuery(t testing.TB, column query.Column, ts time.Time, compare query.TimestampComparison) query.SearchQuery {
q, err := query.NewTimestampQuery(column, ts, compare)
require.NoError(t, err)
return q
}
func Test_listSessionsRequestToQuery(t *testing.T) {
type args struct {
ctx context.Context
req *session.ListSessionsRequest
}
tests := []struct {
name string
args args
@ -337,6 +349,26 @@ func Test_listSessionsRequestToQuery(t *testing.T) {
},
},
},
{
name: "default request with sorting column",
args: args{
ctx: authz.NewMockContext("123", "456", "789"),
req: &session.ListSessionsRequest{
SortingColumn: session.SessionFieldName_SESSION_FIELD_NAME_CREATION_DATE,
},
},
want: &query.SessionsSearchQueries{
SearchRequest: query.SearchRequest{
Offset: 0,
Limit: 0,
SortingColumn: query.SessionColumnCreationDate,
Asc: false,
},
Queries: []query.SearchQuery{
mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals),
},
},
},
{
name: "with list query and sessions",
args: args{
@ -358,6 +390,17 @@ func Test_listSessionsRequestToQuery(t *testing.T) {
Ids: []string{"4", "5", "6"},
},
}},
{Query: &session.SearchQuery_UserIdQuery{
UserIdQuery: &session.UserIDQuery{
Id: "10",
},
}},
{Query: &session.SearchQuery_CreationDateQuery{
CreationDateQuery: &session.CreationDateQuery{
CreationDate: timestamppb.New(creationDate),
Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER,
},
}},
},
},
},
@ -370,6 +413,8 @@ func Test_listSessionsRequestToQuery(t *testing.T) {
Queries: []query.SearchQuery{
mustNewListQuery(t, query.SessionColumnID, []interface{}{"1", "2", "3"}, query.ListIn),
mustNewListQuery(t, query.SessionColumnID, []interface{}{"4", "5", "6"}, query.ListIn),
mustNewTextQuery(t, query.SessionColumnUserID, "10", query.TextEquals),
mustNewTimestampQuery(t, query.SessionColumnCreationDate, creationDate, query.TimestampGreater),
mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals),
},
},
@ -487,7 +532,7 @@ func Test_sessionQueryToQuery(t *testing.T) {
wantErr: caos_errs.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid"),
},
{
name: "query",
name: "ids query",
args: args{&session.SearchQuery{
Query: &session.SearchQuery_IdsQuery{
IdsQuery: &session.IDsQuery{
@ -497,6 +542,40 @@ func Test_sessionQueryToQuery(t *testing.T) {
}},
want: mustNewListQuery(t, query.SessionColumnID, []interface{}{"1", "2", "3"}, query.ListIn),
},
{
name: "user id query",
args: args{&session.SearchQuery{
Query: &session.SearchQuery_UserIdQuery{
UserIdQuery: &session.UserIDQuery{
Id: "10",
},
},
}},
want: mustNewTextQuery(t, query.SessionColumnUserID, "10", query.TextEquals),
},
{
name: "creation date query",
args: args{&session.SearchQuery{
Query: &session.SearchQuery_CreationDateQuery{
CreationDateQuery: &session.CreationDateQuery{
CreationDate: timestamppb.New(creationDate),
Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS,
},
},
}},
want: mustNewTimestampQuery(t, query.SessionColumnCreationDate, creationDate, query.TimestampLess),
},
{
name: "creation date query with default method",
args: args{&session.SearchQuery{
Query: &session.SearchQuery_CreationDateQuery{
CreationDateQuery: &session.CreationDateQuery{
CreationDate: timestamppb.New(creationDate),
},
},
}},
want: mustNewTimestampQuery(t, query.SessionColumnCreationDate, creationDate, query.TimestampEquals),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@ -3,6 +3,7 @@ package query
import (
"errors"
"reflect"
"time"
sq "github.com/Masterminds/squirrel"
@ -231,36 +232,41 @@ func (q *InTextQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
return query.Where(q.comp())
}
func (s *InTextQuery) comp() sq.Sqlizer {
func (q *InTextQuery) comp() sq.Sqlizer {
// This translates to an IN query
return sq.Eq{s.Column.identifier(): s.Values}
return sq.Eq{q.Column.identifier(): q.Values}
}
func (q *TextQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
return query.Where(q.comp())
}
func (s *TextQuery) comp() sq.Sqlizer {
switch s.Compare {
func (q *TextQuery) comp() sq.Sqlizer {
switch q.Compare {
case TextEquals:
return sq.Eq{s.Column.identifier(): s.Text}
return sq.Eq{q.Column.identifier(): q.Text}
case TextNotEquals:
return sq.NotEq{q.Column.identifier(): q.Text}
case TextEqualsIgnoreCase:
return sq.ILike{s.Column.identifier(): s.Text}
return sq.ILike{q.Column.identifier(): q.Text}
case TextStartsWith:
return sq.Like{s.Column.identifier(): s.Text + "%"}
return sq.Like{q.Column.identifier(): q.Text + "%"}
case TextStartsWithIgnoreCase:
return sq.ILike{s.Column.identifier(): s.Text + "%"}
return sq.ILike{q.Column.identifier(): q.Text + "%"}
case TextEndsWith:
return sq.Like{s.Column.identifier(): "%" + s.Text}
return sq.Like{q.Column.identifier(): "%" + q.Text}
case TextEndsWithIgnoreCase:
return sq.ILike{s.Column.identifier(): "%" + s.Text}
return sq.ILike{q.Column.identifier(): "%" + q.Text}
case TextContains:
return sq.Like{s.Column.identifier(): "%" + s.Text + "%"}
return sq.Like{q.Column.identifier(): "%" + q.Text + "%"}
case TextContainsIgnoreCase:
return sq.ILike{s.Column.identifier(): "%" + s.Text + "%"}
return sq.ILike{q.Column.identifier(): "%" + q.Text + "%"}
case TextListContains:
return &listContains{col: s.Column, args: []interface{}{s.Text}}
return &listContains{col: q.Column, args: []interface{}{q.Text}}
case textCompareMax:
return nil
}
return nil
}
@ -341,19 +347,22 @@ func (q *NumberQuery) Col() Column {
return q.Column
}
func (s *NumberQuery) comp() sq.Sqlizer {
switch s.Compare {
func (q *NumberQuery) comp() sq.Sqlizer {
switch q.Compare {
case NumberEquals:
return sq.Eq{s.Column.identifier(): s.Number}
return sq.Eq{q.Column.identifier(): q.Number}
case NumberNotEquals:
return sq.NotEq{s.Column.identifier(): s.Number}
return sq.NotEq{q.Column.identifier(): q.Number}
case NumberLess:
return sq.Lt{s.Column.identifier(): s.Number}
return sq.Lt{q.Column.identifier(): q.Number}
case NumberGreater:
return sq.Gt{s.Column.identifier(): s.Number}
return sq.Gt{q.Column.identifier(): q.Number}
case NumberListContains:
return &listContains{col: s.Column, args: []interface{}{s.Number}}
return &listContains{col: q.Column, args: []interface{}{q.Number}}
case numberCompareMax:
return nil
}
return nil
}
@ -442,19 +451,19 @@ func (q *ListQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
return query.Where(q.comp())
}
func (s *ListQuery) comp() sq.Sqlizer {
switch s.Compare {
case ListIn:
if subSelect, ok := s.Data.(*SubSelect); ok {
subSelect, args, err := subSelect.comp().ToSql()
if err != nil {
return nil
}
return sq.Expr(s.Column.identifier()+" IN ( "+subSelect+" )", args...)
}
return sq.Eq{s.Column.identifier(): s.Data}
func (q *ListQuery) comp() sq.Sqlizer {
if q.Compare != ListIn {
return nil
}
return nil
if subSelect, ok := q.Data.(*SubSelect); ok {
subSelect, args, err := subSelect.comp().ToSql()
if err != nil {
return nil
}
return sq.Expr(q.Column.identifier()+" IN ( "+subSelect+" )", args...)
}
return sq.Eq{q.Column.identifier(): q.Data}
}
func (q *ListQuery) Col() Column {
@ -524,16 +533,64 @@ func (q *BoolQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
return query.Where(q.comp())
}
func (s *BoolQuery) comp() sq.Sqlizer {
return sq.Eq{s.Column.identifier(): s.Value}
func (q *BoolQuery) comp() sq.Sqlizer {
return sq.Eq{q.Column.identifier(): q.Value}
}
type TimestampComparison int
const (
TimestampEquals TimestampComparison = iota
TimestampGreater
TimestampGreaterOrEquals
TimestampLess
TimestampLessOrEquals
)
type TimestampQuery struct {
Column Column
Compare TimestampComparison
Value time.Time
}
func NewTimestampQuery(c Column, value time.Time, compare TimestampComparison) (*TimestampQuery, error) {
return &TimestampQuery{
Column: c,
Compare: compare,
Value: value,
}, nil
}
func (q *TimestampQuery) Col() Column {
return q.Column
}
func (q *TimestampQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
return query.Where(q.comp())
}
func (q *TimestampQuery) comp() sq.Sqlizer {
switch q.Compare {
case TimestampEquals:
return sq.Eq{q.Column.identifier(): q.Value}
case TimestampGreater:
return sq.Gt{q.Column.identifier(): q.Value}
case TimestampGreaterOrEquals:
return sq.GtOrEq{q.Column.identifier(): q.Value}
case TimestampLess:
return sq.Lt{q.Column.identifier(): q.Value}
case TimestampLessOrEquals:
return sq.LtOrEq{q.Column.identifier(): q.Value}
}
return nil
}
var (
//countColumn represents the default counter for search responses
// countColumn represents the default counter for search responses
countColumn = Column{
name: "COUNT(*) OVER ()",
}
//uniqueColumn shows if there are any results
// uniqueColumn shows if there are any results
uniqueColumn = Column{
name: "COUNT(*) = 0",
}

View File

@ -239,7 +239,8 @@ func (q *Queries) SearchSessions(ctx context.Context, queries *SessionsSearchQue
stmt, args, err := queries.toQuery(query).
Where(sq.Eq{
SessionColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
}).ToSql()
}).
ToSql()
if err != nil {
return nil, errors.ThrowInvalidArgument(err, "QUERY-sn9Jf", "Errors.Query.InvalidRequest")
}
@ -268,6 +269,14 @@ func NewSessionCreatorSearchQuery(creator string) (SearchQuery, error) {
return NewTextQuery(SessionColumnCreator, creator, TextEquals)
}
func NewUserIDSearchQuery(id string) (SearchQuery, error) {
return NewTextQuery(SessionColumnUserID, id, TextEquals)
}
func NewCreationDateQuery(datetime time.Time, compare TimestampComparison) (SearchQuery, error) {
return NewTimestampQuery(SessionColumnCreationDate, datetime, compare)
}
func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Session, string, error)) {
return sq.Select(
SessionColumnID.identifier(),

View File

@ -92,7 +92,14 @@ enum TextQueryMethod {
TEXT_QUERY_METHOD_ENDS_WITH_IGNORE_CASE = 7;
}
enum ListQueryMethod {
LIST_QUERY_METHOD_IN = 0;
}
enum TimestampQueryMethod {
TIMESTAMP_QUERY_METHOD_EQUALS = 0;
TIMESTAMP_QUERY_METHOD_GREATER = 1;
TIMESTAMP_QUERY_METHOD_GREATER_OR_EQUALS = 2;
TIMESTAMP_QUERY_METHOD_LESS = 3;
TIMESTAMP_QUERY_METHOD_LESS_OR_EQUALS = 4;
}

View File

@ -2,6 +2,7 @@ syntax = "proto3";
package zitadel.session.v2beta;
import "zitadel/object.proto";
import "google/protobuf/timestamp.proto";
import "protoc-gen-openapiv2/options/annotations.proto";
import "validate/validate.proto";
@ -137,6 +138,8 @@ message SearchQuery {
option (validate.required) = true;
IDsQuery ids_query = 1;
UserIDQuery user_id_query = 2;
CreationDateQuery creation_date_query = 3;
}
}
@ -144,6 +147,20 @@ message IDsQuery {
repeated string ids = 1;
}
message UserIDQuery {
string id = 1;
}
message CreationDateQuery {
google.protobuf.Timestamp creation_date = 1;
zitadel.v1.TimestampQueryMethod method = 2 [
(validate.rules).enum.defined_only = true,
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "defines which timestamp comparison method is used";
}
];
}
message UserAgent {
optional string fingerprint_id = 1;
optional string ip = 2;
@ -158,3 +175,8 @@ message UserAgent {
}
map<string,HeaderValues> header = 4;
}
enum SessionFieldName {
SESSION_FIELD_NAME_UNSPECIFIED = 0;
SESSION_FIELD_NAME_CREATION_DATE = 1;
}

View File

@ -248,6 +248,7 @@ service SessionService {
message ListSessionsRequest{
zitadel.object.v2beta.ListQuery query = 1;
repeated SearchQuery queries = 2;
zitadel.session.v2beta.SessionFieldName sorting_column = 3;
}
message ListSessionsResponse{