From 9ff0227e2e0b8fdf00cb87af1bfbeea0e915ac72 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Wed, 16 Dec 2020 16:35:15 -0800 Subject: [PATCH] Handle empty expires_on (#599) * Handle empty expires_on ADFS can return an empty expires_on, handle it. Return a TokenRefreshError when failing to parse a token response. * refactor --- autorest/adal/token.go | 12 ++++--- autorest/adal/token_test.go | 72 +++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 4 deletions(-) diff --git a/autorest/adal/token.go b/autorest/adal/token.go index 96c62efc3..addc91099 100644 --- a/autorest/adal/token.go +++ b/autorest/adal/token.go @@ -1006,13 +1006,17 @@ func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource Resource string `json:"resource"` Type string `json:"token_type"` }{} + // return a TokenRefreshError in the follow error cases as the token is in an unexpected format err = json.Unmarshal(rb, &token) if err != nil { - return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb)) + return newTokenRefreshError(fmt.Sprintf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb)), resp) } - expiresOn, err := parseExpiresOn(token.ExpiresOn) - if err != nil { - return err + expiresOn := json.Number("") + // ADFS doesn't include the expires_on field + if token.ExpiresOn != "" { + if expiresOn, err = parseExpiresOn(token.ExpiresOn); err != nil { + return newTokenRefreshError(fmt.Sprintf("adal: failed to parse expires_on: %v value '%s'", err, token.ExpiresOn), resp) + } } spt.inner.Token.AccessToken = token.AccessToken spt.inner.Token.RefreshToken = token.RefreshToken diff --git a/autorest/adal/token_test.go b/autorest/adal/token_test.go index 1367a81d2..25f6cf392 100644 --- a/autorest/adal/token_test.go +++ b/autorest/adal/token_test.go @@ -386,6 +386,69 @@ func TestServicePrincipalTokenFromASE(t *testing.T) { } } +func TestServicePrincipalTokenFromADFS(t *testing.T) { + os.Setenv("MSI_ENDPOINT", "http://localhost") + os.Setenv("MSI_SECRET", "super") + defer func() { + os.Unsetenv("MSI_ENDPOINT") + os.Unsetenv("MSI_SECRET") + }() + resource := "https://resource" + endpoint, _ := GetMSIEndpoint() + spt, err := NewServicePrincipalTokenFromMSI(endpoint, resource) + if err != nil { + t.Fatalf("Failed to get MSI SPT: %v", err) + } + spt.MaxMSIRefreshAttempts = 1 + const expiresIn = 3600 + body := mocks.NewBody(newADFSTokenJSON(expiresIn)) + resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK") + + c := mocks.NewSender() + s := DecorateSender(c, + (func() SendDecorator { + return func(s Sender) Sender { + return SenderFunc(func(r *http.Request) (*http.Response, error) { + if r.Method != "GET" { + t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "GET", r.Method) + } + if h := r.Header.Get(metadataHeader); h != "" { + t.Fatalf("adal: ServicePrincipalToken#Refresh incorrectly set Metadata header for ASE") + } + if s := r.Header.Get(secretHeader); s != "super" { + t.Fatalf("adal: unexpected secret header value %s", s) + } + if r.URL.Host != "localhost" { + t.Fatalf("adal: unexpected host %s", r.URL.Host) + } + qp := r.URL.Query() + if api := qp.Get("api-version"); api != appServiceAPIVersion { + t.Fatalf("adal: unexpected api-version %s", api) + } + return resp, nil + }) + } + })()) + spt.SetSender(s) + err = spt.Refresh() + if err != nil { + t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err) + } + i, err := spt.inner.Token.ExpiresIn.Int64() + if err != nil { + t.Fatalf("unexpected parsing of expires_in: %v", err) + } + if i != expiresIn { + t.Fatalf("unexpected expires_in %d", i) + } + if spt.inner.Token.ExpiresOn.String() != "" { + t.Fatal("expected empty expires_on") + } + if body.IsOpen() { + t.Fatalf("the response was not closed!") + } +} + func TestServicePrincipalTokenFromMSIRefreshCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) endpoint, _ := GetMSIVMEndpoint() @@ -1284,6 +1347,15 @@ func newTokenJSON(expiresIn, expiresOn, resource string) string { expiresIn, expiresOn, nb, resource) } +func newADFSTokenJSON(expiresIn int) string { + return fmt.Sprintf(`{ + "access_token" : "accessToken", + "expires_in" : %d, + "token_type" : "Bearer" + }`, + expiresIn) +} + func newTokenExpiresIn(expireIn time.Duration) *Token { t := newToken() return setTokenToExpireIn(&t, expireIn)