From 250f2344c8c2292ca9b861cdd12223d0b4719d43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Mon, 4 Nov 2024 11:44:51 +0100 Subject: [PATCH] feat(cache): redis cache (#8822) # Which Problems Are Solved Add a cache implementation using Redis single mode. This does not add support for Redis Cluster or sentinel. # How the Problems Are Solved Added the `internal/cache/redis` package. All operations occur atomically, including setting of secondary indexes, using LUA scripts where needed. The [`miniredis`](https://github.com/alicebob/miniredis) package is used to run unit tests. # Additional Changes - Move connector code to `internal/cache/connector/...` and remove duplicate code from `query` and `command` packages. - Fix a missed invalidation on the restrictions projection # Additional Context Closes #8130 --- .github/workflows/core-integration-test.yml | 4 + Makefile | 2 +- cmd/defaults.yaml | 138 +++- cmd/mirror/projections.go | 13 +- cmd/setup/03.go | 6 +- cmd/setup/config.go | 4 +- cmd/setup/config_change.go | 5 +- cmd/setup/setup.go | 11 +- cmd/start/config.go | 4 +- cmd/start/start.go | 11 +- go.mod | 5 + go.sum | 14 + internal/cache/cache.go | 44 +- internal/cache/connector/connector.go | 69 ++ internal/cache/connector/gomap/connector.go | 23 + internal/cache/{ => connector}/gomap/gomap.go | 8 +- .../cache/{ => connector}/gomap/gomap_test.go | 24 +- internal/cache/{ => connector}/noop/noop.go | 0 internal/cache/connector/pg/connector.go | 28 + .../pg/create_partition.sql.tmpl | 0 internal/cache/{ => connector}/pg/delete.sql | 0 internal/cache/{ => connector}/pg/get.sql | 0 .../cache/{ => connector}/pg/invalidate.sql | 0 internal/cache/{ => connector}/pg/pg.go | 40 +- internal/cache/{ => connector}/pg/pg_test.go | 89 ++- internal/cache/{ => connector}/pg/prune.sql | 0 internal/cache/{ => connector}/pg/set.sql | 0 .../cache/{ => connector}/pg/truncate.sql | 0 internal/cache/connector/redis/_remove.lua | 10 + internal/cache/connector/redis/_select.lua | 3 + internal/cache/connector/redis/_util.lua | 17 + internal/cache/connector/redis/connector.go | 154 ++++ internal/cache/connector/redis/get.lua | 29 + internal/cache/connector/redis/invalidate.lua | 9 + internal/cache/connector/redis/redis.go | 172 +++++ internal/cache/connector/redis/redis_test.go | 714 ++++++++++++++++++ internal/cache/connector/redis/set.lua | 27 + internal/cache/connector_enumer.go | 98 +++ internal/cache/pruner.go | 14 +- internal/cache/pruner_test.go | 2 +- internal/cache/purpose_enumer.go | 82 ++ internal/command/cache.go | 69 +- internal/command/command.go | 7 +- internal/command/instance_test.go | 2 +- internal/command/milestone_test.go | 6 +- .../integration/config/docker-compose.yaml | 6 + internal/integration/config/zitadel.yaml | 16 +- internal/query/cache.go | 72 +- internal/query/instance.go | 3 +- internal/query/query.go | 6 +- 50 files changed, 1767 insertions(+), 293 deletions(-) create mode 100644 internal/cache/connector/connector.go create mode 100644 internal/cache/connector/gomap/connector.go rename internal/cache/{ => connector}/gomap/gomap.go (95%) rename internal/cache/{ => connector}/gomap/gomap_test.go (94%) rename internal/cache/{ => connector}/noop/noop.go (100%) create mode 100644 internal/cache/connector/pg/connector.go rename internal/cache/{ => connector}/pg/create_partition.sql.tmpl (100%) rename internal/cache/{ => connector}/pg/delete.sql (100%) rename internal/cache/{ => connector}/pg/get.sql (100%) rename internal/cache/{ => connector}/pg/invalidate.sql (100%) rename internal/cache/{ => connector}/pg/pg.go (78%) rename internal/cache/{ => connector}/pg/pg_test.go (83%) rename internal/cache/{ => connector}/pg/prune.sql (100%) rename internal/cache/{ => connector}/pg/set.sql (100%) rename internal/cache/{ => connector}/pg/truncate.sql (100%) create mode 100644 internal/cache/connector/redis/_remove.lua create mode 100644 internal/cache/connector/redis/_select.lua create mode 100644 internal/cache/connector/redis/_util.lua create mode 100644 internal/cache/connector/redis/connector.go create mode 100644 internal/cache/connector/redis/get.lua create mode 100644 internal/cache/connector/redis/invalidate.lua create mode 100644 internal/cache/connector/redis/redis.go create mode 100644 internal/cache/connector/redis/redis_test.go create mode 100644 internal/cache/connector/redis/set.lua create mode 100644 internal/cache/connector_enumer.go create mode 100644 internal/cache/purpose_enumer.go diff --git a/.github/workflows/core-integration-test.yml b/.github/workflows/core-integration-test.yml index 2673d4addf..cc9d898f5c 100644 --- a/.github/workflows/core-integration-test.yml +++ b/.github/workflows/core-integration-test.yml @@ -36,6 +36,10 @@ jobs: --health-timeout 5s --health-retries 5 --health-start-period 10s + cache: + image: redis:latest + ports: + - 6379:6379 steps: - uses: actions/checkout@v4 diff --git a/Makefile b/Makefile index 6a41683390..27e76c0614 100644 --- a/Makefile +++ b/Makefile @@ -113,7 +113,7 @@ core_unit_test: .PHONY: core_integration_db_up core_integration_db_up: - docker compose -f internal/integration/config/docker-compose.yaml up --pull always --wait $${INTEGRATION_DB_FLAVOR} + docker compose -f internal/integration/config/docker-compose.yaml up --pull always --wait $${INTEGRATION_DB_FLAVOR} cache .PHONY: core_integration_db_down core_integration_db_down: diff --git a/cmd/defaults.yaml b/cmd/defaults.yaml index f691fd2af2..8015dd8dad 100644 --- a/cmd/defaults.yaml +++ b/cmd/defaults.yaml @@ -185,34 +185,136 @@ Database: # Caches are EXPERIMENTAL. The following config may have breaking changes in the future. # If no config is provided, caching is disabled by default. -# Caches: +Caches: # Connectors are reused by caches. -# Connectors: + Connectors: # Memory connector works with local server memory. # It is the simplest (and probably fastest) cache implementation. # Unsuitable for deployments with multiple containers, # as each container's cache may hold a different state of the same object. -# Memory: -# Enabled: true + Memory: + Enabled: false # AutoPrune removes invalidated or expired object from the cache. -# AutoPrune: -# Interval: 15m -# TimeOut: 30s + AutoPrune: + Interval: 1m + TimeOut: 5s + Postgres: + Enabled: false + AutoPrune: + Interval: 15m + TimeOut: 30s + Redis: + Enabled: false + # The network type, either tcp or unix. + # Default is tcp. + # Network string + # host:port address. + Addr: localhost:6379 + # ClientName will execute the `CLIENT SETNAME ClientName` command for each conn. + ClientName: ZITADEL_cache + # Use the specified Username to authenticate the current connection + # with one of the connections defined in the ACL list when connecting + # to a Redis 6.0 instance, or greater, that is using the Redis ACL system. + Username: zitadel + # Optional password. Must match the password specified in the + # requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower), + # or the User Password when connecting to a Redis 6.0 instance, or greater, + # that is using the Redis ACL system. + Password: "" + # Each ZITADEL cache uses an incremental DB namespace. + # This option offsets the first DB so it doesn't conflict with other databases on the same server. + # Note that ZITADEL uses FLUSHDB command to truncate a cache. + # This can have destructive consequences when overlapping DB namespaces are used. + DBOffset: 10 + # Maximum number of retries before giving up. + # Default is 3 retries; -1 (not 0) disables retries. + MaxRetries: 3 + # Minimum backoff between each retry. + # Default is 8 milliseconds; -1 disables backoff. + MinRetryBackoff: 8ms + # Maximum backoff between each retry. + # Default is 512 milliseconds; -1 disables backoff. + MaxRetryBackoff: 512ms + # Dial timeout for establishing new connections. + # Default is 5 seconds. + DialTimeout: 1s + # Timeout for socket reads. If reached, commands will fail + # with a timeout instead of blocking. Supported values: + # - `0` - default timeout (3 seconds). + # - `-1` - no timeout (block indefinitely). + # - `-2` - disables SetReadDeadline calls completely. + ReadTimeout: 100ms + # Timeout for socket writes. If reached, commands will fail + # with a timeout instead of blocking. Supported values: + # - `0` - default timeout (3 seconds). + # - `-1` - no timeout (block indefinitely). + # - `-2` - disables SetWriteDeadline calls completely. + WriteTimeout: 100ms + # Type of connection pool. + # true for FIFO pool, false for LIFO pool. + # Note that FIFO has slightly higher overhead compared to LIFO, + # but it helps closing idle connections faster reducing the pool size. + PoolFIFO: false + # Base number of socket connections. + # Default is 10 connections per every available CPU as reported by runtime.GOMAXPROCS. + # If there is not enough connections in the pool, new connections will be allocated in excess of PoolSize, + # you can limit it through MaxActiveConns + PoolSize: 20 + # Amount of time client waits for connection if all connections + # are busy before returning an error. + # Default is ReadTimeout + 1 second. + PoolTimeout: 100ms + # Minimum number of idle connections which is useful when establishing + # new connection is slow. + # Default is 0. the idle connections are not closed by default. + MinIdleConns: 5 + # Maximum number of idle connections. + # Default is 0. the idle connections are not closed by default. + MaxIdleConns: 10 + # Maximum number of connections allocated by the pool at a given time. + # When zero, there is no limit on the number of connections in the pool. + MaxActiveConns: 40 + # ConnMaxIdleTime is the maximum amount of time a connection may be idle. + # Should be less than server's timeout. + # Expired connections may be closed lazily before reuse. + # If d <= 0, connections are not closed due to a connection's idle time. + # Default is 30 minutes. -1 disables idle timeout check. + ConnMaxIdleTime: 30m + # ConnMaxLifetime is the maximum amount of time a connection may be reused. + # Expired connections may be closed lazily before reuse. + # If <= 0, connections are not closed due to a connection's age. + # Default is to not close idle connections. + ConnMaxLifetime: -1 + # Enable TLS server authentication using the default system bundle. + EnableTLS: false + # Disable set-lib on connect. Default is false. + DisableIndentity: false + # Add suffix to client name. Default is empty. + IdentitySuffix: "" # Instance caches auth middleware instances, gettable by domain or ID. -# Instance: + Instance: # Connector must be enabled above. # When connector is empty, this cache will be disabled. -# Connector: "memory" -# MaxAge: 1h -# LastUsage: 10m -# - # Log enables cache-specific logging. Default to error log to stdout when omitted. -# Log: -# Level: debug -# AddSource: true -# Formatter: -# Format: text + Connector: "" + MaxAge: 1h + LastUsage: 10m + # Log enables cache-specific logging. Default to error log to stderr when omitted. + Log: + Level: error + AddSource: true + Formatter: + Format: text + # Milestones caches instance milestone state, gettable by instance ID + Milestones: + Connector: "" + MaxAge: 1h + LastUsage: 10m + Log: + Level: error + AddSource: true + Formatter: + Format: text Machine: # Cloud-hosted VMs need to specify their metadata endpoint so that the machine can be uniquely identified. diff --git a/cmd/mirror/projections.go b/cmd/mirror/projections.go index 9b7ec02cb8..cffc4921ca 100644 --- a/cmd/mirror/projections.go +++ b/cmd/mirror/projections.go @@ -25,7 +25,7 @@ import ( auth_view "github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/view" "github.com/zitadel/zitadel/internal/authz" authz_es "github.com/zitadel/zitadel/internal/authz/repository/eventsourcing/eventstore" - "github.com/zitadel/zitadel/internal/cache" + "github.com/zitadel/zitadel/internal/cache/connector" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/config/systemdefaults" crypto_db "github.com/zitadel/zitadel/internal/crypto/database" @@ -72,7 +72,7 @@ type ProjectionsConfig struct { EncryptionKeys *encryption.EncryptionKeyConfig SystemAPIUsers map[string]*internal_authz.SystemAPIUser Eventstore *eventstore.Config - Caches *cache.CachesConfig + Caches *connector.CachesConfig Admin admin_es.Config Auth auth_es.Config @@ -128,13 +128,16 @@ func projections( sessionTokenVerifier := internal_authz.SessionTokenVerifier(keys.OIDC) + cacheConnectors, err := connector.StartConnectors(config.Caches, client) + logging.OnError(err).Fatal("unable to start caches") + queries, err := query.StartQueries( ctx, es, esV4.Querier, client, client, - config.Caches, + cacheConnectors, config.Projections, config.SystemDefaults, keys.IDPConfig, @@ -161,9 +164,9 @@ func projections( DisplayName: config.WebAuthNName, ExternalSecure: config.ExternalSecure, } - commands, err := command.StartCommands( + commands, err := command.StartCommands(ctx, es, - config.Caches, + cacheConnectors, config.SystemDefaults, config.InternalAuthZ.RolePermissionMappings, staticStorage, diff --git a/cmd/setup/03.go b/cmd/setup/03.go index 4311418388..4d4231ea9c 100644 --- a/cmd/setup/03.go +++ b/cmd/setup/03.go @@ -9,6 +9,7 @@ import ( "golang.org/x/text/language" "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/cache/connector" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/config/systemdefaults" "github.com/zitadel/zitadel/internal/crypto" @@ -64,8 +65,9 @@ func (mig *FirstInstance) Execute(ctx context.Context, _ eventstore.Event) error return err } - cmd, err := command.StartCommands(mig.es, - nil, + cmd, err := command.StartCommands(ctx, + mig.es, + connector.Connectors{}, mig.defaults, mig.zitadelRoles, nil, diff --git a/cmd/setup/config.go b/cmd/setup/config.go index 09044456ea..57681c8bc1 100644 --- a/cmd/setup/config.go +++ b/cmd/setup/config.go @@ -15,7 +15,7 @@ import ( internal_authz "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/oidc" "github.com/zitadel/zitadel/internal/api/ui/login" - "github.com/zitadel/zitadel/internal/cache" + "github.com/zitadel/zitadel/internal/cache/connector" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/config/hook" "github.com/zitadel/zitadel/internal/config/systemdefaults" @@ -31,7 +31,7 @@ import ( type Config struct { ForMirror bool Database database.Config - Caches *cache.CachesConfig + Caches *connector.CachesConfig SystemDefaults systemdefaults.SystemDefaults InternalAuthZ internal_authz.Config ExternalDomain string diff --git a/cmd/setup/config_change.go b/cmd/setup/config_change.go index 08f0c3c3d6..f38508af2c 100644 --- a/cmd/setup/config_change.go +++ b/cmd/setup/config_change.go @@ -3,6 +3,7 @@ package setup import ( "context" + "github.com/zitadel/zitadel/internal/cache/connector" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/config/systemdefaults" "github.com/zitadel/zitadel/internal/eventstore" @@ -31,9 +32,9 @@ func (mig *externalConfigChange) Check(lastRun map[string]interface{}) bool { } func (mig *externalConfigChange) Execute(ctx context.Context, _ eventstore.Event) error { - cmd, err := command.StartCommands( + cmd, err := command.StartCommands(ctx, mig.es, - nil, + connector.Connectors{}, mig.defaults, nil, nil, diff --git a/cmd/setup/setup.go b/cmd/setup/setup.go index 7ffef5e853..e0784654b1 100644 --- a/cmd/setup/setup.go +++ b/cmd/setup/setup.go @@ -22,6 +22,7 @@ import ( auth_view "github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/view" "github.com/zitadel/zitadel/internal/authz" authz_es "github.com/zitadel/zitadel/internal/authz/repository/eventsourcing/eventstore" + "github.com/zitadel/zitadel/internal/cache/connector" "github.com/zitadel/zitadel/internal/command" cryptoDB "github.com/zitadel/zitadel/internal/crypto/database" "github.com/zitadel/zitadel/internal/database" @@ -346,13 +347,17 @@ func initProjections( } sessionTokenVerifier := internal_authz.SessionTokenVerifier(keys.OIDC) + + cacheConnectors, err := connector.StartConnectors(config.Caches, queryDBClient) + logging.OnError(err).Fatal("unable to start caches") + queries, err := query.StartQueries( ctx, eventstoreClient, eventstoreV4.Querier, queryDBClient, projectionDBClient, - config.Caches, + cacheConnectors, config.Projections, config.SystemDefaults, keys.IDPConfig, @@ -394,9 +399,9 @@ func initProjections( permissionCheck := func(ctx context.Context, permission, orgID, resourceID string) (err error) { return internal_authz.CheckPermission(ctx, authZRepo, config.InternalAuthZ.RolePermissionMappings, permission, orgID, resourceID) } - commands, err := command.StartCommands( + commands, err := command.StartCommands(ctx, eventstoreClient, - config.Caches, + cacheConnectors, config.SystemDefaults, config.InternalAuthZ.RolePermissionMappings, staticStorage, diff --git a/cmd/start/config.go b/cmd/start/config.go index ea432e6296..26c4b84b50 100644 --- a/cmd/start/config.go +++ b/cmd/start/config.go @@ -18,7 +18,7 @@ import ( "github.com/zitadel/zitadel/internal/api/ui/console" "github.com/zitadel/zitadel/internal/api/ui/login" auth_es "github.com/zitadel/zitadel/internal/auth/repository/eventsourcing" - "github.com/zitadel/zitadel/internal/cache" + "github.com/zitadel/zitadel/internal/cache/connector" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/config/hook" "github.com/zitadel/zitadel/internal/config/network" @@ -49,7 +49,7 @@ type Config struct { HTTP1HostHeader string WebAuthNName string Database database.Config - Caches *cache.CachesConfig + Caches *connector.CachesConfig Tracing tracing.Config Metrics metrics.Config Profiler profiler.Config diff --git a/cmd/start/start.go b/cmd/start/start.go index 8de1105307..e816b5bb52 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -69,6 +69,7 @@ import ( "github.com/zitadel/zitadel/internal/authz" authz_repo "github.com/zitadel/zitadel/internal/authz/repository" authz_es "github.com/zitadel/zitadel/internal/authz/repository/eventsourcing/eventstore" + "github.com/zitadel/zitadel/internal/cache/connector" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/crypto" cryptoDB "github.com/zitadel/zitadel/internal/crypto/database" @@ -177,6 +178,10 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server })) sessionTokenVerifier := internal_authz.SessionTokenVerifier(keys.OIDC) + cacheConnectors, err := connector.StartConnectors(config.Caches, queryDBClient) + if err != nil { + return fmt.Errorf("unable to start caches: %w", err) + } queries, err := query.StartQueries( ctx, @@ -184,7 +189,7 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server eventstoreV4.Querier, queryDBClient, projectionDBClient, - config.Caches, + cacheConnectors, config.Projections, config.SystemDefaults, keys.IDPConfig, @@ -222,9 +227,9 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server DisplayName: config.WebAuthNName, ExternalSecure: config.ExternalSecure, } - commands, err := command.StartCommands( + commands, err := command.StartCommands(ctx, eventstoreClient, - config.Caches, + cacheConnectors, config.SystemDefaults, config.InternalAuthZ.RolePermissionMappings, storage, diff --git a/go.mod b/go.mod index 1e4f67eb7d..cf4e755605 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.24.0 github.com/Masterminds/squirrel v1.5.4 github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b + github.com/alicebob/miniredis/v2 v2.33.0 github.com/benbjohnson/clock v1.3.5 github.com/boombuler/barcode v1.0.2 github.com/brianvoe/gofakeit/v6 v6.28.0 @@ -52,6 +53,7 @@ require ( github.com/pashagolub/pgxmock/v4 v4.3.0 github.com/pquerna/otp v1.4.0 github.com/rakyll/statik v0.1.7 + github.com/redis/go-redis/v9 v9.7.0 github.com/rs/cors v1.11.1 github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/sony/sonyflake v1.2.0 @@ -94,8 +96,10 @@ require ( cloud.google.com/go/auth v0.6.1 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.48.0 // indirect + github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect github.com/bmatcuk/doublestar/v4 v4.7.1 // indirect github.com/crewjam/httperr v0.2.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/go-ini/ini v1.67.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -121,6 +125,7 @@ require ( github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect github.com/zenazn/goji v1.0.1 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/time v0.5.0 // indirect diff --git a/go.sum b/go.sum index 8645aa8417..015fea1b80 100644 --- a/go.sum +++ b/go.sum @@ -56,6 +56,10 @@ github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRF github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk= +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/alicebob/miniredis/v2 v2.33.0 h1:uvTF0EDeu9RLnUEG27Db5I68ESoIxTiXbNUiji6lZrA= +github.com/alicebob/miniredis/v2 v2.33.0/go.mod h1:MhP4a3EU7aENRi9aO+tHfTBZicLqQevyi/DJpoj6mi0= github.com/amdonov/xmlsig v0.1.0 h1:i0iQ3neKLmUhcfIRgiiR3eRPKgXZj+n5lAfqnfKoeXI= github.com/amdonov/xmlsig v0.1.0/go.mod h1:jTR/jO0E8fSl/cLvMesP+RjxyV4Ux4WL1Ip64ZnQpA0= github.com/andybalholm/cascadia v1.1.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9PqHO0sqidkEA4Y= @@ -87,6 +91,10 @@ github.com/boombuler/barcode v1.0.2 h1:79yrbttoZrLGkL/oOI8hBrUKucwOL0oOjUgEguGMc github.com/boombuler/barcode v1.0.2/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/brianvoe/gofakeit/v6 v6.28.0 h1:Xib46XXuQfmlLS2EXRuJpqcw8St6qSZz75OUo0tgAW4= github.com/brianvoe/gofakeit/v6 v6.28.0/go.mod h1:Xj58BMSnFqcn/fAQeSK+/PLtC5kSb7FJIq4JyGa8vEs= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ= github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= github.com/cenkalti/backoff/v4 v4.1.1/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= @@ -127,6 +135,8 @@ github.com/descope/virtualwebauthn v1.0.2/go.mod h1:iJvinjD1iZYqQ09J5lF0+795OdDb github.com/desertbit/timer v0.0.0-20180107155436-c41aec40b27f h1:U5y3Y5UE0w7amNe7Z5G/twsBW0KEalRQXZzf8ufSh9I= github.com/desertbit/timer v0.0.0-20180107155436-c41aec40b27f/go.mod h1:xH/i4TFMt8koVQZ6WFms69WAsDWr2XsYL3Hkl7jkoLE= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dop251/goja v0.0.0-20240627195025-eb1f15ee67d2 h1:4Ew88p5s9dwIk5/woUyqI9BD89NgZoUNH4/rM/h2UDg= @@ -620,6 +630,8 @@ github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoG github.com/rakyll/statik v0.1.7 h1:OF3QCZUuyPxuGEP7B4ypUa7sB/iHtqOTDYZXGM8KOdQ= github.com/rakyll/statik v0.1.7/go.mod h1:AlZONWzMtEnMs7W4e/1LURLiI49pIMmp6V9Unghqrcc= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/redis/go-redis/v9 v9.7.0 h1:HhLSs+B6O021gwzl+locl0zEDnyNkxMtf/Z3NNBMa9E= +github.com/redis/go-redis/v9 v9.7.0/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= @@ -719,6 +731,8 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= github.com/zenazn/goji v1.0.1 h1:4lbD8Mx2h7IvloP7r2C0D6ltZP6Ufip8Hn0wmSK5LR8= github.com/zenazn/goji v1.0.1/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= github.com/zitadel/logging v0.6.1 h1:Vyzk1rl9Kq9RCevcpX6ujUaTYFX43aa4LkvV1TvUk+Y= diff --git a/internal/cache/cache.go b/internal/cache/cache.go index c6d01b928e..9e92f50988 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -6,8 +6,16 @@ import ( "time" "github.com/zitadel/logging" +) - "github.com/zitadel/zitadel/internal/database/postgres" +// Purpose describes which object types are stored by a cache. +type Purpose int + +//go:generate enumer -type Purpose -transform snake -trimprefix Purpose +const ( + PurposeUnspecified Purpose = iota + PurposeAuthzInstance + PurposeMilestones ) // Cache stores objects with a value of type `V`. @@ -72,18 +80,19 @@ type Entry[I, K comparable] interface { Keys(index I) (key []K) } -type CachesConfig struct { - Connectors struct { - Memory MemoryConnectorConfig - Postgres PostgresConnectorConfig - // Redis redis.Config? - } - Instance *CacheConfig - Milestones *CacheConfig -} +type Connector int -type CacheConfig struct { - Connector string +//go:generate enumer -type Connector -transform snake -trimprefix Connector -linecomment -text +const ( + // Empty line comment ensures empty string for unspecified value + ConnectorUnspecified Connector = iota // + ConnectorMemory + ConnectorPostgres + ConnectorRedis +) + +type Config struct { + Connector Connector // Age since an object was added to the cache, // after which the object is considered invalid. @@ -99,14 +108,3 @@ type CacheConfig struct { // By default only errors are logged to stdout. Log *logging.Config } - -type MemoryConnectorConfig struct { - Enabled bool - AutoPrune AutoPruneConfig -} - -type PostgresConnectorConfig struct { - Enabled bool - AutoPrune AutoPruneConfig - Connection postgres.Config -} diff --git a/internal/cache/connector/connector.go b/internal/cache/connector/connector.go new file mode 100644 index 0000000000..0c4fb9ccc6 --- /dev/null +++ b/internal/cache/connector/connector.go @@ -0,0 +1,69 @@ +// Package connector provides glue between the [cache.Cache] interface and implementations from the connector sub-packages. +package connector + +import ( + "context" + "fmt" + + "github.com/zitadel/zitadel/internal/cache" + "github.com/zitadel/zitadel/internal/cache/connector/gomap" + "github.com/zitadel/zitadel/internal/cache/connector/noop" + "github.com/zitadel/zitadel/internal/cache/connector/pg" + "github.com/zitadel/zitadel/internal/cache/connector/redis" + "github.com/zitadel/zitadel/internal/database" +) + +type CachesConfig struct { + Connectors struct { + Memory gomap.Config + Postgres pg.Config + Redis redis.Config + } + Instance *cache.Config + Milestones *cache.Config +} + +type Connectors struct { + Config CachesConfig + Memory *gomap.Connector + Postgres *pg.Connector + Redis *redis.Connector +} + +func StartConnectors(conf *CachesConfig, client *database.DB) (Connectors, error) { + if conf == nil { + return Connectors{}, nil + } + return Connectors{ + Config: *conf, + Memory: gomap.NewConnector(conf.Connectors.Memory), + Postgres: pg.NewConnector(conf.Connectors.Postgres, client), + Redis: redis.NewConnector(conf.Connectors.Redis), + }, nil +} + +func StartCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, purpose cache.Purpose, conf *cache.Config, connectors Connectors) (cache.Cache[I, K, V], error) { + if conf == nil || conf.Connector == cache.ConnectorUnspecified { + return noop.NewCache[I, K, V](), nil + } + if conf.Connector == cache.ConnectorMemory && connectors.Memory != nil { + c := gomap.NewCache[I, K, V](background, indices, *conf) + connectors.Memory.Config.StartAutoPrune(background, c, purpose) + return c, nil + } + if conf.Connector == cache.ConnectorPostgres && connectors.Postgres != nil { + c, err := pg.NewCache[I, K, V](background, purpose, *conf, indices, connectors.Postgres) + if err != nil { + return nil, fmt.Errorf("start cache: %w", err) + } + connectors.Postgres.Config.AutoPrune.StartAutoPrune(background, c, purpose) + return c, nil + } + if conf.Connector == cache.ConnectorRedis && connectors.Redis != nil { + db := connectors.Redis.Config.DBOffset + int(purpose) + c := redis.NewCache[I, K, V](*conf, connectors.Redis, db, indices) + return c, nil + } + + return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector) +} diff --git a/internal/cache/connector/gomap/connector.go b/internal/cache/connector/gomap/connector.go new file mode 100644 index 0000000000..7ed09c7a72 --- /dev/null +++ b/internal/cache/connector/gomap/connector.go @@ -0,0 +1,23 @@ +package gomap + +import ( + "github.com/zitadel/zitadel/internal/cache" +) + +type Config struct { + Enabled bool + AutoPrune cache.AutoPruneConfig +} + +type Connector struct { + Config cache.AutoPruneConfig +} + +func NewConnector(config Config) *Connector { + if !config.Enabled { + return nil + } + return &Connector{ + Config: config.AutoPrune, + } +} diff --git a/internal/cache/gomap/gomap.go b/internal/cache/connector/gomap/gomap.go similarity index 95% rename from internal/cache/gomap/gomap.go rename to internal/cache/connector/gomap/gomap.go index 160fe4e315..dff9f04143 100644 --- a/internal/cache/gomap/gomap.go +++ b/internal/cache/connector/gomap/gomap.go @@ -14,14 +14,14 @@ import ( ) type mapCache[I, K comparable, V cache.Entry[I, K]] struct { - config *cache.CacheConfig + config *cache.Config indexMap map[I]*index[K, V] logger *slog.Logger } // NewCache returns an in-memory Cache implementation based on the builtin go map type. // Object values are stored as-is and there is no encoding or decoding involved. -func NewCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, config cache.CacheConfig) cache.PrunerCache[I, K, V] { +func NewCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, config cache.Config) cache.PrunerCache[I, K, V] { m := &mapCache[I, K, V]{ config: &config, indexMap: make(map[I]*index[K, V], len(indices)), @@ -116,7 +116,7 @@ func (c *mapCache[I, K, V]) Truncate(ctx context.Context) error { type index[K comparable, V any] struct { mutex sync.RWMutex - config *cache.CacheConfig + config *cache.Config entries map[K]*entry[V] } @@ -177,7 +177,7 @@ type entry[V any] struct { lastUse atomic.Int64 // UnixMicro time } -func (e *entry[V]) isValid(c *cache.CacheConfig) bool { +func (e *entry[V]) isValid(c *cache.Config) bool { if e.invalid.Load() { return false } diff --git a/internal/cache/gomap/gomap_test.go b/internal/cache/connector/gomap/gomap_test.go similarity index 94% rename from internal/cache/gomap/gomap_test.go rename to internal/cache/connector/gomap/gomap_test.go index 7f41900833..810788b554 100644 --- a/internal/cache/gomap/gomap_test.go +++ b/internal/cache/connector/gomap/gomap_test.go @@ -41,7 +41,7 @@ func (o *testObject) Keys(index testIndex) []string { } func Test_mapCache_Get(t *testing.T) { - c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.CacheConfig{ + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ MaxAge: time.Second, LastUseAge: time.Second / 4, Log: &logging.Config{ @@ -103,7 +103,7 @@ func Test_mapCache_Get(t *testing.T) { } func Test_mapCache_Invalidate(t *testing.T) { - c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.CacheConfig{ + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ MaxAge: time.Second, LastUseAge: time.Second / 4, Log: &logging.Config{ @@ -124,7 +124,7 @@ func Test_mapCache_Invalidate(t *testing.T) { } func Test_mapCache_Delete(t *testing.T) { - c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.CacheConfig{ + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ MaxAge: time.Second, LastUseAge: time.Second / 4, Log: &logging.Config{ @@ -157,7 +157,7 @@ func Test_mapCache_Delete(t *testing.T) { } func Test_mapCache_Prune(t *testing.T) { - c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.CacheConfig{ + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ MaxAge: time.Second, LastUseAge: time.Second / 4, Log: &logging.Config{ @@ -193,7 +193,7 @@ func Test_mapCache_Prune(t *testing.T) { } func Test_mapCache_Truncate(t *testing.T) { - c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.CacheConfig{ + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ MaxAge: time.Second, LastUseAge: time.Second / 4, Log: &logging.Config{ @@ -235,7 +235,7 @@ func Test_entry_isValid(t *testing.T) { tests := []struct { name string fields fields - config *cache.CacheConfig + config *cache.Config want bool }{ { @@ -245,7 +245,7 @@ func Test_entry_isValid(t *testing.T) { invalid: true, lastUse: time.Now(), }, - config: &cache.CacheConfig{ + config: &cache.Config{ MaxAge: time.Minute, LastUseAge: time.Second, }, @@ -258,7 +258,7 @@ func Test_entry_isValid(t *testing.T) { invalid: false, lastUse: time.Now(), }, - config: &cache.CacheConfig{ + config: &cache.Config{ MaxAge: time.Minute, LastUseAge: time.Second, }, @@ -271,7 +271,7 @@ func Test_entry_isValid(t *testing.T) { invalid: false, lastUse: time.Now(), }, - config: &cache.CacheConfig{ + config: &cache.Config{ LastUseAge: time.Second, }, want: true, @@ -283,7 +283,7 @@ func Test_entry_isValid(t *testing.T) { invalid: false, lastUse: time.Now().Add(-(time.Second * 2)), }, - config: &cache.CacheConfig{ + config: &cache.Config{ MaxAge: time.Minute, LastUseAge: time.Second, }, @@ -296,7 +296,7 @@ func Test_entry_isValid(t *testing.T) { invalid: false, lastUse: time.Now().Add(-(time.Second * 2)), }, - config: &cache.CacheConfig{ + config: &cache.Config{ MaxAge: time.Minute, }, want: true, @@ -308,7 +308,7 @@ func Test_entry_isValid(t *testing.T) { invalid: false, lastUse: time.Now(), }, - config: &cache.CacheConfig{ + config: &cache.Config{ MaxAge: time.Minute, LastUseAge: time.Second, }, diff --git a/internal/cache/noop/noop.go b/internal/cache/connector/noop/noop.go similarity index 100% rename from internal/cache/noop/noop.go rename to internal/cache/connector/noop/noop.go diff --git a/internal/cache/connector/pg/connector.go b/internal/cache/connector/pg/connector.go new file mode 100644 index 0000000000..9a89cf5f6a --- /dev/null +++ b/internal/cache/connector/pg/connector.go @@ -0,0 +1,28 @@ +package pg + +import ( + "github.com/zitadel/zitadel/internal/cache" + "github.com/zitadel/zitadel/internal/database" +) + +type Config struct { + Enabled bool + AutoPrune cache.AutoPruneConfig +} + +type Connector struct { + PGXPool + Dialect string + Config Config +} + +func NewConnector(config Config, client *database.DB) *Connector { + if !config.Enabled { + return nil + } + return &Connector{ + PGXPool: client.Pool, + Dialect: client.Type(), + Config: config, + } +} diff --git a/internal/cache/pg/create_partition.sql.tmpl b/internal/cache/connector/pg/create_partition.sql.tmpl similarity index 100% rename from internal/cache/pg/create_partition.sql.tmpl rename to internal/cache/connector/pg/create_partition.sql.tmpl diff --git a/internal/cache/pg/delete.sql b/internal/cache/connector/pg/delete.sql similarity index 100% rename from internal/cache/pg/delete.sql rename to internal/cache/connector/pg/delete.sql diff --git a/internal/cache/pg/get.sql b/internal/cache/connector/pg/get.sql similarity index 100% rename from internal/cache/pg/get.sql rename to internal/cache/connector/pg/get.sql diff --git a/internal/cache/pg/invalidate.sql b/internal/cache/connector/pg/invalidate.sql similarity index 100% rename from internal/cache/pg/invalidate.sql rename to internal/cache/connector/pg/invalidate.sql diff --git a/internal/cache/pg/pg.go b/internal/cache/connector/pg/pg.go similarity index 78% rename from internal/cache/pg/pg.go rename to internal/cache/connector/pg/pg.go index aee0315327..18215b68ed 100644 --- a/internal/cache/pg/pg.go +++ b/internal/cache/connector/pg/pg.go @@ -40,25 +40,25 @@ type PGXPool interface { } type pgCache[I ~int, K ~string, V cache.Entry[I, K]] struct { - name string - config *cache.CacheConfig - indices []I - pool PGXPool - logger *slog.Logger + purpose cache.Purpose + config *cache.Config + indices []I + connector *Connector + logger *slog.Logger } // NewCache returns a cache that stores and retrieves objects using PostgreSQL unlogged tables. -func NewCache[I ~int, K ~string, V cache.Entry[I, K]](ctx context.Context, name string, config cache.CacheConfig, indices []I, pool PGXPool, dialect string) (cache.PrunerCache[I, K, V], error) { +func NewCache[I ~int, K ~string, V cache.Entry[I, K]](ctx context.Context, purpose cache.Purpose, config cache.Config, indices []I, connector *Connector) (cache.PrunerCache[I, K, V], error) { c := &pgCache[I, K, V]{ - name: name, - config: &config, - indices: indices, - pool: pool, - logger: config.Log.Slog().With("cache_name", name), + purpose: purpose, + config: &config, + indices: indices, + connector: connector, + logger: config.Log.Slog().With("cache_purpose", purpose), } c.logger.InfoContext(ctx, "pg cache logging enabled") - if dialect == "postgres" { + if connector.Dialect == "postgres" { if err := c.createPartition(ctx); err != nil { return nil, err } @@ -68,10 +68,10 @@ func NewCache[I ~int, K ~string, V cache.Entry[I, K]](ctx context.Context, name func (c *pgCache[I, K, V]) createPartition(ctx context.Context) error { var query strings.Builder - if err := createPartitionTmpl.Execute(&query, c.name); err != nil { + if err := createPartitionTmpl.Execute(&query, c.purpose.String()); err != nil { return err } - _, err := c.pool.Exec(ctx, query.String()) + _, err := c.connector.Exec(ctx, query.String()) return err } @@ -87,7 +87,7 @@ func (c *pgCache[I, K, V]) set(ctx context.Context, entry V) (err error) { keys := c.indexKeysFromEntry(entry) c.logger.DebugContext(ctx, "pg cache set", "index_key", keys) - _, err = c.pool.Exec(ctx, setQuery, c.name, keys, entry) + _, err = c.connector.Exec(ctx, setQuery, c.purpose.String(), keys, entry) if err != nil { c.logger.ErrorContext(ctx, "pg cache set", "err", err) return err @@ -117,7 +117,7 @@ func (c *pgCache[I, K, V]) get(ctx context.Context, index I, key K) (value V, er if !slices.Contains(c.indices, index) { return value, cache.NewIndexUnknownErr(index) } - err = c.pool.QueryRow(ctx, getQuery, c.name, index, key, c.config.MaxAge, c.config.LastUseAge).Scan(&value) + err = c.connector.QueryRow(ctx, getQuery, c.purpose.String(), index, key, c.config.MaxAge, c.config.LastUseAge).Scan(&value) return value, err } @@ -125,7 +125,7 @@ func (c *pgCache[I, K, V]) Invalidate(ctx context.Context, index I, keys ...K) ( ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - _, err = c.pool.Exec(ctx, invalidateQuery, c.name, index, keys) + _, err = c.connector.Exec(ctx, invalidateQuery, c.purpose.String(), index, keys) c.logger.DebugContext(ctx, "pg cache invalidate", "index", index, "keys", keys) return err } @@ -134,7 +134,7 @@ func (c *pgCache[I, K, V]) Delete(ctx context.Context, index I, keys ...K) (err ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - _, err = c.pool.Exec(ctx, deleteQuery, c.name, index, keys) + _, err = c.connector.Exec(ctx, deleteQuery, c.purpose.String(), index, keys) c.logger.DebugContext(ctx, "pg cache delete", "index", index, "keys", keys) return err } @@ -143,7 +143,7 @@ func (c *pgCache[I, K, V]) Prune(ctx context.Context) (err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - _, err = c.pool.Exec(ctx, pruneQuery, c.name, c.config.MaxAge, c.config.LastUseAge) + _, err = c.connector.Exec(ctx, pruneQuery, c.purpose.String(), c.config.MaxAge, c.config.LastUseAge) c.logger.DebugContext(ctx, "pg cache prune") return err } @@ -152,7 +152,7 @@ func (c *pgCache[I, K, V]) Truncate(ctx context.Context) (err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - _, err = c.pool.Exec(ctx, truncateQuery, c.name) + _, err = c.connector.Exec(ctx, truncateQuery, c.purpose.String()) c.logger.DebugContext(ctx, "pg cache truncate") return err } diff --git a/internal/cache/pg/pg_test.go b/internal/cache/connector/pg/pg_test.go similarity index 83% rename from internal/cache/pg/pg_test.go rename to internal/cache/connector/pg/pg_test.go index 9206a220f2..f5980ad845 100644 --- a/internal/cache/pg/pg_test.go +++ b/internal/cache/connector/pg/pg_test.go @@ -67,7 +67,7 @@ func TestNewCache(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - conf := cache.CacheConfig{ + conf := cache.Config{ Log: &logging.Config{ Level: "debug", AddSource: true, @@ -76,8 +76,12 @@ func TestNewCache(t *testing.T) { pool, err := pgxmock.NewPool() require.NoError(t, err) tt.expect(pool) + connector := &Connector{ + PGXPool: pool, + Dialect: "postgres", + } - c, err := NewCache[testIndex, string, *testObject](context.Background(), cacheName, conf, testIndices, pool, "postgres") + c, err := NewCache[testIndex, string, *testObject](context.Background(), cachePurpose, conf, testIndices, connector) require.ErrorIs(t, err, tt.wantErr) if tt.wantErr == nil { assert.NotNil(t, c) @@ -111,7 +115,7 @@ func Test_pgCache_Set(t *testing.T) { }, expect: func(ppi pgxmock.PgxCommonIface) { ppi.ExpectExec(queryExpect). - WithArgs("test", + WithArgs(cachePurpose.String(), []indexKey[testIndex, string]{ {IndexID: testIndexID, IndexKey: "id1"}, {IndexID: testIndexName, IndexKey: "foo"}, @@ -135,7 +139,7 @@ func Test_pgCache_Set(t *testing.T) { }, expect: func(ppi pgxmock.PgxCommonIface) { ppi.ExpectExec(queryExpect). - WithArgs("test", + WithArgs(cachePurpose.String(), []indexKey[testIndex, string]{ {IndexID: testIndexID, IndexKey: "id1"}, {IndexID: testIndexName, IndexKey: "foo"}, @@ -151,7 +155,7 @@ func Test_pgCache_Set(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c, pool := prepareCache(t, cache.CacheConfig{}) + c, pool := prepareCache(t, cache.Config{}) defer pool.Close() tt.expect(pool) @@ -173,7 +177,7 @@ func Test_pgCache_Get(t *testing.T) { } tests := []struct { name string - config cache.CacheConfig + config cache.Config args args expect func(pgxmock.PgxCommonIface) want *testObject @@ -181,7 +185,7 @@ func Test_pgCache_Get(t *testing.T) { }{ { name: "invalid index", - config: cache.CacheConfig{ + config: cache.Config{ MaxAge: time.Minute, LastUseAge: time.Second, }, @@ -194,7 +198,7 @@ func Test_pgCache_Get(t *testing.T) { }, { name: "no rows", - config: cache.CacheConfig{ + config: cache.Config{ MaxAge: 0, LastUseAge: 0, }, @@ -204,14 +208,14 @@ func Test_pgCache_Get(t *testing.T) { }, expect: func(pci pgxmock.PgxCommonIface) { pci.ExpectQuery(queryExpect). - WithArgs("test", testIndexID, "id1", time.Duration(0), time.Duration(0)). + WithArgs(cachePurpose.String(), testIndexID, "id1", time.Duration(0), time.Duration(0)). WillReturnRows(pgxmock.NewRows([]string{"payload"})) }, wantOk: false, }, { name: "error", - config: cache.CacheConfig{ + config: cache.Config{ MaxAge: 0, LastUseAge: 0, }, @@ -221,14 +225,14 @@ func Test_pgCache_Get(t *testing.T) { }, expect: func(pci pgxmock.PgxCommonIface) { pci.ExpectQuery(queryExpect). - WithArgs("test", testIndexID, "id1", time.Duration(0), time.Duration(0)). + WithArgs(cachePurpose.String(), testIndexID, "id1", time.Duration(0), time.Duration(0)). WillReturnError(pgx.ErrTxClosed) }, wantOk: false, }, { name: "ok", - config: cache.CacheConfig{ + config: cache.Config{ MaxAge: time.Minute, LastUseAge: time.Second, }, @@ -238,7 +242,7 @@ func Test_pgCache_Get(t *testing.T) { }, expect: func(pci pgxmock.PgxCommonIface) { pci.ExpectQuery(queryExpect). - WithArgs("test", testIndexID, "id1", time.Minute, time.Second). + WithArgs(cachePurpose.String(), testIndexID, "id1", time.Minute, time.Second). WillReturnRows( pgxmock.NewRows([]string{"payload"}).AddRow(&testObject{ ID: "id1", @@ -276,14 +280,14 @@ func Test_pgCache_Invalidate(t *testing.T) { } tests := []struct { name string - config cache.CacheConfig + config cache.Config args args expect func(pgxmock.PgxCommonIface) wantErr error }{ { name: "error", - config: cache.CacheConfig{ + config: cache.Config{ MaxAge: 0, LastUseAge: 0, }, @@ -293,14 +297,14 @@ func Test_pgCache_Invalidate(t *testing.T) { }, expect: func(pci pgxmock.PgxCommonIface) { pci.ExpectExec(queryExpect). - WithArgs("test", testIndexID, []string{"id1", "id2"}). + WithArgs(cachePurpose.String(), testIndexID, []string{"id1", "id2"}). WillReturnError(pgx.ErrTxClosed) }, wantErr: pgx.ErrTxClosed, }, { name: "ok", - config: cache.CacheConfig{ + config: cache.Config{ MaxAge: time.Minute, LastUseAge: time.Second, }, @@ -310,7 +314,7 @@ func Test_pgCache_Invalidate(t *testing.T) { }, expect: func(pci pgxmock.PgxCommonIface) { pci.ExpectExec(queryExpect). - WithArgs("test", testIndexID, []string{"id1", "id2"}). + WithArgs(cachePurpose.String(), testIndexID, []string{"id1", "id2"}). WillReturnResult(pgxmock.NewResult("DELETE", 1)) }, }, @@ -338,14 +342,14 @@ func Test_pgCache_Delete(t *testing.T) { } tests := []struct { name string - config cache.CacheConfig + config cache.Config args args expect func(pgxmock.PgxCommonIface) wantErr error }{ { name: "error", - config: cache.CacheConfig{ + config: cache.Config{ MaxAge: 0, LastUseAge: 0, }, @@ -355,14 +359,14 @@ func Test_pgCache_Delete(t *testing.T) { }, expect: func(pci pgxmock.PgxCommonIface) { pci.ExpectExec(queryExpect). - WithArgs("test", testIndexID, []string{"id1", "id2"}). + WithArgs(cachePurpose.String(), testIndexID, []string{"id1", "id2"}). WillReturnError(pgx.ErrTxClosed) }, wantErr: pgx.ErrTxClosed, }, { name: "ok", - config: cache.CacheConfig{ + config: cache.Config{ MaxAge: time.Minute, LastUseAge: time.Second, }, @@ -372,7 +376,7 @@ func Test_pgCache_Delete(t *testing.T) { }, expect: func(pci pgxmock.PgxCommonIface) { pci.ExpectExec(queryExpect). - WithArgs("test", testIndexID, []string{"id1", "id2"}). + WithArgs(cachePurpose.String(), testIndexID, []string{"id1", "id2"}). WillReturnResult(pgxmock.NewResult("DELETE", 1)) }, }, @@ -396,32 +400,32 @@ func Test_pgCache_Prune(t *testing.T) { queryExpect := regexp.QuoteMeta(pruneQuery) tests := []struct { name string - config cache.CacheConfig + config cache.Config expect func(pgxmock.PgxCommonIface) wantErr error }{ { name: "error", - config: cache.CacheConfig{ + config: cache.Config{ MaxAge: 0, LastUseAge: 0, }, expect: func(pci pgxmock.PgxCommonIface) { pci.ExpectExec(queryExpect). - WithArgs("test", time.Duration(0), time.Duration(0)). + WithArgs(cachePurpose.String(), time.Duration(0), time.Duration(0)). WillReturnError(pgx.ErrTxClosed) }, wantErr: pgx.ErrTxClosed, }, { name: "ok", - config: cache.CacheConfig{ + config: cache.Config{ MaxAge: time.Minute, LastUseAge: time.Second, }, expect: func(pci pgxmock.PgxCommonIface) { pci.ExpectExec(queryExpect). - WithArgs("test", time.Minute, time.Second). + WithArgs(cachePurpose.String(), time.Minute, time.Second). WillReturnResult(pgxmock.NewResult("DELETE", 1)) }, }, @@ -445,32 +449,32 @@ func Test_pgCache_Truncate(t *testing.T) { queryExpect := regexp.QuoteMeta(truncateQuery) tests := []struct { name string - config cache.CacheConfig + config cache.Config expect func(pgxmock.PgxCommonIface) wantErr error }{ { name: "error", - config: cache.CacheConfig{ + config: cache.Config{ MaxAge: 0, LastUseAge: 0, }, expect: func(pci pgxmock.PgxCommonIface) { pci.ExpectExec(queryExpect). - WithArgs("test"). + WithArgs(cachePurpose.String()). WillReturnError(pgx.ErrTxClosed) }, wantErr: pgx.ErrTxClosed, }, { name: "ok", - config: cache.CacheConfig{ + config: cache.Config{ MaxAge: time.Minute, LastUseAge: time.Second, }, expect: func(pci pgxmock.PgxCommonIface) { pci.ExpectExec(queryExpect). - WithArgs("test"). + WithArgs(cachePurpose.String()). WillReturnResult(pgxmock.NewResult("DELETE", 1)) }, }, @@ -491,18 +495,18 @@ func Test_pgCache_Truncate(t *testing.T) { } const ( - cacheName = "test" - expectedCreatePartitionQuery = `create unlogged table if not exists cache.objects_test + cachePurpose = cache.PurposeAuthzInstance + expectedCreatePartitionQuery = `create unlogged table if not exists cache.objects_authz_instance partition of cache.objects -for values in ('test'); +for values in ('authz_instance'); -create unlogged table if not exists cache.string_keys_test +create unlogged table if not exists cache.string_keys_authz_instance partition of cache.string_keys -for values in ('test'); +for values in ('authz_instance'); ` ) -func prepareCache(t *testing.T, conf cache.CacheConfig) (cache.PrunerCache[testIndex, string, *testObject], pgxmock.PgxPoolIface) { +func prepareCache(t *testing.T, conf cache.Config) (cache.PrunerCache[testIndex, string, *testObject], pgxmock.PgxPoolIface) { conf.Log = &logging.Config{ Level: "debug", AddSource: true, @@ -512,8 +516,11 @@ func prepareCache(t *testing.T, conf cache.CacheConfig) (cache.PrunerCache[testI pool.ExpectExec(regexp.QuoteMeta(expectedCreatePartitionQuery)). WillReturnResult(pgxmock.NewResult("CREATE TABLE", 0)) - - c, err := NewCache[testIndex, string, *testObject](context.Background(), cacheName, conf, testIndices, pool, "postgres") + connector := &Connector{ + PGXPool: pool, + Dialect: "postgres", + } + c, err := NewCache[testIndex, string, *testObject](context.Background(), cachePurpose, conf, testIndices, connector) require.NoError(t, err) return c, pool } diff --git a/internal/cache/pg/prune.sql b/internal/cache/connector/pg/prune.sql similarity index 100% rename from internal/cache/pg/prune.sql rename to internal/cache/connector/pg/prune.sql diff --git a/internal/cache/pg/set.sql b/internal/cache/connector/pg/set.sql similarity index 100% rename from internal/cache/pg/set.sql rename to internal/cache/connector/pg/set.sql diff --git a/internal/cache/pg/truncate.sql b/internal/cache/connector/pg/truncate.sql similarity index 100% rename from internal/cache/pg/truncate.sql rename to internal/cache/connector/pg/truncate.sql diff --git a/internal/cache/connector/redis/_remove.lua b/internal/cache/connector/redis/_remove.lua new file mode 100644 index 0000000000..cbd7f5a797 --- /dev/null +++ b/internal/cache/connector/redis/_remove.lua @@ -0,0 +1,10 @@ +local function remove(object_id) + local setKey = keySetKey(object_id) + local keys = redis.call("SMEMBERS", setKey) + local n = #keys + for i = 1, n do + redis.call("DEL", keys[i]) + end + redis.call("DEL", setKey) + redis.call("DEL", object_id) +end diff --git a/internal/cache/connector/redis/_select.lua b/internal/cache/connector/redis/_select.lua new file mode 100644 index 0000000000..d87bb3f5c0 --- /dev/null +++ b/internal/cache/connector/redis/_select.lua @@ -0,0 +1,3 @@ +-- SELECT ensures the DB namespace for each script. +-- When used, it consumes the first ARGV entry. +redis.call("SELECT", ARGV[1]) diff --git a/internal/cache/connector/redis/_util.lua b/internal/cache/connector/redis/_util.lua new file mode 100644 index 0000000000..4563c3df6e --- /dev/null +++ b/internal/cache/connector/redis/_util.lua @@ -0,0 +1,17 @@ +-- keySetKey returns the redis key of the set containing all keys to the object. +local function keySetKey (object_id) + return object_id .. "-keys" +end + +local function getTime() + return tonumber(redis.call('TIME')[1]) +end + +-- getCall wrapts redis.call so a nil is returned instead of false. +local function getCall (...) + local result = redis.call(...) + if result == false then + return nil + end + return result +end diff --git a/internal/cache/connector/redis/connector.go b/internal/cache/connector/redis/connector.go new file mode 100644 index 0000000000..2d0498dfa0 --- /dev/null +++ b/internal/cache/connector/redis/connector.go @@ -0,0 +1,154 @@ +package redis + +import ( + "crypto/tls" + "time" + + "github.com/redis/go-redis/v9" +) + +type Config struct { + Enabled bool + + // The network type, either tcp or unix. + // Default is tcp. + Network string + // host:port address. + Addr string + // ClientName will execute the `CLIENT SETNAME ClientName` command for each conn. + ClientName string + // Use the specified Username to authenticate the current connection + // with one of the connections defined in the ACL list when connecting + // to a Redis 6.0 instance, or greater, that is using the Redis ACL system. + Username string + // Optional password. Must match the password specified in the + // requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower), + // or the User Password when connecting to a Redis 6.0 instance, or greater, + // that is using the Redis ACL system. + Password string + // Each ZITADEL cache uses an incremental DB namespace. + // This option offsets the first DB so it doesn't conflict with other databases on the same server. + // Note that ZITADEL uses FLUSHDB command to truncate a cache. + // This can have destructive consequences when overlapping DB namespaces are used. + DBOffset int + + // Maximum number of retries before giving up. + // Default is 3 retries; -1 (not 0) disables retries. + MaxRetries int + // Minimum backoff between each retry. + // Default is 8 milliseconds; -1 disables backoff. + MinRetryBackoff time.Duration + // Maximum backoff between each retry. + // Default is 512 milliseconds; -1 disables backoff. + MaxRetryBackoff time.Duration + + // Dial timeout for establishing new connections. + // Default is 5 seconds. + DialTimeout time.Duration + // Timeout for socket reads. If reached, commands will fail + // with a timeout instead of blocking. Supported values: + // - `0` - default timeout (3 seconds). + // - `-1` - no timeout (block indefinitely). + // - `-2` - disables SetReadDeadline calls completely. + ReadTimeout time.Duration + // Timeout for socket writes. If reached, commands will fail + // with a timeout instead of blocking. Supported values: + // - `0` - default timeout (3 seconds). + // - `-1` - no timeout (block indefinitely). + // - `-2` - disables SetWriteDeadline calls completely. + WriteTimeout time.Duration + + // Type of connection pool. + // true for FIFO pool, false for LIFO pool. + // Note that FIFO has slightly higher overhead compared to LIFO, + // but it helps closing idle connections faster reducing the pool size. + PoolFIFO bool + // Base number of socket connections. + // Default is 10 connections per every available CPU as reported by runtime.GOMAXPROCS. + // If there is not enough connections in the pool, new connections will be allocated in excess of PoolSize, + // you can limit it through MaxActiveConns + PoolSize int + // Amount of time client waits for connection if all connections + // are busy before returning an error. + // Default is ReadTimeout + 1 second. + PoolTimeout time.Duration + // Minimum number of idle connections which is useful when establishing + // new connection is slow. + // Default is 0. the idle connections are not closed by default. + MinIdleConns int + // Maximum number of idle connections. + // Default is 0. the idle connections are not closed by default. + MaxIdleConns int + // Maximum number of connections allocated by the pool at a given time. + // When zero, there is no limit on the number of connections in the pool. + MaxActiveConns int + // ConnMaxIdleTime is the maximum amount of time a connection may be idle. + // Should be less than server's timeout. + // + // Expired connections may be closed lazily before reuse. + // If d <= 0, connections are not closed due to a connection's idle time. + // + // Default is 30 minutes. -1 disables idle timeout check. + ConnMaxIdleTime time.Duration + // ConnMaxLifetime is the maximum amount of time a connection may be reused. + // + // Expired connections may be closed lazily before reuse. + // If <= 0, connections are not closed due to a connection's age. + // + // Default is to not close idle connections. + ConnMaxLifetime time.Duration + + EnableTLS bool + + // Disable set-lib on connect. Default is false. + DisableIndentity bool + + // Add suffix to client name. Default is empty. + IdentitySuffix string +} + +type Connector struct { + *redis.Client + Config Config +} + +func NewConnector(config Config) *Connector { + if !config.Enabled { + return nil + } + return &Connector{ + Client: redis.NewClient(optionsFromConfig(config)), + Config: config, + } +} + +func optionsFromConfig(c Config) *redis.Options { + opts := &redis.Options{ + Network: c.Network, + Addr: c.Addr, + ClientName: c.ClientName, + Protocol: 3, + Username: c.Username, + Password: c.Password, + MaxRetries: c.MaxRetries, + MinRetryBackoff: c.MinRetryBackoff, + MaxRetryBackoff: c.MaxRetryBackoff, + DialTimeout: c.DialTimeout, + ReadTimeout: c.ReadTimeout, + WriteTimeout: c.WriteTimeout, + ContextTimeoutEnabled: true, + PoolFIFO: c.PoolFIFO, + PoolTimeout: c.PoolTimeout, + MinIdleConns: c.MinIdleConns, + MaxIdleConns: c.MaxIdleConns, + MaxActiveConns: c.MaxActiveConns, + ConnMaxIdleTime: c.ConnMaxIdleTime, + ConnMaxLifetime: c.ConnMaxLifetime, + DisableIndentity: c.DisableIndentity, + IdentitySuffix: c.IdentitySuffix, + } + if c.EnableTLS { + opts.TLSConfig = new(tls.Config) + } + return opts +} diff --git a/internal/cache/connector/redis/get.lua b/internal/cache/connector/redis/get.lua new file mode 100644 index 0000000000..cfb3e89d8a --- /dev/null +++ b/internal/cache/connector/redis/get.lua @@ -0,0 +1,29 @@ +local result = redis.call("GET", KEYS[1]) +if result == false then + return nil +end +local object_id = tostring(result) + +local object = getCall("HGET", object_id, "object") +if object == nil then + -- object expired, but there are keys that need to be cleaned up + remove(object_id) + return nil +end + +-- max-age must be checked manually +local expiry = getCall("HGET", object_id, "expiry") +if not (expiry == nil) and expiry > 0 then + if getTime() > expiry then + remove(object_id) + return nil + end +end + +local usage_lifetime = getCall("HGET", object_id, "usage_lifetime") +-- reset usage based TTL +if not (usage_lifetime == nil) and tonumber(usage_lifetime) > 0 then + redis.call('EXPIRE', object_id, usage_lifetime) +end + +return object diff --git a/internal/cache/connector/redis/invalidate.lua b/internal/cache/connector/redis/invalidate.lua new file mode 100644 index 0000000000..e2a766ac72 --- /dev/null +++ b/internal/cache/connector/redis/invalidate.lua @@ -0,0 +1,9 @@ +local n = #KEYS +for i = 1, n do + local result = redis.call("GET", KEYS[i]) + if result == false then + return nil + end + local object_id = tostring(result) + remove(object_id) +end diff --git a/internal/cache/connector/redis/redis.go b/internal/cache/connector/redis/redis.go new file mode 100644 index 0000000000..fef15f6d55 --- /dev/null +++ b/internal/cache/connector/redis/redis.go @@ -0,0 +1,172 @@ +package redis + +import ( + "context" + _ "embed" + "encoding/json" + "errors" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/google/uuid" + "github.com/redis/go-redis/v9" + + "github.com/zitadel/zitadel/internal/cache" + "github.com/zitadel/zitadel/internal/telemetry/tracing" +) + +var ( + //go:embed _select.lua + selectComponent string + //go:embed _util.lua + utilComponent string + //go:embed _remove.lua + removeComponent string + //go:embed set.lua + setScript string + //go:embed get.lua + getScript string + //go:embed invalidate.lua + invalidateScript string + + // Don't mind the creative "import" + setParsed = redis.NewScript(strings.Join([]string{selectComponent, utilComponent, setScript}, "\n")) + getParsed = redis.NewScript(strings.Join([]string{selectComponent, utilComponent, removeComponent, getScript}, "\n")) + invalidateParsed = redis.NewScript(strings.Join([]string{selectComponent, utilComponent, removeComponent, invalidateScript}, "\n")) +) + +type redisCache[I, K comparable, V cache.Entry[I, K]] struct { + db int + config *cache.Config + indices []I + connector *Connector + logger *slog.Logger +} + +// NewCache returns a cache that stores and retrieves object using single Redis. +func NewCache[I, K comparable, V cache.Entry[I, K]](config cache.Config, client *Connector, db int, indices []I) cache.Cache[I, K, V] { + return &redisCache[I, K, V]{ + config: &config, + db: db, + indices: indices, + connector: client, + logger: config.Log.Slog(), + } +} + +func (c *redisCache[I, K, V]) Set(ctx context.Context, value V) { + if _, err := c.set(ctx, value); err != nil { + c.logger.ErrorContext(ctx, "redis cache set", "err", err) + } +} + +func (c *redisCache[I, K, V]) set(ctx context.Context, value V) (objectID string, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + // Internal ID used for the object + objectID = uuid.NewString() + keys := []string{objectID} + // flatten the secondary keys + for _, index := range c.indices { + keys = append(keys, c.redisIndexKeys(index, value.Keys(index)...)...) + } + var buf strings.Builder + err = json.NewEncoder(&buf).Encode(value) + if err != nil { + return "", err + } + err = setParsed.Run(ctx, c.connector, keys, + c.db, // DB namespace + buf.String(), // object + int64(c.config.LastUseAge/time.Second), // usage_lifetime + int64(c.config.MaxAge/time.Second), // max_age, + ).Err() + // redis.Nil is always returned because the script doesn't have a return value. + if err != nil && !errors.Is(err, redis.Nil) { + return "", err + } + return objectID, nil +} + +func (c *redisCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) { + var ( + obj any + err error + ) + ctx, span := tracing.NewSpan(ctx) + defer func() { + if errors.Is(err, redis.Nil) { + err = nil + } + span.EndWithError(err) + }() + + logger := c.logger.With("index", index, "key", key) + obj, err = getParsed.Run(ctx, c.connector, c.redisIndexKeys(index, key), c.db).Result() + if err != nil && !errors.Is(err, redis.Nil) { + logger.ErrorContext(ctx, "redis cache get", "err", err) + return value, false + } + data, ok := obj.(string) + if !ok { + logger.With("err", err).InfoContext(ctx, "redis cache miss") + return value, false + } + err = json.NewDecoder(strings.NewReader(data)).Decode(&value) + if err != nil { + logger.ErrorContext(ctx, "redis cache get", "err", fmt.Errorf("decode: %w", err)) + return value, false + } + return value, true +} + +func (c *redisCache[I, K, V]) Invalidate(ctx context.Context, index I, key ...K) (err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + if len(key) == 0 { + return nil + } + err = invalidateParsed.Run(ctx, c.connector, c.redisIndexKeys(index, key...), c.db).Err() + // redis.Nil is always returned because the script doesn't have a return value. + if err != nil && !errors.Is(err, redis.Nil) { + return err + } + return nil +} + +func (c *redisCache[I, K, V]) Delete(ctx context.Context, index I, key ...K) (err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + if len(key) == 0 { + return nil + } + pipe := c.connector.Pipeline() + pipe.Select(ctx, c.db) + pipe.Del(ctx, c.redisIndexKeys(index, key...)...) + _, err = pipe.Exec(ctx) + return err +} + +func (c *redisCache[I, K, V]) Truncate(ctx context.Context) (err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + pipe := c.connector.Pipeline() + pipe.Select(ctx, c.db) + pipe.FlushDB(ctx) + _, err = pipe.Exec(ctx) + return err +} + +func (c *redisCache[I, K, V]) redisIndexKeys(index I, keys ...K) []string { + out := make([]string, len(keys)) + for i, k := range keys { + out[i] = fmt.Sprintf("%v:%v", index, k) + } + return out +} diff --git a/internal/cache/connector/redis/redis_test.go b/internal/cache/connector/redis/redis_test.go new file mode 100644 index 0000000000..3f45be1502 --- /dev/null +++ b/internal/cache/connector/redis/redis_test.go @@ -0,0 +1,714 @@ +package redis + +import ( + "context" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/cache" +) + +type testIndex int + +const ( + testIndexID testIndex = iota + testIndexName +) + +const ( + testDB = 99 +) + +var testIndices = []testIndex{ + testIndexID, + testIndexName, +} + +type testObject struct { + ID string + Name []string +} + +func (o *testObject) Keys(index testIndex) []string { + switch index { + case testIndexID: + return []string{o.ID} + case testIndexName: + return o.Name + default: + return nil + } +} + +func Test_redisCache_set(t *testing.T) { + type args struct { + ctx context.Context + value *testObject + } + tests := []struct { + name string + config cache.Config + args args + assertions func(t *testing.T, s *miniredis.Miniredis, objectID string) + wantErr error + }{ + { + name: "ok", + config: cache.Config{}, + args: args{ + ctx: context.Background(), + value: &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }, + }, + assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) { + s.CheckGet(t, "0:one", objectID) + s.CheckGet(t, "1:foo", objectID) + s.CheckGet(t, "1:bar", objectID) + assert.Empty(t, s.HGet(objectID, "expiry")) + assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object")) + }, + }, + { + name: "with last use TTL", + config: cache.Config{ + LastUseAge: time.Second, + }, + args: args{ + ctx: context.Background(), + value: &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }, + }, + assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) { + s.CheckGet(t, "0:one", objectID) + s.CheckGet(t, "1:foo", objectID) + s.CheckGet(t, "1:bar", objectID) + assert.Empty(t, s.HGet(objectID, "expiry")) + assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object")) + assert.Positive(t, s.TTL(objectID)) + + s.FastForward(2 * time.Second) + v, err := s.Get(objectID) + require.Error(t, err) + assert.Empty(t, v) + }, + }, + { + name: "with last use TTL and max age", + config: cache.Config{ + MaxAge: time.Minute, + LastUseAge: time.Second, + }, + args: args{ + ctx: context.Background(), + value: &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }, + }, + assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) { + s.CheckGet(t, "0:one", objectID) + s.CheckGet(t, "1:foo", objectID) + s.CheckGet(t, "1:bar", objectID) + assert.NotEmpty(t, s.HGet(objectID, "expiry")) + assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object")) + assert.Positive(t, s.TTL(objectID)) + + s.FastForward(2 * time.Second) + v, err := s.Get(objectID) + require.Error(t, err) + assert.Empty(t, v) + }, + }, + { + name: "with max age TTL", + config: cache.Config{ + MaxAge: time.Minute, + }, + args: args{ + ctx: context.Background(), + value: &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }, + }, + assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) { + s.CheckGet(t, "0:one", objectID) + s.CheckGet(t, "1:foo", objectID) + s.CheckGet(t, "1:bar", objectID) + assert.Empty(t, s.HGet(objectID, "expiry")) + assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object")) + assert.Positive(t, s.TTL(objectID)) + + s.FastForward(2 * time.Minute) + v, err := s.Get(objectID) + require.Error(t, err) + assert.Empty(t, v) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, server := prepareCache(t, tt.config) + rc := c.(*redisCache[testIndex, string, *testObject]) + objectID, err := rc.set(tt.args.ctx, tt.args.value) + require.ErrorIs(t, err, tt.wantErr) + t.Log(rc.connector.HGetAll(context.Background(), objectID)) + tt.assertions(t, server, objectID) + }) + } +} + +func Test_redisCache_Get(t *testing.T) { + type args struct { + ctx context.Context + index testIndex + key string + } + tests := []struct { + name string + config cache.Config + preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) + args args + want *testObject + wantOK bool + }{ + { + name: "connection error", + config: cache.Config{}, + preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + s.RequireAuth("foobar") + }, + args: args{ + ctx: context.Background(), + index: testIndexName, + key: "foo", + }, + wantOK: false, + }, + { + name: "get by ID", + config: cache.Config{}, + preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + c.Set(context.Background(), &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }) + }, + args: args{ + ctx: context.Background(), + index: testIndexID, + key: "one", + }, + want: &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }, + wantOK: true, + }, + { + name: "get by name", + config: cache.Config{}, + preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + c.Set(context.Background(), &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }) + }, + args: args{ + ctx: context.Background(), + index: testIndexName, + key: "foo", + }, + want: &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }, + wantOK: true, + }, + { + name: "usage timeout", + config: cache.Config{ + LastUseAge: time.Minute, + }, + preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + c.Set(context.Background(), &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }) + _, ok := c.Get(context.Background(), testIndexID, "one") + require.True(t, ok) + s.FastForward(2 * time.Minute) + }, + args: args{ + ctx: context.Background(), + index: testIndexName, + key: "foo", + }, + want: nil, + wantOK: false, + }, + { + name: "max age timeout", + config: cache.Config{ + MaxAge: time.Minute, + }, + preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + c.Set(context.Background(), &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }) + _, ok := c.Get(context.Background(), testIndexID, "one") + require.True(t, ok) + s.FastForward(2 * time.Minute) + }, + args: args{ + ctx: context.Background(), + index: testIndexName, + key: "foo", + }, + want: nil, + wantOK: false, + }, + { + name: "not found", + config: cache.Config{}, + preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + c.Set(context.Background(), &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }) + }, + args: args{ + ctx: context.Background(), + index: testIndexName, + key: "spanac", + }, + wantOK: false, + }, + { + name: "json decode error", + config: cache.Config{}, + preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + c.Set(context.Background(), &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }) + objectID, err := s.Get(c.(*redisCache[testIndex, string, *testObject]).redisIndexKeys(testIndexID, "one")[0]) + require.NoError(t, err) + s.HSet(objectID, "object", "~~~") + }, + args: args{ + ctx: context.Background(), + index: testIndexID, + key: "one", + }, + wantOK: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, server := prepareCache(t, tt.config) + tt.preparation(t, c, server) + t.Log(server.Keys()) + + got, ok := c.Get(tt.args.ctx, tt.args.index, tt.args.key) + require.Equal(t, tt.wantOK, ok) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_redisCache_Invalidate(t *testing.T) { + type args struct { + ctx context.Context + index testIndex + key []string + } + tests := []struct { + name string + config cache.Config + preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) + assertions func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) + args args + wantErr bool + }{ + { + name: "connection error", + config: cache.Config{}, + preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + s.RequireAuth("foobar") + }, + args: args{ + ctx: context.Background(), + index: testIndexName, + key: []string{"foo"}, + }, + wantErr: true, + }, + { + name: "no keys, noop", + config: cache.Config{}, + preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + c.Set(context.Background(), &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }) + }, + args: args{ + ctx: context.Background(), + index: testIndexID, + key: []string{}, + }, + wantErr: false, + }, + { + name: "invalidate by ID", + config: cache.Config{}, + preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + c.Set(context.Background(), &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }) + }, + assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) { + obj, ok := c.Get(context.Background(), testIndexID, "one") + assert.False(t, ok) + assert.Nil(t, obj) + obj, ok = c.Get(context.Background(), testIndexName, "foo") + assert.False(t, ok) + assert.Nil(t, obj) + }, + args: args{ + ctx: context.Background(), + index: testIndexID, + key: []string{"one"}, + }, + wantErr: false, + }, + { + name: "invalidate by name", + config: cache.Config{}, + preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + c.Set(context.Background(), &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }) + }, + assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) { + obj, ok := c.Get(context.Background(), testIndexID, "one") + assert.False(t, ok) + assert.Nil(t, obj) + obj, ok = c.Get(context.Background(), testIndexName, "foo") + assert.False(t, ok) + assert.Nil(t, obj) + }, + args: args{ + ctx: context.Background(), + index: testIndexName, + key: []string{"foo"}, + }, + wantErr: false, + }, + { + name: "invalidate after timeout", + config: cache.Config{ + LastUseAge: time.Minute, + }, + preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + c.Set(context.Background(), &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }) + _, ok := c.Get(context.Background(), testIndexID, "one") + require.True(t, ok) + s.FastForward(2 * time.Minute) + }, + assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) { + obj, ok := c.Get(context.Background(), testIndexID, "one") + assert.False(t, ok) + assert.Nil(t, obj) + obj, ok = c.Get(context.Background(), testIndexName, "foo") + assert.False(t, ok) + assert.Nil(t, obj) + }, + args: args{ + ctx: context.Background(), + index: testIndexName, + key: []string{"foo"}, + }, + wantErr: false, + }, + { + name: "not found", + config: cache.Config{}, + preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + c.Set(context.Background(), &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }) + }, + assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) { + obj, ok := c.Get(context.Background(), testIndexID, "one") + assert.True(t, ok) + assert.NotNil(t, obj) + obj, ok = c.Get(context.Background(), testIndexName, "foo") + assert.True(t, ok) + assert.NotNil(t, obj) + }, + args: args{ + ctx: context.Background(), + index: testIndexName, + key: []string{"spanac"}, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, server := prepareCache(t, tt.config) + tt.preparation(t, c, server) + t.Log(server.Keys()) + + err := c.Invalidate(tt.args.ctx, tt.args.index, tt.args.key...) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} + +func Test_redisCache_Delete(t *testing.T) { + type args struct { + ctx context.Context + index testIndex + key []string + } + tests := []struct { + name string + config cache.Config + preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) + assertions func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) + args args + wantErr bool + }{ + { + name: "connection error", + config: cache.Config{}, + preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + s.RequireAuth("foobar") + }, + args: args{ + ctx: context.Background(), + index: testIndexName, + key: []string{"foo"}, + }, + wantErr: true, + }, + { + name: "no keys, noop", + config: cache.Config{}, + preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + c.Set(context.Background(), &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }) + }, + args: args{ + ctx: context.Background(), + index: testIndexID, + key: []string{}, + }, + wantErr: false, + }, + { + name: "delete ID", + config: cache.Config{}, + preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + c.Set(context.Background(), &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }) + }, + assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) { + obj, ok := c.Get(context.Background(), testIndexID, "one") + assert.False(t, ok) + assert.Nil(t, obj) + // Get be name should still work + obj, ok = c.Get(context.Background(), testIndexName, "foo") + assert.True(t, ok) + assert.NotNil(t, obj) + }, + args: args{ + ctx: context.Background(), + index: testIndexID, + key: []string{"one"}, + }, + wantErr: false, + }, + { + name: "delete name", + config: cache.Config{}, + preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + c.Set(context.Background(), &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }) + }, + assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) { + // get by ID should still work + obj, ok := c.Get(context.Background(), testIndexID, "one") + assert.True(t, ok) + assert.NotNil(t, obj) + obj, ok = c.Get(context.Background(), testIndexName, "foo") + assert.False(t, ok) + assert.Nil(t, obj) + }, + args: args{ + ctx: context.Background(), + index: testIndexName, + key: []string{"foo"}, + }, + wantErr: false, + }, + { + name: "not found", + config: cache.Config{}, + preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + c.Set(context.Background(), &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }) + }, + assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) { + obj, ok := c.Get(context.Background(), testIndexID, "one") + assert.True(t, ok) + assert.NotNil(t, obj) + obj, ok = c.Get(context.Background(), testIndexName, "foo") + assert.True(t, ok) + assert.NotNil(t, obj) + }, + args: args{ + ctx: context.Background(), + index: testIndexName, + key: []string{"spanac"}, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, server := prepareCache(t, tt.config) + tt.preparation(t, c, server) + t.Log(server.Keys()) + + err := c.Delete(tt.args.ctx, tt.args.index, tt.args.key...) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} + +func Test_redisCache_Truncate(t *testing.T) { + type args struct { + ctx context.Context + } + tests := []struct { + name string + config cache.Config + preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) + assertions func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) + args args + wantErr bool + }{ + { + name: "connection error", + config: cache.Config{}, + preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + s.RequireAuth("foobar") + }, + args: args{ + ctx: context.Background(), + }, + wantErr: true, + }, + { + name: "ok", + config: cache.Config{}, + preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) { + c.Set(context.Background(), &testObject{ + ID: "one", + Name: []string{"foo", "bar"}, + }) + c.Set(context.Background(), &testObject{ + ID: "two", + Name: []string{"Hello", "World"}, + }) + }, + assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) { + obj, ok := c.Get(context.Background(), testIndexID, "one") + assert.False(t, ok) + assert.Nil(t, obj) + obj, ok = c.Get(context.Background(), testIndexName, "World") + assert.False(t, ok) + assert.Nil(t, obj) + }, + args: args{ + ctx: context.Background(), + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, server := prepareCache(t, tt.config) + tt.preparation(t, c, server) + t.Log(server.Keys()) + + err := c.Truncate(tt.args.ctx) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} + +func prepareCache(t *testing.T, conf cache.Config) (cache.Cache[testIndex, string, *testObject], *miniredis.Miniredis) { + conf.Log = &logging.Config{ + Level: "debug", + AddSource: true, + } + server := miniredis.RunT(t) + server.Select(testDB) + client := redis.NewClient(&redis.Options{ + Network: "tcp", + Addr: server.Addr(), + }) + t.Cleanup(func() { + client.Close() + server.Close() + }) + connector := NewConnector(Config{ + Enabled: true, + Network: "tcp", + Addr: server.Addr(), + }) + c := NewCache[testIndex, string, *testObject](conf, connector, testDB, testIndices) + return c, server +} diff --git a/internal/cache/connector/redis/set.lua b/internal/cache/connector/redis/set.lua new file mode 100644 index 0000000000..8c586bb47b --- /dev/null +++ b/internal/cache/connector/redis/set.lua @@ -0,0 +1,27 @@ +-- KEYS: [1]: object_id; [>1]: index keys. +local object_id = KEYS[1] +local object = ARGV[2] +local usage_lifetime = tonumber(ARGV[3]) -- usage based lifetime in seconds +local max_age = tonumber(ARGV[4]) -- max age liftime in seconds + +redis.call("HSET", object_id,"object", object) +if usage_lifetime > 0 then + redis.call("HSET", object_id, "usage_lifetime", usage_lifetime) + -- enable usage based TTL + redis.call("EXPIRE", object_id, usage_lifetime) + if max_age > 0 then + -- set max_age to hash map for expired remove on Get + local expiry = getTime() + max_age + redis.call("HSET", object_id, "expiry", expiry) + end +elseif max_age > 0 then + -- enable max_age based TTL + redis.call("EXPIRE", object_id, max_age) +end + +local n = #KEYS +local setKey = keySetKey(object_id) +for i = 2, n do -- offset to the second element to skip object_id + redis.call("SADD", setKey, KEYS[i]) -- set of all keys used for housekeeping + redis.call("SET", KEYS[i], object_id) -- key to object_id mapping +end diff --git a/internal/cache/connector_enumer.go b/internal/cache/connector_enumer.go new file mode 100644 index 0000000000..7ea014db16 --- /dev/null +++ b/internal/cache/connector_enumer.go @@ -0,0 +1,98 @@ +// Code generated by "enumer -type Connector -transform snake -trimprefix Connector -linecomment -text"; DO NOT EDIT. + +package cache + +import ( + "fmt" + "strings" +) + +const _ConnectorName = "memorypostgresredis" + +var _ConnectorIndex = [...]uint8{0, 0, 6, 14, 19} + +const _ConnectorLowerName = "memorypostgresredis" + +func (i Connector) String() string { + if i < 0 || i >= Connector(len(_ConnectorIndex)-1) { + return fmt.Sprintf("Connector(%d)", i) + } + return _ConnectorName[_ConnectorIndex[i]:_ConnectorIndex[i+1]] +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _ConnectorNoOp() { + var x [1]struct{} + _ = x[ConnectorUnspecified-(0)] + _ = x[ConnectorMemory-(1)] + _ = x[ConnectorPostgres-(2)] + _ = x[ConnectorRedis-(3)] +} + +var _ConnectorValues = []Connector{ConnectorUnspecified, ConnectorMemory, ConnectorPostgres, ConnectorRedis} + +var _ConnectorNameToValueMap = map[string]Connector{ + _ConnectorName[0:0]: ConnectorUnspecified, + _ConnectorLowerName[0:0]: ConnectorUnspecified, + _ConnectorName[0:6]: ConnectorMemory, + _ConnectorLowerName[0:6]: ConnectorMemory, + _ConnectorName[6:14]: ConnectorPostgres, + _ConnectorLowerName[6:14]: ConnectorPostgres, + _ConnectorName[14:19]: ConnectorRedis, + _ConnectorLowerName[14:19]: ConnectorRedis, +} + +var _ConnectorNames = []string{ + _ConnectorName[0:0], + _ConnectorName[0:6], + _ConnectorName[6:14], + _ConnectorName[14:19], +} + +// ConnectorString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func ConnectorString(s string) (Connector, error) { + if val, ok := _ConnectorNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _ConnectorNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to Connector values", s) +} + +// ConnectorValues returns all values of the enum +func ConnectorValues() []Connector { + return _ConnectorValues +} + +// ConnectorStrings returns a slice of all String values of the enum +func ConnectorStrings() []string { + strs := make([]string, len(_ConnectorNames)) + copy(strs, _ConnectorNames) + return strs +} + +// IsAConnector returns "true" if the value is listed in the enum definition. "false" otherwise +func (i Connector) IsAConnector() bool { + for _, v := range _ConnectorValues { + if i == v { + return true + } + } + return false +} + +// MarshalText implements the encoding.TextMarshaler interface for Connector +func (i Connector) MarshalText() ([]byte, error) { + return []byte(i.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface for Connector +func (i *Connector) UnmarshalText(text []byte) error { + var err error + *i, err = ConnectorString(string(text)) + return err +} diff --git a/internal/cache/pruner.go b/internal/cache/pruner.go index d4b0b41266..959762d410 100644 --- a/internal/cache/pruner.go +++ b/internal/cache/pruner.go @@ -31,22 +31,22 @@ type AutoPruneConfig struct { Timeout time.Duration } -func (c AutoPruneConfig) StartAutoPrune(background context.Context, pruner Pruner, name string) (close func()) { - return c.startAutoPrune(background, pruner, name, clockwork.NewRealClock()) +func (c AutoPruneConfig) StartAutoPrune(background context.Context, pruner Pruner, purpose Purpose) (close func()) { + return c.startAutoPrune(background, pruner, purpose, clockwork.NewRealClock()) } -func (c *AutoPruneConfig) startAutoPrune(background context.Context, pruner Pruner, name string, clock clockwork.Clock) (close func()) { +func (c *AutoPruneConfig) startAutoPrune(background context.Context, pruner Pruner, purpose Purpose, clock clockwork.Clock) (close func()) { if c.Interval <= 0 { return func() {} } background, cancel := context.WithCancel(background) // randomize the first interval timer := clock.NewTimer(time.Duration(rand.Int63n(int64(c.Interval)))) - go c.pruneTimer(background, pruner, name, timer) + go c.pruneTimer(background, pruner, purpose, timer) return cancel } -func (c *AutoPruneConfig) pruneTimer(background context.Context, pruner Pruner, name string, timer clockwork.Timer) { +func (c *AutoPruneConfig) pruneTimer(background context.Context, pruner Pruner, purpose Purpose, timer clockwork.Timer) { defer func() { if !timer.Stop() { <-timer.Chan() @@ -58,9 +58,9 @@ func (c *AutoPruneConfig) pruneTimer(background context.Context, pruner Pruner, case <-background.Done(): return case <-timer.Chan(): - timer.Reset(c.Interval) err := c.doPrune(background, pruner) - logging.OnError(err).WithField("name", name).Error("cache auto prune") + logging.OnError(err).WithField("purpose", purpose).Error("cache auto prune") + timer.Reset(c.Interval) } } } diff --git a/internal/cache/pruner_test.go b/internal/cache/pruner_test.go index ababe81e59..faaedeb88c 100644 --- a/internal/cache/pruner_test.go +++ b/internal/cache/pruner_test.go @@ -30,7 +30,7 @@ func TestAutoPruneConfig_startAutoPrune(t *testing.T) { called: make(chan struct{}), } clock := clockwork.NewFakeClock() - close := c.startAutoPrune(ctx, &pruner, "foo", clock) + close := c.startAutoPrune(ctx, &pruner, PurposeAuthzInstance, clock) defer close() clock.Advance(time.Second) diff --git a/internal/cache/purpose_enumer.go b/internal/cache/purpose_enumer.go new file mode 100644 index 0000000000..bae47476ff --- /dev/null +++ b/internal/cache/purpose_enumer.go @@ -0,0 +1,82 @@ +// Code generated by "enumer -type Purpose -transform snake -trimprefix Purpose"; DO NOT EDIT. + +package cache + +import ( + "fmt" + "strings" +) + +const _PurposeName = "unspecifiedauthz_instancemilestones" + +var _PurposeIndex = [...]uint8{0, 11, 25, 35} + +const _PurposeLowerName = "unspecifiedauthz_instancemilestones" + +func (i Purpose) String() string { + if i < 0 || i >= Purpose(len(_PurposeIndex)-1) { + return fmt.Sprintf("Purpose(%d)", i) + } + return _PurposeName[_PurposeIndex[i]:_PurposeIndex[i+1]] +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _PurposeNoOp() { + var x [1]struct{} + _ = x[PurposeUnspecified-(0)] + _ = x[PurposeAuthzInstance-(1)] + _ = x[PurposeMilestones-(2)] +} + +var _PurposeValues = []Purpose{PurposeUnspecified, PurposeAuthzInstance, PurposeMilestones} + +var _PurposeNameToValueMap = map[string]Purpose{ + _PurposeName[0:11]: PurposeUnspecified, + _PurposeLowerName[0:11]: PurposeUnspecified, + _PurposeName[11:25]: PurposeAuthzInstance, + _PurposeLowerName[11:25]: PurposeAuthzInstance, + _PurposeName[25:35]: PurposeMilestones, + _PurposeLowerName[25:35]: PurposeMilestones, +} + +var _PurposeNames = []string{ + _PurposeName[0:11], + _PurposeName[11:25], + _PurposeName[25:35], +} + +// PurposeString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func PurposeString(s string) (Purpose, error) { + if val, ok := _PurposeNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _PurposeNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to Purpose values", s) +} + +// PurposeValues returns all values of the enum +func PurposeValues() []Purpose { + return _PurposeValues +} + +// PurposeStrings returns a slice of all String values of the enum +func PurposeStrings() []string { + strs := make([]string, len(_PurposeNames)) + copy(strs, _PurposeNames) + return strs +} + +// IsAPurpose returns "true" if the value is listed in the enum definition. "false" otherwise +func (i Purpose) IsAPurpose() bool { + for _, v := range _PurposeValues { + if i == v { + return true + } + } + return false +} diff --git a/internal/command/cache.go b/internal/command/cache.go index bf976bd2d7..384577738e 100644 --- a/internal/command/cache.go +++ b/internal/command/cache.go @@ -2,81 +2,20 @@ package command import ( "context" - "fmt" - "strings" "github.com/zitadel/zitadel/internal/cache" - "github.com/zitadel/zitadel/internal/cache/gomap" - "github.com/zitadel/zitadel/internal/cache/noop" - "github.com/zitadel/zitadel/internal/cache/pg" - "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/cache/connector" ) type Caches struct { - connectors *cacheConnectors milestones cache.Cache[milestoneIndex, string, *MilestonesReached] } -func startCaches(background context.Context, conf *cache.CachesConfig, client *database.DB) (_ *Caches, err error) { - caches := &Caches{ - milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](), - } - if conf == nil { - return caches, nil - } - caches.connectors, err = startCacheConnectors(background, conf, client) - if err != nil { - return nil, err - } - caches.milestones, err = startCache[milestoneIndex, string, *MilestonesReached](background, []milestoneIndex{milestoneIndexInstanceID}, "milestones", conf.Instance, caches.connectors) +func startCaches(background context.Context, connectors connector.Connectors) (_ *Caches, err error) { + caches := new(Caches) + caches.milestones, err = connector.StartCache[milestoneIndex, string, *MilestonesReached](background, []milestoneIndex{milestoneIndexInstanceID}, cache.PurposeMilestones, connectors.Config.Milestones, connectors) if err != nil { return nil, err } return caches, nil } - -type cacheConnectors struct { - memory *cache.AutoPruneConfig - postgres *pgxPoolCacheConnector -} - -type pgxPoolCacheConnector struct { - *cache.AutoPruneConfig - client *database.DB -} - -func startCacheConnectors(_ context.Context, conf *cache.CachesConfig, client *database.DB) (_ *cacheConnectors, err error) { - connectors := new(cacheConnectors) - if conf.Connectors.Memory.Enabled { - connectors.memory = &conf.Connectors.Memory.AutoPrune - } - if conf.Connectors.Postgres.Enabled { - connectors.postgres = &pgxPoolCacheConnector{ - AutoPruneConfig: &conf.Connectors.Postgres.AutoPrune, - client: client, - } - } - return connectors, nil -} - -func startCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, name string, conf *cache.CacheConfig, connectors *cacheConnectors) (cache.Cache[I, K, V], error) { - if conf == nil || conf.Connector == "" { - return noop.NewCache[I, K, V](), nil - } - if strings.EqualFold(conf.Connector, "memory") && connectors.memory != nil { - c := gomap.NewCache[I, K, V](background, indices, *conf) - connectors.memory.StartAutoPrune(background, c, name) - return c, nil - } - if strings.EqualFold(conf.Connector, "postgres") && connectors.postgres != nil { - client := connectors.postgres.client - c, err := pg.NewCache[I, K, V](background, name, *conf, indices, client.Pool, client.Type()) - if err != nil { - return nil, fmt.Errorf("query start cache: %w", err) - } - connectors.postgres.StartAutoPrune(background, c, name) - return c, nil - } - - return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector) -} diff --git a/internal/command/command.go b/internal/command/command.go index 7c56f05b86..bc3f189a4a 100644 --- a/internal/command/command.go +++ b/internal/command/command.go @@ -18,7 +18,7 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" api_http "github.com/zitadel/zitadel/internal/api/http" - "github.com/zitadel/zitadel/internal/cache" + "github.com/zitadel/zitadel/internal/cache/connector" "github.com/zitadel/zitadel/internal/command/preparation" sd "github.com/zitadel/zitadel/internal/config/systemdefaults" "github.com/zitadel/zitadel/internal/crypto" @@ -98,8 +98,9 @@ type Commands struct { } func StartCommands( + ctx context.Context, es *eventstore.Eventstore, - cachesConfig *cache.CachesConfig, + cacheConnectors connector.Connectors, defaults sd.SystemDefaults, zitadelRoles []authz.RoleMapping, staticStore static.Storage, @@ -131,7 +132,7 @@ func StartCommands( if err != nil { return nil, fmt.Errorf("password hasher: %w", err) } - caches, err := startCaches(context.TODO(), cachesConfig, es.Client()) + caches, err := startCaches(ctx, cacheConnectors) if err != nil { return nil, fmt.Errorf("caches: %w", err) } diff --git a/internal/command/instance_test.go b/internal/command/instance_test.go index c60b2763b3..301077b268 100644 --- a/internal/command/instance_test.go +++ b/internal/command/instance_test.go @@ -13,7 +13,7 @@ import ( "golang.org/x/text/language" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/cache/noop" + "github.com/zitadel/zitadel/internal/cache/connector/noop" "github.com/zitadel/zitadel/internal/command/preparation" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" diff --git a/internal/command/milestone_test.go b/internal/command/milestone_test.go index 819db9d098..3c4bffc704 100644 --- a/internal/command/milestone_test.go +++ b/internal/command/milestone_test.go @@ -10,8 +10,8 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/cache" - "github.com/zitadel/zitadel/internal/cache/gomap" - "github.com/zitadel/zitadel/internal/cache/noop" + "github.com/zitadel/zitadel/internal/cache/connector/gomap" + "github.com/zitadel/zitadel/internal/cache/connector/noop" "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/repository/milestone" ) @@ -183,7 +183,7 @@ func TestCommands_GetMilestonesReached(t *testing.T) { cache := gomap.NewCache[milestoneIndex, string, *MilestonesReached]( context.Background(), []milestoneIndex{milestoneIndexInstanceID}, - cache.CacheConfig{Connector: "memory"}, + cache.Config{Connector: cache.ConnectorMemory}, ) cache.Set(context.Background(), cached) diff --git a/internal/integration/config/docker-compose.yaml b/internal/integration/config/docker-compose.yaml index 1749b9f0ab..19c68ae405 100644 --- a/internal/integration/config/docker-compose.yaml +++ b/internal/integration/config/docker-compose.yaml @@ -23,3 +23,9 @@ services: start_period: '20s' ports: - 5432:5432 + + cache: + restart: 'always' + image: 'redis:latest' + ports: + - 6379:6379 diff --git a/internal/integration/config/zitadel.yaml b/internal/integration/config/zitadel.yaml index b1482f6e1a..378dc2f09b 100644 --- a/internal/integration/config/zitadel.yaml +++ b/internal/integration/config/zitadel.yaml @@ -10,13 +10,21 @@ Caches: Connectors: Postgres: Enabled: true - AutoPrune: - Interval: 30s - TimeOut: 1s + Redis: + Enabled: true Instance: + Connector: "redis" + MaxAge: 1h + LastUsage: 10m + Log: + Level: info + AddSource: true + Formatter: + Format: text + Milestones: Connector: "postgres" MaxAge: 1h - LastUsage: 30m + LastUsage: 10m Log: Level: info AddSource: true diff --git a/internal/query/cache.go b/internal/query/cache.go index 2722377891..55f7bb3db6 100644 --- a/internal/query/cache.go +++ b/internal/query/cache.go @@ -2,90 +2,28 @@ package query import ( "context" - "fmt" - "strings" "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/cache" - "github.com/zitadel/zitadel/internal/cache/gomap" - "github.com/zitadel/zitadel/internal/cache/noop" - "github.com/zitadel/zitadel/internal/cache/pg" - "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/cache/connector" "github.com/zitadel/zitadel/internal/eventstore" ) type Caches struct { - connectors *cacheConnectors - instance cache.Cache[instanceIndex, string, *authzInstance] + instance cache.Cache[instanceIndex, string, *authzInstance] } -func startCaches(background context.Context, conf *cache.CachesConfig, client *database.DB) (_ *Caches, err error) { - caches := &Caches{ - instance: noop.NewCache[instanceIndex, string, *authzInstance](), - } - if conf == nil { - return caches, nil - } - caches.connectors, err = startCacheConnectors(background, conf, client) - if err != nil { - return nil, err - } - caches.instance, err = startCache[instanceIndex, string, *authzInstance](background, instanceIndexValues(), "authz_instance", conf.Instance, caches.connectors) +func startCaches(background context.Context, connectors connector.Connectors) (_ *Caches, err error) { + caches := new(Caches) + caches.instance, err = connector.StartCache[instanceIndex, string, *authzInstance](background, instanceIndexValues(), cache.PurposeAuthzInstance, connectors.Config.Instance, connectors) if err != nil { return nil, err } caches.registerInstanceInvalidation() - return caches, nil } -type cacheConnectors struct { - memory *cache.AutoPruneConfig - postgres *pgxPoolCacheConnector -} - -type pgxPoolCacheConnector struct { - *cache.AutoPruneConfig - client *database.DB -} - -func startCacheConnectors(_ context.Context, conf *cache.CachesConfig, client *database.DB) (_ *cacheConnectors, err error) { - connectors := new(cacheConnectors) - if conf.Connectors.Memory.Enabled { - connectors.memory = &conf.Connectors.Memory.AutoPrune - } - if conf.Connectors.Postgres.Enabled { - connectors.postgres = &pgxPoolCacheConnector{ - AutoPruneConfig: &conf.Connectors.Postgres.AutoPrune, - client: client, - } - } - return connectors, nil -} - -func startCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, name string, conf *cache.CacheConfig, connectors *cacheConnectors) (cache.Cache[I, K, V], error) { - if conf == nil || conf.Connector == "" { - return noop.NewCache[I, K, V](), nil - } - if strings.EqualFold(conf.Connector, "memory") && connectors.memory != nil { - c := gomap.NewCache[I, K, V](background, indices, *conf) - connectors.memory.StartAutoPrune(background, c, name) - return c, nil - } - if strings.EqualFold(conf.Connector, "postgres") && connectors.postgres != nil { - client := connectors.postgres.client - c, err := pg.NewCache[I, K, V](background, name, *conf, indices, client.Pool, client.Type()) - if err != nil { - return nil, fmt.Errorf("query start cache: %w", err) - } - connectors.postgres.StartAutoPrune(background, c, name) - return c, nil - } - - return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector) -} - type invalidator[I comparable] interface { Invalidate(ctx context.Context, index I, key ...string) error } diff --git a/internal/query/instance.go b/internal/query/instance.go index ef74f0ebdd..549c05a233 100644 --- a/internal/query/instance.go +++ b/internal/query/instance.go @@ -587,9 +587,10 @@ func (c *Caches) registerInstanceInvalidation() { projection.InstanceTrustedDomainProjection.RegisterCacheInvalidation(invalidate) projection.SecurityPolicyProjection.RegisterCacheInvalidation(invalidate) - // limits uses own aggregate ID, invalidate using resource owner. + // These projections have their own aggregate ID, invalidate using resource owner. invalidate = cacheInvalidationFunc(c.instance, instanceIndexByID, getResourceOwner) projection.LimitsProjection.RegisterCacheInvalidation(invalidate) + projection.RestrictionsProjection.RegisterCacheInvalidation(invalidate) // System feature update should invalidate all instances, so Truncate the cache. projection.SystemFeatureProjection.RegisterCacheInvalidation(func(ctx context.Context, _ []*eventstore.Aggregate) { diff --git a/internal/query/query.go b/internal/query/query.go index 590326d07b..b39dbe9ca1 100644 --- a/internal/query/query.go +++ b/internal/query/query.go @@ -11,7 +11,7 @@ import ( "golang.org/x/text/language" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/cache" + "github.com/zitadel/zitadel/internal/cache/connector" sd "github.com/zitadel/zitadel/internal/config/systemdefaults" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/database" @@ -49,7 +49,7 @@ func StartQueries( es *eventstore.Eventstore, esV4 es_v4.Querier, querySqlClient, projectionSqlClient *database.DB, - caches *cache.CachesConfig, + cacheConnectors connector.Connectors, projections projection.Config, defaults sd.SystemDefaults, idpConfigEncryption, otpEncryption, keyEncryptionAlgorithm, certEncryptionAlgorithm crypto.EncryptionAlgorithm, @@ -89,7 +89,7 @@ func StartQueries( if startProjections { projection.Start(ctx) } - repo.caches, err = startCaches(ctx, caches, querySqlClient) + repo.caches, err = startCaches(ctx, cacheConnectors) if err != nil { return nil, err }