Skip to content

Commit

Permalink
feat(presubcommandrun): pre-run hook for parent commands (#4)
Browse files Browse the repository at this point in the history
This feature allows parent commands to execute a hook just before a
sub-command is run. The hooks are run in out-to-inner order - like so:

Hook run order: root -> sub1 -> sub2
Command run: sub2 only
  • Loading branch information
arikkfir authored May 16, 2024
1 parent b624931 commit d71d3a0
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 50 deletions.
119 changes: 75 additions & 44 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"flag"
"fmt"
"io"
"os"
"reflect"
"regexp"
"slices"
Expand All @@ -18,10 +17,7 @@ import (
var Version = "0.0.0-unknown"

var (
tokenRE = regexp.MustCompile(`^([^=]+)=(.*)$`)
builtinConfig = &BuiltinConfig{
Help: false,
}
tokenRE = regexp.MustCompile(`^([^=]+)=(.*)$`)
)

func New(parent *Command, spec Spec) *Command {
Expand All @@ -33,12 +29,14 @@ func New(parent *Command, spec Spec) *Command {
ShortDescription: spec.ShortDescription,
LongDescription: spec.LongDescription,
Config: spec.Config,
PreSubCommandRun: spec.OnSubCommandRun,
Run: spec.Run,
Parent: parent,
builtinConfig: &BuiltinConfig{Help: false},
parent: parent,
createdByNewCommand: true,
}
if cmd.Parent != nil {
cmd.Parent.subCommands = append(cmd.Parent.subCommands, cmd)
if cmd.parent != nil {
cmd.parent.subCommands = append(cmd.parent.subCommands, cmd)
}
return cmd
}
Expand All @@ -48,17 +46,20 @@ type Spec struct {
ShortDescription string
LongDescription string
Config any
OnSubCommandRun func(ctx context.Context, config any, usagePrinter UsagePrinter) error
Run func(ctx context.Context, config any, usagePrinter UsagePrinter) error
}

type Command struct {
Name string
ShortDescription string
LongDescription string
Parent *Command
subCommands []*Command
Config any
PreSubCommandRun func(ctx context.Context, config any, usagePrinter UsagePrinter) error
Run func(ctx context.Context, config any, usagePrinter UsagePrinter) error
builtinConfig any
parent *Command
subCommands []*Command
createdByNewCommand bool
envVarsMapping map[string]reflect.Value
flagSet *flag.FlagSet
Expand Down Expand Up @@ -91,12 +92,12 @@ func (c *Command) initializeFlagSet() error {

// Create a flag set
name := c.Name
for parent := c.Parent; parent != nil; parent = parent.Parent {
for parent := c.parent; parent != nil; parent = parent.parent {
name = parent.Name + " " + name
}
c.flagSet = flag.NewFlagSet(name, flag.ContinueOnError)
c.flagSet.SetOutput(io.Discard)
if err := c.initializeFlagSetFromStruct(reflect.ValueOf(builtinConfig).Elem()); err != nil {
if err := c.initializeFlagSetFromStruct(reflect.ValueOf(c.builtinConfig).Elem()); err != nil {
return fmt.Errorf("failed to process builtin configuration fields: %w", err)
}

Expand Down Expand Up @@ -261,19 +262,17 @@ func (c *Command) applyEnvironmentVariables(envVars map[string]string) error {
return nil
}

func (c *Command) configure(envVars map[string]string, args []string) error {
func (c *Command) applyCLIArguments(args []string) error {

// Apply environment variables first
if err := c.applyEnvironmentVariables(envVars); err != nil {
return fmt.Errorf("failed to apply environment variables: %w", err)
}

// Override with CLI arguments
// Update config with CLI arguments
if err := c.flagSet.Parse(args); err != nil {
return fmt.Errorf("failed to apply CLI arguments: %w", err)
}

// Ensure all required flags have been provided via either CLI or via environment variables
return nil
}

func (c *Command) validateRequiredFlagsWereProvided(envVars map[string]string) error {
var missingRequiredFlags []string
copy(missingRequiredFlags, c.requiredFlags)
c.flagSet.Visit(func(f *flag.Flag) {
Expand All @@ -289,20 +288,40 @@ func (c *Command) configure(envVars map[string]string, args []string) error {
})
}
}
if len(missingRequiredFlags) > 0 {
return fmt.Errorf("these required flags have not set via either CLI nor environment variables: %v", missingRequiredFlags)
}
return nil
}

func (c *Command) configure(envVars map[string]string, args []string) error {

// Initialize the flagSet for the chosen command
if err := c.initializeFlagSet(); err != nil {
panic(fmt.Sprintf("failed to initialize flag set for command '%s': %v", c.Name, err))
}

// Apply environment variables first
if err := c.applyEnvironmentVariables(envVars); err != nil {
return fmt.Errorf("failed to apply environment variables: %w", err)
}

// Override with CLI arguments
if err := c.flagSet.Parse(args); err != nil {
return fmt.Errorf("failed to apply CLI arguments: %w", err)
}

// Apply positional arguments
if c.positionalArgsTarget != nil {
*c.positionalArgsTarget = c.flagSet.Args()
}
if len(missingRequiredFlags) > 0 {
return fmt.Errorf("these required flags have not set via either CLI nor environment variables: %v", missingRequiredFlags)
}

return nil
}

func (c *Command) printCommandUsage(w io.Writer, short bool) {
cmdChain := c.Name
for cmd := c.Parent; cmd != nil; cmd = cmd.Parent {
for cmd := c.parent; cmd != nil; cmd = cmd.parent {
cmdChain = cmd.Name + " " + cmdChain
}

Expand Down Expand Up @@ -392,10 +411,9 @@ func (c *Command) printCommandUsage(w io.Writer, short bool) {
}
}

//goland:noinspection GoUnusedExportedFunction
func Execute(root *Command, args []string, envVars map[string]string) {
func Execute(ctx context.Context, w io.Writer, root *Command, args []string, envVars map[string]string) (exitCode int) {
if !root.createdByNewCommand {
panic("illegal root command was specified - was it created by 'command.New(...)'?")
panic("invalid root command given, indicating it may not have been created by 'command.New(...)'")
}

// Iterate CLI args, separate them to flags & positional args, but also infer the command to execute from the given
Expand All @@ -411,31 +429,44 @@ func Execute(root *Command, args []string, envVars map[string]string) {
// positional args: [something, sub3, a, b, c]: no "cmd1", "sub1" and "sub2" as they are commands in the hierarchy
cmd, flagArgs, positionalArgs := inferCommandFlagsAndPositionals(root, args)

// Initialize the flagSet for the chosen command
if err := cmd.initializeFlagSet(); err != nil {
panic(fmt.Sprintf("failed to initialize flag set for command '%s': %v", cmd.Name, err))
// Build the command chain from top-to-bottom (so index 0 is the root)
commandChain := []*Command{cmd}
parent := cmd.parent
for parent != nil {
commandChain = append([]*Command{parent}, commandChain...)
parent = parent.parent
}

// Parse the arguments as returned in the parsing step
if err := cmd.configure(envVars, append(flagArgs, positionalArgs...)); err != nil {
cmd.PrintShortUsage(os.Stderr)
os.Exit(1)
} else if cmd.flagSet.Lookup("help").Value.String() == "true" {
cmd.PrintFullUsage(os.Stderr)
os.Exit(0)
// Configure commands up the chain, in order to invoke their "PreSubCommandRun" function
for _, current := range commandChain {
if err := current.configure(envVars, append(flagArgs, positionalArgs...)); err != nil {
current.PrintShortUsage(w)
return 1
}

if err := current.PreSubCommandRun(ctx, current.Config, current); err != nil {
_, _ = fmt.Fprintln(w, err.Error())
return 1
}
}

// If "--help" was provided, show usage and exit immediately
if cmd.flagSet.Lookup("help").Value.String() == "true" {
cmd.PrintFullUsage(w)
return 0
}

// If command has no "Run" function, it's an intermediate probably - just print its usage and exit successfully
if cmd.Run == nil {
cmd.PrintFullUsage(os.Stderr)
cmd.PrintFullUsage(w)
return 0
}

// Run the command with a fresh context
ctx, cancel := context.WithCancel(SetupSignalHandler())
defer cancel()
// Run the command
if err := cmd.Run(ctx, cmd.Config, cmd); err != nil {
cancel() // os.Exit might not invoke the deferred cancel call
_, _ = fmt.Fprintln(os.Stderr, err.Error())
os.Exit(1)
_, _ = fmt.Fprintln(w, err.Error())
return 1
}

return 0
}
Loading

0 comments on commit d71d3a0

Please sign in to comment.