Browse Source

Merge branch 0.4.x into main

The following commits are merged using the default merge strategy.

8d7f8b3420b77ee44061591604f3ccc99be90441 Improve customizing OIDC UserInfo endpoint
2ba711c83a08607a19963bc2fa40dbb4afcd8f21 Polish gh-929
efbfdc234c7feb63383c0987be91d9874aa0e7e7 Improve customizing OIDC Client Registration endpoint
bfd7a09c3b4e70c51d10eb9b4b53fb0daf81df2b Polish gh-946
11ce8ef201eda14a8e8e4d271deb07cf7addbebd Polish gh-929
356d669a78ea860f9e018791e21fc600eb30e4b5 Fix URL encoding for authorization request state parameter
4eb25c163f0e26403a9381f23e9863f2317ecd13 Polish gh-920
6dc3944eef3d73d9d76c863114f24ca7d1b1337c Add OidcClientRegistrationAuthenticationProvider.setRegisteredClientConverter()
Joe Grandja 2 years ago
parent
commit
4adc3766ea
16 changed files with 1075 additions and 130 deletions
  1. 41 7
      docs/src/docs/asciidoc/protocol-endpoints.adoc
  2. 162 15
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationEndpointConfigurer.java
  3. 155 7
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcUserInfoEndpointConfigurer.java
  4. 2 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientConfigurationAuthenticationProvider.java
  5. 16 4
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java
  6. 1 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/RegisteredClientOidcClientRegistrationConverter.java
  7. 64 19
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilter.java
  8. 66 13
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilter.java
  9. 18 4
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java
  10. 1 1
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java
  11. 29 9
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java
  12. 147 0
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationTests.java
  13. 130 11
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcUserInfoTests.java
  14. 7 0
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java
  15. 126 35
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java
  16. 110 3
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilterTests.java

+ 41 - 7
docs/src/docs/asciidoc/protocol-endpoints.adoc

@@ -269,9 +269,9 @@ public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity h
 == OpenID Connect 1.0 UserInfo Endpoint
 == OpenID Connect 1.0 UserInfo Endpoint
 
 
 `OidcUserInfoEndpointConfigurer` provides the ability to customize the https://openid.net/specs/openid-connect-core-1_0.html#UserInfo[OpenID Connect 1.0 UserInfo endpoint].
 `OidcUserInfoEndpointConfigurer` provides the ability to customize the https://openid.net/specs/openid-connect-core-1_0.html#UserInfo[OpenID Connect 1.0 UserInfo endpoint].
-It defines extension points that let you customize the https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse[UserInfo response].
+It defines extension points that let you customize the pre-processing, main processing, and post-processing logic for https://openid.net/specs/openid-connect-core-1_0.html#UserInfoRequest[UserInfo requests].
 
 
-`OidcUserInfoEndpointConfigurer` provides the following configuration option:
+`OidcUserInfoEndpointConfigurer` provides the following configuration options:
 
 
 [source,java]
 [source,java]
 ----
 ----
@@ -285,21 +285,37 @@ public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity h
 		.oidc(oidc ->
 		.oidc(oidc ->
 			oidc
 			oidc
 				.userInfoEndpoint(userInfoEndpoint ->
 				.userInfoEndpoint(userInfoEndpoint ->
-					userInfoEndpoint.userInfoMapper(userInfoMapper)   <1>
+					userInfoEndpoint
+						.userInfoRequestConverter(userInfoRequestConverter) <1>
+						.userInfoRequestConverters(userInfoRequestConvertersConsumer) <2>
+						.authenticationProvider(authenticationProvider) <3>
+						.authenticationProviders(authenticationProvidersConsumer) <4>
+						.userInfoResponseHandler(userInfoResponseHandler) <5>
+						.errorResponseHandler(errorResponseHandler) <6>
+						.userInfoMapper(userInfoMapper) <7>
 				)
 				)
 		);
 		);
 
 
 	return http.build();
 	return http.build();
 }
 }
 ----
 ----
-<1> `userInfoMapper()`: The `Function` used to extract claims from `OidcUserInfoAuthenticationContext` to an instance of `OidcUserInfo`.
+<1> `userInfoRequestConverter()`: Adds an `AuthenticationConverter` (_pre-processor_) used when attempting to extract an https://openid.net/specs/openid-connect-core-1_0.html#UserInfoRequest[UserInfo request] from `HttpServletRequest` to an instance of `OidcUserInfoAuthenticationToken`.
+<2> `userInfoRequestConverters()`: Sets the `Consumer` providing access to the `List` of default and (optionally) added ``AuthenticationConverter``'s allowing the ability to add, remove, or customize a specific `AuthenticationConverter`.
+<3> `authenticationProvider()`: Adds an `AuthenticationProvider` (_main processor_) used for authenticating the `OidcUserInfoAuthenticationToken`.
+<4> `authenticationProviders()`: Sets the `Consumer` providing access to the `List` of default and (optionally) added ``AuthenticationProvider``'s allowing the ability to add, remove, or customize a specific `AuthenticationProvider`.
+<5> `userInfoResponseHandler()`: The `AuthenticationSuccessHandler` (_post-processor_) used for handling an "`authenticated`" `OidcUserInfoAuthenticationToken` and returning the https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse[UserInfo response].
+<6> `errorResponseHandler()`: The `AuthenticationFailureHandler` (_post-processor_) used for handling an `OAuth2AuthenticationException` and returning the https://openid.net/specs/openid-connect-core-1_0.html#UserInfoError[UserInfo Error response].
+<7> `userInfoMapper()`: The `Function` used to extract claims from `OidcUserInfoAuthenticationContext` to an instance of `OidcUserInfo`.
 
 
 `OidcUserInfoEndpointConfigurer` configures the `OidcUserInfoEndpointFilter` and registers it with the OAuth2 authorization server `SecurityFilterChain` `@Bean`.
 `OidcUserInfoEndpointConfigurer` configures the `OidcUserInfoEndpointFilter` and registers it with the OAuth2 authorization server `SecurityFilterChain` `@Bean`.
 `OidcUserInfoEndpointFilter` is the `Filter` that processes https://openid.net/specs/openid-connect-core-1_0.html#UserInfoRequest[UserInfo requests] and returns the https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse[OidcUserInfo response].
 `OidcUserInfoEndpointFilter` is the `Filter` that processes https://openid.net/specs/openid-connect-core-1_0.html#UserInfoRequest[UserInfo requests] and returns the https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse[OidcUserInfo response].
 
 
 `OidcUserInfoEndpointFilter` is configured with the following defaults:
 `OidcUserInfoEndpointFilter` is configured with the following defaults:
 
 
+* `*AuthenticationConverter*` -- An internal implementation that obtains the `Authentication` from the `SecurityContext` and creates an `OidcUserInfoAuthenticationToken` with the principal.
 * `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OidcUserInfoAuthenticationProvider`, which is associated with an internal implementation of `userInfoMapper` that extracts https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims[standard claims] from the https://openid.net/specs/openid-connect-core-1_0.html#IDToken[ID Token] based on the https://openid.net/specs/openid-connect-core-1_0.html#ScopeClaims[scopes requested] during authorization.
 * `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OidcUserInfoAuthenticationProvider`, which is associated with an internal implementation of `userInfoMapper` that extracts https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims[standard claims] from the https://openid.net/specs/openid-connect-core-1_0.html#IDToken[ID Token] based on the https://openid.net/specs/openid-connect-core-1_0.html#ScopeClaims[scopes requested] during authorization.
+* `*AuthenticationSuccessHandler*` -- An internal implementation that handles an "`authenticated`" `OidcUserInfoAuthenticationToken` and returns the `OidcUserInfo` response.
+* `*AuthenticationFailureHandler*` -- An internal implementation that uses the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response.
 
 
 [TIP]
 [TIP]
 You can customize the ID Token by providing an xref:core-model-components.adoc#oauth2-token-customizer[`OAuth2TokenCustomizer<JwtEncodingContext>`] `@Bean`.
 You can customize the ID Token by providing an xref:core-model-components.adoc#oauth2-token-customizer[`OAuth2TokenCustomizer<JwtEncodingContext>`] `@Bean`.
@@ -337,8 +353,10 @@ The guide xref:guides/how-to-userinfo.adoc#how-to-userinfo[How-to: Customize the
 [[oidc-client-registration-endpoint]]
 [[oidc-client-registration-endpoint]]
 == OpenID Connect 1.0 Client Registration Endpoint
 == OpenID Connect 1.0 Client Registration Endpoint
 
 
-`OidcClientRegistrationEndpointConfigurer` configures the https://openid.net/specs/openid-connect-registration-1_0.html#ClientRegistration[OpenID Connect 1.0 Client Registration endpoint].
-The following example shows how to enable (disabled by default) the OpenID Connect 1.0 Client Registration endpoint:
+`OidcClientRegistrationEndpointConfigurer` provides the ability to customize the https://openid.net/specs/openid-connect-registration-1_0.html#ClientRegistration[OpenID Connect 1.0 Client Registration endpoint].
+It defines extension points that let you customize the pre-processing, main processing, and post-processing logic for https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationRequest[Client Registration requests] or https://openid.net/specs/openid-connect-registration-1_0.html#ReadRequest[Client Read requests].
+
+`OidcClientRegistrationEndpointConfigurer` provides the following configuration options:
 
 
 [source,java]
 [source,java]
 ----
 ----
@@ -351,12 +369,26 @@ public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity h
 	authorizationServerConfigurer
 	authorizationServerConfigurer
 		.oidc(oidc ->
 		.oidc(oidc ->
 			oidc
 			oidc
-				.clientRegistrationEndpoint(Customizer.withDefaults())
+				.clientRegistrationEndpoint(clientRegistrationEndpoint ->
+					clientRegistrationEndpoint
+						.clientRegistrationRequestConverter(clientRegistrationRequestConverter) <1>
+						.clientRegistrationRequestConverters(clientRegistrationRequestConvertersConsumers) <2>
+						.authenticationProvider(authenticationProvider) <3>
+						.authenticationProviders(authenticationProvidersConsumer) <4>
+						.clientRegistrationResponseHandler(clientRegistrationResponseHandler) <5>
+						.errorResponseHandler(errorResponseHandler) <6>
+				)
 		);
 		);
 
 
 	return http.build();
 	return http.build();
 }
 }
 ----
 ----
+<1> `clientRegistrationRequestConverter()`: Adds an `AuthenticationConverter` (_pre-processor_) used when attempting to extract a https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationRequest[Client Registration request] or https://openid.net/specs/openid-connect-registration-1_0.html#ReadRequest[Client Read request] from `HttpServletRequest` to an instance of `OidcClientRegistrationAuthenticationToken`.
+<2> `clientRegistrationRequestConverters()`: Sets the `Consumer` providing access to the `List` of default and (optionally) added ``AuthenticationConverter``'s allowing the ability to add, remove, or customize a specific `AuthenticationConverter`.
+<3> `authenticationProvider()`: Adds an `AuthenticationProvider` (_main processor_) used for authenticating the `OidcClientRegistrationAuthenticationToken`.
+<4> `authenticationProviders()`: Sets the `Consumer` providing access to the `List` of default and (optionally) added ``AuthenticationProvider``'s allowing the ability to add, remove, or customize a specific `AuthenticationProvider`.
+<5> `clientRegistrationResponseHandler()`: The `AuthenticationSuccessHandler` (_post-processor_) used for handling an "`authenticated`" `OidcClientRegistrationAuthenticationToken` and returning the https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationResponse[Client Registration response] or https://openid.net/specs/openid-connect-registration-1_0.html#ReadResponse[Client Read response].
+<6> `errorResponseHandler()`: The `AuthenticationFailureHandler` (_post-processor_) used for handling an `OAuth2AuthenticationException` and returning the https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationError[Client Registration Error response] or https://openid.net/specs/openid-connect-registration-1_0.html#ReadError[Client Read Error response].
 
 
 [NOTE]
 [NOTE]
 The OpenID Connect 1.0 Client Registration endpoint is disabled by default because many deployments do not require dynamic client registration.
 The OpenID Connect 1.0 Client Registration endpoint is disabled by default because many deployments do not require dynamic client registration.
@@ -371,6 +403,8 @@ The OpenID Connect 1.0 Client Registration endpoint is disabled by default becau
 
 
 * `*AuthenticationConverter*` -- An `OidcClientRegistrationAuthenticationConverter`.
 * `*AuthenticationConverter*` -- An `OidcClientRegistrationAuthenticationConverter`.
 * `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OidcClientRegistrationAuthenticationProvider` and `OidcClientConfigurationAuthenticationProvider`.
 * `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OidcClientRegistrationAuthenticationProvider` and `OidcClientConfigurationAuthenticationProvider`.
+* `*AuthenticationSuccessHandler*` -- An internal implementation that handles an "`authenticated`" `OidcClientRegistrationAuthenticationToken` and returns the `OidcClientRegistration` response.
+* `*AuthenticationFailureHandler*` -- An internal implementation that uses the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response.
 
 
 The OpenID Connect 1.0 Client Registration endpoint is an https://openid.net/specs/openid-connect-registration-1_0.html#ClientRegistration[OAuth2 protected resource], which *REQUIRES* an access token to be sent as a bearer token in the Client Registration (or Client Read) request.
 The OpenID Connect 1.0 Client Registration endpoint is an https://openid.net/specs/openid-connect-registration-1_0.html#ClientRegistration[OAuth2 protected resource], which *REQUIRES* an access token to be sent as a bearer token in the Client Registration (or Client Read) request.
 
 

+ 162 - 15
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationEndpointConfigurer.java

@@ -15,29 +15,53 @@
  */
  */
 package org.springframework.security.oauth2.server.authorization.config.annotation.web.configurers;
 package org.springframework.security.oauth2.server.authorization.config.annotation.web.configurers;
 
 
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Consumer;
+
+import jakarta.servlet.http.HttpServletRequest;
+
 import org.springframework.http.HttpMethod;
 import org.springframework.http.HttpMethod;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.AuthenticationManager;
+import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.config.annotation.ObjectPostProcessor;
 import org.springframework.security.config.annotation.ObjectPostProcessor;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.server.authorization.oidc.OidcClientRegistration;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientConfigurationAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientConfigurationAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationProvider;
+import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.oidc.web.OidcClientRegistrationEndpointFilter;
 import org.springframework.security.oauth2.server.authorization.oidc.web.OidcClientRegistrationEndpointFilter;
+import org.springframework.security.oauth2.server.authorization.oidc.web.authentication.OidcClientRegistrationAuthenticationConverter;
 import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
 import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
+import org.springframework.security.oauth2.server.authorization.web.authentication.DelegatingAuthenticationConverter;
 import org.springframework.security.web.access.intercept.AuthorizationFilter;
 import org.springframework.security.web.access.intercept.AuthorizationFilter;
+import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.OrRequestMatcher;
 import org.springframework.security.web.util.matcher.OrRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
+import org.springframework.util.Assert;
 
 
 /**
 /**
- * Configurer for OpenID Connect Dynamic Client Registration 1.0 Endpoint.
+ * Configurer for OpenID Connect 1.0 Dynamic Client Registration Endpoint.
  *
  *
  * @author Joe Grandja
  * @author Joe Grandja
+ * @author Daniel Garnier-Moiroux
  * @since 0.2.0
  * @since 0.2.0
  * @see OidcConfigurer#clientRegistrationEndpoint
  * @see OidcConfigurer#clientRegistrationEndpoint
  * @see OidcClientRegistrationEndpointFilter
  * @see OidcClientRegistrationEndpointFilter
  */
  */
 public final class OidcClientRegistrationEndpointConfigurer extends AbstractOAuth2Configurer {
 public final class OidcClientRegistrationEndpointConfigurer extends AbstractOAuth2Configurer {
 	private RequestMatcher requestMatcher;
 	private RequestMatcher requestMatcher;
+	private final List<AuthenticationConverter> clientRegistrationRequestConverters = new ArrayList<>();
+	private Consumer<List<AuthenticationConverter>> clientRegistrationRequestConvertersConsumer = (clientRegistrationRequestConverters) -> {};
+	private final List<AuthenticationProvider> authenticationProviders = new ArrayList<>();
+	private Consumer<List<AuthenticationProvider>> authenticationProvidersConsumer = (authenticationProviders) -> {};
+	private AuthenticationSuccessHandler clientRegistrationResponseHandler;
+	private AuthenticationFailureHandler errorResponseHandler;
 
 
 	/**
 	/**
 	 * Restrict for internal use only.
 	 * Restrict for internal use only.
@@ -46,26 +70,108 @@ public final class OidcClientRegistrationEndpointConfigurer extends AbstractOAut
 		super(objectPostProcessor);
 		super(objectPostProcessor);
 	}
 	}
 
 
+	/**
+	 * Adds an {@link AuthenticationConverter} used when attempting to extract a Client Registration Request from {@link HttpServletRequest}
+	 * to an instance of {@link OidcClientRegistrationAuthenticationToken} used for authenticating the request.
+	 *
+	 * @param clientRegistrationRequestConverter an {@link AuthenticationConverter} used when attempting to extract a Client Registration Request from {@link HttpServletRequest}
+	 * @return the {@link OidcClientRegistrationEndpointConfigurer} for further configuration
+	 * @since 0.4.0
+	 */
+	public OidcClientRegistrationEndpointConfigurer clientRegistrationRequestConverter(
+			AuthenticationConverter clientRegistrationRequestConverter) {
+		Assert.notNull(clientRegistrationRequestConverter, "clientRegistrationRequestConverter cannot be null");
+		this.clientRegistrationRequestConverters.add(clientRegistrationRequestConverter);
+		return this;
+	}
+
+	/**
+	 * Sets the {@code Consumer} providing access to the {@code List} of default
+	 * and (optionally) added {@link #clientRegistrationRequestConverter(AuthenticationConverter) AuthenticationConverter}'s
+	 * allowing the ability to add, remove, or customize a specific {@link AuthenticationConverter}.
+	 *
+	 * @param clientRegistrationRequestConvertersConsumer the {@code Consumer} providing access to the {@code List} of default and (optionally) added {@link AuthenticationConverter}'s
+	 * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration
+	 * @since 0.4.0
+	 */
+	public OidcClientRegistrationEndpointConfigurer clientRegistrationRequestConverters(
+			Consumer<List<AuthenticationConverter>> clientRegistrationRequestConvertersConsumer) {
+		Assert.notNull(clientRegistrationRequestConvertersConsumer, "clientRegistrationRequestConvertersConsumer cannot be null");
+		this.clientRegistrationRequestConvertersConsumer = clientRegistrationRequestConvertersConsumer;
+		return this;
+	}
+
+	/**
+	 * Adds an {@link AuthenticationProvider} used for authenticating an {@link OidcClientRegistrationAuthenticationToken}.
+	 *
+	 * @param authenticationProvider an {@link AuthenticationProvider} used for authenticating an {@link OidcClientRegistrationAuthenticationToken}
+	 * @return the {@link OidcClientRegistrationEndpointConfigurer} for further configuration
+	 * @since 0.4.0
+	 */
+	public OidcClientRegistrationEndpointConfigurer authenticationProvider(AuthenticationProvider authenticationProvider) {
+		Assert.notNull(authenticationProvider, "authenticationProvider cannot be null");
+		this.authenticationProviders.add(authenticationProvider);
+		return this;
+	}
+
+	/**
+	 * Sets the {@code Consumer} providing access to the {@code List} of default
+	 * and (optionally) added {@link #authenticationProvider(AuthenticationProvider) AuthenticationProvider}'s
+	 * allowing the ability to add, remove, or customize a specific {@link AuthenticationProvider}.
+	 *
+	 * @param authenticationProvidersConsumer the {@code Consumer} providing access to the {@code List} of default and (optionally) added {@link AuthenticationProvider}'s
+	 * @return the {@link OidcClientRegistrationEndpointConfigurer} for further configuration
+	 * @since 0.4.0
+	 */
+	public OidcClientRegistrationEndpointConfigurer authenticationProviders(
+			Consumer<List<AuthenticationProvider>> authenticationProvidersConsumer) {
+		Assert.notNull(authenticationProvidersConsumer, "authenticationProvidersConsumer cannot be null");
+		this.authenticationProvidersConsumer = authenticationProvidersConsumer;
+		return this;
+	}
+
+	/**
+	 * Sets the {@link AuthenticationSuccessHandler} used for handling an {@link OidcClientRegistrationAuthenticationToken}
+	 * and returning the {@link OidcClientRegistration Client Registration Response}.
+	 *
+	 * @param clientRegistrationResponseHandler the {@link AuthenticationSuccessHandler} used for handling an {@link OidcClientRegistrationAuthenticationToken}
+	 * @return the {@link OidcClientRegistrationEndpointConfigurer} for further configuration
+	 * @since 0.4.0
+	 */
+	public OidcClientRegistrationEndpointConfigurer clientRegistrationResponseHandler(AuthenticationSuccessHandler clientRegistrationResponseHandler) {
+		this.clientRegistrationResponseHandler = clientRegistrationResponseHandler;
+		return this;
+	}
+
+	/**
+	 * Sets the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException}
+	 * and returning the {@link OAuth2Error Error Response}.
+	 *
+	 * @param errorResponseHandler the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException}
+	 * @return the {@link OidcClientRegistrationEndpointConfigurer} for further configuration
+	 * @since 0.4.0
+	 */
+	public OidcClientRegistrationEndpointConfigurer errorResponseHandler(AuthenticationFailureHandler errorResponseHandler) {
+		this.errorResponseHandler = errorResponseHandler;
+		return this;
+	}
+
 	@Override
 	@Override
 	void init(HttpSecurity httpSecurity) {
 	void init(HttpSecurity httpSecurity) {
 		AuthorizationServerSettings authorizationServerSettings = OAuth2ConfigurerUtils.getAuthorizationServerSettings(httpSecurity);
 		AuthorizationServerSettings authorizationServerSettings = OAuth2ConfigurerUtils.getAuthorizationServerSettings(httpSecurity);
+		String clientRegistrationEndpointUri = authorizationServerSettings.getOidcClientRegistrationEndpoint();
 		this.requestMatcher = new OrRequestMatcher(
 		this.requestMatcher = new OrRequestMatcher(
-				new AntPathRequestMatcher(authorizationServerSettings.getOidcClientRegistrationEndpoint(), HttpMethod.POST.name()),
-				new AntPathRequestMatcher(authorizationServerSettings.getOidcClientRegistrationEndpoint(), HttpMethod.GET.name())
+				new AntPathRequestMatcher(clientRegistrationEndpointUri, HttpMethod.POST.name()),
+				new AntPathRequestMatcher(clientRegistrationEndpointUri, HttpMethod.GET.name())
 		);
 		);
 
 
-		OidcClientRegistrationAuthenticationProvider oidcClientRegistrationAuthenticationProvider =
-				new OidcClientRegistrationAuthenticationProvider(
-						OAuth2ConfigurerUtils.getRegisteredClientRepository(httpSecurity),
-						OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity),
-						OAuth2ConfigurerUtils.getTokenGenerator(httpSecurity));
-		httpSecurity.authenticationProvider(postProcess(oidcClientRegistrationAuthenticationProvider));
-
-		OidcClientConfigurationAuthenticationProvider oidcClientConfigurationAuthenticationProvider =
-				new OidcClientConfigurationAuthenticationProvider(
-						OAuth2ConfigurerUtils.getRegisteredClientRepository(httpSecurity),
-						OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity));
-		httpSecurity.authenticationProvider(postProcess(oidcClientConfigurationAuthenticationProvider));
+		List<AuthenticationProvider> authenticationProviders = createDefaultAuthenticationProviders(httpSecurity);
+		if (!this.authenticationProviders.isEmpty()) {
+			authenticationProviders.addAll(0, this.authenticationProviders);
+		}
+		this.authenticationProvidersConsumer.accept(authenticationProviders);
+		authenticationProviders.forEach(authenticationProvider ->
+				httpSecurity.authenticationProvider(postProcess(authenticationProvider)));
 	}
 	}
 
 
 	@Override
 	@Override
@@ -77,6 +183,20 @@ public final class OidcClientRegistrationEndpointConfigurer extends AbstractOAut
 				new OidcClientRegistrationEndpointFilter(
 				new OidcClientRegistrationEndpointFilter(
 						authenticationManager,
 						authenticationManager,
 						authorizationServerSettings.getOidcClientRegistrationEndpoint());
 						authorizationServerSettings.getOidcClientRegistrationEndpoint());
+		List<AuthenticationConverter> authenticationConverters = createDefaultAuthenticationConverters();
+		if (!this.clientRegistrationRequestConverters.isEmpty()) {
+			authenticationConverters.addAll(0, this.clientRegistrationRequestConverters);
+		}
+		this.clientRegistrationRequestConvertersConsumer.accept(authenticationConverters);
+		oidcClientRegistrationEndpointFilter.setAuthenticationConverter(
+				new DelegatingAuthenticationConverter(authenticationConverters));
+		if (this.clientRegistrationResponseHandler != null) {
+			oidcClientRegistrationEndpointFilter
+					.setAuthenticationSuccessHandler(this.clientRegistrationResponseHandler);
+		}
+		if (this.errorResponseHandler != null) {
+			oidcClientRegistrationEndpointFilter.setAuthenticationFailureHandler(this.errorResponseHandler);
+		}
 		httpSecurity.addFilterAfter(postProcess(oidcClientRegistrationEndpointFilter), AuthorizationFilter.class);
 		httpSecurity.addFilterAfter(postProcess(oidcClientRegistrationEndpointFilter), AuthorizationFilter.class);
 	}
 	}
 
 
@@ -85,4 +205,31 @@ public final class OidcClientRegistrationEndpointConfigurer extends AbstractOAut
 		return this.requestMatcher;
 		return this.requestMatcher;
 	}
 	}
 
 
+	private static List<AuthenticationConverter> createDefaultAuthenticationConverters() {
+		List<AuthenticationConverter> authenticationConverters = new ArrayList<>();
+
+		authenticationConverters.add(new OidcClientRegistrationAuthenticationConverter());
+
+		return authenticationConverters;
+	}
+
+	private static List<AuthenticationProvider> createDefaultAuthenticationProviders(HttpSecurity httpSecurity) {
+		List<AuthenticationProvider> authenticationProviders = new ArrayList<>();
+
+		OidcClientRegistrationAuthenticationProvider oidcClientRegistrationAuthenticationProvider =
+				new OidcClientRegistrationAuthenticationProvider(
+						OAuth2ConfigurerUtils.getRegisteredClientRepository(httpSecurity),
+						OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity),
+						OAuth2ConfigurerUtils.getTokenGenerator(httpSecurity));
+		authenticationProviders.add(oidcClientRegistrationAuthenticationProvider);
+
+		OidcClientConfigurationAuthenticationProvider oidcClientConfigurationAuthenticationProvider =
+				new OidcClientConfigurationAuthenticationProvider(
+						OAuth2ConfigurerUtils.getRegisteredClientRepository(httpSecurity),
+						OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity));
+		authenticationProviders.add(oidcClientConfigurationAuthenticationProvider);
+
+		return authenticationProviders;
+	}
+
 }
 }

+ 155 - 7
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcUserInfoEndpointConfigurer.java

@@ -15,13 +15,23 @@
  */
  */
 package org.springframework.security.oauth2.server.authorization.config.annotation.web.configurers;
 package org.springframework.security.oauth2.server.authorization.config.annotation.web.configurers;
 
 
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.function.Function;
 
 
+import jakarta.servlet.http.HttpServletRequest;
+
 import org.springframework.http.HttpMethod;
 import org.springframework.http.HttpMethod;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.AuthenticationManager;
+import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.config.annotation.ObjectPostProcessor;
 import org.springframework.security.config.annotation.ObjectPostProcessor;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
 import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationContext;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationContext;
@@ -29,21 +39,33 @@ import org.springframework.security.oauth2.server.authorization.oidc.authenticat
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.oidc.web.OidcUserInfoEndpointFilter;
 import org.springframework.security.oauth2.server.authorization.oidc.web.OidcUserInfoEndpointFilter;
 import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
 import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
+import org.springframework.security.oauth2.server.authorization.web.authentication.DelegatingAuthenticationConverter;
 import org.springframework.security.web.access.intercept.AuthorizationFilter;
 import org.springframework.security.web.access.intercept.AuthorizationFilter;
+import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.OrRequestMatcher;
 import org.springframework.security.web.util.matcher.OrRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
+import org.springframework.util.Assert;
 
 
 /**
 /**
  * Configurer for OpenID Connect 1.0 UserInfo Endpoint.
  * Configurer for OpenID Connect 1.0 UserInfo Endpoint.
  *
  *
  * @author Steve Riesenberg
  * @author Steve Riesenberg
+ * @author Daniel Garnier-Moiroux
  * @since 0.2.1
  * @since 0.2.1
  * @see OidcConfigurer#userInfoEndpoint
  * @see OidcConfigurer#userInfoEndpoint
  * @see OidcUserInfoEndpointFilter
  * @see OidcUserInfoEndpointFilter
  */
  */
 public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configurer {
 public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configurer {
 	private RequestMatcher requestMatcher;
 	private RequestMatcher requestMatcher;
+	private final List<AuthenticationConverter> userInfoRequestConverters = new ArrayList<>();
+	private Consumer<List<AuthenticationConverter>> userInfoRequestConvertersConsumer = (userInfoRequestConverters) -> {};
+	private final List<AuthenticationProvider> authenticationProviders = new ArrayList<>();
+	private Consumer<List<AuthenticationProvider>> authenticationProvidersConsumer = (authenticationProviders) -> {};
+	private AuthenticationSuccessHandler userInfoResponseHandler;
+	private AuthenticationFailureHandler errorResponseHandler;
 	private Function<OidcUserInfoAuthenticationContext, OidcUserInfo> userInfoMapper;
 	private Function<OidcUserInfoAuthenticationContext, OidcUserInfo> userInfoMapper;
 
 
 	/**
 	/**
@@ -53,6 +75,91 @@ public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configur
 		super(objectPostProcessor);
 		super(objectPostProcessor);
 	}
 	}
 
 
+	/**
+	 * Adds an {@link AuthenticationConverter} used when attempting to extract an UserInfo Request from {@link HttpServletRequest}
+	 * to an instance of {@link OidcUserInfoAuthenticationToken} used for authenticating the request.
+	 *
+	 * @param userInfoRequestConverter an {@link AuthenticationConverter} used when attempting to extract an UserInfo Request from {@link HttpServletRequest}
+	 * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration
+	 * @since 0.4.0
+	 */
+	public OidcUserInfoEndpointConfigurer userInfoRequestConverter(AuthenticationConverter userInfoRequestConverter) {
+		Assert.notNull(userInfoRequestConverter, "userInfoRequestConverter cannot be null");
+		this.userInfoRequestConverters.add(userInfoRequestConverter);
+		return this;
+	}
+
+	/**
+	 * Sets the {@code Consumer} providing access to the {@code List} of default
+	 * and (optionally) added {@link #userInfoRequestConverter(AuthenticationConverter) AuthenticationConverter}'s
+	 * allowing the ability to add, remove, or customize a specific {@link AuthenticationConverter}.
+	 *
+	 * @param userInfoRequestConvertersConsumer the {@code Consumer} providing access to the {@code List} of default and (optionally) added {@link AuthenticationConverter}'s
+	 * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration
+	 * @since 0.4.0
+	 */
+	public OidcUserInfoEndpointConfigurer userInfoRequestConverters(
+			Consumer<List<AuthenticationConverter>> userInfoRequestConvertersConsumer) {
+		Assert.notNull(userInfoRequestConvertersConsumer, "userInfoRequestConvertersConsumer cannot be null");
+		this.userInfoRequestConvertersConsumer = userInfoRequestConvertersConsumer;
+		return this;
+	}
+
+	/**
+	 * Adds an {@link AuthenticationProvider} used for authenticating an {@link OidcUserInfoAuthenticationToken}.
+	 *
+	 * @param authenticationProvider an {@link AuthenticationProvider} used for authenticating an {@link OidcUserInfoAuthenticationToken}
+	 * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration
+	 * @since 0.4.0
+	 */
+	public OidcUserInfoEndpointConfigurer authenticationProvider(AuthenticationProvider authenticationProvider) {
+		Assert.notNull(authenticationProvider, "authenticationProvider cannot be null");
+		this.authenticationProviders.add(authenticationProvider);
+		return this;
+	}
+
+	/**
+	 * Sets the {@code Consumer} providing access to the {@code List} of default
+	 * and (optionally) added {@link #authenticationProvider(AuthenticationProvider) AuthenticationProvider}'s
+	 * allowing the ability to add, remove, or customize a specific {@link AuthenticationProvider}.
+	 *
+	 * @param authenticationProvidersConsumer the {@code Consumer} providing access to the {@code List} of default and (optionally) added {@link AuthenticationProvider}'s
+	 * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration
+	 * @since 0.4.0
+	 */
+	public OidcUserInfoEndpointConfigurer authenticationProviders(
+			Consumer<List<AuthenticationProvider>> authenticationProvidersConsumer) {
+		Assert.notNull(authenticationProvidersConsumer, "authenticationProvidersConsumer cannot be null");
+		this.authenticationProvidersConsumer = authenticationProvidersConsumer;
+		return this;
+	}
+
+	/**
+	 * Sets the {@link AuthenticationSuccessHandler} used for handling an {@link OidcUserInfoAuthenticationToken}
+	 * and returning the {@link OidcUserInfo UserInfo Response}.
+	 *
+	 * @param userInfoResponseHandler the {@link AuthenticationSuccessHandler} used for handling an {@link OidcUserInfoAuthenticationToken}
+	 * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration
+	 * @since 0.4.0
+	 */
+	public OidcUserInfoEndpointConfigurer userInfoResponseHandler(AuthenticationSuccessHandler userInfoResponseHandler) {
+		this.userInfoResponseHandler = userInfoResponseHandler;
+		return this;
+	}
+
+	/**
+	 * Sets the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException}
+	 * and returning the {@link OAuth2Error Error Response}.
+	 *
+	 * @param errorResponseHandler the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException}
+	 * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration
+	 * @since 0.4.0
+	 */
+	public OidcUserInfoEndpointConfigurer errorResponseHandler(AuthenticationFailureHandler errorResponseHandler) {
+		this.errorResponseHandler = errorResponseHandler;
+		return this;
+	}
+
 	/**
 	/**
 	 * Sets the {@link Function} used to extract claims from {@link OidcUserInfoAuthenticationContext}
 	 * Sets the {@link Function} used to extract claims from {@link OidcUserInfoAuthenticationContext}
 	 * to an instance of {@link OidcUserInfo} for the UserInfo response.
 	 * to an instance of {@link OidcUserInfo} for the UserInfo response.
@@ -69,7 +176,8 @@ public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configur
 	 * @param userInfoMapper the {@link Function} used to extract claims from {@link OidcUserInfoAuthenticationContext} to an instance of {@link OidcUserInfo}
 	 * @param userInfoMapper the {@link Function} used to extract claims from {@link OidcUserInfoAuthenticationContext} to an instance of {@link OidcUserInfo}
 	 * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration
 	 * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration
 	 */
 	 */
-	public OidcUserInfoEndpointConfigurer userInfoMapper(Function<OidcUserInfoAuthenticationContext, OidcUserInfo> userInfoMapper) {
+	public OidcUserInfoEndpointConfigurer userInfoMapper(
+			Function<OidcUserInfoAuthenticationContext, OidcUserInfo> userInfoMapper) {
 		this.userInfoMapper = userInfoMapper;
 		this.userInfoMapper = userInfoMapper;
 		return this;
 		return this;
 	}
 	}
@@ -82,13 +190,13 @@ public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configur
 				new AntPathRequestMatcher(userInfoEndpointUri, HttpMethod.GET.name()),
 				new AntPathRequestMatcher(userInfoEndpointUri, HttpMethod.GET.name()),
 				new AntPathRequestMatcher(userInfoEndpointUri, HttpMethod.POST.name()));
 				new AntPathRequestMatcher(userInfoEndpointUri, HttpMethod.POST.name()));
 
 
-		OidcUserInfoAuthenticationProvider oidcUserInfoAuthenticationProvider =
-				new OidcUserInfoAuthenticationProvider(
-						OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity));
-		if (this.userInfoMapper != null) {
-			oidcUserInfoAuthenticationProvider.setUserInfoMapper(this.userInfoMapper);
+		List<AuthenticationProvider> authenticationProviders = createDefaultAuthenticationProviders(httpSecurity);
+		if (!this.authenticationProviders.isEmpty()) {
+			authenticationProviders.addAll(0, this.authenticationProviders);
 		}
 		}
-		httpSecurity.authenticationProvider(postProcess(oidcUserInfoAuthenticationProvider));
+		this.authenticationProvidersConsumer.accept(authenticationProviders);
+		authenticationProviders.forEach(authenticationProvider ->
+				httpSecurity.authenticationProvider(postProcess(authenticationProvider)));
 	}
 	}
 
 
 	@Override
 	@Override
@@ -100,6 +208,19 @@ public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configur
 				new OidcUserInfoEndpointFilter(
 				new OidcUserInfoEndpointFilter(
 						authenticationManager,
 						authenticationManager,
 						authorizationServerSettings.getOidcUserInfoEndpoint());
 						authorizationServerSettings.getOidcUserInfoEndpoint());
+		List<AuthenticationConverter> authenticationConverters = createDefaultAuthenticationConverters();
+		if (!this.userInfoRequestConverters.isEmpty()) {
+			authenticationConverters.addAll(0, this.userInfoRequestConverters);
+		}
+		this.userInfoRequestConvertersConsumer.accept(authenticationConverters);
+		oidcUserInfoEndpointFilter.setAuthenticationConverter(
+				new DelegatingAuthenticationConverter(authenticationConverters));
+		if (this.userInfoResponseHandler != null) {
+			oidcUserInfoEndpointFilter.setAuthenticationSuccessHandler(this.userInfoResponseHandler);
+		}
+		if (this.errorResponseHandler != null) {
+			oidcUserInfoEndpointFilter.setAuthenticationFailureHandler(this.errorResponseHandler);
+		}
 		httpSecurity.addFilterAfter(postProcess(oidcUserInfoEndpointFilter), AuthorizationFilter.class);
 		httpSecurity.addFilterAfter(postProcess(oidcUserInfoEndpointFilter), AuthorizationFilter.class);
 	}
 	}
 
 
@@ -108,4 +229,31 @@ public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configur
 		return this.requestMatcher;
 		return this.requestMatcher;
 	}
 	}
 
 
+	private static List<AuthenticationConverter> createDefaultAuthenticationConverters() {
+		List<AuthenticationConverter> authenticationConverters = new ArrayList<>();
+
+		authenticationConverters.add(
+				(request) -> {
+					Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
+					return new OidcUserInfoAuthenticationToken(authentication);
+				}
+		);
+
+		return authenticationConverters;
+	}
+
+	private List<AuthenticationProvider> createDefaultAuthenticationProviders(HttpSecurity httpSecurity) {
+		List<AuthenticationProvider> authenticationProviders = new ArrayList<>();
+
+		OidcUserInfoAuthenticationProvider oidcUserInfoAuthenticationProvider =
+				new OidcUserInfoAuthenticationProvider(
+						OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity));
+		if (this.userInfoMapper != null) {
+			oidcUserInfoAuthenticationProvider.setUserInfoMapper(this.userInfoMapper);
+		}
+		authenticationProviders.add(oidcUserInfoAuthenticationProvider);
+
+		return authenticationProviders;
+	}
+
 }
 }

+ 2 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientConfigurationAuthenticationProvider.java

@@ -46,6 +46,7 @@ import org.springframework.util.StringUtils;
  * @since 0.4.0
  * @since 0.4.0
  * @see RegisteredClientRepository
  * @see RegisteredClientRepository
  * @see OAuth2AuthorizationService
  * @see OAuth2AuthorizationService
+ * @see OidcClientRegistrationAuthenticationToken
  * @see OidcClientRegistrationAuthenticationProvider
  * @see OidcClientRegistrationAuthenticationProvider
  * @see <a href="https://openid.net/specs/openid-connect-registration-1_0.html#ClientConfigurationEndpoint">4. Client Configuration Endpoint</a>
  * @see <a href="https://openid.net/specs/openid-connect-registration-1_0.html#ClientConfigurationEndpoint">4. Client Configuration Endpoint</a>
  */
  */
@@ -67,7 +68,7 @@ public final class OidcClientConfigurationAuthenticationProvider implements Auth
 		Assert.notNull(authorizationService, "authorizationService cannot be null");
 		Assert.notNull(authorizationService, "authorizationService cannot be null");
 		this.registeredClientRepository = registeredClientRepository;
 		this.registeredClientRepository = registeredClientRepository;
 		this.authorizationService = authorizationService;
 		this.authorizationService = authorizationService;
-		this.clientRegistrationConverter = new OidcClientRegistrationConverter();
+		this.clientRegistrationConverter = new RegisteredClientOidcClientRegistrationConverter();
 	}
 	}
 
 
 	@Override
 	@Override

+ 16 - 4
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java

@@ -74,6 +74,7 @@ import org.springframework.util.StringUtils;
  * @see RegisteredClientRepository
  * @see RegisteredClientRepository
  * @see OAuth2AuthorizationService
  * @see OAuth2AuthorizationService
  * @see OAuth2TokenGenerator
  * @see OAuth2TokenGenerator
+ * @see OidcClientRegistrationAuthenticationToken
  * @see OidcClientConfigurationAuthenticationProvider
  * @see OidcClientConfigurationAuthenticationProvider
  * @see <a href="https://openid.net/specs/openid-connect-registration-1_0.html#ClientRegistration">3. Client Registration Endpoint</a>
  * @see <a href="https://openid.net/specs/openid-connect-registration-1_0.html#ClientRegistration">3. Client Registration Endpoint</a>
  */
  */
@@ -84,7 +85,7 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
 	private final OAuth2AuthorizationService authorizationService;
 	private final OAuth2AuthorizationService authorizationService;
 	private final OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator;
 	private final OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator;
 	private final Converter<RegisteredClient, OidcClientRegistration> clientRegistrationConverter;
 	private final Converter<RegisteredClient, OidcClientRegistration> clientRegistrationConverter;
-	private final Converter<OidcClientRegistration, RegisteredClient> registeredClientConverter;
+	private Converter<OidcClientRegistration, RegisteredClient> registeredClientConverter;
 
 
 	/**
 	/**
 	 * Constructs an {@code OidcClientRegistrationAuthenticationProvider} using the provided parameters.
 	 * Constructs an {@code OidcClientRegistrationAuthenticationProvider} using the provided parameters.
@@ -102,8 +103,8 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
 		this.registeredClientRepository = registeredClientRepository;
 		this.registeredClientRepository = registeredClientRepository;
 		this.authorizationService = authorizationService;
 		this.authorizationService = authorizationService;
 		this.tokenGenerator = tokenGenerator;
 		this.tokenGenerator = tokenGenerator;
-		this.clientRegistrationConverter = new OidcClientRegistrationConverter();
-		this.registeredClientConverter = new RegisteredClientConverter();
+		this.clientRegistrationConverter = new RegisteredClientOidcClientRegistrationConverter();
+		this.registeredClientConverter = new OidcClientRegistrationRegisteredClientConverter();
 	}
 	}
 
 
 	@Override
 	@Override
@@ -147,6 +148,17 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
 		return OidcClientRegistrationAuthenticationToken.class.isAssignableFrom(authentication);
 		return OidcClientRegistrationAuthenticationToken.class.isAssignableFrom(authentication);
 	}
 	}
 
 
+	/**
+	 * Sets the {@link Converter} used for converting an {@link OidcClientRegistration} to a {@link RegisteredClient}.
+	 *
+	 * @param registeredClientConverter the {@link Converter} used for converting an {@link OidcClientRegistration} to a {@link RegisteredClient}
+	 * @since 0.4.0
+	 */
+	public void setRegisteredClientConverter(Converter<OidcClientRegistration, RegisteredClient> registeredClientConverter) {
+		Assert.notNull(registeredClientConverter, "registeredClientConverter cannot be null");
+		this.registeredClientConverter = registeredClientConverter;
+	}
+
 	private OidcClientRegistrationAuthenticationToken registerClient(OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication,
 	private OidcClientRegistrationAuthenticationToken registerClient(OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication,
 			OAuth2Authorization authorization) {
 			OAuth2Authorization authorization) {
 
 
@@ -293,7 +305,7 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
 		throw new OAuth2AuthenticationException(error);
 		throw new OAuth2AuthenticationException(error);
 	}
 	}
 
 
-	private static final class RegisteredClientConverter implements Converter<OidcClientRegistration, RegisteredClient> {
+	private static final class OidcClientRegistrationRegisteredClientConverter implements Converter<OidcClientRegistration, RegisteredClient> {
 		private static final StringKeyGenerator CLIENT_ID_GENERATOR = new Base64StringKeyGenerator(
 		private static final StringKeyGenerator CLIENT_ID_GENERATOR = new Base64StringKeyGenerator(
 				Base64.getUrlEncoder().withoutPadding(), 32);
 				Base64.getUrlEncoder().withoutPadding(), 32);
 		private static final StringKeyGenerator CLIENT_SECRET_GENERATOR = new Base64StringKeyGenerator(
 		private static final StringKeyGenerator CLIENT_SECRET_GENERATOR = new Base64StringKeyGenerator(

+ 1 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationConverter.java → oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/RegisteredClientOidcClientRegistrationConverter.java

@@ -31,7 +31,7 @@ import org.springframework.web.util.UriComponentsBuilder;
  * @author Joe Grandja
  * @author Joe Grandja
  * @since 0.4.0
  * @since 0.4.0
  */
  */
-final class OidcClientRegistrationConverter implements Converter<RegisteredClient, OidcClientRegistration> {
+final class RegisteredClientOidcClientRegistrationConverter implements Converter<RegisteredClient, OidcClientRegistration> {
 
 
 	@Override
 	@Override
 	public OidcClientRegistration convert(RegisteredClient registeredClient) {
 	public OidcClientRegistration convert(RegisteredClient registeredClient) {

+ 64 - 19
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilter.java

@@ -27,6 +27,8 @@ import org.springframework.http.HttpStatus;
 import org.springframework.http.converter.HttpMessageConverter;
 import org.springframework.http.converter.HttpMessageConverter;
 import org.springframework.http.server.ServletServerHttpResponse;
 import org.springframework.http.server.ServletServerHttpResponse;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.AuthenticationManager;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2Error;
@@ -40,6 +42,8 @@ import org.springframework.security.oauth2.server.authorization.oidc.authenticat
 import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter;
 import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter;
 import org.springframework.security.oauth2.server.authorization.oidc.web.authentication.OidcClientRegistrationAuthenticationConverter;
 import org.springframework.security.oauth2.server.authorization.oidc.web.authentication.OidcClientRegistrationAuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.util.matcher.AndRequestMatcher;
 import org.springframework.security.web.util.matcher.AndRequestMatcher;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.OrRequestMatcher;
 import org.springframework.security.web.util.matcher.OrRequestMatcher;
@@ -53,6 +57,7 @@ import org.springframework.web.filter.OncePerRequestFilter;
  *
  *
  * @author Ovidiu Popa
  * @author Ovidiu Popa
  * @author Joe Grandja
  * @author Joe Grandja
+ * @author Daniel Garnier-Moiroux
  * @since 0.1.1
  * @since 0.1.1
  * @see OidcClientRegistration
  * @see OidcClientRegistration
  * @see OidcClientRegistrationAuthenticationConverter
  * @see OidcClientRegistrationAuthenticationConverter
@@ -73,7 +78,9 @@ public final class OidcClientRegistrationEndpointFilter extends OncePerRequestFi
 			new OidcClientRegistrationHttpMessageConverter();
 			new OidcClientRegistrationHttpMessageConverter();
 	private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
 	private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
 			new OAuth2ErrorHttpMessageConverter();
 			new OAuth2ErrorHttpMessageConverter();
-	private AuthenticationConverter authenticationConverter;
+	private AuthenticationConverter authenticationConverter = new OidcClientRegistrationAuthenticationConverter();
+	private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendClientRegistrationResponse;
+	private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse;
 
 
 	/**
 	/**
 	 * Constructs an {@code OidcClientRegistrationEndpointFilter} using the provided parameters.
 	 * Constructs an {@code OidcClientRegistrationEndpointFilter} using the provided parameters.
@@ -99,7 +106,6 @@ public final class OidcClientRegistrationEndpointFilter extends OncePerRequestFi
 				new AntPathRequestMatcher(
 				new AntPathRequestMatcher(
 						clientRegistrationEndpointUri, HttpMethod.POST.name()),
 						clientRegistrationEndpointUri, HttpMethod.POST.name()),
 				createClientConfigurationMatcher(clientRegistrationEndpointUri));
 				createClientConfigurationMatcher(clientRegistrationEndpointUri));
-		this.authenticationConverter = new OidcClientRegistrationAuthenticationConverter();
 	}
 	}
 
 
 	private static RequestMatcher createClientConfigurationMatcher(String clientRegistrationEndpointUri) {
 	private static RequestMatcher createClientConfigurationMatcher(String clientRegistrationEndpointUri) {
@@ -124,39 +130,78 @@ public final class OidcClientRegistrationEndpointFilter extends OncePerRequestFi
 		}
 		}
 
 
 		try {
 		try {
-			OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication =
-					(OidcClientRegistrationAuthenticationToken) this.authenticationConverter.convert(request);
+			Authentication clientRegistrationAuthentication = this.authenticationConverter.convert(request);
 
 
-			OidcClientRegistrationAuthenticationToken clientRegistrationAuthenticationResult =
-					(OidcClientRegistrationAuthenticationToken) this.authenticationManager.authenticate(clientRegistrationAuthentication);
-
-			HttpStatus httpStatus = HttpStatus.OK;
-			if (clientRegistrationAuthentication.getClientRegistration() != null) {
-				httpStatus = HttpStatus.CREATED;
-			}
-
-			sendClientRegistrationResponse(response, httpStatus, clientRegistrationAuthenticationResult.getClientRegistration());
+			Authentication clientRegistrationAuthenticationResult =
+					this.authenticationManager.authenticate(clientRegistrationAuthentication);
 
 
+			this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, clientRegistrationAuthenticationResult);
 		} catch (OAuth2AuthenticationException ex) {
 		} catch (OAuth2AuthenticationException ex) {
-			sendErrorResponse(response, ex.getError());
+			this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
 		} catch (Exception ex) {
 		} catch (Exception ex) {
 			OAuth2Error error = new OAuth2Error(
 			OAuth2Error error = new OAuth2Error(
 					OAuth2ErrorCodes.INVALID_REQUEST,
 					OAuth2ErrorCodes.INVALID_REQUEST,
-					"OpenID Client Registration Error: " + ex.getMessage(),
+					"OpenID Connect 1.0 Client Registration Error: " + ex.getMessage(),
 					"https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationError");
 					"https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationError");
-			sendErrorResponse(response, error);
+			this.authenticationFailureHandler.onAuthenticationFailure(request, response,
+					new OAuth2AuthenticationException(error));
 		} finally {
 		} finally {
 			SecurityContextHolder.clearContext();
 			SecurityContextHolder.clearContext();
 		}
 		}
 	}
 	}
 
 
-	private void sendClientRegistrationResponse(HttpServletResponse response, HttpStatus httpStatus, OidcClientRegistration clientRegistration) throws IOException {
+	/**
+	 * Sets the {@link AuthenticationConverter} used when attempting to extract a Client Registration Request from {@link HttpServletRequest}
+	 * to an instance of {@link OidcClientRegistrationAuthenticationToken} used for authenticating the request.
+	 *
+	 * @param authenticationConverter an {@link AuthenticationConverter} used when attempting to extract a Client Registration Request from {@link HttpServletRequest}
+	 * @since 0.4.0
+	 */
+	public void setAuthenticationConverter(AuthenticationConverter authenticationConverter) {
+		Assert.notNull(authenticationConverter, "authenticationConverter cannot be null");
+		this.authenticationConverter = authenticationConverter;
+	}
+
+	/**
+	 * Sets the {@link AuthenticationSuccessHandler} used for handling an {@link OidcClientRegistrationAuthenticationToken}
+	 * and returning the {@link OidcClientRegistration Client Registration Response}.
+	 *
+	 * @param authenticationSuccessHandler the {@link AuthenticationSuccessHandler} used for handling an {@link OidcClientRegistrationAuthenticationToken}
+	 * @see 0.4.0
+	 */
+	public void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) {
+		Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null");
+		this.authenticationSuccessHandler = authenticationSuccessHandler;
+	}
+
+	/**
+	 * Sets the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException}
+	 * and returning the {@link OAuth2Error Error Response}.
+	 *
+	 * @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException}
+	 * @since 0.4.0
+	 */
+	public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) {
+		Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
+		this.authenticationFailureHandler = authenticationFailureHandler;
+	}
+
+	private void sendClientRegistrationResponse(HttpServletRequest request, HttpServletResponse response,
+			Authentication authentication) throws IOException {
+		OidcClientRegistration clientRegistration = ((OidcClientRegistrationAuthenticationToken) authentication)
+				.getClientRegistration();
 		ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
 		ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
-		httpResponse.setStatusCode(httpStatus);
+		if (HttpMethod.POST.name().equals(request.getMethod())) {
+			httpResponse.setStatusCode(HttpStatus.CREATED);
+		} else {
+			httpResponse.setStatusCode(HttpStatus.OK);
+		}
 		this.clientRegistrationHttpMessageConverter.write(clientRegistration, null, httpResponse);
 		this.clientRegistrationHttpMessageConverter.write(clientRegistration, null, httpResponse);
 	}
 	}
 
 
-	private void sendErrorResponse(HttpServletResponse response, OAuth2Error error) throws IOException {
+	private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
+			AuthenticationException authenticationException) throws IOException {
+		OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError();
 		HttpStatus httpStatus = HttpStatus.BAD_REQUEST;
 		HttpStatus httpStatus = HttpStatus.BAD_REQUEST;
 		if (OAuth2ErrorCodes.INVALID_TOKEN.equals(error.getErrorCode())) {
 		if (OAuth2ErrorCodes.INVALID_TOKEN.equals(error.getErrorCode())) {
 			httpStatus = HttpStatus.UNAUTHORIZED;
 			httpStatus = HttpStatus.UNAUTHORIZED;

+ 66 - 13
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilter.java

@@ -28,14 +28,19 @@ import org.springframework.http.converter.HttpMessageConverter;
 import org.springframework.http.server.ServletServerHttpResponse;
 import org.springframework.http.server.ServletServerHttpResponse;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
 import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
 import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
 import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
+import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcUserInfoHttpMessageConverter;
 import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcUserInfoHttpMessageConverter;
+import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.OrRequestMatcher;
 import org.springframework.security.web.util.matcher.OrRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -47,8 +52,10 @@ import org.springframework.web.filter.OncePerRequestFilter;
  *
  *
  * @author Ido Salomon
  * @author Ido Salomon
  * @author Steve Riesenberg
  * @author Steve Riesenberg
+ * @author Daniel Garnier-Moiroux
  * @since 0.2.1
  * @since 0.2.1
  * @see OidcUserInfo
  * @see OidcUserInfo
+ * @see OidcUserInfoAuthenticationProvider
  * @see <a href="https://openid.net/specs/openid-connect-core-1_0.html#UserInfo">5.3. UserInfo Endpoint</a>
  * @see <a href="https://openid.net/specs/openid-connect-core-1_0.html#UserInfo">5.3. UserInfo Endpoint</a>
  */
  */
 public final class OidcUserInfoEndpointFilter extends OncePerRequestFilter {
 public final class OidcUserInfoEndpointFilter extends OncePerRequestFilter {
@@ -60,11 +67,13 @@ public final class OidcUserInfoEndpointFilter extends OncePerRequestFilter {
 
 
 	private final AuthenticationManager authenticationManager;
 	private final AuthenticationManager authenticationManager;
 	private final RequestMatcher userInfoEndpointMatcher;
 	private final RequestMatcher userInfoEndpointMatcher;
-
 	private final HttpMessageConverter<OidcUserInfo> userInfoHttpMessageConverter =
 	private final HttpMessageConverter<OidcUserInfo> userInfoHttpMessageConverter =
 			new OidcUserInfoHttpMessageConverter();
 			new OidcUserInfoHttpMessageConverter();
 	private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
 	private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
 			new OAuth2ErrorHttpMessageConverter();
 			new OAuth2ErrorHttpMessageConverter();
+	private AuthenticationConverter authenticationConverter = this::createAuthentication;
+	private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendUserInfoResponse;
+	private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse;
 
 
 	/**
 	/**
 	 * Constructs an {@code OidcUserInfoEndpointFilter} using the provided parameters.
 	 * Constructs an {@code OidcUserInfoEndpointFilter} using the provided parameters.
@@ -100,34 +109,77 @@ public final class OidcUserInfoEndpointFilter extends OncePerRequestFilter {
 		}
 		}
 
 
 		try {
 		try {
-			Authentication principal = SecurityContextHolder.getContext().getAuthentication();
-
-			OidcUserInfoAuthenticationToken userInfoAuthentication = new OidcUserInfoAuthenticationToken(principal);
-
-			OidcUserInfoAuthenticationToken userInfoAuthenticationResult =
-					(OidcUserInfoAuthenticationToken) this.authenticationManager.authenticate(userInfoAuthentication);
+			Authentication userInfoAuthentication = this.authenticationConverter.convert(request);
 
 
-			sendUserInfoResponse(response, userInfoAuthenticationResult.getUserInfo());
+			Authentication userInfoAuthenticationResult =
+					this.authenticationManager.authenticate(userInfoAuthentication);
 
 
+			this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, userInfoAuthenticationResult);
 		} catch (OAuth2AuthenticationException ex) {
 		} catch (OAuth2AuthenticationException ex) {
-			sendErrorResponse(response, ex.getError());
+			this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
 		} catch (Exception ex) {
 		} catch (Exception ex) {
 			OAuth2Error error = new OAuth2Error(
 			OAuth2Error error = new OAuth2Error(
 					OAuth2ErrorCodes.INVALID_REQUEST,
 					OAuth2ErrorCodes.INVALID_REQUEST,
 					"OpenID Connect 1.0 UserInfo Error: " + ex.getMessage(),
 					"OpenID Connect 1.0 UserInfo Error: " + ex.getMessage(),
 					"https://openid.net/specs/openid-connect-core-1_0.html#UserInfoError");
 					"https://openid.net/specs/openid-connect-core-1_0.html#UserInfoError");
-			sendErrorResponse(response, error);
+			this.authenticationFailureHandler.onAuthenticationFailure(request, response,
+					new OAuth2AuthenticationException(error));
 		} finally {
 		} finally {
 			SecurityContextHolder.clearContext();
 			SecurityContextHolder.clearContext();
 		}
 		}
 	}
 	}
 
 
-	private void sendUserInfoResponse(HttpServletResponse response, OidcUserInfo userInfo) throws IOException {
+	/**
+	 * Sets the {@link AuthenticationConverter} used when attempting to extract an UserInfo Request from {@link HttpServletRequest}
+	 * to an instance of {@link OidcUserInfoAuthenticationToken} used for authenticating the request.
+	 *
+	 * @param authenticationConverter the {@link AuthenticationConverter} used when attempting to extract an UserInfo Request from {@link HttpServletRequest}
+	 * @since 0.4.0
+	 */
+	public void setAuthenticationConverter(AuthenticationConverter authenticationConverter) {
+		Assert.notNull(authenticationConverter, "authenticationConverter cannot be null");
+		this.authenticationConverter = authenticationConverter;
+	}
+
+	/**
+	 * Sets the {@link AuthenticationSuccessHandler} used for handling an {@link OidcUserInfoAuthenticationToken}
+	 * and returning the {@link OidcUserInfo UserInfo Response}.
+	 *
+	 * @param authenticationSuccessHandler the {@link AuthenticationSuccessHandler} used for handling an {@link OidcUserInfoAuthenticationToken}
+	 * @since 0.4.0
+	 */
+	public void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) {
+		Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null");
+		this.authenticationSuccessHandler = authenticationSuccessHandler;
+	}
+
+	/**
+	 * Sets the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException}
+	 * and returning the {@link OAuth2Error Error Response}.
+	 *
+	 * @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException}
+	 * @since 0.4.0
+	 */
+	public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) {
+		Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
+		this.authenticationFailureHandler = authenticationFailureHandler;
+	}
+
+	private Authentication createAuthentication(HttpServletRequest request) {
+		Authentication principal = SecurityContextHolder.getContext().getAuthentication();
+		return new OidcUserInfoAuthenticationToken(principal);
+	}
+
+	private void sendUserInfoResponse(HttpServletRequest request, HttpServletResponse response,
+			Authentication authentication) throws IOException {
+		OidcUserInfoAuthenticationToken userInfoAuthenticationToken = (OidcUserInfoAuthenticationToken) authentication;
 		ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
 		ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
-		this.userInfoHttpMessageConverter.write(userInfo, null, httpResponse);
+		this.userInfoHttpMessageConverter.write(userInfoAuthenticationToken.getUserInfo(), null, httpResponse);
 	}
 	}
 
 
-	private void sendErrorResponse(HttpServletResponse response, OAuth2Error error) throws IOException {
+	private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
+			AuthenticationException authenticationException) throws IOException {
+		OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError();
 		HttpStatus httpStatus = HttpStatus.BAD_REQUEST;
 		HttpStatus httpStatus = HttpStatus.BAD_REQUEST;
 		if (error.getErrorCode().equals(OAuth2ErrorCodes.INVALID_TOKEN)) {
 		if (error.getErrorCode().equals(OAuth2ErrorCodes.INVALID_TOKEN)) {
 			httpStatus = HttpStatus.UNAUTHORIZED;
 			httpStatus = HttpStatus.UNAUTHORIZED;
@@ -138,4 +190,5 @@ public final class OidcUserInfoEndpointFilter extends OncePerRequestFilter {
 		httpResponse.setStatusCode(httpStatus);
 		httpResponse.setStatusCode(httpStatus);
 		this.errorHttpResponseConverter.write(error, null, httpResponse);
 		this.errorHttpResponseConverter.write(error, null, httpResponse);
 	}
 	}
+
 }
 }

+ 18 - 4
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java

@@ -18,7 +18,9 @@ package org.springframework.security.oauth2.server.authorization.web;
 import java.io.IOException;
 import java.io.IOException;
 import java.nio.charset.StandardCharsets;
 import java.nio.charset.StandardCharsets;
 import java.util.Arrays;
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.HashSet;
+import java.util.Map;
 import java.util.Set;
 import java.util.Set;
 
 
 import jakarta.servlet.FilterChain;
 import jakarta.servlet.FilterChain;
@@ -287,10 +289,16 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
 		UriComponentsBuilder uriBuilder = UriComponentsBuilder
 		UriComponentsBuilder uriBuilder = UriComponentsBuilder
 				.fromUriString(authorizationCodeRequestAuthentication.getRedirectUri())
 				.fromUriString(authorizationCodeRequestAuthentication.getRedirectUri())
 				.queryParam(OAuth2ParameterNames.CODE, authorizationCodeRequestAuthentication.getAuthorizationCode().getTokenValue());
 				.queryParam(OAuth2ParameterNames.CODE, authorizationCodeRequestAuthentication.getAuthorizationCode().getTokenValue());
+		String redirectUri;
 		if (StringUtils.hasText(authorizationCodeRequestAuthentication.getState())) {
 		if (StringUtils.hasText(authorizationCodeRequestAuthentication.getState())) {
-			uriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState());
+			uriBuilder.queryParam(OAuth2ParameterNames.STATE, "{state}");
+			Map<String, String> queryParams = new HashMap<>();
+			queryParams.put(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState());
+			redirectUri = uriBuilder.build(queryParams).toString();
+		} else {
+			redirectUri = uriBuilder.toUriString();
 		}
 		}
-		this.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString());
+		this.redirectStrategy.sendRedirect(request, response, redirectUri);
 	}
 	}
 
 
 	private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
 	private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
@@ -317,10 +325,16 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
 		if (StringUtils.hasText(error.getUri())) {
 		if (StringUtils.hasText(error.getUri())) {
 			uriBuilder.queryParam(OAuth2ParameterNames.ERROR_URI, error.getUri());
 			uriBuilder.queryParam(OAuth2ParameterNames.ERROR_URI, error.getUri());
 		}
 		}
+		String redirectUri;
 		if (StringUtils.hasText(authorizationCodeRequestAuthentication.getState())) {
 		if (StringUtils.hasText(authorizationCodeRequestAuthentication.getState())) {
-			uriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState());
+			uriBuilder.queryParam(OAuth2ParameterNames.STATE, "{state}");
+			Map<String, String> queryParams = new HashMap<>();
+			queryParams.put(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState());
+			redirectUri = uriBuilder.build(queryParams).toString();
+		} else {
+			redirectUri = uriBuilder.toUriString();
 		}
 		}
-		this.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString());
+		this.redirectStrategy.sendRedirect(request, response, redirectUri);
 	}
 	}
 
 
 	/**
 	/**

+ 1 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java

@@ -84,7 +84,7 @@ public class TestOAuth2Authorizations {
 				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
 				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
 				.authorizedScopes(authorizationRequest.getScopes())
 				.authorizedScopes(authorizationRequest.getScopes())
 				.token(authorizationCode)
 				.token(authorizationCode)
-				.attribute(OAuth2ParameterNames.STATE, "state")
+				.attribute(OAuth2ParameterNames.STATE, "consent-state")
 				.attribute(OAuth2AuthorizationRequest.class.getName(), authorizationRequest)
 				.attribute(OAuth2AuthorizationRequest.class.getName(), authorizationRequest)
 				.attribute(Principal.class.getName(),
 				.attribute(Principal.class.getName(),
 						new TestingAuthenticationToken("principal", null, "ROLE_A", "ROLE_B"));
 						new TestingAuthenticationToken("principal", null, "ROLE_A", "ROLE_B"));

+ 29 - 9
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java

@@ -70,6 +70,7 @@ import org.springframework.security.crypto.password.PasswordEncoder;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2Token;
 import org.springframework.security.oauth2.core.OAuth2Token;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
@@ -160,6 +161,9 @@ public class OAuth2AuthorizationCodeGrantTests {
 	private static final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
 	private static final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
 	private static final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM";
 	private static final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM";
 	private static final String AUTHORITIES_CLAIM = "authorities";
 	private static final String AUTHORITIES_CLAIM = "authorities";
+	private static final String STATE_URL_UNENCODED = "awrD0fCnEcTUPFgmyy2SU89HZNcnAJ60ZW6l39YI0KyVjmIZ+004pwm9j55li7BoydXYysH4enZMF21Q";
+	private static final String STATE_URL_ENCODED = "awrD0fCnEcTUPFgmyy2SU89HZNcnAJ60ZW6l39YI0KyVjmIZ%2B004pwm9j55li7BoydXYysH4enZMF21Q";
+
 	private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE);
 	private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE);
 	private static final OAuth2TokenType STATE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.STATE);
 	private static final OAuth2TokenType STATE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.STATE);
 
 
@@ -291,7 +295,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 				.andExpect(status().is3xxRedirection())
 				.andExpect(status().is3xxRedirection())
 				.andReturn();
 				.andReturn();
 		String redirectedUrl = mvcResult.getResponse().getRedirectedUrl();
 		String redirectedUrl = mvcResult.getResponse().getRedirectedUrl();
-		assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=state");
+		assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=" + STATE_URL_ENCODED);
 
 
 		String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code");
 		String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code");
 		OAuth2Authorization authorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE);
 		OAuth2Authorization authorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE);
@@ -383,7 +387,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 				.andExpect(status().is3xxRedirection())
 				.andExpect(status().is3xxRedirection())
 				.andReturn();
 				.andReturn();
 		String redirectedUrl = mvcResult.getResponse().getRedirectedUrl();
 		String redirectedUrl = mvcResult.getResponse().getRedirectedUrl();
-		assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=state");
+		assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=" + STATE_URL_ENCODED);
 
 
 		String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code");
 		String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code");
 		OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE);
 		OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE);
@@ -427,7 +431,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 				.andExpect(status().is3xxRedirection())
 				.andExpect(status().is3xxRedirection())
 				.andReturn();
 				.andReturn();
 		String redirectedUrl = mvcResult.getResponse().getRedirectedUrl();
 		String redirectedUrl = mvcResult.getResponse().getRedirectedUrl();
-		assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=state");
+		assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=" + STATE_URL_ENCODED);
 
 
 		String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code");
 		String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code");
 		OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE);
 		OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE);
@@ -503,19 +507,27 @@ public class OAuth2AuthorizationCodeGrantTests {
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
 				.principalName("user")
 				.principalName("user")
 				.build();
 				.build();
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName());
+		OAuth2AuthorizationRequest updatedAuthorizationRequest =
+				OAuth2AuthorizationRequest.from(authorizationRequest)
+						.state(STATE_URL_UNENCODED)
+						.build();
+		authorization = OAuth2Authorization.from(authorization)
+				.attribute(OAuth2AuthorizationRequest.class.getName(), updatedAuthorizationRequest)
+				.build();
 		this.authorizationService.save(authorization);
 		this.authorizationService.save(authorization);
 
 
 		MvcResult mvcResult = this.mvc.perform(post(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
 		MvcResult mvcResult = this.mvc.perform(post(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
 				.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
 				.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
 				.param(OAuth2ParameterNames.SCOPE, "message.read")
 				.param(OAuth2ParameterNames.SCOPE, "message.read")
 				.param(OAuth2ParameterNames.SCOPE, "message.write")
 				.param(OAuth2ParameterNames.SCOPE, "message.write")
-				.param(OAuth2ParameterNames.STATE, "state")
+				.param(OAuth2ParameterNames.STATE, authorization.<String>getAttribute(OAuth2ParameterNames.STATE))
 				.with(user("user")))
 				.with(user("user")))
 				.andExpect(status().is3xxRedirection())
 				.andExpect(status().is3xxRedirection())
 				.andReturn();
 				.andReturn();
 
 
 		String redirectedUrl = mvcResult.getResponse().getRedirectedUrl();
 		String redirectedUrl = mvcResult.getResponse().getRedirectedUrl();
-		assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=state");
+		assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=" + STATE_URL_ENCODED);
 
 
 		String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code");
 		String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code");
 		OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE);
 		OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE);
@@ -583,18 +595,26 @@ public class OAuth2AuthorizationCodeGrantTests {
 
 
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
 				.build();
 				.build();
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName());
+		OAuth2AuthorizationRequest updatedAuthorizationRequest =
+				OAuth2AuthorizationRequest.from(authorizationRequest)
+						.state(STATE_URL_UNENCODED)
+						.build();
+		authorization = OAuth2Authorization.from(authorization)
+				.attribute(OAuth2AuthorizationRequest.class.getName(), updatedAuthorizationRequest)
+				.build();
 		this.authorizationService.save(authorization);
 		this.authorizationService.save(authorization);
 
 
 		MvcResult mvcResult = this.mvc.perform(post(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
 		MvcResult mvcResult = this.mvc.perform(post(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
 				.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
 				.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
 				.param("authority", "authority-1 authority-2")
 				.param("authority", "authority-1 authority-2")
-				.param(OAuth2ParameterNames.STATE, "state")
+				.param(OAuth2ParameterNames.STATE, authorization.<String>getAttribute(OAuth2ParameterNames.STATE))
 				.with(user("principal")))
 				.with(user("principal")))
 				.andExpect(status().is3xxRedirection())
 				.andExpect(status().is3xxRedirection())
 				.andReturn();
 				.andReturn();
 
 
 		String redirectedUrl = mvcResult.getResponse().getRedirectedUrl();
 		String redirectedUrl = mvcResult.getResponse().getRedirectedUrl();
-		assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=state");
+		assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=" + STATE_URL_ENCODED);
 
 
 		String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code");
 		String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code");
 		OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE);
 		OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE);
@@ -632,7 +652,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 		OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult =
 		OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult =
 				new OAuth2AuthorizationCodeRequestAuthenticationToken(
 				new OAuth2AuthorizationCodeRequestAuthenticationToken(
 						"https://provider.com/oauth2/authorize", registeredClient.getClientId(), principal, authorizationCode,
 						"https://provider.com/oauth2/authorize", registeredClient.getClientId(), principal, authorizationCode,
-						registeredClient.getRedirectUris().iterator().next(), "state", registeredClient.getScopes());
+						registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED, registeredClient.getScopes());
 		when(authorizationRequestConverter.convert(any())).thenReturn(authorizationCodeRequestAuthenticationResult);
 		when(authorizationRequestConverter.convert(any())).thenReturn(authorizationCodeRequestAuthenticationResult);
 		when(authorizationRequestAuthenticationProvider.supports(eq(OAuth2AuthorizationCodeRequestAuthenticationToken.class))).thenReturn(true);
 		when(authorizationRequestAuthenticationProvider.supports(eq(OAuth2AuthorizationCodeRequestAuthenticationToken.class))).thenReturn(true);
 		when(authorizationRequestAuthenticationProvider.authenticate(any())).thenReturn(authorizationCodeRequestAuthenticationResult);
 		when(authorizationRequestAuthenticationProvider.authenticate(any())).thenReturn(authorizationCodeRequestAuthenticationResult);
@@ -718,7 +738,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 		parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next());
 		parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next());
 		parameters.set(OAuth2ParameterNames.SCOPE,
 		parameters.set(OAuth2ParameterNames.SCOPE,
 				StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
 				StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
-		parameters.set(OAuth2ParameterNames.STATE, "state");
+		parameters.set(OAuth2ParameterNames.STATE, STATE_URL_UNENCODED);
 		return parameters;
 		return parameters;
 	}
 	}
 
 

+ 147 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationTests.java

@@ -18,6 +18,10 @@ package org.springframework.security.oauth2.server.authorization.config.annotati
 import java.time.Instant;
 import java.time.Instant;
 import java.time.temporal.ChronoUnit;
 import java.time.temporal.ChronoUnit;
 import java.util.Collections;
 import java.util.Collections;
+import java.util.List;
+import java.util.function.Consumer;
+
+import jakarta.servlet.http.HttpServletResponse;
 
 
 import com.nimbusds.jose.jwk.JWKSet;
 import com.nimbusds.jose.jwk.JWKSet;
 import com.nimbusds.jose.jwk.source.JWKSource;
 import com.nimbusds.jose.jwk.source.JWKSource;
@@ -30,6 +34,7 @@ import org.junit.Before;
 import org.junit.BeforeClass;
 import org.junit.BeforeClass;
 import org.junit.Rule;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.Test;
+import org.mockito.ArgumentCaptor;
 
 
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Bean;
@@ -38,6 +43,7 @@ import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpStatus;
 import org.springframework.http.HttpStatus;
 import org.springframework.http.MediaType;
 import org.springframework.http.MediaType;
 import org.springframework.http.converter.HttpMessageConverter;
 import org.springframework.http.converter.HttpMessageConverter;
+import org.springframework.http.server.ServletServerHttpResponse;
 import org.springframework.jdbc.core.JdbcOperations;
 import org.springframework.jdbc.core.JdbcOperations;
 import org.springframework.jdbc.core.JdbcTemplate;
 import org.springframework.jdbc.core.JdbcTemplate;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
@@ -46,6 +52,7 @@ import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
 import org.springframework.mock.http.MockHttpOutputMessage;
 import org.springframework.mock.http.MockHttpOutputMessage;
 import org.springframework.mock.http.client.MockClientHttpResponse;
 import org.springframework.mock.http.client.MockClientHttpResponse;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.config.Customizer;
 import org.springframework.security.config.Customizer;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
@@ -55,6 +62,7 @@ import org.springframework.security.crypto.password.PasswordEncoder;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@@ -77,11 +85,18 @@ import org.springframework.security.oauth2.server.authorization.client.Registere
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration;
 import org.springframework.security.oauth2.server.authorization.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration;
 import org.springframework.security.oauth2.server.authorization.oidc.OidcClientRegistration;
 import org.springframework.security.oauth2.server.authorization.oidc.OidcClientRegistration;
+import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientConfigurationAuthenticationProvider;
+import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationProvider;
+import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter;
 import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter;
+import org.springframework.security.oauth2.server.authorization.oidc.web.authentication.OidcClientRegistrationAuthenticationConverter;
 import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
 import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
 import org.springframework.security.oauth2.server.authorization.settings.ClientSettings;
 import org.springframework.security.oauth2.server.authorization.settings.ClientSettings;
 import org.springframework.security.oauth2.server.authorization.test.SpringTestRule;
 import org.springframework.security.oauth2.server.authorization.test.SpringTestRule;
 import org.springframework.security.web.SecurityFilterChain;
 import org.springframework.security.web.SecurityFilterChain;
+import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.test.web.servlet.MvcResult;
 import org.springframework.test.web.servlet.MvcResult;
@@ -89,6 +104,14 @@ import org.springframework.web.util.UriComponentsBuilder;
 
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.hamcrest.CoreMatchers.containsString;
 import static org.hamcrest.CoreMatchers.containsString;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.reset;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.when;
+import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.jwt;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header;
@@ -128,6 +151,18 @@ public class OidcClientRegistrationTests {
 	@Autowired
 	@Autowired
 	private AuthorizationServerSettings authorizationServerSettings;
 	private AuthorizationServerSettings authorizationServerSettings;
 
 
+	private static AuthenticationConverter authenticationConverter;
+
+	private static Consumer<List<AuthenticationConverter>> authenticationConvertersConsumer;
+
+	private static AuthenticationProvider authenticationProvider;
+
+	private static Consumer<List<AuthenticationProvider>> authenticationProvidersConsumer;
+
+	private static AuthenticationSuccessHandler authenticationSuccessHandler;
+
+	private static AuthenticationFailureHandler authenticationFailureHandler;
+
 	private MockWebServer server;
 	private MockWebServer server;
 	private String clientJwkSetUrl;
 	private String clientJwkSetUrl;
 
 
@@ -145,6 +180,12 @@ public class OidcClientRegistrationTests {
 				.addScript("org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql")
 				.addScript("org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql")
 				.addScript("org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql")
 				.addScript("org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql")
 				.build();
 				.build();
+		authenticationConverter = mock(AuthenticationConverter.class);
+		authenticationConvertersConsumer = mock(Consumer.class);
+		authenticationProvider = mock(AuthenticationProvider.class);
+		authenticationProvidersConsumer = mock(Consumer.class);
+		authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class);
+		authenticationFailureHandler = mock(AuthenticationFailureHandler.class);
 	}
 	}
 
 
 	@Before
 	@Before
@@ -158,6 +199,7 @@ public class OidcClientRegistrationTests {
 				.setBody(clientJwkSet.toString());
 				.setBody(clientJwkSet.toString());
 		// @formatter:on
 		// @formatter:on
 		this.server.enqueue(response);
 		this.server.enqueue(response);
+		when(authenticationProvider.supports(OidcClientRegistrationAuthenticationToken.class)).thenReturn(true);
 	}
 	}
 
 
 	@After
 	@After
@@ -165,6 +207,12 @@ public class OidcClientRegistrationTests {
 		this.server.shutdown();
 		this.server.shutdown();
 		jdbcOperations.update("truncate table oauth2_authorization");
 		jdbcOperations.update("truncate table oauth2_authorization");
 		jdbcOperations.update("truncate table oauth2_registered_client");
 		jdbcOperations.update("truncate table oauth2_registered_client");
+		reset(authenticationConverter);
+		reset(authenticationConvertersConsumer);
+		reset(authenticationProvider);
+		reset(authenticationProvidersConsumer);
+		reset(authenticationSuccessHandler);
+		reset(authenticationFailureHandler);
 	}
 	}
 
 
 	@AfterClass
 	@AfterClass
@@ -261,6 +309,67 @@ public class OidcClientRegistrationTests {
 		assertThat(clientConfigurationResponse.getRegistrationAccessToken()).isNull();
 		assertThat(clientConfigurationResponse.getRegistrationAccessToken()).isNull();
 	}
 	}
 
 
+	@Test
+	public void requestWhenClientRegistrationEndpointCustomizedThenUsed() throws Exception {
+		this.spring.register(CustomClientRegistrationConfiguration.class).autowire();
+
+		// @formatter:off
+		OidcClientRegistration clientRegistration = OidcClientRegistration.builder()
+				.clientName("client-name")
+				.redirectUri("https://client.example.com")
+				.grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue())
+				.grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())
+				.scope("scope1")
+				.scope("scope2")
+				.build();
+		// @formatter:on
+
+		doAnswer(invocation -> {
+			HttpServletResponse response = invocation.getArgument(1, HttpServletResponse.class);
+			ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
+			httpResponse.setStatusCode(HttpStatus.CREATED);
+			new OidcClientRegistrationHttpMessageConverter().write(clientRegistration, null, httpResponse);
+			return null;
+		}).when(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), any());
+
+		registerClient(clientRegistration);
+
+		verify(authenticationConverter).convert(any());
+		ArgumentCaptor<List<AuthenticationConverter>> authenticationConvertersCaptor =
+				ArgumentCaptor.forClass(List.class);
+		verify(authenticationConvertersConsumer).accept(authenticationConvertersCaptor.capture());
+		List<AuthenticationConverter> authenticationConverters = authenticationConvertersCaptor.getValue();
+		assertThat(authenticationConverters).hasSize(2)
+				.allMatch(converter -> converter == authenticationConverter
+						|| converter instanceof OidcClientRegistrationAuthenticationConverter);
+
+		verify(authenticationProvider).authenticate(any());
+		ArgumentCaptor<List<AuthenticationProvider>> authenticationProvidersCaptor =
+				ArgumentCaptor.forClass(List.class);
+		verify(authenticationProvidersConsumer).accept(authenticationProvidersCaptor.capture());
+		List<AuthenticationProvider> authenticationProviders = authenticationProvidersCaptor.getValue();
+		assertThat(authenticationProviders).hasSize(3)
+				.allMatch(provider -> provider == authenticationProvider
+						|| provider instanceof OidcClientRegistrationAuthenticationProvider
+						|| provider instanceof OidcClientConfigurationAuthenticationProvider);
+
+		verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), any());
+		verifyNoInteractions(authenticationFailureHandler);
+	}
+
+	@Test
+	public void requestWhenClientRegistrationEndpointCustomizedWithAuthenticationFailureHandlerThenUsed() throws Exception {
+		this.spring.register(CustomClientRegistrationConfiguration.class).autowire();
+
+		when(authenticationProvider.authenticate(any())).thenThrow(new OAuth2AuthenticationException("error"));
+
+		this.mvc.perform(get(DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI)
+				.param(OAuth2ParameterNames.CLIENT_ID, "invalid").with(jwt()));
+
+		verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), any());
+		verifyNoInteractions(authenticationSuccessHandler);
+	}
+
 	private OidcClientRegistration registerClient(OidcClientRegistration clientRegistration) throws Exception {
 	private OidcClientRegistration registerClient(OidcClientRegistration clientRegistration) throws Exception {
 		// ***** (1) Obtain the "initial" access token used for registering the client
 		// ***** (1) Obtain the "initial" access token used for registering the client
 
 
@@ -353,6 +462,44 @@ public class OidcClientRegistrationTests {
 		return clientRegistrationHttpMessageConverter.read(OidcClientRegistration.class, httpResponse);
 		return clientRegistrationHttpMessageConverter.read(OidcClientRegistration.class, httpResponse);
 	}
 	}
 
 
+	@EnableWebSecurity
+	@Configuration(proxyBeanMethods = false)
+	static class CustomClientRegistrationConfiguration extends AuthorizationServerConfiguration {
+
+		// @formatter:off
+		@Bean
+		@Override
+		public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http) throws Exception {
+			OAuth2AuthorizationServerConfigurer authorizationServerConfigurer =
+					new OAuth2AuthorizationServerConfigurer();
+			authorizationServerConfigurer
+				.oidc(oidc ->
+					oidc
+						.clientRegistrationEndpoint(clientRegistration ->
+							clientRegistration
+								.clientRegistrationRequestConverter(authenticationConverter)
+								.clientRegistrationRequestConverters(authenticationConvertersConsumer)
+								.authenticationProvider(authenticationProvider)
+								.authenticationProviders(authenticationProvidersConsumer)
+								.clientRegistrationResponseHandler(authenticationSuccessHandler)
+								.errorResponseHandler(authenticationFailureHandler)
+						)
+				);
+			RequestMatcher endpointsMatcher = authorizationServerConfigurer.getEndpointsMatcher();
+
+			http
+					.securityMatcher(endpointsMatcher)
+					.authorizeHttpRequests(authorize ->
+							authorize.anyRequest().authenticated()
+					)
+					.csrf(csrf -> csrf.ignoringRequestMatchers(endpointsMatcher))
+					.oauth2ResourceServer(OAuth2ResourceServerConfigurer::jwt)
+					.apply(authorizationServerConfigurer);
+			return http.build();
+		}
+		// @formatter:on
+	}
+
 	@EnableWebSecurity
 	@EnableWebSecurity
 	@Configuration(proxyBeanMethods = false)
 	@Configuration(proxyBeanMethods = false)
 	static class AuthorizationServerConfiguration {
 	static class AuthorizationServerConfiguration {

+ 130 - 11
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcUserInfoTests.java

@@ -19,9 +19,13 @@ import java.time.Instant;
 import java.util.Arrays;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Set;
 import java.util.Set;
+import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.function.Function;
 
 
+import jakarta.servlet.http.HttpServletResponse;
+
 import com.nimbusds.jose.jwk.JWKSet;
 import com.nimbusds.jose.jwk.JWKSet;
 import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
 import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
 import com.nimbusds.jose.jwk.source.JWKSource;
 import com.nimbusds.jose.jwk.source.JWKSource;
@@ -30,11 +34,14 @@ import org.junit.Before;
 import org.junit.BeforeClass;
 import org.junit.BeforeClass;
 import org.junit.Rule;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.Test;
+import org.mockito.ArgumentCaptor;
 
 
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Configuration;
 import org.springframework.context.annotation.Configuration;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpStatus;
+import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.config.Customizer;
 import org.springframework.security.config.Customizer;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
@@ -62,11 +69,15 @@ import org.springframework.security.oauth2.server.authorization.client.Registere
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration;
 import org.springframework.security.oauth2.server.authorization.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationContext;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationContext;
+import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
 import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
 import org.springframework.security.oauth2.server.authorization.test.SpringTestRule;
 import org.springframework.security.oauth2.server.authorization.test.SpringTestRule;
 import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken;
 import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken;
 import org.springframework.security.web.SecurityFilterChain;
 import org.springframework.security.web.SecurityFilterChain;
+import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
 import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -75,8 +86,15 @@ import org.springframework.test.web.servlet.MvcResult;
 import org.springframework.test.web.servlet.ResultMatcher;
 import org.springframework.test.web.servlet.ResultMatcher;
 
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.when;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
@@ -100,17 +118,48 @@ public class OidcUserInfoTests {
 	@Autowired
 	@Autowired
 	private JwtEncoder jwtEncoder;
 	private JwtEncoder jwtEncoder;
 
 
+	@Autowired
+	private JwtDecoder jwtDecoder;
+
 	@Autowired
 	@Autowired
 	private OAuth2AuthorizationService authorizationService;
 	private OAuth2AuthorizationService authorizationService;
 
 
+	private static AuthenticationConverter authenticationConverter;
+
+	private static Consumer<List<AuthenticationConverter>> authenticationConvertersConsumer;
+
+	private static AuthenticationProvider authenticationProvider;
+
+	private static Consumer<List<AuthenticationProvider>> authenticationProvidersConsumer;
+
+	private static AuthenticationSuccessHandler authenticationSuccessHandler;
+
+	private static AuthenticationFailureHandler authenticationFailureHandler;
+
+	private static Function<OidcUserInfoAuthenticationContext, OidcUserInfo> userInfoMapper;
+
 	@BeforeClass
 	@BeforeClass
 	public static void init() {
 	public static void init() {
 		securityContextRepository = spy(new HttpSessionSecurityContextRepository());
 		securityContextRepository = spy(new HttpSessionSecurityContextRepository());
+		authenticationConverter = mock(AuthenticationConverter.class);
+		authenticationConvertersConsumer = mock(Consumer.class);
+		authenticationProvider = mock(AuthenticationProvider.class);
+		authenticationProvidersConsumer = mock(Consumer.class);
+		authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class);
+		authenticationFailureHandler = mock(AuthenticationFailureHandler.class);
+		userInfoMapper = mock(Function.class);
 	}
 	}
 
 
 	@Before
 	@Before
 	public void setup() {
 	public void setup() {
 		reset(securityContextRepository);
 		reset(securityContextRepository);
+		reset(authenticationConverter);
+		reset(authenticationConvertersConsumer);
+		reset(authenticationProvider);
+		reset(authenticationProvidersConsumer);
+		reset(authenticationSuccessHandler);
+		reset(authenticationFailureHandler);
+		reset(userInfoMapper);
 	}
 	}
 
 
 	@Test
 	@Test
@@ -146,19 +195,91 @@ public class OidcUserInfoTests {
 	}
 	}
 
 
 	@Test
 	@Test
-	public void requestWhenSignedJwtAndCustomUserInfoMapperThenMapJwtClaimsToUserInfoResponse() throws Exception {
+	public void requestWhenUserInfoEndpointCustomizedThenUsed() throws Exception {
 		this.spring.register(CustomUserInfoConfiguration.class).autowire();
 		this.spring.register(CustomUserInfoConfiguration.class).autowire();
 
 
 		OAuth2Authorization authorization = createAuthorization();
 		OAuth2Authorization authorization = createAuthorization();
 		this.authorizationService.save(authorization);
 		this.authorizationService.save(authorization);
 
 
+		when(userInfoMapper.apply(any())).thenReturn(createUserInfo());
+
 		OAuth2AccessToken accessToken = authorization.getAccessToken().getToken();
 		OAuth2AccessToken accessToken = authorization.getAccessToken().getToken();
 		// @formatter:off
 		// @formatter:off
 		this.mvc.perform(get(DEFAULT_OIDC_USER_INFO_ENDPOINT_URI)
 		this.mvc.perform(get(DEFAULT_OIDC_USER_INFO_ENDPOINT_URI)
 				.header(HttpHeaders.AUTHORIZATION, "Bearer " + accessToken.getTokenValue()))
 				.header(HttpHeaders.AUTHORIZATION, "Bearer " + accessToken.getTokenValue()))
-				.andExpect(status().is2xxSuccessful())
-				.andExpectAll(userInfoResponse());
+				.andExpect(status().is2xxSuccessful());
+		// @formatter:on
+
+		verify(userInfoMapper).apply(any());
+		verify(authenticationConverter).convert(any());
+		verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), any());
+		verifyNoInteractions(authenticationFailureHandler);
+
+		ArgumentCaptor<List<AuthenticationProvider>> authenticationProvidersCaptor = ArgumentCaptor.forClass(List.class);
+		verify(authenticationProvidersConsumer).accept(authenticationProvidersCaptor.capture());
+		List<AuthenticationProvider> authenticationProviders = authenticationProvidersCaptor.getValue();
+		assertThat(authenticationProviders).hasSize(2).allMatch(provider ->
+				provider == authenticationProvider ||
+						provider instanceof OidcUserInfoAuthenticationProvider
+				);
+
+		ArgumentCaptor<List<AuthenticationConverter>> authenticationConvertersCaptor = ArgumentCaptor.forClass(List.class);
+		verify(authenticationConvertersConsumer).accept(authenticationConvertersCaptor.capture());
+		List<AuthenticationConverter> authenticationConverters = authenticationConvertersCaptor.getValue();
+		assertThat(authenticationConverters).hasSize(2).allMatch(AuthenticationConverter.class::isInstance);
+	}
+
+	@Test
+	public void requestWhenUserInfoEndpointCustomizedWithAuthenticationProviderThenUsed() throws Exception {
+		this.spring.register(CustomUserInfoConfiguration.class).autowire();
+
+		OAuth2Authorization authorization = createAuthorization();
+		this.authorizationService.save(authorization);
+
+		when(authenticationProvider.supports(eq(OidcUserInfoAuthenticationToken.class))).thenReturn(true);
+		String tokenValue = authorization.getAccessToken().getToken().getTokenValue();
+		Jwt jwt = this.jwtDecoder.decode(tokenValue);
+		OidcUserInfoAuthenticationToken oidcUserInfoAuthentication = new OidcUserInfoAuthenticationToken(
+				new JwtAuthenticationToken(jwt), createUserInfo());
+		when(authenticationProvider.authenticate(any())).thenReturn(oidcUserInfoAuthentication);
+
+		OAuth2AccessToken accessToken = authorization.getAccessToken().getToken();
+		// @formatter:off
+		this.mvc.perform(get(DEFAULT_OIDC_USER_INFO_ENDPOINT_URI)
+						.header(HttpHeaders.AUTHORIZATION, "Bearer " + accessToken.getTokenValue()))
+				.andExpect(status().is2xxSuccessful());
+		// @formatter:on
+
+		verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), any());
+		verify(authenticationProvider).authenticate(any());
+		verifyNoInteractions(authenticationFailureHandler);
+		verifyNoInteractions(userInfoMapper);
+	}
+
+	@Test
+	public void requestWhenUserInfoEndpointCustomizedWithAuthenticationFailureHandlerThenUsed() throws Exception {
+		this.spring.register(CustomUserInfoConfiguration.class).autowire();
+
+		when(userInfoMapper.apply(any())).thenReturn(createUserInfo());
+		doAnswer(
+				invocation -> {
+					HttpServletResponse response = invocation.getArgument(1);
+					response.setStatus(HttpStatus.UNAUTHORIZED.value());
+					response.getWriter().write("unauthorized");
+					return null;
+				}
+		).when(authenticationFailureHandler).onAuthenticationFailure(any(), any(), any());
+
+		OAuth2AccessToken accessToken = createAuthorization().getAccessToken().getToken();
+		// @formatter:off
+		this.mvc.perform(get(DEFAULT_OIDC_USER_INFO_ENDPOINT_URI)
+				.header(HttpHeaders.AUTHORIZATION, "Bearer " + accessToken.getTokenValue()))
+				.andExpect(status().is4xxClientError());
 		// @formatter:on
 		// @formatter:on
+
+		verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), any());
+		verifyNoInteractions(authenticationSuccessHandler);
+		verifyNoInteractions(userInfoMapper);
 	}
 	}
 
 
 	// gh-482
 	// gh-482
@@ -273,14 +394,6 @@ public class OidcUserInfoTests {
 			RequestMatcher endpointsMatcher = authorizationServerConfigurer
 			RequestMatcher endpointsMatcher = authorizationServerConfigurer
 					.getEndpointsMatcher();
 					.getEndpointsMatcher();
 
 
-			// Custom User Info Mapper that retrieves claims from a signed JWT
-			Function<OidcUserInfoAuthenticationContext, OidcUserInfo> userInfoMapper = context -> {
-				OidcUserInfoAuthenticationToken authentication = context.getAuthentication();
-				JwtAuthenticationToken principal = (JwtAuthenticationToken) authentication.getPrincipal();
-
-				return new OidcUserInfo(principal.getToken().getClaims());
-			};
-
 			// @formatter:off
 			// @formatter:off
 			http
 			http
 				.securityMatcher(endpointsMatcher)
 				.securityMatcher(endpointsMatcher)
@@ -292,6 +405,12 @@ public class OidcUserInfoTests {
 				.apply(authorizationServerConfigurer)
 				.apply(authorizationServerConfigurer)
 					.oidc(oidc -> oidc
 					.oidc(oidc -> oidc
 						.userInfoEndpoint(userInfo -> userInfo
 						.userInfoEndpoint(userInfo -> userInfo
+							.userInfoRequestConverter(authenticationConverter)
+							.userInfoRequestConverters(authenticationConvertersConsumer)
+							.authenticationProvider(authenticationProvider)
+							.authenticationProviders(authenticationProvidersConsumer)
+							.userInfoResponseHandler(authenticationSuccessHandler)
+							.errorResponseHandler(authenticationFailureHandler)
 							.userInfoMapper(userInfoMapper)
 							.userInfoMapper(userInfoMapper)
 						)
 						)
 					);
 					);

+ 7 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java

@@ -134,6 +134,13 @@ public class OidcClientRegistrationAuthenticationProviderTests {
 				.withMessage("tokenGenerator cannot be null");
 				.withMessage("tokenGenerator cannot be null");
 	}
 	}
 
 
+	@Test
+	public void setRegisteredClientConverterWhenNullThenThrowIllegalArgumentException() {
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.authenticationProvider.setRegisteredClientConverter(null))
+				.withMessage("registeredClientConverter cannot be null");
+	}
+
 	@Test
 	@Test
 	public void supportsWhenTypeOidcClientRegistrationAuthenticationTokenThenReturnTrue() {
 	public void supportsWhenTypeOidcClientRegistrationAuthenticationTokenThenReturnTrue() {
 		assertThat(this.authenticationProvider.supports(OidcClientRegistrationAuthenticationToken.class)).isTrue();
 		assertThat(this.authenticationProvider.supports(OidcClientRegistrationAuthenticationToken.class)).isTrue();

+ 126 - 35
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java

@@ -15,10 +15,12 @@
  */
  */
 package org.springframework.security.oauth2.server.authorization.oidc.web;
 package org.springframework.security.oauth2.server.authorization.oidc.web;
 
 
+import java.io.IOException;
 import java.time.Instant;
 import java.time.Instant;
 import java.util.Collections;
 import java.util.Collections;
 
 
 import jakarta.servlet.FilterChain;
 import jakarta.servlet.FilterChain;
+import jakarta.servlet.ServletException;
 import jakarta.servlet.http.HttpServletRequest;
 import jakarta.servlet.http.HttpServletRequest;
 import jakarta.servlet.http.HttpServletResponse;
 import jakarta.servlet.http.HttpServletResponse;
 
 
@@ -33,6 +35,8 @@ import org.springframework.mock.http.client.MockClientHttpResponse;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.AuthenticationManager;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.Authentication;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.context.SecurityContextHolder;
@@ -54,10 +58,14 @@ import org.springframework.security.oauth2.server.authorization.oidc.OidcClientR
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter;
 import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter;
 import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken;
 import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken;
+import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoInteractions;
 import static org.mockito.Mockito.verifyNoInteractions;
@@ -68,6 +76,7 @@ import static org.mockito.Mockito.when;
  *
  *
  * @author Ovidiu Popa
  * @author Ovidiu Popa
  * @author Joe Grandja
  * @author Joe Grandja
+ * @author Daniel Garnier-Moiroux
  */
  */
 public class OidcClientRegistrationEndpointFilterTests {
 public class OidcClientRegistrationEndpointFilterTests {
 	private static final String DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI = "/connect/register";
 	private static final String DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI = "/connect/register";
@@ -103,6 +112,27 @@ public class OidcClientRegistrationEndpointFilterTests {
 				.withMessage("clientRegistrationEndpointUri cannot be empty");
 				.withMessage("clientRegistrationEndpointUri cannot be empty");
 	}
 	}
 
 
+	@Test
+	public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException() {
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.filter.setAuthenticationConverter(null))
+				.withMessage("authenticationConverter cannot be null");
+	}
+
+	@Test
+	public void setAuthenticationSuccessHandlerWhenNullThenThrowIllegalArgumentException() {
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.filter.setAuthenticationSuccessHandler(null))
+				.withMessage("authenticationSuccessHandler cannot be null");
+	}
+
+	@Test
+	public void setAuthenticationFailureHandlerWhenNullThenThrowIllegalArgumentException() {
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.filter.setAuthenticationFailureHandler(null))
+				.withMessage("authenticationFailureHandler cannot be null");
+	}
+
 	@Test
 	@Test
 	public void doFilterWhenNotClientRegistrationRequestThenNotProcessed() throws Exception {
 	public void doFilterWhenNotClientRegistrationRequestThenNotProcessed() throws Exception {
 		String requestUri = "/path";
 		String requestUri = "/path";
@@ -203,25 +233,13 @@ public class OidcClientRegistrationEndpointFilterTests {
 	@Test
 	@Test
 	public void doFilterWhenClientRegistrationRequestValidThenSuccessResponse() throws Exception {
 	public void doFilterWhenClientRegistrationRequestValidThenSuccessResponse() throws Exception {
 		// @formatter:off
 		// @formatter:off
-		OidcClientRegistration.Builder clientRegistrationBuilder = OidcClientRegistration.builder()
-				.clientName("client-name")
-				.redirectUri("https://client.example.com")
-				.grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue())
-				.grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())
-				.scope("scope1")
-				.scope("scope2");
-
-		OidcClientRegistration clientRegistrationRequest = clientRegistrationBuilder.build();
+		OidcClientRegistration expectedClientRegistrationResponse = createClientRegistration();
 
 
-		OidcClientRegistration expectedClientRegistrationResponse = clientRegistrationBuilder
-				.clientId("client-id")
-				.clientIdIssuedAt(Instant.now())
-				.clientSecret("client-secret")
-				.tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue())
-				.responseType(OAuth2AuthorizationResponseType.CODE.getValue())
-				.idTokenSignedResponseAlgorithm(SignatureAlgorithm.RS256.getName())
-				.registrationAccessToken("registration-access-token")
-				.registrationClientUrl("https://auth-server:9000/connect/register?client_id=client-id")
+		OidcClientRegistration clientRegistrationRequest = OidcClientRegistration.builder()
+				.clientName(expectedClientRegistrationResponse.getClientName())
+				.redirectUris(redirectUris -> redirectUris.addAll(expectedClientRegistrationResponse.getRedirectUris()))
+				.grantTypes(grantTypes -> grantTypes.addAll(expectedClientRegistrationResponse.getGrantTypes()))
+				.scopes(scopes -> scopes.addAll(expectedClientRegistrationResponse.getScopes()))
 				.build();
 				.build();
 		// @formatter:on
 		// @formatter:on
 
 
@@ -384,23 +402,7 @@ public class OidcClientRegistrationEndpointFilterTests {
 
 
 	@Test
 	@Test
 	public void doFilterWhenClientConfigurationRequestValidThenSuccessResponse() throws Exception {
 	public void doFilterWhenClientConfigurationRequestValidThenSuccessResponse() throws Exception {
-		// @formatter:off
-		OidcClientRegistration expectedClientRegistrationResponse = OidcClientRegistration.builder()
-				.clientId("client-id")
-				.clientIdIssuedAt(Instant.now())
-				.clientSecret("client-secret")
-				.clientName("client-name")
-				.redirectUri("https://client.example.com")
-				.grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue())
-				.grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())
-				.tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue())
-				.responseType(OAuth2AuthorizationResponseType.CODE.getValue())
-				.idTokenSignedResponseAlgorithm(SignatureAlgorithm.RS256.getName())
-				.scope("scope1")
-				.scope("scope2")
-				.registrationClientUrl("https://auth-server:9000/connect/register?client_id=client-id")
-				.build();
-		// @formatter:on
+		OidcClientRegistration expectedClientRegistrationResponse = createClientRegistration();
 
 
 		Jwt jwt = createJwt("client.read");
 		Jwt jwt = createJwt("client.read");
 		JwtAuthenticationToken principal = new JwtAuthenticationToken(
 		JwtAuthenticationToken principal = new JwtAuthenticationToken(
@@ -452,6 +454,74 @@ public class OidcClientRegistrationEndpointFilterTests {
 				.isEqualTo(expectedClientRegistrationResponse.getRegistrationClientUrl());
 				.isEqualTo(expectedClientRegistrationResponse.getRegistrationClientUrl());
 	}
 	}
 
 
+	@Test
+	public void doFilterWhenCustomAuthenticationConverterThenUsed() throws ServletException, IOException {
+		AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class);
+		this.filter.setAuthenticationConverter(authenticationConverter);
+
+		String requestUri = DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI;
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client-id");
+
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(authenticationConverter).convert(request);
+	}
+
+	@Test
+	public void doFilterWhenCustomAuthenticationSuccessHandlerThenUsed() throws Exception {
+		OidcClientRegistration expectedClientRegistrationResponse = createClientRegistration();
+		Authentication principal = new TestingAuthenticationToken("principal", "Credentials");
+
+		OidcClientRegistrationAuthenticationToken clientRegistrationAuthenticationResult =
+				new OidcClientRegistrationAuthenticationToken(principal, expectedClientRegistrationResponse);
+
+		when(this.authenticationManager.authenticate(any())).thenReturn(clientRegistrationAuthenticationResult);
+		AuthenticationSuccessHandler successHandler = mock(AuthenticationSuccessHandler.class);
+		this.filter.setAuthenticationSuccessHandler(successHandler);
+
+		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
+		securityContext.setAuthentication(principal);
+		SecurityContextHolder.setContext(securityContext);
+
+		String requestUri = DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI;
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		request.setParameter(OAuth2ParameterNames.CLIENT_ID, expectedClientRegistrationResponse.getClientId());
+
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(successHandler).onAuthenticationSuccess(request, response, clientRegistrationAuthenticationResult);
+	}
+
+	@Test
+	public void doFilterWhenCustomAuthenticationFailureHandlerThenUsed() throws Exception {
+		AuthenticationFailureHandler authenticationFailureHandler = mock(AuthenticationFailureHandler.class);
+		this.filter.setAuthenticationFailureHandler(authenticationFailureHandler);
+
+		when(this.authenticationManager.authenticate(any()))
+				.thenThrow(new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN));
+
+		String requestUri = DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI;
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client1");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(authenticationFailureHandler).onAuthenticationFailure(eq(request), eq(response),
+				any(OAuth2AuthenticationException.class));
+	}
+
 	private OAuth2Error readError(MockHttpServletResponse response) throws Exception {
 	private OAuth2Error readError(MockHttpServletResponse response) throws Exception {
 		MockClientHttpResponse httpResponse = new MockClientHttpResponse(
 		MockClientHttpResponse httpResponse = new MockClientHttpResponse(
 				response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus()));
 				response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus()));
@@ -471,6 +541,27 @@ public class OidcClientRegistrationEndpointFilterTests {
 		return this.clientRegistrationHttpMessageConverter.read(OidcClientRegistration.class, httpResponse);
 		return this.clientRegistrationHttpMessageConverter.read(OidcClientRegistration.class, httpResponse);
 	}
 	}
 
 
+	private static OidcClientRegistration createClientRegistration() {
+		// @formatter:off
+		OidcClientRegistration clientRegistration = OidcClientRegistration.builder()
+				.clientId("client-id")
+				.clientIdIssuedAt(Instant.now())
+				.clientSecret("client-secret")
+				.clientName("client-name")
+				.redirectUri("https://client.example.com")
+				.grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue())
+				.grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())
+				.tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue())
+				.responseType(OAuth2AuthorizationResponseType.CODE.getValue())
+				.idTokenSignedResponseAlgorithm(SignatureAlgorithm.RS256.getName())
+				.scope("scope1")
+				.scope("scope2")
+				.registrationClientUrl("https://auth-server:9000/connect/register?client_id=client-id")
+				.build();
+		return clientRegistration;
+		// @formatter:on
+	}
+
 	private static Jwt createJwt(String scope) {
 	private static Jwt createJwt(String scope) {
 		// @formatter:off
 		// @formatter:off
 		JwsHeader jwsHeader = TestJwsHeaders.jwsHeader()
 		JwsHeader jwsHeader = TestJwsHeaders.jwsHeader()

+ 110 - 3
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilterTests.java

@@ -44,6 +44,9 @@ import org.springframework.security.oauth2.jwt.JoseHeaderNames;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken;
 import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken;
 import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken;
+import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@@ -84,6 +87,27 @@ public class OidcUserInfoEndpointFilterTests {
 				.withMessage("userInfoEndpointUri cannot be empty");
 				.withMessage("userInfoEndpointUri cannot be empty");
 	}
 	}
 
 
+	@Test
+	public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException() {
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.filter.setAuthenticationConverter(null))
+				.withMessage("authenticationConverter cannot be null");
+	}
+
+	@Test
+	public void setAuthenticationSuccessHandlerWhenNullThenThrowIllegalArgumentException() {
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.filter.setAuthenticationSuccessHandler(null))
+				.withMessage("authenticationSuccessHandler cannot be null");
+	}
+
+	@Test
+	public void setAuthenticationFailureHandlerWhenNullThenThrowIllegalArgumentException() {
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.filter.setAuthenticationFailureHandler(null))
+				.withMessage("authenticationFailureHandler cannot be null");
+	}
+
 	@Test
 	@Test
 	public void doFilterWhenNotUserInfoRequestThenNotProcessed() throws Exception {
 	public void doFilterWhenNotUserInfoRequestThenNotProcessed() throws Exception {
 		String requestUri = "/path";
 		String requestUri = "/path";
@@ -145,11 +169,21 @@ public class OidcUserInfoEndpointFilterTests {
 
 
 	@Test
 	@Test
 	public void doFilterWhenUserInfoRequestInvalidTokenThenUnauthorizedError() throws Exception {
 	public void doFilterWhenUserInfoRequestInvalidTokenThenUnauthorizedError() throws Exception {
+		doFilterWhenAuthenticationExceptionThenError(OAuth2ErrorCodes.INVALID_TOKEN, HttpStatus.UNAUTHORIZED);
+	}
+
+	@Test
+	public void doFilterWhenUserInfoRequestInsufficientScopeThenForbiddenError() throws Exception {
+		doFilterWhenAuthenticationExceptionThenError(OAuth2ErrorCodes.INSUFFICIENT_SCOPE, HttpStatus.FORBIDDEN);
+	}
+
+	private void doFilterWhenAuthenticationExceptionThenError(String oauth2ErrorCode, HttpStatus httpStatus)
+			throws Exception {
 		Authentication principal = new TestingAuthenticationToken("principal", "credentials");
 		Authentication principal = new TestingAuthenticationToken("principal", "credentials");
 		SecurityContextHolder.getContext().setAuthentication(principal);
 		SecurityContextHolder.getContext().setAuthentication(principal);
 
 
 		when(this.authenticationManager.authenticate(any()))
 		when(this.authenticationManager.authenticate(any()))
-				.thenThrow(new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN));
+				.thenThrow(new OAuth2AuthenticationException(oauth2ErrorCode));
 
 
 		String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI;
 		String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI;
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
@@ -161,9 +195,82 @@ public class OidcUserInfoEndpointFilterTests {
 
 
 		verifyNoInteractions(filterChain);
 		verifyNoInteractions(filterChain);
 
 
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value());
+		assertThat(response.getStatus()).isEqualTo(httpStatus.value());
 		OAuth2Error error = readError(response);
 		OAuth2Error error = readError(response);
-		assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN);
+		assertThat(error.getErrorCode()).isEqualTo(oauth2ErrorCode);
+	}
+
+	@Test
+	public void doFilterWhenCustomAuthenticationConverterThenUsed() throws Exception {
+		Authentication principal = new TestingAuthenticationToken("principal", "credentials");
+		OidcUserInfoAuthenticationToken authentication = new OidcUserInfoAuthenticationToken(principal);
+		AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class);
+		this.filter.setAuthenticationConverter(authenticationConverter);
+
+		when(authenticationConverter.convert(any())).thenReturn(authentication);
+		when(this.authenticationManager.authenticate(any())).thenReturn(
+				new OidcUserInfoAuthenticationToken(principal, createUserInfo())
+		);
+
+		String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI;
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+		verify(authenticationConverter).convert(request);
+		verify(this.authenticationManager).authenticate(authentication);
+		assertUserInfoResponse(response.getContentAsString());
+	}
+
+	@Test
+	public void doFilterWhenCustomAuthenticationSuccessHandlerThenUsed() throws Exception {
+		AuthenticationSuccessHandler successHandler = mock(AuthenticationSuccessHandler.class);
+		this.filter.setAuthenticationSuccessHandler(successHandler);
+
+		Authentication principal = new TestingAuthenticationToken("principal", "credentials");
+		SecurityContextHolder.getContext().setAuthentication(principal);
+
+		OidcUserInfoAuthenticationToken authentication = new OidcUserInfoAuthenticationToken(principal, createUserInfo());
+		when(this.authenticationManager.authenticate(any())).thenReturn(authentication);
+
+		String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI;
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+		verify(successHandler).onAuthenticationSuccess(request, response, authentication);
+	}
+
+	@Test
+	public void doFilterWhenCustomAuthenticationFailureHandlerThenUsed() throws Exception {
+		AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class);
+		this.filter.setAuthenticationFailureHandler(failureHandler);
+
+		Authentication principal = new TestingAuthenticationToken("principal", "credentials");
+		SecurityContextHolder.getContext().setAuthentication(principal);
+
+		OAuth2AuthenticationException authenticationException =
+				new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN);
+		when(this.authenticationManager.authenticate(any())).thenThrow(authenticationException);
+
+		String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI;
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+		verify(failureHandler).onAuthenticationFailure(request, response, authenticationException);
 	}
 	}
 
 
 	private OAuth2Error readError(MockHttpServletResponse response) throws Exception {
 	private OAuth2Error readError(MockHttpServletResponse response) throws Exception {