tsnet: cleanup resources upon start failure (#5301)

In a partially initialized state, we should cleanup
all prior resources when an error occurs.

Signed-off-by: Joe Tsai <joetsai@digital-static.net>
This commit is contained in:
Joe Tsai 2022-08-04 16:20:48 -07:00 committed by GitHub
parent f0d6f173c9
commit b1fff4499f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,6 +10,7 @@
import (
"context"
"fmt"
"io"
"io/ioutil"
"log"
"net"
@ -182,7 +183,10 @@ func (s *Server) getAuthKey() string {
return os.Getenv("TS_AUTHKEY")
}
func (s *Server) start() error {
func (s *Server) start() (reterr error) {
var closePool closeOnErrorPool
defer closePool.closeAllIfError(&reterr)
exe, err := os.Executable()
if err != nil {
return err
@ -244,6 +248,7 @@ func (s *Server) start() error {
if err != nil {
return fmt.Errorf("error creating filch: %w", err)
}
closePool.add(s.logbuffer)
c := logtail.Config{
Collection: lpc.Collection,
PrivateID: lpc.PrivateID,
@ -259,11 +264,13 @@ func (s *Server) start() error {
HTTPC: &http.Client{Transport: logpolicy.NewLogtailTransport(logtail.DefaultHost)},
}
s.logtail = logtail.NewLogger(c, logf)
closePool.addFunc(func() { s.logtail.Shutdown(context.Background()) })
s.linkMon, err = monitor.New(logf)
if err != nil {
return err
}
closePool.add(s.linkMon)
s.dialer = new(tsdial.Dialer) // mutated below (before used)
eng, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{
@ -274,6 +281,7 @@ func (s *Server) start() error {
if err != nil {
return err
}
closePool.add(s.dialer)
tunDev, magicConn, dns, ok := eng.(wgengine.InternalsGetter).GetInternals()
if !ok {
@ -317,6 +325,7 @@ func (s *Server) start() error {
lb.SetVarRoot(s.rootPath)
logf("tsnet starting with hostname %q, varRoot %q", s.hostname, s.rootPath)
s.lb = lb
closePool.addFunc(func() { s.lb.Shutdown() })
lb.SetDecompressor(func() (controlclient.Decompressor, error) {
return smallzstd.NewDecoder(nil)
})
@ -357,9 +366,22 @@ func (s *Server) start() error {
logf("localapi serve error: %v", err)
}
}()
closePool.add(s.localAPIListener)
return nil
}
type closeOnErrorPool []func()
func (p *closeOnErrorPool) add(c io.Closer) { *p = append(*p, func() { c.Close() }) }
func (p *closeOnErrorPool) addFunc(fn func()) { *p = append(*p, fn) }
func (p closeOnErrorPool) closeAllIfError(errp *error) {
if *errp != nil {
for _, closeFn := range p {
closeFn()
}
}
}
func (s *Server) logf(format string, a ...interface{}) {
if s.logtail != nil {
s.logtail.Logf(format, a...)