diff --git a/types/opt/bool.go b/types/opt/bool.go index 2a9efe31b..0a3ee67ad 100644 --- a/types/opt/bool.go +++ b/types/opt/bool.go @@ -105,3 +105,29 @@ func (b *Bool) UnmarshalJSON(j []byte) error { } return nil } + +// BoolFlag is a wrapper for Bool that implements [flag.Value]. +type BoolFlag struct { + *Bool +} + +// Set the value of b, using any value supported by [strconv.ParseBool]. +func (b *BoolFlag) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + b.Bool.Set(v) + return nil +} + +// String returns "true" or "false" if the value is set, or an empty string otherwise. +func (b *BoolFlag) String() string { + if b == nil || b.Bool == nil { + return "" + } + if v, ok := b.Bool.Get(); ok { + return strconv.FormatBool(v) + } + return "" +} diff --git a/types/opt/bool_test.go b/types/opt/bool_test.go index 92ba275e1..dddbcfc19 100644 --- a/types/opt/bool_test.go +++ b/types/opt/bool_test.go @@ -5,7 +5,9 @@ import ( "encoding/json" + "flag" "reflect" + "strings" "testing" ) @@ -127,3 +129,38 @@ func TestUnmarshalAlloc(t *testing.T) { t.Errorf("got %v allocs, want 0", n) } } + +func TestBoolFlag(t *testing.T) { + tests := []struct { + arguments string + wantParseError bool // expect flag.Parse to error + want Bool + }{ + {"", false, Bool("")}, + {"-test", true, Bool("")}, + {`-test=""`, true, Bool("")}, + {"-test invalid", true, Bool("")}, + + {"-test true", false, NewBool(true)}, + {"-test 1", false, NewBool(true)}, + + {"-test false", false, NewBool(false)}, + {"-test 0", false, NewBool(false)}, + } + + for _, tt := range tests { + var got Bool + fs := flag.NewFlagSet(t.Name(), flag.ContinueOnError) + fs.Var(&BoolFlag{&got}, "test", "test flag") + + arguments := strings.Split(tt.arguments, " ") + err := fs.Parse(arguments) + if (err != nil) != tt.wantParseError { + t.Errorf("flag.Parse(%q) returned error %v, want %v", arguments, err, tt.wantParseError) + } + + if got != tt.want { + t.Errorf("flag.Parse(%q) got %q, want %q", arguments, got, tt.want) + } + } +}