From 03caa95bf21c1c59272c860c586eb0ff37a5a712 Mon Sep 17 00:00:00 2001
From: Brad Fitzpatrick <bradfitz@tailscale.com>
Date: Fri, 18 Feb 2022 19:07:04 -0800
Subject: [PATCH] ssh/tailssh: get login shell when running as non-root

And also reject attempts to use other users.

Updates #3802

Change-Id: Iddc85f6ea2dba17d12be66a50408d24c1f92833e
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
---
 ssh/tailssh/tailssh.go | 40 +++++++++++++++++++++++++++++++++++-----
 1 file changed, 35 insertions(+), 5 deletions(-)

diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go
index 13fd68732..a17af5cee 100644
--- a/ssh/tailssh/tailssh.go
+++ b/ssh/tailssh/tailssh.go
@@ -16,6 +16,9 @@ import (
 	"net"
 	"os"
 	"os/exec"
+	"os/user"
+	"runtime"
+	"strings"
 	"syscall"
 	"time"
 	"unsafe"
@@ -100,9 +103,9 @@ func (srv *server) handleSSH(s ssh.Session) {
 	lb := srv.lb
 	logf := srv.logf
 
-	user := s.User()
+	sshUser := s.User()
 	addr := s.RemoteAddr()
-	logf("Handling SSH from %v for user %v", addr, user)
+	logf("Handling SSH from %v for user %v", addr, sshUser)
 	ta, ok := addr.(*net.TCPAddr)
 	if !ok {
 		logf("tsshd: rejecting non-TCP addr %T %v", addr, addr)
@@ -140,7 +143,7 @@ func (srv *server) handleSSH(s ssh.Session) {
 	srcIP := srcIPP.IP()
 	sctx := &sshContext{
 		now:     time.Now(),
-		sshUser: s.User(),
+		sshUser: sshUser,
 		srcIP:   srcIP,
 		node:    node,
 		uprof:   &uprof,
@@ -165,8 +168,19 @@ func (srv *server) handleSSH(s ssh.Session) {
 		return
 	}
 	var cmd *exec.Cmd
-	if os.Getuid() != 0 || localUser == "root" {
-		cmd = exec.Command("/bin/bash")
+	if os.Getuid() != 0 {
+		u, err := user.Current()
+		if err != nil {
+			logf("failed to get current user: %v", err)
+			s.Exit(1)
+			return
+		}
+		if u.Username != localUser {
+			fmt.Fprintf(s, "can't switch user\n")
+			s.Exit(1)
+			return
+		}
+		cmd = exec.Command(loginShell(u.Uid))
 	} else {
 		cmd = exec.Command("/usr/bin/env", "su", "-", localUser)
 	}
@@ -297,3 +311,19 @@ func matchesPrincipal(ps []*tailcfg.SSHPrincipal, sctx *sshContext) bool {
 	}
 	return false
 }
+
+func loginShell(uid string) string {
+	switch runtime.GOOS {
+	case "linux":
+		out, _ := exec.Command("getent", "passwd", uid).Output()
+		// out is "root:x:0:0:root:/root:/bin/bash"
+		f := strings.SplitN(string(out), ":", 10)
+		if len(f) > 6 {
+			return f[6] // shell
+		}
+	}
+	if e := os.Getenv("SHELL"); e != "" {
+		return e
+	}
+	return "/bin/bash"
+}