package command

import (
	"time"

	"github.com/zitadel/zitadel/internal/domain"
	caos_errs "github.com/zitadel/zitadel/internal/errors"
	"github.com/zitadel/zitadel/internal/eventstore"
	"github.com/zitadel/zitadel/internal/repository/oidcsession"
)

type OIDCSessionWriteModel struct {
	eventstore.WriteModel

	UserID                     string
	SessionID                  string
	ClientID                   string
	Audience                   []string
	Scope                      []string
	AuthMethods                []domain.UserAuthMethodType
	AuthTime                   time.Time
	State                      domain.OIDCSessionState
	AccessTokenID              string
	AccessTokenCreation        time.Time
	AccessTokenExpiration      time.Time
	RefreshTokenID             string
	RefreshToken               string
	RefreshTokenExpiration     time.Time
	RefreshTokenIdleExpiration time.Time

	aggregate *eventstore.Aggregate
}

func NewOIDCSessionWriteModel(id string, resourceOwner string) *OIDCSessionWriteModel {
	return &OIDCSessionWriteModel{
		WriteModel: eventstore.WriteModel{
			AggregateID:   id,
			ResourceOwner: resourceOwner,
		},
		aggregate: &oidcsession.NewAggregate(id, resourceOwner).Aggregate,
	}
}

func (wm *OIDCSessionWriteModel) Reduce() error {
	for _, event := range wm.Events {
		switch e := event.(type) {
		case *oidcsession.AddedEvent:
			wm.reduceAdded(e)
		case *oidcsession.AccessTokenAddedEvent:
			wm.reduceAccessTokenAdded(e)
		case *oidcsession.AccessTokenRevokedEvent:
			wm.reduceAccessTokenRevoked(e)
		case *oidcsession.RefreshTokenAddedEvent:
			wm.reduceRefreshTokenAdded(e)
		case *oidcsession.RefreshTokenRenewedEvent:
			wm.reduceRefreshTokenRenewed(e)
		case *oidcsession.RefreshTokenRevokedEvent:
			wm.reduceRefreshTokenRevoked(e)
		}
	}
	return wm.WriteModel.Reduce()
}

func (wm *OIDCSessionWriteModel) Query() *eventstore.SearchQueryBuilder {
	query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
		AddQuery().
		AggregateTypes(oidcsession.AggregateType).
		AggregateIDs(wm.AggregateID).
		EventTypes(
			oidcsession.AddedType,
			oidcsession.AccessTokenAddedType,
			oidcsession.RefreshTokenAddedType,
			oidcsession.RefreshTokenRenewedType,
			oidcsession.RefreshTokenRevokedType,
		).
		Builder()

	if wm.ResourceOwner != "" {
		query.ResourceOwner(wm.ResourceOwner)
	}
	return query
}

func (wm *OIDCSessionWriteModel) reduceAdded(e *oidcsession.AddedEvent) {
	wm.UserID = e.UserID
	wm.SessionID = e.SessionID
	wm.ClientID = e.ClientID
	wm.Audience = e.Audience
	wm.Scope = e.Scope
	wm.AuthMethods = e.AuthMethods
	wm.AuthTime = e.AuthTime
	wm.State = domain.OIDCSessionStateActive
	// the write model might be initialized without resource owner,
	// so update the aggregate
	if wm.ResourceOwner == "" {
		wm.aggregate = &oidcsession.NewAggregate(wm.AggregateID, e.Aggregate().ResourceOwner).Aggregate
	}
}

func (wm *OIDCSessionWriteModel) reduceAccessTokenAdded(e *oidcsession.AccessTokenAddedEvent) {
	wm.AccessTokenID = e.ID
	wm.AccessTokenExpiration = e.CreationDate().Add(e.Lifetime)
}

func (wm *OIDCSessionWriteModel) reduceAccessTokenRevoked(e *oidcsession.AccessTokenRevokedEvent) {
	wm.AccessTokenID = ""
	wm.AccessTokenExpiration = e.CreationDate()
}

func (wm *OIDCSessionWriteModel) reduceRefreshTokenAdded(e *oidcsession.RefreshTokenAddedEvent) {
	wm.RefreshTokenID = e.ID
	wm.RefreshTokenExpiration = e.CreationDate().Add(e.Lifetime)
	wm.RefreshTokenIdleExpiration = e.CreationDate().Add(e.IdleLifetime)
}

func (wm *OIDCSessionWriteModel) reduceRefreshTokenRenewed(e *oidcsession.RefreshTokenRenewedEvent) {
	wm.RefreshTokenID = e.ID
	wm.RefreshTokenIdleExpiration = e.CreationDate().Add(e.IdleLifetime)
}

func (wm *OIDCSessionWriteModel) reduceRefreshTokenRevoked(e *oidcsession.RefreshTokenRevokedEvent) {
	wm.RefreshTokenID = ""
	wm.RefreshTokenExpiration = e.CreationDate()
	wm.RefreshTokenIdleExpiration = e.CreationDate()
	wm.AccessTokenID = ""
	wm.AccessTokenExpiration = e.CreationDate()
}

func (wm *OIDCSessionWriteModel) CheckRefreshToken(refreshTokenID string) error {
	if wm.State != domain.OIDCSessionStateActive {
		return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid")
	}
	if wm.RefreshTokenID != refreshTokenID {
		return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid")
	}
	now := time.Now()
	if wm.RefreshTokenExpiration.Before(now) || wm.RefreshTokenIdleExpiration.Before(now) {
		return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid")
	}
	return nil
}

func (wm *OIDCSessionWriteModel) CheckAccessToken(accessTokenID string) error {
	if wm.State != domain.OIDCSessionStateActive {
		return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-KL2pk", "Errors.OIDCSession.Token.Invalid")
	}
	if wm.AccessTokenID != accessTokenID {
		return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JLKW2", "Errors.OIDCSession.Token.Invalid")
	}
	if wm.AccessTokenExpiration.Before(time.Now()) {
		return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3j3md", "Errors.OIDCSession.Token.Invalid")
	}
	return nil
}

func (wm *OIDCSessionWriteModel) CheckClient(clientID string) error {
	for _, aud := range wm.Audience {
		if aud == clientID {
			return nil
		}
	}
	return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-SKjl3", "Errors.OIDCSession.InvalidClient")
}

func (wm *OIDCSessionWriteModel) OIDCRefreshTokenID(refreshTokenID string) string {
	return wm.AggregateID + TokenDelimiter + refreshTokenID
}