sockstats: add validation for TCP socket stats

We can use the TCP_CONNECTION_INFO getsockopt() on Darwin to get
OS-collected tx/rx bytes for TCP sockets. Since this API is not available
for UDP sockets (or on Linux/Android), we can't rely on it for actual
stats gathering.

However, we can use it to validate the stats that we collect ourselves
using read/write hooks, so that we can be more confident in them. We
do need additional hooks from the Go standard library (added in
tailscale/go#59) to be able to collect them.

Updates tailscale/corp#9230
Updates #3363

Signed-off-by: Mihai Parparita <mihai@tailscale.com>
This commit is contained in:
Mihai Parparita 2023-03-07 16:29:41 -08:00 committed by Mihai Parparita
parent 6eca47b16c
commit f4f8ed98d9
5 changed files with 106 additions and 19 deletions

View File

@ -1 +1 @@
fb11c0df588717a3ee13b09dacae1e7093279d67 db4dc9046c93dde2c0e534ca7d529bd690ad09c9

View File

@ -876,6 +876,7 @@ func (h *peerAPIHandler) handleServeSockStats(w http.ResponseWriter, r *http.Req
fmt.Fprintf(w, "<th>Tx (%s)</th>", html.EscapeString(iface)) fmt.Fprintf(w, "<th>Tx (%s)</th>", html.EscapeString(iface))
fmt.Fprintf(w, "<th>Rx (%s)</th>", html.EscapeString(iface)) fmt.Fprintf(w, "<th>Rx (%s)</th>", html.EscapeString(iface))
} }
fmt.Fprintln(w, "<th>Validation</th>")
fmt.Fprintln(w, "</thead>") fmt.Fprintln(w, "</thead>")
fmt.Fprintln(w, "<tbody>") fmt.Fprintln(w, "<tbody>")
@ -887,10 +888,10 @@ func (h *peerAPIHandler) handleServeSockStats(w http.ResponseWriter, r *http.Req
return a.String() < b.String() return a.String() < b.String()
}) })
txTotal := int64(0) txTotal := uint64(0)
rxTotal := int64(0) rxTotal := uint64(0)
txTotalByInterface := map[string]int64{} txTotalByInterface := map[string]uint64{}
rxTotalByInterface := map[string]int64{} rxTotalByInterface := map[string]uint64{}
for _, label := range labels { for _, label := range labels {
stat := stats.Stats[label] stat := stats.Stats[label]
@ -908,6 +909,17 @@ func (h *peerAPIHandler) handleServeSockStats(w http.ResponseWriter, r *http.Req
txTotalByInterface[iface] += stat.TxBytesByInterface[iface] txTotalByInterface[iface] += stat.TxBytesByInterface[iface]
rxTotalByInterface[iface] += stat.RxBytesByInterface[iface] rxTotalByInterface[iface] += stat.RxBytesByInterface[iface]
} }
if stat.ValidationRxBytes > 0 || stat.ValidationTxBytes > 0 {
fmt.Fprintf(w, "<td>Tx=%d (%+d) Rx=%d (%+d)</td>",
stat.ValidationTxBytes,
int64(stat.ValidationTxBytes)-int64(stat.TxBytes),
stat.ValidationRxBytes,
int64(stat.ValidationRxBytes)-int64(stat.RxBytes))
} else {
fmt.Fprintln(w, "<td></td>")
}
fmt.Fprintln(w, "</tr>") fmt.Fprintln(w, "</tr>")
} }
fmt.Fprintln(w, "</tbody>") fmt.Fprintln(w, "</tbody>")
@ -920,6 +932,7 @@ func (h *peerAPIHandler) handleServeSockStats(w http.ResponseWriter, r *http.Req
fmt.Fprintf(w, "<th>%d</th>", txTotalByInterface[iface]) fmt.Fprintf(w, "<th>%d</th>", txTotalByInterface[iface])
fmt.Fprintf(w, "<th>%d</th>", rxTotalByInterface[iface]) fmt.Fprintf(w, "<th>%d</th>", rxTotalByInterface[iface])
} }
fmt.Fprintln(w, "<th></th>")
fmt.Fprintln(w, "</tfoot>") fmt.Fprintln(w, "</tfoot>")
fmt.Fprintln(w, "</table>") fmt.Fprintln(w, "</table>")

View File

@ -42,10 +42,14 @@ type SockStats struct {
) )
type SockStat struct { type SockStat struct {
TxBytes int64 TxBytes uint64
RxBytes int64 RxBytes uint64
TxBytesByInterface map[string]int64 TxBytesByInterface map[string]uint64
RxBytesByInterface map[string]int64 RxBytesByInterface map[string]uint64
// NOCOMMIT
ValidationTxBytes uint64
ValidationRxBytes uint64
} }
func WithSockStats(ctx context.Context, label Label) context.Context { func WithSockStats(ctx context.Context, label Label) context.Context {

View File

@ -11,6 +11,7 @@
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
"syscall"
"tailscale.com/net/interfaces" "tailscale.com/net/interfaces"
) )
@ -18,6 +19,12 @@
type sockStatCounters struct { type sockStatCounters struct {
txBytes, rxBytes atomic.Uint64 txBytes, rxBytes atomic.Uint64
rxBytesByInterface, txBytesByInterface map[int]*atomic.Uint64 rxBytesByInterface, txBytesByInterface map[int]*atomic.Uint64
// Validate counts for TCP sockets by using the TCP_CONNECTION_INFO
// getsockopt. We get current counts, as well as save final values when
// sockets are closed.
validationConn atomic.Pointer[syscall.RawConn]
validationTxBytes, validationRxBytes atomic.Uint64
} }
var sockStats = struct { var sockStats = struct {
@ -53,6 +60,23 @@ func withSockStats(ctx context.Context, label Label) context.Context {
sockStats.countersByLabel[label] = counters sockStats.countersByLabel[label] = counters
} }
didCreateTCPConn := func(c syscall.RawConn) {
counters.validationConn.Store(&c)
}
willCloseTCPConn := func(c syscall.RawConn) {
tx, rx := tcpConnStats(c)
counters.validationTxBytes.Add(tx)
counters.validationRxBytes.Add(rx)
}
// Don't bother adding these hooks if we can't get stats that they end up
// collecting.
if tcpConnStats == nil {
willCloseTCPConn = nil
didCreateTCPConn = nil
}
didRead := func(n int) { didRead := func(n int) {
counters.rxBytes.Add(uint64(n)) counters.rxBytes.Add(uint64(n))
if currentInterface := int(sockStats.currentInterface.Load()); currentInterface != 0 { if currentInterface := int(sockStats.currentInterface.Load()); currentInterface != 0 {
@ -74,12 +98,19 @@ func withSockStats(ctx context.Context, label Label) context.Context {
} }
return net.WithSockTrace(ctx, &net.SockTrace{ return net.WithSockTrace(ctx, &net.SockTrace{
DidCreateTCPConn: didCreateTCPConn,
DidRead: didRead, DidRead: didRead,
DidWrite: didWrite, DidWrite: didWrite,
WillOverwrite: willOverwrite, WillOverwrite: willOverwrite,
WillCloseTCPConn: willCloseTCPConn,
}) })
} }
// tcpConnStats returns the number of bytes sent and received on the
// given TCP socket. Its implementation is platform-dependent (or it may not
// be available at all).
var tcpConnStats func(c syscall.RawConn) (tx, rx uint64)
func get() *SockStats { func get() *SockStats {
sockStats.mu.Lock() sockStats.mu.Lock()
defer sockStats.mu.Unlock() defer sockStats.mu.Unlock()
@ -93,20 +124,29 @@ func get() *SockStats {
} }
for label, counters := range sockStats.countersByLabel { for label, counters := range sockStats.countersByLabel {
r.Stats[label] = SockStat{ s := SockStat{
TxBytes: int64(counters.txBytes.Load()), TxBytes: counters.txBytes.Load(),
RxBytes: int64(counters.rxBytes.Load()), RxBytes: counters.rxBytes.Load(),
TxBytesByInterface: make(map[string]int64), TxBytesByInterface: make(map[string]uint64),
RxBytesByInterface: make(map[string]int64), RxBytesByInterface: make(map[string]uint64),
ValidationTxBytes: counters.validationTxBytes.Load(),
ValidationRxBytes: counters.validationRxBytes.Load(),
}
if c := counters.validationConn.Load(); c != nil && tcpConnStats != nil {
tx, rx := tcpConnStats(*c)
s.ValidationTxBytes += tx
s.ValidationRxBytes += rx
} }
for iface, a := range counters.rxBytesByInterface { for iface, a := range counters.rxBytesByInterface {
ifName := sockStats.knownInterfaces[iface] ifName := sockStats.knownInterfaces[iface]
r.Stats[label].RxBytesByInterface[ifName] = int64(a.Load()) s.RxBytesByInterface[ifName] = a.Load()
} }
for iface, a := range counters.txBytesByInterface { for iface, a := range counters.txBytesByInterface {
ifName := sockStats.knownInterfaces[iface] ifName := sockStats.knownInterfaces[iface]
r.Stats[label].TxBytesByInterface[ifName] = int64(a.Load()) s.TxBytesByInterface[ifName] = a.Load()
} }
r.Stats[label] = s
} }
return r return r

View File

@ -0,0 +1,30 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build tailscale_go && (darwin || ios)
package sockstats
import (
"syscall"
"golang.org/x/sys/unix"
)
func init() {
tcpConnStats = darwinTcpConnStats
}
func darwinTcpConnStats(c syscall.RawConn) (tx, rx uint64) {
c.Control(func(fd uintptr) {
if rawInfo, err := unix.GetsockoptTCPConnectionInfo(
int(fd),
unix.IPPROTO_TCP,
unix.TCP_CONNECTION_INFO,
); err == nil {
tx = uint64(rawInfo.Txbytes)
rx = uint64(rawInfo.Rxbytes)
}
})
return
}