mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 04:55:31 +00:00
ssh: Add logic to set accepted environment variables in SSH session (#13559)
Add logic to set environment variables that match the SSH rule's `acceptEnv` settings in the SSH session's environment. Updates https://github.com/tailscale/corp/issues/22775 Signed-off-by: Mario Minardi <mario@tailscale.com>
This commit is contained in:
parent
dd6b808acf
commit
8f44ba1cd6
@ -4,6 +4,7 @@
|
||||
package tailssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
@ -17,27 +18,35 @@
|
||||
//
|
||||
// acceptEnv values may contain * and ? wildcard characters which match against
|
||||
// zero or one or more characters and a single character respectively.
|
||||
func filterEnv(acceptEnv []string, environ []string) []string {
|
||||
func filterEnv(acceptEnv []string, environ []string) ([]string, error) {
|
||||
var acceptedPairs []string
|
||||
|
||||
// Quick return if we have an empty list.
|
||||
if acceptEnv == nil || len(acceptEnv) == 0 {
|
||||
return acceptedPairs, nil
|
||||
}
|
||||
|
||||
for _, envPair := range environ {
|
||||
envVar := strings.Split(envPair, "=")[0]
|
||||
variableName, _, ok := strings.Cut(envPair, "=")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`invalid environment variable: %q. Variables must be in "KEY=VALUE" format`, envPair)
|
||||
}
|
||||
|
||||
// Short circuit if we have a direct match between the environment
|
||||
// variable and an AcceptEnv value.
|
||||
if slices.Contains(acceptEnv, envVar) {
|
||||
if slices.Contains(acceptEnv, variableName) {
|
||||
acceptedPairs = append(acceptedPairs, envPair)
|
||||
continue
|
||||
}
|
||||
|
||||
// Otherwise check if we have a wildcard pattern that matches.
|
||||
if matchAcceptEnv(acceptEnv, envVar) {
|
||||
if matchAcceptEnv(acceptEnv, variableName) {
|
||||
acceptedPairs = append(acceptedPairs, envPair)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return acceptedPairs
|
||||
return acceptedPairs, nil
|
||||
}
|
||||
|
||||
// matchAcceptEnv is a convenience function that wraps calling matchAcceptEnvPattern
|
||||
|
@ -108,6 +108,7 @@ func TestFilterEnv(t *testing.T) {
|
||||
acceptEnv []string
|
||||
environ []string
|
||||
expectedFiltered []string
|
||||
wantErrMessage string
|
||||
}{
|
||||
{
|
||||
name: "simple direct matches",
|
||||
@ -127,11 +128,26 @@ func TestFilterEnv(t *testing.T) {
|
||||
environ: []string{"FOO=BAR", "FOO2=BAZ", "FOO_3=123", "FOOOO4-2=AbCdEfG", "FO1-kmndGamc79567=ABC", "FO57=BAR2"},
|
||||
expectedFiltered: []string{"FOO=BAR", "FOOOO4-2=AbCdEfG", "FO1-kmndGamc79567=ABC"},
|
||||
},
|
||||
{
|
||||
name: "environ format invalid",
|
||||
acceptEnv: []string{"FO?", "FOOO*", "FO*5?7"},
|
||||
environ: []string{"FOOBAR"},
|
||||
expectedFiltered: nil,
|
||||
wantErrMessage: `invalid environment variable: "FOOBAR". Variables must be in "KEY=VALUE" format`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
filtered := filterEnv(tc.acceptEnv, tc.environ)
|
||||
filtered, err := filterEnv(tc.acceptEnv, tc.environ)
|
||||
if err == nil && tc.wantErrMessage != "" {
|
||||
t.Errorf("wanted error with message %q but error was nil", tc.wantErrMessage)
|
||||
}
|
||||
|
||||
if err != nil && err.Error() != tc.wantErrMessage {
|
||||
t.Errorf("err = %v; want %v", err, tc.wantErrMessage)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tc.expectedFiltered, filtered); diff != "" {
|
||||
t.Errorf("unexpected filter result (-got,+want): \n%s", diff)
|
||||
}
|
||||
|
@ -12,6 +12,7 @@
|
||||
package tailssh
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
@ -154,6 +155,22 @@ func (ss *sshSession) newIncubatorCommand(logf logger.Logf) (cmd *exec.Cmd, err
|
||||
incubatorArgs = append(incubatorArgs, "--cmd="+ss.RawCommand())
|
||||
}
|
||||
|
||||
allowSendEnv := nm.HasCap(tailcfg.NodeAttrSSHEnvironmentVariables)
|
||||
if allowSendEnv {
|
||||
env, err := filterEnv(ss.conn.acceptEnv, ss.Session.Environ())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(env) > 0 {
|
||||
encoded, err := json.Marshal(env)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode environment: %w", err)
|
||||
}
|
||||
incubatorArgs = append(incubatorArgs, fmt.Sprintf("--encoded-env=%q", encoded))
|
||||
}
|
||||
}
|
||||
|
||||
return exec.CommandContext(ss.ctx, ss.conn.srv.tailscaledPath, incubatorArgs...), nil
|
||||
}
|
||||
|
||||
@ -192,6 +209,9 @@ type incubatorArgs struct {
|
||||
forceV1Behavior bool
|
||||
debugTest bool
|
||||
isSELinuxEnforcing bool
|
||||
encodedEnv string
|
||||
allowListEnvKeys string
|
||||
forwardedEnviron []string
|
||||
}
|
||||
|
||||
func parseIncubatorArgs(args []string) (incubatorArgs, error) {
|
||||
@ -215,6 +235,7 @@ func parseIncubatorArgs(args []string) (incubatorArgs, error) {
|
||||
flags.BoolVar(&ia.forceV1Behavior, "force-v1-behavior", false, "allow falling back to the su command if login is unavailable")
|
||||
flags.BoolVar(&ia.debugTest, "debug-test", false, "should debug in test mode")
|
||||
flags.BoolVar(&ia.isSELinuxEnforcing, "is-selinux-enforcing", false, "whether SELinux is in enforcing mode")
|
||||
flags.StringVar(&ia.encodedEnv, "encoded-env", "", "JSON encoded array of environment variables in '['key=value']' format")
|
||||
flags.Parse(args)
|
||||
|
||||
for _, g := range strings.Split(groups, ",") {
|
||||
@ -225,6 +246,30 @@ func parseIncubatorArgs(args []string) (incubatorArgs, error) {
|
||||
ia.gids = append(ia.gids, gid)
|
||||
}
|
||||
|
||||
ia.forwardedEnviron = os.Environ()
|
||||
// pass through SSH_AUTH_SOCK environment variable to support ssh agent forwarding
|
||||
ia.allowListEnvKeys = "SSH_AUTH_SOCK"
|
||||
|
||||
if ia.encodedEnv != "" {
|
||||
unquoted, err := strconv.Unquote(ia.encodedEnv)
|
||||
if err != nil {
|
||||
return ia, fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err)
|
||||
}
|
||||
|
||||
var extraEnviron []string
|
||||
|
||||
err = json.Unmarshal([]byte(unquoted), &extraEnviron)
|
||||
if err != nil {
|
||||
return ia, fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err)
|
||||
}
|
||||
|
||||
ia.forwardedEnviron = append(ia.forwardedEnviron, extraEnviron...)
|
||||
|
||||
for _, v := range extraEnviron {
|
||||
ia.allowListEnvKeys = fmt.Sprintf("%s,%s", ia.allowListEnvKeys, strings.Split(v, "=")[0])
|
||||
}
|
||||
}
|
||||
|
||||
return ia, nil
|
||||
}
|
||||
|
||||
@ -406,7 +451,7 @@ func tryExecLogin(dlogf logger.Logf, ia incubatorArgs) error {
|
||||
dlogf("logging in with %+v", loginArgs)
|
||||
|
||||
// If Exec works, the Go code will not proceed past this:
|
||||
err = unix.Exec(loginCmdPath, loginArgs, os.Environ())
|
||||
err = unix.Exec(loginCmdPath, loginArgs, ia.forwardedEnviron)
|
||||
|
||||
// If we made it here, Exec failed.
|
||||
return err
|
||||
@ -441,7 +486,7 @@ func trySU(dlogf logger.Logf, ia incubatorArgs) (handled bool, err error) {
|
||||
|
||||
loginArgs := []string{
|
||||
su,
|
||||
"-w", "SSH_AUTH_SOCK", // pass through SSH_AUTH_SOCK environment variable to support ssh agent forwarding
|
||||
"-w", ia.allowListEnvKeys,
|
||||
"-l",
|
||||
ia.localUser,
|
||||
}
|
||||
@ -453,7 +498,7 @@ func trySU(dlogf logger.Logf, ia incubatorArgs) (handled bool, err error) {
|
||||
dlogf("logging in with %+v", loginArgs)
|
||||
|
||||
// If Exec works, the Go code will not proceed past this:
|
||||
err = unix.Exec(su, loginArgs, os.Environ())
|
||||
err = unix.Exec(su, loginArgs, ia.forwardedEnviron)
|
||||
|
||||
// If we made it here, Exec failed.
|
||||
return true, err
|
||||
@ -482,11 +527,11 @@ func findSU(dlogf logger.Logf, ia incubatorArgs) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// First try to execute su -w SSH_AUTH_SOCK -l <user> -c true
|
||||
// First try to execute su -w <allow listed env> -l <user> -c true
|
||||
// to make sure su supports the necessary arguments.
|
||||
err = exec.Command(
|
||||
su,
|
||||
"-w", "SSH_AUTH_SOCK",
|
||||
"-w", ia.allowListEnvKeys,
|
||||
"-l",
|
||||
ia.localUser,
|
||||
"-c", "true",
|
||||
@ -515,7 +560,7 @@ func handleSSHInProcess(dlogf logger.Logf, ia incubatorArgs) error {
|
||||
|
||||
args := shellArgs(ia.isShell, ia.cmd)
|
||||
dlogf("running %s %q", ia.loginShell, args)
|
||||
cmd := newCommand(ia.hasTTY, ia.loginShell, args)
|
||||
cmd := newCommand(ia.hasTTY, ia.loginShell, ia.forwardedEnviron, args)
|
||||
err := cmd.Run()
|
||||
if ee, ok := err.(*exec.ExitError); ok {
|
||||
ps := ee.ProcessState
|
||||
@ -532,12 +577,12 @@ func handleSSHInProcess(dlogf logger.Logf, ia incubatorArgs) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func newCommand(hasTTY bool, cmdPath string, cmdArgs []string) *exec.Cmd {
|
||||
func newCommand(hasTTY bool, cmdPath string, cmdEnviron []string, cmdArgs []string) *exec.Cmd {
|
||||
cmd := exec.Command(cmdPath, cmdArgs...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Env = cmdEnviron
|
||||
|
||||
if hasTTY {
|
||||
// If we were launched with a tty then we should mark that as the ctty
|
||||
|
@ -238,6 +238,7 @@ type conn struct {
|
||||
localUser *userMeta // set by doPolicyAuth
|
||||
userGroupIDs []string // set by doPolicyAuth
|
||||
pubKey gossh.PublicKey // set by doPolicyAuth
|
||||
acceptEnv []string
|
||||
|
||||
// mu protects the following fields.
|
||||
//
|
||||
@ -377,7 +378,7 @@ func (c *conn) doPolicyAuth(ctx ssh.Context, pubKey ssh.PublicKey) error {
|
||||
c.logf("failed to get conninfo: %v", err)
|
||||
return errDenied
|
||||
}
|
||||
a, localUser, err := c.evaluatePolicy(pubKey)
|
||||
a, localUser, acceptEnv, err := c.evaluatePolicy(pubKey)
|
||||
if err != nil {
|
||||
if pubKey == nil && c.havePubKeyPolicy() {
|
||||
return errPubKeyRequired
|
||||
@ -387,6 +388,7 @@ func (c *conn) doPolicyAuth(ctx ssh.Context, pubKey ssh.PublicKey) error {
|
||||
c.action0 = a
|
||||
c.currentAction = a
|
||||
c.pubKey = pubKey
|
||||
c.acceptEnv = acceptEnv
|
||||
if a.Message != "" {
|
||||
if err := ctx.SendAuthBanner(a.Message); err != nil {
|
||||
return fmt.Errorf("SendBanner: %w", err)
|
||||
@ -619,16 +621,16 @@ func (c *conn) setInfo(ctx ssh.Context) error {
|
||||
|
||||
// evaluatePolicy returns the SSHAction and localUser after evaluating
|
||||
// the SSHPolicy for this conn. The pubKey may be nil for "none" auth.
|
||||
func (c *conn) evaluatePolicy(pubKey gossh.PublicKey) (_ *tailcfg.SSHAction, localUser string, _ error) {
|
||||
func (c *conn) evaluatePolicy(pubKey gossh.PublicKey) (_ *tailcfg.SSHAction, localUser string, acceptEnv []string, _ error) {
|
||||
pol, ok := c.sshPolicy()
|
||||
if !ok {
|
||||
return nil, "", fmt.Errorf("tailssh: rejecting connection; no SSH policy")
|
||||
return nil, "", nil, fmt.Errorf("tailssh: rejecting connection; no SSH policy")
|
||||
}
|
||||
a, localUser, ok := c.evalSSHPolicy(pol, pubKey)
|
||||
a, localUser, acceptEnv, ok := c.evalSSHPolicy(pol, pubKey)
|
||||
if !ok {
|
||||
return nil, "", fmt.Errorf("tailssh: rejecting connection; no matching policy")
|
||||
return nil, "", nil, fmt.Errorf("tailssh: rejecting connection; no matching policy")
|
||||
}
|
||||
return a, localUser, nil
|
||||
return a, localUser, acceptEnv, nil
|
||||
}
|
||||
|
||||
// pubKeyCacheEntry is the cache value for an HTTPS URL of public keys (like
|
||||
@ -892,7 +894,7 @@ func (c *conn) newSSHSession(s ssh.Session) *sshSession {
|
||||
|
||||
// isStillValid reports whether the conn is still valid.
|
||||
func (c *conn) isStillValid() bool {
|
||||
a, localUser, err := c.evaluatePolicy(c.pubKey)
|
||||
a, localUser, _, err := c.evaluatePolicy(c.pubKey)
|
||||
c.vlogf("stillValid: %+v %v %v", a, localUser, err)
|
||||
if err != nil {
|
||||
return false
|
||||
@ -1275,13 +1277,13 @@ func (c *conn) ruleExpired(r *tailcfg.SSHRule) bool {
|
||||
return r.RuleExpires.Before(c.srv.now())
|
||||
}
|
||||
|
||||
func (c *conn) evalSSHPolicy(pol *tailcfg.SSHPolicy, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, ok bool) {
|
||||
func (c *conn) evalSSHPolicy(pol *tailcfg.SSHPolicy, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, acceptEnv []string, ok bool) {
|
||||
for _, r := range pol.Rules {
|
||||
if a, localUser, err := c.matchRule(r, pubKey); err == nil {
|
||||
return a, localUser, true
|
||||
if a, localUser, acceptEnv, err := c.matchRule(r, pubKey); err == nil {
|
||||
return a, localUser, acceptEnv, true
|
||||
}
|
||||
}
|
||||
return nil, "", false
|
||||
return nil, "", nil, false
|
||||
}
|
||||
|
||||
// internal errors for testing; they don't escape to callers or logs.
|
||||
@ -1294,26 +1296,26 @@ func (c *conn) evalSSHPolicy(pol *tailcfg.SSHPolicy, pubKey gossh.PublicKey) (a
|
||||
errInvalidConn = errors.New("invalid connection state")
|
||||
)
|
||||
|
||||
func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, err error) {
|
||||
func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, acceptEnv []string, err error) {
|
||||
defer func() {
|
||||
c.vlogf("matchRule(%+v): %v", r, err)
|
||||
}()
|
||||
|
||||
if c == nil {
|
||||
return nil, "", errInvalidConn
|
||||
return nil, "", nil, errInvalidConn
|
||||
}
|
||||
if c.info == nil {
|
||||
c.logf("invalid connection state")
|
||||
return nil, "", errInvalidConn
|
||||
return nil, "", nil, errInvalidConn
|
||||
}
|
||||
if r == nil {
|
||||
return nil, "", errNilRule
|
||||
return nil, "", nil, errNilRule
|
||||
}
|
||||
if r.Action == nil {
|
||||
return nil, "", errNilAction
|
||||
return nil, "", nil, errNilAction
|
||||
}
|
||||
if c.ruleExpired(r) {
|
||||
return nil, "", errRuleExpired
|
||||
return nil, "", nil, errRuleExpired
|
||||
}
|
||||
if !r.Action.Reject {
|
||||
// For all but Reject rules, SSHUsers is required.
|
||||
@ -1321,15 +1323,15 @@ func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg
|
||||
// empty string anyway.
|
||||
localUser = mapLocalUser(r.SSHUsers, c.info.sshUser)
|
||||
if localUser == "" {
|
||||
return nil, "", errUserMatch
|
||||
return nil, "", nil, errUserMatch
|
||||
}
|
||||
}
|
||||
if ok, err := c.anyPrincipalMatches(r.Principals, pubKey); err != nil {
|
||||
return nil, "", err
|
||||
return nil, "", nil, err
|
||||
} else if !ok {
|
||||
return nil, "", errPrincipalMatch
|
||||
return nil, "", nil, errPrincipalMatch
|
||||
}
|
||||
return r.Action, localUser, nil
|
||||
return r.Action, localUser, r.AcceptEnv, nil
|
||||
}
|
||||
|
||||
func mapLocalUser(ruleSSHUsers map[string]string, reqSSHUser string) (localUser string) {
|
||||
|
@ -108,6 +108,7 @@ func TestIntegrationSSH(t *testing.T) {
|
||||
want []string
|
||||
forceV1Behavior bool
|
||||
skip bool
|
||||
allowSendEnv bool
|
||||
}{
|
||||
{
|
||||
cmd: "id",
|
||||
@ -131,6 +132,18 @@ func TestIntegrationSSH(t *testing.T) {
|
||||
skip: os.Getenv("SKIP_FILE_OPS") == "1" || !fallbackToSUAvailable(),
|
||||
forceV1Behavior: false,
|
||||
},
|
||||
{
|
||||
cmd: `echo "${GIT_ENV_VAR:-unset1} ${EXACT_MATCH:-unset2} ${TESTING:-unset3} ${NOT_ALLOWED:-unset4}"`,
|
||||
want: []string{"working1 working2 working3 unset4"},
|
||||
forceV1Behavior: false,
|
||||
allowSendEnv: true,
|
||||
},
|
||||
{
|
||||
cmd: `echo "${GIT_ENV_VAR:-unset1} ${EXACT_MATCH:-unset2} ${TESTING:-unset3} ${NOT_ALLOWED:-unset4}"`,
|
||||
want: []string{"unset1 unset2 unset3 unset4"},
|
||||
forceV1Behavior: false,
|
||||
allowSendEnv: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
@ -151,7 +164,13 @@ func TestIntegrationSSH(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run(fmt.Sprintf("%s_%s_%s", test.cmd, shellQualifier, versionQualifier), func(t *testing.T) {
|
||||
s := testSession(t, test.forceV1Behavior)
|
||||
sendEnv := map[string]string{
|
||||
"GIT_ENV_VAR": "working1",
|
||||
"EXACT_MATCH": "working2",
|
||||
"TESTING": "working3",
|
||||
"NOT_ALLOWED": "working4",
|
||||
}
|
||||
s := testSession(t, test.forceV1Behavior, test.allowSendEnv, sendEnv)
|
||||
|
||||
if shell {
|
||||
err := s.RequestPty("xterm", 40, 80, ssh.TerminalModes{
|
||||
@ -201,7 +220,7 @@ func TestIntegrationSFTP(t *testing.T) {
|
||||
}
|
||||
wantText := "hello world"
|
||||
|
||||
cl := testClient(t, forceV1Behavior)
|
||||
cl := testClient(t, forceV1Behavior, false)
|
||||
scl, err := sftp.NewClient(cl)
|
||||
if err != nil {
|
||||
t.Fatalf("can't get sftp client: %s", err)
|
||||
@ -233,7 +252,7 @@ func TestIntegrationSFTP(t *testing.T) {
|
||||
t.Fatalf("unexpected file contents (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
s := testSessionFor(t, cl)
|
||||
s := testSessionFor(t, cl, nil)
|
||||
got := s.run(t, "ls -l "+filePath, false)
|
||||
if !strings.Contains(got, "testuser") {
|
||||
t.Fatalf("unexpected file owner user: %s", got)
|
||||
@ -262,7 +281,7 @@ func TestIntegrationSCP(t *testing.T) {
|
||||
}
|
||||
wantText := "hello world"
|
||||
|
||||
cl := testClient(t, forceV1Behavior)
|
||||
cl := testClient(t, forceV1Behavior, false)
|
||||
scl, err := scp.NewClientBySSH(cl)
|
||||
if err != nil {
|
||||
t.Fatalf("can't get sftp client: %s", err)
|
||||
@ -291,7 +310,7 @@ func TestIntegrationSCP(t *testing.T) {
|
||||
t.Fatalf("unexpected file contents (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
s := testSessionFor(t, cl)
|
||||
s := testSessionFor(t, cl, nil)
|
||||
got := s.run(t, "ls -l "+filePath, false)
|
||||
if !strings.Contains(got, "testuser") {
|
||||
t.Fatalf("unexpected file owner user: %s", got)
|
||||
@ -349,7 +368,7 @@ func TestSSHAgentForwarding(t *testing.T) {
|
||||
|
||||
// Run tailscale SSH server and connect to it
|
||||
username := "testuser"
|
||||
tailscaleAddr := testServer(t, username, false)
|
||||
tailscaleAddr := testServer(t, username, false, false)
|
||||
tcl, err := ssh.Dial("tcp", tailscaleAddr, &ssh.ClientConfig{
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
})
|
||||
@ -465,11 +484,11 @@ func (s *session) read() string {
|
||||
return string(_got)
|
||||
}
|
||||
|
||||
func testClient(t *testing.T, forceV1Behavior bool, authMethods ...ssh.AuthMethod) *ssh.Client {
|
||||
func testClient(t *testing.T, forceV1Behavior bool, allowSendEnv bool, authMethods ...ssh.AuthMethod) *ssh.Client {
|
||||
t.Helper()
|
||||
|
||||
username := "testuser"
|
||||
addr := testServer(t, username, forceV1Behavior)
|
||||
addr := testServer(t, username, forceV1Behavior, allowSendEnv)
|
||||
|
||||
cl, err := ssh.Dial("tcp", addr, &ssh.ClientConfig{
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
@ -483,9 +502,9 @@ func testClient(t *testing.T, forceV1Behavior bool, authMethods ...ssh.AuthMetho
|
||||
return cl
|
||||
}
|
||||
|
||||
func testServer(t *testing.T, username string, forceV1Behavior bool) string {
|
||||
func testServer(t *testing.T, username string, forceV1Behavior bool, allowSendEnv bool) string {
|
||||
srv := &server{
|
||||
lb: &testBackend{localUser: username, forceV1Behavior: forceV1Behavior},
|
||||
lb: &testBackend{localUser: username, forceV1Behavior: forceV1Behavior, allowSendEnv: allowSendEnv},
|
||||
logf: log.Printf,
|
||||
tailscaledPath: os.Getenv("TAILSCALED_PATH"),
|
||||
timeNow: time.Now,
|
||||
@ -509,16 +528,20 @@ func testServer(t *testing.T, username string, forceV1Behavior bool) string {
|
||||
return l.Addr().String()
|
||||
}
|
||||
|
||||
func testSession(t *testing.T, forceV1Behavior bool) *session {
|
||||
cl := testClient(t, forceV1Behavior)
|
||||
return testSessionFor(t, cl)
|
||||
func testSession(t *testing.T, forceV1Behavior bool, allowSendEnv bool, sendEnv map[string]string) *session {
|
||||
cl := testClient(t, forceV1Behavior, allowSendEnv)
|
||||
return testSessionFor(t, cl, sendEnv)
|
||||
}
|
||||
|
||||
func testSessionFor(t *testing.T, cl *ssh.Client) *session {
|
||||
func testSessionFor(t *testing.T, cl *ssh.Client, sendEnv map[string]string) *session {
|
||||
s, err := cl.NewSession()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for k, v := range sendEnv {
|
||||
s.Setenv(k, v)
|
||||
}
|
||||
|
||||
t.Cleanup(func() { s.Close() })
|
||||
|
||||
stdinReader, stdinWriter := io.Pipe()
|
||||
@ -564,6 +587,7 @@ func generateClientKey(t *testing.T, privateKeyFile string) (ssh.Signer, *rsa.Pr
|
||||
type testBackend struct {
|
||||
localUser string
|
||||
forceV1Behavior bool
|
||||
allowSendEnv bool
|
||||
}
|
||||
|
||||
func (tb *testBackend) GetSSH_HostKeys() ([]gossh.Signer, error) {
|
||||
@ -597,6 +621,9 @@ func (tb *testBackend) NetMap() *netmap.NetworkMap {
|
||||
if tb.forceV1Behavior {
|
||||
capMap[tailcfg.NodeAttrSSHBehaviorV1] = struct{}{}
|
||||
}
|
||||
if tb.allowSendEnv {
|
||||
capMap[tailcfg.NodeAttrSSHEnvironmentVariables] = struct{}{}
|
||||
}
|
||||
return &netmap.NetworkMap{
|
||||
SSHPolicy: &tailcfg.SSHPolicy{
|
||||
Rules: []*tailcfg.SSHRule{
|
||||
@ -604,6 +631,7 @@ func (tb *testBackend) NetMap() *netmap.NetworkMap {
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
Action: &tailcfg.SSHAction{Accept: true, AllowAgentForwarding: true},
|
||||
SSHUsers: map[string]string{"*": tb.localUser},
|
||||
AcceptEnv: []string{"GIT_*", "EXACT_MATCH", "TEST?NG"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -24,6 +24,7 @@
|
||||
"os/user"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -61,6 +62,7 @@ func TestMatchRule(t *testing.T) {
|
||||
ci *sshConnInfo
|
||||
wantErr error
|
||||
wantUser string
|
||||
wantAcceptEnv []string
|
||||
}{
|
||||
{
|
||||
name: "invalid-conn",
|
||||
@ -153,6 +155,21 @@ func TestMatchRule(t *testing.T) {
|
||||
ci: &sshConnInfo{sshUser: "alice"},
|
||||
wantUser: "thealice",
|
||||
},
|
||||
{
|
||||
name: "ok-with-accept-env",
|
||||
rule: &tailcfg.SSHRule{
|
||||
Action: someAction,
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
SSHUsers: map[string]string{
|
||||
"*": "ubuntu",
|
||||
"alice": "thealice",
|
||||
},
|
||||
AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
|
||||
},
|
||||
ci: &sshConnInfo{sshUser: "alice"},
|
||||
wantUser: "thealice",
|
||||
wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
|
||||
},
|
||||
{
|
||||
name: "no-users-for-reject",
|
||||
rule: &tailcfg.SSHRule{
|
||||
@ -210,7 +227,7 @@ func TestMatchRule(t *testing.T) {
|
||||
info: tt.ci,
|
||||
srv: &server{logf: t.Logf},
|
||||
}
|
||||
got, gotUser, err := c.matchRule(tt.rule, nil)
|
||||
got, gotUser, gotAcceptEnv, err := c.matchRule(tt.rule, nil)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("err = %v; want %v", err, tt.wantErr)
|
||||
}
|
||||
@ -220,6 +237,128 @@ func TestMatchRule(t *testing.T) {
|
||||
if err == nil && got == nil {
|
||||
t.Errorf("expected non-nil action on success")
|
||||
}
|
||||
if !slices.Equal(gotAcceptEnv, tt.wantAcceptEnv) {
|
||||
t.Errorf("acceptEnv = %v; want %v", gotAcceptEnv, tt.wantAcceptEnv)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvalSSHPolicy(t *testing.T) {
|
||||
someAction := new(tailcfg.SSHAction)
|
||||
tests := []struct {
|
||||
name string
|
||||
policy *tailcfg.SSHPolicy
|
||||
ci *sshConnInfo
|
||||
wantMatch bool
|
||||
wantUser string
|
||||
wantAcceptEnv []string
|
||||
}{
|
||||
{
|
||||
name: "multiple-matches-picks-first-match",
|
||||
policy: &tailcfg.SSHPolicy{
|
||||
Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Action: someAction,
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
SSHUsers: map[string]string{
|
||||
"other": "other1",
|
||||
},
|
||||
},
|
||||
{
|
||||
Action: someAction,
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
SSHUsers: map[string]string{
|
||||
"*": "ubuntu",
|
||||
"alice": "thealice",
|
||||
},
|
||||
AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
|
||||
},
|
||||
{
|
||||
Action: someAction,
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
SSHUsers: map[string]string{
|
||||
"other2": "other3",
|
||||
},
|
||||
},
|
||||
{
|
||||
Action: someAction,
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
SSHUsers: map[string]string{
|
||||
"*": "ubuntu",
|
||||
"alice": "thealice",
|
||||
"mark": "markthe",
|
||||
},
|
||||
AcceptEnv: []string{"*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
ci: &sshConnInfo{sshUser: "alice"},
|
||||
wantUser: "thealice",
|
||||
wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
|
||||
wantMatch: true,
|
||||
},
|
||||
{
|
||||
name: "no-matches-returns-failure",
|
||||
policy: &tailcfg.SSHPolicy{
|
||||
Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Action: someAction,
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
SSHUsers: map[string]string{
|
||||
"other": "other1",
|
||||
},
|
||||
},
|
||||
{
|
||||
Action: someAction,
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
SSHUsers: map[string]string{
|
||||
"fedora": "ubuntu",
|
||||
},
|
||||
AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
|
||||
},
|
||||
{
|
||||
Action: someAction,
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
SSHUsers: map[string]string{
|
||||
"other2": "other3",
|
||||
},
|
||||
},
|
||||
{
|
||||
Action: someAction,
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
SSHUsers: map[string]string{
|
||||
"mark": "markthe",
|
||||
},
|
||||
AcceptEnv: []string{"*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
ci: &sshConnInfo{sshUser: "alice"},
|
||||
wantUser: "",
|
||||
wantAcceptEnv: nil,
|
||||
wantMatch: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &conn{
|
||||
info: tt.ci,
|
||||
srv: &server{logf: t.Logf},
|
||||
}
|
||||
got, gotUser, gotAcceptEnv, match := c.evalSSHPolicy(tt.policy, nil)
|
||||
if match != tt.wantMatch {
|
||||
t.Errorf("match = %v; want %v", match, tt.wantMatch)
|
||||
}
|
||||
if gotUser != tt.wantUser {
|
||||
t.Errorf("user = %q; want %q", gotUser, tt.wantUser)
|
||||
}
|
||||
if tt.wantMatch == true && got == nil {
|
||||
t.Errorf("expected non-nil action on success")
|
||||
}
|
||||
if !slices.Equal(gotAcceptEnv, tt.wantAcceptEnv) {
|
||||
t.Errorf("acceptEnv = %v; want %v", gotAcceptEnv, tt.wantAcceptEnv)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user