mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 01:37:31 +00:00
feat(storage): read only transactions for queries (#6415)
* fix: tests * bastle wie en grosse * fix(database): scan as callback * fix tests * fix merge failures * remove as of system time * refactor: remove unused test * refacotr: remove unused lines
This commit is contained in:
@@ -6,11 +6,12 @@ import (
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
z_db "github.com/zitadel/zitadel/internal/database"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
type database struct {
|
||||
client *sql.DB
|
||||
client *z_db.DB
|
||||
masterKey string
|
||||
encrypt func(key, masterKey string) (encryptedKey string, err error)
|
||||
decrypt func(encryptedKey, masterKey string) (key string, err error)
|
||||
@@ -22,7 +23,7 @@ const (
|
||||
encryptionKeysKeyCol = "key"
|
||||
)
|
||||
|
||||
func NewKeyStorage(client *sql.DB, masterKey string) (*database, error) {
|
||||
func NewKeyStorage(client *z_db.DB, masterKey string) (*database, error) {
|
||||
if err := checkMasterKeyLength(masterKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -42,32 +43,30 @@ func (d *database) ReadKeys() (crypto.Keys, error) {
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowInternal(err, "", "unable to read keys")
|
||||
}
|
||||
rows, err := d.client.Query(stmt, args...)
|
||||
err = d.client.Query(func(rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
var id, encryptionKey string
|
||||
err = rows.Scan(&id, &encryptionKey)
|
||||
if err != nil {
|
||||
return caos_errs.ThrowInternal(err, "", "unable to read keys")
|
||||
}
|
||||
key, err := d.decrypt(encryptionKey, d.masterKey)
|
||||
if err != nil {
|
||||
return caos_errs.ThrowInternal(err, "", "unable to decrypt key")
|
||||
}
|
||||
keys[id] = key
|
||||
}
|
||||
return nil
|
||||
}, stmt, args...)
|
||||
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowInternal(err, "", "unable to read keys")
|
||||
}
|
||||
for rows.Next() {
|
||||
var id, encryptionKey string
|
||||
err = rows.Scan(&id, &encryptionKey)
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowInternal(err, "", "unable to read keys")
|
||||
}
|
||||
key, err := d.decrypt(encryptionKey, d.masterKey)
|
||||
if err != nil {
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, caos_errs.ThrowInternal(err, "", "unable to close rows")
|
||||
}
|
||||
return nil, caos_errs.ThrowInternal(err, "", "unable to decrypt key")
|
||||
}
|
||||
keys[id] = key
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, caos_errs.ThrowInternal(err, "", "unable to close rows")
|
||||
}
|
||||
return keys, err
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (d *database) ReadKey(id string) (*crypto.Key, error) {
|
||||
func (d *database) ReadKey(id string) (_ *crypto.Key, err error) {
|
||||
stmt, args, err := sq.Select(encryptionKeysKeyCol).
|
||||
From(EncryptionKeysTable).
|
||||
Where(sq.Eq{encryptionKeysIDCol: id}).
|
||||
@@ -76,19 +75,23 @@ func (d *database) ReadKey(id string) (*crypto.Key, error) {
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowInternal(err, "", "unable to read key")
|
||||
}
|
||||
row := d.client.QueryRow(stmt, args...)
|
||||
var key string
|
||||
err = d.client.QueryRow(func(row *sql.Row) error {
|
||||
var encryptionKey string
|
||||
err = row.Scan(&encryptionKey)
|
||||
if err != nil {
|
||||
return caos_errs.ThrowInternal(err, "", "unable to read key")
|
||||
}
|
||||
key, err = d.decrypt(encryptionKey, d.masterKey)
|
||||
if err != nil {
|
||||
return caos_errs.ThrowInternal(err, "", "unable to decrypt key")
|
||||
}
|
||||
return nil
|
||||
}, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowInternal(err, "", "unable to read key")
|
||||
}
|
||||
var encryptionKey string
|
||||
err = row.Scan(&encryptionKey)
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowInternal(err, "", "unable to read key")
|
||||
}
|
||||
key, err := d.decrypt(encryptionKey, d.masterKey)
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowInternal(err, "", "unable to decrypt key")
|
||||
}
|
||||
|
||||
return &crypto.Key{
|
||||
ID: id,
|
||||
Value: key,
|
||||
|
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
z_db "github.com/zitadel/zitadel/internal/database"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
@@ -46,7 +47,7 @@ func Test_database_ReadKeys(t *testing.T) {
|
||||
{
|
||||
"decryption error",
|
||||
fields{
|
||||
client: dbMock(t, expectQuery(
|
||||
client: dbMock(t, expectQueryScanErr(
|
||||
"SELECT id, key FROM system.encryption_keys",
|
||||
[]string{"id", "key"},
|
||||
[][]driver.Value{
|
||||
@@ -172,7 +173,7 @@ func Test_database_ReadKey(t *testing.T) {
|
||||
{
|
||||
"key not found err",
|
||||
fields{
|
||||
client: dbMock(t, expectQuery(
|
||||
client: dbMock(t, expectQueryScanErr(
|
||||
"SELECT key FROM system.encryption_keys WHERE id = $1",
|
||||
nil,
|
||||
nil,
|
||||
@@ -192,7 +193,7 @@ func Test_database_ReadKey(t *testing.T) {
|
||||
{
|
||||
"decryption error",
|
||||
fields{
|
||||
client: dbMock(t, expectQuery(
|
||||
client: dbMock(t, expectQueryScanErr(
|
||||
"SELECT key FROM system.encryption_keys WHERE id = $1",
|
||||
[]string{"key"},
|
||||
[][]driver.Value{
|
||||
@@ -445,7 +446,7 @@ func Test_checkMasterKeyLength(t *testing.T) {
|
||||
|
||||
type db struct {
|
||||
mock sqlmock.Sqlmock
|
||||
db *sql.DB
|
||||
db *z_db.DB
|
||||
}
|
||||
|
||||
func dbMock(t *testing.T, expectations ...func(m sqlmock.Sqlmock)) db {
|
||||
@@ -459,19 +460,41 @@ func dbMock(t *testing.T, expectations ...func(m sqlmock.Sqlmock)) db {
|
||||
}
|
||||
return db{
|
||||
mock: mock,
|
||||
db: client,
|
||||
db: &z_db.DB{DB: client},
|
||||
}
|
||||
}
|
||||
|
||||
func expectQueryErr(query string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectBegin()
|
||||
m.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(args...).WillReturnError(err)
|
||||
m.ExpectRollback()
|
||||
}
|
||||
}
|
||||
|
||||
func expectQueryScanErr(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectBegin()
|
||||
q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...)
|
||||
m.ExpectRollback()
|
||||
result := sqlmock.NewRows(cols)
|
||||
count := uint64(len(rows))
|
||||
for _, row := range rows {
|
||||
if cols[len(cols)-1] == "count" {
|
||||
row = append(row, count)
|
||||
}
|
||||
result.AddRow(row...)
|
||||
}
|
||||
q.WillReturnRows(result)
|
||||
q.RowsWillBeClosed()
|
||||
}
|
||||
}
|
||||
|
||||
func expectQuery(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectBegin()
|
||||
q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...)
|
||||
m.ExpectCommit()
|
||||
result := sqlmock.NewRows(cols)
|
||||
count := uint64(len(rows))
|
||||
for _, row := range rows {
|
||||
|
Reference in New Issue
Block a user