// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

// Package wsconn contains an adapter type that turns
// a websocket connection into a net.Conn. It a temporary fork of the
// netconn.go file from the nhooyr.io/websocket package while we wait for
// https://github.com/nhooyr/websocket/pull/350 to be merged.
package wsconn

import (
	"context"
	"fmt"
	"io"
	"math"
	"net"
	"os"
	"sync"
	"sync/atomic"
	"time"

	"nhooyr.io/websocket"
)

// NetConn converts a *websocket.Conn into a net.Conn.
//
// It's for tunneling arbitrary protocols over WebSockets.
// Few users of the library will need this but it's tricky to implement
// correctly and so provided in the library.
// See https://github.com/nhooyr/websocket/issues/100.
//
// Every Write to the net.Conn will correspond to a message write of
// the given type on *websocket.Conn.
//
// The passed ctx bounds the lifetime of the net.Conn. If cancelled,
// all reads and writes on the net.Conn will be cancelled.
//
// If a message is read that is not of the correct type, the connection
// will be closed with StatusUnsupportedData and an error will be returned.
//
// Close will close the *websocket.Conn with StatusNormalClosure.
//
// When a deadline is hit, the connection will be closed. This is
// different from most net.Conn implementations where only the
// reading/writing goroutines are interrupted but the connection is kept alive.
//
// The Addr methods will return a mock net.Addr that returns "websocket" for Network
// and "websocket/unknown-addr" for String.
//
// A received StatusNormalClosure or StatusGoingAway close frame will be translated to
// io.EOF when reading.
//
// The given remoteAddr will be the value of the returned conn's
// RemoteAddr().String(). For best compatibility with consumers of
// conns, the string should be an ip:port if available, but in the
// absence of that it can be any string that describes the remote
// endpoint, or the empty string to makes RemoteAddr() return a place
// holder value.
func NetConn(ctx context.Context, c *websocket.Conn, msgType websocket.MessageType, remoteAddr string) net.Conn {
	nc := &netConn{
		c:          c,
		msgType:    msgType,
		remoteAddr: remoteAddr,
	}

	var writeCancel context.CancelFunc
	nc.writeContext, writeCancel = context.WithCancel(ctx)
	nc.writeTimer = time.AfterFunc(math.MaxInt64, func() {
		nc.afterWriteDeadline.Store(true)
		if nc.writing.Load() {
			writeCancel()
		}
	})
	if !nc.writeTimer.Stop() {
		<-nc.writeTimer.C
	}

	var readCancel context.CancelFunc
	nc.readContext, readCancel = context.WithCancel(ctx)
	nc.readTimer = time.AfterFunc(math.MaxInt64, func() {
		nc.afterReadDeadline.Store(true)
		if nc.reading.Load() {
			readCancel()
		}
	})
	if !nc.readTimer.Stop() {
		<-nc.readTimer.C
	}

	return nc
}

type netConn struct {
	c          *websocket.Conn
	msgType    websocket.MessageType
	remoteAddr string

	writeTimer         *time.Timer
	writeContext       context.Context
	writing            atomic.Bool
	afterWriteDeadline atomic.Bool

	readTimer         *time.Timer
	readContext       context.Context
	reading           atomic.Bool
	afterReadDeadline atomic.Bool

	readMu sync.Mutex
	eofed  bool
	reader io.Reader
}

var _ net.Conn = &netConn{}

func (c *netConn) Close() error {
	return c.c.Close(websocket.StatusNormalClosure, "")
}

func (c *netConn) Write(p []byte) (int, error) {
	if c.afterWriteDeadline.Load() {
		return 0, os.ErrDeadlineExceeded
	}

	if swapped := c.writing.CompareAndSwap(false, true); !swapped {
		panic("Concurrent writes not allowed")
	}
	defer c.writing.Store(false)

	err := c.c.Write(c.writeContext, c.msgType, p)
	if err != nil {
		return 0, err
	}

	return len(p), nil
}

func (c *netConn) Read(p []byte) (int, error) {
	if c.afterReadDeadline.Load() {
		return 0, os.ErrDeadlineExceeded
	}

	c.readMu.Lock()
	defer c.readMu.Unlock()
	if swapped := c.reading.CompareAndSwap(false, true); !swapped {
		panic("Concurrent reads not allowed")
	}
	defer c.reading.Store(false)

	if c.eofed {
		return 0, io.EOF
	}

	if c.reader == nil {
		typ, r, err := c.c.Reader(c.readContext)
		if err != nil {
			switch websocket.CloseStatus(err) {
			case websocket.StatusNormalClosure, websocket.StatusGoingAway:
				c.eofed = true
				return 0, io.EOF
			}
			return 0, err
		}
		if typ != c.msgType {
			err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ)
			c.c.Close(websocket.StatusUnsupportedData, err.Error())
			return 0, err
		}
		c.reader = r
	}

	n, err := c.reader.Read(p)
	if err == io.EOF {
		c.reader = nil
		err = nil
	}
	return n, err
}

type websocketAddr struct {
	addr string
}

func (a websocketAddr) Network() string {
	return "websocket"
}

func (a websocketAddr) String() string {
	if a.addr != "" {
		return a.addr
	}
	return "websocket/unknown-addr"
}

func (c *netConn) RemoteAddr() net.Addr {
	return websocketAddr{c.remoteAddr}
}

func (c *netConn) LocalAddr() net.Addr {
	return websocketAddr{""}
}

func (c *netConn) SetDeadline(t time.Time) error {
	c.SetWriteDeadline(t)
	c.SetReadDeadline(t)
	return nil
}

func (c *netConn) SetWriteDeadline(t time.Time) error {
	if t.IsZero() {
		c.writeTimer.Stop()
	} else {
		c.writeTimer.Reset(time.Until(t))
	}
	c.afterWriteDeadline.Store(false)
	return nil
}

func (c *netConn) SetReadDeadline(t time.Time) error {
	if t.IsZero() {
		c.readTimer.Stop()
	} else {
		c.readTimer.Reset(time.Until(t))
	}
	c.afterReadDeadline.Store(false)
	return nil
}