feat(cache): redis cache (#8822)

# Which Problems Are Solved

Add a cache implementation using Redis single mode. This does not add
support for Redis Cluster or sentinel.

# How the Problems Are Solved

Added the `internal/cache/redis` package. All operations occur
atomically, including setting of secondary indexes, using LUA scripts
where needed.

The [`miniredis`](https://github.com/alicebob/miniredis) package is used
to run unit tests.

# Additional Changes

- Move connector code to `internal/cache/connector/...` and remove
duplicate code from `query` and `command` packages.
- Fix a missed invalidation on the restrictions projection

# Additional Context

Closes #8130
This commit is contained in:
Tim Möhlmann
2024-11-04 11:44:51 +01:00
committed by GitHub
parent 9c3e5e467b
commit 250f2344c8
50 changed files with 1767 additions and 293 deletions

View File

@@ -0,0 +1,10 @@
local function remove(object_id)
local setKey = keySetKey(object_id)
local keys = redis.call("SMEMBERS", setKey)
local n = #keys
for i = 1, n do
redis.call("DEL", keys[i])
end
redis.call("DEL", setKey)
redis.call("DEL", object_id)
end

View File

@@ -0,0 +1,3 @@
-- SELECT ensures the DB namespace for each script.
-- When used, it consumes the first ARGV entry.
redis.call("SELECT", ARGV[1])

View File

@@ -0,0 +1,17 @@
-- keySetKey returns the redis key of the set containing all keys to the object.
local function keySetKey (object_id)
return object_id .. "-keys"
end
local function getTime()
return tonumber(redis.call('TIME')[1])
end
-- getCall wrapts redis.call so a nil is returned instead of false.
local function getCall (...)
local result = redis.call(...)
if result == false then
return nil
end
return result
end

View File

@@ -0,0 +1,154 @@
package redis
import (
"crypto/tls"
"time"
"github.com/redis/go-redis/v9"
)
type Config struct {
Enabled bool
// The network type, either tcp or unix.
// Default is tcp.
Network string
// host:port address.
Addr string
// ClientName will execute the `CLIENT SETNAME ClientName` command for each conn.
ClientName string
// Use the specified Username to authenticate the current connection
// with one of the connections defined in the ACL list when connecting
// to a Redis 6.0 instance, or greater, that is using the Redis ACL system.
Username string
// Optional password. Must match the password specified in the
// requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower),
// or the User Password when connecting to a Redis 6.0 instance, or greater,
// that is using the Redis ACL system.
Password string
// Each ZITADEL cache uses an incremental DB namespace.
// This option offsets the first DB so it doesn't conflict with other databases on the same server.
// Note that ZITADEL uses FLUSHDB command to truncate a cache.
// This can have destructive consequences when overlapping DB namespaces are used.
DBOffset int
// Maximum number of retries before giving up.
// Default is 3 retries; -1 (not 0) disables retries.
MaxRetries int
// Minimum backoff between each retry.
// Default is 8 milliseconds; -1 disables backoff.
MinRetryBackoff time.Duration
// Maximum backoff between each retry.
// Default is 512 milliseconds; -1 disables backoff.
MaxRetryBackoff time.Duration
// Dial timeout for establishing new connections.
// Default is 5 seconds.
DialTimeout time.Duration
// Timeout for socket reads. If reached, commands will fail
// with a timeout instead of blocking. Supported values:
// - `0` - default timeout (3 seconds).
// - `-1` - no timeout (block indefinitely).
// - `-2` - disables SetReadDeadline calls completely.
ReadTimeout time.Duration
// Timeout for socket writes. If reached, commands will fail
// with a timeout instead of blocking. Supported values:
// - `0` - default timeout (3 seconds).
// - `-1` - no timeout (block indefinitely).
// - `-2` - disables SetWriteDeadline calls completely.
WriteTimeout time.Duration
// Type of connection pool.
// true for FIFO pool, false for LIFO pool.
// Note that FIFO has slightly higher overhead compared to LIFO,
// but it helps closing idle connections faster reducing the pool size.
PoolFIFO bool
// Base number of socket connections.
// Default is 10 connections per every available CPU as reported by runtime.GOMAXPROCS.
// If there is not enough connections in the pool, new connections will be allocated in excess of PoolSize,
// you can limit it through MaxActiveConns
PoolSize int
// Amount of time client waits for connection if all connections
// are busy before returning an error.
// Default is ReadTimeout + 1 second.
PoolTimeout time.Duration
// Minimum number of idle connections which is useful when establishing
// new connection is slow.
// Default is 0. the idle connections are not closed by default.
MinIdleConns int
// Maximum number of idle connections.
// Default is 0. the idle connections are not closed by default.
MaxIdleConns int
// Maximum number of connections allocated by the pool at a given time.
// When zero, there is no limit on the number of connections in the pool.
MaxActiveConns int
// ConnMaxIdleTime is the maximum amount of time a connection may be idle.
// Should be less than server's timeout.
//
// Expired connections may be closed lazily before reuse.
// If d <= 0, connections are not closed due to a connection's idle time.
//
// Default is 30 minutes. -1 disables idle timeout check.
ConnMaxIdleTime time.Duration
// ConnMaxLifetime is the maximum amount of time a connection may be reused.
//
// Expired connections may be closed lazily before reuse.
// If <= 0, connections are not closed due to a connection's age.
//
// Default is to not close idle connections.
ConnMaxLifetime time.Duration
EnableTLS bool
// Disable set-lib on connect. Default is false.
DisableIndentity bool
// Add suffix to client name. Default is empty.
IdentitySuffix string
}
type Connector struct {
*redis.Client
Config Config
}
func NewConnector(config Config) *Connector {
if !config.Enabled {
return nil
}
return &Connector{
Client: redis.NewClient(optionsFromConfig(config)),
Config: config,
}
}
func optionsFromConfig(c Config) *redis.Options {
opts := &redis.Options{
Network: c.Network,
Addr: c.Addr,
ClientName: c.ClientName,
Protocol: 3,
Username: c.Username,
Password: c.Password,
MaxRetries: c.MaxRetries,
MinRetryBackoff: c.MinRetryBackoff,
MaxRetryBackoff: c.MaxRetryBackoff,
DialTimeout: c.DialTimeout,
ReadTimeout: c.ReadTimeout,
WriteTimeout: c.WriteTimeout,
ContextTimeoutEnabled: true,
PoolFIFO: c.PoolFIFO,
PoolTimeout: c.PoolTimeout,
MinIdleConns: c.MinIdleConns,
MaxIdleConns: c.MaxIdleConns,
MaxActiveConns: c.MaxActiveConns,
ConnMaxIdleTime: c.ConnMaxIdleTime,
ConnMaxLifetime: c.ConnMaxLifetime,
DisableIndentity: c.DisableIndentity,
IdentitySuffix: c.IdentitySuffix,
}
if c.EnableTLS {
opts.TLSConfig = new(tls.Config)
}
return opts
}

29
internal/cache/connector/redis/get.lua vendored Normal file
View File

@@ -0,0 +1,29 @@
local result = redis.call("GET", KEYS[1])
if result == false then
return nil
end
local object_id = tostring(result)
local object = getCall("HGET", object_id, "object")
if object == nil then
-- object expired, but there are keys that need to be cleaned up
remove(object_id)
return nil
end
-- max-age must be checked manually
local expiry = getCall("HGET", object_id, "expiry")
if not (expiry == nil) and expiry > 0 then
if getTime() > expiry then
remove(object_id)
return nil
end
end
local usage_lifetime = getCall("HGET", object_id, "usage_lifetime")
-- reset usage based TTL
if not (usage_lifetime == nil) and tonumber(usage_lifetime) > 0 then
redis.call('EXPIRE', object_id, usage_lifetime)
end
return object

View File

@@ -0,0 +1,9 @@
local n = #KEYS
for i = 1, n do
local result = redis.call("GET", KEYS[i])
if result == false then
return nil
end
local object_id = tostring(result)
remove(object_id)
end

172
internal/cache/connector/redis/redis.go vendored Normal file
View File

@@ -0,0 +1,172 @@
package redis
import (
"context"
_ "embed"
"encoding/json"
"errors"
"fmt"
"log/slog"
"strings"
"time"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
var (
//go:embed _select.lua
selectComponent string
//go:embed _util.lua
utilComponent string
//go:embed _remove.lua
removeComponent string
//go:embed set.lua
setScript string
//go:embed get.lua
getScript string
//go:embed invalidate.lua
invalidateScript string
// Don't mind the creative "import"
setParsed = redis.NewScript(strings.Join([]string{selectComponent, utilComponent, setScript}, "\n"))
getParsed = redis.NewScript(strings.Join([]string{selectComponent, utilComponent, removeComponent, getScript}, "\n"))
invalidateParsed = redis.NewScript(strings.Join([]string{selectComponent, utilComponent, removeComponent, invalidateScript}, "\n"))
)
type redisCache[I, K comparable, V cache.Entry[I, K]] struct {
db int
config *cache.Config
indices []I
connector *Connector
logger *slog.Logger
}
// NewCache returns a cache that stores and retrieves object using single Redis.
func NewCache[I, K comparable, V cache.Entry[I, K]](config cache.Config, client *Connector, db int, indices []I) cache.Cache[I, K, V] {
return &redisCache[I, K, V]{
config: &config,
db: db,
indices: indices,
connector: client,
logger: config.Log.Slog(),
}
}
func (c *redisCache[I, K, V]) Set(ctx context.Context, value V) {
if _, err := c.set(ctx, value); err != nil {
c.logger.ErrorContext(ctx, "redis cache set", "err", err)
}
}
func (c *redisCache[I, K, V]) set(ctx context.Context, value V) (objectID string, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
// Internal ID used for the object
objectID = uuid.NewString()
keys := []string{objectID}
// flatten the secondary keys
for _, index := range c.indices {
keys = append(keys, c.redisIndexKeys(index, value.Keys(index)...)...)
}
var buf strings.Builder
err = json.NewEncoder(&buf).Encode(value)
if err != nil {
return "", err
}
err = setParsed.Run(ctx, c.connector, keys,
c.db, // DB namespace
buf.String(), // object
int64(c.config.LastUseAge/time.Second), // usage_lifetime
int64(c.config.MaxAge/time.Second), // max_age,
).Err()
// redis.Nil is always returned because the script doesn't have a return value.
if err != nil && !errors.Is(err, redis.Nil) {
return "", err
}
return objectID, nil
}
func (c *redisCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) {
var (
obj any
err error
)
ctx, span := tracing.NewSpan(ctx)
defer func() {
if errors.Is(err, redis.Nil) {
err = nil
}
span.EndWithError(err)
}()
logger := c.logger.With("index", index, "key", key)
obj, err = getParsed.Run(ctx, c.connector, c.redisIndexKeys(index, key), c.db).Result()
if err != nil && !errors.Is(err, redis.Nil) {
logger.ErrorContext(ctx, "redis cache get", "err", err)
return value, false
}
data, ok := obj.(string)
if !ok {
logger.With("err", err).InfoContext(ctx, "redis cache miss")
return value, false
}
err = json.NewDecoder(strings.NewReader(data)).Decode(&value)
if err != nil {
logger.ErrorContext(ctx, "redis cache get", "err", fmt.Errorf("decode: %w", err))
return value, false
}
return value, true
}
func (c *redisCache[I, K, V]) Invalidate(ctx context.Context, index I, key ...K) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if len(key) == 0 {
return nil
}
err = invalidateParsed.Run(ctx, c.connector, c.redisIndexKeys(index, key...), c.db).Err()
// redis.Nil is always returned because the script doesn't have a return value.
if err != nil && !errors.Is(err, redis.Nil) {
return err
}
return nil
}
func (c *redisCache[I, K, V]) Delete(ctx context.Context, index I, key ...K) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if len(key) == 0 {
return nil
}
pipe := c.connector.Pipeline()
pipe.Select(ctx, c.db)
pipe.Del(ctx, c.redisIndexKeys(index, key...)...)
_, err = pipe.Exec(ctx)
return err
}
func (c *redisCache[I, K, V]) Truncate(ctx context.Context) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
pipe := c.connector.Pipeline()
pipe.Select(ctx, c.db)
pipe.FlushDB(ctx)
_, err = pipe.Exec(ctx)
return err
}
func (c *redisCache[I, K, V]) redisIndexKeys(index I, keys ...K) []string {
out := make([]string, len(keys))
for i, k := range keys {
out[i] = fmt.Sprintf("%v:%v", index, k)
}
return out
}

View File

@@ -0,0 +1,714 @@
package redis
import (
"context"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/cache"
)
type testIndex int
const (
testIndexID testIndex = iota
testIndexName
)
const (
testDB = 99
)
var testIndices = []testIndex{
testIndexID,
testIndexName,
}
type testObject struct {
ID string
Name []string
}
func (o *testObject) Keys(index testIndex) []string {
switch index {
case testIndexID:
return []string{o.ID}
case testIndexName:
return o.Name
default:
return nil
}
}
func Test_redisCache_set(t *testing.T) {
type args struct {
ctx context.Context
value *testObject
}
tests := []struct {
name string
config cache.Config
args args
assertions func(t *testing.T, s *miniredis.Miniredis, objectID string)
wantErr error
}{
{
name: "ok",
config: cache.Config{},
args: args{
ctx: context.Background(),
value: &testObject{
ID: "one",
Name: []string{"foo", "bar"},
},
},
assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) {
s.CheckGet(t, "0:one", objectID)
s.CheckGet(t, "1:foo", objectID)
s.CheckGet(t, "1:bar", objectID)
assert.Empty(t, s.HGet(objectID, "expiry"))
assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object"))
},
},
{
name: "with last use TTL",
config: cache.Config{
LastUseAge: time.Second,
},
args: args{
ctx: context.Background(),
value: &testObject{
ID: "one",
Name: []string{"foo", "bar"},
},
},
assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) {
s.CheckGet(t, "0:one", objectID)
s.CheckGet(t, "1:foo", objectID)
s.CheckGet(t, "1:bar", objectID)
assert.Empty(t, s.HGet(objectID, "expiry"))
assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object"))
assert.Positive(t, s.TTL(objectID))
s.FastForward(2 * time.Second)
v, err := s.Get(objectID)
require.Error(t, err)
assert.Empty(t, v)
},
},
{
name: "with last use TTL and max age",
config: cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
args: args{
ctx: context.Background(),
value: &testObject{
ID: "one",
Name: []string{"foo", "bar"},
},
},
assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) {
s.CheckGet(t, "0:one", objectID)
s.CheckGet(t, "1:foo", objectID)
s.CheckGet(t, "1:bar", objectID)
assert.NotEmpty(t, s.HGet(objectID, "expiry"))
assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object"))
assert.Positive(t, s.TTL(objectID))
s.FastForward(2 * time.Second)
v, err := s.Get(objectID)
require.Error(t, err)
assert.Empty(t, v)
},
},
{
name: "with max age TTL",
config: cache.Config{
MaxAge: time.Minute,
},
args: args{
ctx: context.Background(),
value: &testObject{
ID: "one",
Name: []string{"foo", "bar"},
},
},
assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) {
s.CheckGet(t, "0:one", objectID)
s.CheckGet(t, "1:foo", objectID)
s.CheckGet(t, "1:bar", objectID)
assert.Empty(t, s.HGet(objectID, "expiry"))
assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object"))
assert.Positive(t, s.TTL(objectID))
s.FastForward(2 * time.Minute)
v, err := s.Get(objectID)
require.Error(t, err)
assert.Empty(t, v)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, server := prepareCache(t, tt.config)
rc := c.(*redisCache[testIndex, string, *testObject])
objectID, err := rc.set(tt.args.ctx, tt.args.value)
require.ErrorIs(t, err, tt.wantErr)
t.Log(rc.connector.HGetAll(context.Background(), objectID))
tt.assertions(t, server, objectID)
})
}
}
func Test_redisCache_Get(t *testing.T) {
type args struct {
ctx context.Context
index testIndex
key string
}
tests := []struct {
name string
config cache.Config
preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis)
args args
want *testObject
wantOK bool
}{
{
name: "connection error",
config: cache.Config{},
preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
s.RequireAuth("foobar")
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: "foo",
},
wantOK: false,
},
{
name: "get by ID",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
args: args{
ctx: context.Background(),
index: testIndexID,
key: "one",
},
want: &testObject{
ID: "one",
Name: []string{"foo", "bar"},
},
wantOK: true,
},
{
name: "get by name",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: "foo",
},
want: &testObject{
ID: "one",
Name: []string{"foo", "bar"},
},
wantOK: true,
},
{
name: "usage timeout",
config: cache.Config{
LastUseAge: time.Minute,
},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
_, ok := c.Get(context.Background(), testIndexID, "one")
require.True(t, ok)
s.FastForward(2 * time.Minute)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: "foo",
},
want: nil,
wantOK: false,
},
{
name: "max age timeout",
config: cache.Config{
MaxAge: time.Minute,
},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
_, ok := c.Get(context.Background(), testIndexID, "one")
require.True(t, ok)
s.FastForward(2 * time.Minute)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: "foo",
},
want: nil,
wantOK: false,
},
{
name: "not found",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: "spanac",
},
wantOK: false,
},
{
name: "json decode error",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
objectID, err := s.Get(c.(*redisCache[testIndex, string, *testObject]).redisIndexKeys(testIndexID, "one")[0])
require.NoError(t, err)
s.HSet(objectID, "object", "~~~")
},
args: args{
ctx: context.Background(),
index: testIndexID,
key: "one",
},
wantOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, server := prepareCache(t, tt.config)
tt.preparation(t, c, server)
t.Log(server.Keys())
got, ok := c.Get(tt.args.ctx, tt.args.index, tt.args.key)
require.Equal(t, tt.wantOK, ok)
assert.Equal(t, tt.want, got)
})
}
}
func Test_redisCache_Invalidate(t *testing.T) {
type args struct {
ctx context.Context
index testIndex
key []string
}
tests := []struct {
name string
config cache.Config
preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis)
assertions func(t *testing.T, c cache.Cache[testIndex, string, *testObject])
args args
wantErr bool
}{
{
name: "connection error",
config: cache.Config{},
preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
s.RequireAuth("foobar")
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"foo"},
},
wantErr: true,
},
{
name: "no keys, noop",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
args: args{
ctx: context.Background(),
index: testIndexID,
key: []string{},
},
wantErr: false,
},
{
name: "invalidate by ID",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.False(t, ok)
assert.Nil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.False(t, ok)
assert.Nil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexID,
key: []string{"one"},
},
wantErr: false,
},
{
name: "invalidate by name",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.False(t, ok)
assert.Nil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.False(t, ok)
assert.Nil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"foo"},
},
wantErr: false,
},
{
name: "invalidate after timeout",
config: cache.Config{
LastUseAge: time.Minute,
},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
_, ok := c.Get(context.Background(), testIndexID, "one")
require.True(t, ok)
s.FastForward(2 * time.Minute)
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.False(t, ok)
assert.Nil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.False(t, ok)
assert.Nil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"foo"},
},
wantErr: false,
},
{
name: "not found",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.True(t, ok)
assert.NotNil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.True(t, ok)
assert.NotNil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"spanac"},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, server := prepareCache(t, tt.config)
tt.preparation(t, c, server)
t.Log(server.Keys())
err := c.Invalidate(tt.args.ctx, tt.args.index, tt.args.key...)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
})
}
}
func Test_redisCache_Delete(t *testing.T) {
type args struct {
ctx context.Context
index testIndex
key []string
}
tests := []struct {
name string
config cache.Config
preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis)
assertions func(t *testing.T, c cache.Cache[testIndex, string, *testObject])
args args
wantErr bool
}{
{
name: "connection error",
config: cache.Config{},
preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
s.RequireAuth("foobar")
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"foo"},
},
wantErr: true,
},
{
name: "no keys, noop",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
args: args{
ctx: context.Background(),
index: testIndexID,
key: []string{},
},
wantErr: false,
},
{
name: "delete ID",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.False(t, ok)
assert.Nil(t, obj)
// Get be name should still work
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.True(t, ok)
assert.NotNil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexID,
key: []string{"one"},
},
wantErr: false,
},
{
name: "delete name",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
// get by ID should still work
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.True(t, ok)
assert.NotNil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.False(t, ok)
assert.Nil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"foo"},
},
wantErr: false,
},
{
name: "not found",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.True(t, ok)
assert.NotNil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.True(t, ok)
assert.NotNil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"spanac"},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, server := prepareCache(t, tt.config)
tt.preparation(t, c, server)
t.Log(server.Keys())
err := c.Delete(tt.args.ctx, tt.args.index, tt.args.key...)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
})
}
}
func Test_redisCache_Truncate(t *testing.T) {
type args struct {
ctx context.Context
}
tests := []struct {
name string
config cache.Config
preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis)
assertions func(t *testing.T, c cache.Cache[testIndex, string, *testObject])
args args
wantErr bool
}{
{
name: "connection error",
config: cache.Config{},
preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
s.RequireAuth("foobar")
},
args: args{
ctx: context.Background(),
},
wantErr: true,
},
{
name: "ok",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
c.Set(context.Background(), &testObject{
ID: "two",
Name: []string{"Hello", "World"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.False(t, ok)
assert.Nil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "World")
assert.False(t, ok)
assert.Nil(t, obj)
},
args: args{
ctx: context.Background(),
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, server := prepareCache(t, tt.config)
tt.preparation(t, c, server)
t.Log(server.Keys())
err := c.Truncate(tt.args.ctx)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
})
}
}
func prepareCache(t *testing.T, conf cache.Config) (cache.Cache[testIndex, string, *testObject], *miniredis.Miniredis) {
conf.Log = &logging.Config{
Level: "debug",
AddSource: true,
}
server := miniredis.RunT(t)
server.Select(testDB)
client := redis.NewClient(&redis.Options{
Network: "tcp",
Addr: server.Addr(),
})
t.Cleanup(func() {
client.Close()
server.Close()
})
connector := NewConnector(Config{
Enabled: true,
Network: "tcp",
Addr: server.Addr(),
})
c := NewCache[testIndex, string, *testObject](conf, connector, testDB, testIndices)
return c, server
}

27
internal/cache/connector/redis/set.lua vendored Normal file
View File

@@ -0,0 +1,27 @@
-- KEYS: [1]: object_id; [>1]: index keys.
local object_id = KEYS[1]
local object = ARGV[2]
local usage_lifetime = tonumber(ARGV[3]) -- usage based lifetime in seconds
local max_age = tonumber(ARGV[4]) -- max age liftime in seconds
redis.call("HSET", object_id,"object", object)
if usage_lifetime > 0 then
redis.call("HSET", object_id, "usage_lifetime", usage_lifetime)
-- enable usage based TTL
redis.call("EXPIRE", object_id, usage_lifetime)
if max_age > 0 then
-- set max_age to hash map for expired remove on Get
local expiry = getTime() + max_age
redis.call("HSET", object_id, "expiry", expiry)
end
elseif max_age > 0 then
-- enable max_age based TTL
redis.call("EXPIRE", object_id, max_age)
end
local n = #KEYS
local setKey = keySetKey(object_id)
for i = 2, n do -- offset to the second element to skip object_id
redis.call("SADD", setKey, KEYS[i]) -- set of all keys used for housekeeping
redis.call("SET", KEYS[i], object_id) -- key to object_id mapping
end