// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

// Package race contains a helper to "race" two functions, returning the first
// successful result. It also allows explicitly triggering the
// (possibly-waiting) second function when the first function returns an error
// or indicates that it should be retried.
package race

import (
	"context"
	"errors"
	"time"
)

type resultType int

const (
	first resultType = iota
	second
)

// queryResult is an internal type for storing the result of a function call
type queryResult[T any] struct {
	ty  resultType
	res T
	err error
}

// Func is the signature of a function to be called.
type Func[T any] func(context.Context) (T, error)

// Race allows running two functions concurrently and returning the first
// non-error result returned.
type Race[T any] struct {
	func1, func2  Func[T]
	d             time.Duration
	results       chan queryResult[T]
	startFallback chan struct{}
}

// New creates a new Race that, when Start is called, will immediately call
// func1 to obtain a result. After the timeout d or if triggered by an error
// response from func1, func2 will be called.
func New[T any](d time.Duration, func1, func2 Func[T]) *Race[T] {
	ret := &Race[T]{
		func1:         func1,
		func2:         func2,
		d:             d,
		results:       make(chan queryResult[T], 2),
		startFallback: make(chan struct{}),
	}
	return ret
}

// Start will start the "race" process, returning the first non-error result or
// the errors that occurred when calling func1 and/or func2.
func (rh *Race[T]) Start(ctx context.Context) (T, error) {
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	// func1 is started immediately
	go func() {
		ret, err := rh.func1(ctx)
		rh.results <- queryResult[T]{first, ret, err}
	}()

	// func2 is started after a timeout
	go func() {
		wait := time.NewTimer(rh.d)
		defer wait.Stop()

		// Wait for our timeout, trigger, or context to finish.
		select {
		case <-ctx.Done():
			// Nothing to do; we're done
			var zero T
			rh.results <- queryResult[T]{second, zero, ctx.Err()}
			return
		case <-rh.startFallback:
		case <-wait.C:
		}

		ret, err := rh.func2(ctx)
		rh.results <- queryResult[T]{second, ret, err}
	}()

	// For each possible result, get it off the channel.
	var errs []error
	for range 2 {
		res := <-rh.results

		// If this was an error, store it and hope that the other
		// result gives us something.
		if res.err != nil {
			errs = append(errs, res.err)

			// Start the fallback function immediately if this is
			// the first function's error, to avoid having
			// to wait.
			if res.ty == first {
				close(rh.startFallback)
			}
			continue
		}

		// Got a valid response! Return it.
		return res.res, nil
	}

	// If we get here, both raced functions failed. Return whatever errors
	// we have, joined together.
	var zero T
	return zero, errors.Join(errs...)
}