zitadel/internal/crypto/database/database_test.go
Silvan 617708e0e5 fix(db): always use begin tx (#7142)
* fix(db): always use begin tx

* fix(handler): timeout for begin
2024-01-04 17:48:45 +01:00

549 lines
11 KiB
Go

package database
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"regexp"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/crypto"
z_db "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/zerrors"
)
func Test_database_ReadKeys(t *testing.T) {
type fields struct {
client db
masterKey string
decrypt func(encryptedKey, masterKey string) (key string, err error)
}
type res struct {
keys crypto.Keys
err func(error) bool
}
tests := []struct {
name string
fields fields
res res
}{
{
"query fails, error",
fields{
client: dbMock(t, expectQueryErr("SELECT id, key FROM system.encryption_keys", sql.ErrConnDone)),
masterKey: "",
decrypt: nil,
},
res{
err: func(err error) bool {
return errors.Is(err, sql.ErrConnDone)
},
},
},
{
"decryption error",
fields{
client: dbMock(t, expectQueryScanErr(
"SELECT id, key FROM system.encryption_keys",
[]string{"id", "key"},
[][]driver.Value{
{
"id1",
"key1",
},
})),
masterKey: "wrong key",
decrypt: func(encryptedKey, masterKey string) (key string, err error) {
return "", fmt.Errorf("wrong masterkey")
},
},
res{
err: zerrors.IsInternal,
},
},
{
"single key ok",
fields{
client: dbMock(t, expectQuery(
"SELECT id, key FROM system.encryption_keys",
[]string{"id", "key"},
[][]driver.Value{
{
"id1",
"key1",
},
})),
masterKey: "masterKey",
decrypt: func(encryptedKey, masterKey string) (key string, err error) {
return encryptedKey, nil
},
},
res{
keys: crypto.Keys(map[string]string{"id1": "key1"}),
},
},
{
"multiple keys ok",
fields{
client: dbMock(t, expectQuery(
"SELECT id, key FROM system.encryption_keys",
[]string{"id", "key"},
[][]driver.Value{
{
"id1",
"key1",
},
{
"id2",
"key2",
},
})),
masterKey: "masterKey",
decrypt: func(encryptedKey, masterKey string) (key string, err error) {
return encryptedKey, nil
},
},
res{
keys: crypto.Keys(map[string]string{"id1": "key1", "id2": "key2"}),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
d := &database{
client: tt.fields.client.db,
masterKey: tt.fields.masterKey,
decrypt: tt.fields.decrypt,
}
got, err := d.ReadKeys()
if tt.res.err == nil {
assert.NoError(t, err)
} else if tt.res.err != nil && !tt.res.err(err) {
t.Errorf("got wrong err: %v", err)
}
if tt.res.err == nil {
assert.Equal(t, tt.res.keys, got)
}
if err := tt.fields.client.mock.ExpectationsWereMet(); err != nil {
t.Error(err)
}
})
}
}
func Test_database_ReadKey(t *testing.T) {
type fields struct {
client db
masterKey string
decrypt func(encryptedKey, masterKey string) (key string, err error)
}
type args struct {
id string
}
type res struct {
key *crypto.Key
err func(error) bool
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"query fails, error",
fields{
client: dbMock(t, expectQueryErr("SELECT key FROM system.encryption_keys WHERE id = $1", sql.ErrConnDone)),
masterKey: "",
decrypt: nil,
},
args{
id: "id1",
},
res{
err: func(err error) bool {
return errors.Is(err, sql.ErrConnDone)
},
},
},
{
"key not found err",
fields{
client: dbMock(t, expectQueryScanErr(
"SELECT key FROM system.encryption_keys WHERE id = $1",
nil,
nil,
"id1")),
masterKey: "masterKey",
decrypt: func(encryptedKey, masterKey string) (key string, err error) {
return encryptedKey, nil
},
},
args{
id: "id1",
},
res{
err: zerrors.IsInternal,
},
},
{
"decryption error",
fields{
client: dbMock(t, expectQueryScanErr(
"SELECT key FROM system.encryption_keys WHERE id = $1",
[]string{"key"},
[][]driver.Value{
{
"key1",
},
},
"id1",
)),
masterKey: "wrong key",
decrypt: func(encryptedKey, masterKey string) (key string, err error) {
return "", fmt.Errorf("wrong masterkey")
},
},
args{
id: "id1",
},
res{
err: zerrors.IsInternal,
},
},
{
"key ok",
fields{
client: dbMock(t, expectQuery(
"SELECT key FROM system.encryption_keys WHERE id = $1",
[]string{"key"},
[][]driver.Value{
{
"key1",
},
},
"id1",
)),
masterKey: "masterKey",
decrypt: func(encryptedKey, masterKey string) (key string, err error) {
return encryptedKey, nil
},
},
args{
id: "id1",
},
res{
key: &crypto.Key{
ID: "id1",
Value: "key1",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
d := &database{
client: tt.fields.client.db,
masterKey: tt.fields.masterKey,
decrypt: tt.fields.decrypt,
}
got, err := d.ReadKey(tt.args.id)
if tt.res.err == nil {
assert.NoError(t, err)
} else if tt.res.err != nil && !tt.res.err(err) {
t.Errorf("got wrong err: %v", err)
}
if tt.res.err == nil {
assert.Equal(t, tt.res.key, got)
}
if err := tt.fields.client.mock.ExpectationsWereMet(); err != nil {
t.Error(err)
}
})
}
}
func Test_database_CreateKeys(t *testing.T) {
type fields struct {
client db
masterKey string
encrypt func(key, masterKey string) (encryptedKey string, err error)
}
type args struct {
keys []*crypto.Key
}
type res struct {
err func(error) bool
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"encryption fails, error",
fields{
client: dbMock(t),
masterKey: "",
encrypt: func(key, masterKey string) (encryptedKey string, err error) {
return "", fmt.Errorf("encryption failed")
},
},
args{
keys: []*crypto.Key{
{
"id1",
"key1",
},
},
},
res{
err: zerrors.IsInternal,
},
},
{
"insert fails, error",
fields{
client: dbMock(t,
expectBegin(nil),
expectExec("INSERT INTO system.encryption_keys (id,key) VALUES ($1,$2)", sql.ErrTxDone),
expectRollback(nil),
),
masterKey: "masterkey",
encrypt: func(key, masterKey string) (encryptedKey string, err error) {
return key, nil
},
},
args{
keys: []*crypto.Key{
{
"id1",
"key1",
},
},
},
res{
err: func(err error) bool {
return errors.Is(err, sql.ErrTxDone)
},
},
},
{
"single insert ok",
fields{
client: dbMock(t,
expectBegin(nil),
expectExec("INSERT INTO system.encryption_keys (id,key) VALUES ($1,$2)", nil, "id1", "key1"),
expectCommit(nil),
),
masterKey: "masterkey",
encrypt: func(key, masterKey string) (encryptedKey string, err error) {
return key, nil
},
},
args{
keys: []*crypto.Key{
{
"id1",
"key1",
},
},
},
res{
err: nil,
},
},
{
"multiple insert ok",
fields{
client: dbMock(t,
expectBegin(nil),
expectExec("INSERT INTO system.encryption_keys (id,key) VALUES ($1,$2)", nil, "id1", "key1", "id2", "key2"),
expectCommit(nil),
),
masterKey: "masterkey",
encrypt: func(key, masterKey string) (encryptedKey string, err error) {
return key, nil
},
},
args{
keys: []*crypto.Key{
{
"id1",
"key1",
},
{
"id2",
"key2",
},
},
},
res{
err: nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
d := &database{
client: tt.fields.client.db,
masterKey: tt.fields.masterKey,
encrypt: tt.fields.encrypt,
}
err := d.CreateKeys(context.Background(), tt.args.keys...)
if tt.res.err == nil {
assert.NoError(t, err)
} else if tt.res.err != nil && !tt.res.err(err) {
t.Errorf("got wrong err: %v", err)
}
if err := tt.fields.client.mock.ExpectationsWereMet(); err != nil {
t.Error(err)
}
})
}
}
func Test_checkMasterKeyLength(t *testing.T) {
type args struct {
masterKey string
}
tests := []struct {
name string
args args
err func(error) bool
}{
{
"invalid length",
args{
masterKey: "",
},
zerrors.IsInternal,
},
{
"valid length",
args{
masterKey: "!themasterkeywhichis32byteslong!",
},
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := checkMasterKeyLength(tt.args.masterKey)
if tt.err == nil {
assert.NoError(t, err)
} else if tt.err != nil && !tt.err(err) {
t.Errorf("got wrong err: %v", err)
}
})
}
}
type db struct {
mock sqlmock.Sqlmock
db *z_db.DB
}
func dbMock(t *testing.T, expectations ...func(m sqlmock.Sqlmock)) db {
t.Helper()
client, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("unable to create sql mock: %v", err)
}
for _, expectation := range expectations {
expectation(mock)
}
return db{
mock: mock,
db: &z_db.DB{DB: client},
}
}
func expectQueryErr(query string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(args...).WillReturnError(err)
m.ExpectRollback()
}
}
func expectQueryScanErr(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectBegin()
q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...)
m.ExpectRollback()
result := sqlmock.NewRows(cols)
count := uint64(len(rows))
for _, row := range rows {
if cols[len(cols)-1] == "count" {
row = append(row, count)
}
result.AddRow(row...)
}
q.WillReturnRows(result)
q.RowsWillBeClosed()
}
}
func expectQuery(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectBegin()
q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...)
m.ExpectCommit()
result := sqlmock.NewRows(cols)
count := uint64(len(rows))
for _, row := range rows {
if cols[len(cols)-1] == "count" {
row = append(row, count)
}
result.AddRow(row...)
}
q.WillReturnRows(result)
q.RowsWillBeClosed()
}
}
func expectExec(stmt string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
query := m.ExpectExec(regexp.QuoteMeta(stmt)).WithArgs(args...)
if err != nil {
query.WillReturnError(err)
return
}
query.WillReturnResult(sqlmock.NewResult(1, 1))
}
}
func expectBegin(err error) func(m sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
query := m.ExpectBegin()
if err != nil {
query.WillReturnError(err)
}
}
}
func expectCommit(err error) func(m sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
query := m.ExpectCommit()
if err != nil {
query.WillReturnError(err)
}
}
}
func expectRollback(err error) func(m sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
query := m.ExpectRollback()
if err != nil {
query.WillReturnError(err)
}
}
}