Skip to content

Commit

Permalink
feat: handle graceful shutdown
Browse files Browse the repository at this point in the history
Trap termination signals in the main program using context.Context.

If a SIGINT or SIGTERM is received, send a SIGTERM to the sub-programs
and wait for them to shutdown gracefully.

If the sub-program is stuck for 5 minutes, force-kill them.
  • Loading branch information
zimbatm committed Nov 18, 2024
1 parent ccc5e1c commit 26856b7
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 10 deletions.
46 changes: 39 additions & 7 deletions command/cleanup.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func CmdCleanup(c *cli.Context) (err error) {

var toUndeploy []string

ctx := context.Background()
ctx := contextWithHandler()
ghCli := githubClient(ctx, c)
owner, repo := githubSlug(c)

Expand All @@ -37,7 +37,7 @@ func CmdCleanup(c *cli.Context) (err error) {
// undeploy all closed pull requests
var deployed []string

deployed, err = listDeployedPullRequests(listScript)
deployed, err = listDeployedPullRequests(ctx, listScript)
if err != nil {
return err
}
Expand Down Expand Up @@ -86,13 +86,36 @@ func CmdCleanup(c *cli.Context) (err error) {
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

err = cmd.Run()
err = cmd.Start()
if err != nil {
log.Println("undeploy error: ", err)

lastErr = err

continue
select {
case <-ctx.Done():
log.Println("undeploy cancelled: ", ctx.Err())

return lastErr
default:
continue
}
}

err = waitOrStop(ctx, cmd)
if err != nil {
log.Println("undeploy error: ", err)

lastErr = err

select {
case <-ctx.Done():
log.Println("undeploy cancelled: ", ctx.Err())

return lastErr
default:
continue
}
}

destroyGitHubDeployments(ctx, ghCli, owner, repo, pullRequestID, ignoreMissing)
Expand All @@ -112,13 +135,22 @@ func contains(item string, list []string) bool {
}

// Get the list of deployed Pull request based on given script.
func listDeployedPullRequests(listScript string) ([]string, error) {
var stdout strings.Builder
func listDeployedPullRequests(ctx context.Context, listScript string) ([]string, error) {
var (
stdout strings.Builder
err error
)

cmd := exec.Command(listScript)
cmd.Stdout = &stdout

if err := cmd.Run(); err != nil {
err = cmd.Start()
if err != nil {
return nil, err
}

err = waitOrStop(ctx, cmd)
if err != nil {
return nil, err
}

Expand Down
5 changes: 2 additions & 3 deletions command/please.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package command

import (
"bytes"
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -89,7 +88,7 @@ func CmdPlease(c *cli.Context) (err error) {
environment = fmt.Sprintf("pr-%d", pr)
}

ctx := context.Background()
ctx := contextWithHandler()
ghCli := githubClient(ctx, c)

log.Println("deploy ref", ref)
Expand Down Expand Up @@ -168,7 +167,7 @@ func CmdPlease(c *cli.Context) (err error) {
}

// Wait on the deploy to finish
err = cmd.Wait()
err = waitOrStop(ctx, cmd)
if err != nil {
err2 := updateStatus(StateFailure, "")
if err2 != nil {
Expand Down
89 changes: 89 additions & 0 deletions command/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@ package command
import (
"context"
"log"
"os"
"os/exec"
"os/signal"
"regexp"
"syscall"
"time"

"github.com/google/go-github/github"
secretvalue "github.com/zimbatm/go-secretvalue"
Expand Down Expand Up @@ -61,3 +66,87 @@ func refString(str string) *string {
func refStringList(l []string) *[]string {
return &l
}

var DefaultKillDelay = 5 * time.Minute

// waitOrStop waits for the already-started command cmd by calling its Wait method.
//
// If cmd does not return before ctx is done, waitOrStop sends it the given interrupt signal.
// waitOrStop waits DefaultKillDelay for Wait to return before sending os.Kill.
//
// This function is copied from the one added to x/playground/internal in
// http://golang.org/cl/228438.
func waitOrStop(ctx context.Context, cmd *exec.Cmd) error {
if cmd.Process == nil {
panic("waitOrStop called with a nil cmd.Process — missing Start call?")
}

errc := make(chan error)
go func() {
select {
case errc <- nil:
return
case <-ctx.Done():
}

err := cmd.Process.Signal(os.Interrupt)
if err == nil {
err = ctx.Err() // Report ctx.Err() as the reason we interrupted.
} else if err.Error() == "os: process already finished" {
errc <- nil

return
}

if DefaultKillDelay > 0 {
timer := time.NewTimer(DefaultKillDelay)
select {
// Report ctx.Err() as the reason we interrupted the process...
case errc <- ctx.Err():
timer.Stop()

return
// ...but after killDelay has elapsed, fall back to a stronger signal.
case <-timer.C:
}

// Wait still hasn't returned.
// Kill the process harder to make sure that it exits.
//
// Ignore any error: if cmd.Process has already terminated, we still
// want to send ctx.Err() (or the error from the Interrupt call)
// to properly attribute the signal that may have terminated it.
_ = cmd.Process.Kill()
}

errc <- err
}()

waitErr := cmd.Wait()

interruptErr := <-errc
if interruptErr != nil {
return interruptErr
}

return waitErr
}

// contextWithHandler returns a context that is canceled when the program receives a SIGINT or SIGTERM.
//
// !! Only call this function once per program.
func contextWithHandler() context.Context {
ctx, cancel := context.WithCancel(context.Background())

signalChan := make(chan os.Signal, 1)

signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM)

go func() {
sig := <-signalChan
log.Printf("Received signal %s, stopping", sig)
cancel()
}()

return ctx
}

0 comments on commit 26856b7

Please sign in to comment.