mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 21:07:31 +00:00
feat: encryption keys in database (#3265)
* enable overwrite of adminUser fields in defaults.yaml * create schema and table * cli: create keys * cli: create keys * read encryptionkey from db * merge v2 * file names * cleanup defaults.yaml * remove custom errors * load encryptionKeys on start * cleanup * fix merge * update system defaults * fix error message
This commit is contained in:
524
internal/crypto/database/database_test.go
Normal file
524
internal/crypto/database/database_test.go
Normal file
@@ -0,0 +1,524 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/caos/zitadel/internal/crypto"
|
||||
caos_errs "github.com/caos/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
func Test_database_ReadKeys(t *testing.T) {
|
||||
type fields struct {
|
||||
client db
|
||||
masterKey string
|
||||
decrypt func(encryptedKey, masterKey string) (key string, err error)
|
||||
}
|
||||
type res struct {
|
||||
keys crypto.Keys
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"query fails, error",
|
||||
fields{
|
||||
client: dbMock(t, expectQueryErr("SELECT id, key FROM system.encryption_keys", sql.ErrConnDone)),
|
||||
masterKey: "",
|
||||
decrypt: nil,
|
||||
},
|
||||
res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, sql.ErrConnDone)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"decryption error",
|
||||
fields{
|
||||
client: dbMock(t, expectQuery(
|
||||
"SELECT id, key FROM system.encryption_keys",
|
||||
[]string{"id", "key"},
|
||||
[][]driver.Value{
|
||||
{
|
||||
"id1",
|
||||
"key1",
|
||||
},
|
||||
})),
|
||||
masterKey: "wrong key",
|
||||
decrypt: func(encryptedKey, masterKey string) (key string, err error) {
|
||||
return "", fmt.Errorf("wrong masterkey")
|
||||
},
|
||||
},
|
||||
res{
|
||||
err: caos_errs.IsInternal,
|
||||
},
|
||||
},
|
||||
{
|
||||
"single key ok",
|
||||
fields{
|
||||
client: dbMock(t, expectQuery(
|
||||
"SELECT id, key FROM system.encryption_keys",
|
||||
[]string{"id", "key"},
|
||||
[][]driver.Value{
|
||||
{
|
||||
"id1",
|
||||
"key1",
|
||||
},
|
||||
})),
|
||||
masterKey: "masterKey",
|
||||
decrypt: func(encryptedKey, masterKey string) (key string, err error) {
|
||||
return encryptedKey, nil
|
||||
},
|
||||
},
|
||||
res{
|
||||
keys: crypto.Keys(map[string]string{"id1": "key1"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
"multiple keys ok",
|
||||
fields{
|
||||
client: dbMock(t, expectQuery(
|
||||
"SELECT id, key FROM system.encryption_keys",
|
||||
[]string{"id", "key"},
|
||||
[][]driver.Value{
|
||||
{
|
||||
"id1",
|
||||
"key1",
|
||||
},
|
||||
{
|
||||
"id2",
|
||||
"key2",
|
||||
},
|
||||
})),
|
||||
masterKey: "masterKey",
|
||||
decrypt: func(encryptedKey, masterKey string) (key string, err error) {
|
||||
return encryptedKey, nil
|
||||
},
|
||||
},
|
||||
res{
|
||||
keys: crypto.Keys(map[string]string{"id1": "key1", "id2": "key2"}),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
d := &database{
|
||||
client: tt.fields.client.db,
|
||||
masterKey: tt.fields.masterKey,
|
||||
decrypt: tt.fields.decrypt,
|
||||
}
|
||||
got, err := d.ReadKeys()
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
} else if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v", err)
|
||||
}
|
||||
if tt.res.err == nil {
|
||||
assert.Equal(t, tt.res.keys, got)
|
||||
}
|
||||
if err := tt.fields.client.mock.ExpectationsWereMet(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_database_ReadKey(t *testing.T) {
|
||||
type fields struct {
|
||||
client db
|
||||
masterKey string
|
||||
decrypt func(encryptedKey, masterKey string) (key string, err error)
|
||||
}
|
||||
type args struct {
|
||||
id string
|
||||
}
|
||||
type res struct {
|
||||
key *crypto.Key
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"query fails, error",
|
||||
fields{
|
||||
client: dbMock(t, expectQueryErr("SELECT key FROM system.encryption_keys WHERE id = $1", sql.ErrConnDone)),
|
||||
masterKey: "",
|
||||
decrypt: nil,
|
||||
},
|
||||
args{
|
||||
id: "id1",
|
||||
},
|
||||
res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, sql.ErrConnDone)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"key not found err",
|
||||
fields{
|
||||
client: dbMock(t, expectQuery(
|
||||
"SELECT key FROM system.encryption_keys WHERE id = $1",
|
||||
nil,
|
||||
nil,
|
||||
"id1")),
|
||||
masterKey: "masterKey",
|
||||
decrypt: func(encryptedKey, masterKey string) (key string, err error) {
|
||||
return encryptedKey, nil
|
||||
},
|
||||
},
|
||||
args{
|
||||
id: "id1",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.IsInternal,
|
||||
},
|
||||
},
|
||||
{
|
||||
"decryption error",
|
||||
fields{
|
||||
client: dbMock(t, expectQuery(
|
||||
"SELECT key FROM system.encryption_keys WHERE id = $1",
|
||||
[]string{"key"},
|
||||
[][]driver.Value{
|
||||
{
|
||||
"key1",
|
||||
},
|
||||
},
|
||||
"id1",
|
||||
)),
|
||||
masterKey: "wrong key",
|
||||
decrypt: func(encryptedKey, masterKey string) (key string, err error) {
|
||||
return "", fmt.Errorf("wrong masterkey")
|
||||
},
|
||||
},
|
||||
args{
|
||||
id: "id1",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.IsInternal,
|
||||
},
|
||||
},
|
||||
{
|
||||
"key ok",
|
||||
fields{
|
||||
client: dbMock(t, expectQuery(
|
||||
"SELECT key FROM system.encryption_keys WHERE id = $1",
|
||||
[]string{"key"},
|
||||
[][]driver.Value{
|
||||
{
|
||||
"key1",
|
||||
},
|
||||
},
|
||||
"id1",
|
||||
)),
|
||||
masterKey: "masterKey",
|
||||
decrypt: func(encryptedKey, masterKey string) (key string, err error) {
|
||||
return encryptedKey, nil
|
||||
},
|
||||
},
|
||||
args{
|
||||
id: "id1",
|
||||
},
|
||||
res{
|
||||
key: &crypto.Key{
|
||||
ID: "id1",
|
||||
Value: "key1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
d := &database{
|
||||
client: tt.fields.client.db,
|
||||
masterKey: tt.fields.masterKey,
|
||||
decrypt: tt.fields.decrypt,
|
||||
}
|
||||
got, err := d.ReadKey(tt.args.id)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
} else if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v", err)
|
||||
}
|
||||
if tt.res.err == nil {
|
||||
assert.Equal(t, tt.res.key, got)
|
||||
}
|
||||
if err := tt.fields.client.mock.ExpectationsWereMet(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_database_CreateKeys(t *testing.T) {
|
||||
type fields struct {
|
||||
client db
|
||||
masterKey string
|
||||
encrypt func(key, masterKey string) (encryptedKey string, err error)
|
||||
}
|
||||
type args struct {
|
||||
keys []*crypto.Key
|
||||
}
|
||||
type res struct {
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"encryption fails, error",
|
||||
fields{
|
||||
client: dbMock(t),
|
||||
masterKey: "",
|
||||
encrypt: func(key, masterKey string) (encryptedKey string, err error) {
|
||||
return "", fmt.Errorf("encryption failed")
|
||||
},
|
||||
},
|
||||
args{
|
||||
keys: []*crypto.Key{
|
||||
{
|
||||
"id1",
|
||||
"key1",
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
err: caos_errs.IsInternal,
|
||||
},
|
||||
},
|
||||
{
|
||||
"insert fails, error",
|
||||
fields{
|
||||
client: dbMock(t,
|
||||
expectBegin(nil),
|
||||
expectExec("INSERT INTO system.encryption_keys (id,key) VALUES ($1,$2)", sql.ErrTxDone),
|
||||
expectRollback(nil),
|
||||
),
|
||||
masterKey: "masterkey",
|
||||
encrypt: func(key, masterKey string) (encryptedKey string, err error) {
|
||||
return key, nil
|
||||
},
|
||||
},
|
||||
args{
|
||||
keys: []*crypto.Key{
|
||||
{
|
||||
"id1",
|
||||
"key1",
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, sql.ErrTxDone)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"single insert ok",
|
||||
fields{
|
||||
client: dbMock(t,
|
||||
expectBegin(nil),
|
||||
expectExec("INSERT INTO system.encryption_keys (id,key) VALUES ($1,$2)", nil, "id1", "key1"),
|
||||
expectCommit(nil),
|
||||
),
|
||||
masterKey: "masterkey",
|
||||
encrypt: func(key, masterKey string) (encryptedKey string, err error) {
|
||||
return key, nil
|
||||
},
|
||||
},
|
||||
args{
|
||||
keys: []*crypto.Key{
|
||||
{
|
||||
"id1",
|
||||
"key1",
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"multiple insert ok",
|
||||
fields{
|
||||
client: dbMock(t,
|
||||
expectBegin(nil),
|
||||
expectExec("INSERT INTO system.encryption_keys (id,key) VALUES ($1,$2)", nil, "id1", "key1", "id2", "key2"),
|
||||
expectCommit(nil),
|
||||
),
|
||||
masterKey: "masterkey",
|
||||
encrypt: func(key, masterKey string) (encryptedKey string, err error) {
|
||||
return key, nil
|
||||
},
|
||||
},
|
||||
args{
|
||||
keys: []*crypto.Key{
|
||||
{
|
||||
"id1",
|
||||
"key1",
|
||||
},
|
||||
{
|
||||
"id2",
|
||||
"key2",
|
||||
},
|
||||
},
|
||||
},
|
||||
res{
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
d := &database{
|
||||
client: tt.fields.client.db,
|
||||
masterKey: tt.fields.masterKey,
|
||||
encrypt: tt.fields.encrypt,
|
||||
}
|
||||
err := d.CreateKeys(tt.args.keys...)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
} else if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v", err)
|
||||
}
|
||||
if err := tt.fields.client.mock.ExpectationsWereMet(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_checkMasterKeyLength(t *testing.T) {
|
||||
type args struct {
|
||||
masterKey string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
err func(error) bool
|
||||
}{
|
||||
{
|
||||
"invalid length",
|
||||
args{
|
||||
masterKey: "",
|
||||
},
|
||||
caos_errs.IsInternal,
|
||||
},
|
||||
{
|
||||
"valid length",
|
||||
args{
|
||||
masterKey: "!themasterkeywhichis32byteslong!",
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := checkMasterKeyLength(tt.args.masterKey)
|
||||
if tt.err == nil {
|
||||
assert.NoError(t, err)
|
||||
} else if tt.err != nil && !tt.err(err) {
|
||||
t.Errorf("got wrong err: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type db struct {
|
||||
mock sqlmock.Sqlmock
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func dbMock(t *testing.T, expectations ...func(m sqlmock.Sqlmock)) db {
|
||||
t.Helper()
|
||||
client, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create sql mock: %v", err)
|
||||
}
|
||||
for _, expectation := range expectations {
|
||||
expectation(mock)
|
||||
}
|
||||
return db{
|
||||
mock: mock,
|
||||
db: client,
|
||||
}
|
||||
}
|
||||
|
||||
func expectQueryErr(query string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(args...).WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectQuery(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...)
|
||||
result := sqlmock.NewRows(cols)
|
||||
count := uint64(len(rows))
|
||||
for _, row := range rows {
|
||||
if cols[len(cols)-1] == "count" {
|
||||
row = append(row, count)
|
||||
}
|
||||
result.AddRow(row...)
|
||||
}
|
||||
q.WillReturnRows(result)
|
||||
q.RowsWillBeClosed()
|
||||
}
|
||||
}
|
||||
|
||||
func expectExec(stmt string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
query := m.ExpectExec(regexp.QuoteMeta(stmt)).WithArgs(args...)
|
||||
if err != nil {
|
||||
query.WillReturnError(err)
|
||||
return
|
||||
}
|
||||
query.WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
}
|
||||
}
|
||||
|
||||
func expectBegin(err error) func(m sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
query := m.ExpectBegin()
|
||||
if err != nil {
|
||||
query.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func expectCommit(err error) func(m sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
query := m.ExpectCommit()
|
||||
if err != nil {
|
||||
query.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func expectRollback(err error) func(m sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
query := m.ExpectRollback()
|
||||
if err != nil {
|
||||
query.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user