tailscale/net/nettest/pipe.go
David Crawshaw 41ac4a79d6 net/nettest: new package with net-like testing primitives
This is a lot like wiring up a local UDP socket, read and write
deadlines work. The big difference is the Block feature, which
lets you stop the packet flow without breaking the connection.
This lets you emulate broken sockets and test timeouts actually
work.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-03-12 14:42:48 -04:00

262 lines
4.9 KiB
Go

// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package nettest
import (
"context"
"errors"
"fmt"
"io"
"log"
"sync"
"time"
)
const debugPipe = false
// Pipe implements an in-memory FIFO with timeouts.
type Pipe struct {
name string
maxBuf int
rCh chan struct{}
wCh chan struct{}
mu sync.Mutex
closed bool
blocked bool
buf []byte
readTimeout time.Time
writeTimeout time.Time
cancelReadTimer func()
cancelWriteTimer func()
}
// NewPipe creates a Pipe with a buffer size fixed at maxBuf.
func NewPipe(name string, maxBuf int) *Pipe {
return &Pipe{
name: name,
maxBuf: maxBuf,
rCh: make(chan struct{}, 1),
wCh: make(chan struct{}, 1),
}
}
var (
ErrTimeout = errors.New("timeout")
ErrReadTimeout = fmt.Errorf("read %w", ErrTimeout)
ErrWriteTimeout = fmt.Errorf("write %w", ErrTimeout)
)
// Read implements io.Reader.
func (p *Pipe) Read(b []byte) (n int, err error) {
if debugPipe {
orig := b
defer func() {
log.Printf("Pipe(%q).Read( %q) n=%d, err=%v", p.name, string(orig[:n]), n, err)
}()
}
for {
p.mu.Lock()
closed := p.closed
timedout := !p.readTimeout.IsZero() && time.Now().After(p.readTimeout)
blocked := p.blocked
if !closed && !timedout && len(p.buf) > 0 {
n2 := copy(b, p.buf)
p.buf = p.buf[n2:]
b = b[n2:]
n += n2
}
p.mu.Unlock()
if closed {
return 0, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
}
if timedout {
return 0, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrReadTimeout)
}
if blocked {
<-p.rCh
continue
}
if n > 0 {
p.signalWrite()
return n, nil
}
<-p.rCh
}
}
// Write implements io.Writer.
func (p *Pipe) Write(b []byte) (n int, err error) {
if debugPipe {
orig := b
defer func() {
log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err)
}()
}
for {
p.mu.Lock()
closed := p.closed
timedout := !p.writeTimeout.IsZero() && time.Now().After(p.writeTimeout)
blocked := p.blocked
if !closed && !timedout {
n2 := len(b)
if limit := p.maxBuf - len(p.buf); limit < n2 {
n2 = limit
}
p.buf = append(p.buf, b[:n2]...)
b = b[n2:]
n += n2
}
p.mu.Unlock()
if closed {
return n, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
}
if timedout {
return n, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrWriteTimeout)
}
if blocked {
<-p.wCh
continue
}
if n > 0 {
p.signalRead()
}
if len(b) == 0 {
return n, nil
}
<-p.wCh
}
}
// Close implements io.Closer.
func (p *Pipe) Close() error {
p.mu.Lock()
closed := p.closed
p.closed = true
if p.cancelWriteTimer != nil {
p.cancelWriteTimer()
p.cancelWriteTimer = nil
}
if p.cancelReadTimer != nil {
p.cancelReadTimer()
p.cancelReadTimer = nil
}
p.mu.Unlock()
if closed {
return fmt.Errorf("nettest.Pipe(%q).Close: already closed", p.name)
}
p.signalRead()
p.signalWrite()
return nil
}
// SetReadDeadline sets the deadline for future Read calls.
func (p *Pipe) SetReadDeadline(t time.Time) error {
p.mu.Lock()
p.readTimeout = t
if p.cancelReadTimer != nil {
p.cancelReadTimer()
p.cancelReadTimer = nil
}
if d := time.Until(t); !t.IsZero() && d > 0 {
ctx, cancel := context.WithCancel(context.Background())
p.cancelReadTimer = cancel
go func() {
t := time.NewTimer(d)
defer t.Stop()
select {
case <-t.C:
p.signalRead()
case <-ctx.Done():
}
}()
}
p.mu.Unlock()
p.signalRead()
return nil
}
// SetWriteDeadline sets the deadline for future Write calls.
func (p *Pipe) SetWriteDeadline(t time.Time) error {
p.mu.Lock()
p.writeTimeout = t
if p.cancelWriteTimer != nil {
p.cancelWriteTimer()
p.cancelWriteTimer = nil
}
if d := time.Until(t); !t.IsZero() && d > 0 {
ctx, cancel := context.WithCancel(context.Background())
p.cancelWriteTimer = cancel
go func() {
t := time.NewTimer(d)
defer t.Stop()
select {
case <-t.C:
p.signalWrite()
case <-ctx.Done():
}
}()
}
p.mu.Unlock()
p.signalWrite()
return nil
}
func (p *Pipe) Block() error {
p.mu.Lock()
closed := p.closed
blocked := p.blocked
p.blocked = true
p.mu.Unlock()
if closed {
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
}
if blocked {
return fmt.Errorf("nettest.Pipe(%q).Block: already blocked", p.name)
}
p.signalRead()
p.signalWrite()
return nil
}
func (p *Pipe) Unblock() error {
p.mu.Lock()
closed := p.closed
blocked := p.blocked
p.blocked = false
p.mu.Unlock()
if closed {
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
}
if !blocked {
return fmt.Errorf("nettest.Pipe(%q).Block: already unblocked", p.name)
}
p.signalRead()
p.signalWrite()
return nil
}
func (p *Pipe) signalRead() {
select {
case p.rCh <- struct{}{}:
default:
}
}
func (p *Pipe) signalWrite() {
select {
case p.wCh <- struct{}{}:
default:
}
}