// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

package memnet

import (
	"net"
	"net/netip"
	"time"
)

// Conn is a net.Conn that can additionally have its reads and writes blocked and unblocked.
type Conn interface {
	net.Conn

	// SetReadBlock blocks or unblocks the Read method of this Conn.
	// It reports an error if the existing value matches the new value,
	// or if the Conn has been Closed.
	SetReadBlock(bool) error

	// SetWriteBlock blocks or unblocks the Write method of this Conn.
	// It reports an error if the existing value matches the new value,
	// or if the Conn has been Closed.
	SetWriteBlock(bool) error
}

// NewConn creates a pair of Conns that are wired together by pipes.
func NewConn(name string, maxBuf int) (Conn, Conn) {
	r := NewPipe(name+"|0", maxBuf)
	w := NewPipe(name+"|1", maxBuf)

	return &connHalf{r: r, w: w}, &connHalf{r: w, w: r}
}

// NewTCPConn creates a pair of Conns that are wired together by pipes.
func NewTCPConn(src, dst netip.AddrPort, maxBuf int) (local Conn, remote Conn) {
	r := NewPipe(src.String(), maxBuf)
	w := NewPipe(dst.String(), maxBuf)

	lAddr := net.TCPAddrFromAddrPort(src)
	rAddr := net.TCPAddrFromAddrPort(dst)

	return &connHalf{r: r, w: w, remote: rAddr, local: lAddr}, &connHalf{r: w, w: r, remote: lAddr, local: rAddr}
}

type connAddr string

func (a connAddr) Network() string { return "mem" }
func (a connAddr) String() string  { return string(a) }

type connHalf struct {
	local, remote net.Addr
	r, w          *Pipe
}

func (c *connHalf) LocalAddr() net.Addr {
	if c.local != nil {
		return c.local
	}
	return connAddr(c.r.name)
}

func (c *connHalf) RemoteAddr() net.Addr {
	if c.remote != nil {
		return c.remote
	}
	return connAddr(c.w.name)
}

func (c *connHalf) Read(b []byte) (n int, err error) {
	return c.r.Read(b)
}
func (c *connHalf) Write(b []byte) (n int, err error) {
	return c.w.Write(b)
}

func (c *connHalf) Close() error {
	if err := c.w.Close(); err != nil {
		return err
	}
	return c.r.Close()
}

func (c *connHalf) SetDeadline(t time.Time) error {
	err1 := c.SetReadDeadline(t)
	err2 := c.SetWriteDeadline(t)
	if err1 != nil {
		return err1
	}
	return err2
}
func (c *connHalf) SetReadDeadline(t time.Time) error {
	return c.r.SetReadDeadline(t)
}
func (c *connHalf) SetWriteDeadline(t time.Time) error {
	return c.w.SetWriteDeadline(t)
}

func (c *connHalf) SetReadBlock(b bool) error {
	if b {
		return c.r.Block()
	}
	return c.r.Unblock()
}
func (c *connHalf) SetWriteBlock(b bool) error {
	if b {
		return c.w.Block()
	}
	return c.w.Unblock()
}