util/syspolicy/source: add package for reading policy settings from external stores

We add package defining interfaces for policy stores, enabling creation of policy sources
and reading settings from them. It includes a Windows-specific PlatformPolicyStore for GP and MDM
policies stored in the Registry, and an in-memory TestStore for testing purposes.

We also include an internal package that tracks and reports policy usage metrics when a policy setting
is read from a store. Initially, it will be used only on Windows and Android, as macOS, iOS, and tvOS
report their own metrics. However, we plan to use it across all platforms eventually.

Updates #12687

Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
Nick Khyl
2024-08-12 22:07:45 -05:00
committed by Nick Khyl
parent e865a0e2b0
commit aeb15dea30
11 changed files with 3009 additions and 2 deletions

View File

@@ -0,0 +1,394 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package source
import (
"errors"
"fmt"
"io"
"slices"
"sort"
"sync"
"time"
"tailscale.com/util/mak"
"tailscale.com/util/set"
"tailscale.com/util/syspolicy/internal/loggerx"
"tailscale.com/util/syspolicy/internal/metrics"
"tailscale.com/util/syspolicy/setting"
)
// Reader reads all configured policy settings from a given [Store].
// It registers a change callback with the [Store] and maintains the current version
// of the [setting.Snapshot] by lazily re-reading policy settings from the [Store]
// whenever a new settings snapshot is requested with [Reader.GetSettings].
// It is safe for concurrent use.
type Reader struct {
store Store
origin *setting.Origin
settings []*setting.Definition
unregisterChangeNotifier func()
doneCh chan struct{} // closed when [Reader] is closed.
mu sync.Mutex
closing bool
upToDate bool
lastPolicy *setting.Snapshot
sessions set.HandleSet[*ReadingSession]
}
// newReader returns a new [Reader] that reads policy settings from a given [Store].
// The returned reader takes ownership of the store. If the store implements [io.Closer],
// the returned reader will close the store when it is closed.
func newReader(store Store, origin *setting.Origin) (*Reader, error) {
settings, err := setting.Definitions()
if err != nil {
return nil, err
}
if expirable, ok := store.(Expirable); ok {
select {
case <-expirable.Done():
return nil, ErrStoreClosed
default:
}
}
reader := &Reader{store: store, origin: origin, settings: settings, doneCh: make(chan struct{})}
if changeable, ok := store.(Changeable); ok {
// We should subscribe to policy change notifications first before reading
// the policy settings from the store. This way we won't miss any notifications.
if reader.unregisterChangeNotifier, err = changeable.RegisterChangeCallback(reader.onPolicyChange); err != nil {
// Errors registering policy change callbacks are non-fatal.
// TODO(nickkhyl): implement a background policy refresh every X minutes?
loggerx.Errorf("failed to register %v policy change callback: %v", origin, err)
}
}
if _, err := reader.reload(true); err != nil {
if reader.unregisterChangeNotifier != nil {
reader.unregisterChangeNotifier()
}
return nil, err
}
if expirable, ok := store.(Expirable); ok {
if waitCh := expirable.Done(); waitCh != nil {
go func() {
select {
case <-waitCh:
reader.Close()
case <-reader.doneCh:
}
}()
}
}
return reader, nil
}
// GetSettings returns the current [*setting.Snapshot],
// re-reading it from from the underlying [Store] only if the policy
// has changed since it was read last. It never fails and returns
// the previous version of the policy settings if a read attempt fails.
func (r *Reader) GetSettings() *setting.Snapshot {
r.mu.Lock()
upToDate, lastPolicy := r.upToDate, r.lastPolicy
r.mu.Unlock()
if upToDate {
return lastPolicy
}
policy, err := r.reload(false)
if err != nil {
// If the policy fails to reload completely, log an error and return the last cached version.
// However, errors related to individual policy items are always
// propagated to callers when they fetch those settings.
loggerx.Errorf("failed to reload %v policy: %v", r.origin, err)
}
return policy
}
// ReadSettings reads policy settings from the underlying [Store] even if no
// changes were detected. It returns the new [*setting.Snapshot],nil on
// success or an undefined snapshot (possibly `nil`) along with a non-`nil`
// error in case of failure.
func (r *Reader) ReadSettings() (*setting.Snapshot, error) {
return r.reload(true)
}
// reload is like [Reader.ReadSettings], but allows specifying whether to re-read
// an unchanged policy, and returns the last [*setting.Snapshot] if the read fails.
func (r *Reader) reload(force bool) (*setting.Snapshot, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.upToDate && !force {
return r.lastPolicy, nil
}
if lockable, ok := r.store.(Lockable); ok {
if err := lockable.Lock(); err != nil {
return r.lastPolicy, err
}
defer lockable.Unlock()
}
r.upToDate = true
metrics.Reset(r.origin)
var m map[setting.Key]setting.RawItem
if lastPolicyCount := r.lastPolicy.Len(); lastPolicyCount > 0 {
m = make(map[setting.Key]setting.RawItem, lastPolicyCount)
}
for _, s := range r.settings {
if !r.origin.Scope().IsConfigurableSetting(s) {
// Skip settings that cannot be configured in the current scope.
continue
}
val, err := readPolicySettingValue(r.store, s)
if err != nil && (errors.Is(err, setting.ErrNoSuchKey) || errors.Is(err, setting.ErrNotConfigured)) {
metrics.ReportNotConfigured(r.origin, s)
continue
}
if err == nil {
metrics.ReportConfigured(r.origin, s, val)
} else {
metrics.ReportError(r.origin, s, err)
}
// If there's an error reading a single policy, such as a value type mismatch,
// we'll wrap the error to preserve its text and return it
// whenever someone attempts to fetch the value.
// Otherwise, the errorText will be nil.
errorText := setting.MaybeErrorText(err)
item := setting.RawItemWith(val, errorText, r.origin)
mak.Set(&m, s.Key(), item)
}
newPolicy := setting.NewSnapshot(m, setting.SummaryWith(r.origin))
if r.lastPolicy == nil || !newPolicy.EqualItems(r.lastPolicy) {
r.lastPolicy = newPolicy
}
return r.lastPolicy, nil
}
// ReadingSession is like [Reader], but with a channel that's written
// to when there's a policy change, and closed when the session is terminated.
type ReadingSession struct {
reader *Reader
policyChangedCh chan struct{} // 1-buffered channel
handle set.Handle // in the reader.sessions
closeInternal func()
}
// OpenSession opens and returns a new session to r, allowing the caller
// to get notified whenever a policy change is reported by the [source.Store],
// or an [ErrStoreClosed] if the reader has already been closed.
func (r *Reader) OpenSession() (*ReadingSession, error) {
session := &ReadingSession{
reader: r,
policyChangedCh: make(chan struct{}, 1),
}
session.closeInternal = sync.OnceFunc(func() { close(session.policyChangedCh) })
r.mu.Lock()
defer r.mu.Unlock()
if r.closing {
return nil, ErrStoreClosed
}
session.handle = r.sessions.Add(session)
return session, nil
}
// GetSettings is like [Reader.GetSettings].
func (s *ReadingSession) GetSettings() *setting.Snapshot {
return s.reader.GetSettings()
}
// ReadSettings is like [Reader.ReadSettings].
func (s *ReadingSession) ReadSettings() (*setting.Snapshot, error) {
return s.reader.ReadSettings()
}
// PolicyChanged returns a channel that's written to when
// there's a policy change, closed when the session is terminated.
func (s *ReadingSession) PolicyChanged() <-chan struct{} {
return s.policyChangedCh
}
// Close unregisters this session with the [Reader].
func (s *ReadingSession) Close() {
s.reader.mu.Lock()
delete(s.reader.sessions, s.handle)
s.closeInternal()
s.reader.mu.Unlock()
}
// onPolicyChange handles a policy change notification from the [Store],
// invalidating the current [setting.Snapshot] in r,
// and notifying the active [ReadingSession]s.
func (r *Reader) onPolicyChange() {
r.mu.Lock()
defer r.mu.Unlock()
r.upToDate = false
for _, s := range r.sessions {
select {
case s.policyChangedCh <- struct{}{}:
// Notified.
default:
// 1-buffered channel is full, meaning that another policy change
// notification is already en route.
}
}
}
// Close closes the store reader and the underlying store.
func (r *Reader) Close() error {
r.mu.Lock()
if r.closing {
r.mu.Unlock()
return nil
}
r.closing = true
r.mu.Unlock()
if r.unregisterChangeNotifier != nil {
r.unregisterChangeNotifier()
r.unregisterChangeNotifier = nil
}
if closer, ok := r.store.(io.Closer); ok {
if err := closer.Close(); err != nil {
return err
}
}
r.store = nil
close(r.doneCh)
r.mu.Lock()
defer r.mu.Unlock()
for _, c := range r.sessions {
c.closeInternal()
}
r.sessions = nil
return nil
}
// Done returns a channel that is closed when the reader is closed.
func (r *Reader) Done() <-chan struct{} {
return r.doneCh
}
// ReadableSource is a [Source] open for reading.
type ReadableSource struct {
*Source
*ReadingSession
}
// Close closes the underlying [ReadingSession].
func (s ReadableSource) Close() {
s.ReadingSession.Close()
}
// ReadableSources is a slice of [ReadableSource].
type ReadableSources []ReadableSource
// Contains reports whether s contains the specified source.
func (s ReadableSources) Contains(source *Source) bool {
return s.IndexOf(source) != -1
}
// IndexOf returns position of the specified source in s, or -1
// if the source does not exist.
func (s ReadableSources) IndexOf(source *Source) int {
return slices.IndexFunc(s, func(rs ReadableSource) bool {
return rs.Source == source
})
}
// InsertionIndexOf returns the position at which source can be inserted
// to maintain the sorted order of the readableSources.
// The return value is unspecified if s is not sorted on entry to InsertionIndexOf.
func (s ReadableSources) InsertionIndexOf(source *Source) int {
// Insert new sources after any existing sources with the same precedence,
// and just before the first source with higher precedence.
// Just like stable sort, but for insertion.
// It's okay to use linear search as insertions are rare
// and we never have more than just a few policy sources.
higherPrecedence := func(rs ReadableSource) bool { return rs.Compare(source) > 0 }
if i := slices.IndexFunc(s, higherPrecedence); i != -1 {
return i
}
return len(s)
}
// StableSort sorts [ReadableSource] in s by precedence, so that policy
// settings from sources with higher precedence (e.g., [DeviceScope])
// will be read and merged last, overriding any policy settings with
// the same keys configured in sources with lower precedence
// (e.g., [CurrentUserScope]).
func (s *ReadableSources) StableSort() {
sort.SliceStable(*s, func(i, j int) bool {
return (*s)[i].Source.Compare((*s)[j].Source) < 0
})
}
// DeleteAt closes and deletes the i-th source from s.
func (s *ReadableSources) DeleteAt(i int) {
(*s)[i].Close()
*s = slices.Delete(*s, i, i+1)
}
// Close closes and deletes all sources in s.
func (s *ReadableSources) Close() {
for _, s := range *s {
s.Close()
}
*s = nil
}
func readPolicySettingValue(store Store, s *setting.Definition) (value any, err error) {
switch key := s.Key(); s.Type() {
case setting.BooleanValue:
return store.ReadBoolean(key)
case setting.IntegerValue:
return store.ReadUInt64(key)
case setting.StringValue:
return store.ReadString(key)
case setting.StringListValue:
return store.ReadStringArray(key)
case setting.PreferenceOptionValue:
s, err := store.ReadString(key)
if err == nil {
var value setting.PreferenceOption
if err = value.UnmarshalText([]byte(s)); err == nil {
return value, nil
}
}
return setting.ShowChoiceByPolicy, err
case setting.VisibilityValue:
s, err := store.ReadString(key)
if err == nil {
var value setting.Visibility
if err = value.UnmarshalText([]byte(s)); err == nil {
return value, nil
}
}
return setting.VisibleByPolicy, err
case setting.DurationValue:
s, err := store.ReadString(key)
if err == nil {
var value time.Duration
if value, err = time.ParseDuration(s); err == nil {
return value, nil
}
}
return nil, err
default:
return nil, fmt.Errorf("%w: unsupported setting type: %v", setting.ErrTypeMismatch, s.Type())
}
}

View File

@@ -0,0 +1,291 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package source
import (
"cmp"
"testing"
"time"
"tailscale.com/util/must"
"tailscale.com/util/syspolicy/setting"
)
func TestReaderLifecycle(t *testing.T) {
tests := []struct {
name string
origin *setting.Origin
definitions []*setting.Definition
wantReads []TestExpectedReads
initStrings []TestSetting[string]
initUInt64s []TestSetting[uint64]
initWant *setting.Snapshot
addStrings []TestSetting[string]
addStringLists []TestSetting[[]string]
newWant *setting.Snapshot
}{
{
name: "read-all-settings-once",
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
definitions: []*setting.Definition{
setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue),
setting.NewDefinition("IntegerValue", setting.DeviceSetting, setting.IntegerValue),
setting.NewDefinition("BooleanValue", setting.DeviceSetting, setting.BooleanValue),
setting.NewDefinition("StringListValue", setting.DeviceSetting, setting.StringListValue),
setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue),
setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue),
},
wantReads: []TestExpectedReads{
{Key: "StringValue", Type: setting.StringValue, NumTimes: 1},
{Key: "IntegerValue", Type: setting.IntegerValue, NumTimes: 1},
{Key: "BooleanValue", Type: setting.BooleanValue, NumTimes: 1},
{Key: "StringListValue", Type: setting.StringListValue, NumTimes: 1},
{Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
},
initWant: setting.NewSnapshot(nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
},
{
name: "re-read-all-settings-when-the-policy-changes",
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
definitions: []*setting.Definition{
setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue),
setting.NewDefinition("IntegerValue", setting.DeviceSetting, setting.IntegerValue),
setting.NewDefinition("BooleanValue", setting.DeviceSetting, setting.BooleanValue),
setting.NewDefinition("StringListValue", setting.DeviceSetting, setting.StringListValue),
setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue),
setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue),
},
wantReads: []TestExpectedReads{
{Key: "StringValue", Type: setting.StringValue, NumTimes: 1},
{Key: "IntegerValue", Type: setting.IntegerValue, NumTimes: 1},
{Key: "BooleanValue", Type: setting.BooleanValue, NumTimes: 1},
{Key: "StringListValue", Type: setting.StringListValue, NumTimes: 1},
{Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
},
initWant: setting.NewSnapshot(nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
addStrings: []TestSetting[string]{TestSettingOf("StringValue", "S1")},
addStringLists: []TestSetting[[]string]{TestSettingOf("StringListValue", []string{"S1", "S2", "S3"})},
newWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"StringValue": setting.RawItemWith("S1", nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
"StringListValue": setting.RawItemWith([]string{"S1", "S2", "S3"}, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
}, setting.NewNamedOrigin("Test", setting.DeviceScope)),
},
{
name: "read-settings-if-in-scope/device",
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
definitions: []*setting.Definition{
setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue),
setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue),
setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue),
},
wantReads: []TestExpectedReads{
{Key: "DeviceSetting", Type: setting.StringValue, NumTimes: 1},
{Key: "ProfileSetting", Type: setting.IntegerValue, NumTimes: 1},
{Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1},
},
},
{
name: "read-settings-if-in-scope/profile",
origin: setting.NewNamedOrigin("Test", setting.CurrentProfileScope),
definitions: []*setting.Definition{
setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue),
setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue),
setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue),
},
wantReads: []TestExpectedReads{
// Device settings cannot be configured at the profile scope and should not be read.
{Key: "ProfileSetting", Type: setting.IntegerValue, NumTimes: 1},
{Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1},
},
},
{
name: "read-settings-if-in-scope/user",
origin: setting.NewNamedOrigin("Test", setting.CurrentUserScope),
definitions: []*setting.Definition{
setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue),
setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue),
setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue),
},
wantReads: []TestExpectedReads{
// Device and profile settings cannot be configured at the profile scope and should not be read.
{Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1},
},
},
{
name: "read-stringy-settings",
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
definitions: []*setting.Definition{
setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue),
setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue),
},
wantReads: []TestExpectedReads{
{Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
},
initStrings: []TestSetting[string]{
TestSettingOf("DurationValue", "2h30m"),
TestSettingOf("PreferenceOptionValue", "always"),
TestSettingOf("VisibilityValue", "show"),
},
initWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"DurationValue": setting.RawItemWith(must.Get(time.ParseDuration("2h30m")), nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
"PreferenceOptionValue": setting.RawItemWith(setting.AlwaysByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
"VisibilityValue": setting.RawItemWith(setting.VisibleByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
}, setting.NewNamedOrigin("Test", setting.DeviceScope)),
},
{
name: "read-erroneous-stringy-settings",
origin: setting.NewNamedOrigin("Test", setting.CurrentUserScope),
definitions: []*setting.Definition{
setting.NewDefinition("DurationValue1", setting.UserSetting, setting.DurationValue),
setting.NewDefinition("DurationValue2", setting.UserSetting, setting.DurationValue),
setting.NewDefinition("PreferenceOptionValue", setting.UserSetting, setting.PreferenceOptionValue),
setting.NewDefinition("VisibilityValue", setting.UserSetting, setting.VisibilityValue),
},
wantReads: []TestExpectedReads{
{Key: "DurationValue1", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
{Key: "DurationValue2", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
},
initStrings: []TestSetting[string]{
TestSettingOf("DurationValue1", "soon"),
TestSettingWithError[string]("DurationValue2", setting.NewErrorText("bang!")),
TestSettingOf("PreferenceOptionValue", "sometimes"),
},
initUInt64s: []TestSetting[uint64]{
TestSettingOf[uint64]("VisibilityValue", 42), // type mismatch
},
initWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"DurationValue1": setting.RawItemWith(nil, setting.NewErrorText("time: invalid duration \"soon\""), setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
"DurationValue2": setting.RawItemWith(nil, setting.NewErrorText("bang!"), setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
"PreferenceOptionValue": setting.RawItemWith(setting.ShowChoiceByPolicy, nil, setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
"VisibilityValue": setting.RawItemWith(setting.VisibleByPolicy, setting.NewErrorText("type mismatch in ReadString: got uint64"), setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
}, setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setting.SetDefinitionsForTest(t, tt.definitions...)
store := NewTestStore(t)
store.SetStrings(tt.initStrings...)
store.SetUInt64s(tt.initUInt64s...)
reader, err := newReader(store, tt.origin)
if err != nil {
t.Fatalf("newReader failed: %v", err)
}
if got := reader.GetSettings(); tt.initWant != nil && !got.Equal(tt.initWant) {
t.Errorf("Settings do not match: got %v, want %v", got, tt.initWant)
}
if tt.wantReads != nil {
store.ReadsMustEqual(tt.wantReads...)
}
// Should not result in new reads as there were no changes.
N := 100
for range N {
reader.GetSettings()
}
if tt.wantReads != nil {
store.ReadsMustEqual(tt.wantReads...)
}
store.ResetCounters()
got, err := reader.ReadSettings()
if err != nil {
t.Fatalf("ReadSettings failed: %v", err)
}
if tt.initWant != nil && !got.Equal(tt.initWant) {
t.Errorf("Settings do not match: got %v, want %v", got, tt.initWant)
}
if tt.wantReads != nil {
store.ReadsMustEqual(tt.wantReads...)
}
store.ResetCounters()
if len(tt.addStrings) != 0 || len(tt.addStringLists) != 0 {
store.SetStrings(tt.addStrings...)
store.SetStringLists(tt.addStringLists...)
// As the settings have changed, GetSettings needs to re-read them.
if got, want := reader.GetSettings(), cmp.Or(tt.newWant, tt.initWant); !got.Equal(want) {
t.Errorf("New Settings do not match: got %v, want %v", got, want)
}
if tt.wantReads != nil {
store.ReadsMustEqual(tt.wantReads...)
}
}
select {
case <-reader.Done():
t.Fatalf("the reader is closed")
default:
}
store.Close()
<-reader.Done()
})
}
}
func TestReadingSession(t *testing.T) {
setting.SetDefinitionsForTest(t, setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue))
store := NewTestStore(t)
origin := setting.NewOrigin(setting.DeviceScope)
reader, err := newReader(store, origin)
if err != nil {
t.Fatalf("newReader failed: %v", err)
}
session, err := reader.OpenSession()
if err != nil {
t.Fatalf("failed to open a reading session: %v", err)
}
t.Cleanup(session.Close)
if got, want := session.GetSettings(), setting.NewSnapshot(nil, origin); !got.Equal(want) {
t.Errorf("Settings do not match: got %v, want %v", got, want)
}
select {
case _, ok := <-session.PolicyChanged():
if ok {
t.Fatalf("the policy changed notification was sent prematurely")
} else {
t.Fatalf("the session was closed prematurely")
}
default:
}
store.SetStrings(TestSettingOf("StringValue", "S1"))
_, ok := <-session.PolicyChanged()
if !ok {
t.Fatalf("the session was closed prematurely")
}
want := setting.NewSnapshot(map[setting.Key]setting.RawItem{
"StringValue": setting.RawItemWith("S1", nil, origin),
}, origin)
if got := session.GetSettings(); !got.Equal(want) {
t.Errorf("Settings do not match: got %v, want %v", got, want)
}
store.Close()
if _, ok = <-session.PolicyChanged(); ok {
t.Fatalf("the session must be closed")
}
}

View File

@@ -0,0 +1,146 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package source defines interfaces for policy stores,
// facilitates the creation of policy sources, and provides
// functionality for reading policy settings from these sources.
package source
import (
"cmp"
"errors"
"fmt"
"io"
"tailscale.com/types/lazy"
"tailscale.com/util/syspolicy/setting"
)
// ErrStoreClosed is an error returned when attempting to use a [Store] after it
// has been closed.
var ErrStoreClosed = errors.New("the policy store has been closed")
// Store provides methods to read system policy settings from OS-specific storage.
// Implementations must be concurrency-safe, and may also implement
// [Lockable], [Changeable], [Expirable] and [io.Closer].
//
// If a [Store] implementation also implements [io.Closer],
// it will be called by the package to release the resources
// when the store is no longer needed.
type Store interface {
// ReadString returns the value of a [setting.StringValue] with the specified key,
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
// an error on failure.
ReadString(key setting.Key) (string, error)
// ReadUInt64 returns the value of a [setting.IntegerValue] with the specified key,
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
// an error on failure.
ReadUInt64(key setting.Key) (uint64, error)
// ReadBoolean returns the value of a [setting.BooleanValue] with the specified key,
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
// an error on failure.
ReadBoolean(key setting.Key) (bool, error)
// ReadStringArray returns the value of a [setting.StringListValue] with the specified key,
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
// an error on failure.
ReadStringArray(key setting.Key) ([]string, error)
}
// Lockable is an optional interface that [Store] implementations may support.
// Locking a [Store] is not mandatory as [Store] must be concurrency-safe,
// but is recommended to avoid issues where consecutive read calls for related
// policies might return inconsistent results if a policy change occurs between
// the calls. Implementations may use locking to pre-read policies or for
// similar performance optimizations.
type Lockable interface {
// Lock acquires a read lock on the policy store,
// ensuring the store's state remains unchanged while locked.
// Multiple readers can hold the lock simultaneously.
// It returns an error if the store cannot be locked.
Lock() error
// Unlock unlocks the policy store.
// It is a run-time error if the store is not locked on entry to Unlock.
Unlock()
}
// Changeable is an optional interface that [Store] implementations may support
// if the policy settings they contain can be externally changed after being initially read.
type Changeable interface {
// RegisterChangeCallback adds a function that will be called
// whenever there's a policy change in the [Store].
// The returned function can be used to unregister the callback.
RegisterChangeCallback(callback func()) (unregister func(), err error)
}
// Expirable is an optional interface that [Store] implementations may support
// if they can be externally closed or otherwise become invalid while in use.
type Expirable interface {
// Done returns a channel that is closed when the policy [Store] should no longer be used.
// It should return nil if the store never expires.
Done() <-chan struct{}
}
// Source represents a named source of policy settings for a given [setting.PolicyScope].
type Source struct {
name string
scope setting.PolicyScope
store Store
origin *setting.Origin
lazyReader lazy.SyncValue[*Reader]
}
// NewSource returns a new [Source] with the specified name, scope, and store.
func NewSource(name string, scope setting.PolicyScope, store Store) *Source {
return &Source{name: name, scope: scope, store: store, origin: setting.NewNamedOrigin(name, scope)}
}
// Name reports the name of the policy source.
func (s *Source) Name() string {
return s.name
}
// Scope reports the management scope of the policy source.
func (s *Source) Scope() setting.PolicyScope {
return s.scope
}
// Reader returns a [Reader] that reads from this source's [Store].
func (s *Source) Reader() (*Reader, error) {
return s.lazyReader.GetErr(func() (*Reader, error) {
return newReader(s.store, s.origin)
})
}
// Description returns a formatted string with the scope and name of this policy source.
// It can be used for logging or display purposes.
func (s *Source) Description() string {
if s.name != "" {
return fmt.Sprintf("%s (%v)", s.name, s.Scope())
}
return s.Scope().String()
}
// Compare returns an integer comparing s and s2
// by their precedence, following the "last-wins" model.
// The result will be:
//
// -1 if policy settings from s should be processed before policy settings from s2;
// +1 if policy settings from s should be processed after policy settings from s2, overriding s2;
// 0 if the relative processing order of policy settings in s and s2 is unspecified.
func (s *Source) Compare(s2 *Source) int {
return cmp.Compare(s2.Scope().Kind(), s.Scope().Kind())
}
// Close closes the [Source] and the underlying [Store].
func (s *Source) Close() error {
// The [Reader], if any, owns the [Store].
if reader, _ := s.lazyReader.GetErr(func() (*Reader, error) { return nil, ErrStoreClosed }); reader != nil {
return reader.Close()
}
// Otherwise, it is our responsibility to close it.
if closer, ok := s.store.(io.Closer); ok {
return closer.Close()
}
return nil
}

View File

@@ -0,0 +1,450 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package source
import (
"errors"
"fmt"
"strings"
"sync"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
"tailscale.com/util/set"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/winutil/gp"
)
const (
softwareKeyName = `Software`
tsPoliciesSubkey = `Policies\Tailscale`
tsIPNSubkey = `Tailscale IPN` // the legacy key we need to fallback to
)
var (
_ Store = (*PlatformPolicyStore)(nil)
_ Lockable = (*PlatformPolicyStore)(nil)
_ Changeable = (*PlatformPolicyStore)(nil)
_ Expirable = (*PlatformPolicyStore)(nil)
)
// PlatformPolicyStore implements [Store] by providing read access to
// Registry-based Tailscale policies, such as those configured via Group Policy or MDM.
// For better performance and consistency, it is recommended to lock it when
// reading multiple policy settings sequentially.
// It also allows subscribing to policy change notifications.
type PlatformPolicyStore struct {
scope gp.Scope // [gp.MachinePolicy] or [gp.UserPolicy]
// The softwareKey can be HKLM\Software, HKCU\Software, or
// HKU\{SID}\Software. Anything below the Software subkey, including
// Software\Policies, may not yet exist or could be deleted throughout the
// [PlatformPolicyStore]'s lifespan, invalidating the handle. We also prefer
// to always use a real registry key (rather than a predefined HKLM or HKCU)
// to simplify bookkeeping (predefined keys should never be closed).
// Finally, this will allow us to watch for any registry changes directly
// should we need this in the future in addition to gp.ChangeWatcher.
softwareKey registry.Key
watcher *gp.ChangeWatcher
done chan struct{} // done is closed when Close call completes
// The policyLock can be locked by the caller when reading multiple policy settings
// to prevent the Group Policy Client service from modifying policies while
// they are being read.
//
// When both policyLock and mu need to be taken, mu must be taken before policyLock.
policyLock *gp.PolicyLock
mu sync.Mutex
tsKeys []registry.Key // or nil if the [PlatformPolicyStore] hasn't been locked.
cbs set.HandleSet[func()] // policy change callbacks
lockCnt int
locked sync.WaitGroup
closing bool
closed bool
}
type registryValueGetter[T any] func(key registry.Key, name string) (T, error)
// NewMachinePlatformPolicyStore returns a new [PlatformPolicyStore] for the machine.
func NewMachinePlatformPolicyStore() (*PlatformPolicyStore, error) {
softwareKey, err := registry.OpenKey(registry.LOCAL_MACHINE, softwareKeyName, windows.KEY_READ)
if err != nil {
return nil, fmt.Errorf("failed to open the %s key: %w", softwareKeyName, err)
}
return newPlatformPolicyStore(gp.MachinePolicy, softwareKey, gp.NewMachinePolicyLock()), nil
}
// NewUserPlatformPolicyStore returns a new [PlatformPolicyStore] for the user specified by its token.
// User's profile must be loaded, and the token handle must have [windows.TOKEN_QUERY]
// and [windows.TOKEN_DUPLICATE] access. The caller retains ownership of the token.
func NewUserPlatformPolicyStore(token windows.Token) (*PlatformPolicyStore, error) {
var err error
var softwareKey registry.Key
if token != 0 {
var user *windows.Tokenuser
if user, err = token.GetTokenUser(); err != nil {
return nil, fmt.Errorf("failed to get token user: %w", err)
}
userSid := user.User.Sid
softwareKey, err = registry.OpenKey(registry.USERS, userSid.String()+`\`+softwareKeyName, windows.KEY_READ)
} else {
softwareKey, err = registry.OpenKey(registry.CURRENT_USER, softwareKeyName, windows.KEY_READ)
}
if err != nil {
return nil, fmt.Errorf("failed to open the %s key: %w", softwareKeyName, err)
}
policyLock, err := gp.NewUserPolicyLock(token)
if err != nil {
return nil, fmt.Errorf("failed to create a user policy lock: %w", err)
}
return newPlatformPolicyStore(gp.UserPolicy, softwareKey, policyLock), nil
}
func newPlatformPolicyStore(scope gp.Scope, softwareKey registry.Key, policyLock *gp.PolicyLock) *PlatformPolicyStore {
return &PlatformPolicyStore{
scope: scope,
softwareKey: softwareKey,
done: make(chan struct{}),
policyLock: policyLock,
}
}
// Lock locks the policy store, preventing the system from modifying the policies
// while they are being read. It is a read lock that may be acquired by multiple goroutines.
// Each Lock call must be balanced by exactly one Unlock call.
func (ps *PlatformPolicyStore) Lock() (err error) {
ps.mu.Lock()
defer ps.mu.Unlock()
if ps.closing {
return ErrStoreClosed
}
ps.lockCnt += 1
if ps.lockCnt != 1 {
return nil
}
defer func() {
if err != nil {
ps.lockCnt -= 1
}
}()
// Ensure ps remains open while the lock is held.
ps.locked.Add(1)
defer func() {
if err != nil {
ps.locked.Done()
}
}()
// Acquire the GP lock to prevent the system from modifying policy settings
// while they are being read.
if err := ps.policyLock.Lock(); err != nil {
if errors.Is(err, gp.ErrInvalidLockState) {
// The policy store is being closed and we've lost the race.
return ErrStoreClosed
}
return err
}
defer func() {
if err != nil {
ps.policyLock.Unlock()
}
}()
// Keep the Tailscale's registry keys open for the duration of the lock.
keyNames := tailscaleKeyNamesFor(ps.scope)
ps.tsKeys = make([]registry.Key, 0, len(keyNames))
for _, keyName := range keyNames {
var tsKey registry.Key
tsKey, err = registry.OpenKey(ps.softwareKey, keyName, windows.KEY_READ)
if err != nil {
if err == registry.ErrNotExist {
continue
}
return err
}
ps.tsKeys = append(ps.tsKeys, tsKey)
}
return nil
}
// Unlock decrements the lock counter and unlocks the policy store once the counter reaches 0.
// It panics if ps is not locked on entry to Unlock.
func (ps *PlatformPolicyStore) Unlock() {
ps.mu.Lock()
defer ps.mu.Unlock()
ps.lockCnt -= 1
if ps.lockCnt < 0 {
panic("negative lockCnt")
} else if ps.lockCnt != 0 {
return
}
for _, key := range ps.tsKeys {
key.Close()
}
ps.tsKeys = nil
ps.policyLock.Unlock()
ps.locked.Done()
}
// RegisterChangeCallback adds a function that will be called whenever there's a policy change.
// It returns a function that can be used to unregister the specified callback or an error.
// The error is [ErrStoreClosed] if ps has already been closed.
func (ps *PlatformPolicyStore) RegisterChangeCallback(cb func()) (unregister func(), err error) {
ps.mu.Lock()
defer ps.mu.Unlock()
if ps.closing {
return nil, ErrStoreClosed
}
handle := ps.cbs.Add(cb)
if len(ps.cbs) == 1 {
if ps.watcher, err = gp.NewChangeWatcher(ps.scope, ps.onChange); err != nil {
return nil, err
}
}
return func() {
ps.mu.Lock()
defer ps.mu.Unlock()
delete(ps.cbs, handle)
if len(ps.cbs) == 0 {
if ps.watcher != nil {
ps.watcher.Close()
ps.watcher = nil
}
}
}, nil
}
func (ps *PlatformPolicyStore) onChange() {
ps.mu.Lock()
defer ps.mu.Unlock()
if ps.closing {
return
}
for _, callback := range ps.cbs {
go callback()
}
}
// ReadString retrieves a string policy with the specified key.
// It returns [setting.ErrNotConfigured] if the policy setting does not exist.
func (ps *PlatformPolicyStore) ReadString(key setting.Key) (val string, err error) {
return getPolicyValue(ps, key,
func(key registry.Key, valueName string) (string, error) {
val, _, err := key.GetStringValue(valueName)
return val, err
})
}
// ReadUInt64 retrieves an integer policy with the specified key.
// It returns [setting.ErrNotConfigured] if the policy setting does not exist.
func (ps *PlatformPolicyStore) ReadUInt64(key setting.Key) (uint64, error) {
return getPolicyValue(ps, key,
func(key registry.Key, valueName string) (uint64, error) {
val, _, err := key.GetIntegerValue(valueName)
return val, err
})
}
// ReadBoolean retrieves a boolean policy with the specified key.
// It returns [setting.ErrNotConfigured] if the policy setting does not exist.
func (ps *PlatformPolicyStore) ReadBoolean(key setting.Key) (bool, error) {
return getPolicyValue(ps, key,
func(key registry.Key, valueName string) (bool, error) {
val, _, err := key.GetIntegerValue(valueName)
if err != nil {
return false, err
}
return val != 0, nil
})
}
// ReadString retrieves a multi-string policy with the specified key.
// It returns [setting.ErrNotConfigured] if the policy setting does not exist.
func (ps *PlatformPolicyStore) ReadStringArray(key setting.Key) ([]string, error) {
return getPolicyValue(ps, key,
func(key registry.Key, valueName string) ([]string, error) {
val, _, err := key.GetStringsValue(valueName)
if err != registry.ErrNotExist {
return val, err // the err may be nil or non-nil
}
// The idiomatic way to store multiple string values in Group Policy
// and MDM for Windows is to have multiple REG_SZ (or REG_EXPAND_SZ)
// values under a subkey rather than in a single REG_MULTI_SZ value.
//
// See the Group Policy: Registry Extension Encoding specification,
// and specifically the ListElement and ListBox types.
// https://web.archive.org/web/20240721033657/https://winprotocoldoc.blob.core.windows.net/productionwindowsarchives/MS-GPREG/%5BMS-GPREG%5D.pdf
valKey, err := registry.OpenKey(key, valueName, windows.KEY_READ)
if err != nil {
return nil, err
}
valNames, err := valKey.ReadValueNames(0)
if err != nil {
return nil, err
}
val = make([]string, 0, len(valNames))
for _, name := range valNames {
switch item, _, err := valKey.GetStringValue(name); {
case err == registry.ErrNotExist:
continue
case err != nil:
return nil, err
default:
val = append(val, item)
}
}
return val, nil
})
}
// splitSettingKey extracts the registry key name and value name from a [setting.Key].
// The [setting.Key] format allows grouping settings into nested categories using one
// or more [setting.KeyPathSeparator]s in the path. How individual policy settings are
// stored is an implementation detail of each [Store]. In the [PlatformPolicyStore]
// for Windows, we map nested policy categories onto the Registry key hierarchy.
// The last component after a [setting.KeyPathSeparator] is treated as the value name,
// while everything preceding it is considered a subpath (relative to the {HKLM,HKCU}\Software\Policies\Tailscale key).
// If there are no [setting.KeyPathSeparator]s in the key, the policy setting value
// is meant to be stored directly under {HKLM,HKCU}\Software\Policies\Tailscale.
func splitSettingKey(key setting.Key) (path, valueName string) {
if idx := strings.LastIndex(string(key), setting.KeyPathSeparator); idx != -1 {
path = strings.ReplaceAll(string(key[:idx]), setting.KeyPathSeparator, `\`)
valueName = string(key[idx+len(setting.KeyPathSeparator):])
return path, valueName
}
return "", string(key)
}
func getPolicyValue[T any](ps *PlatformPolicyStore, key setting.Key, getter registryValueGetter[T]) (T, error) {
var zero T
ps.mu.Lock()
defer ps.mu.Unlock()
if ps.closed {
return zero, ErrStoreClosed
}
path, valueName := splitSettingKey(key)
getValue := func(key registry.Key) (T, error) {
var err error
if path != "" {
key, err = registry.OpenKey(key, path, windows.KEY_READ)
if err != nil {
return zero, err
}
defer key.Close()
}
return getter(key, valueName)
}
if ps.tsKeys != nil {
// A non-nil tsKeys indicates that ps has been locked.
// The slice may be empty if Tailscale policy keys do not exist.
for _, tsKey := range ps.tsKeys {
val, err := getValue(tsKey)
if err == nil || err != registry.ErrNotExist {
return val, err
}
}
return zero, setting.ErrNotConfigured
}
// The ps has not been locked, so we don't have any pre-opened keys.
for _, tsKeyName := range tailscaleKeyNamesFor(ps.scope) {
var tsKey registry.Key
tsKey, err := registry.OpenKey(ps.softwareKey, tsKeyName, windows.KEY_READ)
if err != nil {
if err == registry.ErrNotExist {
continue
}
return zero, err
}
val, err := getValue(tsKey)
tsKey.Close()
if err == nil || err != registry.ErrNotExist {
return val, err
}
}
return zero, setting.ErrNotConfigured
}
// Close closes the policy store and releases any associated resources.
// It cancels pending locks and prevents any new lock attempts,
// but waits for existing locks to be released.
func (ps *PlatformPolicyStore) Close() error {
// Request to close the Group Policy read lock.
// Existing held locks will remain valid, but any new or pending locks
// will fail. In certain scenarios, the corresponding write lock may be held
// by the Group Policy service for extended periods (minutes rather than
// seconds or milliseconds). In such cases, we prefer not to wait that long
// if the ps is being closed anyway.
if ps.policyLock != nil {
ps.policyLock.Close()
}
// Mark ps as closing to fast-fail any new lock attempts.
// Callers that have already locked it can finish their reading.
ps.mu.Lock()
if ps.closing {
ps.mu.Unlock()
return nil
}
ps.closing = true
if ps.watcher != nil {
ps.watcher.Close()
ps.watcher = nil
}
ps.mu.Unlock()
// Signal to the external code that ps should no longer be used.
close(ps.done)
// Wait for any outstanding locks to be released.
ps.locked.Wait()
// Deny any further read attempts and release remaining resources.
ps.mu.Lock()
defer ps.mu.Unlock()
ps.cbs = nil
ps.policyLock = nil
ps.closed = true
if ps.softwareKey != 0 {
ps.softwareKey.Close()
ps.softwareKey = 0
}
return nil
}
// Done returns a channel that is closed when the Close method is called.
func (ps *PlatformPolicyStore) Done() <-chan struct{} {
return ps.done
}
func tailscaleKeyNamesFor(scope gp.Scope) []string {
switch scope {
case gp.MachinePolicy:
// If a computer-side policy value does not exist under Software\Policies\Tailscale,
// we need to fallback and use the legacy Software\Tailscale IPN key.
return []string{tsPoliciesSubkey, tsIPNSubkey}
case gp.UserPolicy:
// However, we've never used the legacy key with user-side policies,
// and we should never do so. Unlike HKLM\Software\Tailscale IPN,
// its HKCU counterpart is user-writable.
return []string{tsPoliciesSubkey}
default:
panic("unreachable")
}
}

View File

@@ -0,0 +1,398 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package source
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
"tailscale.com/tstest"
"tailscale.com/util/cibuild"
"tailscale.com/util/mak"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/winutil"
"tailscale.com/util/winutil/gp"
)
// subkeyStrings is a test type indicating that a string slice should be written
// to the registry as multiple REG_SZ values under the setting's key,
// rather than as a single REG_MULTI_SZ value under the group key.
// This is the same format as ADMX use for string lists.
type subkeyStrings []string
type testPolicyValue struct {
name setting.Key
value any
}
func TestLockUnlockPolicyStore(t *testing.T) {
// Make sure we don't leak goroutines
tstest.ResourceCheck(t)
store, err := NewMachinePlatformPolicyStore()
if err != nil {
t.Fatalf("NewMachinePolicyStore failed: %v", err)
}
t.Run("One-Goroutine", func(t *testing.T) {
if err := store.Lock(); err != nil {
t.Errorf("store.Lock(): got %v; want nil", err)
return
}
if v, err := store.ReadString("NonExistingPolicySetting"); err == nil || !errors.Is(err, setting.ErrNotConfigured) {
t.Errorf(`ReadString: got %v, %v; want "", %v`, v, err, setting.ErrNotConfigured)
}
store.Unlock()
})
// Lock the store N times from different goroutines.
const N = 100
var unlocked atomic.Int32
t.Run("N-Goroutines", func(t *testing.T) {
var wg sync.WaitGroup
wg.Add(N)
for range N {
go func() {
if err := store.Lock(); err != nil {
t.Errorf("store.Lock(): got %v; want nil", err)
return
}
if v, err := store.ReadString("NonExistingPolicySetting"); err == nil || !errors.Is(err, setting.ErrNotConfigured) {
t.Errorf(`ReadString: got %v, %v; want "", %v`, v, err, setting.ErrNotConfigured)
}
wg.Done()
time.Sleep(10 * time.Millisecond)
unlocked.Add(1)
store.Unlock()
}()
}
// Wait until the store is locked N times.
wg.Wait()
})
// Close the store. The call should wait for all held locks to be released.
if err := store.Close(); err != nil {
t.Fatalf("(*PolicyStore).Close failed: %v", err)
}
if locked := unlocked.Load(); locked != N {
t.Errorf("locked.Load(): got %v; want %v", locked, N)
}
// Any further attempts to lock it should fail.
if err = store.Lock(); err == nil || !errors.Is(err, ErrStoreClosed) {
t.Errorf("store.Lock(): got %v; want %v", err, ErrStoreClosed)
}
}
func TestReadPolicyStore(t *testing.T) {
if !winutil.IsCurrentProcessElevated() {
t.Skipf("test requires running as elevated user")
}
tests := []struct {
name setting.Key
newValue any
legacyValue any
want any
}{
{name: "LegacyPolicy", legacyValue: "LegacyValue", want: "LegacyValue"},
{name: "StringPolicy", legacyValue: "LegacyValue", newValue: "Value", want: "Value"},
{name: "StringPolicy_Empty", legacyValue: "LegacyValue", newValue: "", want: ""},
{name: "BoolPolicy_True", newValue: true, want: true},
{name: "BoolPolicy_False", newValue: false, want: false},
{name: "UIntPolicy_1", newValue: uint32(10), want: uint64(10)}, // uint32 values should be returned as uint64
{name: "UIntPolicy_2", newValue: uint64(1 << 37), want: uint64(1 << 37)},
{name: "StringListPolicy", newValue: []string{"Value1", "Value2"}, want: []string{"Value1", "Value2"}},
{name: "StringListPolicy_Empty", newValue: []string{}, want: []string{}},
{name: "StringListPolicy_SubKey", newValue: subkeyStrings{"Value1", "Value2"}, want: []string{"Value1", "Value2"}},
{name: "StringListPolicy_SubKey_Empty", newValue: subkeyStrings{}, want: []string{}},
}
runTests := func(t *testing.T, userStore bool, token windows.Token) {
var hive registry.Key
if userStore {
hive = registry.CURRENT_USER
} else {
hive = registry.LOCAL_MACHINE
}
// Write policy values to the registry.
newValues := make([]testPolicyValue, 0, len(tests))
for _, tt := range tests {
if tt.newValue != nil {
newValues = append(newValues, testPolicyValue{name: tt.name, value: tt.newValue})
}
}
policiesKeyName := softwareKeyName + `\` + tsPoliciesSubkey
cleanup, err := createTestPolicyValues(hive, policiesKeyName, newValues)
if err != nil {
t.Fatalf("createTestPolicyValues failed: %v", err)
}
t.Cleanup(cleanup)
// Write legacy policy values to the registry.
legacyValues := make([]testPolicyValue, 0, len(tests))
for _, tt := range tests {
if tt.legacyValue != nil {
legacyValues = append(legacyValues, testPolicyValue{name: tt.name, value: tt.legacyValue})
}
}
legacyKeyName := softwareKeyName + `\` + tsIPNSubkey
cleanup, err = createTestPolicyValues(hive, legacyKeyName, legacyValues)
if err != nil {
t.Fatalf("createTestPolicyValues failed: %v", err)
}
t.Cleanup(cleanup)
var store *PlatformPolicyStore
if userStore {
store, err = NewUserPlatformPolicyStore(token)
} else {
store, err = NewMachinePlatformPolicyStore()
}
if err != nil {
t.Fatalf("NewXPolicyStore failed: %v", err)
}
t.Cleanup(func() {
if err := store.Close(); err != nil {
t.Errorf("(*PolicyStore).Close failed: %v", err)
}
})
// testReadValues checks that [PolicyStore] returns the same values we wrote directly to the registry.
testReadValues := func(t *testing.T, withLocks bool) {
for _, tt := range tests {
t.Run(string(tt.name), func(t *testing.T) {
if userStore && tt.newValue == nil {
t.Skip("there is no legacy policies for users")
}
t.Parallel()
if withLocks {
if err := store.Lock(); err != nil {
t.Errorf("failed to acquire the lock: %v", err)
}
defer store.Unlock()
}
var got any
var err error
switch tt.want.(type) {
case string:
got, err = store.ReadString(tt.name)
case uint64:
got, err = store.ReadUInt64(tt.name)
case bool:
got, err = store.ReadBoolean(tt.name)
case []string:
got, err = store.ReadStringArray(tt.name)
}
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("got %v; want %v", got, tt.want)
}
})
}
}
t.Run("NoLock", func(t *testing.T) {
testReadValues(t, false)
})
t.Run("WithLock", func(t *testing.T) {
testReadValues(t, true)
})
}
t.Run("MachineStore", func(t *testing.T) {
runTests(t, false, 0)
})
t.Run("CurrentUserStore", func(t *testing.T) {
runTests(t, true, 0)
})
t.Run("UserStoreWithToken", func(t *testing.T) {
var token windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token); err != nil {
t.Fatalf("OpenProcessToken: %v", err)
}
defer token.Close()
runTests(t, true, token)
})
}
func TestPolicyStoreChangeNotifications(t *testing.T) {
if cibuild.On() {
t.Skipf("test requires running on a real Windows environment")
}
store, err := NewMachinePlatformPolicyStore()
if err != nil {
t.Fatalf("NewMachinePolicyStore failed: %v", err)
}
t.Cleanup(func() {
if err := store.Close(); err != nil {
t.Errorf("(*PolicyStore).Close failed: %v", err)
}
})
done := make(chan struct{})
unregister, err := store.RegisterChangeCallback(func() { close(done) })
if err != nil {
t.Fatalf("RegisterChangeCallback failed: %v", err)
}
t.Cleanup(unregister)
// RefreshMachinePolicy is a non-blocking call.
if err := gp.RefreshMachinePolicy(true); err != nil {
t.Fatalf("RefreshMachinePolicy failed: %v", err)
}
// We should receive a policy change notification when
// the Group Policy service completes policy processing.
// Otherwise, the test will eventually time out.
<-done
}
func TestSplitSettingKey(t *testing.T) {
tests := []struct {
name string
key setting.Key
wantPath string
wantValue string
}{
{
name: "empty",
key: "",
wantPath: ``,
wantValue: "",
},
{
name: "explicit-empty-path",
key: "/ValueName",
wantPath: ``,
wantValue: "ValueName",
},
{
name: "empty-value",
key: "Root/Sub/",
wantPath: `Root\Sub`,
wantValue: "",
},
{
name: "with-path",
key: "Root/Sub/ValueName",
wantPath: `Root\Sub`,
wantValue: "ValueName",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotPath, gotValue := splitSettingKey(tt.key)
if gotPath != tt.wantPath {
t.Errorf("Path: got %q, want %q", gotPath, tt.wantPath)
}
if gotValue != tt.wantValue {
t.Errorf("Value: got %q, want %q", gotValue, tt.wantPath)
}
})
}
}
func createTestPolicyValues(hive registry.Key, keyName string, values []testPolicyValue) (cleanup func(), err error) {
key, existing, err := registry.CreateKey(hive, keyName, registry.ALL_ACCESS)
if err != nil {
return nil, err
}
var valuesToDelete map[string][]string
doCleanup := func() {
for path, values := range valuesToDelete {
if len(values) == 0 {
registry.DeleteKey(key, path)
continue
}
key, err := registry.OpenKey(key, path, windows.KEY_ALL_ACCESS)
if err != nil {
continue
}
defer key.Close()
for _, value := range values {
key.DeleteValue(value)
}
}
key.Close()
if !existing {
registry.DeleteKey(hive, keyName)
}
}
defer func() {
if err != nil {
doCleanup()
}
}()
for _, v := range values {
key, existing := key, existing
path, valueName := splitSettingKey(v.name)
if path != "" {
if key, existing, err = registry.CreateKey(key, valueName, windows.KEY_ALL_ACCESS); err != nil {
return nil, err
}
defer key.Close()
}
if values, ok := valuesToDelete[path]; len(values) > 0 || (!ok && existing) {
values = append(values, valueName)
mak.Set(&valuesToDelete, path, values)
} else if !ok {
mak.Set(&valuesToDelete, path, nil)
}
switch value := v.value.(type) {
case string:
err = key.SetStringValue(valueName, value)
case uint32:
err = key.SetDWordValue(valueName, value)
case uint64:
err = key.SetQWordValue(valueName, value)
case bool:
if value {
err = key.SetDWordValue(valueName, 1)
} else {
err = key.SetDWordValue(valueName, 0)
}
case []string:
err = key.SetStringsValue(valueName, value)
case subkeyStrings:
key, _, err := registry.CreateKey(key, valueName, windows.KEY_ALL_ACCESS)
if err != nil {
return nil, err
}
defer key.Close()
mak.Set(&valuesToDelete, strings.Trim(path+`\`+valueName, `\`), nil)
for i, value := range value {
if err := key.SetStringValue(strconv.Itoa(i), value); err != nil {
return nil, err
}
}
default:
err = fmt.Errorf("unsupported value: %v (%T), name: %q", value, value, v.name)
}
if err != nil {
return nil, err
}
}
return doCleanup, nil
}

View File

@@ -0,0 +1,451 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package source
import (
"fmt"
"sync"
"sync/atomic"
xmaps "golang.org/x/exp/maps"
"tailscale.com/util/mak"
"tailscale.com/util/set"
"tailscale.com/util/syspolicy/internal"
"tailscale.com/util/syspolicy/setting"
)
var (
_ Store = (*TestStore)(nil)
_ Lockable = (*TestStore)(nil)
_ Changeable = (*TestStore)(nil)
_ Expirable = (*TestStore)(nil)
)
// TestValueType is a constraint that allows types supported by [TestStore].
type TestValueType interface {
bool | uint64 | string | []string
}
// TestSetting is a policy setting in a [TestStore].
type TestSetting[T TestValueType] struct {
// Key is the setting's unique identifier.
Key setting.Key
// Error is the error to be returned by the [TestStore] when reading
// a policy setting with the specified key.
Error error
// Value is the value to be returned by the [TestStore] when reading
// a policy setting with the specified key.
// It is only used if the Error is nil.
Value T
}
// TestSettingOf returns a [TestSetting] representing a policy setting
// configured with the specified key and value.
func TestSettingOf[T TestValueType](key setting.Key, value T) TestSetting[T] {
return TestSetting[T]{Key: key, Value: value}
}
// TestSettingWithError returns a [TestSetting] representing a policy setting
// with the specified key and error.
func TestSettingWithError[T TestValueType](key setting.Key, err error) TestSetting[T] {
return TestSetting[T]{Key: key, Error: err}
}
// testReadOperation describes a single policy setting read operation.
type testReadOperation struct {
// Key is the setting's unique identifier.
Key setting.Key
// Type is a value type of a read operation.
// [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue]
Type setting.Type
}
// TestExpectedReads is the number of read operations with the specified details.
type TestExpectedReads struct {
// Key is the setting's unique identifier.
Key setting.Key
// Type is a value type of a read operation.
// [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue]
Type setting.Type
// NumTimes is how many times a setting with the specified key and type should have been read.
NumTimes int
}
func (r TestExpectedReads) operation() testReadOperation {
return testReadOperation{r.Key, r.Type}
}
// TestStore is a [Store] that can be used in tests.
type TestStore struct {
tb internal.TB
done chan struct{}
storeLock sync.RWMutex // its RLock is exposed via [Store.Lock]/[Store.Unlock].
storeLockCount atomic.Int32
mu sync.RWMutex
suspendCount int // change callback are suspended if > 0
mr, mw map[setting.Key]any // maps for reading and writing; they're the same unless the store is suspended.
cbs set.HandleSet[func()]
readsMu sync.Mutex
reads map[testReadOperation]int // how many times a policy setting was read
}
// NewTestStore returns a new [TestStore].
// The tb will be used to report coding errors detected by the [TestStore].
func NewTestStore(tb internal.TB) *TestStore {
m := make(map[setting.Key]any)
return &TestStore{
tb: tb,
done: make(chan struct{}),
mr: m,
mw: m,
}
}
// NewTestStoreOf is a shorthand for [NewTestStore] followed by [TestStore.SetBooleans],
// [TestStore.SetUInt64s], [TestStore.SetStrings] or [TestStore.SetStringLists].
func NewTestStoreOf[T TestValueType](tb internal.TB, settings ...TestSetting[T]) *TestStore {
m := make(map[setting.Key]any)
store := &TestStore{
tb: tb,
done: make(chan struct{}),
mr: m,
mw: m,
}
switch settings := any(settings).(type) {
case []TestSetting[bool]:
store.SetBooleans(settings...)
case []TestSetting[uint64]:
store.SetUInt64s(settings...)
case []TestSetting[string]:
store.SetStrings(settings...)
case []TestSetting[[]string]:
store.SetStringLists(settings...)
}
return store
}
// Lock implements [Lockable].
func (s *TestStore) Lock() error {
s.storeLock.RLock()
s.storeLockCount.Add(1)
return nil
}
// Unlock implements [Lockable].
func (s *TestStore) Unlock() {
if s.storeLockCount.Add(-1) < 0 {
s.tb.Fatal("negative storeLockCount")
}
s.storeLock.RUnlock()
}
// RegisterChangeCallback implements [Changeable].
func (s *TestStore) RegisterChangeCallback(callback func()) (unregister func(), err error) {
s.mu.Lock()
defer s.mu.Unlock()
handle := s.cbs.Add(callback)
return func() {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.cbs, handle)
}, nil
}
// ReadString implements [Store].
func (s *TestStore) ReadString(key setting.Key) (string, error) {
defer s.recordRead(key, setting.StringValue)
s.mu.RLock()
defer s.mu.RUnlock()
v, ok := s.mr[key]
if !ok {
return "", setting.ErrNotConfigured
}
if err, ok := v.(error); ok {
return "", err
}
str, ok := v.(string)
if !ok {
return "", fmt.Errorf("%w in ReadString: got %T", setting.ErrTypeMismatch, v)
}
return str, nil
}
// ReadUInt64 implements [Store].
func (s *TestStore) ReadUInt64(key setting.Key) (uint64, error) {
defer s.recordRead(key, setting.IntegerValue)
s.mu.RLock()
defer s.mu.RUnlock()
v, ok := s.mr[key]
if !ok {
return 0, setting.ErrNotConfigured
}
if err, ok := v.(error); ok {
return 0, err
}
u64, ok := v.(uint64)
if !ok {
return 0, fmt.Errorf("%w in ReadUInt64: got %T", setting.ErrTypeMismatch, v)
}
return u64, nil
}
// ReadBoolean implements [Store].
func (s *TestStore) ReadBoolean(key setting.Key) (bool, error) {
defer s.recordRead(key, setting.BooleanValue)
s.mu.RLock()
defer s.mu.RUnlock()
v, ok := s.mr[key]
if !ok {
return false, setting.ErrNotConfigured
}
if err, ok := v.(error); ok {
return false, err
}
b, ok := v.(bool)
if !ok {
return false, fmt.Errorf("%w in ReadBoolean: got %T", setting.ErrTypeMismatch, v)
}
return b, nil
}
// ReadStringArray implements [Store].
func (s *TestStore) ReadStringArray(key setting.Key) ([]string, error) {
defer s.recordRead(key, setting.StringListValue)
s.mu.RLock()
defer s.mu.RUnlock()
v, ok := s.mr[key]
if !ok {
return nil, setting.ErrNotConfigured
}
if err, ok := v.(error); ok {
return nil, err
}
slice, ok := v.([]string)
if !ok {
return nil, fmt.Errorf("%w in ReadStringArray: got %T", setting.ErrTypeMismatch, v)
}
return slice, nil
}
func (s *TestStore) recordRead(key setting.Key, typ setting.Type) {
s.readsMu.Lock()
op := testReadOperation{key, typ}
num := s.reads[op]
num++
mak.Set(&s.reads, op, num)
s.readsMu.Unlock()
}
func (s *TestStore) ResetCounters() {
s.readsMu.Lock()
clear(s.reads)
s.readsMu.Unlock()
}
// ReadsMustEqual fails the test if the actual reads differs from the specified reads.
func (s *TestStore) ReadsMustEqual(reads ...TestExpectedReads) {
s.tb.Helper()
s.readsMu.Lock()
defer s.readsMu.Unlock()
s.readsMustContainLocked(reads...)
s.readMustNoExtraLocked(reads...)
}
// ReadsMustContain fails the test if the specified reads have not been made,
// or have been made a different number of times. It permits other values to be
// read in addition to the ones being tested.
func (s *TestStore) ReadsMustContain(reads ...TestExpectedReads) {
s.tb.Helper()
s.readsMu.Lock()
defer s.readsMu.Unlock()
s.readsMustContainLocked(reads...)
}
func (s *TestStore) readsMustContainLocked(reads ...TestExpectedReads) {
s.tb.Helper()
for _, r := range reads {
if numTimes := s.reads[r.operation()]; numTimes != r.NumTimes {
s.tb.Errorf("%q (%v) reads: got %v, want %v", r.Key, r.Type, numTimes, r.NumTimes)
}
}
}
func (s *TestStore) readMustNoExtraLocked(reads ...TestExpectedReads) {
s.tb.Helper()
rs := make(set.Set[testReadOperation])
for i := range reads {
rs.Add(reads[i].operation())
}
for ro, num := range s.reads {
if !rs.Contains(ro) {
s.tb.Errorf("%q (%v) reads: got %v, want 0", ro.Key, ro.Type, num)
}
}
}
// Suspend suspends the store, batching changes and notifications
// until [TestStore.Resume] is called the same number of times as Suspend.
func (s *TestStore) Suspend() {
s.mu.Lock()
defer s.mu.Unlock()
if s.suspendCount++; s.suspendCount == 1 {
s.mw = xmaps.Clone(s.mr)
}
}
// Resume resumes the store, applying the changes and invoking
// the change callbacks.
func (s *TestStore) Resume() {
s.storeLock.Lock()
s.mu.Lock()
switch s.suspendCount--; {
case s.suspendCount == 0:
s.mr = s.mw
s.mu.Unlock()
s.storeLock.Unlock()
s.notifyPolicyChanged()
case s.suspendCount < 0:
s.tb.Fatal("negative suspendCount")
default:
s.mu.Unlock()
s.storeLock.Unlock()
}
}
// SetBooleans sets the specified boolean settings in s.
func (s *TestStore) SetBooleans(settings ...TestSetting[bool]) {
s.storeLock.Lock()
for _, setting := range settings {
if setting.Key == "" {
s.tb.Fatal("empty keys disallowed")
}
s.mu.Lock()
if setting.Error != nil {
mak.Set(&s.mw, setting.Key, any(setting.Error))
} else {
mak.Set(&s.mw, setting.Key, any(setting.Value))
}
s.mu.Unlock()
}
s.storeLock.Unlock()
s.notifyPolicyChanged()
}
// SetUInt64s sets the specified integer settings in s.
func (s *TestStore) SetUInt64s(settings ...TestSetting[uint64]) {
s.storeLock.Lock()
for _, setting := range settings {
if setting.Key == "" {
s.tb.Fatal("empty keys disallowed")
}
s.mu.Lock()
if setting.Error != nil {
mak.Set(&s.mw, setting.Key, any(setting.Error))
} else {
mak.Set(&s.mw, setting.Key, any(setting.Value))
}
s.mu.Unlock()
}
s.storeLock.Unlock()
s.notifyPolicyChanged()
}
// SetStrings sets the specified string settings in s.
func (s *TestStore) SetStrings(settings ...TestSetting[string]) {
s.storeLock.Lock()
for _, setting := range settings {
if setting.Key == "" {
s.tb.Fatal("empty keys disallowed")
}
s.mu.Lock()
if setting.Error != nil {
mak.Set(&s.mw, setting.Key, any(setting.Error))
} else {
mak.Set(&s.mw, setting.Key, any(setting.Value))
}
s.mu.Unlock()
}
s.storeLock.Unlock()
s.notifyPolicyChanged()
}
// SetStrings sets the specified string list settings in s.
func (s *TestStore) SetStringLists(settings ...TestSetting[[]string]) {
s.storeLock.Lock()
for _, setting := range settings {
if setting.Key == "" {
s.tb.Fatal("empty keys disallowed")
}
s.mu.Lock()
if setting.Error != nil {
mak.Set(&s.mw, setting.Key, any(setting.Error))
} else {
mak.Set(&s.mw, setting.Key, any(setting.Value))
}
s.mu.Unlock()
}
s.storeLock.Unlock()
s.notifyPolicyChanged()
}
// Delete deletes the specified settings from s.
func (s *TestStore) Delete(keys ...setting.Key) {
s.storeLock.Lock()
for _, key := range keys {
s.mu.Lock()
delete(s.mw, key)
s.mu.Unlock()
}
s.storeLock.Unlock()
s.notifyPolicyChanged()
}
// Clear deletes all settings from s.
func (s *TestStore) Clear() {
s.storeLock.Lock()
s.mu.Lock()
clear(s.mw)
s.mu.Unlock()
s.storeLock.Unlock()
s.notifyPolicyChanged()
}
func (s *TestStore) notifyPolicyChanged() {
s.mu.RLock()
if s.suspendCount != 0 {
s.mu.RUnlock()
return
}
cbs := xmaps.Values(s.cbs)
s.mu.RUnlock()
var wg sync.WaitGroup
wg.Add(len(cbs))
for _, cb := range cbs {
go func() {
defer wg.Done()
cb()
}()
}
wg.Wait()
}
// Close closes s, notifying its users that it has expired.
func (s *TestStore) Close() {
s.mu.Lock()
defer s.mu.Unlock()
if s.done != nil {
close(s.done)
s.done = nil
}
}
// Done implements [Expirable].
func (s *TestStore) Done() <-chan struct{} {
return s.done
}