diff --git a/oidc/cobra.go b/oidc/cobra.go index 37c6882..b722912 100644 --- a/oidc/cobra.go +++ b/oidc/cobra.go @@ -2,6 +2,8 @@ package oidc import ( "context" + "os/signal" + "syscall" "github.com/spf13/cobra" "golang.org/x/oauth2" @@ -12,16 +14,19 @@ func LoginCmd(cfg *oauth2.Config, aud, keyFilePath string, onTokenOrErr func(t * Use: "login", Short: "Login with your Google account.", Run: func(cmd *cobra.Command, args []string) { + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + var ts oauth2.TokenSource if keyFilePath != "" { var err error - ts, err = NewGoogleServiceAccountTokenSource(context.Background(), keyFilePath, aud) + ts, err = NewGoogleServiceAccountTokenSource(ctx, keyFilePath, aud) if err != nil { onTokenOrErr(nil, err) return } } else { - ts = NewTokenSource(context.Background(), cfg, aud) + ts = NewTokenSource(ctx, cfg, aud) } onTokenOrErr(ts.Token()) }, diff --git a/oidc/source_oidc.go b/oidc/source_oidc.go index d78653c..720c4b0 100644 --- a/oidc/source_oidc.go +++ b/oidc/source_oidc.go @@ -57,7 +57,7 @@ func (source *authHandlerSource) Token() (*oauth2.Token, error) { oauth2.SetAuthURLParam(codeChallengeMethodKey, challengeMethod), ) - code, receivedState, err := browserAuthzHandler(source.config.RedirectURL, url) + code, receivedState, err := browserAuthzHandler(source.ctx, source.config.RedirectURL, url) if err != nil { return nil, err } else if receivedState != actualState { diff --git a/oidc/utils.go b/oidc/utils.go index 7362bdd..87abef1 100644 --- a/oidc/utils.go +++ b/oidc/utils.go @@ -1,6 +1,7 @@ package oidc import ( + "context" "crypto/rand" _ "embed" // for embedded html "encoding/base64" @@ -56,7 +57,7 @@ func randomBytes(length int) ([]byte, error) { } } -func browserAuthzHandler(redirectURL, authCodeURL string) (code string, state string, err error) { +func browserAuthzHandler(ctx context.Context, redirectURL, authCodeURL string) (code string, state string, err error) { if err := openURL(authCodeURL); err != nil { return "", "", err } @@ -66,14 +67,14 @@ func browserAuthzHandler(redirectURL, authCodeURL string) (code string, state st return "", "", err } - code, state, err = waitForCallback(fmt.Sprintf(":%s", u.Port())) + code, state, err = waitForCallback(ctx, fmt.Sprintf(":%s", u.Port())) if err != nil { return "", "", err } return code, state, nil } -func waitForCallback(addr string) (code, state string, err error) { +func waitForCallback(ctx context.Context, addr string) (code, state string, err error) { var cb struct { code string state string @@ -101,8 +102,14 @@ func waitForCallback(addr string) (code, state string, err error) { } go func() { - <-stopCh - _ = srv.Close() + select { + case <-stopCh: + _ = srv.Close() + + case <-ctx.Done(): + cb.err = ctx.Err() + _ = srv.Close() + } }() if serveErr := srv.ListenAndServe(); serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {