// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

package magicsock

import (
	"errors"
	"net"
	"net/netip"
	"sync"
	"sync/atomic"
	"syscall"

	"golang.org/x/net/ipv6"
	"tailscale.com/net/netaddr"
	"tailscale.com/types/nettype"
)

// RebindingUDPConn is a UDP socket that can be re-bound.
// Unix has no notion of re-binding a socket, so we swap it out for a new one.
type RebindingUDPConn struct {
	// pconnAtomic is a pointer to the value stored in pconn, but doesn't
	// require acquiring mu. It's used for reads/writes and only upon failure
	// do the reads/writes then check pconn (after acquiring mu) to see if
	// there's been a rebind meanwhile.
	// pconn isn't really needed, but makes some of the code simpler
	// to keep it distinct.
	// Neither is expected to be nil, sockets are bound on creation.
	pconnAtomic atomic.Pointer[nettype.PacketConn]

	mu    sync.Mutex // held while changing pconn (and pconnAtomic)
	pconn nettype.PacketConn
	port  uint16
}

// setConnLocked sets the provided nettype.PacketConn. It should be called only
// after acquiring RebindingUDPConn.mu. It upgrades the provided
// nettype.PacketConn to a batchingConn when appropriate. This upgrade is
// intentionally pushed closest to where read/write ops occur in order to avoid
// disrupting surrounding code that assumes nettype.PacketConn is a
// *net.UDPConn.
func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn, network string, batchSize int) {
	upc := tryUpgradeToBatchingConn(p, network, batchSize)
	c.pconn = upc
	c.pconnAtomic.Store(&upc)
	c.port = uint16(c.localAddrLocked().Port)
}

// currentConn returns c's current pconn, acquiring c.mu in the process.
func (c *RebindingUDPConn) currentConn() nettype.PacketConn {
	c.mu.Lock()
	defer c.mu.Unlock()
	return c.pconn
}

func (c *RebindingUDPConn) readFromWithInitPconn(pconn nettype.PacketConn, b []byte) (int, netip.AddrPort, error) {
	for {
		n, addr, err := pconn.ReadFromUDPAddrPort(b)
		if err != nil && pconn != c.currentConn() {
			pconn = *c.pconnAtomic.Load()
			continue
		}
		return n, addr, err
	}
}

// ReadFromUDPAddrPort reads a packet from c into b.
// It returns the number of bytes copied and the source address.
func (c *RebindingUDPConn) ReadFromUDPAddrPort(b []byte) (int, netip.AddrPort, error) {
	return c.readFromWithInitPconn(*c.pconnAtomic.Load(), b)
}

// WriteBatchTo writes buffs to addr.
func (c *RebindingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error {
	for {
		pconn := *c.pconnAtomic.Load()
		b, ok := pconn.(batchingConn)
		if !ok {
			for _, buf := range buffs {
				_, err := c.writeToUDPAddrPortWithInitPconn(pconn, buf, addr)
				if err != nil {
					return err
				}
			}
			return nil
		}
		err := b.WriteBatchTo(buffs, addr)
		if err != nil {
			if pconn != c.currentConn() {
				continue
			}
			return err
		}
		return err
	}
}

// ReadBatch reads messages from c into msgs. It returns the number of messages
// the caller should evaluate for nonzero len, as a zero len message may fall
// on either side of a nonzero.
func (c *RebindingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (int, error) {
	for {
		pconn := *c.pconnAtomic.Load()
		b, ok := pconn.(batchingConn)
		if !ok {
			n, ap, err := c.readFromWithInitPconn(pconn, msgs[0].Buffers[0])
			if err == nil {
				msgs[0].N = n
				msgs[0].Addr = net.UDPAddrFromAddrPort(netaddr.Unmap(ap))
				return 1, nil
			}
			return 0, err
		}
		n, err := b.ReadBatch(msgs, flags)
		if err != nil && pconn != c.currentConn() {
			continue
		}
		return n, err
	}
}

func (c *RebindingUDPConn) Port() uint16 {
	c.mu.Lock()
	defer c.mu.Unlock()
	return c.port
}

func (c *RebindingUDPConn) LocalAddr() *net.UDPAddr {
	c.mu.Lock()
	defer c.mu.Unlock()
	return c.localAddrLocked()
}

func (c *RebindingUDPConn) localAddrLocked() *net.UDPAddr {
	return c.pconn.LocalAddr().(*net.UDPAddr)
}

// errNilPConn is returned by RebindingUDPConn.Close when there is no current pconn.
// It is for internal use only and should not be returned to users.
var errNilPConn = errors.New("nil pconn")

func (c *RebindingUDPConn) Close() error {
	c.mu.Lock()
	defer c.mu.Unlock()
	return c.closeLocked()
}

func (c *RebindingUDPConn) closeLocked() error {
	if c.pconn == nil {
		return errNilPConn
	}
	c.port = 0
	return c.pconn.Close()
}

func (c *RebindingUDPConn) writeToUDPAddrPortWithInitPconn(pconn nettype.PacketConn, b []byte, addr netip.AddrPort) (int, error) {
	for {
		n, err := pconn.WriteToUDPAddrPort(b, addr)
		if err != nil && pconn != c.currentConn() {
			pconn = *c.pconnAtomic.Load()
			continue
		}
		return n, err
	}
}

func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
	return c.writeToUDPAddrPortWithInitPconn(*c.pconnAtomic.Load(), b, addr)
}

func (c *RebindingUDPConn) SyscallConn() (syscall.RawConn, error) {
	c.mu.Lock()
	defer c.mu.Unlock()
	sc, ok := c.pconn.(syscall.Conn)
	if !ok {
		return nil, errUnsupportedConnType
	}
	return sc.SyscallConn()
}