mirror of
https://github.com/tailscale/tailscale.git
synced 2025-03-14 01:11:01 +00:00
ssh/tailssh: add more SSH tests, blend in env from ssh session
Updates #3802 Change-Id: I568c661cacbb0524afcd8be9577457ddba611f19 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
4686224e5a
commit
4b50977422
@ -232,11 +232,11 @@ func (srv *server) handleAcceptedSSH(ctx context.Context, s ssh.Session, ci *ssh
|
||||
}
|
||||
}
|
||||
cmd.Dir = lu.HomeDir
|
||||
cmd.Env = append(cmd.Env, s.Environ()...)
|
||||
cmd.Env = append(cmd.Env, envForUser(lu)...)
|
||||
if ptyReq.Term != "" {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
|
||||
}
|
||||
// TODO(bradfitz,maisem): also blend in user's s.Environ()
|
||||
logf("Running: %q", cmd.Args)
|
||||
var toCmd io.WriteCloser
|
||||
var fromCmd io.ReadCloser
|
||||
|
@ -8,12 +8,15 @@
|
||||
package tailssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -25,6 +28,7 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tstest"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/lineread"
|
||||
"tailscale.com/wgengine"
|
||||
)
|
||||
|
||||
@ -231,12 +235,78 @@ func TestSSH(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
got, err := exec.Command("ssh",
|
||||
"-p", fmt.Sprint(port),
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"user@127.0.0.1", "env").CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
execSSH := func(args ...string) *exec.Cmd {
|
||||
cmd := exec.Command("ssh",
|
||||
"-p", fmt.Sprint(port),
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"user@127.0.0.1")
|
||||
cmd.Args = append(cmd.Args, args...)
|
||||
return cmd
|
||||
}
|
||||
t.Logf("Got: %s", got)
|
||||
|
||||
t.Run("env", func(t *testing.T) {
|
||||
cmd := execSSH("env")
|
||||
cmd.Env = append(os.Environ(), "LANG=foo")
|
||||
got, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m := parseEnv(got)
|
||||
if got := m["USER"]; got == "" || got != u.Username {
|
||||
t.Errorf("USER = %q; want %q", got, u.Username)
|
||||
}
|
||||
if got := m["HOME"]; got == "" || got != u.HomeDir {
|
||||
t.Errorf("HOME = %q; want %q", got, u.HomeDir)
|
||||
}
|
||||
if got := m["PWD"]; got == "" || got != u.HomeDir {
|
||||
t.Errorf("PWD = %q; want %q", got, u.HomeDir)
|
||||
}
|
||||
if got := m["SHELL"]; got == "" {
|
||||
t.Errorf("no SHELL")
|
||||
}
|
||||
if got, want := m["LANG"], "foo"; got != want {
|
||||
t.Errorf("LANG = %q; want %q", got, want)
|
||||
}
|
||||
t.Logf("got: %+v", m)
|
||||
})
|
||||
|
||||
t.Run("stdout_stderr", func(t *testing.T) {
|
||||
cmd := execSSH("sh", "-c", "echo foo; echo bar >&2")
|
||||
var outBuf, errBuf bytes.Buffer
|
||||
cmd.Stdout = &outBuf
|
||||
cmd.Stderr = &errBuf
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Got: %q and %q", outBuf.Bytes(), errBuf.Bytes())
|
||||
// TODO: figure out why these aren't right. should be
|
||||
// "foo\n" and "bar\n", not "\n" and "bar\n".
|
||||
})
|
||||
|
||||
t.Run("stdin", func(t *testing.T) {
|
||||
cmd := execSSH("cat")
|
||||
var outBuf bytes.Buffer
|
||||
cmd.Stdout = &outBuf
|
||||
const str = "foo\nbar\n"
|
||||
cmd.Stdin = strings.NewReader(str)
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := outBuf.String(); got != str {
|
||||
t.Errorf("got %q; want %q", got, str)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func parseEnv(out []byte) map[string]string {
|
||||
e := map[string]string{}
|
||||
lineread.Reader(bytes.NewReader(out), func(line []byte) error {
|
||||
i := bytes.IndexByte(line, '=')
|
||||
if i == -1 {
|
||||
return nil
|
||||
}
|
||||
e[string(line[:i])] = string(line[i+1:])
|
||||
return nil
|
||||
})
|
||||
return e
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user