uti/syspolicy: user policy support, auto-refresh and initial preparation for policy structs

This updates the syspolicy package to support multiple policy sources in the
three policy scopes: user, profile, and device, and provides a merged resultant
policy. A policy source is a syspolicy/source.Store that has a name and provides
access to policy settings for a given scope. It can be registered with
syspolicy/rsop.RegisterStore. Policy sources and policy stores can be either
platform-specific or platform-agnostic. On Windows, we have the Registry-based,
platform-specific policy store implemented as
syspolicy/source.PlatformPolicyStore. This store provides access to the Group
Policy and MDM policy settings stored in the Registry. On other platforms, we
currently provide a wrapper that converts a syspolicy.Handler into a
syspolicy/source.Store. However, we should update them in follow-up PRs. An
example of a platform-agnostic policy store would be a policy deployed from the
control, a local policy config file, or even environment variables.

We maintain the current, most recent version of the resultant policy for each
scope in an rsop.Policy. This is done by reading and merging the policy settings
from the registered stores the first time the resultant policy is requested,
then re-reading and re-merging them if a store implements the source.Changeable
interface and reports a policy change. Policy change notifications are debounced
to avoid re-reading policy settings multiple times if there are several changes
within a short period. The rsop.Policy can notify clients if the resultant
policy has changed. However, we do not currently expose this via the syspolicy
package and plan to do so differently along with a struct-based policy hierarchy
in the next PR.

To facilitate this, all policy settings should be registered with the
setting.Register function. The syspolicy package does this automatically for all
policy settings defined in policy_keys.go.

The new functionality is available through the existing syspolicy.Read* set of
functions. However, we plan to expose it via a struct-based policy hierarchy,
along with policy change notifications that other subsystems can use, in the
next PR. We also plan to send the resultant policy back from tailscaled to the
clients via the LocalAPI.

This is primarily a foundational PR to facilitate future changes, but the
immediate observable changes on Windows include:
- The service will use the current policy setting values instead of those read
  at OS boot time.
- The GUI has access to policy settings configured on a per-user basis.
On Android:
- We now report policy setting usage via clientmetrics.

Updates #12687

Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
Nick Khyl 2024-08-02 19:18:42 -05:00
parent 655b4f8fc5
commit cab0e1a6f7
44 changed files with 7320 additions and 752 deletions

View File

@ -10,7 +10,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw
W 💣 github.com/dblohm7/wingoes from tailscale.com/util/winutil
github.com/fxamacker/cbor/v2 from tailscale.com/tka
github.com/go-json-experiment/json from tailscale.com/types/opt
github.com/go-json-experiment/json from tailscale.com/types/opt+
github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json+
github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json+
github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json+
@ -146,9 +146,11 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
tailscale.com/util/cloudenv from tailscale.com/hostinfo+
W tailscale.com/util/cmpver from tailscale.com/net/tshttpproxy
tailscale.com/util/ctxkey from tailscale.com/tsweb+
💣 tailscale.com/util/deephash from tailscale.com/util/syspolicy/setting
L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics
tailscale.com/util/dnsname from tailscale.com/hostinfo+
tailscale.com/util/fastuuid from tailscale.com/tsweb
💣 tailscale.com/util/hashx from tailscale.com/util/deephash
tailscale.com/util/httpm from tailscale.com/client/tailscale
tailscale.com/util/lineread from tailscale.com/hostinfo+
L tailscale.com/util/linuxfw from tailscale.com/net/netns
@ -159,8 +161,17 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
tailscale.com/util/singleflight from tailscale.com/net/dnscache
tailscale.com/util/slicesx from tailscale.com/cmd/derper+
tailscale.com/util/syspolicy from tailscale.com/ipn
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/internal/lazyinit from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
tailscale.com/util/testenv from tailscale.com/util/syspolicy+
tailscale.com/util/vizerror from tailscale.com/tailcfg+
W 💣 tailscale.com/util/winutil from tailscale.com/hostinfo+
W 💣 tailscale.com/util/winutil/gp from tailscale.com/util/syspolicy/source
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
tailscale.com/version from tailscale.com/derp+
tailscale.com/version/distro from tailscale.com/envknob+
@ -180,6 +191,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box
golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+
W golang.org/x/exp/constraints from tailscale.com/util/winutil
golang.org/x/exp/maps from tailscale.com/util/syspolicy/internal/metrics+
L golang.org/x/net/bpf from github.com/mdlayher/netlink+
golang.org/x/net/dns/dnsmessage from net+
golang.org/x/net/http/httpguts from net/http
@ -240,7 +252,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
encoding/pem from crypto/tls+
errors from bufio+
expvar from github.com/prometheus/client_golang/prometheus+
flag from tailscale.com/cmd/derper
flag from tailscale.com/cmd/derper+
fmt from compress/flate+
go/token from google.golang.org/protobuf/internal/strs
hash from crypto+
@ -273,7 +285,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
os from crypto/rand+
os/exec from github.com/coreos/go-iptables/iptables+
os/signal from tailscale.com/cmd/derper
W os/user from tailscale.com/util/winutil
W os/user from tailscale.com/util/winutil+
path from github.com/prometheus/client_golang/prometheus/internal+
path/filepath from crypto/x509+
reflect from crypto/x509+

View File

@ -96,7 +96,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
💣 github.com/fsnotify/fsnotify from sigs.k8s.io/controller-runtime/pkg/certwatcher
github.com/fxamacker/cbor/v2 from tailscale.com/tka
github.com/gaissmai/bart from tailscale.com/net/ipset+
github.com/go-json-experiment/json from tailscale.com/types/opt
github.com/go-json-experiment/json from tailscale.com/types/opt+
github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json/internal/jsonflags+
github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json/internal/jsonopts+
github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json/jsontext+
@ -803,6 +803,13 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
tailscale.com/util/singleflight from tailscale.com/control/controlclient+
tailscale.com/util/slicesx from tailscale.com/appc+
tailscale.com/util/syspolicy from tailscale.com/control/controlclient+
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/internal/lazyinit from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock
tailscale.com/util/systemd from tailscale.com/control/controlclient+
tailscale.com/util/testenv from tailscale.com/control/controlclient+
@ -811,7 +818,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
tailscale.com/util/vizerror from tailscale.com/tailcfg+
💣 tailscale.com/util/winutil from tailscale.com/clientupdate+
W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate+
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns+
W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
tailscale.com/util/zstdframe from tailscale.com/control/controlclient+

View File

@ -9,7 +9,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
W 💣 github.com/dblohm7/wingoes from github.com/dblohm7/wingoes/pe+
W 💣 github.com/dblohm7/wingoes/pe from tailscale.com/util/winutil/authenticode
github.com/fxamacker/cbor/v2 from tailscale.com/tka
github.com/go-json-experiment/json from tailscale.com/types/opt
github.com/go-json-experiment/json from tailscale.com/types/opt+
github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json+
github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json+
github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json+
@ -152,9 +152,11 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
tailscale.com/util/cloudenv from tailscale.com/net/dnscache+
tailscale.com/util/cmpver from tailscale.com/net/tshttpproxy+
tailscale.com/util/ctxkey from tailscale.com/types/logger
💣 tailscale.com/util/deephash from tailscale.com/util/syspolicy/setting
L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics
tailscale.com/util/dnsname from tailscale.com/cmd/tailscale/cli+
tailscale.com/util/groupmember from tailscale.com/client/web
💣 tailscale.com/util/hashx from tailscale.com/util/deephash
tailscale.com/util/httpm from tailscale.com/client/tailscale+
tailscale.com/util/lineread from tailscale.com/hostinfo+
L tailscale.com/util/linuxfw from tailscale.com/net/netns
@ -167,11 +169,19 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
tailscale.com/util/singleflight from tailscale.com/net/dnscache+
tailscale.com/util/slicesx from tailscale.com/net/dns/recursive+
tailscale.com/util/syspolicy from tailscale.com/ipn
tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/internal/lazyinit from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli+
tailscale.com/util/truncate from tailscale.com/cmd/tailscale/cli
tailscale.com/util/vizerror from tailscale.com/tailcfg+
💣 tailscale.com/util/winutil from tailscale.com/clientupdate+
W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate
W 💣 tailscale.com/util/winutil/gp from tailscale.com/util/syspolicy/source
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
tailscale.com/version from tailscale.com/client/web+
tailscale.com/version/distro from tailscale.com/client/web+
@ -191,7 +201,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
golang.org/x/crypto/pbkdf2 from software.sslmate.com/src/go-pkcs12
golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+
W golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe+
golang.org/x/exp/maps from tailscale.com/cmd/tailscale/cli
golang.org/x/exp/maps from tailscale.com/cmd/tailscale/cli+
golang.org/x/net/bpf from github.com/mdlayher/netlink+
golang.org/x/net/dns/dnsmessage from net+
golang.org/x/net/http/httpguts from net/http+

View File

@ -90,7 +90,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
💣 github.com/djherbis/times from tailscale.com/drive/driveimpl
github.com/fxamacker/cbor/v2 from tailscale.com/tka
github.com/gaissmai/bart from tailscale.com/net/tstun+
github.com/go-json-experiment/json from tailscale.com/types/opt
github.com/go-json-experiment/json from tailscale.com/types/opt+
github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json/internal/jsonflags+
github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json/internal/jsonopts+
github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json/jsontext+
@ -395,6 +395,13 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/util/singleflight from tailscale.com/control/controlclient+
tailscale.com/util/slicesx from tailscale.com/net/dns/recursive+
tailscale.com/util/syspolicy from tailscale.com/cmd/tailscaled+
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/internal/lazyinit from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock
tailscale.com/util/systemd from tailscale.com/control/controlclient+
tailscale.com/util/testenv from tailscale.com/ipn/ipnlocal+
@ -403,7 +410,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/util/vizerror from tailscale.com/tailcfg+
💣 tailscale.com/util/winutil from tailscale.com/clientupdate+
W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate+
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns+
W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
tailscale.com/util/zstdframe from tailscale.com/control/controlclient+

View File

@ -52,6 +52,8 @@
"tailscale.com/util/must"
"tailscale.com/util/set"
"tailscale.com/util/syspolicy"
"tailscale.com/util/syspolicy/rsop"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/wgengine"
"tailscale.com/wgengine/filter"
"tailscale.com/wgengine/wgcfg"
@ -2546,6 +2548,14 @@ func TestPreferencePolicyInfo(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
definitions := make([]*setting.Definition, 0, len(preferencePolicies)+1)
definitions = append(definitions, must.Get(syspolicy.WellKnownSettingDefinition(syspolicy.ControlURL)))
for _, pp := range preferencePolicies {
definitions = append(definitions, must.Get(syspolicy.WellKnownSettingDefinition(pp.key)))
}
if err := setting.SetDefinitionsForTest(t, definitions...); err != nil {
t.Fatalf("SetDefinitionsForTest failed: %v", err)
}
for _, pp := range preferencePolicies {
t.Run(string(pp.key), func(t *testing.T) {
var h syspolicy.Handler
@ -2572,7 +2582,7 @@ func TestPreferencePolicyInfo(t *testing.T) {
msh.stringPolicies[pp.key] = &tt.policyValue
h = msh
}
syspolicy.SetHandlerForTest(t, h)
rsop.RegisterStoreForTest(t, tt.name, setting.DeviceScope, syspolicy.WrapHandler(h))
prefs := defaultPrefs.AsStruct()
pp.set(prefs, tt.initialValue)

View File

@ -1,122 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package syspolicy
import (
"errors"
"sync"
)
// CachingHandler is a handler that reads policies from an underlying handler the first time each key is requested
// and permanently caches the result unless there is an error. If there is an ErrNoSuchKey error, that result is cached,
// otherwise the actual error is returned and the next read for that key will retry using the handler.
type CachingHandler struct {
mu sync.Mutex
strings map[string]string
uint64s map[string]uint64
bools map[string]bool
strArrs map[string][]string
notFound map[string]bool
handler Handler
}
// NewCachingHandler creates a CachingHandler given a handler.
func NewCachingHandler(handler Handler) *CachingHandler {
return &CachingHandler{
handler: handler,
strings: make(map[string]string),
uint64s: make(map[string]uint64),
bools: make(map[string]bool),
strArrs: make(map[string][]string),
notFound: make(map[string]bool),
}
}
// ReadString reads the policy settings value string given the key.
// ReadString first reads from the handler's cache before resorting to using the handler.
func (ch *CachingHandler) ReadString(key string) (string, error) {
ch.mu.Lock()
defer ch.mu.Unlock()
if val, ok := ch.strings[key]; ok {
return val, nil
}
if notFound := ch.notFound[key]; notFound {
return "", ErrNoSuchKey
}
val, err := ch.handler.ReadString(key)
if errors.Is(err, ErrNoSuchKey) {
ch.notFound[key] = true
return "", err
} else if err != nil {
return "", err
}
ch.strings[key] = val
return val, nil
}
// ReadUInt64 reads the policy settings uint64 value given the key.
// ReadUInt64 first reads from the handler's cache before resorting to using the handler.
func (ch *CachingHandler) ReadUInt64(key string) (uint64, error) {
ch.mu.Lock()
defer ch.mu.Unlock()
if val, ok := ch.uint64s[key]; ok {
return val, nil
}
if notFound := ch.notFound[key]; notFound {
return 0, ErrNoSuchKey
}
val, err := ch.handler.ReadUInt64(key)
if errors.Is(err, ErrNoSuchKey) {
ch.notFound[key] = true
return 0, err
} else if err != nil {
return 0, err
}
ch.uint64s[key] = val
return val, nil
}
// ReadBoolean reads the policy settings boolean value given the key.
// ReadBoolean first reads from the handler's cache before resorting to using the handler.
func (ch *CachingHandler) ReadBoolean(key string) (bool, error) {
ch.mu.Lock()
defer ch.mu.Unlock()
if val, ok := ch.bools[key]; ok {
return val, nil
}
if notFound := ch.notFound[key]; notFound {
return false, ErrNoSuchKey
}
val, err := ch.handler.ReadBoolean(key)
if errors.Is(err, ErrNoSuchKey) {
ch.notFound[key] = true
return false, err
} else if err != nil {
return false, err
}
ch.bools[key] = val
return val, nil
}
// ReadBoolean reads the policy settings boolean value given the key.
// ReadBoolean first reads from the handler's cache before resorting to using the handler.
func (ch *CachingHandler) ReadStringArray(key string) ([]string, error) {
ch.mu.Lock()
defer ch.mu.Unlock()
if val, ok := ch.strArrs[key]; ok {
return val, nil
}
if notFound := ch.notFound[key]; notFound {
return nil, ErrNoSuchKey
}
val, err := ch.handler.ReadStringArray(key)
if errors.Is(err, ErrNoSuchKey) {
ch.notFound[key] = true
return nil, err
} else if err != nil {
return nil, err
}
ch.strArrs[key] = val
return val, nil
}

View File

@ -1,262 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package syspolicy
import (
"testing"
)
func TestHandlerReadString(t *testing.T) {
tests := []struct {
name string
key string
handlerKey Key
handlerValue string
handlerError error
preserveHandler bool
wantValue string
wantErr error
strings map[string]string
expectedCalls int
}{
{
name: "read existing cached values",
key: "test",
handlerKey: "do not read",
strings: map[string]string{"test": "foo"},
wantValue: "foo",
expectedCalls: 0,
},
{
name: "read existing values not cached",
key: "test",
handlerKey: "test",
handlerValue: "foo",
wantValue: "foo",
expectedCalls: 1,
},
{
name: "error no such key",
key: "test",
handlerKey: "test",
handlerError: ErrNoSuchKey,
wantErr: ErrNoSuchKey,
expectedCalls: 1,
},
{
name: "other error",
key: "test",
handlerKey: "test",
handlerError: someOtherError,
wantErr: someOtherError,
preserveHandler: true,
expectedCalls: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testHandler := &testHandler{
t: t,
key: tt.handlerKey,
s: tt.handlerValue,
err: tt.handlerError,
}
cache := NewCachingHandler(testHandler)
if tt.strings != nil {
cache.strings = tt.strings
}
got, err := cache.ReadString(tt.key)
if err != tt.wantErr {
t.Errorf("err=%v want %v", err, tt.wantErr)
}
if got != tt.wantValue {
t.Errorf("got %v want %v", got, cache.strings[tt.key])
}
if !tt.preserveHandler {
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
}
got, err = cache.ReadString(tt.key)
if err != tt.wantErr {
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
}
if got != tt.wantValue {
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
}
if testHandler.calls != tt.expectedCalls {
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
}
})
}
}
func TestHandlerReadUint64(t *testing.T) {
tests := []struct {
name string
key string
handlerKey Key
handlerValue uint64
handlerError error
preserveHandler bool
wantValue uint64
wantErr error
uint64s map[string]uint64
expectedCalls int
}{
{
name: "read existing cached values",
key: "test",
handlerKey: "do not read",
uint64s: map[string]uint64{"test": 1},
wantValue: 1,
expectedCalls: 0,
},
{
name: "read existing values not cached",
key: "test",
handlerKey: "test",
handlerValue: 1,
wantValue: 1,
expectedCalls: 1,
},
{
name: "error no such key",
key: "test",
handlerKey: "test",
handlerError: ErrNoSuchKey,
wantErr: ErrNoSuchKey,
expectedCalls: 1,
},
{
name: "other error",
key: "test",
handlerKey: "test",
handlerError: someOtherError,
wantErr: someOtherError,
preserveHandler: true,
expectedCalls: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testHandler := &testHandler{
t: t,
key: tt.handlerKey,
u64: tt.handlerValue,
err: tt.handlerError,
}
cache := NewCachingHandler(testHandler)
if tt.uint64s != nil {
cache.uint64s = tt.uint64s
}
got, err := cache.ReadUInt64(tt.key)
if err != tt.wantErr {
t.Errorf("err=%v want %v", err, tt.wantErr)
}
if got != tt.wantValue {
t.Errorf("got %v want %v", got, cache.strings[tt.key])
}
if !tt.preserveHandler {
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
}
got, err = cache.ReadUInt64(tt.key)
if err != tt.wantErr {
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
}
if got != tt.wantValue {
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
}
if testHandler.calls != tt.expectedCalls {
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
}
})
}
}
func TestHandlerReadBool(t *testing.T) {
tests := []struct {
name string
key string
handlerKey Key
handlerValue bool
handlerError error
preserveHandler bool
wantValue bool
wantErr error
bools map[string]bool
expectedCalls int
}{
{
name: "read existing cached values",
key: "test",
handlerKey: "do not read",
bools: map[string]bool{"test": true},
wantValue: true,
expectedCalls: 0,
},
{
name: "read existing values not cached",
key: "test",
handlerKey: "test",
handlerValue: true,
wantValue: true,
expectedCalls: 1,
},
{
name: "error no such key",
key: "test",
handlerKey: "test",
handlerError: ErrNoSuchKey,
wantErr: ErrNoSuchKey,
expectedCalls: 1,
},
{
name: "other error",
key: "test",
handlerKey: "test",
handlerError: someOtherError,
wantErr: someOtherError,
preserveHandler: true,
expectedCalls: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testHandler := &testHandler{
t: t,
key: tt.handlerKey,
b: tt.handlerValue,
err: tt.handlerError,
}
cache := NewCachingHandler(testHandler)
if tt.bools != nil {
cache.bools = tt.bools
}
got, err := cache.ReadBoolean(tt.key)
if err != tt.wantErr {
t.Errorf("err=%v want %v", err, tt.wantErr)
}
if got != tt.wantValue {
t.Errorf("got %v want %v", got, cache.strings[tt.key])
}
if !tt.preserveHandler {
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
}
got, err = cache.ReadBoolean(tt.key)
if err != tt.wantErr {
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
}
if got != tt.wantValue {
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
}
if testHandler.calls != tt.expectedCalls {
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
}
})
}
}

View File

@ -4,16 +4,15 @@
package syspolicy
import (
"errors"
"sync/atomic"
)
var (
handlerUsed atomic.Bool
handler Handler = defaultHandler{}
"tailscale.com/util/syspolicy/internal"
"tailscale.com/util/syspolicy/rsop"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/syspolicy/source"
)
// Handler reads system policies from OS-specific storage.
//
// Deprecated: implementing a [Store] should be preferred.
type Handler interface {
// ReadString reads the policy setting's string value for the given key.
// It should return ErrNoSuchKey if the key does not have a value set.
@ -29,55 +28,81 @@ type Handler interface {
ReadStringArray(key string) ([]string, error)
}
// ErrNoSuchKey is returned by a Handler when the specified key does not have a
// value set.
var ErrNoSuchKey = errors.New("no such key")
// defaultHandler is the catch all syspolicy type for anything that isn't windows or apple.
type defaultHandler struct{}
func (defaultHandler) ReadString(_ string) (string, error) {
return "", ErrNoSuchKey
}
func (defaultHandler) ReadUInt64(_ string) (uint64, error) {
return 0, ErrNoSuchKey
}
func (defaultHandler) ReadBoolean(_ string) (bool, error) {
return false, ErrNoSuchKey
}
func (defaultHandler) ReadStringArray(_ string) ([]string, error) {
return nil, ErrNoSuchKey
}
// markHandlerInUse is called before handler methods are called.
func markHandlerInUse() {
handlerUsed.Store(true)
}
// RegisterHandler initializes the policy handler and ensures registration will happen once.
// RegisterHandler wraps and registers the specified handler as the device's
// policy [Store] for the program's lifetime.
//
// Deprecated: using [RegisterStore] should be preferred.
func RegisterHandler(h Handler) {
// Technically this assignment is not concurrency safe, but in the
// event that there was any risk of a data race, we will panic due to
// the CompareAndSwap failing.
handler = h
if !handlerUsed.CompareAndSwap(false, true) {
panic("handler was already used before registration")
}
rsop.RegisterStore("DeviceHandler", setting.DeviceScope, WrapHandler(h))
}
// TB is a subset of testing.TB that we use to set up test helpers.
// It's defined here to avoid pulling in the testing package.
type TB interface {
Helper()
Cleanup(func())
type TB = internal.TB
// SetHandlerForTest wraps and sets the specified handler as the device's policy
// [Store] for the duration of tb.
//
// Deprecated: using [resultant.RegisterStoreForTest] should be preferred.
func SetHandlerForTest(tb TB, h Handler) {
if err := setWellKnownSettingsForTest(tb); err != nil {
tb.Fatalf("setWellKnownSettingsForTest failed: %v", err)
}
rsop.RegisterStoreForTest(tb, "DeviceHandler-TestOnly", setting.CurrentScope(), WrapHandler(h))
}
func SetHandlerForTest(tb TB, h Handler) {
tb.Helper()
oldHandler := handler
handler = h
tb.Cleanup(func() { handler = oldHandler })
var _ source.Store = (*handlerStore)(nil)
// handlerStore is a [source.Store] that calls the underlying [Handler].
// TODO(nickkhyl): remove it when the corp and android repos are updated.
type handlerStore struct {
h Handler
}
// WrapHandler returns a [source.Store] that wraps the specified [Handler].
func WrapHandler(h Handler) source.Store {
return handlerStore{h}
}
func (s handlerStore) Lock() error {
if lockable, ok := s.h.(source.Lockable); ok {
return lockable.Lock()
}
return nil
}
func (s handlerStore) Unlock() {
if lockable, ok := s.h.(source.Lockable); ok {
lockable.Unlock()
}
}
func (s handlerStore) RegisterChangeCallback(callback func()) (unregister func(), err error) {
if lockable, ok := s.h.(source.Changeable); ok {
return lockable.RegisterChangeCallback(callback)
}
return func() {}, nil
}
func (s handlerStore) ReadString(key setting.Key) (string, error) {
return s.h.ReadString(string(key))
}
func (s handlerStore) ReadUInt64(key setting.Key) (uint64, error) {
return s.h.ReadUInt64(string(key))
}
func (s handlerStore) ReadBoolean(key setting.Key) (bool, error) {
return s.h.ReadBoolean(string(key))
}
func (s handlerStore) ReadStringArray(key setting.Key) ([]string, error) {
return s.h.ReadStringArray(string(key))
}
func (s handlerStore) Done() <-chan struct{} {
if expirable, ok := s.h.(source.Expirable); ok {
return expirable.Done()
}
return nil
}

View File

@ -1,19 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package syspolicy
import "testing"
func TestDefaultHandlerReadValues(t *testing.T) {
var h defaultHandler
got, err := h.ReadString(string(AdminConsoleVisibility))
if got != "" || err != ErrNoSuchKey {
t.Fatalf("got %v err %v", got, err)
}
result, err := h.ReadUInt64(string(LogSCMInteractions))
if result != 0 || err != ErrNoSuchKey {
t.Fatalf("got %v err %v", result, err)
}
}

View File

@ -1,105 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package syspolicy
import (
"errors"
"fmt"
"tailscale.com/util/clientmetric"
"tailscale.com/util/winutil"
)
var (
windowsErrors = clientmetric.NewCounter("windows_syspolicy_errors")
windowsAny = clientmetric.NewGauge("windows_syspolicy_any")
)
type windowsHandler struct{}
func init() {
RegisterHandler(NewCachingHandler(windowsHandler{}))
keyList := []struct {
isSet func(Key) bool
keys []Key
}{
{
isSet: func(k Key) bool {
_, err := handler.ReadString(string(k))
return err == nil
},
keys: stringKeys,
},
{
isSet: func(k Key) bool {
_, err := handler.ReadBoolean(string(k))
return err == nil
},
keys: boolKeys,
},
{
isSet: func(k Key) bool {
_, err := handler.ReadUInt64(string(k))
return err == nil
},
keys: uint64Keys,
},
}
var anySet bool
for _, l := range keyList {
for _, k := range l.keys {
if !l.isSet(k) {
continue
}
clientmetric.NewGauge(fmt.Sprintf("windows_syspolicy_%s", k)).Set(1)
anySet = true
}
}
if anySet {
windowsAny.Set(1)
}
}
func (windowsHandler) ReadString(key string) (string, error) {
s, err := winutil.GetPolicyString(key)
if errors.Is(err, winutil.ErrNoValue) {
err = ErrNoSuchKey
} else if err != nil {
windowsErrors.Add(1)
}
return s, err
}
func (windowsHandler) ReadUInt64(key string) (uint64, error) {
value, err := winutil.GetPolicyInteger(key)
if errors.Is(err, winutil.ErrNoValue) {
err = ErrNoSuchKey
} else if err != nil {
windowsErrors.Add(1)
}
return value, err
}
func (windowsHandler) ReadBoolean(key string) (bool, error) {
value, err := winutil.GetPolicyInteger(key)
if errors.Is(err, winutil.ErrNoValue) {
err = ErrNoSuchKey
} else if err != nil {
windowsErrors.Add(1)
}
return value != 0, err
}
func (windowsHandler) ReadStringArray(key string) ([]string, error) {
value, err := winutil.GetPolicyStringArray(key)
if errors.Is(err, winutil.ErrNoValue) {
err = ErrNoSuchKey
} else if err != nil {
windowsErrors.Add(1)
}
return value, err
}

View File

@ -0,0 +1,63 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package internal contains miscellaneous functions and types
// that are internal to the syspolicy packages.
package internal
import (
"bytes"
"github.com/go-json-experiment/json/jsontext"
"tailscale.com/types/lazy"
"tailscale.com/version"
)
// OSForTesting is the operating system override used for testing.
// It follows the same naming convention as [version.OS].
var OSForTesting lazy.SyncValue[string]
// OS is like [version.OS], but supports a test hook.
func OS() string {
return OSForTesting.Get(version.OS)
}
// TB is a subset of testing.TB that we use to set up test helpers.
// It's defined here to avoid pulling in the testing package.
type TB interface {
Helper()
Cleanup(func())
Logf(format string, args ...any)
Error(args ...any)
Errorf(format string, args ...any)
Fatal(args ...any)
Fatalf(format string, args ...any)
}
// EqualJSONForTest compares the JSON in j1 and j2 for semantic equality.
// It returns "", "", true if j1 and j2 are equal. Otherwise, it returns
// indented versions of j1 and j2 and false.
func EqualJSONForTest(tb TB, j1, j2 jsontext.Value) (s1, s2 string, equal bool) {
tb.Helper()
j1 = j1.Clone()
j2 = j2.Clone()
// Canonicalize JSON values for comparison.
if err := j1.Canonicalize(); err != nil {
tb.Error(err)
}
if err := j2.Canonicalize(); err != nil {
tb.Error(err)
}
// Check and return true if the two values are structurally equal.
if bytes.Equal(j1, j2) {
return "", "", true
}
// Otherwise, format the values for display and return false.
if err := j1.Indent("", "\t"); err != nil {
tb.Fatal(err)
}
if err := j2.Indent("", "\t"); err != nil {
tb.Fatal(err)
}
return j1.String(), j2.String(), false
}

View File

@ -0,0 +1,84 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// The lazyinit package facilitates deferred package initialization.
package lazyinit
import (
"sync"
"sync/atomic"
)
var packageInit deferredOnce
// Defer defers the specified action until [Do] is called.
// It returns a boolean indicating whether [Do] has already been called.
func Defer(action func() error) bool {
return packageInit.Defer(action)
}
// DeferWithCleanup is like [Defer], but the action function returns a cleanup
// function to be called in case of an error.
func DeferWithCleanup(action func() (cleanup func(), err error)) bool {
return packageInit.DeferWithCleanup(action)
}
// Do runs all deferred functions and returns an error if any of them fail.
func Do() error {
return packageInit.Do()
}
type deferredOnce struct {
done atomic.Uint32
err error
m sync.Mutex
funcs []func() (cleanup func(), err error)
}
func (o *deferredOnce) Defer(action func() error) bool {
return o.DeferWithCleanup(func() (cleanup func(), err error) {
return nil, action()
})
}
func (o *deferredOnce) DeferWithCleanup(action func() (cleanup func(), err error)) bool {
o.m.Lock()
defer o.m.Unlock()
if o.done.Load() != 0 {
return false
}
o.funcs = append(o.funcs, action)
return true
}
func (o *deferredOnce) Do() error {
if o.done.Load() == 0 {
o.doSlow()
}
return o.err
}
func (o *deferredOnce) doSlow() (err error) {
o.m.Lock()
defer o.m.Unlock()
if o.done.Load() == 0 {
defer func() {
o.done.Store(1)
o.err = err
}()
for _, f := range o.funcs {
cleanup, err := f()
if err != nil {
return err
}
if cleanup != nil {
defer func() {
if err != nil {
cleanup()
}
}()
}
}
}
return o.err
}

View File

@ -0,0 +1,46 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package loggerx provides logging functions to the rest of the syspolicy packages.
package loggerx
import (
"log"
"tailscale.com/types/lazy"
"tailscale.com/types/logger"
"tailscale.com/util/syspolicy/internal"
)
const (
errorPrefix = "syspolicy: "
verbosePrefix = "syspolicy: [v2] "
)
var (
lazyErrorf lazy.SyncValue[logger.Logf]
lazyVerbosef lazy.SyncValue[logger.Logf]
)
// Errorf formats and writes an error message to the log.
func Errorf(format string, args ...any) {
errorf := lazyErrorf.Get(func() logger.Logf {
return logger.WithPrefix(log.Printf, errorPrefix)
})
errorf(format, args...)
}
// Verbosef formats and writes an optional, verbose message to the log.
func Verbosef(format string, args ...any) {
verbosef := lazyVerbosef.Get(func() logger.Logf {
return logger.WithPrefix(log.Printf, verbosePrefix)
})
verbosef(format, args...)
}
// SetForTest sets the specified errorf and verbosef functions for the duration
// of tb and its subtests.
func SetForTest(tb internal.TB, errorf, verbosef logger.Logf) {
lazyErrorf.SetForTest(tb, errorf, nil)
lazyVerbosef.SetForTest(tb, verbosef, nil)
}

View File

@ -0,0 +1,315 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package metrics provides logging and reporting for policy settings and scopes.
package metrics
import (
"strings"
"sync"
xmaps "golang.org/x/exp/maps"
"tailscale.com/syncs"
"tailscale.com/types/lazy"
"tailscale.com/util/clientmetric"
"tailscale.com/util/mak"
"tailscale.com/util/syspolicy/internal"
"tailscale.com/util/syspolicy/internal/loggerx"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/testenv"
)
var lazyReportMetrics lazy.SyncValue[bool] // used as a test hook
// ShouldReport reports whether metrics should be reported on the current environment.
func ShouldReport() bool {
return lazyReportMetrics.Get(func() bool {
// macOS, iOS and tvOS create their own metrics,
// and we don't have syspolicy on any other platforms.
return setting.PlatformList{"android", "windows"}.HasCurrent()
})
}
// Reset metrics for the specified policy origin.
func Reset(origin *setting.Origin) {
scopeMetrics(origin).Reset()
}
// ReportConfigured updates metrics and logs that the specified setting is
// configured with the given value in the origin.
func ReportConfigured(origin *setting.Origin, setting *setting.Definition, value any) {
settingMetricsFor(setting).ReportValue(origin, value)
}
// ReportError updates metrics and logs that the specified setting has an error
// in the origin.
func ReportError(origin *setting.Origin, setting *setting.Definition, err error) {
settingMetricsFor(setting).ReportError(origin, err)
}
// ReportNotConfigured updates metrics and logs that the specified setting is
// not configured in the origin.
func ReportNotConfigured(origin *setting.Origin, setting *setting.Definition) {
settingMetricsFor(setting).Reset(origin)
}
// metric is an interface implemented by [clientmetric.Metric] and [funcMetric].
type metric interface {
Add(v int64)
Set(v int64)
}
// policyScopeMetrics are metrics that apply to an entire policy scope rather
// than a specific policy setting.
type policyScopeMetrics struct {
hasAny metric
numErrored metric
}
func newScopeMetrics(scope setting.Scope) *policyScopeMetrics {
prefix := metricScopeName(scope)
if prefix != "" {
prefix += "_"
}
// {os}_syspolicy_{scope_unless_device}_any
// Example: windows_syspolicy_any or windows_syspolicy_user_any.
hasAny := newMetric(prefix+"any", clientmetric.TypeGauge)
// {os}_syspolicy_{scope_unless_device}_errors
// Example: windows_syspolicy_errors or windows_syspolicy_user_errors.
//
// TODO(nickkhyl): maybe make the `{os}_syspolicy_errors` metric a gauge rather than a counter?
// It was a counter prior to https://github.com/tailscale/tailscale/issues/12687, so I kept it as such.
// But I think a gauge makes more sense: syspolicy errors indicate a mismatch between the expected
// policy value type or format and the actual value read from the underlying store (like the Windows Registry).
// We'll encounter the same error every time we re-read the policy setting from the backing store
// until the policy value is corrected by the user, or until we fix the bug in the code or ADMX.
// There's probably no reason to count and accumulate them over time.
numErrored := newMetric(prefix+"errors", clientmetric.TypeCounter)
return &policyScopeMetrics{hasAny, numErrored}
}
// ReportHasSettings is called when there's any configured policy setting in the scope.
func (m *policyScopeMetrics) ReportHasSettings() {
if m != nil {
m.hasAny.Set(1)
}
}
// ReportError is called when there's any errored policy setting in the scope.
func (m *policyScopeMetrics) ReportError() {
if m != nil {
m.numErrored.Add(1)
}
}
// Reset is called to reset the policy scope metrics, such as when the policy scope
// is about to be reloaded.
func (m *policyScopeMetrics) Reset() {
if m != nil {
m.hasAny.Set(0)
// numErrored is a counter and cannot be (re-)set.
}
}
// settingMetrics are metrics for a single policy setting in one or more scopes.
type settingMetrics struct {
definition *setting.Definition
isSet []metric // by scope
hasErrors []metric // by scope
}
// ReportValue is called when the policy setting is found to be configured in the specified source.
func (m *settingMetrics) ReportValue(origin *setting.Origin, v any) {
if m == nil {
return
}
if scope := origin.Scope().Kind(); int(scope) < len(m.isSet) {
m.isSet[scope].Set(1)
m.hasErrors[scope].Set(0)
}
scopeMetrics(origin).ReportHasSettings()
loggerx.Verbosef("%v(%q) = %v\n", origin, m.definition.Key(), v)
}
// ReportError is called when there's an error with the policy setting in the specified source.
func (m *settingMetrics) ReportError(origin *setting.Origin, err error) {
if m == nil {
return
}
if scope := origin.Scope().Kind(); int(scope) < len(m.hasErrors) {
m.isSet[scope].Set(0)
m.hasErrors[scope].Set(1)
}
scopeMetrics(origin).ReportError()
loggerx.Errorf("%v(%q): %v\n", origin, m.definition.Key(), err)
}
// Reset is called to reset the policy setting's metrics, such as when
// the policy setting does not exist or the source containing the policy
// is about to be reloaded.
func (m *settingMetrics) Reset(origin *setting.Origin) {
if m == nil {
return
}
if scope := origin.Scope().Kind(); int(scope) < len(m.isSet) {
m.isSet[scope].Set(0)
m.hasErrors[scope].Set(0)
}
}
// metricFn is a function that adds or sets a metric value.
type metricFn = func(name string, typ clientmetric.Type, v int64)
// funcMetric implements [metric] by calling the specified add and set functions.
// Used for testing, and with nil functions on platforms that do not support
// syspolicy, and on platforms that report policy metrics from the GUI.
type funcMetric struct {
name string
typ clientmetric.Type
add, set metricFn
}
func (m funcMetric) Add(v int64) {
if m.add != nil {
m.add(m.name, m.typ, v)
}
}
func (m funcMetric) Set(v int64) {
if m.set != nil {
m.set(m.name, m.typ, v)
}
}
var (
lazyDeviceMetrics lazy.SyncValue[*policyScopeMetrics]
lazyProfileMetrics lazy.SyncValue[*policyScopeMetrics]
lazyUserMetrics lazy.SyncValue[*policyScopeMetrics]
)
func scopeMetrics(origin *setting.Origin) *policyScopeMetrics {
switch origin.Scope().Kind() {
case setting.DeviceSetting:
return lazyDeviceMetrics.Get(func() *policyScopeMetrics {
return newScopeMetrics(setting.DeviceSetting)
})
case setting.ProfileSetting:
return lazyProfileMetrics.Get(func() *policyScopeMetrics {
return newScopeMetrics(setting.ProfileSetting)
})
case setting.UserSetting:
return lazyUserMetrics.Get(func() *policyScopeMetrics {
return newScopeMetrics(setting.UserSetting)
})
default:
panic("unreachable")
}
}
var (
settingMetricsMu sync.RWMutex
settingMetricsMap map[setting.Key]*settingMetrics
)
func settingMetricsFor(setting *setting.Definition) *settingMetrics {
settingMetricsMu.RLock()
if metrics, ok := settingMetricsMap[setting.Key()]; ok {
settingMetricsMu.RUnlock()
return metrics
}
settingMetricsMu.RUnlock()
return settingMetricsForSlow(setting)
}
func settingMetricsForSlow(d *setting.Definition) *settingMetrics {
settingMetricsMu.Lock()
defer settingMetricsMu.Unlock()
if metrics, ok := settingMetricsMap[d.Key()]; ok {
return metrics
}
isSet := make([]metric, d.Scope()+1)
hasErrors := make([]metric, d.Scope()+1)
for i := range isSet {
scope := setting.Scope(i)
// {os}_syspolicy_{key}_{scope_unless_device}
// Example: windows_syspolicy_AdminConsole or windows_syspolicy_AdminConsole_user.
isSet[i] = newSettingMetric(d.Key(), scope, "", clientmetric.TypeGauge)
// {os}_syspolicy_{key}_{scope_unless_device}_error
// Example: windows_syspolicy_AdminConsole_error or windows_syspolicy_TestSetting01_user_error.
hasErrors[i] = newSettingMetric(d.Key(), scope, "error", clientmetric.TypeGauge)
}
metrics := &settingMetrics{d, isSet, hasErrors}
mak.Set(&settingMetricsMap, d.Key(), metrics)
return metrics
}
// hooks for testing
var addMetricTestHook, setMetricTestHook syncs.AtomicValue[metricFn]
// SetHooksForTest sets the specified addMetric and setMetric functions
// as the metric functions for the duration of tb and all its subtests.
func SetHooksForTest(tb internal.TB, addMetric, setMetric metricFn) {
oldAddMetric := addMetricTestHook.Swap(addMetric)
oldSetMetric := setMetricTestHook.Swap(setMetric)
tb.Cleanup(func() {
addMetricTestHook.Store(oldAddMetric)
setMetricTestHook.Store(oldSetMetric)
})
settingMetricsMu.Lock()
oldSettingMetricsMap := xmaps.Clone(settingMetricsMap)
clear(settingMetricsMap)
settingMetricsMu.Unlock()
tb.Cleanup(func() {
settingMetricsMu.Lock()
settingMetricsMap = oldSettingMetricsMap
settingMetricsMu.Unlock()
})
// (re-)set the scope metrics to use the test hooks for the duration of tb.
lazyDeviceMetrics.SetForTest(tb, newScopeMetrics(setting.DeviceSetting), nil)
lazyProfileMetrics.SetForTest(tb, newScopeMetrics(setting.ProfileSetting), nil)
lazyUserMetrics.SetForTest(tb, newScopeMetrics(setting.UserSetting), nil)
}
func newSettingMetric(key setting.Key, scope setting.Scope, suffix string, typ clientmetric.Type) metric {
name := strings.ReplaceAll(string(key), setting.KeyPathSeparator, "_")
if tag := metricScopeName(scope); tag != "" {
name += "_" + tag
}
if suffix != "" {
name += "_" + suffix
}
return newMetric(name, typ)
}
func newMetric(name string, typ clientmetric.Type) metric {
name = internal.OS() + "_syspolicy_" + name
switch {
case !ShouldReport():
return &funcMetric{name: name, typ: typ}
case testenv.InTest():
return &funcMetric{name, typ, addMetricTestHook.Load(), setMetricTestHook.Load()}
case typ == clientmetric.TypeCounter:
return clientmetric.NewCounter(name)
case typ == clientmetric.TypeGauge:
return clientmetric.NewGauge(name)
default:
panic("unreachable")
}
}
func metricScopeName(scope setting.Scope) string {
switch scope {
case setting.DeviceSetting:
return ""
case setting.ProfileSetting:
return "profile"
case setting.UserSetting:
return "user"
default:
panic("unreachable")
}
}

View File

@ -0,0 +1,423 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package metrics
import (
"errors"
"testing"
"tailscale.com/types/lazy"
"tailscale.com/util/clientmetric"
"tailscale.com/util/syspolicy/internal"
"tailscale.com/util/syspolicy/setting"
)
func TestSettingMetricNames(t *testing.T) {
tests := []struct {
name string
key setting.Key
scope setting.Scope
suffix string
typ clientmetric.Type
osOverride string
wantMetricName string
}{
{
name: "windows-device-no-suffix",
key: "AdminConsole",
scope: setting.DeviceSetting,
suffix: "",
typ: clientmetric.TypeCounter,
osOverride: "windows",
wantMetricName: "windows_syspolicy_AdminConsole",
},
{
name: "windows-user-no-suffix",
key: "AdminConsole",
scope: setting.UserSetting,
suffix: "",
typ: clientmetric.TypeCounter,
osOverride: "windows",
wantMetricName: "windows_syspolicy_AdminConsole_user",
},
{
name: "windows-profile-no-suffix",
key: "AdminConsole",
scope: setting.ProfileSetting,
suffix: "",
typ: clientmetric.TypeCounter,
osOverride: "windows",
wantMetricName: "windows_syspolicy_AdminConsole_profile",
},
{
name: "windows-profile-err",
key: "AdminConsole",
scope: setting.ProfileSetting,
suffix: "error",
typ: clientmetric.TypeCounter,
osOverride: "windows",
wantMetricName: "windows_syspolicy_AdminConsole_profile_error",
},
{
name: "android-device-no-suffix",
key: "AdminConsole",
scope: setting.DeviceSetting,
suffix: "",
typ: clientmetric.TypeCounter,
osOverride: "android",
wantMetricName: "android_syspolicy_AdminConsole",
},
{
name: "key-path",
key: "category/subcategory/setting",
scope: setting.DeviceSetting,
suffix: "",
typ: clientmetric.TypeCounter,
osOverride: "fakeos",
wantMetricName: "fakeos_syspolicy_category_subcategory_setting",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
internal.OSForTesting.SetForTest(t, tt.osOverride, nil)
metric, ok := newSettingMetric(tt.key, tt.scope, tt.suffix, tt.typ).(*funcMetric)
if !ok {
t.Fatal("metric is not a funcMetric")
}
if metric.name != tt.wantMetricName {
t.Errorf("got %q, want %q", metric.name, tt.wantMetricName)
}
})
}
}
func TestScopeMetrics(t *testing.T) {
tests := []struct {
name string
scope setting.Scope
osOverride string
wantHasAnyName string
wantNumErroredName string
wantHasAnyType clientmetric.Type
wantNumErroredType clientmetric.Type
}{
{
name: "windows-device",
scope: setting.DeviceSetting,
osOverride: "windows",
wantHasAnyName: "windows_syspolicy_any",
wantHasAnyType: clientmetric.TypeGauge,
wantNumErroredName: "windows_syspolicy_errors",
wantNumErroredType: clientmetric.TypeCounter,
},
{
name: "windows-profile",
scope: setting.ProfileSetting,
osOverride: "windows",
wantHasAnyName: "windows_syspolicy_profile_any",
wantHasAnyType: clientmetric.TypeGauge,
wantNumErroredName: "windows_syspolicy_profile_errors",
wantNumErroredType: clientmetric.TypeCounter,
},
{
name: "windows-user",
scope: setting.UserSetting,
osOverride: "windows",
wantHasAnyName: "windows_syspolicy_user_any",
wantHasAnyType: clientmetric.TypeGauge,
wantNumErroredName: "windows_syspolicy_user_errors",
wantNumErroredType: clientmetric.TypeCounter,
},
{
name: "android-device",
scope: setting.DeviceSetting,
osOverride: "android",
wantHasAnyName: "android_syspolicy_any",
wantHasAnyType: clientmetric.TypeGauge,
wantNumErroredName: "android_syspolicy_errors",
wantNumErroredType: clientmetric.TypeCounter,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
internal.OSForTesting.SetForTest(t, tt.osOverride, nil)
metrics := newScopeMetrics(tt.scope)
hasAny, ok := metrics.hasAny.(*funcMetric)
if !ok {
t.Fatal("hasAny is not a funcMetric")
}
numErrored, ok := metrics.numErrored.(*funcMetric)
if !ok {
t.Fatal("numErrored is not a funcMetric")
}
if hasAny.name != tt.wantHasAnyName {
t.Errorf("hasAny.Name: got %q, want %q", hasAny.name, tt.wantHasAnyName)
}
if hasAny.typ != tt.wantHasAnyType {
t.Errorf("hasAny.Type: got %q, want %q", hasAny.typ, tt.wantHasAnyType)
}
if numErrored.name != tt.wantNumErroredName {
t.Errorf("numErrored.Name: got %q, want %q", numErrored.name, tt.wantNumErroredName)
}
if numErrored.typ != tt.wantNumErroredType {
t.Errorf("hasAny.Type: got %q, want %q", numErrored.typ, tt.wantNumErroredType)
}
})
}
}
type testSettingDetails struct {
definition *setting.Definition
origin *setting.Origin
value any
err error
}
func TestReportMetrics(t *testing.T) {
tests := []struct {
name string
osOverride string
useMetrics bool
settings []testSettingDetails
wantMetrics []TestState
wantResetMetrics []TestState
}{
{
name: "none",
osOverride: "windows",
settings: []testSettingDetails{},
wantMetrics: []TestState{},
},
{
name: "single-value",
osOverride: "windows",
settings: []testSettingDetails{
{
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
origin: setting.NewOrigin(setting.DeviceScope),
value: 42,
},
},
wantMetrics: []TestState{
{"windows_syspolicy_any", 1},
{"windows_syspolicy_TestSetting01", 1},
},
wantResetMetrics: []TestState{
{"windows_syspolicy_any", 0},
{"windows_syspolicy_TestSetting01", 0},
},
},
{
name: "single-error",
osOverride: "windows",
settings: []testSettingDetails{
{
definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
origin: setting.NewOrigin(setting.DeviceScope),
err: errors.New("bang!"),
},
},
wantMetrics: []TestState{
{"windows_syspolicy_errors", 1},
{"windows_syspolicy_TestSetting02_error", 1},
},
wantResetMetrics: []TestState{
{"windows_syspolicy_errors", 1},
{"windows_syspolicy_TestSetting02_error", 0},
},
},
{
name: "value-and-error",
osOverride: "windows",
settings: []testSettingDetails{
{
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
origin: setting.NewOrigin(setting.DeviceScope),
value: 42,
},
{
definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
origin: setting.NewOrigin(setting.DeviceScope),
err: errors.New("bang!"),
},
},
wantMetrics: []TestState{
{"windows_syspolicy_any", 1},
{"windows_syspolicy_errors", 1},
{"windows_syspolicy_TestSetting01", 1},
{"windows_syspolicy_TestSetting02_error", 1},
},
wantResetMetrics: []TestState{
{"windows_syspolicy_any", 0},
{"windows_syspolicy_errors", 1},
{"windows_syspolicy_TestSetting01", 0},
{"windows_syspolicy_TestSetting02_error", 0},
},
},
{
name: "two-values",
osOverride: "windows",
settings: []testSettingDetails{
{
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
origin: setting.NewOrigin(setting.DeviceScope),
value: 42,
},
{
definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
origin: setting.NewOrigin(setting.DeviceScope),
value: 17,
},
},
wantMetrics: []TestState{
{"windows_syspolicy_any", 1},
{"windows_syspolicy_TestSetting01", 1},
{"windows_syspolicy_TestSetting02", 1},
},
wantResetMetrics: []TestState{
{"windows_syspolicy_any", 0},
{"windows_syspolicy_TestSetting01", 0},
{"windows_syspolicy_TestSetting02", 0},
},
},
{
name: "two-errors",
osOverride: "windows",
settings: []testSettingDetails{
{
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
origin: setting.NewOrigin(setting.DeviceScope),
err: errors.New("bang!"),
},
{
definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
origin: setting.NewOrigin(setting.DeviceScope),
err: errors.New("bang!"),
},
},
wantMetrics: []TestState{
{"windows_syspolicy_errors", 2},
{"windows_syspolicy_TestSetting01_error", 1},
{"windows_syspolicy_TestSetting02_error", 1},
},
wantResetMetrics: []TestState{
{"windows_syspolicy_errors", 2},
{"windows_syspolicy_TestSetting01_error", 0},
{"windows_syspolicy_TestSetting02_error", 0},
},
},
{
name: "multi-scope",
osOverride: "windows",
settings: []testSettingDetails{
{
definition: setting.NewDefinition("TestSetting01", setting.ProfileSetting, setting.IntegerValue),
origin: setting.NewOrigin(setting.DeviceScope),
value: 42,
},
{
definition: setting.NewDefinition("TestSetting02", setting.ProfileSetting, setting.IntegerValue),
origin: setting.NewOrigin(setting.CurrentProfileScope),
err: errors.New("bang!"),
},
{
definition: setting.NewDefinition("TestSetting03", setting.UserSetting, setting.IntegerValue),
origin: setting.NewOrigin(setting.CurrentUserScope),
value: 17,
},
},
wantMetrics: []TestState{
{"windows_syspolicy_any", 1},
{"windows_syspolicy_profile_errors", 1},
{"windows_syspolicy_user_any", 1},
{"windows_syspolicy_TestSetting01", 1},
{"windows_syspolicy_TestSetting02_profile_error", 1},
{"windows_syspolicy_TestSetting03_user", 1},
},
wantResetMetrics: []TestState{
{"windows_syspolicy_any", 0},
{"windows_syspolicy_profile_errors", 1},
{"windows_syspolicy_user_any", 0},
{"windows_syspolicy_TestSetting01", 0},
{"windows_syspolicy_TestSetting02_profile_error", 0},
{"windows_syspolicy_TestSetting03_user", 0},
},
},
{
name: "report-metrics-on-android",
osOverride: "android",
settings: []testSettingDetails{
{
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
origin: setting.NewOrigin(setting.DeviceScope),
value: 42,
},
},
wantMetrics: []TestState{
{"android_syspolicy_any", 1},
{"android_syspolicy_TestSetting01", 1},
},
wantResetMetrics: []TestState{
{"android_syspolicy_any", 0},
{"android_syspolicy_TestSetting01", 0},
},
},
{
name: "do-not-report-metrics-on-macos",
osOverride: "macos",
settings: []testSettingDetails{
{
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
origin: setting.NewOrigin(setting.DeviceScope),
value: 42,
},
},
wantMetrics: []TestState{}, // none reported
},
{
name: "do-not-report-metrics-on-ios",
osOverride: "ios",
settings: []testSettingDetails{
{
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
origin: setting.NewOrigin(setting.DeviceScope),
value: 42,
},
},
wantMetrics: []TestState{}, // none reported
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset the lazy value so it'll be re-evaluated with the osOverride.
lazyReportMetrics = lazy.SyncValue[bool]{}
t.Cleanup(func() {
// Also reset it during the cleanup.
lazyReportMetrics = lazy.SyncValue[bool]{}
})
internal.OSForTesting.SetForTest(t, tt.osOverride, nil)
h := NewTestHandler(t)
SetHooksForTest(t, h.AddMetric, h.SetMetric)
for _, s := range tt.settings {
if s.err != nil {
ReportError(s.origin, s.definition, s.err)
} else {
ReportConfigured(s.origin, s.definition, s.value)
}
}
h.MustEqual(tt.wantMetrics...)
for _, s := range tt.settings {
Reset(s.origin)
ReportNotConfigured(s.origin, s.definition)
}
h.MustEqual(tt.wantResetMetrics...)
})
}
}

View File

@ -0,0 +1,88 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package metrics
import (
"strings"
"tailscale.com/util/clientmetric"
"tailscale.com/util/set"
"tailscale.com/util/syspolicy/internal"
)
// TestState represents a metric name and its expected value.
type TestState struct {
Name string // `$os` in the name will be replaced by the actual operating system name.`
Value int64
}
// TestHandler facilitates testing of the code that uses metrics.
type TestHandler struct {
t internal.TB
m map[string]int64
}
// NewTestHandler returns a new TestHandler.
func NewTestHandler(t internal.TB) *TestHandler {
return &TestHandler{t, make(map[string]int64)}
}
// AddMetric increments the metric with the specified name and type by delta d.
func (h *TestHandler) AddMetric(name string, typ clientmetric.Type, d int64) {
h.t.Helper()
if typ == clientmetric.TypeCounter && d < 0 {
h.t.Fatalf("an attempt was made to decrement a counter metric %q", name)
}
if v, ok := h.m[name]; ok || d != 0 {
h.m[name] = v + d
}
}
// SetMetric sets the metric with the specified name and type to the value v.
func (h *TestHandler) SetMetric(name string, typ clientmetric.Type, v int64) {
h.t.Helper()
if typ == clientmetric.TypeCounter {
h.t.Fatalf("an attempt was made to set a counter metric %q", name)
}
if _, ok := h.m[name]; ok || v != 0 {
h.m[name] = v
}
}
// MustEqual fails the test if the actual metric state differs from the specified state.
func (h *TestHandler) MustEqual(metrics ...TestState) {
h.t.Helper()
h.MustContain(metrics...)
h.mustNoExtra(metrics...)
}
// MustContain fails the test if the specified metrics are not set or have
// different values than specified. It permits other metrics to be set in
// addition to the ones being tested.
func (h *TestHandler) MustContain(metrics ...TestState) {
h.t.Helper()
for _, m := range metrics {
name := strings.ReplaceAll(m.Name, "$os", internal.OS())
v, ok := h.m[name]
if !ok {
h.t.Errorf("%q: got (none), want %v", name, m.Value)
} else if v != m.Value {
h.t.Fatalf("%q: got %v, want %v", name, v, m.Value)
}
}
}
func (h *TestHandler) mustNoExtra(metrics ...TestState) {
h.t.Helper()
s := make(set.Set[string])
for i := range metrics {
s.Add(strings.ReplaceAll(metrics[i].Name, "$os", internal.OS()))
}
for n, v := range h.m {
if !s.Contains(n) {
h.t.Errorf("%q: got %v, want (none)", n, v)
}
}
}

View File

@ -3,7 +3,21 @@
package syspolicy
type Key string
import (
"tailscale.com/types/lazy"
"tailscale.com/util/syspolicy/internal/lazyinit"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/testenv"
)
type Key = setting.Key
// The const block below lists known policy keys.
// When adding a key to this list, remember to add a corresponding
// [setting.Definition] to [implicitDefinitions] below.
// Otherwise, the [TestKnownKeysRegistered] test will fail as a reminder.
// Preferably, use a strongly typed policy hierarchy, such as [Policy],
// instead of adding each key to the list below.
const (
// Keys with a string value
@ -96,3 +110,83 @@
// AllowedSuggestedExitNodes's string array value is a list of exit node IDs that restricts which exit nodes are considered when generating suggestions for exit nodes.
AllowedSuggestedExitNodes Key = "AllowedSuggestedExitNodes"
)
// implicitDefinitions is a list of [setting.Definition] that will be registered
// automatically by [settingDefinitions] as soon as the package needs to ready a policy.
var implicitDefinitions = []*setting.Definition{
// Device policy settings
setting.NewDefinition(AllowedSuggestedExitNodes, setting.DeviceSetting, setting.StringListValue),
setting.NewDefinition(ApplyUpdates, setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition(CheckUpdates, setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition(ControlURL, setting.DeviceSetting, setting.StringValue),
setting.NewDefinition(DeviceSerialNumber, setting.DeviceSetting, setting.StringValue),
setting.NewDefinition(EnableIncomingConnections, setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition(EnableRunExitNode, setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition(EnableServerMode, setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition(EnableTailscaleDNS, setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition(EnableTailscaleSubnets, setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition(ExitNodeAllowLANAccess, setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition(ExitNodeID, setting.DeviceSetting, setting.StringValue),
setting.NewDefinition(ExitNodeIP, setting.DeviceSetting, setting.StringValue),
setting.NewDefinition(FlushDNSOnSessionUnlock, setting.DeviceSetting, setting.BooleanValue),
setting.NewDefinition(LogSCMInteractions, setting.DeviceSetting, setting.BooleanValue),
setting.NewDefinition(LogTarget, setting.DeviceSetting, setting.StringValue),
setting.NewDefinition(PostureChecking, setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition(Tailnet, setting.DeviceSetting, setting.StringValue),
// User policy settings
setting.NewDefinition(AdminConsoleVisibility, setting.UserSetting, setting.VisibilityValue),
setting.NewDefinition(AutoUpdateVisibility, setting.UserSetting, setting.VisibilityValue),
setting.NewDefinition(ExitNodeMenuVisibility, setting.UserSetting, setting.VisibilityValue),
setting.NewDefinition(KeyExpirationNoticeTime, setting.UserSetting, setting.DurationValue),
setting.NewDefinition(ManagedByCaption, setting.UserSetting, setting.StringValue),
setting.NewDefinition(ManagedByOrganizationName, setting.UserSetting, setting.StringValue),
setting.NewDefinition(ManagedByURL, setting.UserSetting, setting.StringValue),
setting.NewDefinition(NetworkDevicesVisibility, setting.UserSetting, setting.VisibilityValue),
setting.NewDefinition(PreferencesMenuVisibility, setting.UserSetting, setting.VisibilityValue),
setting.NewDefinition(ResetToDefaultsVisibility, setting.UserSetting, setting.VisibilityValue),
setting.NewDefinition(RunExitNodeVisibility, setting.UserSetting, setting.VisibilityValue),
setting.NewDefinition(SuggestedExitNodeVisibility, setting.UserSetting, setting.VisibilityValue),
setting.NewDefinition(TestMenuVisibility, setting.UserSetting, setting.VisibilityValue),
setting.NewDefinition(UpdateMenuVisibility, setting.UserSetting, setting.VisibilityValue),
}
func init() {
lazyinit.Defer(func() error {
// Avoid implicit [SettingDefinition] registration during tests.
// Each test should control which policy settings to register.
// Use [setting.SetDefinitionsForTest] to specify necessary definitions,
// or [setWellKnownSettingsForTest] to set implicit definitions for the test duration.
if testenv.InTest() {
return nil
}
for _, d := range implicitDefinitions {
setting.RegisterDefinition(d)
}
return nil
})
}
var implicitDefinitionMap lazy.SyncValue[setting.DefinitionMap]
// WellKnownSettingDefinition returns a well-known, implicit setting definition by its key,
// or an [ErrNoSuchKey] if a policy setting with the specified key does not exist
// among implicit policy definitions.
func WellKnownSettingDefinition(k Key) (*setting.Definition, error) {
m, err := implicitDefinitionMap.GetErr(func() (setting.DefinitionMap, error) {
return setting.DefinitionMapOf(implicitDefinitions)
})
if err != nil {
return nil, err
}
if d, ok := m[k]; ok {
return d, nil
}
return nil, ErrNoSuchKey
}
// setWellKnownSettingsForTest registers all implicit setting definitions
// for the duration of the test.
func setWellKnownSettingsForTest(tb lazy.TB) error {
return setting.SetDefinitionsForTest(tb, implicitDefinitions...)
}

View File

@ -0,0 +1,95 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package syspolicy
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"go/types"
"os"
"reflect"
"strconv"
"testing"
"tailscale.com/util/syspolicy/setting"
)
func TestKnownKeysRegistered(t *testing.T) {
keyConsts, err := listStringConsts[Key]("policy_keys.go")
if err != nil {
t.Fatalf("listStringConsts failed: %v", err)
}
m, err := setting.DefinitionMapOf(implicitDefinitions)
if err != nil {
t.Fatalf("definitionMapOf failed: %v", err)
}
for _, key := range keyConsts {
t.Run(string(key), func(t *testing.T) {
d := m[key]
if d == nil {
t.Fatalf("%q was not registered", key)
}
if d.Key() != key {
t.Fatalf("d.Key got: %s, want %s", d.Key(), key)
}
})
}
}
func TestNotAWellKnownSetting(t *testing.T) {
d, err := WellKnownSettingDefinition("TestSettingDoesNotExist")
if d != nil || err == nil {
t.Fatalf("got %v, %v; want nil, %v", d, err, ErrNoSuchKey)
}
}
func listStringConsts[T ~string](filename string) (map[string]T, error) {
fset := token.NewFileSet()
src, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
f, err := parser.ParseFile(fset, filename, src, 0)
if err != nil {
return nil, err
}
consts := make(map[string]T)
typeName := reflect.TypeFor[T]().Name()
for _, d := range f.Decls {
g, ok := d.(*ast.GenDecl)
if !ok || g.Tok != token.CONST {
continue
}
for _, s := range g.Specs {
vs, ok := s.(*ast.ValueSpec)
if !ok || len(vs.Names) != len(vs.Values) {
continue
}
if typ, ok := vs.Type.(*ast.Ident); !ok || typ.Name != typeName {
continue
}
for i, n := range vs.Names {
lit, ok := vs.Values[i].(*ast.BasicLit)
if !ok {
return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, types.ExprString(vs.Values[i]))
}
val, err := strconv.Unquote(lit.Value)
if err != nil {
return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, lit.Value)
}
consts[n.Name] = T(val)
}
}
}
return consts, nil
}

View File

@ -1,38 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package syspolicy
var stringKeys = []Key{
ControlURL,
LogTarget,
Tailnet,
ExitNodeID,
ExitNodeIP,
EnableIncomingConnections,
EnableServerMode,
ExitNodeAllowLANAccess,
EnableTailscaleDNS,
EnableTailscaleSubnets,
AdminConsoleVisibility,
NetworkDevicesVisibility,
TestMenuVisibility,
UpdateMenuVisibility,
RunExitNodeVisibility,
PreferencesMenuVisibility,
ExitNodeMenuVisibility,
AutoUpdateVisibility,
ResetToDefaultsVisibility,
KeyExpirationNoticeTime,
PostureChecking,
ManagedByOrganizationName,
ManagedByCaption,
ManagedByURL,
}
var boolKeys = []Key{
LogSCMInteractions,
FlushDNSOnSessionUnlock,
}
var uint64Keys = []Key{}

View File

@ -0,0 +1,109 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package rsop
import (
"reflect"
"slices"
"sync"
"time"
"tailscale.com/util/set"
"tailscale.com/util/syspolicy/internal/loggerx"
"tailscale.com/util/syspolicy/setting"
)
// Change represents a change from the Old to the New value of type T.
type Change[T any] struct {
New, Old T
}
// PolicyChangeCallback is a function called whenever a policy changes.
type PolicyChangeCallback func(*PolicyChange)
// PolicyChange describes a policy change.
type PolicyChange struct {
snapshots Change[*setting.Snapshot]
}
// New returns the [setting.Snapshot] after the change.
func (c PolicyChange) New() *setting.Snapshot {
return c.snapshots.New
}
// Old returns the [setting.Snapshot] before the change.
func (c PolicyChange) Old() *setting.Snapshot {
return c.snapshots.Old
}
// HasChanged reports whether a policy setting with the specified [setting.Key], has changed.
func (c PolicyChange) HasChanged(key setting.Key) bool {
new, newErr := c.snapshots.New.GetErr(key)
old, oldErr := c.snapshots.Old.GetErr(key)
if newErr != nil && oldErr != nil {
return false
}
if newErr != nil || oldErr != nil {
return true
}
switch newVal := new.(type) {
case bool, uint64, string, setting.Visibility, setting.PreferenceOption, time.Duration:
return newVal != old
case []string:
if oldVal, ok := old.([]string); ok {
return slices.Equal(newVal, oldVal)
}
return false
default:
loggerx.Errorf("%q has an unsupported value type: %T", newVal)
return reflect.DeepEqual(new, old)
}
}
// policyChangeCallbacks are the callbacks to invoke when the resultant policy changes.
// It is safe for concurrent use.
type policyChangeCallbacks struct {
mu sync.RWMutex
cbs set.HandleSet[PolicyChangeCallback]
}
// Register adds the specified callback to be invoked whenever the policy changes.
func (c *policyChangeCallbacks) Register(callback PolicyChangeCallback) (unregister func()) {
c.mu.Lock()
handle := c.cbs.Add(callback)
c.mu.Unlock()
return func() {
c.mu.Lock()
delete(c.cbs, handle)
c.mu.Unlock()
}
}
// Invoke calls the registered callback functions with the specified policy change info.
func (c *policyChangeCallbacks) Invoke(snapshots Change[*setting.Snapshot]) {
var wg sync.WaitGroup
defer wg.Wait()
c.mu.RLock()
defer c.mu.RUnlock()
wg.Add(len(c.cbs))
change := &PolicyChange{snapshots: snapshots}
for _, cb := range c.cbs {
go func() {
defer wg.Done()
cb(change)
}()
}
}
// Close awaits the completion of active callbacks and prevents any further invocations.
func (c *policyChangeCallbacks) Close() {
c.mu.Lock()
defer c.mu.Unlock()
if c.cbs != nil {
clear(c.cbs)
c.cbs = nil
}
}

View File

@ -0,0 +1,698 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package rsop facilitates [source.Store] registration via [RegisterStore]
// and provides access to the resultant policy merged from all registered sources
// via [PolicyFor].
package rsop
import (
"errors"
"fmt"
"reflect"
"slices"
"sync"
"sync/atomic"
"time"
"tailscale.com/syncs"
"tailscale.com/types/lazy"
"tailscale.com/util/slicesx"
"tailscale.com/util/syspolicy/internal"
"tailscale.com/util/syspolicy/internal/lazyinit"
"tailscale.com/util/syspolicy/internal/loggerx"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/syspolicy/source"
)
var errResultantPolicyClosed = errors.New("resultant policy closed")
// The minimum and maximum wait times after detecting a policy change
// before reloading the policy.
// Policy changes occurring within [policyReloadMinDelay] of each other
// will be batched together, resulting in a single policy reload
// no later than [policyReloadMaxDelay] after the first detected change.
// In other words, the resultant policy will be reloaded no more often than once
// every 5 seconds, but at most 15 seconds after an underlying [source.Store]
// has issued a policy change callback.
// See [Policy.watchReload].
const (
defaultPolicyReloadMinDelay = 5 * time.Second
defaultPolicyReloadMaxDelay = 15 * time.Second
)
// policyReloadMinDelay and policyReloadMaxDelay are test hooks.
// Their values default to [defaultPolicyReloadMinDelay] and [defaultPolicyReloadMaxDelay].
var (
policyReloadMinDelay, policyReloadMaxDelay lazy.SyncValue[time.Duration]
)
// Policy provides access to the current resultant [setting.Snapshot] for a given
// scope and allows to reload it from the underlying [source.Store]s. It also allows to
// subscribe and receive a callback whenever the resultant [setting.Snapshot] is
// changed. It is safe for concurrent use.
type Policy struct {
scope setting.PolicyScope
reloadCh chan reloadRequest // 1-buffered; written to when a policy reload is required
changeSourceCh chan sourceChangeRequest // written to to add a new or remove an existing source
closeCh chan struct{} // closed to signal that the Policy is being closed
doneCh chan struct{} // closed by closeInternal when watchReload exits
// resultant is the most recent version of the [setting.Snapshot] containing policy settings
// merged from all applicable sources.
resultant atomic.Pointer[setting.Snapshot]
changeCallbacks policyChangeCallbacks
mu sync.RWMutex
sources source.ReadableSources
closing bool // Close was called (even if we're still closing)
}
// newPolicy returns a new [Policy] for the specified [setting.PolicyScope]
// that tracks changes and merges policy settings read from the specified sources.
func newPolicy(scope setting.PolicyScope, sources ...*source.Source) (p *Policy, err error) {
readableSources := source.ReadableSources(make([]source.ReadableSource, len(sources)))
for i, s := range sources {
reader, err := s.Reader()
if err != nil {
return nil, fmt.Errorf("failed to get a store reader: %v", err)
}
session, err := reader.OpenSession()
if err != nil {
return nil, fmt.Errorf("failed to open a reading session: %v", err)
}
readableSource := source.ReadableSource{
Source: s,
ReadingSession: session,
}
readableSources[i] = readableSource
defer func() {
if err != nil {
readableSource.Close()
}
}()
}
// Sort policy sources by their precedence from lower to higher.
// For example, {UserPolicy},{ProfilePolicy},{DevicePolicy}.
readableSources.StableSort()
p = &Policy{
scope: scope,
sources: readableSources,
reloadCh: make(chan reloadRequest, 1),
changeSourceCh: make(chan sourceChangeRequest),
closeCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
if err := p.start(); err != nil {
return nil, err
}
return p, nil
}
// IsValid reports whether p is in a valid state and has not been closed.
func (p *Policy) IsValid() bool {
select {
case <-p.closeCh:
return false
default:
return true
}
}
// Scope returns the [setting.PolicyScope] that this resultant policy applies to.
func (p *Policy) Scope() setting.PolicyScope {
return p.scope
}
// Get returns the most recent resultant [setting.Snapshot].
func (p *Policy) Get() *setting.Snapshot {
return p.resultant.Load()
}
// RegisterChangeCallback adds a function to be called whenever the resultant
// policy changes. The returned function can be used to unregister the callback.
func (p *Policy) RegisterChangeCallback(callback PolicyChangeCallback) (unregister func()) {
return p.changeCallbacks.Register(callback)
}
// Reload synchronously re-reads policy settings from the underlying policy
// [source.Store], constructing a new merged [setting.Snapshot] even if the policy remains
// unchanged. In most scenarios, there's no need to re-read the policy manually.
// Instead, it is recommended to register a policy change callback, or to use
// the most recent [setting.Snapshot] returned by the [Policy.Get] method.
func (p *Policy) Reload() (*setting.Snapshot, error) {
return p.reload(true)
}
// reload is like Reload, but allows to specify whether to re-read policy settings
// from unchanged policy sources.
func (p *Policy) reload(force bool) (*setting.Snapshot, error) {
respCh := make(chan reloadResponse, 1)
select {
case p.reloadCh <- reloadRequest{force: force, respCh: respCh}:
// continue
case <-p.closeCh:
return nil, errResultantPolicyClosed
}
select {
case resp := <-respCh:
return resp.policy, resp.err
case <-p.closeCh:
return nil, errResultantPolicyClosed
}
}
// Done returns a channel that is closed when the [Policy] is closed.
func (p *Policy) Done() <-chan struct{} {
return p.doneCh
}
func (p *Policy) start() error {
if _, err := p.reloadNow(false); err != nil {
return err
}
go p.watchPolicyChanges()
go p.watchReload()
return nil
}
// readAndMerge reads and merges policy settings from the underlying sources,
// returning a [setting.Snapshot] with the merged result.
// If the force parameter is true, it re-reads policy settings from each store
// even if no policy change was observed, and returns an error if the read
// operation fails.
func (p *Policy) readAndMerge(force bool) (*setting.Snapshot, error) {
p.mu.RLock()
defer p.mu.RUnlock()
// Start with an empty policy in the target scope.
resultant := setting.NewSnapshot(nil, setting.SummaryWith(p.scope))
// Then merge policy settings from all sources.
// Policy sources with the highest precedence (e.g., the device policy) are merged last,
// overriding any conflicting policy settings with lower precedence.
for _, s := range p.sources {
var policy *setting.Snapshot
if force {
var err error
if policy, err = s.ReadSettings(); err != nil {
return nil, err
}
} else {
policy = s.GetSettings()
}
resultant = setting.MergeSnapshots(resultant, policy)
}
return resultant, nil
}
// reloadAsync requests an asynchronous background policy reload.
// The policy will be reloaded no later than in [policyReloadMaxDelay].
func (p *Policy) reloadAsync() {
select {
case p.reloadCh <- reloadRequest{}:
// Sent.
default:
// A reload request is already en route.
}
}
// reloadNow loads and merges policies from all sources, updating the resultant policy.
// If the force parameter is true, it forcibly reloads policies
// from the underlying policy store, even if no policy changes were detected.
//
// Except for the initial policy reload during the [Policy] creation,
// this method should only be called from the [Policy.watchReload] goroutine.
func (p *Policy) reloadNow(force bool) (*setting.Snapshot, error) {
new, err := p.readAndMerge(force)
if err != nil {
return nil, err
}
old := p.resultant.Swap(new)
// A nil old value indicates the initial policy load rather than a policy change.
// Additionally, we should not invoke the policy change callbacks unless the
// policy items have actually changed.
if old != nil && !old.EqualItems(new) {
snapshots := Change[*setting.Snapshot]{New: new, Old: old}
p.changeCallbacks.Invoke(snapshots)
}
return new, nil
}
// AddSource adds the specified source to the list of sources used by p,
// and triggers a synchronous policy refresh. It returns an error
// if the source is not a valid source for this resultant policy,
// or if the resultant policy is being closed,
// or if policy refresh fails with an error.
func (p *Policy) AddSource(source *source.Source) error {
return p.changeSource(source, nil)
}
// RemoveSource removes the specified source from the list of sources used by p,
// and triggers a synchronous policy refresh. It returns an error if the
// resultant policy is being closed, or if policy refresh fails with an error.
func (p *Policy) RemoveSource(source *source.Source) error {
return p.changeSource(nil, source)
}
// ReplaceSource replaces the old source with the new source atomically,
// and triggers a synchronous policy refresh. It returns an error
// if the source is not a valid source for this resultant policy,
// or if the resultant policy is being closed,
// or if policy refresh fails with an error.
func (p *Policy) ReplaceSource(old, new *source.Source) error {
return p.changeSource(new, old)
}
func (p *Policy) changeSource(toAdd, toRemove *source.Source) error {
if toAdd == toRemove {
return nil
}
if toAdd != nil && !p.scope.IsWithinOf(toAdd.Scope()) {
return errors.New("scope mismatch")
}
respCh := make(chan error, 1)
req := sourceChangeRequest{toAdd, toRemove, respCh}
select {
case p.changeSourceCh <- req:
return <-respCh
case <-p.closeCh:
return errResultantPolicyClosed
}
}
// watchPolicyChanges awaits a policy change notification from any of the sources
// and calls reloadAsync whenever a notification is received.
func (p *Policy) watchPolicyChanges() {
const (
closeIdx = iota
changeSourceIdx
policyChangedOffset
)
// The cases are Close, ChangeSource, PolicyChanged[0],...,PolicyChanged[N-1].
p.mu.RLock()
cases := make([]reflect.SelectCase, len(p.sources)+policyChangedOffset)
// Add the PolicyChanged[N] cases.
for i, source := range p.sources {
cases[i+policyChangedOffset] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(source.PolicyChanged())}
}
// Add the Close case.
cases[closeIdx] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(p.closeCh)}
// Add the ChangeSource case.
cases[changeSourceIdx] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(p.changeSourceCh)}
p.mu.RUnlock()
for {
switch chosen, recv, ok := reflect.Select(cases); chosen {
case closeIdx: // Close
// Exit the watch as the closeCh was closed, indicating that
// the [Policy] is being closed.
return
case changeSourceIdx: // ChangeSource
// We've received a source change request from one of the AddSource,
// RemoveSource, or ReplaceSource methods, meaning that we need to:
// - Open a reader session if a new source is being added;
// - Update the p.sources slice;
// - Update the cases slice;
// - Trigger a synchronous policy reload;
// - Report an error, if any, back to the caller.
req := recv.Interface().(sourceChangeRequest)
needClose, err := func() (close bool, err error) {
p.mu.Lock()
defer p.mu.Unlock()
if req.toAdd != nil {
if !p.sources.Contains(req.toAdd) {
reader, err := req.toAdd.Reader()
if err != nil {
return false, fmt.Errorf("failed to get a store reader: %v", err)
}
session, err := reader.OpenSession()
if err != nil {
return false, fmt.Errorf("failed to open a reading session: %v", err)
}
addAt := p.sources.InsertionIndexOf(req.toAdd)
toAdd := source.ReadableSource{
Source: req.toAdd,
ReadingSession: session,
}
p.sources = slices.Insert(p.sources, addAt, toAdd)
newCase := reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(toAdd.PolicyChanged())}
caseIndex := addAt + policyChangedOffset
cases = slices.Insert(cases, caseIndex, newCase)
}
}
if req.toDelete != nil {
if deleteAt := p.sources.IndexOf(req.toDelete); deleteAt != -1 {
p.sources.DeleteAt(deleteAt)
caseIndex := deleteAt + policyChangedOffset
cases = slices.Delete(cases, caseIndex, caseIndex+1)
}
}
return len(p.sources) == 0, nil
}()
if err == nil {
if needClose {
// Close the resultant policy if the last policy source was deleted.
p.Close()
} else {
// Otherwise, reload the policy synchronously.
_, err = p.reload(false)
}
}
req.respCh <- err
default: // PolicyChanged[N]
if !ok {
// One of the PolicyChanged channels was closed, indicating that
// the corresponding [source.Source] is no longer valid.
// We can no longer keep this [Policy] up to date
// and should close it.
p.Close()
return
}
// One of the PolicyChanged channels was signaled.
// We should request an asynchronous policy reload.
p.reloadAsync()
}
}
}
// watchReload processes incoming synchronous and asynchronous policy reload requests.
// Synchronous requests (with a non-nil respCh) are served immediately.
// Asynchronous requests are debounced and throttled: they are executed at least
// [policyReloadMinDelay] after the last request, but no later than [policyReloadMaxDelay]
// after the first request in a batch.
func (p *Policy) watchReload() {
force := false // whether a forced refresh was requested
var delayCh, timeoutCh <-chan time.Time
reload := func(respCh chan<- reloadResponse) {
delayCh, timeoutCh = nil, nil
policy, err := p.reloadNow(force)
if err != nil {
loggerx.Errorf("%v policy reload failed: %v\n", p.scope, err)
}
if respCh != nil {
respCh <- reloadResponse{policy: policy, err: err}
}
force = false
}
loop:
for {
select {
case req := <-p.reloadCh:
if req.force {
force = true
}
if req.respCh != nil {
reload(req.respCh)
continue
}
if delayCh == nil {
timeoutCh = time.After(policyReloadMaxDelay.Get(func() time.Duration { return defaultPolicyReloadMaxDelay }))
}
delayCh = time.After(policyReloadMinDelay.Get(func() time.Duration { return defaultPolicyReloadMinDelay }))
case <-delayCh:
reload(nil)
case <-timeoutCh:
reload(nil)
case <-p.closeCh:
break loop
}
}
p.closeInternal()
}
func (p *Policy) closeInternal() {
p.mu.Lock()
defer p.mu.Unlock()
p.sources.Close()
p.changeCallbacks.Close()
close(p.doneCh)
}
// Close initiates the closing of the resultant policy.
// The actual closing is performed by closeInternal when watchReload exits,
// and the Done() channel is closed when closeInternal finishes.
func (p *Policy) Close() {
p.mu.Lock()
defer p.mu.Unlock()
if p.closing {
return
}
p.closing = true
close(p.closeCh)
}
// sourceChangeRequest is a request to add and/or remove source from a [Policy].
type sourceChangeRequest struct {
toAdd, toDelete *source.Source
respCh chan<- error
}
// reloadRequest describes a policy reload request.
type reloadRequest struct {
// force triggers an immediate synchronous policy reload,
// reloading the policy regardless of whether a policy change was detected.
force bool
// respCh is an optional channel. If non-nil, it makes the reload request
// synchronous and receives the result.
respCh chan<- reloadResponse
}
type reloadResponse struct {
policy *setting.Snapshot
err error
}
var (
policyMu sync.RWMutex
policySources []*source.Source
resultantPolicies []*Policy
resultantPolicyLRU [setting.MaxSettingScope + 1]syncs.AtomicValue[*Policy] // by [Scope.Kind]
)
// registerSource registers the specified [source.Source] to be used by the package.
// It updates existing [Policy]s returned by [PolicyFor] to use this source if
// they are within the source's [setting.PolicyScope].
func registerSource(source *source.Source) error {
policyMu.Lock()
defer policyMu.Unlock()
if slices.Contains(policySources, source) {
return nil
}
policySources = append(policySources, source)
return forEachResultantPolicyLocked(func(policy *Policy) error {
if !policy.Scope().IsWithinOf(source.Scope()) {
return nil
}
return policy.AddSource(source)
})
}
// replaceSource is like [unregisterSource](old) followed by [registerSource](new),
// but is atomic from the perspective of each [Policy].
func replaceSource(old, new *source.Source) error {
policyMu.Lock()
defer policyMu.Unlock()
oldIndex := slices.Index(policySources, old)
if oldIndex == -1 {
return fmt.Errorf("the source is not registered: %v", old)
}
policySources[oldIndex] = new
return forEachResultantPolicyLocked(func(policy *Policy) error {
if policy.Scope().IsWithinOf(old.Scope()) || policy.Scope().IsWithinOf(new.Scope()) {
return nil
}
return policy.ReplaceSource(old, new)
})
}
// unregisterSource unregisters the specified [source.Source],
// so that it won't be used by any new or existing [Policy].
func unregisterSource(source *source.Source) error {
policyMu.Lock()
defer policyMu.Unlock()
index := slices.Index(policySources, source)
if index == -1 {
return nil
}
policySources = slices.Delete(policySources, index, index+1)
return forEachResultantPolicyLocked(func(policy *Policy) error {
if !policy.Scope().IsWithinOf(source.Scope()) {
return nil
}
return policy.RemoveSource(source)
})
}
// forEachResultantPolicyLocked calls fn for every [Policy] in [resultantPolicies].
// It accumulates the returned errors, except for [errResultantPolicyClosed],
// and returns an error that wraps all errors returned by fn.
// The [policyMu] mutex must be held while this function is executed.
func forEachResultantPolicyLocked(fn func(p *Policy) error) error {
var errs []error
for _, policy := range resultantPolicies {
err := fn(policy)
if err != nil && !errors.Is(err, errResultantPolicyClosed) {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}
// PolicyFor returns the [Policy] for the specified scope,
// creating one from the registered [source.Store]s if it does not exist.
func PolicyFor(scope setting.PolicyScope) (*Policy, error) {
if err := lazyinit.Do(); err != nil {
return nil, err
}
if policy := resultantPolicyLRU[scope.Kind()].Load(); policy != nil && policy.Scope() == scope && policy.IsValid() {
return policy, nil
}
return policyForSlow(scope)
}
func policyForSlow(scope setting.PolicyScope) (policy *Policy, err error) {
defer func() {
if policy != nil {
resultantPolicyLRU[scope.Kind()].Store(policy)
}
}()
policyMu.RLock()
if policy, ok := findPolicyByScopeLocked(scope); ok {
policyMu.RUnlock()
return policy, nil
}
policyMu.RUnlock()
policyMu.Lock()
defer policyMu.Unlock()
if policy, ok := findPolicyByScopeLocked(scope); ok {
return policy, nil
}
sources := slicesx.Filter(nil, policySources, func(source *source.Source) bool {
return scope.IsWithinOf(source.Scope())
})
policy, err = newPolicy(scope, sources...)
if err != nil {
return nil, err
}
resultantPolicies = append(resultantPolicies, policy)
go func() {
<-policy.Done()
deletePolicy(policy)
}()
return policy, nil
}
// findPolicyByScopeLocked returns a policy with the specified scope and true if
// one exists, otherwise it returns nil, false.
// [policyMu] must be held.
func findPolicyByScopeLocked(target setting.PolicyScope) (policy *Policy, ok bool) {
for _, policy := range resultantPolicies {
if policy.Scope() == target && policy.IsValid() {
return policy, true
}
}
return nil, false
}
// deletePolicy deletes the specified resultant policy from the [resultantPolicies] list.
func deletePolicy(policy *Policy) {
policyMu.Lock()
if i := slices.Index(resultantPolicies, policy); i != -1 {
resultantPolicies = slices.Delete(resultantPolicies, i, i+1)
}
resultantPolicyLRU[policy.Scope().Kind()].CompareAndSwap(policy, nil)
policyMu.Unlock()
}
// ErrAlreadyConsumed is the error returned when [StoreRegistration.ReplaceStore]
// or [StoreRegistration.Unregister] is called more than once.
var ErrAlreadyConsumed = errors.New("the store registration is no longer valid")
// StoreRegistration is a [source.Store] registered for use in the specified scope.
// It can be used to unregister the store, or replace it with another one.
type StoreRegistration struct {
source *source.Source
consumed atomic.Uint32
m sync.Mutex
}
// RegisterStore registers a new policy [source.Store] with the specified name and [setting.PolicyScope].
func RegisterStore(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) {
return newStoreRegistration(name, scope, store)
}
// RegisterStoreForTest is like [RegisterStore], but unregisters the store when
// tb and all its subtests complete.
func RegisterStoreForTest(tb internal.TB, name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) {
reg, err := RegisterStore(name, scope, store)
if err == nil {
tb.Cleanup(func() {
if err := reg.Unregister(); err != nil && !errors.Is(err, ErrAlreadyConsumed) {
tb.Fatalf("Unregister failed: %v", err)
}
})
}
return reg, err // may be nil or non-nil
}
func newStoreRegistration(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) {
source := source.NewSource(name, scope, store)
if err := registerSource(source); err != nil {
return nil, err
}
return &StoreRegistration{source: source}, nil
}
// ReplaceStore replaces the registered store with the new one,
// returning a new [StoreRegistration] or an error.
func (r *StoreRegistration) ReplaceStore(new source.Store) (*StoreRegistration, error) {
var res *StoreRegistration
err := r.consume(func() error {
newSource := source.NewSource(r.source.Name(), r.source.Scope(), new)
if err := replaceSource(r.source, newSource); err != nil {
return err
}
res = &StoreRegistration{source: newSource}
return nil
})
return res, err
}
// Unregister reverts the registration.
func (r *StoreRegistration) Unregister() error {
return r.consume(func() error { return unregisterSource(r.source) })
}
// consume invokes fn, consuming r if no error is returned.
// It returns [ErrAlreadyConsumed] on subsequent calls after the first successful call.
func (r *StoreRegistration) consume(fn func() error) (err error) {
if r.consumed.Load() != 0 {
return ErrAlreadyConsumed
}
return r.consumeSlow(fn)
}
func (r *StoreRegistration) consumeSlow(fn func() error) (err error) {
r.m.Lock()
defer r.m.Unlock()
if r.consumed.Load() != 0 {
return ErrAlreadyConsumed
}
if err = fn(); err == nil {
r.consumed.Store(1)
}
return err // may be nil or non-nil
}

View File

@ -0,0 +1,368 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package rsop
import (
"slices"
"sort"
"testing"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/syspolicy/source"
)
func TestRegisterSourceAndGetResultantPolicy(t *testing.T) {
type sourceConfig struct {
name string
scope setting.PolicyScope
settingKey setting.Key
settingValue string
wantEffective bool
}
tests := []struct {
name string
scope setting.PolicyScope
initialSources []sourceConfig
additionalSources []sourceConfig
wantSnapshot *setting.Snapshot
}{
{
name: "DevicePolicy/NoSources",
scope: setting.DeviceScope,
wantSnapshot: setting.NewSnapshot(nil, setting.DeviceScope),
},
{
name: "UserScope/NoSources",
scope: setting.CurrentUserScope,
wantSnapshot: setting.NewSnapshot(nil, setting.CurrentUserScope),
},
{
name: "DevicePolicy/OneInitialSource",
scope: setting.DeviceScope,
initialSources: []sourceConfig{
{
name: "TestSourceA",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "TestValueA",
wantEffective: true,
},
},
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
}, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
},
{
name: "DevicePolicy/OneAdditionalSource",
scope: setting.DeviceScope,
additionalSources: []sourceConfig{
{
name: "TestSourceA",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "TestValueA",
wantEffective: true,
},
},
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
}, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
},
{
name: "DevicePolicy/ManyInitialSources/NoConflicts",
scope: setting.DeviceScope,
initialSources: []sourceConfig{
{
name: "TestSourceA",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "TestValueA",
wantEffective: true,
},
{
name: "TestSourceB",
scope: setting.DeviceScope,
settingKey: "TestKeyB",
settingValue: "TestValueB",
wantEffective: true,
},
{
name: "TestSourceC",
scope: setting.DeviceScope,
settingKey: "TestKeyC",
settingValue: "TestValueC",
wantEffective: true,
},
},
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
"TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)),
"TestKeyC": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)),
}, setting.DeviceScope),
},
{
name: "DevicePolicy/ManyInitialSources/Conflicts",
scope: setting.DeviceScope,
initialSources: []sourceConfig{
{
name: "TestSourceA",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "TestValueA",
wantEffective: true,
},
{
name: "TestSourceB",
scope: setting.DeviceScope,
settingKey: "TestKeyB",
settingValue: "TestValueB",
wantEffective: true,
},
{
name: "TestSourceC",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "TestValueC",
wantEffective: true,
},
},
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"TestKeyA": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)),
"TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)),
}, setting.DeviceScope),
},
{
name: "DevicePolicy/MixedSources/Conflicts",
scope: setting.DeviceScope,
initialSources: []sourceConfig{
{
name: "TestSourceA",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "TestValueA",
wantEffective: true,
},
{
name: "TestSourceB",
scope: setting.DeviceScope,
settingKey: "TestKeyB",
settingValue: "TestValueB",
wantEffective: true,
},
{
name: "TestSourceC",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "TestValueC",
wantEffective: true,
},
},
additionalSources: []sourceConfig{
{
name: "TestSourceD",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "TestValueD",
wantEffective: true,
},
{
name: "TestSourceE",
scope: setting.DeviceScope,
settingKey: "TestKeyC",
settingValue: "TestValueE",
wantEffective: true,
},
{
name: "TestSourceF",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "TestValueF",
wantEffective: true,
},
},
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"TestKeyA": setting.RawItemWith("TestValueF", nil, setting.NewNamedOrigin("TestSourceF", setting.DeviceScope)),
"TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)),
"TestKeyC": setting.RawItemWith("TestValueE", nil, setting.NewNamedOrigin("TestSourceE", setting.DeviceScope)),
}, setting.DeviceScope),
},
{
name: "UserScope/Init-DeviceSource",
scope: setting.CurrentUserScope,
initialSources: []sourceConfig{
{
name: "TestSourceDevice",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "DeviceValue",
wantEffective: true,
},
},
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
}, setting.CurrentUserScope, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
},
{
name: "UserScope/Init-DeviceSource/Add-UserSource",
scope: setting.CurrentUserScope,
initialSources: []sourceConfig{
{
name: "TestSourceDevice",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "DeviceValue",
wantEffective: true,
},
},
additionalSources: []sourceConfig{
{
name: "TestSourceUser",
scope: setting.CurrentUserScope,
settingKey: "TestKeyB",
settingValue: "UserValue",
wantEffective: true,
},
},
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
"TestKeyB": setting.RawItemWith("UserValue", nil, setting.NewNamedOrigin("TestSourceUser", setting.CurrentUserScope)),
}, setting.CurrentUserScope),
},
{
name: "UserScope/Init-DeviceSource/Add-UserSource-and-ProfileSource",
scope: setting.CurrentUserScope,
initialSources: []sourceConfig{
{
name: "TestSourceDevice",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "DeviceValue",
wantEffective: true,
},
},
additionalSources: []sourceConfig{
{
name: "TestSourceProfile",
scope: setting.CurrentProfileScope,
settingKey: "TestKeyB",
settingValue: "ProfileValue",
wantEffective: true,
},
{
name: "TestSourceUser",
scope: setting.CurrentUserScope,
settingKey: "TestKeyB",
settingValue: "UserValue",
wantEffective: true,
},
},
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
"TestKeyB": setting.RawItemWith("ProfileValue", nil, setting.NewNamedOrigin("TestSourceProfile", setting.CurrentProfileScope)),
}, setting.CurrentUserScope),
},
{
name: "DevicePolicy/User-Source-does-not-apply",
scope: setting.DeviceScope,
initialSources: []sourceConfig{
{
name: "TestSourceDevice",
scope: setting.DeviceScope,
settingKey: "TestKeyA",
settingValue: "DeviceValue",
wantEffective: true,
},
},
additionalSources: []sourceConfig{
{
name: "TestSourceUser",
scope: setting.CurrentUserScope,
settingKey: "TestKeyA",
settingValue: "UserValue",
wantEffective: false, // Registering a user source should have no impact on the device policy.
},
},
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
}, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Register all settings that we use in this test.
var definitions []*setting.Definition
for _, source := range slices.Concat(tt.initialSources, tt.additionalSources) {
definitions = append(definitions, setting.NewDefinition(source.settingKey, tt.scope.Kind(), setting.StringValue))
}
if err := setting.SetDefinitionsForTest(t, definitions...); err != nil {
t.Fatalf("SetDefinitionsForTest failed: %v", err)
}
// Add the initial policy sources.
var wantSources []*source.Source
for _, s := range tt.initialSources {
store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue))
source := source.NewSource(s.name, s.scope, store)
if err := registerSource(source); err != nil {
t.Fatalf("failed to register policy source: %v", source)
}
if s.wantEffective {
wantSources = append(wantSources, source)
}
t.Cleanup(func() { unregisterSource(source) })
}
// Retrieve the resultant policy.
policy, err := resultantPolicyForTest(t, tt.scope)
if err != nil {
t.Fatalf("failed to get resultant policy for %v", tt.scope)
}
// Add additional setting sources one by one, and check the policy settings at each step.
for _, s := range tt.additionalSources {
store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue))
source := source.NewSource(s.name, s.scope, store)
if err := registerSource(source); err != nil {
t.Fatalf("failed to register additional policy source: %v", source)
}
if s.wantEffective {
wantSources = append(wantSources, source)
}
t.Cleanup(func() { unregisterSource(source) })
}
sort.SliceStable(wantSources, func(i, j int) bool {
return wantSources[i].Compare(wantSources[j]) < 0
})
gotSources := make([]*source.Source, len(policy.sources))
for i, s := range policy.sources {
gotSources[i] = s.Source
}
if !slices.Equal(gotSources, wantSources) {
t.Errorf("Sources: got %v; want %v", gotSources, wantSources)
}
// Verify the final resultant settings snapshots.
if got := policy.Get(); !got.Equal(tt.wantSnapshot) {
t.Errorf("Snapshot: got %v; want %v", got, tt.wantSnapshot)
}
})
}
}
// resultantPolicyForTest is like [resultantPolicyFor], but it deletes the policy
// when tb and all its subtests complete.
func resultantPolicyForTest(tb testing.TB, target setting.PolicyScope) (*Policy, error) {
policy, err := PolicyFor(target)
if err != nil {
return nil, err
}
tb.Cleanup(func() {
policy.Close()
<-policy.Done()
deletePolicy(policy)
})
return policy, nil
}

View File

@ -0,0 +1,60 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package setting
import "errors"
var (
// ErrNotConfigured is returned when the requested policy setting is not configured.
ErrNotConfigured = errors.New("not configured")
// ErrTypeMismatch is returned when there's a type mismatch between the actual type
// of the setting value and the expected type.
ErrTypeMismatch = errors.New("type mismatch")
// ErrNoSuchKey is returned by [DefinitionOf] when no policy setting
// has been registered with the specified key.
//
// Until 2024-08-02, this error was also returned by a [Handler] when the specified
// key did not have a value set. While the package maintains compatibility with this
// usage of ErrNoSuchKey, it is recommended to return [ErrNotConfigured] from newer
// [source.Store] implementations.
ErrNoSuchKey = errors.New("no such key")
)
// Error is an error when reading or parsing a policy setting.
type Error struct {
text string
}
// NewError returns a [Error] with the specified error message.
func NewError(text string) *Error {
return &Error{text}
}
// WrapError returns an [Error] with the text of the specified error,
// or nil if err is nil, [ErrNotConfigured], or [ErrNoSuchKey].
func WrapError(err error) *Error {
if err == nil || errors.Is(err, ErrNotConfigured) || errors.Is(err, ErrNoSuchKey) {
return nil
}
if err, ok := err.(*Error); ok {
return err
}
return &Error{err.Error()}
}
// Error implements error.
func (e Error) Error() string {
return e.text
}
// MarshalText implements [encoding.TextMarshaler].
func (e Error) MarshalText() (text []byte, err error) {
return []byte(e.Error()), nil
}
// UnmarshalText implements [encoding.TextUnmarshaler].
func (e *Error) UnmarshalText(text []byte) error {
e.text = string(text)
return nil
}

View File

@ -0,0 +1,13 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package setting
// Key is a string that uniquely identifies a policy and must remain unchanged
// once established and documented for a given policy setting. It may contain
// alphanumeric characters and zero or more [KeyPathSeparator]s to group
// individual policy settings into categories.
type Key string
// KeyPathSeparator allows logical grouping of policy settings into categories.
const KeyPathSeparator = "/"

View File

@ -0,0 +1,71 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package setting
import (
"fmt"
jsonv2 "github.com/go-json-experiment/json"
"github.com/go-json-experiment/json/jsontext"
)
// Origin describes where a policy or a policy setting is configured.
type Origin struct {
data settingOrigin
}
// settingOrigin is the marshallable data of a [Origin].
type settingOrigin struct {
Name string `json:",omitzero"`
Scope PolicyScope
}
// NewOrigin returns a new [Origin] with the specified scope.
func NewOrigin(scope PolicyScope) *Origin {
return NewNamedOrigin("", scope)
}
// NewNamedOrigin returns a new [Origin] with the specified scope and name.
func NewNamedOrigin(name string, scope PolicyScope) *Origin {
return &Origin{settingOrigin{name, scope}}
}
// Scope reports the policy [PolicyScope] where the setting is configured.
func (s Origin) Scope() PolicyScope {
return s.data.Scope
}
// Name returns the name of the policy source where the setting is configured,
// or "" if not available.
func (s Origin) Name() string {
return s.data.Name
}
// String implements [fmt.Stringer].
func (s Origin) String() string {
if s.Name() != "" {
return fmt.Sprintf("%s (%v)", s.Name(), s.Scope())
}
return s.Scope().String()
}
// MarshalJSONV2 implements [jsonv2.MarshalerV2].
func (s Origin) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error {
return jsonv2.MarshalEncode(out, &s.data, opts)
}
// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2].
func (s *Origin) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error {
return jsonv2.UnmarshalDecode(in, &s.data, opts)
}
// MarshalJSON implements [json.Marshaler].
func (s Origin) MarshalJSON() ([]byte, error) {
return jsonv2.Marshal(s) // uses MarshalJSONV2
}
// UnmarshalJSON implements [json.Unmarshaler].
func (s *Origin) UnmarshalJSON(b []byte) error {
return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONV2
}

View File

@ -0,0 +1,195 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package setting
import (
"fmt"
"strings"
"tailscale.com/types/lazy"
"tailscale.com/util/syspolicy/internal/lazyinit"
)
var (
lazyCurrentScope lazy.SyncValue[PolicyScope]
// DeviceScope indicates a scope containing device-global policies.
DeviceScope = PolicyScope{kind: DeviceSetting}
// CurrentProfileScope indicates a scope containing policies that apply to the
// currently active Tailscale profile.
CurrentProfileScope = PolicyScope{kind: ProfileSetting}
// CurrentUserScope indicates a scope containing policies that apply to the
// current user, for whatever that means on the current platform and
// in the current application context.
CurrentUserScope = PolicyScope{kind: UserSetting}
)
// PolicyScope is a management scope.
type PolicyScope struct {
kind Scope
userID string
profileID string
}
// CurrentScope returns the default [PolicyScope] that the package will use to return
// the policy settings for unless a different scope is explicitly requested.
// This defaults to [DeviceScope], unless the process runs as a user (rather than LocalSystem)
// on Windows, in which case it returns the [CurrentUserScope].
func CurrentScope() PolicyScope {
// Allow deferred package init functions to override the default scope.
lazyinit.Do()
return lazyCurrentScope.Get(func() PolicyScope { return DeviceScope })
}
// SetCurrentScope attempts to set the specified scope as the current scope,
// and reports whether it succeeds.
// It can be called only once and must be during lazy package initialization.
func SetCurrentScope(scope PolicyScope) bool {
return lazyCurrentScope.Set(scope)
}
// UserScopeOf returns a policy [PolicyScope] of the specified user.
func UserScopeOf(uid string) PolicyScope {
return PolicyScope{kind: UserSetting, userID: uid}
}
// Kind reports the base [Scope] of s.
func (s PolicyScope) Kind() Scope {
return s.kind
}
// IsApplicableSetting reports whether the specified setting applies to
// and can be retrieved for this scope. Policy settings are applicable
// to their own scopes as well as more specific scopes. For example,
// device settings are applicable to device, profile and user scopes,
// but user settings are only applicable to user scopes.
// For instance, a menu visibility setting is inherently a user setting
// and only makes sense in the context of a specific user.
func (s PolicyScope) IsApplicableSetting(setting *Definition) bool {
return setting != nil && setting.Scope() <= s.Kind()
}
// IsConfigurableSetting reports whether the specified setting can be configured
// by a policy at this scope. Policy settings are configurable at their own scopes
// as well as broader scopes. For example, [UserSetting]s are configurable in
// user, profile, and device scopes, but [DeviceSetting]s are only configurable
// in the [DeviceScope]. For instance, the InstallUpdates policy setting
// can only be configured in the device scope, as it controls whether updates
// will be installed automatically on the device, rather than for specific users.
func (s PolicyScope) IsConfigurableSetting(setting *Definition) bool {
return setting != nil && setting.Scope() >= s.Kind()
}
// IsWithinOf reports whether policy settings that apply to s2 also apply to s.
// For example, policy settings that apply to the [DeviceScope] also apply to
// the [CurrentUserScope].
func (s PolicyScope) IsWithinOf(s2 PolicyScope) bool {
if s2.Kind() > s.Kind() {
return false
}
switch s2.Kind() {
case DeviceSetting:
return true
case ProfileSetting:
return s.profileID == s2.profileID
case UserSetting:
return s.userID == s2.userID
default:
panic("unreachable")
}
}
// IsStrictlyWithinOf is like [IsWithinOf], except it returns false
// when s and s2 is the same scope.
func (s PolicyScope) IsStrictlyWithinOf(s2 PolicyScope) bool {
return s != s2 && s.IsWithinOf(s2)
}
// String implements [fmt.Stringer].
func (s PolicyScope) String() string {
if s.profileID == "" && s.userID == "" {
return s.kind.String()
}
return s.stringSlow()
}
// MarshalText implements [encoding.TextMarshaler].
func (s PolicyScope) MarshalText() ([]byte, error) {
return []byte(s.String()), nil
}
// MarshalText implements [encoding.TextUnmarshaler].
func (s *PolicyScope) UnmarshalText(b []byte) error {
*s = PolicyScope{}
parts := strings.SplitN(string(b), "/", 2)
if len(parts) == 0 {
return fmt.Errorf("%s is not a valid scope", b)
}
for i, part := range parts {
kind, id, err := parseScopeAndID(part)
if err != nil {
return err
}
if i > 0 && kind <= s.kind {
return fmt.Errorf("invalid scope hierarchy: %s", b)
}
s.kind = kind
switch kind {
case DeviceSetting:
if id != "" {
return fmt.Errorf("the device scope must not have an ID: %s", b)
}
case ProfileSetting:
s.profileID = id
case UserSetting:
s.userID = id
}
}
return nil
}
func (s PolicyScope) stringSlow() string {
var sb strings.Builder
writeScopeWithID := func(s Scope, id string) {
sb.WriteString(s.String())
if id != "" {
sb.WriteRune('(')
sb.WriteString(id)
sb.WriteRune(')')
}
}
if s.kind == ProfileSetting || s.profileID != "" {
writeScopeWithID(ProfileSetting, s.profileID)
if s.kind != ProfileSetting {
sb.WriteRune('/')
}
}
if s.kind == UserSetting {
writeScopeWithID(UserSetting, s.userID)
}
return sb.String()
}
func parseScopeAndID(s string) (scope Scope, id string, err error) {
name, params, ok := extractScopeAndParams(s)
if !ok {
return 0, "", fmt.Errorf("%q is not a valid scope string", s)
}
if err := scope.UnmarshalText([]byte(name)); err != nil {
return 0, "", err
}
return scope, params, nil
}
func extractScopeAndParams(s string) (name, params string, ok bool) {
paramsStart := strings.Index(s, "(")
if paramsStart == -1 {
return s, "", true
}
paramsEnd := strings.LastIndex(s, ")")
if paramsEnd < paramsStart {
return "", "", false
}
return s[0:paramsStart], s[paramsStart+1 : paramsEnd], true
}

View File

@ -0,0 +1,550 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package setting
import (
"reflect"
"testing"
jsonv2 "github.com/go-json-experiment/json"
)
func TestPolicyScopeIsApplicableSetting(t *testing.T) {
tests := []struct {
name string
scope PolicyScope
setting *Definition
wantApplicable bool
}{
{
name: "DeviceScope/DeviceSetting",
scope: DeviceScope,
setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
wantApplicable: true,
},
{
name: "DeviceScope/ProfileSetting",
scope: DeviceScope,
setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
wantApplicable: false,
},
{
name: "DeviceScope/UserSetting",
scope: DeviceScope,
setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
wantApplicable: false,
},
{
name: "ProfileScope/DeviceSetting",
scope: CurrentProfileScope,
setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
wantApplicable: true,
},
{
name: "ProfileScope/ProfileSetting",
scope: CurrentProfileScope,
setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
wantApplicable: true,
},
{
name: "ProfileScope/UserSetting",
scope: CurrentProfileScope,
setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
wantApplicable: false,
},
{
name: "UserScope/DeviceSetting",
scope: CurrentUserScope,
setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
wantApplicable: true,
},
{
name: "UserScope/ProfileSetting",
scope: CurrentUserScope,
setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
wantApplicable: true,
},
{
name: "UserScope/UserSetting",
scope: CurrentUserScope,
setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
wantApplicable: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotApplicable := tt.scope.IsApplicableSetting(tt.setting)
if gotApplicable != tt.wantApplicable {
t.Fatalf("got %v, want %v", gotApplicable, tt.wantApplicable)
}
})
}
}
func TestPolicyScopeIsConfigurableSetting(t *testing.T) {
tests := []struct {
name string
scope PolicyScope
setting *Definition
wantConfigurable bool
}{
{
name: "DeviceScope/DeviceSetting",
scope: DeviceScope,
setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
wantConfigurable: true,
},
{
name: "DeviceScope/ProfileSetting",
scope: DeviceScope,
setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
wantConfigurable: true,
},
{
name: "DeviceScope/UserSetting",
scope: DeviceScope,
setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
wantConfigurable: true,
},
{
name: "ProfileScope/DeviceSetting",
scope: CurrentProfileScope,
setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
wantConfigurable: false,
},
{
name: "ProfileScope/ProfileSetting",
scope: CurrentProfileScope,
setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
wantConfigurable: true,
},
{
name: "ProfileScope/UserSetting",
scope: CurrentProfileScope,
setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
wantConfigurable: true,
},
{
name: "UserScope/DeviceSetting",
scope: CurrentUserScope,
setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
wantConfigurable: false,
},
{
name: "UserScope/ProfileSetting",
scope: CurrentUserScope,
setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
wantConfigurable: false,
},
{
name: "UserScope/UserSetting",
scope: CurrentUserScope,
setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
wantConfigurable: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotConfigurable := tt.scope.IsConfigurableSetting(tt.setting)
if gotConfigurable != tt.wantConfigurable {
t.Fatalf("got %v, want %v", gotConfigurable, tt.wantConfigurable)
}
})
}
}
func TestPolicyScopeIsWithinOf(t *testing.T) {
tests := []struct {
name string
scopeA PolicyScope
scopeB PolicyScope
wantBWithinOfA bool
wantBStrictlyWithinOfA bool
}{
{
name: "DeviceScope/DeviceScope",
scopeA: DeviceScope,
scopeB: DeviceScope,
wantBWithinOfA: true,
wantBStrictlyWithinOfA: false,
},
{
name: "DeviceScope/CurrentProfileScope",
scopeA: DeviceScope,
scopeB: CurrentProfileScope,
wantBWithinOfA: true,
wantBStrictlyWithinOfA: true,
},
{
name: "DeviceScope/UserScope",
scopeA: DeviceScope,
scopeB: CurrentUserScope,
wantBWithinOfA: true,
wantBStrictlyWithinOfA: true,
},
{
name: "ProfileScope/DeviceScope",
scopeA: CurrentProfileScope,
scopeB: DeviceScope,
wantBWithinOfA: false,
wantBStrictlyWithinOfA: false,
},
{
name: "ProfileScope/ProfileScope",
scopeA: CurrentProfileScope,
scopeB: CurrentProfileScope,
wantBWithinOfA: true,
wantBStrictlyWithinOfA: false,
},
{
name: "ProfileScope/UserScope",
scopeA: CurrentProfileScope,
scopeB: CurrentUserScope,
wantBWithinOfA: true,
wantBStrictlyWithinOfA: true,
},
{
name: "UserScope/DeviceScope",
scopeA: CurrentUserScope,
scopeB: DeviceScope,
wantBWithinOfA: false,
wantBStrictlyWithinOfA: false,
},
{
name: "UserScope/ProfileScope",
scopeA: CurrentUserScope,
scopeB: CurrentProfileScope,
wantBWithinOfA: false,
wantBStrictlyWithinOfA: false,
},
{
name: "UserScope/UserScope",
scopeA: CurrentUserScope,
scopeB: CurrentUserScope,
wantBWithinOfA: true,
wantBStrictlyWithinOfA: false,
},
{
name: "UserScope(1234)/UserScope(1234)",
scopeA: UserScopeOf("1234"),
scopeB: UserScopeOf("1234"),
wantBWithinOfA: true,
wantBStrictlyWithinOfA: false,
},
{
name: "UserScope(1234)/UserScope(5678)",
scopeA: UserScopeOf("1234"),
scopeB: UserScopeOf("5678"),
wantBWithinOfA: false,
wantBStrictlyWithinOfA: false,
},
{
name: "ProfileScope(A)/UserScope(A/1234)",
scopeA: PolicyScope{kind: ProfileSetting, profileID: "A"},
scopeB: PolicyScope{kind: UserSetting, userID: "1234", profileID: "A"},
wantBWithinOfA: true,
wantBStrictlyWithinOfA: true,
},
{
name: "ProfileScope(A)/UserScope(B/1234)",
scopeA: PolicyScope{kind: ProfileSetting, profileID: "A"},
scopeB: PolicyScope{kind: UserSetting, userID: "1234", profileID: "B"},
wantBWithinOfA: false,
wantBStrictlyWithinOfA: false,
},
{
name: "UserScope(1234)/UserScope(A/1234)",
scopeA: PolicyScope{kind: UserSetting, userID: "1234"},
scopeB: PolicyScope{kind: UserSetting, userID: "1234", profileID: "A"},
wantBWithinOfA: true,
wantBStrictlyWithinOfA: true,
},
{
name: "UserScope(1234)/UserScope(A/5678)",
scopeA: PolicyScope{kind: UserSetting, userID: "1234"},
scopeB: PolicyScope{kind: UserSetting, userID: "5678", profileID: "A"},
wantBWithinOfA: false,
wantBStrictlyWithinOfA: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotWithinOf := tt.scopeB.IsWithinOf(tt.scopeA)
if gotWithinOf != tt.wantBWithinOfA {
t.Fatalf("WithinOf: got %v, want %v", gotWithinOf, tt.wantBWithinOfA)
}
gotStrictlyWithinOf := tt.scopeB.IsStrictlyWithinOf(tt.scopeA)
if gotStrictlyWithinOf != tt.wantBStrictlyWithinOfA {
t.Fatalf("StrictlyWithinOf: got %v, want %v", gotStrictlyWithinOf, tt.wantBStrictlyWithinOfA)
}
})
}
}
func TestPolicyScopeMarshalUnmarshal(t *testing.T) {
tests := []struct {
name string
in any
wantJSON string
wantError bool
}{
{
name: "null-scope",
in: &struct {
Scope PolicyScope
}{},
wantJSON: `{"Scope":"Device"}`,
},
{
name: "null-scope-omit-zero",
in: &struct {
Scope PolicyScope `json:",omitzero"`
}{},
wantJSON: `{}`,
},
{
name: "device-scope",
in: &struct {
Scope PolicyScope
}{DeviceScope},
wantJSON: `{"Scope":"Device"}`,
},
{
name: "current-profile-scope",
in: &struct {
Scope PolicyScope
}{CurrentProfileScope},
wantJSON: `{"Scope":"Profile"}`,
},
{
name: "current-user-scope",
in: &struct {
Scope PolicyScope
}{CurrentUserScope},
wantJSON: `{"Scope":"User"}`,
},
{
name: "specific-user-scope",
in: &struct {
Scope PolicyScope
}{UserScopeOf("_")},
wantJSON: `{"Scope":"User(_)"}`,
},
{
name: "specific-user-scope",
in: &struct {
Scope PolicyScope
}{UserScopeOf("S-1-5-21-3698941153-1525015703-2649197413-1001")},
wantJSON: `{"Scope":"User(S-1-5-21-3698941153-1525015703-2649197413-1001)"}`,
},
{
name: "specific-profile-scope",
in: &struct {
Scope PolicyScope
}{PolicyScope{kind: ProfileSetting, profileID: "1234"}},
wantJSON: `{"Scope":"Profile(1234)"}`,
},
{
name: "specific-profile-and-user-scope",
in: &struct {
Scope PolicyScope
}{PolicyScope{
kind: UserSetting,
profileID: "1234",
userID: "S-1-5-21-3698941153-1525015703-2649197413-1001",
}},
wantJSON: `{"Scope":"Profile(1234)/User(S-1-5-21-3698941153-1525015703-2649197413-1001)"}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotJSON, err := jsonv2.Marshal(tt.in)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
if string(gotJSON) != tt.wantJSON {
t.Fatalf("Marshal got %s, want %s", gotJSON, tt.wantJSON)
}
wantBack := tt.in
gotBack := reflect.New(reflect.TypeOf(tt.in).Elem()).Interface()
err = jsonv2.Unmarshal(gotJSON, gotBack)
if err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if !reflect.DeepEqual(gotBack, wantBack) {
t.Fatalf("Unmarshal got %+v, want %+v", gotBack, wantBack)
}
})
}
}
func TestPolicyScopeUnmarshalSpecial(t *testing.T) {
tests := []struct {
name string
json string
want any
wantError bool
}{
{
name: "empty",
json: "{}",
want: &struct {
Scope PolicyScope
}{},
},
{
name: "too-many-scopes",
json: `{"Scope":"Device/Profile/User"}`,
wantError: true,
},
{
name: "user/profile", // incorrect order
json: `{"Scope":"User/Profile"}`,
wantError: true,
},
{
name: "profile-user-no-params",
json: `{"Scope":"Profile/User"}`,
want: &struct {
Scope PolicyScope
}{CurrentUserScope},
},
{
name: "unknown-scope",
json: `{"Scope":"Unknown"}`,
wantError: true,
},
{
name: "unknown-scope/unknown-scope",
json: `{"Scope":"Unknown/Unknown"}`,
wantError: true,
},
{
name: "device-scope/unknown-scope",
json: `{"Scope":"Device/Unknown"}`,
wantError: true,
},
{
name: "unknown-scope/device-scope",
json: `{"Scope":"Unknown/Device"}`,
wantError: true,
},
{
name: "slash",
json: `{"Scope":"/"}`,
wantError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := &struct {
Scope PolicyScope
}{}
err := jsonv2.Unmarshal([]byte(tt.json), got)
if (err != nil) != tt.wantError {
t.Errorf("Marshal error: got %v, want %v", err, tt.wantError)
}
if err != nil {
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Fatalf("Unmarshal got %+v, want %+v", got, tt.want)
}
})
}
}
func TestExtractScopeAndParams(t *testing.T) {
tests := []struct {
name string
s string
scope string
params string
wantOk bool
}{
{
name: "empty",
s: "",
wantOk: true,
},
{
name: "scope-only",
s: "device",
scope: "device",
wantOk: true,
},
{
name: "scope-with-params",
s: "user(1234)",
scope: "user",
params: "1234",
wantOk: true,
},
{
name: "params-empty-scope",
s: "(1234)",
scope: "",
params: "1234",
wantOk: true,
},
{
name: "params-with-brackets",
s: "test()())))())",
scope: "test",
params: ")())))()",
wantOk: true,
},
{
name: "no-closing-bracket",
s: "user(1234",
scope: "",
params: "",
wantOk: false,
},
{
name: "open-before-close",
s: ")user(1234",
scope: "",
params: "",
wantOk: false,
},
{
name: "brackets-only",
s: ")(",
scope: "",
params: "",
wantOk: false,
},
{
name: "closing-bracket",
s: ")",
scope: "",
params: "",
wantOk: false,
},
{
name: "opening-bracket",
s: ")",
scope: "",
params: "",
wantOk: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
scope, params, ok := extractScopeAndParams(tt.s)
if ok != tt.wantOk {
t.Logf("OK: got %v; want %v", ok, tt.wantOk)
}
if scope != tt.scope {
t.Logf("Scope: got %q; want %q", scope, tt.scope)
}
if params != tt.params {
t.Logf("Params: got %v; want %v", params, tt.params)
}
})
}
}

View File

@ -0,0 +1,47 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package setting
// RawItem contains a raw policy setting as read from a policy store, or an
// error if the requested setting could not be read from the store. As a special
// case, it may also hold a value of the [Visibility], [PreferenceOption],
// or [time.Duration] types. While the policy store interface does not support
// these types natively, and the values of these types have to be unmarshalled
// or converted from strings, these setting types predate the typed policy
// hierarchies, and must be supported at this layer.
type RawItem struct {
value any
err *Error
origin *Origin // or nil
}
// RawItemOf returns [RawItem] with the specified value.
func RawItemOf(value any) RawItem {
return RawItemWith(value, nil, nil)
}
// RawItemWith returns an [RawItem] with the specified value, error and origin.
func RawItemWith(value any, err *Error, origin *Origin) RawItem {
return RawItem{value, err, origin}
}
// Value returns the value of an untyped policy setting,
// or nil if the policy setting is not configured.
func (i RawItem) Value() any {
return i.value
}
// Error returns the error that occurred when reading the policy setting,
// or nil if no error occurred.
func (i RawItem) Error() error {
if i.err != nil {
return i.err
}
return nil
}
// Origin returns an optional [Origin] indicating the policy settings is configured.
func (i RawItem) Origin() *Origin {
return i.origin
}

View File

@ -0,0 +1,352 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package setting contain types for policy settings.
package setting
import (
"fmt"
"slices"
"strings"
"sync"
"time"
"tailscale.com/types/lazy"
"tailscale.com/util/syspolicy/internal"
"tailscale.com/util/syspolicy/internal/lazyinit"
)
// Scope indicates the broadest scope at which a policy setting may apply,
// and the narrowest scope at which it may be configured.
type Scope int8
const (
// DeviceSetting indicates a policy setting that applies to a device, regardless of
// which OS user or Tailscale profile is currently active, if any.
// It can only be configured at a [DeviceScope].
DeviceSetting Scope = iota
// ProfileSetting indicates a policy setting that applies to a Tailscale profile.
// It can only be configured for a specific profile or at a [DeviceScope],
// in which case it applies to all profiles on the device.
ProfileSetting
// UserSetting indicates a policy setting that applies to users.
// It can be configured for a user, profile, or the entire device.
UserSetting
// MaxSettingScope is the maximum possible [Scope] value.
MaxSettingScope = UserSetting
)
// String implements [fmt.Stringer].
func (s Scope) String() string {
switch s {
case DeviceSetting:
return "Device"
case ProfileSetting:
return "Profile"
case UserSetting:
return "User"
default:
panic("unreachable")
}
}
// MarshalText implements [encoding.TextMarshaler].
func (s Scope) MarshalText() (text []byte, err error) {
return []byte(s.String()), nil
}
// UnmarshalText implements [encoding.TextUnmarshaler].
func (s *Scope) UnmarshalText(text []byte) error {
switch strings.ToLower(string(text)) {
case "device":
*s = DeviceSetting
case "profile":
*s = ProfileSetting
case "user":
*s = UserSetting
default:
return fmt.Errorf("%q is not a valid scope", string(text))
}
return nil
}
// Type is a policy setting value type.
// Except for [InvalidValue], which represents an invalid policy setting type,
// and [PreferenceOptionValue], [VisibilityValue], and [DurationValue],
// which have special handling due to their legacy status in the package,
// SettingTypes represent the raw value types readable from policy stores.
type Type int
const (
// InvalidValue indicates an invalid policy setting value type.
InvalidValue Type = iota
// BooleanValue indicates a policy setting whose underlying type in the
// [source.Store] is a bool.
BooleanValue
// IntegerValue indicates a policy setting whose underlying type in the
// [source.Store] is a uint64.
IntegerValue
// StringValue indicates a policy setting whose underlying type in the
// [source.Store] is a string.
StringValue
// StringListValue indicates a policy setting whose underlying type in the
// [source.Store] is a []string.
StringListValue
// PreferenceOptionValue indicates a three-state policy setting whose
// underlying type in the [source.Store] is a string, but the actual value
// is a [PreferenceOption].
PreferenceOptionValue
// VisibilityValue indicates a two-state boolean-like policy setting whose
// underlying type in the [source.Store] is a string, but the actual value
// is a [Visibility].
VisibilityValue
// DurationValue indicates an interval/period/duration policy setting whose
// underlying type in the [source.Store] is a string, but the actual value
// is a [time.Duration].
DurationValue
)
// String returns a string representation of t.
func (t Type) String() string {
switch t {
case InvalidValue:
return "Invalid"
case BooleanValue:
return "Boolean"
case IntegerValue:
return "Integer"
case StringValue:
return "String"
case StringListValue:
return "StringList"
case PreferenceOptionValue:
return "PreferenceOption"
case VisibilityValue:
return "Visibility"
case DurationValue:
return "Duration"
default:
panic("unreachable")
}
}
// ValueType is a constraint that allows Go types corresponding to [Type].
type ValueType interface {
bool | uint64 | string | []string | Visibility | PreferenceOption | time.Duration
}
// Definition defines policy key, scope and value type.
type Definition struct {
key Key
scope Scope
typ Type
platforms PlatformList
}
// NewDefinition returns a new [Definition] with the specified
// key, scope, type and supported platforms (see [PlatformList]).
func NewDefinition(k Key, s Scope, t Type, platforms ...string) *Definition {
return &Definition{key: k, scope: s, typ: t, platforms: platforms}
}
// Key returns a policy setting's identifier.
func (d *Definition) Key() Key {
if d == nil {
return ""
}
return d.key
}
// Scope reports the broadest [Scope] the policy setting may apply to.
func (d *Definition) Scope() Scope {
if d == nil {
return 0
}
return d.scope
}
// Type reports the underlying value type of the policy setting.
func (d *Definition) Type() Type {
if d == nil {
return InvalidValue
}
return d.typ
}
// IsSupported reports whether the policy setting is supported on the current OS.
func (d *Definition) IsSupported() bool {
if d == nil {
return false
}
return d.platforms.HasCurrent()
}
// SupportedPlatforms reports platforms on which the policy setting is supported.
// An empty [PlatformList] indicates that s is available on all platforms.
func (d *Definition) SupportedPlatforms() PlatformList {
if d == nil {
return nil
}
return d.platforms
}
// String implements [fmt.Stringer].
func (d *Definition) String() string {
if d == nil {
return "(nil)"
}
return fmt.Sprintf("%v(%q, %v)", d.scope, d.key, d.typ)
}
// Equal reports whether d and d2 have the same key, type and scope.
// It does not check whether both s and s2 are supported on the same platforms.
func (d *Definition) Equal(d2 *Definition) bool {
if d == d2 {
return true
}
if d == nil || d2 == nil {
return false
}
return d.key == d2.key && d.typ == d2.typ && d.scope == d2.scope
}
// DefinitionMap is a map of setting [Definition] by [Key].
type DefinitionMap map[Key]*Definition
var (
definitions lazy.SyncValue[DefinitionMap]
definitionsMu sync.Mutex
definitionsList []*Definition
definitionsUsed bool
)
// Register registers a policy setting with the specified key, scope, and value type.
// All policy settings must be registered before any of them can be used.
// Register panics if called after invoking any syspolicy functions that use the
// registered policy definitions, such as functions that read the policy.
func Register(k Key, s Scope, t Type, platforms ...string) {
RegisterDefinition(NewDefinition(k, s, t, platforms...))
}
// RegisterDefinition is like [Register], but accepts a [Definition].
func RegisterDefinition(d *Definition) {
definitionsMu.Lock()
defer definitionsMu.Unlock()
registerLocked(d)
}
func registerLocked(d *Definition) {
if definitionsUsed {
panic("policy definitions are already in use")
}
definitionsList = append(definitionsList, d)
}
func settingDefinitions() (DefinitionMap, error) {
return definitions.GetErr(func() (DefinitionMap, error) {
lazyinit.Do()
definitionsMu.Lock()
defer definitionsMu.Unlock()
definitionsUsed = true
return DefinitionMapOf(definitionsList)
})
}
// DefinitionMapOf returns a [DefinitionMap] with the specified settings,
// or an error if any settings have the same key but different type or scope.
func DefinitionMapOf(settings []*Definition) (DefinitionMap, error) {
m := make(DefinitionMap, len(settings))
for _, s := range settings {
if existing, exists := m[s.key]; exists {
if existing.Equal(s) {
// Ignore duplicate setting definitions if they match. It is acceptable
// if the same policy setting was registered more than once
// (e.g. by the syspolicy package itself and by iOS/Android code).
existing.platforms.mergeFrom(s.platforms)
continue
}
return nil, fmt.Errorf("duplicate policy definition: %q", s.key)
}
m[s.key] = s
}
return m, nil
}
// SetDefinitionsForTest allows to register the specified setting definitions
// for the test duration. It is not concurrency-safe, but unlike [Register],
// it does not panic and can be called anytime.
// It returns an error if ds contains two different settings with the same [Key].
func SetDefinitionsForTest(tb lazy.TB, ds ...*Definition) error {
m, err := DefinitionMapOf(ds)
if err != nil {
return err
}
definitions.SetForTest(tb, m, err)
return nil
}
// DefinitionOf returns a setting definition by key,
// or [ErrNoSuchKey] if the specified key does not exist,
// or an error if there are conflicting policy definitions.
func DefinitionOf(k Key) (*Definition, error) {
ds, err := settingDefinitions()
if err != nil {
return nil, err
}
if d, ok := ds[k]; ok {
return d, nil
}
return nil, ErrNoSuchKey
}
// Definitions returns all registered setting definitions,
// or an error if different policies were registered under the same name.
func Definitions() ([]*Definition, error) {
ds, err := settingDefinitions()
if err != nil {
return nil, err
}
res := make([]*Definition, 0, len(ds))
for _, d := range ds {
res = append(res, d)
}
return res, nil
}
// PlatformList is a list of OSes.
// An empty list indicates that all possible platforms are supported.
type PlatformList []string
// Has reports whether the list contains the target platform.
func (l PlatformList) Has(target string) bool {
if len(l) == 0 {
return true
}
return slices.ContainsFunc(l, func(os string) bool {
return strings.EqualFold(os, target)
})
}
// HasCurrent is like Has, but for the current platform.
func (l PlatformList) HasCurrent() bool {
return l.Has(internal.OS())
}
// mergeFrom merges l2 into l. Since an empty list indicates no platform restrictions,
// if either l or l2 is empty, the merged result in l will also be empty.
func (l *PlatformList) mergeFrom(l2 PlatformList) {
switch {
case len(*l) == 0:
// No-op. An empty list indicates no platform restrictions.
case len(l2) == 0:
// Merging with an empty list results in an empty list.
*l = l2
default:
// Append, sort and dedup.
*l = append(*l, l2...)
slices.Sort(*l)
*l = slices.Compact(*l)
}
}

View File

@ -0,0 +1,344 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package setting
import (
"slices"
"strings"
"testing"
"tailscale.com/types/lazy"
"tailscale.com/types/ptr"
"tailscale.com/util/syspolicy/internal"
)
func TestSettingDefinition(t *testing.T) {
tests := []struct {
name string
setting *Definition
osOverride string
wantKey Key
wantScope Scope
wantType Type
wantIsSupported bool
wantSupportedPlatforms PlatformList
wantString string
}{
{
name: "Nil",
setting: nil,
wantKey: "",
wantScope: 0,
wantType: InvalidValue,
wantIsSupported: false,
wantString: "(nil)",
},
{
name: "Device/Invalid",
setting: NewDefinition("TestDevicePolicySetting", DeviceSetting, InvalidValue),
wantKey: "TestDevicePolicySetting",
wantScope: DeviceSetting,
wantType: InvalidValue,
wantIsSupported: true,
wantString: `Device("TestDevicePolicySetting", Invalid)`,
},
{
name: "Device/Integer",
setting: NewDefinition("TestDevicePolicySetting", DeviceSetting, IntegerValue),
wantKey: "TestDevicePolicySetting",
wantScope: DeviceSetting,
wantType: IntegerValue,
wantIsSupported: true,
wantString: `Device("TestDevicePolicySetting", Integer)`,
},
{
name: "Profile/String",
setting: NewDefinition("TestProfilePolicySetting", ProfileSetting, StringValue),
wantKey: "TestProfilePolicySetting",
wantScope: ProfileSetting,
wantType: StringValue,
wantIsSupported: true,
wantString: `Profile("TestProfilePolicySetting", String)`,
},
{
name: "Device/StringList",
setting: NewDefinition("AllowedSuggestedExitNodes", DeviceSetting, StringListValue),
wantKey: "AllowedSuggestedExitNodes",
wantScope: DeviceSetting,
wantType: StringListValue,
wantIsSupported: true,
wantString: `Device("AllowedSuggestedExitNodes", StringList)`,
},
{
name: "Device/PreferenceOption",
setting: NewDefinition("AdvertiseExitNode", DeviceSetting, PreferenceOptionValue),
wantKey: "AdvertiseExitNode",
wantScope: DeviceSetting,
wantType: PreferenceOptionValue,
wantIsSupported: true,
wantString: `Device("AdvertiseExitNode", PreferenceOption)`,
},
{
name: "User/Boolean",
setting: NewDefinition("TestUserPolicySetting", UserSetting, BooleanValue),
wantKey: "TestUserPolicySetting",
wantScope: UserSetting,
wantType: BooleanValue,
wantIsSupported: true,
wantString: `User("TestUserPolicySetting", Boolean)`,
},
{
name: "User/Visibility",
setting: NewDefinition("AdminConsole", UserSetting, VisibilityValue),
wantKey: "AdminConsole",
wantScope: UserSetting,
wantType: VisibilityValue,
wantIsSupported: true,
wantString: `User("AdminConsole", Visibility)`,
},
{
name: "User/Duration",
setting: NewDefinition("KeyExpirationNotice", UserSetting, DurationValue),
wantKey: "KeyExpirationNotice",
wantScope: UserSetting,
wantType: DurationValue,
wantIsSupported: true,
wantString: `User("KeyExpirationNotice", Duration)`,
},
{
name: "SupportedSetting",
setting: NewDefinition("DesktopPolicySetting", DeviceSetting, StringValue, "macos", "windows"),
osOverride: "windows",
wantKey: "DesktopPolicySetting",
wantScope: DeviceSetting,
wantType: StringValue,
wantIsSupported: true,
wantSupportedPlatforms: PlatformList{"macos", "windows"},
wantString: `Device("DesktopPolicySetting", String)`,
},
{
name: "UnsupportedSetting",
setting: NewDefinition("AndroidPolicySetting", DeviceSetting, StringValue, "android"),
osOverride: "macos",
wantKey: "AndroidPolicySetting",
wantScope: DeviceSetting,
wantType: StringValue,
wantIsSupported: false,
wantSupportedPlatforms: PlatformList{"android"},
wantString: `Device("AndroidPolicySetting", String)`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.osOverride != "" {
internal.OSForTesting.SetForTest(t, tt.osOverride, nil)
}
if !tt.setting.Equal(tt.setting) {
t.Errorf("the setting should be equal to itself")
}
if tt.setting != nil && !tt.setting.Equal(ptr.To(*tt.setting)) {
t.Errorf("the setting should be equal to its shallow copy")
}
if gotKey := tt.setting.Key(); gotKey != tt.wantKey {
t.Errorf("Key: got %q, want %q", gotKey, tt.wantKey)
}
if gotScope := tt.setting.Scope(); gotScope != tt.wantScope {
t.Errorf("Scope: got %v, want %v", gotScope, tt.wantScope)
}
if gotType := tt.setting.Type(); gotType != tt.wantType {
t.Errorf("Type: got %v, want %v", gotType, tt.wantType)
}
if gotIsSupported := tt.setting.IsSupported(); gotIsSupported != tt.wantIsSupported {
t.Errorf("IsSupported: got %v, want %v", gotIsSupported, tt.wantIsSupported)
}
if gotSupportedPlatforms := tt.setting.SupportedPlatforms(); !slices.Equal(gotSupportedPlatforms, tt.wantSupportedPlatforms) {
t.Errorf("SupportedPlatforms: got %v, want %v", gotSupportedPlatforms, tt.wantSupportedPlatforms)
}
if gotString := tt.setting.String(); gotString != tt.wantString {
t.Errorf("String: got %v, want %v", gotString, tt.wantString)
}
})
}
}
func TestRegisterSettingDefinition(t *testing.T) {
const testPolicySettingKey Key = "TestPolicySetting"
tests := []struct {
name string
key Key
wantEq *Definition
wantErr error
}{
{
name: "GetRegistered",
key: "TestPolicySetting",
wantEq: NewDefinition(testPolicySettingKey, DeviceSetting, StringValue),
},
{
name: "GetNonRegistered",
key: "OtherPolicySetting",
wantEq: nil,
wantErr: ErrNoSuchKey,
},
}
resetSettingDefinitions(t)
Register(testPolicySettingKey, DeviceSetting, StringValue)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, gotErr := DefinitionOf(tt.key)
if gotErr != tt.wantErr {
t.Errorf("gotErr %v, wantErr %v", gotErr, tt.wantErr)
}
if !got.Equal(tt.wantEq) {
t.Errorf("got %v, want %v", got, tt.wantEq)
}
})
}
}
func TestRegisterAfterUsePanics(t *testing.T) {
resetSettingDefinitions(t)
Register("TestPolicySetting", DeviceSetting, StringValue)
DefinitionOf("TestPolicySetting")
func() {
defer func() {
if gotPanic, wantPanic := recover(), "policy definitions are already in use"; gotPanic != wantPanic {
t.Errorf("gotPanic: %q, wantPanic: %q", gotPanic, wantPanic)
}
}()
Register("TestPolicySetting", DeviceSetting, StringValue)
}()
}
func TestRegisterDuplicateSettings(t *testing.T) {
tests := []struct {
name string
settings []*Definition
wantEq *Definition
wantErrStr string
}{
{
name: "NoConflict/Exact",
settings: []*Definition{
NewDefinition("TestPolicySetting", DeviceSetting, StringValue),
NewDefinition("TestPolicySetting", DeviceSetting, StringValue),
},
wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue),
},
{
name: "NoConflict/MergeOS-First",
settings: []*Definition{
NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "android", "macos"),
NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms
},
wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms
},
{
name: "NoConflict/MergeOS-Second",
settings: []*Definition{
NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms
NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "android", "macos"),
},
wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms
},
{
name: "NoConflict/MergeOS-Both",
settings: []*Definition{
NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "macos"),
NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "windows"),
},
wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "macos", "windows"),
},
{
name: "Conflict/Scope",
settings: []*Definition{
NewDefinition("TestPolicySetting", DeviceSetting, StringValue),
NewDefinition("TestPolicySetting", UserSetting, StringValue),
},
wantEq: nil,
wantErrStr: `duplicate policy definition: "TestPolicySetting"`,
},
{
name: "Conflict/Type",
settings: []*Definition{
NewDefinition("TestPolicySetting", UserSetting, StringValue),
NewDefinition("TestPolicySetting", UserSetting, IntegerValue),
},
wantEq: nil,
wantErrStr: `duplicate policy definition: "TestPolicySetting"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resetSettingDefinitions(t)
for _, s := range tt.settings {
Register(s.Key(), s.Scope(), s.Type(), s.SupportedPlatforms()...)
}
got, err := DefinitionOf("TestPolicySetting")
var gotErrStr string
if err != nil {
gotErrStr = err.Error()
}
if gotErrStr != tt.wantErrStr {
t.Fatalf("ErrStr: got %q, want %q", gotErrStr, tt.wantErrStr)
}
if !got.Equal(tt.wantEq) {
t.Errorf("Definition got %v, want %v", got, tt.wantEq)
}
if !slices.Equal(got.SupportedPlatforms(), tt.wantEq.SupportedPlatforms()) {
t.Errorf("SupportedPlatforms got %v, want %v", got.SupportedPlatforms(), tt.wantEq.SupportedPlatforms())
}
})
}
}
func TestListSettingDefinitions(t *testing.T) {
definitions := []*Definition{
NewDefinition("TestDevicePolicySetting", DeviceSetting, IntegerValue),
NewDefinition("TestProfilePolicySetting", ProfileSetting, StringValue),
NewDefinition("TestUserPolicySetting", UserSetting, BooleanValue),
NewDefinition("TestStringListPolicySetting", DeviceSetting, StringListValue),
}
if err := SetDefinitionsForTest(t, definitions...); err != nil {
t.Fatalf("SetDefinitionsForTest failed: %v", err)
}
cmp := func(l, r *Definition) int {
return strings.Compare(string(l.Key()), string(r.Key()))
}
want := append([]*Definition{}, definitions...)
slices.SortFunc(want, cmp)
got, err := Definitions()
if err != nil {
t.Fatalf("Definitions failed: %v", err)
}
slices.SortFunc(got, cmp)
if !slices.Equal(got, want) {
t.Errorf("got %v, want %v", got, want)
}
}
func resetSettingDefinitions(t *testing.T) {
t.Cleanup(func() {
definitionsMu.Lock()
definitionsList = nil
definitions = lazy.SyncValue[DefinitionMap]{}
definitionsUsed = false
definitionsMu.Unlock()
})
definitionsMu.Lock()
definitionsList = nil
definitions = lazy.SyncValue[DefinitionMap]{}
definitionsUsed = false
definitionsMu.Unlock()
}

View File

@ -0,0 +1,153 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package setting
import (
xmaps "golang.org/x/exp/maps"
"tailscale.com/util/deephash"
)
// Snapshot is an immutable collection of [RawItem]s, representing
// a set of policy settings applied at a specific moment in time.
// A nil pointer to [Snapshot] is valid.
type Snapshot struct {
m map[Key]RawItem
sig deephash.Sum // of m
summary Summary
}
// NewSnapshot returns a new [Snapshot] with the specified items and options.
func NewSnapshot(items map[Key]RawItem, opts ...SummaryOption) *Snapshot {
return &Snapshot{m: items, sig: deephash.Hash(&items), summary: SummaryWith(opts...)}
}
type keyItemPair struct {
Key Key
Item RawItem
}
// All returns an iterator over [[Key], [RawItem]] key-value pairs in b. The
// iteration order is not specified and is not guaranteed to be the same from
// one call to the next.
func (s *Snapshot) All() []keyItemPair {
if s == nil {
return nil
}
// TODO(nickkhyl): return iter.Seq2[[Key], [RawItem]] in Go 1.23,
// and remove [keyItemPair].
items := make([]keyItemPair, 0, len(s.m))
for k, i := range s.m {
items = append(items, keyItemPair{k, i})
}
return items
}
// Get returns the value of the policy setting with the specified key
// or nil if it does not exist or could not be read.
func (s *Snapshot) Get(k Key) any {
v, _ := s.GetErr(k)
return v
}
// GetErr returns the value of the policy setting with the specified key,
// [ErrNotConfigured] if it does not exist, or an error returned by
// the policy Store if the policy setting could not be read.
func (s *Snapshot) GetErr(k Key) (any, error) {
if s != nil {
if s, ok := s.m[k]; ok {
return s.Value(), s.Error()
}
}
return nil, ErrNotConfigured
}
// GetSetting returns the untyped policy setting with the specified key and true
// if a policy setting with such key has been configured;
// otherwise, it returns zero, false.
func (s *Snapshot) GetSetting(k Key) (setting RawItem, ok bool) {
setting, ok = s.m[k]
return setting, ok
}
// Equal reports whether s and s2 are equal.
func (s *Snapshot) Equal(s2 *Snapshot) bool {
if !s.EqualItems(s2) {
return false
}
return s.Summary() == s2.Summary()
}
// EqualItems reports whether items in s and s2 are equal.
func (s *Snapshot) EqualItems(s2 *Snapshot) bool {
if s == s2 {
return true
}
if s.Len() != s2.Len() {
return false
}
if s.Len() == 0 {
return true
}
return s.sig == s2.sig
}
// Keys return an iterator over keys in s. The iteration order is not specified
// and is not guaranteed to be the same from one call to the next.
func (s *Snapshot) Keys() []Key {
if s.m == nil {
return nil
}
// TODO(nickkhyl): return iter.Seq[Key] in Go 1.23.
return xmaps.Keys(s.m)
}
// Len reports the number of [RawItem]s in s.
func (s *Snapshot) Len() int {
if s == nil {
return 0
}
return len(s.m)
}
// Summary returns information about s as a whole rather than about specific [RawItem]s in it.
func (s *Snapshot) Summary() Summary {
if s == nil {
return Summary{}
}
return s.summary
}
// MergeSnapshots returns a [Snapshot] that contains all [RawItem]s
// from snapshot1 and snapshot2 and the [Summary] with the narrower [PolicyScope].
// If there's a conflict between policy settings in the two snapshots,
// the policy settings from the snapshot with the broader scope take precedence.
// In other words, policy settings configured for the [DeviceScope] win
// over policy settings configured for a user scope.
func MergeSnapshots(snapshot1, snapshot2 *Snapshot) *Snapshot {
scope1, ok1 := snapshot1.Summary().Scope().GetOk()
scope2, ok2 := snapshot2.Summary().Scope().GetOk()
if ok1 && ok2 && scope2.IsStrictlyWithinOf(scope1) {
// Swap snapshots if snapshot1 has higher precedence than snapshot2.
snapshot1, snapshot2 = snapshot2, snapshot1
}
if snapshot2.Len() == 0 {
return snapshot1
}
summaryOpts := make([]SummaryOption, 0, 2)
if scope, ok := snapshot1.Summary().Scope().GetOk(); ok {
// Use the scope from snapshot1, if present, which is the more specific snapshot.
summaryOpts = append(summaryOpts, scope)
}
if snapshot1.Len() == 0 {
if origin, ok := snapshot2.Summary().Origin().GetOk(); ok {
// Use the origin from snapshot2 if snapshot1 is empty.
summaryOpts = append(summaryOpts, origin)
}
return &Snapshot{snapshot2.m, snapshot2.sig, SummaryWith(summaryOpts...)}
}
m := make(map[Key]RawItem, snapshot1.Len()+snapshot2.Len())
xmaps.Copy(m, snapshot1.m)
xmaps.Copy(m, snapshot2.m) // snapshot2 has higher precedence
return &Snapshot{m, deephash.Hash(&m), SummaryWith(summaryOpts...)}
}

View File

@ -0,0 +1,372 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package setting
import (
"testing"
"time"
)
func TestMergeSnapshots(t *testing.T) {
tests := []struct {
name string
s1, s2 *Snapshot
want *Snapshot
}{
{
name: "both-nil",
s1: nil,
s2: nil,
want: NewSnapshot(map[Key]RawItem{}),
},
{
name: "both-empty",
s1: NewSnapshot(map[Key]RawItem{}),
s2: NewSnapshot(map[Key]RawItem{}),
want: NewSnapshot(map[Key]RawItem{}),
},
{
name: "first-nil",
s1: nil,
s2: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}),
want: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}),
},
{
name: "first-empty",
s1: NewSnapshot(map[Key]RawItem{}),
s2: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
want: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
},
{
name: "second-nil",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}),
s2: nil,
want: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}),
},
{
name: "second-empty",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
s2: NewSnapshot(map[Key]RawItem{}),
want: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
},
{
name: "no-conflicts",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
s2: NewSnapshot(map[Key]RawItem{
"Setting4": {value: 2 * time.Hour},
"Setting5": {value: VisibleByPolicy},
"Setting6": {value: ShowChoiceByPolicy},
}),
want: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
"Setting4": {value: 2 * time.Hour},
"Setting5": {value: VisibleByPolicy},
"Setting6": {value: ShowChoiceByPolicy},
}),
},
{
name: "with-conflicts",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}),
s2: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 456},
"Setting3": {value: false},
"Setting4": {value: 2 * time.Hour},
}),
want: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 456},
"Setting2": {value: "String"},
"Setting3": {value: false},
"Setting4": {value: 2 * time.Hour},
}),
},
{
name: "with-scope-first-wins",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}, DeviceScope),
s2: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 456},
"Setting3": {value: false},
"Setting4": {value: 2 * time.Hour},
}, CurrentUserScope),
want: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
"Setting4": {value: 2 * time.Hour},
}, CurrentUserScope),
},
{
name: "with-scope-second-wins",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}, CurrentUserScope),
s2: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 456},
"Setting3": {value: false},
"Setting4": {value: 2 * time.Hour},
}, DeviceScope),
want: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 456},
"Setting2": {value: "String"},
"Setting3": {value: false},
"Setting4": {value: 2 * time.Hour},
}, CurrentUserScope),
},
{
name: "with-scope-both-empty",
s1: NewSnapshot(map[Key]RawItem{}, CurrentUserScope),
s2: NewSnapshot(map[Key]RawItem{}, DeviceScope),
want: NewSnapshot(map[Key]RawItem{}, CurrentUserScope),
},
{
name: "with-scope-first-empty",
s1: NewSnapshot(map[Key]RawItem{}, CurrentUserScope),
s2: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true}}, DeviceScope),
want: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}, CurrentUserScope),
},
{
name: "with-scope-second-empty",
s1: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}, CurrentUserScope),
s2: NewSnapshot(map[Key]RawItem{}, DeviceScope),
want: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}, CurrentUserScope),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := MergeSnapshots(tt.s1, tt.s2)
if !got.Equal(tt.want) {
t.Errorf("got %v, want %v", got, tt.want)
}
})
}
}
func TestSnapshotEqual(t *testing.T) {
tests := []struct {
name string
b1, b2 *Snapshot
wantEqual bool
wantEqualItems bool
}{
{
name: "nil-nil",
b1: nil,
b2: nil,
wantEqual: true,
wantEqualItems: true,
},
{
name: "nil-empty",
b1: nil,
b2: NewSnapshot(map[Key]RawItem{}),
wantEqual: true,
wantEqualItems: true,
},
{
name: "empty-nil",
b1: NewSnapshot(map[Key]RawItem{}),
b2: nil,
wantEqual: true,
wantEqualItems: true,
},
{
name: "empty-empty",
b1: NewSnapshot(map[Key]RawItem{}),
b2: NewSnapshot(map[Key]RawItem{}),
wantEqual: true,
wantEqualItems: true,
},
{
name: "first-nil",
b1: nil,
b2: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
wantEqual: false,
wantEqualItems: false,
},
{
name: "first-empty",
b1: NewSnapshot(map[Key]RawItem{}),
b2: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
wantEqual: false,
wantEqualItems: false,
},
{
name: "second-nil",
b1: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: true},
}),
b2: nil,
wantEqual: false,
wantEqualItems: false,
},
{
name: "second-empty",
b1: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
b2: NewSnapshot(map[Key]RawItem{}),
wantEqual: false,
wantEqualItems: false,
},
{
name: "same-items-same-order-no-scope",
b1: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
b2: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}),
wantEqual: true,
wantEqualItems: true,
},
{
name: "same-items-same-order-same-scope",
b1: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}, DeviceScope),
b2: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}, DeviceScope),
wantEqual: true,
wantEqualItems: true,
},
{
name: "same-items-different-order-same-scope",
b1: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}, DeviceScope),
b2: NewSnapshot(map[Key]RawItem{
"Setting3": {value: false},
"Setting1": {value: 123},
"Setting2": {value: "String"},
}, DeviceScope),
wantEqual: true,
wantEqualItems: true,
},
{
name: "same-items-same-order-different-scope",
b1: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}, DeviceScope),
b2: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}, CurrentUserScope),
wantEqual: false,
wantEqualItems: true,
},
{
name: "different-items-same-scope",
b1: NewSnapshot(map[Key]RawItem{
"Setting1": {value: 123},
"Setting2": {value: "String"},
"Setting3": {value: false},
}, DeviceScope),
b2: NewSnapshot(map[Key]RawItem{
"Setting4": {value: 2 * time.Hour},
"Setting5": {value: VisibleByPolicy},
"Setting6": {value: ShowChoiceByPolicy},
}, DeviceScope),
wantEqual: false,
wantEqualItems: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotEqual := tt.b1.Equal(tt.b2); gotEqual != tt.wantEqual {
t.Errorf("WantEqual: got %v, want %v", gotEqual, tt.wantEqual)
}
if gotEqualItems := tt.b1.EqualItems(tt.b2); gotEqualItems != tt.wantEqualItems {
t.Errorf("WantEqualItems: got %v, want %v", gotEqualItems, tt.wantEqualItems)
}
})
}
}

View File

@ -0,0 +1,84 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package setting
import (
jsonv2 "github.com/go-json-experiment/json"
"github.com/go-json-experiment/json/jsontext"
"tailscale.com/types/opt"
)
// Summary is an immutable [PolicyScope] and [Origin].
type Summary struct {
data summary
}
type summary struct {
Scope opt.Value[PolicyScope] `json:",omitzero"`
Origin opt.Value[Origin] `json:",omitzero"`
}
// SummaryWith returns a [Summary] with the specified options.
func SummaryWith(opts ...SummaryOption) Summary {
var summary Summary
for _, o := range opts {
o.applySummaryOption(&summary)
}
return summary
}
// Scope reports the [PolicyScope] in s.
func (s Summary) Scope() opt.Value[PolicyScope] {
return s.data.Scope
}
// Origin reports the [Origin] in s.
func (s Summary) Origin() opt.Value[Origin] {
return s.data.Origin
}
// MarshalJSONV2 implements [jsonv2.MarshalerV2].
func (s Summary) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error {
return jsonv2.MarshalEncode(out, &s.data, opts)
}
// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2].
func (s *Summary) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error {
return jsonv2.UnmarshalDecode(in, &s.data, opts)
}
// MarshalJSON implements [json.Marshaler].
func (s Summary) MarshalJSON() ([]byte, error) {
return jsonv2.Marshal(s) // uses MarshalJSONV2
}
// UnmarshalJSON implements [json.Unmarshaler].
func (s *Summary) UnmarshalJSON(b []byte) error {
return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONV2
}
// SummaryOption is an option that configures [Summary]
// The following are allowed options:
//
// - [Summary]
// - [PolicyScope]
// - [Origin]
type SummaryOption interface {
applySummaryOption(summary *Summary)
}
func (s PolicyScope) applySummaryOption(summary *Summary) {
summary.data.Scope.Set(s)
}
func (o Origin) applySummaryOption(summary *Summary) {
summary.data.Origin.Set(o)
if !summary.data.Scope.IsSet() {
summary.data.Scope.Set(o.Scope())
}
}
func (s Summary) applySummaryOption(summary *Summary) {
*summary = s
}

View File

@ -0,0 +1,132 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package setting
import (
"encoding"
)
// PreferenceOption is a policy that governs whether a boolean variable
// is forcibly assigned an administrator-defined value, or allowed to receive
// a user-defined value.
type PreferenceOption int
const (
ShowChoiceByPolicy PreferenceOption = iota
NeverByPolicy
AlwaysByPolicy
)
// Show returns if the UI option that controls the choice administered by this
// policy should be shown. Currently this is true if and only if the policy is
// [ShowChoiceByPolicy].
func (p PreferenceOption) Show() bool {
return p == ShowChoiceByPolicy
}
// ShouldEnable checks if the choice administered by this policy should be
// enabled. If the administrator has chosen a setting, the administrator's
// setting is returned, otherwise userChoice is returned.
func (p PreferenceOption) ShouldEnable(userChoice bool) bool {
switch p {
case NeverByPolicy:
return false
case AlwaysByPolicy:
return true
default:
return userChoice
}
}
// IsAlways reports whether the preference should always be enabled.
func (p PreferenceOption) IsAlways() bool {
return p == AlwaysByPolicy
}
// IsNever reports whether the preference should always be disabled.
func (p PreferenceOption) IsNever() bool {
return p == NeverByPolicy
}
// WillOverride checks if the choice administered by the policy is different
// from the user's choice.
func (p PreferenceOption) WillOverride(userChoice bool) bool {
return p.ShouldEnable(userChoice) != userChoice
}
// String returns a string representation of p.
func (p PreferenceOption) String() string {
switch p {
case AlwaysByPolicy:
return "always"
case NeverByPolicy:
return "never"
default:
return "user-decides"
}
}
// MarshalText implements [encoding.TextMarshaler].
func (p *PreferenceOption) MarshalText() (text []byte, err error) {
return []byte(p.String()), nil
}
// UnmarshalText implements [encoding.TextUnmarshaler].
func (p *PreferenceOption) UnmarshalText(text []byte) error {
switch string(text) {
case "always":
*p = AlwaysByPolicy
case "never":
*p = NeverByPolicy
default:
*p = ShowChoiceByPolicy
}
return nil
}
// Visibility is a policy that controls whether or not a particular
// component of a user interface is to be shown.
type Visibility byte
var (
_ encoding.TextMarshaler = (*Visibility)(nil)
_ encoding.TextUnmarshaler = (*Visibility)(nil)
)
const (
VisibleByPolicy Visibility = 'v'
HiddenByPolicy Visibility = 'h'
)
// Show reports whether the UI option administered by this policy should be shown.
// Currently this is true if the policy is not [hiddenByPolicy].
func (p Visibility) Show() bool {
return p != HiddenByPolicy
}
// String returns a string representation of p.
func (p Visibility) String() string {
switch p {
case 'h':
return "hide"
default:
return "show"
}
}
// MarshalText implements [encoding.TextMarshaler].
func (p Visibility) MarshalText() (text []byte, err error) {
return []byte(p.String()), nil
}
// UnmarshalText implements [encoding.TextUnmarshaler].
func (p *Visibility) UnmarshalText(text []byte) error {
switch string(text) {
case "hide":
*p = HiddenByPolicy
default:
*p = VisibleByPolicy
}
return nil
}

View File

@ -0,0 +1,393 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package source
import (
"errors"
"fmt"
"io"
"slices"
"sort"
"sync"
"time"
"tailscale.com/util/mak"
"tailscale.com/util/set"
"tailscale.com/util/syspolicy/internal/loggerx"
"tailscale.com/util/syspolicy/internal/metrics"
"tailscale.com/util/syspolicy/setting"
)
// Reader reads all configured policy settings from a given [Store].
// It registers a change callback with the [Store] and maintains the current version
// of the [setting.Snapshot] by lazily re-reading policy settings from the [Store]
// whenever a new snapshot is requested
// It is safe for concurrent use.
type Reader struct {
store Store
origin *setting.Origin
settings []*setting.Definition
unregisterChangeNotifier func()
doneCh chan struct{} // closed when policyCache is closed.
mu sync.RWMutex
closing bool
upToDate bool
lastPolicy *setting.Snapshot
sessions set.HandleSet[*ReadingSession]
}
// newReader returns a new [Reader] that reads policy settings from a given [Store].
// The returned reader takes ownership of the store. If the store implements [io.Closer],
// the returned reader will close the store when it is closed.
func newReader(store Store, origin *setting.Origin) (*Reader, error) {
settings, err := setting.Definitions()
if err != nil {
return nil, err
}
if expirable, ok := store.(Expirable); ok {
select {
case <-expirable.Done():
return nil, ErrStoreClosed
default:
}
}
reader := &Reader{store: store, origin: origin, settings: settings, doneCh: make(chan struct{})}
if changeable, ok := store.(Changeable); ok {
// We should subscribe to policy change notifications first before reading
// the policy settings from the store. This way we won't miss any notifications.
if reader.unregisterChangeNotifier, err = changeable.RegisterChangeCallback(reader.onPolicyChange); err != nil {
// Errors registering policy change callbacks are non-fatal.
// TODO(nickkhyl): implement a background policy refresh every X minutes?
loggerx.Errorf("failed to register %v policy change callback: %v\n", origin, err)
}
}
if _, err := reader.reload(true); err != nil {
if reader.unregisterChangeNotifier != nil {
reader.unregisterChangeNotifier()
}
return nil, err
}
if expirable, ok := store.(Expirable); ok {
if waitCh := expirable.Done(); waitCh != nil {
go func() {
select {
case <-waitCh:
reader.Close()
case <-reader.doneCh:
}
}()
}
}
return reader, nil
}
// GetSettings returns the current [*setting.Snapshot],
// re-reading it from from the underlying [Store] only if the policy
// has changed since it was read last. It never fails and returns
// the previous version of the policy settings if a read attempt fails.
func (r *Reader) GetSettings() *setting.Snapshot {
r.mu.RLock()
if r.upToDate {
r.mu.RUnlock()
return r.lastPolicy
}
r.mu.RUnlock()
policy, err := r.reload(false)
if err != nil {
// If the policy could not be reloaded at all, we'll return the last cached version of it.
// On the contrary, errors specific to individual policy items are always propagated to the callers.
loggerx.Errorf("failed to reload %v policy: %v\n", r.origin, err)
}
return policy
}
// ReadSettings reads policy settings from the underlying [Store] even if no
// changes were detected. It returns the new [*setting.Snapshot], nil on
// success, or nil, error in case of failure.
func (r *Reader) ReadSettings() (*setting.Snapshot, error) {
b, err := r.reload(true)
if err != nil {
return nil, err
}
return b, nil
}
// reload is like [Reader.ReadSettings], but allows specifying whether to re-read
// an unchanged policy, and returns the last [*setting.Snapshot] if the read fails.
func (r *Reader) reload(force bool) (*setting.Snapshot, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.upToDate && !force {
return r.lastPolicy, nil
}
if lockable, ok := r.store.(Lockable); ok {
if err := lockable.Lock(); err != nil {
return r.lastPolicy, err
}
defer lockable.Unlock()
}
r.upToDate = true
metrics.Reset(r.origin)
var m map[setting.Key]setting.RawItem
if lastPolicyCount := r.lastPolicy.Len(); lastPolicyCount > 0 {
m = make(map[setting.Key]setting.RawItem, lastPolicyCount)
}
for _, s := range r.settings {
if !r.origin.Scope().IsConfigurableSetting(s) {
// Skip settings that cannot be configured in the current scope.
continue
}
val, err := readPolicySettingValue(r.store, s)
if err != nil && (errors.Is(err, setting.ErrNoSuchKey) || errors.Is(err, setting.ErrNotConfigured)) {
metrics.ReportNotConfigured(r.origin, s)
continue
}
if err == nil {
metrics.ReportConfigured(r.origin, s, val)
} else {
metrics.ReportError(r.origin, s, err)
}
// If there's an error reading a single policy, such as a value type mismatch,
// we'll wrap the error to preserve its text and return it
// whenever someone attempts to fetch the value.
mak.Set(&m, s.Key(), setting.RawItemWith(val, setting.WrapError(err), r.origin))
}
newPolicy := setting.NewSnapshot(m, setting.SummaryWith(r.origin))
if r.lastPolicy == nil || !newPolicy.EqualItems(r.lastPolicy) {
r.lastPolicy = newPolicy
}
return r.lastPolicy, nil
}
// ReadingSession is like [Reader], but with a channel that's written
// to when there's a policy change, and closed when the session is terminated.
type ReadingSession struct {
reader *Reader
policyChangedCh chan struct{} // 1-buffered channel
handle set.Handle // in the reader.sessions
closeInternal func()
}
// OpenSession opens and returns a new session to r, allowing the caller
// to get notified whenever a policy change is reported by the [source.Store],
// or an [ErrStoreClosed] if the reader has already been closed.
func (r *Reader) OpenSession() (*ReadingSession, error) {
session := &ReadingSession{
reader: r,
policyChangedCh: make(chan struct{}, 1),
}
session.closeInternal = sync.OnceFunc(func() { close(session.policyChangedCh) })
r.mu.Lock()
if !r.closing {
session.handle = r.sessions.Add(session)
r.mu.Unlock()
return session, nil
}
r.mu.Unlock()
return nil, ErrStoreClosed
}
// GetSettings is like [Reader.GetSettings].
func (s *ReadingSession) GetSettings() *setting.Snapshot {
return s.reader.GetSettings()
}
// ReadSettings is like [Reader.ReadSettings].
func (s *ReadingSession) ReadSettings() (*setting.Snapshot, error) {
return s.reader.ReadSettings()
}
// PolicyChanged returns a channel that's written to when
// there's a policy change, closed when the session is terminated.
func (s *ReadingSession) PolicyChanged() <-chan struct{} {
return s.policyChangedCh
}
// Close unregisters this session with the [Reader].
func (s *ReadingSession) Close() {
s.reader.mu.Lock()
delete(s.reader.sessions, s.handle)
s.closeInternal()
s.reader.mu.Unlock()
}
// onPolicyChange handles a policy change notification from the [Store],
// invalidating the current [setting.Snapshot] in r,
// and notifying the active [ReadingSession]s.
func (r *Reader) onPolicyChange() {
r.mu.Lock()
defer r.mu.Unlock()
r.upToDate = false
for _, s := range r.sessions {
select {
case s.policyChangedCh <- struct{}{}:
// Notified.
default:
// 1-buffered channel is full, meaning that another policy change
// notification is already en route.
}
}
}
// Close closes the store reader and the underlying store.
func (r *Reader) Close() error {
r.mu.Lock()
if r.closing {
r.mu.Unlock()
return nil
}
r.closing = true
r.mu.Unlock()
if r.unregisterChangeNotifier != nil {
r.unregisterChangeNotifier()
r.unregisterChangeNotifier = nil
}
if closer, ok := r.store.(io.Closer); ok {
if err := closer.Close(); err != nil {
return err
}
}
r.store = nil
close(r.doneCh)
r.mu.Lock()
defer r.mu.Unlock()
for _, c := range r.sessions {
c.closeInternal()
}
r.sessions = nil
return nil
}
// Done returns a channel that is closed when the reader is closed.
func (r *Reader) Done() <-chan struct{} {
return r.doneCh
}
// ReadableSource is a [Source] open for reading.
type ReadableSource struct {
*Source
*ReadingSession
}
// Close closes the underlying [ReadingSession].
func (s ReadableSource) Close() {
s.ReadingSession.Close()
}
// ReadableSources is a slice of [ReadableSource].
type ReadableSources []ReadableSource
// Contains reports whether s contains the specified source.
func (s ReadableSources) Contains(source *Source) bool {
return s.IndexOf(source) != -1
}
// IndexOf returns position of the specified source in s, or -1
// if the source does not exist.
func (s ReadableSources) IndexOf(source *Source) int {
return slices.IndexFunc(s, func(rs ReadableSource) bool {
return rs.Source == source
})
}
// InsertionIndexOf returns the position at which source can be inserted
// to maintain the sorted order of the readableSources.
// The return value is unspecified if s is not sorted on entry to InsertionIndexOf.
func (s ReadableSources) InsertionIndexOf(source *Source) int {
low, high := 0, len(s)
for low < high {
mid := (low + high) / 2
if s[mid].Compare(source) <= 0 {
low = mid + 1
} else {
high = mid
}
}
return low
}
// StableSort sorts the readableSources by the precedence, so that policy settings
// from sources with higher precedence (e.g., [DeviceScope]) will be merged last,
// overriding any policy settings with the same keys configured in sources with
// lower precedence (e.g., [CurrentUserScope]).
func (s *ReadableSources) StableSort() {
sort.SliceStable(*s, func(i, j int) bool {
return (*s)[i].Source.Compare((*s)[j].Source) < 0
})
}
// DeleteAt closes and deletes the i-th source from s.
func (s *ReadableSources) DeleteAt(i int) {
(*s)[i].Close()
*s = slices.Delete(*s, i, i+1)
}
// Close closes and deletes all sources in s.
func (s *ReadableSources) Close() {
for _, s := range *s {
s.Close()
}
*s = nil
}
func readPolicySettingValue(store Store, s *setting.Definition) (value any, err error) {
switch key := s.Key(); s.Type() {
case setting.BooleanValue:
return store.ReadBoolean(key)
case setting.IntegerValue:
return store.ReadUInt64(key)
case setting.StringValue:
return store.ReadString(key)
case setting.StringListValue:
return store.ReadStringArray(key)
case setting.PreferenceOptionValue:
s, err := store.ReadString(key)
if err == nil {
var value setting.PreferenceOption
if err = value.UnmarshalText([]byte(s)); err == nil {
return value, nil
}
}
return setting.ShowChoiceByPolicy, err
case setting.VisibilityValue:
s, err := store.ReadString(key)
if err == nil {
var value setting.Visibility
if err = value.UnmarshalText([]byte(s)); err == nil {
return value, nil
}
}
return setting.VisibleByPolicy, err
case setting.DurationValue:
s, err := store.ReadString(key)
if err == nil {
var value time.Duration
if value, err = time.ParseDuration(s); err == nil {
return value, nil
}
}
return nil, err
default:
return nil, fmt.Errorf("%w: unsupported setting type: %v", setting.ErrTypeMismatch, s.Type())
}
}

View File

@ -0,0 +1,291 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package source
import (
"cmp"
"testing"
"time"
"tailscale.com/util/must"
"tailscale.com/util/syspolicy/setting"
)
func TestReaderLifecycle(t *testing.T) {
tests := []struct {
name string
origin *setting.Origin
definitions []*setting.Definition
wantReads []TestExpectedReads
initStrings []TestSetting[string]
initUInt64s []TestSetting[uint64]
initWant *setting.Snapshot
addStrings []TestSetting[string]
addStringLists []TestSetting[[]string]
newWant *setting.Snapshot
}{
{
name: "read-all-settings-once",
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
definitions: []*setting.Definition{
setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue),
setting.NewDefinition("IntegerValue", setting.DeviceSetting, setting.IntegerValue),
setting.NewDefinition("BooleanValue", setting.DeviceSetting, setting.BooleanValue),
setting.NewDefinition("StringListValue", setting.DeviceSetting, setting.StringListValue),
setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue),
setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue),
},
wantReads: []TestExpectedReads{
{Key: "StringValue", Type: setting.StringValue, NumTimes: 1},
{Key: "IntegerValue", Type: setting.IntegerValue, NumTimes: 1},
{Key: "BooleanValue", Type: setting.BooleanValue, NumTimes: 1},
{Key: "StringListValue", Type: setting.StringListValue, NumTimes: 1},
{Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
},
initWant: setting.NewSnapshot(nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
},
{
name: "re-read-all-settings-when-the-policy-changes",
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
definitions: []*setting.Definition{
setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue),
setting.NewDefinition("IntegerValue", setting.DeviceSetting, setting.IntegerValue),
setting.NewDefinition("BooleanValue", setting.DeviceSetting, setting.BooleanValue),
setting.NewDefinition("StringListValue", setting.DeviceSetting, setting.StringListValue),
setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue),
setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue),
},
wantReads: []TestExpectedReads{
{Key: "StringValue", Type: setting.StringValue, NumTimes: 1},
{Key: "IntegerValue", Type: setting.IntegerValue, NumTimes: 1},
{Key: "BooleanValue", Type: setting.BooleanValue, NumTimes: 1},
{Key: "StringListValue", Type: setting.StringListValue, NumTimes: 1},
{Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
},
initWant: setting.NewSnapshot(nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
addStrings: []TestSetting[string]{TestSettingOf("StringValue", "S1")},
addStringLists: []TestSetting[[]string]{TestSettingOf("StringListValue", []string{"S1", "S2", "S3"})},
newWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"StringValue": setting.RawItemWith("S1", nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
"StringListValue": setting.RawItemWith([]string{"S1", "S2", "S3"}, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
}, setting.NewNamedOrigin("Test", setting.DeviceScope)),
},
{
name: "read-settings-if-in-scope/device",
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
definitions: []*setting.Definition{
setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue),
setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue),
setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue),
},
wantReads: []TestExpectedReads{
{Key: "DeviceSetting", Type: setting.StringValue, NumTimes: 1},
{Key: "ProfileSetting", Type: setting.IntegerValue, NumTimes: 1},
{Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1},
},
},
{
name: "read-settings-if-in-scope/profile",
origin: setting.NewNamedOrigin("Test", setting.CurrentProfileScope),
definitions: []*setting.Definition{
setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue),
setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue),
setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue),
},
wantReads: []TestExpectedReads{
// Device settings cannot be configured at the profile scope and should not be read.
{Key: "ProfileSetting", Type: setting.IntegerValue, NumTimes: 1},
{Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1},
},
},
{
name: "read-settings-if-in-scope/user",
origin: setting.NewNamedOrigin("Test", setting.CurrentUserScope),
definitions: []*setting.Definition{
setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue),
setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue),
setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue),
},
wantReads: []TestExpectedReads{
// Device and profile settings cannot be configured at the profile scope and should not be read.
{Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1},
},
},
{
name: "read-stringy-settings",
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
definitions: []*setting.Definition{
setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue),
setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue),
setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue),
},
wantReads: []TestExpectedReads{
{Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
},
initStrings: []TestSetting[string]{
TestSettingOf("DurationValue", "2h30m"),
TestSettingOf("PreferenceOptionValue", "always"),
TestSettingOf("VisibilityValue", "show"),
},
initWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"DurationValue": setting.RawItemWith(must.Get(time.ParseDuration("2h30m")), nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
"PreferenceOptionValue": setting.RawItemWith(setting.AlwaysByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
"VisibilityValue": setting.RawItemWith(setting.VisibleByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
}, setting.NewNamedOrigin("Test", setting.DeviceScope)),
},
{
name: "read-erroneous-stringy-settings",
origin: setting.NewNamedOrigin("Test", setting.CurrentUserScope),
definitions: []*setting.Definition{
setting.NewDefinition("DurationValue1", setting.UserSetting, setting.DurationValue),
setting.NewDefinition("DurationValue2", setting.UserSetting, setting.DurationValue),
setting.NewDefinition("PreferenceOptionValue", setting.UserSetting, setting.PreferenceOptionValue),
setting.NewDefinition("VisibilityValue", setting.UserSetting, setting.VisibilityValue),
},
wantReads: []TestExpectedReads{
{Key: "DurationValue1", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
{Key: "DurationValue2", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
},
initStrings: []TestSetting[string]{
TestSettingOf("DurationValue1", "soon"),
TestSettingWithError[string]("DurationValue2", setting.NewError("bang!")),
TestSettingOf("PreferenceOptionValue", "sometimes"),
},
initUInt64s: []TestSetting[uint64]{
TestSettingOf[uint64]("VisibilityValue", 42), // type mismatch
},
initWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{
"DurationValue1": setting.RawItemWith(nil, setting.NewError("time: invalid duration \"soon\""), setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
"DurationValue2": setting.RawItemWith(nil, setting.NewError("bang!"), setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
"PreferenceOptionValue": setting.RawItemWith(setting.ShowChoiceByPolicy, nil, setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
"VisibilityValue": setting.RawItemWith(setting.VisibleByPolicy, setting.NewError("type mismatch in ReadString: got uint64"), setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
}, setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setting.SetDefinitionsForTest(t, tt.definitions...)
store := NewTestStore(t)
store.SetStrings(tt.initStrings...)
store.SetUInt64s(tt.initUInt64s...)
reader, err := newReader(store, tt.origin)
if err != nil {
t.Fatalf("newReader failed: %v", err)
}
if got := reader.GetSettings(); tt.initWant != nil && !got.Equal(tt.initWant) {
t.Errorf("Settings do not match: got %v, want %v", got, tt.initWant)
}
if tt.wantReads != nil {
store.ReadsMustEqual(tt.wantReads...)
}
// Should not result in new reads as there were no changes.
N := 100
for range N {
reader.GetSettings()
}
if tt.wantReads != nil {
store.ReadsMustEqual(tt.wantReads...)
}
store.ResetCounters()
got, err := reader.ReadSettings()
if err != nil {
t.Fatalf("ReadSettings failed: %v", err)
}
if tt.initWant != nil && !got.Equal(tt.initWant) {
t.Errorf("Settings do not match: got %v, want %v", got, tt.initWant)
}
if tt.wantReads != nil {
store.ReadsMustEqual(tt.wantReads...)
}
store.ResetCounters()
if len(tt.addStrings) != 0 || len(tt.addStringLists) != 0 {
store.SetStrings(tt.addStrings...)
store.SetStringLists(tt.addStringLists...)
// As the settings have changed, GetSettings needs to re-read them.
if got, want := reader.GetSettings(), cmp.Or(tt.newWant, tt.initWant); !got.Equal(want) {
t.Errorf("New Settings do not match: got %v, want %v", got, want)
}
if tt.wantReads != nil {
store.ReadsMustEqual(tt.wantReads...)
}
}
select {
case <-reader.Done():
t.Fatalf("the reader is closed")
default:
}
store.Close()
<-reader.Done()
})
}
}
func TestReadingSession(t *testing.T) {
setting.SetDefinitionsForTest(t, setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue))
store := NewTestStore(t)
origin := setting.NewOrigin(setting.DeviceScope)
reader, err := newReader(store, origin)
if err != nil {
t.Fatalf("newReader failed: %v", err)
}
session, err := reader.OpenSession()
if err != nil {
t.Fatalf("failed to open a reading session: %v", err)
}
t.Cleanup(session.Close)
if got, want := session.GetSettings(), setting.NewSnapshot(nil, origin); !got.Equal(want) {
t.Errorf("Settings do not match: got %v, want %v", got, want)
}
select {
case _, ok := <-session.PolicyChanged():
if ok {
t.Fatalf("the policy changed notification was sent prematurely")
} else {
t.Fatalf("the session was closed prematurely")
}
default:
}
store.SetStrings(TestSettingOf("StringValue", "S1"))
_, ok := <-session.PolicyChanged()
if !ok {
t.Fatalf("the session was closed prematurely")
}
want := setting.NewSnapshot(map[setting.Key]setting.RawItem{
"StringValue": setting.RawItemWith("S1", nil, origin),
}, origin)
if got := session.GetSettings(); !got.Equal(want) {
t.Errorf("Settings do not match: got %v, want %v", got, want)
}
store.Close()
if _, ok = <-session.PolicyChanged(); ok {
t.Fatalf("the session must be closed")
}
}

View File

@ -0,0 +1,146 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package source
import (
"cmp"
"errors"
"fmt"
"io"
"tailscale.com/types/lazy"
"tailscale.com/util/syspolicy/setting"
)
// ErrStoreClosed is an error returned when attempting to use a [Store] after it
// has been closed.
var ErrStoreClosed = errors.New("the policy store has been closed")
// Store provides methods to read system policy settings from OS-specific storage.
// Implementations must be concurrency-safe, and may also implement
// [Lockable], [Changeable], [Expirable] and [io.Closer].
//
// If a [Store] implementation also implements [io.Closer],
// it will be called by the package to release the resources
// when the store is no longer needed.
type Store interface {
// ReadString returns the value of a [setting.StringValue] with the specified key,
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
// an [setting.ErrTypeMismatch] if the policy setting is not of a string type.
ReadString(key setting.Key) (string, error)
// ReadUInt64 returns the value of a [setting.IntegerValue] with the specified key,
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
// an [setting.ErrTypeMismatch] if the policy setting is not of a string type.
ReadUInt64(key setting.Key) (uint64, error)
// ReadBoolean returns the value of a [setting.BooleanValue] with the specified key,
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
// an [setting.ErrTypeMismatch] if the policy setting is not of a string type.
ReadBoolean(key setting.Key) (bool, error)
// ReadStringArray returns the value of a [setting.StringListValue] with the specified key,
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
// an [setting.ErrTypeMismatch] if the policy setting is not of a string list type.
ReadStringArray(key setting.Key) ([]string, error)
}
// Lockable is an optional interface that [Store] implementations may support.
// Locking a [Store] is not mandatory as [Store] must be concurrency-safe,
// but is recommended to avoid issues where consecutive read calls for related
// policies might return inconsistent results if a policy change occurs between
// the calls.
type Lockable interface {
// Lock acquires a read lock on the policy store,
// ensuring the store's state remains unchanged while locked.
// Multiple readers can hold the lock simultaneously.
// It should return nil if the store does not support locking,
// or an error if the store cannot be locked.
Lock() error
// Unlock unlocks the policy store.
// It is a runtime error if the store is not locked on entry to Unlock.
Unlock()
}
// Changeable is an optional interface that [Store] implementations may support.
type Changeable interface {
// RegisterChangeCallback adds a function that will be called
// whenever there's a policy change in the [Store].
// The returned function can be used to unregister the callback.
RegisterChangeCallback(callback func()) (unregister func(), err error)
}
// Expirable is an optional interface that [Store] implementations may support.
type Expirable interface {
// Done returns a channel that is closed when the policy [Store] should no longer be used.
// It should return nil if the store never expires.
Done() <-chan struct{}
}
// Source represents a named source of policy settings for a given scope.
type Source struct {
name string
scope setting.PolicyScope
store Store
origin *setting.Origin
lazyReader lazy.SyncValue[*Reader]
}
// NewSource returns a new [Source] with the specified name, scope, and store.
func NewSource(name string, scope setting.PolicyScope, store Store) *Source {
return &Source{name: name, scope: scope, store: store, origin: setting.NewNamedOrigin(name, scope)}
}
// Name reports the name of the policy source.
func (s *Source) Name() string {
return s.name
}
// Scope reports the management scope of the policy source.
func (s *Source) Scope() setting.PolicyScope {
return s.scope
}
// Store returns the [Store] that can be used to read policy settings from this source.
func (s *Source) Store() Store {
return s.store
}
// Reader returns a [Reader] that reads from this source's [Store].
func (s *Source) Reader() (*Reader, error) {
return s.lazyReader.GetErr(func() (*Reader, error) {
return newReader(s.store, s.origin)
})
}
// String implements [fmt.Stringer].
func (s *Source) String() string {
if s.Name() != "" {
return fmt.Sprintf("%s (%v)", s.Name(), s.Scope())
}
return s.Scope().String()
}
// Compare returns an integer comparing [Source] s and s2
// by their precedence, following the "last-wins" model.
// The result will be:
//
// -1 if policy settings from s should be processed before policy settings from s2;
// +1 if policy settings from s should be processed after policy settings from s2, overriding s2;
// 0 if the relative processing order of policy settings in s and s2 is unspecified.
func (s *Source) Compare(s2 *Source) int {
return cmp.Compare(s2.Scope().Kind(), s.Scope().Kind())
}
// Close closes the [Source] and the underlying [Store].
func (s *Source) Close() error {
// The [Reader], if any, owns the [Store].
if reader, _ := s.lazyReader.GetErr(func() (*Reader, error) { return nil, ErrStoreClosed }); reader != nil {
return reader.Close()
}
// Otherwise, it is our responsibility to close it.
if closer, ok := s.store.(io.Closer); ok {
return closer.Close()
}
return nil
}

View File

@ -0,0 +1,438 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package source
import (
"errors"
"fmt"
"strings"
"sync"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
"tailscale.com/util/set"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/winutil/gp"
)
const (
softwareKeyName = "Software"
tsPoliciesSubkey = `Policies\Tailscale`
tsIPNSubkey = "Tailscale IPN" // the legacy key we need to fallback to
)
var (
// [PlatformPolicyStore] implements [Store].
_ Store = (*PlatformPolicyStore)(nil)
)
// PlatformPolicyStore implements [Store] by providing read access to the Registry-based
// Tailscale policies, such as those configured via Group Policy or MDM. It is
// recommended to lock it when reading multiple policy values in a row. It also
// allows subscribing to notifications when there's a policy change.
type PlatformPolicyStore struct {
scope gp.Scope // [gp.MachinePolicy] or [gp.UserPolicy]
// The softwareKey can be HKLM\Software, HKCU\Software, or
// HKU\{SID}\Software. Anything below the Software subkey, including
// Software\Policies, may not yet exist or could be deleted throughout the
// [PlatformPolicyStore]'s lifespan, invalidating the handle. We also prefer
// to always use a real registry key (rather than a predefined HKLM or HKCU)
// to simplify bookkeeping (predefined keys should never be closed).
// Finally, this will allow us to watch for any registry changes directly
// should we need this in the future in addition to gp.ChangeWatcher.
softwareKey registry.Key
watcher *gp.ChangeWatcher
done chan struct{} // done is closed when Close call completes
// The policyLock can be locked by the caller when reading multiple policy settings
// to prevent the Group Policy Client service from modifying policies while
// they are being read.
//
// When both policyLock and mu need to be taken, mu must be taken before policyLock.
policyLock *gp.PolicyLock
mu sync.RWMutex
tsKeys []registry.Key // or nil if the [PlatformPolicyStore] hasn't been locked.
cbs set.HandleSet[func()] // policy change callbacks
lockCnt int
locked sync.WaitGroup
closing bool
readable bool
}
type registryValueGetter[T any] func(key registry.Key, name setting.Key) (T, error)
// NewMachinePlatformPolicyStore returns a new [PlatformPolicyStore] for the machine.
func NewMachinePlatformPolicyStore() (*PlatformPolicyStore, error) {
softwareKey, err := registry.OpenKey(registry.LOCAL_MACHINE, softwareKeyName, windows.KEY_READ)
if err != nil {
return nil, fmt.Errorf("failed to open the %s key: %w", softwareKeyName, err)
}
return newPlatformPolicyStore(gp.MachinePolicy, softwareKey, 0)
}
// NewUserPlatformPolicyStore returns a new [PlatformPolicyStore] for the user specified by its token.
// User's profile must be loaded, and the token handle must have [windows.TOKEN_QUERY]
// access. The caller retains ownership of the token.
func NewUserPlatformPolicyStore(token windows.Token) (*PlatformPolicyStore, error) {
var err error
var softwareKey registry.Key
if token != 0 {
var user *windows.Tokenuser
if user, err = token.GetTokenUser(); err != nil {
return nil, fmt.Errorf("failed to get token user: %w", err)
}
userSid := user.User.Sid
softwareKey, err = registry.OpenKey(registry.USERS, userSid.String()+`\`+softwareKeyName, windows.KEY_READ)
} else {
softwareKey, err = registry.OpenKey(registry.CURRENT_USER, softwareKeyName, windows.KEY_READ)
}
if err != nil {
return nil, fmt.Errorf("failed to open the %s key: %w", softwareKeyName, err)
}
return newPlatformPolicyStore(gp.UserPolicy, softwareKey, token)
}
func newPlatformPolicyStore(scope gp.Scope, softwareKey registry.Key, token windows.Token) (_ *PlatformPolicyStore, err error) {
store := &PlatformPolicyStore{
scope: scope,
softwareKey: softwareKey,
done: make(chan struct{}),
readable: true,
}
defer func() {
if err != nil {
store.Close()
}
}()
switch scope {
case gp.MachinePolicy:
store.policyLock = gp.NewMachinePolicyLock()
case gp.UserPolicy:
if store.policyLock, err = gp.NewUserPolicyLock(token); err != nil {
return nil, fmt.Errorf("failed to create a user policy lock: %w", err)
}
default:
panic("unreachable")
}
return store, nil
}
// Lock locks the policy store, preventing the system from modifying the policies
// while they are being read. It is a read lock that may be acquired by multiple goroutines.
// Each Lock call must be balanced by exactly one Unlock call.
func (ps *PlatformPolicyStore) Lock() (err error) {
ps.mu.Lock()
defer ps.mu.Unlock()
if ps.closing {
return ErrStoreClosed
}
ps.lockCnt += 1
if ps.lockCnt != 1 {
return nil
}
defer func() {
if err != nil {
ps.lockCnt -= 1
}
}()
// Ensure ps remains open while the lock is held.
ps.locked.Add(1)
defer func() {
if err != nil {
ps.locked.Done()
}
}()
// Acquire the GP lock to prevent the system from modifying policy settings
// while they are being read.
if err := ps.policyLock.Lock(); err != nil {
if errors.Is(err, gp.ErrInvalidLockState) {
return ErrStoreClosed
}
return err
}
defer func() {
if err != nil {
ps.policyLock.Unlock()
}
}()
// Keep the Tailscale's registry keys open for the duration of the lock.
keyNames := tailscaleKeyNamesFor(ps.scope)
ps.tsKeys = make([]registry.Key, 0, len(keyNames))
for _, keyName := range keyNames {
var tsKey registry.Key
tsKey, err = registry.OpenKey(ps.softwareKey, keyName, windows.KEY_READ)
if err != nil {
if err == registry.ErrNotExist {
continue
}
return err
}
ps.tsKeys = append(ps.tsKeys, tsKey)
}
return nil
}
// Unlock decrements the lock counter and unlocks the policy store once the counter reaches 0.
// It panics if ps is not locked on entry to Unlock.
func (ps *PlatformPolicyStore) Unlock() {
ps.mu.Lock()
defer ps.mu.Unlock()
ps.lockCnt -= 1
if ps.lockCnt < 0 {
panic("negative lockCnt")
} else if ps.lockCnt != 0 {
return
}
for _, key := range ps.tsKeys {
key.Close()
}
ps.tsKeys = nil
ps.policyLock.Unlock()
ps.locked.Done()
}
// RegisterChangeCallback adds a function that will be called whenever there's a policy change.
// It returns a function that needs to be called to unregister the specified callback or an error.
// The error is [ErrStoreClosed] if ps has already been closed.
func (ps *PlatformPolicyStore) RegisterChangeCallback(cb func()) (unregister func(), err error) {
ps.mu.Lock()
defer ps.mu.Unlock()
if ps.closing {
return nil, ErrStoreClosed
}
handle := ps.cbs.Add(cb)
if len(ps.cbs) == 1 {
if ps.watcher, err = gp.NewChangeWatcher(ps.scope, ps.onChange); err != nil {
return nil, err
}
}
return func() {
ps.mu.Lock()
defer ps.mu.Unlock()
delete(ps.cbs, handle)
if len(ps.cbs) == 0 {
if ps.watcher != nil {
ps.watcher.Close()
ps.watcher = nil
}
}
}, nil
}
func (ps *PlatformPolicyStore) onChange() {
ps.mu.RLock()
defer ps.mu.RUnlock()
if ps.closing {
return
}
for _, callback := range ps.cbs {
go callback()
}
}
// ReadString retrieves a string policy with the specified name.
// It returns [ErrNotConfigured] if the policy setting does not exist.
func (ps *PlatformPolicyStore) ReadString(name setting.Key) (val string, err error) {
return getPolicyValue(ps, canonicalizeValueName(name),
func(key registry.Key, name setting.Key) (string, error) {
val, _, err := key.GetStringValue(string(name))
return val, err
})
}
// ReadUInt64 retrieves an integer policy with the specified name.
// It returns [ErrNotConfigured] if the policy setting does not exist.
func (ps *PlatformPolicyStore) ReadUInt64(name setting.Key) (uint64, error) {
return getPolicyValue(ps, canonicalizeValueName(name),
func(key registry.Key, name setting.Key) (uint64, error) {
val, _, err := key.GetIntegerValue(string(name))
return val, err
})
}
// ReadBoolean retrieves a boolean policy with the specified name.
// It returns [ErrNotConfigured] if the policy setting does not exist.
func (ps *PlatformPolicyStore) ReadBoolean(name setting.Key) (bool, error) {
return getPolicyValue(ps, canonicalizeValueName(name),
func(key registry.Key, name setting.Key) (bool, error) {
val, _, err := key.GetIntegerValue(string(name))
if err != nil {
return false, err
}
return val != 0, nil
})
}
// ReadString retrieves a multi-string policy with the specified name.
// It returns [ErrNotConfigured] if the policy setting does not exist.
func (ps *PlatformPolicyStore) ReadStringArray(name setting.Key) ([]string, error) {
return getPolicyValue(ps, name,
func(key registry.Key, name setting.Key) ([]string, error) {
val, _, err := key.GetStringsValue(string(canonicalizeValueName(name)))
if err != registry.ErrNotExist {
return val, err
}
// The idiomatic way to store multiple string values in Group Policy
// and MDM for Windows is to have multiple REG_SZ (or REG_EXPAND_SZ)
// values under a subkey rather than in a single REG_MULTI_SZ value.
//
// See the Group Policy: Registry Extension Encoding specification,
// and specifically the ListElement and ListBox types.
// https://web.archive.org/web/20240721033657/https://winprotocoldoc.blob.core.windows.net/productionwindowsarchives/MS-GPREG/%5BMS-GPREG%5D.pdf
valKey, err := registry.OpenKey(key, string(canonicalizeKeyName(name)), windows.KEY_READ)
if err != nil {
return nil, err
}
valNames, err := valKey.ReadValueNames(0)
if err != nil {
return nil, err
}
val = make([]string, 0, len(valNames))
for _, name := range valNames {
switch item, _, err := valKey.GetStringValue(name); {
case err == registry.ErrNotExist:
continue
case err != nil:
return nil, err
default:
val = append(val, item)
}
}
return val, nil
})
}
func canonicalizeKeyName(name setting.Key) setting.Key {
return setting.Key(strings.ReplaceAll(string(name), setting.KeyPathSeparator, `\`))
}
func canonicalizeValueName(name setting.Key) setting.Key {
return setting.Key(strings.ReplaceAll(string(name), setting.KeyPathSeparator, `_`))
}
func getPolicyValue[T any](ps *PlatformPolicyStore, name setting.Key, getter registryValueGetter[T]) (T, error) {
var zero T
ps.mu.RLock()
defer ps.mu.RUnlock()
if !ps.readable {
return zero, setting.ErrNotConfigured
}
if ps.tsKeys != nil {
// A non-nil tsKeys indicates that ps has been locked.
// It may be empty if Tailscale policy keys do not exist.
for _, tsKey := range ps.tsKeys {
val, err := getter(tsKey, name)
if err == nil || err != registry.ErrNotExist {
return val, err
}
}
return zero, setting.ErrNotConfigured
}
// The ps has not been locked, so we don't have any pre-opened keys.
for _, tsKeyName := range tailscaleKeyNamesFor(ps.scope) {
var tsKey registry.Key
tsKey, err := registry.OpenKey(ps.softwareKey, tsKeyName, windows.KEY_READ)
if err != nil {
if err == registry.ErrNotExist {
continue
}
return zero, err
}
defer tsKey.Close()
val, err := getter(tsKey, name)
if err == nil || err != registry.ErrNotExist {
return val, err
}
}
return zero, setting.ErrNotConfigured
}
// Close closes the policy store and releases any associated resources.
// It cancels pending locks and prevents any new lock attempts,
// but waits for existing locks to be released.
func (ps *PlatformPolicyStore) Close() error {
// Request to close the Group Policy read lock.
// Existing held locks will remain valid, but any new or pending locks
// will fail. In certain scenarios, the corresponding write lock may be held
// by the Group Policy service for extended periods (minutes rather than
// seconds or milliseconds). In such cases, we prefer not to wait that long
// if the ps is being closed anyway.
if ps.policyLock != nil {
ps.policyLock.Close()
}
// Signal to the external code that ps should no longer be used.
close(ps.done)
// Mark ps as closing to fast-fail any new lock attempts.
// Callers that have already locked it can finish their reading.
ps.mu.Lock()
if ps.closing {
ps.mu.Unlock()
return nil
}
ps.closing = true
if ps.watcher != nil {
ps.watcher.Close()
ps.watcher = nil
}
ps.mu.Unlock()
// Wait for any outstanding locks to be released.
ps.locked.Wait()
// Deny any further read attempts and release remaining resources.
ps.mu.Lock()
defer ps.mu.Unlock()
ps.cbs = nil
ps.policyLock = nil
ps.readable = false
if ps.softwareKey != 0 {
ps.softwareKey.Close()
ps.softwareKey = 0
}
return nil
}
// Done returns a channel that is closed when the Close method is called.
func (ps *PlatformPolicyStore) Done() <-chan struct{} {
return ps.done
}
func tailscaleKeyNamesFor(scope gp.Scope) []string {
switch scope {
case gp.MachinePolicy:
// If a computer-side policy value does not exist under Software\Policies\Tailscale,
// we need to fallback and use the legacy Software\Tailscale IPN key.
return []string{tsPoliciesSubkey, tsIPNSubkey}
case gp.UserPolicy:
// However, we've never used the legacy key with user-side policies,
// and we should never do so. Unlike HKLM\Software\Tailscale IPN,
// its HKCU counterpart is user-writable.
return []string{tsPoliciesSubkey}
default:
panic("unreachable")
}
}

View File

@ -0,0 +1,298 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package source
import (
"errors"
"fmt"
"reflect"
"sync"
"sync/atomic"
"testing"
"time"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
"tailscale.com/util/cibuild"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/winutil"
"tailscale.com/util/winutil/gp"
)
type testPolicyValue struct {
name setting.Key
value any
}
func TestLockUnlockPolicyStore(t *testing.T) {
store, err := NewMachinePlatformPolicyStore()
if err != nil {
t.Fatalf("NewMachinePolicyStore failed: %v", err)
}
t.Run("One-Goroutine", func(t *testing.T) {
if err := store.Lock(); err != nil {
t.Errorf("store.Lock(): got %v; want nil", err)
return
}
if v, err := store.ReadString("NonExistingPolicySetting"); err == nil || !errors.Is(err, setting.ErrNotConfigured) {
t.Errorf(`ReadString: got %v, %v; want "", %v`, v, err, setting.ErrNotConfigured)
}
store.Unlock()
})
// Lock the store N times from different goroutines.
const N = 100
var unlocked atomic.Int32
t.Run("N-Goroutines", func(t *testing.T) {
var wg sync.WaitGroup
wg.Add(N)
for range N {
go func() {
if err := store.Lock(); err != nil {
t.Errorf("store.Lock(): got %v; want nil", err)
return
}
if v, err := store.ReadString("NonExistingPolicySetting"); err == nil || !errors.Is(err, setting.ErrNotConfigured) {
t.Errorf(`ReadString: got %v, %v; want "", %v`, v, err, setting.ErrNotConfigured)
}
wg.Done()
time.Sleep(10 * time.Millisecond)
unlocked.Add(1)
store.Unlock()
}()
}
// Wait until the store is locked N times.
wg.Wait()
})
// Close the store. The call should wait for all held locks to be released.
if err := store.Close(); err != nil {
t.Fatalf("(*PolicyStore).Close failed: %v", err)
}
if locked := unlocked.Load(); locked != N {
t.Errorf("locked.Load(): got %v; want %v", locked, N)
}
// Any further attempts to lock it should fail.
if err = store.Lock(); err == nil || !errors.Is(err, ErrStoreClosed) {
t.Errorf("store.Lock(): got %v; want %v", err, ErrStoreClosed)
}
}
func TestReadPolicyStore(t *testing.T) {
if !winutil.IsCurrentProcessElevated() {
t.Skipf("test requires running as elevated user")
}
tests := []struct {
name setting.Key
newValue any
legacyValue any
want any
}{
{name: "LegacyPolicy", legacyValue: "LegacyValue", want: "LegacyValue"},
{name: "StringPolicy", legacyValue: "LegacyValue", newValue: "Value", want: "Value"},
{name: "StringPolicy_Empty", legacyValue: "LegacyValue", newValue: "", want: ""},
{name: "BoolPolicy_True", newValue: true, want: true},
{name: "BoolPolicy_False", newValue: false, want: false},
{name: "UIntPolicy_1", newValue: uint32(10), want: uint64(10)}, // uint32 values should be returned as uint64
{name: "UIntPolicy_2", newValue: uint64(1 << 37), want: uint64(1 << 37)},
{name: "StringListPolicy", newValue: []string{"Value1", "Value2"}, want: []string{"Value1", "Value2"}},
{name: "StringListPolicy_Empty", newValue: []string{}, want: []string{}},
}
runTests := func(t *testing.T, userStore bool, token windows.Token) {
var hive registry.Key
if userStore {
hive = registry.CURRENT_USER
} else {
hive = registry.LOCAL_MACHINE
}
// Write policy values to the registry.
newValues := make([]testPolicyValue, 0, len(tests))
for _, tt := range tests {
if tt.newValue != nil {
newValues = append(newValues, testPolicyValue{name: tt.name, value: tt.newValue})
}
}
policiesKeyName := softwareKeyName + `\` + tsPoliciesSubkey
cleanup, err := createTestPolicyValues(hive, policiesKeyName, newValues)
if err != nil {
t.Fatalf("createTestPolicyValues failed: %v", err)
}
t.Cleanup(cleanup)
// Write legacy policy values to the registry.
legacyValues := make([]testPolicyValue, 0, len(tests))
for _, tt := range tests {
if tt.legacyValue != nil {
legacyValues = append(legacyValues, testPolicyValue{name: tt.name, value: tt.legacyValue})
}
}
legacyKeyName := softwareKeyName + `\` + tsIPNSubkey
cleanup, err = createTestPolicyValues(hive, legacyKeyName, legacyValues)
if err != nil {
t.Fatalf("createTestPolicyValues failed: %v", err)
}
t.Cleanup(cleanup)
var store *PlatformPolicyStore
if userStore {
store, err = NewUserPlatformPolicyStore(token)
} else {
store, err = NewMachinePlatformPolicyStore()
}
if err != nil {
t.Fatalf("NewXPolicyStore failed: %v", err)
}
t.Cleanup(func() {
if err := store.Close(); err != nil {
t.Errorf("(*PolicyStore).Close failed: %v", err)
}
})
// testReadValues checks that [PolicyStore] returns the same values we wrote directly to the registry.
testReadValues := func(t *testing.T, withLocks bool) {
for _, tt := range tests {
t.Run(string(tt.name), func(t *testing.T) {
if userStore && tt.newValue == nil {
t.Skip("there is no legacy policies for users")
}
t.Parallel()
if withLocks {
if err := store.Lock(); err != nil {
t.Errorf("failed to acquire the lock: %v", err)
}
defer store.Unlock()
}
var got any
var err error
switch tt.want.(type) {
case string:
got, err = store.ReadString(tt.name)
case uint64:
got, err = store.ReadUInt64(tt.name)
case bool:
got, err = store.ReadBoolean(tt.name)
case []string:
got, err = store.ReadStringArray(tt.name)
}
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("got %v; want %v", got, tt.want)
}
})
}
}
t.Run("NoLock", func(t *testing.T) {
testReadValues(t, false)
})
t.Run("WithLock", func(t *testing.T) {
testReadValues(t, true)
})
}
t.Run("MachineStore", func(t *testing.T) {
runTests(t, false, 0)
})
t.Run("CurrentUserStore", func(t *testing.T) {
runTests(t, true, 0)
})
t.Run("UserStoreWithToken", func(t *testing.T) {
var token windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token); err != nil {
t.Fatalf("OpenProcessToken: %v", err)
}
defer token.Close()
runTests(t, true, token)
})
}
func TestPolicyStoreChangeNotifications(t *testing.T) {
if cibuild.On() {
t.Skipf("test requires running on a real Windows environment")
}
store, err := NewMachinePlatformPolicyStore()
if err != nil {
t.Fatalf("NewMachinePolicyStore failed: %v", err)
}
t.Cleanup(func() {
if err := store.Close(); err != nil {
t.Errorf("(*PolicyStore).Close failed: %v", err)
}
})
done := make(chan struct{})
unregister, err := store.RegisterChangeCallback(func() { close(done) })
if err != nil {
t.Fatalf("RegisterChangeCallback failed: %v", err)
}
t.Cleanup(unregister)
// RefreshMachinePolicy is a non-blocking call.
if err := gp.RefreshMachinePolicy(true); err != nil {
t.Fatalf("RefreshMachinePolicy failed: %v", err)
}
// We should receive a policy change notification when
// the Group Policy service completes policy processing.
// Otherwise, the test will eventually time out.
<-done
}
func createTestPolicyValues(hive registry.Key, keyName string, values []testPolicyValue) (cleanup func(), err error) {
key, existing, err := registry.CreateKey(hive, keyName, registry.ALL_ACCESS)
if err != nil {
return nil, err
}
doCleanup := func() {
for _, v := range values {
key.DeleteValue(string(v.name))
}
key.Close()
if !existing {
registry.DeleteKey(hive, keyName)
}
}
defer func() {
if err != nil {
doCleanup()
}
}()
for _, v := range values {
switch value := v.value.(type) {
case string:
err = key.SetStringValue(string(v.name), value)
case uint32:
err = key.SetDWordValue(string(v.name), value)
case uint64:
err = key.SetQWordValue(string(v.name), value)
case bool:
if value {
err = key.SetDWordValue(string(v.name), 1)
} else {
err = key.SetDWordValue(string(v.name), 0)
}
case []string:
err = key.SetStringsValue(string(v.name), value)
default:
err = fmt.Errorf("unsupported value: %v (%T), name: %q", value, value, v.name)
}
if err != nil {
return nil, err
}
}
return doCleanup, nil
}

View File

@ -0,0 +1,446 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package source
import (
"fmt"
"sync"
"sync/atomic"
xmaps "golang.org/x/exp/maps"
"tailscale.com/util/mak"
"tailscale.com/util/set"
"tailscale.com/util/syspolicy/internal"
"tailscale.com/util/syspolicy/setting"
)
var _ Store = (*TestStore)(nil)
// TestValueType is a constraint that allows types supported by [TestStore].
type TestValueType interface {
bool | uint64 | string | []string
}
// TestSetting is a policy setting in a [TestStore].
type TestSetting[T TestValueType] struct {
// Key is the setting's unique identifier.
Key setting.Key
// Error is the error to be returned by the [TestStore] when reading
// a policy setting with the specified key.
Error error
// Value is the value to be returned by the [TestStore] when reading
// a policy setting with the specified key.
// It is only used if the Error is nil.
Value T
}
// TestSettingOf returns a [TestSetting] representing a policy setting
// configured with the specified key and value.
func TestSettingOf[T TestValueType](key setting.Key, value T) TestSetting[T] {
return TestSetting[T]{Key: key, Value: value}
}
// TestSettingWithError returns a [TestSetting] representing a policy setting
// with the specified key and error.
func TestSettingWithError[T TestValueType](key setting.Key, err error) TestSetting[T] {
return TestSetting[T]{Key: key, Error: err}
}
// testReadOperation describes a single policy setting read operation.
type testReadOperation struct {
// Key is the setting's unique identifier.
Key setting.Key
// Type is a value type of a read operation.
// [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue]
Type setting.Type
}
// TestExpectedReads is the number of read operations with the specified details.
type TestExpectedReads struct {
// Key is the setting's unique identifier.
Key setting.Key
// Type is a value type of a read operation.
// [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue]
Type setting.Type
// NumTimes is how many times a setting with the specified key and type should have been read.
NumTimes int
}
func (r TestExpectedReads) operation() testReadOperation {
return testReadOperation{r.Key, r.Type}
}
// TestStore is a [Store] that can be used in tests.
type TestStore struct {
tb internal.TB
done chan struct{}
storeLock sync.RWMutex // its RLock is exposed via [Store.Lock]/[Store.Unlock].
storeLockCount atomic.Int32
mu sync.RWMutex
suspendCount int // change callback are suspended if > 0
mr, mw map[setting.Key]any // maps for reading and writing; they're the same unless the store is suspended.
cbs set.HandleSet[func()]
readsMu sync.Mutex
reads map[testReadOperation]int // how many times a policy setting was read
}
// NewTestStore returns a new [TestStore].
// The tb will be used to report coding errors detected by the [TestStore].
func NewTestStore(tb internal.TB) *TestStore {
m := make(map[setting.Key]any)
return &TestStore{
tb: tb,
done: make(chan struct{}),
mr: m,
mw: m,
}
}
// NewTestStoreOf is a shorthand for [NewTestStore] followed by [TestStore.SetBooleans],
// [TestStore.SetUInt64s], [TestStore.SetStrings] or [TestStore.SetStringLists].
func NewTestStoreOf[T TestValueType](tb internal.TB, settings ...TestSetting[T]) *TestStore {
m := make(map[setting.Key]any)
store := &TestStore{
tb: tb,
done: make(chan struct{}),
mr: m,
mw: m,
}
switch settings := any(settings).(type) {
case []TestSetting[bool]:
store.SetBooleans(settings...)
case []TestSetting[uint64]:
store.SetUInt64s(settings...)
case []TestSetting[string]:
store.SetStrings(settings...)
case []TestSetting[[]string]:
store.SetStringLists(settings...)
}
return store
}
// Lock implements [Store].
func (s *TestStore) Lock() error {
s.storeLock.RLock()
s.storeLockCount.Add(1)
return nil
}
// Unlock implements [Store].
func (s *TestStore) Unlock() {
if s.storeLockCount.Add(-1) < 0 {
s.tb.Fatal("negative storeLockCount")
}
s.storeLock.RUnlock()
}
// RegisterChangeCallback implements [Store].
func (s *TestStore) RegisterChangeCallback(callback func()) (unregister func(), err error) {
s.mu.Lock()
defer s.mu.Unlock()
handle := s.cbs.Add(callback)
return func() {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.cbs, handle)
}, nil
}
// ReadString implements [Store].
func (s *TestStore) ReadString(key setting.Key) (string, error) {
defer s.recordRead(key, setting.StringValue)
s.mu.RLock()
defer s.mu.RUnlock()
v, ok := s.mr[key]
if !ok {
return "", setting.ErrNotConfigured
}
if err, ok := v.(error); ok {
return "", err
}
str, ok := v.(string)
if !ok {
return "", fmt.Errorf("%w in ReadString: got %T", setting.ErrTypeMismatch, v)
}
return str, nil
}
// ReadUInt64 implements [Store].
func (s *TestStore) ReadUInt64(key setting.Key) (uint64, error) {
defer s.recordRead(key, setting.IntegerValue)
s.mu.RLock()
defer s.mu.RUnlock()
v, ok := s.mr[key]
if !ok {
return 0, setting.ErrNotConfigured
}
if err, ok := v.(error); ok {
return 0, err
}
u64, ok := v.(uint64)
if !ok {
return 0, fmt.Errorf("%w in ReadUInt64: got %T", setting.ErrTypeMismatch, v)
}
return u64, nil
}
// ReadBoolean implements [Store].
func (s *TestStore) ReadBoolean(key setting.Key) (bool, error) {
defer s.recordRead(key, setting.BooleanValue)
s.mu.RLock()
defer s.mu.RUnlock()
v, ok := s.mr[key]
if !ok {
return false, setting.ErrNotConfigured
}
if err, ok := v.(error); ok {
return false, err
}
b, ok := v.(bool)
if !ok {
return false, fmt.Errorf("%w in ReadBoolean: got %T", setting.ErrTypeMismatch, v)
}
return b, nil
}
// ReadStringArray implements [Store].
func (s *TestStore) ReadStringArray(key setting.Key) ([]string, error) {
defer s.recordRead(key, setting.StringListValue)
s.mu.RLock()
defer s.mu.RUnlock()
v, ok := s.mr[key]
if !ok {
return nil, setting.ErrNotConfigured
}
if err, ok := v.(error); ok {
return nil, err
}
slice, ok := v.([]string)
if !ok {
return nil, fmt.Errorf("%w in ReadStringArray: got %T", setting.ErrTypeMismatch, v)
}
return slice, nil
}
func (s *TestStore) recordRead(key setting.Key, typ setting.Type) {
s.readsMu.Lock()
op := testReadOperation{key, typ}
num := s.reads[op]
num++
mak.Set(&s.reads, op, num)
s.readsMu.Unlock()
}
func (s *TestStore) ResetCounters() {
s.readsMu.Lock()
clear(s.reads)
s.readsMu.Unlock()
}
// ReadsMustEqual fails the test if the actual reads differs from the specified reads.
func (s *TestStore) ReadsMustEqual(reads ...TestExpectedReads) {
s.tb.Helper()
s.readsMu.Lock()
defer s.readsMu.Unlock()
s.readsMustContainLocked(reads...)
s.readMustNoExtraLocked(reads...)
}
// ReadsMustContain fails the test if the specified reads have not been made,
// or have been made a different number of times. It permits other values to be
// read in addition to the ones being tested.
func (s *TestStore) ReadsMustContain(reads ...TestExpectedReads) {
s.tb.Helper()
s.readsMu.Lock()
defer s.readsMu.Unlock()
s.readsMustContainLocked(reads...)
}
func (s *TestStore) readsMustContainLocked(reads ...TestExpectedReads) {
s.tb.Helper()
for _, r := range reads {
if numTimes := s.reads[r.operation()]; numTimes != r.NumTimes {
s.tb.Errorf("%q (%v) reads: got %v, want %v", r.Key, r.Type, numTimes, r.NumTimes)
}
}
}
func (s *TestStore) readMustNoExtraLocked(reads ...TestExpectedReads) {
s.tb.Helper()
rs := make(set.Set[testReadOperation])
for i := range reads {
rs.Add(reads[i].operation())
}
for ro, num := range s.reads {
if !rs.Contains(ro) {
s.tb.Errorf("%q (%v) reads: got %v, want 0", ro.Key, ro.Type, num)
}
}
}
// Suspend suspends the store, batching changes and notifications
// until [TestStore.Resume] is called the same number of times as Suspend.
func (s *TestStore) Suspend() {
s.mu.Lock()
defer s.mu.Unlock()
if s.suspendCount++; s.suspendCount == 1 {
s.mw = xmaps.Clone(s.mr)
}
}
// Resume resumes the store, applying the changes and invoking
// the change callbacks.
func (s *TestStore) Resume() {
s.storeLock.Lock()
s.mu.Lock()
switch s.suspendCount--; {
case s.suspendCount == 0:
s.mr = s.mw
s.mu.Unlock()
s.storeLock.Unlock()
s.notifyPolicyChanged()
case s.suspendCount < 0:
s.tb.Fatal("negative suspendCount")
default:
s.mu.Unlock()
s.storeLock.Unlock()
}
}
// SetBooleans sets the specified boolean settings in s.
func (s *TestStore) SetBooleans(settings ...TestSetting[bool]) {
s.storeLock.Lock()
for _, setting := range settings {
if setting.Key == "" {
s.tb.Fatal("empty keys disallowed")
}
s.mu.Lock()
if setting.Error != nil {
mak.Set(&s.mw, setting.Key, any(setting.Error))
} else {
mak.Set(&s.mw, setting.Key, any(setting.Value))
}
s.mu.Unlock()
}
s.storeLock.Unlock()
s.notifyPolicyChanged()
}
// SetUInt64s sets the specified integer settings in s.
func (s *TestStore) SetUInt64s(settings ...TestSetting[uint64]) {
s.storeLock.Lock()
for _, setting := range settings {
if setting.Key == "" {
s.tb.Fatal("empty keys disallowed")
}
s.mu.Lock()
if setting.Error != nil {
mak.Set(&s.mw, setting.Key, any(setting.Error))
} else {
mak.Set(&s.mw, setting.Key, any(setting.Value))
}
s.mu.Unlock()
}
s.storeLock.Unlock()
s.notifyPolicyChanged()
}
// SetStrings sets the specified string settings in s.
func (s *TestStore) SetStrings(settings ...TestSetting[string]) {
s.storeLock.Lock()
for _, setting := range settings {
if setting.Key == "" {
s.tb.Fatal("empty keys disallowed")
}
s.mu.Lock()
if setting.Error != nil {
mak.Set(&s.mw, setting.Key, any(setting.Error))
} else {
mak.Set(&s.mw, setting.Key, any(setting.Value))
}
s.mu.Unlock()
}
s.storeLock.Unlock()
s.notifyPolicyChanged()
}
// SetStrings sets the specified string list settings in s.
func (s *TestStore) SetStringLists(settings ...TestSetting[[]string]) {
s.storeLock.Lock()
for _, setting := range settings {
if setting.Key == "" {
s.tb.Fatal("empty keys disallowed")
}
s.mu.Lock()
if setting.Error != nil {
mak.Set(&s.mw, setting.Key, any(setting.Error))
} else {
mak.Set(&s.mw, setting.Key, any(setting.Value))
}
s.mu.Unlock()
}
s.storeLock.Unlock()
s.notifyPolicyChanged()
}
// Delete deletes the specified settings from s.
func (s *TestStore) Delete(keys ...setting.Key) {
s.storeLock.Lock()
for _, key := range keys {
s.mu.Lock()
delete(s.mw, key)
s.mu.Unlock()
}
s.storeLock.Unlock()
s.notifyPolicyChanged()
}
// Clear deletes all settings from s.
func (s *TestStore) Clear() {
s.storeLock.Lock()
s.mu.Lock()
clear(s.mw)
s.mu.Unlock()
s.storeLock.Unlock()
s.notifyPolicyChanged()
}
func (s *TestStore) notifyPolicyChanged() {
s.mu.RLock()
if s.suspendCount != 0 {
s.mu.RUnlock()
return
}
cbs := xmaps.Values(s.cbs)
s.mu.RUnlock()
var wg sync.WaitGroup
wg.Add(len(cbs))
for _, cb := range cbs {
go func() {
defer wg.Done()
cb()
}()
}
wg.Wait()
}
// Close closes s, notifying its users that it has expired.
func (s *TestStore) Close() {
s.mu.Lock()
defer s.mu.Unlock()
if s.done != nil {
close(s.done)
s.done = nil
}
}
// Done implements [Store].
func (s *TestStore) Done() <-chan struct{} {
return s.done
}

View File

@ -1,122 +1,83 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package syspolicy provides functions to retrieve system settings of a device.
// Package syspolicy facilitates retrieval of the current policy settings
// applied to the device or user and receiving notifications when the policy
// changes.
//
// It provides functions that return specific policy settings by their unique
// [setting.Key]s, such as [GetBoolean], [GetUint64], [GetString],
// [GetStringArray], [GetPreferenceOption], [GetVisibility] and [GetDuration].
package syspolicy
import (
"errors"
"fmt"
"reflect"
"time"
"tailscale.com/util/syspolicy/rsop"
"tailscale.com/util/syspolicy/setting"
)
var (
// ErrNotConfigured is returned when the requested policy setting is not configured.
ErrNotConfigured = setting.ErrNotConfigured
// ErrTypeMismatch is returned when there's a type mismatch between the actual type
// of the setting value and the expected type.
ErrTypeMismatch = setting.ErrTypeMismatch
// ErrNoSuchKey is returned by [setting.DefinitionOf] when no policy setting
// has been registered with the specified key.
//
// Until 2024-08-02, this error was also returned by a [Handler] when the specified
// key did not have a value set. While the package maintains compatibility with this
// usage of ErrNoSuchKey, it is recommended to return [ErrNotConfigured] from newer
// [source.Store] implementations.
ErrNoSuchKey = setting.ErrNoSuchKey
)
// GetString returns a string policy setting with the specified key,
// or defaultValue if it does not exist.
func GetString(key Key, defaultValue string) (string, error) {
markHandlerInUse()
v, err := handler.ReadString(string(key))
if errors.Is(err, ErrNoSuchKey) {
return defaultValue, nil
}
return v, err
return getCurrentPolicySettingValue(key, defaultValue)
}
// GetUint64 returns a numeric policy setting with the specified key,
// or defaultValue if it does not exist.
func GetUint64(key Key, defaultValue uint64) (uint64, error) {
markHandlerInUse()
v, err := handler.ReadUInt64(string(key))
if errors.Is(err, ErrNoSuchKey) {
return defaultValue, nil
}
return v, err
return getCurrentPolicySettingValue(key, defaultValue)
}
// GetBoolean returns a boolean policy setting with the specified key,
// or defaultValue if it does not exist.
func GetBoolean(key Key, defaultValue bool) (bool, error) {
markHandlerInUse()
v, err := handler.ReadBoolean(string(key))
if errors.Is(err, ErrNoSuchKey) {
return defaultValue, nil
}
return v, err
return getCurrentPolicySettingValue(key, defaultValue)
}
// GetStringArray returns a multi-string policy setting with the specified key,
// or defaultValue if it does not exist.
func GetStringArray(key Key, defaultValue []string) ([]string, error) {
markHandlerInUse()
v, err := handler.ReadStringArray(string(key))
if errors.Is(err, ErrNoSuchKey) {
return defaultValue, nil
}
return v, err
return getCurrentPolicySettingValue(key, defaultValue)
}
// PreferenceOption is a policy that governs whether a boolean variable
// is forcibly assigned an administrator-defined value, or allowed to receive
// a user-defined value.
type PreferenceOption int
const (
showChoiceByPolicy PreferenceOption = iota
neverByPolicy
alwaysByPolicy
type (
// PreferenceOption is a policy that governs whether a boolean variable
// is forcibly assigned an administrator-defined value, or allowed to receive
// a user-defined value.
PreferenceOption = setting.PreferenceOption
// Visibility is a policy that controls whether or not a particular
// component of a user interface is to be shown.
Visibility = setting.Visibility
)
// Show returns if the UI option that controls the choice administered by this
// policy should be shown. Currently this is true if and only if the policy is
// showChoiceByPolicy.
func (p PreferenceOption) Show() bool {
return p == showChoiceByPolicy
}
// ShouldEnable checks if the choice administered by this policy should be
// enabled. If the administrator has chosen a setting, the administrator's
// setting is returned, otherwise userChoice is returned.
func (p PreferenceOption) ShouldEnable(userChoice bool) bool {
switch p {
case neverByPolicy:
return false
case alwaysByPolicy:
return true
default:
return userChoice
}
}
// WillOverride checks if the choice administered by the policy is different
// from the user's choice.
func (p PreferenceOption) WillOverride(userChoice bool) bool {
return p.ShouldEnable(userChoice) != userChoice
}
// GetPreferenceOption loads a policy from the registry that can be
// managed by an enterprise policy management system and allows administrative
// overrides of users' choices in a way that we do not want tailcontrol to have
// the authority to set. It describes user-decides/always/never options, where
// "always" and "never" remove the user's ability to make a selection. If not
// present or set to a different value, "user-decides" is the default.
func GetPreferenceOption(name Key) (PreferenceOption, error) {
opt, err := GetString(name, "user-decides")
if err != nil {
return showChoiceByPolicy, err
}
switch opt {
case "always":
return alwaysByPolicy, nil
case "never":
return neverByPolicy, nil
default:
return showChoiceByPolicy, nil
}
}
// Visibility is a policy that controls whether or not a particular
// component of a user interface is to be shown.
type Visibility byte
const (
visibleByPolicy Visibility = 'v'
hiddenByPolicy Visibility = 'h'
)
// Show reports whether the UI option administered by this policy should be shown.
// Currently this is true if and only if the policy is visibleByPolicy.
func (p Visibility) Show() bool {
return p == visibleByPolicy
func GetPreferenceOption(name Key) (setting.PreferenceOption, error) {
return getCurrentPolicySettingValue(name, setting.ShowChoiceByPolicy)
}
// GetVisibility loads a policy from the registry that can be managed
@ -124,17 +85,8 @@ func (p Visibility) Show() bool {
// for UI elements. The registry value should be a string set to "show" (return
// true) or "hide" (return true). If not present or set to a different value,
// "show" (return false) is the default.
func GetVisibility(name Key) (Visibility, error) {
opt, err := GetString(name, "show")
if err != nil {
return visibleByPolicy, err
}
switch opt {
case "hide":
return hiddenByPolicy, nil
default:
return visibleByPolicy, nil
}
func GetVisibility(name Key) (setting.Visibility, error) {
return getCurrentPolicySettingValue(name, setting.VisibleByPolicy)
}
// GetDuration loads a policy from the registry that can be managed
@ -143,15 +95,48 @@ func GetVisibility(name Key) (Visibility, error) {
// understands. If the registry value is "" or can not be processed,
// defaultValue is returned instead.
func GetDuration(name Key, defaultValue time.Duration) (time.Duration, error) {
opt, err := GetString(name, "")
if opt == "" || err != nil {
return defaultValue, err
d, err := getCurrentPolicySettingValue(name, defaultValue)
if err != nil {
return d, err
}
v, err := time.ParseDuration(opt)
if err != nil || v < 0 {
if d < 0 {
return defaultValue, nil
}
return v, nil
return d, nil
}
// getCurrentPolicySettingValue returns the value of the policy setting
// specified by its key from the [rsop.Policy] of the [CurrentScope]. It
// returns def if the policy setting is not configured, or an error if it has
// an error or could not be converted to the specified type T.
func getCurrentPolicySettingValue[T setting.ValueType](key Key, def T) (T, error) {
resultant, err := rsop.PolicyFor(setting.CurrentScope())
if err != nil {
return def, err
}
value, err := resultant.Get().GetErr(key)
if err != nil {
if errors.Is(err, setting.ErrNotConfigured) || errors.Is(err, setting.ErrNoSuchKey) {
return def, nil
}
return def, err
}
if res, ok := value.(T); ok {
return res, nil
}
return convertPolicySettingValueTo(value, def)
}
func convertPolicySettingValueTo[T setting.ValueType](value any, def T) (T, error) {
// Convert [PreferenceOption], [Visibility], or [time.Duration] back to a string
// if someone requests a string instead of the actual setting's value.
// TODO(nickkhyl): check if this behavior is relied upon anywhere besides the old tests.
if reflect.TypeFor[T]().Kind() == reflect.String {
if str, ok := value.(fmt.Stringer); ok {
return any(str.String()).(T), nil
}
}
return def, fmt.Errorf("%w: got %T, want %T", setting.ErrTypeMismatch, value, def)
}
// SelectControlURL returns the ControlURL to use based on a value in

View File

@ -5,16 +5,24 @@
import (
"errors"
"fmt"
"slices"
"testing"
"time"
"tailscale.com/types/logger"
"tailscale.com/util/syspolicy/internal/loggerx"
"tailscale.com/util/syspolicy/internal/metrics"
"tailscale.com/util/syspolicy/rsop"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/syspolicy/source"
)
// testHandler encompasses all data types returned when testing any of the syspolicy
// methods that involve getting a policy value.
// For keys and the corresponding values, check policy_keys.go.
type testHandler struct {
t *testing.T
t testing.TB
key Key
s string
u64 uint64
@ -28,7 +36,10 @@ type testHandler struct {
func (th *testHandler) ReadString(key string) (string, error) {
if key != string(th.key) {
th.t.Errorf("ReadString(%q) want %q", key, th.key)
// The syspolicy package now reads and caches all registered policy settings.
// Therefore, it is expected to call the handler requesting all policies
// rather than just the specific ones we asked for.
return "", ErrNotConfigured
}
th.calls++
return th.s, th.err
@ -36,7 +47,10 @@ func (th *testHandler) ReadString(key string) (string, error) {
func (th *testHandler) ReadUInt64(key string) (uint64, error) {
if key != string(th.key) {
th.t.Errorf("ReadUint64(%q) want %q", key, th.key)
// The syspolicy package now reads and caches all registered policy settings.
// Therefore, it is expected to call the handler requesting all policies
// rather than just the specific ones we asked for.
return 0, ErrNotConfigured
}
th.calls++
return th.u64, th.err
@ -44,7 +58,10 @@ func (th *testHandler) ReadUInt64(key string) (uint64, error) {
func (th *testHandler) ReadBoolean(key string) (bool, error) {
if key != string(th.key) {
th.t.Errorf("ReadBool(%q) want %q", key, th.key)
// The syspolicy package now reads and caches all registered policy settings.
// Therefore, it is expected to call the handler requesting all policies
// rather than just the specific ones we asked for.
return false, ErrNotConfigured
}
th.calls++
return th.b, th.err
@ -52,7 +69,10 @@ func (th *testHandler) ReadBoolean(key string) (bool, error) {
func (th *testHandler) ReadStringArray(key string) ([]string, error) {
if key != string(th.key) {
th.t.Errorf("ReadStringArray(%q) want %q", key, th.key)
// The syspolicy package now reads and caches all registered policy settings.
// Therefore, it is expected to call the handler requesting all policies
// rather than just the specific ones we asked for.
return nil, ErrNotConfigured
}
th.calls++
return th.sArr, th.err
@ -67,23 +87,28 @@ func TestGetString(t *testing.T) {
defaultValue string
wantValue string
wantError error
wantMetrics []metrics.TestState
}{
{
name: "read existing value",
key: AdminConsoleVisibility,
handlerValue: "hide",
wantValue: "hide",
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_any", Value: 1},
{Name: "$os_syspolicy_AdminConsole", Value: 1},
},
},
{
name: "read non-existing value",
key: EnableServerMode,
handlerError: ErrNoSuchKey,
handlerError: ErrNotConfigured,
wantError: nil,
},
{
name: "read non-existing value, non-blank default",
key: EnableServerMode,
handlerError: ErrNoSuchKey,
handlerError: ErrNotConfigured,
defaultValue: "test",
wantValue: "test",
wantError: nil,
@ -93,11 +118,17 @@ func TestGetString(t *testing.T) {
key: NetworkDevicesVisibility,
handlerError: someOtherError,
wantError: someOtherError,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_errors", Value: 1},
{Name: "$os_syspolicy_NetworkDevices_error", Value: 1},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := metrics.NewTestHandler(t)
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
SetHandlerForTest(t, &testHandler{
t: t,
key: tt.key,
@ -105,12 +136,21 @@ func TestGetString(t *testing.T) {
err: tt.handlerError,
})
value, err := GetString(tt.key, tt.defaultValue)
if err != tt.wantError {
if !errorsMatchForTest(err, tt.wantError) {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if value != tt.wantValue {
t.Errorf("value=%v, want %v", value, tt.wantValue)
}
wantMetrics := tt.wantMetrics
if !metrics.ShouldReport() {
// Check that metrics are not reported on platforms
// where they shouldn't be reported.
// As of 2024-08-02, syspolicy only reports metrics
// on Windows and Android.
wantMetrics = nil
}
h.MustEqual(wantMetrics...)
})
}
}
@ -127,7 +167,7 @@ func TestGetUint64(t *testing.T) {
}{
{
name: "read existing value",
key: KeyExpirationNoticeTime,
key: LogSCMInteractions,
handlerValue: 1,
wantValue: 1,
},
@ -135,14 +175,14 @@ func TestGetUint64(t *testing.T) {
name: "read non-existing value",
key: LogSCMInteractions,
handlerValue: 0,
handlerError: ErrNoSuchKey,
handlerError: ErrNotConfigured,
wantValue: 0,
},
{
name: "read non-existing value, non-zero default",
key: LogSCMInteractions,
defaultValue: 2,
handlerError: ErrNoSuchKey,
handlerError: ErrNotConfigured,
wantValue: 2,
},
{
@ -155,14 +195,21 @@ func TestGetUint64(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
SetHandlerForTest(t, &testHandler{
// None of the policy settings tested here are integers.
// In fact, we don't have any integer policies as of 2024-07-29.
// However, we can register each of them as an integer policy setting
// for the duration of the test, providing us with something to test against.
if err := setting.SetDefinitionsForTest(t, setting.NewDefinition(tt.key, setting.DeviceSetting, setting.IntegerValue)); err != nil {
t.Fatalf("SetDefinitionsForTest failed: %v", err)
}
rsop.RegisterStoreForTest(t, tt.name, setting.DeviceScope, WrapHandler(&testHandler{
t: t,
key: tt.key,
u64: tt.handlerValue,
err: tt.handlerError,
})
}))
value, err := GetUint64(tt.key, tt.defaultValue)
if err != tt.wantError {
if !errorsMatchForTest(err, tt.wantError) {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if value != tt.wantValue {
@ -181,32 +228,43 @@ func TestGetBoolean(t *testing.T) {
defaultValue bool
wantValue bool
wantError error
wantMetrics []metrics.TestState
}{
{
name: "read existing value",
key: FlushDNSOnSessionUnlock,
handlerValue: true,
wantValue: true,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_any", Value: 1},
{Name: "$os_syspolicy_FlushDNSOnSessionUnlock", Value: 1},
},
},
{
name: "read non-existing value",
key: LogSCMInteractions,
handlerValue: false,
handlerError: ErrNoSuchKey,
handlerError: ErrNotConfigured,
wantValue: false,
},
{
name: "reading value returns other error",
key: FlushDNSOnSessionUnlock,
handlerError: someOtherError,
wantError: someOtherError,
wantError: someOtherError, // expect error...
defaultValue: true,
wantValue: false,
wantValue: true, // ...AND default value if the handler fails.
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_errors", Value: 1},
{Name: "$os_syspolicy_FlushDNSOnSessionUnlock_error", Value: 1},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := metrics.NewTestHandler(t)
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
SetHandlerForTest(t, &testHandler{
t: t,
key: tt.key,
@ -214,12 +272,21 @@ func TestGetBoolean(t *testing.T) {
err: tt.handlerError,
})
value, err := GetBoolean(tt.key, tt.defaultValue)
if err != tt.wantError {
if !errorsMatchForTest(err, tt.wantError) {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if value != tt.wantValue {
t.Errorf("value=%v, want %v", value, tt.wantValue)
}
wantMetrics := tt.wantMetrics
if !metrics.ShouldReport() {
// Check that metrics are not reported on platforms
// where they shouldn't be reported.
// As of 2024-08-02, syspolicy only reports metrics
// on Windows and Android.
wantMetrics = nil
}
h.MustEqual(wantMetrics...)
})
}
}
@ -232,42 +299,61 @@ func TestGetPreferenceOption(t *testing.T) {
handlerError error
wantValue PreferenceOption
wantError error
wantMetrics []metrics.TestState
}{
{
name: "always by policy",
key: EnableIncomingConnections,
handlerValue: "always",
wantValue: alwaysByPolicy,
wantValue: setting.AlwaysByPolicy,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_any", Value: 1},
{Name: "$os_syspolicy_AllowIncomingConnections", Value: 1},
},
},
{
name: "never by policy",
key: EnableIncomingConnections,
handlerValue: "never",
wantValue: neverByPolicy,
wantValue: setting.NeverByPolicy,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_any", Value: 1},
{Name: "$os_syspolicy_AllowIncomingConnections", Value: 1},
},
},
{
name: "use default",
key: EnableIncomingConnections,
handlerValue: "",
wantValue: showChoiceByPolicy,
wantValue: setting.ShowChoiceByPolicy,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_any", Value: 1},
{Name: "$os_syspolicy_AllowIncomingConnections", Value: 1},
},
},
{
name: "read non-existing value",
key: EnableIncomingConnections,
handlerError: ErrNoSuchKey,
wantValue: showChoiceByPolicy,
handlerError: ErrNotConfigured,
wantValue: setting.ShowChoiceByPolicy,
},
{
name: "other error is returned",
key: EnableIncomingConnections,
handlerError: someOtherError,
wantValue: showChoiceByPolicy,
wantValue: setting.ShowChoiceByPolicy,
wantError: someOtherError,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_errors", Value: 1},
{Name: "$os_syspolicy_AllowIncomingConnections_error", Value: 1},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := metrics.NewTestHandler(t)
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
SetHandlerForTest(t, &testHandler{
t: t,
key: tt.key,
@ -275,12 +361,21 @@ func TestGetPreferenceOption(t *testing.T) {
err: tt.handlerError,
})
option, err := GetPreferenceOption(tt.key)
if err != tt.wantError {
if !errorsMatchForTest(err, tt.wantError) {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if option != tt.wantValue {
t.Errorf("option=%v, want %v", option, tt.wantValue)
}
wantMetrics := tt.wantMetrics
if !metrics.ShouldReport() {
// Check that metrics are not reported on platforms
// where they shouldn't be reported.
// As of 2024-08-02, syspolicy only reports metrics
// on Windows and Android.
wantMetrics = nil
}
h.MustEqual(wantMetrics...)
})
}
}
@ -293,38 +388,53 @@ func TestGetVisibility(t *testing.T) {
handlerError error
wantValue Visibility
wantError error
wantMetrics []metrics.TestState
}{
{
name: "hidden by policy",
key: AdminConsoleVisibility,
handlerValue: "hide",
wantValue: hiddenByPolicy,
wantValue: setting.HiddenByPolicy,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_any", Value: 1},
{Name: "$os_syspolicy_AdminConsole", Value: 1},
},
},
{
name: "visibility default",
key: AdminConsoleVisibility,
handlerValue: "show",
wantValue: visibleByPolicy,
wantValue: setting.VisibleByPolicy,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_any", Value: 1},
{Name: "$os_syspolicy_AdminConsole", Value: 1},
},
},
{
name: "read non-existing value",
key: AdminConsoleVisibility,
handlerValue: "show",
handlerError: ErrNoSuchKey,
wantValue: visibleByPolicy,
handlerError: ErrNotConfigured,
wantValue: setting.VisibleByPolicy,
},
{
name: "other error is returned",
key: AdminConsoleVisibility,
handlerValue: "show",
handlerError: someOtherError,
wantValue: visibleByPolicy,
wantValue: setting.VisibleByPolicy,
wantError: someOtherError,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_errors", Value: 1},
{Name: "$os_syspolicy_AdminConsole_error", Value: 1},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := metrics.NewTestHandler(t)
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
SetHandlerForTest(t, &testHandler{
t: t,
key: tt.key,
@ -332,12 +442,21 @@ func TestGetVisibility(t *testing.T) {
err: tt.handlerError,
})
visibility, err := GetVisibility(tt.key)
if err != tt.wantError {
if !errorsMatchForTest(err, tt.wantError) {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if visibility != tt.wantValue {
t.Errorf("visibility=%v, want %v", visibility, tt.wantValue)
}
wantMetrics := tt.wantMetrics
if !metrics.ShouldReport() {
// Check that metrics are not reported on platforms
// where they shouldn't be reported.
// As of 2024-08-02, syspolicy only reports metrics
// on Windows and Android.
wantMetrics = nil
}
h.MustEqual(wantMetrics...)
})
}
}
@ -351,6 +470,7 @@ func TestGetDuration(t *testing.T) {
defaultValue time.Duration
wantValue time.Duration
wantError error
wantMetrics []metrics.TestState
}{
{
name: "read existing value",
@ -358,25 +478,34 @@ func TestGetDuration(t *testing.T) {
handlerValue: "2h",
wantValue: 2 * time.Hour,
defaultValue: 24 * time.Hour,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_any", Value: 1},
{Name: "$os_syspolicy_KeyExpirationNotice", Value: 1},
},
},
{
name: "invalid duration value",
key: KeyExpirationNoticeTime,
handlerValue: "-20",
wantValue: 24 * time.Hour,
wantError: errors.New(`time: missing unit in duration "-20"`),
defaultValue: 24 * time.Hour,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_errors", Value: 1},
{Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1},
},
},
{
name: "read non-existing value",
key: KeyExpirationNoticeTime,
handlerError: ErrNoSuchKey,
handlerError: ErrNotConfigured,
wantValue: 24 * time.Hour,
defaultValue: 24 * time.Hour,
},
{
name: "read non-existing value different default",
key: KeyExpirationNoticeTime,
handlerError: ErrNoSuchKey,
handlerError: ErrNotConfigured,
wantValue: 0 * time.Second,
defaultValue: 0 * time.Second,
},
@ -387,11 +516,17 @@ func TestGetDuration(t *testing.T) {
wantValue: 24 * time.Hour,
wantError: someOtherError,
defaultValue: 24 * time.Hour,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_errors", Value: 1},
{Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := metrics.NewTestHandler(t)
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
SetHandlerForTest(t, &testHandler{
t: t,
key: tt.key,
@ -399,12 +534,21 @@ func TestGetDuration(t *testing.T) {
err: tt.handlerError,
})
duration, err := GetDuration(tt.key, tt.defaultValue)
if err != tt.wantError {
if fmt.Sprint(err) != fmt.Sprint(tt.wantError) {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if duration != tt.wantValue {
t.Errorf("duration=%v, want %v", duration, tt.wantValue)
}
wantMetrics := tt.wantMetrics
if !metrics.ShouldReport() {
// Check that metrics are not reported on platforms
// where they shouldn't be reported.
// As of 2024-08-02, syspolicy only reports metrics
// on Windows and Android.
wantMetrics = nil
}
h.MustEqual(wantMetrics...)
})
}
}
@ -418,23 +562,28 @@ func TestGetStringArray(t *testing.T) {
defaultValue []string
wantValue []string
wantError error
wantMetrics []metrics.TestState
}{
{
name: "read existing value",
key: AllowedSuggestedExitNodes,
handlerValue: []string{"foo", "bar"},
wantValue: []string{"foo", "bar"},
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_any", Value: 1},
{Name: "$os_syspolicy_AllowedSuggestedExitNodes", Value: 1},
},
},
{
name: "read non-existing value",
key: AllowedSuggestedExitNodes,
handlerError: ErrNoSuchKey,
handlerError: ErrNotConfigured,
wantError: nil,
},
{
name: "read non-existing value, non nil default",
key: AllowedSuggestedExitNodes,
handlerError: ErrNoSuchKey,
handlerError: ErrNotConfigured,
defaultValue: []string{"foo", "bar"},
wantValue: []string{"foo", "bar"},
wantError: nil,
@ -444,11 +593,17 @@ func TestGetStringArray(t *testing.T) {
key: AllowedSuggestedExitNodes,
handlerError: someOtherError,
wantError: someOtherError,
wantMetrics: []metrics.TestState{
{Name: "$os_syspolicy_errors", Value: 1},
{Name: "$os_syspolicy_AllowedSuggestedExitNodes_error", Value: 1},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := metrics.NewTestHandler(t)
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
SetHandlerForTest(t, &testHandler{
t: t,
key: tt.key,
@ -456,16 +611,47 @@ func TestGetStringArray(t *testing.T) {
err: tt.handlerError,
})
value, err := GetStringArray(tt.key, tt.defaultValue)
if err != tt.wantError {
if !errorsMatchForTest(err, tt.wantError) {
t.Errorf("err=%q, want %q", err, tt.wantError)
}
if !slices.Equal(tt.wantValue, value) {
t.Errorf("value=%v, want %v", value, tt.wantValue)
}
wantMetrics := tt.wantMetrics
if !metrics.ShouldReport() {
// Check that metrics are not reported on platforms
// where they shouldn't be reported.
// As of 2024-08-02, syspolicy only reports metrics
// on Windows and Android.
wantMetrics = nil
}
h.MustEqual(wantMetrics...)
})
}
}
func BenchmarkGetString(b *testing.B) {
loggerx.SetForTest(b, logger.Discard, logger.Discard)
setWellKnownSettingsForTest(b)
store := source.NewTestStore(b)
wantControlURL := "https://login.tailscale.com"
store.SetStrings(source.TestSetting[string]{Key: ControlURL, Value: wantControlURL})
_, err := rsop.RegisterStoreForTest(b, "Test Store", setting.DeviceScope, store)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
gotControlURL, _ := GetString(ControlURL, "https://controlplane.tailscale.com")
if gotControlURL != wantControlURL {
b.Fatalf("got %v; want %v", gotControlURL, wantControlURL)
}
}
}
func TestSelectControlURL(t *testing.T) {
tests := []struct {
reg, disk, want string
@ -497,3 +683,13 @@ func TestSelectControlURL(t *testing.T) {
}
}
}
func errorsMatchForTest(got, want error) bool {
if got == nil && want == nil {
return true
}
if got == nil || want == nil {
return false
}
return errors.Is(got, want) || got.Error() == want.Error()
}

View File

@ -0,0 +1,93 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package syspolicy
import (
"errors"
"fmt"
"os/user"
"tailscale.com/util/syspolicy/internal"
"tailscale.com/util/syspolicy/internal/lazyinit"
"tailscale.com/util/syspolicy/rsop"
"tailscale.com/util/syspolicy/setting"
"tailscale.com/util/syspolicy/source"
"tailscale.com/util/testenv"
)
func init() {
// On Windows, we should automatically register the Registry-based policy
// store for the device. If we are running in a user's security context
// (e.g., we're the GUI), we should also register the Registry policy store for
// the user. In the future, we should register (and unregister) user policy
// stores whenever a user connects to the local backend. This ensures the
// backend is aware of the user's policy settings and can send them to the
// GUI/CLI/Web clients on demand or whenever they change.
//
// Other platforms, such as macOS, iOS and Android, should register their
// platform-specific policy stores via [RegisterStore] (or [RegisterHandler]
// until they implement the [Store] interface).
//
// External code, such as the ipnlocal package, may choose to register
// additional policy stores, such as config files and policies received from
// the control plane.
lazyinit.Defer(func() error {
// Do not register or use default policy stores during tests.
// Each test should set up its own necessary configurations.
if testenv.InTest() {
return nil
}
return configureSyspolicy(nil)
})
}
// configureSyspolicy configures syspolicy for use on Windows,
// either in test or regular builds depending on whether tb has a non-nil value.
func configureSyspolicy(tb internal.TB) error {
const localSystemSID = "S-1-5-18"
// Always create and register a machine policy store that reads
// policy settings from the HKEY_LOCAL_MACHINE registry hive.
machineStore, err := source.NewMachinePlatformPolicyStore()
if err != nil {
return fmt.Errorf("failed to create the machine policy store: %v", err)
}
if tb == nil {
_, err = rsop.RegisterStore("Platform", setting.DeviceScope, machineStore)
} else {
_, err = rsop.RegisterStoreForTest(tb, "Platform", setting.DeviceScope, machineStore)
}
if err != nil {
return err
}
// Check whether the current process is running as Local System or not.
u, err := user.Current()
if err != nil {
return err
}
if u.Uid == localSystemSID {
return nil
}
// If it's not a Local System's process (e.g., the GUI and not the tailscaled service),
// we should create and use a policy store for the current user that reads
// policy settings from that user's registry hive (HKEY_CURRENT_USER).
userStore, err := source.NewUserPlatformPolicyStore(0)
if err != nil {
return fmt.Errorf("failed to create the current user's policy store: %v", err)
}
if tb == nil {
_, err = rsop.RegisterStore("Platform", setting.CurrentUserScope, userStore)
} else {
_, err = rsop.RegisterStoreForTest(tb, "Platform", setting.CurrentUserScope, userStore)
}
if err != nil {
return err
}
// And also set [CurrentUserScope] as the [CurrentScope], so [GetString],
// [GetVisibility] and similar functions would be returning a merged result
// of the machine's and user's policies.
if !setting.SetCurrentScope(setting.CurrentUserScope) {
return errors.New("current scope already set")
}
return nil
}

View File

@ -189,6 +189,7 @@ func (l *PolicyLock) lockSlow() (err error) {
select {
case resultCh <- policyLockResult{handle, err}:
// lockSlow has received the result.
break send_result
default:
select {
case <-closing: