mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-16 11:41:39 +00:00
derp: add sclient.done channel, simplify some context passing
This is mostly prep for a few future CLs, making sure we always have a close-on-dead done channel available to select on when doing other channel operations.
This commit is contained in:
parent
ea90780066
commit
1453aecb44
@ -222,6 +222,9 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string) error
|
|||||||
// At this point we trust the client so we don't time out.
|
// At this point we trust the client so we don't time out.
|
||||||
nc.SetDeadline(time.Time{})
|
nc.SetDeadline(time.Time{})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
c := &sclient{
|
c := &sclient{
|
||||||
s: s,
|
s: s,
|
||||||
key: clientKey,
|
key: clientKey,
|
||||||
@ -229,6 +232,7 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string) error
|
|||||||
br: br,
|
br: br,
|
||||||
bw: bw,
|
bw: bw,
|
||||||
logf: logger.WithPrefix(s.logf, fmt.Sprintf("derp client %v/%x: ", remoteAddr, clientKey)),
|
logf: logger.WithPrefix(s.logf, fmt.Sprintf("derp client %v/%x: ", remoteAddr, clientKey)),
|
||||||
|
done: ctx.Done(),
|
||||||
remoteAddr: remoteAddr,
|
remoteAddr: remoteAddr,
|
||||||
connectedAt: time.Now(),
|
connectedAt: time.Now(),
|
||||||
sendQueue: make(chan pkt, perClientSendQueueDepth),
|
sendQueue: make(chan pkt, perClientSendQueueDepth),
|
||||||
@ -248,8 +252,6 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *sclient) run() error {
|
func (c *sclient) run() error {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
defer func() {
|
defer func() {
|
||||||
// Atomically close+remove send queue, so racing writers don't
|
// Atomically close+remove send queue, so racing writers don't
|
||||||
// send to closed channel.
|
// send to closed channel.
|
||||||
@ -259,7 +261,7 @@ func (c *sclient) run() error {
|
|||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go c.sender(ctx)
|
go c.sender()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
ft, fl, err := readFrameHeader(c.br)
|
ft, fl, err := readFrameHeader(c.br)
|
||||||
@ -270,9 +272,9 @@ func (c *sclient) run() error {
|
|||||||
case frameNotePreferred:
|
case frameNotePreferred:
|
||||||
err = c.handleFrameNotePreferred(ft, fl)
|
err = c.handleFrameNotePreferred(ft, fl)
|
||||||
case frameSendPacket:
|
case frameSendPacket:
|
||||||
err = c.handleFrameSendPacket(ctx, ft, fl)
|
err = c.handleFrameSendPacket(ft, fl)
|
||||||
default:
|
default:
|
||||||
err = c.handleUnknownFrame(ctx, ft, fl)
|
err = c.handleUnknownFrame(ft, fl)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -280,7 +282,7 @@ func (c *sclient) run() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sclient) handleUnknownFrame(ctx context.Context, ft frameType, fl uint32) error {
|
func (c *sclient) handleUnknownFrame(ft frameType, fl uint32) error {
|
||||||
_, err := io.CopyN(ioutil.Discard, c.br, int64(fl))
|
_, err := io.CopyN(ioutil.Discard, c.br, int64(fl))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -297,10 +299,10 @@ func (c *sclient) handleFrameNotePreferred(ft frameType, fl uint32) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sclient) handleFrameSendPacket(ctx context.Context, ft frameType, fl uint32) error {
|
func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
|
||||||
s := c.s
|
s := c.s
|
||||||
|
|
||||||
dstKey, contents, err := s.recvPacket(ctx, c.br, fl)
|
dstKey, contents, err := s.recvPacket(c.br, fl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("client %x: recvPacket: %v", c.key, err)
|
return fmt.Errorf("client %x: recvPacket: %v", c.key, err)
|
||||||
}
|
}
|
||||||
@ -446,7 +448,7 @@ func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.Public, info *cl
|
|||||||
return clientKey, info, nil
|
return clientKey, info, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) recvPacket(ctx context.Context, br *bufio.Reader, frameLen uint32) (dstKey key.Public, contents []byte, err error) {
|
func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.Public, contents []byte, err error) {
|
||||||
if frameLen < keyLen {
|
if frameLen < keyLen {
|
||||||
return key.Public{}, nil, errors.New("short send packet frame")
|
return key.Public{}, nil, errors.New("short send packet frame")
|
||||||
}
|
}
|
||||||
@ -476,6 +478,7 @@ type sclient struct {
|
|||||||
key key.Public
|
key key.Public
|
||||||
info clientInfo
|
info clientInfo
|
||||||
logf logger.Logf
|
logf logger.Logf
|
||||||
|
done <-chan struct{} // closed when connection closes
|
||||||
remoteAddr string // usually ip:port from net.Conn.RemoteAddr().String()
|
remoteAddr string // usually ip:port from net.Conn.RemoteAddr().String()
|
||||||
|
|
||||||
// Owned by run, not thread-safe.
|
// Owned by run, not thread-safe.
|
||||||
@ -521,16 +524,16 @@ func (c *sclient) setPreferred(v bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sclient) sender(ctx context.Context) {
|
func (c *sclient) sender() {
|
||||||
// If the sender shuts down unilaterally due to an error, close so
|
// If the sender shuts down unilaterally due to an error, close so
|
||||||
// that the receive loop unblocks and cleans up the rest.
|
// that the receive loop unblocks and cleans up the rest.
|
||||||
defer c.nc.Close()
|
defer c.nc.Close()
|
||||||
if err := c.sendLoop(ctx); err != nil {
|
if err := c.sendLoop(); err != nil {
|
||||||
c.logf("sender failed: %v", err)
|
c.logf("sender failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sclient) sendLoop(ctx context.Context) error {
|
func (c *sclient) sendLoop() error {
|
||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
queue := c.sendQueue
|
queue := c.sendQueue
|
||||||
c.mu.RUnlock()
|
c.mu.RUnlock()
|
||||||
@ -566,7 +569,7 @@ func (c *sclient) sendLoop(ctx context.Context) error {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-c.done:
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
case pkt, ok := <-queue:
|
case pkt, ok := <-queue:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user