Bläddra i källkod

Add ability to customize the access token response

Issue gh-925

Closes gh-1429
Dmitriy Dubson 1 år sedan
förälder
incheckning
d4ae69bfa8

+ 1 - 1
docs/modules/ROOT/pages/protocol-endpoints.adoc

@@ -263,7 +263,7 @@ The supported https://datatracker.ietf.org/doc/html/rfc6749#section-1.3[authoriz
 
 * `*AuthenticationConverter*` -- A `DelegatingAuthenticationConverter` composed of `OAuth2AuthorizationCodeAuthenticationConverter`, `OAuth2RefreshTokenAuthenticationConverter`, `OAuth2ClientCredentialsAuthenticationConverter`, and `OAuth2DeviceCodeAuthenticationConverter`.
 * `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OAuth2AuthorizationCodeAuthenticationProvider`, `OAuth2RefreshTokenAuthenticationProvider`, `OAuth2ClientCredentialsAuthenticationProvider`, and `OAuth2DeviceCodeAuthenticationProvider`.
-* `*AuthenticationSuccessHandler*` -- An internal implementation that handles an `OAuth2AccessTokenAuthenticationToken` and returns the `OAuth2AccessTokenResponse`.
+* `*AuthenticationSuccessHandler*` -- An `OAuth2AccessTokenResponseAuthenticationSuccessHandler`.
 * `*AuthenticationFailureHandler*` -- An `OAuth2ErrorAuthenticationFailureHandler`.
 
 [[oauth2-token-endpoint-customizing-client-credentials-grant-request-validation]]

+ 104 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationContext.java

@@ -0,0 +1,104 @@
+/*
+ * Copyright 2020-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.authentication;
+
+import org.springframework.lang.Nullable;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AccessTokenResponseAuthenticationSuccessHandler;
+import org.springframework.util.Assert;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.function.Consumer;
+
+/**
+ * An {@link OAuth2AuthenticationContext} that holds an {@link OAuth2AccessTokenResponse.Builder}
+ * and is used when customizing the building of the {@link OAuth2AccessTokenResponse}.
+ *
+ * @author Dmitriy Dubson
+ * @see OAuth2AuthenticationContext
+ * @see OAuth2AccessTokenResponse
+ * @see OAuth2AccessTokenResponseAuthenticationSuccessHandler#setAccessTokenResponseCustomizer(Consumer)
+ * @since 1.3
+ */
+public final class OAuth2AccessTokenAuthenticationContext implements OAuth2AuthenticationContext {
+	private final Map<Object, Object> context;
+
+	private OAuth2AccessTokenAuthenticationContext(Map<Object, Object> context) {
+		this.context = Collections.unmodifiableMap(new HashMap<>(context));
+	}
+
+	@SuppressWarnings("unchecked")
+	@Nullable
+	@Override
+	public <V> V get(Object key) {
+		return hasKey(key) ? (V) this.context.get(key) : null;
+	}
+
+	@Override
+	public boolean hasKey(Object key) {
+		Assert.notNull(key, "key cannot be null");
+		return this.context.containsKey(key);
+	}
+
+	/**
+	 * Returns the {@link OAuth2AccessTokenResponse.Builder} access token response builder
+	 * @return the {@link OAuth2AccessTokenResponse.Builder}
+	 */
+	public OAuth2AccessTokenResponse.Builder getAccessTokenResponse() {
+		return get(OAuth2AccessTokenResponse.Builder.class);
+	}
+
+	/**
+	 * Constructs a new {@link Builder} with the provided {@link OAuth2AccessTokenAuthenticationToken}.
+	 *
+	 * @param authentication the {@link OAuth2AccessTokenAuthenticationToken}
+	 * @return the {@link Builder}
+	 */
+	public static OAuth2AccessTokenAuthenticationContext.Builder with(OAuth2AccessTokenAuthenticationToken authentication) {
+		return new OAuth2AccessTokenAuthenticationContext.Builder(authentication);
+	}
+
+	/**
+	 * A builder for {@link OAuth2AccessTokenAuthenticationContext}
+	 */
+	public static final class Builder extends AbstractBuilder<OAuth2AccessTokenAuthenticationContext, Builder> {
+		private Builder(OAuth2AccessTokenAuthenticationToken authentication) {
+			super(authentication);
+		}
+
+		/**
+		 * Sets the {@link OAuth2AccessTokenResponse.Builder} access token response builder
+		 * @param accessTokenResponse the {@link OAuth2AccessTokenResponse.Builder}
+		 * @return the {@link Builder} for further configuration
+		 */
+		public Builder accessTokenResponse(OAuth2AccessTokenResponse.Builder accessTokenResponse) {
+			return put(OAuth2AccessTokenResponse.Builder.class, accessTokenResponse);
+		}
+
+		/**
+		 * Builds a new {@link OAuth2AccessTokenAuthenticationContext}.
+		 *
+		 * @return the {@link OAuth2AccessTokenAuthenticationContext}
+		 */
+		public OAuth2AccessTokenAuthenticationContext build() {
+			Assert.notNull(get(OAuth2AccessTokenResponse.Builder.class), "accessTokenResponse cannot be null");
+
+			return new OAuth2AccessTokenAuthenticationContext(getContext());
+		}
+	}
+}

+ 4 - 41
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2023 the original author or authors.
+ * Copyright 2020-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,32 +16,24 @@
 package org.springframework.security.oauth2.server.authorization.web;
 
 import java.io.IOException;
-import java.time.temporal.ChronoUnit;
 import java.util.Arrays;
-import java.util.Map;
 
 import jakarta.servlet.FilterChain;
 import jakarta.servlet.ServletException;
 import jakarta.servlet.http.HttpServletRequest;
 import jakarta.servlet.http.HttpServletResponse;
-
 import org.springframework.core.log.LogMessage;
 import org.springframework.http.HttpMethod;
-import org.springframework.http.converter.HttpMessageConverter;
-import org.springframework.http.server.ServletServerHttpResponse;
 import org.springframework.security.authentication.AbstractAuthenticationToken;
 import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
-import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
-import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
-import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationGrantAuthenticationToken;
@@ -54,6 +46,7 @@ import org.springframework.security.oauth2.server.authorization.web.authenticati
 import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2DeviceCodeAuthenticationConverter;
 import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler;
 import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2RefreshTokenAuthenticationConverter;
+import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AccessTokenResponseAuthenticationSuccessHandler;
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
@@ -61,7 +54,6 @@ import org.springframework.security.web.authentication.WebAuthenticationDetailsS
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
-import org.springframework.util.CollectionUtils;
 import org.springframework.web.filter.OncePerRequestFilter;
 
 /**
@@ -86,6 +78,7 @@ import org.springframework.web.filter.OncePerRequestFilter;
  * @author Joe Grandja
  * @author Madhu Bhat
  * @author Daniel Garnier-Moiroux
+ * @author Dmitriy Dubson
  * @since 0.0.1
  * @see AuthenticationManager
  * @see OAuth2AuthorizationCodeAuthenticationProvider
@@ -103,12 +96,10 @@ public final class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
 	private static final String DEFAULT_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2";
 	private final AuthenticationManager authenticationManager;
 	private final RequestMatcher tokenEndpointMatcher;
-	private final HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter =
-			new OAuth2AccessTokenResponseHttpMessageConverter();
 	private AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource =
 			new WebAuthenticationDetailsSource();
 	private AuthenticationConverter authenticationConverter;
-	private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendAccessTokenResponse;
+	private AuthenticationSuccessHandler authenticationSuccessHandler = new OAuth2AccessTokenResponseAuthenticationSuccessHandler();
 	private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ErrorAuthenticationFailureHandler();
 
 	/**
@@ -218,34 +209,6 @@ public final class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
 		this.authenticationFailureHandler = authenticationFailureHandler;
 	}
 
-	private void sendAccessTokenResponse(HttpServletRequest request, HttpServletResponse response,
-			Authentication authentication) throws IOException {
-
-		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
-				(OAuth2AccessTokenAuthenticationToken) authentication;
-
-		OAuth2AccessToken accessToken = accessTokenAuthentication.getAccessToken();
-		OAuth2RefreshToken refreshToken = accessTokenAuthentication.getRefreshToken();
-		Map<String, Object> additionalParameters = accessTokenAuthentication.getAdditionalParameters();
-
-		OAuth2AccessTokenResponse.Builder builder =
-				OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue())
-						.tokenType(accessToken.getTokenType())
-						.scopes(accessToken.getScopes());
-		if (accessToken.getIssuedAt() != null && accessToken.getExpiresAt() != null) {
-			builder.expiresIn(ChronoUnit.SECONDS.between(accessToken.getIssuedAt(), accessToken.getExpiresAt()));
-		}
-		if (refreshToken != null) {
-			builder.refreshToken(refreshToken.getTokenValue());
-		}
-		if (!CollectionUtils.isEmpty(additionalParameters)) {
-			builder.additionalParameters(additionalParameters);
-		}
-		OAuth2AccessTokenResponse accessTokenResponse = builder.build();
-		ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
-		this.accessTokenHttpResponseConverter.write(accessTokenResponse, null, httpResponse);
-	}
-
 	private static void throwError(String errorCode, String parameterName) {
 		OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, DEFAULT_ERROR_URI);
 		throw new OAuth2AuthenticationException(error);

+ 115 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AccessTokenResponseAuthenticationSuccessHandler.java

@@ -0,0 +1,115 @@
+/*
+ * Copyright 2020-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.web.authentication;
+
+import java.io.IOException;
+import java.time.temporal.ChronoUnit;
+import java.util.Map;
+import java.util.function.Consumer;
+
+import jakarta.servlet.ServletException;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.springframework.http.converter.HttpMessageConverter;
+import org.springframework.http.server.ServletServerHttpResponse;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.core.*;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationContext;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
+import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
+
+/**
+ * An implementation of an {@link AuthenticationSuccessHandler} used for handling an {@link OAuth2AccessTokenAuthenticationToken}
+ * and returning the {@link OAuth2AccessTokenResponse Access Token Response}.
+ *
+ * @author Dmitriy Dubson
+ * @see AuthenticationSuccessHandler
+ * @see OAuth2AccessTokenResponseHttpMessageConverter
+ * @since 1.3
+ */
+public final class OAuth2AccessTokenResponseAuthenticationSuccessHandler implements AuthenticationSuccessHandler {
+	private final Log logger = LogFactory.getLog(getClass());
+
+	private final HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenResponseConverter =
+			new OAuth2AccessTokenResponseHttpMessageConverter();
+
+	private Consumer<OAuth2AccessTokenAuthenticationContext> accessTokenResponseCustomizer;
+
+	@Override
+	public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws IOException, ServletException {
+		if (!(authentication instanceof OAuth2AccessTokenAuthenticationToken accessTokenAuthentication)) {
+			if (this.logger.isErrorEnabled()) {
+				this.logger.error(Authentication.class.getSimpleName() + " must be of type " +
+						OAuth2AccessTokenAuthenticationToken.class.getName() +
+						" but was " + authentication.getClass().getName());
+			}
+			OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, "Unable to process the access token response.", null);
+			throw new OAuth2AuthenticationException(error);
+		}
+
+		OAuth2AccessToken accessToken = accessTokenAuthentication.getAccessToken();
+		OAuth2RefreshToken refreshToken = accessTokenAuthentication.getRefreshToken();
+		Map<String, Object> additionalParameters = accessTokenAuthentication.getAdditionalParameters();
+
+		OAuth2AccessTokenResponse.Builder builder =
+				OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue())
+						.tokenType(accessToken.getTokenType())
+						.scopes(accessToken.getScopes());
+		if (accessToken.getIssuedAt() != null && accessToken.getExpiresAt() != null) {
+			builder.expiresIn(ChronoUnit.SECONDS.between(accessToken.getIssuedAt(), accessToken.getExpiresAt()));
+		}
+		if (refreshToken != null) {
+			builder.refreshToken(refreshToken.getTokenValue());
+		}
+		if (!CollectionUtils.isEmpty(additionalParameters)) {
+			builder.additionalParameters(additionalParameters);
+		}
+
+		if (this.accessTokenResponseCustomizer != null) {
+			// @formatter:off
+			OAuth2AccessTokenAuthenticationContext accessTokenAuthenticationContext =
+					OAuth2AccessTokenAuthenticationContext.with(accessTokenAuthentication)
+						.accessTokenResponse(builder)
+						.build();
+			// @formatter:on
+			this.accessTokenResponseCustomizer.accept(accessTokenAuthenticationContext);
+			if (this.logger.isTraceEnabled()) {
+				this.logger.trace("Customized access token response");
+			}
+		}
+
+		OAuth2AccessTokenResponse accessTokenResponse = builder.build();
+		ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
+		this.accessTokenResponseConverter.write(accessTokenResponse, null, httpResponse);
+	}
+
+	/**
+	 * Sets the {@code Consumer} providing access to the {@link OAuth2AccessTokenAuthenticationContext}
+	 * containing an {@link OAuth2AccessTokenResponse.Builder} and additional context information.
+	 *
+	 * @param accessTokenResponseCustomizer the {@code Consumer} providing access to the {@link OAuth2AccessTokenAuthenticationContext} containing an {@link OAuth2AccessTokenResponse.Builder}
+	 */
+	public void setAccessTokenResponseCustomizer(Consumer<OAuth2AccessTokenAuthenticationContext> accessTokenResponseCustomizer) {
+		Assert.notNull(accessTokenResponseCustomizer, "accessTokenResponseCustomizer cannot be null");
+		this.accessTokenResponseCustomizer = accessTokenResponseCustomizer;
+	}
+}

+ 71 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationContextTest.java

@@ -0,0 +1,71 @@
+/*
+ * Copyright 2020-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.authentication;
+
+
+import org.junit.jupiter.api.Test;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+
+import java.security.Principal;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Tests for {@link OAuth2AccessTokenAuthenticationContext}
+ *
+ * @author Dmitriy Dubson
+ */
+public class OAuth2AccessTokenAuthenticationContextTest {
+	private final RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+	private final OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(this.registeredClient).build();
+	private final Authentication principal = this.authorization.getAttribute(Principal.class.getName());
+	private final OAuth2AccessTokenAuthenticationToken accessTokenAuthenticationToken = new OAuth2AccessTokenAuthenticationToken(registeredClient, principal,
+			authorization.getAccessToken().getToken(), authorization.getRefreshToken().getToken());
+
+	@Test
+	public void withWhenAuthenticationNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> OAuth2AccessTokenAuthenticationContext.with(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authentication cannot be null");
+	}
+
+	@Test
+	public void setWhenValueNullThenThrowIllegalArgumentException() {
+		OAuth2AccessTokenAuthenticationContext.Builder builder =
+				OAuth2AccessTokenAuthenticationContext.with(this.accessTokenAuthenticationToken);
+
+		assertThatThrownBy(() -> builder.accessTokenResponse(null))
+				.isInstanceOf(IllegalArgumentException.class).hasMessage("value cannot be null");
+	}
+
+	@Test
+	public void buildWhenAllValuesProvidedThenAllValuesAreSet() {
+		OAuth2AccessTokenResponse.Builder accessTokenResponseBuilder = OAuth2AccessTokenResponse.withToken(this.accessTokenAuthenticationToken.getAccessToken().getTokenValue());
+		OAuth2AccessTokenAuthenticationContext context =
+				OAuth2AccessTokenAuthenticationContext.with(this.accessTokenAuthenticationToken)
+						.accessTokenResponse(accessTokenResponseBuilder)
+						.build();
+
+		assertThat(context.<Authentication>getAuthentication()).isEqualTo(this.accessTokenAuthenticationToken);
+		assertThat(context.getAccessTokenResponse()).isEqualTo(accessTokenResponseBuilder);
+	}
+}

+ 173 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AccessTokenResponseAuthenticationSuccessHandlerTests.java

@@ -0,0 +1,173 @@
+/*
+ * Copyright 2020-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.web.authentication;
+
+import org.junit.jupiter.api.Test;
+import org.springframework.http.HttpStatus;
+import org.springframework.http.converter.HttpMessageConverter;
+import org.springframework.mock.http.client.MockClientHttpResponse;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
+import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
+import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationContext;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationToken;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Consumer;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.assertj.core.api.Assertions.within;
+
+/**
+ * Tests for {@link OAuth2AccessTokenResponseAuthenticationSuccessHandler}.
+ *
+ * @author Dmitriy Dubson
+ */
+public class OAuth2AccessTokenResponseAuthenticationSuccessHandlerTests {
+	private final RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+
+	private final HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter =
+			new OAuth2AccessTokenResponseHttpMessageConverter();
+
+	private final OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
+			this.registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, this.registeredClient.getClientSecret());
+
+	private final OAuth2AccessTokenResponseAuthenticationSuccessHandler authenticationSuccessHandler = new OAuth2AccessTokenResponseAuthenticationSuccessHandler();
+
+	@Test
+	public void setAccessTokenResponseCustomizerWhenNullThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatThrownBy(() -> this.authenticationSuccessHandler.setAccessTokenResponseCustomizer(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("accessTokenResponseCustomizer cannot be null");
+		// @formatter:on
+	}
+
+	@Test
+	public void onAuthenticationSuccessWhenProvidedRequestResponseAndAuthThenWritesAccessTokenToHttpResponse() throws Exception {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+
+		Instant issuedAt = Instant.now();
+		Instant expiresAt = issuedAt.plusSeconds(300);
+		OAuth2Authorization testAuthorization = TestOAuth2Authorizations.authorization(this.registeredClient).build();
+		Map<String, Object> additionalParameters = Collections.singletonMap("param1", "value1");
+		Authentication authentication = new OAuth2AccessTokenAuthenticationToken(this.registeredClient, clientPrincipal,
+				testAuthorization.getAccessToken().getToken(), testAuthorization.getRefreshToken().getToken(),
+				additionalParameters);
+
+		this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authentication);
+
+		OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response);
+		assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token");
+		assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER);
+		assertThat(accessTokenResponse.getAccessToken().getIssuedAt()).isCloseTo(issuedAt, within(2, ChronoUnit.SECONDS));
+		assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isCloseTo(expiresAt, within(2, ChronoUnit.SECONDS));
+		assertThat(accessTokenResponse.getRefreshToken()).isNotNull();
+		assertThat(accessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo("refresh-token");
+		assertThat(accessTokenResponse.getAdditionalParameters()).containsExactlyInAnyOrderEntriesOf(
+				Map.of("param1", "value1")
+		);
+	}
+
+	@Test
+	public void onAuthenticationSuccessWhenAuthenticationIsNotInstanceOfOAuth2AccessTokenAuthenticationTokenThenThrowOAuth2AuthenticationException() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+
+		assertThatThrownBy(() ->
+				this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, Set.of(), Map.of())))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.SERVER_ERROR);
+	}
+
+	@Test
+	public void onAuthenticationSuccessWhenAccessTokenResponseIsCustomizedViaAccessTokenResponseCustomizerThenResponseHasCustomizedFields() throws Exception {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		OAuth2AuthorizationService authorizationService = new InMemoryOAuth2AuthorizationService();
+		OAuth2Authorization testAuthorization = TestOAuth2Authorizations.authorization(this.registeredClient).build();
+		authorizationService.save(testAuthorization);
+
+		Instant issuedAt = Instant.now();
+		Instant expiresAt = issuedAt.plusSeconds(300);
+		OAuth2AccessToken accessToken = testAuthorization.getAccessToken().getToken();
+		OAuth2RefreshToken refreshToken = testAuthorization.getRefreshToken().getToken();
+		Map<String, Object> additionalParameters = Collections.singletonMap("param1", "value1");
+		Authentication authentication = new OAuth2AccessTokenAuthenticationToken(this.registeredClient, clientPrincipal, accessToken, refreshToken, additionalParameters);
+
+		Consumer<OAuth2AccessTokenAuthenticationContext> accessTokenResponseCustomizer = (OAuth2AccessTokenAuthenticationContext authenticationContext) -> {
+			OAuth2AccessTokenAuthenticationToken authenticationToken = authenticationContext.getAuthentication();
+			OAuth2AccessTokenResponse.Builder accessTokenResponse = authenticationContext.getAccessTokenResponse();
+			OAuth2Authorization authorization = authorizationService.findByToken(
+					authenticationToken.getAccessToken().getTokenValue(),
+					OAuth2TokenType.ACCESS_TOKEN
+			);
+			Map<String, Object> customParams = Map.of(
+					"authorization_id", authorization.getId(),
+					"registered_client_id", authorization.getRegisteredClientId()
+			);
+			Map<String, Object> allParams = new HashMap<>(authenticationToken.getAdditionalParameters());
+			allParams.putAll(customParams);
+			accessTokenResponse.additionalParameters(allParams);
+		};
+
+		this.authenticationSuccessHandler.setAccessTokenResponseCustomizer(accessTokenResponseCustomizer);
+		this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authentication);
+
+		OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response);
+		assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token");
+		assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER);
+		assertThat(accessTokenResponse.getAccessToken().getIssuedAt()).isCloseTo(issuedAt, within(2, ChronoUnit.SECONDS));
+		assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isCloseTo(expiresAt, within(2, ChronoUnit.SECONDS));
+		assertThat(accessTokenResponse.getRefreshToken()).isNotNull();
+		assertThat(accessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo("refresh-token");
+		assertThat(accessTokenResponse.getAdditionalParameters()).containsExactlyInAnyOrderEntriesOf(
+				Map.of("param1", "value1", "authorization_id", "id", "registered_client_id", "registration-1")
+		);
+	}
+
+	private OAuth2AccessTokenResponse readAccessTokenResponse(MockHttpServletResponse response) throws Exception {
+		MockClientHttpResponse httpResponse = new MockClientHttpResponse(
+				response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus()));
+		return this.accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse);
+	}
+
+}