net/dns: ensure the Windows configurator does not touch the hosts file unless the configuration actually changed

We build up maps of both the existing MagicDNS configuration in hosts
and the desired MagicDNS configuration, compare the two, and only
write out a new one if there are changes. The comparison doesn't need
to be perfect, as the occasional false-positive is fine, but this
should greatly reduce rewrites of the hosts file.

I also changed the hosts updating code to remove the CRLF/LF conversion
stuff, and use Fprintf instead of Frintln to let us write those inline.

Updates #14428

Signed-off-by: Aaron Klotz <aaron@tailscale.com>
This commit is contained in:
Aaron Klotz 2025-01-13 13:47:56 -07:00
parent 27477983e3
commit d818a58a77
2 changed files with 107 additions and 19 deletions

View File

@ -8,10 +8,12 @@ import (
"bytes"
"errors"
"fmt"
"maps"
"net/netip"
"os"
"os/exec"
"path/filepath"
"slices"
"sort"
"strings"
"sync"
@ -140,9 +142,8 @@ func (m *windowsManager) setSplitDNS(resolvers []netip.Addr, domains []dnsname.F
return m.nrptDB.WriteSplitDNSConfig(servers, domains)
}
func setTailscaleHosts(prevHostsFile []byte, hosts []*HostEntry) ([]byte, error) {
b := bytes.ReplaceAll(prevHostsFile, []byte("\r\n"), []byte("\n"))
sc := bufio.NewScanner(bytes.NewReader(b))
func setTailscaleHosts(logf logger.Logf, prevHostsFile []byte, hosts []*HostEntry) ([]byte, error) {
sc := bufio.NewScanner(bytes.NewReader(prevHostsFile))
const (
header = "# TailscaleHostsSectionStart"
footer = "# TailscaleHostsSectionEnd"
@ -151,6 +152,32 @@ func setTailscaleHosts(prevHostsFile []byte, hosts []*HostEntry) ([]byte, error)
"# This section contains MagicDNS entries for Tailscale.",
"# Do not edit this section manually.",
}
prevEntries := make(map[netip.Addr][]string)
addPrevEntry := func(line string) {
if line == "" || line[0] == '#' {
return
}
parts := strings.Split(line, " ")
if len(parts) < 1 {
return
}
addr, err := netip.ParseAddr(parts[0])
if err != nil {
logf("Parsing address from hosts: %v", err)
return
}
prevEntries[addr] = parts[1:]
}
nextEntries := make(map[netip.Addr][]string, len(hosts))
for _, he := range hosts {
nextEntries[he.Addr] = he.Hosts
}
var out bytes.Buffer
var inSection bool
for sc.Scan() {
@ -164,26 +191,34 @@ func setTailscaleHosts(prevHostsFile []byte, hosts []*HostEntry) ([]byte, error)
continue
}
if inSection {
addPrevEntry(line)
continue
}
fmt.Fprintln(&out, line)
fmt.Fprintf(&out, "%s\r\n", line)
}
if err := sc.Err(); err != nil {
return nil, err
}
if len(hosts) > 0 {
fmt.Fprintln(&out, header)
for _, c := range comments {
fmt.Fprintln(&out, c)
}
fmt.Fprintln(&out)
for _, he := range hosts {
fmt.Fprintf(&out, "%s %s\n", he.Addr, strings.Join(he.Hosts, " "))
}
fmt.Fprintln(&out)
fmt.Fprintln(&out, footer)
unchanged := maps.EqualFunc(prevEntries, nextEntries, func(a, b []string) bool {
return slices.Equal(a, b)
})
if unchanged {
return nil, nil
}
return bytes.ReplaceAll(out.Bytes(), []byte("\n"), []byte("\r\n")), nil
if len(hosts) > 0 {
fmt.Fprintf(&out, "%s\r\n", header)
for _, c := range comments {
fmt.Fprintf(&out, "%s\r\n", c)
}
fmt.Fprintf(&out, "\r\n")
for _, he := range hosts {
fmt.Fprintf(&out, "%s %s\r\n", he.Addr, strings.Join(he.Hosts, " "))
}
fmt.Fprintf(&out, "\r\n%s\r\n", footer)
}
return out.Bytes(), nil
}
// setHosts sets the hosts file to contain the given host entries.
@ -197,10 +232,15 @@ func (m *windowsManager) setHosts(hosts []*HostEntry) error {
if err != nil {
return err
}
outB, err := setTailscaleHosts(b, hosts)
outB, err := setTailscaleHosts(m.logf, b, hosts)
if err != nil {
return err
}
if outB == nil {
// No change to hosts file, therefore no write necessary.
return nil
}
const fileMode = 0 // ignored on windows.
// This can fail spuriously with an access denied error, so retry it a

View File

@ -15,6 +15,7 @@ import (
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
"tailscale.com/types/logger"
"tailscale.com/util/dnsname"
"tailscale.com/util/winutil"
"tailscale.com/util/winutil/gp"
@ -24,9 +25,56 @@ const testGPRuleID = "{7B1B6151-84E6-41A3-8967-62F7F7B45687}"
func TestHostFileNewLines(t *testing.T) {
in := []byte("#foo\r\n#bar\n#baz\n")
want := []byte("#foo\r\n#bar\r\n#baz\r\n")
want := []byte("#foo\r\n#bar\r\n#baz\r\n# TailscaleHostsSectionStart\r\n# This section contains MagicDNS entries for Tailscale.\r\n# Do not edit this section manually.\r\n\r\n192.168.1.1 aaron\r\n\r\n# TailscaleHostsSectionEnd\r\n")
got, err := setTailscaleHosts(in, nil)
he := []*HostEntry{
&HostEntry{
Addr: netip.MustParseAddr("192.168.1.1"),
Hosts: []string{"aaron"},
},
}
got, err := setTailscaleHosts(logger.Discard, in, he)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, want) {
t.Errorf("got %q, want %q\n", got, want)
}
}
func TestHostFileUnchanged(t *testing.T) {
in := []byte("#foo\r\n#bar\r\n#baz\r\n# TailscaleHostsSectionStart\r\n# This section contains MagicDNS entries for Tailscale.\r\n# Do not edit this section manually.\r\n\r\n192.168.1.1 aaron\r\n\r\n# TailscaleHostsSectionEnd\r\n")
he := []*HostEntry{
&HostEntry{
Addr: netip.MustParseAddr("192.168.1.1"),
Hosts: []string{"aaron"},
},
}
got, err := setTailscaleHosts(logger.Discard, in, he)
if err != nil {
t.Fatal(err)
}
if got != nil {
t.Errorf("got %q, want nil\n", got)
}
}
func TestHostFileChanged(t *testing.T) {
in := []byte("#foo\r\n#bar\r\n#baz\r\n# TailscaleHostsSectionStart\r\n# This section contains MagicDNS entries for Tailscale.\r\n# Do not edit this section manually.\r\n\r\n192.168.1.1 aaron1\r\n\r\n# TailscaleHostsSectionEnd\r\n")
want := []byte("#foo\r\n#bar\r\n#baz\r\n# TailscaleHostsSectionStart\r\n# This section contains MagicDNS entries for Tailscale.\r\n# Do not edit this section manually.\r\n\r\n192.168.1.1 aaron1\r\n192.168.1.2 aaron2\r\n\r\n# TailscaleHostsSectionEnd\r\n")
he := []*HostEntry{
&HostEntry{
Addr: netip.MustParseAddr("192.168.1.1"),
Hosts: []string{"aaron1"},
},
&HostEntry{
Addr: netip.MustParseAddr("192.168.1.2"),
Hosts: []string{"aaron2"},
},
}
got, err := setTailscaleHosts(logger.Discard, in, he)
if err != nil {
t.Fatal(err)
}