diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 7ed0e7d73..7b235c91a 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -609,10 +609,11 @@ func isAlpha(b byte) bool { // // We might relax these rules later. func CheckTag(tag string) error { - if !strings.HasPrefix(tag, "tag:") { + var ok bool + tag, ok = strings.CutPrefix(tag, "tag:") + if !ok { return errors.New("tags must start with 'tag:'") } - tag = tag[4:] if tag == "" { return errors.New("tag names must not be empty") } diff --git a/tailcfg/tailcfg_test.go b/tailcfg/tailcfg_test.go index 94b4b50d1..0820a6d5c 100644 --- a/tailcfg/tailcfg_test.go +++ b/tailcfg/tailcfg_test.go @@ -865,3 +865,27 @@ func TestDeps(t *testing.T) { }, }.Check(t) } + +func TestCheckTag(t *testing.T) { + tests := []struct { + name string + tag string + want bool + }{ + {"empty", "", false}, + {"good", "tag:foo", true}, + {"bad", "tag:", false}, + {"no_leading_num", "tag:1foo", false}, + {"no_punctuation", "tag:foa@bar", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckTag(tt.tag) + if err == nil && !tt.want { + t.Errorf("got nil; want error") + } else if err != nil && tt.want { + t.Errorf("got %v; want nil", err) + } + }) + } +}