diff --git a/autorest/adal/token.go b/autorest/adal/token.go index 07efaa402..1a9c8ab53 100644 --- a/autorest/adal/token.go +++ b/autorest/adal/token.go @@ -1104,8 +1104,8 @@ func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource // AAD returns expires_in as a string, ADFS returns it as an int ExpiresIn json.Number `json:"expires_in"` - // expires_on can be in two formats, a UTC time stamp or the number of seconds. - ExpiresOn string `json:"expires_on"` + // expires_on can be in three formats, a UTC time stamp, or the number of seconds as a string *or* int. + ExpiresOn interface{} `json:"expires_on"` NotBefore json.Number `json:"not_before"` Resource string `json:"resource"` @@ -1118,7 +1118,7 @@ func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource } expiresOn := json.Number("") // ADFS doesn't include the expires_on field - if token.ExpiresOn != "" { + if token.ExpiresOn != nil { if expiresOn, err = parseExpiresOn(token.ExpiresOn); err != nil { return newTokenRefreshError(fmt.Sprintf("adal: failed to parse expires_on: %v value '%s'", err, token.ExpiresOn), resp) } @@ -1135,18 +1135,27 @@ func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource } // converts expires_on to the number of seconds -func parseExpiresOn(s string) (json.Number, error) { - // convert the expiration date to the number of seconds from now +func parseExpiresOn(s interface{}) (json.Number, error) { + // the JSON unmarshaler treats JSON numbers unmarshaled into an interface{} as float64 + asFloat64, ok := s.(float64) + if ok { + // this is the number of seconds as int case + return json.Number(strconv.FormatInt(int64(asFloat64), 10)), nil + } + asStr, ok := s.(string) + if !ok { + return "", fmt.Errorf("unexpected expires_on type %T", s) + } + // convert the expiration date to the number of seconds from the unix epoch timeToDuration := func(t time.Time) json.Number { - dur := t.Sub(time.Now().UTC()) - return json.Number(strconv.FormatInt(int64(dur.Round(time.Second).Seconds()), 10)) + return json.Number(strconv.FormatInt(t.UTC().Unix(), 10)) } - if _, err := strconv.ParseInt(s, 10, 64); err == nil { + if _, err := json.Number(asStr).Int64(); err == nil { // this is the number of seconds case, no conversion required - return json.Number(s), nil - } else if eo, err := time.Parse(expiresOnDateFormatPM, s); err == nil { + return json.Number(asStr), nil + } else if eo, err := time.Parse(expiresOnDateFormatPM, asStr); err == nil { return timeToDuration(eo), nil - } else if eo, err := time.Parse(expiresOnDateFormat, s); err == nil { + } else if eo, err := time.Parse(expiresOnDateFormat, asStr); err == nil { return timeToDuration(eo), nil } else { // unknown format diff --git a/autorest/adal/token_test.go b/autorest/adal/token_test.go index 264ed749d..4c9ecf342 100644 --- a/autorest/adal/token_test.go +++ b/autorest/adal/token_test.go @@ -88,8 +88,7 @@ func TestTokenWillExpireIn(t *testing.T) { } func TestParseExpiresOn(t *testing.T) { - // get current time, round to nearest second, and add one hour - n := time.Now().UTC().Round(time.Second).Add(time.Hour) + n := time.Now().UTC() amPM := "AM" if n.Hour() >= 12 { amPM = "PM" @@ -107,12 +106,12 @@ func TestParseExpiresOn(t *testing.T) { { Name: "timestamp with AM/PM", String: fmt.Sprintf("%d/%d/%d %d:%02d:%02d %s +00:00", n.Month(), n.Day(), n.Year(), n.Hour(), n.Minute(), n.Second(), amPM), - Value: 3600, + Value: n.Unix(), }, { Name: "timestamp without AM/PM", - String: fmt.Sprintf("%d/%d/%d %d:%02d:%02d +00:00", n.Month(), n.Day(), n.Year(), n.Hour(), n.Minute(), n.Second()), - Value: 3600, + String: fmt.Sprintf("%02d/%02d/%02d %02d:%02d:%02d +00:00", n.Month(), n.Day(), n.Year(), n.Hour(), n.Minute(), n.Second()), + Value: n.Unix(), }, } for _, tc := range testcases { @@ -368,7 +367,8 @@ func TestServicePrincipalTokenFromASE(t *testing.T) { } spt.MaxMSIRefreshAttempts = 1 // expires_on is sent in UTC - expiresOn := time.Now().UTC().Add(time.Hour) + nowTime := time.Now() + expiresOn := nowTime.UTC().Add(time.Hour) // use int format for expires_in body := mocks.NewBody(newTokenJSON("3600", expiresOn.Format(expiresOnDateFormat), "test")) resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK") @@ -407,10 +407,8 @@ func TestServicePrincipalTokenFromASE(t *testing.T) { if err != nil { t.Fatalf("adal: failed to get ExpiresOn %v", err) } - // depending on elapsed time it might be slightly less that one hour - const hourInSeconds = int64(time.Hour / time.Second) - if v > hourInSeconds || v < hourInSeconds-1 { - t.Fatalf("adal: expected %v, got %v", int64(time.Hour/time.Second), v) + if nowAsUnix := nowTime.Add(time.Hour).Unix(); v != nowAsUnix { + t.Fatalf("adal: expected %v, got %v", nowAsUnix, v) } if body.IsOpen() { t.Fatalf("the response was not closed!") @@ -891,6 +889,34 @@ func TestServicePrincipalTokenEnsureFreshRefreshes(t *testing.T) { } } +func TestServicePrincipalTokenEnsureFreshWithIntExpiresOn(t *testing.T) { + spt := newServicePrincipalToken() + expireToken(&spt.inner.Token) + + body := mocks.NewBody(newTokenJSONIntExpiresOn(`"3600"`, 12345, "test")) + resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK") + + f := false + c := mocks.NewSender() + s := DecorateSender(c, + (func() SendDecorator { + return func(s Sender) Sender { + return SenderFunc(func(r *http.Request) (*http.Response, error) { + f = true + return resp, nil + }) + } + })()) + spt.SetSender(s) + err := spt.EnsureFresh() + if err != nil { + t.Fatalf("adal: ServicePrincipalToken#EnsureFresh returned an unexpected error (%v)", err) + } + if !f { + t.Fatal("adal: ServicePrincipalToken#EnsureFresh failed to call Refresh for stale token") + } +} + func TestServicePrincipalTokenEnsureFreshFails1(t *testing.T) { spt := newServicePrincipalToken() expireToken(&spt.inner.Token) @@ -1461,6 +1487,19 @@ func newTokenJSON(expiresIn, expiresOn, resource string) string { expiresIn, expiresOn, nb, resource) } +func newTokenJSONIntExpiresOn(expiresIn string, expiresOn int, resource string) string { + return fmt.Sprintf(`{ + "access_token" : "accessToken", + "expires_in" : %s, + "expires_on" : %d, + "not_before" : "%d", + "resource" : "%s", + "token_type" : "Bearer", + "refresh_token": "ABC123" + }`, + expiresIn, expiresOn, expiresOn, resource) +} + func newADFSTokenJSON(expiresIn int) string { return fmt.Sprintf(`{ "access_token" : "accessToken",