ipn/ipn{server,test}: extract the LocalAPI test client and server into ipntest

In this PR, we extract the in-process LocalAPI client/server implementation from ipn/ipnserver/server_test.go
into a new ipntest package to be used in high‑level black‑box tests, such as those for the tailscale CLI.

Updates #15575

Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
Nick Khyl
2025-04-16 16:32:10 -05:00
committed by Nick Khyl
parent 0f4f808e70
commit f0a27066c4
9 changed files with 846 additions and 335 deletions

View File

@@ -179,6 +179,12 @@ func contextWithActor(ctx context.Context, logf logger.Logf, c net.Conn) context
return actorKey.WithValue(ctx, actorOrError{actor: actor, err: err})
}
// NewContextWithActorForTest returns a new context that carries the identity
// of the specified actor. It is used in tests only.
func NewContextWithActorForTest(ctx context.Context, actor ipnauth.Actor) context.Context {
return actorKey.WithValue(ctx, actorOrError{actor: actor})
}
// actorFromContext returns an [ipnauth.Actor] associated with ctx,
// or an error if the context does not carry an actor's identity.
func actorFromContext(ctx context.Context) (ipnauth.Actor, error) {

View File

@@ -0,0 +1,42 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package ipnserver
import (
"context"
"net/http"
"tailscale.com/ipn/ipnauth"
)
// BlockWhileInUseByOtherForTest blocks while the actor can't connect to the server because
// the server is in use by a different actor. It is used in tests only.
func (s *Server) BlockWhileInUseByOtherForTest(ctx context.Context, actor ipnauth.Actor) error {
return s.blockWhileIdentityInUse(ctx, actor)
}
// BlockWhileInUseForTest blocks until the server becomes idle (no active requests),
// or the specified context is done. It returns the context's error if it is done.
// It is used in tests only.
func (s *Server) BlockWhileInUseForTest(ctx context.Context) error {
ready, cleanup := s.zeroReqWaiter.add(&s.mu, ctx)
s.mu.Lock()
busy := len(s.activeReqs) != 0
s.mu.Unlock()
if busy {
<-ready
}
cleanup()
return ctx.Err()
}
// ServeHTTPForTest responds to a single LocalAPI HTTP request.
// The request's context carries the actor that made the request
// and can be created with [NewContextWithActorForTest].
// It is used in tests only.
func (s *Server) ServeHTTPForTest(w http.ResponseWriter, r *http.Request) {
s.serveHTTP(w, r)
}

View File

@@ -1,76 +1,22 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package ipnserver
package ipnserver_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/http/httptest"
"runtime"
"strconv"
"sync"
"sync/atomic"
"testing"
"tailscale.com/client/local"
"tailscale.com/client/tailscale"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/control/controlclient"
"tailscale.com/envknob"
"tailscale.com/ipn"
"tailscale.com/ipn/ipnauth"
"tailscale.com/ipn/ipnlocal"
"tailscale.com/ipn/store/mem"
"tailscale.com/tsd"
"tailscale.com/tstest"
"tailscale.com/types/logger"
"tailscale.com/types/logid"
"tailscale.com/ipn/lapitest"
"tailscale.com/types/ptr"
"tailscale.com/util/mak"
"tailscale.com/wgengine"
)
func TestWaiterSet(t *testing.T) {
var s waiterSet
wantLen := func(want int, when string) {
t.Helper()
if got := len(s); got != want {
t.Errorf("%s: len = %v; want %v", when, got, want)
}
}
wantLen(0, "initial")
var mu sync.Mutex
ctx, cancel := context.WithCancel(context.Background())
ready, cleanup := s.add(&mu, ctx)
wantLen(1, "after add")
select {
case <-ready:
t.Fatal("should not be ready")
default:
}
s.wakeAll()
<-ready
wantLen(1, "after fire")
cleanup()
wantLen(0, "after cleanup")
// And again but on an already-expired ctx.
cancel()
ready, cleanup = s.add(&mu, ctx)
<-ready // shouldn't block
cleanup()
wantLen(0, "at end")
}
func TestUserConnectDisconnectNonWindows(t *testing.T) {
enableLogging := false
if runtime.GOOS == "windows" {
@@ -78,20 +24,20 @@ func TestUserConnectDisconnectNonWindows(t *testing.T) {
}
ctx := context.Background()
server := startDefaultTestIPNServer(t, ctx, enableLogging)
server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging))
// UserA connects and starts watching the IPN bus.
clientA := server.getClientAs("UserA")
clientA := server.ClientWithName("UserA")
watcherA, _ := clientA.WatchIPNBus(ctx, 0)
// The concept of "current user" is only relevant on Windows
// and it should not be set on non-Windows platforms.
server.checkCurrentUser(nil)
server.CheckCurrentUser(nil)
// Additionally, a different user should be able to connect and use the LocalAPI.
clientB := server.getClientAs("UserB")
clientB := server.ClientWithName("UserB")
if _, gotErr := clientB.Status(ctx); gotErr != nil {
t.Fatalf("Status(%q): want nil; got %v", clientB.User.Name, gotErr)
t.Fatalf("Status(%q): want nil; got %v", clientB.Username(), gotErr)
}
// Watching the IPN bus should also work for UserB.
@@ -100,18 +46,18 @@ func TestUserConnectDisconnectNonWindows(t *testing.T) {
// And if we send a notification, both users should receive it.
wantErrMessage := "test error"
testNotify := ipn.Notify{ErrMessage: ptr.To(wantErrMessage)}
server.mustBackend().DebugNotify(testNotify)
server.Backend().DebugNotify(testNotify)
if n, err := watcherA.Next(); err != nil {
t.Fatalf("IPNBusWatcher.Next(%q): %v", clientA.User.Name, err)
t.Fatalf("IPNBusWatcher.Next(%q): %v", clientA.Username(), err)
} else if gotErrMessage := n.ErrMessage; gotErrMessage == nil || *gotErrMessage != wantErrMessage {
t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientA.User.Name, wantErrMessage, gotErrMessage)
t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientA.Username(), wantErrMessage, gotErrMessage)
}
if n, err := watcherB.Next(); err != nil {
t.Fatalf("IPNBusWatcher.Next(%q): %v", clientB.User.Name, err)
t.Fatalf("IPNBusWatcher.Next(%q): %v", clientB.Username(), err)
} else if gotErrMessage := n.ErrMessage; gotErrMessage == nil || *gotErrMessage != wantErrMessage {
t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientB.User.Name, wantErrMessage, gotErrMessage)
t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientB.Username(), wantErrMessage, gotErrMessage)
}
}
@@ -120,21 +66,21 @@ func TestUserConnectDisconnectOnWindows(t *testing.T) {
setGOOSForTest(t, "windows")
ctx := context.Background()
server := startDefaultTestIPNServer(t, ctx, enableLogging)
server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging))
client := server.getClientAs("User")
client := server.ClientWithName("User")
_, cancelWatcher := client.WatchIPNBus(ctx, 0)
// On Windows, however, the current user should be set to the user that connected.
server.checkCurrentUser(client.User)
server.CheckCurrentUser(client.Actor)
// Cancel the IPN bus watcher request and wait for the server to unblock.
cancelWatcher()
server.blockWhileInUse(ctx)
server.BlockWhileInUse(ctx)
// The current user should not be set after a disconnect, as no one is
// currently using the server.
server.checkCurrentUser(nil)
server.CheckCurrentUser(nil)
}
func TestIPNAlreadyInUseOnWindows(t *testing.T) {
@@ -142,22 +88,22 @@ func TestIPNAlreadyInUseOnWindows(t *testing.T) {
setGOOSForTest(t, "windows")
ctx := context.Background()
server := startDefaultTestIPNServer(t, ctx, enableLogging)
server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging))
// UserA connects and starts watching the IPN bus.
clientA := server.getClientAs("UserA")
clientA := server.ClientWithName("UserA")
clientA.WatchIPNBus(ctx, 0)
// While UserA is connected, UserB should not be able to connect.
clientB := server.getClientAs("UserB")
clientB := server.ClientWithName("UserB")
if _, gotErr := clientB.Status(ctx); gotErr == nil {
t.Fatalf("Status(%q): want error; got nil", clientB.User.Name)
t.Fatalf("Status(%q): want error; got nil", clientB.Username())
} else if wantError := "401 Unauthorized: Tailscale already in use by UserA"; gotErr.Error() != wantError {
t.Fatalf("Status(%q): want %q; got %q", clientB.User.Name, wantError, gotErr.Error())
t.Fatalf("Status(%q): want %q; got %q", clientB.Username(), wantError, gotErr.Error())
}
// Current user should still be UserA.
server.checkCurrentUser(clientA.User)
server.CheckCurrentUser(clientA.Actor)
}
func TestSequentialOSUserSwitchingOnWindows(t *testing.T) {
@@ -165,22 +111,22 @@ func TestSequentialOSUserSwitchingOnWindows(t *testing.T) {
setGOOSForTest(t, "windows")
ctx := context.Background()
server := startDefaultTestIPNServer(t, ctx, enableLogging)
server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging))
connectDisconnectAsUser := func(name string) {
// User connects and starts watching the IPN bus.
client := server.getClientAs(name)
client := server.ClientWithName(name)
watcher, cancelWatcher := client.WatchIPNBus(ctx, 0)
defer cancelWatcher()
go pumpIPNBus(watcher)
// It should be the current user from the LocalBackend's perspective...
server.checkCurrentUser(client.User)
server.CheckCurrentUser(client.Actor)
// until it disconnects.
cancelWatcher()
server.blockWhileInUse(ctx)
server.BlockWhileInUse(ctx)
// Now, the current user should be unset.
server.checkCurrentUser(nil)
server.CheckCurrentUser(nil)
}
// UserA logs in, uses Tailscale for a bit, then logs out.
@@ -194,11 +140,11 @@ func TestConcurrentOSUserSwitchingOnWindows(t *testing.T) {
setGOOSForTest(t, "windows")
ctx := context.Background()
server := startDefaultTestIPNServer(t, ctx, enableLogging)
server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging))
connectDisconnectAsUser := func(name string) {
// User connects and starts watching the IPN bus.
client := server.getClientAs(name)
client := server.ClientWithName(name)
watcher, cancelWatcher := client.WatchIPNBus(ctx, ipn.NotifyInitialState)
defer cancelWatcher()
@@ -206,7 +152,7 @@ func TestConcurrentOSUserSwitchingOnWindows(t *testing.T) {
// Get the current user from the LocalBackend's perspective
// as soon as we're connected.
gotUID, gotActor := server.mustBackend().CurrentUserForTest()
gotUID, gotActor := server.Backend().CurrentUserForTest()
// Wait for the first notification to arrive.
// It will either be the initial state we've requested via [ipn.NotifyInitialState],
@@ -225,17 +171,17 @@ func TestConcurrentOSUserSwitchingOnWindows(t *testing.T) {
}
// Otherwise, our user should have been the current user since the time we connected.
if gotUID != client.User.UID {
t.Errorf("CurrentUser(Initial): got UID %q; want %q", gotUID, client.User.UID)
if gotUID != client.Actor.UserID() {
t.Errorf("CurrentUser(Initial): got UID %q; want %q", gotUID, client.Actor.UserID())
return
}
if gotActor, ok := gotActor.(*ipnauth.TestActor); !ok || *gotActor != *client.User {
t.Errorf("CurrentUser(Initial): got %v; want %v", gotActor, client.User)
if hasActor := gotActor != nil; !hasActor || gotActor != client.Actor {
t.Errorf("CurrentUser(Initial): got %v; want %v", gotActor, client.Actor)
return
}
// And should still be the current user (as they're still connected)...
server.checkCurrentUser(client.User)
server.CheckCurrentUser(client.Actor)
}
numIterations := 10
@@ -253,11 +199,11 @@ func TestConcurrentOSUserSwitchingOnWindows(t *testing.T) {
}
wg.Wait()
if err := server.blockWhileInUse(ctx); err != nil {
t.Fatalf("blockWhileInUse: %v", err)
if err := server.BlockWhileInUse(ctx); err != nil {
t.Fatalf("BlockUntilIdle: %v", err)
}
server.checkCurrentUser(nil)
server.CheckCurrentUser(nil)
}
}
@@ -266,13 +212,13 @@ func TestBlockWhileIdentityInUse(t *testing.T) {
setGOOSForTest(t, "windows")
ctx := context.Background()
server := startDefaultTestIPNServer(t, ctx, enableLogging)
server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging))
// connectWaitDisconnectAsUser connects as a user with the specified name
// and keeps the IPN bus watcher alive until the context is canceled.
// It returns a channel that is closed when done.
connectWaitDisconnectAsUser := func(ctx context.Context, name string) <-chan struct{} {
client := server.getClientAs(name)
client := server.ClientWithName(name)
watcher, cancelWatcher := client.WatchIPNBus(ctx, 0)
done := make(chan struct{})
@@ -301,8 +247,8 @@ func TestBlockWhileIdentityInUse(t *testing.T) {
// in blockWhileIdentityInUse. But the issue also occurs during
// the normal execution path when UserB connects to the IPN server
// while UserA is disconnecting.
userB := server.makeTestUser("UserB", "ClientB")
server.blockWhileIdentityInUse(ctx, userB)
userB := server.MakeTestActor("UserB", "ClientB")
server.BlockWhileInUseByOther(ctx, userB)
<-userADone
}
}
@@ -313,41 +259,7 @@ func setGOOSForTest(tb testing.TB, goos string) {
tb.Cleanup(func() { envknob.Setenv("TS_DEBUG_FAKE_GOOS", "") })
}
func testLogger(tb testing.TB, enableLogging bool) logger.Logf {
tb.Helper()
if enableLogging {
return tstest.WhileTestRunningLogger(tb)
}
return logger.Discard
}
// newTestIPNServer creates a new IPN server for testing, using the specified local backend.
func newTestIPNServer(tb testing.TB, lb *ipnlocal.LocalBackend, enableLogging bool) *Server {
tb.Helper()
server := New(testLogger(tb, enableLogging), logid.PublicID{}, lb.NetMon())
server.lb.Store(lb)
return server
}
type testIPNClient struct {
tb testing.TB
*local.Client
User *ipnauth.TestActor
}
func (c *testIPNClient) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (*tailscale.IPNBusWatcher, context.CancelFunc) {
c.tb.Helper()
ctx, cancelWatcher := context.WithCancel(ctx)
c.tb.Cleanup(cancelWatcher)
watcher, err := c.Client.WatchIPNBus(ctx, mask)
if err != nil {
c.tb.Fatalf("WatchIPNBus(%q): %v", c.User.Name, err)
}
c.tb.Cleanup(func() { watcher.Close() })
return watcher, cancelWatcher
}
func pumpIPNBus(watcher *tailscale.IPNBusWatcher) {
func pumpIPNBus(watcher *local.IPNBusWatcher) {
for {
_, err := watcher.Next()
if err != nil {
@@ -355,206 +267,3 @@ func pumpIPNBus(watcher *tailscale.IPNBusWatcher) {
}
}
}
type testIPNServer struct {
tb testing.TB
*Server
clientID atomic.Int64
getClient func(*ipnauth.TestActor) *local.Client
actorsMu sync.Mutex
actors map[string]*ipnauth.TestActor
}
func (s *testIPNServer) getClientAs(name string) *testIPNClient {
clientID := fmt.Sprintf("Client-%d", 1+s.clientID.Add(1))
user := s.makeTestUser(name, clientID)
return &testIPNClient{
tb: s.tb,
Client: s.getClient(user),
User: user,
}
}
func (s *testIPNServer) makeTestUser(name string, clientID string) *ipnauth.TestActor {
s.actorsMu.Lock()
defer s.actorsMu.Unlock()
actor := s.actors[name]
if actor == nil {
actor = &ipnauth.TestActor{Name: name}
if envknob.GOOS() == "windows" {
// Historically, as of 2025-01-13, IPN does not distinguish between
// different users on non-Windows devices. Therefore, the UID, which is
// an [ipn.WindowsUserID], should only be populated when the actual or
// fake GOOS is Windows.
actor.UID = ipn.WindowsUserID(fmt.Sprintf("S-1-5-21-1-0-0-%d", 1001+len(s.actors)))
}
mak.Set(&s.actors, name, actor)
s.tb.Cleanup(func() { delete(s.actors, name) })
}
actor = ptr.To(*actor)
actor.CID = ipnauth.ClientIDFrom(clientID)
return actor
}
func (s *testIPNServer) blockWhileInUse(ctx context.Context) error {
ready, cleanup := s.zeroReqWaiter.add(&s.mu, ctx)
s.mu.Lock()
busy := len(s.activeReqs) != 0
s.mu.Unlock()
if busy {
<-ready
}
cleanup()
return ctx.Err()
}
func (s *testIPNServer) checkCurrentUser(want *ipnauth.TestActor) {
s.tb.Helper()
var wantUID ipn.WindowsUserID
if want != nil {
wantUID = want.UID
}
gotUID, gotActor := s.mustBackend().CurrentUserForTest()
if gotUID != wantUID {
s.tb.Errorf("CurrentUser: got UID %q; want %q", gotUID, wantUID)
}
if gotActor, ok := gotActor.(*ipnauth.TestActor); ok != (want != nil) || (want != nil && *gotActor != *want) {
s.tb.Errorf("CurrentUser: got %v; want %v", gotActor, want)
}
}
// startTestIPNServer starts a [httptest.Server] that hosts the specified IPN server for the
// duration of the test, using the specified base context for incoming requests.
// It returns a function that creates a [local.Client] as a given [ipnauth.TestActor].
func startTestIPNServer(tb testing.TB, baseContext context.Context, server *Server) *testIPNServer {
tb.Helper()
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actor, err := extractActorFromHeader(r.Header)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
tb.Errorf("extractActorFromHeader: %v", err)
return
}
ctx := newTestContextWithActor(r.Context(), actor)
server.serveHTTP(w, r.Clone(ctx))
}))
ts.Config.Addr = "http://" + apitype.LocalAPIHost
ts.Config.BaseContext = func(_ net.Listener) context.Context { return baseContext }
ts.Config.ErrorLog = logger.StdLogger(logger.WithPrefix(server.logf, "ipnserver: "))
ts.Start()
tb.Cleanup(ts.Close)
return &testIPNServer{
tb: tb,
Server: server,
getClient: func(actor *ipnauth.TestActor) *local.Client {
return &local.Client{Transport: newTestRoundTripper(ts, actor)}
},
}
}
func startDefaultTestIPNServer(tb testing.TB, ctx context.Context, enableLogging bool) *testIPNServer {
tb.Helper()
lb := newLocalBackendWithTestControl(tb, newUnreachableControlClient, enableLogging)
ctx, stopServer := context.WithCancel(ctx)
tb.Cleanup(stopServer)
return startTestIPNServer(tb, ctx, newTestIPNServer(tb, lb, enableLogging))
}
type testRoundTripper struct {
transport http.RoundTripper
actor *ipnauth.TestActor
}
// newTestRoundTripper creates a new [http.RoundTripper] that sends requests
// to the specified test server as the specified actor.
func newTestRoundTripper(ts *httptest.Server, actor *ipnauth.TestActor) *testRoundTripper {
return &testRoundTripper{
transport: &http.Transport{DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
var std net.Dialer
return std.DialContext(ctx, network, ts.Listener.Addr().(*net.TCPAddr).String())
}},
actor: actor,
}
}
const testActorHeaderName = "TS-Test-Actor"
// RoundTrip implements [http.RoundTripper] by forwarding the request to the underlying transport
// and including the test actor's identity in the request headers.
func (rt *testRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
actorJSON, err := json.Marshal(&rt.actor)
if err != nil {
// An [http.RoundTripper] must always close the request body, including on error.
if r.Body != nil {
r.Body.Close()
}
return nil, err
}
r = r.Clone(r.Context())
r.Header.Set(testActorHeaderName, string(actorJSON))
return rt.transport.RoundTrip(r)
}
// extractActorFromHeader extracts a test actor from the specified request headers.
func extractActorFromHeader(h http.Header) (*ipnauth.TestActor, error) {
actorJSON := h.Get(testActorHeaderName)
if actorJSON == "" {
return nil, errors.New("missing Test-Actor header")
}
actor := &ipnauth.TestActor{}
if err := json.Unmarshal([]byte(actorJSON), &actor); err != nil {
return nil, fmt.Errorf("invalid Test-Actor header: %v", err)
}
return actor, nil
}
type newControlClientFn func(tb testing.TB, opts controlclient.Options) controlclient.Client
func newLocalBackendWithTestControl(tb testing.TB, newControl newControlClientFn, enableLogging bool) *ipnlocal.LocalBackend {
tb.Helper()
sys := tsd.NewSystem()
store := &mem.Store{}
sys.Set(store)
logf := testLogger(tb, enableLogging)
e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry(), sys.Bus.Get())
if err != nil {
tb.Fatalf("NewFakeUserspaceEngine: %v", err)
}
tb.Cleanup(e.Close)
sys.Set(e)
b, err := ipnlocal.NewLocalBackend(logf, logid.PublicID{}, sys, 0)
if err != nil {
tb.Fatalf("NewLocalBackend: %v", err)
}
tb.Cleanup(b.Shutdown)
b.DisablePortMapperForTest()
b.SetControlClientGetterForTesting(func(opts controlclient.Options) (controlclient.Client, error) {
return newControl(tb, opts), nil
})
return b
}
func newUnreachableControlClient(tb testing.TB, opts controlclient.Options) controlclient.Client {
tb.Helper()
opts.ServerURL = "https://127.0.0.1:1"
cc, err := controlclient.New(opts)
if err != nil {
tb.Fatal(err)
}
return cc
}
// newTestContextWithActor returns a new context that carries the identity
// of the specified actor and can be used for testing.
// It can be retrieved with [actorFromContext].
func newTestContextWithActor(ctx context.Context, actor ipnauth.Actor) context.Context {
return actorKey.WithValue(ctx, actorOrError{actor: actor})
}

View File

@@ -0,0 +1,46 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package ipnserver
import (
"context"
"sync"
"testing"
)
func TestWaiterSet(t *testing.T) {
var s waiterSet
wantLen := func(want int, when string) {
t.Helper()
if got := len(s); got != want {
t.Errorf("%s: len = %v; want %v", when, got, want)
}
}
wantLen(0, "initial")
var mu sync.Mutex
ctx, cancel := context.WithCancel(context.Background())
ready, cleanup := s.add(&mu, ctx)
wantLen(1, "after add")
select {
case <-ready:
t.Fatal("should not be ready")
default:
}
s.wakeAll()
<-ready
wantLen(1, "after fire")
cleanup()
wantLen(0, "after cleanup")
// And again but on an already-expired ctx.
cancel()
ready, cleanup = s.add(&mu, ctx)
<-ready // shouldn't block
cleanup()
wantLen(0, "at end")
}