mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-20 05:31:40 +00:00
cmd/containerboot: fix unclean shutdown (#10035)
* cmd/containerboot: shut down cleanly on SIGTERM Make sure that tailscaled watcher returns when SIGTERM is received and also that it shuts down before tailscaled exits. Updates tailscale/tailscale#10090 Signed-off-by: Irbe Krumina <irbe@tailscale.com>
This commit is contained in:
parent
7238586652
commit
664ebb14d9
@ -69,6 +69,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
@ -181,10 +182,16 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
client, daemonPid, err := startTailscaled(bootCtx, cfg)
|
client, daemonProcess, err := startTailscaled(bootCtx, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to bring up tailscale: %v", err)
|
log.Fatalf("failed to bring up tailscale: %v", err)
|
||||||
}
|
}
|
||||||
|
killTailscaled := func() {
|
||||||
|
if err := daemonProcess.Signal(unix.SIGTERM); err != nil {
|
||||||
|
log.Fatalf("error shutting tailscaled down: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer killTailscaled()
|
||||||
|
|
||||||
w, err := client.WatchIPNBus(bootCtx, ipn.NotifyInitialNetMap|ipn.NotifyInitialPrefs|ipn.NotifyInitialState)
|
w, err := client.WatchIPNBus(bootCtx, ipn.NotifyInitialNetMap|ipn.NotifyInitialPrefs|ipn.NotifyInitialState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -252,7 +259,7 @@ authLoop:
|
|||||||
|
|
||||||
w.Close()
|
w.Close()
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background()) // no deadline now that we're in steady state
|
ctx, cancel := contextWithExitSignalWatch()
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if cfg.AuthOnce {
|
if cfg.AuthOnce {
|
||||||
@ -306,12 +313,33 @@ authLoop:
|
|||||||
log.Fatalf("error creating new netfilter runner: %v", err)
|
log.Fatalf("error creating new netfilter runner: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
notifyChan := make(chan ipn.Notify)
|
||||||
|
errChan := make(chan error)
|
||||||
|
go func() {
|
||||||
for {
|
for {
|
||||||
n, err := w.Next()
|
n, err := w.Next()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to read from tailscaled: %v", err)
|
errChan <- err
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
notifyChan <- n
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
runLoop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Although killTailscaled() is deferred earlier, if we
|
||||||
|
// have started the reaper defined below, we need to
|
||||||
|
// kill tailscaled and let reaper clean up child
|
||||||
|
// processes.
|
||||||
|
killTailscaled()
|
||||||
|
break runLoop
|
||||||
|
case err := <-errChan:
|
||||||
|
log.Fatalf("failed to read from tailscaled: %v", err)
|
||||||
|
case n := <-notifyChan:
|
||||||
if n.State != nil && *n.State != ipn.Running {
|
if n.State != nil && *n.State != ipn.Running {
|
||||||
// Something's gone wrong and we've left the authenticated state.
|
// Something's gone wrong and we've left the authenticated state.
|
||||||
// Our container image never recovered gracefully from this, and the
|
// Our container image never recovered gracefully from this, and the
|
||||||
@ -361,11 +389,12 @@ authLoop:
|
|||||||
log.Println("Startup complete, waiting for shutdown signal")
|
log.Println("Startup complete, waiting for shutdown signal")
|
||||||
startupTasksDone = true
|
startupTasksDone = true
|
||||||
|
|
||||||
// Reap all processes, since we are PID1 and need to collect zombies. We can
|
// // Reap all processes, since we are PID1 and need to collect zombies. We can
|
||||||
// only start doing this once we've stopped shelling out to things
|
// // only start doing this once we've stopped shelling out to things
|
||||||
// `tailscale up`, otherwise this goroutine can reap the CLI subprocesses
|
// // `tailscale up`, otherwise this goroutine can reap the CLI subprocesses
|
||||||
// and wedge bringup.
|
// // and wedge bringup.
|
||||||
go func() {
|
reaper := func() {
|
||||||
|
defer wg.Done()
|
||||||
for {
|
for {
|
||||||
var status unix.WaitStatus
|
var status unix.WaitStatus
|
||||||
pid, err := unix.Wait4(-1, &status, 0, nil)
|
pid, err := unix.Wait4(-1, &status, 0, nil)
|
||||||
@ -375,15 +404,20 @@ authLoop:
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Waiting for exited processes: %v", err)
|
log.Fatalf("Waiting for exited processes: %v", err)
|
||||||
}
|
}
|
||||||
if pid == daemonPid {
|
if pid == daemonProcess.Pid {
|
||||||
log.Printf("Tailscaled exited")
|
log.Printf("Tailscaled exited")
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go reaper()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// watchServeConfigChanges watches path for changes, and when it sees one, reads
|
// watchServeConfigChanges watches path for changes, and when it sees one, reads
|
||||||
@ -460,10 +494,8 @@ func readServeConfig(path, certDomain string) (*ipn.ServeConfig, error) {
|
|||||||
return &sc, nil
|
return &sc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient, int, error) {
|
func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient, *os.Process, error) {
|
||||||
args := tailscaledArgs(cfg)
|
args := tailscaledArgs(cfg)
|
||||||
sigCh := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigCh, unix.SIGTERM, unix.SIGINT)
|
|
||||||
// tailscaled runs without context, since it needs to persist
|
// tailscaled runs without context, since it needs to persist
|
||||||
// beyond the startup timeout in ctx.
|
// beyond the startup timeout in ctx.
|
||||||
cmd := exec.Command("tailscaled", args...)
|
cmd := exec.Command("tailscaled", args...)
|
||||||
@ -474,13 +506,8 @@ func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient
|
|||||||
}
|
}
|
||||||
log.Printf("Starting tailscaled")
|
log.Printf("Starting tailscaled")
|
||||||
if err := cmd.Start(); err != nil {
|
if err := cmd.Start(); err != nil {
|
||||||
return nil, 0, fmt.Errorf("starting tailscaled failed: %v", err)
|
return nil, nil, fmt.Errorf("starting tailscaled failed: %v", err)
|
||||||
}
|
}
|
||||||
go func() {
|
|
||||||
<-sigCh
|
|
||||||
log.Printf("Received SIGTERM from container runtime, shutting down tailscaled")
|
|
||||||
cmd.Process.Signal(unix.SIGTERM)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Wait for the socket file to appear, otherwise API ops will racily fail.
|
// Wait for the socket file to appear, otherwise API ops will racily fail.
|
||||||
log.Printf("Waiting for tailscaled socket")
|
log.Printf("Waiting for tailscaled socket")
|
||||||
@ -503,7 +530,7 @@ func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient
|
|||||||
UseSocketOnly: true,
|
UseSocketOnly: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
return tsClient, cmd.Process.Pid, nil
|
return tsClient, cmd.Process, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// tailscaledArgs uses cfg to construct the argv for tailscaled.
|
// tailscaledArgs uses cfg to construct the argv for tailscaled.
|
||||||
@ -801,3 +828,25 @@ func defaultBool(name string, defVal bool) bool {
|
|||||||
}
|
}
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// contextWithExitSignalWatch watches for SIGTERM/SIGINT signals. It returns a
|
||||||
|
// context that gets cancelled when a signal is received and a cancel function
|
||||||
|
// that can be called to free the resources when the watch should be stopped.
|
||||||
|
func contextWithExitSignalWatch() (context.Context, func()) {
|
||||||
|
closeChan := make(chan string)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
signalChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-signalChan:
|
||||||
|
cancel()
|
||||||
|
case <-closeChan:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
f := func() {
|
||||||
|
closeChan <- "goodbye"
|
||||||
|
}
|
||||||
|
return ctx, f
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user