package database

import (
	"context"
	"database/sql"

	sq "github.com/Masterminds/squirrel"

	"github.com/zitadel/zitadel/internal/crypto"
	z_db "github.com/zitadel/zitadel/internal/database"
	"github.com/zitadel/zitadel/internal/zerrors"
)

type Database struct {
	client    *z_db.DB
	masterKey string
	encrypt   func(key, masterKey string) (encryptedKey string, err error)
	decrypt   func(encryptedKey, masterKey string) (key string, err error)
}

const (
	EncryptionKeysTable  = "system.encryption_keys"
	encryptionKeysIDCol  = "id"
	encryptionKeysKeyCol = "key"
)

func NewKeyStorage(client *z_db.DB, masterKey string) (*Database, error) {
	if err := checkMasterKeyLength(masterKey); err != nil {
		return nil, err
	}
	return &Database{
		client:    client,
		masterKey: masterKey,
		encrypt:   crypto.EncryptAESString,
		decrypt:   crypto.DecryptAESString,
	}, nil
}

func (d *Database) ReadKeys() (crypto.Keys, error) {
	keys := make(map[string]string)
	stmt, args, err := sq.Select(encryptionKeysIDCol, encryptionKeysKeyCol).
		From(EncryptionKeysTable).
		ToSql()
	if err != nil {
		return nil, zerrors.ThrowInternal(err, "", "unable to read keys")
	}
	err = d.client.Query(func(rows *sql.Rows) error {
		for rows.Next() {
			var id, encryptionKey string
			err = rows.Scan(&id, &encryptionKey)
			if err != nil {
				return zerrors.ThrowInternal(err, "", "unable to read keys")
			}
			key, err := d.decrypt(encryptionKey, d.masterKey)
			if err != nil {
				return zerrors.ThrowInternal(err, "", "unable to decrypt key")
			}
			keys[id] = key
		}
		return nil
	}, stmt, args...)

	if err != nil {
		return nil, zerrors.ThrowInternal(err, "", "unable to read keys")
	}

	return keys, nil
}

func (d *Database) ReadKey(id string) (_ *crypto.Key, err error) {
	stmt, args, err := sq.Select(encryptionKeysKeyCol).
		From(EncryptionKeysTable).
		Where(sq.Eq{encryptionKeysIDCol: id}).
		PlaceholderFormat(sq.Dollar).
		ToSql()
	if err != nil {
		return nil, zerrors.ThrowInternal(err, "", "unable to read key")
	}
	var key string
	err = d.client.QueryRow(func(row *sql.Row) error {
		var encryptionKey string
		err = row.Scan(&encryptionKey)
		if err != nil {
			return zerrors.ThrowInternal(err, "", "unable to read key")
		}
		key, err = d.decrypt(encryptionKey, d.masterKey)
		if err != nil {
			return zerrors.ThrowInternal(err, "", "unable to decrypt key")
		}
		return nil
	}, stmt, args...)
	if err != nil {
		return nil, zerrors.ThrowInternal(err, "", "unable to read key")
	}

	return &crypto.Key{
		ID:    id,
		Value: key,
	}, nil
}

func (d *Database) CreateKeys(ctx context.Context, keys ...*crypto.Key) error {
	insert := sq.Insert(EncryptionKeysTable).
		Columns(encryptionKeysIDCol, encryptionKeysKeyCol).PlaceholderFormat(sq.Dollar)
	for _, key := range keys {
		encryptionKey, err := d.encrypt(key.Value, d.masterKey)
		if err != nil {
			return zerrors.ThrowInternal(err, "", "unable to encrypt key")
		}
		insert = insert.Values(key.ID, encryptionKey)
	}
	stmt, args, err := insert.ToSql()
	if err != nil {
		return zerrors.ThrowInternal(err, "", "unable to insert new keys")
	}
	tx, err := d.client.BeginTx(ctx, nil)
	if err != nil {
		return zerrors.ThrowInternal(err, "", "unable to insert new keys")
	}
	_, err = tx.Exec(stmt, args...)
	if err != nil {
		tx.Rollback()
		return zerrors.ThrowInternal(err, "", "unable to insert new keys")
	}
	err = tx.Commit()
	if err != nil {
		return zerrors.ThrowInternal(err, "", "unable to insert new keys")
	}
	return nil
}

func checkMasterKeyLength(masterKey string) error {
	if length := len([]byte(masterKey)); length != 32 {
		return zerrors.ThrowInternalf(nil, "", "masterkey must be 32 bytes, but is %d", length)
	}
	return nil
}