From 6632504f45246b716cdfb2e41cbd5c6f4d3e11f4 Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Fri, 16 Sep 2022 19:52:02 -0700 Subject: [PATCH] cmd/tailscale/cli: [up] move lose-ssh check after other validations The check was happening too early and in the case of error would wait 5 s and then error out. This makes it so that it does validations before the SSH check. Signed-off-by: Maisem Ali --- cmd/tailscale/cli/cli_test.go | 159 +++++++++++++++++++++++++++++++++- cmd/tailscale/cli/down.go | 8 +- cmd/tailscale/cli/risks.go | 26 +++--- cmd/tailscale/cli/up.go | 29 ++++--- 4 files changed, 195 insertions(+), 27 deletions(-) diff --git a/cmd/tailscale/cli/cli_test.go b/cmd/tailscale/cli/cli_test.go index f6825bf17..5afa199e0 100644 --- a/cmd/tailscale/cli/cli_test.go +++ b/cmd/tailscale/cli/cli_test.go @@ -789,6 +789,10 @@ func TestUpdatePrefs(t *testing.T) { curPrefs *ipn.Prefs env upCheckEnv // empty goos means "linux" + // sshOverTailscale specifies if the cmd being run over SSH over Tailscale. + // It is used to test the --accept-risks flag. + sshOverTailscale bool + // checkUpdatePrefsMutations, if non-nil, is run with the new prefs after // updatePrefs might've mutated them (from applyImplicitPrefs). checkUpdatePrefsMutations func(t *testing.T, newPrefs *ipn.Prefs) @@ -916,15 +920,159 @@ func TestUpdatePrefs(t *testing.T) { } }, }, + { + name: "enable_ssh", + flags: []string{"--ssh"}, + curPrefs: &ipn.Prefs{ + ControlURL: "https://login.tailscale.com", + Persist: &persist.Persist{LoginName: "crawshaw.github"}, + AllowSingleHosts: true, + CorpDNS: true, + NetfilterMode: preftype.NetfilterOn, + }, + wantJustEditMP: &ipn.MaskedPrefs{ + RunSSHSet: true, + WantRunningSet: true, + }, + checkUpdatePrefsMutations: func(t *testing.T, newPrefs *ipn.Prefs) { + if !newPrefs.RunSSH { + t.Errorf("RunSSH not set to true") + } + }, + env: upCheckEnv{backendState: "Running"}, + }, + { + name: "disable_ssh", + flags: []string{"--ssh=false"}, + curPrefs: &ipn.Prefs{ + ControlURL: "https://login.tailscale.com", + Persist: &persist.Persist{LoginName: "crawshaw.github"}, + AllowSingleHosts: true, + CorpDNS: true, + RunSSH: true, + NetfilterMode: preftype.NetfilterOn, + }, + wantJustEditMP: &ipn.MaskedPrefs{ + RunSSHSet: true, + WantRunningSet: true, + }, + checkUpdatePrefsMutations: func(t *testing.T, newPrefs *ipn.Prefs) { + if newPrefs.RunSSH { + t.Errorf("RunSSH not set to false") + } + }, + env: upCheckEnv{backendState: "Running", upArgs: upArgsT{ + runSSH: true, + }}, + }, + { + name: "disable_ssh_over_ssh_no_risk", + flags: []string{"--ssh=false"}, + sshOverTailscale: true, + curPrefs: &ipn.Prefs{ + ControlURL: "https://login.tailscale.com", + Persist: &persist.Persist{LoginName: "crawshaw.github"}, + AllowSingleHosts: true, + CorpDNS: true, + NetfilterMode: preftype.NetfilterOn, + RunSSH: true, + }, + wantJustEditMP: &ipn.MaskedPrefs{ + RunSSHSet: true, + WantRunningSet: true, + }, + checkUpdatePrefsMutations: func(t *testing.T, newPrefs *ipn.Prefs) { + if !newPrefs.RunSSH { + t.Errorf("RunSSH not set to true") + } + }, + env: upCheckEnv{backendState: "Running"}, + wantErrSubtr: "aborted, no changes made", + }, + { + name: "enable_ssh_over_ssh_no_risk", + flags: []string{"--ssh=true"}, + sshOverTailscale: true, + curPrefs: &ipn.Prefs{ + ControlURL: "https://login.tailscale.com", + Persist: &persist.Persist{LoginName: "crawshaw.github"}, + AllowSingleHosts: true, + CorpDNS: true, + NetfilterMode: preftype.NetfilterOn, + }, + wantJustEditMP: &ipn.MaskedPrefs{ + RunSSHSet: true, + WantRunningSet: true, + }, + checkUpdatePrefsMutations: func(t *testing.T, newPrefs *ipn.Prefs) { + if !newPrefs.RunSSH { + t.Errorf("RunSSH not set to true") + } + }, + env: upCheckEnv{backendState: "Running"}, + wantErrSubtr: "aborted, no changes made", + }, + { + name: "enable_ssh_over_ssh", + flags: []string{"--ssh=true", "--accept-risk=lose-ssh"}, + sshOverTailscale: true, + curPrefs: &ipn.Prefs{ + ControlURL: "https://login.tailscale.com", + Persist: &persist.Persist{LoginName: "crawshaw.github"}, + AllowSingleHosts: true, + CorpDNS: true, + NetfilterMode: preftype.NetfilterOn, + }, + wantJustEditMP: &ipn.MaskedPrefs{ + RunSSHSet: true, + WantRunningSet: true, + }, + checkUpdatePrefsMutations: func(t *testing.T, newPrefs *ipn.Prefs) { + if !newPrefs.RunSSH { + t.Errorf("RunSSH not set to true") + } + }, + env: upCheckEnv{backendState: "Running"}, + }, + { + name: "disable_ssh_over_ssh", + flags: []string{"--ssh=false", "--accept-risk=lose-ssh"}, + sshOverTailscale: true, + curPrefs: &ipn.Prefs{ + ControlURL: "https://login.tailscale.com", + Persist: &persist.Persist{LoginName: "crawshaw.github"}, + AllowSingleHosts: true, + CorpDNS: true, + RunSSH: true, + NetfilterMode: preftype.NetfilterOn, + }, + wantJustEditMP: &ipn.MaskedPrefs{ + RunSSHSet: true, + WantRunningSet: true, + }, + checkUpdatePrefsMutations: func(t *testing.T, newPrefs *ipn.Prefs) { + if newPrefs.RunSSH { + t.Errorf("RunSSH not set to false") + } + }, + env: upCheckEnv{backendState: "Running"}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + if tt.sshOverTailscale { + old := getSSHClientEnvVar + getSSHClientEnvVar = func() string { return "100.100.100.100 1 1" } + t.Cleanup(func() { getSSHClientEnvVar = old }) + } if tt.env.goos == "" { tt.env.goos = "linux" } tt.env.flagSet = newUpFlagSet(tt.env.goos, &tt.env.upArgs) flags := CleanUpArgs(tt.flags) - tt.env.flagSet.Parse(flags) + if err := tt.env.flagSet.Parse(flags); err != nil { + t.Fatal(err) + } newPrefs, err := prefsFromUpArgs(tt.env.upArgs, t.Logf, new(ipnstate.Status), tt.env.goos) if err != nil { @@ -939,6 +1087,8 @@ func TestUpdatePrefs(t *testing.T) { return } t.Fatal(err) + } else if tt.wantErrSubtr != "" { + t.Fatalf("want error %q, got nil", tt.wantErrSubtr) } if tt.checkUpdatePrefsMutations != nil { tt.checkUpdatePrefsMutations(t, newPrefs) @@ -952,13 +1102,18 @@ func TestUpdatePrefs(t *testing.T) { justEditMP.Prefs = ipn.Prefs{} // uninteresting } if !reflect.DeepEqual(justEditMP, tt.wantJustEditMP) { - t.Logf("justEditMP != wantJustEditMP; following diff omits the Prefs field, which was %+v", oldEditPrefs) + t.Logf("justEditMP != wantJustEditMP; following diff omits the Prefs field, which was \n%v", asJSON(oldEditPrefs)) t.Fatalf("justEditMP: %v\n\n: ", cmp.Diff(justEditMP, tt.wantJustEditMP, cmpIP)) } }) } } +func asJSON(v any) string { + b, _ := json.MarshalIndent(v, "", "\t") + return string(b) +} + var cmpIP = cmp.Comparer(func(a, b netip.Addr) bool { return a == b }) diff --git a/cmd/tailscale/cli/down.go b/cmd/tailscale/cli/down.go index 9a139f315..2105f76b7 100644 --- a/cmd/tailscale/cli/down.go +++ b/cmd/tailscale/cli/down.go @@ -22,9 +22,13 @@ FlagSet: newDownFlagSet(), } +var downArgs struct { + acceptedRisks string +} + func newDownFlagSet() *flag.FlagSet { downf := newFlagSet("down") - registerAcceptRiskFlag(downf) + registerAcceptRiskFlag(downf, &downArgs.acceptedRisks) return downf } @@ -34,7 +38,7 @@ func runDown(ctx context.Context, args []string) error { } if isSSHOverTailscale() { - if err := presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will disable Tailscale and result in your session disconnecting.`); err != nil { + if err := presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will disable Tailscale and result in your session disconnecting.`, downArgs.acceptedRisks); err != nil { return err } } diff --git a/cmd/tailscale/cli/risks.go b/cmd/tailscale/cli/risks.go index d09bfce65..5decdd5d2 100644 --- a/cmd/tailscale/cli/risks.go +++ b/cmd/tailscale/cli/risks.go @@ -16,9 +16,8 @@ ) var ( - riskTypes []string - acceptedRisks string - riskLoseSSH = registerRiskType("lose-ssh") + riskTypes []string + riskLoseSSH = registerRiskType("lose-ssh") ) func registerRiskType(riskType string) string { @@ -28,12 +27,13 @@ func registerRiskType(riskType string) string { // registerAcceptRiskFlag registers the --accept-risk flag. Accepted risks are accounted for // in presentRiskToUser. -func registerAcceptRiskFlag(f *flag.FlagSet) { - f.StringVar(&acceptedRisks, "accept-risk", "", "accept risk and skip confirmation for risk types: "+strings.Join(riskTypes, ",")) +func registerAcceptRiskFlag(f *flag.FlagSet, acceptedRisks *string) { + f.StringVar(acceptedRisks, "accept-risk", "", "accept risk and skip confirmation for risk types: "+strings.Join(riskTypes, ",")) } -// riskAccepted reports whether riskType is in acceptedRisks. -func riskAccepted(riskType string) bool { +// isRiskAccepted reports whether riskType is in the comma-separated list of +// risks in acceptedRisks. +func isRiskAccepted(riskType, acceptedRisks string) bool { for _, r := range strings.Split(acceptedRisks, ",") { if r == riskType { return true @@ -49,12 +49,16 @@ func riskAccepted(riskType string) bool { // It is used by the presentRiskToUser function below. const riskAbortTimeSeconds = 5 -// presentRiskToUser displays the risk message and waits for the user to -// cancel. It returns errorAborted if the user aborts. -func presentRiskToUser(riskType, riskMessage string) error { - if riskAccepted(riskType) { +// presentRiskToUser displays the risk message and waits for the user to cancel. +// It returns errorAborted if the user aborts. In tests it returns errAborted +// immediately unless the risk has been explicitly accepted. +func presentRiskToUser(riskType, riskMessage, acceptedRisks string) error { + if isRiskAccepted(riskType, acceptedRisks) { return nil } + if inTest() { + return errAborted + } outln(riskMessage) printf("To skip this warning, use --accept-risk=%s\n", riskType) diff --git a/cmd/tailscale/cli/up.go b/cmd/tailscale/cli/up.go index 5deaf96d7..bb081a27b 100644 --- a/cmd/tailscale/cli/up.go +++ b/cmd/tailscale/cli/up.go @@ -116,7 +116,7 @@ func newUpFlagSet(goos string, upArgs *upArgsT) *flag.FlagSet { upf.BoolVar(&upArgs.forceDaemon, "unattended", false, "run in \"Unattended Mode\" where Tailscale keeps running even after the current GUI user logs out (Windows-only)") } upf.DurationVar(&upArgs.timeout, "timeout", 0, "maximum amount of time to wait for tailscaled to enter a Running state; default (0s) blocks forever") - registerAcceptRiskFlag(upf) + registerAcceptRiskFlag(upf, &upArgs.acceptedRisks) return upf } @@ -150,6 +150,7 @@ type upArgsT struct { opUser string json bool timeout time.Duration + acceptedRisks string } func (a upArgsT) getAuthKey() (string, error) { @@ -376,6 +377,21 @@ func updatePrefs(prefs, curPrefs *ipn.Prefs, env upCheckEnv) (simpleUp bool, jus return false, nil, fmt.Errorf("can't change --login-server without --force-reauth") } + // Do this after validations to avoid the 5s delay if we're going to error + // out anyway. + wantSSH, haveSSH := env.upArgs.runSSH, curPrefs.RunSSH + fmt.Println("wantSSH", wantSSH, "haveSSH", haveSSH) + if wantSSH != haveSSH && isSSHOverTailscale() { + if wantSSH { + err = presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will reroute SSH traffic to Tailscale SSH and will result in your session disconnecting.`, env.upArgs.acceptedRisks) + } else { + err = presentRiskToUser(riskLoseSSH, `You are connected using Tailscale SSH; this action will result in your session disconnecting.`, env.upArgs.acceptedRisks) + } + if err != nil { + return false, nil, err + } + } + tagsChanged := !reflect.DeepEqual(curPrefs.AdvertiseTags, prefs.AdvertiseTags) simpleUp = env.flagSet.NFlag() == 0 && @@ -475,17 +491,6 @@ func runUp(ctx context.Context, args []string) (retErr error) { curExitNodeIP: exitNodeIP(curPrefs, st), } - if upArgs.runSSH != curPrefs.RunSSH && isSSHOverTailscale() { - if upArgs.runSSH { - err = presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will reroute SSH traffic to Tailscale SSH and will result in your session disconnecting.`) - } else { - err = presentRiskToUser(riskLoseSSH, `You are connected using Tailscale SSH; this action will result in your session disconnecting.`) - } - if err != nil { - return err - } - } - defer func() { if retErr == nil { checkSSHUpWarnings(ctx)