mirror of
https://github.com/tailscale/tailscale.git
synced 2025-07-28 14:53:44 +00:00
ssh/tailssh: add more SSH tests, blend in env from ssh session
Updates #3802 Change-Id: I568c661cacbb0524afcd8be9577457ddba611f19 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
4686224e5a
commit
4b50977422
@ -232,11 +232,11 @@ func (srv *server) handleAcceptedSSH(ctx context.Context, s ssh.Session, ci *ssh
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
cmd.Dir = lu.HomeDir
|
cmd.Dir = lu.HomeDir
|
||||||
|
cmd.Env = append(cmd.Env, s.Environ()...)
|
||||||
cmd.Env = append(cmd.Env, envForUser(lu)...)
|
cmd.Env = append(cmd.Env, envForUser(lu)...)
|
||||||
if ptyReq.Term != "" {
|
if ptyReq.Term != "" {
|
||||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
|
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
|
||||||
}
|
}
|
||||||
// TODO(bradfitz,maisem): also blend in user's s.Environ()
|
|
||||||
logf("Running: %q", cmd.Args)
|
logf("Running: %q", cmd.Args)
|
||||||
var toCmd io.WriteCloser
|
var toCmd io.WriteCloser
|
||||||
var fromCmd io.ReadCloser
|
var fromCmd io.ReadCloser
|
||||||
|
@ -8,12 +8,15 @@
|
|||||||
package tailssh
|
package tailssh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"os/user"
|
"os/user"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -25,6 +28,7 @@ import (
|
|||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/tstest"
|
"tailscale.com/tstest"
|
||||||
"tailscale.com/types/logger"
|
"tailscale.com/types/logger"
|
||||||
|
"tailscale.com/util/lineread"
|
||||||
"tailscale.com/wgengine"
|
"tailscale.com/wgengine"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -231,12 +235,78 @@ func TestSSH(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
got, err := exec.Command("ssh",
|
execSSH := func(args ...string) *exec.Cmd {
|
||||||
"-p", fmt.Sprint(port),
|
cmd := exec.Command("ssh",
|
||||||
"-o", "StrictHostKeyChecking=no",
|
"-p", fmt.Sprint(port),
|
||||||
"user@127.0.0.1", "env").CombinedOutput()
|
"-o", "StrictHostKeyChecking=no",
|
||||||
if err != nil {
|
"user@127.0.0.1")
|
||||||
t.Fatal(err)
|
cmd.Args = append(cmd.Args, args...)
|
||||||
|
return cmd
|
||||||
}
|
}
|
||||||
t.Logf("Got: %s", got)
|
|
||||||
|
t.Run("env", func(t *testing.T) {
|
||||||
|
cmd := execSSH("env")
|
||||||
|
cmd.Env = append(os.Environ(), "LANG=foo")
|
||||||
|
got, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
m := parseEnv(got)
|
||||||
|
if got := m["USER"]; got == "" || got != u.Username {
|
||||||
|
t.Errorf("USER = %q; want %q", got, u.Username)
|
||||||
|
}
|
||||||
|
if got := m["HOME"]; got == "" || got != u.HomeDir {
|
||||||
|
t.Errorf("HOME = %q; want %q", got, u.HomeDir)
|
||||||
|
}
|
||||||
|
if got := m["PWD"]; got == "" || got != u.HomeDir {
|
||||||
|
t.Errorf("PWD = %q; want %q", got, u.HomeDir)
|
||||||
|
}
|
||||||
|
if got := m["SHELL"]; got == "" {
|
||||||
|
t.Errorf("no SHELL")
|
||||||
|
}
|
||||||
|
if got, want := m["LANG"], "foo"; got != want {
|
||||||
|
t.Errorf("LANG = %q; want %q", got, want)
|
||||||
|
}
|
||||||
|
t.Logf("got: %+v", m)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("stdout_stderr", func(t *testing.T) {
|
||||||
|
cmd := execSSH("sh", "-c", "echo foo; echo bar >&2")
|
||||||
|
var outBuf, errBuf bytes.Buffer
|
||||||
|
cmd.Stdout = &outBuf
|
||||||
|
cmd.Stderr = &errBuf
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Logf("Got: %q and %q", outBuf.Bytes(), errBuf.Bytes())
|
||||||
|
// TODO: figure out why these aren't right. should be
|
||||||
|
// "foo\n" and "bar\n", not "\n" and "bar\n".
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("stdin", func(t *testing.T) {
|
||||||
|
cmd := execSSH("cat")
|
||||||
|
var outBuf bytes.Buffer
|
||||||
|
cmd.Stdout = &outBuf
|
||||||
|
const str = "foo\nbar\n"
|
||||||
|
cmd.Stdin = strings.NewReader(str)
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if got := outBuf.String(); got != str {
|
||||||
|
t.Errorf("got %q; want %q", got, str)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseEnv(out []byte) map[string]string {
|
||||||
|
e := map[string]string{}
|
||||||
|
lineread.Reader(bytes.NewReader(out), func(line []byte) error {
|
||||||
|
i := bytes.IndexByte(line, '=')
|
||||||
|
if i == -1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
e[string(line[:i])] = string(line[i+1:])
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return e
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user