feat(storage): generic cache interface (#8628)

# Which Problems Are Solved

We identified the need of caching.
Currently we have a number of places where we use different ways of
caching, like go maps or LRU.
We might also want shared chaches in the future, like Redis-based or in
special SQL tables.

# How the Problems Are Solved

Define a generic Cache interface which allows different implementations.

- A noop implementation is provided and enabled as.
- An implementation using go maps is provided
  - disabled in defaults.yaml
  - enabled in integration tests
- Authz middleware instance objects are cached using the interface.

# Additional Changes

- Enabled integration test command raceflag
- Fix a race condition in the limits integration test client
- Fix a number of flaky integration tests. (Because zitadel is super
fast now!) 🎸 🚀

# Additional Context

Related to https://github.com/zitadel/zitadel/issues/8648
This commit is contained in:
Tim Möhlmann
2024-09-25 22:40:21 +03:00
committed by GitHub
parent a6ea83168d
commit 4eaa3163b6
28 changed files with 1290 additions and 78 deletions

95
internal/query/cache.go Normal file
View File

@@ -0,0 +1,95 @@
package query
import (
"context"
"fmt"
"strings"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/cache/gomap"
"github.com/zitadel/zitadel/internal/cache/noop"
"github.com/zitadel/zitadel/internal/eventstore"
)
type Caches struct {
connectors *cacheConnectors
instance cache.Cache[instanceIndex, string, *authzInstance]
}
func startCaches(background context.Context, conf *cache.CachesConfig) (_ *Caches, err error) {
caches := &Caches{
instance: noop.NewCache[instanceIndex, string, *authzInstance](),
}
if conf == nil {
return caches, nil
}
caches.connectors, err = startCacheConnectors(background, conf)
if err != nil {
return nil, err
}
caches.instance, err = startCache[instanceIndex, string, *authzInstance](background, instanceIndexValues(), "authz_instance", conf.Instance, caches.connectors)
if err != nil {
return nil, err
}
caches.registerInstanceInvalidation()
return caches, nil
}
type cacheConnectors struct {
memory *cache.AutoPruneConfig
// pool *pgxpool.Pool
}
func startCacheConnectors(_ context.Context, conf *cache.CachesConfig) (*cacheConnectors, error) {
connectors := new(cacheConnectors)
if conf.Connectors.Memory.Enabled {
connectors.memory = &conf.Connectors.Memory.AutoPrune
}
return connectors, nil
}
func startCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, name string, conf *cache.CacheConfig, connectors *cacheConnectors) (cache.Cache[I, K, V], error) {
if conf == nil || conf.Connector == "" {
return noop.NewCache[I, K, V](), nil
}
if strings.EqualFold(conf.Connector, "memory") && connectors.memory != nil {
c := gomap.NewCache[I, K, V](background, indices, *conf)
connectors.memory.StartAutoPrune(background, c, name)
return c, nil
}
/* TODO
if strings.EqualFold(conf.Connector, "sql") && connectors.pool != nil {
return ...
}
*/
return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector)
}
type invalidator[I comparable] interface {
Invalidate(ctx context.Context, index I, key ...string) error
}
func cacheInvalidationFunc[I comparable](cache invalidator[I], index I, getID func(*eventstore.Aggregate) string) func(context.Context, []*eventstore.Aggregate) {
return func(ctx context.Context, aggregates []*eventstore.Aggregate) {
ids := make([]string, len(aggregates))
for i, aggregate := range aggregates {
ids[i] = getID(aggregate)
}
err := cache.Invalidate(ctx, index, ids...)
logging.OnError(err).Warn("cache invalidation failed")
}
}
func getAggregateID(aggregate *eventstore.Aggregate) string {
return aggregate.ID
}
func getResourceOwner(aggregate *eventstore.Aggregate) string {
return aggregate.ResourceOwner
}

View File

@@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
"slices"
"strings"
"time"
@@ -17,6 +18,7 @@ import (
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
"github.com/zitadel/zitadel/internal/feature"
"github.com/zitadel/zitadel/internal/query/projection"
@@ -206,22 +208,35 @@ func (q *Queries) InstanceByHost(ctx context.Context, instanceHost, publicHost s
instanceDomain := strings.Split(instanceHost, ":")[0] // remove possible port
publicDomain := strings.Split(publicHost, ":")[0] // remove possible port
instance, scan := scanAuthzInstance()
// in case public domain is the same as the instance domain, we do not need to check it
// and can empty it for the check
if instanceDomain == publicDomain {
publicDomain = ""
instance, ok := q.caches.instance.Get(ctx, instanceIndexByHost, instanceDomain)
if ok {
return instance, instance.checkDomain(instanceDomain, publicDomain)
}
err = q.client.QueryRowContext(ctx, scan, instanceByDomainQuery, instanceDomain, publicDomain)
return instance, err
instance, scan := scanAuthzInstance()
if err = q.client.QueryRowContext(ctx, scan, instanceByDomainQuery, instanceDomain); err != nil {
return nil, err
}
q.caches.instance.Set(ctx, instance)
return instance, instance.checkDomain(instanceDomain, publicDomain)
}
func (q *Queries) InstanceByID(ctx context.Context, id string) (_ authz.Instance, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
instance, ok := q.caches.instance.Get(ctx, instanceIndexByID, id)
if ok {
return instance, nil
}
instance, scan := scanAuthzInstance()
err = q.client.QueryRowContext(ctx, scan, instanceByIDQuery, id)
logging.OnError(err).WithField("instance_id", id).Warn("instance by ID")
if err == nil {
q.caches.instance.Set(ctx, instance)
}
return instance, err
}
@@ -431,6 +446,8 @@ type authzInstance struct {
block *bool
auditLogRetention *time.Duration
features feature.Features
externalDomains database.TextArray[string]
trustedDomains database.TextArray[string]
}
type csp struct {
@@ -485,6 +502,31 @@ func (i *authzInstance) Features() feature.Features {
return i.features
}
var errPublicDomain = "public domain %q not trusted"
func (i *authzInstance) checkDomain(instanceDomain, publicDomain string) error {
// in case public domain is empty, or the same as the instance domain, we do not need to check it
if publicDomain == "" || instanceDomain == publicDomain {
return nil
}
if !slices.Contains(i.trustedDomains, publicDomain) {
return zerrors.ThrowNotFound(fmt.Errorf(errPublicDomain, publicDomain), "QUERY-IuGh1", "Errors.IAM.NotFound")
}
return nil
}
// Keys implements [cache.Entry]
func (i *authzInstance) Keys(index instanceIndex) []string {
switch index {
case instanceIndexByID:
return []string{i.id}
case instanceIndexByHost:
return i.externalDomains
default:
return nil
}
}
func scanAuthzInstance() (*authzInstance, func(row *sql.Row) error) {
instance := &authzInstance{}
return instance, func(row *sql.Row) error {
@@ -509,6 +551,8 @@ func scanAuthzInstance() (*authzInstance, func(row *sql.Row) error) {
&auditLogRetention,
&block,
&features,
&instance.externalDomains,
&instance.trustedDomains,
)
if errors.Is(err, sql.ErrNoRows) {
return zerrors.ThrowNotFound(nil, "QUERY-1kIjX", "Errors.IAM.NotFound")
@@ -534,3 +578,30 @@ func scanAuthzInstance() (*authzInstance, func(row *sql.Row) error) {
return nil
}
}
func (c *Caches) registerInstanceInvalidation() {
invalidate := cacheInvalidationFunc(c.instance, instanceIndexByID, getAggregateID)
projection.InstanceProjection.RegisterCacheInvalidation(invalidate)
projection.InstanceDomainProjection.RegisterCacheInvalidation(invalidate)
projection.InstanceFeatureProjection.RegisterCacheInvalidation(invalidate)
projection.InstanceTrustedDomainProjection.RegisterCacheInvalidation(invalidate)
projection.SecurityPolicyProjection.RegisterCacheInvalidation(invalidate)
// limits uses own aggregate ID, invalidate using resource owner.
invalidate = cacheInvalidationFunc(c.instance, instanceIndexByID, getResourceOwner)
projection.LimitsProjection.RegisterCacheInvalidation(invalidate)
// System feature update should invalidate all instances, so Truncate the cache.
projection.SystemFeatureProjection.RegisterCacheInvalidation(func(ctx context.Context, _ []*eventstore.Aggregate) {
err := c.instance.Truncate(ctx)
logging.OnError(err).Warn("cache truncate failed")
})
}
type instanceIndex int16
//go:generate enumer -type instanceIndex
const (
instanceIndexByID instanceIndex = iota
instanceIndexByHost
)

View File

@@ -14,6 +14,16 @@ with domain as (
cross join projections.system_features s
full outer join instance_features i using (instance_id, key)
group by instance_id
), external_domains as (
select ed.instance_id, array_agg(ed.domain) as domains
from domain d
join projections.instance_domains ed on d.instance_id = ed.instance_id
group by ed.instance_id
), trusted_domains as (
select td.instance_id, array_agg(td.domain) as domains
from domain d
join projections.instance_trusted_domains td on d.instance_id = td.instance_id
group by td.instance_id
)
select
i.id,
@@ -27,11 +37,13 @@ select
s.enable_impersonation,
l.audit_log_retention,
l.block,
f.features
f.features,
ed.domains as external_domains,
td.domains as trusted_domains
from domain d
join projections.instances i on i.id = d.instance_id
left join projections.instance_trusted_domains td on i.id = td.instance_id
left join projections.security_policies2 s on i.id = s.instance_id
left join projections.limits l on i.id = l.instance_id
left join features f on i.id = f.instance_id
where case when $2 = '' then true else td.domain = $2 end;
left join external_domains ed on i.id = ed.instance_id
left join trusted_domains td on i.id = td.instance_id;

View File

@@ -7,6 +7,16 @@ with features as (
cross join projections.system_features s
full outer join projections.instance_features2 i using (key, instance_id)
group by instance_id
), external_domains as (
select instance_id, array_agg(domain) as domains
from projections.instance_domains
where instance_id = $1
group by instance_id
), trusted_domains as (
select instance_id, array_agg(domain) as domains
from projections.instance_trusted_domains
where instance_id = $1
group by instance_id
)
select
i.id,
@@ -20,9 +30,13 @@ select
s.enable_impersonation,
l.audit_log_retention,
l.block,
f.features
f.features,
ed.domains as external_domains,
td.domains as trusted_domains
from projections.instances i
left join projections.security_policies2 s on i.id = s.instance_id
left join projections.limits l on i.id = l.instance_id
left join features f on i.id = f.instance_id
left join external_domains ed on i.id = ed.instance_id
left join trusted_domains td on i.id = td.instance_id
where i.id = $1;

View File

@@ -0,0 +1,78 @@
// Code generated by "enumer -type instanceIndex"; DO NOT EDIT.
package query
import (
"fmt"
"strings"
)
const _instanceIndexName = "instanceIndexByIDinstanceIndexByHost"
var _instanceIndexIndex = [...]uint8{0, 17, 36}
const _instanceIndexLowerName = "instanceindexbyidinstanceindexbyhost"
func (i instanceIndex) String() string {
if i < 0 || i >= instanceIndex(len(_instanceIndexIndex)-1) {
return fmt.Sprintf("instanceIndex(%d)", i)
}
return _instanceIndexName[_instanceIndexIndex[i]:_instanceIndexIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _instanceIndexNoOp() {
var x [1]struct{}
_ = x[instanceIndexByID-(0)]
_ = x[instanceIndexByHost-(1)]
}
var _instanceIndexValues = []instanceIndex{instanceIndexByID, instanceIndexByHost}
var _instanceIndexNameToValueMap = map[string]instanceIndex{
_instanceIndexName[0:17]: instanceIndexByID,
_instanceIndexLowerName[0:17]: instanceIndexByID,
_instanceIndexName[17:36]: instanceIndexByHost,
_instanceIndexLowerName[17:36]: instanceIndexByHost,
}
var _instanceIndexNames = []string{
_instanceIndexName[0:17],
_instanceIndexName[17:36],
}
// instanceIndexString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func instanceIndexString(s string) (instanceIndex, error) {
if val, ok := _instanceIndexNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _instanceIndexNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to instanceIndex values", s)
}
// instanceIndexValues returns all values of the enum
func instanceIndexValues() []instanceIndex {
return _instanceIndexValues
}
// instanceIndexStrings returns a slice of all String values of the enum
func instanceIndexStrings() []string {
strs := make([]string, len(_instanceIndexNames))
copy(strs, _instanceIndexNames)
return strs
}
// IsAinstanceIndex returns "true" if the value is listed in the enum definition. "false" otherwise
func (i instanceIndex) IsAinstanceIndex() bool {
for _, v := range _instanceIndexValues {
if i == v {
return true
}
}
return false
}

View File

@@ -74,8 +74,8 @@ func assertReduce(t *testing.T, stmt *handler.Statement, err error, projection s
if want.err != nil && want.err(err) {
return
}
if stmt.AggregateType != want.aggregateType {
t.Errorf("wrong aggregate type: want: %q got: %q", want.aggregateType, stmt.AggregateType)
if stmt.Aggregate.Type != want.aggregateType {
t.Errorf("wrong aggregate type: want: %q got: %q", want.aggregateType, stmt.Aggregate.Type)
}
if stmt.Sequence != want.sequence {

View File

@@ -11,6 +11,7 @@ import (
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/cache"
sd "github.com/zitadel/zitadel/internal/config/systemdefaults"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/database"
@@ -26,6 +27,7 @@ type Queries struct {
eventstore *eventstore.Eventstore
eventStoreV4 es_v4.Querier
client *database.DB
caches *Caches
keyEncryptionAlgorithm crypto.EncryptionAlgorithm
idpConfigEncryption crypto.EncryptionAlgorithm
@@ -47,6 +49,7 @@ func StartQueries(
es *eventstore.Eventstore,
esV4 es_v4.Querier,
querySqlClient, projectionSqlClient *database.DB,
caches *cache.CachesConfig,
projections projection.Config,
defaults sd.SystemDefaults,
idpConfigEncryption, otpEncryption, keyEncryptionAlgorithm, certEncryptionAlgorithm crypto.EncryptionAlgorithm,
@@ -86,6 +89,10 @@ func StartQueries(
if startProjections {
projection.Start(ctx)
}
repo.caches, err = startCaches(ctx, caches)
if err != nil {
return nil, err
}
return repo, nil
}