zitadel/internal/crypto/database/database.go
Florian Forster fa9f581d56
chore(v2): move to new org (#3499)
* chore: move to new org

* logging

* fix: org rename caos -> zitadel

Co-authored-by: adlerhurst <silvan.reusser@gmail.com>
2022-04-26 23:01:45 +00:00

134 lines
3.7 KiB
Go

package database
import (
"database/sql"
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/crypto"
caos_errs "github.com/zitadel/zitadel/internal/errors"
)
type database struct {
client *sql.DB
masterKey string
encrypt func(key, masterKey string) (encryptedKey string, err error)
decrypt func(encryptedKey, masterKey string) (key string, err error)
}
const (
EncryptionKeysTable = "system.encryption_keys"
encryptionKeysIDCol = "id"
encryptionKeysKeyCol = "key"
)
func NewKeyStorage(client *sql.DB, masterKey string) (*database, error) {
if err := checkMasterKeyLength(masterKey); err != nil {
return nil, err
}
return &database{
client: client,
masterKey: masterKey,
encrypt: crypto.EncryptAESString,
decrypt: crypto.DecryptAESString,
}, nil
}
func (d *database) ReadKeys() (crypto.Keys, error) {
keys := make(map[string]string)
stmt, args, err := sq.Select(encryptionKeysIDCol, encryptionKeysKeyCol).
From(EncryptionKeysTable).
ToSql()
if err != nil {
return nil, caos_errs.ThrowInternal(err, "", "unable to read keys")
}
rows, err := d.client.Query(stmt, args...)
if err != nil {
return nil, caos_errs.ThrowInternal(err, "", "unable to read keys")
}
for rows.Next() {
var id, encryptionKey string
err = rows.Scan(&id, &encryptionKey)
if err != nil {
return nil, caos_errs.ThrowInternal(err, "", "unable to read keys")
}
key, err := d.decrypt(encryptionKey, d.masterKey)
if err != nil {
if err := rows.Close(); err != nil {
return nil, caos_errs.ThrowInternal(err, "", "unable to close rows")
}
return nil, caos_errs.ThrowInternal(err, "", "unable to decrypt key")
}
keys[id] = key
}
if err := rows.Close(); err != nil {
return nil, caos_errs.ThrowInternal(err, "", "unable to close rows")
}
return keys, err
}
func (d *database) ReadKey(id string) (*crypto.Key, error) {
stmt, args, err := sq.Select(encryptionKeysKeyCol).
From(EncryptionKeysTable).
Where(sq.Eq{encryptionKeysIDCol: id}).
PlaceholderFormat(sq.Dollar).
ToSql()
if err != nil {
return nil, caos_errs.ThrowInternal(err, "", "unable to read key")
}
row := d.client.QueryRow(stmt, args...)
if err != nil {
return nil, caos_errs.ThrowInternal(err, "", "unable to read key")
}
var encryptionKey string
err = row.Scan(&encryptionKey)
if err != nil {
return nil, caos_errs.ThrowInternal(err, "", "unable to read key")
}
key, err := d.decrypt(encryptionKey, d.masterKey)
if err != nil {
return nil, caos_errs.ThrowInternal(err, "", "unable to decrypt key")
}
return &crypto.Key{
ID: id,
Value: key,
}, nil
}
func (d *database) CreateKeys(keys ...*crypto.Key) error {
insert := sq.Insert(EncryptionKeysTable).
Columns(encryptionKeysIDCol, encryptionKeysKeyCol).PlaceholderFormat(sq.Dollar)
for _, key := range keys {
encryptionKey, err := d.encrypt(key.Value, d.masterKey)
if err != nil {
return caos_errs.ThrowInternal(err, "", "unable to encrypt key")
}
insert = insert.Values(key.ID, encryptionKey)
}
stmt, args, err := insert.ToSql()
if err != nil {
return caos_errs.ThrowInternal(err, "", "unable to insert new keys")
}
tx, err := d.client.Begin()
if err != nil {
return caos_errs.ThrowInternal(err, "", "unable to insert new keys")
}
_, err = tx.Exec(stmt, args...)
if err != nil {
tx.Rollback()
return caos_errs.ThrowInternal(err, "", "unable to insert new keys")
}
err = tx.Commit()
if err != nil {
return caos_errs.ThrowInternal(err, "", "unable to insert new keys")
}
return nil
}
func checkMasterKeyLength(masterKey string) error {
if length := len([]byte(masterKey)); length != 32 {
return caos_errs.ThrowInternalf(nil, "", "masterkey must be 32 bytes, but is %d", length)
}
return nil
}