package command

import (
	"context"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
	"github.com/zitadel/oidc/v3/pkg/oidc"

	"github.com/zitadel/zitadel/internal/domain"
	"github.com/zitadel/zitadel/internal/eventstore"
	"github.com/zitadel/zitadel/internal/repository/user"
	"github.com/zitadel/zitadel/internal/zerrors"
)

func TestCommands_RevokeRefreshToken(t *testing.T) {
	type fields struct {
		eventstore *eventstore.Eventstore
	}
	type args struct {
		ctx     context.Context
		userID  string
		orgID   string
		tokenID string
	}
	type res struct {
		want *domain.ObjectDetails
		err  func(error) bool
	}
	tests := []struct {
		name   string
		fields fields
		args   args
		res    res
	}{
		{
			"missing param, error",
			fields{
				eventstore: eventstoreExpect(t),
			},
			args{},
			res{
				err: zerrors.IsErrorInvalidArgument,
			},
		},
		{
			"token not active, error",
			fields{
				eventstore: eventstoreExpect(t,
					expectFilter(),
				),
			},
			args{
				context.Background(),
				"userID",
				"orgID",
				"tokenID",
			},
			res{
				err: zerrors.IsNotFound,
			},
		},
		{
			"push failed, error",
			fields{
				eventstore: eventstoreExpect(t,
					expectFilter(
						eventFromEventPusher(user.NewHumanRefreshTokenAddedEvent(
							context.Background(),
							&user.NewAggregate("userID", "orgID").Aggregate,
							"tokenID",
							"clientID",
							"agentID",
							"de",
							[]string{"clientID1"},
							[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
							[]string{"password"},
							time.Now(),
							1*time.Hour,
							10*time.Hour,
							nil,
						)),
					),
					expectPushFailed(zerrors.ThrowInternal(nil, "ERROR", "internal"),
						user.NewHumanRefreshTokenRemovedEvent(
							context.Background(),
							&user.NewAggregate("userID", "orgID").Aggregate,
							"tokenID",
						),
					),
				),
			},
			args{
				context.Background(),
				"userID",
				"orgID",
				"tokenID",
			},
			res{
				err: zerrors.IsInternal,
			},
		},
		{
			"revoke, ok",
			fields{
				eventstore: eventstoreExpect(t,
					expectFilter(
						eventFromEventPusher(user.NewHumanRefreshTokenAddedEvent(
							context.Background(),
							&user.NewAggregate("userID", "orgID").Aggregate,
							"tokenID",
							"clientID",
							"agentID",
							"de",
							[]string{"clientID1"},
							[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
							[]string{"password"},
							time.Now(),
							1*time.Hour,
							10*time.Hour,
							nil,
						)),
					),
					expectPush(
						user.NewHumanRefreshTokenRemovedEvent(
							context.Background(),
							&user.NewAggregate("userID", "orgID").Aggregate,
							"tokenID",
						),
					),
				),
			},
			args{
				context.Background(),
				"userID",
				"orgID",
				"tokenID",
			},
			res{
				want: &domain.ObjectDetails{
					ResourceOwner: "orgID",
				},
			},
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			c := &Commands{
				eventstore: tt.fields.eventstore,
			}
			got, err := c.RevokeRefreshToken(tt.args.ctx, tt.args.userID, tt.args.orgID, tt.args.tokenID)
			if tt.res.err == nil {
				assert.NoError(t, err)
			}
			if tt.res.err != nil && !tt.res.err(err) {
				t.Errorf("got wrong err: %v ", err)
			}
			if tt.res.err == nil {
				assertObjectDetails(t, tt.res.want, got)
			}
		})
	}
}

func TestCommands_RevokeRefreshTokens(t *testing.T) {
	type fields struct {
		eventstore *eventstore.Eventstore
	}
	type args struct {
		ctx      context.Context
		userID   string
		orgID    string
		tokenIDs []string
	}
	type res struct {
		err func(error) bool
	}
	tests := []struct {
		name   string
		fields fields
		args   args
		res    res
	}{
		{
			"missing tokenIDs, error",
			fields{
				eventstore: eventstoreExpect(t),
			},
			args{
				context.Background(),
				"userID",
				"orgID",
				nil,
			},
			res{
				err: zerrors.IsErrorInvalidArgument,
			},
		},
		{
			"one token not active, error",
			fields{
				eventstore: eventstoreExpect(t,
					expectFilter(
						eventFromEventPusher(user.NewHumanRefreshTokenAddedEvent(
							context.Background(),
							&user.NewAggregate("userID", "orgID").Aggregate,
							"tokenID",
							"clientID",
							"agentID",
							"de",
							[]string{"clientID1"},
							[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
							[]string{"password"},
							time.Now(),
							1*time.Hour,
							10*time.Hour,
							nil,
						)),
					),
					expectFilter(),
				),
			},
			args{
				context.Background(),
				"userID",
				"orgID",
				[]string{"tokenID", "tokenID2"},
			},
			res{
				err: zerrors.IsNotFound,
			},
		},
		{
			"push failed, error",
			fields{
				eventstore: eventstoreExpect(t,
					expectFilter(
						eventFromEventPusher(user.NewHumanRefreshTokenAddedEvent(
							context.Background(),
							&user.NewAggregate("userID", "orgID").Aggregate,
							"tokenID",
							"clientID",
							"agentID",
							"de",
							[]string{"clientID"},
							[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
							[]string{"password"},
							time.Now(),
							1*time.Hour,
							10*time.Hour,
							nil,
						)),
					),
					expectFilter(
						eventFromEventPusher(user.NewHumanRefreshTokenAddedEvent(
							context.Background(),
							&user.NewAggregate("userID", "orgID").Aggregate,
							"tokenID2",
							"clientID2",
							"agentID",
							"de",
							[]string{"clientID2"},
							[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
							[]string{"password"},
							time.Now(),
							1*time.Hour,
							10*time.Hour,
							nil,
						)),
					),
					expectPushFailed(zerrors.ThrowInternal(nil, "ERROR", "internal"),
						user.NewHumanRefreshTokenRemovedEvent(
							context.Background(),
							&user.NewAggregate("userID", "orgID").Aggregate,
							"tokenID",
						),
						user.NewHumanRefreshTokenRemovedEvent(
							context.Background(),
							&user.NewAggregate("userID", "orgID").Aggregate,
							"tokenID2",
						),
					),
				),
			},
			args{
				context.Background(),
				"userID",
				"orgID",
				[]string{"tokenID", "tokenID2"},
			},
			res{
				err: zerrors.IsInternal,
			},
		},
		{
			"revoke, ok",
			fields{
				eventstore: eventstoreExpect(t,
					expectFilter(
						eventFromEventPusher(user.NewHumanRefreshTokenAddedEvent(
							context.Background(),
							&user.NewAggregate("userID", "orgID").Aggregate,
							"tokenID",
							"clientID",
							"agentID",
							"de",
							[]string{"clientID1"},
							[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
							[]string{"password"},
							time.Now(),
							1*time.Hour,
							10*time.Hour,
							nil,
						)),
					),
					expectFilter(
						eventFromEventPusher(user.NewHumanRefreshTokenAddedEvent(
							context.Background(),
							&user.NewAggregate("userID", "orgID").Aggregate,
							"tokenID2",
							"clientID2",
							"agentID",
							"de",
							[]string{"clientID2"},
							[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
							[]string{"password"},
							time.Now(),
							1*time.Hour,
							10*time.Hour,
							nil,
						)),
					),
					expectPush(
						user.NewHumanRefreshTokenRemovedEvent(
							context.Background(),
							&user.NewAggregate("userID", "orgID").Aggregate,
							"tokenID",
						),
						user.NewHumanRefreshTokenRemovedEvent(
							context.Background(),
							&user.NewAggregate("userID", "orgID").Aggregate,
							"tokenID2",
						),
					),
				),
			},
			args{
				context.Background(),
				"userID",
				"orgID",
				[]string{"tokenID", "tokenID2"},
			},
			res{
				err: nil,
			},
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			c := &Commands{
				eventstore: tt.fields.eventstore,
			}
			err := c.RevokeRefreshTokens(tt.args.ctx, tt.args.userID, tt.args.orgID, tt.args.tokenIDs)
			if tt.res.err == nil {
				assert.NoError(t, err)
			}
			if tt.res.err != nil && !tt.res.err(err) {
				t.Errorf("got wrong err: %v ", err)
			}
		})
	}
}