From 2636a83d0e7118169764a7fc97cb0058d7969e27 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sun, 25 Aug 2024 20:36:03 -0700 Subject: [PATCH] cmd/tta: pull out test driver dialing into a type, fix bugs There were a few places it could get wedged (notably the dial without a timeout). And add a knob for verbose debug logs. And keep two idle connections always. Updates #13038 Change-Id: I952ad182d7111481d97a83c12aa2ff4bfdc55fe8 Signed-off-by: Brad Fitzpatrick --- cmd/tta/tta.go | 150 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 111 insertions(+), 39 deletions(-) diff --git a/cmd/tta/tta.go b/cmd/tta/tta.go index 755e51a90..4a4c4a6be 100644 --- a/cmd/tta/tta.go +++ b/cmd/tta/tta.go @@ -24,6 +24,7 @@ "os" "os/exec" "regexp" + "strconv" "strings" "sync" "time" @@ -31,6 +32,7 @@ "tailscale.com/atomicfile" "tailscale.com/client/tailscale" "tailscale.com/hostinfo" + "tailscale.com/util/mak" "tailscale.com/util/must" "tailscale.com/util/set" "tailscale.com/version/distro" @@ -85,6 +87,7 @@ func main() { } flag.Parse() + debug := false if distro.Get() == distro.Gokrazy { cmdLine, _ := os.ReadFile("/proc/cmdline") explicitNS := false @@ -93,7 +96,11 @@ func main() { err := atomicfile.WriteFile("/tmp/resolv.conf", []byte("nameserver "+ns+"\n"), 0644) log.Printf("Wrote /tmp/resolv.conf: %v", err) explicitNS = true - break + continue + } + if v, ok := strings.CutPrefix(s, "tta.debug="); ok { + debug, _ = strconv.ParseBool(v) + continue } } if !explicitNS { @@ -134,28 +141,11 @@ func main() { }) var hs http.Server hs.Handler = &serveMux - var ( - stMu sync.Mutex - newSet = set.Set[net.Conn]{} // conns in StateNew - ) - needConnCh := make(chan bool, 1) - hs.ConnState = func(c net.Conn, s http.ConnState) { - stMu.Lock() - defer stMu.Unlock() - oldLen := len(newSet) - switch s { - case http.StateNew: - newSet.Add(c) - default: - newSet.Delete(c) - } - if oldLen != 0 && len(newSet) == 0 { - select { - case needConnCh <- true: - default: - } - } + revSt := revDialState{ + needConnCh: make(chan bool, 1), + debug: debug, } + hs.ConnState = revSt.connState conns := make(chan net.Conn, 1) lcRP := httputil.NewSingleHostReverseProxy(must.Get(url.Parse("http://local-tailscaled.sock"))) @@ -193,26 +183,14 @@ func main() { } }() - var lastErr string - needConnCh <- true - for { - <-needConnCh - c, err := connect() - if err != nil { - s := err.Error() - if s != lastErr { - log.Printf("Connect failure: %v", s) - } - lastErr = s - time.Sleep(time.Second) - continue - } - conns <- c - } + revSt.runDialOutLoop(conns) } func connect() (net.Conn, error) { - c, err := net.Dial("tcp", *driverAddr) + var d net.Dialer + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + c, err := d.DialContext(ctx, "tcp", *driverAddr) if err != nil { return nil, err } @@ -240,6 +218,100 @@ func (cl chanListener) Addr() net.Addr { } } +type revDialState struct { + needConnCh chan bool + debug bool + + mu sync.Mutex + newSet set.Set[net.Conn] // conns in StateNew + onNew map[net.Conn]func() +} + +func (s *revDialState) connState(c net.Conn, cs http.ConnState) { + s.mu.Lock() + defer s.mu.Unlock() + oldLen := len(s.newSet) + switch cs { + case http.StateNew: + if f, ok := s.onNew[c]; ok { + f() + delete(s.onNew, c) + } + s.newSet.Make() + s.newSet.Add(c) + default: + s.newSet.Delete(c) + } + s.vlogf("ConnState: %p now %v; newSet %v=>%v", c, s, oldLen, len(s.newSet)) + if len(s.newSet) < 2 { + select { + case s.needConnCh <- true: + default: + } + } +} + +func (s *revDialState) waitNeedConnect() { + for { + s.mu.Lock() + need := len(s.newSet) < 2 + s.mu.Unlock() + if need { + return + } + <-s.needConnCh + } +} + +func (s *revDialState) vlogf(format string, arg ...any) { + if !s.debug { + return + } + log.Printf(format, arg...) +} + +func (s *revDialState) runDialOutLoop(conns chan<- net.Conn) { + var lastErr string + connected := false + + for { + s.vlogf("[dial-driver] waiting need connect...") + s.waitNeedConnect() + s.vlogf("[dial-driver] connecting...") + t0 := time.Now() + c, err := connect() + if err != nil { + s := err.Error() + if s != lastErr { + log.Printf("[dial-driver] connect failure: %v", s) + } + lastErr = s + time.Sleep(time.Second) + continue + } + if !connected { + connected = true + log.Printf("Connected to %v", *driverAddr) + } + s.vlogf("[dial-driver] connected %v => %v after %v", c.LocalAddr(), c.RemoteAddr(), time.Since(t0)) + + inHTTP := make(chan struct{}) + s.mu.Lock() + mak.Set(&s.onNew, c, func() { close(inHTTP) }) + s.mu.Unlock() + + s.vlogf("[dial-driver] sending...") + conns <- c + s.vlogf("[dial-driver] sent; waiting") + select { + case <-inHTTP: + s.vlogf("[dial-driver] conn in HTTP") + case <-time.After(2 * time.Second): + s.vlogf("[dial-driver] timeout waiting for conn to be accepted into HTTP") + } + } +} + func addFirewallHandler(w http.ResponseWriter, r *http.Request) { if addFirewall == nil { http.Error(w, "firewall not supported", 500)