// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build linux
// +build linux

package netstack

import (
	"encoding/json"
	"fmt"
	"io"
	"io/ioutil"
	"log"
	"net"
	"os"
	"os/exec"
	"syscall"
	"unsafe"

	"github.com/creack/pty"
	"github.com/gliderlabs/ssh"
	gossh "golang.org/x/crypto/ssh"
	"inet.af/netaddr"
	"tailscale.com/envknob"
	"tailscale.com/net/tsaddr"
)

func init() {
	sshDemo = sshDemoImpl
}

func sshDemoImpl(ns *Impl, c net.Conn) error {
	hostKey, err := ioutil.ReadFile("/etc/ssh/ssh_host_ed25519_key")
	if err != nil {
		return err
	}
	signer, err := gossh.ParsePrivateKey(hostKey)
	if err != nil {
		return err
	}
	srv := &ssh.Server{
		Handler:           ns.handleSSH,
		RequestHandlers:   map[string]ssh.RequestHandler{},
		SubsystemHandlers: map[string]ssh.SubsystemHandler{},
		ChannelHandlers:   map[string]ssh.ChannelHandler{},
	}
	for k, v := range ssh.DefaultRequestHandlers {
		srv.RequestHandlers[k] = v
	}
	for k, v := range ssh.DefaultChannelHandlers {
		srv.ChannelHandlers[k] = v
	}
	for k, v := range ssh.DefaultSubsystemHandlers {
		srv.SubsystemHandlers[k] = v
	}
	srv.AddHostKey(signer)

	srv.HandleConn(c)
	return nil
}

func (ns *Impl) handleSSH(s ssh.Session) {
	lb := ns.lb
	user := s.User()
	addr := s.RemoteAddr()
	log.Printf("Handling SSH from %v for user %v", addr, user)
	ta, ok := addr.(*net.TCPAddr)
	if !ok {
		log.Printf("tsshd: rejecting non-TCP addr %T %v", addr, addr)
		s.Exit(1)
		return
	}
	tanetaddr, ok := netaddr.FromStdIP(ta.IP)
	if !ok {
		log.Printf("tsshd: rejecting unparseable addr %v", ta.IP)
		s.Exit(1)
		return
	}
	if !tsaddr.IsTailscaleIP(tanetaddr) {
		log.Printf("tsshd: rejecting non-Tailscale addr %v", ta.IP)
		s.Exit(1)
		return
	}

	ptyReq, winCh, isPty := s.Pty()
	if !isPty {
		fmt.Fprintf(s, "TODO scp etc")
		s.Exit(1)
		return
	}
	srcIPP := netaddr.IPPortFrom(tanetaddr, uint16(ta.Port))
	node, uprof, ok := lb.WhoIs(srcIPP)
	if !ok {
		fmt.Fprintf(s, "Hello, %v. I don't know who you are.\n", srcIPP)
		s.Exit(0)
		return
	}
	allow := envknob.String("TS_SSH_ALLOW_LOGIN")
	if allow == "" || uprof.LoginName != allow {
		log.Printf("ssh: access denied for %q (only allowing %q)", uprof.LoginName, allow)
		jnode, _ := json.Marshal(node)
		jprof, _ := json.Marshal(uprof)
		fmt.Fprintf(s, "Access denied.\n\nYou are node: %s\n\nYour profile: %s\n\nYou wanted %+v\n", jnode, jprof, ptyReq)
		s.Exit(1)
		return
	}

	var cmd *exec.Cmd
	sshUser := s.User()
	if os.Getuid() != 0 || sshUser == "root" {
		cmd = exec.Command("/bin/bash")
	} else {
		cmd = exec.Command("/usr/bin/env", "su", "-", sshUser)
	}
	cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
	f, err := pty.Start(cmd)
	if err != nil {
		log.Printf("running shell: %v", err)
		s.Exit(1)
		return
	}
	defer f.Close()
	go func() {
		for win := range winCh {
			setWinsize(f, win.Width, win.Height)
		}
	}()
	go func() {
		io.Copy(f, s) // stdin
	}()
	io.Copy(s, f) // stdout
	cmd.Process.Kill()
	if err := cmd.Wait(); err != nil {
		s.Exit(1)
	}
	s.Exit(0)
	return
}

func setWinsize(f *os.File, w, h int) {
	syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
		uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
}