diff --git a/cmd/tailscale/cli/serve.go b/cmd/tailscale/cli/serve.go index efae1d178..7b1be20cd 100644 --- a/cmd/tailscale/cli/serve.go +++ b/cmd/tailscale/cli/serve.go @@ -30,9 +30,9 @@ import ( "tailscale.com/version" ) -var serveCmd = newServeCommand(&serveEnv{}) +var serveCmd = newServeCommand(&serveEnv{lc: &localClient}) -// newServeCommand returns a new "serve" subcommand using e as its environmment. +// newServeCommand returns a new "serve" subcommand using e as its environment. func newServeCommand(e *serveEnv) *ffcli.Command { return &ffcli.Command{ Name: "serve", @@ -127,8 +127,21 @@ func (e *serveEnv) newFlags(name string, setup func(fs *flag.FlagSet)) *flag.Fla return fs } +// localServeClient is an interface conforming to the subset of +// tailscale.LocalClient. It includes only the methods used by the +// serve command. +// +// The purpose of this interface is to allow tests to provide a mock. +type localServeClient interface { + Status(context.Context) (*ipnstate.Status, error) + GetServeConfig(context.Context) (*ipn.ServeConfig, error) + SetServeConfig(context.Context, *ipn.ServeConfig) error +} + // serveEnv is the environment the serve command runs within. All I/O should be // done via serveEnv methods so that it can be faked out for tests. +// Calls to localClient should be done via the lc field, which is an interface +// that can be faked out for tests. // // It also contains the flags, as registered with newServeCommand. type serveEnv struct { @@ -138,12 +151,11 @@ type serveEnv struct { remove bool // remove a serve config json bool // output JSON (status only for now) + lc localServeClient // localClient interface, specific to serve + // optional stuff for tests: - testFlagOut io.Writer - testGetServeConfig func(context.Context) (*ipn.ServeConfig, error) - testSetServeConfig func(context.Context, *ipn.ServeConfig) error - testGetLocalClientStatus func(context.Context) (*ipnstate.Status, error) - testStdout io.Writer + testFlagOut io.Writer + testStdout io.Writer } // getSelfDNSName returns the DNS name of the current node. @@ -157,15 +169,12 @@ func (e *serveEnv) getSelfDNSName(ctx context.Context) (string, error) { return strings.TrimSuffix(st.Self.DNSName, "."), nil } -// getLocalClientStatus calls LocalClient.Status, checks if -// Status is ready. +// getLocalClientStatus returns the Status of the local client. // Returns error if unable to reach tailscaled or if self node is nil. +// // Exits if status is not running or starting. func (e *serveEnv) getLocalClientStatus(ctx context.Context) (*ipnstate.Status, error) { - if e.testGetLocalClientStatus != nil { - return e.testGetLocalClientStatus(ctx) - } - st, err := localClient.Status(ctx) + st, err := e.lc.Status(ctx) if err != nil { return nil, fixTailscaledConnectError(err) } @@ -180,20 +189,6 @@ func (e *serveEnv) getLocalClientStatus(ctx context.Context) (*ipnstate.Status, return st, nil } -func (e *serveEnv) getServeConfig(ctx context.Context) (*ipn.ServeConfig, error) { - if e.testGetServeConfig != nil { - return e.testGetServeConfig(ctx) - } - return localClient.GetServeConfig(ctx) -} - -func (e *serveEnv) setServeConfig(ctx context.Context, c *ipn.ServeConfig) error { - if e.testSetServeConfig != nil { - return e.testSetServeConfig(ctx, c) - } - return localClient.SetServeConfig(ctx, c) -} - // validateServePort returns --serve-port flag value, // or an error if the port is not a valid port to serve on. func (e *serveEnv) validateServePort() (port uint16, err error) { @@ -232,7 +227,7 @@ func (e *serveEnv) runServe(ctx context.Context, args []string) error { if err := json.Unmarshal(valb, sc); err != nil { return fmt.Errorf("invalid JSON: %w", err) } - return localClient.SetServeConfig(ctx, sc) + return e.lc.SetServeConfig(ctx, sc) } if !(len(args) == 3 || (e.remove && len(args) >= 1)) { @@ -294,7 +289,7 @@ func (e *serveEnv) runServe(ctx context.Context, args []string) error { return flag.ErrHelp } - cursc, err := e.getServeConfig(ctx) + cursc, err := e.lc.GetServeConfig(ctx) if err != nil { return err } @@ -337,7 +332,7 @@ func (e *serveEnv) runServe(ctx context.Context, args []string) error { } if !reflect.DeepEqual(cursc, sc) { - if err := e.setServeConfig(ctx, sc); err != nil { + if err := e.lc.SetServeConfig(ctx, sc); err != nil { return err } } @@ -351,7 +346,7 @@ func (e *serveEnv) handleWebServeRemove(ctx context.Context, mount string) error return err } srvPortStr := strconv.Itoa(int(srvPort)) - sc, err := e.getServeConfig(ctx) + sc, err := e.lc.GetServeConfig(ctx) if err != nil { return err } @@ -382,7 +377,7 @@ func (e *serveEnv) handleWebServeRemove(ctx context.Context, mount string) error if len(sc.TCP) == 0 { sc.TCP = nil } - if err := e.setServeConfig(ctx, sc); err != nil { + if err := e.lc.SetServeConfig(ctx, sc); err != nil { return err } return nil @@ -453,7 +448,7 @@ func allNumeric(s string) bool { // - tailscale status // - tailscale status --json func (e *serveEnv) runServeStatus(ctx context.Context, args []string) error { - sc, err := e.getServeConfig(ctx) + sc, err := e.lc.GetServeConfig(ctx) if err != nil { return err } @@ -603,7 +598,7 @@ func (e *serveEnv) runServeTCP(ctx context.Context, args []string) error { fmt.Fprintf(os.Stderr, "error: invalid port %q\n\n", portStr) } - cursc, err := e.getServeConfig(ctx) + cursc, err := e.lc.GetServeConfig(ctx) if err != nil { return err } @@ -628,7 +623,7 @@ func (e *serveEnv) runServeTCP(ctx context.Context, args []string) error { if len(sc.TCP) == 0 { sc.TCP = nil } - return e.setServeConfig(ctx, sc) + return e.lc.SetServeConfig(ctx, sc) } return errors.New("error: serve config does not exist") } @@ -644,7 +639,7 @@ func (e *serveEnv) runServeTCP(ctx context.Context, args []string) error { } if !reflect.DeepEqual(cursc, sc) { - if err := e.setServeConfig(ctx, sc); err != nil { + if err := e.lc.SetServeConfig(ctx, sc); err != nil { return err } } @@ -674,7 +669,7 @@ func (e *serveEnv) runServeFunnel(ctx context.Context, args []string) error { default: return flag.ErrHelp } - sc, err := e.getServeConfig(ctx) + sc, err := e.lc.GetServeConfig(ctx) if err != nil { return err } @@ -703,7 +698,7 @@ func (e *serveEnv) runServeFunnel(ctx context.Context, args []string) error { sc.AllowFunnel = nil } } - if err := e.setServeConfig(ctx, sc); err != nil { + if err := e.lc.SetServeConfig(ctx, sc); err != nil { return err } return nil diff --git a/cmd/tailscale/cli/serve_test.go b/cmd/tailscale/cli/serve_test.go index 6f3fe50cb..d4aec3708 100644 --- a/cmd/tailscale/cli/serve_test.go +++ b/cmd/tailscale/cli/serve_test.go @@ -144,10 +144,10 @@ func TestServeConfigMutations(t *testing.T) { }, }, }) - add(step{ + add(step{ // invalid port command: cmd("--serve-port=9999 /abc proxy 3001"), wantErr: anyErr(), - }) // invalid port + }) add(step{ command: cmd("--serve-port=8443 /abc proxy 3001"), want: &ipn.ServeConfig{ @@ -606,12 +606,12 @@ func TestServeConfigMutations(t *testing.T) { wantErr: anyErr(), }) + lc := &fakeLocalServeClient{} // And now run the steps above. - var current *ipn.ServeConfig for i, st := range steps { if st.reset { t.Logf("Executing step #%d, line %v: [reset]", i, st.line) - current = nil + lc.config = nil } if st.command == nil { continue @@ -620,26 +620,12 @@ func TestServeConfigMutations(t *testing.T) { var stdout bytes.Buffer var flagOut bytes.Buffer - var newState *ipn.ServeConfig e := &serveEnv{ + lc: lc, testFlagOut: &flagOut, testStdout: &stdout, - testGetLocalClientStatus: func(context.Context) (*ipnstate.Status, error) { - return &ipnstate.Status{ - Self: &ipnstate.PeerStatus{ - DNSName: "foo.test.ts.net", - Capabilities: []string{tailcfg.NodeAttrFunnel}, - }, - }, nil - }, - testGetServeConfig: func(context.Context) (*ipn.ServeConfig, error) { - return current, nil - }, - testSetServeConfig: func(_ context.Context, c *ipn.ServeConfig) error { - newState = c - return nil - }, } + lastCount := lc.setCount cmd := newServeCommand(e) err := cmd.ParseAndRun(context.Background(), st.command) if flagOut.Len() > 0 { @@ -655,23 +641,61 @@ func TestServeConfigMutations(t *testing.T) { continue } if st.wantErr != nil { - t.Fatalf("step #%d, line %v: got success (saved=%v), but wanted an error", i, st.line, newState != nil) + t.Fatalf("step #%d, line %v: got success (saved=%v), but wanted an error", i, st.line, lc.config != nil) } - if !reflect.DeepEqual(newState, st.want) { + var got *ipn.ServeConfig = nil + if lc.setCount > lastCount { + got = lc.config + } + if !reflect.DeepEqual(got, st.want) { t.Fatalf("[%d] %v: bad state. got:\n%s\n\nwant:\n%s\n", - i, st.command, asJSON(newState), asJSON(st.want)) + i, st.command, asJSON(got), asJSON(st.want)) // NOTE: asJSON will omit empty fields, which might make // result in bad state got/want diffs being the same, even // though the actual state is different. Use below to debug: // t.Fatalf("[%d] %v: bad state. got:\n%+v\n\nwant:\n%+v\n", - // i, st.command, newState, st.want) - } - if newState != nil { - current = newState + // i, st.command, got, st.want) } } } +// fakeLocalServeClient is a fake tailscale.LocalClient for tests. +// It's not a full implementation, just enough to test the serve command. +// +// The fake client is stateful, and is used to test manipulating +// ServeConfig state. This implementation cannot be used concurrently. +type fakeLocalServeClient struct { + config *ipn.ServeConfig + setCount int // counts calls to SetServeConfig +} + +// fakeStatus is a fake ipnstate.Status value for tests. +// It's not a full implementation, just enough to test the serve command. +// +// It returns a state that's running, with a fake DNSName and the Funnel +// node attribute capability. +var fakeStatus = &ipnstate.Status{ + BackendState: ipn.Running.String(), + Self: &ipnstate.PeerStatus{ + DNSName: "foo.test.ts.net", + Capabilities: []string{tailcfg.NodeAttrFunnel}, + }, +} + +func (lc *fakeLocalServeClient) Status(ctx context.Context) (*ipnstate.Status, error) { + return fakeStatus, nil +} + +func (lc *fakeLocalServeClient) GetServeConfig(ctx context.Context) (*ipn.ServeConfig, error) { + return lc.config.Clone(), nil +} + +func (lc *fakeLocalServeClient) SetServeConfig(ctx context.Context, config *ipn.ServeConfig) error { + lc.setCount += 1 + lc.config = config.Clone() + return nil +} + // exactError returns an error checker that wants exactly the provided want error. // If optName is non-empty, it's used in the error message. func exactErr(want error, optName ...string) func(error) string {