diff --git a/cmd/stunstamp/stunstamp.go b/cmd/stunstamp/stunstamp.go index e2b034e32..e01f3ac92 100644 --- a/cmd/stunstamp/stunstamp.go +++ b/cmd/stunstamp/stunstamp.go @@ -1,13 +1,14 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -// The stunstamp binary measures STUN round-trip latency with DERPs. +// The stunstamp binary measures round-trip latency with DERPs. package main import ( "bytes" "cmp" "context" + "crypto/tls" "encoding/json" "errors" "flag" @@ -31,8 +32,10 @@ "github.com/golang/snappy" "github.com/prometheus/prometheus/prompb" + "github.com/tcnksm/go-httpstat" "tailscale.com/logtail/backoff" "tailscale.com/net/stun" + "tailscale.com/net/tcpinfo" "tailscale.com/tailcfg" ) @@ -44,6 +47,7 @@ flagInstance = flag.String("instance", "", "instance label value; defaults to hostname if unspecified") flagSTUNDstPorts = flag.String("stun-dst-ports", "", "comma-separated list of STUN destination ports to monitor") flagHTTPSDstPorts = flag.String("https-dst-ports", "", "comma-separated list of HTTPS destination ports to monitor") + flagTCPDstPorts = flag.String("tcp-dst-ports", "", "comma-separated list of TCP destination ports to monitor") flagICMP = flag.Bool("icmp", false, "probe ICMP") ) @@ -97,6 +101,7 @@ func (t timestampSource) String() string { protocolSTUN protocol = "stun" protocolICMP protocol = "icmp" protocolHTTPS protocol = "https" + protocolTCP protocol = "tcp" ) // resultKey contains the stable dimensions and their values for a given @@ -115,7 +120,188 @@ type result struct { rtt *time.Duration // nil signifies failure, e.g. timeout } -func measureSTUNRTT(conn io.ReadWriteCloser, dst netip.AddrPort) (rtt time.Duration, err error) { +type lportsPool struct { + sync.Mutex + ports []int +} + +func (l *lportsPool) get() int { + l.Lock() + defer l.Unlock() + ret := l.ports[0] + l.ports = append(l.ports[:0], l.ports[1:]...) + return ret +} + +func (l *lportsPool) put(i int) { + l.Lock() + defer l.Unlock() + l.ports = append(l.ports, int(i)) +} + +var ( + lports *lportsPool +) + +const ( + lportPoolSize = 16000 + lportBase = 2048 +) + +func init() { + lports = &lportsPool{ + ports: make([]int, 0, lportPoolSize), + } + for i := lportBase; i < lportBase+lportPoolSize; i++ { + lports.ports = append(lports.ports, i) + } +} + +// lportForTCPConn satisfies io.ReadWriteCloser, but is really just used to pass +// around a persistent laddr for stableConn purposes. The underlying TCP +// connection is not created until measurement time as in some cases we need to +// measure dial time. +type lportForTCPConn int + +func (l *lportForTCPConn) Close() error { + if *l == 0 { + return nil + } + lports.put(int(*l)) + return nil +} + +func (l *lportForTCPConn) Write([]byte) (int, error) { + return 0, errors.New("unimplemented") +} + +func (l *lportForTCPConn) Read([]byte) (int, error) { + return 0, errors.New("unimplemented") +} + +func addrInUse(err error, lport *lportForTCPConn) bool { + if errors.Is(err, syscall.EADDRINUSE) { + old := int(*lport) + // abandon port, don't return it to pool + *lport = lportForTCPConn(lports.get()) // get a new port + log.Printf("EADDRINUSE: %v old: %d new: %d", err, old, *lport) + return true + } + return false +} + +func tcpDial(lport *lportForTCPConn, dst netip.AddrPort) (net.Conn, error) { + for { + var opErr error + dialer := &net.Dialer{ + Timeout: time.Second * 2, + LocalAddr: &net.TCPAddr{ + Port: int(*lport), + }, + Control: func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + // we may restart faster than TIME_WAIT can clear + opErr = setSOReuseAddr(fd) + }) + }, + } + if opErr != nil { + panic(opErr) + } + tcpConn, err := dialer.Dial("tcp", dst.String()) + if err != nil { + if addrInUse(err, lport) { + continue + } + return nil, err + } + return tcpConn, nil + } +} + +type tempError struct { + error +} + +func (t tempError) Temporary() bool { + return true +} + +func measureTCPRTT(conn io.ReadWriteCloser, _ string, dst netip.AddrPort) (rtt time.Duration, err error) { + lport, ok := conn.(*lportForTCPConn) + if !ok { + return 0, fmt.Errorf("unexpected conn type: %T", conn) + } + tcpConn, err := tcpDial(lport, dst) + if err != nil { + return 0, tempError{err} + } + defer tcpConn.Close() + rtt, err = tcpinfo.RTT(tcpConn) + if err != nil { + return 0, tempError{err} + } + return rtt, nil +} + +func measureHTTPSRTT(conn io.ReadWriteCloser, hostname string, dst netip.AddrPort) (rtt time.Duration, err error) { + lport, ok := conn.(*lportForTCPConn) + if !ok { + return 0, fmt.Errorf("unexpected conn type: %T", conn) + } + var httpResult httpstat.Result + ctx, cancel := context.WithTimeout(httpstat.WithHTTPStat(context.Background(), &httpResult), time.Second*3) + defer cancel() + reqURL := "https://" + dst.String() + "/derp/latency-check" + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) + if err != nil { + return 0, err + } + client := &http.Client{} + tcpConn, err := tcpDial(lport, dst) + if err != nil { + return 0, tempError{err} + } + defer tcpConn.Close() + tlsConn := tls.Client(tcpConn, &tls.Config{ + ServerName: hostname, + }) + // Mirror client/netcheck behavior, which handshakes before handing the + // tlsConn over to the http.Client via http.Transport + err = tlsConn.Handshake() + if err != nil { + return 0, tempError{err} + } + tlsConnCh := make(chan net.Conn, 1) + tlsConnCh <- tlsConn + tr := &http.Transport{ + DialTLSContext: func(ctx context.Context, network string, addr string) (net.Conn, error) { + select { + case tlsConn := <-tlsConnCh: + return tlsConn, nil + default: + return nil, errors.New("unexpected second call of DialTLSContext") + } + }, + } + client.Transport = tr + resp, err := client.Do(req) + if err != nil { + return 0, tempError{err} + } + if resp.StatusCode/100 != 2 { + return 0, tempError{fmt.Errorf("unexpected status code: %d", resp.StatusCode)} + } + defer resp.Body.Close() + _, err = io.Copy(io.Discard, io.LimitReader(resp.Body, 8<<10)) + if err != nil { + return 0, tempError{err} + } + httpResult.End(time.Now()) + return httpResult.ServerProcessing, nil +} + +func measureSTUNRTT(conn io.ReadWriteCloser, _ string, dst netip.AddrPort) (rtt time.Duration, err error) { uconn, ok := conn.(*net.UDPConn) if !ok { return 0, fmt.Errorf("unexpected conn type: %T", conn) @@ -167,7 +353,7 @@ type nodeMeta struct { addr netip.Addr } -type measureFn func(conn io.ReadWriteCloser, dst netip.AddrPort) (rtt time.Duration, err error) +type measureFn func(conn io.ReadWriteCloser, hostname string, dst netip.AddrPort) (rtt time.Duration, err error) // probe measures round trip time for the node described by meta over // conn against dstPort using fn. It may return a nil duration and nil error in @@ -180,7 +366,7 @@ func probe(meta nodeMeta, conn io.ReadWriteCloser, fn measureFn, dstPort int) (* } time.Sleep(rand.N(200 * time.Millisecond)) // jitter across tx - rtt, err := fn(conn, netip.AddrPortFrom(meta.addr, uint16(dstPort))) + rtt, err := fn(conn, meta.hostname, netip.AddrPortFrom(meta.addr, uint16(dstPort))) if err != nil { if isTemporaryOrTimeoutErr(err) { log.Printf("temp error measuring RTT to %s(%s): %v", meta.hostname, ua.String(), err) @@ -251,7 +437,7 @@ func nodeMetaFromDERPMap(dm *tailcfg.DERPMap, nodeMetaByAddr map[netip.Addr]node return stale, nil } -func newConn(source timestampSource, protocol protocol) (io.ReadWriteCloser, error) { +func newConn(source timestampSource, protocol protocol, stable connStability) (io.ReadWriteCloser, error) { switch protocol { case protocolSTUN: if source == timestampSourceKernel { @@ -263,8 +449,19 @@ func newConn(source timestampSource, protocol protocol) (io.ReadWriteCloser, err // TODO(jwhited): implement return nil, errors.New("unimplemented protocol") case protocolHTTPS: - // TODO(jwhited): implement - return nil, errors.New("unimplemented protocol") + localPort := 0 + if stable { + localPort = lports.get() + } + ret := lportForTCPConn(localPort) + return &ret, nil + case protocolTCP: + localPort := 0 + if stable { + localPort = lports.get() + } + ret := lportForTCPConn(localPort) + return &ret, nil } return nil, errors.New("unknown protocol") } @@ -279,19 +476,20 @@ func getStableConns(stableConns map[stableConnKey][2]io.ReadWriteCloser, addr ne if !protocolSupportsStableConn(protocol) { return [2]io.ReadWriteCloser{}, nil } - conns, ok := stableConns[stableConnKey{addr, protocol, dstPort}] + key := stableConnKey{addr, protocol, dstPort} + conns, ok := stableConns[key] if ok { return conns, nil } if protocolSupportsKernelTS(protocol) { - kconn, err := newConn(timestampSourceKernel, protocol) + kconn, err := newConn(timestampSourceKernel, protocol, stableConn) if err != nil { return conns, err } conns[timestampSourceKernel] = kconn } - uconn, err := newConn(timestampSourceUserspace, protocol) + uconn, err := newConn(timestampSourceUserspace, protocol, stableConn) if err != nil { if protocolSupportsKernelTS(protocol) { conns[timestampSourceKernel].Close() @@ -299,6 +497,7 @@ func getStableConns(stableConns map[stableConnKey][2]io.ReadWriteCloser, addr ne return conns, err } conns[timestampSourceUserspace] = uconn + stableConns[key] = conns return conns, nil } @@ -338,7 +537,7 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableCo } if conn == nil { var err error - conn, err = newConn(source, protocol) + conn, err = newConn(source, protocol, unstableConn) if err != nil { select { case <-doneCh: @@ -361,7 +560,9 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableCo case protocolICMP: // TODO(jwhited): implement case protocolHTTPS: - // TODO(jwhited): implement + fn = measureHTTPSRTT + case protocolTCP: + fn = measureTCPRTT } rtt, err := probe(meta, conn, fn, dstPort) if err != nil { @@ -725,6 +926,13 @@ func main() { if len(httpsPorts) > 0 { portsByProtocol[protocolHTTPS] = httpsPorts } + tcpPorts, err := getPortsFromFlag(*flagTCPDstPorts) + if err != nil { + log.Fatalf("invalid tcp-dst-ports flag value: %v", err) + } + if len(tcpPorts) > 0 { + portsByProtocol[protocolTCP] = tcpPorts + } if *flagICMP { portsByProtocol[protocolICMP] = []int{0} } @@ -734,8 +942,8 @@ func main() { // TODO(jwhited): remove protocol restriction for k := range portsByProtocol { - if k != protocolSTUN { - log.Fatal("HTTPS & ICMP are not yet supported") + if k != protocolSTUN && k != protocolHTTPS && k != protocolTCP { + log.Fatal("ICMP is not yet supported") } } @@ -883,7 +1091,7 @@ func main() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() updatedDM, err := getDERPMap(ctx, *flagDERPMap) - if err != nil { + if err == nil { dmCh <- updatedDM } }() diff --git a/cmd/stunstamp/stunstamp_default.go b/cmd/stunstamp/stunstamp_default.go index 707035306..017af1251 100644 --- a/cmd/stunstamp/stunstamp_default.go +++ b/cmd/stunstamp/stunstamp_default.go @@ -16,10 +16,14 @@ func getUDPConnKernelTimestamp() (io.ReadWriteCloser, error) { return nil, errors.New("unimplemented") } -func measureSTUNRTTKernel(conn io.ReadWriteCloser, dst netip.AddrPort) (rtt time.Duration, err error) { +func measureSTUNRTTKernel(conn io.ReadWriteCloser, hostname string, dst netip.AddrPort) (rtt time.Duration, err error) { return 0, errors.New("unimplemented") } func protocolSupportsKernelTS(_ protocol) bool { return false } + +func setSOReuseAddr(fd uintptr) error { + return nil +} diff --git a/cmd/stunstamp/stunstamp_linux.go b/cmd/stunstamp/stunstamp_linux.go index 1545c067f..148e4b0ef 100644 --- a/cmd/stunstamp/stunstamp_linux.go +++ b/cmd/stunstamp/stunstamp_linux.go @@ -11,6 +11,7 @@ "fmt" "io" "net/netip" + "syscall" "time" "github.com/mdlayher/socket" @@ -56,7 +57,7 @@ func parseTimestampFromCmsgs(oob []byte) (time.Time, error) { return time.Time{}, errors.New("failed to parse timestamp from cmsgs") } -func measureSTUNRTTKernel(conn io.ReadWriteCloser, dst netip.AddrPort) (rtt time.Duration, err error) { +func measureSTUNRTTKernel(conn io.ReadWriteCloser, hostname string, dst netip.AddrPort) (rtt time.Duration, err error) { sconn, ok := conn.(*socket.Conn) if !ok { return 0, fmt.Errorf("conn of unexpected type: %T", conn) @@ -144,3 +145,8 @@ func protocolSupportsKernelTS(p protocol) bool { // TODO: jwhited support ICMP return false } + +func setSOReuseAddr(fd uintptr) error { + // we may restart faster than TIME_WAIT can clear + return syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) +}