mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-16 03:31:39 +00:00
net/flowtrack,wgengine/filter: refactor Cache to use generics
Signed-off-by: Tom DNetto <tom@tailscale.com>
This commit is contained in:
parent
3becf82dd3
commit
2ac5474be1
@ -34,7 +34,7 @@ func (t Tuple) String() string {
|
|||||||
// The zero value is valid to use.
|
// The zero value is valid to use.
|
||||||
//
|
//
|
||||||
// It is not safe for concurrent access.
|
// It is not safe for concurrent access.
|
||||||
type Cache struct {
|
type Cache[Value any] struct {
|
||||||
// MaxEntries is the maximum number of cache entries before
|
// MaxEntries is the maximum number of cache entries before
|
||||||
// an item is evicted. Zero means no limit.
|
// an item is evicted. Zero means no limit.
|
||||||
MaxEntries int
|
MaxEntries int
|
||||||
@ -44,9 +44,9 @@ type Cache struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// entry is the container/list element type.
|
// entry is the container/list element type.
|
||||||
type entry struct {
|
type entry[Value any] struct {
|
||||||
key Tuple
|
key Tuple
|
||||||
value any
|
value Value
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add adds a value to the cache, set or updating its associated
|
// Add adds a value to the cache, set or updating its associated
|
||||||
@ -54,17 +54,17 @@ type entry struct {
|
|||||||
//
|
//
|
||||||
// If MaxEntries is non-zero and the length of the cache is greater
|
// If MaxEntries is non-zero and the length of the cache is greater
|
||||||
// after any addition, the least recently used value is evicted.
|
// after any addition, the least recently used value is evicted.
|
||||||
func (c *Cache) Add(key Tuple, value any) {
|
func (c *Cache[Value]) Add(key Tuple, value Value) {
|
||||||
if c.m == nil {
|
if c.m == nil {
|
||||||
c.m = make(map[Tuple]*list.Element)
|
c.m = make(map[Tuple]*list.Element)
|
||||||
c.ll = list.New()
|
c.ll = list.New()
|
||||||
}
|
}
|
||||||
if ee, ok := c.m[key]; ok {
|
if ee, ok := c.m[key]; ok {
|
||||||
c.ll.MoveToFront(ee)
|
c.ll.MoveToFront(ee)
|
||||||
ee.Value.(*entry).value = value
|
ee.Value.(*entry[Value]).value = value
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ele := c.ll.PushFront(&entry{key, value})
|
ele := c.ll.PushFront(&entry[Value]{key, value})
|
||||||
c.m[key] = ele
|
c.m[key] = ele
|
||||||
if c.MaxEntries != 0 && c.Len() > c.MaxEntries {
|
if c.MaxEntries != 0 && c.Len() > c.MaxEntries {
|
||||||
c.RemoveOldest()
|
c.RemoveOldest()
|
||||||
@ -73,23 +73,23 @@ func (c *Cache) Add(key Tuple, value any) {
|
|||||||
|
|
||||||
// Get looks up a key's value from the cache, also reporting
|
// Get looks up a key's value from the cache, also reporting
|
||||||
// whether it was present.
|
// whether it was present.
|
||||||
func (c *Cache) Get(key Tuple) (value any, ok bool) {
|
func (c *Cache[Value]) Get(key Tuple) (value *Value, ok bool) {
|
||||||
if ele, hit := c.m[key]; hit {
|
if ele, hit := c.m[key]; hit {
|
||||||
c.ll.MoveToFront(ele)
|
c.ll.MoveToFront(ele)
|
||||||
return ele.Value.(*entry).value, true
|
return &ele.Value.(*entry[Value]).value, true
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove removes the provided key from the cache if it was present.
|
// Remove removes the provided key from the cache if it was present.
|
||||||
func (c *Cache) Remove(key Tuple) {
|
func (c *Cache[Value]) Remove(key Tuple) {
|
||||||
if ele, hit := c.m[key]; hit {
|
if ele, hit := c.m[key]; hit {
|
||||||
c.removeElement(ele)
|
c.removeElement(ele)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveOldest removes the oldest item from the cache, if any.
|
// RemoveOldest removes the oldest item from the cache, if any.
|
||||||
func (c *Cache) RemoveOldest() {
|
func (c *Cache[Value]) RemoveOldest() {
|
||||||
if c.ll != nil {
|
if c.ll != nil {
|
||||||
if ele := c.ll.Back(); ele != nil {
|
if ele := c.ll.Back(); ele != nil {
|
||||||
c.removeElement(ele)
|
c.removeElement(ele)
|
||||||
@ -97,10 +97,10 @@ func (c *Cache) RemoveOldest() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cache) removeElement(e *list.Element) {
|
func (c *Cache[Value]) removeElement(e *list.Element) {
|
||||||
c.ll.Remove(e)
|
c.ll.Remove(e)
|
||||||
delete(c.m, e.Value.(*entry).key)
|
delete(c.m, e.Value.(*entry[Value]).key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Len returns the number of items in the cache.
|
// Len returns the number of items in the cache.
|
||||||
func (c *Cache) Len() int { return len(c.m) }
|
func (c *Cache[Value]) Len() int { return len(c.m) }
|
||||||
|
@ -12,7 +12,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestCache(t *testing.T) {
|
func TestCache(t *testing.T) {
|
||||||
c := &Cache{MaxEntries: 2}
|
c := &Cache[int]{MaxEntries: 2}
|
||||||
|
|
||||||
k1 := Tuple{Src: netip.MustParseAddrPort("1.1.1.1:1"), Dst: netip.MustParseAddrPort("1.1.1.1:1")}
|
k1 := Tuple{Src: netip.MustParseAddrPort("1.1.1.1:1"), Dst: netip.MustParseAddrPort("1.1.1.1:1")}
|
||||||
k2 := Tuple{Src: netip.MustParseAddrPort("1.1.1.1:1"), Dst: netip.MustParseAddrPort("2.2.2.2:2")}
|
k2 := Tuple{Src: netip.MustParseAddrPort("1.1.1.1:1"), Dst: netip.MustParseAddrPort("2.2.2.2:2")}
|
||||||
@ -25,13 +25,13 @@ func TestCache(t *testing.T) {
|
|||||||
t.Fatalf("Len = %d; want %d", got, want)
|
t.Fatalf("Len = %d; want %d", got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
wantVal := func(key Tuple, want any) {
|
wantVal := func(key Tuple, want int) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
got, ok := c.Get(key)
|
got, ok := c.Get(key)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("Get(%q) failed; want value %v", key, want)
|
t.Fatalf("Get(%q) failed; want value %v", key, want)
|
||||||
}
|
}
|
||||||
if got != want {
|
if *got != want {
|
||||||
t.Fatalf("Get(%q) = %v; want %v", key, got, want)
|
t.Fatalf("Get(%q) = %v; want %v", key, got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -73,7 +73,7 @@ func TestCache(t *testing.T) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("missing k3")
|
t.Fatal("missing k3")
|
||||||
}
|
}
|
||||||
if got != 30 {
|
if *got != 30 {
|
||||||
t.Fatalf("got = %d; want 30", got)
|
t.Fatalf("got = %d; want 30", got)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -59,7 +59,7 @@ type Filter struct {
|
|||||||
// filterState is a state cache of past seen packets.
|
// filterState is a state cache of past seen packets.
|
||||||
type filterState struct {
|
type filterState struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
lru *flowtrack.Cache // from flowtrack.Tuple -> nil
|
lru *flowtrack.Cache[struct{}] // from flowtrack.Tuple -> struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// lruMax is the size of the LRU cache in filterState.
|
// lruMax is the size of the LRU cache in filterState.
|
||||||
@ -176,7 +176,7 @@ func New(matches []Match, localNets *netipx.IPSet, logIPs *netipx.IPSet, shareSt
|
|||||||
state = shareStateWith.state
|
state = shareStateWith.state
|
||||||
} else {
|
} else {
|
||||||
state = &filterState{
|
state = &filterState{
|
||||||
lru: &flowtrack.Cache{MaxEntries: lruMax},
|
lru: &flowtrack.Cache[struct{}]{MaxEntries: lruMax},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
f := &Filter{
|
f := &Filter{
|
||||||
@ -517,7 +517,7 @@ func (f *Filter) runOut(q *packet.Parsed) (r Response, why string) {
|
|||||||
Src: q.Dst, Dst: q.Src, // src/dst reversed
|
Src: q.Dst, Dst: q.Src, // src/dst reversed
|
||||||
}
|
}
|
||||||
f.state.mu.Lock()
|
f.state.mu.Lock()
|
||||||
f.state.lru.Add(tuple, nil)
|
f.state.lru.Add(tuple, struct{}{})
|
||||||
f.state.mu.Unlock()
|
f.state.mu.Unlock()
|
||||||
}
|
}
|
||||||
return Accept, "ok out"
|
return Accept, "ok out"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user