diff --git a/api.go b/api.go index 2e707dd..3fb501a 100644 --- a/api.go +++ b/api.go @@ -165,6 +165,7 @@ func (a *API) validateClientCredentials() error { ClientSecret: a.clientSecret, TokenURL: tokenURL.String(), EndpointParams: v, + AuthStyle: oauth2.AuthStyleInHeader, } a.clientCredentialsConfig = c a.AuthenticatedClient = c.Client(context.WithValue(context.Background(), oauth2.HTTPClient, a.UnauthenticatedClient)) @@ -304,7 +305,7 @@ func (a *API) Token(ctx context.Context) (*oauth2.Token, error) { // NewWithAuthorizationCode builds an API that uses the authorization code // grant to get a token for use with the UAA API. func NewWithAuthorizationCode(target string, zoneID string, clientID string, clientSecret string, authorizationCode string, tokenFormat TokenFormat, skipSSLValidation bool) (*API, error) { - a := New(target, zoneID).WithAuthorizationCode(clientID, clientSecret, authorizationCode, tokenFormat).WithSkipSSLValidation(skipSSLValidation) + a := New(target, zoneID).WithSkipSSLValidation(skipSSLValidation).WithAuthorizationCode(clientID, clientSecret, authorizationCode, tokenFormat) err := a.Validate() if err != nil { return nil, err @@ -332,7 +333,8 @@ func (a *API) validateAuthorizationCode() error { ClientID: a.clientID, ClientSecret: a.clientSecret, Endpoint: oauth2.Endpoint{ - TokenURL: tokenURL.String(), + TokenURL: tokenURL.String(), + AuthStyle: oauth2.AuthStyleInHeader, }, } a.oauthConfig = c @@ -356,7 +358,7 @@ func (a *API) validateAuthorizationCode() error { // NewWithRefreshToken builds an API that uses the given refresh token to get an // access token for use with the UAA API. func NewWithRefreshToken(target string, zoneID string, clientID string, clientSecret string, refreshToken string, tokenFormat TokenFormat, skipSSLValidation bool) (*API, error) { - a := New(target, zoneID).WithRefreshToken(clientID, clientSecret, refreshToken, tokenFormat).WithSkipSSLValidation(skipSSLValidation) + a := New(target, zoneID).WithSkipSSLValidation(skipSSLValidation).WithRefreshToken(clientID, clientSecret, refreshToken, tokenFormat) err := a.Validate() if err != nil { return nil, err @@ -387,7 +389,8 @@ func (a *API) validateRefreshToken() error { ClientID: a.clientID, ClientSecret: a.clientSecret, Endpoint: oauth2.Endpoint{ - TokenURL: tokenURL.String(), + TokenURL: tokenURL.String(), + AuthStyle: oauth2.AuthStyleInHeader, }, } a.oauthConfig = c diff --git a/api_test.go b/api_test.go index bebbac5..97662db 100644 --- a/api_test.go +++ b/api_test.go @@ -2,17 +2,16 @@ package uaa_test import ( "context" - "encoding/json" + "fmt" "net/http" "net/http/httptest" "reflect" "testing" "time" - "io/ioutil" - uaa "github.com/cloudfoundry-community/go-uaa" . "github.com/onsi/gomega" + "github.com/onsi/gomega/ghttp" "github.com/sclevine/spec" "golang.org/x/oauth2" ) @@ -156,31 +155,23 @@ func testNew(t *testing.T, when spec.G, it spec.S) { when("the server returns tokens", func() { var ( - s *httptest.Server - returnToken bool - callCount int + s *ghttp.Server ) it.Before(func() { - returnToken = true - s = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - callCount = callCount + 1 - - w.Header().Set("Content-Type", "application/json") - - t := &oauth2.Token{ - AccessToken: "test-access-token", - RefreshToken: "test-refresh-token", - TokenType: "bearer", - Expiry: time.Now().Add(60 * time.Second), - } - if !returnToken { - t = nil - } - w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(t) - Expect(err).NotTo(HaveOccurred()) - })) + s = ghttp.NewServer() + t := &oauth2.Token{ + AccessToken: "test-access-token", + RefreshToken: "test-refresh-token", + TokenType: "bearer", + Expiry: time.Now().Add(60 * time.Second), + } + s.AppendHandlers(ghttp.CombineHandlers( + ghttp.VerifyRequest("POST", "/oauth/token"), + ghttp.VerifyFormKV("grant_type", "client_credentials"), + ghttp.VerifyFormKV("token_format", "opaque"), + ghttp.RespondWithJSONEncoded(http.StatusOK, t), + )) }) it.After(func() { @@ -190,7 +181,7 @@ func testNew(t *testing.T, when spec.G, it spec.S) { }) it("Token() succeeds when the mode is client credentials and the client credentials are valid", func() { - api := uaa.New(s.URL, "") + api := uaa.New(s.URL(), "") Expect(api).NotTo(BeNil()) api.TargetURL = nil api = api.WithClientCredentials("client-id", "client-secret", uaa.OpaqueToken) @@ -226,222 +217,252 @@ func testNew(t *testing.T, when spec.G, it spec.S) { }) when("NewWithAuthorizationCode", func() { - var ( - s *httptest.Server - returnToken bool - reqBody []byte - callCount int - ) + var s *ghttp.Server + + stubTokenRequest := func(clientId string, clientSecret string, authCode string, tokenFormat uaa.TokenFormat, response http.HandlerFunc) { + s.AppendHandlers(ghttp.CombineHandlers( + ghttp.VerifyRequest("POST", "/oauth/token"), + ghttp.VerifyFormKV("grant_type", "authorization_code"), + ghttp.VerifyFormKV("code", authCode), + ghttp.VerifyFormKV("token_format", tokenFormat.String()), + response, + )) + } + + stubTokenSuccess := func(clientId string, clientSecret string, authCode string, tokenFormat uaa.TokenFormat) { + t := &oauth2.Token{ + AccessToken: "test-access-token", + RefreshToken: "test-refresh-token", + TokenType: "bearer", + Expiry: time.Now().Add(60 * time.Second), + } - it.Before(func() { - returnToken = true - s = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - callCount = callCount + 1 - var err error - reqBody, err = ioutil.ReadAll(req.Body) - Expect(err).NotTo(HaveOccurred()) + stubTokenRequest(clientId, clientSecret, authCode, tokenFormat, ghttp.RespondWithJSONEncoded(http.StatusOK, t)) + } - w.Header().Set("Content-Type", "application/json") + stubMalformedTokenSuccess := func(clientId string, clientSecret string, authCode string, tokenFormat uaa.TokenFormat) { + stubTokenRequest(clientId, clientSecret, authCode, tokenFormat, ghttp.RespondWithJSONEncoded(http.StatusOK, nil)) + } - t := &oauth2.Token{ - AccessToken: "test-access-token", - RefreshToken: "test-refresh-token", - TokenType: "bearer", - Expiry: time.Now().Add(60 * time.Second), - } - if !returnToken { - t = nil - } - w.WriteHeader(http.StatusOK) - err = json.NewEncoder(w).Encode(t) - Expect(err).NotTo(HaveOccurred()) - })) + stubTokenFailure := func(clientId string, clientSecret string, authCode string, tokenFormat uaa.TokenFormat) { + stubTokenRequest(clientId, clientSecret, authCode, tokenFormat, ghttp.RespondWithJSONEncoded(http.StatusBadRequest, nil)) + } + + it.Before(func() { + s = ghttp.NewServer() }) it.After(func() { - if s != nil { - s.Close() - } + s.Close() }) - it("fails if the target url is invalid", func() { - api, err := uaa.NewWithAuthorizationCode("(*#&^@%$&%)", "", "", "", "", uaa.OpaqueToken, false) - Expect(err).To(HaveOccurred()) - Expect(api).To(BeNil()) - }) + when("success", func() { + it.Before(func() { + // Token retrieval is done as part of validateAuthorizationCode + // validateAuthorizationCode is called two times on construction + // AuthStyle is set to AuthStyleInHeader, failed token requests are not retried + // Because the first token reqest succeeds, later token attempts are skipped + // 1 token request, 1 attempt each => 1 request + stubTokenSuccess("client-id", "client-secret", "auth-code", uaa.OpaqueToken) + }) - it("returns an API with a TargetURL", func() { - api, err := uaa.NewWithAuthorizationCode(s.URL, "", "", "", "", uaa.OpaqueToken, false) - Expect(err).NotTo(HaveOccurred()) - Expect(api).NotTo(BeNil()) - Expect(api.TargetURL.String()).To(Equal(s.URL)) - Expect(callCount).To(Equal(1)) - }) + it("returns an API with a TargetURL", func() { + api, err := uaa.NewWithAuthorizationCode(s.URL(), "", "client-id", "client-secret", "auth-code", uaa.OpaqueToken, false) + Expect(err).NotTo(HaveOccurred()) + Expect(api).NotTo(BeNil()) + Expect(api.TargetURL.String()).To(Equal(s.URL())) + }) - it("returns an API with an HTTPClient", func() { - api, err := uaa.NewWithAuthorizationCode(s.URL, "", "", "", "", uaa.OpaqueToken, false) - Expect(err).NotTo(HaveOccurred()) - Expect(api).NotTo(BeNil()) - Expect(api.AuthenticatedClient).NotTo(BeNil()) + it("returns an API with an HTTPClient", func() { + api, err := uaa.NewWithAuthorizationCode(s.URL(), "", "client-id", "client-secret", "auth-code", uaa.OpaqueToken, false) + Expect(err).NotTo(HaveOccurred()) + Expect(api).NotTo(BeNil()) + Expect(api.AuthenticatedClient).NotTo(BeNil()) + }) }) - it("returns an error if the token cannot be retrieved", func() { - returnToken = false - api, err := uaa.NewWithAuthorizationCode(s.URL, "", "", "", "", uaa.OpaqueToken, false) - Expect(err).To(HaveOccurred()) - Expect(api).To(BeNil()) + when("invalid target url", func() { + it("returns an error", func() { + api, err := uaa.NewWithAuthorizationCode("(*#&^@%$&%)", "client-id", "client-secret", "auth-code", "", uaa.OpaqueToken, false) + Expect(err).To(HaveOccurred()) + Expect(api).To(BeNil()) + }) }) - it("ensure that auth code grant type params are set correctly", func() { - api, err := uaa.NewWithAuthorizationCode(s.URL, "", "", "", "", uaa.OpaqueToken, false) - Expect(err).NotTo(HaveOccurred()) - Expect(api).NotTo(BeNil()) + when("created with an invalid auth code", func() { + it.Before(func() { + // Token retrieval is done as part of validateAuthorizationCode + // validateAuthorizationCode is called two times on construction + // AuthStyle is set to AuthStyleInHeader, failed token requests are not retried + // 2 token requests, 1 attempt each => 2 requests + stubTokenFailure("client-id", "client-secret", "", uaa.JSONWebToken) + stubTokenFailure("client-id", "client-secret", "", uaa.JSONWebToken) + }) - Expect(string(reqBody)).To(ContainSubstring("token_format=opaque")) - Expect(string(reqBody)).To(ContainSubstring("response_type=token")) - Expect(string(reqBody)).To(ContainSubstring("grant_type=authorization_code")) + it("returns an error", func() { + api, err := uaa.NewWithAuthorizationCode(s.URL(), "", "client-id", "client-secret", "", uaa.JSONWebToken, false) + Expect(err).To(HaveOccurred()) + Expect(api).To(BeNil()) + }) }) - it("Token() fails when the mode is authorizationcode and the authorization code is invalid", func() { - api := uaa.New("(*#&^@%$&%)", "") - Expect(api).NotTo(BeNil()) - api.TargetURL = nil - api = api.WithAuthorizationCode("client-id", "client-secret", "", uaa.OpaqueToken) - Expect(api).NotTo(BeNil()) - t, err := api.Token(context.Background()) - Expect(err).To(HaveOccurred()) - Expect(t).To(BeNil()) + when("the token response is missing a token", func() { + it.Before(func() { + // Token retrieval is done as part of validateAuthorizationCode + // validateAuthorizationCode is called two times on construction + // AuthStyle is set to AuthStyleInHeader, failed token requests are not retried + // 2 token requests, 2 attempts each => 4 requests + stubMalformedTokenSuccess("client-id", "client-secret", "auth-code", uaa.OpaqueToken) + stubMalformedTokenSuccess("client-id", "client-secret", "auth-code", uaa.OpaqueToken) + }) + + it("returns an error", func() { + api, err := uaa.NewWithAuthorizationCode(s.URL(), "", "client-id", "client-secret", "auth-code", uaa.OpaqueToken, false) + Expect(err).To(HaveOccurred()) + Expect(api).To(BeNil()) + }) }) - it("Token() will set the UnauthenticatedClient to the default if necessary", func() { - api := uaa.New(s.URL, "") - Expect(api).NotTo(BeNil()) - api.TargetURL = nil - api = api.WithAuthorizationCode("client-id", "client-secret", "valid", uaa.OpaqueToken) - Expect(api).NotTo(BeNil()) - api.UnauthenticatedClient = nil - t, err := api.Token(context.Background()) - Expect(err).NotTo(HaveOccurred()) - Expect(t.Valid()).To(BeTrue()) + when("the UnauthenticatedClient is removed", func() { + it.Before(func() { + // Token retrieval is done as part of validateAuthorizationCode + // validateAuthorizationCode is called two times on construction + // AuthStyle is set to AuthStyleInHeader, failed token requests are not retried + // Because the first token reqest succeeds, later token attempts are skipped + // Then another token is explicitly requested + // 2 token requests, 1 attempt each => 2 requests + stubTokenSuccess("client-id", "client-secret", "auth-code", uaa.OpaqueToken) + stubTokenSuccess("client-id", "client-secret", "auth-code", uaa.OpaqueToken) + }) + + it("Token() will set the UnauthenticatedClient to the default", func() { + api, err := uaa.NewWithAuthorizationCode(s.URL(), "", "client-id", "client-secret", "auth-code", uaa.OpaqueToken, false) + Expect(err).To(BeNil()) + Expect(api).NotTo(BeNil()) + api.UnauthenticatedClient = nil + t, err := api.Token(context.Background()) + Expect(err).NotTo(HaveOccurred()) + Expect(t.Valid()).To(BeTrue()) + }) }) }) when("NewWithRefreshToken", func() { - var ( - s *httptest.Server - returnToken bool - rawQuery string - reqBody []byte - ) + var s *ghttp.Server + + stubTokenRequest := func(clientId string, clientSecret string, refreshToken string, tokenFormat uaa.TokenFormat, response http.HandlerFunc) { + s.AppendHandlers(ghttp.CombineHandlers( + ghttp.VerifyRequest("POST", "/oauth/token", fmt.Sprintf("token_format=%s", tokenFormat)), + ghttp.VerifyFormKV("grant_type", "refresh_token"), + ghttp.VerifyFormKV("refresh_token", refreshToken), + response, + )) + } + + stubTokenSuccess := func(clientId string, clientSecret string, refreshToken string, tokenFormat uaa.TokenFormat) { + token := &oauth2.Token{ + AccessToken: "test-access-token", + RefreshToken: "test-refresh-token", + TokenType: "bearer", + Expiry: time.Now().Add(60 * time.Second), + } - it.Before(func() { - returnToken = true + stubTokenRequest(clientId, clientSecret, refreshToken, tokenFormat, ghttp.RespondWithJSONEncoded(http.StatusOK, token)) + } - s = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - var err error - rawQuery = req.URL.RawQuery - reqBody, err = ioutil.ReadAll(req.Body) - Expect(err).NotTo(HaveOccurred()) + stubMalformedTokenSuccess := func(clientId string, clientSecret string, refreshToken string, tokenFormat uaa.TokenFormat) { + stubTokenRequest(clientId, clientSecret, refreshToken, tokenFormat, ghttp.RespondWithJSONEncoded(http.StatusOK, nil)) + } - w.Header().Set("Content-Type", "application/json") - t := &oauth2.Token{ - AccessToken: "test-access-token", - RefreshToken: "test-refresh-token", - TokenType: "bearer", - Expiry: time.Now().Add(60 * time.Second), - } - if !returnToken { - t = nil - } - w.WriteHeader(http.StatusOK) - err = json.NewEncoder(w).Encode(t) - Expect(err).NotTo(HaveOccurred()) - })) + it.Before(func() { + s = ghttp.NewServer() }) it.After(func() { - if s != nil { - s.Close() - } + s.Close() }) - it("fails if the refresh token is invalid", func() { - invalidRefreshToken := "" - api, err := uaa.NewWithRefreshToken(s.URL, "", "", "", invalidRefreshToken, uaa.JSONWebToken, false) - Expect(err).To(HaveOccurred()) - Expect(err).To(MatchError("oauth2: token expired and refresh token is not set")) - Expect(api).To(BeNil()) - }) - - it("fails if the target url is invalid", func() { - api, err := uaa.NewWithRefreshToken("(*#&^@%$&%)", "", "", "", "refresh-token", uaa.JSONWebToken, false) - Expect(err).To(HaveOccurred()) - Expect(api).To(BeNil()) - }) + when("success", func() { + it.Before(func() { + // Token retrieval is done as part of validateRefreshToken + // validateRefreshToken is called two times on construction + // AuthStyle is set to AuthStyleInHeader, failed token requests are not retried + // Because the first token reqest succeeds, later token attempts are skipped + // 1 token request, 1 attempt each => 1 request + stubTokenSuccess("client-id", "client-secret", "refresh-token", uaa.JSONWebToken) + }) - it("returns an API with a TargetURL", func() { - api, err := uaa.NewWithRefreshToken(s.URL, "", "", "", "refresh-token", uaa.JSONWebToken, false) - Expect(err).NotTo(HaveOccurred()) - Expect(api).NotTo(BeNil()) - Expect(api.TargetURL.String()).To(Equal(s.URL)) - }) + it("returns an API with a TargetURL", func() { + api, err := uaa.NewWithRefreshToken(s.URL(), "", "client-id", "client-secret", "refresh-token", uaa.JSONWebToken, false) + Expect(err).NotTo(HaveOccurred()) + Expect(api).NotTo(BeNil()) + Expect(api.TargetURL.String()).To(Equal(s.URL())) + }) - it("returns an API with an HTTPClient", func() { - api, err := uaa.NewWithRefreshToken(s.URL, "", "", "", "refresh-token", uaa.JSONWebToken, false) - Expect(err).NotTo(HaveOccurred()) - Expect(api).NotTo(BeNil()) - Expect(api.AuthenticatedClient).NotTo(BeNil()) + it("returns an API with an HTTPClient", func() { + api, err := uaa.NewWithRefreshToken(s.URL(), "", "client-id", "client-secret", "refresh-token", uaa.JSONWebToken, false) + Expect(err).NotTo(HaveOccurred()) + Expect(api).NotTo(BeNil()) + Expect(api.AuthenticatedClient).NotTo(BeNil()) + }) }) - it("returns an error if the token cannot be retrieved", func() { - returnToken = false - api, err := uaa.NewWithRefreshToken(s.URL, "", "", "", "refresh-token", uaa.JSONWebToken, false) - Expect(err).To(HaveOccurred()) - Expect(err).To(MatchError("oauth2: server response missing access_token")) - Expect(api).To(BeNil()) + when("created with an invalid target url", func() { + it("returns an error", func() { + api, err := uaa.NewWithRefreshToken("(*#&^@%$&%)", "", "client-id", "client-secret", "refresh-token", uaa.JSONWebToken, false) + Expect(err).To(HaveOccurred()) + Expect(api).To(BeNil()) + }) }) - it("ensure that refresh grant type params are set correctly", func() { - api, err := uaa.NewWithRefreshToken(s.URL, "", "", "", "refresh-token", uaa.JSONWebToken, false) - Expect(err).NotTo(HaveOccurred()) - Expect(api).NotTo(BeNil()) - - Expect(rawQuery).To(Equal("token_format=jwt")) - Expect(string(reqBody)).To(ContainSubstring("grant_type=refresh_token")) - Expect(string(reqBody)).To(ContainSubstring("refresh_token=refresh-token")) + when("created with an invalid refresh token", func() { + it("returns an error", func() { + api, err := uaa.NewWithRefreshToken(s.URL(), "", "client-id", "client-secret", "", uaa.JSONWebToken, false) + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError("oauth2: token expired and refresh token is not set")) + Expect(api).To(BeNil()) + }) }) - it("ensure that refresh grant type params are set correctly for opaque tokens", func() { - api, err := uaa.NewWithRefreshToken(s.URL, "", "", "", "refresh-token", uaa.OpaqueToken, false) - Expect(err).NotTo(HaveOccurred()) - Expect(api).NotTo(BeNil()) + when("the token response is missing a token", func() { + it.Before(func() { + // Token retrieval is done as part of validateRefreshToken + // validateRefreshToken is called two times on construction + // AuthStyle is set to AuthStyleInHeader, failed token requests are not retried + // 2 token requests, 1 attempt each => 2 requests + stubMalformedTokenSuccess("client-id", "client-secret", "refresh-token", uaa.JSONWebToken) + stubMalformedTokenSuccess("client-id", "client-secret", "refresh-token", uaa.JSONWebToken) + }) - Expect(rawQuery).To(Equal("token_format=opaque")) - Expect(string(reqBody)).To(ContainSubstring("grant_type=refresh_token")) - Expect(string(reqBody)).To(ContainSubstring("refresh_token=refresh-token")) + it("returns an error", func() { + api, err := uaa.NewWithRefreshToken(s.URL(), "", "client-id", "client-secret", "refresh-token", uaa.JSONWebToken, false) + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError("oauth2: server response missing access_token")) + Expect(api).To(BeNil()) + }) }) - it("Token() fails when the mode is refreshtoken and the refresh token is invalid", func() { - api := uaa.New("(*#&^@%$&%)", "") - Expect(api).NotTo(BeNil()) - api.TargetURL = nil - api = api.WithRefreshToken("client-id", "client-secret", "", uaa.OpaqueToken) - Expect(api).NotTo(BeNil()) - t, err := api.Token(context.Background()) - Expect(err).To(HaveOccurred()) - Expect(t).To(BeNil()) - }) + when("the UnauthenticatedClient is removed", func() { + it.Before(func() { + // Token retrieval is done as part of validateRefreshToken + // validateRefreshToken is called two times on construction + // AuthStyle is set to AuthStyleInHeader, failed token requests are not retried + // Because the first token reqest succeeds, later token attempts are skipped + // Then another token is explicitly requested + // 2 token requests, 1 attempt each => 2 requests + stubTokenSuccess("client-id", "client-secret", "refresh-token", uaa.JSONWebToken) + stubTokenSuccess("client-id", "client-secret", "refresh-token", uaa.JSONWebToken) + }) - it("Token() will set the UnauthenticatedClient to the default if necessary", func() { - api := uaa.New(s.URL, "") - Expect(api).NotTo(BeNil()) - api.TargetURL = nil - api = api.WithRefreshToken("client-id", "client-secret", "valid", uaa.OpaqueToken) - Expect(api).NotTo(BeNil()) - api.UnauthenticatedClient = nil - t, err := api.Token(context.Background()) - Expect(err).NotTo(HaveOccurred()) - Expect(t.Valid()).To(BeTrue()) + it("Token() will set the UnauthenticatedClient to the default", func() { + api, err := uaa.NewWithRefreshToken(s.URL(), "", "client-id", "client-secret", "refresh-token", uaa.JSONWebToken, false) + Expect(err).To(BeNil()) + Expect(api).NotTo(BeNil()) + api.UnauthenticatedClient = nil + t, err := api.Token(context.Background()) + Expect(err).NotTo(HaveOccurred()) + Expect(t.Valid()).To(BeTrue()) + }) }) }) }