chore: move the go code into a subfolder

This commit is contained in:
Florian Forster
2025-08-05 15:20:32 -07:00
parent 4ad22ba456
commit cd2921de26
2978 changed files with 373 additions and 300 deletions

View File

@@ -0,0 +1,176 @@
package eventstore
import (
"context"
"encoding/json"
"strconv"
"time"
"github.com/shopspring/decimal"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/zerrors"
)
var (
_ eventstore.Event = (*event)(nil)
)
type command struct {
InstanceID string
AggregateType string
AggregateID string
CommandType string
Revision uint16
Payload Payload
Creator string
Owner string
}
func (c *command) Aggregate() *eventstore.Aggregate {
return &eventstore.Aggregate{
ID: c.AggregateID,
Type: eventstore.AggregateType(c.AggregateType),
ResourceOwner: c.Owner,
InstanceID: c.InstanceID,
Version: eventstore.Version("v" + strconv.Itoa(int(c.Revision))),
}
}
type event struct {
command *command
createdAt time.Time
sequence uint64
position decimal.Decimal
}
// TODO: remove on v3
func commandToEventOld(sequence *latestSequence, cmd eventstore.Command) (_ *event, err error) {
var payload Payload
if cmd.Payload() != nil {
payload, err = json.Marshal(cmd.Payload())
if err != nil {
logging.WithError(err).Warn("marshal payload failed")
return nil, zerrors.ThrowInternal(err, "V3-MInPK", "Errors.Internal")
}
}
return &event{
command: &command{
InstanceID: sequence.aggregate.InstanceID,
AggregateType: string(sequence.aggregate.Type),
AggregateID: sequence.aggregate.ID,
CommandType: string(cmd.Type()),
Revision: cmd.Revision(),
Payload: payload,
Creator: cmd.Creator(),
Owner: sequence.aggregate.ResourceOwner,
},
sequence: sequence.sequence,
}, nil
}
func commandsToEvents(ctx context.Context, cmds []eventstore.Command) (_ []eventstore.Event, _ []*command, err error) {
events := make([]eventstore.Event, len(cmds))
commands := make([]*command, len(cmds))
for i, cmd := range cmds {
if cmd.Aggregate().InstanceID == "" {
cmd.Aggregate().InstanceID = authz.GetInstance(ctx).InstanceID()
}
events[i], err = commandToEvent(cmd)
if err != nil {
return nil, nil, err
}
commands[i] = events[i].(*event).command
}
return events, commands, nil
}
func commandToEvent(cmd eventstore.Command) (_ eventstore.Event, err error) {
var payload Payload
if cmd.Payload() != nil {
payload, err = json.Marshal(cmd.Payload())
if err != nil {
logging.WithError(err).Warn("marshal payload failed")
return nil, zerrors.ThrowInternal(err, "V3-MInPK", "Errors.Internal")
}
}
command := &command{
InstanceID: cmd.Aggregate().InstanceID,
AggregateType: string(cmd.Aggregate().Type),
AggregateID: cmd.Aggregate().ID,
CommandType: string(cmd.Type()),
Revision: cmd.Revision(),
Payload: payload,
Creator: cmd.Creator(),
Owner: cmd.Aggregate().ResourceOwner,
}
return &event{
command: command,
}, nil
}
// CreationDate implements [eventstore.Event]
func (e *event) CreationDate() time.Time {
return e.CreatedAt()
}
// EditorUser implements [eventstore.Event]
func (e *event) EditorUser() string {
return e.Creator()
}
// Aggregate implements [eventstore.Event]
func (e *event) Aggregate() *eventstore.Aggregate {
return e.command.Aggregate()
}
// Creator implements [eventstore.Event]
func (e *event) Creator() string {
return e.command.Creator
}
// Revision implements [eventstore.Event]
func (e *event) Revision() uint16 {
return e.command.Revision
}
// Type implements [eventstore.Event]
func (e *event) Type() eventstore.EventType {
return eventstore.EventType(e.command.CommandType)
}
// CreatedAt implements [eventstore.Event]
func (e *event) CreatedAt() time.Time {
return e.createdAt
}
// Sequence implements [eventstore.Event]
func (e *event) Sequence() uint64 {
return e.sequence
}
// Position implements [eventstore.Event]
func (e *event) Position() decimal.Decimal {
return e.position
}
// Unmarshal implements [eventstore.Event]
func (e *event) Unmarshal(ptr any) error {
if len(e.command.Payload) == 0 {
return nil
}
if err := json.Unmarshal(e.command.Payload, ptr); err != nil {
return zerrors.ThrowInternal(err, "V3-u8qVo", "Errors.Internal")
}
return nil
}
// DataAsBytes implements [eventstore.Event]
func (e *event) DataAsBytes() []byte {
return e.command.Payload
}

View File

@@ -0,0 +1,482 @@
package eventstore
import (
"context"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/eventstore"
)
func Test_commandToEvent(t *testing.T) {
payload := struct {
ID string
}{
ID: "test",
}
payloadMarshalled, err := json.Marshal(payload)
if err != nil {
t.Fatalf("marshal of payload failed: %v", err)
}
type args struct {
command eventstore.Command
}
type want struct {
event *event
err func(t *testing.T, err error)
}
tests := []struct {
name string
args args
want want
}{
{
name: "no payload",
args: args{
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: nil,
},
},
want: want{
event: mockEvent(
mockAggregate("V3-Red9I"),
0,
nil,
).(*event),
},
},
{
name: "struct payload",
args: args{
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: payload,
},
},
want: want{
event: mockEvent(
mockAggregate("V3-Red9I"),
0,
payloadMarshalled,
).(*event),
},
},
{
name: "pointer payload",
args: args{
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: &payload,
},
},
want: want{
event: mockEvent(
mockAggregate("V3-Red9I"),
0,
payloadMarshalled,
).(*event),
},
},
{
name: "invalid payload",
args: args{
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: func() {},
},
},
want: want{
err: func(t *testing.T, err error) {
assert.Error(t, err)
},
},
},
}
for _, tt := range tests {
if tt.want.err == nil {
tt.want.err = func(t *testing.T, err error) {
require.NoError(t, err)
}
}
t.Run(tt.name, func(t *testing.T) {
got, err := commandToEvent(tt.args.command)
tt.want.err(t, err)
if tt.want.event == nil {
assert.Nil(t, got)
return
}
assert.Equal(t, tt.want.event, got)
})
}
}
func Test_commandToEventOld(t *testing.T) {
payload := struct {
ID string
}{
ID: "test",
}
payloadMarshalled, err := json.Marshal(payload)
if err != nil {
t.Fatalf("marshal of payload failed: %v", err)
}
type args struct {
sequence *latestSequence
command eventstore.Command
}
type want struct {
event *event
err func(t *testing.T, err error)
}
tests := []struct {
name string
args args
want want
}{
{
name: "no payload",
args: args{
sequence: &latestSequence{
aggregate: mockAggregate("V3-Red9I"),
sequence: 0,
},
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: nil,
},
},
want: want{
event: mockEvent(
mockAggregate("V3-Red9I"),
0,
nil,
).(*event),
},
},
{
name: "struct payload",
args: args{
sequence: &latestSequence{
aggregate: mockAggregate("V3-Red9I"),
sequence: 0,
},
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: payload,
},
},
want: want{
event: mockEvent(
mockAggregate("V3-Red9I"),
0,
payloadMarshalled,
).(*event),
},
},
{
name: "pointer payload",
args: args{
sequence: &latestSequence{
aggregate: mockAggregate("V3-Red9I"),
sequence: 0,
},
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: &payload,
},
},
want: want{
event: mockEvent(
mockAggregate("V3-Red9I"),
0,
payloadMarshalled,
).(*event),
},
},
{
name: "invalid payload",
args: args{
sequence: &latestSequence{
aggregate: mockAggregate("V3-Red9I"),
sequence: 0,
},
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: func() {},
},
},
want: want{
err: func(t *testing.T, err error) {
assert.Error(t, err)
},
},
},
}
for _, tt := range tests {
if tt.want.err == nil {
tt.want.err = func(t *testing.T, err error) {
require.NoError(t, err)
}
}
t.Run(tt.name, func(t *testing.T) {
got, err := commandToEventOld(tt.args.sequence, tt.args.command)
tt.want.err(t, err)
assert.Equal(t, tt.want.event, got)
})
}
}
func Test_commandsToEvents(t *testing.T) {
ctx := context.Background()
payload := struct {
ID string
}{
ID: "test",
}
payloadMarshalled, err := json.Marshal(payload)
if err != nil {
t.Fatalf("marshal of payload failed: %v", err)
}
type args struct {
ctx context.Context
cmds []eventstore.Command
}
type want struct {
events []eventstore.Event
commands []*command
err func(t *testing.T, err error)
}
tests := []struct {
name string
args args
want want
}{
{
name: "no commands",
args: args{
ctx: ctx,
cmds: nil,
},
want: want{
events: []eventstore.Event{},
commands: []*command{},
err: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
},
{
name: "single command no payload",
args: args{
ctx: ctx,
cmds: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: nil,
},
},
},
want: want{
events: []eventstore.Event{
mockEvent(
mockAggregate("V3-Red9I"),
0,
nil,
),
},
commands: []*command{
{
InstanceID: "instance",
AggregateType: "type",
AggregateID: "V3-Red9I",
Owner: "ro",
CommandType: "event.type",
Revision: 1,
Payload: nil,
Creator: "creator",
},
},
err: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
},
{
name: "single command no instance id",
args: args{
ctx: authz.WithInstanceID(ctx, "instance from ctx"),
cmds: []eventstore.Command{
&mockCommand{
aggregate: mockAggregateWithInstance("V3-Red9I", ""),
payload: nil,
},
},
},
want: want{
events: []eventstore.Event{
mockEvent(
mockAggregateWithInstance("V3-Red9I", "instance from ctx"),
0,
nil,
),
},
commands: []*command{
{
InstanceID: "instance from ctx",
AggregateType: "type",
AggregateID: "V3-Red9I",
Owner: "ro",
CommandType: "event.type",
Revision: 1,
Payload: nil,
Creator: "creator",
},
},
err: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
},
{
name: "single command with payload",
args: args{
ctx: ctx,
cmds: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: payload,
},
},
},
want: want{
events: []eventstore.Event{
mockEvent(
mockAggregate("V3-Red9I"),
0,
payloadMarshalled,
),
},
commands: []*command{
{
InstanceID: "instance",
AggregateType: "type",
AggregateID: "V3-Red9I",
Owner: "ro",
CommandType: "event.type",
Revision: 1,
Payload: payloadMarshalled,
Creator: "creator",
},
},
err: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
},
{
name: "multiple commands",
args: args{
ctx: ctx,
cmds: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: payload,
},
&mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: nil,
},
},
},
want: want{
events: []eventstore.Event{
mockEvent(
mockAggregate("V3-Red9I"),
0,
payloadMarshalled,
),
mockEvent(
mockAggregate("V3-Red9I"),
0,
nil,
),
},
commands: []*command{
{
InstanceID: "instance",
AggregateType: "type",
AggregateID: "V3-Red9I",
CommandType: "event.type",
Revision: 1,
Payload: payloadMarshalled,
Creator: "creator",
Owner: "ro",
},
{
InstanceID: "instance",
AggregateType: "type",
AggregateID: "V3-Red9I",
CommandType: "event.type",
Revision: 1,
Payload: nil,
Creator: "creator",
Owner: "ro",
},
},
err: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
},
{
name: "invalid command",
args: args{
ctx: ctx,
cmds: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: func() {},
},
},
},
want: want{
events: nil,
commands: nil,
err: func(t *testing.T, err error) {
assert.Error(t, err)
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotEvents, gotCommands, err := commandsToEvents(tt.args.ctx, tt.args.cmds)
tt.want.err(t, err)
assert.Equal(t, tt.want.events, gotEvents)
require.Len(t, gotCommands, len(tt.want.commands))
for i, wantCommand := range tt.want.commands {
assertCommand(t, wantCommand, gotCommands[i])
}
})
}
}
func assertCommand(t *testing.T, want, got *command) {
t.Helper()
assert.Equal(t, want.CommandType, got.CommandType)
assert.Equal(t, want.Payload, got.Payload)
assert.Equal(t, want.Creator, got.Creator)
assert.Equal(t, want.Owner, got.Owner)
assert.Equal(t, want.AggregateID, got.AggregateID)
assert.Equal(t, want.AggregateType, got.AggregateType)
assert.Equal(t, want.InstanceID, got.InstanceID)
assert.Equal(t, want.Revision, got.Revision)
}

View File

@@ -0,0 +1,202 @@
package eventstore
import (
"context"
"database/sql"
"encoding/json"
"errors"
"sync"
"github.com/DATA-DOG/go-sqlmock"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/stdlib"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
"github.com/zitadel/zitadel/internal/eventstore"
)
func init() {
dialect.RegisterAfterConnect(RegisterEventstoreTypes)
}
var (
// pushPlaceholderFmt defines how data are inserted into the events table
pushPlaceholderFmt = "($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $%d)"
// uniqueConstraintPlaceholderFmt defines the format of the unique constraint error returned from the database
uniqueConstraintPlaceholderFmt = "(%s, %s, %s)"
_ eventstore.Pusher = (*Eventstore)(nil)
)
type Eventstore struct {
client *database.DB
}
var (
textType = &pgtype.Type{
Name: "text",
OID: pgtype.TextOID,
Codec: pgtype.TextCodec{},
}
commandType = &pgtype.Type{
Codec: &pgtype.CompositeCodec{
Fields: []pgtype.CompositeCodecField{
{
Name: "instance_id",
Type: textType,
},
{
Name: "aggregate_type",
Type: textType,
},
{
Name: "aggregate_id",
Type: textType,
},
{
Name: "command_type",
Type: textType,
},
{
Name: "revision",
Type: &pgtype.Type{
Name: "int2",
OID: pgtype.Int2OID,
Codec: pgtype.Int2Codec{},
},
},
{
Name: "payload",
Type: &pgtype.Type{
Name: "jsonb",
OID: pgtype.JSONBOID,
Codec: &pgtype.JSONBCodec{
Marshal: json.Marshal,
Unmarshal: json.Unmarshal,
},
},
},
{
Name: "creator",
Type: textType,
},
{
Name: "owner",
Type: textType,
},
},
},
}
commandArrayCodec = &pgtype.Type{
Codec: &pgtype.ArrayCodec{
ElementType: commandType,
},
}
)
var typeMu sync.Mutex
func RegisterEventstoreTypes(ctx context.Context, conn *pgx.Conn) error {
// conn.TypeMap is not thread safe
typeMu.Lock()
defer typeMu.Unlock()
m := conn.TypeMap()
var cmd *command
if _, ok := m.TypeForValue(cmd); ok {
return nil
}
if commandType.OID == 0 || commandArrayCodec.OID == 0 {
err := conn.QueryRow(ctx, "select oid, typarray from pg_type where typname = $1 and typnamespace = (select oid from pg_namespace where nspname = $2)", "command", "eventstore").
Scan(&commandType.OID, &commandArrayCodec.OID)
if err != nil {
logging.WithError(err).Debug("failed to get oid for command type")
return nil
}
if commandType.OID == 0 || commandArrayCodec.OID == 0 {
logging.Debug("oid for command type not found")
return nil
}
}
m.RegisterTypes([]*pgtype.Type{
{
Name: "eventstore.command",
Codec: commandType.Codec,
OID: commandType.OID,
},
{
Name: "command",
Codec: commandType.Codec,
OID: commandType.OID,
},
{
Name: "eventstore._command",
Codec: commandArrayCodec.Codec,
OID: commandArrayCodec.OID,
},
{
Name: "_command",
Codec: commandArrayCodec.Codec,
OID: commandArrayCodec.OID,
},
})
dialect.RegisterDefaultPgTypeVariants[command](m, "eventstore.command", "eventstore._command")
dialect.RegisterDefaultPgTypeVariants[command](m, "command", "_command")
return nil
}
// Client implements the [eventstore.Pusher]
func (es *Eventstore) Client() *database.DB {
return es.client
}
func NewEventstore(client *database.DB) *Eventstore {
return &Eventstore{client: client}
}
func (es *Eventstore) Health(ctx context.Context) error {
return es.client.PingContext(ctx)
}
var errTypesNotFound = errors.New("types not found")
func CheckExecutionPlan(ctx context.Context, conn *sql.Conn) error {
return conn.Raw(func(driverConn any) error {
if _, ok := driverConn.(sqlmock.SqlmockCommon); ok {
return nil
}
conn, ok := driverConn.(*stdlib.Conn)
if !ok {
return errTypesNotFound
}
return RegisterEventstoreTypes(ctx, conn.Conn())
})
}
func (es *Eventstore) pushTx(ctx context.Context, client database.ContextQueryExecuter) (tx database.Tx, deferrable func(err error) error, err error) {
tx, ok := client.(database.Tx)
if ok {
return tx, nil, nil
}
beginner, ok := client.(database.Beginner)
if !ok {
beginner = es.client
}
tx, err = beginner.BeginTx(ctx, &sql.TxOptions{
Isolation: sql.LevelReadCommitted,
ReadOnly: false,
})
if err != nil {
return nil, nil, err
}
return tx, func(err error) error { return database.CloseTransaction(tx, err) }, nil
}

View File

@@ -0,0 +1,369 @@
package eventstore
import (
"context"
"database/sql"
_ "embed"
"encoding/json"
"reflect"
"slices"
"strconv"
"strings"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
type fieldValue struct {
value []byte
}
func (value *fieldValue) Unmarshal(ptr any) error {
return json.Unmarshal(value.value, ptr)
}
func (es *Eventstore) FillFields(ctx context.Context, events ...eventstore.FillFieldsEvent) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer span.End()
tx, err := es.client.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err != nil {
return err
}
defer func() {
if err != nil {
_ = tx.Rollback()
return
}
err = tx.Commit()
}()
return handleFieldFillEvents(ctx, tx, events)
}
// Search implements the [eventstore.Search] method
func (es *Eventstore) Search(ctx context.Context, conditions ...map[eventstore.FieldType]any) (result []*eventstore.SearchResult, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
var builder strings.Builder
args := buildSearchStatement(ctx, &builder, conditions...)
err = es.client.QueryContext(
ctx,
func(rows *sql.Rows) error {
for rows.Next() {
var (
res eventstore.SearchResult
value fieldValue
)
err = rows.Scan(
&res.Aggregate.InstanceID,
&res.Aggregate.ResourceOwner,
&res.Aggregate.Type,
&res.Aggregate.ID,
&res.Object.Type,
&res.Object.ID,
&res.Object.Revision,
&res.FieldName,
&value.value,
)
if err != nil {
return err
}
res.Value = &value
result = append(result, &res)
}
return nil
},
builder.String(),
args...,
)
if err != nil {
return nil, err
}
return result, nil
}
const searchQueryPrefix = `SELECT instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value FROM eventstore.fields WHERE instance_id = $1`
func buildSearchStatement(ctx context.Context, builder *strings.Builder, conditions ...map[eventstore.FieldType]any) []any {
args := make([]any, 0, len(conditions)*4+1)
args = append(args, authz.GetInstance(ctx).InstanceID())
builder.WriteString(searchQueryPrefix)
builder.WriteString(" AND ")
if len(conditions) > 1 {
builder.WriteRune('(')
}
for i, condition := range conditions {
if i > 0 {
builder.WriteString(" OR ")
}
if len(condition) > 1 {
builder.WriteRune('(')
}
args = append(args, buildSearchCondition(builder, len(args)+1, condition)...)
if len(condition) > 1 {
builder.WriteRune(')')
}
}
if len(conditions) > 1 {
builder.WriteRune(')')
}
return args
}
func buildSearchCondition(builder *strings.Builder, index int, conditions map[eventstore.FieldType]any) []any {
args := make([]any, 0, len(conditions))
orderedCondition := make([]eventstore.FieldType, 0, len(conditions))
for field := range conditions {
orderedCondition = append(orderedCondition, field)
}
slices.Sort(orderedCondition)
for _, field := range orderedCondition {
if len(args) > 0 {
builder.WriteString(" AND ")
}
builder.WriteString(fieldNameByType(field, conditions[field]))
builder.WriteString(" = $")
builder.WriteString(strconv.Itoa(index + len(args)))
args = append(args, conditions[field])
}
return args
}
func (es *Eventstore) handleFieldCommands(ctx context.Context, tx database.Tx, commands []eventstore.Command) error {
for _, command := range commands {
if len(command.Fields()) > 0 {
if err := handleFieldOperations(ctx, tx, command.Fields()); err != nil {
return err
}
}
}
return nil
}
func handleFieldFillEvents(ctx context.Context, tx database.Tx, events []eventstore.FillFieldsEvent) error {
for _, event := range events {
if len(event.Fields()) == 0 {
continue
}
if err := handleFieldOperations(ctx, tx, event.Fields()); err != nil {
return err
}
}
return nil
}
func handleFieldOperations(ctx context.Context, tx database.Tx, operations []*eventstore.FieldOperation) error {
for _, operation := range operations {
if operation.Set != nil {
if err := handleFieldSet(ctx, tx, operation.Set); err != nil {
return err
}
continue
}
if operation.Remove != nil {
if err := handleSearchDelete(ctx, tx, operation.Remove); err != nil {
return err
}
}
}
return nil
}
func handleFieldSet(ctx context.Context, tx database.Tx, field *eventstore.Field) error {
if len(field.UpsertConflictFields) == 0 {
return handleSearchInsert(ctx, tx, field)
}
return handleSearchUpsert(ctx, tx, field)
}
const (
insertField = `INSERT INTO eventstore.fields (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)`
)
func handleSearchInsert(ctx context.Context, tx database.Tx, field *eventstore.Field) error {
value, err := json.Marshal(field.Value.Value)
if err != nil {
return zerrors.ThrowInvalidArgument(err, "V3-fcrW1", "unable to marshal field value")
}
_, err = tx.ExecContext(
ctx,
insertField,
field.Aggregate.InstanceID,
field.Aggregate.ResourceOwner,
field.Aggregate.Type,
field.Aggregate.ID,
field.Object.Type,
field.Object.ID,
field.Object.Revision,
field.FieldName,
value,
field.Value.MustBeUnique,
field.Value.ShouldIndex,
)
return err
}
const (
fieldsUpsertPrefix = `WITH upsert AS (UPDATE eventstore.fields SET (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) = ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) WHERE `
fieldsUpsertSuffix = ` RETURNING * ) INSERT INTO eventstore.fields (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) SELECT $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11 WHERE NOT EXISTS (SELECT 1 FROM upsert)`
)
func handleSearchUpsert(ctx context.Context, tx database.Tx, field *eventstore.Field) error {
value, err := json.Marshal(field.Value.Value)
if err != nil {
return zerrors.ThrowInvalidArgument(err, "V3-fcrW1", "unable to marshal field value")
}
_, err = tx.ExecContext(
ctx,
writeUpsertField(field.UpsertConflictFields),
field.Aggregate.InstanceID,
field.Aggregate.ResourceOwner,
field.Aggregate.Type,
field.Aggregate.ID,
field.Object.Type,
field.Object.ID,
field.Object.Revision,
field.FieldName,
value,
field.Value.MustBeUnique,
field.Value.ShouldIndex,
)
return err
}
func writeUpsertField(fields []eventstore.FieldType) string {
var builder strings.Builder
builder.WriteString(fieldsUpsertPrefix)
for i, fieldName := range fields {
if i > 0 {
builder.WriteString(" AND ")
}
name, index := searchFieldNameAndIndexByTypeForPush(fieldName)
builder.WriteString(name)
builder.WriteString(" = ")
builder.WriteString(index)
}
builder.WriteString(fieldsUpsertSuffix)
return builder.String()
}
const removeSearch = `DELETE FROM eventstore.fields WHERE `
func handleSearchDelete(ctx context.Context, tx database.Tx, clauses map[eventstore.FieldType]any) error {
if len(clauses) == 0 {
return zerrors.ThrowInvalidArgument(nil, "V3-oqlBZ", "no conditions")
}
stmt, args := writeDeleteField(clauses)
_, err := tx.ExecContext(ctx, stmt, args...)
return err
}
func writeDeleteField(clauses map[eventstore.FieldType]any) (string, []any) {
var (
builder strings.Builder
args = make([]any, 0, len(clauses))
)
builder.WriteString(removeSearch)
orderedCondition := make([]eventstore.FieldType, 0, len(clauses))
for field := range clauses {
orderedCondition = append(orderedCondition, field)
}
slices.Sort(orderedCondition)
for _, fieldName := range orderedCondition {
if len(args) > 0 {
builder.WriteString(" AND ")
}
builder.WriteString(fieldNameByType(fieldName, clauses[fieldName]))
builder.WriteString(" = $")
builder.WriteString(strconv.Itoa(len(args) + 1))
args = append(args, clauses[fieldName])
}
return builder.String(), args
}
func fieldNameByType(typ eventstore.FieldType, value any) string {
switch typ {
case eventstore.FieldTypeAggregateID:
return "aggregate_id"
case eventstore.FieldTypeAggregateType:
return "aggregate_type"
case eventstore.FieldTypeInstanceID:
return "instance_id"
case eventstore.FieldTypeResourceOwner:
return "resource_owner"
case eventstore.FieldTypeFieldName:
return "field_name"
case eventstore.FieldTypeObjectType:
return "object_type"
case eventstore.FieldTypeObjectID:
return "object_id"
case eventstore.FieldTypeObjectRevision:
return "object_revision"
case eventstore.FieldTypeValue:
return valueColumn(value)
}
return ""
}
func searchFieldNameAndIndexByTypeForPush(typ eventstore.FieldType) (string, string) {
switch typ {
case eventstore.FieldTypeInstanceID:
return "instance_id", "$1"
case eventstore.FieldTypeResourceOwner:
return "resource_owner", "$2"
case eventstore.FieldTypeAggregateType:
return "aggregate_type", "$3"
case eventstore.FieldTypeAggregateID:
return "aggregate_id", "$4"
case eventstore.FieldTypeObjectType:
return "object_type", "$5"
case eventstore.FieldTypeObjectID:
return "object_id", "$6"
case eventstore.FieldTypeObjectRevision:
return "object_revision", "$7"
case eventstore.FieldTypeFieldName:
return "field_name", "$8"
case eventstore.FieldTypeValue:
return "value", "$9"
}
return "", ""
}
func valueColumn(value any) string {
//nolint: exhaustive
switch reflect.TypeOf(value).Kind() {
case reflect.Bool:
return "bool_value"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64:
return "number_value"
case reflect.String:
return "text_value"
}
return ""
}

View File

@@ -0,0 +1,260 @@
package eventstore
import (
"context"
_ "embed"
"reflect"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/eventstore"
)
func Test_handleSearchDelete(t *testing.T) {
type args struct {
clauses map[eventstore.FieldType]any
}
type want struct {
stmt string
args []any
}
tests := []struct {
name string
args args
want want
}{
{
name: "1 condition",
args: args{
clauses: map[eventstore.FieldType]any{
eventstore.FieldTypeInstanceID: "i_id",
},
},
want: want{
stmt: "DELETE FROM eventstore.fields WHERE instance_id = $1",
args: []any{"i_id"},
},
},
{
name: "2 conditions",
args: args{
clauses: map[eventstore.FieldType]any{
eventstore.FieldTypeInstanceID: "i_id",
eventstore.FieldTypeAggregateID: "a_id",
},
},
want: want{
stmt: "DELETE FROM eventstore.fields WHERE aggregate_id = $1 AND instance_id = $2",
args: []any{"a_id", "i_id"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
stmt, args := writeDeleteField(tt.args.clauses)
if stmt != tt.want.stmt {
t.Errorf("handleSearchDelete() stmt = %q, want %q", stmt, tt.want.stmt)
}
assert.Equal(t, tt.want.args, args)
})
}
}
func Test_writeUpsertField(t *testing.T) {
type args struct {
fields []eventstore.FieldType
}
tests := []struct {
name string
args args
want string
}{
{
name: "1 field",
args: args{
fields: []eventstore.FieldType{
eventstore.FieldTypeInstanceID,
},
},
want: "WITH upsert AS (UPDATE eventstore.fields SET (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) = ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) WHERE instance_id = $1 RETURNING * ) INSERT INTO eventstore.fields (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) SELECT $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11 WHERE NOT EXISTS (SELECT 1 FROM upsert)",
},
{
name: "2 fields",
args: args{
fields: []eventstore.FieldType{
eventstore.FieldTypeInstanceID,
eventstore.FieldTypeAggregateType,
},
},
want: "WITH upsert AS (UPDATE eventstore.fields SET (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) = ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) WHERE instance_id = $1 AND aggregate_type = $3 RETURNING * ) INSERT INTO eventstore.fields (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) SELECT $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11 WHERE NOT EXISTS (SELECT 1 FROM upsert)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := writeUpsertField(tt.args.fields); got != tt.want {
t.Errorf("writeUpsertField() = %q, want %q", got, tt.want)
}
})
}
}
func Test_buildSearchCondition(t *testing.T) {
type args struct {
index int
conditions map[eventstore.FieldType]any
}
type want struct {
stmt string
args []any
}
tests := []struct {
name string
args args
want want
}{
{
name: "1 condition",
args: args{
index: 1,
conditions: map[eventstore.FieldType]any{
eventstore.FieldTypeAggregateID: "a_id",
},
},
want: want{
stmt: "aggregate_id = $1",
args: []any{"a_id"},
},
},
{
name: "3 condition",
args: args{
index: 1,
conditions: map[eventstore.FieldType]any{
eventstore.FieldTypeAggregateID: "a_id",
eventstore.FieldTypeInstanceID: "i_id",
eventstore.FieldTypeAggregateType: "a_type",
},
},
want: want{
stmt: "aggregate_type = $1 AND aggregate_id = $2 AND instance_id = $3",
args: []any{"a_type", "a_id", "i_id"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var builder strings.Builder
if got := buildSearchCondition(&builder, tt.args.index, tt.args.conditions); !reflect.DeepEqual(got, tt.want.args) {
t.Errorf("buildSearchCondition() = %v, want %v", got, tt.want)
}
if tt.want.stmt != builder.String() {
t.Errorf("buildSearchCondition() stmt = %q, want %q", builder.String(), tt.want.stmt)
}
})
}
}
func Test_buildSearchStatement(t *testing.T) {
type args struct {
index int
conditions []map[eventstore.FieldType]any
}
type want struct {
stmt string
args []any
}
tests := []struct {
name string
args args
want want
}{
{
name: "1 condition with 1 field",
args: args{
index: 1,
conditions: []map[eventstore.FieldType]any{
{
eventstore.FieldTypeAggregateID: "a_id",
},
},
},
want: want{
stmt: "SELECT instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value FROM eventstore.fields WHERE instance_id = $1 AND aggregate_id = $2",
args: []any{"a_id"},
},
},
{
name: "1 condition with 3 fields",
args: args{
index: 1,
conditions: []map[eventstore.FieldType]any{
{
eventstore.FieldTypeAggregateID: "a_id",
eventstore.FieldTypeInstanceID: "i_id",
eventstore.FieldTypeAggregateType: "a_type",
},
},
},
want: want{
stmt: "SELECT instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value FROM eventstore.fields WHERE instance_id = $1 AND (aggregate_type = $2 AND aggregate_id = $3 AND instance_id = $4)",
args: []any{"a_type", "a_id", "i_id"},
},
},
{
name: "2 condition with 1 field",
args: args{
index: 1,
conditions: []map[eventstore.FieldType]any{
{
eventstore.FieldTypeAggregateID: "a_id",
},
{
eventstore.FieldTypeAggregateType: "a_type",
},
},
},
want: want{
stmt: "SELECT instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value FROM eventstore.fields WHERE instance_id = $1 AND (aggregate_id = $2 OR aggregate_type = $3)",
args: []any{"a_id", "a_type"},
},
},
{
name: "2 condition with 2 fields",
args: args{
index: 1,
conditions: []map[eventstore.FieldType]any{
{
eventstore.FieldTypeAggregateID: "a_id1",
eventstore.FieldTypeAggregateType: "a_type1",
},
{
eventstore.FieldTypeAggregateID: "a_id2",
eventstore.FieldTypeAggregateType: "a_type2",
},
},
},
want: want{
stmt: "SELECT instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value FROM eventstore.fields WHERE instance_id = $1 AND ((aggregate_type = $2 AND aggregate_id = $3) OR (aggregate_type = $4 AND aggregate_id = $5))",
args: []any{"a_type1", "a_id1", "a_type2", "a_id2"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var builder strings.Builder
tt.want.args = append([]any{"i_id"}, tt.want.args...)
ctx := authz.WithInstanceID(context.Background(), "i_id")
if got := buildSearchStatement(ctx, &builder, tt.args.conditions...); !reflect.DeepEqual(got, tt.want.args) {
t.Errorf("buildSearchStatement() = %v, want %v", got, tt.want)
}
if tt.want.stmt != builder.String() {
t.Errorf("buildSearchStatement() stmt = %q, want %q", builder.String(), tt.want.stmt)
}
})
}
}

View File

@@ -0,0 +1,83 @@
package eventstore
import (
"github.com/zitadel/zitadel/internal/eventstore"
)
var _ eventstore.Command = (*mockCommand)(nil)
type mockCommand struct {
aggregate *eventstore.Aggregate
payload any
constraints []*eventstore.UniqueConstraint
}
// Aggregate implements [eventstore.Command]
func (m *mockCommand) Aggregate() *eventstore.Aggregate {
return m.aggregate
}
// Creator implements [eventstore.Command]
func (m *mockCommand) Creator() string {
return "creator"
}
// Revision implements [eventstore.Command]
func (m *mockCommand) Revision() uint16 {
return 1
}
// Type implements [eventstore.Command]
func (m *mockCommand) Type() eventstore.EventType {
return "event.type"
}
// Payload implements [eventstore.Command]
func (m *mockCommand) Payload() any {
return m.payload
}
// UniqueConstraints implements [eventstore.Command]
func (m *mockCommand) UniqueConstraints() []*eventstore.UniqueConstraint {
return m.constraints
}
func (e *mockCommand) Fields() []*eventstore.FieldOperation {
return nil
}
func mockEvent(aggregate *eventstore.Aggregate, sequence uint64, payload Payload) eventstore.Event {
return &event{
command: &command{
InstanceID: aggregate.InstanceID,
AggregateType: string(aggregate.Type),
AggregateID: aggregate.ID,
Owner: aggregate.ResourceOwner,
Creator: "creator",
Revision: 1,
CommandType: "event.type",
Payload: payload,
},
sequence: sequence,
}
}
func mockAggregate(id string) *eventstore.Aggregate {
return &eventstore.Aggregate{
ID: id,
Type: "type",
ResourceOwner: "ro",
InstanceID: "instance",
Version: "v1",
}
}
func mockAggregateWithInstance(id, instance string) *eventstore.Aggregate {
return &eventstore.Aggregate{
ID: id,
InstanceID: instance,
Type: "type",
ResourceOwner: "ro",
Version: "v1",
}
}

View File

@@ -0,0 +1,108 @@
package eventstore
import (
"context"
"database/sql"
_ "embed"
"fmt"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
var pushTxOpts = &sql.TxOptions{
Isolation: sql.LevelReadCommitted,
ReadOnly: false,
}
func (es *Eventstore) Push(ctx context.Context, client database.ContextQueryExecuter, commands ...eventstore.Command) (events []eventstore.Event, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
events, err = es.writeCommands(ctx, client, commands)
if isSetupNotExecutedError(err) {
return es.pushWithoutFunc(ctx, client, commands...)
}
return events, err
}
func (es *Eventstore) writeCommands(ctx context.Context, client database.ContextQueryExecuter, commands []eventstore.Command) (_ []eventstore.Event, err error) {
var conn *sql.Conn
switch c := client.(type) {
case database.Client:
conn, err = c.Conn(ctx)
case nil:
conn, err = es.client.Conn(ctx)
client = conn
}
if err != nil {
return nil, err
}
if conn != nil {
defer conn.Close()
}
tx, close, err := es.pushTx(ctx, client)
if err != nil {
return nil, err
}
if close != nil {
defer func() {
err = close(err)
}()
}
_, err = tx.ExecContext(ctx, fmt.Sprintf("SET LOCAL application_name = '%s'", fmt.Sprintf("zitadel_es_pusher_%s", authz.GetInstance(ctx).InstanceID())))
if err != nil {
return nil, err
}
events, err := writeEvents(ctx, tx, commands)
if err != nil {
return nil, err
}
if err = handleUniqueConstraints(ctx, tx, commands); err != nil {
return nil, err
}
err = es.handleFieldCommands(ctx, tx, commands)
if err != nil {
return nil, err
}
return events, nil
}
func writeEvents(ctx context.Context, tx database.Tx, commands []eventstore.Command) (_ []eventstore.Event, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
events, cmds, err := commandsToEvents(ctx, commands)
if err != nil {
return nil, err
}
rows, err := tx.QueryContext(ctx, `select owner, created_at, "sequence", position from eventstore.push($1::eventstore.command[])`, cmds)
if err != nil {
return nil, err
}
defer rows.Close()
for i := 0; rows.Next(); i++ {
err = rows.Scan(&events[i].(*event).command.Owner, &events[i].(*event).createdAt, &events[i].(*event).sequence, &events[i].(*event).position)
if err != nil {
logging.WithError(err).Warn("failed to scan events")
return nil, err
}
}
if err = rows.Err(); err != nil {
return nil, err
}
return events, nil
}

View File

@@ -0,0 +1,18 @@
INSERT INTO eventstore.events2 (
instance_id
, "owner"
, aggregate_type
, aggregate_id
, revision
, creator
, event_type
, payload
, "sequence"
, created_at
, "position"
, in_tx_order
) VALUES
%s
RETURNING created_at, "position";

View File

@@ -0,0 +1,253 @@
package eventstore
import (
_ "embed"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/postgres"
"github.com/zitadel/zitadel/internal/eventstore"
)
func Test_mapCommands(t *testing.T) {
type args struct {
commands []eventstore.Command
sequences []*latestSequence
}
type want struct {
events []eventstore.Event
placeHolders []string
args []any
err func(t *testing.T, err error)
shouldPanic bool
}
tests := []struct {
name string
args args
want want
}{
{
name: "no commands",
args: args{
commands: []eventstore.Command{},
sequences: []*latestSequence{},
},
want: want{
events: []eventstore.Event{},
placeHolders: []string{},
args: []any{},
},
},
{
name: "one command",
args: args{
commands: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-VEIvq"),
},
},
sequences: []*latestSequence{
{
aggregate: mockAggregate("V3-VEIvq"),
sequence: 0,
},
},
},
want: want{
events: []eventstore.Event{
mockEvent(
mockAggregate("V3-VEIvq"),
1,
nil,
),
},
placeHolders: []string{
"($1, $2, $3, $4, $5, $6, $7, $8, $9, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $10)",
},
args: []any{
"instance",
"ro",
"type",
"V3-VEIvq",
uint16(1),
"creator",
"event.type",
Payload(nil),
uint64(1),
0,
},
err: func(t *testing.T, err error) {},
},
},
{
name: "multiple commands same aggregate",
args: args{
commands: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-VEIvq"),
},
&mockCommand{
aggregate: mockAggregate("V3-VEIvq"),
},
},
sequences: []*latestSequence{
{
aggregate: mockAggregate("V3-VEIvq"),
sequence: 5,
},
},
},
want: want{
events: []eventstore.Event{
mockEvent(
mockAggregate("V3-VEIvq"),
6,
nil,
),
mockEvent(
mockAggregate("V3-VEIvq"),
7,
nil,
),
},
placeHolders: []string{
"($1, $2, $3, $4, $5, $6, $7, $8, $9, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $10)",
"($11, $12, $13, $14, $15, $16, $17, $18, $19, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $20)",
},
args: []any{
// first event
"instance",
"ro",
"type",
"V3-VEIvq",
uint16(1),
"creator",
"event.type",
Payload(nil),
uint64(6),
0,
// second event
"instance",
"ro",
"type",
"V3-VEIvq",
uint16(1),
"creator",
"event.type",
Payload(nil),
uint64(7),
1,
},
err: func(t *testing.T, err error) {},
},
},
{
name: "one command per aggregate",
args: args{
commands: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-VEIvq"),
},
&mockCommand{
aggregate: mockAggregate("V3-IT6VN"),
},
},
sequences: []*latestSequence{
{
aggregate: mockAggregate("V3-VEIvq"),
sequence: 5,
},
{
aggregate: mockAggregate("V3-IT6VN"),
sequence: 0,
},
},
},
want: want{
events: []eventstore.Event{
mockEvent(
mockAggregate("V3-VEIvq"),
6,
nil,
),
mockEvent(
mockAggregate("V3-IT6VN"),
1,
nil,
),
},
placeHolders: []string{
"($1, $2, $3, $4, $5, $6, $7, $8, $9, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $10)",
"($11, $12, $13, $14, $15, $16, $17, $18, $19, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $20)",
},
args: []any{
// first event
"instance",
"ro",
"type",
"V3-VEIvq",
uint16(1),
"creator",
"event.type",
Payload(nil),
uint64(6),
0,
// second event
"instance",
"ro",
"type",
"V3-IT6VN",
uint16(1),
"creator",
"event.type",
Payload(nil),
uint64(1),
1,
},
err: func(t *testing.T, err error) {},
},
},
{
name: "missing sequence",
args: args{
commands: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-VEIvq"),
},
},
sequences: []*latestSequence{},
},
want: want{
events: []eventstore.Event{},
placeHolders: []string{},
args: []any{},
err: func(t *testing.T, err error) {},
shouldPanic: true,
},
},
}
for _, tt := range tests {
if tt.want.err == nil {
tt.want.err = func(t *testing.T, err error) {
require.NoError(t, err)
}
}
// is used to set the the [pushPlaceholderFmt]
NewEventstore(&database.DB{Database: new(postgres.Config)})
t.Run(tt.name, func(t *testing.T) {
defer func() {
cause := recover()
assert.Equal(t, tt.want.shouldPanic, cause != nil)
}()
gotEvents, gotPlaceHolders, gotArgs, err := mapCommands(tt.args.commands, tt.args.sequences)
tt.want.err(t, err)
assert.ElementsMatch(t, tt.want.events, gotEvents)
assert.ElementsMatch(t, tt.want.placeHolders, gotPlaceHolders)
assert.ElementsMatch(t, tt.want.args, gotArgs)
})
}
}

View File

@@ -0,0 +1,162 @@
package eventstore
import (
"context"
_ "embed"
"errors"
"fmt"
"strings"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/zerrors"
)
// checks whether the error is caused because setup step 39 was not executed
func isSetupNotExecutedError(err error) bool {
if err == nil {
return false
}
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
return (pgErr.Code == "42704" && strings.Contains(pgErr.Message, "eventstore.command")) ||
(pgErr.Code == "42883" && strings.Contains(pgErr.Message, "eventstore.push"))
}
return errors.Is(err, errTypesNotFound)
}
var (
//go:embed push.sql
pushStmt string
)
// pushWithoutFunc implements pushing events before setup step 39 was introduced.
// TODO: remove with v3
func (es *Eventstore) pushWithoutFunc(ctx context.Context, client database.ContextQueryExecuter, commands ...eventstore.Command) (events []eventstore.Event, err error) {
tx, closeTx, err := es.pushTx(ctx, client)
if err != nil {
return nil, err
}
defer func() {
err = closeTx(err)
}()
var (
sequences []*latestSequence
)
sequences, err = latestSequences(ctx, tx, commands)
if err != nil {
return nil, err
}
events, err = es.writeEventsOld(ctx, tx, sequences, commands)
if err != nil {
return nil, err
}
if err = handleUniqueConstraints(ctx, tx, commands); err != nil {
return nil, err
}
err = es.handleFieldCommands(ctx, tx, commands)
if err != nil {
return nil, err
}
return events, nil
}
func (es *Eventstore) writeEventsOld(ctx context.Context, tx database.Tx, sequences []*latestSequence, commands []eventstore.Command) ([]eventstore.Event, error) {
events, placeholders, args, err := mapCommands(commands, sequences)
if err != nil {
return nil, err
}
rows, err := tx.QueryContext(ctx, fmt.Sprintf(pushStmt, strings.Join(placeholders, ", ")), args...)
if err != nil {
return nil, err
}
defer rows.Close()
for i := 0; rows.Next(); i++ {
err = rows.Scan(&events[i].(*event).createdAt, &events[i].(*event).position)
if err != nil {
logging.WithError(err).Warn("failed to scan events")
return nil, err
}
}
if err := rows.Err(); err != nil {
pgErr := new(pgconn.PgError)
if errors.As(err, &pgErr) {
// Check if push tries to write an event just written
// by another transaction
if pgErr.Code == "40001" {
// TODO: @livio-a should we return the parent or not?
return nil, zerrors.ThrowInvalidArgument(err, "V3-p5xAn", "Errors.AlreadyExists")
}
}
logging.WithError(rows.Err()).Warn("failed to push events")
return nil, zerrors.ThrowInternal(err, "V3-VGnZY", "Errors.Internal")
}
return events, nil
}
const argsPerCommand = 10
func mapCommands(commands []eventstore.Command, sequences []*latestSequence) (events []eventstore.Event, placeholders []string, args []any, err error) {
events = make([]eventstore.Event, len(commands))
args = make([]any, 0, len(commands)*argsPerCommand)
placeholders = make([]string, len(commands))
for i, command := range commands {
sequence := searchSequenceByCommand(sequences, command)
if sequence == nil {
logging.WithFields(
"aggType", command.Aggregate().Type,
"aggID", command.Aggregate().ID,
"instance", command.Aggregate().InstanceID,
).Panic("no sequence found")
// added return for linting
return nil, nil, nil, nil
}
sequence.sequence++
events[i], err = commandToEventOld(sequence, command)
if err != nil {
return nil, nil, nil, err
}
placeholders[i] = fmt.Sprintf(pushPlaceholderFmt,
i*argsPerCommand+1,
i*argsPerCommand+2,
i*argsPerCommand+3,
i*argsPerCommand+4,
i*argsPerCommand+5,
i*argsPerCommand+6,
i*argsPerCommand+7,
i*argsPerCommand+8,
i*argsPerCommand+9,
i*argsPerCommand+10,
)
args = append(args,
events[i].(*event).command.InstanceID,
events[i].(*event).command.Owner,
events[i].(*event).command.AggregateType,
events[i].(*event).command.AggregateID,
events[i].(*event).command.Revision,
events[i].(*event).command.Creator,
events[i].(*event).command.CommandType,
events[i].(*event).command.Payload,
events[i].(*event).sequence,
i,
)
}
return events, placeholders, args, nil
}

View File

@@ -0,0 +1,144 @@
package eventstore
import (
"context"
"database/sql"
_ "embed"
"fmt"
"strings"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/zerrors"
)
type latestSequence struct {
aggregate *eventstore.Aggregate
sequence uint64
}
//go:embed sequences_query.sql
var latestSequencesStmt string
func latestSequences(ctx context.Context, tx database.Tx, commands []eventstore.Command) ([]*latestSequence, error) {
sequences := commandsToSequences(ctx, commands)
conditions, args := sequencesToSql(sequences)
rows, err := tx.QueryContext(ctx, fmt.Sprintf(latestSequencesStmt, strings.Join(conditions, " UNION ALL ")), args...)
if err != nil {
return nil, zerrors.ThrowInternal(err, "V3-5jU5z", "Errors.Internal")
}
defer rows.Close()
for rows.Next() {
if err := scanToSequence(rows, sequences); err != nil {
return nil, zerrors.ThrowInternal(err, "V3-Ydiwv", "Errors.Internal")
}
}
if rows.Err() != nil {
return nil, zerrors.ThrowInternal(rows.Err(), "V3-XApDk", "Errors.Internal")
}
return sequences, nil
}
func searchSequenceByCommand(sequences []*latestSequence, command eventstore.Command) *latestSequence {
for _, sequence := range sequences {
if sequence.aggregate.Type == command.Aggregate().Type &&
sequence.aggregate.ID == command.Aggregate().ID &&
sequence.aggregate.InstanceID == command.Aggregate().InstanceID {
return sequence
}
}
return nil
}
func searchSequence(sequences []*latestSequence, aggregateType eventstore.AggregateType, aggregateID, instanceID string) *latestSequence {
for _, sequence := range sequences {
if sequence.aggregate.Type == aggregateType &&
sequence.aggregate.ID == aggregateID &&
sequence.aggregate.InstanceID == instanceID {
return sequence
}
}
return nil
}
func commandsToSequences(ctx context.Context, commands []eventstore.Command) []*latestSequence {
sequences := make([]*latestSequence, 0, len(commands))
for _, command := range commands {
if searchSequenceByCommand(sequences, command) != nil {
continue
}
if command.Aggregate().InstanceID == "" {
command.Aggregate().InstanceID = authz.GetInstance(ctx).InstanceID()
}
sequences = append(sequences, &latestSequence{
aggregate: command.Aggregate(),
})
}
return sequences
}
const argsPerCondition = 3
func sequencesToSql(sequences []*latestSequence) (conditions []string, args []any) {
args = make([]interface{}, 0, len(sequences)*argsPerCondition)
conditions = make([]string, len(sequences))
for i, sequence := range sequences {
conditions[i] = fmt.Sprintf(`(SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $%d AND aggregate_type = $%d AND aggregate_id = $%d ORDER BY "sequence" DESC LIMIT 1)`,
i*argsPerCondition+1,
i*argsPerCondition+2,
i*argsPerCondition+3,
)
args = append(args, sequence.aggregate.InstanceID, sequence.aggregate.Type, sequence.aggregate.ID)
}
return conditions, args
}
func scanToSequence(rows *sql.Rows, sequences []*latestSequence) error {
var aggregateType eventstore.AggregateType
var aggregateID, instanceID string
var currentSequence uint64
var resourceOwner string
if err := rows.Scan(&instanceID, &resourceOwner, &aggregateType, &aggregateID, &currentSequence); err != nil {
return zerrors.ThrowInternal(err, "V3-OIWqj", "Errors.Internal")
}
sequence := searchSequence(sequences, aggregateType, aggregateID, instanceID)
if sequence == nil {
logging.WithFields(
"aggType", aggregateType,
"aggID", aggregateID,
"instance", instanceID,
).Panic("no sequence found")
// added return for linting
return nil
}
sequence.sequence = currentSequence
if resourceOwner != "" && sequence.aggregate.ResourceOwner != "" && sequence.aggregate.ResourceOwner != resourceOwner {
logging.WithFields(
"current_sequence", sequence.sequence,
"instance_id", sequence.aggregate.InstanceID,
"agg_type", sequence.aggregate.Type,
"agg_id", sequence.aggregate.ID,
"current_owner", resourceOwner,
"provided_owner", sequence.aggregate.ResourceOwner,
).Info("would have set wrong resource owner")
}
// set resource owner from previous events
if resourceOwner != "" {
sequence.aggregate.ResourceOwner = resourceOwner
}
return nil
}

View File

@@ -0,0 +1,293 @@
package eventstore
import (
"context"
_ "embed"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/eventstore"
)
func Test_searchSequence(t *testing.T) {
sequence := &latestSequence{
aggregate: mockAggregate("V3-p1BWC"),
sequence: 1,
}
type args struct {
sequences []*latestSequence
aggregateType eventstore.AggregateType
aggregateID string
instanceID string
}
tests := []struct {
name string
args args
want *latestSequence
}{
{
name: "type missmatch",
args: args{
sequences: []*latestSequence{
sequence,
},
aggregateType: "wrong",
aggregateID: "V3-p1BWC",
instanceID: "instance",
},
want: nil,
},
{
name: "id missmatch",
args: args{
sequences: []*latestSequence{
sequence,
},
aggregateType: "type",
aggregateID: "wrong",
instanceID: "instance",
},
want: nil,
},
{
name: "instance missmatch",
args: args{
sequences: []*latestSequence{
sequence,
},
aggregateType: "type",
aggregateID: "V3-p1BWC",
instanceID: "wrong",
},
want: nil,
},
{
name: "match",
args: args{
sequences: []*latestSequence{
sequence,
},
aggregateType: "type",
aggregateID: "V3-p1BWC",
instanceID: "instance",
},
want: sequence,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := searchSequence(tt.args.sequences, tt.args.aggregateType, tt.args.aggregateID, tt.args.instanceID); !reflect.DeepEqual(got, tt.want) {
t.Errorf("searchSequence() = %v, want %v", got, tt.want)
}
})
}
}
func Test_commandsToSequences(t *testing.T) {
aggregate := mockAggregate("V3-MKHTF")
type args struct {
ctx context.Context
commands []eventstore.Command
}
tests := []struct {
name string
args args
want []*latestSequence
}{
{
name: "no command",
args: args{
ctx: context.Background(),
commands: []eventstore.Command{},
},
want: []*latestSequence{},
},
{
name: "one command",
args: args{
ctx: context.Background(),
commands: []eventstore.Command{
&mockCommand{
aggregate: aggregate,
},
},
},
want: []*latestSequence{
{
aggregate: aggregate,
},
},
},
{
name: "two commands same aggregate",
args: args{
ctx: context.Background(),
commands: []eventstore.Command{
&mockCommand{
aggregate: aggregate,
},
&mockCommand{
aggregate: aggregate,
},
},
},
want: []*latestSequence{
{
aggregate: aggregate,
},
},
},
{
name: "two commands different aggregates",
args: args{
ctx: context.Background(),
commands: []eventstore.Command{
&mockCommand{
aggregate: aggregate,
},
&mockCommand{
aggregate: mockAggregate("V3-cZkCy"),
},
},
},
want: []*latestSequence{
{
aggregate: aggregate,
},
{
aggregate: mockAggregate("V3-cZkCy"),
},
},
},
{
name: "instance set in command",
args: args{
ctx: authz.WithInstanceID(context.Background(), "V3-ANV4p"),
commands: []eventstore.Command{
&mockCommand{
aggregate: &eventstore.Aggregate{
ID: "V3-bF0Sa",
Type: "type",
ResourceOwner: "to",
InstanceID: "instance",
Version: "v1",
},
},
},
},
want: []*latestSequence{
{
aggregate: &eventstore.Aggregate{
ID: "V3-bF0Sa",
Type: "type",
ResourceOwner: "to",
InstanceID: "instance",
Version: "v1",
},
},
},
},
{
name: "instance from context",
args: args{
ctx: authz.WithInstanceID(context.Background(), "V3-ANV4p"),
commands: []eventstore.Command{
&mockCommand{
aggregate: &eventstore.Aggregate{
ID: "V3-bF0Sa",
Type: "type",
ResourceOwner: "to",
Version: "v1",
},
},
},
},
want: []*latestSequence{
{
aggregate: &eventstore.Aggregate{
ID: "V3-bF0Sa",
Type: "type",
ResourceOwner: "to",
InstanceID: "V3-ANV4p",
Version: "v1",
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := commandsToSequences(tt.args.ctx, tt.args.commands)
assert.ElementsMatch(t, tt.want, got)
})
}
}
func Test_sequencesToSql(t *testing.T) {
tests := []struct {
name string
arg []*latestSequence
wantConditions []string
wantArgs []any
}{
{
name: "no sequence",
arg: []*latestSequence{},
wantConditions: []string{},
wantArgs: []any{},
},
{
name: "one",
arg: []*latestSequence{
{
aggregate: mockAggregate("V3-SbpGB"),
},
},
wantConditions: []string{
`(SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $1 AND aggregate_type = $2 AND aggregate_id = $3 ORDER BY "sequence" DESC LIMIT 1)`,
},
wantArgs: []any{
"instance",
eventstore.AggregateType("type"),
"V3-SbpGB",
},
},
{
name: "multiple",
arg: []*latestSequence{
{
aggregate: mockAggregate("V3-SbpGB"),
},
{
aggregate: mockAggregate("V3-0X3yt"),
},
},
wantConditions: []string{
`(SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $1 AND aggregate_type = $2 AND aggregate_id = $3 ORDER BY "sequence" DESC LIMIT 1)`,
`(SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $4 AND aggregate_type = $5 AND aggregate_id = $6 ORDER BY "sequence" DESC LIMIT 1)`,
},
wantArgs: []any{
"instance",
eventstore.AggregateType("type"),
"V3-SbpGB",
"instance",
eventstore.AggregateType("type"),
"V3-0X3yt",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotConditions, gotArgs := sequencesToSql(tt.arg)
if !reflect.DeepEqual(gotConditions, tt.wantConditions) {
t.Errorf("sequencesToSql() gotConditions = %v, want %v", gotConditions, tt.wantConditions)
}
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
t.Errorf("sequencesToSql() gotArgs = %v, want %v", gotArgs, tt.wantArgs)
}
})
}
}

View File

@@ -0,0 +1,18 @@
WITH existing AS (
%s
) SELECT
e.instance_id
, e.owner
, e.aggregate_type
, e.aggregate_id
, e.sequence
FROM
eventstore.events2 e
JOIN
existing
ON
e.instance_id = existing.instance_id
AND e.aggregate_type = existing.aggregate_type
AND e.aggregate_id = existing.aggregate_id
AND e.sequence = existing.sequence
FOR UPDATE;

View File

@@ -0,0 +1,25 @@
package eventstore
import "database/sql/driver"
// Payload represents a byte array that may be null.
// Payload implements the sql.Scanner interface
type Payload []byte
// Scan implements the Scanner interface.
func (data *Payload) Scan(value interface{}) error {
if value == nil {
*data = nil
return nil
}
*data = Payload(value.([]byte))
return nil
}
// Value implements the driver Valuer interface.
func (data Payload) Value() (driver.Value, error) {
if len(data) == 0 {
return nil, nil
}
return []byte(data), nil
}

View File

@@ -0,0 +1,100 @@
package eventstore
import (
"context"
_ "embed"
"errors"
"fmt"
"strings"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
var (
//go:embed unique_constraints_delete.sql
deleteConstraintStmt string
//go:embed unique_constraints_delete_placeholders.sql
deleteConstraintPlaceholdersStmt string
//go:embed unique_constraints_add.sql
addConstraintStmt string
)
func handleUniqueConstraints(ctx context.Context, tx database.Tx, commands []eventstore.Command) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
deletePlaceholders := make([]string, 0)
deleteArgs := make([]any, 0)
addPlaceholders := make([]string, 0)
addArgs := make([]any, 0)
addConstraints := map[string]*eventstore.UniqueConstraint{}
deleteConstraints := map[string]*eventstore.UniqueConstraint{}
for _, command := range commands {
for _, constraint := range command.UniqueConstraints() {
instanceID := command.Aggregate().InstanceID
if constraint.IsGlobal {
instanceID = ""
}
switch constraint.Action {
case eventstore.UniqueConstraintAdd:
constraint.UniqueField = strings.ToLower(constraint.UniqueField)
addPlaceholders = append(addPlaceholders, fmt.Sprintf("($%d, $%d, $%d)", len(addArgs)+1, len(addArgs)+2, len(addArgs)+3))
addArgs = append(addArgs, instanceID, constraint.UniqueType, constraint.UniqueField)
addConstraints[fmt.Sprintf(uniqueConstraintPlaceholderFmt, instanceID, constraint.UniqueType, constraint.UniqueField)] = constraint
case eventstore.UniqueConstraintRemove:
deletePlaceholders = append(deletePlaceholders, fmt.Sprintf(deleteConstraintPlaceholdersStmt, len(deleteArgs)+1, len(deleteArgs)+2, len(deleteArgs)+3))
deleteArgs = append(deleteArgs, instanceID, constraint.UniqueType, constraint.UniqueField)
deleteConstraints[fmt.Sprintf(uniqueConstraintPlaceholderFmt, instanceID, constraint.UniqueType, constraint.UniqueField)] = constraint
case eventstore.UniqueConstraintInstanceRemove:
deletePlaceholders = append(deletePlaceholders, fmt.Sprintf("(instance_id = $%d)", len(deleteArgs)+1))
deleteArgs = append(deleteArgs, instanceID)
deleteConstraints[fmt.Sprintf(uniqueConstraintPlaceholderFmt, instanceID, constraint.UniqueType, constraint.UniqueField)] = constraint
}
}
}
if len(deletePlaceholders) > 0 {
_, err := tx.ExecContext(ctx, fmt.Sprintf(deleteConstraintStmt, strings.Join(deletePlaceholders, " OR ")), deleteArgs...)
if err != nil {
logging.WithError(err).Warn("delete unique constraint failed")
errMessage := "Errors.Internal"
if constraint := constraintFromErr(err, deleteConstraints); constraint != nil {
errMessage = constraint.ErrorMessage
}
return zerrors.ThrowInternal(err, "V3-C8l3V", errMessage)
}
}
if len(addPlaceholders) > 0 {
_, err := tx.ExecContext(ctx, fmt.Sprintf(addConstraintStmt, strings.Join(addPlaceholders, ", ")), addArgs...)
if err != nil {
logging.WithError(err).Warn("add unique constraint failed")
errMessage := "Errors.Internal"
if constraint := constraintFromErr(err, addConstraints); constraint != nil {
errMessage = constraint.ErrorMessage
}
return zerrors.ThrowAlreadyExists(err, "V3-DKcYh", errMessage)
}
}
return nil
}
func constraintFromErr(err error, constraints map[string]*eventstore.UniqueConstraint) *eventstore.UniqueConstraint {
pgErr := new(pgconn.PgError)
if !errors.As(err, &pgErr) {
return nil
}
for key, constraint := range constraints {
if strings.Contains(pgErr.Detail, key) {
return constraint
}
}
return nil
}

View File

@@ -0,0 +1,6 @@
INSERT INTO eventstore.unique_constraints (
instance_id
, unique_type
, unique_field
) VALUES
%s

View File

@@ -0,0 +1 @@
DELETE FROM eventstore.unique_constraints WHERE %s

View File

@@ -0,0 +1,13 @@
-- the query is so complex because we accidentally stored unique constraint case sensitive
-- the query checks first if there is a case sensitive match and afterwards if there is a case insensitive match
(instance_id = $%[1]d AND unique_type = $%[2]d AND unique_field = (
SELECT unique_field from (
SELECT instance_id, unique_type, unique_field
FROM eventstore.unique_constraints
WHERE instance_id = $%[1]d AND unique_type = $%[2]d AND unique_field = $%[3]d
UNION ALL
SELECT instance_id, unique_type, unique_field
FROM eventstore.unique_constraints
WHERE instance_id = $%[1]d AND unique_type = $%[2]d AND unique_field = LOWER($%[3]d)
) AS case_insensitive_constraints LIMIT 1)
)