ssh/tailssh: fall back to using su when no TTY available on Linux

This allows pam authentication to run for ssh sessions, triggering
automation like pam_mkhomedir.

Updates #11854

Signed-off-by: Percy Wegmann <percy@tailscale.com>
This commit is contained in:
Percy Wegmann
2024-05-29 12:51:50 -05:00
committed by Percy Wegmann
parent f1d10c12ac
commit 08a9551a73
9 changed files with 632 additions and 260 deletions

View File

@@ -8,6 +8,7 @@ package tailssh
import (
"bufio"
"context"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
@@ -28,6 +29,7 @@ import (
"testing"
"time"
"github.com/bramvdbogaerde/go-scp"
"github.com/google/go-cmp/cmp"
"github.com/pkg/sftp"
gossh "github.com/tailscale/golang-x-crypto/ssh"
@@ -36,6 +38,7 @@ import (
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/netmap"
"tailscale.com/util/set"
)
// This file contains integration tests of the SSH functionality. These tests
@@ -58,7 +61,7 @@ func TestMain(m *testing.M) {
file.Close()
// Tail our log file.
cmd := exec.Command("tail", "-f", "/tmp/tailscalessh.log")
cmd := exec.Command("tail", "-F", "/tmp/tailscalessh.log")
r, err := cmd.StdoutPipe()
if err != nil {
@@ -77,6 +80,12 @@ func TestMain(m *testing.M) {
if err != nil {
return
}
defer func() {
// tail -f has a default sleep interval of 1 second, so it takes a
// moment for it to finish reading our log file after we've terminated.
// So, wait a bit to let it catch up.
time.Sleep(2 * time.Second)
}()
m.Run()
}
@@ -93,20 +102,40 @@ func TestIntegrationSSH(t *testing.T) {
}
tests := []struct {
cmd string
want []string
cmd string
want []string
forceV1Behavior bool
skip bool
}{
{
cmd: "id",
want: []string{"testuser", "groupone", "grouptwo"},
cmd: "id",
want: []string{"testuser", "groupone", "grouptwo"},
forceV1Behavior: false,
},
{
cmd: "pwd",
want: []string{homeDir},
cmd: "id",
want: []string{"testuser", "groupone", "grouptwo"},
forceV1Behavior: true,
},
{
cmd: "pwd",
want: []string{homeDir},
skip: !fallbackToSUAvailable(),
forceV1Behavior: false,
},
{
cmd: "echo 'hello'",
want: []string{"hello"},
skip: !fallbackToSUAvailable(),
forceV1Behavior: false,
},
}
for _, test := range tests {
if test.skip {
continue
}
// run every test both without and with a shell
for _, shell := range []bool{false, true} {
shellQualifier := "no_shell"
@@ -114,8 +143,13 @@ func TestIntegrationSSH(t *testing.T) {
shellQualifier = "shell"
}
t.Run(fmt.Sprintf("%s_%s", test.cmd, shellQualifier), func(t *testing.T) {
s := testSession(t)
versionQualifier := "v2"
if test.forceV1Behavior {
versionQualifier = "v1"
}
t.Run(fmt.Sprintf("%s_%s_%s", test.cmd, shellQualifier, versionQualifier), func(t *testing.T) {
s := testSession(t, test.forceV1Behavior)
if shell {
err := s.RequestPty("xterm", 40, 80, ssh.TerminalModes{
@@ -123,12 +157,20 @@ func TestIntegrationSSH(t *testing.T) {
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
})
if err != nil {
t.Fatalf("unable to request PTY: %s", err)
}
err = s.Shell()
if err != nil {
t.Fatalf("unable to request shell: %s", err)
}
// Read the shell prompt
s.read()
}
got := s.run(t, test.cmd)
got := s.run(t, test.cmd, shell)
for _, want := range test.want {
if !strings.Contains(got, want) {
t.Errorf("%q does not contain %q", got, want)
@@ -145,48 +187,133 @@ func TestIntegrationSFTP(t *testing.T) {
debugTest.Store(false)
})
filePath := "/tmp/sftptest.dat"
wantText := "hello world"
for _, forceV1Behavior := range []bool{false, true} {
name := "v2"
if forceV1Behavior {
name = "v1"
}
t.Run(name, func(t *testing.T) {
filePath := "/home/testuser/sftptest.dat"
if forceV1Behavior || !fallbackToSUAvailable() {
filePath = "/tmp/sftptest.dat"
}
wantText := "hello world"
cl := testClient(t)
scl, err := sftp.NewClient(cl)
if err != nil {
t.Fatalf("can't get sftp client: %s", err)
cl := testClient(t, forceV1Behavior)
scl, err := sftp.NewClient(cl)
if err != nil {
t.Fatalf("can't get sftp client: %s", err)
}
file, err := scl.Create(filePath)
if err != nil {
t.Fatalf("can't create file: %s", err)
}
_, err = file.Write([]byte(wantText))
if err != nil {
t.Fatalf("can't write to file: %s", err)
}
err = file.Close()
if err != nil {
t.Fatalf("can't close file: %s", err)
}
file, err = scl.OpenFile(filePath, os.O_RDONLY)
if err != nil {
t.Fatalf("can't open file: %s", err)
}
defer file.Close()
gotText, err := io.ReadAll(file)
if err != nil {
t.Fatalf("can't read file: %s", err)
}
if diff := cmp.Diff(string(gotText), wantText); diff != "" {
t.Fatalf("unexpected file contents (-got +want):\n%s", diff)
}
s := testSessionFor(t, cl)
got := s.run(t, "ls -l "+filePath, false)
if !strings.Contains(got, "testuser") {
t.Fatalf("unexpected file owner user: %s", got)
} else if !strings.Contains(got, "testuser") {
t.Fatalf("unexpected file owner group: %s", got)
}
})
}
}
func TestIntegrationSCP(t *testing.T) {
debugTest.Store(true)
t.Cleanup(func() {
debugTest.Store(false)
})
for _, forceV1Behavior := range []bool{false, true} {
name := "v2"
if forceV1Behavior {
name = "v1"
}
t.Run(name, func(t *testing.T) {
filePath := "/home/testuser/scptest.dat"
if !fallbackToSUAvailable() {
filePath = "/tmp/scptest.dat"
}
wantText := "hello world"
cl := testClient(t, forceV1Behavior)
scl, err := scp.NewClientBySSH(cl)
if err != nil {
t.Fatalf("can't get sftp client: %s", err)
}
err = scl.Copy(context.Background(), strings.NewReader(wantText), filePath, "0644", int64(len(wantText)))
if err != nil {
t.Fatalf("can't create file: %s", err)
}
outfile, err := os.CreateTemp("", "")
if err != nil {
t.Fatalf("can't create temp file: %s", err)
}
err = scl.CopyFromRemote(context.Background(), outfile, filePath)
if err != nil {
t.Fatalf("can't copy file from remote: %s", err)
}
outfile.Close()
gotText, err := os.ReadFile(outfile.Name())
if err != nil {
t.Fatalf("can't read file: %s", err)
}
if diff := cmp.Diff(string(gotText), wantText); diff != "" {
t.Fatalf("unexpected file contents (-got +want):\n%s", diff)
}
s := testSessionFor(t, cl)
got := s.run(t, "ls -l "+filePath, false)
if !strings.Contains(got, "testuser") {
t.Fatalf("unexpected file owner user: %s", got)
} else if !strings.Contains(got, "testuser") {
t.Fatalf("unexpected file owner group: %s", got)
}
})
}
}
func fallbackToSUAvailable() bool {
if runtime.GOOS != "linux" {
return false
}
file, err := scl.Create(filePath)
_, err := exec.LookPath("su")
if err != nil {
t.Fatalf("can't create file: %s", err)
}
_, err = file.Write([]byte(wantText))
if err != nil {
t.Fatalf("can't write to file: %s", err)
}
err = file.Close()
if err != nil {
t.Fatalf("can't close file: %s", err)
return false
}
file, err = scl.OpenFile(filePath, os.O_RDONLY)
if err != nil {
t.Fatalf("can't open file: %s", err)
}
defer file.Close()
gotText, err := io.ReadAll(file)
if err != nil {
t.Fatalf("can't read file: %s", err)
}
if diff := cmp.Diff(string(gotText), wantText); diff != "" {
t.Fatalf("unexpected file contents (-got +want):\n%s", diff)
}
s := testSessionFor(t, cl)
got := s.run(t, "ls -l "+filePath)
if !strings.Contains(got, "testuser") {
t.Fatalf("unexpected file owner user: %s", got)
} else if !strings.Contains(got, "testuser") {
t.Fatalf("unexpected file owner group: %s", got)
}
// Some operating systems like Fedora seem to require login to be present
// in order for su to work.
_, err = exec.LookPath("login")
return err == nil
}
type session struct {
@@ -197,14 +324,25 @@ type session struct {
stderr io.ReadCloser
}
func (s *session) run(t *testing.T, cmdString string) string {
func (s *session) run(t *testing.T, cmdString string, shell bool) string {
t.Helper()
err := s.Start(cmdString)
if err != nil {
t.Fatalf("unable to start command: %s", err)
if shell {
_, err := s.stdin.Write([]byte(fmt.Sprintf("%s\n", cmdString)))
if err != nil {
t.Fatalf("unable to send command to shell: %s", err)
}
} else {
err := s.Start(cmdString)
if err != nil {
t.Fatalf("unable to start command: %s", err)
}
}
return s.read()
}
func (s *session) read() string {
ch := make(chan []byte)
go func() {
for {
@@ -228,7 +366,7 @@ readLoop:
select {
case b := <-ch:
_got = append(_got, b...)
case <-time.After(25 * time.Millisecond):
case <-time.After(1 * time.Second):
break readLoop
}
}
@@ -236,12 +374,12 @@ readLoop:
return string(_got)
}
func testClient(t *testing.T) *ssh.Client {
func testClient(t *testing.T, forceV1Behavior bool) *ssh.Client {
t.Helper()
username := "testuser"
srv := &server{
lb: &testBackend{localUser: username},
lb: &testBackend{localUser: username, forceV1Behavior: forceV1Behavior},
logf: log.Printf,
tailscaledPath: os.Getenv("TAILSCALED_PATH"),
timeNow: time.Now,
@@ -271,8 +409,8 @@ func testClient(t *testing.T) *ssh.Client {
return cl
}
func testSession(t *testing.T) *session {
cl := testClient(t)
func testSession(t *testing.T, forceV1Behavior bool) *session {
cl := testClient(t, forceV1Behavior)
return testSessionFor(t, cl)
}
@@ -299,7 +437,8 @@ func testSessionFor(t *testing.T, cl *ssh.Client) *session {
// testBackend implements ipnLocalBackend
type testBackend struct {
localUser string
localUser string
forceV1Behavior bool
}
func (tb *testBackend) GetSSH_HostKeys() ([]gossh.Signer, error) {
@@ -339,16 +478,21 @@ func (tb *testBackend) ShouldRunSSH() bool {
}
func (tb *testBackend) NetMap() *netmap.NetworkMap {
capMap := make(set.Set[tailcfg.NodeCapability])
if tb.forceV1Behavior {
capMap[tailcfg.NodeAttrSSHBehaviorV1] = struct{}{}
}
return &netmap.NetworkMap{
SSHPolicy: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
&tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
Action: &tailcfg.SSHAction{Accept: true},
SSHUsers: map[string]string{"*": tb.localUser},
},
},
},
AllCaps: capMap,
}
}