Files
zitadel/internal/cache/connector/pg/pg.go
Tim Möhlmann f6f37d3a31 fix(cache): use key versioning (#10657)
# Which Problems Are Solved

Cached object may have a different schema between Zitadel versions.

# How the Problems Are Solved

Use the curent build version in DB based cache connectors PostgreSQL and
Redis.

# Additional Changes

- Cleanup the ZitadelVersion field from the authz Instance
- Solve potential race condition on global variables in build package.

# Additional Context

- Closes https://github.com/zitadel/zitadel/issues/10648
- Obsoletes https://github.com/zitadel/zitadel/pull/10646
- Needs to be back-ported to v4 over
https://github.com/zitadel/zitadel/pull/10645
2025-09-15 09:51:54 +00:00

199 lines
5.5 KiB
Go

package pg
import (
"context"
_ "embed"
"errors"
"fmt"
"log/slog"
"slices"
"strings"
"text/template"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
var (
//go:embed create_partition.sql.tmpl
createPartitionQuery string
createPartitionTmpl = template.Must(template.New("create_partition").Parse(createPartitionQuery))
//go:embed set.sql
setQuery string
//go:embed get.sql
getQuery string
//go:embed invalidate.sql
invalidateQuery string
//go:embed delete.sql
deleteQuery string
//go:embed prune.sql
pruneQuery string
//go:embed truncate.sql
truncateQuery string
)
type PGXPool interface {
Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
}
type pgCache[I ~int, K ~string, V cache.Entry[I, K]] struct {
purpose cache.Purpose
zitadelVersion string
config *cache.Config
indices []I
connector *Connector
logger *slog.Logger
}
// NewCache returns a cache that stores and retrieves objects using PostgreSQL unlogged tables.
func NewCache[I ~int, K ~string, V cache.Entry[I, K]](ctx context.Context, purpose cache.Purpose, zitadelVersion string, config cache.Config, indices []I, connector *Connector) (cache.PrunerCache[I, K, V], error) {
c := &pgCache[I, K, V]{
purpose: purpose,
zitadelVersion: zitadelVersion,
config: &config,
indices: indices,
connector: connector,
logger: config.Log.Slog().With("cache_purpose", purpose),
}
c.logger.InfoContext(ctx, "pg cache logging enabled")
if err := c.createPartition(ctx); err != nil {
return nil, err
}
return c, nil
}
func (c *pgCache[I, K, V]) createPartition(ctx context.Context) error {
var query strings.Builder
if err := createPartitionTmpl.Execute(&query, c.purpose.String()); err != nil {
return err
}
_, err := c.connector.Exec(ctx, query.String())
return err
}
func (c *pgCache[I, K, V]) Set(ctx context.Context, entry V) {
//nolint:errcheck
c.set(ctx, entry)
}
func (c *pgCache[I, K, V]) set(ctx context.Context, entry V) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
keys := c.indexKeysFromEntry(entry)
c.logger.DebugContext(ctx, "pg cache set", "index_key", keys)
_, err = c.connector.Exec(ctx, setQuery, c.purpose.String(), keys, entry)
if err != nil {
c.logger.ErrorContext(ctx, "pg cache set", "err", err)
return err
}
return nil
}
func (c *pgCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) {
value, err := c.get(ctx, index, key)
if err == nil {
c.logger.DebugContext(ctx, "pg cache get", "index", index, "key", key)
return value, true
}
logger := c.logger.With("err", err, "index", index, "key", key)
if errors.Is(err, pgx.ErrNoRows) {
logger.InfoContext(ctx, "pg cache miss")
return value, false
}
logger.ErrorContext(ctx, "pg cache get", "err", err)
return value, false
}
func (c *pgCache[I, K, V]) get(ctx context.Context, index I, key K) (value V, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if !slices.Contains(c.indices, index) {
return value, cache.NewIndexUnknownErr(index)
}
err = c.connector.QueryRow(ctx,
getQuery,
c.purpose.String(),
index,
c.versionedKey(key),
c.config.MaxAge,
c.config.LastUseAge,
).Scan(&value)
return value, err
}
func (c *pgCache[I, K, V]) Invalidate(ctx context.Context, index I, keys ...K) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
versionedKeys := c.versionedKeys(keys)
_, err = c.connector.Exec(ctx, invalidateQuery, c.purpose.String(), index, versionedKeys)
c.logger.DebugContext(ctx, "pg cache invalidate", "index", index, "keys", keys)
return err
}
func (c *pgCache[I, K, V]) Delete(ctx context.Context, index I, keys ...K) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
versionedKeys := c.versionedKeys(keys)
_, err = c.connector.Exec(ctx, deleteQuery, c.purpose.String(), index, versionedKeys)
c.logger.DebugContext(ctx, "pg cache delete", "index", index, "keys", keys)
return err
}
func (c *pgCache[I, K, V]) Prune(ctx context.Context) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
_, err = c.connector.Exec(ctx, pruneQuery, c.purpose.String(), c.config.MaxAge, c.config.LastUseAge)
c.logger.DebugContext(ctx, "pg cache prune")
return err
}
func (c *pgCache[I, K, V]) Truncate(ctx context.Context) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
_, err = c.connector.Exec(ctx, truncateQuery, c.purpose.String())
c.logger.DebugContext(ctx, "pg cache truncate")
return err
}
type indexKey[I comparable] struct {
IndexID I `json:"index_id"`
IndexKey string `json:"index_key"`
}
func (c *pgCache[I, K, V]) indexKeysFromEntry(entry V) []indexKey[I] {
keys := make([]indexKey[I], 0, len(c.indices)*3) // naive assumption
for _, index := range c.indices {
for _, key := range entry.Keys(index) {
keys = append(keys, indexKey[I]{
IndexID: index,
IndexKey: c.versionedKey(key),
})
}
}
return keys
}
func (c *pgCache[I, K, V]) versionedKey(key K) string {
return fmt.Sprintf("%s:%s", c.zitadelVersion, key)
}
func (c *pgCache[I, K, V]) versionedKeys(key []K) []string {
result := make([]string, len(key))
for i, k := range key {
result[i] = c.versionedKey(k)
}
return result
}