mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-30 05:25:35 +00:00
262 lines
4.9 KiB
Go
262 lines
4.9 KiB
Go
|
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package nettest
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"log"
|
||
|
"sync"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
const debugPipe = false
|
||
|
|
||
|
// Pipe implements an in-memory FIFO with timeouts.
|
||
|
type Pipe struct {
|
||
|
name string
|
||
|
maxBuf int
|
||
|
rCh chan struct{}
|
||
|
wCh chan struct{}
|
||
|
|
||
|
mu sync.Mutex
|
||
|
closed bool
|
||
|
blocked bool
|
||
|
buf []byte
|
||
|
readTimeout time.Time
|
||
|
writeTimeout time.Time
|
||
|
cancelReadTimer func()
|
||
|
cancelWriteTimer func()
|
||
|
}
|
||
|
|
||
|
// NewPipe creates a Pipe with a buffer size fixed at maxBuf.
|
||
|
func NewPipe(name string, maxBuf int) *Pipe {
|
||
|
return &Pipe{
|
||
|
name: name,
|
||
|
maxBuf: maxBuf,
|
||
|
rCh: make(chan struct{}, 1),
|
||
|
wCh: make(chan struct{}, 1),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
ErrTimeout = errors.New("timeout")
|
||
|
ErrReadTimeout = fmt.Errorf("read %w", ErrTimeout)
|
||
|
ErrWriteTimeout = fmt.Errorf("write %w", ErrTimeout)
|
||
|
)
|
||
|
|
||
|
// Read implements io.Reader.
|
||
|
func (p *Pipe) Read(b []byte) (n int, err error) {
|
||
|
if debugPipe {
|
||
|
orig := b
|
||
|
defer func() {
|
||
|
log.Printf("Pipe(%q).Read( %q) n=%d, err=%v", p.name, string(orig[:n]), n, err)
|
||
|
}()
|
||
|
}
|
||
|
for {
|
||
|
p.mu.Lock()
|
||
|
closed := p.closed
|
||
|
timedout := !p.readTimeout.IsZero() && time.Now().After(p.readTimeout)
|
||
|
blocked := p.blocked
|
||
|
if !closed && !timedout && len(p.buf) > 0 {
|
||
|
n2 := copy(b, p.buf)
|
||
|
p.buf = p.buf[n2:]
|
||
|
b = b[n2:]
|
||
|
n += n2
|
||
|
}
|
||
|
p.mu.Unlock()
|
||
|
|
||
|
if closed {
|
||
|
return 0, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
|
||
|
}
|
||
|
if timedout {
|
||
|
return 0, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrReadTimeout)
|
||
|
}
|
||
|
if blocked {
|
||
|
<-p.rCh
|
||
|
continue
|
||
|
}
|
||
|
if n > 0 {
|
||
|
p.signalWrite()
|
||
|
return n, nil
|
||
|
}
|
||
|
<-p.rCh
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Write implements io.Writer.
|
||
|
func (p *Pipe) Write(b []byte) (n int, err error) {
|
||
|
if debugPipe {
|
||
|
orig := b
|
||
|
defer func() {
|
||
|
log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err)
|
||
|
}()
|
||
|
}
|
||
|
for {
|
||
|
p.mu.Lock()
|
||
|
closed := p.closed
|
||
|
timedout := !p.writeTimeout.IsZero() && time.Now().After(p.writeTimeout)
|
||
|
blocked := p.blocked
|
||
|
if !closed && !timedout {
|
||
|
n2 := len(b)
|
||
|
if limit := p.maxBuf - len(p.buf); limit < n2 {
|
||
|
n2 = limit
|
||
|
}
|
||
|
p.buf = append(p.buf, b[:n2]...)
|
||
|
b = b[n2:]
|
||
|
n += n2
|
||
|
}
|
||
|
p.mu.Unlock()
|
||
|
|
||
|
if closed {
|
||
|
return n, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
|
||
|
}
|
||
|
if timedout {
|
||
|
return n, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrWriteTimeout)
|
||
|
}
|
||
|
if blocked {
|
||
|
<-p.wCh
|
||
|
continue
|
||
|
}
|
||
|
if n > 0 {
|
||
|
p.signalRead()
|
||
|
}
|
||
|
if len(b) == 0 {
|
||
|
return n, nil
|
||
|
}
|
||
|
<-p.wCh
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Close implements io.Closer.
|
||
|
func (p *Pipe) Close() error {
|
||
|
p.mu.Lock()
|
||
|
closed := p.closed
|
||
|
p.closed = true
|
||
|
if p.cancelWriteTimer != nil {
|
||
|
p.cancelWriteTimer()
|
||
|
p.cancelWriteTimer = nil
|
||
|
}
|
||
|
if p.cancelReadTimer != nil {
|
||
|
p.cancelReadTimer()
|
||
|
p.cancelReadTimer = nil
|
||
|
}
|
||
|
p.mu.Unlock()
|
||
|
|
||
|
if closed {
|
||
|
return fmt.Errorf("nettest.Pipe(%q).Close: already closed", p.name)
|
||
|
}
|
||
|
|
||
|
p.signalRead()
|
||
|
p.signalWrite()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// SetReadDeadline sets the deadline for future Read calls.
|
||
|
func (p *Pipe) SetReadDeadline(t time.Time) error {
|
||
|
p.mu.Lock()
|
||
|
p.readTimeout = t
|
||
|
if p.cancelReadTimer != nil {
|
||
|
p.cancelReadTimer()
|
||
|
p.cancelReadTimer = nil
|
||
|
}
|
||
|
if d := time.Until(t); !t.IsZero() && d > 0 {
|
||
|
ctx, cancel := context.WithCancel(context.Background())
|
||
|
p.cancelReadTimer = cancel
|
||
|
go func() {
|
||
|
t := time.NewTimer(d)
|
||
|
defer t.Stop()
|
||
|
select {
|
||
|
case <-t.C:
|
||
|
p.signalRead()
|
||
|
case <-ctx.Done():
|
||
|
}
|
||
|
}()
|
||
|
}
|
||
|
p.mu.Unlock()
|
||
|
|
||
|
p.signalRead()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// SetWriteDeadline sets the deadline for future Write calls.
|
||
|
func (p *Pipe) SetWriteDeadline(t time.Time) error {
|
||
|
p.mu.Lock()
|
||
|
p.writeTimeout = t
|
||
|
if p.cancelWriteTimer != nil {
|
||
|
p.cancelWriteTimer()
|
||
|
p.cancelWriteTimer = nil
|
||
|
}
|
||
|
if d := time.Until(t); !t.IsZero() && d > 0 {
|
||
|
ctx, cancel := context.WithCancel(context.Background())
|
||
|
p.cancelWriteTimer = cancel
|
||
|
go func() {
|
||
|
t := time.NewTimer(d)
|
||
|
defer t.Stop()
|
||
|
select {
|
||
|
case <-t.C:
|
||
|
p.signalWrite()
|
||
|
case <-ctx.Done():
|
||
|
}
|
||
|
}()
|
||
|
}
|
||
|
p.mu.Unlock()
|
||
|
|
||
|
p.signalWrite()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (p *Pipe) Block() error {
|
||
|
p.mu.Lock()
|
||
|
closed := p.closed
|
||
|
blocked := p.blocked
|
||
|
p.blocked = true
|
||
|
p.mu.Unlock()
|
||
|
|
||
|
if closed {
|
||
|
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
|
||
|
}
|
||
|
if blocked {
|
||
|
return fmt.Errorf("nettest.Pipe(%q).Block: already blocked", p.name)
|
||
|
}
|
||
|
p.signalRead()
|
||
|
p.signalWrite()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (p *Pipe) Unblock() error {
|
||
|
p.mu.Lock()
|
||
|
closed := p.closed
|
||
|
blocked := p.blocked
|
||
|
p.blocked = false
|
||
|
p.mu.Unlock()
|
||
|
|
||
|
if closed {
|
||
|
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
|
||
|
}
|
||
|
if !blocked {
|
||
|
return fmt.Errorf("nettest.Pipe(%q).Block: already unblocked", p.name)
|
||
|
}
|
||
|
p.signalRead()
|
||
|
p.signalWrite()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (p *Pipe) signalRead() {
|
||
|
select {
|
||
|
case p.rCh <- struct{}{}:
|
||
|
default:
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (p *Pipe) signalWrite() {
|
||
|
select {
|
||
|
case p.wCh <- struct{}{}:
|
||
|
default:
|
||
|
}
|
||
|
}
|