mirror of
https://github.com/zitadel/zitadel.git
synced 2025-01-05 14:37:45 +00:00
fix(oidc): ignore public key expiry for ID Token hints (#7293)
* fix(oidc): ignore public key expiry for ID Token hints This splits the key sets used for access token and ID token hints. ID Token hints should be able to be verified by with public keys that are already expired. However, we do not want to change this behavior for Access Tokens, where an error for an expired public key is still returned. The public key cache is modified to purge public keys based on last use, instead of expiry. The cache is shared between both verifiers. * resolve review comments * pin oidc 3.11
This commit is contained in:
parent
5e23ea55b2
commit
df57a64ed7
@ -219,7 +219,7 @@ docker compose --file ./e2e/config/host.docker.internal/docker-compose.yaml down
|
||||
In order to run the integrations tests for the gRPC API, PostgreSQL and CockroachDB must be started and initialized:
|
||||
|
||||
```bash
|
||||
export INTEGRATION_DB_FLAVOR="cockroach" ZITADEL_MASTERKEY="MasterkeyNeedsToHave32Characters"
|
||||
export INTEGRATION_DB_FLAVOR="postgres" ZITADEL_MASTERKEY="MasterkeyNeedsToHave32Characters"
|
||||
docker compose -f internal/integration/config/docker-compose.yaml up --pull always --wait ${INTEGRATION_DB_FLAVOR}
|
||||
make core_integration_test
|
||||
docker compose -f internal/integration/config/docker-compose.yaml down
|
||||
|
@ -337,6 +337,7 @@ OIDC:
|
||||
TriggerIntrospectionProjections: false
|
||||
# Allows fallback to the Legacy Introspection implementation
|
||||
LegacyIntrospection: false
|
||||
PublicKeyCacheMaxAge: 24h # ZITADEL_OIDC_PUBLICKEYCACHEMAXAGE
|
||||
|
||||
SAML:
|
||||
ProviderConfig:
|
||||
|
10
go.mod
10
go.mod
@ -61,19 +61,19 @@ require (
|
||||
github.com/superseriousbusiness/exifremove v0.0.0-20210330092427-6acd27eac203
|
||||
github.com/ttacon/libphonenumber v1.2.1
|
||||
github.com/zitadel/logging v0.5.0
|
||||
github.com/zitadel/oidc/v3 v3.10.2
|
||||
github.com/zitadel/oidc/v3 v3.11.0
|
||||
github.com/zitadel/passwap v0.5.0
|
||||
github.com/zitadel/saml v0.1.3
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.46.1
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1
|
||||
go.opentelemetry.io/otel v1.21.0
|
||||
go.opentelemetry.io/otel v1.22.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.21.0
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.44.0
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.21.0
|
||||
go.opentelemetry.io/otel/metric v1.21.0
|
||||
go.opentelemetry.io/otel/metric v1.22.0
|
||||
go.opentelemetry.io/otel/sdk v1.21.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.21.0
|
||||
go.opentelemetry.io/otel/trace v1.21.0
|
||||
go.opentelemetry.io/otel/trace v1.22.0
|
||||
go.uber.org/mock v0.4.0
|
||||
golang.org/x/crypto v0.18.0
|
||||
golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3
|
||||
@ -155,7 +155,7 @@ require (
|
||||
github.com/gofrs/uuid v4.4.0+incompatible // indirect
|
||||
github.com/golang/geo v0.0.0-20230421003525-6adc56603217 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
github.com/google/uuid v1.5.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.12.0 // indirect
|
||||
github.com/gorilla/handlers v1.5.2 // indirect
|
||||
|
20
go.sum
20
go.sum
@ -334,8 +334,8 @@ github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8
|
||||
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
|
||||
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0=
|
||||
github.com/googleapis/gax-go/v2 v2.12.0 h1:A+gCJKdRfqXkr+BIRGtZLibNXf0m1f9E4HG56etFpas=
|
||||
@ -782,8 +782,8 @@ github.com/zenazn/goji v1.0.1 h1:4lbD8Mx2h7IvloP7r2C0D6ltZP6Ufip8Hn0wmSK5LR8=
|
||||
github.com/zenazn/goji v1.0.1/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
|
||||
github.com/zitadel/logging v0.5.0 h1:Kunouvqse/efXy4UDvFw5s3vP+Z4AlHo3y8wF7stXHA=
|
||||
github.com/zitadel/logging v0.5.0/go.mod h1:IzP5fzwFhzzyxHkSmfF8dsyqFsQRJLLcQmwhIBzlGsE=
|
||||
github.com/zitadel/oidc/v3 v3.10.2 h1:nowZrpOBR4tdIlYXE8/l5Nl84QDYwyHpccIE1l2OAd4=
|
||||
github.com/zitadel/oidc/v3 v3.10.2/go.mod h1:nfjWH8ps4B7T0JGJyLLOIUlhr0Z4becyGKui/sXYpA8=
|
||||
github.com/zitadel/oidc/v3 v3.11.0 h1:g3sOT1ith+Yc8ExDrywe0WJLKg7Fvhs7txiYX3fEcWY=
|
||||
github.com/zitadel/oidc/v3 v3.11.0/go.mod h1:UehVNuuqOYrBSFqNeHLzCpt+/Wd+LI0c9Ok87UEO73g=
|
||||
github.com/zitadel/passwap v0.5.0 h1:kFMoRyo0GnxtOz7j9+r/CsRwSCjHGRaAKoUe69NwPvs=
|
||||
github.com/zitadel/passwap v0.5.0/go.mod h1:uqY7D3jqdTFcKsW0Q3Pcv5qDMmSHpVTzUZewUKC1KZA=
|
||||
github.com/zitadel/saml v0.1.3 h1:LI4DOCVyyU1qKPkzs3vrGcA5J3H4pH3+CL9zr9ShkpM=
|
||||
@ -801,8 +801,8 @@ go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.4
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.46.1/go.mod h1:4UoMYEZOC0yN/sPGH76KPkkU7zgiEWYWL9vwmbnTJPE=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 h1:aFJWCqJMNjENlcleuuOkGAPH82y0yULBScfXcIEdS24=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1/go.mod h1:sEGXWArGqc3tVa+ekntsN65DmVbVeW+7lTKTjZF3/Fo=
|
||||
go.opentelemetry.io/otel v1.21.0 h1:hzLeKBZEL7Okw2mGzZ0cc4k/A7Fta0uoPgaJCr8fsFc=
|
||||
go.opentelemetry.io/otel v1.21.0/go.mod h1:QZzNPQPm1zLX4gZK4cMi+71eaorMSGT3A4znnUvNNEo=
|
||||
go.opentelemetry.io/otel v1.22.0 h1:xS7Ku+7yTFvDfDraDIJVpw7XPyuHlB9MCiqqX5mcJ6Y=
|
||||
go.opentelemetry.io/otel v1.22.0/go.mod h1:eoV4iAi3Ea8LkAEI9+GFT44O6T/D0GWAVFyZVCC6pMI=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.21.0 h1:cl5P5/GIfFh4t6xyruOgJP5QiA1pw4fYYdv6nc6CBWw=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.21.0/go.mod h1:zgBdWWAu7oEEMC06MMKc5NLbA/1YDXV1sMpSqEeLQLg=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.21.0 h1:tIqheXEFWAZ7O8A7m+J0aPTmpJN3YQ7qetUAdkkkKpk=
|
||||
@ -811,14 +811,14 @@ go.opentelemetry.io/otel/exporters/prometheus v0.44.0 h1:08qeJgaPC0YEBu2PQMbqU3r
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.44.0/go.mod h1:ERL2uIeBtg4TxZdojHUwzZfIFlUIjZtxubT5p4h1Gjg=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.21.0 h1:VhlEQAPp9R1ktYfrPk5SOryw1e9LDDTZCbIPFrho0ec=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.21.0/go.mod h1:kB3ufRbfU+CQ4MlUcqtW8Z7YEOBeK2DJ6CmR5rYYF3E=
|
||||
go.opentelemetry.io/otel/metric v1.21.0 h1:tlYWfeo+Bocx5kLEloTjbcDwBuELRrIFxwdQ36PlJu4=
|
||||
go.opentelemetry.io/otel/metric v1.21.0/go.mod h1:o1p3CA8nNHW8j5yuQLdc1eeqEaPfzug24uvsyIEJRWM=
|
||||
go.opentelemetry.io/otel/metric v1.22.0 h1:lypMQnGyJYeuYPhOM/bgjbFM6WE44W1/T45er4d8Hhg=
|
||||
go.opentelemetry.io/otel/metric v1.22.0/go.mod h1:evJGjVpZv0mQ5QBRJoBF64yMuOf4xCWdXjK8pzFvliY=
|
||||
go.opentelemetry.io/otel/sdk v1.21.0 h1:FTt8qirL1EysG6sTQRZ5TokkU8d0ugCj8htOgThZXQ8=
|
||||
go.opentelemetry.io/otel/sdk v1.21.0/go.mod h1:Nna6Yv7PWTdgJHVRD9hIYywQBRx7pbox6nwBnZIxl/E=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.21.0 h1:smhI5oD714d6jHE6Tie36fPx4WDFIg+Y6RfAY4ICcR0=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.21.0/go.mod h1:FJ8RAsoPGv/wYMgBdUJXOm+6pzFY3YdljnXtv1SBE8Q=
|
||||
go.opentelemetry.io/otel/trace v1.21.0 h1:WD9i5gzvoUPuXIXH24ZNBudiarZDKuekPqi/E8fpfLc=
|
||||
go.opentelemetry.io/otel/trace v1.21.0/go.mod h1:LGbsEB0f9LGjN+OZaQQ26sohbOmiMR+BaslueVtS/qQ=
|
||||
go.opentelemetry.io/otel/trace v1.22.0 h1:Hg6pPujv0XG9QaVbGOBVHunyuLcCC3jN7WEhPx83XD0=
|
||||
go.opentelemetry.io/otel/trace v1.22.0/go.mod h1:RbbHXVqKES9QhzZq/fE5UnOSILqRt40a21sPw2He1xo=
|
||||
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
|
||||
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
|
||||
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
|
||||
|
@ -39,7 +39,7 @@ func (s *Server) verifyAccessToken(ctx context.Context, tkn string) (*accessToke
|
||||
}
|
||||
tokenID, subject = split[0], split[1]
|
||||
} else {
|
||||
verifier := op.NewAccessTokenVerifier(op.IssuerFromContext(ctx), s.keySet)
|
||||
verifier := op.NewAccessTokenVerifier(op.IssuerFromContext(ctx), s.accessTokenKeySet)
|
||||
claims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](ctx, tkn, verifier)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowPermissionDenied(err, "OIDC-Eib8e", "token is not valid or has expired")
|
||||
|
@ -4,7 +4,6 @@ package oidc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -330,8 +329,6 @@ func TestServer_VerifyClient(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
fmt.Printf("\n\n%s\n\n", tt.client.keyData)
|
||||
|
||||
authRequestID, err := Tester.CreateOIDCAuthRequest(CTX, tt.client.authReqClientID, Tester.Users[integration.FirstInstanceUsersKey][integration.Login].ID, redirectURI, oidc.ScopeOpenID)
|
||||
require.NoError(t, err)
|
||||
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
@ -22,31 +23,55 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
// keySetCache implements oidc.KeySet for Access Token verification.
|
||||
// Public Keys are cached in a 2-dimensional map of Instance ID and Key ID.
|
||||
type cachedPublicKey struct {
|
||||
lastUse atomic.Int64 // unix micro time.
|
||||
query.PublicKey
|
||||
}
|
||||
|
||||
func newCachedPublicKey(key query.PublicKey, now time.Time) *cachedPublicKey {
|
||||
cachedKey := &cachedPublicKey{
|
||||
PublicKey: key,
|
||||
}
|
||||
cachedKey.setLastUse(now)
|
||||
return cachedKey
|
||||
}
|
||||
|
||||
func (c *cachedPublicKey) setLastUse(now time.Time) {
|
||||
c.lastUse.Store(now.UnixMicro())
|
||||
}
|
||||
|
||||
func (c *cachedPublicKey) getLastUse() time.Time {
|
||||
return time.UnixMicro(c.lastUse.Load())
|
||||
}
|
||||
|
||||
func (c *cachedPublicKey) expired(now time.Time, validity time.Duration) bool {
|
||||
return c.getLastUse().Add(validity).Before(now)
|
||||
}
|
||||
|
||||
// publicKeyCache caches public keys in a 2-dimensional map of Instance ID and Key ID.
|
||||
// When a key is not present the queryKey function is called to obtain the key
|
||||
// from the database.
|
||||
type keySetCache struct {
|
||||
type publicKeyCache struct {
|
||||
mtx sync.RWMutex
|
||||
instanceKeys map[string]map[string]query.PublicKey
|
||||
queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)
|
||||
instanceKeys map[string]map[string]*cachedPublicKey
|
||||
queryKey func(ctx context.Context, keyID string) (query.PublicKey, error)
|
||||
clock clockwork.Clock
|
||||
}
|
||||
|
||||
// newKeySet initializes a keySetCache and starts a purging Go routine,
|
||||
// which runs once every purgeInterval.
|
||||
// newPublicKeyCache initializes a keySetCache starts a purging Go routine.
|
||||
// The purge routine deletes all public keys that are older than maxAge.
|
||||
// When the passed context is done, the purge routine will terminate.
|
||||
func newKeySet(background context.Context, purgeInterval time.Duration, queryKey func(ctx context.Context, keyID string, current time.Time) (query.PublicKey, error)) *keySetCache {
|
||||
k := &keySetCache{
|
||||
instanceKeys: make(map[string]map[string]query.PublicKey),
|
||||
func newPublicKeyCache(background context.Context, maxAge time.Duration, queryKey func(ctx context.Context, keyID string) (query.PublicKey, error)) *publicKeyCache {
|
||||
k := &publicKeyCache{
|
||||
instanceKeys: make(map[string]map[string]*cachedPublicKey),
|
||||
queryKey: queryKey,
|
||||
clock: clockwork.FromContext(background), // defaults to real clock
|
||||
}
|
||||
go k.purgeOnInterval(background, k.clock.NewTicker(purgeInterval))
|
||||
go k.purgeOnInterval(background, k.clock.NewTicker(maxAge/5), maxAge)
|
||||
return k
|
||||
}
|
||||
|
||||
func (k *keySetCache) purgeOnInterval(background context.Context, ticker clockwork.Ticker) {
|
||||
func (k *publicKeyCache) purgeOnInterval(background context.Context, ticker clockwork.Ticker, maxAge time.Duration) {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
@ -59,7 +84,7 @@ func (k *keySetCache) purgeOnInterval(background context.Context, ticker clockwo
|
||||
k.mtx.Lock()
|
||||
for instanceID, keys := range k.instanceKeys {
|
||||
for keyID, key := range keys {
|
||||
if key.Expiry().Before(k.clock.Now()) {
|
||||
if key.expired(k.clock.Now(), maxAge) {
|
||||
delete(keys, keyID)
|
||||
}
|
||||
}
|
||||
@ -71,19 +96,18 @@ func (k *keySetCache) purgeOnInterval(background context.Context, ticker clockwo
|
||||
}
|
||||
}
|
||||
|
||||
func (k *keySetCache) setKey(instanceID, keyID string, key query.PublicKey) {
|
||||
func (k *publicKeyCache) setKey(instanceID, keyID string, cachedKey *cachedPublicKey) {
|
||||
k.mtx.Lock()
|
||||
defer k.mtx.Unlock()
|
||||
|
||||
if keys, ok := k.instanceKeys[instanceID]; ok {
|
||||
keys[keyID] = key
|
||||
keys[keyID] = cachedKey
|
||||
return
|
||||
}
|
||||
|
||||
k.instanceKeys[instanceID] = map[string]query.PublicKey{keyID: key}
|
||||
k.instanceKeys[instanceID] = map[string]*cachedPublicKey{keyID: cachedKey}
|
||||
}
|
||||
|
||||
func (k *keySetCache) getKey(ctx context.Context, keyID string) (_ *jose.JSONWebKey, err error) {
|
||||
func (k *publicKeyCache) getKey(ctx context.Context, keyID string) (_ *cachedPublicKey, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
@ -94,22 +118,20 @@ func (k *keySetCache) getKey(ctx context.Context, keyID string) (_ *jose.JSONWeb
|
||||
k.mtx.RUnlock()
|
||||
|
||||
if ok {
|
||||
if key.Expiry().After(k.clock.Now()) {
|
||||
return jsonWebkey(key), nil
|
||||
key.setLastUse(k.clock.Now())
|
||||
} else {
|
||||
newKey, err := k.queryKey(ctx, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "OIDC-Zoh9E", "Errors.Key.ExpireBeforeNow")
|
||||
key = newCachedPublicKey(newKey, k.clock.Now())
|
||||
k.setKey(instanceID, keyID, key)
|
||||
}
|
||||
|
||||
key, err = k.queryKey(ctx, keyID, k.clock.Now())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
k.setKey(instanceID, keyID, key)
|
||||
return jsonWebkey(key), nil
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// VerifySignature implements the oidc.KeySet interface.
|
||||
func (k *keySetCache) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (_ []byte, err error) {
|
||||
func (k *publicKeyCache) verifySignature(ctx context.Context, jws *jose.JSONWebSignature, checkKeyExpiry bool) (_ []byte, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
@ -123,7 +145,45 @@ func (k *keySetCache) VerifySignature(ctx context.Context, jws *jose.JSONWebSign
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jws.Verify(key)
|
||||
if checkKeyExpiry && key.Expiry().Before(k.clock.Now()) {
|
||||
return nil, zerrors.ThrowInvalidArgument(err, "QUERY-ciF4k", "Errors.Key.ExpireBeforeNow")
|
||||
}
|
||||
return jws.Verify(jsonWebkey(key))
|
||||
}
|
||||
|
||||
type oidcKeySet struct {
|
||||
*publicKeyCache
|
||||
|
||||
keyExpiryCheck bool
|
||||
}
|
||||
|
||||
// newOidcKeySet returns an oidc.KeySet implementation around the passed cache.
|
||||
// It is advised to reuse the same cache if different key set configurations are required.
|
||||
func newOidcKeySet(cache *publicKeyCache, opts ...keySetOption) *oidcKeySet {
|
||||
k := &oidcKeySet{
|
||||
publicKeyCache: cache,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(k)
|
||||
}
|
||||
return k
|
||||
}
|
||||
|
||||
// VerifySignature implements the oidc.KeySet interface.
|
||||
func (k *oidcKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (_ []byte, err error) {
|
||||
return k.verifySignature(ctx, jws, k.keyExpiryCheck)
|
||||
}
|
||||
|
||||
type keySetOption func(*oidcKeySet)
|
||||
|
||||
// withKeyExpiryCheck forces VerifySignature to check the expiry of the public key.
|
||||
// Note that public key expiry is not part of the standard,
|
||||
// but is currently established behavior of zitadel.
|
||||
// We might want to remove this check in the future.
|
||||
func withKeyExpiryCheck(check bool) keySetOption {
|
||||
return func(k *oidcKeySet) {
|
||||
k.keyExpiryCheck = check
|
||||
}
|
||||
}
|
||||
|
||||
func jsonWebkey(key query.PublicKey) *jose.JSONWebKey {
|
||||
|
@ -66,100 +66,100 @@ var (
|
||||
seq: 3,
|
||||
expiry: clock.Now().Add(10 * time.Hour),
|
||||
},
|
||||
"exp1": {
|
||||
id: "key2",
|
||||
alg: "alg",
|
||||
use: domain.KeyUsageSigning,
|
||||
seq: 4,
|
||||
expiry: clock.Now().Add(-time.Hour),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func queryKeyDB(_ context.Context, keyID string, current time.Time) (query.PublicKey, error) {
|
||||
func queryKeyDB(_ context.Context, keyID string) (query.PublicKey, error) {
|
||||
if key, ok := keyDB[keyID]; ok {
|
||||
return key, nil
|
||||
}
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
|
||||
func Test_keySetCache(t *testing.T) {
|
||||
func Test_publicKeyCache(t *testing.T) {
|
||||
background, cancel := context.WithCancel(
|
||||
clockwork.AddToContext(context.Background(), clock),
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
// create an empty keySet with a purge go routine, runs every Hour
|
||||
keySet := newKeySet(background, time.Hour, queryKeyDB)
|
||||
// create an empty cache with a purge go routine, runs every minute.
|
||||
// keys are cached for at least 1 Hour after last use.
|
||||
cache := newPublicKeyCache(background, time.Hour, queryKeyDB)
|
||||
ctx := authz.NewMockContext("instanceID", "orgID", "userID")
|
||||
|
||||
// query error
|
||||
_, err := keySet.getKey(ctx, "key9")
|
||||
_, err := cache.getKey(ctx, "key9")
|
||||
require.Error(t, err)
|
||||
|
||||
want := &jose.JSONWebKey{
|
||||
KeyID: "key1",
|
||||
Algorithm: "alg",
|
||||
Use: domain.KeyUsageSigning.String(),
|
||||
}
|
||||
|
||||
// get key first time, populate the cache
|
||||
got, err := keySet.getKey(ctx, "key1")
|
||||
got, err := cache.getKey(ctx, "key1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, want, got)
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, keyDB["key1"], got.PublicKey)
|
||||
|
||||
// move time forward
|
||||
clock.Advance(5 * time.Minute)
|
||||
clock.Advance(15 * time.Minute)
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
// key should still be in cache
|
||||
keySet.mtx.RLock()
|
||||
_, ok := keySet.instanceKeys["instanceID"]["key1"]
|
||||
cache.mtx.RLock()
|
||||
_, ok := cache.instanceKeys["instanceID"]["key1"]
|
||||
require.True(t, ok)
|
||||
keySet.mtx.RUnlock()
|
||||
|
||||
// the key is expired, should error
|
||||
_, err = keySet.getKey(ctx, "key1")
|
||||
require.Error(t, err)
|
||||
|
||||
want = &jose.JSONWebKey{
|
||||
KeyID: "key2",
|
||||
Algorithm: "alg",
|
||||
Use: domain.KeyUsageSigning.String(),
|
||||
}
|
||||
|
||||
// get the second key from DB
|
||||
got, err = keySet.getKey(ctx, "key2")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, want, got)
|
||||
cache.mtx.RUnlock()
|
||||
|
||||
// move time forward
|
||||
clock.Advance(time.Hour)
|
||||
clock.Advance(50 * time.Minute)
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
// first key shoud be purged, second still present
|
||||
keySet.mtx.RLock()
|
||||
_, ok = keySet.instanceKeys["instanceID"]["key1"]
|
||||
require.False(t, ok)
|
||||
_, ok = keySet.instanceKeys["instanceID"]["key2"]
|
||||
require.True(t, ok)
|
||||
keySet.mtx.RUnlock()
|
||||
|
||||
// get the second key from cache
|
||||
got, err = keySet.getKey(ctx, "key2")
|
||||
// get the second key from DB
|
||||
got, err = cache.getKey(ctx, "key2")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, want, got)
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, keyDB["key2"], got.PublicKey)
|
||||
|
||||
// move time forward
|
||||
clock.Advance(10 * time.Hour)
|
||||
clock.Advance(15 * time.Minute)
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
// first key should be purged, second still present
|
||||
cache.mtx.RLock()
|
||||
_, ok = cache.instanceKeys["instanceID"]["key1"]
|
||||
require.False(t, ok)
|
||||
_, ok = cache.instanceKeys["instanceID"]["key2"]
|
||||
require.True(t, ok)
|
||||
cache.mtx.RUnlock()
|
||||
|
||||
// get the second key from cache
|
||||
got, err = cache.getKey(ctx, "key2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, keyDB["key2"], got.PublicKey)
|
||||
|
||||
// move time forward
|
||||
clock.Advance(2 * time.Hour)
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
// now the cache should be empty
|
||||
keySet.mtx.RLock()
|
||||
assert.Empty(t, keySet.instanceKeys)
|
||||
keySet.mtx.RUnlock()
|
||||
cache.mtx.RLock()
|
||||
assert.Empty(t, cache.instanceKeys)
|
||||
cache.mtx.RUnlock()
|
||||
}
|
||||
|
||||
func Test_keySetCache_VerifySignature(t *testing.T) {
|
||||
func Test_oidcKeySet_VerifySignature(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
k := newKeySet(ctx, time.Second, queryKeyDB)
|
||||
cache := newPublicKeyCache(ctx, time.Second, queryKeyDB)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts []keySetOption
|
||||
jws *jose.JSONWebSignature
|
||||
}{
|
||||
{
|
||||
@ -186,9 +186,33 @@ func Test_keySetCache_VerifySignature(t *testing.T) {
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "expired, no check",
|
||||
jws: &jose.JSONWebSignature{
|
||||
Signatures: []jose.Signature{{
|
||||
Header: jose.Header{
|
||||
KeyID: "exp1",
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "expired, with check",
|
||||
jws: &jose.JSONWebSignature{
|
||||
Signatures: []jose.Signature{{
|
||||
Header: jose.Header{
|
||||
KeyID: "exp1",
|
||||
},
|
||||
}},
|
||||
},
|
||||
opts: []keySetOption{
|
||||
withKeyExpiryCheck(true),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := newOidcKeySet(cache, tt.opts...)
|
||||
_, err := k.VerifySignature(ctx, tt.jws)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
@ -43,6 +43,7 @@ type Config struct {
|
||||
DefaultLoginURLV2 string
|
||||
DefaultLogoutURLV2 string
|
||||
Features Features
|
||||
PublicKeyCacheMaxAge time.Duration
|
||||
}
|
||||
|
||||
type EndpointConfig struct {
|
||||
@ -104,13 +105,17 @@ func NewServer(
|
||||
return nil, zerrors.ThrowInternal(err, "OIDC-EGrqd", "cannot create op config: %w")
|
||||
}
|
||||
storage := newStorage(config, command, query, repo, encryptionAlg, es, projections, externalSecure)
|
||||
var options []op.Option
|
||||
keyCache := newPublicKeyCache(context.TODO(), config.PublicKeyCacheMaxAge, query.GetPublicKeyByID)
|
||||
accessTokenKeySet := newOidcKeySet(keyCache, withKeyExpiryCheck(true))
|
||||
idTokenHintKeySet := newOidcKeySet(keyCache)
|
||||
|
||||
options := []op.Option{
|
||||
op.WithAccessTokenKeySet(accessTokenKeySet),
|
||||
op.WithIDTokenHintKeySet(idTokenHintKeySet),
|
||||
}
|
||||
if !externalSecure {
|
||||
options = append(options, op.WithAllowInsecure())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "OIDC-D3gq1", "cannot create options: %w")
|
||||
}
|
||||
provider, err := op.NewProvider(
|
||||
opConfig,
|
||||
storage,
|
||||
@ -127,7 +132,8 @@ func NewServer(
|
||||
repo: repo,
|
||||
query: query,
|
||||
command: command,
|
||||
keySet: newKeySet(context.TODO(), time.Hour, query.GetActivePublicKeyByID),
|
||||
accessTokenKeySet: accessTokenKeySet,
|
||||
idTokenHintKeySet: idTokenHintKeySet,
|
||||
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
|
||||
defaultLoginURLV2: config.DefaultLoginURLV2,
|
||||
defaultLogoutURLV2: config.DefaultLogoutURLV2,
|
||||
|
@ -23,10 +23,11 @@ type Server struct {
|
||||
*op.LegacyServer
|
||||
features Features
|
||||
|
||||
repo repository.Repository
|
||||
query *query.Queries
|
||||
command *command.Commands
|
||||
keySet *keySetCache
|
||||
repo repository.Repository
|
||||
query *query.Queries
|
||||
command *command.Commands
|
||||
accessTokenKeySet *oidcKeySet
|
||||
idTokenHintKeySet *oidcKeySet
|
||||
|
||||
defaultLoginURL string
|
||||
defaultLoginURLV2 string
|
||||
|
@ -399,7 +399,7 @@ func (wm *PublicKeyReadModel) Query() *eventstore.SearchQueryBuilder {
|
||||
Builder()
|
||||
}
|
||||
|
||||
func (q *Queries) GetActivePublicKeyByID(ctx context.Context, keyID string, current time.Time) (_ PublicKey, err error) {
|
||||
func (q *Queries) GetPublicKeyByID(ctx context.Context, keyID string) (_ PublicKey, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
@ -410,9 +410,6 @@ func (q *Queries) GetActivePublicKeyByID(ctx context.Context, keyID string, curr
|
||||
if model.Algorithm == "" || model.Key == nil {
|
||||
return nil, zerrors.ThrowNotFound(err, "QUERY-Ahf7x", "Errors.Key.NotFound")
|
||||
}
|
||||
if model.Expiry.Before(current) {
|
||||
return nil, zerrors.ThrowInvalidArgument(err, "QUERY-ciF4k", "Errors.Key.ExpireBeforeNow")
|
||||
}
|
||||
keyValue, err := crypto.Decrypt(model.Key, q.keyEncryptionAlgorithm)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "QUERY-Ie4oh", "Errors.Internal")
|
||||
|
@ -269,7 +269,7 @@ MZbmlCoBru+rC8ITlTX/0V1ZcsSbL8tYWhthyu9x6yjo1bH85wiVI4gs0MhU8f2a
|
||||
-----END PUBLIC KEY-----
|
||||
`
|
||||
|
||||
func TestQueries_GetActivePublicKeyByID(t *testing.T) {
|
||||
func TestQueries_GetPublicKeyByID(t *testing.T) {
|
||||
now := time.Now()
|
||||
future := now.Add(time.Hour)
|
||||
|
||||
@ -294,38 +294,6 @@ func TestQueries_GetActivePublicKeyByID(t *testing.T) {
|
||||
),
|
||||
wantErr: zerrors.ThrowNotFound(nil, "QUERY-Ahf7x", "Errors.Key.NotFound"),
|
||||
},
|
||||
{
|
||||
name: "expired error",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(key_repo.NewAddedEvent(context.Background(),
|
||||
&eventstore.Aggregate{
|
||||
ID: "keyID",
|
||||
Type: key_repo.AggregateType,
|
||||
ResourceOwner: "instanceID",
|
||||
InstanceID: "instanceID",
|
||||
Version: key_repo.AggregateVersion,
|
||||
},
|
||||
domain.KeyUsageSigning, "alg",
|
||||
&crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "alg",
|
||||
KeyID: "keyID",
|
||||
Crypted: []byte("private"),
|
||||
},
|
||||
&crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "alg",
|
||||
KeyID: "keyID",
|
||||
Crypted: []byte("public"),
|
||||
},
|
||||
now.Add(-time.Hour),
|
||||
now.Add(-time.Hour),
|
||||
)),
|
||||
),
|
||||
),
|
||||
wantErr: zerrors.ThrowInvalidArgument(nil, "QUERY-ciF4k", "Errors.Key.ExpireBeforeNow"),
|
||||
},
|
||||
{
|
||||
name: "decrypt error",
|
||||
eventstore: expectEventstore(
|
||||
@ -470,7 +438,7 @@ func TestQueries_GetActivePublicKeyByID(t *testing.T) {
|
||||
q.keyEncryptionAlgorithm = tt.encryption(t)
|
||||
}
|
||||
ctx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||
key, err := q.GetActivePublicKeyByID(ctx, "keyID", now)
|
||||
key, err := q.GetPublicKeyByID(ctx, "keyID")
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
return
|
||||
|
Loading…
x
Reference in New Issue
Block a user