// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package dns

import (
	"encoding/binary"
	"io"
	"net"
	"testing"

	"github.com/google/go-cmp/cmp"
	dns "golang.org/x/net/dns/dnsmessage"
	"inet.af/netaddr"
	"tailscale.com/net/tsdial"
	"tailscale.com/util/dnsname"
)

func mkDNSRequest(domain dnsname.FQDN, tp dns.Type) []byte {
	var dnsHeader dns.Header
	question := dns.Question{
		Name:  dns.MustNewName(domain.WithTrailingDot()),
		Type:  tp,
		Class: dns.ClassINET,
	}

	builder := dns.NewBuilder(nil, dnsHeader)
	if err := builder.StartQuestions(); err != nil {
		panic(err)
	}
	if err := builder.Question(question); err != nil {
		panic(err)
	}

	if err := builder.StartAdditionals(); err != nil {
		panic(err)
	}

	ednsHeader := dns.ResourceHeader{
		Name:  dns.MustNewName("."),
		Type:  dns.TypeOPT,
		Class: dns.Class(4095),
	}

	if err := builder.OPTResource(ednsHeader, dns.OPTResource{}); err != nil {
		panic(err)
	}

	payload, _ := builder.Finish()

	return payload
}

func TestDNSOverTCP(t *testing.T) {
	f := fakeOSConfigurator{
		SplitDNS: true,
		BaseConfig: OSConfig{
			Nameservers:   mustIPs("8.8.8.8"),
			SearchDomains: fqdns("coffee.shop"),
		},
	}
	m := NewManager(t.Logf, &f, nil, new(tsdial.Dialer), nil)
	m.resolver.TestOnlySetHook(f.SetResolver)
	m.Set(Config{
		Hosts: hosts(
			"dave.ts.com.", "1.2.3.4",
			"bradfitz.ts.com.", "2.3.4.5"),
		Routes:        upstreams("ts.com", ""),
		SearchDomains: fqdns("tailscale.com", "universe.tf"),
	})
	defer m.Down()

	c, s := net.Pipe()
	defer s.Close()
	go m.HandleTCPConn(s, netaddr.IPPort{})
	defer c.Close()

	wantResults := map[dnsname.FQDN]string{
		"dave.ts.com.":     "1.2.3.4",
		"bradfitz.ts.com.": "2.3.4.5",
	}

	for domain, _ := range wantResults {
		b := mkDNSRequest(domain, dns.TypeA)
		binary.Write(c, binary.BigEndian, uint16(len(b)))
		c.Write(b)
	}

	results := map[dnsname.FQDN]string{}
	for i := 0; i < len(wantResults); i++ {
		var respLength uint16
		if err := binary.Read(c, binary.BigEndian, &respLength); err != nil {
			t.Fatalf("reading len: %v", err)
		}
		resp := make([]byte, int(respLength))
		if _, err := io.ReadFull(c, resp); err != nil {
			t.Fatalf("reading data: %v", err)
		}

		var parser dns.Parser
		if _, err := parser.Start(resp); err != nil {
			t.Errorf("parser.Start() failed: %v", err)
			continue
		}
		q, err := parser.Question()
		if err != nil {
			t.Errorf("parser.Question(): %v", err)
			continue
		}
		if err := parser.SkipAllQuestions(); err != nil {
			t.Errorf("parser.SkipAllQuestions(): %v", err)
			continue
		}
		ah, err := parser.AnswerHeader()
		if err != nil {
			t.Errorf("parser.AnswerHeader(): %v", err)
			continue
		}
		if ah.Type != dns.TypeA {
			t.Errorf("unexpected answer type: got %v, want %v", ah.Type, dns.TypeA)
			continue
		}
		res, err := parser.AResource()
		if err != nil {
			t.Errorf("parser.AResource(): %v", err)
			continue
		}
		results[dnsname.FQDN(q.Name.String())] = net.IP(res.A[:]).String()
	}
	c.Close()

	if diff := cmp.Diff(wantResults, results); diff != "" {
		t.Errorf("wrong results (-got+want)\n%s", diff)
	}
}