tailscale/util/workgraph/workgraph_test.go
Andrew Dunham 8290d287d0 util/workgraph: add package for concurrent execution of DAGs
This package is intended to be used for cleaning up the mess of
concurrent things happening in the netcheck package.

Updates #10972

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: I609b61a7838b84a74b74bdef66d1d4c4014705e1
2024-06-25 15:58:17 -04:00

354 lines
8.3 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package workgraph
import (
"context"
"errors"
"fmt"
"runtime"
"sync/atomic"
"testing"
"tailscale.com/util/must"
"tailscale.com/util/set"
)
func debugGraph(tb testing.TB, g *WorkGraph) {
before := g.Graphviz()
tb.Cleanup(func() {
if !tb.Failed() {
return
}
after := g.Graphviz()
tb.Logf("graphviz at start of test:\n%s", before)
tb.Logf("graphviz at end of test:\n%s", after)
})
}
func makeTestGraph(tb testing.TB) *WorkGraph {
logFunc := func(s string) func(context.Context) error {
return func(_ context.Context) error {
tb.Log(s)
return nil
}
}
makeNode := func(s string) Node {
return NodeFunc(s, logFunc(s+" called"))
}
withDeps := func(ss ...string) *AddNodeOpts {
return &AddNodeOpts{Dependencies: ss}
}
g := NewWorkGraph()
// Ensure we have at least 2 concurrent goroutines
g.Concurrency = runtime.GOMAXPROCS(-1)
if g.Concurrency < 2 {
g.Concurrency = 2
}
n1 := makeNode("one")
n2 := makeNode("two")
n3 := makeNode("three")
n4 := makeNode("four")
n5 := makeNode("five")
n6 := makeNode("six")
must.Do(g.AddNode(n1, nil)) // can execute first
must.Do(g.AddNode(n2, nil)) // can execute first
must.Do(g.AddNode(n3, withDeps("one")))
must.Do(g.AddNode(n4, withDeps("one", "two")))
must.Do(g.AddNode(n5, withDeps("one")))
must.Do(g.AddNode(n6, withDeps("four", "five")))
return g
}
func TestWorkGraph(t *testing.T) {
g := makeTestGraph(t)
debugGraph(t, g)
if err := g.Run(context.Background()); err != nil {
t.Fatal(err)
}
}
func TestWorkGroup_Error(t *testing.T) {
g := NewWorkGraph()
terr := errors.New("test error")
returnsErr := func(_ context.Context) error { return terr }
notCalled := func(_ context.Context) error { panic("unused") }
n1 := NodeFunc("one", returnsErr)
n2 := NodeFunc("two", notCalled)
n3 := NodeFunc("three", notCalled)
must.Do(g.AddNode(n1, nil))
must.Do(g.AddNode(n2, &AddNodeOpts{Dependencies: []string{"one"}}))
must.Do(g.AddNode(n3, &AddNodeOpts{Dependencies: []string{"one", "two"}}))
err := g.Run(context.Background())
if err == nil {
t.Fatal("wanted non-nil error")
}
if !errors.Is(err, terr) {
t.Errorf("got %v, want %v", err, terr)
}
}
func TestWorkGroup_HandlesPanic(t *testing.T) {
g := NewWorkGraph()
terr := errors.New("test error")
n1 := NodeFunc("one", func(_ context.Context) error { panic(terr) })
must.Do(g.AddNode(n1, nil))
err := g.Run(context.Background())
if err == nil {
t.Fatal("wanted non-nil error")
}
if !errors.Is(err, terr) {
t.Errorf("got %v, want %v", err, terr)
}
}
func TestWorkGroup_Cancellation(t *testing.T) {
g := NewWorkGraph()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var running atomic.Int64
blocks := func(ctx context.Context) error {
running.Add(1)
<-ctx.Done()
return ctx.Err()
}
n1 := NodeFunc("one", blocks)
n2 := NodeFunc("two", blocks)
n3 := NodeFunc("three", blocks)
must.Do(g.AddNode(n1, nil))
must.Do(g.AddNode(n2, nil))
// Ensure that we have a node with dependencies that's also waiting
// since we want to verify that the queue publisher also properly
// handles context cancellation.
must.Do(g.AddNode(n3, &AddNodeOpts{Dependencies: []string{"one", "two"}}))
// call Run in a goroutine since it blocks
errCh := make(chan error, 1)
go func() {
errCh <- g.Run(ctx)
}()
// after all goroutines are running, cancel the context to unblock
for running.Load() != 2 {
// wait
}
cancel()
err := <-errCh
if err == nil {
t.Fatal("wanted non-nil error")
}
if !errors.Is(err, context.Canceled) {
t.Errorf("got %v, want %v", err, context.Canceled)
}
}
func TestTopoSortDFS(t *testing.T) {
g := makeTestGraph(t)
debugGraph(t, g)
sorted, err := g.topoSortDFS()
if err != nil {
t.Fatal(err)
}
t.Logf("DFS topological sort: %v", sorted)
validateTopologicalSortDFS(t, g, sorted)
}
func validateTopologicalSortDFS(tb testing.TB, g *WorkGraph, order []string) {
// A valid ordering is any one where a node ID later in the list does
// not depend on a node ID earlier in the list.
for i, node := range order {
for j := 0; j < i; j++ {
if g.edges.Exists(node, order[j]) {
tb.Errorf("invalid edge: %v [%d] -> %v [%d]", node, i, order[j], j)
}
}
}
}
func TestTopoSortKahn(t *testing.T) {
g := makeTestGraph(t)
debugGraph(t, g)
groups, err := g.topoSortKahn()
if err != nil {
t.Fatal(err)
}
t.Logf("grouped topological sort: %v", groups)
validateTopologicalSortKahn(t, g, groups)
}
func validateTopologicalSortKahn(tb testing.TB, g *WorkGraph, groups []set.Set[string]) {
// A valid ordering is any one where a node ID later in the list does
// not depend on a node ID earlier in the list.
prev := make(map[string]bool)
for i, group := range groups {
for node := range group {
for m := range prev {
if g.edges.Exists(node, m) {
tb.Errorf("group[%d]: invalid edge: %v -> %v", i, node, m)
}
}
prev[node] = true
}
}
// Verify that our topologically sorted groups contain all nodes.
for nid := range g.nodes {
if !prev[nid] {
tb.Errorf("topological sort missing node %v", nid)
}
}
}
func FuzzTopSortKahn(f *testing.F) {
// We can't pass a map[string][]string (or similar) into a fuzz
// function, so instead let's create test data by using a combination
// of 'n' nodes and an adjacency matrix of edges from node to node.
//
// We then need to filter this adjacency matrix in the Fuzz function,
// since the fuzzer doesn't distinguish between "invalid fuzz inputs
// due to logic bugs", and "invalid fuzz data that causes a real
// error".
f.Add(
10, // number of nodes
[]byte{
1, 0, // 1 depends on 0
6, 2, // 6 depends on 2
9, 8, // 9 depends on 8
},
)
f.Fuzz(func(t *testing.T, numNodes int, edges []byte) {
g := createGraphFromFuzzInput(t, numNodes, edges)
if g == nil {
return
}
// This should not error
groups, err := g.topoSortKahn()
if err != nil {
t.Fatal(err)
}
validateTopologicalSortKahn(t, g, groups)
})
}
func FuzzTopSortDFS(f *testing.F) {
// We can't pass a map[string][]string (or similar) into a fuzz
// function, so instead let's create test data by using a combination
// of 'n' nodes and an adjacency matrix of edges from node to node.
//
// We then need to filter this adjacency matrix in the Fuzz function,
// since the fuzzer doesn't distinguish between "invalid fuzz inputs
// due to logic bugs", and "invalid fuzz data that causes a real
// error".
f.Add(
10, // number of nodes
[]byte{
1, 0, // 1 depends on 0
6, 2, // 6 depends on 2
9, 8, // 9 depends on 8
},
)
f.Fuzz(func(t *testing.T, numNodes int, edges []byte) {
g := createGraphFromFuzzInput(t, numNodes, edges)
if g == nil {
return
}
// This should not error
sorted, err := g.topoSortDFS()
if err != nil {
t.Fatal(err)
}
validateTopologicalSortDFS(t, g, sorted)
})
}
func createGraphFromFuzzInput(tb testing.TB, numNodes int, edges []byte) *WorkGraph {
nodeName := func(i int) string {
return fmt.Sprintf("node-%d", i)
}
filterAdjacencyMatrix := func(numNodes int, edges []byte) map[string][]string {
deps := make(map[string][]string)
for i := 0; i < len(edges); i += 2 {
node, dep := int(edges[i]), int(edges[i+1])
if node >= numNodes || dep >= numNodes {
// invalid node
continue
}
if node == dep {
// can't depend on self
continue
}
// We add nodes in incrementing order (0, 1, 2, etc.),
// so an edge can't point 'forward' or it'll fail to be
// added.
if dep > node {
continue
}
nn := nodeName(node)
deps[nn] = append(deps[nn], nodeName(dep))
}
return deps
}
// Constrain the number of nodes
if numNodes <= 0 || numNodes > 1000 {
return nil
}
// Must have pairs of edges (from, to)
if len(edges)%2 != 0 {
return nil
}
// Convert list of edges into list of dependencies
deps := filterAdjacencyMatrix(numNodes, edges)
if len(deps) == 0 {
return nil
}
// Actually create graph.
g := NewWorkGraph()
doNothing := func(context.Context) error { return nil }
for i := 0; i < numNodes; i++ {
nn := nodeName(i)
node := NodeFunc(nn, doNothing)
if err := g.AddNode(node, &AddNodeOpts{
Dependencies: deps[nn],
}); err != nil {
tb.Error(err) // shouldn't error after we filtered out bad edges above
}
}
if tb.Failed() {
return nil
}
return g
}