Tim Möhlmann a2f60f2e7a
perf(query): org permission function for resources (#9677)
# Which Problems Are Solved

Classic permission checks execute for every returned row on resource
based search APIs. Complete background and problem definition can be
found here: https://github.com/zitadel/zitadel/issues/9188

# How the Problems Are Solved

- PermissionClause function now support dynamic query building, so it
supports multiple cases.
- PermissionClause is applied to all list resources which support org
level permissions.
- Wrap permission logic into wrapper functions so we keep the business
logic clean.

# Additional Changes

- Handle org ID optimization in the query package, so it is reusable for
all resources, instead of extracting the filter in the API.
- Cleanup and test system user conversion in the authz package. (context
middleware)
- Fix: `core_integration_db_up` make recipe was missing the postgres
service.

# Additional Context

- Related to https://github.com/zitadel/zitadel/issues/9190
2025-04-15 18:38:25 +02:00

262 lines
5.8 KiB
Go

package database
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"reflect"
"strings"
"time"
"github.com/jackc/pgx/v5/pgtype"
)
type TextArray[T ~string] pgtype.FlatArray[T]
// Scan implements the [database/sql.Scanner] interface.
func (s *TextArray[T]) Scan(src any) error {
var typedArray []string
err := pgtype.NewMap().SQLScanner(&typedArray).Scan(src)
if err != nil {
return err
}
(*s) = make(TextArray[T], len(typedArray))
for i, value := range typedArray {
(*s)[i] = T(value)
}
return nil
}
// Value implements the [database/sql/driver.Valuer] interface.
func (s TextArray[T]) Value() (driver.Value, error) {
if len(s) == 0 {
return nil, nil
}
typed := make([]string, len(s))
for i, value := range s {
typed[i] = string(value)
}
return []byte("{" + strings.Join(typed, ",") + "}"), nil
}
type ByteArray[T ~byte] pgtype.FlatArray[T]
// Scan implements the [database/sql.Scanner] interface.
func (s *ByteArray[T]) Scan(src any) error {
var typedArray []byte
typedArray, ok := src.([]byte)
if !ok {
// tests use a different src type
err := pgtype.NewMap().SQLScanner(&typedArray).Scan(src)
if err != nil {
return err
}
}
(*s) = make(ByteArray[T], len(typedArray))
for i, value := range typedArray {
(*s)[i] = T(value)
}
return nil
}
// Value implements the [database/sql/driver.Valuer] interface.
func (s ByteArray[T]) Value() (driver.Value, error) {
if len(s) == 0 {
return nil, nil
}
typed := make([]byte, len(s))
for i, value := range s {
typed[i] = byte(value)
}
return typed, nil
}
type numberField interface {
~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~int | ~uint
}
type numberTypeField interface {
int8 | uint8 | int16 | uint16 | int32 | uint32 | int64 | uint64 | int | uint
}
var _ sql.Scanner = (*NumberArray[int8])(nil)
type NumberArray[F numberField] pgtype.FlatArray[F]
// Scan implements the [database/sql.Scanner] interface.
func (a *NumberArray[F]) Scan(src any) (err error) {
var (
mapper func()
scanner sql.Scanner
)
//nolint: exhaustive
// only defined types
switch reflect.TypeOf(*a).Elem().Kind() {
case reflect.Int8:
mapper, scanner = castedScan[int8](a)
case reflect.Uint8:
// we provide int16 is a workaround because pgx thinks we want to scan a byte array if we provide uint8
mapper, scanner = castedScan[int16](a)
case reflect.Int16:
mapper, scanner = castedScan[int16](a)
case reflect.Uint16:
mapper, scanner = castedScan[uint16](a)
case reflect.Int32:
mapper, scanner = castedScan[int32](a)
case reflect.Uint32:
mapper, scanner = castedScan[uint32](a)
case reflect.Int64:
mapper, scanner = castedScan[int64](a)
case reflect.Uint64:
mapper, scanner = castedScan[uint64](a)
case reflect.Int:
mapper, scanner = castedScan[int](a)
case reflect.Uint:
mapper, scanner = castedScan[uint](a)
}
if err = scanner.Scan(src); err != nil {
return err
}
mapper()
return nil
}
func castedScan[T numberTypeField, F numberField](a *NumberArray[F]) (mapper func(), scanner sql.Scanner) {
var typedArray []T
mapper = func() {
(*a) = make(NumberArray[F], len(typedArray))
for i, value := range typedArray {
(*a)[i] = F(value)
}
}
scanner = pgtype.NewMap().SQLScanner(&typedArray)
return mapper, scanner
}
type Map[V any] map[string]V
// Scan implements the [database/sql.Scanner] interface.
func (m *Map[V]) Scan(src any) error {
if src == nil {
return nil
}
bytes := src.([]byte)
if len(bytes) == 0 {
return nil
}
return json.Unmarshal(bytes, &m)
}
// Value implements the [database/sql/driver.Valuer] interface.
func (m Map[V]) Value() (driver.Value, error) {
if len(m) == 0 {
return nil, nil
}
return json.Marshal(m)
}
type Duration time.Duration
// Scan implements the [database/sql.Scanner] interface.
func (d *Duration) Scan(src any) error {
switch duration := src.(type) {
case *time.Duration:
*d = Duration(*duration)
return nil
case time.Duration:
*d = Duration(duration)
return nil
case *pgtype.Interval:
*d = intervalToDuration(duration)
return nil
case pgtype.Interval:
*d = intervalToDuration(&duration)
return nil
case int64:
*d = Duration(duration)
return nil
}
interval := new(pgtype.Interval)
if err := interval.Scan(src); err != nil {
return err
}
*d = intervalToDuration(interval)
return nil
}
func intervalToDuration(interval *pgtype.Interval) Duration {
return Duration(time.Duration(interval.Microseconds*1000) + time.Duration(interval.Days)*24*time.Hour + time.Duration(interval.Months)*30*24*time.Hour)
}
// NullDuration can be used for NULL intervals.
// If Valid is false, the scanned value was NULL
// This behavior is similar to [database/sql.NullString]
type NullDuration struct {
Valid bool
Duration time.Duration
}
// Scan implements the [database/sql.Scanner] interface.
func (d *NullDuration) Scan(src any) error {
if src == nil {
d.Duration, d.Valid = 0, false
return nil
}
duration := new(Duration)
if err := duration.Scan(src); err != nil {
return err
}
d.Duration, d.Valid = time.Duration(*duration), true
return nil
}
// JSONArray allows sending and receiving JSON arrays to and from the database.
// It implements the [database/sql.Scanner] and [database/sql/driver.Valuer] interfaces.
// Values are marshaled and unmarshaled using the [encoding/json] package.
type JSONArray[T any] []T
// NewJSONArray wraps an existing slice into a JSONArray.
func NewJSONArray[T any](a []T) JSONArray[T] {
return JSONArray[T](a)
}
// Scan implements the [database/sql.Scanner] interface.
func (a *JSONArray[T]) Scan(src any) error {
if src == nil {
*a = nil
return nil
}
bytes := src.([]byte)
if len(bytes) == 0 {
*a = nil
return nil
}
return json.Unmarshal(bytes, a)
}
// Value implements the [database/sql/driver.Valuer] interface.
func (a JSONArray[T]) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
return json.Marshal(a)
}