// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

package speedtest

import (
	"crypto/rand"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"net"
	"time"
)

// Serve starts up the server on a given host and port pair. It starts to listen for
// connections and handles each one in a goroutine. Because it runs in an infinite loop,
// this function only returns if any of the speedtests return with errors, or if the
// listener is closed.
func Serve(l net.Listener) error {
	for {
		conn, err := l.Accept()
		if errors.Is(err, net.ErrClosed) {
			return nil
		}
		if err != nil {
			return err
		}
		err = handleConnection(conn)
		if err != nil {
			return err
		}
	}
}

// handleConnection handles the initial exchange between the server and the client.
// It reads the testconfig message into a config struct. If any errors occur with
// the testconfig (specifically, if there is a version mismatch), it will return those
// errors to the client with a configResponse. After the exchange, it will start
// the speed test.
func handleConnection(conn net.Conn) error {
	defer conn.Close()
	var conf config

	decoder := json.NewDecoder(conn)
	err := decoder.Decode(&conf)
	encoder := json.NewEncoder(conn)

	// Both return and encode errors that occurred before the test started.
	if err != nil {
		encoder.Encode(configResponse{Error: err.Error()})
		return err
	}

	// The server should always be doing the opposite of what the client is doing.
	conf.Direction.Reverse()

	if conf.Version != version {
		err = fmt.Errorf("version mismatch! Server is version %d, client is version %d", version, conf.Version)
		encoder.Encode(configResponse{Error: err.Error()})
		return err
	}

	// Start the test
	encoder.Encode(configResponse{})
	_, err = doTest(conn, conf)
	return err
}

// TODO include code to detect whether the code is direct vs DERP

// doTest contains the code to run both the upload and download speedtest.
// the direction value in the config parameter determines which test to run.
func doTest(conn net.Conn, conf config) ([]Result, error) {
	bufferData := make([]byte, blockSize)

	intervalBytes := 0
	totalBytes := 0

	var currentTime time.Time
	var results []Result

	if conf.Direction == Download {
		conn.SetReadDeadline(time.Now().Add(conf.TestDuration).Add(5 * time.Second))
	} else {
		_, err := rand.Read(bufferData)
		if err != nil {
			return nil, err
		}

	}

	startTime := time.Now()
	lastCalculated := startTime

SpeedTestLoop:
	for {
		var n int
		var err error

		if conf.Direction == Download {
			n, err = io.ReadFull(conn, bufferData)
			switch err {
			case io.EOF, io.ErrUnexpectedEOF:
				break SpeedTestLoop
			case nil:
				// successful read
			default:
				return nil, fmt.Errorf("unexpected error has occurred: %w", err)
			}
		} else {
			n, err = conn.Write(bufferData)
			if err != nil {
				// If the write failed, there is most likely something wrong with the connection.
				return nil, fmt.Errorf("upload failed: %w", err)
			}
		}
		intervalBytes += n

		currentTime = time.Now()
		// checks if the current time is more or equal to the lastCalculated time plus the increment
		if currentTime.Sub(lastCalculated) >= increment {
			results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false})
			lastCalculated = currentTime
			totalBytes += intervalBytes
			intervalBytes = 0
		}

		if conf.Direction == Upload && currentTime.Sub(startTime) > conf.TestDuration {
			break SpeedTestLoop
		}
	}

	// get last segment
	if currentTime.Sub(lastCalculated) > minInterval {
		results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false})
	}

	// get total
	totalBytes += intervalBytes
	if currentTime.Sub(startTime) > minInterval {
		results = append(results, Result{Bytes: totalBytes, IntervalStart: startTime, IntervalEnd: currentTime, Total: true})
	}

	return results, nil
}