net/nettest: make nettest.NewConn pass x/net/nettest.TestConn.

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2021-04-02 20:04:52 -07:00 committed by Maisem Ali
parent e0e677a8f6
commit 57756ef673
6 changed files with 284 additions and 154 deletions

View File

@ -5,21 +5,13 @@
package nettest package nettest
import ( import (
"io" "net"
"time" "time"
) )
// Conn is a bi-directional in-memory stream that looks like a TCP net.Conn. // Conn is a net.Conn that can additionally have its reads and writes blocked and unblocked.
type Conn interface { type Conn interface {
io.Reader net.Conn
io.Writer
io.Closer
// The *Deadline methods follow the semantics of net.Conn.
SetDeadline(t time.Time) error
SetReadDeadline(t time.Time) error
SetWriteDeadline(t time.Time) error
// SetReadBlock blocks or unblocks the Read method of this Conn. // SetReadBlock blocks or unblocks the Read method of this Conn.
// It reports an error if the existing value matches the new value, // It reports an error if the existing value matches the new value,
@ -40,24 +32,37 @@ func NewConn(name string, maxBuf int) (Conn, Conn) {
return &connHalf{r: r, w: w}, &connHalf{r: w, w: r} return &connHalf{r: r, w: w}, &connHalf{r: w, w: r}
} }
type connAddr string
func (a connAddr) Network() string { return "mem" }
func (a connAddr) String() string { return string(a) }
type connHalf struct { type connHalf struct {
r, w *Pipe r, w *Pipe
} }
func (c *connHalf) LocalAddr() net.Addr {
return connAddr(c.r.name)
}
func (c *connHalf) RemoteAddr() net.Addr {
return connAddr(c.w.name)
}
func (c *connHalf) Read(b []byte) (n int, err error) { func (c *connHalf) Read(b []byte) (n int, err error) {
return c.r.Read(b) return c.r.Read(b)
} }
func (c *connHalf) Write(b []byte) (n int, err error) { func (c *connHalf) Write(b []byte) (n int, err error) {
return c.w.Write(b) return c.w.Write(b)
} }
func (c *connHalf) Close() error { func (c *connHalf) Close() error {
err1 := c.r.Close() if err := c.w.Close(); err != nil {
err2 := c.w.Close() return err
if err1 != nil {
return err1
} }
return err2 return c.r.Close()
} }
func (c *connHalf) SetDeadline(t time.Time) error { func (c *connHalf) SetDeadline(t time.Time) error {
err1 := c.SetReadDeadline(t) err1 := c.SetReadDeadline(t)
err2 := c.SetWriteDeadline(t) err2 := c.SetWriteDeadline(t)
@ -72,6 +77,7 @@ func (c *connHalf) SetReadDeadline(t time.Time) error {
func (c *connHalf) SetWriteDeadline(t time.Time) error { func (c *connHalf) SetWriteDeadline(t time.Time) error {
return c.w.SetWriteDeadline(t) return c.w.SetWriteDeadline(t)
} }
func (c *connHalf) SetReadBlock(b bool) error { func (c *connHalf) SetReadBlock(b bool) error {
if b { if b {
return c.r.Block() return c.r.Block()

22
net/nettest/conn_test.go Normal file
View File

@ -0,0 +1,22 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package nettest
import (
"net"
"testing"
"golang.org/x/net/nettest"
)
func TestConn(t *testing.T) {
nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) {
c1, c2 = NewConn("test", bufferSize)
return c1, c2, func() {
c1.Close()
c2.Close()
}, nil
})
}

83
net/nettest/listener.go Normal file
View File

@ -0,0 +1,83 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package nettest
import (
"context"
"net"
"strings"
"sync"
)
const (
bufferSize = 256 * 1024
)
// Listener is a net.Listener using using NewConn to create pairs of network
// connections connected in memory using a buffered pipe. It also provides a
// Dial method to establish new connections.
type Listener struct {
addr connAddr
ch chan Conn
closeOnce sync.Once
closed chan struct{}
}
// Listen returns a new Listener for the provided address.
func Listen(addr string) *Listener {
return &Listener{
addr: connAddr(addr),
ch: make(chan Conn),
closed: make(chan struct{}),
}
}
// Addr implements net.Listener.Addr.
func (l *Listener) Addr() net.Addr {
return l.addr
}
// Close closes the pipe listener.
func (l *Listener) Close() error {
l.closeOnce.Do(func() {
close(l.closed)
})
return nil
}
// Accept blocks until a new connection is available or the listener is closed.
func (l *Listener) Accept() (net.Conn, error) {
select {
case c := <-l.ch:
return c, nil
case <-l.closed:
return nil, net.ErrClosed
}
}
// Dial connects to the listener using the provided context.
// The provided Context must be non-nil. If the context expires before the
// connection is complete, an error is returned. Once successfully connected
// any expiration of the context will not affect the connection.
func (l *Listener) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
if !strings.HasSuffix(network, "tcp") {
return nil, net.UnknownNetworkError(network)
}
if connAddr(addr) != l.addr {
return nil, &net.AddrError{
Err: "invalid address",
Addr: addr,
}
}
c, s := NewConn(addr, bufferSize)
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-l.closed:
return nil, net.ErrClosed
case l.ch <- s:
return c, nil
}
}

View File

@ -0,0 +1,34 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package nettest
import (
"context"
"testing"
)
func TestListener(t *testing.T) {
l := Listen("srv.local")
defer l.Close()
go func() {
c, err := l.Accept()
if err != nil {
t.Error(err)
return
}
defer c.Close()
}()
if c, err := l.Dial(context.Background(), "tcp", "invalid"); err == nil {
c.Close()
t.Fatalf("dial to invalid address succeeded")
}
c, err := l.Dial(context.Background(), "tcp", "srv.local")
if err != nil {
t.Fatalf("dial failed: %v", err)
return
}
c.Close()
}

View File

@ -5,11 +5,13 @@
package nettest package nettest
import ( import (
"bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
"net"
"os"
"sync" "sync"
"time" "time"
) )
@ -20,13 +22,12 @@
type Pipe struct { type Pipe struct {
name string name string
maxBuf int maxBuf int
rCh chan struct{} mu sync.Mutex
wCh chan struct{} cnd *sync.Cond
mu sync.Mutex
closed bool
blocked bool blocked bool
buf []byte closed bool
buf bytes.Buffer
readTimeout time.Time readTimeout time.Time
writeTimeout time.Time writeTimeout time.Time
cancelReadTimer func() cancelReadTimer func()
@ -35,21 +36,42 @@ type Pipe struct {
// NewPipe creates a Pipe with a buffer size fixed at maxBuf. // NewPipe creates a Pipe with a buffer size fixed at maxBuf.
func NewPipe(name string, maxBuf int) *Pipe { func NewPipe(name string, maxBuf int) *Pipe {
return &Pipe{ p := &Pipe{
name: name, name: name,
maxBuf: maxBuf, maxBuf: maxBuf,
rCh: make(chan struct{}, 1),
wCh: make(chan struct{}, 1),
} }
p.cnd = sync.NewCond(&p.mu)
return p
} }
var ( // readOrBlock attempts to read from the buffer, if the buffer is empty and
ErrTimeout = errors.New("timeout") // the connection hasn't been closed it will block until there is a change.
ErrReadTimeout = fmt.Errorf("read %w", ErrTimeout) func (p *Pipe) readOrBlock(b []byte) (int, error) {
ErrWriteTimeout = fmt.Errorf("write %w", ErrTimeout) p.mu.Lock()
) defer p.mu.Unlock()
if !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout) {
return 0, os.ErrDeadlineExceeded
}
if p.blocked {
p.cnd.Wait()
return 0, nil
}
n, err := p.buf.Read(b)
// err will either be nil or io.EOF.
if err == io.EOF {
if p.closed {
return n, err
}
// Wait for something to change.
p.cnd.Wait()
}
return n, nil
}
// Read implements io.Reader. // Read implements io.Reader.
// Once the buffer is drained (i.e. after Close), subsequent calls will
// return io.EOF.
func (p *Pipe) Read(b []byte) (n int, err error) { func (p *Pipe) Read(b []byte) (n int, err error) {
if debugPipe { if debugPipe {
orig := b orig := b
@ -57,35 +79,48 @@ func (p *Pipe) Read(b []byte) (n int, err error) {
log.Printf("Pipe(%q).Read( %q) n=%d, err=%v", p.name, string(orig[:n]), n, err) log.Printf("Pipe(%q).Read( %q) n=%d, err=%v", p.name, string(orig[:n]), n, err)
}() }()
} }
for { for n == 0 {
p.mu.Lock() n2, err := p.readOrBlock(b)
closed := p.closed if err != nil {
timedout := !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout) return n2, err
blocked := p.blocked
if !closed && !timedout && len(p.buf) > 0 {
n2 := copy(b, p.buf)
p.buf = p.buf[n2:]
b = b[n2:]
n += n2
} }
p.mu.Unlock() n += n2
if closed {
return 0, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
}
if timedout {
return 0, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrReadTimeout)
}
if blocked {
<-p.rCh
continue
}
if n > 0 {
p.signalWrite()
return n, nil
}
<-p.rCh
} }
p.cnd.Signal()
return n, nil
}
// writeOrBlock attempts to write to the buffer, if the buffer is full it will
// block until there is a change.
func (p *Pipe) writeOrBlock(b []byte) (int, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return 0, net.ErrClosed
}
if !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout) {
return 0, os.ErrDeadlineExceeded
}
if p.blocked {
p.cnd.Wait()
return 0, nil
}
// Optimistically we want to write the entire slice.
n := len(b)
if limit := p.maxBuf - p.buf.Len(); limit < n {
// However, we don't have enough capacity to write everything.
n = limit
}
if n == 0 {
// Wait for something to change.
p.cnd.Wait()
return 0, nil
}
p.buf.Write(b[:n])
p.cnd.Signal()
return n, nil
} }
// Write implements io.Writer. // Write implements io.Writer.
@ -96,47 +131,23 @@ func (p *Pipe) Write(b []byte) (n int, err error) {
log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err) log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err)
}() }()
} }
for { for len(b) > 0 {
p.mu.Lock() n2, err := p.writeOrBlock(b)
closed := p.closed if err != nil {
timedout := !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout) return n + n2, err
blocked := p.blocked
if !closed && !timedout {
n2 := len(b)
if limit := p.maxBuf - len(p.buf); limit < n2 {
n2 = limit
}
p.buf = append(p.buf, b[:n2]...)
b = b[n2:]
n += n2
} }
p.mu.Unlock() n += n2
b = b[n2:]
if closed {
return n, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
}
if timedout {
return n, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrWriteTimeout)
}
if blocked {
<-p.wCh
continue
}
if n > 0 {
p.signalRead()
}
if len(b) == 0 {
return n, nil
}
<-p.wCh
} }
return n, nil
} }
// Close implements io.Closer. // Close closes the pipe.
func (p *Pipe) Close() error { func (p *Pipe) Close() error {
p.mu.Lock() p.mu.Lock()
closed := p.closed defer p.mu.Unlock()
p.closed = true p.closed = true
p.blocked = false
if p.cancelWriteTimer != nil { if p.cancelWriteTimer != nil {
p.cancelWriteTimer() p.cancelWriteTimer()
p.cancelWriteTimer = nil p.cancelWriteTimer = nil
@ -145,77 +156,65 @@ func (p *Pipe) Close() error {
p.cancelReadTimer() p.cancelReadTimer()
p.cancelReadTimer = nil p.cancelReadTimer = nil
} }
p.mu.Unlock() p.cnd.Broadcast()
if closed {
return fmt.Errorf("nettest.Pipe(%q).Close: already closed", p.name)
}
p.signalRead()
p.signalWrite()
return nil return nil
} }
func (p *Pipe) deadlineTimer(t time.Time) func() {
if t.IsZero() {
return nil
}
if t.Before(time.Now()) {
p.cnd.Broadcast()
return nil
}
ctx, cancel := context.WithDeadline(context.Background(), t)
go func() {
<-ctx.Done()
if ctx.Err() == context.DeadlineExceeded {
p.cnd.Broadcast()
}
}()
return cancel
}
// SetReadDeadline sets the deadline for future Read calls. // SetReadDeadline sets the deadline for future Read calls.
func (p *Pipe) SetReadDeadline(t time.Time) error { func (p *Pipe) SetReadDeadline(t time.Time) error {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock()
p.readTimeout = t p.readTimeout = t
// If we already have a deadline, cancel it and create a new one.
if p.cancelReadTimer != nil { if p.cancelReadTimer != nil {
p.cancelReadTimer() p.cancelReadTimer()
p.cancelReadTimer = nil p.cancelReadTimer = nil
} }
if d := time.Until(t); !t.IsZero() && d > 0 { p.cancelReadTimer = p.deadlineTimer(t)
ctx, cancel := context.WithCancel(context.Background())
p.cancelReadTimer = cancel
go func() {
t := time.NewTimer(d)
defer t.Stop()
select {
case <-t.C:
p.signalRead()
case <-ctx.Done():
}
}()
}
p.mu.Unlock()
p.signalRead()
return nil return nil
} }
// SetWriteDeadline sets the deadline for future Write calls. // SetWriteDeadline sets the deadline for future Write calls.
func (p *Pipe) SetWriteDeadline(t time.Time) error { func (p *Pipe) SetWriteDeadline(t time.Time) error {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock()
p.writeTimeout = t p.writeTimeout = t
// If we already have a deadline, cancel it and create a new one.
if p.cancelWriteTimer != nil { if p.cancelWriteTimer != nil {
p.cancelWriteTimer() p.cancelWriteTimer()
p.cancelWriteTimer = nil p.cancelWriteTimer = nil
} }
if d := time.Until(t); !t.IsZero() && d > 0 { p.cancelWriteTimer = p.deadlineTimer(t)
ctx, cancel := context.WithCancel(context.Background())
p.cancelWriteTimer = cancel
go func() {
t := time.NewTimer(d)
defer t.Stop()
select {
case <-t.C:
p.signalWrite()
case <-ctx.Done():
}
}()
}
p.mu.Unlock()
p.signalWrite()
return nil return nil
} }
// Block will cause all calls to Read and Write to block until they either
// timeout, are unblocked or the pipe is closed.
func (p *Pipe) Block() error { func (p *Pipe) Block() error {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock()
closed := p.closed closed := p.closed
blocked := p.blocked blocked := p.blocked
p.blocked = true p.blocked = true
p.mu.Unlock()
if closed { if closed {
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name) return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
@ -223,17 +222,17 @@ func (p *Pipe) Block() error {
if blocked { if blocked {
return fmt.Errorf("nettest.Pipe(%q).Block: already blocked", p.name) return fmt.Errorf("nettest.Pipe(%q).Block: already blocked", p.name)
} }
p.signalRead() p.cnd.Broadcast()
p.signalWrite()
return nil return nil
} }
// Unblock will cause all blocked Read/Write calls to continue execution.
func (p *Pipe) Unblock() error { func (p *Pipe) Unblock() error {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock()
closed := p.closed closed := p.closed
blocked := p.blocked blocked := p.blocked
p.blocked = false p.blocked = false
p.mu.Unlock()
if closed { if closed {
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name) return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
@ -241,21 +240,6 @@ func (p *Pipe) Unblock() error {
if !blocked { if !blocked {
return fmt.Errorf("nettest.Pipe(%q).Block: already unblocked", p.name) return fmt.Errorf("nettest.Pipe(%q).Block: already unblocked", p.name)
} }
p.signalRead() p.cnd.Broadcast()
p.signalWrite()
return nil return nil
} }
func (p *Pipe) signalRead() {
select {
case p.rCh <- struct{}{}:
default:
}
}
func (p *Pipe) signalWrite() {
select {
case p.wCh <- struct{}{}:
default:
}
}

View File

@ -7,6 +7,7 @@
import ( import (
"errors" "errors"
"fmt" "fmt"
"os"
"testing" "testing"
"time" "time"
) )
@ -35,7 +36,7 @@ func TestPipeTimeout(t *testing.T) {
p := NewPipe("p1", 1<<16) p := NewPipe("p1", 1<<16)
p.SetWriteDeadline(time.Now().Add(-1 * time.Second)) p.SetWriteDeadline(time.Now().Add(-1 * time.Second))
n, err := p.Write([]byte{'h'}) n, err := p.Write([]byte{'h'})
if !errors.Is(err, ErrWriteTimeout) || !errors.Is(err, ErrTimeout) { if !errors.Is(err, os.ErrDeadlineExceeded) {
t.Errorf("missing write timeout got err: %v", err) t.Errorf("missing write timeout got err: %v", err)
} }
if n != 0 { if n != 0 {
@ -49,7 +50,7 @@ func TestPipeTimeout(t *testing.T) {
p.SetReadDeadline(time.Now().Add(-1 * time.Second)) p.SetReadDeadline(time.Now().Add(-1 * time.Second))
b := make([]byte, 1) b := make([]byte, 1)
n, err := p.Read(b) n, err := p.Read(b)
if !errors.Is(err, ErrReadTimeout) || !errors.Is(err, ErrTimeout) { if !errors.Is(err, os.ErrDeadlineExceeded) {
t.Errorf("missing read timeout got err: %v", err) t.Errorf("missing read timeout got err: %v", err)
} }
if n != 0 { if n != 0 {
@ -65,7 +66,7 @@ func TestPipeTimeout(t *testing.T) {
if err := p.Block(); err != nil { if err := p.Block(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if _, err := p.Write([]byte{'h'}); !errors.Is(err, ErrWriteTimeout) { if _, err := p.Write([]byte{'h'}); !errors.Is(err, os.ErrDeadlineExceeded) {
t.Fatalf("want write timeout got: %v", err) t.Fatalf("want write timeout got: %v", err)
} }
}) })
@ -80,7 +81,7 @@ func TestPipeTimeout(t *testing.T) {
if err := p.Block(); err != nil { if err := p.Block(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if _, err := p.Read(b); !errors.Is(err, ErrReadTimeout) { if _, err := p.Read(b); !errors.Is(err, os.ErrDeadlineExceeded) {
t.Fatalf("want read timeout got: %v", err) t.Fatalf("want read timeout got: %v", err)
} }
}) })