derp: add sclient.done channel, simplify some context passing

This is mostly prep for a few future CLs, making sure we always have a
close-on-dead done channel available to select on when doing other
channel operations.
This commit is contained in:
Brad Fitzpatrick 2020-03-21 18:28:34 -07:00
parent ea90780066
commit 1453aecb44

View File

@ -222,6 +222,9 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string) error
// At this point we trust the client so we don't time out. // At this point we trust the client so we don't time out.
nc.SetDeadline(time.Time{}) nc.SetDeadline(time.Time{})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c := &sclient{ c := &sclient{
s: s, s: s,
key: clientKey, key: clientKey,
@ -229,6 +232,7 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string) error
br: br, br: br,
bw: bw, bw: bw,
logf: logger.WithPrefix(s.logf, fmt.Sprintf("derp client %v/%x: ", remoteAddr, clientKey)), logf: logger.WithPrefix(s.logf, fmt.Sprintf("derp client %v/%x: ", remoteAddr, clientKey)),
done: ctx.Done(),
remoteAddr: remoteAddr, remoteAddr: remoteAddr,
connectedAt: time.Now(), connectedAt: time.Now(),
sendQueue: make(chan pkt, perClientSendQueueDepth), sendQueue: make(chan pkt, perClientSendQueueDepth),
@ -248,8 +252,6 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string) error
} }
func (c *sclient) run() error { func (c *sclient) run() error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
defer func() { defer func() {
// Atomically close+remove send queue, so racing writers don't // Atomically close+remove send queue, so racing writers don't
// send to closed channel. // send to closed channel.
@ -259,7 +261,7 @@ func (c *sclient) run() error {
c.mu.Unlock() c.mu.Unlock()
}() }()
go c.sender(ctx) go c.sender()
for { for {
ft, fl, err := readFrameHeader(c.br) ft, fl, err := readFrameHeader(c.br)
@ -270,9 +272,9 @@ func (c *sclient) run() error {
case frameNotePreferred: case frameNotePreferred:
err = c.handleFrameNotePreferred(ft, fl) err = c.handleFrameNotePreferred(ft, fl)
case frameSendPacket: case frameSendPacket:
err = c.handleFrameSendPacket(ctx, ft, fl) err = c.handleFrameSendPacket(ft, fl)
default: default:
err = c.handleUnknownFrame(ctx, ft, fl) err = c.handleUnknownFrame(ft, fl)
} }
if err != nil { if err != nil {
return err return err
@ -280,7 +282,7 @@ func (c *sclient) run() error {
} }
} }
func (c *sclient) handleUnknownFrame(ctx context.Context, ft frameType, fl uint32) error { func (c *sclient) handleUnknownFrame(ft frameType, fl uint32) error {
_, err := io.CopyN(ioutil.Discard, c.br, int64(fl)) _, err := io.CopyN(ioutil.Discard, c.br, int64(fl))
return err return err
} }
@ -297,10 +299,10 @@ func (c *sclient) handleFrameNotePreferred(ft frameType, fl uint32) error {
return nil return nil
} }
func (c *sclient) handleFrameSendPacket(ctx context.Context, ft frameType, fl uint32) error { func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
s := c.s s := c.s
dstKey, contents, err := s.recvPacket(ctx, c.br, fl) dstKey, contents, err := s.recvPacket(c.br, fl)
if err != nil { if err != nil {
return fmt.Errorf("client %x: recvPacket: %v", c.key, err) return fmt.Errorf("client %x: recvPacket: %v", c.key, err)
} }
@ -446,7 +448,7 @@ func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.Public, info *cl
return clientKey, info, nil return clientKey, info, nil
} }
func (s *Server) recvPacket(ctx context.Context, br *bufio.Reader, frameLen uint32) (dstKey key.Public, contents []byte, err error) { func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.Public, contents []byte, err error) {
if frameLen < keyLen { if frameLen < keyLen {
return key.Public{}, nil, errors.New("short send packet frame") return key.Public{}, nil, errors.New("short send packet frame")
} }
@ -476,6 +478,7 @@ type sclient struct {
key key.Public key key.Public
info clientInfo info clientInfo
logf logger.Logf logf logger.Logf
done <-chan struct{} // closed when connection closes
remoteAddr string // usually ip:port from net.Conn.RemoteAddr().String() remoteAddr string // usually ip:port from net.Conn.RemoteAddr().String()
// Owned by run, not thread-safe. // Owned by run, not thread-safe.
@ -521,16 +524,16 @@ func (c *sclient) setPreferred(v bool) {
} }
} }
func (c *sclient) sender(ctx context.Context) { func (c *sclient) sender() {
// If the sender shuts down unilaterally due to an error, close so // If the sender shuts down unilaterally due to an error, close so
// that the receive loop unblocks and cleans up the rest. // that the receive loop unblocks and cleans up the rest.
defer c.nc.Close() defer c.nc.Close()
if err := c.sendLoop(ctx); err != nil { if err := c.sendLoop(); err != nil {
c.logf("sender failed: %v", err) c.logf("sender failed: %v", err)
} }
} }
func (c *sclient) sendLoop(ctx context.Context) error { func (c *sclient) sendLoop() error {
c.mu.RLock() c.mu.RLock()
queue := c.sendQueue queue := c.sendQueue
c.mu.RUnlock() c.mu.RUnlock()
@ -566,7 +569,7 @@ func (c *sclient) sendLoop(ctx context.Context) error {
for { for {
select { select {
case <-ctx.Done(): case <-c.done:
return nil return nil
case pkt, ok := <-queue: case pkt, ok := <-queue: