mirror of
https://github.com/zitadel/zitadel.git
synced 2025-01-11 04:53:39 +00:00
353 lines
8.0 KiB
Go
353 lines
8.0 KiB
Go
|
package consistent
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"reflect"
|
||
|
"slices"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
|
||
|
"github.com/zitadel/zitadel/backend/internal/port"
|
||
|
"github.com/zitadel/zitadel/backend/internal/port/storage"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
_ port.Executor[port.Object] = (*MapRepository[port.Object])(nil)
|
||
|
_ port.Querier[port.Object] = (*MapRepository[port.Object])(nil)
|
||
|
_ port.Client[port.Object] = (*MapRepository[port.Object])(nil)
|
||
|
// _ port.Pool[any] = (*MapRepository[any])(nil)
|
||
|
|
||
|
_ port.Executor[port.Object] = (*MapTransaction[port.Object])(nil)
|
||
|
_ port.Querier[port.Object] = (*MapTransaction[port.Object])(nil)
|
||
|
_ port.Transaction[port.Object] = (*MapTransaction[port.Object])(nil)
|
||
|
)
|
||
|
var pkFields = &sync.Map{}
|
||
|
|
||
|
type MapRepository[T port.Object] struct {
|
||
|
objects []*row[T]
|
||
|
}
|
||
|
|
||
|
type row[T port.Object] struct {
|
||
|
object *T
|
||
|
mu *sync.RWMutex
|
||
|
}
|
||
|
|
||
|
func (m *MapRepository[T]) Begin(ctx context.Context) (port.Transaction[T], error) {
|
||
|
return &MapTransaction[T]{
|
||
|
repo: m,
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
func (m *MapRepository[T]) Release(ctx context.Context) error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Create implements [port.Executor]
|
||
|
func (m *MapRepository[T]) Create(ctx context.Context, object T) (err error) {
|
||
|
tx, err := m.Begin(ctx)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer tx.End(ctx, err)
|
||
|
|
||
|
return tx.Create(ctx, object)
|
||
|
}
|
||
|
|
||
|
// Delete implements [port.Executor]
|
||
|
func (m *MapRepository[T]) Delete(ctx context.Context, filters []port.Filter) (err error) {
|
||
|
tx, err := m.Begin(ctx)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer tx.End(ctx, err)
|
||
|
|
||
|
return tx.Delete(ctx, filters)
|
||
|
}
|
||
|
|
||
|
// Update implements [port.Executor]
|
||
|
func (m *MapRepository[T]) Update(ctx context.Context, columns []*port.Column, filters []port.Filter) (err error) {
|
||
|
tx, err := m.Begin(ctx)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer tx.End(ctx, err)
|
||
|
|
||
|
return tx.Update(ctx, columns, filters)
|
||
|
}
|
||
|
|
||
|
// Get implements [port.Querier]
|
||
|
func (m *MapRepository[T]) Get(ctx context.Context, filters []port.Filter) (o T, err error) {
|
||
|
tx, err := m.Begin(ctx)
|
||
|
if err != nil {
|
||
|
return o, err
|
||
|
}
|
||
|
defer tx.End(ctx, err)
|
||
|
|
||
|
return tx.Get(ctx, filters)
|
||
|
}
|
||
|
|
||
|
// List implements [port.Querier]
|
||
|
func (m *MapRepository[T]) List(ctx context.Context, filters []port.Filter) (_ []T, err error) {
|
||
|
tx, err := m.Begin(ctx)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
defer tx.End(ctx, err)
|
||
|
|
||
|
return tx.List(ctx, filters)
|
||
|
}
|
||
|
|
||
|
type MapTransaction[T port.Object] struct {
|
||
|
repo *MapRepository[T]
|
||
|
changes []*change[T]
|
||
|
}
|
||
|
|
||
|
type change[T port.Object] struct {
|
||
|
row *row[T]
|
||
|
changed *T
|
||
|
typ changeType
|
||
|
}
|
||
|
|
||
|
type changeType uint8
|
||
|
|
||
|
const (
|
||
|
createChangeType changeType = iota
|
||
|
updateChangeType
|
||
|
deleteChangeType
|
||
|
)
|
||
|
|
||
|
func (m *MapTransaction[T]) End(ctx context.Context, gotErr error) error {
|
||
|
if ctx.Err() == nil && gotErr == nil {
|
||
|
return m.commit()
|
||
|
}
|
||
|
return m.rollback()
|
||
|
}
|
||
|
|
||
|
func (m *MapTransaction[T]) commit() error {
|
||
|
for _, change := range m.changes {
|
||
|
switch change.typ {
|
||
|
case createChangeType:
|
||
|
change.row.object = change.changed
|
||
|
case deleteChangeType:
|
||
|
m.repo.objects = slices.DeleteFunc(m.repo.objects, func(r *row[T]) bool {
|
||
|
return r == change.row
|
||
|
})
|
||
|
change.row.mu.Unlock()
|
||
|
change.row = nil
|
||
|
continue
|
||
|
case updateChangeType:
|
||
|
*change.row.object = *change.changed
|
||
|
default:
|
||
|
return errors.New("unknown change type")
|
||
|
}
|
||
|
change.row.mu.Unlock()
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (m *MapTransaction[T]) rollback() error {
|
||
|
for _, change := range m.changes {
|
||
|
if change.changed == nil {
|
||
|
// TODO: remove from repo.objects
|
||
|
}
|
||
|
change.row.mu.Unlock()
|
||
|
change = nil
|
||
|
}
|
||
|
m.changes = nil
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Create implements port.Executor.
|
||
|
func (m *MapTransaction[T]) Create(ctx context.Context, object T) error {
|
||
|
_, err := m.get(ctx, pkFilter(object))
|
||
|
if err == nil {
|
||
|
return storage.ErrAlreadyExists
|
||
|
}
|
||
|
|
||
|
mu := &sync.RWMutex{}
|
||
|
mu.Lock()
|
||
|
m.repo.objects = append(m.repo.objects, &row[T]{
|
||
|
object: &object,
|
||
|
mu: mu,
|
||
|
})
|
||
|
m.changes = append(m.changes, &change[T]{
|
||
|
row: m.repo.objects[len(m.repo.objects)-1],
|
||
|
typ: createChangeType,
|
||
|
})
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Delete implements port.Executor.
|
||
|
func (m *MapTransaction[T]) Delete(ctx context.Context, filters []port.Filter) error {
|
||
|
rows, err := m.list(ctx, filters)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
for _, row := range rows {
|
||
|
row.mu.Lock()
|
||
|
m.changes = append(m.changes, &change[T]{
|
||
|
row: row,
|
||
|
typ: deleteChangeType,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Update implements port.Executor.
|
||
|
func (m *MapTransaction[T]) Update(ctx context.Context, columns []*port.Column, filters []port.Filter) error {
|
||
|
rows, err := m.list(ctx, filters)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
for _, row := range rows {
|
||
|
row.mu.Lock()
|
||
|
changed, err := update(row, columns)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
m.changes = append(m.changes, &change[T]{
|
||
|
row: row,
|
||
|
changed: changed,
|
||
|
typ: updateChangeType,
|
||
|
})
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Get implements [port.Querier]
|
||
|
func (m *MapTransaction[T]) Get(ctx context.Context, filters []port.Filter) (o T, err error) {
|
||
|
r, err := m.get(ctx, filters)
|
||
|
if err != nil {
|
||
|
return o, err
|
||
|
}
|
||
|
r.mu.RLock()
|
||
|
|
||
|
return o, nil
|
||
|
}
|
||
|
|
||
|
func (m *MapTransaction[T]) get(ctx context.Context, filters []port.Filter) (o *row[T], err error) {
|
||
|
rows, err := m.list(ctx, filters)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if len(rows) > 1 {
|
||
|
return nil, errors.New("multiple rows returned")
|
||
|
}
|
||
|
return rows[0], nil
|
||
|
}
|
||
|
|
||
|
// List implements [port.Querier]
|
||
|
func (m *MapTransaction[T]) List(ctx context.Context, filters []port.Filter) ([]T, error) {
|
||
|
rows, err := m.list(ctx, filters)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
objects := make([]T, len(rows))
|
||
|
for i, row := range rows {
|
||
|
row.mu.RLock()
|
||
|
objects[i] = *row.object
|
||
|
}
|
||
|
return objects, nil
|
||
|
}
|
||
|
|
||
|
func (m *MapTransaction[T]) list(ctx context.Context, filters []port.Filter) (res []*row[T], err error) {
|
||
|
res = slices.Clone(m.repo.objects)
|
||
|
for _, filter := range filters {
|
||
|
res = slices.DeleteFunc(res, func(r *row[T]) bool {
|
||
|
v := reflect.ValueOf(r.object)
|
||
|
definition := getDefinition[T]()
|
||
|
return v.Field(definition.columnIndexes[filter.Column().Name]).Interface() != filter.Column().Value
|
||
|
})
|
||
|
res = slices.Clip(res)
|
||
|
}
|
||
|
return res, nil
|
||
|
}
|
||
|
|
||
|
func pkFilter[T port.Object](o T) []port.Filter {
|
||
|
definition := getDefinition[T]()
|
||
|
|
||
|
v := reflect.ValueOf(o)
|
||
|
filters := make([]port.Filter, len(definition.pkFieldIndexes))
|
||
|
|
||
|
for _, idx := range definition.pkFieldIndexes {
|
||
|
filters = append(filters, port.NewEqualFilter(
|
||
|
&port.Column{
|
||
|
Name: definition.columnNames[idx],
|
||
|
Value: v.Field(idx).Interface(),
|
||
|
},
|
||
|
))
|
||
|
}
|
||
|
|
||
|
return filters
|
||
|
}
|
||
|
|
||
|
type structDefinition struct {
|
||
|
columnNames []string
|
||
|
columnIndexes map[string]int
|
||
|
pkFieldIndexes []int
|
||
|
fkFieldIndexes []int
|
||
|
}
|
||
|
|
||
|
func getDefinition[T port.Object]() *structDefinition {
|
||
|
definition, ok := pkFields.Load(reflect.TypeFor[T]())
|
||
|
if !ok {
|
||
|
definition = registerPkFields[T]()
|
||
|
}
|
||
|
return definition.(*structDefinition)
|
||
|
}
|
||
|
|
||
|
func registerPkFields[T port.Object]() *structDefinition {
|
||
|
t := reflect.TypeFor[T]()
|
||
|
|
||
|
definition := &structDefinition{
|
||
|
columnIndexes: make(map[string]int, t.NumField()),
|
||
|
columnNames: make([]string, t.NumField()),
|
||
|
pkFieldIndexes: make([]int, 0, t.NumField()),
|
||
|
fkFieldIndexes: make([]int, 0, t.NumField()),
|
||
|
}
|
||
|
|
||
|
for i := 0; i < t.NumField(); i++ {
|
||
|
tags := t.Field(i).Tag.Get("consistent")
|
||
|
fields := strings.Split(tags, ",")
|
||
|
if fields[0] == "" {
|
||
|
fields[0] = t.Field(i).Name
|
||
|
}
|
||
|
definition.columnIndexes[fields[0]] = i
|
||
|
definition.columnNames[i] = fields[0]
|
||
|
|
||
|
for _, field := range fields[1:] {
|
||
|
switch field {
|
||
|
case "pk":
|
||
|
definition.pkFieldIndexes = append(definition.pkFieldIndexes, i)
|
||
|
case "fk":
|
||
|
definition.fkFieldIndexes = append(definition.fkFieldIndexes, i)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
pkFields.Store(t, definition)
|
||
|
return definition
|
||
|
}
|
||
|
|
||
|
func update[T port.Object](r *row[T], columns []*port.Column) (o *T, err error) {
|
||
|
v := reflect.Zero(reflect.TypeOf(r.object))
|
||
|
current := reflect.ValueOf(r.object)
|
||
|
definition := getDefinition[T]()
|
||
|
|
||
|
fields:
|
||
|
for i := 0; i < v.NumField(); i++ {
|
||
|
for _, column := range columns {
|
||
|
if definition.columnIndexes[column.Name] != i {
|
||
|
continue
|
||
|
}
|
||
|
v.Field(i).Set(reflect.ValueOf(column.Value))
|
||
|
continue fields
|
||
|
}
|
||
|
v.Field(i).Set(current.Field(i))
|
||
|
}
|
||
|
|
||
|
return r.object, nil
|
||
|
}
|