| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | package database | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							| 
									
										
										
										
											2024-01-04 17:12:20 +01:00
										 |  |  | 	"context" | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 	"database/sql" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	sq "github.com/Masterminds/squirrel" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 01:01:45 +02:00
										 |  |  | 	"github.com/zitadel/zitadel/internal/crypto" | 
					
						
							| 
									
										
										
										
											2023-08-22 12:49:22 +02:00
										 |  |  | 	z_db "github.com/zitadel/zitadel/internal/database" | 
					
						
							| 
									
										
										
										
											2023-12-08 16:30:55 +02:00
										 |  |  | 	"github.com/zitadel/zitadel/internal/zerrors" | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-25 17:28:20 +01:00
										 |  |  | type Database struct { | 
					
						
							| 
									
										
										
										
											2023-08-22 12:49:22 +02:00
										 |  |  | 	client    *z_db.DB | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 	masterKey string | 
					
						
							|  |  |  | 	encrypt   func(key, masterKey string) (encryptedKey string, err error) | 
					
						
							|  |  |  | 	decrypt   func(encryptedKey, masterKey string) (key string, err error) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | const ( | 
					
						
							|  |  |  | 	EncryptionKeysTable  = "system.encryption_keys" | 
					
						
							|  |  |  | 	encryptionKeysIDCol  = "id" | 
					
						
							|  |  |  | 	encryptionKeysKeyCol = "key" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-25 17:28:20 +01:00
										 |  |  | func NewKeyStorage(client *z_db.DB, masterKey string) (*Database, error) { | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 	if err := checkMasterKeyLength(masterKey); err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2024-01-25 17:28:20 +01:00
										 |  |  | 	return &Database{ | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 		client:    client, | 
					
						
							|  |  |  | 		masterKey: masterKey, | 
					
						
							|  |  |  | 		encrypt:   crypto.EncryptAESString, | 
					
						
							|  |  |  | 		decrypt:   crypto.DecryptAESString, | 
					
						
							|  |  |  | 	}, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-25 17:28:20 +01:00
										 |  |  | func (d *Database) ReadKeys() (crypto.Keys, error) { | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 	keys := make(map[string]string) | 
					
						
							|  |  |  | 	stmt, args, err := sq.Select(encryptionKeysIDCol, encryptionKeysKeyCol). | 
					
						
							|  |  |  | 		From(EncryptionKeysTable). | 
					
						
							|  |  |  | 		ToSql() | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2023-12-08 16:30:55 +02:00
										 |  |  | 		return nil, zerrors.ThrowInternal(err, "", "unable to read keys") | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2023-08-22 12:49:22 +02:00
										 |  |  | 	err = d.client.Query(func(rows *sql.Rows) error { | 
					
						
							|  |  |  | 		for rows.Next() { | 
					
						
							|  |  |  | 			var id, encryptionKey string | 
					
						
							|  |  |  | 			err = rows.Scan(&id, &encryptionKey) | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							| 
									
										
										
										
											2023-12-08 16:30:55 +02:00
										 |  |  | 				return zerrors.ThrowInternal(err, "", "unable to read keys") | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 			} | 
					
						
							| 
									
										
										
										
											2023-08-22 12:49:22 +02:00
										 |  |  | 			key, err := d.decrypt(encryptionKey, d.masterKey) | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							| 
									
										
										
										
											2023-12-08 16:30:55 +02:00
										 |  |  | 				return zerrors.ThrowInternal(err, "", "unable to decrypt key") | 
					
						
							| 
									
										
										
										
											2023-08-22 12:49:22 +02:00
										 |  |  | 			} | 
					
						
							|  |  |  | 			keys[id] = key | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2023-08-22 12:49:22 +02:00
										 |  |  | 		return nil | 
					
						
							|  |  |  | 	}, stmt, args...) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2023-12-08 16:30:55 +02:00
										 |  |  | 		return nil, zerrors.ThrowInternal(err, "", "unable to read keys") | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2023-08-22 12:49:22 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	return keys, nil | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-25 17:28:20 +01:00
										 |  |  | func (d *Database) ReadKey(id string) (_ *crypto.Key, err error) { | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 	stmt, args, err := sq.Select(encryptionKeysKeyCol). | 
					
						
							|  |  |  | 		From(EncryptionKeysTable). | 
					
						
							|  |  |  | 		Where(sq.Eq{encryptionKeysIDCol: id}). | 
					
						
							|  |  |  | 		PlaceholderFormat(sq.Dollar). | 
					
						
							|  |  |  | 		ToSql() | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2023-12-08 16:30:55 +02:00
										 |  |  | 		return nil, zerrors.ThrowInternal(err, "", "unable to read key") | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2023-08-22 12:49:22 +02:00
										 |  |  | 	var key string | 
					
						
							|  |  |  | 	err = d.client.QueryRow(func(row *sql.Row) error { | 
					
						
							|  |  |  | 		var encryptionKey string | 
					
						
							|  |  |  | 		err = row.Scan(&encryptionKey) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							| 
									
										
										
										
											2023-12-08 16:30:55 +02:00
										 |  |  | 			return zerrors.ThrowInternal(err, "", "unable to read key") | 
					
						
							| 
									
										
										
										
											2023-08-22 12:49:22 +02:00
										 |  |  | 		} | 
					
						
							|  |  |  | 		key, err = d.decrypt(encryptionKey, d.masterKey) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							| 
									
										
										
										
											2023-12-08 16:30:55 +02:00
										 |  |  | 			return zerrors.ThrowInternal(err, "", "unable to decrypt key") | 
					
						
							| 
									
										
										
										
											2023-08-22 12:49:22 +02:00
										 |  |  | 		} | 
					
						
							|  |  |  | 		return nil | 
					
						
							|  |  |  | 	}, stmt, args...) | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2023-12-08 16:30:55 +02:00
										 |  |  | 		return nil, zerrors.ThrowInternal(err, "", "unable to read key") | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2023-08-22 12:49:22 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 	return &crypto.Key{ | 
					
						
							|  |  |  | 		ID:    id, | 
					
						
							|  |  |  | 		Value: key, | 
					
						
							|  |  |  | 	}, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-25 17:28:20 +01:00
										 |  |  | func (d *Database) CreateKeys(ctx context.Context, keys ...*crypto.Key) error { | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 	insert := sq.Insert(EncryptionKeysTable). | 
					
						
							|  |  |  | 		Columns(encryptionKeysIDCol, encryptionKeysKeyCol).PlaceholderFormat(sq.Dollar) | 
					
						
							|  |  |  | 	for _, key := range keys { | 
					
						
							|  |  |  | 		encryptionKey, err := d.encrypt(key.Value, d.masterKey) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							| 
									
										
										
										
											2023-12-08 16:30:55 +02:00
										 |  |  | 			return zerrors.ThrowInternal(err, "", "unable to encrypt key") | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 		} | 
					
						
							|  |  |  | 		insert = insert.Values(key.ID, encryptionKey) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	stmt, args, err := insert.ToSql() | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2023-12-08 16:30:55 +02:00
										 |  |  | 		return zerrors.ThrowInternal(err, "", "unable to insert new keys") | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2024-01-04 17:12:20 +01:00
										 |  |  | 	tx, err := d.client.BeginTx(ctx, nil) | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2023-12-08 16:30:55 +02:00
										 |  |  | 		return zerrors.ThrowInternal(err, "", "unable to insert new keys") | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 	} | 
					
						
							|  |  |  | 	_, err = tx.Exec(stmt, args...) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		tx.Rollback() | 
					
						
							| 
									
										
										
										
											2023-12-08 16:30:55 +02:00
										 |  |  | 		return zerrors.ThrowInternal(err, "", "unable to insert new keys") | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 	} | 
					
						
							|  |  |  | 	err = tx.Commit() | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2023-12-08 16:30:55 +02:00
										 |  |  | 		return zerrors.ThrowInternal(err, "", "unable to insert new keys") | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 	} | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func checkMasterKeyLength(masterKey string) error { | 
					
						
							|  |  |  | 	if length := len([]byte(masterKey)); length != 32 { | 
					
						
							| 
									
										
										
										
											2023-12-08 16:30:55 +02:00
										 |  |  | 		return zerrors.ThrowInternalf(nil, "", "masterkey must be 32 bytes, but is %d", length) | 
					
						
							| 
									
										
										
										
											2022-03-14 07:55:09 +01:00
										 |  |  | 	} | 
					
						
							|  |  |  | 	return nil | 
					
						
							|  |  |  | } |