mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 04:07:31 +00:00
chore: move the go code into a subfolder
This commit is contained in:
176
apps/api/internal/eventstore/v3/event.go
Normal file
176
apps/api/internal/eventstore/v3/event.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
var (
|
||||
_ eventstore.Event = (*event)(nil)
|
||||
)
|
||||
|
||||
type command struct {
|
||||
InstanceID string
|
||||
AggregateType string
|
||||
AggregateID string
|
||||
CommandType string
|
||||
Revision uint16
|
||||
Payload Payload
|
||||
Creator string
|
||||
Owner string
|
||||
}
|
||||
|
||||
func (c *command) Aggregate() *eventstore.Aggregate {
|
||||
return &eventstore.Aggregate{
|
||||
ID: c.AggregateID,
|
||||
Type: eventstore.AggregateType(c.AggregateType),
|
||||
ResourceOwner: c.Owner,
|
||||
InstanceID: c.InstanceID,
|
||||
Version: eventstore.Version("v" + strconv.Itoa(int(c.Revision))),
|
||||
}
|
||||
}
|
||||
|
||||
type event struct {
|
||||
command *command
|
||||
createdAt time.Time
|
||||
sequence uint64
|
||||
position decimal.Decimal
|
||||
}
|
||||
|
||||
// TODO: remove on v3
|
||||
func commandToEventOld(sequence *latestSequence, cmd eventstore.Command) (_ *event, err error) {
|
||||
var payload Payload
|
||||
if cmd.Payload() != nil {
|
||||
payload, err = json.Marshal(cmd.Payload())
|
||||
if err != nil {
|
||||
logging.WithError(err).Warn("marshal payload failed")
|
||||
return nil, zerrors.ThrowInternal(err, "V3-MInPK", "Errors.Internal")
|
||||
}
|
||||
}
|
||||
return &event{
|
||||
command: &command{
|
||||
InstanceID: sequence.aggregate.InstanceID,
|
||||
AggregateType: string(sequence.aggregate.Type),
|
||||
AggregateID: sequence.aggregate.ID,
|
||||
CommandType: string(cmd.Type()),
|
||||
Revision: cmd.Revision(),
|
||||
Payload: payload,
|
||||
Creator: cmd.Creator(),
|
||||
Owner: sequence.aggregate.ResourceOwner,
|
||||
},
|
||||
sequence: sequence.sequence,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func commandsToEvents(ctx context.Context, cmds []eventstore.Command) (_ []eventstore.Event, _ []*command, err error) {
|
||||
events := make([]eventstore.Event, len(cmds))
|
||||
commands := make([]*command, len(cmds))
|
||||
for i, cmd := range cmds {
|
||||
if cmd.Aggregate().InstanceID == "" {
|
||||
cmd.Aggregate().InstanceID = authz.GetInstance(ctx).InstanceID()
|
||||
}
|
||||
events[i], err = commandToEvent(cmd)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
commands[i] = events[i].(*event).command
|
||||
}
|
||||
return events, commands, nil
|
||||
}
|
||||
|
||||
func commandToEvent(cmd eventstore.Command) (_ eventstore.Event, err error) {
|
||||
var payload Payload
|
||||
if cmd.Payload() != nil {
|
||||
payload, err = json.Marshal(cmd.Payload())
|
||||
if err != nil {
|
||||
logging.WithError(err).Warn("marshal payload failed")
|
||||
return nil, zerrors.ThrowInternal(err, "V3-MInPK", "Errors.Internal")
|
||||
}
|
||||
}
|
||||
|
||||
command := &command{
|
||||
InstanceID: cmd.Aggregate().InstanceID,
|
||||
AggregateType: string(cmd.Aggregate().Type),
|
||||
AggregateID: cmd.Aggregate().ID,
|
||||
CommandType: string(cmd.Type()),
|
||||
Revision: cmd.Revision(),
|
||||
Payload: payload,
|
||||
Creator: cmd.Creator(),
|
||||
Owner: cmd.Aggregate().ResourceOwner,
|
||||
}
|
||||
|
||||
return &event{
|
||||
command: command,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreationDate implements [eventstore.Event]
|
||||
func (e *event) CreationDate() time.Time {
|
||||
return e.CreatedAt()
|
||||
}
|
||||
|
||||
// EditorUser implements [eventstore.Event]
|
||||
func (e *event) EditorUser() string {
|
||||
return e.Creator()
|
||||
}
|
||||
|
||||
// Aggregate implements [eventstore.Event]
|
||||
func (e *event) Aggregate() *eventstore.Aggregate {
|
||||
return e.command.Aggregate()
|
||||
}
|
||||
|
||||
// Creator implements [eventstore.Event]
|
||||
func (e *event) Creator() string {
|
||||
return e.command.Creator
|
||||
}
|
||||
|
||||
// Revision implements [eventstore.Event]
|
||||
func (e *event) Revision() uint16 {
|
||||
return e.command.Revision
|
||||
}
|
||||
|
||||
// Type implements [eventstore.Event]
|
||||
func (e *event) Type() eventstore.EventType {
|
||||
return eventstore.EventType(e.command.CommandType)
|
||||
}
|
||||
|
||||
// CreatedAt implements [eventstore.Event]
|
||||
func (e *event) CreatedAt() time.Time {
|
||||
return e.createdAt
|
||||
}
|
||||
|
||||
// Sequence implements [eventstore.Event]
|
||||
func (e *event) Sequence() uint64 {
|
||||
return e.sequence
|
||||
}
|
||||
|
||||
// Position implements [eventstore.Event]
|
||||
func (e *event) Position() decimal.Decimal {
|
||||
return e.position
|
||||
}
|
||||
|
||||
// Unmarshal implements [eventstore.Event]
|
||||
func (e *event) Unmarshal(ptr any) error {
|
||||
if len(e.command.Payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := json.Unmarshal(e.command.Payload, ptr); err != nil {
|
||||
return zerrors.ThrowInternal(err, "V3-u8qVo", "Errors.Internal")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DataAsBytes implements [eventstore.Event]
|
||||
func (e *event) DataAsBytes() []byte {
|
||||
return e.command.Payload
|
||||
}
|
482
apps/api/internal/eventstore/v3/event_test.go
Normal file
482
apps/api/internal/eventstore/v3/event_test.go
Normal file
@@ -0,0 +1,482 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
func Test_commandToEvent(t *testing.T) {
|
||||
payload := struct {
|
||||
ID string
|
||||
}{
|
||||
ID: "test",
|
||||
}
|
||||
payloadMarshalled, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal of payload failed: %v", err)
|
||||
}
|
||||
type args struct {
|
||||
command eventstore.Command
|
||||
}
|
||||
type want struct {
|
||||
event *event
|
||||
err func(t *testing.T, err error)
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "no payload",
|
||||
args: args{
|
||||
command: &mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: nil,
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
event: mockEvent(
|
||||
mockAggregate("V3-Red9I"),
|
||||
0,
|
||||
nil,
|
||||
).(*event),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "struct payload",
|
||||
args: args{
|
||||
command: &mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: payload,
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
event: mockEvent(
|
||||
mockAggregate("V3-Red9I"),
|
||||
0,
|
||||
payloadMarshalled,
|
||||
).(*event),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "pointer payload",
|
||||
args: args{
|
||||
command: &mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: &payload,
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
event: mockEvent(
|
||||
mockAggregate("V3-Red9I"),
|
||||
0,
|
||||
payloadMarshalled,
|
||||
).(*event),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid payload",
|
||||
args: args{
|
||||
command: &mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: func() {},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
err: func(t *testing.T, err error) {
|
||||
assert.Error(t, err)
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if tt.want.err == nil {
|
||||
tt.want.err = func(t *testing.T, err error) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := commandToEvent(tt.args.command)
|
||||
|
||||
tt.want.err(t, err)
|
||||
if tt.want.event == nil {
|
||||
assert.Nil(t, got)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, tt.want.event, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_commandToEventOld(t *testing.T) {
|
||||
payload := struct {
|
||||
ID string
|
||||
}{
|
||||
ID: "test",
|
||||
}
|
||||
payloadMarshalled, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal of payload failed: %v", err)
|
||||
}
|
||||
type args struct {
|
||||
sequence *latestSequence
|
||||
command eventstore.Command
|
||||
}
|
||||
type want struct {
|
||||
event *event
|
||||
err func(t *testing.T, err error)
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "no payload",
|
||||
args: args{
|
||||
sequence: &latestSequence{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
sequence: 0,
|
||||
},
|
||||
command: &mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: nil,
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
event: mockEvent(
|
||||
mockAggregate("V3-Red9I"),
|
||||
0,
|
||||
nil,
|
||||
).(*event),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "struct payload",
|
||||
args: args{
|
||||
sequence: &latestSequence{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
sequence: 0,
|
||||
},
|
||||
command: &mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: payload,
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
event: mockEvent(
|
||||
mockAggregate("V3-Red9I"),
|
||||
0,
|
||||
payloadMarshalled,
|
||||
).(*event),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "pointer payload",
|
||||
args: args{
|
||||
sequence: &latestSequence{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
sequence: 0,
|
||||
},
|
||||
command: &mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: &payload,
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
event: mockEvent(
|
||||
mockAggregate("V3-Red9I"),
|
||||
0,
|
||||
payloadMarshalled,
|
||||
).(*event),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid payload",
|
||||
args: args{
|
||||
sequence: &latestSequence{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
sequence: 0,
|
||||
},
|
||||
command: &mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: func() {},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
err: func(t *testing.T, err error) {
|
||||
assert.Error(t, err)
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if tt.want.err == nil {
|
||||
tt.want.err = func(t *testing.T, err error) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := commandToEventOld(tt.args.sequence, tt.args.command)
|
||||
|
||||
tt.want.err(t, err)
|
||||
assert.Equal(t, tt.want.event, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_commandsToEvents(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
payload := struct {
|
||||
ID string
|
||||
}{
|
||||
ID: "test",
|
||||
}
|
||||
payloadMarshalled, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal of payload failed: %v", err)
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
cmds []eventstore.Command
|
||||
}
|
||||
type want struct {
|
||||
events []eventstore.Event
|
||||
commands []*command
|
||||
err func(t *testing.T, err error)
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "no commands",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
cmds: nil,
|
||||
},
|
||||
want: want{
|
||||
events: []eventstore.Event{},
|
||||
commands: []*command{},
|
||||
err: func(t *testing.T, err error) {
|
||||
require.NoError(t, err)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single command no payload",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
cmds: []eventstore.Command{
|
||||
&mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
events: []eventstore.Event{
|
||||
mockEvent(
|
||||
mockAggregate("V3-Red9I"),
|
||||
0,
|
||||
nil,
|
||||
),
|
||||
},
|
||||
commands: []*command{
|
||||
{
|
||||
InstanceID: "instance",
|
||||
AggregateType: "type",
|
||||
AggregateID: "V3-Red9I",
|
||||
Owner: "ro",
|
||||
CommandType: "event.type",
|
||||
Revision: 1,
|
||||
Payload: nil,
|
||||
Creator: "creator",
|
||||
},
|
||||
},
|
||||
err: func(t *testing.T, err error) {
|
||||
require.NoError(t, err)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single command no instance id",
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(ctx, "instance from ctx"),
|
||||
cmds: []eventstore.Command{
|
||||
&mockCommand{
|
||||
aggregate: mockAggregateWithInstance("V3-Red9I", ""),
|
||||
payload: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
events: []eventstore.Event{
|
||||
mockEvent(
|
||||
mockAggregateWithInstance("V3-Red9I", "instance from ctx"),
|
||||
0,
|
||||
nil,
|
||||
),
|
||||
},
|
||||
commands: []*command{
|
||||
{
|
||||
InstanceID: "instance from ctx",
|
||||
AggregateType: "type",
|
||||
AggregateID: "V3-Red9I",
|
||||
Owner: "ro",
|
||||
CommandType: "event.type",
|
||||
Revision: 1,
|
||||
Payload: nil,
|
||||
Creator: "creator",
|
||||
},
|
||||
},
|
||||
err: func(t *testing.T, err error) {
|
||||
require.NoError(t, err)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single command with payload",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
cmds: []eventstore.Command{
|
||||
&mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: payload,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
events: []eventstore.Event{
|
||||
mockEvent(
|
||||
mockAggregate("V3-Red9I"),
|
||||
0,
|
||||
payloadMarshalled,
|
||||
),
|
||||
},
|
||||
commands: []*command{
|
||||
{
|
||||
InstanceID: "instance",
|
||||
AggregateType: "type",
|
||||
AggregateID: "V3-Red9I",
|
||||
Owner: "ro",
|
||||
CommandType: "event.type",
|
||||
Revision: 1,
|
||||
Payload: payloadMarshalled,
|
||||
Creator: "creator",
|
||||
},
|
||||
},
|
||||
err: func(t *testing.T, err error) {
|
||||
require.NoError(t, err)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple commands",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
cmds: []eventstore.Command{
|
||||
&mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: payload,
|
||||
},
|
||||
&mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
events: []eventstore.Event{
|
||||
mockEvent(
|
||||
mockAggregate("V3-Red9I"),
|
||||
0,
|
||||
payloadMarshalled,
|
||||
),
|
||||
mockEvent(
|
||||
mockAggregate("V3-Red9I"),
|
||||
0,
|
||||
nil,
|
||||
),
|
||||
},
|
||||
commands: []*command{
|
||||
{
|
||||
InstanceID: "instance",
|
||||
AggregateType: "type",
|
||||
AggregateID: "V3-Red9I",
|
||||
CommandType: "event.type",
|
||||
Revision: 1,
|
||||
Payload: payloadMarshalled,
|
||||
Creator: "creator",
|
||||
Owner: "ro",
|
||||
},
|
||||
{
|
||||
InstanceID: "instance",
|
||||
AggregateType: "type",
|
||||
AggregateID: "V3-Red9I",
|
||||
CommandType: "event.type",
|
||||
Revision: 1,
|
||||
Payload: nil,
|
||||
Creator: "creator",
|
||||
Owner: "ro",
|
||||
},
|
||||
},
|
||||
err: func(t *testing.T, err error) {
|
||||
require.NoError(t, err)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid command",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
cmds: []eventstore.Command{
|
||||
&mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: func() {},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
events: nil,
|
||||
commands: nil,
|
||||
err: func(t *testing.T, err error) {
|
||||
assert.Error(t, err)
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotEvents, gotCommands, err := commandsToEvents(tt.args.ctx, tt.args.cmds)
|
||||
|
||||
tt.want.err(t, err)
|
||||
assert.Equal(t, tt.want.events, gotEvents)
|
||||
require.Len(t, gotCommands, len(tt.want.commands))
|
||||
for i, wantCommand := range tt.want.commands {
|
||||
assertCommand(t, wantCommand, gotCommands[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func assertCommand(t *testing.T, want, got *command) {
|
||||
t.Helper()
|
||||
assert.Equal(t, want.CommandType, got.CommandType)
|
||||
assert.Equal(t, want.Payload, got.Payload)
|
||||
assert.Equal(t, want.Creator, got.Creator)
|
||||
assert.Equal(t, want.Owner, got.Owner)
|
||||
assert.Equal(t, want.AggregateID, got.AggregateID)
|
||||
assert.Equal(t, want.AggregateType, got.AggregateType)
|
||||
assert.Equal(t, want.InstanceID, got.InstanceID)
|
||||
assert.Equal(t, want.Revision, got.Revision)
|
||||
}
|
202
apps/api/internal/eventstore/v3/eventstore.go
Normal file
202
apps/api/internal/eventstore/v3/eventstore.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
func init() {
|
||||
dialect.RegisterAfterConnect(RegisterEventstoreTypes)
|
||||
}
|
||||
|
||||
var (
|
||||
// pushPlaceholderFmt defines how data are inserted into the events table
|
||||
pushPlaceholderFmt = "($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $%d)"
|
||||
// uniqueConstraintPlaceholderFmt defines the format of the unique constraint error returned from the database
|
||||
uniqueConstraintPlaceholderFmt = "(%s, %s, %s)"
|
||||
|
||||
_ eventstore.Pusher = (*Eventstore)(nil)
|
||||
)
|
||||
|
||||
type Eventstore struct {
|
||||
client *database.DB
|
||||
}
|
||||
|
||||
var (
|
||||
textType = &pgtype.Type{
|
||||
Name: "text",
|
||||
OID: pgtype.TextOID,
|
||||
Codec: pgtype.TextCodec{},
|
||||
}
|
||||
commandType = &pgtype.Type{
|
||||
Codec: &pgtype.CompositeCodec{
|
||||
Fields: []pgtype.CompositeCodecField{
|
||||
{
|
||||
Name: "instance_id",
|
||||
Type: textType,
|
||||
},
|
||||
{
|
||||
Name: "aggregate_type",
|
||||
Type: textType,
|
||||
},
|
||||
{
|
||||
Name: "aggregate_id",
|
||||
Type: textType,
|
||||
},
|
||||
{
|
||||
Name: "command_type",
|
||||
Type: textType,
|
||||
},
|
||||
{
|
||||
Name: "revision",
|
||||
Type: &pgtype.Type{
|
||||
Name: "int2",
|
||||
OID: pgtype.Int2OID,
|
||||
Codec: pgtype.Int2Codec{},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "payload",
|
||||
Type: &pgtype.Type{
|
||||
Name: "jsonb",
|
||||
OID: pgtype.JSONBOID,
|
||||
Codec: &pgtype.JSONBCodec{
|
||||
Marshal: json.Marshal,
|
||||
Unmarshal: json.Unmarshal,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "creator",
|
||||
Type: textType,
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Type: textType,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
commandArrayCodec = &pgtype.Type{
|
||||
Codec: &pgtype.ArrayCodec{
|
||||
ElementType: commandType,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
var typeMu sync.Mutex
|
||||
|
||||
func RegisterEventstoreTypes(ctx context.Context, conn *pgx.Conn) error {
|
||||
// conn.TypeMap is not thread safe
|
||||
typeMu.Lock()
|
||||
defer typeMu.Unlock()
|
||||
|
||||
m := conn.TypeMap()
|
||||
|
||||
var cmd *command
|
||||
if _, ok := m.TypeForValue(cmd); ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if commandType.OID == 0 || commandArrayCodec.OID == 0 {
|
||||
err := conn.QueryRow(ctx, "select oid, typarray from pg_type where typname = $1 and typnamespace = (select oid from pg_namespace where nspname = $2)", "command", "eventstore").
|
||||
Scan(&commandType.OID, &commandArrayCodec.OID)
|
||||
if err != nil {
|
||||
logging.WithError(err).Debug("failed to get oid for command type")
|
||||
return nil
|
||||
}
|
||||
if commandType.OID == 0 || commandArrayCodec.OID == 0 {
|
||||
logging.Debug("oid for command type not found")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
m.RegisterTypes([]*pgtype.Type{
|
||||
{
|
||||
Name: "eventstore.command",
|
||||
Codec: commandType.Codec,
|
||||
OID: commandType.OID,
|
||||
},
|
||||
{
|
||||
Name: "command",
|
||||
Codec: commandType.Codec,
|
||||
OID: commandType.OID,
|
||||
},
|
||||
{
|
||||
Name: "eventstore._command",
|
||||
Codec: commandArrayCodec.Codec,
|
||||
OID: commandArrayCodec.OID,
|
||||
},
|
||||
{
|
||||
Name: "_command",
|
||||
Codec: commandArrayCodec.Codec,
|
||||
OID: commandArrayCodec.OID,
|
||||
},
|
||||
})
|
||||
dialect.RegisterDefaultPgTypeVariants[command](m, "eventstore.command", "eventstore._command")
|
||||
dialect.RegisterDefaultPgTypeVariants[command](m, "command", "_command")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Client implements the [eventstore.Pusher]
|
||||
func (es *Eventstore) Client() *database.DB {
|
||||
return es.client
|
||||
}
|
||||
|
||||
func NewEventstore(client *database.DB) *Eventstore {
|
||||
return &Eventstore{client: client}
|
||||
}
|
||||
|
||||
func (es *Eventstore) Health(ctx context.Context) error {
|
||||
return es.client.PingContext(ctx)
|
||||
}
|
||||
|
||||
var errTypesNotFound = errors.New("types not found")
|
||||
|
||||
func CheckExecutionPlan(ctx context.Context, conn *sql.Conn) error {
|
||||
return conn.Raw(func(driverConn any) error {
|
||||
if _, ok := driverConn.(sqlmock.SqlmockCommon); ok {
|
||||
return nil
|
||||
}
|
||||
conn, ok := driverConn.(*stdlib.Conn)
|
||||
if !ok {
|
||||
return errTypesNotFound
|
||||
}
|
||||
|
||||
return RegisterEventstoreTypes(ctx, conn.Conn())
|
||||
})
|
||||
}
|
||||
|
||||
func (es *Eventstore) pushTx(ctx context.Context, client database.ContextQueryExecuter) (tx database.Tx, deferrable func(err error) error, err error) {
|
||||
tx, ok := client.(database.Tx)
|
||||
if ok {
|
||||
return tx, nil, nil
|
||||
}
|
||||
beginner, ok := client.(database.Beginner)
|
||||
if !ok {
|
||||
beginner = es.client
|
||||
}
|
||||
|
||||
tx, err = beginner.BeginTx(ctx, &sql.TxOptions{
|
||||
Isolation: sql.LevelReadCommitted,
|
||||
ReadOnly: false,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return tx, func(err error) error { return database.CloseTransaction(tx, err) }, nil
|
||||
}
|
369
apps/api/internal/eventstore/v3/field.go
Normal file
369
apps/api/internal/eventstore/v3/field.go
Normal file
@@ -0,0 +1,369 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type fieldValue struct {
|
||||
value []byte
|
||||
}
|
||||
|
||||
func (value *fieldValue) Unmarshal(ptr any) error {
|
||||
return json.Unmarshal(value.value, ptr)
|
||||
}
|
||||
|
||||
func (es *Eventstore) FillFields(ctx context.Context, events ...eventstore.FillFieldsEvent) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
tx, err := es.client.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
return
|
||||
}
|
||||
err = tx.Commit()
|
||||
}()
|
||||
|
||||
return handleFieldFillEvents(ctx, tx, events)
|
||||
}
|
||||
|
||||
// Search implements the [eventstore.Search] method
|
||||
func (es *Eventstore) Search(ctx context.Context, conditions ...map[eventstore.FieldType]any) (result []*eventstore.SearchResult, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
var builder strings.Builder
|
||||
args := buildSearchStatement(ctx, &builder, conditions...)
|
||||
|
||||
err = es.client.QueryContext(
|
||||
ctx,
|
||||
func(rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
var (
|
||||
res eventstore.SearchResult
|
||||
value fieldValue
|
||||
)
|
||||
err = rows.Scan(
|
||||
&res.Aggregate.InstanceID,
|
||||
&res.Aggregate.ResourceOwner,
|
||||
&res.Aggregate.Type,
|
||||
&res.Aggregate.ID,
|
||||
&res.Object.Type,
|
||||
&res.Object.ID,
|
||||
&res.Object.Revision,
|
||||
&res.FieldName,
|
||||
&value.value,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.Value = &value
|
||||
|
||||
result = append(result, &res)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
builder.String(),
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
const searchQueryPrefix = `SELECT instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value FROM eventstore.fields WHERE instance_id = $1`
|
||||
|
||||
func buildSearchStatement(ctx context.Context, builder *strings.Builder, conditions ...map[eventstore.FieldType]any) []any {
|
||||
args := make([]any, 0, len(conditions)*4+1)
|
||||
args = append(args, authz.GetInstance(ctx).InstanceID())
|
||||
|
||||
builder.WriteString(searchQueryPrefix)
|
||||
|
||||
builder.WriteString(" AND ")
|
||||
if len(conditions) > 1 {
|
||||
builder.WriteRune('(')
|
||||
}
|
||||
for i, condition := range conditions {
|
||||
if i > 0 {
|
||||
builder.WriteString(" OR ")
|
||||
}
|
||||
if len(condition) > 1 {
|
||||
builder.WriteRune('(')
|
||||
}
|
||||
args = append(args, buildSearchCondition(builder, len(args)+1, condition)...)
|
||||
if len(condition) > 1 {
|
||||
builder.WriteRune(')')
|
||||
}
|
||||
}
|
||||
if len(conditions) > 1 {
|
||||
builder.WriteRune(')')
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
func buildSearchCondition(builder *strings.Builder, index int, conditions map[eventstore.FieldType]any) []any {
|
||||
args := make([]any, 0, len(conditions))
|
||||
|
||||
orderedCondition := make([]eventstore.FieldType, 0, len(conditions))
|
||||
for field := range conditions {
|
||||
orderedCondition = append(orderedCondition, field)
|
||||
}
|
||||
slices.Sort(orderedCondition)
|
||||
|
||||
for _, field := range orderedCondition {
|
||||
if len(args) > 0 {
|
||||
builder.WriteString(" AND ")
|
||||
}
|
||||
builder.WriteString(fieldNameByType(field, conditions[field]))
|
||||
builder.WriteString(" = $")
|
||||
builder.WriteString(strconv.Itoa(index + len(args)))
|
||||
args = append(args, conditions[field])
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
func (es *Eventstore) handleFieldCommands(ctx context.Context, tx database.Tx, commands []eventstore.Command) error {
|
||||
for _, command := range commands {
|
||||
if len(command.Fields()) > 0 {
|
||||
if err := handleFieldOperations(ctx, tx, command.Fields()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleFieldFillEvents(ctx context.Context, tx database.Tx, events []eventstore.FillFieldsEvent) error {
|
||||
for _, event := range events {
|
||||
if len(event.Fields()) == 0 {
|
||||
continue
|
||||
}
|
||||
if err := handleFieldOperations(ctx, tx, event.Fields()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleFieldOperations(ctx context.Context, tx database.Tx, operations []*eventstore.FieldOperation) error {
|
||||
for _, operation := range operations {
|
||||
if operation.Set != nil {
|
||||
if err := handleFieldSet(ctx, tx, operation.Set); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if operation.Remove != nil {
|
||||
if err := handleSearchDelete(ctx, tx, operation.Remove); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleFieldSet(ctx context.Context, tx database.Tx, field *eventstore.Field) error {
|
||||
if len(field.UpsertConflictFields) == 0 {
|
||||
return handleSearchInsert(ctx, tx, field)
|
||||
}
|
||||
return handleSearchUpsert(ctx, tx, field)
|
||||
}
|
||||
|
||||
const (
|
||||
insertField = `INSERT INTO eventstore.fields (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)`
|
||||
)
|
||||
|
||||
func handleSearchInsert(ctx context.Context, tx database.Tx, field *eventstore.Field) error {
|
||||
value, err := json.Marshal(field.Value.Value)
|
||||
if err != nil {
|
||||
return zerrors.ThrowInvalidArgument(err, "V3-fcrW1", "unable to marshal field value")
|
||||
}
|
||||
_, err = tx.ExecContext(
|
||||
ctx,
|
||||
insertField,
|
||||
|
||||
field.Aggregate.InstanceID,
|
||||
field.Aggregate.ResourceOwner,
|
||||
field.Aggregate.Type,
|
||||
field.Aggregate.ID,
|
||||
field.Object.Type,
|
||||
field.Object.ID,
|
||||
field.Object.Revision,
|
||||
field.FieldName,
|
||||
value,
|
||||
field.Value.MustBeUnique,
|
||||
field.Value.ShouldIndex,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const (
|
||||
fieldsUpsertPrefix = `WITH upsert AS (UPDATE eventstore.fields SET (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) = ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) WHERE `
|
||||
fieldsUpsertSuffix = ` RETURNING * ) INSERT INTO eventstore.fields (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) SELECT $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11 WHERE NOT EXISTS (SELECT 1 FROM upsert)`
|
||||
)
|
||||
|
||||
func handleSearchUpsert(ctx context.Context, tx database.Tx, field *eventstore.Field) error {
|
||||
value, err := json.Marshal(field.Value.Value)
|
||||
if err != nil {
|
||||
return zerrors.ThrowInvalidArgument(err, "V3-fcrW1", "unable to marshal field value")
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(
|
||||
ctx,
|
||||
writeUpsertField(field.UpsertConflictFields),
|
||||
|
||||
field.Aggregate.InstanceID,
|
||||
field.Aggregate.ResourceOwner,
|
||||
field.Aggregate.Type,
|
||||
field.Aggregate.ID,
|
||||
field.Object.Type,
|
||||
field.Object.ID,
|
||||
field.Object.Revision,
|
||||
field.FieldName,
|
||||
value,
|
||||
field.Value.MustBeUnique,
|
||||
field.Value.ShouldIndex,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func writeUpsertField(fields []eventstore.FieldType) string {
|
||||
var builder strings.Builder
|
||||
|
||||
builder.WriteString(fieldsUpsertPrefix)
|
||||
for i, fieldName := range fields {
|
||||
if i > 0 {
|
||||
builder.WriteString(" AND ")
|
||||
}
|
||||
name, index := searchFieldNameAndIndexByTypeForPush(fieldName)
|
||||
|
||||
builder.WriteString(name)
|
||||
builder.WriteString(" = ")
|
||||
builder.WriteString(index)
|
||||
}
|
||||
builder.WriteString(fieldsUpsertSuffix)
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
const removeSearch = `DELETE FROM eventstore.fields WHERE `
|
||||
|
||||
func handleSearchDelete(ctx context.Context, tx database.Tx, clauses map[eventstore.FieldType]any) error {
|
||||
if len(clauses) == 0 {
|
||||
return zerrors.ThrowInvalidArgument(nil, "V3-oqlBZ", "no conditions")
|
||||
}
|
||||
stmt, args := writeDeleteField(clauses)
|
||||
_, err := tx.ExecContext(ctx, stmt, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
func writeDeleteField(clauses map[eventstore.FieldType]any) (string, []any) {
|
||||
var (
|
||||
builder strings.Builder
|
||||
args = make([]any, 0, len(clauses))
|
||||
)
|
||||
builder.WriteString(removeSearch)
|
||||
|
||||
orderedCondition := make([]eventstore.FieldType, 0, len(clauses))
|
||||
for field := range clauses {
|
||||
orderedCondition = append(orderedCondition, field)
|
||||
}
|
||||
slices.Sort(orderedCondition)
|
||||
|
||||
for _, fieldName := range orderedCondition {
|
||||
if len(args) > 0 {
|
||||
builder.WriteString(" AND ")
|
||||
}
|
||||
builder.WriteString(fieldNameByType(fieldName, clauses[fieldName]))
|
||||
|
||||
builder.WriteString(" = $")
|
||||
builder.WriteString(strconv.Itoa(len(args) + 1))
|
||||
|
||||
args = append(args, clauses[fieldName])
|
||||
}
|
||||
|
||||
return builder.String(), args
|
||||
}
|
||||
|
||||
func fieldNameByType(typ eventstore.FieldType, value any) string {
|
||||
switch typ {
|
||||
case eventstore.FieldTypeAggregateID:
|
||||
return "aggregate_id"
|
||||
case eventstore.FieldTypeAggregateType:
|
||||
return "aggregate_type"
|
||||
case eventstore.FieldTypeInstanceID:
|
||||
return "instance_id"
|
||||
case eventstore.FieldTypeResourceOwner:
|
||||
return "resource_owner"
|
||||
case eventstore.FieldTypeFieldName:
|
||||
return "field_name"
|
||||
case eventstore.FieldTypeObjectType:
|
||||
return "object_type"
|
||||
case eventstore.FieldTypeObjectID:
|
||||
return "object_id"
|
||||
case eventstore.FieldTypeObjectRevision:
|
||||
return "object_revision"
|
||||
case eventstore.FieldTypeValue:
|
||||
return valueColumn(value)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func searchFieldNameAndIndexByTypeForPush(typ eventstore.FieldType) (string, string) {
|
||||
switch typ {
|
||||
case eventstore.FieldTypeInstanceID:
|
||||
return "instance_id", "$1"
|
||||
case eventstore.FieldTypeResourceOwner:
|
||||
return "resource_owner", "$2"
|
||||
case eventstore.FieldTypeAggregateType:
|
||||
return "aggregate_type", "$3"
|
||||
case eventstore.FieldTypeAggregateID:
|
||||
return "aggregate_id", "$4"
|
||||
case eventstore.FieldTypeObjectType:
|
||||
return "object_type", "$5"
|
||||
case eventstore.FieldTypeObjectID:
|
||||
return "object_id", "$6"
|
||||
case eventstore.FieldTypeObjectRevision:
|
||||
return "object_revision", "$7"
|
||||
case eventstore.FieldTypeFieldName:
|
||||
return "field_name", "$8"
|
||||
case eventstore.FieldTypeValue:
|
||||
return "value", "$9"
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func valueColumn(value any) string {
|
||||
//nolint: exhaustive
|
||||
switch reflect.TypeOf(value).Kind() {
|
||||
case reflect.Bool:
|
||||
return "bool_value"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64:
|
||||
return "number_value"
|
||||
case reflect.String:
|
||||
return "text_value"
|
||||
}
|
||||
return ""
|
||||
}
|
260
apps/api/internal/eventstore/v3/field_test.go
Normal file
260
apps/api/internal/eventstore/v3/field_test.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
func Test_handleSearchDelete(t *testing.T) {
|
||||
type args struct {
|
||||
clauses map[eventstore.FieldType]any
|
||||
}
|
||||
type want struct {
|
||||
stmt string
|
||||
args []any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "1 condition",
|
||||
args: args{
|
||||
clauses: map[eventstore.FieldType]any{
|
||||
eventstore.FieldTypeInstanceID: "i_id",
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
stmt: "DELETE FROM eventstore.fields WHERE instance_id = $1",
|
||||
args: []any{"i_id"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "2 conditions",
|
||||
args: args{
|
||||
clauses: map[eventstore.FieldType]any{
|
||||
eventstore.FieldTypeInstanceID: "i_id",
|
||||
eventstore.FieldTypeAggregateID: "a_id",
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
stmt: "DELETE FROM eventstore.fields WHERE aggregate_id = $1 AND instance_id = $2",
|
||||
args: []any{"a_id", "i_id"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
stmt, args := writeDeleteField(tt.args.clauses)
|
||||
if stmt != tt.want.stmt {
|
||||
t.Errorf("handleSearchDelete() stmt = %q, want %q", stmt, tt.want.stmt)
|
||||
}
|
||||
assert.Equal(t, tt.want.args, args)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_writeUpsertField(t *testing.T) {
|
||||
type args struct {
|
||||
fields []eventstore.FieldType
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "1 field",
|
||||
args: args{
|
||||
fields: []eventstore.FieldType{
|
||||
eventstore.FieldTypeInstanceID,
|
||||
},
|
||||
},
|
||||
want: "WITH upsert AS (UPDATE eventstore.fields SET (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) = ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) WHERE instance_id = $1 RETURNING * ) INSERT INTO eventstore.fields (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) SELECT $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11 WHERE NOT EXISTS (SELECT 1 FROM upsert)",
|
||||
},
|
||||
{
|
||||
name: "2 fields",
|
||||
args: args{
|
||||
fields: []eventstore.FieldType{
|
||||
eventstore.FieldTypeInstanceID,
|
||||
eventstore.FieldTypeAggregateType,
|
||||
},
|
||||
},
|
||||
want: "WITH upsert AS (UPDATE eventstore.fields SET (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) = ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) WHERE instance_id = $1 AND aggregate_type = $3 RETURNING * ) INSERT INTO eventstore.fields (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) SELECT $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11 WHERE NOT EXISTS (SELECT 1 FROM upsert)",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := writeUpsertField(tt.args.fields); got != tt.want {
|
||||
t.Errorf("writeUpsertField() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_buildSearchCondition(t *testing.T) {
|
||||
type args struct {
|
||||
index int
|
||||
conditions map[eventstore.FieldType]any
|
||||
}
|
||||
type want struct {
|
||||
stmt string
|
||||
args []any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "1 condition",
|
||||
args: args{
|
||||
index: 1,
|
||||
conditions: map[eventstore.FieldType]any{
|
||||
eventstore.FieldTypeAggregateID: "a_id",
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
stmt: "aggregate_id = $1",
|
||||
args: []any{"a_id"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "3 condition",
|
||||
args: args{
|
||||
index: 1,
|
||||
conditions: map[eventstore.FieldType]any{
|
||||
eventstore.FieldTypeAggregateID: "a_id",
|
||||
eventstore.FieldTypeInstanceID: "i_id",
|
||||
eventstore.FieldTypeAggregateType: "a_type",
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
stmt: "aggregate_type = $1 AND aggregate_id = $2 AND instance_id = $3",
|
||||
args: []any{"a_type", "a_id", "i_id"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
|
||||
if got := buildSearchCondition(&builder, tt.args.index, tt.args.conditions); !reflect.DeepEqual(got, tt.want.args) {
|
||||
t.Errorf("buildSearchCondition() = %v, want %v", got, tt.want)
|
||||
}
|
||||
if tt.want.stmt != builder.String() {
|
||||
t.Errorf("buildSearchCondition() stmt = %q, want %q", builder.String(), tt.want.stmt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_buildSearchStatement(t *testing.T) {
|
||||
type args struct {
|
||||
index int
|
||||
conditions []map[eventstore.FieldType]any
|
||||
}
|
||||
type want struct {
|
||||
stmt string
|
||||
args []any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "1 condition with 1 field",
|
||||
args: args{
|
||||
index: 1,
|
||||
conditions: []map[eventstore.FieldType]any{
|
||||
{
|
||||
eventstore.FieldTypeAggregateID: "a_id",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
stmt: "SELECT instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value FROM eventstore.fields WHERE instance_id = $1 AND aggregate_id = $2",
|
||||
args: []any{"a_id"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "1 condition with 3 fields",
|
||||
args: args{
|
||||
index: 1,
|
||||
conditions: []map[eventstore.FieldType]any{
|
||||
{
|
||||
eventstore.FieldTypeAggregateID: "a_id",
|
||||
eventstore.FieldTypeInstanceID: "i_id",
|
||||
eventstore.FieldTypeAggregateType: "a_type",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
stmt: "SELECT instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value FROM eventstore.fields WHERE instance_id = $1 AND (aggregate_type = $2 AND aggregate_id = $3 AND instance_id = $4)",
|
||||
args: []any{"a_type", "a_id", "i_id"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "2 condition with 1 field",
|
||||
args: args{
|
||||
index: 1,
|
||||
conditions: []map[eventstore.FieldType]any{
|
||||
{
|
||||
eventstore.FieldTypeAggregateID: "a_id",
|
||||
},
|
||||
{
|
||||
eventstore.FieldTypeAggregateType: "a_type",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
stmt: "SELECT instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value FROM eventstore.fields WHERE instance_id = $1 AND (aggregate_id = $2 OR aggregate_type = $3)",
|
||||
args: []any{"a_id", "a_type"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "2 condition with 2 fields",
|
||||
args: args{
|
||||
index: 1,
|
||||
conditions: []map[eventstore.FieldType]any{
|
||||
{
|
||||
eventstore.FieldTypeAggregateID: "a_id1",
|
||||
eventstore.FieldTypeAggregateType: "a_type1",
|
||||
},
|
||||
{
|
||||
eventstore.FieldTypeAggregateID: "a_id2",
|
||||
eventstore.FieldTypeAggregateType: "a_type2",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
stmt: "SELECT instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value FROM eventstore.fields WHERE instance_id = $1 AND ((aggregate_type = $2 AND aggregate_id = $3) OR (aggregate_type = $4 AND aggregate_id = $5))",
|
||||
args: []any{"a_type1", "a_id1", "a_type2", "a_id2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
tt.want.args = append([]any{"i_id"}, tt.want.args...)
|
||||
ctx := authz.WithInstanceID(context.Background(), "i_id")
|
||||
|
||||
if got := buildSearchStatement(ctx, &builder, tt.args.conditions...); !reflect.DeepEqual(got, tt.want.args) {
|
||||
t.Errorf("buildSearchStatement() = %v, want %v", got, tt.want)
|
||||
}
|
||||
if tt.want.stmt != builder.String() {
|
||||
t.Errorf("buildSearchStatement() stmt = %q, want %q", builder.String(), tt.want.stmt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
83
apps/api/internal/eventstore/v3/mock_test.go
Normal file
83
apps/api/internal/eventstore/v3/mock_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
var _ eventstore.Command = (*mockCommand)(nil)
|
||||
|
||||
type mockCommand struct {
|
||||
aggregate *eventstore.Aggregate
|
||||
payload any
|
||||
constraints []*eventstore.UniqueConstraint
|
||||
}
|
||||
|
||||
// Aggregate implements [eventstore.Command]
|
||||
func (m *mockCommand) Aggregate() *eventstore.Aggregate {
|
||||
return m.aggregate
|
||||
}
|
||||
|
||||
// Creator implements [eventstore.Command]
|
||||
func (m *mockCommand) Creator() string {
|
||||
return "creator"
|
||||
}
|
||||
|
||||
// Revision implements [eventstore.Command]
|
||||
func (m *mockCommand) Revision() uint16 {
|
||||
return 1
|
||||
}
|
||||
|
||||
// Type implements [eventstore.Command]
|
||||
func (m *mockCommand) Type() eventstore.EventType {
|
||||
return "event.type"
|
||||
}
|
||||
|
||||
// Payload implements [eventstore.Command]
|
||||
func (m *mockCommand) Payload() any {
|
||||
return m.payload
|
||||
}
|
||||
|
||||
// UniqueConstraints implements [eventstore.Command]
|
||||
func (m *mockCommand) UniqueConstraints() []*eventstore.UniqueConstraint {
|
||||
return m.constraints
|
||||
}
|
||||
|
||||
func (e *mockCommand) Fields() []*eventstore.FieldOperation {
|
||||
return nil
|
||||
}
|
||||
|
||||
func mockEvent(aggregate *eventstore.Aggregate, sequence uint64, payload Payload) eventstore.Event {
|
||||
return &event{
|
||||
command: &command{
|
||||
InstanceID: aggregate.InstanceID,
|
||||
AggregateType: string(aggregate.Type),
|
||||
AggregateID: aggregate.ID,
|
||||
Owner: aggregate.ResourceOwner,
|
||||
Creator: "creator",
|
||||
Revision: 1,
|
||||
CommandType: "event.type",
|
||||
Payload: payload,
|
||||
},
|
||||
sequence: sequence,
|
||||
}
|
||||
}
|
||||
|
||||
func mockAggregate(id string) *eventstore.Aggregate {
|
||||
return &eventstore.Aggregate{
|
||||
ID: id,
|
||||
Type: "type",
|
||||
ResourceOwner: "ro",
|
||||
InstanceID: "instance",
|
||||
Version: "v1",
|
||||
}
|
||||
}
|
||||
|
||||
func mockAggregateWithInstance(id, instance string) *eventstore.Aggregate {
|
||||
return &eventstore.Aggregate{
|
||||
ID: id,
|
||||
InstanceID: instance,
|
||||
Type: "type",
|
||||
ResourceOwner: "ro",
|
||||
Version: "v1",
|
||||
}
|
||||
}
|
108
apps/api/internal/eventstore/v3/push.go
Normal file
108
apps/api/internal/eventstore/v3/push.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
var pushTxOpts = &sql.TxOptions{
|
||||
Isolation: sql.LevelReadCommitted,
|
||||
ReadOnly: false,
|
||||
}
|
||||
|
||||
func (es *Eventstore) Push(ctx context.Context, client database.ContextQueryExecuter, commands ...eventstore.Command) (events []eventstore.Event, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
events, err = es.writeCommands(ctx, client, commands)
|
||||
if isSetupNotExecutedError(err) {
|
||||
return es.pushWithoutFunc(ctx, client, commands...)
|
||||
}
|
||||
|
||||
return events, err
|
||||
}
|
||||
|
||||
func (es *Eventstore) writeCommands(ctx context.Context, client database.ContextQueryExecuter, commands []eventstore.Command) (_ []eventstore.Event, err error) {
|
||||
var conn *sql.Conn
|
||||
switch c := client.(type) {
|
||||
case database.Client:
|
||||
conn, err = c.Conn(ctx)
|
||||
case nil:
|
||||
conn, err = es.client.Conn(ctx)
|
||||
client = conn
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if conn != nil {
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
tx, close, err := es.pushTx(ctx, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if close != nil {
|
||||
defer func() {
|
||||
err = close(err)
|
||||
}()
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, fmt.Sprintf("SET LOCAL application_name = '%s'", fmt.Sprintf("zitadel_es_pusher_%s", authz.GetInstance(ctx).InstanceID())))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
events, err := writeEvents(ctx, tx, commands)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = handleUniqueConstraints(ctx, tx, commands); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = es.handleFieldCommands(ctx, tx, commands)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func writeEvents(ctx context.Context, tx database.Tx, commands []eventstore.Command) (_ []eventstore.Event, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
events, cmds, err := commandsToEvents(ctx, commands)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := tx.QueryContext(ctx, `select owner, created_at, "sequence", position from eventstore.push($1::eventstore.command[])`, cmds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for i := 0; rows.Next(); i++ {
|
||||
err = rows.Scan(&events[i].(*event).command.Owner, &events[i].(*event).createdAt, &events[i].(*event).sequence, &events[i].(*event).position)
|
||||
if err != nil {
|
||||
logging.WithError(err).Warn("failed to scan events")
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return events, nil
|
||||
}
|
18
apps/api/internal/eventstore/v3/push.sql
Normal file
18
apps/api/internal/eventstore/v3/push.sql
Normal file
@@ -0,0 +1,18 @@
|
||||
INSERT INTO eventstore.events2 (
|
||||
instance_id
|
||||
, "owner"
|
||||
, aggregate_type
|
||||
, aggregate_id
|
||||
, revision
|
||||
|
||||
, creator
|
||||
, event_type
|
||||
, payload
|
||||
, "sequence"
|
||||
, created_at
|
||||
|
||||
, "position"
|
||||
, in_tx_order
|
||||
) VALUES
|
||||
%s
|
||||
RETURNING created_at, "position";
|
253
apps/api/internal/eventstore/v3/push_test.go
Normal file
253
apps/api/internal/eventstore/v3/push_test.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/database/postgres"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
func Test_mapCommands(t *testing.T) {
|
||||
type args struct {
|
||||
commands []eventstore.Command
|
||||
sequences []*latestSequence
|
||||
}
|
||||
type want struct {
|
||||
events []eventstore.Event
|
||||
placeHolders []string
|
||||
args []any
|
||||
err func(t *testing.T, err error)
|
||||
shouldPanic bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "no commands",
|
||||
args: args{
|
||||
commands: []eventstore.Command{},
|
||||
sequences: []*latestSequence{},
|
||||
},
|
||||
want: want{
|
||||
events: []eventstore.Event{},
|
||||
placeHolders: []string{},
|
||||
args: []any{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "one command",
|
||||
args: args{
|
||||
commands: []eventstore.Command{
|
||||
&mockCommand{
|
||||
aggregate: mockAggregate("V3-VEIvq"),
|
||||
},
|
||||
},
|
||||
sequences: []*latestSequence{
|
||||
{
|
||||
aggregate: mockAggregate("V3-VEIvq"),
|
||||
sequence: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
events: []eventstore.Event{
|
||||
mockEvent(
|
||||
mockAggregate("V3-VEIvq"),
|
||||
1,
|
||||
nil,
|
||||
),
|
||||
},
|
||||
placeHolders: []string{
|
||||
"($1, $2, $3, $4, $5, $6, $7, $8, $9, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $10)",
|
||||
},
|
||||
args: []any{
|
||||
"instance",
|
||||
"ro",
|
||||
"type",
|
||||
"V3-VEIvq",
|
||||
uint16(1),
|
||||
"creator",
|
||||
"event.type",
|
||||
Payload(nil),
|
||||
uint64(1),
|
||||
0,
|
||||
},
|
||||
err: func(t *testing.T, err error) {},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple commands same aggregate",
|
||||
args: args{
|
||||
commands: []eventstore.Command{
|
||||
&mockCommand{
|
||||
aggregate: mockAggregate("V3-VEIvq"),
|
||||
},
|
||||
&mockCommand{
|
||||
aggregate: mockAggregate("V3-VEIvq"),
|
||||
},
|
||||
},
|
||||
sequences: []*latestSequence{
|
||||
{
|
||||
aggregate: mockAggregate("V3-VEIvq"),
|
||||
sequence: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
events: []eventstore.Event{
|
||||
mockEvent(
|
||||
mockAggregate("V3-VEIvq"),
|
||||
6,
|
||||
nil,
|
||||
),
|
||||
mockEvent(
|
||||
mockAggregate("V3-VEIvq"),
|
||||
7,
|
||||
nil,
|
||||
),
|
||||
},
|
||||
placeHolders: []string{
|
||||
"($1, $2, $3, $4, $5, $6, $7, $8, $9, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $10)",
|
||||
"($11, $12, $13, $14, $15, $16, $17, $18, $19, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $20)",
|
||||
},
|
||||
args: []any{
|
||||
// first event
|
||||
"instance",
|
||||
"ro",
|
||||
"type",
|
||||
"V3-VEIvq",
|
||||
uint16(1),
|
||||
"creator",
|
||||
"event.type",
|
||||
Payload(nil),
|
||||
uint64(6),
|
||||
0,
|
||||
// second event
|
||||
"instance",
|
||||
"ro",
|
||||
"type",
|
||||
"V3-VEIvq",
|
||||
uint16(1),
|
||||
"creator",
|
||||
"event.type",
|
||||
Payload(nil),
|
||||
uint64(7),
|
||||
1,
|
||||
},
|
||||
err: func(t *testing.T, err error) {},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "one command per aggregate",
|
||||
args: args{
|
||||
commands: []eventstore.Command{
|
||||
&mockCommand{
|
||||
aggregate: mockAggregate("V3-VEIvq"),
|
||||
},
|
||||
&mockCommand{
|
||||
aggregate: mockAggregate("V3-IT6VN"),
|
||||
},
|
||||
},
|
||||
sequences: []*latestSequence{
|
||||
{
|
||||
aggregate: mockAggregate("V3-VEIvq"),
|
||||
sequence: 5,
|
||||
},
|
||||
{
|
||||
aggregate: mockAggregate("V3-IT6VN"),
|
||||
sequence: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
events: []eventstore.Event{
|
||||
mockEvent(
|
||||
mockAggregate("V3-VEIvq"),
|
||||
6,
|
||||
nil,
|
||||
),
|
||||
mockEvent(
|
||||
mockAggregate("V3-IT6VN"),
|
||||
1,
|
||||
nil,
|
||||
),
|
||||
},
|
||||
placeHolders: []string{
|
||||
"($1, $2, $3, $4, $5, $6, $7, $8, $9, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $10)",
|
||||
"($11, $12, $13, $14, $15, $16, $17, $18, $19, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $20)",
|
||||
},
|
||||
args: []any{
|
||||
// first event
|
||||
"instance",
|
||||
"ro",
|
||||
"type",
|
||||
"V3-VEIvq",
|
||||
uint16(1),
|
||||
"creator",
|
||||
"event.type",
|
||||
Payload(nil),
|
||||
uint64(6),
|
||||
0,
|
||||
// second event
|
||||
"instance",
|
||||
"ro",
|
||||
"type",
|
||||
"V3-IT6VN",
|
||||
uint16(1),
|
||||
"creator",
|
||||
"event.type",
|
||||
Payload(nil),
|
||||
uint64(1),
|
||||
1,
|
||||
},
|
||||
err: func(t *testing.T, err error) {},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing sequence",
|
||||
args: args{
|
||||
commands: []eventstore.Command{
|
||||
&mockCommand{
|
||||
aggregate: mockAggregate("V3-VEIvq"),
|
||||
},
|
||||
},
|
||||
sequences: []*latestSequence{},
|
||||
},
|
||||
want: want{
|
||||
events: []eventstore.Event{},
|
||||
placeHolders: []string{},
|
||||
args: []any{},
|
||||
err: func(t *testing.T, err error) {},
|
||||
shouldPanic: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if tt.want.err == nil {
|
||||
tt.want.err = func(t *testing.T, err error) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
// is used to set the the [pushPlaceholderFmt]
|
||||
NewEventstore(&database.DB{Database: new(postgres.Config)})
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer func() {
|
||||
cause := recover()
|
||||
assert.Equal(t, tt.want.shouldPanic, cause != nil)
|
||||
}()
|
||||
gotEvents, gotPlaceHolders, gotArgs, err := mapCommands(tt.args.commands, tt.args.sequences)
|
||||
tt.want.err(t, err)
|
||||
|
||||
assert.ElementsMatch(t, tt.want.events, gotEvents)
|
||||
assert.ElementsMatch(t, tt.want.placeHolders, gotPlaceHolders)
|
||||
assert.ElementsMatch(t, tt.want.args, gotArgs)
|
||||
})
|
||||
}
|
||||
}
|
162
apps/api/internal/eventstore/v3/push_without_func.go
Normal file
162
apps/api/internal/eventstore/v3/push_without_func.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
// checks whether the error is caused because setup step 39 was not executed
|
||||
func isSetupNotExecutedError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) {
|
||||
return (pgErr.Code == "42704" && strings.Contains(pgErr.Message, "eventstore.command")) ||
|
||||
(pgErr.Code == "42883" && strings.Contains(pgErr.Message, "eventstore.push"))
|
||||
}
|
||||
return errors.Is(err, errTypesNotFound)
|
||||
}
|
||||
|
||||
var (
|
||||
//go:embed push.sql
|
||||
pushStmt string
|
||||
)
|
||||
|
||||
// pushWithoutFunc implements pushing events before setup step 39 was introduced.
|
||||
// TODO: remove with v3
|
||||
func (es *Eventstore) pushWithoutFunc(ctx context.Context, client database.ContextQueryExecuter, commands ...eventstore.Command) (events []eventstore.Event, err error) {
|
||||
tx, closeTx, err := es.pushTx(ctx, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
err = closeTx(err)
|
||||
}()
|
||||
|
||||
var (
|
||||
sequences []*latestSequence
|
||||
)
|
||||
sequences, err = latestSequences(ctx, tx, commands)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
events, err = es.writeEventsOld(ctx, tx, sequences, commands)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = handleUniqueConstraints(ctx, tx, commands); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = es.handleFieldCommands(ctx, tx, commands)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func (es *Eventstore) writeEventsOld(ctx context.Context, tx database.Tx, sequences []*latestSequence, commands []eventstore.Command) ([]eventstore.Event, error) {
|
||||
events, placeholders, args, err := mapCommands(commands, sequences)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := tx.QueryContext(ctx, fmt.Sprintf(pushStmt, strings.Join(placeholders, ", ")), args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for i := 0; rows.Next(); i++ {
|
||||
err = rows.Scan(&events[i].(*event).createdAt, &events[i].(*event).position)
|
||||
if err != nil {
|
||||
logging.WithError(err).Warn("failed to scan events")
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
pgErr := new(pgconn.PgError)
|
||||
if errors.As(err, &pgErr) {
|
||||
// Check if push tries to write an event just written
|
||||
// by another transaction
|
||||
if pgErr.Code == "40001" {
|
||||
// TODO: @livio-a should we return the parent or not?
|
||||
return nil, zerrors.ThrowInvalidArgument(err, "V3-p5xAn", "Errors.AlreadyExists")
|
||||
}
|
||||
}
|
||||
logging.WithError(rows.Err()).Warn("failed to push events")
|
||||
return nil, zerrors.ThrowInternal(err, "V3-VGnZY", "Errors.Internal")
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
const argsPerCommand = 10
|
||||
|
||||
func mapCommands(commands []eventstore.Command, sequences []*latestSequence) (events []eventstore.Event, placeholders []string, args []any, err error) {
|
||||
events = make([]eventstore.Event, len(commands))
|
||||
args = make([]any, 0, len(commands)*argsPerCommand)
|
||||
placeholders = make([]string, len(commands))
|
||||
|
||||
for i, command := range commands {
|
||||
sequence := searchSequenceByCommand(sequences, command)
|
||||
if sequence == nil {
|
||||
logging.WithFields(
|
||||
"aggType", command.Aggregate().Type,
|
||||
"aggID", command.Aggregate().ID,
|
||||
"instance", command.Aggregate().InstanceID,
|
||||
).Panic("no sequence found")
|
||||
// added return for linting
|
||||
return nil, nil, nil, nil
|
||||
}
|
||||
sequence.sequence++
|
||||
|
||||
events[i], err = commandToEventOld(sequence, command)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
placeholders[i] = fmt.Sprintf(pushPlaceholderFmt,
|
||||
i*argsPerCommand+1,
|
||||
i*argsPerCommand+2,
|
||||
i*argsPerCommand+3,
|
||||
i*argsPerCommand+4,
|
||||
i*argsPerCommand+5,
|
||||
i*argsPerCommand+6,
|
||||
i*argsPerCommand+7,
|
||||
i*argsPerCommand+8,
|
||||
i*argsPerCommand+9,
|
||||
i*argsPerCommand+10,
|
||||
)
|
||||
|
||||
args = append(args,
|
||||
events[i].(*event).command.InstanceID,
|
||||
events[i].(*event).command.Owner,
|
||||
events[i].(*event).command.AggregateType,
|
||||
events[i].(*event).command.AggregateID,
|
||||
events[i].(*event).command.Revision,
|
||||
events[i].(*event).command.Creator,
|
||||
events[i].(*event).command.CommandType,
|
||||
events[i].(*event).command.Payload,
|
||||
events[i].(*event).sequence,
|
||||
i,
|
||||
)
|
||||
}
|
||||
|
||||
return events, placeholders, args, nil
|
||||
}
|
144
apps/api/internal/eventstore/v3/sequence.go
Normal file
144
apps/api/internal/eventstore/v3/sequence.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type latestSequence struct {
|
||||
aggregate *eventstore.Aggregate
|
||||
sequence uint64
|
||||
}
|
||||
|
||||
//go:embed sequences_query.sql
|
||||
var latestSequencesStmt string
|
||||
|
||||
func latestSequences(ctx context.Context, tx database.Tx, commands []eventstore.Command) ([]*latestSequence, error) {
|
||||
sequences := commandsToSequences(ctx, commands)
|
||||
|
||||
conditions, args := sequencesToSql(sequences)
|
||||
rows, err := tx.QueryContext(ctx, fmt.Sprintf(latestSequencesStmt, strings.Join(conditions, " UNION ALL ")), args...)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "V3-5jU5z", "Errors.Internal")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
if err := scanToSequence(rows, sequences); err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "V3-Ydiwv", "Errors.Internal")
|
||||
}
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
return nil, zerrors.ThrowInternal(rows.Err(), "V3-XApDk", "Errors.Internal")
|
||||
}
|
||||
return sequences, nil
|
||||
}
|
||||
|
||||
func searchSequenceByCommand(sequences []*latestSequence, command eventstore.Command) *latestSequence {
|
||||
for _, sequence := range sequences {
|
||||
if sequence.aggregate.Type == command.Aggregate().Type &&
|
||||
sequence.aggregate.ID == command.Aggregate().ID &&
|
||||
sequence.aggregate.InstanceID == command.Aggregate().InstanceID {
|
||||
return sequence
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func searchSequence(sequences []*latestSequence, aggregateType eventstore.AggregateType, aggregateID, instanceID string) *latestSequence {
|
||||
for _, sequence := range sequences {
|
||||
if sequence.aggregate.Type == aggregateType &&
|
||||
sequence.aggregate.ID == aggregateID &&
|
||||
sequence.aggregate.InstanceID == instanceID {
|
||||
return sequence
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func commandsToSequences(ctx context.Context, commands []eventstore.Command) []*latestSequence {
|
||||
sequences := make([]*latestSequence, 0, len(commands))
|
||||
|
||||
for _, command := range commands {
|
||||
if searchSequenceByCommand(sequences, command) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if command.Aggregate().InstanceID == "" {
|
||||
command.Aggregate().InstanceID = authz.GetInstance(ctx).InstanceID()
|
||||
}
|
||||
sequences = append(sequences, &latestSequence{
|
||||
aggregate: command.Aggregate(),
|
||||
})
|
||||
}
|
||||
|
||||
return sequences
|
||||
}
|
||||
|
||||
const argsPerCondition = 3
|
||||
|
||||
func sequencesToSql(sequences []*latestSequence) (conditions []string, args []any) {
|
||||
args = make([]interface{}, 0, len(sequences)*argsPerCondition)
|
||||
conditions = make([]string, len(sequences))
|
||||
|
||||
for i, sequence := range sequences {
|
||||
conditions[i] = fmt.Sprintf(`(SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $%d AND aggregate_type = $%d AND aggregate_id = $%d ORDER BY "sequence" DESC LIMIT 1)`,
|
||||
i*argsPerCondition+1,
|
||||
i*argsPerCondition+2,
|
||||
i*argsPerCondition+3,
|
||||
)
|
||||
args = append(args, sequence.aggregate.InstanceID, sequence.aggregate.Type, sequence.aggregate.ID)
|
||||
}
|
||||
|
||||
return conditions, args
|
||||
}
|
||||
|
||||
func scanToSequence(rows *sql.Rows, sequences []*latestSequence) error {
|
||||
var aggregateType eventstore.AggregateType
|
||||
var aggregateID, instanceID string
|
||||
var currentSequence uint64
|
||||
var resourceOwner string
|
||||
|
||||
if err := rows.Scan(&instanceID, &resourceOwner, &aggregateType, &aggregateID, ¤tSequence); err != nil {
|
||||
return zerrors.ThrowInternal(err, "V3-OIWqj", "Errors.Internal")
|
||||
}
|
||||
|
||||
sequence := searchSequence(sequences, aggregateType, aggregateID, instanceID)
|
||||
if sequence == nil {
|
||||
logging.WithFields(
|
||||
"aggType", aggregateType,
|
||||
"aggID", aggregateID,
|
||||
"instance", instanceID,
|
||||
).Panic("no sequence found")
|
||||
// added return for linting
|
||||
return nil
|
||||
}
|
||||
sequence.sequence = currentSequence
|
||||
if resourceOwner != "" && sequence.aggregate.ResourceOwner != "" && sequence.aggregate.ResourceOwner != resourceOwner {
|
||||
logging.WithFields(
|
||||
"current_sequence", sequence.sequence,
|
||||
"instance_id", sequence.aggregate.InstanceID,
|
||||
"agg_type", sequence.aggregate.Type,
|
||||
"agg_id", sequence.aggregate.ID,
|
||||
"current_owner", resourceOwner,
|
||||
"provided_owner", sequence.aggregate.ResourceOwner,
|
||||
).Info("would have set wrong resource owner")
|
||||
}
|
||||
// set resource owner from previous events
|
||||
if resourceOwner != "" {
|
||||
sequence.aggregate.ResourceOwner = resourceOwner
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
293
apps/api/internal/eventstore/v3/sequence_test.go
Normal file
293
apps/api/internal/eventstore/v3/sequence_test.go
Normal file
@@ -0,0 +1,293 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
func Test_searchSequence(t *testing.T) {
|
||||
sequence := &latestSequence{
|
||||
aggregate: mockAggregate("V3-p1BWC"),
|
||||
sequence: 1,
|
||||
}
|
||||
type args struct {
|
||||
sequences []*latestSequence
|
||||
aggregateType eventstore.AggregateType
|
||||
aggregateID string
|
||||
instanceID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *latestSequence
|
||||
}{
|
||||
{
|
||||
name: "type missmatch",
|
||||
args: args{
|
||||
sequences: []*latestSequence{
|
||||
sequence,
|
||||
},
|
||||
aggregateType: "wrong",
|
||||
aggregateID: "V3-p1BWC",
|
||||
instanceID: "instance",
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "id missmatch",
|
||||
args: args{
|
||||
sequences: []*latestSequence{
|
||||
sequence,
|
||||
},
|
||||
aggregateType: "type",
|
||||
aggregateID: "wrong",
|
||||
instanceID: "instance",
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "instance missmatch",
|
||||
args: args{
|
||||
sequences: []*latestSequence{
|
||||
sequence,
|
||||
},
|
||||
aggregateType: "type",
|
||||
aggregateID: "V3-p1BWC",
|
||||
instanceID: "wrong",
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "match",
|
||||
args: args{
|
||||
sequences: []*latestSequence{
|
||||
sequence,
|
||||
},
|
||||
aggregateType: "type",
|
||||
aggregateID: "V3-p1BWC",
|
||||
instanceID: "instance",
|
||||
},
|
||||
want: sequence,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := searchSequence(tt.args.sequences, tt.args.aggregateType, tt.args.aggregateID, tt.args.instanceID); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("searchSequence() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_commandsToSequences(t *testing.T) {
|
||||
aggregate := mockAggregate("V3-MKHTF")
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
commands []eventstore.Command
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []*latestSequence
|
||||
}{
|
||||
{
|
||||
name: "no command",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
commands: []eventstore.Command{},
|
||||
},
|
||||
want: []*latestSequence{},
|
||||
},
|
||||
{
|
||||
name: "one command",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
commands: []eventstore.Command{
|
||||
&mockCommand{
|
||||
aggregate: aggregate,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []*latestSequence{
|
||||
{
|
||||
aggregate: aggregate,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two commands same aggregate",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
commands: []eventstore.Command{
|
||||
&mockCommand{
|
||||
aggregate: aggregate,
|
||||
},
|
||||
&mockCommand{
|
||||
aggregate: aggregate,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []*latestSequence{
|
||||
{
|
||||
aggregate: aggregate,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two commands different aggregates",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
commands: []eventstore.Command{
|
||||
&mockCommand{
|
||||
aggregate: aggregate,
|
||||
},
|
||||
&mockCommand{
|
||||
aggregate: mockAggregate("V3-cZkCy"),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []*latestSequence{
|
||||
{
|
||||
aggregate: aggregate,
|
||||
},
|
||||
{
|
||||
aggregate: mockAggregate("V3-cZkCy"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "instance set in command",
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "V3-ANV4p"),
|
||||
commands: []eventstore.Command{
|
||||
&mockCommand{
|
||||
aggregate: &eventstore.Aggregate{
|
||||
ID: "V3-bF0Sa",
|
||||
Type: "type",
|
||||
ResourceOwner: "to",
|
||||
InstanceID: "instance",
|
||||
Version: "v1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []*latestSequence{
|
||||
{
|
||||
aggregate: &eventstore.Aggregate{
|
||||
ID: "V3-bF0Sa",
|
||||
Type: "type",
|
||||
ResourceOwner: "to",
|
||||
InstanceID: "instance",
|
||||
Version: "v1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "instance from context",
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "V3-ANV4p"),
|
||||
commands: []eventstore.Command{
|
||||
&mockCommand{
|
||||
aggregate: &eventstore.Aggregate{
|
||||
ID: "V3-bF0Sa",
|
||||
Type: "type",
|
||||
ResourceOwner: "to",
|
||||
Version: "v1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []*latestSequence{
|
||||
{
|
||||
aggregate: &eventstore.Aggregate{
|
||||
ID: "V3-bF0Sa",
|
||||
Type: "type",
|
||||
ResourceOwner: "to",
|
||||
InstanceID: "V3-ANV4p",
|
||||
Version: "v1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := commandsToSequences(tt.args.ctx, tt.args.commands)
|
||||
assert.ElementsMatch(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sequencesToSql(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
arg []*latestSequence
|
||||
wantConditions []string
|
||||
wantArgs []any
|
||||
}{
|
||||
{
|
||||
name: "no sequence",
|
||||
arg: []*latestSequence{},
|
||||
wantConditions: []string{},
|
||||
wantArgs: []any{},
|
||||
},
|
||||
{
|
||||
name: "one",
|
||||
arg: []*latestSequence{
|
||||
{
|
||||
aggregate: mockAggregate("V3-SbpGB"),
|
||||
},
|
||||
},
|
||||
wantConditions: []string{
|
||||
`(SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $1 AND aggregate_type = $2 AND aggregate_id = $3 ORDER BY "sequence" DESC LIMIT 1)`,
|
||||
},
|
||||
wantArgs: []any{
|
||||
"instance",
|
||||
eventstore.AggregateType("type"),
|
||||
"V3-SbpGB",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple",
|
||||
arg: []*latestSequence{
|
||||
{
|
||||
aggregate: mockAggregate("V3-SbpGB"),
|
||||
},
|
||||
{
|
||||
aggregate: mockAggregate("V3-0X3yt"),
|
||||
},
|
||||
},
|
||||
wantConditions: []string{
|
||||
`(SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $1 AND aggregate_type = $2 AND aggregate_id = $3 ORDER BY "sequence" DESC LIMIT 1)`,
|
||||
`(SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $4 AND aggregate_type = $5 AND aggregate_id = $6 ORDER BY "sequence" DESC LIMIT 1)`,
|
||||
},
|
||||
wantArgs: []any{
|
||||
"instance",
|
||||
eventstore.AggregateType("type"),
|
||||
"V3-SbpGB",
|
||||
"instance",
|
||||
eventstore.AggregateType("type"),
|
||||
"V3-0X3yt",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotConditions, gotArgs := sequencesToSql(tt.arg)
|
||||
if !reflect.DeepEqual(gotConditions, tt.wantConditions) {
|
||||
t.Errorf("sequencesToSql() gotConditions = %v, want %v", gotConditions, tt.wantConditions)
|
||||
}
|
||||
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
|
||||
t.Errorf("sequencesToSql() gotArgs = %v, want %v", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
18
apps/api/internal/eventstore/v3/sequences_query.sql
Normal file
18
apps/api/internal/eventstore/v3/sequences_query.sql
Normal file
@@ -0,0 +1,18 @@
|
||||
WITH existing AS (
|
||||
%s
|
||||
) SELECT
|
||||
e.instance_id
|
||||
, e.owner
|
||||
, e.aggregate_type
|
||||
, e.aggregate_id
|
||||
, e.sequence
|
||||
FROM
|
||||
eventstore.events2 e
|
||||
JOIN
|
||||
existing
|
||||
ON
|
||||
e.instance_id = existing.instance_id
|
||||
AND e.aggregate_type = existing.aggregate_type
|
||||
AND e.aggregate_id = existing.aggregate_id
|
||||
AND e.sequence = existing.sequence
|
||||
FOR UPDATE;
|
25
apps/api/internal/eventstore/v3/type.go
Normal file
25
apps/api/internal/eventstore/v3/type.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package eventstore
|
||||
|
||||
import "database/sql/driver"
|
||||
|
||||
// Payload represents a byte array that may be null.
|
||||
// Payload implements the sql.Scanner interface
|
||||
type Payload []byte
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (data *Payload) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*data = nil
|
||||
return nil
|
||||
}
|
||||
*data = Payload(value.([]byte))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (data Payload) Value() (driver.Value, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return []byte(data), nil
|
||||
}
|
100
apps/api/internal/eventstore/v3/unique_constraints.go
Normal file
100
apps/api/internal/eventstore/v3/unique_constraints.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed unique_constraints_delete.sql
|
||||
deleteConstraintStmt string
|
||||
//go:embed unique_constraints_delete_placeholders.sql
|
||||
deleteConstraintPlaceholdersStmt string
|
||||
//go:embed unique_constraints_add.sql
|
||||
addConstraintStmt string
|
||||
)
|
||||
|
||||
func handleUniqueConstraints(ctx context.Context, tx database.Tx, commands []eventstore.Command) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
deletePlaceholders := make([]string, 0)
|
||||
deleteArgs := make([]any, 0)
|
||||
|
||||
addPlaceholders := make([]string, 0)
|
||||
addArgs := make([]any, 0)
|
||||
addConstraints := map[string]*eventstore.UniqueConstraint{}
|
||||
deleteConstraints := map[string]*eventstore.UniqueConstraint{}
|
||||
|
||||
for _, command := range commands {
|
||||
for _, constraint := range command.UniqueConstraints() {
|
||||
instanceID := command.Aggregate().InstanceID
|
||||
if constraint.IsGlobal {
|
||||
instanceID = ""
|
||||
}
|
||||
switch constraint.Action {
|
||||
case eventstore.UniqueConstraintAdd:
|
||||
constraint.UniqueField = strings.ToLower(constraint.UniqueField)
|
||||
addPlaceholders = append(addPlaceholders, fmt.Sprintf("($%d, $%d, $%d)", len(addArgs)+1, len(addArgs)+2, len(addArgs)+3))
|
||||
addArgs = append(addArgs, instanceID, constraint.UniqueType, constraint.UniqueField)
|
||||
addConstraints[fmt.Sprintf(uniqueConstraintPlaceholderFmt, instanceID, constraint.UniqueType, constraint.UniqueField)] = constraint
|
||||
case eventstore.UniqueConstraintRemove:
|
||||
deletePlaceholders = append(deletePlaceholders, fmt.Sprintf(deleteConstraintPlaceholdersStmt, len(deleteArgs)+1, len(deleteArgs)+2, len(deleteArgs)+3))
|
||||
deleteArgs = append(deleteArgs, instanceID, constraint.UniqueType, constraint.UniqueField)
|
||||
deleteConstraints[fmt.Sprintf(uniqueConstraintPlaceholderFmt, instanceID, constraint.UniqueType, constraint.UniqueField)] = constraint
|
||||
case eventstore.UniqueConstraintInstanceRemove:
|
||||
deletePlaceholders = append(deletePlaceholders, fmt.Sprintf("(instance_id = $%d)", len(deleteArgs)+1))
|
||||
deleteArgs = append(deleteArgs, instanceID)
|
||||
deleteConstraints[fmt.Sprintf(uniqueConstraintPlaceholderFmt, instanceID, constraint.UniqueType, constraint.UniqueField)] = constraint
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(deletePlaceholders) > 0 {
|
||||
_, err := tx.ExecContext(ctx, fmt.Sprintf(deleteConstraintStmt, strings.Join(deletePlaceholders, " OR ")), deleteArgs...)
|
||||
if err != nil {
|
||||
logging.WithError(err).Warn("delete unique constraint failed")
|
||||
errMessage := "Errors.Internal"
|
||||
if constraint := constraintFromErr(err, deleteConstraints); constraint != nil {
|
||||
errMessage = constraint.ErrorMessage
|
||||
}
|
||||
return zerrors.ThrowInternal(err, "V3-C8l3V", errMessage)
|
||||
}
|
||||
}
|
||||
if len(addPlaceholders) > 0 {
|
||||
_, err := tx.ExecContext(ctx, fmt.Sprintf(addConstraintStmt, strings.Join(addPlaceholders, ", ")), addArgs...)
|
||||
if err != nil {
|
||||
logging.WithError(err).Warn("add unique constraint failed")
|
||||
errMessage := "Errors.Internal"
|
||||
if constraint := constraintFromErr(err, addConstraints); constraint != nil {
|
||||
errMessage = constraint.ErrorMessage
|
||||
}
|
||||
return zerrors.ThrowAlreadyExists(err, "V3-DKcYh", errMessage)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func constraintFromErr(err error, constraints map[string]*eventstore.UniqueConstraint) *eventstore.UniqueConstraint {
|
||||
pgErr := new(pgconn.PgError)
|
||||
if !errors.As(err, &pgErr) {
|
||||
return nil
|
||||
}
|
||||
for key, constraint := range constraints {
|
||||
if strings.Contains(pgErr.Detail, key) {
|
||||
return constraint
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -0,0 +1,6 @@
|
||||
INSERT INTO eventstore.unique_constraints (
|
||||
instance_id
|
||||
, unique_type
|
||||
, unique_field
|
||||
) VALUES
|
||||
%s
|
@@ -0,0 +1 @@
|
||||
DELETE FROM eventstore.unique_constraints WHERE %s
|
@@ -0,0 +1,13 @@
|
||||
-- the query is so complex because we accidentally stored unique constraint case sensitive
|
||||
-- the query checks first if there is a case sensitive match and afterwards if there is a case insensitive match
|
||||
(instance_id = $%[1]d AND unique_type = $%[2]d AND unique_field = (
|
||||
SELECT unique_field from (
|
||||
SELECT instance_id, unique_type, unique_field
|
||||
FROM eventstore.unique_constraints
|
||||
WHERE instance_id = $%[1]d AND unique_type = $%[2]d AND unique_field = $%[3]d
|
||||
UNION ALL
|
||||
SELECT instance_id, unique_type, unique_field
|
||||
FROM eventstore.unique_constraints
|
||||
WHERE instance_id = $%[1]d AND unique_type = $%[2]d AND unique_field = LOWER($%[3]d)
|
||||
) AS case_insensitive_constraints LIMIT 1)
|
||||
)
|
Reference in New Issue
Block a user