소스 검색

Improve customizing OIDC UserInfo endpoint

Closes gh-785
Daniel Garnier-Moiroux 2 년 전
부모
커밋
8d7f8b3420

+ 18 - 2
docs/src/docs/asciidoc/protocol-endpoints.adoc

@@ -285,21 +285,37 @@ public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity h
 		.oidc(oidc ->
 			oidc
 				.userInfoEndpoint(userInfoEndpoint ->
-					userInfoEndpoint.userInfoMapper(userInfoMapper)   <1>
+					userInfoEndpoint
+						.userInfoRequestConverter(userInfoRequestConverter) <1>
+						.userInfoRequestConverters(userInfoRequestConvertersConsumer) <2>
+						.authenticationProvider(authenticationProvider) <3>
+						.authenticationProviders(authenticationProvidersConsumer) <4>
+						.userInfoResponseHandler(userInfoResponseHandler) <5>
+						.errorResponseHandler(errorResponseHandler) <6>
+						.userInfoMapper(userInfoMapper) <7>
 				)
 		);
 
 	return http.build();
 }
 ----
-<1> `userInfoMapper()`: The `Function` used to extract claims from `OidcUserInfoAuthenticationContext` to an instance of `OidcUserInfo`.
+<1> `userInfoRequestConverter()`: Adds an `AuthenticationConverter` (_pre-processor_) used when attempting to extract an https://openid.net/specs/openid-connect-core-1_0.html#UserInfoRequest[UserInfo request] from `HttpServletRequest` to an instance of `OidcUserInfoAuthenticationToken`.
+<2> `userInfoRequestConverters()`: Sets the `Consumer` providing access to the `List` of default and (optionally) added ``AuthenticationConverter``'s allowing the ability to add, remove, or customize a specific `AuthenticationConverter`.
+<3> `authenticationProvider()`: Adds an `AuthenticationProvider` (_main processor_) used for authenticating the `OidcUserInfoAuthenticationToken`.
+<4> `authenticationProviders()`: Sets the `Consumer` providing access to the `List` of default and (optionally) added ``AuthenticationProvider``'s allowing the ability to add, remove, or customize a specific `AuthenticationProvider`.
+<5> `revocationResponseHandler()`: The `AuthenticationSuccessHandler` (_post-processor_) used for handling an "`authenticated`" `OidcUserInfoAuthenticationToken` and returning the https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse[UserInfo response].
+<6> `userInfoResponseHandler()`: The `AuthenticationFailureHandler` (_post-processor_) used for handling an `OAuth2AuthenticationException` and returning the https://openid.net/specs/openid-connect-core-1_0.html#UserInfoError[UserInfo Error response].
+<7> `userInfoMapper()`: The `Function` used to extract claims from `OidcUserInfoAuthenticationContext` to an instance of `OidcUserInfo`.
 
 `OidcUserInfoEndpointConfigurer` configures the `OidcUserInfoEndpointFilter` and registers it with the OAuth2 authorization server `SecurityFilterChain` `@Bean`.
 `OidcUserInfoEndpointFilter` is the `Filter` that processes https://openid.net/specs/openid-connect-core-1_0.html#UserInfoRequest[UserInfo requests] and returns the https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse[OidcUserInfo response].
 
 `OidcUserInfoEndpointFilter` is configured with the following defaults:
 
+* `*AuthenticationConverter*` -- An internal implementation that obtains the `Authentication` from the `SecurityContext` and wraps the principal in an `OidcUserInfoAuthenticationToken`.
 * `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OidcUserInfoAuthenticationProvider`, which is associated with an internal implementation of `userInfoMapper` that extracts https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims[standard claims] from the https://openid.net/specs/openid-connect-core-1_0.html#IDToken[ID Token] based on the https://openid.net/specs/openid-connect-core-1_0.html#ScopeClaims[scopes requested] during authorization.
+* `*AuthenticationSuccessHandler*` -- An internal implementation that handles an "`authenticated`" `OidcUserInfoAuthenticationToken` and returns the UserInfo response.
+* `*AuthenticationFailureHandler*` -- An internal implementation that uses the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response.
 
 [TIP]
 You can customize the ID Token by providing an xref:core-model-components.adoc#oauth2-token-customizer[`OAuth2TokenCustomizer<JwtEncodingContext>`] `@Bean`.

+ 154 - 8
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcUserInfoEndpointConfigurer.java

@@ -15,13 +15,23 @@
  */
 package org.springframework.security.oauth2.server.authorization.config.annotation.web.configurers;
 
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Consumer;
 import java.util.function.Function;
 
+import javax.servlet.http.HttpServletRequest;
+
 import org.springframework.http.HttpMethod;
 import org.springframework.security.authentication.AuthenticationManager;
+import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.config.annotation.ObjectPostProcessor;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationContext;
@@ -29,21 +39,33 @@ import org.springframework.security.oauth2.server.authorization.oidc.authenticat
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.oidc.web.OidcUserInfoEndpointFilter;
 import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
+import org.springframework.security.oauth2.server.authorization.web.authentication.DelegatingAuthenticationConverter;
 import org.springframework.security.web.access.intercept.FilterSecurityInterceptor;
+import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.OrRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
+import org.springframework.util.Assert;
 
 /**
  * Configurer for OpenID Connect 1.0 UserInfo Endpoint.
  *
  * @author Steve Riesenberg
+ * @author Daniel Garnier-Moiroux
  * @since 0.2.1
  * @see OidcConfigurer#userInfoEndpoint
  * @see OidcUserInfoEndpointFilter
  */
 public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configurer {
 	private RequestMatcher requestMatcher;
+	private final List<AuthenticationConverter> userInfoRequestConverters = new ArrayList<>();
+	private Consumer<List<AuthenticationConverter>> userInfoRequestConvertersConsumer = (authenticationConverters) -> {};
+	private final List<AuthenticationProvider> authenticationProviders = new ArrayList<>();
+	private Consumer<List<AuthenticationProvider>> authenticationProvidersConsumer = (authenticationProviders) -> {};
+	private AuthenticationSuccessHandler userInfoResponseHandler;
+	private AuthenticationFailureHandler errorResponseHandler;
 	private Function<OidcUserInfoAuthenticationContext, OidcUserInfo> userInfoMapper;
 
 	/**
@@ -53,6 +75,91 @@ public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configur
 		super(objectPostProcessor);
 	}
 
+	/**
+	 * Sets the {@link AuthenticationConverter} used when attempting to extract the OAuth2 Access Token from {@link HttpServletRequest}
+	 * to an instance of {@link OidcUserInfoAuthenticationToken} used for authenticating the User Info request.
+	 *
+	 * @param userInfoRequestConverter the {@link AuthenticationConverter} used when attempting to extract an OIDC User Info from {@link HttpServletRequest}
+	 * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration
+	 * @since 0.4.0
+	 */
+	public OidcUserInfoEndpointConfigurer userInfoRequestConverter(AuthenticationConverter userInfoRequestConverter) {
+		Assert.notNull(userInfoRequestConverter, "userInfoRequestConverter cannot be null");
+		this.userInfoRequestConverters.add(userInfoRequestConverter);
+		return this;
+	}
+
+	/**
+	 * Sets the {@code Consumer} providing access to the {@code List} of default
+	 * and (optionally) added {@link #userInfoRequestConverter(AuthenticationConverter) AuthenticationConverter}'s
+	 * allowing the ability to add, remove, or customize a specific {@link AuthenticationConverter}.
+	 *
+	 * @param userInfoRequestConvertersConsumer the {@code Consumer} providing access to the {@code List} of default and (optionally) added {@link AuthenticationConverter}'s
+	 * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration
+	 * @since 0.4.0
+	 */
+	public OidcUserInfoEndpointConfigurer userInfoRequestConverters(
+			Consumer<List<AuthenticationConverter>> userInfoRequestConvertersConsumer) {
+		Assert.notNull(userInfoRequestConvertersConsumer, "userInfoRequestConvertersConsumer cannot be null");
+		this.userInfoRequestConvertersConsumer = userInfoRequestConvertersConsumer;
+		return this;
+	}
+
+	/**
+	 * Adds an {@link AuthenticationProvider} used for authenticating a type of {@link OidcUserInfoAuthenticationToken}.
+	 *
+	 * @param authenticationProvider a {@link AuthenticationProvider} used for authenticating a type of {@link OidcUserInfoAuthenticationToken}
+	 * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration
+	 * @since 0.4.0
+	 */
+	public OidcUserInfoEndpointConfigurer authenticationProvider(AuthenticationProvider authenticationProvider) {
+		Assert.notNull(authenticationProvider, "authenticationProvider cannot be null");
+		this.authenticationProviders.add(authenticationProvider);
+		return this;
+	}
+
+	/**
+	 * Sets the {@code Consumer} providing access to the {@code List} of default
+	 * and (optionally) added {@link #authenticationProvider(AuthenticationProvider) AuthenticationProvider}'s
+	 * allowing the ability to add, remove, or customize a specific {@link AuthenticationProvider}.
+	 *
+	 * @param authenticationProvidersConsumer the {@code Consumer} providing access to the {@code List} of default and (optionally) added {@link AuthenticationProvider}'s
+	 * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration
+	 * @since 0.4.0
+	 */
+	public OidcUserInfoEndpointConfigurer authenticationProviders(
+			Consumer<List<AuthenticationProvider>> authenticationProvidersConsumer) {
+		Assert.notNull(authenticationProvidersConsumer, "authenticationProvidersConsumer cannot be null");
+		this.authenticationProvidersConsumer = authenticationProvidersConsumer;
+		return this;
+	}
+
+	/**
+	 * Sets the {@link AuthenticationSuccessHandler} used for handling an {@link OidcUserInfoAuthenticationToken} and
+	 * returning the {@link OidcUserInfo User Info Response}.
+	 *
+	 * @param userInfoResponseHandler the {@link AuthenticationSuccessHandler} used for handling an {@link OidcUserInfoAuthenticationToken}
+	 * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration
+	 * @since 0.4.0
+	 */
+	public OidcUserInfoEndpointConfigurer userInfoResponseHandler(AuthenticationSuccessHandler userInfoResponseHandler) {
+		this.userInfoResponseHandler = userInfoResponseHandler;
+		return this;
+	}
+
+	/**
+	 * Sets the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException} and
+	 * returning the {@link OAuth2Error Error Response}.
+	 *
+	 * @param errorResponseHandler the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException}
+	 * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration
+	 * @since 0.4.0
+	 */
+	public OidcUserInfoEndpointConfigurer errorResponseHandler(AuthenticationFailureHandler errorResponseHandler) {
+		this.errorResponseHandler = errorResponseHandler;
+		return this;
+	}
+
 	/**
 	 * Sets the {@link Function} used to extract claims from {@link OidcUserInfoAuthenticationContext}
 	 * to an instance of {@link OidcUserInfo} for the UserInfo response.
@@ -67,9 +174,9 @@ public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configur
 	 * </ul>
 	 *
 	 * @param userInfoMapper the {@link Function} used to extract claims from {@link OidcUserInfoAuthenticationContext} to an instance of {@link OidcUserInfo}
-	 * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration
 	 */
-	public OidcUserInfoEndpointConfigurer userInfoMapper(Function<OidcUserInfoAuthenticationContext, OidcUserInfo> userInfoMapper) {
+	public OidcUserInfoEndpointConfigurer userInfoMapper(
+			Function<OidcUserInfoAuthenticationContext, OidcUserInfo> userInfoMapper) {
 		this.userInfoMapper = userInfoMapper;
 		return this;
 	}
@@ -82,13 +189,15 @@ public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configur
 				new AntPathRequestMatcher(userInfoEndpointUri, HttpMethod.GET.name()),
 				new AntPathRequestMatcher(userInfoEndpointUri, HttpMethod.POST.name()));
 
-		OidcUserInfoAuthenticationProvider oidcUserInfoAuthenticationProvider =
-				new OidcUserInfoAuthenticationProvider(
-						OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity));
-		if (this.userInfoMapper != null) {
-			oidcUserInfoAuthenticationProvider.setUserInfoMapper(this.userInfoMapper);
+		List<AuthenticationProvider> authenticationProviders = createDefaultAuthenticationProviders(httpSecurity);
+
+		if (!this.authenticationProviders.isEmpty()) {
+			authenticationProviders.addAll(0, this.authenticationProviders);
 		}
-		httpSecurity.authenticationProvider(postProcess(oidcUserInfoAuthenticationProvider));
+		this.authenticationProvidersConsumer.accept(authenticationProviders);
+
+		authenticationProviders.forEach(authenticationProvider ->
+				httpSecurity.authenticationProvider(postProcess(authenticationProvider)));
 	}
 
 	@Override
@@ -100,6 +209,19 @@ public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configur
 				new OidcUserInfoEndpointFilter(
 						authenticationManager,
 						authorizationServerSettings.getOidcUserInfoEndpoint());
+		List<AuthenticationConverter> authenticationConverters = createDefaultAuthenticationConverters();
+		if (!this.userInfoRequestConverters.isEmpty()) {
+			authenticationConverters.addAll(0, this.userInfoRequestConverters);
+		}
+		this.userInfoRequestConvertersConsumer.accept(authenticationConverters);
+		oidcUserInfoEndpointFilter.setAuthenticationConverter(
+				new DelegatingAuthenticationConverter(authenticationConverters));
+		if (this.userInfoResponseHandler != null) {
+			oidcUserInfoEndpointFilter.setAuthenticationSuccessHandler(this.userInfoResponseHandler);
+		}
+		if (this.errorResponseHandler != null) {
+			oidcUserInfoEndpointFilter.setAuthenticationFailureHandler(this.errorResponseHandler);
+		}
 		httpSecurity.addFilterAfter(postProcess(oidcUserInfoEndpointFilter), FilterSecurityInterceptor.class);
 	}
 
@@ -108,4 +230,28 @@ public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configur
 		return this.requestMatcher;
 	}
 
+	private static List<AuthenticationConverter> createDefaultAuthenticationConverters() {
+		List<AuthenticationConverter> authenticationConverters = new ArrayList<>();
+		authenticationConverters.add(
+				(request) -> {
+					Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
+					return new OidcUserInfoAuthenticationToken(authentication);
+				}
+		);
+		return authenticationConverters;
+	}
+
+	private List<AuthenticationProvider> createDefaultAuthenticationProviders(HttpSecurity httpSecurity) {
+		List<AuthenticationProvider> authenticationProviders = new ArrayList<>();
+
+		OidcUserInfoAuthenticationProvider oidcUserInfoAuthenticationProvider = new OidcUserInfoAuthenticationProvider(
+				OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity));
+		if (this.userInfoMapper != null) {
+			oidcUserInfoAuthenticationProvider.setUserInfoMapper(this.userInfoMapper);
+		}
+		authenticationProviders.add(oidcUserInfoAuthenticationProvider);
+
+		return authenticationProviders;
+	}
+
 }

+ 63 - 10
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilter.java

@@ -28,6 +28,7 @@ import org.springframework.http.converter.HttpMessageConverter;
 import org.springframework.http.server.ServletServerHttpResponse;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
@@ -36,6 +37,9 @@ import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMe
 import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcUserInfoHttpMessageConverter;
+import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.OrRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -61,11 +65,16 @@ public final class OidcUserInfoEndpointFilter extends OncePerRequestFilter {
 	private final AuthenticationManager authenticationManager;
 	private final RequestMatcher userInfoEndpointMatcher;
 
+	private AuthenticationConverter authenticationConverter = this::createAuthentication;
+
 	private final HttpMessageConverter<OidcUserInfo> userInfoHttpMessageConverter =
 			new OidcUserInfoHttpMessageConverter();
 	private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
 			new OAuth2ErrorHttpMessageConverter();
 
+	private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendUserInfoResponse;
+	private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse;
+
 	/**
 	 * Constructs an {@code OidcUserInfoEndpointFilter} using the provided parameters.
 	 *
@@ -100,34 +109,77 @@ public final class OidcUserInfoEndpointFilter extends OncePerRequestFilter {
 		}
 
 		try {
-			Authentication principal = SecurityContextHolder.getContext().getAuthentication();
-
-			OidcUserInfoAuthenticationToken userInfoAuthentication = new OidcUserInfoAuthenticationToken(principal);
+			Authentication userInfoAuthentication = this.authenticationConverter.convert(request);
 
 			OidcUserInfoAuthenticationToken userInfoAuthenticationResult =
 					(OidcUserInfoAuthenticationToken) this.authenticationManager.authenticate(userInfoAuthentication);
 
-			sendUserInfoResponse(response, userInfoAuthenticationResult.getUserInfo());
-
+			this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, userInfoAuthenticationResult);
 		} catch (OAuth2AuthenticationException ex) {
-			sendErrorResponse(response, ex.getError());
+			this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
 		} catch (Exception ex) {
 			OAuth2Error error = new OAuth2Error(
 					OAuth2ErrorCodes.INVALID_REQUEST,
 					"OpenID Connect 1.0 UserInfo Error: " + ex.getMessage(),
 					"https://openid.net/specs/openid-connect-core-1_0.html#UserInfoError");
-			sendErrorResponse(response, error);
+			this.authenticationFailureHandler.onAuthenticationFailure(request, response,
+					new OAuth2AuthenticationException(error));
 		} finally {
 			SecurityContextHolder.clearContext();
 		}
 	}
 
-	private void sendUserInfoResponse(HttpServletResponse response, OidcUserInfo userInfo) throws IOException {
+	/**
+	 * Sets the {@link AuthenticationConverter} used when attempting to extract the OAuth2 Access Token from {@link HttpServletRequest}
+	 * to an instance of {@link OidcUserInfoAuthenticationToken} used for authenticating the User Info request.
+	 *
+	 * @param authenticationConverter the {@link AuthenticationConverter} used when attempting to extract an OIDC User Info from {@link HttpServletRequest}
+	 * @since 0.4.0
+	 */
+	public void setAuthenticationConverter(AuthenticationConverter authenticationConverter) {
+		Assert.notNull(authenticationConverter, "authenticationConverter cannot be null");
+		this.authenticationConverter = authenticationConverter;
+	}
+
+	/**
+	 * Sets the {@link AuthenticationSuccessHandler} used for handling an {@link OidcUserInfoAuthenticationToken} and
+	 * returning the {@link OidcUserInfo OIDC User Info}.
+	 *
+	 * @param authenticationSuccessHandler the {@link AuthenticationSuccessHandler} for handling an {@link OidcUserInfoAuthenticationToken}
+	 * @since 0.4.0
+	 */
+	public void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) {
+		Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null");
+		this.authenticationSuccessHandler = authenticationSuccessHandler;
+	}
+
+	/**
+	 * Sets the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException}
+	 * and returning the {@link OAuth2Error Error Response}.
+	 *
+	 * @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException}
+	 * @since 0.4.0
+	 */
+	public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) {
+		Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
+		this.authenticationFailureHandler = authenticationFailureHandler;
+	}
+
+	private Authentication createAuthentication(HttpServletRequest request) {
+		Authentication principal = SecurityContextHolder.getContext().getAuthentication();
+		return new OidcUserInfoAuthenticationToken(principal);
+	}
+
+	private void sendUserInfoResponse(HttpServletRequest request, HttpServletResponse response,
+			Authentication authentication) throws IOException {
+		OidcUserInfoAuthenticationToken userInfoAuthenticationToken = (OidcUserInfoAuthenticationToken) authentication;
 		ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
-		this.userInfoHttpMessageConverter.write(userInfo, null, httpResponse);
+		this.userInfoHttpMessageConverter.write(userInfoAuthenticationToken.getUserInfo(), null, httpResponse);
 	}
 
-	private void sendErrorResponse(HttpServletResponse response, OAuth2Error error) throws IOException {
+	private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
+			AuthenticationException authenticationException) throws IOException {
+		OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError();
 		HttpStatus httpStatus = HttpStatus.BAD_REQUEST;
 		if (error.getErrorCode().equals(OAuth2ErrorCodes.INVALID_TOKEN)) {
 			httpStatus = HttpStatus.UNAUTHORIZED;
@@ -138,4 +190,5 @@ public final class OidcUserInfoEndpointFilter extends OncePerRequestFilter {
 		httpResponse.setStatusCode(httpStatus);
 		this.errorHttpResponseConverter.write(error, null, httpResponse);
 	}
+
 }

+ 128 - 11
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcUserInfoTests.java

@@ -19,9 +19,13 @@ import java.time.Instant;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Set;
+import java.util.function.Consumer;
 import java.util.function.Function;
 
+import javax.servlet.http.HttpServletResponse;
+
 import com.nimbusds.jose.jwk.JWKSet;
 import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
 import com.nimbusds.jose.jwk.source.JWKSource;
@@ -30,10 +34,13 @@ import org.junit.Before;
 import org.junit.BeforeClass;
 import org.junit.Rule;
 import org.junit.Test;
+import org.mockito.ArgumentCaptor;
 
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Bean;
 import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpStatus;
+import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.config.Customizer;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
@@ -61,11 +68,15 @@ import org.springframework.security.oauth2.server.authorization.client.Registere
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationContext;
+import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
 import org.springframework.security.oauth2.server.authorization.test.SpringTestRule;
 import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken;
 import org.springframework.security.web.SecurityFilterChain;
+import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -74,8 +85,15 @@ import org.springframework.test.web.servlet.MvcResult;
 import org.springframework.test.web.servlet.ResultMatcher;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.when;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
@@ -99,17 +117,48 @@ public class OidcUserInfoTests {
 	@Autowired
 	private JwtEncoder jwtEncoder;
 
+	@Autowired
+	private JwtDecoder jwtDecoder;
+
 	@Autowired
 	private OAuth2AuthorizationService authorizationService;
 
+	private static AuthenticationConverter authenticationConverter;
+
+	private static Consumer<List<AuthenticationConverter>> authenticationConvertersConsumer;
+
+	private static AuthenticationProvider authenticationProvider;
+
+	private static Consumer<List<AuthenticationProvider>> authenticationProvidersConsumer;
+
+	private static AuthenticationSuccessHandler authenticationSuccessHandler;
+
+	private static AuthenticationFailureHandler authenticationFailureHandler;
+
+	private static Function<OidcUserInfoAuthenticationContext, OidcUserInfo> userInfoMapper;
+
 	@BeforeClass
 	public static void init() {
 		securityContextRepository = spy(new HttpSessionSecurityContextRepository());
+		authenticationConverter = mock(AuthenticationConverter.class);
+		authenticationConvertersConsumer = mock(Consumer.class);
+		authenticationProvider = mock(AuthenticationProvider.class);
+		authenticationProvidersConsumer = mock(Consumer.class);
+		authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class);
+		authenticationFailureHandler = mock(AuthenticationFailureHandler.class);
+		userInfoMapper = mock(Function.class);
 	}
 
 	@Before
 	public void setup() {
 		reset(securityContextRepository);
+		reset(authenticationConverter);
+		reset(authenticationConvertersConsumer);
+		reset(authenticationProvider);
+		reset(authenticationProvidersConsumer);
+		reset(authenticationSuccessHandler);
+		reset(authenticationFailureHandler);
+		reset(userInfoMapper);
 	}
 
 	@Test
@@ -145,19 +194,89 @@ public class OidcUserInfoTests {
 	}
 
 	@Test
-	public void requestWhenSignedJwtAndCustomUserInfoMapperThenMapJwtClaimsToUserInfoResponse() throws Exception {
+	public void requestWhenUserInfoEndpointCustomizedThenUsed() throws Exception {
 		this.spring.register(CustomUserInfoConfiguration.class).autowire();
 
 		OAuth2Authorization authorization = createAuthorization();
 		this.authorizationService.save(authorization);
 
+		when(userInfoMapper.apply(any())).thenReturn(createUserInfo());
+
 		OAuth2AccessToken accessToken = authorization.getAccessToken().getToken();
 		// @formatter:off
 		this.mvc.perform(get(DEFAULT_OIDC_USER_INFO_ENDPOINT_URI)
 				.header(HttpHeaders.AUTHORIZATION, "Bearer " + accessToken.getTokenValue()))
-				.andExpect(status().is2xxSuccessful())
-				.andExpectAll(userInfoResponse());
+				.andExpect(status().is2xxSuccessful());
 		// @formatter:on
+		verify(userInfoMapper).apply(any());
+		verify(authenticationConverter).convert(any());
+		verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), any());
+		verifyNoInteractions(authenticationFailureHandler);
+
+		ArgumentCaptor<List<AuthenticationProvider>> authenticationProvidersCaptor = ArgumentCaptor.forClass(List.class);
+		verify(authenticationProvidersConsumer).accept(authenticationProvidersCaptor.capture());
+		List<AuthenticationProvider> authenticationProviders = authenticationProvidersCaptor.getValue();
+		assertThat(authenticationProviders).hasSize(2).allMatch(provider ->
+				provider == authenticationProvider ||
+						provider instanceof OidcUserInfoAuthenticationProvider
+				);
+
+		ArgumentCaptor<List<AuthenticationConverter>> authenticationConvertersCaptor = ArgumentCaptor.forClass(List.class);
+		verify(authenticationConvertersConsumer).accept(authenticationConvertersCaptor.capture());
+		List<AuthenticationConverter> authenticationConverters = authenticationConvertersCaptor.getValue();
+		assertThat(authenticationConverters).hasSize(2).allMatch(AuthenticationConverter.class::isInstance);
+	}
+
+	@Test
+	public void requestWhenUserInfoEndpointCustomizedThenAuthenticationProviderUsed() throws Exception {
+		this.spring.register(CustomUserInfoConfiguration.class).autowire();
+
+		OAuth2Authorization authorization = createAuthorization();
+		this.authorizationService.save(authorization);
+
+		when(authenticationProvider.supports(eq(OidcUserInfoAuthenticationToken.class))).thenReturn(true);
+		String tokenValue = authorization.getAccessToken().getToken().getTokenValue();
+		Jwt jwt = this.jwtDecoder.decode(tokenValue);
+		OidcUserInfoAuthenticationToken oidcUserInfoAuthentication = new OidcUserInfoAuthenticationToken(
+				new JwtAuthenticationToken(jwt), createUserInfo());
+		when(authenticationProvider.authenticate(any())).thenReturn(oidcUserInfoAuthentication);
+
+		OAuth2AccessToken accessToken = authorization.getAccessToken().getToken();
+		// @formatter:off
+		this.mvc.perform(get(DEFAULT_OIDC_USER_INFO_ENDPOINT_URI)
+						.header(HttpHeaders.AUTHORIZATION, "Bearer " + accessToken.getTokenValue()))
+				.andExpect(status().is2xxSuccessful());
+		// @formatter:on
+		verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), any());
+		verify(authenticationProvider).authenticate(any());
+		verifyNoInteractions(authenticationFailureHandler);
+		verifyNoInteractions(userInfoMapper);
+	}
+
+	@Test
+	public void requestWhenUserInfoEndpointCustomizedAndErrorThenUsed() throws Exception {
+		this.spring.register(CustomUserInfoConfiguration.class).autowire();
+		when(userInfoMapper.apply(any())).thenReturn(createUserInfo());
+		doAnswer(
+				invocation -> {
+					HttpServletResponse response = invocation.getArgument(1);
+					response.setStatus(HttpStatus.UNAUTHORIZED.value());
+					response.getWriter().write("unauthorized");
+					return null;
+				}
+		).when(authenticationFailureHandler).onAuthenticationFailure(any(), any(), any());
+
+		OAuth2AccessToken accessToken = createAuthorization().getAccessToken().getToken();
+
+
+		// @formatter:off
+		this.mvc.perform(get(DEFAULT_OIDC_USER_INFO_ENDPOINT_URI)
+				.header(HttpHeaders.AUTHORIZATION, "Bearer " + accessToken.getTokenValue()))
+				.andExpect(status().is4xxClientError());
+		// @formatter:on
+		verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), any());
+		verifyNoInteractions(authenticationSuccessHandler);
+		verifyNoInteractions(userInfoMapper);
 	}
 
 	// gh-482
@@ -271,14 +390,6 @@ public class OidcUserInfoTests {
 			RequestMatcher endpointsMatcher = authorizationServerConfigurer
 					.getEndpointsMatcher();
 
-			// Custom User Info Mapper that retrieves claims from a signed JWT
-			Function<OidcUserInfoAuthenticationContext, OidcUserInfo> userInfoMapper = context -> {
-				OidcUserInfoAuthenticationToken authentication = context.getAuthentication();
-				JwtAuthenticationToken principal = (JwtAuthenticationToken) authentication.getPrincipal();
-
-				return new OidcUserInfo(principal.getToken().getClaims());
-			};
-
 			// @formatter:off
 			http
 				.requestMatcher(endpointsMatcher)
@@ -290,6 +401,12 @@ public class OidcUserInfoTests {
 				.apply(authorizationServerConfigurer)
 					.oidc(oidc -> oidc
 						.userInfoEndpoint(userInfo -> userInfo
+							.userInfoRequestConverter(authenticationConverter)
+							.userInfoRequestConverters(authenticationConvertersConsumer)
+							.authenticationProvider(authenticationProvider)
+							.authenticationProviders(authenticationProvidersConsumer)
+							.userInfoResponseHandler(authenticationSuccessHandler)
+							.errorResponseHandler(authenticationFailureHandler)
 							.userInfoMapper(userInfoMapper)
 						)
 					);

+ 110 - 3
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilterTests.java

@@ -44,6 +44,9 @@ import org.springframework.security.oauth2.jwt.JoseHeaderNames;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken;
 import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken;
+import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@@ -84,6 +87,27 @@ public class OidcUserInfoEndpointFilterTests {
 				.withMessage("userInfoEndpointUri cannot be empty");
 	}
 
+	@Test
+	public void setAuthenticationConverterNullThenThrowIllegalArgumentException() {
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.filter.setAuthenticationConverter(null))
+				.withMessage("authenticationConverter cannot be null");
+	}
+
+	@Test
+	public void setAuthenticationSuccessHandlerNullThenThrowIllegalArgumentException() {
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.filter.setAuthenticationSuccessHandler(null))
+				.withMessage("authenticationSuccessHandler cannot be null");
+	}
+
+	@Test
+	public void setAuthenticationFailureHandlerNullThenThrowIllegalArgumentException() {
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.filter.setAuthenticationFailureHandler(null))
+				.withMessage("authenticationFailureHandler cannot be null");
+	}
+
 	@Test
 	public void doFilterWhenNotUserInfoRequestThenNotProcessed() throws Exception {
 		String requestUri = "/path";
@@ -145,11 +169,21 @@ public class OidcUserInfoEndpointFilterTests {
 
 	@Test
 	public void doFilterWhenUserInfoRequestInvalidTokenThenUnauthorizedError() throws Exception {
+		doFilterWhenAuthenticationExceptionThenError(OAuth2ErrorCodes.INVALID_TOKEN, HttpStatus.UNAUTHORIZED);
+	}
+
+	@Test
+	public void doFilterWhenUserInfoRequestInsufficientScopeThenForbiddenError() throws Exception {
+		doFilterWhenAuthenticationExceptionThenError(OAuth2ErrorCodes.INSUFFICIENT_SCOPE, HttpStatus.FORBIDDEN);
+	}
+
+	private void doFilterWhenAuthenticationExceptionThenError(String oauth2ErrorCode, HttpStatus httpStatus)
+			throws Exception {
 		Authentication principal = new TestingAuthenticationToken("principal", "credentials");
 		SecurityContextHolder.getContext().setAuthentication(principal);
 
 		when(this.authenticationManager.authenticate(any()))
-				.thenThrow(new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN));
+				.thenThrow(new OAuth2AuthenticationException(oauth2ErrorCode));
 
 		String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI;
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
@@ -161,9 +195,82 @@ public class OidcUserInfoEndpointFilterTests {
 
 		verifyNoInteractions(filterChain);
 
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value());
+		assertThat(response.getStatus()).isEqualTo(httpStatus.value());
 		OAuth2Error error = readError(response);
-		assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN);
+		assertThat(error.getErrorCode()).isEqualTo(oauth2ErrorCode);
+	}
+
+	@Test
+	public void doFilterWhenCustomAuthenticationConverterThenUses() throws Exception {
+		Authentication principal = new TestingAuthenticationToken("principal", "credentials");
+		OidcUserInfoAuthenticationToken authentication = new OidcUserInfoAuthenticationToken(principal);
+		AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class);
+		this.filter.setAuthenticationConverter(authenticationConverter);
+
+		when(authenticationConverter.convert(any())).thenReturn(authentication);
+		when(this.authenticationManager.authenticate(any())).thenReturn(
+				new OidcUserInfoAuthenticationToken(principal, createUserInfo())
+		);
+
+		String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI;
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(authenticationConverter).convert(request);
+		verify(this.authenticationManager).authenticate(authentication);
+		assertUserInfoResponse(response.getContentAsString());
+	}
+
+	@Test
+	public void doFilterWhenCustomAuthenticationSuccessHandlerThenUses() throws Exception {
+		AuthenticationSuccessHandler successHandler = mock(AuthenticationSuccessHandler.class);
+		this.filter.setAuthenticationSuccessHandler(successHandler);
+
+		Authentication principal = new TestingAuthenticationToken("principal", "credentials");
+		SecurityContextHolder.getContext().setAuthentication(principal);
+
+		OidcUserInfoAuthenticationToken authentication = new OidcUserInfoAuthenticationToken(principal, createUserInfo());
+		when(this.authenticationManager.authenticate(any())).thenReturn(authentication);
+
+		String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI;
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+		verify(successHandler).onAuthenticationSuccess(request, response, authentication);
+	}
+
+	@Test
+	public void doFilterWhenCustomFailureHandlerThenUses() throws Exception {
+		AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class);
+		this.filter.setAuthenticationFailureHandler(failureHandler);
+
+		Authentication principal = new TestingAuthenticationToken("principal", "credentials");
+		SecurityContextHolder.getContext().setAuthentication(principal);
+
+		OAuth2AuthenticationException authenticationException =
+				new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN);
+		when(this.authenticationManager.authenticate(any())).thenThrow(authenticationException);
+
+		String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI;
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		verify(failureHandler).onAuthenticationFailure(request, response, authenticationException);
 	}
 
 	private OAuth2Error readError(MockHttpServletResponse response) throws Exception {