mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-22 08:51:41 +00:00

I added yet another one in 6d117d64a256234 but that new one is at the best place int he dependency graph and has the best name, so let's use that one for everything possible. types/lazy can't use it for circular dependency reasons, so unexport that copy at least. Updates #cleanup Change-Id: I25db6b6a0d81dbb8e89a0a9080c7f15cbf7aa770 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
450 lines
11 KiB
Go
450 lines
11 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
package source
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
xmaps "golang.org/x/exp/maps"
|
|
"tailscale.com/util/mak"
|
|
"tailscale.com/util/set"
|
|
"tailscale.com/util/slicesx"
|
|
"tailscale.com/util/syspolicy/setting"
|
|
"tailscale.com/util/testenv"
|
|
)
|
|
|
|
var (
|
|
_ Store = (*TestStore)(nil)
|
|
_ Lockable = (*TestStore)(nil)
|
|
_ Changeable = (*TestStore)(nil)
|
|
_ Expirable = (*TestStore)(nil)
|
|
)
|
|
|
|
// TestValueType is a constraint that allows types supported by [TestStore].
|
|
type TestValueType interface {
|
|
bool | uint64 | string | []string
|
|
}
|
|
|
|
// TestSetting is a policy setting in a [TestStore].
|
|
type TestSetting[T TestValueType] struct {
|
|
// Key is the setting's unique identifier.
|
|
Key setting.Key
|
|
// Error is the error to be returned by the [TestStore] when reading
|
|
// a policy setting with the specified key.
|
|
Error error
|
|
// Value is the value to be returned by the [TestStore] when reading
|
|
// a policy setting with the specified key.
|
|
// It is only used if the Error is nil.
|
|
Value T
|
|
}
|
|
|
|
// TestSettingOf returns a [TestSetting] representing a policy setting
|
|
// configured with the specified key and value.
|
|
func TestSettingOf[T TestValueType](key setting.Key, value T) TestSetting[T] {
|
|
return TestSetting[T]{Key: key, Value: value}
|
|
}
|
|
|
|
// TestSettingWithError returns a [TestSetting] representing a policy setting
|
|
// with the specified key and error.
|
|
func TestSettingWithError[T TestValueType](key setting.Key, err error) TestSetting[T] {
|
|
return TestSetting[T]{Key: key, Error: err}
|
|
}
|
|
|
|
// testReadOperation describes a single policy setting read operation.
|
|
type testReadOperation struct {
|
|
// Key is the setting's unique identifier.
|
|
Key setting.Key
|
|
// Type is a value type of a read operation.
|
|
// [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue]
|
|
Type setting.Type
|
|
}
|
|
|
|
// TestExpectedReads is the number of read operations with the specified details.
|
|
type TestExpectedReads struct {
|
|
// Key is the setting's unique identifier.
|
|
Key setting.Key
|
|
// Type is a value type of a read operation.
|
|
// [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue]
|
|
Type setting.Type
|
|
// NumTimes is how many times a setting with the specified key and type should have been read.
|
|
NumTimes int
|
|
}
|
|
|
|
func (r TestExpectedReads) operation() testReadOperation {
|
|
return testReadOperation{r.Key, r.Type}
|
|
}
|
|
|
|
// TestStore is a [Store] that can be used in tests.
|
|
type TestStore struct {
|
|
tb testenv.TB
|
|
|
|
done chan struct{}
|
|
|
|
storeLock sync.RWMutex // its RLock is exposed via [Store.Lock]/[Store.Unlock].
|
|
storeLockCount atomic.Int32
|
|
|
|
mu sync.RWMutex
|
|
suspendCount int // change callback are suspended if > 0
|
|
mr, mw map[setting.Key]any // maps for reading and writing; they're the same unless the store is suspended.
|
|
cbs set.HandleSet[func()]
|
|
closed bool
|
|
|
|
readsMu sync.Mutex
|
|
reads map[testReadOperation]int // how many times a policy setting was read
|
|
}
|
|
|
|
// NewTestStore returns a new [TestStore].
|
|
// The tb will be used to report coding errors detected by the [TestStore].
|
|
func NewTestStore(tb testenv.TB) *TestStore {
|
|
m := make(map[setting.Key]any)
|
|
store := &TestStore{
|
|
tb: tb,
|
|
done: make(chan struct{}),
|
|
mr: m,
|
|
mw: m,
|
|
}
|
|
tb.Cleanup(store.Close)
|
|
return store
|
|
}
|
|
|
|
// NewTestStoreOf is a shorthand for [NewTestStore] followed by [TestStore.SetBooleans],
|
|
// [TestStore.SetUInt64s], [TestStore.SetStrings] or [TestStore.SetStringLists].
|
|
func NewTestStoreOf[T TestValueType](tb testenv.TB, settings ...TestSetting[T]) *TestStore {
|
|
store := NewTestStore(tb)
|
|
switch settings := any(settings).(type) {
|
|
case []TestSetting[bool]:
|
|
store.SetBooleans(settings...)
|
|
case []TestSetting[uint64]:
|
|
store.SetUInt64s(settings...)
|
|
case []TestSetting[string]:
|
|
store.SetStrings(settings...)
|
|
case []TestSetting[[]string]:
|
|
store.SetStringLists(settings...)
|
|
}
|
|
return store
|
|
}
|
|
|
|
// Lock implements [Lockable].
|
|
func (s *TestStore) Lock() error {
|
|
s.storeLock.RLock()
|
|
s.storeLockCount.Add(1)
|
|
return nil
|
|
}
|
|
|
|
// Unlock implements [Lockable].
|
|
func (s *TestStore) Unlock() {
|
|
if s.storeLockCount.Add(-1) < 0 {
|
|
s.tb.Fatal("negative storeLockCount")
|
|
}
|
|
s.storeLock.RUnlock()
|
|
}
|
|
|
|
// RegisterChangeCallback implements [Changeable].
|
|
func (s *TestStore) RegisterChangeCallback(callback func()) (unregister func(), err error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
handle := s.cbs.Add(callback)
|
|
return func() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
delete(s.cbs, handle)
|
|
}, nil
|
|
}
|
|
|
|
// ReadString implements [Store].
|
|
func (s *TestStore) ReadString(key setting.Key) (string, error) {
|
|
defer s.recordRead(key, setting.StringValue)
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
v, ok := s.mr[key]
|
|
if !ok {
|
|
return "", setting.ErrNotConfigured
|
|
}
|
|
if err, ok := v.(error); ok {
|
|
return "", err
|
|
}
|
|
str, ok := v.(string)
|
|
if !ok {
|
|
return "", fmt.Errorf("%w in ReadString: got %T", setting.ErrTypeMismatch, v)
|
|
}
|
|
return str, nil
|
|
}
|
|
|
|
// ReadUInt64 implements [Store].
|
|
func (s *TestStore) ReadUInt64(key setting.Key) (uint64, error) {
|
|
defer s.recordRead(key, setting.IntegerValue)
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
v, ok := s.mr[key]
|
|
if !ok {
|
|
return 0, setting.ErrNotConfigured
|
|
}
|
|
if err, ok := v.(error); ok {
|
|
return 0, err
|
|
}
|
|
u64, ok := v.(uint64)
|
|
if !ok {
|
|
return 0, fmt.Errorf("%w in ReadUInt64: got %T", setting.ErrTypeMismatch, v)
|
|
}
|
|
return u64, nil
|
|
}
|
|
|
|
// ReadBoolean implements [Store].
|
|
func (s *TestStore) ReadBoolean(key setting.Key) (bool, error) {
|
|
defer s.recordRead(key, setting.BooleanValue)
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
v, ok := s.mr[key]
|
|
if !ok {
|
|
return false, setting.ErrNotConfigured
|
|
}
|
|
if err, ok := v.(error); ok {
|
|
return false, err
|
|
}
|
|
b, ok := v.(bool)
|
|
if !ok {
|
|
return false, fmt.Errorf("%w in ReadBoolean: got %T", setting.ErrTypeMismatch, v)
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
// ReadStringArray implements [Store].
|
|
func (s *TestStore) ReadStringArray(key setting.Key) ([]string, error) {
|
|
defer s.recordRead(key, setting.StringListValue)
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
v, ok := s.mr[key]
|
|
if !ok {
|
|
return nil, setting.ErrNotConfigured
|
|
}
|
|
if err, ok := v.(error); ok {
|
|
return nil, err
|
|
}
|
|
slice, ok := v.([]string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("%w in ReadStringArray: got %T", setting.ErrTypeMismatch, v)
|
|
}
|
|
return slice, nil
|
|
}
|
|
|
|
func (s *TestStore) recordRead(key setting.Key, typ setting.Type) {
|
|
s.readsMu.Lock()
|
|
op := testReadOperation{key, typ}
|
|
num := s.reads[op]
|
|
num++
|
|
mak.Set(&s.reads, op, num)
|
|
s.readsMu.Unlock()
|
|
}
|
|
|
|
func (s *TestStore) ResetCounters() {
|
|
s.readsMu.Lock()
|
|
clear(s.reads)
|
|
s.readsMu.Unlock()
|
|
}
|
|
|
|
// ReadsMustEqual fails the test if the actual reads differs from the specified reads.
|
|
func (s *TestStore) ReadsMustEqual(reads ...TestExpectedReads) {
|
|
s.tb.Helper()
|
|
s.readsMu.Lock()
|
|
defer s.readsMu.Unlock()
|
|
s.readsMustContainLocked(reads...)
|
|
s.readMustNoExtraLocked(reads...)
|
|
}
|
|
|
|
// ReadsMustContain fails the test if the specified reads have not been made,
|
|
// or have been made a different number of times. It permits other values to be
|
|
// read in addition to the ones being tested.
|
|
func (s *TestStore) ReadsMustContain(reads ...TestExpectedReads) {
|
|
s.tb.Helper()
|
|
s.readsMu.Lock()
|
|
defer s.readsMu.Unlock()
|
|
s.readsMustContainLocked(reads...)
|
|
}
|
|
|
|
func (s *TestStore) readsMustContainLocked(reads ...TestExpectedReads) {
|
|
s.tb.Helper()
|
|
for _, r := range reads {
|
|
if numTimes := s.reads[r.operation()]; numTimes != r.NumTimes {
|
|
s.tb.Errorf("%q (%v) reads: got %v, want %v", r.Key, r.Type, numTimes, r.NumTimes)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *TestStore) readMustNoExtraLocked(reads ...TestExpectedReads) {
|
|
s.tb.Helper()
|
|
rs := make(set.Set[testReadOperation])
|
|
for i := range reads {
|
|
rs.Add(reads[i].operation())
|
|
}
|
|
for ro, num := range s.reads {
|
|
if !rs.Contains(ro) {
|
|
s.tb.Errorf("%q (%v) reads: got %v, want 0", ro.Key, ro.Type, num)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Suspend suspends the store, batching changes and notifications
|
|
// until [TestStore.Resume] is called the same number of times as Suspend.
|
|
func (s *TestStore) Suspend() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.suspendCount++; s.suspendCount == 1 {
|
|
s.mw = xmaps.Clone(s.mr)
|
|
}
|
|
}
|
|
|
|
// Resume resumes the store, applying the changes and invoking
|
|
// the change callbacks.
|
|
func (s *TestStore) Resume() {
|
|
s.storeLock.Lock()
|
|
s.mu.Lock()
|
|
switch s.suspendCount--; {
|
|
case s.suspendCount == 0:
|
|
s.mr = s.mw
|
|
s.mu.Unlock()
|
|
s.storeLock.Unlock()
|
|
s.NotifyPolicyChanged()
|
|
case s.suspendCount < 0:
|
|
s.tb.Fatal("negative suspendCount")
|
|
default:
|
|
s.mu.Unlock()
|
|
s.storeLock.Unlock()
|
|
}
|
|
}
|
|
|
|
// SetBooleans sets the specified boolean settings in s.
|
|
func (s *TestStore) SetBooleans(settings ...TestSetting[bool]) {
|
|
s.storeLock.Lock()
|
|
for _, setting := range settings {
|
|
if setting.Key == "" {
|
|
s.tb.Fatal("empty keys disallowed")
|
|
}
|
|
s.mu.Lock()
|
|
if setting.Error != nil {
|
|
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
|
} else {
|
|
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
s.storeLock.Unlock()
|
|
s.NotifyPolicyChanged()
|
|
}
|
|
|
|
// SetUInt64s sets the specified integer settings in s.
|
|
func (s *TestStore) SetUInt64s(settings ...TestSetting[uint64]) {
|
|
s.storeLock.Lock()
|
|
for _, setting := range settings {
|
|
if setting.Key == "" {
|
|
s.tb.Fatal("empty keys disallowed")
|
|
}
|
|
s.mu.Lock()
|
|
if setting.Error != nil {
|
|
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
|
} else {
|
|
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
s.storeLock.Unlock()
|
|
s.NotifyPolicyChanged()
|
|
}
|
|
|
|
// SetStrings sets the specified string settings in s.
|
|
func (s *TestStore) SetStrings(settings ...TestSetting[string]) {
|
|
s.storeLock.Lock()
|
|
for _, setting := range settings {
|
|
if setting.Key == "" {
|
|
s.tb.Fatal("empty keys disallowed")
|
|
}
|
|
s.mu.Lock()
|
|
if setting.Error != nil {
|
|
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
|
} else {
|
|
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
s.storeLock.Unlock()
|
|
s.NotifyPolicyChanged()
|
|
}
|
|
|
|
// SetStrings sets the specified string list settings in s.
|
|
func (s *TestStore) SetStringLists(settings ...TestSetting[[]string]) {
|
|
s.storeLock.Lock()
|
|
for _, setting := range settings {
|
|
if setting.Key == "" {
|
|
s.tb.Fatal("empty keys disallowed")
|
|
}
|
|
s.mu.Lock()
|
|
if setting.Error != nil {
|
|
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
|
} else {
|
|
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
s.storeLock.Unlock()
|
|
s.NotifyPolicyChanged()
|
|
}
|
|
|
|
// Delete deletes the specified settings from s.
|
|
func (s *TestStore) Delete(keys ...setting.Key) {
|
|
s.storeLock.Lock()
|
|
for _, key := range keys {
|
|
s.mu.Lock()
|
|
delete(s.mw, key)
|
|
s.mu.Unlock()
|
|
}
|
|
s.storeLock.Unlock()
|
|
s.NotifyPolicyChanged()
|
|
}
|
|
|
|
// Clear deletes all settings from s.
|
|
func (s *TestStore) Clear() {
|
|
s.storeLock.Lock()
|
|
s.mu.Lock()
|
|
clear(s.mw)
|
|
s.mu.Unlock()
|
|
s.storeLock.Unlock()
|
|
s.NotifyPolicyChanged()
|
|
}
|
|
|
|
func (s *TestStore) NotifyPolicyChanged() {
|
|
s.mu.RLock()
|
|
if s.suspendCount != 0 {
|
|
s.mu.RUnlock()
|
|
return
|
|
}
|
|
cbs := slicesx.MapValues(s.cbs)
|
|
s.mu.RUnlock()
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(len(cbs))
|
|
for _, cb := range cbs {
|
|
go func() {
|
|
defer wg.Done()
|
|
cb()
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
// Close closes s, notifying its users that it has expired.
|
|
func (s *TestStore) Close() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if !s.closed {
|
|
close(s.done)
|
|
s.closed = true
|
|
}
|
|
}
|
|
|
|
// Done implements [Expirable].
|
|
func (s *TestStore) Done() <-chan struct{} {
|
|
return s.done
|
|
}
|