// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause package nettest import ( "bytes" "context" "fmt" "io" "log" "net" "os" "sync" "time" ) const debugPipe = false // Pipe implements an in-memory FIFO with timeouts. type Pipe struct { name string maxBuf int mu sync.Mutex cnd *sync.Cond blocked bool closed bool buf bytes.Buffer readTimeout time.Time writeTimeout time.Time cancelReadTimer func() cancelWriteTimer func() } // NewPipe creates a Pipe with a buffer size fixed at maxBuf. func NewPipe(name string, maxBuf int) *Pipe { p := &Pipe{ name: name, maxBuf: maxBuf, } p.cnd = sync.NewCond(&p.mu) return p } // readOrBlock attempts to read from the buffer, if the buffer is empty and // the connection hasn't been closed it will block until there is a change. func (p *Pipe) readOrBlock(b []byte) (int, error) { p.mu.Lock() defer p.mu.Unlock() if !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout) { return 0, os.ErrDeadlineExceeded } if p.blocked { p.cnd.Wait() return 0, nil } n, err := p.buf.Read(b) // err will either be nil or io.EOF. if err == io.EOF { if p.closed { return n, err } // Wait for something to change. p.cnd.Wait() } return n, nil } // Read implements io.Reader. // Once the buffer is drained (i.e. after Close), subsequent calls will // return io.EOF. func (p *Pipe) Read(b []byte) (n int, err error) { if debugPipe { orig := b defer func() { log.Printf("Pipe(%q).Read(%q) n=%d, err=%v", p.name, string(orig[:n]), n, err) }() } for n == 0 { n2, err := p.readOrBlock(b) if err != nil { return n2, err } n += n2 } p.cnd.Signal() return n, nil } // writeOrBlock attempts to write to the buffer, if the buffer is full it will // block until there is a change. func (p *Pipe) writeOrBlock(b []byte) (int, error) { p.mu.Lock() defer p.mu.Unlock() if p.closed { return 0, net.ErrClosed } if !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout) { return 0, os.ErrDeadlineExceeded } if p.blocked { p.cnd.Wait() return 0, nil } // Optimistically we want to write the entire slice. n := len(b) if limit := p.maxBuf - p.buf.Len(); limit < n { // However, we don't have enough capacity to write everything. n = limit } if n == 0 { // Wait for something to change. p.cnd.Wait() return 0, nil } p.buf.Write(b[:n]) p.cnd.Signal() return n, nil } // Write implements io.Writer. func (p *Pipe) Write(b []byte) (n int, err error) { if debugPipe { orig := b defer func() { log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err) }() } for len(b) > 0 { n2, err := p.writeOrBlock(b) if err != nil { return n + n2, err } n += n2 b = b[n2:] } return n, nil } // Close closes the pipe. func (p *Pipe) Close() error { p.mu.Lock() defer p.mu.Unlock() p.closed = true p.blocked = false if p.cancelWriteTimer != nil { p.cancelWriteTimer() p.cancelWriteTimer = nil } if p.cancelReadTimer != nil { p.cancelReadTimer() p.cancelReadTimer = nil } p.cnd.Broadcast() return nil } func (p *Pipe) deadlineTimer(t time.Time) func() { if t.IsZero() { return nil } if t.Before(time.Now()) { p.cnd.Broadcast() return nil } ctx, cancel := context.WithDeadline(context.Background(), t) go func() { <-ctx.Done() if ctx.Err() == context.DeadlineExceeded { p.cnd.Broadcast() } }() return cancel } // SetReadDeadline sets the deadline for future Read calls. func (p *Pipe) SetReadDeadline(t time.Time) error { p.mu.Lock() defer p.mu.Unlock() p.readTimeout = t // If we already have a deadline, cancel it and create a new one. if p.cancelReadTimer != nil { p.cancelReadTimer() p.cancelReadTimer = nil } p.cancelReadTimer = p.deadlineTimer(t) return nil } // SetWriteDeadline sets the deadline for future Write calls. func (p *Pipe) SetWriteDeadline(t time.Time) error { p.mu.Lock() defer p.mu.Unlock() p.writeTimeout = t // If we already have a deadline, cancel it and create a new one. if p.cancelWriteTimer != nil { p.cancelWriteTimer() p.cancelWriteTimer = nil } p.cancelWriteTimer = p.deadlineTimer(t) return nil } // Block will cause all calls to Read and Write to block until they either // timeout, are unblocked or the pipe is closed. func (p *Pipe) Block() error { p.mu.Lock() defer p.mu.Unlock() closed := p.closed blocked := p.blocked p.blocked = true if closed { return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name) } if blocked { return fmt.Errorf("nettest.Pipe(%q).Block: already blocked", p.name) } p.cnd.Broadcast() return nil } // Unblock will cause all blocked Read/Write calls to continue execution. func (p *Pipe) Unblock() error { p.mu.Lock() defer p.mu.Unlock() closed := p.closed blocked := p.blocked p.blocked = false if closed { return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name) } if !blocked { return fmt.Errorf("nettest.Pipe(%q).Block: already unblocked", p.name) } p.cnd.Broadcast() return nil }