diff --git a/util/syspolicy/caching_handler.go b/util/syspolicy/caching_handler.go new file mode 100644 index 000000000..fff96de1c --- /dev/null +++ b/util/syspolicy/caching_handler.go @@ -0,0 +1,98 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syspolicy + +import ( + "errors" + "sync" +) + +// CachingHandler is a handler that reads policies from an underlying handler the first time each key is requested +// and permanently caches the result unless there is an error. If there is an ErrNoSuchKey error, that result is cached, +// otherwise the actual error is returned and the next read for that key will retry using the handler. +type CachingHandler struct { + mu sync.Mutex + strings map[string]string + uint64s map[string]uint64 + bools map[string]bool + notFound map[string]bool + handler Handler +} + +// NewCachingHandler creates a CachingHandler given a handler. +func NewCachingHandler(handler Handler) *CachingHandler { + return &CachingHandler{ + handler: handler, + strings: make(map[string]string), + uint64s: make(map[string]uint64), + bools: make(map[string]bool), + notFound: make(map[string]bool), + } +} + +// ReadString reads the policy settings value string given the key. +// ReadString first reads from the handler's cache before resorting to using the handler. +func (ch *CachingHandler) ReadString(key string) (string, error) { + ch.mu.Lock() + defer ch.mu.Unlock() + if val, ok := ch.strings[key]; ok { + return val, nil + } + if notFound := ch.notFound[key]; notFound { + return "", ErrNoSuchKey + } + val, err := ch.handler.ReadString(key) + if errors.Is(err, ErrNoSuchKey) { + ch.notFound[key] = true + return "", err + } else if err != nil { + return "", err + } + ch.strings[key] = val + return val, nil +} + +// ReadUInt64 reads the policy settings uint64 value given the key. +// ReadUInt64 first reads from the handler's cache before resorting to using the handler. +func (ch *CachingHandler) ReadUInt64(key string) (uint64, error) { + ch.mu.Lock() + defer ch.mu.Unlock() + if val, ok := ch.uint64s[key]; ok { + return val, nil + } + if notFound := ch.notFound[key]; notFound { + return 0, ErrNoSuchKey + } + val, err := ch.handler.ReadUInt64(key) + if errors.Is(err, ErrNoSuchKey) { + ch.notFound[key] = true + return 0, err + } else if err != nil { + return 0, err + } + ch.uint64s[key] = val + return val, nil +} + +// ReadBoolean reads the policy settings boolean value given the key. +// ReadBoolean first reads from the handler's cache before resorting to using the handler. +func (ch *CachingHandler) ReadBoolean(key string) (bool, error) { + ch.mu.Lock() + defer ch.mu.Unlock() + if val, ok := ch.bools[key]; ok { + return val, nil + } + if notFound := ch.notFound[key]; notFound { + return false, ErrNoSuchKey + } + val, err := ch.handler.ReadBoolean(key) + if errors.Is(err, ErrNoSuchKey) { + ch.notFound[key] = true + return false, err + } else if err != nil { + return false, err + } + ch.bools[key] = val + return val, nil +} diff --git a/util/syspolicy/caching_handler_test.go b/util/syspolicy/caching_handler_test.go new file mode 100644 index 000000000..881f6ff83 --- /dev/null +++ b/util/syspolicy/caching_handler_test.go @@ -0,0 +1,262 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syspolicy + +import ( + "testing" +) + +func TestHandlerReadString(t *testing.T) { + tests := []struct { + name string + key string + handlerKey Key + handlerValue string + handlerError error + preserveHandler bool + wantValue string + wantErr error + strings map[string]string + expectedCalls int + }{ + { + name: "read existing cached values", + key: "test", + handlerKey: "do not read", + strings: map[string]string{"test": "foo"}, + wantValue: "foo", + expectedCalls: 0, + }, + { + name: "read existing values not cached", + key: "test", + handlerKey: "test", + handlerValue: "foo", + wantValue: "foo", + expectedCalls: 1, + }, + { + name: "error no such key", + key: "test", + handlerKey: "test", + handlerError: ErrNoSuchKey, + wantErr: ErrNoSuchKey, + expectedCalls: 1, + }, + { + name: "other error", + key: "test", + handlerKey: "test", + handlerError: someOtherError, + wantErr: someOtherError, + preserveHandler: true, + expectedCalls: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testHandler := &testHandler{ + t: t, + key: tt.handlerKey, + s: tt.handlerValue, + err: tt.handlerError, + } + cache := NewCachingHandler(testHandler) + if tt.strings != nil { + cache.strings = tt.strings + } + got, err := cache.ReadString(tt.key) + if err != tt.wantErr { + t.Errorf("err=%v want %v", err, tt.wantErr) + } + if got != tt.wantValue { + t.Errorf("got %v want %v", got, cache.strings[tt.key]) + } + if !tt.preserveHandler { + testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil + } + got, err = cache.ReadString(tt.key) + if err != tt.wantErr { + t.Errorf("repeat err=%v want %v", err, tt.wantErr) + } + if got != tt.wantValue { + t.Errorf("repeat got %v want %v", got, cache.strings[tt.key]) + } + if testHandler.calls != tt.expectedCalls { + t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls) + } + }) + } +} + +func TestHandlerReadUint64(t *testing.T) { + tests := []struct { + name string + key string + handlerKey Key + handlerValue uint64 + handlerError error + preserveHandler bool + wantValue uint64 + wantErr error + uint64s map[string]uint64 + expectedCalls int + }{ + { + name: "read existing cached values", + key: "test", + handlerKey: "do not read", + uint64s: map[string]uint64{"test": 1}, + wantValue: 1, + expectedCalls: 0, + }, + { + name: "read existing values not cached", + key: "test", + handlerKey: "test", + handlerValue: 1, + wantValue: 1, + expectedCalls: 1, + }, + { + name: "error no such key", + key: "test", + handlerKey: "test", + handlerError: ErrNoSuchKey, + wantErr: ErrNoSuchKey, + expectedCalls: 1, + }, + { + name: "other error", + key: "test", + handlerKey: "test", + handlerError: someOtherError, + wantErr: someOtherError, + preserveHandler: true, + expectedCalls: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testHandler := &testHandler{ + t: t, + key: tt.handlerKey, + u64: tt.handlerValue, + err: tt.handlerError, + } + cache := NewCachingHandler(testHandler) + if tt.uint64s != nil { + cache.uint64s = tt.uint64s + } + got, err := cache.ReadUInt64(tt.key) + if err != tt.wantErr { + t.Errorf("err=%v want %v", err, tt.wantErr) + } + if got != tt.wantValue { + t.Errorf("got %v want %v", got, cache.strings[tt.key]) + } + if !tt.preserveHandler { + testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil + } + got, err = cache.ReadUInt64(tt.key) + if err != tt.wantErr { + t.Errorf("repeat err=%v want %v", err, tt.wantErr) + } + if got != tt.wantValue { + t.Errorf("repeat got %v want %v", got, cache.strings[tt.key]) + } + if testHandler.calls != tt.expectedCalls { + t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls) + } + }) + } + +} + +func TestHandlerReadBool(t *testing.T) { + tests := []struct { + name string + key string + handlerKey Key + handlerValue bool + handlerError error + preserveHandler bool + wantValue bool + wantErr error + bools map[string]bool + expectedCalls int + }{ + { + name: "read existing cached values", + key: "test", + handlerKey: "do not read", + bools: map[string]bool{"test": true}, + wantValue: true, + expectedCalls: 0, + }, + { + name: "read existing values not cached", + key: "test", + handlerKey: "test", + handlerValue: true, + wantValue: true, + expectedCalls: 1, + }, + { + name: "error no such key", + key: "test", + handlerKey: "test", + handlerError: ErrNoSuchKey, + wantErr: ErrNoSuchKey, + expectedCalls: 1, + }, + { + name: "other error", + key: "test", + handlerKey: "test", + handlerError: someOtherError, + wantErr: someOtherError, + preserveHandler: true, + expectedCalls: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testHandler := &testHandler{ + t: t, + key: tt.handlerKey, + b: tt.handlerValue, + err: tt.handlerError, + } + cache := NewCachingHandler(testHandler) + if tt.bools != nil { + cache.bools = tt.bools + } + got, err := cache.ReadBoolean(tt.key) + if err != tt.wantErr { + t.Errorf("err=%v want %v", err, tt.wantErr) + } + if got != tt.wantValue { + t.Errorf("got %v want %v", got, cache.strings[tt.key]) + } + if !tt.preserveHandler { + testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil + } + got, err = cache.ReadBoolean(tt.key) + if err != tt.wantErr { + t.Errorf("repeat err=%v want %v", err, tt.wantErr) + } + if got != tt.wantValue { + t.Errorf("repeat got %v want %v", got, cache.strings[tt.key]) + } + if testHandler.calls != tt.expectedCalls { + t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls) + } + }) + } + +} diff --git a/util/syspolicy/handler_windows.go b/util/syspolicy/handler_windows.go index c12a21fdf..c259bf96c 100644 --- a/util/syspolicy/handler_windows.go +++ b/util/syspolicy/handler_windows.go @@ -12,7 +12,7 @@ type windowsHandler struct{} func init() { - RegisterHandler(windowsHandler{}) + RegisterHandler(NewCachingHandler(windowsHandler{})) } func (windowsHandler) ReadString(key string) (string, error) { diff --git a/util/syspolicy/syspolicy_test.go b/util/syspolicy/syspolicy_test.go index 859843431..ea6749ce3 100644 --- a/util/syspolicy/syspolicy_test.go +++ b/util/syspolicy/syspolicy_test.go @@ -13,12 +13,13 @@ // methods that involve getting a policy value. // For keys and the corresponding values, check policy_keys.go. type testHandler struct { - t *testing.T - key Key - s string - u64 uint64 - b bool - err error + t *testing.T + key Key + s string + u64 uint64 + b bool + err error + calls int // used for testing reads from cache vs. handler } var someOtherError = errors.New("error other than not found") @@ -34,6 +35,7 @@ func (th *testHandler) ReadString(key string) (string, error) { if key != string(th.key) { th.t.Errorf("ReadString(%q) want %q", key, th.key) } + th.calls++ return th.s, th.err } @@ -41,6 +43,7 @@ func (th *testHandler) ReadUInt64(key string) (uint64, error) { if key != string(th.key) { th.t.Errorf("ReadUint64(%q) want %q", key, th.key) } + th.calls++ return th.u64, th.err } @@ -48,6 +51,7 @@ func (th *testHandler) ReadBoolean(key string) (bool, error) { if key != string(th.key) { th.t.Errorf("ReadBool(%q) want %q", key, th.key) } + th.calls++ return th.b, th.err }