mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-22 08:51:41 +00:00
cmd/tailscale/cli: use mock impl of LocalClient for serve cmd (#6422)
Create an interface and mock implementation of tailscale.LocalClient for serve command tests. Updates #6304 Closes #6372 Signed-off-by: Shayne Sweeney <shayne@tailscale.com>
This commit is contained in:
parent
fec888581a
commit
73399f784b
@ -30,9 +30,9 @@ import (
|
|||||||
"tailscale.com/version"
|
"tailscale.com/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
var serveCmd = newServeCommand(&serveEnv{})
|
var serveCmd = newServeCommand(&serveEnv{lc: &localClient})
|
||||||
|
|
||||||
// newServeCommand returns a new "serve" subcommand using e as its environmment.
|
// newServeCommand returns a new "serve" subcommand using e as its environment.
|
||||||
func newServeCommand(e *serveEnv) *ffcli.Command {
|
func newServeCommand(e *serveEnv) *ffcli.Command {
|
||||||
return &ffcli.Command{
|
return &ffcli.Command{
|
||||||
Name: "serve",
|
Name: "serve",
|
||||||
@ -127,8 +127,21 @@ func (e *serveEnv) newFlags(name string, setup func(fs *flag.FlagSet)) *flag.Fla
|
|||||||
return fs
|
return fs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// localServeClient is an interface conforming to the subset of
|
||||||
|
// tailscale.LocalClient. It includes only the methods used by the
|
||||||
|
// serve command.
|
||||||
|
//
|
||||||
|
// The purpose of this interface is to allow tests to provide a mock.
|
||||||
|
type localServeClient interface {
|
||||||
|
Status(context.Context) (*ipnstate.Status, error)
|
||||||
|
GetServeConfig(context.Context) (*ipn.ServeConfig, error)
|
||||||
|
SetServeConfig(context.Context, *ipn.ServeConfig) error
|
||||||
|
}
|
||||||
|
|
||||||
// serveEnv is the environment the serve command runs within. All I/O should be
|
// serveEnv is the environment the serve command runs within. All I/O should be
|
||||||
// done via serveEnv methods so that it can be faked out for tests.
|
// done via serveEnv methods so that it can be faked out for tests.
|
||||||
|
// Calls to localClient should be done via the lc field, which is an interface
|
||||||
|
// that can be faked out for tests.
|
||||||
//
|
//
|
||||||
// It also contains the flags, as registered with newServeCommand.
|
// It also contains the flags, as registered with newServeCommand.
|
||||||
type serveEnv struct {
|
type serveEnv struct {
|
||||||
@ -138,12 +151,11 @@ type serveEnv struct {
|
|||||||
remove bool // remove a serve config
|
remove bool // remove a serve config
|
||||||
json bool // output JSON (status only for now)
|
json bool // output JSON (status only for now)
|
||||||
|
|
||||||
|
lc localServeClient // localClient interface, specific to serve
|
||||||
|
|
||||||
// optional stuff for tests:
|
// optional stuff for tests:
|
||||||
testFlagOut io.Writer
|
testFlagOut io.Writer
|
||||||
testGetServeConfig func(context.Context) (*ipn.ServeConfig, error)
|
testStdout io.Writer
|
||||||
testSetServeConfig func(context.Context, *ipn.ServeConfig) error
|
|
||||||
testGetLocalClientStatus func(context.Context) (*ipnstate.Status, error)
|
|
||||||
testStdout io.Writer
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getSelfDNSName returns the DNS name of the current node.
|
// getSelfDNSName returns the DNS name of the current node.
|
||||||
@ -157,15 +169,12 @@ func (e *serveEnv) getSelfDNSName(ctx context.Context) (string, error) {
|
|||||||
return strings.TrimSuffix(st.Self.DNSName, "."), nil
|
return strings.TrimSuffix(st.Self.DNSName, "."), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getLocalClientStatus calls LocalClient.Status, checks if
|
// getLocalClientStatus returns the Status of the local client.
|
||||||
// Status is ready.
|
|
||||||
// Returns error if unable to reach tailscaled or if self node is nil.
|
// Returns error if unable to reach tailscaled or if self node is nil.
|
||||||
|
//
|
||||||
// Exits if status is not running or starting.
|
// Exits if status is not running or starting.
|
||||||
func (e *serveEnv) getLocalClientStatus(ctx context.Context) (*ipnstate.Status, error) {
|
func (e *serveEnv) getLocalClientStatus(ctx context.Context) (*ipnstate.Status, error) {
|
||||||
if e.testGetLocalClientStatus != nil {
|
st, err := e.lc.Status(ctx)
|
||||||
return e.testGetLocalClientStatus(ctx)
|
|
||||||
}
|
|
||||||
st, err := localClient.Status(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fixTailscaledConnectError(err)
|
return nil, fixTailscaledConnectError(err)
|
||||||
}
|
}
|
||||||
@ -180,20 +189,6 @@ func (e *serveEnv) getLocalClientStatus(ctx context.Context) (*ipnstate.Status,
|
|||||||
return st, nil
|
return st, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *serveEnv) getServeConfig(ctx context.Context) (*ipn.ServeConfig, error) {
|
|
||||||
if e.testGetServeConfig != nil {
|
|
||||||
return e.testGetServeConfig(ctx)
|
|
||||||
}
|
|
||||||
return localClient.GetServeConfig(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *serveEnv) setServeConfig(ctx context.Context, c *ipn.ServeConfig) error {
|
|
||||||
if e.testSetServeConfig != nil {
|
|
||||||
return e.testSetServeConfig(ctx, c)
|
|
||||||
}
|
|
||||||
return localClient.SetServeConfig(ctx, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateServePort returns --serve-port flag value,
|
// validateServePort returns --serve-port flag value,
|
||||||
// or an error if the port is not a valid port to serve on.
|
// or an error if the port is not a valid port to serve on.
|
||||||
func (e *serveEnv) validateServePort() (port uint16, err error) {
|
func (e *serveEnv) validateServePort() (port uint16, err error) {
|
||||||
@ -232,7 +227,7 @@ func (e *serveEnv) runServe(ctx context.Context, args []string) error {
|
|||||||
if err := json.Unmarshal(valb, sc); err != nil {
|
if err := json.Unmarshal(valb, sc); err != nil {
|
||||||
return fmt.Errorf("invalid JSON: %w", err)
|
return fmt.Errorf("invalid JSON: %w", err)
|
||||||
}
|
}
|
||||||
return localClient.SetServeConfig(ctx, sc)
|
return e.lc.SetServeConfig(ctx, sc)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(len(args) == 3 || (e.remove && len(args) >= 1)) {
|
if !(len(args) == 3 || (e.remove && len(args) >= 1)) {
|
||||||
@ -294,7 +289,7 @@ func (e *serveEnv) runServe(ctx context.Context, args []string) error {
|
|||||||
return flag.ErrHelp
|
return flag.ErrHelp
|
||||||
}
|
}
|
||||||
|
|
||||||
cursc, err := e.getServeConfig(ctx)
|
cursc, err := e.lc.GetServeConfig(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -337,7 +332,7 @@ func (e *serveEnv) runServe(ctx context.Context, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(cursc, sc) {
|
if !reflect.DeepEqual(cursc, sc) {
|
||||||
if err := e.setServeConfig(ctx, sc); err != nil {
|
if err := e.lc.SetServeConfig(ctx, sc); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -351,7 +346,7 @@ func (e *serveEnv) handleWebServeRemove(ctx context.Context, mount string) error
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
srvPortStr := strconv.Itoa(int(srvPort))
|
srvPortStr := strconv.Itoa(int(srvPort))
|
||||||
sc, err := e.getServeConfig(ctx)
|
sc, err := e.lc.GetServeConfig(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -382,7 +377,7 @@ func (e *serveEnv) handleWebServeRemove(ctx context.Context, mount string) error
|
|||||||
if len(sc.TCP) == 0 {
|
if len(sc.TCP) == 0 {
|
||||||
sc.TCP = nil
|
sc.TCP = nil
|
||||||
}
|
}
|
||||||
if err := e.setServeConfig(ctx, sc); err != nil {
|
if err := e.lc.SetServeConfig(ctx, sc); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -453,7 +448,7 @@ func allNumeric(s string) bool {
|
|||||||
// - tailscale status
|
// - tailscale status
|
||||||
// - tailscale status --json
|
// - tailscale status --json
|
||||||
func (e *serveEnv) runServeStatus(ctx context.Context, args []string) error {
|
func (e *serveEnv) runServeStatus(ctx context.Context, args []string) error {
|
||||||
sc, err := e.getServeConfig(ctx)
|
sc, err := e.lc.GetServeConfig(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -603,7 +598,7 @@ func (e *serveEnv) runServeTCP(ctx context.Context, args []string) error {
|
|||||||
fmt.Fprintf(os.Stderr, "error: invalid port %q\n\n", portStr)
|
fmt.Fprintf(os.Stderr, "error: invalid port %q\n\n", portStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
cursc, err := e.getServeConfig(ctx)
|
cursc, err := e.lc.GetServeConfig(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -628,7 +623,7 @@ func (e *serveEnv) runServeTCP(ctx context.Context, args []string) error {
|
|||||||
if len(sc.TCP) == 0 {
|
if len(sc.TCP) == 0 {
|
||||||
sc.TCP = nil
|
sc.TCP = nil
|
||||||
}
|
}
|
||||||
return e.setServeConfig(ctx, sc)
|
return e.lc.SetServeConfig(ctx, sc)
|
||||||
}
|
}
|
||||||
return errors.New("error: serve config does not exist")
|
return errors.New("error: serve config does not exist")
|
||||||
}
|
}
|
||||||
@ -644,7 +639,7 @@ func (e *serveEnv) runServeTCP(ctx context.Context, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(cursc, sc) {
|
if !reflect.DeepEqual(cursc, sc) {
|
||||||
if err := e.setServeConfig(ctx, sc); err != nil {
|
if err := e.lc.SetServeConfig(ctx, sc); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -674,7 +669,7 @@ func (e *serveEnv) runServeFunnel(ctx context.Context, args []string) error {
|
|||||||
default:
|
default:
|
||||||
return flag.ErrHelp
|
return flag.ErrHelp
|
||||||
}
|
}
|
||||||
sc, err := e.getServeConfig(ctx)
|
sc, err := e.lc.GetServeConfig(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -703,7 +698,7 @@ func (e *serveEnv) runServeFunnel(ctx context.Context, args []string) error {
|
|||||||
sc.AllowFunnel = nil
|
sc.AllowFunnel = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := e.setServeConfig(ctx, sc); err != nil {
|
if err := e.lc.SetServeConfig(ctx, sc); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -144,10 +144,10 @@ func TestServeConfigMutations(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
add(step{
|
add(step{ // invalid port
|
||||||
command: cmd("--serve-port=9999 /abc proxy 3001"),
|
command: cmd("--serve-port=9999 /abc proxy 3001"),
|
||||||
wantErr: anyErr(),
|
wantErr: anyErr(),
|
||||||
}) // invalid port
|
})
|
||||||
add(step{
|
add(step{
|
||||||
command: cmd("--serve-port=8443 /abc proxy 3001"),
|
command: cmd("--serve-port=8443 /abc proxy 3001"),
|
||||||
want: &ipn.ServeConfig{
|
want: &ipn.ServeConfig{
|
||||||
@ -606,12 +606,12 @@ func TestServeConfigMutations(t *testing.T) {
|
|||||||
wantErr: anyErr(),
|
wantErr: anyErr(),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
lc := &fakeLocalServeClient{}
|
||||||
// And now run the steps above.
|
// And now run the steps above.
|
||||||
var current *ipn.ServeConfig
|
|
||||||
for i, st := range steps {
|
for i, st := range steps {
|
||||||
if st.reset {
|
if st.reset {
|
||||||
t.Logf("Executing step #%d, line %v: [reset]", i, st.line)
|
t.Logf("Executing step #%d, line %v: [reset]", i, st.line)
|
||||||
current = nil
|
lc.config = nil
|
||||||
}
|
}
|
||||||
if st.command == nil {
|
if st.command == nil {
|
||||||
continue
|
continue
|
||||||
@ -620,26 +620,12 @@ func TestServeConfigMutations(t *testing.T) {
|
|||||||
|
|
||||||
var stdout bytes.Buffer
|
var stdout bytes.Buffer
|
||||||
var flagOut bytes.Buffer
|
var flagOut bytes.Buffer
|
||||||
var newState *ipn.ServeConfig
|
|
||||||
e := &serveEnv{
|
e := &serveEnv{
|
||||||
|
lc: lc,
|
||||||
testFlagOut: &flagOut,
|
testFlagOut: &flagOut,
|
||||||
testStdout: &stdout,
|
testStdout: &stdout,
|
||||||
testGetLocalClientStatus: func(context.Context) (*ipnstate.Status, error) {
|
|
||||||
return &ipnstate.Status{
|
|
||||||
Self: &ipnstate.PeerStatus{
|
|
||||||
DNSName: "foo.test.ts.net",
|
|
||||||
Capabilities: []string{tailcfg.NodeAttrFunnel},
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
testGetServeConfig: func(context.Context) (*ipn.ServeConfig, error) {
|
|
||||||
return current, nil
|
|
||||||
},
|
|
||||||
testSetServeConfig: func(_ context.Context, c *ipn.ServeConfig) error {
|
|
||||||
newState = c
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
lastCount := lc.setCount
|
||||||
cmd := newServeCommand(e)
|
cmd := newServeCommand(e)
|
||||||
err := cmd.ParseAndRun(context.Background(), st.command)
|
err := cmd.ParseAndRun(context.Background(), st.command)
|
||||||
if flagOut.Len() > 0 {
|
if flagOut.Len() > 0 {
|
||||||
@ -655,23 +641,61 @@ func TestServeConfigMutations(t *testing.T) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if st.wantErr != nil {
|
if st.wantErr != nil {
|
||||||
t.Fatalf("step #%d, line %v: got success (saved=%v), but wanted an error", i, st.line, newState != nil)
|
t.Fatalf("step #%d, line %v: got success (saved=%v), but wanted an error", i, st.line, lc.config != nil)
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(newState, st.want) {
|
var got *ipn.ServeConfig = nil
|
||||||
|
if lc.setCount > lastCount {
|
||||||
|
got = lc.config
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, st.want) {
|
||||||
t.Fatalf("[%d] %v: bad state. got:\n%s\n\nwant:\n%s\n",
|
t.Fatalf("[%d] %v: bad state. got:\n%s\n\nwant:\n%s\n",
|
||||||
i, st.command, asJSON(newState), asJSON(st.want))
|
i, st.command, asJSON(got), asJSON(st.want))
|
||||||
// NOTE: asJSON will omit empty fields, which might make
|
// NOTE: asJSON will omit empty fields, which might make
|
||||||
// result in bad state got/want diffs being the same, even
|
// result in bad state got/want diffs being the same, even
|
||||||
// though the actual state is different. Use below to debug:
|
// though the actual state is different. Use below to debug:
|
||||||
// t.Fatalf("[%d] %v: bad state. got:\n%+v\n\nwant:\n%+v\n",
|
// t.Fatalf("[%d] %v: bad state. got:\n%+v\n\nwant:\n%+v\n",
|
||||||
// i, st.command, newState, st.want)
|
// i, st.command, got, st.want)
|
||||||
}
|
|
||||||
if newState != nil {
|
|
||||||
current = newState
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fakeLocalServeClient is a fake tailscale.LocalClient for tests.
|
||||||
|
// It's not a full implementation, just enough to test the serve command.
|
||||||
|
//
|
||||||
|
// The fake client is stateful, and is used to test manipulating
|
||||||
|
// ServeConfig state. This implementation cannot be used concurrently.
|
||||||
|
type fakeLocalServeClient struct {
|
||||||
|
config *ipn.ServeConfig
|
||||||
|
setCount int // counts calls to SetServeConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakeStatus is a fake ipnstate.Status value for tests.
|
||||||
|
// It's not a full implementation, just enough to test the serve command.
|
||||||
|
//
|
||||||
|
// It returns a state that's running, with a fake DNSName and the Funnel
|
||||||
|
// node attribute capability.
|
||||||
|
var fakeStatus = &ipnstate.Status{
|
||||||
|
BackendState: ipn.Running.String(),
|
||||||
|
Self: &ipnstate.PeerStatus{
|
||||||
|
DNSName: "foo.test.ts.net",
|
||||||
|
Capabilities: []string{tailcfg.NodeAttrFunnel},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lc *fakeLocalServeClient) Status(ctx context.Context) (*ipnstate.Status, error) {
|
||||||
|
return fakeStatus, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lc *fakeLocalServeClient) GetServeConfig(ctx context.Context) (*ipn.ServeConfig, error) {
|
||||||
|
return lc.config.Clone(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lc *fakeLocalServeClient) SetServeConfig(ctx context.Context, config *ipn.ServeConfig) error {
|
||||||
|
lc.setCount += 1
|
||||||
|
lc.config = config.Clone()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// exactError returns an error checker that wants exactly the provided want error.
|
// exactError returns an error checker that wants exactly the provided want error.
|
||||||
// If optName is non-empty, it's used in the error message.
|
// If optName is non-empty, it's used in the error message.
|
||||||
func exactErr(want error, optName ...string) func(error) string {
|
func exactErr(want error, optName ...string) func(error) string {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user