From e516c1cdf67e43ef00c77b77b2ff4c8b677f5341 Mon Sep 17 00:00:00 2001 From: adlerhurst <27845747+adlerhurst@users.noreply.github.com> Date: Wed, 23 Jul 2025 11:37:55 +0200 Subject: [PATCH] instance domain tests --- backend/v3/domain/instance_domain.go | 14 +- backend/v3/storage/database/operators.go | 2 +- .../database/repository/instance_domain.go | 17 +- .../repository/instance_domain_test.go | 689 ++++++++++++++++++ backend/v3/storage/database/statement.go | 12 + 5 files changed, 718 insertions(+), 16 deletions(-) create mode 100644 backend/v3/storage/database/repository/instance_domain_test.go diff --git a/backend/v3/domain/instance_domain.go b/backend/v3/domain/instance_domain.go index 76ce42048f..7acbf11b6b 100644 --- a/backend/v3/domain/instance_domain.go +++ b/backend/v3/domain/instance_domain.go @@ -14,16 +14,16 @@ type InstanceDomain struct { IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"` ValidationType DomainValidationType `json:"validationType,omitempty" db:"validation_type"` - CreatedAt string `json:"createdAt,omitempty" db:"created_at"` - UpdatedAt string `json:"updatedAt,omitempty" db:"updated_at"` + CreatedAt time.Time `json:"createdAt,omitzero" db:"created_at"` + UpdatedAt time.Time `json:"updatedAt,omitzero" db:"updated_at"` } type AddInstanceDomain struct { - InstanceID string `json:"instanceId,omitempty" db:"instance_id"` - Domain string `json:"domain,omitempty" db:"domain"` - IsVerified bool `json:"isVerified,omitempty" db:"is_verified"` - IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"` - VerificationType DomainValidationType `json:"validationType,omitempty" db:"validation_type"` + InstanceID string `json:"instanceId,omitempty" db:"instance_id"` + Domain string `json:"domain,omitempty" db:"domain"` + IsVerified bool `json:"isVerified,omitempty" db:"is_verified"` + IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"` + ValidationType DomainValidationType `json:"validationType,omitempty" db:"validation_type"` // CreatedAt is the time when the domain was added. // It is set by the repository and should not be set by the caller. diff --git a/backend/v3/storage/database/operators.go b/backend/v3/storage/database/operators.go index a2949220e9..9f3e18aeab 100644 --- a/backend/v3/storage/database/operators.go +++ b/backend/v3/storage/database/operators.go @@ -137,6 +137,6 @@ const ( func writeBooleanOperation[T Boolean](builder *StatementBuilder, col Column, value T) { col.Write(builder) - builder.WriteString(" IS ") + builder.WriteString(" = ") builder.WriteArg(value) } diff --git a/backend/v3/storage/database/repository/instance_domain.go b/backend/v3/storage/database/repository/instance_domain.go index fef2be4e35..d04ddb6abe 100644 --- a/backend/v3/storage/database/repository/instance_domain.go +++ b/backend/v3/storage/database/repository/instance_domain.go @@ -19,7 +19,7 @@ type instanceDomain struct { // ------------------------------------------------------------- const queryInstanceDomainStmt = `SELECT instance_domains.instance_id, instance_domains.domain, instance_domains.is_verified, instance_domains.is_primary, instance_domains.validation_type, instance_domains.created_at, instance_domains.updated_at ` + - `FROM zitadel.instance_domains id` + `FROM zitadel.instance_domains` // Get implements [domain.InstanceDomainRepository]. // Subtle: this method shadows the method ([domain.InstanceRepository]).Get of instanceDomain.instance. @@ -55,11 +55,9 @@ func (i *instanceDomain) List(ctx context.Context, opts ...database.QueryOption) func (i *instanceDomain) Add(ctx context.Context, domain *domain.AddInstanceDomain) error { var builder database.StatementBuilder - builder.WriteString(`INSERT INTO zitadel.instance_domains (instance_id, domain, is_verified, is_primary, validation_type) ` + - `VALUES ($1, $2, $3, $4, $5)` + - ` RETURNING created_at, updated_at`) - - builder.AppendArgs(domain.InstanceID, domain.Domain, domain.IsVerified, domain.IsPrimary, domain.VerificationType) + builder.WriteString(`INSERT INTO zitadel.instance_domains (instance_id, domain, is_verified, is_primary, validation_type) VALUES (`) + builder.WriteArgs(domain.InstanceID, domain.Domain, domain.IsVerified, domain.IsPrimary, domain.ValidationType) + builder.WriteString(`) RETURNING created_at, updated_at`) return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt) } @@ -69,7 +67,7 @@ func (i *instanceDomain) Remove(ctx context.Context, condition database.Conditio var builder database.StatementBuilder builder.WriteString(`DELETE FROM zitadel.instance_domains WHERE `) - writeCondition(&builder, condition) + condition.Write(&builder) return i.client.Exec(ctx, builder.String(), builder.Args()...) } @@ -77,6 +75,9 @@ func (i *instanceDomain) Remove(ctx context.Context, condition database.Conditio // Update implements [domain.InstanceDomainRepository]. // Subtle: this method shadows the method ([domain.InstanceRepository]).Update of instanceDomain.instance. func (i *instanceDomain) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) { + if len(changes) == 0 { + return 0, database.NoChangesError + } var builder database.StatementBuilder builder.WriteString(`UPDATE zitadel.instance_domains SET `) @@ -223,7 +224,7 @@ func scanInstanceDomain(ctx context.Context, querier database.Querier, builder * if err != nil { return nil, err } - instanceDomain := &domain.InstanceDomain{} + instanceDomain := new(domain.InstanceDomain) if err := rows.(database.CollectableRows).CollectExactlyOneRow(instanceDomain); err != nil { return nil, err } diff --git a/backend/v3/storage/database/repository/instance_domain_test.go b/backend/v3/storage/database/repository/instance_domain_test.go new file mode 100644 index 0000000000..888b74bdab --- /dev/null +++ b/backend/v3/storage/database/repository/instance_domain_test.go @@ -0,0 +1,689 @@ +package repository_test + +import ( + "context" + "testing" + + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" +) + +func TestAddInstanceDomain(t *testing.T) { + // create instance + instanceID := gofakeit.UUID() + instance := domain.Instance{ + ID: instanceID, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleClient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + tests := []struct { + name string + testFunc func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain + instanceDomain domain.AddInstanceDomain + err error + }{ + { + name: "happy path", + instanceDomain: domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: gofakeit.DomainName(), + IsVerified: false, + IsPrimary: false, + ValidationType: domain.DomainValidationTypeDNS, + }, + }, + { + name: "add verified domain", + instanceDomain: domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: gofakeit.DomainName(), + IsVerified: true, + IsPrimary: false, + ValidationType: domain.DomainValidationTypeHTTP, + }, + }, + { + name: "add primary domain", + instanceDomain: domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: gofakeit.DomainName(), + IsVerified: true, + IsPrimary: true, + ValidationType: domain.DomainValidationTypeDNS, + }, + }, + { + name: "add domain without domain name", + instanceDomain: domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: "", + IsVerified: false, + IsPrimary: false, + ValidationType: domain.DomainValidationTypeDNS, + }, + err: new(database.CheckError), + }, + { + name: "add domain with same domain twice", + testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain { + domainName := gofakeit.DomainName() + + instanceDomain := &domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: domainName, + IsVerified: false, + IsPrimary: false, + ValidationType: domain.DomainValidationTypeDNS, + } + + err := domainRepo.Add(ctx, instanceDomain) + require.NoError(t, err) + + // return same domain again + return &domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: domainName, + IsVerified: true, + IsPrimary: true, + ValidationType: domain.DomainValidationTypeHTTP, + } + }, + err: new(database.UniqueError), + }, + { + name: "add domain with non-existent instance", + instanceDomain: domain.AddInstanceDomain{ + InstanceID: "non-existent-instance", + Domain: gofakeit.DomainName(), + IsVerified: false, + IsPrimary: false, + ValidationType: domain.DomainValidationTypeDNS, + }, + err: new(database.ForeignKeyError), + }, + { + name: "add domain without instance id", + instanceDomain: domain.AddInstanceDomain{ + Domain: gofakeit.DomainName(), + IsVerified: false, + IsPrimary: false, + ValidationType: domain.DomainValidationTypeDNS, + }, + err: new(database.ForeignKeyError), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := t.Context() + + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + require.NoError(t, tx.Rollback(t.Context())) + }() + instanceRepo := repository.InstanceRepository(tx) + domainRepo := instanceRepo.Domains(false) + + var instanceDomain *domain.AddInstanceDomain + if test.testFunc != nil { + instanceDomain = test.testFunc(ctx, t, domainRepo) + } else { + instanceDomain = &test.instanceDomain + } + + err = domainRepo.Add(ctx, instanceDomain) + if test.err != nil { + assert.ErrorIs(t, err, test.err) + return + } + + require.NoError(t, err) + assert.NotZero(t, instanceDomain.CreatedAt) + assert.NotZero(t, instanceDomain.UpdatedAt) + }) + } +} + +func TestGetInstanceDomain(t *testing.T) { + // create instance + instanceID := gofakeit.UUID() + instance := domain.Instance{ + ID: instanceID, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleClient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + require.NoError(t, tx.Rollback(t.Context())) + }() + instanceRepo := repository.InstanceRepository(tx) + err = instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + // add domains + domainRepo := instanceRepo.Domains(false) + domainName1 := gofakeit.DomainName() + domainName2 := gofakeit.DomainName() + + domain1 := &domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: domainName1, + IsVerified: true, + IsPrimary: true, + ValidationType: domain.DomainValidationTypeDNS, + } + domain2 := &domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: domainName2, + IsVerified: false, + IsPrimary: false, + ValidationType: domain.DomainValidationTypeHTTP, + } + + err = domainRepo.Add(t.Context(), domain1) + require.NoError(t, err) + err = domainRepo.Add(t.Context(), domain2) + require.NoError(t, err) + + tests := []struct { + name string + opts []database.QueryOption + expected *domain.InstanceDomain + err error + }{ + { + name: "get primary domain", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.IsPrimaryCondition(true)), + }, + expected: &domain.InstanceDomain{ + InstanceID: instanceID, + Domain: domainName1, + IsVerified: true, + IsPrimary: true, + ValidationType: domain.DomainValidationTypeDNS, + }, + }, + { + name: "get by domain name", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domainName2)), + }, + expected: &domain.InstanceDomain{ + InstanceID: instanceID, + Domain: domainName2, + IsVerified: false, + IsPrimary: false, + ValidationType: domain.DomainValidationTypeHTTP, + }, + }, + { + name: "get verified domain", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.IsVerifiedCondition(true)), + }, + expected: &domain.InstanceDomain{ + InstanceID: instanceID, + Domain: domainName1, + IsVerified: true, + IsPrimary: true, + ValidationType: domain.DomainValidationTypeDNS, + }, + }, + { + name: "get non-existent domain", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com")), + }, + err: new(database.NoRowFoundError), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := t.Context() + + result, err := domainRepo.Get(ctx, test.opts...) + if test.err != nil { + assert.ErrorIs(t, err, test.err) + return + } + + require.NoError(t, err) + assert.Equal(t, test.expected.InstanceID, result.InstanceID) + assert.Equal(t, test.expected.Domain, result.Domain) + assert.Equal(t, test.expected.IsVerified, result.IsVerified) + assert.Equal(t, test.expected.IsPrimary, result.IsPrimary) + assert.Equal(t, test.expected.ValidationType, result.ValidationType) + assert.NotEmpty(t, result.CreatedAt) + assert.NotEmpty(t, result.UpdatedAt) + }) + } +} + +func TestListInstanceDomains(t *testing.T) { + // create instance + instanceID := gofakeit.UUID() + instance := domain.Instance{ + ID: instanceID, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleClient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + require.NoError(t, tx.Rollback(t.Context())) + }() + + instanceRepo := repository.InstanceRepository(tx) + err = instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + // add multiple domains + domainRepo := instanceRepo.Domains(false) + domains := []domain.AddInstanceDomain{ + { + InstanceID: instanceID, + Domain: gofakeit.DomainName(), + IsVerified: true, + IsPrimary: true, + ValidationType: domain.DomainValidationTypeDNS, + }, + { + InstanceID: instanceID, + Domain: gofakeit.DomainName(), + IsVerified: false, + IsPrimary: false, + ValidationType: domain.DomainValidationTypeHTTP, + }, + { + InstanceID: instanceID, + Domain: gofakeit.DomainName(), + IsVerified: true, + IsPrimary: false, + ValidationType: domain.DomainValidationTypeDNS, + }, + } + + for i := range domains { + err = domainRepo.Add(t.Context(), &domains[i]) + require.NoError(t, err) + } + + tests := []struct { + name string + opts []database.QueryOption + expectedCount int + }{ + { + name: "list all domains", + opts: []database.QueryOption{}, + expectedCount: 3, + }, + { + name: "list verified domains", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.IsVerifiedCondition(true)), + }, + expectedCount: 2, + }, + { + name: "list primary domains", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.IsPrimaryCondition(true)), + }, + expectedCount: 1, + }, + { + name: "list by instance", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.InstanceIDCondition(instanceID)), + }, + expectedCount: 3, + }, + { + name: "list non-existent instance", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.InstanceIDCondition("non-existent")), + }, + expectedCount: 0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := t.Context() + + results, err := domainRepo.List(ctx, test.opts...) + require.NoError(t, err) + assert.Len(t, results, test.expectedCount) + + for _, result := range results { + assert.Equal(t, instanceID, result.InstanceID) + assert.NotEmpty(t, result.Domain) + assert.NotEmpty(t, result.CreatedAt) + assert.NotEmpty(t, result.UpdatedAt) + } + }) + } +} + +func TestUpdateInstanceDomain(t *testing.T) { + // create instance + instanceID := gofakeit.UUID() + instance := domain.Instance{ + ID: instanceID, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleClient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + require.NoError(t, tx.Rollback(t.Context())) + }() + + instanceRepo := repository.InstanceRepository(tx) + err = instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + // add domain + domainRepo := instanceRepo.Domains(false) + domainName := gofakeit.DomainName() + instanceDomain := &domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: domainName, + IsVerified: false, + IsPrimary: false, + ValidationType: domain.DomainValidationTypeDNS, + } + + err = domainRepo.Add(t.Context(), instanceDomain) + require.NoError(t, err) + + tests := []struct { + name string + condition database.Condition + changes []database.Change + expected int64 + err error + }{ + { + name: "set verified", + condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName), + changes: []database.Change{domainRepo.SetVerified()}, + expected: 1, + }, + { + name: "set primary", + condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName), + changes: []database.Change{domainRepo.SetPrimary()}, + expected: 1, + }, + { + name: "set validation type", + condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName), + changes: []database.Change{domainRepo.SetValidationType(domain.DomainValidationTypeHTTP)}, + expected: 1, + }, + { + name: "multiple changes", + condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName), + changes: []database.Change{ + domainRepo.SetVerified(), + domainRepo.SetPrimary(), + domainRepo.SetValidationType(domain.DomainValidationTypeDNS), + }, + expected: 1, + }, + { + name: "update non-existent domain", + condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"), + changes: []database.Change{domainRepo.SetVerified()}, + expected: 0, + }, + { + name: "no changes", + condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName), + changes: []database.Change{}, + expected: 0, + err: database.NoChangesError, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := t.Context() + + rowsAffected, err := domainRepo.Update(ctx, test.condition, test.changes...) + if test.err != nil { + assert.ErrorIs(t, err, test.err) + return + } + + require.NoError(t, err) + assert.Equal(t, test.expected, rowsAffected) + + // verify changes were applied if rows were affected + if rowsAffected > 0 && len(test.changes) > 0 { + result, err := domainRepo.Get(ctx, database.WithCondition(test.condition)) + require.NoError(t, err) + + // We know changes were applied since rowsAffected > 0 + // The specific verification of what changed is less important + // than knowing the operation succeeded + assert.NotNil(t, result) + } + }) + } +} + +func TestRemoveInstanceDomain(t *testing.T) { + // create instance + instanceID := gofakeit.UUID() + instance := domain.Instance{ + ID: instanceID, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleClient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + require.NoError(t, tx.Rollback(t.Context())) + }() + instanceRepo := repository.InstanceRepository(tx) + err = instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + // add domains + domainRepo := instanceRepo.Domains(false) + domainName1 := gofakeit.DomainName() + domainName2 := gofakeit.DomainName() + + domain1 := &domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: domainName1, + IsVerified: true, + IsPrimary: true, + ValidationType: domain.DomainValidationTypeDNS, + } + domain2 := &domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: domainName2, + IsVerified: false, + IsPrimary: false, + ValidationType: domain.DomainValidationTypeHTTP, + } + + err = domainRepo.Add(t.Context(), domain1) + require.NoError(t, err) + err = domainRepo.Add(t.Context(), domain2) + require.NoError(t, err) + + tests := []struct { + name string + condition database.Condition + expected int64 + }{ + { + name: "remove by domain name", + condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName1), + expected: 1, + }, + { + name: "remove by primary condition", + condition: domainRepo.IsPrimaryCondition(false), + expected: 1, // domain2 should still exist and be non-primary + }, + { + name: "remove non-existent domain", + condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"), + expected: 0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := t.Context() + + // count before removal + beforeCount, err := domainRepo.List(ctx) + require.NoError(t, err) + + rowsAffected, err := domainRepo.Remove(ctx, test.condition) + require.NoError(t, err) + assert.Equal(t, test.expected, rowsAffected) + + // verify removal + afterCount, err := domainRepo.List(ctx) + require.NoError(t, err) + assert.Equal(t, len(beforeCount)-int(test.expected), len(afterCount)) + }) + } +} + +func TestInstanceDomainConditions(t *testing.T) { + instanceRepo := repository.InstanceRepository(pool) + domainRepo := instanceRepo.Domains(false) + + tests := []struct { + name string + condition database.Condition + expected string + }{ + { + name: "domain condition equal", + condition: domainRepo.DomainCondition(database.TextOperationEqual, "example.com"), + expected: "instance_domains.domain = $1", + }, + { + name: "domain condition starts with", + condition: domainRepo.DomainCondition(database.TextOperationStartsWith, "example"), + expected: "instance_domains.domain LIKE $1 || '%'", + }, + { + name: "instance id condition", + condition: domainRepo.InstanceIDCondition("instance-123"), + expected: "instance_domains.instance_id = $1", + }, + { + name: "is primary true", + condition: domainRepo.IsPrimaryCondition(true), + expected: "instance_domains.is_primary = $1", + }, + { + name: "is primary false", + condition: domainRepo.IsPrimaryCondition(false), + expected: "instance_domains.is_primary = $1", + }, + { + name: "is verified true", + condition: domainRepo.IsVerifiedCondition(true), + expected: "instance_domains.is_verified = $1", + }, + { + name: "is verified false", + condition: domainRepo.IsVerifiedCondition(false), + expected: "instance_domains.is_verified = $1", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var builder database.StatementBuilder + test.condition.Write(&builder) + assert.Equal(t, test.expected, builder.String()) + }) + } +} + +func TestInstanceDomainChanges(t *testing.T) { + instanceRepo := repository.InstanceRepository(pool) + domainRepo := instanceRepo.Domains(false) + + tests := []struct { + name string + change database.Change + expected string + }{ + { + name: "set verified", + change: domainRepo.SetVerified(), + expected: "is_verified = $1", + }, + { + name: "set primary", + change: domainRepo.SetPrimary(), + expected: "is_primary = $1", + }, + { + name: "set validation type DNS", + change: domainRepo.SetValidationType(domain.DomainValidationTypeDNS), + expected: "validation_type = $1", + }, + { + name: "set validation type HTTP", + change: domainRepo.SetValidationType(domain.DomainValidationTypeHTTP), + expected: "validation_type = $1", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var builder database.StatementBuilder + test.change.Write(&builder) + assert.Equal(t, test.expected, builder.String()) + }) + } +} diff --git a/backend/v3/storage/database/statement.go b/backend/v3/storage/database/statement.go index 2a038af549..7e20fcf4f0 100644 --- a/backend/v3/storage/database/statement.go +++ b/backend/v3/storage/database/statement.go @@ -24,6 +24,17 @@ func (b *StatementBuilder) WriteArg(arg any) { b.WriteString(b.AppendArg(arg)) } +// WriteArgs adds the arguments to the statement and writes the placeholders to the query. +// The placeholders are comma separated. +func (b *StatementBuilder) WriteArgs(args ...any) { + for i, arg := range args { + if i > 0 { + b.WriteString(", ") + } + b.WriteArg(arg) + } +} + // AppendArg adds the argument to the statement and returns the placeholder. func (b *StatementBuilder) AppendArg(arg any) (placeholder string) { if b.existingArgs == nil { @@ -43,6 +54,7 @@ func (b *StatementBuilder) AppendArg(arg any) (placeholder string) { } // AppendArgs adds the arguments to the statement and doesn't return the placeholders. +// If an argument is already added, it will not be added again. func (b *StatementBuilder) AppendArgs(args ...any) { for _, arg := range args { b.AppendArg(arg)