diff --git a/docs/docs/apis/actionsv2/introduction.md b/docs/docs/apis/actionsv2/introduction.md index 16adaac423..e6362c1aee 100644 --- a/docs/docs/apis/actionsv2/introduction.md +++ b/docs/docs/apis/actionsv2/introduction.md @@ -86,6 +86,11 @@ And if you use a different service, for example `zitadel.session.v2.SessionServi ### Targets and Includes +:::info +Includes are limited to 3 levels, which mean that include1->include2->include3 is the maximum for now. +If you have feedback to the include logic, or a reason why 3 levels are not enough, please open [an issue on github](https://github.com/zitadel/zitadel/issues) or [start a discussion on github](https://github.com/zitadel/zitadel/discussions)/[start a topic on discord](https://zitadel.com/chat) +::: + An execution can not only contain a list of Targets, but also Includes. The Includes can be defined in the Execution directly, which means you include all defined Targets by a before set Execution. diff --git a/internal/api/grpc/action/v3alpha/execution_integration_test.go b/internal/api/grpc/action/v3alpha/execution_integration_test.go index 4ca1b97d6f..fcbdb2b576 100644 --- a/internal/api/grpc/action/v3alpha/execution_integration_test.go +++ b/internal/api/grpc/action/v3alpha/execution_integration_test.go @@ -196,6 +196,33 @@ func TestServer_SetExecution_Request_Include(t *testing.T) { executionTargetsSingleTarget(targetResp.GetId()), ) + circularExecutionService := &action.Condition{ + ConditionType: &action.Condition_Request{ + Request: &action.RequestExecution{ + Condition: &action.RequestExecution_Service{ + Service: "zitadel.session.v2beta.SessionService", + }, + }, + }, + } + Tester.SetExecution(CTX, t, + circularExecutionService, + executionTargetsSingleInclude(executionCond), + ) + circularExecutionMethod := &action.Condition{ + ConditionType: &action.Condition_Request{ + Request: &action.RequestExecution{ + Condition: &action.RequestExecution_Method{ + Method: "/zitadel.session.v2beta.SessionService/ListSessions", + }, + }, + }, + } + Tester.SetExecution(CTX, t, + circularExecutionMethod, + executionTargetsSingleInclude(circularExecutionService), + ) + tests := []struct { name string ctx context.Context @@ -203,6 +230,15 @@ func TestServer_SetExecution_Request_Include(t *testing.T) { want *action.SetExecutionResponse wantErr bool }{ + { + name: "method, circular error", + ctx: CTX, + req: &action.SetExecutionRequest{ + Condition: circularExecutionService, + Targets: executionTargetsSingleInclude(circularExecutionMethod), + }, + wantErr: true, + }, { name: "method, ok", ctx: CTX, @@ -247,30 +283,6 @@ func TestServer_SetExecution_Request_Include(t *testing.T) { }, }, }, - /* circular - { - name: "all, ok", - ctx: CTX, - req: &action.SetExecutionRequest{ - Condition: &action.Condition{ - ConditionType: &action.Condition_Request{ - Request: &action.RequestExecution{ - Condition: &action.RequestExecution_All{ - All: true, - }, - }, - }, - }, - Targets: executionTargetsSingleInclude(executionCond), - }, - want: &action.SetExecutionResponse{ - Details: &object.Details{ - ChangeDate: timestamppb.Now(), - ResourceOwner: Tester.Instance.InstanceID(), - }, - }, - }, - */ } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/command/action_v2_execution.go b/internal/command/action_v2_execution.go index 422e053daf..7fb08a4a32 100644 --- a/internal/command/action_v2_execution.go +++ b/internal/command/action_v2_execution.go @@ -216,7 +216,9 @@ func (e *SetExecution) Existing(c *Commands, ctx context.Context, resourceOwner if len(includes) > 0 && !c.existsExecutionsByIDs(ctx, includes, resourceOwner) { return zerrors.ThrowNotFound(nil, "COMMAND-slgj0l4cdz", "Errors.Execution.IncludeNotFound") } - return nil + get, set := createIncludeCacheFunctions() + // maxLevels could be configurable, but set as 3 for now + return checkForIncludeCircular(ctx, e.AggregateID, resourceOwner, includes, c.getExecutionIncludes(get, set), 3) } func (c *Commands) setExecution(ctx context.Context, set *SetExecution, resourceOwner string) (_ *domain.ObjectDetails, err error) { @@ -309,3 +311,75 @@ func (c *Commands) getExecutionWriteModelByID(ctx context.Context, id string, re } return wm, nil } + +func createIncludeCacheFunctions() (func(s string) ([]string, bool), func(s string, strings []string)) { + tempCache := make(map[string][]string) + return func(s string) ([]string, bool) { + include, ok := tempCache[s] + return include, ok + }, func(s string, strings []string) { + tempCache[s] = strings + } +} + +type includeCacheFunc func(ctx context.Context, id string, resourceOwner string) ([]string, error) + +func checkForIncludeCircular(ctx context.Context, id string, resourceOwner string, includes []string, cache includeCacheFunc, maxLevels int) error { + if len(includes) == 0 { + return nil + } + level := 0 + for _, include := range includes { + if id == include { + return zerrors.ThrowPreconditionFailed(nil, "COMMAND-mo1cmjp5k7", "Errors.Execution.CircularInclude") + } + if err := checkForIncludeCircularRecur(ctx, []string{id}, resourceOwner, include, cache, maxLevels, level); err != nil { + return err + } + } + return nil +} + +func (c *Commands) getExecutionIncludes( + getCache func(string) ([]string, bool), + setCache func(string, []string), +) includeCacheFunc { + return func(ctx context.Context, id string, resourceOwner string) ([]string, error) { + included, ok := getCache(id) + if !ok { + included, err := c.getExecutionWriteModelByID(ctx, id, resourceOwner) + if err != nil { + return nil, err + } + includes := included.IncludeList() + setCache(id, includes) + return includes, nil + } + return included, nil + } +} + +func checkForIncludeCircularRecur(ctx context.Context, ids []string, resourceOwner string, include string, cache includeCacheFunc, maxLevels, level int) error { + included, err := cache(ctx, include, resourceOwner) + if err != nil { + return err + } + currentLevel := level + 1 + if currentLevel >= maxLevels { + return zerrors.ThrowPreconditionFailed(nil, "COMMAND-gbhd3g57oo", "Errors.Execution.MaxLevelsInclude") + } + for _, includedInclude := range included { + if include == includedInclude { + return zerrors.ThrowPreconditionFailed(nil, "COMMAND-iuch02i656", "Errors.Execution.CircularInclude") + } + for _, id := range ids { + if includedInclude == id { + return zerrors.ThrowPreconditionFailed(nil, "COMMAND-819opvhgjv", "Errors.Execution.CircularInclude") + } + } + if err := checkForIncludeCircularRecur(ctx, append(ids, include), resourceOwner, includedInclude, cache, maxLevels, currentLevel); err != nil { + return err + } + } + return nil +} diff --git a/internal/command/action_v2_execution_model.go b/internal/command/action_v2_execution_model.go index c53992856e..30cab0f56e 100644 --- a/internal/command/action_v2_execution_model.go +++ b/internal/command/action_v2_execution_model.go @@ -3,6 +3,7 @@ package command import ( "slices" + "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/repository/execution" ) @@ -15,6 +16,16 @@ type ExecutionWriteModel struct { ExecutionTargets []*execution.Target } +func (e *ExecutionWriteModel) IncludeList() []string { + includes := make([]string, 0) + for i := range e.ExecutionTargets { + if e.ExecutionTargets[i].Type == domain.ExecutionTargetTypeInclude { + includes = append(includes, e.ExecutionTargets[i].Target) + } + } + return includes +} + func (e *ExecutionWriteModel) Exists() bool { return len(e.ExecutionTargets) > 0 || len(e.Includes) > 0 || len(e.Targets) > 0 } diff --git a/internal/command/action_v2_execution_test.go b/internal/command/action_v2_execution_test.go index c8f91f49b2..5a9c0ecb1d 100644 --- a/internal/command/action_v2_execution_test.go +++ b/internal/command/action_v2_execution_test.go @@ -348,6 +348,16 @@ func TestCommands_SetExecutionRequest(t *testing.T) { "push ok, method include", fields{ eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + execution.NewSetEventV2(context.Background(), + execution.NewAggregate("request/include", "instance"), + []*execution.Target{ + {Type: domain.ExecutionTargetTypeTarget, Target: "target"}, + }, + ), + ), + ), expectFilter( eventFromEventPusher( execution.NewSetEventV2(context.Background(), @@ -419,6 +429,16 @@ func TestCommands_SetExecutionRequest(t *testing.T) { "push ok, service include", fields{ eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + execution.NewSetEventV2(context.Background(), + execution.NewAggregate("request/include", "instance"), + []*execution.Target{ + {Type: domain.ExecutionTargetTypeTarget, Target: "target"}, + }, + ), + ), + ), expectFilter( eventFromEventPusher( execution.NewSetEventV2(context.Background(), @@ -489,6 +509,16 @@ func TestCommands_SetExecutionRequest(t *testing.T) { "push ok, all include", fields{ eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + execution.NewSetEventV2(context.Background(), + execution.NewAggregate("request/include", "instance"), + []*execution.Target{ + {Type: domain.ExecutionTargetTypeTarget, Target: "target"}, + }, + ), + ), + ), expectFilter( eventFromEventPusher( execution.NewSetEventV2(context.Background(), @@ -2373,3 +2403,470 @@ func TestCommands_DeleteExecutionFunction(t *testing.T) { }) } } + +func mockExecutionIncludesCache(cache map[string][]string) includeCacheFunc { + return func(ctx context.Context, id string, resourceOwner string) ([]string, error) { + included, ok := cache[id] + if !ok { + return nil, zerrors.ThrowPreconditionFailed(nil, "", "cache failed") + } + return included, nil + } +} + +func TestCommands_checkForIncludeCircular(t *testing.T) { + type args struct { + ctx context.Context + id string + resourceOwner string + includes []string + cache map[string][]string + } + type res struct { + err func(error) bool + } + tests := []struct { + name string + args args + res res + }{ + { + "not found, error", + args{ + ctx: context.Background(), + id: "id", + resourceOwner: "", + includes: []string{"notexistent"}, + cache: map[string][]string{}, + }, + res{ + err: zerrors.IsPreconditionFailed, + }, + }, + { + "single, ok", + args{ + ctx: context.Background(), + id: "id1", + resourceOwner: "", + includes: []string{"id2"}, + cache: map[string][]string{ + "id2": {}, + }, + }, + res{}, + }, + { + "single, circular", + args{ + ctx: context.Background(), + id: "id1", + resourceOwner: "", + includes: []string{"id1"}, + cache: map[string][]string{}, + }, + res{ + err: zerrors.IsPreconditionFailed, + }, + }, + { + "multi 3, ok", + args{ + ctx: context.Background(), + id: "id1", + resourceOwner: "", + includes: []string{"id2"}, + cache: map[string][]string{ + "id2": {"id3"}, + "id3": {}, + }, + }, + res{}, + }, + { + "multi 3, circular", + args{ + ctx: context.Background(), + id: "id1", + resourceOwner: "", + includes: []string{"id2"}, + cache: map[string][]string{ + "id2": {"id3"}, + "id3": {"id1"}, + }, + }, + res{ + err: zerrors.IsPreconditionFailed, + }, + }, + { + "multi 5, ok", + args{ + ctx: context.Background(), + id: "id1", + resourceOwner: "", + includes: []string{"id11", "id12"}, + cache: map[string][]string{ + "id11": {"id21", "id23"}, + "id12": {"id22"}, + "id21": {}, + "id22": {}, + "id23": {}, + }, + }, + res{}, + }, + { + "multi 5, circular", + args{ + ctx: context.Background(), + id: "id1", + resourceOwner: "", + includes: []string{"id11", "id12"}, + cache: map[string][]string{ + "id11": {"id21", "id23"}, + "id12": {"id22"}, + "id21": {}, + "id22": {}, + "id23": {"id1"}, + }, + }, + res{ + err: zerrors.IsPreconditionFailed, + }, + }, + { + "multi 5, circular", + args{ + ctx: context.Background(), + id: "id1", + resourceOwner: "", + includes: []string{"id11", "id12"}, + cache: map[string][]string{ + "id11": {"id21", "id23"}, + "id12": {"id22"}, + "id21": {}, + "id22": {}, + "id23": {"id11"}, + }, + }, + res{ + err: zerrors.IsPreconditionFailed, + }, + }, + { + "multi 5, circular", + args{ + ctx: context.Background(), + id: "id1", + resourceOwner: "", + includes: []string{"id11", "id12"}, + cache: map[string][]string{ + "id11": {"id21", "id23"}, + "id12": {"id22"}, + "id21": {"id11"}, + "id22": {}, + "id23": {}, + }, + }, + res{ + err: zerrors.IsPreconditionFailed, + }, + }, + { + "multi 5, circular", + args{ + ctx: context.Background(), + id: "id1", + resourceOwner: "", + includes: []string{"id11", "id12"}, + cache: map[string][]string{ + "id11": {"id21", "id23"}, + "id12": {"id22"}, + "id21": {}, + "id22": {"id12"}, + "id23": {}, + }, + }, + res{ + err: zerrors.IsPreconditionFailed, + }, + }, + { + "multi 3, maxlevel", + args{ + ctx: context.Background(), + id: "id1", + resourceOwner: "", + includes: []string{"id2"}, + cache: map[string][]string{ + "id2": {"id3"}, + "id3": {}, + }, + }, + res{}, + }, + { + "multi 4, over maxlevel", + args{ + ctx: context.Background(), + id: "id1", + resourceOwner: "", + includes: []string{"id2"}, + cache: map[string][]string{ + "id2": {"id3"}, + "id3": {"id4"}, + "id4": {}, + }, + }, + res{ + err: zerrors.IsPreconditionFailed, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := mockExecutionIncludesCache(tt.args.cache) + err := checkForIncludeCircular(tt.args.ctx, tt.args.id, tt.args.resourceOwner, tt.args.includes, f, 3) + if tt.res.err == nil { + assert.NoError(t, err) + } + if tt.res.err != nil && !tt.res.err(err) { + t.Errorf("got wrong err: %v ", err) + } + }) + } +} + +func mockExecutionIncludesCacheFuncs(cache map[string][]string) (func(string) ([]string, bool), func(string, []string)) { + return func(s string) ([]string, bool) { + includes, ok := cache[s] + return includes, ok + }, func(s string, strings []string) { + cache[s] = strings + } +} + +func TestCommands_getExecutionIncludes(t *testing.T) { + type fields struct { + eventstore func(t *testing.T) *eventstore.Eventstore + } + type args struct { + ctx context.Context + cache map[string][]string + id string + resourceOwner string + } + type res struct { + includes []string + cache map[string][]string + err func(error) bool + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + "new empty, ok", + fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + execution.NewSetEventV2(context.Background(), + execution.NewAggregate("request/include", "instance"), + []*execution.Target{ + {Type: domain.ExecutionTargetTypeTarget, Target: "target"}, + }, + ), + ), + ), + ), + }, + args{ + ctx: context.Background(), + cache: map[string][]string{}, + id: "id", + resourceOwner: "instance", + }, + res{ + includes: []string{}, + cache: map[string][]string{"id": {}}, + }, + }, + { + "new includes, ok", + fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + execution.NewSetEventV2(context.Background(), + execution.NewAggregate("request/include", "instance"), + []*execution.Target{ + {Type: domain.ExecutionTargetTypeInclude, Target: "include"}, + }, + ), + ), + ), + ), + }, + args{ + ctx: context.Background(), + cache: map[string][]string{}, + id: "id", + resourceOwner: "instance", + }, + res{ + includes: []string{"include"}, + cache: map[string][]string{"id": {"include"}}, + }, + }, + { + "found, ok", + fields{ + eventstore: expectEventstore(), + }, + args{ + ctx: context.Background(), + cache: map[string][]string{"id": nil}, + id: "id", + resourceOwner: "instance", + }, + res{ + includes: nil, + cache: map[string][]string{"id": nil}, + }, + }, + { + "found includes, ok", + fields{ + eventstore: expectEventstore(), + }, + args{ + ctx: context.Background(), + cache: map[string][]string{"id": {"include1", "include2", "include3"}}, + id: "id", + resourceOwner: "instance", + }, + res{ + includes: []string{"include1", "include2", "include3"}, + cache: map[string][]string{"id": {"include1", "include2", "include3"}}, + }, + }, + { + "found multiple, ok", + fields{ + eventstore: expectEventstore(), + }, + args{ + ctx: context.Background(), + cache: map[string][]string{ + "id1": {"include1", "include2", "include3"}, + "id2": {"include1", "include2", "include3"}, + "id3": {"include1", "include2", "include3"}, + }, + id: "id2", + resourceOwner: "instance", + }, + res{ + includes: []string{"include1", "include2", "include3"}, + cache: map[string][]string{ + "id1": {"include1", "include2", "include3"}, + "id2": {"include1", "include2", "include3"}, + "id3": {"include1", "include2", "include3"}, + }, + }, + }, + { + "new multiple, ok", + fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + execution.NewSetEventV2(context.Background(), + execution.NewAggregate("request/include", "instance"), + []*execution.Target{ + {Type: domain.ExecutionTargetTypeTarget, Target: "target"}, + }, + ), + ), + ), + ), + }, + args{ + ctx: context.Background(), + cache: map[string][]string{ + "id1": {"include1", "include2", "include3"}, + "id2": {"include1", "include2", "include3"}, + "id3": {"include1", "include2", "include3"}, + }, + id: "id", + resourceOwner: "instance", + }, + res{ + includes: []string{}, + cache: map[string][]string{ + "id1": {"include1", "include2", "include3"}, + "id2": {"include1", "include2", "include3"}, + "id3": {"include1", "include2", "include3"}, + "id": {}, + }, + }, + }, + { + "new multiple includes, ok", + fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + execution.NewSetEventV2(context.Background(), + execution.NewAggregate("request/include", "instance"), + []*execution.Target{ + {Type: domain.ExecutionTargetTypeInclude, Target: "include"}, + }, + ), + ), + ), + ), + }, + args{ + ctx: context.Background(), + cache: map[string][]string{ + "id1": {"include1", "include2", "include3"}, + "id2": {"include1", "include2", "include3"}, + "id3": {"include1", "include2", "include3"}, + }, + id: "id", + resourceOwner: "instance", + }, + res{ + includes: []string{"include"}, + cache: map[string][]string{ + "id1": {"include1", "include2", "include3"}, + "id2": {"include1", "include2", "include3"}, + "id3": {"include1", "include2", "include3"}, + "id": {"include"}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore(t), + } + includes, err := c.getExecutionIncludes(mockExecutionIncludesCacheFuncs(tt.args.cache))(tt.args.ctx, tt.args.id, tt.args.resourceOwner) + if tt.res.err == nil { + assert.NoError(t, err) + } + if tt.res.err != nil && !tt.res.err(err) { + t.Errorf("got wrong err: %v ", err) + } + if tt.res.err == nil { + assert.Equal(t, tt.res.cache, tt.args.cache) + assert.Equal(t, tt.res.includes, includes) + } + }) + } +}