ssh/tailssh: always use current time for policy evaluation

Whenever the SSH policy changes we revaluate all open connections to
make sure they still have access. This check was using the wrong
timestamp and would match against expired policies, however this really
isn't a problem today as we don't have policy that would be impacted by
this check. Fixing it for future use.

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2022-06-27 12:13:50 -07:00 committed by Maisem Ali
parent a7d2024e35
commit c434e47f2d
2 changed files with 6 additions and 10 deletions

View File

@ -69,7 +69,7 @@ type server struct {
} }
func (srv *server) now() time.Time { func (srv *server) now() time.Time {
if srv.timeNow != nil { if srv != nil && srv.timeNow != nil {
return srv.timeNow() return srv.timeNow()
} }
return time.Now() return time.Now()
@ -152,10 +152,6 @@ type conn struct {
insecureSkipTailscaleAuth bool // used by tests. insecureSkipTailscaleAuth bool // used by tests.
// now is the time to consider the present moment for the
// purposes of rule evaluation.
now time.Time
connID string // ID that's shared with control connID string // ID that's shared with control
action0 *tailcfg.SSHAction // first matching action action0 *tailcfg.SSHAction // first matching action
srv *server srv *server
@ -278,8 +274,9 @@ func (srv *server) newConn() (*conn, error) {
return nil, gossh.ErrDenied return nil, gossh.ErrDenied
} }
srv.mu.Unlock() srv.mu.Unlock()
c := &conn{srv: srv, now: srv.now()} c := &conn{srv: srv}
c.connID = fmt.Sprintf("conn-%s-%02x", c.now.UTC().Format("20060102T150405"), randBytes(5)) now := srv.now()
c.connID = fmt.Sprintf("conn-%s-%02x", now.UTC().Format("20060102T150405"), randBytes(5))
c.Server = &ssh.Server{ c.Server = &ssh.Server{
Version: "Tailscale", Version: "Tailscale",
Handler: c.handleSessionPostSSHAuth, Handler: c.handleSessionPostSSHAuth,
@ -751,7 +748,7 @@ func (ss *sshSession) vlogf(format string, args ...interface{}) {
} }
func (c *conn) newSSHSession(s ssh.Session) *sshSession { func (c *conn) newSSHSession(s ssh.Session) *sshSession {
sharedID := fmt.Sprintf("sess-%s-%02x", c.now.UTC().Format("20060102T150405"), randBytes(5)) sharedID := fmt.Sprintf("sess-%s-%02x", c.srv.now().UTC().Format("20060102T150405"), randBytes(5))
c.logf("starting session: %v", sharedID) c.logf("starting session: %v", sharedID)
return &sshSession{ return &sshSession{
Session: s, Session: s,
@ -1087,7 +1084,7 @@ func (c *conn) ruleExpired(r *tailcfg.SSHRule) bool {
if r.RuleExpires == nil { if r.RuleExpires == nil {
return false return false
} }
return r.RuleExpires.Before(c.now) return r.RuleExpires.Before(c.srv.now())
} }
func (c *conn) evalSSHPolicy(pol *tailcfg.SSHPolicy, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, ok bool) { func (c *conn) evalSSHPolicy(pol *tailcfg.SSHPolicy, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, ok bool) {

View File

@ -179,7 +179,6 @@ func TestMatchRule(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := &conn{ c := &conn{
now: time.Unix(200, 0),
info: tt.ci, info: tt.ci,
} }
got, gotUser, err := c.matchRule(tt.rule, nil) got, gotUser, err := c.matchRule(tt.rule, nil)