tsnet,wgengine/netstack: add ListenPacket and tests

This adds a new ListenPacket function on tsnet.Server
which acts mostly like `net.ListenPacket`.

Unlike `Server.Listen`, this requires listening on a
specific IP and does not automatically listen on both
V4 and V6 addresses of the Server when the IP is unspecified.

To test this, it also adds UDP support to tsdial.Dialer.UserDial
and plumbs it through the localapi. Then an associated test
to make sure the UDP functionality works from both sides.

Updates #12182

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2024-05-18 14:37:37 -07:00 committed by Maisem Ali
parent bcb55fdeb6
commit 42cfbf427c
7 changed files with 236 additions and 35 deletions

View File

@ -778,6 +778,17 @@ func (lc *LocalClient) SetDNS(ctx context.Context, name, value string) error {
// //
// The ctx is only used for the duration of the call, not the lifetime of the net.Conn. // The ctx is only used for the duration of the call, not the lifetime of the net.Conn.
func (lc *LocalClient) DialTCP(ctx context.Context, host string, port uint16) (net.Conn, error) { func (lc *LocalClient) DialTCP(ctx context.Context, host string, port uint16) (net.Conn, error) {
return lc.UserDial(ctx, "tcp", host, port)
}
// UserDial connects to the host's port via Tailscale for the given network.
//
// The host may be a base DNS name (resolved from the netmap inside tailscaled),
// a FQDN, or an IP address.
//
// The ctx is only used for the duration of the call, not the lifetime of the
// net.Conn.
func (lc *LocalClient) UserDial(ctx context.Context, network, host string, port uint16) (net.Conn, error) {
connCh := make(chan net.Conn, 1) connCh := make(chan net.Conn, 1)
trace := httptrace.ClientTrace{ trace := httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) { GotConn: func(info httptrace.GotConnInfo) {
@ -794,6 +805,7 @@ func (lc *LocalClient) DialTCP(ctx context.Context, host string, port uint16) (n
"Connection": []string{"upgrade"}, "Connection": []string{"upgrade"},
"Dial-Host": []string{host}, "Dial-Host": []string{host},
"Dial-Port": []string{fmt.Sprint(port)}, "Dial-Port": []string{fmt.Sprint(port)},
"Dial-Network": []string{network},
} }
res, err := lc.DoLocalRequest(req) res, err := lc.DoLocalRequest(req)
if err != nil { if err != nil {

View File

@ -548,14 +548,25 @@ func getLocalBackend(ctx context.Context, logf logger.Logf, logID logid.PublicID
return ok return ok
} }
dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
// Note: don't just return ns.DialContextTCP or we'll // Note: don't just return ns.DialContextTCP or we'll return
// return an interface containing a nil pointer. // *gonet.TCPConn(nil) instead of a nil interface which trips up
// callers.
tcpConn, err := ns.DialContextTCP(ctx, dst) tcpConn, err := ns.DialContextTCP(ctx, dst)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return tcpConn, nil return tcpConn, nil
} }
dialer.NetstackDialUDP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
// Note: don't just return ns.DialContextUDP or we'll return
// *gonet.UDPConn(nil) instead of a nil interface which trips up
// callers.
udpConn, err := ns.DialContextUDP(ctx, dst)
if err != nil {
return nil, err
}
return udpConn, nil
}
} }
if socksListener != nil || httpProxyListener != nil { if socksListener != nil || httpProxyListener != nil {
var addrs []string var addrs []string

View File

@ -6,6 +6,7 @@
import ( import (
"bytes" "bytes"
"cmp"
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
@ -1939,8 +1940,10 @@ func (h *Handler) serveDial(w http.ResponseWriter, r *http.Request) {
return return
} }
network := cmp.Or(r.Header.Get("Dial-Network"), "tcp")
addr := net.JoinHostPort(hostStr, portStr) addr := net.JoinHostPort(hostStr, portStr)
outConn, err := h.b.Dialer().UserDial(r.Context(), "tcp", addr) outConn, err := h.b.Dialer().UserDial(r.Context(), network, addr)
if err != nil { if err != nil {
http.Error(w, "dial failure: "+err.Error(), http.StatusBadGateway) http.Error(w, "dial failure: "+err.Error(), http.StatusBadGateway)
return return

View File

@ -59,6 +59,10 @@ type Dialer struct {
// If nil, it's not used. // If nil, it's not used.
NetstackDialTCP func(context.Context, netip.AddrPort) (net.Conn, error) NetstackDialTCP func(context.Context, netip.AddrPort) (net.Conn, error)
// NetstackDialUDP dials the provided IPPort using netstack.
// If nil, it's not used.
NetstackDialUDP func(context.Context, netip.AddrPort) (net.Conn, error)
peerClientOnce sync.Once peerClientOnce sync.Once
peerClient *http.Client peerClient *http.Client
@ -403,9 +407,12 @@ func (d *Dialer) UserDial(ctx context.Context, network, addr string) (net.Conn,
return nil, err return nil, err
} }
if d.UseNetstackForIP != nil && d.UseNetstackForIP(ipp.Addr()) { if d.UseNetstackForIP != nil && d.UseNetstackForIP(ipp.Addr()) {
if d.NetstackDialTCP == nil { if d.NetstackDialTCP == nil || d.NetstackDialUDP == nil {
return nil, errors.New("Dialer not initialized correctly") return nil, errors.New("Dialer not initialized correctly")
} }
if strings.HasPrefix(network, "udp") {
return d.NetstackDialUDP(ctx, ipp)
}
return d.NetstackDialTCP(ctx, ipp) return d.NetstackDialTCP(ctx, ipp)
} }

View File

@ -562,14 +562,25 @@ func (s *Server) start() (reterr error) {
return ok return ok
} }
s.dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { s.dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
// Note: don't just return ns.DialContextTCP or we'll // Note: don't just return ns.DialContextTCP or we'll return
// return an interface containing a nil pointer. // *gonet.TCPConn(nil) instead of a nil interface which trips up
// callers.
tcpConn, err := ns.DialContextTCP(ctx, dst) tcpConn, err := ns.DialContextTCP(ctx, dst)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return tcpConn, nil return tcpConn, nil
} }
s.dialer.NetstackDialUDP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
// Note: don't just return ns.DialContextUDP or we'll return
// *gonet.UDPConn(nil) instead of a nil interface which trips up
// callers.
udpConn, err := ns.DialContextUDP(ctx, dst)
if err != nil {
return nil, err
}
return udpConn, nil
}
if s.Store == nil { if s.Store == nil {
stateFile := filepath.Join(s.rootPath, "tailscaled.state") stateFile := filepath.Join(s.rootPath, "tailscaled.state")
@ -908,6 +919,34 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) {
return s.listen(network, addr, listenOnTailnet) return s.listen(network, addr, listenOnTailnet)
} }
// ListenPacket announces on the Tailscale network.
//
// The network must be "udp", "udp4" or "udp6". The addr must be of the form
// "ip:port" (or "[ip]:port") where ip is a valid IPv4 or IPv6 address
// corresponding to "udp4" or "udp6" respectively. IP must be specified.
//
// If s has not been started yet, it will be started.
func (s *Server) ListenPacket(network, addr string) (net.PacketConn, error) {
ap, err := resolveListenAddr(network, addr)
if err != nil {
return nil, err
}
if !ap.Addr().IsValid() {
return nil, fmt.Errorf("tsnet.ListenPacket(%q, %q): address must be a valid IP", network, addr)
}
if network == "udp" {
if ap.Addr().Is4() {
network = "udp4"
} else {
network = "udp6"
}
}
if err := s.Start(); err != nil {
return nil, err
}
return s.netstack.ListenPacket(network, ap.String())
}
// ListenTLS announces only on the Tailscale network. // ListenTLS announces only on the Tailscale network.
// It returns a TLS listener wrapping the tsnet listener. // It returns a TLS listener wrapping the tsnet listener.
// It will start the server if it has not been started yet. // It will start the server if it has not been started yet.
@ -1070,50 +1109,65 @@ func (s *Server) ListenFunnel(network, addr string, opts ...FunnelOption) (net.L
listenOnBoth = listenOn("listen-on-both") listenOnBoth = listenOn("listen-on-both")
) )
func (s *Server) listen(network, addr string, lnOn listenOn) (net.Listener, error) { // resolveListenAddr resolves a network and address into a netip.AddrPort. The
switch network { // returned netip.AddrPort.Addr will be the zero value if the address is empty.
case "", "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": // The port must be a valid port number. The caller is responsible for checking
default: // the network and address are valid.
return nil, errors.New("unsupported network type") //
} // It resolves well-known port names and validates the address is a valid IP
// literal for the network.
func resolveListenAddr(network, addr string) (netip.AddrPort, error) {
var zero netip.AddrPort
host, portStr, err := net.SplitHostPort(addr) host, portStr, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("tsnet: %w", err) return zero, fmt.Errorf("tsnet: %w", err)
} }
port, err := net.LookupPort(network, portStr) port, err := net.LookupPort(network, portStr)
if err != nil || port < 0 || port > math.MaxUint16 { if err != nil || port < 0 || port > math.MaxUint16 {
// LookupPort returns an error on out of range values so the bounds // LookupPort returns an error on out of range values so the bounds
// checks on port should be unnecessary, but harmless. If they do // checks on port should be unnecessary, but harmless. If they do
// match, worst case this error message says "invalid port: <nil>". // match, worst case this error message says "invalid port: <nil>".
return nil, fmt.Errorf("invalid port: %w", err) return zero, fmt.Errorf("invalid port: %w", err)
}
var bindHostOrZero netip.Addr
if host != "" {
bindHostOrZero, err = netip.ParseAddr(host)
if err != nil {
return nil, fmt.Errorf("invalid Listen addr %q; host part must be empty or IP literal", host)
}
if strings.HasSuffix(network, "4") && !bindHostOrZero.Is4() {
return nil, fmt.Errorf("invalid non-IPv4 addr %v for network %q", host, network)
}
if strings.HasSuffix(network, "6") && !bindHostOrZero.Is6() {
return nil, fmt.Errorf("invalid non-IPv6 addr %v for network %q", host, network)
} }
if host == "" {
return netip.AddrPortFrom(netip.Addr{}, uint16(port)), nil
} }
bindHostOrZero, err := netip.ParseAddr(host)
if err != nil {
return zero, fmt.Errorf("invalid Listen addr %q; host part must be empty or IP literal", host)
}
if strings.HasSuffix(network, "4") && !bindHostOrZero.Is4() {
return zero, fmt.Errorf("invalid non-IPv4 addr %v for network %q", host, network)
}
if strings.HasSuffix(network, "6") && !bindHostOrZero.Is6() {
return zero, fmt.Errorf("invalid non-IPv6 addr %v for network %q", host, network)
}
return netip.AddrPortFrom(bindHostOrZero, uint16(port)), nil
}
func (s *Server) listen(network, addr string, lnOn listenOn) (*listener, error) {
switch network {
case "", "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
default:
return nil, errors.New("unsupported network type")
}
host, err := resolveListenAddr(network, addr)
if err != nil {
return nil, err
}
if err := s.Start(); err != nil { if err := s.Start(); err != nil {
return nil, err return nil, err
} }
var keys []listenKey var keys []listenKey
switch lnOn { switch lnOn {
case listenOnTailnet: case listenOnTailnet:
keys = append(keys, listenKey{network, bindHostOrZero, uint16(port), false}) keys = append(keys, listenKey{network, host.Addr(), host.Port(), false})
case listenOnFunnel: case listenOnFunnel:
keys = append(keys, listenKey{network, bindHostOrZero, uint16(port), true}) keys = append(keys, listenKey{network, host.Addr(), host.Port(), true})
case listenOnBoth: case listenOnBoth:
keys = append(keys, listenKey{network, bindHostOrZero, uint16(port), false}) keys = append(keys, listenKey{network, host.Addr(), host.Port(), false})
keys = append(keys, listenKey{network, bindHostOrZero, uint16(port), true}) keys = append(keys, listenKey{network, host.Addr(), host.Port(), true})
} }
ln := &listener{ ln := &listener{

View File

@ -745,3 +745,73 @@ func TestCapturePcap(t *testing.T) {
t.Errorf("s2 pcap file size = %d, want > pcapHeaderSize(%d)", got, pcapHeaderSize) t.Errorf("s2 pcap file size = %d, want > pcapHeaderSize(%d)", got, pcapHeaderSize)
} }
} }
func TestUDPConn(t *testing.T) {
tstest.ResourceCheck(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
controlURL, _ := startControl(t)
s1, s1ip, _ := startServer(t, ctx, controlURL, "s1")
s2, s2ip, _ := startServer(t, ctx, controlURL, "s2")
lc2, err := s2.LocalClient()
if err != nil {
t.Fatal(err)
}
// ping to make sure the connection is up.
res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP)
if err != nil {
t.Fatal(err)
}
t.Logf("ping success: %#+v", res)
pc := must.Get(s1.ListenPacket("udp", fmt.Sprintf("%s:8081", s1ip)))
defer pc.Close()
// Dial to s1 from s2
w, err := s2.Dial(ctx, "udp", fmt.Sprintf("%s:8081", s1ip))
if err != nil {
t.Fatal(err)
}
defer w.Close()
// Send a packet from s2 to s1
want := "hello"
if _, err := io.WriteString(w, want); err != nil {
t.Fatal(err)
}
// Receive the packet on s1
got := make([]byte, 1024)
n, from, err := pc.ReadFrom(got)
if err != nil {
t.Fatal(err)
}
got = got[:n]
t.Logf("got: %q", got)
if string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
if from.(*net.UDPAddr).AddrPort().Addr() != s2ip {
t.Errorf("got from %v, want %v", from, s2ip)
}
// Write a response back to s2
if _, err := pc.WriteTo([]byte("world"), from); err != nil {
t.Fatal(err)
}
// Receive the response on s2
got = make([]byte, 1024)
n, err = w.Read(got)
if err != nil {
t.Fatal(err)
}
got = got[:n]
t.Logf("got: %q", got)
if string(got) != "world" {
t.Errorf("got %q, want world", got)
}
}

View File

@ -1326,6 +1326,50 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet.
return return
} }
// ListenPacket listens for incoming packets for the given network and address.
// Address must be of the form "ip:port" or "[ip]:port".
//
// As of 2024-05-18, only udp4 and udp6 are supported.
func (ns *Impl) ListenPacket(network, address string) (net.PacketConn, error) {
ap, err := netip.ParseAddrPort(address)
if err != nil {
return nil, fmt.Errorf("netstack: ParseAddrPort(%q): %v", address, err)
}
var networkProto tcpip.NetworkProtocolNumber
switch network {
case "udp":
return nil, fmt.Errorf("netstack: udp not supported; use udp4 or udp6")
case "udp4":
networkProto = ipv4.ProtocolNumber
if !ap.Addr().Is4() {
return nil, fmt.Errorf("netstack: udp4 requires an IPv4 address")
}
case "udp6":
networkProto = ipv6.ProtocolNumber
if !ap.Addr().Is6() {
return nil, fmt.Errorf("netstack: udp6 requires an IPv6 address")
}
default:
return nil, fmt.Errorf("netstack: unsupported network %q", network)
}
var wq waiter.Queue
ep, nserr := ns.ipstack.NewEndpoint(udp.ProtocolNumber, networkProto, &wq)
if nserr != nil {
return nil, fmt.Errorf("netstack: NewEndpoint: %v", nserr)
}
localAddress := tcpip.FullAddress{
NIC: nicID,
Addr: tcpip.AddrFromSlice(ap.Addr().AsSlice()),
Port: ap.Port(),
}
if err := ep.Bind(localAddress); err != nil {
ep.Close()
return nil, fmt.Errorf("netstack: Bind(%v): %v", localAddress, err)
}
return gonet.NewUDPConn(&wq, ep), nil
}
func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) { func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
sess := r.ID() sess := r.ID()
if debugNetstack() { if debugNetstack() {