package query

import (
	"context"
	"database/sql"
	_ "embed"
	"errors"
	"fmt"
	"strings"
	"time"

	sq "github.com/Masterminds/squirrel"
	"github.com/zitadel/logging"

	"github.com/zitadel/zitadel/internal/api/authz"
	"github.com/zitadel/zitadel/internal/api/call"
	"github.com/zitadel/zitadel/internal/eventstore"
	"github.com/zitadel/zitadel/internal/query/projection"
	"github.com/zitadel/zitadel/internal/telemetry/tracing"
	"github.com/zitadel/zitadel/internal/zerrors"
)

type State struct {
	LastRun time.Time

	Position       float64
	EventCreatedAt time.Time
	AggregateID    string
	AggregateType  eventstore.AggregateType
	Sequence       uint64
}

type CurrentStates struct {
	SearchResponse
	CurrentStates []*CurrentState
}

type CurrentState struct {
	ProjectionName string
	State
}

type CurrentStateSearchQueries struct {
	SearchRequest
	Queries []SearchQuery
}

func NewCurrentStatesInstanceIDSearchQuery(instanceID string) (SearchQuery, error) {
	return NewTextQuery(CurrentStateColInstanceID, instanceID, TextEquals)
}

func NewCurrentStatesProjectionSearchQuery(projection string) (SearchQuery, error) {
	return NewTextQuery(CurrentStateColProjectionName, projection, TextEquals)
}

func (q *CurrentStateSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
	query = q.SearchRequest.toQuery(query)
	for _, q := range q.Queries {
		query = q.toQuery(query)
	}
	return query
}

func (q *Queries) SearchCurrentStates(ctx context.Context, queries *CurrentStateSearchQueries) (currentStates *CurrentStates, err error) {
	ctx, span := tracing.NewSpan(ctx)
	defer func() { span.EndWithError(err) }()

	query, scan := prepareCurrentStateQuery(ctx, q.client)
	stmt, args, err := queries.toQuery(query).ToSql()
	if err != nil {
		return nil, zerrors.ThrowInvalidArgument(err, "QUERY-MmFef", "Errors.Query.InvalidRequest")
	}

	err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
		currentStates, err = scan(rows)
		return err
	}, stmt, args...)
	if err != nil {
		return nil, zerrors.ThrowInternal(err, "QUERY-22H8f", "Errors.Internal")
	}

	return currentStates, nil
}

func (q *Queries) latestState(ctx context.Context, projections ...table) (state *State, err error) {
	ctx, span := tracing.NewSpan(ctx)
	defer func() { span.EndWithError(err) }()

	query, scan := prepareLatestState(ctx, q.client)
	or := make(sq.Or, len(projections))
	for i, projection := range projections {
		or[i] = sq.Eq{CurrentStateColProjectionName.identifier(): projection.name}
	}
	stmt, args, err := query.
		Where(or).
		Where(sq.Eq{CurrentStateColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}).
		OrderBy(CurrentStateColEventDate.identifier() + " DESC").
		ToSql()
	if err != nil {
		return nil, zerrors.ThrowInternal(err, "QUERY-5CfX9", "Errors.Query.SQLStatement")
	}

	err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
		state, err = scan(row)
		return err
	}, stmt, args...)

	return state, err
}

func (q *Queries) ClearCurrentSequence(ctx context.Context, projectionName string) (err error) {
	tx, err := q.client.Begin()
	if err != nil {
		return zerrors.ThrowInternal(err, "QUERY-9iOpr", "Errors.RemoveFailed")
	}
	defer func() {
		if err != nil {
			rollbackErr := tx.Rollback()
			logging.OnError(rollbackErr).Debug("rollback failed")
			return
		}
		if commitErr := tx.Commit(); commitErr != nil {
			err = zerrors.ThrowInternal(commitErr, "QUERY-JGD0l", "Errors.Internal")
		}
	}()

	name, err := q.checkAndLock(tx, projectionName)
	if err != nil {
		return err
	}

	tables, err := tablesForReset(ctx, tx, name)
	if err != nil {
		return err
	}
	err = reset(ctx, tx, tables, name)
	if err != nil {
		return err
	}
	err = tx.Commit()
	if err != nil {
		return zerrors.ThrowInternal(err, "QUERY-Sfvsc", "Errors.Internal")
	}
	return nil
}

func (q *Queries) checkAndLock(tx *sql.Tx, projectionName string) (name string, err error) {
	stmt, args, err := sq.Select(CurrentStateColProjectionName.identifier()).
		From(currentStateTable.identifier()).
		Where(sq.Eq{
			CurrentStateColProjectionName.identifier(): projectionName,
		}).Suffix("FOR UPDATE").
		PlaceholderFormat(sq.Dollar).
		ToSql()
	if err != nil {
		return "", zerrors.ThrowInternal(err, "QUERY-UJTUy", "Errors.Internal")
	}
	row := tx.QueryRow(stmt, args...)
	if err := row.Scan(&name); err != nil || name == "" {
		return "", zerrors.ThrowInternal(err, "QUERY-ej8fn", "Errors.ProjectionName.Invalid")
	}
	return name, nil
}

func tablesForReset(ctx context.Context, tx *sql.Tx, projectionName string) (tables []string, err error) {
	names := strings.Split(projectionName, ".")
	if len(names) != 2 {
		return nil, zerrors.ThrowInvalidArgument(nil, "QUERY-wk1jr", "Errors.InvalidArgument")
	}
	schema := names[0]
	tablePrefix := names[1]

	tablesQuery, args, err := sq.Select("table_name").
		From("[show tables from " + schema + "]").
		Where(
			sq.And{
				sq.Eq{"type": "table"},
				sq.NotEq{"table_name": []string{"locks", "current_sequences", "current_states", "failed_events", "failed_events2"}},
				sq.Like{"table_name": tablePrefix + "%"},
			}).
		PlaceholderFormat(sq.Dollar).
		ToSql()
	if err != nil {
		return nil, zerrors.ThrowInternal(err, "QUERY-ASff2", "Errors.ProjectionName.Invalid")
	}

	rows, err := tx.QueryContext(ctx, tablesQuery, args...)
	if err != nil {
		return nil, zerrors.ThrowInternal(err, "QUERY-Dgfw", "Errors.ProjectionName.Invalid")
	}
	defer rows.Close()

	for rows.Next() {
		var tableName string
		if err := rows.Scan(&tableName); err != nil {
			return nil, zerrors.ThrowInternal(err, "QUERY-ej8fn", "Errors.ProjectionName.Invalid")
		}
		tables = append(tables, schema+"."+tableName)
	}

	if err := rows.Err(); err != nil {
		return nil, err
	}

	return tables, nil
}

func reset(ctx context.Context, tx *sql.Tx, tables []string, projectionName string) error {
	for _, tableName := range tables {
		_, err := tx.Exec(fmt.Sprintf("TRUNCATE %s cascade", tableName))
		if err != nil {
			return zerrors.ThrowInternal(err, "QUERY-3n92f", "Errors.RemoveFailed")
		}
	}
	update, args, err := sq.Update(currentStateTable.identifier()).
		Set(CurrentStateColEventDate.name, 0).
		Where(sq.Eq{
			CurrentStateColProjectionName.name: projectionName,
		}).
		PlaceholderFormat(sq.Dollar).
		ToSql()
	if err != nil {
		return zerrors.ThrowInternal(err, "QUERY-Ff3tw", "Errors.RemoveFailed")
	}
	_, err = tx.Exec(update, args...)
	if err != nil {
		return zerrors.ThrowInternal(err, "QUERY-NFiws", "Errors.RemoveFailed")
	}
	return nil
}

func prepareLatestState(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*State, error)) {
	return sq.Select(
			CurrentStateColEventDate.identifier(),
			CurrentStateColPosition.identifier(),
			CurrentStateColLastUpdated.identifier()).
			From(currentStateTable.identifier() + db.Timetravel(call.Took(ctx))).
			PlaceholderFormat(sq.Dollar),
		func(row *sql.Row) (*State, error) {
			var (
				creationDate sql.NullTime
				lastUpdated  sql.NullTime
				position     sql.NullFloat64
			)
			err := row.Scan(
				&creationDate,
				&position,
				&lastUpdated,
			)
			if err != nil && !errors.Is(err, sql.ErrNoRows) {
				return nil, zerrors.ThrowInternal(err, "QUERY-aAZ1D", "Errors.Internal")
			}
			return &State{
				EventCreatedAt: creationDate.Time,
				LastRun:        lastUpdated.Time,
				Position:       position.Float64,
			}, nil
		}
}

func prepareCurrentStateQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*CurrentStates, error)) {
	return sq.Select(
			CurrentStateColLastUpdated.identifier(),
			CurrentStateColEventDate.identifier(),
			CurrentStateColPosition.identifier(),
			CurrentStateColProjectionName.identifier(),
			CurrentStateColAggregateType.identifier(),
			CurrentStateColAggregateID.identifier(),
			CurrentStateColSequence.identifier(),
			countColumn.identifier()).
			From(currentStateTable.identifier() + db.Timetravel(call.Took(ctx))).
			PlaceholderFormat(sq.Dollar),
		func(rows *sql.Rows) (*CurrentStates, error) {
			states := make([]*CurrentState, 0)
			var count uint64
			for rows.Next() {
				currentState := new(CurrentState)
				var (
					lastRun         sql.NullTime
					eventDate       sql.NullTime
					currentPosition sql.NullFloat64
					aggregateType   sql.NullString
					aggregateID     sql.NullString
					sequence        sql.NullInt64
				)

				err := rows.Scan(
					&lastRun,
					&eventDate,
					&currentPosition,
					&currentState.ProjectionName,
					&aggregateType,
					&aggregateID,
					&sequence,
					&count,
				)
				if err != nil {
					return nil, err
				}
				currentState.State.EventCreatedAt = eventDate.Time
				currentState.State.LastRun = lastRun.Time
				currentState.Position = currentPosition.Float64
				currentState.AggregateType = eventstore.AggregateType(aggregateType.String)
				currentState.AggregateID = aggregateID.String
				currentState.Sequence = uint64(sequence.Int64)
				states = append(states, currentState)
			}

			if err := rows.Close(); err != nil {
				return nil, zerrors.ThrowInternal(err, "QUERY-jbJ77", "Errors.Query.CloseRows")
			}

			return &CurrentStates{
				CurrentStates: states,
				SearchResponse: SearchResponse{
					Count: count,
				},
			}, nil
		}
}

var (
	currentStateTable = table{
		name:          projection.CurrentStateTable,
		instanceIDCol: "instance_id",
	}
	CurrentStateColEventDate = Column{
		name:  "event_date",
		table: currentStateTable,
	}
	CurrentStateColPosition = Column{
		name:  "position",
		table: currentStateTable,
	}
	CurrentStateColAggregateType = Column{
		name:  "aggregate_type",
		table: currentStateTable,
	}
	CurrentStateColAggregateID = Column{
		name:  "aggregate_id",
		table: currentStateTable,
	}
	CurrentStateColSequence = Column{
		name:  "sequence",
		table: currentStateTable,
	}
	CurrentStateColLastUpdated = Column{
		name:  "last_updated",
		table: currentStateTable,
	}
	CurrentStateColProjectionName = Column{
		name:  "projection_name",
		table: currentStateTable,
	}
	CurrentStateColInstanceID = Column{
		name:  "instance_id",
		table: currentStateTable,
	}
)

var (
	locksTable = table{
		name:          projection.LocksTable,
		instanceIDCol: "instance_id",
	}
	LocksColLockerID = Column{
		name:  "locker_id",
		table: locksTable,
	}
	LocksColUntil = Column{
		name:  "locked_until",
		table: locksTable,
	}
	LocksColProjectionName = Column{
		name:  "projection_name",
		table: locksTable,
	}
)