fix: set domain verified if domain policy does not require validation (#4061)

* fix: set domain verified if domain policy does not require validation

* handle domain claimed
This commit is contained in:
Livio Spring 2022-07-28 13:18:31 +02:00 committed by GitHub
parent 0b742233f9
commit 096e12d3d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 192 additions and 144 deletions

View File

@ -140,16 +140,17 @@ func (s *Server) ListOrgDomains(ctx context.Context, req *mgmt_pb.ListOrgDomains
}
func (s *Server) AddOrgDomain(ctx context.Context, req *mgmt_pb.AddOrgDomainRequest) (*mgmt_pb.AddOrgDomainResponse, error) {
domain, err := s.command.AddOrgDomain(ctx, AddOrgDomainRequestToDomain(ctx, req), nil)
orgID := authz.GetCtxData(ctx).OrgID
userIDs, err := s.getClaimedUserIDsOfOrgDomain(ctx, req.Domain, orgID)
if err != nil {
return nil, err
}
details, err := s.command.AddOrgDomain(ctx, orgID, req.Domain, userIDs)
if err != nil {
return nil, err
}
return &mgmt_pb.AddOrgDomainResponse{
Details: object.AddToDetailsPb(
domain.Sequence,
domain.ChangeDate,
domain.ResourceOwner,
),
Details: object.DomainToAddDetailsPb(details),
}, nil
}

View File

@ -6,7 +6,6 @@ import (
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/command/preparation"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
caos_errs "github.com/zitadel/zitadel/internal/errors"
@ -47,10 +46,7 @@ func (c *Commands) SetUpOrg(ctx context.Context, o *OrgSetup, userIDs ...string)
c.AddOrgMemberCommand(orgAgg, userID, roles...),
}
if o.CustomDomain != "" {
validations = append(validations, AddOrgDomain(orgAgg, o.CustomDomain))
for _, userID := range userIDs {
validations = append(validations, c.prepareUserDomainClaimed(userID))
}
validations = append(validations, c.prepareAddOrgDomain(orgAgg, o.CustomDomain, userIDs))
}
cmds, err := preparation.PrepareCommands(ctx, c.eventstore.Filter, validations...)
@ -246,43 +242,6 @@ func ExistsOrg(ctx context.Context, filter preparation.FilterToQueryReducer, id
return exists, nil
}
func (c *Commands) setUpOrg(
ctx context.Context,
organisation *domain.Org,
admin *domain.Human,
loginPolicy *domain.DomainPolicy,
pwPolicy *domain.PasswordComplexityPolicy,
initCodeGenerator crypto.Generator,
phoneCodeGenerator crypto.Generator,
claimedUserIDs []string,
selfregistered bool,
) (orgAgg *eventstore.Aggregate, org *OrgWriteModel, human *HumanWriteModel, orgMember *OrgMemberWriteModel, events []eventstore.Command, err error) {
orgAgg, orgWriteModel, addOrgEvents, err := c.addOrg(ctx, organisation, claimedUserIDs)
if err != nil {
return nil, nil, nil, nil, nil, err
}
var userEvents []eventstore.Command
if selfregistered {
userEvents, human, err = c.registerHuman(ctx, orgAgg.ID, admin, nil, loginPolicy, pwPolicy, initCodeGenerator, phoneCodeGenerator)
} else {
userEvents, human, err = c.addHuman(ctx, orgAgg.ID, admin, loginPolicy, pwPolicy, initCodeGenerator, phoneCodeGenerator)
}
if err != nil {
return nil, nil, nil, nil, nil, err
}
addOrgEvents = append(addOrgEvents, userEvents...)
addedMember := NewOrgMemberWriteModel(orgAgg.ID, human.AggregateID)
orgMemberAgg := OrgAggregateFromWriteModel(&addedMember.WriteModel)
orgMemberEvent, err := c.addOrgMember(ctx, orgMemberAgg, addedMember, domain.NewMember(orgMemberAgg.ID, human.AggregateID, domain.RoleOrgOwner))
if err != nil {
return nil, nil, nil, nil, nil, err
}
addOrgEvents = append(addOrgEvents, orgMemberEvent)
return orgAgg, orgWriteModel, human, addedMember, addOrgEvents, nil
}
func (c *Commands) addOrg(ctx context.Context, organisation *domain.Org, claimedUserIDs []string) (_ *eventstore.Aggregate, _ *OrgWriteModel, _ []eventstore.Command, err error) {
if !organisation.IsValid() {
return nil, nil, nil, caos_errs.ThrowInvalidArgument(nil, "COMM-deLSk", "Errors.Org.Invalid")

View File

@ -17,26 +17,34 @@ import (
"github.com/zitadel/zitadel/internal/repository/org"
)
func AddOrgDomain(a *org.Aggregate, domain string) preparation.Validation {
func (c *Commands) prepareAddOrgDomain(a *org.Aggregate, addDomain string, userIDs []string) preparation.Validation {
return func() (preparation.CreateCommands, error) {
if domain = strings.TrimSpace(domain); domain == "" {
if addDomain = strings.TrimSpace(addDomain); addDomain == "" {
return nil, errors.ThrowInvalidArgument(nil, "ORG-r3h4J", "Errors.Invalid.Argument")
}
return func(ctx context.Context, filter preparation.FilterToQueryReducer) ([]eventstore.Command, error) {
existing, err := orgDomain(ctx, filter, a.ID, domain)
existing, err := orgDomain(ctx, filter, a.ID, addDomain)
if err != nil && !errs.Is(err, errors.ThrowNotFound(nil, "", "")) {
return nil, err
}
if existing != nil && existing.Verified {
if existing != nil && existing.State == domain.OrgDomainStateActive {
return nil, errors.ThrowAlreadyExists(nil, "V2-e1wse", "Errors.Already.Exists")
}
domainPolicy, err := domainPolicyWriteModel(ctx, filter)
if err != nil {
return nil, err
}
events := []eventstore.Command{org.NewDomainAddedEvent(ctx, &a.Aggregate, domain)}
events := []eventstore.Command{org.NewDomainAddedEvent(ctx, &a.Aggregate, addDomain)}
if !domainPolicy.ValidateOrgDomains {
events = append(events, org.NewDomainVerifiedEvent(ctx, &a.Aggregate, domain))
events = append(events, org.NewDomainVerifiedEvent(ctx, &a.Aggregate, addDomain))
for _, userID := range userIDs {
claimedEvent, err := c.prepareUserDomainClaimed(ctx, filter, userID)
if err != nil {
logging.WithFields("userid", userID).WithError(err).Error("could not claim user")
continue
}
events = append(events, claimedEvent)
}
}
return events, nil
}, nil
@ -93,25 +101,17 @@ func orgDomain(ctx context.Context, filter preparation.FilterToQueryReducer, org
return wm, nil
}
func (c *Commands) AddOrgDomain(ctx context.Context, orgDomain *domain.OrgDomain, claimedUserIDs []string) (*domain.OrgDomain, error) {
if !orgDomain.IsValid() {
return nil, errors.ThrowInvalidArgument(nil, "ORG-R24hb", "Errors.Org.InvalidDomain")
}
domainWriteModel := NewOrgDomainWriteModel(orgDomain.AggregateID, orgDomain.Domain)
orgAgg := OrgAggregateFromWriteModel(&domainWriteModel.WriteModel)
events, err := c.addOrgDomain(ctx, orgAgg, domainWriteModel, orgDomain, claimedUserIDs)
func (c *Commands) AddOrgDomain(ctx context.Context, orgID, domain string, claimedUserIDs []string) (*domain.ObjectDetails, error) {
orgAgg := org.NewAggregate(orgID)
cmds, err := preparation.PrepareCommands(ctx, c.eventstore.Filter, c.prepareAddOrgDomain(orgAgg, domain, claimedUserIDs))
if err != nil {
return nil, err
}
pushedEvents, err := c.eventstore.Push(ctx, events...)
pushedEvents, err := c.eventstore.Push(ctx, cmds...)
if err != nil {
return nil, err
}
err = AppendAndReduce(domainWriteModel, pushedEvents...)
if err != nil {
return nil, err
}
return orgDomainWriteModelToOrgDomain(domainWriteModel), nil
return pushedEventsToObjectDetails(pushedEvents), nil
}
func (c *Commands) GenerateOrgDomainValidation(ctx context.Context, orgDomain *domain.OrgDomain) (token, url string, err error) {

View File

@ -25,9 +25,11 @@ import (
func TestAddDomain(t *testing.T) {
type args struct {
a *org.Aggregate
domain string
filter preparation.FilterToQueryReducer
a *org.Aggregate
domain string
claimedUserIDs []string
idGenerator id.Generator
filter preparation.FilterToQueryReducer
}
agg := org.NewAggregate("test")
@ -50,8 +52,9 @@ func TestAddDomain(t *testing.T) {
{
name: "correct (should verify domain)",
args: args{
a: agg,
domain: "domain",
a: agg,
domain: "domain",
claimedUserIDs: []string{"userID1"},
filter: func(ctx context.Context, queryFactory *eventstore.SearchQueryBuilder) ([]eventstore.Event, error) {
return []eventstore.Event{
org.NewDomainPolicyAddedEvent(ctx, &agg.Aggregate, true, true, true),
@ -67,26 +70,52 @@ func TestAddDomain(t *testing.T) {
{
name: "correct (should not verify domain)",
args: args{
a: agg,
domain: "domain",
filter: func(ctx context.Context, queryFactory *eventstore.SearchQueryBuilder) ([]eventstore.Event, error) {
return []eventstore.Event{
org.NewDomainPolicyAddedEvent(ctx, &agg.Aggregate, true, false, false),
}, nil
},
a: agg,
domain: "domain",
claimedUserIDs: []string{"userID1"},
idGenerator: id_mock.ExpectID(t, "newID"),
filter: func() func(ctx context.Context, queryFactory *eventstore.SearchQueryBuilder) ([]eventstore.Event, error) {
i := 0 //TODO: we should fix this in the future to use some kind of mock struct and expect filter calls
return func(ctx context.Context, queryFactory *eventstore.SearchQueryBuilder) ([]eventstore.Event, error) {
if i == 2 {
i++
return []eventstore.Event{user.NewHumanAddedEvent(
ctx,
&user.NewAggregate("userID1", "org2").Aggregate,
"username",
"firstname",
"lastname",
"nickname",
"displayname",
language.Und,
domain.GenderUnspecified,
"email",
false,
)}, nil
}
if i == 3 {
i++
return []eventstore.Event{org.NewDomainPolicyAddedEvent(ctx, &agg.Aggregate, false, false, false)}, nil
}
i++
return []eventstore.Event{org.NewDomainPolicyAddedEvent(ctx, &agg.Aggregate, true, false, false)}, nil
}
}(),
},
want: Want{
Commands: []eventstore.Command{
org.NewDomainAddedEvent(context.Background(), &agg.Aggregate, "domain"),
org.NewDomainVerifiedEvent(context.Background(), &agg.Aggregate, "domain"),
user.NewDomainClaimedEvent(context.Background(), &user.NewAggregate("userID1", "org2").Aggregate, "newID@temporary.domain", "username", false),
},
},
},
{
name: "already verified",
args: args{
a: agg,
domain: "domain",
a: agg,
domain: "domain",
claimedUserIDs: []string{"userID1"},
filter: func(ctx context.Context, queryFactory *eventstore.SearchQueryBuilder) ([]eventstore.Event, error) {
return []eventstore.Event{
org.NewDomainAddedEvent(ctx, &agg.Aggregate, "domain"),
@ -102,7 +131,13 @@ func TestAddDomain(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
AssertValidation(t, AddOrgDomain(tt.args.a, tt.args.domain), tt.args.filter, tt.want)
AssertValidation(
t,
authz.WithRequestedDomain(context.Background(), "domain"),
(&Commands{idGenerator: tt.args.idGenerator}).prepareAddOrgDomain(tt.args.a, tt.args.domain, tt.args.claimedUserIDs),
tt.args.filter,
tt.want,
)
})
}
}
@ -143,7 +178,7 @@ func TestVerifyDomain(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
AssertValidation(t, VerifyOrgDomain(tt.args.a, tt.args.domain), nil, tt.want)
AssertValidation(t, context.Background(), VerifyOrgDomain(tt.args.a, tt.args.domain), nil, tt.want)
})
}
}
@ -238,7 +273,7 @@ func TestSetDomainPrimary(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
AssertValidation(t, SetPrimaryOrgDomain(tt.args.a, tt.args.domain), tt.args.filter, tt.want)
AssertValidation(t, context.Background(), SetPrimaryOrgDomain(tt.args.a, tt.args.domain), tt.args.filter, tt.want)
})
}
}
@ -249,11 +284,12 @@ func TestCommandSide_AddOrgDomain(t *testing.T) {
}
type args struct {
ctx context.Context
domain *domain.OrgDomain
orgID string
domain string
claimedUserIDs []string
}
type res struct {
want *domain.OrgDomain
want *domain.ObjectDetails
err func(error) bool
}
tests := []struct {
@ -270,8 +306,7 @@ func TestCommandSide_AddOrgDomain(t *testing.T) {
),
},
args: args{
ctx: context.Background(),
domain: &domain.OrgDomain{},
ctx: context.Background(),
},
res: res{
err: errors.IsErrorInvalidArgument,
@ -299,13 +334,9 @@ func TestCommandSide_AddOrgDomain(t *testing.T) {
),
},
args: args{
ctx: context.Background(),
domain: &domain.OrgDomain{
ObjectRoot: models.ObjectRoot{
AggregateID: "org1",
},
Domain: "domain.ch",
},
ctx: context.Background(),
orgID: "org1",
domain: "domain.ch",
},
res: res{
err: errors.IsErrorAlreadyExists,
@ -324,6 +355,16 @@ func TestCommandSide_AddOrgDomain(t *testing.T) {
),
),
),
expectFilter(
eventFromEventPusher(
org.NewDomainPolicyAddedEvent(context.Background(),
&org.NewAggregate("org1").Aggregate,
true,
true,
true,
),
),
),
expectPush(
[]*repository.Event{
eventFromEventPusher(org.NewDomainAddedEvent(context.Background(),
@ -335,21 +376,13 @@ func TestCommandSide_AddOrgDomain(t *testing.T) {
),
},
args: args{
ctx: context.Background(),
domain: &domain.OrgDomain{
ObjectRoot: models.ObjectRoot{
AggregateID: "org1",
},
Domain: "domain.ch",
},
ctx: context.Background(),
orgID: "org1",
domain: "domain.ch",
},
res: res{
want: &domain.OrgDomain{
ObjectRoot: models.ObjectRoot{
AggregateID: "org1",
ResourceOwner: "org1",
},
Domain: "domain.ch",
want: &domain.ObjectDetails{
ResourceOwner: "org1",
},
},
},
@ -359,7 +392,7 @@ func TestCommandSide_AddOrgDomain(t *testing.T) {
r := &Commands{
eventstore: tt.fields.eventstore,
}
got, err := r.AddOrgDomain(tt.args.ctx, tt.args.domain, tt.args.claimedUserIDs)
got, err := r.AddOrgDomain(tt.args.ctx, tt.args.orgID, tt.args.domain, tt.args.claimedUserIDs)
if tt.res.err == nil {
assert.NoError(t, err)
}

View File

@ -165,7 +165,7 @@ func TestAddMember(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
AssertValidation(t, (&Commands{zitadelRoles: tt.args.zitadelRoles}).AddOrgMemberCommand(tt.args.a, tt.args.userID, tt.args.roles...), tt.args.filter, tt.want)
AssertValidation(t, context.Background(), (&Commands{zitadelRoles: tt.args.zitadelRoles}).AddOrgMemberCommand(tt.args.a, tt.args.userID, tt.args.roles...), tt.args.filter, tt.want)
})
}
}

View File

@ -62,7 +62,7 @@ func TestAddOrg(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
AssertValidation(t, AddOrgCommand(authz.WithRequestedDomain(context.Background(), "localhost"), tt.args.a, tt.args.name), nil, tt.want)
AssertValidation(t, context.Background(), AddOrgCommand(authz.WithRequestedDomain(context.Background(), "localhost"), tt.args.a, tt.args.name), nil, tt.want)
})
}
}

View File

@ -23,7 +23,7 @@ type CommandVerifier interface {
}
//AssertValidation checks if the validation works as inteded
func AssertValidation(t *testing.T, validation preparation.Validation, filter preparation.FilterToQueryReducer, want Want) {
func AssertValidation(t *testing.T, ctx context.Context, validation preparation.Validation, filter preparation.FilterToQueryReducer, want Want) {
t.Helper()
creates, err := validation()
@ -34,7 +34,7 @@ func AssertValidation(t *testing.T, validation preparation.Validation, filter pr
if err != nil {
return
}
cmds, err := creates(context.Background(), filter)
cmds, err := creates(ctx, filter)
if !errors.Is(err, want.CreateErr) {
t.Errorf("wrong create err = (%[1]T): %[1]v, want (%[2]T): %[2]v", err, want.CreateErr)
return

View File

@ -128,6 +128,7 @@ func TestAddAPIConfig(t *testing.T) {
idGenerator: tt.fields.idGenerator,
}
AssertValidation(t,
context.Background(),
c.AddAPIAppCommand(
&addAPIApp{
AddApp: AddApp{

View File

@ -180,6 +180,7 @@ func TestAddOIDCApp(t *testing.T) {
idGenerator: tt.fields.idGenerator,
}
AssertValidation(t,
context.Background(),
c.AddOIDCAppCommand(
tt.args.app,
tt.args.clientSecretAlg,

View File

@ -5,6 +5,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
id_mock "github.com/zitadel/zitadel/internal/id/mock"
"github.com/zitadel/zitadel/internal/repository/member"
@ -1065,7 +1066,7 @@ func TestAddProject(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
AssertValidation(t, AddProjectCommand(tt.args.a, tt.args.name, tt.args.owner, false, false, false, tt.args.privateLabelingSetting), nil, tt.want)
AssertValidation(t, context.Background(), AddProjectCommand(tt.args.a, tt.args.name, tt.args.owner, false, false, false, tt.args.privateLabelingSetting), nil, tt.want)
})
}
}

View File

@ -339,36 +339,31 @@ func (c *Commands) userDomainClaimed(ctx context.Context, userID string) (events
}, changedUserGrant, nil
}
func (c *Commands) prepareUserDomainClaimed(userID string) preparation.Validation {
return func() (_ preparation.CreateCommands, err error) {
return func(ctx context.Context, filter preparation.FilterToQueryReducer) ([]eventstore.Command, error) {
userWriteModel, err := userWriteModelByID(ctx, filter, userID, "")
if err != nil {
return nil, err
}
if !userWriteModel.UserState.Exists() {
return nil, caos_errs.ThrowNotFound(nil, "COMMAND-ii9K0", "Errors.User.NotFound")
}
domainPolicy, err := domainPolicyWriteModel(ctx, filter)
if err != nil {
return nil, err
}
userAgg := UserAggregateFromWriteModel(&userWriteModel.WriteModel)
id, err := c.idGenerator.Next()
if err != nil {
return nil, err
}
return []eventstore.Command{user.NewDomainClaimedEvent(
ctx,
userAgg,
fmt.Sprintf("%s@temporary.%s", id, authz.GetInstance(ctx).RequestedDomain()),
userWriteModel.UserName,
domainPolicy.UserLoginMustBeDomain),
}, nil
}, nil
func (c *Commands) prepareUserDomainClaimed(ctx context.Context, filter preparation.FilterToQueryReducer, userID string) (*user.DomainClaimedEvent, error) {
userWriteModel, err := userWriteModelByID(ctx, filter, userID, "")
if err != nil {
return nil, err
}
if !userWriteModel.UserState.Exists() {
return nil, caos_errs.ThrowNotFound(nil, "COMMAND-ii9K0", "Errors.User.NotFound")
}
domainPolicy, err := domainPolicyWriteModel(ctx, filter)
if err != nil {
return nil, err
}
userAgg := UserAggregateFromWriteModel(&userWriteModel.WriteModel)
id, err := c.idGenerator.Next()
if err != nil {
return nil, err
}
return user.NewDomainClaimedEvent(
ctx,
userAgg,
fmt.Sprintf("%s@temporary.%s", id, authz.GetInstance(ctx).RequestedDomain()),
userWriteModel.UserName,
domainPolicy.UserLoginMustBeDomain), nil
}
func (c *Commands) UserDomainClaimedSent(ctx context.Context, orgID, userID string) (err error) {

View File

@ -2948,7 +2948,7 @@ func TestAddHumanCommand(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
AssertValidation(t, AddHumanCommand(tt.args.a, tt.args.human, tt.args.passwordAlg, tt.args.codeAlg), tt.args.filter, tt.want)
AssertValidation(t, context.Background(), AddHumanCommand(tt.args.a, tt.args.human, tt.args.passwordAlg, tt.args.codeAlg), tt.args.filter, tt.want)
})
}
}

View File

@ -200,6 +200,10 @@ func (p *userProjection) reducers() []handler.AggregateReducer {
Event: user.UserUserNameChangedType,
Reduce: p.reduceUserNameChanged,
},
{
Event: user.UserDomainClaimedType,
Reduce: p.reduceDomainClaimed,
},
{
Event: user.HumanProfileChangedType,
Reduce: p.reduceHumanProfileChanged,
@ -518,6 +522,26 @@ func (p *userProjection) reduceUserNameChanged(event eventstore.Event) (*handler
), nil
}
func (p *userProjection) reduceDomainClaimed(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*user.DomainClaimedEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-ASwf3", "reduce.wrong.event.type %s", user.UserDomainClaimedType)
}
return crdb.NewUpdateStatement(
e,
[]handler.Column{
handler.NewCol(UserChangeDateCol, e.CreationDate()),
handler.NewCol(UserUsernameCol, e.UserName),
handler.NewCol(UserSequenceCol, e.Sequence()),
},
[]handler.Condition{
handler.NewCond(UserIDCol, e.Aggregate().ID),
handler.NewCond(UserInstanceIDCol, e.Aggregate().InstanceID),
},
), nil
}
func (p *userProjection) reduceHumanProfileChanged(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*user.HumanProfileChangedEvent)
if !ok {

View File

@ -733,6 +733,39 @@ func TestUserProjection_reduces(t *testing.T) {
},
},
},
{
name: "reduceDomainClaimed",
args: args{
event: getEvent(testEvent(
repository.EventType(user.UserDomainClaimedType),
user.AggregateType,
[]byte(`{
"username": "id@temporary.domain"
}`),
), user.DomainClaimedEventMapper),
},
reduce: (&userProjection{}).reduceDomainClaimed,
want: wantReduce{
aggregateType: user.AggregateType,
sequence: 15,
previousSequence: 10,
projection: UserTable,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.users2 SET (change_date, username, sequence) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
expectedArgs: []interface{}{
anyArg{},
"id@temporary.domain",
uint64(15),
"agg-id",
"instance-id",
},
},
},
},
},
},
{
name: "reduceHumanProfileChanged",
args: args{