ssh/tailssh: replace incubator process with su instead of running su as child

This allows the SSH_AUTH_SOCK environment variable to work inside of
su and agent forwarding to succeed.

Fixes #12467

Signed-off-by: Percy Wegmann <percy@tailscale.com>
This commit is contained in:
Percy Wegmann
2024-06-14 13:28:39 -05:00
committed by Percy Wegmann
parent 24976b5bfd
commit 730f0368d0
5 changed files with 202 additions and 51 deletions

View File

@@ -8,14 +8,13 @@ package tailssh
import (
"bufio"
"bytes"
"context"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io"
"log"
@@ -24,6 +23,7 @@ import (
"net/netip"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"testing"
@@ -34,8 +34,10 @@ import (
"github.com/pkg/sftp"
gossh "github.com/tailscale/golang-x-crypto/ssh"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"tailscale.com/net/tsdial"
"tailscale.com/tailcfg"
glider "tailscale.com/tempfork/gliderlabs/ssh"
"tailscale.com/types/key"
"tailscale.com/types/netmap"
"tailscale.com/util/set"
@@ -300,6 +302,95 @@ func TestIntegrationSCP(t *testing.T) {
}
}
func TestSSHAgentForwarding(t *testing.T) {
debugTest.Store(true)
t.Cleanup(func() {
debugTest.Store(false)
})
// Create a client SSH key
tmpDir, err := os.MkdirTemp("", "")
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
_ = os.RemoveAll(tmpDir)
})
pkFile := filepath.Join(tmpDir, "pk")
clientKey, clientKeyRSA := generateClientKey(t, pkFile)
// Start upstream SSH server
l, err := net.Listen("tcp", "127.0.0.1:")
if err != nil {
t.Fatalf("unable to listen for SSH: %s", err)
}
t.Cleanup(func() {
_ = l.Close()
})
// Run an SSH server that accepts connections from that client SSH key.
gs := glider.Server{
Handler: func(s glider.Session) {
io.WriteString(s, "Hello world\n")
},
PublicKeyHandler: func(ctx glider.Context, key glider.PublicKey) error {
// Note - this is not meant to be cryptographically secure, it's
// just checking that SSH agent forwarding is forwarding the right
// key.
a := key.Marshal()
b := clientKey.PublicKey().Marshal()
if !bytes.Equal(a, b) {
return errors.New("key mismatch")
}
return nil
},
}
go gs.Serve(l)
// Run tailscale SSH server and connect to it
username := "testuser"
tailscaleAddr := testServer(t, username, false) // TODO: make this false to use V2 behavior
tcl, err := ssh.Dial("tcp", tailscaleAddr, &ssh.ClientConfig{
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { tcl.Close() })
s, err := tcl.NewSession()
if err != nil {
t.Fatal(err)
}
// Set up SSH agent forwarding on the client
err = agent.RequestAgentForwarding(s)
if err != nil {
t.Fatal(err)
}
keyring := agent.NewKeyring()
keyring.Add(agent.AddedKey{
PrivateKey: clientKeyRSA,
})
err = agent.ForwardToAgent(tcl, keyring)
if err != nil {
t.Fatal(err)
}
// Attempt to SSH to the upstream test server using the forwarded SSH key
// and run the "true" command.
upstreamHost, upstreamPort, err := net.SplitHostPort(l.Addr().String())
if err != nil {
t.Fatal(err)
}
o, err := s.CombinedOutput(fmt.Sprintf(`ssh -T -o StrictHostKeyChecking=no -p %s upstreamuser@%s "true"`, upstreamPort, upstreamHost))
if err != nil {
t.Fatalf("unable to call true command: %s\n%s", err, o)
}
}
func fallbackToSUAvailable() bool {
if runtime.GOOS != "linux" {
return false
@@ -374,10 +465,25 @@ readLoop:
return string(_got)
}
func testClient(t *testing.T, forceV1Behavior bool) *ssh.Client {
func testClient(t *testing.T, forceV1Behavior bool, authMethods ...ssh.AuthMethod) *ssh.Client {
t.Helper()
username := "testuser"
addr := testServer(t, username, forceV1Behavior)
cl, err := ssh.Dial("tcp", addr, &ssh.ClientConfig{
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Auth: authMethods,
})
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { cl.Close() })
return cl
}
func testServer(t *testing.T, username string, forceV1Behavior bool) string {
srv := &server{
lb: &testBackend{localUser: username, forceV1Behavior: forceV1Behavior},
logf: log.Printf,
@@ -392,21 +498,15 @@ func testClient(t *testing.T, forceV1Behavior bool) *ssh.Client {
t.Cleanup(func() { l.Close() })
go func() {
conn, err := l.Accept()
if err == nil {
go srv.HandleSSHConn(&addressFakingConn{conn})
for {
conn, err := l.Accept()
if err == nil {
go srv.HandleSSHConn(&addressFakingConn{conn})
}
}
}()
cl, err := ssh.Dial("tcp", l.Addr().String(), &ssh.ClientConfig{
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
if err != nil {
log.Fatal(err)
}
t.Cleanup(func() { cl.Close() })
return cl
return l.Addr().String()
}
func testSession(t *testing.T, forceV1Behavior bool) *session {
@@ -417,7 +517,7 @@ func testSession(t *testing.T, forceV1Behavior bool) *session {
func testSessionFor(t *testing.T, cl *ssh.Client) *session {
s, err := cl.NewSession()
if err != nil {
log.Fatal(err)
t.Fatal(err)
}
t.Cleanup(func() { s.Close() })
@@ -435,6 +535,31 @@ func testSessionFor(t *testing.T, cl *ssh.Client) *session {
}
}
func generateClientKey(t *testing.T, privateKeyFile string) (ssh.Signer, *rsa.PrivateKey) {
t.Helper()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
mk, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
t.Fatal(err)
}
privateKey := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: mk})
if privateKey == nil {
t.Fatal("failed to encoded private key")
}
err = os.WriteFile(privateKeyFile, privateKey, 0600)
if err != nil {
t.Fatal(err)
}
signer, err := ssh.ParsePrivateKey(privateKey)
if err != nil {
t.Fatal(err)
}
return signer, priv
}
// testBackend implements ipnLocalBackend
type testBackend struct {
localUser string
@@ -443,33 +568,23 @@ type testBackend struct {
func (tb *testBackend) GetSSH_HostKeys() ([]gossh.Signer, error) {
var result []gossh.Signer
for _, typ := range []string{"ed25519", "ecdsa", "rsa"} {
var priv any
var err error
switch typ {
case "ed25519":
_, priv, err = ed25519.GenerateKey(rand.Reader)
case "ecdsa":
curve := elliptic.P256()
priv, err = ecdsa.GenerateKey(curve, rand.Reader)
case "rsa":
const keySize = 2048
priv, err = rsa.GenerateKey(rand.Reader, keySize)
}
if err != nil {
return nil, err
}
mk, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
return nil, err
}
hostKey := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: mk})
signer, err := gossh.ParsePrivateKey(hostKey)
if err != nil {
return nil, err
}
result = append(result, signer)
var priv any
var err error
const keySize = 2048
priv, err = rsa.GenerateKey(rand.Reader, keySize)
if err != nil {
return nil, err
}
mk, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
return nil, err
}
hostKey := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: mk})
signer, err := gossh.ParsePrivateKey(hostKey)
if err != nil {
return nil, err
}
result = append(result, signer)
return result, nil
}
@@ -487,7 +602,7 @@ func (tb *testBackend) NetMap() *netmap.NetworkMap {
Rules: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
Action: &tailcfg.SSHAction{Accept: true},
Action: &tailcfg.SSHAction{Accept: true, AllowAgentForwarding: true},
SSHUsers: map[string]string{"*": tb.localUser},
},
},