diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 061987305..943e5297f 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -141,14 +141,14 @@ func (srv *server) handleSSH(s ssh.Session) { } srcIP := srcIPP.IP() - sctx := &sshContext{ + ci := &sshConnInfo{ now: time.Now(), sshUser: sshUser, srcIP: srcIP, node: node, uprof: &uprof, } - action, localUser, ok := evalSSHPolicy(pol, sctx) + action, localUser, ok := evalSSHPolicy(pol, ci) if ok && action.Message != "" { io.WriteString(s.Stderr(), strings.Replace(action.Message, "\n", "\r\n", -1)) } @@ -264,7 +264,7 @@ func setWinsize(f *os.File, w, h int) { uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0}))) } -type sshContext struct { +type sshConnInfo struct { // now is the time to consider the present moment for the // purposes of rule evaluation. now time.Time @@ -282,9 +282,9 @@ type sshContext struct { uprof *tailcfg.UserProfile } -func evalSSHPolicy(pol *tailcfg.SSHPolicy, sctx *sshContext) (a *tailcfg.SSHAction, localUser string, ok bool) { +func evalSSHPolicy(pol *tailcfg.SSHPolicy, ci *sshConnInfo) (a *tailcfg.SSHAction, localUser string, ok bool) { for _, r := range pol.Rules { - if a, localUser, err := matchRule(r, sctx); err == nil { + if a, localUser, err := matchRule(r, ci); err == nil { return a, localUser, true } } @@ -300,21 +300,21 @@ var ( errUserMatch = errors.New("user didn't match") ) -func matchRule(r *tailcfg.SSHRule, sctx *sshContext) (a *tailcfg.SSHAction, localUser string, err error) { +func matchRule(r *tailcfg.SSHRule, ci *sshConnInfo) (a *tailcfg.SSHAction, localUser string, err error) { if r == nil { return nil, "", errNilRule } if r.Action == nil { return nil, "", errNilAction } - if r.RuleExpires != nil && sctx.now.After(*r.RuleExpires) { + if r.RuleExpires != nil && ci.now.After(*r.RuleExpires) { return nil, "", errRuleExpired } - if !matchesPrincipal(r.Principals, sctx) { + if !matchesPrincipal(r.Principals, ci) { return nil, "", errPrincipalMatch } if !r.Action.Reject || r.SSHUsers != nil { - localUser = mapLocalUser(r.SSHUsers, sctx.sshUser) + localUser = mapLocalUser(r.SSHUsers, ci.sshUser) if localUser == "" { return nil, "", errUserMatch } @@ -329,7 +329,7 @@ func mapLocalUser(ruleSSHUsers map[string]string, reqSSHUser string) (localUser return ruleSSHUsers["*"] } -func matchesPrincipal(ps []*tailcfg.SSHPrincipal, sctx *sshContext) bool { +func matchesPrincipal(ps []*tailcfg.SSHPrincipal, ci *sshConnInfo) bool { for _, p := range ps { if p == nil { continue @@ -337,15 +337,15 @@ func matchesPrincipal(ps []*tailcfg.SSHPrincipal, sctx *sshContext) bool { if p.Any { return true } - if !p.Node.IsZero() && sctx.node != nil && p.Node == sctx.node.StableID { + if !p.Node.IsZero() && ci.node != nil && p.Node == ci.node.StableID { return true } if p.NodeIP != "" { - if ip, _ := netaddr.ParseIP(p.NodeIP); ip == sctx.srcIP { + if ip, _ := netaddr.ParseIP(p.NodeIP); ip == ci.srcIP { return true } } - if p.UserLogin != "" && sctx.uprof != nil && sctx.uprof.LoginName == p.UserLogin { + if p.UserLogin != "" && ci.uprof != nil && ci.uprof.LoginName == p.UserLogin { return true } } diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 8becf2e2c..eeb71bc34 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -20,7 +20,7 @@ func TestMatchRule(t *testing.T) { tests := []struct { name string rule *tailcfg.SSHRule - ctx *sshContext + ci *sshConnInfo wantErr error wantUser string }{ @@ -40,7 +40,7 @@ func TestMatchRule(t *testing.T) { Action: someAction, RuleExpires: timePtr(time.Unix(100, 0)), }, - ctx: &sshContext{now: time.Unix(200, 0)}, + ci: &sshConnInfo{now: time.Unix(200, 0)}, wantErr: errRuleExpired, }, { @@ -56,7 +56,7 @@ func TestMatchRule(t *testing.T) { Action: someAction, Principals: []*tailcfg.SSHPrincipal{{Any: true}}, }, - ctx: &sshContext{sshUser: "alice"}, + ci: &sshConnInfo{sshUser: "alice"}, wantErr: errUserMatch, }, { @@ -68,7 +68,7 @@ func TestMatchRule(t *testing.T) { "*": "ubuntu", }, }, - ctx: &sshContext{sshUser: "alice"}, + ci: &sshConnInfo{sshUser: "alice"}, wantUser: "ubuntu", }, { @@ -83,7 +83,7 @@ func TestMatchRule(t *testing.T) { "*": "ubuntu", }, }, - ctx: &sshContext{sshUser: "alice"}, + ci: &sshConnInfo{sshUser: "alice"}, wantUser: "ubuntu", }, { @@ -96,7 +96,7 @@ func TestMatchRule(t *testing.T) { "alice": "thealice", }, }, - ctx: &sshContext{sshUser: "alice"}, + ci: &sshConnInfo{sshUser: "alice"}, wantUser: "thealice", }, { @@ -105,7 +105,7 @@ func TestMatchRule(t *testing.T) { Principals: []*tailcfg.SSHPrincipal{{Any: true}}, Action: &tailcfg.SSHAction{Reject: true}, }, - ctx: &sshContext{sshUser: "alice"}, + ci: &sshConnInfo{sshUser: "alice"}, }, { name: "match-principal-node-ip", @@ -114,7 +114,7 @@ func TestMatchRule(t *testing.T) { Principals: []*tailcfg.SSHPrincipal{{NodeIP: "1.2.3.4"}}, SSHUsers: map[string]string{"*": "ubuntu"}, }, - ctx: &sshContext{srcIP: netaddr.MustParseIP("1.2.3.4")}, + ci: &sshConnInfo{srcIP: netaddr.MustParseIP("1.2.3.4")}, wantUser: "ubuntu", }, { @@ -124,7 +124,7 @@ func TestMatchRule(t *testing.T) { Principals: []*tailcfg.SSHPrincipal{{Node: "some-node-ID"}}, SSHUsers: map[string]string{"*": "ubuntu"}, }, - ctx: &sshContext{node: &tailcfg.Node{StableID: "some-node-ID"}}, + ci: &sshConnInfo{node: &tailcfg.Node{StableID: "some-node-ID"}}, wantUser: "ubuntu", }, { @@ -134,13 +134,13 @@ func TestMatchRule(t *testing.T) { Principals: []*tailcfg.SSHPrincipal{{UserLogin: "foo@bar.com"}}, SSHUsers: map[string]string{"*": "ubuntu"}, }, - ctx: &sshContext{uprof: &tailcfg.UserProfile{LoginName: "foo@bar.com"}}, + ci: &sshConnInfo{uprof: &tailcfg.UserProfile{LoginName: "foo@bar.com"}}, wantUser: "ubuntu", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, gotUser, err := matchRule(tt.rule, tt.ctx) + got, gotUser, err := matchRule(tt.rule, tt.ci) if err != tt.wantErr { t.Errorf("err = %v; want %v", err, tt.wantErr) }