// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

//go:build integrationtest
// +build integrationtest

package tailssh

import (
	"bufio"
	"crypto/ecdsa"
	"crypto/ed25519"
	"crypto/elliptic"
	"crypto/rand"
	"crypto/rsa"
	"crypto/x509"
	"encoding/pem"
	"fmt"
	"io"
	"log"
	"net"
	"net/http"
	"net/netip"
	"os"
	"os/exec"
	"runtime"
	"strings"
	"testing"
	"time"

	"github.com/google/go-cmp/cmp"
	"github.com/pkg/sftp"
	gossh "github.com/tailscale/golang-x-crypto/ssh"
	"golang.org/x/crypto/ssh"
	"tailscale.com/net/tsdial"
	"tailscale.com/tailcfg"
	"tailscale.com/types/key"
	"tailscale.com/types/netmap"
)

// This file contains integration tests of the SSH functionality. These tests
// exercise everything except for the authentication logic.
//
// The tests make the following assumptions about the environment:
//
// - OS is one of MacOS or Linux
// - Test is being run as root (e.g. go test -tags integrationtest -c . && sudo ./tailssh.test -test.run TestIntegration)
// - TAILSCALED_PATH environment variable points at tailscaled binary
// - User "testuser" exists
// - "testuser" is in groups "groupone" and "grouptwo"

func TestMain(m *testing.M) {
	// Create our log file.
	file, err := os.OpenFile("/tmp/tailscalessh.log", os.O_CREATE|os.O_WRONLY, 0666)
	if err != nil {
		log.Fatal(err)
	}
	file.Close()

	// Tail our log file.
	cmd := exec.Command("tail", "-f", "/tmp/tailscalessh.log")

	r, err := cmd.StdoutPipe()
	if err != nil {
		return
	}

	scanner := bufio.NewScanner(r)
	go func() {
		for scanner.Scan() {
			line := scanner.Text()
			log.Println(line)
		}
	}()

	err = cmd.Start()
	if err != nil {
		return
	}

	m.Run()
}

func TestIntegrationSSH(t *testing.T) {
	debugTest.Store(true)
	t.Cleanup(func() {
		debugTest.Store(false)
	})

	homeDir := "/home/testuser"
	if runtime.GOOS == "darwin" {
		homeDir = "/Users/testuser"
	}

	tests := []struct {
		cmd  string
		want []string
	}{
		{
			cmd:  "id",
			want: []string{"testuser", "groupone", "grouptwo"},
		},
		{
			cmd:  "pwd",
			want: []string{homeDir},
		},
	}

	for _, test := range tests {
		// run every test both without and with a shell
		for _, shell := range []bool{false, true} {
			shellQualifier := "no_shell"
			if shell {
				shellQualifier = "shell"
			}

			t.Run(fmt.Sprintf("%s_%s", test.cmd, shellQualifier), func(t *testing.T) {
				s := testSession(t)

				if shell {
					err := s.RequestPty("xterm", 40, 80, ssh.TerminalModes{
						ssh.ECHO:          1,
						ssh.TTY_OP_ISPEED: 14400,
						ssh.TTY_OP_OSPEED: 14400,
					})
					if err != nil {
						t.Fatalf("unable to request shell: %s", err)
					}
				}

				got := s.run(t, test.cmd)
				for _, want := range test.want {
					if !strings.Contains(got, want) {
						t.Errorf("%q does not contain %q", got, want)
					}
				}
			})
		}
	}
}

func TestIntegrationSFTP(t *testing.T) {
	debugTest.Store(true)
	t.Cleanup(func() {
		debugTest.Store(false)
	})

	filePath := "/tmp/sftptest.dat"
	wantText := "hello world"

	cl := testClient(t)
	scl, err := sftp.NewClient(cl)
	if err != nil {
		t.Fatalf("can't get sftp client: %s", err)
	}

	file, err := scl.Create(filePath)
	if err != nil {
		t.Fatalf("can't create file: %s", err)
	}
	_, err = file.Write([]byte(wantText))
	if err != nil {
		t.Fatalf("can't write to file: %s", err)
	}
	err = file.Close()
	if err != nil {
		t.Fatalf("can't close file: %s", err)
	}

	file, err = scl.OpenFile(filePath, os.O_RDONLY)
	if err != nil {
		t.Fatalf("can't open file: %s", err)
	}
	defer file.Close()
	gotText, err := io.ReadAll(file)
	if err != nil {
		t.Fatalf("can't read file: %s", err)
	}
	if diff := cmp.Diff(string(gotText), wantText); diff != "" {
		t.Fatalf("unexpected file contents (-got +want):\n%s", diff)
	}

	s := testSessionFor(t, cl)
	got := s.run(t, "ls -l "+filePath)
	if !strings.Contains(got, "testuser") {
		t.Fatalf("unexpected file owner user: %s", got)
	} else if !strings.Contains(got, "testuser") {
		t.Fatalf("unexpected file owner group: %s", got)
	}
}

type session struct {
	*ssh.Session

	stdin  io.WriteCloser
	stdout io.ReadCloser
	stderr io.ReadCloser
}

func (s *session) run(t *testing.T, cmdString string) string {
	t.Helper()

	err := s.Start(cmdString)
	if err != nil {
		t.Fatalf("unable to start command: %s", err)
	}

	ch := make(chan []byte)
	go func() {
		for {
			b := make([]byte, 1)
			n, err := s.stdout.Read(b)
			if n > 0 {
				ch <- b
			}
			if err == io.EOF {
				return
			}
		}
	}()

	// Read first byte in blocking fashion.
	_got := <-ch

	// Read subsequent bytes in non-blocking fashion.
readLoop:
	for {
		select {
		case b := <-ch:
			_got = append(_got, b...)
		case <-time.After(25 * time.Millisecond):
			break readLoop
		}
	}

	return string(_got)
}

func testClient(t *testing.T) *ssh.Client {
	t.Helper()

	username := "testuser"
	srv := &server{
		lb:             &testBackend{localUser: username},
		logf:           log.Printf,
		tailscaledPath: os.Getenv("TAILSCALED_PATH"),
		timeNow:        time.Now,
	}

	l, err := net.Listen("tcp", "127.0.0.1:0")
	if err != nil {
		t.Fatal(err)
	}
	t.Cleanup(func() { l.Close() })

	go func() {
		conn, err := l.Accept()
		if err == nil {
			go srv.HandleSSHConn(&addressFakingConn{conn})
		}
	}()

	cl, err := ssh.Dial("tcp", l.Addr().String(), &ssh.ClientConfig{
		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
	})
	if err != nil {
		log.Fatal(err)
	}
	t.Cleanup(func() { cl.Close() })

	return cl
}

func testSession(t *testing.T) *session {
	cl := testClient(t)
	return testSessionFor(t, cl)
}

func testSessionFor(t *testing.T, cl *ssh.Client) *session {
	s, err := cl.NewSession()
	if err != nil {
		log.Fatal(err)
	}
	t.Cleanup(func() { s.Close() })

	stdinReader, stdinWriter := io.Pipe()
	stdoutReader, stdoutWriter := io.Pipe()
	stderrReader, stderrWriter := io.Pipe()
	s.Stdin = stdinReader
	s.Stdout = io.MultiWriter(stdoutWriter, os.Stdout)
	s.Stderr = io.MultiWriter(stderrWriter, os.Stderr)
	return &session{
		Session: s,
		stdin:   stdinWriter,
		stdout:  stdoutReader,
		stderr:  stderrReader,
	}
}

// testBackend implements ipnLocalBackend
type testBackend struct {
	localUser string
}

func (tb *testBackend) GetSSH_HostKeys() ([]gossh.Signer, error) {
	var result []gossh.Signer
	for _, typ := range []string{"ed25519", "ecdsa", "rsa"} {
		var priv any
		var err error
		switch typ {
		case "ed25519":
			_, priv, err = ed25519.GenerateKey(rand.Reader)
		case "ecdsa":
			curve := elliptic.P256()
			priv, err = ecdsa.GenerateKey(curve, rand.Reader)
		case "rsa":
			const keySize = 2048
			priv, err = rsa.GenerateKey(rand.Reader, keySize)
		}
		if err != nil {
			return nil, err
		}
		mk, err := x509.MarshalPKCS8PrivateKey(priv)
		if err != nil {
			return nil, err
		}
		hostKey := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: mk})
		signer, err := gossh.ParsePrivateKey(hostKey)
		if err != nil {
			return nil, err
		}
		result = append(result, signer)
	}
	return result, nil
}

func (tb *testBackend) ShouldRunSSH() bool {
	return true
}

func (tb *testBackend) NetMap() *netmap.NetworkMap {
	return &netmap.NetworkMap{
		SSHPolicy: &tailcfg.SSHPolicy{
			Rules: []*tailcfg.SSHRule{
				&tailcfg.SSHRule{
					Principals: []*tailcfg.SSHPrincipal{{Any: true}},
					Action:     &tailcfg.SSHAction{Accept: true},
					SSHUsers:   map[string]string{"*": tb.localUser},
				},
			},
		},
	}
}

func (tb *testBackend) WhoIs(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
	return (&tailcfg.Node{}).View(), tailcfg.UserProfile{
		LoginName: tb.localUser + "@example.com",
	}, true
}

func (tb *testBackend) DoNoiseRequest(req *http.Request) (*http.Response, error) {
	return nil, nil
}

func (tb *testBackend) Dialer() *tsdial.Dialer {
	return nil
}

func (tb *testBackend) TailscaleVarRoot() string {
	return ""
}

func (tb *testBackend) NodeKey() key.NodePublic {
	return key.NodePublic{}
}

type addressFakingConn struct {
	net.Conn
}

func (conn *addressFakingConn) LocalAddr() net.Addr {
	return &net.TCPAddr{
		IP:   net.ParseIP("100.100.100.101"),
		Port: 22,
	}
}

func (conn *addressFakingConn) RemoteAddr() net.Addr {
	return &net.TCPAddr{
		IP:   net.ParseIP("100.100.100.102"),
		Port: 10002,
	}
}