diff --git a/cmd/generate.go b/cmd/generate.go index 8e7f69a..c94bbd7 100644 --- a/cmd/generate.go +++ b/cmd/generate.go @@ -13,6 +13,8 @@ import ( "github.com/spf13/cobra" ) +var generateArgMax = 1 + func init() { rootCmd.AddCommand(generateCmd) } @@ -117,6 +119,29 @@ func newGenerateCmd() *cobra.Command { return err }, + PreRunE: func(cmd *cobra.Command, args []string) error { + if len(args) > generateArgMax { + return errTooManyArguments + } + + // validate flags + _, err := cmd.Flags().GetString("output") + if err != nil { + return err + } + + formatString, err := cmd.Flags().GetString("format") + if err != nil { + return errors.Join(errMissingFormat, err) + } + + _, err = ParseFmtString(formatString) + if err != nil && formatString != "" { + return errors.Join(errInvalidFormat, err) + } + + return nil + }, } res.Version = VERSION diff --git a/cmd/generate_test.go b/cmd/generate_test.go new file mode 100644 index 0000000..3531466 --- /dev/null +++ b/cmd/generate_test.go @@ -0,0 +1,43 @@ +//go:build unit + +package cmd + +import ( + "errors" + "testing" +) + +func TestRunEGenerate(t *testing.T) { + type runTestGen struct { + flags []string + arguments []string + expectError error + } + + var testGenArgs = []runTestGen{ + { + flags: []string{}, + arguments: []string{"too", "many", "args"}, + expectError: errTooManyArguments, + }, + { + flags: []string{"--format", "bad_fmt"}, + arguments: []string{}, + expectError: errInvalidFormat, + }, + { + flags: []string{"-F", "bad_fmt"}, + arguments: []string{}, + expectError: errInvalidFormat, + }, + } + cmd := newGenerateCmd() + + for i, test := range testGenArgs { + cmd.ParseFlags(test.flags) + err := cmd.PreRunE(cmd, test.arguments) + if !errors.Is(err, test.expectError) { + t.Fatalf("case: %d, expectError: %v does not match err: %v", i, test.expectError, err) + } + } +}