chore: move the go code into a subfolder

This commit is contained in:
Florian Forster
2025-08-05 15:20:32 -07:00
parent 4ad22ba456
commit cd2921de26
2978 changed files with 373 additions and 300 deletions

113
apps/api/internal/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,113 @@
// Package cache provides abstraction of cache implementations that can be used by zitadel.
package cache
import (
"context"
"time"
"github.com/zitadel/logging"
)
// Purpose describes which object types are stored by a cache.
type Purpose int
//go:generate enumer -type Purpose -transform snake -trimprefix Purpose
const (
PurposeUnspecified Purpose = iota
PurposeAuthzInstance
PurposeMilestones
PurposeOrganization
PurposeIdPFormCallback
PurposeFederatedLogout
)
// Cache stores objects with a value of type `V`.
// Objects may be referred to by one or more indices.
// Implementations may encode the value for storage.
// This means non-exported fields may be lost and objects
// with function values may fail to encode.
// See https://pkg.go.dev/encoding/json#Marshal for example.
//
// `I` is the type by which indices are identified,
// typically an enum for type-safe access.
// Indices are defined when calling the constructor of an implementation of this interface.
// It is illegal to refer to an idex not defined during construction.
//
// `K` is the type used as key in each index.
// Due to the limitations in type constraints, all indices use the same key type.
//
// Implementations are free to use stricter type constraints or fixed typing.
type Cache[I, K comparable, V Entry[I, K]] interface {
// Get an object through specified index.
// An [IndexUnknownError] may be returned if the index is unknown.
// [ErrCacheMiss] is returned if the key was not found in the index,
// or the object is not valid.
Get(ctx context.Context, index I, key K) (V, bool)
// Set an object.
// Keys are created on each index based in the [Entry.Keys] method.
// If any key maps to an existing object, the object is invalidated,
// regardless if the object has other keys defined in the new entry.
// This to prevent ghost objects when an entry reduces the amount of keys
// for a given index.
Set(ctx context.Context, value V)
// Invalidate an object through specified index.
// Implementations may choose to instantly delete the object,
// defer until prune or a separate cleanup routine.
// Invalidated object are no longer returned from Get.
// It is safe to call Invalidate multiple times or on non-existing entries.
Invalidate(ctx context.Context, index I, key ...K) error
// Delete one or more keys from a specific index.
// An [IndexUnknownError] may be returned if the index is unknown.
// The referred object is not invalidated and may still be accessible though
// other indices and keys.
// It is safe to call Delete multiple times or on non-existing entries
Delete(ctx context.Context, index I, key ...K) error
// Truncate deletes all cached objects.
Truncate(ctx context.Context) error
}
// Entry contains a value of type `V` to be cached.
//
// `I` is the type by which indices are identified,
// typically an enum for type-safe access.
//
// `K` is the type used as key in an index.
// Due to the limitations in type constraints, all indices use the same key type.
type Entry[I, K comparable] interface {
// Keys returns which keys map to the object in a specified index.
// May return nil if the index in unknown or when there are no keys.
Keys(index I) (key []K)
}
type Connector int
//go:generate enumer -type Connector -transform snake -trimprefix Connector -linecomment -text
const (
// Empty line comment ensures empty string for unspecified value
ConnectorUnspecified Connector = iota //
ConnectorMemory
ConnectorPostgres
ConnectorRedis
)
type Config struct {
Connector Connector
// Age since an object was added to the cache,
// after which the object is considered invalid.
// 0 disables max age checks.
MaxAge time.Duration
// Age since last use (Get) of an object,
// after which the object is considered invalid.
// 0 disables last use age checks.
LastUseAge time.Duration
// Log allows logging of the specific cache.
// By default only errors are logged to stdout.
Log *logging.Config
}

View File

@@ -0,0 +1,72 @@
// Package connector provides glue between the [cache.Cache] interface and implementations from the connector sub-packages.
package connector
import (
"context"
"fmt"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/cache/connector/gomap"
"github.com/zitadel/zitadel/internal/cache/connector/noop"
"github.com/zitadel/zitadel/internal/cache/connector/pg"
"github.com/zitadel/zitadel/internal/cache/connector/redis"
"github.com/zitadel/zitadel/internal/database"
)
type CachesConfig struct {
Connectors struct {
Memory gomap.Config
Postgres pg.Config
Redis redis.Config
}
Instance *cache.Config
Milestones *cache.Config
Organization *cache.Config
IdPFormCallbacks *cache.Config
FederatedLogouts *cache.Config
}
type Connectors struct {
Config CachesConfig
Memory *gomap.Connector
Postgres *pg.Connector
Redis *redis.Connector
}
func StartConnectors(conf *CachesConfig, client *database.DB) (Connectors, error) {
if conf == nil {
return Connectors{}, nil
}
return Connectors{
Config: *conf,
Memory: gomap.NewConnector(conf.Connectors.Memory),
Postgres: pg.NewConnector(conf.Connectors.Postgres, client),
Redis: redis.NewConnector(conf.Connectors.Redis),
}, nil
}
func StartCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, purpose cache.Purpose, conf *cache.Config, connectors Connectors) (cache.Cache[I, K, V], error) {
if conf == nil || conf.Connector == cache.ConnectorUnspecified {
return noop.NewCache[I, K, V](), nil
}
if conf.Connector == cache.ConnectorMemory && connectors.Memory != nil {
c := gomap.NewCache[I, K, V](background, indices, *conf)
connectors.Memory.Config.StartAutoPrune(background, c, purpose)
return c, nil
}
if conf.Connector == cache.ConnectorPostgres && connectors.Postgres != nil {
c, err := pg.NewCache[I, K, V](background, purpose, *conf, indices, connectors.Postgres)
if err != nil {
return nil, fmt.Errorf("start cache: %w", err)
}
connectors.Postgres.Config.AutoPrune.StartAutoPrune(background, c, purpose)
return c, nil
}
if conf.Connector == cache.ConnectorRedis && connectors.Redis != nil {
db := connectors.Redis.Config.DBOffset + int(purpose)
c := redis.NewCache[I, K, V](*conf, connectors.Redis, db, indices)
return c, nil
}
return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector)
}

View File

@@ -0,0 +1,23 @@
package gomap
import (
"github.com/zitadel/zitadel/internal/cache"
)
type Config struct {
Enabled bool
AutoPrune cache.AutoPruneConfig
}
type Connector struct {
Config cache.AutoPruneConfig
}
func NewConnector(config Config) *Connector {
if !config.Enabled {
return nil
}
return &Connector{
Config: config.AutoPrune,
}
}

View File

@@ -0,0 +1,200 @@
package gomap
import (
"context"
"errors"
"log/slog"
"maps"
"os"
"sync"
"sync/atomic"
"time"
"github.com/zitadel/zitadel/internal/cache"
)
type mapCache[I, K comparable, V cache.Entry[I, K]] struct {
config *cache.Config
indexMap map[I]*index[K, V]
logger *slog.Logger
}
// NewCache returns an in-memory Cache implementation based on the builtin go map type.
// Object values are stored as-is and there is no encoding or decoding involved.
func NewCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, config cache.Config) cache.PrunerCache[I, K, V] {
m := &mapCache[I, K, V]{
config: &config,
indexMap: make(map[I]*index[K, V], len(indices)),
logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
AddSource: true,
Level: slog.LevelError,
})),
}
if config.Log != nil {
m.logger = config.Log.Slog()
}
m.logger.InfoContext(background, "map cache logging enabled")
for _, name := range indices {
m.indexMap[name] = &index[K, V]{
config: m.config,
entries: make(map[K]*entry[V]),
}
}
return m
}
func (c *mapCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) {
i, ok := c.indexMap[index]
if !ok {
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
return value, false
}
entry, err := i.Get(key)
if err == nil {
c.logger.DebugContext(ctx, "map cache get", "index", index, "key", key)
return entry.value, true
}
if errors.Is(err, cache.ErrCacheMiss) {
c.logger.InfoContext(ctx, "map cache get", "err", err, "index", index, "key", key)
return value, false
}
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
return value, false
}
func (c *mapCache[I, K, V]) Set(ctx context.Context, value V) {
now := time.Now()
entry := &entry[V]{
value: value,
created: now,
}
entry.lastUse.Store(now.UnixMicro())
for name, i := range c.indexMap {
keys := value.Keys(name)
i.Set(keys, entry)
c.logger.DebugContext(ctx, "map cache set", "index", name, "keys", keys)
}
}
func (c *mapCache[I, K, V]) Invalidate(ctx context.Context, index I, keys ...K) error {
i, ok := c.indexMap[index]
if !ok {
return cache.NewIndexUnknownErr(index)
}
i.Invalidate(keys)
c.logger.DebugContext(ctx, "map cache invalidate", "index", index, "keys", keys)
return nil
}
func (c *mapCache[I, K, V]) Delete(ctx context.Context, index I, keys ...K) error {
i, ok := c.indexMap[index]
if !ok {
return cache.NewIndexUnknownErr(index)
}
i.Delete(keys)
c.logger.DebugContext(ctx, "map cache delete", "index", index, "keys", keys)
return nil
}
func (c *mapCache[I, K, V]) Prune(ctx context.Context) error {
for name, index := range c.indexMap {
index.Prune()
c.logger.DebugContext(ctx, "map cache prune", "index", name)
}
return nil
}
func (c *mapCache[I, K, V]) Truncate(ctx context.Context) error {
for name, index := range c.indexMap {
index.Truncate()
c.logger.DebugContext(ctx, "map cache truncate", "index", name)
}
return nil
}
type index[K comparable, V any] struct {
mutex sync.RWMutex
config *cache.Config
entries map[K]*entry[V]
}
func (i *index[K, V]) Get(key K) (*entry[V], error) {
i.mutex.RLock()
entry, ok := i.entries[key]
i.mutex.RUnlock()
if ok && entry.isValid(i.config) {
return entry, nil
}
return nil, cache.ErrCacheMiss
}
func (c *index[K, V]) Set(keys []K, entry *entry[V]) {
c.mutex.Lock()
for _, key := range keys {
c.entries[key] = entry
}
c.mutex.Unlock()
}
func (i *index[K, V]) Invalidate(keys []K) {
i.mutex.RLock()
for _, key := range keys {
if entry, ok := i.entries[key]; ok {
entry.invalid.Store(true)
}
}
i.mutex.RUnlock()
}
func (c *index[K, V]) Delete(keys []K) {
c.mutex.Lock()
for _, key := range keys {
delete(c.entries, key)
}
c.mutex.Unlock()
}
func (c *index[K, V]) Prune() {
c.mutex.Lock()
maps.DeleteFunc(c.entries, func(_ K, entry *entry[V]) bool {
return !entry.isValid(c.config)
})
c.mutex.Unlock()
}
func (c *index[K, V]) Truncate() {
c.mutex.Lock()
c.entries = make(map[K]*entry[V])
c.mutex.Unlock()
}
type entry[V any] struct {
value V
created time.Time
invalid atomic.Bool
lastUse atomic.Int64 // UnixMicro time
}
func (e *entry[V]) isValid(c *cache.Config) bool {
if e.invalid.Load() {
return false
}
now := time.Now()
if c.MaxAge > 0 {
if e.created.Add(c.MaxAge).Before(now) {
e.invalid.Store(true)
return false
}
}
if c.LastUseAge > 0 {
lastUse := e.lastUse.Load()
if time.UnixMicro(lastUse).Add(c.LastUseAge).Before(now) {
e.invalid.Store(true)
return false
}
e.lastUse.CompareAndSwap(lastUse, now.UnixMicro())
}
return true
}

View File

@@ -0,0 +1,329 @@
package gomap
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/cache"
)
type testIndex int
const (
testIndexID testIndex = iota
testIndexName
)
var testIndices = []testIndex{
testIndexID,
testIndexName,
}
type testObject struct {
id string
names []string
}
func (o *testObject) Keys(index testIndex) []string {
switch index {
case testIndexID:
return []string{o.id}
case testIndexName:
return o.names
default:
return nil
}
}
func Test_mapCache_Get(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
}
c.Set(context.Background(), obj)
type args struct {
index testIndex
key string
}
tests := []struct {
name string
args args
want *testObject
wantOk bool
}{
{
name: "ok",
args: args{
index: testIndexID,
key: "id",
},
want: obj,
wantOk: true,
},
{
name: "miss",
args: args{
index: testIndexID,
key: "spanac",
},
want: nil,
wantOk: false,
},
{
name: "unknown index",
args: args{
index: 99,
key: "id",
},
want: nil,
wantOk: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := c.Get(context.Background(), tt.args.index, tt.args.key)
assert.Equal(t, tt.want, got)
assert.Equal(t, tt.wantOk, ok)
})
}
}
func Test_mapCache_Invalidate(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
}
c.Set(context.Background(), obj)
err := c.Invalidate(context.Background(), testIndexName, "bar")
require.NoError(t, err)
got, ok := c.Get(context.Background(), testIndexID, "id")
assert.Nil(t, got)
assert.False(t, ok)
}
func Test_mapCache_Delete(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
}
c.Set(context.Background(), obj)
err := c.Delete(context.Background(), testIndexName, "bar")
require.NoError(t, err)
// Shouldn't find object by deleted name
got, ok := c.Get(context.Background(), testIndexName, "bar")
assert.Nil(t, got)
assert.False(t, ok)
// Should find object by other name
got, ok = c.Get(context.Background(), testIndexName, "foo")
assert.Equal(t, obj, got)
assert.True(t, ok)
// Should find object by id
got, ok = c.Get(context.Background(), testIndexID, "id")
assert.Equal(t, obj, got)
assert.True(t, ok)
}
func Test_mapCache_Prune(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
objects := []*testObject{
{
id: "id1",
names: []string{"foo", "bar"},
},
{
id: "id2",
names: []string{"hello"},
},
}
for _, obj := range objects {
c.Set(context.Background(), obj)
}
// invalidate one entry
err := c.Invalidate(context.Background(), testIndexName, "bar")
require.NoError(t, err)
err = c.(cache.Pruner).Prune(context.Background())
require.NoError(t, err)
// Other object should still be found
got, ok := c.Get(context.Background(), testIndexID, "id2")
assert.Equal(t, objects[1], got)
assert.True(t, ok)
}
func Test_mapCache_Truncate(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
objects := []*testObject{
{
id: "id1",
names: []string{"foo", "bar"},
},
{
id: "id2",
names: []string{"hello"},
},
}
for _, obj := range objects {
c.Set(context.Background(), obj)
}
err := c.Truncate(context.Background())
require.NoError(t, err)
mc := c.(*mapCache[testIndex, string, *testObject])
for _, index := range mc.indexMap {
index.mutex.RLock()
assert.Len(t, index.entries, 0)
index.mutex.RUnlock()
}
}
func Test_entry_isValid(t *testing.T) {
type fields struct {
created time.Time
invalid bool
lastUse time.Time
}
tests := []struct {
name string
fields fields
config *cache.Config
want bool
}{
{
name: "invalid",
fields: fields{
created: time.Now(),
invalid: true,
lastUse: time.Now(),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: false,
},
{
name: "max age exceeded",
fields: fields{
created: time.Now().Add(-(time.Minute + time.Second)),
invalid: false,
lastUse: time.Now(),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: false,
},
{
name: "max age disabled",
fields: fields{
created: time.Now().Add(-(time.Minute + time.Second)),
invalid: false,
lastUse: time.Now(),
},
config: &cache.Config{
LastUseAge: time.Second,
},
want: true,
},
{
name: "last use age exceeded",
fields: fields{
created: time.Now().Add(-(time.Minute / 2)),
invalid: false,
lastUse: time.Now().Add(-(time.Second * 2)),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: false,
},
{
name: "last use age disabled",
fields: fields{
created: time.Now().Add(-(time.Minute / 2)),
invalid: false,
lastUse: time.Now().Add(-(time.Second * 2)),
},
config: &cache.Config{
MaxAge: time.Minute,
},
want: true,
},
{
name: "valid",
fields: fields{
created: time.Now(),
invalid: false,
lastUse: time.Now(),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &entry[any]{
created: tt.fields.created,
}
e.invalid.Store(tt.fields.invalid)
e.lastUse.Store(tt.fields.lastUse.UnixMicro())
got := e.isValid(tt.config)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -0,0 +1,21 @@
package noop
import (
"context"
"github.com/zitadel/zitadel/internal/cache"
)
type noop[I, K comparable, V cache.Entry[I, K]] struct{}
// NewCache returns a cache that does nothing
func NewCache[I, K comparable, V cache.Entry[I, K]]() cache.Cache[I, K, V] {
return noop[I, K, V]{}
}
func (noop[I, K, V]) Set(context.Context, V) {}
func (noop[I, K, V]) Get(context.Context, I, K) (value V, ok bool) { return }
func (noop[I, K, V]) Invalidate(context.Context, I, ...K) (err error) { return }
func (noop[I, K, V]) Delete(context.Context, I, ...K) (err error) { return }
func (noop[I, K, V]) Prune(context.Context) (err error) { return }
func (noop[I, K, V]) Truncate(context.Context) (err error) { return }

View File

@@ -0,0 +1,26 @@
package pg
import (
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/database"
)
type Config struct {
Enabled bool
AutoPrune cache.AutoPruneConfig
}
type Connector struct {
PGXPool
Config Config
}
func NewConnector(config Config, client *database.DB) *Connector {
if !config.Enabled {
return nil
}
return &Connector{
PGXPool: client.Pool,
Config: config,
}
}

View File

@@ -0,0 +1,7 @@
create unlogged table if not exists cache.objects_{{ . }}
partition of cache.objects
for values in ('{{ . }}');
create unlogged table if not exists cache.string_keys_{{ . }}
partition of cache.string_keys
for values in ('{{ . }}');

View File

@@ -0,0 +1,5 @@
delete from cache.string_keys k
where k.cache_name = $1
and k.index_id = $2
and k.index_key = any($3)
;

View File

@@ -0,0 +1,19 @@
update cache.objects
set last_used_at = now()
where cache_name = $1
and (
select object_id
from cache.string_keys k
where cache_name = $1
and index_id = $2
and index_key = $3
) = id
and case when $4::interval > '0s'
then created_at > now()-$4::interval -- max age
else true
end
and case when $5::interval > '0s'
then last_used_at > now()-$5::interval -- last use
else true
end
returning payload;

View File

@@ -0,0 +1,9 @@
delete from cache.objects o
using cache.string_keys k
where k.cache_name = $1
and k.index_id = $2
and k.index_key = any($3)
and o.cache_name = k.cache_name
and o.id = k.object_id
;

View File

@@ -0,0 +1,174 @@
package pg
import (
"context"
_ "embed"
"errors"
"log/slog"
"slices"
"strings"
"text/template"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
var (
//go:embed create_partition.sql.tmpl
createPartitionQuery string
createPartitionTmpl = template.Must(template.New("create_partition").Parse(createPartitionQuery))
//go:embed set.sql
setQuery string
//go:embed get.sql
getQuery string
//go:embed invalidate.sql
invalidateQuery string
//go:embed delete.sql
deleteQuery string
//go:embed prune.sql
pruneQuery string
//go:embed truncate.sql
truncateQuery string
)
type PGXPool interface {
Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
}
type pgCache[I ~int, K ~string, V cache.Entry[I, K]] struct {
purpose cache.Purpose
config *cache.Config
indices []I
connector *Connector
logger *slog.Logger
}
// NewCache returns a cache that stores and retrieves objects using PostgreSQL unlogged tables.
func NewCache[I ~int, K ~string, V cache.Entry[I, K]](ctx context.Context, purpose cache.Purpose, config cache.Config, indices []I, connector *Connector) (cache.PrunerCache[I, K, V], error) {
c := &pgCache[I, K, V]{
purpose: purpose,
config: &config,
indices: indices,
connector: connector,
logger: config.Log.Slog().With("cache_purpose", purpose),
}
c.logger.InfoContext(ctx, "pg cache logging enabled")
if err := c.createPartition(ctx); err != nil {
return nil, err
}
return c, nil
}
func (c *pgCache[I, K, V]) createPartition(ctx context.Context) error {
var query strings.Builder
if err := createPartitionTmpl.Execute(&query, c.purpose.String()); err != nil {
return err
}
_, err := c.connector.Exec(ctx, query.String())
return err
}
func (c *pgCache[I, K, V]) Set(ctx context.Context, entry V) {
//nolint:errcheck
c.set(ctx, entry)
}
func (c *pgCache[I, K, V]) set(ctx context.Context, entry V) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
keys := c.indexKeysFromEntry(entry)
c.logger.DebugContext(ctx, "pg cache set", "index_key", keys)
_, err = c.connector.Exec(ctx, setQuery, c.purpose.String(), keys, entry)
if err != nil {
c.logger.ErrorContext(ctx, "pg cache set", "err", err)
return err
}
return nil
}
func (c *pgCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) {
value, err := c.get(ctx, index, key)
if err == nil {
c.logger.DebugContext(ctx, "pg cache get", "index", index, "key", key)
return value, true
}
logger := c.logger.With("err", err, "index", index, "key", key)
if errors.Is(err, pgx.ErrNoRows) {
logger.InfoContext(ctx, "pg cache miss")
return value, false
}
logger.ErrorContext(ctx, "pg cache get", "err", err)
return value, false
}
func (c *pgCache[I, K, V]) get(ctx context.Context, index I, key K) (value V, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if !slices.Contains(c.indices, index) {
return value, cache.NewIndexUnknownErr(index)
}
err = c.connector.QueryRow(ctx, getQuery, c.purpose.String(), index, key, c.config.MaxAge, c.config.LastUseAge).Scan(&value)
return value, err
}
func (c *pgCache[I, K, V]) Invalidate(ctx context.Context, index I, keys ...K) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
_, err = c.connector.Exec(ctx, invalidateQuery, c.purpose.String(), index, keys)
c.logger.DebugContext(ctx, "pg cache invalidate", "index", index, "keys", keys)
return err
}
func (c *pgCache[I, K, V]) Delete(ctx context.Context, index I, keys ...K) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
_, err = c.connector.Exec(ctx, deleteQuery, c.purpose.String(), index, keys)
c.logger.DebugContext(ctx, "pg cache delete", "index", index, "keys", keys)
return err
}
func (c *pgCache[I, K, V]) Prune(ctx context.Context) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
_, err = c.connector.Exec(ctx, pruneQuery, c.purpose.String(), c.config.MaxAge, c.config.LastUseAge)
c.logger.DebugContext(ctx, "pg cache prune")
return err
}
func (c *pgCache[I, K, V]) Truncate(ctx context.Context) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
_, err = c.connector.Exec(ctx, truncateQuery, c.purpose.String())
c.logger.DebugContext(ctx, "pg cache truncate")
return err
}
type indexKey[I, K comparable] struct {
IndexID I `json:"index_id"`
IndexKey K `json:"index_key"`
}
func (c *pgCache[I, K, V]) indexKeysFromEntry(entry V) []indexKey[I, K] {
keys := make([]indexKey[I, K], 0, len(c.indices)*3) // naive assumption
for _, index := range c.indices {
for _, key := range entry.Keys(index) {
keys = append(keys, indexKey[I, K]{
IndexID: index,
IndexKey: key,
})
}
}
return keys
}

View File

@@ -0,0 +1,524 @@
package pg
import (
"context"
"regexp"
"testing"
"time"
"github.com/jackc/pgx/v5"
"github.com/pashagolub/pgxmock/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/cache"
)
type testIndex int
const (
testIndexID testIndex = iota
testIndexName
)
var testIndices = []testIndex{
testIndexID,
testIndexName,
}
type testObject struct {
ID string
Name []string
}
func (o *testObject) Keys(index testIndex) []string {
switch index {
case testIndexID:
return []string{o.ID}
case testIndexName:
return o.Name
default:
return nil
}
}
func TestNewCache(t *testing.T) {
tests := []struct {
name string
expect func(pgxmock.PgxCommonIface)
wantErr error
}{
{
name: "error",
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(regexp.QuoteMeta(expectedCreatePartitionQuery)).
WillReturnError(pgx.ErrTxClosed)
},
wantErr: pgx.ErrTxClosed,
},
{
name: "success",
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(regexp.QuoteMeta(expectedCreatePartitionQuery)).
WillReturnResult(pgxmock.NewResult("CREATE TABLE", 0))
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conf := cache.Config{
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
}
pool, err := pgxmock.NewPool()
require.NoError(t, err)
tt.expect(pool)
connector := &Connector{
PGXPool: pool,
}
c, err := NewCache[testIndex, string, *testObject](context.Background(), cachePurpose, conf, testIndices, connector)
require.ErrorIs(t, err, tt.wantErr)
if tt.wantErr == nil {
assert.NotNil(t, c)
}
err = pool.ExpectationsWereMet()
assert.NoError(t, err)
})
}
}
func Test_pgCache_Set(t *testing.T) {
queryExpect := regexp.QuoteMeta(setQuery)
type args struct {
entry *testObject
}
tests := []struct {
name string
args args
expect func(pgxmock.PgxCommonIface)
wantErr error
}{
{
name: "error",
args: args{
&testObject{
ID: "id1",
Name: []string{"foo", "bar"},
},
},
expect: func(ppi pgxmock.PgxCommonIface) {
ppi.ExpectExec(queryExpect).
WithArgs(cachePurpose.String(),
[]indexKey[testIndex, string]{
{IndexID: testIndexID, IndexKey: "id1"},
{IndexID: testIndexName, IndexKey: "foo"},
{IndexID: testIndexName, IndexKey: "bar"},
},
&testObject{
ID: "id1",
Name: []string{"foo", "bar"},
}).
WillReturnError(pgx.ErrTxClosed)
},
wantErr: pgx.ErrTxClosed,
},
{
name: "success",
args: args{
&testObject{
ID: "id1",
Name: []string{"foo", "bar"},
},
},
expect: func(ppi pgxmock.PgxCommonIface) {
ppi.ExpectExec(queryExpect).
WithArgs(cachePurpose.String(),
[]indexKey[testIndex, string]{
{IndexID: testIndexID, IndexKey: "id1"},
{IndexID: testIndexName, IndexKey: "foo"},
{IndexID: testIndexName, IndexKey: "bar"},
},
&testObject{
ID: "id1",
Name: []string{"foo", "bar"},
}).
WillReturnResult(pgxmock.NewResult("INSERT", 1))
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, pool := prepareCache(t, cache.Config{})
defer pool.Close()
tt.expect(pool)
err := c.(*pgCache[testIndex, string, *testObject]).
set(context.Background(), tt.args.entry)
require.ErrorIs(t, err, tt.wantErr)
err = pool.ExpectationsWereMet()
assert.NoError(t, err)
})
}
}
func Test_pgCache_Get(t *testing.T) {
queryExpect := regexp.QuoteMeta(getQuery)
type args struct {
index testIndex
key string
}
tests := []struct {
name string
config cache.Config
args args
expect func(pgxmock.PgxCommonIface)
want *testObject
wantOk bool
}{
{
name: "invalid index",
config: cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
args: args{
index: 99,
key: "id1",
},
expect: func(pci pgxmock.PgxCommonIface) {},
wantOk: false,
},
{
name: "no rows",
config: cache.Config{
MaxAge: 0,
LastUseAge: 0,
},
args: args{
index: testIndexID,
key: "id1",
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectQuery(queryExpect).
WithArgs(cachePurpose.String(), testIndexID, "id1", time.Duration(0), time.Duration(0)).
WillReturnRows(pgxmock.NewRows([]string{"payload"}))
},
wantOk: false,
},
{
name: "error",
config: cache.Config{
MaxAge: 0,
LastUseAge: 0,
},
args: args{
index: testIndexID,
key: "id1",
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectQuery(queryExpect).
WithArgs(cachePurpose.String(), testIndexID, "id1", time.Duration(0), time.Duration(0)).
WillReturnError(pgx.ErrTxClosed)
},
wantOk: false,
},
{
name: "ok",
config: cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
args: args{
index: testIndexID,
key: "id1",
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectQuery(queryExpect).
WithArgs(cachePurpose.String(), testIndexID, "id1", time.Minute, time.Second).
WillReturnRows(
pgxmock.NewRows([]string{"payload"}).AddRow(&testObject{
ID: "id1",
Name: []string{"foo", "bar"},
}),
)
},
want: &testObject{
ID: "id1",
Name: []string{"foo", "bar"},
},
wantOk: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, pool := prepareCache(t, tt.config)
defer pool.Close()
tt.expect(pool)
got, ok := c.Get(context.Background(), tt.args.index, tt.args.key)
assert.Equal(t, tt.wantOk, ok)
assert.Equal(t, tt.want, got)
err := pool.ExpectationsWereMet()
assert.NoError(t, err)
})
}
}
func Test_pgCache_Invalidate(t *testing.T) {
queryExpect := regexp.QuoteMeta(invalidateQuery)
type args struct {
index testIndex
keys []string
}
tests := []struct {
name string
config cache.Config
args args
expect func(pgxmock.PgxCommonIface)
wantErr error
}{
{
name: "error",
config: cache.Config{
MaxAge: 0,
LastUseAge: 0,
},
args: args{
index: testIndexID,
keys: []string{"id1", "id2"},
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(queryExpect).
WithArgs(cachePurpose.String(), testIndexID, []string{"id1", "id2"}).
WillReturnError(pgx.ErrTxClosed)
},
wantErr: pgx.ErrTxClosed,
},
{
name: "ok",
config: cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
args: args{
index: testIndexID,
keys: []string{"id1", "id2"},
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(queryExpect).
WithArgs(cachePurpose.String(), testIndexID, []string{"id1", "id2"}).
WillReturnResult(pgxmock.NewResult("DELETE", 1))
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, pool := prepareCache(t, tt.config)
defer pool.Close()
tt.expect(pool)
err := c.Invalidate(context.Background(), tt.args.index, tt.args.keys...)
assert.ErrorIs(t, err, tt.wantErr)
err = pool.ExpectationsWereMet()
assert.NoError(t, err)
})
}
}
func Test_pgCache_Delete(t *testing.T) {
queryExpect := regexp.QuoteMeta(deleteQuery)
type args struct {
index testIndex
keys []string
}
tests := []struct {
name string
config cache.Config
args args
expect func(pgxmock.PgxCommonIface)
wantErr error
}{
{
name: "error",
config: cache.Config{
MaxAge: 0,
LastUseAge: 0,
},
args: args{
index: testIndexID,
keys: []string{"id1", "id2"},
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(queryExpect).
WithArgs(cachePurpose.String(), testIndexID, []string{"id1", "id2"}).
WillReturnError(pgx.ErrTxClosed)
},
wantErr: pgx.ErrTxClosed,
},
{
name: "ok",
config: cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
args: args{
index: testIndexID,
keys: []string{"id1", "id2"},
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(queryExpect).
WithArgs(cachePurpose.String(), testIndexID, []string{"id1", "id2"}).
WillReturnResult(pgxmock.NewResult("DELETE", 1))
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, pool := prepareCache(t, tt.config)
defer pool.Close()
tt.expect(pool)
err := c.Delete(context.Background(), tt.args.index, tt.args.keys...)
assert.ErrorIs(t, err, tt.wantErr)
err = pool.ExpectationsWereMet()
assert.NoError(t, err)
})
}
}
func Test_pgCache_Prune(t *testing.T) {
queryExpect := regexp.QuoteMeta(pruneQuery)
tests := []struct {
name string
config cache.Config
expect func(pgxmock.PgxCommonIface)
wantErr error
}{
{
name: "error",
config: cache.Config{
MaxAge: 0,
LastUseAge: 0,
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(queryExpect).
WithArgs(cachePurpose.String(), time.Duration(0), time.Duration(0)).
WillReturnError(pgx.ErrTxClosed)
},
wantErr: pgx.ErrTxClosed,
},
{
name: "ok",
config: cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(queryExpect).
WithArgs(cachePurpose.String(), time.Minute, time.Second).
WillReturnResult(pgxmock.NewResult("DELETE", 1))
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, pool := prepareCache(t, tt.config)
defer pool.Close()
tt.expect(pool)
err := c.Prune(context.Background())
assert.ErrorIs(t, err, tt.wantErr)
err = pool.ExpectationsWereMet()
assert.NoError(t, err)
})
}
}
func Test_pgCache_Truncate(t *testing.T) {
queryExpect := regexp.QuoteMeta(truncateQuery)
tests := []struct {
name string
config cache.Config
expect func(pgxmock.PgxCommonIface)
wantErr error
}{
{
name: "error",
config: cache.Config{
MaxAge: 0,
LastUseAge: 0,
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(queryExpect).
WithArgs(cachePurpose.String()).
WillReturnError(pgx.ErrTxClosed)
},
wantErr: pgx.ErrTxClosed,
},
{
name: "ok",
config: cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
expect: func(pci pgxmock.PgxCommonIface) {
pci.ExpectExec(queryExpect).
WithArgs(cachePurpose.String()).
WillReturnResult(pgxmock.NewResult("DELETE", 1))
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, pool := prepareCache(t, tt.config)
defer pool.Close()
tt.expect(pool)
err := c.Truncate(context.Background())
assert.ErrorIs(t, err, tt.wantErr)
err = pool.ExpectationsWereMet()
assert.NoError(t, err)
})
}
}
const (
cachePurpose = cache.PurposeAuthzInstance
expectedCreatePartitionQuery = `create unlogged table if not exists cache.objects_authz_instance
partition of cache.objects
for values in ('authz_instance');
create unlogged table if not exists cache.string_keys_authz_instance
partition of cache.string_keys
for values in ('authz_instance');
`
)
func prepareCache(t *testing.T, conf cache.Config) (cache.PrunerCache[testIndex, string, *testObject], pgxmock.PgxPoolIface) {
conf.Log = &logging.Config{
Level: "debug",
AddSource: true,
}
pool, err := pgxmock.NewPool()
require.NoError(t, err)
pool.ExpectExec(regexp.QuoteMeta(expectedCreatePartitionQuery)).
WillReturnResult(pgxmock.NewResult("CREATE TABLE", 0))
connector := &Connector{
PGXPool: pool,
}
c, err := NewCache[testIndex, string, *testObject](context.Background(), cachePurpose, conf, testIndices, connector)
require.NoError(t, err)
return c, pool
}

View File

@@ -0,0 +1,18 @@
delete from cache.objects o
where o.cache_name = $1
and (
case when $2::interval > '0s'
then created_at < now()-$2::interval -- max age
else false
end
or case when $3::interval > '0s'
then last_used_at < now()-$3::interval -- last use
else false
end
or o.id not in (
select object_id
from cache.string_keys
where cache_name = $1
)
)
;

View File

@@ -0,0 +1,19 @@
with object as (
insert into cache.objects (cache_name, payload)
values ($1, $3)
returning id
)
insert into cache.string_keys (
cache_name,
index_id,
index_key,
object_id
)
select $1, keys.index_id, keys.index_key, id as object_id
from object, jsonb_to_recordset($2) keys (
index_id bigint,
index_key text
)
on conflict (cache_name, index_id, index_key) do
update set object_id = EXCLUDED.object_id
;

View File

@@ -0,0 +1,3 @@
delete from cache.objects o
where o.cache_name = $1
;

View File

@@ -0,0 +1,10 @@
local function remove(object_id)
local setKey = keySetKey(object_id)
local keys = redis.call("SMEMBERS", setKey)
local n = #keys
for i = 1, n do
redis.call("DEL", keys[i])
end
redis.call("DEL", setKey)
redis.call("DEL", object_id)
end

View File

@@ -0,0 +1,3 @@
-- SELECT ensures the DB namespace for each script.
-- When used, it consumes the first ARGV entry.
redis.call("SELECT", ARGV[1])

View File

@@ -0,0 +1,17 @@
-- keySetKey returns the redis key of the set containing all keys to the object.
local function keySetKey (object_id)
return object_id .. "-keys"
end
local function getTime()
return tonumber(redis.call('TIME')[1])
end
-- getCall wrapts redis.call so a nil is returned instead of false.
local function getCall (...)
local result = redis.call(...)
if result == false then
return nil
end
return result
end

View File

@@ -0,0 +1,91 @@
package redis
import (
"context"
"errors"
"time"
"github.com/redis/go-redis/v9"
"github.com/sony/gobreaker/v2"
"github.com/zitadel/logging"
)
const defaultInflightSize = 100000
type CBConfig struct {
// Interval when the counters are reset to 0.
// 0 interval never resets the counters until the CB is opened.
Interval time.Duration
// Amount of consecutive failures permitted
MaxConsecutiveFailures uint32
// The ratio of failed requests out of total requests
MaxFailureRatio float64
// Timeout after opening of the CB, until the state is set to half-open.
Timeout time.Duration
// The allowed amount of requests that are allowed to pass when the CB is half-open.
MaxRetryRequests uint32
}
func (config *CBConfig) readyToTrip(counts gobreaker.Counts) bool {
if config.MaxConsecutiveFailures > 0 && counts.ConsecutiveFailures > config.MaxConsecutiveFailures {
return true
}
if config.MaxFailureRatio > 0 && counts.Requests > 0 {
failureRatio := float64(counts.TotalFailures) / float64(counts.Requests)
return failureRatio > config.MaxFailureRatio
}
return false
}
// limiter implements [redis.Limiter] as a circuit breaker.
type limiter struct {
inflight chan func(success bool)
cb *gobreaker.TwoStepCircuitBreaker[struct{}]
}
func newLimiter(config *CBConfig, maxActiveConns int) redis.Limiter {
if config == nil {
return nil
}
// The size of the inflight channel needs to be big enough for maxActiveConns to prevent blocking.
// When that is 0 (no limit), we must set a sane default.
if maxActiveConns <= 0 {
maxActiveConns = defaultInflightSize
}
return &limiter{
inflight: make(chan func(success bool), maxActiveConns),
cb: gobreaker.NewTwoStepCircuitBreaker[struct{}](gobreaker.Settings{
Name: "redis cache",
MaxRequests: config.MaxRetryRequests,
Interval: config.Interval,
Timeout: config.Timeout,
ReadyToTrip: config.readyToTrip,
OnStateChange: func(name string, from, to gobreaker.State) {
logging.WithFields("name", name, "from", from, "to", to).Warn("circuit breaker state change")
},
}),
}
}
// Allow implements [redis.Limiter].
func (l *limiter) Allow() error {
done, err := l.cb.Allow()
if err != nil {
return err
}
l.inflight <- done
return nil
}
// ReportResult implements [redis.Limiter].
//
// ReportResult checks the error returned by the Redis client.
// `nil`, [redis.Nil] and [context.Canceled] are not considered failures.
// Any other error, like connection or [context.DeadlineExceeded] is counted as a failure.
func (l *limiter) ReportResult(err error) {
done := <-l.inflight
done(err == nil ||
errors.Is(err, redis.Nil) ||
errors.Is(err, context.Canceled) ||
redis.HasErrorPrefix(err, "NOSCRIPT"))
}

View File

@@ -0,0 +1,168 @@
package redis
import (
"context"
"testing"
"time"
"github.com/sony/gobreaker/v2"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/cache"
)
func TestCBConfig_readyToTrip(t *testing.T) {
type fields struct {
MaxConsecutiveFailures uint32
MaxFailureRatio float64
}
type args struct {
counts gobreaker.Counts
}
tests := []struct {
name string
fields fields
args args
want bool
}{
{
name: "disabled",
fields: fields{},
args: args{
counts: gobreaker.Counts{
Requests: 100,
ConsecutiveFailures: 5,
TotalFailures: 10,
},
},
want: false,
},
{
name: "no failures",
fields: fields{
MaxConsecutiveFailures: 5,
MaxFailureRatio: 0.1,
},
args: args{
counts: gobreaker.Counts{
Requests: 100,
ConsecutiveFailures: 0,
TotalFailures: 0,
},
},
want: false,
},
{
name: "some failures",
fields: fields{
MaxConsecutiveFailures: 5,
MaxFailureRatio: 0.1,
},
args: args{
counts: gobreaker.Counts{
Requests: 100,
ConsecutiveFailures: 5,
TotalFailures: 10,
},
},
want: false,
},
{
name: "consecutive exceeded",
fields: fields{
MaxConsecutiveFailures: 5,
MaxFailureRatio: 0.1,
},
args: args{
counts: gobreaker.Counts{
Requests: 100,
ConsecutiveFailures: 6,
TotalFailures: 0,
},
},
want: true,
},
{
name: "ratio exceeded",
fields: fields{
MaxConsecutiveFailures: 5,
MaxFailureRatio: 0.1,
},
args: args{
counts: gobreaker.Counts{
Requests: 100,
ConsecutiveFailures: 1,
TotalFailures: 11,
},
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := &CBConfig{
MaxConsecutiveFailures: tt.fields.MaxConsecutiveFailures,
MaxFailureRatio: tt.fields.MaxFailureRatio,
}
if got := config.readyToTrip(tt.args.counts); got != tt.want {
t.Errorf("CBConfig.readyToTrip() = %v, want %v", got, tt.want)
}
})
}
}
func Test_redisCache_limiter(t *testing.T) {
c, _ := prepareCache(t, cache.Config{}, withCircuitBreakerOption(
&CBConfig{
MaxConsecutiveFailures: 2,
MaxFailureRatio: 0.4,
Timeout: 100 * time.Millisecond,
MaxRetryRequests: 1,
},
))
ctx := context.Background()
canceledCtx, cancel := context.WithCancel(ctx)
cancel()
timedOutCtx, cancel := context.WithTimeout(ctx, -1)
defer cancel()
// CB is and should remain closed
for i := 0; i < 10; i++ {
err := c.Truncate(ctx)
require.NoError(t, err)
}
for i := 0; i < 10; i++ {
err := c.Truncate(canceledCtx)
require.ErrorIs(t, err, context.Canceled)
}
// Timeout err should open the CB after more than 2 failures
for i := 0; i < 3; i++ {
err := c.Truncate(timedOutCtx)
if i > 2 {
require.ErrorIs(t, err, gobreaker.ErrOpenState)
} else {
require.ErrorIs(t, err, context.DeadlineExceeded)
}
}
time.Sleep(200 * time.Millisecond)
// CB should be half-open. If the first command fails, the CB will be Open again
err := c.Truncate(timedOutCtx)
require.ErrorIs(t, err, context.DeadlineExceeded)
err = c.Truncate(timedOutCtx)
require.ErrorIs(t, err, gobreaker.ErrOpenState)
// Reset the DB to closed
time.Sleep(200 * time.Millisecond)
err = c.Truncate(ctx)
require.NoError(t, err)
// Exceed the ratio
err = c.Truncate(timedOutCtx)
require.ErrorIs(t, err, context.DeadlineExceeded)
err = c.Truncate(ctx)
require.ErrorIs(t, err, gobreaker.ErrOpenState)
}

View File

@@ -0,0 +1,157 @@
package redis
import (
"crypto/tls"
"time"
"github.com/redis/go-redis/v9"
)
type Config struct {
Enabled bool
// The network type, either tcp or unix.
// Default is tcp.
Network string
// host:port address.
Addr string
// ClientName will execute the `CLIENT SETNAME ClientName` command for each conn.
ClientName string
// Use the specified Username to authenticate the current connection
// with one of the connections defined in the ACL list when connecting
// to a Redis 6.0 instance, or greater, that is using the Redis ACL system.
Username string
// Optional password. Must match the password specified in the
// requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower),
// or the User Password when connecting to a Redis 6.0 instance, or greater,
// that is using the Redis ACL system.
Password string
// Each ZITADEL cache uses an incremental DB namespace.
// This option offsets the first DB so it doesn't conflict with other databases on the same server.
// Note that ZITADEL uses FLUSHDB command to truncate a cache.
// This can have destructive consequences when overlapping DB namespaces are used.
DBOffset int
// Maximum number of retries before giving up.
// Default is 3 retries; -1 (not 0) disables retries.
MaxRetries int
// Minimum backoff between each retry.
// Default is 8 milliseconds; -1 disables backoff.
MinRetryBackoff time.Duration
// Maximum backoff between each retry.
// Default is 512 milliseconds; -1 disables backoff.
MaxRetryBackoff time.Duration
// Dial timeout for establishing new connections.
// Default is 5 seconds.
DialTimeout time.Duration
// Timeout for socket reads. If reached, commands will fail
// with a timeout instead of blocking. Supported values:
// - `0` - default timeout (3 seconds).
// - `-1` - no timeout (block indefinitely).
// - `-2` - disables SetReadDeadline calls completely.
ReadTimeout time.Duration
// Timeout for socket writes. If reached, commands will fail
// with a timeout instead of blocking. Supported values:
// - `0` - default timeout (3 seconds).
// - `-1` - no timeout (block indefinitely).
// - `-2` - disables SetWriteDeadline calls completely.
WriteTimeout time.Duration
// Type of connection pool.
// true for FIFO pool, false for LIFO pool.
// Note that FIFO has slightly higher overhead compared to LIFO,
// but it helps closing idle connections faster reducing the pool size.
PoolFIFO bool
// Base number of socket connections.
// Default is 10 connections per every available CPU as reported by runtime.GOMAXPROCS.
// If there is not enough connections in the pool, new connections will be allocated in excess of PoolSize,
// you can limit it through MaxActiveConns
PoolSize int
// Amount of time client waits for connection if all connections
// are busy before returning an error.
// Default is ReadTimeout + 1 second.
PoolTimeout time.Duration
// Minimum number of idle connections which is useful when establishing
// new connection is slow.
// Default is 0. the idle connections are not closed by default.
MinIdleConns int
// Maximum number of idle connections.
// Default is 0. the idle connections are not closed by default.
MaxIdleConns int
// Maximum number of connections allocated by the pool at a given time.
// When zero, there is no limit on the number of connections in the pool.
MaxActiveConns int
// ConnMaxIdleTime is the maximum amount of time a connection may be idle.
// Should be less than server's timeout.
//
// Expired connections may be closed lazily before reuse.
// If d <= 0, connections are not closed due to a connection's idle time.
//
// Default is 30 minutes. -1 disables idle timeout check.
ConnMaxIdleTime time.Duration
// ConnMaxLifetime is the maximum amount of time a connection may be reused.
//
// Expired connections may be closed lazily before reuse.
// If <= 0, connections are not closed due to a connection's age.
//
// Default is to not close idle connections.
ConnMaxLifetime time.Duration
EnableTLS bool
// Disable set-lib on connect. Default is false.
DisableIndentity bool
// Add suffix to client name. Default is empty.
IdentitySuffix string
CircuitBreaker *CBConfig
}
type Connector struct {
*redis.Client
Config Config
}
func NewConnector(config Config) *Connector {
if !config.Enabled {
return nil
}
return &Connector{
Client: redis.NewClient(optionsFromConfig(config)),
Config: config,
}
}
func optionsFromConfig(c Config) *redis.Options {
opts := &redis.Options{
Network: c.Network,
Addr: c.Addr,
ClientName: c.ClientName,
Protocol: 3,
Username: c.Username,
Password: c.Password,
MaxRetries: c.MaxRetries,
MinRetryBackoff: c.MinRetryBackoff,
MaxRetryBackoff: c.MaxRetryBackoff,
DialTimeout: c.DialTimeout,
ReadTimeout: c.ReadTimeout,
WriteTimeout: c.WriteTimeout,
ContextTimeoutEnabled: true,
PoolFIFO: c.PoolFIFO,
PoolTimeout: c.PoolTimeout,
MinIdleConns: c.MinIdleConns,
MaxIdleConns: c.MaxIdleConns,
MaxActiveConns: c.MaxActiveConns,
ConnMaxIdleTime: c.ConnMaxIdleTime,
ConnMaxLifetime: c.ConnMaxLifetime,
DisableIndentity: c.DisableIndentity,
IdentitySuffix: c.IdentitySuffix,
Limiter: newLimiter(c.CircuitBreaker, c.MaxActiveConns),
}
if c.EnableTLS {
opts.TLSConfig = new(tls.Config)
}
return opts
}

View File

@@ -0,0 +1,29 @@
local result = redis.call("GET", KEYS[1])
if result == false then
return nil
end
local object_id = tostring(result)
local object = getCall("HGET", object_id, "object")
if object == nil then
-- object expired, but there are keys that need to be cleaned up
remove(object_id)
return nil
end
-- max-age must be checked manually
local expiry = getCall("HGET", object_id, "expiry")
if not (expiry == nil) and tonumber(expiry) > 0 then
if getTime() > tonumber(expiry) then
remove(object_id)
return nil
end
end
local usage_lifetime = getCall("HGET", object_id, "usage_lifetime")
-- reset usage based TTL
if not (usage_lifetime == nil) and tonumber(usage_lifetime) > 0 then
redis.call('EXPIRE', object_id, usage_lifetime)
end
return object

View File

@@ -0,0 +1,9 @@
local n = #KEYS
for i = 1, n do
local result = redis.call("GET", KEYS[i])
if result == false then
return nil
end
local object_id = tostring(result)
remove(object_id)
end

View File

@@ -0,0 +1,172 @@
package redis
import (
"context"
_ "embed"
"encoding/json"
"errors"
"fmt"
"log/slog"
"strings"
"time"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
var (
//go:embed _select.lua
selectComponent string
//go:embed _util.lua
utilComponent string
//go:embed _remove.lua
removeComponent string
//go:embed set.lua
setScript string
//go:embed get.lua
getScript string
//go:embed invalidate.lua
invalidateScript string
// Don't mind the creative "import"
setParsed = redis.NewScript(strings.Join([]string{selectComponent, utilComponent, setScript}, "\n"))
getParsed = redis.NewScript(strings.Join([]string{selectComponent, utilComponent, removeComponent, getScript}, "\n"))
invalidateParsed = redis.NewScript(strings.Join([]string{selectComponent, utilComponent, removeComponent, invalidateScript}, "\n"))
)
type redisCache[I, K comparable, V cache.Entry[I, K]] struct {
db int
config *cache.Config
indices []I
connector *Connector
logger *slog.Logger
}
// NewCache returns a cache that stores and retrieves object using single Redis.
func NewCache[I, K comparable, V cache.Entry[I, K]](config cache.Config, client *Connector, db int, indices []I) cache.Cache[I, K, V] {
return &redisCache[I, K, V]{
config: &config,
db: db,
indices: indices,
connector: client,
logger: config.Log.Slog(),
}
}
func (c *redisCache[I, K, V]) Set(ctx context.Context, value V) {
if _, err := c.set(ctx, value); err != nil {
c.logger.ErrorContext(ctx, "redis cache set", "err", err)
}
}
func (c *redisCache[I, K, V]) set(ctx context.Context, value V) (objectID string, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
// Internal ID used for the object
objectID = uuid.NewString()
keys := []string{objectID}
// flatten the secondary keys
for _, index := range c.indices {
keys = append(keys, c.redisIndexKeys(index, value.Keys(index)...)...)
}
var buf strings.Builder
err = json.NewEncoder(&buf).Encode(value)
if err != nil {
return "", err
}
err = setParsed.Run(ctx, c.connector, keys,
c.db, // DB namespace
buf.String(), // object
int64(c.config.LastUseAge/time.Second), // usage_lifetime
int64(c.config.MaxAge/time.Second), // max_age,
).Err()
// redis.Nil is always returned because the script doesn't have a return value.
if err != nil && !errors.Is(err, redis.Nil) {
return "", err
}
return objectID, nil
}
func (c *redisCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) {
var (
obj any
err error
)
ctx, span := tracing.NewSpan(ctx)
defer func() {
if errors.Is(err, redis.Nil) {
err = nil
}
span.EndWithError(err)
}()
logger := c.logger.With("index", index, "key", key)
obj, err = getParsed.Run(ctx, c.connector, c.redisIndexKeys(index, key), c.db).Result()
if err != nil && !errors.Is(err, redis.Nil) {
logger.ErrorContext(ctx, "redis cache get", "err", err)
return value, false
}
data, ok := obj.(string)
if !ok {
logger.With("err", err).InfoContext(ctx, "redis cache miss")
return value, false
}
err = json.NewDecoder(strings.NewReader(data)).Decode(&value)
if err != nil {
logger.ErrorContext(ctx, "redis cache get", "err", fmt.Errorf("decode: %w", err))
return value, false
}
return value, true
}
func (c *redisCache[I, K, V]) Invalidate(ctx context.Context, index I, key ...K) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if len(key) == 0 {
return nil
}
err = invalidateParsed.Run(ctx, c.connector, c.redisIndexKeys(index, key...), c.db).Err()
// redis.Nil is always returned because the script doesn't have a return value.
if err != nil && !errors.Is(err, redis.Nil) {
return err
}
return nil
}
func (c *redisCache[I, K, V]) Delete(ctx context.Context, index I, key ...K) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if len(key) == 0 {
return nil
}
pipe := c.connector.Pipeline()
pipe.Select(ctx, c.db)
pipe.Del(ctx, c.redisIndexKeys(index, key...)...)
_, err = pipe.Exec(ctx)
return err
}
func (c *redisCache[I, K, V]) Truncate(ctx context.Context) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
pipe := c.connector.Pipeline()
pipe.Select(ctx, c.db)
pipe.FlushDB(ctx)
_, err = pipe.Exec(ctx)
return err
}
func (c *redisCache[I, K, V]) redisIndexKeys(index I, keys ...K) []string {
out := make([]string, len(keys))
for i, k := range keys {
out[i] = fmt.Sprintf("%v:%v", index, k)
}
return out
}

View File

@@ -0,0 +1,721 @@
package redis
import (
"context"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/cache"
)
type testIndex int
const (
testIndexID testIndex = iota
testIndexName
)
const (
testDB = 99
)
var testIndices = []testIndex{
testIndexID,
testIndexName,
}
type testObject struct {
ID string
Name []string
}
func (o *testObject) Keys(index testIndex) []string {
switch index {
case testIndexID:
return []string{o.ID}
case testIndexName:
return o.Name
default:
return nil
}
}
func Test_redisCache_set(t *testing.T) {
type args struct {
ctx context.Context
value *testObject
}
tests := []struct {
name string
config cache.Config
args args
assertions func(t *testing.T, s *miniredis.Miniredis, objectID string)
wantErr error
}{
{
name: "ok",
config: cache.Config{},
args: args{
ctx: context.Background(),
value: &testObject{
ID: "one",
Name: []string{"foo", "bar"},
},
},
assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) {
s.CheckGet(t, "0:one", objectID)
s.CheckGet(t, "1:foo", objectID)
s.CheckGet(t, "1:bar", objectID)
assert.Empty(t, s.HGet(objectID, "expiry"))
assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object"))
},
},
{
name: "with last use TTL",
config: cache.Config{
LastUseAge: time.Second,
},
args: args{
ctx: context.Background(),
value: &testObject{
ID: "one",
Name: []string{"foo", "bar"},
},
},
assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) {
s.CheckGet(t, "0:one", objectID)
s.CheckGet(t, "1:foo", objectID)
s.CheckGet(t, "1:bar", objectID)
assert.Empty(t, s.HGet(objectID, "expiry"))
assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object"))
assert.Positive(t, s.TTL(objectID))
s.FastForward(2 * time.Second)
v, err := s.Get(objectID)
require.Error(t, err)
assert.Empty(t, v)
},
},
{
name: "with last use TTL and max age",
config: cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
args: args{
ctx: context.Background(),
value: &testObject{
ID: "one",
Name: []string{"foo", "bar"},
},
},
assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) {
s.CheckGet(t, "0:one", objectID)
s.CheckGet(t, "1:foo", objectID)
s.CheckGet(t, "1:bar", objectID)
assert.NotEmpty(t, s.HGet(objectID, "expiry"))
assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object"))
assert.Positive(t, s.TTL(objectID))
s.FastForward(2 * time.Second)
v, err := s.Get(objectID)
require.Error(t, err)
assert.Empty(t, v)
},
},
{
name: "with max age TTL",
config: cache.Config{
MaxAge: time.Minute,
},
args: args{
ctx: context.Background(),
value: &testObject{
ID: "one",
Name: []string{"foo", "bar"},
},
},
assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) {
s.CheckGet(t, "0:one", objectID)
s.CheckGet(t, "1:foo", objectID)
s.CheckGet(t, "1:bar", objectID)
assert.Empty(t, s.HGet(objectID, "expiry"))
assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object"))
assert.Positive(t, s.TTL(objectID))
s.FastForward(2 * time.Minute)
v, err := s.Get(objectID)
require.Error(t, err)
assert.Empty(t, v)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, server := prepareCache(t, tt.config)
rc := c.(*redisCache[testIndex, string, *testObject])
objectID, err := rc.set(tt.args.ctx, tt.args.value)
require.ErrorIs(t, err, tt.wantErr)
t.Log(rc.connector.HGetAll(context.Background(), objectID))
tt.assertions(t, server, objectID)
})
}
}
func Test_redisCache_Get(t *testing.T) {
type args struct {
ctx context.Context
index testIndex
key string
}
tests := []struct {
name string
config cache.Config
preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis)
args args
want *testObject
wantOK bool
}{
{
name: "connection error",
config: cache.Config{},
preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
s.RequireAuth("foobar")
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: "foo",
},
wantOK: false,
},
{
name: "get by ID",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
args: args{
ctx: context.Background(),
index: testIndexID,
key: "one",
},
want: &testObject{
ID: "one",
Name: []string{"foo", "bar"},
},
wantOK: true,
},
{
name: "get by name",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: "foo",
},
want: &testObject{
ID: "one",
Name: []string{"foo", "bar"},
},
wantOK: true,
},
{
name: "usage timeout",
config: cache.Config{
LastUseAge: time.Minute,
},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
_, ok := c.Get(context.Background(), testIndexID, "one")
require.True(t, ok)
s.FastForward(2 * time.Minute)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: "foo",
},
want: nil,
wantOK: false,
},
{
name: "max age timeout",
config: cache.Config{
MaxAge: time.Minute,
},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
_, ok := c.Get(context.Background(), testIndexID, "one")
require.True(t, ok)
s.FastForward(2 * time.Minute)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: "foo",
},
want: nil,
wantOK: false,
},
{
name: "not found",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: "spanac",
},
wantOK: false,
},
{
name: "json decode error",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
objectID, err := s.Get(c.(*redisCache[testIndex, string, *testObject]).redisIndexKeys(testIndexID, "one")[0])
require.NoError(t, err)
s.HSet(objectID, "object", "~~~")
},
args: args{
ctx: context.Background(),
index: testIndexID,
key: "one",
},
wantOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, server := prepareCache(t, tt.config)
tt.preparation(t, c, server)
t.Log(server.Keys())
got, ok := c.Get(tt.args.ctx, tt.args.index, tt.args.key)
require.Equal(t, tt.wantOK, ok)
assert.Equal(t, tt.want, got)
})
}
}
func Test_redisCache_Invalidate(t *testing.T) {
type args struct {
ctx context.Context
index testIndex
key []string
}
tests := []struct {
name string
config cache.Config
preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis)
assertions func(t *testing.T, c cache.Cache[testIndex, string, *testObject])
args args
wantErr bool
}{
{
name: "connection error",
config: cache.Config{},
preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
s.RequireAuth("foobar")
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"foo"},
},
wantErr: true,
},
{
name: "no keys, noop",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
args: args{
ctx: context.Background(),
index: testIndexID,
key: []string{},
},
wantErr: false,
},
{
name: "invalidate by ID",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.False(t, ok)
assert.Nil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.False(t, ok)
assert.Nil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexID,
key: []string{"one"},
},
wantErr: false,
},
{
name: "invalidate by name",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.False(t, ok)
assert.Nil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.False(t, ok)
assert.Nil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"foo"},
},
wantErr: false,
},
{
name: "invalidate after timeout",
config: cache.Config{
LastUseAge: time.Minute,
},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
_, ok := c.Get(context.Background(), testIndexID, "one")
require.True(t, ok)
s.FastForward(2 * time.Minute)
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.False(t, ok)
assert.Nil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.False(t, ok)
assert.Nil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"foo"},
},
wantErr: false,
},
{
name: "not found",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.True(t, ok)
assert.NotNil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.True(t, ok)
assert.NotNil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"spanac"},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, server := prepareCache(t, tt.config)
tt.preparation(t, c, server)
t.Log(server.Keys())
err := c.Invalidate(tt.args.ctx, tt.args.index, tt.args.key...)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
})
}
}
func Test_redisCache_Delete(t *testing.T) {
type args struct {
ctx context.Context
index testIndex
key []string
}
tests := []struct {
name string
config cache.Config
preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis)
assertions func(t *testing.T, c cache.Cache[testIndex, string, *testObject])
args args
wantErr bool
}{
{
name: "connection error",
config: cache.Config{},
preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
s.RequireAuth("foobar")
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"foo"},
},
wantErr: true,
},
{
name: "no keys, noop",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
args: args{
ctx: context.Background(),
index: testIndexID,
key: []string{},
},
wantErr: false,
},
{
name: "delete ID",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.False(t, ok)
assert.Nil(t, obj)
// Get be name should still work
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.True(t, ok)
assert.NotNil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexID,
key: []string{"one"},
},
wantErr: false,
},
{
name: "delete name",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
// get by ID should still work
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.True(t, ok)
assert.NotNil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.False(t, ok)
assert.Nil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"foo"},
},
wantErr: false,
},
{
name: "not found",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.True(t, ok)
assert.NotNil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.True(t, ok)
assert.NotNil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"spanac"},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, server := prepareCache(t, tt.config)
tt.preparation(t, c, server)
t.Log(server.Keys())
err := c.Delete(tt.args.ctx, tt.args.index, tt.args.key...)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
})
}
}
func Test_redisCache_Truncate(t *testing.T) {
type args struct {
ctx context.Context
}
tests := []struct {
name string
config cache.Config
preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis)
assertions func(t *testing.T, c cache.Cache[testIndex, string, *testObject])
args args
wantErr bool
}{
{
name: "connection error",
config: cache.Config{},
preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
s.RequireAuth("foobar")
},
args: args{
ctx: context.Background(),
},
wantErr: true,
},
{
name: "ok",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
c.Set(context.Background(), &testObject{
ID: "two",
Name: []string{"Hello", "World"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.False(t, ok)
assert.Nil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "World")
assert.False(t, ok)
assert.Nil(t, obj)
},
args: args{
ctx: context.Background(),
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, server := prepareCache(t, tt.config)
tt.preparation(t, c, server)
t.Log(server.Keys())
err := c.Truncate(tt.args.ctx)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
})
}
}
func prepareCache(t *testing.T, conf cache.Config, options ...func(*Config)) (cache.Cache[testIndex, string, *testObject], *miniredis.Miniredis) {
conf.Log = &logging.Config{
Level: "debug",
AddSource: true,
}
server := miniredis.RunT(t)
server.Select(testDB)
connConfig := Config{
Enabled: true,
Network: "tcp",
Addr: server.Addr(),
DisableIndentity: true,
}
for _, option := range options {
option(&connConfig)
}
connector := NewConnector(connConfig)
t.Cleanup(func() {
connector.Close()
server.Close()
})
c := NewCache[testIndex, string, *testObject](conf, connector, testDB, testIndices)
return c, server
}
func withCircuitBreakerOption(cb *CBConfig) func(*Config) {
return func(c *Config) {
c.CircuitBreaker = cb
}
}

View File

@@ -0,0 +1,27 @@
-- KEYS: [1]: object_id; [>1]: index keys.
local object_id = KEYS[1]
local object = ARGV[2]
local usage_lifetime = tonumber(ARGV[3]) -- usage based lifetime in seconds
local max_age = tonumber(ARGV[4]) -- max age liftime in seconds
redis.call("HSET", object_id,"object", object)
if usage_lifetime > 0 then
redis.call("HSET", object_id, "usage_lifetime", usage_lifetime)
-- enable usage based TTL
redis.call("EXPIRE", object_id, usage_lifetime)
if max_age > 0 then
-- set max_age to hash map for expired remove on Get
local expiry = getTime() + max_age
redis.call("HSET", object_id, "expiry", expiry)
end
elseif max_age > 0 then
-- enable max_age based TTL
redis.call("EXPIRE", object_id, max_age)
end
local n = #KEYS
local setKey = keySetKey(object_id)
for i = 2, n do -- offset to the second element to skip object_id
redis.call("SADD", setKey, KEYS[i]) -- set of all keys used for housekeeping
redis.call("SET", KEYS[i], object_id) -- key to object_id mapping
end

View File

@@ -0,0 +1,98 @@
// Code generated by "enumer -type Connector -transform snake -trimprefix Connector -linecomment -text"; DO NOT EDIT.
package cache
import (
"fmt"
"strings"
)
const _ConnectorName = "memorypostgresredis"
var _ConnectorIndex = [...]uint8{0, 0, 6, 14, 19}
const _ConnectorLowerName = "memorypostgresredis"
func (i Connector) String() string {
if i < 0 || i >= Connector(len(_ConnectorIndex)-1) {
return fmt.Sprintf("Connector(%d)", i)
}
return _ConnectorName[_ConnectorIndex[i]:_ConnectorIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _ConnectorNoOp() {
var x [1]struct{}
_ = x[ConnectorUnspecified-(0)]
_ = x[ConnectorMemory-(1)]
_ = x[ConnectorPostgres-(2)]
_ = x[ConnectorRedis-(3)]
}
var _ConnectorValues = []Connector{ConnectorUnspecified, ConnectorMemory, ConnectorPostgres, ConnectorRedis}
var _ConnectorNameToValueMap = map[string]Connector{
_ConnectorName[0:0]: ConnectorUnspecified,
_ConnectorLowerName[0:0]: ConnectorUnspecified,
_ConnectorName[0:6]: ConnectorMemory,
_ConnectorLowerName[0:6]: ConnectorMemory,
_ConnectorName[6:14]: ConnectorPostgres,
_ConnectorLowerName[6:14]: ConnectorPostgres,
_ConnectorName[14:19]: ConnectorRedis,
_ConnectorLowerName[14:19]: ConnectorRedis,
}
var _ConnectorNames = []string{
_ConnectorName[0:0],
_ConnectorName[0:6],
_ConnectorName[6:14],
_ConnectorName[14:19],
}
// ConnectorString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func ConnectorString(s string) (Connector, error) {
if val, ok := _ConnectorNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _ConnectorNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to Connector values", s)
}
// ConnectorValues returns all values of the enum
func ConnectorValues() []Connector {
return _ConnectorValues
}
// ConnectorStrings returns a slice of all String values of the enum
func ConnectorStrings() []string {
strs := make([]string, len(_ConnectorNames))
copy(strs, _ConnectorNames)
return strs
}
// IsAConnector returns "true" if the value is listed in the enum definition. "false" otherwise
func (i Connector) IsAConnector() bool {
for _, v := range _ConnectorValues {
if i == v {
return true
}
}
return false
}
// MarshalText implements the encoding.TextMarshaler interface for Connector
func (i Connector) MarshalText() ([]byte, error) {
return []byte(i.String()), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface for Connector
func (i *Connector) UnmarshalText(text []byte) error {
var err error
*i, err = ConnectorString(string(text))
return err
}

29
apps/api/internal/cache/error.go vendored Normal file
View File

@@ -0,0 +1,29 @@
package cache
import (
"errors"
"fmt"
)
type IndexUnknownError[I comparable] struct {
index I
}
func NewIndexUnknownErr[I comparable](index I) error {
return IndexUnknownError[I]{index}
}
func (i IndexUnknownError[I]) Error() string {
return fmt.Sprintf("index %v unknown", i.index)
}
func (a IndexUnknownError[I]) Is(err error) bool {
if b, ok := err.(IndexUnknownError[I]); ok {
return a.index == b.index
}
return false
}
var (
ErrCacheMiss = errors.New("cache miss")
)

76
apps/api/internal/cache/pruner.go vendored Normal file
View File

@@ -0,0 +1,76 @@
package cache
import (
"context"
"math/rand"
"time"
"github.com/jonboulle/clockwork"
"github.com/zitadel/logging"
)
// Pruner is an optional [Cache] interface.
type Pruner interface {
// Prune deletes all invalidated or expired objects.
Prune(ctx context.Context) error
}
type PrunerCache[I, K comparable, V Entry[I, K]] interface {
Cache[I, K, V]
Pruner
}
type AutoPruneConfig struct {
// Interval at which the cache is automatically pruned.
// 0 or lower disables automatic pruning.
Interval time.Duration
// Timeout for an automatic prune.
// It is recommended to keep the value shorter than AutoPruneInterval
// 0 or lower disables automatic pruning.
Timeout time.Duration
}
func (c AutoPruneConfig) StartAutoPrune(background context.Context, pruner Pruner, purpose Purpose) (close func()) {
return c.startAutoPrune(background, pruner, purpose, clockwork.NewRealClock())
}
func (c *AutoPruneConfig) startAutoPrune(background context.Context, pruner Pruner, purpose Purpose, clock clockwork.Clock) (close func()) {
if c.Interval <= 0 {
return func() {}
}
background, cancel := context.WithCancel(background)
// randomize the first interval
timer := clock.NewTimer(time.Duration(rand.Int63n(int64(c.Interval))))
go c.pruneTimer(background, pruner, purpose, timer)
return cancel
}
func (c *AutoPruneConfig) pruneTimer(background context.Context, pruner Pruner, purpose Purpose, timer clockwork.Timer) {
defer func() {
if !timer.Stop() {
<-timer.Chan()
}
}()
for {
select {
case <-background.Done():
return
case <-timer.Chan():
err := c.doPrune(background, pruner)
logging.OnError(err).WithField("purpose", purpose).Error("cache auto prune")
timer.Reset(c.Interval)
}
}
}
func (c *AutoPruneConfig) doPrune(background context.Context, pruner Pruner) error {
ctx, cancel := context.WithCancel(background)
defer cancel()
if c.Timeout > 0 {
ctx, cancel = context.WithTimeout(background, c.Timeout)
defer cancel()
}
return pruner.Prune(ctx)
}

43
apps/api/internal/cache/pruner_test.go vendored Normal file
View File

@@ -0,0 +1,43 @@
package cache
import (
"context"
"testing"
"time"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
)
type testPruner struct {
called chan struct{}
}
func (p *testPruner) Prune(context.Context) error {
p.called <- struct{}{}
return nil
}
func TestAutoPruneConfig_startAutoPrune(t *testing.T) {
c := AutoPruneConfig{
Interval: time.Second,
Timeout: time.Millisecond,
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
pruner := testPruner{
called: make(chan struct{}),
}
clock := clockwork.NewFakeClock()
close := c.startAutoPrune(ctx, &pruner, PurposeAuthzInstance, clock)
defer close()
clock.Advance(time.Second)
select {
case _, ok := <-pruner.called:
assert.True(t, ok)
case <-ctx.Done():
t.Fatal(ctx.Err())
}
}

View File

@@ -0,0 +1,94 @@
// Code generated by "enumer -type Purpose -transform snake -trimprefix Purpose"; DO NOT EDIT.
package cache
import (
"fmt"
"strings"
)
const _PurposeName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callbackfederated_logout"
var _PurposeIndex = [...]uint8{0, 11, 25, 35, 47, 65, 81}
const _PurposeLowerName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callbackfederated_logout"
func (i Purpose) String() string {
if i < 0 || i >= Purpose(len(_PurposeIndex)-1) {
return fmt.Sprintf("Purpose(%d)", i)
}
return _PurposeName[_PurposeIndex[i]:_PurposeIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _PurposeNoOp() {
var x [1]struct{}
_ = x[PurposeUnspecified-(0)]
_ = x[PurposeAuthzInstance-(1)]
_ = x[PurposeMilestones-(2)]
_ = x[PurposeOrganization-(3)]
_ = x[PurposeIdPFormCallback-(4)]
_ = x[PurposeFederatedLogout-(5)]
}
var _PurposeValues = []Purpose{PurposeUnspecified, PurposeAuthzInstance, PurposeMilestones, PurposeOrganization, PurposeIdPFormCallback, PurposeFederatedLogout}
var _PurposeNameToValueMap = map[string]Purpose{
_PurposeName[0:11]: PurposeUnspecified,
_PurposeLowerName[0:11]: PurposeUnspecified,
_PurposeName[11:25]: PurposeAuthzInstance,
_PurposeLowerName[11:25]: PurposeAuthzInstance,
_PurposeName[25:35]: PurposeMilestones,
_PurposeLowerName[25:35]: PurposeMilestones,
_PurposeName[35:47]: PurposeOrganization,
_PurposeLowerName[35:47]: PurposeOrganization,
_PurposeName[47:65]: PurposeIdPFormCallback,
_PurposeLowerName[47:65]: PurposeIdPFormCallback,
_PurposeName[65:81]: PurposeFederatedLogout,
_PurposeLowerName[65:81]: PurposeFederatedLogout,
}
var _PurposeNames = []string{
_PurposeName[0:11],
_PurposeName[11:25],
_PurposeName[25:35],
_PurposeName[35:47],
_PurposeName[47:65],
_PurposeName[65:81],
}
// PurposeString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func PurposeString(s string) (Purpose, error) {
if val, ok := _PurposeNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _PurposeNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to Purpose values", s)
}
// PurposeValues returns all values of the enum
func PurposeValues() []Purpose {
return _PurposeValues
}
// PurposeStrings returns a slice of all String values of the enum
func PurposeStrings() []string {
strs := make([]string, len(_PurposeNames))
copy(strs, _PurposeNames)
return strs
}
// IsAPurpose returns "true" if the value is listed in the enum definition. "false" otherwise
func (i Purpose) IsAPurpose() bool {
for _, v := range _PurposeValues {
if i == v {
return true
}
}
return false
}