net/flowtrack: add new package to specialize groupcache/lru key type

Reduces allocs.

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2021-01-11 12:46:45 -08:00
parent f69e46175d
commit 4d15e954bd
3 changed files with 189 additions and 15 deletions

View File

@ -0,0 +1,99 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
// Original implementation (from same author) from which this was derived was:
// https://github.com/golang/groupcache/blob/5b532d6fd5efaf7fa130d4e859a2fde0fc3a9e1b/lru/lru.go
// ... which was Apache licensed:
// https://github.com/golang/groupcache/blob/master/LICENSE
// Package flowtrack contains types for tracking TCP/UDP flows by 4-tuples.
package flowtrack
import (
"container/list"
"inet.af/netaddr"
)
// Tuple is a 4-tuple of source and destination IP and port.
type Tuple struct {
Src netaddr.IPPort
Dst netaddr.IPPort
}
// Cache is an LRU cache keyed by Tuple.
//
// The zero value is valid to use.
//
// It is not safe for concurrent access.
type Cache struct {
// MaxEntries is the maximum number of cache entries before
// an item is evicted. Zero means no limit.
MaxEntries int
ll *list.List
m map[Tuple]*list.Element // of *entry
}
// entry is the container/list element type.
type entry struct {
key Tuple
value interface{}
}
// Add adds a value to the cache, set or updating its assoicated
// value.
//
// If MaxEntries is non-zero and the length of the cache is greater
// after any addition, the least recently used value is evicted.
func (c *Cache) Add(key Tuple, value interface{}) {
if c.m == nil {
c.m = make(map[Tuple]*list.Element)
c.ll = list.New()
}
if ee, ok := c.m[key]; ok {
c.ll.MoveToFront(ee)
ee.Value.(*entry).value = value
return
}
ele := c.ll.PushFront(&entry{key, value})
c.m[key] = ele
if c.MaxEntries != 0 && c.Len() > c.MaxEntries {
c.RemoveOldest()
}
}
// Get looks up a key's value from the cache, also reporting
// whether it was present.
func (c *Cache) Get(key Tuple) (value interface{}, ok bool) {
if ele, hit := c.m[key]; hit {
c.ll.MoveToFront(ele)
return ele.Value.(*entry).value, true
}
return nil, false
}
// Remove removes the provided key from the cache if it was present.
func (c *Cache) Remove(key Tuple) {
if ele, hit := c.m[key]; hit {
c.removeElement(ele)
}
}
// RemoveOldest removes the oldest item from the cache, if any.
func (c *Cache) RemoveOldest() {
if c.ll != nil {
if ele := c.ll.Back(); ele != nil {
c.removeElement(ele)
}
}
}
func (c *Cache) removeElement(e *list.Element) {
c.ll.Remove(e)
delete(c.m, e.Value.(*entry).key)
}
// Len returns the number of items in the cache.
func (c *Cache) Len() int { return len(c.m) }

View File

@ -0,0 +1,82 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package flowtrack
import (
"testing"
"inet.af/netaddr"
)
func TestCache(t *testing.T) {
c := &Cache{MaxEntries: 2}
k1 := Tuple{Src: netaddr.MustParseIPPort("1.1.1.1:1"), Dst: netaddr.MustParseIPPort("1.1.1.1:1")}
k2 := Tuple{Src: netaddr.MustParseIPPort("1.1.1.1:1"), Dst: netaddr.MustParseIPPort("2.2.2.2:2")}
k3 := Tuple{Src: netaddr.MustParseIPPort("1.1.1.1:1"), Dst: netaddr.MustParseIPPort("3.3.3.3:3")}
k4 := Tuple{Src: netaddr.MustParseIPPort("1.1.1.1:1"), Dst: netaddr.MustParseIPPort("4.4.4.4:4")}
wantLen := func(want int) {
t.Helper()
if got := c.Len(); got != want {
t.Fatalf("Len = %d; want %d", got, want)
}
}
wantVal := func(key Tuple, want interface{}) {
t.Helper()
got, ok := c.Get(key)
if !ok {
t.Fatalf("Get(%q) failed; want value %v", key, want)
}
if got != want {
t.Fatalf("Get(%q) = %v; want %v", key, got, want)
}
}
wantMissing := func(key Tuple) {
t.Helper()
if got, ok := c.Get(key); ok {
t.Fatalf("Get(%q) = %v; want absent from cache", key, got)
}
}
wantLen(0)
c.RemoveOldest() // shouldn't panic
c.Remove(k4) // shouldn't panic
c.Add(k1, 1)
wantLen(1)
c.Add(k2, 2)
wantLen(2)
c.Add(k3, 3)
wantLen(2) // hit the max
wantMissing(k1)
c.Remove(k1)
wantLen(2) // no change; k1 should've been the deleted one per LRU
wantVal(k3, 3)
wantVal(k2, 2)
c.Remove(k2)
wantLen(1)
wantMissing(k2)
c.Add(k3, 30)
wantVal(k3, 30)
wantLen(1)
allocs := int(testing.AllocsPerRun(1000, func() {
got, ok := c.Get(k3)
if !ok {
t.Fatal("missing k3")
}
if got != 30 {
t.Fatalf("got = %d; want 30", got)
}
}))
if allocs != 0 {
t.Errorf("allocs = %v; want 0", allocs)
}
}

View File

@ -10,9 +10,9 @@
"sync" "sync"
"time" "time"
"github.com/golang/groupcache/lru"
"golang.org/x/time/rate" "golang.org/x/time/rate"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/net/flowtrack"
"tailscale.com/net/packet" "tailscale.com/net/packet"
"tailscale.com/types/logger" "tailscale.com/types/logger"
) )
@ -41,17 +41,10 @@ type Filter struct {
state *filterState state *filterState
} }
// tuple is a 4-tuple of source and destination IP and port. It's used
// as a lookup key in filterState.
type tuple struct {
Src netaddr.IPPort
Dst netaddr.IPPort
}
// filterState is a state cache of past seen packets. // filterState is a state cache of past seen packets.
type filterState struct { type filterState struct {
mu sync.Mutex mu sync.Mutex
lru *lru.Cache // of tuple lru *flowtrack.Cache // from flowtrack.Tuple -> nil
} }
// lruMax is the size of the LRU cache in filterState. // lruMax is the size of the LRU cache in filterState.
@ -141,7 +134,7 @@ func New(matches []Match, localNets []netaddr.IPPrefix, shareStateWith *Filter,
state = shareStateWith.state state = shareStateWith.state
} else { } else {
state = &filterState{ state = &filterState{
lru: lru.New(lruMax), lru: &flowtrack.Cache{MaxEntries: lruMax},
} }
} }
f := &Filter{ f := &Filter{
@ -334,7 +327,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) {
return Accept, "tcp ok" return Accept, "tcp ok"
} }
case packet.UDP: case packet.UDP:
t := tuple{q.Src, q.Dst} t := flowtrack.Tuple{Src: q.Src, Dst: q.Dst}
f.state.mu.Lock() f.state.mu.Lock()
_, ok := f.state.lru.Get(t) _, ok := f.state.lru.Get(t)
@ -389,7 +382,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) {
return Accept, "tcp ok" return Accept, "tcp ok"
} }
case packet.UDP: case packet.UDP:
t := tuple{q.Src, q.Dst} t := flowtrack.Tuple{Src: q.Src, Dst: q.Dst}
f.state.mu.Lock() f.state.mu.Lock()
_, ok := f.state.lru.Get(t) _, ok := f.state.lru.Get(t)
@ -413,10 +406,10 @@ func (f *Filter) runOut(q *packet.Parsed) (r Response, why string) {
return Accept, "ok out" return Accept, "ok out"
} }
t := tuple{q.Dst, q.Src} tuple := flowtrack.Tuple{Src: q.Dst, Dst: q.Src} // src/dst reversed
var ti interface{} = t // allocate once, rather than twice inside mutex
f.state.mu.Lock() f.state.mu.Lock()
f.state.lru.Add(ti, ti) f.state.lru.Add(tuple, nil)
f.state.mu.Unlock() f.state.mu.Unlock()
return Accept, "ok out" return Accept, "ok out"
} }