package database

import (
	"context"
	"database/sql"
	"database/sql/driver"
	"errors"
	"fmt"
	"regexp"
	"testing"

	"github.com/DATA-DOG/go-sqlmock"
	"github.com/stretchr/testify/assert"

	"github.com/zitadel/zitadel/internal/crypto"
	z_db "github.com/zitadel/zitadel/internal/database"
	db_mock "github.com/zitadel/zitadel/internal/database/mock"
	"github.com/zitadel/zitadel/internal/zerrors"
)

func Test_database_ReadKeys(t *testing.T) {
	type fields struct {
		client    db
		masterKey string
		decrypt   func(encryptedKey, masterKey string) (key string, err error)
	}
	type res struct {
		keys crypto.Keys
		err  func(error) bool
	}
	tests := []struct {
		name   string
		fields fields
		res    res
	}{
		{
			"query fails, error",
			fields{
				client:    dbMock(t, expectQueryErr("SELECT id, key FROM system.encryption_keys", sql.ErrConnDone)),
				masterKey: "",
				decrypt:   nil,
			},
			res{
				err: func(err error) bool {
					return errors.Is(err, sql.ErrConnDone)
				},
			},
		},
		{
			"decryption error",
			fields{
				client: dbMock(t, expectQueryScanErr(
					"SELECT id, key FROM system.encryption_keys",
					[]string{"id", "key"},
					[][]driver.Value{
						{
							"id1",
							"key1",
						},
					})),
				masterKey: "wrong key",
				decrypt: func(encryptedKey, masterKey string) (key string, err error) {
					return "", fmt.Errorf("wrong masterkey")
				},
			},
			res{
				err: zerrors.IsInternal,
			},
		},
		{
			"single key ok",
			fields{
				client: dbMock(t, expectQuery(
					"SELECT id, key FROM system.encryption_keys",
					[]string{"id", "key"},
					[][]driver.Value{
						{
							"id1",
							"key1",
						},
					})),
				masterKey: "masterKey",
				decrypt: func(encryptedKey, masterKey string) (key string, err error) {
					return encryptedKey, nil
				},
			},
			res{
				keys: crypto.Keys(map[string]string{"id1": "key1"}),
			},
		},
		{
			"multiple keys ok",
			fields{
				client: dbMock(t, expectQuery(
					"SELECT id, key FROM system.encryption_keys",
					[]string{"id", "key"},
					[][]driver.Value{
						{
							"id1",
							"key1",
						},
						{
							"id2",
							"key2",
						},
					})),
				masterKey: "masterKey",
				decrypt: func(encryptedKey, masterKey string) (key string, err error) {
					return encryptedKey, nil
				},
			},
			res{
				keys: crypto.Keys(map[string]string{"id1": "key1", "id2": "key2"}),
			},
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			d := &Database{
				client:    tt.fields.client.db,
				masterKey: tt.fields.masterKey,
				decrypt:   tt.fields.decrypt,
			}
			got, err := d.ReadKeys()
			if tt.res.err == nil {
				assert.NoError(t, err)
			} else if tt.res.err != nil && !tt.res.err(err) {
				t.Errorf("got wrong err: %v", err)
			}
			if tt.res.err == nil {
				assert.Equal(t, tt.res.keys, got)
			}
			if err := tt.fields.client.mock.ExpectationsWereMet(); err != nil {
				t.Error(err)
			}
		})
	}
}

func Test_database_ReadKey(t *testing.T) {
	type fields struct {
		client    db
		masterKey string
		decrypt   func(encryptedKey, masterKey string) (key string, err error)
	}
	type args struct {
		id string
	}
	type res struct {
		key *crypto.Key
		err func(error) bool
	}
	tests := []struct {
		name   string
		fields fields
		args   args
		res    res
	}{
		{
			"query fails, error",
			fields{
				client:    dbMock(t, expectQueryErr("SELECT key FROM system.encryption_keys WHERE id = $1", sql.ErrConnDone)),
				masterKey: "",
				decrypt:   nil,
			},
			args{
				id: "id1",
			},
			res{
				err: func(err error) bool {
					return errors.Is(err, sql.ErrConnDone)
				},
			},
		},
		{
			"key not found err",
			fields{
				client: dbMock(t, expectQueryScanErr(
					"SELECT key FROM system.encryption_keys WHERE id = $1",
					nil,
					nil,
					"id1")),
				masterKey: "masterKey",
				decrypt: func(encryptedKey, masterKey string) (key string, err error) {
					return encryptedKey, nil
				},
			},
			args{
				id: "id1",
			},
			res{
				err: zerrors.IsInternal,
			},
		},
		{
			"decryption error",
			fields{
				client: dbMock(t, expectQueryScanErr(
					"SELECT key FROM system.encryption_keys WHERE id = $1",
					[]string{"key"},
					[][]driver.Value{
						{
							"key1",
						},
					},
					"id1",
				)),
				masterKey: "wrong key",
				decrypt: func(encryptedKey, masterKey string) (key string, err error) {
					return "", fmt.Errorf("wrong masterkey")
				},
			},
			args{
				id: "id1",
			},
			res{
				err: zerrors.IsInternal,
			},
		},
		{
			"key ok",
			fields{
				client: dbMock(t, expectQuery(
					"SELECT key FROM system.encryption_keys WHERE id = $1",
					[]string{"key"},
					[][]driver.Value{
						{
							"key1",
						},
					},
					"id1",
				)),
				masterKey: "masterKey",
				decrypt: func(encryptedKey, masterKey string) (key string, err error) {
					return encryptedKey, nil
				},
			},
			args{
				id: "id1",
			},
			res{
				key: &crypto.Key{
					ID:    "id1",
					Value: "key1",
				},
			},
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			d := &Database{
				client:    tt.fields.client.db,
				masterKey: tt.fields.masterKey,
				decrypt:   tt.fields.decrypt,
			}
			got, err := d.ReadKey(tt.args.id)
			if tt.res.err == nil {
				assert.NoError(t, err)
			} else if tt.res.err != nil && !tt.res.err(err) {
				t.Errorf("got wrong err: %v", err)
			}
			if tt.res.err == nil {
				assert.Equal(t, tt.res.key, got)
			}
			if err := tt.fields.client.mock.ExpectationsWereMet(); err != nil {
				t.Error(err)
			}
		})
	}
}

func Test_database_CreateKeys(t *testing.T) {
	type fields struct {
		client    db
		masterKey string
		encrypt   func(key, masterKey string) (encryptedKey string, err error)
	}
	type args struct {
		keys []*crypto.Key
	}
	type res struct {
		err func(error) bool
	}
	tests := []struct {
		name   string
		fields fields
		args   args
		res    res
	}{
		{
			"encryption fails, error",
			fields{
				client:    dbMock(t),
				masterKey: "",
				encrypt: func(key, masterKey string) (encryptedKey string, err error) {
					return "", fmt.Errorf("encryption failed")
				},
			},
			args{
				keys: []*crypto.Key{
					{
						"id1",
						"key1",
					},
				},
			},
			res{
				err: zerrors.IsInternal,
			},
		},
		{
			"insert fails, error",
			fields{
				client: dbMock(t,
					expectBegin(nil),
					expectExec("INSERT INTO system.encryption_keys (id,key) VALUES ($1,$2)", sql.ErrTxDone),
					expectRollback(nil),
				),
				masterKey: "masterkey",
				encrypt: func(key, masterKey string) (encryptedKey string, err error) {
					return key, nil
				},
			},
			args{
				keys: []*crypto.Key{
					{
						"id1",
						"key1",
					},
				},
			},
			res{
				err: func(err error) bool {
					return errors.Is(err, sql.ErrTxDone)
				},
			},
		},
		{
			"single insert ok",
			fields{
				client: dbMock(t,
					expectBegin(nil),
					expectExec("INSERT INTO system.encryption_keys (id,key) VALUES ($1,$2)", nil, "id1", "key1"),
					expectCommit(nil),
				),
				masterKey: "masterkey",
				encrypt: func(key, masterKey string) (encryptedKey string, err error) {
					return key, nil
				},
			},
			args{
				keys: []*crypto.Key{
					{
						"id1",
						"key1",
					},
				},
			},
			res{
				err: nil,
			},
		},
		{
			"multiple insert ok",
			fields{
				client: dbMock(t,
					expectBegin(nil),
					expectExec("INSERT INTO system.encryption_keys (id,key) VALUES ($1,$2)", nil, "id1", "key1", "id2", "key2"),
					expectCommit(nil),
				),
				masterKey: "masterkey",
				encrypt: func(key, masterKey string) (encryptedKey string, err error) {
					return key, nil
				},
			},
			args{
				keys: []*crypto.Key{
					{
						"id1",
						"key1",
					},
					{
						"id2",
						"key2",
					},
				},
			},
			res{
				err: nil,
			},
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			d := &Database{
				client:    tt.fields.client.db,
				masterKey: tt.fields.masterKey,
				encrypt:   tt.fields.encrypt,
			}
			err := d.CreateKeys(context.Background(), tt.args.keys...)
			if tt.res.err == nil {
				assert.NoError(t, err)
			} else if tt.res.err != nil && !tt.res.err(err) {
				t.Errorf("got wrong err: %v", err)
			}
			if err := tt.fields.client.mock.ExpectationsWereMet(); err != nil {
				t.Error(err)
			}
		})
	}
}

func Test_checkMasterKeyLength(t *testing.T) {
	type args struct {
		masterKey string
	}
	tests := []struct {
		name string
		args args
		err  func(error) bool
	}{
		{
			"invalid length",
			args{
				masterKey: "",
			},
			zerrors.IsInternal,
		},
		{
			"valid length",
			args{
				masterKey: "!themasterkeywhichis32byteslong!",
			},
			nil,
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			err := checkMasterKeyLength(tt.args.masterKey)
			if tt.err == nil {
				assert.NoError(t, err)
			} else if tt.err != nil && !tt.err(err) {
				t.Errorf("got wrong err: %v", err)
			}
		})
	}
}

type db struct {
	mock sqlmock.Sqlmock
	db   *z_db.DB
}

func dbMock(t *testing.T, expectations ...func(m sqlmock.Sqlmock)) db {
	t.Helper()
	client, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter)))
	if err != nil {
		t.Fatalf("unable to create sql mock: %v", err)
	}
	for _, expectation := range expectations {
		expectation(mock)
	}
	return db{
		mock: mock,
		db:   &z_db.DB{DB: client},
	}
}

func expectQueryErr(query string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) {
	return func(m sqlmock.Sqlmock) {
		m.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(args...).WillReturnError(err)
	}
}

func expectQueryScanErr(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) {
	return func(m sqlmock.Sqlmock) {
		q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...)
		result := m.NewRows(cols)
		count := uint64(len(rows))
		for _, row := range rows {
			if cols[len(cols)-1] == "count" {
				row = append(row, count)
			}
			result.AddRow(row...)
		}
		q.WillReturnRows(result)
		q.RowsWillBeClosed()
	}
}

func expectQuery(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) {
	return func(m sqlmock.Sqlmock) {
		q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...)
		result := m.NewRows(cols)
		count := uint64(len(rows))
		for _, row := range rows {
			if cols[len(cols)-1] == "count" {
				row = append(row, count)
			}
			result.AddRow(row...)
		}
		q.WillReturnRows(result)
		q.RowsWillBeClosed()
	}
}

func expectExec(stmt string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) {
	return func(m sqlmock.Sqlmock) {
		query := m.ExpectExec(regexp.QuoteMeta(stmt)).WithArgs(args...)
		if err != nil {
			query.WillReturnError(err)
			return
		}
		query.WillReturnResult(sqlmock.NewResult(1, 1))
	}
}

func expectBegin(err error) func(m sqlmock.Sqlmock) {
	return func(m sqlmock.Sqlmock) {
		query := m.ExpectBegin()
		if err != nil {
			query.WillReturnError(err)
		}
	}
}

func expectCommit(err error) func(m sqlmock.Sqlmock) {
	return func(m sqlmock.Sqlmock) {
		query := m.ExpectCommit()
		if err != nil {
			query.WillReturnError(err)
		}
	}
}

func expectRollback(err error) func(m sqlmock.Sqlmock) {
	return func(m sqlmock.Sqlmock) {
		query := m.ExpectRollback()
		if err != nil {
			query.WillReturnError(err)
		}
	}
}