diff --git a/feature/relayserver/relayserver.go b/feature/relayserver/relayserver.go index 9cf776661..3d851780d 100644 --- a/feature/relayserver/relayserver.go +++ b/feature/relayserver/relayserver.go @@ -13,7 +13,9 @@ import ( "net/netip" "sync" + "tailscale.com/envknob" "tailscale.com/feature" + "tailscale.com/ipn" "tailscale.com/ipn/ipnext" "tailscale.com/ipn/ipnlocal" "tailscale.com/net/udprelay" @@ -21,6 +23,7 @@ import ( "tailscale.com/tsd" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/types/ptr" "tailscale.com/util/httpm" ) @@ -46,10 +49,17 @@ func newExtension(logf logger.Logf, _ *tsd.System) (ipnext.Extension, error) { type extension struct { logf logger.Logf - mu sync.Mutex // guards the following fields - shutdown bool - port int - server *udprelay.Server // lazily initialized + mu sync.Mutex // guards the following fields + shutdown bool + port *int // ipn.Prefs.RelayServerPort, nil if disabled + hasNodeAttrRelayServer bool // tailcfg.NodeAttrRelayServer + server relayServer // lazily initialized +} + +// relayServer is the interface of [udprelay.Server]. +type relayServer interface { + AllocateEndpoint(discoA key.DiscoPublic, discoB key.DiscoPublic) (udprelay.ServerEndpoint, error) + Close() error } // Name implements [ipnext.Extension]. @@ -59,10 +69,32 @@ func (e *extension) Name() string { // Init implements [ipnext.Extension] by registering callbacks and providers // for the duration of the extension's lifetime. -func (e *extension) Init(_ ipnext.Host) error { +func (e *extension) Init(host ipnext.Host) error { + profile, prefs := host.Profiles().CurrentProfileState() + e.profileStateChanged(profile, prefs, false) + host.Profiles().RegisterProfileStateChangeCallback(e.profileStateChanged) + // TODO(jwhited): callback for netmap/nodeattr changes (e.hasNodeAttrRelayServer) return nil } +func (e *extension) profileStateChanged(_ ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { + e.mu.Lock() + defer e.mu.Unlock() + newPort, ok := prefs.RelayServerPort().GetOk() + enableOrDisableServer := ok != (e.port != nil) + portChanged := ok && e.port != nil && newPort != *e.port + if enableOrDisableServer || portChanged || !sameNode { + if e.server != nil { + e.server.Close() + e.server = nil + } + e.port = nil + if ok { + e.port = ptr.To(newPort) + } + } +} + // Shutdown implements [ipnlocal.Extension]. func (e *extension) Shutdown() error { e.mu.Lock() @@ -75,16 +107,7 @@ func (e *extension) Shutdown() error { return nil } -func (e *extension) shouldRunRelayServer() bool { - // TODO(jwhited): consider: - // 1. tailcfg.NodeAttrRelayServer - // 2. ipn.Prefs.RelayServerPort - // 3. envknob.UseWIPCode() - // 4. e.shutdown - return false -} - -func (e *extension) relayServerOrInit() (*udprelay.Server, error) { +func (e *extension) relayServerOrInit() (relayServer, error) { e.mu.Lock() defer e.mu.Unlock() if e.shutdown { @@ -93,8 +116,17 @@ func (e *extension) relayServerOrInit() (*udprelay.Server, error) { if e.server != nil { return e.server, nil } + if e.port == nil { + return nil, errors.New("relay server is not configured") + } + if !e.hasNodeAttrRelayServer { + return nil, errors.New("no relay:server node attribute") + } + if !envknob.UseWIPCode() { + return nil, errors.New("TAILSCALE_USE_WIP_CODE envvar is not set") + } var err error - e.server, _, err = udprelay.NewServer(e.port, []netip.Addr{netip.MustParseAddr("127.0.0.1")}) + e.server, _, err = udprelay.NewServer(*e.port, []netip.Addr{netip.MustParseAddr("127.0.0.1")}) if err != nil { return nil, err } @@ -102,25 +134,24 @@ func (e *extension) relayServerOrInit() (*udprelay.Server, error) { } func handlePeerAPIRelayAllocateEndpoint(h ipnlocal.PeerAPIHandler, w http.ResponseWriter, r *http.Request) { - // TODO(jwhited): log errors e, ok := h.LocalBackend().FindExtensionByName(featureName).(*extension) if !ok { http.Error(w, "relay failed to initialize", http.StatusServiceUnavailable) return } - if !e.shouldRunRelayServer() { - http.Error(w, "relay not enabled", http.StatusNotFound) - return + httpErrAndLog := func(message string, code int) { + http.Error(w, message, code) + e.logf("peerapi: request from %v returned code %d: %s", h.RemoteAddr(), code, message) } if !h.PeerCaps().HasCapability(tailcfg.PeerCapabilityRelay) { - http.Error(w, "relay not permitted", http.StatusForbidden) + httpErrAndLog("relay not permitted", http.StatusForbidden) return } if r.Method != httpm.POST { - http.Error(w, "only POST method is allowed", http.StatusMethodNotAllowed) + httpErrAndLog("only POST method is allowed", http.StatusMethodNotAllowed) return } @@ -129,26 +160,26 @@ func handlePeerAPIRelayAllocateEndpoint(h ipnlocal.PeerAPIHandler, w http.Respon } err := json.NewDecoder(io.LimitReader(r.Body, 512)).Decode(&allocateEndpointReq) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + httpErrAndLog(err.Error(), http.StatusBadRequest) return } if len(allocateEndpointReq.DiscoKeys) != 2 { - http.Error(w, "2 disco public keys must be supplied", http.StatusBadRequest) + httpErrAndLog("2 disco public keys must be supplied", http.StatusBadRequest) return } rs, err := e.relayServerOrInit() if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + httpErrAndLog(err.Error(), http.StatusServiceUnavailable) return } ep, err := rs.AllocateEndpoint(allocateEndpointReq.DiscoKeys[0], allocateEndpointReq.DiscoKeys[1]) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + httpErrAndLog(err.Error(), http.StatusInternalServerError) return } err = json.NewEncoder(w).Encode(&ep) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + httpErrAndLog(err.Error(), http.StatusInternalServerError) } } diff --git a/feature/relayserver/relayserver_test.go b/feature/relayserver/relayserver_test.go new file mode 100644 index 000000000..af4d11df0 --- /dev/null +++ b/feature/relayserver/relayserver_test.go @@ -0,0 +1,126 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package relayserver + +import ( + "errors" + "testing" + + "tailscale.com/ipn" + "tailscale.com/net/udprelay" + "tailscale.com/types/key" + "tailscale.com/types/ptr" +) + +type fakeRelayServer struct{} + +func (f *fakeRelayServer) Close() error { return nil } + +func (f *fakeRelayServer) AllocateEndpoint(_, _ key.DiscoPublic) (udprelay.ServerEndpoint, error) { + return udprelay.ServerEndpoint{}, errors.New("fake relay server") +} + +func Test_extension_profileStateChanged(t *testing.T) { + prefsWithPortOne := ipn.Prefs{RelayServerPort: ptr.To(1)} + prefsWithNilPort := ipn.Prefs{RelayServerPort: nil} + + type fields struct { + server relayServer + port *int + } + type args struct { + prefs ipn.PrefsView + sameNode bool + } + tests := []struct { + name string + fields fields + args args + wantPort *int + wantNilServer bool + }{ + { + name: "no changes non-nil server", + fields: fields{ + server: &fakeRelayServer{}, + port: ptr.To(1), + }, + args: args{ + prefs: prefsWithPortOne.View(), + sameNode: true, + }, + wantPort: ptr.To(1), + wantNilServer: false, + }, + { + name: "prefs port nil", + fields: fields{ + server: &fakeRelayServer{}, + port: ptr.To(1), + }, + args: args{ + prefs: prefsWithNilPort.View(), + sameNode: true, + }, + wantPort: nil, + wantNilServer: true, + }, + { + name: "prefs port changed", + fields: fields{ + server: &fakeRelayServer{}, + port: ptr.To(2), + }, + args: args{ + prefs: prefsWithPortOne.View(), + sameNode: true, + }, + wantPort: ptr.To(1), + wantNilServer: true, + }, + { + name: "sameNode false", + fields: fields{ + server: &fakeRelayServer{}, + port: ptr.To(1), + }, + args: args{ + prefs: prefsWithPortOne.View(), + sameNode: false, + }, + wantPort: ptr.To(1), + wantNilServer: true, + }, + { + name: "prefs port non-nil extension port nil", + fields: fields{ + server: nil, + port: nil, + }, + args: args{ + prefs: prefsWithPortOne.View(), + sameNode: false, + }, + wantPort: ptr.To(1), + wantNilServer: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &extension{ + port: tt.fields.port, + server: tt.fields.server, + } + e.profileStateChanged(ipn.LoginProfileView{}, tt.args.prefs, tt.args.sameNode) + if tt.wantNilServer != (e.server == nil) { + t.Errorf("wantNilServer: %v != (e.server == nil): %v", tt.wantNilServer, e.server == nil) + } + if (tt.wantPort == nil) != (e.port == nil) { + t.Errorf("(tt.wantPort == nil): %v != (e.port == nil): %v", tt.wantPort == nil, e.port == nil) + } else if tt.wantPort != nil && *tt.wantPort != *e.port { + t.Errorf("wantPort: %d != *e.port: %d", *tt.wantPort, *e.port) + } + }) + } +}