ipn: misc cleanup

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2020-02-25 07:36:32 -08:00
parent 04e6b77774
commit 367ffde21a
3 changed files with 38 additions and 26 deletions

View File

@ -76,16 +76,21 @@ func pump(logf logger.Logf, ctx context.Context, bs *ipn.BackendServer, s net.Co
} }
} }
func Run(rctx context.Context, logf logger.Logf, logid string, opts Options, e wgengine.Engine) error { func Run(rctx context.Context, logf logger.Logf, logid string, opts Options, e wgengine.Engine) (err error) {
bo := backoff.Backoff{Name: "ipnserver"} runDone := make(chan error, 1)
defer func() { runDone <- err }()
listen, _, err := safesocket.Listen(opts.SocketPath, uint16(opts.Port)) listen, _, err := safesocket.Listen(opts.SocketPath, uint16(opts.Port))
if err != nil { if err != nil {
return fmt.Errorf("safesocket.Listen: %v", err) return fmt.Errorf("safesocket.Listen: %v", err)
} }
// Go listeners can't take a context, close it instead. // Go listeners can't take a context, close it instead.
go func() { go func() {
<-rctx.Done() select {
case <-rctx.Done():
case <-runDone:
}
listen.Close() listen.Close()
}() }()
logf("Listening on %v\n", listen.Addr()) logf("Listening on %v\n", listen.Addr())
@ -130,13 +135,11 @@ func Run(rctx context.Context, logf logger.Logf, logid string, opts Options, e w
}) })
} }
var oldS net.Conn var (
//lint:ignore SA4006 ctx is never used, but has to be defined so oldS net.Conn
// that it can be assigned to in the following for loop. It's a ctx context.Context
// bit of necessary code convolution to work around Go's variable cancel context.CancelFunc
// shadowing rules. )
ctx, cancel := context.WithCancel(rctx)
stopAll := func() { stopAll := func() {
// Currently we only support one client connection at a time. // Currently we only support one client connection at a time.
// Theoretically we could allow multiple clients, by passing // Theoretically we could allow multiple clients, by passing
@ -150,6 +153,8 @@ func Run(rctx context.Context, logf logger.Logf, logid string, opts Options, e w
} }
} }
bo := backoff.Backoff{Name: "ipnserver"}
for i := 1; rctx.Err() == nil; i++ { for i := 1; rctx.Err() == nil; i++ {
s, err = listen.Accept() s, err = listen.Accept()
if err != nil { if err != nil {
@ -160,10 +165,11 @@ func Run(rctx context.Context, logf logger.Logf, logid string, opts Options, e w
logf("%d: Incoming control connection.\n", i) logf("%d: Incoming control connection.\n", i)
stopAll() stopAll()
ctx, cancel = context.WithCancel(context.Background()) ctx, cancel = context.WithCancel(rctx)
oldS = s oldS = s
go func(ctx context.Context, bs *ipn.BackendServer, s net.Conn, i int) { go func(ctx context.Context, bs *ipn.BackendServer, s net.Conn, i int) {
// TODO: move this prefixing-Logf code into a new helper in types/logger?
si := fmt.Sprintf("%d: ", i) si := fmt.Sprintf("%d: ", i)
pump(func(fmt string, args ...interface{}) { pump(func(fmt string, args ...interface{}) {
logf(si+fmt, args...) logf(si+fmt, args...)

View File

@ -71,7 +71,7 @@ func NewLocalBackend(logf logger.Logf, logid string, store StateStore, e wgengin
logf("skipping portlist: %s\n", err) logf("skipping portlist: %s\n", err)
} }
b := LocalBackend{ b := &LocalBackend{
logf: logf, logf: logf,
e: e, e: e,
store: store, store: store,
@ -86,7 +86,7 @@ func NewLocalBackend(logf logger.Logf, logid string, store StateStore, e wgengin
go b.runPoller() go b.runPoller()
} }
return &b, nil return b, nil
} }
func (b *LocalBackend) Shutdown() { func (b *LocalBackend) Shutdown() {

View File

@ -48,9 +48,9 @@ type Command struct {
type BackendServer struct { type BackendServer struct {
logf logger.Logf logf logger.Logf
b Backend // the Backend we are serving up b Backend // the Backend we are serving up
sendNotifyMsg func(b []byte) // send a notification message sendNotifyMsg func(jsonMsg []byte) // send a notification message
GotQuit bool // a Quit command was received GotQuit bool // a Quit command was received
} }
func NewBackendServer(logf logger.Logf, b Backend, sendNotifyMsg func(b []byte)) *BackendServer { func NewBackendServer(logf logger.Logf, b Backend, sendNotifyMsg func(b []byte)) *BackendServer {
@ -70,13 +70,14 @@ func (bs *BackendServer) send(n Notify) {
bs.sendNotifyMsg(b) bs.sendNotifyMsg(b)
} }
// Inform the BackendServer of an incoming message. // GotCommandMsg parses the incoming message b as a JSON Command and
// calls GotCommand with it.
func (bs *BackendServer) GotCommandMsg(b []byte) error { func (bs *BackendServer) GotCommandMsg(b []byte) error {
cmd := Command{} cmd := &Command{}
if err := json.Unmarshal(b, &cmd); err != nil { if err := json.Unmarshal(b, cmd); err != nil {
return err return err
} }
return bs.GotCommand(&cmd) return bs.GotCommand(cmd)
} }
func (bs *BackendServer) GotCommand(cmd *Command) error { func (bs *BackendServer) GotCommand(cmd *Command) error {
@ -130,11 +131,11 @@ func (bs *BackendServer) Reset() error {
type BackendClient struct { type BackendClient struct {
logf logger.Logf logf logger.Logf
sendCommandMsg func(b []byte) sendCommandMsg func(jsonb []byte)
notify func(n Notify) notify func(Notify)
} }
func NewBackendClient(logf logger.Logf, sendCommandMsg func(b []byte)) *BackendClient { func NewBackendClient(logf logger.Logf, sendCommandMsg func(jsonb []byte)) *BackendClient {
return &BackendClient{ return &BackendClient{
logf: logf, logf: logf,
sendCommandMsg: sendCommandMsg, sendCommandMsg: sendCommandMsg,
@ -203,7 +204,8 @@ func (bc *BackendClient) FakeExpireAfter(x time.Duration) {
bc.send(Command{FakeExpireAfter: &FakeExpireAfterArgs{Duration: x}}) bc.send(Command{FakeExpireAfter: &FakeExpireAfterArgs{Duration: x}})
} }
const MSG_MAX = 1024 * 1024 // MaxMessageSize is the maximum message size, in bytes.
const MaxMessageSize = 1 << 20
// TODO(apenwarr): incremental json decode? // TODO(apenwarr): incremental json decode?
// That would let us avoid storing the whole byte array uselessly in RAM. // That would let us avoid storing the whole byte array uselessly in RAM.
@ -214,7 +216,7 @@ func ReadMsg(r io.Reader) ([]byte, error) {
return nil, err return nil, err
} }
n := binary.LittleEndian.Uint32(cb) n := binary.LittleEndian.Uint32(cb)
if n > 1024*1024 { if n > MaxMessageSize {
return nil, fmt.Errorf("ipn.Read: message too large: %v bytes", n) return nil, fmt.Errorf("ipn.Read: message too large: %v bytes", n)
} }
b := make([]byte, n) b := make([]byte, n)
@ -229,8 +231,12 @@ func ReadMsg(r io.Reader) ([]byte, error) {
// That would save RAM, at the expense of having to encode once so that // That would save RAM, at the expense of having to encode once so that
// we can produce the initial byte count. // we can produce the initial byte count.
func WriteMsg(w io.Writer, b []byte) error { func WriteMsg(w io.Writer, b []byte) error {
// TODO(bradfitz): this does two writes to w, which likely
// does two writes on the wire, two frame generations, etc. We
// should take a concrete buffered type, or use a sync.Pool to
// allocate a buf and do one write.
cb := make([]byte, 4) cb := make([]byte, 4)
if len(b) > MSG_MAX { if len(b) > MaxMessageSize {
return fmt.Errorf("ipn.Write: message too large: %v bytes", len(b)) return fmt.Errorf("ipn.Write: message too large: %v bytes", len(b))
} }
binary.LittleEndian.PutUint32(cb, uint32(len(b))) binary.LittleEndian.PutUint32(cb, uint32(len(b)))