From b70c0c50fd73f134b8618792d89018cb444d8987 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Fri, 9 May 2025 23:12:00 -0500 Subject: [PATCH] ssh/tailssh: fix data race during execution of test In tailssh.go:1284, (*sshSession).startNewRecording starts a fire-and-forget goroutine that can outlive the test that triggered its creation. Among other things, it uses ss.logf, and may call it after the test has already returned. Since we typically use (*testing.T).Logf as the logger, this results in a data race and causes flaky tests. Ideally, we should fix the root cause and/or use a goroutines.Tracker to wait for the goroutine to complete. But with the release approaching, it's too risky to make such changes now. As a workaround, we update the tests to use tstest.WhileTestRunningLogger, which logs to t.Logf while the test is running and disables logging once the test finishes, avoiding the race. While there, we also fix TestSSHAuthFlow not to use log.Printf. Updates #15568 Updates #7707 (probably related) Signed-off-by: Nick Khyl --- ssh/tailssh/tailssh_test.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 980c77414..79479d7fb 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -16,7 +16,6 @@ import ( "errors" "fmt" "io" - "log" "net" "net/http" "net/http/httptest" @@ -48,7 +47,6 @@ import ( "tailscale.com/tsd" "tailscale.com/tstest" "tailscale.com/types/key" - "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/netmap" "tailscale.com/types/ptr" @@ -230,7 +228,7 @@ func TestMatchRule(t *testing.T) { t.Run(tt.name, func(t *testing.T) { c := &conn{ info: tt.ci, - srv: &server{logf: t.Logf}, + srv: &server{logf: tstest.WhileTestRunningLogger(t)}, } got, gotUser, gotAcceptEnv, err := c.matchRule(tt.rule) if err != tt.wantErr { @@ -349,7 +347,7 @@ func TestEvalSSHPolicy(t *testing.T) { t.Run(tt.name, func(t *testing.T) { c := &conn{ info: tt.ci, - srv: &server{logf: t.Logf}, + srv: &server{logf: tstest.WhileTestRunningLogger(t)}, } got, gotUser, gotAcceptEnv, match := c.evalSSHPolicy(tt.policy) if match != tt.wantMatch { @@ -491,7 +489,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { }) s := &server{ - logf: t.Logf, + logf: tstest.WhileTestRunningLogger(t), lb: &localState{ sshEnabled: true, matchingRule: newSSHRule( @@ -553,7 +551,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s.logf = t.Logf + s.logf = tstest.WhileTestRunningLogger(t) tstest.Replace(t, &handler, tt.handler) sc, dc := memnet.NewTCPConn(src, dst, 1024) var wg sync.WaitGroup @@ -621,7 +619,7 @@ func TestMultipleRecorders(t *testing.T) { }) s := &server{ - logf: t.Logf, + logf: tstest.WhileTestRunningLogger(t), lb: &localState{ sshEnabled: true, matchingRule: newSSHRule( @@ -714,7 +712,7 @@ func TestSSHRecordingNonInteractive(t *testing.T) { }) s := &server{ - logf: t.Logf, + logf: tstest.WhileTestRunningLogger(t), lb: &localState{ sshEnabled: true, matchingRule: newSSHRule( @@ -887,13 +885,15 @@ func TestSSHAuthFlow(t *testing.T) { }, } s := &server{ - logf: log.Printf, + logf: tstest.WhileTestRunningLogger(t), } defer s.Shutdown() src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22")) for _, tc := range tests { for _, authMethods := range [][]string{nil, {"publickey", "password"}, {"password", "publickey"}} { t.Run(fmt.Sprintf("%s-skip-none-auth-%v", tc.name, strings.Join(authMethods, "-then-")), func(t *testing.T) { + s.logf = tstest.WhileTestRunningLogger(t) + sc, dc := memnet.NewTCPConn(src, dst, 1024) s.lb = tc.state sshUser := "alice" @@ -1036,7 +1036,7 @@ func TestSSHAuthFlow(t *testing.T) { } func TestSSH(t *testing.T) { - var logf logger.Logf = t.Logf + logf := tstest.WhileTestRunningLogger(t) sys := tsd.NewSystem() eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry(), sys.Bus.Get()) if err != nil {