Skip to content

Commit

Permalink
fix(execute): use fresh context for post-run hooks
Browse files Browse the repository at this point in the history
This change ensures that command post-run hooks receive a fresh context
instead of the original context passed to the `Execute` functions.

This is needed since often the context passed to the `Execute` functions
is canceled by the time the post-run hooks are ran - either by the OS
signal, deadlines, or other reasons.
  • Loading branch information
arikkfir committed Jul 12, 2024
1 parent a49ab26 commit 1b6608a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
3 changes: 2 additions & 1 deletion execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ func ExecuteWithContext(ctx context.Context, w io.Writer, root *Command, args []
// Ensure we invoke post-run hooks before we return
chain := cmd.getChain()
defer func() {
postHooksCtx := context.Background()
for i := len(chain) - 1; i >= 0; i-- {
c := chain[i]
for j := len(c.postRunHooks) - 1; j >= 0; j-- {
h := c.postRunHooks[j]
if err := h.PostRun(ctx, actionError, exitCode); err != nil {
if err := h.PostRun(postHooksCtx, actionError, exitCode); err != nil {
_, _ = fmt.Fprintln(w, err)
exitCode = ExitCodeError
}
Expand Down
28 changes: 26 additions & 2 deletions execute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ import (

type TrackingAction struct {
callTime *time.Time
providedCtx context.Context
errorToReturnOnCall error
}

func (a *TrackingAction) Run(_ context.Context) error {
func (a *TrackingAction) Run(ctx context.Context) error {
a.callTime = ptrOf(time.Now())
a.providedCtx = ctx
time.Sleep(100 * time.Millisecond)
return a.errorToReturnOnCall
}
Expand All @@ -36,13 +38,15 @@ func (a *TrackingPreRunHook) PreRun(_ context.Context) error {

type TrackingPostRunHook struct {
callTime *time.Time
providedCtx context.Context
providedActionError error
providedExitCode ExitCode
errorToReturnOnCall error
}

func (a *TrackingPostRunHook) PostRun(_ context.Context, actionError error, exitCode ExitCode) error {
func (a *TrackingPostRunHook) PostRun(ctx context.Context, actionError error, exitCode ExitCode) error {
a.callTime = ptrOf(time.Now())
a.providedCtx = ctx
a.providedActionError = actionError
a.providedExitCode = exitCode
time.Sleep(100 * time.Millisecond)
Expand Down Expand Up @@ -252,4 +256,24 @@ Flags:
With(t).Verify(action.TrackingAction.callTime).Will(Not(BeNil())).OrFail()
With(t).Verify(b.String()).Will(BeEmpty()).OrFail()
})

t.Run("ensure post-hooks use fresh context", func(t *testing.T) {
//nolint:all
executionCtx := context.WithValue(context.Background(), "k", "v")

action := &TrackingAction{}
root := MustNew("cmd", "desc", "long desc", action, []any{&PostRunHookWithConfig{}})

exitCode := ExecuteWithContext(executionCtx, os.Stderr, root, nil, nil)
With(t).Verify(exitCode).Will(EqualTo(ExitCodeSuccess)).OrFail()

if action.providedCtx != executionCtx {
t.Fatalf("incorrect context passed to action: %+v", action.providedCtx)
}

rootPostRunHook := root.postRunHooks[0].(*PostRunHookWithConfig)
if rootPostRunHook.providedCtx == executionCtx {
t.Fatalf("incorrect context passed to posthook: %+v", rootPostRunHook.providedCtx)
}
})
}

0 comments on commit 1b6608a

Please sign in to comment.