mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 04:55:31 +00:00
cmd/tailscale/cli: use mock impl of LocalClient for serve cmd (#6422)
Create an interface and mock implementation of tailscale.LocalClient for serve command tests. Updates #6304 Closes #6372 Signed-off-by: Shayne Sweeney <shayne@tailscale.com>
This commit is contained in:
parent
fec888581a
commit
73399f784b
@ -30,9 +30,9 @@
|
||||
"tailscale.com/version"
|
||||
)
|
||||
|
||||
var serveCmd = newServeCommand(&serveEnv{})
|
||||
var serveCmd = newServeCommand(&serveEnv{lc: &localClient})
|
||||
|
||||
// newServeCommand returns a new "serve" subcommand using e as its environmment.
|
||||
// newServeCommand returns a new "serve" subcommand using e as its environment.
|
||||
func newServeCommand(e *serveEnv) *ffcli.Command {
|
||||
return &ffcli.Command{
|
||||
Name: "serve",
|
||||
@ -127,8 +127,21 @@ func (e *serveEnv) newFlags(name string, setup func(fs *flag.FlagSet)) *flag.Fla
|
||||
return fs
|
||||
}
|
||||
|
||||
// localServeClient is an interface conforming to the subset of
|
||||
// tailscale.LocalClient. It includes only the methods used by the
|
||||
// serve command.
|
||||
//
|
||||
// The purpose of this interface is to allow tests to provide a mock.
|
||||
type localServeClient interface {
|
||||
Status(context.Context) (*ipnstate.Status, error)
|
||||
GetServeConfig(context.Context) (*ipn.ServeConfig, error)
|
||||
SetServeConfig(context.Context, *ipn.ServeConfig) error
|
||||
}
|
||||
|
||||
// serveEnv is the environment the serve command runs within. All I/O should be
|
||||
// done via serveEnv methods so that it can be faked out for tests.
|
||||
// Calls to localClient should be done via the lc field, which is an interface
|
||||
// that can be faked out for tests.
|
||||
//
|
||||
// It also contains the flags, as registered with newServeCommand.
|
||||
type serveEnv struct {
|
||||
@ -138,12 +151,11 @@ type serveEnv struct {
|
||||
remove bool // remove a serve config
|
||||
json bool // output JSON (status only for now)
|
||||
|
||||
lc localServeClient // localClient interface, specific to serve
|
||||
|
||||
// optional stuff for tests:
|
||||
testFlagOut io.Writer
|
||||
testGetServeConfig func(context.Context) (*ipn.ServeConfig, error)
|
||||
testSetServeConfig func(context.Context, *ipn.ServeConfig) error
|
||||
testGetLocalClientStatus func(context.Context) (*ipnstate.Status, error)
|
||||
testStdout io.Writer
|
||||
testFlagOut io.Writer
|
||||
testStdout io.Writer
|
||||
}
|
||||
|
||||
// getSelfDNSName returns the DNS name of the current node.
|
||||
@ -157,15 +169,12 @@ func (e *serveEnv) getSelfDNSName(ctx context.Context) (string, error) {
|
||||
return strings.TrimSuffix(st.Self.DNSName, "."), nil
|
||||
}
|
||||
|
||||
// getLocalClientStatus calls LocalClient.Status, checks if
|
||||
// Status is ready.
|
||||
// getLocalClientStatus returns the Status of the local client.
|
||||
// Returns error if unable to reach tailscaled or if self node is nil.
|
||||
//
|
||||
// Exits if status is not running or starting.
|
||||
func (e *serveEnv) getLocalClientStatus(ctx context.Context) (*ipnstate.Status, error) {
|
||||
if e.testGetLocalClientStatus != nil {
|
||||
return e.testGetLocalClientStatus(ctx)
|
||||
}
|
||||
st, err := localClient.Status(ctx)
|
||||
st, err := e.lc.Status(ctx)
|
||||
if err != nil {
|
||||
return nil, fixTailscaledConnectError(err)
|
||||
}
|
||||
@ -180,20 +189,6 @@ func (e *serveEnv) getLocalClientStatus(ctx context.Context) (*ipnstate.Status,
|
||||
return st, nil
|
||||
}
|
||||
|
||||
func (e *serveEnv) getServeConfig(ctx context.Context) (*ipn.ServeConfig, error) {
|
||||
if e.testGetServeConfig != nil {
|
||||
return e.testGetServeConfig(ctx)
|
||||
}
|
||||
return localClient.GetServeConfig(ctx)
|
||||
}
|
||||
|
||||
func (e *serveEnv) setServeConfig(ctx context.Context, c *ipn.ServeConfig) error {
|
||||
if e.testSetServeConfig != nil {
|
||||
return e.testSetServeConfig(ctx, c)
|
||||
}
|
||||
return localClient.SetServeConfig(ctx, c)
|
||||
}
|
||||
|
||||
// validateServePort returns --serve-port flag value,
|
||||
// or an error if the port is not a valid port to serve on.
|
||||
func (e *serveEnv) validateServePort() (port uint16, err error) {
|
||||
@ -232,7 +227,7 @@ func (e *serveEnv) runServe(ctx context.Context, args []string) error {
|
||||
if err := json.Unmarshal(valb, sc); err != nil {
|
||||
return fmt.Errorf("invalid JSON: %w", err)
|
||||
}
|
||||
return localClient.SetServeConfig(ctx, sc)
|
||||
return e.lc.SetServeConfig(ctx, sc)
|
||||
}
|
||||
|
||||
if !(len(args) == 3 || (e.remove && len(args) >= 1)) {
|
||||
@ -294,7 +289,7 @@ func (e *serveEnv) runServe(ctx context.Context, args []string) error {
|
||||
return flag.ErrHelp
|
||||
}
|
||||
|
||||
cursc, err := e.getServeConfig(ctx)
|
||||
cursc, err := e.lc.GetServeConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -337,7 +332,7 @@ func (e *serveEnv) runServe(ctx context.Context, args []string) error {
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(cursc, sc) {
|
||||
if err := e.setServeConfig(ctx, sc); err != nil {
|
||||
if err := e.lc.SetServeConfig(ctx, sc); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -351,7 +346,7 @@ func (e *serveEnv) handleWebServeRemove(ctx context.Context, mount string) error
|
||||
return err
|
||||
}
|
||||
srvPortStr := strconv.Itoa(int(srvPort))
|
||||
sc, err := e.getServeConfig(ctx)
|
||||
sc, err := e.lc.GetServeConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -382,7 +377,7 @@ func (e *serveEnv) handleWebServeRemove(ctx context.Context, mount string) error
|
||||
if len(sc.TCP) == 0 {
|
||||
sc.TCP = nil
|
||||
}
|
||||
if err := e.setServeConfig(ctx, sc); err != nil {
|
||||
if err := e.lc.SetServeConfig(ctx, sc); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@ -453,7 +448,7 @@ func allNumeric(s string) bool {
|
||||
// - tailscale status
|
||||
// - tailscale status --json
|
||||
func (e *serveEnv) runServeStatus(ctx context.Context, args []string) error {
|
||||
sc, err := e.getServeConfig(ctx)
|
||||
sc, err := e.lc.GetServeConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -603,7 +598,7 @@ func (e *serveEnv) runServeTCP(ctx context.Context, args []string) error {
|
||||
fmt.Fprintf(os.Stderr, "error: invalid port %q\n\n", portStr)
|
||||
}
|
||||
|
||||
cursc, err := e.getServeConfig(ctx)
|
||||
cursc, err := e.lc.GetServeConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -628,7 +623,7 @@ func (e *serveEnv) runServeTCP(ctx context.Context, args []string) error {
|
||||
if len(sc.TCP) == 0 {
|
||||
sc.TCP = nil
|
||||
}
|
||||
return e.setServeConfig(ctx, sc)
|
||||
return e.lc.SetServeConfig(ctx, sc)
|
||||
}
|
||||
return errors.New("error: serve config does not exist")
|
||||
}
|
||||
@ -644,7 +639,7 @@ func (e *serveEnv) runServeTCP(ctx context.Context, args []string) error {
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(cursc, sc) {
|
||||
if err := e.setServeConfig(ctx, sc); err != nil {
|
||||
if err := e.lc.SetServeConfig(ctx, sc); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -674,7 +669,7 @@ func (e *serveEnv) runServeFunnel(ctx context.Context, args []string) error {
|
||||
default:
|
||||
return flag.ErrHelp
|
||||
}
|
||||
sc, err := e.getServeConfig(ctx)
|
||||
sc, err := e.lc.GetServeConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -703,7 +698,7 @@ func (e *serveEnv) runServeFunnel(ctx context.Context, args []string) error {
|
||||
sc.AllowFunnel = nil
|
||||
}
|
||||
}
|
||||
if err := e.setServeConfig(ctx, sc); err != nil {
|
||||
if err := e.lc.SetServeConfig(ctx, sc); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
@ -144,10 +144,10 @@ type step struct {
|
||||
},
|
||||
},
|
||||
})
|
||||
add(step{
|
||||
add(step{ // invalid port
|
||||
command: cmd("--serve-port=9999 /abc proxy 3001"),
|
||||
wantErr: anyErr(),
|
||||
}) // invalid port
|
||||
})
|
||||
add(step{
|
||||
command: cmd("--serve-port=8443 /abc proxy 3001"),
|
||||
want: &ipn.ServeConfig{
|
||||
@ -606,12 +606,12 @@ type step struct {
|
||||
wantErr: anyErr(),
|
||||
})
|
||||
|
||||
lc := &fakeLocalServeClient{}
|
||||
// And now run the steps above.
|
||||
var current *ipn.ServeConfig
|
||||
for i, st := range steps {
|
||||
if st.reset {
|
||||
t.Logf("Executing step #%d, line %v: [reset]", i, st.line)
|
||||
current = nil
|
||||
lc.config = nil
|
||||
}
|
||||
if st.command == nil {
|
||||
continue
|
||||
@ -620,26 +620,12 @@ type step struct {
|
||||
|
||||
var stdout bytes.Buffer
|
||||
var flagOut bytes.Buffer
|
||||
var newState *ipn.ServeConfig
|
||||
e := &serveEnv{
|
||||
lc: lc,
|
||||
testFlagOut: &flagOut,
|
||||
testStdout: &stdout,
|
||||
testGetLocalClientStatus: func(context.Context) (*ipnstate.Status, error) {
|
||||
return &ipnstate.Status{
|
||||
Self: &ipnstate.PeerStatus{
|
||||
DNSName: "foo.test.ts.net",
|
||||
Capabilities: []string{tailcfg.NodeAttrFunnel},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
testGetServeConfig: func(context.Context) (*ipn.ServeConfig, error) {
|
||||
return current, nil
|
||||
},
|
||||
testSetServeConfig: func(_ context.Context, c *ipn.ServeConfig) error {
|
||||
newState = c
|
||||
return nil
|
||||
},
|
||||
}
|
||||
lastCount := lc.setCount
|
||||
cmd := newServeCommand(e)
|
||||
err := cmd.ParseAndRun(context.Background(), st.command)
|
||||
if flagOut.Len() > 0 {
|
||||
@ -655,23 +641,61 @@ type step struct {
|
||||
continue
|
||||
}
|
||||
if st.wantErr != nil {
|
||||
t.Fatalf("step #%d, line %v: got success (saved=%v), but wanted an error", i, st.line, newState != nil)
|
||||
t.Fatalf("step #%d, line %v: got success (saved=%v), but wanted an error", i, st.line, lc.config != nil)
|
||||
}
|
||||
if !reflect.DeepEqual(newState, st.want) {
|
||||
var got *ipn.ServeConfig = nil
|
||||
if lc.setCount > lastCount {
|
||||
got = lc.config
|
||||
}
|
||||
if !reflect.DeepEqual(got, st.want) {
|
||||
t.Fatalf("[%d] %v: bad state. got:\n%s\n\nwant:\n%s\n",
|
||||
i, st.command, asJSON(newState), asJSON(st.want))
|
||||
i, st.command, asJSON(got), asJSON(st.want))
|
||||
// NOTE: asJSON will omit empty fields, which might make
|
||||
// result in bad state got/want diffs being the same, even
|
||||
// though the actual state is different. Use below to debug:
|
||||
// t.Fatalf("[%d] %v: bad state. got:\n%+v\n\nwant:\n%+v\n",
|
||||
// i, st.command, newState, st.want)
|
||||
}
|
||||
if newState != nil {
|
||||
current = newState
|
||||
// i, st.command, got, st.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fakeLocalServeClient is a fake tailscale.LocalClient for tests.
|
||||
// It's not a full implementation, just enough to test the serve command.
|
||||
//
|
||||
// The fake client is stateful, and is used to test manipulating
|
||||
// ServeConfig state. This implementation cannot be used concurrently.
|
||||
type fakeLocalServeClient struct {
|
||||
config *ipn.ServeConfig
|
||||
setCount int // counts calls to SetServeConfig
|
||||
}
|
||||
|
||||
// fakeStatus is a fake ipnstate.Status value for tests.
|
||||
// It's not a full implementation, just enough to test the serve command.
|
||||
//
|
||||
// It returns a state that's running, with a fake DNSName and the Funnel
|
||||
// node attribute capability.
|
||||
var fakeStatus = &ipnstate.Status{
|
||||
BackendState: ipn.Running.String(),
|
||||
Self: &ipnstate.PeerStatus{
|
||||
DNSName: "foo.test.ts.net",
|
||||
Capabilities: []string{tailcfg.NodeAttrFunnel},
|
||||
},
|
||||
}
|
||||
|
||||
func (lc *fakeLocalServeClient) Status(ctx context.Context) (*ipnstate.Status, error) {
|
||||
return fakeStatus, nil
|
||||
}
|
||||
|
||||
func (lc *fakeLocalServeClient) GetServeConfig(ctx context.Context) (*ipn.ServeConfig, error) {
|
||||
return lc.config.Clone(), nil
|
||||
}
|
||||
|
||||
func (lc *fakeLocalServeClient) SetServeConfig(ctx context.Context, config *ipn.ServeConfig) error {
|
||||
lc.setCount += 1
|
||||
lc.config = config.Clone()
|
||||
return nil
|
||||
}
|
||||
|
||||
// exactError returns an error checker that wants exactly the provided want error.
|
||||
// If optName is non-empty, it's used in the error message.
|
||||
func exactErr(want error, optName ...string) func(error) string {
|
||||
|
Loading…
Reference in New Issue
Block a user