// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

// Package tstest provides utilities for use in unit tests.
package tstest

import (
	"context"
	"os"
	"strconv"
	"strings"
	"sync/atomic"
	"testing"
	"time"

	"tailscale.com/envknob"
	"tailscale.com/logtail/backoff"
	"tailscale.com/types/logger"
	"tailscale.com/util/cibuild"
)

// Replace replaces the value of target with val.
// The old value is restored when the test ends.
func Replace[T any](t testing.TB, target *T, val T) {
	t.Helper()
	if target == nil {
		t.Fatalf("Replace: nil pointer")
		panic("unreachable") // pacify staticcheck
	}
	old := *target
	t.Cleanup(func() {
		*target = old
	})

	*target = val
	return
}

// WaitFor retries try for up to maxWait.
// It returns nil once try returns nil the first time.
// If maxWait passes without success, it returns try's last error.
func WaitFor(maxWait time.Duration, try func() error) error {
	bo := backoff.NewBackoff("wait-for", logger.Discard, maxWait/4)
	deadline := time.Now().Add(maxWait)
	var err error
	for time.Now().Before(deadline) {
		err = try()
		if err == nil {
			break
		}
		bo.BackOff(context.Background(), err)
	}
	return err
}

var testNum atomic.Int32

// Shard skips t if it's not running if the TS_TEST_SHARD test shard is set to
// "n/m" and this test execution number in the process mod m is not equal to n-1.
// That is, to run with 4 shards, set TS_TEST_SHARD=1/4, ..., TS_TEST_SHARD=4/4
// for the four jobs.
func Shard(t testing.TB) {
	e := os.Getenv("TS_TEST_SHARD")
	a, b, ok := strings.Cut(e, "/")
	if !ok {
		return
	}
	wantShard, _ := strconv.ParseInt(a, 10, 32)
	shards, _ := strconv.ParseInt(b, 10, 32)
	if wantShard == 0 || shards == 0 {
		return
	}

	shard := ((testNum.Add(1) - 1) % int32(shards)) + 1
	if shard != int32(wantShard) {
		t.Skipf("skipping shard %d/%d (process has TS_TEST_SHARD=%q)", shard, shards, e)
	}
}

// SkipOnUnshardedCI skips t if we're in CI and the TS_TEST_SHARD
// environment variable isn't set.
func SkipOnUnshardedCI(t testing.TB) {
	if cibuild.On() && os.Getenv("TS_TEST_SHARD") == "" {
		t.Skip("skipping on CI without TS_TEST_SHARD")
	}
}

var serializeParallel = envknob.RegisterBool("TS_SERIAL_TESTS")

// Parallel calls t.Parallel, unless TS_SERIAL_TESTS is set true.
func Parallel(t *testing.T) {
	if !serializeParallel() {
		t.Parallel()
	}
}