mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 13:05:46 +00:00
ssh/tailssh: break a method into half in prep for testing
And add a private context type in the process. Updates #3802 Change-Id: I257187f4cfb0f2248d95b81c1dfe0911ef203b60 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
1b5bb2e81d
commit
e2ed06c53c
61
ssh/tailssh/context.go
Normal file
61
ssh/tailssh/context.go
Normal file
@ -0,0 +1,61 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tailssh
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// sshContext is the context.Context implementation we use for SSH
|
||||
// that adds a CloseWithError method. Otherwise it's just a normalish
|
||||
// Context.
|
||||
type sshContext struct {
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
done chan struct{}
|
||||
err error
|
||||
}
|
||||
|
||||
func newSSHContext() *sshContext {
|
||||
return &sshContext{done: make(chan struct{})}
|
||||
}
|
||||
|
||||
func (ctx *sshContext) CloseWithError(err error) {
|
||||
ctx.mu.Lock()
|
||||
defer ctx.mu.Unlock()
|
||||
if ctx.closed {
|
||||
return
|
||||
}
|
||||
ctx.closed = true
|
||||
ctx.err = err
|
||||
close(ctx.done)
|
||||
}
|
||||
|
||||
func (ctx *sshContext) Err() error {
|
||||
ctx.mu.Lock()
|
||||
defer ctx.mu.Unlock()
|
||||
return ctx.err
|
||||
}
|
||||
|
||||
func (ctx *sshContext) Done() <-chan struct{} { return ctx.done }
|
||||
func (ctx *sshContext) Deadline() (deadline time.Time, ok bool) { return }
|
||||
func (ctx *sshContext) Value(interface{}) interface{} { return nil }
|
||||
|
||||
// userVisibleError is a wrapper around an error that implements
|
||||
// SSHTerminationError, so msg is written to their session.
|
||||
type userVisibleError struct {
|
||||
msg string
|
||||
error
|
||||
}
|
||||
|
||||
func (ue userVisibleError) SSHTerminationMessage() string { return ue.msg }
|
||||
|
||||
// SSHTerminationError is implemented by errors that terminate an SSH
|
||||
// session and should be written to user's sessions.
|
||||
type SSHTerminationError interface {
|
||||
error
|
||||
SSHTerminationMessage() string
|
||||
}
|
@ -9,6 +9,7 @@
|
||||
package tailssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -102,7 +103,6 @@ func (srv *server) sshPolicy() (_ *tailcfg.SSHPolicy, ok bool) {
|
||||
func (srv *server) handleSSH(s ssh.Session) {
|
||||
lb := srv.lb
|
||||
logf := srv.logf
|
||||
|
||||
sshUser := s.User()
|
||||
addr := s.RemoteAddr()
|
||||
logf("Handling SSH from %v for user %v", addr, sshUser)
|
||||
@ -131,7 +131,6 @@ func (srv *server) handleSSH(s ssh.Session) {
|
||||
return
|
||||
}
|
||||
|
||||
ptyReq, winCh, isPty := s.Pty()
|
||||
srcIPP := netaddr.IPPortFrom(tanetaddr, uint16(ta.Port))
|
||||
node, uprof, ok := lb.WhoIs(srcIPP)
|
||||
if !ok {
|
||||
@ -167,7 +166,34 @@ func (srv *server) handleSSH(s ssh.Session) {
|
||||
s.Exit(1)
|
||||
}
|
||||
|
||||
logf("ssh: connection from %v %v to %v@ => %q. command = %q, env = %q", srcIP, uprof.LoginName, sshUser, localUser, s.Command(), s.Environ())
|
||||
var ctx context.Context = context.Background()
|
||||
if action.SesssionDuration != 0 {
|
||||
sctx := newSSHContext()
|
||||
ctx = sctx
|
||||
t := time.AfterFunc(action.SesssionDuration, func() {
|
||||
sctx.CloseWithError(userVisibleError{
|
||||
fmt.Sprintf("Session timeout of %v elapsed.", action.SesssionDuration),
|
||||
context.DeadlineExceeded,
|
||||
})
|
||||
})
|
||||
defer t.Stop()
|
||||
}
|
||||
srv.handleAcceptedSSH(ctx, s, ci, lu)
|
||||
}
|
||||
|
||||
// handleAcceptedSSH handles s once it's been accepted and determined
|
||||
// that it should run as local system user lu.
|
||||
//
|
||||
// When ctx is done, the session is forcefully terminated. If its Err
|
||||
// is an SSHTerminationError, its SSHTerminationMessage is sent to the
|
||||
// user.
|
||||
func (srv *server) handleAcceptedSSH(ctx context.Context, s ssh.Session, ci *sshConnInfo, lu *user.User) {
|
||||
logf := srv.logf
|
||||
localUser := lu.Username
|
||||
|
||||
var err error
|
||||
ptyReq, winCh, isPty := s.Pty()
|
||||
logf("ssh: connection from %v %v to %v@ => %q. command = %q, env = %q", ci.srcIP, ci.uprof.LoginName, ci.sshUser, localUser, s.Command(), s.Environ())
|
||||
var cmd *exec.Cmd
|
||||
if euid := os.Geteuid(); euid != 0 {
|
||||
if lu.Uid != fmt.Sprint(euid) {
|
||||
@ -223,12 +249,24 @@ func (srv *server) handleSSH(s ssh.Session) {
|
||||
go func() { io.Copy(s.Stderr(), stderr) }()
|
||||
}
|
||||
|
||||
if action.SesssionDuration != 0 {
|
||||
t := time.AfterFunc(action.SesssionDuration, func() {
|
||||
logf("terminating SSH session from %v after max duration", srcIP)
|
||||
cmd.Process.Kill()
|
||||
})
|
||||
defer t.Stop()
|
||||
if ctx.Done() != nil {
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
err := ctx.Err()
|
||||
if serr, ok := err.(SSHTerminationError); ok {
|
||||
msg := serr.SSHTerminationMessage()
|
||||
if msg != "" {
|
||||
io.WriteString(s.Stderr(), "\r\n\r\n"+msg+"\r\n\r\n")
|
||||
}
|
||||
}
|
||||
logf("terminating SSH session from %v: %v", ci.srcIP, err)
|
||||
cmd.Process.Kill()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
go func() {
|
||||
|
Loading…
Reference in New Issue
Block a user