diff --git a/cmd/defaults.yaml b/cmd/defaults.yaml index 18f814a6d2..a188d1446b 100644 --- a/cmd/defaults.yaml +++ b/cmd/defaults.yaml @@ -291,6 +291,19 @@ Caches: DisableIndentity: false # Add suffix to client name. Default is empty. IdentitySuffix: "" + # Implementation of [Circuit Breaker Pattern](https://learn.microsoft.com/en-us/previous-versions/msp-n-p/dn589784(v=pandp.10)?redirectedfrom=MSDN) + CircuitBreaker: + # Interval when the counters are reset to 0. + # 0 interval never resets the counters until the CB is opened. + Interval: 0 + # Amount of consecutive failures permitted + MaxConsecutiveFailures: 5 + # The ratio of failed requests out of total requests + MaxFailureRatio: 0.1 + # Timeout after opening of the CB, until the state is set to half-open. + Timeout: 60s + # The allowed amount of requests that are allowed to pass when the CB is half-open. + MaxRetryRequests: 1 # Instance caches auth middleware instances, gettable by domain or ID. Instance: diff --git a/go.mod b/go.mod index cf4e755605..2928d4dbfb 100644 --- a/go.mod +++ b/go.mod @@ -56,6 +56,7 @@ require ( github.com/redis/go-redis/v9 v9.7.0 github.com/rs/cors v1.11.1 github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 + github.com/sony/gobreaker/v2 v2.0.0 github.com/sony/sonyflake v1.2.0 github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.19.0 diff --git a/go.sum b/go.sum index 015fea1b80..ad9e8914cf 100644 --- a/go.sum +++ b/go.sum @@ -670,6 +670,8 @@ github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIK github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= +github.com/sony/gobreaker/v2 v2.0.0 h1:23AaR4JQ65y4rz8JWMzgXw2gKOykZ/qfqYunll4OwJ4= +github.com/sony/gobreaker/v2 v2.0.0/go.mod h1:8JnRUz80DJ1/ne8M8v7nmTs2713i58nIt4s7XcGe/DI= github.com/sony/sonyflake v1.2.0 h1:Pfr3A+ejSg+0SPqpoAmQgEtNDAhc2G1SUYk205qVMLQ= github.com/sony/sonyflake v1.2.0/go.mod h1:LORtCywH/cq10ZbyfhKrHYgAUGH7mOBa76enV9txy/Y= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= diff --git a/internal/cache/connector/redis/circuit_breaker.go b/internal/cache/connector/redis/circuit_breaker.go new file mode 100644 index 0000000000..1e06b7387e --- /dev/null +++ b/internal/cache/connector/redis/circuit_breaker.go @@ -0,0 +1,90 @@ +package redis + +import ( + "context" + "errors" + "time" + + "github.com/redis/go-redis/v9" + "github.com/sony/gobreaker/v2" + "github.com/zitadel/logging" +) + +const defaultInflightSize = 100000 + +type CBConfig struct { + // Interval when the counters are reset to 0. + // 0 interval never resets the counters until the CB is opened. + Interval time.Duration + // Amount of consecutive failures permitted + MaxConsecutiveFailures uint32 + // The ratio of failed requests out of total requests + MaxFailureRatio float64 + // Timeout after opening of the CB, until the state is set to half-open. + Timeout time.Duration + // The allowed amount of requests that are allowed to pass when the CB is half-open. + MaxRetryRequests uint32 +} + +func (config *CBConfig) readyToTrip(counts gobreaker.Counts) bool { + if config.MaxConsecutiveFailures > 0 && counts.ConsecutiveFailures > config.MaxConsecutiveFailures { + return true + } + if config.MaxFailureRatio > 0 && counts.Requests > 0 { + failureRatio := float64(counts.TotalFailures) / float64(counts.Requests) + return failureRatio > config.MaxFailureRatio + } + return false +} + +// limiter implements [redis.Limiter] as a circuit breaker. +type limiter struct { + inflight chan func(success bool) + cb *gobreaker.TwoStepCircuitBreaker[struct{}] +} + +func newLimiter(config *CBConfig, maxActiveConns int) redis.Limiter { + if config == nil { + return nil + } + // The size of the inflight channel needs to be big enough for maxActiveConns to prevent blocking. + // When that is 0 (no limit), we must set a sane default. + if maxActiveConns <= 0 { + maxActiveConns = defaultInflightSize + } + return &limiter{ + inflight: make(chan func(success bool), maxActiveConns), + cb: gobreaker.NewTwoStepCircuitBreaker[struct{}](gobreaker.Settings{ + Name: "redis cache", + MaxRequests: config.MaxRetryRequests, + Interval: config.Interval, + Timeout: config.Timeout, + ReadyToTrip: config.readyToTrip, + OnStateChange: func(name string, from, to gobreaker.State) { + logging.WithFields("name", name, "from", from, "to", to).Warn("circuit breaker state change") + }, + }), + } +} + +// Allow implements [redis.Limiter]. +func (l *limiter) Allow() error { + done, err := l.cb.Allow() + if err != nil { + return err + } + l.inflight <- done + return nil +} + +// ReportResult implements [redis.Limiter]. +// +// ReportResult checks the error returned by the Redis client. +// `nil`, [redis.Nil] and [context.Canceled] are not considered failures. +// Any other error, like connection or [context.DeadlineExceeded] is counted as a failure. +func (l *limiter) ReportResult(err error) { + done := <-l.inflight + done(err == nil || + errors.Is(err, redis.Nil) || + errors.Is(err, context.Canceled)) +} diff --git a/internal/cache/connector/redis/circuit_breaker_test.go b/internal/cache/connector/redis/circuit_breaker_test.go new file mode 100644 index 0000000000..ba61d18071 --- /dev/null +++ b/internal/cache/connector/redis/circuit_breaker_test.go @@ -0,0 +1,168 @@ +package redis + +import ( + "context" + "testing" + "time" + + "github.com/sony/gobreaker/v2" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/internal/cache" +) + +func TestCBConfig_readyToTrip(t *testing.T) { + type fields struct { + MaxConsecutiveFailures uint32 + MaxFailureRatio float64 + } + type args struct { + counts gobreaker.Counts + } + tests := []struct { + name string + fields fields + args args + want bool + }{ + { + name: "disabled", + fields: fields{}, + args: args{ + counts: gobreaker.Counts{ + Requests: 100, + ConsecutiveFailures: 5, + TotalFailures: 10, + }, + }, + want: false, + }, + { + name: "no failures", + fields: fields{ + MaxConsecutiveFailures: 5, + MaxFailureRatio: 0.1, + }, + args: args{ + counts: gobreaker.Counts{ + Requests: 100, + ConsecutiveFailures: 0, + TotalFailures: 0, + }, + }, + want: false, + }, + { + name: "some failures", + fields: fields{ + MaxConsecutiveFailures: 5, + MaxFailureRatio: 0.1, + }, + args: args{ + counts: gobreaker.Counts{ + Requests: 100, + ConsecutiveFailures: 5, + TotalFailures: 10, + }, + }, + want: false, + }, + { + name: "consecutive exceeded", + fields: fields{ + MaxConsecutiveFailures: 5, + MaxFailureRatio: 0.1, + }, + args: args{ + counts: gobreaker.Counts{ + Requests: 100, + ConsecutiveFailures: 6, + TotalFailures: 0, + }, + }, + want: true, + }, + { + name: "ratio exceeded", + fields: fields{ + MaxConsecutiveFailures: 5, + MaxFailureRatio: 0.1, + }, + args: args{ + counts: gobreaker.Counts{ + Requests: 100, + ConsecutiveFailures: 1, + TotalFailures: 11, + }, + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := &CBConfig{ + MaxConsecutiveFailures: tt.fields.MaxConsecutiveFailures, + MaxFailureRatio: tt.fields.MaxFailureRatio, + } + if got := config.readyToTrip(tt.args.counts); got != tt.want { + t.Errorf("CBConfig.readyToTrip() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_redisCache_limiter(t *testing.T) { + c, _ := prepareCache(t, cache.Config{}, withCircuitBreakerOption( + &CBConfig{ + MaxConsecutiveFailures: 2, + MaxFailureRatio: 0.4, + Timeout: 100 * time.Millisecond, + MaxRetryRequests: 1, + }, + )) + + ctx := context.Background() + canceledCtx, cancel := context.WithCancel(ctx) + cancel() + timedOutCtx, cancel := context.WithTimeout(ctx, -1) + defer cancel() + + // CB is and should remain closed + for i := 0; i < 10; i++ { + err := c.Truncate(ctx) + require.NoError(t, err) + } + for i := 0; i < 10; i++ { + err := c.Truncate(canceledCtx) + require.ErrorIs(t, err, context.Canceled) + } + + // Timeout err should open the CB after more than 2 failures + for i := 0; i < 3; i++ { + err := c.Truncate(timedOutCtx) + if i > 2 { + require.ErrorIs(t, err, gobreaker.ErrOpenState) + } else { + require.ErrorIs(t, err, context.DeadlineExceeded) + } + } + + time.Sleep(200 * time.Millisecond) + + // CB should be half-open. If the first command fails, the CB will be Open again + err := c.Truncate(timedOutCtx) + require.ErrorIs(t, err, context.DeadlineExceeded) + err = c.Truncate(timedOutCtx) + require.ErrorIs(t, err, gobreaker.ErrOpenState) + + // Reset the DB to closed + time.Sleep(200 * time.Millisecond) + err = c.Truncate(ctx) + require.NoError(t, err) + + // Exceed the ratio + err = c.Truncate(timedOutCtx) + require.ErrorIs(t, err, context.DeadlineExceeded) + err = c.Truncate(ctx) + require.ErrorIs(t, err, gobreaker.ErrOpenState) +} diff --git a/internal/cache/connector/redis/connector.go b/internal/cache/connector/redis/connector.go index 2d0498dfa0..a10a0c25d0 100644 --- a/internal/cache/connector/redis/connector.go +++ b/internal/cache/connector/redis/connector.go @@ -105,6 +105,8 @@ type Config struct { // Add suffix to client name. Default is empty. IdentitySuffix string + + CircuitBreaker *CBConfig } type Connector struct { @@ -146,6 +148,7 @@ func optionsFromConfig(c Config) *redis.Options { ConnMaxLifetime: c.ConnMaxLifetime, DisableIndentity: c.DisableIndentity, IdentitySuffix: c.IdentitySuffix, + Limiter: newLimiter(c.CircuitBreaker, c.MaxActiveConns), } if c.EnableTLS { opts.TLSConfig = new(tls.Config) diff --git a/internal/cache/connector/redis/redis_test.go b/internal/cache/connector/redis/redis_test.go index 3f45be1502..1909f55e44 100644 --- a/internal/cache/connector/redis/redis_test.go +++ b/internal/cache/connector/redis/redis_test.go @@ -6,7 +6,6 @@ import ( "time" "github.com/alicebob/miniredis/v2" - "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/zitadel/logging" @@ -689,26 +688,34 @@ func Test_redisCache_Truncate(t *testing.T) { } } -func prepareCache(t *testing.T, conf cache.Config) (cache.Cache[testIndex, string, *testObject], *miniredis.Miniredis) { +func prepareCache(t *testing.T, conf cache.Config, options ...func(*Config)) (cache.Cache[testIndex, string, *testObject], *miniredis.Miniredis) { conf.Log = &logging.Config{ Level: "debug", AddSource: true, } server := miniredis.RunT(t) server.Select(testDB) - client := redis.NewClient(&redis.Options{ - Network: "tcp", - Addr: server.Addr(), - }) + + connConfig := Config{ + Enabled: true, + Network: "tcp", + Addr: server.Addr(), + DisableIndentity: true, + } + for _, option := range options { + option(&connConfig) + } + connector := NewConnector(connConfig) t.Cleanup(func() { - client.Close() + connector.Close() server.Close() }) - connector := NewConnector(Config{ - Enabled: true, - Network: "tcp", - Addr: server.Addr(), - }) c := NewCache[testIndex, string, *testObject](conf, connector, testDB, testIndices) return c, server } + +func withCircuitBreakerOption(cb *CBConfig) func(*Config) { + return func(c *Config) { + c.CircuitBreaker = cb + } +}