tailscale/portlist/portlist_test.go
Marwan Sulaiman e32e5c0d0c portlist: add Poller.IncludeLocalhost option
This PR parameterizes receiving loopback updates from the portlist package.
Callers can now include services bound to localhost if they want.
Note that this option is off by default still.

Fixes #8171

Signed-off-by: Marwan Sulaiman <marwan@tailscale.com>
2023-05-24 13:26:16 -04:00

263 lines
5.2 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package portlist
import (
"context"
"net"
"sync"
"testing"
"time"
"tailscale.com/tstest"
)
func TestGetList(t *testing.T) {
tstest.ResourceCheck(t)
var p Poller
pl, err := p.getList()
if err != nil {
t.Fatal(err)
}
for i, p := range pl {
t.Logf("[%d] %+v", i, p)
}
t.Logf("As String: %v", pl.String())
}
func TestIgnoreLocallyBoundPorts(t *testing.T) {
tstest.ResourceCheck(t)
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Skipf("failed to bind: %v", err)
}
defer ln.Close()
ta := ln.Addr().(*net.TCPAddr)
port := ta.Port
var p Poller
pl, err := p.getList()
if err != nil {
t.Fatal(err)
}
for _, p := range pl {
if p.Proto == "tcp" && int(p.Port) == port {
t.Fatal("didn't expect to find test's localhost ephemeral port")
}
}
}
func TestChangesOverTime(t *testing.T) {
var p Poller
p.IncludeLocalhost = true
get := func(t *testing.T) []Port {
t.Helper()
s, err := p.getList()
if err != nil {
t.Fatal(err)
}
return append([]Port(nil), s...)
}
p1 := get(t)
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Skipf("failed to bind: %v", err)
}
defer ln.Close()
port := uint16(ln.Addr().(*net.TCPAddr).Port)
containsPort := func(pl List) bool {
for _, p := range pl {
if p.Proto == "tcp" && p.Port == port {
return true
}
}
return false
}
if containsPort(p1) {
t.Error("unexpectedly found ephemeral port in p1, before it was opened", port)
}
p2 := get(t)
if !containsPort(p2) {
t.Error("didn't find ephemeral port in p2", port)
}
ln.Close()
p3 := get(t)
if containsPort(p3) {
t.Error("unexpectedly found ephemeral port in p3, after it was closed", port)
}
}
func TestEqualLessThan(t *testing.T) {
tests := []struct {
name string
a, b Port
want bool
}{
{
"Port a < b",
Port{Proto: "tcp", Port: 100, Process: "proc1"},
Port{Proto: "tcp", Port: 101, Process: "proc1"},
true,
},
{
"Port a > b",
Port{Proto: "tcp", Port: 101, Process: "proc1"},
Port{Proto: "tcp", Port: 100, Process: "proc1"},
false,
},
{
"Proto a < b",
Port{Proto: "tcp", Port: 100, Process: "proc1"},
Port{Proto: "udp", Port: 100, Process: "proc1"},
true,
},
{
"Proto a < b",
Port{Proto: "udp", Port: 100, Process: "proc1"},
Port{Proto: "tcp", Port: 100, Process: "proc1"},
false,
},
{
"Process a < b",
Port{Proto: "tcp", Port: 100, Process: "proc1"},
Port{Proto: "tcp", Port: 100, Process: "proc2"},
true,
},
{
"Process a > b",
Port{Proto: "tcp", Port: 100, Process: "proc2"},
Port{Proto: "tcp", Port: 100, Process: "proc1"},
false,
},
{
"Port evaluated first",
Port{Proto: "udp", Port: 100, Process: "proc2"},
Port{Proto: "tcp", Port: 101, Process: "proc1"},
true,
},
{
"Proto evaluated second",
Port{Proto: "tcp", Port: 100, Process: "proc2"},
Port{Proto: "udp", Port: 100, Process: "proc1"},
true,
},
{
"Process evaluated fourth",
Port{Proto: "tcp", Port: 100, Process: "proc1"},
Port{Proto: "tcp", Port: 100, Process: "proc2"},
true,
},
{
"equal",
Port{Proto: "tcp", Port: 100, Process: "proc1"},
Port{Proto: "tcp", Port: 100, Process: "proc1"},
false,
},
}
for _, tt := range tests {
got := tt.a.lessThan(&tt.b)
if got != tt.want {
t.Errorf("%s: Equal = %v; want %v", tt.name, got, tt.want)
}
lessBack := tt.b.lessThan(&tt.a)
if got && lessBack {
t.Errorf("%s: both a and b report being less than each other", tt.name)
}
wantEqual := !got && !lessBack
gotEqual := tt.a.equal(&tt.b)
if gotEqual != wantEqual {
t.Errorf("%s: equal = %v; want %v", tt.name, gotEqual, wantEqual)
}
}
}
func TestPoller(t *testing.T) {
p, err := NewPoller()
if err != nil {
t.Skipf("not running test: %v", err)
}
defer p.Close()
var wg sync.WaitGroup
wg.Add(2)
gotUpdate := make(chan bool, 16)
go func() {
defer wg.Done()
for pl := range p.Updates() {
// Look at all the pl slice memory to maximize
// chance of race detector seeing violations.
for _, v := range pl {
if v == (Port{}) {
// Force use
panic("empty port")
}
}
select {
case gotUpdate <- true:
default:
}
}
}()
tick := make(chan time.Time, 16)
go func() {
defer wg.Done()
if err := p.runWithTickChan(context.Background(), tick); err != nil {
t.Error("runWithTickChan:", err)
}
}()
for i := 0; i < 10; i++ {
ln, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
tick <- time.Time{}
select {
case <-gotUpdate:
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for update")
}
}
// And a bunch of ticks without waiting for updates,
// to make race tests more likely to fail, if any present.
for i := 0; i < 10; i++ {
tick <- time.Time{}
}
if err := p.Close(); err != nil {
t.Fatal(err)
}
wg.Wait()
}
func BenchmarkGetList(b *testing.B) {
benchmarkGetList(b, false)
}
func BenchmarkGetListIncremental(b *testing.B) {
benchmarkGetList(b, true)
}
func benchmarkGetList(b *testing.B, incremental bool) {
b.ReportAllocs()
var p Poller
for i := 0; i < b.N; i++ {
pl, err := p.getList()
if err != nil {
b.Fatal(err)
}
if incremental {
p.prev = pl
}
}
}