package query

import (
	"context"
	"database/sql"
	errs "errors"

	sq "github.com/Masterminds/squirrel"

	"github.com/zitadel/zitadel/internal/api/authz"
	"github.com/zitadel/zitadel/internal/database"
	"github.com/zitadel/zitadel/internal/domain"
	"github.com/zitadel/zitadel/internal/errors"
	"github.com/zitadel/zitadel/internal/query/projection"
	"github.com/zitadel/zitadel/internal/telemetry/tracing"
)

var (
	deviceAuthTable = table{
		name:          projection.DeviceAuthProjectionTable,
		instanceIDCol: projection.DeviceAuthColumnInstanceID,
	}
	DeviceAuthColumnID = Column{
		name:  projection.DeviceAuthColumnID,
		table: deviceAuthTable,
	}
	DeviceAuthColumnClientID = Column{
		name:  projection.DeviceAuthColumnClientID,
		table: deviceAuthTable,
	}
	DeviceAuthColumnDeviceCode = Column{
		name:  projection.DeviceAuthColumnDeviceCode,
		table: deviceAuthTable,
	}
	DeviceAuthColumnUserCode = Column{
		name:  projection.DeviceAuthColumnUserCode,
		table: deviceAuthTable,
	}
	DeviceAuthColumnExpires = Column{
		name:  projection.DeviceAuthColumnExpires,
		table: deviceAuthTable,
	}
	DeviceAuthColumnScopes = Column{
		name:  projection.DeviceAuthColumnScopes,
		table: deviceAuthTable,
	}
	DeviceAuthColumnState = Column{
		name:  projection.DeviceAuthColumnState,
		table: deviceAuthTable,
	}
	DeviceAuthColumnSubject = Column{
		name:  projection.DeviceAuthColumnSubject,
		table: deviceAuthTable,
	}
	DeviceAuthColumnCreationDate = Column{
		name:  projection.DeviceAuthColumnCreationDate,
		table: deviceAuthTable,
	}
	DeviceAuthColumnChangeDate = Column{
		name:  projection.DeviceAuthColumnChangeDate,
		table: deviceAuthTable,
	}
	DeviceAuthColumnSequence = Column{
		name:  projection.DeviceAuthColumnSequence,
		table: deviceAuthTable,
	}
	DeviceAuthColumnInstanceID = Column{
		name:  projection.DeviceAuthColumnInstanceID,
		table: deviceAuthTable,
	}
)

func (q *Queries) DeviceAuthByDeviceCode(ctx context.Context, clientID, deviceCode string) (_ *domain.DeviceAuth, err error) {
	ctx, span := tracing.NewSpan(ctx)
	defer func() { span.EndWithError(err) }()

	stmt, scan := prepareDeviceAuthQuery(ctx, q.client)
	eq := sq.Eq{
		DeviceAuthColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
		DeviceAuthColumnClientID.identifier():   clientID,
		DeviceAuthColumnDeviceCode.identifier(): deviceCode,
	}
	query, args, err := stmt.Where(eq).ToSql()
	if err != nil {
		return nil, errors.ThrowInternal(err, "QUERY-uk1Oh", "Errors.Query.SQLStatement")
	}

	return scan(q.client.QueryRowContext(ctx, query, args...))
}

func (q *Queries) DeviceAuthByUserCode(ctx context.Context, userCode string) (_ *domain.DeviceAuth, err error) {
	ctx, span := tracing.NewSpan(ctx)
	defer func() { span.EndWithError(err) }()

	stmt, scan := prepareDeviceAuthQuery(ctx, q.client)
	eq := sq.Eq{
		DeviceAuthColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
		DeviceAuthColumnUserCode.identifier():   userCode,
	}
	query, args, err := stmt.Where(eq).ToSql()
	if err != nil {
		return nil, errors.ThrowInternal(err, "QUERY-Axu7l", "Errors.Query.SQLStatement")
	}

	return scan(q.client.QueryRowContext(ctx, query, args...))
}

var deviceAuthSelectColumns = []string{
	DeviceAuthColumnID.identifier(),
	DeviceAuthColumnClientID.identifier(),
	DeviceAuthColumnScopes.identifier(),
	DeviceAuthColumnExpires.identifier(),
	DeviceAuthColumnState.identifier(),
	DeviceAuthColumnSubject.identifier(),
}

func prepareDeviceAuthQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*domain.DeviceAuth, error)) {
	return sq.Select(deviceAuthSelectColumns...).From(deviceAuthTable.identifier()).PlaceholderFormat(sq.Dollar),
		func(row *sql.Row) (*domain.DeviceAuth, error) {
			dst := new(domain.DeviceAuth)
			var scopes database.StringArray

			err := row.Scan(
				&dst.AggregateID,
				&dst.ClientID,
				&scopes,
				&dst.Expires,
				&dst.State,
				&dst.Subject,
			)
			if errs.Is(err, sql.ErrNoRows) {
				return nil, errors.ThrowNotFound(err, "QUERY-Sah9a", "Errors.DeviceAuth.NotExisting")
			}
			if err != nil {
				return nil, errors.ThrowInternal(err, "QUERY-Voo3o", "Errors.Internal")
			}

			dst.Scopes = scopes
			return dst, nil
		}
}