diff --git a/internal/crypto/code.go b/internal/crypto/code.go index 8738aef074..7cff29a746 100644 --- a/internal/crypto/code.go +++ b/internal/crypto/code.go @@ -137,6 +137,9 @@ func generateRandomString(length uint, chars []rune) (string, error) { } func verifyEncryptedCode(cryptoCode *CryptoValue, verificationCode string, alg EncryptionAlgorithm) error { + if cryptoCode == nil { + return errors.ThrowInvalidArgument(nil, "CRYPT-aqrFV", "cryptoCode must not be nil") + } code, err := DecryptString(cryptoCode, alg) if err != nil { return err @@ -149,5 +152,8 @@ func verifyEncryptedCode(cryptoCode *CryptoValue, verificationCode string, alg E } func verifyHashedCode(cryptoCode *CryptoValue, verificationCode string, alg HashAlgorithm) error { + if cryptoCode == nil { + return errors.ThrowInvalidArgument(nil, "CRYPT-2q3r", "cryptoCode must not be nil") + } return CompareHash(cryptoCode, []byte(verificationCode), alg) } diff --git a/internal/crypto/code_test.go b/internal/crypto/code_test.go index 4f62f22898..8e1947c2d8 100644 --- a/internal/crypto/code_test.go +++ b/internal/crypto/code_test.go @@ -5,10 +5,11 @@ import ( "time" "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" + + "github.com/caos/zitadel/internal/errors" ) -func Test_Encrypted_OK(t *testing.T) { +func createMockEncryptionAlg(t *testing.T) EncryptionAlgorithm { mCrypto := NewMockEncryptionAlgorithm(gomock.NewController(t)) mCrypto.EXPECT().Algorithm().AnyTimes().Return("enc") mCrypto.EXPECT().EncryptionKeyID().AnyTimes().Return("id") @@ -19,74 +20,333 @@ func Test_Encrypted_OK(t *testing.T) { }, ) mCrypto.EXPECT().DecryptString(gomock.Any(), gomock.Any()).DoAndReturn( - func(code []byte, _ string) (string, error) { + func(code []byte, keyID string) (string, error) { + if keyID != "id" { + return "", errors.ThrowInternal(nil, "id", "invalid key id") + } return string(code), nil }, ) - generator := NewEncryptionGenerator(6, 0, mCrypto, Digits) - - crypto, code, err := NewCode(generator) - assert.NoError(t, err) - - decrypted, err := DecryptString(crypto, generator.alg) - assert.NoError(t, err) - assert.Equal(t, code, decrypted) - assert.Equal(t, 6, len(decrypted)) + return mCrypto } -func Test_Verify_Encrypted_OK(t *testing.T) { - mCrypto := NewMockEncryptionAlgorithm(gomock.NewController(t)) - mCrypto.EXPECT().Algorithm().AnyTimes().Return("enc") - mCrypto.EXPECT().EncryptionKeyID().AnyTimes().Return("id") - mCrypto.EXPECT().DecryptionKeyIDs().AnyTimes().Return([]string{"id"}) - mCrypto.EXPECT().DecryptString(gomock.Any(), gomock.Any()).DoAndReturn( - func(code []byte, _ string) (string, error) { - return string(code), nil +func createMockHashAlg(t *testing.T) HashAlgorithm { + mCrypto := NewMockHashAlgorithm(gomock.NewController(t)) + mCrypto.EXPECT().Algorithm().AnyTimes().Return("hash") + mCrypto.EXPECT().Hash(gomock.Any()).DoAndReturn( + func(code []byte) ([]byte, error) { + return code, nil }, ) - creationDate := time.Now() - code := &CryptoValue{ - CryptoType: TypeEncryption, - Algorithm: "enc", - KeyID: "id", - Crypted: []byte("code"), + mCrypto.EXPECT().CompareHash(gomock.Any(), gomock.Any()).DoAndReturn( + func(hashed, comparer []byte) error { + if string(hashed) != string(comparer) { + return errors.ThrowInternal(nil, "id", "invalid") + } + return nil + }, + ) + return mCrypto +} + +func createMockCrypto(t *testing.T) Crypto { + mCrypto := NewMockCrypto(gomock.NewController(t)) + mCrypto.EXPECT().Algorithm().AnyTimes().Return("crypto") + return mCrypto +} + +func createMockGenerator(t *testing.T, crypto Crypto) Generator { + mGenerator := NewMockGenerator(gomock.NewController(t)) + mGenerator.EXPECT().Alg().AnyTimes().Return(crypto) + return mGenerator +} + +func TestIsCodeExpired(t *testing.T) { + type args struct { + creationDate time.Time + expiry time.Duration } - generator := NewEncryptionGenerator(6, 0, mCrypto, Digits) - - err := VerifyCode(creationDate, 1*time.Hour, code, "code", generator) - assert.NoError(t, err) -} -func Test_Verify_Encrypted_Invalid_Err(t *testing.T) { - mCrypto := NewMockEncryptionAlgorithm(gomock.NewController(t)) - mCrypto.EXPECT().Algorithm().AnyTimes().Return("enc") - mCrypto.EXPECT().EncryptionKeyID().AnyTimes().Return("id") - mCrypto.EXPECT().DecryptionKeyIDs().AnyTimes().Return([]string{"id"}) - mCrypto.EXPECT().DecryptString(gomock.Any(), gomock.Any()).DoAndReturn( - func(code []byte, _ string) (string, error) { - return string(code), nil + tests := []struct { + name string + args args + want bool + }{ + { + "not expired", + args{ + creationDate: time.Now(), + expiry: time.Duration(5 * time.Minute), + }, + false, + }, + { + "expired", + args{ + creationDate: time.Now().Add(-5 * time.Minute), + expiry: time.Duration(5 * time.Minute), + }, + true, }, - ) - creationDate := time.Now() - code := &CryptoValue{ - CryptoType: TypeEncryption, - Algorithm: "enc", - KeyID: "id", - Crypted: []byte("code"), } - generator := NewEncryptionGenerator(6, 0, mCrypto, Digits) - - err := VerifyCode(creationDate, 1*time.Hour, code, "wrong", generator) - assert.Error(t, err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsCodeExpired(tt.args.creationDate, tt.args.expiry); got != tt.want { + t.Errorf("IsCodeExpired() = %v, want %v", got, tt.want) + } + }) + } } -func TestIsCodeExpired_Expired(t *testing.T) { - creationDate := time.Date(2019, time.April, 1, 0, 0, 0, 0, time.UTC) - expired := IsCodeExpired(creationDate, 1*time.Hour) - assert.True(t, expired) +func TestVerifyCode(t *testing.T) { + type args struct { + creationDate time.Time + expiry time.Duration + cryptoCode *CryptoValue + verificationCode string + g Generator + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "expired", + args{ + creationDate: time.Now().Add(-5 * time.Minute), + expiry: 5 * time.Minute, + cryptoCode: nil, + verificationCode: "", + g: nil, + }, + true, + }, + { + "unsupported alg err", + args{ + creationDate: time.Now(), + expiry: 5 * time.Minute, + cryptoCode: nil, + verificationCode: "code", + g: createMockGenerator(t, createMockCrypto(t)), + }, + true, + }, + { + "encryption alg ok", + args{ + creationDate: time.Now(), + expiry: 5 * time.Minute, + cryptoCode: &CryptoValue{ + CryptoType: TypeEncryption, + Algorithm: "enc", + KeyID: "id", + Crypted: []byte("code"), + }, + verificationCode: "code", + g: createMockGenerator(t, createMockEncryptionAlg(t)), + }, + false, + }, + { + "hash alg ok", + args{ + creationDate: time.Now(), + expiry: 5 * time.Minute, + cryptoCode: &CryptoValue{ + CryptoType: TypeHash, + Algorithm: "hash", + Crypted: []byte("code"), + }, + verificationCode: "code", + g: createMockGenerator(t, createMockHashAlg(t)), + }, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := VerifyCode(tt.args.creationDate, tt.args.expiry, tt.args.cryptoCode, tt.args.verificationCode, tt.args.g); (err != nil) != tt.wantErr { + t.Errorf("VerifyCode() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } } -func TestIsCodeExpired_NotExpired(t *testing.T) { - creationDate := time.Now() - expired := IsCodeExpired(creationDate, 1*time.Hour) - assert.False(t, expired) +func Test_verifyEncryptedCode(t *testing.T) { + type args struct { + cryptoCode *CryptoValue + verificationCode string + alg EncryptionAlgorithm + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "nil error", + args{ + cryptoCode: nil, + verificationCode: "", + alg: createMockEncryptionAlg(t), + }, + true, + }, + { + "wrong cryptotype error", + args{ + cryptoCode: &CryptoValue{ + CryptoType: TypeHash, + Crypted: nil, + }, + verificationCode: "", + alg: createMockEncryptionAlg(t), + }, + true, + }, + { + "wrong algorithm error", + args{ + cryptoCode: &CryptoValue{ + CryptoType: TypeEncryption, + Algorithm: "enc2", + Crypted: nil, + }, + verificationCode: "", + alg: createMockEncryptionAlg(t), + }, + true, + }, + { + "wrong key id error", + args{ + cryptoCode: &CryptoValue{ + CryptoType: TypeEncryption, + Algorithm: "enc", + Crypted: nil, + }, + verificationCode: "wrong", + alg: createMockEncryptionAlg(t), + }, + true, + }, + { + "wrong verification code error", + args{ + cryptoCode: &CryptoValue{ + CryptoType: TypeEncryption, + Algorithm: "enc", + KeyID: "id", + Crypted: []byte("code"), + }, + verificationCode: "wrong", + alg: createMockEncryptionAlg(t), + }, + true, + }, + { + "verification code ok", + args{ + cryptoCode: &CryptoValue{ + CryptoType: TypeEncryption, + Algorithm: "enc", + KeyID: "id", + Crypted: []byte("code"), + }, + verificationCode: "code", + alg: createMockEncryptionAlg(t), + }, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := verifyEncryptedCode(tt.args.cryptoCode, tt.args.verificationCode, tt.args.alg); (err != nil) != tt.wantErr { + t.Errorf("verifyEncryptedCode() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_verifyHashedCode(t *testing.T) { + type args struct { + cryptoCode *CryptoValue + verificationCode string + alg HashAlgorithm + } + tests := []struct { + name string + args args + wantErr bool + }{ + + { + "nil error", + args{ + cryptoCode: nil, + verificationCode: "", + alg: createMockHashAlg(t), + }, + true, + }, + { + "wrong cryptotype error", + args{ + cryptoCode: &CryptoValue{ + CryptoType: TypeEncryption, + Crypted: nil, + }, + verificationCode: "", + alg: createMockHashAlg(t), + }, + true, + }, + { + "wrong algorithm error", + args{ + cryptoCode: &CryptoValue{ + CryptoType: TypeHash, + Algorithm: "hash2", + Crypted: nil, + }, + verificationCode: "", + alg: createMockHashAlg(t), + }, + true, + }, + { + "wrong verification code error", + args{ + cryptoCode: &CryptoValue{ + CryptoType: TypeHash, + Algorithm: "hash", + Crypted: []byte("code"), + }, + verificationCode: "wrong", + alg: createMockHashAlg(t), + }, + true, + }, + { + "verification code ok", + args{ + cryptoCode: &CryptoValue{ + CryptoType: TypeHash, + Algorithm: "hash", + Crypted: []byte("code"), + }, + verificationCode: "code", + alg: createMockHashAlg(t), + }, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := verifyHashedCode(tt.args.cryptoCode, tt.args.verificationCode, tt.args.alg); (err != nil) != tt.wantErr { + t.Errorf("verifyHashedCode() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } } diff --git a/internal/crypto/crypto.go b/internal/crypto/crypto.go index c1b46d69d5..d7f3269321 100644 --- a/internal/crypto/crypto.go +++ b/internal/crypto/crypto.go @@ -99,5 +99,8 @@ func Hash(value []byte, alg HashAlgorithm) (*CryptoValue, error) { } func CompareHash(value *CryptoValue, comparer []byte, alg HashAlgorithm) error { + if value.Algorithm != alg.Algorithm() { + return errors.ThrowInvalidArgument(nil, "CRYPT-HF32f", "value was hash with a different algorithm") + } return alg.CompareHash(value.Crypted, comparer) } diff --git a/internal/crypto/generate.go b/internal/crypto/generate.go index 4c43e7e9b3..fd3de9f759 100644 --- a/internal/crypto/generate.go +++ b/internal/crypto/generate.go @@ -1,3 +1,4 @@ package crypto //go:generate mockgen -source crypto.go -destination ./crypto_mock.go -package crypto +//go:generate mockgen -source code.go -destination ./code_mock.go -package crypto