Browse Source

Add defaultOAuth2AuthorizedClient flag

Fixes: gh-5619
Rob Winch 7 years ago
parent
commit
1a65abd781

+ 18 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

@@ -109,12 +109,25 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 
 	private OAuth2AuthorizedClientRepository authorizedClientRepository;
 
+	private boolean defaultOAuth2AuthorizedClient;
+
 	public ServletOAuth2AuthorizedClientExchangeFilterFunction() {}
 
 	public ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientRepository authorizedClientRepository) {
 		this.authorizedClientRepository = authorizedClientRepository;
 	}
 
+	/**
+	 * If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
+	 * recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
+	 * resolved from the current Authentication.
+	 * @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false.
+	 *                                      Default is false.
+	 */
+	public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) {
+		this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient;
+	}
+
 	/**
 	 * Configures the builder with {@link #defaultRequest()} and adds this as a {@link ExchangeFilterFunction}
 	 * @return the {@link Consumer} to configure the builder
@@ -251,13 +264,16 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 	}
 
 	private void populateDefaultOAuth2AuthorizedClient(Map<String, Object> attrs) {
-		if (this.authorizedClientRepository == null || attrs.containsKey(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)) {
+		if (this.authorizedClientRepository == null
+				|| attrs.containsKey(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)) {
 			return;
 		}
 
 		Authentication authentication = getAuthentication(attrs);
 		String clientRegistrationId = getClientRegistrationId(attrs);
-		if (clientRegistrationId == null  && authentication instanceof OAuth2AuthenticationToken) {
+		if (clientRegistrationId == null
+				&& this.defaultOAuth2AuthorizedClient
+				&& authentication instanceof OAuth2AuthenticationToken) {
 			clientRegistrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId();
 		}
 		if (clientRegistrationId != null) {

+ 15 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@@ -207,8 +207,9 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 	}
 
 	@Test
-	public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() {
+	public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() {
 		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		this.function.setDefaultOAuth2AuthorizedClient(true);
 		OAuth2User user = mock(OAuth2User.class);
 		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
 		OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id");
@@ -223,6 +224,19 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		verify(this.authorizedClientRepository).loadAuthorizedClient(eq(token.getAuthorizedClientRegistrationId()), any(), any());
 	}
 
+	@Test
+	public void defaultRequestOAuth2AuthorizedClientWhenDefaultFalseAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() {
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		OAuth2User user = mock(OAuth2User.class);
+		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
+		OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id");
+		authentication(token).accept(this.result);
+
+		Map<String, Object> attrs = getDefaultRequestAttributes();
+
+		assertThat(getOAuth2AuthorizedClient(attrs)).isNull();
+	}
+
 	@Test
 	public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegistrationIdThenIdIsExplicit() {
 		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);