diff --git a/cmd/stunstamp/stunstamp.go b/cmd/stunstamp/stunstamp.go index 950fdc2cd..32f4e7ff0 100644 --- a/cmd/stunstamp/stunstamp.go +++ b/cmd/stunstamp/stunstamp.go @@ -460,7 +460,7 @@ type connAndMeasureFn struct { // newConnAndMeasureFn returns a connAndMeasureFn or an error. It may return // nil for both if some combination of the supplied timestampSource, protocol, // or connStability is unsupported. -func newConnAndMeasureFn(source timestampSource, protocol protocol, stable connStability) (*connAndMeasureFn, error) { +func newConnAndMeasureFn(forDst netip.Addr, source timestampSource, protocol protocol, stable connStability) (*connAndMeasureFn, error) { info := getProtocolSupportInfo(protocol) if !info.stableConn && bool(stable) { return nil, nil @@ -493,8 +493,14 @@ func newConnAndMeasureFn(source timestampSource, protocol protocol, stable connS }, nil } case protocolICMP: - // TODO(jwhited): implement - return nil, nil + conn, err := getICMPConn(forDst, source) + if err != nil { + return nil, err + } + return &connAndMeasureFn{ + conn: conn, + fn: mkICMPRTTFn(source), + }, nil case protocolHTTPS: localPort := 0 if stable { @@ -558,7 +564,7 @@ func getConns( if !ok { for _, source := range []timestampSource{timestampSourceUserspace, timestampSourceKernel} { var cf *connAndMeasureFn - cf, err = newConnAndMeasureFn(source, protocol, stableConn) + cf, err = newConnAndMeasureFn(addr, source, protocol, stableConn) if err != nil { return } @@ -569,7 +575,7 @@ func getConns( for _, source := range []timestampSource{timestampSourceUserspace, timestampSourceKernel} { var cf *connAndMeasureFn - cf, err = newConnAndMeasureFn(source, protocol, unstableConn) + cf, err = newConnAndMeasureFn(addr, source, protocol, unstableConn) if err != nil { return } @@ -953,13 +959,6 @@ func main() { log.Fatal("nothing to probe") } - // TODO(jwhited): remove protocol restriction - for k := range portsByProtocol { - if k != protocolSTUN && k != protocolHTTPS && k != protocolTCP { - log.Fatal("ICMP is not yet supported") - } - } - if len(*flagDERPMap) < 1 { log.Fatal("derp-map flag is unset") } diff --git a/cmd/stunstamp/stunstamp_default.go b/cmd/stunstamp/stunstamp_default.go index 36afdbb8f..5eb765134 100644 --- a/cmd/stunstamp/stunstamp_default.go +++ b/cmd/stunstamp/stunstamp_default.go @@ -40,10 +40,26 @@ func getProtocolSupportInfo(p protocol) protocolSupportInfo { userspaceTS: false, stableConn: true, } + case protocolICMP: + return protocolSupportInfo{ + kernelTS: false, + userspaceTS: false, + stableConn: false, + } } return protocolSupportInfo{} } +func getICMPConn(forDst netip.Addr, source timestampSource) (io.ReadWriteCloser, error) { + return nil, errors.New("platform unsupported") +} + +func mkICMPRTTFn(source timestampSource) func(conn io.ReadWriteCloser, hostname string, dst netip.AddrPort) (rtt time.Duration, err error) { + return func(conn io.ReadWriteCloser, hostname string, dst netip.AddrPort) (rtt time.Duration, err error) { + return 0, errors.New("platform unsupported") + } +} + func setSOReuseAddr(fd uintptr) error { return nil } diff --git a/cmd/stunstamp/stunstamp_linux.go b/cmd/stunstamp/stunstamp_linux.go index e73d1ee3c..038c01ffa 100644 --- a/cmd/stunstamp/stunstamp_linux.go +++ b/cmd/stunstamp/stunstamp_linux.go @@ -10,17 +10,22 @@ "errors" "fmt" "io" + "math" + "math/rand/v2" "net/netip" "syscall" "time" "github.com/mdlayher/socket" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" "golang.org/x/sys/unix" "tailscale.com/net/stun" ) const ( - flags = unix.SOF_TIMESTAMPING_TX_SOFTWARE | // tx timestamp generation in device driver + timestampingFlags = unix.SOF_TIMESTAMPING_TX_SOFTWARE | // tx timestamp generation in device driver unix.SOF_TIMESTAMPING_RX_SOFTWARE | // rx timestamp generation in the kernel unix.SOF_TIMESTAMPING_SOFTWARE // report software timestamps ) @@ -35,7 +40,7 @@ func getUDPConnKernelTimestamp() (io.ReadWriteCloser, error) { if err != nil { return nil, err } - err = sconn.SetsockoptInt(unix.SOL_SOCKET, unix.SO_TIMESTAMPING_NEW, flags) + err = sconn.SetsockoptInt(unix.SOL_SOCKET, unix.SO_TIMESTAMPING_NEW, timestampingFlags) if err != nil { return nil, err } @@ -57,7 +62,128 @@ func parseTimestampFromCmsgs(oob []byte) (time.Time, error) { return time.Time{}, errors.New("failed to parse timestamp from cmsgs") } -func measureSTUNRTTKernel(conn io.ReadWriteCloser, hostname string, dst netip.AddrPort) (rtt time.Duration, err error) { +func mkICMPRTTFn(source timestampSource) func(conn io.ReadWriteCloser, hostname string, dst netip.AddrPort) (rtt time.Duration, err error) { + return func(conn io.ReadWriteCloser, hostname string, dst netip.AddrPort) (rtt time.Duration, err error) { + return measureICMPRTT(source, conn, hostname, dst) + } +} + +func measureICMPRTT(source timestampSource, conn io.ReadWriteCloser, _ string, dst netip.AddrPort) (rtt time.Duration, err error) { + sconn, ok := conn.(*socket.Conn) + if !ok { + return 0, fmt.Errorf("conn of unexpected type: %T", conn) + } + txBody := &icmp.Echo{ + // The kernel overrides this and routes appropriately so there is no + // point in setting or verifying. + ID: 0, + // Make this sufficiently random so that we do not account a late + // arriving reply in a future probe window. + Seq: int(rand.Int32N(math.MaxUint16)), + // Fingerprint ourselves. + Data: []byte("stunstamp"), + } + txMsg := icmp.Message{ + Body: txBody, + } + var to unix.Sockaddr + if dst.Addr().Is4() { + txMsg.Type = ipv4.ICMPTypeEcho + to = &unix.SockaddrInet4{} + copy(to.(*unix.SockaddrInet4).Addr[:], dst.Addr().AsSlice()) + } else { + txMsg.Type = ipv6.ICMPTypeEchoRequest + to = &unix.SockaddrInet6{} + copy(to.(*unix.SockaddrInet6).Addr[:], dst.Addr().AsSlice()) + } + txBuf, err := txMsg.Marshal(nil) + if err != nil { + return 0, err + } + txAt := time.Now() + err = sconn.Sendto(context.Background(), txBuf, 0, to) + if err != nil { + return 0, fmt.Errorf("sendto error: %v", err) + } + + if source == timestampSourceKernel { + txCtx, txCancel := context.WithTimeout(context.Background(), time.Second*2) + defer txCancel() + + buf := make([]byte, 1024) + oob := make([]byte, 1024) + + for { + n, oobn, _, _, err := sconn.Recvmsg(txCtx, buf, oob, unix.MSG_ERRQUEUE) + if err != nil { + return 0, fmt.Errorf("recvmsg (MSG_ERRQUEUE) error: %v", err) // don't wrap + } + + buf = buf[:n] + // Spin until we find the message we sent. We get the full packet + // looped including eth header so match against the tail. + if n < len(txBuf) { + continue + } + txLoopedMsg, err := icmp.ParseMessage(txMsg.Type.Protocol(), buf[len(buf)-len(txBuf):]) + if err != nil { + continue + } + txLoopedBody, ok := txLoopedMsg.Body.(*icmp.Echo) + if !ok || txLoopedBody.Seq != txBody.Seq || txLoopedMsg.Code != txMsg.Code || + txLoopedMsg.Type != txLoopedMsg.Type || !bytes.Equal(txLoopedBody.Data, txBody.Data) { + continue + } + txAt, err = parseTimestampFromCmsgs(oob[:oobn]) + if err != nil { + return 0, fmt.Errorf("failed to get tx timestamp: %v", err) // don't wrap + } + break + } + } + + rxCtx, txCancel := context.WithTimeout(context.Background(), time.Second*2) + defer txCancel() + + rxBuf := make([]byte, 1024) + oob := make([]byte, 1024) + for { + n, oobn, _, _, err := sconn.Recvmsg(rxCtx, rxBuf, oob, 0) + if err != nil { + return 0, fmt.Errorf("recvmsg error: %w", err) + } + rxAt := time.Now() + rxMsg, err := icmp.ParseMessage(txMsg.Type.Protocol(), rxBuf[:n]) + if err != nil { + continue + } + if txMsg.Type == ipv4.ICMPTypeEcho { + if rxMsg.Type != ipv4.ICMPTypeEchoReply { + continue + } + } else { + if rxMsg.Type != ipv6.ICMPTypeEchoReply { + continue + } + } + if rxMsg.Code != txMsg.Code { + continue + } + rxBody, ok := rxMsg.Body.(*icmp.Echo) + if !ok || rxBody.Seq != txBody.Seq || !bytes.Equal(rxBody.Data, txBody.Data) { + continue + } + if source == timestampSourceKernel { + rxAt, err = parseTimestampFromCmsgs(oob[:oobn]) + if err != nil { + return 0, fmt.Errorf("failed to get rx timestamp: %v", err) + } + } + return rxAt.Sub(txAt), nil + } +} + +func measureSTUNRTTKernel(conn io.ReadWriteCloser, _ string, dst netip.AddrPort) (rtt time.Duration, err error) { sconn, ok := conn.(*socket.Conn) if !ok { return 0, fmt.Errorf("conn of unexpected type: %T", conn) @@ -138,6 +264,23 @@ func measureSTUNRTTKernel(conn io.ReadWriteCloser, hostname string, dst netip.Ad } +func getICMPConn(forDst netip.Addr, source timestampSource) (io.ReadWriteCloser, error) { + domain := unix.AF_INET + proto := unix.IPPROTO_ICMP + if forDst.Is6() { + domain = unix.AF_INET6 + proto = unix.IPPROTO_ICMPV6 + } + conn, err := socket.Socket(domain, unix.SOCK_DGRAM, proto, "icmp", nil) + if err != nil { + return nil, err + } + if source == timestampSourceKernel { + err = conn.SetsockoptInt(unix.SOL_SOCKET, unix.SO_TIMESTAMPING_NEW, timestampingFlags) + } + return conn, err +} + func getProtocolSupportInfo(p protocol) protocolSupportInfo { switch p { case protocolSTUN: @@ -158,7 +301,12 @@ func getProtocolSupportInfo(p protocol) protocolSupportInfo { userspaceTS: false, stableConn: true, } - // TODO(jwhited): add ICMP + case protocolICMP: + return protocolSupportInfo{ + kernelTS: true, + userspaceTS: true, + stableConn: false, + } } return protocolSupportInfo{} }