package database

import (
	"database/sql"
	"database/sql/driver"
	"encoding/json"
	"reflect"
	"strings"
	"time"

	"github.com/jackc/pgx/v5/pgtype"
)

type TextArray[T ~string] pgtype.FlatArray[T]

// Scan implements the [database/sql.Scanner] interface.
func (s *TextArray[T]) Scan(src any) error {
	var typedArray []string
	err := pgtype.NewMap().SQLScanner(&typedArray).Scan(src)
	if err != nil {
		return err
	}

	(*s) = make(TextArray[T], len(typedArray))
	for i, value := range typedArray {
		(*s)[i] = T(value)
	}

	return nil
}

// Value implements the [database/sql/driver.Valuer] interface.
func (s TextArray[T]) Value() (driver.Value, error) {
	if len(s) == 0 {
		return nil, nil
	}

	typed := make([]string, len(s))

	for i, value := range s {
		typed[i] = string(value)
	}

	return []byte("{" + strings.Join(typed, ",") + "}"), nil
}

type ByteArray[T ~byte] pgtype.FlatArray[T]

// Scan implements the [database/sql.Scanner] interface.
func (s *ByteArray[T]) Scan(src any) error {
	var typedArray []byte
	typedArray, ok := src.([]byte)
	if !ok {
		// tests use a different src type
		err := pgtype.NewMap().SQLScanner(&typedArray).Scan(src)
		if err != nil {
			return err
		}
	}

	(*s) = make(ByteArray[T], len(typedArray))
	for i, value := range typedArray {
		(*s)[i] = T(value)
	}

	return nil
}

// Value implements the [database/sql/driver.Valuer] interface.
func (s ByteArray[T]) Value() (driver.Value, error) {
	if len(s) == 0 {
		return nil, nil
	}
	typed := make([]byte, len(s))

	for i, value := range s {
		typed[i] = byte(value)
	}

	return typed, nil
}

type numberField interface {
	~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~int | ~uint
}

type numberTypeField interface {
	int8 | uint8 | int16 | uint16 | int32 | uint32 | int64 | uint64 | int | uint
}

var _ sql.Scanner = (*NumberArray[int8])(nil)

type NumberArray[F numberField] pgtype.FlatArray[F]

// Scan implements the [database/sql.Scanner] interface.
func (a *NumberArray[F]) Scan(src any) (err error) {
	var (
		mapper  func()
		scanner sql.Scanner
	)

	//nolint: exhaustive
	// only defined types
	switch reflect.TypeOf(*a).Elem().Kind() {
	case reflect.Int8:
		mapper, scanner = castedScan[int8](a)
	case reflect.Uint8:
		// we provide int16 is a workaround because pgx thinks we want to scan a byte array if we provide uint8
		mapper, scanner = castedScan[int16](a)
	case reflect.Int16:
		mapper, scanner = castedScan[int16](a)
	case reflect.Uint16:
		mapper, scanner = castedScan[uint16](a)
	case reflect.Int32:
		mapper, scanner = castedScan[int32](a)
	case reflect.Uint32:
		mapper, scanner = castedScan[uint32](a)
	case reflect.Int64:
		mapper, scanner = castedScan[int64](a)
	case reflect.Uint64:
		mapper, scanner = castedScan[uint64](a)
	case reflect.Int:
		mapper, scanner = castedScan[int](a)
	case reflect.Uint:
		mapper, scanner = castedScan[uint](a)
	}

	if err = scanner.Scan(src); err != nil {
		return err
	}
	mapper()

	return nil
}

func castedScan[T numberTypeField, F numberField](a *NumberArray[F]) (mapper func(), scanner sql.Scanner) {
	var typedArray []T

	mapper = func() {
		(*a) = make(NumberArray[F], len(typedArray))
		for i, value := range typedArray {
			(*a)[i] = F(value)
		}
	}
	scanner = pgtype.NewMap().SQLScanner(&typedArray)

	return mapper, scanner
}

type Map[V any] map[string]V

// Scan implements the [database/sql.Scanner] interface.
func (m *Map[V]) Scan(src any) error {
	if src == nil {
		return nil
	}

	bytes := src.([]byte)
	if len(bytes) == 0 {
		return nil
	}

	return json.Unmarshal(bytes, &m)
}

// Value implements the [database/sql/driver.Valuer] interface.
func (m Map[V]) Value() (driver.Value, error) {
	if len(m) == 0 {
		return nil, nil
	}
	return json.Marshal(m)
}

type Duration time.Duration

// Scan implements the [database/sql.Scanner] interface.
func (d *Duration) Scan(src any) error {
	switch duration := src.(type) {
	case *time.Duration:
		*d = Duration(*duration)
		return nil
	case time.Duration:
		*d = Duration(duration)
		return nil
	case *pgtype.Interval:
		*d = intervalToDuration(duration)
		return nil
	case pgtype.Interval:
		*d = intervalToDuration(&duration)
		return nil
	case int64:
		*d = Duration(duration)
		return nil
	}
	interval := new(pgtype.Interval)
	if err := interval.Scan(src); err != nil {
		return err
	}
	*d = intervalToDuration(interval)
	return nil
}

func intervalToDuration(interval *pgtype.Interval) Duration {
	return Duration(time.Duration(interval.Microseconds*1000) + time.Duration(interval.Days)*24*time.Hour + time.Duration(interval.Months)*30*24*time.Hour)
}

// NullDuration can be used for NULL intervals.
// If Valid is false, the scanned value was NULL
// This behavior is similar to [database/sql.NullString]
type NullDuration struct {
	Valid    bool
	Duration time.Duration
}

// Scan implements the [database/sql.Scanner] interface.
func (d *NullDuration) Scan(src any) error {
	if src == nil {
		d.Duration, d.Valid = 0, false
		return nil
	}
	duration := new(Duration)
	if err := duration.Scan(src); err != nil {
		return err
	}
	d.Duration, d.Valid = time.Duration(*duration), true
	return nil
}