|
@@ -32,6 +32,7 @@ import java.util.stream.Collectors;
|
|
|
import com.nimbusds.oauth2.sdk.util.StringUtils;
|
|
|
import reactor.core.publisher.Mono;
|
|
|
|
|
|
+import org.springframework.context.ApplicationContext;
|
|
|
import org.springframework.core.convert.converter.Converter;
|
|
|
import org.springframework.http.client.reactive.ClientHttpConnector;
|
|
|
import org.springframework.lang.Nullable;
|
|
@@ -44,9 +45,13 @@ import org.springframework.security.core.context.SecurityContext;
|
|
|
import org.springframework.security.core.context.SecurityContextImpl;
|
|
|
import org.springframework.security.core.userdetails.User;
|
|
|
import org.springframework.security.core.userdetails.UserDetails;
|
|
|
+import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
|
|
|
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
|
|
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
|
|
|
+import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager;
|
|
|
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
|
|
|
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
|
|
+import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
|
|
|
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
|
|
import org.springframework.security.oauth2.client.web.server.WebSessionServerOAuth2AuthorizedClientRepository;
|
|
|
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
|
@@ -70,15 +75,21 @@ import org.springframework.security.oauth2.server.resource.introspection.OAuth2I
|
|
|
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors;
|
|
|
import org.springframework.security.web.server.csrf.CsrfWebFilter;
|
|
|
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
|
|
|
+import org.springframework.test.util.ReflectionTestUtils;
|
|
|
import org.springframework.test.web.reactive.server.MockServerConfigurer;
|
|
|
import org.springframework.test.web.reactive.server.WebTestClient;
|
|
|
import org.springframework.test.web.reactive.server.WebTestClientConfigurer;
|
|
|
import org.springframework.util.Assert;
|
|
|
+import org.springframework.util.ClassUtils;
|
|
|
+import org.springframework.web.reactive.result.method.HandlerMethodArgumentResolver;
|
|
|
+import org.springframework.web.reactive.result.method.annotation.ArgumentResolverConfigurer;
|
|
|
+import org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerAdapter;
|
|
|
import org.springframework.web.server.ServerWebExchange;
|
|
|
import org.springframework.web.server.WebFilter;
|
|
|
import org.springframework.web.server.WebFilterChain;
|
|
|
import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
|
|
|
|
|
|
+import static java.lang.Boolean.TRUE;
|
|
|
import static org.springframework.security.oauth2.jwt.JwtClaimNames.SUB;
|
|
|
|
|
|
/**
|
|
@@ -1121,9 +1132,18 @@ public class SecurityMockServerConfigurers {
|
|
|
|
|
|
private Consumer<List<WebFilter>> addAuthorizedClientFilter() {
|
|
|
OAuth2AuthorizedClient client = getClient();
|
|
|
- return filters -> filters.add(0, (exchange, chain) ->
|
|
|
- authorizedClientRepository.saveAuthorizedClient(client, null, exchange)
|
|
|
- .then(chain.filter(exchange)));
|
|
|
+ return filters -> filters.add(0, (exchange, chain) -> {
|
|
|
+ ReactiveOAuth2AuthorizedClientManager authorizationClientManager = OAuth2ClientServerTestUtils
|
|
|
+ .getOAuth2AuthorizedClientManager(exchange);
|
|
|
+ if (!(authorizationClientManager instanceof TestReactiveOAuth2AuthorizedClientManager)) {
|
|
|
+ authorizationClientManager =
|
|
|
+ new TestReactiveOAuth2AuthorizedClientManager(authorizationClientManager);
|
|
|
+ OAuth2ClientServerTestUtils.setOAuth2AuthorizedClientManager(exchange, authorizationClientManager);
|
|
|
+ }
|
|
|
+ TestReactiveOAuth2AuthorizedClientManager.enable(exchange);
|
|
|
+ exchange.getAttributes().put(TestReactiveOAuth2AuthorizedClientManager.TOKEN_ATTR_NAME, client);
|
|
|
+ return chain.filter(exchange);
|
|
|
+ });
|
|
|
}
|
|
|
|
|
|
private OAuth2AuthorizedClient getClient() {
|
|
@@ -1141,5 +1161,136 @@ public class SecurityMockServerConfigurers {
|
|
|
.clientSecret("test-secret")
|
|
|
.tokenUri("https://idp.example.org/oauth/token");
|
|
|
}
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Used to wrap the {@link OAuth2AuthorizedClientManager} to provide support for testing when the
|
|
|
+ * request is wrapped
|
|
|
+ */
|
|
|
+ private static class TestReactiveOAuth2AuthorizedClientManager
|
|
|
+ implements ReactiveOAuth2AuthorizedClientManager {
|
|
|
+
|
|
|
+ final static String TOKEN_ATTR_NAME = TestReactiveOAuth2AuthorizedClientManager.class
|
|
|
+ .getName().concat(".TOKEN");
|
|
|
+
|
|
|
+ final static String ENABLED_ATTR_NAME = TestReactiveOAuth2AuthorizedClientManager.class
|
|
|
+ .getName().concat(".ENABLED");
|
|
|
+
|
|
|
+ private final ReactiveOAuth2AuthorizedClientManager delegate;
|
|
|
+
|
|
|
+ private TestReactiveOAuth2AuthorizedClientManager(ReactiveOAuth2AuthorizedClientManager delegate) {
|
|
|
+ this.delegate = delegate;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizeRequest authorizeRequest) {
|
|
|
+ ServerWebExchange exchange =
|
|
|
+ authorizeRequest.getAttribute(ServerWebExchange.class.getName());
|
|
|
+ if (isEnabled(exchange)) {
|
|
|
+ OAuth2AuthorizedClient client = exchange.getAttribute(TOKEN_ATTR_NAME);
|
|
|
+ return Mono.just(client);
|
|
|
+ } else {
|
|
|
+ return this.delegate.authorize(authorizeRequest);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public static void enable(ServerWebExchange exchange) {
|
|
|
+ exchange.getAttributes().put(ENABLED_ATTR_NAME, TRUE);
|
|
|
+ }
|
|
|
+
|
|
|
+ public boolean isEnabled(ServerWebExchange exchange) {
|
|
|
+ return TRUE.equals(exchange.getAttribute(ENABLED_ATTR_NAME));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private static class OAuth2ClientServerTestUtils {
|
|
|
+ private static final ServerOAuth2AuthorizedClientRepository DEFAULT_CLIENT_REPO =
|
|
|
+ new WebSessionServerOAuth2AuthorizedClientRepository();
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Gets the {@link ReactiveOAuth2AuthorizedClientManager} for the specified {@link ServerWebExchange}.
|
|
|
+ * If one is not found, one based off of {@link WebSessionServerOAuth2AuthorizedClientRepository} is used.
|
|
|
+ *
|
|
|
+ * @param exchange the {@link ServerWebExchange} to obtain the
|
|
|
+ * {@link ReactiveOAuth2AuthorizedClientManager}
|
|
|
+ * @return the {@link ReactiveOAuth2AuthorizedClientManager} for the specified
|
|
|
+ * {@link ServerWebExchange}
|
|
|
+ */
|
|
|
+ public static ReactiveOAuth2AuthorizedClientManager getOAuth2AuthorizedClientManager(ServerWebExchange exchange) {
|
|
|
+ OAuth2AuthorizedClientArgumentResolver resolver =
|
|
|
+ findResolver(exchange, OAuth2AuthorizedClientArgumentResolver.class);
|
|
|
+ if (resolver == null) {
|
|
|
+ return authorizeRequest -> DEFAULT_CLIENT_REPO.loadAuthorizedClient
|
|
|
+ (authorizeRequest.getClientRegistrationId(), authorizeRequest.getPrincipal(), exchange);
|
|
|
+ }
|
|
|
+ return (ReactiveOAuth2AuthorizedClientManager)
|
|
|
+ ReflectionTestUtils.getField(resolver, "authorizedClientManager");
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Sets the {@link ReactiveOAuth2AuthorizedClientManager} for the specified {@link ServerWebExchange}.
|
|
|
+ *
|
|
|
+ * @param exchange the {@link ServerWebExchange} to obtain the
|
|
|
+ * {@link ReactiveOAuth2AuthorizedClientManager}
|
|
|
+ * @param manager the {@link ReactiveOAuth2AuthorizedClientManager} to set
|
|
|
+ */
|
|
|
+ public static void setOAuth2AuthorizedClientManager(ServerWebExchange exchange,
|
|
|
+ ReactiveOAuth2AuthorizedClientManager manager) {
|
|
|
+ OAuth2AuthorizedClientArgumentResolver resolver =
|
|
|
+ findResolver(exchange, OAuth2AuthorizedClientArgumentResolver.class);
|
|
|
+ if (resolver == null) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ ReflectionTestUtils.setField(resolver, "authorizedClientManager", manager);
|
|
|
+ }
|
|
|
+
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ static <T extends HandlerMethodArgumentResolver> T findResolver(ServerWebExchange exchange,
|
|
|
+ Class<T> resolverClass) {
|
|
|
+ if (!ClassUtils.isPresent
|
|
|
+ ("org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerAdapter", null)) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ return WebFluxClasspathGuard.findResolver(exchange, resolverClass);
|
|
|
+ }
|
|
|
+
|
|
|
+ private static class WebFluxClasspathGuard {
|
|
|
+ static <T extends HandlerMethodArgumentResolver> T findResolver(ServerWebExchange exchange,
|
|
|
+ Class<T> resolverClass) {
|
|
|
+ RequestMappingHandlerAdapter handlerAdapter = getRequestMappingHandlerAdapter(exchange);
|
|
|
+ if (handlerAdapter == null) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ ArgumentResolverConfigurer configurer = handlerAdapter.getArgumentResolverConfigurer();
|
|
|
+ if (configurer == null) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ List<HandlerMethodArgumentResolver> resolvers = (List<HandlerMethodArgumentResolver>)
|
|
|
+ ReflectionTestUtils.invokeGetterMethod(configurer, "customResolvers");
|
|
|
+ if (resolvers == null) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ for (HandlerMethodArgumentResolver resolver : resolvers) {
|
|
|
+ if (resolverClass.isAssignableFrom(resolver.getClass())) {
|
|
|
+ return (T) resolver;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+
|
|
|
+ private static RequestMappingHandlerAdapter getRequestMappingHandlerAdapter(ServerWebExchange exchange) {
|
|
|
+ ApplicationContext context = exchange.getApplicationContext();
|
|
|
+ if (context != null) {
|
|
|
+ String[] names = context.getBeanNamesForType(RequestMappingHandlerAdapter.class);
|
|
|
+ if (names.length > 0) {
|
|
|
+ return (RequestMappingHandlerAdapter) context.getBean(names[0]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private OAuth2ClientServerTestUtils() {
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
}
|