From f06d8fad42ae0a460e6e53f1b57cbbf1d4d14757 Mon Sep 17 00:00:00 2001 From: Dmitriy Dubson Date: Mon, 30 Oct 2023 17:15:46 -0400 Subject: [PATCH] Add OAuth2TokenEndpointAuthenticationSuccessHandler Fixes gh-925 --- .../ROOT/pages/protocol-endpoints.adoc | 2 +- ...Auth2AccessTokenAuthenticationContext.java | 106 ++++++++++ .../web/OAuth2TokenEndpointFilter.java | 43 +--- ...nResponseAuthenticationSuccessHandler.java | 125 +++++++++++ ...2AccessTokenAuthenticationContextTest.java | 70 +++++++ ...onseAuthenticationSuccessHandlerTests.java | 195 ++++++++++++++++++ 6 files changed, 500 insertions(+), 41 deletions(-) create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationContext.java create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AccessTokenResponseAuthenticationSuccessHandler.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationContextTest.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AccessTokenResponseAuthenticationSuccessHandlerTests.java diff --git a/docs/modules/ROOT/pages/protocol-endpoints.adoc b/docs/modules/ROOT/pages/protocol-endpoints.adoc index 2e6404852b..13e3e3f8da 100644 --- a/docs/modules/ROOT/pages/protocol-endpoints.adoc +++ b/docs/modules/ROOT/pages/protocol-endpoints.adoc @@ -263,7 +263,7 @@ The supported https://datatracker.ietf.org/doc/html/rfc6749#section-1.3[authoriz * `*AuthenticationConverter*` -- A `DelegatingAuthenticationConverter` composed of `OAuth2AuthorizationCodeAuthenticationConverter`, `OAuth2RefreshTokenAuthenticationConverter`, `OAuth2ClientCredentialsAuthenticationConverter`, and `OAuth2DeviceCodeAuthenticationConverter`. * `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OAuth2AuthorizationCodeAuthenticationProvider`, `OAuth2RefreshTokenAuthenticationProvider`, `OAuth2ClientCredentialsAuthenticationProvider`, and `OAuth2DeviceCodeAuthenticationProvider`. -* `*AuthenticationSuccessHandler*` -- An internal implementation that handles an `OAuth2AccessTokenAuthenticationToken` and returns the `OAuth2AccessTokenResponse`. +* `*AuthenticationSuccessHandler*` -- An `OAuth2AccessTokenResponseAuthenticationSuccessHandler`. * `*AuthenticationFailureHandler*` -- An `OAuth2ErrorAuthenticationFailureHandler`. [[oauth2-token-endpoint-customizing-client-credentials-grant-request-validation]] diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationContext.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationContext.java new file mode 100644 index 0000000000..a160795489 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationContext.java @@ -0,0 +1,106 @@ +/* + * Copyright 2020-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + +import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AccessTokenResponseAuthenticationSuccessHandler; +import org.springframework.util.Assert; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Consumer; + +/** + * An {@link OAuth2AuthenticationContext} that holds an {@link OAuth2AccessTokenResponse.Builder} + * and is used when customizing the building of the {@link OAuth2AccessTokenResponse}. + * + * @author Dmitriy Dubson + * @see OAuth2AuthenticationContext + * @see OAuth2AccessTokenResponse + * @see OAuth2AccessTokenResponseAuthenticationSuccessHandler#setAccessTokenResponseCustomizer(Consumer) + * @since 1.3 + */ +public class OAuth2AccessTokenAuthenticationContext implements OAuth2AuthenticationContext { + private final Map context; + + private OAuth2AccessTokenAuthenticationContext(Map context) { + this.context = Collections.unmodifiableMap(new HashMap<>(context)); + } + + @SuppressWarnings("unchecked") + @Nullable + @Override + public V get(Object key) { + return hasKey(key) ? (V) this.context.get(key) : null; + } + + @Override + public boolean hasKey(Object key) { + Assert.notNull(key, "key cannot be null"); + return this.context.containsKey(key); + } + + /** + * Returns the {@link OAuth2AccessTokenResponse.Builder} access token response builder + * @return the {@link OAuth2AccessTokenResponse.Builder} + */ + public OAuth2AccessTokenResponse.Builder getAccessTokenResponse() { + return get(OAuth2AccessTokenResponse.Builder.class); + } + + /** + * Constructs a new {@link Builder} with the provided {@link OAuth2AccessTokenAuthenticationToken}. + * + * @param authentication the {@link OAuth2AccessTokenAuthenticationToken} + * @return the {@link Builder} + */ + public static OAuth2AccessTokenAuthenticationContext.Builder with(OAuth2AccessTokenAuthenticationToken authentication) { + return new OAuth2AccessTokenAuthenticationContext.Builder(authentication); + } + + /** + * A builder for {@link OAuth2AccessTokenAuthenticationContext} + */ + public static final class Builder extends AbstractBuilder { + private Builder(OAuth2AccessTokenAuthenticationToken authentication) { + super(authentication); + put(OAuth2AccessTokenAuthenticationToken.class, authentication); + } + + /** + * Sets the {@link OAuth2AccessTokenResponse.Builder} access token response builder + * @param accessTokenResponse the {@link OAuth2AccessTokenResponse.Builder} + * @return the {@link Builder} for further configuration + */ + public Builder accessTokenResponse(OAuth2AccessTokenResponse.Builder accessTokenResponse) { + return put(OAuth2AccessTokenResponse.Builder.class, accessTokenResponse); + } + + /** + * Builds a new {@link OAuth2AccessTokenAuthenticationContext}. + * + * @return the {@link OAuth2AccessTokenAuthenticationContext} + */ + public OAuth2AccessTokenAuthenticationContext build() { + Assert.notNull(get(OAuth2AccessTokenResponse.Builder.class), "accessTokenResponse cannot be null"); + Assert.notNull(get(OAuth2AccessTokenAuthenticationToken.class), "accessTokenAuthenticationToken cannot be null"); + + return new OAuth2AccessTokenAuthenticationContext(getContext()); + } + } +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java index 655162c329..c653d5a663 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java @@ -16,32 +16,24 @@ package org.springframework.security.oauth2.server.authorization.web; import java.io.IOException; -import java.time.temporal.ChronoUnit; import java.util.Arrays; -import java.util.Map; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; - import org.springframework.core.log.LogMessage; import org.springframework.http.HttpMethod; -import org.springframework.http.converter.HttpMessageConverter; -import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; -import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationGrantAuthenticationToken; @@ -54,6 +46,7 @@ import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2DeviceCodeAuthenticationConverter; import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler; import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2RefreshTokenAuthenticationConverter; +import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AccessTokenResponseAuthenticationSuccessHandler; import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; @@ -61,7 +54,6 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; import org.springframework.web.filter.OncePerRequestFilter; /** @@ -86,6 +78,7 @@ * @author Joe Grandja * @author Madhu Bhat * @author Daniel Garnier-Moiroux + * @author Dmitriy Dubson * @since 0.0.1 * @see AuthenticationManager * @see OAuth2AuthorizationCodeAuthenticationProvider @@ -103,12 +96,10 @@ public final class OAuth2TokenEndpointFilter extends OncePerRequestFilter { private static final String DEFAULT_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2"; private final AuthenticationManager authenticationManager; private final RequestMatcher tokenEndpointMatcher; - private final HttpMessageConverter accessTokenHttpResponseConverter = - new OAuth2AccessTokenResponseHttpMessageConverter(); private AuthenticationDetailsSource authenticationDetailsSource = new WebAuthenticationDetailsSource(); private AuthenticationConverter authenticationConverter; - private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendAccessTokenResponse; + private AuthenticationSuccessHandler authenticationSuccessHandler = new OAuth2AccessTokenResponseAuthenticationSuccessHandler(); private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ErrorAuthenticationFailureHandler(); /** @@ -218,34 +209,6 @@ public void setAuthenticationFailureHandler(AuthenticationFailureHandler authent this.authenticationFailureHandler = authenticationFailureHandler; } - private void sendAccessTokenResponse(HttpServletRequest request, HttpServletResponse response, - Authentication authentication) throws IOException { - - OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = - (OAuth2AccessTokenAuthenticationToken) authentication; - - OAuth2AccessToken accessToken = accessTokenAuthentication.getAccessToken(); - OAuth2RefreshToken refreshToken = accessTokenAuthentication.getRefreshToken(); - Map additionalParameters = accessTokenAuthentication.getAdditionalParameters(); - - OAuth2AccessTokenResponse.Builder builder = - OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue()) - .tokenType(accessToken.getTokenType()) - .scopes(accessToken.getScopes()); - if (accessToken.getIssuedAt() != null && accessToken.getExpiresAt() != null) { - builder.expiresIn(ChronoUnit.SECONDS.between(accessToken.getIssuedAt(), accessToken.getExpiresAt())); - } - if (refreshToken != null) { - builder.refreshToken(refreshToken.getTokenValue()); - } - if (!CollectionUtils.isEmpty(additionalParameters)) { - builder.additionalParameters(additionalParameters); - } - OAuth2AccessTokenResponse accessTokenResponse = builder.build(); - ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); - this.accessTokenHttpResponseConverter.write(accessTokenResponse, null, httpResponse); - } - private static void throwError(String errorCode, String parameterName) { OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, DEFAULT_ERROR_URI); throw new OAuth2AuthenticationException(error); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AccessTokenResponseAuthenticationSuccessHandler.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AccessTokenResponseAuthenticationSuccessHandler.java new file mode 100644 index 0000000000..70891df876 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AccessTokenResponseAuthenticationSuccessHandler.java @@ -0,0 +1,125 @@ +/* + * Copyright 2020-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web.authentication; + +import java.io.IOException; +import java.time.temporal.ChronoUnit; +import java.util.Map; +import java.util.function.Consumer; + +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.server.ServletServerHttpResponse; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.*; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationContext; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + +/** + * An implementation of an {@link AuthenticationSuccessHandler} used for handling an {@link OAuth2AccessTokenAuthenticationToken} + * and returning the {@link OAuth2AccessTokenResponse Access Token Response}. + * + * @author Dmitriy Dubson + * @see AuthenticationSuccessHandler + * @see OAuth2AccessTokenResponseHttpMessageConverter + * @since 1.3 + */ +public final class OAuth2AccessTokenResponseAuthenticationSuccessHandler implements AuthenticationSuccessHandler { + private final Log logger = LogFactory.getLog(getClass()); + + private HttpMessageConverter accessTokenResponseConverter = + new OAuth2AccessTokenResponseHttpMessageConverter(); + + private Consumer accessTokenResponseCustomizer; + + @Override + public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws IOException, ServletException { + if (!(authentication instanceof OAuth2AccessTokenAuthenticationToken accessTokenAuthentication)) { + if (this.logger.isErrorEnabled()) { + this.logger.error(Authentication.class.getSimpleName() + " must be of type " + + OAuth2AccessTokenAuthenticationToken.class.getName() + + " but was " + authentication.getClass().getName()); + } + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, "Unable to process the access token response.", null); + throw new OAuth2AuthenticationException(error); + } + + OAuth2AccessToken accessToken = accessTokenAuthentication.getAccessToken(); + OAuth2RefreshToken refreshToken = accessTokenAuthentication.getRefreshToken(); + Map additionalParameters = accessTokenAuthentication.getAdditionalParameters(); + + OAuth2AccessTokenResponse.Builder builder = + OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue()) + .tokenType(accessToken.getTokenType()) + .scopes(accessToken.getScopes()); + if (accessToken.getIssuedAt() != null && accessToken.getExpiresAt() != null) { + builder.expiresIn(ChronoUnit.SECONDS.between(accessToken.getIssuedAt(), accessToken.getExpiresAt())); + } + if (refreshToken != null) { + builder.refreshToken(refreshToken.getTokenValue()); + } + if (!CollectionUtils.isEmpty(additionalParameters)) { + builder.additionalParameters(additionalParameters); + } + + if (this.accessTokenResponseCustomizer != null) { + // @formatter:off + OAuth2AccessTokenAuthenticationContext accessTokenAuthenticationContext = + OAuth2AccessTokenAuthenticationContext.with(accessTokenAuthentication) + .accessTokenResponse(builder) + .build(); + // @formatter:on + this.accessTokenResponseCustomizer.accept(accessTokenAuthenticationContext); + if (this.logger.isTraceEnabled()) { + this.logger.trace("Customized access token response"); + } + } + + OAuth2AccessTokenResponse accessTokenResponse = builder.build(); + ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); + this.accessTokenResponseConverter.write(accessTokenResponse, null, httpResponse); + } + + /** + * Sets the {@link HttpMessageConverter} used for converting an {@link OAuth2AccessTokenResponse} to an HTTP response. + * + * @param accessTokenResponseConverter the {@link HttpMessageConverter} used for converting an {@link OAuth2AccessTokenResponse} to an HTTP response + */ + public void setAccessTokenResponseConverter(HttpMessageConverter accessTokenResponseConverter) { + Assert.notNull(accessTokenResponseConverter, "accessTokenHttpResponseConverter cannot be null"); + this.accessTokenResponseConverter = accessTokenResponseConverter; + } + + /** + * Sets the {@code Consumer} providing access to the {@link OAuth2AccessTokenAuthenticationContext} + * containing an {@link OAuth2AccessTokenResponse.Builder} and additional context information. + * + * @param accessTokenResponseCustomizer the {@code Consumer} providing access to the {@link OAuth2AccessTokenAuthenticationContext} containing an {@link OAuth2AccessTokenResponse.Builder} + */ + public void setAccessTokenResponseCustomizer(Consumer accessTokenResponseCustomizer) { + Assert.notNull(accessTokenResponseCustomizer, "accessTokenResponseCustomizer cannot be null"); + this.accessTokenResponseCustomizer = accessTokenResponseCustomizer; + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationContextTest.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationContextTest.java new file mode 100644 index 0000000000..dbf48dfdd8 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationContextTest.java @@ -0,0 +1,70 @@ +/* + * Copyright 2020-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.authentication; + + +import org.junit.jupiter.api.Test; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + +import java.security.Principal; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link OAuth2AccessTokenAuthenticationContext} + * + * @author Dmitriy Dubson + */ +public class OAuth2AccessTokenAuthenticationContextTest { + private final RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + private final OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(this.registeredClient).build(); + private final Authentication principal = this.authorization.getAttribute(Principal.class.getName()); + OAuth2AccessTokenAuthenticationToken accessTokenAuthenticationToken = new OAuth2AccessTokenAuthenticationToken(registeredClient, principal, + authorization.getAccessToken().getToken(), authorization.getRefreshToken().getToken()); + + @Test + public void withWhenAuthenticationNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AccessTokenAuthenticationContext.with(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authentication cannot be null"); + } + + @Test + public void setWhenValueNullThenThrowIllegalArgumentException() { + OAuth2AccessTokenAuthenticationContext.Builder builder = + OAuth2AccessTokenAuthenticationContext.with(accessTokenAuthenticationToken); + + assertThatThrownBy(() -> builder.accessTokenResponse(null)) + .isInstanceOf(IllegalArgumentException.class).hasMessage("value cannot be null"); + } + @Test + public void buildWhenAllValuesProvidedThenAllValuesAreSet() { + OAuth2AccessTokenResponse.Builder accessTokenResponseBuilder = OAuth2AccessTokenResponse.withToken(accessTokenAuthenticationToken.getAccessToken().getTokenValue()); + OAuth2AccessTokenAuthenticationContext context = + OAuth2AccessTokenAuthenticationContext.with(accessTokenAuthenticationToken) + .accessTokenResponse(accessTokenResponseBuilder) + .build(); + + assertThat(context.getAuthentication()).isEqualTo(accessTokenAuthenticationToken); + assertThat(context.getAccessTokenResponse()).isEqualTo(accessTokenResponseBuilder); + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AccessTokenResponseAuthenticationSuccessHandlerTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AccessTokenResponseAuthenticationSuccessHandlerTests.java new file mode 100644 index 0000000000..406a4b2261 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AccessTokenResponseAuthenticationSuccessHandlerTests.java @@ -0,0 +1,195 @@ +/* + * Copyright 2020-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.web.authentication; + +import jakarta.servlet.ServletException; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.server.ServletServerHttpResponse; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationContext; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.within; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +/** + * Tests for {@link OAuth2AccessTokenResponseAuthenticationSuccessHandler}. + * + * @author Dmitriy Dubson + */ +public class OAuth2AccessTokenResponseAuthenticationSuccessHandlerTests { + RegisteredClient testClient = TestRegisteredClients.registeredClient().build(); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( + testClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, testClient.getClientSecret()); + + private final OAuth2AccessTokenResponseAuthenticationSuccessHandler authenticationSuccessHandler = new OAuth2AccessTokenResponseAuthenticationSuccessHandler(); + + @Test + public void setAccessTokenHttpResponseConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatThrownBy(() -> this.authenticationSuccessHandler.setAccessTokenResponseConverter(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("accessTokenHttpResponseConverter cannot be null"); + // @formatter:on + } + + @Test + public void setAccessTokenResponseCustomizerWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatThrownBy(() -> this.authenticationSuccessHandler.setAccessTokenResponseCustomizer(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("accessTokenResponseCustomizer cannot be null"); + // @formatter:on + } + + @Test + public void onAuthenticationSuccessWritesAccessTokenToHttpResponse() throws ServletException, IOException { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plusSeconds(300); + OAuth2Authorization testAuthorization = TestOAuth2Authorizations.authorization(testClient).build(); + Map additionalParameters = Collections.singletonMap("param1", "value1"); + Authentication authentication = new OAuth2AccessTokenAuthenticationToken(testClient, clientPrincipal, + testAuthorization.getAccessToken().getToken(), testAuthorization.getRefreshToken().getToken(), + additionalParameters); + + HttpMessageConverter responseConverter = mock(HttpMessageConverter.class); + authenticationSuccessHandler.setAccessTokenResponseConverter(responseConverter); + authenticationSuccessHandler.onAuthenticationSuccess(request, response, authentication); + + ArgumentCaptor accessTokenResponseCaptor = ArgumentCaptor.forClass(OAuth2AccessTokenResponse.class); + ArgumentCaptor servletServerHttpResponseArgumentCaptor = ArgumentCaptor.forClass(ServletServerHttpResponse.class); + verify(responseConverter).write(accessTokenResponseCaptor.capture(), isNull(), servletServerHttpResponseArgumentCaptor.capture()); + + OAuth2AccessTokenResponse actualAccessTokenResponse = accessTokenResponseCaptor.getValue(); + assertThat(actualAccessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token"); + assertThat(actualAccessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(actualAccessTokenResponse.getAccessToken().getIssuedAt()).isCloseTo(issuedAt, within(2, ChronoUnit.SECONDS)); + assertThat(actualAccessTokenResponse.getAccessToken().getExpiresAt()).isCloseTo(expiresAt, within(2, ChronoUnit.SECONDS)); + assertThat(actualAccessTokenResponse.getRefreshToken()).isNotNull(); + assertThat(actualAccessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo("refresh-token"); + assertThat(actualAccessTokenResponse.getAdditionalParameters()).containsExactlyEntriesOf(additionalParameters); + + assertThat(servletServerHttpResponseArgumentCaptor.getValue().getServletResponse()).isEqualTo(response); + } + + @Test + public void onAuthenticationSuccessAuthenticationIsNotInstanceOfOAuth2AccessTokenAuthenticationToken() { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + + HttpMessageConverter responseConverter = mock(HttpMessageConverter.class); + authenticationSuccessHandler.setAccessTokenResponseConverter(responseConverter); + + assertThatThrownBy(() -> + authenticationSuccessHandler.onAuthenticationSuccess(request, response, new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, Set.of(), Map.of()))) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.SERVER_ERROR); + + verifyNoInteractions(responseConverter); + } + + @Test + public void onAuthenticationSuccessAccessTokenResponseIsCustomizedViaAccessTokenResponseCustomizer() throws ServletException, IOException { + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + OAuth2AuthorizationService authorizationService = new InMemoryOAuth2AuthorizationService(); + OAuth2Authorization testAuthorization = TestOAuth2Authorizations.authorization(testClient).build(); + authorizationService.save(testAuthorization); + + Instant issuedAt = Instant.now(); + Instant expiresAt = issuedAt.plusSeconds(300); + var accessToken = testAuthorization.getAccessToken().getToken(); + var refreshToken = testAuthorization.getRefreshToken().getToken(); + Map additionalParameters = Collections.singletonMap("param1", "value1"); + Authentication authentication = new OAuth2AccessTokenAuthenticationToken(testClient, clientPrincipal, accessToken, refreshToken, additionalParameters); + + Consumer accessTokenResponseCustomizer = (OAuth2AccessTokenAuthenticationContext authenticationContext) -> { + OAuth2AccessTokenAuthenticationToken authenticationToken = authenticationContext.getAuthentication(); + OAuth2AccessTokenResponse.Builder accessTokenResponse = authenticationContext.getAccessTokenResponse(); + OAuth2Authorization authorization = authorizationService.findByToken( + authenticationToken.getAccessToken().getTokenValue(), + OAuth2TokenType.ACCESS_TOKEN + ); + Map customParams = Map.of( + "authorization_id", authorization.getId(), + "registered_client_id", authorization.getRegisteredClientId() + ); + Map allParams = new HashMap<>(authenticationToken.getAdditionalParameters()); + allParams.putAll(customParams); + accessTokenResponse.additionalParameters(allParams); + }; + + HttpMessageConverter responseConverter = mock(HttpMessageConverter.class); + authenticationSuccessHandler.setAccessTokenResponseConverter(responseConverter); + authenticationSuccessHandler.setAccessTokenResponseCustomizer(accessTokenResponseCustomizer); + authenticationSuccessHandler.onAuthenticationSuccess(request, response, authentication); + + ArgumentCaptor accessTokenResponseCaptor = ArgumentCaptor.forClass(OAuth2AccessTokenResponse.class); + ArgumentCaptor servletServerHttpResponseArgumentCaptor = ArgumentCaptor.forClass(ServletServerHttpResponse.class); + verify(responseConverter).write(accessTokenResponseCaptor.capture(), isNull(), servletServerHttpResponseArgumentCaptor.capture()); + + OAuth2AccessTokenResponse actualAccessTokenResponse = accessTokenResponseCaptor.getValue(); + assertThat(actualAccessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token"); + assertThat(actualAccessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(actualAccessTokenResponse.getAccessToken().getIssuedAt()).isCloseTo(issuedAt, within(2, ChronoUnit.SECONDS)); + assertThat(actualAccessTokenResponse.getAccessToken().getExpiresAt()).isCloseTo(expiresAt, within(2, ChronoUnit.SECONDS)); + assertThat(actualAccessTokenResponse.getRefreshToken()).isNotNull(); + assertThat(actualAccessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo("refresh-token"); + assertThat(actualAccessTokenResponse.getAdditionalParameters()).containsExactlyInAnyOrderEntriesOf( + Map.of("param1", "value1", "authorization_id", "id", "registered_client_id", "registration-1") + ); + + assertThat(servletServerHttpResponseArgumentCaptor.getValue().getServletResponse()).isEqualTo(response); + } +}