zitadel/internal/command/milestone_test.go

630 lines
14 KiB
Go
Raw Normal View History

package command
import (
"context"
"io"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/cache/connector/gomap"
"github.com/zitadel/zitadel/internal/cache/connector/noop"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/milestone"
)
func TestCommands_GetMilestonesReached(t *testing.T) {
cached := &MilestonesReached{
InstanceID: "cached-id",
InstanceCreated: true,
AuthenticationSucceededOnInstance: true,
}
ctx := authz.WithInstanceID(context.Background(), "instanceID")
aggregate := milestone.NewAggregate(ctx)
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
}
tests := []struct {
name string
fields fields
args args
want *MilestonesReached
wantErr error
}{
{
name: "cached",
fields: fields{
eventstore: expectEventstore(),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "cached-id"),
},
want: cached,
},
{
name: "filter error",
fields: fields{
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
},
args: args{
ctx: ctx,
},
wantErr: io.ErrClosedPipe,
},
{
name: "no events, all false",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
},
},
{
name: "instance created",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.InstanceCreated)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
InstanceCreated: true,
},
},
{
name: "instance auth",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
AuthenticationSucceededOnInstance: true,
},
},
{
name: "project created",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ProjectCreated)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
ProjectCreated: true,
},
},
{
name: "app created",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ApplicationCreated)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
ApplicationCreated: true,
},
},
{
name: "app auth",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
AuthenticationSucceededOnApplication: true,
},
},
{
name: "instance deleted",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.InstanceDeleted)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
InstanceDeleted: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache := gomap.NewCache[milestoneIndex, string, *MilestonesReached](
context.Background(),
[]milestoneIndex{milestoneIndexInstanceID},
cache.Config{Connector: cache.ConnectorMemory},
)
cache.Set(context.Background(), cached)
c := &Commands{
eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: cache,
},
}
got, err := c.GetMilestonesReached(tt.args.ctx)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
}
}
func TestCommands_milestonesCompleted(t *testing.T) {
c := &Commands{
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
}
ctx := authz.WithInstanceID(context.Background(), "instanceID")
arg := &MilestonesReached{
InstanceID: "instanceID",
InstanceCreated: true,
AuthenticationSucceededOnInstance: true,
ProjectCreated: true,
ApplicationCreated: true,
AuthenticationSucceededOnApplication: true,
InstanceDeleted: false,
}
c.setCachedMilestonesReached(ctx, arg)
got, ok := c.getCachedMilestonesReached(ctx)
assert.True(t, ok)
assert.Equal(t, arg, got)
}
func TestCommands_MilestonePushed(t *testing.T) {
aggregate := milestone.NewInstanceAggregate("instanceID")
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
instanceID string
msType milestone.Type
endpoints []string
}
tests := []struct {
name string
fields fields
args args
wantErr error
}{
{
name: "milestone pushed",
fields: fields{
eventstore: expectEventstore(
expectPush(
milestone.NewPushedEvent(
context.Background(),
aggregate,
milestone.ApplicationCreated,
[]string{"foo.com", "bar.com"},
"example.com",
),
),
),
},
args: args{
ctx: context.Background(),
instanceID: "instanceID",
msType: milestone.ApplicationCreated,
endpoints: []string{"foo.com", "bar.com"},
},
wantErr: nil,
},
{
name: "pusher error",
fields: fields{
eventstore: expectEventstore(
expectPushFailed(
io.ErrClosedPipe,
milestone.NewPushedEvent(
context.Background(),
aggregate,
milestone.ApplicationCreated,
[]string{"foo.com", "bar.com"},
"example.com",
),
),
),
},
args: args{
ctx: context.Background(),
instanceID: "instanceID",
msType: milestone.ApplicationCreated,
endpoints: []string{"foo.com", "bar.com"},
},
wantErr: io.ErrClosedPipe,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
externalDomain: "example.com",
}
err := c.MilestonePushed(tt.args.ctx, tt.args.instanceID, tt.args.msType, tt.args.endpoints)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestOIDCSessionEvents_SetMilestones(t *testing.T) {
ctx := authz.WithInstanceID(context.Background(), "instanceID")
ctx = authz.WithConsoleClientID(ctx, "console")
aggregate := milestone.NewAggregate(ctx)
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
clientID string
isHuman bool
}
tests := []struct {
name string
fields fields
args args
wantEvents []eventstore.Command
wantErr error
}{
{
name: "get error",
fields: fields{
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
},
args: args{
ctx: ctx,
clientID: "client",
isHuman: true,
},
wantErr: io.ErrClosedPipe,
},
{
name: "milestones already reached",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication)),
),
),
},
args: args{
ctx: ctx,
clientID: "client",
isHuman: true,
},
wantErr: nil,
},
{
name: "auth on instance",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
},
args: args{
ctx: ctx,
clientID: "console",
isHuman: true,
},
wantEvents: []eventstore.Command{
milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance),
},
wantErr: nil,
},
{
name: "subsequent console login, no milestone",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
),
),
},
args: args{
ctx: ctx,
clientID: "console",
isHuman: true,
},
wantErr: nil,
},
{
name: "subsequent machine login, no milestone",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
),
),
},
args: args{
ctx: ctx,
clientID: "client",
isHuman: false,
},
wantErr: nil,
},
{
name: "auth on app",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
),
),
},
args: args{
ctx: ctx,
clientID: "client",
isHuman: true,
},
wantEvents: []eventstore.Command{
milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication),
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
}
s := &OIDCSessionEvents{
commands: c,
}
postCommit, err := s.SetMilestones(tt.args.ctx, tt.args.clientID, tt.args.isHuman)
postCommit(tt.args.ctx)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantEvents, s.events)
})
}
}
func TestCommands_projectCreatedMilestone(t *testing.T) {
ctx := authz.WithInstanceID(context.Background(), "instanceID")
systemCtx := authz.SetCtxData(ctx, authz.CtxData{
SystemMemberships: authz.Memberships{
&authz.Membership{
MemberType: authz.MemberTypeSystem,
},
},
})
aggregate := milestone.NewAggregate(ctx)
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
}
tests := []struct {
name string
fields fields
args args
wantEvents []eventstore.Command
wantErr error
}{
{
name: "system user",
fields: fields{
eventstore: expectEventstore(),
},
args: args{
ctx: systemCtx,
},
wantErr: nil,
},
{
name: "get error",
fields: fields{
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
},
args: args{
ctx: ctx,
},
wantErr: io.ErrClosedPipe,
},
{
name: "milestone already reached",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ProjectCreated)),
),
),
},
args: args{
ctx: ctx,
},
wantErr: nil,
},
{
name: "milestone reached event",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
},
args: args{
ctx: ctx,
},
wantEvents: []eventstore.Command{
milestone.NewReachedEvent(ctx, aggregate, milestone.ProjectCreated),
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
}
var cmds []eventstore.Command
postCommit, err := c.projectCreatedMilestone(tt.args.ctx, &cmds)
postCommit(tt.args.ctx)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantEvents, cmds)
})
}
}
func TestCommands_applicationCreatedMilestone(t *testing.T) {
ctx := authz.WithInstanceID(context.Background(), "instanceID")
systemCtx := authz.SetCtxData(ctx, authz.CtxData{
SystemMemberships: authz.Memberships{
&authz.Membership{
MemberType: authz.MemberTypeSystem,
},
},
})
aggregate := milestone.NewAggregate(ctx)
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
}
tests := []struct {
name string
fields fields
args args
wantEvents []eventstore.Command
wantErr error
}{
{
name: "system user",
fields: fields{
eventstore: expectEventstore(),
},
args: args{
ctx: systemCtx,
},
wantErr: nil,
},
{
name: "get error",
fields: fields{
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
},
args: args{
ctx: ctx,
},
wantErr: io.ErrClosedPipe,
},
{
name: "milestone already reached",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ApplicationCreated)),
),
),
},
args: args{
ctx: ctx,
},
wantErr: nil,
},
{
name: "milestone reached event",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
},
args: args{
ctx: ctx,
},
wantEvents: []eventstore.Command{
milestone.NewReachedEvent(ctx, aggregate, milestone.ApplicationCreated),
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
}
var cmds []eventstore.Command
postCommit, err := c.applicationCreatedMilestone(tt.args.ctx, &cmds)
postCommit(tt.args.ctx)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantEvents, cmds)
})
}
}
func (c *Commands) setMilestonesCompletedForTest(instanceID string) {
c.milestonesCompleted.Store(instanceID, struct{}{})
}