From b1fff4499f40cdcfa8630638156adeb6f0ebd954 Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Thu, 4 Aug 2022 16:20:48 -0700 Subject: [PATCH] tsnet: cleanup resources upon start failure (#5301) In a partially initialized state, we should cleanup all prior resources when an error occurs. Signed-off-by: Joe Tsai --- tsnet/tsnet.go | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index d9439927c..ffc57deed 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -10,6 +10,7 @@ import ( "context" "fmt" + "io" "io/ioutil" "log" "net" @@ -182,7 +183,10 @@ func (s *Server) getAuthKey() string { return os.Getenv("TS_AUTHKEY") } -func (s *Server) start() error { +func (s *Server) start() (reterr error) { + var closePool closeOnErrorPool + defer closePool.closeAllIfError(&reterr) + exe, err := os.Executable() if err != nil { return err @@ -244,6 +248,7 @@ func (s *Server) start() error { if err != nil { return fmt.Errorf("error creating filch: %w", err) } + closePool.add(s.logbuffer) c := logtail.Config{ Collection: lpc.Collection, PrivateID: lpc.PrivateID, @@ -259,11 +264,13 @@ func (s *Server) start() error { HTTPC: &http.Client{Transport: logpolicy.NewLogtailTransport(logtail.DefaultHost)}, } s.logtail = logtail.NewLogger(c, logf) + closePool.addFunc(func() { s.logtail.Shutdown(context.Background()) }) s.linkMon, err = monitor.New(logf) if err != nil { return err } + closePool.add(s.linkMon) s.dialer = new(tsdial.Dialer) // mutated below (before used) eng, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{ @@ -274,6 +281,7 @@ func (s *Server) start() error { if err != nil { return err } + closePool.add(s.dialer) tunDev, magicConn, dns, ok := eng.(wgengine.InternalsGetter).GetInternals() if !ok { @@ -317,6 +325,7 @@ func (s *Server) start() error { lb.SetVarRoot(s.rootPath) logf("tsnet starting with hostname %q, varRoot %q", s.hostname, s.rootPath) s.lb = lb + closePool.addFunc(func() { s.lb.Shutdown() }) lb.SetDecompressor(func() (controlclient.Decompressor, error) { return smallzstd.NewDecoder(nil) }) @@ -357,9 +366,22 @@ func (s *Server) start() error { logf("localapi serve error: %v", err) } }() + closePool.add(s.localAPIListener) return nil } +type closeOnErrorPool []func() + +func (p *closeOnErrorPool) add(c io.Closer) { *p = append(*p, func() { c.Close() }) } +func (p *closeOnErrorPool) addFunc(fn func()) { *p = append(*p, fn) } +func (p closeOnErrorPool) closeAllIfError(errp *error) { + if *errp != nil { + for _, closeFn := range p { + closeFn() + } + } +} + func (s *Server) logf(format string, a ...interface{}) { if s.logtail != nil { s.logtail.Logf(format, a...)