diff --git a/feature/relayserver/relayserver.go b/feature/relayserver/relayserver.go index f4077b5f9..b90a62345 100644 --- a/feature/relayserver/relayserver.go +++ b/feature/relayserver/relayserver.go @@ -6,7 +6,6 @@ package relayserver import ( - "errors" "sync" "tailscale.com/disco" @@ -48,16 +47,12 @@ type extension struct { logf logger.Logf bus *eventbus.Bus - mu sync.Mutex // guards the following fields - eventClient *eventbus.Client // closed to stop consumeEventbusTopics - reqSub *eventbus.Subscriber[magicsock.UDPRelayAllocReq] // receives endpoint alloc requests from magicsock - respPub *eventbus.Publisher[magicsock.UDPRelayAllocResp] // publishes endpoint alloc responses to magicsock + mu sync.Mutex // guards the following fields shutdown bool port *int // ipn.Prefs.RelayServerPort, nil if disabled - busDoneCh chan struct{} // non-nil if port is non-nil, closed when consumeEventbusTopics returns + disconnectFromBusCh chan struct{} // non-nil if consumeEventbusTopics is running, closed to signal it to return + busDoneCh chan struct{} // non-nil if consumeEventbusTopics is running, closed when it returns hasNodeAttrDisableRelayServer bool // tailcfg.NodeAttrDisableRelayServer - server relayServer // lazily initialized - } // relayServer is the interface of [udprelay.Server]. @@ -81,26 +76,27 @@ func (e *extension) Init(host ipnext.Host) error { return nil } -// initBusConnection initializes the [*eventbus.Client], [*eventbus.Subscriber], -// [*eventbus.Publisher], and [chan struct{}] used to publish/receive endpoint -// allocation messages to/from the [*eventbus.Bus]. It also starts -// consumeEventbusTopics in a separate goroutine. -func (e *extension) initBusConnection() { - e.eventClient = e.bus.Client("relayserver.extension") - e.reqSub = eventbus.Subscribe[magicsock.UDPRelayAllocReq](e.eventClient) - e.respPub = eventbus.Publish[magicsock.UDPRelayAllocResp](e.eventClient) +// handleBusLifetimeLocked handles the lifetime of consumeEventbusTopics. +func (e *extension) handleBusLifetimeLocked() { + busShouldBeRunning := !e.shutdown && e.port != nil && !e.hasNodeAttrDisableRelayServer + if !busShouldBeRunning { + e.disconnectFromBusLocked() + return + } + if e.busDoneCh != nil { + return // already running + } + port := *e.port + e.disconnectFromBusCh = make(chan struct{}) e.busDoneCh = make(chan struct{}) - go e.consumeEventbusTopics() + go e.consumeEventbusTopics(port) } func (e *extension) selfNodeViewChanged(nodeView tailcfg.NodeView) { e.mu.Lock() defer e.mu.Unlock() e.hasNodeAttrDisableRelayServer = nodeView.HasCap(tailcfg.NodeAttrDisableRelayServer) - if e.hasNodeAttrDisableRelayServer && e.server != nil { - e.server.Close() - e.server = nil - } + e.handleBusLifetimeLocked() } func (e *extension) profileStateChanged(_ ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { @@ -110,43 +106,52 @@ func (e *extension) profileStateChanged(_ ipn.LoginProfileView, prefs ipn.PrefsV enableOrDisableServer := ok != (e.port != nil) portChanged := ok && e.port != nil && newPort != *e.port if enableOrDisableServer || portChanged || !sameNode { - if e.server != nil { - e.server.Close() - e.server = nil - } - if e.port != nil { - e.eventClient.Close() - <-e.busDoneCh - } + e.disconnectFromBusLocked() e.port = nil if ok { e.port = ptr.To(newPort) - e.initBusConnection() } } + e.handleBusLifetimeLocked() } -func (e *extension) consumeEventbusTopics() { +func (e *extension) consumeEventbusTopics(port int) { defer close(e.busDoneCh) + eventClient := e.bus.Client("relayserver.extension") + reqSub := eventbus.Subscribe[magicsock.UDPRelayAllocReq](eventClient) + respPub := eventbus.Publish[magicsock.UDPRelayAllocResp](eventClient) + defer eventClient.Close() + + var rs relayServer // lazily initialized + defer func() { + if rs != nil { + rs.Close() + } + }() for { select { - case <-e.reqSub.Done(): + case <-e.disconnectFromBusCh: + return + case <-reqSub.Done(): // If reqSub is done, the eventClient has been closed, which is a // signal to return. return - case req := <-e.reqSub.Events(): - rs, err := e.relayServerOrInit() - if err != nil { - e.logf("error initializing server: %v", err) - continue + case req := <-reqSub.Events(): + if rs == nil { + var err error + rs, err = udprelay.NewServer(e.logf, port, nil) + if err != nil { + e.logf("error initializing server: %v", err) + continue + } } se, err := rs.AllocateEndpoint(req.Message.ClientDisco[0], req.Message.ClientDisco[1]) if err != nil { e.logf("error allocating endpoint: %v", err) continue } - e.respPub.Publish(magicsock.UDPRelayAllocResp{ + respPub.Publish(magicsock.UDPRelayAllocResp{ ReqRxFromNodeKey: req.RxFromNodeKey, ReqRxFromDiscoKey: req.RxFromDiscoKey, Message: &disco.AllocateUDPRelayEndpointResponse{ @@ -164,44 +169,22 @@ func (e *extension) consumeEventbusTopics() { }) } } +} +func (e *extension) disconnectFromBusLocked() { + if e.busDoneCh != nil { + close(e.disconnectFromBusCh) + <-e.busDoneCh + e.busDoneCh = nil + e.disconnectFromBusCh = nil + } } // Shutdown implements [ipnlocal.Extension]. func (e *extension) Shutdown() error { e.mu.Lock() defer e.mu.Unlock() + e.disconnectFromBusLocked() e.shutdown = true - if e.server != nil { - e.server.Close() - e.server = nil - } - if e.port != nil { - e.eventClient.Close() - <-e.busDoneCh - } return nil } - -func (e *extension) relayServerOrInit() (relayServer, error) { - e.mu.Lock() - defer e.mu.Unlock() - if e.shutdown { - return nil, errors.New("relay server is shutdown") - } - if e.server != nil { - return e.server, nil - } - if e.port == nil { - return nil, errors.New("relay server is not configured") - } - if e.hasNodeAttrDisableRelayServer { - return nil, errors.New("disable-relay-server node attribute is present") - } - var err error - e.server, err = udprelay.NewServer(e.logf, *e.port, nil) - if err != nil { - return nil, err - } - return e.server, nil -} diff --git a/feature/relayserver/relayserver_test.go b/feature/relayserver/relayserver_test.go index 84158188e..d3fc36a83 100644 --- a/feature/relayserver/relayserver_test.go +++ b/feature/relayserver/relayserver_test.go @@ -4,107 +4,91 @@ package relayserver import ( - "errors" "testing" "tailscale.com/ipn" - "tailscale.com/net/udprelay/endpoint" "tailscale.com/tsd" - "tailscale.com/types/key" "tailscale.com/types/ptr" + "tailscale.com/util/eventbus" ) -type fakeRelayServer struct{} - -func (f *fakeRelayServer) Close() error { return nil } - -func (f *fakeRelayServer) AllocateEndpoint(_, _ key.DiscoPublic) (endpoint.ServerEndpoint, error) { - return endpoint.ServerEndpoint{}, errors.New("fake relay server") -} - func Test_extension_profileStateChanged(t *testing.T) { prefsWithPortOne := ipn.Prefs{RelayServerPort: ptr.To(1)} prefsWithNilPort := ipn.Prefs{RelayServerPort: nil} type fields struct { - server relayServer - port *int + port *int } type args struct { prefs ipn.PrefsView sameNode bool } tests := []struct { - name string - fields fields - args args - wantPort *int - wantNilServer bool + name string + fields fields + args args + wantPort *int + wantBusRunning bool }{ { - name: "no changes non-nil server", + name: "no changes non-nil port", fields: fields{ - server: &fakeRelayServer{}, - port: ptr.To(1), + port: ptr.To(1), }, args: args{ prefs: prefsWithPortOne.View(), sameNode: true, }, - wantPort: ptr.To(1), - wantNilServer: false, + wantPort: ptr.To(1), + wantBusRunning: true, }, { name: "prefs port nil", fields: fields{ - server: &fakeRelayServer{}, - port: ptr.To(1), + port: ptr.To(1), }, args: args{ prefs: prefsWithNilPort.View(), sameNode: true, }, - wantPort: nil, - wantNilServer: true, + wantPort: nil, + wantBusRunning: false, }, { name: "prefs port changed", fields: fields{ - server: &fakeRelayServer{}, - port: ptr.To(2), + port: ptr.To(2), }, args: args{ prefs: prefsWithPortOne.View(), sameNode: true, }, - wantPort: ptr.To(1), - wantNilServer: true, + wantPort: ptr.To(1), + wantBusRunning: true, }, { name: "sameNode false", fields: fields{ - server: &fakeRelayServer{}, - port: ptr.To(1), + port: ptr.To(1), }, args: args{ prefs: prefsWithPortOne.View(), sameNode: false, }, - wantPort: ptr.To(1), - wantNilServer: true, + wantPort: ptr.To(1), + wantBusRunning: true, }, { name: "prefs port non-nil extension port nil", fields: fields{ - server: nil, - port: nil, + port: nil, }, args: args{ prefs: prefsWithPortOne.View(), sameNode: false, }, - wantPort: ptr.To(1), - wantNilServer: true, + wantPort: ptr.To(1), + wantBusRunning: true, }, } for _, tt := range tests { @@ -112,19 +96,13 @@ func Test_extension_profileStateChanged(t *testing.T) { sys := tsd.NewSystem() bus := sys.Bus.Get() e := &extension{ - port: tt.fields.port, - server: tt.fields.server, - bus: bus, - } - if e.port != nil { - // Entering profileStateChanged with a non-nil port requires - // bus init, which is called in profileStateChanged when - // transitioning port from nil to non-nil. - e.initBusConnection() + port: tt.fields.port, + bus: bus, } + defer e.disconnectFromBusLocked() e.profileStateChanged(ipn.LoginProfileView{}, tt.args.prefs, tt.args.sameNode) - if tt.wantNilServer != (e.server == nil) { - t.Errorf("wantNilServer: %v != (e.server == nil): %v", tt.wantNilServer, e.server == nil) + if tt.wantBusRunning != (e.busDoneCh != nil) { + t.Errorf("wantBusRunning: %v != (e.busDoneCh != nil): %v", tt.wantBusRunning, e.busDoneCh != nil) } if (tt.wantPort == nil) != (e.port == nil) { t.Errorf("(tt.wantPort == nil): %v != (e.port == nil): %v", tt.wantPort == nil, e.port == nil) @@ -134,3 +112,59 @@ func Test_extension_profileStateChanged(t *testing.T) { }) } } + +func Test_extension_handleBusLifetimeLocked(t *testing.T) { + tests := []struct { + name string + shutdown bool + port *int + busDoneCh chan struct{} + hasNodeAttrDisableRelayServer bool + wantBusRunning bool + }{ + { + name: "want running", + shutdown: false, + port: ptr.To(1), + hasNodeAttrDisableRelayServer: false, + wantBusRunning: true, + }, + { + name: "shutdown true", + shutdown: true, + port: ptr.To(1), + hasNodeAttrDisableRelayServer: false, + wantBusRunning: false, + }, + { + name: "port nil", + shutdown: false, + port: nil, + hasNodeAttrDisableRelayServer: false, + wantBusRunning: false, + }, + { + name: "hasNodeAttrDisableRelayServer true", + shutdown: false, + port: nil, + hasNodeAttrDisableRelayServer: true, + wantBusRunning: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &extension{ + bus: eventbus.New(), + shutdown: tt.shutdown, + port: tt.port, + busDoneCh: tt.busDoneCh, + hasNodeAttrDisableRelayServer: tt.hasNodeAttrDisableRelayServer, + } + e.handleBusLifetimeLocked() + defer e.disconnectFromBusLocked() + if tt.wantBusRunning != (e.busDoneCh != nil) { + t.Errorf("wantBusRunning: %v != (e.busDoneCh != nil): %v", tt.wantBusRunning, e.busDoneCh != nil) + } + }) + } +}