ipn/ipnlocal: close foreground sessions on SetServeConfig

This PR ensures zombie foregrounds are shutdown if a new
ServeConfig is created that wipes the ongoing foreground ones.
For example, "tailscale serve|funnel reset|off" should close
all open sessions.

Updates #8489

Signed-off-by: Marwan Sulaiman <marwan@tailscale.com>
This commit is contained in:
Marwan Sulaiman 2023-09-18 10:30:58 -04:00 committed by Marwan Sulaiman
parent 530aaa52f1
commit 651620623b
5 changed files with 140 additions and 11 deletions

View File

@ -8,6 +8,7 @@
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"io"
"log" "log"
"net" "net"
"net/url" "net/url"
@ -289,7 +290,7 @@ func (e *serveEnv) runServeCombined(subcmd serveMode) execFunc {
for { for {
_, err = watcher.Next() _, err = watcher.Next()
if err != nil { if err != nil {
if errors.Is(err, context.Canceled) { if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
return nil return nil
} }
return err return err

View File

@ -128,6 +128,13 @@ func RegisterNewSSHServer(fn newSSHServerFunc) {
newSSHServer = fn newSSHServer = fn
} }
// watchSession represents a WatchNotifications channel
// and sessionID as required to close targeted buses.
type watchSession struct {
ch chan *ipn.Notify
sessionID string
}
// LocalBackend is the glue between the major pieces of the Tailscale // LocalBackend is the glue between the major pieces of the Tailscale
// network software: the cloud control plane (via controlclient), the // network software: the cloud control plane (via controlclient), the
// network data plane (via wgengine), and the user-facing UIs and CLIs // network data plane (via wgengine), and the user-facing UIs and CLIs
@ -233,7 +240,7 @@ type LocalBackend struct {
loginFlags controlclient.LoginFlags loginFlags controlclient.LoginFlags
incomingFiles map[*incomingFile]bool incomingFiles map[*incomingFile]bool
fileWaiters set.HandleSet[context.CancelFunc] // of wake-up funcs fileWaiters set.HandleSet[context.CancelFunc] // of wake-up funcs
notifyWatchers set.HandleSet[chan *ipn.Notify] notifyWatchers set.HandleSet[*watchSession]
lastStatusTime time.Time // status.AsOf value of the last processed status update lastStatusTime time.Time // status.AsOf value of the last processed status update
// directFileRoot, if non-empty, means to write received files // directFileRoot, if non-empty, means to write received files
// directly to this directory, without staging them in an // directly to this directory, without staging them in an
@ -2058,7 +2065,7 @@ func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWa
} }
} }
handle := b.notifyWatchers.Add(ch) handle := b.notifyWatchers.Add(&watchSession{ch, sessionID})
b.mu.Unlock() b.mu.Unlock()
defer func() { defer func() {
@ -2103,8 +2110,8 @@ func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWa
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case n := <-ch: case n, ok := <-ch:
if !fn(n) { if !ok || !fn(n) {
return return
} }
} }
@ -2174,9 +2181,9 @@ func (b *LocalBackend) send(n ipn.Notify) {
n.FilesWaiting = &empty.Message{} n.FilesWaiting = &empty.Message{}
} }
for _, ch := range b.notifyWatchers { for _, sess := range b.notifyWatchers {
select { select {
case ch <- &n: case sess.ch <- &n:
default: default:
// Drop the notification if the channel is full. // Drop the notification if the channel is full.
} }

View File

@ -752,9 +752,9 @@ func TestWatchNotificationsCallbacks(t *testing.T) {
} }
// Send a notification. Range over notifyWatchers to get the channel // Send a notification. Range over notifyWatchers to get the channel
// because WatchNotifications doesn't expose the handle for it. // because WatchNotifications doesn't expose the handle for it.
for _, c := range b.notifyWatchers { for _, sess := range b.notifyWatchers {
select { select {
case c <- n: case sess.ch <- n:
default: default:
t.Fatalf("could not send notification") t.Fatalf("could not send notification")
} }

View File

@ -247,16 +247,17 @@ func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string
// If etag is present, check that it has // If etag is present, check that it has
// not changed from the last config. // not changed from the last config.
prevConfig := b.serveConfig
if etag != "" { if etag != "" {
// Note that we marshal b.serveConfig // Note that we marshal b.serveConfig
// and not use b.lastServeConfJSON as that might // and not use b.lastServeConfJSON as that might
// be a Go nil value, which produces a different // be a Go nil value, which produces a different
// checksum from a JSON "null" value. // checksum from a JSON "null" value.
previousCfg, err := json.Marshal(b.serveConfig) prevBytes, err := json.Marshal(prevConfig)
if err != nil { if err != nil {
return fmt.Errorf("error encoding previous config: %w", err) return fmt.Errorf("error encoding previous config: %w", err)
} }
sum := sha256.Sum256(previousCfg) sum := sha256.Sum256(prevBytes)
previousEtag := hex.EncodeToString(sum[:]) previousEtag := hex.EncodeToString(sum[:])
if etag != previousEtag { if etag != previousEtag {
return ErrETagMismatch return ErrETagMismatch
@ -279,6 +280,26 @@ func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string
} }
b.setTCPPortsInterceptedFromNetmapAndPrefsLocked(b.pm.CurrentPrefs()) b.setTCPPortsInterceptedFromNetmapAndPrefsLocked(b.pm.CurrentPrefs())
// clean up and close all previously open foreground sessions
// if the current ServeConfig has overwritten them.
if prevConfig.Valid() {
has := func(string) bool { return false }
if b.serveConfig.Valid() {
has = b.serveConfig.Foreground().Has
}
prevConfig.Foreground().Range(func(k string, v ipn.ServeConfigView) (cont bool) {
if !has(k) {
for _, sess := range b.notifyWatchers {
if sess.sessionID == k {
close(sess.ch)
}
}
}
return true
})
}
return nil return nil
} }

View File

@ -20,6 +20,7 @@
"path/filepath" "path/filepath"
"strings" "strings"
"testing" "testing"
"time"
"tailscale.com/ipn" "tailscale.com/ipn"
"tailscale.com/ipn/store/mem" "tailscale.com/ipn/store/mem"
@ -184,6 +185,105 @@ func getEtag(t *testing.T, b any) string {
return hex.EncodeToString(sum[:]) return hex.EncodeToString(sum[:])
} }
// TestServeConfigForeground tests the inter-dependency
// between a ServeConfig and a WatchIPNBus:
// 1. Creating a WatchIPNBus returns a sessionID, that
// 2. ServeConfig sets it as the key of the Foreground field.
// 3. ServeConfig expects the WatchIPNBus to clean up the Foreground
// config when the session is done.
// 4. WatchIPNBus expects the ServeConfig to send a signal (close the channel)
// if an incoming SetServeConfig removes previous foregrounds.
func TestServeConfigForeground(t *testing.T) {
b := newTestBackend(t)
ch1 := make(chan string, 1)
go func() {
defer close(ch1)
b.WatchNotifications(context.Background(), ipn.NotifyInitialState, nil, func(roNotify *ipn.Notify) (keepGoing bool) {
if roNotify.SessionID != "" {
ch1 <- roNotify.SessionID
}
return true
})
}()
ch2 := make(chan string, 1)
go func() {
b.WatchNotifications(context.Background(), ipn.NotifyInitialState, nil, func(roNotify *ipn.Notify) (keepGoing bool) {
if roNotify.SessionID != "" {
ch2 <- roNotify.SessionID
return true
}
ch2 <- "again" // let channel know fn was called again
return true
})
}()
var session1 string
select {
case session1 = <-ch1:
case <-time.After(time.Second):
t.Fatal("timed out waiting on watch notifications session id")
}
var session2 string
select {
case session2 = <-ch2:
case <-time.After(time.Second):
t.Fatal("timed out waiting on watch notifications session id")
}
err := b.SetServeConfig(&ipn.ServeConfig{
Foreground: map[string]*ipn.ServeConfig{
session1: {TCP: map[uint16]*ipn.TCPPortHandler{
443: {TCPForward: "http://localhost:3000"}},
},
session2: {TCP: map[uint16]*ipn.TCPPortHandler{
999: {TCPForward: "http://localhost:4000"}},
},
},
}, "")
if err != nil {
t.Fatal(err)
}
// Setting a new serve config should shut down WatchNotifications
// whose session IDs are no longer found: session1 goes, session2 stays.
err = b.SetServeConfig(&ipn.ServeConfig{
TCP: map[uint16]*ipn.TCPPortHandler{
5000: {TCPForward: "http://localhost:5000"},
},
Foreground: map[string]*ipn.ServeConfig{
session2: {TCP: map[uint16]*ipn.TCPPortHandler{
999: {TCPForward: "http://localhost:4000"}},
},
},
}, "")
if err != nil {
t.Fatal(err)
}
select {
case _, ok := <-ch1:
if ok {
t.Fatal("expected channel to be closed")
}
case <-time.After(time.Second):
t.Fatal("timed out waiting on watch notifications closing")
}
// check that the second session is still running
b.send(ipn.Notify{})
select {
case _, ok := <-ch2:
if !ok {
t.Fatal("expected second session to remain open")
}
case <-time.After(time.Second):
t.Fatal("timed out waiting on second session")
}
}
func TestServeConfigETag(t *testing.T) { func TestServeConfigETag(t *testing.T) {
b := newTestBackend(t) b := newTestBackend(t)