// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause package eventbustest import ( "errors" "fmt" "reflect" "testing" "time" "tailscale.com/util/eventbus" ) // NewBus constructs an [eventbus.Bus] that will be shut automatically when // its controlling test ends. func NewBus(t *testing.T) *eventbus.Bus { bus := eventbus.New() t.Cleanup(bus.Close) return bus } // NewTestWatcher constructs a [Watcher] that can be used to check the stream of // events generated by code under test. After construction the caller may use // [Expect] and [ExpectExactly], to verify that the desired events were captured. func NewWatcher(t *testing.T, bus *eventbus.Bus) *Watcher { tw := &Watcher{ mon: bus.Debugger().WatchBus(), TimeOut: 5 * time.Second, chDone: make(chan bool, 1), events: make(chan any, 100), } if deadline, ok := t.Deadline(); ok { tw.TimeOut = deadline.Sub(time.Now()) } t.Cleanup(tw.done) go tw.watch() return tw } // Watcher monitors and holds events for test expectations. type Watcher struct { mon *eventbus.Subscriber[eventbus.RoutedEvent] events chan any chDone chan bool // TimeOut defines when the Expect* functions should stop looking for events // coming from the Watcher. The value is set by [NewWatcher] and defaults to // the deadline passed in by [testing.T]. If looking to verify the absence // of an event, the TimeOut can be set to a lower value after creating the // Watcher. TimeOut time.Duration } // Type is a helper representing the expectation to see an event of type T, without // caring about the content of the event. // It makes it possible to use helpers like: // // eventbustest.ExpectFilter(tw, eventbustest.Type[EventFoo]()) func Type[T any]() func(T) { return func(T) {} } // Expect verifies that the given events are a subsequence of the events // observed by tw. That is, tw must contain at least one event matching the type // of each argument in the given order, other event types are allowed to occur in // between without error. The given events are represented by a function // that must have one of the following forms: // // // Tests for the event type only // func(e ExpectedType) // // // Tests for event type and whatever is defined in the body. // // If return is false, the test will look for other events of that type // // If return is true, the test will look for the next given event // // if a list is given // func(e ExpectedType) bool // // // Tests for event type and whatever is defined in the body. // // The boolean return works as above. // // The if error != nil, the test helper will return that error immediately. // func(e ExpectedType) (bool, error) // // If the list of events must match exactly with no extra events, // use [ExpectExactly]. func Expect(tw *Watcher, filters ...any) error { if len(filters) == 0 { return errors.New("no event filters were provided") } eventCount := 0 head := 0 for head < len(filters) { eventFunc := eventFilter(filters[head]) select { case event := <-tw.events: eventCount++ if ok, err := eventFunc(event); err != nil { return err } else if ok { head++ } case <-time.After(tw.TimeOut): return fmt.Errorf( "timed out waiting for event, saw %d events, %d was expected", eventCount, head) case <-tw.chDone: return errors.New("watcher closed while waiting for events") } } return nil } // ExpectExactly checks for some number of events showing up on the event bus // in a given order, returning an error if the events does not match the given list // exactly. The given events are represented by a function as described in // [Expect]. Use [Expect] if other events are allowed. func ExpectExactly(tw *Watcher, filters ...any) error { if len(filters) == 0 { return errors.New("no event filters were provided") } eventCount := 0 for pos, next := range filters { eventFunc := eventFilter(next) fnType := reflect.TypeOf(next) argType := fnType.In(0) select { case event := <-tw.events: eventCount++ typeEvent := reflect.TypeOf(event) if typeEvent != argType { return fmt.Errorf( "expected event type %s, saw %s, at index %d", argType, typeEvent, pos) } else if ok, err := eventFunc(event); err != nil { return err } else if !ok { return fmt.Errorf( "expected test ok for type %s, at index %d", argType, pos) } case <-time.After(tw.TimeOut): return fmt.Errorf( "timed out waiting for event, saw %d events, %d was expected", eventCount, pos) case <-tw.chDone: return errors.New("watcher closed while waiting for events") } } return nil } func (tw *Watcher) watch() { for { select { case event := <-tw.mon.Events(): tw.events <- event.Event case <-tw.chDone: tw.mon.Close() return } } } // done tells the watcher to stop monitoring for new events. func (tw *Watcher) done() { close(tw.chDone) } type filter = func(any) (bool, error) func eventFilter(f any) filter { ft := reflect.TypeOf(f) if ft.Kind() != reflect.Func { panic("filter is not a function") } else if ft.NumIn() != 1 { panic(fmt.Sprintf("function takes %d arguments, want 1", ft.NumIn())) } var fixup func([]reflect.Value) []reflect.Value switch ft.NumOut() { case 0: fixup = func([]reflect.Value) []reflect.Value { return []reflect.Value{reflect.ValueOf(true), reflect.Zero(reflect.TypeFor[error]())} } case 1: if ft.Out(0) != reflect.TypeFor[bool]() { panic(fmt.Sprintf("result is %T, want bool", ft.Out(0))) } fixup = func(vals []reflect.Value) []reflect.Value { return append(vals, reflect.Zero(reflect.TypeFor[error]())) } case 2: if ft.Out(0) != reflect.TypeFor[bool]() || ft.Out(1) != reflect.TypeFor[error]() { panic(fmt.Sprintf("results are %T, %T; want bool, error", ft.Out(0), ft.Out(1))) } fixup = func(vals []reflect.Value) []reflect.Value { return vals } default: panic(fmt.Sprintf("function returns %d values", ft.NumOut())) } fv := reflect.ValueOf(f) return reflect.MakeFunc(reflect.TypeFor[filter](), func(args []reflect.Value) []reflect.Value { if !args[0].IsValid() || args[0].Elem().Type() != ft.In(0) { return []reflect.Value{reflect.ValueOf(false), reflect.Zero(reflect.TypeFor[error]())} } return fixup(fv.Call([]reflect.Value{args[0].Elem()})) }).Interface().(filter) }