From c71bf85b7a9dfbb7b499a6a90dfdd1924eefffa3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Thu, 12 Oct 2023 15:16:59 +0300 Subject: [PATCH] feat(api/v2): store user agent details in the session (#6711) This change adds the ability to set and get user agent data, such as fingerprint, IP, request headers and a description to the session. All fields are optional. Closes #6028 --- internal/api/grpc/session/v2/session.go | 56 ++++++- .../session/v2/session_integration_test.go | 84 +++++++--- internal/api/grpc/session/v2/session_test.go | 158 +++++++++++++++++- internal/command/auth_request_test.go | 44 ++++- internal/command/oidc_session_test.go | 22 ++- internal/command/session.go | 8 +- internal/command/session_test.go | 88 ++++++++-- internal/domain/user_agent.go | 17 ++ internal/query/prepare_test.go | 12 +- internal/query/projection/session.go | 94 +++++++---- internal/query/projection/session_test.go | 38 +++-- internal/query/session.go | 33 ++++ internal/query/sessions_test.go | 107 +++++++----- internal/repository/session/session.go | 3 + proto/zitadel/session/v2beta/session.proto | 16 ++ .../session/v2beta/session_service.proto | 1 + 16 files changed, 634 insertions(+), 147 deletions(-) create mode 100644 internal/domain/user_agent.go diff --git a/internal/api/grpc/session/v2/session.go b/internal/api/grpc/session/v2/session.go index fe8e0d8744..be6907ae08 100644 --- a/internal/api/grpc/session/v2/session.go +++ b/internal/api/grpc/session/v2/session.go @@ -2,10 +2,13 @@ package session import ( "context" + "net" + "net/http" "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" + "github.com/muhlemmer/gu" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/object/v2" "github.com/zitadel/zitadel/internal/command" @@ -41,7 +44,7 @@ func (s *Server) ListSessions(ctx context.Context, req *session.ListSessionsRequ } func (s *Server) CreateSession(ctx context.Context, req *session.CreateSessionRequest) (*session.CreateSessionResponse, error) { - checks, metadata, err := s.createSessionRequestToCommand(ctx, req) + checks, metadata, userAgent, err := s.createSessionRequestToCommand(ctx, req) if err != nil { return nil, err } @@ -50,7 +53,7 @@ func (s *Server) CreateSession(ctx context.Context, req *session.CreateSessionRe return nil, err } - set, err := s.command.CreateSession(ctx, cmds, metadata) + set, err := s.command.CreateSession(ctx, cmds, metadata, userAgent) if err != nil { return nil, err } @@ -113,9 +116,34 @@ func sessionToPb(s *query.Session) *session.Session { Sequence: s.Sequence, Factors: factorsToPb(s), Metadata: s.Metadata, + UserAgent: userAgentToPb(s.UserAgent), } } +func userAgentToPb(ua domain.UserAgent) *session.UserAgent { + if ua.IsEmpty() { + return nil + } + + out := &session.UserAgent{ + FingerprintId: ua.FingerprintID, + Description: ua.Description, + } + if ua.IP != nil { + out.Ip = gu.Ptr(ua.IP.String()) + } + if ua.Header == nil { + return out + } + out.Header = make(map[string]*session.UserAgent_HeaderValues, len(ua.Header)) + for k, v := range ua.Header { + out.Header[k] = &session.UserAgent_HeaderValues{ + Values: v, + } + } + return out +} + func factorsToPb(s *query.Session) *session.Factors { user := userFactorToPb(s.UserFactor) if user == nil { @@ -236,12 +264,30 @@ func idsQueryToQuery(q *session.IDsQuery) (query.SearchQuery, error) { return query.NewSessionIDsSearchQuery(q.Ids) } -func (s *Server) createSessionRequestToCommand(ctx context.Context, req *session.CreateSessionRequest) ([]command.SessionCommand, map[string][]byte, error) { +func (s *Server) createSessionRequestToCommand(ctx context.Context, req *session.CreateSessionRequest) ([]command.SessionCommand, map[string][]byte, *domain.UserAgent, error) { checks, err := s.checksToCommand(ctx, req.Checks) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - return checks, req.GetMetadata(), nil + return checks, req.GetMetadata(), userAgentToCommand(req.GetUserAgent()), nil +} + +func userAgentToCommand(userAgent *session.UserAgent) *domain.UserAgent { + if userAgent == nil { + return nil + } + out := &domain.UserAgent{ + FingerprintID: userAgent.FingerprintId, + IP: net.ParseIP(userAgent.GetIp()), + Description: userAgent.Description, + } + if len(userAgent.Header) > 0 { + out.Header = make(http.Header, len(userAgent.Header)) + for k, values := range userAgent.Header { + out.Header[k] = values.GetValues() + } + } + return out } func (s *Server) setSessionRequestToCommand(ctx context.Context, req *session.SetSessionRequest) ([]command.SessionCommand, error) { diff --git a/internal/api/grpc/session/v2/session_integration_test.go b/internal/api/grpc/session/v2/session_integration_test.go index 1dcad7bc53..9aba59ee37 100644 --- a/internal/api/grpc/session/v2/session_integration_test.go +++ b/internal/api/grpc/session/v2/session_integration_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" "github.com/zitadel/zitadel/internal/integration" object "github.com/zitadel/zitadel/pkg/grpc/object/v2beta" @@ -53,7 +54,7 @@ func TestMain(m *testing.M) { }()) } -func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, window time.Duration, metadata map[string][]byte, factors ...wantFactor) *session.Session { +func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, window time.Duration, metadata map[string][]byte, userAgent *session.UserAgent, factors ...wantFactor) *session.Session { t.Helper() require.NotEmpty(t, id) require.NotEmpty(t, token) @@ -70,6 +71,11 @@ func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, windo assert.WithinRange(t, s.GetChangeDate().AsTime(), time.Now().Add(-window), time.Now().Add(window)) assert.Equal(t, sequence, s.GetSequence()) assert.Equal(t, metadata, s.GetMetadata()) + + if !proto.Equal(userAgent, s.GetUserAgent()) { + t.Errorf("user agent =\n%v\nwant\n%v", s.GetUserAgent(), userAgent) + } + verifyFactors(t, s.GetFactors(), window, factors) return s } @@ -131,11 +137,12 @@ func verifyFactors(t testing.TB, factors *session.Factors, window time.Duration, func TestServer_CreateSession(t *testing.T) { tests := []struct { - name string - req *session.CreateSessionRequest - want *session.CreateSessionResponse - wantErr bool - wantFactors []wantFactor + name string + req *session.CreateSessionRequest + want *session.CreateSessionResponse + wantErr bool + wantFactors []wantFactor + wantUserAgent *session.UserAgent }{ { name: "empty session", @@ -148,6 +155,33 @@ func TestServer_CreateSession(t *testing.T) { }, }, }, + { + name: "user agent", + req: &session.CreateSessionRequest{ + Metadata: map[string][]byte{"foo": []byte("bar")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("fingerPrintID"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + want: &session.CreateSessionResponse{ + Details: &object.Details{ + ResourceOwner: Tester.Organisation.ID, + }, + }, + wantUserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("fingerPrintID"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, { name: "with user", req: &session.CreateSessionRequest{ @@ -219,7 +253,7 @@ func TestServer_CreateSession(t *testing.T) { require.NoError(t, err) integration.AssertDetails(t, tt.want, got) - verifyCurrentSession(t, got.GetSessionId(), got.GetSessionToken(), got.GetDetails().GetSequence(), time.Minute, tt.req.GetMetadata(), tt.wantFactors...) + verifyCurrentSession(t, got.GetSessionId(), got.GetSessionToken(), got.GetDetails().GetSequence(), time.Minute, tt.req.GetMetadata(), tt.wantUserAgent, tt.wantFactors...) }) } } @@ -242,7 +276,7 @@ func TestServer_CreateSession_webauthn(t *testing.T) { }, }) require.NoError(t, err) - verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil) + verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil, nil) assertionData, err := Tester.WebAuthN.CreateAssertionResponse(createResp.GetChallenges().GetWebAuthN().GetPublicKeyCredentialRequestOptions(), true) require.NoError(t, err) @@ -258,7 +292,7 @@ func TestServer_CreateSession_webauthn(t *testing.T) { }, }) require.NoError(t, err) - verifyCurrentSession(t, createResp.GetSessionId(), updateResp.GetSessionToken(), updateResp.GetDetails().GetSequence(), time.Minute, nil, wantUserFactor, wantWebAuthNFactorUserVerified) + verifyCurrentSession(t, createResp.GetSessionId(), updateResp.GetSessionToken(), updateResp.GetDetails().GetSequence(), time.Minute, nil, nil, wantUserFactor, wantWebAuthNFactorUserVerified) } func TestServer_CreateSession_successfulIntent(t *testing.T) { @@ -274,7 +308,7 @@ func TestServer_CreateSession_successfulIntent(t *testing.T) { }, }) require.NoError(t, err) - verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil) + verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil, nil) intentID, token, _, _ := Tester.CreateSuccessfulOAuthIntent(t, idpID, User.GetUserId(), "id") updateResp, err := Client.SetSession(CTX, &session.SetSessionRequest{ @@ -288,7 +322,7 @@ func TestServer_CreateSession_successfulIntent(t *testing.T) { }, }) require.NoError(t, err) - verifyCurrentSession(t, createResp.GetSessionId(), updateResp.GetSessionToken(), updateResp.GetDetails().GetSequence(), time.Minute, nil, wantUserFactor, wantIntentFactor) + verifyCurrentSession(t, createResp.GetSessionId(), updateResp.GetSessionToken(), updateResp.GetDetails().GetSequence(), time.Minute, nil, nil, wantUserFactor, wantIntentFactor) } func TestServer_CreateSession_successfulIntentUnknownUserID(t *testing.T) { @@ -304,7 +338,7 @@ func TestServer_CreateSession_successfulIntentUnknownUserID(t *testing.T) { }, }) require.NoError(t, err) - verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil) + verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil, nil) idpUserID := "id" intentID, token, _, _ := Tester.CreateSuccessfulOAuthIntent(t, idpID, "", idpUserID) @@ -331,7 +365,7 @@ func TestServer_CreateSession_successfulIntentUnknownUserID(t *testing.T) { }, }) require.NoError(t, err) - verifyCurrentSession(t, createResp.GetSessionId(), updateResp.GetSessionToken(), updateResp.GetDetails().GetSequence(), time.Minute, nil, wantUserFactor, wantIntentFactor) + verifyCurrentSession(t, createResp.GetSessionId(), updateResp.GetSessionToken(), updateResp.GetDetails().GetSequence(), time.Minute, nil, nil, wantUserFactor, wantIntentFactor) } func TestServer_CreateSession_startedIntentFalseToken(t *testing.T) { @@ -347,7 +381,7 @@ func TestServer_CreateSession_startedIntentFalseToken(t *testing.T) { }, }) require.NoError(t, err) - verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil) + verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil, nil) intentID := Tester.CreateIntent(t, idpID) _, err = Client.SetSession(CTX, &session.SetSessionRequest{ @@ -399,7 +433,7 @@ func TestServer_SetSession_flow(t *testing.T) { createResp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{}) require.NoError(t, err) sessionToken := createResp.GetSessionToken() - verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, createResp.GetDetails().GetSequence(), time.Minute, nil) + verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, createResp.GetDetails().GetSequence(), time.Minute, nil, nil) t.Run("check user", func(t *testing.T) { resp, err := Client.SetSession(CTX, &session.SetSessionRequest{ @@ -415,7 +449,7 @@ func TestServer_SetSession_flow(t *testing.T) { }) require.NoError(t, err) sessionToken = resp.GetSessionToken() - verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, wantUserFactor) + verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, nil, wantUserFactor) }) t.Run("check webauthn, user verified (passkey)", func(t *testing.T) { @@ -430,7 +464,7 @@ func TestServer_SetSession_flow(t *testing.T) { }, }) require.NoError(t, err) - verifyCurrentSession(t, createResp.GetSessionId(), resp.GetSessionToken(), resp.GetDetails().GetSequence(), time.Minute, nil) + verifyCurrentSession(t, createResp.GetSessionId(), resp.GetSessionToken(), resp.GetDetails().GetSequence(), time.Minute, nil, nil) sessionToken = resp.GetSessionToken() assertionData, err := Tester.WebAuthN.CreateAssertionResponse(resp.GetChallenges().GetWebAuthN().GetPublicKeyCredentialRequestOptions(), true) @@ -447,7 +481,7 @@ func TestServer_SetSession_flow(t *testing.T) { }) require.NoError(t, err) sessionToken = resp.GetSessionToken() - verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, wantUserFactor, wantWebAuthNFactorUserVerified) + verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, nil, wantUserFactor, wantWebAuthNFactorUserVerified) }) userAuthCtx := Tester.WithAuthorizationToken(CTX, sessionToken) @@ -474,7 +508,7 @@ func TestServer_SetSession_flow(t *testing.T) { }, }) require.NoError(t, err) - verifyCurrentSession(t, createResp.GetSessionId(), resp.GetSessionToken(), resp.GetDetails().GetSequence(), time.Minute, nil) + verifyCurrentSession(t, createResp.GetSessionId(), resp.GetSessionToken(), resp.GetDetails().GetSequence(), time.Minute, nil, nil) sessionToken = resp.GetSessionToken() assertionData, err := Tester.WebAuthN.CreateAssertionResponse(resp.GetChallenges().GetWebAuthN().GetPublicKeyCredentialRequestOptions(), false) @@ -491,7 +525,7 @@ func TestServer_SetSession_flow(t *testing.T) { }) require.NoError(t, err) sessionToken = resp.GetSessionToken() - verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, wantUserFactor, wantWebAuthNFactor) + verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, nil, wantUserFactor, wantWebAuthNFactor) }) } }) @@ -510,7 +544,7 @@ func TestServer_SetSession_flow(t *testing.T) { }) require.NoError(t, err) sessionToken = resp.GetSessionToken() - verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, wantUserFactor, wantWebAuthNFactor, wantTOTPFactor) + verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, nil, wantUserFactor, wantWebAuthNFactor, wantTOTPFactor) }) t.Run("check OTP SMS", func(t *testing.T) { @@ -522,7 +556,7 @@ func TestServer_SetSession_flow(t *testing.T) { }, }) require.NoError(t, err) - verifyCurrentSession(t, createResp.GetSessionId(), resp.GetSessionToken(), resp.GetDetails().GetSequence(), time.Minute, nil) + verifyCurrentSession(t, createResp.GetSessionId(), resp.GetSessionToken(), resp.GetDetails().GetSequence(), time.Minute, nil, nil) sessionToken = resp.GetSessionToken() otp := resp.GetChallenges().GetOtpSms() @@ -539,7 +573,7 @@ func TestServer_SetSession_flow(t *testing.T) { }) require.NoError(t, err) sessionToken = resp.GetSessionToken() - verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, wantUserFactor, wantWebAuthNFactor, wantOTPSMSFactor) + verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, nil, wantUserFactor, wantWebAuthNFactor, wantOTPSMSFactor) }) t.Run("check OTP Email", func(t *testing.T) { @@ -553,7 +587,7 @@ func TestServer_SetSession_flow(t *testing.T) { }, }) require.NoError(t, err) - verifyCurrentSession(t, createResp.GetSessionId(), resp.GetSessionToken(), resp.GetDetails().GetSequence(), time.Minute, nil) + verifyCurrentSession(t, createResp.GetSessionId(), resp.GetSessionToken(), resp.GetDetails().GetSequence(), time.Minute, nil, nil) sessionToken = resp.GetSessionToken() otp := resp.GetChallenges().GetOtpEmail() @@ -570,7 +604,7 @@ func TestServer_SetSession_flow(t *testing.T) { }) require.NoError(t, err) sessionToken = resp.GetSessionToken() - verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, wantUserFactor, wantWebAuthNFactor, wantOTPEmailFactor) + verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, nil, wantUserFactor, wantWebAuthNFactor, wantOTPEmailFactor) }) } diff --git a/internal/api/grpc/session/v2/session_test.go b/internal/api/grpc/session/v2/session_test.go index 33804caba5..ae33fab4c7 100644 --- a/internal/api/grpc/session/v2/session_test.go +++ b/internal/api/grpc/session/v2/session_test.go @@ -2,15 +2,18 @@ package session import ( "context" + "net" + "net/http" "testing" "time" + "github.com/muhlemmer/gu" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/zitadel/zitadel/internal/api/authz" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" - "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/domain" caos_errs "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/query" @@ -23,7 +26,7 @@ func Test_sessionsToPb(t *testing.T) { past := now.Add(-time.Hour) sessions := []*query.Session{ - { // no factor + { // no factor, with user agent ID: "999", CreationDate: now, ChangeDate: now, @@ -32,6 +35,12 @@ func Test_sessionsToPb(t *testing.T) { ResourceOwner: "me", Creator: "he", Metadata: map[string][]byte{"hello": []byte("world")}, + UserAgent: domain.UserAgent{ + FingerprintID: gu.Ptr("fingerprintID"), + Description: gu.Ptr("description"), + IP: net.IPv4(1, 2, 3, 4), + Header: http.Header{"foo": []string{"foo", "bar"}}, + }, }, { // user factor ID: "999", @@ -114,13 +123,21 @@ func Test_sessionsToPb(t *testing.T) { } want := []*session.Session{ - { // no factor + { // no factor, with user agent Id: "999", CreationDate: timestamppb.New(now), ChangeDate: timestamppb.New(now), Sequence: 123, Factors: nil, Metadata: map[string][]byte{"hello": []byte("world")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("fingerprintID"), + Description: gu.Ptr("description"), + Ip: gu.Ptr("1.2.3.4"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, }, { // user factor Id: "999", @@ -208,6 +225,71 @@ func Test_sessionsToPb(t *testing.T) { } } +func Test_userAgentToPb(t *testing.T) { + type args struct { + ua domain.UserAgent + } + tests := []struct { + name string + args args + want *session.UserAgent + }{ + { + name: "empty", + args: args{domain.UserAgent{}}, + }, + { + name: "fingerprint id and description", + args: args{domain.UserAgent{ + FingerprintID: gu.Ptr("fingerPrintID"), + Description: gu.Ptr("description"), + }}, + want: &session.UserAgent{ + FingerprintId: gu.Ptr("fingerPrintID"), + Description: gu.Ptr("description"), + }, + }, + { + name: "with ip", + args: args{domain.UserAgent{ + FingerprintID: gu.Ptr("fingerPrintID"), + Description: gu.Ptr("description"), + IP: net.IPv4(1, 2, 3, 4), + }}, + want: &session.UserAgent{ + FingerprintId: gu.Ptr("fingerPrintID"), + Description: gu.Ptr("description"), + Ip: gu.Ptr("1.2.3.4"), + }, + }, + { + name: "with header", + args: args{domain.UserAgent{ + FingerprintID: gu.Ptr("fingerPrintID"), + Description: gu.Ptr("description"), + Header: http.Header{ + "foo": []string{"foo", "bar"}, + "hello": []string{"world"}, + }, + }}, + want: &session.UserAgent{ + FingerprintId: gu.Ptr("fingerPrintID"), + Description: gu.Ptr("description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + "hello": {Values: []string{"world"}}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := userAgentToPb(tt.args.ua) + assert.Equal(t, tt.want, got) + }) + } +} + func mustNewTextQuery(t testing.TB, column query.Column, value string, compare query.TextComparison) query.SearchQuery { q, err := query.NewTextQuery(column, value, compare) require.NoError(t, err) @@ -510,3 +592,73 @@ func Test_userVerificationRequirementToDomain(t *testing.T) { }) } } + +func Test_userAgentToCommand(t *testing.T) { + type args struct { + userAgent *session.UserAgent + } + tests := []struct { + name string + args args + want *domain.UserAgent + }{ + { + name: "nil", + args: args{nil}, + want: nil, + }, + { + name: "all fields", + args: args{&session.UserAgent{ + FingerprintId: gu.Ptr("fp1"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: map[string]*session.UserAgent_HeaderValues{ + "hello": { + Values: []string{"foo", "bar"}, + }, + }, + }}, + want: &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{ + "hello": []string{"foo", "bar"}, + }, + }, + }, + { + name: "invalid ip", + args: args{&session.UserAgent{ + FingerprintId: gu.Ptr("fp1"), + Ip: gu.Ptr("oops"), + Description: gu.Ptr("firefox"), + Header: map[string]*session.UserAgent_HeaderValues{ + "hello": { + Values: []string{"foo", "bar"}, + }, + }, + }}, + want: &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: nil, + Description: gu.Ptr("firefox"), + Header: http.Header{ + "hello": []string{"foo", "bar"}, + }, + }, + }, + { + name: "nil fields", + args: args{&session.UserAgent{}}, + want: &domain.UserAgent{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := userAgentToCommand(tt.args.userAgent) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/command/auth_request_test.go b/internal/command/auth_request_test.go index 9e6bf8328d..61f7c9184e 100644 --- a/internal/command/auth_request_test.go +++ b/internal/command/auth_request_test.go @@ -2,6 +2,8 @@ package command import ( "context" + "net" + "net/http" "testing" "time" @@ -358,7 +360,15 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) { ), expectFilter( eventFromEventPusher( - session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate)), + session.NewAddedEvent(mockCtx, + &session.NewAggregate("sessionID", "org1").Aggregate, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + )), ), ), tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) { @@ -401,7 +411,15 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) { ), expectFilter( eventFromEventPusher( - session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate)), + session.NewAddedEvent(mockCtx, + &session.NewAggregate("sessionID", "org1").Aggregate, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + )), ), ), tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) { @@ -444,8 +462,15 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) { ), expectFilter( eventFromEventPusher( - session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate), - ), + session.NewAddedEvent(mockCtx, + &session.NewAggregate("sessionID", "org1").Aggregate, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + )), eventFromEventPusher( session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "userID", testNow), @@ -523,8 +548,15 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) { ), expectFilter( eventFromEventPusher( - session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate), - ), + session.NewAddedEvent(mockCtx, + &session.NewAggregate("sessionID", "org1").Aggregate, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + )), eventFromEventPusher( session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "userID", testNow), diff --git a/internal/command/oidc_session_test.go b/internal/command/oidc_session_test.go index aba917fa24..59d8f7d3f3 100644 --- a/internal/command/oidc_session_test.go +++ b/internal/command/oidc_session_test.go @@ -2,6 +2,8 @@ package command import ( "context" + "net" + "net/http" "testing" "time" @@ -164,7 +166,15 @@ func TestCommands_AddOIDCSessionAccessToken(t *testing.T) { ), expectFilter( eventFromEventPusher( - session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate), + session.NewAddedEvent(context.Background(), + &session.NewAggregate("sessionID", "org1").Aggregate, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + ), ), eventFromEventPusher( session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, @@ -365,7 +375,15 @@ func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) { ), expectFilter( eventFromEventPusher( - session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate), + session.NewAddedEvent(context.Background(), + &session.NewAggregate("sessionID", "org1").Aggregate, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + ), ), eventFromEventPusher( session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, diff --git a/internal/command/session.go b/internal/command/session.go index a58c5d2d3a..caf3056f76 100644 --- a/internal/command/session.go +++ b/internal/command/session.go @@ -166,8 +166,8 @@ func (s *SessionCommands) Exec(ctx context.Context) error { return nil } -func (s *SessionCommands) Start(ctx context.Context) { - s.eventCommands = append(s.eventCommands, session.NewAddedEvent(ctx, s.sessionWriteModel.aggregate)) +func (s *SessionCommands) Start(ctx context.Context, userAgent *domain.UserAgent) { + s.eventCommands = append(s.eventCommands, session.NewAddedEvent(ctx, s.sessionWriteModel.aggregate, userAgent)) } func (s *SessionCommands) UserChecked(ctx context.Context, userID string, checkedAt time.Time) error { @@ -280,7 +280,7 @@ func (s *SessionCommands) commands(ctx context.Context) (string, []eventstore.Co return token, s.eventCommands, nil } -func (c *Commands) CreateSession(ctx context.Context, cmds []SessionCommand, metadata map[string][]byte) (set *SessionChanged, err error) { +func (c *Commands) CreateSession(ctx context.Context, cmds []SessionCommand, metadata map[string][]byte, userAgent *domain.UserAgent) (set *SessionChanged, err error) { sessionID, err := c.idGenerator.Next() if err != nil { return nil, err @@ -291,7 +291,7 @@ func (c *Commands) CreateSession(ctx context.Context, cmds []SessionCommand, met return nil, err } cmd := c.NewSessionCommands(cmds, sessionWriteModel) - cmd.Start(ctx) + cmd.Start(ctx, userAgent) return c.updateSession(ctx, cmd, metadata) } diff --git a/internal/command/session_test.go b/internal/command/session_test.go index 57f3ba971c..048ed51e84 100644 --- a/internal/command/session_test.go +++ b/internal/command/session_test.go @@ -3,10 +3,13 @@ package command import ( "context" "io" + "net" + "net/http" "testing" "time" "github.com/golang/mock/gomock" + "github.com/muhlemmer/gu" "github.com/pquerna/otp/totp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -145,9 +148,10 @@ func TestCommands_CreateSession(t *testing.T) { tokenCreator func(sessionID string) (string, string, error) } type args struct { - ctx context.Context - checks []SessionCommand - metadata map[string][]byte + ctx context.Context + checks []SessionCommand + metadata map[string][]byte + userAgent *domain.UserAgent } type res struct { want *SessionChanged @@ -200,12 +204,26 @@ func TestCommands_CreateSession(t *testing.T) { }, args{ ctx: authz.NewMockContext("", "org1", ""), + userAgent: &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, }, []expect{ expectFilter(), expectPush( eventPusherToEvents( - session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate), + session.NewAddedEvent(context.Background(), + &session.NewAggregate("sessionID", "org1").Aggregate, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + ), session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "tokenID", ), @@ -229,7 +247,7 @@ func TestCommands_CreateSession(t *testing.T) { idGenerator: tt.fields.idGenerator, sessionTokenCreator: tt.fields.tokenCreator, } - got, err := c.CreateSession(tt.args.ctx, tt.args.checks, tt.args.metadata) + got, err := c.CreateSession(tt.args.ctx, tt.args.checks, tt.args.metadata, tt.args.userAgent) require.ErrorIs(t, err, tt.res.err) assert.Equal(t, tt.res.want, got) }) @@ -278,7 +296,15 @@ func TestCommands_UpdateSession(t *testing.T) { eventstore: eventstoreExpect(t, expectFilter( eventFromEventPusher( - session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)), + session.NewAddedEvent(context.Background(), + &session.NewAggregate("sessionID", "org1").Aggregate, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + )), eventFromEventPusher( session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "tokenID")), @@ -303,7 +329,15 @@ func TestCommands_UpdateSession(t *testing.T) { eventstore: eventstoreExpect(t, expectFilter( eventFromEventPusher( - session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)), + session.NewAddedEvent(context.Background(), + &session.NewAggregate("sessionID", "org1").Aggregate, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + )), eventFromEventPusher( session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "tokenID")), @@ -868,7 +902,15 @@ func TestCommands_TerminateSession(t *testing.T) { eventstore: eventstoreExpect(t, expectFilter( eventFromEventPusher( - session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)), + session.NewAddedEvent(context.Background(), + &session.NewAggregate("sessionID", "org1").Aggregate, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + )), eventFromEventPusher( session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "tokenID")), @@ -893,7 +935,15 @@ func TestCommands_TerminateSession(t *testing.T) { eventstore: eventstoreExpect(t, expectFilter( eventFromEventPusher( - session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)), + session.NewAddedEvent(context.Background(), + &session.NewAggregate("sessionID", "org1").Aggregate, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + )), eventFromEventPusher( session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "tokenID")), @@ -922,7 +972,15 @@ func TestCommands_TerminateSession(t *testing.T) { eventstore: eventstoreExpect(t, expectFilter( eventFromEventPusher( - session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)), + session.NewAddedEvent(context.Background(), + &session.NewAggregate("sessionID", "org1").Aggregate, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + )), eventFromEventPusher( session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "tokenID"), @@ -953,7 +1011,15 @@ func TestCommands_TerminateSession(t *testing.T) { eventstore: eventstoreExpect(t, expectFilter( eventFromEventPusher( - session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)), + session.NewAddedEvent(context.Background(), + &session.NewAggregate("sessionID", "org1").Aggregate, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + )), eventFromEventPusher( session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "tokenID"), diff --git a/internal/domain/user_agent.go b/internal/domain/user_agent.go new file mode 100644 index 0000000000..ca72c6bec4 --- /dev/null +++ b/internal/domain/user_agent.go @@ -0,0 +1,17 @@ +package domain + +import ( + "net" + httplib "net/http" +) + +type UserAgent struct { + FingerprintID *string `json:"fingerprint_id,omitempty"` + IP net.IP `json:"ip,omitempty"` + Description *string `json:"description,omitempty"` + Header httplib.Header `json:"header,omitempty"` +} + +func (ua UserAgent) IsEmpty() bool { + return ua.FingerprintID == nil && len(ua.IP) == 0 && ua.Description == nil && ua.Header == nil +} diff --git a/internal/query/prepare_test.go b/internal/query/prepare_test.go index 242b387408..c9770c0dd1 100644 --- a/internal/query/prepare_test.go +++ b/internal/query/prepare_test.go @@ -54,7 +54,7 @@ func assertPrepare(t *testing.T, prepareFunc, expectedObject interface{}, sqlExp } return isErr(err) } - object, ok, didScan := execScan(&database.DB{DB: client}, builder, scan, errCheck) + object, ok, didScan := execScan(t, &database.DB{DB: client}, builder, scan, errCheck) if !ok { t.Error(object) return false @@ -168,7 +168,7 @@ var ( selectBuilderType = reflect.TypeOf(sq.SelectBuilder{}) ) -func execScan(client *database.DB, builder sq.SelectBuilder, scan interface{}, errCheck checkErr) (object interface{}, ok bool, didScan bool) { +func execScan(t testing.TB, client *database.DB, builder sq.SelectBuilder, scan interface{}, errCheck checkErr) (object interface{}, ok bool, didScan bool) { scanType := reflect.TypeOf(scan) err := validateScan(scanType) if err != nil { @@ -177,7 +177,7 @@ func execScan(client *database.DB, builder sq.SelectBuilder, scan interface{}, e stmt, args, err := builder.ToSql() if err != nil { - return fmt.Errorf("unexpeted error from sql builder: %w", err), false, false + return fmt.Errorf("unexpected error from sql builder: %w", err), false, false } //resultSet represents *sql.Row or *sql.Rows, @@ -199,6 +199,9 @@ func execScan(client *database.DB, builder sq.SelectBuilder, scan interface{}, e // if scan(*sql.Row)... } else if scanType.In(0).AssignableTo(rowType) { err = client.QueryRow(func(r *sql.Row) error { + if r.Err() != nil { + return r.Err() + } didScan = true res = reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(r)}) if err, ok := res[1].Interface().(error); ok { @@ -213,6 +216,9 @@ func execScan(client *database.DB, builder sq.SelectBuilder, scan interface{}, e if err != nil { err, ok := errCheck(err) + if !ok { + t.Fatal(err) + } if didScan { return res[0].Interface(), ok, didScan } diff --git a/internal/query/projection/session.go b/internal/query/projection/session.go index 7305d6a87c..f51d7c1b2a 100644 --- a/internal/query/projection/session.go +++ b/internal/query/projection/session.go @@ -14,27 +14,31 @@ import ( ) const ( - SessionsProjectionTable = "projections.sessions5" + SessionsProjectionTable = "projections.sessions6" - SessionColumnID = "id" - SessionColumnCreationDate = "creation_date" - SessionColumnChangeDate = "change_date" - SessionColumnSequence = "sequence" - SessionColumnState = "state" - SessionColumnResourceOwner = "resource_owner" - SessionColumnInstanceID = "instance_id" - SessionColumnCreator = "creator" - SessionColumnUserID = "user_id" - SessionColumnUserCheckedAt = "user_checked_at" - SessionColumnPasswordCheckedAt = "password_checked_at" - SessionColumnIntentCheckedAt = "intent_checked_at" - SessionColumnWebAuthNCheckedAt = "webauthn_checked_at" - SessionColumnWebAuthNUserVerified = "webauthn_user_verified" - SessionColumnTOTPCheckedAt = "totp_checked_at" - SessionColumnOTPSMSCheckedAt = "otp_sms_checked_at" - SessionColumnOTPEmailCheckedAt = "otp_email_checked_at" - SessionColumnMetadata = "metadata" - SessionColumnTokenID = "token_id" + SessionColumnID = "id" + SessionColumnCreationDate = "creation_date" + SessionColumnChangeDate = "change_date" + SessionColumnSequence = "sequence" + SessionColumnState = "state" + SessionColumnResourceOwner = "resource_owner" + SessionColumnInstanceID = "instance_id" + SessionColumnCreator = "creator" + SessionColumnUserID = "user_id" + SessionColumnUserCheckedAt = "user_checked_at" + SessionColumnPasswordCheckedAt = "password_checked_at" + SessionColumnIntentCheckedAt = "intent_checked_at" + SessionColumnWebAuthNCheckedAt = "webauthn_checked_at" + SessionColumnWebAuthNUserVerified = "webauthn_user_verified" + SessionColumnTOTPCheckedAt = "totp_checked_at" + SessionColumnOTPSMSCheckedAt = "otp_sms_checked_at" + SessionColumnOTPEmailCheckedAt = "otp_email_checked_at" + SessionColumnMetadata = "metadata" + SessionColumnTokenID = "token_id" + SessionColumnUserAgentFingerprintID = "user_agent_fingerprint_id" + SessionColumnUserAgentIP = "user_agent_ip" + SessionColumnUserAgentDescription = "user_agent_description" + SessionColumnUserAgentHeader = "user_agent_header" ) type sessionProjection struct { @@ -66,8 +70,16 @@ func newSessionProjection(ctx context.Context, config crdb.StatementHandlerConfi crdb.NewColumn(SessionColumnOTPEmailCheckedAt, crdb.ColumnTypeTimestamp, crdb.Nullable()), crdb.NewColumn(SessionColumnMetadata, crdb.ColumnTypeJSONB, crdb.Nullable()), crdb.NewColumn(SessionColumnTokenID, crdb.ColumnTypeText, crdb.Nullable()), + crdb.NewColumn(SessionColumnUserAgentFingerprintID, crdb.ColumnTypeText, crdb.Nullable()), + crdb.NewColumn(SessionColumnUserAgentIP, crdb.ColumnTypeText, crdb.Nullable()), + crdb.NewColumn(SessionColumnUserAgentDescription, crdb.ColumnTypeText, crdb.Nullable()), + crdb.NewColumn(SessionColumnUserAgentHeader, crdb.ColumnTypeJSONB, crdb.Nullable()), }, crdb.NewPrimaryKey(SessionColumnInstanceID, SessionColumnID), + crdb.WithIndex(crdb.NewIndex( + SessionColumnUserAgentFingerprintID+"_idx", + []string{SessionColumnUserAgentFingerprintID}, + )), ), ) p.StatementHandler = crdb.NewStatementHandler(ctx, config) @@ -152,19 +164,35 @@ func (p *sessionProjection) reduceSessionAdded(event eventstore.Event) (*handler return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-Sfrgf", "reduce.wrong.event.type %s", session.AddedType) } - return crdb.NewCreateStatement( - e, - []handler.Column{ - handler.NewCol(SessionColumnID, e.Aggregate().ID), - handler.NewCol(SessionColumnInstanceID, e.Aggregate().InstanceID), - handler.NewCol(SessionColumnCreationDate, e.CreationDate()), - handler.NewCol(SessionColumnChangeDate, e.CreationDate()), - handler.NewCol(SessionColumnResourceOwner, e.Aggregate().ResourceOwner), - handler.NewCol(SessionColumnState, domain.SessionStateActive), - handler.NewCol(SessionColumnSequence, e.Sequence()), - handler.NewCol(SessionColumnCreator, e.User), - }, - ), nil + cols := make([]handler.Column, 0, 12) + cols = append(cols, + handler.NewCol(SessionColumnID, e.Aggregate().ID), + handler.NewCol(SessionColumnInstanceID, e.Aggregate().InstanceID), + handler.NewCol(SessionColumnCreationDate, e.CreationDate()), + handler.NewCol(SessionColumnChangeDate, e.CreationDate()), + handler.NewCol(SessionColumnResourceOwner, e.Aggregate().ResourceOwner), + handler.NewCol(SessionColumnState, domain.SessionStateActive), + handler.NewCol(SessionColumnSequence, e.Sequence()), + handler.NewCol(SessionColumnCreator, e.User), + ) + if e.UserAgent != nil { + cols = append(cols, + handler.NewCol(SessionColumnUserAgentFingerprintID, e.UserAgent.FingerprintID), + handler.NewCol(SessionColumnUserAgentDescription, e.UserAgent.Description), + ) + if e.UserAgent.IP != nil { + cols = append(cols, + handler.NewCol(SessionColumnUserAgentIP, e.UserAgent.IP.String()), + ) + } + if e.UserAgent.Header != nil { + cols = append(cols, + handler.NewJSONCol(SessionColumnUserAgentHeader, e.UserAgent.Header), + ) + } + } + + return crdb.NewCreateStatement(e, cols), nil } func (p *sessionProjection) reduceUserChecked(event eventstore.Event) (*handler.Statement, error) { diff --git a/internal/query/projection/session_test.go b/internal/query/projection/session_test.go index 8ac52b7484..e18f1ca59c 100644 --- a/internal/query/projection/session_test.go +++ b/internal/query/projection/session_test.go @@ -4,6 +4,8 @@ import ( "testing" "time" + "github.com/muhlemmer/gu" + "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/eventstore" @@ -31,7 +33,15 @@ func TestSessionProjection_reduces(t *testing.T) { session.AddedType, session.AggregateType, []byte(`{ - "domain": "domain" + "domain": "domain", + "user_agent": { + "fingerprint_id": "fp1", + "ip": "1.2.3.4", + "description": "firefox", + "header": { + "foo": ["bar"] + } + } }`), ), session.AddedEventMapper), }, @@ -43,7 +53,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "INSERT INTO projections.sessions5 (id, instance_id, creation_date, change_date, resource_owner, state, sequence, creator) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)", + expectedStmt: "INSERT INTO projections.sessions6 (id, instance_id, creation_date, change_date, resource_owner, state, sequence, creator, user_agent_fingerprint_id, user_agent_description, user_agent_ip, user_agent_header) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)", expectedArgs: []interface{}{ "agg-id", "instance-id", @@ -53,6 +63,10 @@ func TestSessionProjection_reduces(t *testing.T) { domain.SessionStateActive, uint64(15), "editor-user", + gu.Ptr("fp1"), + gu.Ptr("firefox"), + "1.2.3.4", + []byte(`{"foo":["bar"]}`), }, }, }, @@ -79,7 +93,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions5 SET (change_date, sequence, user_id, user_checked_at) = ($1, $2, $3, $4) WHERE (id = $5) AND (instance_id = $6)", + expectedStmt: "UPDATE projections.sessions6 SET (change_date, sequence, user_id, user_checked_at) = ($1, $2, $3, $4) WHERE (id = $5) AND (instance_id = $6)", expectedArgs: []interface{}{ anyArg{}, anyArg{}, @@ -112,7 +126,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions5 SET (change_date, sequence, password_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", + expectedStmt: "UPDATE projections.sessions6 SET (change_date, sequence, password_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", expectedArgs: []interface{}{ anyArg{}, anyArg{}, @@ -145,7 +159,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions5 SET (change_date, sequence, webauthn_checked_at, webauthn_user_verified) = ($1, $2, $3, $4) WHERE (id = $5) AND (instance_id = $6)", + expectedStmt: "UPDATE projections.sessions6 SET (change_date, sequence, webauthn_checked_at, webauthn_user_verified) = ($1, $2, $3, $4) WHERE (id = $5) AND (instance_id = $6)", expectedArgs: []interface{}{ anyArg{}, anyArg{}, @@ -178,7 +192,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions5 SET (change_date, sequence, intent_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", + expectedStmt: "UPDATE projections.sessions6 SET (change_date, sequence, intent_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", expectedArgs: []interface{}{ anyArg{}, anyArg{}, @@ -210,7 +224,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions5 SET (change_date, sequence, totp_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", + expectedStmt: "UPDATE projections.sessions6 SET (change_date, sequence, totp_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", expectedArgs: []interface{}{ anyArg{}, anyArg{}, @@ -242,7 +256,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions5 SET (change_date, sequence, token_id) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", + expectedStmt: "UPDATE projections.sessions6 SET (change_date, sequence, token_id) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", expectedArgs: []interface{}{ anyArg{}, anyArg{}, @@ -276,7 +290,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions5 SET (change_date, sequence, metadata) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", + expectedStmt: "UPDATE projections.sessions6 SET (change_date, sequence, metadata) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", expectedArgs: []interface{}{ anyArg{}, anyArg{}, @@ -308,7 +322,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "DELETE FROM projections.sessions5 WHERE (id = $1) AND (instance_id = $2)", + expectedStmt: "DELETE FROM projections.sessions6 WHERE (id = $1) AND (instance_id = $2)", expectedArgs: []interface{}{ "agg-id", "instance-id", @@ -335,7 +349,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "DELETE FROM projections.sessions5 WHERE (instance_id = $1)", + expectedStmt: "DELETE FROM projections.sessions6 WHERE (instance_id = $1)", expectedArgs: []interface{}{ "agg-id", }, @@ -366,7 +380,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions5 SET password_checked_at = $1 WHERE (user_id = $2) AND (password_checked_at < $3)", + expectedStmt: "UPDATE projections.sessions6 SET password_checked_at = $1 WHERE (user_id = $2) AND (password_checked_at < $3)", expectedArgs: []interface{}{ nil, "agg-id", diff --git a/internal/query/session.go b/internal/query/session.go index c098d3d110..73ff5a9688 100644 --- a/internal/query/session.go +++ b/internal/query/session.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" errs "errors" + "net" + "net/http" "time" sq "github.com/Masterminds/squirrel" @@ -38,6 +40,7 @@ type Session struct { OTPSMSFactor SessionOTPFactor OTPEmailFactor SessionOTPFactor Metadata map[string][]byte + UserAgent domain.UserAgent } type SessionUserFactor struct { @@ -163,6 +166,22 @@ var ( name: projection.SessionColumnTokenID, table: sessionsTable, } + SessionColumnUserAgentFingerprintID = Column{ + name: projection.SessionColumnUserAgentFingerprintID, + table: sessionsTable, + } + SessionColumnUserAgentIP = Column{ + name: projection.SessionColumnUserAgentIP, + table: sessionsTable, + } + SessionColumnUserAgentDescription = Column{ + name: projection.SessionColumnUserAgentDescription, + table: sessionsTable, + } + SessionColumnUserAgentHeader = Column{ + name: projection.SessionColumnUserAgentHeader, + table: sessionsTable, + } ) func (q *Queries) SessionByID(ctx context.Context, shouldTriggerBulk bool, id, sessionToken string) (session *Session, err error) { @@ -261,6 +280,10 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil SessionColumnOTPEmailCheckedAt.identifier(), SessionColumnMetadata.identifier(), SessionColumnToken.identifier(), + SessionColumnUserAgentFingerprintID.identifier(), + SessionColumnUserAgentIP.identifier(), + SessionColumnUserAgentDescription.identifier(), + SessionColumnUserAgentHeader.identifier(), ).From(sessionsTable.identifier()). LeftJoin(join(LoginNameUserIDCol, SessionColumnUserID)). LeftJoin(join(HumanUserIDCol, SessionColumnUserID)). @@ -283,6 +306,8 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil otpEmailCheckedAt sql.NullTime metadata database.Map[[]byte] token sql.NullString + userAgentIP sql.NullString + userAgentHeader database.Map[[]string] ) err := row.Scan( @@ -307,6 +332,10 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil &otpEmailCheckedAt, &metadata, &token, + &session.UserAgent.FingerprintID, + &userAgentIP, + &session.UserAgent.Description, + &userAgentHeader, ) if err != nil { @@ -329,7 +358,11 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil session.OTPSMSFactor.OTPCheckedAt = otpSMSCheckedAt.Time session.OTPEmailFactor.OTPCheckedAt = otpEmailCheckedAt.Time session.Metadata = metadata + session.UserAgent.Header = http.Header(userAgentHeader) + if userAgentIP.Valid { + session.UserAgent.IP = net.ParseIP(userAgentIP.String) + } return session, token.String, nil } } diff --git a/internal/query/sessions_test.go b/internal/query/sessions_test.go index fa5209bdd3..1bae095ec6 100644 --- a/internal/query/sessions_test.go +++ b/internal/query/sessions_test.go @@ -6,10 +6,13 @@ import ( "database/sql/driver" "errors" "fmt" + "net" + "net/http" "regexp" "testing" sq "github.com/Masterminds/squirrel" + "github.com/muhlemmer/gu" "github.com/stretchr/testify/require" "github.com/zitadel/zitadel/internal/domain" @@ -17,57 +20,61 @@ import ( ) var ( - expectedSessionQuery = regexp.QuoteMeta(`SELECT projections.sessions5.id,` + - ` projections.sessions5.creation_date,` + - ` projections.sessions5.change_date,` + - ` projections.sessions5.sequence,` + - ` projections.sessions5.state,` + - ` projections.sessions5.resource_owner,` + - ` projections.sessions5.creator,` + - ` projections.sessions5.user_id,` + - ` projections.sessions5.user_checked_at,` + + expectedSessionQuery = regexp.QuoteMeta(`SELECT projections.sessions6.id,` + + ` projections.sessions6.creation_date,` + + ` projections.sessions6.change_date,` + + ` projections.sessions6.sequence,` + + ` projections.sessions6.state,` + + ` projections.sessions6.resource_owner,` + + ` projections.sessions6.creator,` + + ` projections.sessions6.user_id,` + + ` projections.sessions6.user_checked_at,` + ` projections.login_names2.login_name,` + ` projections.users8_humans.display_name,` + ` projections.users8.resource_owner,` + - ` projections.sessions5.password_checked_at,` + - ` projections.sessions5.intent_checked_at,` + - ` projections.sessions5.webauthn_checked_at,` + - ` projections.sessions5.webauthn_user_verified,` + - ` projections.sessions5.totp_checked_at,` + - ` projections.sessions5.otp_sms_checked_at,` + - ` projections.sessions5.otp_email_checked_at,` + - ` projections.sessions5.metadata,` + - ` projections.sessions5.token_id` + - ` FROM projections.sessions5` + - ` LEFT JOIN projections.login_names2 ON projections.sessions5.user_id = projections.login_names2.user_id AND projections.sessions5.instance_id = projections.login_names2.instance_id` + - ` LEFT JOIN projections.users8_humans ON projections.sessions5.user_id = projections.users8_humans.user_id AND projections.sessions5.instance_id = projections.users8_humans.instance_id` + - ` LEFT JOIN projections.users8 ON projections.sessions5.user_id = projections.users8.id AND projections.sessions5.instance_id = projections.users8.instance_id` + + ` projections.sessions6.password_checked_at,` + + ` projections.sessions6.intent_checked_at,` + + ` projections.sessions6.webauthn_checked_at,` + + ` projections.sessions6.webauthn_user_verified,` + + ` projections.sessions6.totp_checked_at,` + + ` projections.sessions6.otp_sms_checked_at,` + + ` projections.sessions6.otp_email_checked_at,` + + ` projections.sessions6.metadata,` + + ` projections.sessions6.token_id,` + + ` projections.sessions6.user_agent_fingerprint_id,` + + ` projections.sessions6.user_agent_ip,` + + ` projections.sessions6.user_agent_description,` + + ` projections.sessions6.user_agent_header` + + ` FROM projections.sessions6` + + ` LEFT JOIN projections.login_names2 ON projections.sessions6.user_id = projections.login_names2.user_id AND projections.sessions6.instance_id = projections.login_names2.instance_id` + + ` LEFT JOIN projections.users8_humans ON projections.sessions6.user_id = projections.users8_humans.user_id AND projections.sessions6.instance_id = projections.users8_humans.instance_id` + + ` LEFT JOIN projections.users8 ON projections.sessions6.user_id = projections.users8.id AND projections.sessions6.instance_id = projections.users8.instance_id` + ` AS OF SYSTEM TIME '-1 ms'`) - expectedSessionsQuery = regexp.QuoteMeta(`SELECT projections.sessions5.id,` + - ` projections.sessions5.creation_date,` + - ` projections.sessions5.change_date,` + - ` projections.sessions5.sequence,` + - ` projections.sessions5.state,` + - ` projections.sessions5.resource_owner,` + - ` projections.sessions5.creator,` + - ` projections.sessions5.user_id,` + - ` projections.sessions5.user_checked_at,` + + expectedSessionsQuery = regexp.QuoteMeta(`SELECT projections.sessions6.id,` + + ` projections.sessions6.creation_date,` + + ` projections.sessions6.change_date,` + + ` projections.sessions6.sequence,` + + ` projections.sessions6.state,` + + ` projections.sessions6.resource_owner,` + + ` projections.sessions6.creator,` + + ` projections.sessions6.user_id,` + + ` projections.sessions6.user_checked_at,` + ` projections.login_names2.login_name,` + ` projections.users8_humans.display_name,` + ` projections.users8.resource_owner,` + - ` projections.sessions5.password_checked_at,` + - ` projections.sessions5.intent_checked_at,` + - ` projections.sessions5.webauthn_checked_at,` + - ` projections.sessions5.webauthn_user_verified,` + - ` projections.sessions5.totp_checked_at,` + - ` projections.sessions5.otp_sms_checked_at,` + - ` projections.sessions5.otp_email_checked_at,` + - ` projections.sessions5.metadata,` + + ` projections.sessions6.password_checked_at,` + + ` projections.sessions6.intent_checked_at,` + + ` projections.sessions6.webauthn_checked_at,` + + ` projections.sessions6.webauthn_user_verified,` + + ` projections.sessions6.totp_checked_at,` + + ` projections.sessions6.otp_sms_checked_at,` + + ` projections.sessions6.otp_email_checked_at,` + + ` projections.sessions6.metadata,` + ` COUNT(*) OVER ()` + - ` FROM projections.sessions5` + - ` LEFT JOIN projections.login_names2 ON projections.sessions5.user_id = projections.login_names2.user_id AND projections.sessions5.instance_id = projections.login_names2.instance_id` + - ` LEFT JOIN projections.users8_humans ON projections.sessions5.user_id = projections.users8_humans.user_id AND projections.sessions5.instance_id = projections.users8_humans.instance_id` + - ` LEFT JOIN projections.users8 ON projections.sessions5.user_id = projections.users8.id AND projections.sessions5.instance_id = projections.users8.instance_id` + + ` FROM projections.sessions6` + + ` LEFT JOIN projections.login_names2 ON projections.sessions6.user_id = projections.login_names2.user_id AND projections.sessions6.instance_id = projections.login_names2.instance_id` + + ` LEFT JOIN projections.users8_humans ON projections.sessions6.user_id = projections.users8_humans.user_id AND projections.sessions6.instance_id = projections.users8_humans.instance_id` + + ` LEFT JOIN projections.users8 ON projections.sessions6.user_id = projections.users8.id AND projections.sessions6.instance_id = projections.users8.instance_id` + ` AS OF SYSTEM TIME '-1 ms'`) sessionCols = []string{ @@ -92,6 +99,10 @@ var ( "otp_email_checked_at", "metadata", "token", + "user_agent_fingerprint_id", + "user_agent_ip", + "user_agent_description", + "user_agent_header", } sessionsCols = []string{ @@ -443,6 +454,10 @@ func Test_SessionPrepare(t *testing.T) { testNow, []byte(`{"key": "dmFsdWU="}`), "tokenID", + "fingerPrintID", + "1.2.3.4", + "agentDescription", + []byte(`{"foo":["foo","bar"]}`), }, ), }, @@ -483,6 +498,12 @@ func Test_SessionPrepare(t *testing.T) { Metadata: map[string][]byte{ "key": []byte("value"), }, + UserAgent: domain.UserAgent{ + FingerprintID: gu.Ptr("fingerPrintID"), + IP: net.IPv4(1, 2, 3, 4), + Description: gu.Ptr("agentDescription"), + Header: http.Header{"foo": []string{"foo", "bar"}}, + }, }, }, { diff --git a/internal/repository/session/session.go b/internal/repository/session/session.go index c34ef52424..1966c178eb 100644 --- a/internal/repository/session/session.go +++ b/internal/repository/session/session.go @@ -35,6 +35,7 @@ const ( type AddedEvent struct { eventstore.BaseEvent `json:"-"` + UserAgent *domain.UserAgent `json:"user_agent,omitempty"` } func (e *AddedEvent) Data() interface{} { @@ -47,6 +48,7 @@ func (e *AddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { func NewAddedEvent(ctx context.Context, aggregate *eventstore.Aggregate, + userAgent *domain.UserAgent, ) *AddedEvent { return &AddedEvent{ BaseEvent: *eventstore.NewBaseEventForPush( @@ -54,6 +56,7 @@ func NewAddedEvent(ctx context.Context, aggregate, AddedType, ), + UserAgent: userAgent, } } diff --git a/proto/zitadel/session/v2beta/session.proto b/proto/zitadel/session/v2beta/session.proto index ddbe143361..b0bfd4fb14 100644 --- a/proto/zitadel/session/v2beta/session.proto +++ b/proto/zitadel/session/v2beta/session.proto @@ -39,6 +39,7 @@ message Session { description: "\"custom key value list\""; } ]; + UserAgent user_agent = 7; } message Factors { @@ -131,3 +132,18 @@ message SearchQuery { message IDsQuery { repeated string ids = 1; } + +message UserAgent { + optional string fingerprint_id = 1; + optional string ip = 2; + optional string description = 3; + + // A header may have multiple values. + // In Go, headers are defined + // as map[string][]string, but protobuf + // doesn't allow this scheme. + message HeaderValues { + repeated string values = 1; + } + map header = 4; +} \ No newline at end of file diff --git a/proto/zitadel/session/v2beta/session_service.proto b/proto/zitadel/session/v2beta/session_service.proto index 745c430019..718a29381f 100644 --- a/proto/zitadel/session/v2beta/session_service.proto +++ b/proto/zitadel/session/v2beta/session_service.proto @@ -274,6 +274,7 @@ message CreateSessionRequest{ } ]; RequestChallenges challenges = 3; + UserAgent user_agent = 4; } message CreateSessionResponse{