浏览代码

OAuth2LoginAuthenticationWebFilter should handle OAuth2AuthorizationException

Issue gh-8609
Joe Grandja 5 年之前
父节点
当前提交
e146a7c16b

+ 8 - 1
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -34,6 +34,8 @@ import java.util.Map;
 import java.util.Optional;
 import java.util.function.Function;
 
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import reactor.core.publisher.Mono;
 import reactor.util.context.Context;
 
@@ -578,7 +580,12 @@ public class ServerHttpSecurity {
 
 		private ServerAuthenticationConverter getAuthenticationConverter(ReactiveClientRegistrationRepository clientRegistrationRepository) {
 			if (this.authenticationConverter == null) {
-				this.authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository);
+				ServerOAuth2AuthorizationCodeAuthenticationTokenConverter delegate =
+						new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository);
+				ServerAuthenticationConverter authenticationConverter = exchange ->
+						delegate.convert(exchange).onErrorMap(OAuth2AuthorizationException.class,
+								e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString()));
+				this.authenticationConverter = authenticationConverter;
 			}
 			return this.authenticationConverter;
 		}

+ 24 - 7
config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java

@@ -16,12 +16,6 @@
 
 package org.springframework.security.config.web.server;
 
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
-
 import org.junit.Rule;
 import org.junit.Test;
 import org.openqa.selenium.WebDriver;
@@ -67,13 +61,18 @@ import org.springframework.web.reactive.config.EnableWebFlux;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebFilter;
 import org.springframework.web.server.WebFilterChain;
-
 import org.springframework.web.server.WebHandler;
 import reactor.core.publisher.Mono;
 
 import java.time.Duration;
 import java.time.Instant;
 
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
 /**
  * @author Rob Winch
  * @since 5.1
@@ -301,6 +300,24 @@ public class OAuth2LoginTests {
 		}
 	}
 
+	// gh-8609
+	@Test
+	public void oauth2LoginWhenAuthenticationConverterFailsThenDefaultRedirectToLogin() {
+		this.spring.register(OAuth2LoginWithMulitpleClientRegistrations.class).autowire();
+
+		WebTestClient webTestClient = WebTestClientBuilder
+				.bindToWebFilters(this.springSecurity)
+				.build();
+
+		webTestClient.get()
+				.uri("/login/oauth2/code/google")
+				.exchange()
+				.expectStatus()
+				.is3xxRedirection()
+				.expectHeader()
+				.valueEquals("Location", "/login?error");
+	}
+
 	static class GitHubWebFilter implements WebFilter {
 
 		@Override

+ 6 - 5
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -117,13 +117,14 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
 					.getAuthorizationExchange().getAuthorizationResponse();
 
 			if (authorizationResponse.statusError()) {
-				throw new OAuth2AuthenticationException(
-						authorizationResponse.getError(), authorizationResponse.getError().toString());
+				return Mono.error(new OAuth2AuthenticationException(
+						authorizationResponse.getError(), authorizationResponse.getError().toString()));
 			}
 
 			if (!authorizationResponse.getState().equals(authorizationRequest.getState())) {
 				OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE);
-				throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
+				return Mono.error(new OAuth2AuthenticationException(
+						oauth2Error, oauth2Error.toString()));
 			}
 
 			OAuth2AuthorizationCodeGrantRequest authzRequest = new OAuth2AuthorizationCodeGrantRequest(
@@ -156,7 +157,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
 					INVALID_ID_TOKEN_ERROR_CODE,
 					"Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(),
 					null);
-			throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString());
+			return Mono.error(new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString()));
 		}
 
 		return createOidcToken(clientRegistration, accessTokenResponse)

+ 6 - 7
web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -89,17 +89,16 @@ public class AuthenticationWebFilter implements WebFilter {
 			.filter( matchResult -> matchResult.isMatch())
 			.flatMap( matchResult -> this.authenticationConverter.convert(exchange))
 			.switchIfEmpty(chain.filter(exchange).then(Mono.empty()))
-			.flatMap( token -> authenticate(exchange, chain, token));
+			.flatMap( token -> authenticate(exchange, chain, token))
+			.onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler
+					.onAuthenticationFailure(new WebFilterExchange(exchange, chain), e));
 	}
 
-	private Mono<Void> authenticate(ServerWebExchange exchange,
-		WebFilterChain chain, Authentication token) {
+	private Mono<Void> authenticate(ServerWebExchange exchange, WebFilterChain chain, Authentication token) {
 		WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain);
 		return this.authenticationManager.authenticate(token)
 			.switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalStateException("No provider found for " + token.getClass()))))
-			.flatMap(authentication -> onAuthenticationSuccess(authentication, webFilterExchange))
-			.onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler
-				.onAuthenticationFailure(webFilterExchange, e));
+			.flatMap(authentication -> onAuthenticationSuccess(authentication, webFilterExchange));
 	}
 
 	protected Mono<Void> onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) {