diff --git a/internal/crypto/aes.go b/internal/crypto/aes.go new file mode 100644 index 0000000000..31d3e46847 --- /dev/null +++ b/internal/crypto/aes.go @@ -0,0 +1,136 @@ +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "io" + + "github.com/caos/zitadel/internal/errors" +) + +var _ EncryptionAlg = (*AESCrypto)(nil) + +type AESCrypto struct { + keys map[string]string + encryptionKeyID string + keyIDs []string +} + +func NewAESCrypto(config *KeyConfig) (*AESCrypto, error) { + keys, ids, err := LoadKeys(config) + if err != nil { + return nil, err + } + return &AESCrypto{ + keys: keys, + encryptionKeyID: config.EncryptionKeyID, + keyIDs: ids, + }, nil +} + +func (a *AESCrypto) Algorithm() string { + return "aes" +} + +func (a *AESCrypto) Encrypt(value []byte) ([]byte, error) { + return EncryptAES(value, a.encryptionKey()) +} + +func (a *AESCrypto) Decrypt(value []byte, keyID string) ([]byte, error) { + key, err := a.decryptionKey(keyID) + if err != nil { + return nil, err + } + return DecryptAES(value, key) +} + +func (a *AESCrypto) DecryptString(value []byte, keyID string) (string, error) { + key, err := a.decryptionKey(keyID) + if err != nil { + return "", err + } + b, err := DecryptAES(value, key) + if err != nil { + return "", err + } + return string(b), nil +} + +func (a *AESCrypto) EncryptionKeyID() string { + return a.encryptionKeyID +} + +func (a *AESCrypto) DecryptionKeyIDs() []string { + return a.keyIDs +} + +func (a *AESCrypto) encryptionKey() string { + return a.keys[a.encryptionKeyID] +} + +func (a *AESCrypto) decryptionKey(keyID string) (string, error) { + key, ok := a.keys[keyID] + if !ok { + return "", errors.ThrowNotFound(nil, "CRYPT-nkj1s", "unknown key id") + } + return key, nil +} + +func EncryptAESString(data string, key string) (string, error) { + encrypted, err := EncryptAES([]byte(data), key) + if err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(encrypted), nil +} + +func EncryptAES(plainText []byte, key string) ([]byte, error) { + block, err := aes.NewCipher([]byte(key)) + if err != nil { + return nil, err + } + + cipherText := make([]byte, aes.BlockSize+len(plainText)) + iv := cipherText[:aes.BlockSize] + if _, err = io.ReadFull(rand.Reader, iv); err != nil { + return nil, err + } + + stream := cipher.NewCFBEncrypter(block, iv) + stream.XORKeyStream(cipherText[aes.BlockSize:], plainText) + + return cipherText, nil +} + +func DecryptAESString(data string, key string) (string, error) { + text, err := base64.URLEncoding.DecodeString(data) + if err != nil { + return "", nil + } + decrypted, err := DecryptAES(text, key) + if err != nil { + return "", err + } + return string(decrypted), nil +} + +func DecryptAES(cipherText []byte, key string) ([]byte, error) { + block, err := aes.NewCipher([]byte(key)) + if err != nil { + return nil, err + } + + if len(cipherText) < aes.BlockSize { + err = errors.ThrowPreconditionFailed(nil, "CRYPT-23kH1", "cipher text block too short") + return nil, err + } + iv := cipherText[:aes.BlockSize] + cipherText = cipherText[aes.BlockSize:] + + stream := cipher.NewCFBDecrypter(block, iv) + stream.XORKeyStream(cipherText, cipherText) + + return cipherText, err +} diff --git a/internal/crypto/aes_test.go b/internal/crypto/aes_test.go new file mode 100644 index 0000000000..c7e31de3d5 --- /dev/null +++ b/internal/crypto/aes_test.go @@ -0,0 +1,17 @@ +package crypto + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDecrypt_OK(t *testing.T) { + encryptedpw, err := EncryptAESString("ThisIsMySecretPw", "passphrasewhichneedstobe32bytes!") + assert.NoError(t, err) + + decryptedpw, err := DecryptAESString(encryptedpw, "passphrasewhichneedstobe32bytes!") + assert.NoError(t, err) + + assert.Equal(t, "ThisIsMySecretPw", decryptedpw) +} diff --git a/internal/crypto/bcrypt.go b/internal/crypto/bcrypt.go new file mode 100644 index 0000000000..e7daffb27d --- /dev/null +++ b/internal/crypto/bcrypt.go @@ -0,0 +1,27 @@ +package crypto + +import ( + "golang.org/x/crypto/bcrypt" +) + +var _ HashAlg = (*BCrypt)(nil) + +type BCrypt struct { + cost int +} + +func NewBCrypt(cost int) *BCrypt { + return &BCrypt{cost: cost} +} + +func (b *BCrypt) Algorithm() string { + return "bcrypt" +} + +func (b *BCrypt) Hash(value []byte) ([]byte, error) { + return bcrypt.GenerateFromPassword(value, b.cost) +} + +func (b *BCrypt) CompareHash(hashed, value []byte) error { + return bcrypt.CompareHashAndPassword(hashed, value) +} diff --git a/internal/crypto/code.go b/internal/crypto/code.go new file mode 100644 index 0000000000..0537710da6 --- /dev/null +++ b/internal/crypto/code.go @@ -0,0 +1,153 @@ +package crypto + +import ( + "crypto/rand" + "time" + + "github.com/caos/zitadel/internal/errors" +) + +var ( + LowerLetters = []rune("abcdefghijklmnopqrstuvwxyz") + UpperLetters = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ") + Digits = []rune("0123456789") + Symbols = []rune("~!@#$^&*()_+`-={}|[]:<>?,./") +) + +type Generator interface { + Length() uint + Expiry() time.Duration + Alg() Crypto + Runes() []rune +} + +type EncryptionGenerator struct { + length uint + expiry time.Duration + alg EncryptionAlg + runes []rune +} + +func (g *EncryptionGenerator) Length() uint { + return g.length +} + +func (g *EncryptionGenerator) Expiry() time.Duration { + return g.expiry +} + +func (g *EncryptionGenerator) Alg() Crypto { + return g.alg +} + +func (g *EncryptionGenerator) Runes() []rune { + return g.runes +} + +func NewEncryptionGenerator(length uint, expiry time.Duration, alg EncryptionAlg, runes []rune) *EncryptionGenerator { + return &EncryptionGenerator{ + length: length, + expiry: expiry, + alg: alg, + runes: runes, + } +} + +type HashGenerator struct { + length uint + expiry time.Duration + alg HashAlg + runes []rune +} + +func (g *HashGenerator) Length() uint { + return g.length +} + +func (g *HashGenerator) Expiry() time.Duration { + return g.expiry +} + +func (g *HashGenerator) Alg() Crypto { + return g.alg +} + +func (g *HashGenerator) Runes() []rune { + return g.runes +} + +func NewHashGenerator(length uint, expiry time.Duration, alg HashAlg, runes []rune) *HashGenerator { + return &HashGenerator{ + length: length, + expiry: expiry, + alg: alg, + runes: runes, + } +} + +func NewCode(g Generator) (*CryptoValue, string, error) { + code, err := generateRandomString(g.Length(), g.Runes()) + if err != nil { + return nil, "", err + } + crypto, err := Crypt([]byte(code), g.Alg()) + if err != nil { + return nil, "", err + } + return crypto, code, nil +} + +func IsCodeExpired(creationDate time.Time, expiry time.Duration) bool { + return creationDate.Add(expiry).Before(time.Now().UTC()) +} + +func VerifyCode(creationDate time.Time, expiry time.Duration, cryptoCode *CryptoValue, verificationCode string, g Generator) error { + if IsCodeExpired(creationDate, expiry) { + return errors.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "verification code is expired") + } + switch alg := g.Alg().(type) { + case EncryptionAlg: + return verifyEncryptedCode(cryptoCode, verificationCode, alg) + case HashAlg: + return verifyHashedCode(cryptoCode, verificationCode, alg) + } + return errors.ThrowInvalidArgument(nil, "CODE-fW2gNa", "generator alg is not supported") +} + +func generateRandomString(length uint, chars []rune) (string, error) { + if length == 0 { + return "", nil + } + + max := len(chars) - 1 + maxStr := int(length - 1) + + str := make([]rune, length) + randBytes := make([]byte, length) + if _, err := rand.Read(randBytes); err != nil { + return "", err + } + for i, rb := range randBytes { + str[i] = chars[int(rb)%max] + if i == maxStr { + return string(str), nil + } + } + return "", nil +} + +func verifyEncryptedCode(cryptoCode *CryptoValue, verificationCode string, alg EncryptionAlg) error { + code, err := DecryptString(cryptoCode, alg) + if err != nil { + return err + } + + if code != verificationCode { + return errors.ThrowInvalidArgument(nil, "CODE-woT0xc", "verification code is invalid") + } + return nil +} + +func verifyHashedCode(cryptoCode *CryptoValue, verificationCode string, alg HashAlg) error { + return CompareHash(cryptoCode, []byte(verificationCode), alg) +} diff --git a/internal/crypto/code_test.go b/internal/crypto/code_test.go new file mode 100644 index 0000000000..f7f55390e7 --- /dev/null +++ b/internal/crypto/code_test.go @@ -0,0 +1,92 @@ +package crypto + +import ( + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func Test_Encrypted_OK(t *testing.T) { + mCrypto := NewMockEncryptionAlg(gomock.NewController(t)) + mCrypto.EXPECT().Algorithm().AnyTimes().Return("enc") + mCrypto.EXPECT().EncryptionKeyID().AnyTimes().Return("id") + mCrypto.EXPECT().DecryptionKeyIDs().AnyTimes().Return([]string{"id"}) + mCrypto.EXPECT().Encrypt(gomock.Any()).DoAndReturn( + func(code []byte) ([]byte, error) { + return code, nil + }, + ) + mCrypto.EXPECT().DecryptString(gomock.Any(), gomock.Any()).DoAndReturn( + func(code []byte, _ string) (string, error) { + return string(code), nil + }, + ) + generator := NewEncryptionGenerator(6, 0, mCrypto, Digits) + + crypto, code, err := NewCode(generator) + assert.NoError(t, err) + + decrypted, err := DecryptString(crypto, generator.alg) + assert.NoError(t, err) + assert.Equal(t, code, decrypted) + assert.Equal(t, 6, len(decrypted)) +} + +func Test_Verify_Encrypted_OK(t *testing.T) { + mCrypto := NewMockEncryptionAlg(gomock.NewController(t)) + mCrypto.EXPECT().Algorithm().AnyTimes().Return("enc") + mCrypto.EXPECT().EncryptionKeyID().AnyTimes().Return("id") + mCrypto.EXPECT().DecryptionKeyIDs().AnyTimes().Return([]string{"id"}) + mCrypto.EXPECT().DecryptString(gomock.Any(), gomock.Any()).DoAndReturn( + func(code []byte, _ string) (string, error) { + return string(code), nil + }, + ) + creationDate := time.Now() + code := &CryptoValue{ + CryptoType: TypeEncryption, + Algorithm: "enc", + KeyID: "id", + Crypted: []byte("code"), + } + generator := NewEncryptionGenerator(6, 0, mCrypto, Digits) + + err := VerifyCode(creationDate, 1*time.Hour, code, "code", generator) + assert.NoError(t, err) +} +func Test_Verify_Encrypted_Invalid_Err(t *testing.T) { + mCrypto := NewMockEncryptionAlg(gomock.NewController(t)) + mCrypto.EXPECT().Algorithm().AnyTimes().Return("enc") + mCrypto.EXPECT().EncryptionKeyID().AnyTimes().Return("id") + mCrypto.EXPECT().DecryptionKeyIDs().AnyTimes().Return([]string{"id"}) + mCrypto.EXPECT().DecryptString(gomock.Any(), gomock.Any()).DoAndReturn( + func(code []byte, _ string) (string, error) { + return string(code), nil + }, + ) + creationDate := time.Now() + code := &CryptoValue{ + CryptoType: TypeEncryption, + Algorithm: "enc", + KeyID: "id", + Crypted: []byte("code"), + } + generator := NewEncryptionGenerator(6, 0, mCrypto, Digits) + + err := VerifyCode(creationDate, 1*time.Hour, code, "wrong", generator) + assert.Error(t, err) +} + +func TestIsCodeExpired_Expired(t *testing.T) { + creationDate := time.Date(2019, time.April, 1, 0, 0, 0, 0, time.UTC) + expired := IsCodeExpired(creationDate, 1*time.Hour) + assert.True(t, expired) +} + +func TestIsCodeExpired_NotExpired(t *testing.T) { + creationDate := time.Now() + expired := IsCodeExpired(creationDate, 1*time.Hour) + assert.False(t, expired) +} diff --git a/internal/crypto/crypto.go b/internal/crypto/crypto.go new file mode 100644 index 0000000000..38ee06577b --- /dev/null +++ b/internal/crypto/crypto.go @@ -0,0 +1,103 @@ +package crypto + +import ( + "github.com/caos/zitadel/internal/errors" +) + +const ( + TypeEncryption CryptoType = iota + TypeHash +) + +type Crypto interface { + Algorithm() string +} + +type EncryptionAlg interface { + Crypto + EncryptionKeyID() string + DecryptionKeyIDs() []string + Encrypt(value []byte) ([]byte, error) + Decrypt(hashed []byte, keyID string) ([]byte, error) + DecryptString(hashed []byte, keyID string) (string, error) +} + +type HashAlg interface { + Crypto + Hash(value []byte) ([]byte, error) + CompareHash(hashed, comparer []byte) error +} + +type CryptoValue struct { + CryptoType CryptoType + Algorithm string + KeyID string + Crypted []byte +} + +type CryptoType int + +func Crypt(value []byte, c Crypto) (*CryptoValue, error) { + switch alg := c.(type) { + case EncryptionAlg: + return Encrypt(value, alg) + case HashAlg: + return Hash(value, alg) + } + return nil, errors.ThrowInternal(nil, "CRYPT-r4IaHZ", "algorithm not supported") +} + +func Encrypt(value []byte, alg EncryptionAlg) (*CryptoValue, error) { + encrypted, err := alg.Encrypt(value) + if err != nil { + return nil, errors.ThrowInternal(err, "CRYPT-qCD0JB", "error encrypting value") + } + return &CryptoValue{ + CryptoType: TypeEncryption, + Algorithm: alg.Algorithm(), + KeyID: alg.EncryptionKeyID(), + Crypted: encrypted, + }, nil +} + +func Decrypt(value *CryptoValue, alg EncryptionAlg) ([]byte, error) { + if err := checkEncAlg(value, alg); err != nil { + return nil, err + } + return alg.Decrypt(value.Crypted, value.KeyID) +} + +func DecryptString(value *CryptoValue, alg EncryptionAlg) (string, error) { + if err := checkEncAlg(value, alg); err != nil { + return "", err + } + return alg.DecryptString(value.Crypted, value.KeyID) +} + +func checkEncAlg(value *CryptoValue, alg EncryptionAlg) error { + if value.Algorithm != alg.Algorithm() { + return errors.ThrowInvalidArgument(nil, "CRYPT-Nx7XlT", "value was encrypted with a different key") + } + for _, id := range alg.DecryptionKeyIDs() { + if id == value.KeyID { + return nil + } + } + return errors.ThrowInvalidArgument(nil, "CRYPT-Kq12vn", "value was encrypted with a different key") +} + +func Hash(value []byte, alg HashAlg) (*CryptoValue, error) { + hashed, err := alg.Hash(value) + if err != nil { + return nil, errors.ThrowInternal(err, "CRYPT-rBVaJU", "error hashing value") + } + return &CryptoValue{ + CryptoType: TypeHash, + Algorithm: alg.Algorithm(), + Crypted: hashed, + }, nil +} + +func CompareHash(value *CryptoValue, comparer []byte, alg HashAlg) error { + return alg.CompareHash(value.Crypted, comparer) +} diff --git a/internal/crypto/crypto_mock.go b/internal/crypto/crypto_mock.go new file mode 100644 index 0000000000..3785c285ae --- /dev/null +++ b/internal/crypto/crypto_mock.go @@ -0,0 +1,223 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: crypto.go + +// Package crypto is a generated GoMock package. +package crypto + +import ( + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockCrypto is a mock of Crypto interface +type MockCrypto struct { + ctrl *gomock.Controller + recorder *MockCryptoMockRecorder +} + +// MockCryptoMockRecorder is the mock recorder for MockCrypto +type MockCryptoMockRecorder struct { + mock *MockCrypto +} + +// NewMockCrypto creates a new mock instance +func NewMockCrypto(ctrl *gomock.Controller) *MockCrypto { + mock := &MockCrypto{ctrl: ctrl} + mock.recorder = &MockCryptoMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockCrypto) EXPECT() *MockCryptoMockRecorder { + return m.recorder +} + +// Algorithm mocks base method +func (m *MockCrypto) Algorithm() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Algorithm") + ret0, _ := ret[0].(string) + return ret0 +} + +// Algorithm indicates an expected call of Algorithm +func (mr *MockCryptoMockRecorder) Algorithm() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Algorithm", reflect.TypeOf((*MockCrypto)(nil).Algorithm)) +} + +// MockEncryptionAlg is a mock of EncryptionAlg interface +type MockEncryptionAlg struct { + ctrl *gomock.Controller + recorder *MockEncryptionAlgMockRecorder +} + +// MockEncryptionAlgMockRecorder is the mock recorder for MockEncryptionAlg +type MockEncryptionAlgMockRecorder struct { + mock *MockEncryptionAlg +} + +// NewMockEncryptionAlg creates a new mock instance +func NewMockEncryptionAlg(ctrl *gomock.Controller) *MockEncryptionAlg { + mock := &MockEncryptionAlg{ctrl: ctrl} + mock.recorder = &MockEncryptionAlgMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockEncryptionAlg) EXPECT() *MockEncryptionAlgMockRecorder { + return m.recorder +} + +// Algorithm mocks base method +func (m *MockEncryptionAlg) Algorithm() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Algorithm") + ret0, _ := ret[0].(string) + return ret0 +} + +// Algorithm indicates an expected call of Algorithm +func (mr *MockEncryptionAlgMockRecorder) Algorithm() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Algorithm", reflect.TypeOf((*MockEncryptionAlg)(nil).Algorithm)) +} + +// EncryptionKeyID mocks base method +func (m *MockEncryptionAlg) EncryptionKeyID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EncryptionKeyID") + ret0, _ := ret[0].(string) + return ret0 +} + +// EncryptionKeyID indicates an expected call of EncryptionKeyID +func (mr *MockEncryptionAlgMockRecorder) EncryptionKeyID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptionKeyID", reflect.TypeOf((*MockEncryptionAlg)(nil).EncryptionKeyID)) +} + +// DecryptionKeyIDs mocks base method +func (m *MockEncryptionAlg) DecryptionKeyIDs() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DecryptionKeyIDs") + ret0, _ := ret[0].([]string) + return ret0 +} + +// DecryptionKeyIDs indicates an expected call of DecryptionKeyIDs +func (mr *MockEncryptionAlgMockRecorder) DecryptionKeyIDs() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptionKeyIDs", reflect.TypeOf((*MockEncryptionAlg)(nil).DecryptionKeyIDs)) +} + +// Encrypt mocks base method +func (m *MockEncryptionAlg) Encrypt(value []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Encrypt", value) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Encrypt indicates an expected call of Encrypt +func (mr *MockEncryptionAlgMockRecorder) Encrypt(value interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encrypt", reflect.TypeOf((*MockEncryptionAlg)(nil).Encrypt), value) +} + +// Decrypt mocks base method +func (m *MockEncryptionAlg) Decrypt(hashed []byte, keyID string) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Decrypt", hashed, keyID) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Decrypt indicates an expected call of Decrypt +func (mr *MockEncryptionAlgMockRecorder) Decrypt(hashed, keyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decrypt", reflect.TypeOf((*MockEncryptionAlg)(nil).Decrypt), hashed, keyID) +} + +// DecryptString mocks base method +func (m *MockEncryptionAlg) DecryptString(hashed []byte, keyID string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DecryptString", hashed, keyID) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DecryptString indicates an expected call of DecryptString +func (mr *MockEncryptionAlgMockRecorder) DecryptString(hashed, keyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptString", reflect.TypeOf((*MockEncryptionAlg)(nil).DecryptString), hashed, keyID) +} + +// MockHashAlg is a mock of HashAlg interface +type MockHashAlg struct { + ctrl *gomock.Controller + recorder *MockHashAlgMockRecorder +} + +// MockHashAlgMockRecorder is the mock recorder for MockHashAlg +type MockHashAlgMockRecorder struct { + mock *MockHashAlg +} + +// NewMockHashAlg creates a new mock instance +func NewMockHashAlg(ctrl *gomock.Controller) *MockHashAlg { + mock := &MockHashAlg{ctrl: ctrl} + mock.recorder = &MockHashAlgMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockHashAlg) EXPECT() *MockHashAlgMockRecorder { + return m.recorder +} + +// Algorithm mocks base method +func (m *MockHashAlg) Algorithm() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Algorithm") + ret0, _ := ret[0].(string) + return ret0 +} + +// Algorithm indicates an expected call of Algorithm +func (mr *MockHashAlgMockRecorder) Algorithm() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Algorithm", reflect.TypeOf((*MockHashAlg)(nil).Algorithm)) +} + +// Hash mocks base method +func (m *MockHashAlg) Hash(value []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Hash", value) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Hash indicates an expected call of Hash +func (mr *MockHashAlgMockRecorder) Hash(value interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Hash", reflect.TypeOf((*MockHashAlg)(nil).Hash), value) +} + +// CompareHash mocks base method +func (m *MockHashAlg) CompareHash(hashed, comparer []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CompareHash", hashed, comparer) + ret0, _ := ret[0].(error) + return ret0 +} + +// CompareHash indicates an expected call of CompareHash +func (mr *MockHashAlgMockRecorder) CompareHash(hashed, comparer interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CompareHash", reflect.TypeOf((*MockHashAlg)(nil).CompareHash), hashed, comparer) +} diff --git a/internal/crypto/crypto_test.go b/internal/crypto/crypto_test.go new file mode 100644 index 0000000000..76334884b9 --- /dev/null +++ b/internal/crypto/crypto_test.go @@ -0,0 +1,273 @@ +package crypto + +import ( + "bytes" + "errors" + "reflect" + "testing" +) + +type mockEncCrypto struct { +} + +func (m *mockEncCrypto) Algorithm() string { + return "enc" +} + +func (m *mockEncCrypto) Encrypt(value []byte) ([]byte, error) { + return value, nil +} + +func (m *mockEncCrypto) Decrypt(value []byte, _ string) ([]byte, error) { + return value, nil +} + +func (m *mockEncCrypto) DecryptString(value []byte, _ string) (string, error) { + return string(value), nil +} + +func (m *mockEncCrypto) EncryptionKeyID() string { + return "keyID" +} +func (m *mockEncCrypto) DecryptionKeyIDs() []string { + return []string{"keyID"} +} + +type mockHashCrypto struct { +} + +func (m *mockHashCrypto) Algorithm() string { + return "hash" +} + +func (m *mockHashCrypto) Hash(value []byte) ([]byte, error) { + return value, nil +} + +func (m *mockHashCrypto) CompareHash(hashed, comparer []byte) error { + if !bytes.Equal(hashed, comparer) { + return errors.New("not equal") + } + return nil +} + +type alg struct{} + +func (a *alg) Algorithm() string { + return "alg" +} + +func TestCrypt(t *testing.T) { + type args struct { + value []byte + c Crypto + } + tests := []struct { + name string + args args + want *CryptoValue + wantErr bool + }{ + { + "encrypt", + args{[]byte("test"), &mockEncCrypto{}}, + &CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID", Crypted: []byte("test")}, + false, + }, + { + "hash", + args{[]byte("test"), &mockHashCrypto{}}, + &CryptoValue{CryptoType: TypeHash, Algorithm: "hash", Crypted: []byte("test")}, + false, + }, + { + "wrong type", + args{[]byte("test"), &alg{}}, + nil, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Crypt(tt.args.value, tt.args.c) + if (err != nil) != tt.wantErr { + t.Errorf("Crypt() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Crypt() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEncrypt(t *testing.T) { + type args struct { + value []byte + c EncryptionAlg + } + tests := []struct { + name string + args args + want *CryptoValue + wantErr bool + }{ + { + "ok", + args{[]byte("test"), &mockEncCrypto{}}, + &CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID", Crypted: []byte("test")}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Encrypt(tt.args.value, tt.args.c) + if (err != nil) != tt.wantErr { + t.Errorf("Encrypt() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Encrypt() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDecrypt(t *testing.T) { + type args struct { + value *CryptoValue + c EncryptionAlg + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ + { + "ok", + args{&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID", Crypted: []byte("test")}, &mockEncCrypto{}}, + []byte("test"), + false, + }, + { + "wrong id", + args{&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID2", Crypted: []byte("test")}, &mockEncCrypto{}}, + nil, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Decrypt(tt.args.value, tt.args.c) + if (err != nil) != tt.wantErr { + t.Errorf("Decrypt() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Decrypt() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDecryptString(t *testing.T) { + type args struct { + value *CryptoValue + c EncryptionAlg + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + "ok", + args{&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID", Crypted: []byte("test")}, &mockEncCrypto{}}, + "test", + false, + }, + { + "wrong id", + args{&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID2", Crypted: []byte("test")}, &mockEncCrypto{}}, + "", + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := DecryptString(tt.args.value, tt.args.c) + if (err != nil) != tt.wantErr { + t.Errorf("DecryptString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("DecryptString() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHash(t *testing.T) { + type args struct { + value []byte + c HashAlg + } + tests := []struct { + name string + args args + want *CryptoValue + wantErr bool + }{ + { + "ok", + args{[]byte("test"), &mockHashCrypto{}}, + &CryptoValue{CryptoType: TypeHash, Algorithm: "hash", Crypted: []byte("test")}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Hash(tt.args.value, tt.args.c) + if (err != nil) != tt.wantErr { + t.Errorf("Hash() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Hash() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCompareHash(t *testing.T) { + type args struct { + value *CryptoValue + comparer []byte + c HashAlg + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "ok", + args{&CryptoValue{CryptoType: TypeHash, Algorithm: "hash", Crypted: []byte("test")}, []byte("test"), &mockHashCrypto{}}, + false, + }, + { + "wrong", + args{&CryptoValue{CryptoType: TypeHash, Algorithm: "hash", Crypted: []byte("test")}, []byte("test2"), &mockHashCrypto{}}, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := CompareHash(tt.args.value, tt.args.comparer, tt.args.c); (err != nil) != tt.wantErr { + t.Errorf("CompareHash() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/internal/crypto/generate.go b/internal/crypto/generate.go new file mode 100644 index 0000000000..4c43e7e9b3 --- /dev/null +++ b/internal/crypto/generate.go @@ -0,0 +1,3 @@ +package crypto + +//go:generate mockgen -source crypto.go -destination ./crypto_mock.go -package crypto diff --git a/internal/crypto/key.go b/internal/crypto/key.go new file mode 100644 index 0000000000..da7f6de9b9 --- /dev/null +++ b/internal/crypto/key.go @@ -0,0 +1,64 @@ +package crypto + +import ( + "os" + + "github.com/caos/logging" + "github.com/caos/utils/errors" + + "github.com/caos/zitadel/internal/config" +) + +const ( + ZitadelKeyPath = "ZITADEL_KEY_PATH" +) + +type KeyConfig struct { + EncryptionKeyID string + DecryptionKeyIDs []string + Path string +} + +type Keys map[string]string + +func ReadKeys(path string) (Keys, error) { + if path == "" { + path = os.Getenv(ZitadelKeyPath) + if path == "" { + return nil, errors.ThrowInvalidArgument(nil, "CRYPT-56lka", "no path set") + } + } + keys := new(Keys) + err := config.Read(keys, path) + return *keys, err +} + +func LoadKeys(config *KeyConfig) (map[string]string, []string, error) { + if config == nil { + return nil, nil, errors.ThrowInvalidArgument(nil, "CRYPT-dJK8s", "config must not be nil") + } + readKeys, err := ReadKeys(config.Path) + if err != nil { + return nil, nil, err + } + keys := make(map[string]string) + ids := make([]string, 0, len(config.DecryptionKeyIDs)+1) + if config.EncryptionKeyID != "" { + key, ok := readKeys[config.EncryptionKeyID] + if !ok { + return nil, nil, errors.ThrowInternalf(nil, "CRYPT-v2Kas", "encryption key not found") + } + keys[config.EncryptionKeyID] = key + ids = append(ids, config.EncryptionKeyID) + } + for _, id := range config.DecryptionKeyIDs { + key, ok := readKeys[id] + if !ok { + logging.Log("CRYPT-s23rf").Warnf("description key %s not found", id) + continue + } + keys[id] = key + ids = append(ids, id) + } + return keys, ids, nil +}