diff --git a/util/syspolicy/internal/metrics/metrics.go b/util/syspolicy/internal/metrics/metrics.go index 2ea02278a..0a2aa1192 100644 --- a/util/syspolicy/internal/metrics/metrics.go +++ b/util/syspolicy/internal/metrics/metrics.go @@ -284,7 +284,7 @@ func SetHooksForTest(tb internal.TB, addMetric, setMetric metricFn) { } func newSettingMetric(key setting.Key, scope setting.Scope, suffix string, typ clientmetric.Type) metric { - name := strings.ReplaceAll(string(key), setting.KeyPathSeparator, "_") + name := strings.ReplaceAll(string(key), string(setting.KeyPathSeparator), "_") return newMetric([]string{name, metricScopeName(scope), suffix}, typ) } diff --git a/util/syspolicy/setting/key.go b/util/syspolicy/setting/key.go index 406fde132..aa7606d36 100644 --- a/util/syspolicy/setting/key.go +++ b/util/syspolicy/setting/key.go @@ -10,4 +10,4 @@ type Key string // KeyPathSeparator allows logical grouping of policy settings into categories. -const KeyPathSeparator = "/" +const KeyPathSeparator = '/' diff --git a/util/syspolicy/source/env_policy_store.go b/util/syspolicy/source/env_policy_store.go new file mode 100644 index 000000000..2f07fffca --- /dev/null +++ b/util/syspolicy/source/env_policy_store.go @@ -0,0 +1,159 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package source + +import ( + "fmt" + "os" + "strconv" + "strings" + "unicode/utf8" + + "github.com/pkg/errors" + "tailscale.com/util/syspolicy/setting" +) + +var lookupEnv = os.LookupEnv // test hook + +var _ Store = (*EnvPolicyStore)(nil) + +// EnvPolicyStore is a [Store] that reads policy settings from environment variables. +type EnvPolicyStore struct{} + +// ReadString implements [Store]. +func (s *EnvPolicyStore) ReadString(key setting.Key) (string, error) { + _, str, err := s.lookupSettingVariable(key) + if err != nil { + return "", err + } + return str, nil +} + +// ReadUInt64 implements [Store]. +func (s *EnvPolicyStore) ReadUInt64(key setting.Key) (uint64, error) { + name, str, err := s.lookupSettingVariable(key) + if err != nil { + return 0, err + } + if str == "" { + return 0, setting.ErrNotConfigured + } + value, err := strconv.ParseUint(str, 0, 64) + if err != nil { + return 0, fmt.Errorf("%s: %w: %q is not a valid uint64", name, setting.ErrTypeMismatch, str) + } + return value, nil +} + +// ReadBoolean implements [Store]. +func (s *EnvPolicyStore) ReadBoolean(key setting.Key) (bool, error) { + name, str, err := s.lookupSettingVariable(key) + if err != nil { + return false, err + } + if str == "" { + return false, setting.ErrNotConfigured + } + value, err := strconv.ParseBool(str) + if err != nil { + return false, fmt.Errorf("%s: %w: %q is not a valid bool", name, setting.ErrTypeMismatch, str) + } + return value, nil +} + +// ReadStringArray implements [Store]. +func (s *EnvPolicyStore) ReadStringArray(key setting.Key) ([]string, error) { + _, str, err := s.lookupSettingVariable(key) + if err != nil || str == "" { + return nil, err + } + var dst int + res := strings.Split(str, ",") + for src := range res { + res[dst] = strings.TrimSpace(res[src]) + if res[dst] != "" { + dst++ + } + } + return res[0:dst], nil +} + +func (s *EnvPolicyStore) lookupSettingVariable(key setting.Key) (name, value string, err error) { + name, err = keyToEnvVarName(key) + if err != nil { + return "", "", err + } + value, ok := lookupEnv(name) + if !ok { + return name, "", setting.ErrNotConfigured + } + return name, value, nil +} + +var ( + errEmptyKey = errors.New("key must not be empty") + errInvalidKey = errors.New("key must consist of alphanumeric characters and slashes") +) + +// keyToEnvVarName returns the environment variable name for a given policy +// setting key, or an error if the key is invalid. It converts CamelCase keys into +// underscore-separated words and prepends the variable name with the TS prefix. +// For example: AuthKey => TS_AUTH_KEY, ExitNodeAllowLANAccess => TS_EXIT_NODE_ALLOW_LAN_ACCESS, etc. +// +// It's fine to use this in [EnvPolicyStore] without caching variable names since it's not a hot path. +// [EnvPolicyStore] is not a [Changeable] policy store, so the conversion will only happen once. +func keyToEnvVarName(key setting.Key) (string, error) { + if len(key) == 0 { + return "", errEmptyKey + } + + isLower := func(c byte) bool { return 'a' <= c && c <= 'z' } + isUpper := func(c byte) bool { return 'A' <= c && c <= 'Z' } + isLetter := func(c byte) bool { return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') } + isDigit := func(c byte) bool { return '0' <= c && c <= '9' } + + words := make([]string, 0, 8) + words = append(words, "TS") + var currentWord strings.Builder + for i := 0; i < len(key); i++ { + c := key[i] + if c >= utf8.RuneSelf { + return "", errInvalidKey + } + + var split bool + switch { + case isLower(c): + c -= 'a' - 'A' // make upper + split = currentWord.Len() > 0 && !isLetter(key[i-1]) + case isUpper(c): + if currentWord.Len() > 0 { + prevUpper := isUpper(key[i-1]) + nextLower := i < len(key)-1 && isLower(key[i+1]) + split = !prevUpper || nextLower // split on case transition + } + case isDigit(c): + split = currentWord.Len() > 0 && !isDigit(key[i-1]) + case c == setting.KeyPathSeparator: + words = append(words, currentWord.String()) + currentWord.Reset() + continue + default: + return "", errInvalidKey + } + + if split { + words = append(words, currentWord.String()) + currentWord.Reset() + } + + currentWord.WriteByte(c) + } + + if currentWord.Len() > 0 { + words = append(words, currentWord.String()) + } + + return strings.Join(words, "_"), nil +} diff --git a/util/syspolicy/source/env_policy_store_test.go b/util/syspolicy/source/env_policy_store_test.go new file mode 100644 index 000000000..364a6104d --- /dev/null +++ b/util/syspolicy/source/env_policy_store_test.go @@ -0,0 +1,354 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package source + +import ( + "cmp" + "errors" + "math" + "reflect" + "strconv" + "testing" + + "tailscale.com/util/syspolicy/setting" +) + +func TestKeyToVariableName(t *testing.T) { + tests := []struct { + name string + key setting.Key + want string + wantErr error + }{ + { + name: "empty", + key: "", + wantErr: errEmptyKey, + }, + { + name: "lowercase", + key: "tailnet", + want: "TS_TAILNET", + }, + { + name: "CamelCase", + key: "AuthKey", + want: "TS_AUTH_KEY", + }, + { + name: "LongerCamelCase", + key: "ManagedByOrganizationName", + want: "TS_MANAGED_BY_ORGANIZATION_NAME", + }, + { + name: "UPPERCASE", + key: "UPPERCASE", + want: "TS_UPPERCASE", + }, + { + name: "WithAbbrev/Front", + key: "DNSServer", + want: "TS_DNS_SERVER", + }, + { + name: "WithAbbrev/Middle", + key: "ExitNodeAllowLANAccess", + want: "TS_EXIT_NODE_ALLOW_LAN_ACCESS", + }, + { + name: "WithAbbrev/Back", + key: "ExitNodeID", + want: "TS_EXIT_NODE_ID", + }, + { + name: "WithDigits/Single/Front", + key: "0TestKey", + want: "TS_0_TEST_KEY", + }, + { + name: "WithDigits/Multi/Front", + key: "64TestKey", + want: "TS_64_TEST_KEY", + }, + { + name: "WithDigits/Single/Middle", + key: "Test0Key", + want: "TS_TEST_0_KEY", + }, + { + name: "WithDigits/Multi/Middle", + key: "Test64Key", + want: "TS_TEST_64_KEY", + }, + { + name: "WithDigits/Single/Back", + key: "TestKey0", + want: "TS_TEST_KEY_0", + }, + { + name: "WithDigits/Multi/Back", + key: "TestKey64", + want: "TS_TEST_KEY_64", + }, + { + name: "WithDigits/Multi/Back", + key: "TestKey64", + want: "TS_TEST_KEY_64", + }, + { + name: "WithPathSeparators/Single", + key: "Key/Subkey", + want: "TS_KEY_SUBKEY", + }, + { + name: "WithPathSeparators/Multi", + key: "Root/Level1/Level2", + want: "TS_ROOT_LEVEL_1_LEVEL_2", + }, + { + name: "Mixed", + key: "Network/DNSServer/IPAddress", + want: "TS_NETWORK_DNS_SERVER_IP_ADDRESS", + }, + { + name: "Non-Alphanumeric/NonASCII/1", + key: "ж", + wantErr: errInvalidKey, + }, + { + name: "Non-Alphanumeric/NonASCII/2", + key: "KeyжName", + wantErr: errInvalidKey, + }, + { + name: "Non-Alphanumeric/Space", + key: "Key Name", + wantErr: errInvalidKey, + }, + { + name: "Non-Alphanumeric/Punct", + key: "Key!Name", + wantErr: errInvalidKey, + }, + { + name: "Non-Alphanumeric/Backslash", + key: `Key\Name`, + wantErr: errInvalidKey, + }, + } + for _, tt := range tests { + t.Run(cmp.Or(tt.name, string(tt.key)), func(t *testing.T) { + got, err := keyToEnvVarName(tt.key) + checkError(t, err, tt.wantErr, true) + + if got != tt.want { + t.Fatalf("got %q; want %q", got, tt.want) + } + }) + } +} + +func TestEnvPolicyStore(t *testing.T) { + blankEnv := func(string) (string, bool) { return "", false } + makeEnv := func(wantName, value string) func(string) (string, bool) { + return func(gotName string) (string, bool) { + if gotName != wantName { + return "", false + } + return value, true + } + } + tests := []struct { + name string + key setting.Key + lookup func(string) (string, bool) + want any + wantErr error + }{ + { + name: "NotConfigured/String", + key: "AuthKey", + lookup: blankEnv, + wantErr: setting.ErrNotConfigured, + want: "", + }, + { + name: "Configured/String/Empty", + key: "AuthKey", + lookup: makeEnv("TS_AUTH_KEY", ""), + want: "", + }, + { + name: "Configured/String/NonEmpty", + key: "AuthKey", + lookup: makeEnv("TS_AUTH_KEY", "ABC123"), + want: "ABC123", + }, + { + name: "NotConfigured/UInt64", + key: "IntegerSetting", + lookup: blankEnv, + wantErr: setting.ErrNotConfigured, + want: uint64(0), + }, + { + name: "Configured/UInt64/Empty", + key: "IntegerSetting", + lookup: makeEnv("TS_INTEGER_SETTING", ""), + wantErr: setting.ErrNotConfigured, + want: uint64(0), + }, + { + name: "Configured/UInt64/Zero", + key: "IntegerSetting", + lookup: makeEnv("TS_INTEGER_SETTING", "0"), + want: uint64(0), + }, + { + name: "Configured/UInt64/NonZero", + key: "IntegerSetting", + lookup: makeEnv("TS_INTEGER_SETTING", "12345"), + want: uint64(12345), + }, + { + name: "Configured/UInt64/MaxUInt64", + key: "IntegerSetting", + lookup: makeEnv("TS_INTEGER_SETTING", strconv.FormatUint(math.MaxUint64, 10)), + want: uint64(math.MaxUint64), + }, + { + name: "Configured/UInt64/Negative", + key: "IntegerSetting", + lookup: makeEnv("TS_INTEGER_SETTING", "-1"), + wantErr: setting.ErrTypeMismatch, + want: uint64(0), + }, + { + name: "Configured/UInt64/Hex", + key: "IntegerSetting", + lookup: makeEnv("TS_INTEGER_SETTING", "0xDEADBEEF"), + want: uint64(0xDEADBEEF), + }, + { + name: "NotConfigured/Bool", + key: "LogSCMInteractions", + lookup: blankEnv, + wantErr: setting.ErrNotConfigured, + want: false, + }, + { + name: "Configured/Bool/Empty", + key: "LogSCMInteractions", + lookup: makeEnv("TS_LOG_SCM_INTERACTIONS", ""), + wantErr: setting.ErrNotConfigured, + want: false, + }, + { + name: "Configured/Bool/True", + key: "LogSCMInteractions", + lookup: makeEnv("TS_LOG_SCM_INTERACTIONS", "true"), + want: true, + }, + { + name: "Configured/Bool/False", + key: "LogSCMInteractions", + lookup: makeEnv("TS_LOG_SCM_INTERACTIONS", "False"), + want: false, + }, + { + name: "Configured/Bool/1", + key: "LogSCMInteractions", + lookup: makeEnv("TS_LOG_SCM_INTERACTIONS", "1"), + want: true, + }, + { + name: "Configured/Bool/0", + key: "LogSCMInteractions", + lookup: makeEnv("TS_LOG_SCM_INTERACTIONS", "0"), + want: false, + }, + { + name: "Configured/Bool/Invalid", + key: "IntegerSetting", + lookup: makeEnv("TS_INTEGER_SETTING", "NotABool"), + wantErr: setting.ErrTypeMismatch, + want: false, + }, + { + name: "NotConfigured/StringArray", + key: "AllowedSuggestedExitNodes", + lookup: blankEnv, + wantErr: setting.ErrNotConfigured, + want: []string(nil), + }, + { + name: "Configured/StringArray/Empty", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("TS_ALLOWED_SUGGESTED_EXIT_NODES", ""), + want: []string(nil), + }, + { + name: "Configured/StringArray/Spaces", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("TS_ALLOWED_SUGGESTED_EXIT_NODES", " \t "), + want: []string{}, + }, + { + name: "Configured/StringArray/Single", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("TS_ALLOWED_SUGGESTED_EXIT_NODES", "NodeA"), + want: []string{"NodeA"}, + }, + { + name: "Configured/StringArray/Multi", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("TS_ALLOWED_SUGGESTED_EXIT_NODES", "NodeA,NodeB,NodeC"), + want: []string{"NodeA", "NodeB", "NodeC"}, + }, + { + name: "Configured/StringArray/WithBlank", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("TS_ALLOWED_SUGGESTED_EXIT_NODES", "NodeA,\t,, ,NodeB"), + want: []string{"NodeA", "NodeB"}, + }, + } + for _, tt := range tests { + t.Run(cmp.Or(tt.name, string(tt.key)), func(t *testing.T) { + oldLookupEnv := lookupEnv + t.Cleanup(func() { lookupEnv = oldLookupEnv }) + lookupEnv = tt.lookup + + var got any + var err error + var store EnvPolicyStore + switch tt.want.(type) { + case string: + got, err = store.ReadString(tt.key) + case uint64: + got, err = store.ReadUInt64(tt.key) + case bool: + got, err = store.ReadBoolean(tt.key) + case []string: + got, err = store.ReadStringArray(tt.key) + } + checkError(t, err, tt.wantErr, false) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } +} + +func checkError(tb testing.TB, got, want error, fatal bool) { + tb.Helper() + f := tb.Errorf + if fatal { + f = tb.Fatalf + } + if (want == nil && got != nil) || + (want != nil && got == nil) || + (want != nil && got != nil && !errors.Is(got, want) && want.Error() != got.Error()) { + f("gotErr: %v; wantErr: %v", got, want) + } +} diff --git a/util/syspolicy/source/policy_store_windows.go b/util/syspolicy/source/policy_store_windows.go index f526b4ce1..86e2254e0 100644 --- a/util/syspolicy/source/policy_store_windows.go +++ b/util/syspolicy/source/policy_store_windows.go @@ -319,9 +319,9 @@ func(key registry.Key, valueName string) ([]string, error) { // If there are no [setting.KeyPathSeparator]s in the key, the policy setting value // is meant to be stored directly under {HKLM,HKCU}\Software\Policies\Tailscale. func splitSettingKey(key setting.Key) (path, valueName string) { - if idx := strings.LastIndex(string(key), setting.KeyPathSeparator); idx != -1 { - path = strings.ReplaceAll(string(key[:idx]), setting.KeyPathSeparator, `\`) - valueName = string(key[idx+len(setting.KeyPathSeparator):]) + if idx := strings.LastIndexByte(string(key), setting.KeyPathSeparator); idx != -1 { + path = strings.ReplaceAll(string(key[:idx]), string(setting.KeyPathSeparator), `\`) + valueName = string(key[idx+1:]) return path, valueName } return "", string(key)