zitadel/internal/crypto/crypto_test.go
Tim Möhlmann ad9422a7d0
fix(crypto): check for nil client secret (#7729)
When creating an app without secret or other type of authentication method,
like JWT, and the authentication type is switched afterwards the app would remain without generated secret.
If then client authentication with secret is attempted, for example on the token endpoint, the handler would panic in the crypto.CompareHash function on the nile pointer to the CryptoValue.

This fix introduces a nil pointer check in crypt.CompareHash and returns a error.

The issue was reported over discord: https://discord.com/channels/927474939156643850/1222971118730875020
Possible fix was suggested here: https://github.com/zitadel/zitadel/pull/6999#discussion_r1553503088
This bug only applies to zitadel versions <=2.49.1.
2024-04-09 08:44:52 +02:00

279 lines
5.8 KiB
Go

package crypto
import (
"bytes"
"errors"
"reflect"
"testing"
)
type mockEncCrypto struct {
}
func (m *mockEncCrypto) Algorithm() string {
return "enc"
}
func (m *mockEncCrypto) Encrypt(value []byte) ([]byte, error) {
return value, nil
}
func (m *mockEncCrypto) Decrypt(value []byte, _ string) ([]byte, error) {
return value, nil
}
func (m *mockEncCrypto) DecryptString(value []byte, _ string) (string, error) {
return string(value), nil
}
func (m *mockEncCrypto) EncryptionKeyID() string {
return "keyID"
}
func (m *mockEncCrypto) DecryptionKeyIDs() []string {
return []string{"keyID"}
}
type mockHashCrypto struct {
}
func (m *mockHashCrypto) Algorithm() string {
return "hash"
}
func (m *mockHashCrypto) Hash(value []byte) ([]byte, error) {
return value, nil
}
func (m *mockHashCrypto) CompareHash(hashed, comparer []byte) error {
if !bytes.Equal(hashed, comparer) {
return errors.New("not equal")
}
return nil
}
type alg struct{}
func (a *alg) Algorithm() string {
return "alg"
}
func TestCrypt(t *testing.T) {
type args struct {
value []byte
c Crypto
}
tests := []struct {
name string
args args
want *CryptoValue
wantErr bool
}{
{
"encrypt",
args{[]byte("test"), &mockEncCrypto{}},
&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID", Crypted: []byte("test")},
false,
},
{
"hash",
args{[]byte("test"), &mockHashCrypto{}},
&CryptoValue{CryptoType: TypeHash, Algorithm: "hash", Crypted: []byte("test")},
false,
},
{
"wrong type",
args{[]byte("test"), &alg{}},
nil,
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Crypt(tt.args.value, tt.args.c)
if (err != nil) != tt.wantErr {
t.Errorf("Crypt() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Crypt() = %v, want %v", got, tt.want)
}
})
}
}
func TestEncrypt(t *testing.T) {
type args struct {
value []byte
c EncryptionAlgorithm
}
tests := []struct {
name string
args args
want *CryptoValue
wantErr bool
}{
{
"ok",
args{[]byte("test"), &mockEncCrypto{}},
&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID", Crypted: []byte("test")},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Encrypt(tt.args.value, tt.args.c)
if (err != nil) != tt.wantErr {
t.Errorf("Encrypt() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Encrypt() = %v, want %v", got, tt.want)
}
})
}
}
func TestDecrypt(t *testing.T) {
type args struct {
value *CryptoValue
c EncryptionAlgorithm
}
tests := []struct {
name string
args args
want []byte
wantErr bool
}{
{
"ok",
args{&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID", Crypted: []byte("test")}, &mockEncCrypto{}},
[]byte("test"),
false,
},
{
"wrong id",
args{&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID2", Crypted: []byte("test")}, &mockEncCrypto{}},
nil,
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Decrypt(tt.args.value, tt.args.c)
if (err != nil) != tt.wantErr {
t.Errorf("Decrypt() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Decrypt() = %v, want %v", got, tt.want)
}
})
}
}
func TestDecryptString(t *testing.T) {
type args struct {
value *CryptoValue
c EncryptionAlgorithm
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
"ok",
args{&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID", Crypted: []byte("test")}, &mockEncCrypto{}},
"test",
false,
},
{
"wrong id",
args{&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID2", Crypted: []byte("test")}, &mockEncCrypto{}},
"",
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := DecryptString(tt.args.value, tt.args.c)
if (err != nil) != tt.wantErr {
t.Errorf("DecryptString() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("DecryptString() = %v, want %v", got, tt.want)
}
})
}
}
func TestHash(t *testing.T) {
type args struct {
value []byte
c HashAlgorithm
}
tests := []struct {
name string
args args
want *CryptoValue
wantErr bool
}{
{
"ok",
args{[]byte("test"), &mockHashCrypto{}},
&CryptoValue{CryptoType: TypeHash, Algorithm: "hash", Crypted: []byte("test")},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Hash(tt.args.value, tt.args.c)
if (err != nil) != tt.wantErr {
t.Errorf("Hash() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Hash() = %v, want %v", got, tt.want)
}
})
}
}
func TestCompareHash(t *testing.T) {
type args struct {
value *CryptoValue
comparer []byte
c HashAlgorithm
}
tests := []struct {
name string
args args
wantErr bool
}{
{
"ok",
args{&CryptoValue{CryptoType: TypeHash, Algorithm: "hash", Crypted: []byte("test")}, []byte("test"), &mockHashCrypto{}},
false,
},
{
"wrong",
args{&CryptoValue{CryptoType: TypeHash, Algorithm: "hash", Crypted: []byte("test")}, []byte("test"), &mockHashCrypto{}},
false,
},
{
"nil",
args{nil, []byte("test2"), &mockHashCrypto{}},
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := CompareHash(tt.args.value, tt.args.comparer, tt.args.c); (err != nil) != tt.wantErr {
t.Errorf("CompareHash() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}