mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-07 16:17:41 +00:00
uti/syspolicy: user policy support, auto-refresh and initial preparation for policy structs
This updates the syspolicy package to support multiple policy sources in the three policy scopes: user, profile, and device, and provides a merged resultant policy. A policy source is a syspolicy/source.Store that has a name and provides access to policy settings for a given scope. It can be registered with syspolicy/rsop.RegisterStore. Policy sources and policy stores can be either platform-specific or platform-agnostic. On Windows, we have the Registry-based, platform-specific policy store implemented as syspolicy/source.PlatformPolicyStore. This store provides access to the Group Policy and MDM policy settings stored in the Registry. On other platforms, we currently provide a wrapper that converts a syspolicy.Handler into a syspolicy/source.Store. However, we should update them in follow-up PRs. An example of a platform-agnostic policy store would be a policy deployed from the control, a local policy config file, or even environment variables. We maintain the current, most recent version of the resultant policy for each scope in an rsop.Policy. This is done by reading and merging the policy settings from the registered stores the first time the resultant policy is requested, then re-reading and re-merging them if a store implements the source.Changeable interface and reports a policy change. Policy change notifications are debounced to avoid re-reading policy settings multiple times if there are several changes within a short period. The rsop.Policy can notify clients if the resultant policy has changed. However, we do not currently expose this via the syspolicy package and plan to do so differently along with a struct-based policy hierarchy in the next PR. To facilitate this, all policy settings should be registered with the setting.Register function. The syspolicy package does this automatically for all policy settings defined in policy_keys.go. The new functionality is available through the existing syspolicy.Read* set of functions. However, we plan to expose it via a struct-based policy hierarchy, along with policy change notifications that other subsystems can use, in the next PR. We also plan to send the resultant policy back from tailscaled to the clients via the LocalAPI. This is primarily a foundational PR to facilitate future changes, but the immediate observable changes on Windows include: - The service will use the current policy setting values instead of those read at OS boot time. - The GUI has access to policy settings configured on a per-user basis. On Android: - We now report policy setting usage via clientmetrics. Updates #12687 Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
parent
655b4f8fc5
commit
cab0e1a6f7
@ -10,7 +10,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
|
||||
L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw
|
||||
W 💣 github.com/dblohm7/wingoes from tailscale.com/util/winutil
|
||||
github.com/fxamacker/cbor/v2 from tailscale.com/tka
|
||||
github.com/go-json-experiment/json from tailscale.com/types/opt
|
||||
github.com/go-json-experiment/json from tailscale.com/types/opt+
|
||||
github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json+
|
||||
github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json+
|
||||
github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json+
|
||||
@ -146,9 +146,11 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
|
||||
tailscale.com/util/cloudenv from tailscale.com/hostinfo+
|
||||
W tailscale.com/util/cmpver from tailscale.com/net/tshttpproxy
|
||||
tailscale.com/util/ctxkey from tailscale.com/tsweb+
|
||||
💣 tailscale.com/util/deephash from tailscale.com/util/syspolicy/setting
|
||||
L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics
|
||||
tailscale.com/util/dnsname from tailscale.com/hostinfo+
|
||||
tailscale.com/util/fastuuid from tailscale.com/tsweb
|
||||
💣 tailscale.com/util/hashx from tailscale.com/util/deephash
|
||||
tailscale.com/util/httpm from tailscale.com/client/tailscale
|
||||
tailscale.com/util/lineread from tailscale.com/hostinfo+
|
||||
L tailscale.com/util/linuxfw from tailscale.com/net/netns
|
||||
@ -159,8 +161,17 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
|
||||
tailscale.com/util/singleflight from tailscale.com/net/dnscache
|
||||
tailscale.com/util/slicesx from tailscale.com/cmd/derper+
|
||||
tailscale.com/util/syspolicy from tailscale.com/ipn
|
||||
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/internal/lazyinit from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
|
||||
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
|
||||
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy
|
||||
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/testenv from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/vizerror from tailscale.com/tailcfg+
|
||||
W 💣 tailscale.com/util/winutil from tailscale.com/hostinfo+
|
||||
W 💣 tailscale.com/util/winutil/gp from tailscale.com/util/syspolicy/source
|
||||
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
|
||||
tailscale.com/version from tailscale.com/derp+
|
||||
tailscale.com/version/distro from tailscale.com/envknob+
|
||||
@ -180,6 +191,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
|
||||
golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box
|
||||
golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+
|
||||
W golang.org/x/exp/constraints from tailscale.com/util/winutil
|
||||
golang.org/x/exp/maps from tailscale.com/util/syspolicy/internal/metrics+
|
||||
L golang.org/x/net/bpf from github.com/mdlayher/netlink+
|
||||
golang.org/x/net/dns/dnsmessage from net+
|
||||
golang.org/x/net/http/httpguts from net/http
|
||||
@ -240,7 +252,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
|
||||
encoding/pem from crypto/tls+
|
||||
errors from bufio+
|
||||
expvar from github.com/prometheus/client_golang/prometheus+
|
||||
flag from tailscale.com/cmd/derper
|
||||
flag from tailscale.com/cmd/derper+
|
||||
fmt from compress/flate+
|
||||
go/token from google.golang.org/protobuf/internal/strs
|
||||
hash from crypto+
|
||||
@ -273,7 +285,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
|
||||
os from crypto/rand+
|
||||
os/exec from github.com/coreos/go-iptables/iptables+
|
||||
os/signal from tailscale.com/cmd/derper
|
||||
W os/user from tailscale.com/util/winutil
|
||||
W os/user from tailscale.com/util/winutil+
|
||||
path from github.com/prometheus/client_golang/prometheus/internal+
|
||||
path/filepath from crypto/x509+
|
||||
reflect from crypto/x509+
|
||||
|
@ -96,7 +96,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
|
||||
💣 github.com/fsnotify/fsnotify from sigs.k8s.io/controller-runtime/pkg/certwatcher
|
||||
github.com/fxamacker/cbor/v2 from tailscale.com/tka
|
||||
github.com/gaissmai/bart from tailscale.com/net/ipset+
|
||||
github.com/go-json-experiment/json from tailscale.com/types/opt
|
||||
github.com/go-json-experiment/json from tailscale.com/types/opt+
|
||||
github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json/internal/jsonflags+
|
||||
github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json/internal/jsonopts+
|
||||
github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json/jsontext+
|
||||
@ -803,6 +803,13 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
|
||||
tailscale.com/util/singleflight from tailscale.com/control/controlclient+
|
||||
tailscale.com/util/slicesx from tailscale.com/appc+
|
||||
tailscale.com/util/syspolicy from tailscale.com/control/controlclient+
|
||||
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/internal/lazyinit from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
|
||||
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
|
||||
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy
|
||||
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock
|
||||
tailscale.com/util/systemd from tailscale.com/control/controlclient+
|
||||
tailscale.com/util/testenv from tailscale.com/control/controlclient+
|
||||
@ -811,7 +818,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
|
||||
tailscale.com/util/vizerror from tailscale.com/tailcfg+
|
||||
💣 tailscale.com/util/winutil from tailscale.com/clientupdate+
|
||||
W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate+
|
||||
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns
|
||||
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns+
|
||||
W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal
|
||||
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
|
||||
tailscale.com/util/zstdframe from tailscale.com/control/controlclient+
|
||||
|
@ -9,7 +9,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
|
||||
W 💣 github.com/dblohm7/wingoes from github.com/dblohm7/wingoes/pe+
|
||||
W 💣 github.com/dblohm7/wingoes/pe from tailscale.com/util/winutil/authenticode
|
||||
github.com/fxamacker/cbor/v2 from tailscale.com/tka
|
||||
github.com/go-json-experiment/json from tailscale.com/types/opt
|
||||
github.com/go-json-experiment/json from tailscale.com/types/opt+
|
||||
github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json+
|
||||
github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json+
|
||||
github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json+
|
||||
@ -152,9 +152,11 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
|
||||
tailscale.com/util/cloudenv from tailscale.com/net/dnscache+
|
||||
tailscale.com/util/cmpver from tailscale.com/net/tshttpproxy+
|
||||
tailscale.com/util/ctxkey from tailscale.com/types/logger
|
||||
💣 tailscale.com/util/deephash from tailscale.com/util/syspolicy/setting
|
||||
L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics
|
||||
tailscale.com/util/dnsname from tailscale.com/cmd/tailscale/cli+
|
||||
tailscale.com/util/groupmember from tailscale.com/client/web
|
||||
💣 tailscale.com/util/hashx from tailscale.com/util/deephash
|
||||
tailscale.com/util/httpm from tailscale.com/client/tailscale+
|
||||
tailscale.com/util/lineread from tailscale.com/hostinfo+
|
||||
L tailscale.com/util/linuxfw from tailscale.com/net/netns
|
||||
@ -167,11 +169,19 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
|
||||
tailscale.com/util/singleflight from tailscale.com/net/dnscache+
|
||||
tailscale.com/util/slicesx from tailscale.com/net/dns/recursive+
|
||||
tailscale.com/util/syspolicy from tailscale.com/ipn
|
||||
tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli
|
||||
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/internal/lazyinit from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
|
||||
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
|
||||
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy
|
||||
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli+
|
||||
tailscale.com/util/truncate from tailscale.com/cmd/tailscale/cli
|
||||
tailscale.com/util/vizerror from tailscale.com/tailcfg+
|
||||
💣 tailscale.com/util/winutil from tailscale.com/clientupdate+
|
||||
W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate
|
||||
W 💣 tailscale.com/util/winutil/gp from tailscale.com/util/syspolicy/source
|
||||
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
|
||||
tailscale.com/version from tailscale.com/client/web+
|
||||
tailscale.com/version/distro from tailscale.com/client/web+
|
||||
@ -191,7 +201,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
|
||||
golang.org/x/crypto/pbkdf2 from software.sslmate.com/src/go-pkcs12
|
||||
golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+
|
||||
W golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe+
|
||||
golang.org/x/exp/maps from tailscale.com/cmd/tailscale/cli
|
||||
golang.org/x/exp/maps from tailscale.com/cmd/tailscale/cli+
|
||||
golang.org/x/net/bpf from github.com/mdlayher/netlink+
|
||||
golang.org/x/net/dns/dnsmessage from net+
|
||||
golang.org/x/net/http/httpguts from net/http+
|
||||
|
@ -90,7 +90,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
💣 github.com/djherbis/times from tailscale.com/drive/driveimpl
|
||||
github.com/fxamacker/cbor/v2 from tailscale.com/tka
|
||||
github.com/gaissmai/bart from tailscale.com/net/tstun+
|
||||
github.com/go-json-experiment/json from tailscale.com/types/opt
|
||||
github.com/go-json-experiment/json from tailscale.com/types/opt+
|
||||
github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json/internal/jsonflags+
|
||||
github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json/internal/jsonopts+
|
||||
github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json/jsontext+
|
||||
@ -395,6 +395,13 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
tailscale.com/util/singleflight from tailscale.com/control/controlclient+
|
||||
tailscale.com/util/slicesx from tailscale.com/net/dns/recursive+
|
||||
tailscale.com/util/syspolicy from tailscale.com/cmd/tailscaled+
|
||||
tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/internal/lazyinit from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+
|
||||
tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source
|
||||
tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy
|
||||
tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+
|
||||
tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock
|
||||
tailscale.com/util/systemd from tailscale.com/control/controlclient+
|
||||
tailscale.com/util/testenv from tailscale.com/ipn/ipnlocal+
|
||||
@ -403,7 +410,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
tailscale.com/util/vizerror from tailscale.com/tailcfg+
|
||||
💣 tailscale.com/util/winutil from tailscale.com/clientupdate+
|
||||
W 💣 tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate+
|
||||
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns
|
||||
W 💣 tailscale.com/util/winutil/gp from tailscale.com/net/dns+
|
||||
W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal
|
||||
W 💣 tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+
|
||||
tailscale.com/util/zstdframe from tailscale.com/control/controlclient+
|
||||
|
@ -52,6 +52,8 @@
|
||||
"tailscale.com/util/must"
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/util/syspolicy"
|
||||
"tailscale.com/util/syspolicy/rsop"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/wgengine"
|
||||
"tailscale.com/wgengine/filter"
|
||||
"tailscale.com/wgengine/wgcfg"
|
||||
@ -2546,6 +2548,14 @@ func TestPreferencePolicyInfo(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
definitions := make([]*setting.Definition, 0, len(preferencePolicies)+1)
|
||||
definitions = append(definitions, must.Get(syspolicy.WellKnownSettingDefinition(syspolicy.ControlURL)))
|
||||
for _, pp := range preferencePolicies {
|
||||
definitions = append(definitions, must.Get(syspolicy.WellKnownSettingDefinition(pp.key)))
|
||||
}
|
||||
if err := setting.SetDefinitionsForTest(t, definitions...); err != nil {
|
||||
t.Fatalf("SetDefinitionsForTest failed: %v", err)
|
||||
}
|
||||
for _, pp := range preferencePolicies {
|
||||
t.Run(string(pp.key), func(t *testing.T) {
|
||||
var h syspolicy.Handler
|
||||
@ -2572,7 +2582,7 @@ func TestPreferencePolicyInfo(t *testing.T) {
|
||||
msh.stringPolicies[pp.key] = &tt.policyValue
|
||||
h = msh
|
||||
}
|
||||
syspolicy.SetHandlerForTest(t, h)
|
||||
rsop.RegisterStoreForTest(t, tt.name, setting.DeviceScope, syspolicy.WrapHandler(h))
|
||||
|
||||
prefs := defaultPrefs.AsStruct()
|
||||
pp.set(prefs, tt.initialValue)
|
||||
|
@ -1,122 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// CachingHandler is a handler that reads policies from an underlying handler the first time each key is requested
|
||||
// and permanently caches the result unless there is an error. If there is an ErrNoSuchKey error, that result is cached,
|
||||
// otherwise the actual error is returned and the next read for that key will retry using the handler.
|
||||
type CachingHandler struct {
|
||||
mu sync.Mutex
|
||||
strings map[string]string
|
||||
uint64s map[string]uint64
|
||||
bools map[string]bool
|
||||
strArrs map[string][]string
|
||||
notFound map[string]bool
|
||||
handler Handler
|
||||
}
|
||||
|
||||
// NewCachingHandler creates a CachingHandler given a handler.
|
||||
func NewCachingHandler(handler Handler) *CachingHandler {
|
||||
return &CachingHandler{
|
||||
handler: handler,
|
||||
strings: make(map[string]string),
|
||||
uint64s: make(map[string]uint64),
|
||||
bools: make(map[string]bool),
|
||||
strArrs: make(map[string][]string),
|
||||
notFound: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// ReadString reads the policy settings value string given the key.
|
||||
// ReadString first reads from the handler's cache before resorting to using the handler.
|
||||
func (ch *CachingHandler) ReadString(key string) (string, error) {
|
||||
ch.mu.Lock()
|
||||
defer ch.mu.Unlock()
|
||||
if val, ok := ch.strings[key]; ok {
|
||||
return val, nil
|
||||
}
|
||||
if notFound := ch.notFound[key]; notFound {
|
||||
return "", ErrNoSuchKey
|
||||
}
|
||||
val, err := ch.handler.ReadString(key)
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
ch.notFound[key] = true
|
||||
return "", err
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
ch.strings[key] = val
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// ReadUInt64 reads the policy settings uint64 value given the key.
|
||||
// ReadUInt64 first reads from the handler's cache before resorting to using the handler.
|
||||
func (ch *CachingHandler) ReadUInt64(key string) (uint64, error) {
|
||||
ch.mu.Lock()
|
||||
defer ch.mu.Unlock()
|
||||
if val, ok := ch.uint64s[key]; ok {
|
||||
return val, nil
|
||||
}
|
||||
if notFound := ch.notFound[key]; notFound {
|
||||
return 0, ErrNoSuchKey
|
||||
}
|
||||
val, err := ch.handler.ReadUInt64(key)
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
ch.notFound[key] = true
|
||||
return 0, err
|
||||
} else if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
ch.uint64s[key] = val
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// ReadBoolean reads the policy settings boolean value given the key.
|
||||
// ReadBoolean first reads from the handler's cache before resorting to using the handler.
|
||||
func (ch *CachingHandler) ReadBoolean(key string) (bool, error) {
|
||||
ch.mu.Lock()
|
||||
defer ch.mu.Unlock()
|
||||
if val, ok := ch.bools[key]; ok {
|
||||
return val, nil
|
||||
}
|
||||
if notFound := ch.notFound[key]; notFound {
|
||||
return false, ErrNoSuchKey
|
||||
}
|
||||
val, err := ch.handler.ReadBoolean(key)
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
ch.notFound[key] = true
|
||||
return false, err
|
||||
} else if err != nil {
|
||||
return false, err
|
||||
}
|
||||
ch.bools[key] = val
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// ReadBoolean reads the policy settings boolean value given the key.
|
||||
// ReadBoolean first reads from the handler's cache before resorting to using the handler.
|
||||
func (ch *CachingHandler) ReadStringArray(key string) ([]string, error) {
|
||||
ch.mu.Lock()
|
||||
defer ch.mu.Unlock()
|
||||
if val, ok := ch.strArrs[key]; ok {
|
||||
return val, nil
|
||||
}
|
||||
if notFound := ch.notFound[key]; notFound {
|
||||
return nil, ErrNoSuchKey
|
||||
}
|
||||
val, err := ch.handler.ReadStringArray(key)
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
ch.notFound[key] = true
|
||||
return nil, err
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch.strArrs[key] = val
|
||||
return val, nil
|
||||
}
|
@ -1,262 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHandlerReadString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
handlerKey Key
|
||||
handlerValue string
|
||||
handlerError error
|
||||
preserveHandler bool
|
||||
wantValue string
|
||||
wantErr error
|
||||
strings map[string]string
|
||||
expectedCalls int
|
||||
}{
|
||||
{
|
||||
name: "read existing cached values",
|
||||
key: "test",
|
||||
handlerKey: "do not read",
|
||||
strings: map[string]string{"test": "foo"},
|
||||
wantValue: "foo",
|
||||
expectedCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "read existing values not cached",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerValue: "foo",
|
||||
wantValue: "foo",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "error no such key",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantErr: ErrNoSuchKey,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: someOtherError,
|
||||
wantErr: someOtherError,
|
||||
preserveHandler: true,
|
||||
expectedCalls: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testHandler := &testHandler{
|
||||
t: t,
|
||||
key: tt.handlerKey,
|
||||
s: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
}
|
||||
cache := NewCachingHandler(testHandler)
|
||||
if tt.strings != nil {
|
||||
cache.strings = tt.strings
|
||||
}
|
||||
got, err := cache.ReadString(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if !tt.preserveHandler {
|
||||
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
|
||||
}
|
||||
got, err = cache.ReadString(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if testHandler.calls != tt.expectedCalls {
|
||||
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerReadUint64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
handlerKey Key
|
||||
handlerValue uint64
|
||||
handlerError error
|
||||
preserveHandler bool
|
||||
wantValue uint64
|
||||
wantErr error
|
||||
uint64s map[string]uint64
|
||||
expectedCalls int
|
||||
}{
|
||||
{
|
||||
name: "read existing cached values",
|
||||
key: "test",
|
||||
handlerKey: "do not read",
|
||||
uint64s: map[string]uint64{"test": 1},
|
||||
wantValue: 1,
|
||||
expectedCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "read existing values not cached",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerValue: 1,
|
||||
wantValue: 1,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "error no such key",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantErr: ErrNoSuchKey,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: someOtherError,
|
||||
wantErr: someOtherError,
|
||||
preserveHandler: true,
|
||||
expectedCalls: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testHandler := &testHandler{
|
||||
t: t,
|
||||
key: tt.handlerKey,
|
||||
u64: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
}
|
||||
cache := NewCachingHandler(testHandler)
|
||||
if tt.uint64s != nil {
|
||||
cache.uint64s = tt.uint64s
|
||||
}
|
||||
got, err := cache.ReadUInt64(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if !tt.preserveHandler {
|
||||
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
|
||||
}
|
||||
got, err = cache.ReadUInt64(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if testHandler.calls != tt.expectedCalls {
|
||||
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestHandlerReadBool(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
handlerKey Key
|
||||
handlerValue bool
|
||||
handlerError error
|
||||
preserveHandler bool
|
||||
wantValue bool
|
||||
wantErr error
|
||||
bools map[string]bool
|
||||
expectedCalls int
|
||||
}{
|
||||
{
|
||||
name: "read existing cached values",
|
||||
key: "test",
|
||||
handlerKey: "do not read",
|
||||
bools: map[string]bool{"test": true},
|
||||
wantValue: true,
|
||||
expectedCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "read existing values not cached",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerValue: true,
|
||||
wantValue: true,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "error no such key",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantErr: ErrNoSuchKey,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: someOtherError,
|
||||
wantErr: someOtherError,
|
||||
preserveHandler: true,
|
||||
expectedCalls: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testHandler := &testHandler{
|
||||
t: t,
|
||||
key: tt.handlerKey,
|
||||
b: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
}
|
||||
cache := NewCachingHandler(testHandler)
|
||||
if tt.bools != nil {
|
||||
cache.bools = tt.bools
|
||||
}
|
||||
got, err := cache.ReadBoolean(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if !tt.preserveHandler {
|
||||
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
|
||||
}
|
||||
got, err = cache.ReadBoolean(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if testHandler.calls != tt.expectedCalls {
|
||||
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
@ -4,16 +4,15 @@
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var (
|
||||
handlerUsed atomic.Bool
|
||||
handler Handler = defaultHandler{}
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/rsop"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/syspolicy/source"
|
||||
)
|
||||
|
||||
// Handler reads system policies from OS-specific storage.
|
||||
//
|
||||
// Deprecated: implementing a [Store] should be preferred.
|
||||
type Handler interface {
|
||||
// ReadString reads the policy setting's string value for the given key.
|
||||
// It should return ErrNoSuchKey if the key does not have a value set.
|
||||
@ -29,55 +28,81 @@ type Handler interface {
|
||||
ReadStringArray(key string) ([]string, error)
|
||||
}
|
||||
|
||||
// ErrNoSuchKey is returned by a Handler when the specified key does not have a
|
||||
// value set.
|
||||
var ErrNoSuchKey = errors.New("no such key")
|
||||
|
||||
// defaultHandler is the catch all syspolicy type for anything that isn't windows or apple.
|
||||
type defaultHandler struct{}
|
||||
|
||||
func (defaultHandler) ReadString(_ string) (string, error) {
|
||||
return "", ErrNoSuchKey
|
||||
}
|
||||
|
||||
func (defaultHandler) ReadUInt64(_ string) (uint64, error) {
|
||||
return 0, ErrNoSuchKey
|
||||
}
|
||||
|
||||
func (defaultHandler) ReadBoolean(_ string) (bool, error) {
|
||||
return false, ErrNoSuchKey
|
||||
}
|
||||
|
||||
func (defaultHandler) ReadStringArray(_ string) ([]string, error) {
|
||||
return nil, ErrNoSuchKey
|
||||
}
|
||||
|
||||
// markHandlerInUse is called before handler methods are called.
|
||||
func markHandlerInUse() {
|
||||
handlerUsed.Store(true)
|
||||
}
|
||||
|
||||
// RegisterHandler initializes the policy handler and ensures registration will happen once.
|
||||
// RegisterHandler wraps and registers the specified handler as the device's
|
||||
// policy [Store] for the program's lifetime.
|
||||
//
|
||||
// Deprecated: using [RegisterStore] should be preferred.
|
||||
func RegisterHandler(h Handler) {
|
||||
// Technically this assignment is not concurrency safe, but in the
|
||||
// event that there was any risk of a data race, we will panic due to
|
||||
// the CompareAndSwap failing.
|
||||
handler = h
|
||||
if !handlerUsed.CompareAndSwap(false, true) {
|
||||
panic("handler was already used before registration")
|
||||
}
|
||||
rsop.RegisterStore("DeviceHandler", setting.DeviceScope, WrapHandler(h))
|
||||
}
|
||||
|
||||
// TB is a subset of testing.TB that we use to set up test helpers.
|
||||
// It's defined here to avoid pulling in the testing package.
|
||||
type TB interface {
|
||||
Helper()
|
||||
Cleanup(func())
|
||||
type TB = internal.TB
|
||||
|
||||
// SetHandlerForTest wraps and sets the specified handler as the device's policy
|
||||
// [Store] for the duration of tb.
|
||||
//
|
||||
// Deprecated: using [resultant.RegisterStoreForTest] should be preferred.
|
||||
func SetHandlerForTest(tb TB, h Handler) {
|
||||
if err := setWellKnownSettingsForTest(tb); err != nil {
|
||||
tb.Fatalf("setWellKnownSettingsForTest failed: %v", err)
|
||||
}
|
||||
rsop.RegisterStoreForTest(tb, "DeviceHandler-TestOnly", setting.CurrentScope(), WrapHandler(h))
|
||||
}
|
||||
|
||||
func SetHandlerForTest(tb TB, h Handler) {
|
||||
tb.Helper()
|
||||
oldHandler := handler
|
||||
handler = h
|
||||
tb.Cleanup(func() { handler = oldHandler })
|
||||
var _ source.Store = (*handlerStore)(nil)
|
||||
|
||||
// handlerStore is a [source.Store] that calls the underlying [Handler].
|
||||
// TODO(nickkhyl): remove it when the corp and android repos are updated.
|
||||
type handlerStore struct {
|
||||
h Handler
|
||||
}
|
||||
|
||||
// WrapHandler returns a [source.Store] that wraps the specified [Handler].
|
||||
func WrapHandler(h Handler) source.Store {
|
||||
return handlerStore{h}
|
||||
}
|
||||
|
||||
func (s handlerStore) Lock() error {
|
||||
if lockable, ok := s.h.(source.Lockable); ok {
|
||||
return lockable.Lock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s handlerStore) Unlock() {
|
||||
if lockable, ok := s.h.(source.Lockable); ok {
|
||||
lockable.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s handlerStore) RegisterChangeCallback(callback func()) (unregister func(), err error) {
|
||||
if lockable, ok := s.h.(source.Changeable); ok {
|
||||
return lockable.RegisterChangeCallback(callback)
|
||||
}
|
||||
return func() {}, nil
|
||||
}
|
||||
|
||||
func (s handlerStore) ReadString(key setting.Key) (string, error) {
|
||||
return s.h.ReadString(string(key))
|
||||
}
|
||||
|
||||
func (s handlerStore) ReadUInt64(key setting.Key) (uint64, error) {
|
||||
return s.h.ReadUInt64(string(key))
|
||||
}
|
||||
|
||||
func (s handlerStore) ReadBoolean(key setting.Key) (bool, error) {
|
||||
return s.h.ReadBoolean(string(key))
|
||||
}
|
||||
|
||||
func (s handlerStore) ReadStringArray(key setting.Key) ([]string, error) {
|
||||
return s.h.ReadStringArray(string(key))
|
||||
}
|
||||
|
||||
func (s handlerStore) Done() <-chan struct{} {
|
||||
if expirable, ok := s.h.(source.Expirable); ok {
|
||||
return expirable.Done()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -1,19 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaultHandlerReadValues(t *testing.T) {
|
||||
var h defaultHandler
|
||||
|
||||
got, err := h.ReadString(string(AdminConsoleVisibility))
|
||||
if got != "" || err != ErrNoSuchKey {
|
||||
t.Fatalf("got %v err %v", got, err)
|
||||
}
|
||||
result, err := h.ReadUInt64(string(LogSCMInteractions))
|
||||
if result != 0 || err != ErrNoSuchKey {
|
||||
t.Fatalf("got %v err %v", result, err)
|
||||
}
|
||||
}
|
@ -1,105 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/winutil"
|
||||
)
|
||||
|
||||
var (
|
||||
windowsErrors = clientmetric.NewCounter("windows_syspolicy_errors")
|
||||
windowsAny = clientmetric.NewGauge("windows_syspolicy_any")
|
||||
)
|
||||
|
||||
type windowsHandler struct{}
|
||||
|
||||
func init() {
|
||||
RegisterHandler(NewCachingHandler(windowsHandler{}))
|
||||
|
||||
keyList := []struct {
|
||||
isSet func(Key) bool
|
||||
keys []Key
|
||||
}{
|
||||
{
|
||||
isSet: func(k Key) bool {
|
||||
_, err := handler.ReadString(string(k))
|
||||
return err == nil
|
||||
},
|
||||
keys: stringKeys,
|
||||
},
|
||||
{
|
||||
isSet: func(k Key) bool {
|
||||
_, err := handler.ReadBoolean(string(k))
|
||||
return err == nil
|
||||
},
|
||||
keys: boolKeys,
|
||||
},
|
||||
{
|
||||
isSet: func(k Key) bool {
|
||||
_, err := handler.ReadUInt64(string(k))
|
||||
return err == nil
|
||||
},
|
||||
keys: uint64Keys,
|
||||
},
|
||||
}
|
||||
|
||||
var anySet bool
|
||||
for _, l := range keyList {
|
||||
for _, k := range l.keys {
|
||||
if !l.isSet(k) {
|
||||
continue
|
||||
}
|
||||
clientmetric.NewGauge(fmt.Sprintf("windows_syspolicy_%s", k)).Set(1)
|
||||
anySet = true
|
||||
}
|
||||
}
|
||||
if anySet {
|
||||
windowsAny.Set(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (windowsHandler) ReadString(key string) (string, error) {
|
||||
s, err := winutil.GetPolicyString(key)
|
||||
if errors.Is(err, winutil.ErrNoValue) {
|
||||
err = ErrNoSuchKey
|
||||
} else if err != nil {
|
||||
windowsErrors.Add(1)
|
||||
}
|
||||
|
||||
return s, err
|
||||
}
|
||||
|
||||
func (windowsHandler) ReadUInt64(key string) (uint64, error) {
|
||||
value, err := winutil.GetPolicyInteger(key)
|
||||
if errors.Is(err, winutil.ErrNoValue) {
|
||||
err = ErrNoSuchKey
|
||||
} else if err != nil {
|
||||
windowsErrors.Add(1)
|
||||
}
|
||||
return value, err
|
||||
}
|
||||
|
||||
func (windowsHandler) ReadBoolean(key string) (bool, error) {
|
||||
value, err := winutil.GetPolicyInteger(key)
|
||||
if errors.Is(err, winutil.ErrNoValue) {
|
||||
err = ErrNoSuchKey
|
||||
} else if err != nil {
|
||||
windowsErrors.Add(1)
|
||||
}
|
||||
return value != 0, err
|
||||
}
|
||||
|
||||
func (windowsHandler) ReadStringArray(key string) ([]string, error) {
|
||||
value, err := winutil.GetPolicyStringArray(key)
|
||||
if errors.Is(err, winutil.ErrNoValue) {
|
||||
err = ErrNoSuchKey
|
||||
} else if err != nil {
|
||||
windowsErrors.Add(1)
|
||||
}
|
||||
return value, err
|
||||
}
|
63
util/syspolicy/internal/internal.go
Normal file
63
util/syspolicy/internal/internal.go
Normal file
@ -0,0 +1,63 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package internal contains miscellaneous functions and types
|
||||
// that are internal to the syspolicy packages.
|
||||
package internal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/go-json-experiment/json/jsontext"
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/version"
|
||||
)
|
||||
|
||||
// OSForTesting is the operating system override used for testing.
|
||||
// It follows the same naming convention as [version.OS].
|
||||
var OSForTesting lazy.SyncValue[string]
|
||||
|
||||
// OS is like [version.OS], but supports a test hook.
|
||||
func OS() string {
|
||||
return OSForTesting.Get(version.OS)
|
||||
}
|
||||
|
||||
// TB is a subset of testing.TB that we use to set up test helpers.
|
||||
// It's defined here to avoid pulling in the testing package.
|
||||
type TB interface {
|
||||
Helper()
|
||||
Cleanup(func())
|
||||
Logf(format string, args ...any)
|
||||
Error(args ...any)
|
||||
Errorf(format string, args ...any)
|
||||
Fatal(args ...any)
|
||||
Fatalf(format string, args ...any)
|
||||
}
|
||||
|
||||
// EqualJSONForTest compares the JSON in j1 and j2 for semantic equality.
|
||||
// It returns "", "", true if j1 and j2 are equal. Otherwise, it returns
|
||||
// indented versions of j1 and j2 and false.
|
||||
func EqualJSONForTest(tb TB, j1, j2 jsontext.Value) (s1, s2 string, equal bool) {
|
||||
tb.Helper()
|
||||
j1 = j1.Clone()
|
||||
j2 = j2.Clone()
|
||||
// Canonicalize JSON values for comparison.
|
||||
if err := j1.Canonicalize(); err != nil {
|
||||
tb.Error(err)
|
||||
}
|
||||
if err := j2.Canonicalize(); err != nil {
|
||||
tb.Error(err)
|
||||
}
|
||||
// Check and return true if the two values are structurally equal.
|
||||
if bytes.Equal(j1, j2) {
|
||||
return "", "", true
|
||||
}
|
||||
// Otherwise, format the values for display and return false.
|
||||
if err := j1.Indent("", "\t"); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
if err := j2.Indent("", "\t"); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
return j1.String(), j2.String(), false
|
||||
}
|
84
util/syspolicy/internal/lazyinit/lazyinit.go
Normal file
84
util/syspolicy/internal/lazyinit/lazyinit.go
Normal file
@ -0,0 +1,84 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// The lazyinit package facilitates deferred package initialization.
|
||||
package lazyinit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var packageInit deferredOnce
|
||||
|
||||
// Defer defers the specified action until [Do] is called.
|
||||
// It returns a boolean indicating whether [Do] has already been called.
|
||||
func Defer(action func() error) bool {
|
||||
return packageInit.Defer(action)
|
||||
}
|
||||
|
||||
// DeferWithCleanup is like [Defer], but the action function returns a cleanup
|
||||
// function to be called in case of an error.
|
||||
func DeferWithCleanup(action func() (cleanup func(), err error)) bool {
|
||||
return packageInit.DeferWithCleanup(action)
|
||||
}
|
||||
|
||||
// Do runs all deferred functions and returns an error if any of them fail.
|
||||
func Do() error {
|
||||
return packageInit.Do()
|
||||
}
|
||||
|
||||
type deferredOnce struct {
|
||||
done atomic.Uint32
|
||||
err error
|
||||
m sync.Mutex
|
||||
funcs []func() (cleanup func(), err error)
|
||||
}
|
||||
|
||||
func (o *deferredOnce) Defer(action func() error) bool {
|
||||
return o.DeferWithCleanup(func() (cleanup func(), err error) {
|
||||
return nil, action()
|
||||
})
|
||||
}
|
||||
|
||||
func (o *deferredOnce) DeferWithCleanup(action func() (cleanup func(), err error)) bool {
|
||||
o.m.Lock()
|
||||
defer o.m.Unlock()
|
||||
if o.done.Load() != 0 {
|
||||
return false
|
||||
}
|
||||
o.funcs = append(o.funcs, action)
|
||||
return true
|
||||
}
|
||||
|
||||
func (o *deferredOnce) Do() error {
|
||||
if o.done.Load() == 0 {
|
||||
o.doSlow()
|
||||
}
|
||||
return o.err
|
||||
}
|
||||
|
||||
func (o *deferredOnce) doSlow() (err error) {
|
||||
o.m.Lock()
|
||||
defer o.m.Unlock()
|
||||
if o.done.Load() == 0 {
|
||||
defer func() {
|
||||
o.done.Store(1)
|
||||
o.err = err
|
||||
}()
|
||||
for _, f := range o.funcs {
|
||||
cleanup, err := f()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cleanup != nil {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
cleanup()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
return o.err
|
||||
}
|
46
util/syspolicy/internal/loggerx/logger.go
Normal file
46
util/syspolicy/internal/loggerx/logger.go
Normal file
@ -0,0 +1,46 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package loggerx provides logging functions to the rest of the syspolicy packages.
|
||||
package loggerx
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
)
|
||||
|
||||
const (
|
||||
errorPrefix = "syspolicy: "
|
||||
verbosePrefix = "syspolicy: [v2] "
|
||||
)
|
||||
|
||||
var (
|
||||
lazyErrorf lazy.SyncValue[logger.Logf]
|
||||
lazyVerbosef lazy.SyncValue[logger.Logf]
|
||||
)
|
||||
|
||||
// Errorf formats and writes an error message to the log.
|
||||
func Errorf(format string, args ...any) {
|
||||
errorf := lazyErrorf.Get(func() logger.Logf {
|
||||
return logger.WithPrefix(log.Printf, errorPrefix)
|
||||
})
|
||||
errorf(format, args...)
|
||||
}
|
||||
|
||||
// Verbosef formats and writes an optional, verbose message to the log.
|
||||
func Verbosef(format string, args ...any) {
|
||||
verbosef := lazyVerbosef.Get(func() logger.Logf {
|
||||
return logger.WithPrefix(log.Printf, verbosePrefix)
|
||||
})
|
||||
verbosef(format, args...)
|
||||
}
|
||||
|
||||
// SetForTest sets the specified errorf and verbosef functions for the duration
|
||||
// of tb and its subtests.
|
||||
func SetForTest(tb internal.TB, errorf, verbosef logger.Logf) {
|
||||
lazyErrorf.SetForTest(tb, errorf, nil)
|
||||
lazyVerbosef.SetForTest(tb, verbosef, nil)
|
||||
}
|
315
util/syspolicy/internal/metrics/metrics.go
Normal file
315
util/syspolicy/internal/metrics/metrics.go
Normal file
@ -0,0 +1,315 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package metrics provides logging and reporting for policy settings and scopes.
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
|
||||
"tailscale.com/syncs"
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/internal/loggerx"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/testenv"
|
||||
)
|
||||
|
||||
var lazyReportMetrics lazy.SyncValue[bool] // used as a test hook
|
||||
|
||||
// ShouldReport reports whether metrics should be reported on the current environment.
|
||||
func ShouldReport() bool {
|
||||
return lazyReportMetrics.Get(func() bool {
|
||||
// macOS, iOS and tvOS create their own metrics,
|
||||
// and we don't have syspolicy on any other platforms.
|
||||
return setting.PlatformList{"android", "windows"}.HasCurrent()
|
||||
})
|
||||
}
|
||||
|
||||
// Reset metrics for the specified policy origin.
|
||||
func Reset(origin *setting.Origin) {
|
||||
scopeMetrics(origin).Reset()
|
||||
}
|
||||
|
||||
// ReportConfigured updates metrics and logs that the specified setting is
|
||||
// configured with the given value in the origin.
|
||||
func ReportConfigured(origin *setting.Origin, setting *setting.Definition, value any) {
|
||||
settingMetricsFor(setting).ReportValue(origin, value)
|
||||
}
|
||||
|
||||
// ReportError updates metrics and logs that the specified setting has an error
|
||||
// in the origin.
|
||||
func ReportError(origin *setting.Origin, setting *setting.Definition, err error) {
|
||||
settingMetricsFor(setting).ReportError(origin, err)
|
||||
}
|
||||
|
||||
// ReportNotConfigured updates metrics and logs that the specified setting is
|
||||
// not configured in the origin.
|
||||
func ReportNotConfigured(origin *setting.Origin, setting *setting.Definition) {
|
||||
settingMetricsFor(setting).Reset(origin)
|
||||
}
|
||||
|
||||
// metric is an interface implemented by [clientmetric.Metric] and [funcMetric].
|
||||
type metric interface {
|
||||
Add(v int64)
|
||||
Set(v int64)
|
||||
}
|
||||
|
||||
// policyScopeMetrics are metrics that apply to an entire policy scope rather
|
||||
// than a specific policy setting.
|
||||
type policyScopeMetrics struct {
|
||||
hasAny metric
|
||||
numErrored metric
|
||||
}
|
||||
|
||||
func newScopeMetrics(scope setting.Scope) *policyScopeMetrics {
|
||||
prefix := metricScopeName(scope)
|
||||
if prefix != "" {
|
||||
prefix += "_"
|
||||
}
|
||||
// {os}_syspolicy_{scope_unless_device}_any
|
||||
// Example: windows_syspolicy_any or windows_syspolicy_user_any.
|
||||
hasAny := newMetric(prefix+"any", clientmetric.TypeGauge)
|
||||
// {os}_syspolicy_{scope_unless_device}_errors
|
||||
// Example: windows_syspolicy_errors or windows_syspolicy_user_errors.
|
||||
//
|
||||
// TODO(nickkhyl): maybe make the `{os}_syspolicy_errors` metric a gauge rather than a counter?
|
||||
// It was a counter prior to https://github.com/tailscale/tailscale/issues/12687, so I kept it as such.
|
||||
// But I think a gauge makes more sense: syspolicy errors indicate a mismatch between the expected
|
||||
// policy value type or format and the actual value read from the underlying store (like the Windows Registry).
|
||||
// We'll encounter the same error every time we re-read the policy setting from the backing store
|
||||
// until the policy value is corrected by the user, or until we fix the bug in the code or ADMX.
|
||||
// There's probably no reason to count and accumulate them over time.
|
||||
numErrored := newMetric(prefix+"errors", clientmetric.TypeCounter)
|
||||
return &policyScopeMetrics{hasAny, numErrored}
|
||||
}
|
||||
|
||||
// ReportHasSettings is called when there's any configured policy setting in the scope.
|
||||
func (m *policyScopeMetrics) ReportHasSettings() {
|
||||
if m != nil {
|
||||
m.hasAny.Set(1)
|
||||
}
|
||||
}
|
||||
|
||||
// ReportError is called when there's any errored policy setting in the scope.
|
||||
func (m *policyScopeMetrics) ReportError() {
|
||||
if m != nil {
|
||||
m.numErrored.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Reset is called to reset the policy scope metrics, such as when the policy scope
|
||||
// is about to be reloaded.
|
||||
func (m *policyScopeMetrics) Reset() {
|
||||
if m != nil {
|
||||
m.hasAny.Set(0)
|
||||
// numErrored is a counter and cannot be (re-)set.
|
||||
}
|
||||
}
|
||||
|
||||
// settingMetrics are metrics for a single policy setting in one or more scopes.
|
||||
type settingMetrics struct {
|
||||
definition *setting.Definition
|
||||
isSet []metric // by scope
|
||||
hasErrors []metric // by scope
|
||||
}
|
||||
|
||||
// ReportValue is called when the policy setting is found to be configured in the specified source.
|
||||
func (m *settingMetrics) ReportValue(origin *setting.Origin, v any) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if scope := origin.Scope().Kind(); int(scope) < len(m.isSet) {
|
||||
m.isSet[scope].Set(1)
|
||||
m.hasErrors[scope].Set(0)
|
||||
}
|
||||
scopeMetrics(origin).ReportHasSettings()
|
||||
loggerx.Verbosef("%v(%q) = %v\n", origin, m.definition.Key(), v)
|
||||
}
|
||||
|
||||
// ReportError is called when there's an error with the policy setting in the specified source.
|
||||
func (m *settingMetrics) ReportError(origin *setting.Origin, err error) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if scope := origin.Scope().Kind(); int(scope) < len(m.hasErrors) {
|
||||
m.isSet[scope].Set(0)
|
||||
m.hasErrors[scope].Set(1)
|
||||
}
|
||||
scopeMetrics(origin).ReportError()
|
||||
loggerx.Errorf("%v(%q): %v\n", origin, m.definition.Key(), err)
|
||||
}
|
||||
|
||||
// Reset is called to reset the policy setting's metrics, such as when
|
||||
// the policy setting does not exist or the source containing the policy
|
||||
// is about to be reloaded.
|
||||
func (m *settingMetrics) Reset(origin *setting.Origin) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if scope := origin.Scope().Kind(); int(scope) < len(m.isSet) {
|
||||
m.isSet[scope].Set(0)
|
||||
m.hasErrors[scope].Set(0)
|
||||
}
|
||||
}
|
||||
|
||||
// metricFn is a function that adds or sets a metric value.
|
||||
type metricFn = func(name string, typ clientmetric.Type, v int64)
|
||||
|
||||
// funcMetric implements [metric] by calling the specified add and set functions.
|
||||
// Used for testing, and with nil functions on platforms that do not support
|
||||
// syspolicy, and on platforms that report policy metrics from the GUI.
|
||||
type funcMetric struct {
|
||||
name string
|
||||
typ clientmetric.Type
|
||||
add, set metricFn
|
||||
}
|
||||
|
||||
func (m funcMetric) Add(v int64) {
|
||||
if m.add != nil {
|
||||
m.add(m.name, m.typ, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (m funcMetric) Set(v int64) {
|
||||
if m.set != nil {
|
||||
m.set(m.name, m.typ, v)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
lazyDeviceMetrics lazy.SyncValue[*policyScopeMetrics]
|
||||
lazyProfileMetrics lazy.SyncValue[*policyScopeMetrics]
|
||||
lazyUserMetrics lazy.SyncValue[*policyScopeMetrics]
|
||||
)
|
||||
|
||||
func scopeMetrics(origin *setting.Origin) *policyScopeMetrics {
|
||||
switch origin.Scope().Kind() {
|
||||
case setting.DeviceSetting:
|
||||
return lazyDeviceMetrics.Get(func() *policyScopeMetrics {
|
||||
return newScopeMetrics(setting.DeviceSetting)
|
||||
})
|
||||
case setting.ProfileSetting:
|
||||
return lazyProfileMetrics.Get(func() *policyScopeMetrics {
|
||||
return newScopeMetrics(setting.ProfileSetting)
|
||||
})
|
||||
case setting.UserSetting:
|
||||
return lazyUserMetrics.Get(func() *policyScopeMetrics {
|
||||
return newScopeMetrics(setting.UserSetting)
|
||||
})
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
settingMetricsMu sync.RWMutex
|
||||
settingMetricsMap map[setting.Key]*settingMetrics
|
||||
)
|
||||
|
||||
func settingMetricsFor(setting *setting.Definition) *settingMetrics {
|
||||
settingMetricsMu.RLock()
|
||||
if metrics, ok := settingMetricsMap[setting.Key()]; ok {
|
||||
settingMetricsMu.RUnlock()
|
||||
return metrics
|
||||
}
|
||||
settingMetricsMu.RUnlock()
|
||||
return settingMetricsForSlow(setting)
|
||||
}
|
||||
|
||||
func settingMetricsForSlow(d *setting.Definition) *settingMetrics {
|
||||
settingMetricsMu.Lock()
|
||||
defer settingMetricsMu.Unlock()
|
||||
if metrics, ok := settingMetricsMap[d.Key()]; ok {
|
||||
return metrics
|
||||
}
|
||||
|
||||
isSet := make([]metric, d.Scope()+1)
|
||||
hasErrors := make([]metric, d.Scope()+1)
|
||||
for i := range isSet {
|
||||
scope := setting.Scope(i)
|
||||
// {os}_syspolicy_{key}_{scope_unless_device}
|
||||
// Example: windows_syspolicy_AdminConsole or windows_syspolicy_AdminConsole_user.
|
||||
isSet[i] = newSettingMetric(d.Key(), scope, "", clientmetric.TypeGauge)
|
||||
// {os}_syspolicy_{key}_{scope_unless_device}_error
|
||||
// Example: windows_syspolicy_AdminConsole_error or windows_syspolicy_TestSetting01_user_error.
|
||||
hasErrors[i] = newSettingMetric(d.Key(), scope, "error", clientmetric.TypeGauge)
|
||||
}
|
||||
metrics := &settingMetrics{d, isSet, hasErrors}
|
||||
mak.Set(&settingMetricsMap, d.Key(), metrics)
|
||||
return metrics
|
||||
}
|
||||
|
||||
// hooks for testing
|
||||
var addMetricTestHook, setMetricTestHook syncs.AtomicValue[metricFn]
|
||||
|
||||
// SetHooksForTest sets the specified addMetric and setMetric functions
|
||||
// as the metric functions for the duration of tb and all its subtests.
|
||||
func SetHooksForTest(tb internal.TB, addMetric, setMetric metricFn) {
|
||||
oldAddMetric := addMetricTestHook.Swap(addMetric)
|
||||
oldSetMetric := setMetricTestHook.Swap(setMetric)
|
||||
tb.Cleanup(func() {
|
||||
addMetricTestHook.Store(oldAddMetric)
|
||||
setMetricTestHook.Store(oldSetMetric)
|
||||
})
|
||||
|
||||
settingMetricsMu.Lock()
|
||||
oldSettingMetricsMap := xmaps.Clone(settingMetricsMap)
|
||||
clear(settingMetricsMap)
|
||||
settingMetricsMu.Unlock()
|
||||
tb.Cleanup(func() {
|
||||
settingMetricsMu.Lock()
|
||||
settingMetricsMap = oldSettingMetricsMap
|
||||
settingMetricsMu.Unlock()
|
||||
})
|
||||
|
||||
// (re-)set the scope metrics to use the test hooks for the duration of tb.
|
||||
lazyDeviceMetrics.SetForTest(tb, newScopeMetrics(setting.DeviceSetting), nil)
|
||||
lazyProfileMetrics.SetForTest(tb, newScopeMetrics(setting.ProfileSetting), nil)
|
||||
lazyUserMetrics.SetForTest(tb, newScopeMetrics(setting.UserSetting), nil)
|
||||
}
|
||||
|
||||
func newSettingMetric(key setting.Key, scope setting.Scope, suffix string, typ clientmetric.Type) metric {
|
||||
name := strings.ReplaceAll(string(key), setting.KeyPathSeparator, "_")
|
||||
if tag := metricScopeName(scope); tag != "" {
|
||||
name += "_" + tag
|
||||
}
|
||||
if suffix != "" {
|
||||
name += "_" + suffix
|
||||
}
|
||||
return newMetric(name, typ)
|
||||
}
|
||||
|
||||
func newMetric(name string, typ clientmetric.Type) metric {
|
||||
name = internal.OS() + "_syspolicy_" + name
|
||||
switch {
|
||||
case !ShouldReport():
|
||||
return &funcMetric{name: name, typ: typ}
|
||||
case testenv.InTest():
|
||||
return &funcMetric{name, typ, addMetricTestHook.Load(), setMetricTestHook.Load()}
|
||||
case typ == clientmetric.TypeCounter:
|
||||
return clientmetric.NewCounter(name)
|
||||
case typ == clientmetric.TypeGauge:
|
||||
return clientmetric.NewGauge(name)
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
func metricScopeName(scope setting.Scope) string {
|
||||
switch scope {
|
||||
case setting.DeviceSetting:
|
||||
return ""
|
||||
case setting.ProfileSetting:
|
||||
return "profile"
|
||||
case setting.UserSetting:
|
||||
return "user"
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
423
util/syspolicy/internal/metrics/metrics_test.go
Normal file
423
util/syspolicy/internal/metrics/metrics_test.go
Normal file
@ -0,0 +1,423 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
func TestSettingMetricNames(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key setting.Key
|
||||
scope setting.Scope
|
||||
suffix string
|
||||
typ clientmetric.Type
|
||||
osOverride string
|
||||
wantMetricName string
|
||||
}{
|
||||
{
|
||||
name: "windows-device-no-suffix",
|
||||
key: "AdminConsole",
|
||||
scope: setting.DeviceSetting,
|
||||
suffix: "",
|
||||
typ: clientmetric.TypeCounter,
|
||||
osOverride: "windows",
|
||||
wantMetricName: "windows_syspolicy_AdminConsole",
|
||||
},
|
||||
{
|
||||
name: "windows-user-no-suffix",
|
||||
key: "AdminConsole",
|
||||
scope: setting.UserSetting,
|
||||
suffix: "",
|
||||
typ: clientmetric.TypeCounter,
|
||||
osOverride: "windows",
|
||||
wantMetricName: "windows_syspolicy_AdminConsole_user",
|
||||
},
|
||||
{
|
||||
name: "windows-profile-no-suffix",
|
||||
key: "AdminConsole",
|
||||
scope: setting.ProfileSetting,
|
||||
suffix: "",
|
||||
typ: clientmetric.TypeCounter,
|
||||
osOverride: "windows",
|
||||
wantMetricName: "windows_syspolicy_AdminConsole_profile",
|
||||
},
|
||||
{
|
||||
name: "windows-profile-err",
|
||||
key: "AdminConsole",
|
||||
scope: setting.ProfileSetting,
|
||||
suffix: "error",
|
||||
typ: clientmetric.TypeCounter,
|
||||
osOverride: "windows",
|
||||
wantMetricName: "windows_syspolicy_AdminConsole_profile_error",
|
||||
},
|
||||
{
|
||||
name: "android-device-no-suffix",
|
||||
key: "AdminConsole",
|
||||
scope: setting.DeviceSetting,
|
||||
suffix: "",
|
||||
typ: clientmetric.TypeCounter,
|
||||
osOverride: "android",
|
||||
wantMetricName: "android_syspolicy_AdminConsole",
|
||||
},
|
||||
{
|
||||
name: "key-path",
|
||||
key: "category/subcategory/setting",
|
||||
scope: setting.DeviceSetting,
|
||||
suffix: "",
|
||||
typ: clientmetric.TypeCounter,
|
||||
osOverride: "fakeos",
|
||||
wantMetricName: "fakeos_syspolicy_category_subcategory_setting",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
internal.OSForTesting.SetForTest(t, tt.osOverride, nil)
|
||||
metric, ok := newSettingMetric(tt.key, tt.scope, tt.suffix, tt.typ).(*funcMetric)
|
||||
if !ok {
|
||||
t.Fatal("metric is not a funcMetric")
|
||||
}
|
||||
if metric.name != tt.wantMetricName {
|
||||
t.Errorf("got %q, want %q", metric.name, tt.wantMetricName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopeMetrics(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scope setting.Scope
|
||||
osOverride string
|
||||
wantHasAnyName string
|
||||
wantNumErroredName string
|
||||
wantHasAnyType clientmetric.Type
|
||||
wantNumErroredType clientmetric.Type
|
||||
}{
|
||||
{
|
||||
name: "windows-device",
|
||||
scope: setting.DeviceSetting,
|
||||
osOverride: "windows",
|
||||
wantHasAnyName: "windows_syspolicy_any",
|
||||
wantHasAnyType: clientmetric.TypeGauge,
|
||||
wantNumErroredName: "windows_syspolicy_errors",
|
||||
wantNumErroredType: clientmetric.TypeCounter,
|
||||
},
|
||||
{
|
||||
name: "windows-profile",
|
||||
scope: setting.ProfileSetting,
|
||||
osOverride: "windows",
|
||||
wantHasAnyName: "windows_syspolicy_profile_any",
|
||||
wantHasAnyType: clientmetric.TypeGauge,
|
||||
wantNumErroredName: "windows_syspolicy_profile_errors",
|
||||
wantNumErroredType: clientmetric.TypeCounter,
|
||||
},
|
||||
{
|
||||
name: "windows-user",
|
||||
scope: setting.UserSetting,
|
||||
osOverride: "windows",
|
||||
wantHasAnyName: "windows_syspolicy_user_any",
|
||||
wantHasAnyType: clientmetric.TypeGauge,
|
||||
wantNumErroredName: "windows_syspolicy_user_errors",
|
||||
wantNumErroredType: clientmetric.TypeCounter,
|
||||
},
|
||||
{
|
||||
name: "android-device",
|
||||
scope: setting.DeviceSetting,
|
||||
osOverride: "android",
|
||||
wantHasAnyName: "android_syspolicy_any",
|
||||
wantHasAnyType: clientmetric.TypeGauge,
|
||||
wantNumErroredName: "android_syspolicy_errors",
|
||||
wantNumErroredType: clientmetric.TypeCounter,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
internal.OSForTesting.SetForTest(t, tt.osOverride, nil)
|
||||
metrics := newScopeMetrics(tt.scope)
|
||||
hasAny, ok := metrics.hasAny.(*funcMetric)
|
||||
if !ok {
|
||||
t.Fatal("hasAny is not a funcMetric")
|
||||
}
|
||||
numErrored, ok := metrics.numErrored.(*funcMetric)
|
||||
if !ok {
|
||||
t.Fatal("numErrored is not a funcMetric")
|
||||
}
|
||||
if hasAny.name != tt.wantHasAnyName {
|
||||
t.Errorf("hasAny.Name: got %q, want %q", hasAny.name, tt.wantHasAnyName)
|
||||
}
|
||||
if hasAny.typ != tt.wantHasAnyType {
|
||||
t.Errorf("hasAny.Type: got %q, want %q", hasAny.typ, tt.wantHasAnyType)
|
||||
}
|
||||
if numErrored.name != tt.wantNumErroredName {
|
||||
t.Errorf("numErrored.Name: got %q, want %q", numErrored.name, tt.wantNumErroredName)
|
||||
}
|
||||
if numErrored.typ != tt.wantNumErroredType {
|
||||
t.Errorf("hasAny.Type: got %q, want %q", numErrored.typ, tt.wantNumErroredType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testSettingDetails struct {
|
||||
definition *setting.Definition
|
||||
origin *setting.Origin
|
||||
value any
|
||||
err error
|
||||
}
|
||||
|
||||
func TestReportMetrics(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
osOverride string
|
||||
useMetrics bool
|
||||
settings []testSettingDetails
|
||||
wantMetrics []TestState
|
||||
wantResetMetrics []TestState
|
||||
}{
|
||||
{
|
||||
name: "none",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{},
|
||||
wantMetrics: []TestState{},
|
||||
},
|
||||
{
|
||||
name: "single-value",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
},
|
||||
wantMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 1},
|
||||
{"windows_syspolicy_TestSetting01", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 0},
|
||||
{"windows_syspolicy_TestSetting01", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single-error",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
err: errors.New("bang!"),
|
||||
},
|
||||
},
|
||||
wantMetrics: []TestState{
|
||||
{"windows_syspolicy_errors", 1},
|
||||
{"windows_syspolicy_TestSetting02_error", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"windows_syspolicy_errors", 1},
|
||||
{"windows_syspolicy_TestSetting02_error", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "value-and-error",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
err: errors.New("bang!"),
|
||||
},
|
||||
},
|
||||
|
||||
wantMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 1},
|
||||
{"windows_syspolicy_errors", 1},
|
||||
{"windows_syspolicy_TestSetting01", 1},
|
||||
{"windows_syspolicy_TestSetting02_error", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 0},
|
||||
{"windows_syspolicy_errors", 1},
|
||||
{"windows_syspolicy_TestSetting01", 0},
|
||||
{"windows_syspolicy_TestSetting02_error", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two-values",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 17,
|
||||
},
|
||||
},
|
||||
wantMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 1},
|
||||
{"windows_syspolicy_TestSetting01", 1},
|
||||
{"windows_syspolicy_TestSetting02", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 0},
|
||||
{"windows_syspolicy_TestSetting01", 0},
|
||||
{"windows_syspolicy_TestSetting02", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two-errors",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
err: errors.New("bang!"),
|
||||
},
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting02", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
err: errors.New("bang!"),
|
||||
},
|
||||
},
|
||||
wantMetrics: []TestState{
|
||||
{"windows_syspolicy_errors", 2},
|
||||
{"windows_syspolicy_TestSetting01_error", 1},
|
||||
{"windows_syspolicy_TestSetting02_error", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"windows_syspolicy_errors", 2},
|
||||
{"windows_syspolicy_TestSetting01_error", 0},
|
||||
{"windows_syspolicy_TestSetting02_error", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multi-scope",
|
||||
osOverride: "windows",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.ProfileSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting02", setting.ProfileSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.CurrentProfileScope),
|
||||
err: errors.New("bang!"),
|
||||
},
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting03", setting.UserSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.CurrentUserScope),
|
||||
value: 17,
|
||||
},
|
||||
},
|
||||
wantMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 1},
|
||||
{"windows_syspolicy_profile_errors", 1},
|
||||
{"windows_syspolicy_user_any", 1},
|
||||
{"windows_syspolicy_TestSetting01", 1},
|
||||
{"windows_syspolicy_TestSetting02_profile_error", 1},
|
||||
{"windows_syspolicy_TestSetting03_user", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"windows_syspolicy_any", 0},
|
||||
{"windows_syspolicy_profile_errors", 1},
|
||||
{"windows_syspolicy_user_any", 0},
|
||||
{"windows_syspolicy_TestSetting01", 0},
|
||||
{"windows_syspolicy_TestSetting02_profile_error", 0},
|
||||
{"windows_syspolicy_TestSetting03_user", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "report-metrics-on-android",
|
||||
osOverride: "android",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
},
|
||||
wantMetrics: []TestState{
|
||||
{"android_syspolicy_any", 1},
|
||||
{"android_syspolicy_TestSetting01", 1},
|
||||
},
|
||||
wantResetMetrics: []TestState{
|
||||
{"android_syspolicy_any", 0},
|
||||
{"android_syspolicy_TestSetting01", 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "do-not-report-metrics-on-macos",
|
||||
osOverride: "macos",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
},
|
||||
|
||||
wantMetrics: []TestState{}, // none reported
|
||||
},
|
||||
{
|
||||
name: "do-not-report-metrics-on-ios",
|
||||
osOverride: "ios",
|
||||
settings: []testSettingDetails{
|
||||
{
|
||||
definition: setting.NewDefinition("TestSetting01", setting.DeviceSetting, setting.IntegerValue),
|
||||
origin: setting.NewOrigin(setting.DeviceScope),
|
||||
value: 42,
|
||||
},
|
||||
},
|
||||
|
||||
wantMetrics: []TestState{}, // none reported
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset the lazy value so it'll be re-evaluated with the osOverride.
|
||||
lazyReportMetrics = lazy.SyncValue[bool]{}
|
||||
t.Cleanup(func() {
|
||||
// Also reset it during the cleanup.
|
||||
lazyReportMetrics = lazy.SyncValue[bool]{}
|
||||
})
|
||||
internal.OSForTesting.SetForTest(t, tt.osOverride, nil)
|
||||
|
||||
h := NewTestHandler(t)
|
||||
SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
|
||||
for _, s := range tt.settings {
|
||||
if s.err != nil {
|
||||
ReportError(s.origin, s.definition, s.err)
|
||||
} else {
|
||||
ReportConfigured(s.origin, s.definition, s.value)
|
||||
}
|
||||
}
|
||||
h.MustEqual(tt.wantMetrics...)
|
||||
|
||||
for _, s := range tt.settings {
|
||||
Reset(s.origin)
|
||||
ReportNotConfigured(s.origin, s.definition)
|
||||
}
|
||||
h.MustEqual(tt.wantResetMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
88
util/syspolicy/internal/metrics/test_handler.go
Normal file
88
util/syspolicy/internal/metrics/test_handler.go
Normal file
@ -0,0 +1,88 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
)
|
||||
|
||||
// TestState represents a metric name and its expected value.
|
||||
type TestState struct {
|
||||
Name string // `$os` in the name will be replaced by the actual operating system name.`
|
||||
Value int64
|
||||
}
|
||||
|
||||
// TestHandler facilitates testing of the code that uses metrics.
|
||||
type TestHandler struct {
|
||||
t internal.TB
|
||||
|
||||
m map[string]int64
|
||||
}
|
||||
|
||||
// NewTestHandler returns a new TestHandler.
|
||||
func NewTestHandler(t internal.TB) *TestHandler {
|
||||
return &TestHandler{t, make(map[string]int64)}
|
||||
}
|
||||
|
||||
// AddMetric increments the metric with the specified name and type by delta d.
|
||||
func (h *TestHandler) AddMetric(name string, typ clientmetric.Type, d int64) {
|
||||
h.t.Helper()
|
||||
if typ == clientmetric.TypeCounter && d < 0 {
|
||||
h.t.Fatalf("an attempt was made to decrement a counter metric %q", name)
|
||||
}
|
||||
if v, ok := h.m[name]; ok || d != 0 {
|
||||
h.m[name] = v + d
|
||||
}
|
||||
}
|
||||
|
||||
// SetMetric sets the metric with the specified name and type to the value v.
|
||||
func (h *TestHandler) SetMetric(name string, typ clientmetric.Type, v int64) {
|
||||
h.t.Helper()
|
||||
if typ == clientmetric.TypeCounter {
|
||||
h.t.Fatalf("an attempt was made to set a counter metric %q", name)
|
||||
}
|
||||
if _, ok := h.m[name]; ok || v != 0 {
|
||||
h.m[name] = v
|
||||
}
|
||||
}
|
||||
|
||||
// MustEqual fails the test if the actual metric state differs from the specified state.
|
||||
func (h *TestHandler) MustEqual(metrics ...TestState) {
|
||||
h.t.Helper()
|
||||
h.MustContain(metrics...)
|
||||
h.mustNoExtra(metrics...)
|
||||
}
|
||||
|
||||
// MustContain fails the test if the specified metrics are not set or have
|
||||
// different values than specified. It permits other metrics to be set in
|
||||
// addition to the ones being tested.
|
||||
func (h *TestHandler) MustContain(metrics ...TestState) {
|
||||
h.t.Helper()
|
||||
for _, m := range metrics {
|
||||
name := strings.ReplaceAll(m.Name, "$os", internal.OS())
|
||||
v, ok := h.m[name]
|
||||
if !ok {
|
||||
h.t.Errorf("%q: got (none), want %v", name, m.Value)
|
||||
} else if v != m.Value {
|
||||
h.t.Fatalf("%q: got %v, want %v", name, v, m.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *TestHandler) mustNoExtra(metrics ...TestState) {
|
||||
h.t.Helper()
|
||||
s := make(set.Set[string])
|
||||
for i := range metrics {
|
||||
s.Add(strings.ReplaceAll(metrics[i].Name, "$os", internal.OS()))
|
||||
}
|
||||
for n, v := range h.m {
|
||||
if !s.Contains(n) {
|
||||
h.t.Errorf("%q: got %v, want (none)", n, v)
|
||||
}
|
||||
}
|
||||
}
|
@ -3,7 +3,21 @@
|
||||
|
||||
package syspolicy
|
||||
|
||||
type Key string
|
||||
import (
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/util/syspolicy/internal/lazyinit"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/testenv"
|
||||
)
|
||||
|
||||
type Key = setting.Key
|
||||
|
||||
// The const block below lists known policy keys.
|
||||
// When adding a key to this list, remember to add a corresponding
|
||||
// [setting.Definition] to [implicitDefinitions] below.
|
||||
// Otherwise, the [TestKnownKeysRegistered] test will fail as a reminder.
|
||||
// Preferably, use a strongly typed policy hierarchy, such as [Policy],
|
||||
// instead of adding each key to the list below.
|
||||
|
||||
const (
|
||||
// Keys with a string value
|
||||
@ -96,3 +110,83 @@
|
||||
// AllowedSuggestedExitNodes's string array value is a list of exit node IDs that restricts which exit nodes are considered when generating suggestions for exit nodes.
|
||||
AllowedSuggestedExitNodes Key = "AllowedSuggestedExitNodes"
|
||||
)
|
||||
|
||||
// implicitDefinitions is a list of [setting.Definition] that will be registered
|
||||
// automatically by [settingDefinitions] as soon as the package needs to ready a policy.
|
||||
var implicitDefinitions = []*setting.Definition{
|
||||
// Device policy settings
|
||||
setting.NewDefinition(AllowedSuggestedExitNodes, setting.DeviceSetting, setting.StringListValue),
|
||||
setting.NewDefinition(ApplyUpdates, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(CheckUpdates, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(ControlURL, setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition(DeviceSerialNumber, setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition(EnableIncomingConnections, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(EnableRunExitNode, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(EnableServerMode, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(EnableTailscaleDNS, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(EnableTailscaleSubnets, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(ExitNodeAllowLANAccess, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(ExitNodeID, setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition(ExitNodeIP, setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition(FlushDNSOnSessionUnlock, setting.DeviceSetting, setting.BooleanValue),
|
||||
setting.NewDefinition(LogSCMInteractions, setting.DeviceSetting, setting.BooleanValue),
|
||||
setting.NewDefinition(LogTarget, setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition(PostureChecking, setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition(Tailnet, setting.DeviceSetting, setting.StringValue),
|
||||
|
||||
// User policy settings
|
||||
setting.NewDefinition(AdminConsoleVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(AutoUpdateVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(ExitNodeMenuVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(KeyExpirationNoticeTime, setting.UserSetting, setting.DurationValue),
|
||||
setting.NewDefinition(ManagedByCaption, setting.UserSetting, setting.StringValue),
|
||||
setting.NewDefinition(ManagedByOrganizationName, setting.UserSetting, setting.StringValue),
|
||||
setting.NewDefinition(ManagedByURL, setting.UserSetting, setting.StringValue),
|
||||
setting.NewDefinition(NetworkDevicesVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(PreferencesMenuVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(ResetToDefaultsVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(RunExitNodeVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(SuggestedExitNodeVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(TestMenuVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
setting.NewDefinition(UpdateMenuVisibility, setting.UserSetting, setting.VisibilityValue),
|
||||
}
|
||||
|
||||
func init() {
|
||||
lazyinit.Defer(func() error {
|
||||
// Avoid implicit [SettingDefinition] registration during tests.
|
||||
// Each test should control which policy settings to register.
|
||||
// Use [setting.SetDefinitionsForTest] to specify necessary definitions,
|
||||
// or [setWellKnownSettingsForTest] to set implicit definitions for the test duration.
|
||||
if testenv.InTest() {
|
||||
return nil
|
||||
}
|
||||
for _, d := range implicitDefinitions {
|
||||
setting.RegisterDefinition(d)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
var implicitDefinitionMap lazy.SyncValue[setting.DefinitionMap]
|
||||
|
||||
// WellKnownSettingDefinition returns a well-known, implicit setting definition by its key,
|
||||
// or an [ErrNoSuchKey] if a policy setting with the specified key does not exist
|
||||
// among implicit policy definitions.
|
||||
func WellKnownSettingDefinition(k Key) (*setting.Definition, error) {
|
||||
m, err := implicitDefinitionMap.GetErr(func() (setting.DefinitionMap, error) {
|
||||
return setting.DefinitionMapOf(implicitDefinitions)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d, ok := m[k]; ok {
|
||||
return d, nil
|
||||
}
|
||||
return nil, ErrNoSuchKey
|
||||
}
|
||||
|
||||
// setWellKnownSettingsForTest registers all implicit setting definitions
|
||||
// for the duration of the test.
|
||||
func setWellKnownSettingsForTest(tb lazy.TB) error {
|
||||
return setting.SetDefinitionsForTest(tb, implicitDefinitions...)
|
||||
}
|
||||
|
95
util/syspolicy/policy_keys_test.go
Normal file
95
util/syspolicy/policy_keys_test.go
Normal file
@ -0,0 +1,95 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
func TestKnownKeysRegistered(t *testing.T) {
|
||||
keyConsts, err := listStringConsts[Key]("policy_keys.go")
|
||||
if err != nil {
|
||||
t.Fatalf("listStringConsts failed: %v", err)
|
||||
}
|
||||
|
||||
m, err := setting.DefinitionMapOf(implicitDefinitions)
|
||||
if err != nil {
|
||||
t.Fatalf("definitionMapOf failed: %v", err)
|
||||
}
|
||||
|
||||
for _, key := range keyConsts {
|
||||
t.Run(string(key), func(t *testing.T) {
|
||||
d := m[key]
|
||||
if d == nil {
|
||||
t.Fatalf("%q was not registered", key)
|
||||
}
|
||||
if d.Key() != key {
|
||||
t.Fatalf("d.Key got: %s, want %s", d.Key(), key)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotAWellKnownSetting(t *testing.T) {
|
||||
d, err := WellKnownSettingDefinition("TestSettingDoesNotExist")
|
||||
if d != nil || err == nil {
|
||||
t.Fatalf("got %v, %v; want nil, %v", d, err, ErrNoSuchKey)
|
||||
}
|
||||
}
|
||||
|
||||
func listStringConsts[T ~string](filename string) (map[string]T, error) {
|
||||
fset := token.NewFileSet()
|
||||
src, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := parser.ParseFile(fset, filename, src, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
consts := make(map[string]T)
|
||||
typeName := reflect.TypeFor[T]().Name()
|
||||
for _, d := range f.Decls {
|
||||
g, ok := d.(*ast.GenDecl)
|
||||
if !ok || g.Tok != token.CONST {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, s := range g.Specs {
|
||||
vs, ok := s.(*ast.ValueSpec)
|
||||
if !ok || len(vs.Names) != len(vs.Values) {
|
||||
continue
|
||||
}
|
||||
if typ, ok := vs.Type.(*ast.Ident); !ok || typ.Name != typeName {
|
||||
continue
|
||||
}
|
||||
|
||||
for i, n := range vs.Names {
|
||||
lit, ok := vs.Values[i].(*ast.BasicLit)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, types.ExprString(vs.Values[i]))
|
||||
}
|
||||
val, err := strconv.Unquote(lit.Value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, lit.Value)
|
||||
}
|
||||
consts[n.Name] = T(val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return consts, nil
|
||||
}
|
@ -1,38 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
var stringKeys = []Key{
|
||||
ControlURL,
|
||||
LogTarget,
|
||||
Tailnet,
|
||||
ExitNodeID,
|
||||
ExitNodeIP,
|
||||
EnableIncomingConnections,
|
||||
EnableServerMode,
|
||||
ExitNodeAllowLANAccess,
|
||||
EnableTailscaleDNS,
|
||||
EnableTailscaleSubnets,
|
||||
AdminConsoleVisibility,
|
||||
NetworkDevicesVisibility,
|
||||
TestMenuVisibility,
|
||||
UpdateMenuVisibility,
|
||||
RunExitNodeVisibility,
|
||||
PreferencesMenuVisibility,
|
||||
ExitNodeMenuVisibility,
|
||||
AutoUpdateVisibility,
|
||||
ResetToDefaultsVisibility,
|
||||
KeyExpirationNoticeTime,
|
||||
PostureChecking,
|
||||
ManagedByOrganizationName,
|
||||
ManagedByCaption,
|
||||
ManagedByURL,
|
||||
}
|
||||
|
||||
var boolKeys = []Key{
|
||||
LogSCMInteractions,
|
||||
FlushDNSOnSessionUnlock,
|
||||
}
|
||||
|
||||
var uint64Keys = []Key{}
|
109
util/syspolicy/rsop/change_callbacks.go
Normal file
109
util/syspolicy/rsop/change_callbacks.go
Normal file
@ -0,0 +1,109 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package rsop
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/util/syspolicy/internal/loggerx"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
// Change represents a change from the Old to the New value of type T.
|
||||
type Change[T any] struct {
|
||||
New, Old T
|
||||
}
|
||||
|
||||
// PolicyChangeCallback is a function called whenever a policy changes.
|
||||
type PolicyChangeCallback func(*PolicyChange)
|
||||
|
||||
// PolicyChange describes a policy change.
|
||||
type PolicyChange struct {
|
||||
snapshots Change[*setting.Snapshot]
|
||||
}
|
||||
|
||||
// New returns the [setting.Snapshot] after the change.
|
||||
func (c PolicyChange) New() *setting.Snapshot {
|
||||
return c.snapshots.New
|
||||
}
|
||||
|
||||
// Old returns the [setting.Snapshot] before the change.
|
||||
func (c PolicyChange) Old() *setting.Snapshot {
|
||||
return c.snapshots.Old
|
||||
}
|
||||
|
||||
// HasChanged reports whether a policy setting with the specified [setting.Key], has changed.
|
||||
func (c PolicyChange) HasChanged(key setting.Key) bool {
|
||||
new, newErr := c.snapshots.New.GetErr(key)
|
||||
old, oldErr := c.snapshots.Old.GetErr(key)
|
||||
if newErr != nil && oldErr != nil {
|
||||
return false
|
||||
}
|
||||
if newErr != nil || oldErr != nil {
|
||||
return true
|
||||
}
|
||||
switch newVal := new.(type) {
|
||||
case bool, uint64, string, setting.Visibility, setting.PreferenceOption, time.Duration:
|
||||
return newVal != old
|
||||
case []string:
|
||||
if oldVal, ok := old.([]string); ok {
|
||||
return slices.Equal(newVal, oldVal)
|
||||
}
|
||||
return false
|
||||
default:
|
||||
loggerx.Errorf("%q has an unsupported value type: %T", newVal)
|
||||
return reflect.DeepEqual(new, old)
|
||||
}
|
||||
}
|
||||
|
||||
// policyChangeCallbacks are the callbacks to invoke when the resultant policy changes.
|
||||
// It is safe for concurrent use.
|
||||
type policyChangeCallbacks struct {
|
||||
mu sync.RWMutex
|
||||
cbs set.HandleSet[PolicyChangeCallback]
|
||||
}
|
||||
|
||||
// Register adds the specified callback to be invoked whenever the policy changes.
|
||||
func (c *policyChangeCallbacks) Register(callback PolicyChangeCallback) (unregister func()) {
|
||||
c.mu.Lock()
|
||||
handle := c.cbs.Add(callback)
|
||||
c.mu.Unlock()
|
||||
return func() {
|
||||
c.mu.Lock()
|
||||
delete(c.cbs, handle)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Invoke calls the registered callback functions with the specified policy change info.
|
||||
func (c *policyChangeCallbacks) Invoke(snapshots Change[*setting.Snapshot]) {
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
wg.Add(len(c.cbs))
|
||||
change := &PolicyChange{snapshots: snapshots}
|
||||
for _, cb := range c.cbs {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
cb(change)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Close awaits the completion of active callbacks and prevents any further invocations.
|
||||
func (c *policyChangeCallbacks) Close() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.cbs != nil {
|
||||
clear(c.cbs)
|
||||
c.cbs = nil
|
||||
}
|
||||
}
|
698
util/syspolicy/rsop/resultant_policy.go
Normal file
698
util/syspolicy/rsop/resultant_policy.go
Normal file
@ -0,0 +1,698 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package rsop facilitates [source.Store] registration via [RegisterStore]
|
||||
// and provides access to the resultant policy merged from all registered sources
|
||||
// via [PolicyFor].
|
||||
package rsop
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"tailscale.com/syncs"
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/util/slicesx"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/internal/lazyinit"
|
||||
"tailscale.com/util/syspolicy/internal/loggerx"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
|
||||
"tailscale.com/util/syspolicy/source"
|
||||
)
|
||||
|
||||
var errResultantPolicyClosed = errors.New("resultant policy closed")
|
||||
|
||||
// The minimum and maximum wait times after detecting a policy change
|
||||
// before reloading the policy.
|
||||
// Policy changes occurring within [policyReloadMinDelay] of each other
|
||||
// will be batched together, resulting in a single policy reload
|
||||
// no later than [policyReloadMaxDelay] after the first detected change.
|
||||
// In other words, the resultant policy will be reloaded no more often than once
|
||||
// every 5 seconds, but at most 15 seconds after an underlying [source.Store]
|
||||
// has issued a policy change callback.
|
||||
// See [Policy.watchReload].
|
||||
const (
|
||||
defaultPolicyReloadMinDelay = 5 * time.Second
|
||||
defaultPolicyReloadMaxDelay = 15 * time.Second
|
||||
)
|
||||
|
||||
// policyReloadMinDelay and policyReloadMaxDelay are test hooks.
|
||||
// Their values default to [defaultPolicyReloadMinDelay] and [defaultPolicyReloadMaxDelay].
|
||||
var (
|
||||
policyReloadMinDelay, policyReloadMaxDelay lazy.SyncValue[time.Duration]
|
||||
)
|
||||
|
||||
// Policy provides access to the current resultant [setting.Snapshot] for a given
|
||||
// scope and allows to reload it from the underlying [source.Store]s. It also allows to
|
||||
// subscribe and receive a callback whenever the resultant [setting.Snapshot] is
|
||||
// changed. It is safe for concurrent use.
|
||||
type Policy struct {
|
||||
scope setting.PolicyScope
|
||||
|
||||
reloadCh chan reloadRequest // 1-buffered; written to when a policy reload is required
|
||||
changeSourceCh chan sourceChangeRequest // written to to add a new or remove an existing source
|
||||
closeCh chan struct{} // closed to signal that the Policy is being closed
|
||||
doneCh chan struct{} // closed by closeInternal when watchReload exits
|
||||
|
||||
// resultant is the most recent version of the [setting.Snapshot] containing policy settings
|
||||
// merged from all applicable sources.
|
||||
resultant atomic.Pointer[setting.Snapshot]
|
||||
|
||||
changeCallbacks policyChangeCallbacks
|
||||
|
||||
mu sync.RWMutex
|
||||
sources source.ReadableSources
|
||||
closing bool // Close was called (even if we're still closing)
|
||||
}
|
||||
|
||||
// newPolicy returns a new [Policy] for the specified [setting.PolicyScope]
|
||||
// that tracks changes and merges policy settings read from the specified sources.
|
||||
func newPolicy(scope setting.PolicyScope, sources ...*source.Source) (p *Policy, err error) {
|
||||
readableSources := source.ReadableSources(make([]source.ReadableSource, len(sources)))
|
||||
for i, s := range sources {
|
||||
reader, err := s.Reader()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get a store reader: %v", err)
|
||||
}
|
||||
session, err := reader.OpenSession()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open a reading session: %v", err)
|
||||
}
|
||||
|
||||
readableSource := source.ReadableSource{
|
||||
Source: s,
|
||||
ReadingSession: session,
|
||||
}
|
||||
readableSources[i] = readableSource
|
||||
defer func() {
|
||||
if err != nil {
|
||||
readableSource.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Sort policy sources by their precedence from lower to higher.
|
||||
// For example, {UserPolicy},{ProfilePolicy},{DevicePolicy}.
|
||||
readableSources.StableSort()
|
||||
|
||||
p = &Policy{
|
||||
scope: scope,
|
||||
sources: readableSources,
|
||||
reloadCh: make(chan reloadRequest, 1),
|
||||
changeSourceCh: make(chan sourceChangeRequest),
|
||||
closeCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
if err := p.start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// IsValid reports whether p is in a valid state and has not been closed.
|
||||
func (p *Policy) IsValid() bool {
|
||||
select {
|
||||
case <-p.closeCh:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Scope returns the [setting.PolicyScope] that this resultant policy applies to.
|
||||
func (p *Policy) Scope() setting.PolicyScope {
|
||||
return p.scope
|
||||
}
|
||||
|
||||
// Get returns the most recent resultant [setting.Snapshot].
|
||||
func (p *Policy) Get() *setting.Snapshot {
|
||||
return p.resultant.Load()
|
||||
}
|
||||
|
||||
// RegisterChangeCallback adds a function to be called whenever the resultant
|
||||
// policy changes. The returned function can be used to unregister the callback.
|
||||
func (p *Policy) RegisterChangeCallback(callback PolicyChangeCallback) (unregister func()) {
|
||||
return p.changeCallbacks.Register(callback)
|
||||
}
|
||||
|
||||
// Reload synchronously re-reads policy settings from the underlying policy
|
||||
// [source.Store], constructing a new merged [setting.Snapshot] even if the policy remains
|
||||
// unchanged. In most scenarios, there's no need to re-read the policy manually.
|
||||
// Instead, it is recommended to register a policy change callback, or to use
|
||||
// the most recent [setting.Snapshot] returned by the [Policy.Get] method.
|
||||
func (p *Policy) Reload() (*setting.Snapshot, error) {
|
||||
return p.reload(true)
|
||||
}
|
||||
|
||||
// reload is like Reload, but allows to specify whether to re-read policy settings
|
||||
// from unchanged policy sources.
|
||||
func (p *Policy) reload(force bool) (*setting.Snapshot, error) {
|
||||
respCh := make(chan reloadResponse, 1)
|
||||
select {
|
||||
case p.reloadCh <- reloadRequest{force: force, respCh: respCh}:
|
||||
// continue
|
||||
case <-p.closeCh:
|
||||
return nil, errResultantPolicyClosed
|
||||
}
|
||||
select {
|
||||
case resp := <-respCh:
|
||||
return resp.policy, resp.err
|
||||
case <-p.closeCh:
|
||||
return nil, errResultantPolicyClosed
|
||||
}
|
||||
}
|
||||
|
||||
// Done returns a channel that is closed when the [Policy] is closed.
|
||||
func (p *Policy) Done() <-chan struct{} {
|
||||
return p.doneCh
|
||||
}
|
||||
|
||||
func (p *Policy) start() error {
|
||||
if _, err := p.reloadNow(false); err != nil {
|
||||
return err
|
||||
}
|
||||
go p.watchPolicyChanges()
|
||||
go p.watchReload()
|
||||
return nil
|
||||
}
|
||||
|
||||
// readAndMerge reads and merges policy settings from the underlying sources,
|
||||
// returning a [setting.Snapshot] with the merged result.
|
||||
// If the force parameter is true, it re-reads policy settings from each store
|
||||
// even if no policy change was observed, and returns an error if the read
|
||||
// operation fails.
|
||||
func (p *Policy) readAndMerge(force bool) (*setting.Snapshot, error) {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
// Start with an empty policy in the target scope.
|
||||
resultant := setting.NewSnapshot(nil, setting.SummaryWith(p.scope))
|
||||
// Then merge policy settings from all sources.
|
||||
// Policy sources with the highest precedence (e.g., the device policy) are merged last,
|
||||
// overriding any conflicting policy settings with lower precedence.
|
||||
for _, s := range p.sources {
|
||||
var policy *setting.Snapshot
|
||||
if force {
|
||||
var err error
|
||||
if policy, err = s.ReadSettings(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
policy = s.GetSettings()
|
||||
}
|
||||
resultant = setting.MergeSnapshots(resultant, policy)
|
||||
}
|
||||
return resultant, nil
|
||||
}
|
||||
|
||||
// reloadAsync requests an asynchronous background policy reload.
|
||||
// The policy will be reloaded no later than in [policyReloadMaxDelay].
|
||||
func (p *Policy) reloadAsync() {
|
||||
select {
|
||||
case p.reloadCh <- reloadRequest{}:
|
||||
// Sent.
|
||||
default:
|
||||
// A reload request is already en route.
|
||||
}
|
||||
}
|
||||
|
||||
// reloadNow loads and merges policies from all sources, updating the resultant policy.
|
||||
// If the force parameter is true, it forcibly reloads policies
|
||||
// from the underlying policy store, even if no policy changes were detected.
|
||||
//
|
||||
// Except for the initial policy reload during the [Policy] creation,
|
||||
// this method should only be called from the [Policy.watchReload] goroutine.
|
||||
func (p *Policy) reloadNow(force bool) (*setting.Snapshot, error) {
|
||||
new, err := p.readAndMerge(force)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
old := p.resultant.Swap(new)
|
||||
// A nil old value indicates the initial policy load rather than a policy change.
|
||||
// Additionally, we should not invoke the policy change callbacks unless the
|
||||
// policy items have actually changed.
|
||||
if old != nil && !old.EqualItems(new) {
|
||||
snapshots := Change[*setting.Snapshot]{New: new, Old: old}
|
||||
p.changeCallbacks.Invoke(snapshots)
|
||||
}
|
||||
return new, nil
|
||||
}
|
||||
|
||||
// AddSource adds the specified source to the list of sources used by p,
|
||||
// and triggers a synchronous policy refresh. It returns an error
|
||||
// if the source is not a valid source for this resultant policy,
|
||||
// or if the resultant policy is being closed,
|
||||
// or if policy refresh fails with an error.
|
||||
func (p *Policy) AddSource(source *source.Source) error {
|
||||
return p.changeSource(source, nil)
|
||||
}
|
||||
|
||||
// RemoveSource removes the specified source from the list of sources used by p,
|
||||
// and triggers a synchronous policy refresh. It returns an error if the
|
||||
// resultant policy is being closed, or if policy refresh fails with an error.
|
||||
func (p *Policy) RemoveSource(source *source.Source) error {
|
||||
return p.changeSource(nil, source)
|
||||
}
|
||||
|
||||
// ReplaceSource replaces the old source with the new source atomically,
|
||||
// and triggers a synchronous policy refresh. It returns an error
|
||||
// if the source is not a valid source for this resultant policy,
|
||||
// or if the resultant policy is being closed,
|
||||
// or if policy refresh fails with an error.
|
||||
func (p *Policy) ReplaceSource(old, new *source.Source) error {
|
||||
return p.changeSource(new, old)
|
||||
}
|
||||
|
||||
func (p *Policy) changeSource(toAdd, toRemove *source.Source) error {
|
||||
if toAdd == toRemove {
|
||||
return nil
|
||||
}
|
||||
if toAdd != nil && !p.scope.IsWithinOf(toAdd.Scope()) {
|
||||
return errors.New("scope mismatch")
|
||||
}
|
||||
respCh := make(chan error, 1)
|
||||
req := sourceChangeRequest{toAdd, toRemove, respCh}
|
||||
select {
|
||||
case p.changeSourceCh <- req:
|
||||
return <-respCh
|
||||
case <-p.closeCh:
|
||||
return errResultantPolicyClosed
|
||||
}
|
||||
}
|
||||
|
||||
// watchPolicyChanges awaits a policy change notification from any of the sources
|
||||
// and calls reloadAsync whenever a notification is received.
|
||||
func (p *Policy) watchPolicyChanges() {
|
||||
const (
|
||||
closeIdx = iota
|
||||
changeSourceIdx
|
||||
policyChangedOffset
|
||||
)
|
||||
|
||||
// The cases are Close, ChangeSource, PolicyChanged[0],...,PolicyChanged[N-1].
|
||||
p.mu.RLock()
|
||||
cases := make([]reflect.SelectCase, len(p.sources)+policyChangedOffset)
|
||||
// Add the PolicyChanged[N] cases.
|
||||
for i, source := range p.sources {
|
||||
cases[i+policyChangedOffset] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(source.PolicyChanged())}
|
||||
}
|
||||
// Add the Close case.
|
||||
cases[closeIdx] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(p.closeCh)}
|
||||
// Add the ChangeSource case.
|
||||
cases[changeSourceIdx] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(p.changeSourceCh)}
|
||||
p.mu.RUnlock()
|
||||
|
||||
for {
|
||||
switch chosen, recv, ok := reflect.Select(cases); chosen {
|
||||
case closeIdx: // Close
|
||||
// Exit the watch as the closeCh was closed, indicating that
|
||||
// the [Policy] is being closed.
|
||||
return
|
||||
case changeSourceIdx: // ChangeSource
|
||||
// We've received a source change request from one of the AddSource,
|
||||
// RemoveSource, or ReplaceSource methods, meaning that we need to:
|
||||
// - Open a reader session if a new source is being added;
|
||||
// - Update the p.sources slice;
|
||||
// - Update the cases slice;
|
||||
// - Trigger a synchronous policy reload;
|
||||
// - Report an error, if any, back to the caller.
|
||||
req := recv.Interface().(sourceChangeRequest)
|
||||
needClose, err := func() (close bool, err error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if req.toAdd != nil {
|
||||
if !p.sources.Contains(req.toAdd) {
|
||||
reader, err := req.toAdd.Reader()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get a store reader: %v", err)
|
||||
}
|
||||
session, err := reader.OpenSession()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to open a reading session: %v", err)
|
||||
}
|
||||
|
||||
addAt := p.sources.InsertionIndexOf(req.toAdd)
|
||||
toAdd := source.ReadableSource{
|
||||
Source: req.toAdd,
|
||||
ReadingSession: session,
|
||||
}
|
||||
p.sources = slices.Insert(p.sources, addAt, toAdd)
|
||||
newCase := reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(toAdd.PolicyChanged())}
|
||||
caseIndex := addAt + policyChangedOffset
|
||||
cases = slices.Insert(cases, caseIndex, newCase)
|
||||
}
|
||||
}
|
||||
if req.toDelete != nil {
|
||||
if deleteAt := p.sources.IndexOf(req.toDelete); deleteAt != -1 {
|
||||
p.sources.DeleteAt(deleteAt)
|
||||
caseIndex := deleteAt + policyChangedOffset
|
||||
cases = slices.Delete(cases, caseIndex, caseIndex+1)
|
||||
}
|
||||
}
|
||||
return len(p.sources) == 0, nil
|
||||
}()
|
||||
if err == nil {
|
||||
if needClose {
|
||||
// Close the resultant policy if the last policy source was deleted.
|
||||
p.Close()
|
||||
} else {
|
||||
// Otherwise, reload the policy synchronously.
|
||||
_, err = p.reload(false)
|
||||
}
|
||||
}
|
||||
req.respCh <- err
|
||||
default: // PolicyChanged[N]
|
||||
if !ok {
|
||||
// One of the PolicyChanged channels was closed, indicating that
|
||||
// the corresponding [source.Source] is no longer valid.
|
||||
// We can no longer keep this [Policy] up to date
|
||||
// and should close it.
|
||||
p.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// One of the PolicyChanged channels was signaled.
|
||||
// We should request an asynchronous policy reload.
|
||||
p.reloadAsync()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// watchReload processes incoming synchronous and asynchronous policy reload requests.
|
||||
// Synchronous requests (with a non-nil respCh) are served immediately.
|
||||
// Asynchronous requests are debounced and throttled: they are executed at least
|
||||
// [policyReloadMinDelay] after the last request, but no later than [policyReloadMaxDelay]
|
||||
// after the first request in a batch.
|
||||
func (p *Policy) watchReload() {
|
||||
force := false // whether a forced refresh was requested
|
||||
var delayCh, timeoutCh <-chan time.Time
|
||||
reload := func(respCh chan<- reloadResponse) {
|
||||
delayCh, timeoutCh = nil, nil
|
||||
policy, err := p.reloadNow(force)
|
||||
if err != nil {
|
||||
loggerx.Errorf("%v policy reload failed: %v\n", p.scope, err)
|
||||
}
|
||||
if respCh != nil {
|
||||
respCh <- reloadResponse{policy: policy, err: err}
|
||||
}
|
||||
force = false
|
||||
}
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case req := <-p.reloadCh:
|
||||
if req.force {
|
||||
force = true
|
||||
}
|
||||
if req.respCh != nil {
|
||||
reload(req.respCh)
|
||||
continue
|
||||
}
|
||||
if delayCh == nil {
|
||||
timeoutCh = time.After(policyReloadMaxDelay.Get(func() time.Duration { return defaultPolicyReloadMaxDelay }))
|
||||
}
|
||||
delayCh = time.After(policyReloadMinDelay.Get(func() time.Duration { return defaultPolicyReloadMinDelay }))
|
||||
case <-delayCh:
|
||||
reload(nil)
|
||||
case <-timeoutCh:
|
||||
reload(nil)
|
||||
case <-p.closeCh:
|
||||
break loop
|
||||
}
|
||||
}
|
||||
|
||||
p.closeInternal()
|
||||
}
|
||||
|
||||
func (p *Policy) closeInternal() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.sources.Close()
|
||||
p.changeCallbacks.Close()
|
||||
close(p.doneCh)
|
||||
}
|
||||
|
||||
// Close initiates the closing of the resultant policy.
|
||||
// The actual closing is performed by closeInternal when watchReload exits,
|
||||
// and the Done() channel is closed when closeInternal finishes.
|
||||
func (p *Policy) Close() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.closing {
|
||||
return
|
||||
}
|
||||
p.closing = true
|
||||
close(p.closeCh)
|
||||
}
|
||||
|
||||
// sourceChangeRequest is a request to add and/or remove source from a [Policy].
|
||||
type sourceChangeRequest struct {
|
||||
toAdd, toDelete *source.Source
|
||||
respCh chan<- error
|
||||
}
|
||||
|
||||
// reloadRequest describes a policy reload request.
|
||||
type reloadRequest struct {
|
||||
// force triggers an immediate synchronous policy reload,
|
||||
// reloading the policy regardless of whether a policy change was detected.
|
||||
force bool
|
||||
// respCh is an optional channel. If non-nil, it makes the reload request
|
||||
// synchronous and receives the result.
|
||||
respCh chan<- reloadResponse
|
||||
}
|
||||
|
||||
type reloadResponse struct {
|
||||
policy *setting.Snapshot
|
||||
err error
|
||||
}
|
||||
|
||||
var (
|
||||
policyMu sync.RWMutex
|
||||
policySources []*source.Source
|
||||
resultantPolicies []*Policy
|
||||
|
||||
resultantPolicyLRU [setting.MaxSettingScope + 1]syncs.AtomicValue[*Policy] // by [Scope.Kind]
|
||||
)
|
||||
|
||||
// registerSource registers the specified [source.Source] to be used by the package.
|
||||
// It updates existing [Policy]s returned by [PolicyFor] to use this source if
|
||||
// they are within the source's [setting.PolicyScope].
|
||||
func registerSource(source *source.Source) error {
|
||||
policyMu.Lock()
|
||||
defer policyMu.Unlock()
|
||||
if slices.Contains(policySources, source) {
|
||||
return nil
|
||||
}
|
||||
policySources = append(policySources, source)
|
||||
return forEachResultantPolicyLocked(func(policy *Policy) error {
|
||||
if !policy.Scope().IsWithinOf(source.Scope()) {
|
||||
return nil
|
||||
}
|
||||
return policy.AddSource(source)
|
||||
})
|
||||
}
|
||||
|
||||
// replaceSource is like [unregisterSource](old) followed by [registerSource](new),
|
||||
// but is atomic from the perspective of each [Policy].
|
||||
func replaceSource(old, new *source.Source) error {
|
||||
policyMu.Lock()
|
||||
defer policyMu.Unlock()
|
||||
oldIndex := slices.Index(policySources, old)
|
||||
if oldIndex == -1 {
|
||||
return fmt.Errorf("the source is not registered: %v", old)
|
||||
}
|
||||
policySources[oldIndex] = new
|
||||
return forEachResultantPolicyLocked(func(policy *Policy) error {
|
||||
if policy.Scope().IsWithinOf(old.Scope()) || policy.Scope().IsWithinOf(new.Scope()) {
|
||||
return nil
|
||||
}
|
||||
return policy.ReplaceSource(old, new)
|
||||
})
|
||||
}
|
||||
|
||||
// unregisterSource unregisters the specified [source.Source],
|
||||
// so that it won't be used by any new or existing [Policy].
|
||||
func unregisterSource(source *source.Source) error {
|
||||
policyMu.Lock()
|
||||
defer policyMu.Unlock()
|
||||
index := slices.Index(policySources, source)
|
||||
if index == -1 {
|
||||
return nil
|
||||
}
|
||||
policySources = slices.Delete(policySources, index, index+1)
|
||||
return forEachResultantPolicyLocked(func(policy *Policy) error {
|
||||
if !policy.Scope().IsWithinOf(source.Scope()) {
|
||||
return nil
|
||||
}
|
||||
return policy.RemoveSource(source)
|
||||
})
|
||||
}
|
||||
|
||||
// forEachResultantPolicyLocked calls fn for every [Policy] in [resultantPolicies].
|
||||
// It accumulates the returned errors, except for [errResultantPolicyClosed],
|
||||
// and returns an error that wraps all errors returned by fn.
|
||||
// The [policyMu] mutex must be held while this function is executed.
|
||||
func forEachResultantPolicyLocked(fn func(p *Policy) error) error {
|
||||
var errs []error
|
||||
for _, policy := range resultantPolicies {
|
||||
err := fn(policy)
|
||||
if err != nil && !errors.Is(err, errResultantPolicyClosed) {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// PolicyFor returns the [Policy] for the specified scope,
|
||||
// creating one from the registered [source.Store]s if it does not exist.
|
||||
func PolicyFor(scope setting.PolicyScope) (*Policy, error) {
|
||||
if err := lazyinit.Do(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if policy := resultantPolicyLRU[scope.Kind()].Load(); policy != nil && policy.Scope() == scope && policy.IsValid() {
|
||||
return policy, nil
|
||||
}
|
||||
return policyForSlow(scope)
|
||||
}
|
||||
|
||||
func policyForSlow(scope setting.PolicyScope) (policy *Policy, err error) {
|
||||
defer func() {
|
||||
if policy != nil {
|
||||
resultantPolicyLRU[scope.Kind()].Store(policy)
|
||||
}
|
||||
}()
|
||||
|
||||
policyMu.RLock()
|
||||
if policy, ok := findPolicyByScopeLocked(scope); ok {
|
||||
policyMu.RUnlock()
|
||||
return policy, nil
|
||||
}
|
||||
policyMu.RUnlock()
|
||||
|
||||
policyMu.Lock()
|
||||
defer policyMu.Unlock()
|
||||
if policy, ok := findPolicyByScopeLocked(scope); ok {
|
||||
return policy, nil
|
||||
}
|
||||
sources := slicesx.Filter(nil, policySources, func(source *source.Source) bool {
|
||||
return scope.IsWithinOf(source.Scope())
|
||||
})
|
||||
policy, err = newPolicy(scope, sources...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resultantPolicies = append(resultantPolicies, policy)
|
||||
go func() {
|
||||
<-policy.Done()
|
||||
deletePolicy(policy)
|
||||
}()
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
// findPolicyByScopeLocked returns a policy with the specified scope and true if
|
||||
// one exists, otherwise it returns nil, false.
|
||||
// [policyMu] must be held.
|
||||
func findPolicyByScopeLocked(target setting.PolicyScope) (policy *Policy, ok bool) {
|
||||
for _, policy := range resultantPolicies {
|
||||
if policy.Scope() == target && policy.IsValid() {
|
||||
return policy, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// deletePolicy deletes the specified resultant policy from the [resultantPolicies] list.
|
||||
func deletePolicy(policy *Policy) {
|
||||
policyMu.Lock()
|
||||
if i := slices.Index(resultantPolicies, policy); i != -1 {
|
||||
resultantPolicies = slices.Delete(resultantPolicies, i, i+1)
|
||||
}
|
||||
resultantPolicyLRU[policy.Scope().Kind()].CompareAndSwap(policy, nil)
|
||||
policyMu.Unlock()
|
||||
}
|
||||
|
||||
// ErrAlreadyConsumed is the error returned when [StoreRegistration.ReplaceStore]
|
||||
// or [StoreRegistration.Unregister] is called more than once.
|
||||
var ErrAlreadyConsumed = errors.New("the store registration is no longer valid")
|
||||
|
||||
// StoreRegistration is a [source.Store] registered for use in the specified scope.
|
||||
// It can be used to unregister the store, or replace it with another one.
|
||||
type StoreRegistration struct {
|
||||
source *source.Source
|
||||
consumed atomic.Uint32
|
||||
m sync.Mutex
|
||||
}
|
||||
|
||||
// RegisterStore registers a new policy [source.Store] with the specified name and [setting.PolicyScope].
|
||||
func RegisterStore(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) {
|
||||
return newStoreRegistration(name, scope, store)
|
||||
}
|
||||
|
||||
// RegisterStoreForTest is like [RegisterStore], but unregisters the store when
|
||||
// tb and all its subtests complete.
|
||||
func RegisterStoreForTest(tb internal.TB, name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) {
|
||||
reg, err := RegisterStore(name, scope, store)
|
||||
if err == nil {
|
||||
tb.Cleanup(func() {
|
||||
if err := reg.Unregister(); err != nil && !errors.Is(err, ErrAlreadyConsumed) {
|
||||
tb.Fatalf("Unregister failed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
return reg, err // may be nil or non-nil
|
||||
}
|
||||
|
||||
func newStoreRegistration(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) {
|
||||
source := source.NewSource(name, scope, store)
|
||||
if err := registerSource(source); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &StoreRegistration{source: source}, nil
|
||||
}
|
||||
|
||||
// ReplaceStore replaces the registered store with the new one,
|
||||
// returning a new [StoreRegistration] or an error.
|
||||
func (r *StoreRegistration) ReplaceStore(new source.Store) (*StoreRegistration, error) {
|
||||
var res *StoreRegistration
|
||||
err := r.consume(func() error {
|
||||
newSource := source.NewSource(r.source.Name(), r.source.Scope(), new)
|
||||
if err := replaceSource(r.source, newSource); err != nil {
|
||||
return err
|
||||
}
|
||||
res = &StoreRegistration{source: newSource}
|
||||
return nil
|
||||
})
|
||||
return res, err
|
||||
}
|
||||
|
||||
// Unregister reverts the registration.
|
||||
func (r *StoreRegistration) Unregister() error {
|
||||
return r.consume(func() error { return unregisterSource(r.source) })
|
||||
}
|
||||
|
||||
// consume invokes fn, consuming r if no error is returned.
|
||||
// It returns [ErrAlreadyConsumed] on subsequent calls after the first successful call.
|
||||
func (r *StoreRegistration) consume(fn func() error) (err error) {
|
||||
if r.consumed.Load() != 0 {
|
||||
return ErrAlreadyConsumed
|
||||
}
|
||||
return r.consumeSlow(fn)
|
||||
}
|
||||
|
||||
func (r *StoreRegistration) consumeSlow(fn func() error) (err error) {
|
||||
r.m.Lock()
|
||||
defer r.m.Unlock()
|
||||
if r.consumed.Load() != 0 {
|
||||
return ErrAlreadyConsumed
|
||||
}
|
||||
if err = fn(); err == nil {
|
||||
r.consumed.Store(1)
|
||||
}
|
||||
return err // may be nil or non-nil
|
||||
}
|
368
util/syspolicy/rsop/resultant_policy_test.go
Normal file
368
util/syspolicy/rsop/resultant_policy_test.go
Normal file
@ -0,0 +1,368 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package rsop
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
|
||||
"tailscale.com/util/syspolicy/source"
|
||||
)
|
||||
|
||||
func TestRegisterSourceAndGetResultantPolicy(t *testing.T) {
|
||||
type sourceConfig struct {
|
||||
name string
|
||||
scope setting.PolicyScope
|
||||
settingKey setting.Key
|
||||
settingValue string
|
||||
wantEffective bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
scope setting.PolicyScope
|
||||
initialSources []sourceConfig
|
||||
additionalSources []sourceConfig
|
||||
wantSnapshot *setting.Snapshot
|
||||
}{
|
||||
{
|
||||
name: "DevicePolicy/NoSources",
|
||||
scope: setting.DeviceScope,
|
||||
wantSnapshot: setting.NewSnapshot(nil, setting.DeviceScope),
|
||||
},
|
||||
{
|
||||
name: "UserScope/NoSources",
|
||||
scope: setting.CurrentUserScope,
|
||||
wantSnapshot: setting.NewSnapshot(nil, setting.CurrentUserScope),
|
||||
},
|
||||
{
|
||||
name: "DevicePolicy/OneInitialSource",
|
||||
scope: setting.DeviceScope,
|
||||
initialSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceA",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "TestValueA",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
|
||||
}, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
|
||||
},
|
||||
{
|
||||
name: "DevicePolicy/OneAdditionalSource",
|
||||
scope: setting.DeviceScope,
|
||||
additionalSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceA",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "TestValueA",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
|
||||
}, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
|
||||
},
|
||||
{
|
||||
name: "DevicePolicy/ManyInitialSources/NoConflicts",
|
||||
scope: setting.DeviceScope,
|
||||
initialSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceA",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "TestValueA",
|
||||
wantEffective: true,
|
||||
},
|
||||
{
|
||||
name: "TestSourceB",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyB",
|
||||
settingValue: "TestValueB",
|
||||
wantEffective: true,
|
||||
},
|
||||
{
|
||||
name: "TestSourceC",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyC",
|
||||
settingValue: "TestValueC",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)),
|
||||
"TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)),
|
||||
"TestKeyC": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)),
|
||||
}, setting.DeviceScope),
|
||||
},
|
||||
{
|
||||
name: "DevicePolicy/ManyInitialSources/Conflicts",
|
||||
scope: setting.DeviceScope,
|
||||
initialSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceA",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "TestValueA",
|
||||
wantEffective: true,
|
||||
},
|
||||
{
|
||||
name: "TestSourceB",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyB",
|
||||
settingValue: "TestValueB",
|
||||
wantEffective: true,
|
||||
},
|
||||
{
|
||||
name: "TestSourceC",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "TestValueC",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"TestKeyA": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)),
|
||||
"TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)),
|
||||
}, setting.DeviceScope),
|
||||
},
|
||||
{
|
||||
name: "DevicePolicy/MixedSources/Conflicts",
|
||||
scope: setting.DeviceScope,
|
||||
initialSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceA",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "TestValueA",
|
||||
wantEffective: true,
|
||||
},
|
||||
{
|
||||
name: "TestSourceB",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyB",
|
||||
settingValue: "TestValueB",
|
||||
wantEffective: true,
|
||||
},
|
||||
{
|
||||
name: "TestSourceC",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "TestValueC",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
additionalSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceD",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "TestValueD",
|
||||
wantEffective: true,
|
||||
},
|
||||
{
|
||||
name: "TestSourceE",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyC",
|
||||
settingValue: "TestValueE",
|
||||
wantEffective: true,
|
||||
},
|
||||
{
|
||||
name: "TestSourceF",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "TestValueF",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"TestKeyA": setting.RawItemWith("TestValueF", nil, setting.NewNamedOrigin("TestSourceF", setting.DeviceScope)),
|
||||
"TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)),
|
||||
"TestKeyC": setting.RawItemWith("TestValueE", nil, setting.NewNamedOrigin("TestSourceE", setting.DeviceScope)),
|
||||
}, setting.DeviceScope),
|
||||
},
|
||||
{
|
||||
name: "UserScope/Init-DeviceSource",
|
||||
scope: setting.CurrentUserScope,
|
||||
initialSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceDevice",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "DeviceValue",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
|
||||
}, setting.CurrentUserScope, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
|
||||
},
|
||||
{
|
||||
name: "UserScope/Init-DeviceSource/Add-UserSource",
|
||||
scope: setting.CurrentUserScope,
|
||||
initialSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceDevice",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "DeviceValue",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
additionalSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceUser",
|
||||
scope: setting.CurrentUserScope,
|
||||
settingKey: "TestKeyB",
|
||||
settingValue: "UserValue",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
|
||||
"TestKeyB": setting.RawItemWith("UserValue", nil, setting.NewNamedOrigin("TestSourceUser", setting.CurrentUserScope)),
|
||||
}, setting.CurrentUserScope),
|
||||
},
|
||||
{
|
||||
name: "UserScope/Init-DeviceSource/Add-UserSource-and-ProfileSource",
|
||||
scope: setting.CurrentUserScope,
|
||||
initialSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceDevice",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "DeviceValue",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
additionalSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceProfile",
|
||||
scope: setting.CurrentProfileScope,
|
||||
settingKey: "TestKeyB",
|
||||
settingValue: "ProfileValue",
|
||||
wantEffective: true,
|
||||
},
|
||||
{
|
||||
name: "TestSourceUser",
|
||||
scope: setting.CurrentUserScope,
|
||||
settingKey: "TestKeyB",
|
||||
settingValue: "UserValue",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
|
||||
"TestKeyB": setting.RawItemWith("ProfileValue", nil, setting.NewNamedOrigin("TestSourceProfile", setting.CurrentProfileScope)),
|
||||
}, setting.CurrentUserScope),
|
||||
},
|
||||
{
|
||||
name: "DevicePolicy/User-Source-does-not-apply",
|
||||
scope: setting.DeviceScope,
|
||||
initialSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceDevice",
|
||||
scope: setting.DeviceScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "DeviceValue",
|
||||
wantEffective: true,
|
||||
},
|
||||
},
|
||||
additionalSources: []sourceConfig{
|
||||
{
|
||||
name: "TestSourceUser",
|
||||
scope: setting.CurrentUserScope,
|
||||
settingKey: "TestKeyA",
|
||||
settingValue: "UserValue",
|
||||
wantEffective: false, // Registering a user source should have no impact on the device policy.
|
||||
},
|
||||
},
|
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
|
||||
}, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Register all settings that we use in this test.
|
||||
var definitions []*setting.Definition
|
||||
for _, source := range slices.Concat(tt.initialSources, tt.additionalSources) {
|
||||
definitions = append(definitions, setting.NewDefinition(source.settingKey, tt.scope.Kind(), setting.StringValue))
|
||||
}
|
||||
if err := setting.SetDefinitionsForTest(t, definitions...); err != nil {
|
||||
t.Fatalf("SetDefinitionsForTest failed: %v", err)
|
||||
}
|
||||
|
||||
// Add the initial policy sources.
|
||||
var wantSources []*source.Source
|
||||
for _, s := range tt.initialSources {
|
||||
store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue))
|
||||
source := source.NewSource(s.name, s.scope, store)
|
||||
if err := registerSource(source); err != nil {
|
||||
t.Fatalf("failed to register policy source: %v", source)
|
||||
}
|
||||
if s.wantEffective {
|
||||
wantSources = append(wantSources, source)
|
||||
}
|
||||
t.Cleanup(func() { unregisterSource(source) })
|
||||
}
|
||||
|
||||
// Retrieve the resultant policy.
|
||||
policy, err := resultantPolicyForTest(t, tt.scope)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get resultant policy for %v", tt.scope)
|
||||
}
|
||||
|
||||
// Add additional setting sources one by one, and check the policy settings at each step.
|
||||
for _, s := range tt.additionalSources {
|
||||
store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue))
|
||||
source := source.NewSource(s.name, s.scope, store)
|
||||
if err := registerSource(source); err != nil {
|
||||
t.Fatalf("failed to register additional policy source: %v", source)
|
||||
}
|
||||
if s.wantEffective {
|
||||
wantSources = append(wantSources, source)
|
||||
}
|
||||
t.Cleanup(func() { unregisterSource(source) })
|
||||
}
|
||||
|
||||
sort.SliceStable(wantSources, func(i, j int) bool {
|
||||
return wantSources[i].Compare(wantSources[j]) < 0
|
||||
})
|
||||
gotSources := make([]*source.Source, len(policy.sources))
|
||||
for i, s := range policy.sources {
|
||||
gotSources[i] = s.Source
|
||||
}
|
||||
if !slices.Equal(gotSources, wantSources) {
|
||||
t.Errorf("Sources: got %v; want %v", gotSources, wantSources)
|
||||
}
|
||||
|
||||
// Verify the final resultant settings snapshots.
|
||||
if got := policy.Get(); !got.Equal(tt.wantSnapshot) {
|
||||
t.Errorf("Snapshot: got %v; want %v", got, tt.wantSnapshot)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// resultantPolicyForTest is like [resultantPolicyFor], but it deletes the policy
|
||||
// when tb and all its subtests complete.
|
||||
func resultantPolicyForTest(tb testing.TB, target setting.PolicyScope) (*Policy, error) {
|
||||
policy, err := PolicyFor(target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tb.Cleanup(func() {
|
||||
policy.Close()
|
||||
<-policy.Done()
|
||||
deletePolicy(policy)
|
||||
})
|
||||
return policy, nil
|
||||
}
|
60
util/syspolicy/setting/errors.go
Normal file
60
util/syspolicy/setting/errors.go
Normal file
@ -0,0 +1,60 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrNotConfigured is returned when the requested policy setting is not configured.
|
||||
ErrNotConfigured = errors.New("not configured")
|
||||
// ErrTypeMismatch is returned when there's a type mismatch between the actual type
|
||||
// of the setting value and the expected type.
|
||||
ErrTypeMismatch = errors.New("type mismatch")
|
||||
// ErrNoSuchKey is returned by [DefinitionOf] when no policy setting
|
||||
// has been registered with the specified key.
|
||||
//
|
||||
// Until 2024-08-02, this error was also returned by a [Handler] when the specified
|
||||
// key did not have a value set. While the package maintains compatibility with this
|
||||
// usage of ErrNoSuchKey, it is recommended to return [ErrNotConfigured] from newer
|
||||
// [source.Store] implementations.
|
||||
ErrNoSuchKey = errors.New("no such key")
|
||||
)
|
||||
|
||||
// Error is an error when reading or parsing a policy setting.
|
||||
type Error struct {
|
||||
text string
|
||||
}
|
||||
|
||||
// NewError returns a [Error] with the specified error message.
|
||||
func NewError(text string) *Error {
|
||||
return &Error{text}
|
||||
}
|
||||
|
||||
// WrapError returns an [Error] with the text of the specified error,
|
||||
// or nil if err is nil, [ErrNotConfigured], or [ErrNoSuchKey].
|
||||
func WrapError(err error) *Error {
|
||||
if err == nil || errors.Is(err, ErrNotConfigured) || errors.Is(err, ErrNoSuchKey) {
|
||||
return nil
|
||||
}
|
||||
if err, ok := err.(*Error); ok {
|
||||
return err
|
||||
}
|
||||
return &Error{err.Error()}
|
||||
}
|
||||
|
||||
// Error implements error.
|
||||
func (e Error) Error() string {
|
||||
return e.text
|
||||
}
|
||||
|
||||
// MarshalText implements [encoding.TextMarshaler].
|
||||
func (e Error) MarshalText() (text []byte, err error) {
|
||||
return []byte(e.Error()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements [encoding.TextUnmarshaler].
|
||||
func (e *Error) UnmarshalText(text []byte) error {
|
||||
e.text = string(text)
|
||||
return nil
|
||||
}
|
13
util/syspolicy/setting/key.go
Normal file
13
util/syspolicy/setting/key.go
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
// Key is a string that uniquely identifies a policy and must remain unchanged
|
||||
// once established and documented for a given policy setting. It may contain
|
||||
// alphanumeric characters and zero or more [KeyPathSeparator]s to group
|
||||
// individual policy settings into categories.
|
||||
type Key string
|
||||
|
||||
// KeyPathSeparator allows logical grouping of policy settings into categories.
|
||||
const KeyPathSeparator = "/"
|
71
util/syspolicy/setting/origin.go
Normal file
71
util/syspolicy/setting/origin.go
Normal file
@ -0,0 +1,71 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
jsonv2 "github.com/go-json-experiment/json"
|
||||
"github.com/go-json-experiment/json/jsontext"
|
||||
)
|
||||
|
||||
// Origin describes where a policy or a policy setting is configured.
|
||||
type Origin struct {
|
||||
data settingOrigin
|
||||
}
|
||||
|
||||
// settingOrigin is the marshallable data of a [Origin].
|
||||
type settingOrigin struct {
|
||||
Name string `json:",omitzero"`
|
||||
Scope PolicyScope
|
||||
}
|
||||
|
||||
// NewOrigin returns a new [Origin] with the specified scope.
|
||||
func NewOrigin(scope PolicyScope) *Origin {
|
||||
return NewNamedOrigin("", scope)
|
||||
}
|
||||
|
||||
// NewNamedOrigin returns a new [Origin] with the specified scope and name.
|
||||
func NewNamedOrigin(name string, scope PolicyScope) *Origin {
|
||||
return &Origin{settingOrigin{name, scope}}
|
||||
}
|
||||
|
||||
// Scope reports the policy [PolicyScope] where the setting is configured.
|
||||
func (s Origin) Scope() PolicyScope {
|
||||
return s.data.Scope
|
||||
}
|
||||
|
||||
// Name returns the name of the policy source where the setting is configured,
|
||||
// or "" if not available.
|
||||
func (s Origin) Name() string {
|
||||
return s.data.Name
|
||||
}
|
||||
|
||||
// String implements [fmt.Stringer].
|
||||
func (s Origin) String() string {
|
||||
if s.Name() != "" {
|
||||
return fmt.Sprintf("%s (%v)", s.Name(), s.Scope())
|
||||
}
|
||||
return s.Scope().String()
|
||||
}
|
||||
|
||||
// MarshalJSONV2 implements [jsonv2.MarshalerV2].
|
||||
func (s Origin) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error {
|
||||
return jsonv2.MarshalEncode(out, &s.data, opts)
|
||||
}
|
||||
|
||||
// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2].
|
||||
func (s *Origin) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error {
|
||||
return jsonv2.UnmarshalDecode(in, &s.data, opts)
|
||||
}
|
||||
|
||||
// MarshalJSON implements [json.Marshaler].
|
||||
func (s Origin) MarshalJSON() ([]byte, error) {
|
||||
return jsonv2.Marshal(s) // uses MarshalJSONV2
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements [json.Unmarshaler].
|
||||
func (s *Origin) UnmarshalJSON(b []byte) error {
|
||||
return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONV2
|
||||
}
|
195
util/syspolicy/setting/policy_scope.go
Normal file
195
util/syspolicy/setting/policy_scope.go
Normal file
@ -0,0 +1,195 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/util/syspolicy/internal/lazyinit"
|
||||
)
|
||||
|
||||
var (
|
||||
lazyCurrentScope lazy.SyncValue[PolicyScope]
|
||||
|
||||
// DeviceScope indicates a scope containing device-global policies.
|
||||
DeviceScope = PolicyScope{kind: DeviceSetting}
|
||||
// CurrentProfileScope indicates a scope containing policies that apply to the
|
||||
// currently active Tailscale profile.
|
||||
CurrentProfileScope = PolicyScope{kind: ProfileSetting}
|
||||
// CurrentUserScope indicates a scope containing policies that apply to the
|
||||
// current user, for whatever that means on the current platform and
|
||||
// in the current application context.
|
||||
CurrentUserScope = PolicyScope{kind: UserSetting}
|
||||
)
|
||||
|
||||
// PolicyScope is a management scope.
|
||||
type PolicyScope struct {
|
||||
kind Scope
|
||||
userID string
|
||||
profileID string
|
||||
}
|
||||
|
||||
// CurrentScope returns the default [PolicyScope] that the package will use to return
|
||||
// the policy settings for unless a different scope is explicitly requested.
|
||||
// This defaults to [DeviceScope], unless the process runs as a user (rather than LocalSystem)
|
||||
// on Windows, in which case it returns the [CurrentUserScope].
|
||||
func CurrentScope() PolicyScope {
|
||||
// Allow deferred package init functions to override the default scope.
|
||||
lazyinit.Do()
|
||||
return lazyCurrentScope.Get(func() PolicyScope { return DeviceScope })
|
||||
}
|
||||
|
||||
// SetCurrentScope attempts to set the specified scope as the current scope,
|
||||
// and reports whether it succeeds.
|
||||
// It can be called only once and must be during lazy package initialization.
|
||||
func SetCurrentScope(scope PolicyScope) bool {
|
||||
return lazyCurrentScope.Set(scope)
|
||||
}
|
||||
|
||||
// UserScopeOf returns a policy [PolicyScope] of the specified user.
|
||||
func UserScopeOf(uid string) PolicyScope {
|
||||
return PolicyScope{kind: UserSetting, userID: uid}
|
||||
}
|
||||
|
||||
// Kind reports the base [Scope] of s.
|
||||
func (s PolicyScope) Kind() Scope {
|
||||
return s.kind
|
||||
}
|
||||
|
||||
// IsApplicableSetting reports whether the specified setting applies to
|
||||
// and can be retrieved for this scope. Policy settings are applicable
|
||||
// to their own scopes as well as more specific scopes. For example,
|
||||
// device settings are applicable to device, profile and user scopes,
|
||||
// but user settings are only applicable to user scopes.
|
||||
// For instance, a menu visibility setting is inherently a user setting
|
||||
// and only makes sense in the context of a specific user.
|
||||
func (s PolicyScope) IsApplicableSetting(setting *Definition) bool {
|
||||
return setting != nil && setting.Scope() <= s.Kind()
|
||||
}
|
||||
|
||||
// IsConfigurableSetting reports whether the specified setting can be configured
|
||||
// by a policy at this scope. Policy settings are configurable at their own scopes
|
||||
// as well as broader scopes. For example, [UserSetting]s are configurable in
|
||||
// user, profile, and device scopes, but [DeviceSetting]s are only configurable
|
||||
// in the [DeviceScope]. For instance, the InstallUpdates policy setting
|
||||
// can only be configured in the device scope, as it controls whether updates
|
||||
// will be installed automatically on the device, rather than for specific users.
|
||||
func (s PolicyScope) IsConfigurableSetting(setting *Definition) bool {
|
||||
return setting != nil && setting.Scope() >= s.Kind()
|
||||
}
|
||||
|
||||
// IsWithinOf reports whether policy settings that apply to s2 also apply to s.
|
||||
// For example, policy settings that apply to the [DeviceScope] also apply to
|
||||
// the [CurrentUserScope].
|
||||
func (s PolicyScope) IsWithinOf(s2 PolicyScope) bool {
|
||||
if s2.Kind() > s.Kind() {
|
||||
return false
|
||||
}
|
||||
switch s2.Kind() {
|
||||
case DeviceSetting:
|
||||
return true
|
||||
case ProfileSetting:
|
||||
return s.profileID == s2.profileID
|
||||
case UserSetting:
|
||||
return s.userID == s2.userID
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// IsStrictlyWithinOf is like [IsWithinOf], except it returns false
|
||||
// when s and s2 is the same scope.
|
||||
func (s PolicyScope) IsStrictlyWithinOf(s2 PolicyScope) bool {
|
||||
return s != s2 && s.IsWithinOf(s2)
|
||||
}
|
||||
|
||||
// String implements [fmt.Stringer].
|
||||
func (s PolicyScope) String() string {
|
||||
if s.profileID == "" && s.userID == "" {
|
||||
return s.kind.String()
|
||||
}
|
||||
return s.stringSlow()
|
||||
}
|
||||
|
||||
// MarshalText implements [encoding.TextMarshaler].
|
||||
func (s PolicyScope) MarshalText() ([]byte, error) {
|
||||
return []byte(s.String()), nil
|
||||
}
|
||||
|
||||
// MarshalText implements [encoding.TextUnmarshaler].
|
||||
func (s *PolicyScope) UnmarshalText(b []byte) error {
|
||||
*s = PolicyScope{}
|
||||
parts := strings.SplitN(string(b), "/", 2)
|
||||
if len(parts) == 0 {
|
||||
return fmt.Errorf("%s is not a valid scope", b)
|
||||
}
|
||||
for i, part := range parts {
|
||||
kind, id, err := parseScopeAndID(part)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if i > 0 && kind <= s.kind {
|
||||
return fmt.Errorf("invalid scope hierarchy: %s", b)
|
||||
}
|
||||
s.kind = kind
|
||||
switch kind {
|
||||
case DeviceSetting:
|
||||
if id != "" {
|
||||
return fmt.Errorf("the device scope must not have an ID: %s", b)
|
||||
}
|
||||
case ProfileSetting:
|
||||
s.profileID = id
|
||||
case UserSetting:
|
||||
s.userID = id
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s PolicyScope) stringSlow() string {
|
||||
var sb strings.Builder
|
||||
writeScopeWithID := func(s Scope, id string) {
|
||||
sb.WriteString(s.String())
|
||||
if id != "" {
|
||||
sb.WriteRune('(')
|
||||
sb.WriteString(id)
|
||||
sb.WriteRune(')')
|
||||
}
|
||||
}
|
||||
if s.kind == ProfileSetting || s.profileID != "" {
|
||||
writeScopeWithID(ProfileSetting, s.profileID)
|
||||
if s.kind != ProfileSetting {
|
||||
sb.WriteRune('/')
|
||||
}
|
||||
}
|
||||
if s.kind == UserSetting {
|
||||
writeScopeWithID(UserSetting, s.userID)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func parseScopeAndID(s string) (scope Scope, id string, err error) {
|
||||
name, params, ok := extractScopeAndParams(s)
|
||||
if !ok {
|
||||
return 0, "", fmt.Errorf("%q is not a valid scope string", s)
|
||||
}
|
||||
if err := scope.UnmarshalText([]byte(name)); err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
return scope, params, nil
|
||||
}
|
||||
|
||||
func extractScopeAndParams(s string) (name, params string, ok bool) {
|
||||
paramsStart := strings.Index(s, "(")
|
||||
if paramsStart == -1 {
|
||||
return s, "", true
|
||||
}
|
||||
paramsEnd := strings.LastIndex(s, ")")
|
||||
if paramsEnd < paramsStart {
|
||||
return "", "", false
|
||||
}
|
||||
return s[0:paramsStart], s[paramsStart+1 : paramsEnd], true
|
||||
}
|
550
util/syspolicy/setting/policy_scope_test.go
Normal file
550
util/syspolicy/setting/policy_scope_test.go
Normal file
@ -0,0 +1,550 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
jsonv2 "github.com/go-json-experiment/json"
|
||||
)
|
||||
|
||||
func TestPolicyScopeIsApplicableSetting(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scope PolicyScope
|
||||
setting *Definition
|
||||
wantApplicable bool
|
||||
}{
|
||||
{
|
||||
name: "DeviceScope/DeviceSetting",
|
||||
scope: DeviceScope,
|
||||
setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
|
||||
wantApplicable: true,
|
||||
},
|
||||
{
|
||||
name: "DeviceScope/ProfileSetting",
|
||||
scope: DeviceScope,
|
||||
setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
|
||||
wantApplicable: false,
|
||||
},
|
||||
{
|
||||
name: "DeviceScope/UserSetting",
|
||||
scope: DeviceScope,
|
||||
setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
|
||||
wantApplicable: false,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope/DeviceSetting",
|
||||
scope: CurrentProfileScope,
|
||||
setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
|
||||
wantApplicable: true,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope/ProfileSetting",
|
||||
scope: CurrentProfileScope,
|
||||
setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
|
||||
wantApplicable: true,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope/UserSetting",
|
||||
scope: CurrentProfileScope,
|
||||
setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
|
||||
wantApplicable: false,
|
||||
},
|
||||
{
|
||||
name: "UserScope/DeviceSetting",
|
||||
scope: CurrentUserScope,
|
||||
setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
|
||||
wantApplicable: true,
|
||||
},
|
||||
{
|
||||
name: "UserScope/ProfileSetting",
|
||||
scope: CurrentUserScope,
|
||||
setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
|
||||
wantApplicable: true,
|
||||
},
|
||||
{
|
||||
name: "UserScope/UserSetting",
|
||||
scope: CurrentUserScope,
|
||||
setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
|
||||
wantApplicable: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotApplicable := tt.scope.IsApplicableSetting(tt.setting)
|
||||
if gotApplicable != tt.wantApplicable {
|
||||
t.Fatalf("got %v, want %v", gotApplicable, tt.wantApplicable)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyScopeIsConfigurableSetting(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scope PolicyScope
|
||||
setting *Definition
|
||||
wantConfigurable bool
|
||||
}{
|
||||
{
|
||||
name: "DeviceScope/DeviceSetting",
|
||||
scope: DeviceScope,
|
||||
setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
|
||||
wantConfigurable: true,
|
||||
},
|
||||
{
|
||||
name: "DeviceScope/ProfileSetting",
|
||||
scope: DeviceScope,
|
||||
setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
|
||||
wantConfigurable: true,
|
||||
},
|
||||
{
|
||||
name: "DeviceScope/UserSetting",
|
||||
scope: DeviceScope,
|
||||
setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
|
||||
wantConfigurable: true,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope/DeviceSetting",
|
||||
scope: CurrentProfileScope,
|
||||
setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
|
||||
wantConfigurable: false,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope/ProfileSetting",
|
||||
scope: CurrentProfileScope,
|
||||
setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
|
||||
wantConfigurable: true,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope/UserSetting",
|
||||
scope: CurrentProfileScope,
|
||||
setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
|
||||
wantConfigurable: true,
|
||||
},
|
||||
{
|
||||
name: "UserScope/DeviceSetting",
|
||||
scope: CurrentUserScope,
|
||||
setting: NewDefinition("TestSetting", DeviceSetting, IntegerValue),
|
||||
wantConfigurable: false,
|
||||
},
|
||||
{
|
||||
name: "UserScope/ProfileSetting",
|
||||
scope: CurrentUserScope,
|
||||
setting: NewDefinition("TestSetting", ProfileSetting, IntegerValue),
|
||||
wantConfigurable: false,
|
||||
},
|
||||
{
|
||||
name: "UserScope/UserSetting",
|
||||
scope: CurrentUserScope,
|
||||
setting: NewDefinition("TestSetting", UserSetting, IntegerValue),
|
||||
wantConfigurable: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotConfigurable := tt.scope.IsConfigurableSetting(tt.setting)
|
||||
if gotConfigurable != tt.wantConfigurable {
|
||||
t.Fatalf("got %v, want %v", gotConfigurable, tt.wantConfigurable)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyScopeIsWithinOf(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scopeA PolicyScope
|
||||
scopeB PolicyScope
|
||||
wantBWithinOfA bool
|
||||
wantBStrictlyWithinOfA bool
|
||||
}{
|
||||
{
|
||||
name: "DeviceScope/DeviceScope",
|
||||
scopeA: DeviceScope,
|
||||
scopeB: DeviceScope,
|
||||
wantBWithinOfA: true,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
{
|
||||
name: "DeviceScope/CurrentProfileScope",
|
||||
scopeA: DeviceScope,
|
||||
scopeB: CurrentProfileScope,
|
||||
wantBWithinOfA: true,
|
||||
wantBStrictlyWithinOfA: true,
|
||||
},
|
||||
{
|
||||
name: "DeviceScope/UserScope",
|
||||
scopeA: DeviceScope,
|
||||
scopeB: CurrentUserScope,
|
||||
wantBWithinOfA: true,
|
||||
wantBStrictlyWithinOfA: true,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope/DeviceScope",
|
||||
scopeA: CurrentProfileScope,
|
||||
scopeB: DeviceScope,
|
||||
wantBWithinOfA: false,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope/ProfileScope",
|
||||
scopeA: CurrentProfileScope,
|
||||
scopeB: CurrentProfileScope,
|
||||
wantBWithinOfA: true,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope/UserScope",
|
||||
scopeA: CurrentProfileScope,
|
||||
scopeB: CurrentUserScope,
|
||||
wantBWithinOfA: true,
|
||||
wantBStrictlyWithinOfA: true,
|
||||
},
|
||||
{
|
||||
name: "UserScope/DeviceScope",
|
||||
scopeA: CurrentUserScope,
|
||||
scopeB: DeviceScope,
|
||||
wantBWithinOfA: false,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
{
|
||||
name: "UserScope/ProfileScope",
|
||||
scopeA: CurrentUserScope,
|
||||
scopeB: CurrentProfileScope,
|
||||
wantBWithinOfA: false,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
{
|
||||
name: "UserScope/UserScope",
|
||||
scopeA: CurrentUserScope,
|
||||
scopeB: CurrentUserScope,
|
||||
wantBWithinOfA: true,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
{
|
||||
name: "UserScope(1234)/UserScope(1234)",
|
||||
scopeA: UserScopeOf("1234"),
|
||||
scopeB: UserScopeOf("1234"),
|
||||
wantBWithinOfA: true,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
{
|
||||
name: "UserScope(1234)/UserScope(5678)",
|
||||
scopeA: UserScopeOf("1234"),
|
||||
scopeB: UserScopeOf("5678"),
|
||||
wantBWithinOfA: false,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope(A)/UserScope(A/1234)",
|
||||
scopeA: PolicyScope{kind: ProfileSetting, profileID: "A"},
|
||||
scopeB: PolicyScope{kind: UserSetting, userID: "1234", profileID: "A"},
|
||||
wantBWithinOfA: true,
|
||||
wantBStrictlyWithinOfA: true,
|
||||
},
|
||||
{
|
||||
name: "ProfileScope(A)/UserScope(B/1234)",
|
||||
scopeA: PolicyScope{kind: ProfileSetting, profileID: "A"},
|
||||
scopeB: PolicyScope{kind: UserSetting, userID: "1234", profileID: "B"},
|
||||
wantBWithinOfA: false,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
{
|
||||
name: "UserScope(1234)/UserScope(A/1234)",
|
||||
scopeA: PolicyScope{kind: UserSetting, userID: "1234"},
|
||||
scopeB: PolicyScope{kind: UserSetting, userID: "1234", profileID: "A"},
|
||||
wantBWithinOfA: true,
|
||||
wantBStrictlyWithinOfA: true,
|
||||
},
|
||||
{
|
||||
name: "UserScope(1234)/UserScope(A/5678)",
|
||||
scopeA: PolicyScope{kind: UserSetting, userID: "1234"},
|
||||
scopeB: PolicyScope{kind: UserSetting, userID: "5678", profileID: "A"},
|
||||
wantBWithinOfA: false,
|
||||
wantBStrictlyWithinOfA: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotWithinOf := tt.scopeB.IsWithinOf(tt.scopeA)
|
||||
if gotWithinOf != tt.wantBWithinOfA {
|
||||
t.Fatalf("WithinOf: got %v, want %v", gotWithinOf, tt.wantBWithinOfA)
|
||||
}
|
||||
|
||||
gotStrictlyWithinOf := tt.scopeB.IsStrictlyWithinOf(tt.scopeA)
|
||||
if gotStrictlyWithinOf != tt.wantBStrictlyWithinOfA {
|
||||
t.Fatalf("StrictlyWithinOf: got %v, want %v", gotStrictlyWithinOf, tt.wantBStrictlyWithinOfA)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyScopeMarshalUnmarshal(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in any
|
||||
wantJSON string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "null-scope",
|
||||
in: &struct {
|
||||
Scope PolicyScope
|
||||
}{},
|
||||
wantJSON: `{"Scope":"Device"}`,
|
||||
},
|
||||
{
|
||||
name: "null-scope-omit-zero",
|
||||
in: &struct {
|
||||
Scope PolicyScope `json:",omitzero"`
|
||||
}{},
|
||||
wantJSON: `{}`,
|
||||
},
|
||||
{
|
||||
name: "device-scope",
|
||||
in: &struct {
|
||||
Scope PolicyScope
|
||||
}{DeviceScope},
|
||||
wantJSON: `{"Scope":"Device"}`,
|
||||
},
|
||||
{
|
||||
name: "current-profile-scope",
|
||||
in: &struct {
|
||||
Scope PolicyScope
|
||||
}{CurrentProfileScope},
|
||||
wantJSON: `{"Scope":"Profile"}`,
|
||||
},
|
||||
{
|
||||
name: "current-user-scope",
|
||||
in: &struct {
|
||||
Scope PolicyScope
|
||||
}{CurrentUserScope},
|
||||
wantJSON: `{"Scope":"User"}`,
|
||||
},
|
||||
{
|
||||
name: "specific-user-scope",
|
||||
in: &struct {
|
||||
Scope PolicyScope
|
||||
}{UserScopeOf("_")},
|
||||
wantJSON: `{"Scope":"User(_)"}`,
|
||||
},
|
||||
{
|
||||
name: "specific-user-scope",
|
||||
in: &struct {
|
||||
Scope PolicyScope
|
||||
}{UserScopeOf("S-1-5-21-3698941153-1525015703-2649197413-1001")},
|
||||
wantJSON: `{"Scope":"User(S-1-5-21-3698941153-1525015703-2649197413-1001)"}`,
|
||||
},
|
||||
{
|
||||
name: "specific-profile-scope",
|
||||
in: &struct {
|
||||
Scope PolicyScope
|
||||
}{PolicyScope{kind: ProfileSetting, profileID: "1234"}},
|
||||
wantJSON: `{"Scope":"Profile(1234)"}`,
|
||||
},
|
||||
{
|
||||
name: "specific-profile-and-user-scope",
|
||||
in: &struct {
|
||||
Scope PolicyScope
|
||||
}{PolicyScope{
|
||||
kind: UserSetting,
|
||||
profileID: "1234",
|
||||
userID: "S-1-5-21-3698941153-1525015703-2649197413-1001",
|
||||
}},
|
||||
wantJSON: `{"Scope":"Profile(1234)/User(S-1-5-21-3698941153-1525015703-2649197413-1001)"}`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotJSON, err := jsonv2.Marshal(tt.in)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
if string(gotJSON) != tt.wantJSON {
|
||||
t.Fatalf("Marshal got %s, want %s", gotJSON, tt.wantJSON)
|
||||
}
|
||||
wantBack := tt.in
|
||||
gotBack := reflect.New(reflect.TypeOf(tt.in).Elem()).Interface()
|
||||
err = jsonv2.Unmarshal(gotJSON, gotBack)
|
||||
if err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotBack, wantBack) {
|
||||
t.Fatalf("Unmarshal got %+v, want %+v", gotBack, wantBack)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyScopeUnmarshalSpecial(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
json string
|
||||
want any
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
json: "{}",
|
||||
want: &struct {
|
||||
Scope PolicyScope
|
||||
}{},
|
||||
},
|
||||
{
|
||||
name: "too-many-scopes",
|
||||
json: `{"Scope":"Device/Profile/User"}`,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "user/profile", // incorrect order
|
||||
json: `{"Scope":"User/Profile"}`,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "profile-user-no-params",
|
||||
json: `{"Scope":"Profile/User"}`,
|
||||
want: &struct {
|
||||
Scope PolicyScope
|
||||
}{CurrentUserScope},
|
||||
},
|
||||
{
|
||||
name: "unknown-scope",
|
||||
json: `{"Scope":"Unknown"}`,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "unknown-scope/unknown-scope",
|
||||
json: `{"Scope":"Unknown/Unknown"}`,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "device-scope/unknown-scope",
|
||||
json: `{"Scope":"Device/Unknown"}`,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "unknown-scope/device-scope",
|
||||
json: `{"Scope":"Unknown/Device"}`,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "slash",
|
||||
json: `{"Scope":"/"}`,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := &struct {
|
||||
Scope PolicyScope
|
||||
}{}
|
||||
err := jsonv2.Unmarshal([]byte(tt.json), got)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("Marshal error: got %v, want %v", err, tt.wantError)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Fatalf("Unmarshal got %+v, want %+v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestExtractScopeAndParams(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
scope string
|
||||
params string
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
s: "",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "scope-only",
|
||||
s: "device",
|
||||
scope: "device",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "scope-with-params",
|
||||
s: "user(1234)",
|
||||
scope: "user",
|
||||
params: "1234",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "params-empty-scope",
|
||||
s: "(1234)",
|
||||
scope: "",
|
||||
params: "1234",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "params-with-brackets",
|
||||
s: "test()())))())",
|
||||
scope: "test",
|
||||
params: ")())))()",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "no-closing-bracket",
|
||||
s: "user(1234",
|
||||
scope: "",
|
||||
params: "",
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "open-before-close",
|
||||
s: ")user(1234",
|
||||
scope: "",
|
||||
params: "",
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "brackets-only",
|
||||
s: ")(",
|
||||
scope: "",
|
||||
params: "",
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "closing-bracket",
|
||||
s: ")",
|
||||
scope: "",
|
||||
params: "",
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "opening-bracket",
|
||||
s: ")",
|
||||
scope: "",
|
||||
params: "",
|
||||
wantOk: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
scope, params, ok := extractScopeAndParams(tt.s)
|
||||
if ok != tt.wantOk {
|
||||
t.Logf("OK: got %v; want %v", ok, tt.wantOk)
|
||||
}
|
||||
if scope != tt.scope {
|
||||
t.Logf("Scope: got %q; want %q", scope, tt.scope)
|
||||
}
|
||||
if params != tt.params {
|
||||
t.Logf("Params: got %v; want %v", params, tt.params)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
47
util/syspolicy/setting/raw_item.go
Normal file
47
util/syspolicy/setting/raw_item.go
Normal file
@ -0,0 +1,47 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
// RawItem contains a raw policy setting as read from a policy store, or an
|
||||
// error if the requested setting could not be read from the store. As a special
|
||||
// case, it may also hold a value of the [Visibility], [PreferenceOption],
|
||||
// or [time.Duration] types. While the policy store interface does not support
|
||||
// these types natively, and the values of these types have to be unmarshalled
|
||||
// or converted from strings, these setting types predate the typed policy
|
||||
// hierarchies, and must be supported at this layer.
|
||||
type RawItem struct {
|
||||
value any
|
||||
err *Error
|
||||
origin *Origin // or nil
|
||||
}
|
||||
|
||||
// RawItemOf returns [RawItem] with the specified value.
|
||||
func RawItemOf(value any) RawItem {
|
||||
return RawItemWith(value, nil, nil)
|
||||
}
|
||||
|
||||
// RawItemWith returns an [RawItem] with the specified value, error and origin.
|
||||
func RawItemWith(value any, err *Error, origin *Origin) RawItem {
|
||||
return RawItem{value, err, origin}
|
||||
}
|
||||
|
||||
// Value returns the value of an untyped policy setting,
|
||||
// or nil if the policy setting is not configured.
|
||||
func (i RawItem) Value() any {
|
||||
return i.value
|
||||
}
|
||||
|
||||
// Error returns the error that occurred when reading the policy setting,
|
||||
// or nil if no error occurred.
|
||||
func (i RawItem) Error() error {
|
||||
if i.err != nil {
|
||||
return i.err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Origin returns an optional [Origin] indicating the policy settings is configured.
|
||||
func (i RawItem) Origin() *Origin {
|
||||
return i.origin
|
||||
}
|
352
util/syspolicy/setting/setting.go
Normal file
352
util/syspolicy/setting/setting.go
Normal file
@ -0,0 +1,352 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package setting contain types for policy settings.
|
||||
package setting
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/internal/lazyinit"
|
||||
)
|
||||
|
||||
// Scope indicates the broadest scope at which a policy setting may apply,
|
||||
// and the narrowest scope at which it may be configured.
|
||||
type Scope int8
|
||||
|
||||
const (
|
||||
// DeviceSetting indicates a policy setting that applies to a device, regardless of
|
||||
// which OS user or Tailscale profile is currently active, if any.
|
||||
// It can only be configured at a [DeviceScope].
|
||||
DeviceSetting Scope = iota
|
||||
// ProfileSetting indicates a policy setting that applies to a Tailscale profile.
|
||||
// It can only be configured for a specific profile or at a [DeviceScope],
|
||||
// in which case it applies to all profiles on the device.
|
||||
ProfileSetting
|
||||
// UserSetting indicates a policy setting that applies to users.
|
||||
// It can be configured for a user, profile, or the entire device.
|
||||
UserSetting
|
||||
|
||||
// MaxSettingScope is the maximum possible [Scope] value.
|
||||
MaxSettingScope = UserSetting
|
||||
)
|
||||
|
||||
// String implements [fmt.Stringer].
|
||||
func (s Scope) String() string {
|
||||
switch s {
|
||||
case DeviceSetting:
|
||||
return "Device"
|
||||
case ProfileSetting:
|
||||
return "Profile"
|
||||
case UserSetting:
|
||||
return "User"
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalText implements [encoding.TextMarshaler].
|
||||
func (s Scope) MarshalText() (text []byte, err error) {
|
||||
return []byte(s.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements [encoding.TextUnmarshaler].
|
||||
func (s *Scope) UnmarshalText(text []byte) error {
|
||||
switch strings.ToLower(string(text)) {
|
||||
case "device":
|
||||
*s = DeviceSetting
|
||||
case "profile":
|
||||
*s = ProfileSetting
|
||||
case "user":
|
||||
*s = UserSetting
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid scope", string(text))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Type is a policy setting value type.
|
||||
// Except for [InvalidValue], which represents an invalid policy setting type,
|
||||
// and [PreferenceOptionValue], [VisibilityValue], and [DurationValue],
|
||||
// which have special handling due to their legacy status in the package,
|
||||
// SettingTypes represent the raw value types readable from policy stores.
|
||||
type Type int
|
||||
|
||||
const (
|
||||
// InvalidValue indicates an invalid policy setting value type.
|
||||
InvalidValue Type = iota
|
||||
// BooleanValue indicates a policy setting whose underlying type in the
|
||||
// [source.Store] is a bool.
|
||||
BooleanValue
|
||||
// IntegerValue indicates a policy setting whose underlying type in the
|
||||
// [source.Store] is a uint64.
|
||||
IntegerValue
|
||||
// StringValue indicates a policy setting whose underlying type in the
|
||||
// [source.Store] is a string.
|
||||
StringValue
|
||||
// StringListValue indicates a policy setting whose underlying type in the
|
||||
// [source.Store] is a []string.
|
||||
StringListValue
|
||||
// PreferenceOptionValue indicates a three-state policy setting whose
|
||||
// underlying type in the [source.Store] is a string, but the actual value
|
||||
// is a [PreferenceOption].
|
||||
PreferenceOptionValue
|
||||
// VisibilityValue indicates a two-state boolean-like policy setting whose
|
||||
// underlying type in the [source.Store] is a string, but the actual value
|
||||
// is a [Visibility].
|
||||
VisibilityValue
|
||||
// DurationValue indicates an interval/period/duration policy setting whose
|
||||
// underlying type in the [source.Store] is a string, but the actual value
|
||||
// is a [time.Duration].
|
||||
DurationValue
|
||||
)
|
||||
|
||||
// String returns a string representation of t.
|
||||
func (t Type) String() string {
|
||||
switch t {
|
||||
case InvalidValue:
|
||||
return "Invalid"
|
||||
case BooleanValue:
|
||||
return "Boolean"
|
||||
case IntegerValue:
|
||||
return "Integer"
|
||||
case StringValue:
|
||||
return "String"
|
||||
case StringListValue:
|
||||
return "StringList"
|
||||
case PreferenceOptionValue:
|
||||
return "PreferenceOption"
|
||||
case VisibilityValue:
|
||||
return "Visibility"
|
||||
case DurationValue:
|
||||
return "Duration"
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// ValueType is a constraint that allows Go types corresponding to [Type].
|
||||
type ValueType interface {
|
||||
bool | uint64 | string | []string | Visibility | PreferenceOption | time.Duration
|
||||
}
|
||||
|
||||
// Definition defines policy key, scope and value type.
|
||||
type Definition struct {
|
||||
key Key
|
||||
scope Scope
|
||||
typ Type
|
||||
platforms PlatformList
|
||||
}
|
||||
|
||||
// NewDefinition returns a new [Definition] with the specified
|
||||
// key, scope, type and supported platforms (see [PlatformList]).
|
||||
func NewDefinition(k Key, s Scope, t Type, platforms ...string) *Definition {
|
||||
return &Definition{key: k, scope: s, typ: t, platforms: platforms}
|
||||
}
|
||||
|
||||
// Key returns a policy setting's identifier.
|
||||
func (d *Definition) Key() Key {
|
||||
if d == nil {
|
||||
return ""
|
||||
}
|
||||
return d.key
|
||||
}
|
||||
|
||||
// Scope reports the broadest [Scope] the policy setting may apply to.
|
||||
func (d *Definition) Scope() Scope {
|
||||
if d == nil {
|
||||
return 0
|
||||
}
|
||||
return d.scope
|
||||
}
|
||||
|
||||
// Type reports the underlying value type of the policy setting.
|
||||
func (d *Definition) Type() Type {
|
||||
if d == nil {
|
||||
return InvalidValue
|
||||
}
|
||||
return d.typ
|
||||
}
|
||||
|
||||
// IsSupported reports whether the policy setting is supported on the current OS.
|
||||
func (d *Definition) IsSupported() bool {
|
||||
if d == nil {
|
||||
return false
|
||||
}
|
||||
return d.platforms.HasCurrent()
|
||||
}
|
||||
|
||||
// SupportedPlatforms reports platforms on which the policy setting is supported.
|
||||
// An empty [PlatformList] indicates that s is available on all platforms.
|
||||
func (d *Definition) SupportedPlatforms() PlatformList {
|
||||
if d == nil {
|
||||
return nil
|
||||
}
|
||||
return d.platforms
|
||||
}
|
||||
|
||||
// String implements [fmt.Stringer].
|
||||
func (d *Definition) String() string {
|
||||
if d == nil {
|
||||
return "(nil)"
|
||||
}
|
||||
return fmt.Sprintf("%v(%q, %v)", d.scope, d.key, d.typ)
|
||||
}
|
||||
|
||||
// Equal reports whether d and d2 have the same key, type and scope.
|
||||
// It does not check whether both s and s2 are supported on the same platforms.
|
||||
func (d *Definition) Equal(d2 *Definition) bool {
|
||||
if d == d2 {
|
||||
return true
|
||||
}
|
||||
if d == nil || d2 == nil {
|
||||
return false
|
||||
}
|
||||
return d.key == d2.key && d.typ == d2.typ && d.scope == d2.scope
|
||||
}
|
||||
|
||||
// DefinitionMap is a map of setting [Definition] by [Key].
|
||||
type DefinitionMap map[Key]*Definition
|
||||
|
||||
var (
|
||||
definitions lazy.SyncValue[DefinitionMap]
|
||||
|
||||
definitionsMu sync.Mutex
|
||||
definitionsList []*Definition
|
||||
definitionsUsed bool
|
||||
)
|
||||
|
||||
// Register registers a policy setting with the specified key, scope, and value type.
|
||||
// All policy settings must be registered before any of them can be used.
|
||||
// Register panics if called after invoking any syspolicy functions that use the
|
||||
// registered policy definitions, such as functions that read the policy.
|
||||
func Register(k Key, s Scope, t Type, platforms ...string) {
|
||||
RegisterDefinition(NewDefinition(k, s, t, platforms...))
|
||||
}
|
||||
|
||||
// RegisterDefinition is like [Register], but accepts a [Definition].
|
||||
func RegisterDefinition(d *Definition) {
|
||||
definitionsMu.Lock()
|
||||
defer definitionsMu.Unlock()
|
||||
registerLocked(d)
|
||||
}
|
||||
|
||||
func registerLocked(d *Definition) {
|
||||
if definitionsUsed {
|
||||
panic("policy definitions are already in use")
|
||||
}
|
||||
definitionsList = append(definitionsList, d)
|
||||
}
|
||||
|
||||
func settingDefinitions() (DefinitionMap, error) {
|
||||
return definitions.GetErr(func() (DefinitionMap, error) {
|
||||
lazyinit.Do()
|
||||
definitionsMu.Lock()
|
||||
defer definitionsMu.Unlock()
|
||||
definitionsUsed = true
|
||||
return DefinitionMapOf(definitionsList)
|
||||
})
|
||||
}
|
||||
|
||||
// DefinitionMapOf returns a [DefinitionMap] with the specified settings,
|
||||
// or an error if any settings have the same key but different type or scope.
|
||||
func DefinitionMapOf(settings []*Definition) (DefinitionMap, error) {
|
||||
m := make(DefinitionMap, len(settings))
|
||||
for _, s := range settings {
|
||||
if existing, exists := m[s.key]; exists {
|
||||
if existing.Equal(s) {
|
||||
// Ignore duplicate setting definitions if they match. It is acceptable
|
||||
// if the same policy setting was registered more than once
|
||||
// (e.g. by the syspolicy package itself and by iOS/Android code).
|
||||
existing.platforms.mergeFrom(s.platforms)
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("duplicate policy definition: %q", s.key)
|
||||
}
|
||||
m[s.key] = s
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// SetDefinitionsForTest allows to register the specified setting definitions
|
||||
// for the test duration. It is not concurrency-safe, but unlike [Register],
|
||||
// it does not panic and can be called anytime.
|
||||
// It returns an error if ds contains two different settings with the same [Key].
|
||||
func SetDefinitionsForTest(tb lazy.TB, ds ...*Definition) error {
|
||||
m, err := DefinitionMapOf(ds)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
definitions.SetForTest(tb, m, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DefinitionOf returns a setting definition by key,
|
||||
// or [ErrNoSuchKey] if the specified key does not exist,
|
||||
// or an error if there are conflicting policy definitions.
|
||||
func DefinitionOf(k Key) (*Definition, error) {
|
||||
ds, err := settingDefinitions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if d, ok := ds[k]; ok {
|
||||
return d, nil
|
||||
}
|
||||
return nil, ErrNoSuchKey
|
||||
}
|
||||
|
||||
// Definitions returns all registered setting definitions,
|
||||
// or an error if different policies were registered under the same name.
|
||||
func Definitions() ([]*Definition, error) {
|
||||
ds, err := settingDefinitions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res := make([]*Definition, 0, len(ds))
|
||||
for _, d := range ds {
|
||||
res = append(res, d)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// PlatformList is a list of OSes.
|
||||
// An empty list indicates that all possible platforms are supported.
|
||||
type PlatformList []string
|
||||
|
||||
// Has reports whether the list contains the target platform.
|
||||
func (l PlatformList) Has(target string) bool {
|
||||
if len(l) == 0 {
|
||||
return true
|
||||
}
|
||||
return slices.ContainsFunc(l, func(os string) bool {
|
||||
return strings.EqualFold(os, target)
|
||||
})
|
||||
}
|
||||
|
||||
// HasCurrent is like Has, but for the current platform.
|
||||
func (l PlatformList) HasCurrent() bool {
|
||||
return l.Has(internal.OS())
|
||||
}
|
||||
|
||||
// mergeFrom merges l2 into l. Since an empty list indicates no platform restrictions,
|
||||
// if either l or l2 is empty, the merged result in l will also be empty.
|
||||
func (l *PlatformList) mergeFrom(l2 PlatformList) {
|
||||
switch {
|
||||
case len(*l) == 0:
|
||||
// No-op. An empty list indicates no platform restrictions.
|
||||
case len(l2) == 0:
|
||||
// Merging with an empty list results in an empty list.
|
||||
*l = l2
|
||||
default:
|
||||
// Append, sort and dedup.
|
||||
*l = append(*l, l2...)
|
||||
slices.Sort(*l)
|
||||
*l = slices.Compact(*l)
|
||||
}
|
||||
}
|
344
util/syspolicy/setting/setting_test.go
Normal file
344
util/syspolicy/setting/setting_test.go
Normal file
@ -0,0 +1,344 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
)
|
||||
|
||||
func TestSettingDefinition(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setting *Definition
|
||||
osOverride string
|
||||
wantKey Key
|
||||
wantScope Scope
|
||||
wantType Type
|
||||
wantIsSupported bool
|
||||
wantSupportedPlatforms PlatformList
|
||||
wantString string
|
||||
}{
|
||||
{
|
||||
name: "Nil",
|
||||
setting: nil,
|
||||
wantKey: "",
|
||||
wantScope: 0,
|
||||
wantType: InvalidValue,
|
||||
wantIsSupported: false,
|
||||
wantString: "(nil)",
|
||||
},
|
||||
{
|
||||
name: "Device/Invalid",
|
||||
setting: NewDefinition("TestDevicePolicySetting", DeviceSetting, InvalidValue),
|
||||
wantKey: "TestDevicePolicySetting",
|
||||
wantScope: DeviceSetting,
|
||||
wantType: InvalidValue,
|
||||
wantIsSupported: true,
|
||||
wantString: `Device("TestDevicePolicySetting", Invalid)`,
|
||||
},
|
||||
{
|
||||
name: "Device/Integer",
|
||||
setting: NewDefinition("TestDevicePolicySetting", DeviceSetting, IntegerValue),
|
||||
wantKey: "TestDevicePolicySetting",
|
||||
wantScope: DeviceSetting,
|
||||
wantType: IntegerValue,
|
||||
wantIsSupported: true,
|
||||
wantString: `Device("TestDevicePolicySetting", Integer)`,
|
||||
},
|
||||
{
|
||||
name: "Profile/String",
|
||||
setting: NewDefinition("TestProfilePolicySetting", ProfileSetting, StringValue),
|
||||
wantKey: "TestProfilePolicySetting",
|
||||
wantScope: ProfileSetting,
|
||||
wantType: StringValue,
|
||||
wantIsSupported: true,
|
||||
wantString: `Profile("TestProfilePolicySetting", String)`,
|
||||
},
|
||||
{
|
||||
name: "Device/StringList",
|
||||
setting: NewDefinition("AllowedSuggestedExitNodes", DeviceSetting, StringListValue),
|
||||
wantKey: "AllowedSuggestedExitNodes",
|
||||
wantScope: DeviceSetting,
|
||||
wantType: StringListValue,
|
||||
wantIsSupported: true,
|
||||
wantString: `Device("AllowedSuggestedExitNodes", StringList)`,
|
||||
},
|
||||
{
|
||||
name: "Device/PreferenceOption",
|
||||
setting: NewDefinition("AdvertiseExitNode", DeviceSetting, PreferenceOptionValue),
|
||||
wantKey: "AdvertiseExitNode",
|
||||
wantScope: DeviceSetting,
|
||||
wantType: PreferenceOptionValue,
|
||||
wantIsSupported: true,
|
||||
wantString: `Device("AdvertiseExitNode", PreferenceOption)`,
|
||||
},
|
||||
{
|
||||
name: "User/Boolean",
|
||||
setting: NewDefinition("TestUserPolicySetting", UserSetting, BooleanValue),
|
||||
wantKey: "TestUserPolicySetting",
|
||||
wantScope: UserSetting,
|
||||
wantType: BooleanValue,
|
||||
wantIsSupported: true,
|
||||
wantString: `User("TestUserPolicySetting", Boolean)`,
|
||||
},
|
||||
{
|
||||
name: "User/Visibility",
|
||||
setting: NewDefinition("AdminConsole", UserSetting, VisibilityValue),
|
||||
wantKey: "AdminConsole",
|
||||
wantScope: UserSetting,
|
||||
wantType: VisibilityValue,
|
||||
wantIsSupported: true,
|
||||
wantString: `User("AdminConsole", Visibility)`,
|
||||
},
|
||||
{
|
||||
name: "User/Duration",
|
||||
setting: NewDefinition("KeyExpirationNotice", UserSetting, DurationValue),
|
||||
wantKey: "KeyExpirationNotice",
|
||||
wantScope: UserSetting,
|
||||
wantType: DurationValue,
|
||||
wantIsSupported: true,
|
||||
wantString: `User("KeyExpirationNotice", Duration)`,
|
||||
},
|
||||
{
|
||||
name: "SupportedSetting",
|
||||
setting: NewDefinition("DesktopPolicySetting", DeviceSetting, StringValue, "macos", "windows"),
|
||||
osOverride: "windows",
|
||||
wantKey: "DesktopPolicySetting",
|
||||
wantScope: DeviceSetting,
|
||||
wantType: StringValue,
|
||||
wantIsSupported: true,
|
||||
wantSupportedPlatforms: PlatformList{"macos", "windows"},
|
||||
wantString: `Device("DesktopPolicySetting", String)`,
|
||||
},
|
||||
{
|
||||
name: "UnsupportedSetting",
|
||||
setting: NewDefinition("AndroidPolicySetting", DeviceSetting, StringValue, "android"),
|
||||
osOverride: "macos",
|
||||
wantKey: "AndroidPolicySetting",
|
||||
wantScope: DeviceSetting,
|
||||
wantType: StringValue,
|
||||
wantIsSupported: false,
|
||||
wantSupportedPlatforms: PlatformList{"android"},
|
||||
wantString: `Device("AndroidPolicySetting", String)`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.osOverride != "" {
|
||||
internal.OSForTesting.SetForTest(t, tt.osOverride, nil)
|
||||
}
|
||||
if !tt.setting.Equal(tt.setting) {
|
||||
t.Errorf("the setting should be equal to itself")
|
||||
}
|
||||
if tt.setting != nil && !tt.setting.Equal(ptr.To(*tt.setting)) {
|
||||
t.Errorf("the setting should be equal to its shallow copy")
|
||||
}
|
||||
if gotKey := tt.setting.Key(); gotKey != tt.wantKey {
|
||||
t.Errorf("Key: got %q, want %q", gotKey, tt.wantKey)
|
||||
}
|
||||
if gotScope := tt.setting.Scope(); gotScope != tt.wantScope {
|
||||
t.Errorf("Scope: got %v, want %v", gotScope, tt.wantScope)
|
||||
}
|
||||
if gotType := tt.setting.Type(); gotType != tt.wantType {
|
||||
t.Errorf("Type: got %v, want %v", gotType, tt.wantType)
|
||||
}
|
||||
if gotIsSupported := tt.setting.IsSupported(); gotIsSupported != tt.wantIsSupported {
|
||||
t.Errorf("IsSupported: got %v, want %v", gotIsSupported, tt.wantIsSupported)
|
||||
}
|
||||
if gotSupportedPlatforms := tt.setting.SupportedPlatforms(); !slices.Equal(gotSupportedPlatforms, tt.wantSupportedPlatforms) {
|
||||
t.Errorf("SupportedPlatforms: got %v, want %v", gotSupportedPlatforms, tt.wantSupportedPlatforms)
|
||||
}
|
||||
if gotString := tt.setting.String(); gotString != tt.wantString {
|
||||
t.Errorf("String: got %v, want %v", gotString, tt.wantString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterSettingDefinition(t *testing.T) {
|
||||
const testPolicySettingKey Key = "TestPolicySetting"
|
||||
tests := []struct {
|
||||
name string
|
||||
key Key
|
||||
wantEq *Definition
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "GetRegistered",
|
||||
key: "TestPolicySetting",
|
||||
wantEq: NewDefinition(testPolicySettingKey, DeviceSetting, StringValue),
|
||||
},
|
||||
{
|
||||
name: "GetNonRegistered",
|
||||
key: "OtherPolicySetting",
|
||||
wantEq: nil,
|
||||
wantErr: ErrNoSuchKey,
|
||||
},
|
||||
}
|
||||
|
||||
resetSettingDefinitions(t)
|
||||
Register(testPolicySettingKey, DeviceSetting, StringValue)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, gotErr := DefinitionOf(tt.key)
|
||||
if gotErr != tt.wantErr {
|
||||
t.Errorf("gotErr %v, wantErr %v", gotErr, tt.wantErr)
|
||||
}
|
||||
if !got.Equal(tt.wantEq) {
|
||||
t.Errorf("got %v, want %v", got, tt.wantEq)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterAfterUsePanics(t *testing.T) {
|
||||
resetSettingDefinitions(t)
|
||||
|
||||
Register("TestPolicySetting", DeviceSetting, StringValue)
|
||||
DefinitionOf("TestPolicySetting")
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
if gotPanic, wantPanic := recover(), "policy definitions are already in use"; gotPanic != wantPanic {
|
||||
t.Errorf("gotPanic: %q, wantPanic: %q", gotPanic, wantPanic)
|
||||
}
|
||||
}()
|
||||
|
||||
Register("TestPolicySetting", DeviceSetting, StringValue)
|
||||
}()
|
||||
}
|
||||
|
||||
func TestRegisterDuplicateSettings(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
settings []*Definition
|
||||
wantEq *Definition
|
||||
wantErrStr string
|
||||
}{
|
||||
{
|
||||
name: "NoConflict/Exact",
|
||||
settings: []*Definition{
|
||||
NewDefinition("TestPolicySetting", DeviceSetting, StringValue),
|
||||
NewDefinition("TestPolicySetting", DeviceSetting, StringValue),
|
||||
},
|
||||
wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue),
|
||||
},
|
||||
{
|
||||
name: "NoConflict/MergeOS-First",
|
||||
settings: []*Definition{
|
||||
NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "android", "macos"),
|
||||
NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms
|
||||
},
|
||||
wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms
|
||||
},
|
||||
{
|
||||
name: "NoConflict/MergeOS-Second",
|
||||
settings: []*Definition{
|
||||
NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms
|
||||
NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "android", "macos"),
|
||||
},
|
||||
wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue), // all platforms
|
||||
},
|
||||
{
|
||||
name: "NoConflict/MergeOS-Both",
|
||||
settings: []*Definition{
|
||||
NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "macos"),
|
||||
NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "windows"),
|
||||
},
|
||||
wantEq: NewDefinition("TestPolicySetting", DeviceSetting, StringValue, "macos", "windows"),
|
||||
},
|
||||
{
|
||||
name: "Conflict/Scope",
|
||||
settings: []*Definition{
|
||||
NewDefinition("TestPolicySetting", DeviceSetting, StringValue),
|
||||
NewDefinition("TestPolicySetting", UserSetting, StringValue),
|
||||
},
|
||||
wantEq: nil,
|
||||
wantErrStr: `duplicate policy definition: "TestPolicySetting"`,
|
||||
},
|
||||
{
|
||||
name: "Conflict/Type",
|
||||
settings: []*Definition{
|
||||
NewDefinition("TestPolicySetting", UserSetting, StringValue),
|
||||
NewDefinition("TestPolicySetting", UserSetting, IntegerValue),
|
||||
},
|
||||
wantEq: nil,
|
||||
wantErrStr: `duplicate policy definition: "TestPolicySetting"`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resetSettingDefinitions(t)
|
||||
for _, s := range tt.settings {
|
||||
Register(s.Key(), s.Scope(), s.Type(), s.SupportedPlatforms()...)
|
||||
}
|
||||
got, err := DefinitionOf("TestPolicySetting")
|
||||
var gotErrStr string
|
||||
if err != nil {
|
||||
gotErrStr = err.Error()
|
||||
}
|
||||
if gotErrStr != tt.wantErrStr {
|
||||
t.Fatalf("ErrStr: got %q, want %q", gotErrStr, tt.wantErrStr)
|
||||
}
|
||||
if !got.Equal(tt.wantEq) {
|
||||
t.Errorf("Definition got %v, want %v", got, tt.wantEq)
|
||||
}
|
||||
if !slices.Equal(got.SupportedPlatforms(), tt.wantEq.SupportedPlatforms()) {
|
||||
t.Errorf("SupportedPlatforms got %v, want %v", got.SupportedPlatforms(), tt.wantEq.SupportedPlatforms())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListSettingDefinitions(t *testing.T) {
|
||||
definitions := []*Definition{
|
||||
NewDefinition("TestDevicePolicySetting", DeviceSetting, IntegerValue),
|
||||
NewDefinition("TestProfilePolicySetting", ProfileSetting, StringValue),
|
||||
NewDefinition("TestUserPolicySetting", UserSetting, BooleanValue),
|
||||
NewDefinition("TestStringListPolicySetting", DeviceSetting, StringListValue),
|
||||
}
|
||||
if err := SetDefinitionsForTest(t, definitions...); err != nil {
|
||||
t.Fatalf("SetDefinitionsForTest failed: %v", err)
|
||||
}
|
||||
|
||||
cmp := func(l, r *Definition) int {
|
||||
return strings.Compare(string(l.Key()), string(r.Key()))
|
||||
}
|
||||
want := append([]*Definition{}, definitions...)
|
||||
slices.SortFunc(want, cmp)
|
||||
|
||||
got, err := Definitions()
|
||||
if err != nil {
|
||||
t.Fatalf("Definitions failed: %v", err)
|
||||
}
|
||||
slices.SortFunc(got, cmp)
|
||||
|
||||
if !slices.Equal(got, want) {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func resetSettingDefinitions(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
definitionsMu.Lock()
|
||||
definitionsList = nil
|
||||
definitions = lazy.SyncValue[DefinitionMap]{}
|
||||
definitionsUsed = false
|
||||
definitionsMu.Unlock()
|
||||
})
|
||||
|
||||
definitionsMu.Lock()
|
||||
definitionsList = nil
|
||||
definitions = lazy.SyncValue[DefinitionMap]{}
|
||||
definitionsUsed = false
|
||||
definitionsMu.Unlock()
|
||||
}
|
153
util/syspolicy/setting/snapshot.go
Normal file
153
util/syspolicy/setting/snapshot.go
Normal file
@ -0,0 +1,153 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
import (
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
"tailscale.com/util/deephash"
|
||||
)
|
||||
|
||||
// Snapshot is an immutable collection of [RawItem]s, representing
|
||||
// a set of policy settings applied at a specific moment in time.
|
||||
// A nil pointer to [Snapshot] is valid.
|
||||
type Snapshot struct {
|
||||
m map[Key]RawItem
|
||||
sig deephash.Sum // of m
|
||||
summary Summary
|
||||
}
|
||||
|
||||
// NewSnapshot returns a new [Snapshot] with the specified items and options.
|
||||
func NewSnapshot(items map[Key]RawItem, opts ...SummaryOption) *Snapshot {
|
||||
return &Snapshot{m: items, sig: deephash.Hash(&items), summary: SummaryWith(opts...)}
|
||||
}
|
||||
|
||||
type keyItemPair struct {
|
||||
Key Key
|
||||
Item RawItem
|
||||
}
|
||||
|
||||
// All returns an iterator over [[Key], [RawItem]] key-value pairs in b. The
|
||||
// iteration order is not specified and is not guaranteed to be the same from
|
||||
// one call to the next.
|
||||
func (s *Snapshot) All() []keyItemPair {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
// TODO(nickkhyl): return iter.Seq2[[Key], [RawItem]] in Go 1.23,
|
||||
// and remove [keyItemPair].
|
||||
items := make([]keyItemPair, 0, len(s.m))
|
||||
for k, i := range s.m {
|
||||
items = append(items, keyItemPair{k, i})
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
// Get returns the value of the policy setting with the specified key
|
||||
// or nil if it does not exist or could not be read.
|
||||
func (s *Snapshot) Get(k Key) any {
|
||||
v, _ := s.GetErr(k)
|
||||
return v
|
||||
}
|
||||
|
||||
// GetErr returns the value of the policy setting with the specified key,
|
||||
// [ErrNotConfigured] if it does not exist, or an error returned by
|
||||
// the policy Store if the policy setting could not be read.
|
||||
func (s *Snapshot) GetErr(k Key) (any, error) {
|
||||
if s != nil {
|
||||
if s, ok := s.m[k]; ok {
|
||||
return s.Value(), s.Error()
|
||||
}
|
||||
}
|
||||
return nil, ErrNotConfigured
|
||||
}
|
||||
|
||||
// GetSetting returns the untyped policy setting with the specified key and true
|
||||
// if a policy setting with such key has been configured;
|
||||
// otherwise, it returns zero, false.
|
||||
func (s *Snapshot) GetSetting(k Key) (setting RawItem, ok bool) {
|
||||
setting, ok = s.m[k]
|
||||
return setting, ok
|
||||
}
|
||||
|
||||
// Equal reports whether s and s2 are equal.
|
||||
func (s *Snapshot) Equal(s2 *Snapshot) bool {
|
||||
if !s.EqualItems(s2) {
|
||||
return false
|
||||
}
|
||||
return s.Summary() == s2.Summary()
|
||||
}
|
||||
|
||||
// EqualItems reports whether items in s and s2 are equal.
|
||||
func (s *Snapshot) EqualItems(s2 *Snapshot) bool {
|
||||
if s == s2 {
|
||||
return true
|
||||
}
|
||||
if s.Len() != s2.Len() {
|
||||
return false
|
||||
}
|
||||
if s.Len() == 0 {
|
||||
return true
|
||||
}
|
||||
return s.sig == s2.sig
|
||||
}
|
||||
|
||||
// Keys return an iterator over keys in s. The iteration order is not specified
|
||||
// and is not guaranteed to be the same from one call to the next.
|
||||
func (s *Snapshot) Keys() []Key {
|
||||
if s.m == nil {
|
||||
return nil
|
||||
}
|
||||
// TODO(nickkhyl): return iter.Seq[Key] in Go 1.23.
|
||||
return xmaps.Keys(s.m)
|
||||
}
|
||||
|
||||
// Len reports the number of [RawItem]s in s.
|
||||
func (s *Snapshot) Len() int {
|
||||
if s == nil {
|
||||
return 0
|
||||
}
|
||||
return len(s.m)
|
||||
}
|
||||
|
||||
// Summary returns information about s as a whole rather than about specific [RawItem]s in it.
|
||||
func (s *Snapshot) Summary() Summary {
|
||||
if s == nil {
|
||||
return Summary{}
|
||||
}
|
||||
return s.summary
|
||||
}
|
||||
|
||||
// MergeSnapshots returns a [Snapshot] that contains all [RawItem]s
|
||||
// from snapshot1 and snapshot2 and the [Summary] with the narrower [PolicyScope].
|
||||
// If there's a conflict between policy settings in the two snapshots,
|
||||
// the policy settings from the snapshot with the broader scope take precedence.
|
||||
// In other words, policy settings configured for the [DeviceScope] win
|
||||
// over policy settings configured for a user scope.
|
||||
func MergeSnapshots(snapshot1, snapshot2 *Snapshot) *Snapshot {
|
||||
scope1, ok1 := snapshot1.Summary().Scope().GetOk()
|
||||
scope2, ok2 := snapshot2.Summary().Scope().GetOk()
|
||||
if ok1 && ok2 && scope2.IsStrictlyWithinOf(scope1) {
|
||||
// Swap snapshots if snapshot1 has higher precedence than snapshot2.
|
||||
snapshot1, snapshot2 = snapshot2, snapshot1
|
||||
}
|
||||
if snapshot2.Len() == 0 {
|
||||
return snapshot1
|
||||
}
|
||||
summaryOpts := make([]SummaryOption, 0, 2)
|
||||
if scope, ok := snapshot1.Summary().Scope().GetOk(); ok {
|
||||
// Use the scope from snapshot1, if present, which is the more specific snapshot.
|
||||
summaryOpts = append(summaryOpts, scope)
|
||||
}
|
||||
if snapshot1.Len() == 0 {
|
||||
if origin, ok := snapshot2.Summary().Origin().GetOk(); ok {
|
||||
// Use the origin from snapshot2 if snapshot1 is empty.
|
||||
summaryOpts = append(summaryOpts, origin)
|
||||
}
|
||||
return &Snapshot{snapshot2.m, snapshot2.sig, SummaryWith(summaryOpts...)}
|
||||
}
|
||||
m := make(map[Key]RawItem, snapshot1.Len()+snapshot2.Len())
|
||||
xmaps.Copy(m, snapshot1.m)
|
||||
xmaps.Copy(m, snapshot2.m) // snapshot2 has higher precedence
|
||||
return &Snapshot{m, deephash.Hash(&m), SummaryWith(summaryOpts...)}
|
||||
}
|
372
util/syspolicy/setting/snapshot_test.go
Normal file
372
util/syspolicy/setting/snapshot_test.go
Normal file
@ -0,0 +1,372 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMergeSnapshots(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s1, s2 *Snapshot
|
||||
want *Snapshot
|
||||
}{
|
||||
{
|
||||
name: "both-nil",
|
||||
s1: nil,
|
||||
s2: nil,
|
||||
want: NewSnapshot(map[Key]RawItem{}),
|
||||
},
|
||||
{
|
||||
name: "both-empty",
|
||||
s1: NewSnapshot(map[Key]RawItem{}),
|
||||
s2: NewSnapshot(map[Key]RawItem{}),
|
||||
want: NewSnapshot(map[Key]RawItem{}),
|
||||
},
|
||||
{
|
||||
name: "first-nil",
|
||||
s1: nil,
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "first-empty",
|
||||
s1: NewSnapshot(map[Key]RawItem{}),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "second-nil",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}),
|
||||
s2: nil,
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "second-empty",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
s2: NewSnapshot(map[Key]RawItem{}),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "no-conflicts",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
"Setting5": {value: VisibleByPolicy},
|
||||
"Setting6": {value: ShowChoiceByPolicy},
|
||||
}),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
"Setting5": {value: VisibleByPolicy},
|
||||
"Setting6": {value: ShowChoiceByPolicy},
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "with-conflicts",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 456},
|
||||
"Setting3": {value: false},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
}),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 456},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "with-scope-first-wins",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}, DeviceScope),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 456},
|
||||
"Setting3": {value: false},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
}, CurrentUserScope),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
}, CurrentUserScope),
|
||||
},
|
||||
{
|
||||
name: "with-scope-second-wins",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}, CurrentUserScope),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 456},
|
||||
"Setting3": {value: false},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
}, DeviceScope),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 456},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
}, CurrentUserScope),
|
||||
},
|
||||
{
|
||||
name: "with-scope-both-empty",
|
||||
s1: NewSnapshot(map[Key]RawItem{}, CurrentUserScope),
|
||||
s2: NewSnapshot(map[Key]RawItem{}, DeviceScope),
|
||||
want: NewSnapshot(map[Key]RawItem{}, CurrentUserScope),
|
||||
},
|
||||
{
|
||||
name: "with-scope-first-empty",
|
||||
s1: NewSnapshot(map[Key]RawItem{}, CurrentUserScope),
|
||||
s2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true}}, DeviceScope),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}, CurrentUserScope),
|
||||
},
|
||||
{
|
||||
name: "with-scope-second-empty",
|
||||
s1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}, CurrentUserScope),
|
||||
s2: NewSnapshot(map[Key]RawItem{}, DeviceScope),
|
||||
want: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}, CurrentUserScope),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := MergeSnapshots(tt.s1, tt.s2)
|
||||
if !got.Equal(tt.want) {
|
||||
t.Errorf("got %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnapshotEqual(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
b1, b2 *Snapshot
|
||||
wantEqual bool
|
||||
wantEqualItems bool
|
||||
}{
|
||||
{
|
||||
name: "nil-nil",
|
||||
b1: nil,
|
||||
b2: nil,
|
||||
wantEqual: true,
|
||||
wantEqualItems: true,
|
||||
},
|
||||
{
|
||||
name: "nil-empty",
|
||||
b1: nil,
|
||||
b2: NewSnapshot(map[Key]RawItem{}),
|
||||
wantEqual: true,
|
||||
wantEqualItems: true,
|
||||
},
|
||||
{
|
||||
name: "empty-nil",
|
||||
b1: NewSnapshot(map[Key]RawItem{}),
|
||||
b2: nil,
|
||||
wantEqual: true,
|
||||
wantEqualItems: true,
|
||||
},
|
||||
{
|
||||
name: "empty-empty",
|
||||
b1: NewSnapshot(map[Key]RawItem{}),
|
||||
b2: NewSnapshot(map[Key]RawItem{}),
|
||||
wantEqual: true,
|
||||
wantEqualItems: true,
|
||||
},
|
||||
{
|
||||
name: "first-nil",
|
||||
b1: nil,
|
||||
b2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
wantEqual: false,
|
||||
wantEqualItems: false,
|
||||
},
|
||||
{
|
||||
name: "first-empty",
|
||||
b1: NewSnapshot(map[Key]RawItem{}),
|
||||
b2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
wantEqual: false,
|
||||
wantEqualItems: false,
|
||||
},
|
||||
{
|
||||
name: "second-nil",
|
||||
b1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: true},
|
||||
}),
|
||||
b2: nil,
|
||||
wantEqual: false,
|
||||
wantEqualItems: false,
|
||||
},
|
||||
{
|
||||
name: "second-empty",
|
||||
b1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
b2: NewSnapshot(map[Key]RawItem{}),
|
||||
wantEqual: false,
|
||||
wantEqualItems: false,
|
||||
},
|
||||
{
|
||||
name: "same-items-same-order-no-scope",
|
||||
b1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
b2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}),
|
||||
wantEqual: true,
|
||||
wantEqualItems: true,
|
||||
},
|
||||
{
|
||||
name: "same-items-same-order-same-scope",
|
||||
b1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}, DeviceScope),
|
||||
b2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}, DeviceScope),
|
||||
wantEqual: true,
|
||||
wantEqualItems: true,
|
||||
},
|
||||
{
|
||||
name: "same-items-different-order-same-scope",
|
||||
b1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}, DeviceScope),
|
||||
b2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting3": {value: false},
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
}, DeviceScope),
|
||||
wantEqual: true,
|
||||
wantEqualItems: true,
|
||||
},
|
||||
{
|
||||
name: "same-items-same-order-different-scope",
|
||||
b1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}, DeviceScope),
|
||||
b2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}, CurrentUserScope),
|
||||
wantEqual: false,
|
||||
wantEqualItems: true,
|
||||
},
|
||||
{
|
||||
name: "different-items-same-scope",
|
||||
b1: NewSnapshot(map[Key]RawItem{
|
||||
"Setting1": {value: 123},
|
||||
"Setting2": {value: "String"},
|
||||
"Setting3": {value: false},
|
||||
}, DeviceScope),
|
||||
b2: NewSnapshot(map[Key]RawItem{
|
||||
"Setting4": {value: 2 * time.Hour},
|
||||
"Setting5": {value: VisibleByPolicy},
|
||||
"Setting6": {value: ShowChoiceByPolicy},
|
||||
}, DeviceScope),
|
||||
wantEqual: false,
|
||||
wantEqualItems: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if gotEqual := tt.b1.Equal(tt.b2); gotEqual != tt.wantEqual {
|
||||
t.Errorf("WantEqual: got %v, want %v", gotEqual, tt.wantEqual)
|
||||
}
|
||||
if gotEqualItems := tt.b1.EqualItems(tt.b2); gotEqualItems != tt.wantEqualItems {
|
||||
t.Errorf("WantEqualItems: got %v, want %v", gotEqualItems, tt.wantEqualItems)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
84
util/syspolicy/setting/summary.go
Normal file
84
util/syspolicy/setting/summary.go
Normal file
@ -0,0 +1,84 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
import (
|
||||
jsonv2 "github.com/go-json-experiment/json"
|
||||
"github.com/go-json-experiment/json/jsontext"
|
||||
"tailscale.com/types/opt"
|
||||
)
|
||||
|
||||
// Summary is an immutable [PolicyScope] and [Origin].
|
||||
type Summary struct {
|
||||
data summary
|
||||
}
|
||||
|
||||
type summary struct {
|
||||
Scope opt.Value[PolicyScope] `json:",omitzero"`
|
||||
Origin opt.Value[Origin] `json:",omitzero"`
|
||||
}
|
||||
|
||||
// SummaryWith returns a [Summary] with the specified options.
|
||||
func SummaryWith(opts ...SummaryOption) Summary {
|
||||
var summary Summary
|
||||
for _, o := range opts {
|
||||
o.applySummaryOption(&summary)
|
||||
}
|
||||
return summary
|
||||
}
|
||||
|
||||
// Scope reports the [PolicyScope] in s.
|
||||
func (s Summary) Scope() opt.Value[PolicyScope] {
|
||||
return s.data.Scope
|
||||
}
|
||||
|
||||
// Origin reports the [Origin] in s.
|
||||
func (s Summary) Origin() opt.Value[Origin] {
|
||||
return s.data.Origin
|
||||
}
|
||||
|
||||
// MarshalJSONV2 implements [jsonv2.MarshalerV2].
|
||||
func (s Summary) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error {
|
||||
return jsonv2.MarshalEncode(out, &s.data, opts)
|
||||
}
|
||||
|
||||
// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2].
|
||||
func (s *Summary) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error {
|
||||
return jsonv2.UnmarshalDecode(in, &s.data, opts)
|
||||
}
|
||||
|
||||
// MarshalJSON implements [json.Marshaler].
|
||||
func (s Summary) MarshalJSON() ([]byte, error) {
|
||||
return jsonv2.Marshal(s) // uses MarshalJSONV2
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements [json.Unmarshaler].
|
||||
func (s *Summary) UnmarshalJSON(b []byte) error {
|
||||
return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONV2
|
||||
}
|
||||
|
||||
// SummaryOption is an option that configures [Summary]
|
||||
// The following are allowed options:
|
||||
//
|
||||
// - [Summary]
|
||||
// - [PolicyScope]
|
||||
// - [Origin]
|
||||
type SummaryOption interface {
|
||||
applySummaryOption(summary *Summary)
|
||||
}
|
||||
|
||||
func (s PolicyScope) applySummaryOption(summary *Summary) {
|
||||
summary.data.Scope.Set(s)
|
||||
}
|
||||
|
||||
func (o Origin) applySummaryOption(summary *Summary) {
|
||||
summary.data.Origin.Set(o)
|
||||
if !summary.data.Scope.IsSet() {
|
||||
summary.data.Scope.Set(o.Scope())
|
||||
}
|
||||
}
|
||||
|
||||
func (s Summary) applySummaryOption(summary *Summary) {
|
||||
*summary = s
|
||||
}
|
132
util/syspolicy/setting/types.go
Normal file
132
util/syspolicy/setting/types.go
Normal file
@ -0,0 +1,132 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package setting
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
)
|
||||
|
||||
// PreferenceOption is a policy that governs whether a boolean variable
|
||||
// is forcibly assigned an administrator-defined value, or allowed to receive
|
||||
// a user-defined value.
|
||||
type PreferenceOption int
|
||||
|
||||
const (
|
||||
ShowChoiceByPolicy PreferenceOption = iota
|
||||
NeverByPolicy
|
||||
AlwaysByPolicy
|
||||
)
|
||||
|
||||
// Show returns if the UI option that controls the choice administered by this
|
||||
// policy should be shown. Currently this is true if and only if the policy is
|
||||
// [ShowChoiceByPolicy].
|
||||
func (p PreferenceOption) Show() bool {
|
||||
return p == ShowChoiceByPolicy
|
||||
}
|
||||
|
||||
// ShouldEnable checks if the choice administered by this policy should be
|
||||
// enabled. If the administrator has chosen a setting, the administrator's
|
||||
// setting is returned, otherwise userChoice is returned.
|
||||
func (p PreferenceOption) ShouldEnable(userChoice bool) bool {
|
||||
switch p {
|
||||
case NeverByPolicy:
|
||||
return false
|
||||
case AlwaysByPolicy:
|
||||
return true
|
||||
default:
|
||||
return userChoice
|
||||
}
|
||||
}
|
||||
|
||||
// IsAlways reports whether the preference should always be enabled.
|
||||
func (p PreferenceOption) IsAlways() bool {
|
||||
return p == AlwaysByPolicy
|
||||
}
|
||||
|
||||
// IsNever reports whether the preference should always be disabled.
|
||||
func (p PreferenceOption) IsNever() bool {
|
||||
return p == NeverByPolicy
|
||||
}
|
||||
|
||||
// WillOverride checks if the choice administered by the policy is different
|
||||
// from the user's choice.
|
||||
func (p PreferenceOption) WillOverride(userChoice bool) bool {
|
||||
return p.ShouldEnable(userChoice) != userChoice
|
||||
}
|
||||
|
||||
// String returns a string representation of p.
|
||||
func (p PreferenceOption) String() string {
|
||||
switch p {
|
||||
case AlwaysByPolicy:
|
||||
return "always"
|
||||
case NeverByPolicy:
|
||||
return "never"
|
||||
default:
|
||||
return "user-decides"
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalText implements [encoding.TextMarshaler].
|
||||
func (p *PreferenceOption) MarshalText() (text []byte, err error) {
|
||||
return []byte(p.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements [encoding.TextUnmarshaler].
|
||||
func (p *PreferenceOption) UnmarshalText(text []byte) error {
|
||||
switch string(text) {
|
||||
case "always":
|
||||
*p = AlwaysByPolicy
|
||||
case "never":
|
||||
*p = NeverByPolicy
|
||||
default:
|
||||
*p = ShowChoiceByPolicy
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Visibility is a policy that controls whether or not a particular
|
||||
// component of a user interface is to be shown.
|
||||
type Visibility byte
|
||||
|
||||
var (
|
||||
_ encoding.TextMarshaler = (*Visibility)(nil)
|
||||
_ encoding.TextUnmarshaler = (*Visibility)(nil)
|
||||
)
|
||||
|
||||
const (
|
||||
VisibleByPolicy Visibility = 'v'
|
||||
HiddenByPolicy Visibility = 'h'
|
||||
)
|
||||
|
||||
// Show reports whether the UI option administered by this policy should be shown.
|
||||
// Currently this is true if the policy is not [hiddenByPolicy].
|
||||
func (p Visibility) Show() bool {
|
||||
return p != HiddenByPolicy
|
||||
}
|
||||
|
||||
// String returns a string representation of p.
|
||||
func (p Visibility) String() string {
|
||||
switch p {
|
||||
case 'h':
|
||||
return "hide"
|
||||
default:
|
||||
return "show"
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalText implements [encoding.TextMarshaler].
|
||||
func (p Visibility) MarshalText() (text []byte, err error) {
|
||||
return []byte(p.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements [encoding.TextUnmarshaler].
|
||||
func (p *Visibility) UnmarshalText(text []byte) error {
|
||||
switch string(text) {
|
||||
case "hide":
|
||||
*p = HiddenByPolicy
|
||||
default:
|
||||
*p = VisibleByPolicy
|
||||
}
|
||||
return nil
|
||||
}
|
393
util/syspolicy/source/policy_reader.go
Normal file
393
util/syspolicy/source/policy_reader.go
Normal file
@ -0,0 +1,393 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/util/syspolicy/internal/loggerx"
|
||||
"tailscale.com/util/syspolicy/internal/metrics"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
// Reader reads all configured policy settings from a given [Store].
|
||||
// It registers a change callback with the [Store] and maintains the current version
|
||||
// of the [setting.Snapshot] by lazily re-reading policy settings from the [Store]
|
||||
// whenever a new snapshot is requested
|
||||
// It is safe for concurrent use.
|
||||
type Reader struct {
|
||||
store Store
|
||||
origin *setting.Origin
|
||||
settings []*setting.Definition
|
||||
unregisterChangeNotifier func()
|
||||
doneCh chan struct{} // closed when policyCache is closed.
|
||||
|
||||
mu sync.RWMutex
|
||||
closing bool
|
||||
upToDate bool
|
||||
lastPolicy *setting.Snapshot
|
||||
sessions set.HandleSet[*ReadingSession]
|
||||
}
|
||||
|
||||
// newReader returns a new [Reader] that reads policy settings from a given [Store].
|
||||
// The returned reader takes ownership of the store. If the store implements [io.Closer],
|
||||
// the returned reader will close the store when it is closed.
|
||||
func newReader(store Store, origin *setting.Origin) (*Reader, error) {
|
||||
settings, err := setting.Definitions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if expirable, ok := store.(Expirable); ok {
|
||||
select {
|
||||
case <-expirable.Done():
|
||||
return nil, ErrStoreClosed
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
reader := &Reader{store: store, origin: origin, settings: settings, doneCh: make(chan struct{})}
|
||||
if changeable, ok := store.(Changeable); ok {
|
||||
// We should subscribe to policy change notifications first before reading
|
||||
// the policy settings from the store. This way we won't miss any notifications.
|
||||
if reader.unregisterChangeNotifier, err = changeable.RegisterChangeCallback(reader.onPolicyChange); err != nil {
|
||||
// Errors registering policy change callbacks are non-fatal.
|
||||
// TODO(nickkhyl): implement a background policy refresh every X minutes?
|
||||
loggerx.Errorf("failed to register %v policy change callback: %v\n", origin, err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := reader.reload(true); err != nil {
|
||||
if reader.unregisterChangeNotifier != nil {
|
||||
reader.unregisterChangeNotifier()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if expirable, ok := store.(Expirable); ok {
|
||||
if waitCh := expirable.Done(); waitCh != nil {
|
||||
go func() {
|
||||
select {
|
||||
case <-waitCh:
|
||||
reader.Close()
|
||||
case <-reader.doneCh:
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
// GetSettings returns the current [*setting.Snapshot],
|
||||
// re-reading it from from the underlying [Store] only if the policy
|
||||
// has changed since it was read last. It never fails and returns
|
||||
// the previous version of the policy settings if a read attempt fails.
|
||||
func (r *Reader) GetSettings() *setting.Snapshot {
|
||||
r.mu.RLock()
|
||||
if r.upToDate {
|
||||
r.mu.RUnlock()
|
||||
return r.lastPolicy
|
||||
}
|
||||
r.mu.RUnlock()
|
||||
|
||||
policy, err := r.reload(false)
|
||||
if err != nil {
|
||||
// If the policy could not be reloaded at all, we'll return the last cached version of it.
|
||||
// On the contrary, errors specific to individual policy items are always propagated to the callers.
|
||||
loggerx.Errorf("failed to reload %v policy: %v\n", r.origin, err)
|
||||
}
|
||||
return policy
|
||||
}
|
||||
|
||||
// ReadSettings reads policy settings from the underlying [Store] even if no
|
||||
// changes were detected. It returns the new [*setting.Snapshot], nil on
|
||||
// success, or nil, error in case of failure.
|
||||
func (r *Reader) ReadSettings() (*setting.Snapshot, error) {
|
||||
b, err := r.reload(true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// reload is like [Reader.ReadSettings], but allows specifying whether to re-read
|
||||
// an unchanged policy, and returns the last [*setting.Snapshot] if the read fails.
|
||||
func (r *Reader) reload(force bool) (*setting.Snapshot, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.upToDate && !force {
|
||||
return r.lastPolicy, nil
|
||||
}
|
||||
|
||||
if lockable, ok := r.store.(Lockable); ok {
|
||||
if err := lockable.Lock(); err != nil {
|
||||
return r.lastPolicy, err
|
||||
}
|
||||
defer lockable.Unlock()
|
||||
}
|
||||
|
||||
r.upToDate = true
|
||||
|
||||
metrics.Reset(r.origin)
|
||||
|
||||
var m map[setting.Key]setting.RawItem
|
||||
if lastPolicyCount := r.lastPolicy.Len(); lastPolicyCount > 0 {
|
||||
m = make(map[setting.Key]setting.RawItem, lastPolicyCount)
|
||||
}
|
||||
for _, s := range r.settings {
|
||||
if !r.origin.Scope().IsConfigurableSetting(s) {
|
||||
// Skip settings that cannot be configured in the current scope.
|
||||
continue
|
||||
}
|
||||
|
||||
val, err := readPolicySettingValue(r.store, s)
|
||||
if err != nil && (errors.Is(err, setting.ErrNoSuchKey) || errors.Is(err, setting.ErrNotConfigured)) {
|
||||
metrics.ReportNotConfigured(r.origin, s)
|
||||
continue
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
metrics.ReportConfigured(r.origin, s, val)
|
||||
} else {
|
||||
metrics.ReportError(r.origin, s, err)
|
||||
}
|
||||
|
||||
// If there's an error reading a single policy, such as a value type mismatch,
|
||||
// we'll wrap the error to preserve its text and return it
|
||||
// whenever someone attempts to fetch the value.
|
||||
mak.Set(&m, s.Key(), setting.RawItemWith(val, setting.WrapError(err), r.origin))
|
||||
}
|
||||
|
||||
newPolicy := setting.NewSnapshot(m, setting.SummaryWith(r.origin))
|
||||
if r.lastPolicy == nil || !newPolicy.EqualItems(r.lastPolicy) {
|
||||
r.lastPolicy = newPolicy
|
||||
}
|
||||
return r.lastPolicy, nil
|
||||
}
|
||||
|
||||
// ReadingSession is like [Reader], but with a channel that's written
|
||||
// to when there's a policy change, and closed when the session is terminated.
|
||||
type ReadingSession struct {
|
||||
reader *Reader
|
||||
policyChangedCh chan struct{} // 1-buffered channel
|
||||
handle set.Handle // in the reader.sessions
|
||||
closeInternal func()
|
||||
}
|
||||
|
||||
// OpenSession opens and returns a new session to r, allowing the caller
|
||||
// to get notified whenever a policy change is reported by the [source.Store],
|
||||
// or an [ErrStoreClosed] if the reader has already been closed.
|
||||
func (r *Reader) OpenSession() (*ReadingSession, error) {
|
||||
session := &ReadingSession{
|
||||
reader: r,
|
||||
policyChangedCh: make(chan struct{}, 1),
|
||||
}
|
||||
session.closeInternal = sync.OnceFunc(func() { close(session.policyChangedCh) })
|
||||
r.mu.Lock()
|
||||
if !r.closing {
|
||||
session.handle = r.sessions.Add(session)
|
||||
r.mu.Unlock()
|
||||
return session, nil
|
||||
}
|
||||
r.mu.Unlock()
|
||||
return nil, ErrStoreClosed
|
||||
}
|
||||
|
||||
// GetSettings is like [Reader.GetSettings].
|
||||
func (s *ReadingSession) GetSettings() *setting.Snapshot {
|
||||
return s.reader.GetSettings()
|
||||
}
|
||||
|
||||
// ReadSettings is like [Reader.ReadSettings].
|
||||
func (s *ReadingSession) ReadSettings() (*setting.Snapshot, error) {
|
||||
return s.reader.ReadSettings()
|
||||
}
|
||||
|
||||
// PolicyChanged returns a channel that's written to when
|
||||
// there's a policy change, closed when the session is terminated.
|
||||
func (s *ReadingSession) PolicyChanged() <-chan struct{} {
|
||||
return s.policyChangedCh
|
||||
}
|
||||
|
||||
// Close unregisters this session with the [Reader].
|
||||
func (s *ReadingSession) Close() {
|
||||
s.reader.mu.Lock()
|
||||
delete(s.reader.sessions, s.handle)
|
||||
s.closeInternal()
|
||||
s.reader.mu.Unlock()
|
||||
}
|
||||
|
||||
// onPolicyChange handles a policy change notification from the [Store],
|
||||
// invalidating the current [setting.Snapshot] in r,
|
||||
// and notifying the active [ReadingSession]s.
|
||||
func (r *Reader) onPolicyChange() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.upToDate = false
|
||||
for _, s := range r.sessions {
|
||||
select {
|
||||
case s.policyChangedCh <- struct{}{}:
|
||||
// Notified.
|
||||
default:
|
||||
// 1-buffered channel is full, meaning that another policy change
|
||||
// notification is already en route.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the store reader and the underlying store.
|
||||
func (r *Reader) Close() error {
|
||||
r.mu.Lock()
|
||||
if r.closing {
|
||||
r.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
r.closing = true
|
||||
r.mu.Unlock()
|
||||
|
||||
if r.unregisterChangeNotifier != nil {
|
||||
r.unregisterChangeNotifier()
|
||||
r.unregisterChangeNotifier = nil
|
||||
}
|
||||
|
||||
if closer, ok := r.store.(io.Closer); ok {
|
||||
if err := closer.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
r.store = nil
|
||||
|
||||
close(r.doneCh)
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, c := range r.sessions {
|
||||
c.closeInternal()
|
||||
}
|
||||
r.sessions = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Done returns a channel that is closed when the reader is closed.
|
||||
func (r *Reader) Done() <-chan struct{} {
|
||||
return r.doneCh
|
||||
}
|
||||
|
||||
// ReadableSource is a [Source] open for reading.
|
||||
type ReadableSource struct {
|
||||
*Source
|
||||
*ReadingSession
|
||||
}
|
||||
|
||||
// Close closes the underlying [ReadingSession].
|
||||
func (s ReadableSource) Close() {
|
||||
s.ReadingSession.Close()
|
||||
}
|
||||
|
||||
// ReadableSources is a slice of [ReadableSource].
|
||||
type ReadableSources []ReadableSource
|
||||
|
||||
// Contains reports whether s contains the specified source.
|
||||
func (s ReadableSources) Contains(source *Source) bool {
|
||||
return s.IndexOf(source) != -1
|
||||
}
|
||||
|
||||
// IndexOf returns position of the specified source in s, or -1
|
||||
// if the source does not exist.
|
||||
func (s ReadableSources) IndexOf(source *Source) int {
|
||||
return slices.IndexFunc(s, func(rs ReadableSource) bool {
|
||||
return rs.Source == source
|
||||
})
|
||||
}
|
||||
|
||||
// InsertionIndexOf returns the position at which source can be inserted
|
||||
// to maintain the sorted order of the readableSources.
|
||||
// The return value is unspecified if s is not sorted on entry to InsertionIndexOf.
|
||||
func (s ReadableSources) InsertionIndexOf(source *Source) int {
|
||||
low, high := 0, len(s)
|
||||
for low < high {
|
||||
mid := (low + high) / 2
|
||||
if s[mid].Compare(source) <= 0 {
|
||||
low = mid + 1
|
||||
} else {
|
||||
high = mid
|
||||
}
|
||||
}
|
||||
return low
|
||||
}
|
||||
|
||||
// StableSort sorts the readableSources by the precedence, so that policy settings
|
||||
// from sources with higher precedence (e.g., [DeviceScope]) will be merged last,
|
||||
// overriding any policy settings with the same keys configured in sources with
|
||||
// lower precedence (e.g., [CurrentUserScope]).
|
||||
func (s *ReadableSources) StableSort() {
|
||||
sort.SliceStable(*s, func(i, j int) bool {
|
||||
return (*s)[i].Source.Compare((*s)[j].Source) < 0
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteAt closes and deletes the i-th source from s.
|
||||
func (s *ReadableSources) DeleteAt(i int) {
|
||||
(*s)[i].Close()
|
||||
*s = slices.Delete(*s, i, i+1)
|
||||
}
|
||||
|
||||
// Close closes and deletes all sources in s.
|
||||
func (s *ReadableSources) Close() {
|
||||
for _, s := range *s {
|
||||
s.Close()
|
||||
}
|
||||
*s = nil
|
||||
}
|
||||
|
||||
func readPolicySettingValue(store Store, s *setting.Definition) (value any, err error) {
|
||||
switch key := s.Key(); s.Type() {
|
||||
case setting.BooleanValue:
|
||||
return store.ReadBoolean(key)
|
||||
case setting.IntegerValue:
|
||||
return store.ReadUInt64(key)
|
||||
case setting.StringValue:
|
||||
return store.ReadString(key)
|
||||
case setting.StringListValue:
|
||||
return store.ReadStringArray(key)
|
||||
case setting.PreferenceOptionValue:
|
||||
s, err := store.ReadString(key)
|
||||
if err == nil {
|
||||
var value setting.PreferenceOption
|
||||
if err = value.UnmarshalText([]byte(s)); err == nil {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
return setting.ShowChoiceByPolicy, err
|
||||
case setting.VisibilityValue:
|
||||
s, err := store.ReadString(key)
|
||||
if err == nil {
|
||||
var value setting.Visibility
|
||||
if err = value.UnmarshalText([]byte(s)); err == nil {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
return setting.VisibleByPolicy, err
|
||||
case setting.DurationValue:
|
||||
s, err := store.ReadString(key)
|
||||
if err == nil {
|
||||
var value time.Duration
|
||||
if value, err = time.ParseDuration(s); err == nil {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: unsupported setting type: %v", setting.ErrTypeMismatch, s.Type())
|
||||
}
|
||||
}
|
291
util/syspolicy/source/policy_reader_test.go
Normal file
291
util/syspolicy/source/policy_reader_test.go
Normal file
@ -0,0 +1,291 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tailscale.com/util/must"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
func TestReaderLifecycle(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
origin *setting.Origin
|
||||
definitions []*setting.Definition
|
||||
wantReads []TestExpectedReads
|
||||
initStrings []TestSetting[string]
|
||||
initUInt64s []TestSetting[uint64]
|
||||
initWant *setting.Snapshot
|
||||
addStrings []TestSetting[string]
|
||||
addStringLists []TestSetting[[]string]
|
||||
newWant *setting.Snapshot
|
||||
}{
|
||||
{
|
||||
name: "read-all-settings-once",
|
||||
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition("IntegerValue", setting.DeviceSetting, setting.IntegerValue),
|
||||
setting.NewDefinition("BooleanValue", setting.DeviceSetting, setting.BooleanValue),
|
||||
setting.NewDefinition("StringListValue", setting.DeviceSetting, setting.StringListValue),
|
||||
setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue),
|
||||
setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
{Key: "StringValue", Type: setting.StringValue, NumTimes: 1},
|
||||
{Key: "IntegerValue", Type: setting.IntegerValue, NumTimes: 1},
|
||||
{Key: "BooleanValue", Type: setting.BooleanValue, NumTimes: 1},
|
||||
{Key: "StringListValue", Type: setting.StringListValue, NumTimes: 1},
|
||||
{Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
|
||||
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
|
||||
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
|
||||
},
|
||||
initWant: setting.NewSnapshot(nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
},
|
||||
{
|
||||
name: "re-read-all-settings-when-the-policy-changes",
|
||||
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition("IntegerValue", setting.DeviceSetting, setting.IntegerValue),
|
||||
setting.NewDefinition("BooleanValue", setting.DeviceSetting, setting.BooleanValue),
|
||||
setting.NewDefinition("StringListValue", setting.DeviceSetting, setting.StringListValue),
|
||||
setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue),
|
||||
setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
{Key: "StringValue", Type: setting.StringValue, NumTimes: 1},
|
||||
{Key: "IntegerValue", Type: setting.IntegerValue, NumTimes: 1},
|
||||
{Key: "BooleanValue", Type: setting.BooleanValue, NumTimes: 1},
|
||||
{Key: "StringListValue", Type: setting.StringListValue, NumTimes: 1},
|
||||
{Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
|
||||
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
|
||||
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
|
||||
},
|
||||
initWant: setting.NewSnapshot(nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
addStrings: []TestSetting[string]{TestSettingOf("StringValue", "S1")},
|
||||
addStringLists: []TestSetting[[]string]{TestSettingOf("StringListValue", []string{"S1", "S2", "S3"})},
|
||||
newWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"StringValue": setting.RawItemWith("S1", nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
"StringListValue": setting.RawItemWith([]string{"S1", "S2", "S3"}, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
}, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
},
|
||||
{
|
||||
name: "read-settings-if-in-scope/device",
|
||||
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue),
|
||||
setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
{Key: "DeviceSetting", Type: setting.StringValue, NumTimes: 1},
|
||||
{Key: "ProfileSetting", Type: setting.IntegerValue, NumTimes: 1},
|
||||
{Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read-settings-if-in-scope/profile",
|
||||
origin: setting.NewNamedOrigin("Test", setting.CurrentProfileScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue),
|
||||
setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
// Device settings cannot be configured at the profile scope and should not be read.
|
||||
{Key: "ProfileSetting", Type: setting.IntegerValue, NumTimes: 1},
|
||||
{Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read-settings-if-in-scope/user",
|
||||
origin: setting.NewNamedOrigin("Test", setting.CurrentUserScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("DeviceSetting", setting.DeviceSetting, setting.StringValue),
|
||||
setting.NewDefinition("ProfileSetting", setting.ProfileSetting, setting.IntegerValue),
|
||||
setting.NewDefinition("UserSetting", setting.UserSetting, setting.BooleanValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
// Device and profile settings cannot be configured at the profile scope and should not be read.
|
||||
{Key: "UserSetting", Type: setting.BooleanValue, NumTimes: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read-stringy-settings",
|
||||
origin: setting.NewNamedOrigin("Test", setting.DeviceScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("DurationValue", setting.DeviceSetting, setting.DurationValue),
|
||||
setting.NewDefinition("PreferenceOptionValue", setting.DeviceSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition("VisibilityValue", setting.DeviceSetting, setting.VisibilityValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
{Key: "DurationValue", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
|
||||
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
|
||||
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
|
||||
},
|
||||
initStrings: []TestSetting[string]{
|
||||
TestSettingOf("DurationValue", "2h30m"),
|
||||
TestSettingOf("PreferenceOptionValue", "always"),
|
||||
TestSettingOf("VisibilityValue", "show"),
|
||||
},
|
||||
initWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"DurationValue": setting.RawItemWith(must.Get(time.ParseDuration("2h30m")), nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
"PreferenceOptionValue": setting.RawItemWith(setting.AlwaysByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
"VisibilityValue": setting.RawItemWith(setting.VisibleByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
}, setting.NewNamedOrigin("Test", setting.DeviceScope)),
|
||||
},
|
||||
{
|
||||
name: "read-erroneous-stringy-settings",
|
||||
origin: setting.NewNamedOrigin("Test", setting.CurrentUserScope),
|
||||
definitions: []*setting.Definition{
|
||||
setting.NewDefinition("DurationValue1", setting.UserSetting, setting.DurationValue),
|
||||
setting.NewDefinition("DurationValue2", setting.UserSetting, setting.DurationValue),
|
||||
setting.NewDefinition("PreferenceOptionValue", setting.UserSetting, setting.PreferenceOptionValue),
|
||||
setting.NewDefinition("VisibilityValue", setting.UserSetting, setting.VisibilityValue),
|
||||
},
|
||||
wantReads: []TestExpectedReads{
|
||||
{Key: "DurationValue1", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
|
||||
{Key: "DurationValue2", Type: setting.StringValue, NumTimes: 1}, // duration is string from the [Store]'s perspective
|
||||
{Key: "PreferenceOptionValue", Type: setting.StringValue, NumTimes: 1}, // and so are [setting.PreferenceOption]s
|
||||
{Key: "VisibilityValue", Type: setting.StringValue, NumTimes: 1}, // and [setting.Visibility]
|
||||
},
|
||||
initStrings: []TestSetting[string]{
|
||||
TestSettingOf("DurationValue1", "soon"),
|
||||
TestSettingWithError[string]("DurationValue2", setting.NewError("bang!")),
|
||||
TestSettingOf("PreferenceOptionValue", "sometimes"),
|
||||
},
|
||||
initUInt64s: []TestSetting[uint64]{
|
||||
TestSettingOf[uint64]("VisibilityValue", 42), // type mismatch
|
||||
},
|
||||
initWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"DurationValue1": setting.RawItemWith(nil, setting.NewError("time: invalid duration \"soon\""), setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
|
||||
"DurationValue2": setting.RawItemWith(nil, setting.NewError("bang!"), setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
|
||||
"PreferenceOptionValue": setting.RawItemWith(setting.ShowChoiceByPolicy, nil, setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
|
||||
"VisibilityValue": setting.RawItemWith(setting.VisibleByPolicy, setting.NewError("type mismatch in ReadString: got uint64"), setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
|
||||
}, setting.NewNamedOrigin("Test", setting.CurrentUserScope)),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setting.SetDefinitionsForTest(t, tt.definitions...)
|
||||
store := NewTestStore(t)
|
||||
store.SetStrings(tt.initStrings...)
|
||||
store.SetUInt64s(tt.initUInt64s...)
|
||||
|
||||
reader, err := newReader(store, tt.origin)
|
||||
if err != nil {
|
||||
t.Fatalf("newReader failed: %v", err)
|
||||
}
|
||||
|
||||
if got := reader.GetSettings(); tt.initWant != nil && !got.Equal(tt.initWant) {
|
||||
t.Errorf("Settings do not match: got %v, want %v", got, tt.initWant)
|
||||
}
|
||||
if tt.wantReads != nil {
|
||||
store.ReadsMustEqual(tt.wantReads...)
|
||||
}
|
||||
|
||||
// Should not result in new reads as there were no changes.
|
||||
N := 100
|
||||
for range N {
|
||||
reader.GetSettings()
|
||||
}
|
||||
if tt.wantReads != nil {
|
||||
store.ReadsMustEqual(tt.wantReads...)
|
||||
}
|
||||
store.ResetCounters()
|
||||
|
||||
got, err := reader.ReadSettings()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadSettings failed: %v", err)
|
||||
}
|
||||
|
||||
if tt.initWant != nil && !got.Equal(tt.initWant) {
|
||||
t.Errorf("Settings do not match: got %v, want %v", got, tt.initWant)
|
||||
}
|
||||
|
||||
if tt.wantReads != nil {
|
||||
store.ReadsMustEqual(tt.wantReads...)
|
||||
}
|
||||
store.ResetCounters()
|
||||
|
||||
if len(tt.addStrings) != 0 || len(tt.addStringLists) != 0 {
|
||||
store.SetStrings(tt.addStrings...)
|
||||
store.SetStringLists(tt.addStringLists...)
|
||||
|
||||
// As the settings have changed, GetSettings needs to re-read them.
|
||||
if got, want := reader.GetSettings(), cmp.Or(tt.newWant, tt.initWant); !got.Equal(want) {
|
||||
t.Errorf("New Settings do not match: got %v, want %v", got, want)
|
||||
}
|
||||
if tt.wantReads != nil {
|
||||
store.ReadsMustEqual(tt.wantReads...)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-reader.Done():
|
||||
t.Fatalf("the reader is closed")
|
||||
default:
|
||||
}
|
||||
|
||||
store.Close()
|
||||
|
||||
<-reader.Done()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadingSession(t *testing.T) {
|
||||
setting.SetDefinitionsForTest(t, setting.NewDefinition("StringValue", setting.DeviceSetting, setting.StringValue))
|
||||
store := NewTestStore(t)
|
||||
|
||||
origin := setting.NewOrigin(setting.DeviceScope)
|
||||
reader, err := newReader(store, origin)
|
||||
if err != nil {
|
||||
t.Fatalf("newReader failed: %v", err)
|
||||
}
|
||||
session, err := reader.OpenSession()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open a reading session: %v", err)
|
||||
}
|
||||
t.Cleanup(session.Close)
|
||||
|
||||
if got, want := session.GetSettings(), setting.NewSnapshot(nil, origin); !got.Equal(want) {
|
||||
t.Errorf("Settings do not match: got %v, want %v", got, want)
|
||||
}
|
||||
|
||||
select {
|
||||
case _, ok := <-session.PolicyChanged():
|
||||
if ok {
|
||||
t.Fatalf("the policy changed notification was sent prematurely")
|
||||
} else {
|
||||
t.Fatalf("the session was closed prematurely")
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
store.SetStrings(TestSettingOf("StringValue", "S1"))
|
||||
_, ok := <-session.PolicyChanged()
|
||||
if !ok {
|
||||
t.Fatalf("the session was closed prematurely")
|
||||
}
|
||||
|
||||
want := setting.NewSnapshot(map[setting.Key]setting.RawItem{
|
||||
"StringValue": setting.RawItemWith("S1", nil, origin),
|
||||
}, origin)
|
||||
if got := session.GetSettings(); !got.Equal(want) {
|
||||
t.Errorf("Settings do not match: got %v, want %v", got, want)
|
||||
}
|
||||
|
||||
store.Close()
|
||||
if _, ok = <-session.PolicyChanged(); ok {
|
||||
t.Fatalf("the session must be closed")
|
||||
}
|
||||
}
|
146
util/syspolicy/source/policy_store.go
Normal file
146
util/syspolicy/source/policy_store.go
Normal file
@ -0,0 +1,146 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
// ErrStoreClosed is an error returned when attempting to use a [Store] after it
|
||||
// has been closed.
|
||||
var ErrStoreClosed = errors.New("the policy store has been closed")
|
||||
|
||||
// Store provides methods to read system policy settings from OS-specific storage.
|
||||
// Implementations must be concurrency-safe, and may also implement
|
||||
// [Lockable], [Changeable], [Expirable] and [io.Closer].
|
||||
//
|
||||
// If a [Store] implementation also implements [io.Closer],
|
||||
// it will be called by the package to release the resources
|
||||
// when the store is no longer needed.
|
||||
type Store interface {
|
||||
// ReadString returns the value of a [setting.StringValue] with the specified key,
|
||||
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
|
||||
// an [setting.ErrTypeMismatch] if the policy setting is not of a string type.
|
||||
ReadString(key setting.Key) (string, error)
|
||||
// ReadUInt64 returns the value of a [setting.IntegerValue] with the specified key,
|
||||
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
|
||||
// an [setting.ErrTypeMismatch] if the policy setting is not of a string type.
|
||||
ReadUInt64(key setting.Key) (uint64, error)
|
||||
// ReadBoolean returns the value of a [setting.BooleanValue] with the specified key,
|
||||
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
|
||||
// an [setting.ErrTypeMismatch] if the policy setting is not of a string type.
|
||||
ReadBoolean(key setting.Key) (bool, error)
|
||||
// ReadStringArray returns the value of a [setting.StringListValue] with the specified key,
|
||||
// an [setting.ErrNotConfigured] if the policy setting is not configured, or
|
||||
// an [setting.ErrTypeMismatch] if the policy setting is not of a string list type.
|
||||
ReadStringArray(key setting.Key) ([]string, error)
|
||||
}
|
||||
|
||||
// Lockable is an optional interface that [Store] implementations may support.
|
||||
// Locking a [Store] is not mandatory as [Store] must be concurrency-safe,
|
||||
// but is recommended to avoid issues where consecutive read calls for related
|
||||
// policies might return inconsistent results if a policy change occurs between
|
||||
// the calls.
|
||||
type Lockable interface {
|
||||
|
||||
// Lock acquires a read lock on the policy store,
|
||||
// ensuring the store's state remains unchanged while locked.
|
||||
// Multiple readers can hold the lock simultaneously.
|
||||
// It should return nil if the store does not support locking,
|
||||
// or an error if the store cannot be locked.
|
||||
Lock() error
|
||||
// Unlock unlocks the policy store.
|
||||
// It is a runtime error if the store is not locked on entry to Unlock.
|
||||
Unlock()
|
||||
}
|
||||
|
||||
// Changeable is an optional interface that [Store] implementations may support.
|
||||
type Changeable interface {
|
||||
// RegisterChangeCallback adds a function that will be called
|
||||
// whenever there's a policy change in the [Store].
|
||||
// The returned function can be used to unregister the callback.
|
||||
RegisterChangeCallback(callback func()) (unregister func(), err error)
|
||||
}
|
||||
|
||||
// Expirable is an optional interface that [Store] implementations may support.
|
||||
type Expirable interface {
|
||||
// Done returns a channel that is closed when the policy [Store] should no longer be used.
|
||||
// It should return nil if the store never expires.
|
||||
Done() <-chan struct{}
|
||||
}
|
||||
|
||||
// Source represents a named source of policy settings for a given scope.
|
||||
type Source struct {
|
||||
name string
|
||||
scope setting.PolicyScope
|
||||
store Store
|
||||
origin *setting.Origin
|
||||
|
||||
lazyReader lazy.SyncValue[*Reader]
|
||||
}
|
||||
|
||||
// NewSource returns a new [Source] with the specified name, scope, and store.
|
||||
func NewSource(name string, scope setting.PolicyScope, store Store) *Source {
|
||||
return &Source{name: name, scope: scope, store: store, origin: setting.NewNamedOrigin(name, scope)}
|
||||
}
|
||||
|
||||
// Name reports the name of the policy source.
|
||||
func (s *Source) Name() string {
|
||||
return s.name
|
||||
}
|
||||
|
||||
// Scope reports the management scope of the policy source.
|
||||
func (s *Source) Scope() setting.PolicyScope {
|
||||
return s.scope
|
||||
}
|
||||
|
||||
// Store returns the [Store] that can be used to read policy settings from this source.
|
||||
func (s *Source) Store() Store {
|
||||
return s.store
|
||||
}
|
||||
|
||||
// Reader returns a [Reader] that reads from this source's [Store].
|
||||
func (s *Source) Reader() (*Reader, error) {
|
||||
return s.lazyReader.GetErr(func() (*Reader, error) {
|
||||
return newReader(s.store, s.origin)
|
||||
})
|
||||
}
|
||||
|
||||
// String implements [fmt.Stringer].
|
||||
func (s *Source) String() string {
|
||||
if s.Name() != "" {
|
||||
return fmt.Sprintf("%s (%v)", s.Name(), s.Scope())
|
||||
}
|
||||
return s.Scope().String()
|
||||
}
|
||||
|
||||
// Compare returns an integer comparing [Source] s and s2
|
||||
// by their precedence, following the "last-wins" model.
|
||||
// The result will be:
|
||||
//
|
||||
// -1 if policy settings from s should be processed before policy settings from s2;
|
||||
// +1 if policy settings from s should be processed after policy settings from s2, overriding s2;
|
||||
// 0 if the relative processing order of policy settings in s and s2 is unspecified.
|
||||
func (s *Source) Compare(s2 *Source) int {
|
||||
return cmp.Compare(s2.Scope().Kind(), s.Scope().Kind())
|
||||
}
|
||||
|
||||
// Close closes the [Source] and the underlying [Store].
|
||||
func (s *Source) Close() error {
|
||||
// The [Reader], if any, owns the [Store].
|
||||
if reader, _ := s.lazyReader.GetErr(func() (*Reader, error) { return nil, ErrStoreClosed }); reader != nil {
|
||||
return reader.Close()
|
||||
}
|
||||
// Otherwise, it is our responsibility to close it.
|
||||
if closer, ok := s.store.(io.Closer); ok {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
438
util/syspolicy/source/policy_store_windows.go
Normal file
438
util/syspolicy/source/policy_store_windows.go
Normal file
@ -0,0 +1,438 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/winutil/gp"
|
||||
)
|
||||
|
||||
const (
|
||||
softwareKeyName = "Software"
|
||||
tsPoliciesSubkey = `Policies\Tailscale`
|
||||
tsIPNSubkey = "Tailscale IPN" // the legacy key we need to fallback to
|
||||
)
|
||||
|
||||
var (
|
||||
// [PlatformPolicyStore] implements [Store].
|
||||
_ Store = (*PlatformPolicyStore)(nil)
|
||||
)
|
||||
|
||||
// PlatformPolicyStore implements [Store] by providing read access to the Registry-based
|
||||
// Tailscale policies, such as those configured via Group Policy or MDM. It is
|
||||
// recommended to lock it when reading multiple policy values in a row. It also
|
||||
// allows subscribing to notifications when there's a policy change.
|
||||
type PlatformPolicyStore struct {
|
||||
scope gp.Scope // [gp.MachinePolicy] or [gp.UserPolicy]
|
||||
|
||||
// The softwareKey can be HKLM\Software, HKCU\Software, or
|
||||
// HKU\{SID}\Software. Anything below the Software subkey, including
|
||||
// Software\Policies, may not yet exist or could be deleted throughout the
|
||||
// [PlatformPolicyStore]'s lifespan, invalidating the handle. We also prefer
|
||||
// to always use a real registry key (rather than a predefined HKLM or HKCU)
|
||||
// to simplify bookkeeping (predefined keys should never be closed).
|
||||
// Finally, this will allow us to watch for any registry changes directly
|
||||
// should we need this in the future in addition to gp.ChangeWatcher.
|
||||
softwareKey registry.Key
|
||||
watcher *gp.ChangeWatcher
|
||||
|
||||
done chan struct{} // done is closed when Close call completes
|
||||
|
||||
// The policyLock can be locked by the caller when reading multiple policy settings
|
||||
// to prevent the Group Policy Client service from modifying policies while
|
||||
// they are being read.
|
||||
//
|
||||
// When both policyLock and mu need to be taken, mu must be taken before policyLock.
|
||||
policyLock *gp.PolicyLock
|
||||
|
||||
mu sync.RWMutex
|
||||
tsKeys []registry.Key // or nil if the [PlatformPolicyStore] hasn't been locked.
|
||||
cbs set.HandleSet[func()] // policy change callbacks
|
||||
lockCnt int
|
||||
locked sync.WaitGroup
|
||||
closing bool
|
||||
readable bool
|
||||
}
|
||||
|
||||
type registryValueGetter[T any] func(key registry.Key, name setting.Key) (T, error)
|
||||
|
||||
// NewMachinePlatformPolicyStore returns a new [PlatformPolicyStore] for the machine.
|
||||
func NewMachinePlatformPolicyStore() (*PlatformPolicyStore, error) {
|
||||
softwareKey, err := registry.OpenKey(registry.LOCAL_MACHINE, softwareKeyName, windows.KEY_READ)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open the %s key: %w", softwareKeyName, err)
|
||||
}
|
||||
return newPlatformPolicyStore(gp.MachinePolicy, softwareKey, 0)
|
||||
}
|
||||
|
||||
// NewUserPlatformPolicyStore returns a new [PlatformPolicyStore] for the user specified by its token.
|
||||
// User's profile must be loaded, and the token handle must have [windows.TOKEN_QUERY]
|
||||
// access. The caller retains ownership of the token.
|
||||
func NewUserPlatformPolicyStore(token windows.Token) (*PlatformPolicyStore, error) {
|
||||
var err error
|
||||
var softwareKey registry.Key
|
||||
if token != 0 {
|
||||
var user *windows.Tokenuser
|
||||
if user, err = token.GetTokenUser(); err != nil {
|
||||
return nil, fmt.Errorf("failed to get token user: %w", err)
|
||||
}
|
||||
userSid := user.User.Sid
|
||||
softwareKey, err = registry.OpenKey(registry.USERS, userSid.String()+`\`+softwareKeyName, windows.KEY_READ)
|
||||
} else {
|
||||
softwareKey, err = registry.OpenKey(registry.CURRENT_USER, softwareKeyName, windows.KEY_READ)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open the %s key: %w", softwareKeyName, err)
|
||||
}
|
||||
return newPlatformPolicyStore(gp.UserPolicy, softwareKey, token)
|
||||
}
|
||||
|
||||
func newPlatformPolicyStore(scope gp.Scope, softwareKey registry.Key, token windows.Token) (_ *PlatformPolicyStore, err error) {
|
||||
store := &PlatformPolicyStore{
|
||||
scope: scope,
|
||||
softwareKey: softwareKey,
|
||||
done: make(chan struct{}),
|
||||
readable: true,
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
store.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
switch scope {
|
||||
case gp.MachinePolicy:
|
||||
store.policyLock = gp.NewMachinePolicyLock()
|
||||
case gp.UserPolicy:
|
||||
if store.policyLock, err = gp.NewUserPolicyLock(token); err != nil {
|
||||
return nil, fmt.Errorf("failed to create a user policy lock: %w", err)
|
||||
}
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// Lock locks the policy store, preventing the system from modifying the policies
|
||||
// while they are being read. It is a read lock that may be acquired by multiple goroutines.
|
||||
// Each Lock call must be balanced by exactly one Unlock call.
|
||||
func (ps *PlatformPolicyStore) Lock() (err error) {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
|
||||
if ps.closing {
|
||||
return ErrStoreClosed
|
||||
}
|
||||
|
||||
ps.lockCnt += 1
|
||||
if ps.lockCnt != 1 {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
ps.lockCnt -= 1
|
||||
}
|
||||
}()
|
||||
|
||||
// Ensure ps remains open while the lock is held.
|
||||
ps.locked.Add(1)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
ps.locked.Done()
|
||||
}
|
||||
}()
|
||||
|
||||
// Acquire the GP lock to prevent the system from modifying policy settings
|
||||
// while they are being read.
|
||||
if err := ps.policyLock.Lock(); err != nil {
|
||||
if errors.Is(err, gp.ErrInvalidLockState) {
|
||||
return ErrStoreClosed
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
ps.policyLock.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
// Keep the Tailscale's registry keys open for the duration of the lock.
|
||||
keyNames := tailscaleKeyNamesFor(ps.scope)
|
||||
ps.tsKeys = make([]registry.Key, 0, len(keyNames))
|
||||
for _, keyName := range keyNames {
|
||||
var tsKey registry.Key
|
||||
tsKey, err = registry.OpenKey(ps.softwareKey, keyName, windows.KEY_READ)
|
||||
if err != nil {
|
||||
if err == registry.ErrNotExist {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
ps.tsKeys = append(ps.tsKeys, tsKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unlock decrements the lock counter and unlocks the policy store once the counter reaches 0.
|
||||
// It panics if ps is not locked on entry to Unlock.
|
||||
func (ps *PlatformPolicyStore) Unlock() {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
|
||||
ps.lockCnt -= 1
|
||||
if ps.lockCnt < 0 {
|
||||
panic("negative lockCnt")
|
||||
} else if ps.lockCnt != 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, key := range ps.tsKeys {
|
||||
key.Close()
|
||||
}
|
||||
ps.tsKeys = nil
|
||||
ps.policyLock.Unlock()
|
||||
ps.locked.Done()
|
||||
}
|
||||
|
||||
// RegisterChangeCallback adds a function that will be called whenever there's a policy change.
|
||||
// It returns a function that needs to be called to unregister the specified callback or an error.
|
||||
// The error is [ErrStoreClosed] if ps has already been closed.
|
||||
func (ps *PlatformPolicyStore) RegisterChangeCallback(cb func()) (unregister func(), err error) {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
if ps.closing {
|
||||
return nil, ErrStoreClosed
|
||||
}
|
||||
|
||||
handle := ps.cbs.Add(cb)
|
||||
if len(ps.cbs) == 1 {
|
||||
if ps.watcher, err = gp.NewChangeWatcher(ps.scope, ps.onChange); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return func() {
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
delete(ps.cbs, handle)
|
||||
if len(ps.cbs) == 0 {
|
||||
if ps.watcher != nil {
|
||||
ps.watcher.Close()
|
||||
ps.watcher = nil
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ps *PlatformPolicyStore) onChange() {
|
||||
ps.mu.RLock()
|
||||
defer ps.mu.RUnlock()
|
||||
if ps.closing {
|
||||
return
|
||||
}
|
||||
for _, callback := range ps.cbs {
|
||||
go callback()
|
||||
}
|
||||
}
|
||||
|
||||
// ReadString retrieves a string policy with the specified name.
|
||||
// It returns [ErrNotConfigured] if the policy setting does not exist.
|
||||
func (ps *PlatformPolicyStore) ReadString(name setting.Key) (val string, err error) {
|
||||
return getPolicyValue(ps, canonicalizeValueName(name),
|
||||
func(key registry.Key, name setting.Key) (string, error) {
|
||||
val, _, err := key.GetStringValue(string(name))
|
||||
return val, err
|
||||
})
|
||||
}
|
||||
|
||||
// ReadUInt64 retrieves an integer policy with the specified name.
|
||||
// It returns [ErrNotConfigured] if the policy setting does not exist.
|
||||
func (ps *PlatformPolicyStore) ReadUInt64(name setting.Key) (uint64, error) {
|
||||
return getPolicyValue(ps, canonicalizeValueName(name),
|
||||
func(key registry.Key, name setting.Key) (uint64, error) {
|
||||
val, _, err := key.GetIntegerValue(string(name))
|
||||
return val, err
|
||||
})
|
||||
}
|
||||
|
||||
// ReadBoolean retrieves a boolean policy with the specified name.
|
||||
// It returns [ErrNotConfigured] if the policy setting does not exist.
|
||||
func (ps *PlatformPolicyStore) ReadBoolean(name setting.Key) (bool, error) {
|
||||
return getPolicyValue(ps, canonicalizeValueName(name),
|
||||
func(key registry.Key, name setting.Key) (bool, error) {
|
||||
val, _, err := key.GetIntegerValue(string(name))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return val != 0, nil
|
||||
})
|
||||
}
|
||||
|
||||
// ReadString retrieves a multi-string policy with the specified name.
|
||||
// It returns [ErrNotConfigured] if the policy setting does not exist.
|
||||
func (ps *PlatformPolicyStore) ReadStringArray(name setting.Key) ([]string, error) {
|
||||
return getPolicyValue(ps, name,
|
||||
func(key registry.Key, name setting.Key) ([]string, error) {
|
||||
val, _, err := key.GetStringsValue(string(canonicalizeValueName(name)))
|
||||
if err != registry.ErrNotExist {
|
||||
return val, err
|
||||
}
|
||||
|
||||
// The idiomatic way to store multiple string values in Group Policy
|
||||
// and MDM for Windows is to have multiple REG_SZ (or REG_EXPAND_SZ)
|
||||
// values under a subkey rather than in a single REG_MULTI_SZ value.
|
||||
//
|
||||
// See the Group Policy: Registry Extension Encoding specification,
|
||||
// and specifically the ListElement and ListBox types.
|
||||
// https://web.archive.org/web/20240721033657/https://winprotocoldoc.blob.core.windows.net/productionwindowsarchives/MS-GPREG/%5BMS-GPREG%5D.pdf
|
||||
valKey, err := registry.OpenKey(key, string(canonicalizeKeyName(name)), windows.KEY_READ)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
valNames, err := valKey.ReadValueNames(0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
val = make([]string, 0, len(valNames))
|
||||
for _, name := range valNames {
|
||||
switch item, _, err := valKey.GetStringValue(name); {
|
||||
case err == registry.ErrNotExist:
|
||||
continue
|
||||
case err != nil:
|
||||
return nil, err
|
||||
default:
|
||||
val = append(val, item)
|
||||
}
|
||||
}
|
||||
return val, nil
|
||||
})
|
||||
}
|
||||
|
||||
func canonicalizeKeyName(name setting.Key) setting.Key {
|
||||
return setting.Key(strings.ReplaceAll(string(name), setting.KeyPathSeparator, `\`))
|
||||
}
|
||||
|
||||
func canonicalizeValueName(name setting.Key) setting.Key {
|
||||
return setting.Key(strings.ReplaceAll(string(name), setting.KeyPathSeparator, `_`))
|
||||
}
|
||||
|
||||
func getPolicyValue[T any](ps *PlatformPolicyStore, name setting.Key, getter registryValueGetter[T]) (T, error) {
|
||||
var zero T
|
||||
|
||||
ps.mu.RLock()
|
||||
defer ps.mu.RUnlock()
|
||||
if !ps.readable {
|
||||
return zero, setting.ErrNotConfigured
|
||||
}
|
||||
|
||||
if ps.tsKeys != nil {
|
||||
// A non-nil tsKeys indicates that ps has been locked.
|
||||
// It may be empty if Tailscale policy keys do not exist.
|
||||
for _, tsKey := range ps.tsKeys {
|
||||
val, err := getter(tsKey, name)
|
||||
if err == nil || err != registry.ErrNotExist {
|
||||
return val, err
|
||||
}
|
||||
}
|
||||
return zero, setting.ErrNotConfigured
|
||||
}
|
||||
|
||||
// The ps has not been locked, so we don't have any pre-opened keys.
|
||||
for _, tsKeyName := range tailscaleKeyNamesFor(ps.scope) {
|
||||
var tsKey registry.Key
|
||||
tsKey, err := registry.OpenKey(ps.softwareKey, tsKeyName, windows.KEY_READ)
|
||||
if err != nil {
|
||||
if err == registry.ErrNotExist {
|
||||
continue
|
||||
}
|
||||
return zero, err
|
||||
}
|
||||
defer tsKey.Close()
|
||||
|
||||
val, err := getter(tsKey, name)
|
||||
if err == nil || err != registry.ErrNotExist {
|
||||
return val, err
|
||||
}
|
||||
}
|
||||
|
||||
return zero, setting.ErrNotConfigured
|
||||
}
|
||||
|
||||
// Close closes the policy store and releases any associated resources.
|
||||
// It cancels pending locks and prevents any new lock attempts,
|
||||
// but waits for existing locks to be released.
|
||||
func (ps *PlatformPolicyStore) Close() error {
|
||||
// Request to close the Group Policy read lock.
|
||||
// Existing held locks will remain valid, but any new or pending locks
|
||||
// will fail. In certain scenarios, the corresponding write lock may be held
|
||||
// by the Group Policy service for extended periods (minutes rather than
|
||||
// seconds or milliseconds). In such cases, we prefer not to wait that long
|
||||
// if the ps is being closed anyway.
|
||||
if ps.policyLock != nil {
|
||||
ps.policyLock.Close()
|
||||
}
|
||||
|
||||
// Signal to the external code that ps should no longer be used.
|
||||
close(ps.done)
|
||||
|
||||
// Mark ps as closing to fast-fail any new lock attempts.
|
||||
// Callers that have already locked it can finish their reading.
|
||||
ps.mu.Lock()
|
||||
if ps.closing {
|
||||
ps.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
ps.closing = true
|
||||
if ps.watcher != nil {
|
||||
ps.watcher.Close()
|
||||
ps.watcher = nil
|
||||
}
|
||||
ps.mu.Unlock()
|
||||
|
||||
// Wait for any outstanding locks to be released.
|
||||
ps.locked.Wait()
|
||||
|
||||
// Deny any further read attempts and release remaining resources.
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
ps.cbs = nil
|
||||
ps.policyLock = nil
|
||||
ps.readable = false
|
||||
if ps.softwareKey != 0 {
|
||||
ps.softwareKey.Close()
|
||||
ps.softwareKey = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Done returns a channel that is closed when the Close method is called.
|
||||
func (ps *PlatformPolicyStore) Done() <-chan struct{} {
|
||||
return ps.done
|
||||
}
|
||||
|
||||
func tailscaleKeyNamesFor(scope gp.Scope) []string {
|
||||
switch scope {
|
||||
case gp.MachinePolicy:
|
||||
// If a computer-side policy value does not exist under Software\Policies\Tailscale,
|
||||
// we need to fallback and use the legacy Software\Tailscale IPN key.
|
||||
return []string{tsPoliciesSubkey, tsIPNSubkey}
|
||||
case gp.UserPolicy:
|
||||
// However, we've never used the legacy key with user-side policies,
|
||||
// and we should never do so. Unlike HKLM\Software\Tailscale IPN,
|
||||
// its HKCU counterpart is user-writable.
|
||||
return []string{tsPoliciesSubkey}
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
298
util/syspolicy/source/policy_store_windows_test.go
Normal file
298
util/syspolicy/source/policy_store_windows_test.go
Normal file
@ -0,0 +1,298 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"tailscale.com/util/cibuild"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/winutil"
|
||||
"tailscale.com/util/winutil/gp"
|
||||
)
|
||||
|
||||
type testPolicyValue struct {
|
||||
name setting.Key
|
||||
value any
|
||||
}
|
||||
|
||||
func TestLockUnlockPolicyStore(t *testing.T) {
|
||||
store, err := NewMachinePlatformPolicyStore()
|
||||
if err != nil {
|
||||
t.Fatalf("NewMachinePolicyStore failed: %v", err)
|
||||
}
|
||||
|
||||
t.Run("One-Goroutine", func(t *testing.T) {
|
||||
if err := store.Lock(); err != nil {
|
||||
t.Errorf("store.Lock(): got %v; want nil", err)
|
||||
return
|
||||
}
|
||||
if v, err := store.ReadString("NonExistingPolicySetting"); err == nil || !errors.Is(err, setting.ErrNotConfigured) {
|
||||
t.Errorf(`ReadString: got %v, %v; want "", %v`, v, err, setting.ErrNotConfigured)
|
||||
}
|
||||
store.Unlock()
|
||||
})
|
||||
|
||||
// Lock the store N times from different goroutines.
|
||||
const N = 100
|
||||
var unlocked atomic.Int32
|
||||
t.Run("N-Goroutines", func(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(N)
|
||||
for range N {
|
||||
go func() {
|
||||
if err := store.Lock(); err != nil {
|
||||
t.Errorf("store.Lock(): got %v; want nil", err)
|
||||
return
|
||||
}
|
||||
if v, err := store.ReadString("NonExistingPolicySetting"); err == nil || !errors.Is(err, setting.ErrNotConfigured) {
|
||||
t.Errorf(`ReadString: got %v, %v; want "", %v`, v, err, setting.ErrNotConfigured)
|
||||
}
|
||||
wg.Done()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
unlocked.Add(1)
|
||||
store.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait until the store is locked N times.
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
// Close the store. The call should wait for all held locks to be released.
|
||||
if err := store.Close(); err != nil {
|
||||
t.Fatalf("(*PolicyStore).Close failed: %v", err)
|
||||
}
|
||||
if locked := unlocked.Load(); locked != N {
|
||||
t.Errorf("locked.Load(): got %v; want %v", locked, N)
|
||||
}
|
||||
|
||||
// Any further attempts to lock it should fail.
|
||||
if err = store.Lock(); err == nil || !errors.Is(err, ErrStoreClosed) {
|
||||
t.Errorf("store.Lock(): got %v; want %v", err, ErrStoreClosed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadPolicyStore(t *testing.T) {
|
||||
if !winutil.IsCurrentProcessElevated() {
|
||||
t.Skipf("test requires running as elevated user")
|
||||
}
|
||||
tests := []struct {
|
||||
name setting.Key
|
||||
newValue any
|
||||
legacyValue any
|
||||
want any
|
||||
}{
|
||||
{name: "LegacyPolicy", legacyValue: "LegacyValue", want: "LegacyValue"},
|
||||
{name: "StringPolicy", legacyValue: "LegacyValue", newValue: "Value", want: "Value"},
|
||||
{name: "StringPolicy_Empty", legacyValue: "LegacyValue", newValue: "", want: ""},
|
||||
{name: "BoolPolicy_True", newValue: true, want: true},
|
||||
{name: "BoolPolicy_False", newValue: false, want: false},
|
||||
{name: "UIntPolicy_1", newValue: uint32(10), want: uint64(10)}, // uint32 values should be returned as uint64
|
||||
{name: "UIntPolicy_2", newValue: uint64(1 << 37), want: uint64(1 << 37)},
|
||||
{name: "StringListPolicy", newValue: []string{"Value1", "Value2"}, want: []string{"Value1", "Value2"}},
|
||||
{name: "StringListPolicy_Empty", newValue: []string{}, want: []string{}},
|
||||
}
|
||||
|
||||
runTests := func(t *testing.T, userStore bool, token windows.Token) {
|
||||
var hive registry.Key
|
||||
if userStore {
|
||||
hive = registry.CURRENT_USER
|
||||
} else {
|
||||
hive = registry.LOCAL_MACHINE
|
||||
}
|
||||
|
||||
// Write policy values to the registry.
|
||||
newValues := make([]testPolicyValue, 0, len(tests))
|
||||
for _, tt := range tests {
|
||||
if tt.newValue != nil {
|
||||
newValues = append(newValues, testPolicyValue{name: tt.name, value: tt.newValue})
|
||||
}
|
||||
}
|
||||
policiesKeyName := softwareKeyName + `\` + tsPoliciesSubkey
|
||||
cleanup, err := createTestPolicyValues(hive, policiesKeyName, newValues)
|
||||
if err != nil {
|
||||
t.Fatalf("createTestPolicyValues failed: %v", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
// Write legacy policy values to the registry.
|
||||
legacyValues := make([]testPolicyValue, 0, len(tests))
|
||||
for _, tt := range tests {
|
||||
if tt.legacyValue != nil {
|
||||
legacyValues = append(legacyValues, testPolicyValue{name: tt.name, value: tt.legacyValue})
|
||||
}
|
||||
}
|
||||
legacyKeyName := softwareKeyName + `\` + tsIPNSubkey
|
||||
cleanup, err = createTestPolicyValues(hive, legacyKeyName, legacyValues)
|
||||
if err != nil {
|
||||
t.Fatalf("createTestPolicyValues failed: %v", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
var store *PlatformPolicyStore
|
||||
if userStore {
|
||||
store, err = NewUserPlatformPolicyStore(token)
|
||||
} else {
|
||||
store, err = NewMachinePlatformPolicyStore()
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("NewXPolicyStore failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := store.Close(); err != nil {
|
||||
t.Errorf("(*PolicyStore).Close failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// testReadValues checks that [PolicyStore] returns the same values we wrote directly to the registry.
|
||||
testReadValues := func(t *testing.T, withLocks bool) {
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.name), func(t *testing.T) {
|
||||
if userStore && tt.newValue == nil {
|
||||
t.Skip("there is no legacy policies for users")
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
if withLocks {
|
||||
if err := store.Lock(); err != nil {
|
||||
t.Errorf("failed to acquire the lock: %v", err)
|
||||
}
|
||||
defer store.Unlock()
|
||||
}
|
||||
|
||||
var got any
|
||||
var err error
|
||||
switch tt.want.(type) {
|
||||
case string:
|
||||
got, err = store.ReadString(tt.name)
|
||||
case uint64:
|
||||
got, err = store.ReadUInt64(tt.name)
|
||||
case bool:
|
||||
got, err = store.ReadBoolean(tt.name)
|
||||
case []string:
|
||||
got, err = store.ReadStringArray(tt.name)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("got %v; want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
t.Run("NoLock", func(t *testing.T) {
|
||||
testReadValues(t, false)
|
||||
})
|
||||
|
||||
t.Run("WithLock", func(t *testing.T) {
|
||||
testReadValues(t, true)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("MachineStore", func(t *testing.T) {
|
||||
runTests(t, false, 0)
|
||||
})
|
||||
|
||||
t.Run("CurrentUserStore", func(t *testing.T) {
|
||||
runTests(t, true, 0)
|
||||
})
|
||||
|
||||
t.Run("UserStoreWithToken", func(t *testing.T) {
|
||||
var token windows.Token
|
||||
if err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token); err != nil {
|
||||
t.Fatalf("OpenProcessToken: %v", err)
|
||||
}
|
||||
defer token.Close()
|
||||
runTests(t, true, token)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPolicyStoreChangeNotifications(t *testing.T) {
|
||||
if cibuild.On() {
|
||||
t.Skipf("test requires running on a real Windows environment")
|
||||
}
|
||||
store, err := NewMachinePlatformPolicyStore()
|
||||
if err != nil {
|
||||
t.Fatalf("NewMachinePolicyStore failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := store.Close(); err != nil {
|
||||
t.Errorf("(*PolicyStore).Close failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
done := make(chan struct{})
|
||||
unregister, err := store.RegisterChangeCallback(func() { close(done) })
|
||||
if err != nil {
|
||||
t.Fatalf("RegisterChangeCallback failed: %v", err)
|
||||
}
|
||||
t.Cleanup(unregister)
|
||||
|
||||
// RefreshMachinePolicy is a non-blocking call.
|
||||
if err := gp.RefreshMachinePolicy(true); err != nil {
|
||||
t.Fatalf("RefreshMachinePolicy failed: %v", err)
|
||||
}
|
||||
|
||||
// We should receive a policy change notification when
|
||||
// the Group Policy service completes policy processing.
|
||||
// Otherwise, the test will eventually time out.
|
||||
<-done
|
||||
}
|
||||
|
||||
func createTestPolicyValues(hive registry.Key, keyName string, values []testPolicyValue) (cleanup func(), err error) {
|
||||
key, existing, err := registry.CreateKey(hive, keyName, registry.ALL_ACCESS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
doCleanup := func() {
|
||||
for _, v := range values {
|
||||
key.DeleteValue(string(v.name))
|
||||
}
|
||||
key.Close()
|
||||
if !existing {
|
||||
registry.DeleteKey(hive, keyName)
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
doCleanup()
|
||||
}
|
||||
}()
|
||||
|
||||
for _, v := range values {
|
||||
switch value := v.value.(type) {
|
||||
case string:
|
||||
err = key.SetStringValue(string(v.name), value)
|
||||
case uint32:
|
||||
err = key.SetDWordValue(string(v.name), value)
|
||||
case uint64:
|
||||
err = key.SetQWordValue(string(v.name), value)
|
||||
case bool:
|
||||
if value {
|
||||
err = key.SetDWordValue(string(v.name), 1)
|
||||
} else {
|
||||
err = key.SetDWordValue(string(v.name), 0)
|
||||
}
|
||||
case []string:
|
||||
err = key.SetStringsValue(string(v.name), value)
|
||||
default:
|
||||
err = fmt.Errorf("unsupported value: %v (%T), name: %q", value, value, v.name)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return doCleanup, nil
|
||||
}
|
446
util/syspolicy/source/test_store.go
Normal file
446
util/syspolicy/source/test_store.go
Normal file
@ -0,0 +1,446 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
var _ Store = (*TestStore)(nil)
|
||||
|
||||
// TestValueType is a constraint that allows types supported by [TestStore].
|
||||
type TestValueType interface {
|
||||
bool | uint64 | string | []string
|
||||
}
|
||||
|
||||
// TestSetting is a policy setting in a [TestStore].
|
||||
type TestSetting[T TestValueType] struct {
|
||||
// Key is the setting's unique identifier.
|
||||
Key setting.Key
|
||||
// Error is the error to be returned by the [TestStore] when reading
|
||||
// a policy setting with the specified key.
|
||||
Error error
|
||||
// Value is the value to be returned by the [TestStore] when reading
|
||||
// a policy setting with the specified key.
|
||||
// It is only used if the Error is nil.
|
||||
Value T
|
||||
}
|
||||
|
||||
// TestSettingOf returns a [TestSetting] representing a policy setting
|
||||
// configured with the specified key and value.
|
||||
func TestSettingOf[T TestValueType](key setting.Key, value T) TestSetting[T] {
|
||||
return TestSetting[T]{Key: key, Value: value}
|
||||
}
|
||||
|
||||
// TestSettingWithError returns a [TestSetting] representing a policy setting
|
||||
// with the specified key and error.
|
||||
func TestSettingWithError[T TestValueType](key setting.Key, err error) TestSetting[T] {
|
||||
return TestSetting[T]{Key: key, Error: err}
|
||||
}
|
||||
|
||||
// testReadOperation describes a single policy setting read operation.
|
||||
type testReadOperation struct {
|
||||
// Key is the setting's unique identifier.
|
||||
Key setting.Key
|
||||
// Type is a value type of a read operation.
|
||||
// [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue]
|
||||
Type setting.Type
|
||||
}
|
||||
|
||||
// TestExpectedReads is the number of read operations with the specified details.
|
||||
type TestExpectedReads struct {
|
||||
// Key is the setting's unique identifier.
|
||||
Key setting.Key
|
||||
// Type is a value type of a read operation.
|
||||
// [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue]
|
||||
Type setting.Type
|
||||
// NumTimes is how many times a setting with the specified key and type should have been read.
|
||||
NumTimes int
|
||||
}
|
||||
|
||||
func (r TestExpectedReads) operation() testReadOperation {
|
||||
return testReadOperation{r.Key, r.Type}
|
||||
}
|
||||
|
||||
// TestStore is a [Store] that can be used in tests.
|
||||
type TestStore struct {
|
||||
tb internal.TB
|
||||
|
||||
done chan struct{}
|
||||
|
||||
storeLock sync.RWMutex // its RLock is exposed via [Store.Lock]/[Store.Unlock].
|
||||
storeLockCount atomic.Int32
|
||||
|
||||
mu sync.RWMutex
|
||||
suspendCount int // change callback are suspended if > 0
|
||||
mr, mw map[setting.Key]any // maps for reading and writing; they're the same unless the store is suspended.
|
||||
cbs set.HandleSet[func()]
|
||||
|
||||
readsMu sync.Mutex
|
||||
reads map[testReadOperation]int // how many times a policy setting was read
|
||||
}
|
||||
|
||||
// NewTestStore returns a new [TestStore].
|
||||
// The tb will be used to report coding errors detected by the [TestStore].
|
||||
func NewTestStore(tb internal.TB) *TestStore {
|
||||
m := make(map[setting.Key]any)
|
||||
return &TestStore{
|
||||
tb: tb,
|
||||
done: make(chan struct{}),
|
||||
mr: m,
|
||||
mw: m,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTestStoreOf is a shorthand for [NewTestStore] followed by [TestStore.SetBooleans],
|
||||
// [TestStore.SetUInt64s], [TestStore.SetStrings] or [TestStore.SetStringLists].
|
||||
func NewTestStoreOf[T TestValueType](tb internal.TB, settings ...TestSetting[T]) *TestStore {
|
||||
m := make(map[setting.Key]any)
|
||||
store := &TestStore{
|
||||
tb: tb,
|
||||
done: make(chan struct{}),
|
||||
mr: m,
|
||||
mw: m,
|
||||
}
|
||||
switch settings := any(settings).(type) {
|
||||
case []TestSetting[bool]:
|
||||
store.SetBooleans(settings...)
|
||||
case []TestSetting[uint64]:
|
||||
store.SetUInt64s(settings...)
|
||||
case []TestSetting[string]:
|
||||
store.SetStrings(settings...)
|
||||
case []TestSetting[[]string]:
|
||||
store.SetStringLists(settings...)
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
// Lock implements [Store].
|
||||
func (s *TestStore) Lock() error {
|
||||
s.storeLock.RLock()
|
||||
s.storeLockCount.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unlock implements [Store].
|
||||
func (s *TestStore) Unlock() {
|
||||
if s.storeLockCount.Add(-1) < 0 {
|
||||
s.tb.Fatal("negative storeLockCount")
|
||||
}
|
||||
s.storeLock.RUnlock()
|
||||
}
|
||||
|
||||
// RegisterChangeCallback implements [Store].
|
||||
func (s *TestStore) RegisterChangeCallback(callback func()) (unregister func(), err error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
handle := s.cbs.Add(callback)
|
||||
return func() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.cbs, handle)
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ReadString implements [Store].
|
||||
func (s *TestStore) ReadString(key setting.Key) (string, error) {
|
||||
defer s.recordRead(key, setting.StringValue)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
v, ok := s.mr[key]
|
||||
if !ok {
|
||||
return "", setting.ErrNotConfigured
|
||||
}
|
||||
if err, ok := v.(error); ok {
|
||||
return "", err
|
||||
}
|
||||
str, ok := v.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("%w in ReadString: got %T", setting.ErrTypeMismatch, v)
|
||||
}
|
||||
return str, nil
|
||||
}
|
||||
|
||||
// ReadUInt64 implements [Store].
|
||||
func (s *TestStore) ReadUInt64(key setting.Key) (uint64, error) {
|
||||
defer s.recordRead(key, setting.IntegerValue)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
v, ok := s.mr[key]
|
||||
if !ok {
|
||||
return 0, setting.ErrNotConfigured
|
||||
}
|
||||
if err, ok := v.(error); ok {
|
||||
return 0, err
|
||||
}
|
||||
u64, ok := v.(uint64)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("%w in ReadUInt64: got %T", setting.ErrTypeMismatch, v)
|
||||
}
|
||||
return u64, nil
|
||||
}
|
||||
|
||||
// ReadBoolean implements [Store].
|
||||
func (s *TestStore) ReadBoolean(key setting.Key) (bool, error) {
|
||||
defer s.recordRead(key, setting.BooleanValue)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
v, ok := s.mr[key]
|
||||
if !ok {
|
||||
return false, setting.ErrNotConfigured
|
||||
}
|
||||
if err, ok := v.(error); ok {
|
||||
return false, err
|
||||
}
|
||||
b, ok := v.(bool)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("%w in ReadBoolean: got %T", setting.ErrTypeMismatch, v)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// ReadStringArray implements [Store].
|
||||
func (s *TestStore) ReadStringArray(key setting.Key) ([]string, error) {
|
||||
defer s.recordRead(key, setting.StringListValue)
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
v, ok := s.mr[key]
|
||||
if !ok {
|
||||
return nil, setting.ErrNotConfigured
|
||||
}
|
||||
if err, ok := v.(error); ok {
|
||||
return nil, err
|
||||
}
|
||||
slice, ok := v.([]string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w in ReadStringArray: got %T", setting.ErrTypeMismatch, v)
|
||||
}
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
func (s *TestStore) recordRead(key setting.Key, typ setting.Type) {
|
||||
s.readsMu.Lock()
|
||||
op := testReadOperation{key, typ}
|
||||
num := s.reads[op]
|
||||
num++
|
||||
mak.Set(&s.reads, op, num)
|
||||
s.readsMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *TestStore) ResetCounters() {
|
||||
s.readsMu.Lock()
|
||||
clear(s.reads)
|
||||
s.readsMu.Unlock()
|
||||
}
|
||||
|
||||
// ReadsMustEqual fails the test if the actual reads differs from the specified reads.
|
||||
func (s *TestStore) ReadsMustEqual(reads ...TestExpectedReads) {
|
||||
s.tb.Helper()
|
||||
s.readsMu.Lock()
|
||||
defer s.readsMu.Unlock()
|
||||
s.readsMustContainLocked(reads...)
|
||||
s.readMustNoExtraLocked(reads...)
|
||||
}
|
||||
|
||||
// ReadsMustContain fails the test if the specified reads have not been made,
|
||||
// or have been made a different number of times. It permits other values to be
|
||||
// read in addition to the ones being tested.
|
||||
func (s *TestStore) ReadsMustContain(reads ...TestExpectedReads) {
|
||||
s.tb.Helper()
|
||||
s.readsMu.Lock()
|
||||
defer s.readsMu.Unlock()
|
||||
s.readsMustContainLocked(reads...)
|
||||
}
|
||||
|
||||
func (s *TestStore) readsMustContainLocked(reads ...TestExpectedReads) {
|
||||
s.tb.Helper()
|
||||
for _, r := range reads {
|
||||
if numTimes := s.reads[r.operation()]; numTimes != r.NumTimes {
|
||||
s.tb.Errorf("%q (%v) reads: got %v, want %v", r.Key, r.Type, numTimes, r.NumTimes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TestStore) readMustNoExtraLocked(reads ...TestExpectedReads) {
|
||||
s.tb.Helper()
|
||||
rs := make(set.Set[testReadOperation])
|
||||
for i := range reads {
|
||||
rs.Add(reads[i].operation())
|
||||
}
|
||||
for ro, num := range s.reads {
|
||||
if !rs.Contains(ro) {
|
||||
s.tb.Errorf("%q (%v) reads: got %v, want 0", ro.Key, ro.Type, num)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Suspend suspends the store, batching changes and notifications
|
||||
// until [TestStore.Resume] is called the same number of times as Suspend.
|
||||
func (s *TestStore) Suspend() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.suspendCount++; s.suspendCount == 1 {
|
||||
s.mw = xmaps.Clone(s.mr)
|
||||
}
|
||||
}
|
||||
|
||||
// Resume resumes the store, applying the changes and invoking
|
||||
// the change callbacks.
|
||||
func (s *TestStore) Resume() {
|
||||
s.storeLock.Lock()
|
||||
s.mu.Lock()
|
||||
switch s.suspendCount--; {
|
||||
case s.suspendCount == 0:
|
||||
s.mr = s.mw
|
||||
s.mu.Unlock()
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
case s.suspendCount < 0:
|
||||
s.tb.Fatal("negative suspendCount")
|
||||
default:
|
||||
s.mu.Unlock()
|
||||
s.storeLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// SetBooleans sets the specified boolean settings in s.
|
||||
func (s *TestStore) SetBooleans(settings ...TestSetting[bool]) {
|
||||
s.storeLock.Lock()
|
||||
for _, setting := range settings {
|
||||
if setting.Key == "" {
|
||||
s.tb.Fatal("empty keys disallowed")
|
||||
}
|
||||
s.mu.Lock()
|
||||
if setting.Error != nil {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
||||
} else {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
}
|
||||
|
||||
// SetUInt64s sets the specified integer settings in s.
|
||||
func (s *TestStore) SetUInt64s(settings ...TestSetting[uint64]) {
|
||||
s.storeLock.Lock()
|
||||
for _, setting := range settings {
|
||||
if setting.Key == "" {
|
||||
s.tb.Fatal("empty keys disallowed")
|
||||
}
|
||||
s.mu.Lock()
|
||||
if setting.Error != nil {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
||||
} else {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
}
|
||||
|
||||
// SetStrings sets the specified string settings in s.
|
||||
func (s *TestStore) SetStrings(settings ...TestSetting[string]) {
|
||||
s.storeLock.Lock()
|
||||
for _, setting := range settings {
|
||||
if setting.Key == "" {
|
||||
s.tb.Fatal("empty keys disallowed")
|
||||
}
|
||||
s.mu.Lock()
|
||||
if setting.Error != nil {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
||||
} else {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
}
|
||||
|
||||
// SetStrings sets the specified string list settings in s.
|
||||
func (s *TestStore) SetStringLists(settings ...TestSetting[[]string]) {
|
||||
s.storeLock.Lock()
|
||||
for _, setting := range settings {
|
||||
if setting.Key == "" {
|
||||
s.tb.Fatal("empty keys disallowed")
|
||||
}
|
||||
s.mu.Lock()
|
||||
if setting.Error != nil {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
||||
} else {
|
||||
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
}
|
||||
|
||||
// Delete deletes the specified settings from s.
|
||||
func (s *TestStore) Delete(keys ...setting.Key) {
|
||||
s.storeLock.Lock()
|
||||
for _, key := range keys {
|
||||
s.mu.Lock()
|
||||
delete(s.mw, key)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
}
|
||||
|
||||
// Clear deletes all settings from s.
|
||||
func (s *TestStore) Clear() {
|
||||
s.storeLock.Lock()
|
||||
s.mu.Lock()
|
||||
clear(s.mw)
|
||||
s.mu.Unlock()
|
||||
s.storeLock.Unlock()
|
||||
s.notifyPolicyChanged()
|
||||
}
|
||||
|
||||
func (s *TestStore) notifyPolicyChanged() {
|
||||
s.mu.RLock()
|
||||
if s.suspendCount != 0 {
|
||||
s.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
cbs := xmaps.Values(s.cbs)
|
||||
s.mu.RUnlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(cbs))
|
||||
for _, cb := range cbs {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
cb()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Close closes s, notifying its users that it has expired.
|
||||
func (s *TestStore) Close() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.done != nil {
|
||||
close(s.done)
|
||||
s.done = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Done implements [Store].
|
||||
func (s *TestStore) Done() <-chan struct{} {
|
||||
return s.done
|
||||
}
|
@ -1,122 +1,83 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package syspolicy provides functions to retrieve system settings of a device.
|
||||
// Package syspolicy facilitates retrieval of the current policy settings
|
||||
// applied to the device or user and receiving notifications when the policy
|
||||
// changes.
|
||||
//
|
||||
// It provides functions that return specific policy settings by their unique
|
||||
// [setting.Key]s, such as [GetBoolean], [GetUint64], [GetString],
|
||||
// [GetStringArray], [GetPreferenceOption], [GetVisibility] and [GetDuration].
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"tailscale.com/util/syspolicy/rsop"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNotConfigured is returned when the requested policy setting is not configured.
|
||||
ErrNotConfigured = setting.ErrNotConfigured
|
||||
// ErrTypeMismatch is returned when there's a type mismatch between the actual type
|
||||
// of the setting value and the expected type.
|
||||
ErrTypeMismatch = setting.ErrTypeMismatch
|
||||
// ErrNoSuchKey is returned by [setting.DefinitionOf] when no policy setting
|
||||
// has been registered with the specified key.
|
||||
//
|
||||
// Until 2024-08-02, this error was also returned by a [Handler] when the specified
|
||||
// key did not have a value set. While the package maintains compatibility with this
|
||||
// usage of ErrNoSuchKey, it is recommended to return [ErrNotConfigured] from newer
|
||||
// [source.Store] implementations.
|
||||
ErrNoSuchKey = setting.ErrNoSuchKey
|
||||
)
|
||||
|
||||
// GetString returns a string policy setting with the specified key,
|
||||
// or defaultValue if it does not exist.
|
||||
func GetString(key Key, defaultValue string) (string, error) {
|
||||
markHandlerInUse()
|
||||
v, err := handler.ReadString(string(key))
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
return defaultValue, nil
|
||||
}
|
||||
return v, err
|
||||
return getCurrentPolicySettingValue(key, defaultValue)
|
||||
}
|
||||
|
||||
// GetUint64 returns a numeric policy setting with the specified key,
|
||||
// or defaultValue if it does not exist.
|
||||
func GetUint64(key Key, defaultValue uint64) (uint64, error) {
|
||||
markHandlerInUse()
|
||||
v, err := handler.ReadUInt64(string(key))
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
return defaultValue, nil
|
||||
}
|
||||
return v, err
|
||||
return getCurrentPolicySettingValue(key, defaultValue)
|
||||
}
|
||||
|
||||
// GetBoolean returns a boolean policy setting with the specified key,
|
||||
// or defaultValue if it does not exist.
|
||||
func GetBoolean(key Key, defaultValue bool) (bool, error) {
|
||||
markHandlerInUse()
|
||||
v, err := handler.ReadBoolean(string(key))
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
return defaultValue, nil
|
||||
}
|
||||
return v, err
|
||||
return getCurrentPolicySettingValue(key, defaultValue)
|
||||
}
|
||||
|
||||
// GetStringArray returns a multi-string policy setting with the specified key,
|
||||
// or defaultValue if it does not exist.
|
||||
func GetStringArray(key Key, defaultValue []string) ([]string, error) {
|
||||
markHandlerInUse()
|
||||
v, err := handler.ReadStringArray(string(key))
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
return defaultValue, nil
|
||||
}
|
||||
return v, err
|
||||
return getCurrentPolicySettingValue(key, defaultValue)
|
||||
}
|
||||
|
||||
// PreferenceOption is a policy that governs whether a boolean variable
|
||||
// is forcibly assigned an administrator-defined value, or allowed to receive
|
||||
// a user-defined value.
|
||||
type PreferenceOption int
|
||||
|
||||
const (
|
||||
showChoiceByPolicy PreferenceOption = iota
|
||||
neverByPolicy
|
||||
alwaysByPolicy
|
||||
type (
|
||||
// PreferenceOption is a policy that governs whether a boolean variable
|
||||
// is forcibly assigned an administrator-defined value, or allowed to receive
|
||||
// a user-defined value.
|
||||
PreferenceOption = setting.PreferenceOption
|
||||
// Visibility is a policy that controls whether or not a particular
|
||||
// component of a user interface is to be shown.
|
||||
Visibility = setting.Visibility
|
||||
)
|
||||
|
||||
// Show returns if the UI option that controls the choice administered by this
|
||||
// policy should be shown. Currently this is true if and only if the policy is
|
||||
// showChoiceByPolicy.
|
||||
func (p PreferenceOption) Show() bool {
|
||||
return p == showChoiceByPolicy
|
||||
}
|
||||
|
||||
// ShouldEnable checks if the choice administered by this policy should be
|
||||
// enabled. If the administrator has chosen a setting, the administrator's
|
||||
// setting is returned, otherwise userChoice is returned.
|
||||
func (p PreferenceOption) ShouldEnable(userChoice bool) bool {
|
||||
switch p {
|
||||
case neverByPolicy:
|
||||
return false
|
||||
case alwaysByPolicy:
|
||||
return true
|
||||
default:
|
||||
return userChoice
|
||||
}
|
||||
}
|
||||
|
||||
// WillOverride checks if the choice administered by the policy is different
|
||||
// from the user's choice.
|
||||
func (p PreferenceOption) WillOverride(userChoice bool) bool {
|
||||
return p.ShouldEnable(userChoice) != userChoice
|
||||
}
|
||||
|
||||
// GetPreferenceOption loads a policy from the registry that can be
|
||||
// managed by an enterprise policy management system and allows administrative
|
||||
// overrides of users' choices in a way that we do not want tailcontrol to have
|
||||
// the authority to set. It describes user-decides/always/never options, where
|
||||
// "always" and "never" remove the user's ability to make a selection. If not
|
||||
// present or set to a different value, "user-decides" is the default.
|
||||
func GetPreferenceOption(name Key) (PreferenceOption, error) {
|
||||
opt, err := GetString(name, "user-decides")
|
||||
if err != nil {
|
||||
return showChoiceByPolicy, err
|
||||
}
|
||||
switch opt {
|
||||
case "always":
|
||||
return alwaysByPolicy, nil
|
||||
case "never":
|
||||
return neverByPolicy, nil
|
||||
default:
|
||||
return showChoiceByPolicy, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Visibility is a policy that controls whether or not a particular
|
||||
// component of a user interface is to be shown.
|
||||
type Visibility byte
|
||||
|
||||
const (
|
||||
visibleByPolicy Visibility = 'v'
|
||||
hiddenByPolicy Visibility = 'h'
|
||||
)
|
||||
|
||||
// Show reports whether the UI option administered by this policy should be shown.
|
||||
// Currently this is true if and only if the policy is visibleByPolicy.
|
||||
func (p Visibility) Show() bool {
|
||||
return p == visibleByPolicy
|
||||
func GetPreferenceOption(name Key) (setting.PreferenceOption, error) {
|
||||
return getCurrentPolicySettingValue(name, setting.ShowChoiceByPolicy)
|
||||
}
|
||||
|
||||
// GetVisibility loads a policy from the registry that can be managed
|
||||
@ -124,17 +85,8 @@ func (p Visibility) Show() bool {
|
||||
// for UI elements. The registry value should be a string set to "show" (return
|
||||
// true) or "hide" (return true). If not present or set to a different value,
|
||||
// "show" (return false) is the default.
|
||||
func GetVisibility(name Key) (Visibility, error) {
|
||||
opt, err := GetString(name, "show")
|
||||
if err != nil {
|
||||
return visibleByPolicy, err
|
||||
}
|
||||
switch opt {
|
||||
case "hide":
|
||||
return hiddenByPolicy, nil
|
||||
default:
|
||||
return visibleByPolicy, nil
|
||||
}
|
||||
func GetVisibility(name Key) (setting.Visibility, error) {
|
||||
return getCurrentPolicySettingValue(name, setting.VisibleByPolicy)
|
||||
}
|
||||
|
||||
// GetDuration loads a policy from the registry that can be managed
|
||||
@ -143,15 +95,48 @@ func GetVisibility(name Key) (Visibility, error) {
|
||||
// understands. If the registry value is "" or can not be processed,
|
||||
// defaultValue is returned instead.
|
||||
func GetDuration(name Key, defaultValue time.Duration) (time.Duration, error) {
|
||||
opt, err := GetString(name, "")
|
||||
if opt == "" || err != nil {
|
||||
return defaultValue, err
|
||||
d, err := getCurrentPolicySettingValue(name, defaultValue)
|
||||
if err != nil {
|
||||
return d, err
|
||||
}
|
||||
v, err := time.ParseDuration(opt)
|
||||
if err != nil || v < 0 {
|
||||
if d < 0 {
|
||||
return defaultValue, nil
|
||||
}
|
||||
return v, nil
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// getCurrentPolicySettingValue returns the value of the policy setting
|
||||
// specified by its key from the [rsop.Policy] of the [CurrentScope]. It
|
||||
// returns def if the policy setting is not configured, or an error if it has
|
||||
// an error or could not be converted to the specified type T.
|
||||
func getCurrentPolicySettingValue[T setting.ValueType](key Key, def T) (T, error) {
|
||||
resultant, err := rsop.PolicyFor(setting.CurrentScope())
|
||||
if err != nil {
|
||||
return def, err
|
||||
}
|
||||
value, err := resultant.Get().GetErr(key)
|
||||
if err != nil {
|
||||
if errors.Is(err, setting.ErrNotConfigured) || errors.Is(err, setting.ErrNoSuchKey) {
|
||||
return def, nil
|
||||
}
|
||||
return def, err
|
||||
}
|
||||
if res, ok := value.(T); ok {
|
||||
return res, nil
|
||||
}
|
||||
return convertPolicySettingValueTo(value, def)
|
||||
}
|
||||
|
||||
func convertPolicySettingValueTo[T setting.ValueType](value any, def T) (T, error) {
|
||||
// Convert [PreferenceOption], [Visibility], or [time.Duration] back to a string
|
||||
// if someone requests a string instead of the actual setting's value.
|
||||
// TODO(nickkhyl): check if this behavior is relied upon anywhere besides the old tests.
|
||||
if reflect.TypeFor[T]().Kind() == reflect.String {
|
||||
if str, ok := value.(fmt.Stringer); ok {
|
||||
return any(str.String()).(T), nil
|
||||
}
|
||||
}
|
||||
return def, fmt.Errorf("%w: got %T, want %T", setting.ErrTypeMismatch, value, def)
|
||||
}
|
||||
|
||||
// SelectControlURL returns the ControlURL to use based on a value in
|
||||
|
@ -5,16 +5,24 @@
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/syspolicy/internal/loggerx"
|
||||
"tailscale.com/util/syspolicy/internal/metrics"
|
||||
"tailscale.com/util/syspolicy/rsop"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/syspolicy/source"
|
||||
)
|
||||
|
||||
// testHandler encompasses all data types returned when testing any of the syspolicy
|
||||
// methods that involve getting a policy value.
|
||||
// For keys and the corresponding values, check policy_keys.go.
|
||||
type testHandler struct {
|
||||
t *testing.T
|
||||
t testing.TB
|
||||
key Key
|
||||
s string
|
||||
u64 uint64
|
||||
@ -28,7 +36,10 @@ type testHandler struct {
|
||||
|
||||
func (th *testHandler) ReadString(key string) (string, error) {
|
||||
if key != string(th.key) {
|
||||
th.t.Errorf("ReadString(%q) want %q", key, th.key)
|
||||
// The syspolicy package now reads and caches all registered policy settings.
|
||||
// Therefore, it is expected to call the handler requesting all policies
|
||||
// rather than just the specific ones we asked for.
|
||||
return "", ErrNotConfigured
|
||||
}
|
||||
th.calls++
|
||||
return th.s, th.err
|
||||
@ -36,7 +47,10 @@ func (th *testHandler) ReadString(key string) (string, error) {
|
||||
|
||||
func (th *testHandler) ReadUInt64(key string) (uint64, error) {
|
||||
if key != string(th.key) {
|
||||
th.t.Errorf("ReadUint64(%q) want %q", key, th.key)
|
||||
// The syspolicy package now reads and caches all registered policy settings.
|
||||
// Therefore, it is expected to call the handler requesting all policies
|
||||
// rather than just the specific ones we asked for.
|
||||
return 0, ErrNotConfigured
|
||||
}
|
||||
th.calls++
|
||||
return th.u64, th.err
|
||||
@ -44,7 +58,10 @@ func (th *testHandler) ReadUInt64(key string) (uint64, error) {
|
||||
|
||||
func (th *testHandler) ReadBoolean(key string) (bool, error) {
|
||||
if key != string(th.key) {
|
||||
th.t.Errorf("ReadBool(%q) want %q", key, th.key)
|
||||
// The syspolicy package now reads and caches all registered policy settings.
|
||||
// Therefore, it is expected to call the handler requesting all policies
|
||||
// rather than just the specific ones we asked for.
|
||||
return false, ErrNotConfigured
|
||||
}
|
||||
th.calls++
|
||||
return th.b, th.err
|
||||
@ -52,7 +69,10 @@ func (th *testHandler) ReadBoolean(key string) (bool, error) {
|
||||
|
||||
func (th *testHandler) ReadStringArray(key string) ([]string, error) {
|
||||
if key != string(th.key) {
|
||||
th.t.Errorf("ReadStringArray(%q) want %q", key, th.key)
|
||||
// The syspolicy package now reads and caches all registered policy settings.
|
||||
// Therefore, it is expected to call the handler requesting all policies
|
||||
// rather than just the specific ones we asked for.
|
||||
return nil, ErrNotConfigured
|
||||
}
|
||||
th.calls++
|
||||
return th.sArr, th.err
|
||||
@ -67,23 +87,28 @@ func TestGetString(t *testing.T) {
|
||||
defaultValue string
|
||||
wantValue string
|
||||
wantError error
|
||||
wantMetrics []metrics.TestState
|
||||
}{
|
||||
{
|
||||
name: "read existing value",
|
||||
key: AdminConsoleVisibility,
|
||||
handlerValue: "hide",
|
||||
wantValue: "hide",
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AdminConsole", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read non-existing value",
|
||||
key: EnableServerMode,
|
||||
handlerError: ErrNoSuchKey,
|
||||
handlerError: ErrNotConfigured,
|
||||
wantError: nil,
|
||||
},
|
||||
{
|
||||
name: "read non-existing value, non-blank default",
|
||||
key: EnableServerMode,
|
||||
handlerError: ErrNoSuchKey,
|
||||
handlerError: ErrNotConfigured,
|
||||
defaultValue: "test",
|
||||
wantValue: "test",
|
||||
wantError: nil,
|
||||
@ -93,11 +118,17 @@ func TestGetString(t *testing.T) {
|
||||
key: NetworkDevicesVisibility,
|
||||
handlerError: someOtherError,
|
||||
wantError: someOtherError,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_NetworkDevices_error", Value: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := metrics.NewTestHandler(t)
|
||||
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
@ -105,12 +136,21 @@ func TestGetString(t *testing.T) {
|
||||
err: tt.handlerError,
|
||||
})
|
||||
value, err := GetString(tt.key, tt.defaultValue)
|
||||
if err != tt.wantError {
|
||||
if !errorsMatchForTest(err, tt.wantError) {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if value != tt.wantValue {
|
||||
t.Errorf("value=%v, want %v", value, tt.wantValue)
|
||||
}
|
||||
wantMetrics := tt.wantMetrics
|
||||
if !metrics.ShouldReport() {
|
||||
// Check that metrics are not reported on platforms
|
||||
// where they shouldn't be reported.
|
||||
// As of 2024-08-02, syspolicy only reports metrics
|
||||
// on Windows and Android.
|
||||
wantMetrics = nil
|
||||
}
|
||||
h.MustEqual(wantMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -127,7 +167,7 @@ func TestGetUint64(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "read existing value",
|
||||
key: KeyExpirationNoticeTime,
|
||||
key: LogSCMInteractions,
|
||||
handlerValue: 1,
|
||||
wantValue: 1,
|
||||
},
|
||||
@ -135,14 +175,14 @@ func TestGetUint64(t *testing.T) {
|
||||
name: "read non-existing value",
|
||||
key: LogSCMInteractions,
|
||||
handlerValue: 0,
|
||||
handlerError: ErrNoSuchKey,
|
||||
handlerError: ErrNotConfigured,
|
||||
wantValue: 0,
|
||||
},
|
||||
{
|
||||
name: "read non-existing value, non-zero default",
|
||||
key: LogSCMInteractions,
|
||||
defaultValue: 2,
|
||||
handlerError: ErrNoSuchKey,
|
||||
handlerError: ErrNotConfigured,
|
||||
wantValue: 2,
|
||||
},
|
||||
{
|
||||
@ -155,14 +195,21 @@ func TestGetUint64(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
// None of the policy settings tested here are integers.
|
||||
// In fact, we don't have any integer policies as of 2024-07-29.
|
||||
// However, we can register each of them as an integer policy setting
|
||||
// for the duration of the test, providing us with something to test against.
|
||||
if err := setting.SetDefinitionsForTest(t, setting.NewDefinition(tt.key, setting.DeviceSetting, setting.IntegerValue)); err != nil {
|
||||
t.Fatalf("SetDefinitionsForTest failed: %v", err)
|
||||
}
|
||||
rsop.RegisterStoreForTest(t, tt.name, setting.DeviceScope, WrapHandler(&testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
u64: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
})
|
||||
}))
|
||||
value, err := GetUint64(tt.key, tt.defaultValue)
|
||||
if err != tt.wantError {
|
||||
if !errorsMatchForTest(err, tt.wantError) {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if value != tt.wantValue {
|
||||
@ -181,32 +228,43 @@ func TestGetBoolean(t *testing.T) {
|
||||
defaultValue bool
|
||||
wantValue bool
|
||||
wantError error
|
||||
wantMetrics []metrics.TestState
|
||||
}{
|
||||
{
|
||||
name: "read existing value",
|
||||
key: FlushDNSOnSessionUnlock,
|
||||
handlerValue: true,
|
||||
wantValue: true,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_FlushDNSOnSessionUnlock", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read non-existing value",
|
||||
key: LogSCMInteractions,
|
||||
handlerValue: false,
|
||||
handlerError: ErrNoSuchKey,
|
||||
handlerError: ErrNotConfigured,
|
||||
wantValue: false,
|
||||
},
|
||||
{
|
||||
name: "reading value returns other error",
|
||||
key: FlushDNSOnSessionUnlock,
|
||||
handlerError: someOtherError,
|
||||
wantError: someOtherError,
|
||||
wantError: someOtherError, // expect error...
|
||||
defaultValue: true,
|
||||
wantValue: false,
|
||||
wantValue: true, // ...AND default value if the handler fails.
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_FlushDNSOnSessionUnlock_error", Value: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := metrics.NewTestHandler(t)
|
||||
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
@ -214,12 +272,21 @@ func TestGetBoolean(t *testing.T) {
|
||||
err: tt.handlerError,
|
||||
})
|
||||
value, err := GetBoolean(tt.key, tt.defaultValue)
|
||||
if err != tt.wantError {
|
||||
if !errorsMatchForTest(err, tt.wantError) {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if value != tt.wantValue {
|
||||
t.Errorf("value=%v, want %v", value, tt.wantValue)
|
||||
}
|
||||
wantMetrics := tt.wantMetrics
|
||||
if !metrics.ShouldReport() {
|
||||
// Check that metrics are not reported on platforms
|
||||
// where they shouldn't be reported.
|
||||
// As of 2024-08-02, syspolicy only reports metrics
|
||||
// on Windows and Android.
|
||||
wantMetrics = nil
|
||||
}
|
||||
h.MustEqual(wantMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -232,42 +299,61 @@ func TestGetPreferenceOption(t *testing.T) {
|
||||
handlerError error
|
||||
wantValue PreferenceOption
|
||||
wantError error
|
||||
wantMetrics []metrics.TestState
|
||||
}{
|
||||
{
|
||||
name: "always by policy",
|
||||
key: EnableIncomingConnections,
|
||||
handlerValue: "always",
|
||||
wantValue: alwaysByPolicy,
|
||||
wantValue: setting.AlwaysByPolicy,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AllowIncomingConnections", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "never by policy",
|
||||
key: EnableIncomingConnections,
|
||||
handlerValue: "never",
|
||||
wantValue: neverByPolicy,
|
||||
wantValue: setting.NeverByPolicy,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AllowIncomingConnections", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "use default",
|
||||
key: EnableIncomingConnections,
|
||||
handlerValue: "",
|
||||
wantValue: showChoiceByPolicy,
|
||||
wantValue: setting.ShowChoiceByPolicy,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AllowIncomingConnections", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read non-existing value",
|
||||
key: EnableIncomingConnections,
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantValue: showChoiceByPolicy,
|
||||
handlerError: ErrNotConfigured,
|
||||
wantValue: setting.ShowChoiceByPolicy,
|
||||
},
|
||||
{
|
||||
name: "other error is returned",
|
||||
key: EnableIncomingConnections,
|
||||
handlerError: someOtherError,
|
||||
wantValue: showChoiceByPolicy,
|
||||
wantValue: setting.ShowChoiceByPolicy,
|
||||
wantError: someOtherError,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_AllowIncomingConnections_error", Value: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := metrics.NewTestHandler(t)
|
||||
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
@ -275,12 +361,21 @@ func TestGetPreferenceOption(t *testing.T) {
|
||||
err: tt.handlerError,
|
||||
})
|
||||
option, err := GetPreferenceOption(tt.key)
|
||||
if err != tt.wantError {
|
||||
if !errorsMatchForTest(err, tt.wantError) {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if option != tt.wantValue {
|
||||
t.Errorf("option=%v, want %v", option, tt.wantValue)
|
||||
}
|
||||
wantMetrics := tt.wantMetrics
|
||||
if !metrics.ShouldReport() {
|
||||
// Check that metrics are not reported on platforms
|
||||
// where they shouldn't be reported.
|
||||
// As of 2024-08-02, syspolicy only reports metrics
|
||||
// on Windows and Android.
|
||||
wantMetrics = nil
|
||||
}
|
||||
h.MustEqual(wantMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -293,38 +388,53 @@ func TestGetVisibility(t *testing.T) {
|
||||
handlerError error
|
||||
wantValue Visibility
|
||||
wantError error
|
||||
wantMetrics []metrics.TestState
|
||||
}{
|
||||
{
|
||||
name: "hidden by policy",
|
||||
key: AdminConsoleVisibility,
|
||||
handlerValue: "hide",
|
||||
wantValue: hiddenByPolicy,
|
||||
wantValue: setting.HiddenByPolicy,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AdminConsole", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "visibility default",
|
||||
key: AdminConsoleVisibility,
|
||||
handlerValue: "show",
|
||||
wantValue: visibleByPolicy,
|
||||
wantValue: setting.VisibleByPolicy,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AdminConsole", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read non-existing value",
|
||||
key: AdminConsoleVisibility,
|
||||
handlerValue: "show",
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantValue: visibleByPolicy,
|
||||
handlerError: ErrNotConfigured,
|
||||
wantValue: setting.VisibleByPolicy,
|
||||
},
|
||||
{
|
||||
name: "other error is returned",
|
||||
key: AdminConsoleVisibility,
|
||||
handlerValue: "show",
|
||||
handlerError: someOtherError,
|
||||
wantValue: visibleByPolicy,
|
||||
wantValue: setting.VisibleByPolicy,
|
||||
wantError: someOtherError,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_AdminConsole_error", Value: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := metrics.NewTestHandler(t)
|
||||
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
@ -332,12 +442,21 @@ func TestGetVisibility(t *testing.T) {
|
||||
err: tt.handlerError,
|
||||
})
|
||||
visibility, err := GetVisibility(tt.key)
|
||||
if err != tt.wantError {
|
||||
if !errorsMatchForTest(err, tt.wantError) {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if visibility != tt.wantValue {
|
||||
t.Errorf("visibility=%v, want %v", visibility, tt.wantValue)
|
||||
}
|
||||
wantMetrics := tt.wantMetrics
|
||||
if !metrics.ShouldReport() {
|
||||
// Check that metrics are not reported on platforms
|
||||
// where they shouldn't be reported.
|
||||
// As of 2024-08-02, syspolicy only reports metrics
|
||||
// on Windows and Android.
|
||||
wantMetrics = nil
|
||||
}
|
||||
h.MustEqual(wantMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -351,6 +470,7 @@ func TestGetDuration(t *testing.T) {
|
||||
defaultValue time.Duration
|
||||
wantValue time.Duration
|
||||
wantError error
|
||||
wantMetrics []metrics.TestState
|
||||
}{
|
||||
{
|
||||
name: "read existing value",
|
||||
@ -358,25 +478,34 @@ func TestGetDuration(t *testing.T) {
|
||||
handlerValue: "2h",
|
||||
wantValue: 2 * time.Hour,
|
||||
defaultValue: 24 * time.Hour,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_KeyExpirationNotice", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid duration value",
|
||||
key: KeyExpirationNoticeTime,
|
||||
handlerValue: "-20",
|
||||
wantValue: 24 * time.Hour,
|
||||
wantError: errors.New(`time: missing unit in duration "-20"`),
|
||||
defaultValue: 24 * time.Hour,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read non-existing value",
|
||||
key: KeyExpirationNoticeTime,
|
||||
handlerError: ErrNoSuchKey,
|
||||
handlerError: ErrNotConfigured,
|
||||
wantValue: 24 * time.Hour,
|
||||
defaultValue: 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "read non-existing value different default",
|
||||
key: KeyExpirationNoticeTime,
|
||||
handlerError: ErrNoSuchKey,
|
||||
handlerError: ErrNotConfigured,
|
||||
wantValue: 0 * time.Second,
|
||||
defaultValue: 0 * time.Second,
|
||||
},
|
||||
@ -387,11 +516,17 @@ func TestGetDuration(t *testing.T) {
|
||||
wantValue: 24 * time.Hour,
|
||||
wantError: someOtherError,
|
||||
defaultValue: 24 * time.Hour,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := metrics.NewTestHandler(t)
|
||||
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
@ -399,12 +534,21 @@ func TestGetDuration(t *testing.T) {
|
||||
err: tt.handlerError,
|
||||
})
|
||||
duration, err := GetDuration(tt.key, tt.defaultValue)
|
||||
if err != tt.wantError {
|
||||
if fmt.Sprint(err) != fmt.Sprint(tt.wantError) {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if duration != tt.wantValue {
|
||||
t.Errorf("duration=%v, want %v", duration, tt.wantValue)
|
||||
}
|
||||
wantMetrics := tt.wantMetrics
|
||||
if !metrics.ShouldReport() {
|
||||
// Check that metrics are not reported on platforms
|
||||
// where they shouldn't be reported.
|
||||
// As of 2024-08-02, syspolicy only reports metrics
|
||||
// on Windows and Android.
|
||||
wantMetrics = nil
|
||||
}
|
||||
h.MustEqual(wantMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -418,23 +562,28 @@ func TestGetStringArray(t *testing.T) {
|
||||
defaultValue []string
|
||||
wantValue []string
|
||||
wantError error
|
||||
wantMetrics []metrics.TestState
|
||||
}{
|
||||
{
|
||||
name: "read existing value",
|
||||
key: AllowedSuggestedExitNodes,
|
||||
handlerValue: []string{"foo", "bar"},
|
||||
wantValue: []string{"foo", "bar"},
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_any", Value: 1},
|
||||
{Name: "$os_syspolicy_AllowedSuggestedExitNodes", Value: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read non-existing value",
|
||||
key: AllowedSuggestedExitNodes,
|
||||
handlerError: ErrNoSuchKey,
|
||||
handlerError: ErrNotConfigured,
|
||||
wantError: nil,
|
||||
},
|
||||
{
|
||||
name: "read non-existing value, non nil default",
|
||||
key: AllowedSuggestedExitNodes,
|
||||
handlerError: ErrNoSuchKey,
|
||||
handlerError: ErrNotConfigured,
|
||||
defaultValue: []string{"foo", "bar"},
|
||||
wantValue: []string{"foo", "bar"},
|
||||
wantError: nil,
|
||||
@ -444,11 +593,17 @@ func TestGetStringArray(t *testing.T) {
|
||||
key: AllowedSuggestedExitNodes,
|
||||
handlerError: someOtherError,
|
||||
wantError: someOtherError,
|
||||
wantMetrics: []metrics.TestState{
|
||||
{Name: "$os_syspolicy_errors", Value: 1},
|
||||
{Name: "$os_syspolicy_AllowedSuggestedExitNodes_error", Value: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := metrics.NewTestHandler(t)
|
||||
metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric)
|
||||
SetHandlerForTest(t, &testHandler{
|
||||
t: t,
|
||||
key: tt.key,
|
||||
@ -456,16 +611,47 @@ func TestGetStringArray(t *testing.T) {
|
||||
err: tt.handlerError,
|
||||
})
|
||||
value, err := GetStringArray(tt.key, tt.defaultValue)
|
||||
if err != tt.wantError {
|
||||
if !errorsMatchForTest(err, tt.wantError) {
|
||||
t.Errorf("err=%q, want %q", err, tt.wantError)
|
||||
}
|
||||
if !slices.Equal(tt.wantValue, value) {
|
||||
t.Errorf("value=%v, want %v", value, tt.wantValue)
|
||||
}
|
||||
wantMetrics := tt.wantMetrics
|
||||
if !metrics.ShouldReport() {
|
||||
// Check that metrics are not reported on platforms
|
||||
// where they shouldn't be reported.
|
||||
// As of 2024-08-02, syspolicy only reports metrics
|
||||
// on Windows and Android.
|
||||
wantMetrics = nil
|
||||
}
|
||||
h.MustEqual(wantMetrics...)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetString(b *testing.B) {
|
||||
loggerx.SetForTest(b, logger.Discard, logger.Discard)
|
||||
setWellKnownSettingsForTest(b)
|
||||
|
||||
store := source.NewTestStore(b)
|
||||
wantControlURL := "https://login.tailscale.com"
|
||||
store.SetStrings(source.TestSetting[string]{Key: ControlURL, Value: wantControlURL})
|
||||
|
||||
_, err := rsop.RegisterStoreForTest(b, "Test Store", setting.DeviceScope, store)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
gotControlURL, _ := GetString(ControlURL, "https://controlplane.tailscale.com")
|
||||
if gotControlURL != wantControlURL {
|
||||
b.Fatalf("got %v; want %v", gotControlURL, wantControlURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectControlURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
reg, disk, want string
|
||||
@ -497,3 +683,13 @@ func TestSelectControlURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func errorsMatchForTest(got, want error) bool {
|
||||
if got == nil && want == nil {
|
||||
return true
|
||||
}
|
||||
if got == nil || want == nil {
|
||||
return false
|
||||
}
|
||||
return errors.Is(got, want) || got.Error() == want.Error()
|
||||
}
|
||||
|
93
util/syspolicy/syspolicy_windows.go
Normal file
93
util/syspolicy/syspolicy_windows.go
Normal file
@ -0,0 +1,93 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/user"
|
||||
|
||||
"tailscale.com/util/syspolicy/internal"
|
||||
"tailscale.com/util/syspolicy/internal/lazyinit"
|
||||
"tailscale.com/util/syspolicy/rsop"
|
||||
"tailscale.com/util/syspolicy/setting"
|
||||
"tailscale.com/util/syspolicy/source"
|
||||
"tailscale.com/util/testenv"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// On Windows, we should automatically register the Registry-based policy
|
||||
// store for the device. If we are running in a user's security context
|
||||
// (e.g., we're the GUI), we should also register the Registry policy store for
|
||||
// the user. In the future, we should register (and unregister) user policy
|
||||
// stores whenever a user connects to the local backend. This ensures the
|
||||
// backend is aware of the user's policy settings and can send them to the
|
||||
// GUI/CLI/Web clients on demand or whenever they change.
|
||||
//
|
||||
// Other platforms, such as macOS, iOS and Android, should register their
|
||||
// platform-specific policy stores via [RegisterStore] (or [RegisterHandler]
|
||||
// until they implement the [Store] interface).
|
||||
//
|
||||
// External code, such as the ipnlocal package, may choose to register
|
||||
// additional policy stores, such as config files and policies received from
|
||||
// the control plane.
|
||||
lazyinit.Defer(func() error {
|
||||
// Do not register or use default policy stores during tests.
|
||||
// Each test should set up its own necessary configurations.
|
||||
if testenv.InTest() {
|
||||
return nil
|
||||
}
|
||||
return configureSyspolicy(nil)
|
||||
})
|
||||
}
|
||||
|
||||
// configureSyspolicy configures syspolicy for use on Windows,
|
||||
// either in test or regular builds depending on whether tb has a non-nil value.
|
||||
func configureSyspolicy(tb internal.TB) error {
|
||||
const localSystemSID = "S-1-5-18"
|
||||
// Always create and register a machine policy store that reads
|
||||
// policy settings from the HKEY_LOCAL_MACHINE registry hive.
|
||||
machineStore, err := source.NewMachinePlatformPolicyStore()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create the machine policy store: %v", err)
|
||||
}
|
||||
if tb == nil {
|
||||
_, err = rsop.RegisterStore("Platform", setting.DeviceScope, machineStore)
|
||||
} else {
|
||||
_, err = rsop.RegisterStoreForTest(tb, "Platform", setting.DeviceScope, machineStore)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Check whether the current process is running as Local System or not.
|
||||
u, err := user.Current()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if u.Uid == localSystemSID {
|
||||
return nil
|
||||
}
|
||||
// If it's not a Local System's process (e.g., the GUI and not the tailscaled service),
|
||||
// we should create and use a policy store for the current user that reads
|
||||
// policy settings from that user's registry hive (HKEY_CURRENT_USER).
|
||||
userStore, err := source.NewUserPlatformPolicyStore(0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create the current user's policy store: %v", err)
|
||||
}
|
||||
if tb == nil {
|
||||
_, err = rsop.RegisterStore("Platform", setting.CurrentUserScope, userStore)
|
||||
} else {
|
||||
_, err = rsop.RegisterStoreForTest(tb, "Platform", setting.CurrentUserScope, userStore)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// And also set [CurrentUserScope] as the [CurrentScope], so [GetString],
|
||||
// [GetVisibility] and similar functions would be returning a merged result
|
||||
// of the machine's and user's policies.
|
||||
if !setting.SetCurrentScope(setting.CurrentUserScope) {
|
||||
return errors.New("current scope already set")
|
||||
}
|
||||
return nil
|
||||
}
|
@ -189,6 +189,7 @@ func (l *PolicyLock) lockSlow() (err error) {
|
||||
select {
|
||||
case resultCh <- policyLockResult{handle, err}:
|
||||
// lockSlow has received the result.
|
||||
break send_result
|
||||
default:
|
||||
select {
|
||||
case <-closing:
|
||||
|
Loading…
x
Reference in New Issue
Block a user