cmd/tailscale: generate known_hosts file for 'tailscale ssh'

Updates #3802

Change-Id: I7a0052392f000ee44fc8e719f6666756aab91f3d
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2022-03-25 12:36:46 -07:00 committed by Brad Fitzpatrick
parent cceacda5eb
commit df93158aac

View File

@ -5,6 +5,7 @@
package cli
import (
"bytes"
"context"
"errors"
"fmt"
@ -12,13 +13,16 @@
"os"
"os/exec"
"os/user"
"path/filepath"
"runtime"
"strings"
"syscall"
"github.com/alessio/shellescape"
"github.com/peterbourgon/ff/v3/ffcli"
"tailscale.com/client/tailscale"
"tailscale.com/envknob"
"tailscale.com/ipn/ipnstate"
)
var sshCmd = &ffcli.Command{
@ -52,9 +56,21 @@ func runSSH(ctx context.Context, args []string) error {
if err != nil {
return err
}
st, err := tailscale.Status(ctx)
if err != nil {
return err
}
knownHostsFile, err := writeKnownHosts(st)
if err != nil {
return err
}
argv := append([]string{
ssh,
"-o", fmt.Sprintf("UserKnownHostsFile %s",
shellescape.Quote(knownHostsFile),
),
"-o", fmt.Sprintf("ProxyCommand %s --socket=%s nc %%h %%p",
shellescape.Quote(tailscaleBin),
shellescape.Quote(rootArgs.socket),
@ -95,3 +111,52 @@ func runSSH(ctx context.Context, args []string) error {
}
return errors.New("unreachable")
}
func writeKnownHosts(st *ipnstate.Status) (knownHostsFile string, err error) {
confDir, err := os.UserConfigDir()
if err != nil {
return "", err
}
tsConfDir := filepath.Join(confDir, "tailscale")
if err := os.MkdirAll(tsConfDir, 0700); err != nil {
return "", err
}
knownHostsFile = filepath.Join(tsConfDir, "ssh_known_hosts")
want := genKnownHosts(st)
if cur, err := os.ReadFile(knownHostsFile); err != nil || !bytes.Equal(cur, want) {
if err := os.WriteFile(knownHostsFile, want, 0644); err != nil {
return "", err
}
}
return knownHostsFile, nil
}
func genKnownHosts(st *ipnstate.Status) []byte {
var buf bytes.Buffer
for _, k := range st.Peers() {
ps := st.Peer[k]
if len(ps.SSH_HostKeys) == 0 {
continue
}
// addEntries adds one line per each of p's host keys.
addEntries := func(host string) {
for _, hk := range ps.SSH_HostKeys {
hostKey := strings.TrimSpace(hk)
if strings.ContainsAny(hostKey, "\n\r") { // invalid
continue
}
fmt.Fprintf(&buf, "%s %s\n", host, hostKey)
}
}
if ps.DNSName != "" {
addEntries(ps.DNSName)
}
if base, _, ok := strings.Cut(ps.DNSName, "."); ok {
addEntries(base)
}
for _, ip := range st.TailscaleIPs {
addEntries(ip.String())
}
}
return buf.Bytes()
}