diff --git a/client/tailscale/localclient.go b/client/tailscale/localclient.go index 32ab74e84..7062dbf3a 100644 --- a/client/tailscale/localclient.go +++ b/client/tailscale/localclient.go @@ -778,6 +778,17 @@ func (lc *LocalClient) SetDNS(ctx context.Context, name, value string) error { // // The ctx is only used for the duration of the call, not the lifetime of the net.Conn. func (lc *LocalClient) DialTCP(ctx context.Context, host string, port uint16) (net.Conn, error) { + return lc.UserDial(ctx, "tcp", host, port) +} + +// UserDial connects to the host's port via Tailscale for the given network. +// +// The host may be a base DNS name (resolved from the netmap inside tailscaled), +// a FQDN, or an IP address. +// +// The ctx is only used for the duration of the call, not the lifetime of the +// net.Conn. +func (lc *LocalClient) UserDial(ctx context.Context, network, host string, port uint16) (net.Conn, error) { connCh := make(chan net.Conn, 1) trace := httptrace.ClientTrace{ GotConn: func(info httptrace.GotConnInfo) { @@ -790,10 +801,11 @@ func (lc *LocalClient) DialTCP(ctx context.Context, host string, port uint16) (n return nil, err } req.Header = http.Header{ - "Upgrade": []string{"ts-dial"}, - "Connection": []string{"upgrade"}, - "Dial-Host": []string{host}, - "Dial-Port": []string{fmt.Sprint(port)}, + "Upgrade": []string{"ts-dial"}, + "Connection": []string{"upgrade"}, + "Dial-Host": []string{host}, + "Dial-Port": []string{fmt.Sprint(port)}, + "Dial-Network": []string{network}, } res, err := lc.DoLocalRequest(req) if err != nil { diff --git a/cmd/tailscaled/tailscaled.go b/cmd/tailscaled/tailscaled.go index 77a595dac..713a8d441 100644 --- a/cmd/tailscaled/tailscaled.go +++ b/cmd/tailscaled/tailscaled.go @@ -548,14 +548,25 @@ func getLocalBackend(ctx context.Context, logf logger.Logf, logID logid.PublicID return ok } dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { - // Note: don't just return ns.DialContextTCP or we'll - // return an interface containing a nil pointer. + // Note: don't just return ns.DialContextTCP or we'll return + // *gonet.TCPConn(nil) instead of a nil interface which trips up + // callers. tcpConn, err := ns.DialContextTCP(ctx, dst) if err != nil { return nil, err } return tcpConn, nil } + dialer.NetstackDialUDP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { + // Note: don't just return ns.DialContextUDP or we'll return + // *gonet.UDPConn(nil) instead of a nil interface which trips up + // callers. + udpConn, err := ns.DialContextUDP(ctx, dst) + if err != nil { + return nil, err + } + return udpConn, nil + } } if socksListener != nil || httpProxyListener != nil { var addrs []string diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go index a822aad69..1ea4743ac 100644 --- a/ipn/localapi/localapi.go +++ b/ipn/localapi/localapi.go @@ -6,6 +6,7 @@ package localapi import ( "bytes" + "cmp" "context" "crypto/sha256" "encoding/hex" @@ -1939,8 +1940,10 @@ func (h *Handler) serveDial(w http.ResponseWriter, r *http.Request) { return } + network := cmp.Or(r.Header.Get("Dial-Network"), "tcp") + addr := net.JoinHostPort(hostStr, portStr) - outConn, err := h.b.Dialer().UserDial(r.Context(), "tcp", addr) + outConn, err := h.b.Dialer().UserDial(r.Context(), network, addr) if err != nil { http.Error(w, "dial failure: "+err.Error(), http.StatusBadGateway) return diff --git a/net/tsdial/tsdial.go b/net/tsdial/tsdial.go index 42433b871..d69075318 100644 --- a/net/tsdial/tsdial.go +++ b/net/tsdial/tsdial.go @@ -59,6 +59,10 @@ type Dialer struct { // If nil, it's not used. NetstackDialTCP func(context.Context, netip.AddrPort) (net.Conn, error) + // NetstackDialUDP dials the provided IPPort using netstack. + // If nil, it's not used. + NetstackDialUDP func(context.Context, netip.AddrPort) (net.Conn, error) + peerClientOnce sync.Once peerClient *http.Client @@ -403,9 +407,12 @@ func (d *Dialer) UserDial(ctx context.Context, network, addr string) (net.Conn, return nil, err } if d.UseNetstackForIP != nil && d.UseNetstackForIP(ipp.Addr()) { - if d.NetstackDialTCP == nil { + if d.NetstackDialTCP == nil || d.NetstackDialUDP == nil { return nil, errors.New("Dialer not initialized correctly") } + if strings.HasPrefix(network, "udp") { + return d.NetstackDialUDP(ctx, ipp) + } return d.NetstackDialTCP(ctx, ipp) } diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index a200ecbec..b7973a233 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -562,14 +562,25 @@ func (s *Server) start() (reterr error) { return ok } s.dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { - // Note: don't just return ns.DialContextTCP or we'll - // return an interface containing a nil pointer. + // Note: don't just return ns.DialContextTCP or we'll return + // *gonet.TCPConn(nil) instead of a nil interface which trips up + // callers. tcpConn, err := ns.DialContextTCP(ctx, dst) if err != nil { return nil, err } return tcpConn, nil } + s.dialer.NetstackDialUDP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { + // Note: don't just return ns.DialContextUDP or we'll return + // *gonet.UDPConn(nil) instead of a nil interface which trips up + // callers. + udpConn, err := ns.DialContextUDP(ctx, dst) + if err != nil { + return nil, err + } + return udpConn, nil + } if s.Store == nil { stateFile := filepath.Join(s.rootPath, "tailscaled.state") @@ -908,6 +919,34 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) { return s.listen(network, addr, listenOnTailnet) } +// ListenPacket announces on the Tailscale network. +// +// The network must be "udp", "udp4" or "udp6". The addr must be of the form +// "ip:port" (or "[ip]:port") where ip is a valid IPv4 or IPv6 address +// corresponding to "udp4" or "udp6" respectively. IP must be specified. +// +// If s has not been started yet, it will be started. +func (s *Server) ListenPacket(network, addr string) (net.PacketConn, error) { + ap, err := resolveListenAddr(network, addr) + if err != nil { + return nil, err + } + if !ap.Addr().IsValid() { + return nil, fmt.Errorf("tsnet.ListenPacket(%q, %q): address must be a valid IP", network, addr) + } + if network == "udp" { + if ap.Addr().Is4() { + network = "udp4" + } else { + network = "udp6" + } + } + if err := s.Start(); err != nil { + return nil, err + } + return s.netstack.ListenPacket(network, ap.String()) +} + // ListenTLS announces only on the Tailscale network. // It returns a TLS listener wrapping the tsnet listener. // It will start the server if it has not been started yet. @@ -1070,50 +1109,65 @@ const ( listenOnBoth = listenOn("listen-on-both") ) -func (s *Server) listen(network, addr string, lnOn listenOn) (net.Listener, error) { - switch network { - case "", "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": - default: - return nil, errors.New("unsupported network type") - } +// resolveListenAddr resolves a network and address into a netip.AddrPort. The +// returned netip.AddrPort.Addr will be the zero value if the address is empty. +// The port must be a valid port number. The caller is responsible for checking +// the network and address are valid. +// +// It resolves well-known port names and validates the address is a valid IP +// literal for the network. +func resolveListenAddr(network, addr string) (netip.AddrPort, error) { + var zero netip.AddrPort host, portStr, err := net.SplitHostPort(addr) if err != nil { - return nil, fmt.Errorf("tsnet: %w", err) + return zero, fmt.Errorf("tsnet: %w", err) } port, err := net.LookupPort(network, portStr) if err != nil || port < 0 || port > math.MaxUint16 { // LookupPort returns an error on out of range values so the bounds // checks on port should be unnecessary, but harmless. If they do // match, worst case this error message says "invalid port: ". - return nil, fmt.Errorf("invalid port: %w", err) + return zero, fmt.Errorf("invalid port: %w", err) } - var bindHostOrZero netip.Addr - if host != "" { - bindHostOrZero, err = netip.ParseAddr(host) - if err != nil { - return nil, fmt.Errorf("invalid Listen addr %q; host part must be empty or IP literal", host) - } - if strings.HasSuffix(network, "4") && !bindHostOrZero.Is4() { - return nil, fmt.Errorf("invalid non-IPv4 addr %v for network %q", host, network) - } - if strings.HasSuffix(network, "6") && !bindHostOrZero.Is6() { - return nil, fmt.Errorf("invalid non-IPv6 addr %v for network %q", host, network) - } + if host == "" { + return netip.AddrPortFrom(netip.Addr{}, uint16(port)), nil } + bindHostOrZero, err := netip.ParseAddr(host) + if err != nil { + return zero, fmt.Errorf("invalid Listen addr %q; host part must be empty or IP literal", host) + } + if strings.HasSuffix(network, "4") && !bindHostOrZero.Is4() { + return zero, fmt.Errorf("invalid non-IPv4 addr %v for network %q", host, network) + } + if strings.HasSuffix(network, "6") && !bindHostOrZero.Is6() { + return zero, fmt.Errorf("invalid non-IPv6 addr %v for network %q", host, network) + } + return netip.AddrPortFrom(bindHostOrZero, uint16(port)), nil +} + +func (s *Server) listen(network, addr string, lnOn listenOn) (*listener, error) { + switch network { + case "", "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": + default: + return nil, errors.New("unsupported network type") + } + host, err := resolveListenAddr(network, addr) + if err != nil { + return nil, err + } if err := s.Start(); err != nil { return nil, err } - var keys []listenKey switch lnOn { case listenOnTailnet: - keys = append(keys, listenKey{network, bindHostOrZero, uint16(port), false}) + keys = append(keys, listenKey{network, host.Addr(), host.Port(), false}) case listenOnFunnel: - keys = append(keys, listenKey{network, bindHostOrZero, uint16(port), true}) + keys = append(keys, listenKey{network, host.Addr(), host.Port(), true}) case listenOnBoth: - keys = append(keys, listenKey{network, bindHostOrZero, uint16(port), false}) - keys = append(keys, listenKey{network, bindHostOrZero, uint16(port), true}) + keys = append(keys, listenKey{network, host.Addr(), host.Port(), false}) + keys = append(keys, listenKey{network, host.Addr(), host.Port(), true}) } ln := &listener{ diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index e6408b06b..9589b4796 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -745,3 +745,73 @@ func TestCapturePcap(t *testing.T) { t.Errorf("s2 pcap file size = %d, want > pcapHeaderSize(%d)", got, pcapHeaderSize) } } + +func TestUDPConn(t *testing.T) { + tstest.ResourceCheck(t) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + controlURL, _ := startControl(t) + s1, s1ip, _ := startServer(t, ctx, controlURL, "s1") + s2, s2ip, _ := startServer(t, ctx, controlURL, "s2") + + lc2, err := s2.LocalClient() + if err != nil { + t.Fatal(err) + } + + // ping to make sure the connection is up. + res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP) + if err != nil { + t.Fatal(err) + } + t.Logf("ping success: %#+v", res) + + pc := must.Get(s1.ListenPacket("udp", fmt.Sprintf("%s:8081", s1ip))) + defer pc.Close() + + // Dial to s1 from s2 + w, err := s2.Dial(ctx, "udp", fmt.Sprintf("%s:8081", s1ip)) + if err != nil { + t.Fatal(err) + } + defer w.Close() + + // Send a packet from s2 to s1 + want := "hello" + if _, err := io.WriteString(w, want); err != nil { + t.Fatal(err) + } + + // Receive the packet on s1 + got := make([]byte, 1024) + n, from, err := pc.ReadFrom(got) + if err != nil { + t.Fatal(err) + } + got = got[:n] + t.Logf("got: %q", got) + if string(got) != want { + t.Errorf("got %q, want %q", got, want) + } + if from.(*net.UDPAddr).AddrPort().Addr() != s2ip { + t.Errorf("got from %v, want %v", from, s2ip) + } + + // Write a response back to s2 + if _, err := pc.WriteTo([]byte("world"), from); err != nil { + t.Fatal(err) + } + + // Receive the response on s2 + got = make([]byte, 1024) + n, err = w.Read(got) + if err != nil { + t.Fatal(err) + } + got = got[:n] + t.Logf("got: %q", got) + if string(got) != "world" { + t.Errorf("got %q, want world", got) + } +} diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 0b1b736eb..fbbcce3a9 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -1326,6 +1326,50 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet. return } +// ListenPacket listens for incoming packets for the given network and address. +// Address must be of the form "ip:port" or "[ip]:port". +// +// As of 2024-05-18, only udp4 and udp6 are supported. +func (ns *Impl) ListenPacket(network, address string) (net.PacketConn, error) { + ap, err := netip.ParseAddrPort(address) + if err != nil { + return nil, fmt.Errorf("netstack: ParseAddrPort(%q): %v", address, err) + } + + var networkProto tcpip.NetworkProtocolNumber + switch network { + case "udp": + return nil, fmt.Errorf("netstack: udp not supported; use udp4 or udp6") + case "udp4": + networkProto = ipv4.ProtocolNumber + if !ap.Addr().Is4() { + return nil, fmt.Errorf("netstack: udp4 requires an IPv4 address") + } + case "udp6": + networkProto = ipv6.ProtocolNumber + if !ap.Addr().Is6() { + return nil, fmt.Errorf("netstack: udp6 requires an IPv6 address") + } + default: + return nil, fmt.Errorf("netstack: unsupported network %q", network) + } + var wq waiter.Queue + ep, nserr := ns.ipstack.NewEndpoint(udp.ProtocolNumber, networkProto, &wq) + if nserr != nil { + return nil, fmt.Errorf("netstack: NewEndpoint: %v", nserr) + } + localAddress := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(ap.Addr().AsSlice()), + Port: ap.Port(), + } + if err := ep.Bind(localAddress); err != nil { + ep.Close() + return nil, fmt.Errorf("netstack: Bind(%v): %v", localAddress, err) + } + return gonet.NewUDPConn(&wq, ep), nil +} + func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) { sess := r.ID() if debugNetstack() {