package database

import (
	"reflect"
	"testing"
)

func TestStatement_WriteArgs(t *testing.T) {
	type args struct {
		args []any
	}
	tests := []struct {
		name string
		args args
		want wantQuery
	}{
		{
			name: "no args",
			args: args{
				args: nil,
			},
		},
		{
			name: "1 arg",
			args: args{
				args: []any{"asdf"},
			},
			want: wantQuery{
				query: "$1",
				args:  []any{"asdf"},
			},
		},
		{
			name: "n args",
			args: args{
				args: []any{"asdf", "jkl", 1},
			},
			want: wantQuery{
				query: "$1, $2, $3",
				args:  []any{"asdf", "jkl", 1},
			},
		},
	}
	for _, tt := range tests {
		var stmt Statement
		t.Run(tt.name, func(t *testing.T) {
			stmt.WriteArgs(tt.args.args...)
			assertQuery(t, &stmt, tt.want)
		})
	}
}

type wantQuery struct {
	query string
	args  []any
}

func assertQuery(t *testing.T, stmt *Statement, want wantQuery) {
	if want.query != stmt.String() {
		t.Errorf("unexpected query: want: %q got: %q", want.query, stmt.String())
	}

	if len(want.args) != len(stmt.Args()) {
		t.Errorf("unexpected length of args: want %d, got %d", len(want.args), len(stmt.Args()))
		return
	}

	for i, wantArg := range want.args {
		if !reflect.DeepEqual(wantArg, stmt.Args()[i]) {
			t.Errorf("unexpected arg at position %d: want: %v, got: %v", i, wantArg, stmt.Args()[i])
		}
	}
}