package query

import (
	"context"
	"database/sql"
	_ "embed"
	"encoding/json"
	"errors"
	"fmt"
	"slices"
	"strings"
	"time"

	sq "github.com/Masterminds/squirrel"
	"github.com/zitadel/logging"
	"golang.org/x/text/language"

	"github.com/zitadel/zitadel/internal/api/authz"
	"github.com/zitadel/zitadel/internal/api/call"
	"github.com/zitadel/zitadel/internal/database"
	"github.com/zitadel/zitadel/internal/eventstore"
	"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
	"github.com/zitadel/zitadel/internal/feature"
	"github.com/zitadel/zitadel/internal/query/projection"
	"github.com/zitadel/zitadel/internal/telemetry/tracing"
	"github.com/zitadel/zitadel/internal/zerrors"
)

const (
	InstancesFilterTableAlias = "f"
)

var (
	instanceTable = table{
		name:          projection.InstanceProjectionTable,
		instanceIDCol: projection.InstanceColumnID,
	}
	limitsTable = table{
		name:          projection.LimitsProjectionTable,
		instanceIDCol: projection.LimitsColumnInstanceID,
	}
	InstanceColumnID = Column{
		name:  projection.InstanceColumnID,
		table: instanceTable,
	}
	InstanceColumnName = Column{
		name:  projection.InstanceColumnName,
		table: instanceTable,
	}
	InstanceColumnCreationDate = Column{
		name:  projection.InstanceColumnCreationDate,
		table: instanceTable,
	}
	InstanceColumnChangeDate = Column{
		name:  projection.InstanceColumnChangeDate,
		table: instanceTable,
	}
	InstanceColumnSequence = Column{
		name:  projection.InstanceColumnSequence,
		table: instanceTable,
	}
	InstanceColumnDefaultOrgID = Column{
		name:  projection.InstanceColumnDefaultOrgID,
		table: instanceTable,
	}
	InstanceColumnProjectID = Column{
		name:  projection.InstanceColumnProjectID,
		table: instanceTable,
	}
	InstanceColumnConsoleID = Column{
		name:  projection.InstanceColumnConsoleID,
		table: instanceTable,
	}
	InstanceColumnConsoleAppID = Column{
		name:  projection.InstanceColumnConsoleAppID,
		table: instanceTable,
	}
	InstanceColumnDefaultLanguage = Column{
		name:  projection.InstanceColumnDefaultLanguage,
		table: instanceTable,
	}
	LimitsColumnInstanceID = Column{
		name:  projection.LimitsColumnInstanceID,
		table: limitsTable,
	}
	LimitsColumnAuditLogRetention = Column{
		name:  projection.LimitsColumnAuditLogRetention,
		table: limitsTable,
	}
	LimitsColumnBlock = Column{
		name:  projection.LimitsColumnBlock,
		table: limitsTable,
	}
)

type Instance struct {
	ID           string
	ChangeDate   time.Time
	CreationDate time.Time
	Sequence     uint64
	Name         string

	DefaultOrgID string
	IAMProjectID string
	ConsoleID    string
	ConsoleAppID string
	DefaultLang  language.Tag
	Domains      []*InstanceDomain
}

type Instances struct {
	SearchResponse
	Instances []*Instance
}

type InstanceSearchQueries struct {
	SearchRequest
	Queries []SearchQuery
}

func NewInstanceIDsListSearchQuery(ids ...string) (SearchQuery, error) {
	list := make([]interface{}, len(ids))
	for i, value := range ids {
		list[i] = value
	}
	return NewListQuery(InstanceColumnID, list, ListIn)
}

func NewInstanceDomainsListSearchQuery(domains ...string) (SearchQuery, error) {
	list := make([]interface{}, len(domains))
	for i, value := range domains {
		list[i] = value
	}

	return NewListQuery(InstanceDomainDomainCol, list, ListIn)
}

func (q *InstanceSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
	query = q.SearchRequest.toQuery(query)
	for _, q := range q.Queries {
		query = q.toQuery(query)
	}
	return query
}

func (q *Queries) SearchInstances(ctx context.Context, queries *InstanceSearchQueries) (instances *Instances, err error) {
	ctx, span := tracing.NewSpan(ctx)
	defer func() { span.EndWithError(err) }()

	filter, query, scan := prepareInstancesQuery(ctx, q.client)
	stmt, args, err := query(queries.toQuery(filter)).ToSql()
	if err != nil {
		return nil, zerrors.ThrowInvalidArgument(err, "QUERY-M9fow", "Errors.Query.SQLStatement")
	}

	err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
		instances, err = scan(rows)
		return err
	}, stmt, args...)
	if err != nil {
		return nil, zerrors.ThrowInternal(err, "QUERY-3j98f", "Errors.Internal")
	}
	return instances, err
}

func (q *Queries) Instance(ctx context.Context, shouldTriggerBulk bool) (instance *Instance, err error) {
	ctx, span := tracing.NewSpan(ctx)
	defer func() { span.EndWithError(err) }()

	if shouldTriggerBulk {
		_, traceSpan := tracing.NewNamedSpan(ctx, "TriggerInstanceProjection")
		ctx, err = projection.InstanceProjection.Trigger(ctx, handler.WithAwaitRunning())
		logging.OnError(err).Debug("trigger failed")
		traceSpan.EndWithError(err)
	}

	stmt, scan := prepareInstanceDomainQuery(ctx, q.client)
	query, args, err := stmt.Where(sq.Eq{
		InstanceColumnID.identifier(): authz.GetInstance(ctx).InstanceID(),
	}).ToSql()
	if err != nil {
		return nil, zerrors.ThrowInternal(err, "QUERY-d9ngs", "Errors.Query.SQLStatement")
	}

	err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
		instance, err = scan(rows)
		return err
	}, query, args...)
	return instance, err
}

var (
	//go:embed instance_by_domain.sql
	instanceByDomainQuery string

	//go:embed instance_by_id.sql
	instanceByIDQuery string
)

func (q *Queries) InstanceByHost(ctx context.Context, instanceHost, publicHost string) (_ authz.Instance, err error) {
	ctx, span := tracing.NewSpan(ctx)
	defer func() {
		if err != nil {
			err = fmt.Errorf("unable to get instance by host: instanceHost %s, publicHost %s: %w", instanceHost, publicHost, err)
		}
		span.EndWithError(err)
	}()

	instanceDomain := strings.Split(instanceHost, ":")[0] // remove possible port
	publicDomain := strings.Split(publicHost, ":")[0]     // remove possible port

	instance, ok := q.caches.instance.Get(ctx, instanceIndexByHost, instanceDomain)
	if ok {
		return instance, instance.checkDomain(instanceDomain, publicDomain)
	}
	instance, scan := scanAuthzInstance()
	if err = q.client.QueryRowContext(ctx, scan, instanceByDomainQuery, instanceDomain); err != nil {
		return nil, err
	}
	q.caches.instance.Set(ctx, instance)

	return instance, instance.checkDomain(instanceDomain, publicDomain)
}

func (q *Queries) InstanceByID(ctx context.Context, id string) (_ authz.Instance, err error) {
	ctx, span := tracing.NewSpan(ctx)
	defer func() { span.EndWithError(err) }()

	instance, ok := q.caches.instance.Get(ctx, instanceIndexByID, id)
	if ok {
		return instance, nil
	}

	instance, scan := scanAuthzInstance()
	err = q.client.QueryRowContext(ctx, scan, instanceByIDQuery, id)
	logging.OnError(err).WithField("instance_id", id).Warn("instance by ID")
	if err == nil {
		q.caches.instance.Set(ctx, instance)
	}
	return instance, err
}

func (q *Queries) GetDefaultLanguage(ctx context.Context) language.Tag {
	instance, err := q.Instance(ctx, false)
	if err != nil {
		return language.Und
	}
	return instance.DefaultLang
}

func prepareInstancesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(sq.SelectBuilder) sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) {
	instanceFilterTable := instanceTable.setAlias(InstancesFilterTableAlias)
	instanceFilterIDColumn := InstanceColumnID.setTable(instanceFilterTable)
	instanceFilterCountColumn := InstancesFilterTableAlias + ".count"
	return sq.Select(
			InstanceColumnID.identifier(),
			countColumn.identifier(),
		).Distinct().From(instanceTable.identifier()).
			LeftJoin(join(InstanceDomainInstanceIDCol, InstanceColumnID)),
		func(builder sq.SelectBuilder) sq.SelectBuilder {
			return sq.Select(
				instanceFilterCountColumn,
				instanceFilterIDColumn.identifier(),
				InstanceColumnCreationDate.identifier(),
				InstanceColumnChangeDate.identifier(),
				InstanceColumnSequence.identifier(),
				InstanceColumnName.identifier(),
				InstanceColumnDefaultOrgID.identifier(),
				InstanceColumnProjectID.identifier(),
				InstanceColumnConsoleID.identifier(),
				InstanceColumnConsoleAppID.identifier(),
				InstanceColumnDefaultLanguage.identifier(),
				InstanceDomainDomainCol.identifier(),
				InstanceDomainIsPrimaryCol.identifier(),
				InstanceDomainIsGeneratedCol.identifier(),
				InstanceDomainCreationDateCol.identifier(),
				InstanceDomainChangeDateCol.identifier(),
				InstanceDomainSequenceCol.identifier(),
			).FromSelect(builder, InstancesFilterTableAlias).
				LeftJoin(join(InstanceColumnID, instanceFilterIDColumn)).
				LeftJoin(join(InstanceDomainInstanceIDCol, instanceFilterIDColumn) + db.Timetravel(call.Took(ctx))).
				PlaceholderFormat(sq.Dollar)
		},
		func(rows *sql.Rows) (*Instances, error) {
			instances := make([]*Instance, 0)
			var lastInstance *Instance
			var count uint64
			for rows.Next() {
				instance := new(Instance)
				lang := ""
				var (
					domain       sql.NullString
					isPrimary    sql.NullBool
					isGenerated  sql.NullBool
					changeDate   sql.NullTime
					creationDate sql.NullTime
					sequence     sql.NullInt64
				)
				err := rows.Scan(
					&count,
					&instance.ID,
					&instance.CreationDate,
					&instance.ChangeDate,
					&instance.Sequence,
					&instance.Name,
					&instance.DefaultOrgID,
					&instance.IAMProjectID,
					&instance.ConsoleID,
					&instance.ConsoleAppID,
					&lang,
					&domain,
					&isPrimary,
					&isGenerated,
					&changeDate,
					&creationDate,
					&sequence,
				)
				if err != nil {
					return nil, err
				}
				if instance.ID == "" || !domain.Valid {
					continue
				}
				instance.DefaultLang = language.Make(lang)
				instanceDomain := &InstanceDomain{
					CreationDate: creationDate.Time,
					ChangeDate:   changeDate.Time,
					Sequence:     uint64(sequence.Int64),
					Domain:       domain.String,
					IsPrimary:    isPrimary.Bool,
					IsGenerated:  isGenerated.Bool,
					InstanceID:   instance.ID,
				}
				if lastInstance != nil && instance.ID == lastInstance.ID {
					lastInstance.Domains = append(lastInstance.Domains, instanceDomain)
					continue
				}
				lastInstance = instance
				instance.Domains = append(instance.Domains, instanceDomain)
				instances = append(instances, instance)
			}

			if err := rows.Close(); err != nil {
				return nil, zerrors.ThrowInternal(err, "QUERY-8nlWW", "Errors.Query.CloseRows")
			}

			return &Instances{
				Instances: instances,
				SearchResponse: SearchResponse{
					Count: count,
				},
			}, nil
		}
}

func prepareInstanceDomainQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Instance, error)) {
	return sq.Select(
			InstanceColumnID.identifier(),
			InstanceColumnCreationDate.identifier(),
			InstanceColumnChangeDate.identifier(),
			InstanceColumnSequence.identifier(),
			InstanceColumnName.identifier(),
			InstanceColumnDefaultOrgID.identifier(),
			InstanceColumnProjectID.identifier(),
			InstanceColumnConsoleID.identifier(),
			InstanceColumnConsoleAppID.identifier(),
			InstanceColumnDefaultLanguage.identifier(),
			InstanceDomainDomainCol.identifier(),
			InstanceDomainIsPrimaryCol.identifier(),
			InstanceDomainIsGeneratedCol.identifier(),
			InstanceDomainCreationDateCol.identifier(),
			InstanceDomainChangeDateCol.identifier(),
			InstanceDomainSequenceCol.identifier(),
		).
			From(instanceTable.identifier()).
			LeftJoin(join(InstanceDomainInstanceIDCol, InstanceColumnID) + db.Timetravel(call.Took(ctx))).
			PlaceholderFormat(sq.Dollar),
		func(rows *sql.Rows) (*Instance, error) {
			instance := &Instance{
				Domains: make([]*InstanceDomain, 0),
			}
			lang := ""
			for rows.Next() {
				var (
					domain       sql.NullString
					isPrimary    sql.NullBool
					isGenerated  sql.NullBool
					changeDate   sql.NullTime
					creationDate sql.NullTime
					sequence     sql.NullInt64
				)
				err := rows.Scan(
					&instance.ID,
					&instance.CreationDate,
					&instance.ChangeDate,
					&instance.Sequence,
					&instance.Name,
					&instance.DefaultOrgID,
					&instance.IAMProjectID,
					&instance.ConsoleID,
					&instance.ConsoleAppID,
					&lang,
					&domain,
					&isPrimary,
					&isGenerated,
					&changeDate,
					&creationDate,
					&sequence,
				)
				if err != nil {
					return nil, zerrors.ThrowInternal(err, "QUERY-d9nw", "Errors.Internal")
				}
				if !domain.Valid {
					continue
				}
				instance.Domains = append(instance.Domains, &InstanceDomain{
					CreationDate: creationDate.Time,
					ChangeDate:   changeDate.Time,
					Sequence:     uint64(sequence.Int64),
					Domain:       domain.String,
					IsPrimary:    isPrimary.Bool,
					IsGenerated:  isGenerated.Bool,
					InstanceID:   instance.ID,
				})
			}
			if instance.ID == "" {
				return nil, zerrors.ThrowNotFound(nil, "QUERY-n0wng", "Errors.IAM.NotFound")
			}
			instance.DefaultLang = language.Make(lang)
			if err := rows.Close(); err != nil {
				return nil, zerrors.ThrowInternal(err, "QUERY-Dfbe2", "Errors.Query.CloseRows")
			}
			return instance, nil
		}
}

type authzInstance struct {
	ID              string                     `json:"id,omitempty"`
	IAMProjectID    string                     `json:"iam_project_id,omitempty"`
	ConsoleID       string                     `json:"console_id,omitempty"`
	ConsoleAppID    string                     `json:"console_app_id,omitempty"`
	DefaultLang     language.Tag               `json:"default_lang,omitempty"`
	DefaultOrgID    string                     `json:"default_org_id,omitempty"`
	CSP             csp                        `json:"csp,omitempty"`
	Impersonation   bool                       `json:"impersonation,omitempty"`
	IsBlocked       *bool                      `json:"is_blocked,omitempty"`
	LogRetention    *time.Duration             `json:"log_retention,omitempty"`
	Feature         feature.Features           `json:"feature,omitempty"`
	ExternalDomains database.TextArray[string] `json:"external_domains,omitempty"`
	TrustedDomains  database.TextArray[string] `json:"trusted_domains,omitempty"`
}

type csp struct {
	EnableIframeEmbedding bool                       `json:"enable_iframe_embedding,omitempty"`
	AllowedOrigins        database.TextArray[string] `json:"allowed_origins,omitempty"`
}

func (i *authzInstance) InstanceID() string {
	return i.ID
}

func (i *authzInstance) ProjectID() string {
	return i.IAMProjectID
}

func (i *authzInstance) ConsoleClientID() string {
	return i.ConsoleID
}

func (i *authzInstance) ConsoleApplicationID() string {
	return i.ConsoleAppID
}

func (i *authzInstance) DefaultLanguage() language.Tag {
	return i.DefaultLang
}

func (i *authzInstance) DefaultOrganisationID() string {
	return i.DefaultOrgID
}

func (i *authzInstance) SecurityPolicyAllowedOrigins() []string {
	if !i.CSP.EnableIframeEmbedding {
		return nil
	}
	return i.CSP.AllowedOrigins
}

func (i *authzInstance) EnableImpersonation() bool {
	return i.Impersonation
}

func (i *authzInstance) Block() *bool {
	return i.IsBlocked
}

func (i *authzInstance) AuditLogRetention() *time.Duration {
	return i.LogRetention
}

func (i *authzInstance) Features() feature.Features {
	return i.Feature
}

var errPublicDomain = "public domain %q not trusted"

func (i *authzInstance) checkDomain(instanceDomain, publicDomain string) error {
	// in case public domain is empty, or the same as the instance domain, we do not need to check it
	if publicDomain == "" || instanceDomain == publicDomain {
		return nil
	}
	if !slices.Contains(i.TrustedDomains, publicDomain) {
		return zerrors.ThrowNotFound(fmt.Errorf(errPublicDomain, publicDomain), "QUERY-IuGh1", "Errors.IAM.NotFound")
	}
	return nil
}

// Keys implements [cache.Entry]
func (i *authzInstance) Keys(index instanceIndex) []string {
	switch index {
	case instanceIndexByID:
		return []string{i.ID}
	case instanceIndexByHost:
		return i.ExternalDomains
	default:
		return nil
	}
}

func scanAuthzInstance() (*authzInstance, func(row *sql.Row) error) {
	instance := &authzInstance{}
	return instance, func(row *sql.Row) error {
		var (
			lang                  string
			enableIframeEmbedding sql.NullBool
			enableImpersonation   sql.NullBool
			auditLogRetention     database.NullDuration
			block                 sql.NullBool
			features              []byte
		)
		err := row.Scan(
			&instance.ID,
			&instance.DefaultOrgID,
			&instance.IAMProjectID,
			&instance.ConsoleID,
			&instance.ConsoleAppID,
			&lang,
			&enableIframeEmbedding,
			&instance.CSP.AllowedOrigins,
			&enableImpersonation,
			&auditLogRetention,
			&block,
			&features,
			&instance.ExternalDomains,
			&instance.TrustedDomains,
		)
		if errors.Is(err, sql.ErrNoRows) {
			return zerrors.ThrowNotFound(nil, "QUERY-1kIjX", "Errors.IAM.NotFound")
		}
		if err != nil {
			return zerrors.ThrowInternal(err, "QUERY-d3fas", "Errors.Internal")
		}
		instance.DefaultLang = language.Make(lang)
		if auditLogRetention.Valid {
			instance.LogRetention = &auditLogRetention.Duration
		}
		if block.Valid {
			instance.IsBlocked = &block.Bool
		}
		instance.CSP.EnableIframeEmbedding = enableIframeEmbedding.Bool
		instance.Impersonation = enableImpersonation.Bool
		if len(features) == 0 {
			return nil
		}
		if err = json.Unmarshal(features, &instance.Feature); err != nil {
			return zerrors.ThrowInternal(err, "QUERY-Po8ki", "Errors.Internal")
		}
		return nil
	}
}

func (c *Caches) registerInstanceInvalidation() {
	invalidate := cacheInvalidationFunc(c.instance, instanceIndexByID, getAggregateID)
	projection.InstanceProjection.RegisterCacheInvalidation(invalidate)
	projection.InstanceDomainProjection.RegisterCacheInvalidation(invalidate)
	projection.InstanceFeatureProjection.RegisterCacheInvalidation(invalidate)
	projection.InstanceTrustedDomainProjection.RegisterCacheInvalidation(invalidate)
	projection.SecurityPolicyProjection.RegisterCacheInvalidation(invalidate)

	// These projections have their own aggregate ID, invalidate using resource owner.
	invalidate = cacheInvalidationFunc(c.instance, instanceIndexByID, getResourceOwner)
	projection.LimitsProjection.RegisterCacheInvalidation(invalidate)
	projection.RestrictionsProjection.RegisterCacheInvalidation(invalidate)

	// System feature update should invalidate all instances, so Truncate the cache.
	projection.SystemFeatureProjection.RegisterCacheInvalidation(func(ctx context.Context, _ []*eventstore.Aggregate) {
		err := c.instance.Truncate(ctx)
		logging.OnError(err).Warn("cache truncate failed")
	})
}

type instanceIndex int

//go:generate enumer -type instanceIndex -linecomment
const (
	// Empty line comment ensures empty string for unspecified value
	instanceIndexUnspecified instanceIndex = iota //
	instanceIndexByID
	instanceIndexByHost
)