util/workgraph: add package for concurrent execution of DAGs

This package is intended to be used for cleaning up the mess of
concurrent things happening in the netcheck package.

Updates #10972

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: I609b61a7838b84a74b74bdef66d1d4c4014705e1
This commit is contained in:
Andrew Dunham 2024-01-29 22:01:05 -05:00
parent d7a4f9d31c
commit 8290d287d0
2 changed files with 803 additions and 0 deletions

450
util/workgraph/workgraph.go Normal file
View File

@ -0,0 +1,450 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package workgraph contains a "workgraph"; a data structure that allows
// defining individual jobs, dependencies between them, and then executing all
// jobs to completion.
package workgraph
import (
"context"
"errors"
"fmt"
"runtime"
"slices"
"strings"
"sync"
"tailscale.com/util/set"
)
// ErrCyclic is returned when there is a cycle in the graph.
var ErrCyclic = errors.New("graph is cyclic")
// Node is the interface that must be implemented by a node in a WorkGraph.
type Node interface {
// ID should return a unique ID for this node. IDs for each Node in a
// WorkGraph must be unique.
ID() string
// Run is called when this node in a WorkGraph is executed; it should
// return an error if execution fails, which will cause all dependent
// Nodes to fail to execute.
Run(context.Context) error
}
type nodeFunc struct {
id string
run func(context.Context) error
}
func (n *nodeFunc) ID() string { return n.id }
func (n *nodeFunc) Run(ctx context.Context) error { return n.run(ctx) }
// NodeFunc is a helper that returns a Node with the given ID that calls the
// given function when Node.Run is called.
func NodeFunc(id string, fn func(context.Context) error) Node {
return &nodeFunc{id, fn}
}
// WorkGraph is a directed acyclic graph of individual jobs to be executed,
// each of which may have dependencies on other jobs. It supports adding a job
// as a Nodea combination of a unique ID and the function to execute that
// joband then running all added Nodes while respecting dependencies.
type WorkGraph struct {
nodes map[string]Node // keyed by Node.ID
edges edgeList[string] // keyed by Node.ID
// Concurrency is the number of concurrent goroutines to use to process
// jobs. If zero, runtime.GOMAXPROCS will be used.
//
// This field must not be modified after Run has been called.
Concurrency int
}
// NewWorkGraph creates a new empty WorkGraph.
func NewWorkGraph() *WorkGraph {
ret := &WorkGraph{
nodes: make(map[string]Node),
edges: newEdgeList[string](),
}
return ret
}
// AddNodeOpts contains options that can be passed to AddNode.
type AddNodeOpts struct {
// Dependencies are any Node IDs that must be completed before this
// Node is started.
Dependencies []string
}
// AddNode adds a new Node to the WorkGraph with the provided options. It
// returns an error if the given Node.ID was already added to the WorkGraph, or
// if one of the options provided was invalid.
func (g *WorkGraph) AddNode(n Node, opts *AddNodeOpts) error {
id := n.ID()
if _, found := g.nodes[id]; found {
return fmt.Errorf("node %q already exists", id)
}
g.nodes[id] = n
if opts == nil {
return nil
}
// Create an edge from each dependency pointing to this node, forcing
// that node to be evaluated first.
for _, dep := range opts.Dependencies {
if _, found := g.nodes[dep]; !found {
return fmt.Errorf("dependency %q not found", dep)
}
g.edges.Add(dep, id)
}
return nil
}
type queueEntry struct {
id string
done chan struct{}
}
// Run will iterate through all Nodes in this WorkGraph, running them once all
// their dependencies have been satisfied, and returning any errors that occur.
func (g *WorkGraph) Run(ctx context.Context) error {
groups, err := g.topoSortKahn()
if err != nil {
return err
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Create one goroutine that pushes jobs onto our queue...
var wg sync.WaitGroup
queue := make(chan queueEntry)
publishCtx, publishCancel := context.WithCancel(ctx)
defer publishCancel()
wg.Add(1)
go g.runPublisher(publishCtx, &wg, queue, groups)
firstErr := make(chan error, 1)
saveErr := func(err error) {
if err == nil {
return
}
// Tell the publisher to shut down
publishCancel()
select {
case firstErr <- err:
default:
}
}
// ... and N goroutines that each work on an item from the queue.
n := g.Concurrency
if n == 0 {
n = runtime.GOMAXPROCS(-1)
}
wg.Add(n)
for i := 0; i < n; i++ {
go g.runWorker(ctx, &wg, queue, saveErr)
}
wg.Wait()
select {
case err := <-firstErr:
return err
default:
}
return nil
}
func (g *WorkGraph) runPublisher(ctx context.Context, wg *sync.WaitGroup, queue chan queueEntry, groups []set.Set[string]) {
defer wg.Done()
defer close(queue)
// For each parallel group...
var dones []chan struct{}
for _, group := range groups {
dones = dones[:0] // re-use existing storage, if any
// Push all items in this group onto our queue
for curr := range group {
done := make(chan struct{})
dones = append(dones, done)
select {
case <-ctx.Done():
return
case queue <- queueEntry{curr, done}:
}
}
// Now that we've started everything, wait for them all
// to complete.
for _, done := range dones {
select {
case <-ctx.Done():
return
case <-done:
}
}
// Now that we've done this entire group, we can
// continue with the next one.
}
}
func (g *WorkGraph) runWorker(ctx context.Context, wg *sync.WaitGroup, queue chan queueEntry, saveErr func(error)) {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
case ent, ok := <-queue:
if !ok {
return
}
if err := g.runEntry(ctx, ent); err != nil {
saveErr(err)
return
}
}
}
}
func (g *WorkGraph) runEntry(ctx context.Context, ent queueEntry) (retErr error) {
defer close(ent.done)
defer func() {
if r := recover(); r != nil {
// Ensure that we wrap an existing error with %w so errors.Is works
switch v := r.(type) {
case error:
retErr = fmt.Errorf("node %q: caught panic: %w", ent.id, v)
default:
retErr = fmt.Errorf("node %q: caught panic: %v", ent.id, v)
}
}
}()
node := g.nodes[ent.id]
return node.Run(ctx)
}
// Depth-first toplogical sort; used in tests
//
// https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
func (g *WorkGraph) topoSortDFS() (sorted []string, err error) {
const (
markTemporary = 1
markPermanent = 2
)
marks := make(map[string]int) // map[node.ID]markType
var visit func(string) error
visit = func(n string) error {
// "if n has a permanent mark then"
if marks[n] == markPermanent {
return nil
}
// "if n has a temporary mark then"
if marks[n] == markTemporary {
return ErrCyclic
}
// "mark n with a temporary mark"
marks[n] = markTemporary
// "for each node m with an edge from n to m do"
for m := range g.edges.OutgoingNodes(n) {
if err := visit(m); err != nil {
return err
}
}
// "remove temporary mark from n"
// "mark n with a permanent mark"
//
// NOTE: this is safe because if this node had a temporary
// mark, we'd have returned above, and the only thing that adds
// a mark to a node is this function.
marks[n] = markPermanent
// "add n to head of L"; note that we append for performance
// reasons and reverse later
sorted = append(sorted, n)
return nil
}
// For all nodes, visit them. From the algorithm description:
// while exists nodes without a permanent mark do
// select an unmarked node n
// visit(n)
for nid := range g.nodes {
if err := visit(nid); err != nil {
return nil, err
}
}
// We appended to the slice for performance reasons; reverse it to get
// our final result.
slices.Reverse(sorted)
return sorted, nil
}
// topoSortKahn runs a variant of Kahn's algorithm for topological sorting,
// which not only returns a sort, but provides individual "groups" of nodes
// that can be executed concurrently.
//
// See:
// - https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
// - https://stackoverflow.com/a/67267597
func (g *WorkGraph) topoSortKahn() (sorted []set.Set[string], err error) {
// We mutate the set of edges during this function, so copy it.
edges := g.edges.Clone()
// Create S_0, the set of nodes with no incoming edge
s0 := make(set.Set[string])
for nid := range g.nodes {
if !edges.HasIncoming(nid) {
s0.Add(nid)
}
}
// Add this set to the returned set of nodes
sorted = append(sorted, s0)
// Repeatedly iterate, starting from the initial set, until we have no
// more nodes. The inner loop is essentially Kahn's algorithm.
sCurr := s0
for {
// Initialize the next set
sNext := make(set.Set[string])
// For each node 'n' in the current set...
for n := range sCurr {
// For each successor 'd' of the current node...
for d := range edges.OutgoingNodes(n) {
// Remove edge 'n -> d'
edges.Remove(n, d)
// If this node 'd' has no incoming edges, we
// can add it to the current set since it can
// be processed.
if !edges.HasIncoming(d) {
sNext.Add(d)
}
}
}
// If the current set is non-empty, then append it to the list
// of returned sets, make it the current set, and continue.
// Otherwise, we're done.
if len(sNext) == 0 {
break
}
sorted = append(sorted, sNext)
sCurr = sNext
}
if edges.Len() > 0 {
return nil, ErrCyclic
}
return sorted, nil
}
// Graphviz prints a basic Graphviz representation of the WorkGraph. This is
// primarily useful for debugging.
func (g *WorkGraph) Graphviz() string {
var buf strings.Builder
buf.WriteString("digraph workgraph {\n")
for from, edges := range g.edges.outgoing {
for to := range edges {
fmt.Fprintf(&buf, "\t%s -> %s;\n", from, to)
}
}
buf.WriteString("}")
return buf.String()
}
// edgeList is a helper type that is used to maintain a set of edges, tracking
// both incoming and outgoing edges for a given node.
type edgeList[K comparable] struct {
incoming map[K]set.Set[K] // for edge A -> B, keyed by B
outgoing map[K]set.Set[K] // for edge A -> B, keyed by A
}
func newEdgeList[K comparable]() edgeList[K] {
return edgeList[K]{
incoming: make(map[K]set.Set[K]),
outgoing: make(map[K]set.Set[K]),
}
}
func (el *edgeList[K]) Clone() edgeList[K] {
ret := edgeList[K]{
incoming: make(map[K]set.Set[K], len(el.incoming)),
outgoing: make(map[K]set.Set[K], len(el.outgoing)),
}
for k, v := range el.incoming {
ret.incoming[k] = v.Clone()
}
for k, v := range el.outgoing {
ret.outgoing[k] = v.Clone()
}
return ret
}
func (el *edgeList[K]) Len() int {
i := 0
for _, set := range el.incoming {
i += set.Len()
}
return i
}
func (el *edgeList[K]) Add(from, to K) {
if _, found := el.incoming[to]; !found {
el.incoming[to] = make(set.Set[K])
}
if _, found := el.outgoing[from]; !found {
el.outgoing[from] = make(set.Set[K])
}
el.incoming[to].Add(from)
el.outgoing[from].Add(to)
}
func (el *edgeList[K]) Remove(from, to K) {
if m, ok := el.incoming[to]; ok {
delete(m, from)
}
if m, ok := el.outgoing[from]; ok {
delete(m, to)
}
}
func (el *edgeList[K]) HasIncoming(id K) bool {
return el.incoming[id].Len() > 0
}
func (el *edgeList[K]) HasOutgoing(id K) bool {
return el.outgoing[id].Len() > 0
}
func (el *edgeList[K]) Exists(from, to K) bool {
return el.outgoing[from].Contains(to)
}
func (el *edgeList[K]) IncomingNodes(id K) set.Set[K] {
return el.incoming[id]
}
func (el *edgeList[K]) OutgoingNodes(id K) set.Set[K] {
return el.outgoing[id]
}

View File

@ -0,0 +1,353 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package workgraph
import (
"context"
"errors"
"fmt"
"runtime"
"sync/atomic"
"testing"
"tailscale.com/util/must"
"tailscale.com/util/set"
)
func debugGraph(tb testing.TB, g *WorkGraph) {
before := g.Graphviz()
tb.Cleanup(func() {
if !tb.Failed() {
return
}
after := g.Graphviz()
tb.Logf("graphviz at start of test:\n%s", before)
tb.Logf("graphviz at end of test:\n%s", after)
})
}
func makeTestGraph(tb testing.TB) *WorkGraph {
logFunc := func(s string) func(context.Context) error {
return func(_ context.Context) error {
tb.Log(s)
return nil
}
}
makeNode := func(s string) Node {
return NodeFunc(s, logFunc(s+" called"))
}
withDeps := func(ss ...string) *AddNodeOpts {
return &AddNodeOpts{Dependencies: ss}
}
g := NewWorkGraph()
// Ensure we have at least 2 concurrent goroutines
g.Concurrency = runtime.GOMAXPROCS(-1)
if g.Concurrency < 2 {
g.Concurrency = 2
}
n1 := makeNode("one")
n2 := makeNode("two")
n3 := makeNode("three")
n4 := makeNode("four")
n5 := makeNode("five")
n6 := makeNode("six")
must.Do(g.AddNode(n1, nil)) // can execute first
must.Do(g.AddNode(n2, nil)) // can execute first
must.Do(g.AddNode(n3, withDeps("one")))
must.Do(g.AddNode(n4, withDeps("one", "two")))
must.Do(g.AddNode(n5, withDeps("one")))
must.Do(g.AddNode(n6, withDeps("four", "five")))
return g
}
func TestWorkGraph(t *testing.T) {
g := makeTestGraph(t)
debugGraph(t, g)
if err := g.Run(context.Background()); err != nil {
t.Fatal(err)
}
}
func TestWorkGroup_Error(t *testing.T) {
g := NewWorkGraph()
terr := errors.New("test error")
returnsErr := func(_ context.Context) error { return terr }
notCalled := func(_ context.Context) error { panic("unused") }
n1 := NodeFunc("one", returnsErr)
n2 := NodeFunc("two", notCalled)
n3 := NodeFunc("three", notCalled)
must.Do(g.AddNode(n1, nil))
must.Do(g.AddNode(n2, &AddNodeOpts{Dependencies: []string{"one"}}))
must.Do(g.AddNode(n3, &AddNodeOpts{Dependencies: []string{"one", "two"}}))
err := g.Run(context.Background())
if err == nil {
t.Fatal("wanted non-nil error")
}
if !errors.Is(err, terr) {
t.Errorf("got %v, want %v", err, terr)
}
}
func TestWorkGroup_HandlesPanic(t *testing.T) {
g := NewWorkGraph()
terr := errors.New("test error")
n1 := NodeFunc("one", func(_ context.Context) error { panic(terr) })
must.Do(g.AddNode(n1, nil))
err := g.Run(context.Background())
if err == nil {
t.Fatal("wanted non-nil error")
}
if !errors.Is(err, terr) {
t.Errorf("got %v, want %v", err, terr)
}
}
func TestWorkGroup_Cancellation(t *testing.T) {
g := NewWorkGraph()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var running atomic.Int64
blocks := func(ctx context.Context) error {
running.Add(1)
<-ctx.Done()
return ctx.Err()
}
n1 := NodeFunc("one", blocks)
n2 := NodeFunc("two", blocks)
n3 := NodeFunc("three", blocks)
must.Do(g.AddNode(n1, nil))
must.Do(g.AddNode(n2, nil))
// Ensure that we have a node with dependencies that's also waiting
// since we want to verify that the queue publisher also properly
// handles context cancellation.
must.Do(g.AddNode(n3, &AddNodeOpts{Dependencies: []string{"one", "two"}}))
// call Run in a goroutine since it blocks
errCh := make(chan error, 1)
go func() {
errCh <- g.Run(ctx)
}()
// after all goroutines are running, cancel the context to unblock
for running.Load() != 2 {
// wait
}
cancel()
err := <-errCh
if err == nil {
t.Fatal("wanted non-nil error")
}
if !errors.Is(err, context.Canceled) {
t.Errorf("got %v, want %v", err, context.Canceled)
}
}
func TestTopoSortDFS(t *testing.T) {
g := makeTestGraph(t)
debugGraph(t, g)
sorted, err := g.topoSortDFS()
if err != nil {
t.Fatal(err)
}
t.Logf("DFS topological sort: %v", sorted)
validateTopologicalSortDFS(t, g, sorted)
}
func validateTopologicalSortDFS(tb testing.TB, g *WorkGraph, order []string) {
// A valid ordering is any one where a node ID later in the list does
// not depend on a node ID earlier in the list.
for i, node := range order {
for j := 0; j < i; j++ {
if g.edges.Exists(node, order[j]) {
tb.Errorf("invalid edge: %v [%d] -> %v [%d]", node, i, order[j], j)
}
}
}
}
func TestTopoSortKahn(t *testing.T) {
g := makeTestGraph(t)
debugGraph(t, g)
groups, err := g.topoSortKahn()
if err != nil {
t.Fatal(err)
}
t.Logf("grouped topological sort: %v", groups)
validateTopologicalSortKahn(t, g, groups)
}
func validateTopologicalSortKahn(tb testing.TB, g *WorkGraph, groups []set.Set[string]) {
// A valid ordering is any one where a node ID later in the list does
// not depend on a node ID earlier in the list.
prev := make(map[string]bool)
for i, group := range groups {
for node := range group {
for m := range prev {
if g.edges.Exists(node, m) {
tb.Errorf("group[%d]: invalid edge: %v -> %v", i, node, m)
}
}
prev[node] = true
}
}
// Verify that our topologically sorted groups contain all nodes.
for nid := range g.nodes {
if !prev[nid] {
tb.Errorf("topological sort missing node %v", nid)
}
}
}
func FuzzTopSortKahn(f *testing.F) {
// We can't pass a map[string][]string (or similar) into a fuzz
// function, so instead let's create test data by using a combination
// of 'n' nodes and an adjacency matrix of edges from node to node.
//
// We then need to filter this adjacency matrix in the Fuzz function,
// since the fuzzer doesn't distinguish between "invalid fuzz inputs
// due to logic bugs", and "invalid fuzz data that causes a real
// error".
f.Add(
10, // number of nodes
[]byte{
1, 0, // 1 depends on 0
6, 2, // 6 depends on 2
9, 8, // 9 depends on 8
},
)
f.Fuzz(func(t *testing.T, numNodes int, edges []byte) {
g := createGraphFromFuzzInput(t, numNodes, edges)
if g == nil {
return
}
// This should not error
groups, err := g.topoSortKahn()
if err != nil {
t.Fatal(err)
}
validateTopologicalSortKahn(t, g, groups)
})
}
func FuzzTopSortDFS(f *testing.F) {
// We can't pass a map[string][]string (or similar) into a fuzz
// function, so instead let's create test data by using a combination
// of 'n' nodes and an adjacency matrix of edges from node to node.
//
// We then need to filter this adjacency matrix in the Fuzz function,
// since the fuzzer doesn't distinguish between "invalid fuzz inputs
// due to logic bugs", and "invalid fuzz data that causes a real
// error".
f.Add(
10, // number of nodes
[]byte{
1, 0, // 1 depends on 0
6, 2, // 6 depends on 2
9, 8, // 9 depends on 8
},
)
f.Fuzz(func(t *testing.T, numNodes int, edges []byte) {
g := createGraphFromFuzzInput(t, numNodes, edges)
if g == nil {
return
}
// This should not error
sorted, err := g.topoSortDFS()
if err != nil {
t.Fatal(err)
}
validateTopologicalSortDFS(t, g, sorted)
})
}
func createGraphFromFuzzInput(tb testing.TB, numNodes int, edges []byte) *WorkGraph {
nodeName := func(i int) string {
return fmt.Sprintf("node-%d", i)
}
filterAdjacencyMatrix := func(numNodes int, edges []byte) map[string][]string {
deps := make(map[string][]string)
for i := 0; i < len(edges); i += 2 {
node, dep := int(edges[i]), int(edges[i+1])
if node >= numNodes || dep >= numNodes {
// invalid node
continue
}
if node == dep {
// can't depend on self
continue
}
// We add nodes in incrementing order (0, 1, 2, etc.),
// so an edge can't point 'forward' or it'll fail to be
// added.
if dep > node {
continue
}
nn := nodeName(node)
deps[nn] = append(deps[nn], nodeName(dep))
}
return deps
}
// Constrain the number of nodes
if numNodes <= 0 || numNodes > 1000 {
return nil
}
// Must have pairs of edges (from, to)
if len(edges)%2 != 0 {
return nil
}
// Convert list of edges into list of dependencies
deps := filterAdjacencyMatrix(numNodes, edges)
if len(deps) == 0 {
return nil
}
// Actually create graph.
g := NewWorkGraph()
doNothing := func(context.Context) error { return nil }
for i := 0; i < numNodes; i++ {
nn := nodeName(i)
node := NodeFunc(nn, doNothing)
if err := g.AddNode(node, &AddNodeOpts{
Dependencies: deps[nn],
}); err != nil {
tb.Error(err) // shouldn't error after we filtered out bad edges above
}
}
if tb.Failed() {
return nil
}
return g
}