Browse Source

Inject TestOAuth2AuthorizedClientRepository

Fixes gh-8603
Josh Cummings 5 năm trước cách đây
mục cha
commit
900f551890

+ 1 - 0
test/spring-security-test.gradle

@@ -11,6 +11,7 @@ dependencies {
 	optional project(':spring-security-oauth2-jose')
 	optional project(':spring-security-oauth2-resource-server')
 	optional 'io.projectreactor:reactor-core'
+	optional 'org.springframework:spring-webmvc'
 	optional 'org.springframework:spring-webflux'
 
 	provided 'javax.servlet:javax.servlet-api'

+ 154 - 3
test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java

@@ -32,6 +32,7 @@ import java.util.stream.Collectors;
 import com.nimbusds.oauth2.sdk.util.StringUtils;
 import reactor.core.publisher.Mono;
 
+import org.springframework.context.ApplicationContext;
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.client.reactive.ClientHttpConnector;
 import org.springframework.lang.Nullable;
@@ -44,9 +45,13 @@ import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextImpl;
 import org.springframework.security.core.userdetails.User;
 import org.springframework.security.core.userdetails.UserDetails;
+import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
+import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.server.WebSessionServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
@@ -70,15 +75,21 @@ import org.springframework.security.oauth2.server.resource.introspection.OAuth2I
 import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors;
 import org.springframework.security.web.server.csrf.CsrfWebFilter;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
+import org.springframework.test.util.ReflectionTestUtils;
 import org.springframework.test.web.reactive.server.MockServerConfigurer;
 import org.springframework.test.web.reactive.server.WebTestClient;
 import org.springframework.test.web.reactive.server.WebTestClientConfigurer;
 import org.springframework.util.Assert;
+import org.springframework.util.ClassUtils;
+import org.springframework.web.reactive.result.method.HandlerMethodArgumentResolver;
+import org.springframework.web.reactive.result.method.annotation.ArgumentResolverConfigurer;
+import org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerAdapter;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebFilter;
 import org.springframework.web.server.WebFilterChain;
 import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
 
+import static java.lang.Boolean.TRUE;
 import static org.springframework.security.oauth2.jwt.JwtClaimNames.SUB;
 
 /**
@@ -1121,9 +1132,18 @@ public class SecurityMockServerConfigurers {
 
 		private Consumer<List<WebFilter>> addAuthorizedClientFilter() {
 			OAuth2AuthorizedClient client = getClient();
-			return filters -> filters.add(0, (exchange, chain) ->
-					authorizedClientRepository.saveAuthorizedClient(client, null, exchange)
-							.then(chain.filter(exchange)));
+			return filters -> filters.add(0, (exchange, chain) -> {
+				ReactiveOAuth2AuthorizedClientManager authorizationClientManager = OAuth2ClientServerTestUtils
+						.getOAuth2AuthorizedClientManager(exchange);
+				if (!(authorizationClientManager instanceof TestReactiveOAuth2AuthorizedClientManager)) {
+					authorizationClientManager =
+							new TestReactiveOAuth2AuthorizedClientManager(authorizationClientManager);
+					OAuth2ClientServerTestUtils.setOAuth2AuthorizedClientManager(exchange, authorizationClientManager);
+				}
+				TestReactiveOAuth2AuthorizedClientManager.enable(exchange);
+				exchange.getAttributes().put(TestReactiveOAuth2AuthorizedClientManager.TOKEN_ATTR_NAME, client);
+				return chain.filter(exchange);
+			});
 		}
 
 		private OAuth2AuthorizedClient getClient() {
@@ -1141,5 +1161,136 @@ public class SecurityMockServerConfigurers {
 					.clientSecret("test-secret")
 					.tokenUri("https://idp.example.org/oauth/token");
 		}
+
+		/**
+		 * Used to wrap the {@link OAuth2AuthorizedClientManager} to provide support for testing when the
+		 * request is wrapped
+		 */
+		private static class TestReactiveOAuth2AuthorizedClientManager
+				implements ReactiveOAuth2AuthorizedClientManager {
+
+			final static String TOKEN_ATTR_NAME = TestReactiveOAuth2AuthorizedClientManager.class
+					.getName().concat(".TOKEN");
+
+			final static String ENABLED_ATTR_NAME = TestReactiveOAuth2AuthorizedClientManager.class
+					.getName().concat(".ENABLED");
+
+			private final ReactiveOAuth2AuthorizedClientManager delegate;
+
+			private TestReactiveOAuth2AuthorizedClientManager(ReactiveOAuth2AuthorizedClientManager delegate) {
+				this.delegate = delegate;
+			}
+
+			@Override
+			public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizeRequest authorizeRequest) {
+				ServerWebExchange exchange =
+						authorizeRequest.getAttribute(ServerWebExchange.class.getName());
+				if (isEnabled(exchange)) {
+					OAuth2AuthorizedClient client = exchange.getAttribute(TOKEN_ATTR_NAME);
+					return Mono.just(client);
+				} else {
+					return this.delegate.authorize(authorizeRequest);
+				}
+			}
+
+			public static void enable(ServerWebExchange exchange) {
+				exchange.getAttributes().put(ENABLED_ATTR_NAME, TRUE);
+			}
+
+			public boolean isEnabled(ServerWebExchange exchange) {
+				return TRUE.equals(exchange.getAttribute(ENABLED_ATTR_NAME));
+			}
+		}
+
+		private static class OAuth2ClientServerTestUtils {
+			private static final ServerOAuth2AuthorizedClientRepository DEFAULT_CLIENT_REPO =
+					new WebSessionServerOAuth2AuthorizedClientRepository();
+
+			/**
+			 * Gets the {@link ReactiveOAuth2AuthorizedClientManager} for the specified {@link ServerWebExchange}.
+			 * If one is not found, one based off of {@link WebSessionServerOAuth2AuthorizedClientRepository} is used.
+			 *
+			 * @param exchange the {@link ServerWebExchange} to obtain the
+			 * {@link ReactiveOAuth2AuthorizedClientManager}
+			 * @return the {@link ReactiveOAuth2AuthorizedClientManager} for the specified
+			 * {@link ServerWebExchange}
+			 */
+			public static ReactiveOAuth2AuthorizedClientManager getOAuth2AuthorizedClientManager(ServerWebExchange exchange) {
+				OAuth2AuthorizedClientArgumentResolver resolver =
+						findResolver(exchange, OAuth2AuthorizedClientArgumentResolver.class);
+				if (resolver == null) {
+					return authorizeRequest -> DEFAULT_CLIENT_REPO.loadAuthorizedClient
+							(authorizeRequest.getClientRegistrationId(), authorizeRequest.getPrincipal(), exchange);
+				}
+				return (ReactiveOAuth2AuthorizedClientManager)
+						ReflectionTestUtils.getField(resolver, "authorizedClientManager");
+			}
+
+			/**
+			 * Sets the {@link ReactiveOAuth2AuthorizedClientManager} for the specified {@link ServerWebExchange}.
+			 *
+			 * @param exchange the {@link ServerWebExchange} to obtain the
+			 * {@link ReactiveOAuth2AuthorizedClientManager}
+			 * @param manager the {@link ReactiveOAuth2AuthorizedClientManager} to set
+			 */
+			public static void setOAuth2AuthorizedClientManager(ServerWebExchange exchange,
+					ReactiveOAuth2AuthorizedClientManager manager) {
+				OAuth2AuthorizedClientArgumentResolver resolver =
+						findResolver(exchange, OAuth2AuthorizedClientArgumentResolver.class);
+				if (resolver == null) {
+					return;
+				}
+				ReflectionTestUtils.setField(resolver, "authorizedClientManager", manager);
+			}
+
+			@SuppressWarnings("unchecked")
+			static <T extends HandlerMethodArgumentResolver> T findResolver(ServerWebExchange exchange,
+					Class<T> resolverClass) {
+				if (!ClassUtils.isPresent
+						("org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerAdapter", null)) {
+					return null;
+				}
+				return WebFluxClasspathGuard.findResolver(exchange, resolverClass);
+			}
+
+			private static class WebFluxClasspathGuard {
+				static <T extends HandlerMethodArgumentResolver> T findResolver(ServerWebExchange exchange,
+						Class<T> resolverClass) {
+					RequestMappingHandlerAdapter handlerAdapter = getRequestMappingHandlerAdapter(exchange);
+					if (handlerAdapter == null) {
+						return null;
+					}
+					ArgumentResolverConfigurer configurer = handlerAdapter.getArgumentResolverConfigurer();
+					if (configurer == null) {
+						return null;
+					}
+					List<HandlerMethodArgumentResolver> resolvers = (List<HandlerMethodArgumentResolver>)
+							ReflectionTestUtils.invokeGetterMethod(configurer, "customResolvers");
+					if (resolvers == null) {
+						return null;
+					}
+					for (HandlerMethodArgumentResolver resolver : resolvers) {
+						if (resolverClass.isAssignableFrom(resolver.getClass())) {
+							return (T) resolver;
+						}
+					}
+					return null;
+				}
+
+				private static RequestMappingHandlerAdapter getRequestMappingHandlerAdapter(ServerWebExchange exchange) {
+					ApplicationContext context = exchange.getApplicationContext();
+					if (context != null) {
+						String[] names = context.getBeanNamesForType(RequestMappingHandlerAdapter.class);
+						if (names.length > 0) {
+							return (RequestMappingHandlerAdapter) context.getBean(names[0]);
+						}
+					}
+					return null;
+				}
+			}
+
+			private OAuth2ClientServerTestUtils() {
+			}
+		}
 	}
 }

+ 147 - 3
test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java

@@ -35,6 +35,7 @@ import java.util.Set;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
 import java.util.stream.Collectors;
+import javax.servlet.ServletContext;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
@@ -56,11 +57,14 @@ import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.userdetails.User;
 import org.springframework.security.core.userdetails.UserDetails;
+import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.DefaultOAuth2AuthenticatedPrincipal;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
@@ -89,10 +93,16 @@ import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfToken;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
+import org.springframework.test.util.ReflectionTestUtils;
 import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.test.web.servlet.request.RequestPostProcessor;
 import org.springframework.util.Assert;
+import org.springframework.util.ClassUtils;
 import org.springframework.util.DigestUtils;
+import org.springframework.web.context.WebApplicationContext;
+import org.springframework.web.context.support.WebApplicationContextUtils;
+import org.springframework.web.method.support.HandlerMethodArgumentResolver;
+import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter;
 
 import static java.lang.Boolean.TRUE;
 import static org.springframework.security.oauth2.jwt.JwtClaimNames.SUB;
@@ -1657,9 +1667,16 @@ public final class SecurityMockMvcRequestPostProcessors {
 			}
 			OAuth2AuthorizedClient client = new OAuth2AuthorizedClient
 					(this.clientRegistration, this.principalName, this.accessToken);
-			OAuth2AuthorizedClientRepository authorizedClientRepository =
-					new HttpSessionOAuth2AuthorizedClientRepository();
-			authorizedClientRepository.saveAuthorizedClient(client, null, request, new MockHttpServletResponse());
+
+			OAuth2AuthorizedClientManager authorizationClientManager = OAuth2ClientServletTestUtils
+					.getOAuth2AuthorizedClientManager(request);
+			if (!(authorizationClientManager instanceof TestOAuth2AuthorizedClientManager)) {
+				authorizationClientManager =
+						new TestOAuth2AuthorizedClientManager(authorizationClientManager);
+				OAuth2ClientServletTestUtils.setOAuth2AuthorizedClientManager(request, authorizationClientManager);
+			}
+			TestOAuth2AuthorizedClientManager.enable(request);
+			request.setAttribute(TestOAuth2AuthorizedClientManager.TOKEN_ATTR_NAME, client);
 			return request;
 		}
 
@@ -1670,6 +1687,133 @@ public final class SecurityMockMvcRequestPostProcessors {
 					.clientSecret("test-secret")
 					.tokenUri("https://idp.example.org/oauth/token");
 		}
+
+		/**
+		 * Used to wrap the {@link OAuth2AuthorizedClientManager} to provide support for testing when the
+		 * request is wrapped
+		 */
+		private static class TestOAuth2AuthorizedClientManager
+				implements OAuth2AuthorizedClientManager {
+
+			final static String TOKEN_ATTR_NAME = TestOAuth2AuthorizedClientManager.class.getName()
+					.concat(".TOKEN");
+
+			final static String ENABLED_ATTR_NAME = TestOAuth2AuthorizedClientManager.class
+					.getName().concat(".ENABLED");
+
+			private final OAuth2AuthorizedClientManager delegate;
+
+			private TestOAuth2AuthorizedClientManager(OAuth2AuthorizedClientManager delegate) {
+				this.delegate = delegate;
+			}
+
+			@Override
+			public OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest) {
+				HttpServletRequest request =
+						authorizeRequest.getAttribute(HttpServletRequest.class.getName());
+				if (isEnabled(request)) {
+					return (OAuth2AuthorizedClient) request.getAttribute(TOKEN_ATTR_NAME);
+				} else {
+					return this.delegate.authorize(authorizeRequest);
+				}
+			}
+
+			public static void enable(HttpServletRequest request) {
+				request.setAttribute(ENABLED_ATTR_NAME, TRUE);
+			}
+
+			public boolean isEnabled(HttpServletRequest request) {
+				return TRUE.equals(request.getAttribute(ENABLED_ATTR_NAME));
+			}
+		}
+
+		private static class OAuth2ClientServletTestUtils {
+			private static final OAuth2AuthorizedClientRepository DEFAULT_CLIENT_REPO =
+					new HttpSessionOAuth2AuthorizedClientRepository();
+
+			/**
+			 * Gets the {@link OAuth2AuthorizedClientManager} for the specified {@link HttpServletRequest}.
+			 * If one is not found, one based off of {@link HttpSessionOAuth2AuthorizedClientRepository} is used.
+			 *
+			 * @param request the {@link HttpServletRequest} to obtain the
+			 * {@link OAuth2AuthorizedClientManager}
+			 * @return the {@link OAuth2AuthorizedClientManager} for the specified
+			 * {@link HttpServletRequest}
+			 */
+			public static OAuth2AuthorizedClientManager getOAuth2AuthorizedClientManager(HttpServletRequest request) {
+				OAuth2AuthorizedClientArgumentResolver resolver =
+						findResolver(request, OAuth2AuthorizedClientArgumentResolver.class);
+				if (resolver == null) {
+					return authorizeRequest -> DEFAULT_CLIENT_REPO.loadAuthorizedClient
+							(authorizeRequest.getClientRegistrationId(), authorizeRequest.getPrincipal(), request);
+				}
+				return (OAuth2AuthorizedClientManager)
+						ReflectionTestUtils.getField(resolver, "authorizedClientManager");
+			}
+
+			/**
+			 * Sets the {@link OAuth2AuthorizedClientManager} for the specified {@link HttpServletRequest}.
+			 *
+			 * @param request the {@link HttpServletRequest} to obtain the
+			 * {@link OAuth2AuthorizedClientManager}
+			 * @param manager the {@link OAuth2AuthorizedClientManager} to set
+			 */
+			public static void setOAuth2AuthorizedClientManager(HttpServletRequest request,
+					OAuth2AuthorizedClientManager manager) {
+				OAuth2AuthorizedClientArgumentResolver resolver =
+						findResolver(request, OAuth2AuthorizedClientArgumentResolver.class);
+				if (resolver == null) {
+					return;
+				}
+				ReflectionTestUtils.setField(resolver, "authorizedClientManager", manager);
+			}
+
+			@SuppressWarnings("unchecked")
+			static <T extends HandlerMethodArgumentResolver> T findResolver(HttpServletRequest request,
+					Class<T> resolverClass) {
+				if (!ClassUtils.isPresent
+						("org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter", null)) {
+					return null;
+				}
+				return WebMvcClasspathGuard.findResolver(request, resolverClass);
+			}
+
+			private static class WebMvcClasspathGuard {
+				static <T extends HandlerMethodArgumentResolver> T findResolver(HttpServletRequest request,
+						Class<T> resolverClass) {
+					ServletContext servletContext = request.getServletContext();
+					RequestMappingHandlerAdapter mapping = getRequestMappingHandlerAdapter(servletContext);
+					if (mapping == null) {
+						return null;
+					}
+					List<HandlerMethodArgumentResolver> resolvers = mapping.getCustomArgumentResolvers();
+					if (resolvers == null) {
+						return null;
+					}
+					for (HandlerMethodArgumentResolver resolver : resolvers) {
+						if (resolverClass.isAssignableFrom(resolver.getClass())) {
+							return (T) resolver;
+						}
+					}
+					return null;
+				}
+
+				private static RequestMappingHandlerAdapter getRequestMappingHandlerAdapter(ServletContext servletContext) {
+					WebApplicationContext context = WebApplicationContextUtils
+							.getWebApplicationContext(servletContext);
+					if (context != null) {
+						String[] names = context.getBeanNamesForType(RequestMappingHandlerAdapter.class);
+						if (names.length > 0) {
+							return (RequestMappingHandlerAdapter) context.getBean(names[0]);
+						}
+					}
+					return null;
+				}
+			}
+
+			private OAuth2ClientServletTestUtils() {
+			}
+		}
 	}
 
 	private SecurityMockMvcRequestPostProcessors() {

+ 37 - 5
test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOAuth2ClientTests.java

@@ -21,26 +21,32 @@ import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
+import reactor.core.publisher.Mono;
 
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.MediaType;
+import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
-import org.springframework.security.oauth2.client.web.server.WebSessionServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
 import org.springframework.test.web.reactive.server.WebTestClient;
 import org.springframework.web.bind.annotation.GetMapping;
 import org.springframework.web.bind.annotation.RestController;
 import org.springframework.web.reactive.DispatcherHandler;
+import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
 import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
 import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockOAuth2Client;
@@ -53,18 +59,18 @@ public class SecurityMockServerConfigurersOAuth2ClientTests extends AbstractMock
 	@Mock
 	private ReactiveClientRegistrationRepository clientRegistrationRepository;
 
+	@Mock
+	private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
+
 	private WebTestClient client;
 
 	@Before
 	public void setup() {
-		ServerOAuth2AuthorizedClientRepository authorizedClientRepository =
-				new WebSessionServerOAuth2AuthorizedClientRepository();
-
 		this.client = WebTestClient
 				.bindToController(this.controller)
 				.argumentResolvers(c -> c.addCustomResolver(
 						new OAuth2AuthorizedClientArgumentResolver
-								(this.clientRegistrationRepository, authorizedClientRepository)))
+								(this.clientRegistrationRepository, this.authorizedClientRepository)))
 				.webFilter(new SecurityContextServerWebExchangeWebFilter())
 				.apply(springSecurity())
 				.configureClient()
@@ -162,6 +168,32 @@ public class SecurityMockServerConfigurersOAuth2ClientTests extends AbstractMock
 		assertThat(client.getRefreshToken()).isNull();
 	}
 
+	@Test
+	public void oauth2ClientWhenUsedOnceThenDoesNotAffectRemainingTests() throws Exception {
+		this.client.mutateWith(mockOAuth2Client("registration-id"))
+				.get().uri("/client")
+				.exchange()
+				.expectStatus().isOk();
+
+		OAuth2AuthorizedClient client = this.controller.authorizedClient;
+		assertThat(client).isNotNull();
+		assertThat(client.getClientRegistration().getClientId()).isEqualTo("test-client");
+
+		client = new OAuth2AuthorizedClient(clientRegistration().build(), "sub", noScopes());
+		when(this.authorizedClientRepository
+				.loadAuthorizedClient(eq("registration-id"), any(Authentication.class), any(ServerWebExchange.class)))
+				.thenReturn(Mono.just(client));
+		this.client
+				.get().uri("/client")
+				.exchange()
+				.expectStatus().isOk();
+		client = this.controller.authorizedClient;
+		assertThat(client).isNotNull();
+		assertThat(client.getClientRegistration().getClientId()).isEqualTo("client-id");
+		verify(this.authorizedClientRepository).loadAuthorizedClient(
+				eq("registration-id"), any(Authentication.class), any(ServerWebExchange.class));
+	}
+
 	@RestController
 	static class OAuth2LoginController {
 		volatile OAuth2AuthorizedClient authorizedClient;

+ 4 - 5
test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOAuth2LoginTests.java

@@ -36,7 +36,6 @@ import org.springframework.security.oauth2.client.authentication.OAuth2Authentic
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
-import org.springframework.security.oauth2.client.web.server.WebSessionServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
@@ -55,18 +54,18 @@ public class SecurityMockServerConfigurersOAuth2LoginTests extends AbstractMockS
 	@Mock
 	private ReactiveClientRegistrationRepository clientRegistrationRepository;
 
+	@Mock
+	private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
+
 	private WebTestClient client;
 
 	@Before
 	public void setup() {
-		ServerOAuth2AuthorizedClientRepository authorizedClientRepository =
-				new WebSessionServerOAuth2AuthorizedClientRepository();
-
 		this.client = WebTestClient
 				.bindToController(this.controller)
 				.argumentResolvers(c -> c.addCustomResolver(
 						new OAuth2AuthorizedClientArgumentResolver
-								(this.clientRegistrationRepository, authorizedClientRepository)))
+								(this.clientRegistrationRepository, this.authorizedClientRepository)))
 				.webFilter(new SecurityContextServerWebExchangeWebFilter())
 				.apply(springSecurity())
 				.configureClient()

+ 5 - 6
test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOidcLoginTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -35,7 +35,6 @@ import org.springframework.security.oauth2.client.authentication.OAuth2Authentic
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
-import org.springframework.security.oauth2.client.web.server.WebSessionServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
 import org.springframework.security.oauth2.core.oidc.user.OidcUser;
@@ -57,18 +56,18 @@ public class SecurityMockServerConfigurersOidcLoginTests extends AbstractMockSer
 	@Mock
 	private ReactiveClientRegistrationRepository clientRegistrationRepository;
 
+	@Mock
+	private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
+
 	private WebTestClient client;
 
 	@Before
 	public void setup() {
-		ServerOAuth2AuthorizedClientRepository authorizedClientRepository =
-				new WebSessionServerOAuth2AuthorizedClientRepository();
-
 		this.client = WebTestClient
 				.bindToController(this.controller)
 				.argumentResolvers(c -> c.addCustomResolver(
 						new OAuth2AuthorizedClientArgumentResolver
-								(this.clientRegistrationRepository, authorizedClientRepository)))
+								(this.clientRegistrationRepository, this.authorizedClientRepository)))
 				.webFilter(new SecurityContextServerWebExchangeWebFilter())
 				.apply(springSecurity())
 				.configureClient()

+ 24 - 2
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOAuth2ClientTests.java

@@ -15,6 +15,8 @@
  */
 package org.springframework.security.test.web.servlet.request;
 
+import javax.servlet.http.HttpServletRequest;
+
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -26,11 +28,11 @@ import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
+import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
-import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.test.context.TestSecurityContextHolder;
@@ -45,7 +47,11 @@ import org.springframework.web.context.WebApplicationContext;
 import org.springframework.web.servlet.config.annotation.EnableWebMvc;
 
 import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
 import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.oauth2Client;
@@ -138,6 +144,22 @@ public class SecurityMockMvcRequestPostProcessorsOAuth2ClientTests {
 				.andExpect(content().string("no-scopes"));
 	}
 
+	@Test
+	public void oauth2ClientWhenUsedOnceThenDoesNotAffectRemainingTests() throws Exception {
+		this.mvc.perform(get("/client-id")
+				.with(oauth2Client("registration-id")))
+				.andExpect(content().string("test-client"));
+
+		OAuth2AuthorizedClient client = new OAuth2AuthorizedClient(clientRegistration().build(), "sub", noScopes());
+		OAuth2AuthorizedClientRepository repository = this.context.getBean(OAuth2AuthorizedClientRepository.class);
+		when(repository.loadAuthorizedClient(eq("registration-id"), any(Authentication.class), any(HttpServletRequest.class)))
+				.thenReturn(client);
+		this.mvc.perform(get("/client-id"))
+				.andExpect(content().string("client-id"));
+		verify(repository).loadAuthorizedClient(
+				eq("registration-id"), any(Authentication.class), any(HttpServletRequest.class));
+	}
+
 	@EnableWebSecurity
 	@EnableWebMvc
 	static class OAuth2ClientConfig extends WebSecurityConfigurerAdapter {
@@ -158,7 +180,7 @@ public class SecurityMockMvcRequestPostProcessorsOAuth2ClientTests {
 
 		@Bean
 		OAuth2AuthorizedClientRepository authorizedClientRepository() {
-			return new HttpSessionOAuth2AuthorizedClientRepository();
+			return mock(OAuth2AuthorizedClientRepository.class);
 		}
 
 		@RestController

+ 2 - 3
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOAuth2LoginTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -37,7 +37,6 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
-import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
 import org.springframework.security.oauth2.core.user.OAuth2User;
@@ -182,7 +181,7 @@ public class SecurityMockMvcRequestPostProcessorsOAuth2LoginTests {
 
 		@Bean
 		OAuth2AuthorizedClientRepository oAuth2AuthorizedClientRepository() {
-			return new HttpSessionOAuth2AuthorizedClientRepository();
+			return mock(OAuth2AuthorizedClientRepository.class);
 		}
 
 		@RestController

+ 2 - 3
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOidcLoginTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -36,7 +36,6 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
-import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
@@ -190,7 +189,7 @@ public class SecurityMockMvcRequestPostProcessorsOidcLoginTests {
 
 		@Bean
 		OAuth2AuthorizedClientRepository oAuth2AuthorizedClientRepository() {
-			return new HttpSessionOAuth2AuthorizedClientRepository();
+			return mock(OAuth2AuthorizedClientRepository.class);
 		}
 
 		@RestController