// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package ssh

import (
	"bytes"
	"crypto/rand"
	"errors"
	"fmt"
	"io"
	"net"
	"reflect"
	"runtime"
	"strings"
	"sync"
	"testing"
)

type testChecker struct {
	calls []string
}

func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
	if dialAddr == "bad" {
		return fmt.Errorf("dialAddr is bad")
	}

	if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
		return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
	}

	t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))

	return nil
}

// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
// therefore is buffered (net.Pipe deadlocks if both sides start with
// a write.)
func netPipe() (net.Conn, net.Conn, error) {
	listener, err := net.Listen("tcp", "127.0.0.1:0")
	if err != nil {
		listener, err = net.Listen("tcp", "[::1]:0")
		if err != nil {
			return nil, nil, err
		}
	}
	defer listener.Close()
	c1, err := net.Dial("tcp", listener.Addr().String())
	if err != nil {
		return nil, nil, err
	}

	c2, err := listener.Accept()
	if err != nil {
		c1.Close()
		return nil, nil, err
	}

	return c1, c2, nil
}

// noiseTransport inserts ignore messages to check that the read loop
// and the key exchange filters out these messages.
type noiseTransport struct {
	keyingTransport
}

func (t *noiseTransport) writePacket(p []byte) error {
	ignore := []byte{msgIgnore}
	if err := t.keyingTransport.writePacket(ignore); err != nil {
		return err
	}
	debug := []byte{msgDebug, 1, 2, 3}
	if err := t.keyingTransport.writePacket(debug); err != nil {
		return err
	}

	return t.keyingTransport.writePacket(p)
}

func addNoiseTransport(t keyingTransport) keyingTransport {
	return &noiseTransport{t}
}

// handshakePair creates two handshakeTransports connected with each
// other. If the noise argument is true, both transports will try to
// confuse the other side by sending ignore and debug messages.
func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) {
	a, b, err := netPipe()
	if err != nil {
		return nil, nil, err
	}

	var trC, trS keyingTransport

	trC = newTransport(a, rand.Reader, true)
	trS = newTransport(b, rand.Reader, false)
	if noise {
		trC = addNoiseTransport(trC)
		trS = addNoiseTransport(trS)
	}
	clientConf.SetDefaults()

	v := []byte("version")
	client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())

	serverConf := &ServerConfig{}
	serverConf.AddHostKey(testSigners["ecdsa"])
	serverConf.AddHostKey(testSigners["rsa"])
	serverConf.SetDefaults()
	server = newServerTransport(trS, v, v, serverConf)

	if err := server.waitSession(); err != nil {
		return nil, nil, fmt.Errorf("server.waitSession: %v", err)
	}
	if err := client.waitSession(); err != nil {
		return nil, nil, fmt.Errorf("client.waitSession: %v", err)
	}

	return client, server, nil
}

func TestHandshakeBasic(t *testing.T) {
	if runtime.GOOS == "plan9" {
		t.Skip("see golang.org/issue/7237")
	}

	checker := &syncChecker{
		waitCall: make(chan int, 10),
		called:   make(chan int, 10),
	}

	checker.waitCall <- 1
	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
	if err != nil {
		t.Fatalf("handshakePair: %v", err)
	}

	defer trC.Close()
	defer trS.Close()

	// Let first kex complete normally.
	<-checker.called

	clientDone := make(chan int, 0)
	gotHalf := make(chan int, 0)
	const N = 20
	errorCh := make(chan error, 1)

	go func() {
		defer close(clientDone)
		// Client writes a bunch of stuff, and does a key
		// change in the middle. This should not confuse the
		// handshake in progress. We do this twice, so we test
		// that the packet buffer is reset correctly.
		for i := 0; i < N; i++ {
			p := []byte{msgRequestSuccess, byte(i)}
			if err := trC.writePacket(p); err != nil {
				errorCh <- err
				trC.Close()
				return
			}
			if (i % 10) == 5 {
				<-gotHalf
				// halfway through, we request a key change.
				trC.requestKeyExchange()

				// Wait until we can be sure the key
				// change has really started before we
				// write more.
				<-checker.called
			}
			if (i % 10) == 7 {
				// write some packets until the kex
				// completes, to test buffering of
				// packets.
				checker.waitCall <- 1
			}
		}
		errorCh <- nil
	}()

	// Server checks that client messages come in cleanly
	i := 0
	for ; i < N; i++ {
		p, err := trS.readPacket()
		if err != nil && err != io.EOF {
			t.Fatalf("server error: %v", err)
		}
		if (i % 10) == 5 {
			gotHalf <- 1
		}

		want := []byte{msgRequestSuccess, byte(i)}
		if bytes.Compare(p, want) != 0 {
			t.Errorf("message %d: got %v, want %v", i, p, want)
		}
	}
	<-clientDone
	if err := <-errorCh; err != nil {
		t.Fatalf("sendPacket: %v", err)
	}
	if i != N {
		t.Errorf("received %d messages, want 10.", i)
	}

	close(checker.called)
	if _, ok := <-checker.called; ok {
		// If all went well, we registered exactly 2 key changes: one
		// that establishes the session, and one that we requested
		// additionally.
		t.Fatalf("got another host key checks after 2 handshakes")
	}
}

func TestForceFirstKex(t *testing.T) {
	// like handshakePair, but must access the keyingTransport.
	checker := &testChecker{}
	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
	a, b, err := netPipe()
	if err != nil {
		t.Fatalf("netPipe: %v", err)
	}

	var trC, trS keyingTransport

	trC = newTransport(a, rand.Reader, true)

	// This is the disallowed packet:
	trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))

	// Rest of the setup.
	trS = newTransport(b, rand.Reader, false)
	clientConf.SetDefaults()

	v := []byte("version")
	client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())

	serverConf := &ServerConfig{}
	serverConf.AddHostKey(testSigners["ecdsa"])
	serverConf.AddHostKey(testSigners["rsa"])
	serverConf.SetDefaults()
	server := newServerTransport(trS, v, v, serverConf)

	defer client.Close()
	defer server.Close()

	// We setup the initial key exchange, but the remote side
	// tries to send serviceRequestMsg in cleartext, which is
	// disallowed.

	if err := server.waitSession(); err == nil {
		t.Errorf("server first kex init should reject unexpected packet")
	}
}

func TestHandshakeAutoRekeyWrite(t *testing.T) {
	checker := &syncChecker{
		called:   make(chan int, 10),
		waitCall: nil,
	}
	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
	clientConf.RekeyThreshold = 500
	trC, trS, err := handshakePair(clientConf, "addr", false)
	if err != nil {
		t.Fatalf("handshakePair: %v", err)
	}
	defer trC.Close()
	defer trS.Close()

	input := make([]byte, 251)
	input[0] = msgRequestSuccess

	done := make(chan int, 1)
	const numPacket = 5
	go func() {
		defer close(done)
		j := 0
		for ; j < numPacket; j++ {
			if p, err := trS.readPacket(); err != nil {
				break
			} else if !bytes.Equal(input, p) {
				t.Errorf("got packet type %d, want %d", p[0], input[0])
			}
		}

		if j != numPacket {
			t.Errorf("got %d, want 5 messages", j)
		}
	}()

	<-checker.called

	for i := 0; i < numPacket; i++ {
		p := make([]byte, len(input))
		copy(p, input)
		if err := trC.writePacket(p); err != nil {
			t.Errorf("writePacket: %v", err)
		}
		if i == 2 {
			// Make sure the kex is in progress.
			<-checker.called
		}

	}
	<-done
}

type syncChecker struct {
	waitCall chan int
	called   chan int
}

func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
	c.called <- 1
	if c.waitCall != nil {
		<-c.waitCall
	}
	return nil
}

func TestHandshakeAutoRekeyRead(t *testing.T) {
	sync := &syncChecker{
		called:   make(chan int, 2),
		waitCall: nil,
	}
	clientConf := &ClientConfig{
		HostKeyCallback: sync.Check,
	}
	clientConf.RekeyThreshold = 500

	trC, trS, err := handshakePair(clientConf, "addr", false)
	if err != nil {
		t.Fatalf("handshakePair: %v", err)
	}
	defer trC.Close()
	defer trS.Close()

	packet := make([]byte, 501)
	packet[0] = msgRequestSuccess
	if err := trS.writePacket(packet); err != nil {
		t.Fatalf("writePacket: %v", err)
	}

	// While we read out the packet, a key change will be
	// initiated.
	errorCh := make(chan error, 1)
	go func() {
		_, err := trC.readPacket()
		errorCh <- err
	}()

	if err := <-errorCh; err != nil {
		t.Fatalf("readPacket(client): %v", err)
	}

	<-sync.called
}

// errorKeyingTransport generates errors after a given number of
// read/write operations.
type errorKeyingTransport struct {
	packetConn
	readLeft, writeLeft int
}

func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
	return nil
}

func (n *errorKeyingTransport) getSessionID() []byte {
	return nil
}

func (n *errorKeyingTransport) writePacket(packet []byte) error {
	if n.writeLeft == 0 {
		n.Close()
		return errors.New("barf")
	}

	n.writeLeft--
	return n.packetConn.writePacket(packet)
}

func (n *errorKeyingTransport) readPacket() ([]byte, error) {
	if n.readLeft == 0 {
		n.Close()
		return nil, errors.New("barf")
	}

	n.readLeft--
	return n.packetConn.readPacket()
}

func (n *errorKeyingTransport) setStrictMode() error { return nil }

func (n *errorKeyingTransport) setInitialKEXDone() {}

func TestHandshakeErrorHandlingRead(t *testing.T) {
	for i := 0; i < 20; i++ {
		testHandshakeErrorHandlingN(t, i, -1, false)
	}
}

func TestHandshakeErrorHandlingWrite(t *testing.T) {
	for i := 0; i < 20; i++ {
		testHandshakeErrorHandlingN(t, -1, i, false)
	}
}

func TestHandshakeErrorHandlingReadCoupled(t *testing.T) {
	for i := 0; i < 20; i++ {
		testHandshakeErrorHandlingN(t, i, -1, true)
	}
}

func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) {
	for i := 0; i < 20; i++ {
		testHandshakeErrorHandlingN(t, -1, i, true)
	}
}

// testHandshakeErrorHandlingN runs handshakes, injecting errors. If
// handshakeTransport deadlocks, the go runtime will detect it and
// panic.
func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) {
	if (runtime.GOOS == "js" || runtime.GOOS == "wasip1") && runtime.GOARCH == "wasm" {
		t.Skipf("skipping on %s/wasm; see golang.org/issue/32840", runtime.GOOS)
	}
	msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})

	a, b := memPipe()
	defer a.Close()
	defer b.Close()

	key := testSigners["ecdsa"]
	serverConf := Config{RekeyThreshold: minRekeyThreshold}
	serverConf.SetDefaults()
	serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
	serverConn.hostKeys = []Signer{key}
	go serverConn.readLoop()
	go serverConn.kexLoop()

	clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
	clientConf.SetDefaults()
	clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
	clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
	clientConn.hostKeyCallback = InsecureIgnoreHostKey()
	go clientConn.readLoop()
	go clientConn.kexLoop()

	var wg sync.WaitGroup

	for _, hs := range []packetConn{serverConn, clientConn} {
		if !coupled {
			wg.Add(2)
			go func(c packetConn) {
				for i := 0; ; i++ {
					str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8)
					err := c.writePacket(Marshal(&serviceRequestMsg{str}))
					if err != nil {
						break
					}
				}
				wg.Done()
				c.Close()
			}(hs)
			go func(c packetConn) {
				for {
					_, err := c.readPacket()
					if err != nil {
						break
					}
				}
				wg.Done()
			}(hs)
		} else {
			wg.Add(1)
			go func(c packetConn) {
				for {
					_, err := c.readPacket()
					if err != nil {
						break
					}
					if err := c.writePacket(msg); err != nil {
						break
					}

				}
				wg.Done()
			}(hs)
		}
	}
	wg.Wait()
}

func TestDisconnect(t *testing.T) {
	if runtime.GOOS == "plan9" {
		t.Skip("see golang.org/issue/7237")
	}
	checker := &testChecker{}
	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
	if err != nil {
		t.Fatalf("handshakePair: %v", err)
	}

	defer trC.Close()
	defer trS.Close()

	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
	errMsg := &disconnectMsg{
		Reason:  42,
		Message: "such is life",
	}
	trC.writePacket(Marshal(errMsg))
	trC.writePacket([]byte{msgRequestSuccess, 0, 0})

	packet, err := trS.readPacket()
	if err != nil {
		t.Fatalf("readPacket 1: %v", err)
	}
	if packet[0] != msgRequestSuccess {
		t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
	}

	_, err = trS.readPacket()
	if err == nil {
		t.Errorf("readPacket 2 succeeded")
	} else if !reflect.DeepEqual(err, errMsg) {
		t.Errorf("got error %#v, want %#v", err, errMsg)
	}

	_, err = trS.readPacket()
	if err == nil {
		t.Errorf("readPacket 3 succeeded")
	}
}

func TestHandshakeRekeyDefault(t *testing.T) {
	clientConf := &ClientConfig{
		Config: Config{
			Ciphers: []string{"aes128-ctr"},
		},
		HostKeyCallback: InsecureIgnoreHostKey(),
	}
	trC, trS, err := handshakePair(clientConf, "addr", false)
	if err != nil {
		t.Fatalf("handshakePair: %v", err)
	}
	defer trC.Close()
	defer trS.Close()

	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
	trC.Close()

	rgb := (1024 + trC.readBytesLeft) >> 30
	wgb := (1024 + trC.writeBytesLeft) >> 30

	if rgb != 64 {
		t.Errorf("got rekey after %dG read, want 64G", rgb)
	}
	if wgb != 64 {
		t.Errorf("got rekey after %dG write, want 64G", wgb)
	}
}

func TestHandshakeAEADCipherNoMAC(t *testing.T) {
	for _, cipher := range []string{chacha20Poly1305ID, gcm128CipherID} {
		checker := &syncChecker{
			called: make(chan int, 1),
		}
		clientConf := &ClientConfig{
			Config: Config{
				Ciphers: []string{cipher},
				MACs:    []string{},
			},
			HostKeyCallback: checker.Check,
		}
		trC, trS, err := handshakePair(clientConf, "addr", false)
		if err != nil {
			t.Fatalf("handshakePair: %v", err)
		}
		defer trC.Close()
		defer trS.Close()

		<-checker.called
	}
}

// TestNoSHA2Support tests a host key Signer that is not an AlgorithmSigner and
// therefore can't do SHA-2 signatures. Ensures the server does not advertise
// support for them in this case.
func TestNoSHA2Support(t *testing.T) {
	c1, c2, err := netPipe()
	if err != nil {
		t.Fatalf("netPipe: %v", err)
	}
	defer c1.Close()
	defer c2.Close()

	serverConf := &ServerConfig{
		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
			return &Permissions{}, nil
		},
	}
	serverConf.AddHostKey(&legacyRSASigner{testSigners["rsa"]})
	go func() {
		_, _, _, err := NewServerConn(c1, serverConf)
		if err != nil {
			t.Error(err)
		}
	}()

	clientConf := &ClientConfig{
		User:            "test",
		Auth:            []AuthMethod{Password("testpw")},
		HostKeyCallback: FixedHostKey(testSigners["rsa"].PublicKey()),
	}

	if _, _, _, err := NewClientConn(c2, "", clientConf); err != nil {
		t.Fatal(err)
	}
}

func TestMultiAlgoSignerHandshake(t *testing.T) {
	algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner)
	if !ok {
		t.Fatal("rsa test signer does not implement the AlgorithmSigner interface")
	}
	multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512})
	if err != nil {
		t.Fatalf("unable to create multi algorithm signer: %v", err)
	}
	c1, c2, err := netPipe()
	if err != nil {
		t.Fatalf("netPipe: %v", err)
	}
	defer c1.Close()
	defer c2.Close()

	serverConf := &ServerConfig{
		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
			return &Permissions{}, nil
		},
	}
	serverConf.AddHostKey(multiAlgoSigner)
	go NewServerConn(c1, serverConf)

	clientConf := &ClientConfig{
		User:              "test",
		Auth:              []AuthMethod{Password("testpw")},
		HostKeyCallback:   FixedHostKey(testSigners["rsa"].PublicKey()),
		HostKeyAlgorithms: []string{KeyAlgoRSASHA512},
	}

	if _, _, _, err := NewClientConn(c2, "", clientConf); err != nil {
		t.Fatal(err)
	}
}

func TestMultiAlgoSignerNoCommonHostKeyAlgo(t *testing.T) {
	algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner)
	if !ok {
		t.Fatal("rsa test signer does not implement the AlgorithmSigner interface")
	}
	multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512})
	if err != nil {
		t.Fatalf("unable to create multi algorithm signer: %v", err)
	}
	c1, c2, err := netPipe()
	if err != nil {
		t.Fatalf("netPipe: %v", err)
	}
	defer c1.Close()
	defer c2.Close()

	// ssh-rsa is disabled server side
	serverConf := &ServerConfig{
		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
			return &Permissions{}, nil
		},
	}
	serverConf.AddHostKey(multiAlgoSigner)
	go NewServerConn(c1, serverConf)

	// the client only supports ssh-rsa
	clientConf := &ClientConfig{
		User:              "test",
		Auth:              []AuthMethod{Password("testpw")},
		HostKeyCallback:   FixedHostKey(testSigners["rsa"].PublicKey()),
		HostKeyAlgorithms: []string{KeyAlgoRSA},
	}

	_, _, _, err = NewClientConn(c2, "", clientConf)
	if err == nil {
		t.Fatal("succeeded connecting with no common hostkey algorithm")
	}
}

func TestPickIncompatibleHostKeyAlgo(t *testing.T) {
	algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner)
	if !ok {
		t.Fatal("rsa test signer does not implement the AlgorithmSigner interface")
	}
	multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512})
	if err != nil {
		t.Fatalf("unable to create multi algorithm signer: %v", err)
	}
	signer := pickHostKey([]Signer{multiAlgoSigner}, KeyAlgoRSA)
	if signer != nil {
		t.Fatal("incompatible signer returned")
	}
}

func TestStrictKEXResetSeqFirstKEX(t *testing.T) {
	if runtime.GOOS == "plan9" {
		t.Skip("see golang.org/issue/7237")
	}

	checker := &syncChecker{
		waitCall: make(chan int, 10),
		called:   make(chan int, 10),
	}

	checker.waitCall <- 1
	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
	if err != nil {
		t.Fatalf("handshakePair: %v", err)
	}
	<-checker.called

	t.Cleanup(func() {
		trC.Close()
		trS.Close()
	})

	// Throw away the msgExtInfo packet sent during the handshake by the server
	_, err = trC.readPacket()
	if err != nil {
		t.Fatalf("readPacket failed: %s", err)
	}

	// close the handshake transports before checking the sequence number to
	// avoid races.
	trC.Close()
	trS.Close()

	// check that the sequence number counters. We reset after msgNewKeys, but
	// then the server immediately writes msgExtInfo, and we close the
	// transports so we expect read 2, write 0 on the client and read 1, write 1
	// on the server.
	if trC.conn.(*transport).reader.seqNum != 2 || trC.conn.(*transport).writer.seqNum != 0 ||
		trS.conn.(*transport).reader.seqNum != 1 || trS.conn.(*transport).writer.seqNum != 1 {
		t.Errorf(
			"unexpected sequence counters:\nclient: reader %d (expected 2), writer %d (expected 0)\nserver: reader %d (expected 1), writer %d (expected 1)",
			trC.conn.(*transport).reader.seqNum,
			trC.conn.(*transport).writer.seqNum,
			trS.conn.(*transport).reader.seqNum,
			trS.conn.(*transport).writer.seqNum,
		)
	}
}

func TestStrictKEXResetSeqSuccessiveKEX(t *testing.T) {
	if runtime.GOOS == "plan9" {
		t.Skip("see golang.org/issue/7237")
	}

	checker := &syncChecker{
		waitCall: make(chan int, 10),
		called:   make(chan int, 10),
	}

	checker.waitCall <- 1
	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
	if err != nil {
		t.Fatalf("handshakePair: %v", err)
	}
	<-checker.called

	t.Cleanup(func() {
		trC.Close()
		trS.Close()
	})

	// Throw away the msgExtInfo packet sent during the handshake by the server
	_, err = trC.readPacket()
	if err != nil {
		t.Fatalf("readPacket failed: %s", err)
	}

	// write and read five packets on either side to bump the sequence numbers
	for i := 0; i < 5; i++ {
		if err := trC.writePacket([]byte{msgRequestSuccess}); err != nil {
			t.Fatalf("writePacket failed: %s", err)
		}
		if _, err := trS.readPacket(); err != nil {
			t.Fatalf("readPacket failed: %s", err)
		}
		if err := trS.writePacket([]byte{msgRequestSuccess}); err != nil {
			t.Fatalf("writePacket failed: %s", err)
		}
		if _, err := trC.readPacket(); err != nil {
			t.Fatalf("readPacket failed: %s", err)
		}
	}

	// Request a key exchange, which should cause the sequence numbers to reset
	checker.waitCall <- 1
	trC.requestKeyExchange()
	<-checker.called

	// write a packet on the client, and then read it, to verify the key change has actually happened, since
	// the HostKeyCallback is called _during_ the handshake, so isn't actually indicative of the handshake
	// finishing.
	dummyPacket := []byte{99}
	if err := trS.writePacket(dummyPacket); err != nil {
		t.Fatalf("writePacket failed: %s", err)
	}
	if p, err := trC.readPacket(); err != nil {
		t.Fatalf("readPacket failed: %s", err)
	} else if !bytes.Equal(p, dummyPacket) {
		t.Fatalf("unexpected packet: got %x, want %x", p, dummyPacket)
	}

	// close the handshake transports before checking the sequence number to
	// avoid races.
	trC.Close()
	trS.Close()

	if trC.conn.(*transport).reader.seqNum != 2 || trC.conn.(*transport).writer.seqNum != 0 ||
		trS.conn.(*transport).reader.seqNum != 1 || trS.conn.(*transport).writer.seqNum != 1 {
		t.Errorf(
			"unexpected sequence counters:\nclient: reader %d (expected 2), writer %d (expected 0)\nserver: reader %d (expected 1), writer %d (expected 1)",
			trC.conn.(*transport).reader.seqNum,
			trC.conn.(*transport).writer.seqNum,
			trS.conn.(*transport).reader.seqNum,
			trS.conn.(*transport).writer.seqNum,
		)
	}
}

func TestSeqNumIncrease(t *testing.T) {
	if runtime.GOOS == "plan9" {
		t.Skip("see golang.org/issue/7237")
	}

	checker := &syncChecker{
		waitCall: make(chan int, 10),
		called:   make(chan int, 10),
	}

	checker.waitCall <- 1
	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
	if err != nil {
		t.Fatalf("handshakePair: %v", err)
	}
	<-checker.called

	t.Cleanup(func() {
		trC.Close()
		trS.Close()
	})

	// Throw away the msgExtInfo packet sent during the handshake by the server
	_, err = trC.readPacket()
	if err != nil {
		t.Fatalf("readPacket failed: %s", err)
	}

	// write and read five packets on either side to bump the sequence numbers
	for i := 0; i < 5; i++ {
		if err := trC.writePacket([]byte{msgRequestSuccess}); err != nil {
			t.Fatalf("writePacket failed: %s", err)
		}
		if _, err := trS.readPacket(); err != nil {
			t.Fatalf("readPacket failed: %s", err)
		}
		if err := trS.writePacket([]byte{msgRequestSuccess}); err != nil {
			t.Fatalf("writePacket failed: %s", err)
		}
		if _, err := trC.readPacket(); err != nil {
			t.Fatalf("readPacket failed: %s", err)
		}
	}

	// close the handshake transports before checking the sequence number to
	// avoid races.
	trC.Close()
	trS.Close()

	if trC.conn.(*transport).reader.seqNum != 7 || trC.conn.(*transport).writer.seqNum != 5 ||
		trS.conn.(*transport).reader.seqNum != 6 || trS.conn.(*transport).writer.seqNum != 6 {
		t.Errorf(
			"unexpected sequence counters:\nclient: reader %d (expected 7), writer %d (expected 5)\nserver: reader %d (expected 6), writer %d (expected 6)",
			trC.conn.(*transport).reader.seqNum,
			trC.conn.(*transport).writer.seqNum,
			trS.conn.(*transport).reader.seqNum,
			trS.conn.(*transport).writer.seqNum,
		)
	}
}

func TestStrictKEXUnexpectedMsg(t *testing.T) {
	if runtime.GOOS == "plan9" {
		t.Skip("see golang.org/issue/7237")
	}

	// Check that unexpected messages during the handshake cause failure
	_, _, err := handshakePair(&ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}, "addr", true)
	if err == nil {
		t.Fatal("handshake should fail when there are unexpected messages during the handshake")
	}

	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}, "addr", false)
	if err != nil {
		t.Fatalf("handshake failed: %s", err)
	}

	// Check that ignore/debug pacekts are still ignored outside of the handshake
	if err := trC.writePacket([]byte{msgIgnore}); err != nil {
		t.Fatalf("writePacket failed: %s", err)
	}
	if err := trC.writePacket([]byte{msgDebug}); err != nil {
		t.Fatalf("writePacket failed: %s", err)
	}
	dummyPacket := []byte{99}
	if err := trC.writePacket(dummyPacket); err != nil {
		t.Fatalf("writePacket failed: %s", err)
	}

	if p, err := trS.readPacket(); err != nil {
		t.Fatalf("readPacket failed: %s", err)
	} else if !bytes.Equal(p, dummyPacket) {
		t.Fatalf("unexpected packet: got %x, want %x", p, dummyPacket)
	}
}

func TestStrictKEXMixed(t *testing.T) {
	// Test that we still support a mixed connection, where one side sends kex-strict but the other
	// side doesn't.

	a, b, err := netPipe()
	if err != nil {
		t.Fatalf("netPipe failed: %s", err)
	}

	var trC, trS keyingTransport

	trC = newTransport(a, rand.Reader, true)
	trS = newTransport(b, rand.Reader, false)
	trS = addNoiseTransport(trS)

	clientConf := &ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}
	clientConf.SetDefaults()

	v := []byte("version")
	client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())

	serverConf := &ServerConfig{}
	serverConf.AddHostKey(testSigners["ecdsa"])
	serverConf.AddHostKey(testSigners["rsa"])
	serverConf.SetDefaults()

	transport := newHandshakeTransport(trS, &serverConf.Config, []byte("version"), []byte("version"))
	transport.hostKeys = serverConf.hostKeys
	transport.publicKeyAuthAlgorithms = serverConf.PublicKeyAuthAlgorithms

	readOneFailure := make(chan error, 1)
	go func() {
		if _, err := transport.readOnePacket(true); err != nil {
			readOneFailure <- err
		}
	}()

	// Basically sendKexInit, but without the kex-strict extension algorithm
	msg := &kexInitMsg{
		KexAlgos:                transport.config.KeyExchanges,
		CiphersClientServer:     transport.config.Ciphers,
		CiphersServerClient:     transport.config.Ciphers,
		MACsClientServer:        transport.config.MACs,
		MACsServerClient:        transport.config.MACs,
		CompressionClientServer: supportedCompressions,
		CompressionServerClient: supportedCompressions,
		ServerHostKeyAlgos:      []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA},
	}
	packet := Marshal(msg)
	// writePacket destroys the contents, so save a copy.
	packetCopy := make([]byte, len(packet))
	copy(packetCopy, packet)
	if err := transport.pushPacket(packetCopy); err != nil {
		t.Fatalf("pushPacket: %s", err)
	}
	transport.sentInitMsg = msg
	transport.sentInitPacket = packet

	if err := transport.getWriteError(); err != nil {
		t.Fatalf("getWriteError failed: %s", err)
	}
	var request *pendingKex
	select {
	case err = <-readOneFailure:
		t.Fatalf("server readOnePacket failed: %s", err)
	case request = <-transport.startKex:
		break
	}

	// We expect the following calls to fail if the side which does not support
	// kex-strict sends unexpected/ignored packets during the handshake, even if
	// the other side does support kex-strict.

	if err := transport.enterKeyExchange(request.otherInit); err != nil {
		t.Fatalf("enterKeyExchange failed: %s", err)
	}
	if err := client.waitSession(); err != nil {
		t.Fatalf("client.waitSession: %v", err)
	}
}