mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 16:47:32 +00:00
feat: add crypto pkg
This commit is contained in:
136
internal/crypto/aes.go
Normal file
136
internal/crypto/aes.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/caos/zitadel/internal/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ EncryptionAlg = (*AESCrypto)(nil)
|
||||||
|
|
||||||
|
type AESCrypto struct {
|
||||||
|
keys map[string]string
|
||||||
|
encryptionKeyID string
|
||||||
|
keyIDs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAESCrypto(config *KeyConfig) (*AESCrypto, error) {
|
||||||
|
keys, ids, err := LoadKeys(config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &AESCrypto{
|
||||||
|
keys: keys,
|
||||||
|
encryptionKeyID: config.EncryptionKeyID,
|
||||||
|
keyIDs: ids,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AESCrypto) Algorithm() string {
|
||||||
|
return "aes"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AESCrypto) Encrypt(value []byte) ([]byte, error) {
|
||||||
|
return EncryptAES(value, a.encryptionKey())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AESCrypto) Decrypt(value []byte, keyID string) ([]byte, error) {
|
||||||
|
key, err := a.decryptionKey(keyID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return DecryptAES(value, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AESCrypto) DecryptString(value []byte, keyID string) (string, error) {
|
||||||
|
key, err := a.decryptionKey(keyID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
b, err := DecryptAES(value, key)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AESCrypto) EncryptionKeyID() string {
|
||||||
|
return a.encryptionKeyID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AESCrypto) DecryptionKeyIDs() []string {
|
||||||
|
return a.keyIDs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AESCrypto) encryptionKey() string {
|
||||||
|
return a.keys[a.encryptionKeyID]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AESCrypto) decryptionKey(keyID string) (string, error) {
|
||||||
|
key, ok := a.keys[keyID]
|
||||||
|
if !ok {
|
||||||
|
return "", errors.ThrowNotFound(nil, "CRYPT-nkj1s", "unknown key id")
|
||||||
|
}
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncryptAESString(data string, key string) (string, error) {
|
||||||
|
encrypted, err := EncryptAES([]byte(data), key)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.URLEncoding.EncodeToString(encrypted), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncryptAES(plainText []byte, key string) ([]byte, error) {
|
||||||
|
block, err := aes.NewCipher([]byte(key))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cipherText := make([]byte, aes.BlockSize+len(plainText))
|
||||||
|
iv := cipherText[:aes.BlockSize]
|
||||||
|
if _, err = io.ReadFull(rand.Reader, iv); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := cipher.NewCFBEncrypter(block, iv)
|
||||||
|
stream.XORKeyStream(cipherText[aes.BlockSize:], plainText)
|
||||||
|
|
||||||
|
return cipherText, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DecryptAESString(data string, key string) (string, error) {
|
||||||
|
text, err := base64.URLEncoding.DecodeString(data)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
decrypted, err := DecryptAES(text, key)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(decrypted), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DecryptAES(cipherText []byte, key string) ([]byte, error) {
|
||||||
|
block, err := aes.NewCipher([]byte(key))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cipherText) < aes.BlockSize {
|
||||||
|
err = errors.ThrowPreconditionFailed(nil, "CRYPT-23kH1", "cipher text block too short")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
iv := cipherText[:aes.BlockSize]
|
||||||
|
cipherText = cipherText[aes.BlockSize:]
|
||||||
|
|
||||||
|
stream := cipher.NewCFBDecrypter(block, iv)
|
||||||
|
stream.XORKeyStream(cipherText, cipherText)
|
||||||
|
|
||||||
|
return cipherText, err
|
||||||
|
}
|
17
internal/crypto/aes_test.go
Normal file
17
internal/crypto/aes_test.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDecrypt_OK(t *testing.T) {
|
||||||
|
encryptedpw, err := EncryptAESString("ThisIsMySecretPw", "passphrasewhichneedstobe32bytes!")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
decryptedpw, err := DecryptAESString(encryptedpw, "passphrasewhichneedstobe32bytes!")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "ThisIsMySecretPw", decryptedpw)
|
||||||
|
}
|
27
internal/crypto/bcrypt.go
Normal file
27
internal/crypto/bcrypt.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ HashAlg = (*BCrypt)(nil)
|
||||||
|
|
||||||
|
type BCrypt struct {
|
||||||
|
cost int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBCrypt(cost int) *BCrypt {
|
||||||
|
return &BCrypt{cost: cost}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BCrypt) Algorithm() string {
|
||||||
|
return "bcrypt"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BCrypt) Hash(value []byte) ([]byte, error) {
|
||||||
|
return bcrypt.GenerateFromPassword(value, b.cost)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BCrypt) CompareHash(hashed, value []byte) error {
|
||||||
|
return bcrypt.CompareHashAndPassword(hashed, value)
|
||||||
|
}
|
153
internal/crypto/code.go
Normal file
153
internal/crypto/code.go
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/caos/zitadel/internal/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
LowerLetters = []rune("abcdefghijklmnopqrstuvwxyz")
|
||||||
|
UpperLetters = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||||
|
Digits = []rune("0123456789")
|
||||||
|
Symbols = []rune("~!@#$^&*()_+`-={}|[]:<>?,./")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Generator interface {
|
||||||
|
Length() uint
|
||||||
|
Expiry() time.Duration
|
||||||
|
Alg() Crypto
|
||||||
|
Runes() []rune
|
||||||
|
}
|
||||||
|
|
||||||
|
type EncryptionGenerator struct {
|
||||||
|
length uint
|
||||||
|
expiry time.Duration
|
||||||
|
alg EncryptionAlg
|
||||||
|
runes []rune
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *EncryptionGenerator) Length() uint {
|
||||||
|
return g.length
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *EncryptionGenerator) Expiry() time.Duration {
|
||||||
|
return g.expiry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *EncryptionGenerator) Alg() Crypto {
|
||||||
|
return g.alg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *EncryptionGenerator) Runes() []rune {
|
||||||
|
return g.runes
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEncryptionGenerator(length uint, expiry time.Duration, alg EncryptionAlg, runes []rune) *EncryptionGenerator {
|
||||||
|
return &EncryptionGenerator{
|
||||||
|
length: length,
|
||||||
|
expiry: expiry,
|
||||||
|
alg: alg,
|
||||||
|
runes: runes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type HashGenerator struct {
|
||||||
|
length uint
|
||||||
|
expiry time.Duration
|
||||||
|
alg HashAlg
|
||||||
|
runes []rune
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *HashGenerator) Length() uint {
|
||||||
|
return g.length
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *HashGenerator) Expiry() time.Duration {
|
||||||
|
return g.expiry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *HashGenerator) Alg() Crypto {
|
||||||
|
return g.alg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *HashGenerator) Runes() []rune {
|
||||||
|
return g.runes
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHashGenerator(length uint, expiry time.Duration, alg HashAlg, runes []rune) *HashGenerator {
|
||||||
|
return &HashGenerator{
|
||||||
|
length: length,
|
||||||
|
expiry: expiry,
|
||||||
|
alg: alg,
|
||||||
|
runes: runes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCode(g Generator) (*CryptoValue, string, error) {
|
||||||
|
code, err := generateRandomString(g.Length(), g.Runes())
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
crypto, err := Crypt([]byte(code), g.Alg())
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
return crypto, code, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsCodeExpired(creationDate time.Time, expiry time.Duration) bool {
|
||||||
|
return creationDate.Add(expiry).Before(time.Now().UTC())
|
||||||
|
}
|
||||||
|
|
||||||
|
func VerifyCode(creationDate time.Time, expiry time.Duration, cryptoCode *CryptoValue, verificationCode string, g Generator) error {
|
||||||
|
if IsCodeExpired(creationDate, expiry) {
|
||||||
|
return errors.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "verification code is expired")
|
||||||
|
}
|
||||||
|
switch alg := g.Alg().(type) {
|
||||||
|
case EncryptionAlg:
|
||||||
|
return verifyEncryptedCode(cryptoCode, verificationCode, alg)
|
||||||
|
case HashAlg:
|
||||||
|
return verifyHashedCode(cryptoCode, verificationCode, alg)
|
||||||
|
}
|
||||||
|
return errors.ThrowInvalidArgument(nil, "CODE-fW2gNa", "generator alg is not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateRandomString(length uint, chars []rune) (string, error) {
|
||||||
|
if length == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
max := len(chars) - 1
|
||||||
|
maxStr := int(length - 1)
|
||||||
|
|
||||||
|
str := make([]rune, length)
|
||||||
|
randBytes := make([]byte, length)
|
||||||
|
if _, err := rand.Read(randBytes); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
for i, rb := range randBytes {
|
||||||
|
str[i] = chars[int(rb)%max]
|
||||||
|
if i == maxStr {
|
||||||
|
return string(str), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyEncryptedCode(cryptoCode *CryptoValue, verificationCode string, alg EncryptionAlg) error {
|
||||||
|
code, err := DecryptString(cryptoCode, alg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if code != verificationCode {
|
||||||
|
return errors.ThrowInvalidArgument(nil, "CODE-woT0xc", "verification code is invalid")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyHashedCode(cryptoCode *CryptoValue, verificationCode string, alg HashAlg) error {
|
||||||
|
return CompareHash(cryptoCode, []byte(verificationCode), alg)
|
||||||
|
}
|
92
internal/crypto/code_test.go
Normal file
92
internal/crypto/code_test.go
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_Encrypted_OK(t *testing.T) {
|
||||||
|
mCrypto := NewMockEncryptionAlg(gomock.NewController(t))
|
||||||
|
mCrypto.EXPECT().Algorithm().AnyTimes().Return("enc")
|
||||||
|
mCrypto.EXPECT().EncryptionKeyID().AnyTimes().Return("id")
|
||||||
|
mCrypto.EXPECT().DecryptionKeyIDs().AnyTimes().Return([]string{"id"})
|
||||||
|
mCrypto.EXPECT().Encrypt(gomock.Any()).DoAndReturn(
|
||||||
|
func(code []byte) ([]byte, error) {
|
||||||
|
return code, nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
mCrypto.EXPECT().DecryptString(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||||
|
func(code []byte, _ string) (string, error) {
|
||||||
|
return string(code), nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
generator := NewEncryptionGenerator(6, 0, mCrypto, Digits)
|
||||||
|
|
||||||
|
crypto, code, err := NewCode(generator)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
decrypted, err := DecryptString(crypto, generator.alg)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, code, decrypted)
|
||||||
|
assert.Equal(t, 6, len(decrypted))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Verify_Encrypted_OK(t *testing.T) {
|
||||||
|
mCrypto := NewMockEncryptionAlg(gomock.NewController(t))
|
||||||
|
mCrypto.EXPECT().Algorithm().AnyTimes().Return("enc")
|
||||||
|
mCrypto.EXPECT().EncryptionKeyID().AnyTimes().Return("id")
|
||||||
|
mCrypto.EXPECT().DecryptionKeyIDs().AnyTimes().Return([]string{"id"})
|
||||||
|
mCrypto.EXPECT().DecryptString(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||||
|
func(code []byte, _ string) (string, error) {
|
||||||
|
return string(code), nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
creationDate := time.Now()
|
||||||
|
code := &CryptoValue{
|
||||||
|
CryptoType: TypeEncryption,
|
||||||
|
Algorithm: "enc",
|
||||||
|
KeyID: "id",
|
||||||
|
Crypted: []byte("code"),
|
||||||
|
}
|
||||||
|
generator := NewEncryptionGenerator(6, 0, mCrypto, Digits)
|
||||||
|
|
||||||
|
err := VerifyCode(creationDate, 1*time.Hour, code, "code", generator)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
func Test_Verify_Encrypted_Invalid_Err(t *testing.T) {
|
||||||
|
mCrypto := NewMockEncryptionAlg(gomock.NewController(t))
|
||||||
|
mCrypto.EXPECT().Algorithm().AnyTimes().Return("enc")
|
||||||
|
mCrypto.EXPECT().EncryptionKeyID().AnyTimes().Return("id")
|
||||||
|
mCrypto.EXPECT().DecryptionKeyIDs().AnyTimes().Return([]string{"id"})
|
||||||
|
mCrypto.EXPECT().DecryptString(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||||
|
func(code []byte, _ string) (string, error) {
|
||||||
|
return string(code), nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
creationDate := time.Now()
|
||||||
|
code := &CryptoValue{
|
||||||
|
CryptoType: TypeEncryption,
|
||||||
|
Algorithm: "enc",
|
||||||
|
KeyID: "id",
|
||||||
|
Crypted: []byte("code"),
|
||||||
|
}
|
||||||
|
generator := NewEncryptionGenerator(6, 0, mCrypto, Digits)
|
||||||
|
|
||||||
|
err := VerifyCode(creationDate, 1*time.Hour, code, "wrong", generator)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsCodeExpired_Expired(t *testing.T) {
|
||||||
|
creationDate := time.Date(2019, time.April, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
expired := IsCodeExpired(creationDate, 1*time.Hour)
|
||||||
|
assert.True(t, expired)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsCodeExpired_NotExpired(t *testing.T) {
|
||||||
|
creationDate := time.Now()
|
||||||
|
expired := IsCodeExpired(creationDate, 1*time.Hour)
|
||||||
|
assert.False(t, expired)
|
||||||
|
}
|
103
internal/crypto/crypto.go
Normal file
103
internal/crypto/crypto.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/caos/zitadel/internal/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeEncryption CryptoType = iota
|
||||||
|
TypeHash
|
||||||
|
)
|
||||||
|
|
||||||
|
type Crypto interface {
|
||||||
|
Algorithm() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type EncryptionAlg interface {
|
||||||
|
Crypto
|
||||||
|
EncryptionKeyID() string
|
||||||
|
DecryptionKeyIDs() []string
|
||||||
|
Encrypt(value []byte) ([]byte, error)
|
||||||
|
Decrypt(hashed []byte, keyID string) ([]byte, error)
|
||||||
|
DecryptString(hashed []byte, keyID string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type HashAlg interface {
|
||||||
|
Crypto
|
||||||
|
Hash(value []byte) ([]byte, error)
|
||||||
|
CompareHash(hashed, comparer []byte) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type CryptoValue struct {
|
||||||
|
CryptoType CryptoType
|
||||||
|
Algorithm string
|
||||||
|
KeyID string
|
||||||
|
Crypted []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type CryptoType int
|
||||||
|
|
||||||
|
func Crypt(value []byte, c Crypto) (*CryptoValue, error) {
|
||||||
|
switch alg := c.(type) {
|
||||||
|
case EncryptionAlg:
|
||||||
|
return Encrypt(value, alg)
|
||||||
|
case HashAlg:
|
||||||
|
return Hash(value, alg)
|
||||||
|
}
|
||||||
|
return nil, errors.ThrowInternal(nil, "CRYPT-r4IaHZ", "algorithm not supported")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Encrypt(value []byte, alg EncryptionAlg) (*CryptoValue, error) {
|
||||||
|
encrypted, err := alg.Encrypt(value)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.ThrowInternal(err, "CRYPT-qCD0JB", "error encrypting value")
|
||||||
|
}
|
||||||
|
return &CryptoValue{
|
||||||
|
CryptoType: TypeEncryption,
|
||||||
|
Algorithm: alg.Algorithm(),
|
||||||
|
KeyID: alg.EncryptionKeyID(),
|
||||||
|
Crypted: encrypted,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Decrypt(value *CryptoValue, alg EncryptionAlg) ([]byte, error) {
|
||||||
|
if err := checkEncAlg(value, alg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return alg.Decrypt(value.Crypted, value.KeyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DecryptString(value *CryptoValue, alg EncryptionAlg) (string, error) {
|
||||||
|
if err := checkEncAlg(value, alg); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return alg.DecryptString(value.Crypted, value.KeyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkEncAlg(value *CryptoValue, alg EncryptionAlg) error {
|
||||||
|
if value.Algorithm != alg.Algorithm() {
|
||||||
|
return errors.ThrowInvalidArgument(nil, "CRYPT-Nx7XlT", "value was encrypted with a different key")
|
||||||
|
}
|
||||||
|
for _, id := range alg.DecryptionKeyIDs() {
|
||||||
|
if id == value.KeyID {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errors.ThrowInvalidArgument(nil, "CRYPT-Kq12vn", "value was encrypted with a different key")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Hash(value []byte, alg HashAlg) (*CryptoValue, error) {
|
||||||
|
hashed, err := alg.Hash(value)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.ThrowInternal(err, "CRYPT-rBVaJU", "error hashing value")
|
||||||
|
}
|
||||||
|
return &CryptoValue{
|
||||||
|
CryptoType: TypeHash,
|
||||||
|
Algorithm: alg.Algorithm(),
|
||||||
|
Crypted: hashed,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CompareHash(value *CryptoValue, comparer []byte, alg HashAlg) error {
|
||||||
|
return alg.CompareHash(value.Crypted, comparer)
|
||||||
|
}
|
223
internal/crypto/crypto_mock.go
Normal file
223
internal/crypto/crypto_mock.go
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: crypto.go
|
||||||
|
|
||||||
|
// Package crypto is a generated GoMock package.
|
||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
reflect "reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockCrypto is a mock of Crypto interface
|
||||||
|
type MockCrypto struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockCryptoMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockCryptoMockRecorder is the mock recorder for MockCrypto
|
||||||
|
type MockCryptoMockRecorder struct {
|
||||||
|
mock *MockCrypto
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockCrypto creates a new mock instance
|
||||||
|
func NewMockCrypto(ctrl *gomock.Controller) *MockCrypto {
|
||||||
|
mock := &MockCrypto{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockCryptoMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use
|
||||||
|
func (m *MockCrypto) EXPECT() *MockCryptoMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Algorithm mocks base method
|
||||||
|
func (m *MockCrypto) Algorithm() string {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Algorithm")
|
||||||
|
ret0, _ := ret[0].(string)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Algorithm indicates an expected call of Algorithm
|
||||||
|
func (mr *MockCryptoMockRecorder) Algorithm() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Algorithm", reflect.TypeOf((*MockCrypto)(nil).Algorithm))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockEncryptionAlg is a mock of EncryptionAlg interface
|
||||||
|
type MockEncryptionAlg struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockEncryptionAlgMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockEncryptionAlgMockRecorder is the mock recorder for MockEncryptionAlg
|
||||||
|
type MockEncryptionAlgMockRecorder struct {
|
||||||
|
mock *MockEncryptionAlg
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockEncryptionAlg creates a new mock instance
|
||||||
|
func NewMockEncryptionAlg(ctrl *gomock.Controller) *MockEncryptionAlg {
|
||||||
|
mock := &MockEncryptionAlg{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockEncryptionAlgMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use
|
||||||
|
func (m *MockEncryptionAlg) EXPECT() *MockEncryptionAlgMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Algorithm mocks base method
|
||||||
|
func (m *MockEncryptionAlg) Algorithm() string {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Algorithm")
|
||||||
|
ret0, _ := ret[0].(string)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Algorithm indicates an expected call of Algorithm
|
||||||
|
func (mr *MockEncryptionAlgMockRecorder) Algorithm() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Algorithm", reflect.TypeOf((*MockEncryptionAlg)(nil).Algorithm))
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncryptionKeyID mocks base method
|
||||||
|
func (m *MockEncryptionAlg) EncryptionKeyID() string {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "EncryptionKeyID")
|
||||||
|
ret0, _ := ret[0].(string)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncryptionKeyID indicates an expected call of EncryptionKeyID
|
||||||
|
func (mr *MockEncryptionAlgMockRecorder) EncryptionKeyID() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptionKeyID", reflect.TypeOf((*MockEncryptionAlg)(nil).EncryptionKeyID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecryptionKeyIDs mocks base method
|
||||||
|
func (m *MockEncryptionAlg) DecryptionKeyIDs() []string {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DecryptionKeyIDs")
|
||||||
|
ret0, _ := ret[0].([]string)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecryptionKeyIDs indicates an expected call of DecryptionKeyIDs
|
||||||
|
func (mr *MockEncryptionAlgMockRecorder) DecryptionKeyIDs() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptionKeyIDs", reflect.TypeOf((*MockEncryptionAlg)(nil).DecryptionKeyIDs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt mocks base method
|
||||||
|
func (m *MockEncryptionAlg) Encrypt(value []byte) ([]byte, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Encrypt", value)
|
||||||
|
ret0, _ := ret[0].([]byte)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt indicates an expected call of Encrypt
|
||||||
|
func (mr *MockEncryptionAlgMockRecorder) Encrypt(value interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encrypt", reflect.TypeOf((*MockEncryptionAlg)(nil).Encrypt), value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt mocks base method
|
||||||
|
func (m *MockEncryptionAlg) Decrypt(hashed []byte, keyID string) ([]byte, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Decrypt", hashed, keyID)
|
||||||
|
ret0, _ := ret[0].([]byte)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt indicates an expected call of Decrypt
|
||||||
|
func (mr *MockEncryptionAlgMockRecorder) Decrypt(hashed, keyID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decrypt", reflect.TypeOf((*MockEncryptionAlg)(nil).Decrypt), hashed, keyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecryptString mocks base method
|
||||||
|
func (m *MockEncryptionAlg) DecryptString(hashed []byte, keyID string) (string, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DecryptString", hashed, keyID)
|
||||||
|
ret0, _ := ret[0].(string)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecryptString indicates an expected call of DecryptString
|
||||||
|
func (mr *MockEncryptionAlgMockRecorder) DecryptString(hashed, keyID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptString", reflect.TypeOf((*MockEncryptionAlg)(nil).DecryptString), hashed, keyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockHashAlg is a mock of HashAlg interface
|
||||||
|
type MockHashAlg struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockHashAlgMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockHashAlgMockRecorder is the mock recorder for MockHashAlg
|
||||||
|
type MockHashAlgMockRecorder struct {
|
||||||
|
mock *MockHashAlg
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockHashAlg creates a new mock instance
|
||||||
|
func NewMockHashAlg(ctrl *gomock.Controller) *MockHashAlg {
|
||||||
|
mock := &MockHashAlg{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockHashAlgMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use
|
||||||
|
func (m *MockHashAlg) EXPECT() *MockHashAlgMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Algorithm mocks base method
|
||||||
|
func (m *MockHashAlg) Algorithm() string {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Algorithm")
|
||||||
|
ret0, _ := ret[0].(string)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Algorithm indicates an expected call of Algorithm
|
||||||
|
func (mr *MockHashAlgMockRecorder) Algorithm() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Algorithm", reflect.TypeOf((*MockHashAlg)(nil).Algorithm))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash mocks base method
|
||||||
|
func (m *MockHashAlg) Hash(value []byte) ([]byte, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Hash", value)
|
||||||
|
ret0, _ := ret[0].([]byte)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash indicates an expected call of Hash
|
||||||
|
func (mr *MockHashAlgMockRecorder) Hash(value interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Hash", reflect.TypeOf((*MockHashAlg)(nil).Hash), value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompareHash mocks base method
|
||||||
|
func (m *MockHashAlg) CompareHash(hashed, comparer []byte) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "CompareHash", hashed, comparer)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompareHash indicates an expected call of CompareHash
|
||||||
|
func (mr *MockHashAlgMockRecorder) CompareHash(hashed, comparer interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CompareHash", reflect.TypeOf((*MockHashAlg)(nil).CompareHash), hashed, comparer)
|
||||||
|
}
|
273
internal/crypto/crypto_test.go
Normal file
273
internal/crypto/crypto_test.go
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockEncCrypto struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockEncCrypto) Algorithm() string {
|
||||||
|
return "enc"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockEncCrypto) Encrypt(value []byte) ([]byte, error) {
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockEncCrypto) Decrypt(value []byte, _ string) ([]byte, error) {
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockEncCrypto) DecryptString(value []byte, _ string) (string, error) {
|
||||||
|
return string(value), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockEncCrypto) EncryptionKeyID() string {
|
||||||
|
return "keyID"
|
||||||
|
}
|
||||||
|
func (m *mockEncCrypto) DecryptionKeyIDs() []string {
|
||||||
|
return []string{"keyID"}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockHashCrypto struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockHashCrypto) Algorithm() string {
|
||||||
|
return "hash"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockHashCrypto) Hash(value []byte) ([]byte, error) {
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockHashCrypto) CompareHash(hashed, comparer []byte) error {
|
||||||
|
if !bytes.Equal(hashed, comparer) {
|
||||||
|
return errors.New("not equal")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type alg struct{}
|
||||||
|
|
||||||
|
func (a *alg) Algorithm() string {
|
||||||
|
return "alg"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCrypt(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
value []byte
|
||||||
|
c Crypto
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *CryptoValue
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"encrypt",
|
||||||
|
args{[]byte("test"), &mockEncCrypto{}},
|
||||||
|
&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID", Crypted: []byte("test")},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"hash",
|
||||||
|
args{[]byte("test"), &mockHashCrypto{}},
|
||||||
|
&CryptoValue{CryptoType: TypeHash, Algorithm: "hash", Crypted: []byte("test")},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"wrong type",
|
||||||
|
args{[]byte("test"), &alg{}},
|
||||||
|
nil,
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := Crypt(tt.args.value, tt.args.c)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Crypt() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Crypt() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncrypt(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
value []byte
|
||||||
|
c EncryptionAlg
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *CryptoValue
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"ok",
|
||||||
|
args{[]byte("test"), &mockEncCrypto{}},
|
||||||
|
&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID", Crypted: []byte("test")},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := Encrypt(tt.args.value, tt.args.c)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Encrypt() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Encrypt() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecrypt(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
value *CryptoValue
|
||||||
|
c EncryptionAlg
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want []byte
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"ok",
|
||||||
|
args{&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID", Crypted: []byte("test")}, &mockEncCrypto{}},
|
||||||
|
[]byte("test"),
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"wrong id",
|
||||||
|
args{&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID2", Crypted: []byte("test")}, &mockEncCrypto{}},
|
||||||
|
nil,
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := Decrypt(tt.args.value, tt.args.c)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Decrypt() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Decrypt() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptString(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
value *CryptoValue
|
||||||
|
c EncryptionAlg
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"ok",
|
||||||
|
args{&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID", Crypted: []byte("test")}, &mockEncCrypto{}},
|
||||||
|
"test",
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"wrong id",
|
||||||
|
args{&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID2", Crypted: []byte("test")}, &mockEncCrypto{}},
|
||||||
|
"",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := DecryptString(tt.args.value, tt.args.c)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("DecryptString() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("DecryptString() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHash(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
value []byte
|
||||||
|
c HashAlg
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *CryptoValue
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"ok",
|
||||||
|
args{[]byte("test"), &mockHashCrypto{}},
|
||||||
|
&CryptoValue{CryptoType: TypeHash, Algorithm: "hash", Crypted: []byte("test")},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := Hash(tt.args.value, tt.args.c)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Hash() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Hash() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompareHash(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
value *CryptoValue
|
||||||
|
comparer []byte
|
||||||
|
c HashAlg
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"ok",
|
||||||
|
args{&CryptoValue{CryptoType: TypeHash, Algorithm: "hash", Crypted: []byte("test")}, []byte("test"), &mockHashCrypto{}},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"wrong",
|
||||||
|
args{&CryptoValue{CryptoType: TypeHash, Algorithm: "hash", Crypted: []byte("test")}, []byte("test2"), &mockHashCrypto{}},
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if err := CompareHash(tt.args.value, tt.args.comparer, tt.args.c); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("CompareHash() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
3
internal/crypto/generate.go
Normal file
3
internal/crypto/generate.go
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
//go:generate mockgen -source crypto.go -destination ./crypto_mock.go -package crypto
|
64
internal/crypto/key.go
Normal file
64
internal/crypto/key.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/caos/logging"
|
||||||
|
"github.com/caos/utils/errors"
|
||||||
|
|
||||||
|
"github.com/caos/zitadel/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ZitadelKeyPath = "ZITADEL_KEY_PATH"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KeyConfig struct {
|
||||||
|
EncryptionKeyID string
|
||||||
|
DecryptionKeyIDs []string
|
||||||
|
Path string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Keys map[string]string
|
||||||
|
|
||||||
|
func ReadKeys(path string) (Keys, error) {
|
||||||
|
if path == "" {
|
||||||
|
path = os.Getenv(ZitadelKeyPath)
|
||||||
|
if path == "" {
|
||||||
|
return nil, errors.ThrowInvalidArgument(nil, "CRYPT-56lka", "no path set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
keys := new(Keys)
|
||||||
|
err := config.Read(keys, path)
|
||||||
|
return *keys, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadKeys(config *KeyConfig) (map[string]string, []string, error) {
|
||||||
|
if config == nil {
|
||||||
|
return nil, nil, errors.ThrowInvalidArgument(nil, "CRYPT-dJK8s", "config must not be nil")
|
||||||
|
}
|
||||||
|
readKeys, err := ReadKeys(config.Path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
keys := make(map[string]string)
|
||||||
|
ids := make([]string, 0, len(config.DecryptionKeyIDs)+1)
|
||||||
|
if config.EncryptionKeyID != "" {
|
||||||
|
key, ok := readKeys[config.EncryptionKeyID]
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, errors.ThrowInternalf(nil, "CRYPT-v2Kas", "encryption key not found")
|
||||||
|
}
|
||||||
|
keys[config.EncryptionKeyID] = key
|
||||||
|
ids = append(ids, config.EncryptionKeyID)
|
||||||
|
}
|
||||||
|
for _, id := range config.DecryptionKeyIDs {
|
||||||
|
key, ok := readKeys[id]
|
||||||
|
if !ok {
|
||||||
|
logging.Log("CRYPT-s23rf").Warnf("description key %s not found", id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
keys[id] = key
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
return keys, ids, nil
|
||||||
|
}
|
Reference in New Issue
Block a user