mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 01:37:31 +00:00
feat(storage): generic cache interface (#8628)
# Which Problems Are Solved We identified the need of caching. Currently we have a number of places where we use different ways of caching, like go maps or LRU. We might also want shared chaches in the future, like Redis-based or in special SQL tables. # How the Problems Are Solved Define a generic Cache interface which allows different implementations. - A noop implementation is provided and enabled as. - An implementation using go maps is provided - disabled in defaults.yaml - enabled in integration tests - Authz middleware instance objects are cached using the interface. # Additional Changes - Enabled integration test command raceflag - Fix a race condition in the limits integration test client - Fix a number of flaky integration tests. (Because zitadel is super fast now!) 🎸 🚀 # Additional Context Related to https://github.com/zitadel/zitadel/issues/8648
This commit is contained in:
95
internal/query/cache.go
Normal file
95
internal/query/cache.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/cache"
|
||||
"github.com/zitadel/zitadel/internal/cache/gomap"
|
||||
"github.com/zitadel/zitadel/internal/cache/noop"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
type Caches struct {
|
||||
connectors *cacheConnectors
|
||||
instance cache.Cache[instanceIndex, string, *authzInstance]
|
||||
}
|
||||
|
||||
func startCaches(background context.Context, conf *cache.CachesConfig) (_ *Caches, err error) {
|
||||
caches := &Caches{
|
||||
instance: noop.NewCache[instanceIndex, string, *authzInstance](),
|
||||
}
|
||||
if conf == nil {
|
||||
return caches, nil
|
||||
}
|
||||
caches.connectors, err = startCacheConnectors(background, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
caches.instance, err = startCache[instanceIndex, string, *authzInstance](background, instanceIndexValues(), "authz_instance", conf.Instance, caches.connectors)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
caches.registerInstanceInvalidation()
|
||||
|
||||
return caches, nil
|
||||
}
|
||||
|
||||
type cacheConnectors struct {
|
||||
memory *cache.AutoPruneConfig
|
||||
// pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func startCacheConnectors(_ context.Context, conf *cache.CachesConfig) (*cacheConnectors, error) {
|
||||
connectors := new(cacheConnectors)
|
||||
if conf.Connectors.Memory.Enabled {
|
||||
connectors.memory = &conf.Connectors.Memory.AutoPrune
|
||||
}
|
||||
|
||||
return connectors, nil
|
||||
}
|
||||
|
||||
func startCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, name string, conf *cache.CacheConfig, connectors *cacheConnectors) (cache.Cache[I, K, V], error) {
|
||||
if conf == nil || conf.Connector == "" {
|
||||
return noop.NewCache[I, K, V](), nil
|
||||
}
|
||||
if strings.EqualFold(conf.Connector, "memory") && connectors.memory != nil {
|
||||
c := gomap.NewCache[I, K, V](background, indices, *conf)
|
||||
connectors.memory.StartAutoPrune(background, c, name)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
/* TODO
|
||||
if strings.EqualFold(conf.Connector, "sql") && connectors.pool != nil {
|
||||
return ...
|
||||
}
|
||||
*/
|
||||
|
||||
return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector)
|
||||
}
|
||||
|
||||
type invalidator[I comparable] interface {
|
||||
Invalidate(ctx context.Context, index I, key ...string) error
|
||||
}
|
||||
|
||||
func cacheInvalidationFunc[I comparable](cache invalidator[I], index I, getID func(*eventstore.Aggregate) string) func(context.Context, []*eventstore.Aggregate) {
|
||||
return func(ctx context.Context, aggregates []*eventstore.Aggregate) {
|
||||
ids := make([]string, len(aggregates))
|
||||
for i, aggregate := range aggregates {
|
||||
ids[i] = getID(aggregate)
|
||||
}
|
||||
err := cache.Invalidate(ctx, index, ids...)
|
||||
logging.OnError(err).Warn("cache invalidation failed")
|
||||
}
|
||||
}
|
||||
|
||||
func getAggregateID(aggregate *eventstore.Aggregate) string {
|
||||
return aggregate.ID
|
||||
}
|
||||
|
||||
func getResourceOwner(aggregate *eventstore.Aggregate) string {
|
||||
return aggregate.ResourceOwner
|
||||
}
|
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/call"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
|
||||
"github.com/zitadel/zitadel/internal/feature"
|
||||
"github.com/zitadel/zitadel/internal/query/projection"
|
||||
@@ -206,22 +208,35 @@ func (q *Queries) InstanceByHost(ctx context.Context, instanceHost, publicHost s
|
||||
|
||||
instanceDomain := strings.Split(instanceHost, ":")[0] // remove possible port
|
||||
publicDomain := strings.Split(publicHost, ":")[0] // remove possible port
|
||||
instance, scan := scanAuthzInstance()
|
||||
// in case public domain is the same as the instance domain, we do not need to check it
|
||||
// and can empty it for the check
|
||||
if instanceDomain == publicDomain {
|
||||
publicDomain = ""
|
||||
|
||||
instance, ok := q.caches.instance.Get(ctx, instanceIndexByHost, instanceDomain)
|
||||
if ok {
|
||||
return instance, instance.checkDomain(instanceDomain, publicDomain)
|
||||
}
|
||||
err = q.client.QueryRowContext(ctx, scan, instanceByDomainQuery, instanceDomain, publicDomain)
|
||||
return instance, err
|
||||
instance, scan := scanAuthzInstance()
|
||||
if err = q.client.QueryRowContext(ctx, scan, instanceByDomainQuery, instanceDomain); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q.caches.instance.Set(ctx, instance)
|
||||
|
||||
return instance, instance.checkDomain(instanceDomain, publicDomain)
|
||||
}
|
||||
|
||||
func (q *Queries) InstanceByID(ctx context.Context, id string) (_ authz.Instance, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
instance, ok := q.caches.instance.Get(ctx, instanceIndexByID, id)
|
||||
if ok {
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
instance, scan := scanAuthzInstance()
|
||||
err = q.client.QueryRowContext(ctx, scan, instanceByIDQuery, id)
|
||||
logging.OnError(err).WithField("instance_id", id).Warn("instance by ID")
|
||||
if err == nil {
|
||||
q.caches.instance.Set(ctx, instance)
|
||||
}
|
||||
return instance, err
|
||||
}
|
||||
|
||||
@@ -431,6 +446,8 @@ type authzInstance struct {
|
||||
block *bool
|
||||
auditLogRetention *time.Duration
|
||||
features feature.Features
|
||||
externalDomains database.TextArray[string]
|
||||
trustedDomains database.TextArray[string]
|
||||
}
|
||||
|
||||
type csp struct {
|
||||
@@ -485,6 +502,31 @@ func (i *authzInstance) Features() feature.Features {
|
||||
return i.features
|
||||
}
|
||||
|
||||
var errPublicDomain = "public domain %q not trusted"
|
||||
|
||||
func (i *authzInstance) checkDomain(instanceDomain, publicDomain string) error {
|
||||
// in case public domain is empty, or the same as the instance domain, we do not need to check it
|
||||
if publicDomain == "" || instanceDomain == publicDomain {
|
||||
return nil
|
||||
}
|
||||
if !slices.Contains(i.trustedDomains, publicDomain) {
|
||||
return zerrors.ThrowNotFound(fmt.Errorf(errPublicDomain, publicDomain), "QUERY-IuGh1", "Errors.IAM.NotFound")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Keys implements [cache.Entry]
|
||||
func (i *authzInstance) Keys(index instanceIndex) []string {
|
||||
switch index {
|
||||
case instanceIndexByID:
|
||||
return []string{i.id}
|
||||
case instanceIndexByHost:
|
||||
return i.externalDomains
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func scanAuthzInstance() (*authzInstance, func(row *sql.Row) error) {
|
||||
instance := &authzInstance{}
|
||||
return instance, func(row *sql.Row) error {
|
||||
@@ -509,6 +551,8 @@ func scanAuthzInstance() (*authzInstance, func(row *sql.Row) error) {
|
||||
&auditLogRetention,
|
||||
&block,
|
||||
&features,
|
||||
&instance.externalDomains,
|
||||
&instance.trustedDomains,
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return zerrors.ThrowNotFound(nil, "QUERY-1kIjX", "Errors.IAM.NotFound")
|
||||
@@ -534,3 +578,30 @@ func scanAuthzInstance() (*authzInstance, func(row *sql.Row) error) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Caches) registerInstanceInvalidation() {
|
||||
invalidate := cacheInvalidationFunc(c.instance, instanceIndexByID, getAggregateID)
|
||||
projection.InstanceProjection.RegisterCacheInvalidation(invalidate)
|
||||
projection.InstanceDomainProjection.RegisterCacheInvalidation(invalidate)
|
||||
projection.InstanceFeatureProjection.RegisterCacheInvalidation(invalidate)
|
||||
projection.InstanceTrustedDomainProjection.RegisterCacheInvalidation(invalidate)
|
||||
projection.SecurityPolicyProjection.RegisterCacheInvalidation(invalidate)
|
||||
|
||||
// limits uses own aggregate ID, invalidate using resource owner.
|
||||
invalidate = cacheInvalidationFunc(c.instance, instanceIndexByID, getResourceOwner)
|
||||
projection.LimitsProjection.RegisterCacheInvalidation(invalidate)
|
||||
|
||||
// System feature update should invalidate all instances, so Truncate the cache.
|
||||
projection.SystemFeatureProjection.RegisterCacheInvalidation(func(ctx context.Context, _ []*eventstore.Aggregate) {
|
||||
err := c.instance.Truncate(ctx)
|
||||
logging.OnError(err).Warn("cache truncate failed")
|
||||
})
|
||||
}
|
||||
|
||||
type instanceIndex int16
|
||||
|
||||
//go:generate enumer -type instanceIndex
|
||||
const (
|
||||
instanceIndexByID instanceIndex = iota
|
||||
instanceIndexByHost
|
||||
)
|
||||
|
@@ -14,6 +14,16 @@ with domain as (
|
||||
cross join projections.system_features s
|
||||
full outer join instance_features i using (instance_id, key)
|
||||
group by instance_id
|
||||
), external_domains as (
|
||||
select ed.instance_id, array_agg(ed.domain) as domains
|
||||
from domain d
|
||||
join projections.instance_domains ed on d.instance_id = ed.instance_id
|
||||
group by ed.instance_id
|
||||
), trusted_domains as (
|
||||
select td.instance_id, array_agg(td.domain) as domains
|
||||
from domain d
|
||||
join projections.instance_trusted_domains td on d.instance_id = td.instance_id
|
||||
group by td.instance_id
|
||||
)
|
||||
select
|
||||
i.id,
|
||||
@@ -27,11 +37,13 @@ select
|
||||
s.enable_impersonation,
|
||||
l.audit_log_retention,
|
||||
l.block,
|
||||
f.features
|
||||
f.features,
|
||||
ed.domains as external_domains,
|
||||
td.domains as trusted_domains
|
||||
from domain d
|
||||
join projections.instances i on i.id = d.instance_id
|
||||
left join projections.instance_trusted_domains td on i.id = td.instance_id
|
||||
left join projections.security_policies2 s on i.id = s.instance_id
|
||||
left join projections.limits l on i.id = l.instance_id
|
||||
left join features f on i.id = f.instance_id
|
||||
where case when $2 = '' then true else td.domain = $2 end;
|
||||
left join external_domains ed on i.id = ed.instance_id
|
||||
left join trusted_domains td on i.id = td.instance_id;
|
||||
|
@@ -7,6 +7,16 @@ with features as (
|
||||
cross join projections.system_features s
|
||||
full outer join projections.instance_features2 i using (key, instance_id)
|
||||
group by instance_id
|
||||
), external_domains as (
|
||||
select instance_id, array_agg(domain) as domains
|
||||
from projections.instance_domains
|
||||
where instance_id = $1
|
||||
group by instance_id
|
||||
), trusted_domains as (
|
||||
select instance_id, array_agg(domain) as domains
|
||||
from projections.instance_trusted_domains
|
||||
where instance_id = $1
|
||||
group by instance_id
|
||||
)
|
||||
select
|
||||
i.id,
|
||||
@@ -20,9 +30,13 @@ select
|
||||
s.enable_impersonation,
|
||||
l.audit_log_retention,
|
||||
l.block,
|
||||
f.features
|
||||
f.features,
|
||||
ed.domains as external_domains,
|
||||
td.domains as trusted_domains
|
||||
from projections.instances i
|
||||
left join projections.security_policies2 s on i.id = s.instance_id
|
||||
left join projections.limits l on i.id = l.instance_id
|
||||
left join features f on i.id = f.instance_id
|
||||
left join external_domains ed on i.id = ed.instance_id
|
||||
left join trusted_domains td on i.id = td.instance_id
|
||||
where i.id = $1;
|
||||
|
78
internal/query/instanceindex_enumer.go
Normal file
78
internal/query/instanceindex_enumer.go
Normal file
@@ -0,0 +1,78 @@
|
||||
// Code generated by "enumer -type instanceIndex"; DO NOT EDIT.
|
||||
|
||||
package query
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const _instanceIndexName = "instanceIndexByIDinstanceIndexByHost"
|
||||
|
||||
var _instanceIndexIndex = [...]uint8{0, 17, 36}
|
||||
|
||||
const _instanceIndexLowerName = "instanceindexbyidinstanceindexbyhost"
|
||||
|
||||
func (i instanceIndex) String() string {
|
||||
if i < 0 || i >= instanceIndex(len(_instanceIndexIndex)-1) {
|
||||
return fmt.Sprintf("instanceIndex(%d)", i)
|
||||
}
|
||||
return _instanceIndexName[_instanceIndexIndex[i]:_instanceIndexIndex[i+1]]
|
||||
}
|
||||
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
func _instanceIndexNoOp() {
|
||||
var x [1]struct{}
|
||||
_ = x[instanceIndexByID-(0)]
|
||||
_ = x[instanceIndexByHost-(1)]
|
||||
}
|
||||
|
||||
var _instanceIndexValues = []instanceIndex{instanceIndexByID, instanceIndexByHost}
|
||||
|
||||
var _instanceIndexNameToValueMap = map[string]instanceIndex{
|
||||
_instanceIndexName[0:17]: instanceIndexByID,
|
||||
_instanceIndexLowerName[0:17]: instanceIndexByID,
|
||||
_instanceIndexName[17:36]: instanceIndexByHost,
|
||||
_instanceIndexLowerName[17:36]: instanceIndexByHost,
|
||||
}
|
||||
|
||||
var _instanceIndexNames = []string{
|
||||
_instanceIndexName[0:17],
|
||||
_instanceIndexName[17:36],
|
||||
}
|
||||
|
||||
// instanceIndexString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func instanceIndexString(s string) (instanceIndex, error) {
|
||||
if val, ok := _instanceIndexNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if val, ok := _instanceIndexNameToValueMap[strings.ToLower(s)]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to instanceIndex values", s)
|
||||
}
|
||||
|
||||
// instanceIndexValues returns all values of the enum
|
||||
func instanceIndexValues() []instanceIndex {
|
||||
return _instanceIndexValues
|
||||
}
|
||||
|
||||
// instanceIndexStrings returns a slice of all String values of the enum
|
||||
func instanceIndexStrings() []string {
|
||||
strs := make([]string, len(_instanceIndexNames))
|
||||
copy(strs, _instanceIndexNames)
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsAinstanceIndex returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i instanceIndex) IsAinstanceIndex() bool {
|
||||
for _, v := range _instanceIndexValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
@@ -74,8 +74,8 @@ func assertReduce(t *testing.T, stmt *handler.Statement, err error, projection s
|
||||
if want.err != nil && want.err(err) {
|
||||
return
|
||||
}
|
||||
if stmt.AggregateType != want.aggregateType {
|
||||
t.Errorf("wrong aggregate type: want: %q got: %q", want.aggregateType, stmt.AggregateType)
|
||||
if stmt.Aggregate.Type != want.aggregateType {
|
||||
t.Errorf("wrong aggregate type: want: %q got: %q", want.aggregateType, stmt.Aggregate.Type)
|
||||
}
|
||||
|
||||
if stmt.Sequence != want.sequence {
|
||||
|
@@ -11,6 +11,7 @@ import (
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/cache"
|
||||
sd "github.com/zitadel/zitadel/internal/config/systemdefaults"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
@@ -26,6 +27,7 @@ type Queries struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
eventStoreV4 es_v4.Querier
|
||||
client *database.DB
|
||||
caches *Caches
|
||||
|
||||
keyEncryptionAlgorithm crypto.EncryptionAlgorithm
|
||||
idpConfigEncryption crypto.EncryptionAlgorithm
|
||||
@@ -47,6 +49,7 @@ func StartQueries(
|
||||
es *eventstore.Eventstore,
|
||||
esV4 es_v4.Querier,
|
||||
querySqlClient, projectionSqlClient *database.DB,
|
||||
caches *cache.CachesConfig,
|
||||
projections projection.Config,
|
||||
defaults sd.SystemDefaults,
|
||||
idpConfigEncryption, otpEncryption, keyEncryptionAlgorithm, certEncryptionAlgorithm crypto.EncryptionAlgorithm,
|
||||
@@ -86,6 +89,10 @@ func StartQueries(
|
||||
if startProjections {
|
||||
projection.Start(ctx)
|
||||
}
|
||||
repo.caches, err = startCaches(ctx, caches)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return repo, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user