mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 13:05:46 +00:00
ipn/ipnlocal: close foreground sessions on SetServeConfig
This PR ensures zombie foregrounds are shutdown if a new ServeConfig is created that wipes the ongoing foreground ones. For example, "tailscale serve|funnel reset|off" should close all open sessions. Updates #8489 Signed-off-by: Marwan Sulaiman <marwan@tailscale.com>
This commit is contained in:
parent
530aaa52f1
commit
651620623b
@ -8,6 +8,7 @@
|
|||||||
"errors"
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -289,7 +290,7 @@ func (e *serveEnv) runServeCombined(subcmd serveMode) execFunc {
|
|||||||
for {
|
for {
|
||||||
_, err = watcher.Next()
|
_, err = watcher.Next()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
@ -128,6 +128,13 @@ func RegisterNewSSHServer(fn newSSHServerFunc) {
|
|||||||
newSSHServer = fn
|
newSSHServer = fn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// watchSession represents a WatchNotifications channel
|
||||||
|
// and sessionID as required to close targeted buses.
|
||||||
|
type watchSession struct {
|
||||||
|
ch chan *ipn.Notify
|
||||||
|
sessionID string
|
||||||
|
}
|
||||||
|
|
||||||
// LocalBackend is the glue between the major pieces of the Tailscale
|
// LocalBackend is the glue between the major pieces of the Tailscale
|
||||||
// network software: the cloud control plane (via controlclient), the
|
// network software: the cloud control plane (via controlclient), the
|
||||||
// network data plane (via wgengine), and the user-facing UIs and CLIs
|
// network data plane (via wgengine), and the user-facing UIs and CLIs
|
||||||
@ -233,7 +240,7 @@ type LocalBackend struct {
|
|||||||
loginFlags controlclient.LoginFlags
|
loginFlags controlclient.LoginFlags
|
||||||
incomingFiles map[*incomingFile]bool
|
incomingFiles map[*incomingFile]bool
|
||||||
fileWaiters set.HandleSet[context.CancelFunc] // of wake-up funcs
|
fileWaiters set.HandleSet[context.CancelFunc] // of wake-up funcs
|
||||||
notifyWatchers set.HandleSet[chan *ipn.Notify]
|
notifyWatchers set.HandleSet[*watchSession]
|
||||||
lastStatusTime time.Time // status.AsOf value of the last processed status update
|
lastStatusTime time.Time // status.AsOf value of the last processed status update
|
||||||
// directFileRoot, if non-empty, means to write received files
|
// directFileRoot, if non-empty, means to write received files
|
||||||
// directly to this directory, without staging them in an
|
// directly to this directory, without staging them in an
|
||||||
@ -2058,7 +2065,7 @@ func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
handle := b.notifyWatchers.Add(ch)
|
handle := b.notifyWatchers.Add(&watchSession{ch, sessionID})
|
||||||
b.mu.Unlock()
|
b.mu.Unlock()
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -2103,8 +2110,8 @@ func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWa
|
|||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case n := <-ch:
|
case n, ok := <-ch:
|
||||||
if !fn(n) {
|
if !ok || !fn(n) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2174,9 +2181,9 @@ func (b *LocalBackend) send(n ipn.Notify) {
|
|||||||
n.FilesWaiting = &empty.Message{}
|
n.FilesWaiting = &empty.Message{}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ch := range b.notifyWatchers {
|
for _, sess := range b.notifyWatchers {
|
||||||
select {
|
select {
|
||||||
case ch <- &n:
|
case sess.ch <- &n:
|
||||||
default:
|
default:
|
||||||
// Drop the notification if the channel is full.
|
// Drop the notification if the channel is full.
|
||||||
}
|
}
|
||||||
|
@ -752,9 +752,9 @@ func TestWatchNotificationsCallbacks(t *testing.T) {
|
|||||||
}
|
}
|
||||||
// Send a notification. Range over notifyWatchers to get the channel
|
// Send a notification. Range over notifyWatchers to get the channel
|
||||||
// because WatchNotifications doesn't expose the handle for it.
|
// because WatchNotifications doesn't expose the handle for it.
|
||||||
for _, c := range b.notifyWatchers {
|
for _, sess := range b.notifyWatchers {
|
||||||
select {
|
select {
|
||||||
case c <- n:
|
case sess.ch <- n:
|
||||||
default:
|
default:
|
||||||
t.Fatalf("could not send notification")
|
t.Fatalf("could not send notification")
|
||||||
}
|
}
|
||||||
|
@ -247,16 +247,17 @@ func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string
|
|||||||
|
|
||||||
// If etag is present, check that it has
|
// If etag is present, check that it has
|
||||||
// not changed from the last config.
|
// not changed from the last config.
|
||||||
|
prevConfig := b.serveConfig
|
||||||
if etag != "" {
|
if etag != "" {
|
||||||
// Note that we marshal b.serveConfig
|
// Note that we marshal b.serveConfig
|
||||||
// and not use b.lastServeConfJSON as that might
|
// and not use b.lastServeConfJSON as that might
|
||||||
// be a Go nil value, which produces a different
|
// be a Go nil value, which produces a different
|
||||||
// checksum from a JSON "null" value.
|
// checksum from a JSON "null" value.
|
||||||
previousCfg, err := json.Marshal(b.serveConfig)
|
prevBytes, err := json.Marshal(prevConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error encoding previous config: %w", err)
|
return fmt.Errorf("error encoding previous config: %w", err)
|
||||||
}
|
}
|
||||||
sum := sha256.Sum256(previousCfg)
|
sum := sha256.Sum256(prevBytes)
|
||||||
previousEtag := hex.EncodeToString(sum[:])
|
previousEtag := hex.EncodeToString(sum[:])
|
||||||
if etag != previousEtag {
|
if etag != previousEtag {
|
||||||
return ErrETagMismatch
|
return ErrETagMismatch
|
||||||
@ -279,6 +280,26 @@ func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string
|
|||||||
}
|
}
|
||||||
|
|
||||||
b.setTCPPortsInterceptedFromNetmapAndPrefsLocked(b.pm.CurrentPrefs())
|
b.setTCPPortsInterceptedFromNetmapAndPrefsLocked(b.pm.CurrentPrefs())
|
||||||
|
|
||||||
|
// clean up and close all previously open foreground sessions
|
||||||
|
// if the current ServeConfig has overwritten them.
|
||||||
|
if prevConfig.Valid() {
|
||||||
|
has := func(string) bool { return false }
|
||||||
|
if b.serveConfig.Valid() {
|
||||||
|
has = b.serveConfig.Foreground().Has
|
||||||
|
}
|
||||||
|
prevConfig.Foreground().Range(func(k string, v ipn.ServeConfigView) (cont bool) {
|
||||||
|
if !has(k) {
|
||||||
|
for _, sess := range b.notifyWatchers {
|
||||||
|
if sess.sessionID == k {
|
||||||
|
close(sess.ch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"tailscale.com/ipn"
|
"tailscale.com/ipn"
|
||||||
"tailscale.com/ipn/store/mem"
|
"tailscale.com/ipn/store/mem"
|
||||||
@ -184,6 +185,105 @@ func getEtag(t *testing.T, b any) string {
|
|||||||
return hex.EncodeToString(sum[:])
|
return hex.EncodeToString(sum[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestServeConfigForeground tests the inter-dependency
|
||||||
|
// between a ServeConfig and a WatchIPNBus:
|
||||||
|
// 1. Creating a WatchIPNBus returns a sessionID, that
|
||||||
|
// 2. ServeConfig sets it as the key of the Foreground field.
|
||||||
|
// 3. ServeConfig expects the WatchIPNBus to clean up the Foreground
|
||||||
|
// config when the session is done.
|
||||||
|
// 4. WatchIPNBus expects the ServeConfig to send a signal (close the channel)
|
||||||
|
// if an incoming SetServeConfig removes previous foregrounds.
|
||||||
|
func TestServeConfigForeground(t *testing.T) {
|
||||||
|
b := newTestBackend(t)
|
||||||
|
|
||||||
|
ch1 := make(chan string, 1)
|
||||||
|
go func() {
|
||||||
|
defer close(ch1)
|
||||||
|
b.WatchNotifications(context.Background(), ipn.NotifyInitialState, nil, func(roNotify *ipn.Notify) (keepGoing bool) {
|
||||||
|
if roNotify.SessionID != "" {
|
||||||
|
ch1 <- roNotify.SessionID
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
|
||||||
|
ch2 := make(chan string, 1)
|
||||||
|
go func() {
|
||||||
|
b.WatchNotifications(context.Background(), ipn.NotifyInitialState, nil, func(roNotify *ipn.Notify) (keepGoing bool) {
|
||||||
|
if roNotify.SessionID != "" {
|
||||||
|
ch2 <- roNotify.SessionID
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
ch2 <- "again" // let channel know fn was called again
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
|
||||||
|
var session1 string
|
||||||
|
select {
|
||||||
|
case session1 = <-ch1:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("timed out waiting on watch notifications session id")
|
||||||
|
}
|
||||||
|
|
||||||
|
var session2 string
|
||||||
|
select {
|
||||||
|
case session2 = <-ch2:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("timed out waiting on watch notifications session id")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := b.SetServeConfig(&ipn.ServeConfig{
|
||||||
|
Foreground: map[string]*ipn.ServeConfig{
|
||||||
|
session1: {TCP: map[uint16]*ipn.TCPPortHandler{
|
||||||
|
443: {TCPForward: "http://localhost:3000"}},
|
||||||
|
},
|
||||||
|
session2: {TCP: map[uint16]*ipn.TCPPortHandler{
|
||||||
|
999: {TCPForward: "http://localhost:4000"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setting a new serve config should shut down WatchNotifications
|
||||||
|
// whose session IDs are no longer found: session1 goes, session2 stays.
|
||||||
|
err = b.SetServeConfig(&ipn.ServeConfig{
|
||||||
|
TCP: map[uint16]*ipn.TCPPortHandler{
|
||||||
|
5000: {TCPForward: "http://localhost:5000"},
|
||||||
|
},
|
||||||
|
Foreground: map[string]*ipn.ServeConfig{
|
||||||
|
session2: {TCP: map[uint16]*ipn.TCPPortHandler{
|
||||||
|
999: {TCPForward: "http://localhost:4000"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case _, ok := <-ch1:
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected channel to be closed")
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("timed out waiting on watch notifications closing")
|
||||||
|
}
|
||||||
|
|
||||||
|
// check that the second session is still running
|
||||||
|
b.send(ipn.Notify{})
|
||||||
|
select {
|
||||||
|
case _, ok := <-ch2:
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected second session to remain open")
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("timed out waiting on second session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestServeConfigETag(t *testing.T) {
|
func TestServeConfigETag(t *testing.T) {
|
||||||
b := newTestBackend(t)
|
b := newTestBackend(t)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user