Răsfoiți Sursa

DefaultReactiveOAuth2AuthorizedClientManager defaults ServerWebExchange

Fixes gh-7390
Joe Grandja 6 ani în urmă
părinte
comite
dcdeab596d

+ 42 - 20
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java

@@ -70,35 +70,52 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
 
 		String clientRegistrationId = authorizeRequest.getClientRegistrationId();
 		Authentication principal = authorizeRequest.getPrincipal();
-
 		ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
-		Assert.notNull(serverWebExchange, "serverWebExchange cannot be null");
 
 		return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient())
-				.switchIfEmpty(Mono.defer(() ->
-						this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange)))
+				.switchIfEmpty(Mono.defer(() -> loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange)))
 				.flatMap(authorizedClient -> {
 					// Re-authorize
 					return authorizationContext(authorizeRequest, authorizedClient)
 							.flatMap(this.authorizedClientProvider::authorize)
-							.doOnNext(reauthorizedClient ->
-									this.authorizedClientRepository.saveAuthorizedClient(
-											reauthorizedClient, principal, serverWebExchange))
+							.flatMap(reauthorizedClient -> saveAuthorizedClient(reauthorizedClient, principal, serverWebExchange))
 							// Default to the existing authorizedClient if the client was not re-authorized
 							.defaultIfEmpty(authorizeRequest.getAuthorizedClient() != null ?
 									authorizeRequest.getAuthorizedClient() : authorizedClient);
 				})
-				.switchIfEmpty(Mono.defer(() ->
+				.switchIfEmpty(Mono.deferWithContext(context ->
 						// Authorize
 						this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
 								.switchIfEmpty(Mono.error(() -> new IllegalArgumentException(
 										"Could not find ClientRegistration with id '" + clientRegistrationId + "'")))
 								.flatMap(clientRegistration -> authorizationContext(authorizeRequest, clientRegistration))
 								.flatMap(this.authorizedClientProvider::authorize)
-								.doOnNext(authorizedClient ->
-										this.authorizedClientRepository.saveAuthorizedClient(
-												authorizedClient, principal, serverWebExchange))
-				));
+								.flatMap(authorizedClient -> saveAuthorizedClient(authorizedClient, principal, serverWebExchange))
+								.subscriberContext(context)
+						)
+				);
+	}
+
+	private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId, Authentication principal, ServerWebExchange serverWebExchange) {
+		return Mono.justOrEmpty(serverWebExchange)
+				.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
+				.flatMap(exchange -> this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange));
+	}
+
+	private Mono<OAuth2AuthorizedClient> saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, ServerWebExchange serverWebExchange) {
+		return Mono.justOrEmpty(serverWebExchange)
+				.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
+				.map(exchange -> {
+					this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, exchange);
+					return authorizedClient;
+				})
+				.defaultIfEmpty(authorizedClient);
+	}
+
+	private static Mono<ServerWebExchange> currentServerWebExchange() {
+		return Mono.subscriberContext()
+				.filter(c -> c.hasKey(ServerWebExchange.class))
+				.map(c -> c.get(ServerWebExchange.class));
 	}
 
 	private Mono<OAuth2AuthorizationContext> authorizationContext(OAuth2AuthorizeRequest authorizeRequest,
@@ -158,15 +175,20 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
 
 		@Override
 		public Mono<Map<String, Object>> apply(OAuth2AuthorizeRequest authorizeRequest) {
-			Map<String, Object> contextAttributes = Collections.emptyMap();
 			ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
-			String scope = serverWebExchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE);
-			if (StringUtils.hasText(scope)) {
-				contextAttributes = new HashMap<>();
-				contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME,
-						StringUtils.delimitedListToStringArray(scope, " "));
-			}
-			return Mono.just(contextAttributes);
+			return Mono.justOrEmpty(serverWebExchange)
+					.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
+					.flatMap(exchange -> {
+						Map<String, Object> contextAttributes = Collections.emptyMap();
+						String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE);
+						if (StringUtils.hasText(scope)) {
+							contextAttributes = new HashMap<>();
+							contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME,
+									StringUtils.delimitedListToStringArray(scope, " "));
+						}
+						return Mono.just(contextAttributes);
+					})
+					.defaultIfEmpty(Collections.emptyMap());
 		}
 	}
 }

+ 59 - 41
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java

@@ -34,9 +34,9 @@ import org.springframework.security.oauth2.client.web.server.ServerOAuth2Authori
 import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
 import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
-import org.springframework.util.StringUtils;
 import org.springframework.web.server.ServerWebExchange;
 import reactor.core.publisher.Mono;
+import reactor.util.context.Context;
 
 import java.util.Collections;
 import java.util.HashMap;
@@ -64,6 +64,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 	private Authentication principal;
 	private OAuth2AuthorizedClient authorizedClient;
 	private MockServerWebExchange serverWebExchange;
+	private Context context;
 	private ArgumentCaptor<OAuth2AuthorizationContext> authorizationContextCaptor;
 
 	@SuppressWarnings("unchecked")
@@ -75,6 +76,8 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 		this.authorizedClientRepository = mock(ServerOAuth2AuthorizedClientRepository.class);
 		when(this.authorizedClientRepository.loadAuthorizedClient(
 				anyString(), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(Mono.empty());
+		when(this.authorizedClientRepository.saveAuthorizedClient(
+				any(OAuth2AuthorizedClient.class), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(Mono.empty());
 		this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class);
 		when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.empty());
 		this.contextAttributesMapper = mock(Function.class);
@@ -88,6 +91,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 		this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(),
 				TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken());
 		this.serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build();
+		this.context = Context.of(ServerWebExchange.class, this.serverWebExchange);
 		this.authorizationContextCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationContext.class);
 	}
 
@@ -119,16 +123,6 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 				.hasMessage("contextAttributesMapper cannot be null");
 	}
 
-	@Test
-	public void authorizeWhenServerWebExchangeIsNullThenThrowIllegalArgumentException() {
-		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
-				.principal(this.principal)
-				.build();
-		assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block())
-				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("serverWebExchange cannot be null");
-	}
-
 	@Test
 	public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> this.authorizedClientManager.authorize(null).block())
@@ -140,9 +134,8 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 	public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() {
 		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId("invalid-registration-id")
 				.principal(this.principal)
-				.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
 				.build();
-		assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block())
+		assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block())
 				.isInstanceOf(IllegalArgumentException.class)
 				.hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'");
 	}
@@ -155,9 +148,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 
 		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
 				.principal(this.principal)
-				.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
 				.build();
-		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest)
+				.subscriberContext(this.context).block();
 
 		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
 		verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
@@ -168,8 +161,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
 
 		assertThat(authorizedClient).isNull();
-		verify(this.authorizedClientRepository, never()).saveAuthorizedClient(
-				any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.serverWebExchange));
+		verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
 	}
 
 	@SuppressWarnings("unchecked")
@@ -177,15 +169,14 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 	public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() {
 		when(this.clientRegistrationRepository.findByRegistrationId(
 				eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
-
 		when(this.authorizedClientProvider.authorize(
 				any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient));
 
 		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
 				.principal(this.principal)
-				.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
 				.build();
-		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest)
+				.subscriberContext(this.context).block();
 
 		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
 		verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
@@ -200,6 +191,31 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 				eq(this.authorizedClient), eq(this.principal), eq(this.serverWebExchange));
 	}
 
+	@Test
+	public void authorizeWhenNotAuthorizedAndSupportedProviderAndExchangeUnavailableThenAuthorizedButNotSaved() {
+		when(this.clientRegistrationRepository.findByRegistrationId(
+				eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
+
+		when(this.authorizedClientProvider.authorize(
+				any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient));
+
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
+				.principal(this.principal)
+				.build();
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
+
+		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
+		verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
+
+		OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
+		assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationContext.getAuthorizedClient()).isNull();
+		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
+
+		assertThat(authorizedClient).isSameAs(this.authorizedClient);
+		verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
+	}
+
 	@SuppressWarnings("unchecked")
 	@Test
 	public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() {
@@ -216,9 +232,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 
 		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
 				.principal(this.principal)
-				.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
 				.build();
-		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest)
+				.subscriberContext(this.context).block();
 
 		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
 		verify(this.contextAttributesMapper).apply(any());
@@ -241,21 +257,18 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 		when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient));
 
 		// Set custom contextAttributesMapper capable of mapping the form parameters
-		this.authorizedClientManager.setContextAttributesMapper(authorizeRequest -> {
-			ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
-			return Mono.just(serverWebExchange)
+		this.authorizedClientManager.setContextAttributesMapper(authorizeRequest ->
+				currentServerWebExchange()
 					.flatMap(ServerWebExchange::getFormData)
 					.map(formData -> {
 						Map<String, Object> contextAttributes = new HashMap<>();
 						String username = formData.getFirst(OAuth2ParameterNames.USERNAME);
+						contextAttributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username);
 						String password = formData.getFirst(OAuth2ParameterNames.PASSWORD);
-						if (StringUtils.hasText(username) && StringUtils.hasText(password)) {
-							contextAttributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username);
-							contextAttributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password);
-						}
+						contextAttributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password);
 						return contextAttributes;
-					});
-		});
+					})
+		);
 
 		this.serverWebExchange = MockServerWebExchange.builder(
 				MockServerHttpRequest
@@ -263,12 +276,12 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 						.contentType(MediaType.APPLICATION_FORM_URLENCODED)
 						.body("username=username&password=password"))
 				.build();
+		this.context = Context.of(ServerWebExchange.class, this.serverWebExchange);
 
 		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
 				.principal(this.principal)
-				.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
 				.build();
-		this.authorizedClientManager.authorize(authorizeRequest).block();
+		this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block();
 
 		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
 
@@ -284,9 +297,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 	public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() {
 		OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
 				.principal(this.principal)
-				.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
 				.build();
-		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block();
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest)
+				.subscriberContext(this.context).block();
 
 		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
 		verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
@@ -297,8 +310,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
 
 		assertThat(authorizedClient).isSameAs(this.authorizedClient);
-		verify(this.authorizedClientRepository, never()).saveAuthorizedClient(
-				any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.serverWebExchange));
+		verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
 	}
 
 	@SuppressWarnings("unchecked")
@@ -312,9 +324,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 
 		OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
 				.principal(this.principal)
-				.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
 				.build();
-		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block();
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest)
+				.subscriberContext(this.context).block();
 
 		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
 		verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
@@ -346,12 +358,12 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 						.get("/")
 						.queryParam(OAuth2ParameterNames.SCOPE, "read write"))
 				.build();
+		this.context = Context.of(ServerWebExchange.class, this.serverWebExchange);
 
 		OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
 				.principal(this.principal)
-				.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
 				.build();
-		this.authorizedClientManager.authorize(reauthorizeRequest).block();
+		this.authorizedClientManager.authorize(reauthorizeRequest).subscriberContext(this.context).block();
 
 		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
 
@@ -359,4 +371,10 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
 		String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME);
 		assertThat(requestScopeAttribute).contains("read", "write");
 	}
+
+	private Mono<ServerWebExchange> currentServerWebExchange() {
+		return Mono.subscriberContext()
+				.filter(c -> c.hasKey(ServerWebExchange.class))
+				.map(c -> c.get(ServerWebExchange.class));
+	}
 }