From 792e87e7bdd25431d7529ad2534ff29323ae5302 Mon Sep 17 00:00:00 2001 From: zimbatm Date: Mon, 18 Nov 2024 13:12:15 +0100 Subject: [PATCH] feat: handle graceful shutdown Trap termination signals in the main program using context.Context. If a SIGINT or SIGTERM is received, send a SIGTERM to the sub-programs and wait for them to shutdown gracefully. If the sub-program is stuck for 5 minutes, force-kill them. --- command/cleanup.go | 44 ++++++++++++++++++++---- command/please.go | 5 ++- command/utils.go | 84 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 123 insertions(+), 10 deletions(-) diff --git a/command/cleanup.go b/command/cleanup.go index 2fbad44..b6ddfca 100644 --- a/command/cleanup.go +++ b/command/cleanup.go @@ -26,7 +26,7 @@ func CmdCleanup(c *cli.Context) (err error) { var toUndeploy []string - ctx := context.Background() + ctx := contextWithHandler() ghCli := githubClient(ctx, c) owner, repo := githubSlug(c) @@ -37,7 +37,7 @@ func CmdCleanup(c *cli.Context) (err error) { // undeploy all closed pull requests var deployed []string - deployed, err = listDeployedPullRequests(listScript) + deployed, err = listDeployedPullRequests(ctx, listScript) if err != nil { return err } @@ -86,13 +86,34 @@ func CmdCleanup(c *cli.Context) (err error) { cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - err = cmd.Run() + err = cmd.Start() if err != nil { log.Println("undeploy error: ", err) lastErr = err - continue + select { + case <-ctx.Done(): + log.Println("undeploy cancelled: ", ctx.Err()) + return lastErr + default: + continue + } + } + + err = waitOrStop(ctx, cmd) + if err != nil { + log.Println("undeploy error: ", err) + + lastErr = err + + select { + case <-ctx.Done(): + log.Println("undeploy cancelled: ", ctx.Err()) + return lastErr + default: + continue + } } destroyGitHubDeployments(ctx, ghCli, owner, repo, pullRequestID, ignoreMissing) @@ -112,13 +133,22 @@ func contains(item string, list []string) bool { } // Get the list of deployed Pull request based on given script. -func listDeployedPullRequests(listScript string) ([]string, error) { - var stdout strings.Builder +func listDeployedPullRequests(ctx context.Context, listScript string) ([]string, error) { + var ( + stdout strings.Builder + err error + ) cmd := exec.Command(listScript) cmd.Stdout = &stdout - if err := cmd.Run(); err != nil { + err = cmd.Start() + if err != nil { + return nil, err + } + + err = waitOrStop(ctx, cmd) + if err != nil { return nil, err } diff --git a/command/please.go b/command/please.go index 9c0932b..eb8d8bc 100644 --- a/command/please.go +++ b/command/please.go @@ -2,7 +2,6 @@ package command import ( "bytes" - "context" "errors" "fmt" "io" @@ -89,7 +88,7 @@ func CmdPlease(c *cli.Context) (err error) { environment = fmt.Sprintf("pr-%d", pr) } - ctx := context.Background() + ctx := contextWithHandler() ghCli := githubClient(ctx, c) log.Println("deploy ref", ref) @@ -168,7 +167,7 @@ func CmdPlease(c *cli.Context) (err error) { } // Wait on the deploy to finish - err = cmd.Wait() + err = waitOrStop(ctx, cmd) if err != nil { err2 := updateStatus(StateFailure, "") if err2 != nil { diff --git a/command/utils.go b/command/utils.go index 321f581..a8c0467 100644 --- a/command/utils.go +++ b/command/utils.go @@ -3,7 +3,12 @@ package command import ( "context" "log" + "os" + "os/exec" + "os/signal" "regexp" + "syscall" + "time" "github.com/google/go-github/github" secretvalue "github.com/zimbatm/go-secretvalue" @@ -61,3 +66,82 @@ func refString(str string) *string { func refStringList(l []string) *[]string { return &l } + +var DefaultKillDelay = 5 * time.Minute + +// waitOrStop waits for the already-started command cmd by calling its Wait method. +// +// If cmd does not return before ctx is done, waitOrStop sends it the given interrupt signal. +// waitOrStop waits DefaultKillDelay for Wait to return before sending os.Kill. +// +// This function is copied from the one added to x/playground/internal in +// http://golang.org/cl/228438. +func waitOrStop(ctx context.Context, cmd *exec.Cmd) error { + if cmd.Process == nil { + panic("waitOrStop called with a nil cmd.Process — missing Start call?") + } + + errc := make(chan error) + go func() { + select { + case errc <- nil: + return + case <-ctx.Done(): + } + + err := cmd.Process.Signal(os.Interrupt) + if err == nil { + err = ctx.Err() // Report ctx.Err() as the reason we interrupted. + } else if err.Error() == "os: process already finished" { + errc <- nil + return + } + + if DefaultKillDelay > 0 { + timer := time.NewTimer(DefaultKillDelay) + select { + // Report ctx.Err() as the reason we interrupted the process... + case errc <- ctx.Err(): + timer.Stop() + return + // ...but after killDelay has elapsed, fall back to a stronger signal. + case <-timer.C: + } + + // Wait still hasn't returned. + // Kill the process harder to make sure that it exits. + // + // Ignore any error: if cmd.Process has already terminated, we still + // want to send ctx.Err() (or the error from the Interrupt call) + // to properly attribute the signal that may have terminated it. + _ = cmd.Process.Kill() + } + + errc <- err + }() + + waitErr := cmd.Wait() + if interruptErr := <-errc; interruptErr != nil { + return interruptErr + } + return waitErr +} + +// contextWithHandler returns a context that is canceled when the program receives a SIGINT or SIGTERM. +// +// !! Only call this function once per program +func contextWithHandler() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + + signalChan := make(chan os.Signal, 1) + + signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM) + + go func() { + sig := <-signalChan + log.Printf("Received signal %s, stopping", sig) + cancel() + }() + + return ctx +}