mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-07 08:07:42 +00:00
cmd/tailscale/cli: add lose-ssh risk
This makes it so that the user is notified that the action they are about to take may result in them getting disconnected from the machine. It then waits for 5s for the user to maybe Ctrl+C out of it. It also introduces a `--accept-risk=lose-ssh` flag for automation, which allows the caller to pre-acknowledge the risk. The two actions that cause this are: - updating `--ssh` from `true` to `false` - running `tailscale down` Updates #3802 Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
parent
1336fb740b
commit
67325d334e
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/peterbourgon/ff/v3/ffcli"
|
"github.com/peterbourgon/ff/v3/ffcli"
|
||||||
@ -17,7 +18,14 @@
|
|||||||
ShortUsage: "down",
|
ShortUsage: "down",
|
||||||
ShortHelp: "Disconnect from Tailscale",
|
ShortHelp: "Disconnect from Tailscale",
|
||||||
|
|
||||||
Exec: runDown,
|
Exec: runDown,
|
||||||
|
FlagSet: newDownFlagSet(),
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDownFlagSet() *flag.FlagSet {
|
||||||
|
downf := newFlagSet("down")
|
||||||
|
registerAcceptRiskFlag(downf)
|
||||||
|
return downf
|
||||||
}
|
}
|
||||||
|
|
||||||
func runDown(ctx context.Context, args []string) error {
|
func runDown(ctx context.Context, args []string) error {
|
||||||
@ -25,6 +33,12 @@ func runDown(ctx context.Context, args []string) error {
|
|||||||
return fmt.Errorf("too many non-flag arguments: %q", args)
|
return fmt.Errorf("too many non-flag arguments: %q", args)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isSSHOverTailscale() {
|
||||||
|
if err := presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will disable Tailscale and result in your session disconnecting.`); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
st, err := localClient.Status(ctx)
|
st, err := localClient.Status(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error fetching current status: %w", err)
|
return fmt.Errorf("error fetching current status: %w", err)
|
||||||
|
78
cmd/tailscale/cli/risks.go
Normal file
78
cmd/tailscale/cli/risks.go
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
riskTypes []string
|
||||||
|
acceptedRisks string
|
||||||
|
riskLoseSSH = registerRiskType("lose-ssh")
|
||||||
|
)
|
||||||
|
|
||||||
|
func registerRiskType(riskType string) string {
|
||||||
|
riskTypes = append(riskTypes, riskType)
|
||||||
|
return riskType
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerAcceptRiskFlag registers the --accept-risk flag. Accepted risks are accounted for
|
||||||
|
// in presentRiskToUser.
|
||||||
|
func registerAcceptRiskFlag(f *flag.FlagSet) {
|
||||||
|
f.StringVar(&acceptedRisks, "accept-risk", "", "accept risk and skip confirmation for risk types: "+strings.Join(riskTypes, ","))
|
||||||
|
}
|
||||||
|
|
||||||
|
// riskAccepted reports whether riskType is in acceptedRisks.
|
||||||
|
func riskAccepted(riskType string) bool {
|
||||||
|
for _, r := range strings.Split(acceptedRisks, ",") {
|
||||||
|
if r == riskType {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var errAborted = errors.New("aborted, no changes made")
|
||||||
|
|
||||||
|
// riskAbortTimeSeconds is the number of seconds to wait after displaying the
|
||||||
|
// risk message before continuing with the operation.
|
||||||
|
// It is used by the presentRiskToUser function below.
|
||||||
|
const riskAbortTimeSeconds = 5
|
||||||
|
|
||||||
|
// presentRiskToUser displays the risk message and waits for the user to
|
||||||
|
// cancel. It returns errorAborted if the user aborts.
|
||||||
|
func presentRiskToUser(riskType, riskMessage string) error {
|
||||||
|
if riskAccepted(riskType) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
fmt.Println(riskMessage)
|
||||||
|
fmt.Printf("To skip this warning, use --accept-risk=%s\n", riskType)
|
||||||
|
|
||||||
|
interrupt := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(interrupt, syscall.SIGINT)
|
||||||
|
var msgLen int
|
||||||
|
for left := riskAbortTimeSeconds; left > 0; left-- {
|
||||||
|
msg := fmt.Sprintf("\rContinuing in %d seconds...", left)
|
||||||
|
msgLen = len(msg)
|
||||||
|
fmt.Print(msg)
|
||||||
|
select {
|
||||||
|
case <-interrupt:
|
||||||
|
fmt.Printf("\r%s\r", strings.Repeat(" ", msgLen+1))
|
||||||
|
return errAborted
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Printf("\r%s\r", strings.Repeat(" ", msgLen))
|
||||||
|
return errAborted
|
||||||
|
}
|
@ -21,6 +21,7 @@
|
|||||||
"inet.af/netaddr"
|
"inet.af/netaddr"
|
||||||
"tailscale.com/envknob"
|
"tailscale.com/envknob"
|
||||||
"tailscale.com/ipn/ipnstate"
|
"tailscale.com/ipn/ipnstate"
|
||||||
|
"tailscale.com/net/tsaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
var sshCmd = &ffcli.Command{
|
var sshCmd = &ffcli.Command{
|
||||||
@ -179,3 +180,28 @@ func nodeDNSNameFromArg(st *ipnstate.Status, arg string) (dnsName string, ok boo
|
|||||||
}
|
}
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getSSHClientEnvVar returns the "SSH_CLIENT" environment variable
|
||||||
|
// for the current process group, if any.
|
||||||
|
var getSSHClientEnvVar = func() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// isSSHOverTailscale checks if the invocation is in a SSH session over Tailscale.
|
||||||
|
// It is used to detect if the user is about to take an action that might result in them
|
||||||
|
// disconnecting from the machine (e.g. disabling SSH)
|
||||||
|
func isSSHOverTailscale() bool {
|
||||||
|
sshClient := getSSHClientEnvVar()
|
||||||
|
if sshClient == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
ipStr, _, ok := strings.Cut(sshClient, " ")
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
ip, err := netaddr.ParseIP(ipStr)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return tsaddr.IsTailscaleIP(ip)
|
||||||
|
}
|
||||||
|
51
cmd/tailscale/cli/ssh_unix.go
Normal file
51
cmd/tailscale/cli/ssh_unix.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
//go:build !js && !windows
|
||||||
|
// +build !js,!windows
|
||||||
|
|
||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
getSSHClientEnvVar = func() string {
|
||||||
|
if os.Getenv("SUDO_USER") == "" {
|
||||||
|
// No sudo, just check the env.
|
||||||
|
return os.Getenv("SSH_CLIENT")
|
||||||
|
}
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
// TODO(maisem): implement this for other platforms. It's not clear
|
||||||
|
// if there is a way to get the environment for a given process on
|
||||||
|
// darwin and bsd.
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// SID is the session ID of the user's login session.
|
||||||
|
// It is also the process ID of the original shell that the user logged in with.
|
||||||
|
// We only need to check the environment of that process.
|
||||||
|
sid, err := unix.Getsid(os.Getpid())
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
b, err := os.ReadFile(filepath.Join("/proc", strconv.Itoa(sid), "environ"))
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
prefix := []byte("SSH_CLIENT=")
|
||||||
|
for _, env := range bytes.Split(b, []byte{0}) {
|
||||||
|
if bytes.HasPrefix(env, prefix) {
|
||||||
|
return string(env[len(prefix):])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
@ -114,6 +114,8 @@ func newUpFlagSet(goos string, upArgs *upArgsT) *flag.FlagSet {
|
|||||||
case "windows":
|
case "windows":
|
||||||
upf.BoolVar(&upArgs.forceDaemon, "unattended", false, "run in \"Unattended Mode\" where Tailscale keeps running even after the current GUI user logs out (Windows-only)")
|
upf.BoolVar(&upArgs.forceDaemon, "unattended", false, "run in \"Unattended Mode\" where Tailscale keeps running even after the current GUI user logs out (Windows-only)")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
registerAcceptRiskFlag(upf)
|
||||||
return upf
|
return upf
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -465,6 +467,18 @@ func runUp(ctx context.Context, args []string) error {
|
|||||||
backendState: st.BackendState,
|
backendState: st.BackendState,
|
||||||
curExitNodeIP: exitNodeIP(curPrefs, st),
|
curExitNodeIP: exitNodeIP(curPrefs, st),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if upArgs.runSSH != curPrefs.RunSSH && isSSHOverTailscale() {
|
||||||
|
if upArgs.runSSH {
|
||||||
|
err = presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will reroute SSH traffic to Tailscale SSH and will result in your session disconnecting.`)
|
||||||
|
} else {
|
||||||
|
err = presentRiskToUser(riskLoseSSH, `You are connected using Tailscale SSH; this action will result in your session disconnecting.`)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
simpleUp, justEditMP, err := updatePrefs(prefs, curPrefs, env)
|
simpleUp, justEditMP, err := updatePrefs(prefs, curPrefs, env)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fatalf("%s", err)
|
fatalf("%s", err)
|
||||||
@ -705,7 +719,7 @@ func addPrefFlagMapping(flagName string, prefNames ...string) {
|
|||||||
// correspond to an ipn.Pref.
|
// correspond to an ipn.Pref.
|
||||||
func preflessFlag(flagName string) bool {
|
func preflessFlag(flagName string) bool {
|
||||||
switch flagName {
|
switch flagName {
|
||||||
case "auth-key", "force-reauth", "reset", "qr", "json":
|
case "auth-key", "force-reauth", "reset", "qr", "json", "accept-risk":
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
Loading…
x
Reference in New Issue
Block a user