mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 13:05:46 +00:00
ssh/tailssh: break a method into half in prep for testing
And add a private context type in the process. Updates #3802 Change-Id: I257187f4cfb0f2248d95b81c1dfe0911ef203b60 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
1b5bb2e81d
commit
e2ed06c53c
61
ssh/tailssh/context.go
Normal file
61
ssh/tailssh/context.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package tailssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// sshContext is the context.Context implementation we use for SSH
|
||||||
|
// that adds a CloseWithError method. Otherwise it's just a normalish
|
||||||
|
// Context.
|
||||||
|
type sshContext struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
closed bool
|
||||||
|
done chan struct{}
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSSHContext() *sshContext {
|
||||||
|
return &sshContext{done: make(chan struct{})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *sshContext) CloseWithError(err error) {
|
||||||
|
ctx.mu.Lock()
|
||||||
|
defer ctx.mu.Unlock()
|
||||||
|
if ctx.closed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx.closed = true
|
||||||
|
ctx.err = err
|
||||||
|
close(ctx.done)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *sshContext) Err() error {
|
||||||
|
ctx.mu.Lock()
|
||||||
|
defer ctx.mu.Unlock()
|
||||||
|
return ctx.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *sshContext) Done() <-chan struct{} { return ctx.done }
|
||||||
|
func (ctx *sshContext) Deadline() (deadline time.Time, ok bool) { return }
|
||||||
|
func (ctx *sshContext) Value(interface{}) interface{} { return nil }
|
||||||
|
|
||||||
|
// userVisibleError is a wrapper around an error that implements
|
||||||
|
// SSHTerminationError, so msg is written to their session.
|
||||||
|
type userVisibleError struct {
|
||||||
|
msg string
|
||||||
|
error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ue userVisibleError) SSHTerminationMessage() string { return ue.msg }
|
||||||
|
|
||||||
|
// SSHTerminationError is implemented by errors that terminate an SSH
|
||||||
|
// session and should be written to user's sessions.
|
||||||
|
type SSHTerminationError interface {
|
||||||
|
error
|
||||||
|
SSHTerminationMessage() string
|
||||||
|
}
|
@ -9,6 +9,7 @@
|
|||||||
package tailssh
|
package tailssh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -102,7 +103,6 @@ func (srv *server) sshPolicy() (_ *tailcfg.SSHPolicy, ok bool) {
|
|||||||
func (srv *server) handleSSH(s ssh.Session) {
|
func (srv *server) handleSSH(s ssh.Session) {
|
||||||
lb := srv.lb
|
lb := srv.lb
|
||||||
logf := srv.logf
|
logf := srv.logf
|
||||||
|
|
||||||
sshUser := s.User()
|
sshUser := s.User()
|
||||||
addr := s.RemoteAddr()
|
addr := s.RemoteAddr()
|
||||||
logf("Handling SSH from %v for user %v", addr, sshUser)
|
logf("Handling SSH from %v for user %v", addr, sshUser)
|
||||||
@ -131,7 +131,6 @@ func (srv *server) handleSSH(s ssh.Session) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ptyReq, winCh, isPty := s.Pty()
|
|
||||||
srcIPP := netaddr.IPPortFrom(tanetaddr, uint16(ta.Port))
|
srcIPP := netaddr.IPPortFrom(tanetaddr, uint16(ta.Port))
|
||||||
node, uprof, ok := lb.WhoIs(srcIPP)
|
node, uprof, ok := lb.WhoIs(srcIPP)
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -167,7 +166,34 @@ func (srv *server) handleSSH(s ssh.Session) {
|
|||||||
s.Exit(1)
|
s.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
logf("ssh: connection from %v %v to %v@ => %q. command = %q, env = %q", srcIP, uprof.LoginName, sshUser, localUser, s.Command(), s.Environ())
|
var ctx context.Context = context.Background()
|
||||||
|
if action.SesssionDuration != 0 {
|
||||||
|
sctx := newSSHContext()
|
||||||
|
ctx = sctx
|
||||||
|
t := time.AfterFunc(action.SesssionDuration, func() {
|
||||||
|
sctx.CloseWithError(userVisibleError{
|
||||||
|
fmt.Sprintf("Session timeout of %v elapsed.", action.SesssionDuration),
|
||||||
|
context.DeadlineExceeded,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
defer t.Stop()
|
||||||
|
}
|
||||||
|
srv.handleAcceptedSSH(ctx, s, ci, lu)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAcceptedSSH handles s once it's been accepted and determined
|
||||||
|
// that it should run as local system user lu.
|
||||||
|
//
|
||||||
|
// When ctx is done, the session is forcefully terminated. If its Err
|
||||||
|
// is an SSHTerminationError, its SSHTerminationMessage is sent to the
|
||||||
|
// user.
|
||||||
|
func (srv *server) handleAcceptedSSH(ctx context.Context, s ssh.Session, ci *sshConnInfo, lu *user.User) {
|
||||||
|
logf := srv.logf
|
||||||
|
localUser := lu.Username
|
||||||
|
|
||||||
|
var err error
|
||||||
|
ptyReq, winCh, isPty := s.Pty()
|
||||||
|
logf("ssh: connection from %v %v to %v@ => %q. command = %q, env = %q", ci.srcIP, ci.uprof.LoginName, ci.sshUser, localUser, s.Command(), s.Environ())
|
||||||
var cmd *exec.Cmd
|
var cmd *exec.Cmd
|
||||||
if euid := os.Geteuid(); euid != 0 {
|
if euid := os.Geteuid(); euid != 0 {
|
||||||
if lu.Uid != fmt.Sprint(euid) {
|
if lu.Uid != fmt.Sprint(euid) {
|
||||||
@ -223,12 +249,24 @@ func (srv *server) handleSSH(s ssh.Session) {
|
|||||||
go func() { io.Copy(s.Stderr(), stderr) }()
|
go func() { io.Copy(s.Stderr(), stderr) }()
|
||||||
}
|
}
|
||||||
|
|
||||||
if action.SesssionDuration != 0 {
|
if ctx.Done() != nil {
|
||||||
t := time.AfterFunc(action.SesssionDuration, func() {
|
done := make(chan struct{})
|
||||||
logf("terminating SSH session from %v after max duration", srcIP)
|
defer close(done)
|
||||||
cmd.Process.Kill()
|
go func() {
|
||||||
})
|
select {
|
||||||
defer t.Stop()
|
case <-done:
|
||||||
|
case <-ctx.Done():
|
||||||
|
err := ctx.Err()
|
||||||
|
if serr, ok := err.(SSHTerminationError); ok {
|
||||||
|
msg := serr.SSHTerminationMessage()
|
||||||
|
if msg != "" {
|
||||||
|
io.WriteString(s.Stderr(), "\r\n\r\n"+msg+"\r\n\r\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logf("terminating SSH session from %v: %v", ci.srcIP, err)
|
||||||
|
cmd.Process.Kill()
|
||||||
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
Loading…
Reference in New Issue
Block a user