// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package dns

import (
	"io/ioutil"
	"os"
	"path/filepath"
	"testing"

	"inet.af/netaddr"
	"tailscale.com/util/dnsname"
)

func TestSetDNS(t *testing.T) {
	const orig = "nameserver 9.9.9.9 # orig"
	tmp := t.TempDir()
	resolvPath := filepath.Join(tmp, "etc", "resolv.conf")
	backupPath := filepath.Join(tmp, "etc", "resolv.pre-tailscale-backup.conf")

	if err := os.MkdirAll(filepath.Dir(resolvPath), 0777); err != nil {
		t.Fatal(err)
	}
	if err := ioutil.WriteFile(resolvPath, []byte(orig), 0644); err != nil {
		t.Fatal(err)
	}

	readFile := func(t *testing.T, path string) string {
		t.Helper()
		b, err := ioutil.ReadFile(path)
		if err != nil {
			t.Fatal(err)
		}
		return string(b)
	}
	assertBaseState := func(t *testing.T) {
		if got := readFile(t, resolvPath); got != orig {
			t.Fatalf("resolv.conf:\n%s, want:\n%s", got, orig)
		}
		if _, err := os.Stat(backupPath); !os.IsNotExist(err) {
			t.Fatalf("resolv.conf backup: want it to be gone but: %v", err)
		}
	}

	m := directManager{fs: directFS{prefix: tmp}}
	if err := m.SetDNS(OSConfig{
		Nameservers:   []netaddr.IP{netaddr.MustParseIP("8.8.8.8"), netaddr.MustParseIP("8.8.4.4")},
		SearchDomains: []dnsname.FQDN{"ts.net.", "ts-dns.test."},
		MatchDomains:  []dnsname.FQDN{"ignored."},
	}); err != nil {
		t.Fatal(err)
	}
	want := `# resolv.conf(5) file generated by tailscale
# DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN

nameserver 8.8.8.8
nameserver 8.8.4.4
search ts.net ts-dns.test
`
	if got := readFile(t, resolvPath); got != want {
		t.Fatalf("resolv.conf:\n%s, want:\n%s", got, want)
	}
	if got := readFile(t, backupPath); got != orig {
		t.Fatalf("resolv.conf backup:\n%s, want:\n%s", got, orig)
	}

	// Test that a nil OSConfig cleans up resolv.conf.
	if err := m.SetDNS(OSConfig{}); err != nil {
		t.Fatal(err)
	}
	assertBaseState(t)

	// Test that Close cleans up resolv.conf.
	if err := m.SetDNS(OSConfig{Nameservers: []netaddr.IP{netaddr.MustParseIP("8.8.8.8")}}); err != nil {
		t.Fatal(err)
	}
	if err := m.Close(); err != nil {
		t.Fatal(err)
	}
	assertBaseState(t)
}