mirror of
https://github.com/tailscale/tailscale.git
synced 2025-03-31 21:42:24 +00:00
cmd/tailscale/cli: [up] move lose-ssh check after other validations
The check was happening too early and in the case of error would wait 5 s and then error out. This makes it so that it does validations before the SSH check. Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
parent
054ef4de56
commit
6632504f45
cmd/tailscale/cli
@ -789,6 +789,10 @@ func TestUpdatePrefs(t *testing.T) {
|
|||||||
curPrefs *ipn.Prefs
|
curPrefs *ipn.Prefs
|
||||||
env upCheckEnv // empty goos means "linux"
|
env upCheckEnv // empty goos means "linux"
|
||||||
|
|
||||||
|
// sshOverTailscale specifies if the cmd being run over SSH over Tailscale.
|
||||||
|
// It is used to test the --accept-risks flag.
|
||||||
|
sshOverTailscale bool
|
||||||
|
|
||||||
// checkUpdatePrefsMutations, if non-nil, is run with the new prefs after
|
// checkUpdatePrefsMutations, if non-nil, is run with the new prefs after
|
||||||
// updatePrefs might've mutated them (from applyImplicitPrefs).
|
// updatePrefs might've mutated them (from applyImplicitPrefs).
|
||||||
checkUpdatePrefsMutations func(t *testing.T, newPrefs *ipn.Prefs)
|
checkUpdatePrefsMutations func(t *testing.T, newPrefs *ipn.Prefs)
|
||||||
@ -916,15 +920,159 @@ func TestUpdatePrefs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "enable_ssh",
|
||||||
|
flags: []string{"--ssh"},
|
||||||
|
curPrefs: &ipn.Prefs{
|
||||||
|
ControlURL: "https://login.tailscale.com",
|
||||||
|
Persist: &persist.Persist{LoginName: "crawshaw.github"},
|
||||||
|
AllowSingleHosts: true,
|
||||||
|
CorpDNS: true,
|
||||||
|
NetfilterMode: preftype.NetfilterOn,
|
||||||
|
},
|
||||||
|
wantJustEditMP: &ipn.MaskedPrefs{
|
||||||
|
RunSSHSet: true,
|
||||||
|
WantRunningSet: true,
|
||||||
|
},
|
||||||
|
checkUpdatePrefsMutations: func(t *testing.T, newPrefs *ipn.Prefs) {
|
||||||
|
if !newPrefs.RunSSH {
|
||||||
|
t.Errorf("RunSSH not set to true")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
env: upCheckEnv{backendState: "Running"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disable_ssh",
|
||||||
|
flags: []string{"--ssh=false"},
|
||||||
|
curPrefs: &ipn.Prefs{
|
||||||
|
ControlURL: "https://login.tailscale.com",
|
||||||
|
Persist: &persist.Persist{LoginName: "crawshaw.github"},
|
||||||
|
AllowSingleHosts: true,
|
||||||
|
CorpDNS: true,
|
||||||
|
RunSSH: true,
|
||||||
|
NetfilterMode: preftype.NetfilterOn,
|
||||||
|
},
|
||||||
|
wantJustEditMP: &ipn.MaskedPrefs{
|
||||||
|
RunSSHSet: true,
|
||||||
|
WantRunningSet: true,
|
||||||
|
},
|
||||||
|
checkUpdatePrefsMutations: func(t *testing.T, newPrefs *ipn.Prefs) {
|
||||||
|
if newPrefs.RunSSH {
|
||||||
|
t.Errorf("RunSSH not set to false")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
env: upCheckEnv{backendState: "Running", upArgs: upArgsT{
|
||||||
|
runSSH: true,
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disable_ssh_over_ssh_no_risk",
|
||||||
|
flags: []string{"--ssh=false"},
|
||||||
|
sshOverTailscale: true,
|
||||||
|
curPrefs: &ipn.Prefs{
|
||||||
|
ControlURL: "https://login.tailscale.com",
|
||||||
|
Persist: &persist.Persist{LoginName: "crawshaw.github"},
|
||||||
|
AllowSingleHosts: true,
|
||||||
|
CorpDNS: true,
|
||||||
|
NetfilterMode: preftype.NetfilterOn,
|
||||||
|
RunSSH: true,
|
||||||
|
},
|
||||||
|
wantJustEditMP: &ipn.MaskedPrefs{
|
||||||
|
RunSSHSet: true,
|
||||||
|
WantRunningSet: true,
|
||||||
|
},
|
||||||
|
checkUpdatePrefsMutations: func(t *testing.T, newPrefs *ipn.Prefs) {
|
||||||
|
if !newPrefs.RunSSH {
|
||||||
|
t.Errorf("RunSSH not set to true")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
env: upCheckEnv{backendState: "Running"},
|
||||||
|
wantErrSubtr: "aborted, no changes made",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enable_ssh_over_ssh_no_risk",
|
||||||
|
flags: []string{"--ssh=true"},
|
||||||
|
sshOverTailscale: true,
|
||||||
|
curPrefs: &ipn.Prefs{
|
||||||
|
ControlURL: "https://login.tailscale.com",
|
||||||
|
Persist: &persist.Persist{LoginName: "crawshaw.github"},
|
||||||
|
AllowSingleHosts: true,
|
||||||
|
CorpDNS: true,
|
||||||
|
NetfilterMode: preftype.NetfilterOn,
|
||||||
|
},
|
||||||
|
wantJustEditMP: &ipn.MaskedPrefs{
|
||||||
|
RunSSHSet: true,
|
||||||
|
WantRunningSet: true,
|
||||||
|
},
|
||||||
|
checkUpdatePrefsMutations: func(t *testing.T, newPrefs *ipn.Prefs) {
|
||||||
|
if !newPrefs.RunSSH {
|
||||||
|
t.Errorf("RunSSH not set to true")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
env: upCheckEnv{backendState: "Running"},
|
||||||
|
wantErrSubtr: "aborted, no changes made",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enable_ssh_over_ssh",
|
||||||
|
flags: []string{"--ssh=true", "--accept-risk=lose-ssh"},
|
||||||
|
sshOverTailscale: true,
|
||||||
|
curPrefs: &ipn.Prefs{
|
||||||
|
ControlURL: "https://login.tailscale.com",
|
||||||
|
Persist: &persist.Persist{LoginName: "crawshaw.github"},
|
||||||
|
AllowSingleHosts: true,
|
||||||
|
CorpDNS: true,
|
||||||
|
NetfilterMode: preftype.NetfilterOn,
|
||||||
|
},
|
||||||
|
wantJustEditMP: &ipn.MaskedPrefs{
|
||||||
|
RunSSHSet: true,
|
||||||
|
WantRunningSet: true,
|
||||||
|
},
|
||||||
|
checkUpdatePrefsMutations: func(t *testing.T, newPrefs *ipn.Prefs) {
|
||||||
|
if !newPrefs.RunSSH {
|
||||||
|
t.Errorf("RunSSH not set to true")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
env: upCheckEnv{backendState: "Running"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disable_ssh_over_ssh",
|
||||||
|
flags: []string{"--ssh=false", "--accept-risk=lose-ssh"},
|
||||||
|
sshOverTailscale: true,
|
||||||
|
curPrefs: &ipn.Prefs{
|
||||||
|
ControlURL: "https://login.tailscale.com",
|
||||||
|
Persist: &persist.Persist{LoginName: "crawshaw.github"},
|
||||||
|
AllowSingleHosts: true,
|
||||||
|
CorpDNS: true,
|
||||||
|
RunSSH: true,
|
||||||
|
NetfilterMode: preftype.NetfilterOn,
|
||||||
|
},
|
||||||
|
wantJustEditMP: &ipn.MaskedPrefs{
|
||||||
|
RunSSHSet: true,
|
||||||
|
WantRunningSet: true,
|
||||||
|
},
|
||||||
|
checkUpdatePrefsMutations: func(t *testing.T, newPrefs *ipn.Prefs) {
|
||||||
|
if newPrefs.RunSSH {
|
||||||
|
t.Errorf("RunSSH not set to false")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
env: upCheckEnv{backendState: "Running"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.sshOverTailscale {
|
||||||
|
old := getSSHClientEnvVar
|
||||||
|
getSSHClientEnvVar = func() string { return "100.100.100.100 1 1" }
|
||||||
|
t.Cleanup(func() { getSSHClientEnvVar = old })
|
||||||
|
}
|
||||||
if tt.env.goos == "" {
|
if tt.env.goos == "" {
|
||||||
tt.env.goos = "linux"
|
tt.env.goos = "linux"
|
||||||
}
|
}
|
||||||
tt.env.flagSet = newUpFlagSet(tt.env.goos, &tt.env.upArgs)
|
tt.env.flagSet = newUpFlagSet(tt.env.goos, &tt.env.upArgs)
|
||||||
flags := CleanUpArgs(tt.flags)
|
flags := CleanUpArgs(tt.flags)
|
||||||
tt.env.flagSet.Parse(flags)
|
if err := tt.env.flagSet.Parse(flags); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
newPrefs, err := prefsFromUpArgs(tt.env.upArgs, t.Logf, new(ipnstate.Status), tt.env.goos)
|
newPrefs, err := prefsFromUpArgs(tt.env.upArgs, t.Logf, new(ipnstate.Status), tt.env.goos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -939,6 +1087,8 @@ func TestUpdatePrefs(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
} else if tt.wantErrSubtr != "" {
|
||||||
|
t.Fatalf("want error %q, got nil", tt.wantErrSubtr)
|
||||||
}
|
}
|
||||||
if tt.checkUpdatePrefsMutations != nil {
|
if tt.checkUpdatePrefsMutations != nil {
|
||||||
tt.checkUpdatePrefsMutations(t, newPrefs)
|
tt.checkUpdatePrefsMutations(t, newPrefs)
|
||||||
@ -952,13 +1102,18 @@ func TestUpdatePrefs(t *testing.T) {
|
|||||||
justEditMP.Prefs = ipn.Prefs{} // uninteresting
|
justEditMP.Prefs = ipn.Prefs{} // uninteresting
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(justEditMP, tt.wantJustEditMP) {
|
if !reflect.DeepEqual(justEditMP, tt.wantJustEditMP) {
|
||||||
t.Logf("justEditMP != wantJustEditMP; following diff omits the Prefs field, which was %+v", oldEditPrefs)
|
t.Logf("justEditMP != wantJustEditMP; following diff omits the Prefs field, which was \n%v", asJSON(oldEditPrefs))
|
||||||
t.Fatalf("justEditMP: %v\n\n: ", cmp.Diff(justEditMP, tt.wantJustEditMP, cmpIP))
|
t.Fatalf("justEditMP: %v\n\n: ", cmp.Diff(justEditMP, tt.wantJustEditMP, cmpIP))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func asJSON(v any) string {
|
||||||
|
b, _ := json.MarshalIndent(v, "", "\t")
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
var cmpIP = cmp.Comparer(func(a, b netip.Addr) bool {
|
var cmpIP = cmp.Comparer(func(a, b netip.Addr) bool {
|
||||||
return a == b
|
return a == b
|
||||||
})
|
})
|
||||||
|
@ -22,9 +22,13 @@ var downCmd = &ffcli.Command{
|
|||||||
FlagSet: newDownFlagSet(),
|
FlagSet: newDownFlagSet(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var downArgs struct {
|
||||||
|
acceptedRisks string
|
||||||
|
}
|
||||||
|
|
||||||
func newDownFlagSet() *flag.FlagSet {
|
func newDownFlagSet() *flag.FlagSet {
|
||||||
downf := newFlagSet("down")
|
downf := newFlagSet("down")
|
||||||
registerAcceptRiskFlag(downf)
|
registerAcceptRiskFlag(downf, &downArgs.acceptedRisks)
|
||||||
return downf
|
return downf
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -34,7 +38,7 @@ func runDown(ctx context.Context, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if isSSHOverTailscale() {
|
if isSSHOverTailscale() {
|
||||||
if err := presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will disable Tailscale and result in your session disconnecting.`); err != nil {
|
if err := presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will disable Tailscale and result in your session disconnecting.`, downArgs.acceptedRisks); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,9 +16,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
riskTypes []string
|
riskTypes []string
|
||||||
acceptedRisks string
|
riskLoseSSH = registerRiskType("lose-ssh")
|
||||||
riskLoseSSH = registerRiskType("lose-ssh")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func registerRiskType(riskType string) string {
|
func registerRiskType(riskType string) string {
|
||||||
@ -28,12 +27,13 @@ func registerRiskType(riskType string) string {
|
|||||||
|
|
||||||
// registerAcceptRiskFlag registers the --accept-risk flag. Accepted risks are accounted for
|
// registerAcceptRiskFlag registers the --accept-risk flag. Accepted risks are accounted for
|
||||||
// in presentRiskToUser.
|
// in presentRiskToUser.
|
||||||
func registerAcceptRiskFlag(f *flag.FlagSet) {
|
func registerAcceptRiskFlag(f *flag.FlagSet, acceptedRisks *string) {
|
||||||
f.StringVar(&acceptedRisks, "accept-risk", "", "accept risk and skip confirmation for risk types: "+strings.Join(riskTypes, ","))
|
f.StringVar(acceptedRisks, "accept-risk", "", "accept risk and skip confirmation for risk types: "+strings.Join(riskTypes, ","))
|
||||||
}
|
}
|
||||||
|
|
||||||
// riskAccepted reports whether riskType is in acceptedRisks.
|
// isRiskAccepted reports whether riskType is in the comma-separated list of
|
||||||
func riskAccepted(riskType string) bool {
|
// risks in acceptedRisks.
|
||||||
|
func isRiskAccepted(riskType, acceptedRisks string) bool {
|
||||||
for _, r := range strings.Split(acceptedRisks, ",") {
|
for _, r := range strings.Split(acceptedRisks, ",") {
|
||||||
if r == riskType {
|
if r == riskType {
|
||||||
return true
|
return true
|
||||||
@ -49,12 +49,16 @@ var errAborted = errors.New("aborted, no changes made")
|
|||||||
// It is used by the presentRiskToUser function below.
|
// It is used by the presentRiskToUser function below.
|
||||||
const riskAbortTimeSeconds = 5
|
const riskAbortTimeSeconds = 5
|
||||||
|
|
||||||
// presentRiskToUser displays the risk message and waits for the user to
|
// presentRiskToUser displays the risk message and waits for the user to cancel.
|
||||||
// cancel. It returns errorAborted if the user aborts.
|
// It returns errorAborted if the user aborts. In tests it returns errAborted
|
||||||
func presentRiskToUser(riskType, riskMessage string) error {
|
// immediately unless the risk has been explicitly accepted.
|
||||||
if riskAccepted(riskType) {
|
func presentRiskToUser(riskType, riskMessage, acceptedRisks string) error {
|
||||||
|
if isRiskAccepted(riskType, acceptedRisks) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if inTest() {
|
||||||
|
return errAborted
|
||||||
|
}
|
||||||
outln(riskMessage)
|
outln(riskMessage)
|
||||||
printf("To skip this warning, use --accept-risk=%s\n", riskType)
|
printf("To skip this warning, use --accept-risk=%s\n", riskType)
|
||||||
|
|
||||||
|
@ -116,7 +116,7 @@ func newUpFlagSet(goos string, upArgs *upArgsT) *flag.FlagSet {
|
|||||||
upf.BoolVar(&upArgs.forceDaemon, "unattended", false, "run in \"Unattended Mode\" where Tailscale keeps running even after the current GUI user logs out (Windows-only)")
|
upf.BoolVar(&upArgs.forceDaemon, "unattended", false, "run in \"Unattended Mode\" where Tailscale keeps running even after the current GUI user logs out (Windows-only)")
|
||||||
}
|
}
|
||||||
upf.DurationVar(&upArgs.timeout, "timeout", 0, "maximum amount of time to wait for tailscaled to enter a Running state; default (0s) blocks forever")
|
upf.DurationVar(&upArgs.timeout, "timeout", 0, "maximum amount of time to wait for tailscaled to enter a Running state; default (0s) blocks forever")
|
||||||
registerAcceptRiskFlag(upf)
|
registerAcceptRiskFlag(upf, &upArgs.acceptedRisks)
|
||||||
return upf
|
return upf
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -150,6 +150,7 @@ type upArgsT struct {
|
|||||||
opUser string
|
opUser string
|
||||||
json bool
|
json bool
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
|
acceptedRisks string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a upArgsT) getAuthKey() (string, error) {
|
func (a upArgsT) getAuthKey() (string, error) {
|
||||||
@ -376,6 +377,21 @@ func updatePrefs(prefs, curPrefs *ipn.Prefs, env upCheckEnv) (simpleUp bool, jus
|
|||||||
return false, nil, fmt.Errorf("can't change --login-server without --force-reauth")
|
return false, nil, fmt.Errorf("can't change --login-server without --force-reauth")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Do this after validations to avoid the 5s delay if we're going to error
|
||||||
|
// out anyway.
|
||||||
|
wantSSH, haveSSH := env.upArgs.runSSH, curPrefs.RunSSH
|
||||||
|
fmt.Println("wantSSH", wantSSH, "haveSSH", haveSSH)
|
||||||
|
if wantSSH != haveSSH && isSSHOverTailscale() {
|
||||||
|
if wantSSH {
|
||||||
|
err = presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will reroute SSH traffic to Tailscale SSH and will result in your session disconnecting.`, env.upArgs.acceptedRisks)
|
||||||
|
} else {
|
||||||
|
err = presentRiskToUser(riskLoseSSH, `You are connected using Tailscale SSH; this action will result in your session disconnecting.`, env.upArgs.acceptedRisks)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return false, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
tagsChanged := !reflect.DeepEqual(curPrefs.AdvertiseTags, prefs.AdvertiseTags)
|
tagsChanged := !reflect.DeepEqual(curPrefs.AdvertiseTags, prefs.AdvertiseTags)
|
||||||
|
|
||||||
simpleUp = env.flagSet.NFlag() == 0 &&
|
simpleUp = env.flagSet.NFlag() == 0 &&
|
||||||
@ -475,17 +491,6 @@ func runUp(ctx context.Context, args []string) (retErr error) {
|
|||||||
curExitNodeIP: exitNodeIP(curPrefs, st),
|
curExitNodeIP: exitNodeIP(curPrefs, st),
|
||||||
}
|
}
|
||||||
|
|
||||||
if upArgs.runSSH != curPrefs.RunSSH && isSSHOverTailscale() {
|
|
||||||
if upArgs.runSSH {
|
|
||||||
err = presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will reroute SSH traffic to Tailscale SSH and will result in your session disconnecting.`)
|
|
||||||
} else {
|
|
||||||
err = presentRiskToUser(riskLoseSSH, `You are connected using Tailscale SSH; this action will result in your session disconnecting.`)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if retErr == nil {
|
if retErr == nil {
|
||||||
checkSSHUpWarnings(ctx)
|
checkSSHUpWarnings(ctx)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user