mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 21:27:42 +00:00
feat(cache): redis cache (#8822)
# Which Problems Are Solved Add a cache implementation using Redis single mode. This does not add support for Redis Cluster or sentinel. # How the Problems Are Solved Added the `internal/cache/redis` package. All operations occur atomically, including setting of secondary indexes, using LUA scripts where needed. The [`miniredis`](https://github.com/alicebob/miniredis) package is used to run unit tests. # Additional Changes - Move connector code to `internal/cache/connector/...` and remove duplicate code from `query` and `command` packages. - Fix a missed invalidation on the restrictions projection # Additional Context Closes #8130
This commit is contained in:
69
internal/cache/connector/connector.go
vendored
Normal file
69
internal/cache/connector/connector.go
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
// Package connector provides glue between the [cache.Cache] interface and implementations from the connector sub-packages.
|
||||
package connector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/cache"
|
||||
"github.com/zitadel/zitadel/internal/cache/connector/gomap"
|
||||
"github.com/zitadel/zitadel/internal/cache/connector/noop"
|
||||
"github.com/zitadel/zitadel/internal/cache/connector/pg"
|
||||
"github.com/zitadel/zitadel/internal/cache/connector/redis"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
)
|
||||
|
||||
type CachesConfig struct {
|
||||
Connectors struct {
|
||||
Memory gomap.Config
|
||||
Postgres pg.Config
|
||||
Redis redis.Config
|
||||
}
|
||||
Instance *cache.Config
|
||||
Milestones *cache.Config
|
||||
}
|
||||
|
||||
type Connectors struct {
|
||||
Config CachesConfig
|
||||
Memory *gomap.Connector
|
||||
Postgres *pg.Connector
|
||||
Redis *redis.Connector
|
||||
}
|
||||
|
||||
func StartConnectors(conf *CachesConfig, client *database.DB) (Connectors, error) {
|
||||
if conf == nil {
|
||||
return Connectors{}, nil
|
||||
}
|
||||
return Connectors{
|
||||
Config: *conf,
|
||||
Memory: gomap.NewConnector(conf.Connectors.Memory),
|
||||
Postgres: pg.NewConnector(conf.Connectors.Postgres, client),
|
||||
Redis: redis.NewConnector(conf.Connectors.Redis),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func StartCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, purpose cache.Purpose, conf *cache.Config, connectors Connectors) (cache.Cache[I, K, V], error) {
|
||||
if conf == nil || conf.Connector == cache.ConnectorUnspecified {
|
||||
return noop.NewCache[I, K, V](), nil
|
||||
}
|
||||
if conf.Connector == cache.ConnectorMemory && connectors.Memory != nil {
|
||||
c := gomap.NewCache[I, K, V](background, indices, *conf)
|
||||
connectors.Memory.Config.StartAutoPrune(background, c, purpose)
|
||||
return c, nil
|
||||
}
|
||||
if conf.Connector == cache.ConnectorPostgres && connectors.Postgres != nil {
|
||||
c, err := pg.NewCache[I, K, V](background, purpose, *conf, indices, connectors.Postgres)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start cache: %w", err)
|
||||
}
|
||||
connectors.Postgres.Config.AutoPrune.StartAutoPrune(background, c, purpose)
|
||||
return c, nil
|
||||
}
|
||||
if conf.Connector == cache.ConnectorRedis && connectors.Redis != nil {
|
||||
db := connectors.Redis.Config.DBOffset + int(purpose)
|
||||
c := redis.NewCache[I, K, V](*conf, connectors.Redis, db, indices)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector)
|
||||
}
|
23
internal/cache/connector/gomap/connector.go
vendored
Normal file
23
internal/cache/connector/gomap/connector.go
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
package gomap
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/internal/cache"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
AutoPrune cache.AutoPruneConfig
|
||||
}
|
||||
|
||||
type Connector struct {
|
||||
Config cache.AutoPruneConfig
|
||||
}
|
||||
|
||||
func NewConnector(config Config) *Connector {
|
||||
if !config.Enabled {
|
||||
return nil
|
||||
}
|
||||
return &Connector{
|
||||
Config: config.AutoPrune,
|
||||
}
|
||||
}
|
200
internal/cache/connector/gomap/gomap.go
vendored
Normal file
200
internal/cache/connector/gomap/gomap.go
vendored
Normal file
@@ -0,0 +1,200 @@
|
||||
package gomap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/cache"
|
||||
)
|
||||
|
||||
type mapCache[I, K comparable, V cache.Entry[I, K]] struct {
|
||||
config *cache.Config
|
||||
indexMap map[I]*index[K, V]
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewCache returns an in-memory Cache implementation based on the builtin go map type.
|
||||
// Object values are stored as-is and there is no encoding or decoding involved.
|
||||
func NewCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, config cache.Config) cache.PrunerCache[I, K, V] {
|
||||
m := &mapCache[I, K, V]{
|
||||
config: &config,
|
||||
indexMap: make(map[I]*index[K, V], len(indices)),
|
||||
logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
AddSource: true,
|
||||
Level: slog.LevelError,
|
||||
})),
|
||||
}
|
||||
if config.Log != nil {
|
||||
m.logger = config.Log.Slog()
|
||||
}
|
||||
m.logger.InfoContext(background, "map cache logging enabled")
|
||||
|
||||
for _, name := range indices {
|
||||
m.indexMap[name] = &index[K, V]{
|
||||
config: m.config,
|
||||
entries: make(map[K]*entry[V]),
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) {
|
||||
i, ok := c.indexMap[index]
|
||||
if !ok {
|
||||
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
|
||||
return value, false
|
||||
}
|
||||
entry, err := i.Get(key)
|
||||
if err == nil {
|
||||
c.logger.DebugContext(ctx, "map cache get", "index", index, "key", key)
|
||||
return entry.value, true
|
||||
}
|
||||
if errors.Is(err, cache.ErrCacheMiss) {
|
||||
c.logger.InfoContext(ctx, "map cache get", "err", err, "index", index, "key", key)
|
||||
return value, false
|
||||
}
|
||||
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
|
||||
return value, false
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Set(ctx context.Context, value V) {
|
||||
now := time.Now()
|
||||
entry := &entry[V]{
|
||||
value: value,
|
||||
created: now,
|
||||
}
|
||||
entry.lastUse.Store(now.UnixMicro())
|
||||
|
||||
for name, i := range c.indexMap {
|
||||
keys := value.Keys(name)
|
||||
i.Set(keys, entry)
|
||||
c.logger.DebugContext(ctx, "map cache set", "index", name, "keys", keys)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Invalidate(ctx context.Context, index I, keys ...K) error {
|
||||
i, ok := c.indexMap[index]
|
||||
if !ok {
|
||||
return cache.NewIndexUnknownErr(index)
|
||||
}
|
||||
i.Invalidate(keys)
|
||||
c.logger.DebugContext(ctx, "map cache invalidate", "index", index, "keys", keys)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Delete(ctx context.Context, index I, keys ...K) error {
|
||||
i, ok := c.indexMap[index]
|
||||
if !ok {
|
||||
return cache.NewIndexUnknownErr(index)
|
||||
}
|
||||
i.Delete(keys)
|
||||
c.logger.DebugContext(ctx, "map cache delete", "index", index, "keys", keys)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Prune(ctx context.Context) error {
|
||||
for name, index := range c.indexMap {
|
||||
index.Prune()
|
||||
c.logger.DebugContext(ctx, "map cache prune", "index", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Truncate(ctx context.Context) error {
|
||||
for name, index := range c.indexMap {
|
||||
index.Truncate()
|
||||
c.logger.DebugContext(ctx, "map cache truncate", "index", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type index[K comparable, V any] struct {
|
||||
mutex sync.RWMutex
|
||||
config *cache.Config
|
||||
entries map[K]*entry[V]
|
||||
}
|
||||
|
||||
func (i *index[K, V]) Get(key K) (*entry[V], error) {
|
||||
i.mutex.RLock()
|
||||
entry, ok := i.entries[key]
|
||||
i.mutex.RUnlock()
|
||||
if ok && entry.isValid(i.config) {
|
||||
return entry, nil
|
||||
}
|
||||
return nil, cache.ErrCacheMiss
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Set(keys []K, entry *entry[V]) {
|
||||
c.mutex.Lock()
|
||||
for _, key := range keys {
|
||||
c.entries[key] = entry
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (i *index[K, V]) Invalidate(keys []K) {
|
||||
i.mutex.RLock()
|
||||
for _, key := range keys {
|
||||
if entry, ok := i.entries[key]; ok {
|
||||
entry.invalid.Store(true)
|
||||
}
|
||||
}
|
||||
i.mutex.RUnlock()
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Delete(keys []K) {
|
||||
c.mutex.Lock()
|
||||
for _, key := range keys {
|
||||
delete(c.entries, key)
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Prune() {
|
||||
c.mutex.Lock()
|
||||
maps.DeleteFunc(c.entries, func(_ K, entry *entry[V]) bool {
|
||||
return !entry.isValid(c.config)
|
||||
})
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Truncate() {
|
||||
c.mutex.Lock()
|
||||
c.entries = make(map[K]*entry[V])
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
type entry[V any] struct {
|
||||
value V
|
||||
created time.Time
|
||||
invalid atomic.Bool
|
||||
lastUse atomic.Int64 // UnixMicro time
|
||||
}
|
||||
|
||||
func (e *entry[V]) isValid(c *cache.Config) bool {
|
||||
if e.invalid.Load() {
|
||||
return false
|
||||
}
|
||||
now := time.Now()
|
||||
if c.MaxAge > 0 {
|
||||
if e.created.Add(c.MaxAge).Before(now) {
|
||||
e.invalid.Store(true)
|
||||
return false
|
||||
}
|
||||
}
|
||||
if c.LastUseAge > 0 {
|
||||
lastUse := e.lastUse.Load()
|
||||
if time.UnixMicro(lastUse).Add(c.LastUseAge).Before(now) {
|
||||
e.invalid.Store(true)
|
||||
return false
|
||||
}
|
||||
e.lastUse.CompareAndSwap(lastUse, now.UnixMicro())
|
||||
}
|
||||
return true
|
||||
}
|
329
internal/cache/connector/gomap/gomap_test.go
vendored
Normal file
329
internal/cache/connector/gomap/gomap_test.go
vendored
Normal file
@@ -0,0 +1,329 @@
|
||||
package gomap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/cache"
|
||||
)
|
||||
|
||||
type testIndex int
|
||||
|
||||
const (
|
||||
testIndexID testIndex = iota
|
||||
testIndexName
|
||||
)
|
||||
|
||||
var testIndices = []testIndex{
|
||||
testIndexID,
|
||||
testIndexName,
|
||||
}
|
||||
|
||||
type testObject struct {
|
||||
id string
|
||||
names []string
|
||||
}
|
||||
|
||||
func (o *testObject) Keys(index testIndex) []string {
|
||||
switch index {
|
||||
case testIndexID:
|
||||
return []string{o.id}
|
||||
case testIndexName:
|
||||
return o.names
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mapCache_Get(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
obj := &testObject{
|
||||
id: "id",
|
||||
names: []string{"foo", "bar"},
|
||||
}
|
||||
c.Set(context.Background(), obj)
|
||||
|
||||
type args struct {
|
||||
index testIndex
|
||||
key string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *testObject
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
index: testIndexID,
|
||||
key: "id",
|
||||
},
|
||||
want: obj,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "miss",
|
||||
args: args{
|
||||
index: testIndexID,
|
||||
key: "spanac",
|
||||
},
|
||||
want: nil,
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "unknown index",
|
||||
args: args{
|
||||
index: 99,
|
||||
key: "id",
|
||||
},
|
||||
want: nil,
|
||||
wantOk: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, ok := c.Get(context.Background(), tt.args.index, tt.args.key)
|
||||
assert.Equal(t, tt.want, got)
|
||||
assert.Equal(t, tt.wantOk, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mapCache_Invalidate(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
obj := &testObject{
|
||||
id: "id",
|
||||
names: []string{"foo", "bar"},
|
||||
}
|
||||
c.Set(context.Background(), obj)
|
||||
err := c.Invalidate(context.Background(), testIndexName, "bar")
|
||||
require.NoError(t, err)
|
||||
got, ok := c.Get(context.Background(), testIndexID, "id")
|
||||
assert.Nil(t, got)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func Test_mapCache_Delete(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
obj := &testObject{
|
||||
id: "id",
|
||||
names: []string{"foo", "bar"},
|
||||
}
|
||||
c.Set(context.Background(), obj)
|
||||
err := c.Delete(context.Background(), testIndexName, "bar")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Shouldn't find object by deleted name
|
||||
got, ok := c.Get(context.Background(), testIndexName, "bar")
|
||||
assert.Nil(t, got)
|
||||
assert.False(t, ok)
|
||||
|
||||
// Should find object by other name
|
||||
got, ok = c.Get(context.Background(), testIndexName, "foo")
|
||||
assert.Equal(t, obj, got)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Should find object by id
|
||||
got, ok = c.Get(context.Background(), testIndexID, "id")
|
||||
assert.Equal(t, obj, got)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func Test_mapCache_Prune(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
|
||||
objects := []*testObject{
|
||||
{
|
||||
id: "id1",
|
||||
names: []string{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
id: "id2",
|
||||
names: []string{"hello"},
|
||||
},
|
||||
}
|
||||
for _, obj := range objects {
|
||||
c.Set(context.Background(), obj)
|
||||
}
|
||||
// invalidate one entry
|
||||
err := c.Invalidate(context.Background(), testIndexName, "bar")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = c.(cache.Pruner).Prune(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Other object should still be found
|
||||
got, ok := c.Get(context.Background(), testIndexID, "id2")
|
||||
assert.Equal(t, objects[1], got)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func Test_mapCache_Truncate(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
objects := []*testObject{
|
||||
{
|
||||
id: "id1",
|
||||
names: []string{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
id: "id2",
|
||||
names: []string{"hello"},
|
||||
},
|
||||
}
|
||||
for _, obj := range objects {
|
||||
c.Set(context.Background(), obj)
|
||||
}
|
||||
|
||||
err := c.Truncate(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
mc := c.(*mapCache[testIndex, string, *testObject])
|
||||
for _, index := range mc.indexMap {
|
||||
index.mutex.RLock()
|
||||
assert.Len(t, index.entries, 0)
|
||||
index.mutex.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
func Test_entry_isValid(t *testing.T) {
|
||||
type fields struct {
|
||||
created time.Time
|
||||
invalid bool
|
||||
lastUse time.Time
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
config *cache.Config
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "invalid",
|
||||
fields: fields{
|
||||
created: time.Now(),
|
||||
invalid: true,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "max age exceeded",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute + time.Second)),
|
||||
invalid: false,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "max age disabled",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute + time.Second)),
|
||||
invalid: false,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "last use age exceeded",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute / 2)),
|
||||
invalid: false,
|
||||
lastUse: time.Now().Add(-(time.Second * 2)),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "last use age disabled",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute / 2)),
|
||||
invalid: false,
|
||||
lastUse: time.Now().Add(-(time.Second * 2)),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
fields: fields{
|
||||
created: time.Now(),
|
||||
invalid: false,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := &entry[any]{
|
||||
created: tt.fields.created,
|
||||
}
|
||||
e.invalid.Store(tt.fields.invalid)
|
||||
e.lastUse.Store(tt.fields.lastUse.UnixMicro())
|
||||
got := e.isValid(tt.config)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
21
internal/cache/connector/noop/noop.go
vendored
Normal file
21
internal/cache/connector/noop/noop.go
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
package noop
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/cache"
|
||||
)
|
||||
|
||||
type noop[I, K comparable, V cache.Entry[I, K]] struct{}
|
||||
|
||||
// NewCache returns a cache that does nothing
|
||||
func NewCache[I, K comparable, V cache.Entry[I, K]]() cache.Cache[I, K, V] {
|
||||
return noop[I, K, V]{}
|
||||
}
|
||||
|
||||
func (noop[I, K, V]) Set(context.Context, V) {}
|
||||
func (noop[I, K, V]) Get(context.Context, I, K) (value V, ok bool) { return }
|
||||
func (noop[I, K, V]) Invalidate(context.Context, I, ...K) (err error) { return }
|
||||
func (noop[I, K, V]) Delete(context.Context, I, ...K) (err error) { return }
|
||||
func (noop[I, K, V]) Prune(context.Context) (err error) { return }
|
||||
func (noop[I, K, V]) Truncate(context.Context) (err error) { return }
|
28
internal/cache/connector/pg/connector.go
vendored
Normal file
28
internal/cache/connector/pg/connector.go
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
package pg
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/internal/cache"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
AutoPrune cache.AutoPruneConfig
|
||||
}
|
||||
|
||||
type Connector struct {
|
||||
PGXPool
|
||||
Dialect string
|
||||
Config Config
|
||||
}
|
||||
|
||||
func NewConnector(config Config, client *database.DB) *Connector {
|
||||
if !config.Enabled {
|
||||
return nil
|
||||
}
|
||||
return &Connector{
|
||||
PGXPool: client.Pool,
|
||||
Dialect: client.Type(),
|
||||
Config: config,
|
||||
}
|
||||
}
|
7
internal/cache/connector/pg/create_partition.sql.tmpl
vendored
Normal file
7
internal/cache/connector/pg/create_partition.sql.tmpl
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
create unlogged table if not exists cache.objects_{{ . }}
|
||||
partition of cache.objects
|
||||
for values in ('{{ . }}');
|
||||
|
||||
create unlogged table if not exists cache.string_keys_{{ . }}
|
||||
partition of cache.string_keys
|
||||
for values in ('{{ . }}');
|
5
internal/cache/connector/pg/delete.sql
vendored
Normal file
5
internal/cache/connector/pg/delete.sql
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
delete from cache.string_keys k
|
||||
where k.cache_name = $1
|
||||
and k.index_id = $2
|
||||
and k.index_key = any($3)
|
||||
;
|
19
internal/cache/connector/pg/get.sql
vendored
Normal file
19
internal/cache/connector/pg/get.sql
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
update cache.objects
|
||||
set last_used_at = now()
|
||||
where cache_name = $1
|
||||
and (
|
||||
select object_id
|
||||
from cache.string_keys k
|
||||
where cache_name = $1
|
||||
and index_id = $2
|
||||
and index_key = $3
|
||||
) = id
|
||||
and case when $4::interval > '0s'
|
||||
then created_at > now()-$4::interval -- max age
|
||||
else true
|
||||
end
|
||||
and case when $5::interval > '0s'
|
||||
then last_used_at > now()-$5::interval -- last use
|
||||
else true
|
||||
end
|
||||
returning payload;
|
9
internal/cache/connector/pg/invalidate.sql
vendored
Normal file
9
internal/cache/connector/pg/invalidate.sql
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
delete from cache.objects o
|
||||
using cache.string_keys k
|
||||
where k.cache_name = $1
|
||||
and k.index_id = $2
|
||||
and k.index_key = any($3)
|
||||
and o.cache_name = k.cache_name
|
||||
and o.id = k.object_id
|
||||
;
|
||||
|
176
internal/cache/connector/pg/pg.go
vendored
Normal file
176
internal/cache/connector/pg/pg.go
vendored
Normal file
@@ -0,0 +1,176 @@
|
||||
package pg
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/cache"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed create_partition.sql.tmpl
|
||||
createPartitionQuery string
|
||||
createPartitionTmpl = template.Must(template.New("create_partition").Parse(createPartitionQuery))
|
||||
//go:embed set.sql
|
||||
setQuery string
|
||||
//go:embed get.sql
|
||||
getQuery string
|
||||
//go:embed invalidate.sql
|
||||
invalidateQuery string
|
||||
//go:embed delete.sql
|
||||
deleteQuery string
|
||||
//go:embed prune.sql
|
||||
pruneQuery string
|
||||
//go:embed truncate.sql
|
||||
truncateQuery string
|
||||
)
|
||||
|
||||
type PGXPool interface {
|
||||
Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
|
||||
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
|
||||
}
|
||||
|
||||
type pgCache[I ~int, K ~string, V cache.Entry[I, K]] struct {
|
||||
purpose cache.Purpose
|
||||
config *cache.Config
|
||||
indices []I
|
||||
connector *Connector
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewCache returns a cache that stores and retrieves objects using PostgreSQL unlogged tables.
|
||||
func NewCache[I ~int, K ~string, V cache.Entry[I, K]](ctx context.Context, purpose cache.Purpose, config cache.Config, indices []I, connector *Connector) (cache.PrunerCache[I, K, V], error) {
|
||||
c := &pgCache[I, K, V]{
|
||||
purpose: purpose,
|
||||
config: &config,
|
||||
indices: indices,
|
||||
connector: connector,
|
||||
logger: config.Log.Slog().With("cache_purpose", purpose),
|
||||
}
|
||||
c.logger.InfoContext(ctx, "pg cache logging enabled")
|
||||
|
||||
if connector.Dialect == "postgres" {
|
||||
if err := c.createPartition(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *pgCache[I, K, V]) createPartition(ctx context.Context) error {
|
||||
var query strings.Builder
|
||||
if err := createPartitionTmpl.Execute(&query, c.purpose.String()); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := c.connector.Exec(ctx, query.String())
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *pgCache[I, K, V]) Set(ctx context.Context, entry V) {
|
||||
//nolint:errcheck
|
||||
c.set(ctx, entry)
|
||||
}
|
||||
|
||||
func (c *pgCache[I, K, V]) set(ctx context.Context, entry V) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
keys := c.indexKeysFromEntry(entry)
|
||||
c.logger.DebugContext(ctx, "pg cache set", "index_key", keys)
|
||||
|
||||
_, err = c.connector.Exec(ctx, setQuery, c.purpose.String(), keys, entry)
|
||||
if err != nil {
|
||||
c.logger.ErrorContext(ctx, "pg cache set", "err", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *pgCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) {
|
||||
value, err := c.get(ctx, index, key)
|
||||
if err == nil {
|
||||
c.logger.DebugContext(ctx, "pg cache get", "index", index, "key", key)
|
||||
return value, true
|
||||
}
|
||||
logger := c.logger.With("err", err, "index", index, "key", key)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
logger.InfoContext(ctx, "pg cache miss")
|
||||
return value, false
|
||||
}
|
||||
logger.ErrorContext(ctx, "pg cache get", "err", err)
|
||||
return value, false
|
||||
}
|
||||
|
||||
func (c *pgCache[I, K, V]) get(ctx context.Context, index I, key K) (value V, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if !slices.Contains(c.indices, index) {
|
||||
return value, cache.NewIndexUnknownErr(index)
|
||||
}
|
||||
err = c.connector.QueryRow(ctx, getQuery, c.purpose.String(), index, key, c.config.MaxAge, c.config.LastUseAge).Scan(&value)
|
||||
return value, err
|
||||
}
|
||||
|
||||
func (c *pgCache[I, K, V]) Invalidate(ctx context.Context, index I, keys ...K) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
_, err = c.connector.Exec(ctx, invalidateQuery, c.purpose.String(), index, keys)
|
||||
c.logger.DebugContext(ctx, "pg cache invalidate", "index", index, "keys", keys)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *pgCache[I, K, V]) Delete(ctx context.Context, index I, keys ...K) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
_, err = c.connector.Exec(ctx, deleteQuery, c.purpose.String(), index, keys)
|
||||
c.logger.DebugContext(ctx, "pg cache delete", "index", index, "keys", keys)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *pgCache[I, K, V]) Prune(ctx context.Context) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
_, err = c.connector.Exec(ctx, pruneQuery, c.purpose.String(), c.config.MaxAge, c.config.LastUseAge)
|
||||
c.logger.DebugContext(ctx, "pg cache prune")
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *pgCache[I, K, V]) Truncate(ctx context.Context) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
_, err = c.connector.Exec(ctx, truncateQuery, c.purpose.String())
|
||||
c.logger.DebugContext(ctx, "pg cache truncate")
|
||||
return err
|
||||
}
|
||||
|
||||
type indexKey[I, K comparable] struct {
|
||||
IndexID I `json:"index_id"`
|
||||
IndexKey K `json:"index_key"`
|
||||
}
|
||||
|
||||
func (c *pgCache[I, K, V]) indexKeysFromEntry(entry V) []indexKey[I, K] {
|
||||
keys := make([]indexKey[I, K], 0, len(c.indices)*3) // naive assumption
|
||||
for _, index := range c.indices {
|
||||
for _, key := range entry.Keys(index) {
|
||||
keys = append(keys, indexKey[I, K]{
|
||||
IndexID: index,
|
||||
IndexKey: key,
|
||||
})
|
||||
}
|
||||
}
|
||||
return keys
|
||||
}
|
526
internal/cache/connector/pg/pg_test.go
vendored
Normal file
526
internal/cache/connector/pg/pg_test.go
vendored
Normal file
@@ -0,0 +1,526 @@
|
||||
package pg
|
||||
|
||||
import (
|
||||
"context"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/pashagolub/pgxmock/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/cache"
|
||||
)
|
||||
|
||||
type testIndex int
|
||||
|
||||
const (
|
||||
testIndexID testIndex = iota
|
||||
testIndexName
|
||||
)
|
||||
|
||||
var testIndices = []testIndex{
|
||||
testIndexID,
|
||||
testIndexName,
|
||||
}
|
||||
|
||||
type testObject struct {
|
||||
ID string
|
||||
Name []string
|
||||
}
|
||||
|
||||
func (o *testObject) Keys(index testIndex) []string {
|
||||
switch index {
|
||||
case testIndexID:
|
||||
return []string{o.ID}
|
||||
case testIndexName:
|
||||
return o.Name
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCache(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expect func(pgxmock.PgxCommonIface)
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "error",
|
||||
expect: func(pci pgxmock.PgxCommonIface) {
|
||||
pci.ExpectExec(regexp.QuoteMeta(expectedCreatePartitionQuery)).
|
||||
WillReturnError(pgx.ErrTxClosed)
|
||||
},
|
||||
wantErr: pgx.ErrTxClosed,
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
expect: func(pci pgxmock.PgxCommonIface) {
|
||||
pci.ExpectExec(regexp.QuoteMeta(expectedCreatePartitionQuery)).
|
||||
WillReturnResult(pgxmock.NewResult("CREATE TABLE", 0))
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
conf := cache.Config{
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
}
|
||||
pool, err := pgxmock.NewPool()
|
||||
require.NoError(t, err)
|
||||
tt.expect(pool)
|
||||
connector := &Connector{
|
||||
PGXPool: pool,
|
||||
Dialect: "postgres",
|
||||
}
|
||||
|
||||
c, err := NewCache[testIndex, string, *testObject](context.Background(), cachePurpose, conf, testIndices, connector)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
if tt.wantErr == nil {
|
||||
assert.NotNil(t, c)
|
||||
}
|
||||
|
||||
err = pool.ExpectationsWereMet()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func Test_pgCache_Set(t *testing.T) {
|
||||
queryExpect := regexp.QuoteMeta(setQuery)
|
||||
type args struct {
|
||||
entry *testObject
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
expect func(pgxmock.PgxCommonIface)
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "error",
|
||||
args: args{
|
||||
&testObject{
|
||||
ID: "id1",
|
||||
Name: []string{"foo", "bar"},
|
||||
},
|
||||
},
|
||||
expect: func(ppi pgxmock.PgxCommonIface) {
|
||||
ppi.ExpectExec(queryExpect).
|
||||
WithArgs(cachePurpose.String(),
|
||||
[]indexKey[testIndex, string]{
|
||||
{IndexID: testIndexID, IndexKey: "id1"},
|
||||
{IndexID: testIndexName, IndexKey: "foo"},
|
||||
{IndexID: testIndexName, IndexKey: "bar"},
|
||||
},
|
||||
&testObject{
|
||||
ID: "id1",
|
||||
Name: []string{"foo", "bar"},
|
||||
}).
|
||||
WillReturnError(pgx.ErrTxClosed)
|
||||
},
|
||||
wantErr: pgx.ErrTxClosed,
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
args: args{
|
||||
&testObject{
|
||||
ID: "id1",
|
||||
Name: []string{"foo", "bar"},
|
||||
},
|
||||
},
|
||||
expect: func(ppi pgxmock.PgxCommonIface) {
|
||||
ppi.ExpectExec(queryExpect).
|
||||
WithArgs(cachePurpose.String(),
|
||||
[]indexKey[testIndex, string]{
|
||||
{IndexID: testIndexID, IndexKey: "id1"},
|
||||
{IndexID: testIndexName, IndexKey: "foo"},
|
||||
{IndexID: testIndexName, IndexKey: "bar"},
|
||||
},
|
||||
&testObject{
|
||||
ID: "id1",
|
||||
Name: []string{"foo", "bar"},
|
||||
}).
|
||||
WillReturnResult(pgxmock.NewResult("INSERT", 1))
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c, pool := prepareCache(t, cache.Config{})
|
||||
defer pool.Close()
|
||||
tt.expect(pool)
|
||||
|
||||
err := c.(*pgCache[testIndex, string, *testObject]).
|
||||
set(context.Background(), tt.args.entry)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
|
||||
err = pool.ExpectationsWereMet()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_pgCache_Get(t *testing.T) {
|
||||
queryExpect := regexp.QuoteMeta(getQuery)
|
||||
type args struct {
|
||||
index testIndex
|
||||
key string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
config cache.Config
|
||||
args args
|
||||
expect func(pgxmock.PgxCommonIface)
|
||||
want *testObject
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "invalid index",
|
||||
config: cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
args: args{
|
||||
index: 99,
|
||||
key: "id1",
|
||||
},
|
||||
expect: func(pci pgxmock.PgxCommonIface) {},
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "no rows",
|
||||
config: cache.Config{
|
||||
MaxAge: 0,
|
||||
LastUseAge: 0,
|
||||
},
|
||||
args: args{
|
||||
index: testIndexID,
|
||||
key: "id1",
|
||||
},
|
||||
expect: func(pci pgxmock.PgxCommonIface) {
|
||||
pci.ExpectQuery(queryExpect).
|
||||
WithArgs(cachePurpose.String(), testIndexID, "id1", time.Duration(0), time.Duration(0)).
|
||||
WillReturnRows(pgxmock.NewRows([]string{"payload"}))
|
||||
},
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "error",
|
||||
config: cache.Config{
|
||||
MaxAge: 0,
|
||||
LastUseAge: 0,
|
||||
},
|
||||
args: args{
|
||||
index: testIndexID,
|
||||
key: "id1",
|
||||
},
|
||||
expect: func(pci pgxmock.PgxCommonIface) {
|
||||
pci.ExpectQuery(queryExpect).
|
||||
WithArgs(cachePurpose.String(), testIndexID, "id1", time.Duration(0), time.Duration(0)).
|
||||
WillReturnError(pgx.ErrTxClosed)
|
||||
},
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
config: cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
args: args{
|
||||
index: testIndexID,
|
||||
key: "id1",
|
||||
},
|
||||
expect: func(pci pgxmock.PgxCommonIface) {
|
||||
pci.ExpectQuery(queryExpect).
|
||||
WithArgs(cachePurpose.String(), testIndexID, "id1", time.Minute, time.Second).
|
||||
WillReturnRows(
|
||||
pgxmock.NewRows([]string{"payload"}).AddRow(&testObject{
|
||||
ID: "id1",
|
||||
Name: []string{"foo", "bar"},
|
||||
}),
|
||||
)
|
||||
},
|
||||
want: &testObject{
|
||||
ID: "id1",
|
||||
Name: []string{"foo", "bar"},
|
||||
},
|
||||
wantOk: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c, pool := prepareCache(t, tt.config)
|
||||
defer pool.Close()
|
||||
tt.expect(pool)
|
||||
|
||||
got, ok := c.Get(context.Background(), tt.args.index, tt.args.key)
|
||||
assert.Equal(t, tt.wantOk, ok)
|
||||
assert.Equal(t, tt.want, got)
|
||||
err := pool.ExpectationsWereMet()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_pgCache_Invalidate(t *testing.T) {
|
||||
queryExpect := regexp.QuoteMeta(invalidateQuery)
|
||||
type args struct {
|
||||
index testIndex
|
||||
keys []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
config cache.Config
|
||||
args args
|
||||
expect func(pgxmock.PgxCommonIface)
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "error",
|
||||
config: cache.Config{
|
||||
MaxAge: 0,
|
||||
LastUseAge: 0,
|
||||
},
|
||||
args: args{
|
||||
index: testIndexID,
|
||||
keys: []string{"id1", "id2"},
|
||||
},
|
||||
expect: func(pci pgxmock.PgxCommonIface) {
|
||||
pci.ExpectExec(queryExpect).
|
||||
WithArgs(cachePurpose.String(), testIndexID, []string{"id1", "id2"}).
|
||||
WillReturnError(pgx.ErrTxClosed)
|
||||
},
|
||||
wantErr: pgx.ErrTxClosed,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
config: cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
args: args{
|
||||
index: testIndexID,
|
||||
keys: []string{"id1", "id2"},
|
||||
},
|
||||
expect: func(pci pgxmock.PgxCommonIface) {
|
||||
pci.ExpectExec(queryExpect).
|
||||
WithArgs(cachePurpose.String(), testIndexID, []string{"id1", "id2"}).
|
||||
WillReturnResult(pgxmock.NewResult("DELETE", 1))
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c, pool := prepareCache(t, tt.config)
|
||||
defer pool.Close()
|
||||
tt.expect(pool)
|
||||
|
||||
err := c.Invalidate(context.Background(), tt.args.index, tt.args.keys...)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
|
||||
err = pool.ExpectationsWereMet()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_pgCache_Delete(t *testing.T) {
|
||||
queryExpect := regexp.QuoteMeta(deleteQuery)
|
||||
type args struct {
|
||||
index testIndex
|
||||
keys []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
config cache.Config
|
||||
args args
|
||||
expect func(pgxmock.PgxCommonIface)
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "error",
|
||||
config: cache.Config{
|
||||
MaxAge: 0,
|
||||
LastUseAge: 0,
|
||||
},
|
||||
args: args{
|
||||
index: testIndexID,
|
||||
keys: []string{"id1", "id2"},
|
||||
},
|
||||
expect: func(pci pgxmock.PgxCommonIface) {
|
||||
pci.ExpectExec(queryExpect).
|
||||
WithArgs(cachePurpose.String(), testIndexID, []string{"id1", "id2"}).
|
||||
WillReturnError(pgx.ErrTxClosed)
|
||||
},
|
||||
wantErr: pgx.ErrTxClosed,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
config: cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
args: args{
|
||||
index: testIndexID,
|
||||
keys: []string{"id1", "id2"},
|
||||
},
|
||||
expect: func(pci pgxmock.PgxCommonIface) {
|
||||
pci.ExpectExec(queryExpect).
|
||||
WithArgs(cachePurpose.String(), testIndexID, []string{"id1", "id2"}).
|
||||
WillReturnResult(pgxmock.NewResult("DELETE", 1))
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c, pool := prepareCache(t, tt.config)
|
||||
defer pool.Close()
|
||||
tt.expect(pool)
|
||||
|
||||
err := c.Delete(context.Background(), tt.args.index, tt.args.keys...)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
|
||||
err = pool.ExpectationsWereMet()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_pgCache_Prune(t *testing.T) {
|
||||
queryExpect := regexp.QuoteMeta(pruneQuery)
|
||||
tests := []struct {
|
||||
name string
|
||||
config cache.Config
|
||||
expect func(pgxmock.PgxCommonIface)
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "error",
|
||||
config: cache.Config{
|
||||
MaxAge: 0,
|
||||
LastUseAge: 0,
|
||||
},
|
||||
expect: func(pci pgxmock.PgxCommonIface) {
|
||||
pci.ExpectExec(queryExpect).
|
||||
WithArgs(cachePurpose.String(), time.Duration(0), time.Duration(0)).
|
||||
WillReturnError(pgx.ErrTxClosed)
|
||||
},
|
||||
wantErr: pgx.ErrTxClosed,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
config: cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
expect: func(pci pgxmock.PgxCommonIface) {
|
||||
pci.ExpectExec(queryExpect).
|
||||
WithArgs(cachePurpose.String(), time.Minute, time.Second).
|
||||
WillReturnResult(pgxmock.NewResult("DELETE", 1))
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c, pool := prepareCache(t, tt.config)
|
||||
defer pool.Close()
|
||||
tt.expect(pool)
|
||||
|
||||
err := c.Prune(context.Background())
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
|
||||
err = pool.ExpectationsWereMet()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_pgCache_Truncate(t *testing.T) {
|
||||
queryExpect := regexp.QuoteMeta(truncateQuery)
|
||||
tests := []struct {
|
||||
name string
|
||||
config cache.Config
|
||||
expect func(pgxmock.PgxCommonIface)
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "error",
|
||||
config: cache.Config{
|
||||
MaxAge: 0,
|
||||
LastUseAge: 0,
|
||||
},
|
||||
expect: func(pci pgxmock.PgxCommonIface) {
|
||||
pci.ExpectExec(queryExpect).
|
||||
WithArgs(cachePurpose.String()).
|
||||
WillReturnError(pgx.ErrTxClosed)
|
||||
},
|
||||
wantErr: pgx.ErrTxClosed,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
config: cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
expect: func(pci pgxmock.PgxCommonIface) {
|
||||
pci.ExpectExec(queryExpect).
|
||||
WithArgs(cachePurpose.String()).
|
||||
WillReturnResult(pgxmock.NewResult("DELETE", 1))
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c, pool := prepareCache(t, tt.config)
|
||||
defer pool.Close()
|
||||
tt.expect(pool)
|
||||
|
||||
err := c.Truncate(context.Background())
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
|
||||
err = pool.ExpectationsWereMet()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
cachePurpose = cache.PurposeAuthzInstance
|
||||
expectedCreatePartitionQuery = `create unlogged table if not exists cache.objects_authz_instance
|
||||
partition of cache.objects
|
||||
for values in ('authz_instance');
|
||||
|
||||
create unlogged table if not exists cache.string_keys_authz_instance
|
||||
partition of cache.string_keys
|
||||
for values in ('authz_instance');
|
||||
`
|
||||
)
|
||||
|
||||
func prepareCache(t *testing.T, conf cache.Config) (cache.PrunerCache[testIndex, string, *testObject], pgxmock.PgxPoolIface) {
|
||||
conf.Log = &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
}
|
||||
pool, err := pgxmock.NewPool()
|
||||
require.NoError(t, err)
|
||||
|
||||
pool.ExpectExec(regexp.QuoteMeta(expectedCreatePartitionQuery)).
|
||||
WillReturnResult(pgxmock.NewResult("CREATE TABLE", 0))
|
||||
connector := &Connector{
|
||||
PGXPool: pool,
|
||||
Dialect: "postgres",
|
||||
}
|
||||
c, err := NewCache[testIndex, string, *testObject](context.Background(), cachePurpose, conf, testIndices, connector)
|
||||
require.NoError(t, err)
|
||||
return c, pool
|
||||
}
|
18
internal/cache/connector/pg/prune.sql
vendored
Normal file
18
internal/cache/connector/pg/prune.sql
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
delete from cache.objects o
|
||||
where o.cache_name = $1
|
||||
and (
|
||||
case when $2::interval > '0s'
|
||||
then created_at < now()-$2::interval -- max age
|
||||
else false
|
||||
end
|
||||
or case when $3::interval > '0s'
|
||||
then last_used_at < now()-$3::interval -- last use
|
||||
else false
|
||||
end
|
||||
or o.id not in (
|
||||
select object_id
|
||||
from cache.string_keys
|
||||
where cache_name = $1
|
||||
)
|
||||
)
|
||||
;
|
19
internal/cache/connector/pg/set.sql
vendored
Normal file
19
internal/cache/connector/pg/set.sql
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
with object as (
|
||||
insert into cache.objects (cache_name, payload)
|
||||
values ($1, $3)
|
||||
returning id
|
||||
)
|
||||
insert into cache.string_keys (
|
||||
cache_name,
|
||||
index_id,
|
||||
index_key,
|
||||
object_id
|
||||
)
|
||||
select $1, keys.index_id, keys.index_key, id as object_id
|
||||
from object, jsonb_to_recordset($2) keys (
|
||||
index_id bigint,
|
||||
index_key text
|
||||
)
|
||||
on conflict (cache_name, index_id, index_key) do
|
||||
update set object_id = EXCLUDED.object_id
|
||||
;
|
3
internal/cache/connector/pg/truncate.sql
vendored
Normal file
3
internal/cache/connector/pg/truncate.sql
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
delete from cache.objects o
|
||||
where o.cache_name = $1
|
||||
;
|
10
internal/cache/connector/redis/_remove.lua
vendored
Normal file
10
internal/cache/connector/redis/_remove.lua
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
local function remove(object_id)
|
||||
local setKey = keySetKey(object_id)
|
||||
local keys = redis.call("SMEMBERS", setKey)
|
||||
local n = #keys
|
||||
for i = 1, n do
|
||||
redis.call("DEL", keys[i])
|
||||
end
|
||||
redis.call("DEL", setKey)
|
||||
redis.call("DEL", object_id)
|
||||
end
|
3
internal/cache/connector/redis/_select.lua
vendored
Normal file
3
internal/cache/connector/redis/_select.lua
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
-- SELECT ensures the DB namespace for each script.
|
||||
-- When used, it consumes the first ARGV entry.
|
||||
redis.call("SELECT", ARGV[1])
|
17
internal/cache/connector/redis/_util.lua
vendored
Normal file
17
internal/cache/connector/redis/_util.lua
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
-- keySetKey returns the redis key of the set containing all keys to the object.
|
||||
local function keySetKey (object_id)
|
||||
return object_id .. "-keys"
|
||||
end
|
||||
|
||||
local function getTime()
|
||||
return tonumber(redis.call('TIME')[1])
|
||||
end
|
||||
|
||||
-- getCall wrapts redis.call so a nil is returned instead of false.
|
||||
local function getCall (...)
|
||||
local result = redis.call(...)
|
||||
if result == false then
|
||||
return nil
|
||||
end
|
||||
return result
|
||||
end
|
154
internal/cache/connector/redis/connector.go
vendored
Normal file
154
internal/cache/connector/redis/connector.go
vendored
Normal file
@@ -0,0 +1,154 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
|
||||
// The network type, either tcp or unix.
|
||||
// Default is tcp.
|
||||
Network string
|
||||
// host:port address.
|
||||
Addr string
|
||||
// ClientName will execute the `CLIENT SETNAME ClientName` command for each conn.
|
||||
ClientName string
|
||||
// Use the specified Username to authenticate the current connection
|
||||
// with one of the connections defined in the ACL list when connecting
|
||||
// to a Redis 6.0 instance, or greater, that is using the Redis ACL system.
|
||||
Username string
|
||||
// Optional password. Must match the password specified in the
|
||||
// requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower),
|
||||
// or the User Password when connecting to a Redis 6.0 instance, or greater,
|
||||
// that is using the Redis ACL system.
|
||||
Password string
|
||||
// Each ZITADEL cache uses an incremental DB namespace.
|
||||
// This option offsets the first DB so it doesn't conflict with other databases on the same server.
|
||||
// Note that ZITADEL uses FLUSHDB command to truncate a cache.
|
||||
// This can have destructive consequences when overlapping DB namespaces are used.
|
||||
DBOffset int
|
||||
|
||||
// Maximum number of retries before giving up.
|
||||
// Default is 3 retries; -1 (not 0) disables retries.
|
||||
MaxRetries int
|
||||
// Minimum backoff between each retry.
|
||||
// Default is 8 milliseconds; -1 disables backoff.
|
||||
MinRetryBackoff time.Duration
|
||||
// Maximum backoff between each retry.
|
||||
// Default is 512 milliseconds; -1 disables backoff.
|
||||
MaxRetryBackoff time.Duration
|
||||
|
||||
// Dial timeout for establishing new connections.
|
||||
// Default is 5 seconds.
|
||||
DialTimeout time.Duration
|
||||
// Timeout for socket reads. If reached, commands will fail
|
||||
// with a timeout instead of blocking. Supported values:
|
||||
// - `0` - default timeout (3 seconds).
|
||||
// - `-1` - no timeout (block indefinitely).
|
||||
// - `-2` - disables SetReadDeadline calls completely.
|
||||
ReadTimeout time.Duration
|
||||
// Timeout for socket writes. If reached, commands will fail
|
||||
// with a timeout instead of blocking. Supported values:
|
||||
// - `0` - default timeout (3 seconds).
|
||||
// - `-1` - no timeout (block indefinitely).
|
||||
// - `-2` - disables SetWriteDeadline calls completely.
|
||||
WriteTimeout time.Duration
|
||||
|
||||
// Type of connection pool.
|
||||
// true for FIFO pool, false for LIFO pool.
|
||||
// Note that FIFO has slightly higher overhead compared to LIFO,
|
||||
// but it helps closing idle connections faster reducing the pool size.
|
||||
PoolFIFO bool
|
||||
// Base number of socket connections.
|
||||
// Default is 10 connections per every available CPU as reported by runtime.GOMAXPROCS.
|
||||
// If there is not enough connections in the pool, new connections will be allocated in excess of PoolSize,
|
||||
// you can limit it through MaxActiveConns
|
||||
PoolSize int
|
||||
// Amount of time client waits for connection if all connections
|
||||
// are busy before returning an error.
|
||||
// Default is ReadTimeout + 1 second.
|
||||
PoolTimeout time.Duration
|
||||
// Minimum number of idle connections which is useful when establishing
|
||||
// new connection is slow.
|
||||
// Default is 0. the idle connections are not closed by default.
|
||||
MinIdleConns int
|
||||
// Maximum number of idle connections.
|
||||
// Default is 0. the idle connections are not closed by default.
|
||||
MaxIdleConns int
|
||||
// Maximum number of connections allocated by the pool at a given time.
|
||||
// When zero, there is no limit on the number of connections in the pool.
|
||||
MaxActiveConns int
|
||||
// ConnMaxIdleTime is the maximum amount of time a connection may be idle.
|
||||
// Should be less than server's timeout.
|
||||
//
|
||||
// Expired connections may be closed lazily before reuse.
|
||||
// If d <= 0, connections are not closed due to a connection's idle time.
|
||||
//
|
||||
// Default is 30 minutes. -1 disables idle timeout check.
|
||||
ConnMaxIdleTime time.Duration
|
||||
// ConnMaxLifetime is the maximum amount of time a connection may be reused.
|
||||
//
|
||||
// Expired connections may be closed lazily before reuse.
|
||||
// If <= 0, connections are not closed due to a connection's age.
|
||||
//
|
||||
// Default is to not close idle connections.
|
||||
ConnMaxLifetime time.Duration
|
||||
|
||||
EnableTLS bool
|
||||
|
||||
// Disable set-lib on connect. Default is false.
|
||||
DisableIndentity bool
|
||||
|
||||
// Add suffix to client name. Default is empty.
|
||||
IdentitySuffix string
|
||||
}
|
||||
|
||||
type Connector struct {
|
||||
*redis.Client
|
||||
Config Config
|
||||
}
|
||||
|
||||
func NewConnector(config Config) *Connector {
|
||||
if !config.Enabled {
|
||||
return nil
|
||||
}
|
||||
return &Connector{
|
||||
Client: redis.NewClient(optionsFromConfig(config)),
|
||||
Config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func optionsFromConfig(c Config) *redis.Options {
|
||||
opts := &redis.Options{
|
||||
Network: c.Network,
|
||||
Addr: c.Addr,
|
||||
ClientName: c.ClientName,
|
||||
Protocol: 3,
|
||||
Username: c.Username,
|
||||
Password: c.Password,
|
||||
MaxRetries: c.MaxRetries,
|
||||
MinRetryBackoff: c.MinRetryBackoff,
|
||||
MaxRetryBackoff: c.MaxRetryBackoff,
|
||||
DialTimeout: c.DialTimeout,
|
||||
ReadTimeout: c.ReadTimeout,
|
||||
WriteTimeout: c.WriteTimeout,
|
||||
ContextTimeoutEnabled: true,
|
||||
PoolFIFO: c.PoolFIFO,
|
||||
PoolTimeout: c.PoolTimeout,
|
||||
MinIdleConns: c.MinIdleConns,
|
||||
MaxIdleConns: c.MaxIdleConns,
|
||||
MaxActiveConns: c.MaxActiveConns,
|
||||
ConnMaxIdleTime: c.ConnMaxIdleTime,
|
||||
ConnMaxLifetime: c.ConnMaxLifetime,
|
||||
DisableIndentity: c.DisableIndentity,
|
||||
IdentitySuffix: c.IdentitySuffix,
|
||||
}
|
||||
if c.EnableTLS {
|
||||
opts.TLSConfig = new(tls.Config)
|
||||
}
|
||||
return opts
|
||||
}
|
29
internal/cache/connector/redis/get.lua
vendored
Normal file
29
internal/cache/connector/redis/get.lua
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
local result = redis.call("GET", KEYS[1])
|
||||
if result == false then
|
||||
return nil
|
||||
end
|
||||
local object_id = tostring(result)
|
||||
|
||||
local object = getCall("HGET", object_id, "object")
|
||||
if object == nil then
|
||||
-- object expired, but there are keys that need to be cleaned up
|
||||
remove(object_id)
|
||||
return nil
|
||||
end
|
||||
|
||||
-- max-age must be checked manually
|
||||
local expiry = getCall("HGET", object_id, "expiry")
|
||||
if not (expiry == nil) and expiry > 0 then
|
||||
if getTime() > expiry then
|
||||
remove(object_id)
|
||||
return nil
|
||||
end
|
||||
end
|
||||
|
||||
local usage_lifetime = getCall("HGET", object_id, "usage_lifetime")
|
||||
-- reset usage based TTL
|
||||
if not (usage_lifetime == nil) and tonumber(usage_lifetime) > 0 then
|
||||
redis.call('EXPIRE', object_id, usage_lifetime)
|
||||
end
|
||||
|
||||
return object
|
9
internal/cache/connector/redis/invalidate.lua
vendored
Normal file
9
internal/cache/connector/redis/invalidate.lua
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
local n = #KEYS
|
||||
for i = 1, n do
|
||||
local result = redis.call("GET", KEYS[i])
|
||||
if result == false then
|
||||
return nil
|
||||
end
|
||||
local object_id = tostring(result)
|
||||
remove(object_id)
|
||||
end
|
172
internal/cache/connector/redis/redis.go
vendored
Normal file
172
internal/cache/connector/redis/redis.go
vendored
Normal file
@@ -0,0 +1,172 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/cache"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed _select.lua
|
||||
selectComponent string
|
||||
//go:embed _util.lua
|
||||
utilComponent string
|
||||
//go:embed _remove.lua
|
||||
removeComponent string
|
||||
//go:embed set.lua
|
||||
setScript string
|
||||
//go:embed get.lua
|
||||
getScript string
|
||||
//go:embed invalidate.lua
|
||||
invalidateScript string
|
||||
|
||||
// Don't mind the creative "import"
|
||||
setParsed = redis.NewScript(strings.Join([]string{selectComponent, utilComponent, setScript}, "\n"))
|
||||
getParsed = redis.NewScript(strings.Join([]string{selectComponent, utilComponent, removeComponent, getScript}, "\n"))
|
||||
invalidateParsed = redis.NewScript(strings.Join([]string{selectComponent, utilComponent, removeComponent, invalidateScript}, "\n"))
|
||||
)
|
||||
|
||||
type redisCache[I, K comparable, V cache.Entry[I, K]] struct {
|
||||
db int
|
||||
config *cache.Config
|
||||
indices []I
|
||||
connector *Connector
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewCache returns a cache that stores and retrieves object using single Redis.
|
||||
func NewCache[I, K comparable, V cache.Entry[I, K]](config cache.Config, client *Connector, db int, indices []I) cache.Cache[I, K, V] {
|
||||
return &redisCache[I, K, V]{
|
||||
config: &config,
|
||||
db: db,
|
||||
indices: indices,
|
||||
connector: client,
|
||||
logger: config.Log.Slog(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *redisCache[I, K, V]) Set(ctx context.Context, value V) {
|
||||
if _, err := c.set(ctx, value); err != nil {
|
||||
c.logger.ErrorContext(ctx, "redis cache set", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *redisCache[I, K, V]) set(ctx context.Context, value V) (objectID string, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
// Internal ID used for the object
|
||||
objectID = uuid.NewString()
|
||||
keys := []string{objectID}
|
||||
// flatten the secondary keys
|
||||
for _, index := range c.indices {
|
||||
keys = append(keys, c.redisIndexKeys(index, value.Keys(index)...)...)
|
||||
}
|
||||
var buf strings.Builder
|
||||
err = json.NewEncoder(&buf).Encode(value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
err = setParsed.Run(ctx, c.connector, keys,
|
||||
c.db, // DB namespace
|
||||
buf.String(), // object
|
||||
int64(c.config.LastUseAge/time.Second), // usage_lifetime
|
||||
int64(c.config.MaxAge/time.Second), // max_age,
|
||||
).Err()
|
||||
// redis.Nil is always returned because the script doesn't have a return value.
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return "", err
|
||||
}
|
||||
return objectID, nil
|
||||
}
|
||||
|
||||
func (c *redisCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) {
|
||||
var (
|
||||
obj any
|
||||
err error
|
||||
)
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
err = nil
|
||||
}
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
logger := c.logger.With("index", index, "key", key)
|
||||
obj, err = getParsed.Run(ctx, c.connector, c.redisIndexKeys(index, key), c.db).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
logger.ErrorContext(ctx, "redis cache get", "err", err)
|
||||
return value, false
|
||||
}
|
||||
data, ok := obj.(string)
|
||||
if !ok {
|
||||
logger.With("err", err).InfoContext(ctx, "redis cache miss")
|
||||
return value, false
|
||||
}
|
||||
err = json.NewDecoder(strings.NewReader(data)).Decode(&value)
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, "redis cache get", "err", fmt.Errorf("decode: %w", err))
|
||||
return value, false
|
||||
}
|
||||
return value, true
|
||||
}
|
||||
|
||||
func (c *redisCache[I, K, V]) Invalidate(ctx context.Context, index I, key ...K) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if len(key) == 0 {
|
||||
return nil
|
||||
}
|
||||
err = invalidateParsed.Run(ctx, c.connector, c.redisIndexKeys(index, key...), c.db).Err()
|
||||
// redis.Nil is always returned because the script doesn't have a return value.
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *redisCache[I, K, V]) Delete(ctx context.Context, index I, key ...K) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if len(key) == 0 {
|
||||
return nil
|
||||
}
|
||||
pipe := c.connector.Pipeline()
|
||||
pipe.Select(ctx, c.db)
|
||||
pipe.Del(ctx, c.redisIndexKeys(index, key...)...)
|
||||
_, err = pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *redisCache[I, K, V]) Truncate(ctx context.Context) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
pipe := c.connector.Pipeline()
|
||||
pipe.Select(ctx, c.db)
|
||||
pipe.FlushDB(ctx)
|
||||
_, err = pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *redisCache[I, K, V]) redisIndexKeys(index I, keys ...K) []string {
|
||||
out := make([]string, len(keys))
|
||||
for i, k := range keys {
|
||||
out[i] = fmt.Sprintf("%v:%v", index, k)
|
||||
}
|
||||
return out
|
||||
}
|
714
internal/cache/connector/redis/redis_test.go
vendored
Normal file
714
internal/cache/connector/redis/redis_test.go
vendored
Normal file
@@ -0,0 +1,714 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/cache"
|
||||
)
|
||||
|
||||
type testIndex int
|
||||
|
||||
const (
|
||||
testIndexID testIndex = iota
|
||||
testIndexName
|
||||
)
|
||||
|
||||
const (
|
||||
testDB = 99
|
||||
)
|
||||
|
||||
var testIndices = []testIndex{
|
||||
testIndexID,
|
||||
testIndexName,
|
||||
}
|
||||
|
||||
type testObject struct {
|
||||
ID string
|
||||
Name []string
|
||||
}
|
||||
|
||||
func (o *testObject) Keys(index testIndex) []string {
|
||||
switch index {
|
||||
case testIndexID:
|
||||
return []string{o.ID}
|
||||
case testIndexName:
|
||||
return o.Name
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func Test_redisCache_set(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
value *testObject
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
config cache.Config
|
||||
args args
|
||||
assertions func(t *testing.T, s *miniredis.Miniredis, objectID string)
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
config: cache.Config{},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
value: &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
},
|
||||
},
|
||||
assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) {
|
||||
s.CheckGet(t, "0:one", objectID)
|
||||
s.CheckGet(t, "1:foo", objectID)
|
||||
s.CheckGet(t, "1:bar", objectID)
|
||||
assert.Empty(t, s.HGet(objectID, "expiry"))
|
||||
assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with last use TTL",
|
||||
config: cache.Config{
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
value: &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
},
|
||||
},
|
||||
assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) {
|
||||
s.CheckGet(t, "0:one", objectID)
|
||||
s.CheckGet(t, "1:foo", objectID)
|
||||
s.CheckGet(t, "1:bar", objectID)
|
||||
assert.Empty(t, s.HGet(objectID, "expiry"))
|
||||
assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object"))
|
||||
assert.Positive(t, s.TTL(objectID))
|
||||
|
||||
s.FastForward(2 * time.Second)
|
||||
v, err := s.Get(objectID)
|
||||
require.Error(t, err)
|
||||
assert.Empty(t, v)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with last use TTL and max age",
|
||||
config: cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
value: &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
},
|
||||
},
|
||||
assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) {
|
||||
s.CheckGet(t, "0:one", objectID)
|
||||
s.CheckGet(t, "1:foo", objectID)
|
||||
s.CheckGet(t, "1:bar", objectID)
|
||||
assert.NotEmpty(t, s.HGet(objectID, "expiry"))
|
||||
assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object"))
|
||||
assert.Positive(t, s.TTL(objectID))
|
||||
|
||||
s.FastForward(2 * time.Second)
|
||||
v, err := s.Get(objectID)
|
||||
require.Error(t, err)
|
||||
assert.Empty(t, v)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with max age TTL",
|
||||
config: cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
value: &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
},
|
||||
},
|
||||
assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) {
|
||||
s.CheckGet(t, "0:one", objectID)
|
||||
s.CheckGet(t, "1:foo", objectID)
|
||||
s.CheckGet(t, "1:bar", objectID)
|
||||
assert.Empty(t, s.HGet(objectID, "expiry"))
|
||||
assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object"))
|
||||
assert.Positive(t, s.TTL(objectID))
|
||||
|
||||
s.FastForward(2 * time.Minute)
|
||||
v, err := s.Get(objectID)
|
||||
require.Error(t, err)
|
||||
assert.Empty(t, v)
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c, server := prepareCache(t, tt.config)
|
||||
rc := c.(*redisCache[testIndex, string, *testObject])
|
||||
objectID, err := rc.set(tt.args.ctx, tt.args.value)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
t.Log(rc.connector.HGetAll(context.Background(), objectID))
|
||||
tt.assertions(t, server, objectID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_redisCache_Get(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
index testIndex
|
||||
key string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
config cache.Config
|
||||
preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis)
|
||||
args args
|
||||
want *testObject
|
||||
wantOK bool
|
||||
}{
|
||||
{
|
||||
name: "connection error",
|
||||
config: cache.Config{},
|
||||
preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
s.RequireAuth("foobar")
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
index: testIndexName,
|
||||
key: "foo",
|
||||
},
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "get by ID",
|
||||
config: cache.Config{},
|
||||
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
c.Set(context.Background(), &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
})
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
index: testIndexID,
|
||||
key: "one",
|
||||
},
|
||||
want: &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
},
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "get by name",
|
||||
config: cache.Config{},
|
||||
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
c.Set(context.Background(), &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
})
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
index: testIndexName,
|
||||
key: "foo",
|
||||
},
|
||||
want: &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
},
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "usage timeout",
|
||||
config: cache.Config{
|
||||
LastUseAge: time.Minute,
|
||||
},
|
||||
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
c.Set(context.Background(), &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
})
|
||||
_, ok := c.Get(context.Background(), testIndexID, "one")
|
||||
require.True(t, ok)
|
||||
s.FastForward(2 * time.Minute)
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
index: testIndexName,
|
||||
key: "foo",
|
||||
},
|
||||
want: nil,
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "max age timeout",
|
||||
config: cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
},
|
||||
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
c.Set(context.Background(), &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
})
|
||||
_, ok := c.Get(context.Background(), testIndexID, "one")
|
||||
require.True(t, ok)
|
||||
s.FastForward(2 * time.Minute)
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
index: testIndexName,
|
||||
key: "foo",
|
||||
},
|
||||
want: nil,
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "not found",
|
||||
config: cache.Config{},
|
||||
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
c.Set(context.Background(), &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
})
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
index: testIndexName,
|
||||
key: "spanac",
|
||||
},
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "json decode error",
|
||||
config: cache.Config{},
|
||||
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
c.Set(context.Background(), &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
})
|
||||
objectID, err := s.Get(c.(*redisCache[testIndex, string, *testObject]).redisIndexKeys(testIndexID, "one")[0])
|
||||
require.NoError(t, err)
|
||||
s.HSet(objectID, "object", "~~~")
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
index: testIndexID,
|
||||
key: "one",
|
||||
},
|
||||
wantOK: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c, server := prepareCache(t, tt.config)
|
||||
tt.preparation(t, c, server)
|
||||
t.Log(server.Keys())
|
||||
|
||||
got, ok := c.Get(tt.args.ctx, tt.args.index, tt.args.key)
|
||||
require.Equal(t, tt.wantOK, ok)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_redisCache_Invalidate(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
index testIndex
|
||||
key []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
config cache.Config
|
||||
preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis)
|
||||
assertions func(t *testing.T, c cache.Cache[testIndex, string, *testObject])
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "connection error",
|
||||
config: cache.Config{},
|
||||
preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
s.RequireAuth("foobar")
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
index: testIndexName,
|
||||
key: []string{"foo"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no keys, noop",
|
||||
config: cache.Config{},
|
||||
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
c.Set(context.Background(), &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
})
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
index: testIndexID,
|
||||
key: []string{},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalidate by ID",
|
||||
config: cache.Config{},
|
||||
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
c.Set(context.Background(), &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
})
|
||||
},
|
||||
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
|
||||
obj, ok := c.Get(context.Background(), testIndexID, "one")
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, obj)
|
||||
obj, ok = c.Get(context.Background(), testIndexName, "foo")
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, obj)
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
index: testIndexID,
|
||||
key: []string{"one"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalidate by name",
|
||||
config: cache.Config{},
|
||||
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
c.Set(context.Background(), &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
})
|
||||
},
|
||||
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
|
||||
obj, ok := c.Get(context.Background(), testIndexID, "one")
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, obj)
|
||||
obj, ok = c.Get(context.Background(), testIndexName, "foo")
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, obj)
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
index: testIndexName,
|
||||
key: []string{"foo"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalidate after timeout",
|
||||
config: cache.Config{
|
||||
LastUseAge: time.Minute,
|
||||
},
|
||||
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
c.Set(context.Background(), &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
})
|
||||
_, ok := c.Get(context.Background(), testIndexID, "one")
|
||||
require.True(t, ok)
|
||||
s.FastForward(2 * time.Minute)
|
||||
},
|
||||
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
|
||||
obj, ok := c.Get(context.Background(), testIndexID, "one")
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, obj)
|
||||
obj, ok = c.Get(context.Background(), testIndexName, "foo")
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, obj)
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
index: testIndexName,
|
||||
key: []string{"foo"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "not found",
|
||||
config: cache.Config{},
|
||||
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
c.Set(context.Background(), &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
})
|
||||
},
|
||||
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
|
||||
obj, ok := c.Get(context.Background(), testIndexID, "one")
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, obj)
|
||||
obj, ok = c.Get(context.Background(), testIndexName, "foo")
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, obj)
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
index: testIndexName,
|
||||
key: []string{"spanac"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c, server := prepareCache(t, tt.config)
|
||||
tt.preparation(t, c, server)
|
||||
t.Log(server.Keys())
|
||||
|
||||
err := c.Invalidate(tt.args.ctx, tt.args.index, tt.args.key...)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_redisCache_Delete(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
index testIndex
|
||||
key []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
config cache.Config
|
||||
preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis)
|
||||
assertions func(t *testing.T, c cache.Cache[testIndex, string, *testObject])
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "connection error",
|
||||
config: cache.Config{},
|
||||
preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
s.RequireAuth("foobar")
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
index: testIndexName,
|
||||
key: []string{"foo"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no keys, noop",
|
||||
config: cache.Config{},
|
||||
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
c.Set(context.Background(), &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
})
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
index: testIndexID,
|
||||
key: []string{},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "delete ID",
|
||||
config: cache.Config{},
|
||||
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
c.Set(context.Background(), &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
})
|
||||
},
|
||||
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
|
||||
obj, ok := c.Get(context.Background(), testIndexID, "one")
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, obj)
|
||||
// Get be name should still work
|
||||
obj, ok = c.Get(context.Background(), testIndexName, "foo")
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, obj)
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
index: testIndexID,
|
||||
key: []string{"one"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "delete name",
|
||||
config: cache.Config{},
|
||||
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
c.Set(context.Background(), &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
})
|
||||
},
|
||||
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
|
||||
// get by ID should still work
|
||||
obj, ok := c.Get(context.Background(), testIndexID, "one")
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, obj)
|
||||
obj, ok = c.Get(context.Background(), testIndexName, "foo")
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, obj)
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
index: testIndexName,
|
||||
key: []string{"foo"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "not found",
|
||||
config: cache.Config{},
|
||||
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
c.Set(context.Background(), &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
})
|
||||
},
|
||||
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
|
||||
obj, ok := c.Get(context.Background(), testIndexID, "one")
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, obj)
|
||||
obj, ok = c.Get(context.Background(), testIndexName, "foo")
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, obj)
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
index: testIndexName,
|
||||
key: []string{"spanac"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c, server := prepareCache(t, tt.config)
|
||||
tt.preparation(t, c, server)
|
||||
t.Log(server.Keys())
|
||||
|
||||
err := c.Delete(tt.args.ctx, tt.args.index, tt.args.key...)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_redisCache_Truncate(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
config cache.Config
|
||||
preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis)
|
||||
assertions func(t *testing.T, c cache.Cache[testIndex, string, *testObject])
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "connection error",
|
||||
config: cache.Config{},
|
||||
preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
s.RequireAuth("foobar")
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
config: cache.Config{},
|
||||
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
|
||||
c.Set(context.Background(), &testObject{
|
||||
ID: "one",
|
||||
Name: []string{"foo", "bar"},
|
||||
})
|
||||
c.Set(context.Background(), &testObject{
|
||||
ID: "two",
|
||||
Name: []string{"Hello", "World"},
|
||||
})
|
||||
},
|
||||
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
|
||||
obj, ok := c.Get(context.Background(), testIndexID, "one")
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, obj)
|
||||
obj, ok = c.Get(context.Background(), testIndexName, "World")
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, obj)
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c, server := prepareCache(t, tt.config)
|
||||
tt.preparation(t, c, server)
|
||||
t.Log(server.Keys())
|
||||
|
||||
err := c.Truncate(tt.args.ctx)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func prepareCache(t *testing.T, conf cache.Config) (cache.Cache[testIndex, string, *testObject], *miniredis.Miniredis) {
|
||||
conf.Log = &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
}
|
||||
server := miniredis.RunT(t)
|
||||
server.Select(testDB)
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Network: "tcp",
|
||||
Addr: server.Addr(),
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
client.Close()
|
||||
server.Close()
|
||||
})
|
||||
connector := NewConnector(Config{
|
||||
Enabled: true,
|
||||
Network: "tcp",
|
||||
Addr: server.Addr(),
|
||||
})
|
||||
c := NewCache[testIndex, string, *testObject](conf, connector, testDB, testIndices)
|
||||
return c, server
|
||||
}
|
27
internal/cache/connector/redis/set.lua
vendored
Normal file
27
internal/cache/connector/redis/set.lua
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
-- KEYS: [1]: object_id; [>1]: index keys.
|
||||
local object_id = KEYS[1]
|
||||
local object = ARGV[2]
|
||||
local usage_lifetime = tonumber(ARGV[3]) -- usage based lifetime in seconds
|
||||
local max_age = tonumber(ARGV[4]) -- max age liftime in seconds
|
||||
|
||||
redis.call("HSET", object_id,"object", object)
|
||||
if usage_lifetime > 0 then
|
||||
redis.call("HSET", object_id, "usage_lifetime", usage_lifetime)
|
||||
-- enable usage based TTL
|
||||
redis.call("EXPIRE", object_id, usage_lifetime)
|
||||
if max_age > 0 then
|
||||
-- set max_age to hash map for expired remove on Get
|
||||
local expiry = getTime() + max_age
|
||||
redis.call("HSET", object_id, "expiry", expiry)
|
||||
end
|
||||
elseif max_age > 0 then
|
||||
-- enable max_age based TTL
|
||||
redis.call("EXPIRE", object_id, max_age)
|
||||
end
|
||||
|
||||
local n = #KEYS
|
||||
local setKey = keySetKey(object_id)
|
||||
for i = 2, n do -- offset to the second element to skip object_id
|
||||
redis.call("SADD", setKey, KEYS[i]) -- set of all keys used for housekeeping
|
||||
redis.call("SET", KEYS[i], object_id) -- key to object_id mapping
|
||||
end
|
Reference in New Issue
Block a user