Files
zitadel/internal/query/instance.go
Livio Spring 7520450e11 fix: sanitize host headers before use
# Which Problems Are Solved

Host headers used to identify the instance and further used in public responses such as OIDC discovery endpoints, email links and more were not correctly handled. While they were matched against existing instances, they were not properly sanitized.

# How the Problems Are Solved

Sanitize host header including port validation (if provided).

# Additional Changes

None

# Additional Context

- requires backports

(cherry picked from commit 72a5c33e6a)
2025-10-29 10:07:05 +01:00

650 lines
19 KiB
Go

package query
import (
"context"
"database/sql"
_ "embed"
"encoding/json"
"errors"
"fmt"
"slices"
"time"
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/logging"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
target_domain "github.com/zitadel/zitadel/internal/execution/target"
"github.com/zitadel/zitadel/internal/feature"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
const (
InstancesFilterTableAlias = "f"
)
var (
instanceTable = table{
name: projection.InstanceProjectionTable,
instanceIDCol: projection.InstanceColumnID,
}
limitsTable = table{
name: projection.LimitsProjectionTable,
instanceIDCol: projection.LimitsColumnInstanceID,
}
InstanceColumnID = Column{
name: projection.InstanceColumnID,
table: instanceTable,
}
InstanceColumnName = Column{
name: projection.InstanceColumnName,
table: instanceTable,
}
InstanceColumnCreationDate = Column{
name: projection.InstanceColumnCreationDate,
table: instanceTable,
}
InstanceColumnChangeDate = Column{
name: projection.InstanceColumnChangeDate,
table: instanceTable,
}
InstanceColumnSequence = Column{
name: projection.InstanceColumnSequence,
table: instanceTable,
}
InstanceColumnDefaultOrgID = Column{
name: projection.InstanceColumnDefaultOrgID,
table: instanceTable,
}
InstanceColumnProjectID = Column{
name: projection.InstanceColumnProjectID,
table: instanceTable,
}
InstanceColumnConsoleID = Column{
name: projection.InstanceColumnConsoleID,
table: instanceTable,
}
InstanceColumnConsoleAppID = Column{
name: projection.InstanceColumnConsoleAppID,
table: instanceTable,
}
InstanceColumnDefaultLanguage = Column{
name: projection.InstanceColumnDefaultLanguage,
table: instanceTable,
}
LimitsColumnInstanceID = Column{
name: projection.LimitsColumnInstanceID,
table: limitsTable,
}
LimitsColumnAuditLogRetention = Column{
name: projection.LimitsColumnAuditLogRetention,
table: limitsTable,
}
LimitsColumnBlock = Column{
name: projection.LimitsColumnBlock,
table: limitsTable,
}
)
type Instance struct {
ID string
ChangeDate time.Time
CreationDate time.Time
Sequence uint64
Name string
DefaultOrgID string
IAMProjectID string
ConsoleID string
ConsoleAppID string
DefaultLang language.Tag
Domains []*InstanceDomain
}
type Instances struct {
SearchResponse
Instances []*Instance
}
type InstanceSearchQueries struct {
SearchRequest
Queries []SearchQuery
}
func NewInstanceIDsListSearchQuery(ids ...string) (SearchQuery, error) {
list := make([]interface{}, len(ids))
for i, value := range ids {
list[i] = value
}
return NewListQuery(InstanceColumnID, list, ListIn)
}
func NewInstanceDomainsListSearchQuery(domains ...string) (SearchQuery, error) {
list := make([]interface{}, len(domains))
for i, value := range domains {
list[i] = value
}
return NewListQuery(InstanceDomainDomainCol, list, ListIn)
}
func (q *InstanceSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
query = q.SearchRequest.toQuery(query)
for _, q := range q.Queries {
query = q.toQuery(query)
}
return query
}
func (q *Queries) ActiveInstances() []string {
return q.caches.activeInstances.Keys()
}
func (q *Queries) SearchInstances(ctx context.Context, queries *InstanceSearchQueries) (instances *Instances, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
filter, query, scan := prepareInstancesQuery(queries.SortingColumn, queries.Asc)
stmt, args, err := query(queries.toQuery(filter)).ToSql()
if err != nil {
return nil, zerrors.ThrowInvalidArgument(err, "QUERY-M9fow", "Errors.Query.SQLStatement")
}
err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
instances, err = scan(rows)
return err
}, stmt, args...)
if err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-3j98f", "Errors.Internal")
}
return instances, err
}
func (q *Queries) Instance(ctx context.Context, shouldTriggerBulk bool) (instance *Instance, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if shouldTriggerBulk {
_, traceSpan := tracing.NewNamedSpan(ctx, "TriggerInstanceProjection")
ctx, err = projection.InstanceProjection.Trigger(ctx, handler.WithAwaitRunning())
logging.OnError(err).Debug("trigger failed")
traceSpan.EndWithError(err)
}
stmt, scan := prepareInstanceDomainQuery()
query, args, err := stmt.Where(sq.Eq{
InstanceColumnID.identifier(): authz.GetInstance(ctx).InstanceID(),
}).ToSql()
if err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-d9ngs", "Errors.Query.SQLStatement")
}
err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
instance, err = scan(rows)
return err
}, query, args...)
return instance, err
}
var (
//go:embed instance_by_domain.sql
instanceByDomainQuery string
//go:embed instance_by_id.sql
instanceByIDQuery string
)
func (q *Queries) InstanceByHost(ctx context.Context, instanceDomain, publicDomain string) (_ authz.Instance, err error) {
var instance *authzInstance
ctx, span := tracing.NewSpan(ctx)
defer func() {
if err != nil {
err = fmt.Errorf("unable to get instance by domain: instanceDomain %s, publicHostname %s: %w", instanceDomain, publicDomain, err)
} else {
q.caches.activeInstances.Add(instance.ID, true)
}
span.EndWithError(err)
}()
instance, ok := q.caches.instance.Get(ctx, instanceIndexByHost, instanceDomain)
if ok {
return instance, instance.checkDomain(instanceDomain, publicDomain)
}
instance, scan := scanAuthzInstance()
if err = q.client.QueryRowContext(ctx, scan, instanceByDomainQuery, instanceDomain); err != nil {
return nil, err
}
q.caches.instance.Set(ctx, instance)
return instance, instance.checkDomain(instanceDomain, publicDomain)
}
func (q *Queries) InstanceByID(ctx context.Context, id string) (_ authz.Instance, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
defer func() {
if err != nil {
return
}
q.caches.activeInstances.Add(id, true)
}()
instance, ok := q.caches.instance.Get(ctx, instanceIndexByID, id)
if ok {
return instance, nil
}
instance, scan := scanAuthzInstance()
err = q.client.QueryRowContext(ctx, scan, instanceByIDQuery, id)
logging.OnError(err).WithField("instance_id", id).Warn("instance by ID")
if err == nil {
q.caches.instance.Set(ctx, instance)
}
return instance, err
}
func (q *Queries) GetDefaultLanguage(ctx context.Context) language.Tag {
instance, err := q.Instance(ctx, false)
if err != nil {
return language.Und
}
return instance.DefaultLang
}
func prepareInstancesQuery(sortBy Column, isAscedingSort bool) (sq.SelectBuilder, func(sq.SelectBuilder) sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) {
instanceFilterTable := instanceTable.setAlias(InstancesFilterTableAlias)
instanceFilterIDColumn := InstanceColumnID.setTable(instanceFilterTable)
instanceFilterCountColumn := InstancesFilterTableAlias + ".count"
selector := sq.Select(InstanceColumnID.identifier(), countColumn.identifier())
if !sortBy.isZero() {
selector = sq.Select(InstanceColumnID.identifier(), countColumn.identifier(), sortBy.identifier())
}
return selector.Distinct().From(instanceTable.identifier()).
LeftJoin(join(InstanceDomainInstanceIDCol, InstanceColumnID)),
func(builder sq.SelectBuilder) sq.SelectBuilder {
outerQuery := sq.Select(
instanceFilterCountColumn,
instanceFilterIDColumn.identifier(),
InstanceColumnCreationDate.identifier(),
InstanceColumnChangeDate.identifier(),
InstanceColumnSequence.identifier(),
InstanceColumnName.identifier(),
InstanceColumnDefaultOrgID.identifier(),
InstanceColumnProjectID.identifier(),
InstanceColumnConsoleID.identifier(),
InstanceColumnConsoleAppID.identifier(),
InstanceColumnDefaultLanguage.identifier(),
InstanceDomainDomainCol.identifier(),
InstanceDomainIsPrimaryCol.identifier(),
InstanceDomainIsGeneratedCol.identifier(),
InstanceDomainCreationDateCol.identifier(),
InstanceDomainChangeDateCol.identifier(),
InstanceDomainSequenceCol.identifier(),
).FromSelect(builder, InstancesFilterTableAlias).
LeftJoin(join(InstanceColumnID, instanceFilterIDColumn)).
LeftJoin(join(InstanceDomainInstanceIDCol, instanceFilterIDColumn)).
PlaceholderFormat(sq.Dollar)
if !sortBy.isZero() {
sorting := sortBy.identifier()
if !isAscedingSort {
sorting += " DESC"
}
return outerQuery.OrderBy(sorting)
}
return outerQuery
},
func(rows *sql.Rows) (*Instances, error) {
instances := make([]*Instance, 0)
var lastInstance *Instance
var count uint64
for rows.Next() {
instance := new(Instance)
lang := ""
var (
domain sql.NullString
isPrimary sql.NullBool
isGenerated sql.NullBool
changeDate sql.NullTime
creationDate sql.NullTime
sequence sql.NullInt64
)
err := rows.Scan(
&count,
&instance.ID,
&instance.CreationDate,
&instance.ChangeDate,
&instance.Sequence,
&instance.Name,
&instance.DefaultOrgID,
&instance.IAMProjectID,
&instance.ConsoleID,
&instance.ConsoleAppID,
&lang,
&domain,
&isPrimary,
&isGenerated,
&changeDate,
&creationDate,
&sequence,
)
if err != nil {
return nil, err
}
if instance.ID == "" || !domain.Valid {
continue
}
instance.DefaultLang = language.Make(lang)
instanceDomain := &InstanceDomain{
CreationDate: creationDate.Time,
ChangeDate: changeDate.Time,
Sequence: uint64(sequence.Int64),
Domain: domain.String,
IsPrimary: isPrimary.Bool,
IsGenerated: isGenerated.Bool,
InstanceID: instance.ID,
}
if lastInstance != nil && instance.ID == lastInstance.ID {
lastInstance.Domains = append(lastInstance.Domains, instanceDomain)
continue
}
lastInstance = instance
instance.Domains = append(instance.Domains, instanceDomain)
instances = append(instances, instance)
}
if err := rows.Close(); err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-8nlWW", "Errors.Query.CloseRows")
}
return &Instances{
Instances: instances,
SearchResponse: SearchResponse{
Count: count,
},
}, nil
}
}
func prepareInstanceDomainQuery() (sq.SelectBuilder, func(*sql.Rows) (*Instance, error)) {
return sq.Select(
InstanceColumnID.identifier(),
InstanceColumnCreationDate.identifier(),
InstanceColumnChangeDate.identifier(),
InstanceColumnSequence.identifier(),
InstanceColumnName.identifier(),
InstanceColumnDefaultOrgID.identifier(),
InstanceColumnProjectID.identifier(),
InstanceColumnConsoleID.identifier(),
InstanceColumnConsoleAppID.identifier(),
InstanceColumnDefaultLanguage.identifier(),
InstanceDomainDomainCol.identifier(),
InstanceDomainIsPrimaryCol.identifier(),
InstanceDomainIsGeneratedCol.identifier(),
InstanceDomainCreationDateCol.identifier(),
InstanceDomainChangeDateCol.identifier(),
InstanceDomainSequenceCol.identifier(),
).
From(instanceTable.identifier()).
LeftJoin(join(InstanceDomainInstanceIDCol, InstanceColumnID)).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*Instance, error) {
instance := &Instance{
Domains: make([]*InstanceDomain, 0),
}
lang := ""
for rows.Next() {
var (
domain sql.NullString
isPrimary sql.NullBool
isGenerated sql.NullBool
changeDate sql.NullTime
creationDate sql.NullTime
sequence sql.NullInt64
)
err := rows.Scan(
&instance.ID,
&instance.CreationDate,
&instance.ChangeDate,
&instance.Sequence,
&instance.Name,
&instance.DefaultOrgID,
&instance.IAMProjectID,
&instance.ConsoleID,
&instance.ConsoleAppID,
&lang,
&domain,
&isPrimary,
&isGenerated,
&changeDate,
&creationDate,
&sequence,
)
if err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-d9nw", "Errors.Internal")
}
if !domain.Valid {
continue
}
instance.Domains = append(instance.Domains, &InstanceDomain{
CreationDate: creationDate.Time,
ChangeDate: changeDate.Time,
Sequence: uint64(sequence.Int64),
Domain: domain.String,
IsPrimary: isPrimary.Bool,
IsGenerated: isGenerated.Bool,
InstanceID: instance.ID,
})
}
if instance.ID == "" {
return nil, zerrors.ThrowNotFound(nil, "QUERY-n0wng", "Errors.IAM.NotFound")
}
instance.DefaultLang = language.Make(lang)
if err := rows.Close(); err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-Dfbe2", "Errors.Query.CloseRows")
}
return instance, nil
}
}
type authzInstance struct {
ID string `json:"id,omitempty"`
IAMProjectID string `json:"iam_project_id,omitempty"`
ConsoleID string `json:"console_id,omitempty"`
ConsoleAppID string `json:"console_app_id,omitempty"`
DefaultLang language.Tag `json:"default_lang,omitempty"`
DefaultOrgID string `json:"default_org_id,omitempty"`
CSP csp `json:"csp,omitempty"`
Impersonation bool `json:"impersonation,omitempty"`
IsBlocked *bool `json:"is_blocked,omitempty"`
LogRetention *time.Duration `json:"log_retention,omitempty"`
Feature feature.Features `json:"feature,omitempty"`
ExternalDomains database.TextArray[string] `json:"external_domains,omitempty"`
TrustedDomains database.TextArray[string] `json:"trusted_domains,omitempty"`
ExecutionTargets target_domain.Router `json:"execution_targets,omitzero"`
}
type csp struct {
EnableIframeEmbedding bool `json:"enable_iframe_embedding,omitempty"`
AllowedOrigins database.TextArray[string] `json:"allowed_origins,omitempty"`
}
func (i *authzInstance) InstanceID() string {
return i.ID
}
func (i *authzInstance) ProjectID() string {
return i.IAMProjectID
}
func (i *authzInstance) ConsoleClientID() string {
return i.ConsoleID
}
func (i *authzInstance) ConsoleApplicationID() string {
return i.ConsoleAppID
}
func (i *authzInstance) DefaultLanguage() language.Tag {
return i.DefaultLang
}
func (i *authzInstance) DefaultOrganisationID() string {
return i.DefaultOrgID
}
func (i *authzInstance) SecurityPolicyAllowedOrigins() []string {
if !i.CSP.EnableIframeEmbedding {
return nil
}
return i.CSP.AllowedOrigins
}
func (i *authzInstance) EnableImpersonation() bool {
return i.Impersonation
}
func (i *authzInstance) Block() *bool {
return i.IsBlocked
}
func (i *authzInstance) AuditLogRetention() *time.Duration {
return i.LogRetention
}
func (i *authzInstance) Features() feature.Features {
return i.Feature
}
func (i *authzInstance) ExecutionRouter() target_domain.Router {
return i.ExecutionTargets
}
var errPublicDomain = "public domain %q not trusted"
func (i *authzInstance) checkDomain(instanceDomain, publicDomain string) error {
// in case public domain is empty, or the same as the instance domain, we do not need to check it
if publicDomain == "" || instanceDomain == publicDomain {
return nil
}
if !slices.Contains(i.TrustedDomains, publicDomain) {
return zerrors.ThrowNotFound(fmt.Errorf(errPublicDomain, publicDomain), "QUERY-IuGh1", "Errors.IAM.NotFound")
}
return nil
}
// Keys implements [cache.Entry]
func (i *authzInstance) Keys(index instanceIndex) []string {
switch index {
case instanceIndexByID:
return []string{i.ID}
case instanceIndexByHost:
return i.ExternalDomains
case instanceIndexUnspecified:
}
return nil
}
func scanAuthzInstance() (*authzInstance, func(row *sql.Row) error) {
instance := &authzInstance{}
return instance, func(row *sql.Row) error {
var (
lang string
enableIframeEmbedding sql.NullBool
enableImpersonation sql.NullBool
auditLogRetention database.NullDuration
block sql.NullBool
features []byte
executionTargetsBytes []byte
)
err := row.Scan(
&instance.ID,
&instance.DefaultOrgID,
&instance.IAMProjectID,
&instance.ConsoleID,
&instance.ConsoleAppID,
&lang,
&enableIframeEmbedding,
&instance.CSP.AllowedOrigins,
&enableImpersonation,
&auditLogRetention,
&block,
&features,
&instance.ExternalDomains,
&instance.TrustedDomains,
&executionTargetsBytes,
)
if errors.Is(err, sql.ErrNoRows) {
return zerrors.ThrowNotFound(nil, "QUERY-1kIjX", "Errors.IAM.NotFound")
}
if err != nil {
return zerrors.ThrowInternal(err, "QUERY-d3fas", "Errors.Internal")
}
instance.DefaultLang = language.Make(lang)
if auditLogRetention.Valid {
instance.LogRetention = &auditLogRetention.Duration
}
if block.Valid {
instance.IsBlocked = &block.Bool
}
instance.CSP.EnableIframeEmbedding = enableIframeEmbedding.Bool
instance.Impersonation = enableImpersonation.Bool
if len(features) == 0 {
return nil
}
if err = json.Unmarshal(features, &instance.Feature); err != nil {
return zerrors.ThrowInternal(err, "QUERY-Po8ki", "Errors.Internal")
}
if len(executionTargetsBytes) > 0 {
var targets []target_domain.Target
if err := json.Unmarshal(executionTargetsBytes, &targets); err != nil {
return zerrors.ThrowInternal(err, "QUERY-aeKa2", "Errors.Internal")
}
instance.ExecutionTargets = target_domain.NewRouter(targets)
}
return nil
}
}
func (c *Caches) registerInstanceInvalidation() {
invalidate := cacheInvalidationFunc(c.instance, instanceIndexByID, getAggregateID)
projection.InstanceProjection.RegisterCacheInvalidation(invalidate)
projection.InstanceDomainProjection.RegisterCacheInvalidation(invalidate)
projection.InstanceFeatureProjection.RegisterCacheInvalidation(invalidate)
projection.InstanceTrustedDomainProjection.RegisterCacheInvalidation(invalidate)
projection.SecurityPolicyProjection.RegisterCacheInvalidation(invalidate)
// These projections have their own aggregate ID, invalidate using resource owner.
invalidate = cacheInvalidationFunc(c.instance, instanceIndexByID, getResourceOwner)
projection.LimitsProjection.RegisterCacheInvalidation(invalidate)
projection.RestrictionsProjection.RegisterCacheInvalidation(invalidate)
projection.ExecutionProjection.RegisterCacheInvalidation(invalidate)
projection.TargetProjection.RegisterCacheInvalidation(invalidate)
// System feature update should invalidate all instances, so Truncate the cache.
projection.SystemFeatureProjection.RegisterCacheInvalidation(func(ctx context.Context, _ []*eventstore.Aggregate) {
err := c.instance.Truncate(ctx)
logging.OnError(err).Warn("cache truncate failed")
})
}
type instanceIndex int
//go:generate enumer -type instanceIndex -linecomment
const (
// Empty line comment ensures empty string for unspecified value
instanceIndexUnspecified instanceIndex = iota //
instanceIndexByID
instanceIndexByHost
)