Procházet zdrojové kódy

OAuth2AuthorizationCodeGrantWebFilter should handle OAuth2AuthorizationException

Fixes gh-8609
Joe Grandja před 5 roky
rodič
revize
a372ec9ef5

+ 17 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java

@@ -20,7 +20,11 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
 import org.springframework.util.Assert;
 
 /**
@@ -40,6 +44,7 @@ import org.springframework.util.Assert;
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.4">Section 4.1.4 Access Token Response</a>
  */
 public class OAuth2AuthorizationCodeAuthenticationProvider implements AuthenticationProvider {
+	private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter";
 	private final OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
 
 	/**
@@ -59,8 +64,18 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
 		OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication =
 			(OAuth2AuthorizationCodeAuthenticationToken) authentication;
 
-		OAuth2AuthorizationExchangeValidator.validate(
-			authorizationCodeAuthentication.getAuthorizationExchange());
+		OAuth2AuthorizationResponse authorizationResponse = authorizationCodeAuthentication
+				.getAuthorizationExchange().getAuthorizationResponse();
+		if (authorizationResponse.statusError()) {
+			throw new OAuth2AuthorizationException(authorizationResponse.getError());
+		}
+
+		OAuth2AuthorizationRequest authorizationRequest = authorizationCodeAuthentication
+				.getAuthorizationExchange().getAuthorizationRequest();
+		if (!authorizationResponse.getState().equals(authorizationRequest.getState())) {
+			OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE);
+			throw new OAuth2AuthorizationException(oauth2Error);
+		}
 
 		OAuth2AccessTokenResponse accessTokenResponse =
 			this.accessTokenResponseClient.getTokenResponse(

+ 17 - 4
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeReactiveAuthenticationManager.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -22,9 +22,13 @@ import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessT
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.util.Assert;
 import reactor.core.publisher.Mono;
@@ -55,8 +59,8 @@ import java.util.function.Function;
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.3">Section 4.1.3 Access Token Request</a>
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.4">Section 4.1.4 Access Token Response</a>
  */
-public class OAuth2AuthorizationCodeReactiveAuthenticationManager implements
-		ReactiveAuthenticationManager {
+public class OAuth2AuthorizationCodeReactiveAuthenticationManager implements ReactiveAuthenticationManager {
+	private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter";
 	private final ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
 
 	public OAuth2AuthorizationCodeReactiveAuthenticationManager(
@@ -70,7 +74,16 @@ public class OAuth2AuthorizationCodeReactiveAuthenticationManager implements
 		return Mono.defer(() -> {
 			OAuth2AuthorizationCodeAuthenticationToken token = (OAuth2AuthorizationCodeAuthenticationToken) authentication;
 
-			OAuth2AuthorizationExchangeValidator.validate(token.getAuthorizationExchange());
+			OAuth2AuthorizationResponse authorizationResponse = token.getAuthorizationExchange().getAuthorizationResponse();
+			if (authorizationResponse.statusError()) {
+				return Mono.error(new OAuth2AuthorizationException(authorizationResponse.getError()));
+			}
+
+			OAuth2AuthorizationRequest authorizationRequest = token.getAuthorizationExchange().getAuthorizationRequest();
+			if (!authorizationResponse.getState().equals(authorizationRequest.getState())) {
+				OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE);
+				return Mono.error(new OAuth2AuthorizationException(oauth2Error));
+			}
 
 			OAuth2AuthorizationCodeGrantRequest authzRequest = new OAuth2AuthorizationCodeGrantRequest(
 					token.getClientRegistration(),

+ 0 - 47
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationExchangeValidator.java

@@ -1,47 +0,0 @@
-/*
- * Copyright 2002-2019 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.security.oauth2.client.authentication;
-
-import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
-import org.springframework.security.oauth2.core.OAuth2Error;
-import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
-import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
-import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
-
-/**
- * A validator for an &quot;exchange&quot; of an OAuth 2.0 Authorization Request and Response.
- *
- * @author Joe Grandja
- * @since 5.1
- * @see OAuth2AuthorizationExchange
- */
-final class OAuth2AuthorizationExchangeValidator {
-	private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter";
-
-	static void validate(OAuth2AuthorizationExchange authorizationExchange) {
-		OAuth2AuthorizationRequest authorizationRequest = authorizationExchange.getAuthorizationRequest();
-		OAuth2AuthorizationResponse authorizationResponse = authorizationExchange.getAuthorizationResponse();
-
-		if (authorizationResponse.statusError()) {
-			throw new OAuth2AuthorizationException(authorizationResponse.getError());
-		}
-
-		if (!authorizationResponse.getState().equals(authorizationRequest.getState())) {
-			OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE);
-			throw new OAuth2AuthorizationException(oauth2Error);
-		}
-	}
-}

+ 12 - 4
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java

@@ -27,6 +27,8 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@@ -146,15 +148,21 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter {
 	public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
 		return this.requiresAuthenticationMatcher.matches(exchange)
 				.filter(ServerWebExchangeMatcher.MatchResult::isMatch)
-				.flatMap(matchResult -> this.authenticationConverter.convert(exchange))
+				.flatMap(matchResult ->
+						this.authenticationConverter.convert(exchange)
+								.onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException(
+										e.getError(), e.getError().toString())))
 				.switchIfEmpty(chain.filter(exchange).then(Mono.empty()))
-				.flatMap(token -> authenticate(exchange, chain, token));
+				.flatMap(token -> authenticate(exchange, chain, token))
+				.onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler
+						.onAuthenticationFailure(new WebFilterExchange(exchange, chain), e));
 	}
 
-	private Mono<Void> authenticate(ServerWebExchange exchange,
-			WebFilterChain chain, Authentication token) {
+	private Mono<Void> authenticate(ServerWebExchange exchange, WebFilterChain chain, Authentication token) {
 		WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain);
 		return this.authenticationManager.authenticate(token)
+				.onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException(
+						e.getError(), e.getError().toString()))
 				.switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalStateException("No provider found for " + token.getClass()))))
 				.flatMap(authentication -> onAuthenticationSuccess(authentication, webFilterExchange))
 				.onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler

+ 1 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.java

@@ -18,7 +18,6 @@ package org.springframework.security.oauth2.client.web.server;
 
 import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
-import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
@@ -33,7 +32,7 @@ import org.springframework.web.util.UriComponentsBuilder;
 import reactor.core.publisher.Mono;
 
 /**
- * Converts from a {@link ServerWebExchange} to an {@link OAuth2LoginAuthenticationToken} that can be authenticated. The
+ * Converts from a {@link ServerWebExchange} to an {@link OAuth2AuthorizationCodeAuthenticationToken} that can be authenticated. The
  * converter does not validate any errors it only performs a conversion.
  * @author Rob Winch
  * @since 5.1

+ 54 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java

@@ -29,6 +29,9 @@ import org.springframework.security.oauth2.client.authentication.TestOAuth2Autho
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.util.CollectionUtils;
@@ -41,6 +44,7 @@ import java.util.LinkedHashMap;
 import java.util.Map;
 
 import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -226,6 +230,56 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests {
 		verifyZeroInteractions(this.authenticationManager);
 	}
 
+	// gh-8609
+	@Test
+	public void filterWhenAuthenticationConverterThrowsOAuth2AuthorizationExceptionThenMappedToOAuth2AuthenticationException() {
+		ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
+		when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.empty());
+
+		MockServerHttpRequest authorizationRequest =
+				createAuthorizationRequest("/authorization/callback");
+		OAuth2AuthorizationRequest oauth2AuthorizationRequest =
+				createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration);
+		MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
+		MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse);
+		DefaultWebFilterChain chain = new DefaultWebFilterChain(
+				e -> e.getResponse().setComplete(), Collections.emptyList());
+
+		this.authorizationRequestRepository.saveAuthorizationRequest(oauth2AuthorizationRequest, exchange).block();
+
+		assertThatThrownBy(() -> this.filter.filter(exchange, chain).block())
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.hasMessageContaining("client_registration_not_found");
+		verifyZeroInteractions(this.authenticationManager);
+	}
+
+	// gh-8609
+	@Test
+	public void filterWhenAuthenticationManagerThrowsOAuth2AuthorizationExceptionThenMappedToOAuth2AuthenticationException() {
+		ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
+		when(this.clientRegistrationRepository.findByRegistrationId(any()))
+				.thenReturn(Mono.just(clientRegistration));
+
+		MockServerHttpRequest authorizationRequest =
+				createAuthorizationRequest("/authorization/callback");
+		OAuth2AuthorizationRequest oauth2AuthorizationRequest =
+				createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration);
+
+		when(this.authenticationManager.authenticate(any()))
+				.thenReturn(Mono.error(new OAuth2AuthorizationException(new OAuth2Error("authorization_error"))));
+
+		MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
+		MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse);
+		DefaultWebFilterChain chain = new DefaultWebFilterChain(
+				e -> e.getResponse().setComplete(), Collections.emptyList());
+
+		this.authorizationRequestRepository.saveAuthorizationRequest(oauth2AuthorizationRequest, exchange).block();
+
+		assertThatThrownBy(() -> this.filter.filter(exchange, chain).block())
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.hasMessageContaining("authorization_error");
+	}
+
 	private static OAuth2AuthorizationRequest createOAuth2AuthorizationRequest(
 			MockServerHttpRequest authorizationRequest, ClientRegistration registration) {
 		Map<String, Object> additionalParameters = new HashMap<>();