diff --git a/cmd/tailscale/cli/ssh.go b/cmd/tailscale/cli/ssh.go index 4ee56dd45..3dfe79853 100644 --- a/cmd/tailscale/cli/ssh.go +++ b/cmd/tailscale/cli/ssh.go @@ -20,6 +20,7 @@ import ( "github.com/alessio/shellescape" "github.com/peterbourgon/ff/v3/ffcli" + "inet.af/netaddr" "tailscale.com/client/tailscale" "tailscale.com/envknob" "tailscale.com/ipn/ipnstate" @@ -46,6 +47,20 @@ func runSSH(ctx context.Context, args []string) error { } username = lu.Username } + + st, err := tailscale.Status(ctx) + if err != nil { + return err + } + + // hostForSSH is the hostname we'll tell OpenSSH we're + // connecting to, so we have to maintain fewer entries in the + // known_hosts files. + hostForSSH := host + if v, ok := nodeDNSNameFromArg(st, host); ok { + hostForSSH = v + } + ssh, err := exec.LookPath("ssh") if err != nil { // TODO(bradfitz): use Go's crypto/ssh client instead @@ -56,10 +71,6 @@ func runSSH(ctx context.Context, args []string) error { if err != nil { return err } - st, err := tailscale.Status(ctx) - if err != nil { - return err - } knownHostsFile, err := writeKnownHosts(st) if err != nil { return err @@ -86,7 +97,7 @@ func runSSH(ctx context.Context, args []string) error { // to use a different one, we'll later be making stock ssh // work well by default too. (doing things like automatically // setting known_hosts, etc) - username + "@" + host, + username + "@" + hostForSSH, }, argRest...) if runtime.GOOS == "windows" { @@ -135,28 +146,41 @@ func genKnownHosts(st *ipnstate.Status) []byte { var buf bytes.Buffer for _, k := range st.Peers() { ps := st.Peer[k] - if len(ps.SSH_HostKeys) == 0 { - continue - } - // addEntries adds one line per each of p's host keys. - addEntries := func(host string) { - for _, hk := range ps.SSH_HostKeys { - hostKey := strings.TrimSpace(hk) - if strings.ContainsAny(hostKey, "\n\r") { // invalid - continue - } - fmt.Fprintf(&buf, "%s %s\n", host, hostKey) + for _, hk := range ps.SSH_HostKeys { + hostKey := strings.TrimSpace(hk) + if strings.ContainsAny(hostKey, "\n\r") { // invalid + continue } - } - if ps.DNSName != "" { - addEntries(ps.DNSName) - } - if base, _, ok := strings.Cut(ps.DNSName, "."); ok { - addEntries(base) - } - for _, ip := range st.TailscaleIPs { - addEntries(ip.String()) + fmt.Fprintf(&buf, "%s %s\n", ps.DNSName, hostKey) } } return buf.Bytes() } + +// nodeDNSNameFromArg returns the PeerStatus.DNSName value from a peer +// in st that matches the input arg which can be a base name, full +// DNS name, or an IP. +func nodeDNSNameFromArg(st *ipnstate.Status, arg string) (dnsName string, ok bool) { + if arg == "" { + return + } + argIP, _ := netaddr.ParseIP(arg) + for _, ps := range st.Peer { + dnsName = ps.DNSName + if !argIP.IsZero() { + for _, ip := range ps.TailscaleIPs { + if ip == argIP { + return dnsName, true + } + } + continue + } + if strings.EqualFold(strings.TrimSuffix(arg, "."), strings.TrimSuffix(dnsName, ".")) { + return dnsName, true + } + if base, _, ok := strings.Cut(ps.DNSName, "."); ok && strings.EqualFold(base, arg) { + return dnsName, true + } + } + return "", false +} diff --git a/net/tsdial/dnsmap.go b/net/tsdial/dnsmap.go index 553c2e765..7f0f23743 100644 --- a/net/tsdial/dnsmap.go +++ b/net/tsdial/dnsmap.go @@ -21,9 +21,14 @@ import ( // It must not be mutated once created. // // Example keys are "foo.domain.tld.beta.tailscale.net" and "foo", -// both without trailing dots. +// both without trailing dots, and both always lowercase. type dnsMap map[string]netaddr.IP +// canonMapKey canonicalizes its input s to be a dnsMap map key. +func canonMapKey(s string) string { + return strings.ToLower(strings.TrimSuffix(s, ".")) +} + func dnsMapFromNetworkMap(nm *netmap.NetworkMap) dnsMap { if nm == nil { return nil @@ -33,9 +38,9 @@ func dnsMapFromNetworkMap(nm *netmap.NetworkMap) dnsMap { have4 := false if nm.Name != "" && len(nm.Addresses) > 0 { ip := nm.Addresses[0].IP() - ret[strings.TrimRight(nm.Name, ".")] = ip + ret[canonMapKey(nm.Name)] = ip if dnsname.HasSuffix(nm.Name, suffix) { - ret[dnsname.TrimSuffix(nm.Name, suffix)] = ip + ret[canonMapKey(dnsname.TrimSuffix(nm.Name, suffix))] = ip } for _, a := range nm.Addresses { if a.IP().Is4() { @@ -52,9 +57,9 @@ func dnsMapFromNetworkMap(nm *netmap.NetworkMap) dnsMap { if ip.Is4() && !have4 { continue } - ret[strings.TrimRight(p.Name, ".")] = ip + ret[canonMapKey(p.Name)] = ip if dnsname.HasSuffix(p.Name, suffix) { - ret[dnsname.TrimSuffix(p.Name, suffix)] = ip + ret[canonMapKey(dnsname.TrimSuffix(p.Name, suffix))] = ip } break } @@ -67,7 +72,7 @@ func dnsMapFromNetworkMap(nm *netmap.NetworkMap) dnsMap { if err != nil { continue } - ret[strings.TrimRight(rec.Name, ".")] = ip + ret[canonMapKey(rec.Name)] = ip } return ret } @@ -106,7 +111,7 @@ func (m dnsMap) resolveMemory(ctx context.Context, network, addr string) (_ neta // Host is not an IP, so assume it's a DNS name. // Try MagicDNS first, otherwise a real DNS lookup. - ip := m[host] + ip := m[canonMapKey(host)] if !ip.IsZero() { return netaddr.IPPortFrom(ip, port), nil }