From 3da2db0851d5d1bd0c20e6e13eea27c4f655667c Mon Sep 17 00:00:00 2001 From: Stefan Benz <46600784+stebenz@users.noreply.github.com> Date: Fri, 27 Dec 2024 17:01:37 +0100 Subject: [PATCH] fix: add search queries to ListSessions and remodel integration tests --- cmd/start/start.go | 4 +- .../session/v2/integration_test/query_test.go | 233 +++++++++++++++++ .../v2/integration_test/server_test.go | 74 ++++++ .../v2/integration_test/session_test.go | 137 ++-------- internal/api/grpc/session/v2/query.go | 242 ++++++++++++++++++ internal/api/grpc/session/v2/server.go | 9 +- internal/api/grpc/session/v2/session.go | 230 ----------------- internal/api/grpc/session/v2beta/server.go | 9 +- internal/api/grpc/session/v2beta/session.go | 4 +- .../eventstore/token_verifier.go | 2 +- internal/notification/handlers/queries.go | 2 +- .../notification/handlers/user_notifier.go | 4 +- .../handlers/user_notifier_legacy.go | 4 +- internal/query/session.go | 68 ++++- proto/zitadel/session/v2/session.proto | 10 +- 15 files changed, 663 insertions(+), 369 deletions(-) create mode 100644 internal/api/grpc/session/v2/integration_test/query_test.go create mode 100644 internal/api/grpc/session/v2/integration_test/server_test.go create mode 100644 internal/api/grpc/session/v2/query.go diff --git a/cmd/start/start.go b/cmd/start/start.go index 72ab9ea862..f6a53f6502 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -438,7 +438,7 @@ func startAPIs( if err := apis.RegisterService(ctx, user_v2.CreateServer(commands, queries, keys.User, keys.IDPConfig, idp.CallbackURL(), idp.SAMLRootURL(), assets.AssetAPI(), permissionCheck)); err != nil { return nil, err } - if err := apis.RegisterService(ctx, session_v2beta.CreateServer(commands, queries)); err != nil { + if err := apis.RegisterService(ctx, session_v2beta.CreateServer(commands, queries, permissionCheck)); err != nil { return nil, err } if err := apis.RegisterService(ctx, settings_v2beta.CreateServer(commands, queries)); err != nil { @@ -450,7 +450,7 @@ func startAPIs( if err := apis.RegisterService(ctx, feature_v2beta.CreateServer(commands, queries)); err != nil { return nil, err } - if err := apis.RegisterService(ctx, session_v2.CreateServer(commands, queries)); err != nil { + if err := apis.RegisterService(ctx, session_v2.CreateServer(commands, queries, permissionCheck)); err != nil { return nil, err } if err := apis.RegisterService(ctx, settings_v2.CreateServer(commands, queries)); err != nil { diff --git a/internal/api/grpc/session/v2/integration_test/query_test.go b/internal/api/grpc/session/v2/integration_test/query_test.go new file mode 100644 index 0000000000..1889be1588 --- /dev/null +++ b/internal/api/grpc/session/v2/integration_test/query_test.go @@ -0,0 +1,233 @@ +//go:build integration + +package session_test + +import ( + "context" + "testing" + "time" + + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/zitadel/zitadel/internal/integration" + "github.com/zitadel/zitadel/pkg/grpc/session/v2" +) + +func TestServer_GetSession(t *testing.T) { + type args struct { + ctx context.Context + req *session.GetSessionRequest + dep func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 + } + tests := []struct { + name string + args args + want *session.GetSessionResponse + wantFactors []wantFactor + wantExpirationWindow time.Duration + wantErr bool + }{ + { + name: "get session, no id provided", + args: args{ + CTX, + &session.GetSessionRequest{ + SessionId: "", + }, + nil, + }, + wantErr: true, + }, + { + name: "get session, not found", + args: args{ + CTX, + &session.GetSessionRequest{ + SessionId: "unknown", + }, + nil, + }, + wantErr: true, + }, + { + name: "get session, no permission", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{}) + require.NoError(t, err) + request.SessionId = resp.SessionId + return resp.GetDetails().GetSequence() + }, + }, + wantErr: true, + }, + { + name: "get session, permission, ok", + args: args{ + CTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{}) + require.NoError(t, err) + request.SessionId = resp.SessionId + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + { + name: "get session, token, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{}) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + { + name: "get session, user agent, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{ + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("fingerPrintID"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{ + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("fingerPrintID"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + { + name: "get session, lifetime, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{ + Lifetime: durationpb.New(5 * time.Minute), + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + wantExpirationWindow: 5 * time.Minute, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + { + name: "get session, metadata, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{ + Metadata: map[string][]byte{"foo": []byte("bar")}, + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{ + Metadata: map[string][]byte{"foo": []byte("bar")}, + }, + }, + }, + { + name: "get session, user, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{ + Checks: &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: User.GetUserId(), + }, + }, + }, + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var sequence uint64 + if tt.args.dep != nil { + sequence = tt.args.dep(tt.args.ctx, t, tt.args.req) + } + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(tt.args.ctx, time.Minute) + require.EventuallyWithT(t, func(ttt *assert.CollectT) { + got, err := Client.GetSession(tt.args.ctx, tt.args.req) + if tt.wantErr { + assert.Error(ttt, err) + return + } + if !assert.NoError(ttt, err) { + return + } + + tt.want.Session.Id = tt.args.req.SessionId + tt.want.Session.Sequence = sequence + verifySession(ttt, got.GetSession(), tt.want.GetSession(), time.Minute, tt.wantExpirationWindow, User.GetUserId(), tt.wantFactors...) + }, retryDuration, tick) + }) + } +} diff --git a/internal/api/grpc/session/v2/integration_test/server_test.go b/internal/api/grpc/session/v2/integration_test/server_test.go new file mode 100644 index 0000000000..70e2146069 --- /dev/null +++ b/internal/api/grpc/session/v2/integration_test/server_test.go @@ -0,0 +1,74 @@ +//go:build integration + +package session_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/integration" + "github.com/zitadel/zitadel/pkg/grpc/session/v2" + "github.com/zitadel/zitadel/pkg/grpc/user/v2" +) + +var ( + CTX context.Context + IAMOwnerCTX context.Context + UserCTX context.Context + Instance *integration.Instance + Client session.SessionServiceClient + User *user.AddHumanUserResponse + DeactivatedUser *user.AddHumanUserResponse + LockedUser *user.AddHumanUserResponse +) + +func TestMain(m *testing.M) { + os.Exit(func() int { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + + Instance = integration.NewInstance(ctx) + Client = Instance.Client.SessionV2 + + CTX = Instance.WithAuthorization(ctx, integration.UserTypeOrgOwner) + IAMOwnerCTX = Instance.WithAuthorization(ctx, integration.UserTypeIAMOwner) + UserCTX = Instance.WithAuthorization(ctx, integration.UserTypeNoPermission) + User = createFullUser(CTX) + DeactivatedUser = createDeactivatedUser(CTX) + LockedUser = createLockedUser(CTX) + return m.Run() + }()) +} + +func createFullUser(ctx context.Context) *user.AddHumanUserResponse { + userResp := Instance.CreateHumanUser(ctx) + Instance.Client.UserV2.VerifyEmail(ctx, &user.VerifyEmailRequest{ + UserId: userResp.GetUserId(), + VerificationCode: userResp.GetEmailCode(), + }) + Instance.Client.UserV2.VerifyPhone(ctx, &user.VerifyPhoneRequest{ + UserId: userResp.GetUserId(), + VerificationCode: userResp.GetPhoneCode(), + }) + Instance.SetUserPassword(ctx, userResp.GetUserId(), integration.UserPassword, false) + Instance.RegisterUserPasskey(ctx, userResp.GetUserId()) + return userResp +} + +func createDeactivatedUser(ctx context.Context) *user.AddHumanUserResponse { + userResp := Instance.CreateHumanUser(ctx) + _, err := Instance.Client.UserV2.DeactivateUser(ctx, &user.DeactivateUserRequest{UserId: userResp.GetUserId()}) + logging.OnError(err).Fatal("deactivate human user") + return userResp +} + +func createLockedUser(ctx context.Context) *user.AddHumanUserResponse { + userResp := Instance.CreateHumanUser(ctx) + _, err := Instance.Client.UserV2.LockUser(ctx, &user.LockUserRequest{UserId: userResp.GetUserId()}) + logging.OnError(err).Fatal("lock human user") + return userResp +} diff --git a/internal/api/grpc/session/v2/integration_test/session_test.go b/internal/api/grpc/session/v2/integration_test/session_test.go index ccd08f3471..18ba018f91 100644 --- a/internal/api/grpc/session/v2/integration_test/session_test.go +++ b/internal/api/grpc/session/v2/integration_test/session_test.go @@ -5,7 +5,6 @@ package session_test import ( "context" "fmt" - "os" "testing" "time" @@ -14,7 +13,6 @@ import ( "github.com/pquerna/otp/totp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/zitadel/logging" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -29,63 +27,7 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -var ( - CTX context.Context - IAMOwnerCTX context.Context - Instance *integration.Instance - Client session.SessionServiceClient - User *user.AddHumanUserResponse - DeactivatedUser *user.AddHumanUserResponse - LockedUser *user.AddHumanUserResponse -) - -func TestMain(m *testing.M) { - os.Exit(func() int { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) - defer cancel() - - Instance = integration.NewInstance(ctx) - Client = Instance.Client.SessionV2 - - CTX = Instance.WithAuthorization(ctx, integration.UserTypeOrgOwner) - IAMOwnerCTX = Instance.WithAuthorization(ctx, integration.UserTypeIAMOwner) - User = createFullUser(CTX) - DeactivatedUser = createDeactivatedUser(CTX) - LockedUser = createLockedUser(CTX) - return m.Run() - }()) -} - -func createFullUser(ctx context.Context) *user.AddHumanUserResponse { - userResp := Instance.CreateHumanUser(ctx) - Instance.Client.UserV2.VerifyEmail(ctx, &user.VerifyEmailRequest{ - UserId: userResp.GetUserId(), - VerificationCode: userResp.GetEmailCode(), - }) - Instance.Client.UserV2.VerifyPhone(ctx, &user.VerifyPhoneRequest{ - UserId: userResp.GetUserId(), - VerificationCode: userResp.GetPhoneCode(), - }) - Instance.SetUserPassword(ctx, userResp.GetUserId(), integration.UserPassword, false) - Instance.RegisterUserPasskey(ctx, userResp.GetUserId()) - return userResp -} - -func createDeactivatedUser(ctx context.Context) *user.AddHumanUserResponse { - userResp := Instance.CreateHumanUser(ctx) - _, err := Instance.Client.UserV2.DeactivateUser(ctx, &user.DeactivateUserRequest{UserId: userResp.GetUserId()}) - logging.OnError(err).Fatal("deactivate human user") - return userResp -} - -func createLockedUser(ctx context.Context) *user.AddHumanUserResponse { - userResp := Instance.CreateHumanUser(ctx) - _, err := Instance.Client.UserV2.LockUser(ctx, &user.LockUserRequest{UserId: userResp.GetUserId()}) - logging.OnError(err).Fatal("lock human user") - return userResp -} - -func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, window time.Duration, metadata map[string][]byte, userAgent *session.UserAgent, expirationWindow time.Duration, userID string, factors ...wantFactor) *session.Session { +func verifyCurrentSession(t *testing.T, id, token string, sequence uint64, window time.Duration, metadata map[string][]byte, userAgent *session.UserAgent, expirationWindow time.Duration, userID string, factors ...wantFactor) *session.Session { t.Helper() require.NotEmpty(t, id) require.NotEmpty(t, token) @@ -96,15 +38,25 @@ func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, windo }) require.NoError(t, err) s := resp.GetSession() + want := &session.Session{ + Id: id, + Sequence: sequence, + Metadata: metadata, + UserAgent: userAgent, + } + verifySession(t, s, want, window, expirationWindow, userID, factors...) + return s +} - assert.Equal(t, id, s.GetId()) +func verifySession(t assert.TestingT, s *session.Session, want *session.Session, window time.Duration, expirationWindow time.Duration, userID string, factors ...wantFactor) { + assert.Equal(t, want.Id, s.GetId()) assert.WithinRange(t, s.GetCreationDate().AsTime(), time.Now().Add(-window), time.Now().Add(window)) assert.WithinRange(t, s.GetChangeDate().AsTime(), time.Now().Add(-window), time.Now().Add(window)) - assert.Equal(t, sequence, s.GetSequence()) - assert.Equal(t, metadata, s.GetMetadata()) + assert.Equal(t, want.Sequence, s.GetSequence()) + assert.Equal(t, want.Metadata, s.GetMetadata()) - if !proto.Equal(userAgent, s.GetUserAgent()) { - t.Errorf("user agent =\n%v\nwant\n%v", s.GetUserAgent(), userAgent) + if !proto.Equal(want.UserAgent, s.GetUserAgent()) { + t.Errorf("user agent =\n%v\nwant\n%v", s.GetUserAgent(), want.UserAgent) } if expirationWindow == 0 { assert.Nil(t, s.GetExpirationDate()) @@ -113,7 +65,6 @@ func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, windo } verifyFactors(t, s.GetFactors(), window, userID, factors) - return s } type wantFactor int @@ -129,7 +80,7 @@ const ( wantOTPEmailFactor ) -func verifyFactors(t testing.TB, factors *session.Factors, window time.Duration, userID string, want []wantFactor) { +func verifyFactors(t assert.TestingT, factors *session.Factors, window time.Duration, userID string, want []wantFactor) { for _, w := range want { switch w { case wantUserFactor: @@ -194,8 +145,15 @@ func TestServer_CreateSession(t *testing.T) { }, }, { - name: "user agent", + name: "full session", req: &session.CreateSessionRequest{ + Checks: &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: User.GetUserId(), + }, + }, + }, Metadata: map[string][]byte{"foo": []byte("bar")}, UserAgent: &session.UserAgent{ FingerprintId: gu.Ptr("fingerPrintID"), @@ -205,6 +163,7 @@ func TestServer_CreateSession(t *testing.T) { "foo": {Values: []string{"foo", "bar"}}, }, }, + Lifetime: durationpb.New(5 * time.Minute), }, want: &session.CreateSessionResponse{ Details: &object.Details{ @@ -212,14 +171,6 @@ func TestServer_CreateSession(t *testing.T) { ResourceOwner: Instance.ID(), }, }, - wantUserAgent: &session.UserAgent{ - FingerprintId: gu.Ptr("fingerPrintID"), - Ip: gu.Ptr("1.2.3.4"), - Description: gu.Ptr("Description"), - Header: map[string]*session.UserAgent_HeaderValues{ - "foo": {Values: []string{"foo", "bar"}}, - }, - }, }, { name: "negative lifetime", @@ -229,40 +180,6 @@ func TestServer_CreateSession(t *testing.T) { }, wantErr: true, }, - { - name: "lifetime", - req: &session.CreateSessionRequest{ - Metadata: map[string][]byte{"foo": []byte("bar")}, - Lifetime: durationpb.New(5 * time.Minute), - }, - want: &session.CreateSessionResponse{ - Details: &object.Details{ - ChangeDate: timestamppb.Now(), - ResourceOwner: Instance.ID(), - }, - }, - wantExpirationWindow: 5 * time.Minute, - }, - { - name: "with user", - req: &session.CreateSessionRequest{ - Checks: &session.Checks{ - User: &session.CheckUser{ - Search: &session.CheckUser_UserId{ - UserId: User.GetUserId(), - }, - }, - }, - Metadata: map[string][]byte{"foo": []byte("bar")}, - }, - want: &session.CreateSessionResponse{ - Details: &object.Details{ - ChangeDate: timestamppb.Now(), - ResourceOwner: Instance.ID(), - }, - }, - wantFactors: []wantFactor{wantUserFactor}, - }, { name: "deactivated user", req: &session.CreateSessionRequest{ @@ -340,8 +257,6 @@ func TestServer_CreateSession(t *testing.T) { } require.NoError(t, err) integration.AssertDetails(t, tt.want, got) - - verifyCurrentSession(t, got.GetSessionId(), got.GetSessionToken(), got.GetDetails().GetSequence(), time.Minute, tt.req.GetMetadata(), tt.wantUserAgent, tt.wantExpirationWindow, User.GetUserId(), tt.wantFactors...) }) } } diff --git a/internal/api/grpc/session/v2/query.go b/internal/api/grpc/session/v2/query.go new file mode 100644 index 0000000000..4b250a72af --- /dev/null +++ b/internal/api/grpc/session/v2/query.go @@ -0,0 +1,242 @@ +package session + +import ( + "context" + "time" + + "github.com/muhlemmer/gu" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/grpc/object/v2" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/internal/zerrors" + objpb "github.com/zitadel/zitadel/pkg/grpc/object" + "github.com/zitadel/zitadel/pkg/grpc/session/v2" +) + +var ( + timestampComparisons = map[objpb.TimestampQueryMethod]query.TimestampComparison{ + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_EQUALS: query.TimestampEquals, + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER: query.TimestampGreater, + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER_OR_EQUALS: query.TimestampGreaterOrEquals, + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS: query.TimestampLess, + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS_OR_EQUALS: query.TimestampLessOrEquals, + } +) + +func (s *Server) GetSession(ctx context.Context, req *session.GetSessionRequest) (*session.GetSessionResponse, error) { + res, err := s.query.SessionByID(ctx, true, req.GetSessionId(), req.GetSessionToken(), s.checkPermission) + if err != nil { + return nil, err + } + return &session.GetSessionResponse{ + Session: sessionToPb(res), + }, nil +} + +func (s *Server) ListSessions(ctx context.Context, req *session.ListSessionsRequest) (*session.ListSessionsResponse, error) { + queries, err := listSessionsRequestToQuery(ctx, req) + if err != nil { + return nil, err + } + sessions, err := s.query.SearchSessions(ctx, queries, s.checkPermission) + if err != nil { + return nil, err + } + return &session.ListSessionsResponse{ + Details: object.ToListDetails(sessions.SearchResponse), + Sessions: sessionsToPb(sessions.Sessions), + }, nil +} + +func listSessionsRequestToQuery(ctx context.Context, req *session.ListSessionsRequest) (*query.SessionsSearchQueries, error) { + offset, limit, asc := object.ListQueryToQuery(req.Query) + queries, err := sessionQueriesToQuery(ctx, req.GetQueries()) + if err != nil { + return nil, err + } + return &query.SessionsSearchQueries{ + SearchRequest: query.SearchRequest{ + Offset: offset, + Limit: limit, + Asc: asc, + SortingColumn: fieldNameToSessionColumn(req.GetSortingColumn()), + }, + Queries: queries, + }, nil +} + +func sessionQueriesToQuery(ctx context.Context, queries []*session.SearchQuery) (_ []query.SearchQuery, err error) { + q := make([]query.SearchQuery, len(queries)) + for i, v := range queries { + q[i], err = sessionQueryToQuery(ctx, v) + if err != nil { + return nil, err + } + } + return q, nil +} + +func sessionQueryToQuery(ctx context.Context, sq *session.SearchQuery) (query.SearchQuery, error) { + switch q := sq.Query.(type) { + case *session.SearchQuery_IdsQuery: + return idsQueryToQuery(q.IdsQuery) + case *session.SearchQuery_UserIdQuery: + return query.NewUserIDSearchQuery(q.UserIdQuery.GetId()) + case *session.SearchQuery_CreationDateQuery: + return creationDateQueryToQuery(q.CreationDateQuery) + case *session.SearchQuery_OwnCreatorQuery: + return query.NewSessionCreatorSearchQuery(authz.GetCtxData(ctx).UserID) + case *session.SearchQuery_OwnUseragentQuery: + return query.NewSessionUserAgentFingerprintIDSearchQuery(authz.GetCtxData(ctx).AgentID) + default: + return nil, zerrors.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid") + } +} + +func idsQueryToQuery(q *session.IDsQuery) (query.SearchQuery, error) { + return query.NewSessionIDsSearchQuery(q.Ids) +} + +func creationDateQueryToQuery(q *session.CreationDateQuery) (query.SearchQuery, error) { + comparison := timestampComparisons[q.GetMethod()] + return query.NewCreationDateQuery(q.GetCreationDate().AsTime(), comparison) +} + +func fieldNameToSessionColumn(field session.SessionFieldName) query.Column { + switch field { + case session.SessionFieldName_SESSION_FIELD_NAME_CREATION_DATE: + return query.SessionColumnCreationDate + default: + return query.Column{} + } +} + +func sessionsToPb(sessions []*query.Session) []*session.Session { + s := make([]*session.Session, len(sessions)) + for i, session := range sessions { + s[i] = sessionToPb(session) + } + return s +} + +func sessionToPb(s *query.Session) *session.Session { + return &session.Session{ + Id: s.ID, + CreationDate: timestamppb.New(s.CreationDate), + ChangeDate: timestamppb.New(s.ChangeDate), + Sequence: s.Sequence, + Factors: factorsToPb(s), + Metadata: s.Metadata, + UserAgent: userAgentToPb(s.UserAgent), + ExpirationDate: expirationToPb(s.Expiration), + } +} + +func userAgentToPb(ua domain.UserAgent) *session.UserAgent { + if ua.IsEmpty() { + return nil + } + + out := &session.UserAgent{ + FingerprintId: ua.FingerprintID, + Description: ua.Description, + } + if ua.IP != nil { + out.Ip = gu.Ptr(ua.IP.String()) + } + if ua.Header == nil { + return out + } + out.Header = make(map[string]*session.UserAgent_HeaderValues, len(ua.Header)) + for k, v := range ua.Header { + out.Header[k] = &session.UserAgent_HeaderValues{ + Values: v, + } + } + return out +} + +func expirationToPb(expiration time.Time) *timestamppb.Timestamp { + if expiration.IsZero() { + return nil + } + return timestamppb.New(expiration) +} + +func factorsToPb(s *query.Session) *session.Factors { + user := userFactorToPb(s.UserFactor) + if user == nil { + return nil + } + return &session.Factors{ + User: user, + Password: passwordFactorToPb(s.PasswordFactor), + WebAuthN: webAuthNFactorToPb(s.WebAuthNFactor), + Intent: intentFactorToPb(s.IntentFactor), + Totp: totpFactorToPb(s.TOTPFactor), + OtpSms: otpFactorToPb(s.OTPSMSFactor), + OtpEmail: otpFactorToPb(s.OTPEmailFactor), + } +} + +func passwordFactorToPb(factor query.SessionPasswordFactor) *session.PasswordFactor { + if factor.PasswordCheckedAt.IsZero() { + return nil + } + return &session.PasswordFactor{ + VerifiedAt: timestamppb.New(factor.PasswordCheckedAt), + } +} + +func intentFactorToPb(factor query.SessionIntentFactor) *session.IntentFactor { + if factor.IntentCheckedAt.IsZero() { + return nil + } + return &session.IntentFactor{ + VerifiedAt: timestamppb.New(factor.IntentCheckedAt), + } +} + +func webAuthNFactorToPb(factor query.SessionWebAuthNFactor) *session.WebAuthNFactor { + if factor.WebAuthNCheckedAt.IsZero() { + return nil + } + return &session.WebAuthNFactor{ + VerifiedAt: timestamppb.New(factor.WebAuthNCheckedAt), + UserVerified: factor.UserVerified, + } +} + +func totpFactorToPb(factor query.SessionTOTPFactor) *session.TOTPFactor { + if factor.TOTPCheckedAt.IsZero() { + return nil + } + return &session.TOTPFactor{ + VerifiedAt: timestamppb.New(factor.TOTPCheckedAt), + } +} + +func otpFactorToPb(factor query.SessionOTPFactor) *session.OTPFactor { + if factor.OTPCheckedAt.IsZero() { + return nil + } + return &session.OTPFactor{ + VerifiedAt: timestamppb.New(factor.OTPCheckedAt), + } +} + +func userFactorToPb(factor query.SessionUserFactor) *session.UserFactor { + if factor.UserID == "" || factor.UserCheckedAt.IsZero() { + return nil + } + return &session.UserFactor{ + VerifiedAt: timestamppb.New(factor.UserCheckedAt), + Id: factor.UserID, + LoginName: factor.LoginName, + DisplayName: factor.DisplayName, + OrganizationId: factor.ResourceOwner, + } +} diff --git a/internal/api/grpc/session/v2/server.go b/internal/api/grpc/session/v2/server.go index e94336bf47..ee534cb26c 100644 --- a/internal/api/grpc/session/v2/server.go +++ b/internal/api/grpc/session/v2/server.go @@ -6,6 +6,7 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/pkg/grpc/session/v2" ) @@ -16,6 +17,8 @@ type Server struct { session.UnimplementedSessionServiceServer command *command.Commands query *query.Queries + + checkPermission domain.PermissionCheck } type Config struct{} @@ -23,10 +26,12 @@ type Config struct{} func CreateServer( command *command.Commands, query *query.Queries, + checkPermission domain.PermissionCheck, ) *Server { return &Server{ - command: command, - query: query, + command: command, + query: query, + checkPermission: checkPermission, } } diff --git a/internal/api/grpc/session/v2/session.go b/internal/api/grpc/session/v2/session.go index aa25fa0ae3..7562d64350 100644 --- a/internal/api/grpc/session/v2/session.go +++ b/internal/api/grpc/session/v2/session.go @@ -6,56 +6,17 @@ import ( "net/http" "time" - "github.com/muhlemmer/gu" "golang.org/x/text/language" "google.golang.org/protobuf/types/known/structpb" - "google.golang.org/protobuf/types/known/timestamppb" - "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/object/v2" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/zerrors" - objpb "github.com/zitadel/zitadel/pkg/grpc/object" "github.com/zitadel/zitadel/pkg/grpc/session/v2" ) -var ( - timestampComparisons = map[objpb.TimestampQueryMethod]query.TimestampComparison{ - objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_EQUALS: query.TimestampEquals, - objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER: query.TimestampGreater, - objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER_OR_EQUALS: query.TimestampGreaterOrEquals, - objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS: query.TimestampLess, - objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS_OR_EQUALS: query.TimestampLessOrEquals, - } -) - -func (s *Server) GetSession(ctx context.Context, req *session.GetSessionRequest) (*session.GetSessionResponse, error) { - res, err := s.query.SessionByID(ctx, true, req.GetSessionId(), req.GetSessionToken()) - if err != nil { - return nil, err - } - return &session.GetSessionResponse{ - Session: sessionToPb(res), - }, nil -} - -func (s *Server) ListSessions(ctx context.Context, req *session.ListSessionsRequest) (*session.ListSessionsResponse, error) { - queries, err := listSessionsRequestToQuery(ctx, req) - if err != nil { - return nil, err - } - sessions, err := s.query.SearchSessions(ctx, queries) - if err != nil { - return nil, err - } - return &session.ListSessionsResponse{ - Details: object.ToListDetails(sessions.SearchResponse), - Sessions: sessionsToPb(sessions.Sessions), - }, nil -} - func (s *Server) CreateSession(ctx context.Context, req *session.CreateSessionRequest) (*session.CreateSessionResponse, error) { checks, metadata, userAgent, lifetime, err := s.createSessionRequestToCommand(ctx, req) if err != nil { @@ -110,197 +71,6 @@ func (s *Server) DeleteSession(ctx context.Context, req *session.DeleteSessionRe }, nil } -func sessionsToPb(sessions []*query.Session) []*session.Session { - s := make([]*session.Session, len(sessions)) - for i, session := range sessions { - s[i] = sessionToPb(session) - } - return s -} - -func sessionToPb(s *query.Session) *session.Session { - return &session.Session{ - Id: s.ID, - CreationDate: timestamppb.New(s.CreationDate), - ChangeDate: timestamppb.New(s.ChangeDate), - Sequence: s.Sequence, - Factors: factorsToPb(s), - Metadata: s.Metadata, - UserAgent: userAgentToPb(s.UserAgent), - ExpirationDate: expirationToPb(s.Expiration), - } -} - -func userAgentToPb(ua domain.UserAgent) *session.UserAgent { - if ua.IsEmpty() { - return nil - } - - out := &session.UserAgent{ - FingerprintId: ua.FingerprintID, - Description: ua.Description, - } - if ua.IP != nil { - out.Ip = gu.Ptr(ua.IP.String()) - } - if ua.Header == nil { - return out - } - out.Header = make(map[string]*session.UserAgent_HeaderValues, len(ua.Header)) - for k, v := range ua.Header { - out.Header[k] = &session.UserAgent_HeaderValues{ - Values: v, - } - } - return out -} - -func expirationToPb(expiration time.Time) *timestamppb.Timestamp { - if expiration.IsZero() { - return nil - } - return timestamppb.New(expiration) -} - -func factorsToPb(s *query.Session) *session.Factors { - user := userFactorToPb(s.UserFactor) - if user == nil { - return nil - } - return &session.Factors{ - User: user, - Password: passwordFactorToPb(s.PasswordFactor), - WebAuthN: webAuthNFactorToPb(s.WebAuthNFactor), - Intent: intentFactorToPb(s.IntentFactor), - Totp: totpFactorToPb(s.TOTPFactor), - OtpSms: otpFactorToPb(s.OTPSMSFactor), - OtpEmail: otpFactorToPb(s.OTPEmailFactor), - } -} - -func passwordFactorToPb(factor query.SessionPasswordFactor) *session.PasswordFactor { - if factor.PasswordCheckedAt.IsZero() { - return nil - } - return &session.PasswordFactor{ - VerifiedAt: timestamppb.New(factor.PasswordCheckedAt), - } -} - -func intentFactorToPb(factor query.SessionIntentFactor) *session.IntentFactor { - if factor.IntentCheckedAt.IsZero() { - return nil - } - return &session.IntentFactor{ - VerifiedAt: timestamppb.New(factor.IntentCheckedAt), - } -} - -func webAuthNFactorToPb(factor query.SessionWebAuthNFactor) *session.WebAuthNFactor { - if factor.WebAuthNCheckedAt.IsZero() { - return nil - } - return &session.WebAuthNFactor{ - VerifiedAt: timestamppb.New(factor.WebAuthNCheckedAt), - UserVerified: factor.UserVerified, - } -} - -func totpFactorToPb(factor query.SessionTOTPFactor) *session.TOTPFactor { - if factor.TOTPCheckedAt.IsZero() { - return nil - } - return &session.TOTPFactor{ - VerifiedAt: timestamppb.New(factor.TOTPCheckedAt), - } -} - -func otpFactorToPb(factor query.SessionOTPFactor) *session.OTPFactor { - if factor.OTPCheckedAt.IsZero() { - return nil - } - return &session.OTPFactor{ - VerifiedAt: timestamppb.New(factor.OTPCheckedAt), - } -} - -func userFactorToPb(factor query.SessionUserFactor) *session.UserFactor { - if factor.UserID == "" || factor.UserCheckedAt.IsZero() { - return nil - } - return &session.UserFactor{ - VerifiedAt: timestamppb.New(factor.UserCheckedAt), - Id: factor.UserID, - LoginName: factor.LoginName, - DisplayName: factor.DisplayName, - OrganizationId: factor.ResourceOwner, - } -} - -func listSessionsRequestToQuery(ctx context.Context, req *session.ListSessionsRequest) (*query.SessionsSearchQueries, error) { - offset, limit, asc := object.ListQueryToQuery(req.Query) - queries, err := sessionQueriesToQuery(ctx, req.GetQueries()) - if err != nil { - return nil, err - } - return &query.SessionsSearchQueries{ - SearchRequest: query.SearchRequest{ - Offset: offset, - Limit: limit, - Asc: asc, - SortingColumn: fieldNameToSessionColumn(req.GetSortingColumn()), - }, - Queries: queries, - }, nil -} - -func sessionQueriesToQuery(ctx context.Context, queries []*session.SearchQuery) (_ []query.SearchQuery, err error) { - q := make([]query.SearchQuery, len(queries)+1) - for i, v := range queries { - q[i], err = sessionQueryToQuery(v) - if err != nil { - return nil, err - } - } - creatorQuery, err := query.NewSessionCreatorSearchQuery(authz.GetCtxData(ctx).UserID) - if err != nil { - return nil, err - } - q[len(queries)] = creatorQuery - return q, nil -} - -func sessionQueryToQuery(sq *session.SearchQuery) (query.SearchQuery, error) { - switch q := sq.Query.(type) { - case *session.SearchQuery_IdsQuery: - return idsQueryToQuery(q.IdsQuery) - case *session.SearchQuery_UserIdQuery: - return query.NewUserIDSearchQuery(q.UserIdQuery.GetId()) - case *session.SearchQuery_CreationDateQuery: - return creationDateQueryToQuery(q.CreationDateQuery) - default: - return nil, zerrors.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid") - } -} - -func idsQueryToQuery(q *session.IDsQuery) (query.SearchQuery, error) { - return query.NewSessionIDsSearchQuery(q.Ids) -} - -func creationDateQueryToQuery(q *session.CreationDateQuery) (query.SearchQuery, error) { - comparison := timestampComparisons[q.GetMethod()] - return query.NewCreationDateQuery(q.GetCreationDate().AsTime(), comparison) -} - -func fieldNameToSessionColumn(field session.SessionFieldName) query.Column { - switch field { - case session.SessionFieldName_SESSION_FIELD_NAME_CREATION_DATE: - return query.SessionColumnCreationDate - default: - return query.Column{} - } -} - func (s *Server) createSessionRequestToCommand(ctx context.Context, req *session.CreateSessionRequest) ([]command.SessionCommand, map[string][]byte, *domain.UserAgent, time.Duration, error) { checks, err := s.checksToCommand(ctx, req.Checks) if err != nil { diff --git a/internal/api/grpc/session/v2beta/server.go b/internal/api/grpc/session/v2beta/server.go index 550d013ad5..cf0d0c27f0 100644 --- a/internal/api/grpc/session/v2beta/server.go +++ b/internal/api/grpc/session/v2beta/server.go @@ -6,6 +6,7 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta" ) @@ -16,6 +17,8 @@ type Server struct { session.UnimplementedSessionServiceServer command *command.Commands query *query.Queries + + checkPermission domain.PermissionCheck } type Config struct{} @@ -23,10 +26,12 @@ type Config struct{} func CreateServer( command *command.Commands, query *query.Queries, + checkPermission domain.PermissionCheck, ) *Server { return &Server{ - command: command, - query: query, + command: command, + query: query, + checkPermission: checkPermission, } } diff --git a/internal/api/grpc/session/v2beta/session.go b/internal/api/grpc/session/v2beta/session.go index 7e67a4b3ff..3b36b8ba83 100644 --- a/internal/api/grpc/session/v2beta/session.go +++ b/internal/api/grpc/session/v2beta/session.go @@ -32,7 +32,7 @@ var ( ) func (s *Server) GetSession(ctx context.Context, req *session.GetSessionRequest) (*session.GetSessionResponse, error) { - res, err := s.query.SessionByID(ctx, true, req.GetSessionId(), req.GetSessionToken()) + res, err := s.query.SessionByID(ctx, true, req.GetSessionId(), req.GetSessionToken(), s.checkPermission) if err != nil { return nil, err } @@ -46,7 +46,7 @@ func (s *Server) ListSessions(ctx context.Context, req *session.ListSessionsRequ if err != nil { return nil, err } - sessions, err := s.query.SearchSessions(ctx, queries) + sessions, err := s.query.SearchSessions(ctx, queries, s.checkPermission) if err != nil { return nil, err } diff --git a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go index 9dec3fcf00..b707631c22 100644 --- a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go +++ b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go @@ -159,7 +159,7 @@ func (repo *TokenVerifierRepo) verifySessionToken(ctx context.Context, sessionID ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - session, err := repo.Query.SessionByID(ctx, true, sessionID, token) + session, err := repo.Query.SessionByID(ctx, true, sessionID, token, nil) if err != nil { return "", "", "", err } diff --git a/internal/notification/handlers/queries.go b/internal/notification/handlers/queries.go index 1c8d37598e..a3d68e4797 100644 --- a/internal/notification/handlers/queries.go +++ b/internal/notification/handlers/queries.go @@ -20,7 +20,7 @@ type Queries interface { GetNotifyUserByID(ctx context.Context, shouldTriggered bool, userID string) (*query.NotifyUser, error) CustomTextListByTemplate(ctx context.Context, aggregateID, template string, withOwnerRemoved bool) (*query.CustomTexts, error) SearchInstanceDomains(ctx context.Context, queries *query.InstanceDomainSearchQueries) (*query.InstanceDomains, error) - SessionByID(ctx context.Context, shouldTriggerBulk bool, id, sessionToken string) (*query.Session, error) + SessionByID(ctx context.Context, shouldTriggerBulk bool, id, sessionToken string, check domain.PermissionCheck) (*query.Session, error) NotificationPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (*query.NotificationPolicy, error) SearchMilestones(ctx context.Context, instanceIDs []string, queries *query.MilestonesSearchQueries) (*query.Milestones, error) NotificationProviderByIDAndType(ctx context.Context, aggID string, providerType domain.NotificationProviderType) (*query.DebugNotificationProvider, error) diff --git a/internal/notification/handlers/user_notifier.go b/internal/notification/handlers/user_notifier.go index ec30ab476f..c24b87c2f6 100644 --- a/internal/notification/handlers/user_notifier.go +++ b/internal/notification/handlers/user_notifier.go @@ -400,7 +400,7 @@ func (u *userNotifier) reduceSessionOTPSMSChallenged(event eventstore.Event) (*h if alreadyHandled { return nil } - s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "") + s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "", nil) if err != nil { return err } @@ -496,7 +496,7 @@ func (u *userNotifier) reduceSessionOTPEmailChallenged(event eventstore.Event) ( if alreadyHandled { return nil } - s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "") + s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "", nil) if err != nil { return err } diff --git a/internal/notification/handlers/user_notifier_legacy.go b/internal/notification/handlers/user_notifier_legacy.go index 7df31cdf91..4bfa1a796e 100644 --- a/internal/notification/handlers/user_notifier_legacy.go +++ b/internal/notification/handlers/user_notifier_legacy.go @@ -324,7 +324,7 @@ func (u *userNotifierLegacy) reduceSessionOTPSMSChallenged(event eventstore.Even return handler.NewNoOpStatement(e), nil } ctx := HandlerContext(event.Aggregate()) - s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "") + s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "", nil) if err != nil { return nil, err } @@ -428,7 +428,7 @@ func (u *userNotifierLegacy) reduceSessionOTPEmailChallenged(event eventstore.Ev return handler.NewNoOpStatement(e), nil } ctx := HandlerContext(event.Aggregate()) - s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "") + s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "", nil) if err != nil { return nil, err } diff --git a/internal/query/session.go b/internal/query/session.go index 54afbde064..53607510f7 100644 --- a/internal/query/session.go +++ b/internal/query/session.go @@ -6,6 +6,7 @@ import ( "errors" "net" "net/http" + "slices" "time" sq "github.com/Masterminds/squirrel" @@ -80,6 +81,24 @@ type SessionsSearchQueries struct { Queries []SearchQuery } +func sessionsCheckPermission(ctx context.Context, sessions *Sessions, permissionCheck domain.PermissionCheck) { + sessions.Sessions = slices.DeleteFunc(sessions.Sessions, + func(session *Session) bool { + return sessionCheckPermission(ctx, session.ResourceOwner, session.Creator, permissionCheck) != nil + }, + ) +} + +func sessionCheckPermission(ctx context.Context, resourceOwner string, creator string, permissionCheck domain.PermissionCheck) error { + data := authz.GetCtxData(ctx) + if data.UserID != creator { + if err := permissionCheck(ctx, domain.PermissionSessionRead, resourceOwner, ""); err != nil { + return err + } + } + return nil +} + func (q *SessionsSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder { query = q.SearchRequest.toQuery(query) for _, q := range q.Queries { @@ -195,7 +214,24 @@ var ( } ) -func (q *Queries) SessionByID(ctx context.Context, shouldTriggerBulk bool, id, sessionToken string) (session *Session, err error) { +func (q *Queries) SessionByID(ctx context.Context, shouldTriggerBulk bool, id, sessionToken string, permissionCheck domain.PermissionCheck) (session *Session, err error) { + session, tokenID, err := q.sessionByID(ctx, shouldTriggerBulk, id) + if err != nil { + return nil, err + } + if sessionToken == "" { + if err := sessionCheckPermission(ctx, session.ResourceOwner, session.Creator, permissionCheck); err != nil { + return nil, err + } + return session, nil + } + if err := q.sessionTokenVerifier(ctx, sessionToken, session.ID, tokenID); err != nil { + return nil, zerrors.ThrowPermissionDenied(nil, "QUERY-dsfr3", "Errors.PermissionDenied") + } + return session, nil +} + +func (q *Queries) sessionByID(ctx context.Context, shouldTriggerBulk bool, id string) (session *Session, tokenID string, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -214,27 +250,31 @@ func (q *Queries) SessionByID(ctx context.Context, shouldTriggerBulk bool, id, s }, ).ToSql() if err != nil { - return nil, zerrors.ThrowInternal(err, "QUERY-dn9JW", "Errors.Query.SQLStatement") + return nil, "", zerrors.ThrowInternal(err, "QUERY-dn9JW", "Errors.Query.SQLStatement") } - var tokenID string err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { session, tokenID, err = scan(row) return err }, stmt, args...) if err != nil { - return nil, err + return nil, "", err } - if sessionToken == "" { - return session, nil - } - if err := q.sessionTokenVerifier(ctx, sessionToken, session.ID, tokenID); err != nil { - return nil, zerrors.ThrowPermissionDenied(nil, "QUERY-dsfr3", "Errors.PermissionDenied") - } - return session, nil + return session, tokenID, nil } -func (q *Queries) SearchSessions(ctx context.Context, queries *SessionsSearchQueries) (sessions *Sessions, err error) { +func (q *Queries) SearchSessions(ctx context.Context, queries *SessionsSearchQueries, permissionCheck domain.PermissionCheck) (*Sessions, error) { + sessions, err := q.searchSessions(ctx, queries) + if err != nil { + return nil, err + } + if permissionCheck != nil { + sessionsCheckPermission(ctx, sessions, permissionCheck) + } + return sessions, nil +} + +func (q *Queries) searchSessions(ctx context.Context, queries *SessionsSearchQueries) (sessions *Sessions, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -272,6 +312,10 @@ func NewSessionCreatorSearchQuery(creator string) (SearchQuery, error) { return NewTextQuery(SessionColumnCreator, creator, TextEquals) } +func NewSessionUserAgentFingerprintIDSearchQuery(fingerprintID string) (SearchQuery, error) { + return NewTextQuery(SessionColumnUserAgentFingerprintID, fingerprintID, TextEquals) +} + func NewUserIDSearchQuery(id string) (SearchQuery, error) { return NewTextQuery(SessionColumnUserID, id, TextEquals) } diff --git a/proto/zitadel/session/v2/session.proto b/proto/zitadel/session/v2/session.proto index 2c17d81f99..9074be4cb4 100644 --- a/proto/zitadel/session/v2/session.proto +++ b/proto/zitadel/session/v2/session.proto @@ -136,6 +136,8 @@ message SearchQuery { IDsQuery ids_query = 1; UserIDQuery user_id_query = 2; CreationDateQuery creation_date_query = 3; + OwnCreatorQuery own_creator_query = 4; + OwnUserAgentQuery own_useragent_query = 5; } } @@ -157,9 +159,13 @@ message CreationDateQuery { ]; } +message OwnCreatorQuery {} + +message OwnUserAgentQuery {} + message UserAgent { optional string fingerprint_id = 1; - optional string ip = 2; + optional string ip = 2; optional string description = 3; // A header may have multiple values. @@ -169,7 +175,7 @@ message UserAgent { message HeaderValues { repeated string values = 1; } - map header = 4; + map header = 4; } enum SessionFieldName {