mirror of
https://github.com/zitadel/zitadel.git
synced 2025-01-07 08:07:46 +00:00
fix(crypto): reject decrypted strings with non-UTF8 characters. (#8374)
# Which Problems Are Solved We noticed logging where 500: Internal Server errors were returned from the token endpoint, mostly for the `refresh_token` grant. The error was thrown by the database as it received non-UTF8 strings for token IDs Zitadel uses symmetric encryption for opaque tokens, including refresh tokens. Encrypted values are base64 encoded. It appeared to be possible to send garbage base64 to the token endpoint, which will pass decryption and string-splitting. In those cases the resulting ID is not a valid UTF-8 string. Invalid non-UTF8 strings are now rejected during token decryption. # How the Problems Are Solved - `AESCrypto.DecryptString()` checks if the decrypted bytes only contain valid UTF-8 characters before converting them into a string. - `AESCrypto.Decrypt()` is unmodified and still allows decryption on non-UTF8 byte strings. - `FromRefreshToken` now uses `DecryptString` instead of `Decrypt` # Additional Changes - Unit tests added for `FromRefreshToken` and `AESCrypto.DecryptString()`. - Fuzz tests added for `FromRefreshToken` and `AESCrypto.DecryptString()`. This was to pinpoint the problem - Testdata with values that resulted in invalid strings are committed. In the pipeline this results in the Fuzz tests to execute as regular unit-test cases. As we don't use the `-fuzz` flag in the pipeline no further fuzzing is performed. # Additional Context - Closes #7765 - https://go.dev/doc/tutorial/fuzz
This commit is contained in:
parent
3d071fc505
commit
4e3fd305ab
@ -293,7 +293,7 @@ func (c *Commands) decryptRefreshToken(refreshToken string) (sessionID, refreshT
|
||||
}
|
||||
decrypted, err := c.keyAlgorithm.DecryptString(decoded, c.keyAlgorithm.EncryptionKeyID())
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
return "", "", zerrors.ThrowInvalidArgument(err, "OIDCS-Jei0i", "Errors.User.RefreshToken.Invalid")
|
||||
}
|
||||
return parseRefreshToken(decrypted)
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
@ -46,15 +47,17 @@ func (a *AESCrypto) Decrypt(value []byte, keyID string) ([]byte, error) {
|
||||
return DecryptAES(value, key)
|
||||
}
|
||||
|
||||
// DecryptString decrypts the value using the key identified by keyID.
|
||||
// When the decrypted value contains non-UTF8 characters an error is returned.
|
||||
func (a *AESCrypto) DecryptString(value []byte, keyID string) (string, error) {
|
||||
key, err := a.decryptionKey(keyID)
|
||||
b, err := a.Decrypt(value, keyID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
b, err := DecryptAES(value, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
if !utf8.Valid(b) {
|
||||
return "", zerrors.ThrowPreconditionFailed(err, "CRYPT-hiCh0", "non-UTF-8 in decrypted string")
|
||||
}
|
||||
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
|
@ -1,18 +1,109 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
// TODO: refactor test style
|
||||
func TestDecrypt_OK(t *testing.T) {
|
||||
encryptedpw, err := EncryptAESString("ThisIsMySecretPw", "passphrasewhichneedstobe32bytes!")
|
||||
assert.NoError(t, err)
|
||||
|
||||
decryptedpw, err := DecryptAESString(encryptedpw, "passphrasewhichneedstobe32bytes!")
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "ThisIsMySecretPw", decryptedpw)
|
||||
type mockKeyStorage struct {
|
||||
keys Keys
|
||||
}
|
||||
|
||||
func (s *mockKeyStorage) ReadKeys() (Keys, error) {
|
||||
return s.keys, nil
|
||||
}
|
||||
|
||||
func (s *mockKeyStorage) ReadKey(id string) (*Key, error) {
|
||||
return &Key{
|
||||
ID: id,
|
||||
Value: s.keys[id],
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (*mockKeyStorage) CreateKeys(context.Context, ...*Key) error {
|
||||
return errors.New("mockKeyStorage.CreateKeys not implemented")
|
||||
}
|
||||
|
||||
func newTestAESCrypto(t testing.TB) *AESCrypto {
|
||||
keyConfig := &KeyConfig{
|
||||
EncryptionKeyID: "keyID",
|
||||
DecryptionKeyIDs: []string{"keyID"},
|
||||
}
|
||||
keys := Keys{"keyID": "ThisKeyNeedsToHave32Characters!!"}
|
||||
aesCrypto, err := NewAESCrypto(keyConfig, &mockKeyStorage{keys: keys})
|
||||
require.NoError(t, err)
|
||||
return aesCrypto
|
||||
}
|
||||
|
||||
func TestAESCrypto_DecryptString(t *testing.T) {
|
||||
aesCrypto := newTestAESCrypto(t)
|
||||
const input = "SecretData"
|
||||
crypted, err := aesCrypto.Encrypt([]byte(input))
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
value []byte
|
||||
keyID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "unknown key id error",
|
||||
args: args{
|
||||
value: crypted,
|
||||
keyID: "foo",
|
||||
},
|
||||
wantErr: zerrors.ThrowNotFound(nil, "CRYPT-nkj1s", "unknown key id"),
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
value: crypted,
|
||||
keyID: "keyID",
|
||||
},
|
||||
want: input,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got, err := aesCrypto.DecryptString(tt.args.value, tt.args.keyID)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzAESCrypto_DecryptString(f *testing.F) {
|
||||
aesCrypto := newTestAESCrypto(f)
|
||||
tests := []string{
|
||||
" ",
|
||||
"SecretData",
|
||||
"FooBar",
|
||||
"HelloWorld",
|
||||
}
|
||||
for _, input := range tests {
|
||||
tc, err := aesCrypto.Encrypt([]byte(input))
|
||||
require.NoError(f, err)
|
||||
f.Add(tc)
|
||||
}
|
||||
f.Fuzz(func(t *testing.T, value []byte) {
|
||||
got, err := aesCrypto.DecryptString(value, "keyID")
|
||||
if errors.Is(err, zerrors.ThrowPreconditionFailed(nil, "CRYPT-23kH1", "cipher text block too short")) {
|
||||
return
|
||||
}
|
||||
if errors.Is(err, zerrors.ThrowPreconditionFailed(nil, "CRYPT-hiCh0", "non-UTF-8 in decrypted string")) {
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.True(t, utf8.ValidString(got), "result is not valid UTF-8")
|
||||
})
|
||||
}
|
||||
|
@ -19,6 +19,9 @@ type EncryptionAlgorithm interface {
|
||||
DecryptionKeyIDs() []string
|
||||
Encrypt(value []byte) ([]byte, error)
|
||||
Decrypt(hashed []byte, keyID string) ([]byte, error)
|
||||
|
||||
// DecryptString decrypts the value using the key identified by keyID.
|
||||
// When the decrypted value contains non-UTF8 characters an error is returned.
|
||||
DecryptString(hashed []byte, keyID string) (string, error)
|
||||
}
|
||||
|
||||
@ -72,6 +75,8 @@ func Decrypt(value *CryptoValue, alg EncryptionAlgorithm) ([]byte, error) {
|
||||
return alg.Decrypt(value.Crypted, value.KeyID)
|
||||
}
|
||||
|
||||
// DecryptString decrypts the value using the key identified by keyID.
|
||||
// When the decrypted value contains non-UTF8 characters an error is returned.
|
||||
func DecryptString(value *CryptoValue, alg EncryptionAlgorithm) (string, error) {
|
||||
if err := checkEncryptionAlgorithm(value, alg); err != nil {
|
||||
return "", err
|
||||
|
2
internal/crypto/testdata/fuzz/FuzzAESCrypto_DecryptString/8d609af8fa2eb76f
vendored
Normal file
2
internal/crypto/testdata/fuzz/FuzzAESCrypto_DecryptString/8d609af8fa2eb76f
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
[]byte("0010120C001010070")
|
@ -25,13 +25,13 @@ func FromRefreshToken(refreshToken string, algorithm crypto.EncryptionAlgorithm)
|
||||
if err != nil {
|
||||
return "", "", "", zerrors.ThrowInvalidArgument(err, "DOMAIN-BGDhn", "Errors.User.RefreshToken.Invalid")
|
||||
}
|
||||
decrypted, err := algorithm.Decrypt(decoded, algorithm.EncryptionKeyID())
|
||||
decrypted, err := algorithm.DecryptString(decoded, algorithm.EncryptionKeyID())
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
return "", "", "", zerrors.ThrowInvalidArgument(err, "DOMAIN-rie9A", "Errors.User.RefreshToken.Invalid")
|
||||
}
|
||||
split := strings.Split(string(decrypted), ":")
|
||||
split := strings.Split(decrypted, ":")
|
||||
if len(split) != 3 {
|
||||
return "", "", "", zerrors.ThrowInvalidArgument(nil, "DOMAIN-BGDhn", "Errors.User.RefreshToken.Invalid")
|
||||
return "", "", "", zerrors.ThrowInvalidArgument(nil, "DOMAIN-Se8oh", "Errors.User.RefreshToken.Invalid")
|
||||
}
|
||||
return split[0], split[1], split[2], nil
|
||||
}
|
||||
|
129
internal/domain/refresh_token_test.go
Normal file
129
internal/domain/refresh_token_test.go
Normal file
@ -0,0 +1,129 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type mockKeyStorage struct {
|
||||
keys crypto.Keys
|
||||
}
|
||||
|
||||
func (s *mockKeyStorage) ReadKeys() (crypto.Keys, error) {
|
||||
return s.keys, nil
|
||||
}
|
||||
|
||||
func (s *mockKeyStorage) ReadKey(id string) (*crypto.Key, error) {
|
||||
return &crypto.Key{
|
||||
ID: id,
|
||||
Value: s.keys[id],
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (*mockKeyStorage) CreateKeys(context.Context, ...*crypto.Key) error {
|
||||
return errors.New("mockKeyStorage.CreateKeys not implemented")
|
||||
}
|
||||
|
||||
func TestFromRefreshToken(t *testing.T) {
|
||||
const (
|
||||
userID = "userID"
|
||||
tokenID = "tokenID"
|
||||
)
|
||||
|
||||
keyConfig := &crypto.KeyConfig{
|
||||
EncryptionKeyID: "keyID",
|
||||
DecryptionKeyIDs: []string{"keyID"},
|
||||
}
|
||||
keys := crypto.Keys{"keyID": "ThisKeyNeedsToHave32Characters!!"}
|
||||
algorithm, err := crypto.NewAESCrypto(keyConfig, &mockKeyStorage{keys: keys})
|
||||
require.NoError(t, err)
|
||||
|
||||
refreshToken, err := NewRefreshToken(userID, tokenID, algorithm)
|
||||
require.NoError(t, err)
|
||||
|
||||
invalidRefreshToken, err := algorithm.Encrypt([]byte(userID + ":" + tokenID))
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
refreshToken string
|
||||
algorithm crypto.EncryptionAlgorithm
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantUserID string
|
||||
wantTokenID string
|
||||
wantToken string
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "invalid base64",
|
||||
args: args{"~~~", algorithm},
|
||||
wantErr: zerrors.ThrowInvalidArgument(nil, "DOMAIN-BGDhn", "Errors.User.RefreshToken.Invalid"),
|
||||
},
|
||||
{
|
||||
name: "short cipher text",
|
||||
args: args{"DEADBEEF", algorithm},
|
||||
wantErr: zerrors.ThrowInvalidArgument(err, "DOMAIN-rie9A", "Errors.User.RefreshToken.Invalid"),
|
||||
},
|
||||
{
|
||||
name: "incorrect amount of segments",
|
||||
args: args{base64.RawURLEncoding.EncodeToString(invalidRefreshToken), algorithm},
|
||||
wantErr: zerrors.ThrowInvalidArgument(nil, "DOMAIN-Se8oh", "Errors.User.RefreshToken.Invalid"),
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
args: args{refreshToken, algorithm},
|
||||
wantUserID: userID,
|
||||
wantTokenID: tokenID,
|
||||
wantToken: tokenID,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotUserID, gotTokenID, gotToken, err := FromRefreshToken(tt.args.refreshToken, tt.args.algorithm)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.wantUserID, gotUserID)
|
||||
assert.Equal(t, tt.wantTokenID, gotTokenID)
|
||||
assert.Equal(t, tt.wantToken, gotToken)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Fuzz test invalid inputs. None of the inputs should result in a success.
|
||||
func FuzzFromRefreshToken(f *testing.F) {
|
||||
keyConfig := &crypto.KeyConfig{
|
||||
EncryptionKeyID: "keyID",
|
||||
DecryptionKeyIDs: []string{"keyID"},
|
||||
}
|
||||
keys := crypto.Keys{"keyID": "ThisKeyNeedsToHave32Characters!!"}
|
||||
algorithm, err := crypto.NewAESCrypto(keyConfig, &mockKeyStorage{keys: keys})
|
||||
require.NoError(f, err)
|
||||
|
||||
invalidRefreshToken, err := algorithm.Encrypt([]byte("userID:tokenID"))
|
||||
require.NoError(f, err)
|
||||
|
||||
tests := []string{
|
||||
"~~~", // invalid base64
|
||||
"DEADBEEF", // short cipher text
|
||||
base64.RawURLEncoding.EncodeToString(invalidRefreshToken), // incorrect amount of segments
|
||||
}
|
||||
for _, tc := range tests {
|
||||
f.Add(tc)
|
||||
}
|
||||
|
||||
f.Fuzz(func(t *testing.T, refreshToken string) {
|
||||
gotUserID, gotTokenID, gotToken, err := FromRefreshToken(refreshToken, algorithm)
|
||||
target := zerrors.InvalidArgumentError{ZitadelError: new(zerrors.ZitadelError)}
|
||||
t.Log(gotUserID, gotTokenID, gotToken)
|
||||
require.ErrorAs(t, err, &target)
|
||||
})
|
||||
}
|
2
internal/domain/testdata/fuzz/FuzzFromRefreshToken/576e811604c701eb
vendored
Normal file
2
internal/domain/testdata/fuzz/FuzzFromRefreshToken/576e811604c701eb
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("0000050000000000000000000000000")
|
@ -1,6 +1,8 @@
|
||||
package zerrors
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var (
|
||||
_ InvalidArgument = (*InvalidArgumentError)(nil)
|
||||
@ -39,6 +41,15 @@ func (err *InvalidArgumentError) Is(target error) bool {
|
||||
return err.ZitadelError.Is(t.ZitadelError)
|
||||
}
|
||||
|
||||
func (err *InvalidArgumentError) As(target any) bool {
|
||||
targetErr, ok := target.(*InvalidArgumentError)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
*targetErr = *err
|
||||
return true
|
||||
}
|
||||
|
||||
func (err *InvalidArgumentError) Unwrap() error {
|
||||
return err.ZitadelError
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user