tailscale/tempfork/sshtest/ssh/transport.go
Brad Fitzpatrick 2f98197857 tempfork/sshtest/ssh: add fork of golang.org/x/crypto/ssh for testing only
This fork golang.org/x/crypto/ssh (at upstream x/crypto git rev e47973b1c1)
into tailscale.com/tempfork/sshtest/ssh so we can hack up the client in weird
ways to simulate other SSH clients seen in the wild.

Two changes were made to the files when they were copied from x/crypto:

* internal/poly1305 imports were replaced by the non-internal version;
  no code changes otherwise. It didn't need the internal one.
* all decode-with-passphrase funcs were deleted, to avoid
  using the internal package x/crypto/ssh/internal/bcrypt_pbkdf

Then the tests passed.

Updates #14969

Change-Id: Ibf1abebfe608c75fef4da0255314f65e54ce5077
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2025-02-11 07:45:06 -08:00

381 lines
9.8 KiB
Go

// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"bufio"
"bytes"
"errors"
"io"
"log"
)
// debugTransport if set, will print packet types as they go over the
// wire. No message decoding is done, to minimize the impact on timing.
const debugTransport = false
const (
gcm128CipherID = "aes128-gcm@openssh.com"
gcm256CipherID = "aes256-gcm@openssh.com"
aes128cbcID = "aes128-cbc"
tripledescbcID = "3des-cbc"
)
// packetConn represents a transport that implements packet based
// operations.
type packetConn interface {
// Encrypt and send a packet of data to the remote peer.
writePacket(packet []byte) error
// Read a packet from the connection. The read is blocking,
// i.e. if error is nil, then the returned byte slice is
// always non-empty.
readPacket() ([]byte, error)
// Close closes the write-side of the connection.
Close() error
}
// transport is the keyingTransport that implements the SSH packet
// protocol.
type transport struct {
reader connectionState
writer connectionState
bufReader *bufio.Reader
bufWriter *bufio.Writer
rand io.Reader
isClient bool
io.Closer
strictMode bool
initialKEXDone bool
}
// packetCipher represents a combination of SSH encryption/MAC
// protocol. A single instance should be used for one direction only.
type packetCipher interface {
// writeCipherPacket encrypts the packet and writes it to w. The
// contents of the packet are generally scrambled.
writeCipherPacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error
// readCipherPacket reads and decrypts a packet of data. The
// returned packet may be overwritten by future calls of
// readPacket.
readCipherPacket(seqnum uint32, r io.Reader) ([]byte, error)
}
// connectionState represents one side (read or write) of the
// connection. This is necessary because each direction has its own
// keys, and can even have its own algorithms
type connectionState struct {
packetCipher
seqNum uint32
dir direction
pendingKeyChange chan packetCipher
}
func (t *transport) setStrictMode() error {
if t.reader.seqNum != 1 {
return errors.New("ssh: sequence number != 1 when strict KEX mode requested")
}
t.strictMode = true
return nil
}
func (t *transport) setInitialKEXDone() {
t.initialKEXDone = true
}
// prepareKeyChange sets up key material for a keychange. The key changes in
// both directions are triggered by reading and writing a msgNewKey packet
// respectively.
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult)
if err != nil {
return err
}
t.reader.pendingKeyChange <- ciph
ciph, err = newPacketCipher(t.writer.dir, algs.w, kexResult)
if err != nil {
return err
}
t.writer.pendingKeyChange <- ciph
return nil
}
func (t *transport) printPacket(p []byte, write bool) {
if len(p) == 0 {
return
}
who := "server"
if t.isClient {
who = "client"
}
what := "read"
if write {
what = "write"
}
log.Println(what, who, p[0])
}
// Read and decrypt next packet.
func (t *transport) readPacket() (p []byte, err error) {
for {
p, err = t.reader.readPacket(t.bufReader, t.strictMode)
if err != nil {
break
}
// in strict mode we pass through DEBUG and IGNORE packets only during the initial KEX
if len(p) == 0 || (t.strictMode && !t.initialKEXDone) || (p[0] != msgIgnore && p[0] != msgDebug) {
break
}
}
if debugTransport {
t.printPacket(p, false)
}
return p, err
}
func (s *connectionState) readPacket(r *bufio.Reader, strictMode bool) ([]byte, error) {
packet, err := s.packetCipher.readCipherPacket(s.seqNum, r)
s.seqNum++
if err == nil && len(packet) == 0 {
err = errors.New("ssh: zero length packet")
}
if len(packet) > 0 {
switch packet[0] {
case msgNewKeys:
select {
case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher
if strictMode {
s.seqNum = 0
}
default:
return nil, errors.New("ssh: got bogus newkeys message")
}
case msgDisconnect:
// Transform a disconnect message into an
// error. Since this is lowest level at which
// we interpret message types, doing it here
// ensures that we don't have to handle it
// elsewhere.
var msg disconnectMsg
if err := Unmarshal(packet, &msg); err != nil {
return nil, err
}
return nil, &msg
}
}
// The packet may point to an internal buffer, so copy the
// packet out here.
fresh := make([]byte, len(packet))
copy(fresh, packet)
return fresh, err
}
func (t *transport) writePacket(packet []byte) error {
if debugTransport {
t.printPacket(packet, true)
}
return t.writer.writePacket(t.bufWriter, t.rand, packet, t.strictMode)
}
func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte, strictMode bool) error {
changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
err := s.packetCipher.writeCipherPacket(s.seqNum, w, rand, packet)
if err != nil {
return err
}
if err = w.Flush(); err != nil {
return err
}
s.seqNum++
if changeKeys {
select {
case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher
if strictMode {
s.seqNum = 0
}
default:
panic("ssh: no key material for msgNewKeys")
}
}
return err
}
func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport {
t := &transport{
bufReader: bufio.NewReader(rwc),
bufWriter: bufio.NewWriter(rwc),
rand: rand,
reader: connectionState{
packetCipher: &streamPacketCipher{cipher: noneCipher{}},
pendingKeyChange: make(chan packetCipher, 1),
},
writer: connectionState{
packetCipher: &streamPacketCipher{cipher: noneCipher{}},
pendingKeyChange: make(chan packetCipher, 1),
},
Closer: rwc,
}
t.isClient = isClient
if isClient {
t.reader.dir = serverKeys
t.writer.dir = clientKeys
} else {
t.reader.dir = clientKeys
t.writer.dir = serverKeys
}
return t
}
type direction struct {
ivTag []byte
keyTag []byte
macKeyTag []byte
}
var (
serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}}
clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
)
// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
// described in RFC 4253, section 6.4. direction should either be serverKeys
// (to setup server->client keys) or clientKeys (for client->server keys).
func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) {
cipherMode := cipherModes[algs.Cipher]
iv := make([]byte, cipherMode.ivSize)
key := make([]byte, cipherMode.keySize)
generateKeyMaterial(iv, d.ivTag, kex)
generateKeyMaterial(key, d.keyTag, kex)
var macKey []byte
if !aeadCiphers[algs.Cipher] {
macMode := macModes[algs.MAC]
macKey = make([]byte, macMode.keySize)
generateKeyMaterial(macKey, d.macKeyTag, kex)
}
return cipherModes[algs.Cipher].create(key, iv, macKey, algs)
}
// generateKeyMaterial fills out with key material generated from tag, K, H
// and sessionId, as specified in RFC 4253, section 7.2.
func generateKeyMaterial(out, tag []byte, r *kexResult) {
var digestsSoFar []byte
h := r.Hash.New()
for len(out) > 0 {
h.Reset()
h.Write(r.K)
h.Write(r.H)
if len(digestsSoFar) == 0 {
h.Write(tag)
h.Write(r.SessionID)
} else {
h.Write(digestsSoFar)
}
digest := h.Sum(nil)
n := copy(out, digest)
out = out[n:]
if len(out) > 0 {
digestsSoFar = append(digestsSoFar, digest...)
}
}
}
const packageVersion = "SSH-2.0-Go"
// Sends and receives a version line. The versionLine string should
// be US ASCII, start with "SSH-2.0-", and should not include a
// newline. exchangeVersions returns the other side's version line.
func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) {
// Contrary to the RFC, we do not ignore lines that don't
// start with "SSH-2.0-" to make the library usable with
// nonconforming servers.
for _, c := range versionLine {
// The spec disallows non US-ASCII chars, and
// specifically forbids null chars.
if c < 32 {
return nil, errors.New("ssh: junk character in version line")
}
}
if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil {
return
}
them, err = readVersion(rw)
return them, err
}
// maxVersionStringBytes is the maximum number of bytes that we'll
// accept as a version string. RFC 4253 section 4.2 limits this at 255
// chars
const maxVersionStringBytes = 255
// Read version string as specified by RFC 4253, section 4.2.
func readVersion(r io.Reader) ([]byte, error) {
versionString := make([]byte, 0, 64)
var ok bool
var buf [1]byte
for length := 0; length < maxVersionStringBytes; length++ {
_, err := io.ReadFull(r, buf[:])
if err != nil {
return nil, err
}
// The RFC says that the version should be terminated with \r\n
// but several SSH servers actually only send a \n.
if buf[0] == '\n' {
if !bytes.HasPrefix(versionString, []byte("SSH-")) {
// RFC 4253 says we need to ignore all version string lines
// except the one containing the SSH version (provided that
// all the lines do not exceed 255 bytes in total).
versionString = versionString[:0]
continue
}
ok = true
break
}
// non ASCII chars are disallowed, but we are lenient,
// since Go doesn't use null-terminated strings.
// The RFC allows a comment after a space, however,
// all of it (version and comments) goes into the
// session hash.
versionString = append(versionString, buf[0])
}
if !ok {
return nil, errors.New("ssh: overflow reading version string")
}
// There might be a '\r' on the end which we should remove.
if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' {
versionString = versionString[:len(versionString)-1]
}
return versionString, nil
}