perf(cache): pgx pool connector (#8703)

# Which Problems Are Solved

Cache implementation using a PGX connection pool.

# How the Problems Are Solved

Defines a new schema `cache` in the zitadel database.
A table for string keys and a table for objects is defined.
For postgreSQL, tables are unlogged and partitioned by cache name for
performance.

Cockroach does not have unlogged tables and partitioning is an
enterprise feature that uses alternative syntax combined with sharding.
Regular tables are used here.

# Additional Changes

- `postgres.Config` can return a pxg pool. See following discussion

# Additional Context

- Part of https://github.com/zitadel/zitadel/issues/8648
- Closes https://github.com/zitadel/zitadel/issues/8647

---------

Co-authored-by: Silvan <silvan.reusser@gmail.com>
This commit is contained in:
Tim Möhlmann 2024-10-04 16:15:41 +03:00 committed by GitHub
parent bee0744d46
commit 25dc7bfe72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 1034 additions and 140 deletions

37
cmd/setup/34.go Normal file
View File

@ -0,0 +1,37 @@
package setup
import (
"context"
_ "embed"
"fmt"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
)
var (
//go:embed 34/cockroach/34_cache_schema.sql
addCacheSchemaCockroach string
//go:embed 34/postgres/34_cache_schema.sql
addCacheSchemaPostgres string
)
type AddCacheSchema struct {
dbClient *database.DB
}
func (mig *AddCacheSchema) Execute(ctx context.Context, _ eventstore.Event) (err error) {
switch mig.dbClient.Type() {
case "cockroach":
_, err = mig.dbClient.ExecContext(ctx, addCacheSchemaCockroach)
case "postgres":
_, err = mig.dbClient.ExecContext(ctx, addCacheSchemaPostgres)
default:
err = fmt.Errorf("add cache schema: unsupported db type %q", mig.dbClient.Type())
}
return err
}
func (mig *AddCacheSchema) String() string {
return "34_add_cache_schema"
}

View File

@ -0,0 +1,27 @@
create schema if not exists cache;
create table if not exists cache.objects (
cache_name varchar not null,
id uuid not null default gen_random_uuid(),
created_at timestamptz not null default now(),
last_used_at timestamptz not null default now(),
payload jsonb not null,
primary key(cache_name, id)
);
create table if not exists cache.string_keys(
cache_name varchar not null check (cache_name <> ''),
index_id integer not null check (index_id > 0),
index_key varchar not null check (index_key <> ''),
object_id uuid not null,
primary key (cache_name, index_id, index_key),
constraint fk_object
foreign key(cache_name, object_id)
references cache.objects(cache_name, id)
on delete cascade
);
create index if not exists string_keys_object_id_idx
on cache.string_keys (cache_name, object_id); -- for delete cascade

View File

@ -0,0 +1,29 @@
create schema if not exists cache;
create unlogged table if not exists cache.objects (
cache_name varchar not null,
id uuid not null default gen_random_uuid(),
created_at timestamptz not null default now(),
last_used_at timestamptz not null default now(),
payload jsonb not null,
primary key(cache_name, id)
)
partition by list (cache_name);
create unlogged table if not exists cache.string_keys(
cache_name varchar not null check (cache_name <> ''),
index_id integer not null check (index_id > 0),
index_key varchar not null check (index_key <> ''),
object_id uuid not null,
primary key (cache_name, index_id, index_key),
constraint fk_object
foreign key(cache_name, object_id)
references cache.objects(cache_name, id)
on delete cascade
)
partition by list (cache_name);
create index if not exists string_keys_object_id_idx
on cache.string_keys (cache_name, object_id); -- for delete cascade

View File

@ -90,36 +90,37 @@ func MustNewConfig(v *viper.Viper) *Config {
}
type Steps struct {
s1ProjectionTable *ProjectionTable
s2AssetsTable *AssetTable
FirstInstance *FirstInstance
s5LastFailed *LastFailed
s6OwnerRemoveColumns *OwnerRemoveColumns
s7LogstoreTables *LogstoreTables
s8AuthTokens *AuthTokenIndexes
CorrectCreationDate *CorrectCreationDate
s12AddOTPColumns *AddOTPColumns
s13FixQuotaProjection *FixQuotaConstraints
s14NewEventsTable *NewEventsTable
s15CurrentStates *CurrentProjectionState
s16UniqueConstraintsLower *UniqueConstraintToLower
s17AddOffsetToUniqueConstraints *AddOffsetToCurrentStates
s18AddLowerFieldsToLoginNames *AddLowerFieldsToLoginNames
s19AddCurrentStatesIndex *AddCurrentSequencesIndex
s20AddByUserSessionIndex *AddByUserIndexToSession
s21AddBlockFieldToLimits *AddBlockFieldToLimits
s22ActiveInstancesIndex *ActiveInstanceEvents
s23CorrectGlobalUniqueConstraints *CorrectGlobalUniqueConstraints
s24AddActorToAuthTokens *AddActorToAuthTokens
s25User11AddLowerFieldsToVerifiedEmail *User11AddLowerFieldsToVerifiedEmail
s26AuthUsers3 *AuthUsers3
s27IDPTemplate6SAMLNameIDFormat *IDPTemplate6SAMLNameIDFormat
s28AddFieldTable *AddFieldTable
s29FillFieldsForProjectGrant *FillFieldsForProjectGrant
s30FillFieldsForOrgDomainVerified *FillFieldsForOrgDomainVerified
s31AddAggregateIndexToFields *AddAggregateIndexToFields
s32AddAuthSessionID *AddAuthSessionID
s1ProjectionTable *ProjectionTable
s2AssetsTable *AssetTable
FirstInstance *FirstInstance
s5LastFailed *LastFailed
s6OwnerRemoveColumns *OwnerRemoveColumns
s7LogstoreTables *LogstoreTables
s8AuthTokens *AuthTokenIndexes
CorrectCreationDate *CorrectCreationDate
s12AddOTPColumns *AddOTPColumns
s13FixQuotaProjection *FixQuotaConstraints
s14NewEventsTable *NewEventsTable
s15CurrentStates *CurrentProjectionState
s16UniqueConstraintsLower *UniqueConstraintToLower
s17AddOffsetToUniqueConstraints *AddOffsetToCurrentStates
s18AddLowerFieldsToLoginNames *AddLowerFieldsToLoginNames
s19AddCurrentStatesIndex *AddCurrentSequencesIndex
s20AddByUserSessionIndex *AddByUserIndexToSession
s21AddBlockFieldToLimits *AddBlockFieldToLimits
s22ActiveInstancesIndex *ActiveInstanceEvents
s23CorrectGlobalUniqueConstraints *CorrectGlobalUniqueConstraints
s24AddActorToAuthTokens *AddActorToAuthTokens
s25User11AddLowerFieldsToVerifiedEmail *User11AddLowerFieldsToVerifiedEmail
s26AuthUsers3 *AuthUsers3
s27IDPTemplate6SAMLNameIDFormat *IDPTemplate6SAMLNameIDFormat
s28AddFieldTable *AddFieldTable
s29FillFieldsForProjectGrant *FillFieldsForProjectGrant
s30FillFieldsForOrgDomainVerified *FillFieldsForOrgDomainVerified
s31AddAggregateIndexToFields *AddAggregateIndexToFields
s32AddAuthSessionID *AddAuthSessionID
s33SMSConfigs3TwilioAddVerifyServiceSid *SMSConfigs3TwilioAddVerifyServiceSid
s34AddCacheSchema *AddCacheSchema
}
func MustNewSteps(v *viper.Viper) *Steps {

View File

@ -162,6 +162,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
steps.s31AddAggregateIndexToFields = &AddAggregateIndexToFields{dbClient: esPusherDBClient}
steps.s32AddAuthSessionID = &AddAuthSessionID{dbClient: esPusherDBClient}
steps.s33SMSConfigs3TwilioAddVerifyServiceSid = &SMSConfigs3TwilioAddVerifyServiceSid{dbClient: esPusherDBClient}
steps.s34AddCacheSchema = &AddCacheSchema{dbClient: queryDBClient}
err = projection.Create(ctx, projectionDBClient, eventstoreClient, config.Projections, nil, nil, nil)
logging.OnError(err).Fatal("unable to start projections")
@ -204,6 +205,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
steps.s26AuthUsers3,
steps.s29FillFieldsForProjectGrant,
steps.s30FillFieldsForOrgDomainVerified,
steps.s34AddCacheSchema,
} {
mustExecuteMigration(ctx, eventstoreClient, step, "migration failed")
}

9
go.mod
View File

@ -38,7 +38,7 @@ require (
github.com/h2non/gock v1.2.0
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/improbable-eng/grpc-web v0.15.0
github.com/jackc/pgx/v5 v5.6.0
github.com/jackc/pgx/v5 v5.7.0
github.com/jarcoal/jpath v0.0.0-20140328210829-f76b8b2dbf52
github.com/jinzhu/gorm v1.9.16
github.com/k3a/html2text v1.2.1
@ -49,6 +49,7 @@ require (
github.com/muhlemmer/gu v0.3.1
github.com/muhlemmer/httpforwarded v0.1.0
github.com/nicksnyder/go-i18n/v2 v2.4.0
github.com/pashagolub/pgxmock/v4 v4.3.0
github.com/pquerna/otp v1.4.0
github.com/rakyll/statik v0.1.7
github.com/rs/cors v1.11.0
@ -76,12 +77,12 @@ require (
go.opentelemetry.io/otel/sdk/metric v1.28.0
go.opentelemetry.io/otel/trace v1.28.0
go.uber.org/mock v0.4.0
golang.org/x/crypto v0.25.0
golang.org/x/crypto v0.27.0
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8
golang.org/x/net v0.26.0
golang.org/x/oauth2 v0.22.0
golang.org/x/sync v0.8.0
golang.org/x/text v0.17.0
golang.org/x/text v0.18.0
google.golang.org/api v0.187.0
google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094
google.golang.org/grpc v1.65.0
@ -169,7 +170,7 @@ require (
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jonboulle/clockwork v0.4.0
github.com/klauspost/compress v1.17.9 // indirect

18
go.sum
View File

@ -400,10 +400,10 @@ github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLf
github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.7.0 h1:FG6VLIdzvAPhnYqP14sQ2xhFLkiUQHCs6ySqO91kF4g=
github.com/jackc/pgx/v5 v5.7.0/go.mod h1:awP1KNnjylvpxHuHP63gzjhnGkI1iw+PMoIwvoleN/8=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jarcoal/jpath v0.0.0-20140328210829-f76b8b2dbf52 h1:jny9eqYPwkG8IVy7foUoRjQmFLcArCSz+uPsL6KS0HQ=
@ -567,6 +567,8 @@ github.com/openzipkin/zipkin-go v0.2.1/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnh
github.com/openzipkin/zipkin-go v0.2.2/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4=
github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM=
github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
github.com/pashagolub/pgxmock/v4 v4.3.0 h1:DqT7fk0OCK6H0GvqtcMsLpv8cIwWqdxWgfZNLeHCb/s=
github.com/pashagolub/pgxmock/v4 v4.3.0/go.mod h1:9VoVHXwS3XR/yPtKGzwQvwZX1kzGB9sM8SviDcHDa3A=
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
@ -791,8 +793,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20200331195152-e8c3332aa8e5/go.mod h1:4M0jN8W1tt0AVLNr8HDosyJCDCDuyL9N9+3m7wDWgKw=
@ -932,8 +934,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc=
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224=
golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=

View File

@ -6,6 +6,7 @@ import (
"time"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database/postgres"
)
// Cache stores objects with a value of type `V`.
@ -55,9 +56,6 @@ type Cache[I, K comparable, V Entry[I, K]] interface {
// Truncate deletes all cached objects.
Truncate(ctx context.Context) error
// Close the cache. Subsequent calls to the cache are not allowed.
Close(ctx context.Context) error
}
// Entry contains a value of type `V` to be cached.
@ -75,8 +73,8 @@ type Entry[I, K comparable] interface {
type CachesConfig struct {
Connectors struct {
Memory MemoryConnectorConfig
// SQL database.Config
Memory MemoryConnectorConfig
Postgres PostgresConnectorConfig
// Redis redis.Config?
}
Instance *CacheConfig
@ -104,3 +102,9 @@ type MemoryConnectorConfig struct {
Enabled bool
AutoPrune AutoPruneConfig
}
type PostgresConnectorConfig struct {
Enabled bool
AutoPrune AutoPruneConfig
Connection postgres.Config
}

View File

@ -109,15 +109,11 @@ func (c *mapCache[I, K, V]) Prune(ctx context.Context) error {
func (c *mapCache[I, K, V]) Truncate(ctx context.Context) error {
for name, index := range c.indexMap {
index.Truncate()
c.logger.DebugContext(ctx, "map cache clear", "index", name)
c.logger.DebugContext(ctx, "map cache truncate", "index", name)
}
return nil
}
func (c *mapCache[I, K, V]) Close(ctx context.Context) error {
return ctx.Err()
}
type index[K comparable, V any] struct {
mutex sync.RWMutex
config *cache.CacheConfig

View File

@ -49,7 +49,6 @@ func Test_mapCache_Get(t *testing.T) {
AddSource: true,
},
})
defer c.Close(context.Background())
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
@ -112,7 +111,6 @@ func Test_mapCache_Invalidate(t *testing.T) {
AddSource: true,
},
})
defer c.Close(context.Background())
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
@ -134,7 +132,6 @@ func Test_mapCache_Delete(t *testing.T) {
AddSource: true,
},
})
defer c.Close(context.Background())
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
@ -168,7 +165,6 @@ func Test_mapCache_Prune(t *testing.T) {
AddSource: true,
},
})
defer c.Close(context.Background())
objects := []*testObject{
{
@ -205,7 +201,6 @@ func Test_mapCache_Truncate(t *testing.T) {
AddSource: true,
},
})
defer c.Close(context.Background())
objects := []*testObject{
{
id: "id1",

View File

@ -19,4 +19,3 @@ func (noop[I, K, V]) Invalidate(context.Context, I, ...K) (err error) { return }
func (noop[I, K, V]) Delete(context.Context, I, ...K) (err error) { return }
func (noop[I, K, V]) Prune(context.Context) (err error) { return }
func (noop[I, K, V]) Truncate(context.Context) (err error) { return }
func (noop[I, K, V]) Close(context.Context) (err error) { return }

View File

@ -0,0 +1,7 @@
create unlogged table if not exists cache.objects_{{ . }}
partition of cache.objects
for values in ('{{ . }}');
create unlogged table if not exists cache.string_keys_{{ . }}
partition of cache.string_keys
for values in ('{{ . }}');

5
internal/cache/pg/delete.sql vendored Normal file
View File

@ -0,0 +1,5 @@
delete from cache.string_keys k
where k.cache_name = $1
and k.index_id = $2
and k.index_key = any($3)
;

19
internal/cache/pg/get.sql vendored Normal file
View File

@ -0,0 +1,19 @@
update cache.objects
set last_used_at = now()
where cache_name = $1
and (
select object_id
from cache.string_keys k
where cache_name = $1
and index_id = $2
and index_key = $3
) = id
and case when $4::interval > '0s'
then created_at > now()-$4::interval -- max age
else true
end
and case when $5::interval > '0s'
then last_used_at > now()-$5::interval -- last use
else true
end
returning payload;

9
internal/cache/pg/invalidate.sql vendored Normal file
View File

@ -0,0 +1,9 @@
delete from cache.objects o
using cache.string_keys k
where k.cache_name = $1
and k.index_id = $2
and k.index_key = any($3)
and o.cache_name = k.cache_name
and o.id = k.object_id
;

176
internal/cache/pg/pg.go vendored Normal file
View File

@ -0,0 +1,176 @@
package pg
import (
"context"
_ "embed"
"errors"
"log/slog"
"strings"
"text/template"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"golang.org/x/exp/slices"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
var (
//go:embed create_partition.sql.tmpl
createPartitionQuery string
createPartitionTmpl = template.Must(template.New("create_partition").Parse(createPartitionQuery))
//go:embed set.sql
setQuery string
//go:embed get.sql
getQuery string
//go:embed invalidate.sql
invalidateQuery string
//go:embed delete.sql
deleteQuery string
//go:embed prune.sql
pruneQuery string
//go:embed truncate.sql
truncateQuery string
)
type PGXPool interface {
Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
}
type pgCache[I ~int, K ~string, V cache.Entry[I, K]] struct {
name string
config *cache.CacheConfig
indices []I
pool PGXPool
logger *slog.Logger
}
// NewCache returns a cache that stores and retrieves objects using PostgreSQL unlogged tables.
func NewCache[I ~int, K ~string, V cache.Entry[I, K]](ctx context.Context, name string, config cache.CacheConfig, indices []I, pool PGXPool, dialect string) (cache.PrunerCache[I, K, V], error) {
c := &pgCache[I, K, V]{
name: name,
config: &config,
indices: indices,
pool: pool,
logger: config.Log.Slog().With("cache_name", name),
}
c.logger.InfoContext(ctx, "pg cache logging enabled")
if dialect == "postgres" {
if err := c.createPartition(ctx); err != nil {
return nil, err
}
}
return c, nil
}
func (c *pgCache[I, K, V]) createPartition(ctx context.Context) error {
var query strings.Builder
if err := createPartitionTmpl.Execute(&query, c.name); err != nil {
return err
}
_, err := c.pool.Exec(ctx, query.String())
return err
}
func (c *pgCache[I, K, V]) Set(ctx context.Context, entry V) {
//nolint:errcheck
c.set(ctx, entry)
}
func (c *pgCache[I, K, V]) set(ctx context.Context, entry V) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
keys := c.indexKeysFromEntry(entry)
c.logger.DebugContext(ctx, "pg cache set", "index_key", keys)
_, err = c.pool.Exec(ctx, setQuery, c.name, keys, entry)
if err != nil {
c.logger.ErrorContext(ctx, "pg cache set", "err", err)
return err
}
return nil
}
func (c *pgCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) {
value, err := c.get(ctx, index, key)
if err == nil {
c.logger.DebugContext(ctx, "pg cache get", "index", index, "key", key)
return value, true
}
logger := c.logger.With("err", err, "index", index, "key", key)
if errors.Is(err, pgx.ErrNoRows) {
logger.InfoContext(ctx, "pg cache miss")
return value, false
}
logger.ErrorContext(ctx, "pg cache get", "err", err)
return value, false
}
func (c *pgCache[I, K, V]) get(ctx context.Context, index I, key K) (value V, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if !slices.Contains(c.indices, index) {
return value, cache.NewIndexUnknownErr(index)
}
err = c.pool.QueryRow(ctx, getQuery, c.name, index, key, c.config.MaxAge, c.config.LastUseAge).Scan(&value)
return value, err
}
func (c *pgCache[I, K, V]) Invalidate(ctx context.Context, index I, keys ...K) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
_, err = c.pool.Exec(ctx, invalidateQuery, c.name, index, keys)
c.logger.DebugContext(ctx, "pg cache invalidate", "index", index, "keys", keys)
return err
}
func (c *pgCache[I, K, V]) Delete(ctx context.Context, index I, keys ...K) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
_, err = c.pool.Exec(ctx, deleteQuery, c.name, index, keys)
c.logger.DebugContext(ctx, "pg cache delete", "index", index, "keys", keys)
return err
}
func (c *pgCache[I, K, V]) Prune(ctx context.Context) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
_, err = c.pool.Exec(ctx, pruneQuery, c.name, c.config.MaxAge, c.config.LastUseAge)
c.logger.DebugContext(ctx, "pg cache prune")
return err
}
func (c *pgCache[I, K, V]) Truncate(ctx context.Context) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
_, err = c.pool.Exec(ctx, truncateQuery, c.name)
c.logger.DebugContext(ctx, "pg cache truncate")
return err
}
type indexKey[I, K comparable] struct {
IndexID I `json:"index_id"`
IndexKey K `json:"index_key"`
}
func (c *pgCache[I, K, V]) indexKeysFromEntry(entry V) []indexKey[I, K] {
keys := make([]indexKey[I, K], 0, len(c.indices)*3) // naive assumption
for _, index := range c.indices {
for _, key := range entry.Keys(index) {
keys = append(keys, indexKey[I, K]{
IndexID: index,
IndexKey: key,
})
}
}
return keys
}

519
internal/cache/pg/pg_test.go vendored Normal file
View File

@ -0,0 +1,519 @@
package pg
import (
"context"
"regexp"
"testing"
"time"
"github.com/jackc/pgx/v5"
"github.com/pashagolub/pgxmock/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/cache"
)
type testIndex int
const (
testIndexID testIndex = iota
testIndexName
)
var testIndices = []testIndex{
testIndexID,
testIndexName,
}
type testObject struct {
ID string
Name []string
}
func (o *testObject) Keys(index testIndex) []string {
switch index {
case testIndexID:
return []string{o.ID}
case testIndexName:
return o.Name
default:
return nil
}
}
func TestNewCache(t *testing.T) {
tests := []struct {
name string
expect func(pgxmock.PgxCommonIface)
wantErr error
}{
{
name: "error",
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(regexp.QuoteMeta(expectedCreatePartitionQuery)).
WillReturnError(pgx.ErrTxClosed)
},
wantErr: pgx.ErrTxClosed,
},
{
name: "success",
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(regexp.QuoteMeta(expectedCreatePartitionQuery)).
WillReturnResult(pgxmock.NewResult("CREATE TABLE", 0))
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conf := cache.CacheConfig{
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
}
pool, err := pgxmock.NewPool()
require.NoError(t, err)
tt.expect(pool)
c, err := NewCache[testIndex, string, *testObject](context.Background(), cacheName, conf, testIndices, pool, "postgres")
require.ErrorIs(t, err, tt.wantErr)
if tt.wantErr == nil {
assert.NotNil(t, c)
}
err = pool.ExpectationsWereMet()
assert.NoError(t, err)
})
}
}
func Test_pgCache_Set(t *testing.T) {
queryExpect := regexp.QuoteMeta(setQuery)
type args struct {
entry *testObject
}
tests := []struct {
name string
args args
expect func(pgxmock.PgxCommonIface)
wantErr error
}{
{
name: "error",
args: args{
&testObject{
ID: "id1",
Name: []string{"foo", "bar"},
},
},
expect: func(ppi pgxmock.PgxCommonIface) {
ppi.ExpectExec(queryExpect).
WithArgs("test",
[]indexKey[testIndex, string]{
{IndexID: testIndexID, IndexKey: "id1"},
{IndexID: testIndexName, IndexKey: "foo"},
{IndexID: testIndexName, IndexKey: "bar"},
},
&testObject{
ID: "id1",
Name: []string{"foo", "bar"},
}).
WillReturnError(pgx.ErrTxClosed)
},
wantErr: pgx.ErrTxClosed,
},
{
name: "success",
args: args{
&testObject{
ID: "id1",
Name: []string{"foo", "bar"},
},
},
expect: func(ppi pgxmock.PgxCommonIface) {
ppi.ExpectExec(queryExpect).
WithArgs("test",
[]indexKey[testIndex, string]{
{IndexID: testIndexID, IndexKey: "id1"},
{IndexID: testIndexName, IndexKey: "foo"},
{IndexID: testIndexName, IndexKey: "bar"},
},
&testObject{
ID: "id1",
Name: []string{"foo", "bar"},
}).
WillReturnResult(pgxmock.NewResult("INSERT", 1))
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, pool := prepareCache(t, cache.CacheConfig{})
defer pool.Close()
tt.expect(pool)
err := c.(*pgCache[testIndex, string, *testObject]).
set(context.Background(), tt.args.entry)
require.ErrorIs(t, err, tt.wantErr)
err = pool.ExpectationsWereMet()
assert.NoError(t, err)
})
}
}
func Test_pgCache_Get(t *testing.T) {
queryExpect := regexp.QuoteMeta(getQuery)
type args struct {
index testIndex
key string
}
tests := []struct {
name string
config cache.CacheConfig
args args
expect func(pgxmock.PgxCommonIface)
want *testObject
wantOk bool
}{
{
name: "invalid index",
config: cache.CacheConfig{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
args: args{
index: 99,
key: "id1",
},
expect: func(pci pgxmock.PgxCommonIface) {},
wantOk: false,
},
{
name: "no rows",
config: cache.CacheConfig{
MaxAge: 0,
LastUseAge: 0,
},
args: args{
index: testIndexID,
key: "id1",
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectQuery(queryExpect).
WithArgs("test", testIndexID, "id1", time.Duration(0), time.Duration(0)).
WillReturnRows(pgxmock.NewRows([]string{"payload"}))
},
wantOk: false,
},
{
name: "error",
config: cache.CacheConfig{
MaxAge: 0,
LastUseAge: 0,
},
args: args{
index: testIndexID,
key: "id1",
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectQuery(queryExpect).
WithArgs("test", testIndexID, "id1", time.Duration(0), time.Duration(0)).
WillReturnError(pgx.ErrTxClosed)
},
wantOk: false,
},
{
name: "ok",
config: cache.CacheConfig{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
args: args{
index: testIndexID,
key: "id1",
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectQuery(queryExpect).
WithArgs("test", testIndexID, "id1", time.Minute, time.Second).
WillReturnRows(
pgxmock.NewRows([]string{"payload"}).AddRow(&testObject{
ID: "id1",
Name: []string{"foo", "bar"},
}),
)
},
want: &testObject{
ID: "id1",
Name: []string{"foo", "bar"},
},
wantOk: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, pool := prepareCache(t, tt.config)
defer pool.Close()
tt.expect(pool)
got, ok := c.Get(context.Background(), tt.args.index, tt.args.key)
assert.Equal(t, tt.wantOk, ok)
assert.Equal(t, tt.want, got)
err := pool.ExpectationsWereMet()
assert.NoError(t, err)
})
}
}
func Test_pgCache_Invalidate(t *testing.T) {
queryExpect := regexp.QuoteMeta(invalidateQuery)
type args struct {
index testIndex
keys []string
}
tests := []struct {
name string
config cache.CacheConfig
args args
expect func(pgxmock.PgxCommonIface)
wantErr error
}{
{
name: "error",
config: cache.CacheConfig{
MaxAge: 0,
LastUseAge: 0,
},
args: args{
index: testIndexID,
keys: []string{"id1", "id2"},
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(queryExpect).
WithArgs("test", testIndexID, []string{"id1", "id2"}).
WillReturnError(pgx.ErrTxClosed)
},
wantErr: pgx.ErrTxClosed,
},
{
name: "ok",
config: cache.CacheConfig{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
args: args{
index: testIndexID,
keys: []string{"id1", "id2"},
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(queryExpect).
WithArgs("test", testIndexID, []string{"id1", "id2"}).
WillReturnResult(pgxmock.NewResult("DELETE", 1))
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, pool := prepareCache(t, tt.config)
defer pool.Close()
tt.expect(pool)
err := c.Invalidate(context.Background(), tt.args.index, tt.args.keys...)
assert.ErrorIs(t, err, tt.wantErr)
err = pool.ExpectationsWereMet()
assert.NoError(t, err)
})
}
}
func Test_pgCache_Delete(t *testing.T) {
queryExpect := regexp.QuoteMeta(deleteQuery)
type args struct {
index testIndex
keys []string
}
tests := []struct {
name string
config cache.CacheConfig
args args
expect func(pgxmock.PgxCommonIface)
wantErr error
}{
{
name: "error",
config: cache.CacheConfig{
MaxAge: 0,
LastUseAge: 0,
},
args: args{
index: testIndexID,
keys: []string{"id1", "id2"},
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(queryExpect).
WithArgs("test", testIndexID, []string{"id1", "id2"}).
WillReturnError(pgx.ErrTxClosed)
},
wantErr: pgx.ErrTxClosed,
},
{
name: "ok",
config: cache.CacheConfig{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
args: args{
index: testIndexID,
keys: []string{"id1", "id2"},
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(queryExpect).
WithArgs("test", testIndexID, []string{"id1", "id2"}).
WillReturnResult(pgxmock.NewResult("DELETE", 1))
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, pool := prepareCache(t, tt.config)
defer pool.Close()
tt.expect(pool)
err := c.Delete(context.Background(), tt.args.index, tt.args.keys...)
assert.ErrorIs(t, err, tt.wantErr)
err = pool.ExpectationsWereMet()
assert.NoError(t, err)
})
}
}
func Test_pgCache_Prune(t *testing.T) {
queryExpect := regexp.QuoteMeta(pruneQuery)
tests := []struct {
name string
config cache.CacheConfig
expect func(pgxmock.PgxCommonIface)
wantErr error
}{
{
name: "error",
config: cache.CacheConfig{
MaxAge: 0,
LastUseAge: 0,
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(queryExpect).
WithArgs("test", time.Duration(0), time.Duration(0)).
WillReturnError(pgx.ErrTxClosed)
},
wantErr: pgx.ErrTxClosed,
},
{
name: "ok",
config: cache.CacheConfig{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(queryExpect).
WithArgs("test", time.Minute, time.Second).
WillReturnResult(pgxmock.NewResult("DELETE", 1))
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, pool := prepareCache(t, tt.config)
defer pool.Close()
tt.expect(pool)
err := c.Prune(context.Background())
assert.ErrorIs(t, err, tt.wantErr)
err = pool.ExpectationsWereMet()
assert.NoError(t, err)
})
}
}
func Test_pgCache_Truncate(t *testing.T) {
queryExpect := regexp.QuoteMeta(truncateQuery)
tests := []struct {
name string
config cache.CacheConfig
expect func(pgxmock.PgxCommonIface)
wantErr error
}{
{
name: "error",
config: cache.CacheConfig{
MaxAge: 0,
LastUseAge: 0,
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(queryExpect).
WithArgs("test").
WillReturnError(pgx.ErrTxClosed)
},
wantErr: pgx.ErrTxClosed,
},
{
name: "ok",
config: cache.CacheConfig{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(queryExpect).
WithArgs("test").
WillReturnResult(pgxmock.NewResult("DELETE", 1))
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, pool := prepareCache(t, tt.config)
defer pool.Close()
tt.expect(pool)
err := c.Truncate(context.Background())
assert.ErrorIs(t, err, tt.wantErr)
err = pool.ExpectationsWereMet()
assert.NoError(t, err)
})
}
}
const (
cacheName = "test"
expectedCreatePartitionQuery = `create unlogged table if not exists cache.objects_test
partition of cache.objects
for values in ('test');
create unlogged table if not exists cache.string_keys_test
partition of cache.string_keys
for values in ('test');
`
)
func prepareCache(t *testing.T, conf cache.CacheConfig) (cache.PrunerCache[testIndex, string, *testObject], pgxmock.PgxPoolIface) {
conf.Log = &logging.Config{
Level: "debug",
AddSource: true,
}
pool, err := pgxmock.NewPool()
require.NoError(t, err)
pool.ExpectExec(regexp.QuoteMeta(expectedCreatePartitionQuery)).
WillReturnResult(pgxmock.NewResult("CREATE TABLE", 0))
c, err := NewCache[testIndex, string, *testObject](context.Background(), cacheName, conf, testIndices, pool, "postgres")
require.NoError(t, err)
return c, pool
}

18
internal/cache/pg/prune.sql vendored Normal file
View File

@ -0,0 +1,18 @@
delete from cache.objects o
where o.cache_name = $1
and (
case when $2::interval > '0s'
then created_at < now()-$2::interval -- max age
else false
end
or case when $3::interval > '0s'
then last_used_at < now()-$3::interval -- last use
else false
end
or o.id not in (
select object_id
from cache.string_keys
where cache_name = $1
)
)
;

19
internal/cache/pg/set.sql vendored Normal file
View File

@ -0,0 +1,19 @@
with object as (
insert into cache.objects (cache_name, payload)
values ($1, $3)
returning id
)
insert into cache.string_keys (
cache_name,
index_id,
index_key,
object_id
)
select $1, keys.index_id, keys.index_key, id as object_id
from object, jsonb_to_recordset($2) keys (
index_id bigint,
index_key text
)
on conflict (cache_name, index_id, index_key) do
update set object_id = EXCLUDED.object_id
;

3
internal/cache/pg/truncate.sql vendored Normal file
View File

@ -0,0 +1,3 @@
delete from cache.objects o
where o.cache_name = $1
;

View File

@ -71,15 +71,15 @@ func (_ *Config) Decode(configs []interface{}) (dialect.Connector, error) {
return connector, nil
}
func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpose dialect.DBPurpose) (*sql.DB, error) {
func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpose dialect.DBPurpose) (*sql.DB, *pgxpool.Pool, error) {
connConfig, err := dialect.NewConnectionConfig(c.MaxOpenConns, c.MaxIdleConns, pusherRatio, spoolerRatio, purpose)
if err != nil {
return nil, err
return nil, nil, err
}
config, err := pgxpool.ParseConfig(c.String(useAdmin, purpose.AppName()))
if err != nil {
return nil, err
return nil, nil, err
}
if connConfig.MaxOpenConns != 0 {
@ -91,14 +91,14 @@ func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpo
pool, err := pgxpool.NewWithConfig(context.Background(), config)
if err != nil {
return nil, err
return nil, nil, err
}
if err := pool.Ping(context.Background()); err != nil {
return nil, err
return nil, nil, err
}
return stdlib.OpenDBFromPool(pool), nil
return stdlib.OpenDBFromPool(pool), pool, nil
}
func (c *Config) DatabaseName() string {

View File

@ -8,6 +8,7 @@ import (
"reflect"
"strings"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/mitchellh/mapstructure"
"github.com/zitadel/logging"
@ -31,6 +32,7 @@ func (c *Config) SetConnector(connector dialect.Connector) {
type DB struct {
*sql.DB
dialect.Database
Pool *pgxpool.Pool
}
func (db *DB) Query(scan func(*sql.Rows) error, query string, args ...any) error {
@ -113,7 +115,7 @@ func QueryJSONObject[T any](ctx context.Context, db *DB, query string, args ...a
}
func Connect(config Config, useAdmin bool, purpose dialect.DBPurpose) (*DB, error) {
client, err := config.connector.Connect(useAdmin, config.EventPushConnRatio, config.ProjectionSpoolerConnRatio, purpose)
client, pool, err := config.connector.Connect(useAdmin, config.EventPushConnRatio, config.ProjectionSpoolerConnRatio, purpose)
if err != nil {
return nil, err
}
@ -125,6 +127,7 @@ func Connect(config Config, useAdmin bool, purpose dialect.DBPurpose) (*DB, erro
return &DB{
DB: client,
Database: config.connector,
Pool: pool,
}, nil
}

View File

@ -4,6 +4,8 @@ import (
"database/sql"
"sync"
"time"
"github.com/jackc/pgx/v5/pgxpool"
)
type Dialect struct {
@ -53,7 +55,7 @@ func (p DBPurpose) AppName() string {
}
type Connector interface {
Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpose DBPurpose) (*sql.DB, error)
Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpose DBPurpose) (*sql.DB, *pgxpool.Pool, error)
Password() string
Database
}

View File

@ -72,15 +72,15 @@ func (_ *Config) Decode(configs []interface{}) (dialect.Connector, error) {
return connector, nil
}
func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpose dialect.DBPurpose) (*sql.DB, error) {
func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpose dialect.DBPurpose) (*sql.DB, *pgxpool.Pool, error) {
connConfig, err := dialect.NewConnectionConfig(c.MaxOpenConns, c.MaxIdleConns, pusherRatio, spoolerRatio, purpose)
if err != nil {
return nil, err
return nil, nil, err
}
config, err := pgxpool.ParseConfig(c.String(useAdmin, purpose.AppName()))
if err != nil {
return nil, err
return nil, nil, err
}
if connConfig.MaxOpenConns != 0 {
@ -95,14 +95,14 @@ func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpo
config,
)
if err != nil {
return nil, err
return nil, nil, err
}
if err := pool.Ping(context.Background()); err != nil {
return nil, err
return nil, nil, err
}
return stdlib.OpenDBFromPool(pool), nil
return stdlib.OpenDBFromPool(pool), pool, nil
}
func (c *Config) DatabaseName() string {

View File

@ -8,15 +8,15 @@ TLS:
Caches:
Connectors:
Memory:
Postgres:
Enabled: true
AutoPrune:
Interval: 30s
TimeOut: 1s
Instance:
Connector: "memory"
MaxAge: 1m
LastUsage: 30s
Connector: "postgres"
MaxAge: 1h
LastUsage: 30m
Log:
Level: info
AddSource: true

View File

@ -10,6 +10,8 @@ import (
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/cache/gomap"
"github.com/zitadel/zitadel/internal/cache/noop"
"github.com/zitadel/zitadel/internal/cache/pg"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
)
@ -18,14 +20,14 @@ type Caches struct {
instance cache.Cache[instanceIndex, string, *authzInstance]
}
func startCaches(background context.Context, conf *cache.CachesConfig) (_ *Caches, err error) {
func startCaches(background context.Context, conf *cache.CachesConfig, client *database.DB) (_ *Caches, err error) {
caches := &Caches{
instance: noop.NewCache[instanceIndex, string, *authzInstance](),
}
if conf == nil {
return caches, nil
}
caches.connectors, err = startCacheConnectors(background, conf)
caches.connectors, err = startCacheConnectors(background, conf, client)
if err != nil {
return nil, err
}
@ -39,20 +41,30 @@ func startCaches(background context.Context, conf *cache.CachesConfig) (_ *Cache
}
type cacheConnectors struct {
memory *cache.AutoPruneConfig
// pool *pgxpool.Pool
memory *cache.AutoPruneConfig
postgres *pgxPoolCacheConnector
}
func startCacheConnectors(_ context.Context, conf *cache.CachesConfig) (*cacheConnectors, error) {
type pgxPoolCacheConnector struct {
*cache.AutoPruneConfig
client *database.DB
}
func startCacheConnectors(_ context.Context, conf *cache.CachesConfig, client *database.DB) (_ *cacheConnectors, err error) {
connectors := new(cacheConnectors)
if conf.Connectors.Memory.Enabled {
connectors.memory = &conf.Connectors.Memory.AutoPrune
}
if conf.Connectors.Postgres.Enabled {
connectors.postgres = &pgxPoolCacheConnector{
AutoPruneConfig: &conf.Connectors.Postgres.AutoPrune,
client: client,
}
}
return connectors, nil
}
func startCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, name string, conf *cache.CacheConfig, connectors *cacheConnectors) (cache.Cache[I, K, V], error) {
func startCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, name string, conf *cache.CacheConfig, connectors *cacheConnectors) (cache.Cache[I, K, V], error) {
if conf == nil || conf.Connector == "" {
return noop.NewCache[I, K, V](), nil
}
@ -61,12 +73,15 @@ func startCache[I, K comparable, V cache.Entry[I, K]](background context.Context
connectors.memory.StartAutoPrune(background, c, name)
return c, nil
}
/* TODO
if strings.EqualFold(conf.Connector, "sql") && connectors.pool != nil {
return ...
if strings.EqualFold(conf.Connector, "postgres") && connectors.postgres != nil {
client := connectors.postgres.client
c, err := pg.NewCache[I, K, V](background, name, *conf, indices, client.Pool, client.Type())
if err != nil {
return nil, fmt.Errorf("query start cache: %w", err)
}
connectors.postgres.StartAutoPrune(background, c, name)
return c, nil
}
*/
return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector)
}

View File

@ -435,71 +435,71 @@ func prepareInstanceDomainQuery(ctx context.Context, db prepareDatabase) (sq.Sel
}
type authzInstance struct {
id string
iamProjectID string
consoleID string
consoleAppID string
defaultLang language.Tag
defaultOrgID string
csp csp
enableImpersonation bool
block *bool
auditLogRetention *time.Duration
features feature.Features
externalDomains database.TextArray[string]
trustedDomains database.TextArray[string]
ID string `json:"id,omitempty"`
IAMProjectID string `json:"iam_project_id,omitempty"`
ConsoleID string `json:"console_id,omitempty"`
ConsoleAppID string `json:"console_app_id,omitempty"`
DefaultLang language.Tag `json:"default_lang,omitempty"`
DefaultOrgID string `json:"default_org_id,omitempty"`
CSP csp `json:"csp,omitempty"`
Impersonation bool `json:"impersonation,omitempty"`
IsBlocked *bool `json:"is_blocked,omitempty"`
LogRetention *time.Duration `json:"log_retention,omitempty"`
Feature feature.Features `json:"feature,omitempty"`
ExternalDomains database.TextArray[string] `json:"external_domains,omitempty"`
TrustedDomains database.TextArray[string] `json:"trusted_domains,omitempty"`
}
type csp struct {
enableIframeEmbedding bool
allowedOrigins database.TextArray[string]
EnableIframeEmbedding bool `json:"enable_iframe_embedding,omitempty"`
AllowedOrigins database.TextArray[string] `json:"allowed_origins,omitempty"`
}
func (i *authzInstance) InstanceID() string {
return i.id
return i.ID
}
func (i *authzInstance) ProjectID() string {
return i.iamProjectID
return i.IAMProjectID
}
func (i *authzInstance) ConsoleClientID() string {
return i.consoleID
return i.ConsoleID
}
func (i *authzInstance) ConsoleApplicationID() string {
return i.consoleAppID
return i.ConsoleAppID
}
func (i *authzInstance) DefaultLanguage() language.Tag {
return i.defaultLang
return i.DefaultLang
}
func (i *authzInstance) DefaultOrganisationID() string {
return i.defaultOrgID
return i.DefaultOrgID
}
func (i *authzInstance) SecurityPolicyAllowedOrigins() []string {
if !i.csp.enableIframeEmbedding {
if !i.CSP.EnableIframeEmbedding {
return nil
}
return i.csp.allowedOrigins
return i.CSP.AllowedOrigins
}
func (i *authzInstance) EnableImpersonation() bool {
return i.enableImpersonation
return i.Impersonation
}
func (i *authzInstance) Block() *bool {
return i.block
return i.IsBlocked
}
func (i *authzInstance) AuditLogRetention() *time.Duration {
return i.auditLogRetention
return i.LogRetention
}
func (i *authzInstance) Features() feature.Features {
return i.features
return i.Feature
}
var errPublicDomain = "public domain %q not trusted"
@ -509,7 +509,7 @@ func (i *authzInstance) checkDomain(instanceDomain, publicDomain string) error {
if publicDomain == "" || instanceDomain == publicDomain {
return nil
}
if !slices.Contains(i.trustedDomains, publicDomain) {
if !slices.Contains(i.TrustedDomains, publicDomain) {
return zerrors.ThrowNotFound(fmt.Errorf(errPublicDomain, publicDomain), "QUERY-IuGh1", "Errors.IAM.NotFound")
}
return nil
@ -519,9 +519,9 @@ func (i *authzInstance) checkDomain(instanceDomain, publicDomain string) error {
func (i *authzInstance) Keys(index instanceIndex) []string {
switch index {
case instanceIndexByID:
return []string{i.id}
return []string{i.ID}
case instanceIndexByHost:
return i.externalDomains
return i.ExternalDomains
default:
return nil
}
@ -539,20 +539,20 @@ func scanAuthzInstance() (*authzInstance, func(row *sql.Row) error) {
features []byte
)
err := row.Scan(
&instance.id,
&instance.defaultOrgID,
&instance.iamProjectID,
&instance.consoleID,
&instance.consoleAppID,
&instance.ID,
&instance.DefaultOrgID,
&instance.IAMProjectID,
&instance.ConsoleID,
&instance.ConsoleAppID,
&lang,
&enableIframeEmbedding,
&instance.csp.allowedOrigins,
&instance.CSP.AllowedOrigins,
&enableImpersonation,
&auditLogRetention,
&block,
&features,
&instance.externalDomains,
&instance.trustedDomains,
&instance.ExternalDomains,
&instance.TrustedDomains,
)
if errors.Is(err, sql.ErrNoRows) {
return zerrors.ThrowNotFound(nil, "QUERY-1kIjX", "Errors.IAM.NotFound")
@ -560,19 +560,19 @@ func scanAuthzInstance() (*authzInstance, func(row *sql.Row) error) {
if err != nil {
return zerrors.ThrowInternal(err, "QUERY-d3fas", "Errors.Internal")
}
instance.defaultLang = language.Make(lang)
instance.DefaultLang = language.Make(lang)
if auditLogRetention.Valid {
instance.auditLogRetention = &auditLogRetention.Duration
instance.LogRetention = &auditLogRetention.Duration
}
if block.Valid {
instance.block = &block.Bool
instance.IsBlocked = &block.Bool
}
instance.csp.enableIframeEmbedding = enableIframeEmbedding.Bool
instance.enableImpersonation = enableImpersonation.Bool
instance.CSP.EnableIframeEmbedding = enableIframeEmbedding.Bool
instance.Impersonation = enableImpersonation.Bool
if len(features) == 0 {
return nil
}
if err = json.Unmarshal(features, &instance.features); err != nil {
if err = json.Unmarshal(features, &instance.Feature); err != nil {
return zerrors.ThrowInternal(err, "QUERY-Po8ki", "Errors.Internal")
}
return nil
@ -598,10 +598,12 @@ func (c *Caches) registerInstanceInvalidation() {
})
}
type instanceIndex int16
type instanceIndex int
//go:generate enumer -type instanceIndex
//go:generate enumer -type instanceIndex -linecomment
const (
instanceIndexByID instanceIndex = iota
// Empty line comment ensures empty string for unspecified value
instanceIndexUnspecified instanceIndex = iota //
instanceIndexByID
instanceIndexByHost
)

View File

@ -1,4 +1,4 @@
// Code generated by "enumer -type instanceIndex"; DO NOT EDIT.
// Code generated by "enumer -type instanceIndex -linecomment"; DO NOT EDIT.
package query
@ -9,7 +9,7 @@ import (
const _instanceIndexName = "instanceIndexByIDinstanceIndexByHost"
var _instanceIndexIndex = [...]uint8{0, 17, 36}
var _instanceIndexIndex = [...]uint8{0, 0, 17, 36}
const _instanceIndexLowerName = "instanceindexbyidinstanceindexbyhost"
@ -24,13 +24,16 @@ func (i instanceIndex) String() string {
// Re-run the stringer command to generate them again.
func _instanceIndexNoOp() {
var x [1]struct{}
_ = x[instanceIndexByID-(0)]
_ = x[instanceIndexByHost-(1)]
_ = x[instanceIndexUnspecified-(0)]
_ = x[instanceIndexByID-(1)]
_ = x[instanceIndexByHost-(2)]
}
var _instanceIndexValues = []instanceIndex{instanceIndexByID, instanceIndexByHost}
var _instanceIndexValues = []instanceIndex{instanceIndexUnspecified, instanceIndexByID, instanceIndexByHost}
var _instanceIndexNameToValueMap = map[string]instanceIndex{
_instanceIndexName[0:0]: instanceIndexUnspecified,
_instanceIndexLowerName[0:0]: instanceIndexUnspecified,
_instanceIndexName[0:17]: instanceIndexByID,
_instanceIndexLowerName[0:17]: instanceIndexByID,
_instanceIndexName[17:36]: instanceIndexByHost,
@ -38,6 +41,7 @@ var _instanceIndexNameToValueMap = map[string]instanceIndex{
}
var _instanceIndexNames = []string{
_instanceIndexName[0:0],
_instanceIndexName[0:17],
_instanceIndexName[17:36],
}

View File

@ -89,7 +89,7 @@ func StartQueries(
if startProjections {
projection.Start(ctx)
}
repo.caches, err = startCaches(ctx, caches)
repo.caches, err = startCaches(ctx, caches, querySqlClient)
if err != nil {
return nil, err
}