From e7904057e05978b724504d248e83c1700a98abc9 Mon Sep 17 00:00:00 2001 From: Silvan Date: Tue, 23 Apr 2024 13:23:50 +0200 Subject: [PATCH] perf: cache auth request in memory (#7824) * perf: cache auth request in memory (cherry picked from commit 25030c69b97b81350a9919faed34e7de9005f725) --- cmd/defaults.yaml | 3 + go.mod | 1 + go.sum | 2 + .../eventsourcing/eventstore/auth_request.go | 87 +++++++++------ .../eventstore/auth_request_test.go | 1 + .../repository/eventsourcing/repository.go | 7 +- .../auth_request/repository/cache/cache.go | 100 +++++++++++++++--- .../repository/mock/repository.mock.go | 12 +++ .../auth_request/repository/repository.go | 1 + internal/database/database.go | 2 +- 10 files changed, 165 insertions(+), 51 deletions(-) diff --git a/cmd/defaults.yaml b/cmd/defaults.yaml index 2437a4ef3d..5652ca2fcb 100644 --- a/cmd/defaults.yaml +++ b/cmd/defaults.yaml @@ -274,6 +274,9 @@ Auth: # from HandleActiveInstances duration in the past until the projections current time # If set to 0 (default), every instance is always considered active HandleActiveInstances: 0s #ZITADEL_AUTH_SPOOLER_HANDLEACTIVEINSTANCES + # Defines the amount of auth requests stored in the LRU caches. + # There are two caches implemented one for id and one for code + AmountOfCachedAuthRequests: 128 #ZITADEL_AUTH_AMOUNTOFCACHEDAUTHREQUESTS Admin: # See Projections.BulkLimit diff --git a/go.mod b/go.mod index 99fcd1f494..9fdc3a7cf6 100644 --- a/go.mod +++ b/go.mod @@ -36,6 +36,7 @@ require ( github.com/grpc-ecosystem/grpc-gateway v1.16.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 github.com/h2non/gock v1.2.0 + github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/improbable-eng/grpc-web v0.15.0 github.com/jackc/pgx/v5 v5.5.5 github.com/jarcoal/jpath v0.0.0-20140328210829-f76b8b2dbf52 diff --git a/go.sum b/go.sum index 4b6f340626..85420a32f2 100644 --- a/go.sum +++ b/go.sum @@ -385,6 +385,8 @@ github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09 github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request.go b/internal/auth/repository/eventsourcing/eventstore/auth_request.go index b7c85ab79e..ca95108205 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request.go @@ -656,39 +656,52 @@ func (repo *AuthRequestRepo) fillPolicies(ctx context.Context, request *domain.A } } - loginPolicy, idpProviders, err := repo.getLoginPolicyAndIDPProviders(ctx, orgID) - if err != nil { - return err + if request.LoginPolicy == nil || len(request.AllowedExternalIDPs) == 0 { + loginPolicy, idpProviders, err := repo.getLoginPolicyAndIDPProviders(ctx, orgID) + if err != nil { + return err + } + request.LoginPolicy = queryLoginPolicyToDomain(loginPolicy) + if len(idpProviders) > 0 { + request.AllowedExternalIDPs = idpProviders + } } - request.LoginPolicy = queryLoginPolicyToDomain(loginPolicy) - if idpProviders != nil { - request.AllowedExternalIDPs = idpProviders + if request.LockoutPolicy == nil { + lockoutPolicy, err := repo.getLockoutPolicy(ctx, orgID) + if err != nil { + return err + } + request.LockoutPolicy = lockoutPolicyToDomain(lockoutPolicy) } - lockoutPolicy, err := repo.getLockoutPolicy(ctx, orgID) - if err != nil { - return err + if request.PrivacyPolicy == nil { + privacyPolicy, err := repo.GetPrivacyPolicy(ctx, orgID) + if err != nil { + return err + } + request.PrivacyPolicy = privacyPolicy } - request.LockoutPolicy = lockoutPolicyToDomain(lockoutPolicy) - privacyPolicy, err := repo.GetPrivacyPolicy(ctx, orgID) - if err != nil { - return err + if request.LabelPolicy == nil { + labelPolicy, err := repo.getLabelPolicy(ctx, request.PrivateLabelingOrgID(orgID)) + if err != nil { + return err + } + request.LabelPolicy = labelPolicy } - request.PrivacyPolicy = privacyPolicy - labelPolicy, err := repo.getLabelPolicy(ctx, request.PrivateLabelingOrgID(orgID)) - if err != nil { - return err + if len(request.DefaultTranslations) == 0 { + defaultLoginTranslations, err := repo.getLoginTexts(ctx, instance.InstanceID()) + if err != nil { + return err + } + request.DefaultTranslations = defaultLoginTranslations } - request.LabelPolicy = labelPolicy - defaultLoginTranslations, err := repo.getLoginTexts(ctx, instance.InstanceID()) - if err != nil { - return err + if len(request.OrgTranslations) == 0 { + orgLoginTranslations, err := repo.getLoginTexts(ctx, orgID) + if err != nil { + return err + } + request.OrgTranslations = orgLoginTranslations } - request.DefaultTranslations = defaultLoginTranslations - orgLoginTranslations, err := repo.getLoginTexts(ctx, orgID) - if err != nil { - return err - } - request.OrgTranslations = orgLoginTranslations + repo.AuthRequests.CacheAuthRequest(ctx, request) return nil } @@ -801,6 +814,7 @@ func (repo *AuthRequestRepo) checkDomainDiscovery(ctx context.Context, request * } request.LoginHint = loginName request.Prompt = append(request.Prompt, domain.PromptCreate) // to trigger registration + repo.AuthRequests.CacheAuthRequest(ctx, request) return true, nil } @@ -872,22 +886,25 @@ func (repo *AuthRequestRepo) checkLoginNameInputForResourceOwner(ctx context.Con return nil, err } -func (repo *AuthRequestRepo) checkLoginPolicyWithResourceOwner(ctx context.Context, request *domain.AuthRequest, resourceOwner string) error { - loginPolicy, idpProviders, err := repo.getLoginPolicyAndIDPProviders(ctx, resourceOwner) - if err != nil { - return err +func (repo *AuthRequestRepo) checkLoginPolicyWithResourceOwner(ctx context.Context, request *domain.AuthRequest, resourceOwner string) (err error) { + if request.LoginPolicy == nil { + loginPolicy, idps, err := repo.getLoginPolicyAndIDPProviders(ctx, resourceOwner) + if err != nil { + return err + } + request.LoginPolicy = queryLoginPolicyToDomain(loginPolicy) + request.AllowedExternalIDPs = idps } - if len(request.LinkingUsers) != 0 && !loginPolicy.AllowExternalIDPs { + if len(request.LinkingUsers) != 0 && !request.LoginPolicy.AllowExternalIDP { return zerrors.ThrowInvalidArgument(nil, "LOGIN-s9sio", "Errors.User.NotAllowedToLink") } if len(request.LinkingUsers) != 0 { - exists := linkingIDPConfigExistingInAllowedIDPs(request.LinkingUsers, idpProviders) + exists := linkingIDPConfigExistingInAllowedIDPs(request.LinkingUsers, request.AllowedExternalIDPs) if !exists { return zerrors.ThrowInvalidArgument(nil, "LOGIN-Dj89o", "Errors.User.NotAllowedToLink") } } - request.LoginPolicy = queryLoginPolicyToDomain(loginPolicy) - request.AllowedExternalIDPs = idpProviders + repo.AuthRequests.CacheAuthRequest(ctx, request) return nil } diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go b/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go index 99e0c78ec6..d5dcf0257d 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go @@ -486,6 +486,7 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) { AuthRequests: func() cache.AuthRequestCache { m := mock.NewMockAuthRequestCache(gomock.NewController(t)) m.EXPECT().UpdateAuthRequest(gomock.Any(), gomock.Any()) + m.EXPECT().CacheAuthRequest(gomock.Any(), gomock.Any()) return m }(), userSessionViewProvider: &mockViewUserSession{ diff --git a/internal/auth/repository/eventsourcing/repository.go b/internal/auth/repository/eventsourcing/repository.go index 19cea1c6ba..20f753863c 100644 --- a/internal/auth/repository/eventsourcing/repository.go +++ b/internal/auth/repository/eventsourcing/repository.go @@ -17,8 +17,9 @@ import ( ) type Config struct { - SearchLimit uint64 - Spooler auth_handler.Config + SearchLimit uint64 + Spooler auth_handler.Config + AmountOfCachedAuthRequests uint16 } type EsRepository struct { @@ -39,7 +40,7 @@ func Start(ctx context.Context, conf Config, systemDefaults sd.SystemDefaults, c auth_handler.Register(ctx, conf.Spooler, view, queries) auth_handler.Start(ctx) - authReq := cache.Start(dbClient) + authReq := cache.Start(dbClient, conf.AmountOfCachedAuthRequests) userRepo := eventstore.UserRepo{ SearchLimit: conf.SearchLimit, diff --git a/internal/auth_request/repository/cache/cache.go b/internal/auth_request/repository/cache/cache.go index 63c442ef2d..9919d717de 100644 --- a/internal/auth_request/repository/cache/cache.go +++ b/internal/auth_request/repository/cache/cache.go @@ -8,6 +8,9 @@ import ( "fmt" "time" + "github.com/hashicorp/golang-lru/v2" + "github.com/zitadel/logging" + "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" @@ -15,12 +18,21 @@ import ( ) type AuthRequestCache struct { - client *database.DB + client *database.DB + idCache *lru.Cache[string, *domain.AuthRequest] + codeCache *lru.Cache[string, *domain.AuthRequest] } -func Start(dbClient *database.DB) *AuthRequestCache { +func Start(dbClient *database.DB, amountOfCachedAuthRequests uint16) *AuthRequestCache { + idCache, err := lru.New[string, *domain.AuthRequest](int(amountOfCachedAuthRequests)) + logging.OnError(err).Info("auth request cache disabled") + codeCache, err := lru.New[string, *domain.AuthRequest](int(amountOfCachedAuthRequests)) + logging.OnError(err).Info("auth request cache disabled") + return &AuthRequestCache{ - client: dbClient, + client: dbClient, + idCache: idCache, + codeCache: codeCache, } } @@ -29,22 +41,38 @@ func (c *AuthRequestCache) Health(ctx context.Context) error { } func (c *AuthRequestCache) GetAuthRequestByID(ctx context.Context, id string) (*domain.AuthRequest, error) { - return c.getAuthRequest("id", id, authz.GetInstance(ctx).InstanceID()) + if authRequest, ok := c.getCachedByID(ctx, id); ok { + return authRequest, nil + } + request, err := c.getAuthRequest(ctx, "id", id, authz.GetInstance(ctx).InstanceID()) + if err != nil { + return nil, err + } + c.CacheAuthRequest(ctx, request) + return request, nil } func (c *AuthRequestCache) GetAuthRequestByCode(ctx context.Context, code string) (*domain.AuthRequest, error) { - return c.getAuthRequest("code", code, authz.GetInstance(ctx).InstanceID()) + if authRequest, ok := c.getCachedByCode(ctx, code); ok { + return authRequest, nil + } + request, err := c.getAuthRequest(ctx, "code", code, authz.GetInstance(ctx).InstanceID()) + if err != nil { + return nil, err + } + c.CacheAuthRequest(ctx, request) + return request, nil } -func (c *AuthRequestCache) SaveAuthRequest(_ context.Context, request *domain.AuthRequest) error { - return c.saveAuthRequest(request, "INSERT INTO auth.auth_requests (id, request, instance_id, creation_date, change_date, request_type) VALUES($1, $2, $3, $4, $4, $5)", request.CreationDate, request.Request.Type()) +func (c *AuthRequestCache) SaveAuthRequest(ctx context.Context, request *domain.AuthRequest) error { + return c.saveAuthRequest(ctx, request, "INSERT INTO auth.auth_requests (id, request, instance_id, creation_date, change_date, request_type) VALUES($1, $2, $3, $4, $4, $5)", request.CreationDate, request.Request.Type()) } -func (c *AuthRequestCache) UpdateAuthRequest(_ context.Context, request *domain.AuthRequest) error { +func (c *AuthRequestCache) UpdateAuthRequest(ctx context.Context, request *domain.AuthRequest) error { if request.ChangeDate.IsZero() { request.ChangeDate = time.Now() } - return c.saveAuthRequest(request, "UPDATE auth.auth_requests SET request = $2, instance_id = $3, change_date = $4, code = $5 WHERE id = $1", request.ChangeDate, request.Code) + return c.saveAuthRequest(ctx, request, "UPDATE auth.auth_requests SET request = $2, instance_id = $3, change_date = $4, code = $5 WHERE id = $1", request.ChangeDate, request.Code) } func (c *AuthRequestCache) DeleteAuthRequest(ctx context.Context, id string) error { @@ -52,14 +80,16 @@ func (c *AuthRequestCache) DeleteAuthRequest(ctx context.Context, id string) err if err != nil { return zerrors.ThrowInternal(err, "CACHE-dsHw3", "unable to delete auth request") } + c.deleteFromCache(ctx, id) return nil } -func (c *AuthRequestCache) getAuthRequest(key, value, instanceID string) (*domain.AuthRequest, error) { +func (c *AuthRequestCache) getAuthRequest(ctx context.Context, key, value, instanceID string) (*domain.AuthRequest, error) { var b []byte var requestType domain.AuthRequestType query := fmt.Sprintf("SELECT request, request_type FROM auth.auth_requests WHERE instance_id = $1 and %s = $2", key) - err := c.client.QueryRow( + err := c.client.QueryRowContext( + ctx, func(row *sql.Row) error { return row.Scan(&b, &requestType) }, @@ -81,7 +111,7 @@ func (c *AuthRequestCache) getAuthRequest(key, value, instanceID string) (*domai return request, nil } -func (c *AuthRequestCache) saveAuthRequest(request *domain.AuthRequest, query string, date time.Time, param interface{}) error { +func (c *AuthRequestCache) saveAuthRequest(ctx context.Context, request *domain.AuthRequest, query string, date time.Time, param interface{}) error { b, err := json.Marshal(request) if err != nil { return zerrors.ThrowInternal(err, "CACHE-os0GH", "Errors.Internal") @@ -90,5 +120,51 @@ func (c *AuthRequestCache) saveAuthRequest(request *domain.AuthRequest, query st if err != nil { return zerrors.ThrowInternal(err, "CACHE-su3GK", "Errors.Internal") } + c.CacheAuthRequest(ctx, request) return nil } + +func (c *AuthRequestCache) getCachedByID(ctx context.Context, id string) (*domain.AuthRequest, bool) { + if c.idCache == nil { + return nil, false + } + authRequest, ok := c.idCache.Get(cacheKey(ctx, id)) + logging.WithFields("hit", ok, "type", "id").Info("get from auth request cache") + return authRequest, ok +} + +func (c *AuthRequestCache) getCachedByCode(ctx context.Context, code string) (*domain.AuthRequest, bool) { + if c.codeCache == nil { + return nil, false + } + authRequest, ok := c.codeCache.Get(cacheKey(ctx, code)) + logging.WithFields("hit", ok, "type", "code").Info("get from auth request cache") + return authRequest, ok +} + +func (c *AuthRequestCache) CacheAuthRequest(ctx context.Context, request *domain.AuthRequest) { + if c.idCache == nil { + return + } + c.idCache.Add(cacheKey(ctx, request.ID), request) + if request.Code != "" { + c.codeCache.Add(cacheKey(ctx, request.Code), request) + } +} + +func cacheKey(ctx context.Context, value string) string { + return fmt.Sprintf("%s-%s", authz.GetInstance(ctx).InstanceID(), value) +} + +func (c *AuthRequestCache) deleteFromCache(ctx context.Context, id string) { + if c.idCache == nil { + return + } + idKey := cacheKey(ctx, id) + request, ok := c.idCache.Get(idKey) + if !ok { + return + } + c.idCache.Remove(idKey) + c.codeCache.Remove(cacheKey(ctx, request.Code)) +} diff --git a/internal/auth_request/repository/mock/repository.mock.go b/internal/auth_request/repository/mock/repository.mock.go index 773018b214..c05e5010fe 100644 --- a/internal/auth_request/repository/mock/repository.mock.go +++ b/internal/auth_request/repository/mock/repository.mock.go @@ -40,6 +40,18 @@ func (m *MockAuthRequestCache) EXPECT() *MockAuthRequestCacheMockRecorder { return m.recorder } +// CacheAuthRequest mocks base method. +func (m *MockAuthRequestCache) CacheAuthRequest(arg0 context.Context, arg1 *domain.AuthRequest) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "CacheAuthRequest", arg0, arg1) +} + +// CacheAuthRequest indicates an expected call of CacheAuthRequest. +func (mr *MockAuthRequestCacheMockRecorder) CacheAuthRequest(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CacheAuthRequest", reflect.TypeOf((*MockAuthRequestCache)(nil).CacheAuthRequest), arg0, arg1) +} + // DeleteAuthRequest mocks base method. func (m *MockAuthRequestCache) DeleteAuthRequest(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() diff --git a/internal/auth_request/repository/repository.go b/internal/auth_request/repository/repository.go index 5c850656be..af9e291fa6 100644 --- a/internal/auth_request/repository/repository.go +++ b/internal/auth_request/repository/repository.go @@ -12,6 +12,7 @@ type AuthRequestCache interface { GetAuthRequestByID(ctx context.Context, id string) (*domain.AuthRequest, error) GetAuthRequestByCode(ctx context.Context, code string) (*domain.AuthRequest, error) SaveAuthRequest(ctx context.Context, request *domain.AuthRequest) error + CacheAuthRequest(ctx context.Context, request *domain.AuthRequest) UpdateAuthRequest(ctx context.Context, request *domain.AuthRequest) error DeleteAuthRequest(ctx context.Context, id string) error } diff --git a/internal/database/database.go b/internal/database/database.go index e64645294b..77baaa7bd2 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -40,7 +40,7 @@ func (db *DB) Query(scan func(*sql.Rows) error, query string, args ...any) error func (db *DB) QueryContext(ctx context.Context, scan func(rows *sql.Rows) error, query string, args ...any) (err error) { ctx, spanBeginTx := tracing.NewNamedSpan(ctx, "db.BeginTx") - tx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) + tx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true, Isolation: sql.LevelReadCommitted}) spanBeginTx.EndWithError(err) if err != nil { return err