// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package tstun

import (
	"bytes"
	"sync/atomic"
	"testing"
	"unsafe"

	"github.com/tailscale/wireguard-go/tun/tuntest"
	"tailscale.com/types/logger"
	"tailscale.com/wgengine/filter"
	"tailscale.com/wgengine/packet"
)

func udp(src, dst packet.IP, sport, dport uint16) []byte {
	header := &packet.UDPHeader{
		IPHeader: packet.IPHeader{
			SrcIP: src,
			DstIP: dst,
			IPID:  0,
		},
		SrcPort: sport,
		DstPort: dport,
	}
	return packet.Generate(header, []byte("udp_payload"))
}

func filterNet(ip, mask packet.IP) filter.Net {
	return filter.Net{IP: ip, Mask: mask}
}

func nets(ips []packet.IP) []filter.Net {
	out := make([]filter.Net, 0, len(ips))
	for _, ip := range ips {
		out = append(out, filterNet(ip, filter.Netmask(32)))
	}
	return out
}

func ippr(ip packet.IP, start, end uint16) []filter.NetPortRange {
	return []filter.NetPortRange{
		filter.NetPortRange{
			Net:   filterNet(ip, filter.Netmask(32)),
			Ports: filter.PortRange{First: start, Last: end},
		},
	}
}

func setfilter(logf logger.Logf, tun *TUN) {
	matches := filter.Matches{
		{Srcs: nets([]packet.IP{0x05060708}), Dsts: ippr(0x01020304, 89, 90)},
		{Srcs: nets([]packet.IP{0x01020304}), Dsts: ippr(0x05060708, 98, 98)},
	}
	localNets := []filter.Net{
		filterNet(packet.IP(0x01020304), filter.Netmask(16)),
	}
	tun.SetFilter(filter.New(matches, localNets, nil, logf))
}

func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *TUN) {
	chtun := tuntest.NewChannelTUN()
	tun := WrapTUN(logf, chtun.TUN())
	if secure {
		setfilter(logf, tun)
	} else {
		tun.disableFilter = true
	}
	return chtun, tun
}

func newFakeTUN(logf logger.Logf, secure bool) (*fakeTUN, *TUN) {
	ftun := NewFakeTUN()
	tun := WrapTUN(logf, ftun)
	if secure {
		setfilter(logf, tun)
	} else {
		tun.disableFilter = true
	}
	return ftun.(*fakeTUN), tun
}

func TestReadAndInject(t *testing.T) {
	chtun, tun := newChannelTUN(t.Logf, false)
	defer tun.Close()

	const size = 2 // all payloads have this size
	written := []string{"w0", "w1"}
	injected := []string{"i0", "i1"}

	go func() {
		for _, packet := range written {
			payload := []byte(packet)
			chtun.Outbound <- payload
		}
	}()

	for _, packet := range injected {
		go func(packet string) {
			payload := []byte(packet)
			err := tun.InjectOutbound(payload)
			if err != nil {
				t.Errorf("%s: error: %v", packet, err)
			}
		}(packet)
	}

	var buf [MaxPacketSize]byte
	var seen = make(map[string]bool)
	// We expect the same packets back, in no particular order.
	for i := 0; i < len(written)+len(injected); i++ {
		n, err := tun.Read(buf[:], 0)
		if err != nil {
			t.Errorf("read %d: error: %v", i, err)
		}
		if n != size {
			t.Errorf("read %d: got size %d; want %d", i, n, size)
		}
		got := string(buf[:n])
		t.Logf("read %d: got %s", i, got)
		seen[got] = true
	}

	for _, packet := range written {
		if !seen[packet] {
			t.Errorf("%s not received", packet)
		}
	}
	for _, packet := range injected {
		if !seen[packet] {
			t.Errorf("%s not received", packet)
		}
	}
}

func TestWriteAndInject(t *testing.T) {
	chtun, tun := newChannelTUN(t.Logf, false)
	defer tun.Close()

	const size = 2 // all payloads have this size
	written := []string{"w0", "w1"}
	injected := []string{"i0", "i1"}

	go func() {
		for _, packet := range written {
			payload := []byte(packet)
			n, err := tun.Write(payload, 0)
			if err != nil {
				t.Errorf("%s: error: %v", packet, err)
			}
			if n != size {
				t.Errorf("%s: got size %d; want %d", packet, n, size)
			}
		}
	}()

	for _, packet := range injected {
		go func(packet string) {
			payload := []byte(packet)
			err := tun.InjectInboundCopy(payload)
			if err != nil {
				t.Errorf("%s: error: %v", packet, err)
			}
		}(packet)
	}

	seen := make(map[string]bool)
	// We expect the same packets back, in no particular order.
	for i := 0; i < len(written)+len(injected); i++ {
		packet := <-chtun.Inbound
		got := string(packet)
		t.Logf("read %d: got %s", i, got)
		seen[got] = true
	}

	for _, packet := range written {
		if !seen[packet] {
			t.Errorf("%s not received", packet)
		}
	}
	for _, packet := range injected {
		if !seen[packet] {
			t.Errorf("%s not received", packet)
		}
	}
}

func TestFilter(t *testing.T) {
	chtun, tun := newChannelTUN(t.Logf, true)
	defer tun.Close()

	type direction int

	const (
		in direction = iota
		out
	)

	tests := []struct {
		name string
		dir  direction
		drop bool
		data []byte
	}{
		{"junk_in", in, true, []byte("\x45not a valid IPv4 packet")},
		{"junk_out", out, true, []byte("\x45not a valid IPv4 packet")},
		{"bad_port_in", in, true, udp(0x05060708, 0x01020304, 22, 22)},
		{"bad_port_out", out, false, udp(0x01020304, 0x05060708, 22, 22)},
		{"bad_ip_in", in, true, udp(0x08010101, 0x01020304, 89, 89)},
		{"bad_ip_out", out, false, udp(0x01020304, 0x08010101, 98, 98)},
		{"good_packet_in", in, false, udp(0x05060708, 0x01020304, 89, 89)},
		{"good_packet_out", out, false, udp(0x01020304, 0x05060708, 98, 98)},
	}

	// A reader on the other end of the TUN.
	go func() {
		var recvbuf []byte
		for {
			select {
			case <-tun.closed:
				return
			case recvbuf = <-chtun.Inbound:
				// continue
			}
			for _, tt := range tests {
				if tt.drop && bytes.Equal(recvbuf, tt.data) {
					t.Errorf("did not drop %s", tt.name)
				}
			}
		}
	}()

	var buf [MaxPacketSize]byte
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			var n int
			var err error
			var filtered bool

			if tt.dir == in {
				_, err = tun.Write(tt.data, 0)
				if err == ErrFiltered {
					filtered = true
					err = nil
				}
			} else {
				chtun.Outbound <- tt.data
				n, err = tun.Read(buf[:], 0)
				// In the read direction, errors are fatal, so we return n = 0 instead.
				filtered = (n == 0)
			}

			if err != nil {
				t.Errorf("got err %v; want nil", err)
			}

			if filtered {
				if !tt.drop {
					t.Errorf("got drop; want accept")
				}
			} else {
				if tt.drop {
					t.Errorf("got accept; want drop")
				}
			}
		})
	}
}

func TestAllocs(t *testing.T) {
	ftun, tun := newFakeTUN(t.Logf, false)
	defer tun.Close()

	go func() {
		var buf []byte
		for {
			select {
			case <-tun.closed:
				return
			case buf = <-ftun.datachan:
				// continue
			}

			select {
			case <-tun.closed:
				return
			case ftun.datachan <- buf:
				// continue
			}
		}
	}()

	buf := []byte{0x00}
	allocs := testing.AllocsPerRun(100, func() {
		_, err := tun.Write(buf, 0)
		if err != nil {
			t.Errorf("write: error: %v", err)
			return
		}

		_, err = tun.Read(buf, 0)
		if err != nil {
			t.Errorf("read: error: %v", err)
			return
		}

	})

	if allocs > 0 {
		t.Errorf("read allocs = %v; want 0", allocs)
	}
}

func BenchmarkWrite(b *testing.B) {
	ftun, tun := newFakeTUN(b.Logf, true)
	defer tun.Close()

	go func() {
		for {
			select {
			case <-tun.closed:
				return
			case <-ftun.datachan:
				// continue
			}
		}
	}()

	packet := udp(0x05060708, 0x01020304, 89, 89)
	for i := 0; i < b.N; i++ {
		_, err := tun.Write(packet, 0)
		if err != nil {
			b.Errorf("err = %v; want nil", err)
		}
	}
}

func BenchmarkRead(b *testing.B) {
	ftun, tun := newFakeTUN(b.Logf, true)
	defer tun.Close()

	packet := udp(0x05060708, 0x01020304, 89, 89)
	go func() {
		for {
			select {
			case <-tun.closed:
				return
			case ftun.datachan <- packet:
				// continue
			}
		}
	}()

	var buf [128]byte
	for i := 0; i < b.N; i++ {
		_, err := tun.Read(buf[:], 0)
		if err != nil {
			b.Errorf("err = %v; want nil", err)
		}
	}
}

func TestAtomic64Alignment(t *testing.T) {
	off := unsafe.Offsetof(TUN{}.lastActivityAtomic)
	if off%8 != 0 {
		t.Errorf("offset %v not 8-byte aligned", off)
	}

	c := new(TUN)
	atomic.StoreInt64(&c.lastActivityAtomic, 123)
}