2
0
Эх сурвалжийг харах

Introduce Reactive OAuth2Authorization success/failure handlers

All ReactiveOAuth2AuthorizedClientManagers now have authorization success/failure handlers.
A success handler is provided to save authorized clients for future requests.
A failure handler is provided to remove previously saved authorized clients.

ServerOAuth2AuthorizedClientExchangeFilterFunction also makes use of a
failure handler in the case of unauthorized or forbidden http status code.

The main use cases now handled are
- remove authorized client when an authorization server indicates that a refresh token is no longer valid (when authorization server returns invalid_grant)
- remove authorized client when a resource server indicates that an access token is no longer valid (when resource server returns invalid_token)

Introduced ClientAuthorizationException to capture details needed when removing an authorized client.
All ReactiveOAuth2AccessTokenResponseClients now throw a ClientAuthorizationException on failures.

Created AbstractWebClientReactiveOAuth2AccessTokenResponseClient to unify common logic between all ReactiveOAuth2AccessTokenResponseClients.

Fixes gh-7699
Phil Clay 5 жил өмнө
parent
commit
e5fca61810
26 өөрчлөгдсөн 2504 нэмэгдсэн , 480 устгасан
  1. 94 9
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java
  2. 89 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientAuthorizationException.java
  3. 4 17
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientAuthorizationRequiredException.java
  4. 51 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizationFailureHandler.java
  5. 51 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizationSuccessHandler.java
  6. 229 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java
  7. 24 56
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClient.java
  8. 11 87
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java
  9. 16 87
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClient.java
  10. 27 83
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java
  11. 133 43
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java
  12. 172 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler.java
  13. 80 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler.java
  14. 332 52
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java
  15. 233 1
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java
  16. 6 6
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java
  17. 11 5
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java
  18. 14 10
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java
  19. 12 10
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java
  20. 229 1
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java
  21. 333 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionITests.java
  22. 268 3
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java
  23. 30 3
      oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AuthorizationException.java
  24. 22 1
      oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2ErrorCodes.java
  25. 14 3
      oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2AccessTokenResponseBodyExtractor.java
  26. 19 3
      oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractorsTests.java

+ 94 - 9
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -17,7 +17,12 @@ package org.springframework.security.oauth2.client;
 
 import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
+import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager;
+import org.springframework.security.oauth2.client.web.RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler;
+import org.springframework.security.oauth2.client.web.SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.util.Assert;
+import org.springframework.web.server.ServerWebExchange;
 import reactor.core.publisher.Mono;
 
 import java.util.Collections;
@@ -25,17 +30,42 @@ import java.util.Map;
 import java.util.function.Function;
 
 /**
- * An implementation of an {@link ReactiveOAuth2AuthorizedClientManager}
- * that is capable of operating outside of a {@code ServerHttpRequest} context,
+ * An implementation of a {@link ReactiveOAuth2AuthorizedClientManager}
+ * that is capable of operating outside of the context of a {@link ServerWebExchange},
  * e.g. in a scheduled/background thread and/or in the service-tier.
  *
- * <p>This is a reactive equivalent of {@link org.springframework.security.oauth2.client.AuthorizedClientServiceOAuth2AuthorizedClientManager}</p>
+ * <p>(When operating <em>within</em> the context of a {@link ServerWebExchange},
+ * use {@link DefaultReactiveOAuth2AuthorizedClientManager} instead.)</p>
+ *
+ * <p>This is a reactive equivalent of {@link org.springframework.security.oauth2.client.AuthorizedClientServiceOAuth2AuthorizedClientManager}.</p>
+ *
+ * <h2>Authorized Client Persistence</h2>
+ *
+ * <p>This client manager utilizes a {@link ReactiveOAuth2AuthorizedClientService}
+ * to persist {@link OAuth2AuthorizedClient}s.</p>
+ *
+ * <p>By default, when an authorization attempt succeeds, the {@link OAuth2AuthorizedClient}
+ * will be saved in the authorized client service.
+ * This functionality can be changed by configuring a custom {@link ReactiveOAuth2AuthorizationSuccessHandler}
+ * via {@link #setAuthorizationSuccessHandler(ReactiveOAuth2AuthorizationSuccessHandler)}.</p>
+ *
+ * <p>By default, when an authorization attempt fails due to an
+ * {@value org.springframework.security.oauth2.core.OAuth2ErrorCodes#INVALID_GRANT} error,
+ * the previously saved {@link OAuth2AuthorizedClient}
+ * will be removed from the authorized client service.
+ * (The {@value org.springframework.security.oauth2.core.OAuth2ErrorCodes#INVALID_GRANT}
+ * error generally occurs when a refresh token that is no longer valid
+ * is used to retrieve a new access token.)
+ * This functionality can be changed by configuring a custom {@link ReactiveOAuth2AuthorizationFailureHandler}
+ * via {@link #setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler)}.</p>
  *
  * @author Ankur Pathak
  * @author Phil Clay
  * @see ReactiveOAuth2AuthorizedClientManager
  * @see ReactiveOAuth2AuthorizedClientProvider
  * @see ReactiveOAuth2AuthorizedClientService
+ * @see ReactiveOAuth2AuthorizationSuccessHandler
+ * @see ReactiveOAuth2AuthorizationFailureHandler
  * @since 5.2.2
  */
 public final class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager
@@ -45,6 +75,8 @@ public final class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager
 	private final ReactiveOAuth2AuthorizedClientService authorizedClientService;
 	private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = context -> Mono.empty();
 	private Function<OAuth2AuthorizeRequest, Mono<Map<String, Object>>> contextAttributesMapper = new DefaultContextAttributesMapper();
+	private ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler;
+	private ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler;
 
 	/**
 	 * Constructs an {@code AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager} using the provided parameters.
@@ -59,6 +91,8 @@ public final class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager
 		Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
 		this.clientRegistrationRepository = clientRegistrationRepository;
 		this.authorizedClientService = authorizedClientService;
+		this.authorizationSuccessHandler = new SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler(authorizedClientService);
+		this.authorizationFailureHandler = new RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler(authorizedClientService);
 	}
 
 	@Override
@@ -66,7 +100,7 @@ public final class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager
 		Assert.notNull(authorizeRequest, "authorizeRequest cannot be null");
 
 		return createAuthorizationContext(authorizeRequest)
-				.flatMap(this::authorizeAndSave);
+				.flatMap(authorizationContext -> authorize(authorizationContext, authorizeRequest.getPrincipal()));
 	}
 
 	private Mono<OAuth2AuthorizationContext> createAuthorizationContext(OAuth2AuthorizeRequest authorizeRequest) {
@@ -90,13 +124,34 @@ public final class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager
 						}));
 	}
 
-	private Mono<OAuth2AuthorizedClient> authorizeAndSave(OAuth2AuthorizationContext authorizationContext) {
+	/**
+	 * Performs authorization and then delegates to either the {@link #authorizationSuccessHandler}
+	 * or {@link #authorizationFailureHandler}, depending on the authorization result.
+	 *
+	 * @param authorizationContext the context to authorize
+	 * @param principal the principle to authorize
+	 * @return a {@link Mono} that emits the authorized client after the authorization attempt succeeds
+	 *         and the {@link #authorizationSuccessHandler} has completed,
+	 *         or completes with an exception after the authorization attempt fails
+	 *         and the {@link #authorizationFailureHandler} has completed
+	 */
+	private Mono<OAuth2AuthorizedClient> authorize(
+			OAuth2AuthorizationContext authorizationContext,
+			Authentication principal) {
 		return this.authorizedClientProvider.authorize(authorizationContext)
-				.flatMap(authorizedClient -> this.authorizedClientService.saveAuthorizedClient(
+				// Delegate to the authorizationSuccessHandler of the successful authorization
+				.flatMap(authorizedClient -> this.authorizationSuccessHandler.onAuthorizationSuccess(
 								authorizedClient,
-								authorizationContext.getPrincipal())
+								principal,
+								Collections.emptyMap())
 						.thenReturn(authorizedClient))
-				.switchIfEmpty(Mono.defer(()-> Mono.justOrEmpty(authorizationContext.getAuthorizedClient())));
+				// Delegate to the authorizationFailureHandler of the failed authorization
+				.onErrorResume(OAuth2AuthorizationException.class, authorizationException -> this.authorizationFailureHandler.onAuthorizationFailure(
+								authorizationException,
+								principal,
+								Collections.emptyMap())
+						.then(Mono.error(authorizationException)))
+				.switchIfEmpty(Mono.defer(() -> Mono.justOrEmpty(authorizationContext.getAuthorizedClient())));
 	}
 
 	/**
@@ -121,6 +176,36 @@ public final class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager
 		this.contextAttributesMapper = contextAttributesMapper;
 	}
 
+	/**
+	 * Sets the handler that handles successful authorizations.
+	 *
+	 * <p>A {@link SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler}
+	 * is used by default.</p>
+	 *
+	 * @param authorizationSuccessHandler the handler that handles successful authorizations.
+	 * @see SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler
+	 * @since 5.3
+	 */
+	public void setAuthorizationSuccessHandler(ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler) {
+		Assert.notNull(authorizationSuccessHandler, "authorizationSuccessHandler cannot be null");
+		this.authorizationSuccessHandler = authorizationSuccessHandler;
+	}
+
+	/**
+	 * Sets the handler that handles authorization failures.
+	 *
+	 * <p>A {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler}
+	 * is used by default.</p>
+	 *
+	 * @param authorizationFailureHandler the handler that handles authorization failures.
+	 * @see RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler
+	 * @since 5.3
+	 */
+	public void setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) {
+		Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null");
+		this.authorizationFailureHandler = authorizationFailureHandler;
+	}
+
 	/**
 	 * The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}.
 	 */

+ 89 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientAuthorizationException.java

@@ -0,0 +1,89 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client;
+
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.util.Assert;
+
+/**
+ * This exception is thrown on the client side when an attempt to authenticate
+ * or authorize an OAuth 2.0 client fails.
+ *
+ * @author Phil Clay
+ * @since 5.3
+ * @see OAuth2AuthorizedClient
+ */
+public class ClientAuthorizationException extends OAuth2AuthorizationException {
+
+	private final String clientRegistrationId;
+
+	/**
+	 * Constructs a {@code ClientAuthorizationException} using the provided parameters.
+	 *
+	 * @param error the {@link OAuth2Error OAuth 2.0 Error}
+	 * @param clientRegistrationId the identifier for the client's registration
+	 */
+	public ClientAuthorizationException(OAuth2Error error, String clientRegistrationId) {
+		this(error, clientRegistrationId, error.toString());
+	}
+	/**
+	 * Constructs a {@code ClientAuthorizationException} using the provided parameters.
+	 *
+	 * @param error the {@link OAuth2Error OAuth 2.0 Error}
+	 * @param clientRegistrationId the identifier for the client's registration
+	 * @param message the exception message
+	 */
+	public ClientAuthorizationException(OAuth2Error error, String clientRegistrationId, String message) {
+		super(error, message);
+		Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
+		this.clientRegistrationId = clientRegistrationId;
+	}
+
+	/**
+	 * Constructs a {@code ClientAuthorizationException} using the provided parameters.
+	 *
+	 * @param error the {@link OAuth2Error OAuth 2.0 Error}
+	 * @param clientRegistrationId the identifier for the client's registration
+	 * @param cause the root cause
+	 */
+	public ClientAuthorizationException(OAuth2Error error, String clientRegistrationId, Throwable cause) {
+		this(error, clientRegistrationId, error.toString(), cause);
+	}
+
+	/**
+	 * Constructs a {@code ClientAuthorizationException} using the provided parameters.
+	 *
+	 * @param error the {@link OAuth2Error OAuth 2.0 Error}
+	 * @param clientRegistrationId the identifier for the client's registration
+	 * @param message the exception message
+	 * @param cause the root cause
+	 */
+	public ClientAuthorizationException(OAuth2Error error, String clientRegistrationId, String message, Throwable cause) {
+		super(error, message, cause);
+		Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
+		this.clientRegistrationId = clientRegistrationId;
+	}
+
+	/**
+	 * Returns the identifier for the client's registration.
+	 *
+	 * @return the identifier for the client's registration
+	 */
+	public String getClientRegistrationId() {
+		return this.clientRegistrationId;
+	}
+}

+ 4 - 17
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientAuthorizationRequiredException.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,9 +15,7 @@
  */
 package org.springframework.security.oauth2.client;
 
-import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
-import org.springframework.util.Assert;
 
 /**
  * This exception is thrown when an OAuth 2.0 Client is required
@@ -27,9 +25,8 @@ import org.springframework.util.Assert;
  * @since 5.1
  * @see OAuth2AuthorizedClient
  */
-public class ClientAuthorizationRequiredException extends OAuth2AuthorizationException {
+public class ClientAuthorizationRequiredException extends ClientAuthorizationException {
 	private static final String CLIENT_AUTHORIZATION_REQUIRED_ERROR_CODE = "client_authorization_required";
-	private final String clientRegistrationId;
 
 	/**
 	 * Constructs a {@code ClientAuthorizationRequiredException} using the provided parameters.
@@ -38,17 +35,7 @@ public class ClientAuthorizationRequiredException extends OAuth2AuthorizationExc
 	 */
 	public ClientAuthorizationRequiredException(String clientRegistrationId) {
 		super(new OAuth2Error(CLIENT_AUTHORIZATION_REQUIRED_ERROR_CODE,
-				"Authorization required for Client Registration Id: " + clientRegistrationId, null));
-		Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
-		this.clientRegistrationId = clientRegistrationId;
-	}
-
-	/**
-	 * Returns the identifier for the client's registration.
-	 *
-	 * @return the identifier for the client's registration
-	 */
-	public String getClientRegistrationId() {
-		return this.clientRegistrationId;
+				"Authorization required for Client Registration Id: " + clientRegistrationId, null),
+				clientRegistrationId);
 	}
 }

+ 51 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizationFailureHandler.java

@@ -0,0 +1,51 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client;
+
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+import reactor.core.publisher.Mono;
+
+import java.util.Map;
+
+/**
+ * Handles when an OAuth 2.0 Client
+ * fails to authorize (or re-authorize)
+ * via the authorization server or resource server.
+ *
+ * @author Phil Clay
+ * @since 5.3
+ */
+@FunctionalInterface
+public interface ReactiveOAuth2AuthorizationFailureHandler {
+
+	/**
+	 * Called when an OAuth 2.0 Client
+	 * fails to authorize (or re-authorize)
+	 * via the authorization server or resource server.
+	 *
+	 * @param authorizationException the exception that contains details about what failed
+	 * @param principal the {@code Principal} that was attempted to be authorized
+	 * @param attributes an immutable {@code Map} of extra optional attributes present under certain conditions.
+	 *                   For example, this might contain a {@link org.springframework.web.server.ServerWebExchange ServerWebExchange}
+	 *                   if the authorization was performed within the context of a {@code ServerWebExchange}.
+	 * @return an empty {@link Mono} that completes after this handler has finished handling the event.
+	 */
+	Mono<Void> onAuthorizationFailure(
+			OAuth2AuthorizationException authorizationException,
+			Authentication principal,
+			Map<String, Object> attributes);
+}

+ 51 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizationSuccessHandler.java

@@ -0,0 +1,51 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client;
+
+import org.springframework.security.core.Authentication;
+import reactor.core.publisher.Mono;
+
+import java.util.Map;
+
+/**
+ * Handles when an OAuth 2.0 Client
+ * has been successfully authorized (or re-authorized)
+ * via the authorization server.
+ *
+ * @author Phil Clay
+ * @since 5.3
+ */
+@FunctionalInterface
+public interface ReactiveOAuth2AuthorizationSuccessHandler {
+
+	/**
+	 * Called when an OAuth 2.0 Client
+	 * has been successfully authorized (or re-authorized)
+	 * via the authorization server.
+	 *
+	 * @param authorizedClient the client that was successfully authorized
+	 * @param principal the {@code Principal} associated with the authorized client
+	 * @param attributes an immutable {@code Map} of extra optional attributes present under certain conditions.
+	 *                   For example, this might contain a {@link org.springframework.web.server.ServerWebExchange ServerWebExchange}
+	 *                   if the authorization was performed within the context of a {@code ServerWebExchange}.
+	 * @return an empty {@link Mono} that completes after this handler has finished handling the event.
+	 */
+	Mono<Void> onAuthorizationSuccess(
+			OAuth2AuthorizedClient authorizedClient,
+			Authentication principal,
+			Map<String, Object> attributes);
+
+}

+ 229 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java

@@ -0,0 +1,229 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.endpoint;
+
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
+import org.springframework.security.oauth2.client.ClientAuthorizationException;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
+import org.springframework.util.StringUtils;
+import org.springframework.web.reactive.function.BodyInserters;
+import org.springframework.web.reactive.function.client.ClientResponse;
+import org.springframework.web.reactive.function.client.WebClient;
+import reactor.core.publisher.Mono;
+
+import java.util.Collections;
+import java.util.Set;
+
+import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse;
+
+/**
+ * Abstract base class for all of the {@code WebClientReactive*TokenResponseClient}s
+ * that communicate to the Authorization Server's Token Endpoint.
+ *
+ * <p>Submits a form request body specific to the type of grant request.</p>
+ *
+ * <p>Accepts a JSON response body containing an OAuth 2.0 Access token or error.</p>
+ *
+ * @author Phil Clay
+ * @since 5.3
+ * @param <T> type of grant request
+ * @see <a href="https://tools.ietf.org/html/rfc6749#section-3.2">RFC-6749 Token Endpoint</a>
+ * @see WebClientReactiveAuthorizationCodeTokenResponseClient
+ * @see WebClientReactiveClientCredentialsTokenResponseClient
+ * @see WebClientReactivePasswordTokenResponseClient
+ * @see WebClientReactiveRefreshTokenTokenResponseClient
+ */
+abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T extends AbstractOAuth2AuthorizationGrantRequest>
+		implements ReactiveOAuth2AccessTokenResponseClient<T> {
+
+	private WebClient webClient = WebClient.builder().build();
+
+	@Override
+	public Mono<OAuth2AccessTokenResponse> getTokenResponse(T grantRequest) {
+		Assert.notNull(grantRequest, "grantRequest cannot be null");
+		return Mono.defer(() -> this.webClient.post()
+				.uri(clientRegistration(grantRequest).getProviderDetails().getTokenUri())
+				.headers(headers -> populateTokenRequestHeaders(grantRequest, headers))
+				.body(createTokenRequestBody(grantRequest))
+				.exchange()
+				.flatMap(response -> readTokenResponse(grantRequest, response)));
+	}
+
+	/**
+	 * Returns the {@link ClientRegistration} for the given {@code grantRequest}.
+	 *
+	 * @param grantRequest the grant request
+	 * @return the {@link ClientRegistration} for the given {@code grantRequest}.
+	 */
+	abstract ClientRegistration clientRegistration(T grantRequest);
+
+	/**
+	 * Populates the headers for the token request.
+	 *
+	 * @param grantRequest the grant request
+	 * @param headers the headers to populate
+	 */
+	private void populateTokenRequestHeaders(T grantRequest, HttpHeaders headers) {
+		ClientRegistration clientRegistration = clientRegistration(grantRequest);
+		headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
+		headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON));
+		if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) {
+			headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
+		}
+	}
+
+	/**
+	 * Creates and returns the body for the token request.
+	 *
+	 * <p>This method pre-populates the body with some standard properties,
+	 * and then delegates to {@link #populateTokenRequestBody(AbstractOAuth2AuthorizationGrantRequest, BodyInserters.FormInserter)}
+	 * for subclasses to further populate the body before returning.</p>
+	 *
+	 * @param grantRequest the grant request
+	 * @return the body for the token request.
+	 */
+	private BodyInserters.FormInserter<String> createTokenRequestBody(T grantRequest) {
+		BodyInserters.FormInserter<String> body = BodyInserters
+				.fromFormData(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue());
+		return populateTokenRequestBody(grantRequest, body);
+	}
+
+	/**
+	 * Populates the body of the token request.
+	 *
+	 * <p>By default, populates properties that are common to all grant types.
+	 * Subclasses can extend this method to populate grant type specific properties.</p>
+	 *
+	 * @param grantRequest the grant request
+	 * @param body the body to populate
+	 * @return the populated body
+	 */
+	BodyInserters.FormInserter<String> populateTokenRequestBody(T grantRequest, BodyInserters.FormInserter<String> body) {
+		ClientRegistration clientRegistration = clientRegistration(grantRequest);
+		if (!ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) {
+			body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
+		}
+		if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) {
+			body.with(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
+		}
+		Set<String> scopes = scopes(grantRequest);
+		if (!CollectionUtils.isEmpty(scopes)) {
+			body.with(OAuth2ParameterNames.SCOPE,
+					StringUtils.collectionToDelimitedString(scopes, " "));
+		}
+		return body;
+	}
+
+	/**
+	 * Returns the scopes to include as a property in the token request.
+	 *
+	 * @param grantRequest the grant request
+	 * @return the scopes to include as a property in the token request.
+	 */
+	abstract Set<String> scopes(T grantRequest);
+
+	/**
+	 * Returns the scopes to include in the response if the authorization
+	 * server returned no scopes in the response.
+	 *
+	 * <p>As per <a href="https://tools.ietf.org/html/rfc6749#section-5.1">RFC-6749 Section 5.1 Successful Access Token Response</a>,
+	 * if AccessTokenResponse.scope is empty, then default to the scope
+	 * originally requested by the client in the Token Request.</p>
+	 *
+	 * @param grantRequest the grant request
+	 * @return the scopes to include in the response if the authorization
+	 *         server returned no scopes.
+	 */
+	Set<String> defaultScopes(T grantRequest) {
+		return scopes(grantRequest);
+	}
+
+	/**
+	 * Reads the token response from the response body.
+	 *
+	 * @param grantRequest the request for which the response was received.
+	 * @param response the client response from which to read
+	 * @return the token response from the response body.
+	 */
+	private Mono<OAuth2AccessTokenResponse> readTokenResponse(T grantRequest, ClientResponse response) {
+		return response.body(oauth2AccessTokenResponse())
+				.onErrorMap(OAuth2AuthorizationException.class, e -> createClientAuthorizationException(
+						response,
+						clientRegistration(grantRequest).getRegistrationId(),
+						e))
+				.map(tokenResponse -> populateTokenResponse(grantRequest, tokenResponse));
+	}
+
+	/**
+	 * Wraps the given {@link OAuth2AuthorizationException} in a {@link ClientAuthorizationException}
+	 * that provides response details, and a more descriptive exception message.
+	 *
+	 * @param response the token response
+	 * @param clientRegistrationId the id of the {@link ClientRegistration} for which a token is being requested
+	 * @param authorizationException the {@link OAuth2AuthorizationException} to wrap
+	 * @return the {@link ClientAuthorizationException} that wraps the given {@link OAuth2AuthorizationException}
+	 */
+	private OAuth2AuthorizationException createClientAuthorizationException(
+			ClientResponse response,
+			String clientRegistrationId,
+			OAuth2AuthorizationException authorizationException) {
+
+		String message = String.format("Error retrieving OAuth 2.0 Access Token (HTTP Status Code: %s) %s",
+				response.rawStatusCode(),
+				authorizationException.getError());
+
+		return new ClientAuthorizationException(
+				authorizationException.getError(),
+				clientRegistrationId,
+				message,
+				authorizationException);
+	}
+
+	/**
+	 * Populates the given {@link OAuth2AccessTokenResponse} with additional details
+	 * from the grant request.
+	 *
+	 * @param grantRequest the request for which the response was received.
+	 * @param tokenResponse the original token response
+	 * @return a token response optionally populated with additional details from the request.
+	 */
+	OAuth2AccessTokenResponse populateTokenResponse(T grantRequest, OAuth2AccessTokenResponse tokenResponse) {
+		if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) {
+			Set<String> defaultScopes = defaultScopes(grantRequest);
+			tokenResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse)
+					.scopes(defaultScopes)
+					.build();
+		}
+		return tokenResponse;
+	}
+
+	/**
+	 * Sets the {@link WebClient} used when requesting the OAuth 2.0 Access Token Response.
+	 *
+	 * @param webClient the {@link WebClient} used when requesting the Access Token Response
+	 */
+	public void setWebClient(WebClient webClient) {
+		Assert.notNull(webClient, "webClient cannot be null");
+		this.webClient = webClient;
+	}
+}

+ 24 - 56
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClient.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,24 +15,19 @@
  */
 package org.springframework.security.oauth2.client.endpoint;
 
-import org.springframework.http.MediaType;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
-import org.springframework.security.oauth2.core.AuthorizationGrantType;
-import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.web.reactive.function.BodyInserters;
-import org.springframework.web.reactive.function.client.WebClient;
-import org.springframework.util.Assert;
-import reactor.core.publisher.Mono;
 
-import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse;
+import java.util.Collections;
+import java.util.Set;
 
 /**
- * An implementation of an {@link ReactiveOAuth2AccessTokenResponseClient} that &quot;exchanges&quot;
+ * An implementation of a {@link ReactiveOAuth2AccessTokenResponseClient} that &quot;exchanges&quot;
  * an authorization code credential for an access token credential
  * at the Authorization Server's Token Endpoint.
  *
@@ -41,7 +36,7 @@ import static org.springframework.security.oauth2.core.web.reactive.function.OAu
  *
  * @author Rob Winch
  * @since 5.1
- * @see OAuth2AccessTokenResponseClient
+ * @see ReactiveOAuth2AccessTokenResponseClient
  * @see OAuth2AuthorizationCodeGrantRequest
  * @see OAuth2AccessTokenResponse
  * @see <a target="_blank" href="https://connect2id.com/products/nimbus-oauth-openid-connect-sdk">Nimbus OAuth 2.0 SDK</a>
@@ -49,64 +44,37 @@ import static org.springframework.security.oauth2.core.web.reactive.function.OAu
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.4">Section 4.1.4 Access Token Response (Authorization Code Grant)</a>
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc7636#section-4.2">Section 4.2 Client Creates the Code Challenge</a>
  */
-public class WebClientReactiveAuthorizationCodeTokenResponseClient implements ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> {
-	private WebClient webClient = WebClient.builder()
-			.build();
+public class WebClientReactiveAuthorizationCodeTokenResponseClient extends
+		AbstractWebClientReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> {
 
-	/**
-	 * @param webClient the webClient to set
-	 */
-	public void setWebClient(WebClient webClient) {
-		Assert.notNull(webClient, "webClient cannot be null");
-		this.webClient = webClient;
+	@Override
+	ClientRegistration clientRegistration(OAuth2AuthorizationCodeGrantRequest grantRequest) {
+		return grantRequest.getClientRegistration();
 	}
 
 	@Override
-	public Mono<OAuth2AccessTokenResponse> getTokenResponse(OAuth2AuthorizationCodeGrantRequest authorizationGrantRequest) {
-		return Mono.defer(() -> {
-			ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration();
-			OAuth2AuthorizationExchange authorizationExchange = authorizationGrantRequest.getAuthorizationExchange();
-			String tokenUri = clientRegistration.getProviderDetails().getTokenUri();
-			BodyInserters.FormInserter<String> body = body(authorizationExchange, clientRegistration);
+	Set<String> scopes(OAuth2AuthorizationCodeGrantRequest grantRequest) {
+		return Collections.emptySet();
+	}
 
-			return this.webClient.post()
-					.uri(tokenUri)
-					.accept(MediaType.APPLICATION_JSON)
-					.headers(headers -> {
-						if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) {
-							headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
-						}
-					})
-					.body(body)
-					.exchange()
-					.flatMap(response -> response.body(oauth2AccessTokenResponse()))
-					.map(response -> {
-						if (response.getAccessToken().getScopes().isEmpty()) {
-							response = OAuth2AccessTokenResponse.withResponse(response)
-								.scopes(authorizationExchange.getAuthorizationRequest().getScopes())
-								.build();
-						}
-						return response;
-					});
-		});
+	@Override
+	Set<String> defaultScopes(OAuth2AuthorizationCodeGrantRequest grantRequest) {
+		return grantRequest.getAuthorizationExchange().getAuthorizationRequest().getScopes();
 	}
 
-	private static BodyInserters.FormInserter<String> body(OAuth2AuthorizationExchange authorizationExchange, ClientRegistration clientRegistration) {
+	@Override
+	BodyInserters.FormInserter<String> populateTokenRequestBody(
+			OAuth2AuthorizationCodeGrantRequest grantRequest,
+			BodyInserters.FormInserter<String> body) {
+		super.populateTokenRequestBody(grantRequest, body);
+		OAuth2AuthorizationExchange authorizationExchange = grantRequest.getAuthorizationExchange();
 		OAuth2AuthorizationResponse authorizationResponse = authorizationExchange.getAuthorizationResponse();
-		BodyInserters.FormInserter<String> body = BodyInserters
-				.fromFormData(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue())
-				.with(OAuth2ParameterNames.CODE, authorizationResponse.getCode());
+		body.with(OAuth2ParameterNames.CODE, authorizationResponse.getCode());
 		String redirectUri = authorizationExchange.getAuthorizationRequest().getRedirectUri();
-		String codeVerifier = authorizationExchange.getAuthorizationRequest().getAttribute(PkceParameterNames.CODE_VERIFIER);
 		if (redirectUri != null) {
 			body.with(OAuth2ParameterNames.REDIRECT_URI, redirectUri);
 		}
-		if (!ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) {
-			body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
-		}
-		if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) {
-			body.with(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
-		}
+		String codeVerifier = authorizationExchange.getAuthorizationRequest().getAttribute(PkceParameterNames.CODE_VERIFIER);
 		if (codeVerifier != null) {
 			body.with(PkceParameterNames.CODE_VERIFIER, codeVerifier);
 		}

+ 11 - 87
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,112 +15,36 @@
  */
 package org.springframework.security.oauth2.client.endpoint;
 
-import org.springframework.core.io.buffer.DataBuffer;
-import org.springframework.core.io.buffer.DataBufferUtils;
-import org.springframework.http.HttpHeaders;
-import org.springframework.http.HttpStatus;
-import org.springframework.http.MediaType;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
-import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
-import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
-import org.springframework.util.Assert;
-import org.springframework.util.CollectionUtils;
-import org.springframework.util.StringUtils;
-import org.springframework.web.reactive.function.BodyInserters;
-import org.springframework.web.reactive.function.client.WebClient;
-import org.springframework.web.reactive.function.client.WebClientResponseException;
-import reactor.core.publisher.Mono;
 
 import java.util.Set;
-import java.util.function.Consumer;
-
-import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse;
 
 /**
- * An implementation of an {@link ReactiveOAuth2AccessTokenResponseClient} that &quot;exchanges&quot;
- * an authorization code credential for an access token credential
+ * An implementation of a {@link ReactiveOAuth2AccessTokenResponseClient} that &quot;exchanges&quot;
+ * a client credential for an access token credential
  * at the Authorization Server's Token Endpoint.
  *
  * @author Rob Winch
  * @since 5.1
- * @see OAuth2AccessTokenResponseClient
+ * @see ReactiveOAuth2AccessTokenResponseClient
  * @see OAuth2AuthorizationCodeGrantRequest
  * @see OAuth2AccessTokenResponse
  * @see <a target="_blank" href="https://connect2id.com/products/nimbus-oauth-openid-connect-sdk">Nimbus OAuth 2.0 SDK</a>
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.3">Section 4.1.3 Access Token Request (Authorization Code Grant)</a>
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.4">Section 4.1.4 Access Token Response (Authorization Code Grant)</a>
  */
-public class WebClientReactiveClientCredentialsTokenResponseClient implements ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> {
-	private WebClient webClient = WebClient.builder()
-			.build();
+public class WebClientReactiveClientCredentialsTokenResponseClient extends
+		AbstractWebClientReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> {
 
 	@Override
-	public Mono<OAuth2AccessTokenResponse> getTokenResponse(OAuth2ClientCredentialsGrantRequest authorizationGrantRequest) {
-		return Mono.defer(() -> {
-			ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration();
-
-			String tokenUri = clientRegistration.getProviderDetails().getTokenUri();
-			BodyInserters.FormInserter<String> body = body(authorizationGrantRequest);
-
-			return this.webClient.post()
-					.uri(tokenUri)
-					.accept(MediaType.APPLICATION_JSON)
-					.headers(headers(clientRegistration))
-					.body(body)
-					.exchange()
-					.flatMap(response -> {
-						HttpStatus status = HttpStatus.resolve(response.rawStatusCode());
-						if (status == null || !status.is2xxSuccessful()) {
-							// extract the contents of this into a method named oauth2AccessTokenResponse but has an argument for the response
-							return response.bodyToFlux(DataBuffer.class)
-								.map(DataBufferUtils::release)
-								.then(Mono.error(WebClientResponseException.create(response.rawStatusCode(),
-											"Cannot get token, expected 2xx HTTP Status code",
-											null,
-											null,
-											null
-								)));
-						}
-						return response.body(oauth2AccessTokenResponse()); })
-					.map(response -> {
-						if (response.getAccessToken().getScopes().isEmpty()) {
-							response = OAuth2AccessTokenResponse.withResponse(response)
-								.scopes(authorizationGrantRequest.getClientRegistration().getScopes())
-								.build();
-						}
-						return response;
-					});
-		});
+	ClientRegistration clientRegistration(OAuth2ClientCredentialsGrantRequest grantRequest) {
+		return grantRequest.getClientRegistration();
 	}
 
-	private Consumer<HttpHeaders> headers(ClientRegistration clientRegistration) {
-		return headers -> {
-			headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
-			if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) {
-				headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
-			}
-		};
-	}
-
-	private static BodyInserters.FormInserter<String> body(OAuth2ClientCredentialsGrantRequest authorizationGrantRequest) {
-		ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration();
-		BodyInserters.FormInserter<String> body = BodyInserters
-				.fromFormData(OAuth2ParameterNames.GRANT_TYPE, authorizationGrantRequest.getGrantType().getValue());
-		Set<String> scopes = clientRegistration.getScopes();
-		if (!CollectionUtils.isEmpty(scopes)) {
-			String scope = StringUtils.collectionToDelimitedString(scopes, " ");
-			body.with(OAuth2ParameterNames.SCOPE, scope);
-		}
-		if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) {
-			body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
-			body.with(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
-		}
-		return body;
+	@Override
+	Set<String> scopes(OAuth2ClientCredentialsGrantRequest grantRequest) {
+		return grantRequest.getClientRegistration().getScopes();
 	}
 
-	public void setWebClient(WebClient webClient) {
-		Assert.notNull(webClient, "webClient cannot be null");
-		this.webClient = webClient;
-	}
 }

+ 16 - 87
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClient.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,29 +15,14 @@
  */
 package org.springframework.security.oauth2.client.endpoint;
 
-import org.springframework.core.io.buffer.DataBuffer;
-import org.springframework.core.io.buffer.DataBufferUtils;
-import org.springframework.http.HttpHeaders;
-import org.springframework.http.HttpStatus;
-import org.springframework.http.MediaType;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
-import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
-import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
-import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
-import org.springframework.util.Assert;
-import org.springframework.util.CollectionUtils;
-import org.springframework.util.StringUtils;
 import org.springframework.web.reactive.function.BodyInserters;
 import org.springframework.web.reactive.function.client.WebClient;
-import reactor.core.publisher.Mono;
 
-import java.util.Collections;
-import java.util.function.Consumer;
-
-import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse;
+import java.util.Set;
 
 /**
  * An implementation of a {@link ReactiveOAuth2AccessTokenResponseClient}
@@ -53,82 +38,26 @@ import static org.springframework.security.oauth2.core.web.reactive.function.OAu
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.3.2">Section 4.3.2 Access Token Request (Resource Owner Password Credentials Grant)</a>
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.3.3">Section 4.3.3 Access Token Response (Resource Owner Password Credentials Grant)</a>
  */
-public final class WebClientReactivePasswordTokenResponseClient implements ReactiveOAuth2AccessTokenResponseClient<OAuth2PasswordGrantRequest> {
-	private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response";
-	private WebClient webClient = WebClient.builder().build();
+public final class WebClientReactivePasswordTokenResponseClient extends
+		AbstractWebClientReactiveOAuth2AccessTokenResponseClient<OAuth2PasswordGrantRequest> {
 
 	@Override
-	public Mono<OAuth2AccessTokenResponse> getTokenResponse(OAuth2PasswordGrantRequest passwordGrantRequest) {
-		Assert.notNull(passwordGrantRequest, "passwordGrantRequest cannot be null");
-		return Mono.defer(() -> {
-			ClientRegistration clientRegistration = passwordGrantRequest.getClientRegistration();
-			return this.webClient.post()
-					.uri(clientRegistration.getProviderDetails().getTokenUri())
-					.headers(tokenRequestHeaders(clientRegistration))
-					.body(tokenRequestBody(passwordGrantRequest))
-					.exchange()
-					.flatMap(response -> {
-						HttpStatus status = HttpStatus.resolve(response.rawStatusCode());
-						if (status == null || !status.is2xxSuccessful()) {
-							OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
-									"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " +
-											"HTTP Status Code " + response.rawStatusCode(), null);
-							return response
-									.bodyToMono(DataBuffer.class)
-									.map(DataBufferUtils::release)
-									.then(Mono.error(new OAuth2AuthorizationException(oauth2Error)));
-						}
-						return response.body(oauth2AccessTokenResponse());
-					})
-					.map(tokenResponse -> {
-						if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) {
-							// As per spec, in Section 5.1 Successful Access Token Response
-							// https://tools.ietf.org/html/rfc6749#section-5.1
-							// If AccessTokenResponse.scope is empty, then default to the scope
-							// originally requested by the client in the Token Request
-							tokenResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse)
-									.scopes(passwordGrantRequest.getClientRegistration().getScopes())
-									.build();
-						}
-						return tokenResponse;
-					});
-		});
+	ClientRegistration clientRegistration(OAuth2PasswordGrantRequest grantRequest) {
+		return grantRequest.getClientRegistration();
 	}
 
-	private static Consumer<HttpHeaders> tokenRequestHeaders(ClientRegistration clientRegistration) {
-		return headers -> {
-			headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
-			headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON));
-			if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) {
-				headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
-			}
-		};
+	@Override
+	Set<String> scopes(OAuth2PasswordGrantRequest grantRequest) {
+		return grantRequest.getClientRegistration().getScopes();
 	}
 
-	private static BodyInserters.FormInserter<String> tokenRequestBody(OAuth2PasswordGrantRequest passwordGrantRequest) {
-		ClientRegistration clientRegistration = passwordGrantRequest.getClientRegistration();
-		BodyInserters.FormInserter<String> body = BodyInserters.fromFormData(
-				OAuth2ParameterNames.GRANT_TYPE, passwordGrantRequest.getGrantType().getValue());
-		body.with(OAuth2ParameterNames.USERNAME, passwordGrantRequest.getUsername());
-		body.with(OAuth2ParameterNames.PASSWORD, passwordGrantRequest.getPassword());
-		if (!CollectionUtils.isEmpty(passwordGrantRequest.getClientRegistration().getScopes())) {
-			body.with(OAuth2ParameterNames.SCOPE,
-					StringUtils.collectionToDelimitedString(passwordGrantRequest.getClientRegistration().getScopes(), " "));
-		}
-		if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) {
-			body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
-			body.with(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
-		}
-		return body;
+	@Override
+	BodyInserters.FormInserter<String> populateTokenRequestBody(
+			OAuth2PasswordGrantRequest grantRequest,
+			BodyInserters.FormInserter<String> body) {
+		return super.populateTokenRequestBody(grantRequest, body)
+			.with(OAuth2ParameterNames.USERNAME, grantRequest.getUsername())
+			.with(OAuth2ParameterNames.PASSWORD, grantRequest.getPassword());
 	}
 
-	/**
-	 * Sets the {@link WebClient} used when requesting the OAuth 2.0 Access Token Response.
-	 *
-	 * @param webClient the {@link WebClient} used when requesting the Access Token Response
-	 */
-	public void setWebClient(WebClient webClient) {
-		Assert.notNull(webClient, "webClient cannot be null");
-		this.webClient = webClient;
-	}
 }

+ 27 - 83
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,29 +15,15 @@
  */
 package org.springframework.security.oauth2.client.endpoint;
 
-import org.springframework.core.io.buffer.DataBuffer;
-import org.springframework.core.io.buffer.DataBufferUtils;
-import org.springframework.http.HttpHeaders;
-import org.springframework.http.HttpStatus;
-import org.springframework.http.MediaType;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
-import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
-import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
-import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
-import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
-import org.springframework.util.StringUtils;
 import org.springframework.web.reactive.function.BodyInserters;
 import org.springframework.web.reactive.function.client.WebClient;
-import reactor.core.publisher.Mono;
 
-import java.util.Collections;
-import java.util.function.Consumer;
-
-import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse;
+import java.util.Set;
 
 /**
  * An implementation of a {@link ReactiveOAuth2AccessTokenResponseClient}
@@ -52,66 +38,37 @@ import static org.springframework.security.oauth2.core.web.reactive.function.OAu
  * @see OAuth2AccessTokenResponse
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-6">Section 6 Refreshing an Access Token</a>
  */
-public final class WebClientReactiveRefreshTokenTokenResponseClient implements ReactiveOAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> {
-	private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response";
-	private WebClient webClient = WebClient.builder().build();
+public final class WebClientReactiveRefreshTokenTokenResponseClient extends
+		AbstractWebClientReactiveOAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> {
 
 	@Override
-	public Mono<OAuth2AccessTokenResponse> getTokenResponse(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) {
-		Assert.notNull(refreshTokenGrantRequest, "refreshTokenGrantRequest cannot be null");
-		return Mono.defer(() -> {
-			ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration();
-			return this.webClient.post()
-					.uri(clientRegistration.getProviderDetails().getTokenUri())
-					.headers(tokenRequestHeaders(clientRegistration))
-					.body(tokenRequestBody(refreshTokenGrantRequest))
-					.exchange()
-					.flatMap(response -> {
-						HttpStatus status = HttpStatus.resolve(response.rawStatusCode());
-						if (status == null || !status.is2xxSuccessful()) {
-							OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
-									"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " +
-											"HTTP Status Code " + response.rawStatusCode(), null);
-							return response
-									.bodyToMono(DataBuffer.class)
-									.map(DataBufferUtils::release)
-									.then(Mono.error(new OAuth2AuthorizationException(oauth2Error)));
-						}
-						return response.body(oauth2AccessTokenResponse());
-					})
-					.map(tokenResponse -> tokenResponse(refreshTokenGrantRequest, tokenResponse));
-		});
+	ClientRegistration clientRegistration(OAuth2RefreshTokenGrantRequest grantRequest) {
+		return grantRequest.getClientRegistration();
 	}
 
-	private static Consumer<HttpHeaders> tokenRequestHeaders(ClientRegistration clientRegistration) {
-		return headers -> {
-			headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
-			headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON));
-			if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) {
-				headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
-			}
-		};
+	@Override
+	Set<String> scopes(OAuth2RefreshTokenGrantRequest grantRequest) {
+		return grantRequest.getScopes();
 	}
 
-	private static BodyInserters.FormInserter<String> tokenRequestBody(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) {
-		ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration();
-		BodyInserters.FormInserter<String> body = BodyInserters.fromFormData(
-				OAuth2ParameterNames.GRANT_TYPE, refreshTokenGrantRequest.getGrantType().getValue());
-		body.with(OAuth2ParameterNames.REFRESH_TOKEN,
-				refreshTokenGrantRequest.getRefreshToken().getTokenValue());
-		if (!CollectionUtils.isEmpty(refreshTokenGrantRequest.getScopes())) {
-			body.with(OAuth2ParameterNames.SCOPE,
-					StringUtils.collectionToDelimitedString(refreshTokenGrantRequest.getScopes(), " "));
-		}
-		if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) {
-			body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
-			body.with(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
-		}
-		return body;
+	@Override
+	Set<String> defaultScopes(OAuth2RefreshTokenGrantRequest grantRequest) {
+		return grantRequest.getAccessToken().getScopes();
 	}
 
-	private static OAuth2AccessTokenResponse tokenResponse(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest,
-															OAuth2AccessTokenResponse accessTokenResponse) {
+	@Override
+	BodyInserters.FormInserter<String> populateTokenRequestBody(
+			OAuth2RefreshTokenGrantRequest grantRequest,
+			BodyInserters.FormInserter<String> body) {
+		return super.populateTokenRequestBody(grantRequest, body)
+			.with(OAuth2ParameterNames.REFRESH_TOKEN, grantRequest.getRefreshToken().getTokenValue());
+	}
+
+	@Override
+	OAuth2AccessTokenResponse populateTokenResponse(
+			OAuth2RefreshTokenGrantRequest grantRequest,
+			OAuth2AccessTokenResponse accessTokenResponse) {
+
 		if (!CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes()) &&
 				accessTokenResponse.getRefreshToken() != null) {
 			return accessTokenResponse;
@@ -119,26 +76,13 @@ public final class WebClientReactiveRefreshTokenTokenResponseClient implements R
 
 		OAuth2AccessTokenResponse.Builder tokenResponseBuilder = OAuth2AccessTokenResponse.withResponse(accessTokenResponse);
 		if (CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes())) {
-			// As per spec, in Section 5.1 Successful Access Token Response
-			// https://tools.ietf.org/html/rfc6749#section-5.1
-			// If AccessTokenResponse.scope is empty, then default to the scope
-			// originally requested by the client in the Token Request
-			tokenResponseBuilder.scopes(refreshTokenGrantRequest.getAccessToken().getScopes());
+			tokenResponseBuilder.scopes(defaultScopes(grantRequest));
 		}
 		if (accessTokenResponse.getRefreshToken() == null) {
 			// Reuse existing refresh token
-			tokenResponseBuilder.refreshToken(refreshTokenGrantRequest.getRefreshToken().getTokenValue());
+			tokenResponseBuilder.refreshToken(grantRequest.getRefreshToken().getTokenValue());
 		}
 		return tokenResponseBuilder.build();
 	}
 
-	/**
-	 * Sets the {@link WebClient} used when requesting the OAuth 2.0 Access Token Response.
-	 *
-	 * @param webClient the {@link WebClient} used when requesting the Access Token Response
-	 */
-	public void setWebClient(WebClient webClient) {
-		Assert.notNull(webClient, "webClient cannot be null");
-		this.webClient = webClient;
-	}
 }

+ 133 - 43
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -19,11 +19,14 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
 import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationFailureHandler;
+import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationSuccessHandler;
 import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
@@ -37,18 +40,54 @@ import java.util.Map;
 import java.util.function.Function;
 
 /**
- * The default implementation of a {@link ReactiveOAuth2AuthorizedClientManager}.
+ * The default implementation of a {@link ReactiveOAuth2AuthorizedClientManager}
+ * for use within the context of a {@link ServerWebExchange}.
+ *
+ * <p>(When operating <em>outside</em> of the context of a {@link ServerWebExchange},
+ * use {@link org.springframework.security.oauth2.client.AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager} instead.)</p>
+ *
+ * <p>This is a reactive equivalent of {@link DefaultOAuth2AuthorizedClientManager}.</p>
+ *
+ * <h2>Authorized Client Persistence</h2>
+ *
+ * <p>This client manager utilizes a {@link ServerOAuth2AuthorizedClientRepository}
+ * to persist {@link OAuth2AuthorizedClient}s.</p>
+ *
+ * <p>By default, when an authorization attempt succeeds, the {@link OAuth2AuthorizedClient}
+ * will be saved in the authorized client repository.
+ * This functionality can be changed by configuring a custom {@link ReactiveOAuth2AuthorizationSuccessHandler}
+ * via {@link #setAuthorizationSuccessHandler(ReactiveOAuth2AuthorizationSuccessHandler)}.</p>
+ *
+ * <p>By default, when an authorization attempt fails due to an
+ * {@value org.springframework.security.oauth2.core.OAuth2ErrorCodes#INVALID_GRANT} error,
+ * the previously saved {@link OAuth2AuthorizedClient}
+ * will be removed from the authorized client repository.
+ * (The {@value org.springframework.security.oauth2.core.OAuth2ErrorCodes#INVALID_GRANT}
+ * error generally occurs when a refresh token that is no longer valid
+ * is used to retrieve a new access token.)
+ * This functionality can be changed by configuring a custom {@link ReactiveOAuth2AuthorizationFailureHandler}
+ * via {@link #setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler)}.</p>
  *
  * @author Joe Grandja
+ * @author Phil Clay
  * @since 5.2
  * @see ReactiveOAuth2AuthorizedClientManager
  * @see ReactiveOAuth2AuthorizedClientProvider
+ * @see ReactiveOAuth2AuthorizationSuccessHandler
+ * @see ReactiveOAuth2AuthorizationFailureHandler
  */
 public final class DefaultReactiveOAuth2AuthorizedClientManager implements ReactiveOAuth2AuthorizedClientManager {
+
+	private static final Mono<ServerWebExchange> currentServerWebExchangeMono = Mono.subscriberContext()
+			.filter(c -> c.hasKey(ServerWebExchange.class))
+			.map(c -> c.get(ServerWebExchange.class));
+
 	private final ReactiveClientRegistrationRepository clientRegistrationRepository;
 	private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
 	private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = context -> Mono.empty();
 	private Function<OAuth2AuthorizeRequest, Mono<Map<String, Object>>> contextAttributesMapper = new DefaultContextAttributesMapper();
+	private ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler;
+	private ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler;
 
 	/**
 	 * Constructs a {@code DefaultReactiveOAuth2AuthorizedClientManager} using the provided parameters.
@@ -62,6 +101,8 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
 		Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
 		this.clientRegistrationRepository = clientRegistrationRepository;
 		this.authorizedClientRepository = authorizedClientRepository;
+		this.authorizationSuccessHandler = new SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler(authorizedClientRepository);
+		this.authorizationFailureHandler = new RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler(authorizedClientRepository);
 	}
 
 	@Override
@@ -70,57 +111,76 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
 
 		String clientRegistrationId = authorizeRequest.getClientRegistrationId();
 		Authentication principal = authorizeRequest.getPrincipal();
-		ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
-
-		return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient())
-				.switchIfEmpty(Mono.defer(() -> loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange)))
-				.flatMap(authorizedClient -> {
-					// Re-authorize
-					return authorizationContext(authorizeRequest, authorizedClient)
-							.flatMap(this.authorizedClientProvider::authorize)
-							.flatMap(reauthorizedClient -> saveAuthorizedClient(reauthorizedClient, principal, serverWebExchange))
-							// Default to the existing authorizedClient if the client was not re-authorized
-							.defaultIfEmpty(authorizeRequest.getAuthorizedClient() != null ?
-									authorizeRequest.getAuthorizedClient() : authorizedClient);
-				})
-				.switchIfEmpty(Mono.deferWithContext(context ->
-						// Authorize
-						this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
-								.switchIfEmpty(Mono.error(() -> new IllegalArgumentException(
-										"Could not find ClientRegistration with id '" + clientRegistrationId + "'")))
-								.flatMap(clientRegistration -> authorizationContext(authorizeRequest, clientRegistration))
-								.flatMap(this.authorizedClientProvider::authorize)
-								.flatMap(authorizedClient -> saveAuthorizedClient(authorizedClient, principal, serverWebExchange))
-								.subscriberContext(context)
-						)
-				);
+
+		return Mono.justOrEmpty(authorizeRequest.<ServerWebExchange>getAttribute(ServerWebExchange.class.getName()))
+				.switchIfEmpty(currentServerWebExchangeMono)
+				.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("serverWebExchange cannot be null")))
+				.flatMap(serverWebExchange -> Mono.justOrEmpty(authorizeRequest.getAuthorizedClient())
+						.switchIfEmpty(Mono.defer(() -> loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange)))
+						.flatMap(authorizedClient -> {
+							// Re-authorize
+							return authorizationContext(authorizeRequest, authorizedClient)
+									.flatMap(authorizationContext -> authorize(authorizationContext, principal, serverWebExchange))
+									// Default to the existing authorizedClient if the client was not re-authorized
+									.defaultIfEmpty(authorizeRequest.getAuthorizedClient() != null ?
+											authorizeRequest.getAuthorizedClient() : authorizedClient);
+						})
+						.switchIfEmpty(Mono.deferWithContext(context ->
+								// Authorize
+								this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
+										.switchIfEmpty(Mono.error(() -> new IllegalArgumentException(
+												"Could not find ClientRegistration with id '" + clientRegistrationId + "'")))
+										.flatMap(clientRegistration -> authorizationContext(authorizeRequest, clientRegistration))
+										.flatMap(authorizationContext -> authorize(authorizationContext, principal, serverWebExchange))
+										.subscriberContext(context)
+								)
+						));
 	}
 
 	private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId, Authentication principal, ServerWebExchange serverWebExchange) {
-		return Mono.justOrEmpty(serverWebExchange)
-				.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
-				.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("serverWebExchange cannot be null")))
-				.flatMap(exchange -> this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange));
+		return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange);
 	}
 
-	private Mono<OAuth2AuthorizedClient> saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, ServerWebExchange serverWebExchange) {
-		return Mono.justOrEmpty(serverWebExchange)
-				.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
-				.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("serverWebExchange cannot be null")))
-				.flatMap(exchange -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, exchange)
-						.thenReturn(authorizedClient));
+	/**
+	 * Performs authorization and then delegates to either the {@link #authorizationSuccessHandler}
+	 * or {@link #authorizationFailureHandler}, depending on the authorization result.
+	 *
+	 * @param authorizationContext the context to authorize
+	 * @param principal the principle to authorize
+	 * @param serverWebExchange the currently active exchange
+	 * @return a {@link Mono} that emits the authorized client after the authorization attempt succeeds
+	 *         and the {@link #authorizationSuccessHandler} has completed,
+	 *         or completes with an exception after the authorization attempt fails
+	 *         and the {@link #authorizationFailureHandler} has completed
+	 */
+	private Mono<OAuth2AuthorizedClient> authorize(
+			OAuth2AuthorizationContext authorizationContext,
+			Authentication principal,
+			ServerWebExchange serverWebExchange) {
+
+		return this.authorizedClientProvider.authorize(authorizationContext)
+				// Delegate to the authorizationSuccessHandler of the successful authorization
+				.flatMap(authorizedClient -> this.authorizationSuccessHandler.onAuthorizationSuccess(
+								authorizedClient,
+								principal,
+								createAttributes(serverWebExchange))
+						.thenReturn(authorizedClient))
+				// Delegate to the authorizationFailureHandler of the failed authorization
+				.onErrorResume(OAuth2AuthorizationException.class, authorizationException -> this.authorizationFailureHandler.onAuthorizationFailure(
+								authorizationException,
+								principal,
+								createAttributes(serverWebExchange))
+						.then(Mono.error(authorizationException)));
 	}
 
-	private static Mono<ServerWebExchange> currentServerWebExchange() {
-		return Mono.subscriberContext()
-				.filter(c -> c.hasKey(ServerWebExchange.class))
-				.map(c -> c.get(ServerWebExchange.class));
+	private Map<String, Object> createAttributes(ServerWebExchange serverWebExchange) {
+		return Collections.singletonMap(ServerWebExchange.class.getName(), serverWebExchange);
 	}
 
 	private Mono<OAuth2AuthorizationContext> authorizationContext(OAuth2AuthorizeRequest authorizeRequest,
 																	OAuth2AuthorizedClient authorizedClient) {
 		return Mono.just(authorizeRequest)
-				.flatMap(this.contextAttributesMapper::apply)
+				.flatMap(this.contextAttributesMapper)
 				.map(attrs -> OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient)
 						.principal(authorizeRequest.getPrincipal())
 						.attributes(attributes -> {
@@ -134,7 +194,7 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
 	private Mono<OAuth2AuthorizationContext> authorizationContext(OAuth2AuthorizeRequest authorizeRequest,
 																	ClientRegistration clientRegistration) {
 		return Mono.just(authorizeRequest)
-				.flatMap(this.contextAttributesMapper::apply)
+				.flatMap(this.contextAttributesMapper)
 				.map(attrs -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration)
 						.principal(authorizeRequest.getPrincipal())
 						.attributes(attributes -> {
@@ -167,6 +227,36 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
 		this.contextAttributesMapper = contextAttributesMapper;
 	}
 
+	/**
+	 * Sets the handler that handles successful authorizations.
+	 *
+	 * <p>A {@link SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler}
+	 * is used by default.</p>
+	 *
+	 * @param authorizationSuccessHandler the handler that handles successful authorizations.
+	 * @see SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler
+	 * @since 5.3
+	 */
+	public void setAuthorizationSuccessHandler(ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler) {
+		Assert.notNull(authorizationSuccessHandler, "authorizationSuccessHandler cannot be null");
+		this.authorizationSuccessHandler = authorizationSuccessHandler;
+	}
+
+	/**
+	 * Sets the handler that handles authorization failures.
+	 *
+	 * <p>A {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler}
+	 * is used by default.</p>
+	 *
+	 * @param authorizationFailureHandler the handler that handles authorization failures.
+	 * @see RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler
+	 * @since 5.3
+	 */
+	public void setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) {
+		Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null");
+		this.authorizationFailureHandler = authorizationFailureHandler;
+	}
+
 	/**
 	 * The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}.
 	 */
@@ -176,7 +266,7 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
 		public Mono<Map<String, Object>> apply(OAuth2AuthorizeRequest authorizeRequest) {
 			ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
 			return Mono.justOrEmpty(serverWebExchange)
-					.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
+					.switchIfEmpty(currentServerWebExchangeMono)
 					.flatMap(exchange -> {
 						Map<String, Object> contextAttributes = Collections.emptyMap();
 						String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE);

+ 172 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler.java

@@ -0,0 +1,172 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.web;
+
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.ClientAuthorizationException;
+import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationFailureHandler;
+import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
+import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.util.Assert;
+import org.springframework.web.server.ServerWebExchange;
+import reactor.core.publisher.Mono;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * An authorization failure handler that removes authorized clients from a
+ * {@link ServerOAuth2AuthorizedClientRepository}
+ * or a {@link ReactiveOAuth2AuthorizedClientService}.
+ * for specific OAuth 2.0 error codes.
+ *
+ * @author Phil Clay
+ * @since 5.3
+ */
+public class RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler implements ReactiveOAuth2AuthorizationFailureHandler {
+
+	/**
+	 * The default OAuth 2.0 error codes that will trigger removal of the authorized client.
+	 * @see OAuth2ErrorCodes
+	 */
+	public static final Set<String>	DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList(
+			/*
+			 * Returned from resource servers when an access token provided is expired, revoked,
+			 * malformed, or invalid for other reasons.
+			 *
+			 * Note that this is needed because the ServerOAuth2AuthorizedClientExchangeFilterFunction
+			 * delegates this type of failure received from a resource server
+			 * to this failure handler.
+			 */
+			OAuth2ErrorCodes.INVALID_TOKEN,
+			/*
+			 * Returned from authorization servers when a refresh token is invalid, expired, revoked,
+			 * does not match the redirection URI used in the authorization request, or was issued to another client.
+			 */
+			OAuth2ErrorCodes.INVALID_GRANT)));
+
+	/**
+	 * A delegate that removes clients from either a
+	 * {@link ServerOAuth2AuthorizedClientRepository}
+	 * or a
+	 * {@link ReactiveOAuth2AuthorizedClientService}
+	 * if the error code is one of the {@link #removeAuthorizedClientErrorCodes}.
+	 */
+	private final OAuth2AuthorizedClientRemover delegate;
+
+	/**
+	 * The OAuth 2.0 Error Codes which will trigger removal of an authorized client.
+	 * @see OAuth2ErrorCodes
+	 */
+	private final Set<String> removeAuthorizedClientErrorCodes;
+
+	@FunctionalInterface
+	private interface OAuth2AuthorizedClientRemover {
+		Mono<Void> removeAuthorizedClient(
+				String clientRegistrationId,
+				Authentication principal,
+				Map<String, Object> attributes);
+	}
+
+	/**
+	 * @param authorizedClientRepository The repository from which authorized clients will be removed
+	 * 		  if the error code is one of the {@link #DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES}.
+	 */
+	public RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
+		this(authorizedClientRepository, DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES);
+	}
+
+	/**
+	 * @param authorizedClientRepository The repository from which authorized clients will be removed
+	 * 		 if the error code is one of the {@code removeAuthorizedClientErrorCodes}.
+	 * @param removeAuthorizedClientErrorCodes the OAuth 2.0 Error Codes which will trigger removal of an authorized client.
+	 * @see OAuth2ErrorCodes
+	 */
+	public RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler(
+			ServerOAuth2AuthorizedClientRepository authorizedClientRepository,
+			Set<String> removeAuthorizedClientErrorCodes) {
+		Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
+		Assert.notNull(removeAuthorizedClientErrorCodes, "removeAuthorizedClientErrorCodes cannot be null");
+		this.removeAuthorizedClientErrorCodes = Collections.unmodifiableSet(new HashSet<>(removeAuthorizedClientErrorCodes));
+		this.delegate = (clientRegistrationId, principal, attributes) ->
+				authorizedClientRepository.removeAuthorizedClient(
+						clientRegistrationId,
+						principal,
+						(ServerWebExchange) attributes.get(ServerWebExchange.class.getName()));
+	}
+
+	/**
+	 * @param authorizedClientService the service from which authorized clients will be removed
+	 * 		  if the error code is one of the {@link #DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES}.
+	 */
+	public RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler(ReactiveOAuth2AuthorizedClientService authorizedClientService) {
+		this(authorizedClientService, DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES);
+	}
+
+	/**
+	 * @param authorizedClientService the service from which authorized clients will be removed
+	 * 		  if the error code is one of the {@code removeAuthorizedClientErrorCodes}.
+	 * @param removeAuthorizedClientErrorCodes the OAuth 2.0 Error Codes which will trigger removal of an authorized client.
+	 * @see OAuth2ErrorCodes
+	 */
+	public RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler(
+			ReactiveOAuth2AuthorizedClientService authorizedClientService,
+			Set<String> removeAuthorizedClientErrorCodes) {
+		Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
+		Assert.notNull(removeAuthorizedClientErrorCodes, "removeAuthorizedClientErrorCodes cannot be null");
+		this.removeAuthorizedClientErrorCodes = Collections.unmodifiableSet(new HashSet<>(removeAuthorizedClientErrorCodes));
+		this.delegate = (clientRegistrationId, principal, attributes) ->
+				authorizedClientService.removeAuthorizedClient(
+						clientRegistrationId,
+						principal.getName());
+	}
+
+	@Override
+	public Mono<Void> onAuthorizationFailure(
+			OAuth2AuthorizationException authorizationException,
+			Authentication principal,
+			Map<String, Object> attributes) {
+
+		if (authorizationException instanceof ClientAuthorizationException
+				&& hasRemovalErrorCode(authorizationException)) {
+
+			ClientAuthorizationException clientAuthorizationException = (ClientAuthorizationException) authorizationException;
+			return this.delegate.removeAuthorizedClient(
+					clientAuthorizationException.getClientRegistrationId(),
+					principal,
+					attributes);
+		} else {
+			return Mono.empty();
+		}
+	}
+
+	/**
+	 * Returns true if the given exception has an error code that
+	 * indicates that the authorized client should be removed.
+	 *
+	 * @param authorizationException the exception that caused the authorization failure
+	 * @return true if the given exception has an error code that
+	 * 		   indicates that the authorized client should be removed.
+	 */
+	private boolean hasRemovalErrorCode(OAuth2AuthorizationException authorizationException) {
+		return this.removeAuthorizedClientErrorCodes.contains(authorizationException.getError().getErrorCode());
+	}
+}

+ 80 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler.java

@@ -0,0 +1,80 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.web;
+
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationSuccessHandler;
+import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
+import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
+import org.springframework.util.Assert;
+import org.springframework.web.server.ServerWebExchange;
+import reactor.core.publisher.Mono;
+
+import java.util.Map;
+
+/**
+ * An authorization success handler that saves authorized clients in a
+ * {@link ServerOAuth2AuthorizedClientRepository}
+ * or a {@link ReactiveOAuth2AuthorizedClientService}.
+ *
+ * @author Phil Clay
+ * @since 5.3
+ */
+public class SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler implements ReactiveOAuth2AuthorizationSuccessHandler {
+
+	/**
+	 * A delegate that saves clients in either a
+	 * {@link ServerOAuth2AuthorizedClientRepository}
+	 * or a
+	 * {@link ReactiveOAuth2AuthorizedClientService}.
+	 */
+	private final ReactiveOAuth2AuthorizationSuccessHandler delegate;
+
+	/**
+	 * @param authorizedClientRepository The repository in which authorized clients will be saved.
+	 */
+	public SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler(final ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
+		Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
+		this.delegate = (authorizedClient, principal, attributes) ->
+				authorizedClientRepository.saveAuthorizedClient(
+						authorizedClient,
+						principal,
+						(ServerWebExchange) attributes.get(ServerWebExchange.class.getName()));
+	}
+
+	/**
+	 * @param authorizedClientService The service in which authorized clients will be saved.
+	 */
+	public SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler(final ReactiveOAuth2AuthorizedClientService authorizedClientService) {
+		Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
+		this.delegate = (authorizedClient, principal, attributes) ->
+				authorizedClientService.saveAuthorizedClient(
+						authorizedClient,
+						principal);
+	}
+
+	@Override
+	public Mono<Void> onAuthorizationSuccess(
+			OAuth2AuthorizedClient authorizedClient,
+			Authentication principal,
+			Map<String, Object> attributes) {
+		return this.delegate.onAuthorizationSuccess(
+				authorizedClient,
+				principal,
+				attributes);
+	}
+}

+ 332 - 52
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,19 +16,26 @@
 
 package org.springframework.security.oauth2.client.web.reactive.function.client;
 
+import org.springframework.http.HttpStatus;
+import org.springframework.lang.Nullable;
 import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.ReactiveSecurityContextHolder;
 import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.oauth2.client.ClientAuthorizationException;
 import org.springframework.security.oauth2.client.ClientCredentialsReactiveOAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
 import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationFailureHandler;
+import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationSuccessHandler;
 import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder;
 import org.springframework.security.oauth2.client.RefreshTokenReactiveOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.web.RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler;
+import org.springframework.security.oauth2.client.web.SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
 import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
@@ -37,15 +44,21 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg
 import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.server.UnAuthenticatedServerOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.util.Assert;
 import org.springframework.web.reactive.function.client.ClientRequest;
 import org.springframework.web.reactive.function.client.ClientResponse;
 import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
 import org.springframework.web.reactive.function.client.ExchangeFunction;
+import org.springframework.web.reactive.function.client.WebClientResponseException;
 import org.springframework.web.server.ServerWebExchange;
 import reactor.core.publisher.Mono;
 
 import java.time.Duration;
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.Map;
 import java.util.Optional;
 import java.util.function.Consumer;
@@ -54,8 +67,27 @@ import java.util.function.Consumer;
  * Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth2 requests by including the
  * token as a Bearer Token.
  *
+ * <h3>Authentication and Authorization Failures</h3>
+ *
+ * <p>Since 5.3, this filter function has the ability to forward authentication (HTTP 401 Unauthorized)
+ * and authorization (HTTP 403 Forbidden) failures from an OAuth 2.0 Resource Server to a
+ * {@link ReactiveOAuth2AuthorizationFailureHandler}.
+ * A {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} can be used
+ * to remove the cached {@link OAuth2AuthorizedClient}, so that future requests will result
+ * in a new token being retrieved from an Authorization Server, and sent to the Resource Server.</p>
+ *
+ * <p>If the {@link #ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository, ServerOAuth2AuthorizedClientRepository)}
+ * constructor is used, a {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler}
+ * will be configured automatically.</p>
+ *
+ * <p>If the {@link #ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientManager)}
+ * constructor is used, a {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler}
+ * will <em>NOT</em> be configured automatically.
+ * It is recommended that you configure one via {@link #setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler)}.</p>
+ *
  * @author Rob Winch
  * @author Joe Grandja
+ * @author Phil Clay
  * @since 5.1
  */
 public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
@@ -77,7 +109,20 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 	private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
 			AuthorityUtils.createAuthorityList("ROLE_USER"));
 
-	private ReactiveOAuth2AuthorizedClientManager authorizedClientManager;
+	private final Mono<Authentication> currentAuthenticationMono = ReactiveSecurityContextHolder.getContext()
+			.map(SecurityContext::getAuthentication)
+			.defaultIfEmpty(ANONYMOUS_USER_TOKEN);
+
+	private final Mono<String> clientRegistrationIdMono = currentAuthenticationMono
+			.filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken)
+			.cast(OAuth2AuthenticationToken.class)
+			.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
+
+	private final Mono<ServerWebExchange> currentServerWebExchangeMono = Mono.subscriberContext()
+			.filter(c -> c.hasKey(ServerWebExchange.class))
+			.map(c -> c.get(ServerWebExchange.class));
+
+	private final ReactiveOAuth2AuthorizedClientManager authorizedClientManager;
 
 	private boolean defaultAuthorizedClientManager;
 
@@ -91,33 +136,71 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 	@Deprecated
 	private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient;
 
+	private ClientResponseHandler clientResponseHandler;
+
+	@FunctionalInterface
+	private interface ClientResponseHandler {
+		Mono<ClientResponse> handleResponse(ClientRequest request, Mono<ClientResponse> response);
+	}
 
 	/**
 	 * Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters.
 	 *
+	 * <p>When this constructor is used, authentication (HTTP 401) and authorization (HTTP 403)
+	 * failures returned from a OAuth 2.0 Resource Server will <em>NOT</em> be forwarded to a
+	 * {@link ReactiveOAuth2AuthorizationFailureHandler}.
+	 * Therefore, future requests to the Resource Server will most likely use the same (most likely invalid) token,
+	 * resulting in the same errors returned from the Resource Server.
+	 * It is recommended to configure a {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler}
+	 * via {@link #setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler)}
+	 * so that authentication and authorization failures returned from a Resource Server
+	 * will result in removing the authorized client, so that a new token is retrieved for future requests.</p>
+	 *
 	 * @since 5.2
 	 * @param authorizedClientManager the {@link ReactiveOAuth2AuthorizedClientManager} which manages the authorized client(s)
 	 */
 	public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientManager authorizedClientManager) {
 		Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null");
 		this.authorizedClientManager = authorizedClientManager;
+		this.clientResponseHandler =  (request, responseMono) -> responseMono;
 	}
 
 	/**
 	 * Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters.
 	 *
+	 * <p>Since 5.3, when this constructor is used, authentication (HTTP 401)
+	 * and authorization (HTTP 403) failures returned from an OAuth 2.0 Resource Server
+	 * will be forwarded to a {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler},
+	 * which will potentially remove the {@link OAuth2AuthorizedClient} from the given
+	 * {@link ServerOAuth2AuthorizedClientRepository}, depending on the OAuth 2.0 error code returned.
+	 * Authentication failures returned from an OAuth 2.0 Resource Server typically indicate
+	 * that the token is invalid, and should not be used in future requests.
+	 * Removing the authorized client from the repository will ensure that the existing
+	 * token will not be sent for future requests to the Resource Server,
+	 * and a new token is retrieved from Authorization Server and used for
+	 * future requests to the Resource Server.</p>
+	 *
 	 * @param clientRegistrationRepository the repository of client registrations
 	 * @param authorizedClientRepository the repository of authorized clients
 	 */
 	public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository,
 																ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
-		this.authorizedClientManager = createDefaultAuthorizedClientManager(clientRegistrationRepository, authorizedClientRepository);
+
+		ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler =
+				new RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler(authorizedClientRepository);
+
+		this.authorizedClientManager = createDefaultAuthorizedClientManager(
+				clientRegistrationRepository,
+				authorizedClientRepository,
+				authorizationFailureHandler);
+		this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler);
 		this.defaultAuthorizedClientManager = true;
 	}
 
 	private static ReactiveOAuth2AuthorizedClientManager createDefaultAuthorizedClientManager(
 			ReactiveClientRegistrationRepository clientRegistrationRepository,
-			ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
+			ServerOAuth2AuthorizedClientRepository authorizedClientRepository,
+			ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) {
 
 		ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider =
 				ReactiveOAuth2AuthorizedClientProviderBuilder.builder()
@@ -132,7 +215,8 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 			UnAuthenticatedReactiveOAuth2AuthorizedClientManager unauthenticatedAuthorizedClientManager =
 					new UnAuthenticatedReactiveOAuth2AuthorizedClientManager(
 							clientRegistrationRepository,
-							(UnAuthenticatedServerOAuth2AuthorizedClientRepository) authorizedClientRepository);
+							(UnAuthenticatedServerOAuth2AuthorizedClientRepository) authorizedClientRepository,
+							authorizationFailureHandler);
 			unauthenticatedAuthorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
 			return unauthenticatedAuthorizedClientManager;
 		}
@@ -140,6 +224,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 		DefaultReactiveOAuth2AuthorizedClientManager authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager(
 				clientRegistrationRepository, authorizedClientRepository);
 		authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
+		authorizedClientManager.setAuthorizationFailureHandler(authorizationFailureHandler);
 
 		return authorizedClientManager;
 	}
@@ -316,8 +401,13 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 	public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
 		return authorizedClient(request)
 				.map(authorizedClient -> bearer(request, authorizedClient))
-				.flatMap(next::exchange)
-				.switchIfEmpty(Mono.defer(() -> next.exchange(request)));
+				.flatMap(requestWithBearer -> exchangeAndHandleResponse(requestWithBearer, next))
+				.switchIfEmpty(Mono.defer(() -> exchangeAndHandleResponse(request, next)));
+	}
+
+	private Mono<ClientResponse> exchangeAndHandleResponse(ClientRequest request, ExchangeFunction next) {
+		return next.exchange(request)
+				.transform(responseMono -> this.clientResponseHandler.handleResponse(request, responseMono));
 	}
 
 	private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) {
@@ -330,80 +420,102 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 	}
 
 	private Mono<OAuth2AuthorizeRequest> authorizeRequest(ClientRequest request) {
-		Mono<Authentication> authentication = currentAuthentication();
-
-		Mono<String> clientRegistrationId = Mono.justOrEmpty(clientRegistrationId(request))
-				.switchIfEmpty(Mono.justOrEmpty(this.defaultClientRegistrationId))
-				.switchIfEmpty(clientRegistrationId(authentication));
+		Mono<String> clientRegistrationId = effectiveClientRegistrationId(request);
 
-		Mono<Optional<ServerWebExchange>> serverWebExchange = Mono.justOrEmpty(serverWebExchange(request))
-				.switchIfEmpty(currentServerWebExchange())
-				.map(Optional::of)
-				.defaultIfEmpty(Optional.empty());
+		Mono<Optional<ServerWebExchange>> serverWebExchange = effectiveServerWebExchange(request);
 
-		return Mono.zip(clientRegistrationId, authentication, serverWebExchange)
+		return Mono.zip(clientRegistrationId, this.currentAuthenticationMono, serverWebExchange)
 				.map(t3 -> {
 					OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withClientRegistrationId(t3.getT1()).principal(t3.getT2());
-					if (t3.getT3().isPresent()) {
-						builder.attribute(ServerWebExchange.class.getName(), t3.getT3().get());
-					}
+					t3.getT3().ifPresent(exchange -> builder.attribute(ServerWebExchange.class.getName(), exchange));
 					return builder.build();
 				});
 	}
 
-	private Mono<OAuth2AuthorizeRequest> reauthorizeRequest(ClientRequest request, OAuth2AuthorizedClient authorizedClient) {
-		Mono<Authentication> authentication = currentAuthentication();
+	/**
+	 * Returns a {@link Mono} the emits the {@code clientRegistrationId}
+	 * that is active for the given request.
+	 *
+	 * @param request the request for which to retrieve the {@code clientRegistrationId}
+	 * @return a mono that emits the {@code clientRegistrationId}
+	 * 	       that is active for the given request.
+	 */
+	private Mono<String> effectiveClientRegistrationId(ClientRequest request) {
+		return Mono.justOrEmpty(clientRegistrationId(request))
+				.switchIfEmpty(Mono.justOrEmpty(this.defaultClientRegistrationId))
+				.switchIfEmpty(clientRegistrationIdMono);
+	}
 
-		Mono<Optional<ServerWebExchange>> serverWebExchange = Mono.justOrEmpty(serverWebExchange(request))
-				.switchIfEmpty(currentServerWebExchange())
+	/**
+	 * Returns a {@link Mono} that emits an {@link Optional} for the {@link ServerWebExchange}
+	 * that is active for the given request.
+	 *
+	 * <p>The returned {@link Mono} will never complete empty.
+	 * Instead, it will emit an empty {@link Optional} if no exchange is active.</p>
+	 *
+	 * @param request the request for which to retrieve the exchange
+	 * @return a {@link Mono} that emits an {@link Optional} for the {@link ServerWebExchange}
+	 * 	       that is active for the given request.
+	 */
+	private Mono<Optional<ServerWebExchange>> effectiveServerWebExchange(ClientRequest request) {
+		return Mono.justOrEmpty(serverWebExchange(request))
+				.switchIfEmpty(currentServerWebExchangeMono)
 				.map(Optional::of)
 				.defaultIfEmpty(Optional.empty());
+	}
 
-		return Mono.zip(authentication, serverWebExchange)
+	private Mono<OAuth2AuthorizeRequest> reauthorizeRequest(ClientRequest request, OAuth2AuthorizedClient authorizedClient) {
+		Mono<Optional<ServerWebExchange>> serverWebExchange = effectiveServerWebExchange(request);
+
+		return Mono.zip(this.currentAuthenticationMono, serverWebExchange)
 				.map(t2 -> {
 					OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withAuthorizedClient(authorizedClient).principal(t2.getT1());
-					if (t2.getT2().isPresent()) {
-						builder.attribute(ServerWebExchange.class.getName(), t2.getT2().get());
-					}
+					t2.getT2().ifPresent(exchange -> builder.attribute(ServerWebExchange.class.getName(), exchange));
 					return builder.build();
 				});
 	}
 
-	private Mono<Authentication> currentAuthentication() {
-		return ReactiveSecurityContextHolder.getContext()
-				.map(SecurityContext::getAuthentication)
-				.defaultIfEmpty(ANONYMOUS_USER_TOKEN);
-	}
-
-	private Mono<String> clientRegistrationId(Mono<Authentication> authentication) {
-		return authentication
-				.filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken)
-				.cast(OAuth2AuthenticationToken.class)
-				.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
-	}
-
-	private Mono<ServerWebExchange> currentServerWebExchange() {
-		return Mono.subscriberContext()
-				.filter(c -> c.hasKey(ServerWebExchange.class))
-				.map(c -> c.get(ServerWebExchange.class));
-	}
-
 	private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) {
 		return ClientRequest.from(request)
 					.headers(headers -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue()))
 					.build();
 	}
 
+	/**
+	 * Sets the handler that handles authentication and authorization failures when communicating
+	 * to the OAuth 2.0 Resource Server.
+	 *
+	 * <p>For example, a {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler}
+	 * is typically used to remove the cached {@link OAuth2AuthorizedClient},
+	 * so that the same token is no longer used in future requests to the Resource Server.</p>
+	 *
+	 * <p>The failure handler used by default depends on which constructor was used
+	 * to construct this {@link ServerOAuth2AuthorizedClientExchangeFilterFunction}.
+	 * See the constructors for more details.</p>
+	 *
+	 * @param authorizationFailureHandler the handler that handles authentication and authorization failures.
+	 * @since 5.3
+	 */
+	public void setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) {
+		Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null");
+		this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler);
+	}
+
 	private static class UnAuthenticatedReactiveOAuth2AuthorizedClientManager implements ReactiveOAuth2AuthorizedClientManager {
 		private final ReactiveClientRegistrationRepository clientRegistrationRepository;
 		private final UnAuthenticatedServerOAuth2AuthorizedClientRepository authorizedClientRepository;
+		private final ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler;
+		private final ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler;
 		private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider;
 
 		private UnAuthenticatedReactiveOAuth2AuthorizedClientManager(
 				ReactiveClientRegistrationRepository clientRegistrationRepository,
-				UnAuthenticatedServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
+				UnAuthenticatedServerOAuth2AuthorizedClientRepository authorizedClientRepository,
+				ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) {
 			this.clientRegistrationRepository = clientRegistrationRepository;
 			this.authorizedClientRepository = authorizedClientRepository;
+			this.authorizationSuccessHandler = new SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler(authorizedClientRepository);
+			this.authorizationFailureHandler = authorizationFailureHandler;
 		}
 
 		@Override
@@ -418,8 +530,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 					.flatMap(authorizedClient -> {
 						// Re-authorize
 						return Mono.just(OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient).principal(principal).build())
-								.flatMap(this.authorizedClientProvider::authorize)
-								.flatMap(reauthorizedClient -> this.authorizedClientRepository.saveAuthorizedClient(reauthorizedClient, principal, null).thenReturn(reauthorizedClient))
+								.flatMap(authorizationContext -> authorize(authorizationContext, principal))
 								// Default to the existing authorizedClient if the client was not re-authorized
 								.defaultIfEmpty(authorizeRequest.getAuthorizedClient() != null ?
 										authorizeRequest.getAuthorizedClient() : authorizedClient);
@@ -430,15 +541,184 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 								.switchIfEmpty(Mono.error(() -> new IllegalArgumentException(
 										"Could not find ClientRegistration with id '" + clientRegistrationId + "'")))
 								.flatMap(clientRegistration -> Mono.just(OAuth2AuthorizationContext.withClientRegistration(clientRegistration).principal(principal).build()))
-								.flatMap(this.authorizedClientProvider::authorize)
-								.flatMap(authorizedClient -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, null).thenReturn(authorizedClient))
+								.flatMap(authorizationContext -> authorize(authorizationContext, principal))
 								.subscriberContext(context)
 					));
 		}
 
+		/**
+		 * Performs authorization and then delegates to either the {@link #authorizationSuccessHandler}
+		 * or {@link #authorizationFailureHandler}, depending on the authorization result.
+		 *
+		 * @param authorizationContext the context to authorize
+		 * @param principal the principle to authorize
+		 * @return a {@link Mono} that emits the authorized client after the authorization attempt succeeds
+		 *         and the {@link #authorizationSuccessHandler} has completed,
+		 *         or completes with an exception after the authorization attempt fails
+		 *         and the {@link #authorizationFailureHandler} has completed
+		 */
+		private Mono<OAuth2AuthorizedClient> authorize(
+				OAuth2AuthorizationContext authorizationContext,
+				Authentication principal) {
+
+			return this.authorizedClientProvider.authorize(authorizationContext)
+					// Delegates to the authorizationSuccessHandler of the successful authorization
+					.flatMap(authorizedClient -> this.authorizationSuccessHandler.onAuthorizationSuccess(
+									authorizedClient,
+									principal,
+									Collections.emptyMap())
+							.thenReturn(authorizedClient))
+					// Delegates to  the authorizationFailureHandler of the failed authorization
+					.onErrorResume(OAuth2AuthorizationException.class, authorizationException -> this.authorizationFailureHandler.onAuthorizationFailure(
+									authorizationException,
+									principal,
+									Collections.emptyMap())
+							.then(Mono.error(authorizationException)));
+		}
+
 		private void setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider) {
 			Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null");
 			this.authorizedClientProvider = authorizedClientProvider;
 		}
 	}
+
+	/**
+	 * Forwards authentication and authorization failures to a
+	 * {@link ReactiveOAuth2AuthorizationFailureHandler}.
+	 *
+	 * @since 5.3
+	 */
+	private class AuthorizationFailureForwarder implements ClientResponseHandler {
+
+		/**
+		 * A map of HTTP Status Code to OAuth 2.0 Error codes for
+		 * HTTP status codes that should be interpreted as
+		 * authentication or authorization failures.
+		 */
+		private final Map<Integer, String> httpStatusToOAuth2ErrorCodeMap;
+
+		/**
+		 * The {@link ReactiveOAuth2AuthorizationFailureHandler} to notify
+		 * when an authentication/authorization failure occurs.
+		 */
+		private final ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler;
+
+		private AuthorizationFailureForwarder(ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) {
+			Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null");
+			this.authorizationFailureHandler = authorizationFailureHandler;
+
+			Map<Integer, String> httpStatusToOAuth2Error = new HashMap<>();
+			httpStatusToOAuth2Error.put(HttpStatus.UNAUTHORIZED.value(), OAuth2ErrorCodes.INVALID_TOKEN);
+			httpStatusToOAuth2Error.put(HttpStatus.FORBIDDEN.value(), OAuth2ErrorCodes.INSUFFICIENT_SCOPE);
+			this.httpStatusToOAuth2ErrorCodeMap = Collections.unmodifiableMap(httpStatusToOAuth2Error);
+		}
+
+		@Override
+		public Mono<ClientResponse> handleResponse(
+				ClientRequest request,
+				Mono<ClientResponse> responseMono) {
+
+			return responseMono
+				.flatMap(response -> handleHttpStatus(request, response.rawStatusCode(), null)
+						.thenReturn(response))
+				.onErrorResume(WebClientResponseException.class, e -> handleHttpStatus(request, e.getRawStatusCode(), e)
+						.then(Mono.error(e)))
+				.onErrorResume(OAuth2AuthorizationException.class, e -> handleAuthorizationException(request, e)
+						.then(Mono.error(e)));
+		}
+
+		/**
+		 * Handles the given http status code returned from a resource server
+		 * by notifying the authorization failure handler if the http status
+		 * code is in the {@link #httpStatusToOAuth2ErrorCodeMap}.
+		 *
+		 * @param request the request being processed
+		 * @param httpStatusCode the http status returned by the resource server
+		 * @param exception The root cause exception for the failure (nullable)
+		 * @return a {@link Mono} that completes empty after the authorization failure handler completes.
+		 */
+		private Mono<Void> handleHttpStatus(ClientRequest request, int httpStatusCode, @Nullable Exception exception) {
+			return Mono.justOrEmpty(this.httpStatusToOAuth2ErrorCodeMap.get(httpStatusCode))
+					.flatMap(oauth2ErrorCode -> {
+						Mono<Optional<ServerWebExchange>> serverWebExchange = effectiveServerWebExchange(request);
+
+						Mono<String> clientRegistrationId = effectiveClientRegistrationId(request);
+
+						return Mono.zip(currentAuthenticationMono, serverWebExchange, clientRegistrationId)
+								.flatMap(tuple3 -> handleAuthorizationFailure(
+										tuple3.getT1(),              // Authentication principal
+										tuple3.getT2().orElse(null), // ServerWebExchange exchange
+										createAuthorizationException(
+												tuple3.getT3(),      // String clientRegistrationId
+												oauth2ErrorCode,
+												exception)));
+					});
+		}
+
+		/**
+		 * Handles the given OAuth2AuthorizationException that occurred downstream
+		 * by notifying the authorization failure handler.
+		 *
+		 * @param request the request being processed
+		 * @param exception the authorization exception to include in the failure event.
+		 * @return a {@link Mono} that completes empty after the authorization failure handler completes.
+		 */
+		private Mono<Void> handleAuthorizationException(ClientRequest request, OAuth2AuthorizationException exception) {
+			Mono<Optional<ServerWebExchange>> serverWebExchange = effectiveServerWebExchange(request);
+
+			return Mono.zip(currentAuthenticationMono, serverWebExchange)
+					.flatMap(tuple2 -> handleAuthorizationFailure(
+							tuple2.getT1(),              // Authentication principal
+							tuple2.getT2().orElse(null), // ServerWebExchange exchange
+							exception));
+		}
+
+		/**
+		 * Creates an authorization exception using the given parameters.
+		 *
+		 * @param clientRegistrationId the client registration id of the client that failed authentication/authorization.
+		 * @param oauth2ErrorCode the OAuth 2.0 error code to use in the authorization failure event
+		 * @param exception The root cause exception for the failure (nullable)
+		 * @return an authorization exception using the given parameters.
+		 */
+		private ClientAuthorizationException createAuthorizationException(
+				String clientRegistrationId,
+				String oauth2ErrorCode,
+				@Nullable Exception exception) {
+			return new ClientAuthorizationException(
+					new OAuth2Error(
+							oauth2ErrorCode,
+							null,
+							"https://tools.ietf.org/html/rfc6750#section-3.1"),
+					clientRegistrationId,
+					exception);
+		}
+
+
+		/**
+		 * Delegates to the authorization failure handler of the failed authorization.
+		 *
+		 * @param principal the principal associated with the failed authorization attempt
+		 * @param exchange the currently active exchange
+		 * @param exception the authorization exception to include in the failure event.
+		 * @return a {@link Mono} that completes empty after the authorization failure handler completes.
+		 */
+		private Mono<Void> handleAuthorizationFailure(
+				Authentication principal,
+				ServerWebExchange exchange,
+				OAuth2AuthorizationException exception) {
+
+			return this.authorizationFailureHandler.onAuthorizationFailure(
+					exception,
+					principal,
+					createAttributes(exchange));
+		}
+
+		private Map<String, Object> createAttributes(ServerWebExchange exchange) {
+			if (exchange == null) {
+				return Collections.emptyMap();
+			}
+			return Collections.singletonMap(ServerWebExchange.class.getName(), exchange);
+		}
+	}
 }

+ 233 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -23,6 +23,9 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
 import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@@ -34,6 +37,7 @@ import java.util.Map;
 import java.util.function.Function;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
@@ -59,6 +63,7 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
 	private OAuth2AuthorizedClient authorizedClient;
 	private ArgumentCaptor<OAuth2AuthorizationContext> authorizationContextCaptor;
 	private PublisherProbe<Void> saveAuthorizedClientProbe;
+	private PublisherProbe<Void> removeAuthorizedClientProbe;
 
 	@SuppressWarnings("unchecked")
 	@Before
@@ -67,6 +72,8 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
 		this.authorizedClientService = mock(ReactiveOAuth2AuthorizedClientService.class);
 		this.saveAuthorizedClientProbe = PublisherProbe.empty();
 		when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(this.saveAuthorizedClientProbe.mono());
+		this.removeAuthorizedClientProbe = PublisherProbe.empty();
+		when(this.authorizedClientService.removeAuthorizedClient(any(), any())).thenReturn(this.removeAuthorizedClientProbe.mono());
 		this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class);
 		this.contextAttributesMapper = mock(Function.class);
 		when(this.contextAttributesMapper.apply(any())).thenReturn(Mono.empty());
@@ -109,6 +116,20 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
 				.hasMessage("contextAttributesMapper cannot be null");
 	}
 
+	@Test
+	public void setAuthorizationSuccessHandlerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizationSuccessHandler cannot be null");
+	}
+
+	@Test
+	public void setAuthorizationFailureHandlerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizationFailureHandler cannot be null");
+	}
+
 	@Test
 	public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> this.authorizedClientManager.authorize(null))
@@ -187,6 +208,214 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
 		verify(this.authorizedClientService).saveAuthorizedClient(
 				eq(this.authorizedClient), eq(this.principal));
 		this.saveAuthorizedClientProbe.assertWasSubscribed();
+		verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any());
+	}
+
+	@SuppressWarnings("unchecked")
+	@Test
+	public void authorizeWhenNotAuthorizedAndSupportedProviderAndCustomSuccessHandlerThenInvokeCustomSuccessHandler() {
+		when(this.clientRegistrationRepository.findByRegistrationId(
+				eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
+
+		when(this.authorizedClientService.loadAuthorizedClient(
+				any(), any())).thenReturn(Mono.empty());
+
+		when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient));
+
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
+				.principal(this.principal)
+				.build();
+		PublisherProbe<Void> authorizationSuccessHandlerProbe = PublisherProbe.empty();
+		this.authorizedClientManager.setAuthorizationSuccessHandler((client, principal, attributes) -> authorizationSuccessHandlerProbe.mono());
+
+		Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
+
+		StepVerifier.create(authorizedClient)
+				.expectNext(this.authorizedClient)
+				.verifyComplete();
+
+		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
+		verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
+
+		OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
+		assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationContext.getAuthorizedClient()).isNull();
+		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
+
+		authorizationSuccessHandlerProbe.assertWasSubscribed();
+		verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any());
+		verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any());
+	}
+
+	@Test
+	public void authorizeWhenInvalidTokenThenRemoveAuthorizedClient() {
+		when(this.clientRegistrationRepository.findByRegistrationId(
+				eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
+
+		when(this.authorizedClientService.loadAuthorizedClient(
+				any(), any())).thenReturn(Mono.empty());
+
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
+				.principal(this.principal)
+				.build();
+
+		ClientAuthorizationException exception = new ClientAuthorizationException(
+				new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN, null, null),
+				this.clientRegistration.getRegistrationId());
+
+		when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception));
+
+		assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest).block())
+				.isEqualTo(exception);
+
+		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
+		verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
+
+		OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
+		assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationContext.getAuthorizedClient()).isNull();
+		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
+
+		verify(this.authorizedClientService).removeAuthorizedClient(
+				eq(this.clientRegistration.getRegistrationId()), eq(this.principal.getName()));
+		this.removeAuthorizedClientProbe.assertWasSubscribed();
+		verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any());
+	}
+
+	@Test
+	public void authorizeWhenInvalidGrantThenRemoveAuthorizedClient() {
+		when(this.clientRegistrationRepository.findByRegistrationId(
+				eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
+
+		when(this.authorizedClientService.loadAuthorizedClient(
+				any(), any())).thenReturn(Mono.empty());
+
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
+				.principal(this.principal)
+				.build();
+
+		ClientAuthorizationException exception = new ClientAuthorizationException(
+				new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null),
+				this.clientRegistration.getRegistrationId());
+
+		when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception));
+
+		assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest).block())
+				.isEqualTo(exception);
+
+		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
+		verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
+
+		OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
+		assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationContext.getAuthorizedClient()).isNull();
+		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
+
+		verify(this.authorizedClientService).removeAuthorizedClient(
+				eq(this.clientRegistration.getRegistrationId()), eq(this.principal.getName()));
+		this.removeAuthorizedClientProbe.assertWasSubscribed();
+		verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any());
+	}
+
+	@Test
+	public void authorizeWhenServerErrorThenDoNotRemoveAuthorizedClient() {
+		when(this.clientRegistrationRepository.findByRegistrationId(
+				eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
+
+		when(this.authorizedClientService.loadAuthorizedClient(
+				any(), any())).thenReturn(Mono.empty());
+
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
+				.principal(this.principal)
+				.build();
+
+		ClientAuthorizationException exception = new ClientAuthorizationException(
+				new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, null, null),
+				this.clientRegistration.getRegistrationId());
+
+		when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception));
+
+		assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest).block())
+				.isEqualTo(exception);
+
+		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
+		verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
+
+		OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
+		assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationContext.getAuthorizedClient()).isNull();
+		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
+
+		verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any());
+		verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any());
+	}
+
+	@Test
+	public void authorizeWhenOAuth2AuthorizationExceptionThenDoNotRemoveAuthorizedClient() {
+		when(this.clientRegistrationRepository.findByRegistrationId(
+				eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
+
+		when(this.authorizedClientService.loadAuthorizedClient(
+				any(), any())).thenReturn(Mono.empty());
+
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
+				.principal(this.principal)
+				.build();
+
+		OAuth2AuthorizationException exception = new OAuth2AuthorizationException(
+				new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null));
+
+		when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception));
+
+		assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest).block())
+				.isEqualTo(exception);
+
+		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
+		verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
+
+		OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
+		assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationContext.getAuthorizedClient()).isNull();
+		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
+
+		verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any());
+		verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any());
+	}
+
+	@Test
+	public void authorizeWhenOAuth2AuthorizationExceptionAndCustomFailureHandlerThenInvokeCustomFailureHandler() {
+		when(this.clientRegistrationRepository.findByRegistrationId(
+				eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
+
+		when(this.authorizedClientService.loadAuthorizedClient(
+				any(), any())).thenReturn(Mono.empty());
+
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
+				.principal(this.principal)
+				.build();
+
+		OAuth2AuthorizationException exception = new OAuth2AuthorizationException(
+				new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null));
+
+		when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception));
+
+		PublisherProbe<Void> authorizationFailureHandlerProbe = PublisherProbe.empty();
+		this.authorizedClientManager.setAuthorizationFailureHandler((client, principal, attributes) -> authorizationFailureHandlerProbe.mono());
+
+		assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest).block())
+				.isEqualTo(exception);
+
+		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
+		verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
+
+		OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
+		assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationContext.getAuthorizedClient()).isNull();
+		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
+
+		authorizationFailureHandlerProbe.assertWasSubscribed();
+		verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any());
+		verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any());
 	}
 
 	@SuppressWarnings("unchecked")
@@ -222,6 +451,7 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
 		verify(this.authorizedClientService).saveAuthorizedClient(
 				eq(reauthorizedClient), eq(this.principal));
 		this.saveAuthorizedClientProbe.assertWasSubscribed();
+		verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any());
 	}
 
 	@SuppressWarnings("unchecked")
@@ -277,6 +507,7 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
 		verify(this.authorizedClientService).saveAuthorizedClient(
 				eq(reauthorizedClient), eq(this.principal));
 		this.saveAuthorizedClientProbe.assertWasSubscribed();
+		verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any());
 	}
 
 	@SuppressWarnings("unchecked")
@@ -302,6 +533,7 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
 		verify(this.authorizedClientService).saveAuthorizedClient(
 				eq(reauthorizedClient), eq(this.principal));
 		this.saveAuthorizedClientProbe.assertWasSubscribed();
+		verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any());
 
 		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
 

+ 6 - 6
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -24,10 +24,10 @@ import org.junit.Test;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpStatus;
 import org.springframework.http.MediaType;
+import org.springframework.security.oauth2.client.ClientAuthorizationException;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
-import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
@@ -178,7 +178,7 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
 		this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(HttpStatus.INTERNAL_SERVER_ERROR.value()));
 
 		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block())
-			.isInstanceOf(OAuth2AuthorizationException.class)
+			.isInstanceOfSatisfying(ClientAuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("unauthorized_client"))
 			.hasMessageContaining("unauthorized_client");
 	}
 
@@ -189,7 +189,7 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
 		this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(HttpStatus.INTERNAL_SERVER_ERROR.value()));
 
 		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block())
-				.isInstanceOf(OAuth2AuthorizationException.class)
+				.isInstanceOf(ClientAuthorizationException.class)
 				.hasMessageContaining("server_error");
 	}
 
@@ -204,7 +204,7 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
 		this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
 
 		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block())
-				.isInstanceOf(OAuth2AuthorizationException.class)
+				.isInstanceOf(ClientAuthorizationException.class)
 				.hasMessageContaining("invalid_token_response");
 	}
 
@@ -307,7 +307,7 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
 		this.tokenResponseClient.getTokenResponse(pkceAuthorizationCodeGrantRequest()).block();
 		String body = this.server.takeRequest().getBody().readUtf8();
 
-		assertThat(body).isEqualTo("grant_type=authorization_code&code=code&redirect_uri=%7BbaseUrl%7D%2F%7Baction%7D%2Foauth2%2Fcode%2F%7BregistrationId%7D&client_id=client-id&code_verifier=code-verifier-1234");
+		assertThat(body).isEqualTo("grant_type=authorization_code&client_id=client-id&code=code&redirect_uri=%7BbaseUrl%7D%2F%7Baction%7D%2Foauth2%2Fcode%2F%7BregistrationId%7D&code_verifier=code-verifier-1234");
 	}
 
 	private OAuth2AuthorizationCodeGrantRequest pkceAuthorizationCodeGrantRequest() {

+ 11 - 5
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -24,6 +24,7 @@ import org.junit.Before;
 import org.junit.Test;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.MediaType;
+import org.springframework.security.oauth2.client.ClientAuthorizationException;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
@@ -32,6 +33,7 @@ import org.springframework.web.reactive.function.client.WebClient;
 import org.springframework.web.reactive.function.client.WebClientResponseException;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.Mockito.*;
 
 /**
@@ -103,7 +105,7 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
 
 		assertThat(response.getAccessToken()).isNotNull();
 		assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
-		assertThat(body).isEqualTo("grant_type=client_credentials&scope=read%3Auser&client_id=client-id&client_secret=client-secret");
+		assertThat(body).isEqualTo("grant_type=client_credentials&client_id=client-id&client_secret=client-secret&scope=read%3Auser");
 	}
 
 	@Test
@@ -147,15 +149,19 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
 		verify(customClient, atLeastOnce()).post();
 	}
 
-	@Test(expected = WebClientResponseException.class)
-	// gh-6089
+	@Test
 	public void getTokenResponseWhenInvalidResponse() throws WebClientResponseException {
 		ClientRegistration registration = this.clientRegistration.build();
 		enqueueUnexpectedResponse();
 
 		OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration);
 
-		OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
+		assertThatThrownBy(() -> this.client.getTokenResponse(request).block())
+				.isInstanceOfSatisfying(ClientAuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("invalid_token_response"))
+				.hasMessageContaining("[invalid_token_response]")
+				.hasMessageContaining("Empty OAuth 2.0 Access Token Response")
+				.hasMessageContaining("HTTP Status Code: 301");
+
 	}
 
 	private void enqueueUnexpectedResponse(){

+ 14 - 10
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -24,11 +24,11 @@ import org.junit.Test;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.MediaType;
+import org.springframework.security.oauth2.client.ClientAuthorizationException;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
-import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 
 import java.time.Instant;
@@ -148,8 +148,10 @@ public class WebClientReactivePasswordTokenResponseClientTests {
 				this.clientRegistrationBuilder.build(), this.username, this.password);
 
 		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest).block())
-				.isInstanceOf(OAuth2AuthorizationException.class)
-				.hasMessageContaining("[invalid_token_response] An error occurred parsing the Access Token response")
+				.isInstanceOfSatisfying(ClientAuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("invalid_token_response"))
+				.hasMessageContaining("[invalid_token_response]")
+				.hasMessageContaining("An error occurred parsing the Access Token response")
+				.hasMessageContaining("HTTP Status Code: 200")
 				.hasCauseInstanceOf(Throwable.class);
 	}
 
@@ -186,9 +188,10 @@ public class WebClientReactivePasswordTokenResponseClientTests {
 				this.clientRegistrationBuilder.build(), this.username, this.password);
 
 		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest).block())
-				.isInstanceOf(OAuth2AuthorizationException.class)
-				.hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response")
-				.hasMessageContaining("HTTP Status Code 400");
+				.isInstanceOfSatisfying(ClientAuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("unauthorized_client"))
+				.hasMessageContaining("[unauthorized_client]")
+				.hasMessageContaining("Error retrieving OAuth 2.0 Access Token")
+				.hasMessageContaining("HTTP Status Code: 400");
 	}
 
 	@Test
@@ -199,9 +202,10 @@ public class WebClientReactivePasswordTokenResponseClientTests {
 				this.clientRegistrationBuilder.build(), this.username, this.password);
 
 		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest).block())
-				.isInstanceOf(OAuth2AuthorizationException.class)
-				.hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response")
-				.hasMessageContaining("HTTP Status Code 500");
+				.isInstanceOfSatisfying(ClientAuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("invalid_token_response"))
+				.hasMessageContaining("[invalid_token_response]")
+				.hasMessageContaining("Empty OAuth 2.0 Access Token Response")
+				.hasMessageContaining("HTTP Status Code: 500");
 	}
 
 	private MockResponse jsonResponse(String json) {

+ 12 - 10
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -24,11 +24,11 @@ import org.junit.Test;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.MediaType;
+import org.springframework.security.oauth2.client.ClientAuthorizationException;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
-import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
 import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
@@ -153,8 +153,9 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
 				this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
 
 		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block())
-				.isInstanceOf(OAuth2AuthorizationException.class)
-				.hasMessageContaining("[invalid_token_response] An error occurred parsing the Access Token response")
+				.isInstanceOf(ClientAuthorizationException.class)
+				.hasMessageContaining("[invalid_token_response]")
+				.hasMessageContaining("An error occurred parsing the Access Token response")
 				.hasCauseInstanceOf(Throwable.class);
 	}
 
@@ -191,9 +192,9 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
 				this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
 
 		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block())
-				.isInstanceOf(OAuth2AuthorizationException.class)
-				.hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response")
-				.hasMessageContaining("HTTP Status Code 400");
+				.isInstanceOfSatisfying(ClientAuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("unauthorized_client"))
+				.hasMessageContaining("[unauthorized_client]")
+				.hasMessageContaining("HTTP Status Code: 400");
 	}
 
 	@Test
@@ -204,9 +205,10 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
 				this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
 
 		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block())
-				.isInstanceOf(OAuth2AuthorizationException.class)
-				.hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response")
-				.hasMessageContaining("HTTP Status Code 500");
+				.isInstanceOfSatisfying(ClientAuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("invalid_token_response"))
+				.hasMessageContaining("[invalid_token_response]")
+				.hasMessageContaining("Empty OAuth 2.0 Access Token Response")
+				.hasMessageContaining("HTTP Status Code: 500");
 	}
 
 	private MockResponse jsonResponse(String json) {

+ 229 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@ import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
 import org.springframework.mock.web.server.MockServerWebExchange;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.ClientAuthorizationException;
 import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
 import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
@@ -31,6 +32,9 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
 import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@@ -45,6 +49,7 @@ import java.util.Map;
 import java.util.function.Function;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.Mockito.*;
 
@@ -67,6 +72,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 	private ArgumentCaptor<OAuth2AuthorizationContext> authorizationContextCaptor;
 	private PublisherProbe<OAuth2AuthorizedClient> loadAuthorizedClientProbe;
 	private PublisherProbe<Void> saveAuthorizedClientProbe;
+	private PublisherProbe<Void> removeAuthorizedClientProbe;
 
 	@SuppressWarnings("unchecked")
 	@Before
@@ -81,6 +87,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 		this.saveAuthorizedClientProbe = PublisherProbe.empty();
 		when(this.authorizedClientRepository.saveAuthorizedClient(
 				any(OAuth2AuthorizedClient.class), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(this.saveAuthorizedClientProbe.mono());
+		this.removeAuthorizedClientProbe = PublisherProbe.empty();
+		when(this.authorizedClientRepository.removeAuthorizedClient(
+				any(String.class), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(this.removeAuthorizedClientProbe.mono());
 		this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class);
 		when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.empty());
 		this.contextAttributesMapper = mock(Function.class);
@@ -119,6 +128,20 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 				.hasMessage("authorizedClientProvider cannot be null");
 	}
 
+	@Test
+	public void setAuthorizationSuccessHandlerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizationSuccessHandler cannot be null");
+	}
+
+	@Test
+	public void setAuthorizationFailureHandlerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizationFailureHandler cannot be null");
+	}
+
 	@Test
 	public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null))
@@ -204,8 +227,211 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 		verify(this.authorizedClientRepository).saveAuthorizedClient(
 				eq(this.authorizedClient), eq(this.principal), eq(this.serverWebExchange));
 		this.saveAuthorizedClientProbe.assertWasSubscribed();
+		verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any());
+	}
+
+	@SuppressWarnings("unchecked")
+	@Test
+	public void authorizeWhenNotAuthorizedAndSupportedProviderAndCustomSuccessHandlerThenInvokeCustomSuccessHandler() {
+		when(this.clientRegistrationRepository.findByRegistrationId(
+				eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
+		when(this.authorizedClientProvider.authorize(
+				any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient));
+
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
+				.principal(this.principal)
+				.build();
+
+		PublisherProbe<Void> authorizationSuccessHandlerProbe = PublisherProbe.empty();
+		this.authorizedClientManager.setAuthorizationSuccessHandler((client, principal, attributes) -> authorizationSuccessHandlerProbe.mono());
+
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest)
+				.subscriberContext(this.context).block();
+
+		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
+		verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
+
+		OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
+		assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationContext.getAuthorizedClient()).isNull();
+		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
+
+		assertThat(authorizedClient).isSameAs(this.authorizedClient);
+		authorizationSuccessHandlerProbe.assertWasSubscribed();
+		verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
+		verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any());
 	}
 
+	@SuppressWarnings("unchecked")
+	@Test
+	public void authorizeWhenInvalidTokenThenRemoveAuthorizedClient() {
+		when(this.clientRegistrationRepository.findByRegistrationId(
+				eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
+
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
+				.principal(this.principal)
+				.build();
+
+		ClientAuthorizationException exception = new ClientAuthorizationException(
+				new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN, null, null),
+				this.clientRegistration.getRegistrationId());
+
+		when(this.authorizedClientProvider.authorize(
+				any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception));
+
+		assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest)
+				.subscriberContext(this.context).block())
+				.isEqualTo(exception);
+
+		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
+		verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
+
+		OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
+		assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationContext.getAuthorizedClient()).isNull();
+		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
+
+		verify(this.authorizedClientRepository).removeAuthorizedClient(
+				eq(this.clientRegistration.getRegistrationId()), eq(this.principal), eq(this.serverWebExchange));
+		this.removeAuthorizedClientProbe.assertWasSubscribed();
+		verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
+	}
+
+	@SuppressWarnings("unchecked")
+	@Test
+	public void authorizeWhenInvalidGrantThenRemoveAuthorizedClient() {
+		when(this.clientRegistrationRepository.findByRegistrationId(
+				eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
+
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
+				.principal(this.principal)
+				.build();
+
+		ClientAuthorizationException exception = new ClientAuthorizationException(
+				new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null),
+				this.clientRegistration.getRegistrationId());
+
+		when(this.authorizedClientProvider.authorize(
+				any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception));
+
+		assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest)
+				.subscriberContext(this.context).block())
+				.isEqualTo(exception);
+
+		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
+		verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
+
+		OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
+		assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationContext.getAuthorizedClient()).isNull();
+		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
+
+		verify(this.authorizedClientRepository).removeAuthorizedClient(
+				eq(this.clientRegistration.getRegistrationId()), eq(this.principal), eq(this.serverWebExchange));
+		this.removeAuthorizedClientProbe.assertWasSubscribed();
+		verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
+	}
+
+	@SuppressWarnings("unchecked")
+	@Test
+	public void authorizeWhenServerErrorThenDoNotRemoveAuthorizedClient() {
+		when(this.clientRegistrationRepository.findByRegistrationId(
+				eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
+
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
+				.principal(this.principal)
+				.build();
+
+		ClientAuthorizationException exception = new ClientAuthorizationException(
+				new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, null, null),
+				this.clientRegistration.getRegistrationId());
+
+		when(this.authorizedClientProvider.authorize(
+				any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception));
+
+		assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest)
+				.subscriberContext(this.context).block())
+				.isEqualTo(exception);
+
+		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
+		verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
+
+		OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
+		assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationContext.getAuthorizedClient()).isNull();
+		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
+
+		verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any());
+		verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
+	}
+
+	@SuppressWarnings("unchecked")
+	@Test
+	public void authorizeWhenOAuth2AuthorizationExceptionThenDoNotRemoveAuthorizedClient() {
+		when(this.clientRegistrationRepository.findByRegistrationId(
+				eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
+
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
+				.principal(this.principal)
+				.build();
+
+		OAuth2AuthorizationException exception = new OAuth2AuthorizationException(
+				new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null));
+
+		when(this.authorizedClientProvider.authorize(
+				any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception));
+
+		assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest)
+				.subscriberContext(this.context).block())
+				.isEqualTo(exception);
+
+		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
+		verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
+
+		OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
+		assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationContext.getAuthorizedClient()).isNull();
+		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
+
+		verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any());
+		verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
+	}
+
+	@SuppressWarnings("unchecked")
+	@Test
+	public void authorizeWhenOAuth2AuthorizationExceptionAndCustomFailureHandlerThenInvokeCustomFailureHandler() {
+		when(this.clientRegistrationRepository.findByRegistrationId(
+				eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
+
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
+				.principal(this.principal)
+				.build();
+
+		OAuth2AuthorizationException exception = new OAuth2AuthorizationException(
+				new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null));
+
+		when(this.authorizedClientProvider.authorize(
+				any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception));
+
+		PublisherProbe<Void> authorizationFailureHandlerProbe = PublisherProbe.empty();
+		this.authorizedClientManager.setAuthorizationFailureHandler((client, principal, attributes) -> authorizationFailureHandlerProbe.mono());
+
+		assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest)
+				.subscriberContext(this.context).block())
+				.isEqualTo(exception);
+
+		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
+		verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
+
+		OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
+		assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationContext.getAuthorizedClient()).isNull();
+		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
+
+		authorizationFailureHandlerProbe.assertWasSubscribed();
+		verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any());
+		verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
+	}
 	@SuppressWarnings("unchecked")
 	@Test
 	public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() {
@@ -239,6 +465,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 		verify(this.authorizedClientRepository).saveAuthorizedClient(
 				eq(reauthorizedClient), eq(this.principal), eq(this.serverWebExchange));
 		this.saveAuthorizedClientProbe.assertWasSubscribed();
+		verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any());
 	}
 
 	@Test
@@ -332,6 +559,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 		verify(this.authorizedClientRepository).saveAuthorizedClient(
 				eq(reauthorizedClient), eq(this.principal), eq(this.serverWebExchange));
 		this.saveAuthorizedClientProbe.assertWasSubscribed();
+		verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any());
 	}
 
 	@Test

+ 333 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionITests.java

@@ -0,0 +1,333 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.web.reactive.function.client;
+
+import okhttp3.mockwebserver.MockResponse;
+import okhttp3.mockwebserver.MockWebServer;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpStatus;
+import org.springframework.http.MediaType;
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
+import org.springframework.mock.web.server.MockServerWebExchange;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2AuthorizedClientService;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
+import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
+import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
+import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
+import org.springframework.web.reactive.function.client.WebClient;
+import org.springframework.web.reactive.function.client.WebClientResponseException;
+import org.springframework.web.server.ServerWebExchange;
+import reactor.core.publisher.Mono;
+import reactor.util.context.Context;
+
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.HashSet;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
+
+/**
+ * @author Phil Clay
+ */
+public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests {
+
+	private ReactiveClientRegistrationRepository clientRegistrationRepository;
+	private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
+	private ServerOAuth2AuthorizedClientExchangeFilterFunction authorizedClientFilter;
+	private MockWebServer server;
+	private String serverUrl;
+	private WebClient webClient;
+	private Authentication authentication;
+	private MockServerWebExchange exchange;
+
+	@Before
+	public void setUp() throws Exception {
+		this.clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class);
+		final ServerOAuth2AuthorizedClientRepository delegate = new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository(
+				new InMemoryReactiveOAuth2AuthorizedClientService(this.clientRegistrationRepository));
+		this.authorizedClientRepository = spy(new ServerOAuth2AuthorizedClientRepository() {
+
+			@Override
+			public <T extends OAuth2AuthorizedClient> Mono<T> loadAuthorizedClient(
+					String clientRegistrationId,
+					Authentication principal, ServerWebExchange exchange) {
+				return delegate.loadAuthorizedClient(clientRegistrationId, principal, exchange);
+			}
+
+			@Override
+			public Mono<Void> saveAuthorizedClient(
+					OAuth2AuthorizedClient authorizedClient,
+					Authentication principal, ServerWebExchange exchange) {
+				return delegate.saveAuthorizedClient(authorizedClient, principal, exchange);
+			}
+
+			@Override
+			public Mono<Void> removeAuthorizedClient(
+					String clientRegistrationId,
+					Authentication principal, ServerWebExchange exchange) {
+				return delegate.removeAuthorizedClient(clientRegistrationId, principal, exchange);
+			}
+
+		});
+		this.authorizedClientFilter = new ServerOAuth2AuthorizedClientExchangeFilterFunction(
+				this.clientRegistrationRepository, this.authorizedClientRepository);
+		this.server = new MockWebServer();
+		this.server.start();
+		this.serverUrl = this.server.url("/").toString();
+		this.webClient = WebClient.builder()
+				.filter(this.authorizedClientFilter)
+				.build();
+		this.authentication = new TestingAuthenticationToken("principal", "password");
+		this.exchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/").build()).build();
+	}
+
+	@After
+	public void cleanup() throws Exception {
+		this.server.shutdown();
+	}
+
+	@Test
+	public void requestWhenNotAuthorizedThenAuthorizeAndSendRequest() {
+		String accessTokenResponse = "{\n" +
+				"	\"access_token\": \"access-token-1234\",\n" +
+				"   \"token_type\": \"bearer\",\n" +
+				"   \"expires_in\": \"3600\",\n" +
+				"   \"scope\": \"read write\"\n" +
+				"}\n";
+		String clientResponse = "{\n" +
+				"	\"attribute1\": \"value1\",\n" +
+				"	\"attribute2\": \"value2\"\n" +
+				"}\n";
+		this.server.enqueue(jsonResponse(accessTokenResponse));
+		this.server.enqueue(jsonResponse(clientResponse));
+
+		ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().tokenUri(this.serverUrl).build();
+		when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))).thenReturn(Mono.just(clientRegistration));
+
+		this.webClient
+				.get()
+				.uri(this.serverUrl)
+				.attributes(clientRegistrationId(clientRegistration.getRegistrationId()))
+				.retrieve()
+				.bodyToMono(String.class)
+				.subscriberContext(Context.of(ServerWebExchange.class, this.exchange))
+				.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication))
+				.block();
+
+		assertThat(this.server.getRequestCount()).isEqualTo(2);
+
+		ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class);
+		verify(this.authorizedClientRepository).saveAuthorizedClient(
+				authorizedClientCaptor.capture(), eq(this.authentication), eq(this.exchange));
+		assertThat(authorizedClientCaptor.getValue().getClientRegistration()).isSameAs(clientRegistration);
+	}
+
+	@Test
+	public void requestWhenAuthorizedButExpiredThenRefreshAndSendRequest() {
+		String accessTokenResponse = "{\n" +
+				"	\"access_token\": \"refreshed-access-token\",\n" +
+				"   \"token_type\": \"bearer\",\n" +
+				"   \"expires_in\": \"3600\"\n" +
+				"}\n";
+		String clientResponse = "{\n" +
+				"	\"attribute1\": \"value1\",\n" +
+				"	\"attribute2\": \"value2\"\n" +
+				"}\n";
+
+		this.server.enqueue(jsonResponse(accessTokenResponse));
+		this.server.enqueue(jsonResponse(clientResponse));
+
+		ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().tokenUri(this.serverUrl).build();
+		when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))).thenReturn(Mono.just(clientRegistration));
+
+		Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
+		Instant expiresAt = issuedAt.plus(Duration.ofHours(1));
+		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
+				"expired-access-token", issuedAt, expiresAt, new HashSet<>(Arrays.asList("read", "write")));
+		OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken();
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+				clientRegistration, this.authentication.getName(), accessToken, refreshToken);
+		doReturn(Mono.just(authorizedClient)).when(this.authorizedClientRepository).loadAuthorizedClient(
+				eq(clientRegistration.getRegistrationId()), eq(this.authentication), eq(this.exchange));
+
+		this.webClient
+				.get()
+				.uri(this.serverUrl)
+				.attributes(clientRegistrationId(clientRegistration.getRegistrationId()))
+				.retrieve()
+				.bodyToMono(String.class)
+				.subscriberContext(Context.of(ServerWebExchange.class, this.exchange))
+				.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication))
+				.block();
+
+		assertThat(this.server.getRequestCount()).isEqualTo(2);
+
+		ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class);
+		verify(this.authorizedClientRepository).saveAuthorizedClient(
+				authorizedClientCaptor.capture(), eq(this.authentication), eq(this.exchange));
+		OAuth2AuthorizedClient refreshedAuthorizedClient = authorizedClientCaptor.getValue();
+		assertThat(refreshedAuthorizedClient.getClientRegistration()).isSameAs(clientRegistration);
+		assertThat(refreshedAuthorizedClient.getAccessToken().getTokenValue()).isEqualTo("refreshed-access-token");
+	}
+
+	@Test
+	public void requestMultipleWhenNoneAuthorizedThenAuthorizeAndSendRequest() {
+		String accessTokenResponse = "{\n" +
+				"	\"access_token\": \"access-token-1234\",\n" +
+				"   \"token_type\": \"bearer\",\n" +
+				"   \"expires_in\": \"3600\",\n" +
+				"   \"scope\": \"read write\"\n" +
+				"}\n";
+		String clientResponse = "{\n" +
+				"	\"attribute1\": \"value1\",\n" +
+				"	\"attribute2\": \"value2\"\n" +
+				"}\n";
+
+		// Client 1
+		this.server.enqueue(jsonResponse(accessTokenResponse));
+		this.server.enqueue(jsonResponse(clientResponse));
+
+		ClientRegistration clientRegistration1 = TestClientRegistrations.clientCredentials()
+				.registrationId("client-1").tokenUri(this.serverUrl).build();
+		when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration1.getRegistrationId()))).thenReturn(Mono.just(clientRegistration1));
+
+		// Client 2
+		this.server.enqueue(jsonResponse(accessTokenResponse));
+		this.server.enqueue(jsonResponse(clientResponse));
+
+		ClientRegistration clientRegistration2 = TestClientRegistrations.clientCredentials()
+				.registrationId("client-2").tokenUri(this.serverUrl).build();
+		when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration2.getRegistrationId()))).thenReturn(Mono.just(clientRegistration2));
+
+		this.webClient
+				.get()
+				.uri(this.serverUrl)
+				.attributes(clientRegistrationId(clientRegistration1.getRegistrationId()))
+				.retrieve()
+				.bodyToMono(String.class)
+				.flatMap(response -> this.webClient
+						.get()
+						.uri(this.serverUrl)
+						.attributes(clientRegistrationId(clientRegistration2.getRegistrationId()))
+						.retrieve()
+						.bodyToMono(String.class))
+				.subscriberContext(Context.of(ServerWebExchange.class, this.exchange))
+				.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication))
+				.block();
+
+		assertThat(this.server.getRequestCount()).isEqualTo(4);
+
+		ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class);
+		verify(this.authorizedClientRepository, times(2)).saveAuthorizedClient(
+				authorizedClientCaptor.capture(), eq(this.authentication), eq(this.exchange));
+		assertThat(authorizedClientCaptor.getAllValues().get(0).getClientRegistration()).isSameAs(clientRegistration1);
+		assertThat(authorizedClientCaptor.getAllValues().get(1).getClientRegistration()).isSameAs(clientRegistration2);
+	}
+
+	/**
+	 * When a non-expired {@link OAuth2AuthorizedClient} exists
+	 * but the resource server returns 401,
+	 * then remove the {@link OAuth2AuthorizedClient} from the repository.
+	 */
+	@Test
+	public void requestWhenUnauthorizedThenReAuthorize() {
+		String accessTokenResponse = "{\n" +
+				"	\"access_token\": \"access-token-1234\",\n" +
+				"   \"token_type\": \"bearer\",\n" +
+				"   \"expires_in\": \"3600\",\n" +
+				"   \"scope\": \"read write\"\n" +
+				"}\n";
+		String clientResponse = "{\n" +
+				"	\"attribute1\": \"value1\",\n" +
+				"	\"attribute2\": \"value2\"\n" +
+				"}\n";
+		this.server.enqueue(new MockResponse().setResponseCode(HttpStatus.UNAUTHORIZED.value()));
+		this.server.enqueue(jsonResponse(accessTokenResponse));
+		this.server.enqueue(jsonResponse(clientResponse));
+
+		ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().tokenUri(this.serverUrl).build();
+		when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))).thenReturn(Mono.just(clientRegistration));
+
+		OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("read", "write");
+		OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken();
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+				clientRegistration, this.authentication.getName(), accessToken, refreshToken);
+		doReturn(Mono.just(authorizedClient))
+				.doReturn(Mono.empty())
+				.when(this.authorizedClientRepository).loadAuthorizedClient(
+						eq(clientRegistration.getRegistrationId()), eq(this.authentication), eq(this.exchange));
+
+		Mono<String> requestMono = this.webClient
+				.get()
+				.uri(this.serverUrl)
+				.attributes(clientRegistrationId(clientRegistration.getRegistrationId()))
+				.retrieve()
+				.bodyToMono(String.class)
+				.subscriberContext(Context.of(ServerWebExchange.class, this.exchange))
+				.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication));
+
+		// first try should fail, and remove the cached authorized client
+		assertThatCode(requestMono::block)
+				.isInstanceOfSatisfying(WebClientResponseException.class, e -> assertThat(e.getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED));
+
+		assertThat(this.server.getRequestCount()).isEqualTo(1);
+
+		verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
+		verify(this.authorizedClientRepository).removeAuthorizedClient(
+				eq(clientRegistration.getRegistrationId()), eq(this.authentication), eq(this.exchange));
+
+		// second try should retrieve the authorized client and succeed
+		requestMono.block();
+
+		assertThat(this.server.getRequestCount()).isEqualTo(3);
+
+		ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class);
+		verify(this.authorizedClientRepository).saveAuthorizedClient(
+				authorizedClientCaptor.capture(), eq(this.authentication), eq(this.exchange));
+		assertThat(authorizedClientCaptor.getValue().getClientRegistration()).isSameAs(clientRegistration);
+	}
+
+	private MockResponse jsonResponse(String json) {
+		return new MockResponse()
+				.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
+				.setBody(json);
+	}
+}

+ 268 - 3
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -13,7 +13,6 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package org.springframework.security.oauth2.client.web.reactive.function.client;
 
 import org.junit.Before;
@@ -27,6 +26,7 @@ import org.springframework.core.codec.ByteBufferEncoder;
 import org.springframework.core.codec.CharSequenceEncoder;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
+import org.springframework.http.HttpStatus;
 import org.springframework.http.MediaType;
 import org.springframework.http.codec.EncoderHttpMessageWriter;
 import org.springframework.http.codec.FormHttpMessageWriter;
@@ -39,11 +39,15 @@ import org.springframework.http.server.reactive.ServerHttpRequest;
 import org.springframework.mock.http.client.reactive.MockClientHttpRequest;
 import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
 import org.springframework.mock.web.server.MockServerWebExchange;
+import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.Authentication;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.oauth2.client.ClientAuthorizationException;
 import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationFailureHandler;
 import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
@@ -59,6 +63,9 @@ import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2Autho
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.server.UnAuthenticatedServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@@ -67,11 +74,15 @@ import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.util.StringUtils;
 import org.springframework.web.reactive.function.BodyInserter;
 import org.springframework.web.reactive.function.client.ClientRequest;
+import org.springframework.web.reactive.function.client.ExchangeFunction;
+import org.springframework.web.reactive.function.client.WebClientResponseException;
 import org.springframework.web.server.ServerWebExchange;
 import reactor.core.publisher.Mono;
+import reactor.test.publisher.PublisherProbe;
 import reactor.util.context.Context;
 
 import java.net.URI;
+import java.nio.charset.StandardCharsets;
 import java.time.Duration;
 import java.time.Instant;
 import java.util.ArrayList;
@@ -82,8 +93,16 @@ import java.util.Map;
 import java.util.Optional;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.assertj.core.api.Assertions.entry;
 import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
-import static org.mockito.Mockito.*;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyZeroInteractions;
+import static org.mockito.Mockito.when;
 import static org.springframework.http.HttpMethod.GET;
 import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
 import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient;
@@ -109,6 +128,18 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 	@Mock
 	private ReactiveOAuth2AccessTokenResponseClient<OAuth2PasswordGrantRequest> passwordTokenResponseClient;
 
+	@Mock
+	private ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler;
+
+	@Captor
+	private ArgumentCaptor<OAuth2AuthorizationException> authorizationExceptionCaptor;
+
+	@Captor
+	private ArgumentCaptor<Authentication> authenticationCaptor;
+
+	@Captor
+	private ArgumentCaptor<Map<String, Object>> attributesCaptor;
+
 	private ServerWebExchange serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build();
 
 	@Captor
@@ -414,6 +445,240 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		assertThat(getBody(request0)).isEmpty();
 	}
 
+	@Test
+	public void filterWhenUnauthorizedThenInvokeFailureHandler() {
+		function.setAuthorizationFailureHandler(authorizationFailureHandler);
+
+		PublisherProbe<Void> publisherProbe = PublisherProbe.empty();
+		when(authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())).thenReturn(publisherProbe.mono());
+
+		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt());
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
+				"principalName", this.accessToken, refreshToken);
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.build();
+
+		when(exchange.getResponse().rawStatusCode()).thenReturn(HttpStatus.UNAUTHORIZED.value());
+
+		this.function.filter(request, this.exchange)
+				.subscriberContext(serverWebExchange())
+				.block();
+
+		assertThat(publisherProbe.wasSubscribed()).isTrue();
+
+		verify(authorizationFailureHandler).onAuthorizationFailure(
+				authorizationExceptionCaptor.capture(),
+				authenticationCaptor.capture(),
+				attributesCaptor.capture());
+
+		assertThat(authorizationExceptionCaptor.getValue())
+				.isInstanceOfSatisfying(ClientAuthorizationException.class, e -> {
+					assertThat(e.getClientRegistrationId()).isEqualTo(registration.getRegistrationId());
+					assertThat(e.getError().getErrorCode()).isEqualTo("invalid_token");
+					assertThat(e).hasNoCause();
+					assertThat(e).hasMessageContaining("[invalid_token]");
+				});
+		assertThat(authenticationCaptor.getValue())
+				.isInstanceOf(AnonymousAuthenticationToken.class);
+		assertThat(attributesCaptor.getValue())
+				.containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange));
+	}
+
+	@Test
+	public void filterWhenUnauthorizedWithWebClientExceptionThenInvokeFailureHandler() {
+		function.setAuthorizationFailureHandler(authorizationFailureHandler);
+
+		PublisherProbe<Void> publisherProbe = PublisherProbe.empty();
+		when(authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())).thenReturn(publisherProbe.mono());
+
+		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt());
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
+				"principalName", this.accessToken, refreshToken);
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.build();
+
+		WebClientResponseException exception = WebClientResponseException.create(
+				HttpStatus.UNAUTHORIZED.value(),
+				HttpStatus.UNAUTHORIZED.getReasonPhrase(),
+				HttpHeaders.EMPTY,
+				new byte[0],
+				StandardCharsets.UTF_8);
+
+		ExchangeFunction throwingExchangeFunction = r -> Mono.error(exception);
+
+		assertThatCode(() -> this.function.filter(request, throwingExchangeFunction)
+				.subscriberContext(serverWebExchange())
+				.block())
+				.isEqualTo(exception);
+
+		assertThat(publisherProbe.wasSubscribed()).isTrue();
+
+		verify(authorizationFailureHandler).onAuthorizationFailure(
+				authorizationExceptionCaptor.capture(),
+				authenticationCaptor.capture(),
+				attributesCaptor.capture());
+
+		assertThat(authorizationExceptionCaptor.getValue())
+				.isInstanceOfSatisfying(ClientAuthorizationException.class, e -> {
+					assertThat(e.getClientRegistrationId()).isEqualTo(registration.getRegistrationId());
+					assertThat(e.getError().getErrorCode()).isEqualTo("invalid_token");
+					assertThat(e).hasCause(exception);
+					assertThat(e).hasMessageContaining("[invalid_token]");
+				});
+		assertThat(authenticationCaptor.getValue())
+				.isInstanceOf(AnonymousAuthenticationToken.class);
+		assertThat(attributesCaptor.getValue())
+				.containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange));
+	}
+
+	@Test
+	public void filterWhenForbiddenThenInvokeFailureHandler() {
+		function.setAuthorizationFailureHandler(authorizationFailureHandler);
+
+		PublisherProbe<Void> publisherProbe = PublisherProbe.empty();
+		when(authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())).thenReturn(publisherProbe.mono());
+
+		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt());
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
+				"principalName", this.accessToken, refreshToken);
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.build();
+
+		when(exchange.getResponse().rawStatusCode()).thenReturn(HttpStatus.FORBIDDEN.value());
+
+		this.function.filter(request, this.exchange)
+				.subscriberContext(serverWebExchange())
+				.block();
+
+		assertThat(publisherProbe.wasSubscribed()).isTrue();
+
+		verify(authorizationFailureHandler).onAuthorizationFailure(
+				authorizationExceptionCaptor.capture(),
+				authenticationCaptor.capture(),
+				attributesCaptor.capture());
+
+		assertThat(authorizationExceptionCaptor.getValue())
+				.isInstanceOfSatisfying(ClientAuthorizationException.class, e -> {
+					assertThat(e.getClientRegistrationId()).isEqualTo(registration.getRegistrationId());
+					assertThat(e.getError().getErrorCode()).isEqualTo("insufficient_scope");
+					assertThat(e).hasNoCause();
+					assertThat(e).hasMessageContaining("[insufficient_scope]");
+				});
+		assertThat(authenticationCaptor.getValue())
+				.isInstanceOf(AnonymousAuthenticationToken.class);
+		assertThat(attributesCaptor.getValue())
+				.containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange));
+	}
+
+	@Test
+	public void filterWhenForbiddenWithWebClientExceptionThenInvokeFailureHandler() {
+		function.setAuthorizationFailureHandler(authorizationFailureHandler);
+
+		PublisherProbe<Void> publisherProbe = PublisherProbe.empty();
+		when(authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())).thenReturn(publisherProbe.mono());
+
+		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt());
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
+				"principalName", this.accessToken, refreshToken);
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.build();
+
+		WebClientResponseException exception = WebClientResponseException.create(
+				HttpStatus.FORBIDDEN.value(),
+				HttpStatus.FORBIDDEN.getReasonPhrase(),
+				HttpHeaders.EMPTY,
+				new byte[0],
+				StandardCharsets.UTF_8);
+
+		ExchangeFunction throwingExchangeFunction = r -> Mono.error(exception);
+
+		assertThatCode(() -> this.function.filter(request, throwingExchangeFunction)
+						.subscriberContext(serverWebExchange())
+						.block())
+				.isEqualTo(exception);
+
+		assertThat(publisherProbe.wasSubscribed()).isTrue();
+
+		verify(authorizationFailureHandler).onAuthorizationFailure(
+				authorizationExceptionCaptor.capture(),
+				authenticationCaptor.capture(),
+				attributesCaptor.capture());
+
+		assertThat(authorizationExceptionCaptor.getValue())
+				.isInstanceOfSatisfying(ClientAuthorizationException.class, e -> {
+					assertThat(e.getClientRegistrationId()).isEqualTo(registration.getRegistrationId());
+					assertThat(e.getError().getErrorCode()).isEqualTo("insufficient_scope");
+					assertThat(e).hasCause(exception);
+					assertThat(e).hasMessageContaining("[insufficient_scope]");
+				});
+		assertThat(authenticationCaptor.getValue())
+				.isInstanceOf(AnonymousAuthenticationToken.class);
+		assertThat(attributesCaptor.getValue())
+				.containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange));
+	}
+
+	@Test
+	public void filterWhenAuthorizationExceptionThenInvokeFailureHandler() {
+		function.setAuthorizationFailureHandler(authorizationFailureHandler);
+
+		PublisherProbe<Void> publisherProbe = PublisherProbe.empty();
+		when(authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())).thenReturn(publisherProbe.mono());
+
+		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt());
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
+				"principalName", this.accessToken, refreshToken);
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.build();
+
+		OAuth2AuthorizationException exception = new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN, null, null));
+
+		ExchangeFunction throwingExchangeFunction = r -> Mono.error(exception);
+
+		assertThatCode(() -> this.function.filter(request, throwingExchangeFunction)
+						.subscriberContext(serverWebExchange())
+						.block())
+				.isEqualTo(exception);
+
+		assertThat(publisherProbe.wasSubscribed()).isTrue();
+
+		verify(authorizationFailureHandler).onAuthorizationFailure(
+				authorizationExceptionCaptor.capture(),
+				authenticationCaptor.capture(),
+				attributesCaptor.capture());
+
+		assertThat(authorizationExceptionCaptor.getValue())
+				.isSameAs(exception);
+		assertThat(authenticationCaptor.getValue())
+				.isInstanceOf(AnonymousAuthenticationToken.class);
+		assertThat(attributesCaptor.getValue())
+				.containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange));
+	}
+
+	@Test
+	public void filterWhenOtherHttpStatusShouldNotInvokeFailureHandler() {
+		function.setAuthorizationFailureHandler(authorizationFailureHandler);
+
+		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt());
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
+				"principalName", this.accessToken, refreshToken);
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.build();
+
+		when(exchange.getResponse().rawStatusCode()).thenReturn(HttpStatus.BAD_REQUEST.value());
+
+		this.function.filter(request, this.exchange)
+				.subscriberContext(serverWebExchange())
+				.block();
+
+		verify(authorizationFailureHandler, never()).onAuthorizationFailure(any(), any(), any());
+	}
+
 	@Test
 	public void filterWhenPasswordClientNotAuthorizedThenGetNewToken() {
 		TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");

+ 30 - 3
oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AuthorizationException.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,6 +15,8 @@
  */
 package org.springframework.security.oauth2.core;
 
+import org.springframework.util.Assert;
+
 /**
  * Base exception for OAuth 2.0 Authorization errors.
  *
@@ -30,7 +32,19 @@ public class OAuth2AuthorizationException extends RuntimeException {
 	 * @param error the {@link OAuth2Error OAuth 2.0 Error}
 	 */
 	public OAuth2AuthorizationException(OAuth2Error error) {
-		super(error.toString());
+		this(error, error.toString());
+	}
+
+	/**
+	 * Constructs an {@code OAuth2AuthorizationException} using the provided parameters.
+	 *
+	 * @param error the {@link OAuth2Error OAuth 2.0 Error}
+	 * @param message the exception message
+	 * @since 5.3
+	 */
+	public OAuth2AuthorizationException(OAuth2Error error, String message) {
+		super(message);
+		Assert.notNull(error, "error must not be null");
 		this.error = error;
 	}
 
@@ -41,7 +55,20 @@ public class OAuth2AuthorizationException extends RuntimeException {
 	 * @param cause the root cause
 	 */
 	public OAuth2AuthorizationException(OAuth2Error error, Throwable cause) {
-		super(error.toString(), cause);
+		this(error, error.toString(), cause);
+	}
+
+	/**
+	 * Constructs an {@code OAuth2AuthorizationException} using the provided parameters.
+	 *
+	 * @param error the {@link OAuth2Error OAuth 2.0 Error}
+	 * @param message the exception message
+	 * @param cause the root cause
+	 * @since 5.3
+	 */
+	public OAuth2AuthorizationException(OAuth2Error error, String message, Throwable cause) {
+		super(message, cause);
+		Assert.notNull(error, "error must not be null");
 		this.error = error;
 	}
 

+ 22 - 1
oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2ErrorCodes.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -53,6 +53,27 @@ public interface OAuth2ErrorCodes {
 	 */
 	String INVALID_SCOPE = "invalid_scope";
 
+	/**
+	 * {@code insufficient_scope} - The request requires higher privileges than
+	 *  provided by the access token.
+	 *  The resource server SHOULD respond with the HTTP 403 (Forbidden)
+	 *  status code and MAY include the "scope" attribute with the scope
+	 *  necessary to access the protected resource.
+	 *
+	 * @see <a href="https://tools.ietf.org/html/rfc6750#section-3.1">RFC-6750 - Section 3.1 - Error Codes</a>
+	 */
+	String INSUFFICIENT_SCOPE = "insufficient_scope";
+
+	/**
+	 * {@code invalid_token} - The access token provided is expired, revoked,
+	 * malformed, or invalid for other reasons.
+	 * The resource SHOULD respond with the HTTP 401 (Unauthorized) status code.
+	 * The client MAY request a new access token and retry the protected resource request.
+	 *
+	 * @see <a href="https://tools.ietf.org/html/rfc6750#section-3.1">RFC-6750 - Section 3.1 - Error Codes</a>
+	 */
+	String INVALID_TOKEN = "invalid_token";
+
 	/**
 	 * {@code server_error} - The authorization server encountered an
 	 * unexpected condition that prevented it from fulfilling the request.

+ 14 - 3
oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2AccessTokenResponseBodyExtractor.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -58,6 +58,10 @@ class OAuth2AccessTokenResponseBodyExtractor
 		ParameterizedTypeReference<Map<String, Object>> type = new ParameterizedTypeReference<Map<String, Object>>() {};
 		BodyExtractor<Mono<Map<String, Object>>, ReactiveHttpInputMessage> delegate = BodyExtractors.toMono(type);
 		return delegate.extract(inputMessage, context)
+				.onErrorMap(e -> new OAuth2AuthorizationException(
+						invalidTokenResponse("An error occurred parsing the Access Token response: " + e.getMessage()), e))
+				.switchIfEmpty(Mono.error(() -> new OAuth2AuthorizationException(
+						invalidTokenResponse("Empty OAuth 2.0 Access Token Response"))))
 				.map(OAuth2AccessTokenResponseBodyExtractor::parse)
 				.flatMap(OAuth2AccessTokenResponseBodyExtractor::oauth2AccessTokenResponse)
 				.map(OAuth2AccessTokenResponseBodyExtractor::oauth2AccessTokenResponse);
@@ -68,12 +72,19 @@ class OAuth2AccessTokenResponseBodyExtractor
 			return TokenResponse.parse(new JSONObject(json));
 		}
 		catch (ParseException pe) {
-			OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
-					"An error occurred parsing the Access Token response: " + pe.getMessage(), null);
+			OAuth2Error oauth2Error = invalidTokenResponse(
+					"An error occurred parsing the Access Token response: " + pe.getMessage());
 			throw new OAuth2AuthorizationException(oauth2Error, pe);
 		}
 	}
 
+	private static OAuth2Error invalidTokenResponse(String message) {
+		return new OAuth2Error(
+				INVALID_TOKEN_RESPONSE_ERROR_CODE,
+				message,
+				null);
+	}
+
 	private static Mono<AccessTokenResponse> oauth2AccessTokenResponse(TokenResponse tokenResponse) {
 		if (tokenResponse.indicatesSuccess()) {
 			return Mono.just(tokenResponse)

+ 19 - 3
oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractorsTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -30,6 +30,7 @@ import org.springframework.http.codec.json.Jackson2JsonDecoder;
 import org.springframework.http.server.reactive.ServerHttpResponse;
 import org.springframework.mock.http.client.reactive.MockClientHttpResponse;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.web.reactive.function.BodyExtractor;
 import reactor.core.publisher.Mono;
@@ -92,8 +93,23 @@ public class OAuth2BodyExtractorsTests {
 
 		Mono<OAuth2AccessTokenResponse> result = extractor.extract(response, this.context);
 
-		assertThatCode(() -> result.block())
-				.isInstanceOf(RuntimeException.class);
+		assertThatCode(result::block)
+				.isInstanceOf(OAuth2AuthorizationException.class)
+				.hasMessageContaining("An error occurred parsing the Access Token response");
+	}
+
+	@Test
+	public void oauth2AccessTokenResponseWhenEmptyThenException() {
+		BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = OAuth2BodyExtractors
+				.oauth2AccessTokenResponse();
+
+		MockClientHttpResponse response = new MockClientHttpResponse(HttpStatus.OK);
+
+		Mono<OAuth2AccessTokenResponse> result = extractor.extract(response, this.context);
+
+		assertThatCode(result::block)
+				.isInstanceOf(OAuth2AuthorizationException.class)
+				.hasMessageContaining("Empty OAuth 2.0 Access Token Response");
 	}
 
 	@Test