Skip to content

Commit

Permalink
feat: add context-cancel support in callback-wait (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
spy16 authored Sep 26, 2022
1 parent ba6734b commit 5280baa
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
9 changes: 7 additions & 2 deletions oidc/cobra.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package oidc

import (
"context"
"os/signal"
"syscall"

"github.com/spf13/cobra"
"golang.org/x/oauth2"
Expand All @@ -12,16 +14,19 @@ func LoginCmd(cfg *oauth2.Config, aud, keyFilePath string, onTokenOrErr func(t *
Use: "login",
Short: "Login with your Google account.",
Run: func(cmd *cobra.Command, args []string) {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()

var ts oauth2.TokenSource
if keyFilePath != "" {
var err error
ts, err = NewGoogleServiceAccountTokenSource(context.Background(), keyFilePath, aud)
ts, err = NewGoogleServiceAccountTokenSource(ctx, keyFilePath, aud)
if err != nil {
onTokenOrErr(nil, err)
return
}
} else {
ts = NewTokenSource(context.Background(), cfg, aud)
ts = NewTokenSource(ctx, cfg, aud)
}
onTokenOrErr(ts.Token())
},
Expand Down
2 changes: 1 addition & 1 deletion oidc/source_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (source *authHandlerSource) Token() (*oauth2.Token, error) {
oauth2.SetAuthURLParam(codeChallengeMethodKey, challengeMethod),
)

code, receivedState, err := browserAuthzHandler(source.config.RedirectURL, url)
code, receivedState, err := browserAuthzHandler(source.ctx, source.config.RedirectURL, url)
if err != nil {
return nil, err
} else if receivedState != actualState {
Expand Down
17 changes: 12 additions & 5 deletions oidc/utils.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package oidc

import (
"context"
"crypto/rand"
_ "embed" // for embedded html
"encoding/base64"
Expand Down Expand Up @@ -56,7 +57,7 @@ func randomBytes(length int) ([]byte, error) {
}
}

func browserAuthzHandler(redirectURL, authCodeURL string) (code string, state string, err error) {
func browserAuthzHandler(ctx context.Context, redirectURL, authCodeURL string) (code string, state string, err error) {
if err := openURL(authCodeURL); err != nil {
return "", "", err
}
Expand All @@ -66,14 +67,14 @@ func browserAuthzHandler(redirectURL, authCodeURL string) (code string, state st
return "", "", err
}

code, state, err = waitForCallback(fmt.Sprintf(":%s", u.Port()))
code, state, err = waitForCallback(ctx, fmt.Sprintf(":%s", u.Port()))
if err != nil {
return "", "", err
}
return code, state, nil
}

func waitForCallback(addr string) (code, state string, err error) {
func waitForCallback(ctx context.Context, addr string) (code, state string, err error) {
var cb struct {
code string
state string
Expand Down Expand Up @@ -101,8 +102,14 @@ func waitForCallback(addr string) (code, state string, err error) {
}

go func() {
<-stopCh
_ = srv.Close()
select {
case <-stopCh:
_ = srv.Close()

case <-ctx.Done():
cb.err = ctx.Err()
_ = srv.Close()
}
}()

if serveErr := srv.ListenAndServe(); serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
Expand Down

0 comments on commit 5280baa

Please sign in to comment.