ソースを参照

ServerOAuth2AuthorizedClientExchangeFilterFunction clientRegistrationId

You can now provide the clientRegistrationId and
ServerOAuth2AuthorizedClientExchangeFilterFunction will look up the authorized client automatically.

Issue: gh-4921
Rob Winch 7 年 前
コミット
89f2874bff

+ 62 - 9
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

@@ -19,14 +19,19 @@ package org.springframework.security.oauth2.client.web.reactive.function.client;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.MediaType;
+import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.ReactiveSecurityContextHolder;
 import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.OAuth2ClientException;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.util.Assert;
 import org.springframework.web.reactive.function.BodyInserters;
@@ -61,10 +66,17 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 	 */
 	private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName();
 
+	/**
+	 * The client request attribute name used to locate the {@link ClientRegistration#getRegistrationId()}
+	 */
+	private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2AuthorizedClient.class.getName().concat(".CLIENT_REGISTRATION_ID");
+
 	/**
 	 * The request attribute name used to locate the {@link org.springframework.web.server.ServerWebExchange}.
 	 */
 	private static final String SERVER_WEB_EXCHANGE_ATTR_NAME = ServerWebExchange.class.getName();
+	public static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
+			AuthorityUtils.createAuthorityList("ROLE_USER"));
 
 	private Clock clock = Clock.systemUTC();
 
@@ -74,8 +86,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 
 	public ServerOAuth2AuthorizedClientExchangeFilterFunction() {}
 
-	public ServerOAuth2AuthorizedClientExchangeFilterFunction(
-			ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
+	public ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
 		this.authorizedClientRepository = authorizedClientRepository;
 	}
 
@@ -141,6 +152,18 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 		return attributes -> attributes.put(SERVER_WEB_EXCHANGE_ATTR_NAME, serverWebExchange);
 	}
 
+	/**
+	 * Modifies the {@link ClientRequest#attributes()} to include the {@link ClientRegistration#getRegistrationId()} to
+	 * be used to look up the {@link OAuth2AuthorizedClient}.
+	 *
+	 * @param clientRegistrationId the {@link ClientRegistration#getRegistrationId()} to
+	 * be used to look up the {@link OAuth2AuthorizedClient}.
+	 * @return the {@link Consumer} to populate the attributes
+	 */
+	public static Consumer<Map<String, Object>> clientRegistrationId(String clientRegistrationId) {
+		return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
+	}
+
 	/**
 	 * An access token will be considered expired by comparing its expiration to now +
 	 * this skewed Duration. The default is 1 minute.
@@ -153,17 +176,42 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 
 	@Override
 	public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
-		Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
-				.map(OAuth2AuthorizedClient.class::cast);
 		ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
-		return Mono.justOrEmpty(attribute)
-				.flatMap(authorizedClient -> authorizedClient(next, authorizedClient, exchange))
+		return authorizedClient(request)
+				.flatMap(authorizedClient -> refreshIfNecessary(next, authorizedClient, exchange))
 				.map(authorizedClient -> bearer(request, authorizedClient))
 				.flatMap(next::exchange)
 				.switchIfEmpty(next.exchange(request));
 	}
 
-	private Mono<OAuth2AuthorizedClient> authorizedClient(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
+	private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) {
+		Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
+				.map(OAuth2AuthorizedClient.class::cast);
+		return Mono.justOrEmpty(attribute)
+				.switchIfEmpty(findAuthorizedClientByRegistrationId(request));
+	}
+
+	private Mono<OAuth2AuthorizedClient> findAuthorizedClientByRegistrationId(ClientRequest request) {
+		if (this.authorizedClientRepository == null) {
+			return Mono.empty();
+		}
+		String clientRegistrationId = (String) request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME);
+		if (clientRegistrationId == null) {
+			return Mono.empty();
+		}
+		ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
+		return currentAuthentication()
+			.flatMap(principal -> loadAuthorizedClient(clientRegistrationId, exchange, principal)
+		);
+	}
+
+	private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId,
+			ServerWebExchange exchange, Authentication principal) {
+		return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange)
+			.switchIfEmpty(Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId)));
+	}
+
+	private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
 		if (shouldRefresh(authorizedClient)) {
 			return refreshAuthorizedClient(next, authorizedClient, exchange);
 		}
@@ -184,13 +232,18 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 		return next.exchange(request)
 				.flatMap(response -> response.body(oauth2AccessTokenResponse()))
 				.map(accessTokenResponse -> new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken()))
-				.flatMap(result -> ReactiveSecurityContextHolder.getContext()
-						.map(SecurityContext::getAuthentication)
+				.flatMap(result -> currentAuthentication()
 						.defaultIfEmpty(new PrincipalNameAuthentication(authorizedClient.getPrincipalName()))
 						.flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(result, principal, exchange))
 						.thenReturn(result));
 	}
 
+	private Mono<Authentication> currentAuthentication() {
+		return ReactiveSecurityContextHolder.getContext()
+				.map(SecurityContext::getAuthentication)
+				.defaultIfEmpty(ANONYMOUS_USER_TOKEN);
+	}
+
 	private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
 		if (this.authorizedClientRepository == null) {
 			return false;

+ 25 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@@ -61,6 +61,7 @@ import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 import static org.springframework.http.HttpMethod.GET;
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
 import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient;
 
 /**
@@ -263,6 +264,30 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		assertThat(getBody(request0)).isEmpty();
 	}
 
+	@Test
+	public void filterWhenClientRegistrationIdThenAuthorizedClientResolved() {
+		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
+
+		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
+				"principalName", this.accessToken, refreshToken);
+		when(this.auth2AuthorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(clientRegistrationId(this.registration.getRegistrationId()))
+				.build();
+
+		this.function.filter(request, this.exchange).block();
+
+		List<ClientRequest> requests = this.exchange.getRequests();
+		assertThat(requests).hasSize(1);
+
+		ClientRequest request0 = requests.get(0);
+		assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
+		assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com");
+		assertThat(request0.method()).isEqualTo(HttpMethod.GET);
+		assertThat(getBody(request0)).isEmpty();
+	}
+
 	private static String getBody(ClientRequest request) {
 		final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
 		messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));