2025-01-06 08:00:35 +01:00

353 lines
8.0 KiB
Go

package consistent
import (
"context"
"errors"
"reflect"
"slices"
"strings"
"sync"
"github.com/zitadel/zitadel/backend/internal/port"
"github.com/zitadel/zitadel/backend/internal/port/storage"
)
var (
_ port.Executor[port.Object] = (*MapRepository[port.Object])(nil)
_ port.Querier[port.Object] = (*MapRepository[port.Object])(nil)
_ port.Client[port.Object] = (*MapRepository[port.Object])(nil)
// _ port.Pool[any] = (*MapRepository[any])(nil)
_ port.Executor[port.Object] = (*MapTransaction[port.Object])(nil)
_ port.Querier[port.Object] = (*MapTransaction[port.Object])(nil)
_ port.Transaction[port.Object] = (*MapTransaction[port.Object])(nil)
)
var pkFields = &sync.Map{}
type MapRepository[T port.Object] struct {
objects []*row[T]
}
type row[T port.Object] struct {
object *T
mu *sync.RWMutex
}
func (m *MapRepository[T]) Begin(ctx context.Context) (port.Transaction[T], error) {
return &MapTransaction[T]{
repo: m,
}, nil
}
func (m *MapRepository[T]) Release(ctx context.Context) error {
return nil
}
// Create implements [port.Executor]
func (m *MapRepository[T]) Create(ctx context.Context, object T) (err error) {
tx, err := m.Begin(ctx)
if err != nil {
return err
}
defer tx.End(ctx, err)
return tx.Create(ctx, object)
}
// Delete implements [port.Executor]
func (m *MapRepository[T]) Delete(ctx context.Context, filters []port.Filter) (err error) {
tx, err := m.Begin(ctx)
if err != nil {
return err
}
defer tx.End(ctx, err)
return tx.Delete(ctx, filters)
}
// Update implements [port.Executor]
func (m *MapRepository[T]) Update(ctx context.Context, columns []*port.Column, filters []port.Filter) (err error) {
tx, err := m.Begin(ctx)
if err != nil {
return err
}
defer tx.End(ctx, err)
return tx.Update(ctx, columns, filters)
}
// Get implements [port.Querier]
func (m *MapRepository[T]) Get(ctx context.Context, filters []port.Filter) (o T, err error) {
tx, err := m.Begin(ctx)
if err != nil {
return o, err
}
defer tx.End(ctx, err)
return tx.Get(ctx, filters)
}
// List implements [port.Querier]
func (m *MapRepository[T]) List(ctx context.Context, filters []port.Filter) (_ []T, err error) {
tx, err := m.Begin(ctx)
if err != nil {
return nil, err
}
defer tx.End(ctx, err)
return tx.List(ctx, filters)
}
type MapTransaction[T port.Object] struct {
repo *MapRepository[T]
changes []*change[T]
}
type change[T port.Object] struct {
row *row[T]
changed *T
typ changeType
}
type changeType uint8
const (
createChangeType changeType = iota
updateChangeType
deleteChangeType
)
func (m *MapTransaction[T]) End(ctx context.Context, gotErr error) error {
if ctx.Err() == nil && gotErr == nil {
return m.commit()
}
return m.rollback()
}
func (m *MapTransaction[T]) commit() error {
for _, change := range m.changes {
switch change.typ {
case createChangeType:
change.row.object = change.changed
case deleteChangeType:
m.repo.objects = slices.DeleteFunc(m.repo.objects, func(r *row[T]) bool {
return r == change.row
})
change.row.mu.Unlock()
change.row = nil
continue
case updateChangeType:
*change.row.object = *change.changed
default:
return errors.New("unknown change type")
}
change.row.mu.Unlock()
}
return nil
}
func (m *MapTransaction[T]) rollback() error {
for _, change := range m.changes {
if change.changed == nil {
// TODO: remove from repo.objects
}
change.row.mu.Unlock()
change = nil
}
m.changes = nil
return nil
}
// Create implements port.Executor.
func (m *MapTransaction[T]) Create(ctx context.Context, object T) error {
_, err := m.get(ctx, pkFilter(object))
if err == nil {
return storage.ErrAlreadyExists
}
mu := &sync.RWMutex{}
mu.Lock()
m.repo.objects = append(m.repo.objects, &row[T]{
object: &object,
mu: mu,
})
m.changes = append(m.changes, &change[T]{
row: m.repo.objects[len(m.repo.objects)-1],
typ: createChangeType,
})
return nil
}
// Delete implements port.Executor.
func (m *MapTransaction[T]) Delete(ctx context.Context, filters []port.Filter) error {
rows, err := m.list(ctx, filters)
if err != nil {
return err
}
for _, row := range rows {
row.mu.Lock()
m.changes = append(m.changes, &change[T]{
row: row,
typ: deleteChangeType,
})
}
return nil
}
// Update implements port.Executor.
func (m *MapTransaction[T]) Update(ctx context.Context, columns []*port.Column, filters []port.Filter) error {
rows, err := m.list(ctx, filters)
if err != nil {
return err
}
for _, row := range rows {
row.mu.Lock()
changed, err := update(row, columns)
if err != nil {
return err
}
m.changes = append(m.changes, &change[T]{
row: row,
changed: changed,
typ: updateChangeType,
})
}
return nil
}
// Get implements [port.Querier]
func (m *MapTransaction[T]) Get(ctx context.Context, filters []port.Filter) (o T, err error) {
r, err := m.get(ctx, filters)
if err != nil {
return o, err
}
r.mu.RLock()
return o, nil
}
func (m *MapTransaction[T]) get(ctx context.Context, filters []port.Filter) (o *row[T], err error) {
rows, err := m.list(ctx, filters)
if err != nil {
return nil, err
}
if len(rows) > 1 {
return nil, errors.New("multiple rows returned")
}
return rows[0], nil
}
// List implements [port.Querier]
func (m *MapTransaction[T]) List(ctx context.Context, filters []port.Filter) ([]T, error) {
rows, err := m.list(ctx, filters)
if err != nil {
return nil, err
}
objects := make([]T, len(rows))
for i, row := range rows {
row.mu.RLock()
objects[i] = *row.object
}
return objects, nil
}
func (m *MapTransaction[T]) list(ctx context.Context, filters []port.Filter) (res []*row[T], err error) {
res = slices.Clone(m.repo.objects)
for _, filter := range filters {
res = slices.DeleteFunc(res, func(r *row[T]) bool {
v := reflect.ValueOf(r.object)
definition := getDefinition[T]()
return v.Field(definition.columnIndexes[filter.Column().Name]).Interface() != filter.Column().Value
})
res = slices.Clip(res)
}
return res, nil
}
func pkFilter[T port.Object](o T) []port.Filter {
definition := getDefinition[T]()
v := reflect.ValueOf(o)
filters := make([]port.Filter, len(definition.pkFieldIndexes))
for _, idx := range definition.pkFieldIndexes {
filters = append(filters, port.NewEqualFilter(
&port.Column{
Name: definition.columnNames[idx],
Value: v.Field(idx).Interface(),
},
))
}
return filters
}
type structDefinition struct {
columnNames []string
columnIndexes map[string]int
pkFieldIndexes []int
fkFieldIndexes []int
}
func getDefinition[T port.Object]() *structDefinition {
definition, ok := pkFields.Load(reflect.TypeFor[T]())
if !ok {
definition = registerPkFields[T]()
}
return definition.(*structDefinition)
}
func registerPkFields[T port.Object]() *structDefinition {
t := reflect.TypeFor[T]()
definition := &structDefinition{
columnIndexes: make(map[string]int, t.NumField()),
columnNames: make([]string, t.NumField()),
pkFieldIndexes: make([]int, 0, t.NumField()),
fkFieldIndexes: make([]int, 0, t.NumField()),
}
for i := 0; i < t.NumField(); i++ {
tags := t.Field(i).Tag.Get("consistent")
fields := strings.Split(tags, ",")
if fields[0] == "" {
fields[0] = t.Field(i).Name
}
definition.columnIndexes[fields[0]] = i
definition.columnNames[i] = fields[0]
for _, field := range fields[1:] {
switch field {
case "pk":
definition.pkFieldIndexes = append(definition.pkFieldIndexes, i)
case "fk":
definition.fkFieldIndexes = append(definition.fkFieldIndexes, i)
}
}
}
pkFields.Store(t, definition)
return definition
}
func update[T port.Object](r *row[T], columns []*port.Column) (o *T, err error) {
v := reflect.Zero(reflect.TypeOf(r.object))
current := reflect.ValueOf(r.object)
definition := getDefinition[T]()
fields:
for i := 0; i < v.NumField(); i++ {
for _, column := range columns {
if definition.columnIndexes[column.Name] != i {
continue
}
v.Field(i).Set(reflect.ValueOf(column.Value))
continue fields
}
v.Field(i).Set(current.Field(i))
}
return r.object, nil
}