From 74c1c39207ad2782b401d3e61c0037e0632af13f Mon Sep 17 00:00:00 2001 From: Silvan Date: Mon, 16 Jan 2023 10:55:19 +0100 Subject: [PATCH] fix: org unique check (#5033) - all verified of domains are checked - domains are checked case insensitive - name is checked case insensitive --- internal/query/org.go | 20 +++--- internal/query/org_test.go | 124 ++++++++++++++++++++++++++++++++++--- 2 files changed, 130 insertions(+), 14 deletions(-) diff --git a/internal/query/org.go b/internal/query/org.go index d81c8c1d62..5b91c4cf68 100644 --- a/internal/query/org.go +++ b/internal/query/org.go @@ -9,7 +9,7 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/domain" + domain_pkg "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -68,7 +68,7 @@ type Org struct { CreationDate time.Time ChangeDate time.Time ResourceOwner string - State domain.OrgState + State domain_pkg.OrgState Sequence uint64 Name string @@ -155,16 +155,20 @@ func (q *Queries) IsOrgUnique(ctx context.Context, name, domain string) (isUniqu stmt, args, err := query.Where( sq.And{ sq.Eq{ - OrgColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), + OrgColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), + OrgDomainIsVerifiedCol.identifier(): true, }, sq.Or{ - sq.Eq{ - OrgColumnDomain.identifier(): domain, + sq.ILike{ + OrgDomainDomainCol.identifier(): domain, }, - sq.Eq{ + sq.ILike{ OrgColumnName.identifier(): name, }, }, + sq.NotEq{ + OrgColumnState.identifier(): domain_pkg.OrgStateRemoved, + }, }).ToSql() if err != nil { return false, errors.ThrowInternal(err, "QUERY-Dgbe2", "Errors.Query.SQLStatement") @@ -346,7 +350,9 @@ func prepareOrgWithDomainsQuery() (sq.SelectBuilder, func(*sql.Row) (*Org, error func prepareOrgUniqueQuery() (sq.SelectBuilder, func(*sql.Row) (bool, error)) { return sq.Select(uniqueColumn.identifier()). - From(orgsTable.identifier()).PlaceholderFormat(sq.Dollar), + From(orgsTable.identifier()). + LeftJoin(join(OrgDomainOrgIDCol, OrgColumnID)). + PlaceholderFormat(sq.Dollar), func(row *sql.Row) (isUnique bool, err error) { err = row.Scan(&isUnique) if err != nil { diff --git a/internal/query/org_test.go b/internal/query/org_test.go index ce6f91a23b..93acdc1555 100644 --- a/internal/query/org_test.go +++ b/internal/query/org_test.go @@ -1,15 +1,23 @@ package query import ( + "context" "database/sql" "database/sql/driver" - "errors" + errs "errors" "fmt" "regexp" "testing" + "github.com/DATA-DOG/go-sqlmock" + "github.com/zitadel/zitadel/internal/domain" - errs "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/errors" +) + +var ( + orgUniqueQuery = "SELECT COUNT(*) = 0 FROM projections.orgs LEFT JOIN projections.org_domains2 ON projections.orgs.id = projections.org_domains2.org_id AND projections.orgs.instance_id = projections.org_domains2.instance_id WHERE (projections.org_domains2.is_verified = $1 AND projections.orgs.instance_id = $2 AND (projections.org_domains2.domain ILIKE $3 OR projections.orgs.name ILIKE $4) AND projections.orgs.org_state <> $5)" + orgUniqueCols = []string{"is_unique"} ) func Test_OrgPrepares(t *testing.T) { @@ -198,7 +206,7 @@ func Test_OrgPrepares(t *testing.T) { sql.ErrConnDone, ), err: func(err error) (error, bool) { - if !errors.Is(err, sql.ErrConnDone) { + if !errs.Is(err, sql.ErrConnDone) { return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false } return nil, true @@ -224,7 +232,7 @@ func Test_OrgPrepares(t *testing.T) { nil, ), err: func(err error) (error, bool) { - if !errs.IsNotFound(err) { + if !errors.IsNotFound(err) { return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false } return nil, true @@ -296,7 +304,7 @@ func Test_OrgPrepares(t *testing.T) { sql.ErrConnDone, ), err: func(err error) (error, bool) { - if !errors.Is(err, sql.ErrConnDone) { + if !errs.Is(err, sql.ErrConnDone) { return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false } return nil, true @@ -315,7 +323,7 @@ func Test_OrgPrepares(t *testing.T) { nil, ), err: func(err error) (error, bool) { - if !errs.IsInternal(err) { + if !errors.IsInternal(err) { return fmt.Errorf("err should be zitadel.Internal got: %w", err), false } return nil, true @@ -350,7 +358,7 @@ func Test_OrgPrepares(t *testing.T) { sql.ErrConnDone, ), err: func(err error) (error, bool) { - if !errors.Is(err, sql.ErrConnDone) { + if !errs.Is(err, sql.ErrConnDone) { return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false } return nil, true @@ -365,3 +373,105 @@ func Test_OrgPrepares(t *testing.T) { }) } } + +func TestQueries_IsOrgUnique(t *testing.T) { + type args struct { + name string + domain string + } + type want struct { + err func(error) bool + sqlExpectations sqlExpectation + isUnique bool + } + tests := []struct { + name string + args args + want want + }{ + { + name: "existing domain", + args: args{ + domain: "exists", + name: "", + }, + want: want{ + isUnique: false, + sqlExpectations: mockQueries(orgUniqueQuery, orgUniqueCols, [][]driver.Value{{false}}, true, "", "exists", "", domain.OrgStateRemoved), + }, + }, + { + name: "existing name", + args: args{ + domain: "", + name: "exists", + }, + want: want{ + isUnique: false, + sqlExpectations: mockQueries(orgUniqueQuery, orgUniqueCols, [][]driver.Value{{false}}, true, "", "", "exists", domain.OrgStateRemoved), + }, + }, + { + name: "existing name and domain", + args: args{ + domain: "exists", + name: "exists", + }, + want: want{ + isUnique: false, + sqlExpectations: mockQueries(orgUniqueQuery, orgUniqueCols, [][]driver.Value{{false}}, true, "", "exists", "exists", domain.OrgStateRemoved), + }, + }, + { + name: "not existing", + args: args{ + domain: "not-exists", + name: "not-exists", + }, + want: want{ + isUnique: true, + sqlExpectations: mockQueries(orgUniqueQuery, orgUniqueCols, [][]driver.Value{{true}}, true, "", "not-exists", "not-exists", domain.OrgStateRemoved), + }, + }, + { + name: "no arg", + args: args{ + domain: "", + name: "", + }, + want: want{ + isUnique: false, + err: errors.IsErrorInvalidArgument, + }, + }, + } + for _, tt := range tests { + client, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + if err != nil { + t.Fatalf("unable to mock db: %v", err) + } + if tt.want.sqlExpectations != nil { + tt.want.sqlExpectations(mock) + } + + t.Run(tt.name, func(t *testing.T) { + q := &Queries{ + client: client, + } + + gotIsUnique, err := q.IsOrgUnique(context.Background(), tt.args.name, tt.args.domain) + if (tt.want.err == nil && err != nil) || (err != nil && tt.want.err != nil && !tt.want.err(err)) { + t.Errorf("Queries.IsOrgUnique() unexpected error = %v", err) + return + } + if gotIsUnique != tt.want.isUnique { + t.Errorf("Queries.IsOrgUnique() = %v, want %v", gotIsUnique, tt.want.isUnique) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectation was met: %v", err) + } + }) + + } +}