Selaa lähdekoodia

Don't force downcasting of RequestAttributes to ServletRequestAttributes

Fixes gh-7952
Stephane Maldini 5 vuotta sitten
vanhempi
commit
851be025e9

+ 10 - 20
config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@ import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Configuration;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.web.context.request.RequestAttributes;
 import org.springframework.web.context.request.RequestContextHolder;
 import org.springframework.web.context.request.ServletRequestAttributes;
 import reactor.core.CoreSubscriber;
@@ -92,32 +93,21 @@ class SecurityReactorContextConfiguration {
 		}
 
 		private static boolean contextAttributesAvailable() {
-			HttpServletRequest servletRequest = null;
-			HttpServletResponse servletResponse = null;
-			ServletRequestAttributes requestAttributes =
-					(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
-			if (requestAttributes != null) {
-				servletRequest = requestAttributes.getRequest();
-				servletResponse = requestAttributes.getResponse();
-			}
-			Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
-			if (authentication != null || servletRequest != null || servletResponse != null) {
-				return true;
-			}
-			return false;
+			return SecurityContextHolder.getContext().getAuthentication() != null ||
+					RequestContextHolder.getRequestAttributes() instanceof ServletRequestAttributes;
 		}
 
 		private static Map<Object, Object> getContextAttributes() {
 			HttpServletRequest servletRequest = null;
 			HttpServletResponse servletResponse = null;
-			ServletRequestAttributes requestAttributes =
-					(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
-			if (requestAttributes != null) {
-				servletRequest = requestAttributes.getRequest();
-				servletResponse = requestAttributes.getResponse();
+			RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
+			if (requestAttributes instanceof ServletRequestAttributes) {
+				ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) requestAttributes;
+				servletRequest = servletRequestAttributes.getRequest();
+				servletResponse = servletRequestAttributes.getResponse();	// possible null
 			}
 			Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
-			if (authentication == null && servletRequest == null && servletResponse == null) {
+			if (authentication == null && servletRequest == null) {
 				return Collections.emptyMap();
 			}
 

+ 49 - 1
config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -28,6 +28,7 @@ import org.springframework.security.config.test.SpringTestRule;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction;
+import org.springframework.web.context.request.RequestAttributes;
 import org.springframework.web.context.request.RequestContextHolder;
 import org.springframework.web.context.request.ServletRequestAttributes;
 import org.springframework.web.reactive.function.client.ClientRequest;
@@ -36,6 +37,7 @@ import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
 import reactor.core.CoreSubscriber;
 import reactor.core.publisher.BaseSubscriber;
 import reactor.core.publisher.Mono;
+import reactor.core.publisher.Operators;
 import reactor.test.StepVerifier;
 import reactor.util.context.Context;
 
@@ -139,6 +141,52 @@ public class SecurityReactorContextConfigurationTests {
 		assertThat(resultContext).isSameAs(parentContext);
 	}
 
+	@Test
+	public void createSubscriberIfNecessaryWhenNotServletRequestAttributesThenStillCreate() {
+		RequestContextHolder.setRequestAttributes(
+				new RequestAttributes() {
+					@Override
+					public Object getAttribute(String name, int scope) {
+						return null;
+					}
+
+					@Override
+					public void setAttribute(String name, Object value, int scope) {
+					}
+
+					@Override
+					public void removeAttribute(String name, int scope) {
+					}
+
+					@Override
+					public String[] getAttributeNames(int scope) {
+						return new String[0];
+					}
+
+					@Override
+					public void registerDestructionCallback(String name, Runnable callback, int scope) {
+					}
+
+					@Override
+					public Object resolveReference(String key) {
+						return null;
+					}
+
+					@Override
+					public String getSessionId() {
+						return null;
+					}
+
+					@Override
+					public Object getSessionMutex() {
+						return null;
+					}
+				});
+
+		CoreSubscriber<Object> subscriber = this.subscriberRegistrar.createSubscriberIfNecessary(Operators.emptySubscriber());
+		assertThat(subscriber).isInstanceOf(SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.class);
+	}
+
 	@Test
 	public void createPublisherWhenLastOperatorAddedThenSecurityContextAttributesAvailable() {
 		// Trigger the importing of SecurityReactorContextConfiguration via OAuth2ImportSelector

+ 8 - 7
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -35,6 +35,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
 import org.springframework.util.StringUtils;
+import org.springframework.web.context.request.RequestAttributes;
 import org.springframework.web.context.request.RequestContextHolder;
 import org.springframework.web.context.request.ServletRequestAttributes;
 
@@ -121,9 +122,9 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
 	private static HttpServletRequest getHttpServletRequestOrDefault(Map<String, Object> attributes) {
 		HttpServletRequest servletRequest = (HttpServletRequest) attributes.get(HttpServletRequest.class.getName());
 		if (servletRequest == null) {
-			ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
-			if (context != null) {
-				servletRequest = context.getRequest();
+			RequestAttributes context = RequestContextHolder.getRequestAttributes();
+			if (context instanceof ServletRequestAttributes) {
+				servletRequest = ((ServletRequestAttributes) context).getRequest();
 			}
 		}
 		return servletRequest;
@@ -132,9 +133,9 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
 	private static HttpServletResponse getHttpServletResponseOrDefault(Map<String, Object> attributes) {
 		HttpServletResponse servletResponse = (HttpServletResponse) attributes.get(HttpServletResponse.class.getName());
 		if (servletResponse == null) {
-			ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
-			if (context != null) {
-				servletResponse = context.getResponse();
+			RequestAttributes context = RequestContextHolder.getRequestAttributes();
+			if (context instanceof ServletRequestAttributes) {
+				servletResponse =  ((ServletRequestAttributes) context).getResponse();
 			}
 		}
 		return servletResponse;

+ 6 - 9
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -36,6 +36,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.util.Assert;
+import org.springframework.web.context.request.RequestAttributes;
 import org.springframework.web.context.request.RequestContextHolder;
 import org.springframework.web.context.request.ServletRequestAttributes;
 import org.springframework.web.reactive.function.client.ClientRequest;
@@ -389,15 +390,11 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 				attrs.containsKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
 			return;
 		}
-		ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
-		HttpServletRequest request = null;
-		HttpServletResponse response = null;
-		if (context != null) {
-			request = context.getRequest();
-			response = context.getResponse();
+		RequestAttributes context = RequestContextHolder.getRequestAttributes();
+		if (context instanceof ServletRequestAttributes) {
+			attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME,  ((ServletRequestAttributes) context).getRequest());
+			attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ((ServletRequestAttributes) context).getResponse());
 		}
-		attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request);
-		attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response);
 	}
 
 	private void populateDefaultAuthentication(Map<String, Object> attrs) {