mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-12 05:37:32 +00:00
net/connstats: invert network logging data flow (#6272)
Previously, tstun.Wrapper and magicsock.Conn managed their own statistics data structure and relied on an external call to Extract to extract (and reset) the statistics. This makes it difficult to ensure a maximum size on the statistics as the caller has no introspection into whether the number of unique connections is getting too large. Invert the control flow such that a *connstats.Statistics is registered with tstun.Wrapper and magicsock.Conn. Methods on non-nil *connstats.Statistics are called for every packet. This allows the implementation of connstats.Statistics (in the future) to better control when it needs to flush to ensure bounds on maximum sizes. The value registered into tstun.Wrapper and magicsock.Conn could be an interface, but that has two performance detriments: 1. Method calls on interface values are more expensive since they must go through a virtual method dispatch. 2. The implementation would need a sync.Mutex to protect the statistics value instead of using an atomic.Pointer. Given that methods on constats.Statistics are called for every packet, we want reduce the CPU cost on this hot path. Signed-off-by: Joe Tsai <joetsai@digital-static.net>
This commit is contained in:
109
net/connstats/stats.go
Normal file
109
net/connstats/stats.go
Normal file
@@ -0,0 +1,109 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package connstats maintains statistics about connections
|
||||
// flowing through a TUN device (which operate at the IP layer).
|
||||
package connstats
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/types/netlogtype"
|
||||
)
|
||||
|
||||
// Statistics maintains counters for every connection.
|
||||
// All methods are safe for concurrent use.
|
||||
// The zero value is ready for use.
|
||||
type Statistics struct {
|
||||
mu sync.Mutex
|
||||
virtual map[netlogtype.Connection]netlogtype.Counts
|
||||
physical map[netlogtype.Connection]netlogtype.Counts
|
||||
}
|
||||
|
||||
// UpdateTxVirtual updates the counters for a transmitted IP packet
|
||||
// The source and destination of the packet directly correspond with
|
||||
// the source and destination in netlogtype.Connection.
|
||||
func (s *Statistics) UpdateTxVirtual(b []byte) {
|
||||
s.updateVirtual(b, false)
|
||||
}
|
||||
|
||||
// UpdateRxVirtual updates the counters for a received IP packet.
|
||||
// The source and destination of the packet are inverted with respect to
|
||||
// the source and destination in netlogtype.Connection.
|
||||
func (s *Statistics) UpdateRxVirtual(b []byte) {
|
||||
s.updateVirtual(b, true)
|
||||
}
|
||||
|
||||
func (s *Statistics) updateVirtual(b []byte, receive bool) {
|
||||
var p packet.Parsed
|
||||
p.Decode(b)
|
||||
conn := netlogtype.Connection{Proto: p.IPProto, Src: p.Src, Dst: p.Dst}
|
||||
if receive {
|
||||
conn.Src, conn.Dst = conn.Dst, conn.Src
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.virtual == nil {
|
||||
s.virtual = make(map[netlogtype.Connection]netlogtype.Counts)
|
||||
}
|
||||
cnts := s.virtual[conn]
|
||||
if receive {
|
||||
cnts.RxPackets++
|
||||
cnts.RxBytes += uint64(len(b))
|
||||
} else {
|
||||
cnts.TxPackets++
|
||||
cnts.TxBytes += uint64(len(b))
|
||||
}
|
||||
s.virtual[conn] = cnts
|
||||
}
|
||||
|
||||
// UpdateTxPhysical updates the counters for a transmitted wireguard packet
|
||||
// The src is always a Tailscale IP address, representing some remote peer.
|
||||
// The dst is a remote IP address and port that corresponds
|
||||
// with some physical peer backing the Tailscale IP address.
|
||||
func (s *Statistics) UpdateTxPhysical(src netip.Addr, dst netip.AddrPort, n int) {
|
||||
s.updatePhysical(src, dst, n, false)
|
||||
}
|
||||
|
||||
// UpdateRxPhysical updates the counters for a received wireguard packet.
|
||||
// The src is always a Tailscale IP address, representing some remote peer.
|
||||
// The dst is a remote IP address and port that corresponds
|
||||
// with some physical peer backing the Tailscale IP address.
|
||||
func (s *Statistics) UpdateRxPhysical(src netip.Addr, dst netip.AddrPort, n int) {
|
||||
s.updatePhysical(src, dst, n, true)
|
||||
}
|
||||
|
||||
func (s *Statistics) updatePhysical(src netip.Addr, dst netip.AddrPort, n int, receive bool) {
|
||||
conn := netlogtype.Connection{Src: netip.AddrPortFrom(src, 0), Dst: dst}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.physical == nil {
|
||||
s.physical = make(map[netlogtype.Connection]netlogtype.Counts)
|
||||
}
|
||||
cnts := s.physical[conn]
|
||||
if receive {
|
||||
cnts.RxPackets++
|
||||
cnts.RxBytes += uint64(n)
|
||||
} else {
|
||||
cnts.TxPackets++
|
||||
cnts.TxBytes += uint64(n)
|
||||
}
|
||||
s.physical[conn] = cnts
|
||||
}
|
||||
|
||||
// Extract extracts and resets the counters for all active connections.
|
||||
// It must be called periodically otherwise the memory used is unbounded.
|
||||
func (s *Statistics) Extract() (virtual, physical map[netlogtype.Connection]netlogtype.Counts) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
virtual = s.virtual
|
||||
s.virtual = make(map[netlogtype.Connection]netlogtype.Counts)
|
||||
physical = s.physical
|
||||
s.physical = make(map[netlogtype.Connection]netlogtype.Counts)
|
||||
return virtual, physical
|
||||
}
|
@@ -2,7 +2,7 @@
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tunstats
|
||||
package connstats
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
@@ -82,13 +82,13 @@ func TestConcurrent(t *testing.T) {
|
||||
|
||||
cnts := gots[i][t2]
|
||||
if receive {
|
||||
stats.UpdateRx(p)
|
||||
stats.UpdateRxVirtual(p)
|
||||
cnts.RxPackets++
|
||||
cnts.RxBytes += uint64(len(p))
|
||||
} else {
|
||||
cnts.TxPackets++
|
||||
cnts.TxBytes += uint64(len(p))
|
||||
stats.UpdateTx(p)
|
||||
stats.UpdateTxVirtual(p)
|
||||
}
|
||||
gots[i][t2] = cnts
|
||||
time.Sleep(time.Duration(rn.Intn(1 + delay)))
|
||||
@@ -96,11 +96,13 @@ func TestConcurrent(t *testing.T) {
|
||||
}(i)
|
||||
}
|
||||
for range gots {
|
||||
wants = append(wants, stats.Extract())
|
||||
virtual, _ := stats.Extract()
|
||||
wants = append(wants, virtual)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
group.Wait()
|
||||
wants = append(wants, stats.Extract())
|
||||
virtual, _ := stats.Extract()
|
||||
wants = append(wants, virtual)
|
||||
|
||||
got := make(map[netlogtype.Connection]netlogtype.Counts)
|
||||
want := make(map[netlogtype.Connection]netlogtype.Counts)
|
||||
@@ -126,7 +128,7 @@ func Benchmark(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var s Statistics
|
||||
for j := 0; j < 1e3; j++ {
|
||||
s.UpdateTx(p)
|
||||
s.UpdateTxVirtual(p)
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -138,7 +140,7 @@ func Benchmark(b *testing.B) {
|
||||
var s Statistics
|
||||
for j := 0; j < 1e3; j++ {
|
||||
binary.BigEndian.PutUint32(p[20:], uint32(j)) // unique port combination
|
||||
s.UpdateTx(p)
|
||||
s.UpdateTxVirtual(p)
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -154,7 +156,7 @@ func Benchmark(b *testing.B) {
|
||||
go func() {
|
||||
defer group.Done()
|
||||
for k := 0; k < 1e3; k++ {
|
||||
s.UpdateTx(p)
|
||||
s.UpdateTxVirtual(p)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -179,7 +181,7 @@ func Benchmark(b *testing.B) {
|
||||
j *= 1e3
|
||||
for k := 0; k < 1e3; k++ {
|
||||
binary.BigEndian.PutUint32(p[20:], uint32(j+k)) // unique port combination
|
||||
s.UpdateTx(p)
|
||||
s.UpdateTxVirtual(p)
|
||||
}
|
||||
}(j)
|
||||
}
|
@@ -22,15 +22,14 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"tailscale.com/disco"
|
||||
"tailscale.com/net/connstats"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/net/tunstats"
|
||||
"tailscale.com/syncs"
|
||||
"tailscale.com/tstime/mono"
|
||||
"tailscale.com/types/ipproto"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/types/netlogtype"
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/wgengine/filter"
|
||||
)
|
||||
@@ -170,10 +169,7 @@ type Wrapper struct {
|
||||
disableTSMPRejected bool
|
||||
|
||||
// stats maintains per-connection counters.
|
||||
stats struct {
|
||||
enabled atomic.Bool
|
||||
tunstats.Statistics
|
||||
}
|
||||
stats atomic.Pointer[connstats.Statistics]
|
||||
}
|
||||
|
||||
// tunReadResult is the result of a TUN read, or an injected result pretending to be a TUN read.
|
||||
@@ -568,8 +564,8 @@ func (t *Wrapper) Read(buf []byte, offset int) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if t.stats.enabled.Load() {
|
||||
t.stats.UpdateTx(buf[offset:][:n])
|
||||
if stats := t.stats.Load(); stats != nil {
|
||||
stats.UpdateTxVirtual(buf[offset:][:n])
|
||||
}
|
||||
t.noteActivity()
|
||||
return n, nil
|
||||
@@ -701,8 +697,8 @@ func (t *Wrapper) Write(buf []byte, offset int) (int, error) {
|
||||
}
|
||||
|
||||
func (t *Wrapper) tdevWrite(buf []byte, offset int) (int, error) {
|
||||
if t.stats.enabled.Load() {
|
||||
t.stats.UpdateRx(buf[offset:])
|
||||
if stats := t.stats.Load(); stats != nil {
|
||||
stats.UpdateRxVirtual(buf[offset:])
|
||||
}
|
||||
if t.isTAP {
|
||||
return t.tapWrite(buf, offset)
|
||||
@@ -843,18 +839,10 @@ func (t *Wrapper) Unwrap() tun.Device {
|
||||
return t.tdev
|
||||
}
|
||||
|
||||
// SetStatisticsEnabled enables per-connections packet counters.
|
||||
// Disabling statistics gathering does not reset the counters.
|
||||
// ExtractStatistics must be called to reset the counters and
|
||||
// be periodically called while enabled to avoid unbounded memory use.
|
||||
func (t *Wrapper) SetStatisticsEnabled(enable bool) {
|
||||
t.stats.enabled.Store(enable)
|
||||
}
|
||||
|
||||
// ExtractStatistics extracts and resets the counters for all active connections.
|
||||
// It must be called periodically otherwise the memory used is unbounded.
|
||||
func (t *Wrapper) ExtractStatistics() map[netlogtype.Connection]netlogtype.Counts {
|
||||
return t.stats.Extract()
|
||||
// SetStatistics specifies a per-connection statistics aggregator.
|
||||
// Nil may be specified to disable statistics gathering.
|
||||
func (t *Wrapper) SetStatistics(stats *connstats.Statistics) {
|
||||
t.stats.Store(stats)
|
||||
}
|
||||
|
||||
var (
|
||||
|
@@ -19,6 +19,7 @@ import (
|
||||
"go4.org/netipx"
|
||||
"golang.zx2c4.com/wireguard/tun/tuntest"
|
||||
"tailscale.com/disco"
|
||||
"tailscale.com/net/connstats"
|
||||
"tailscale.com/net/netaddr"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/tstest"
|
||||
@@ -283,11 +284,6 @@ func TestWriteAndInject(t *testing.T) {
|
||||
t.Errorf("%s not received", packet)
|
||||
}
|
||||
}
|
||||
|
||||
// Statistics gathering is disabled by default.
|
||||
if stats := tun.ExtractStatistics(); len(stats) > 0 {
|
||||
t.Errorf("tun.ExtractStatistics = %v, want {}", stats)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter(t *testing.T) {
|
||||
@@ -336,15 +332,17 @@ func TestFilter(t *testing.T) {
|
||||
}()
|
||||
|
||||
var buf [MaxPacketSize]byte
|
||||
tun.SetStatisticsEnabled(true)
|
||||
stats := new(connstats.Statistics)
|
||||
tun.SetStatistics(stats)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var n int
|
||||
var err error
|
||||
var filtered bool
|
||||
|
||||
if stats := tun.ExtractStatistics(); len(stats) > 0 {
|
||||
t.Errorf("tun.ExtractStatistics = %v, want {}", stats)
|
||||
tunStats, _ := stats.Extract()
|
||||
if len(tunStats) > 0 {
|
||||
t.Errorf("connstats.Statistics.Extract = %v, want {}", stats)
|
||||
}
|
||||
|
||||
if tt.dir == in {
|
||||
@@ -377,7 +375,7 @@ func TestFilter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
got := tun.ExtractStatistics()
|
||||
got, _ := stats.Extract()
|
||||
want := map[netlogtype.Connection]netlogtype.Counts{}
|
||||
if !tt.drop {
|
||||
var p packet.Parsed
|
||||
|
@@ -1,70 +0,0 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package tunstats maintains statistics about connections
|
||||
// flowing through a TUN device (which operate at the IP layer).
|
||||
package tunstats
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/types/netlogtype"
|
||||
)
|
||||
|
||||
// Statistics maintains counters for every connection.
|
||||
// All methods are safe for concurrent use.
|
||||
// The zero value is ready for use.
|
||||
type Statistics struct {
|
||||
mu sync.Mutex
|
||||
m map[netlogtype.Connection]netlogtype.Counts
|
||||
}
|
||||
|
||||
// UpdateTx updates the counters for a transmitted IP packet
|
||||
// The source and destination of the packet directly correspond with
|
||||
// the source and destination in netlogtype.Connection.
|
||||
func (s *Statistics) UpdateTx(b []byte) {
|
||||
s.update(b, false)
|
||||
}
|
||||
|
||||
// UpdateRx updates the counters for a received IP packet.
|
||||
// The source and destination of the packet are inverted with respect to
|
||||
// the source and destination in netlogtype.Connection.
|
||||
func (s *Statistics) UpdateRx(b []byte) {
|
||||
s.update(b, true)
|
||||
}
|
||||
|
||||
func (s *Statistics) update(b []byte, receive bool) {
|
||||
var p packet.Parsed
|
||||
p.Decode(b)
|
||||
conn := netlogtype.Connection{Proto: p.IPProto, Src: p.Src, Dst: p.Dst}
|
||||
if receive {
|
||||
conn.Src, conn.Dst = conn.Dst, conn.Src
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.m == nil {
|
||||
s.m = make(map[netlogtype.Connection]netlogtype.Counts)
|
||||
}
|
||||
cnts := s.m[conn]
|
||||
if receive {
|
||||
cnts.RxPackets++
|
||||
cnts.RxBytes += uint64(len(b))
|
||||
} else {
|
||||
cnts.TxPackets++
|
||||
cnts.TxBytes += uint64(len(b))
|
||||
}
|
||||
s.m[conn] = cnts
|
||||
}
|
||||
|
||||
// Extract extracts and resets the counters for all active connections.
|
||||
// It must be called periodically otherwise the memory used is unbounded.
|
||||
func (s *Statistics) Extract() map[netlogtype.Connection]netlogtype.Counts {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
m := s.m
|
||||
s.m = make(map[netlogtype.Connection]netlogtype.Counts)
|
||||
return m
|
||||
}
|
Reference in New Issue
Block a user