Skip to content

Commit

Permalink
Handle empty expires_on (#599)
Browse files Browse the repository at this point in the history
* Handle empty expires_on

ADFS can return an empty expires_on, handle it.
Return a TokenRefreshError when failing to parse a token response.

* refactor
  • Loading branch information
jhendrixMSFT authored Dec 17, 2020
1 parent 27ede5f commit 9ff0227
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 4 deletions.
12 changes: 8 additions & 4 deletions autorest/adal/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -1006,13 +1006,17 @@ func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource
Resource string `json:"resource"`
Type string `json:"token_type"`
}{}
// return a TokenRefreshError in the follow error cases as the token is in an unexpected format
err = json.Unmarshal(rb, &token)
if err != nil {
return fmt.Errorf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb))
return newTokenRefreshError(fmt.Sprintf("adal: Failed to unmarshal the service principal token during refresh. Error = '%v' JSON = '%s'", err, string(rb)), resp)
}
expiresOn, err := parseExpiresOn(token.ExpiresOn)
if err != nil {
return err
expiresOn := json.Number("")
// ADFS doesn't include the expires_on field
if token.ExpiresOn != "" {
if expiresOn, err = parseExpiresOn(token.ExpiresOn); err != nil {
return newTokenRefreshError(fmt.Sprintf("adal: failed to parse expires_on: %v value '%s'", err, token.ExpiresOn), resp)
}
}
spt.inner.Token.AccessToken = token.AccessToken
spt.inner.Token.RefreshToken = token.RefreshToken
Expand Down
72 changes: 72 additions & 0 deletions autorest/adal/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,69 @@ func TestServicePrincipalTokenFromASE(t *testing.T) {
}
}

func TestServicePrincipalTokenFromADFS(t *testing.T) {
os.Setenv("MSI_ENDPOINT", "http://localhost")
os.Setenv("MSI_SECRET", "super")
defer func() {
os.Unsetenv("MSI_ENDPOINT")
os.Unsetenv("MSI_SECRET")
}()
resource := "https://resource"
endpoint, _ := GetMSIEndpoint()
spt, err := NewServicePrincipalTokenFromMSI(endpoint, resource)
if err != nil {
t.Fatalf("Failed to get MSI SPT: %v", err)
}
spt.MaxMSIRefreshAttempts = 1
const expiresIn = 3600
body := mocks.NewBody(newADFSTokenJSON(expiresIn))
resp := mocks.NewResponseWithBodyAndStatus(body, http.StatusOK, "OK")

c := mocks.NewSender()
s := DecorateSender(c,
(func() SendDecorator {
return func(s Sender) Sender {
return SenderFunc(func(r *http.Request) (*http.Response, error) {
if r.Method != "GET" {
t.Fatalf("adal: ServicePrincipalToken#Refresh did not correctly set HTTP method -- expected %v, received %v", "GET", r.Method)
}
if h := r.Header.Get(metadataHeader); h != "" {
t.Fatalf("adal: ServicePrincipalToken#Refresh incorrectly set Metadata header for ASE")
}
if s := r.Header.Get(secretHeader); s != "super" {
t.Fatalf("adal: unexpected secret header value %s", s)
}
if r.URL.Host != "localhost" {
t.Fatalf("adal: unexpected host %s", r.URL.Host)
}
qp := r.URL.Query()
if api := qp.Get("api-version"); api != appServiceAPIVersion {
t.Fatalf("adal: unexpected api-version %s", api)
}
return resp, nil
})
}
})())
spt.SetSender(s)
err = spt.Refresh()
if err != nil {
t.Fatalf("adal: ServicePrincipalToken#Refresh returned an unexpected error (%v)", err)
}
i, err := spt.inner.Token.ExpiresIn.Int64()
if err != nil {
t.Fatalf("unexpected parsing of expires_in: %v", err)
}
if i != expiresIn {
t.Fatalf("unexpected expires_in %d", i)
}
if spt.inner.Token.ExpiresOn.String() != "" {
t.Fatal("expected empty expires_on")
}
if body.IsOpen() {
t.Fatalf("the response was not closed!")
}
}

func TestServicePrincipalTokenFromMSIRefreshCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
endpoint, _ := GetMSIVMEndpoint()
Expand Down Expand Up @@ -1284,6 +1347,15 @@ func newTokenJSON(expiresIn, expiresOn, resource string) string {
expiresIn, expiresOn, nb, resource)
}

func newADFSTokenJSON(expiresIn int) string {
return fmt.Sprintf(`{
"access_token" : "accessToken",
"expires_in" : %d,
"token_type" : "Bearer"
}`,
expiresIn)
}

func newTokenExpiresIn(expireIn time.Duration) *Token {
t := newToken()
return setTokenToExpireIn(&t, expireIn)
Expand Down

0 comments on commit 9ff0227

Please sign in to comment.