Kaynağa Gözat

Add OAuth2Authorization success/failure handlers

Fixes gh-7840
Joe Grandja 5 yıl önce
ebeveyn
işleme
69156b741d
15 değiştirilmiş dosya ile 1349 ekleme ve 101 silme
  1. 80 5
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java
  2. 48 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationFailureHandler.java
  3. 47 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationSuccessHandler.java
  4. 19 5
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClient.java
  5. 19 5
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java
  6. 19 5
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClient.java
  7. 19 5
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java
  8. 16 6
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClient.java
  9. 101 17
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java
  10. 169 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/RemoveAuthorizedClientOAuth2AuthorizationFailureHandler.java
  11. 77 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/SaveAuthorizedClientOAuth2AuthorizationSuccessHandler.java
  12. 306 24
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java
  13. 87 6
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManagerTests.java
  14. 95 4
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java
  15. 247 19
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

+ 80 - 5
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -19,6 +19,11 @@ import org.springframework.lang.Nullable;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
+import org.springframework.security.oauth2.client.web.RemoveAuthorizedClientOAuth2AuthorizationFailureHandler;
+import org.springframework.security.oauth2.client.web.SaveAuthorizedClientOAuth2AuthorizationSuccessHandler;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
@@ -31,20 +36,50 @@ import java.util.function.Function;
 
 /**
  * An implementation of an {@link OAuth2AuthorizedClientManager}
- * that is capable of operating outside of a {@code HttpServletRequest} context,
+ * that is capable of operating outside of the context of a {@code HttpServletRequest},
  * e.g. in a scheduled/background thread and/or in the service-tier.
  *
+ * <p>
+ * (When operating <em>within</em> the context of a {@code HttpServletRequest},
+ * use {@link DefaultOAuth2AuthorizedClientManager} instead.)
+ *
+ * <h2>Authorized Client Persistence</h2>
+ *
+ * <p>
+ * This manager utilizes an {@link OAuth2AuthorizedClientService}
+ * to persist {@link OAuth2AuthorizedClient}s.
+ *
+ * <p>
+ * By default, when an authorization attempt succeeds, the {@link OAuth2AuthorizedClient}
+ * will be saved in the {@link OAuth2AuthorizedClientService}.
+ * This functionality can be changed by configuring a custom {@link OAuth2AuthorizationSuccessHandler}
+ * via {@link #setAuthorizationSuccessHandler(OAuth2AuthorizationSuccessHandler)}.
+ *
+ * <p>
+ * By default, when an authorization attempt fails due to an
+ * {@value OAuth2ErrorCodes#INVALID_GRANT} error,
+ * the previously saved {@link OAuth2AuthorizedClient}
+ * will be removed from the {@link OAuth2AuthorizedClientService}.
+ * (The {@value OAuth2ErrorCodes#INVALID_GRANT} error can occur
+ * when a refresh token that is no longer valid is used to retrieve a new access token.)
+ * This functionality can be changed by configuring a custom {@link OAuth2AuthorizationFailureHandler}
+ * via {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)}.
+ *
  * @author Joe Grandja
  * @since 5.2
  * @see OAuth2AuthorizedClientManager
  * @see OAuth2AuthorizedClientProvider
  * @see OAuth2AuthorizedClientService
+ * @see OAuth2AuthorizationSuccessHandler
+ * @see OAuth2AuthorizationFailureHandler
  */
 public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implements OAuth2AuthorizedClientManager {
 	private final ClientRegistrationRepository clientRegistrationRepository;
 	private final OAuth2AuthorizedClientService authorizedClientService;
 	private OAuth2AuthorizedClientProvider authorizedClientProvider = context -> null;
-	private Function<OAuth2AuthorizeRequest, Map<String, Object>> contextAttributesMapper = new DefaultContextAttributesMapper();
+	private Function<OAuth2AuthorizeRequest, Map<String, Object>> contextAttributesMapper;
+	private OAuth2AuthorizationSuccessHandler authorizationSuccessHandler;
+	private OAuth2AuthorizationFailureHandler authorizationFailureHandler;
 
 	/**
 	 * Constructs an {@code AuthorizedClientServiceOAuth2AuthorizedClientManager} using the provided parameters.
@@ -58,6 +93,9 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen
 		Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
 		this.clientRegistrationRepository = clientRegistrationRepository;
 		this.authorizedClientService = authorizedClientService;
+		this.contextAttributesMapper = new DefaultContextAttributesMapper();
+		this.authorizationSuccessHandler = new SaveAuthorizedClientOAuth2AuthorizationSuccessHandler(authorizedClientService);
+		this.authorizationFailureHandler = new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(authorizedClientService);
 	}
 
 	@Nullable
@@ -92,9 +130,16 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen
 				})
 				.build();
 
-		authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
+		try {
+			authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
+		} catch (OAuth2AuthorizationException ex) {
+			this.authorizationFailureHandler.onAuthorizationFailure(ex, principal, Collections.emptyMap());
+			throw ex;
+		}
+
 		if (authorizedClient != null) {
-			this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal);
+			this.authorizationSuccessHandler.onAuthorizationSuccess(
+					authorizedClient, principal, Collections.emptyMap());
 		} else {
 			// In the case of re-authorization, the returned `authorizedClient` may be null if re-authorization is not supported.
 			// For these cases, return the provided `authorizationContext.authorizedClient`.
@@ -128,6 +173,36 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen
 		this.contextAttributesMapper = contextAttributesMapper;
 	}
 
+	/**
+	 * Sets the {@link OAuth2AuthorizationSuccessHandler} that handles successful authorizations.
+	 *
+	 * <p>
+	 * A {@link SaveAuthorizedClientOAuth2AuthorizationSuccessHandler} is used by default.
+	 *
+	 * @param authorizationSuccessHandler the {@link OAuth2AuthorizationSuccessHandler} that handles successful authorizations
+	 * @see SaveAuthorizedClientOAuth2AuthorizationSuccessHandler
+	 * @since 5.3
+	 */
+	public void setAuthorizationSuccessHandler(OAuth2AuthorizationSuccessHandler authorizationSuccessHandler) {
+		Assert.notNull(authorizationSuccessHandler, "authorizationSuccessHandler cannot be null");
+		this.authorizationSuccessHandler = authorizationSuccessHandler;
+	}
+
+	/**
+	 * Sets the {@link OAuth2AuthorizationFailureHandler} that handles authorization failures.
+	 *
+	 * <p>
+	 * A {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} is used by default.
+	 *
+	 * @param authorizationFailureHandler the {@link OAuth2AuthorizationFailureHandler} that handles authorization failures
+	 * @see RemoveAuthorizedClientOAuth2AuthorizationFailureHandler
+	 * @since 5.3
+	 */
+	public void setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler authorizationFailureHandler) {
+		Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null");
+		this.authorizationFailureHandler = authorizationFailureHandler;
+	}
+
 	/**
 	 * The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}.
 	 */

+ 48 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationFailureHandler.java

@@ -0,0 +1,48 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client;
+
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+
+import java.util.Map;
+
+/**
+ * Handles when an OAuth 2.0 Client fails to authorize (or re-authorize)
+ * via the Authorization Server or Resource Server.
+ *
+ * @author Joe Grandja
+ * @since 5.3
+ * @see OAuth2AuthorizedClient
+ * @see OAuth2AuthorizedClientManager
+ */
+@FunctionalInterface
+public interface OAuth2AuthorizationFailureHandler {
+
+	/**
+	 * Called when an OAuth 2.0 Client fails to authorize (or re-authorize)
+	 * via the Authorization Server or Resource Server.
+	 *
+	 * @param authorizationException the exception that contains details about what failed
+	 * @param principal the {@code Principal} associated with the attempted authorization
+	 * @param attributes an immutable {@code Map} of (optional) attributes present under certain conditions.
+	 *                   For example, this might contain a {@code javax.servlet.http.HttpServletRequest}
+	 *                   and {@code javax.servlet.http.HttpServletResponse} if the authorization was performed
+	 *                   within the context of a {@code javax.servlet.ServletContext}.
+	 */
+	void onAuthorizationFailure(OAuth2AuthorizationException authorizationException,
+			Authentication principal, Map<String, Object> attributes);
+}

+ 47 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationSuccessHandler.java

@@ -0,0 +1,47 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client;
+
+import org.springframework.security.core.Authentication;
+
+import java.util.Map;
+
+/**
+ * Handles when an OAuth 2.0 Client has been successfully
+ * authorized (or re-authorized) via the Authorization Server.
+ *
+ * @author Joe Grandja
+ * @since 5.3
+ * @see OAuth2AuthorizedClient
+ * @see OAuth2AuthorizedClientManager
+ */
+@FunctionalInterface
+public interface OAuth2AuthorizationSuccessHandler {
+
+	/**
+	 * Called when an OAuth 2.0 Client has been successfully
+	 * authorized (or re-authorized) via the Authorization Server.
+	 *
+	 * @param authorizedClient the client that was successfully authorized (or re-authorized)
+	 * @param principal the {@code Principal} associated with the authorized client
+	 * @param attributes an immutable {@code Map} of (optional) attributes present under certain conditions.
+	 *                   For example, this might contain a {@code javax.servlet.http.HttpServletRequest}
+	 *                   and {@code javax.servlet.http.HttpServletResponse} if the authorization was performed
+	 *                   within the context of a {@code javax.servlet.ServletContext}.
+	 */
+	void onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient,
+			Authentication principal, Map<String, Object> attributes);
+}

+ 19 - 5
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultAuthorizationCodeTokenResponseClient.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -20,9 +20,9 @@ import org.springframework.http.RequestEntity;
 import org.springframework.http.ResponseEntity;
 import org.springframework.http.converter.FormHttpMessageConverter;
 import org.springframework.http.converter.HttpMessageConverter;
+import org.springframework.security.oauth2.client.ClientAuthorizationException;
 import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
-import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
@@ -30,6 +30,7 @@ import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
 import org.springframework.web.client.ResponseErrorHandler;
 import org.springframework.web.client.RestClientException;
+import org.springframework.web.client.RestClientResponseException;
 import org.springframework.web.client.RestOperations;
 import org.springframework.web.client.RestTemplate;
 
@@ -74,9 +75,22 @@ public final class DefaultAuthorizationCodeTokenResponseClient implements OAuth2
 		try {
 			response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class);
 		} catch (RestClientException ex) {
-			OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
-					"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null);
-			throw new OAuth2AuthorizationException(oauth2Error, ex);
+			int statusCode = 500;
+			if (ex instanceof RestClientResponseException) {
+				statusCode = ((RestClientResponseException) ex).getRawStatusCode();
+			}
+			OAuth2Error oauth2Error = new OAuth2Error(
+					INVALID_TOKEN_RESPONSE_ERROR_CODE,
+					"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(),
+					null);
+			String message = String.format("Error retrieving OAuth 2.0 Access Token (HTTP Status Code: %s) %s",
+					statusCode,
+					oauth2Error);
+			throw new ClientAuthorizationException(
+					oauth2Error,
+					authorizationCodeGrantRequest.getClientRegistration().getRegistrationId(),
+					message,
+					ex);
 		}
 
 		OAuth2AccessTokenResponse tokenResponse = response.getBody();

+ 19 - 5
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -20,9 +20,9 @@ import org.springframework.http.RequestEntity;
 import org.springframework.http.ResponseEntity;
 import org.springframework.http.converter.FormHttpMessageConverter;
 import org.springframework.http.converter.HttpMessageConverter;
+import org.springframework.security.oauth2.client.ClientAuthorizationException;
 import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
-import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
@@ -30,6 +30,7 @@ import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
 import org.springframework.web.client.ResponseErrorHandler;
 import org.springframework.web.client.RestClientException;
+import org.springframework.web.client.RestClientResponseException;
 import org.springframework.web.client.RestOperations;
 import org.springframework.web.client.RestTemplate;
 
@@ -74,9 +75,22 @@ public final class DefaultClientCredentialsTokenResponseClient implements OAuth2
 		try {
 			response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class);
 		} catch (RestClientException ex) {
-			OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
-					"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null);
-			throw new OAuth2AuthorizationException(oauth2Error, ex);
+			int statusCode = 500;
+			if (ex instanceof RestClientResponseException) {
+				statusCode = ((RestClientResponseException) ex).getRawStatusCode();
+			}
+			OAuth2Error oauth2Error = new OAuth2Error(
+					INVALID_TOKEN_RESPONSE_ERROR_CODE,
+					"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(),
+					null);
+			String message = String.format("Error retrieving OAuth 2.0 Access Token (HTTP Status Code: %s) %s",
+					statusCode,
+					oauth2Error);
+			throw new ClientAuthorizationException(
+					oauth2Error,
+					clientCredentialsGrantRequest.getClientRegistration().getRegistrationId(),
+					message,
+					ex);
 		}
 
 		OAuth2AccessTokenResponse tokenResponse = response.getBody();

+ 19 - 5
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultPasswordTokenResponseClient.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -20,9 +20,9 @@ import org.springframework.http.RequestEntity;
 import org.springframework.http.ResponseEntity;
 import org.springframework.http.converter.FormHttpMessageConverter;
 import org.springframework.http.converter.HttpMessageConverter;
+import org.springframework.security.oauth2.client.ClientAuthorizationException;
 import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
-import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
@@ -30,6 +30,7 @@ import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
 import org.springframework.web.client.ResponseErrorHandler;
 import org.springframework.web.client.RestClientException;
+import org.springframework.web.client.RestClientResponseException;
 import org.springframework.web.client.RestOperations;
 import org.springframework.web.client.RestTemplate;
 
@@ -74,9 +75,22 @@ public final class DefaultPasswordTokenResponseClient implements OAuth2AccessTok
 		try {
 			response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class);
 		} catch (RestClientException ex) {
-			OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
-					"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null);
-			throw new OAuth2AuthorizationException(oauth2Error, ex);
+			int statusCode = 500;
+			if (ex instanceof RestClientResponseException) {
+				statusCode = ((RestClientResponseException) ex).getRawStatusCode();
+			}
+			OAuth2Error oauth2Error = new OAuth2Error(
+					INVALID_TOKEN_RESPONSE_ERROR_CODE,
+					"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(),
+					null);
+			String message = String.format("Error retrieving OAuth 2.0 Access Token (HTTP Status Code: %s) %s",
+					statusCode,
+					oauth2Error);
+			throw new ClientAuthorizationException(
+					oauth2Error,
+					passwordGrantRequest.getClientRegistration().getRegistrationId(),
+					message,
+					ex);
 		}
 
 		OAuth2AccessTokenResponse tokenResponse = response.getBody();

+ 19 - 5
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -20,9 +20,9 @@ import org.springframework.http.RequestEntity;
 import org.springframework.http.ResponseEntity;
 import org.springframework.http.converter.FormHttpMessageConverter;
 import org.springframework.http.converter.HttpMessageConverter;
+import org.springframework.security.oauth2.client.ClientAuthorizationException;
 import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
-import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
@@ -30,6 +30,7 @@ import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
 import org.springframework.web.client.ResponseErrorHandler;
 import org.springframework.web.client.RestClientException;
+import org.springframework.web.client.RestClientResponseException;
 import org.springframework.web.client.RestOperations;
 import org.springframework.web.client.RestTemplate;
 
@@ -73,9 +74,22 @@ public final class DefaultRefreshTokenTokenResponseClient implements OAuth2Acces
 		try {
 			response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class);
 		} catch (RestClientException ex) {
-			OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
-					"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null);
-			throw new OAuth2AuthorizationException(oauth2Error, ex);
+			int statusCode = 500;
+			if (ex instanceof RestClientResponseException) {
+				statusCode = ((RestClientResponseException) ex).getRawStatusCode();
+			}
+			OAuth2Error oauth2Error = new OAuth2Error(
+					INVALID_TOKEN_RESPONSE_ERROR_CODE,
+					"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(),
+					null);
+			String message = String.format("Error retrieving OAuth 2.0 Access Token (HTTP Status Code: %s) %s",
+					statusCode,
+					oauth2Error);
+			throw new ClientAuthorizationException(
+					oauth2Error,
+					refreshTokenGrantRequest.getClientRegistration().getRegistrationId(),
+					message,
+					ex);
 		}
 
 		OAuth2AccessTokenResponse tokenResponse = response.getBody();

+ 16 - 6
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClient.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -31,10 +31,10 @@ import com.nimbusds.oauth2.sdk.auth.Secret;
 import com.nimbusds.oauth2.sdk.http.HTTPRequest;
 import com.nimbusds.oauth2.sdk.id.ClientID;
 import org.springframework.http.MediaType;
+import org.springframework.security.oauth2.client.ClientAuthorizationException;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
-import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
@@ -100,9 +100,19 @@ public class NimbusAuthorizationCodeTokenResponseClient implements OAuth2AccessT
 			httpRequest.setReadTimeout(30000);
 			tokenResponse = com.nimbusds.oauth2.sdk.TokenResponse.parse(httpRequest.send());
 		} catch (ParseException | IOException ex) {
-			OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
-					"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null);
-			throw new OAuth2AuthorizationException(oauth2Error, ex);
+			int statusCode = 500;
+			OAuth2Error oauth2Error = new OAuth2Error(
+					INVALID_TOKEN_RESPONSE_ERROR_CODE,
+					"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(),
+					null);
+			String message = String.format("Error retrieving OAuth 2.0 Access Token (HTTP Status Code: %s) %s",
+					statusCode,
+					oauth2Error);
+			throw new ClientAuthorizationException(
+					oauth2Error,
+					clientRegistration.getRegistrationId(),
+					message,
+					ex);
 		}
 
 		if (!tokenResponse.indicatesSuccess()) {
@@ -117,7 +127,7 @@ public class NimbusAuthorizationCodeTokenResponseClient implements OAuth2AccessT
 						errorObject.getDescription(),
 						errorObject.getURI() != null ? errorObject.getURI().toString() : null);
 			}
-			throw new OAuth2AuthorizationException(oauth2Error);
+			throw new ClientAuthorizationException(oauth2Error, clientRegistration.getRegistrationId());
 		}
 
 		AccessTokenResponse accessTokenResponse = (AccessTokenResponse) tokenResponse;

+ 101 - 17
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java

@@ -15,22 +15,20 @@
  */
 package org.springframework.security.oauth2.client.web;
 
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.function.Function;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-
 import org.springframework.lang.Nullable;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.AuthorizedClientServiceOAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
+import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler;
+import org.springframework.security.oauth2.client.OAuth2AuthorizationSuccessHandler;
 import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
@@ -39,19 +37,57 @@ import org.springframework.web.context.request.RequestAttributes;
 import org.springframework.web.context.request.RequestContextHolder;
 import org.springframework.web.context.request.ServletRequestAttributes;
 
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.function.Function;
+
 /**
- * The default implementation of an {@link OAuth2AuthorizedClientManager}.
+ * The default implementation of an {@link OAuth2AuthorizedClientManager}
+ * for use within the context of a {@code HttpServletRequest}.
+ *
+ * <p>
+ * (When operating <em>outside</em> of the context of a {@code HttpServletRequest},
+ * use {@link AuthorizedClientServiceOAuth2AuthorizedClientManager} instead.)
+ *
+ * <h2>Authorized Client Persistence</h2>
+ *
+ * <p>
+ * This manager utilizes an {@link OAuth2AuthorizedClientRepository}
+ * to persist {@link OAuth2AuthorizedClient}s.
+ *
+ * <p>
+ * By default, when an authorization attempt succeeds, the {@link OAuth2AuthorizedClient}
+ * will be saved in the {@link OAuth2AuthorizedClientRepository}.
+ * This functionality can be changed by configuring a custom {@link OAuth2AuthorizationSuccessHandler}
+ * via {@link #setAuthorizationSuccessHandler(OAuth2AuthorizationSuccessHandler)}.
+ *
+ * <p>
+ * By default, when an authorization attempt fails due to an
+ * {@value OAuth2ErrorCodes#INVALID_GRANT} error,
+ * the previously saved {@link OAuth2AuthorizedClient}
+ * will be removed from the {@link OAuth2AuthorizedClientRepository}.
+ * (The {@value OAuth2ErrorCodes#INVALID_GRANT} error can occur
+ * when a refresh token that is no longer valid is used to retrieve a new access token.)
+ * This functionality can be changed by configuring a custom {@link OAuth2AuthorizationFailureHandler}
+ * via {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)}.
  *
  * @author Joe Grandja
  * @since 5.2
  * @see OAuth2AuthorizedClientManager
  * @see OAuth2AuthorizedClientProvider
+ * @see OAuth2AuthorizationSuccessHandler
+ * @see OAuth2AuthorizationFailureHandler
  */
 public final class DefaultOAuth2AuthorizedClientManager implements OAuth2AuthorizedClientManager {
 	private final ClientRegistrationRepository clientRegistrationRepository;
 	private final OAuth2AuthorizedClientRepository authorizedClientRepository;
 	private OAuth2AuthorizedClientProvider authorizedClientProvider = context -> null;
-	private Function<OAuth2AuthorizeRequest, Map<String, Object>> contextAttributesMapper = new DefaultContextAttributesMapper();
+	private Function<OAuth2AuthorizeRequest, Map<String, Object>> contextAttributesMapper;
+	private OAuth2AuthorizationSuccessHandler authorizationSuccessHandler;
+	private OAuth2AuthorizationFailureHandler authorizationFailureHandler;
 
 	/**
 	 * Constructs a {@code DefaultOAuth2AuthorizedClientManager} using the provided parameters.
@@ -65,6 +101,9 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
 		Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
 		this.clientRegistrationRepository = clientRegistrationRepository;
 		this.authorizedClientRepository = authorizedClientRepository;
+		this.contextAttributesMapper = new DefaultContextAttributesMapper();
+		this.authorizationSuccessHandler = new SaveAuthorizedClientOAuth2AuthorizationSuccessHandler(authorizedClientRepository);
+		this.authorizationFailureHandler = new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(authorizedClientRepository);
 	}
 
 	@Nullable
@@ -105,9 +144,17 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
 				})
 				.build();
 
-		authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
+		try {
+			authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
+		} catch (OAuth2AuthorizationException ex) {
+			this.authorizationFailureHandler.onAuthorizationFailure(
+					ex, principal, createAttributes(servletRequest, servletResponse));
+			throw ex;
+		}
+
 		if (authorizedClient != null) {
-			this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, servletRequest, servletResponse);
+			this.authorizationSuccessHandler.onAuthorizationSuccess(
+					authorizedClient, principal, createAttributes(servletRequest, servletResponse));
 		} else {
 			// In the case of re-authorization, the returned `authorizedClient` may be null if re-authorization is not supported.
 			// For these cases, return the provided `authorizationContext.authorizedClient`.
@@ -119,12 +166,19 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
 		return authorizedClient;
 	}
 
+	private static Map<String, Object> createAttributes(HttpServletRequest servletRequest, HttpServletResponse servletResponse) {
+		Map<String, Object> attributes = new HashMap<>();
+		attributes.put(HttpServletRequest.class.getName(), servletRequest);
+		attributes.put(HttpServletResponse.class.getName(), servletResponse);
+		return attributes;
+	}
+
 	private static HttpServletRequest getHttpServletRequestOrDefault(Map<String, Object> attributes) {
 		HttpServletRequest servletRequest = (HttpServletRequest) attributes.get(HttpServletRequest.class.getName());
 		if (servletRequest == null) {
-			RequestAttributes context = RequestContextHolder.getRequestAttributes();
-			if (context instanceof ServletRequestAttributes) {
-				servletRequest = ((ServletRequestAttributes) context).getRequest();
+			RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
+			if (requestAttributes instanceof ServletRequestAttributes) {
+				servletRequest = ((ServletRequestAttributes) requestAttributes).getRequest();
 			}
 		}
 		return servletRequest;
@@ -133,9 +187,9 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
 	private static HttpServletResponse getHttpServletResponseOrDefault(Map<String, Object> attributes) {
 		HttpServletResponse servletResponse = (HttpServletResponse) attributes.get(HttpServletResponse.class.getName());
 		if (servletResponse == null) {
-			RequestAttributes context = RequestContextHolder.getRequestAttributes();
-			if (context instanceof ServletRequestAttributes) {
-				servletResponse =  ((ServletRequestAttributes) context).getResponse();
+			RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
+			if (requestAttributes instanceof ServletRequestAttributes) {
+				servletResponse =  ((ServletRequestAttributes) requestAttributes).getResponse();
 			}
 		}
 		return servletResponse;
@@ -163,6 +217,36 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
 		this.contextAttributesMapper = contextAttributesMapper;
 	}
 
+	/**
+	 * Sets the {@link OAuth2AuthorizationSuccessHandler} that handles successful authorizations.
+	 *
+	 * <p>
+	 * A {@link SaveAuthorizedClientOAuth2AuthorizationSuccessHandler} is used by default.
+	 *
+	 * @param authorizationSuccessHandler the {@link OAuth2AuthorizationSuccessHandler} that handles successful authorizations
+	 * @see SaveAuthorizedClientOAuth2AuthorizationSuccessHandler
+	 * @since 5.3
+	 */
+	public void setAuthorizationSuccessHandler(OAuth2AuthorizationSuccessHandler authorizationSuccessHandler) {
+		Assert.notNull(authorizationSuccessHandler, "authorizationSuccessHandler cannot be null");
+		this.authorizationSuccessHandler = authorizationSuccessHandler;
+	}
+
+	/**
+	 * Sets the {@link OAuth2AuthorizationFailureHandler} that handles authorization failures.
+	 *
+	 * <p>
+	 * A {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} is used by default.
+	 *
+	 * @param authorizationFailureHandler the {@link OAuth2AuthorizationFailureHandler} that handles authorization failures
+	 * @see RemoveAuthorizedClientOAuth2AuthorizationFailureHandler
+	 * @since 5.3
+	 */
+	public void setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler authorizationFailureHandler) {
+		Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null");
+		this.authorizationFailureHandler = authorizationFailureHandler;
+	}
+
 	/**
 	 * The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}.
 	 */

+ 169 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/RemoveAuthorizedClientOAuth2AuthorizationFailureHandler.java

@@ -0,0 +1,169 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.web;
+
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.ClientAuthorizationException;
+import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.util.Assert;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * An {@link OAuth2AuthorizationFailureHandler} that removes an {@link OAuth2AuthorizedClient}
+ * from an {@link OAuth2AuthorizedClientRepository} or {@link OAuth2AuthorizedClientService}
+ * for a specific set of OAuth 2.0 error codes.
+ *
+ * @author Joe Grandja
+ * @since 5.3
+ * @see OAuth2AuthorizedClient
+ * @see OAuth2AuthorizedClientRepository
+ * @see OAuth2AuthorizedClientService
+ */
+public class RemoveAuthorizedClientOAuth2AuthorizationFailureHandler implements OAuth2AuthorizationFailureHandler {
+
+	/**
+	 * The default OAuth 2.0 error codes that will trigger removal of an {@link OAuth2AuthorizedClient}.
+	 * @see OAuth2ErrorCodes
+	 */
+	public static final Set<String>	DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList(
+			/*
+			 * Returned from Resource Servers when an access token provided is expired, revoked,
+			 * malformed, or invalid for other reasons.
+			 *
+			 * Note that this is needed because ServletOAuth2AuthorizedClientExchangeFilterFunction
+			 * delegates this type of failure received from a Resource Server
+			 * to this failure handler.
+			 */
+			OAuth2ErrorCodes.INVALID_TOKEN,
+
+			/*
+			 * Returned from Authorization Servers when the authorization grant or refresh token is invalid, expired, revoked,
+			 * does not match the redirection URI used in the authorization request, or was issued to another client.
+			 */
+			OAuth2ErrorCodes.INVALID_GRANT
+	)));
+
+	/**
+	 * The OAuth 2.0 error codes which will trigger removal of an {@link OAuth2AuthorizedClient}.
+	 * @see OAuth2ErrorCodes
+	 */
+	private final Set<String> removeAuthorizedClientErrorCodes;
+
+	/**
+	 * A delegate that removes an {@link OAuth2AuthorizedClient} from a
+	 * {@link OAuth2AuthorizedClientRepository} or {@link OAuth2AuthorizedClientService}
+	 * if the error code is one of the {@link #removeAuthorizedClientErrorCodes}.
+	 */
+	private final OAuth2AuthorizedClientRemover delegate;
+
+	@FunctionalInterface
+	private interface OAuth2AuthorizedClientRemover {
+		void removeAuthorizedClient(String clientRegistrationId, Authentication principal, Map<String, Object> attributes);
+	}
+
+	/**
+	 * Constructs a {@code RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} using the provided parameters.
+	 *
+	 * @param authorizedClientRepository the repository from which authorized clients will be removed
+	 *                                   if the error code is one of the {@link #DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES}.
+	 */
+	public RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(OAuth2AuthorizedClientRepository authorizedClientRepository) {
+		this(authorizedClientRepository, DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES);
+	}
+
+	/**
+	 * Constructs a {@code RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} using the provided parameters.
+	 *
+	 * @param authorizedClientRepository the repository from which authorized clients will be removed
+	 *                                   if the error code is one of the {@link #removeAuthorizedClientErrorCodes}.
+	 * @param removeAuthorizedClientErrorCodes the OAuth 2.0 error codes which will trigger removal of an authorized client.
+	 * @see OAuth2ErrorCodes
+	 */
+	public RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(
+			OAuth2AuthorizedClientRepository authorizedClientRepository,
+			Set<String> removeAuthorizedClientErrorCodes) {
+		Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
+		Assert.notNull(removeAuthorizedClientErrorCodes, "removeAuthorizedClientErrorCodes cannot be null");
+		this.removeAuthorizedClientErrorCodes = Collections.unmodifiableSet(new HashSet<>(removeAuthorizedClientErrorCodes));
+		this.delegate = (clientRegistrationId, principal, attributes) ->
+				authorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal,
+						(HttpServletRequest) attributes.get(HttpServletRequest.class.getName()),
+						(HttpServletResponse) attributes.get(HttpServletResponse.class.getName()));
+	}
+
+	/**
+	 * Constructs a {@code RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} using the provided parameters.
+	 *
+	 * @param authorizedClientService the service from which authorized clients will be removed
+	 *                                if the error code is one of the {@link #DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES}.
+	 */
+	public RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(OAuth2AuthorizedClientService authorizedClientService) {
+		this(authorizedClientService, DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES);
+	}
+
+	/**
+	 * Constructs a {@code RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} using the provided parameters.
+	 *
+	 * @param authorizedClientService the service from which authorized clients will be removed
+	 *                                if the error code is one of the {@link #removeAuthorizedClientErrorCodes}.
+	 * @param removeAuthorizedClientErrorCodes the OAuth 2.0 error codes which will trigger removal of an authorized client.
+	 * @see OAuth2ErrorCodes
+	 */
+	public RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(
+			OAuth2AuthorizedClientService authorizedClientService,
+			Set<String> removeAuthorizedClientErrorCodes) {
+		Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
+		Assert.notNull(removeAuthorizedClientErrorCodes, "removeAuthorizedClientErrorCodes cannot be null");
+		this.removeAuthorizedClientErrorCodes = Collections.unmodifiableSet(new HashSet<>(removeAuthorizedClientErrorCodes));
+		this.delegate = (clientRegistrationId, principal, attributes) ->
+				authorizedClientService.removeAuthorizedClient(clientRegistrationId, principal.getName());
+	}
+
+	@Override
+	public void onAuthorizationFailure(OAuth2AuthorizationException authorizationException,
+			Authentication principal, Map<String, Object> attributes) {
+
+		if (authorizationException instanceof ClientAuthorizationException &&
+				hasRemovalErrorCode(authorizationException)) {
+			ClientAuthorizationException clientAuthorizationException = (ClientAuthorizationException) authorizationException;
+			this.delegate.removeAuthorizedClient(
+					clientAuthorizationException.getClientRegistrationId(), principal, attributes);
+		}
+	}
+
+	/**
+	 * Returns true if the given exception has an error code that
+	 * indicates that the authorized client should be removed.
+	 *
+	 * @param authorizationException the exception that caused the authorization failure
+	 * @return true if the given exception has an error code that
+	 * 		   indicates that the authorized client should be removed.
+	 */
+	private boolean hasRemovalErrorCode(OAuth2AuthorizationException authorizationException) {
+		return this.removeAuthorizedClientErrorCodes.contains(authorizationException.getError().getErrorCode());
+	}
+}

+ 77 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/SaveAuthorizedClientOAuth2AuthorizationSuccessHandler.java

@@ -0,0 +1,77 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.web;
+
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.OAuth2AuthorizationSuccessHandler;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
+import org.springframework.util.Assert;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.util.Map;
+
+/**
+ * An {@link OAuth2AuthorizationSuccessHandler} that saves an {@link OAuth2AuthorizedClient}
+ * in an {@link OAuth2AuthorizedClientRepository} or {@link OAuth2AuthorizedClientService}.
+ *
+ * @author Joe Grandja
+ * @since 5.3
+ * @see OAuth2AuthorizedClient
+ * @see OAuth2AuthorizedClientRepository
+ * @see OAuth2AuthorizedClientService
+ */
+public class SaveAuthorizedClientOAuth2AuthorizationSuccessHandler implements OAuth2AuthorizationSuccessHandler {
+
+	/**
+	 * A delegate that saves an {@link OAuth2AuthorizedClient} in an
+	 * {@link OAuth2AuthorizedClientRepository} or {@link OAuth2AuthorizedClientService}.
+	 */
+	private final OAuth2AuthorizationSuccessHandler delegate;
+
+	/**
+	 * Constructs a {@code SaveAuthorizedClientOAuth2AuthorizationSuccessHandler} using the provided parameters.
+	 *
+	 * @param authorizedClientRepository The repository in which authorized clients will be saved.
+	 */
+	public SaveAuthorizedClientOAuth2AuthorizationSuccessHandler(
+			OAuth2AuthorizedClientRepository authorizedClientRepository) {
+		Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
+		this.delegate = (authorizedClient, principal, attributes) ->
+				authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal,
+						(HttpServletRequest) attributes.get(HttpServletRequest.class.getName()),
+						(HttpServletResponse) attributes.get(HttpServletResponse.class.getName()));
+	}
+
+	/**
+	 * Constructs a {@code SaveAuthorizedClientOAuth2AuthorizationSuccessHandler} using the provided parameters.
+	 *
+	 * @param authorizedClientService The service in which authorized clients will be saved.
+	 */
+	public SaveAuthorizedClientOAuth2AuthorizationSuccessHandler(
+			OAuth2AuthorizedClientService authorizedClientService) {
+		Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
+		this.delegate = (authorizedClient, principal, attributes) ->
+				authorizedClientService.saveAuthorizedClient(authorizedClient, principal);
+	}
+
+	@Override
+	public void onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient,
+			Authentication principal, Map<String, Object> attributes) {
+		this.delegate.onAuthorizationSuccess(authorizedClient, principal, attributes);
+	}
+}

+ 306 - 24
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

@@ -13,15 +13,18 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package org.springframework.security.oauth2.client.web.reactive.function.client;
 
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpStatus;
 import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.oauth2.client.ClientAuthorizationException;
 import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler;
 import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
@@ -35,7 +38,13 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.client.web.RemoveAuthorizedClientOAuth2AuthorizationFailureHandler;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
 import org.springframework.web.context.request.RequestAttributes;
 import org.springframework.web.context.request.RequestContextHolder;
 import org.springframework.web.context.request.ServletRequestAttributes;
@@ -44,6 +53,7 @@ import org.springframework.web.reactive.function.client.ClientResponse;
 import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
 import org.springframework.web.reactive.function.client.ExchangeFunction;
 import org.springframework.web.reactive.function.client.WebClient;
+import org.springframework.web.reactive.function.client.WebClientResponseException;
 import reactor.core.publisher.Mono;
 import reactor.core.scheduler.Schedulers;
 import reactor.util.context.Context;
@@ -52,18 +62,25 @@ import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 import java.time.Duration;
 import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.Map;
 import java.util.function.Consumer;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 /**
- * Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth2 requests by including the
- * token as a Bearer Token. It also provides mechanisms for looking up the {@link OAuth2AuthorizedClient}. This class is
- * intended to be used in a servlet environment.
+ * Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth 2.0 requests
+ * by including the {@link OAuth2AuthorizedClient#getAccessToken() access token} as a bearer token.
+ *
+ * <p>
+ * <b>NOTE:</b>This class is intended to be used in a {@code Servlet} environment.
  *
+ * <p>
  * Example usage:
  *
  * <pre>
- * ServletOAuth2AuthorizedClientExchangeFilterFunction oauth2 = new ServletOAuth2AuthorizedClientExchangeFilterFunction(clientRegistrationRepository, authorizedClientRepository);
+ * ServletOAuth2AuthorizedClientExchangeFilterFunction oauth2 = new ServletOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager);
  * WebClient webClient = WebClient.builder()
  *    .apply(oauth2.oauth2Configuration())
  *    .build();
@@ -76,23 +93,35 @@ import java.util.function.Consumer;
  *    .bodyToMono(String.class);
  * </pre>
  *
- * An attempt to automatically refresh the token will be made if all of the following
- * are true:
+ * <h3>Authentication and Authorization Failures</h3>
+ *
+ * <p>
+ * Since 5.3, this filter function has the ability to forward authentication (HTTP 401 Unauthorized)
+ * and authorization (HTTP 403 Forbidden) failures from an OAuth 2.0 Resource Server
+ * to a {@link OAuth2AuthorizationFailureHandler}.
+ * A {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler} can be used
+ * to remove the cached {@link OAuth2AuthorizedClient}, so that future requests will result
+ * in a new token being retrieved from an Authorization Server, and sent to the Resource Server.
+ *
+ * <p>
+ * If the {@link #ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)}
+ * constructor is used, a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler}
+ * will be configured automatically.
  *
- * <ul>
- * <li>The {@link OAuth2AuthorizedClientManager} is not null</li>
- * <li>A refresh token is present on the {@link OAuth2AuthorizedClient}</li>
- * <li>The access token is expired</li>
- * <li>The {@link SecurityContextHolder} will be used to attempt to save
- * the token. If it is empty, then the principal name on the {@link OAuth2AuthorizedClient}
- * will be used to create an Authentication for saving.</li>
- * </ul>
+ * <p>
+ * If the {@link #ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)}
+ * constructor is used, a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler}
+ * will <em>NOT</em> be configured automatically.
+ * It is recommended that you configure one via {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)}.
  *
  * @author Rob Winch
  * @author Joe Grandja
  * @author Roman Matiushchenko
  * @since 5.1
  * @see OAuth2AuthorizedClientManager
+ * @see DefaultOAuth2AuthorizedClientManager
+ * @see OAuth2AuthorizedClientProvider
+ * @see OAuth2AuthorizedClientProviderBuilder
  */
 public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
 
@@ -103,6 +132,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 	 * The request attribute name used to locate the {@link OAuth2AuthorizedClient}.
 	 */
 	private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName();
+
 	private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2AuthorizedClient.class.getName().concat(".CLIENT_REGISTRATION_ID");
 	private static final String AUTHENTICATION_ATTR_NAME = Authentication.class.getName();
 	private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
@@ -125,35 +155,75 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 
 	private String defaultClientRegistrationId;
 
+	private ClientResponseHandler clientResponseHandler;
+
+	@FunctionalInterface
+	private interface ClientResponseHandler {
+		Mono<ClientResponse> handleResponse(ClientRequest request, Mono<ClientResponse> response);
+	}
+
 	public ServletOAuth2AuthorizedClientExchangeFilterFunction() {
 	}
 
 	/**
 	 * Constructs a {@code ServletOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters.
 	 *
+	 * <p>
+	 * When this constructor is used, authentication (HTTP 401) and authorization (HTTP 403)
+	 * failures returned from an OAuth 2.0 Resource Server will <em>NOT</em> be forwarded to an
+	 * {@link OAuth2AuthorizationFailureHandler}.
+	 * Therefore, future requests to the Resource Server will most likely use the same (likely invalid) token,
+	 * resulting in the same errors returned from the Resource Server.
+	 * It is recommended to configure a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler}
+	 * via {@link #setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler)}
+	 * so that authentication and authorization failures returned from a Resource Server
+	 * will result in removing the authorized client, so that a new token is retrieved for future requests.
+	 *
 	 * @since 5.2
 	 * @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which manages the authorized client(s)
 	 */
 	public ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager authorizedClientManager) {
 		Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null");
 		this.authorizedClientManager = authorizedClientManager;
+		this.clientResponseHandler =  (request, responseMono) -> responseMono;
 	}
 
 	/**
 	 * Constructs a {@code ServletOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters.
 	 *
+	 * <p>
+	 * Since 5.3, when this constructor is used, authentication (HTTP 401)
+	 * and authorization (HTTP 403) failures returned from an OAuth 2.0 Resource Server
+	 * will be forwarded to a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler},
+	 * which will potentially remove the {@link OAuth2AuthorizedClient} from the given
+	 * {@link OAuth2AuthorizedClientRepository}, depending on the OAuth 2.0 error code returned.
+	 * Authentication failures returned from an OAuth 2.0 Resource Server typically indicate
+	 * that the token is invalid, and should not be used in future requests.
+	 * Removing the authorized client from the repository will ensure that the existing
+	 * token will not be sent for future requests to the Resource Server,
+	 * and a new token is retrieved from the Authorization Server and used for
+	 * future requests to the Resource Server.
+	 *
 	 * @param clientRegistrationRepository the repository of client registrations
 	 * @param authorizedClientRepository the repository of authorized clients
 	 */
 	public ServletOAuth2AuthorizedClientExchangeFilterFunction(
 			ClientRegistrationRepository clientRegistrationRepository,
 			OAuth2AuthorizedClientRepository authorizedClientRepository) {
-		this.authorizedClientManager = createDefaultAuthorizedClientManager(clientRegistrationRepository, authorizedClientRepository);
+
+		OAuth2AuthorizationFailureHandler authorizationFailureHandler =
+				new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(authorizedClientRepository);
+
+		this.authorizedClientManager = createDefaultAuthorizedClientManager(
+				clientRegistrationRepository, authorizedClientRepository, authorizationFailureHandler);
 		this.defaultAuthorizedClientManager = true;
+		this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler);
 	}
 
 	private static OAuth2AuthorizedClientManager createDefaultAuthorizedClientManager(
-			ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) {
+			ClientRegistrationRepository clientRegistrationRepository,
+			OAuth2AuthorizedClientRepository authorizedClientRepository,
+			OAuth2AuthorizationFailureHandler authorizationFailureHandler) {
 
 		OAuth2AuthorizedClientProvider authorizedClientProvider =
 				OAuth2AuthorizedClientProviderBuilder.builder()
@@ -165,6 +235,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 		DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager(
 				clientRegistrationRepository, authorizedClientRepository);
 		authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
+		authorizedClientManager.setAuthorizationFailureHandler(authorizationFailureHandler);
 
 		return authorizedClientManager;
 	}
@@ -333,19 +404,47 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 		updateDefaultAuthorizedClientManager();
 	}
 
+	/**
+	 * Sets the {@link OAuth2AuthorizationFailureHandler} that handles
+	 * authentication and authorization failures when communicating
+	 * to the OAuth 2.0 Resource Server.
+	 *
+	 * <p>
+	 * For example, a {@link RemoveAuthorizedClientOAuth2AuthorizationFailureHandler}
+	 * is typically used to remove the cached {@link OAuth2AuthorizedClient},
+	 * so that the same token is no longer used in future requests to the Resource Server.
+	 *
+	 * <p>
+	 * The failure handler used by default depends on which constructor was used
+	 * to construct this {@link ServletOAuth2AuthorizedClientExchangeFilterFunction}.
+	 * See the constructors for more details.
+	 *
+	 * @param authorizationFailureHandler the {@link OAuth2AuthorizationFailureHandler} that handles authentication and authorization failures
+	 * @since 5.3
+	 */
+	public void setAuthorizationFailureHandler(OAuth2AuthorizationFailureHandler authorizationFailureHandler) {
+		Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null");
+		this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler);
+	}
+
 	@Override
 	public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
 		return mergeRequestAttributesIfNecessary(request)
 				.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
-				.flatMap(req -> authorizedClient(getOAuth2AuthorizedClient(req.attributes()), req))
+				.flatMap(req -> reauthorizeClient(getOAuth2AuthorizedClient(req.attributes()), req))
 				.switchIfEmpty(Mono.defer(() ->
 						mergeRequestAttributesIfNecessary(request)
 								.filter(req -> resolveClientRegistrationId(req) != null)
 								.flatMap(req -> authorizeClient(resolveClientRegistrationId(req), req))
 				))
 				.map(authorizedClient -> bearer(request, authorizedClient))
-				.flatMap(next::exchange)
-				.switchIfEmpty(Mono.defer(() -> next.exchange(request)));
+				.flatMap(requestWithBearer -> exchangeAndHandleResponse(requestWithBearer, next))
+				.switchIfEmpty(Mono.defer(() -> exchangeAndHandleResponse(request, next)));
+	}
+
+	private Mono<ClientResponse> exchangeAndHandleResponse(ClientRequest request, ExchangeFunction next) {
+		return next.exchange(request)
+				.transform(responseMono -> this.clientResponseHandler.handleResponse(request, responseMono));
 	}
 
 	private Mono<ClientRequest> mergeRequestAttributesIfNecessary(ClientRequest request) {
@@ -443,13 +542,14 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 		});
 		OAuth2AuthorizeRequest authorizeRequest = builder.build();
 
-		// NOTE: 'authorizedClientManager.authorize()' needs to be executed on a dedicated thread via subscribeOn(Schedulers.boundedElastic())
-		// NOTE: 'authorizedClientManager.authorize()' needs to be executed on a dedicated thread via subscribeOn(Schedulers.boundedElastic())
+		// NOTE:
+		// 'authorizedClientManager.authorize()' needs to be executed
+		// on a dedicated thread via subscribeOn(Schedulers.boundedElastic())
 		// since it performs a blocking I/O operation using RestTemplate internally
 		return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(authorizeRequest)).subscribeOn(Schedulers.boundedElastic());
 	}
 
-	private Mono<OAuth2AuthorizedClient> authorizedClient(OAuth2AuthorizedClient authorizedClient, ClientRequest request) {
+	private Mono<OAuth2AuthorizedClient> reauthorizeClient(OAuth2AuthorizedClient authorizedClient, ClientRequest request) {
 		if (this.authorizedClientManager == null) {
 			return Mono.just(authorizedClient);
 		}
@@ -472,7 +572,9 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 		});
 		OAuth2AuthorizeRequest reauthorizeRequest = builder.build();
 
-		// NOTE: 'authorizedClientManager.authorize()' needs to be executed on a dedicated thread via subscribeOn(Schedulers.boundedElastic())
+		// NOTE:
+		// 'authorizedClientManager.authorize()' needs to be executed
+		// on a dedicated thread via subscribeOn(Schedulers.boundedElastic())
 		// since it performs a blocking I/O operation using RestTemplate internally
 		return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(reauthorizeRequest)).subscribeOn(Schedulers.boundedElastic());
 	}
@@ -480,6 +582,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 	private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) {
 		return ClientRequest.from(request)
 					.headers(headers -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue()))
+					.attributes(oauth2AuthorizedClient(authorizedClient))
 					.build();
 	}
 
@@ -550,4 +653,183 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 			return new UnsupportedOperationException("Not Supported");
 		}
 	}
+
+	/**
+	 * Forwards authentication and authorization failures to an
+	 * {@link OAuth2AuthorizationFailureHandler}.
+	 *
+	 * @since 5.3
+	 */
+	private static class AuthorizationFailureForwarder implements ClientResponseHandler {
+
+		/**
+		 * A map of HTTP status code to OAuth 2.0 error code for
+		 * HTTP status codes that should be interpreted as
+		 * authentication or authorization failures.
+		 */
+		private final Map<Integer, String> httpStatusToOAuth2ErrorCodeMap;
+
+		/**
+		 * The {@link OAuth2AuthorizationFailureHandler} to notify
+		 * when an authentication/authorization failure occurs.
+		 */
+		private final OAuth2AuthorizationFailureHandler authorizationFailureHandler;
+
+		private AuthorizationFailureForwarder(OAuth2AuthorizationFailureHandler authorizationFailureHandler) {
+			Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null");
+			this.authorizationFailureHandler = authorizationFailureHandler;
+
+			Map<Integer, String> httpStatusToOAuth2Error = new HashMap<>();
+			httpStatusToOAuth2Error.put(HttpStatus.UNAUTHORIZED.value(), OAuth2ErrorCodes.INVALID_TOKEN);
+			httpStatusToOAuth2Error.put(HttpStatus.FORBIDDEN.value(), OAuth2ErrorCodes.INSUFFICIENT_SCOPE);
+			this.httpStatusToOAuth2ErrorCodeMap = Collections.unmodifiableMap(httpStatusToOAuth2Error);
+		}
+
+		@Override
+		public Mono<ClientResponse> handleResponse(ClientRequest request, Mono<ClientResponse> responseMono) {
+			return responseMono
+					.flatMap(response -> handleResponse(request, response)
+							.thenReturn(response))
+					.onErrorResume(WebClientResponseException.class, e -> handleWebClientResponseException(request, e)
+							.then(Mono.error(e)))
+					.onErrorResume(OAuth2AuthorizationException.class, e -> handleAuthorizationException(request, e)
+							.then(Mono.error(e)));
+		}
+
+		private Mono<Void> handleResponse(ClientRequest request, ClientResponse response) {
+			return Mono.justOrEmpty(resolveErrorIfPossible(response))
+					.flatMap(oauth2Error -> {
+						Map<String, Object> attrs = request.attributes();
+						OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs);
+						if (authorizedClient == null) {
+							return Mono.empty();
+						}
+
+						ClientAuthorizationException authorizationException = new ClientAuthorizationException(
+								oauth2Error, authorizedClient.getClientRegistration().getRegistrationId());
+
+						Authentication principal = new PrincipalNameAuthentication(authorizedClient.getPrincipalName());
+						HttpServletRequest servletRequest = getRequest(attrs);
+						HttpServletResponse servletResponse = getResponse(attrs);
+
+						return handleAuthorizationFailure(authorizationException, principal, servletRequest, servletResponse);
+					});
+		}
+
+		private OAuth2Error resolveErrorIfPossible(ClientResponse response) {
+			// Try to resolve from 'WWW-Authenticate' header
+			if (!response.headers().header(HttpHeaders.WWW_AUTHENTICATE).isEmpty()) {
+				String wwwAuthenticateHeader = response.headers().header(HttpHeaders.WWW_AUTHENTICATE).get(0);
+				Map<String, String> authParameters = parseAuthParameters(wwwAuthenticateHeader);
+				if (authParameters.containsKey(OAuth2ParameterNames.ERROR)) {
+					return new OAuth2Error(
+							authParameters.get(OAuth2ParameterNames.ERROR),
+							authParameters.get(OAuth2ParameterNames.ERROR_DESCRIPTION),
+							authParameters.get(OAuth2ParameterNames.ERROR_URI));
+				}
+			}
+			return resolveErrorIfPossible(response.rawStatusCode());
+		}
+
+		private OAuth2Error resolveErrorIfPossible(int statusCode) {
+			if (this.httpStatusToOAuth2ErrorCodeMap.containsKey(statusCode)) {
+				return new OAuth2Error(
+						this.httpStatusToOAuth2ErrorCodeMap.get(statusCode),
+						null,
+						"https://tools.ietf.org/html/rfc6750#section-3.1");
+			}
+			return null;
+		}
+
+		private Map<String, String> parseAuthParameters(String wwwAuthenticateHeader) {
+			return Stream.of(wwwAuthenticateHeader)
+					.filter(header -> !StringUtils.isEmpty(header))
+					.filter(header -> header.toLowerCase().startsWith("bearer"))
+					.map(header -> header.substring("bearer".length()))
+					.map(header -> header.split(","))
+					.flatMap(Stream::of)
+					.map(parameter -> parameter.split("="))
+					.filter(parameter -> parameter.length > 1)
+					.collect(Collectors.toMap(
+							parameters -> parameters[0].trim(),
+							parameters -> parameters[1].trim().replace("\"", "")));
+		}
+
+		/**
+		 * Handles the given http status code returned from a resource server
+		 * by notifying the authorization failure handler if the http status
+		 * code is in the {@link #httpStatusToOAuth2ErrorCodeMap}.
+		 *
+		 * @param request the request being processed
+		 * @param exception The root cause exception for the failure
+		 * @return a {@link Mono} that completes empty after the authorization failure handler completes
+		 */
+		private Mono<Void> handleWebClientResponseException(ClientRequest request, WebClientResponseException exception) {
+			return Mono.justOrEmpty(resolveErrorIfPossible(exception.getRawStatusCode()))
+					.flatMap(oauth2Error -> {
+						Map<String, Object> attrs = request.attributes();
+						OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs);
+						if (authorizedClient == null) {
+							return Mono.empty();
+						}
+
+						ClientAuthorizationException authorizationException = new ClientAuthorizationException(
+								oauth2Error, authorizedClient.getClientRegistration().getRegistrationId(), exception);
+
+						Authentication principal = new PrincipalNameAuthentication(authorizedClient.getPrincipalName());
+						HttpServletRequest servletRequest = getRequest(attrs);
+						HttpServletResponse servletResponse = getResponse(attrs);
+
+						return handleAuthorizationFailure(authorizationException, principal, servletRequest, servletResponse);
+					});
+		}
+
+		/**
+		 * Handles the given {@link OAuth2AuthorizationException} that occurred downstream
+		 * by notifying the authorization failure handler.
+		 *
+		 * @param request the request being processed
+		 * @param authorizationException the authorization exception to include in the failure event
+		 * @return a {@link Mono} that completes empty after the authorization failure handler completes
+		 */
+		private Mono<Void> handleAuthorizationException(ClientRequest request, OAuth2AuthorizationException authorizationException) {
+			return Mono.justOrEmpty(request)
+					.flatMap(req -> {
+						Map<String, Object> attrs = req.attributes();
+						OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs);
+						if (authorizedClient == null) {
+							return Mono.empty();
+						}
+
+						Authentication principal = new PrincipalNameAuthentication(authorizedClient.getPrincipalName());
+						HttpServletRequest servletRequest = getRequest(attrs);
+						HttpServletResponse servletResponse = getResponse(attrs);
+
+						return handleAuthorizationFailure(authorizationException, principal, servletRequest, servletResponse);
+					});
+		}
+
+		/**
+		 * Delegates the failed authorization to the {@link OAuth2AuthorizationFailureHandler}.
+		 *
+		 * @param exception the {@link OAuth2AuthorizationException} to include in the failure event
+		 * @param principal the principal associated with the failed authorization attempt
+		 * @param servletRequest the currently active {@code HttpServletRequest}
+		 * @param servletResponse the currently active {@code HttpServletResponse}
+		 * @return a {@link Mono} that completes empty after the {@link OAuth2AuthorizationFailureHandler} completes
+		 */
+		private Mono<Void> handleAuthorizationFailure(OAuth2AuthorizationException exception,
+				Authentication principal, HttpServletRequest servletRequest, HttpServletResponse servletResponse) {
+			Runnable runnable = () -> this.authorizationFailureHandler.onAuthorizationFailure(
+					exception, principal, createAttributes(servletRequest, servletResponse));
+			return Mono.fromRunnable(runnable).subscribeOn(Schedulers.boundedElastic()).then();
+		}
+
+		private static Map<String, Object> createAttributes(HttpServletRequest servletRequest, HttpServletResponse servletResponse) {
+			Map<String, Object> attributes = new HashMap<>();
+			attributes.put(HttpServletRequest.class.getName(), servletRequest);
+			attributes.put(HttpServletResponse.class.getName(), servletResponse);
+			return attributes;
+		}
+	}
 }

+ 87 - 6
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManagerTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -23,6 +23,10 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
+import org.springframework.security.oauth2.client.web.RemoveAuthorizedClientOAuth2AuthorizationFailureHandler;
+import org.springframework.security.oauth2.client.web.SaveAuthorizedClientOAuth2AuthorizationSuccessHandler;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
 import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@@ -30,10 +34,16 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import java.util.function.Function;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.Mockito.*;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.when;
 
 /**
  * Tests for {@link AuthorizedClientServiceOAuth2AuthorizedClientManager}.
@@ -45,6 +55,8 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
 	private OAuth2AuthorizedClientService authorizedClientService;
 	private OAuth2AuthorizedClientProvider authorizedClientProvider;
 	private Function contextAttributesMapper;
+	private OAuth2AuthorizationSuccessHandler authorizationSuccessHandler;
+	private OAuth2AuthorizationFailureHandler authorizationFailureHandler;
 	private AuthorizedClientServiceOAuth2AuthorizedClientManager authorizedClientManager;
 	private ClientRegistration clientRegistration;
 	private Authentication principal;
@@ -58,10 +70,14 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
 		this.authorizedClientService = mock(OAuth2AuthorizedClientService.class);
 		this.authorizedClientProvider = mock(OAuth2AuthorizedClientProvider.class);
 		this.contextAttributesMapper = mock(Function.class);
+		this.authorizationSuccessHandler = spy(new SaveAuthorizedClientOAuth2AuthorizationSuccessHandler(this.authorizedClientService));
+		this.authorizationFailureHandler = spy(new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(this.authorizedClientService));
 		this.authorizedClientManager = new AuthorizedClientServiceOAuth2AuthorizedClientManager(
 				this.clientRegistrationRepository, this.authorizedClientService);
 		this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider);
 		this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper);
+		this.authorizedClientManager.setAuthorizationSuccessHandler(this.authorizationSuccessHandler);
+		this.authorizedClientManager.setAuthorizationFailureHandler(this.authorizationFailureHandler);
 		this.clientRegistration = TestClientRegistrations.clientRegistration().build();
 		this.principal = new TestingAuthenticationToken("principal", "password");
 		this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(),
@@ -97,6 +113,20 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
 				.hasMessage("contextAttributesMapper cannot be null");
 	}
 
+	@Test
+	public void setAuthorizationSuccessHandlerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizationSuccessHandler cannot be null");
+	}
+
+	@Test
+	public void setAuthorizationFailureHandlerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizationFailureHandler cannot be null");
+	}
+
 	@Test
 	public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> this.authorizedClientManager.authorize(null))
@@ -134,8 +164,8 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
 		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
 
 		assertThat(authorizedClient).isNull();
-		verify(this.authorizedClientService, never()).saveAuthorizedClient(
-				any(OAuth2AuthorizedClient.class), eq(this.principal));
+		verifyNoInteractions(this.authorizationSuccessHandler);
+		verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any());
 	}
 
 	@SuppressWarnings("unchecked")
@@ -160,6 +190,8 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
 		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
 
 		assertThat(authorizedClient).isSameAs(this.authorizedClient);
+		verify(this.authorizationSuccessHandler).onAuthorizationSuccess(
+				eq(this.authorizedClient), eq(this.principal), any());
 		verify(this.authorizedClientService).saveAuthorizedClient(
 				eq(this.authorizedClient), eq(this.principal));
 	}
@@ -192,6 +224,8 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
 		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
 
 		assertThat(authorizedClient).isSameAs(reauthorizedClient);
+		verify(this.authorizationSuccessHandler).onAuthorizationSuccess(
+				eq(reauthorizedClient), eq(this.principal), any());
 		verify(this.authorizedClientService).saveAuthorizedClient(
 				eq(reauthorizedClient), eq(this.principal));
 	}
@@ -213,8 +247,8 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
 		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
 
 		assertThat(authorizedClient).isSameAs(this.authorizedClient);
-		verify(this.authorizedClientService, never()).saveAuthorizedClient(
-				any(OAuth2AuthorizedClient.class), eq(this.principal));
+		verifyNoInteractions(this.authorizationSuccessHandler);
+		verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any());
 	}
 
 	@SuppressWarnings("unchecked")
@@ -240,6 +274,8 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
 		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
 
 		assertThat(authorizedClient).isSameAs(reauthorizedClient);
+		verify(this.authorizationSuccessHandler).onAuthorizationSuccess(
+				eq(reauthorizedClient), eq(this.principal), any());
 		verify(this.authorizedClientService).saveAuthorizedClient(
 				eq(reauthorizedClient), eq(this.principal));
 	}
@@ -274,7 +310,52 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
 		assertThat(requestScopeAttribute).contains("read", "write");
 
 		assertThat(authorizedClient).isSameAs(reauthorizedClient);
+		verify(this.authorizationSuccessHandler).onAuthorizationSuccess(
+				eq(reauthorizedClient), eq(this.principal), any());
 		verify(this.authorizedClientService).saveAuthorizedClient(
 				eq(reauthorizedClient), eq(this.principal));
 	}
+
+	@Test
+	public void reauthorizeWhenErrorCodeMatchThenRemoveAuthorizedClient() {
+		ClientAuthorizationException authorizationException = new ClientAuthorizationException(
+				new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null),
+				this.clientRegistration.getRegistrationId());
+
+		when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
+				.thenThrow(authorizationException);
+
+		OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
+				.principal(this.principal)
+				.build();
+
+		assertThatCode(() -> this.authorizedClientManager.authorize(reauthorizeRequest))
+				.isEqualTo(authorizationException);
+
+		verify(this.authorizationFailureHandler).onAuthorizationFailure(
+				eq(authorizationException), eq(this.principal), any());
+		verify(this.authorizedClientService).removeAuthorizedClient(
+				eq(this.clientRegistration.getRegistrationId()), eq(this.principal.getName()));
+	}
+
+	@Test
+	public void reauthorizeWhenErrorCodeDoesNotMatchThenDoNotRemoveAuthorizedClient() {
+		ClientAuthorizationException authorizationException = new ClientAuthorizationException(
+				new OAuth2Error("non-matching-error-code", null, null),
+				this.clientRegistration.getRegistrationId());
+
+		when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
+				.thenThrow(authorizationException);
+
+		OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
+				.principal(this.principal)
+				.build();
+
+		assertThatCode(() -> this.authorizedClientManager.authorize(reauthorizeRequest))
+				.isEqualTo(authorizationException);
+
+		verify(this.authorizationFailureHandler).onAuthorizationFailure(
+				eq(authorizationException), eq(this.principal), any());
+		verifyNoInteractions(this.authorizedClientService);
+	}
 }

+ 95 - 4
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -22,13 +22,18 @@ import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.ClientAuthorizationException;
 import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
+import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler;
+import org.springframework.security.oauth2.client.OAuth2AuthorizationSuccessHandler;
 import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
 import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@@ -41,8 +46,16 @@ import java.util.Map;
 import java.util.function.Function;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
-import static org.mockito.Mockito.*;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.when;
 
 /**
  * Tests for {@link DefaultOAuth2AuthorizedClientManager}.
@@ -54,6 +67,8 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
 	private OAuth2AuthorizedClientRepository authorizedClientRepository;
 	private OAuth2AuthorizedClientProvider authorizedClientProvider;
 	private Function contextAttributesMapper;
+	private OAuth2AuthorizationSuccessHandler authorizationSuccessHandler;
+	private OAuth2AuthorizationFailureHandler authorizationFailureHandler;
 	private DefaultOAuth2AuthorizedClientManager authorizedClientManager;
 	private ClientRegistration clientRegistration;
 	private Authentication principal;
@@ -69,10 +84,14 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
 		this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class);
 		this.authorizedClientProvider = mock(OAuth2AuthorizedClientProvider.class);
 		this.contextAttributesMapper = mock(Function.class);
+		this.authorizationSuccessHandler = spy(new SaveAuthorizedClientOAuth2AuthorizationSuccessHandler(this.authorizedClientRepository));
+		this.authorizationFailureHandler = spy(new RemoveAuthorizedClientOAuth2AuthorizationFailureHandler(this.authorizedClientRepository));
 		this.authorizedClientManager = new DefaultOAuth2AuthorizedClientManager(
 				this.clientRegistrationRepository, this.authorizedClientRepository);
 		this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider);
 		this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper);
+		this.authorizedClientManager.setAuthorizationSuccessHandler(this.authorizationSuccessHandler);
+		this.authorizedClientManager.setAuthorizationFailureHandler(this.authorizationFailureHandler);
 		this.clientRegistration = TestClientRegistrations.clientRegistration().build();
 		this.principal = new TestingAuthenticationToken("principal", "password");
 		this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(),
@@ -110,6 +129,20 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
 				.hasMessage("contextAttributesMapper cannot be null");
 	}
 
+	@Test
+	public void setAuthorizationSuccessHandlerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizationSuccessHandler cannot be null");
+	}
+
+	@Test
+	public void setAuthorizationFailureHandlerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizationFailureHandler cannot be null");
+	}
+
 	@Test
 	public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> this.authorizedClientManager.authorize(null))
@@ -176,8 +209,8 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
 		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
 
 		assertThat(authorizedClient).isNull();
-		verify(this.authorizedClientRepository, never()).saveAuthorizedClient(
-				any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.request), eq(this.response));
+		verifyNoInteractions(this.authorizationSuccessHandler);
+		verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any(), any());
 	}
 
 	@SuppressWarnings("unchecked")
@@ -206,6 +239,8 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
 		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
 
 		assertThat(authorizedClient).isSameAs(this.authorizedClient);
+		verify(this.authorizationSuccessHandler).onAuthorizationSuccess(
+				eq(this.authorizedClient), eq(this.principal), any());
 		verify(this.authorizedClientRepository).saveAuthorizedClient(
 				eq(this.authorizedClient), eq(this.principal), eq(this.request), eq(this.response));
 	}
@@ -242,6 +277,8 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
 		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
 
 		assertThat(authorizedClient).isSameAs(reauthorizedClient);
+		verify(this.authorizationSuccessHandler).onAuthorizationSuccess(
+				eq(reauthorizedClient), eq(this.principal), any());
 		verify(this.authorizedClientRepository).saveAuthorizedClient(
 				eq(reauthorizedClient), eq(this.principal), eq(this.request), eq(this.response));
 	}
@@ -308,6 +345,7 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
 		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
 
 		assertThat(authorizedClient).isSameAs(this.authorizedClient);
+		verifyNoInteractions(this.authorizationSuccessHandler);
 		verify(this.authorizedClientRepository, never()).saveAuthorizedClient(
 				any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.request), eq(this.response));
 	}
@@ -339,6 +377,8 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
 		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
 
 		assertThat(authorizedClient).isSameAs(reauthorizedClient);
+		verify(this.authorizationSuccessHandler).onAuthorizationSuccess(
+				eq(reauthorizedClient), eq(this.principal), any());
 		verify(this.authorizedClientRepository).saveAuthorizedClient(
 				eq(reauthorizedClient), eq(this.principal), eq(this.request), eq(this.response));
 	}
@@ -372,4 +412,55 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
 		String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME);
 		assertThat(requestScopeAttribute).contains("read", "write");
 	}
+
+	@Test
+	public void reauthorizeWhenErrorCodeMatchThenRemoveAuthorizedClient() {
+		ClientAuthorizationException authorizationException = new ClientAuthorizationException(
+				new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null),
+				this.clientRegistration.getRegistrationId());
+
+		when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
+				.thenThrow(authorizationException);
+
+		OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
+				.principal(this.principal)
+				.attributes(attrs -> {
+					attrs.put(HttpServletRequest.class.getName(), this.request);
+					attrs.put(HttpServletResponse.class.getName(), this.response);
+				})
+				.build();
+
+		assertThatCode(() -> this.authorizedClientManager.authorize(reauthorizeRequest))
+				.isEqualTo(authorizationException);
+
+		verify(this.authorizationFailureHandler).onAuthorizationFailure(
+				eq(authorizationException), eq(this.principal), any());
+		verify(this.authorizedClientRepository).removeAuthorizedClient(
+				eq(this.clientRegistration.getRegistrationId()), eq(this.principal), eq(this.request), eq(this.response));
+	}
+
+	@Test
+	public void reauthorizeWhenErrorCodeDoesNotMatchThenDoNotRemoveAuthorizedClient() {
+		ClientAuthorizationException authorizationException = new ClientAuthorizationException(
+				new OAuth2Error("non-matching-error-code", null, null),
+				this.clientRegistration.getRegistrationId());
+
+		when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
+				.thenThrow(authorizationException);
+
+		OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
+				.principal(this.principal)
+				.attributes(attrs -> {
+					attrs.put(HttpServletRequest.class.getName(), this.request);
+					attrs.put(HttpServletResponse.class.getName(), this.response);
+				})
+				.build();
+
+		assertThatCode(() -> this.authorizedClientManager.authorize(reauthorizeRequest))
+				.isEqualTo(authorizationException);
+
+		verify(this.authorizationFailureHandler).onAuthorizationFailure(
+				eq(authorizationException), eq(this.principal), any());
+		verifyNoInteractions(this.authorizedClientRepository);
+	}
 }

+ 247 - 19
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,18 +15,6 @@
  */
 package org.springframework.security.oauth2.client.web.reactive.function.client;
 
-import java.net.URI;
-import java.time.Duration;
-import java.time.Instant;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.function.Consumer;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -35,8 +23,6 @@ import org.mockito.ArgumentCaptor;
 import org.mockito.Captor;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
-import reactor.util.context.Context;
-
 import org.springframework.core.codec.ByteBufferEncoder;
 import org.springframework.core.codec.CharSequenceEncoder;
 import org.springframework.http.HttpHeaders;
@@ -60,7 +46,9 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.oauth2.client.ClientAuthorizationException;
 import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
+import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder;
@@ -78,6 +66,9 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistr
 import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@@ -89,16 +80,37 @@ import org.springframework.web.context.request.RequestContextHolder;
 import org.springframework.web.context.request.ServletRequestAttributes;
 import org.springframework.web.reactive.function.BodyInserter;
 import org.springframework.web.reactive.function.client.ClientRequest;
+import org.springframework.web.reactive.function.client.ClientResponse;
+import org.springframework.web.reactive.function.client.ExchangeFunction;
 import org.springframework.web.reactive.function.client.WebClient;
+import org.springframework.web.reactive.function.client.WebClientResponseException;
+import reactor.core.publisher.Mono;
+import reactor.util.context.Context;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.net.URI;
+import java.nio.charset.StandardCharsets;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.Consumer;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.assertj.core.api.Assertions.entry;
 import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyZeroInteractions;
+import static org.mockito.Mockito.verifyNoInteractions;
 import static org.mockito.Mockito.when;
 import static org.springframework.http.HttpMethod.GET;
 import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
@@ -128,6 +140,14 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 	@Mock
 	private OAuth2AccessTokenResponseClient<OAuth2PasswordGrantRequest> passwordTokenResponseClient;
 	@Mock
+	private OAuth2AuthorizationFailureHandler authorizationFailureHandler;
+	@Captor
+	private ArgumentCaptor<OAuth2AuthorizationException> authorizationExceptionCaptor;
+	@Captor
+	private ArgumentCaptor<Authentication> authenticationCaptor;
+	@Captor
+	private ArgumentCaptor<Map<String, Object>> attributesCaptor;
+	@Mock
 	private WebClient.RequestHeadersSpec<?> spec;
 	@Captor
 	private ArgumentCaptor<Consumer<Map<String, Object>>> attrs;
@@ -167,7 +187,7 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		this.authorizedClientManager = new DefaultOAuth2AuthorizedClientManager(
 				this.clientRegistrationRepository, this.authorizedClientRepository);
 		this.authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
-		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager);
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager);
 	}
 
 	@After
@@ -233,7 +253,7 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		SecurityContextHolder.getContext().setAuthentication(this.authentication);
 		Map<String, Object> attrs = getDefaultRequestAttributes();
 		assertThat(getAuthentication(attrs)).isEqualTo(this.authentication);
-		verifyZeroInteractions(this.authorizedClientRepository);
+		verifyNoInteractions(this.authorizedClientRepository);
 	}
 
 	private Map<String, Object> getDefaultRequestAttributes() {
@@ -647,6 +667,215 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		assertThat(getBody(request)).isEmpty();
 	}
 
+	@Test
+	public void filterWhenUnauthorizedThenInvokeFailureHandler() {
+		assertHttpStatusInvokesFailureHandler(HttpStatus.UNAUTHORIZED, OAuth2ErrorCodes.INVALID_TOKEN);
+	}
+
+	@Test
+	public void filterWhenForbiddenThenInvokeFailureHandler() {
+		assertHttpStatusInvokesFailureHandler(HttpStatus.FORBIDDEN, OAuth2ErrorCodes.INSUFFICIENT_SCOPE);
+	}
+
+	private void assertHttpStatusInvokesFailureHandler(HttpStatus httpStatus, String expectedErrorCode) {
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+				this.registration, "principalName", this.accessToken);
+		MockHttpServletRequest servletRequest = new MockHttpServletRequest();
+		MockHttpServletResponse servletResponse = new MockHttpServletResponse();
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.attributes(httpServletRequest(servletRequest))
+				.attributes(httpServletResponse(servletResponse))
+				.build();
+
+		when(this.exchange.getResponse().rawStatusCode()).thenReturn(httpStatus.value());
+		when(this.exchange.getResponse().headers()).thenReturn(mock(ClientResponse.Headers.class));
+		this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler);
+
+		this.function.filter(request, this.exchange).block();
+
+		verify(this.authorizationFailureHandler).onAuthorizationFailure(
+				this.authorizationExceptionCaptor.capture(),
+				this.authenticationCaptor.capture(),
+				this.attributesCaptor.capture());
+
+		assertThat(this.authorizationExceptionCaptor.getValue())
+				.isInstanceOfSatisfying(ClientAuthorizationException.class, e -> {
+					assertThat(e.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId());
+					assertThat(e.getError().getErrorCode()).isEqualTo(expectedErrorCode);
+					assertThat(e).hasNoCause();
+					assertThat(e).hasMessageContaining(expectedErrorCode);
+				});
+		assertThat(this.authenticationCaptor.getValue().getName())
+				.isEqualTo(authorizedClient.getPrincipalName());
+		assertThat(this.attributesCaptor.getValue())
+				.containsExactly(
+						entry(HttpServletRequest.class.getName(), servletRequest),
+						entry(HttpServletResponse.class.getName(), servletResponse));
+	}
+
+	@Test
+	public void filterWhenWWWAuthenticateHeaderIncludesErrorThenInvokeFailureHandler() {
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+				this.registration, "principalName", this.accessToken);
+		MockHttpServletRequest servletRequest = new MockHttpServletRequest();
+		MockHttpServletResponse servletResponse = new MockHttpServletResponse();
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.attributes(httpServletRequest(servletRequest))
+				.attributes(httpServletResponse(servletResponse))
+				.build();
+
+		String wwwAuthenticateHeader = "Bearer error=\"insufficient_scope\", " +
+				"error_description=\"The request requires higher privileges than provided by the access token.\", " +
+				"error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\"";
+		ClientResponse.Headers headers = mock(ClientResponse.Headers.class);
+		when(headers.header(eq(HttpHeaders.WWW_AUTHENTICATE)))
+				.thenReturn(Collections.singletonList(wwwAuthenticateHeader));
+		when(this.exchange.getResponse().headers()).thenReturn(headers);
+		this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler);
+
+		this.function.filter(request, this.exchange).block();
+
+		verify(this.authorizationFailureHandler).onAuthorizationFailure(
+				this.authorizationExceptionCaptor.capture(),
+				this.authenticationCaptor.capture(),
+				this.attributesCaptor.capture());
+
+		assertThat(this.authorizationExceptionCaptor.getValue())
+				.isInstanceOfSatisfying(ClientAuthorizationException.class, e -> {
+					assertThat(e.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId());
+					assertThat(e.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INSUFFICIENT_SCOPE);
+					assertThat(e.getError().getDescription()).isEqualTo("The request requires higher privileges than provided by the access token.");
+					assertThat(e.getError().getUri()).isEqualTo("https://tools.ietf.org/html/rfc6750#section-3.1");
+					assertThat(e).hasNoCause();
+					assertThat(e).hasMessageContaining(OAuth2ErrorCodes.INSUFFICIENT_SCOPE);
+				});
+		assertThat(this.authenticationCaptor.getValue().getName())
+				.isEqualTo(authorizedClient.getPrincipalName());
+		assertThat(this.attributesCaptor.getValue())
+				.containsExactly(
+						entry(HttpServletRequest.class.getName(), servletRequest),
+						entry(HttpServletResponse.class.getName(), servletResponse));
+	}
+
+	@Test
+	public void filterWhenUnauthorizedWithWebClientExceptionThenInvokeFailureHandler() {
+		assertHttpStatusWithWebClientExceptionInvokesFailureHandler(
+				HttpStatus.UNAUTHORIZED, OAuth2ErrorCodes.INVALID_TOKEN);
+	}
+
+	@Test
+	public void filterWhenForbiddenWithWebClientExceptionThenInvokeFailureHandler() {
+		assertHttpStatusWithWebClientExceptionInvokesFailureHandler(
+				HttpStatus.FORBIDDEN, OAuth2ErrorCodes.INSUFFICIENT_SCOPE);
+	}
+
+	private void assertHttpStatusWithWebClientExceptionInvokesFailureHandler(
+			HttpStatus httpStatus, String expectedErrorCode) {
+
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+				this.registration, "principalName", this.accessToken);
+		MockHttpServletRequest servletRequest = new MockHttpServletRequest();
+		MockHttpServletResponse servletResponse = new MockHttpServletResponse();
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.attributes(httpServletRequest(servletRequest))
+				.attributes(httpServletResponse(servletResponse))
+				.build();
+
+		WebClientResponseException exception = WebClientResponseException.create(
+				httpStatus.value(),
+				httpStatus.getReasonPhrase(),
+				HttpHeaders.EMPTY,
+				new byte[0],
+				StandardCharsets.UTF_8);
+		ExchangeFunction throwingExchangeFunction = r -> Mono.error(exception);
+		this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler);
+
+		assertThatCode(() -> this.function.filter(request, throwingExchangeFunction).block())
+				.isEqualTo(exception);
+
+		verify(this.authorizationFailureHandler).onAuthorizationFailure(
+				this.authorizationExceptionCaptor.capture(),
+				this.authenticationCaptor.capture(),
+				this.attributesCaptor.capture());
+
+		assertThat(this.authorizationExceptionCaptor.getValue())
+				.isInstanceOfSatisfying(ClientAuthorizationException.class, e -> {
+					assertThat(e.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId());
+					assertThat(e.getError().getErrorCode()).isEqualTo(expectedErrorCode);
+					assertThat(e).hasCause(exception);
+					assertThat(e).hasMessageContaining(expectedErrorCode);
+				});
+		assertThat(this.authenticationCaptor.getValue().getName())
+				.isEqualTo(authorizedClient.getPrincipalName());
+		assertThat(this.attributesCaptor.getValue())
+				.containsExactly(
+						entry(HttpServletRequest.class.getName(), servletRequest),
+						entry(HttpServletResponse.class.getName(), servletResponse));
+	}
+
+	@Test
+	public void filterWhenAuthorizationExceptionThenInvokeFailureHandler() {
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+				this.registration, "principalName", this.accessToken);
+		MockHttpServletRequest servletRequest = new MockHttpServletRequest();
+		MockHttpServletResponse servletResponse = new MockHttpServletResponse();
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.attributes(httpServletRequest(servletRequest))
+				.attributes(httpServletResponse(servletResponse))
+				.build();
+
+		OAuth2AuthorizationException authorizationException = new OAuth2AuthorizationException(
+				new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN));
+		ExchangeFunction throwingExchangeFunction = r -> Mono.error(authorizationException);
+		this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler);
+
+		assertThatCode(() -> this.function.filter(request, throwingExchangeFunction).block())
+				.isEqualTo(authorizationException);
+
+		verify(this.authorizationFailureHandler).onAuthorizationFailure(
+				this.authorizationExceptionCaptor.capture(),
+				this.authenticationCaptor.capture(),
+				this.attributesCaptor.capture());
+
+		assertThat(this.authorizationExceptionCaptor.getValue())
+				.isInstanceOfSatisfying(OAuth2AuthorizationException.class, e -> {
+					assertThat(e.getError().getErrorCode()).isEqualTo(authorizationException.getError().getErrorCode());
+					assertThat(e).hasNoCause();
+					assertThat(e).hasMessageContaining(OAuth2ErrorCodes.INVALID_TOKEN);
+				});
+		assertThat(this.authenticationCaptor.getValue().getName())
+				.isEqualTo(authorizedClient.getPrincipalName());
+		assertThat(this.attributesCaptor.getValue())
+				.containsExactly(
+						entry(HttpServletRequest.class.getName(), servletRequest),
+						entry(HttpServletResponse.class.getName(), servletResponse));
+	}
+
+	@Test
+	public void filterWhenOtherHttpStatusThenDoesNotInvokeFailureHandler() {
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+				this.registration, "principalName", this.accessToken);
+		MockHttpServletRequest servletRequest = new MockHttpServletRequest();
+		MockHttpServletResponse servletResponse = new MockHttpServletResponse();
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.attributes(httpServletRequest(servletRequest))
+				.attributes(httpServletResponse(servletResponse))
+				.build();
+
+		when(this.exchange.getResponse().rawStatusCode()).thenReturn(HttpStatus.BAD_REQUEST.value());
+		when(this.exchange.getResponse().headers()).thenReturn(mock(ClientResponse.Headers.class));
+		this.function.setAuthorizationFailureHandler(this.authorizationFailureHandler);
+
+		this.function.filter(request, this.exchange).block();
+
+		verifyNoInteractions(this.authorizationFailureHandler);
+	}
+
 	private Context context(HttpServletRequest servletRequest, HttpServletResponse servletResponse, Authentication authentication) {
 		Map<Object, Object> contextAttributes = new HashMap<>();
 		contextAttributes.put(HttpServletRequest.class, servletRequest);
@@ -688,5 +917,4 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		request.body().insert(body, context).block();
 		return body.getBodyAsString().block();
 	}
-
 }