From 091ea4a4a54a1a68cc86ea0d859b4660455a7361 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 22 Mar 2022 15:37:17 -0700 Subject: [PATCH] ssh/tailssh: support placeholders in SSHAction.HoldAndDelegate URL Updates #3802 Change-Id: I60f9827409d14fd4f4824d102ba11db49bf0d365 Signed-off-by: Brad Fitzpatrick --- ssh/tailssh/tailssh.go | 109 ++++++++++++++++++++++++------------ ssh/tailssh/tailssh_test.go | 3 +- 2 files changed, 74 insertions(+), 38 deletions(-) diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index a4859196e..df94d7dd7 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -214,39 +214,6 @@ func (srv *server) handleSSH(s ssh.Session) { return } - // Loop processing/fetching Actions until one reaches a - // terminal state (Accept, Reject, or invalid Action), or - // until fetchSSHAction times out due to the context being - // done (client disconnect) or its 30 minute timeout passes. - // (Which is a long time for somebody to see login - // instructions and go to a URL to do something.) -ProcessAction: - for { - if action.Message != "" { - io.WriteString(s.Stderr(), strings.Replace(action.Message, "\n", "\r\n", -1)) - } - if action.Reject { - logf("ssh: access denied for %q from %v", ci.uprof.LoginName, ci.src.IP()) - s.Exit(1) - return - } - if action.Accept { - break ProcessAction - } - url := action.HoldAndDelegate - if url == "" { - logf("ssh: access denied; SSHAction has neither Reject, Accept, or next step URL") - s.Exit(1) - return - } - action, err = srv.fetchSSHAction(s.Context(), url) - if err != nil { - logf("ssh: fetching SSAction from %s: %v", url, err) - s.Exit(1) - return - } - } - lu, err := user.Lookup(localUser) if err != nil { logf("ssh: user Lookup %q: %v", localUser, err) @@ -254,10 +221,71 @@ func (srv *server) handleSSH(s ssh.Session) { return } - ss := srv.newSSHSession(s, ci, lu, action) + ss := srv.newSSHSession(s, ci, lu) + action, err = ss.resolveTerminalAction(action) + if err != nil { + logf("ssh: resolveTerminalAction: %v", err) + io.WriteString(s.Stderr(), "Access denied: failed to resolve SSHAction.\n") + s.Exit(1) + return + } + if action.Reject || !action.Accept { + logf("ssh: access denied for %q from %v", ci.uprof.LoginName, ci.src.IP()) + s.Exit(1) + return + } + + ss.action = action ss.run() } +// resolveTerminalAction either returns action (if it's Accept or Reject) or else +// loops, fetching new SSHActions from the control plane. +// +// Any action with a Message in the chain will be printed to ss. +// +// The returned SSHAction will be either Reject or Accept. +func (ss *sshSession) resolveTerminalAction(action *tailcfg.SSHAction) (*tailcfg.SSHAction, error) { + // Loop processing/fetching Actions until one reaches a + // terminal state (Accept, Reject, or invalid Action), or + // until fetchSSHAction times out due to the context being + // done (client disconnect) or its 30 minute timeout passes. + // (Which is a long time for somebody to see login + // instructions and go to a URL to do something.) + for { + if action.Message != "" { + io.WriteString(ss.Stderr(), strings.Replace(action.Message, "\n", "\r\n", -1)) + } + if action.Accept || action.Reject { + return action, nil + } + url := action.HoldAndDelegate + if url == "" { + return nil, errors.New("reached Action that lacked Accept, Reject, and HoldAndDelegate") + } + url = ss.expandDelegateURL(url) + var err error + action, err = ss.srv.fetchSSHAction(ss.Context(), url) + if err != nil { + return nil, fmt.Errorf("fetching SSHAction from %s: %w", url, err) + } + } +} + +func (ss *sshSession) expandDelegateURL(url string) string { + nm := ss.srv.lb.NetMap() + var dstNodeID string + if nm != nil { + dstNodeID = fmt.Sprint(int64(nm.SelfNode.ID)) + } + return strings.NewReplacer( + "$SRC_NODE_ID", fmt.Sprint(int64(ss.connInfo.node.ID)), + "$DST_NODE_ID", dstNodeID, + "$SSH_USER", ss.connInfo.sshUser, + "$LOCAL_USER", ss.localUser.Username, + ).Replace(url) +} + // sshSession is an accepted Tailscale SSH session. type sshSession struct { ssh.Session @@ -284,7 +312,7 @@ type sshSession struct { exitOnce sync.Once } -func (srv *server) newSSHSession(s ssh.Session, ci *sshConnInfo, lu *user.User, action *tailcfg.SSHAction) *sshSession { +func (srv *server) newSSHSession(s ssh.Session, ci *sshConnInfo, lu *user.User) *sshSession { sharedID := fmt.Sprintf("%s-%02x", ci.now.UTC().Format("20060102T150405"), randBytes(5)) return &sshSession{ Session: s, @@ -292,7 +320,6 @@ func (srv *server) newSSHSession(s ssh.Session, ci *sshConnInfo, lu *user.User, sharedID: sharedID, ctx: newSSHContext(), srv: srv, - action: action, localUser: lu, connInfo: ci, logf: logger.WithPrefix(srv.logf, "ssh-session("+sharedID+"): "), @@ -317,12 +344,20 @@ func (srv *server) fetchSSHAction(ctx context.Context, url string) (*tailcfg.SSH continue } if res.StatusCode != 200 { + body, _ := io.ReadAll(res.Body) res.Body.Close() + if len(body) > 1<<10 { + body = body[:1<<10] + } + srv.logf("fetch of %v: %s, %s", url, res.Status, body) bo.BackOff(ctx, fmt.Errorf("unexpected status: %v", res.Status)) continue } a := new(tailcfg.SSHAction) - if err := json.NewDecoder(res.Body).Decode(a); err != nil { + err = json.NewDecoder(res.Body).Decode(a) + res.Body.Close() + if err != nil { + srv.logf("invalid next SSHAction JSON from %v: %v", url, err) bo.BackOff(ctx, err) continue } diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 9a442f7fd..57767d6ab 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -224,7 +224,8 @@ func TestSSH(t *testing.T) { } ss.Handler = func(s ssh.Session) { - ss := srv.newSSHSession(s, ci, u, &tailcfg.SSHAction{Accept: true}) + ss := srv.newSSHSession(s, ci, u) + ss.action = &tailcfg.SSHAction{Accept: true} ss.run() }