net/art: make each strideTable track the IP prefix it represents

This is a prerequisite for path compression, so that insert/delete
can determine when compression occurred.

Updates #7781

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson
2023-04-06 11:20:33 -07:00
committed by Dave Anderson
parent 45b5d0983c
commit 486195edf0
3 changed files with 59 additions and 8 deletions

View File

@@ -18,17 +18,27 @@ import (
"io"
"net/netip"
"strings"
"sync"
)
// Table is an IPv4 and IPv6 routing table.
type Table[T any] struct {
v4 strideTable[T]
v6 strideTable[T]
v4 strideTable[T]
v6 strideTable[T]
initOnce sync.Once
}
func (t *Table[T]) init() {
t.initOnce.Do(func() {
t.v4.prefix = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
t.v6.prefix = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
})
}
// Get does a route lookup for addr and returns the associated value, or nil if
// no route matched.
func (t *Table[T]) Get(addr netip.Addr) *T {
t.init()
st := &t.v4
if addr.Is6() {
st = &t.v6
@@ -58,6 +68,7 @@ func (t *Table[T]) Get(addr netip.Addr) *T {
// Insert adds pfx to the table, with value val.
// If pfx is already present in the table, its value is set to val.
func (t *Table[T]) Insert(pfx netip.Prefix, val *T) {
t.init()
if val == nil {
panic("Table.Insert called with nil value")
}
@@ -85,6 +96,7 @@ func (t *Table[T]) Insert(pfx netip.Prefix, val *T) {
// Delete removes pfx from the table, if it is present.
func (t *Table[T]) Delete(pfx netip.Prefix) {
t.init()
st := &t.v4
if pfx.Addr().Is6() {
st = &t.v6
@@ -141,6 +153,7 @@ func (t *Table[T]) Delete(pfx netip.Prefix) {
// debugSummary prints the tree of allocated strideTables in t, with each
// strideTable's refcount.
func (t *Table[T]) debugSummary() string {
t.init()
var ret bytes.Buffer
fmt.Fprintf(&ret, "v4: ")
strideSummary(&ret, &t.v4, 0)
@@ -150,7 +163,7 @@ func (t *Table[T]) debugSummary() string {
}
func strideSummary[T any](w io.Writer, st *strideTable[T], indent int) {
fmt.Fprintf(w, "%d refs\n", st.refs)
fmt.Fprintf(w, "%s: %d refs\n", st.prefix, st.refs)
indent += 2
for i := firstHostIndex; i <= lastHostIndex; i++ {
if child := st.entries[i].child; child != nil {