mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 04:55:31 +00:00
net/nettest: make nettest.NewConn pass x/net/nettest.TestConn.
Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
parent
e0e677a8f6
commit
57756ef673
@ -5,21 +5,13 @@
|
|||||||
package nettest
|
package nettest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Conn is a bi-directional in-memory stream that looks like a TCP net.Conn.
|
// Conn is a net.Conn that can additionally have its reads and writes blocked and unblocked.
|
||||||
type Conn interface {
|
type Conn interface {
|
||||||
io.Reader
|
net.Conn
|
||||||
io.Writer
|
|
||||||
io.Closer
|
|
||||||
|
|
||||||
// The *Deadline methods follow the semantics of net.Conn.
|
|
||||||
|
|
||||||
SetDeadline(t time.Time) error
|
|
||||||
SetReadDeadline(t time.Time) error
|
|
||||||
SetWriteDeadline(t time.Time) error
|
|
||||||
|
|
||||||
// SetReadBlock blocks or unblocks the Read method of this Conn.
|
// SetReadBlock blocks or unblocks the Read method of this Conn.
|
||||||
// It reports an error if the existing value matches the new value,
|
// It reports an error if the existing value matches the new value,
|
||||||
@ -40,24 +32,37 @@ func NewConn(name string, maxBuf int) (Conn, Conn) {
|
|||||||
return &connHalf{r: r, w: w}, &connHalf{r: w, w: r}
|
return &connHalf{r: r, w: w}, &connHalf{r: w, w: r}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type connAddr string
|
||||||
|
|
||||||
|
func (a connAddr) Network() string { return "mem" }
|
||||||
|
func (a connAddr) String() string { return string(a) }
|
||||||
|
|
||||||
type connHalf struct {
|
type connHalf struct {
|
||||||
r, w *Pipe
|
r, w *Pipe
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *connHalf) LocalAddr() net.Addr {
|
||||||
|
return connAddr(c.r.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connHalf) RemoteAddr() net.Addr {
|
||||||
|
return connAddr(c.w.name)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *connHalf) Read(b []byte) (n int, err error) {
|
func (c *connHalf) Read(b []byte) (n int, err error) {
|
||||||
return c.r.Read(b)
|
return c.r.Read(b)
|
||||||
}
|
}
|
||||||
func (c *connHalf) Write(b []byte) (n int, err error) {
|
func (c *connHalf) Write(b []byte) (n int, err error) {
|
||||||
return c.w.Write(b)
|
return c.w.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *connHalf) Close() error {
|
func (c *connHalf) Close() error {
|
||||||
err1 := c.r.Close()
|
if err := c.w.Close(); err != nil {
|
||||||
err2 := c.w.Close()
|
return err
|
||||||
if err1 != nil {
|
|
||||||
return err1
|
|
||||||
}
|
}
|
||||||
return err2
|
return c.r.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *connHalf) SetDeadline(t time.Time) error {
|
func (c *connHalf) SetDeadline(t time.Time) error {
|
||||||
err1 := c.SetReadDeadline(t)
|
err1 := c.SetReadDeadline(t)
|
||||||
err2 := c.SetWriteDeadline(t)
|
err2 := c.SetWriteDeadline(t)
|
||||||
@ -72,6 +77,7 @@ func (c *connHalf) SetReadDeadline(t time.Time) error {
|
|||||||
func (c *connHalf) SetWriteDeadline(t time.Time) error {
|
func (c *connHalf) SetWriteDeadline(t time.Time) error {
|
||||||
return c.w.SetWriteDeadline(t)
|
return c.w.SetWriteDeadline(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *connHalf) SetReadBlock(b bool) error {
|
func (c *connHalf) SetReadBlock(b bool) error {
|
||||||
if b {
|
if b {
|
||||||
return c.r.Block()
|
return c.r.Block()
|
||||||
|
22
net/nettest/conn_test.go
Normal file
22
net/nettest/conn_test.go
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package nettest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/net/nettest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConn(t *testing.T) {
|
||||||
|
nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) {
|
||||||
|
c1, c2 = NewConn("test", bufferSize)
|
||||||
|
return c1, c2, func() {
|
||||||
|
c1.Close()
|
||||||
|
c2.Close()
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
}
|
83
net/nettest/listener.go
Normal file
83
net/nettest/listener.go
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package nettest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
bufferSize = 256 * 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
// Listener is a net.Listener using using NewConn to create pairs of network
|
||||||
|
// connections connected in memory using a buffered pipe. It also provides a
|
||||||
|
// Dial method to establish new connections.
|
||||||
|
type Listener struct {
|
||||||
|
addr connAddr
|
||||||
|
ch chan Conn
|
||||||
|
closeOnce sync.Once
|
||||||
|
closed chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen returns a new Listener for the provided address.
|
||||||
|
func Listen(addr string) *Listener {
|
||||||
|
return &Listener{
|
||||||
|
addr: connAddr(addr),
|
||||||
|
ch: make(chan Conn),
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Addr implements net.Listener.Addr.
|
||||||
|
func (l *Listener) Addr() net.Addr {
|
||||||
|
return l.addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the pipe listener.
|
||||||
|
func (l *Listener) Close() error {
|
||||||
|
l.closeOnce.Do(func() {
|
||||||
|
close(l.closed)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accept blocks until a new connection is available or the listener is closed.
|
||||||
|
func (l *Listener) Accept() (net.Conn, error) {
|
||||||
|
select {
|
||||||
|
case c := <-l.ch:
|
||||||
|
return c, nil
|
||||||
|
case <-l.closed:
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial connects to the listener using the provided context.
|
||||||
|
// The provided Context must be non-nil. If the context expires before the
|
||||||
|
// connection is complete, an error is returned. Once successfully connected
|
||||||
|
// any expiration of the context will not affect the connection.
|
||||||
|
func (l *Listener) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
if !strings.HasSuffix(network, "tcp") {
|
||||||
|
return nil, net.UnknownNetworkError(network)
|
||||||
|
}
|
||||||
|
if connAddr(addr) != l.addr {
|
||||||
|
return nil, &net.AddrError{
|
||||||
|
Err: "invalid address",
|
||||||
|
Addr: addr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c, s := NewConn(addr, bufferSize)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-l.closed:
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
case l.ch <- s:
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
}
|
34
net/nettest/listener_test.go
Normal file
34
net/nettest/listener_test.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package nettest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestListener(t *testing.T) {
|
||||||
|
l := Listen("srv.local")
|
||||||
|
defer l.Close()
|
||||||
|
go func() {
|
||||||
|
c, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if c, err := l.Dial(context.Background(), "tcp", "invalid"); err == nil {
|
||||||
|
c.Close()
|
||||||
|
t.Fatalf("dial to invalid address succeeded")
|
||||||
|
}
|
||||||
|
c, err := l.Dial(context.Background(), "tcp", "srv.local")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Close()
|
||||||
|
}
|
@ -5,11 +5,13 @@
|
|||||||
package nettest
|
package nettest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -20,13 +22,12 @@
|
|||||||
type Pipe struct {
|
type Pipe struct {
|
||||||
name string
|
name string
|
||||||
maxBuf int
|
maxBuf int
|
||||||
rCh chan struct{}
|
mu sync.Mutex
|
||||||
wCh chan struct{}
|
cnd *sync.Cond
|
||||||
|
|
||||||
mu sync.Mutex
|
|
||||||
closed bool
|
|
||||||
blocked bool
|
blocked bool
|
||||||
buf []byte
|
closed bool
|
||||||
|
buf bytes.Buffer
|
||||||
readTimeout time.Time
|
readTimeout time.Time
|
||||||
writeTimeout time.Time
|
writeTimeout time.Time
|
||||||
cancelReadTimer func()
|
cancelReadTimer func()
|
||||||
@ -35,21 +36,42 @@ type Pipe struct {
|
|||||||
|
|
||||||
// NewPipe creates a Pipe with a buffer size fixed at maxBuf.
|
// NewPipe creates a Pipe with a buffer size fixed at maxBuf.
|
||||||
func NewPipe(name string, maxBuf int) *Pipe {
|
func NewPipe(name string, maxBuf int) *Pipe {
|
||||||
return &Pipe{
|
p := &Pipe{
|
||||||
name: name,
|
name: name,
|
||||||
maxBuf: maxBuf,
|
maxBuf: maxBuf,
|
||||||
rCh: make(chan struct{}, 1),
|
|
||||||
wCh: make(chan struct{}, 1),
|
|
||||||
}
|
}
|
||||||
|
p.cnd = sync.NewCond(&p.mu)
|
||||||
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
// readOrBlock attempts to read from the buffer, if the buffer is empty and
|
||||||
ErrTimeout = errors.New("timeout")
|
// the connection hasn't been closed it will block until there is a change.
|
||||||
ErrReadTimeout = fmt.Errorf("read %w", ErrTimeout)
|
func (p *Pipe) readOrBlock(b []byte) (int, error) {
|
||||||
ErrWriteTimeout = fmt.Errorf("write %w", ErrTimeout)
|
p.mu.Lock()
|
||||||
)
|
defer p.mu.Unlock()
|
||||||
|
if !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout) {
|
||||||
|
return 0, os.ErrDeadlineExceeded
|
||||||
|
}
|
||||||
|
if p.blocked {
|
||||||
|
p.cnd.Wait()
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := p.buf.Read(b)
|
||||||
|
// err will either be nil or io.EOF.
|
||||||
|
if err == io.EOF {
|
||||||
|
if p.closed {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
// Wait for something to change.
|
||||||
|
p.cnd.Wait()
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Read implements io.Reader.
|
// Read implements io.Reader.
|
||||||
|
// Once the buffer is drained (i.e. after Close), subsequent calls will
|
||||||
|
// return io.EOF.
|
||||||
func (p *Pipe) Read(b []byte) (n int, err error) {
|
func (p *Pipe) Read(b []byte) (n int, err error) {
|
||||||
if debugPipe {
|
if debugPipe {
|
||||||
orig := b
|
orig := b
|
||||||
@ -57,35 +79,48 @@ func (p *Pipe) Read(b []byte) (n int, err error) {
|
|||||||
log.Printf("Pipe(%q).Read( %q) n=%d, err=%v", p.name, string(orig[:n]), n, err)
|
log.Printf("Pipe(%q).Read( %q) n=%d, err=%v", p.name, string(orig[:n]), n, err)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
for {
|
for n == 0 {
|
||||||
p.mu.Lock()
|
n2, err := p.readOrBlock(b)
|
||||||
closed := p.closed
|
if err != nil {
|
||||||
timedout := !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout)
|
return n2, err
|
||||||
blocked := p.blocked
|
|
||||||
if !closed && !timedout && len(p.buf) > 0 {
|
|
||||||
n2 := copy(b, p.buf)
|
|
||||||
p.buf = p.buf[n2:]
|
|
||||||
b = b[n2:]
|
|
||||||
n += n2
|
|
||||||
}
|
}
|
||||||
p.mu.Unlock()
|
n += n2
|
||||||
|
|
||||||
if closed {
|
|
||||||
return 0, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
|
|
||||||
}
|
|
||||||
if timedout {
|
|
||||||
return 0, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrReadTimeout)
|
|
||||||
}
|
|
||||||
if blocked {
|
|
||||||
<-p.rCh
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if n > 0 {
|
|
||||||
p.signalWrite()
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
<-p.rCh
|
|
||||||
}
|
}
|
||||||
|
p.cnd.Signal()
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeOrBlock attempts to write to the buffer, if the buffer is full it will
|
||||||
|
// block until there is a change.
|
||||||
|
func (p *Pipe) writeOrBlock(b []byte) (int, error) {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
if p.closed {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
if !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout) {
|
||||||
|
return 0, os.ErrDeadlineExceeded
|
||||||
|
}
|
||||||
|
if p.blocked {
|
||||||
|
p.cnd.Wait()
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optimistically we want to write the entire slice.
|
||||||
|
n := len(b)
|
||||||
|
if limit := p.maxBuf - p.buf.Len(); limit < n {
|
||||||
|
// However, we don't have enough capacity to write everything.
|
||||||
|
n = limit
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
// Wait for something to change.
|
||||||
|
p.cnd.Wait()
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
p.buf.Write(b[:n])
|
||||||
|
p.cnd.Signal()
|
||||||
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write implements io.Writer.
|
// Write implements io.Writer.
|
||||||
@ -96,47 +131,23 @@ func (p *Pipe) Write(b []byte) (n int, err error) {
|
|||||||
log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err)
|
log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
for {
|
for len(b) > 0 {
|
||||||
p.mu.Lock()
|
n2, err := p.writeOrBlock(b)
|
||||||
closed := p.closed
|
if err != nil {
|
||||||
timedout := !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout)
|
return n + n2, err
|
||||||
blocked := p.blocked
|
|
||||||
if !closed && !timedout {
|
|
||||||
n2 := len(b)
|
|
||||||
if limit := p.maxBuf - len(p.buf); limit < n2 {
|
|
||||||
n2 = limit
|
|
||||||
}
|
|
||||||
p.buf = append(p.buf, b[:n2]...)
|
|
||||||
b = b[n2:]
|
|
||||||
n += n2
|
|
||||||
}
|
}
|
||||||
p.mu.Unlock()
|
n += n2
|
||||||
|
b = b[n2:]
|
||||||
if closed {
|
|
||||||
return n, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
|
|
||||||
}
|
|
||||||
if timedout {
|
|
||||||
return n, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrWriteTimeout)
|
|
||||||
}
|
|
||||||
if blocked {
|
|
||||||
<-p.wCh
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if n > 0 {
|
|
||||||
p.signalRead()
|
|
||||||
}
|
|
||||||
if len(b) == 0 {
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
<-p.wCh
|
|
||||||
}
|
}
|
||||||
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close implements io.Closer.
|
// Close closes the pipe.
|
||||||
func (p *Pipe) Close() error {
|
func (p *Pipe) Close() error {
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
closed := p.closed
|
defer p.mu.Unlock()
|
||||||
p.closed = true
|
p.closed = true
|
||||||
|
p.blocked = false
|
||||||
if p.cancelWriteTimer != nil {
|
if p.cancelWriteTimer != nil {
|
||||||
p.cancelWriteTimer()
|
p.cancelWriteTimer()
|
||||||
p.cancelWriteTimer = nil
|
p.cancelWriteTimer = nil
|
||||||
@ -145,77 +156,65 @@ func (p *Pipe) Close() error {
|
|||||||
p.cancelReadTimer()
|
p.cancelReadTimer()
|
||||||
p.cancelReadTimer = nil
|
p.cancelReadTimer = nil
|
||||||
}
|
}
|
||||||
p.mu.Unlock()
|
p.cnd.Broadcast()
|
||||||
|
|
||||||
if closed {
|
|
||||||
return fmt.Errorf("nettest.Pipe(%q).Close: already closed", p.name)
|
|
||||||
}
|
|
||||||
|
|
||||||
p.signalRead()
|
|
||||||
p.signalWrite()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Pipe) deadlineTimer(t time.Time) func() {
|
||||||
|
if t.IsZero() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if t.Before(time.Now()) {
|
||||||
|
p.cnd.Broadcast()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithDeadline(context.Background(), t)
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
|
p.cnd.Broadcast()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return cancel
|
||||||
|
}
|
||||||
|
|
||||||
// SetReadDeadline sets the deadline for future Read calls.
|
// SetReadDeadline sets the deadline for future Read calls.
|
||||||
func (p *Pipe) SetReadDeadline(t time.Time) error {
|
func (p *Pipe) SetReadDeadline(t time.Time) error {
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
p.readTimeout = t
|
p.readTimeout = t
|
||||||
|
// If we already have a deadline, cancel it and create a new one.
|
||||||
if p.cancelReadTimer != nil {
|
if p.cancelReadTimer != nil {
|
||||||
p.cancelReadTimer()
|
p.cancelReadTimer()
|
||||||
p.cancelReadTimer = nil
|
p.cancelReadTimer = nil
|
||||||
}
|
}
|
||||||
if d := time.Until(t); !t.IsZero() && d > 0 {
|
p.cancelReadTimer = p.deadlineTimer(t)
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
p.cancelReadTimer = cancel
|
|
||||||
go func() {
|
|
||||||
t := time.NewTimer(d)
|
|
||||||
defer t.Stop()
|
|
||||||
select {
|
|
||||||
case <-t.C:
|
|
||||||
p.signalRead()
|
|
||||||
case <-ctx.Done():
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
p.mu.Unlock()
|
|
||||||
|
|
||||||
p.signalRead()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetWriteDeadline sets the deadline for future Write calls.
|
// SetWriteDeadline sets the deadline for future Write calls.
|
||||||
func (p *Pipe) SetWriteDeadline(t time.Time) error {
|
func (p *Pipe) SetWriteDeadline(t time.Time) error {
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
p.writeTimeout = t
|
p.writeTimeout = t
|
||||||
|
// If we already have a deadline, cancel it and create a new one.
|
||||||
if p.cancelWriteTimer != nil {
|
if p.cancelWriteTimer != nil {
|
||||||
p.cancelWriteTimer()
|
p.cancelWriteTimer()
|
||||||
p.cancelWriteTimer = nil
|
p.cancelWriteTimer = nil
|
||||||
}
|
}
|
||||||
if d := time.Until(t); !t.IsZero() && d > 0 {
|
p.cancelWriteTimer = p.deadlineTimer(t)
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
p.cancelWriteTimer = cancel
|
|
||||||
go func() {
|
|
||||||
t := time.NewTimer(d)
|
|
||||||
defer t.Stop()
|
|
||||||
select {
|
|
||||||
case <-t.C:
|
|
||||||
p.signalWrite()
|
|
||||||
case <-ctx.Done():
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
p.mu.Unlock()
|
|
||||||
|
|
||||||
p.signalWrite()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Block will cause all calls to Read and Write to block until they either
|
||||||
|
// timeout, are unblocked or the pipe is closed.
|
||||||
func (p *Pipe) Block() error {
|
func (p *Pipe) Block() error {
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
closed := p.closed
|
closed := p.closed
|
||||||
blocked := p.blocked
|
blocked := p.blocked
|
||||||
p.blocked = true
|
p.blocked = true
|
||||||
p.mu.Unlock()
|
|
||||||
|
|
||||||
if closed {
|
if closed {
|
||||||
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
|
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
|
||||||
@ -223,17 +222,17 @@ func (p *Pipe) Block() error {
|
|||||||
if blocked {
|
if blocked {
|
||||||
return fmt.Errorf("nettest.Pipe(%q).Block: already blocked", p.name)
|
return fmt.Errorf("nettest.Pipe(%q).Block: already blocked", p.name)
|
||||||
}
|
}
|
||||||
p.signalRead()
|
p.cnd.Broadcast()
|
||||||
p.signalWrite()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unblock will cause all blocked Read/Write calls to continue execution.
|
||||||
func (p *Pipe) Unblock() error {
|
func (p *Pipe) Unblock() error {
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
closed := p.closed
|
closed := p.closed
|
||||||
blocked := p.blocked
|
blocked := p.blocked
|
||||||
p.blocked = false
|
p.blocked = false
|
||||||
p.mu.Unlock()
|
|
||||||
|
|
||||||
if closed {
|
if closed {
|
||||||
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
|
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
|
||||||
@ -241,21 +240,6 @@ func (p *Pipe) Unblock() error {
|
|||||||
if !blocked {
|
if !blocked {
|
||||||
return fmt.Errorf("nettest.Pipe(%q).Block: already unblocked", p.name)
|
return fmt.Errorf("nettest.Pipe(%q).Block: already unblocked", p.name)
|
||||||
}
|
}
|
||||||
p.signalRead()
|
p.cnd.Broadcast()
|
||||||
p.signalWrite()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Pipe) signalRead() {
|
|
||||||
select {
|
|
||||||
case p.rCh <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Pipe) signalWrite() {
|
|
||||||
select {
|
|
||||||
case p.wCh <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -35,7 +36,7 @@ func TestPipeTimeout(t *testing.T) {
|
|||||||
p := NewPipe("p1", 1<<16)
|
p := NewPipe("p1", 1<<16)
|
||||||
p.SetWriteDeadline(time.Now().Add(-1 * time.Second))
|
p.SetWriteDeadline(time.Now().Add(-1 * time.Second))
|
||||||
n, err := p.Write([]byte{'h'})
|
n, err := p.Write([]byte{'h'})
|
||||||
if !errors.Is(err, ErrWriteTimeout) || !errors.Is(err, ErrTimeout) {
|
if !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
t.Errorf("missing write timeout got err: %v", err)
|
t.Errorf("missing write timeout got err: %v", err)
|
||||||
}
|
}
|
||||||
if n != 0 {
|
if n != 0 {
|
||||||
@ -49,7 +50,7 @@ func TestPipeTimeout(t *testing.T) {
|
|||||||
p.SetReadDeadline(time.Now().Add(-1 * time.Second))
|
p.SetReadDeadline(time.Now().Add(-1 * time.Second))
|
||||||
b := make([]byte, 1)
|
b := make([]byte, 1)
|
||||||
n, err := p.Read(b)
|
n, err := p.Read(b)
|
||||||
if !errors.Is(err, ErrReadTimeout) || !errors.Is(err, ErrTimeout) {
|
if !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
t.Errorf("missing read timeout got err: %v", err)
|
t.Errorf("missing read timeout got err: %v", err)
|
||||||
}
|
}
|
||||||
if n != 0 {
|
if n != 0 {
|
||||||
@ -65,7 +66,7 @@ func TestPipeTimeout(t *testing.T) {
|
|||||||
if err := p.Block(); err != nil {
|
if err := p.Block(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if _, err := p.Write([]byte{'h'}); !errors.Is(err, ErrWriteTimeout) {
|
if _, err := p.Write([]byte{'h'}); !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
t.Fatalf("want write timeout got: %v", err)
|
t.Fatalf("want write timeout got: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -80,7 +81,7 @@ func TestPipeTimeout(t *testing.T) {
|
|||||||
if err := p.Block(); err != nil {
|
if err := p.Block(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if _, err := p.Read(b); !errors.Is(err, ErrReadTimeout) {
|
if _, err := p.Read(b); !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
t.Fatalf("want read timeout got: %v", err)
|
t.Fatalf("want read timeout got: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
Loading…
Reference in New Issue
Block a user