2022-01-12 12:22:04 +00:00
package query
import (
2023-11-21 12:11:38 +00:00
"context"
2022-01-12 12:22:04 +00:00
"crypto/rsa"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
2023-11-21 12:11:38 +00:00
"io"
2022-01-12 12:22:04 +00:00
"math/big"
"regexp"
"testing"
2023-11-21 12:11:38 +00:00
"time"
2022-01-12 12:22:04 +00:00
2023-11-21 12:11:38 +00:00
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
2023-11-22 10:56:43 +00:00
"go.uber.org/mock/gomock"
2023-11-21 12:11:38 +00:00
"github.com/zitadel/zitadel/internal/api/authz"
2022-04-26 23:01:45 +00:00
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
errs "github.com/zitadel/zitadel/internal/errors"
2023-11-21 12:11:38 +00:00
"github.com/zitadel/zitadel/internal/eventstore"
key_repo "github.com/zitadel/zitadel/internal/repository/keypair"
2022-01-12 12:22:04 +00:00
)
2023-02-27 21:36:43 +00:00
var (
preparePublicKeysStmt = ` SELECT projections.keys4.id, ` +
` projections.keys4.creation_date, ` +
` projections.keys4.change_date, ` +
` projections.keys4.sequence, ` +
` projections.keys4.resource_owner, ` +
` projections.keys4.algorithm, ` +
` projections.keys4.use, ` +
` projections.keys4_public.expiry, ` +
` projections.keys4_public.key, ` +
` COUNT(*) OVER () ` +
` FROM projections.keys4 ` +
` LEFT JOIN projections.keys4_public ON projections.keys4.id = projections.keys4_public.id AND projections.keys4.instance_id = projections.keys4_public.instance_id ` +
` AS OF SYSTEM TIME '-1 ms' `
preparePublicKeysCols = [ ] string {
"id" ,
"creation_date" ,
"change_date" ,
"sequence" ,
"resource_owner" ,
"algorithm" ,
"use" ,
"expiry" ,
"key" ,
"count" ,
}
preparePrivateKeysStmt = ` SELECT projections.keys4.id, ` +
` projections.keys4.creation_date, ` +
` projections.keys4.change_date, ` +
` projections.keys4.sequence, ` +
` projections.keys4.resource_owner, ` +
` projections.keys4.algorithm, ` +
` projections.keys4.use, ` +
` projections.keys4_private.expiry, ` +
` projections.keys4_private.key, ` +
` COUNT(*) OVER () ` +
` FROM projections.keys4 ` +
` LEFT JOIN projections.keys4_private ON projections.keys4.id = projections.keys4_private.id AND projections.keys4.instance_id = projections.keys4_private.instance_id ` +
` AS OF SYSTEM TIME '-1 ms' `
)
2022-01-12 12:22:04 +00:00
func Test_KeyPrepares ( t * testing . T ) {
type want struct {
sqlExpectations sqlExpectation
err checkErr
}
tests := [ ] struct {
name string
prepare interface { }
want want
object interface { }
} {
{
name : "preparePublicKeysQuery no result" ,
prepare : preparePublicKeysQuery ,
want : want {
sqlExpectations : mockQueries (
2023-02-27 21:36:43 +00:00
regexp . QuoteMeta ( preparePublicKeysStmt ) ,
2022-01-12 12:22:04 +00:00
nil ,
nil ,
) ,
err : func ( err error ) ( error , bool ) {
if ! errs . IsNotFound ( err ) {
return fmt . Errorf ( "err should be zitadel.NotFoundError got: %w" , err ) , false
}
return nil , true
} ,
} ,
object : & PublicKeys { Keys : [ ] PublicKey { } } ,
} ,
{
name : "preparePublicKeysQuery found" ,
prepare : preparePublicKeysQuery ,
want : want {
sqlExpectations : mockQueries (
2023-02-27 21:36:43 +00:00
regexp . QuoteMeta ( preparePublicKeysStmt ) ,
preparePublicKeysCols ,
2022-01-12 12:22:04 +00:00
[ ] [ ] driver . Value {
{
"key-id" ,
testNow ,
testNow ,
uint64 ( 20211109 ) ,
"ro" ,
"RS256" ,
0 ,
testNow ,
[ ] byte ( "-----BEGIN RSA PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAsvX9P58JFxEs5C+L+H7W\nduFSWL5EPzber7C2m94klrSV6q0bAcrYQnGwFOlveThsY200hRbadKaKjHD7qIKH\nDEe0IY2PSRht33Jye52AwhkRw+M3xuQH/7R8LydnsNFk2KHpr5X2SBv42e37LjkE\nslKSaMRgJW+v0KZ30piY8QsdFRKKaVg5/Ajt1YToM1YVsdHXJ3vmXFMtypLdxwUD\ndIaLEX6pFUkU75KSuEQ/E2luT61Q3ta9kOWm9+0zvi7OMcbdekJT7mzcVnh93R1c\n13ZhQCLbh9A7si8jKFtaMWevjayrvqQABEcTN9N4Hoxcyg6l4neZtRDk75OMYcqm\nDQIDAQAB\n-----END RSA PUBLIC KEY-----\n" ) ,
} ,
} ,
) ,
} ,
object : & PublicKeys {
SearchResponse : SearchResponse {
Count : 1 ,
} ,
Keys : [ ] PublicKey {
& rsaPublicKey {
key : key {
id : "key-id" ,
creationDate : testNow ,
changeDate : testNow ,
sequence : 20211109 ,
resourceOwner : "ro" ,
algorithm : "RS256" ,
use : domain . KeyUsageSigning ,
} ,
expiry : testNow ,
publicKey : & rsa . PublicKey {
E : 65537 ,
N : fromBase16 ( "b2f5fd3f9f0917112ce42f8bf87ed676e15258be443f36deafb0b69bde2496b495eaad1b01cad84271b014e96f79386c636d348516da74a68a8c70fba882870c47b4218d8f49186ddf72727b9d80c21911c3e337c6e407ffb47c2f2767b0d164d8a1e9af95f6481bf8d9edfb2e3904b2529268c460256fafd0a677d29898f10b1d15128a695839fc08edd584e8335615b1d1d7277be65c532dca92ddc7050374868b117ea9154914ef9292b8443f13696e4fad50ded6bd90e5a6f7ed33be2ece31c6dd7a4253ee6cdc56787ddd1d5cd776614022db87d03bb22f23285b5a3167af8dacabbea40004471337d3781e8c5cca0ea5e27799b510e4ef938c61caa60d" ) ,
} ,
} ,
} ,
} ,
} ,
{
name : "preparePublicKeysQuery sql err" ,
prepare : preparePublicKeysQuery ,
want : want {
sqlExpectations : mockQueryErr (
2023-02-27 21:36:43 +00:00
regexp . QuoteMeta ( preparePublicKeysStmt ) ,
2022-01-12 12:22:04 +00:00
sql . ErrConnDone ,
) ,
err : func ( err error ) ( error , bool ) {
if ! errors . Is ( err , sql . ErrConnDone ) {
return fmt . Errorf ( "err should be sql.ErrConnDone got: %w" , err ) , false
}
return nil , true
} ,
} ,
2023-08-22 10:49:22 +00:00
object : ( * PublicKeys ) ( nil ) ,
2022-01-12 12:22:04 +00:00
} ,
{
name : "preparePrivateKeysQuery no result" ,
prepare : preparePrivateKeysQuery ,
want : want {
sqlExpectations : mockQueries (
2023-02-27 21:36:43 +00:00
regexp . QuoteMeta ( preparePrivateKeysStmt ) ,
2022-01-12 12:22:04 +00:00
nil ,
nil ,
) ,
err : func ( err error ) ( error , bool ) {
if ! errs . IsNotFound ( err ) {
return fmt . Errorf ( "err should be zitadel.NotFoundError got: %w" , err ) , false
}
return nil , true
} ,
} ,
object : & PrivateKeys { Keys : [ ] PrivateKey { } } ,
} ,
{
name : "preparePrivateKeysQuery found" ,
prepare : preparePrivateKeysQuery ,
want : want {
sqlExpectations : mockQueries (
2023-02-27 21:36:43 +00:00
regexp . QuoteMeta ( preparePrivateKeysStmt ) ,
preparePublicKeysCols ,
2022-01-12 12:22:04 +00:00
[ ] [ ] driver . Value {
{
"key-id" ,
testNow ,
testNow ,
uint64 ( 20211109 ) ,
"ro" ,
"RS256" ,
0 ,
testNow ,
[ ] byte ( ` { "Algorithm": "enc", "Crypted": "cHJpdmF0ZUtleQ==", "CryptoType": 0, "KeyID": "id"} ` ) ,
} ,
} ,
) ,
} ,
object : & PrivateKeys {
SearchResponse : SearchResponse {
Count : 1 ,
} ,
Keys : [ ] PrivateKey {
& privateKey {
key : key {
id : "key-id" ,
creationDate : testNow ,
changeDate : testNow ,
sequence : 20211109 ,
resourceOwner : "ro" ,
algorithm : "RS256" ,
use : domain . KeyUsageSigning ,
} ,
expiry : testNow ,
privateKey : & crypto . CryptoValue {
CryptoType : crypto . TypeEncryption ,
Algorithm : "enc" ,
KeyID : "id" ,
Crypted : [ ] byte ( "privateKey" ) ,
} ,
} ,
} ,
} ,
} ,
{
name : "preparePrivateKeysQuery sql err" ,
prepare : preparePrivateKeysQuery ,
want : want {
sqlExpectations : mockQueryErr (
2023-02-27 21:36:43 +00:00
regexp . QuoteMeta ( preparePrivateKeysStmt ) ,
2022-01-12 12:22:04 +00:00
sql . ErrConnDone ,
) ,
err : func ( err error ) ( error , bool ) {
if ! errors . Is ( err , sql . ErrConnDone ) {
return fmt . Errorf ( "err should be sql.ErrConnDone got: %w" , err ) , false
}
return nil , true
} ,
} ,
2023-08-22 10:49:22 +00:00
object : ( * PrivateKeys ) ( nil ) ,
2022-01-12 12:22:04 +00:00
} ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
2023-02-27 21:36:43 +00:00
assertPrepare ( t , tt . prepare , tt . object , tt . want . sqlExpectations , tt . want . err , defaultPrepareArgs ... )
2022-01-12 12:22:04 +00:00
} )
}
}
func fromBase16 ( base16 string ) * big . Int {
i , ok := new ( big . Int ) . SetString ( base16 , 16 )
if ! ok {
panic ( "bad number: " + base16 )
}
return i
}
2023-11-21 12:11:38 +00:00
const pubKey = ` -- -- - BEGIN PUBLIC KEY -- -- -
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAs38btwb3c7r0tMaQpGvB
mY + mPwMU / LpfuPoC0k2t4RsKp0fv40SMl50CRrHgk395wch8PMPYbl3 + 8 TtYAJuy
rFALIj3Ff1UcKIk0hOH5DDsfh7 / q2wFuncTmS6bifYo8CfSq2vDGnM7nZnEvxY / M
fSydZdcmIqlkUpfQmtzExw9 + tSe5Dxq6gn5JtlGgLgZGt69r5iMMrTEGhhVAXzNu
MZbmlCoBru + rC8ITlTX / 0 V1ZcsSbL8tYWhthyu9x6yjo1bH85wiVI4gs0MhU8f2a
+ kjL / KGZbR14Ua2eo6tonBZLC5DHWM2TkYXgRCDPufjcgmzN0Lm91E4P8KvBcvly
6 QIDAQAB
-- -- - END PUBLIC KEY -- -- -
`
func TestQueries_GetActivePublicKeyByID ( t * testing . T ) {
now := time . Now ( )
future := now . Add ( time . Hour )
tests := [ ] struct {
name string
eventstore func ( * testing . T ) * eventstore . Eventstore
encryption func ( * testing . T ) * crypto . MockEncryptionAlgorithm
want * rsaPublicKey
wantErr error
} {
{
name : "filter error" ,
eventstore : expectEventstore (
expectFilterError ( io . ErrClosedPipe ) ,
) ,
wantErr : io . ErrClosedPipe ,
} ,
{
name : "not found error" ,
eventstore : expectEventstore (
expectFilter ( ) ,
) ,
wantErr : errs . ThrowNotFound ( nil , "QUERY-Ahf7x" , "Errors.Key.NotFound" ) ,
} ,
{
name : "expired error" ,
eventstore : expectEventstore (
expectFilter (
eventFromEventPusher ( key_repo . NewAddedEvent ( context . Background ( ) ,
& eventstore . Aggregate {
ID : "keyID" ,
Type : key_repo . AggregateType ,
ResourceOwner : "instanceID" ,
InstanceID : "instanceID" ,
Version : key_repo . AggregateVersion ,
} ,
domain . KeyUsageSigning , "alg" ,
& crypto . CryptoValue {
CryptoType : crypto . TypeEncryption ,
Algorithm : "alg" ,
KeyID : "keyID" ,
Crypted : [ ] byte ( "private" ) ,
} ,
& crypto . CryptoValue {
CryptoType : crypto . TypeEncryption ,
Algorithm : "alg" ,
KeyID : "keyID" ,
Crypted : [ ] byte ( "public" ) ,
} ,
now . Add ( - time . Hour ) ,
now . Add ( - time . Hour ) ,
) ) ,
) ,
) ,
wantErr : errs . ThrowInvalidArgument ( nil , "QUERY-ciF4k" , "Errors.Key.ExpireBeforeNow" ) ,
} ,
{
name : "decrypt error" ,
eventstore : expectEventstore (
expectFilter (
eventFromEventPusher ( key_repo . NewAddedEvent ( context . Background ( ) ,
& eventstore . Aggregate {
ID : "keyID" ,
Type : key_repo . AggregateType ,
ResourceOwner : "instanceID" ,
InstanceID : "instanceID" ,
Version : key_repo . AggregateVersion ,
} ,
domain . KeyUsageSigning , "alg" ,
& crypto . CryptoValue {
CryptoType : crypto . TypeEncryption ,
Algorithm : "alg" ,
KeyID : "keyID" ,
Crypted : [ ] byte ( "private" ) ,
} ,
& crypto . CryptoValue {
CryptoType : crypto . TypeEncryption ,
Algorithm : "alg" ,
KeyID : "keyID" ,
Crypted : [ ] byte ( "public" ) ,
} ,
future ,
future ,
) ) ,
) ,
) ,
encryption : func ( t * testing . T ) * crypto . MockEncryptionAlgorithm {
encryption := crypto . NewMockEncryptionAlgorithm ( gomock . NewController ( t ) )
expect := encryption . EXPECT ( )
expect . Algorithm ( ) . Return ( "alg" )
expect . DecryptionKeyIDs ( ) . Return ( [ ] string { } )
return encryption
} ,
wantErr : errs . ThrowInternal ( nil , "QUERY-Ie4oh" , "Errors.Internal" ) ,
} ,
{
name : "parse error" ,
eventstore : expectEventstore (
expectFilter (
eventFromEventPusher ( key_repo . NewAddedEvent ( context . Background ( ) ,
& eventstore . Aggregate {
ID : "keyID" ,
Type : key_repo . AggregateType ,
ResourceOwner : "instanceID" ,
InstanceID : "instanceID" ,
Version : key_repo . AggregateVersion ,
} ,
domain . KeyUsageSigning , "alg" ,
& crypto . CryptoValue {
CryptoType : crypto . TypeEncryption ,
Algorithm : "alg" ,
KeyID : "keyID" ,
Crypted : [ ] byte ( "private" ) ,
} ,
& crypto . CryptoValue {
CryptoType : crypto . TypeEncryption ,
Algorithm : "alg" ,
KeyID : "keyID" ,
Crypted : [ ] byte ( "public" ) ,
} ,
future ,
future ,
) ) ,
) ,
) ,
encryption : func ( t * testing . T ) * crypto . MockEncryptionAlgorithm {
encryption := crypto . NewMockEncryptionAlgorithm ( gomock . NewController ( t ) )
expect := encryption . EXPECT ( )
expect . Algorithm ( ) . Return ( "alg" )
expect . DecryptionKeyIDs ( ) . Return ( [ ] string { "keyID" } )
expect . Decrypt ( [ ] byte ( "public" ) , "keyID" ) . Return ( [ ] byte ( "foo" ) , nil )
return encryption
} ,
wantErr : errs . ThrowInternal ( nil , "QUERY-Kai2Z" , "Errors.Internal" ) ,
} ,
{
name : "success" ,
eventstore : expectEventstore (
expectFilter (
eventFromEventPusher ( key_repo . NewAddedEvent ( context . Background ( ) ,
& eventstore . Aggregate {
ID : "keyID" ,
Type : key_repo . AggregateType ,
ResourceOwner : "instanceID" ,
InstanceID : "instanceID" ,
Version : key_repo . AggregateVersion ,
} ,
domain . KeyUsageSigning , "alg" ,
& crypto . CryptoValue {
CryptoType : crypto . TypeEncryption ,
Algorithm : "alg" ,
KeyID : "keyID" ,
Crypted : [ ] byte ( "private" ) ,
} ,
& crypto . CryptoValue {
CryptoType : crypto . TypeEncryption ,
Algorithm : "alg" ,
KeyID : "keyID" ,
Crypted : [ ] byte ( "public" ) ,
} ,
future ,
future ,
) ) ,
) ,
) ,
encryption : func ( t * testing . T ) * crypto . MockEncryptionAlgorithm {
encryption := crypto . NewMockEncryptionAlgorithm ( gomock . NewController ( t ) )
expect := encryption . EXPECT ( )
expect . Algorithm ( ) . Return ( "alg" )
expect . DecryptionKeyIDs ( ) . Return ( [ ] string { "keyID" } )
expect . Decrypt ( [ ] byte ( "public" ) , "keyID" ) . Return ( [ ] byte ( pubKey ) , nil )
return encryption
} ,
want : & rsaPublicKey {
key : key {
id : "keyID" ,
resourceOwner : "instanceID" ,
algorithm : "alg" ,
use : domain . KeyUsageSigning ,
} ,
expiry : future ,
publicKey : func ( ) * rsa . PublicKey {
publicKey , err := crypto . BytesToPublicKey ( [ ] byte ( pubKey ) )
if err != nil {
panic ( err )
}
return publicKey
} ( ) ,
} ,
} ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
q := & Queries {
eventstore : tt . eventstore ( t ) ,
}
if tt . encryption != nil {
q . keyEncryptionAlgorithm = tt . encryption ( t )
}
ctx := authz . NewMockContext ( "instanceID" , "orgID" , "loginClient" )
key , err := q . GetActivePublicKeyByID ( ctx , "keyID" , now )
if tt . wantErr != nil {
require . ErrorIs ( t , err , tt . wantErr )
return
}
require . NoError ( t , err )
require . NotNil ( t , key )
got := key . ( * rsaPublicKey )
assert . WithinDuration ( t , tt . want . expiry , got . expiry , time . Second )
tt . want . expiry = time . Time { }
got . expiry = time . Time { }
assert . Equal ( t , tt . want , got )
} )
}
}