瀏覽代碼

Resolve OAuth2Error from WWW-Authenticate header

Issue gh-7699
Joe Grandja 5 年之前
父節點
當前提交
f2da2c56be

+ 72 - 38
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

@@ -16,8 +16,8 @@
 
 package org.springframework.security.oauth2.client.web.reactive.function.client;
 
+import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpStatus;
-import org.springframework.lang.Nullable;
 import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.authority.AuthorityUtils;
@@ -34,20 +34,22 @@ import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClient
 import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder;
 import org.springframework.security.oauth2.client.RefreshTokenReactiveOAuth2AuthorizedClientProvider;
-import org.springframework.security.oauth2.client.web.RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler;
-import org.springframework.security.oauth2.client.web.SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
 import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager;
+import org.springframework.security.oauth2.client.web.RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler;
+import org.springframework.security.oauth2.client.web.SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.server.UnAuthenticatedServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
 import org.springframework.web.reactive.function.client.ClientRequest;
 import org.springframework.web.reactive.function.client.ClientResponse;
 import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
@@ -62,6 +64,8 @@ import java.util.HashMap;
 import java.util.Map;
 import java.util.Optional;
 import java.util.function.Consumer;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 /**
  * Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth2 requests by including the
@@ -614,32 +618,84 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 		}
 
 		@Override
-		public Mono<ClientResponse> handleResponse(
-				ClientRequest request,
-				Mono<ClientResponse> responseMono) {
-
+		public Mono<ClientResponse> handleResponse(ClientRequest request, Mono<ClientResponse> responseMono) {
 			return responseMono
-				.flatMap(response -> handleHttpStatus(request, response.rawStatusCode(), null)
+				.flatMap(response -> handleResponse(request, response)
 						.thenReturn(response))
-				.onErrorResume(WebClientResponseException.class, e -> handleHttpStatus(request, e.getRawStatusCode(), e)
+				.onErrorResume(WebClientResponseException.class, e -> handleWebClientResponseException(request, e)
 						.then(Mono.error(e)))
 				.onErrorResume(OAuth2AuthorizationException.class, e -> handleAuthorizationException(request, e)
 						.then(Mono.error(e)));
 		}
 
+		private Mono<Void> handleResponse(ClientRequest request, ClientResponse response) {
+			return Mono.justOrEmpty(resolveErrorIfPossible(response))
+					.flatMap(oauth2Error -> {
+						Mono<Optional<ServerWebExchange>> serverWebExchange = effectiveServerWebExchange(request);
+
+						Mono<String> clientRegistrationId = effectiveClientRegistrationId(request);
+
+						return Mono.zip(currentAuthenticationMono, serverWebExchange, clientRegistrationId)
+								.flatMap(tuple3 -> handleAuthorizationFailure(
+										tuple3.getT1(),              // Authentication principal
+										tuple3.getT2().orElse(null), // ServerWebExchange exchange
+										new ClientAuthorizationException(
+												oauth2Error,
+												tuple3.getT3())));      // String clientRegistrationId
+					});
+		}
+
+		private OAuth2Error resolveErrorIfPossible(ClientResponse response) {
+			// Try to resolve from 'WWW-Authenticate' header
+			if (!response.headers().header(HttpHeaders.WWW_AUTHENTICATE).isEmpty()) {
+				String wwwAuthenticateHeader = response.headers().header(HttpHeaders.WWW_AUTHENTICATE).get(0);
+				Map<String, String> authParameters = parseAuthParameters(wwwAuthenticateHeader);
+				if (authParameters.containsKey(OAuth2ParameterNames.ERROR)) {
+					return new OAuth2Error(
+							authParameters.get(OAuth2ParameterNames.ERROR),
+							authParameters.get(OAuth2ParameterNames.ERROR_DESCRIPTION),
+							authParameters.get(OAuth2ParameterNames.ERROR_URI));
+				}
+			}
+			return resolveErrorIfPossible(response.rawStatusCode());
+		}
+
+		private OAuth2Error resolveErrorIfPossible(int statusCode) {
+			if (this.httpStatusToOAuth2ErrorCodeMap.containsKey(statusCode)) {
+				return new OAuth2Error(
+						this.httpStatusToOAuth2ErrorCodeMap.get(statusCode),
+						null,
+						"https://tools.ietf.org/html/rfc6750#section-3.1");
+			}
+			return null;
+		}
+
+		private Map<String, String> parseAuthParameters(String wwwAuthenticateHeader) {
+			return Stream.of(wwwAuthenticateHeader)
+					.filter(header -> !StringUtils.isEmpty(header))
+					.filter(header -> header.toLowerCase().startsWith("bearer"))
+					.map(header -> header.substring("bearer".length()))
+					.map(header -> header.split(","))
+					.flatMap(Stream::of)
+					.map(parameter -> parameter.split("="))
+					.filter(parameter -> parameter.length > 1)
+					.collect(Collectors.toMap(
+							parameters -> parameters[0].trim(),
+							parameters -> parameters[1].trim().replace("\"", "")));
+		}
+
 		/**
 		 * Handles the given http status code returned from a resource server
 		 * by notifying the authorization failure handler if the http status
 		 * code is in the {@link #httpStatusToOAuth2ErrorCodeMap}.
 		 *
 		 * @param request the request being processed
-		 * @param httpStatusCode the http status returned by the resource server
-		 * @param exception The root cause exception for the failure (nullable)
+		 * @param exception The root cause exception for the failure
 		 * @return a {@link Mono} that completes empty after the authorization failure handler completes.
 		 */
-		private Mono<Void> handleHttpStatus(ClientRequest request, int httpStatusCode, @Nullable Exception exception) {
-			return Mono.justOrEmpty(this.httpStatusToOAuth2ErrorCodeMap.get(httpStatusCode))
-					.flatMap(oauth2ErrorCode -> {
+		private Mono<Void> handleWebClientResponseException(ClientRequest request, WebClientResponseException exception) {
+			return Mono.justOrEmpty(resolveErrorIfPossible(exception.getRawStatusCode()))
+					.flatMap(oauth2Error -> {
 						Mono<Optional<ServerWebExchange>> serverWebExchange = effectiveServerWebExchange(request);
 
 						Mono<String> clientRegistrationId = effectiveClientRegistrationId(request);
@@ -648,9 +704,9 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 								.flatMap(tuple3 -> handleAuthorizationFailure(
 										tuple3.getT1(),              // Authentication principal
 										tuple3.getT2().orElse(null), // ServerWebExchange exchange
-										createAuthorizationException(
+										new ClientAuthorizationException(
+												oauth2Error,
 												tuple3.getT3(),      // String clientRegistrationId
-												oauth2ErrorCode,
 												exception)));
 					});
 		}
@@ -673,28 +729,6 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 							exception));
 		}
 
-		/**
-		 * Creates an authorization exception using the given parameters.
-		 *
-		 * @param clientRegistrationId the client registration id of the client that failed authentication/authorization.
-		 * @param oauth2ErrorCode the OAuth 2.0 error code to use in the authorization failure event
-		 * @param exception The root cause exception for the failure (nullable)
-		 * @return an authorization exception using the given parameters.
-		 */
-		private ClientAuthorizationException createAuthorizationException(
-				String clientRegistrationId,
-				String oauth2ErrorCode,
-				@Nullable Exception exception) {
-			return new ClientAuthorizationException(
-					new OAuth2Error(
-							oauth2ErrorCode,
-							null,
-							"https://tools.ietf.org/html/rfc6750#section-3.1"),
-					clientRegistrationId,
-					exception);
-		}
-
-
 		/**
 		 * Delegates to the authorization failure handler of the failed authorization.
 		 *

+ 51 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@@ -74,6 +74,7 @@ import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.util.StringUtils;
 import org.springframework.web.reactive.function.BodyInserter;
 import org.springframework.web.reactive.function.client.ClientRequest;
+import org.springframework.web.reactive.function.client.ClientResponse;
 import org.springframework.web.reactive.function.client.ExchangeFunction;
 import org.springframework.web.reactive.function.client.WebClientResponseException;
 import org.springframework.web.server.ServerWebExchange;
@@ -98,6 +99,7 @@ import static org.assertj.core.api.Assertions.entry;
 import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
@@ -173,6 +175,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		this.authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
 		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager);
 		when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
+		when(this.exchange.getResponse().headers()).thenReturn(mock(ClientResponse.Headers.class));
 	}
 
 	@Test
@@ -621,6 +624,54 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 				.containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange));
 	}
 
+	@Test
+	public void filterWhenWWWAuthenticateHeaderIncludesErrorThenInvokeFailureHandler() {
+		function.setAuthorizationFailureHandler(authorizationFailureHandler);
+
+		PublisherProbe<Void> publisherProbe = PublisherProbe.empty();
+		when(authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())).thenReturn(publisherProbe.mono());
+
+		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt());
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
+				"principalName", this.accessToken, refreshToken);
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.build();
+
+		String wwwAuthenticateHeader = "Bearer error=\"insufficient_scope\", " +
+				"error_description=\"The request requires higher privileges than provided by the access token.\", " +
+				"error_uri=\"https://tools.ietf.org/html/rfc6750#section-3.1\"";
+		ClientResponse.Headers headers = mock(ClientResponse.Headers.class);
+		when(headers.header(eq(HttpHeaders.WWW_AUTHENTICATE)))
+				.thenReturn(Collections.singletonList(wwwAuthenticateHeader));
+		when(this.exchange.getResponse().headers()).thenReturn(headers);
+
+		this.function.filter(request, this.exchange)
+				.subscriberContext(serverWebExchange())
+				.block();
+
+		assertThat(publisherProbe.wasSubscribed()).isTrue();
+
+		verify(authorizationFailureHandler).onAuthorizationFailure(
+				authorizationExceptionCaptor.capture(),
+				authenticationCaptor.capture(),
+				attributesCaptor.capture());
+
+		assertThat(authorizationExceptionCaptor.getValue())
+				.isInstanceOfSatisfying(ClientAuthorizationException.class, e -> {
+					assertThat(e.getClientRegistrationId()).isEqualTo(registration.getRegistrationId());
+					assertThat(e.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INSUFFICIENT_SCOPE);
+					assertThat(e.getError().getDescription()).isEqualTo("The request requires higher privileges than provided by the access token.");
+					assertThat(e.getError().getUri()).isEqualTo("https://tools.ietf.org/html/rfc6750#section-3.1");
+					assertThat(e).hasNoCause();
+					assertThat(e).hasMessageContaining(OAuth2ErrorCodes.INSUFFICIENT_SCOPE);
+				});
+		assertThat(authenticationCaptor.getValue())
+				.isInstanceOf(AnonymousAuthenticationToken.class);
+		assertThat(attributesCaptor.getValue())
+				.containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange));
+	}
+
 	@Test
 	public void filterWhenAuthorizationExceptionThenInvokeFailureHandler() {
 		function.setAuthorizationFailureHandler(authorizationFailureHandler);