ssh: Add logic to set accepted environment variables in SSH session (#13559)

Add logic to set environment variables that match the SSH rule's
`acceptEnv` settings in the SSH session's environment.

Updates https://github.com/tailscale/corp/issues/22775

Signed-off-by: Mario Minardi <mario@tailscale.com>
This commit is contained in:
Mario Minardi 2024-09-30 21:47:45 -06:00 committed by GitHub
parent dd6b808acf
commit 8f44ba1cd6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 294 additions and 55 deletions

View File

@ -4,6 +4,7 @@
package tailssh
import (
"fmt"
"slices"
"strings"
)
@ -17,27 +18,35 @@
//
// acceptEnv values may contain * and ? wildcard characters which match against
// zero or one or more characters and a single character respectively.
func filterEnv(acceptEnv []string, environ []string) []string {
func filterEnv(acceptEnv []string, environ []string) ([]string, error) {
var acceptedPairs []string
// Quick return if we have an empty list.
if acceptEnv == nil || len(acceptEnv) == 0 {
return acceptedPairs, nil
}
for _, envPair := range environ {
envVar := strings.Split(envPair, "=")[0]
variableName, _, ok := strings.Cut(envPair, "=")
if !ok {
return nil, fmt.Errorf(`invalid environment variable: %q. Variables must be in "KEY=VALUE" format`, envPair)
}
// Short circuit if we have a direct match between the environment
// variable and an AcceptEnv value.
if slices.Contains(acceptEnv, envVar) {
if slices.Contains(acceptEnv, variableName) {
acceptedPairs = append(acceptedPairs, envPair)
continue
}
// Otherwise check if we have a wildcard pattern that matches.
if matchAcceptEnv(acceptEnv, envVar) {
if matchAcceptEnv(acceptEnv, variableName) {
acceptedPairs = append(acceptedPairs, envPair)
continue
}
}
return acceptedPairs
return acceptedPairs, nil
}
// matchAcceptEnv is a convenience function that wraps calling matchAcceptEnvPattern

View File

@ -108,6 +108,7 @@ func TestFilterEnv(t *testing.T) {
acceptEnv []string
environ []string
expectedFiltered []string
wantErrMessage string
}{
{
name: "simple direct matches",
@ -127,11 +128,26 @@ func TestFilterEnv(t *testing.T) {
environ: []string{"FOO=BAR", "FOO2=BAZ", "FOO_3=123", "FOOOO4-2=AbCdEfG", "FO1-kmndGamc79567=ABC", "FO57=BAR2"},
expectedFiltered: []string{"FOO=BAR", "FOOOO4-2=AbCdEfG", "FO1-kmndGamc79567=ABC"},
},
{
name: "environ format invalid",
acceptEnv: []string{"FO?", "FOOO*", "FO*5?7"},
environ: []string{"FOOBAR"},
expectedFiltered: nil,
wantErrMessage: `invalid environment variable: "FOOBAR". Variables must be in "KEY=VALUE" format`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
filtered := filterEnv(tc.acceptEnv, tc.environ)
filtered, err := filterEnv(tc.acceptEnv, tc.environ)
if err == nil && tc.wantErrMessage != "" {
t.Errorf("wanted error with message %q but error was nil", tc.wantErrMessage)
}
if err != nil && err.Error() != tc.wantErrMessage {
t.Errorf("err = %v; want %v", err, tc.wantErrMessage)
}
if diff := cmp.Diff(tc.expectedFiltered, filtered); diff != "" {
t.Errorf("unexpected filter result (-got,+want): \n%s", diff)
}

View File

@ -12,6 +12,7 @@
package tailssh
import (
"encoding/json"
"errors"
"flag"
"fmt"
@ -154,6 +155,22 @@ func (ss *sshSession) newIncubatorCommand(logf logger.Logf) (cmd *exec.Cmd, err
incubatorArgs = append(incubatorArgs, "--cmd="+ss.RawCommand())
}
allowSendEnv := nm.HasCap(tailcfg.NodeAttrSSHEnvironmentVariables)
if allowSendEnv {
env, err := filterEnv(ss.conn.acceptEnv, ss.Session.Environ())
if err != nil {
return nil, err
}
if len(env) > 0 {
encoded, err := json.Marshal(env)
if err != nil {
return nil, fmt.Errorf("failed to encode environment: %w", err)
}
incubatorArgs = append(incubatorArgs, fmt.Sprintf("--encoded-env=%q", encoded))
}
}
return exec.CommandContext(ss.ctx, ss.conn.srv.tailscaledPath, incubatorArgs...), nil
}
@ -192,6 +209,9 @@ type incubatorArgs struct {
forceV1Behavior bool
debugTest bool
isSELinuxEnforcing bool
encodedEnv string
allowListEnvKeys string
forwardedEnviron []string
}
func parseIncubatorArgs(args []string) (incubatorArgs, error) {
@ -215,6 +235,7 @@ func parseIncubatorArgs(args []string) (incubatorArgs, error) {
flags.BoolVar(&ia.forceV1Behavior, "force-v1-behavior", false, "allow falling back to the su command if login is unavailable")
flags.BoolVar(&ia.debugTest, "debug-test", false, "should debug in test mode")
flags.BoolVar(&ia.isSELinuxEnforcing, "is-selinux-enforcing", false, "whether SELinux is in enforcing mode")
flags.StringVar(&ia.encodedEnv, "encoded-env", "", "JSON encoded array of environment variables in '['key=value']' format")
flags.Parse(args)
for _, g := range strings.Split(groups, ",") {
@ -225,6 +246,30 @@ func parseIncubatorArgs(args []string) (incubatorArgs, error) {
ia.gids = append(ia.gids, gid)
}
ia.forwardedEnviron = os.Environ()
// pass through SSH_AUTH_SOCK environment variable to support ssh agent forwarding
ia.allowListEnvKeys = "SSH_AUTH_SOCK"
if ia.encodedEnv != "" {
unquoted, err := strconv.Unquote(ia.encodedEnv)
if err != nil {
return ia, fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err)
}
var extraEnviron []string
err = json.Unmarshal([]byte(unquoted), &extraEnviron)
if err != nil {
return ia, fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err)
}
ia.forwardedEnviron = append(ia.forwardedEnviron, extraEnviron...)
for _, v := range extraEnviron {
ia.allowListEnvKeys = fmt.Sprintf("%s,%s", ia.allowListEnvKeys, strings.Split(v, "=")[0])
}
}
return ia, nil
}
@ -406,7 +451,7 @@ func tryExecLogin(dlogf logger.Logf, ia incubatorArgs) error {
dlogf("logging in with %+v", loginArgs)
// If Exec works, the Go code will not proceed past this:
err = unix.Exec(loginCmdPath, loginArgs, os.Environ())
err = unix.Exec(loginCmdPath, loginArgs, ia.forwardedEnviron)
// If we made it here, Exec failed.
return err
@ -441,7 +486,7 @@ func trySU(dlogf logger.Logf, ia incubatorArgs) (handled bool, err error) {
loginArgs := []string{
su,
"-w", "SSH_AUTH_SOCK", // pass through SSH_AUTH_SOCK environment variable to support ssh agent forwarding
"-w", ia.allowListEnvKeys,
"-l",
ia.localUser,
}
@ -453,7 +498,7 @@ func trySU(dlogf logger.Logf, ia incubatorArgs) (handled bool, err error) {
dlogf("logging in with %+v", loginArgs)
// If Exec works, the Go code will not proceed past this:
err = unix.Exec(su, loginArgs, os.Environ())
err = unix.Exec(su, loginArgs, ia.forwardedEnviron)
// If we made it here, Exec failed.
return true, err
@ -482,11 +527,11 @@ func findSU(dlogf logger.Logf, ia incubatorArgs) string {
return ""
}
// First try to execute su -w SSH_AUTH_SOCK -l <user> -c true
// First try to execute su -w <allow listed env> -l <user> -c true
// to make sure su supports the necessary arguments.
err = exec.Command(
su,
"-w", "SSH_AUTH_SOCK",
"-w", ia.allowListEnvKeys,
"-l",
ia.localUser,
"-c", "true",
@ -515,7 +560,7 @@ func handleSSHInProcess(dlogf logger.Logf, ia incubatorArgs) error {
args := shellArgs(ia.isShell, ia.cmd)
dlogf("running %s %q", ia.loginShell, args)
cmd := newCommand(ia.hasTTY, ia.loginShell, args)
cmd := newCommand(ia.hasTTY, ia.loginShell, ia.forwardedEnviron, args)
err := cmd.Run()
if ee, ok := err.(*exec.ExitError); ok {
ps := ee.ProcessState
@ -532,12 +577,12 @@ func handleSSHInProcess(dlogf logger.Logf, ia incubatorArgs) error {
return err
}
func newCommand(hasTTY bool, cmdPath string, cmdArgs []string) *exec.Cmd {
func newCommand(hasTTY bool, cmdPath string, cmdEnviron []string, cmdArgs []string) *exec.Cmd {
cmd := exec.Command(cmdPath, cmdArgs...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = os.Environ()
cmd.Env = cmdEnviron
if hasTTY {
// If we were launched with a tty then we should mark that as the ctty

View File

@ -238,6 +238,7 @@ type conn struct {
localUser *userMeta // set by doPolicyAuth
userGroupIDs []string // set by doPolicyAuth
pubKey gossh.PublicKey // set by doPolicyAuth
acceptEnv []string
// mu protects the following fields.
//
@ -377,7 +378,7 @@ func (c *conn) doPolicyAuth(ctx ssh.Context, pubKey ssh.PublicKey) error {
c.logf("failed to get conninfo: %v", err)
return errDenied
}
a, localUser, err := c.evaluatePolicy(pubKey)
a, localUser, acceptEnv, err := c.evaluatePolicy(pubKey)
if err != nil {
if pubKey == nil && c.havePubKeyPolicy() {
return errPubKeyRequired
@ -387,6 +388,7 @@ func (c *conn) doPolicyAuth(ctx ssh.Context, pubKey ssh.PublicKey) error {
c.action0 = a
c.currentAction = a
c.pubKey = pubKey
c.acceptEnv = acceptEnv
if a.Message != "" {
if err := ctx.SendAuthBanner(a.Message); err != nil {
return fmt.Errorf("SendBanner: %w", err)
@ -619,16 +621,16 @@ func (c *conn) setInfo(ctx ssh.Context) error {
// evaluatePolicy returns the SSHAction and localUser after evaluating
// the SSHPolicy for this conn. The pubKey may be nil for "none" auth.
func (c *conn) evaluatePolicy(pubKey gossh.PublicKey) (_ *tailcfg.SSHAction, localUser string, _ error) {
func (c *conn) evaluatePolicy(pubKey gossh.PublicKey) (_ *tailcfg.SSHAction, localUser string, acceptEnv []string, _ error) {
pol, ok := c.sshPolicy()
if !ok {
return nil, "", fmt.Errorf("tailssh: rejecting connection; no SSH policy")
return nil, "", nil, fmt.Errorf("tailssh: rejecting connection; no SSH policy")
}
a, localUser, ok := c.evalSSHPolicy(pol, pubKey)
a, localUser, acceptEnv, ok := c.evalSSHPolicy(pol, pubKey)
if !ok {
return nil, "", fmt.Errorf("tailssh: rejecting connection; no matching policy")
return nil, "", nil, fmt.Errorf("tailssh: rejecting connection; no matching policy")
}
return a, localUser, nil
return a, localUser, acceptEnv, nil
}
// pubKeyCacheEntry is the cache value for an HTTPS URL of public keys (like
@ -892,7 +894,7 @@ func (c *conn) newSSHSession(s ssh.Session) *sshSession {
// isStillValid reports whether the conn is still valid.
func (c *conn) isStillValid() bool {
a, localUser, err := c.evaluatePolicy(c.pubKey)
a, localUser, _, err := c.evaluatePolicy(c.pubKey)
c.vlogf("stillValid: %+v %v %v", a, localUser, err)
if err != nil {
return false
@ -1275,13 +1277,13 @@ func (c *conn) ruleExpired(r *tailcfg.SSHRule) bool {
return r.RuleExpires.Before(c.srv.now())
}
func (c *conn) evalSSHPolicy(pol *tailcfg.SSHPolicy, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, ok bool) {
func (c *conn) evalSSHPolicy(pol *tailcfg.SSHPolicy, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, acceptEnv []string, ok bool) {
for _, r := range pol.Rules {
if a, localUser, err := c.matchRule(r, pubKey); err == nil {
return a, localUser, true
if a, localUser, acceptEnv, err := c.matchRule(r, pubKey); err == nil {
return a, localUser, acceptEnv, true
}
}
return nil, "", false
return nil, "", nil, false
}
// internal errors for testing; they don't escape to callers or logs.
@ -1294,26 +1296,26 @@ func (c *conn) evalSSHPolicy(pol *tailcfg.SSHPolicy, pubKey gossh.PublicKey) (a
errInvalidConn = errors.New("invalid connection state")
)
func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, err error) {
func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, acceptEnv []string, err error) {
defer func() {
c.vlogf("matchRule(%+v): %v", r, err)
}()
if c == nil {
return nil, "", errInvalidConn
return nil, "", nil, errInvalidConn
}
if c.info == nil {
c.logf("invalid connection state")
return nil, "", errInvalidConn
return nil, "", nil, errInvalidConn
}
if r == nil {
return nil, "", errNilRule
return nil, "", nil, errNilRule
}
if r.Action == nil {
return nil, "", errNilAction
return nil, "", nil, errNilAction
}
if c.ruleExpired(r) {
return nil, "", errRuleExpired
return nil, "", nil, errRuleExpired
}
if !r.Action.Reject {
// For all but Reject rules, SSHUsers is required.
@ -1321,15 +1323,15 @@ func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg
// empty string anyway.
localUser = mapLocalUser(r.SSHUsers, c.info.sshUser)
if localUser == "" {
return nil, "", errUserMatch
return nil, "", nil, errUserMatch
}
}
if ok, err := c.anyPrincipalMatches(r.Principals, pubKey); err != nil {
return nil, "", err
return nil, "", nil, err
} else if !ok {
return nil, "", errPrincipalMatch
return nil, "", nil, errPrincipalMatch
}
return r.Action, localUser, nil
return r.Action, localUser, r.AcceptEnv, nil
}
func mapLocalUser(ruleSSHUsers map[string]string, reqSSHUser string) (localUser string) {

View File

@ -108,6 +108,7 @@ func TestIntegrationSSH(t *testing.T) {
want []string
forceV1Behavior bool
skip bool
allowSendEnv bool
}{
{
cmd: "id",
@ -131,6 +132,18 @@ func TestIntegrationSSH(t *testing.T) {
skip: os.Getenv("SKIP_FILE_OPS") == "1" || !fallbackToSUAvailable(),
forceV1Behavior: false,
},
{
cmd: `echo "${GIT_ENV_VAR:-unset1} ${EXACT_MATCH:-unset2} ${TESTING:-unset3} ${NOT_ALLOWED:-unset4}"`,
want: []string{"working1 working2 working3 unset4"},
forceV1Behavior: false,
allowSendEnv: true,
},
{
cmd: `echo "${GIT_ENV_VAR:-unset1} ${EXACT_MATCH:-unset2} ${TESTING:-unset3} ${NOT_ALLOWED:-unset4}"`,
want: []string{"unset1 unset2 unset3 unset4"},
forceV1Behavior: false,
allowSendEnv: false,
},
}
for _, test := range tests {
@ -151,7 +164,13 @@ func TestIntegrationSSH(t *testing.T) {
}
t.Run(fmt.Sprintf("%s_%s_%s", test.cmd, shellQualifier, versionQualifier), func(t *testing.T) {
s := testSession(t, test.forceV1Behavior)
sendEnv := map[string]string{
"GIT_ENV_VAR": "working1",
"EXACT_MATCH": "working2",
"TESTING": "working3",
"NOT_ALLOWED": "working4",
}
s := testSession(t, test.forceV1Behavior, test.allowSendEnv, sendEnv)
if shell {
err := s.RequestPty("xterm", 40, 80, ssh.TerminalModes{
@ -201,7 +220,7 @@ func TestIntegrationSFTP(t *testing.T) {
}
wantText := "hello world"
cl := testClient(t, forceV1Behavior)
cl := testClient(t, forceV1Behavior, false)
scl, err := sftp.NewClient(cl)
if err != nil {
t.Fatalf("can't get sftp client: %s", err)
@ -233,7 +252,7 @@ func TestIntegrationSFTP(t *testing.T) {
t.Fatalf("unexpected file contents (-got +want):\n%s", diff)
}
s := testSessionFor(t, cl)
s := testSessionFor(t, cl, nil)
got := s.run(t, "ls -l "+filePath, false)
if !strings.Contains(got, "testuser") {
t.Fatalf("unexpected file owner user: %s", got)
@ -262,7 +281,7 @@ func TestIntegrationSCP(t *testing.T) {
}
wantText := "hello world"
cl := testClient(t, forceV1Behavior)
cl := testClient(t, forceV1Behavior, false)
scl, err := scp.NewClientBySSH(cl)
if err != nil {
t.Fatalf("can't get sftp client: %s", err)
@ -291,7 +310,7 @@ func TestIntegrationSCP(t *testing.T) {
t.Fatalf("unexpected file contents (-got +want):\n%s", diff)
}
s := testSessionFor(t, cl)
s := testSessionFor(t, cl, nil)
got := s.run(t, "ls -l "+filePath, false)
if !strings.Contains(got, "testuser") {
t.Fatalf("unexpected file owner user: %s", got)
@ -349,7 +368,7 @@ func TestSSHAgentForwarding(t *testing.T) {
// Run tailscale SSH server and connect to it
username := "testuser"
tailscaleAddr := testServer(t, username, false)
tailscaleAddr := testServer(t, username, false, false)
tcl, err := ssh.Dial("tcp", tailscaleAddr, &ssh.ClientConfig{
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
@ -465,11 +484,11 @@ func (s *session) read() string {
return string(_got)
}
func testClient(t *testing.T, forceV1Behavior bool, authMethods ...ssh.AuthMethod) *ssh.Client {
func testClient(t *testing.T, forceV1Behavior bool, allowSendEnv bool, authMethods ...ssh.AuthMethod) *ssh.Client {
t.Helper()
username := "testuser"
addr := testServer(t, username, forceV1Behavior)
addr := testServer(t, username, forceV1Behavior, allowSendEnv)
cl, err := ssh.Dial("tcp", addr, &ssh.ClientConfig{
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
@ -483,9 +502,9 @@ func testClient(t *testing.T, forceV1Behavior bool, authMethods ...ssh.AuthMetho
return cl
}
func testServer(t *testing.T, username string, forceV1Behavior bool) string {
func testServer(t *testing.T, username string, forceV1Behavior bool, allowSendEnv bool) string {
srv := &server{
lb: &testBackend{localUser: username, forceV1Behavior: forceV1Behavior},
lb: &testBackend{localUser: username, forceV1Behavior: forceV1Behavior, allowSendEnv: allowSendEnv},
logf: log.Printf,
tailscaledPath: os.Getenv("TAILSCALED_PATH"),
timeNow: time.Now,
@ -509,16 +528,20 @@ func testServer(t *testing.T, username string, forceV1Behavior bool) string {
return l.Addr().String()
}
func testSession(t *testing.T, forceV1Behavior bool) *session {
cl := testClient(t, forceV1Behavior)
return testSessionFor(t, cl)
func testSession(t *testing.T, forceV1Behavior bool, allowSendEnv bool, sendEnv map[string]string) *session {
cl := testClient(t, forceV1Behavior, allowSendEnv)
return testSessionFor(t, cl, sendEnv)
}
func testSessionFor(t *testing.T, cl *ssh.Client) *session {
func testSessionFor(t *testing.T, cl *ssh.Client, sendEnv map[string]string) *session {
s, err := cl.NewSession()
if err != nil {
t.Fatal(err)
}
for k, v := range sendEnv {
s.Setenv(k, v)
}
t.Cleanup(func() { s.Close() })
stdinReader, stdinWriter := io.Pipe()
@ -564,6 +587,7 @@ func generateClientKey(t *testing.T, privateKeyFile string) (ssh.Signer, *rsa.Pr
type testBackend struct {
localUser string
forceV1Behavior bool
allowSendEnv bool
}
func (tb *testBackend) GetSSH_HostKeys() ([]gossh.Signer, error) {
@ -597,6 +621,9 @@ func (tb *testBackend) NetMap() *netmap.NetworkMap {
if tb.forceV1Behavior {
capMap[tailcfg.NodeAttrSSHBehaviorV1] = struct{}{}
}
if tb.allowSendEnv {
capMap[tailcfg.NodeAttrSSHEnvironmentVariables] = struct{}{}
}
return &netmap.NetworkMap{
SSHPolicy: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
@ -604,6 +631,7 @@ func (tb *testBackend) NetMap() *netmap.NetworkMap {
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
Action: &tailcfg.SSHAction{Accept: true, AllowAgentForwarding: true},
SSHUsers: map[string]string{"*": tb.localUser},
AcceptEnv: []string{"GIT_*", "EXACT_MATCH", "TEST?NG"},
},
},
},

View File

@ -24,6 +24,7 @@
"os/user"
"reflect"
"runtime"
"slices"
"strconv"
"strings"
"sync"
@ -56,11 +57,12 @@
func TestMatchRule(t *testing.T) {
someAction := new(tailcfg.SSHAction)
tests := []struct {
name string
rule *tailcfg.SSHRule
ci *sshConnInfo
wantErr error
wantUser string
name string
rule *tailcfg.SSHRule
ci *sshConnInfo
wantErr error
wantUser string
wantAcceptEnv []string
}{
{
name: "invalid-conn",
@ -153,6 +155,21 @@ func TestMatchRule(t *testing.T) {
ci: &sshConnInfo{sshUser: "alice"},
wantUser: "thealice",
},
{
name: "ok-with-accept-env",
rule: &tailcfg.SSHRule{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"*": "ubuntu",
"alice": "thealice",
},
AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
},
ci: &sshConnInfo{sshUser: "alice"},
wantUser: "thealice",
wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
},
{
name: "no-users-for-reject",
rule: &tailcfg.SSHRule{
@ -210,7 +227,7 @@ func TestMatchRule(t *testing.T) {
info: tt.ci,
srv: &server{logf: t.Logf},
}
got, gotUser, err := c.matchRule(tt.rule, nil)
got, gotUser, gotAcceptEnv, err := c.matchRule(tt.rule, nil)
if err != tt.wantErr {
t.Errorf("err = %v; want %v", err, tt.wantErr)
}
@ -220,6 +237,128 @@ func TestMatchRule(t *testing.T) {
if err == nil && got == nil {
t.Errorf("expected non-nil action on success")
}
if !slices.Equal(gotAcceptEnv, tt.wantAcceptEnv) {
t.Errorf("acceptEnv = %v; want %v", gotAcceptEnv, tt.wantAcceptEnv)
}
})
}
}
func TestEvalSSHPolicy(t *testing.T) {
someAction := new(tailcfg.SSHAction)
tests := []struct {
name string
policy *tailcfg.SSHPolicy
ci *sshConnInfo
wantMatch bool
wantUser string
wantAcceptEnv []string
}{
{
name: "multiple-matches-picks-first-match",
policy: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"other": "other1",
},
},
{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"*": "ubuntu",
"alice": "thealice",
},
AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
},
{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"other2": "other3",
},
},
{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"*": "ubuntu",
"alice": "thealice",
"mark": "markthe",
},
AcceptEnv: []string{"*"},
},
},
},
ci: &sshConnInfo{sshUser: "alice"},
wantUser: "thealice",
wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
wantMatch: true,
},
{
name: "no-matches-returns-failure",
policy: &tailcfg.SSHPolicy{
Rules: []*tailcfg.SSHRule{
{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"other": "other1",
},
},
{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"fedora": "ubuntu",
},
AcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"},
},
{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"other2": "other3",
},
},
{
Action: someAction,
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
SSHUsers: map[string]string{
"mark": "markthe",
},
AcceptEnv: []string{"*"},
},
},
},
ci: &sshConnInfo{sshUser: "alice"},
wantUser: "",
wantAcceptEnv: nil,
wantMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &conn{
info: tt.ci,
srv: &server{logf: t.Logf},
}
got, gotUser, gotAcceptEnv, match := c.evalSSHPolicy(tt.policy, nil)
if match != tt.wantMatch {
t.Errorf("match = %v; want %v", match, tt.wantMatch)
}
if gotUser != tt.wantUser {
t.Errorf("user = %q; want %q", gotUser, tt.wantUser)
}
if tt.wantMatch == true && got == nil {
t.Errorf("expected non-nil action on success")
}
if !slices.Equal(gotAcceptEnv, tt.wantAcceptEnv) {
t.Errorf("acceptEnv = %v; want %v", gotAcceptEnv, tt.wantAcceptEnv)
}
})
}
}