feat: Instance commands (#3385)

* fix: add events for domain

* fix: add/remove domain command side

* fix: add/remove domain command side

* fix: add/remove domain query side

* fix: create instance

* fix: merge v2

* fix: instance domain

* fix: instance domain

* fix: instance domain

* fix: instance domain

* fix: remove domain.IAMID from writemodels

* fix: remove domain.IAMID from writemodels

* fix: remove domain.IAMID from writemodels

* fix: remove domain.IAMID from writemodels

* fix: remove domain.IAMID from writemodels

* fix: remove domain.IAMID from writemodels

* fix: remove domain.IAMID from writemodels

* fix: remove domain.IAMID from writemodels

* fix: remove domain.IAMID from writemodels

* fix: remove domain.IAMID from api

* fix: remove domain.IAMID

* fix: remove domain.IAMID

* fix: add instance domain queries

* fix: fix after merge

* Update auth_request.go

* fix keypair

* remove unused code

* feat: read instance id from context

* feat: remove unused code

* feat: use instance id from context

* some fixes

Co-authored-by: Livio Amstutz <livio.a@gmail.com>
This commit is contained in:
Fabi
2022-04-05 07:58:09 +02:00
committed by GitHub
parent 7d6a10015a
commit c740ee5d81
156 changed files with 6360 additions and 3951 deletions

View File

@@ -140,7 +140,7 @@ func (q *Queries) GetDefaultLoginTexts(ctx context.Context, lang string) (*domai
return nil, errors.ThrowInternal(err, "TEXT-M0p4s", "Errors.TranslationFile.ReadError")
}
loginText.IsDefault = true
loginText.AggregateID = domain.IAMID
loginText.AggregateID = authz.GetInstance(ctx).InstanceID()
return loginText, nil
}
@@ -149,7 +149,7 @@ func (q *Queries) GetCustomLoginTexts(ctx context.Context, aggregateID, lang str
if err != nil {
return nil, err
}
return CustomTextsToLoginDomain(aggregateID, lang, texts), err
return CustomTextsToLoginDomain(authz.GetInstance(ctx).InstanceID(), aggregateID, lang, texts), err
}
func (q *Queries) IAMLoginTexts(ctx context.Context, lang string) (*domain.CustomLoginText, error) {
@@ -161,7 +161,7 @@ func (q *Queries) IAMLoginTexts(ctx context.Context, lang string) (*domain.Custo
if err := yaml.Unmarshal(contents, &loginTextMap); err != nil {
return nil, errors.ThrowInternal(err, "QUERY-m0Jf3", "Errors.TranslationFile.ReadError")
}
texts, err := q.CustomTextList(ctx, domain.IAMID, domain.LoginCustomText, lang)
texts, err := q.CustomTextList(ctx, authz.GetInstance(ctx).InstanceID(), domain.LoginCustomText, lang)
if err != nil {
return nil, err
}
@@ -181,7 +181,7 @@ func (q *Queries) IAMLoginTexts(ctx context.Context, lang string) (*domain.Custo
if err := json.Unmarshal(jsonbody, &loginText); err != nil {
return nil, errors.ThrowInternal(err, "QUERY-m93Jf", "Errors.TranslationFile.MergeError")
}
loginText.AggregateID = domain.IAMID
loginText.AggregateID = authz.GetInstance(ctx).InstanceID()
loginText.IsDefault = true
return loginText, nil
}
@@ -276,7 +276,7 @@ func CustomTextToDomain(text *CustomText) *domain.CustomText {
}
}
func CustomTextsToLoginDomain(aggregateID, lang string, texts *CustomTexts) *domain.CustomLoginText {
func CustomTextsToLoginDomain(instanceID, aggregateID, lang string, texts *CustomTexts) *domain.CustomLoginText {
langTag := language.Make(lang)
result := &domain.CustomLoginText{
ObjectRoot: models.ObjectRoot{
@@ -285,7 +285,7 @@ func CustomTextsToLoginDomain(aggregateID, lang string, texts *CustomTexts) *dom
Language: langTag,
}
if len(texts.CustomTexts) == 0 {
result.AggregateID = domain.IAMID
result.AggregateID = instanceID
result.IsDefault = true
}
for _, text := range texts.CustomTexts {

View File

@@ -170,7 +170,7 @@ func (q *Queries) FeaturesByOrgID(ctx context.Context, orgID string) (*Features,
FeatureColumnAggregateID.identifier(): orgID,
},
sq.Eq{
FeatureColumnAggregateID.identifier(): domain.IAMID,
FeatureColumnAggregateID.identifier(): authz.GetInstance(ctx).InstanceID(),
},
},
}).
@@ -187,7 +187,7 @@ func (q *Queries) FeaturesByOrgID(ctx context.Context, orgID string) (*Features,
func (q *Queries) DefaultFeatures(ctx context.Context) (*Features, error) {
query, scan := prepareFeaturesQuery()
stmt, args, err := query.Where(sq.Eq{
FeatureColumnAggregateID.identifier(): domain.IAMID,
FeatureColumnAggregateID.identifier(): authz.GetInstance(ctx).InstanceID(),
FeatureColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
}).ToSql()
if err != nil {

View File

@@ -193,7 +193,7 @@ func (q *Queries) IDPByIDAndResourceOwner(ctx context.Context, id, resourceOwner
IDPResourceOwnerCol.identifier(): resourceOwner,
},
sq.Eq{
IDPResourceOwnerCol.identifier(): domain.IAMID,
IDPResourceOwnerCol.identifier(): authz.GetInstance(ctx).InstanceID(),
},
},
},

View File

@@ -0,0 +1,148 @@
package query
import (
"context"
"database/sql"
"time"
sq "github.com/Masterminds/squirrel"
"github.com/caos/zitadel/internal/api/authz"
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/query/projection"
)
type InstanceDomain struct {
CreationDate time.Time
ChangeDate time.Time
Sequence uint64
Domain string
InstanceID string
IsGenerated bool
}
type InstanceDomains struct {
SearchResponse
Domains []*InstanceDomain
}
type InstanceDomainSearchQueries struct {
SearchRequest
Queries []SearchQuery
}
func (q *InstanceDomainSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
query = q.SearchRequest.toQuery(query)
for _, q := range q.Queries {
query = q.toQuery(query)
}
return query
}
func NewInstanceDomainDomainSearchQuery(method TextComparison, value string) (SearchQuery, error) {
return NewTextQuery(InstanceDomainDomainCol, value, method)
}
func NewInstanceDomainInstanceIDSearchQuery(value string) (SearchQuery, error) {
return NewTextQuery(InstanceDomainInstanceIDCol, value, TextEquals)
}
func NewInstanceDomainGeneratedSearchQuery(verified bool) (SearchQuery, error) {
return NewBoolQuery(InstanceDomainIsGeneratedCol, verified)
}
func (q *Queries) SearchInstanceDomains(ctx context.Context, queries *InstanceDomainSearchQueries) (domains *InstanceDomains, err error) {
query, scan := prepareInstanceDomainsQuery()
stmt, args, err := queries.toQuery(query).
Where(sq.Eq{
InstanceDomainInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
}).ToSql()
if err != nil {
return nil, errors.ThrowInvalidArgument(err, "QUERY-inlsF", "Errors.Query.SQLStatement")
}
rows, err := q.client.QueryContext(ctx, stmt, args...)
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-38Fni", "Errors.Internal")
}
domains, err = scan(rows)
if err != nil {
return nil, err
}
domains.LatestSequence, err = q.latestSequence(ctx, instanceDomainsTable)
return domains, err
}
func prepareInstanceDomainsQuery() (sq.SelectBuilder, func(*sql.Rows) (*InstanceDomains, error)) {
return sq.Select(
InstanceDomainCreationDateCol.identifier(),
InstanceDomainChangeDateCol.identifier(),
InstanceDomainSequenceCol.identifier(),
InstanceDomainDomainCol.identifier(),
InstanceDomainInstanceIDCol.identifier(),
InstanceDomainIsGeneratedCol.identifier(),
countColumn.identifier(),
).From(instanceDomainsTable.identifier()).PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*InstanceDomains, error) {
domains := make([]*InstanceDomain, 0)
var count uint64
for rows.Next() {
domain := new(InstanceDomain)
err := rows.Scan(
&domain.CreationDate,
&domain.ChangeDate,
&domain.Sequence,
&domain.Domain,
&domain.InstanceID,
&domain.IsGenerated,
&count,
)
if err != nil {
return nil, err
}
domains = append(domains, domain)
}
if err := rows.Close(); err != nil {
return nil, errors.ThrowInternal(err, "QUERY-8nlWW", "Errors.Query.CloseRows")
}
return &InstanceDomains{
Domains: domains,
SearchResponse: SearchResponse{
Count: count,
},
}, nil
}
}
var (
instanceDomainsTable = table{
name: projection.InstanceDomainTable,
}
InstanceDomainCreationDateCol = Column{
name: projection.InstanceDomainCreationDateCol,
table: instanceDomainsTable,
}
InstanceDomainChangeDateCol = Column{
name: projection.InstanceDomainChangeDateCol,
table: instanceDomainsTable,
}
InstanceDomainSequenceCol = Column{
name: projection.InstanceDomainSequenceCol,
table: instanceDomainsTable,
}
InstanceDomainDomainCol = Column{
name: projection.InstanceDomainDomainCol,
table: instanceDomainsTable,
}
InstanceDomainInstanceIDCol = Column{
name: projection.InstanceDomainInstanceIDCol,
table: instanceDomainsTable,
}
InstanceDomainIsGeneratedCol = Column{
name: projection.InstanceDomainIsGeneratedCol,
table: instanceDomainsTable,
}
)

View File

@@ -0,0 +1,188 @@
package query
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"regexp"
"testing"
)
func Test_InstanceDomainPrepares(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
err checkErr
}
tests := []struct {
name string
prepare interface{}
want want
object interface{}
}{
{
name: "prepareDomainsQuery no result",
prepare: prepareInstanceDomainsQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.instance_domains.creation_date,`+
` projections.instance_domains.change_date,`+
` projections.instance_domains.sequence,`+
` projections.instance_domains.domain,`+
` projections.instance_domains.instance_id,`+
` projections.instance_domains.is_generated,`+
` COUNT(*) OVER ()`+
` FROM projections.instance_domains`),
nil,
nil,
),
},
object: &InstanceDomains{Domains: []*InstanceDomain{}},
},
{
name: "prepareDomainsQuery one result",
prepare: prepareInstanceDomainsQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.instance_domains.creation_date,`+
` projections.instance_domains.change_date,`+
` projections.instance_domains.sequence,`+
` projections.instance_domains.domain,`+
` projections.instance_domains.instance_id,`+
` projections.instance_domains.is_generated,`+
` COUNT(*) OVER ()`+
` FROM projections.instance_domains`),
[]string{
"creation_date",
"change_date",
"sequence",
"domain",
"instance_id",
"is_generated",
"count",
},
[][]driver.Value{
{
testNow,
testNow,
uint64(20211109),
"zitadel.ch",
"inst-id",
true,
},
},
),
},
object: &InstanceDomains{
SearchResponse: SearchResponse{
Count: 1,
},
Domains: []*InstanceDomain{
{
CreationDate: testNow,
ChangeDate: testNow,
Sequence: 20211109,
Domain: "zitadel.ch",
InstanceID: "inst-id",
IsGenerated: true,
},
},
},
},
{
name: "prepareDomainsQuery multiple result",
prepare: prepareInstanceDomainsQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(`SELECT projections.instance_domains.creation_date,`+
` projections.instance_domains.change_date,`+
` projections.instance_domains.sequence,`+
` projections.instance_domains.domain,`+
` projections.instance_domains.instance_id,`+
` projections.instance_domains.is_generated,`+
` COUNT(*) OVER ()`+
` FROM projections.instance_domains`),
[]string{
"creation_date",
"change_date",
"sequence",
"domain",
"instance_id",
"is_generated",
"count",
},
[][]driver.Value{
{
testNow,
testNow,
uint64(20211109),
"zitadel.ch",
"inst-id",
true,
},
{
testNow,
testNow,
uint64(20211109),
"zitadel.com",
"inst-id",
false,
},
},
),
},
object: &InstanceDomains{
SearchResponse: SearchResponse{
Count: 2,
},
Domains: []*InstanceDomain{
{
CreationDate: testNow,
ChangeDate: testNow,
Sequence: 20211109,
Domain: "zitadel.ch",
InstanceID: "inst-id",
IsGenerated: true,
},
{
CreationDate: testNow,
ChangeDate: testNow,
Sequence: 20211109,
Domain: "zitadel.com",
InstanceID: "inst-id",
IsGenerated: false,
},
},
},
},
{
name: "prepareDomainsQuery sql err",
prepare: prepareInstanceDomainsQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(`SELECT projections.instance_domains.creation_date,`+
` projections.instance_domains.change_date,`+
` projections.instance_domains.sequence,`+
` projections.instance_domains.domain,`+
` projections.instance_domains.instance_id,`+
` projections.instance_domains.is_generated,`+
` COUNT(*) OVER ()`+
` FROM projections.instance_domains`),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
},
},
object: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
})
}
}

View File

@@ -49,7 +49,7 @@ func (q *Queries) ActiveLabelPolicyByOrg(ctx context.Context, orgID string) (*La
LabelPolicyColID.identifier(): orgID,
},
sq.Eq{
LabelPolicyColID.identifier(): domain.IAMID,
LabelPolicyColID.identifier(): authz.GetInstance(ctx).InstanceID(),
},
},
sq.Eq{
@@ -76,7 +76,7 @@ func (q *Queries) PreviewLabelPolicyByOrg(ctx context.Context, orgID string) (*L
LabelPolicyColID.identifier(): orgID,
},
sq.Eq{
LabelPolicyColID.identifier(): domain.IAMID,
LabelPolicyColID.identifier(): authz.GetInstance(ctx).InstanceID(),
},
},
sq.Eq{
@@ -97,7 +97,7 @@ func (q *Queries) PreviewLabelPolicyByOrg(ctx context.Context, orgID string) (*L
func (q *Queries) DefaultActiveLabelPolicy(ctx context.Context) (*LabelPolicy, error) {
stmt, scan := prepareLabelPolicyQuery()
query, args, err := stmt.Where(sq.Eq{
LabelPolicyColID.identifier(): domain.IAMID,
LabelPolicyColID.identifier(): authz.GetInstance(ctx).InstanceID(),
LabelPolicyColState.identifier(): domain.LabelPolicyStateActive,
LabelPolicyColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
}).
@@ -114,7 +114,7 @@ func (q *Queries) DefaultActiveLabelPolicy(ctx context.Context) (*LabelPolicy, e
func (q *Queries) DefaultPreviewLabelPolicy(ctx context.Context) (*LabelPolicy, error) {
stmt, scan := prepareLabelPolicyQuery()
query, args, err := stmt.Where(sq.Eq{
LabelPolicyColID.identifier(): domain.IAMID,
LabelPolicyColID.identifier(): authz.GetInstance(ctx).InstanceID(),
LabelPolicyColState.identifier(): domain.LabelPolicyStatePreview,
LabelPolicyColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
}).

View File

@@ -87,7 +87,7 @@ func (q *Queries) LockoutPolicyByOrg(ctx context.Context, orgID string) (*Lockou
LockoutColID.identifier(): orgID,
},
sq.Eq{
LockoutColID.identifier(): domain.IAMID,
LockoutColID.identifier(): authz.GetInstance(ctx).InstanceID(),
},
},
}).
@@ -104,7 +104,7 @@ func (q *Queries) LockoutPolicyByOrg(ctx context.Context, orgID string) (*Lockou
func (q *Queries) DefaultLockoutPolicy(ctx context.Context) (*LockoutPolicy, error) {
stmt, scan := prepareLockoutPolicyQuery()
query, args, err := stmt.Where(sq.Eq{
LockoutColID.identifier(): domain.IAMID,
LockoutColID.identifier(): authz.GetInstance(ctx).InstanceID(),
LockoutColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
}).
OrderBy(LockoutColIsDefault.identifier()).

View File

@@ -141,7 +141,7 @@ func (q *Queries) LoginPolicyByID(ctx context.Context, orgID string) (*LoginPoli
LoginPolicyColumnOrgID.identifier(): orgID,
},
sq.Eq{
LoginPolicyColumnOrgID.identifier(): domain.IAMID,
LoginPolicyColumnOrgID.identifier(): authz.GetInstance(ctx).InstanceID(),
},
},
}).
@@ -158,7 +158,7 @@ func (q *Queries) LoginPolicyByID(ctx context.Context, orgID string) (*LoginPoli
func (q *Queries) DefaultLoginPolicy(ctx context.Context) (*LoginPolicy, error) {
query, scan := prepareLoginPolicyQuery()
stmt, args, err := query.Where(sq.Eq{
LoginPolicyColumnOrgID.identifier(): domain.IAMID,
LoginPolicyColumnOrgID.identifier(): authz.GetInstance(ctx).InstanceID(),
LoginPolicyColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
}).OrderBy(LoginPolicyColumnIsDefault.identifier()).ToSql()
if err != nil {
@@ -181,7 +181,7 @@ func (q *Queries) SecondFactorsByOrg(ctx context.Context, orgID string) (*Second
LoginPolicyColumnOrgID.identifier(): orgID,
},
sq.Eq{
LoginPolicyColumnOrgID.identifier(): domain.IAMID,
LoginPolicyColumnOrgID.identifier(): authz.GetInstance(ctx).InstanceID(),
},
},
}).
@@ -203,7 +203,7 @@ func (q *Queries) SecondFactorsByOrg(ctx context.Context, orgID string) (*Second
func (q *Queries) DefaultSecondFactors(ctx context.Context) (*SecondFactors, error) {
query, scan := prepareLoginPolicy2FAsQuery()
stmt, args, err := query.Where(sq.Eq{
LoginPolicyColumnOrgID.identifier(): domain.IAMID,
LoginPolicyColumnOrgID.identifier(): authz.GetInstance(ctx).InstanceID(),
LoginPolicyColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
}).OrderBy(LoginPolicyColumnIsDefault.identifier()).ToSql()
if err != nil {
@@ -231,7 +231,7 @@ func (q *Queries) MultiFactorsByOrg(ctx context.Context, orgID string) (*MultiFa
LoginPolicyColumnOrgID.identifier(): orgID,
},
sq.Eq{
LoginPolicyColumnOrgID.identifier(): domain.IAMID,
LoginPolicyColumnOrgID.identifier(): authz.GetInstance(ctx).InstanceID(),
},
},
}).
@@ -253,7 +253,7 @@ func (q *Queries) MultiFactorsByOrg(ctx context.Context, orgID string) (*MultiFa
func (q *Queries) DefaultMultiFactors(ctx context.Context) (*MultiFactors, error) {
query, scan := prepareLoginPolicyMFAsQuery()
stmt, args, err := query.Where(sq.Eq{
LoginPolicyColumnOrgID.identifier(): domain.IAMID,
LoginPolicyColumnOrgID.identifier(): authz.GetInstance(ctx).InstanceID(),
LoginPolicyColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
}).OrderBy(LoginPolicyColumnIsDefault.identifier()).ToSql()
if err != nil {

View File

@@ -75,7 +75,7 @@ func (q *Queries) MailTemplateByOrg(ctx context.Context, orgID string) (*MailTem
MailTemplateColAggregateID.identifier(): orgID,
},
sq.Eq{
MailTemplateColAggregateID.identifier(): domain.IAMID,
MailTemplateColAggregateID.identifier(): authz.GetInstance(ctx).InstanceID(),
},
},
}).
@@ -92,7 +92,7 @@ func (q *Queries) MailTemplateByOrg(ctx context.Context, orgID string) (*MailTem
func (q *Queries) DefaultMailTemplate(ctx context.Context) (*MailTemplate, error) {
stmt, scan := prepareMailTemplateQuery()
query, args, err := stmt.Where(sq.Eq{
MailTemplateColAggregateID.identifier(): domain.IAMID,
MailTemplateColAggregateID.identifier(): authz.GetInstance(ctx).InstanceID(),
MailTemplateColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
}).
OrderBy(MailTemplateColIsDefault.identifier()).

View File

@@ -129,7 +129,7 @@ func (q *Queries) MessageTextByOrg(ctx context.Context, orgID string) (*MessageT
MessageTextColAggregateID.identifier(): orgID,
},
sq.Eq{
MessageTextColAggregateID.identifier(): domain.IAMID,
MessageTextColAggregateID.identifier(): authz.GetInstance(ctx).InstanceID(),
},
},
}).
@@ -146,7 +146,7 @@ func (q *Queries) MessageTextByOrg(ctx context.Context, orgID string) (*MessageT
func (q *Queries) DefaultMessageText(ctx context.Context) (*MessageText, error) {
stmt, scan := prepareMessageTextQuery()
query, args, err := stmt.Where(sq.Eq{
MessageTextColAggregateID.identifier(): domain.IAMID,
MessageTextColAggregateID.identifier(): authz.GetInstance(ctx).InstanceID(),
MessageTextColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
}).
Limit(1).ToSql()
@@ -203,7 +203,7 @@ func (q *Queries) IAMMessageTextByTypeAndLanguage(ctx context.Context, messageTy
if err := yaml.Unmarshal(contents, &notificationTextMap); err != nil {
return nil, errors.ThrowInternal(err, "QUERY-ekjFF", "Errors.TranslationFile.ReadError")
}
texts, err := q.CustomTextList(ctx, domain.IAMID, messageType, language)
texts, err := q.CustomTextList(ctx, authz.GetInstance(ctx).InstanceID(), messageType, language)
if err != nil {
return nil, err
}
@@ -225,7 +225,7 @@ func (q *Queries) IAMMessageTextByTypeAndLanguage(ctx context.Context, messageTy
}
result := notificationText.GetMessageTextByType(messageType)
result.IsDefault = true
result.AggregateID = domain.IAMID
result.AggregateID = authz.GetInstance(ctx).InstanceID()
return result, nil
}

View File

@@ -82,7 +82,7 @@ func (q *Queries) DomainPolicyByOrg(ctx context.Context, orgID string) (*DomainP
DomainPolicyColID.identifier(): orgID,
},
sq.Eq{
DomainPolicyColID.identifier(): domain.IAMID,
DomainPolicyColID.identifier(): authz.GetInstance(ctx).InstanceID(),
},
},
}).
@@ -99,7 +99,7 @@ func (q *Queries) DomainPolicyByOrg(ctx context.Context, orgID string) (*DomainP
func (q *Queries) DefaultDomainPolicy(ctx context.Context) (*DomainPolicy, error) {
stmt, scan := prepareDomainPolicyQuery()
query, args, err := stmt.Where(sq.Eq{
DomainPolicyColID.identifier(): domain.IAMID,
DomainPolicyColID.identifier(): authz.GetInstance(ctx).InstanceID(),
DomainPolicyColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
}).
OrderBy(DomainPolicyColIsDefault.identifier()).

View File

@@ -86,7 +86,7 @@ func (q *Queries) PasswordAgePolicyByOrg(ctx context.Context, orgID string) (*Pa
PasswordAgeColID.identifier(): orgID,
},
sq.Eq{
PasswordAgeColID.identifier(): domain.IAMID,
PasswordAgeColID.identifier(): authz.GetInstance(ctx).InstanceID(),
},
},
}).
@@ -103,7 +103,7 @@ func (q *Queries) PasswordAgePolicyByOrg(ctx context.Context, orgID string) (*Pa
func (q *Queries) DefaultPasswordAgePolicy(ctx context.Context) (*PasswordAgePolicy, error) {
stmt, scan := preparePasswordAgePolicyQuery()
query, args, err := stmt.Where(sq.Eq{
PasswordAgeColID.identifier(): domain.IAMID,
PasswordAgeColID.identifier(): authz.GetInstance(ctx).InstanceID(),
}).
OrderBy(PasswordAgeColIsDefault.identifier()).
Limit(1).ToSql()

View File

@@ -43,7 +43,7 @@ func (q *Queries) PasswordComplexityPolicyByOrg(ctx context.Context, orgID strin
PasswordComplexityColID.identifier(): orgID,
},
sq.Eq{
PasswordComplexityColID.identifier(): domain.IAMID,
PasswordComplexityColID.identifier(): authz.GetInstance(ctx).InstanceID(),
},
},
}).
@@ -60,7 +60,7 @@ func (q *Queries) PasswordComplexityPolicyByOrg(ctx context.Context, orgID strin
func (q *Queries) DefaultPasswordComplexityPolicy(ctx context.Context) (*PasswordComplexityPolicy, error) {
stmt, scan := preparePasswordComplexityPolicyQuery()
query, args, err := stmt.Where(sq.Eq{
PasswordComplexityColID.identifier(): domain.IAMID,
PasswordComplexityColID.identifier(): authz.GetInstance(ctx).InstanceID(),
PasswordComplexityColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
}).
OrderBy(PasswordComplexityColIsDefault.identifier()).

View File

@@ -91,7 +91,7 @@ func (q *Queries) PrivacyPolicyByOrg(ctx context.Context, orgID string) (*Privac
PrivacyColID.identifier(): orgID,
},
sq.Eq{
PrivacyColID.identifier(): domain.IAMID,
PrivacyColID.identifier(): authz.GetInstance(ctx).InstanceID(),
},
},
}).
@@ -108,7 +108,7 @@ func (q *Queries) PrivacyPolicyByOrg(ctx context.Context, orgID string) (*Privac
func (q *Queries) DefaultPrivacyPolicy(ctx context.Context) (*PrivacyPolicy, error) {
stmt, scan := preparePrivacyPolicyQuery()
query, args, err := stmt.Where(sq.Eq{
PrivacyColID.identifier(): domain.IAMID,
PrivacyColID.identifier(): authz.GetInstance(ctx).InstanceID(),
PrivacyColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
}).
OrderBy(PrivacyColIsDefault.identifier()).

View File

@@ -0,0 +1,96 @@
package projection
import (
"context"
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore"
"github.com/caos/zitadel/internal/eventstore/handler"
"github.com/caos/zitadel/internal/eventstore/handler/crdb"
"github.com/caos/zitadel/internal/repository/instance"
)
const (
InstanceDomainTable = "projections.instance_domains"
InstanceDomainInstanceIDCol = "instance_id"
InstanceDomainCreationDateCol = "creation_date"
InstanceDomainChangeDateCol = "change_date"
InstanceDomainSequenceCol = "sequence"
InstanceDomainDomainCol = "domain"
InstanceDomainIsGeneratedCol = "is_generated"
)
type InstanceDomainProjection struct {
crdb.StatementHandler
}
func NewInstanceDomainProjection(ctx context.Context, config crdb.StatementHandlerConfig) *InstanceDomainProjection {
p := new(InstanceDomainProjection)
config.ProjectionName = InstanceDomainTable
config.Reducers = p.reducers()
config.InitCheck = crdb.NewTableCheck(
crdb.NewTable([]*crdb.Column{
crdb.NewColumn(InstanceDomainInstanceIDCol, crdb.ColumnTypeText),
crdb.NewColumn(InstanceDomainCreationDateCol, crdb.ColumnTypeTimestamp),
crdb.NewColumn(InstanceDomainChangeDateCol, crdb.ColumnTypeTimestamp),
crdb.NewColumn(InstanceDomainSequenceCol, crdb.ColumnTypeInt64),
crdb.NewColumn(InstanceDomainDomainCol, crdb.ColumnTypeText),
crdb.NewColumn(InstanceDomainIsGeneratedCol, crdb.ColumnTypeBool),
},
crdb.NewPrimaryKey(InstanceDomainInstanceIDCol, InstanceDomainDomainCol),
),
)
p.StatementHandler = crdb.NewStatementHandler(ctx, config)
return p
}
func (p *InstanceDomainProjection) reducers() []handler.AggregateReducer {
return []handler.AggregateReducer{
{
Aggregate: instance.AggregateType,
EventRedusers: []handler.EventReducer{
{
Event: instance.InstanceDomainAddedEventType,
Reduce: p.reduceDomainAdded,
},
{
Event: instance.InstanceDomainRemovedEventType,
Reduce: p.reduceDomainRemoved,
},
},
},
}
}
func (p *InstanceDomainProjection) reduceDomainAdded(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*instance.DomainAddedEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "PROJE-38nNf", "reduce.wrong.event.type %s", instance.InstanceDomainAddedEventType)
}
return crdb.NewCreateStatement(
e,
[]handler.Column{
handler.NewCol(InstanceDomainCreationDateCol, e.CreationDate()),
handler.NewCol(InstanceDomainChangeDateCol, e.CreationDate()),
handler.NewCol(InstanceDomainSequenceCol, e.Sequence()),
handler.NewCol(InstanceDomainDomainCol, e.Domain),
handler.NewCol(InstanceDomainInstanceIDCol, e.Aggregate().ID),
handler.NewCol(InstanceDomainIsGeneratedCol, e.Generated),
},
), nil
}
func (p *InstanceDomainProjection) reduceDomainRemoved(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*instance.DomainRemovedEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "PROJE-388Nk", "reduce.wrong.event.type %s", instance.InstanceDomainRemovedEventType)
}
return crdb.NewDeleteStatement(
e,
[]handler.Condition{
handler.NewCond(InstanceDomainDomainCol, e.Domain),
handler.NewCond(InstanceDomainInstanceIDCol, e.Aggregate().ID),
},
), nil
}

View File

@@ -0,0 +1,97 @@
package projection
import (
"testing"
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore"
"github.com/caos/zitadel/internal/eventstore/handler"
"github.com/caos/zitadel/internal/eventstore/repository"
"github.com/caos/zitadel/internal/repository/instance"
)
func TestInstanceDomainProjection_reduces(t *testing.T) {
type args struct {
event func(t *testing.T) eventstore.Event
}
tests := []struct {
name string
args args
reduce func(event eventstore.Event) (*handler.Statement, error)
want wantReduce
}{
{
name: "reduceDomainAdded",
args: args{
event: getEvent(testEvent(
repository.EventType(instance.InstanceDomainAddedEventType),
instance.AggregateType,
[]byte(`{"domain": "domain.new", "generated": true}`),
), instance.DomainAddedEventMapper),
},
reduce: (&InstanceDomainProjection{}).reduceDomainAdded,
want: wantReduce{
projection: InstanceDomainTable,
aggregateType: eventstore.AggregateType("instance"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "INSERT INTO projections.instance_domains (creation_date, change_date, sequence, domain, instance_id, is_generated) VALUES ($1, $2, $3, $4, $5, $6)",
expectedArgs: []interface{}{
anyArg{},
anyArg{},
uint64(15),
"domain.new",
"agg-id",
true,
},
},
},
},
},
},
{
name: "reduceDomainRemoved",
args: args{
event: getEvent(testEvent(
repository.EventType(instance.InstanceDomainRemovedEventType),
instance.AggregateType,
[]byte(`{"domain": "domain.new"}`),
), instance.DomainRemovedEventMapper),
},
reduce: (&InstanceDomainProjection{}).reduceDomainRemoved,
want: wantReduce{
projection: InstanceDomainTable,
aggregateType: eventstore.AggregateType("instance"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "DELETE FROM projections.instance_domains WHERE (domain = $1) AND (instance_id = $2)",
expectedArgs: []interface{}{
"domain.new",
"agg-id",
},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
event := baseEvent(t)
got, err := tt.reduce(event)
if _, ok := err.(errors.InvalidArgument); !ok {
t.Errorf("no wrong event mapping: %v, got: %v", err, got)
}
event = tt.args.event(t)
got, err = tt.reduce(event)
assertReduce(t, got, err, tt.want)
})
}
}

View File

@@ -59,6 +59,7 @@ func Start(ctx context.Context, sqlClient *sql.DB, es *eventstore.Eventstore, co
NewUserProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["users"]))
NewLoginNameProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["login_names"]))
NewOrgMemberProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["org_members"]))
NewInstanceDomainProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["instance_domains"]))
NewInstanceMemberProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["iam_members"]))
NewProjectMemberProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["project_members"]))
NewProjectGrantMemberProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["project_grant_members"]))

View File

@@ -3,6 +3,7 @@ package query
import (
"context"
"github.com/caos/zitadel/internal/api/authz"
"github.com/caos/zitadel/internal/domain"
)
@@ -11,7 +12,7 @@ func (q *Queries) MyZitadelPermissions(ctx context.Context, orgID, userID string
if err != nil {
return nil, err
}
orgIDsQuery, err := NewMembershipResourceOwnersSearchQuery(orgID, domain.IAMID)
orgIDsQuery, err := NewMembershipResourceOwnersSearchQuery(orgID, authz.GetInstance(ctx).InstanceID())
if err != nil {
return nil, err
}