mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 00:57:33 +00:00
chore: move the go code into a subfolder
This commit is contained in:
146
apps/api/internal/crypto/aes.go
Normal file
146
apps/api/internal/crypto/aes.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
var _ EncryptionAlgorithm = (*AESCrypto)(nil)
|
||||
|
||||
type AESCrypto struct {
|
||||
keys map[string]string
|
||||
encryptionKeyID string
|
||||
keyIDs []string
|
||||
}
|
||||
|
||||
func NewAESCrypto(config *KeyConfig, keyStorage KeyStorage) (*AESCrypto, error) {
|
||||
keys, ids, err := LoadKeys(config, keyStorage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AESCrypto{
|
||||
keys: keys,
|
||||
encryptionKeyID: config.EncryptionKeyID,
|
||||
keyIDs: ids,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *AESCrypto) Algorithm() string {
|
||||
return "aes"
|
||||
}
|
||||
|
||||
func (a *AESCrypto) Encrypt(value []byte) ([]byte, error) {
|
||||
return EncryptAES(value, a.encryptionKey())
|
||||
}
|
||||
|
||||
func (a *AESCrypto) Decrypt(value []byte, keyID string) ([]byte, error) {
|
||||
key, err := a.decryptionKey(keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return DecryptAES(value, key)
|
||||
}
|
||||
|
||||
// DecryptString decrypts the value using the key identified by keyID.
|
||||
// When the decrypted value contains non-UTF8 characters an error is returned.
|
||||
func (a *AESCrypto) DecryptString(value []byte, keyID string) (string, error) {
|
||||
b, err := a.Decrypt(value, keyID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !utf8.Valid(b) {
|
||||
return "", zerrors.ThrowPreconditionFailed(err, "CRYPT-hiCh0", "non-UTF-8 in decrypted string")
|
||||
}
|
||||
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func (a *AESCrypto) EncryptionKeyID() string {
|
||||
return a.encryptionKeyID
|
||||
}
|
||||
|
||||
func (a *AESCrypto) DecryptionKeyIDs() []string {
|
||||
return a.keyIDs
|
||||
}
|
||||
|
||||
func (a *AESCrypto) encryptionKey() string {
|
||||
return a.keys[a.encryptionKeyID]
|
||||
}
|
||||
|
||||
func (a *AESCrypto) decryptionKey(keyID string) (string, error) {
|
||||
key, ok := a.keys[keyID]
|
||||
if !ok {
|
||||
return "", zerrors.ThrowNotFound(nil, "CRYPT-nkj1s", "unknown key id")
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func EncryptAESString(data string, key string) (string, error) {
|
||||
encrypted, err := EncryptAES([]byte(data), key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(encrypted), nil
|
||||
}
|
||||
|
||||
func EncryptAES(plainText []byte, key string) ([]byte, error) {
|
||||
block, err := aes.NewCipher([]byte(key))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxSize := 64 * 1024 * 1024
|
||||
if len(plainText) > maxSize {
|
||||
return nil, zerrors.ThrowPreconditionFailedf(nil, "CRYPT-AGg4t3", "data too large, max bytes: %v", maxSize)
|
||||
}
|
||||
cipherText := make([]byte, aes.BlockSize+len(plainText))
|
||||
iv := cipherText[:aes.BlockSize]
|
||||
if _, err = io.ReadFull(rand.Reader, iv); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stream := cipher.NewCFBEncrypter(block, iv)
|
||||
stream.XORKeyStream(cipherText[aes.BlockSize:], plainText)
|
||||
|
||||
return cipherText, nil
|
||||
}
|
||||
|
||||
func DecryptAESString(data string, key string) (string, error) {
|
||||
text, err := base64.URLEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return "", nil
|
||||
}
|
||||
decrypted, err := DecryptAES(text, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(decrypted), nil
|
||||
}
|
||||
|
||||
func DecryptAES(text []byte, key string) ([]byte, error) {
|
||||
cipherText := make([]byte, len(text))
|
||||
copy(cipherText, text)
|
||||
|
||||
block, err := aes.NewCipher([]byte(key))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(cipherText) < aes.BlockSize {
|
||||
err = zerrors.ThrowPreconditionFailed(nil, "CRYPT-23kH1", "cipher text block too short")
|
||||
return nil, err
|
||||
}
|
||||
iv := cipherText[:aes.BlockSize]
|
||||
cipherText = cipherText[aes.BlockSize:]
|
||||
|
||||
stream := cipher.NewCFBDecrypter(block, iv)
|
||||
stream.XORKeyStream(cipherText, cipherText)
|
||||
|
||||
return cipherText, err
|
||||
}
|
109
apps/api/internal/crypto/aes_test.go
Normal file
109
apps/api/internal/crypto/aes_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type mockKeyStorage struct {
|
||||
keys Keys
|
||||
}
|
||||
|
||||
func (s *mockKeyStorage) ReadKeys() (Keys, error) {
|
||||
return s.keys, nil
|
||||
}
|
||||
|
||||
func (s *mockKeyStorage) ReadKey(id string) (*Key, error) {
|
||||
return &Key{
|
||||
ID: id,
|
||||
Value: s.keys[id],
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (*mockKeyStorage) CreateKeys(context.Context, ...*Key) error {
|
||||
return errors.New("mockKeyStorage.CreateKeys not implemented")
|
||||
}
|
||||
|
||||
func newTestAESCrypto(t testing.TB) *AESCrypto {
|
||||
keyConfig := &KeyConfig{
|
||||
EncryptionKeyID: "keyID",
|
||||
DecryptionKeyIDs: []string{"keyID"},
|
||||
}
|
||||
keys := Keys{"keyID": "ThisKeyNeedsToHave32Characters!!"}
|
||||
aesCrypto, err := NewAESCrypto(keyConfig, &mockKeyStorage{keys: keys})
|
||||
require.NoError(t, err)
|
||||
return aesCrypto
|
||||
}
|
||||
|
||||
func TestAESCrypto_DecryptString(t *testing.T) {
|
||||
aesCrypto := newTestAESCrypto(t)
|
||||
const input = "SecretData"
|
||||
crypted, err := aesCrypto.Encrypt([]byte(input))
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
value []byte
|
||||
keyID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "unknown key id error",
|
||||
args: args{
|
||||
value: crypted,
|
||||
keyID: "foo",
|
||||
},
|
||||
wantErr: zerrors.ThrowNotFound(nil, "CRYPT-nkj1s", "unknown key id"),
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
value: crypted,
|
||||
keyID: "keyID",
|
||||
},
|
||||
want: input,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got, err := aesCrypto.DecryptString(tt.args.value, tt.args.keyID)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzAESCrypto_DecryptString(f *testing.F) {
|
||||
aesCrypto := newTestAESCrypto(f)
|
||||
tests := []string{
|
||||
" ",
|
||||
"SecretData",
|
||||
"FooBar",
|
||||
"HelloWorld",
|
||||
}
|
||||
for _, input := range tests {
|
||||
tc, err := aesCrypto.Encrypt([]byte(input))
|
||||
require.NoError(f, err)
|
||||
f.Add(tc)
|
||||
}
|
||||
f.Fuzz(func(t *testing.T, value []byte) {
|
||||
got, err := aesCrypto.DecryptString(value, "keyID")
|
||||
if errors.Is(err, zerrors.ThrowPreconditionFailed(nil, "CRYPT-23kH1", "cipher text block too short")) {
|
||||
return
|
||||
}
|
||||
if errors.Is(err, zerrors.ThrowPreconditionFailed(nil, "CRYPT-hiCh0", "non-UTF-8 in decrypted string")) {
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.True(t, utf8.ValidString(got), "result is not valid UTF-8")
|
||||
})
|
||||
}
|
173
apps/api/internal/crypto/code.go
Normal file
173
apps/api/internal/crypto/code.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
var (
|
||||
lowerLetters = []rune("abcdefghijklmnopqrstuvwxyz")
|
||||
upperLetters = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
digits = []rune("0123456789")
|
||||
symbols = []rune("~!@#$^&*()_+`-={}|[]:<>?,./")
|
||||
)
|
||||
|
||||
type GeneratorConfig struct {
|
||||
Length uint
|
||||
Expiry time.Duration
|
||||
IncludeLowerLetters bool
|
||||
IncludeUpperLetters bool
|
||||
IncludeDigits bool
|
||||
IncludeSymbols bool
|
||||
}
|
||||
|
||||
type Generator interface {
|
||||
Length() uint
|
||||
Expiry() time.Duration
|
||||
Alg() EncryptionAlgorithm
|
||||
Runes() []rune
|
||||
}
|
||||
|
||||
type generator struct {
|
||||
length uint
|
||||
expiry time.Duration
|
||||
runes []rune
|
||||
}
|
||||
|
||||
func (g *generator) Length() uint {
|
||||
return g.length
|
||||
}
|
||||
|
||||
func (g *generator) Expiry() time.Duration {
|
||||
return g.expiry
|
||||
}
|
||||
|
||||
func (g *generator) Runes() []rune {
|
||||
return g.runes
|
||||
}
|
||||
|
||||
type encryptionGenerator struct {
|
||||
generator
|
||||
alg EncryptionAlgorithm
|
||||
}
|
||||
|
||||
func (g *encryptionGenerator) Alg() EncryptionAlgorithm {
|
||||
return g.alg
|
||||
}
|
||||
|
||||
func NewEncryptionGenerator(config GeneratorConfig, algorithm EncryptionAlgorithm) Generator {
|
||||
return &encryptionGenerator{
|
||||
newGenerator(config),
|
||||
algorithm,
|
||||
}
|
||||
}
|
||||
|
||||
type HashGenerator struct {
|
||||
generator
|
||||
hasher *Hasher
|
||||
}
|
||||
|
||||
func NewHashGenerator(config GeneratorConfig, hasher *Hasher) *HashGenerator {
|
||||
return &HashGenerator{
|
||||
newGenerator(config),
|
||||
hasher,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *HashGenerator) NewCode() (encoded, plain string, err error) {
|
||||
plain, err = GenerateRandomString(g.Length(), g.Runes())
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
encoded, err = g.hasher.Hash(plain)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return encoded, plain, nil
|
||||
}
|
||||
|
||||
func newGenerator(config GeneratorConfig) generator {
|
||||
var runes []rune
|
||||
if config.IncludeLowerLetters {
|
||||
runes = append(runes, lowerLetters...)
|
||||
}
|
||||
if config.IncludeUpperLetters {
|
||||
runes = append(runes, upperLetters...)
|
||||
}
|
||||
if config.IncludeDigits {
|
||||
runes = append(runes, digits...)
|
||||
}
|
||||
if config.IncludeSymbols {
|
||||
runes = append(runes, symbols...)
|
||||
}
|
||||
return generator{
|
||||
length: config.Length,
|
||||
expiry: config.Expiry,
|
||||
runes: runes,
|
||||
}
|
||||
}
|
||||
|
||||
func NewCode(g Generator) (*CryptoValue, string, error) {
|
||||
code, err := GenerateRandomString(g.Length(), g.Runes())
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
crypto, err := Crypt([]byte(code), g.Alg())
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return crypto, code, nil
|
||||
}
|
||||
|
||||
func IsCodeExpired(creationDate time.Time, expiry time.Duration) bool {
|
||||
if expiry == 0 {
|
||||
return false
|
||||
}
|
||||
return creationDate.Add(expiry).Before(time.Now().UTC())
|
||||
}
|
||||
|
||||
func VerifyCode(creationDate time.Time, expiry time.Duration, cryptoCode *CryptoValue, verificationCode string, algorithm EncryptionAlgorithm) error {
|
||||
if IsCodeExpired(creationDate, expiry) {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired")
|
||||
}
|
||||
return verifyEncryptedCode(cryptoCode, verificationCode, algorithm)
|
||||
}
|
||||
|
||||
func GenerateRandomString(length uint, chars []rune) (string, error) {
|
||||
if length == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
max := len(chars) - 1
|
||||
maxStr := int(length - 1)
|
||||
|
||||
str := make([]rune, length)
|
||||
randBytes := make([]byte, length)
|
||||
if _, err := rand.Read(randBytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
for i, rb := range randBytes {
|
||||
str[i] = chars[int(rb)%max]
|
||||
if i == maxStr {
|
||||
return string(str), nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func verifyEncryptedCode(cryptoCode *CryptoValue, verificationCode string, alg EncryptionAlgorithm) error {
|
||||
if cryptoCode == nil {
|
||||
return zerrors.ThrowInvalidArgument(nil, "CRYPT-aqrFV", "Errors.User.Code.CryptoCodeNil")
|
||||
}
|
||||
code, err := DecryptString(cryptoCode, alg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if code != verificationCode {
|
||||
return zerrors.ThrowInvalidArgument(nil, "CODE-woT0xc", "Errors.User.Code.Invalid")
|
||||
}
|
||||
return nil
|
||||
}
|
96
apps/api/internal/crypto/code_mock.go
Normal file
96
apps/api/internal/crypto/code_mock.go
Normal file
@@ -0,0 +1,96 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: code.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source code.go -destination ./code_mock.go -package crypto
|
||||
//
|
||||
|
||||
// Package crypto is a generated GoMock package.
|
||||
package crypto
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockGenerator is a mock of Generator interface.
|
||||
type MockGenerator struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockGeneratorMockRecorder
|
||||
}
|
||||
|
||||
// MockGeneratorMockRecorder is the mock recorder for MockGenerator.
|
||||
type MockGeneratorMockRecorder struct {
|
||||
mock *MockGenerator
|
||||
}
|
||||
|
||||
// NewMockGenerator creates a new mock instance.
|
||||
func NewMockGenerator(ctrl *gomock.Controller) *MockGenerator {
|
||||
mock := &MockGenerator{ctrl: ctrl}
|
||||
mock.recorder = &MockGeneratorMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockGenerator) EXPECT() *MockGeneratorMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Alg mocks base method.
|
||||
func (m *MockGenerator) Alg() EncryptionAlgorithm {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Alg")
|
||||
ret0, _ := ret[0].(EncryptionAlgorithm)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Alg indicates an expected call of Alg.
|
||||
func (mr *MockGeneratorMockRecorder) Alg() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Alg", reflect.TypeOf((*MockGenerator)(nil).Alg))
|
||||
}
|
||||
|
||||
// Expiry mocks base method.
|
||||
func (m *MockGenerator) Expiry() time.Duration {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Expiry")
|
||||
ret0, _ := ret[0].(time.Duration)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Expiry indicates an expected call of Expiry.
|
||||
func (mr *MockGeneratorMockRecorder) Expiry() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Expiry", reflect.TypeOf((*MockGenerator)(nil).Expiry))
|
||||
}
|
||||
|
||||
// Length mocks base method.
|
||||
func (m *MockGenerator) Length() uint {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Length")
|
||||
ret0, _ := ret[0].(uint)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Length indicates an expected call of Length.
|
||||
func (mr *MockGeneratorMockRecorder) Length() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Length", reflect.TypeOf((*MockGenerator)(nil).Length))
|
||||
}
|
||||
|
||||
// Runes mocks base method.
|
||||
func (m *MockGenerator) Runes() []rune {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Runes")
|
||||
ret0, _ := ret[0].([]rune)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Runes indicates an expected call of Runes.
|
||||
func (mr *MockGeneratorMockRecorder) Runes() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Runes", reflect.TypeOf((*MockGenerator)(nil).Runes))
|
||||
}
|
73
apps/api/internal/crypto/code_mocker.go
Normal file
73
apps/api/internal/crypto/code_mocker.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func CreateMockEncryptionAlg(ctrl *gomock.Controller) EncryptionAlgorithm {
|
||||
return createMockEncryptionAlgorithm(
|
||||
ctrl,
|
||||
func(code []byte) ([]byte, error) {
|
||||
return code, nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// CreateMockEncryptionAlgWithCode compares the length of the value to be encrypted with the length of the provided code.
|
||||
// It will return an error if they do not match.
|
||||
// The provided code will be used to encrypt in favor of the value passed to the encryption.
|
||||
// This function is intended to be used where the passed value is not in control, but where the returned encryption requires a static value.
|
||||
func CreateMockEncryptionAlgWithCode(ctrl *gomock.Controller, code string) EncryptionAlgorithm {
|
||||
return createMockEncryptionAlgorithm(
|
||||
ctrl,
|
||||
func(c []byte) ([]byte, error) {
|
||||
if len(c) != len(code) {
|
||||
return nil, zerrors.ThrowInvalidArgumentf(nil, "id", "invalid code length - expected %d, got %d", len(code), len(c))
|
||||
}
|
||||
return []byte(code), nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func createMockEncryptionAlgorithm(ctrl *gomock.Controller, encryptFunction func(c []byte) ([]byte, error)) *MockEncryptionAlgorithm {
|
||||
mCrypto := NewMockEncryptionAlgorithm(ctrl)
|
||||
mCrypto.EXPECT().Algorithm().AnyTimes().Return("enc")
|
||||
mCrypto.EXPECT().EncryptionKeyID().AnyTimes().Return("id")
|
||||
mCrypto.EXPECT().DecryptionKeyIDs().AnyTimes().Return([]string{"id"})
|
||||
mCrypto.EXPECT().Encrypt(gomock.Any()).AnyTimes().DoAndReturn(
|
||||
encryptFunction,
|
||||
)
|
||||
mCrypto.EXPECT().DecryptString(gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn(
|
||||
func(code []byte, keyID string) (string, error) {
|
||||
if keyID != "id" {
|
||||
return "", zerrors.ThrowInternal(nil, "id", "invalid key id")
|
||||
}
|
||||
return string(code), nil
|
||||
},
|
||||
)
|
||||
mCrypto.EXPECT().Decrypt(gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn(
|
||||
func(code []byte, keyID string) ([]byte, error) {
|
||||
if keyID != "id" {
|
||||
return nil, zerrors.ThrowInternal(nil, "id", "invalid key id")
|
||||
}
|
||||
return code, nil
|
||||
},
|
||||
)
|
||||
return mCrypto
|
||||
}
|
||||
|
||||
func createMockCrypto(t *testing.T) EncryptionAlgorithm {
|
||||
mCrypto := NewMockEncryptionAlgorithm(gomock.NewController(t))
|
||||
mCrypto.EXPECT().Algorithm().AnyTimes().Return("crypto")
|
||||
return mCrypto
|
||||
}
|
||||
|
||||
func createMockGenerator(t *testing.T, crypto EncryptionAlgorithm) Generator {
|
||||
mGenerator := NewMockGenerator(gomock.NewController(t))
|
||||
mGenerator.EXPECT().Alg().AnyTimes().Return(crypto)
|
||||
return mGenerator
|
||||
}
|
209
apps/api/internal/crypto/code_test.go
Normal file
209
apps/api/internal/crypto/code_test.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
func TestIsCodeExpired(t *testing.T) {
|
||||
type args struct {
|
||||
creationDate time.Time
|
||||
expiry time.Duration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
"not expired",
|
||||
args{
|
||||
creationDate: time.Now(),
|
||||
expiry: time.Duration(5 * time.Minute),
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"never expires",
|
||||
args{
|
||||
creationDate: time.Now().Add(-5 * time.Minute),
|
||||
expiry: 0,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"expired",
|
||||
args{
|
||||
creationDate: time.Now().Add(-5 * time.Minute),
|
||||
expiry: time.Duration(5 * time.Minute),
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := IsCodeExpired(tt.args.creationDate, tt.args.expiry); got != tt.want {
|
||||
t.Errorf("IsCodeExpired() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyCode(t *testing.T) {
|
||||
type args struct {
|
||||
creationDate time.Time
|
||||
expiry time.Duration
|
||||
cryptoCode *CryptoValue
|
||||
verificationCode string
|
||||
g Generator
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"expired",
|
||||
args{
|
||||
creationDate: time.Now().Add(-5 * time.Minute),
|
||||
expiry: 5 * time.Minute,
|
||||
cryptoCode: nil,
|
||||
verificationCode: "",
|
||||
g: createMockGenerator(t, createMockCrypto(t)),
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"unsupported alg err",
|
||||
args{
|
||||
creationDate: time.Now(),
|
||||
expiry: 5 * time.Minute,
|
||||
cryptoCode: nil,
|
||||
verificationCode: "code",
|
||||
g: createMockGenerator(t, createMockCrypto(t)),
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"encryption alg ok",
|
||||
args{
|
||||
creationDate: time.Now(),
|
||||
expiry: 5 * time.Minute,
|
||||
cryptoCode: &CryptoValue{
|
||||
CryptoType: TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("code"),
|
||||
},
|
||||
verificationCode: "code",
|
||||
g: createMockGenerator(t, CreateMockEncryptionAlg(gomock.NewController(t))),
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := VerifyCode(tt.args.creationDate, tt.args.expiry, tt.args.cryptoCode, tt.args.verificationCode, tt.args.g.Alg()); (err != nil) != tt.wantErr {
|
||||
t.Errorf("VerifyCode() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_verifyEncryptedCode(t *testing.T) {
|
||||
type args struct {
|
||||
cryptoCode *CryptoValue
|
||||
verificationCode string
|
||||
alg EncryptionAlgorithm
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"nil error",
|
||||
args{
|
||||
cryptoCode: nil,
|
||||
verificationCode: "",
|
||||
alg: CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"wrong cryptotype error",
|
||||
args{
|
||||
cryptoCode: &CryptoValue{
|
||||
CryptoType: TypeHash,
|
||||
Crypted: nil,
|
||||
},
|
||||
verificationCode: "",
|
||||
alg: CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"wrong algorithm error",
|
||||
args{
|
||||
cryptoCode: &CryptoValue{
|
||||
CryptoType: TypeEncryption,
|
||||
Algorithm: "enc2",
|
||||
Crypted: nil,
|
||||
},
|
||||
verificationCode: "",
|
||||
alg: CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"wrong key id error",
|
||||
args{
|
||||
cryptoCode: &CryptoValue{
|
||||
CryptoType: TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
Crypted: nil,
|
||||
},
|
||||
verificationCode: "wrong",
|
||||
alg: CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"wrong verification code error",
|
||||
args{
|
||||
cryptoCode: &CryptoValue{
|
||||
CryptoType: TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("code"),
|
||||
},
|
||||
verificationCode: "wrong",
|
||||
alg: CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"verification code ok",
|
||||
args{
|
||||
cryptoCode: &CryptoValue{
|
||||
CryptoType: TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("code"),
|
||||
},
|
||||
verificationCode: "code",
|
||||
alg: CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := verifyEncryptedCode(tt.args.cryptoCode, tt.args.verificationCode, tt.args.alg); (err != nil) != tt.wantErr {
|
||||
t.Errorf("verifyEncryptedCode() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
143
apps/api/internal/crypto/crypto.go
Normal file
143
apps/api/internal/crypto/crypto.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
const (
|
||||
TypeEncryption CryptoType = iota
|
||||
TypeHash // Depcrecated: use [passwap.Swapper] instead
|
||||
)
|
||||
|
||||
type EncryptionAlgorithm interface {
|
||||
Algorithm() string
|
||||
EncryptionKeyID() string
|
||||
DecryptionKeyIDs() []string
|
||||
Encrypt(value []byte) ([]byte, error)
|
||||
Decrypt(hashed []byte, keyID string) ([]byte, error)
|
||||
|
||||
// DecryptString decrypts the value using the key identified by keyID.
|
||||
// When the decrypted value contains non-UTF8 characters an error is returned.
|
||||
DecryptString(hashed []byte, keyID string) (string, error)
|
||||
}
|
||||
|
||||
type CryptoValue struct {
|
||||
CryptoType CryptoType
|
||||
Algorithm string
|
||||
KeyID string
|
||||
Crypted []byte
|
||||
}
|
||||
|
||||
func (c *CryptoValue) Value() (driver.Value, error) {
|
||||
if c == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return json.Marshal(c)
|
||||
}
|
||||
|
||||
func (c *CryptoValue) Scan(src interface{}) error {
|
||||
if b, ok := src.([]byte); ok {
|
||||
return json.Unmarshal(b, c)
|
||||
}
|
||||
if s, ok := src.(string); ok {
|
||||
return json.Unmarshal([]byte(s), c)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type CryptoType int
|
||||
|
||||
func Crypt(value []byte, alg EncryptionAlgorithm) (*CryptoValue, error) {
|
||||
return Encrypt(value, alg)
|
||||
}
|
||||
|
||||
func Encrypt(value []byte, alg EncryptionAlgorithm) (*CryptoValue, error) {
|
||||
encrypted, err := alg.Encrypt(value)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "CRYPT-qCD0JB", "error encrypting value")
|
||||
}
|
||||
return &CryptoValue{
|
||||
CryptoType: TypeEncryption,
|
||||
Algorithm: alg.Algorithm(),
|
||||
KeyID: alg.EncryptionKeyID(),
|
||||
Crypted: encrypted,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func EncryptJSON(obj any, alg EncryptionAlgorithm) (*CryptoValue, error) {
|
||||
data, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "CRYPT-Ei6doF", "error encrypting value")
|
||||
}
|
||||
return Encrypt(data, alg)
|
||||
}
|
||||
|
||||
func Decrypt(value *CryptoValue, alg EncryptionAlgorithm) ([]byte, error) {
|
||||
if err := checkEncryptionAlgorithm(value, alg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return alg.Decrypt(value.Crypted, value.KeyID)
|
||||
}
|
||||
|
||||
func DecryptJSON(value *CryptoValue, dst any, alg EncryptionAlgorithm) error {
|
||||
data, err := Decrypt(value, alg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = json.Unmarshal(data, dst); err != nil {
|
||||
return zerrors.ThrowInternal(err, "CRYPT-Jaik2R", "error decrypting value")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecryptString decrypts the value using the key identified by keyID.
|
||||
// When the decrypted value contains non-UTF8 characters an error is returned.
|
||||
func DecryptString(value *CryptoValue, alg EncryptionAlgorithm) (string, error) {
|
||||
if err := checkEncryptionAlgorithm(value, alg); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return alg.DecryptString(value.Crypted, value.KeyID)
|
||||
}
|
||||
|
||||
func checkEncryptionAlgorithm(value *CryptoValue, alg EncryptionAlgorithm) error {
|
||||
if value.Algorithm != alg.Algorithm() {
|
||||
return zerrors.ThrowInvalidArgument(nil, "CRYPT-Nx7XlT", "value was encrypted with a different key")
|
||||
}
|
||||
for _, id := range alg.DecryptionKeyIDs() {
|
||||
if id == value.KeyID {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return zerrors.ThrowInvalidArgument(nil, "CRYPT-Kq12vn", "value was encrypted with a different key")
|
||||
}
|
||||
|
||||
func CheckToken(alg EncryptionAlgorithm, token string, content string) error {
|
||||
if token == "" {
|
||||
return zerrors.ThrowPermissionDenied(nil, "CRYPTO-Sfefs", "Errors.Intent.InvalidToken")
|
||||
}
|
||||
data, err := base64.RawURLEncoding.DecodeString(token)
|
||||
if err != nil {
|
||||
return zerrors.ThrowPermissionDenied(err, "CRYPTO-Swg31", "Errors.Intent.InvalidToken")
|
||||
}
|
||||
decryptedToken, err := alg.DecryptString(data, alg.EncryptionKeyID())
|
||||
if err != nil {
|
||||
return zerrors.ThrowPermissionDenied(err, "CRYPTO-Sf4gt", "Errors.Intent.InvalidToken")
|
||||
}
|
||||
if decryptedToken != content {
|
||||
return zerrors.ThrowPermissionDenied(nil, "CRYPTO-CRYPTO", "Errors.Intent.InvalidToken")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SecretOrEncodedHash returns the Crypted value from legacy [CryptoValue] if it is not nil.
|
||||
// otherwise it will returns the encoded hash string.
|
||||
func SecretOrEncodedHash(secret *CryptoValue, encoded string) string {
|
||||
if secret != nil {
|
||||
return string(secret.Crypted)
|
||||
}
|
||||
return encoded
|
||||
}
|
126
apps/api/internal/crypto/crypto_mock.go
Normal file
126
apps/api/internal/crypto/crypto_mock.go
Normal file
@@ -0,0 +1,126 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: crypto.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source crypto.go -destination ./crypto_mock.go -package crypto
|
||||
//
|
||||
|
||||
// Package crypto is a generated GoMock package.
|
||||
package crypto
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockEncryptionAlgorithm is a mock of EncryptionAlgorithm interface.
|
||||
type MockEncryptionAlgorithm struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockEncryptionAlgorithmMockRecorder
|
||||
}
|
||||
|
||||
// MockEncryptionAlgorithmMockRecorder is the mock recorder for MockEncryptionAlgorithm.
|
||||
type MockEncryptionAlgorithmMockRecorder struct {
|
||||
mock *MockEncryptionAlgorithm
|
||||
}
|
||||
|
||||
// NewMockEncryptionAlgorithm creates a new mock instance.
|
||||
func NewMockEncryptionAlgorithm(ctrl *gomock.Controller) *MockEncryptionAlgorithm {
|
||||
mock := &MockEncryptionAlgorithm{ctrl: ctrl}
|
||||
mock.recorder = &MockEncryptionAlgorithmMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockEncryptionAlgorithm) EXPECT() *MockEncryptionAlgorithmMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Algorithm mocks base method.
|
||||
func (m *MockEncryptionAlgorithm) Algorithm() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Algorithm")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Algorithm indicates an expected call of Algorithm.
|
||||
func (mr *MockEncryptionAlgorithmMockRecorder) Algorithm() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Algorithm", reflect.TypeOf((*MockEncryptionAlgorithm)(nil).Algorithm))
|
||||
}
|
||||
|
||||
// Decrypt mocks base method.
|
||||
func (m *MockEncryptionAlgorithm) Decrypt(hashed []byte, keyID string) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Decrypt", hashed, keyID)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Decrypt indicates an expected call of Decrypt.
|
||||
func (mr *MockEncryptionAlgorithmMockRecorder) Decrypt(hashed, keyID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decrypt", reflect.TypeOf((*MockEncryptionAlgorithm)(nil).Decrypt), hashed, keyID)
|
||||
}
|
||||
|
||||
// DecryptString mocks base method.
|
||||
func (m *MockEncryptionAlgorithm) DecryptString(hashed []byte, keyID string) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DecryptString", hashed, keyID)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DecryptString indicates an expected call of DecryptString.
|
||||
func (mr *MockEncryptionAlgorithmMockRecorder) DecryptString(hashed, keyID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptString", reflect.TypeOf((*MockEncryptionAlgorithm)(nil).DecryptString), hashed, keyID)
|
||||
}
|
||||
|
||||
// DecryptionKeyIDs mocks base method.
|
||||
func (m *MockEncryptionAlgorithm) DecryptionKeyIDs() []string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DecryptionKeyIDs")
|
||||
ret0, _ := ret[0].([]string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DecryptionKeyIDs indicates an expected call of DecryptionKeyIDs.
|
||||
func (mr *MockEncryptionAlgorithmMockRecorder) DecryptionKeyIDs() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecryptionKeyIDs", reflect.TypeOf((*MockEncryptionAlgorithm)(nil).DecryptionKeyIDs))
|
||||
}
|
||||
|
||||
// Encrypt mocks base method.
|
||||
func (m *MockEncryptionAlgorithm) Encrypt(value []byte) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Encrypt", value)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Encrypt indicates an expected call of Encrypt.
|
||||
func (mr *MockEncryptionAlgorithmMockRecorder) Encrypt(value any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encrypt", reflect.TypeOf((*MockEncryptionAlgorithm)(nil).Encrypt), value)
|
||||
}
|
||||
|
||||
// EncryptionKeyID mocks base method.
|
||||
func (m *MockEncryptionAlgorithm) EncryptionKeyID() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "EncryptionKeyID")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// EncryptionKeyID indicates an expected call of EncryptionKeyID.
|
||||
func (mr *MockEncryptionAlgorithmMockRecorder) EncryptionKeyID() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptionKeyID", reflect.TypeOf((*MockEncryptionAlgorithm)(nil).EncryptionKeyID))
|
||||
}
|
198
apps/api/internal/crypto/crypto_test.go
Normal file
198
apps/api/internal/crypto/crypto_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockEncCrypto struct {
|
||||
}
|
||||
|
||||
func (m *mockEncCrypto) Algorithm() string {
|
||||
return "enc"
|
||||
}
|
||||
|
||||
func (m *mockEncCrypto) Encrypt(value []byte) ([]byte, error) {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (m *mockEncCrypto) Decrypt(value []byte, _ string) ([]byte, error) {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (m *mockEncCrypto) DecryptString(value []byte, _ string) (string, error) {
|
||||
return string(value), nil
|
||||
}
|
||||
|
||||
func (m *mockEncCrypto) EncryptionKeyID() string {
|
||||
return "keyID"
|
||||
}
|
||||
func (m *mockEncCrypto) DecryptionKeyIDs() []string {
|
||||
return []string{"keyID"}
|
||||
}
|
||||
|
||||
type mockHashCrypto struct {
|
||||
}
|
||||
|
||||
func (m *mockHashCrypto) Algorithm() string {
|
||||
return "hash"
|
||||
}
|
||||
|
||||
func (m *mockHashCrypto) Hash(value []byte) ([]byte, error) {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (m *mockHashCrypto) CompareHash(hashed, comparer []byte) error {
|
||||
if !bytes.Equal(hashed, comparer) {
|
||||
return errors.New("not equal")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type alg struct{}
|
||||
|
||||
func (a *alg) Algorithm() string {
|
||||
return "alg"
|
||||
}
|
||||
|
||||
func TestCrypt(t *testing.T) {
|
||||
type args struct {
|
||||
value []byte
|
||||
c EncryptionAlgorithm
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *CryptoValue
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"encrypt",
|
||||
args{[]byte("test"), &mockEncCrypto{}},
|
||||
&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID", Crypted: []byte("test")},
|
||||
false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Crypt(tt.args.value, tt.args.c)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Crypt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Crypt() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncrypt(t *testing.T) {
|
||||
type args struct {
|
||||
value []byte
|
||||
c EncryptionAlgorithm
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *CryptoValue
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"ok",
|
||||
args{[]byte("test"), &mockEncCrypto{}},
|
||||
&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID", Crypted: []byte("test")},
|
||||
false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Encrypt(tt.args.value, tt.args.c)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Encrypt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Encrypt() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecrypt(t *testing.T) {
|
||||
type args struct {
|
||||
value *CryptoValue
|
||||
c EncryptionAlgorithm
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"ok",
|
||||
args{&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID", Crypted: []byte("test")}, &mockEncCrypto{}},
|
||||
[]byte("test"),
|
||||
false,
|
||||
},
|
||||
{
|
||||
"wrong id",
|
||||
args{&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID2", Crypted: []byte("test")}, &mockEncCrypto{}},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Decrypt(tt.args.value, tt.args.c)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Decrypt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Decrypt() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptString(t *testing.T) {
|
||||
type args struct {
|
||||
value *CryptoValue
|
||||
c EncryptionAlgorithm
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"ok",
|
||||
args{&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID", Crypted: []byte("test")}, &mockEncCrypto{}},
|
||||
"test",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"wrong id",
|
||||
args{&CryptoValue{CryptoType: TypeEncryption, Algorithm: "enc", KeyID: "keyID2", Crypted: []byte("test")}, &mockEncCrypto{}},
|
||||
"",
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := DecryptString(tt.args.value, tt.args.c)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("DecryptString() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("DecryptString() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
137
apps/api/internal/crypto/database/database.go
Normal file
137
apps/api/internal/crypto/database/database.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
z_db "github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
client *z_db.DB
|
||||
masterKey string
|
||||
encrypt func(key, masterKey string) (encryptedKey string, err error)
|
||||
decrypt func(encryptedKey, masterKey string) (key string, err error)
|
||||
}
|
||||
|
||||
const (
|
||||
EncryptionKeysTable = "system.encryption_keys"
|
||||
encryptionKeysIDCol = "id"
|
||||
encryptionKeysKeyCol = "key"
|
||||
)
|
||||
|
||||
func NewKeyStorage(client *z_db.DB, masterKey string) (*Database, error) {
|
||||
if err := checkMasterKeyLength(masterKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Database{
|
||||
client: client,
|
||||
masterKey: masterKey,
|
||||
encrypt: crypto.EncryptAESString,
|
||||
decrypt: crypto.DecryptAESString,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Database) ReadKeys() (crypto.Keys, error) {
|
||||
keys := make(map[string]string)
|
||||
stmt, args, err := sq.Select(encryptionKeysIDCol, encryptionKeysKeyCol).
|
||||
From(EncryptionKeysTable).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "", "unable to read keys")
|
||||
}
|
||||
err = d.client.Query(func(rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
var id, encryptionKey string
|
||||
err = rows.Scan(&id, &encryptionKey)
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "", "unable to read keys")
|
||||
}
|
||||
key, err := d.decrypt(encryptionKey, d.masterKey)
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "", "unable to decrypt key")
|
||||
}
|
||||
keys[id] = key
|
||||
}
|
||||
return nil
|
||||
}, stmt, args...)
|
||||
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "", "unable to read keys")
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (d *Database) ReadKey(id string) (_ *crypto.Key, err error) {
|
||||
stmt, args, err := sq.Select(encryptionKeysKeyCol).
|
||||
From(EncryptionKeysTable).
|
||||
Where(sq.Eq{encryptionKeysIDCol: id}).
|
||||
PlaceholderFormat(sq.Dollar).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "", "unable to read key")
|
||||
}
|
||||
var key string
|
||||
err = d.client.QueryRow(func(row *sql.Row) error {
|
||||
var encryptionKey string
|
||||
err = row.Scan(&encryptionKey)
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "", "unable to read key")
|
||||
}
|
||||
key, err = d.decrypt(encryptionKey, d.masterKey)
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "", "unable to decrypt key")
|
||||
}
|
||||
return nil
|
||||
}, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "", "unable to read key")
|
||||
}
|
||||
|
||||
return &crypto.Key{
|
||||
ID: id,
|
||||
Value: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Database) CreateKeys(ctx context.Context, keys ...*crypto.Key) error {
|
||||
insert := sq.Insert(EncryptionKeysTable).
|
||||
Columns(encryptionKeysIDCol, encryptionKeysKeyCol).PlaceholderFormat(sq.Dollar)
|
||||
for _, key := range keys {
|
||||
encryptionKey, err := d.encrypt(key.Value, d.masterKey)
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "", "unable to encrypt key")
|
||||
}
|
||||
insert = insert.Values(key.ID, encryptionKey)
|
||||
}
|
||||
stmt, args, err := insert.ToSql()
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "", "unable to insert new keys")
|
||||
}
|
||||
tx, err := d.client.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "", "unable to insert new keys")
|
||||
}
|
||||
_, err = tx.Exec(stmt, args...)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return zerrors.ThrowInternal(err, "", "unable to insert new keys")
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "", "unable to insert new keys")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkMasterKeyLength(masterKey string) error {
|
||||
if length := len([]byte(masterKey)); length != 32 {
|
||||
return zerrors.ThrowInternalf(nil, "", "masterkey must be 32 bytes, but is %d", length)
|
||||
}
|
||||
return nil
|
||||
}
|
543
apps/api/internal/crypto/database/database_test.go
Normal file
543
apps/api/internal/crypto/database/database_test.go
Normal file
@@ -0,0 +1,543 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
z_db "github.com/zitadel/zitadel/internal/database"
|
||||
db_mock "github.com/zitadel/zitadel/internal/database/mock"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func Test_database_ReadKeys(t *testing.T) {
|
||||
type fields struct {
|
||||
client db
|
||||
masterKey string
|
||||
decrypt func(encryptedKey, masterKey string) (key string, err error)
|
||||
}
|
||||
type res struct {
|
||||
keys crypto.Keys
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"query fails, error",
|
||||
fields{
|
||||
client: dbMock(t, expectQueryErr("SELECT id, key FROM system.encryption_keys", sql.ErrConnDone)),
|
||||
masterKey: "",
|
||||
decrypt: nil,
|
||||
},
|
||||
res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, sql.ErrConnDone)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"decryption error",
|
||||
fields{
|
||||
client: dbMock(t, expectQueryScanErr(
|
||||
"SELECT id, key FROM system.encryption_keys",
|
||||
[]string{"id", "key"},
|
||||
[][]driver.Value{
|
||||
{
|
||||
"id1",
|
||||
"key1",
|
||||
},
|
||||
})),
|
||||
masterKey: "wrong key",
|
||||
decrypt: func(encryptedKey, masterKey string) (key string, err error) {
|
||||
return "", fmt.Errorf("wrong masterkey")
|
||||
},
|
||||
},
|
||||
res{
|
||||
err: zerrors.IsInternal,
|
||||
},
|
||||
},
|
||||
{
|
||||
"single key ok",
|
||||
fields{
|
||||
client: dbMock(t, expectQuery(
|
||||
"SELECT id, key FROM system.encryption_keys",
|
||||
[]string{"id", "key"},
|
||||
[][]driver.Value{
|
||||
{
|
||||
"id1",
|
||||
"key1",
|
||||
},
|
||||
})),
|
||||
masterKey: "masterKey",
|
||||
decrypt: func(encryptedKey, masterKey string) (key string, err error) {
|
||||
return encryptedKey, nil
|
||||
},
|
||||
},
|
||||
res{
|
||||
keys: crypto.Keys(map[string]string{"id1": "key1"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
"multiple keys ok",
|
||||
fields{
|
||||
client: dbMock(t, expectQuery(
|
||||
"SELECT id, key FROM system.encryption_keys",
|
||||
[]string{"id", "key"},
|
||||
[][]driver.Value{
|
||||
{
|
||||
"id1",
|
||||
"key1",
|
||||
},
|
||||
{
|
||||
"id2",
|
||||
"key2",
|
||||
},
|
||||
})),
|
||||
masterKey: "masterKey",
|
||||
decrypt: func(encryptedKey, masterKey string) (key string, err error) {
|
||||
return encryptedKey, nil
|
||||
},
|
||||
},
|
||||
res{
|
||||
keys: crypto.Keys(map[string]string{"id1": "key1", "id2": "key2"}),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
d := &Database{
|
||||
client: tt.fields.client.db,
|
||||
masterKey: tt.fields.masterKey,
|
||||
decrypt: tt.fields.decrypt,
|
||||
}
|
||||
got, err := d.ReadKeys()
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
} else if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v", err)
|
||||
}
|
||||
if tt.res.err == nil {
|
||||
assert.Equal(t, tt.res.keys, got)
|
||||
}
|
||||
if err := tt.fields.client.mock.ExpectationsWereMet(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_database_ReadKey(t *testing.T) {
|
||||
type fields struct {
|
||||
client db
|
||||
masterKey string
|
||||
decrypt func(encryptedKey, masterKey string) (key string, err error)
|
||||
}
|
||||
type args struct {
|
||||
id string
|
||||
}
|
||||
type res struct {
|
||||
key *crypto.Key
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"query fails, error",
|
||||
fields{
|
||||
client: dbMock(t, expectQueryErr("SELECT key FROM system.encryption_keys WHERE id = $1", sql.ErrConnDone)),
|
||||
masterKey: "",
|
||||
decrypt: nil,
|
||||
},
|
||||
args{
|
||||
id: "id1",
|
||||
},
|
||||
res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, sql.ErrConnDone)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"key not found err",
|
||||
fields{
|
||||
client: dbMock(t, expectQueryScanErr(
|
||||
"SELECT key FROM system.encryption_keys WHERE id = $1",
|
||||
nil,
|
||||
nil,
|
||||
"id1")),
|
||||
masterKey: "masterKey",
|
||||
decrypt: func(encryptedKey, masterKey string) (key string, err error) {
|
||||
return encryptedKey, nil
|
||||
},
|
||||
},
|
||||
args{
|
||||
id: "id1",
|
||||
},
|
||||
res{
|
||||
err: zerrors.IsInternal,
|
||||
},
|
||||
},
|
||||
{
|
||||
"decryption error",
|
||||
fields{
|
||||
client: dbMock(t, expectQueryScanErr(
|
||||
"SELECT key FROM system.encryption_keys WHERE id = $1",
|
||||
[]string{"key"},
|
||||
[][]driver.Value{
|
||||
{
|
||||
"key1",
|
||||
},
|
||||
},
|
||||
"id1",
|
||||
)),
|
||||
masterKey: "wrong key",
|
||||
decrypt: func(encryptedKey, masterKey string) (key string, err error) {
|
||||
return "", fmt.Errorf("wrong masterkey")
|
||||
},
|
||||
},
|
||||
args{
|
||||
id: "id1",
|
||||
},
|
||||
res{
|
||||
err: zerrors.IsInternal,
|
||||
},
|
||||
},
|
||||
{
|
||||
"key ok",
|
||||
fields{
|
||||
client: dbMock(t, expectQuery(
|
||||
"SELECT key FROM system.encryption_keys WHERE id = $1",
|
||||
[]string{"key"},
|
||||
[][]driver.Value{
|
||||
{
|
||||
"key1",
|
||||
},
|
||||
},
|
||||
"id1",
|
||||
)),
|
||||
masterKey: "masterKey",
|
||||
decrypt: func(encryptedKey, masterKey string) (key string, err error) {
|
||||
return encryptedKey, nil
|
||||
},
|
||||
},
|
||||
args{
|
||||
id: "id1",
|
||||
},
|
||||
res{
|
||||
key: &crypto.Key{
|
||||
ID: "id1",
|
||||
Value: "key1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
d := &Database{
|
||||
client: tt.fields.client.db,
|
||||
masterKey: tt.fields.masterKey,
|
||||
decrypt: tt.fields.decrypt,
|
||||
}
|
||||
got, err := d.ReadKey(tt.args.id)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
} else if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v", err)
|
||||
}
|
||||
if tt.res.err == nil {
|
||||
assert.Equal(t, tt.res.key, got)
|
||||
}
|
||||
if err := tt.fields.client.mock.ExpectationsWereMet(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_database_CreateKeys(t *testing.T) {
|
||||
type fields struct {
|
||||
client db
|
||||
masterKey string
|
||||
encrypt func(key, masterKey string) (encryptedKey string, err error)
|
||||
}
|
||||
type args struct {
|
||||
keys []*crypto.Key
|
||||
}
|
||||
type res struct {
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"encryption fails, error",
|
||||
fields{
|
||||
client: dbMock(t),
|
||||
masterKey: "",
|
||||
encrypt: func(key, masterKey string) (encryptedKey string, err error) {
|
||||
return "", fmt.Errorf("encryption failed")
|
||||
},
|
||||
},
|
||||
args{
|
||||
keys: []*crypto.Key{
|
||||
{
|
||||
"id1",
|
||||
"key1",
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
err: zerrors.IsInternal,
|
||||
},
|
||||
},
|
||||
{
|
||||
"insert fails, error",
|
||||
fields{
|
||||
client: dbMock(t,
|
||||
expectBegin(nil),
|
||||
expectExec("INSERT INTO system.encryption_keys (id,key) VALUES ($1,$2)", sql.ErrTxDone),
|
||||
expectRollback(nil),
|
||||
),
|
||||
masterKey: "masterkey",
|
||||
encrypt: func(key, masterKey string) (encryptedKey string, err error) {
|
||||
return key, nil
|
||||
},
|
||||
},
|
||||
args{
|
||||
keys: []*crypto.Key{
|
||||
{
|
||||
"id1",
|
||||
"key1",
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, sql.ErrTxDone)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"single insert ok",
|
||||
fields{
|
||||
client: dbMock(t,
|
||||
expectBegin(nil),
|
||||
expectExec("INSERT INTO system.encryption_keys (id,key) VALUES ($1,$2)", nil, "id1", "key1"),
|
||||
expectCommit(nil),
|
||||
),
|
||||
masterKey: "masterkey",
|
||||
encrypt: func(key, masterKey string) (encryptedKey string, err error) {
|
||||
return key, nil
|
||||
},
|
||||
},
|
||||
args{
|
||||
keys: []*crypto.Key{
|
||||
{
|
||||
"id1",
|
||||
"key1",
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"multiple insert ok",
|
||||
fields{
|
||||
client: dbMock(t,
|
||||
expectBegin(nil),
|
||||
expectExec("INSERT INTO system.encryption_keys (id,key) VALUES ($1,$2)", nil, "id1", "key1", "id2", "key2"),
|
||||
expectCommit(nil),
|
||||
),
|
||||
masterKey: "masterkey",
|
||||
encrypt: func(key, masterKey string) (encryptedKey string, err error) {
|
||||
return key, nil
|
||||
},
|
||||
},
|
||||
args{
|
||||
keys: []*crypto.Key{
|
||||
{
|
||||
"id1",
|
||||
"key1",
|
||||
},
|
||||
{
|
||||
"id2",
|
||||
"key2",
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
d := &Database{
|
||||
client: tt.fields.client.db,
|
||||
masterKey: tt.fields.masterKey,
|
||||
encrypt: tt.fields.encrypt,
|
||||
}
|
||||
err := d.CreateKeys(context.Background(), tt.args.keys...)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
} else if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v", err)
|
||||
}
|
||||
if err := tt.fields.client.mock.ExpectationsWereMet(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_checkMasterKeyLength(t *testing.T) {
|
||||
type args struct {
|
||||
masterKey string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
err func(error) bool
|
||||
}{
|
||||
{
|
||||
"invalid length",
|
||||
args{
|
||||
masterKey: "",
|
||||
},
|
||||
zerrors.IsInternal,
|
||||
},
|
||||
{
|
||||
"valid length",
|
||||
args{
|
||||
masterKey: "!themasterkeywhichis32byteslong!",
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := checkMasterKeyLength(tt.args.masterKey)
|
||||
if tt.err == nil {
|
||||
assert.NoError(t, err)
|
||||
} else if tt.err != nil && !tt.err(err) {
|
||||
t.Errorf("got wrong err: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type db struct {
|
||||
mock sqlmock.Sqlmock
|
||||
db *z_db.DB
|
||||
}
|
||||
|
||||
func dbMock(t *testing.T, expectations ...func(m sqlmock.Sqlmock)) db {
|
||||
t.Helper()
|
||||
client, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter)))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create sql mock: %v", err)
|
||||
}
|
||||
for _, expectation := range expectations {
|
||||
expectation(mock)
|
||||
}
|
||||
return db{
|
||||
mock: mock,
|
||||
db: &z_db.DB{DB: client},
|
||||
}
|
||||
}
|
||||
|
||||
func expectQueryErr(query string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(args...).WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectQueryScanErr(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...)
|
||||
result := m.NewRows(cols)
|
||||
count := uint64(len(rows))
|
||||
for _, row := range rows {
|
||||
if cols[len(cols)-1] == "count" {
|
||||
row = append(row, count)
|
||||
}
|
||||
result.AddRow(row...)
|
||||
}
|
||||
q.WillReturnRows(result)
|
||||
q.RowsWillBeClosed()
|
||||
}
|
||||
}
|
||||
|
||||
func expectQuery(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...)
|
||||
result := m.NewRows(cols)
|
||||
count := uint64(len(rows))
|
||||
for _, row := range rows {
|
||||
if cols[len(cols)-1] == "count" {
|
||||
row = append(row, count)
|
||||
}
|
||||
result.AddRow(row...)
|
||||
}
|
||||
q.WillReturnRows(result)
|
||||
q.RowsWillBeClosed()
|
||||
}
|
||||
}
|
||||
|
||||
func expectExec(stmt string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
query := m.ExpectExec(regexp.QuoteMeta(stmt)).WithArgs(args...)
|
||||
if err != nil {
|
||||
query.WillReturnError(err)
|
||||
return
|
||||
}
|
||||
query.WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
}
|
||||
}
|
||||
|
||||
func expectBegin(err error) func(m sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
query := m.ExpectBegin()
|
||||
if err != nil {
|
||||
query.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func expectCommit(err error) func(m sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
query := m.ExpectCommit()
|
||||
if err != nil {
|
||||
query.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func expectRollback(err error) func(m sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
query := m.ExpectRollback()
|
||||
if err != nil {
|
||||
query.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
}
|
116
apps/api/internal/crypto/ellipticcurve_enumer.go
Normal file
116
apps/api/internal/crypto/ellipticcurve_enumer.go
Normal file
@@ -0,0 +1,116 @@
|
||||
// Code generated by "enumer -type EllipticCurve -trimprefix EllipticCurve -text -json -linecomment"; DO NOT EDIT.
|
||||
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const _EllipticCurveName = "P256P384P512"
|
||||
|
||||
var _EllipticCurveIndex = [...]uint8{0, 0, 4, 8, 12}
|
||||
|
||||
const _EllipticCurveLowerName = "p256p384p512"
|
||||
|
||||
func (i EllipticCurve) String() string {
|
||||
if i < 0 || i >= EllipticCurve(len(_EllipticCurveIndex)-1) {
|
||||
return fmt.Sprintf("EllipticCurve(%d)", i)
|
||||
}
|
||||
return _EllipticCurveName[_EllipticCurveIndex[i]:_EllipticCurveIndex[i+1]]
|
||||
}
|
||||
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
func _EllipticCurveNoOp() {
|
||||
var x [1]struct{}
|
||||
_ = x[EllipticCurveUnspecified-(0)]
|
||||
_ = x[EllipticCurveP256-(1)]
|
||||
_ = x[EllipticCurveP384-(2)]
|
||||
_ = x[EllipticCurveP512-(3)]
|
||||
}
|
||||
|
||||
var _EllipticCurveValues = []EllipticCurve{EllipticCurveUnspecified, EllipticCurveP256, EllipticCurveP384, EllipticCurveP512}
|
||||
|
||||
var _EllipticCurveNameToValueMap = map[string]EllipticCurve{
|
||||
_EllipticCurveName[0:0]: EllipticCurveUnspecified,
|
||||
_EllipticCurveLowerName[0:0]: EllipticCurveUnspecified,
|
||||
_EllipticCurveName[0:4]: EllipticCurveP256,
|
||||
_EllipticCurveLowerName[0:4]: EllipticCurveP256,
|
||||
_EllipticCurveName[4:8]: EllipticCurveP384,
|
||||
_EllipticCurveLowerName[4:8]: EllipticCurveP384,
|
||||
_EllipticCurveName[8:12]: EllipticCurveP512,
|
||||
_EllipticCurveLowerName[8:12]: EllipticCurveP512,
|
||||
}
|
||||
|
||||
var _EllipticCurveNames = []string{
|
||||
_EllipticCurveName[0:0],
|
||||
_EllipticCurveName[0:4],
|
||||
_EllipticCurveName[4:8],
|
||||
_EllipticCurveName[8:12],
|
||||
}
|
||||
|
||||
// EllipticCurveString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func EllipticCurveString(s string) (EllipticCurve, error) {
|
||||
if val, ok := _EllipticCurveNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if val, ok := _EllipticCurveNameToValueMap[strings.ToLower(s)]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to EllipticCurve values", s)
|
||||
}
|
||||
|
||||
// EllipticCurveValues returns all values of the enum
|
||||
func EllipticCurveValues() []EllipticCurve {
|
||||
return _EllipticCurveValues
|
||||
}
|
||||
|
||||
// EllipticCurveStrings returns a slice of all String values of the enum
|
||||
func EllipticCurveStrings() []string {
|
||||
strs := make([]string, len(_EllipticCurveNames))
|
||||
copy(strs, _EllipticCurveNames)
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsAEllipticCurve returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i EllipticCurve) IsAEllipticCurve() bool {
|
||||
for _, v := range _EllipticCurveValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for EllipticCurve
|
||||
func (i EllipticCurve) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(i.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface for EllipticCurve
|
||||
func (i *EllipticCurve) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return fmt.Errorf("EllipticCurve should be a string, got %s", data)
|
||||
}
|
||||
|
||||
var err error
|
||||
*i, err = EllipticCurveString(s)
|
||||
return err
|
||||
}
|
||||
|
||||
// MarshalText implements the encoding.TextMarshaler interface for EllipticCurve
|
||||
func (i EllipticCurve) MarshalText() ([]byte, error) {
|
||||
return []byte(i.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements the encoding.TextUnmarshaler interface for EllipticCurve
|
||||
func (i *EllipticCurve) UnmarshalText(text []byte) error {
|
||||
var err error
|
||||
*i, err = EllipticCurveString(string(text))
|
||||
return err
|
||||
}
|
44
apps/api/internal/crypto/file/file.go
Normal file
44
apps/api/internal/crypto/file/file.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/config"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
)
|
||||
|
||||
const (
|
||||
ZitadelKeyPath = "ZITADEL_KEY_PATH"
|
||||
)
|
||||
|
||||
type Storage struct{}
|
||||
|
||||
func (d *Storage) ReadKeys() (crypto.Keys, error) {
|
||||
path := os.Getenv(ZitadelKeyPath)
|
||||
if path == "" {
|
||||
return nil, fmt.Errorf("no path set, %s is empty", ZitadelKeyPath)
|
||||
}
|
||||
keys := new(crypto.Keys)
|
||||
err := config.Read(keys, path)
|
||||
return *keys, err
|
||||
}
|
||||
|
||||
func (d *Storage) ReadKey(id string) (*crypto.Key, error) {
|
||||
keys, err := d.ReadKeys()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key, ok := keys[id]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("key no found")
|
||||
}
|
||||
return &crypto.Key{
|
||||
ID: id,
|
||||
Value: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Storage) CreateKeys(keys ...*crypto.Key) error {
|
||||
return fmt.Errorf("this provider is not able to store new keys")
|
||||
}
|
4
apps/api/internal/crypto/generate.go
Normal file
4
apps/api/internal/crypto/generate.go
Normal file
@@ -0,0 +1,4 @@
|
||||
package crypto
|
||||
|
||||
//go:generate mockgen -source crypto.go -destination ./crypto_mock.go -package crypto
|
||||
//go:generate mockgen -source code.go -destination ./code_mock.go -package crypto
|
70
apps/api/internal/crypto/key.go
Normal file
70
apps/api/internal/crypto/key.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type KeyConfig struct {
|
||||
EncryptionKeyID string
|
||||
DecryptionKeyIDs []string
|
||||
}
|
||||
|
||||
type Keys map[string]string
|
||||
|
||||
type Key struct {
|
||||
ID string
|
||||
Value string
|
||||
}
|
||||
|
||||
func NewKey(id string) (*Key, error) {
|
||||
randBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(randBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Key{
|
||||
ID: id,
|
||||
Value: string(randBytes),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func LoadKey(id string, keyStorage KeyStorage) (string, error) {
|
||||
key, err := keyStorage.ReadKey(id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return key.Value, nil
|
||||
}
|
||||
|
||||
func LoadKeys(config *KeyConfig, keyStorage KeyStorage) (Keys, []string, error) {
|
||||
if config == nil {
|
||||
return nil, nil, zerrors.ThrowInvalidArgument(nil, "CRYPT-dJK8s", "config must not be nil")
|
||||
}
|
||||
readKeys, err := keyStorage.ReadKeys()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
keys := make(Keys)
|
||||
ids := make([]string, 0, len(config.DecryptionKeyIDs)+1)
|
||||
if config.EncryptionKeyID != "" {
|
||||
key, ok := readKeys[config.EncryptionKeyID]
|
||||
if !ok {
|
||||
return nil, nil, zerrors.ThrowInternalf(nil, "CRYPT-v2Kas", "encryption key %s not found", config.EncryptionKeyID)
|
||||
}
|
||||
keys[config.EncryptionKeyID] = key
|
||||
ids = append(ids, config.EncryptionKeyID)
|
||||
}
|
||||
for _, id := range config.DecryptionKeyIDs {
|
||||
key, ok := readKeys[id]
|
||||
if !ok {
|
||||
logging.Errorf("description key %s not found", id)
|
||||
continue
|
||||
}
|
||||
keys[id] = key
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return keys, ids, nil
|
||||
}
|
9
apps/api/internal/crypto/key_storage.go
Normal file
9
apps/api/internal/crypto/key_storage.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package crypto
|
||||
|
||||
import "context"
|
||||
|
||||
type KeyStorage interface {
|
||||
ReadKeys() (Keys, error)
|
||||
ReadKey(id string) (*Key, error)
|
||||
CreateKeys(context.Context, ...*Key) error
|
||||
}
|
353
apps/api/internal/crypto/passwap.go
Normal file
353
apps/api/internal/crypto/passwap.go
Normal file
@@ -0,0 +1,353 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/zitadel/passwap"
|
||||
"github.com/zitadel/passwap/argon2"
|
||||
"github.com/zitadel/passwap/bcrypt"
|
||||
"github.com/zitadel/passwap/md5"
|
||||
"github.com/zitadel/passwap/md5plain"
|
||||
"github.com/zitadel/passwap/md5salted"
|
||||
"github.com/zitadel/passwap/pbkdf2"
|
||||
"github.com/zitadel/passwap/phpass"
|
||||
"github.com/zitadel/passwap/scrypt"
|
||||
"github.com/zitadel/passwap/sha2"
|
||||
"github.com/zitadel/passwap/verifier"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type Hasher struct {
|
||||
*passwap.Swapper
|
||||
Prefixes []string
|
||||
HexSupported bool
|
||||
}
|
||||
|
||||
func (h *Hasher) EncodingSupported(encodedHash string) bool {
|
||||
for _, prefix := range h.Prefixes {
|
||||
if strings.HasPrefix(encodedHash, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if h.HexSupported {
|
||||
_, err := hex.DecodeString(encodedHash)
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type HashName string
|
||||
|
||||
const (
|
||||
HashNameArgon2 HashName = "argon2" // used for the common argon2 verifier
|
||||
HashNameArgon2i HashName = "argon2i" // hash only
|
||||
HashNameArgon2id HashName = "argon2id" // hash only
|
||||
HashNameBcrypt HashName = "bcrypt" // hash and verify
|
||||
HashNameMd5 HashName = "md5" // verify only, as hashing with md5 is insecure and deprecated
|
||||
HashNameMd5Plain HashName = "md5plain" // verify only, as hashing with md5 is insecure and deprecated
|
||||
HashNameMd5Salted HashName = "md5salted" // verify only, as hashing with md5 is insecure and deprecated
|
||||
HashNamePHPass HashName = "phpass" // verify only, as hashing with md5 is insecure and deprecated
|
||||
HashNameSha2 HashName = "sha2" // hash and verify
|
||||
HashNameScrypt HashName = "scrypt" // hash and verify
|
||||
HashNamePBKDF2 HashName = "pbkdf2" // hash and verify
|
||||
)
|
||||
|
||||
type HashMode string
|
||||
|
||||
// HashMode defines a underlying [hash.Hash] implementation
|
||||
// for algorithms like pbkdf2
|
||||
const (
|
||||
HashModeSHA1 HashMode = "sha1"
|
||||
HashModeSHA224 HashMode = "sha224"
|
||||
HashModeSHA256 HashMode = "sha256"
|
||||
HashModeSHA384 HashMode = "sha384"
|
||||
HashModeSHA512 HashMode = "sha512"
|
||||
)
|
||||
|
||||
type HashConfig struct {
|
||||
Verifiers []HashName
|
||||
Hasher HasherConfig
|
||||
}
|
||||
|
||||
func (c *HashConfig) NewHasher() (*Hasher, error) {
|
||||
verifiers, vPrefixes, err := c.buildVerifiers()
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInvalidArgument(err, "CRYPT-sahW9", "password hash config invalid")
|
||||
}
|
||||
hasher, hPrefixes, err := c.Hasher.buildHasher()
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInvalidArgument(err, "CRYPT-Que4r", "password hash config invalid")
|
||||
}
|
||||
return &Hasher{
|
||||
Swapper: passwap.NewSwapper(hasher, verifiers...),
|
||||
Prefixes: append(hPrefixes, vPrefixes...),
|
||||
HexSupported: slices.Contains(c.Verifiers, HashNameMd5Plain),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type prefixVerifier struct {
|
||||
prefixes []string
|
||||
verifier verifier.Verifier
|
||||
}
|
||||
|
||||
// map HashNames to Verifier instances.
|
||||
var knowVerifiers = map[HashName]prefixVerifier{
|
||||
HashNameArgon2: {
|
||||
// only argon2i and argon2id are suppored.
|
||||
// The Prefix constant also covers argon2d.
|
||||
prefixes: []string{argon2.Prefix},
|
||||
verifier: argon2.Verifier,
|
||||
},
|
||||
HashNameBcrypt: {
|
||||
prefixes: []string{bcrypt.Prefix},
|
||||
verifier: bcrypt.Verifier,
|
||||
},
|
||||
HashNameMd5: {
|
||||
prefixes: []string{md5.Prefix},
|
||||
verifier: md5.Verifier,
|
||||
},
|
||||
HashNameMd5Plain: {
|
||||
prefixes: nil, // hex encoded without identifier or prefix
|
||||
verifier: md5plain.Verifier,
|
||||
},
|
||||
HashNameScrypt: {
|
||||
prefixes: []string{scrypt.Prefix, scrypt.Prefix_Linux},
|
||||
verifier: scrypt.Verifier,
|
||||
},
|
||||
HashNamePBKDF2: {
|
||||
prefixes: []string{pbkdf2.Prefix},
|
||||
verifier: pbkdf2.Verifier,
|
||||
},
|
||||
HashNameMd5Salted: {
|
||||
prefixes: []string{md5salted.Prefix},
|
||||
verifier: md5salted.Verifier,
|
||||
},
|
||||
HashNameSha2: {
|
||||
prefixes: []string{sha2.Sha256Identifier, sha2.Sha512Identifier},
|
||||
verifier: sha2.Verifier,
|
||||
},
|
||||
HashNamePHPass: {
|
||||
prefixes: []string{phpass.IdentifierP, phpass.IdentifierH},
|
||||
verifier: phpass.Verifier,
|
||||
},
|
||||
}
|
||||
|
||||
func (c *HashConfig) buildVerifiers() (verifiers []verifier.Verifier, prefixes []string, err error) {
|
||||
verifiers = make([]verifier.Verifier, len(c.Verifiers))
|
||||
prefixes = make([]string, 0, len(c.Verifiers)+1)
|
||||
for i, name := range c.Verifiers {
|
||||
v, ok := knowVerifiers[name]
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("invalid verifier %q", name)
|
||||
}
|
||||
verifiers[i] = v.verifier
|
||||
prefixes = append(prefixes, v.prefixes...)
|
||||
}
|
||||
return verifiers, prefixes, nil
|
||||
}
|
||||
|
||||
type HasherConfig struct {
|
||||
Algorithm HashName
|
||||
Params map[string]any `mapstructure:",remain"`
|
||||
}
|
||||
|
||||
func (c *HasherConfig) buildHasher() (hasher passwap.Hasher, prefixes []string, err error) {
|
||||
switch c.Algorithm {
|
||||
case HashNameArgon2i:
|
||||
return c.argon2i()
|
||||
case HashNameArgon2id:
|
||||
return c.argon2id()
|
||||
case HashNameBcrypt:
|
||||
return c.bcrypt()
|
||||
case HashNameScrypt:
|
||||
return c.scrypt()
|
||||
case HashNamePBKDF2:
|
||||
return c.pbkdf2()
|
||||
case HashNameSha2:
|
||||
return c.sha2()
|
||||
case "":
|
||||
return nil, nil, fmt.Errorf("missing hasher algorithm")
|
||||
case HashNameArgon2, HashNameMd5, HashNameMd5Plain, HashNameMd5Salted, HashNamePHPass:
|
||||
fallthrough
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("invalid algorithm %q", c.Algorithm)
|
||||
}
|
||||
}
|
||||
|
||||
// decodeParams uses a mapstructure decoder from the Params map to dst.
|
||||
// The decoder fails when there are unused fields in dst.
|
||||
// It uses weak input typing, to allow conversion of env strings to ints.
|
||||
func (c *HasherConfig) decodeParams(dst any) error {
|
||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
ErrorUnused: false,
|
||||
ErrorUnset: true,
|
||||
WeaklyTypedInput: true,
|
||||
Result: dst,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return decoder.Decode(c.Params)
|
||||
}
|
||||
|
||||
// argon2Params decodes [HasherConfig.Params] into a [argon2.Params] used as defaults.
|
||||
// p is passed a copy and therfore will not be modified.
|
||||
func (c *HasherConfig) argon2Params(p argon2.Params) (argon2.Params, error) {
|
||||
var dst struct {
|
||||
Time uint32 `mapstructure:"Time"`
|
||||
Memory uint32 `mapstructure:"Memory"`
|
||||
Threads uint8 `mapstructure:"Threads"`
|
||||
}
|
||||
if err := c.decodeParams(&dst); err != nil {
|
||||
return argon2.Params{}, fmt.Errorf("decode argon2i params: %w", err)
|
||||
}
|
||||
p.Time = dst.Time
|
||||
p.Memory = dst.Memory
|
||||
p.Threads = dst.Threads
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (c *HasherConfig) argon2i() (passwap.Hasher, []string, error) {
|
||||
p, err := c.argon2Params(argon2.RecommendedIParams)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return argon2.NewArgon2i(p), []string{argon2.Prefix}, nil
|
||||
}
|
||||
|
||||
func (c *HasherConfig) argon2id() (passwap.Hasher, []string, error) {
|
||||
p, err := c.argon2Params(argon2.RecommendedIDParams)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return argon2.NewArgon2id(p), []string{argon2.Prefix}, nil
|
||||
}
|
||||
|
||||
func (c *HasherConfig) bcryptCost() (int, error) {
|
||||
var dst = struct {
|
||||
Cost int `mapstructure:"Cost"`
|
||||
}{}
|
||||
if err := c.decodeParams(&dst); err != nil {
|
||||
return 0, fmt.Errorf("decode bcrypt params: %w", err)
|
||||
}
|
||||
return dst.Cost, nil
|
||||
}
|
||||
|
||||
func (c *HasherConfig) bcrypt() (passwap.Hasher, []string, error) {
|
||||
cost, err := c.bcryptCost()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return bcrypt.New(cost), []string{bcrypt.Prefix}, nil
|
||||
}
|
||||
|
||||
func (c *HasherConfig) scryptParams() (scrypt.Params, error) {
|
||||
var dst = struct {
|
||||
Cost int `mapstructure:"Cost"`
|
||||
}{}
|
||||
if err := c.decodeParams(&dst); err != nil {
|
||||
return scrypt.Params{}, fmt.Errorf("decode scrypt params: %w", err)
|
||||
}
|
||||
p := scrypt.RecommendedParams // copy
|
||||
p.N = 1 << dst.Cost
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (c *HasherConfig) scrypt() (passwap.Hasher, []string, error) {
|
||||
p, err := c.scryptParams()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return scrypt.New(p), []string{scrypt.Prefix, scrypt.Prefix_Linux}, nil
|
||||
}
|
||||
|
||||
func (c *HasherConfig) pbkdf2Params() (p pbkdf2.Params, _ HashMode, _ error) {
|
||||
var dst = struct {
|
||||
Rounds uint32 `mapstructure:"Rounds"`
|
||||
Hash HashMode `mapstructure:"Hash"`
|
||||
}{}
|
||||
if err := c.decodeParams(&dst); err != nil {
|
||||
return p, "", fmt.Errorf("decode pbkdf2 params: %w", err)
|
||||
}
|
||||
switch dst.Hash {
|
||||
case HashModeSHA1:
|
||||
p = pbkdf2.RecommendedSHA1Params
|
||||
case HashModeSHA224:
|
||||
p = pbkdf2.RecommendedSHA224Params
|
||||
case HashModeSHA256:
|
||||
p = pbkdf2.RecommendedSHA256Params
|
||||
case HashModeSHA384:
|
||||
p = pbkdf2.RecommendedSHA384Params
|
||||
case HashModeSHA512:
|
||||
p = pbkdf2.RecommendedSHA512Params
|
||||
}
|
||||
p.Rounds = dst.Rounds
|
||||
return p, dst.Hash, nil
|
||||
}
|
||||
|
||||
func (c *HasherConfig) pbkdf2() (passwap.Hasher, []string, error) {
|
||||
p, hash, err := c.pbkdf2Params()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
prefix := []string{pbkdf2.Prefix}
|
||||
switch hash {
|
||||
case HashModeSHA1:
|
||||
return pbkdf2.NewSHA1(p), prefix, nil
|
||||
case HashModeSHA224:
|
||||
return pbkdf2.NewSHA224(p), prefix, nil
|
||||
case HashModeSHA256:
|
||||
return pbkdf2.NewSHA256(p), prefix, nil
|
||||
case HashModeSHA384:
|
||||
return pbkdf2.NewSHA384(p), prefix, nil
|
||||
case HashModeSHA512:
|
||||
return pbkdf2.NewSHA512(p), prefix, nil
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unsupported pbkdf2 hash mode: %s", hash)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HasherConfig) sha2Params() (use512 bool, rounds int, err error) {
|
||||
var dst = struct {
|
||||
Rounds uint32 `mapstructure:"Rounds"`
|
||||
Hash HashMode `mapstructure:"Hash"`
|
||||
}{}
|
||||
if err := c.decodeParams(&dst); err != nil {
|
||||
return false, 0, fmt.Errorf("decode sha2 params: %w", err)
|
||||
}
|
||||
switch dst.Hash {
|
||||
case HashModeSHA256:
|
||||
use512 = false
|
||||
case HashModeSHA512:
|
||||
use512 = true
|
||||
case HashModeSHA1, HashModeSHA224, HashModeSHA384:
|
||||
fallthrough
|
||||
default:
|
||||
return false, 0, fmt.Errorf("cannot use %s with sha2", dst.Hash)
|
||||
}
|
||||
if dst.Rounds > sha2.RoundsMax {
|
||||
return false, 0, fmt.Errorf("rounds with sha2 cannot be larger than %d", sha2.RoundsMax)
|
||||
} else {
|
||||
rounds = int(dst.Rounds)
|
||||
}
|
||||
return use512, rounds, nil
|
||||
}
|
||||
|
||||
func (c *HasherConfig) sha2() (passwap.Hasher, []string, error) {
|
||||
use512, rounds, err := c.sha2Params()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if use512 {
|
||||
return sha2.New512(rounds), []string{sha2.Sha256Identifier, sha2.Sha512Identifier}, nil
|
||||
} else {
|
||||
return sha2.New256(rounds), []string{sha2.Sha256Identifier, sha2.Sha512Identifier}, nil
|
||||
}
|
||||
}
|
880
apps/api/internal/crypto/passwap_test.go
Normal file
880
apps/api/internal/crypto/passwap_test.go
Normal file
@@ -0,0 +1,880 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/passwap/argon2"
|
||||
"github.com/zitadel/passwap/bcrypt"
|
||||
"github.com/zitadel/passwap/md5"
|
||||
"github.com/zitadel/passwap/md5salted"
|
||||
"github.com/zitadel/passwap/pbkdf2"
|
||||
"github.com/zitadel/passwap/scrypt"
|
||||
"github.com/zitadel/passwap/sha2"
|
||||
)
|
||||
|
||||
func TestPasswordHasher_EncodingSupported(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
encodedHash string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "empty string, false",
|
||||
encodedHash: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "scrypt, false",
|
||||
encodedHash: "$scrypt$ln=16,r=8,p=1$cmFuZG9tc2FsdGlzaGFyZA$Rh+NnJNo1I6nRwaNqbDm6kmADswD1+7FTKZ7Ln9D8nQ",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "bcrypt, true",
|
||||
encodedHash: "$2y$12$hXUrnqdq1RIIYZ2HPytIIe5lXdIvbhqrTvdPsSF7o.jFh817Z6lwm",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "argo2i, true",
|
||||
encodedHash: "$argon2i$v=19$m=4096,t=3,p=1$cmFuZG9tc2FsdGlzaGFyZA$YMvo8AUoNtnKYGqeODruCjHdiEbl1pKL2MsYy9VgU/E",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "argo2id, true",
|
||||
encodedHash: "$argon2d$v=19$m=4096,t=3,p=1$cmFuZG9tc2FsdGlzaGFyZA$CB0Du96aj3fQVcVSqb0LIA6Z6fpStjzjVkaC3RlpK9A",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &Hasher{
|
||||
Prefixes: []string{bcrypt.Prefix, argon2.Prefix},
|
||||
}
|
||||
got := h.EncodingSupported(tt.encodedHash)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordHashConfig_PasswordHasher(t *testing.T) {
|
||||
type fields struct {
|
||||
Verifiers []HashName
|
||||
Hasher HasherConfig
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantPrefixes []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "invalid verifier",
|
||||
fields: fields{
|
||||
Verifiers: []HashName{
|
||||
HashNameArgon2,
|
||||
HashNameBcrypt,
|
||||
HashNameMd5,
|
||||
HashNameMd5Salted,
|
||||
HashNamePHPass,
|
||||
HashNameScrypt,
|
||||
HashNameSha2,
|
||||
"foobar",
|
||||
},
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNameBcrypt,
|
||||
Params: map[string]any{
|
||||
"cost": 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid hasher",
|
||||
fields: fields{
|
||||
Verifiers: []HashName{
|
||||
HashNameArgon2,
|
||||
HashNameBcrypt,
|
||||
HashNameMd5,
|
||||
HashNameScrypt,
|
||||
},
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: "foobar",
|
||||
Params: map[string]any{
|
||||
"cost": 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing algorithm",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid md5",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNameMd5,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid md5plain",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNameMd5Plain,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid md5salted",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNameMd5Salted,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid phpass",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNamePHPass,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid argon2",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNameArgon2,
|
||||
Params: map[string]any{
|
||||
"time": 3,
|
||||
"memory": 32768,
|
||||
"threads": 4,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "argon2i, error",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNameArgon2i,
|
||||
Params: map[string]any{
|
||||
"time": 3,
|
||||
"threads": 4,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "argon2i, ok",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNameArgon2i,
|
||||
Params: map[string]any{
|
||||
"time": 3,
|
||||
"memory": 32768,
|
||||
"threads": 4,
|
||||
},
|
||||
},
|
||||
Verifiers: []HashName{HashNameBcrypt, HashNameMd5, HashNameScrypt, HashNameMd5Plain, HashNameMd5Salted},
|
||||
},
|
||||
wantPrefixes: []string{argon2.Prefix, bcrypt.Prefix, md5.Prefix, scrypt.Prefix, scrypt.Prefix_Linux, md5salted.Prefix},
|
||||
},
|
||||
{
|
||||
name: "argon2id, error",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNameArgon2id,
|
||||
Params: map[string]any{
|
||||
"time": 3,
|
||||
"threads": 4,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "argon2id, ok",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNameArgon2id,
|
||||
Params: map[string]any{
|
||||
"time": 3,
|
||||
"memory": 32768,
|
||||
"threads": 4,
|
||||
},
|
||||
},
|
||||
Verifiers: []HashName{HashNameBcrypt, HashNameMd5, HashNameScrypt, HashNameMd5Plain, HashNameMd5Salted},
|
||||
},
|
||||
wantPrefixes: []string{argon2.Prefix, bcrypt.Prefix, md5.Prefix, scrypt.Prefix, scrypt.Prefix_Linux, md5salted.Prefix},
|
||||
},
|
||||
{
|
||||
name: "bcrypt, error",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNameBcrypt,
|
||||
Params: map[string]any{
|
||||
"foo": 3,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "bcrypt, ok",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNameBcrypt,
|
||||
Params: map[string]any{
|
||||
"cost": 3,
|
||||
},
|
||||
},
|
||||
Verifiers: []HashName{HashNameArgon2, HashNameMd5, HashNameScrypt, HashNameMd5Plain, HashNameMd5Salted},
|
||||
},
|
||||
wantPrefixes: []string{bcrypt.Prefix, argon2.Prefix, md5.Prefix, scrypt.Prefix, scrypt.Prefix_Linux, md5salted.Prefix},
|
||||
},
|
||||
{
|
||||
name: "scrypt, error",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNameScrypt,
|
||||
Params: map[string]any{
|
||||
"cost": "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "scrypt, ok",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNameScrypt,
|
||||
Params: map[string]any{
|
||||
"cost": 3,
|
||||
},
|
||||
},
|
||||
Verifiers: []HashName{HashNameArgon2, HashNameBcrypt, HashNameMd5, HashNameMd5Plain, HashNameMd5Salted},
|
||||
},
|
||||
wantPrefixes: []string{scrypt.Prefix, scrypt.Prefix_Linux, argon2.Prefix, bcrypt.Prefix, md5.Prefix, md5salted.Prefix},
|
||||
},
|
||||
{
|
||||
name: "pbkdf2, parse error",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNamePBKDF2,
|
||||
Params: map[string]any{
|
||||
"cost": "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "pbkdf2, hash mode error",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNamePBKDF2,
|
||||
Params: map[string]any{
|
||||
"Rounds": 12,
|
||||
"Hash": "foo",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "pbkdf2, sha1",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNamePBKDF2,
|
||||
Params: map[string]any{
|
||||
"Rounds": 12,
|
||||
"Hash": HashModeSHA1,
|
||||
},
|
||||
},
|
||||
Verifiers: []HashName{HashNameArgon2, HashNameBcrypt, HashNameMd5, HashNameMd5Plain, HashNameMd5Salted},
|
||||
},
|
||||
wantPrefixes: []string{pbkdf2.Prefix, argon2.Prefix, bcrypt.Prefix, md5.Prefix, md5salted.Prefix},
|
||||
},
|
||||
{
|
||||
name: "pbkdf2, sha224",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNamePBKDF2,
|
||||
Params: map[string]any{
|
||||
"Rounds": 12,
|
||||
"Hash": HashModeSHA224,
|
||||
},
|
||||
},
|
||||
Verifiers: []HashName{HashNameArgon2, HashNameBcrypt, HashNameMd5, HashNameMd5Plain, HashNameMd5Salted},
|
||||
},
|
||||
wantPrefixes: []string{pbkdf2.Prefix, argon2.Prefix, bcrypt.Prefix, md5.Prefix, md5salted.Prefix},
|
||||
},
|
||||
{
|
||||
name: "pbkdf2, sha256",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNamePBKDF2,
|
||||
Params: map[string]any{
|
||||
"Rounds": 12,
|
||||
"Hash": HashModeSHA256,
|
||||
},
|
||||
},
|
||||
Verifiers: []HashName{HashNameArgon2, HashNameBcrypt, HashNameMd5, HashNameMd5Plain, HashNameMd5Salted},
|
||||
},
|
||||
wantPrefixes: []string{pbkdf2.Prefix, argon2.Prefix, bcrypt.Prefix, md5.Prefix, md5salted.Prefix},
|
||||
},
|
||||
{
|
||||
name: "pbkdf2, sha384",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNamePBKDF2,
|
||||
Params: map[string]any{
|
||||
"Rounds": 12,
|
||||
"Hash": HashModeSHA384,
|
||||
},
|
||||
},
|
||||
Verifiers: []HashName{HashNameArgon2, HashNameBcrypt, HashNameMd5, HashNameMd5Plain, HashNameMd5Salted},
|
||||
},
|
||||
wantPrefixes: []string{pbkdf2.Prefix, argon2.Prefix, bcrypt.Prefix, md5.Prefix, md5salted.Prefix},
|
||||
},
|
||||
{
|
||||
name: "pbkdf2, sha512",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNamePBKDF2,
|
||||
Params: map[string]any{
|
||||
"Rounds": 12,
|
||||
"Hash": HashModeSHA512,
|
||||
},
|
||||
},
|
||||
Verifiers: []HashName{HashNameArgon2, HashNameBcrypt, HashNameMd5, HashNameMd5Plain, HashNameMd5Salted},
|
||||
},
|
||||
wantPrefixes: []string{pbkdf2.Prefix, argon2.Prefix, bcrypt.Prefix, md5.Prefix, md5salted.Prefix},
|
||||
},
|
||||
{
|
||||
name: "sha2, parse error",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNameSha2,
|
||||
Params: map[string]any{
|
||||
"cost": "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "pbkdf2, hash mode error",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNameSha2,
|
||||
Params: map[string]any{
|
||||
"Rounds": 12,
|
||||
"Hash": "foo",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "sha2, sha256",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNameSha2,
|
||||
Params: map[string]any{
|
||||
"Rounds": 12,
|
||||
"Hash": HashModeSHA256,
|
||||
},
|
||||
},
|
||||
Verifiers: []HashName{HashNameArgon2, HashNameBcrypt, HashNameMd5, HashNameMd5Plain},
|
||||
},
|
||||
wantPrefixes: []string{sha2.Sha256Identifier, sha2.Sha512Identifier, argon2.Prefix, bcrypt.Prefix, md5.Prefix},
|
||||
},
|
||||
{
|
||||
name: "sha2, sha512",
|
||||
fields: fields{
|
||||
Hasher: HasherConfig{
|
||||
Algorithm: HashNameSha2,
|
||||
Params: map[string]any{
|
||||
"Rounds": 12,
|
||||
"Hash": HashModeSHA512,
|
||||
},
|
||||
},
|
||||
Verifiers: []HashName{HashNameArgon2, HashNameBcrypt, HashNameMd5, HashNameMd5Plain},
|
||||
},
|
||||
wantPrefixes: []string{sha2.Sha256Identifier, sha2.Sha512Identifier, argon2.Prefix, bcrypt.Prefix, md5.Prefix},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &HashConfig{
|
||||
Verifiers: tt.fields.Verifiers,
|
||||
Hasher: tt.fields.Hasher,
|
||||
}
|
||||
got, err := c.NewHasher()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
if tt.wantPrefixes != nil {
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, tt.wantPrefixes, got.Prefixes)
|
||||
encoded, err := got.Hash("password")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, encoded)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasherConfig_decodeParams(t *testing.T) {
|
||||
type dst struct {
|
||||
A int
|
||||
B uint32
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
params map[string]any
|
||||
want dst
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "unused",
|
||||
params: map[string]any{
|
||||
"a": 1,
|
||||
"b": 2,
|
||||
"c": 3,
|
||||
},
|
||||
want: dst{
|
||||
A: 1,
|
||||
B: 2,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unset",
|
||||
params: map[string]any{
|
||||
"a": 1,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong type",
|
||||
params: map[string]any{
|
||||
"a": 1,
|
||||
"b": "2",
|
||||
},
|
||||
want: dst{
|
||||
A: 1,
|
||||
B: 2,
|
||||
},
|
||||
wantErr: false, // https://github.com/zitadel/zitadel/issues/6913
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
params: map[string]any{
|
||||
"a": 1,
|
||||
"b": 2,
|
||||
},
|
||||
want: dst{
|
||||
A: 1,
|
||||
B: 2,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &HasherConfig{
|
||||
Params: tt.params,
|
||||
}
|
||||
var got dst
|
||||
err := c.decodeParams(&got)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasherConfig_argon2Params(t *testing.T) {
|
||||
type fields struct {
|
||||
Params map[string]any
|
||||
}
|
||||
type args struct {
|
||||
p argon2.Params
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want argon2.Params
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "decode error",
|
||||
fields: fields{
|
||||
Params: map[string]any{
|
||||
"foo": "bar",
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
p: argon2.RecommendedIDParams,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
fields: fields{
|
||||
Params: map[string]any{
|
||||
"time": 2,
|
||||
"memory": 256,
|
||||
"threads": 8,
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
p: argon2.RecommendedIDParams,
|
||||
},
|
||||
want: argon2.Params{
|
||||
Time: 2,
|
||||
Memory: 256,
|
||||
Threads: 8,
|
||||
KeyLen: 32,
|
||||
SaltLen: 16,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &HasherConfig{
|
||||
Params: tt.fields.Params,
|
||||
}
|
||||
got, err := c.argon2Params(tt.args.p)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasherConfig_bcryptCost(t *testing.T) {
|
||||
type fields struct {
|
||||
Params map[string]any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "decode error",
|
||||
fields: fields{
|
||||
Params: map[string]any{
|
||||
"foo": "bar",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
fields: fields{
|
||||
Params: map[string]any{
|
||||
"cost": 12,
|
||||
},
|
||||
},
|
||||
want: 12,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &HasherConfig{
|
||||
Params: tt.fields.Params,
|
||||
}
|
||||
got, err := c.bcryptCost()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasherConfig_scryptParams(t *testing.T) {
|
||||
type fields struct {
|
||||
Params map[string]any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want scrypt.Params
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "decode error",
|
||||
fields: fields{
|
||||
Params: map[string]any{
|
||||
"foo": "bar",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
fields: fields{
|
||||
Params: map[string]any{
|
||||
"cost": 2,
|
||||
},
|
||||
},
|
||||
want: scrypt.Params{
|
||||
N: 4,
|
||||
R: 8,
|
||||
P: 1,
|
||||
KeyLen: 32,
|
||||
SaltLen: 16,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &HasherConfig{
|
||||
Params: tt.fields.Params,
|
||||
}
|
||||
got, err := c.scryptParams()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasherConfig_pbkdf2Params(t *testing.T) {
|
||||
type fields struct {
|
||||
Params map[string]any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantP pbkdf2.Params
|
||||
wantHash HashMode
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "decode error",
|
||||
fields: fields{
|
||||
Params: map[string]any{
|
||||
"foo": "bar",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "sha1",
|
||||
fields: fields{
|
||||
Params: map[string]any{
|
||||
"Rounds": 12,
|
||||
"Hash": "sha1",
|
||||
},
|
||||
},
|
||||
wantP: pbkdf2.Params{
|
||||
Rounds: 12,
|
||||
KeyLen: sha1.Size,
|
||||
SaltLen: 16,
|
||||
},
|
||||
wantHash: HashModeSHA1,
|
||||
},
|
||||
{
|
||||
name: "sha224",
|
||||
fields: fields{
|
||||
Params: map[string]any{
|
||||
"Rounds": 12,
|
||||
"Hash": "sha224",
|
||||
},
|
||||
},
|
||||
wantP: pbkdf2.Params{
|
||||
Rounds: 12,
|
||||
KeyLen: sha256.Size224,
|
||||
SaltLen: 16,
|
||||
},
|
||||
wantHash: HashModeSHA224,
|
||||
},
|
||||
{
|
||||
name: "sha256",
|
||||
fields: fields{
|
||||
Params: map[string]any{
|
||||
"Rounds": 12,
|
||||
"Hash": "sha256",
|
||||
},
|
||||
},
|
||||
wantP: pbkdf2.Params{
|
||||
Rounds: 12,
|
||||
KeyLen: sha256.Size,
|
||||
SaltLen: 16,
|
||||
},
|
||||
wantHash: HashModeSHA256,
|
||||
},
|
||||
{
|
||||
name: "sha384",
|
||||
fields: fields{
|
||||
Params: map[string]any{
|
||||
"Rounds": 12,
|
||||
"Hash": "sha384",
|
||||
},
|
||||
},
|
||||
wantP: pbkdf2.Params{
|
||||
Rounds: 12,
|
||||
KeyLen: sha512.Size384,
|
||||
SaltLen: 16,
|
||||
},
|
||||
wantHash: HashModeSHA384,
|
||||
},
|
||||
{
|
||||
name: "sha512",
|
||||
fields: fields{
|
||||
Params: map[string]any{
|
||||
"Rounds": 12,
|
||||
"Hash": "sha512",
|
||||
},
|
||||
},
|
||||
wantP: pbkdf2.Params{
|
||||
Rounds: 12,
|
||||
KeyLen: sha512.Size,
|
||||
SaltLen: 16,
|
||||
},
|
||||
wantHash: HashModeSHA512,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &HasherConfig{
|
||||
Params: tt.fields.Params,
|
||||
}
|
||||
gotP, gotHash, err := c.pbkdf2Params()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantP, gotP)
|
||||
assert.Equal(t, tt.wantHash, gotHash)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasherConfig_sha2Params(t *testing.T) {
|
||||
type fields struct {
|
||||
Params map[string]any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want512 bool
|
||||
wantRounds int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "decode error",
|
||||
fields: fields{
|
||||
Params: map[string]any{
|
||||
"foo": "bar",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "sha1",
|
||||
fields: fields{
|
||||
Params: map[string]any{
|
||||
"Rounds": 12,
|
||||
"Hash": "sha1",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "sha224",
|
||||
fields: fields{
|
||||
Params: map[string]any{
|
||||
"Rounds": 12,
|
||||
"Hash": "sha224",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "sha256",
|
||||
fields: fields{
|
||||
Params: map[string]any{
|
||||
"Rounds": 5000,
|
||||
"Hash": "sha256",
|
||||
},
|
||||
},
|
||||
want512: false,
|
||||
wantRounds: 5000,
|
||||
},
|
||||
{
|
||||
name: "sha384",
|
||||
fields: fields{
|
||||
Params: map[string]any{
|
||||
"Rounds": 12,
|
||||
"Hash": "sha384",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "sha512",
|
||||
fields: fields{
|
||||
Params: map[string]any{
|
||||
"Rounds": 15000,
|
||||
"Hash": "sha512",
|
||||
},
|
||||
},
|
||||
want512: true,
|
||||
wantRounds: 15000,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &HasherConfig{
|
||||
Params: tt.fields.Params,
|
||||
}
|
||||
got512, gotRounds, err := c.sha2Params()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want512, got512)
|
||||
assert.Equal(t, tt.wantRounds, gotRounds)
|
||||
})
|
||||
}
|
||||
}
|
229
apps/api/internal/crypto/rsa.go
Normal file
229
apps/api/internal/crypto/rsa.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"time"
|
||||
)
|
||||
|
||||
func GenerateKeyPair(bits int) (*rsa.PrivateKey, *rsa.PublicKey, error) {
|
||||
privkey, err := rsa.GenerateKey(rand.Reader, bits)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return privkey, &privkey.PublicKey, nil
|
||||
}
|
||||
|
||||
type CertificateInformations struct {
|
||||
SerialNumber *big.Int
|
||||
Organisation []string
|
||||
CommonName string
|
||||
NotBefore time.Time
|
||||
NotAfter time.Time
|
||||
KeyUsage x509.KeyUsage
|
||||
ExtKeyUsage []x509.ExtKeyUsage
|
||||
}
|
||||
|
||||
func GenerateEncryptedKeyPairWithCACertificate(bits int, keyAlg, certAlg EncryptionAlgorithm, informations *CertificateInformations) (*CryptoValue, *CryptoValue, *CryptoValue, error) {
|
||||
privateKey, publicKey, cert, err := GenerateCACertificate(bits, informations)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
encryptPriv, encryptPub, encryptCaCert, err := EncryptKeysAndCert(privateKey, publicKey, cert, keyAlg, certAlg)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
return encryptPriv, encryptPub, encryptCaCert, nil
|
||||
}
|
||||
|
||||
func GenerateEncryptedKeyPairWithCertificate(bits int, keyAlg, certAlg EncryptionAlgorithm, caPrivateKey *rsa.PrivateKey, caCertificate []byte, informations *CertificateInformations) (*CryptoValue, *CryptoValue, *CryptoValue, error) {
|
||||
privateKey, publicKey, cert, err := GenerateCertificate(bits, caPrivateKey, caCertificate, informations)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
encryptPriv, encryptPub, encryptCaCert, err := EncryptKeysAndCert(privateKey, publicKey, cert, keyAlg, certAlg)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
return encryptPriv, encryptPub, encryptCaCert, nil
|
||||
}
|
||||
|
||||
func GenerateCACertificate(bits int, informations *CertificateInformations) (*rsa.PrivateKey, *rsa.PublicKey, []byte, error) {
|
||||
return generateCertificate(bits, nil, nil, informations)
|
||||
}
|
||||
|
||||
func GenerateCertificate(bits int, caPrivateKey *rsa.PrivateKey, ca []byte, informations *CertificateInformations) (*rsa.PrivateKey, *rsa.PublicKey, []byte, error) {
|
||||
return generateCertificate(bits, caPrivateKey, ca, informations)
|
||||
}
|
||||
|
||||
func generateCertificate(bits int, caPrivateKey *rsa.PrivateKey, ca []byte, informations *CertificateInformations) (*rsa.PrivateKey, *rsa.PublicKey, []byte, error) {
|
||||
notBefore := time.Now()
|
||||
if !informations.NotBefore.IsZero() {
|
||||
notBefore = informations.NotBefore
|
||||
}
|
||||
cert := &x509.Certificate{
|
||||
SerialNumber: informations.SerialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: informations.CommonName,
|
||||
Organization: informations.Organisation,
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: informations.NotAfter,
|
||||
KeyUsage: informations.KeyUsage,
|
||||
ExtKeyUsage: informations.ExtKeyUsage,
|
||||
}
|
||||
|
||||
certPrivKey, err := rsa.GenerateKey(rand.Reader, bits)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
certBytes := make([]byte, 0)
|
||||
if ca == nil {
|
||||
cert.IsCA = true
|
||||
cert.BasicConstraintsValid = true
|
||||
|
||||
certBytes, err = x509.CreateCertificate(rand.Reader, cert, cert, &certPrivKey.PublicKey, certPrivKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
} else {
|
||||
caCert, err := x509.ParseCertificate(ca)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
certBytes, err = x509.CreateCertificate(rand.Reader, cert, caCert, &certPrivKey.PublicKey, caPrivateKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
x509Cert, err := x509.ParseCertificate(certBytes)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
certPem, err := CertificateToBytes(x509Cert)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
return certPrivKey, &certPrivKey.PublicKey, certPem, nil
|
||||
}
|
||||
|
||||
func PrivateKeyToBytes(priv *rsa.PrivateKey) []byte {
|
||||
return pem.EncodeToMemory(
|
||||
&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(priv),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func PublicKeyToBytes(pub *rsa.PublicKey) ([]byte, error) {
|
||||
pubASN1, err := x509.MarshalPKIXPublicKey(pub)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pubBytes := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PUBLIC KEY",
|
||||
Bytes: pubASN1,
|
||||
})
|
||||
|
||||
return pubBytes, nil
|
||||
}
|
||||
|
||||
func BytesToPrivateKey(priv []byte) (*rsa.PrivateKey, error) {
|
||||
block, _ := pem.Decode(priv)
|
||||
enc := x509.IsEncryptedPEMBlock(block)
|
||||
b := block.Bytes
|
||||
var err error
|
||||
if enc {
|
||||
b, err = x509.DecryptPEMBlock(block, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
key, err := x509.ParsePKCS1PrivateKey(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
var ErrEmpty = errors.New("cannot decode, empty data")
|
||||
|
||||
func BytesToPublicKey(pub []byte) (*rsa.PublicKey, error) {
|
||||
if len(pub) == 0 {
|
||||
return nil, ErrEmpty
|
||||
}
|
||||
block, _ := pem.Decode(pub)
|
||||
if block == nil {
|
||||
return nil, ErrEmpty
|
||||
}
|
||||
ifc, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key, ok := ifc.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, err
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func EncryptKeys(privateKey *rsa.PrivateKey, publicKey *rsa.PublicKey, alg EncryptionAlgorithm) (*CryptoValue, *CryptoValue, error) {
|
||||
encryptedPrivateKey, err := Encrypt(PrivateKeyToBytes(privateKey), alg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
pubKey, err := PublicKeyToBytes(publicKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
encryptedPublicKey, err := Encrypt(pubKey, alg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return encryptedPrivateKey, encryptedPublicKey, nil
|
||||
}
|
||||
|
||||
func CertificateToBytes(cert *x509.Certificate) ([]byte, error) {
|
||||
certPem := new(bytes.Buffer)
|
||||
if err := pem.Encode(certPem, &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return certPem.Bytes(), nil
|
||||
}
|
||||
|
||||
func BytesToCertificate(data []byte) ([]byte, error) {
|
||||
block, _ := pem.Decode(data)
|
||||
if block == nil || block.Type != "CERTIFICATE" {
|
||||
return nil, fmt.Errorf("failed to decode PEM block containing public key")
|
||||
}
|
||||
return block.Bytes, nil
|
||||
}
|
||||
|
||||
func EncryptKeysAndCert(privateKey *rsa.PrivateKey, publicKey *rsa.PublicKey, cert []byte, keyAlg, certAlg EncryptionAlgorithm) (*CryptoValue, *CryptoValue, *CryptoValue, error) {
|
||||
encryptedPrivateKey, encryptedPublicKey, err := EncryptKeys(privateKey, publicKey, keyAlg)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
encryptedCertificate, err := Encrypt(cert, certAlg)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
return encryptedPrivateKey, encryptedPublicKey, encryptedCertificate, nil
|
||||
}
|
136
apps/api/internal/crypto/rsabits_enumer.go
Normal file
136
apps/api/internal/crypto/rsabits_enumer.go
Normal file
@@ -0,0 +1,136 @@
|
||||
// Code generated by "enumer -type RSABits -trimprefix RSABits -text -json -linecomment"; DO NOT EDIT.
|
||||
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
_RSABitsName_0 = ""
|
||||
_RSABitsLowerName_0 = ""
|
||||
_RSABitsName_1 = "2048"
|
||||
_RSABitsLowerName_1 = "2048"
|
||||
_RSABitsName_2 = "3072"
|
||||
_RSABitsLowerName_2 = "3072"
|
||||
_RSABitsName_3 = "4096"
|
||||
_RSABitsLowerName_3 = "4096"
|
||||
)
|
||||
|
||||
var (
|
||||
_RSABitsIndex_0 = [...]uint8{0, 0}
|
||||
_RSABitsIndex_1 = [...]uint8{0, 4}
|
||||
_RSABitsIndex_2 = [...]uint8{0, 4}
|
||||
_RSABitsIndex_3 = [...]uint8{0, 4}
|
||||
)
|
||||
|
||||
func (i RSABits) String() string {
|
||||
switch {
|
||||
case i == 0:
|
||||
return _RSABitsName_0
|
||||
case i == 2048:
|
||||
return _RSABitsName_1
|
||||
case i == 3072:
|
||||
return _RSABitsName_2
|
||||
case i == 4096:
|
||||
return _RSABitsName_3
|
||||
default:
|
||||
return fmt.Sprintf("RSABits(%d)", i)
|
||||
}
|
||||
}
|
||||
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
func _RSABitsNoOp() {
|
||||
var x [1]struct{}
|
||||
_ = x[RSABitsUnspecified-(0)]
|
||||
_ = x[RSABits2048-(2048)]
|
||||
_ = x[RSABits3072-(3072)]
|
||||
_ = x[RSABits4096-(4096)]
|
||||
}
|
||||
|
||||
var _RSABitsValues = []RSABits{RSABitsUnspecified, RSABits2048, RSABits3072, RSABits4096}
|
||||
|
||||
var _RSABitsNameToValueMap = map[string]RSABits{
|
||||
_RSABitsName_0[0:0]: RSABitsUnspecified,
|
||||
_RSABitsLowerName_0[0:0]: RSABitsUnspecified,
|
||||
_RSABitsName_1[0:4]: RSABits2048,
|
||||
_RSABitsLowerName_1[0:4]: RSABits2048,
|
||||
_RSABitsName_2[0:4]: RSABits3072,
|
||||
_RSABitsLowerName_2[0:4]: RSABits3072,
|
||||
_RSABitsName_3[0:4]: RSABits4096,
|
||||
_RSABitsLowerName_3[0:4]: RSABits4096,
|
||||
}
|
||||
|
||||
var _RSABitsNames = []string{
|
||||
_RSABitsName_0[0:0],
|
||||
_RSABitsName_1[0:4],
|
||||
_RSABitsName_2[0:4],
|
||||
_RSABitsName_3[0:4],
|
||||
}
|
||||
|
||||
// RSABitsString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func RSABitsString(s string) (RSABits, error) {
|
||||
if val, ok := _RSABitsNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if val, ok := _RSABitsNameToValueMap[strings.ToLower(s)]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to RSABits values", s)
|
||||
}
|
||||
|
||||
// RSABitsValues returns all values of the enum
|
||||
func RSABitsValues() []RSABits {
|
||||
return _RSABitsValues
|
||||
}
|
||||
|
||||
// RSABitsStrings returns a slice of all String values of the enum
|
||||
func RSABitsStrings() []string {
|
||||
strs := make([]string, len(_RSABitsNames))
|
||||
copy(strs, _RSABitsNames)
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsARSABits returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i RSABits) IsARSABits() bool {
|
||||
for _, v := range _RSABitsValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for RSABits
|
||||
func (i RSABits) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(i.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface for RSABits
|
||||
func (i *RSABits) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return fmt.Errorf("RSABits should be a string, got %s", data)
|
||||
}
|
||||
|
||||
var err error
|
||||
*i, err = RSABitsString(s)
|
||||
return err
|
||||
}
|
||||
|
||||
// MarshalText implements the encoding.TextMarshaler interface for RSABits
|
||||
func (i RSABits) MarshalText() ([]byte, error) {
|
||||
return []byte(i.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements the encoding.TextUnmarshaler interface for RSABits
|
||||
func (i *RSABits) UnmarshalText(text []byte) error {
|
||||
var err error
|
||||
*i, err = RSABitsString(string(text))
|
||||
return err
|
||||
}
|
116
apps/api/internal/crypto/rsahasher_enumer.go
Normal file
116
apps/api/internal/crypto/rsahasher_enumer.go
Normal file
@@ -0,0 +1,116 @@
|
||||
// Code generated by "enumer -type RSAHasher -trimprefix RSAHasher -text -json -linecomment"; DO NOT EDIT.
|
||||
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const _RSAHasherName = "SHA256SHA384SHA512"
|
||||
|
||||
var _RSAHasherIndex = [...]uint8{0, 0, 6, 12, 18}
|
||||
|
||||
const _RSAHasherLowerName = "sha256sha384sha512"
|
||||
|
||||
func (i RSAHasher) String() string {
|
||||
if i < 0 || i >= RSAHasher(len(_RSAHasherIndex)-1) {
|
||||
return fmt.Sprintf("RSAHasher(%d)", i)
|
||||
}
|
||||
return _RSAHasherName[_RSAHasherIndex[i]:_RSAHasherIndex[i+1]]
|
||||
}
|
||||
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
func _RSAHasherNoOp() {
|
||||
var x [1]struct{}
|
||||
_ = x[RSAHasherUnspecified-(0)]
|
||||
_ = x[RSAHasherSHA256-(1)]
|
||||
_ = x[RSAHasherSHA384-(2)]
|
||||
_ = x[RSAHasherSHA512-(3)]
|
||||
}
|
||||
|
||||
var _RSAHasherValues = []RSAHasher{RSAHasherUnspecified, RSAHasherSHA256, RSAHasherSHA384, RSAHasherSHA512}
|
||||
|
||||
var _RSAHasherNameToValueMap = map[string]RSAHasher{
|
||||
_RSAHasherName[0:0]: RSAHasherUnspecified,
|
||||
_RSAHasherLowerName[0:0]: RSAHasherUnspecified,
|
||||
_RSAHasherName[0:6]: RSAHasherSHA256,
|
||||
_RSAHasherLowerName[0:6]: RSAHasherSHA256,
|
||||
_RSAHasherName[6:12]: RSAHasherSHA384,
|
||||
_RSAHasherLowerName[6:12]: RSAHasherSHA384,
|
||||
_RSAHasherName[12:18]: RSAHasherSHA512,
|
||||
_RSAHasherLowerName[12:18]: RSAHasherSHA512,
|
||||
}
|
||||
|
||||
var _RSAHasherNames = []string{
|
||||
_RSAHasherName[0:0],
|
||||
_RSAHasherName[0:6],
|
||||
_RSAHasherName[6:12],
|
||||
_RSAHasherName[12:18],
|
||||
}
|
||||
|
||||
// RSAHasherString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func RSAHasherString(s string) (RSAHasher, error) {
|
||||
if val, ok := _RSAHasherNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if val, ok := _RSAHasherNameToValueMap[strings.ToLower(s)]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to RSAHasher values", s)
|
||||
}
|
||||
|
||||
// RSAHasherValues returns all values of the enum
|
||||
func RSAHasherValues() []RSAHasher {
|
||||
return _RSAHasherValues
|
||||
}
|
||||
|
||||
// RSAHasherStrings returns a slice of all String values of the enum
|
||||
func RSAHasherStrings() []string {
|
||||
strs := make([]string, len(_RSAHasherNames))
|
||||
copy(strs, _RSAHasherNames)
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsARSAHasher returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i RSAHasher) IsARSAHasher() bool {
|
||||
for _, v := range _RSAHasherValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for RSAHasher
|
||||
func (i RSAHasher) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(i.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface for RSAHasher
|
||||
func (i *RSAHasher) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return fmt.Errorf("RSAHasher should be a string, got %s", data)
|
||||
}
|
||||
|
||||
var err error
|
||||
*i, err = RSAHasherString(s)
|
||||
return err
|
||||
}
|
||||
|
||||
// MarshalText implements the encoding.TextMarshaler interface for RSAHasher
|
||||
func (i RSAHasher) MarshalText() ([]byte, error) {
|
||||
return []byte(i.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements the encoding.TextUnmarshaler interface for RSAHasher
|
||||
func (i *RSAHasher) UnmarshalText(text []byte) error {
|
||||
var err error
|
||||
*i, err = RSAHasherString(string(text))
|
||||
return err
|
||||
}
|
2
apps/api/internal/crypto/testdata/fuzz/FuzzAESCrypto_DecryptString/8d609af8fa2eb76f
vendored
Normal file
2
apps/api/internal/crypto/testdata/fuzz/FuzzAESCrypto_DecryptString/8d609af8fa2eb76f
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
[]byte("0010120C001010070")
|
241
apps/api/internal/crypto/web_key.go
Normal file
241
apps/api/internal/crypto/web_key.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/muhlemmer/gu"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type KeyUsage int32
|
||||
|
||||
const (
|
||||
KeyUsageSigning KeyUsage = iota
|
||||
KeyUsageSAMLMetadataSigning
|
||||
KeyUsageSAMLResponseSinging
|
||||
KeyUsageSAMLCA
|
||||
)
|
||||
|
||||
func (u KeyUsage) String() string {
|
||||
switch u {
|
||||
case KeyUsageSigning:
|
||||
return "sig"
|
||||
case KeyUsageSAMLCA:
|
||||
return "saml_ca"
|
||||
case KeyUsageSAMLResponseSinging:
|
||||
return "saml_response_sig"
|
||||
case KeyUsageSAMLMetadataSigning:
|
||||
return "saml_metadata_sig"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
//go:generate enumer -type WebKeyConfigType -trimprefix WebKeyConfigType -text -json -linecomment
|
||||
type WebKeyConfigType int
|
||||
|
||||
const (
|
||||
WebKeyConfigTypeUnspecified WebKeyConfigType = iota //
|
||||
WebKeyConfigTypeRSA
|
||||
WebKeyConfigTypeECDSA
|
||||
WebKeyConfigTypeED25519
|
||||
)
|
||||
|
||||
//go:generate enumer -type RSABits -trimprefix RSABits -text -json -linecomment
|
||||
type RSABits int
|
||||
|
||||
const (
|
||||
RSABitsUnspecified RSABits = 0 //
|
||||
RSABits2048 RSABits = 2048
|
||||
RSABits3072 RSABits = 3072
|
||||
RSABits4096 RSABits = 4096
|
||||
)
|
||||
|
||||
type RSAHasher int
|
||||
|
||||
//go:generate enumer -type RSAHasher -trimprefix RSAHasher -text -json -linecomment
|
||||
const (
|
||||
RSAHasherUnspecified RSAHasher = iota //
|
||||
RSAHasherSHA256
|
||||
RSAHasherSHA384
|
||||
RSAHasherSHA512
|
||||
)
|
||||
|
||||
type EllipticCurve int
|
||||
|
||||
//go:generate enumer -type EllipticCurve -trimprefix EllipticCurve -text -json -linecomment
|
||||
const (
|
||||
EllipticCurveUnspecified EllipticCurve = iota //
|
||||
EllipticCurveP256
|
||||
EllipticCurveP384
|
||||
EllipticCurveP512
|
||||
)
|
||||
|
||||
type WebKeyConfig interface {
|
||||
Alg() jose.SignatureAlgorithm
|
||||
Type() WebKeyConfigType // Type is needed to make Unmarshal work
|
||||
IsValid() error
|
||||
}
|
||||
|
||||
func UnmarshalWebKeyConfig(data []byte, configType WebKeyConfigType) (config WebKeyConfig, err error) {
|
||||
switch configType {
|
||||
case WebKeyConfigTypeUnspecified:
|
||||
return nil, zerrors.ThrowInternal(nil, "CRYPT-Ii3AiH", "Errors.Internal")
|
||||
case WebKeyConfigTypeRSA:
|
||||
config = new(WebKeyRSAConfig)
|
||||
case WebKeyConfigTypeECDSA:
|
||||
config = new(WebKeyECDSAConfig)
|
||||
case WebKeyConfigTypeED25519:
|
||||
config = new(WebKeyED25519Config)
|
||||
default:
|
||||
return nil, zerrors.ThrowInternal(nil, "CRYPT-Eig8ho", "Errors.Internal")
|
||||
}
|
||||
if err = json.Unmarshal(data, config); err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "CRYPT-waeR0N", "Errors.Internal")
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
type WebKeyRSAConfig struct {
|
||||
Bits RSABits
|
||||
Hasher RSAHasher
|
||||
}
|
||||
|
||||
func (c WebKeyRSAConfig) Alg() jose.SignatureAlgorithm {
|
||||
switch c.Hasher {
|
||||
case RSAHasherUnspecified:
|
||||
return ""
|
||||
case RSAHasherSHA256:
|
||||
return jose.RS256
|
||||
case RSAHasherSHA384:
|
||||
return jose.RS384
|
||||
case RSAHasherSHA512:
|
||||
return jose.RS512
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (WebKeyRSAConfig) Type() WebKeyConfigType {
|
||||
return WebKeyConfigTypeRSA
|
||||
}
|
||||
|
||||
func (c WebKeyRSAConfig) IsValid() error {
|
||||
if !c.Bits.IsARSABits() || c.Bits == RSABitsUnspecified {
|
||||
return zerrors.ThrowInvalidArgument(nil, "CRYPTO-eaz3T", "Errors.WebKey.Config")
|
||||
}
|
||||
if !c.Hasher.IsARSAHasher() || c.Hasher == RSAHasherUnspecified {
|
||||
return zerrors.ThrowInvalidArgument(nil, "CRYPTO-ODie7", "Errors.WebKey.Config")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type WebKeyECDSAConfig struct {
|
||||
Curve EllipticCurve
|
||||
}
|
||||
|
||||
func (c WebKeyECDSAConfig) Alg() jose.SignatureAlgorithm {
|
||||
switch c.Curve {
|
||||
case EllipticCurveUnspecified:
|
||||
return ""
|
||||
case EllipticCurveP256:
|
||||
return jose.ES256
|
||||
case EllipticCurveP384:
|
||||
return jose.ES384
|
||||
case EllipticCurveP512:
|
||||
return jose.ES512
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (WebKeyECDSAConfig) Type() WebKeyConfigType {
|
||||
return WebKeyConfigTypeECDSA
|
||||
}
|
||||
|
||||
func (c WebKeyECDSAConfig) IsValid() error {
|
||||
if !c.Curve.IsAEllipticCurve() || c.Curve == EllipticCurveUnspecified {
|
||||
return zerrors.ThrowInvalidArgument(nil, "CRYPTO-Ii2ai", "Errors.WebKey.Config")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c WebKeyECDSAConfig) GetCurve() elliptic.Curve {
|
||||
switch c.Curve {
|
||||
case EllipticCurveUnspecified:
|
||||
return nil
|
||||
case EllipticCurveP256:
|
||||
return elliptic.P256()
|
||||
case EllipticCurveP384:
|
||||
return elliptic.P384()
|
||||
case EllipticCurveP512:
|
||||
return elliptic.P521()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type WebKeyED25519Config struct{}
|
||||
|
||||
func (WebKeyED25519Config) Alg() jose.SignatureAlgorithm {
|
||||
return jose.EdDSA
|
||||
}
|
||||
|
||||
func (WebKeyED25519Config) Type() WebKeyConfigType {
|
||||
return WebKeyConfigTypeED25519
|
||||
}
|
||||
|
||||
func (WebKeyED25519Config) IsValid() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func GenerateEncryptedWebKey(keyID string, alg EncryptionAlgorithm, genConfig WebKeyConfig) (encryptedPrivate *CryptoValue, public *jose.JSONWebKey, err error) {
|
||||
private, public, err := generateWebKey(keyID, genConfig)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
encryptedPrivate, err = EncryptJSON(private, alg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return encryptedPrivate, public, nil
|
||||
}
|
||||
|
||||
func generateWebKey(keyID string, genConfig WebKeyConfig) (private, public *jose.JSONWebKey, err error) {
|
||||
if err = genConfig.IsValid(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
var key any
|
||||
switch conf := genConfig.(type) {
|
||||
case *WebKeyRSAConfig:
|
||||
key, err = rsa.GenerateKey(rand.Reader, int(conf.Bits))
|
||||
case *WebKeyECDSAConfig:
|
||||
key, err = ecdsa.GenerateKey(conf.GetCurve(), rand.Reader)
|
||||
case *WebKeyED25519Config:
|
||||
_, key, err = ed25519.GenerateKey(rand.Reader)
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unknown webkey config type %T", genConfig)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
private = newJSONWebkey(key, keyID, genConfig.Alg())
|
||||
return private, gu.Ptr(private.Public()), err
|
||||
}
|
||||
|
||||
func newJSONWebkey(key any, keyID string, algorithm jose.SignatureAlgorithm) *jose.JSONWebKey {
|
||||
return &jose.JSONWebKey{
|
||||
Key: key,
|
||||
KeyID: keyID,
|
||||
Algorithm: string(algorithm),
|
||||
Use: KeyUsageSigning.String(),
|
||||
}
|
||||
}
|
269
apps/api/internal/crypto/web_key_test.go
Normal file
269
apps/api/internal/crypto/web_key_test.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/elliptic"
|
||||
"testing"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func TestUnmarshalWebKeyConfig(t *testing.T) {
|
||||
type args struct {
|
||||
data []byte
|
||||
configType WebKeyConfigType
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantConfig WebKeyConfig
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "unspecified",
|
||||
args: args{
|
||||
[]byte(`{}`),
|
||||
WebKeyConfigTypeUnspecified,
|
||||
},
|
||||
wantErr: zerrors.ThrowInternal(nil, "CRYPT-Ii3AiH", "Errors.Internal"),
|
||||
},
|
||||
{
|
||||
name: "rsa",
|
||||
args: args{
|
||||
[]byte(`{"bits":"2048", "hasher":"sha256"}`),
|
||||
WebKeyConfigTypeRSA,
|
||||
},
|
||||
wantConfig: &WebKeyRSAConfig{
|
||||
Bits: RSABits2048,
|
||||
Hasher: RSAHasherSHA256,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ecdsa",
|
||||
args: args{
|
||||
[]byte(`{"curve":"p256"}`),
|
||||
WebKeyConfigTypeECDSA,
|
||||
},
|
||||
wantConfig: &WebKeyECDSAConfig{
|
||||
Curve: EllipticCurveP256,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ed25519",
|
||||
args: args{
|
||||
[]byte(`{}`),
|
||||
WebKeyConfigTypeED25519,
|
||||
},
|
||||
wantConfig: &WebKeyED25519Config{},
|
||||
},
|
||||
{
|
||||
name: "unknown type error",
|
||||
args: args{
|
||||
[]byte(`{"curve":0}`),
|
||||
99,
|
||||
},
|
||||
wantErr: zerrors.ThrowInternal(nil, "CRYPT-Eig8ho", "Errors.Internal"),
|
||||
},
|
||||
{
|
||||
name: "unmarshal error",
|
||||
args: args{
|
||||
[]byte(`~~`),
|
||||
WebKeyConfigTypeED25519,
|
||||
},
|
||||
wantErr: zerrors.ThrowInternal(nil, "CRYPT-waeR0N", "Errors.Internal"),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotConfig, err := UnmarshalWebKeyConfig(tt.args.data, tt.args.configType)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, gotConfig, tt.wantConfig)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebKeyECDSAConfig_Alg(t *testing.T) {
|
||||
type fields struct {
|
||||
Curve EllipticCurve
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want jose.SignatureAlgorithm
|
||||
}{
|
||||
{
|
||||
name: "unspecified",
|
||||
fields: fields{
|
||||
Curve: EllipticCurveUnspecified,
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "P256",
|
||||
fields: fields{
|
||||
Curve: EllipticCurveP256,
|
||||
},
|
||||
want: jose.ES256,
|
||||
},
|
||||
{
|
||||
name: "P384",
|
||||
fields: fields{
|
||||
Curve: EllipticCurveP384,
|
||||
},
|
||||
want: jose.ES384,
|
||||
},
|
||||
{
|
||||
name: "P512",
|
||||
fields: fields{
|
||||
Curve: EllipticCurveP512,
|
||||
},
|
||||
want: jose.ES512,
|
||||
},
|
||||
{
|
||||
name: "default",
|
||||
fields: fields{
|
||||
Curve: 99,
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := WebKeyECDSAConfig{
|
||||
Curve: tt.fields.Curve,
|
||||
}
|
||||
got := c.Alg()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebKeyECDSAConfig_GetCurve(t *testing.T) {
|
||||
type fields struct {
|
||||
Curve EllipticCurve
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want elliptic.Curve
|
||||
}{
|
||||
{
|
||||
name: "unspecified",
|
||||
fields: fields{EllipticCurveUnspecified},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "P256",
|
||||
fields: fields{EllipticCurveP256},
|
||||
want: elliptic.P256(),
|
||||
},
|
||||
{
|
||||
name: "P384",
|
||||
fields: fields{EllipticCurveP384},
|
||||
want: elliptic.P384(),
|
||||
},
|
||||
{
|
||||
name: "P512",
|
||||
fields: fields{EllipticCurveP512},
|
||||
want: elliptic.P521(),
|
||||
},
|
||||
{
|
||||
name: "default",
|
||||
fields: fields{99},
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := WebKeyECDSAConfig{
|
||||
Curve: tt.fields.Curve,
|
||||
}
|
||||
got := c.GetCurve()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_generateEncryptedWebKey(t *testing.T) {
|
||||
type args struct {
|
||||
keyID string
|
||||
genConfig WebKeyConfig
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
assertPrivate func(t *testing.T, got *jose.JSONWebKey)
|
||||
assertPublic func(t *testing.T, got *jose.JSONWebKey)
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "invalid",
|
||||
args: args{
|
||||
keyID: "keyID",
|
||||
genConfig: &WebKeyRSAConfig{
|
||||
Bits: RSABitsUnspecified,
|
||||
Hasher: RSAHasherSHA256,
|
||||
},
|
||||
},
|
||||
wantErr: zerrors.ThrowInvalidArgument(nil, "CRYPTO-eaz3T", "Errors.WebKey.Config"),
|
||||
},
|
||||
{
|
||||
name: "RSA",
|
||||
args: args{
|
||||
keyID: "keyID",
|
||||
genConfig: &WebKeyRSAConfig{
|
||||
Bits: RSABits2048,
|
||||
Hasher: RSAHasherSHA256,
|
||||
},
|
||||
},
|
||||
assertPrivate: assertJSONWebKey("keyID", "RS256", "sig", false),
|
||||
assertPublic: assertJSONWebKey("keyID", "RS256", "sig", true),
|
||||
},
|
||||
{
|
||||
name: "ECDSA",
|
||||
args: args{
|
||||
keyID: "keyID",
|
||||
genConfig: &WebKeyECDSAConfig{
|
||||
Curve: EllipticCurveP256,
|
||||
},
|
||||
},
|
||||
assertPrivate: assertJSONWebKey("keyID", "ES256", "sig", false),
|
||||
assertPublic: assertJSONWebKey("keyID", "ES256", "sig", true),
|
||||
},
|
||||
{
|
||||
name: "ED25519",
|
||||
args: args{
|
||||
keyID: "keyID",
|
||||
genConfig: &WebKeyED25519Config{},
|
||||
},
|
||||
assertPrivate: assertJSONWebKey("keyID", "EdDSA", "sig", false),
|
||||
assertPublic: assertJSONWebKey("keyID", "EdDSA", "sig", true),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotPrivate, gotPublic, err := generateWebKey(tt.args.keyID, tt.args.genConfig)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
if tt.assertPrivate != nil {
|
||||
tt.assertPrivate(t, gotPrivate)
|
||||
}
|
||||
if tt.assertPublic != nil {
|
||||
tt.assertPublic(t, gotPublic)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func assertJSONWebKey(keyID, algorithm, use string, isPublic bool) func(t *testing.T, got *jose.JSONWebKey) {
|
||||
return func(t *testing.T, got *jose.JSONWebKey) {
|
||||
assert.NotNil(t, got)
|
||||
assert.NotNil(t, got.Key, "key")
|
||||
assert.Equal(t, keyID, got.KeyID, "keyID")
|
||||
assert.Equal(t, algorithm, got.Algorithm, "algorithm")
|
||||
assert.Equal(t, use, got.Use, "user")
|
||||
assert.Equal(t, isPublic, got.IsPublic(), "isPublic")
|
||||
}
|
||||
}
|
116
apps/api/internal/crypto/webkeyconfigtype_enumer.go
Normal file
116
apps/api/internal/crypto/webkeyconfigtype_enumer.go
Normal file
@@ -0,0 +1,116 @@
|
||||
// Code generated by "enumer -type WebKeyConfigType -trimprefix WebKeyConfigType -text -json -linecomment"; DO NOT EDIT.
|
||||
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const _WebKeyConfigTypeName = "RSAECDSAED25519"
|
||||
|
||||
var _WebKeyConfigTypeIndex = [...]uint8{0, 0, 3, 8, 15}
|
||||
|
||||
const _WebKeyConfigTypeLowerName = "rsaecdsaed25519"
|
||||
|
||||
func (i WebKeyConfigType) String() string {
|
||||
if i < 0 || i >= WebKeyConfigType(len(_WebKeyConfigTypeIndex)-1) {
|
||||
return fmt.Sprintf("WebKeyConfigType(%d)", i)
|
||||
}
|
||||
return _WebKeyConfigTypeName[_WebKeyConfigTypeIndex[i]:_WebKeyConfigTypeIndex[i+1]]
|
||||
}
|
||||
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
func _WebKeyConfigTypeNoOp() {
|
||||
var x [1]struct{}
|
||||
_ = x[WebKeyConfigTypeUnspecified-(0)]
|
||||
_ = x[WebKeyConfigTypeRSA-(1)]
|
||||
_ = x[WebKeyConfigTypeECDSA-(2)]
|
||||
_ = x[WebKeyConfigTypeED25519-(3)]
|
||||
}
|
||||
|
||||
var _WebKeyConfigTypeValues = []WebKeyConfigType{WebKeyConfigTypeUnspecified, WebKeyConfigTypeRSA, WebKeyConfigTypeECDSA, WebKeyConfigTypeED25519}
|
||||
|
||||
var _WebKeyConfigTypeNameToValueMap = map[string]WebKeyConfigType{
|
||||
_WebKeyConfigTypeName[0:0]: WebKeyConfigTypeUnspecified,
|
||||
_WebKeyConfigTypeLowerName[0:0]: WebKeyConfigTypeUnspecified,
|
||||
_WebKeyConfigTypeName[0:3]: WebKeyConfigTypeRSA,
|
||||
_WebKeyConfigTypeLowerName[0:3]: WebKeyConfigTypeRSA,
|
||||
_WebKeyConfigTypeName[3:8]: WebKeyConfigTypeECDSA,
|
||||
_WebKeyConfigTypeLowerName[3:8]: WebKeyConfigTypeECDSA,
|
||||
_WebKeyConfigTypeName[8:15]: WebKeyConfigTypeED25519,
|
||||
_WebKeyConfigTypeLowerName[8:15]: WebKeyConfigTypeED25519,
|
||||
}
|
||||
|
||||
var _WebKeyConfigTypeNames = []string{
|
||||
_WebKeyConfigTypeName[0:0],
|
||||
_WebKeyConfigTypeName[0:3],
|
||||
_WebKeyConfigTypeName[3:8],
|
||||
_WebKeyConfigTypeName[8:15],
|
||||
}
|
||||
|
||||
// WebKeyConfigTypeString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func WebKeyConfigTypeString(s string) (WebKeyConfigType, error) {
|
||||
if val, ok := _WebKeyConfigTypeNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if val, ok := _WebKeyConfigTypeNameToValueMap[strings.ToLower(s)]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to WebKeyConfigType values", s)
|
||||
}
|
||||
|
||||
// WebKeyConfigTypeValues returns all values of the enum
|
||||
func WebKeyConfigTypeValues() []WebKeyConfigType {
|
||||
return _WebKeyConfigTypeValues
|
||||
}
|
||||
|
||||
// WebKeyConfigTypeStrings returns a slice of all String values of the enum
|
||||
func WebKeyConfigTypeStrings() []string {
|
||||
strs := make([]string, len(_WebKeyConfigTypeNames))
|
||||
copy(strs, _WebKeyConfigTypeNames)
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsAWebKeyConfigType returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i WebKeyConfigType) IsAWebKeyConfigType() bool {
|
||||
for _, v := range _WebKeyConfigTypeValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for WebKeyConfigType
|
||||
func (i WebKeyConfigType) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(i.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface for WebKeyConfigType
|
||||
func (i *WebKeyConfigType) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return fmt.Errorf("WebKeyConfigType should be a string, got %s", data)
|
||||
}
|
||||
|
||||
var err error
|
||||
*i, err = WebKeyConfigTypeString(s)
|
||||
return err
|
||||
}
|
||||
|
||||
// MarshalText implements the encoding.TextMarshaler interface for WebKeyConfigType
|
||||
func (i WebKeyConfigType) MarshalText() ([]byte, error) {
|
||||
return []byte(i.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements the encoding.TextUnmarshaler interface for WebKeyConfigType
|
||||
func (i *WebKeyConfigType) UnmarshalText(text []byte) error {
|
||||
var err error
|
||||
*i, err = WebKeyConfigTypeString(string(text))
|
||||
return err
|
||||
}
|
Reference in New Issue
Block a user