diff --git a/util/eventbus/bus_test.go b/util/eventbus/bus_test.go index 7782634ae..67f68cd4a 100644 --- a/util/eventbus/bus_test.go +++ b/util/eventbus/bus_test.go @@ -257,8 +257,8 @@ func TestMonitor(t *testing.T) { cli := bus.Client("test client") // The monitored goroutine runs until the client or test subscription ends. + sub := eventbus.Subscribe[string](cli) m := cli.Monitor(func(c *eventbus.Client) { - sub := eventbus.Subscribe[string](cli) select { case <-c.Done(): t.Log("client closed") @@ -294,6 +294,43 @@ func TestMonitor(t *testing.T) { t.Run("Wait", testMon(t, func(c *eventbus.Client, m eventbus.Monitor) { c.Close(); m.Wait() })) } +func TestRegression(t *testing.T) { + bus := eventbus.New() + t.Cleanup(bus.Close) + + t.Run("SubscribeClosed", func(t *testing.T) { + c := bus.Client("test sub client") + c.Close() + + var v any + func() { + defer func() { v = recover() }() + eventbus.Subscribe[string](c) + }() + if v == nil { + t.Fatal("Expected a panic from Subscribe on a closed client") + } else { + t.Logf("Got expected panic: %v", v) + } + }) + + t.Run("PublishClosed", func(t *testing.T) { + c := bus.Client("test pub client") + c.Close() + + var v any + func() { + defer func() { v = recover() }() + eventbus.Publish[string](c) + }() + if v == nil { + t.Fatal("expected a panic from Publish on a closed client") + } else { + t.Logf("Got expected panic: %v", v) + } + }) +} + type queueChecker struct { t *testing.T want []any diff --git a/util/eventbus/client.go b/util/eventbus/client.go index 176b6f2bc..9b4119865 100644 --- a/util/eventbus/client.go +++ b/util/eventbus/client.go @@ -51,6 +51,8 @@ func (c *Client) Close() { c.stop.Stop() } +func (c *Client) isClosed() bool { return c.pub == nil && c.sub == nil } + // Done returns a channel that is closed when [Client.Close] is called. // The channel is closed after all the publishers and subscribers governed by // the client have been closed. @@ -83,6 +85,10 @@ func (c *Client) subscribeTypes() []reflect.Type { func (c *Client) subscribeState() *subscribeState { c.mu.Lock() defer c.mu.Unlock() + return c.subscribeStateLocked() +} + +func (c *Client) subscribeStateLocked() *subscribeState { if c.sub == nil { c.sub = newSubscribeState(c) } @@ -92,6 +98,9 @@ func (c *Client) subscribeState() *subscribeState { func (c *Client) addPublisher(pub publisher) { c.mu.Lock() defer c.mu.Unlock() + if c.isClosed() { + panic("cannot Publish on a closed client") + } c.pub.Add(pub) } @@ -117,17 +126,29 @@ func (c *Client) shouldPublish(t reflect.Type) bool { return c.publishDebug.active() || c.bus.shouldPublish(t) } -// Subscribe requests delivery of events of type T through the given -// Queue. Panics if the queue already has a subscriber for T. +// Subscribe requests delivery of events of type T through the given client. +// It panics if c already has a subscriber for type T, or if c is closed. func Subscribe[T any](c *Client) *Subscriber[T] { - r := c.subscribeState() + // Hold the client lock throughout the subscription process so that a caller + // attempting to subscribe on a closed client will get a useful diagnostic + // instead of a random panic from inside the subscriber plumbing. + c.mu.Lock() + defer c.mu.Unlock() + + // The caller should not race subscriptions with close, give them a useful + // diagnostic at the call site. + if c.isClosed() { + panic("cannot Subscribe on a closed client") + } + + r := c.subscribeStateLocked() s := newSubscriber[T](r) r.addSubscriber(s) return s } -// Publish returns a publisher for event type T using the given -// client. +// Publish returns a publisher for event type T using the given client. +// It panics if c is closed. func Publish[T any](c *Client) *Publisher[T] { p := newPublisher[T](c) c.addPublisher(p)