mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-14 06:57:31 +00:00
ssh/tailssh: replace incubator process with su instead of running su as child
This allows the SSH_AUTH_SOCK environment variable to work inside of su and agent forwarding to succeed. Fixes #12467 Signed-off-by: Percy Wegmann <percy@tailscale.com>
This commit is contained in:

committed by
Percy Wegmann

parent
24976b5bfd
commit
730f0368d0
@@ -8,14 +8,13 @@ package tailssh
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -24,6 +23,7 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -34,8 +34,10 @@ import (
|
||||
"github.com/pkg/sftp"
|
||||
gossh "github.com/tailscale/golang-x-crypto/ssh"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
"tailscale.com/net/tsdial"
|
||||
"tailscale.com/tailcfg"
|
||||
glider "tailscale.com/tempfork/gliderlabs/ssh"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/netmap"
|
||||
"tailscale.com/util/set"
|
||||
@@ -300,6 +302,95 @@ func TestIntegrationSCP(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHAgentForwarding(t *testing.T) {
|
||||
debugTest.Store(true)
|
||||
t.Cleanup(func() {
|
||||
debugTest.Store(false)
|
||||
})
|
||||
|
||||
// Create a client SSH key
|
||||
tmpDir, err := os.MkdirTemp("", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
})
|
||||
pkFile := filepath.Join(tmpDir, "pk")
|
||||
clientKey, clientKeyRSA := generateClientKey(t, pkFile)
|
||||
|
||||
// Start upstream SSH server
|
||||
l, err := net.Listen("tcp", "127.0.0.1:")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to listen for SSH: %s", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = l.Close()
|
||||
})
|
||||
|
||||
// Run an SSH server that accepts connections from that client SSH key.
|
||||
gs := glider.Server{
|
||||
Handler: func(s glider.Session) {
|
||||
io.WriteString(s, "Hello world\n")
|
||||
},
|
||||
PublicKeyHandler: func(ctx glider.Context, key glider.PublicKey) error {
|
||||
// Note - this is not meant to be cryptographically secure, it's
|
||||
// just checking that SSH agent forwarding is forwarding the right
|
||||
// key.
|
||||
a := key.Marshal()
|
||||
b := clientKey.PublicKey().Marshal()
|
||||
if !bytes.Equal(a, b) {
|
||||
return errors.New("key mismatch")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
go gs.Serve(l)
|
||||
|
||||
// Run tailscale SSH server and connect to it
|
||||
username := "testuser"
|
||||
tailscaleAddr := testServer(t, username, false) // TODO: make this false to use V2 behavior
|
||||
tcl, err := ssh.Dial("tcp", tailscaleAddr, &ssh.ClientConfig{
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { tcl.Close() })
|
||||
|
||||
s, err := tcl.NewSession()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Set up SSH agent forwarding on the client
|
||||
err = agent.RequestAgentForwarding(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
keyring := agent.NewKeyring()
|
||||
keyring.Add(agent.AddedKey{
|
||||
PrivateKey: clientKeyRSA,
|
||||
})
|
||||
err = agent.ForwardToAgent(tcl, keyring)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Attempt to SSH to the upstream test server using the forwarded SSH key
|
||||
// and run the "true" command.
|
||||
upstreamHost, upstreamPort, err := net.SplitHostPort(l.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
o, err := s.CombinedOutput(fmt.Sprintf(`ssh -T -o StrictHostKeyChecking=no -p %s upstreamuser@%s "true"`, upstreamPort, upstreamHost))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to call true command: %s\n%s", err, o)
|
||||
}
|
||||
}
|
||||
|
||||
func fallbackToSUAvailable() bool {
|
||||
if runtime.GOOS != "linux" {
|
||||
return false
|
||||
@@ -374,10 +465,25 @@ readLoop:
|
||||
return string(_got)
|
||||
}
|
||||
|
||||
func testClient(t *testing.T, forceV1Behavior bool) *ssh.Client {
|
||||
func testClient(t *testing.T, forceV1Behavior bool, authMethods ...ssh.AuthMethod) *ssh.Client {
|
||||
t.Helper()
|
||||
|
||||
username := "testuser"
|
||||
addr := testServer(t, username, forceV1Behavior)
|
||||
|
||||
cl, err := ssh.Dial("tcp", addr, &ssh.ClientConfig{
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Auth: authMethods,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { cl.Close() })
|
||||
|
||||
return cl
|
||||
}
|
||||
|
||||
func testServer(t *testing.T, username string, forceV1Behavior bool) string {
|
||||
srv := &server{
|
||||
lb: &testBackend{localUser: username, forceV1Behavior: forceV1Behavior},
|
||||
logf: log.Printf,
|
||||
@@ -392,21 +498,15 @@ func testClient(t *testing.T, forceV1Behavior bool) *ssh.Client {
|
||||
t.Cleanup(func() { l.Close() })
|
||||
|
||||
go func() {
|
||||
conn, err := l.Accept()
|
||||
if err == nil {
|
||||
go srv.HandleSSHConn(&addressFakingConn{conn})
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err == nil {
|
||||
go srv.HandleSSHConn(&addressFakingConn{conn})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
cl, err := ssh.Dial("tcp", l.Addr().String(), &ssh.ClientConfig{
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { cl.Close() })
|
||||
|
||||
return cl
|
||||
return l.Addr().String()
|
||||
}
|
||||
|
||||
func testSession(t *testing.T, forceV1Behavior bool) *session {
|
||||
@@ -417,7 +517,7 @@ func testSession(t *testing.T, forceV1Behavior bool) *session {
|
||||
func testSessionFor(t *testing.T, cl *ssh.Client) *session {
|
||||
s, err := cl.NewSession()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { s.Close() })
|
||||
|
||||
@@ -435,6 +535,31 @@ func testSessionFor(t *testing.T, cl *ssh.Client) *session {
|
||||
}
|
||||
}
|
||||
|
||||
func generateClientKey(t *testing.T, privateKeyFile string) (ssh.Signer, *rsa.PrivateKey) {
|
||||
t.Helper()
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mk, err := x509.MarshalPKCS8PrivateKey(priv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
privateKey := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: mk})
|
||||
if privateKey == nil {
|
||||
t.Fatal("failed to encoded private key")
|
||||
}
|
||||
err = os.WriteFile(privateKeyFile, privateKey, 0600)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
signer, err := ssh.ParsePrivateKey(privateKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return signer, priv
|
||||
}
|
||||
|
||||
// testBackend implements ipnLocalBackend
|
||||
type testBackend struct {
|
||||
localUser string
|
||||
@@ -443,33 +568,23 @@ type testBackend struct {
|
||||
|
||||
func (tb *testBackend) GetSSH_HostKeys() ([]gossh.Signer, error) {
|
||||
var result []gossh.Signer
|
||||
for _, typ := range []string{"ed25519", "ecdsa", "rsa"} {
|
||||
var priv any
|
||||
var err error
|
||||
switch typ {
|
||||
case "ed25519":
|
||||
_, priv, err = ed25519.GenerateKey(rand.Reader)
|
||||
case "ecdsa":
|
||||
curve := elliptic.P256()
|
||||
priv, err = ecdsa.GenerateKey(curve, rand.Reader)
|
||||
case "rsa":
|
||||
const keySize = 2048
|
||||
priv, err = rsa.GenerateKey(rand.Reader, keySize)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mk, err := x509.MarshalPKCS8PrivateKey(priv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hostKey := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: mk})
|
||||
signer, err := gossh.ParsePrivateKey(hostKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, signer)
|
||||
var priv any
|
||||
var err error
|
||||
const keySize = 2048
|
||||
priv, err = rsa.GenerateKey(rand.Reader, keySize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mk, err := x509.MarshalPKCS8PrivateKey(priv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hostKey := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: mk})
|
||||
signer, err := gossh.ParsePrivateKey(hostKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, signer)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -487,7 +602,7 @@ func (tb *testBackend) NetMap() *netmap.NetworkMap {
|
||||
Rules: []*tailcfg.SSHRule{
|
||||
{
|
||||
Principals: []*tailcfg.SSHPrincipal{{Any: true}},
|
||||
Action: &tailcfg.SSHAction{Accept: true},
|
||||
Action: &tailcfg.SSHAction{Accept: true, AllowAgentForwarding: true},
|
||||
SSHUsers: map[string]string{"*": tb.localUser},
|
||||
},
|
||||
},
|
||||
|
Reference in New Issue
Block a user