// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

package main

import (
	"encoding/json"
	"net"
	"net/http"
	"net/http/httptest"
	"net/url"
	"reflect"
	"testing"

	"tailscale.com/tstest"
)

func BenchmarkHandleBootstrapDNS(b *testing.B) {
	tstest.Replace(b, bootstrapDNS, "log.tailscale.io,login.tailscale.com,controlplane.tailscale.com,login.us.tailscale.com")
	refreshBootstrapDNS()
	w := new(bitbucketResponseWriter)
	req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape("log.tailscale.io"), nil)
	b.ReportAllocs()
	b.ResetTimer()
	b.RunParallel(func(b *testing.PB) {
		for b.Next() {
			handleBootstrapDNS(w, req)
		}
	})
}

type bitbucketResponseWriter struct{}

func (b *bitbucketResponseWriter) Header() http.Header { return make(http.Header) }

func (b *bitbucketResponseWriter) Write(p []byte) (int, error) { return len(p), nil }

func (b *bitbucketResponseWriter) WriteHeader(statusCode int) {}

func getBootstrapDNS(t *testing.T, q string) dnsEntryMap {
	t.Helper()
	req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape(q), nil)
	w := httptest.NewRecorder()
	handleBootstrapDNS(w, req)

	res := w.Result()
	if res.StatusCode != 200 {
		t.Fatalf("got status=%d; want %d", res.StatusCode, 200)
	}
	var ips dnsEntryMap
	if err := json.NewDecoder(res.Body).Decode(&ips); err != nil {
		t.Fatalf("error decoding response body: %v", err)
	}
	return ips
}

func TestUnpublishedDNS(t *testing.T) {
	const published = "login.tailscale.com"
	const unpublished = "log.tailscale.io"

	prev1, prev2 := *bootstrapDNS, *unpublishedDNS
	*bootstrapDNS = published
	*unpublishedDNS = unpublished
	t.Cleanup(func() {
		*bootstrapDNS = prev1
		*unpublishedDNS = prev2
	})

	refreshBootstrapDNS()
	refreshUnpublishedDNS()

	hasResponse := func(q string) bool {
		_, found := getBootstrapDNS(t, q)[q]
		return found
	}

	if !hasResponse(published) {
		t.Errorf("expected response for: %s", published)
	}
	if !hasResponse(unpublished) {
		t.Errorf("expected response for: %s", unpublished)
	}

	// Verify that querying for a random query or a real query does not
	// leak our unpublished domain
	m1 := getBootstrapDNS(t, published)
	if _, found := m1[unpublished]; found {
		t.Errorf("found unpublished domain %s: %+v", unpublished, m1)
	}
	m2 := getBootstrapDNS(t, "random.example.com")
	if _, found := m2[unpublished]; found {
		t.Errorf("found unpublished domain %s: %+v", unpublished, m2)
	}
}

func resetMetrics() {
	publishedDNSHits.Set(0)
	publishedDNSMisses.Set(0)
	unpublishedDNSHits.Set(0)
	unpublishedDNSMisses.Set(0)
}

// Verify that we don't count an empty list in the unpublishedDNSCache as a
// cache hit in our metrics.
func TestUnpublishedDNSEmptyList(t *testing.T) {
	pub := dnsEntryMap{
		"tailscale.com": {net.IPv4(10, 10, 10, 10)},
	}
	dnsCache.Store(pub)
	dnsCacheBytes.Store([]byte(`{"tailscale.com":["10.10.10.10"]}`))

	unpublishedDNSCache.Store(dnsEntryMap{
		"log.tailscale.io":           {},
		"controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)},
	})

	t.Run("CacheMiss", func(t *testing.T) {
		// One domain in map but empty, one not in map at all
		for _, q := range []string{"log.tailscale.io", "login.tailscale.com"} {
			resetMetrics()
			ips := getBootstrapDNS(t, q)

			// Expected our public map to be returned on a cache miss
			if !reflect.DeepEqual(ips, pub) {
				t.Errorf("got ips=%+v; want %+v", ips, pub)
			}
			if v := unpublishedDNSHits.Value(); v != 0 {
				t.Errorf("got hits=%d; want 0", v)
			}
			if v := unpublishedDNSMisses.Value(); v != 1 {
				t.Errorf("got misses=%d; want 1", v)
			}
		}
	})

	// Verify that we do get a valid response and metric.
	t.Run("CacheHit", func(t *testing.T) {
		resetMetrics()
		ips := getBootstrapDNS(t, "controlplane.tailscale.com")
		want := dnsEntryMap{"controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)}}
		if !reflect.DeepEqual(ips, want) {
			t.Errorf("got ips=%+v; want %+v", ips, want)
		}
		if v := unpublishedDNSHits.Value(); v != 1 {
			t.Errorf("got hits=%d; want 1", v)
		}
		if v := unpublishedDNSMisses.Value(); v != 0 {
			t.Errorf("got misses=%d; want 0", v)
		}
	})
}