net/dns: add a simple test for resolv.conf inotify watcher

Updates #14699

Signed-off-by: Anton <anton@tailscale.com>
This commit is contained in:
Anton 2025-02-11 17:26:07 +00:00 committed by Anton Tolchanov
parent b865ceea20
commit c4984632ca
2 changed files with 75 additions and 13 deletions

View File

@ -6,20 +6,28 @@ package dns
import (
"bytes"
"context"
"fmt"
"github.com/illarion/gonotify/v2"
"tailscale.com/health"
)
func (m *directManager) runFileWatcher() {
ctx, cancel := context.WithCancel(m.ctx)
if err := watchFile(m.ctx, "/etc/", resolvConf, m.checkForFileTrample); err != nil {
// This is all best effort for now, so surface warnings to users.
m.logf("dns: inotify: %s", err)
}
}
// watchFile sets up an inotify watch for a given directory and
// calls the callback function every time a particular file is changed.
// The filename should be located in the provided directory.
func watchFile(ctx context.Context, dir, filename string, cb func()) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
in, err := gonotify.NewInotify(ctx)
if err != nil {
// Oh well, we tried. This is all best effort for now, to
// surface warnings to users.
m.logf("dns: inotify new: %v", err)
return
return fmt.Errorf("NewInotify: %w", err)
}
const events = gonotify.IN_ATTRIB |
@ -29,22 +37,20 @@ func (m *directManager) runFileWatcher() {
gonotify.IN_MODIFY |
gonotify.IN_MOVE
if err := in.AddWatch("/etc/", events); err != nil {
m.logf("dns: inotify addwatch: %v", err)
return
if err := in.AddWatch(dir, events); err != nil {
return fmt.Errorf("AddWatch: %w", err)
}
for {
events, err := in.Read()
if ctx.Err() != nil {
return
return ctx.Err()
}
if err != nil {
m.logf("dns: inotify read: %v", err)
return
return fmt.Errorf("Read: %w", err)
}
var match bool
for _, ev := range events {
if ev.Name == resolvConf {
if ev.Name == filename {
match = true
break
}
@ -52,7 +58,7 @@ func (m *directManager) runFileWatcher() {
if !match {
continue
}
m.checkForFileTrample()
cb()
}
}

View File

@ -0,0 +1,56 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package dns
import (
"context"
"errors"
"fmt"
"os"
"sync/atomic"
"testing"
"time"
"golang.org/x/sync/errgroup"
)
func TestWatchFile(t *testing.T) {
dir := t.TempDir()
filepath := dir + "/test.txt"
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var callbackCalled atomic.Bool
callbackDone := make(chan bool)
callback := func() {
callbackDone <- true
callbackCalled.Store(true)
}
var eg errgroup.Group
eg.Go(func() error { return watchFile(ctx, dir, filepath, callback) })
// Keep writing until we get a callback.
func() {
for i := range 10000 {
if err := os.WriteFile(filepath, []byte(fmt.Sprintf("write%d", i)), 0644); err != nil {
t.Fatal(err)
}
select {
case <-callbackDone:
return
case <-time.After(10 * time.Millisecond):
}
}
}()
cancel()
if err := eg.Wait(); err != nil && !errors.Is(err, context.Canceled) {
t.Error(err)
}
if !callbackCalled.Load() {
t.Error("callback was not called")
}
}