fix: org unique check (#5033)

- all verified of domains are checked
- domains are checked case insensitive
- name is checked case insensitive
This commit is contained in:
Silvan 2023-01-16 10:55:19 +01:00 committed by GitHub
parent e7a97b1f3b
commit 74c1c39207
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 130 additions and 14 deletions

View File

@ -9,7 +9,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
domain_pkg "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
@ -68,7 +68,7 @@ type Org struct {
CreationDate time.Time
ChangeDate time.Time
ResourceOwner string
State domain.OrgState
State domain_pkg.OrgState
Sequence uint64
Name string
@ -156,15 +156,19 @@ func (q *Queries) IsOrgUnique(ctx context.Context, name, domain string) (isUniqu
sq.And{
sq.Eq{
OrgColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
OrgDomainIsVerifiedCol.identifier(): true,
},
sq.Or{
sq.Eq{
OrgColumnDomain.identifier(): domain,
sq.ILike{
OrgDomainDomainCol.identifier(): domain,
},
sq.Eq{
sq.ILike{
OrgColumnName.identifier(): name,
},
},
sq.NotEq{
OrgColumnState.identifier(): domain_pkg.OrgStateRemoved,
},
}).ToSql()
if err != nil {
return false, errors.ThrowInternal(err, "QUERY-Dgbe2", "Errors.Query.SQLStatement")
@ -346,7 +350,9 @@ func prepareOrgWithDomainsQuery() (sq.SelectBuilder, func(*sql.Row) (*Org, error
func prepareOrgUniqueQuery() (sq.SelectBuilder, func(*sql.Row) (bool, error)) {
return sq.Select(uniqueColumn.identifier()).
From(orgsTable.identifier()).PlaceholderFormat(sq.Dollar),
From(orgsTable.identifier()).
LeftJoin(join(OrgDomainOrgIDCol, OrgColumnID)).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (isUnique bool, err error) {
err = row.Scan(&isUnique)
if err != nil {

View File

@ -1,15 +1,23 @@
package query
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
errs "errors"
"fmt"
"regexp"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/zitadel/zitadel/internal/domain"
errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/errors"
)
var (
orgUniqueQuery = "SELECT COUNT(*) = 0 FROM projections.orgs LEFT JOIN projections.org_domains2 ON projections.orgs.id = projections.org_domains2.org_id AND projections.orgs.instance_id = projections.org_domains2.instance_id WHERE (projections.org_domains2.is_verified = $1 AND projections.orgs.instance_id = $2 AND (projections.org_domains2.domain ILIKE $3 OR projections.orgs.name ILIKE $4) AND projections.orgs.org_state <> $5)"
orgUniqueCols = []string{"is_unique"}
)
func Test_OrgPrepares(t *testing.T) {
@ -198,7 +206,7 @@ func Test_OrgPrepares(t *testing.T) {
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
if !errs.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
@ -224,7 +232,7 @@ func Test_OrgPrepares(t *testing.T) {
nil,
),
err: func(err error) (error, bool) {
if !errs.IsNotFound(err) {
if !errors.IsNotFound(err) {
return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false
}
return nil, true
@ -296,7 +304,7 @@ func Test_OrgPrepares(t *testing.T) {
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
if !errs.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
@ -315,7 +323,7 @@ func Test_OrgPrepares(t *testing.T) {
nil,
),
err: func(err error) (error, bool) {
if !errs.IsInternal(err) {
if !errors.IsInternal(err) {
return fmt.Errorf("err should be zitadel.Internal got: %w", err), false
}
return nil, true
@ -350,7 +358,7 @@ func Test_OrgPrepares(t *testing.T) {
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
if !errs.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
@ -365,3 +373,105 @@ func Test_OrgPrepares(t *testing.T) {
})
}
}
func TestQueries_IsOrgUnique(t *testing.T) {
type args struct {
name string
domain string
}
type want struct {
err func(error) bool
sqlExpectations sqlExpectation
isUnique bool
}
tests := []struct {
name string
args args
want want
}{
{
name: "existing domain",
args: args{
domain: "exists",
name: "",
},
want: want{
isUnique: false,
sqlExpectations: mockQueries(orgUniqueQuery, orgUniqueCols, [][]driver.Value{{false}}, true, "", "exists", "", domain.OrgStateRemoved),
},
},
{
name: "existing name",
args: args{
domain: "",
name: "exists",
},
want: want{
isUnique: false,
sqlExpectations: mockQueries(orgUniqueQuery, orgUniqueCols, [][]driver.Value{{false}}, true, "", "", "exists", domain.OrgStateRemoved),
},
},
{
name: "existing name and domain",
args: args{
domain: "exists",
name: "exists",
},
want: want{
isUnique: false,
sqlExpectations: mockQueries(orgUniqueQuery, orgUniqueCols, [][]driver.Value{{false}}, true, "", "exists", "exists", domain.OrgStateRemoved),
},
},
{
name: "not existing",
args: args{
domain: "not-exists",
name: "not-exists",
},
want: want{
isUnique: true,
sqlExpectations: mockQueries(orgUniqueQuery, orgUniqueCols, [][]driver.Value{{true}}, true, "", "not-exists", "not-exists", domain.OrgStateRemoved),
},
},
{
name: "no arg",
args: args{
domain: "",
name: "",
},
want: want{
isUnique: false,
err: errors.IsErrorInvalidArgument,
},
},
}
for _, tt := range tests {
client, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
if err != nil {
t.Fatalf("unable to mock db: %v", err)
}
if tt.want.sqlExpectations != nil {
tt.want.sqlExpectations(mock)
}
t.Run(tt.name, func(t *testing.T) {
q := &Queries{
client: client,
}
gotIsUnique, err := q.IsOrgUnique(context.Background(), tt.args.name, tt.args.domain)
if (tt.want.err == nil && err != nil) || (err != nil && tt.want.err != nil && !tt.want.err(err)) {
t.Errorf("Queries.IsOrgUnique() unexpected error = %v", err)
return
}
if gotIsUnique != tt.want.isUnique {
t.Errorf("Queries.IsOrgUnique() = %v, want %v", gotIsUnique, tt.want.isUnique)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("expectation was met: %v", err)
}
})
}
}