浏览代码

OAuth2AuthorizationRequestRedirectWebFilter handles ClientAuthorizationRequiredException

Fixes: gh-5383
Rob Winch 7 年之前
父节点
当前提交
2658577396

+ 1 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilter.java

@@ -136,6 +136,7 @@ public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter {
 			.map(ServerWebExchangeMatcher.MatchResult::getVariables)
 			.map(variables -> variables.get(REGISTRATION_ID_URI_VARIABLE_NAME))
 			.cast(String.class)
+			.onErrorResume(ClientAuthorizationRequiredException.class, e -> Mono.just(e.getClientRegistrationId()))
 			.flatMap(clientRegistrationId -> this.findByRegistrationId(exchange, clientRegistrationId))
 			.flatMap(clientRegistration -> sendRedirectForAuthorization(exchange, clientRegistration));
 	}

+ 12 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilterTests.java

@@ -21,6 +21,7 @@ import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
+import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
@@ -133,4 +134,15 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests {
 		});
 		verify(this.authzRequestRepository).saveAuthorizationRequest(any(), any());
 	}
+
+	@Test
+	public void filterWhenExceptionThenRedirected() {
+		FilteringWebHandler webHandler = new FilteringWebHandler(e -> Mono.error(new ClientAuthorizationRequiredException(this.github.getRegistrationId())), Arrays.asList(this.filter));
+		this.client = WebTestClient.bindToWebHandler(webHandler).build();
+		FluxExchangeResult<String> result = this.client.get()
+				.uri("https://example.com/foo").exchange()
+				.expectStatus()
+				.is3xxRedirection()
+				.returnResult(String.class);
+	}
 }