feat(storage): generic cache interface (#8628)

# Which Problems Are Solved

We identified the need of caching.
Currently we have a number of places where we use different ways of
caching, like go maps or LRU.
We might also want shared chaches in the future, like Redis-based or in
special SQL tables.

# How the Problems Are Solved

Define a generic Cache interface which allows different implementations.

- A noop implementation is provided and enabled as.
- An implementation using go maps is provided
  - disabled in defaults.yaml
  - enabled in integration tests
- Authz middleware instance objects are cached using the interface.

# Additional Changes

- Enabled integration test command raceflag
- Fix a race condition in the limits integration test client
- Fix a number of flaky integration tests. (Because zitadel is super
fast now!) 🎸 🚀

# Additional Context

Related to https://github.com/zitadel/zitadel/issues/8648
This commit is contained in:
Tim Möhlmann 2024-09-25 22:40:21 +03:00 committed by GitHub
parent a6ea83168d
commit 4eaa3163b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 1290 additions and 78 deletions

View File

@ -135,7 +135,7 @@ core_integration_server_start: core_integration_setup
.PHONY: core_integration_test_packages
core_integration_test_packages:
go test -count 1 -tags integration -timeout 30m $$(go list -tags integration ./... | grep "integration_test")
go test -race -count 1 -tags integration -timeout 30m $$(go list -tags integration ./... | grep "integration_test")
.PHONY: core_integration_server_stop
core_integration_server_stop:

View File

@ -183,6 +183,37 @@ Database:
Cert: # ZITADEL_DATABASE_POSTGRES_ADMIN_SSL_CERT
Key: # ZITADEL_DATABASE_POSTGRES_ADMIN_SSL_KEY
# Caches are EXPERIMENTAL. The following config may have breaking changes in the future.
# If no config is provided, caching is disabled by default.
# Caches:
# Connectors are reused by caches.
# Connectors:
# Memory connector works with local server memory.
# It is the simplest (and probably fastest) cache implementation.
# Unsuitable for deployments with multiple containers,
# as each container's cache may hold a different state of the same object.
# Memory:
# Enabled: true
# AutoPrune removes invalidated or expired object from the cache.
# AutoPrune:
# Interval: 15m
# TimeOut: 30s
# Instance caches auth middleware instances, gettable by domain or ID.
# Instance:
# Connector must be enabled above.
# When connector is empty, this cache will be disabled.
# Connector: "memory"
# MaxAge: 1h
# LastUsage: 10m
#
# Log enables cache-specific logging. Default to error log to stdout when omitted.
# Log:
# Level: debug
# AddSource: true
# Formatter:
# Format: text
Machine:
# Cloud-hosted VMs need to specify their metadata endpoint so that the machine can be uniquely identified.
Identification:

View File

@ -25,6 +25,7 @@ import (
auth_view "github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/view"
"github.com/zitadel/zitadel/internal/authz"
authz_es "github.com/zitadel/zitadel/internal/authz/repository/eventsourcing/eventstore"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/config/systemdefaults"
crypto_db "github.com/zitadel/zitadel/internal/crypto/database"
@ -71,6 +72,7 @@ type ProjectionsConfig struct {
EncryptionKeys *encryption.EncryptionKeyConfig
SystemAPIUsers map[string]*internal_authz.SystemAPIUser
Eventstore *eventstore.Config
Caches *cache.CachesConfig
Admin admin_es.Config
Auth auth_es.Config
@ -132,6 +134,7 @@ func projections(
esV4.Querier,
client,
client,
config.Caches,
config.Projections,
config.SystemDefaults,
keys.IDPConfig,

View File

@ -15,6 +15,7 @@ import (
internal_authz "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/oidc"
"github.com/zitadel/zitadel/internal/api/ui/login"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/config/hook"
"github.com/zitadel/zitadel/internal/config/systemdefaults"
@ -30,6 +31,7 @@ import (
type Config struct {
ForMirror bool
Database database.Config
Caches *cache.CachesConfig
SystemDefaults systemdefaults.SystemDefaults
InternalAuthZ internal_authz.Config
ExternalDomain string

View File

@ -309,6 +309,7 @@ func initProjections(
eventstoreV4.Querier,
queryDBClient,
projectionDBClient,
config.Caches,
config.Projections,
config.SystemDefaults,
keys.IDPConfig,

View File

@ -18,6 +18,7 @@ import (
"github.com/zitadel/zitadel/internal/api/ui/console"
"github.com/zitadel/zitadel/internal/api/ui/login"
auth_es "github.com/zitadel/zitadel/internal/auth/repository/eventsourcing"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/config/hook"
"github.com/zitadel/zitadel/internal/config/network"
@ -48,6 +49,7 @@ type Config struct {
HTTP1HostHeader string
WebAuthNName string
Database database.Config
Caches *cache.CachesConfig
Tracing tracing.Config
Metrics metrics.Config
Profiler profiler.Config

View File

@ -184,6 +184,7 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server
eventstoreV4.Querier,
queryDBClient,
projectionDBClient,
config.Caches,
config.Projections,
config.SystemDefaults,
keys.IDPConfig,

View File

@ -123,7 +123,9 @@ func (s *Server) ImportData(ctx context.Context, req *admin_pb.ImportDataRequest
return nil, ctxTimeout.Err()
case result := <-ch:
logging.OnError(result.err).Errorf("error while importing: %v", result.err)
logging.Infof("Import done: %s", result.count.getProgress())
if result.count != nil {
logging.Infof("Import done: %s", result.count.getProgress())
}
return result.ret, result.err
}
} else {

View File

@ -31,11 +31,11 @@ func TestServer_Limits_AuditLogRetention(t *testing.T) {
farPast := timestamppb.New(beforeTime.Add(-10 * time.Hour).UTC())
zeroCounts := &eventCounts{}
seededCount := requireEventually(t, iamOwnerCtx, isoInstance.Client, userID, projectID, appID, projectGrantID, func(c assert.TestingT, counts *eventCounts) {
counts.assertAll(t, c, "seeded events are > 0", assert.Greater, zeroCounts)
counts.assertAll(c, "seeded events are > 0", assert.Greater, zeroCounts)
}, "wait for seeded event assertions to pass")
produceEvents(iamOwnerCtx, t, isoInstance.Client, userID, appID, projectID, projectGrantID)
addedCount := requireEventually(t, iamOwnerCtx, isoInstance.Client, userID, projectID, appID, projectGrantID, func(c assert.TestingT, counts *eventCounts) {
counts.assertAll(t, c, "added events are > seeded events", assert.Greater, seededCount)
counts.assertAll(c, "added events are > seeded events", assert.Greater, seededCount)
}, "wait for added event assertions to pass")
_, err := integration.SystemClient().SetLimits(CTX, &system.SetLimitsRequest{
InstanceId: isoInstance.ID(),
@ -44,8 +44,8 @@ func TestServer_Limits_AuditLogRetention(t *testing.T) {
require.NoError(t, err)
var limitedCounts *eventCounts
requireEventually(t, iamOwnerCtx, isoInstance.Client, userID, projectID, appID, projectGrantID, func(c assert.TestingT, counts *eventCounts) {
counts.assertAll(t, c, "limited events < added events", assert.Less, addedCount)
counts.assertAll(t, c, "limited events > 0", assert.Greater, zeroCounts)
counts.assertAll(c, "limited events < added events", assert.Less, addedCount)
counts.assertAll(c, "limited events > 0", assert.Greater, zeroCounts)
limitedCounts = counts
}, "wait for limited event assertions to pass")
listedEvents, err := isoInstance.Client.Admin.ListEvents(iamOwnerCtx, &admin.ListEventsRequest{CreationDateFilter: &admin.ListEventsRequest_From{
@ -63,7 +63,7 @@ func TestServer_Limits_AuditLogRetention(t *testing.T) {
})
require.NoError(t, err)
requireEventually(t, iamOwnerCtx, isoInstance.Client, userID, projectID, appID, projectGrantID, func(c assert.TestingT, counts *eventCounts) {
counts.assertAll(t, c, "with reset limit, added events are > seeded events", assert.Greater, seededCount)
counts.assertAll(c, "with reset limit, added events are > seeded events", assert.Greater, seededCount)
}, "wait for reset event assertions to pass")
}
@ -77,7 +77,7 @@ func requireEventually(
) (counts *eventCounts) {
countTimeout := 30 * time.Second
assertTimeout := countTimeout + time.Second
countCtx, cancel := context.WithTimeout(ctx, countTimeout)
countCtx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
require.EventuallyWithT(t, func(c *assert.CollectT) {
counts = countEvents(countCtx, c, cc, userID, projectID, appID, projectGrantID)
@ -168,63 +168,77 @@ type eventCounts struct {
all, myUser, aUser, grant, project, app, org int
}
func (e *eventCounts) assertAll(t *testing.T, c assert.TestingT, name string, compare assert.ComparisonAssertionFunc, than *eventCounts) {
t.Run(name, func(t *testing.T) {
compare(c, e.all, than.all, "ListEvents")
compare(c, e.myUser, than.myUser, "ListMyUserChanges")
compare(c, e.aUser, than.aUser, "ListUserChanges")
compare(c, e.grant, than.grant, "ListProjectGrantChanges")
compare(c, e.project, than.project, "ListProjectChanges")
compare(c, e.app, than.app, "ListAppChanges")
compare(c, e.org, than.org, "ListOrgChanges")
})
func (e *eventCounts) assertAll(c assert.TestingT, name string, compare assert.ComparisonAssertionFunc, than *eventCounts) {
compare(c, e.all, than.all, name+"ListEvents")
compare(c, e.myUser, than.myUser, name+"ListMyUserChanges")
compare(c, e.aUser, than.aUser, name+"ListUserChanges")
compare(c, e.grant, than.grant, name+"ListProjectGrantChanges")
compare(c, e.project, than.project, name+"ListProjectChanges")
compare(c, e.app, than.app, name+"ListAppChanges")
compare(c, e.org, than.org, name+"ListOrgChanges")
}
func countEvents(ctx context.Context, t assert.TestingT, cc *integration.Client, userID, projectID, appID, grantID string) *eventCounts {
counts := new(eventCounts)
var wg sync.WaitGroup
wg.Add(7)
var mutex sync.Mutex
assertResultLocked := func(err error, f func(counts *eventCounts)) {
mutex.Lock()
assert.NoError(t, err)
f(counts)
mutex.Unlock()
}
go func() {
defer wg.Done()
result, err := cc.Admin.ListEvents(ctx, &admin.ListEventsRequest{})
assert.NoError(t, err)
counts.all = len(result.GetEvents())
assertResultLocked(err, func(counts *eventCounts) {
counts.all = len(result.GetEvents())
})
}()
go func() {
defer wg.Done()
result, err := cc.Auth.ListMyUserChanges(ctx, &auth.ListMyUserChangesRequest{})
assert.NoError(t, err)
counts.myUser = len(result.GetResult())
assertResultLocked(err, func(counts *eventCounts) {
counts.myUser = len(result.GetResult())
})
}()
go func() {
defer wg.Done()
result, err := cc.Mgmt.ListUserChanges(ctx, &management.ListUserChangesRequest{UserId: userID})
assert.NoError(t, err)
counts.aUser = len(result.GetResult())
assertResultLocked(err, func(counts *eventCounts) {
counts.aUser = len(result.GetResult())
})
}()
go func() {
defer wg.Done()
result, err := cc.Mgmt.ListAppChanges(ctx, &management.ListAppChangesRequest{ProjectId: projectID, AppId: appID})
assert.NoError(t, err)
counts.app = len(result.GetResult())
assertResultLocked(err, func(counts *eventCounts) {
counts.app = len(result.GetResult())
})
}()
go func() {
defer wg.Done()
result, err := cc.Mgmt.ListOrgChanges(ctx, &management.ListOrgChangesRequest{})
assert.NoError(t, err)
counts.org = len(result.GetResult())
assertResultLocked(err, func(counts *eventCounts) {
counts.org = len(result.GetResult())
})
}()
go func() {
defer wg.Done()
result, err := cc.Mgmt.ListProjectChanges(ctx, &management.ListProjectChangesRequest{ProjectId: projectID})
assert.NoError(t, err)
counts.project = len(result.GetResult())
assertResultLocked(err, func(counts *eventCounts) {
counts.project = len(result.GetResult())
})
}()
go func() {
defer wg.Done()
result, err := cc.Mgmt.ListProjectGrantChanges(ctx, &management.ListProjectGrantChangesRequest{ProjectId: projectID, GrantId: grantID})
assert.NoError(t, err)
counts.grant = len(result.GetResult())
assertResultLocked(err, func(counts *eventCounts) {
counts.grant = len(result.GetResult())
})
}()
wg.Wait()
return counts

View File

@ -11,6 +11,7 @@ import (
"github.com/zitadel/oidc/v3/pkg/client/profile"
"github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
oidc_api "github.com/zitadel/zitadel/internal/api/oidc"
"github.com/zitadel/zitadel/internal/domain"
@ -98,13 +99,19 @@ func TestServer_JWTProfile(t *testing.T) {
tokenSource, err := profile.NewJWTProfileTokenSourceFromKeyFileData(CTX, Instance.OIDCIssuer(), tt.keyData, tt.scope)
require.NoError(t, err)
tokens, err := tokenSource.TokenCtx(CTX)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.NotNil(t, tokens)
var tokens *oauth2.Token
require.EventuallyWithT(
t, func(collect *assert.CollectT) {
tokens, err = tokenSource.TokenCtx(CTX)
if tt.wantErr {
assert.Error(collect, err)
return
}
assert.NoError(collect, err)
assert.NotNil(collect, tokens)
},
time.Minute, time.Second,
)
provider, err := rp.NewRelyingPartyOIDC(CTX, Instance.OIDCIssuer(), "", "", redirectURI, tt.scope)
require.NoError(t, err)

106
internal/cache/cache.go vendored Normal file
View File

@ -0,0 +1,106 @@
// Package cache provides abstraction of cache implementations that can be used by zitadel.
package cache
import (
"context"
"time"
"github.com/zitadel/logging"
)
// Cache stores objects with a value of type `V`.
// Objects may be referred to by one or more indices.
// Implementations may encode the value for storage.
// This means non-exported fields may be lost and objects
// with function values may fail to encode.
// See https://pkg.go.dev/encoding/json#Marshal for example.
//
// `I` is the type by which indices are identified,
// typically an enum for type-safe access.
// Indices are defined when calling the constructor of an implementation of this interface.
// It is illegal to refer to an idex not defined during construction.
//
// `K` is the type used as key in each index.
// Due to the limitations in type constraints, all indices use the same key type.
//
// Implementations are free to use stricter type constraints or fixed typing.
type Cache[I, K comparable, V Entry[I, K]] interface {
// Get an object through specified index.
// An [IndexUnknownError] may be returned if the index is unknown.
// [ErrCacheMiss] is returned if the key was not found in the index,
// or the object is not valid.
Get(ctx context.Context, index I, key K) (V, bool)
// Set an object.
// Keys are created on each index based in the [Entry.Keys] method.
// If any key maps to an existing object, the object is invalidated,
// regardless if the object has other keys defined in the new entry.
// This to prevent ghost objects when an entry reduces the amount of keys
// for a given index.
Set(ctx context.Context, value V)
// Invalidate an object through specified index.
// Implementations may choose to instantly delete the object,
// defer until prune or a separate cleanup routine.
// Invalidated object are no longer returned from Get.
// It is safe to call Invalidate multiple times or on non-existing entries.
Invalidate(ctx context.Context, index I, key ...K) error
// Delete one or more keys from a specific index.
// An [IndexUnknownError] may be returned if the index is unknown.
// The referred object is not invalidated and may still be accessible though
// other indices and keys.
// It is safe to call Delete multiple times or on non-existing entries
Delete(ctx context.Context, index I, key ...K) error
// Truncate deletes all cached objects.
Truncate(ctx context.Context) error
// Close the cache. Subsequent calls to the cache are not allowed.
Close(ctx context.Context) error
}
// Entry contains a value of type `V` to be cached.
//
// `I` is the type by which indices are identified,
// typically an enum for type-safe access.
//
// `K` is the type used as key in an index.
// Due to the limitations in type constraints, all indices use the same key type.
type Entry[I, K comparable] interface {
// Keys returns which keys map to the object in a specified index.
// May return nil if the index in unknown or when there are no keys.
Keys(index I) (key []K)
}
type CachesConfig struct {
Connectors struct {
Memory MemoryConnectorConfig
// SQL database.Config
// Redis redis.Config?
}
Instance *CacheConfig
}
type CacheConfig struct {
Connector string
// Age since an object was added to the cache,
// after which the object is considered invalid.
// 0 disables max age checks.
MaxAge time.Duration
// Age since last use (Get) of an object,
// after which the object is considered invalid.
// 0 disables last use age checks.
LastUseAge time.Duration
// Log allows logging of the specific cache.
// By default only errors are logged to stdout.
Log *logging.Config
}
type MemoryConnectorConfig struct {
Enabled bool
AutoPrune AutoPruneConfig
}

29
internal/cache/error.go vendored Normal file
View File

@ -0,0 +1,29 @@
package cache
import (
"errors"
"fmt"
)
type IndexUnknownError[I comparable] struct {
index I
}
func NewIndexUnknownErr[I comparable](index I) error {
return IndexUnknownError[I]{index}
}
func (i IndexUnknownError[I]) Error() string {
return fmt.Sprintf("index %v unknown", i.index)
}
func (a IndexUnknownError[I]) Is(err error) bool {
if b, ok := err.(IndexUnknownError[I]); ok {
return a.index == b.index
}
return false
}
var (
ErrCacheMiss = errors.New("cache miss")
)

204
internal/cache/gomap/gomap.go vendored Normal file
View File

@ -0,0 +1,204 @@
package gomap
import (
"context"
"errors"
"log/slog"
"maps"
"os"
"sync"
"sync/atomic"
"time"
"github.com/zitadel/zitadel/internal/cache"
)
type mapCache[I, K comparable, V cache.Entry[I, K]] struct {
config *cache.CacheConfig
indexMap map[I]*index[K, V]
logger *slog.Logger
}
// NewCache returns an in-memory Cache implementation based on the builtin go map type.
// Object values are stored as-is and there is no encoding or decoding involved.
func NewCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, config cache.CacheConfig) cache.PrunerCache[I, K, V] {
m := &mapCache[I, K, V]{
config: &config,
indexMap: make(map[I]*index[K, V], len(indices)),
logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
AddSource: true,
Level: slog.LevelError,
})),
}
if config.Log != nil {
m.logger = config.Log.Slog()
}
m.logger.InfoContext(background, "map cache logging enabled")
for _, name := range indices {
m.indexMap[name] = &index[K, V]{
config: m.config,
entries: make(map[K]*entry[V]),
}
}
return m
}
func (c *mapCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) {
i, ok := c.indexMap[index]
if !ok {
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
return value, false
}
entry, err := i.Get(key)
if err == nil {
c.logger.DebugContext(ctx, "map cache get", "index", index, "key", key)
return entry.value, true
}
if errors.Is(err, cache.ErrCacheMiss) {
c.logger.InfoContext(ctx, "map cache get", "err", err, "index", index, "key", key)
return value, false
}
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
return value, false
}
func (c *mapCache[I, K, V]) Set(ctx context.Context, value V) {
now := time.Now()
entry := &entry[V]{
value: value,
created: now,
}
entry.lastUse.Store(now.UnixMicro())
for name, i := range c.indexMap {
keys := value.Keys(name)
i.Set(keys, entry)
c.logger.DebugContext(ctx, "map cache set", "index", name, "keys", keys)
}
}
func (c *mapCache[I, K, V]) Invalidate(ctx context.Context, index I, keys ...K) error {
i, ok := c.indexMap[index]
if !ok {
return cache.NewIndexUnknownErr(index)
}
i.Invalidate(keys)
c.logger.DebugContext(ctx, "map cache invalidate", "index", index, "keys", keys)
return nil
}
func (c *mapCache[I, K, V]) Delete(ctx context.Context, index I, keys ...K) error {
i, ok := c.indexMap[index]
if !ok {
return cache.NewIndexUnknownErr(index)
}
i.Delete(keys)
c.logger.DebugContext(ctx, "map cache delete", "index", index, "keys", keys)
return nil
}
func (c *mapCache[I, K, V]) Prune(ctx context.Context) error {
for name, index := range c.indexMap {
index.Prune()
c.logger.DebugContext(ctx, "map cache prune", "index", name)
}
return nil
}
func (c *mapCache[I, K, V]) Truncate(ctx context.Context) error {
for name, index := range c.indexMap {
index.Truncate()
c.logger.DebugContext(ctx, "map cache clear", "index", name)
}
return nil
}
func (c *mapCache[I, K, V]) Close(ctx context.Context) error {
return ctx.Err()
}
type index[K comparable, V any] struct {
mutex sync.RWMutex
config *cache.CacheConfig
entries map[K]*entry[V]
}
func (i *index[K, V]) Get(key K) (*entry[V], error) {
i.mutex.RLock()
entry, ok := i.entries[key]
i.mutex.RUnlock()
if ok && entry.isValid(i.config) {
return entry, nil
}
return nil, cache.ErrCacheMiss
}
func (c *index[K, V]) Set(keys []K, entry *entry[V]) {
c.mutex.Lock()
for _, key := range keys {
c.entries[key] = entry
}
c.mutex.Unlock()
}
func (i *index[K, V]) Invalidate(keys []K) {
i.mutex.RLock()
for _, key := range keys {
if entry, ok := i.entries[key]; ok {
entry.invalid.Store(true)
}
}
i.mutex.RUnlock()
}
func (c *index[K, V]) Delete(keys []K) {
c.mutex.Lock()
for _, key := range keys {
delete(c.entries, key)
}
c.mutex.Unlock()
}
func (c *index[K, V]) Prune() {
c.mutex.Lock()
maps.DeleteFunc(c.entries, func(_ K, entry *entry[V]) bool {
return !entry.isValid(c.config)
})
c.mutex.Unlock()
}
func (c *index[K, V]) Truncate() {
c.mutex.Lock()
c.entries = make(map[K]*entry[V])
c.mutex.Unlock()
}
type entry[V any] struct {
value V
created time.Time
invalid atomic.Bool
lastUse atomic.Int64 // UnixMicro time
}
func (e *entry[V]) isValid(c *cache.CacheConfig) bool {
if e.invalid.Load() {
return false
}
now := time.Now()
if c.MaxAge > 0 {
if e.created.Add(c.MaxAge).Before(now) {
e.invalid.Store(true)
return false
}
}
if c.LastUseAge > 0 {
lastUse := e.lastUse.Load()
if time.UnixMicro(lastUse).Add(c.LastUseAge).Before(now) {
e.invalid.Store(true)
return false
}
e.lastUse.CompareAndSwap(lastUse, now.UnixMicro())
}
return true
}

334
internal/cache/gomap/gomap_test.go vendored Normal file
View File

@ -0,0 +1,334 @@
package gomap
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/cache"
)
type testIndex int
const (
testIndexID testIndex = iota
testIndexName
)
var testIndices = []testIndex{
testIndexID,
testIndexName,
}
type testObject struct {
id string
names []string
}
func (o *testObject) Keys(index testIndex) []string {
switch index {
case testIndexID:
return []string{o.id}
case testIndexName:
return o.names
default:
return nil
}
}
func Test_mapCache_Get(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.CacheConfig{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
defer c.Close(context.Background())
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
}
c.Set(context.Background(), obj)
type args struct {
index testIndex
key string
}
tests := []struct {
name string
args args
want *testObject
wantOk bool
}{
{
name: "ok",
args: args{
index: testIndexID,
key: "id",
},
want: obj,
wantOk: true,
},
{
name: "miss",
args: args{
index: testIndexID,
key: "spanac",
},
want: nil,
wantOk: false,
},
{
name: "unknown index",
args: args{
index: 99,
key: "id",
},
want: nil,
wantOk: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := c.Get(context.Background(), tt.args.index, tt.args.key)
assert.Equal(t, tt.want, got)
assert.Equal(t, tt.wantOk, ok)
})
}
}
func Test_mapCache_Invalidate(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.CacheConfig{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
defer c.Close(context.Background())
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
}
c.Set(context.Background(), obj)
err := c.Invalidate(context.Background(), testIndexName, "bar")
require.NoError(t, err)
got, ok := c.Get(context.Background(), testIndexID, "id")
assert.Nil(t, got)
assert.False(t, ok)
}
func Test_mapCache_Delete(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.CacheConfig{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
defer c.Close(context.Background())
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
}
c.Set(context.Background(), obj)
err := c.Delete(context.Background(), testIndexName, "bar")
require.NoError(t, err)
// Shouldn't find object by deleted name
got, ok := c.Get(context.Background(), testIndexName, "bar")
assert.Nil(t, got)
assert.False(t, ok)
// Should find object by other name
got, ok = c.Get(context.Background(), testIndexName, "foo")
assert.Equal(t, obj, got)
assert.True(t, ok)
// Should find object by id
got, ok = c.Get(context.Background(), testIndexID, "id")
assert.Equal(t, obj, got)
assert.True(t, ok)
}
func Test_mapCache_Prune(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.CacheConfig{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
defer c.Close(context.Background())
objects := []*testObject{
{
id: "id1",
names: []string{"foo", "bar"},
},
{
id: "id2",
names: []string{"hello"},
},
}
for _, obj := range objects {
c.Set(context.Background(), obj)
}
// invalidate one entry
err := c.Invalidate(context.Background(), testIndexName, "bar")
require.NoError(t, err)
err = c.(cache.Pruner).Prune(context.Background())
require.NoError(t, err)
// Other object should still be found
got, ok := c.Get(context.Background(), testIndexID, "id2")
assert.Equal(t, objects[1], got)
assert.True(t, ok)
}
func Test_mapCache_Truncate(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.CacheConfig{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
defer c.Close(context.Background())
objects := []*testObject{
{
id: "id1",
names: []string{"foo", "bar"},
},
{
id: "id2",
names: []string{"hello"},
},
}
for _, obj := range objects {
c.Set(context.Background(), obj)
}
err := c.Truncate(context.Background())
require.NoError(t, err)
mc := c.(*mapCache[testIndex, string, *testObject])
for _, index := range mc.indexMap {
index.mutex.RLock()
assert.Len(t, index.entries, 0)
index.mutex.RUnlock()
}
}
func Test_entry_isValid(t *testing.T) {
type fields struct {
created time.Time
invalid bool
lastUse time.Time
}
tests := []struct {
name string
fields fields
config *cache.CacheConfig
want bool
}{
{
name: "invalid",
fields: fields{
created: time.Now(),
invalid: true,
lastUse: time.Now(),
},
config: &cache.CacheConfig{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: false,
},
{
name: "max age exceeded",
fields: fields{
created: time.Now().Add(-(time.Minute + time.Second)),
invalid: false,
lastUse: time.Now(),
},
config: &cache.CacheConfig{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: false,
},
{
name: "max age disabled",
fields: fields{
created: time.Now().Add(-(time.Minute + time.Second)),
invalid: false,
lastUse: time.Now(),
},
config: &cache.CacheConfig{
LastUseAge: time.Second,
},
want: true,
},
{
name: "last use age exceeded",
fields: fields{
created: time.Now().Add(-(time.Minute / 2)),
invalid: false,
lastUse: time.Now().Add(-(time.Second * 2)),
},
config: &cache.CacheConfig{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: false,
},
{
name: "last use age disabled",
fields: fields{
created: time.Now().Add(-(time.Minute / 2)),
invalid: false,
lastUse: time.Now().Add(-(time.Second * 2)),
},
config: &cache.CacheConfig{
MaxAge: time.Minute,
},
want: true,
},
{
name: "valid",
fields: fields{
created: time.Now(),
invalid: false,
lastUse: time.Now(),
},
config: &cache.CacheConfig{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &entry[any]{
created: tt.fields.created,
}
e.invalid.Store(tt.fields.invalid)
e.lastUse.Store(tt.fields.lastUse.UnixMicro())
got := e.isValid(tt.config)
assert.Equal(t, tt.want, got)
})
}
}

22
internal/cache/noop/noop.go vendored Normal file
View File

@ -0,0 +1,22 @@
package noop
import (
"context"
"github.com/zitadel/zitadel/internal/cache"
)
type noop[I, K comparable, V cache.Entry[I, K]] struct{}
// NewCache returns a cache that does nothing
func NewCache[I, K comparable, V cache.Entry[I, K]]() cache.Cache[I, K, V] {
return noop[I, K, V]{}
}
func (noop[I, K, V]) Set(context.Context, V) {}
func (noop[I, K, V]) Get(context.Context, I, K) (value V, ok bool) { return }
func (noop[I, K, V]) Invalidate(context.Context, I, ...K) (err error) { return }
func (noop[I, K, V]) Delete(context.Context, I, ...K) (err error) { return }
func (noop[I, K, V]) Prune(context.Context) (err error) { return }
func (noop[I, K, V]) Truncate(context.Context) (err error) { return }
func (noop[I, K, V]) Close(context.Context) (err error) { return }

76
internal/cache/pruner.go vendored Normal file
View File

@ -0,0 +1,76 @@
package cache
import (
"context"
"math/rand"
"time"
"github.com/jonboulle/clockwork"
"github.com/zitadel/logging"
)
// Pruner is an optional [Cache] interface.
type Pruner interface {
// Prune deletes all invalidated or expired objects.
Prune(ctx context.Context) error
}
type PrunerCache[I, K comparable, V Entry[I, K]] interface {
Cache[I, K, V]
Pruner
}
type AutoPruneConfig struct {
// Interval at which the cache is automatically pruned.
// 0 or lower disables automatic pruning.
Interval time.Duration
// Timeout for an automatic prune.
// It is recommended to keep the value shorter than AutoPruneInterval
// 0 or lower disables automatic pruning.
Timeout time.Duration
}
func (c AutoPruneConfig) StartAutoPrune(background context.Context, pruner Pruner, name string) (close func()) {
return c.startAutoPrune(background, pruner, name, clockwork.NewRealClock())
}
func (c *AutoPruneConfig) startAutoPrune(background context.Context, pruner Pruner, name string, clock clockwork.Clock) (close func()) {
if c.Interval <= 0 {
return func() {}
}
background, cancel := context.WithCancel(background)
// randomize the first interval
timer := clock.NewTimer(time.Duration(rand.Int63n(int64(c.Interval))))
go c.pruneTimer(background, pruner, name, timer)
return cancel
}
func (c *AutoPruneConfig) pruneTimer(background context.Context, pruner Pruner, name string, timer clockwork.Timer) {
defer func() {
if !timer.Stop() {
<-timer.Chan()
}
}()
for {
select {
case <-background.Done():
return
case <-timer.Chan():
timer.Reset(c.Interval)
err := c.doPrune(background, pruner)
logging.OnError(err).WithField("name", name).Error("cache auto prune")
}
}
}
func (c *AutoPruneConfig) doPrune(background context.Context, pruner Pruner) error {
ctx, cancel := context.WithCancel(background)
defer cancel()
if c.Timeout > 0 {
ctx, cancel = context.WithTimeout(background, c.Timeout)
defer cancel()
}
return pruner.Prune(ctx)
}

43
internal/cache/pruner_test.go vendored Normal file
View File

@ -0,0 +1,43 @@
package cache
import (
"context"
"testing"
"time"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
)
type testPruner struct {
called chan struct{}
}
func (p *testPruner) Prune(context.Context) error {
p.called <- struct{}{}
return nil
}
func TestAutoPruneConfig_startAutoPrune(t *testing.T) {
c := AutoPruneConfig{
Interval: time.Second,
Timeout: time.Millisecond,
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
pruner := testPruner{
called: make(chan struct{}),
}
clock := clockwork.NewFakeClock()
close := c.startAutoPrune(ctx, &pruner, "foo", clock)
defer close()
clock.Advance(time.Second)
select {
case _, ok := <-pruner.called:
assert.True(t, ok)
case <-ctx.Done():
t.Fatal(ctx.Err())
}
}

View File

@ -39,9 +39,9 @@ func failureFromEvent(event eventstore.Event, err error) *failure {
func failureFromStatement(statement *Statement, err error) *failure {
return &failure{
sequence: statement.Sequence,
instance: statement.InstanceID,
aggregateID: statement.AggregateID,
aggregateType: statement.AggregateType,
instance: statement.Aggregate.InstanceID,
aggregateID: statement.Aggregate.ID,
aggregateType: statement.Aggregate.Type,
eventDate: statement.CreationDate,
err: err,
}

View File

@ -62,6 +62,7 @@ type Handler struct {
triggeredInstancesSync sync.Map
triggerWithoutEvents Reduce
cacheInvalidations []func(ctx context.Context, aggregates []*eventstore.Aggregate)
}
var _ migration.Migration = (*Handler)(nil)
@ -418,6 +419,12 @@ func (h *Handler) Trigger(ctx context.Context, opts ...TriggerOpt) (_ context.Co
}
}
// RegisterCacheInvalidation registers a function to be called when a cache needs to be invalidated.
// In order to avoid race conditions, this method must be called before [Handler.Start] is called.
func (h *Handler) RegisterCacheInvalidation(invalidate func(ctx context.Context, aggregates []*eventstore.Aggregate)) {
h.cacheInvalidations = append(h.cacheInvalidations, invalidate)
}
// lockInstance tries to lock the instance.
// If the instance is already locked from another process no cancel function is returned
// the instance can be skipped then
@ -486,10 +493,6 @@ func (h *Handler) processEvents(ctx context.Context, config *triggerConfig) (add
h.log().OnError(rollbackErr).Debug("unable to rollback tx")
return
}
commitErr := tx.Commit()
if err == nil {
err = commitErr
}
}()
currentState, err := h.currentState(ctx, tx, config)
@ -509,6 +512,17 @@ func (h *Handler) processEvents(ctx context.Context, config *triggerConfig) (add
if err != nil {
return additionalIteration, err
}
defer func() {
commitErr := tx.Commit()
if err == nil {
err = commitErr
}
if err == nil && currentState.aggregateID != "" && len(statements) > 0 {
h.invalidateCaches(ctx, aggregatesFromStatements(statements))
}
}()
if len(statements) == 0 {
err = h.setState(tx, currentState)
return additionalIteration, err
@ -522,8 +536,8 @@ func (h *Handler) processEvents(ctx context.Context, config *triggerConfig) (add
currentState.position = statements[lastProcessedIndex].Position
currentState.offset = statements[lastProcessedIndex].offset
currentState.aggregateID = statements[lastProcessedIndex].AggregateID
currentState.aggregateType = statements[lastProcessedIndex].AggregateType
currentState.aggregateID = statements[lastProcessedIndex].Aggregate.ID
currentState.aggregateType = statements[lastProcessedIndex].Aggregate.Type
currentState.sequence = statements[lastProcessedIndex].Sequence
currentState.eventTimestamp = statements[lastProcessedIndex].CreationDate
err = h.setState(tx, currentState)
@ -556,8 +570,8 @@ func (h *Handler) generateStatements(ctx context.Context, tx *sql.Tx, currentSta
if idx+1 == len(statements) {
currentState.position = statements[len(statements)-1].Position
currentState.offset = statements[len(statements)-1].offset
currentState.aggregateID = statements[len(statements)-1].AggregateID
currentState.aggregateType = statements[len(statements)-1].AggregateType
currentState.aggregateID = statements[len(statements)-1].Aggregate.ID
currentState.aggregateType = statements[len(statements)-1].Aggregate.Type
currentState.sequence = statements[len(statements)-1].Sequence
currentState.eventTimestamp = statements[len(statements)-1].CreationDate
@ -577,8 +591,8 @@ func (h *Handler) generateStatements(ctx context.Context, tx *sql.Tx, currentSta
func skipPreviouslyReducedStatements(statements []*Statement, currentState *state) int {
for i, statement := range statements {
if statement.Position == currentState.position &&
statement.AggregateID == currentState.aggregateID &&
statement.AggregateType == currentState.aggregateType &&
statement.Aggregate.ID == currentState.aggregateID &&
statement.Aggregate.Type == currentState.aggregateType &&
statement.Sequence == currentState.sequence {
return i
}
@ -667,3 +681,34 @@ func (h *Handler) eventQuery(currentState *state) *eventstore.SearchQueryBuilder
func (h *Handler) ProjectionName() string {
return h.projection.Name()
}
func (h *Handler) invalidateCaches(ctx context.Context, aggregates []*eventstore.Aggregate) {
if len(h.cacheInvalidations) == 0 {
return
}
var wg sync.WaitGroup
wg.Add(len(h.cacheInvalidations))
for _, invalidate := range h.cacheInvalidations {
go func(invalidate func(context.Context, []*eventstore.Aggregate)) {
defer wg.Done()
invalidate(ctx, aggregates)
}(invalidate)
}
wg.Wait()
}
// aggregatesFromStatements returns the unique aggregates from statements.
// Duplicate aggregates are omitted.
func aggregatesFromStatements(statements []*Statement) []*eventstore.Aggregate {
aggregates := make([]*eventstore.Aggregate, 0, len(statements))
for _, statement := range statements {
if !slices.ContainsFunc(aggregates, func(aggregate *eventstore.Aggregate) bool {
return *statement.Aggregate == *aggregate
}) {
aggregates = append(aggregates, statement.Aggregate)
}
}
return aggregates
}

View File

@ -80,12 +80,10 @@ func (h *Handler) reduce(event eventstore.Event) (*Statement, error) {
}
type Statement struct {
AggregateType eventstore.AggregateType
AggregateID string
Sequence uint64
Position float64
CreationDate time.Time
InstanceID string
Aggregate *eventstore.Aggregate
Sequence uint64
Position float64
CreationDate time.Time
offset uint32
@ -108,13 +106,11 @@ var (
func NewStatement(event eventstore.Event, e Exec) *Statement {
return &Statement{
AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(),
Position: event.Position(),
AggregateID: event.Aggregate().ID,
CreationDate: event.CreatedAt(),
InstanceID: event.Aggregate().InstanceID,
Execute: e,
Aggregate: event.Aggregate(),
Sequence: event.Sequence(),
Position: event.Position(),
CreationDate: event.CreatedAt(),
Execute: e,
}
}

View File

@ -6,6 +6,23 @@ ExternalSecure: false
TLS:
Enabled: false
Caches:
Connectors:
Memory:
Enabled: true
AutoPrune:
Interval: 30s
TimeOut: 1s
Instance:
Connector: "memory"
MaxAge: 1m
LastUsage: 30s
Log:
Level: info
AddSource: true
Formatter:
Format: text
Quotas:
Access:
Enabled: true

95
internal/query/cache.go Normal file
View File

@ -0,0 +1,95 @@
package query
import (
"context"
"fmt"
"strings"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/cache/gomap"
"github.com/zitadel/zitadel/internal/cache/noop"
"github.com/zitadel/zitadel/internal/eventstore"
)
type Caches struct {
connectors *cacheConnectors
instance cache.Cache[instanceIndex, string, *authzInstance]
}
func startCaches(background context.Context, conf *cache.CachesConfig) (_ *Caches, err error) {
caches := &Caches{
instance: noop.NewCache[instanceIndex, string, *authzInstance](),
}
if conf == nil {
return caches, nil
}
caches.connectors, err = startCacheConnectors(background, conf)
if err != nil {
return nil, err
}
caches.instance, err = startCache[instanceIndex, string, *authzInstance](background, instanceIndexValues(), "authz_instance", conf.Instance, caches.connectors)
if err != nil {
return nil, err
}
caches.registerInstanceInvalidation()
return caches, nil
}
type cacheConnectors struct {
memory *cache.AutoPruneConfig
// pool *pgxpool.Pool
}
func startCacheConnectors(_ context.Context, conf *cache.CachesConfig) (*cacheConnectors, error) {
connectors := new(cacheConnectors)
if conf.Connectors.Memory.Enabled {
connectors.memory = &conf.Connectors.Memory.AutoPrune
}
return connectors, nil
}
func startCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, name string, conf *cache.CacheConfig, connectors *cacheConnectors) (cache.Cache[I, K, V], error) {
if conf == nil || conf.Connector == "" {
return noop.NewCache[I, K, V](), nil
}
if strings.EqualFold(conf.Connector, "memory") && connectors.memory != nil {
c := gomap.NewCache[I, K, V](background, indices, *conf)
connectors.memory.StartAutoPrune(background, c, name)
return c, nil
}
/* TODO
if strings.EqualFold(conf.Connector, "sql") && connectors.pool != nil {
return ...
}
*/
return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector)
}
type invalidator[I comparable] interface {
Invalidate(ctx context.Context, index I, key ...string) error
}
func cacheInvalidationFunc[I comparable](cache invalidator[I], index I, getID func(*eventstore.Aggregate) string) func(context.Context, []*eventstore.Aggregate) {
return func(ctx context.Context, aggregates []*eventstore.Aggregate) {
ids := make([]string, len(aggregates))
for i, aggregate := range aggregates {
ids[i] = getID(aggregate)
}
err := cache.Invalidate(ctx, index, ids...)
logging.OnError(err).Warn("cache invalidation failed")
}
}
func getAggregateID(aggregate *eventstore.Aggregate) string {
return aggregate.ID
}
func getResourceOwner(aggregate *eventstore.Aggregate) string {
return aggregate.ResourceOwner
}

View File

@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
"slices"
"strings"
"time"
@ -17,6 +18,7 @@ import (
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
"github.com/zitadel/zitadel/internal/feature"
"github.com/zitadel/zitadel/internal/query/projection"
@ -206,22 +208,35 @@ func (q *Queries) InstanceByHost(ctx context.Context, instanceHost, publicHost s
instanceDomain := strings.Split(instanceHost, ":")[0] // remove possible port
publicDomain := strings.Split(publicHost, ":")[0] // remove possible port
instance, scan := scanAuthzInstance()
// in case public domain is the same as the instance domain, we do not need to check it
// and can empty it for the check
if instanceDomain == publicDomain {
publicDomain = ""
instance, ok := q.caches.instance.Get(ctx, instanceIndexByHost, instanceDomain)
if ok {
return instance, instance.checkDomain(instanceDomain, publicDomain)
}
err = q.client.QueryRowContext(ctx, scan, instanceByDomainQuery, instanceDomain, publicDomain)
return instance, err
instance, scan := scanAuthzInstance()
if err = q.client.QueryRowContext(ctx, scan, instanceByDomainQuery, instanceDomain); err != nil {
return nil, err
}
q.caches.instance.Set(ctx, instance)
return instance, instance.checkDomain(instanceDomain, publicDomain)
}
func (q *Queries) InstanceByID(ctx context.Context, id string) (_ authz.Instance, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
instance, ok := q.caches.instance.Get(ctx, instanceIndexByID, id)
if ok {
return instance, nil
}
instance, scan := scanAuthzInstance()
err = q.client.QueryRowContext(ctx, scan, instanceByIDQuery, id)
logging.OnError(err).WithField("instance_id", id).Warn("instance by ID")
if err == nil {
q.caches.instance.Set(ctx, instance)
}
return instance, err
}
@ -431,6 +446,8 @@ type authzInstance struct {
block *bool
auditLogRetention *time.Duration
features feature.Features
externalDomains database.TextArray[string]
trustedDomains database.TextArray[string]
}
type csp struct {
@ -485,6 +502,31 @@ func (i *authzInstance) Features() feature.Features {
return i.features
}
var errPublicDomain = "public domain %q not trusted"
func (i *authzInstance) checkDomain(instanceDomain, publicDomain string) error {
// in case public domain is empty, or the same as the instance domain, we do not need to check it
if publicDomain == "" || instanceDomain == publicDomain {
return nil
}
if !slices.Contains(i.trustedDomains, publicDomain) {
return zerrors.ThrowNotFound(fmt.Errorf(errPublicDomain, publicDomain), "QUERY-IuGh1", "Errors.IAM.NotFound")
}
return nil
}
// Keys implements [cache.Entry]
func (i *authzInstance) Keys(index instanceIndex) []string {
switch index {
case instanceIndexByID:
return []string{i.id}
case instanceIndexByHost:
return i.externalDomains
default:
return nil
}
}
func scanAuthzInstance() (*authzInstance, func(row *sql.Row) error) {
instance := &authzInstance{}
return instance, func(row *sql.Row) error {
@ -509,6 +551,8 @@ func scanAuthzInstance() (*authzInstance, func(row *sql.Row) error) {
&auditLogRetention,
&block,
&features,
&instance.externalDomains,
&instance.trustedDomains,
)
if errors.Is(err, sql.ErrNoRows) {
return zerrors.ThrowNotFound(nil, "QUERY-1kIjX", "Errors.IAM.NotFound")
@ -534,3 +578,30 @@ func scanAuthzInstance() (*authzInstance, func(row *sql.Row) error) {
return nil
}
}
func (c *Caches) registerInstanceInvalidation() {
invalidate := cacheInvalidationFunc(c.instance, instanceIndexByID, getAggregateID)
projection.InstanceProjection.RegisterCacheInvalidation(invalidate)
projection.InstanceDomainProjection.RegisterCacheInvalidation(invalidate)
projection.InstanceFeatureProjection.RegisterCacheInvalidation(invalidate)
projection.InstanceTrustedDomainProjection.RegisterCacheInvalidation(invalidate)
projection.SecurityPolicyProjection.RegisterCacheInvalidation(invalidate)
// limits uses own aggregate ID, invalidate using resource owner.
invalidate = cacheInvalidationFunc(c.instance, instanceIndexByID, getResourceOwner)
projection.LimitsProjection.RegisterCacheInvalidation(invalidate)
// System feature update should invalidate all instances, so Truncate the cache.
projection.SystemFeatureProjection.RegisterCacheInvalidation(func(ctx context.Context, _ []*eventstore.Aggregate) {
err := c.instance.Truncate(ctx)
logging.OnError(err).Warn("cache truncate failed")
})
}
type instanceIndex int16
//go:generate enumer -type instanceIndex
const (
instanceIndexByID instanceIndex = iota
instanceIndexByHost
)

View File

@ -14,6 +14,16 @@ with domain as (
cross join projections.system_features s
full outer join instance_features i using (instance_id, key)
group by instance_id
), external_domains as (
select ed.instance_id, array_agg(ed.domain) as domains
from domain d
join projections.instance_domains ed on d.instance_id = ed.instance_id
group by ed.instance_id
), trusted_domains as (
select td.instance_id, array_agg(td.domain) as domains
from domain d
join projections.instance_trusted_domains td on d.instance_id = td.instance_id
group by td.instance_id
)
select
i.id,
@ -27,11 +37,13 @@ select
s.enable_impersonation,
l.audit_log_retention,
l.block,
f.features
f.features,
ed.domains as external_domains,
td.domains as trusted_domains
from domain d
join projections.instances i on i.id = d.instance_id
left join projections.instance_trusted_domains td on i.id = td.instance_id
left join projections.security_policies2 s on i.id = s.instance_id
left join projections.limits l on i.id = l.instance_id
left join features f on i.id = f.instance_id
where case when $2 = '' then true else td.domain = $2 end;
left join external_domains ed on i.id = ed.instance_id
left join trusted_domains td on i.id = td.instance_id;

View File

@ -7,6 +7,16 @@ with features as (
cross join projections.system_features s
full outer join projections.instance_features2 i using (key, instance_id)
group by instance_id
), external_domains as (
select instance_id, array_agg(domain) as domains
from projections.instance_domains
where instance_id = $1
group by instance_id
), trusted_domains as (
select instance_id, array_agg(domain) as domains
from projections.instance_trusted_domains
where instance_id = $1
group by instance_id
)
select
i.id,
@ -20,9 +30,13 @@ select
s.enable_impersonation,
l.audit_log_retention,
l.block,
f.features
f.features,
ed.domains as external_domains,
td.domains as trusted_domains
from projections.instances i
left join projections.security_policies2 s on i.id = s.instance_id
left join projections.limits l on i.id = l.instance_id
left join features f on i.id = f.instance_id
left join external_domains ed on i.id = ed.instance_id
left join trusted_domains td on i.id = td.instance_id
where i.id = $1;

View File

@ -0,0 +1,78 @@
// Code generated by "enumer -type instanceIndex"; DO NOT EDIT.
package query
import (
"fmt"
"strings"
)
const _instanceIndexName = "instanceIndexByIDinstanceIndexByHost"
var _instanceIndexIndex = [...]uint8{0, 17, 36}
const _instanceIndexLowerName = "instanceindexbyidinstanceindexbyhost"
func (i instanceIndex) String() string {
if i < 0 || i >= instanceIndex(len(_instanceIndexIndex)-1) {
return fmt.Sprintf("instanceIndex(%d)", i)
}
return _instanceIndexName[_instanceIndexIndex[i]:_instanceIndexIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _instanceIndexNoOp() {
var x [1]struct{}
_ = x[instanceIndexByID-(0)]
_ = x[instanceIndexByHost-(1)]
}
var _instanceIndexValues = []instanceIndex{instanceIndexByID, instanceIndexByHost}
var _instanceIndexNameToValueMap = map[string]instanceIndex{
_instanceIndexName[0:17]: instanceIndexByID,
_instanceIndexLowerName[0:17]: instanceIndexByID,
_instanceIndexName[17:36]: instanceIndexByHost,
_instanceIndexLowerName[17:36]: instanceIndexByHost,
}
var _instanceIndexNames = []string{
_instanceIndexName[0:17],
_instanceIndexName[17:36],
}
// instanceIndexString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func instanceIndexString(s string) (instanceIndex, error) {
if val, ok := _instanceIndexNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _instanceIndexNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to instanceIndex values", s)
}
// instanceIndexValues returns all values of the enum
func instanceIndexValues() []instanceIndex {
return _instanceIndexValues
}
// instanceIndexStrings returns a slice of all String values of the enum
func instanceIndexStrings() []string {
strs := make([]string, len(_instanceIndexNames))
copy(strs, _instanceIndexNames)
return strs
}
// IsAinstanceIndex returns "true" if the value is listed in the enum definition. "false" otherwise
func (i instanceIndex) IsAinstanceIndex() bool {
for _, v := range _instanceIndexValues {
if i == v {
return true
}
}
return false
}

View File

@ -74,8 +74,8 @@ func assertReduce(t *testing.T, stmt *handler.Statement, err error, projection s
if want.err != nil && want.err(err) {
return
}
if stmt.AggregateType != want.aggregateType {
t.Errorf("wrong aggregate type: want: %q got: %q", want.aggregateType, stmt.AggregateType)
if stmt.Aggregate.Type != want.aggregateType {
t.Errorf("wrong aggregate type: want: %q got: %q", want.aggregateType, stmt.Aggregate.Type)
}
if stmt.Sequence != want.sequence {

View File

@ -11,6 +11,7 @@ import (
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/cache"
sd "github.com/zitadel/zitadel/internal/config/systemdefaults"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/database"
@ -26,6 +27,7 @@ type Queries struct {
eventstore *eventstore.Eventstore
eventStoreV4 es_v4.Querier
client *database.DB
caches *Caches
keyEncryptionAlgorithm crypto.EncryptionAlgorithm
idpConfigEncryption crypto.EncryptionAlgorithm
@ -47,6 +49,7 @@ func StartQueries(
es *eventstore.Eventstore,
esV4 es_v4.Querier,
querySqlClient, projectionSqlClient *database.DB,
caches *cache.CachesConfig,
projections projection.Config,
defaults sd.SystemDefaults,
idpConfigEncryption, otpEncryption, keyEncryptionAlgorithm, certEncryptionAlgorithm crypto.EncryptionAlgorithm,
@ -86,6 +89,10 @@ func StartQueries(
if startProjections {
projection.Start(ctx)
}
repo.caches, err = startCaches(ctx, caches)
if err != nil {
return nil, err
}
return repo, nil
}