package crypto

import (
	"crypto/rand"
	"time"

	"github.com/caos/zitadel/internal/config/types"
	"github.com/caos/zitadel/internal/errors"
)

var (
	lowerLetters = []rune("abcdefghijklmnopqrstuvwxyz")
	upperLetters = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
	digits       = []rune("0123456789")
	symbols      = []rune("~!@#$^&*()_+`-={}|[]:<>?,./")
)

type GeneratorConfig struct {
	Length              uint
	Expiry              types.Duration
	IncludeLowerLetters bool
	IncludeUpperLetters bool
	IncludeDigits       bool
	IncludeSymbols      bool
}

type Generator interface {
	Length() uint
	Expiry() time.Duration
	Alg() Crypto
	Runes() []rune
}

type generator struct {
	length uint
	expiry time.Duration
	runes  []rune
}

func (g *generator) Length() uint {
	return g.length
}

func (g *generator) Expiry() time.Duration {
	return g.expiry
}

func (g *generator) Runes() []rune {
	return g.runes
}

type encryptionGenerator struct {
	generator
	alg EncryptionAlgorithm
}

func (g *encryptionGenerator) Alg() Crypto {
	return g.alg
}

func NewEncryptionGenerator(config GeneratorConfig, algorithm EncryptionAlgorithm) Generator {
	return &encryptionGenerator{
		newGenerator(config),
		algorithm,
	}
}

type hashGenerator struct {
	generator
	alg HashAlgorithm
}

func (g *hashGenerator) Alg() Crypto {
	return g.alg
}

func NewHashGenerator(config GeneratorConfig, algorithm HashAlgorithm) Generator {
	return &hashGenerator{
		newGenerator(config),
		algorithm,
	}
}

func newGenerator(config GeneratorConfig) generator {
	var runes []rune
	if config.IncludeLowerLetters {
		runes = append(runes, lowerLetters...)
	}
	if config.IncludeUpperLetters {
		runes = append(runes, upperLetters...)
	}
	if config.IncludeDigits {
		runes = append(runes, digits...)
	}
	if config.IncludeSymbols {
		runes = append(runes, symbols...)
	}
	return generator{
		length: config.Length,
		expiry: config.Expiry.Duration,
		runes:  runes,
	}
}

func NewCode(g Generator) (*CryptoValue, string, error) {
	code, err := generateRandomString(g.Length(), g.Runes())
	if err != nil {
		return nil, "", err
	}
	crypto, err := Crypt([]byte(code), g.Alg())
	if err != nil {
		return nil, "", err
	}
	return crypto, code, nil
}

func IsCodeExpired(creationDate time.Time, expiry time.Duration) bool {
	if expiry == 0 {
		return false
	}
	return creationDate.Add(expiry).Before(time.Now().UTC())
}

func VerifyCode(creationDate time.Time, expiry time.Duration, cryptoCode *CryptoValue, verificationCode string, g Generator) error {
	if IsCodeExpired(creationDate, expiry) {
		return errors.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired")
	}
	switch alg := g.Alg().(type) {
	case EncryptionAlgorithm:
		return verifyEncryptedCode(cryptoCode, verificationCode, alg)
	case HashAlgorithm:
		return verifyHashedCode(cryptoCode, verificationCode, alg)
	}
	return errors.ThrowInvalidArgument(nil, "CODE-fW2gNa", "Errors.User.Code.GeneratorAlgNotSupported")
}

func generateRandomString(length uint, chars []rune) (string, error) {
	if length == 0 {
		return "", nil
	}

	max := len(chars) - 1
	maxStr := int(length - 1)

	str := make([]rune, length)
	randBytes := make([]byte, length)
	if _, err := rand.Read(randBytes); err != nil {
		return "", err
	}
	for i, rb := range randBytes {
		str[i] = chars[int(rb)%max]
		if i == maxStr {
			return string(str), nil
		}
	}
	return "", nil
}

func verifyEncryptedCode(cryptoCode *CryptoValue, verificationCode string, alg EncryptionAlgorithm) error {
	if cryptoCode == nil {
		return errors.ThrowInvalidArgument(nil, "CRYPT-aqrFV", "Errors.User.Code.CryptoCodeNil")
	}
	code, err := DecryptString(cryptoCode, alg)
	if err != nil {
		return err
	}

	if code != verificationCode {
		return errors.ThrowInvalidArgument(nil, "CODE-woT0xc", "Errors.User.Code.Invalid")
	}
	return nil
}

func verifyHashedCode(cryptoCode *CryptoValue, verificationCode string, alg HashAlgorithm) error {
	if cryptoCode == nil {
		return errors.ThrowInvalidArgument(nil, "CRYPT-2q3r", "cryptoCode must not be nil")
	}
	return CompareHash(cryptoCode, []byte(verificationCode), alg)
}