types/opt: add BoolFlag for setting Bool value as a flag

Updates tailscale/corp#22578

Signed-off-by: Will Norris <will@tailscale.com>
This commit is contained in:
Will Norris 2024-08-26 10:17:45 -07:00 committed by Will Norris
parent 8af50fa97c
commit cccacff564
2 changed files with 63 additions and 0 deletions

View File

@ -105,3 +105,29 @@ func (b *Bool) UnmarshalJSON(j []byte) error {
} }
return nil return nil
} }
// BoolFlag is a wrapper for Bool that implements [flag.Value].
type BoolFlag struct {
*Bool
}
// Set the value of b, using any value supported by [strconv.ParseBool].
func (b *BoolFlag) Set(s string) error {
v, err := strconv.ParseBool(s)
if err != nil {
return err
}
b.Bool.Set(v)
return nil
}
// String returns "true" or "false" if the value is set, or an empty string otherwise.
func (b *BoolFlag) String() string {
if b == nil || b.Bool == nil {
return ""
}
if v, ok := b.Bool.Get(); ok {
return strconv.FormatBool(v)
}
return ""
}

View File

@ -5,7 +5,9 @@
import ( import (
"encoding/json" "encoding/json"
"flag"
"reflect" "reflect"
"strings"
"testing" "testing"
) )
@ -127,3 +129,38 @@ func TestUnmarshalAlloc(t *testing.T) {
t.Errorf("got %v allocs, want 0", n) t.Errorf("got %v allocs, want 0", n)
} }
} }
func TestBoolFlag(t *testing.T) {
tests := []struct {
arguments string
wantParseError bool // expect flag.Parse to error
want Bool
}{
{"", false, Bool("")},
{"-test", true, Bool("")},
{`-test=""`, true, Bool("")},
{"-test invalid", true, Bool("")},
{"-test true", false, NewBool(true)},
{"-test 1", false, NewBool(true)},
{"-test false", false, NewBool(false)},
{"-test 0", false, NewBool(false)},
}
for _, tt := range tests {
var got Bool
fs := flag.NewFlagSet(t.Name(), flag.ContinueOnError)
fs.Var(&BoolFlag{&got}, "test", "test flag")
arguments := strings.Split(tt.arguments, " ")
err := fs.Parse(arguments)
if (err != nil) != tt.wantParseError {
t.Errorf("flag.Parse(%q) returned error %v, want %v", arguments, err, tt.wantParseError)
}
if got != tt.want {
t.Errorf("flag.Parse(%q) got %q, want %q", arguments, got, tt.want)
}
}
}