// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package nettest

import (
	"context"
	"errors"
	"fmt"
	"io"
	"log"
	"sync"
	"time"
)

const debugPipe = false

// Pipe implements an in-memory FIFO with timeouts.
type Pipe struct {
	name   string
	maxBuf int
	rCh    chan struct{}
	wCh    chan struct{}

	mu               sync.Mutex
	closed           bool
	blocked          bool
	buf              []byte
	readTimeout      time.Time
	writeTimeout     time.Time
	cancelReadTimer  func()
	cancelWriteTimer func()
}

// NewPipe creates a Pipe with a buffer size fixed at maxBuf.
func NewPipe(name string, maxBuf int) *Pipe {
	return &Pipe{
		name:   name,
		maxBuf: maxBuf,
		rCh:    make(chan struct{}, 1),
		wCh:    make(chan struct{}, 1),
	}
}

var (
	ErrTimeout      = errors.New("timeout")
	ErrReadTimeout  = fmt.Errorf("read %w", ErrTimeout)
	ErrWriteTimeout = fmt.Errorf("write %w", ErrTimeout)
)

// Read implements io.Reader.
func (p *Pipe) Read(b []byte) (n int, err error) {
	if debugPipe {
		orig := b
		defer func() {
			log.Printf("Pipe(%q).Read( %q) n=%d, err=%v", p.name, string(orig[:n]), n, err)
		}()
	}
	for {
		p.mu.Lock()
		closed := p.closed
		timedout := !p.readTimeout.IsZero() && time.Now().After(p.readTimeout)
		blocked := p.blocked
		if !closed && !timedout && len(p.buf) > 0 {
			n2 := copy(b, p.buf)
			p.buf = p.buf[n2:]
			b = b[n2:]
			n += n2
		}
		p.mu.Unlock()

		if closed {
			return 0, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
		}
		if timedout {
			return 0, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrReadTimeout)
		}
		if blocked {
			<-p.rCh
			continue
		}
		if n > 0 {
			p.signalWrite()
			return n, nil
		}
		<-p.rCh
	}
}

// Write implements io.Writer.
func (p *Pipe) Write(b []byte) (n int, err error) {
	if debugPipe {
		orig := b
		defer func() {
			log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err)
		}()
	}
	for {
		p.mu.Lock()
		closed := p.closed
		timedout := !p.writeTimeout.IsZero() && time.Now().After(p.writeTimeout)
		blocked := p.blocked
		if !closed && !timedout {
			n2 := len(b)
			if limit := p.maxBuf - len(p.buf); limit < n2 {
				n2 = limit
			}
			p.buf = append(p.buf, b[:n2]...)
			b = b[n2:]
			n += n2
		}
		p.mu.Unlock()

		if closed {
			return n, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
		}
		if timedout {
			return n, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrWriteTimeout)
		}
		if blocked {
			<-p.wCh
			continue
		}
		if n > 0 {
			p.signalRead()
		}
		if len(b) == 0 {
			return n, nil
		}
		<-p.wCh
	}
}

// Close implements io.Closer.
func (p *Pipe) Close() error {
	p.mu.Lock()
	closed := p.closed
	p.closed = true
	if p.cancelWriteTimer != nil {
		p.cancelWriteTimer()
		p.cancelWriteTimer = nil
	}
	if p.cancelReadTimer != nil {
		p.cancelReadTimer()
		p.cancelReadTimer = nil
	}
	p.mu.Unlock()

	if closed {
		return fmt.Errorf("nettest.Pipe(%q).Close: already closed", p.name)
	}

	p.signalRead()
	p.signalWrite()
	return nil
}

// SetReadDeadline sets the deadline for future Read calls.
func (p *Pipe) SetReadDeadline(t time.Time) error {
	p.mu.Lock()
	p.readTimeout = t
	if p.cancelReadTimer != nil {
		p.cancelReadTimer()
		p.cancelReadTimer = nil
	}
	if d := time.Until(t); !t.IsZero() && d > 0 {
		ctx, cancel := context.WithCancel(context.Background())
		p.cancelReadTimer = cancel
		go func() {
			t := time.NewTimer(d)
			defer t.Stop()
			select {
			case <-t.C:
				p.signalRead()
			case <-ctx.Done():
			}
		}()
	}
	p.mu.Unlock()

	p.signalRead()
	return nil
}

// SetWriteDeadline sets the deadline for future Write calls.
func (p *Pipe) SetWriteDeadline(t time.Time) error {
	p.mu.Lock()
	p.writeTimeout = t
	if p.cancelWriteTimer != nil {
		p.cancelWriteTimer()
		p.cancelWriteTimer = nil
	}
	if d := time.Until(t); !t.IsZero() && d > 0 {
		ctx, cancel := context.WithCancel(context.Background())
		p.cancelWriteTimer = cancel
		go func() {
			t := time.NewTimer(d)
			defer t.Stop()
			select {
			case <-t.C:
				p.signalWrite()
			case <-ctx.Done():
			}
		}()
	}
	p.mu.Unlock()

	p.signalWrite()
	return nil
}

func (p *Pipe) Block() error {
	p.mu.Lock()
	closed := p.closed
	blocked := p.blocked
	p.blocked = true
	p.mu.Unlock()

	if closed {
		return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
	}
	if blocked {
		return fmt.Errorf("nettest.Pipe(%q).Block: already blocked", p.name)
	}
	p.signalRead()
	p.signalWrite()
	return nil
}

func (p *Pipe) Unblock() error {
	p.mu.Lock()
	closed := p.closed
	blocked := p.blocked
	p.blocked = false
	p.mu.Unlock()

	if closed {
		return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
	}
	if !blocked {
		return fmt.Errorf("nettest.Pipe(%q).Block: already unblocked", p.name)
	}
	p.signalRead()
	p.signalWrite()
	return nil
}

func (p *Pipe) signalRead() {
	select {
	case p.rCh <- struct{}{}:
	default:
	}
}

func (p *Pipe) signalWrite() {
	select {
	case p.wCh <- struct{}{}:
	default:
	}
}