Преглед изворни кода

ServerOAuth2AuthorizedClientExchangeFilterFunction defaultOAuth2AuthorizedClient

Defaults to use the OAuth2AuthenticationToken to resolve the authorized client

Issue: gh-4921
Rob Winch пре 7 година
родитељ
комит
ac78258847

+ 31 - 6
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

@@ -27,6 +27,7 @@ import org.springframework.security.core.context.ReactiveSecurityContextHolder;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
 import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
 import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient;
@@ -86,6 +87,8 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 
 	private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
 
+	private boolean defaultOAuth2AuthorizedClient;
+
 	private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
 			new WebClientReactiveClientCredentialsTokenResponseClient();
 
@@ -174,6 +177,17 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 		return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
 	}
 
+	/**
+	 * If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
+	 * recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
+	 * resolved from the current Authentication.
+	 * @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false.
+	 *                                      Default is false.
+	 */
+	public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) {
+		this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient;
+	}
+
 	/**
 	 * Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
 	 * client_credentials grant.
@@ -216,14 +230,25 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 		if (this.authorizedClientRepository == null) {
 			return Mono.empty();
 		}
-		String clientRegistrationId = (String) request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME);
-		if (clientRegistrationId == null) {
-			return Mono.empty();
-		}
+
 		ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
 		return currentAuthentication()
-			.flatMap(principal -> loadAuthorizedClient(clientRegistrationId, exchange, principal)
-		);
+			.flatMap(principal -> clientRegistrationId(request, principal)
+					.flatMap(clientRegistrationId -> loadAuthorizedClient(clientRegistrationId, exchange, principal))
+			);
+	}
+
+	private Mono<String> clientRegistrationId(ClientRequest request, Authentication authentication) {
+		return Mono.justOrEmpty(request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME))
+				.cast(String.class)
+				.switchIfEmpty(clientRegistrationId(authentication));
+	}
+
+	private Mono<String> clientRegistrationId(Authentication authentication) {
+		return Mono.justOrEmpty(authentication)
+				.filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken)
+				.cast(OAuth2AuthenticationToken.class)
+				.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
 	}
 
 	private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId,

+ 59 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@@ -34,8 +34,10 @@ import org.springframework.http.codec.multipart.MultipartHttpMessageWriter;
 import org.springframework.http.server.reactive.ServerHttpRequest;
 import org.springframework.mock.http.client.reactive.MockClientHttpRequest;
 import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.ReactiveSecurityContextHolder;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
@@ -43,6 +45,8 @@ import org.springframework.security.oauth2.client.web.server.ServerOAuth2Authori
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
+import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.web.reactive.function.BodyInserter;
 import org.springframework.web.reactive.function.client.ClientRequest;
 import reactor.core.publisher.Mono;
@@ -51,6 +55,7 @@ import java.net.URI;
 import java.time.Duration;
 import java.time.Instant;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -60,6 +65,7 @@ import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyZeroInteractions;
 import static org.mockito.Mockito.when;
 import static org.springframework.http.HttpMethod.GET;
 import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
@@ -293,6 +299,59 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		assertThat(getBody(request0)).isEmpty();
 	}
 
+	@Test
+	public void filterWhenClientRegistrationIdFromAuthenticationThenAuthorizedClientResolved() {
+		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
+		this.function.setDefaultOAuth2AuthorizedClient(true);
+
+		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
+				"principalName", this.accessToken, refreshToken);
+		when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
+		when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration));
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.build();
+
+		OAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections
+				.singletonMap("user", "rob"), "user");
+		OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(), "client-id");
+		this.function
+				.filter(request, this.exchange)
+				.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
+				.block();
+
+		List<ClientRequest> requests = this.exchange.getRequests();
+		assertThat(requests).hasSize(1);
+
+		ClientRequest request0 = requests.get(0);
+		assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
+		assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com");
+		assertThat(request0.method()).isEqualTo(HttpMethod.GET);
+		assertThat(getBody(request0)).isEmpty();
+	}
+
+	@Test
+	public void filterWhenDefaultOAuth2AuthorizedClientFalseThenEmpty() {
+		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
+
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.build();
+
+		OAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections
+				.singletonMap("user", "rob"), "user");
+		OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(), "client-id");
+
+		this.function
+				.filter(request, this.exchange)
+				.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
+				.block();
+
+		List<ClientRequest> requests = this.exchange.getRequests();
+		assertThat(requests).hasSize(1);
+
+		verifyZeroInteractions(this.clientRegistrationRepository, this.authorizedClientRepository);
+	}
+
 	private static String getBody(ClientRequest request) {
 		final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
 		messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));