net/tcpinfo: add package to allow fetching TCP information

This package contains platform-independent abstractions for fetching
information about an open TCP connection.

Updates #8413

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: I236657b1060d7e6a45efc7a2f6aacf474547a2fe
This commit is contained in:
Andrew Dunham 2023-06-22 12:41:55 -04:00
parent 243ce6ccc1
commit d9eca20ee2
5 changed files with 196 additions and 0 deletions

51
net/tcpinfo/tcpinfo.go Normal file
View File

@ -0,0 +1,51 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package tcpinfo provides platform-agnostic accessors to information about a
// TCP connection (e.g. RTT, MSS, etc.).
package tcpinfo
import (
"errors"
"net"
"time"
)
var (
ErrNotTCP = errors.New("tcpinfo: not a TCP conn")
ErrUnimplemented = errors.New("tcpinfo: unimplemented")
)
// RTT returns the RTT for the given net.Conn.
//
// If the net.Conn is not a *net.TCPConn and cannot be unwrapped into one, then
// ErrNotTCP will be returned. If retrieving the RTT is not supported on the
// current platform, ErrUnimplemented will be returned.
func RTT(conn net.Conn) (time.Duration, error) {
tcpConn, err := unwrap(conn)
if err != nil {
return 0, err
}
return rttImpl(tcpConn)
}
// netConner is implemented by crypto/tls.Conn to unwrap into an underlying
// net.Conn.
type netConner interface {
NetConn() net.Conn
}
// unwrap attempts to unwrap a net.Conn into an underlying *net.TCPConn
func unwrap(nc net.Conn) (*net.TCPConn, error) {
for {
switch v := nc.(type) {
case *net.TCPConn:
return v, nil
case netConner:
nc = v.NetConn()
default:
return nil, ErrNotTCP
}
}
}

View File

@ -0,0 +1,33 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package tcpinfo
import (
"net"
"time"
"golang.org/x/sys/unix"
)
func rttImpl(conn *net.TCPConn) (time.Duration, error) {
rawConn, err := conn.SyscallConn()
if err != nil {
return 0, err
}
var (
tcpInfo *unix.TCPConnectionInfo
sysErr error
)
err = rawConn.Control(func(fd uintptr) {
tcpInfo, sysErr = unix.GetsockoptTCPConnectionInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO)
})
if err != nil {
return 0, err
} else if sysErr != nil {
return 0, sysErr
}
return time.Duration(tcpInfo.Rttcur) * time.Millisecond, nil
}

View File

@ -0,0 +1,33 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package tcpinfo
import (
"net"
"time"
"golang.org/x/sys/unix"
)
func rttImpl(conn *net.TCPConn) (time.Duration, error) {
rawConn, err := conn.SyscallConn()
if err != nil {
return 0, err
}
var (
tcpInfo *unix.TCPInfo
sysErr error
)
err = rawConn.Control(func(fd uintptr) {
tcpInfo, sysErr = unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO)
})
if err != nil {
return 0, err
} else if sysErr != nil {
return 0, sysErr
}
return time.Duration(tcpInfo.Rtt) * time.Microsecond, nil
}

View File

@ -0,0 +1,15 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !linux && !darwin
package tcpinfo
import (
"net"
"time"
)
func rttImpl(conn *net.TCPConn) (time.Duration, error) {
return 0, ErrUnimplemented
}

View File

@ -0,0 +1,64 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package tcpinfo
import (
"bytes"
"io"
"net"
"runtime"
"testing"
)
func TestRTT(t *testing.T) {
switch runtime.GOOS {
case "linux", "darwin":
default:
t.Skipf("not currently supported on %s", runtime.GOOS)
}
ln, err := net.Listen("tcp4", "localhost:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
go func() {
for {
c, err := ln.Accept()
if err != nil {
return
}
t.Cleanup(func() { c.Close() })
// Copy from the client to nowhere
go io.Copy(io.Discard, c)
}
}()
conn, err := net.Dial("tcp4", ln.Addr().String())
if err != nil {
t.Fatal(err)
}
// Write a bunch of data to the conn to force TCP session establishment
// and a few packets.
junkData := bytes.Repeat([]byte("hello world\n"), 1024*1024)
for i := 0; i < 10; i++ {
if _, err := conn.Write(junkData); err != nil {
t.Fatalf("error writing junk data [%d]: %v", i, err)
}
}
// Get the RTT now
rtt, err := RTT(conn)
if err != nil {
t.Fatalf("error getting RTT: %v", err)
}
if rtt == 0 {
t.Errorf("expected RTT > 0")
}
t.Logf("TCP rtt: %v", rtt)
}