2022-08-31 09:52:43 +02:00
|
|
|
package database
|
|
|
|
|
|
|
|
import (
|
2024-03-27 14:48:22 +01:00
|
|
|
"database/sql"
|
2022-08-31 09:52:43 +02:00
|
|
|
"database/sql/driver"
|
2023-05-05 17:34:53 +02:00
|
|
|
"encoding/json"
|
2024-03-27 14:48:22 +01:00
|
|
|
"reflect"
|
|
|
|
"strings"
|
2023-09-15 16:58:45 +02:00
|
|
|
"time"
|
2022-08-31 09:52:43 +02:00
|
|
|
|
2024-03-27 14:48:22 +01:00
|
|
|
"github.com/jackc/pgx/v5/pgtype"
|
2022-08-31 09:52:43 +02:00
|
|
|
)
|
|
|
|
|
2024-03-27 14:48:22 +01:00
|
|
|
type TextArray[T ~string] pgtype.FlatArray[T]
|
2022-08-31 09:52:43 +02:00
|
|
|
|
2023-05-05 17:34:53 +02:00
|
|
|
// Scan implements the [database/sql.Scanner] interface.
|
2024-03-27 14:48:22 +01:00
|
|
|
func (s *TextArray[T]) Scan(src any) error {
|
|
|
|
var typedArray []string
|
|
|
|
err := pgtype.NewMap().SQLScanner(&typedArray).Scan(src)
|
|
|
|
if err != nil {
|
2022-08-31 09:52:43 +02:00
|
|
|
return err
|
|
|
|
}
|
2024-03-27 14:48:22 +01:00
|
|
|
|
|
|
|
(*s) = make(TextArray[T], len(typedArray))
|
|
|
|
for i, value := range typedArray {
|
|
|
|
(*s)[i] = T(value)
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
2022-08-31 09:52:43 +02:00
|
|
|
}
|
|
|
|
|
2023-05-05 17:34:53 +02:00
|
|
|
// Value implements the [database/sql/driver.Valuer] interface.
|
2024-03-27 14:48:22 +01:00
|
|
|
func (s TextArray[T]) Value() (driver.Value, error) {
|
2022-08-31 09:52:43 +02:00
|
|
|
if len(s) == 0 {
|
|
|
|
return nil, nil
|
|
|
|
}
|
|
|
|
|
2024-03-27 14:48:22 +01:00
|
|
|
typed := make([]string, len(s))
|
2022-08-31 09:52:43 +02:00
|
|
|
|
2024-03-27 14:48:22 +01:00
|
|
|
for i, value := range s {
|
|
|
|
typed[i] = string(value)
|
|
|
|
}
|
2022-08-31 09:52:43 +02:00
|
|
|
|
2024-03-27 14:48:22 +01:00
|
|
|
return []byte("{" + strings.Join(typed, ",") + "}"), nil
|
2022-08-31 09:52:43 +02:00
|
|
|
}
|
|
|
|
|
2024-03-27 14:48:22 +01:00
|
|
|
type ByteArray[T ~byte] pgtype.FlatArray[T]
|
2022-08-31 09:52:43 +02:00
|
|
|
|
2023-05-05 17:34:53 +02:00
|
|
|
// Scan implements the [database/sql.Scanner] interface.
|
2024-03-27 14:48:22 +01:00
|
|
|
func (s *ByteArray[T]) Scan(src any) error {
|
|
|
|
var typedArray []byte
|
|
|
|
typedArray, ok := src.([]byte)
|
|
|
|
if !ok {
|
|
|
|
// tests use a different src type
|
|
|
|
err := pgtype.NewMap().SQLScanner(&typedArray).Scan(src)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2022-08-31 09:52:43 +02:00
|
|
|
}
|
2024-03-27 14:48:22 +01:00
|
|
|
|
|
|
|
(*s) = make(ByteArray[T], len(typedArray))
|
|
|
|
for i, value := range typedArray {
|
|
|
|
(*s)[i] = T(value)
|
2022-08-31 09:52:43 +02:00
|
|
|
}
|
2024-03-27 14:48:22 +01:00
|
|
|
|
2022-08-31 09:52:43 +02:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2023-05-05 17:34:53 +02:00
|
|
|
// Value implements the [database/sql/driver.Valuer] interface.
|
2024-03-27 14:48:22 +01:00
|
|
|
func (s ByteArray[T]) Value() (driver.Value, error) {
|
|
|
|
if len(s) == 0 {
|
2022-08-31 09:52:43 +02:00
|
|
|
return nil, nil
|
|
|
|
}
|
2024-03-27 14:48:22 +01:00
|
|
|
typed := make([]byte, len(s))
|
|
|
|
|
|
|
|
for i, value := range s {
|
|
|
|
typed[i] = byte(value)
|
|
|
|
}
|
|
|
|
|
|
|
|
return typed, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
type numberField interface {
|
|
|
|
~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~int | ~uint
|
|
|
|
}
|
|
|
|
|
|
|
|
type numberTypeField interface {
|
|
|
|
int8 | uint8 | int16 | uint16 | int32 | uint32 | int64 | uint64 | int | uint
|
|
|
|
}
|
|
|
|
|
|
|
|
var _ sql.Scanner = (*NumberArray[int8])(nil)
|
|
|
|
|
|
|
|
type NumberArray[F numberField] pgtype.FlatArray[F]
|
|
|
|
|
|
|
|
// Scan implements the [database/sql.Scanner] interface.
|
|
|
|
func (a *NumberArray[F]) Scan(src any) (err error) {
|
|
|
|
var (
|
|
|
|
mapper func()
|
|
|
|
scanner sql.Scanner
|
|
|
|
)
|
|
|
|
|
|
|
|
//nolint: exhaustive
|
|
|
|
// only defined types
|
|
|
|
switch reflect.TypeOf(*a).Elem().Kind() {
|
|
|
|
case reflect.Int8:
|
|
|
|
mapper, scanner = castedScan[int8](a)
|
|
|
|
case reflect.Uint8:
|
|
|
|
// we provide int16 is a workaround because pgx thinks we want to scan a byte array if we provide uint8
|
|
|
|
mapper, scanner = castedScan[int16](a)
|
|
|
|
case reflect.Int16:
|
|
|
|
mapper, scanner = castedScan[int16](a)
|
|
|
|
case reflect.Uint16:
|
|
|
|
mapper, scanner = castedScan[uint16](a)
|
|
|
|
case reflect.Int32:
|
|
|
|
mapper, scanner = castedScan[int32](a)
|
|
|
|
case reflect.Uint32:
|
|
|
|
mapper, scanner = castedScan[uint32](a)
|
|
|
|
case reflect.Int64:
|
|
|
|
mapper, scanner = castedScan[int64](a)
|
|
|
|
case reflect.Uint64:
|
|
|
|
mapper, scanner = castedScan[uint64](a)
|
|
|
|
case reflect.Int:
|
|
|
|
mapper, scanner = castedScan[int](a)
|
|
|
|
case reflect.Uint:
|
|
|
|
mapper, scanner = castedScan[uint](a)
|
|
|
|
}
|
|
|
|
|
|
|
|
if err = scanner.Scan(src); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
mapper()
|
2022-08-31 09:52:43 +02:00
|
|
|
|
2024-03-27 14:48:22 +01:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func castedScan[T numberTypeField, F numberField](a *NumberArray[F]) (mapper func(), scanner sql.Scanner) {
|
|
|
|
var typedArray []T
|
|
|
|
|
|
|
|
mapper = func() {
|
|
|
|
(*a) = make(NumberArray[F], len(typedArray))
|
|
|
|
for i, value := range typedArray {
|
|
|
|
(*a)[i] = F(value)
|
|
|
|
}
|
2022-08-31 09:52:43 +02:00
|
|
|
}
|
2024-03-27 14:48:22 +01:00
|
|
|
scanner = pgtype.NewMap().SQLScanner(&typedArray)
|
2022-08-31 09:52:43 +02:00
|
|
|
|
2024-03-27 14:48:22 +01:00
|
|
|
return mapper, scanner
|
2022-08-31 09:52:43 +02:00
|
|
|
}
|
2023-05-05 17:34:53 +02:00
|
|
|
|
|
|
|
type Map[V any] map[string]V
|
|
|
|
|
|
|
|
// Scan implements the [database/sql.Scanner] interface.
|
|
|
|
func (m *Map[V]) Scan(src any) error {
|
2024-03-27 14:48:22 +01:00
|
|
|
if src == nil {
|
|
|
|
return nil
|
2023-05-05 17:34:53 +02:00
|
|
|
}
|
2024-03-27 14:48:22 +01:00
|
|
|
|
|
|
|
bytes := src.([]byte)
|
|
|
|
if len(bytes) == 0 {
|
2023-05-05 17:34:53 +02:00
|
|
|
return nil
|
|
|
|
}
|
2024-03-27 14:48:22 +01:00
|
|
|
|
|
|
|
return json.Unmarshal(bytes, &m)
|
2023-05-05 17:34:53 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
// Value implements the [database/sql/driver.Valuer] interface.
|
|
|
|
func (m Map[V]) Value() (driver.Value, error) {
|
|
|
|
if len(m) == 0 {
|
|
|
|
return nil, nil
|
|
|
|
}
|
|
|
|
return json.Marshal(m)
|
|
|
|
}
|
2023-09-15 16:58:45 +02:00
|
|
|
|
|
|
|
type Duration time.Duration
|
|
|
|
|
|
|
|
// Scan implements the [database/sql.Scanner] interface.
|
|
|
|
func (d *Duration) Scan(src any) error {
|
2024-03-27 14:48:22 +01:00
|
|
|
switch duration := src.(type) {
|
|
|
|
case *time.Duration:
|
|
|
|
*d = Duration(*duration)
|
|
|
|
return nil
|
|
|
|
case time.Duration:
|
|
|
|
*d = Duration(duration)
|
|
|
|
return nil
|
|
|
|
case *pgtype.Interval:
|
|
|
|
*d = intervalToDuration(duration)
|
|
|
|
return nil
|
|
|
|
case pgtype.Interval:
|
|
|
|
*d = intervalToDuration(&duration)
|
|
|
|
return nil
|
|
|
|
case int64:
|
|
|
|
*d = Duration(duration)
|
|
|
|
return nil
|
|
|
|
}
|
2023-09-15 16:58:45 +02:00
|
|
|
interval := new(pgtype.Interval)
|
|
|
|
if err := interval.Scan(src); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2024-03-27 14:48:22 +01:00
|
|
|
*d = intervalToDuration(interval)
|
2023-09-15 16:58:45 +02:00
|
|
|
return nil
|
|
|
|
}
|
2023-10-25 13:42:00 +02:00
|
|
|
|
2024-03-27 14:48:22 +01:00
|
|
|
func intervalToDuration(interval *pgtype.Interval) Duration {
|
|
|
|
return Duration(time.Duration(interval.Microseconds*1000) + time.Duration(interval.Days)*24*time.Hour + time.Duration(interval.Months)*30*24*time.Hour)
|
|
|
|
}
|
|
|
|
|
2023-10-25 13:42:00 +02:00
|
|
|
// NullDuration can be used for NULL intervals.
|
|
|
|
// If Valid is false, the scanned value was NULL
|
|
|
|
// This behavior is similar to [database/sql.NullString]
|
|
|
|
type NullDuration struct {
|
|
|
|
Valid bool
|
|
|
|
Duration time.Duration
|
|
|
|
}
|
|
|
|
|
|
|
|
// Scan implements the [database/sql.Scanner] interface.
|
|
|
|
func (d *NullDuration) Scan(src any) error {
|
|
|
|
if src == nil {
|
|
|
|
d.Duration, d.Valid = 0, false
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
duration := new(Duration)
|
|
|
|
if err := duration.Scan(src); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
d.Duration, d.Valid = time.Duration(*duration), true
|
|
|
|
return nil
|
|
|
|
}
|