mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-30 05:25:35 +00:00
169 lines
4.7 KiB
Go
169 lines
4.7 KiB
Go
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||
|
|
||
|
package magicsock
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"net"
|
||
|
"net/netip"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
|
||
|
"golang.org/x/net/ipv6"
|
||
|
"tailscale.com/net/netaddr"
|
||
|
"tailscale.com/types/nettype"
|
||
|
)
|
||
|
|
||
|
// RebindingUDPConn is a UDP socket that can be re-bound.
|
||
|
// Unix has no notion of re-binding a socket, so we swap it out for a new one.
|
||
|
type RebindingUDPConn struct {
|
||
|
// pconnAtomic is a pointer to the value stored in pconn, but doesn't
|
||
|
// require acquiring mu. It's used for reads/writes and only upon failure
|
||
|
// do the reads/writes then check pconn (after acquiring mu) to see if
|
||
|
// there's been a rebind meanwhile.
|
||
|
// pconn isn't really needed, but makes some of the code simpler
|
||
|
// to keep it distinct.
|
||
|
// Neither is expected to be nil, sockets are bound on creation.
|
||
|
pconnAtomic atomic.Pointer[nettype.PacketConn]
|
||
|
|
||
|
mu sync.Mutex // held while changing pconn (and pconnAtomic)
|
||
|
pconn nettype.PacketConn
|
||
|
port uint16
|
||
|
}
|
||
|
|
||
|
// setConnLocked sets the provided nettype.PacketConn. It should be called only
|
||
|
// after acquiring RebindingUDPConn.mu. It upgrades the provided
|
||
|
// nettype.PacketConn to a *batchingUDPConn when appropriate. This upgrade
|
||
|
// is intentionally pushed closest to where read/write ops occur in order to
|
||
|
// avoid disrupting surrounding code that assumes nettype.PacketConn is a
|
||
|
// *net.UDPConn.
|
||
|
func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn, network string, batchSize int) {
|
||
|
upc := tryUpgradeToBatchingUDPConn(p, network, batchSize)
|
||
|
c.pconn = upc
|
||
|
c.pconnAtomic.Store(&upc)
|
||
|
c.port = uint16(c.localAddrLocked().Port)
|
||
|
}
|
||
|
|
||
|
// currentConn returns c's current pconn, acquiring c.mu in the process.
|
||
|
func (c *RebindingUDPConn) currentConn() nettype.PacketConn {
|
||
|
c.mu.Lock()
|
||
|
defer c.mu.Unlock()
|
||
|
return c.pconn
|
||
|
}
|
||
|
|
||
|
func (c *RebindingUDPConn) readFromWithInitPconn(pconn nettype.PacketConn, b []byte) (int, netip.AddrPort, error) {
|
||
|
for {
|
||
|
n, addr, err := pconn.ReadFromUDPAddrPort(b)
|
||
|
if err != nil && pconn != c.currentConn() {
|
||
|
pconn = *c.pconnAtomic.Load()
|
||
|
continue
|
||
|
}
|
||
|
return n, addr, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// ReadFromUDPAddrPort reads a packet from c into b.
|
||
|
// It returns the number of bytes copied and the source address.
|
||
|
func (c *RebindingUDPConn) ReadFromUDPAddrPort(b []byte) (int, netip.AddrPort, error) {
|
||
|
return c.readFromWithInitPconn(*c.pconnAtomic.Load(), b)
|
||
|
}
|
||
|
|
||
|
// WriteBatchTo writes buffs to addr.
|
||
|
func (c *RebindingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error {
|
||
|
for {
|
||
|
pconn := *c.pconnAtomic.Load()
|
||
|
b, ok := pconn.(*batchingUDPConn)
|
||
|
if !ok {
|
||
|
for _, buf := range buffs {
|
||
|
_, err := c.writeToUDPAddrPortWithInitPconn(pconn, buf, addr)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
err := b.WriteBatchTo(buffs, addr)
|
||
|
if err != nil {
|
||
|
if pconn != c.currentConn() {
|
||
|
continue
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// ReadBatch reads messages from c into msgs. It returns the number of messages
|
||
|
// the caller should evaluate for nonzero len, as a zero len message may fall
|
||
|
// on either side of a nonzero.
|
||
|
func (c *RebindingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (int, error) {
|
||
|
for {
|
||
|
pconn := *c.pconnAtomic.Load()
|
||
|
b, ok := pconn.(*batchingUDPConn)
|
||
|
if !ok {
|
||
|
n, ap, err := c.readFromWithInitPconn(pconn, msgs[0].Buffers[0])
|
||
|
if err == nil {
|
||
|
msgs[0].N = n
|
||
|
msgs[0].Addr = net.UDPAddrFromAddrPort(netaddr.Unmap(ap))
|
||
|
return 1, nil
|
||
|
}
|
||
|
return 0, err
|
||
|
}
|
||
|
n, err := b.ReadBatch(msgs, flags)
|
||
|
if err != nil && pconn != c.currentConn() {
|
||
|
continue
|
||
|
}
|
||
|
return n, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (c *RebindingUDPConn) Port() uint16 {
|
||
|
c.mu.Lock()
|
||
|
defer c.mu.Unlock()
|
||
|
return c.port
|
||
|
}
|
||
|
|
||
|
func (c *RebindingUDPConn) LocalAddr() *net.UDPAddr {
|
||
|
c.mu.Lock()
|
||
|
defer c.mu.Unlock()
|
||
|
return c.localAddrLocked()
|
||
|
}
|
||
|
|
||
|
func (c *RebindingUDPConn) localAddrLocked() *net.UDPAddr {
|
||
|
return c.pconn.LocalAddr().(*net.UDPAddr)
|
||
|
}
|
||
|
|
||
|
// errNilPConn is returned by RebindingUDPConn.Close when there is no current pconn.
|
||
|
// It is for internal use only and should not be returned to users.
|
||
|
var errNilPConn = errors.New("nil pconn")
|
||
|
|
||
|
func (c *RebindingUDPConn) Close() error {
|
||
|
c.mu.Lock()
|
||
|
defer c.mu.Unlock()
|
||
|
return c.closeLocked()
|
||
|
}
|
||
|
|
||
|
func (c *RebindingUDPConn) closeLocked() error {
|
||
|
if c.pconn == nil {
|
||
|
return errNilPConn
|
||
|
}
|
||
|
c.port = 0
|
||
|
return c.pconn.Close()
|
||
|
}
|
||
|
|
||
|
func (c *RebindingUDPConn) writeToUDPAddrPortWithInitPconn(pconn nettype.PacketConn, b []byte, addr netip.AddrPort) (int, error) {
|
||
|
for {
|
||
|
n, err := pconn.WriteToUDPAddrPort(b, addr)
|
||
|
if err != nil && pconn != c.currentConn() {
|
||
|
pconn = *c.pconnAtomic.Load()
|
||
|
continue
|
||
|
}
|
||
|
return n, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
|
||
|
return c.writeToUDPAddrPortWithInitPconn(*c.pconnAtomic.Load(), b, addr)
|
||
|
}
|