Skip to content

Commit

Permalink
merge go/src/flag v1.16.3
Browse files Browse the repository at this point in the history
  • Loading branch information
jnovack committed Apr 2, 2021
1 parent 04dca4a commit 802519a
Show file tree
Hide file tree
Showing 2 changed files with 237 additions and 17 deletions.
50 changes: 39 additions & 11 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,12 @@ func (d *durationValue) Get() interface{} { return time.Duration(*d) }

func (d *durationValue) String() string { return (*time.Duration)(d).String() }

type funcValue func(string) error

func (f funcValue) Set(s string) error { return f(s) }

func (f funcValue) String() string { return "" }

// Value is the interface to the dynamic value stored in a flag.
// (The default value is represented as a string.)
//
Expand All @@ -296,7 +302,7 @@ type Value interface {
// Getter is an interface that allows the contents of a Value to be retrieved.
// It wraps the Value interface, rather than being part of it, because it
// appeared after Go 1 and its compatibility rules. All Value types provided
// by this package satisfy the Getter interface.
// by this package satisfy the Getter interface, except the type used by Func.
type Getter interface {
Value
Get() interface{}
Expand Down Expand Up @@ -831,24 +837,44 @@ func Duration(name string, value time.Duration, usage string) *time.Duration {
return CommandLine.Duration(name, value, usage)
}

// Func defines a flag with the specified name and usage string.
// Each time the flag is seen, fn is called with the value of the flag.
// If fn returns a non-nil error, it will be treated as a flag value parsing error.
func (f *FlagSet) Func(name, usage string, fn func(string) error) {
f.Var(funcValue(fn), name, usage)
}

// Func defines a flag with the specified name and usage string.
// Each time the flag is seen, fn is called with the value of the flag.
// If fn returns a non-nil error, it will be treated as a flag value parsing error.
func Func(name, usage string, fn func(string) error) {
CommandLine.Func(name, usage, fn)
}

// Var defines a flag with the specified name and usage string. The type and
// value of the flag are represented by the first argument, of type Value, which
// typically holds a user-defined implementation of Value. For instance, the
// caller could create a flag that turns a comma-separated string into a slice
// of strings by giving the slice the methods of Value; in particular, Set would
// decompose the comma-separated string into the slice.
func (f *FlagSet) Var(value Value, name string, usage string) {
// Flag must not begin "-" or contain "=".
if strings.HasPrefix(name, "-") {
panic(f.sprintf("flag %q begins with -", name))
} else if strings.Contains(name, "=") {
panic(f.sprintf("flag %q contains =", name))
}

// Remember the default value as a string; it won't change.
flag := &Flag{name, usage, value, value.String()}
_, alreadythere := f.formal[name]
if alreadythere {
var msg string
if f.name == "" {
msg = fmt.Sprintf("flag redefined: %s", name)
msg = f.sprintf("flag redefined: %s", name)
} else {
msg = fmt.Sprintf("%s flag redefined: %s", f.name, name)
msg = f.sprintf("%s flag redefined: %s", f.name, name)
}
fmt.Fprintln(f.Output(), msg)
panic(msg) // Happens only if flags are declared with identical names
}
if f.formal == nil {
Expand All @@ -867,24 +893,26 @@ func Var(value Value, name string, usage string) {
CommandLine.Var(value, name, usage)
}

// sprintf formats the message, prints it to output, and returns it.
func (f *FlagSet) sprintf(format string, a ...interface{}) string {
msg := fmt.Sprintf(format, a...)
fmt.Fprintln(f.Output(), msg)
return msg
}

// failf prints to standard error a formatted error and usage message and
// returns the error.
func (f *FlagSet) failf(format string, a ...interface{}) error {
err := fmt.Errorf(format, a...)
fmt.Fprintln(f.Output(), err)
msg := f.sprintf(format, a...)
f.usage()
return err
return errors.New(msg)
}

// usage calls the Usage method for the flag set if one is specified,
// or the appropriate default usage function otherwise.
func (f *FlagSet) usage() {
if f.Usage == nil {
if f == CommandLine {
Usage()
} else {
f.defaultUsage()
}
} else {
f.Usage()
}
Expand Down
204 changes: 198 additions & 6 deletions flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"bytes"
"fmt"
"io"
"io/ioutil"
"os"
"sort"
"strconv"
Expand Down Expand Up @@ -36,6 +35,7 @@ func TestEverything(t *testing.T) {
String("test_string", "0", "string value")
Float64("test_float64", 0, "float64 value")
Duration("test_duration", 0, "time.Duration value")
Func("test_func", "func value", func(string) error { return nil })

m := make(map[string]*Flag)
desired := "0"
Expand All @@ -50,14 +50,16 @@ func TestEverything(t *testing.T) {
ok = true
case f.Name == "test_duration" && f.Value.String() == desired+"s":
ok = true
case f.Name == "test_func" && f.Value.String() == "":
ok = true
}
if !ok {
t.Error("Visit: bad value", f.Value.String(), "for", f.Name)
}
}
}
VisitAll(visitor)
if len(m) != 8 {
if len(m) != 9 {
t.Error("VisitAll misses some flags")
for k, v := range m {
t.Log(k, *v)
Expand All @@ -80,9 +82,10 @@ func TestEverything(t *testing.T) {
Set("test_string", "1")
Set("test_float64", "1")
Set("test_duration", "1s")
Set("test_func", "1")
desired = "1"
Visit(visitor)
if len(m) != 8 {
if len(m) != 9 {
t.Error("Visit fails after set")
for k, v := range m {
t.Log(k, *v)
Expand Down Expand Up @@ -255,6 +258,48 @@ func TestUserDefined(t *testing.T) {
}
}

func TestUserDefinedFunc(t *testing.T) {
var flags FlagSet
flags.Init("test", ContinueOnError)
var ss []string
flags.Func("v", "usage", func(s string) error {
ss = append(ss, s)
return nil
})
if err := flags.Parse([]string{"-v", "1", "-v", "2", "-v=3"}); err != nil {
t.Error(err)
}
if len(ss) != 3 {
t.Fatal("expected 3 args; got ", len(ss))
}
expect := "[1 2 3]"
if got := fmt.Sprint(ss); got != expect {
t.Errorf("expected value %q got %q", expect, got)
}
// test usage
var buf strings.Builder
flags.SetOutput(&buf)
flags.Parse([]string{"-h"})
if usage := buf.String(); !strings.Contains(usage, "usage") {
t.Errorf("usage string not included: %q", usage)
}
// test Func error
flags = *NewFlagSet("test", ContinueOnError)
flags.Func("v", "usage", func(s string) error {
return fmt.Errorf("test error")
})
// flag not set, so no error
if err := flags.Parse(nil); err != nil {
t.Error(err)
}
// flag set, expect error
if err := flags.Parse([]string{"-v", "1"}); err == nil {
t.Error("expected error; got none")
} else if errMsg := err.Error(); !strings.Contains(errMsg, "test error") {
t.Errorf(`error should contain "test error"; got %q`, errMsg)
}
}

func TestUserDefinedForCommandLine(t *testing.T) {
const help = "HELP"
var result string
Expand Down Expand Up @@ -497,7 +542,7 @@ func TestGetters(t *testing.T) {
func TestParseError(t *testing.T) {
for _, typ := range []string{"bool", "int", "int64", "uint", "uint64", "float64", "duration"} {
fs := NewFlagSet("parse error test", ContinueOnError)
fs.SetOutput(ioutil.Discard)
fs.SetOutput(io.Discard)
_ = fs.Bool("bool", false, "")
_ = fs.Int("int", 0, "")
_ = fs.Int64("int64", 0, "")
Expand Down Expand Up @@ -528,7 +573,7 @@ func TestRangeError(t *testing.T) {
}
for _, arg := range bad {
fs := NewFlagSet("parse error test", ContinueOnError)
fs.SetOutput(ioutil.Discard)
fs.SetOutput(io.Discard)
_ = fs.Int("int", 0, "")
_ = fs.Int64("int64", 0, "")
_ = fs.Uint("uint", 0, "")
Expand All @@ -546,4 +591,151 @@ func TestRangeError(t *testing.T) {
}
}

// /* jnovack/flag cannot import TextExitCode because it relies on internal/testenv */
// jnovack/flag cannot import TextExitCode because it relies on internal/testenv
/*
func TestExitCode(t *testing.T) {
testenv.MustHaveExec(t)
magic := 123
if os.Getenv("GO_CHILD_FLAG") != "" {
fs := NewFlagSet("test", ExitOnError)
if os.Getenv("GO_CHILD_FLAG_HANDLE") != "" {
var b bool
fs.BoolVar(&b, os.Getenv("GO_CHILD_FLAG_HANDLE"), false, "")
}
fs.Parse([]string{os.Getenv("GO_CHILD_FLAG")})
os.Exit(magic)
}
tests := []struct {
flag string
flagHandle string
expectExit int
}{
{
flag: "-h",
expectExit: 0,
},
{
flag: "-help",
expectExit: 0,
},
{
flag: "-undefined",
expectExit: 2,
},
{
flag: "-h",
flagHandle: "h",
expectExit: magic,
},
{
flag: "-help",
flagHandle: "help",
expectExit: magic,
},
}
for _, test := range tests {
cmd := exec.Command(os.Args[0], "-test.run=TestExitCode")
cmd.Env = append(
os.Environ(),
"GO_CHILD_FLAG="+test.flag,
"GO_CHILD_FLAG_HANDLE="+test.flagHandle,
)
cmd.Run()
got := cmd.ProcessState.ExitCode()
// ExitCode is either 0 or 1 on Plan 9.
if runtime.GOOS == "plan9" && test.expectExit != 0 {
test.expectExit = 1
}
if got != test.expectExit {
t.Errorf("unexpected exit code for test case %+v \n: got %d, expect %d",
test, got, test.expectExit)
}
}
}
func mustPanic(t *testing.T, testName string, expected string, f func()) {
t.Helper()
defer func() {
switch msg := recover().(type) {
case nil:
t.Errorf("%s\n: expected panic(%q), but did not panic", testName, expected)
case string:
if msg != expected {
t.Errorf("%s\n: expected panic(%q), but got panic(%q)", testName, expected, msg)
}
default:
t.Errorf("%s\n: expected panic(%q), but got panic(%T%v)", testName, expected, msg, msg)
}
}()
f()
}
func TestInvalidFlags(t *testing.T) {
tests := []struct {
flag string
errorMsg string
}{
{
flag: "-foo",
errorMsg: "flag \"-foo\" begins with -",
},
{
flag: "foo=bar",
errorMsg: "flag \"foo=bar\" contains =",
},
}
for _, test := range tests {
testName := fmt.Sprintf("FlagSet.Var(&v, %q, \"\")", test.flag)
fs := NewFlagSet("", ContinueOnError)
buf := bytes.NewBuffer(nil)
fs.SetOutput(buf)
mustPanic(t, testName, test.errorMsg, func() {
var v flagVar
fs.Var(&v, test.flag, "")
})
if msg := test.errorMsg + "\n"; msg != buf.String() {
t.Errorf("%s\n: unexpected output: expected %q, bug got %q", testName, msg, buf)
}
}
}
func TestRedefinedFlags(t *testing.T) {
tests := []struct {
flagSetName string
errorMsg string
}{
{
flagSetName: "",
errorMsg: "flag redefined: foo",
},
{
flagSetName: "fs",
errorMsg: "fs flag redefined: foo",
},
}
for _, test := range tests {
testName := fmt.Sprintf("flag redefined in FlagSet(%q)", test.flagSetName)
fs := NewFlagSet(test.flagSetName, ContinueOnError)
buf := bytes.NewBuffer(nil)
fs.SetOutput(buf)
var v flagVar
fs.Var(&v, "foo", "")
mustPanic(t, testName, test.errorMsg, func() {
fs.Var(&v, "foo", "")
})
if msg := test.errorMsg + "\n"; msg != buf.String() {
t.Errorf("%s\n: unexpected output: expected %q, bug got %q", testName, msg, buf)
}
}
}
*/

0 comments on commit 802519a

Please sign in to comment.