From d225df2d46d14cca7f09507440e1cb6942cb50e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Wed, 15 Nov 2023 18:29:03 +0200 Subject: [PATCH] test instance keys query --- internal/query/key.go | 4 +- internal/query/key_test.go | 238 +++++++++++++++++++++++++++++++++++ internal/query/query_test.go | 81 ++++++++++++ 3 files changed, 322 insertions(+), 1 deletion(-) diff --git a/internal/query/key.go b/internal/query/key.go index 64dde168e05..4020fb49032 100644 --- a/internal/query/key.go +++ b/internal/query/key.go @@ -358,6 +358,7 @@ type PublicKeyReadModel struct { Algorithm string Key *crypto.CryptoValue Expiry time.Time + Usage domain.KeyUsage } func NewPublicKeyReadModel(keyID, resourceOwner string) *PublicKeyReadModel { @@ -380,6 +381,7 @@ func (wm *PublicKeyReadModel) Reduce() error { wm.Algorithm = e.Algorithm wm.Key = e.PublicKey.Key wm.Expiry = e.PublicKey.Expiry + wm.Usage = e.Usage } } return wm.ReadModel.Reduce() @@ -427,7 +429,7 @@ func (q *Queries) GetActivePublicKeyByID(ctx context.Context, keyID string, curr sequence: model.ProcessedSequence, resourceOwner: model.ResourceOwner, algorithm: model.Algorithm, - // use: , TBD, what events update this and do we need it? + use: model.Usage, }, expiry: model.Expiry, publicKey: publicKey, diff --git a/internal/query/key_test.go b/internal/query/key_test.go index a600b23eee1..b90077096e3 100644 --- a/internal/query/key_test.go +++ b/internal/query/key_test.go @@ -1,18 +1,27 @@ package query import ( + "context" "crypto/rsa" "database/sql" "database/sql/driver" "errors" "fmt" + "io" "math/big" "regexp" "testing" + "time" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" errs "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/eventstore" + key_repo "github.com/zitadel/zitadel/internal/repository/keypair" ) var ( @@ -247,3 +256,232 @@ func fromBase16(base16 string) *big.Int { } return i } + +const pubKey = `-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAs38btwb3c7r0tMaQpGvB +mY+mPwMU/LpfuPoC0k2t4RsKp0fv40SMl50CRrHgk395wch8PMPYbl3+8TtYAJuy +rFALIj3Ff1UcKIk0hOH5DDsfh7/q2wFuncTmS6bifYo8CfSq2vDGnM7nZnEvxY/M +fSydZdcmIqlkUpfQmtzExw9+tSe5Dxq6gn5JtlGgLgZGt69r5iMMrTEGhhVAXzNu +MZbmlCoBru+rC8ITlTX/0V1ZcsSbL8tYWhthyu9x6yjo1bH85wiVI4gs0MhU8f2a ++kjL/KGZbR14Ua2eo6tonBZLC5DHWM2TkYXgRCDPufjcgmzN0Lm91E4P8KvBcvly +6QIDAQAB +-----END PUBLIC KEY----- +` + +func TestQueries_GetActivePublicKeyByID(t *testing.T) { + now := time.Now() + future := now.Add(time.Hour) + + tests := []struct { + name string + eventstore func(*testing.T) *eventstore.Eventstore + encryption func(*testing.T) *crypto.MockEncryptionAlgorithm + want *rsaPublicKey + wantErr error + }{ + { + name: "filter error", + eventstore: expectEventstore( + expectFilterError(io.ErrClosedPipe), + ), + wantErr: io.ErrClosedPipe, + }, + { + name: "not found error", + eventstore: expectEventstore( + expectFilter(), + ), + wantErr: errs.ThrowNotFound(nil, "QUERY-Ahf7x", "Errors.Key.NotFound"), + }, + { + name: "expired error", + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher(key_repo.NewAddedEvent(context.Background(), + &eventstore.Aggregate{ + ID: "keyID", + Type: key_repo.AggregateType, + ResourceOwner: "instanceID", + InstanceID: "instanceID", + Version: key_repo.AggregateVersion, + }, + domain.KeyUsageSigning, "alg", + &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "alg", + KeyID: "keyID", + Crypted: []byte("private"), + }, + &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "alg", + KeyID: "keyID", + Crypted: []byte("public"), + }, + now.Add(-time.Hour), + now.Add(-time.Hour), + )), + ), + ), + wantErr: errs.ThrowInvalidArgument(nil, "QUERY-ciF4k", "Errors.Key.ExpireBeforeNow"), + }, + { + name: "decrypt error", + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher(key_repo.NewAddedEvent(context.Background(), + &eventstore.Aggregate{ + ID: "keyID", + Type: key_repo.AggregateType, + ResourceOwner: "instanceID", + InstanceID: "instanceID", + Version: key_repo.AggregateVersion, + }, + domain.KeyUsageSigning, "alg", + &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "alg", + KeyID: "keyID", + Crypted: []byte("private"), + }, + &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "alg", + KeyID: "keyID", + Crypted: []byte("public"), + }, + future, + future, + )), + ), + ), + encryption: func(t *testing.T) *crypto.MockEncryptionAlgorithm { + encryption := crypto.NewMockEncryptionAlgorithm(gomock.NewController(t)) + expect := encryption.EXPECT() + expect.Algorithm().Return("alg") + expect.DecryptionKeyIDs().Return([]string{}) + return encryption + }, + wantErr: errs.ThrowInternal(nil, "QUERY-Ie4oh", "Errors.Internal"), + }, + { + name: "parse error", + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher(key_repo.NewAddedEvent(context.Background(), + &eventstore.Aggregate{ + ID: "keyID", + Type: key_repo.AggregateType, + ResourceOwner: "instanceID", + InstanceID: "instanceID", + Version: key_repo.AggregateVersion, + }, + domain.KeyUsageSigning, "alg", + &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "alg", + KeyID: "keyID", + Crypted: []byte("private"), + }, + &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "alg", + KeyID: "keyID", + Crypted: []byte("public"), + }, + future, + future, + )), + ), + ), + encryption: func(t *testing.T) *crypto.MockEncryptionAlgorithm { + encryption := crypto.NewMockEncryptionAlgorithm(gomock.NewController(t)) + expect := encryption.EXPECT() + expect.Algorithm().Return("alg") + expect.DecryptionKeyIDs().Return([]string{"keyID"}) + expect.Decrypt([]byte("public"), "keyID").Return([]byte("foo"), nil) + return encryption + }, + wantErr: errs.ThrowInternal(nil, "QUERY-Kai2Z", "Errors.Internal"), + }, + { + name: "success", + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher(key_repo.NewAddedEvent(context.Background(), + &eventstore.Aggregate{ + ID: "keyID", + Type: key_repo.AggregateType, + ResourceOwner: "instanceID", + InstanceID: "instanceID", + Version: key_repo.AggregateVersion, + }, + domain.KeyUsageSigning, "alg", + &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "alg", + KeyID: "keyID", + Crypted: []byte("private"), + }, + &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "alg", + KeyID: "keyID", + Crypted: []byte("public"), + }, + future, + future, + )), + ), + ), + encryption: func(t *testing.T) *crypto.MockEncryptionAlgorithm { + encryption := crypto.NewMockEncryptionAlgorithm(gomock.NewController(t)) + expect := encryption.EXPECT() + expect.Algorithm().Return("alg") + expect.DecryptionKeyIDs().Return([]string{"keyID"}) + expect.Decrypt([]byte("public"), "keyID").Return([]byte(pubKey), nil) + return encryption + }, + want: &rsaPublicKey{ + key: key{ + id: "keyID", + resourceOwner: "instanceID", + algorithm: "alg", + use: domain.KeyUsageSigning, + }, + expiry: future, + publicKey: func() *rsa.PublicKey { + publicKey, err := crypto.BytesToPublicKey([]byte(pubKey)) + if err != nil { + panic(err) + } + return publicKey + }(), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := &Queries{ + eventstore: tt.eventstore(t), + } + if tt.encryption != nil { + q.keyEncryptionAlgorithm = tt.encryption(t) + } + ctx := authz.NewMockContext("instanceID", "orgID", "loginClient") + key, err := q.GetActivePublicKeyByID(ctx, "keyID", now) + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + return + } + require.NoError(t, err) + require.NotNil(t, key) + + got := key.(*rsaPublicKey) + assert.WithinDuration(t, tt.want.expiry, got.expiry, time.Second) + tt.want.expiry = time.Time{} + got.expiry = time.Time{} + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/query/query_test.go b/internal/query/query_test.go index fc16ee9fad9..f1dee698147 100644 --- a/internal/query/query_test.go +++ b/internal/query/query_test.go @@ -1,11 +1,92 @@ package query import ( + "database/sql" "testing" + "time" "github.com/stretchr/testify/assert" + + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/repository" + "github.com/zitadel/zitadel/internal/eventstore/repository/mock" + action_repo "github.com/zitadel/zitadel/internal/repository/action" + "github.com/zitadel/zitadel/internal/repository/authrequest" + "github.com/zitadel/zitadel/internal/repository/feature" + "github.com/zitadel/zitadel/internal/repository/idpintent" + iam_repo "github.com/zitadel/zitadel/internal/repository/instance" + key_repo "github.com/zitadel/zitadel/internal/repository/keypair" + "github.com/zitadel/zitadel/internal/repository/limits" + "github.com/zitadel/zitadel/internal/repository/oidcsession" + "github.com/zitadel/zitadel/internal/repository/org" + proj_repo "github.com/zitadel/zitadel/internal/repository/project" + quota_repo "github.com/zitadel/zitadel/internal/repository/quota" + "github.com/zitadel/zitadel/internal/repository/session" + usr_repo "github.com/zitadel/zitadel/internal/repository/user" + "github.com/zitadel/zitadel/internal/repository/usergrant" ) +type expect func(mockRepository *mock.MockRepository) + +func expectEventstore(expects ...expect) func(*testing.T) *eventstore.Eventstore { + return func(t *testing.T) *eventstore.Eventstore { + m := mock.NewRepo(t) + for _, e := range expects { + e(m) + } + es := eventstore.NewEventstore( + &eventstore.Config{ + Querier: m.MockQuerier, + Pusher: m.MockPusher, + }, + ) + iam_repo.RegisterEventMappers(es) + org.RegisterEventMappers(es) + usr_repo.RegisterEventMappers(es) + proj_repo.RegisterEventMappers(es) + usergrant.RegisterEventMappers(es) + key_repo.RegisterEventMappers(es) + action_repo.RegisterEventMappers(es) + session.RegisterEventMappers(es) + idpintent.RegisterEventMappers(es) + authrequest.RegisterEventMappers(es) + oidcsession.RegisterEventMappers(es) + quota_repo.RegisterEventMappers(es) + limits.RegisterEventMappers(es) + feature.RegisterEventMappers(es) + return es + } +} + +func expectFilter(events ...eventstore.Event) expect { + return func(m *mock.MockRepository) { + m.ExpectFilterEvents(events...) + } +} +func expectFilterError(err error) expect { + return func(m *mock.MockRepository) { + m.ExpectFilterEventsError(err) + } +} + +func eventFromEventPusher(event eventstore.Command) *repository.Event { + data, _ := eventstore.EventData(event) + return &repository.Event{ + InstanceID: event.Aggregate().InstanceID, + ID: "", + Seq: 0, + CreationDate: time.Time{}, + Typ: event.Type(), + Data: data, + EditorUser: event.Creator(), + Version: event.Aggregate().Version, + AggregateID: event.Aggregate().ID, + AggregateType: event.Aggregate().Type, + ResourceOwner: sql.NullString{String: event.Aggregate().ResourceOwner, Valid: event.Aggregate().ResourceOwner != ""}, + Constraints: event.UniqueConstraints(), + } +} + func Test_cleanStaticQueries(t *testing.T) { query := `select foo,