cmd/tailscale/cli: use mock impl of LocalClient for serve cmd (#6422)

Create an interface and mock implementation of tailscale.LocalClient for
serve command tests.

Updates #6304
Closes #6372

Signed-off-by: Shayne Sweeney <shayne@tailscale.com>
This commit is contained in:
shayne 2023-01-20 12:36:25 -05:00 committed by GitHub
parent fec888581a
commit 73399f784b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 84 additions and 65 deletions

View File

@ -30,9 +30,9 @@
"tailscale.com/version"
)
var serveCmd = newServeCommand(&serveEnv{})
var serveCmd = newServeCommand(&serveEnv{lc: &localClient})
// newServeCommand returns a new "serve" subcommand using e as its environmment.
// newServeCommand returns a new "serve" subcommand using e as its environment.
func newServeCommand(e *serveEnv) *ffcli.Command {
return &ffcli.Command{
Name: "serve",
@ -127,8 +127,21 @@ func (e *serveEnv) newFlags(name string, setup func(fs *flag.FlagSet)) *flag.Fla
return fs
}
// localServeClient is an interface conforming to the subset of
// tailscale.LocalClient. It includes only the methods used by the
// serve command.
//
// The purpose of this interface is to allow tests to provide a mock.
type localServeClient interface {
Status(context.Context) (*ipnstate.Status, error)
GetServeConfig(context.Context) (*ipn.ServeConfig, error)
SetServeConfig(context.Context, *ipn.ServeConfig) error
}
// serveEnv is the environment the serve command runs within. All I/O should be
// done via serveEnv methods so that it can be faked out for tests.
// Calls to localClient should be done via the lc field, which is an interface
// that can be faked out for tests.
//
// It also contains the flags, as registered with newServeCommand.
type serveEnv struct {
@ -138,12 +151,11 @@ type serveEnv struct {
remove bool // remove a serve config
json bool // output JSON (status only for now)
lc localServeClient // localClient interface, specific to serve
// optional stuff for tests:
testFlagOut io.Writer
testGetServeConfig func(context.Context) (*ipn.ServeConfig, error)
testSetServeConfig func(context.Context, *ipn.ServeConfig) error
testGetLocalClientStatus func(context.Context) (*ipnstate.Status, error)
testStdout io.Writer
testFlagOut io.Writer
testStdout io.Writer
}
// getSelfDNSName returns the DNS name of the current node.
@ -157,15 +169,12 @@ func (e *serveEnv) getSelfDNSName(ctx context.Context) (string, error) {
return strings.TrimSuffix(st.Self.DNSName, "."), nil
}
// getLocalClientStatus calls LocalClient.Status, checks if
// Status is ready.
// getLocalClientStatus returns the Status of the local client.
// Returns error if unable to reach tailscaled or if self node is nil.
//
// Exits if status is not running or starting.
func (e *serveEnv) getLocalClientStatus(ctx context.Context) (*ipnstate.Status, error) {
if e.testGetLocalClientStatus != nil {
return e.testGetLocalClientStatus(ctx)
}
st, err := localClient.Status(ctx)
st, err := e.lc.Status(ctx)
if err != nil {
return nil, fixTailscaledConnectError(err)
}
@ -180,20 +189,6 @@ func (e *serveEnv) getLocalClientStatus(ctx context.Context) (*ipnstate.Status,
return st, nil
}
func (e *serveEnv) getServeConfig(ctx context.Context) (*ipn.ServeConfig, error) {
if e.testGetServeConfig != nil {
return e.testGetServeConfig(ctx)
}
return localClient.GetServeConfig(ctx)
}
func (e *serveEnv) setServeConfig(ctx context.Context, c *ipn.ServeConfig) error {
if e.testSetServeConfig != nil {
return e.testSetServeConfig(ctx, c)
}
return localClient.SetServeConfig(ctx, c)
}
// validateServePort returns --serve-port flag value,
// or an error if the port is not a valid port to serve on.
func (e *serveEnv) validateServePort() (port uint16, err error) {
@ -232,7 +227,7 @@ func (e *serveEnv) runServe(ctx context.Context, args []string) error {
if err := json.Unmarshal(valb, sc); err != nil {
return fmt.Errorf("invalid JSON: %w", err)
}
return localClient.SetServeConfig(ctx, sc)
return e.lc.SetServeConfig(ctx, sc)
}
if !(len(args) == 3 || (e.remove && len(args) >= 1)) {
@ -294,7 +289,7 @@ func (e *serveEnv) runServe(ctx context.Context, args []string) error {
return flag.ErrHelp
}
cursc, err := e.getServeConfig(ctx)
cursc, err := e.lc.GetServeConfig(ctx)
if err != nil {
return err
}
@ -337,7 +332,7 @@ func (e *serveEnv) runServe(ctx context.Context, args []string) error {
}
if !reflect.DeepEqual(cursc, sc) {
if err := e.setServeConfig(ctx, sc); err != nil {
if err := e.lc.SetServeConfig(ctx, sc); err != nil {
return err
}
}
@ -351,7 +346,7 @@ func (e *serveEnv) handleWebServeRemove(ctx context.Context, mount string) error
return err
}
srvPortStr := strconv.Itoa(int(srvPort))
sc, err := e.getServeConfig(ctx)
sc, err := e.lc.GetServeConfig(ctx)
if err != nil {
return err
}
@ -382,7 +377,7 @@ func (e *serveEnv) handleWebServeRemove(ctx context.Context, mount string) error
if len(sc.TCP) == 0 {
sc.TCP = nil
}
if err := e.setServeConfig(ctx, sc); err != nil {
if err := e.lc.SetServeConfig(ctx, sc); err != nil {
return err
}
return nil
@ -453,7 +448,7 @@ func allNumeric(s string) bool {
// - tailscale status
// - tailscale status --json
func (e *serveEnv) runServeStatus(ctx context.Context, args []string) error {
sc, err := e.getServeConfig(ctx)
sc, err := e.lc.GetServeConfig(ctx)
if err != nil {
return err
}
@ -603,7 +598,7 @@ func (e *serveEnv) runServeTCP(ctx context.Context, args []string) error {
fmt.Fprintf(os.Stderr, "error: invalid port %q\n\n", portStr)
}
cursc, err := e.getServeConfig(ctx)
cursc, err := e.lc.GetServeConfig(ctx)
if err != nil {
return err
}
@ -628,7 +623,7 @@ func (e *serveEnv) runServeTCP(ctx context.Context, args []string) error {
if len(sc.TCP) == 0 {
sc.TCP = nil
}
return e.setServeConfig(ctx, sc)
return e.lc.SetServeConfig(ctx, sc)
}
return errors.New("error: serve config does not exist")
}
@ -644,7 +639,7 @@ func (e *serveEnv) runServeTCP(ctx context.Context, args []string) error {
}
if !reflect.DeepEqual(cursc, sc) {
if err := e.setServeConfig(ctx, sc); err != nil {
if err := e.lc.SetServeConfig(ctx, sc); err != nil {
return err
}
}
@ -674,7 +669,7 @@ func (e *serveEnv) runServeFunnel(ctx context.Context, args []string) error {
default:
return flag.ErrHelp
}
sc, err := e.getServeConfig(ctx)
sc, err := e.lc.GetServeConfig(ctx)
if err != nil {
return err
}
@ -703,7 +698,7 @@ func (e *serveEnv) runServeFunnel(ctx context.Context, args []string) error {
sc.AllowFunnel = nil
}
}
if err := e.setServeConfig(ctx, sc); err != nil {
if err := e.lc.SetServeConfig(ctx, sc); err != nil {
return err
}
return nil

View File

@ -144,10 +144,10 @@ type step struct {
},
},
})
add(step{
add(step{ // invalid port
command: cmd("--serve-port=9999 /abc proxy 3001"),
wantErr: anyErr(),
}) // invalid port
})
add(step{
command: cmd("--serve-port=8443 /abc proxy 3001"),
want: &ipn.ServeConfig{
@ -606,12 +606,12 @@ type step struct {
wantErr: anyErr(),
})
lc := &fakeLocalServeClient{}
// And now run the steps above.
var current *ipn.ServeConfig
for i, st := range steps {
if st.reset {
t.Logf("Executing step #%d, line %v: [reset]", i, st.line)
current = nil
lc.config = nil
}
if st.command == nil {
continue
@ -620,26 +620,12 @@ type step struct {
var stdout bytes.Buffer
var flagOut bytes.Buffer
var newState *ipn.ServeConfig
e := &serveEnv{
lc: lc,
testFlagOut: &flagOut,
testStdout: &stdout,
testGetLocalClientStatus: func(context.Context) (*ipnstate.Status, error) {
return &ipnstate.Status{
Self: &ipnstate.PeerStatus{
DNSName: "foo.test.ts.net",
Capabilities: []string{tailcfg.NodeAttrFunnel},
},
}, nil
},
testGetServeConfig: func(context.Context) (*ipn.ServeConfig, error) {
return current, nil
},
testSetServeConfig: func(_ context.Context, c *ipn.ServeConfig) error {
newState = c
return nil
},
}
lastCount := lc.setCount
cmd := newServeCommand(e)
err := cmd.ParseAndRun(context.Background(), st.command)
if flagOut.Len() > 0 {
@ -655,23 +641,61 @@ type step struct {
continue
}
if st.wantErr != nil {
t.Fatalf("step #%d, line %v: got success (saved=%v), but wanted an error", i, st.line, newState != nil)
t.Fatalf("step #%d, line %v: got success (saved=%v), but wanted an error", i, st.line, lc.config != nil)
}
if !reflect.DeepEqual(newState, st.want) {
var got *ipn.ServeConfig = nil
if lc.setCount > lastCount {
got = lc.config
}
if !reflect.DeepEqual(got, st.want) {
t.Fatalf("[%d] %v: bad state. got:\n%s\n\nwant:\n%s\n",
i, st.command, asJSON(newState), asJSON(st.want))
i, st.command, asJSON(got), asJSON(st.want))
// NOTE: asJSON will omit empty fields, which might make
// result in bad state got/want diffs being the same, even
// though the actual state is different. Use below to debug:
// t.Fatalf("[%d] %v: bad state. got:\n%+v\n\nwant:\n%+v\n",
// i, st.command, newState, st.want)
}
if newState != nil {
current = newState
// i, st.command, got, st.want)
}
}
}
// fakeLocalServeClient is a fake tailscale.LocalClient for tests.
// It's not a full implementation, just enough to test the serve command.
//
// The fake client is stateful, and is used to test manipulating
// ServeConfig state. This implementation cannot be used concurrently.
type fakeLocalServeClient struct {
config *ipn.ServeConfig
setCount int // counts calls to SetServeConfig
}
// fakeStatus is a fake ipnstate.Status value for tests.
// It's not a full implementation, just enough to test the serve command.
//
// It returns a state that's running, with a fake DNSName and the Funnel
// node attribute capability.
var fakeStatus = &ipnstate.Status{
BackendState: ipn.Running.String(),
Self: &ipnstate.PeerStatus{
DNSName: "foo.test.ts.net",
Capabilities: []string{tailcfg.NodeAttrFunnel},
},
}
func (lc *fakeLocalServeClient) Status(ctx context.Context) (*ipnstate.Status, error) {
return fakeStatus, nil
}
func (lc *fakeLocalServeClient) GetServeConfig(ctx context.Context) (*ipn.ServeConfig, error) {
return lc.config.Clone(), nil
}
func (lc *fakeLocalServeClient) SetServeConfig(ctx context.Context, config *ipn.ServeConfig) error {
lc.setCount += 1
lc.config = config.Clone()
return nil
}
// exactError returns an error checker that wants exactly the provided want error.
// If optName is non-empty, it's used in the error message.
func exactErr(want error, optName ...string) func(error) string {