tailscale/ipn/ipnserver/server_test.go
Nick Khyl 6fac2903e1 ipn/ipnserver: fix race condition where LocalBackend is reset after a different user connects
In this commit, we add a failing test to verify that ipn/ipnserver.Server correctly
sets and unsets the current user when two different clients send requests concurrently
(A sends request, B sends request, A's request completes, B's request completes).

The expectation is that the user who wins the race becomes the current user
from the LocalBackend's perspective, remaining in this state until they disconnect,
after which a different user should be able to connect and use the LocalBackend.

We then fix the second of two bugs in (*Server).addActiveHTTPRequest, where a race
condition causes the LocalBackend's state to be reset after a new client connects,
instead of after the last active request of the previous client completes and the server
becomes idle.

Fixes tailscale/corp#25804

Signed-off-by: Nick Khyl <nickk@tailscale.com>
2025-01-14 15:54:43 -06:00

501 lines
15 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package ipnserver
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/http/httptest"
"runtime"
"strconv"
"sync"
"sync/atomic"
"testing"
"tailscale.com/client/tailscale"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/control/controlclient"
"tailscale.com/envknob"
"tailscale.com/ipn"
"tailscale.com/ipn/ipnauth"
"tailscale.com/ipn/ipnlocal"
"tailscale.com/ipn/store/mem"
"tailscale.com/tsd"
"tailscale.com/tstest"
"tailscale.com/types/logger"
"tailscale.com/types/logid"
"tailscale.com/types/ptr"
"tailscale.com/util/mak"
"tailscale.com/wgengine"
)
func TestWaiterSet(t *testing.T) {
var s waiterSet
wantLen := func(want int, when string) {
t.Helper()
if got := len(s); got != want {
t.Errorf("%s: len = %v; want %v", when, got, want)
}
}
wantLen(0, "initial")
var mu sync.Mutex
ctx, cancel := context.WithCancel(context.Background())
ready, cleanup := s.add(&mu, ctx)
wantLen(1, "after add")
select {
case <-ready:
t.Fatal("should not be ready")
default:
}
s.wakeAll()
<-ready
wantLen(1, "after fire")
cleanup()
wantLen(0, "after cleanup")
// And again but on an already-expired ctx.
cancel()
ready, cleanup = s.add(&mu, ctx)
<-ready // shouldn't block
cleanup()
wantLen(0, "at end")
}
func TestUserConnectDisconnectNonWindows(t *testing.T) {
enableLogging := false
if runtime.GOOS == "windows" {
setGOOSForTest(t, "linux")
}
ctx := context.Background()
server := startDefaultTestIPNServer(t, ctx, enableLogging)
// UserA connects and starts watching the IPN bus.
clientA := server.getClientAs("UserA")
watcherA, _ := clientA.WatchIPNBus(ctx, 0)
// The concept of "current user" is only relevant on Windows
// and it should not be set on non-Windows platforms.
server.checkCurrentUser(nil)
// Additionally, a different user should be able to connect and use the LocalAPI.
clientB := server.getClientAs("UserB")
if _, gotErr := clientB.Status(ctx); gotErr != nil {
t.Fatalf("Status(%q): want nil; got %v", clientB.User.Name, gotErr)
}
// Watching the IPN bus should also work for UserB.
watcherB, _ := clientB.WatchIPNBus(ctx, 0)
// And if we send a notification, both users should receive it.
wantErrMessage := "test error"
testNotify := ipn.Notify{ErrMessage: ptr.To(wantErrMessage)}
server.mustBackend().DebugNotify(testNotify)
if n, err := watcherA.Next(); err != nil {
t.Fatalf("IPNBusWatcher.Next(%q): %v", clientA.User.Name, err)
} else if gotErrMessage := n.ErrMessage; gotErrMessage == nil || *gotErrMessage != wantErrMessage {
t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientA.User.Name, wantErrMessage, gotErrMessage)
}
if n, err := watcherB.Next(); err != nil {
t.Fatalf("IPNBusWatcher.Next(%q): %v", clientB.User.Name, err)
} else if gotErrMessage := n.ErrMessage; gotErrMessage == nil || *gotErrMessage != wantErrMessage {
t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientB.User.Name, wantErrMessage, gotErrMessage)
}
}
func TestUserConnectDisconnectOnWindows(t *testing.T) {
enableLogging := false
setGOOSForTest(t, "windows")
ctx := context.Background()
server := startDefaultTestIPNServer(t, ctx, enableLogging)
client := server.getClientAs("User")
_, cancelWatcher := client.WatchIPNBus(ctx, 0)
// On Windows, however, the current user should be set to the user that connected.
server.checkCurrentUser(client.User)
// Cancel the IPN bus watcher request and wait for the server to unblock.
cancelWatcher()
server.blockWhileInUse(ctx)
// The current user should not be set after a disconnect, as no one is
// currently using the server.
server.checkCurrentUser(nil)
}
func TestIPNAlreadyInUseOnWindows(t *testing.T) {
enableLogging := false
setGOOSForTest(t, "windows")
ctx := context.Background()
server := startDefaultTestIPNServer(t, ctx, enableLogging)
// UserA connects and starts watching the IPN bus.
clientA := server.getClientAs("UserA")
clientA.WatchIPNBus(ctx, 0)
// While UserA is connected, UserB should not be able to connect.
clientB := server.getClientAs("UserB")
if _, gotErr := clientB.Status(ctx); gotErr == nil {
t.Fatalf("Status(%q): want error; got nil", clientB.User.Name)
} else if wantError := "401 Unauthorized: Tailscale already in use by UserA"; gotErr.Error() != wantError {
t.Fatalf("Status(%q): want %q; got %q", clientB.User.Name, wantError, gotErr.Error())
}
// Current user should still be UserA.
server.checkCurrentUser(clientA.User)
}
func TestSequentialOSUserSwitchingOnWindows(t *testing.T) {
enableLogging := false
setGOOSForTest(t, "windows")
ctx := context.Background()
server := startDefaultTestIPNServer(t, ctx, enableLogging)
connectDisconnectAsUser := func(name string) {
// User connects and starts watching the IPN bus.
client := server.getClientAs(name)
watcher, cancelWatcher := client.WatchIPNBus(ctx, 0)
defer cancelWatcher()
go pumpIPNBus(watcher)
// It should be the current user from the LocalBackend's perspective...
server.checkCurrentUser(client.User)
// until it disconnects.
cancelWatcher()
server.blockWhileInUse(ctx)
// Now, the current user should be unset.
server.checkCurrentUser(nil)
}
// UserA logs in, uses Tailscale for a bit, then logs out.
connectDisconnectAsUser("UserA")
// Same for UserB.
connectDisconnectAsUser("UserB")
}
func TestConcurrentOSUserSwitchingOnWindows(t *testing.T) {
enableLogging := false
setGOOSForTest(t, "windows")
ctx := context.Background()
server := startDefaultTestIPNServer(t, ctx, enableLogging)
connectDisconnectAsUser := func(name string) {
// User connects and starts watching the IPN bus.
client := server.getClientAs(name)
watcher, cancelWatcher := client.WatchIPNBus(ctx, ipn.NotifyInitialState)
defer cancelWatcher()
runtime.Gosched()
// Get the current user from the LocalBackend's perspective
// as soon as we're connected.
gotUID, gotActor := server.mustBackend().CurrentUserForTest()
// Wait for the first notification to arrive.
// It will either be the initial state we've requested via [ipn.NotifyInitialState],
// returned by an actual handler, or a "fake" notification sent by the server
// itself to indicate that it is being used by someone else.
n, err := watcher.Next()
if err != nil {
t.Fatal(err)
}
// If our user lost the race and the IPN is in use by another user,
// we should just return. For the sake of this test, we're not
// interested in waiting for the server to become idle.
if n.State != nil && *n.State == ipn.InUseOtherUser {
return
}
// Otherwise, our user should have been the current user since the time we connected.
if gotUID != client.User.UID {
t.Errorf("CurrentUser(Initial): got UID %q; want %q", gotUID, client.User.UID)
return
}
if gotActor, ok := gotActor.(*ipnauth.TestActor); !ok || *gotActor != *client.User {
t.Errorf("CurrentUser(Initial): got %v; want %v", gotActor, client.User)
return
}
// And should still be the current user (as they're still connected)...
server.checkCurrentUser(client.User)
}
numIterations := 10
for range numIterations {
numGoRoutines := 100
var wg sync.WaitGroup
wg.Add(numGoRoutines)
for i := range numGoRoutines {
// User logs in, uses Tailscale for a bit, then logs out
// in parallel with other users doing the same.
go func() {
defer wg.Done()
connectDisconnectAsUser("User-" + strconv.Itoa(i))
}()
}
wg.Wait()
}
}
func setGOOSForTest(tb testing.TB, goos string) {
tb.Helper()
envknob.Setenv("TS_DEBUG_FAKE_GOOS", goos)
tb.Cleanup(func() { envknob.Setenv("TS_DEBUG_FAKE_GOOS", "") })
}
func testLogger(tb testing.TB, enableLogging bool) logger.Logf {
tb.Helper()
if enableLogging {
return tstest.WhileTestRunningLogger(tb)
}
return logger.Discard
}
// newTestIPNServer creates a new IPN server for testing, using the specified local backend.
func newTestIPNServer(tb testing.TB, lb *ipnlocal.LocalBackend, enableLogging bool) *Server {
tb.Helper()
server := New(testLogger(tb, enableLogging), logid.PublicID{}, lb.NetMon())
server.lb.Store(lb)
return server
}
type testIPNClient struct {
tb testing.TB
*tailscale.LocalClient
User *ipnauth.TestActor
}
func (c *testIPNClient) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (*tailscale.IPNBusWatcher, context.CancelFunc) {
c.tb.Helper()
ctx, cancelWatcher := context.WithCancel(ctx)
c.tb.Cleanup(cancelWatcher)
watcher, err := c.LocalClient.WatchIPNBus(ctx, mask)
if err != nil {
c.tb.Fatalf("WatchIPNBus(%q): %v", c.User.Name, err)
}
c.tb.Cleanup(func() { watcher.Close() })
return watcher, cancelWatcher
}
func pumpIPNBus(watcher *tailscale.IPNBusWatcher) {
for {
_, err := watcher.Next()
if err != nil {
break
}
}
}
type testIPNServer struct {
tb testing.TB
*Server
clientID atomic.Int64
getClient func(*ipnauth.TestActor) *tailscale.LocalClient
actorsMu sync.Mutex
actors map[string]*ipnauth.TestActor
}
func (s *testIPNServer) getClientAs(name string) *testIPNClient {
clientID := fmt.Sprintf("Client-%d", 1+s.clientID.Add(1))
user := s.makeTestUser(name, clientID)
return &testIPNClient{
tb: s.tb,
LocalClient: s.getClient(user),
User: user,
}
}
func (s *testIPNServer) makeTestUser(name string, clientID string) *ipnauth.TestActor {
s.actorsMu.Lock()
defer s.actorsMu.Unlock()
actor := s.actors[name]
if actor == nil {
actor = &ipnauth.TestActor{Name: name}
if envknob.GOOS() == "windows" {
// Historically, as of 2025-01-13, IPN does not distinguish between
// different users on non-Windows devices. Therefore, the UID, which is
// an [ipn.WindowsUserID], should only be populated when the actual or
// fake GOOS is Windows.
actor.UID = ipn.WindowsUserID(fmt.Sprintf("S-1-5-21-1-0-0-%d", 1001+len(s.actors)))
}
mak.Set(&s.actors, name, actor)
s.tb.Cleanup(func() { delete(s.actors, name) })
}
actor = ptr.To(*actor)
actor.CID = ipnauth.ClientIDFrom(clientID)
return actor
}
func (s *testIPNServer) blockWhileInUse(ctx context.Context) error {
ready, cleanup := s.zeroReqWaiter.add(&s.mu, ctx)
<-ready
cleanup()
return ctx.Err()
}
func (s *testIPNServer) checkCurrentUser(want *ipnauth.TestActor) {
s.tb.Helper()
var wantUID ipn.WindowsUserID
if want != nil {
wantUID = want.UID
}
gotUID, gotActor := s.mustBackend().CurrentUserForTest()
if gotUID != wantUID {
s.tb.Errorf("CurrentUser: got UID %q; want %q", gotUID, wantUID)
}
if gotActor, ok := gotActor.(*ipnauth.TestActor); ok != (want != nil) || (want != nil && *gotActor != *want) {
s.tb.Errorf("CurrentUser: got %v; want %v", gotActor, want)
}
}
// startTestIPNServer starts a [httptest.Server] that hosts the specified IPN server for the
// duration of the test, using the specified base context for incoming requests.
// It returns a function that creates a [tailscale.LocalClient] as a given [ipnauth.TestActor].
func startTestIPNServer(tb testing.TB, baseContext context.Context, server *Server) *testIPNServer {
tb.Helper()
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actor, err := extractActorFromHeader(r.Header)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
tb.Errorf("extractActorFromHeader: %v", err)
return
}
ctx := newTestContextWithActor(r.Context(), actor)
server.serveHTTP(w, r.Clone(ctx))
}))
ts.Config.Addr = "http://" + apitype.LocalAPIHost
ts.Config.BaseContext = func(_ net.Listener) context.Context { return baseContext }
ts.Config.ErrorLog = logger.StdLogger(logger.WithPrefix(server.logf, "ipnserver: "))
ts.Start()
tb.Cleanup(ts.Close)
return &testIPNServer{
tb: tb,
Server: server,
getClient: func(actor *ipnauth.TestActor) *tailscale.LocalClient {
return &tailscale.LocalClient{Transport: newTestRoundTripper(ts, actor)}
},
}
}
func startDefaultTestIPNServer(tb testing.TB, ctx context.Context, enableLogging bool) *testIPNServer {
tb.Helper()
lb := newLocalBackendWithTestControl(tb, newUnreachableControlClient, enableLogging)
ctx, stopServer := context.WithCancel(ctx)
tb.Cleanup(stopServer)
return startTestIPNServer(tb, ctx, newTestIPNServer(tb, lb, enableLogging))
}
type testRoundTripper struct {
transport http.RoundTripper
actor *ipnauth.TestActor
}
// newTestRoundTripper creates a new [http.RoundTripper] that sends requests
// to the specified test server as the specified actor.
func newTestRoundTripper(ts *httptest.Server, actor *ipnauth.TestActor) *testRoundTripper {
return &testRoundTripper{
transport: &http.Transport{DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
var std net.Dialer
return std.DialContext(ctx, network, ts.Listener.Addr().(*net.TCPAddr).String())
}},
actor: actor,
}
}
const testActorHeaderName = "TS-Test-Actor"
// RoundTrip implements [http.RoundTripper] by forwarding the request to the underlying transport
// and including the test actor's identity in the request headers.
func (rt *testRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
actorJSON, err := json.Marshal(&rt.actor)
if err != nil {
// An [http.RoundTripper] must always close the request body, including on error.
if r.Body != nil {
r.Body.Close()
}
return nil, err
}
r = r.Clone(r.Context())
r.Header.Set(testActorHeaderName, string(actorJSON))
return rt.transport.RoundTrip(r)
}
// extractActorFromHeader extracts a test actor from the specified request headers.
func extractActorFromHeader(h http.Header) (*ipnauth.TestActor, error) {
actorJSON := h.Get(testActorHeaderName)
if actorJSON == "" {
return nil, errors.New("missing Test-Actor header")
}
actor := &ipnauth.TestActor{}
if err := json.Unmarshal([]byte(actorJSON), &actor); err != nil {
return nil, fmt.Errorf("invalid Test-Actor header: %v", err)
}
return actor, nil
}
type newControlClientFn func(tb testing.TB, opts controlclient.Options) controlclient.Client
func newLocalBackendWithTestControl(tb testing.TB, newControl newControlClientFn, enableLogging bool) *ipnlocal.LocalBackend {
tb.Helper()
sys := &tsd.System{}
store := &mem.Store{}
sys.Set(store)
logf := testLogger(tb, enableLogging)
e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry())
if err != nil {
tb.Fatalf("NewFakeUserspaceEngine: %v", err)
}
tb.Cleanup(e.Close)
sys.Set(e)
b, err := ipnlocal.NewLocalBackend(logf, logid.PublicID{}, sys, 0)
if err != nil {
tb.Fatalf("NewLocalBackend: %v", err)
}
tb.Cleanup(b.Shutdown)
b.DisablePortMapperForTest()
b.SetControlClientGetterForTesting(func(opts controlclient.Options) (controlclient.Client, error) {
return newControl(tb, opts), nil
})
return b
}
func newUnreachableControlClient(tb testing.TB, opts controlclient.Options) controlclient.Client {
tb.Helper()
opts.ServerURL = "https://127.0.0.1:1"
cc, err := controlclient.New(opts)
if err != nil {
tb.Fatal(err)
}
return cc
}
// newTestContextWithActor returns a new context that carries the identity
// of the specified actor and can be used for testing.
// It can be retrieved with [actorFromContext].
func newTestContextWithActor(ctx context.Context, actor ipnauth.Actor) context.Context {
return actorKey.WithValue(ctx, actorOrError{actor: actor})
}