mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-25 19:15:34 +00:00
tstest: add method to Replace values for tests
We have many function pointers that we replace for the duration of test and restore it on test completion, add method to do that. Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
parent
12100320d2
commit
b9ebf7cf14
@ -1075,16 +1075,12 @@ func TestUpdatePrefs(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if tt.sshOverTailscale {
|
if tt.sshOverTailscale {
|
||||||
old := getSSHClientEnvVar
|
tstest.Replace(t, &getSSHClientEnvVar, func() string { return "100.100.100.100 1 1" })
|
||||||
getSSHClientEnvVar = func() string { return "100.100.100.100 1 1" }
|
|
||||||
t.Cleanup(func() { getSSHClientEnvVar = old })
|
|
||||||
} else if isSSHOverTailscale() {
|
} else if isSSHOverTailscale() {
|
||||||
// The test is being executed over a "real" tailscale SSH
|
// The test is being executed over a "real" tailscale SSH
|
||||||
// session, but sshOverTailscale is unset. Make the test appear
|
// session, but sshOverTailscale is unset. Make the test appear
|
||||||
// as if it's not over tailscale SSH.
|
// as if it's not over tailscale SSH.
|
||||||
old := getSSHClientEnvVar
|
tstest.Replace(t, &getSSHClientEnvVar, func() string { return "" })
|
||||||
getSSHClientEnvVar = func() string { return "" }
|
|
||||||
t.Cleanup(func() { getSSHClientEnvVar = old })
|
|
||||||
}
|
}
|
||||||
if tt.env.goos == "" {
|
if tt.env.goos == "" {
|
||||||
tt.env.goos = "linux"
|
tt.env.goos = "linux"
|
||||||
|
@ -13,6 +13,8 @@
|
|||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"tailscale.com/tstest"
|
||||||
)
|
)
|
||||||
|
|
||||||
var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial")
|
var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial")
|
||||||
@ -142,9 +144,7 @@ func TestResolverAllHostStaticResult(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldTryBootstrap(t *testing.T) {
|
func TestShouldTryBootstrap(t *testing.T) {
|
||||||
oldDebug := debug
|
tstest.Replace(t, &debug, func() bool { return true })
|
||||||
t.Cleanup(func() { debug = oldDebug })
|
|
||||||
debug = func() bool { return true }
|
|
||||||
|
|
||||||
type step struct {
|
type step struct {
|
||||||
ip netip.Addr // IP we pretended to dial
|
ip netip.Addr // IP we pretended to dial
|
||||||
|
@ -22,6 +22,7 @@
|
|||||||
"tailscale.com/net/stun"
|
"tailscale.com/net/stun"
|
||||||
"tailscale.com/net/stun/stuntest"
|
"tailscale.com/net/stun/stuntest"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/tstest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHairpinSTUN(t *testing.T) {
|
func TestHairpinSTUN(t *testing.T) {
|
||||||
@ -679,9 +680,7 @@ func TestNoCaptivePortalWhenUDP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
oldTransport := noRedirectClient.Transport
|
tstest.Replace(t, &noRedirectClient.Transport, http.RoundTripper(tr))
|
||||||
t.Cleanup(func() { noRedirectClient.Transport = oldTransport })
|
|
||||||
noRedirectClient.Transport = tr
|
|
||||||
|
|
||||||
stunAddr, cleanup := stuntest.Serve(t)
|
stunAddr, cleanup := stuntest.Serve(t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
@ -6,12 +6,29 @@
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"tailscale.com/logtail/backoff"
|
"tailscale.com/logtail/backoff"
|
||||||
"tailscale.com/types/logger"
|
"tailscale.com/types/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Replace replaces the value of target with val.
|
||||||
|
// The old value is restored when the test ends.
|
||||||
|
func Replace[T any](t *testing.T, target *T, val T) {
|
||||||
|
t.Helper()
|
||||||
|
if target == nil {
|
||||||
|
t.Fatalf("Replace: nil pointer")
|
||||||
|
}
|
||||||
|
old := *target
|
||||||
|
t.Cleanup(func() {
|
||||||
|
*target = old
|
||||||
|
})
|
||||||
|
|
||||||
|
*target = val
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// WaitFor retries try for up to maxWait.
|
// WaitFor retries try for up to maxWait.
|
||||||
// It returns nil once try returns nil the first time.
|
// It returns nil once try returns nil the first time.
|
||||||
// If maxWait passes without success, it returns try's last error.
|
// If maxWait passes without success, it returns try's last error.
|
||||||
|
24
tstest/tstest_test.go
Normal file
24
tstest/tstest_test.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
package tstest
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestReplace(t *testing.T) {
|
||||||
|
before := "before"
|
||||||
|
done := false
|
||||||
|
t.Run("replace", func(t *testing.T) {
|
||||||
|
Replace(t, &before, "after")
|
||||||
|
if before != "after" {
|
||||||
|
t.Errorf("before = %q; want %q", before, "after")
|
||||||
|
}
|
||||||
|
done = true
|
||||||
|
})
|
||||||
|
if !done {
|
||||||
|
t.Fatal("subtest didn't run")
|
||||||
|
}
|
||||||
|
if before != "before" {
|
||||||
|
t.Errorf("before = %q; want %q", before, "before")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user