diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 91da77cf0..fb37b5922 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -41,7 +41,11 @@ func Handle(logf logger.Logf, lb *ipnlocal.LocalBackend, c net.Conn) error { if err != nil { return err } - srv := &server{lb, logf, tsd} + srv := &server{ + lb: lb, + logf: logf, + tailscaledPath: tsd, + } ss, err := srv.newSSHServer() if err != nil { return err @@ -55,7 +59,13 @@ func (srv *server) newSSHServer() (*ssh.Server, error) { Handler: srv.handleSSH, RequestHandlers: map[string]ssh.RequestHandler{}, SubsystemHandlers: map[string]ssh.SubsystemHandler{}, - ChannelHandlers: map[string]ssh.ChannelHandler{}, + // Note: the direct-tcpip channel handler and LocalPortForwardingCallback + // only adds support for forwarding ports from the local machine. + // TODO(maisem/bradfitz): add remote port forwarding support. + ChannelHandlers: map[string]ssh.ChannelHandler{ + "direct-tcpip": ssh.DirectTCPIPHandler, + }, + LocalPortForwardingCallback: srv.portForward, } for k, v := range ssh.DefaultRequestHandlers { ss.RequestHandlers[k] = v @@ -80,10 +90,21 @@ type server struct { lb *ipnlocal.LocalBackend logf logger.Logf tailscaledPath string + + // mu protects activeSessions. + mu sync.Mutex + activeSessions map[string]bool } var debugPolicyFile = envknob.String("TS_DEBUG_SSH_POLICY_FILE") +// portForward reports whether the ctx should be allowed to port forward +// to the specified host and port. +// TODO(bradfitz/maisem): should we have more checks on host/port? +func (srv *server) portForward(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { + return srv.isActiveSession(ctx) +} + // sshPolicy returns the SSHPolicy for current node. // If there is no SSHPolicy in the netmap, it returns a debugPolicy // if one is defined. @@ -231,6 +252,31 @@ func (srv *server) handleSessionTermination(ctx context.Context, s ssh.Session, }) } +// isActiveSession reports whether the ssh.Context corresponds +// to an active session. +func (srv *server) isActiveSession(sctx ssh.Context) bool { + srv.mu.Lock() + defer srv.mu.Unlock() + return srv.activeSessions[sctx.SessionID()] +} + +// startSession registers s as an active session. +func (srv *server) startSession(s ssh.Session) { + srv.mu.Lock() + defer srv.mu.Unlock() + if srv.activeSessions == nil { + srv.activeSessions = make(map[string]bool) + } + srv.activeSessions[s.Context().(ssh.Context).SessionID()] = true +} + +// endSession unregisters s from the list of active sessions. +func (srv *server) endSession(s ssh.Session) { + srv.mu.Lock() + defer srv.mu.Unlock() + delete(srv.activeSessions, s.Context().(ssh.Context).SessionID()) +} + // handleAcceptedSSH handles s once it's been accepted and determined // that it should run as local system user lu. // @@ -238,6 +284,8 @@ func (srv *server) handleSessionTermination(ctx context.Context, s ssh.Session, // is an SSHTerminationError, its SSHTerminationMessage is sent to the // user. func (srv *server) handleAcceptedSSH(ctx context.Context, s ssh.Session, ci *sshConnInfo, lu *user.User) { + srv.startSession(s) + defer srv.endSession(s) logf := srv.logf localUser := lu.Username diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index ad86fc57d..4c569c9b9 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -189,7 +189,10 @@ func TestSSH(t *testing.T) { dir := t.TempDir() lb.SetVarRoot(dir) - srv := &server{lb, logf, ""} + srv := &server{ + lb: lb, + logf: logf, + } ss, err := srv.newSSHServer() if err != nil { t.Fatal(err)