package crypto

import (
	"crypto/aes"
	"crypto/cipher"
	"crypto/rand"
	"encoding/base64"
	"io"

	"github.com/zitadel/zitadel/internal/zerrors"
)

var _ EncryptionAlgorithm = (*AESCrypto)(nil)

type AESCrypto struct {
	keys            map[string]string
	encryptionKeyID string
	keyIDs          []string
}

func NewAESCrypto(config *KeyConfig, keyStorage KeyStorage) (*AESCrypto, error) {
	keys, ids, err := LoadKeys(config, keyStorage)
	if err != nil {
		return nil, err
	}
	return &AESCrypto{
		keys:            keys,
		encryptionKeyID: config.EncryptionKeyID,
		keyIDs:          ids,
	}, nil
}

func (a *AESCrypto) Algorithm() string {
	return "aes"
}

func (a *AESCrypto) Encrypt(value []byte) ([]byte, error) {
	return EncryptAES(value, a.encryptionKey())
}

func (a *AESCrypto) Decrypt(value []byte, keyID string) ([]byte, error) {
	key, err := a.decryptionKey(keyID)
	if err != nil {
		return nil, err
	}
	return DecryptAES(value, key)
}

func (a *AESCrypto) DecryptString(value []byte, keyID string) (string, error) {
	key, err := a.decryptionKey(keyID)
	if err != nil {
		return "", err
	}
	b, err := DecryptAES(value, key)
	if err != nil {
		return "", err
	}
	return string(b), nil
}

func (a *AESCrypto) EncryptionKeyID() string {
	return a.encryptionKeyID
}

func (a *AESCrypto) DecryptionKeyIDs() []string {
	return a.keyIDs
}

func (a *AESCrypto) encryptionKey() string {
	return a.keys[a.encryptionKeyID]
}

func (a *AESCrypto) decryptionKey(keyID string) (string, error) {
	key, ok := a.keys[keyID]
	if !ok {
		return "", zerrors.ThrowNotFound(nil, "CRYPT-nkj1s", "unknown key id")
	}
	return key, nil
}

func EncryptAESString(data string, key string) (string, error) {
	encrypted, err := EncryptAES([]byte(data), key)
	if err != nil {
		return "", err
	}
	return base64.URLEncoding.EncodeToString(encrypted), nil
}

func EncryptAES(plainText []byte, key string) ([]byte, error) {
	block, err := aes.NewCipher([]byte(key))
	if err != nil {
		return nil, err
	}

	maxSize := 64 * 1024 * 1024
	if len(plainText) > maxSize {
		return nil, zerrors.ThrowPreconditionFailedf(nil, "CRYPT-AGg4t3", "data too large, max bytes: %v", maxSize)
	}
	cipherText := make([]byte, aes.BlockSize+len(plainText))
	iv := cipherText[:aes.BlockSize]
	if _, err = io.ReadFull(rand.Reader, iv); err != nil {
		return nil, err
	}

	stream := cipher.NewCFBEncrypter(block, iv)
	stream.XORKeyStream(cipherText[aes.BlockSize:], plainText)

	return cipherText, nil
}

func DecryptAESString(data string, key string) (string, error) {
	text, err := base64.URLEncoding.DecodeString(data)
	if err != nil {
		return "", nil
	}
	decrypted, err := DecryptAES(text, key)
	if err != nil {
		return "", err
	}
	return string(decrypted), nil
}

func DecryptAES(text []byte, key string) ([]byte, error) {
	cipherText := make([]byte, len(text))
	copy(cipherText, text)

	block, err := aes.NewCipher([]byte(key))
	if err != nil {
		return nil, err
	}

	if len(cipherText) < aes.BlockSize {
		err = zerrors.ThrowPreconditionFailed(nil, "CRYPT-23kH1", "cipher text block too short")
		return nil, err
	}
	iv := cipherText[:aes.BlockSize]
	cipherText = cipherText[aes.BlockSize:]

	stream := cipher.NewCFBDecrypter(block, iv)
	stream.XORKeyStream(cipherText, cipherText)

	return cipherText, err
}