// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

package rsop

import (
	"errors"
	"sync"
	"sync/atomic"

	"tailscale.com/util/syspolicy/internal"
	"tailscale.com/util/syspolicy/setting"
	"tailscale.com/util/syspolicy/source"
)

// ErrAlreadyConsumed is the error returned when [StoreRegistration.ReplaceStore]
// or [StoreRegistration.Unregister] is called more than once.
var ErrAlreadyConsumed = errors.New("the store registration is no longer valid")

// StoreRegistration is a [source.Store] registered for use in the specified scope.
// It can be used to unregister the store, or replace it with another one.
type StoreRegistration struct {
	source   *source.Source
	m        sync.Mutex  // protects the [StoreRegistration.consumeSlow] path
	consumed atomic.Bool // can be read without holding m, but must be written with m held
}

// RegisterStore registers a new policy [source.Store] with the specified name and [setting.PolicyScope].
func RegisterStore(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) {
	return newStoreRegistration(name, scope, store)
}

// RegisterStoreForTest is like [RegisterStore], but unregisters the store when
// tb and all its subtests complete.
func RegisterStoreForTest(tb internal.TB, name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) {
	reg, err := RegisterStore(name, scope, store)
	if err == nil {
		tb.Cleanup(func() {
			if err := reg.Unregister(); err != nil && !errors.Is(err, ErrAlreadyConsumed) {
				tb.Fatalf("Unregister failed: %v", err)
			}
		})
	}
	return reg, err // may be nil or non-nil
}

func newStoreRegistration(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) {
	source := source.NewSource(name, scope, store)
	if err := registerSource(source); err != nil {
		return nil, err
	}
	return &StoreRegistration{source: source}, nil
}

// ReplaceStore replaces the registered store with the new one,
// returning a new [StoreRegistration] or an error.
func (r *StoreRegistration) ReplaceStore(new source.Store) (*StoreRegistration, error) {
	var res *StoreRegistration
	err := r.consume(func() error {
		newSource := source.NewSource(r.source.Name(), r.source.Scope(), new)
		if err := replaceSource(r.source, newSource); err != nil {
			return err
		}
		res = &StoreRegistration{source: newSource}
		return nil
	})
	return res, err
}

// Unregister reverts the registration.
func (r *StoreRegistration) Unregister() error {
	return r.consume(func() error { return unregisterSource(r.source) })
}

// consume invokes fn, consuming r if no error is returned.
// It returns [ErrAlreadyConsumed] on subsequent calls after the first successful call.
func (r *StoreRegistration) consume(fn func() error) (err error) {
	if r.consumed.Load() {
		return ErrAlreadyConsumed
	}
	return r.consumeSlow(fn)
}

func (r *StoreRegistration) consumeSlow(fn func() error) (err error) {
	r.m.Lock()
	defer r.m.Unlock()
	if r.consumed.Load() {
		return ErrAlreadyConsumed
	}
	if err = fn(); err == nil {
		r.consumed.Store(true)
	}
	return err // may be nil or non-nil
}