// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

//go:build !plan9

package ws

import (
	"encoding/binary"
	"fmt"
	"reflect"
	"testing"
	"time"

	"math/rand"

	"go.uber.org/zap"
	"golang.org/x/net/websocket"
)

func Test_msg_Parse(t *testing.T) {
	zl, err := zap.NewDevelopment()
	if err != nil {
		t.Fatalf("error creating a test logger: %v", err)
	}
	testMask := [4]byte{1, 2, 3, 4}
	bs126, bs126Len := bytesSlice2ByteLen(t)
	bs127, bs127Len := byteSlice8ByteLen(t)
	tests := []struct {
		name            string
		b               []byte
		initialPayload  []byte
		wantPayload     []byte
		wantIsFinalized bool
		wantStreamID    uint32
		wantErr         bool
	}{
		{
			name:            "single_fragment_stdout_stream_no_payload_no_mask",
			b:               []byte{0x82, 0x1, 0x1},
			wantPayload:     nil,
			wantIsFinalized: true,
			wantStreamID:    1,
		},
		{
			name:            "single_fragment_stderr_steam_no_payload_has_mask",
			b:               append([]byte{0x82, 0x81, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x2})...),
			wantPayload:     nil,
			wantIsFinalized: true,
			wantStreamID:    2,
		},
		{
			name:            "single_fragment_stdout_stream_no_mask_has_payload",
			b:               []byte{0x82, 0x3, 0x1, 0x7, 0x8},
			wantPayload:     []byte{0x7, 0x8},
			wantIsFinalized: true,
			wantStreamID:    1,
		},
		{
			name:            "single_fragment_stdout_stream_has_mask_has_payload",
			b:               append([]byte{0x82, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
			wantPayload:     []byte{0x7, 0x8},
			wantIsFinalized: true,
			wantStreamID:    1,
		},
		{
			name:         "initial_fragment_stdout_stream_no_mask_has_payload",
			b:            []byte{0x2, 0x3, 0x1, 0x7, 0x8},
			wantPayload:  []byte{0x7, 0x8},
			wantStreamID: 1,
		},
		{
			name:         "initial_fragment_stdout_stream_has_mask_has_payload",
			b:            append([]byte{0x2, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
			wantPayload:  []byte{0x7, 0x8},
			wantStreamID: 1,
		},
		{
			name:           "subsequent_fragment_stdout_stream_no_mask_has_payload",
			b:              []byte{0x0, 0x3, 0x1, 0x7, 0x8},
			initialPayload: []byte{0x1, 0x2, 0x3},
			wantPayload:    []byte{0x1, 0x2, 0x3, 0x7, 0x8},
			wantStreamID:   1,
		},
		{
			name:           "subsequent_fragment_stdout_stream_has_mask_has_payload",
			b:              append([]byte{0x0, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
			initialPayload: []byte{0x1, 0x2, 0x3},
			wantPayload:    []byte{0x1, 0x2, 0x3, 0x7, 0x8},
			wantStreamID:   1,
		},
		{
			name:            "final_fragment_stdout_stream_no_mask_has_payload",
			b:               []byte{0x80, 0x3, 0x1, 0x7, 0x8},
			initialPayload:  []byte{0x1, 0x2, 0x3},
			wantIsFinalized: true,
			wantPayload:     []byte{0x1, 0x2, 0x3, 0x7, 0x8},
			wantStreamID:    1,
		},
		{
			name:            "final_fragment_stdout_stream_has_mask_has_payload",
			b:               append([]byte{0x80, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
			initialPayload:  []byte{0x1, 0x2, 0x3},
			wantIsFinalized: true,
			wantPayload:     []byte{0x1, 0x2, 0x3, 0x7, 0x8},
			wantStreamID:    1,
		},
		{
			name:            "single_large_fragment_no_mask_length_hint_126",
			b:               append(append([]byte{0x80, 0x7e}, bs126Len...), append([]byte{0x1}, bs126...)...),
			wantIsFinalized: true,
			wantPayload:     bs126,
			wantStreamID:    1,
		},
		{
			name:            "single_large_fragment_no_mask_length_hint_127",
			b:               append(append([]byte{0x80, 0x7f}, bs127Len...), append([]byte{0x1}, bs127...)...),
			wantIsFinalized: true,
			wantPayload:     bs127,
			wantStreamID:    1,
		},
		{
			name:    "zero_length_bytes",
			b:       []byte{},
			wantErr: true,
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			msg := &message{
				typ:     binaryMessage,
				payload: tt.initialPayload,
			}
			if _, err := msg.Parse(tt.b, zl.Sugar()); (err != nil) != tt.wantErr {
				t.Errorf("msg.Parse() = %v, wantsErr: %t", err, tt.wantErr)
			}
			if msg.isFinalized != tt.wantIsFinalized {
				t.Errorf("wants message to be finalized: %t, got: %t", tt.wantIsFinalized, msg.isFinalized)
			}
			if msg.streamID.Load() != tt.wantStreamID {
				t.Errorf("wants stream ID: %d, got: %d", tt.wantStreamID, msg.streamID.Load())
			}
			if !reflect.DeepEqual(msg.payload, tt.wantPayload) {
				t.Errorf("unexpected message payload after Parse, wants %b got %b", tt.wantPayload, msg.payload)
			}
		})
	}
}

// Test_msg_Parse_Rand calls Parse with a randomly generated input to verify
// that it doesn't panic.
func Test_msg_Parse_Rand(t *testing.T) {
	zl, err := zap.NewDevelopment()
	if err != nil {
		t.Fatalf("error creating a test logger: %v", err)
	}
	r := rand.New(rand.NewSource(time.Now().UnixNano()))
	for i := range 100 {
		n := r.Intn(4096)
		b := make([]byte, n)
		_, err := r.Read(b)
		if err != nil {
			t.Fatalf("error generating random byte slice: %v", err)
		}
		msg := message{typ: binaryMessage}
		f := func() {
			msg.Parse(b, zl.Sugar())
		}
		testPanic(t, f, fmt.Sprintf("[%d] Parse panicked running with byte slice of length %d: %v", i, n, r))
	}
}

// byteSlice2ByteLen generates a number that represents websocket message fragment length and is stored in an 8 byte slice.
// Returns the byte slice with the length as well as a slice of arbitrary bytes of the given length.
// This is used to generate test input representing websocket message with payload length hint 126.
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
func bytesSlice2ByteLen(t *testing.T) ([]byte, []byte) {
	r := rand.New(rand.NewSource(time.Now().UnixNano()))
	var n uint16
	n = uint16(rand.Intn(65535 - 1)) // space for and additional 1 byte stream ID
	b := make([]byte, n)
	_, err := r.Read(b)
	if err != nil {
		t.Fatalf("error generating random byte slice: %v ", err)
	}
	bb := make([]byte, 2)
	binary.BigEndian.PutUint16(bb, n+1) // + stream ID
	return b, bb
}

// byteSlice8ByteLen generates a number that represents websocket message fragment length and is stored in an 8 byte slice.
// Returns the byte slice with the length as well as a slice of arbitrary bytes of the given length.
// This is used to generate test input representing websocket message with payload length hint 127.
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
func byteSlice8ByteLen(t *testing.T) ([]byte, []byte) {
	nanos := time.Now().UnixNano()
	t.Logf("Creating random source with seed %v", nanos)
	r := rand.New(rand.NewSource(nanos))
	var n uint64
	n = uint64(rand.Intn(websocket.DefaultMaxPayloadBytes - 1)) // space for and additional 1 byte stream ID
	t.Logf("byteSlice8ByteLen: generating message payload of length %d", n)
	b := make([]byte, n)
	_, err := r.Read(b)
	if err != nil {
		t.Fatalf("error generating random byte slice: %v ", err)
	}
	bb := make([]byte, 8)
	binary.BigEndian.PutUint64(bb, n+1) // + stream ID
	return b, bb
}

func maskedBytes(mask [4]byte, b []byte) []byte {
	maskBytes(mask, b)
	return b
}