722 lines
18 KiB
Go
Raw Normal View History

package redis
import (
"context"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/cache"
)
type testIndex int
const (
testIndexID testIndex = iota
testIndexName
)
const (
testDB = 99
)
var testIndices = []testIndex{
testIndexID,
testIndexName,
}
type testObject struct {
ID string
Name []string
}
func (o *testObject) Keys(index testIndex) []string {
switch index {
case testIndexID:
return []string{o.ID}
case testIndexName:
return o.Name
default:
return nil
}
}
func Test_redisCache_set(t *testing.T) {
type args struct {
ctx context.Context
value *testObject
}
tests := []struct {
name string
config cache.Config
args args
assertions func(t *testing.T, s *miniredis.Miniredis, objectID string)
wantErr error
}{
{
name: "ok",
config: cache.Config{},
args: args{
ctx: context.Background(),
value: &testObject{
ID: "one",
Name: []string{"foo", "bar"},
},
},
assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) {
s.CheckGet(t, "0:one", objectID)
s.CheckGet(t, "1:foo", objectID)
s.CheckGet(t, "1:bar", objectID)
assert.Empty(t, s.HGet(objectID, "expiry"))
assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object"))
},
},
{
name: "with last use TTL",
config: cache.Config{
LastUseAge: time.Second,
},
args: args{
ctx: context.Background(),
value: &testObject{
ID: "one",
Name: []string{"foo", "bar"},
},
},
assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) {
s.CheckGet(t, "0:one", objectID)
s.CheckGet(t, "1:foo", objectID)
s.CheckGet(t, "1:bar", objectID)
assert.Empty(t, s.HGet(objectID, "expiry"))
assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object"))
assert.Positive(t, s.TTL(objectID))
s.FastForward(2 * time.Second)
v, err := s.Get(objectID)
require.Error(t, err)
assert.Empty(t, v)
},
},
{
name: "with last use TTL and max age",
config: cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
args: args{
ctx: context.Background(),
value: &testObject{
ID: "one",
Name: []string{"foo", "bar"},
},
},
assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) {
s.CheckGet(t, "0:one", objectID)
s.CheckGet(t, "1:foo", objectID)
s.CheckGet(t, "1:bar", objectID)
assert.NotEmpty(t, s.HGet(objectID, "expiry"))
assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object"))
assert.Positive(t, s.TTL(objectID))
s.FastForward(2 * time.Second)
v, err := s.Get(objectID)
require.Error(t, err)
assert.Empty(t, v)
},
},
{
name: "with max age TTL",
config: cache.Config{
MaxAge: time.Minute,
},
args: args{
ctx: context.Background(),
value: &testObject{
ID: "one",
Name: []string{"foo", "bar"},
},
},
assertions: func(t *testing.T, s *miniredis.Miniredis, objectID string) {
s.CheckGet(t, "0:one", objectID)
s.CheckGet(t, "1:foo", objectID)
s.CheckGet(t, "1:bar", objectID)
assert.Empty(t, s.HGet(objectID, "expiry"))
assert.JSONEq(t, `{"ID":"one","Name":["foo","bar"]}`, s.HGet(objectID, "object"))
assert.Positive(t, s.TTL(objectID))
s.FastForward(2 * time.Minute)
v, err := s.Get(objectID)
require.Error(t, err)
assert.Empty(t, v)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, server := prepareCache(t, tt.config)
rc := c.(*redisCache[testIndex, string, *testObject])
objectID, err := rc.set(tt.args.ctx, tt.args.value)
require.ErrorIs(t, err, tt.wantErr)
t.Log(rc.connector.HGetAll(context.Background(), objectID))
tt.assertions(t, server, objectID)
})
}
}
func Test_redisCache_Get(t *testing.T) {
type args struct {
ctx context.Context
index testIndex
key string
}
tests := []struct {
name string
config cache.Config
preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis)
args args
want *testObject
wantOK bool
}{
{
name: "connection error",
config: cache.Config{},
preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
s.RequireAuth("foobar")
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: "foo",
},
wantOK: false,
},
{
name: "get by ID",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
args: args{
ctx: context.Background(),
index: testIndexID,
key: "one",
},
want: &testObject{
ID: "one",
Name: []string{"foo", "bar"},
},
wantOK: true,
},
{
name: "get by name",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: "foo",
},
want: &testObject{
ID: "one",
Name: []string{"foo", "bar"},
},
wantOK: true,
},
{
name: "usage timeout",
config: cache.Config{
LastUseAge: time.Minute,
},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
_, ok := c.Get(context.Background(), testIndexID, "one")
require.True(t, ok)
s.FastForward(2 * time.Minute)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: "foo",
},
want: nil,
wantOK: false,
},
{
name: "max age timeout",
config: cache.Config{
MaxAge: time.Minute,
},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
_, ok := c.Get(context.Background(), testIndexID, "one")
require.True(t, ok)
s.FastForward(2 * time.Minute)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: "foo",
},
want: nil,
wantOK: false,
},
{
name: "not found",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: "spanac",
},
wantOK: false,
},
{
name: "json decode error",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
objectID, err := s.Get(c.(*redisCache[testIndex, string, *testObject]).redisIndexKeys(testIndexID, "one")[0])
require.NoError(t, err)
s.HSet(objectID, "object", "~~~")
},
args: args{
ctx: context.Background(),
index: testIndexID,
key: "one",
},
wantOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, server := prepareCache(t, tt.config)
tt.preparation(t, c, server)
t.Log(server.Keys())
got, ok := c.Get(tt.args.ctx, tt.args.index, tt.args.key)
require.Equal(t, tt.wantOK, ok)
assert.Equal(t, tt.want, got)
})
}
}
func Test_redisCache_Invalidate(t *testing.T) {
type args struct {
ctx context.Context
index testIndex
key []string
}
tests := []struct {
name string
config cache.Config
preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis)
assertions func(t *testing.T, c cache.Cache[testIndex, string, *testObject])
args args
wantErr bool
}{
{
name: "connection error",
config: cache.Config{},
preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
s.RequireAuth("foobar")
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"foo"},
},
wantErr: true,
},
{
name: "no keys, noop",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
args: args{
ctx: context.Background(),
index: testIndexID,
key: []string{},
},
wantErr: false,
},
{
name: "invalidate by ID",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.False(t, ok)
assert.Nil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.False(t, ok)
assert.Nil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexID,
key: []string{"one"},
},
wantErr: false,
},
{
name: "invalidate by name",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.False(t, ok)
assert.Nil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.False(t, ok)
assert.Nil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"foo"},
},
wantErr: false,
},
{
name: "invalidate after timeout",
config: cache.Config{
LastUseAge: time.Minute,
},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
_, ok := c.Get(context.Background(), testIndexID, "one")
require.True(t, ok)
s.FastForward(2 * time.Minute)
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.False(t, ok)
assert.Nil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.False(t, ok)
assert.Nil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"foo"},
},
wantErr: false,
},
{
name: "not found",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.True(t, ok)
assert.NotNil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.True(t, ok)
assert.NotNil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"spanac"},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, server := prepareCache(t, tt.config)
tt.preparation(t, c, server)
t.Log(server.Keys())
err := c.Invalidate(tt.args.ctx, tt.args.index, tt.args.key...)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
})
}
}
func Test_redisCache_Delete(t *testing.T) {
type args struct {
ctx context.Context
index testIndex
key []string
}
tests := []struct {
name string
config cache.Config
preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis)
assertions func(t *testing.T, c cache.Cache[testIndex, string, *testObject])
args args
wantErr bool
}{
{
name: "connection error",
config: cache.Config{},
preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
s.RequireAuth("foobar")
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"foo"},
},
wantErr: true,
},
{
name: "no keys, noop",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
args: args{
ctx: context.Background(),
index: testIndexID,
key: []string{},
},
wantErr: false,
},
{
name: "delete ID",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.False(t, ok)
assert.Nil(t, obj)
// Get be name should still work
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.True(t, ok)
assert.NotNil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexID,
key: []string{"one"},
},
wantErr: false,
},
{
name: "delete name",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
// get by ID should still work
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.True(t, ok)
assert.NotNil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.False(t, ok)
assert.Nil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"foo"},
},
wantErr: false,
},
{
name: "not found",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.True(t, ok)
assert.NotNil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "foo")
assert.True(t, ok)
assert.NotNil(t, obj)
},
args: args{
ctx: context.Background(),
index: testIndexName,
key: []string{"spanac"},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, server := prepareCache(t, tt.config)
tt.preparation(t, c, server)
t.Log(server.Keys())
err := c.Delete(tt.args.ctx, tt.args.index, tt.args.key...)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
})
}
}
func Test_redisCache_Truncate(t *testing.T) {
type args struct {
ctx context.Context
}
tests := []struct {
name string
config cache.Config
preparation func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis)
assertions func(t *testing.T, c cache.Cache[testIndex, string, *testObject])
args args
wantErr bool
}{
{
name: "connection error",
config: cache.Config{},
preparation: func(_ *testing.T, _ cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
s.RequireAuth("foobar")
},
args: args{
ctx: context.Background(),
},
wantErr: true,
},
{
name: "ok",
config: cache.Config{},
preparation: func(t *testing.T, c cache.Cache[testIndex, string, *testObject], s *miniredis.Miniredis) {
c.Set(context.Background(), &testObject{
ID: "one",
Name: []string{"foo", "bar"},
})
c.Set(context.Background(), &testObject{
ID: "two",
Name: []string{"Hello", "World"},
})
},
assertions: func(t *testing.T, c cache.Cache[testIndex, string, *testObject]) {
obj, ok := c.Get(context.Background(), testIndexID, "one")
assert.False(t, ok)
assert.Nil(t, obj)
obj, ok = c.Get(context.Background(), testIndexName, "World")
assert.False(t, ok)
assert.Nil(t, obj)
},
args: args{
ctx: context.Background(),
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, server := prepareCache(t, tt.config)
tt.preparation(t, c, server)
t.Log(server.Keys())
err := c.Truncate(tt.args.ctx)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
})
}
}
func prepareCache(t *testing.T, conf cache.Config, options ...func(*Config)) (cache.Cache[testIndex, string, *testObject], *miniredis.Miniredis) {
conf.Log = &logging.Config{
Level: "debug",
AddSource: true,
}
server := miniredis.RunT(t)
server.Select(testDB)
connConfig := Config{
Enabled: true,
Network: "tcp",
Addr: server.Addr(),
DisableIndentity: true,
}
for _, option := range options {
option(&connConfig)
}
connector := NewConnector(connConfig)
t.Cleanup(func() {
connector.Close()
server.Close()
})
c := NewCache[testIndex, string, *testObject](conf, connector, testDB, testIndices)
return c, server
}
func withCircuitBreakerOption(cb *CBConfig) func(*Config) {
return func(c *Config) {
c.CircuitBreaker = cb
}
}