Jelajahi Sumber

WebFlux oauth2Login() redirects on failed authentication

Fixes gh-5562 gh-6484
Joe Grandja 5 tahun lalu
induk
melakukan
c40a17b4d1

+ 10 - 2
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -986,7 +986,7 @@ public class ServerHttpSecurity {
 
 		private ServerAuthenticationSuccessHandler authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler();
 
-		private ServerAuthenticationFailureHandler authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception);
+		private ServerAuthenticationFailureHandler authenticationFailureHandler;
 
 		/**
 		 * Configures the {@link ReactiveAuthenticationManager} to use. The default is
@@ -1028,6 +1028,7 @@ public class ServerHttpSecurity {
 
 		/**
 		 * The {@link ServerAuthenticationFailureHandler} used after authentication failure.
+		 * Defaults to {@link RedirectServerAuthenticationFailureHandler} redirecting to "/login?error".
 		 *
 		 * @since 5.2
 		 * @param authenticationFailureHandler the failure handler to use
@@ -1175,7 +1176,7 @@ public class ServerHttpSecurity {
 			authenticationFilter.setServerAuthenticationConverter(getAuthenticationConverter(clientRegistrationRepository));
 
 			authenticationFilter.setAuthenticationSuccessHandler(this.authenticationSuccessHandler);
-			authenticationFilter.setAuthenticationFailureHandler(this.authenticationFailureHandler);
+			authenticationFilter.setAuthenticationFailureHandler(getAuthenticationFailureHandler());
 			authenticationFilter.setSecurityContextRepository(this.securityContextRepository);
 
 			MediaTypeServerWebExchangeMatcher htmlMatcher = new MediaTypeServerWebExchangeMatcher(
@@ -1192,6 +1193,13 @@ public class ServerHttpSecurity {
 			http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION);
 		}
 
+		private ServerAuthenticationFailureHandler getAuthenticationFailureHandler() {
+			if (this.authenticationFailureHandler == null) {
+				this.authenticationFailureHandler = new RedirectServerAuthenticationFailureHandler("/login?error");
+			}
+			return this.authenticationFailureHandler;
+		}
+
 		private ServerWebExchangeMatcher createAttemptAuthenticationRequestMatcher() {
 			return new PathPatternParserServerWebExchangeMatcher("/login/oauth2/code/{registrationId}");
 		}

+ 80 - 0
config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java

@@ -70,6 +70,7 @@ import org.springframework.security.oauth2.core.oidc.user.TestOidcUsers;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.security.oauth2.core.user.TestOAuth2Users;
 import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.JwtValidationException;
 import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
 import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
 import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
@@ -518,6 +519,85 @@ public class OAuth2LoginTests {
 		verify(securityContextRepository).save(any(), any());
 	}
 
+	// gh-5562
+	@Test
+	public void oauth2LoginWhenAccessTokenRequestFailsThenDefaultRedirectToLogin() {
+		this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class,
+				OAuth2LoginWithCustomBeansConfig.class).autowire();
+
+		WebTestClient webTestClient = WebTestClientBuilder
+				.bindToWebFilters(this.springSecurity)
+				.build();
+
+		OAuth2AuthorizationRequest request = TestOAuth2AuthorizationRequests.request().scope("openid").build();
+		OAuth2AuthorizationResponse response = TestOAuth2AuthorizationResponses.success().build();
+		OAuth2AuthorizationExchange exchange = new OAuth2AuthorizationExchange(request, response);
+		OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("openid");
+		OAuth2AuthorizationCodeAuthenticationToken authenticationToken = new OAuth2AuthorizationCodeAuthenticationToken(google, exchange, accessToken);
+
+		OAuth2LoginWithCustomBeansConfig config = this.spring.getContext().getBean(OAuth2LoginWithCustomBeansConfig.class);
+
+		ServerAuthenticationConverter converter = config.authenticationConverter;
+		when(converter.convert(any())).thenReturn(Mono.just(authenticationToken));
+
+		ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> tokenResponseClient = config.tokenResponseClient;
+		OAuth2Error oauth2Error = new OAuth2Error("invalid_request", "Invalid request", null);
+		when(tokenResponseClient.getTokenResponse(any())).thenThrow(new OAuth2AuthenticationException(oauth2Error));
+
+		webTestClient.get()
+				.uri("/login/oauth2/code/google")
+				.exchange()
+				.expectStatus()
+					.is3xxRedirection()
+				.expectHeader()
+					.valueEquals("Location", "/login?error");
+	}
+
+	// gh-6484
+	@Test
+	public void oauth2LoginWhenIdTokenValidationFailsThenDefaultRedirectToLogin() {
+		this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class,
+				OAuth2LoginWithCustomBeansConfig.class).autowire();
+
+		WebTestClient webTestClient = WebTestClientBuilder
+				.bindToWebFilters(this.springSecurity)
+				.build();
+
+		OAuth2LoginWithCustomBeansConfig config = this.spring.getContext().getBean(OAuth2LoginWithCustomBeansConfig.class);
+
+		OAuth2AuthorizationRequest request = TestOAuth2AuthorizationRequests.request().scope("openid").build();
+		OAuth2AuthorizationResponse response = TestOAuth2AuthorizationResponses.success().build();
+		OAuth2AuthorizationExchange exchange = new OAuth2AuthorizationExchange(request, response);
+		OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("openid");
+		OAuth2AuthorizationCodeAuthenticationToken authenticationToken = new OAuth2AuthorizationCodeAuthenticationToken(google, exchange, accessToken);
+
+		ServerAuthenticationConverter converter = config.authenticationConverter;
+		when(converter.convert(any())).thenReturn(Mono.just(authenticationToken));
+
+		Map<String, Object> additionalParameters = new HashMap<>();
+		additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token");
+		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue())
+				.tokenType(accessToken.getTokenType())
+				.scopes(accessToken.getScopes())
+				.additionalParameters(additionalParameters)
+				.build();
+		ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> tokenResponseClient = config.tokenResponseClient;
+		when(tokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
+
+		ReactiveJwtDecoderFactory<ClientRegistration> jwtDecoderFactory = config.jwtDecoderFactory;
+		OAuth2Error oauth2Error = new OAuth2Error("invalid_id_token", "Invalid ID Token", null);
+		when(jwtDecoderFactory.createDecoder(any())).thenReturn(token ->
+				Mono.error(new JwtValidationException("ID Token validation failed", Collections.singleton(oauth2Error))));
+
+		webTestClient.get()
+				.uri("/login/oauth2/code/google")
+				.exchange()
+				.expectStatus()
+					.is3xxRedirection()
+				.expectHeader()
+					.valueEquals("Location", "/login?error");
+	}
+
 	@Configuration
 	static class OAuth2LoginWithCustomBeansConfig {