package middleware

import (
	"context"
	"fmt"
	"net/http"
	"net/http/httptest"
	"reflect"
	"testing"

	"github.com/stretchr/testify/assert"
	"golang.org/x/text/language"

	"github.com/zitadel/zitadel/internal/api/authz"
)

func Test_instanceInterceptor_Handler(t *testing.T) {
	type fields struct {
		verifier   authz.InstanceVerifier
		headerName string
	}
	type args struct {
		request *http.Request
	}
	type res struct {
		statusCode int
		context    context.Context
	}
	tests := []struct {
		name   string
		fields fields
		args   args
		res    res
	}{
		{
			"setInstance error",
			fields{
				verifier:   &mockInstanceVerifier{},
				headerName: "header",
			},
			args{
				request: httptest.NewRequest("", "/url", nil),
			},
			res{
				statusCode: 404,
				context:    nil,
			},
		},
		{
			"setInstance ok",
			fields{
				verifier:   &mockInstanceVerifier{"host"},
				headerName: "header",
			},
			args{
				request: func() *http.Request {
					r := httptest.NewRequest("", "/url", nil)
					r.Header.Set("header", "host")
					return r
				}(),
			},
			res{
				statusCode: 200,
				context:    authz.WithInstance(context.Background(), &mockInstance{}),
			},
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			a := &instanceInterceptor{
				verifier:   tt.fields.verifier,
				headerName: tt.fields.headerName,
			}
			next := &testHandler{}
			got := a.HandlerFunc(next.ServeHTTP)
			rr := httptest.NewRecorder()
			got.ServeHTTP(rr, tt.args.request)
			assert.Equal(t, tt.res.statusCode, rr.Code)
			assert.Equal(t, tt.res.context, next.context)
		})
	}
}

func Test_instanceInterceptor_HandlerFunc(t *testing.T) {
	type fields struct {
		verifier   authz.InstanceVerifier
		headerName string
	}
	type args struct {
		request *http.Request
	}
	type res struct {
		statusCode int
		context    context.Context
	}
	tests := []struct {
		name   string
		fields fields
		args   args
		res    res
	}{
		{
			"setInstance error",
			fields{
				verifier:   &mockInstanceVerifier{},
				headerName: "header",
			},
			args{
				request: httptest.NewRequest("", "/url", nil),
			},
			res{
				statusCode: 404,
				context:    nil,
			},
		},
		{
			"setInstance ok",
			fields{
				verifier:   &mockInstanceVerifier{"host"},
				headerName: "header",
			},
			args{
				request: func() *http.Request {
					r := httptest.NewRequest("", "/url", nil)
					r.Header.Set("header", "host")
					return r
				}(),
			},
			res{
				statusCode: 200,
				context:    authz.WithInstance(context.Background(), &mockInstance{}),
			},
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			a := &instanceInterceptor{
				verifier:   tt.fields.verifier,
				headerName: tt.fields.headerName,
			}
			next := &testHandler{}
			got := a.HandlerFunc(next.ServeHTTP)
			rr := httptest.NewRecorder()
			got.ServeHTTP(rr, tt.args.request)
			assert.Equal(t, tt.res.statusCode, rr.Code)
			assert.Equal(t, tt.res.context, next.context)
		})
	}
}

func Test_setInstance(t *testing.T) {
	type args struct {
		r          *http.Request
		verifier   authz.InstanceVerifier
		headerName string
	}
	type res struct {
		want context.Context
		err  bool
	}
	tests := []struct {
		name string
		args args
		res  res
	}{
		{
			"hostname not found, error",
			args{
				r: func() *http.Request {
					r := httptest.NewRequest("", "/url", nil)
					return r
				}(),
				verifier:   &mockInstanceVerifier{},
				headerName: "",
			},
			res{
				want: nil,
				err:  true,
			},
		},
		{
			"invalid host, error",
			args{
				r: func() *http.Request {
					r := httptest.NewRequest("", "/url", nil)
					r.Header.Set("header", "host2")
					return r
				}(),
				verifier:   &mockInstanceVerifier{"host"},
				headerName: "header",
			},
			res{
				want: nil,
				err:  true,
			},
		},
		{
			"valid host",
			args{
				r: func() *http.Request {
					r := httptest.NewRequest("", "/url", nil)
					r.Header.Set("header", "host")
					return r
				}(),
				verifier:   &mockInstanceVerifier{"host"},
				headerName: "header",
			},
			res{
				want: authz.WithInstance(context.Background(), &mockInstance{}),
				err:  false,
			},
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got, err := setInstance(tt.args.r, tt.args.verifier, tt.args.headerName)
			if (err != nil) != tt.res.err {
				t.Errorf("setInstance() error = %v, wantErr %v", err, tt.res.err)
				return
			}
			if !reflect.DeepEqual(got, tt.res.want) {
				t.Errorf("setInstance() got = %v, want %v", got, tt.res.want)
			}
		})
	}
}

type testHandler struct {
	context context.Context
}

func (t *testHandler) ServeHTTP(_ http.ResponseWriter, r *http.Request) {
	t.context = r.Context()
}

type mockInstanceVerifier struct {
	host string
}

func (m *mockInstanceVerifier) InstanceByHost(_ context.Context, host string) (authz.Instance, error) {
	if host != m.host {
		return nil, fmt.Errorf("invalid host")
	}
	return &mockInstance{}, nil
}

func (m *mockInstanceVerifier) InstanceByID(context.Context) (authz.Instance, error) {
	return nil, nil
}

type mockInstance struct{}

func (m *mockInstance) InstanceID() string {
	return "instanceID"
}

func (m *mockInstance) ProjectID() string {
	return "projectID"
}

func (m *mockInstance) ConsoleClientID() string {
	return "consoleClientID"
}

func (m *mockInstance) ConsoleApplicationID() string {
	return "consoleApplicationID"
}

func (m *mockInstance) DefaultLanguage() language.Tag {
	return language.English
}

func (m *mockInstance) DefaultOrganisationID() string {
	return "orgID"
}

func (m *mockInstance) RequestedDomain() string {
	return "zitadel.cloud"
}

func (m *mockInstance) RequestedHost() string {
	return "zitadel.cloud:443"
}

func (m *mockInstance) SecurityPolicyAllowedOrigins() []string {
	return nil
}