diff --git a/internal/command/oidc_session.go b/internal/command/oidc_session.go index fc384d0d30..1fe82198bf 100644 --- a/internal/command/oidc_session.go +++ b/internal/command/oidc_session.go @@ -293,7 +293,7 @@ func (c *Commands) decryptRefreshToken(refreshToken string) (sessionID, refreshT } decrypted, err := c.keyAlgorithm.DecryptString(decoded, c.keyAlgorithm.EncryptionKeyID()) if err != nil { - return "", "", err + return "", "", zerrors.ThrowInvalidArgument(err, "OIDCS-Jei0i", "Errors.User.RefreshToken.Invalid") } return parseRefreshToken(decrypted) } diff --git a/internal/crypto/aes.go b/internal/crypto/aes.go index e943c2ca8e..f57a78fb85 100644 --- a/internal/crypto/aes.go +++ b/internal/crypto/aes.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "encoding/base64" "io" + "unicode/utf8" "github.com/zitadel/zitadel/internal/zerrors" ) @@ -46,15 +47,17 @@ func (a *AESCrypto) Decrypt(value []byte, keyID string) ([]byte, error) { return DecryptAES(value, key) } +// DecryptString decrypts the value using the key identified by keyID. +// When the decrypted value contains non-UTF8 characters an error is returned. func (a *AESCrypto) DecryptString(value []byte, keyID string) (string, error) { - key, err := a.decryptionKey(keyID) + b, err := a.Decrypt(value, keyID) if err != nil { return "", err } - b, err := DecryptAES(value, key) - if err != nil { - return "", err + if !utf8.Valid(b) { + return "", zerrors.ThrowPreconditionFailed(err, "CRYPT-hiCh0", "non-UTF-8 in decrypted string") } + return string(b), nil } diff --git a/internal/crypto/aes_test.go b/internal/crypto/aes_test.go index 5731f320eb..128fd6c4dc 100644 --- a/internal/crypto/aes_test.go +++ b/internal/crypto/aes_test.go @@ -1,18 +1,109 @@ package crypto import ( + "context" + "errors" "testing" + "unicode/utf8" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/internal/zerrors" ) -// TODO: refactor test style -func TestDecrypt_OK(t *testing.T) { - encryptedpw, err := EncryptAESString("ThisIsMySecretPw", "passphrasewhichneedstobe32bytes!") - assert.NoError(t, err) - - decryptedpw, err := DecryptAESString(encryptedpw, "passphrasewhichneedstobe32bytes!") - assert.NoError(t, err) - - assert.Equal(t, "ThisIsMySecretPw", decryptedpw) +type mockKeyStorage struct { + keys Keys +} + +func (s *mockKeyStorage) ReadKeys() (Keys, error) { + return s.keys, nil +} + +func (s *mockKeyStorage) ReadKey(id string) (*Key, error) { + return &Key{ + ID: id, + Value: s.keys[id], + }, nil +} + +func (*mockKeyStorage) CreateKeys(context.Context, ...*Key) error { + return errors.New("mockKeyStorage.CreateKeys not implemented") +} + +func newTestAESCrypto(t testing.TB) *AESCrypto { + keyConfig := &KeyConfig{ + EncryptionKeyID: "keyID", + DecryptionKeyIDs: []string{"keyID"}, + } + keys := Keys{"keyID": "ThisKeyNeedsToHave32Characters!!"} + aesCrypto, err := NewAESCrypto(keyConfig, &mockKeyStorage{keys: keys}) + require.NoError(t, err) + return aesCrypto +} + +func TestAESCrypto_DecryptString(t *testing.T) { + aesCrypto := newTestAESCrypto(t) + const input = "SecretData" + crypted, err := aesCrypto.Encrypt([]byte(input)) + require.NoError(t, err) + + type args struct { + value []byte + keyID string + } + tests := []struct { + name string + args args + want string + wantErr error + }{ + { + name: "unknown key id error", + args: args{ + value: crypted, + keyID: "foo", + }, + wantErr: zerrors.ThrowNotFound(nil, "CRYPT-nkj1s", "unknown key id"), + }, + { + name: "ok", + args: args{ + value: crypted, + keyID: "keyID", + }, + want: input, + }, + } + for _, tt := range tests { + got, err := aesCrypto.DecryptString(tt.args.value, tt.args.keyID) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + } +} + +func FuzzAESCrypto_DecryptString(f *testing.F) { + aesCrypto := newTestAESCrypto(f) + tests := []string{ + " ", + "SecretData", + "FooBar", + "HelloWorld", + } + for _, input := range tests { + tc, err := aesCrypto.Encrypt([]byte(input)) + require.NoError(f, err) + f.Add(tc) + } + f.Fuzz(func(t *testing.T, value []byte) { + got, err := aesCrypto.DecryptString(value, "keyID") + if errors.Is(err, zerrors.ThrowPreconditionFailed(nil, "CRYPT-23kH1", "cipher text block too short")) { + return + } + if errors.Is(err, zerrors.ThrowPreconditionFailed(nil, "CRYPT-hiCh0", "non-UTF-8 in decrypted string")) { + return + } + require.NoError(t, err) + assert.True(t, utf8.ValidString(got), "result is not valid UTF-8") + }) } diff --git a/internal/crypto/crypto.go b/internal/crypto/crypto.go index 2e8e4a71b0..a74f97a054 100644 --- a/internal/crypto/crypto.go +++ b/internal/crypto/crypto.go @@ -19,6 +19,9 @@ type EncryptionAlgorithm interface { DecryptionKeyIDs() []string Encrypt(value []byte) ([]byte, error) Decrypt(hashed []byte, keyID string) ([]byte, error) + + // DecryptString decrypts the value using the key identified by keyID. + // When the decrypted value contains non-UTF8 characters an error is returned. DecryptString(hashed []byte, keyID string) (string, error) } @@ -72,6 +75,8 @@ func Decrypt(value *CryptoValue, alg EncryptionAlgorithm) ([]byte, error) { return alg.Decrypt(value.Crypted, value.KeyID) } +// DecryptString decrypts the value using the key identified by keyID. +// When the decrypted value contains non-UTF8 characters an error is returned. func DecryptString(value *CryptoValue, alg EncryptionAlgorithm) (string, error) { if err := checkEncryptionAlgorithm(value, alg); err != nil { return "", err diff --git a/internal/crypto/testdata/fuzz/FuzzAESCrypto_DecryptString/8d609af8fa2eb76f b/internal/crypto/testdata/fuzz/FuzzAESCrypto_DecryptString/8d609af8fa2eb76f new file mode 100644 index 0000000000..233de8fb25 --- /dev/null +++ b/internal/crypto/testdata/fuzz/FuzzAESCrypto_DecryptString/8d609af8fa2eb76f @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("0010120C001010070") diff --git a/internal/domain/refresh_token.go b/internal/domain/refresh_token.go index 6f2d883df5..25ab32f45b 100644 --- a/internal/domain/refresh_token.go +++ b/internal/domain/refresh_token.go @@ -25,13 +25,13 @@ func FromRefreshToken(refreshToken string, algorithm crypto.EncryptionAlgorithm) if err != nil { return "", "", "", zerrors.ThrowInvalidArgument(err, "DOMAIN-BGDhn", "Errors.User.RefreshToken.Invalid") } - decrypted, err := algorithm.Decrypt(decoded, algorithm.EncryptionKeyID()) + decrypted, err := algorithm.DecryptString(decoded, algorithm.EncryptionKeyID()) if err != nil { - return "", "", "", err + return "", "", "", zerrors.ThrowInvalidArgument(err, "DOMAIN-rie9A", "Errors.User.RefreshToken.Invalid") } - split := strings.Split(string(decrypted), ":") + split := strings.Split(decrypted, ":") if len(split) != 3 { - return "", "", "", zerrors.ThrowInvalidArgument(nil, "DOMAIN-BGDhn", "Errors.User.RefreshToken.Invalid") + return "", "", "", zerrors.ThrowInvalidArgument(nil, "DOMAIN-Se8oh", "Errors.User.RefreshToken.Invalid") } return split[0], split[1], split[2], nil } diff --git a/internal/domain/refresh_token_test.go b/internal/domain/refresh_token_test.go new file mode 100644 index 0000000000..e2719bd238 --- /dev/null +++ b/internal/domain/refresh_token_test.go @@ -0,0 +1,129 @@ +package domain + +import ( + "encoding/base64" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/net/context" + + "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/zerrors" +) + +type mockKeyStorage struct { + keys crypto.Keys +} + +func (s *mockKeyStorage) ReadKeys() (crypto.Keys, error) { + return s.keys, nil +} + +func (s *mockKeyStorage) ReadKey(id string) (*crypto.Key, error) { + return &crypto.Key{ + ID: id, + Value: s.keys[id], + }, nil +} + +func (*mockKeyStorage) CreateKeys(context.Context, ...*crypto.Key) error { + return errors.New("mockKeyStorage.CreateKeys not implemented") +} + +func TestFromRefreshToken(t *testing.T) { + const ( + userID = "userID" + tokenID = "tokenID" + ) + + keyConfig := &crypto.KeyConfig{ + EncryptionKeyID: "keyID", + DecryptionKeyIDs: []string{"keyID"}, + } + keys := crypto.Keys{"keyID": "ThisKeyNeedsToHave32Characters!!"} + algorithm, err := crypto.NewAESCrypto(keyConfig, &mockKeyStorage{keys: keys}) + require.NoError(t, err) + + refreshToken, err := NewRefreshToken(userID, tokenID, algorithm) + require.NoError(t, err) + + invalidRefreshToken, err := algorithm.Encrypt([]byte(userID + ":" + tokenID)) + require.NoError(t, err) + + type args struct { + refreshToken string + algorithm crypto.EncryptionAlgorithm + } + tests := []struct { + name string + args args + wantUserID string + wantTokenID string + wantToken string + wantErr error + }{ + { + name: "invalid base64", + args: args{"~~~", algorithm}, + wantErr: zerrors.ThrowInvalidArgument(nil, "DOMAIN-BGDhn", "Errors.User.RefreshToken.Invalid"), + }, + { + name: "short cipher text", + args: args{"DEADBEEF", algorithm}, + wantErr: zerrors.ThrowInvalidArgument(err, "DOMAIN-rie9A", "Errors.User.RefreshToken.Invalid"), + }, + { + name: "incorrect amount of segments", + args: args{base64.RawURLEncoding.EncodeToString(invalidRefreshToken), algorithm}, + wantErr: zerrors.ThrowInvalidArgument(nil, "DOMAIN-Se8oh", "Errors.User.RefreshToken.Invalid"), + }, + { + name: "success", + args: args{refreshToken, algorithm}, + wantUserID: userID, + wantTokenID: tokenID, + wantToken: tokenID, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotUserID, gotTokenID, gotToken, err := FromRefreshToken(tt.args.refreshToken, tt.args.algorithm) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.wantUserID, gotUserID) + assert.Equal(t, tt.wantTokenID, gotTokenID) + assert.Equal(t, tt.wantToken, gotToken) + }) + } +} + +// Fuzz test invalid inputs. None of the inputs should result in a success. +func FuzzFromRefreshToken(f *testing.F) { + keyConfig := &crypto.KeyConfig{ + EncryptionKeyID: "keyID", + DecryptionKeyIDs: []string{"keyID"}, + } + keys := crypto.Keys{"keyID": "ThisKeyNeedsToHave32Characters!!"} + algorithm, err := crypto.NewAESCrypto(keyConfig, &mockKeyStorage{keys: keys}) + require.NoError(f, err) + + invalidRefreshToken, err := algorithm.Encrypt([]byte("userID:tokenID")) + require.NoError(f, err) + + tests := []string{ + "~~~", // invalid base64 + "DEADBEEF", // short cipher text + base64.RawURLEncoding.EncodeToString(invalidRefreshToken), // incorrect amount of segments + } + for _, tc := range tests { + f.Add(tc) + } + + f.Fuzz(func(t *testing.T, refreshToken string) { + gotUserID, gotTokenID, gotToken, err := FromRefreshToken(refreshToken, algorithm) + target := zerrors.InvalidArgumentError{ZitadelError: new(zerrors.ZitadelError)} + t.Log(gotUserID, gotTokenID, gotToken) + require.ErrorAs(t, err, &target) + }) +} diff --git a/internal/domain/testdata/fuzz/FuzzFromRefreshToken/576e811604c701eb b/internal/domain/testdata/fuzz/FuzzFromRefreshToken/576e811604c701eb new file mode 100644 index 0000000000..0e9296b076 --- /dev/null +++ b/internal/domain/testdata/fuzz/FuzzFromRefreshToken/576e811604c701eb @@ -0,0 +1,2 @@ +go test fuzz v1 +string("0000050000000000000000000000000") diff --git a/internal/zerrors/invalid_argument.go b/internal/zerrors/invalid_argument.go index b2a33fc860..e97519e660 100644 --- a/internal/zerrors/invalid_argument.go +++ b/internal/zerrors/invalid_argument.go @@ -1,6 +1,8 @@ package zerrors -import "fmt" +import ( + "fmt" +) var ( _ InvalidArgument = (*InvalidArgumentError)(nil) @@ -39,6 +41,15 @@ func (err *InvalidArgumentError) Is(target error) bool { return err.ZitadelError.Is(t.ZitadelError) } +func (err *InvalidArgumentError) As(target any) bool { + targetErr, ok := target.(*InvalidArgumentError) + if !ok { + return false + } + *targetErr = *err + return true +} + func (err *InvalidArgumentError) Unwrap() error { return err.ZitadelError }