cmd/tailscale: write fewer known_hosts, resolve ssh host to FQDN early

Updates #3802

Change-Id: Ic44fa2e6661a9c046e725c04fa6b8213d3d4d2b2
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2022-03-25 14:27:22 -07:00
committed by Brad Fitzpatrick
parent df93158aac
commit 753f1bfad4
2 changed files with 61 additions and 32 deletions

View File

@@ -20,6 +20,7 @@ import (
"github.com/alessio/shellescape"
"github.com/peterbourgon/ff/v3/ffcli"
"inet.af/netaddr"
"tailscale.com/client/tailscale"
"tailscale.com/envknob"
"tailscale.com/ipn/ipnstate"
@@ -46,6 +47,20 @@ func runSSH(ctx context.Context, args []string) error {
}
username = lu.Username
}
st, err := tailscale.Status(ctx)
if err != nil {
return err
}
// hostForSSH is the hostname we'll tell OpenSSH we're
// connecting to, so we have to maintain fewer entries in the
// known_hosts files.
hostForSSH := host
if v, ok := nodeDNSNameFromArg(st, host); ok {
hostForSSH = v
}
ssh, err := exec.LookPath("ssh")
if err != nil {
// TODO(bradfitz): use Go's crypto/ssh client instead
@@ -56,10 +71,6 @@ func runSSH(ctx context.Context, args []string) error {
if err != nil {
return err
}
st, err := tailscale.Status(ctx)
if err != nil {
return err
}
knownHostsFile, err := writeKnownHosts(st)
if err != nil {
return err
@@ -86,7 +97,7 @@ func runSSH(ctx context.Context, args []string) error {
// to use a different one, we'll later be making stock ssh
// work well by default too. (doing things like automatically
// setting known_hosts, etc)
username + "@" + host,
username + "@" + hostForSSH,
}, argRest...)
if runtime.GOOS == "windows" {
@@ -135,28 +146,41 @@ func genKnownHosts(st *ipnstate.Status) []byte {
var buf bytes.Buffer
for _, k := range st.Peers() {
ps := st.Peer[k]
if len(ps.SSH_HostKeys) == 0 {
continue
}
// addEntries adds one line per each of p's host keys.
addEntries := func(host string) {
for _, hk := range ps.SSH_HostKeys {
hostKey := strings.TrimSpace(hk)
if strings.ContainsAny(hostKey, "\n\r") { // invalid
continue
}
fmt.Fprintf(&buf, "%s %s\n", host, hostKey)
for _, hk := range ps.SSH_HostKeys {
hostKey := strings.TrimSpace(hk)
if strings.ContainsAny(hostKey, "\n\r") { // invalid
continue
}
}
if ps.DNSName != "" {
addEntries(ps.DNSName)
}
if base, _, ok := strings.Cut(ps.DNSName, "."); ok {
addEntries(base)
}
for _, ip := range st.TailscaleIPs {
addEntries(ip.String())
fmt.Fprintf(&buf, "%s %s\n", ps.DNSName, hostKey)
}
}
return buf.Bytes()
}
// nodeDNSNameFromArg returns the PeerStatus.DNSName value from a peer
// in st that matches the input arg which can be a base name, full
// DNS name, or an IP.
func nodeDNSNameFromArg(st *ipnstate.Status, arg string) (dnsName string, ok bool) {
if arg == "" {
return
}
argIP, _ := netaddr.ParseIP(arg)
for _, ps := range st.Peer {
dnsName = ps.DNSName
if !argIP.IsZero() {
for _, ip := range ps.TailscaleIPs {
if ip == argIP {
return dnsName, true
}
}
continue
}
if strings.EqualFold(strings.TrimSuffix(arg, "."), strings.TrimSuffix(dnsName, ".")) {
return dnsName, true
}
if base, _, ok := strings.Cut(ps.DNSName, "."); ok && strings.EqualFold(base, arg) {
return dnsName, true
}
}
return "", false
}