mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-12 05:37:32 +00:00
net/connstats: enforce maximum number of connections (#6760)
The Tailscale logging service has a hard limit on the maximum log message size that can be accepted. We want to ensure that netlog messages never exceed this limit otherwise a client cannot transmit logs. Move the goroutine for periodically dumping netlog messages from wgengine/netlog to net/connstats. This allows net/connstats to manage when it dumps messages, either based on time or by size. Updates tailscale/corp#8427 Signed-off-by: Joe Tsai <joetsai@digital-static.net>
This commit is contained in:
@@ -7,9 +7,12 @@
|
||||
package connstats
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/types/netlogtype"
|
||||
)
|
||||
@@ -18,11 +21,64 @@ import (
|
||||
// All methods are safe for concurrent use.
|
||||
// The zero value is ready for use.
|
||||
type Statistics struct {
|
||||
mu sync.Mutex
|
||||
maxConns int // immutable once set
|
||||
|
||||
mu sync.Mutex
|
||||
connCnts
|
||||
|
||||
connCntsCh chan connCnts
|
||||
shutdownCtx context.Context
|
||||
shutdown context.CancelFunc
|
||||
group errgroup.Group
|
||||
}
|
||||
|
||||
type connCnts struct {
|
||||
start time.Time
|
||||
end time.Time
|
||||
virtual map[netlogtype.Connection]netlogtype.Counts
|
||||
physical map[netlogtype.Connection]netlogtype.Counts
|
||||
}
|
||||
|
||||
// NewStatistics creates a data structure for tracking connection statistics
|
||||
// that periodically dumps the virtual and physical connection counts
|
||||
// depending on whether the maxPeriod or maxConns is exceeded.
|
||||
// The dump function is called from a single goroutine.
|
||||
// Shutdown must be called to cleanup resources.
|
||||
func NewStatistics(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)) *Statistics {
|
||||
s := &Statistics{maxConns: maxConns}
|
||||
s.connCntsCh = make(chan connCnts, 256)
|
||||
s.shutdownCtx, s.shutdown = context.WithCancel(context.Background())
|
||||
s.group.Go(func() error {
|
||||
// TODO(joetsai): Using a ticker is problematic on mobile platforms
|
||||
// where waking up a process every maxPeriod when there is no activity
|
||||
// is a drain on battery life. Switch this instead to instead use
|
||||
// a time.Timer that is triggered upon network activity.
|
||||
ticker := new(time.Ticker)
|
||||
if maxPeriod > 0 {
|
||||
ticker := time.NewTicker(maxPeriod)
|
||||
defer ticker.Stop()
|
||||
}
|
||||
|
||||
for {
|
||||
var cc connCnts
|
||||
select {
|
||||
case cc = <-s.connCntsCh:
|
||||
case <-ticker.C:
|
||||
cc = s.extract()
|
||||
case <-s.shutdownCtx.Done():
|
||||
cc = s.extract()
|
||||
}
|
||||
if len(cc.virtual)+len(cc.physical) > 0 && dump != nil {
|
||||
dump(cc.start, cc.end, cc.virtual, cc.physical)
|
||||
}
|
||||
if s.shutdownCtx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
})
|
||||
return s
|
||||
}
|
||||
|
||||
// UpdateTxVirtual updates the counters for a transmitted IP packet
|
||||
// The source and destination of the packet directly correspond with
|
||||
// the source and destination in netlogtype.Connection.
|
||||
@@ -47,10 +103,10 @@ func (s *Statistics) updateVirtual(b []byte, receive bool) {
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.virtual == nil {
|
||||
s.virtual = make(map[netlogtype.Connection]netlogtype.Counts)
|
||||
cnts, found := s.virtual[conn]
|
||||
if !found && !s.preInsertConn() {
|
||||
return
|
||||
}
|
||||
cnts := s.virtual[conn]
|
||||
if receive {
|
||||
cnts.RxPackets++
|
||||
cnts.RxBytes += uint64(len(b))
|
||||
@@ -82,10 +138,10 @@ func (s *Statistics) updatePhysical(src netip.Addr, dst netip.AddrPort, n int, r
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.physical == nil {
|
||||
s.physical = make(map[netlogtype.Connection]netlogtype.Counts)
|
||||
cnts, found := s.physical[conn]
|
||||
if !found && !s.preInsertConn() {
|
||||
return
|
||||
}
|
||||
cnts := s.physical[conn]
|
||||
if receive {
|
||||
cnts.RxPackets++
|
||||
cnts.RxBytes += uint64(n)
|
||||
@@ -96,14 +152,57 @@ func (s *Statistics) updatePhysical(src netip.Addr, dst netip.AddrPort, n int, r
|
||||
s.physical[conn] = cnts
|
||||
}
|
||||
|
||||
// Extract extracts and resets the counters for all active connections.
|
||||
// It must be called periodically otherwise the memory used is unbounded.
|
||||
func (s *Statistics) Extract() (virtual, physical map[netlogtype.Connection]netlogtype.Counts) {
|
||||
// preInsertConn updates the maps to handle insertion of a new connection.
|
||||
// It reports false if insertion is not allowed (i.e., after shutdown).
|
||||
func (s *Statistics) preInsertConn() bool {
|
||||
// Check whether insertion of a new connection will exceed maxConns.
|
||||
if len(s.virtual)+len(s.physical) == s.maxConns && s.maxConns > 0 {
|
||||
// Extract the current statistics and send it to the serializer.
|
||||
// Avoid blocking the network packet handling path.
|
||||
select {
|
||||
case s.connCntsCh <- s.extractLocked():
|
||||
default:
|
||||
// TODO(joetsai): Log that we are dropping an entire connCounts.
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the maps if nil.
|
||||
if s.virtual == nil && s.physical == nil {
|
||||
s.start = time.Now().UTC()
|
||||
s.virtual = make(map[netlogtype.Connection]netlogtype.Counts)
|
||||
s.physical = make(map[netlogtype.Connection]netlogtype.Counts)
|
||||
}
|
||||
|
||||
return s.shutdownCtx.Err() == nil
|
||||
}
|
||||
|
||||
func (s *Statistics) extract() connCnts {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
virtual = s.virtual
|
||||
s.virtual = make(map[netlogtype.Connection]netlogtype.Counts)
|
||||
physical = s.physical
|
||||
s.physical = make(map[netlogtype.Connection]netlogtype.Counts)
|
||||
return virtual, physical
|
||||
return s.extractLocked()
|
||||
}
|
||||
|
||||
func (s *Statistics) extractLocked() connCnts {
|
||||
if len(s.virtual)+len(s.physical) == 0 {
|
||||
return connCnts{}
|
||||
}
|
||||
s.end = time.Now().UTC()
|
||||
cc := s.connCnts
|
||||
s.connCnts = connCnts{}
|
||||
return cc
|
||||
}
|
||||
|
||||
// TestExtract synchronously extracts the current network statistics map
|
||||
// and resets the counters. This should only be used for testing purposes.
|
||||
func (s *Statistics) TestExtract() (virtual, physical map[netlogtype.Connection]netlogtype.Counts) {
|
||||
cc := s.extract()
|
||||
return cc.virtual, cc.physical
|
||||
}
|
||||
|
||||
// Shutdown performs a final flush of statistics.
|
||||
// Statistics for any subsequent calls to Update will be dropped.
|
||||
// It is safe to call Shutdown concurrently and repeatedly.
|
||||
func (s *Statistics) Shutdown(context.Context) error {
|
||||
s.shutdown()
|
||||
return s.group.Wait()
|
||||
}
|
||||
|
@@ -5,6 +5,7 @@
|
||||
package connstats
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
@@ -47,7 +48,20 @@ func testPacketV4(proto ipproto.Proto, srcAddr, dstAddr [4]byte, srcPort, dstPor
|
||||
func TestConcurrent(t *testing.T) {
|
||||
c := qt.New(t)
|
||||
|
||||
var stats Statistics
|
||||
const maxPeriod = 10 * time.Millisecond
|
||||
const maxConns = 10
|
||||
virtualAggregate := make(map[netlogtype.Connection]netlogtype.Counts)
|
||||
stats := NewStatistics(maxPeriod, maxConns, func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts) {
|
||||
c.Assert(start.IsZero(), qt.IsFalse)
|
||||
c.Assert(end.IsZero(), qt.IsFalse)
|
||||
c.Assert(end.Before(start), qt.IsFalse)
|
||||
c.Assert(len(virtual) > 0 && len(virtual) <= maxConns, qt.IsTrue)
|
||||
c.Assert(len(physical) == 0, qt.IsTrue)
|
||||
for conn, cnts := range virtual {
|
||||
virtualAggregate[conn] = virtualAggregate[conn].Add(cnts)
|
||||
}
|
||||
})
|
||||
defer stats.Shutdown(context.Background())
|
||||
var wants []map[netlogtype.Connection]netlogtype.Counts
|
||||
gots := make([]map[netlogtype.Connection]netlogtype.Counts, runtime.NumCPU())
|
||||
var group sync.WaitGroup
|
||||
@@ -95,14 +109,9 @@ func TestConcurrent(t *testing.T) {
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
for range gots {
|
||||
virtual, _ := stats.Extract()
|
||||
wants = append(wants, virtual)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
group.Wait()
|
||||
virtual, _ := stats.Extract()
|
||||
wants = append(wants, virtual)
|
||||
c.Assert(stats.Shutdown(context.Background()), qt.IsNil)
|
||||
wants = append(wants, virtualAggregate)
|
||||
|
||||
got := make(map[netlogtype.Connection]netlogtype.Counts)
|
||||
want := make(map[netlogtype.Connection]netlogtype.Counts)
|
||||
@@ -126,7 +135,7 @@ func Benchmark(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var s Statistics
|
||||
s := NewStatistics(0, 0, nil)
|
||||
for j := 0; j < 1e3; j++ {
|
||||
s.UpdateTxVirtual(p)
|
||||
}
|
||||
@@ -137,7 +146,7 @@ func Benchmark(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var s Statistics
|
||||
s := NewStatistics(0, 0, nil)
|
||||
for j := 0; j < 1e3; j++ {
|
||||
binary.BigEndian.PutUint32(p[20:], uint32(j)) // unique port combination
|
||||
s.UpdateTxVirtual(p)
|
||||
@@ -149,7 +158,7 @@ func Benchmark(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var s Statistics
|
||||
s := NewStatistics(0, 0, nil)
|
||||
var group sync.WaitGroup
|
||||
for j := 0; j < runtime.NumCPU(); j++ {
|
||||
group.Add(1)
|
||||
@@ -171,7 +180,7 @@ func Benchmark(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var s Statistics
|
||||
s := NewStatistics(0, 0, nil)
|
||||
var group sync.WaitGroup
|
||||
for j := 0; j < runtime.NumCPU(); j++ {
|
||||
group.Add(1)
|
||||
|
@@ -6,15 +6,17 @@ package tstun
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/tailscale/wireguard-go/tun/tuntest"
|
||||
"go4.org/mem"
|
||||
"go4.org/netipx"
|
||||
@@ -337,7 +339,8 @@ func TestFilter(t *testing.T) {
|
||||
}()
|
||||
|
||||
var buf [MaxPacketSize]byte
|
||||
stats := new(connstats.Statistics)
|
||||
stats := connstats.NewStatistics(0, 0, nil)
|
||||
defer stats.Shutdown(context.Background())
|
||||
tun.SetStatistics(stats)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@@ -346,7 +349,7 @@ func TestFilter(t *testing.T) {
|
||||
var filtered bool
|
||||
sizes := make([]int, 1)
|
||||
|
||||
tunStats, _ := stats.Extract()
|
||||
tunStats, _ := stats.TestExtract()
|
||||
if len(tunStats) > 0 {
|
||||
t.Errorf("connstats.Statistics.Extract = %v, want {}", stats)
|
||||
}
|
||||
@@ -381,7 +384,7 @@ func TestFilter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
got, _ := stats.Extract()
|
||||
got, _ := stats.TestExtract()
|
||||
want := map[netlogtype.Connection]netlogtype.Counts{}
|
||||
if !tt.drop {
|
||||
var p packet.Parsed
|
||||
@@ -395,8 +398,8 @@ func TestFilter(t *testing.T) {
|
||||
want[conn] = netlogtype.Counts{TxPackets: 1, TxBytes: uint64(len(tt.data))}
|
||||
}
|
||||
}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("tun.ExtractStatistics = %v, want %v", got, want)
|
||||
if diff := cmp.Diff(got, want, cmpopts.EquateEmpty()); diff != "" {
|
||||
t.Errorf("stats.TestExtract (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user