2
0
Эх сурвалжийг харах

Allow Configure RequestRjectedHandler in XML

Issue gh-5007
Rob Winch 5 жил өмнө
parent
commit
4a9fa0337a

+ 44 - 2
config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java

@@ -24,14 +24,19 @@ import org.apache.commons.logging.LogFactory;
 import org.w3c.dom.Element;
 import org.w3c.dom.Element;
 
 
 import org.springframework.beans.BeanMetadataElement;
 import org.springframework.beans.BeanMetadataElement;
+import org.springframework.beans.BeansException;
 import org.springframework.beans.factory.config.BeanDefinition;
 import org.springframework.beans.factory.config.BeanDefinition;
 import org.springframework.beans.factory.config.BeanReference;
 import org.springframework.beans.factory.config.BeanReference;
+import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
 import org.springframework.beans.factory.config.ListFactoryBean;
 import org.springframework.beans.factory.config.ListFactoryBean;
 import org.springframework.beans.factory.config.MethodInvokingFactoryBean;
 import org.springframework.beans.factory.config.MethodInvokingFactoryBean;
 import org.springframework.beans.factory.config.RuntimeBeanReference;
 import org.springframework.beans.factory.config.RuntimeBeanReference;
 import org.springframework.beans.factory.parsing.BeanComponentDefinition;
 import org.springframework.beans.factory.parsing.BeanComponentDefinition;
 import org.springframework.beans.factory.parsing.CompositeComponentDefinition;
 import org.springframework.beans.factory.parsing.CompositeComponentDefinition;
+import org.springframework.beans.factory.support.AbstractBeanDefinition;
 import org.springframework.beans.factory.support.BeanDefinitionBuilder;
 import org.springframework.beans.factory.support.BeanDefinitionBuilder;
+import org.springframework.beans.factory.support.BeanDefinitionRegistry;
+import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
 import org.springframework.beans.factory.support.ManagedList;
 import org.springframework.beans.factory.support.ManagedList;
 import org.springframework.beans.factory.support.RootBeanDefinition;
 import org.springframework.beans.factory.support.RootBeanDefinition;
 import org.springframework.beans.factory.xml.BeanDefinitionParser;
 import org.springframework.beans.factory.xml.BeanDefinitionParser;
@@ -393,7 +398,8 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
 	}
 	}
 
 
 	static void registerFilterChainProxyIfNecessary(ParserContext pc, Object source) {
 	static void registerFilterChainProxyIfNecessary(ParserContext pc, Object source) {
-		if (pc.getRegistry().containsBeanDefinition(BeanIds.FILTER_CHAIN_PROXY)) {
+		BeanDefinitionRegistry registry = pc.getRegistry();
+		if (registry.containsBeanDefinition(BeanIds.FILTER_CHAIN_PROXY)) {
 			return;
 			return;
 		}
 		}
 		// Not already registered, so register the list of filter chains and the
 		// Not already registered, so register the list of filter chains and the
@@ -412,10 +418,46 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
 		BeanDefinition fcpBean = fcpBldr.getBeanDefinition();
 		BeanDefinition fcpBean = fcpBldr.getBeanDefinition();
 		pc.registerBeanComponent(new BeanComponentDefinition(fcpBean,
 		pc.registerBeanComponent(new BeanComponentDefinition(fcpBean,
 				BeanIds.FILTER_CHAIN_PROXY));
 				BeanIds.FILTER_CHAIN_PROXY));
-		pc.getRegistry().registerAlias(BeanIds.FILTER_CHAIN_PROXY,
+		registry.registerAlias(BeanIds.FILTER_CHAIN_PROXY,
 				BeanIds.SPRING_SECURITY_FILTER_CHAIN);
 				BeanIds.SPRING_SECURITY_FILTER_CHAIN);
+
+		BeanDefinitionBuilder requestRejected = BeanDefinitionBuilder.rootBeanDefinition(RequestRejectedHandlerPostProcessor.class);
+		requestRejected.setRole(BeanDefinition.ROLE_INFRASTRUCTURE);
+		requestRejected.addConstructorArgValue("requestRejectedHandler");
+		requestRejected.addConstructorArgValue(BeanIds.FILTER_CHAIN_PROXY);
+		requestRejected.addConstructorArgValue("requestRejectedHandler");
+		AbstractBeanDefinition requestRejectedBean = requestRejected.getBeanDefinition();
+		String requestRejectedPostProcessorName = pc.getReaderContext().generateBeanName(requestRejectedBean);
+		registry.registerBeanDefinition(requestRejectedPostProcessorName, requestRejectedBean);
+	}
+
+}
+
+class RequestRejectedHandlerPostProcessor implements BeanDefinitionRegistryPostProcessor {
+	private final String beanName;
+
+	private final String targetBeanName;
+
+	private final String targetPropertyName;
+
+	RequestRejectedHandlerPostProcessor(String beanName, String targetBeanName, String targetPropertyName) {
+		this.beanName = beanName;
+		this.targetBeanName = targetBeanName;
+		this.targetPropertyName = targetPropertyName;
 	}
 	}
 
 
+	@Override
+	public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
+		if (registry.containsBeanDefinition(this.beanName)) {
+			BeanDefinition beanDefinition = registry.getBeanDefinition(this.targetBeanName);
+			beanDefinition.getPropertyValues().add(this.targetPropertyName, new RuntimeBeanReference(this.beanName));
+		}
+	}
+
+	@Override
+	public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
+
+	}
 }
 }
 
 
 class OrderDecorator implements Ordered {
 class OrderDecorator implements Ordered {

+ 17 - 0
config/src/test/java/org/springframework/security/config/http/MiscHttpConfigTests.java

@@ -94,6 +94,8 @@ import org.springframework.security.web.context.request.async.WebAsyncManagerInt
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.firewall.FirewalledRequest;
 import org.springframework.security.web.firewall.FirewalledRequest;
 import org.springframework.security.web.firewall.HttpFirewall;
 import org.springframework.security.web.firewall.HttpFirewall;
+import org.springframework.security.web.firewall.RequestRejectedException;
+import org.springframework.security.web.firewall.RequestRejectedHandler;
 import org.springframework.security.web.header.HeaderWriterFilter;
 import org.springframework.security.web.header.HeaderWriterFilter;
 import org.springframework.security.web.savedrequest.RequestCache;
 import org.springframework.security.web.savedrequest.RequestCache;
 import org.springframework.security.web.savedrequest.RequestCacheAwareFilter;
 import org.springframework.security.web.savedrequest.RequestCacheAwareFilter;
@@ -754,6 +756,21 @@ public class MiscHttpConfigTests {
 		verify(firewall).getFirewalledResponse(any(HttpServletResponse.class));
 		verify(firewall).getFirewalledResponse(any(HttpServletResponse.class));
 	}
 	}
 
 
+	@Test
+	public void getWhenUsingCustomRequestRejectedHandlerThenRequestRejectedHandlerIsInvoked() throws Exception {
+		this.spring.configLocations(xml("RequestRejectedHandler")).autowire();
+
+		HttpServletResponse response = new MockHttpServletResponse();
+
+		RequestRejectedException rejected = new RequestRejectedException("failed");
+		HttpFirewall firewall = this.spring.getContext().getBean(HttpFirewall.class);
+		RequestRejectedHandler requestRejectedHandler = this.spring.getContext().getBean(RequestRejectedHandler.class);
+		when(firewall.getFirewalledRequest(any(HttpServletRequest.class))).thenThrow(rejected);
+		this.mvc.perform(get("/unprotected"));
+
+		verify(requestRejectedHandler).handle(any(), any(), any());
+	}
+
 	@Test
 	@Test
 	public void getWhenUsingCustomAccessDecisionManagerThenAuthorizesAccordingly() throws Exception {
 	public void getWhenUsingCustomAccessDecisionManagerThenAuthorizesAccordingly() throws Exception {
 		this.spring.configLocations(xml("CustomAccessDecisionManager")).autowire();
 		this.spring.configLocations(xml("CustomAccessDecisionManager")).autowire();

+ 32 - 0
config/src/test/resources/org/springframework/security/config/http/MiscHttpConfigTests-RequestRejectedHandler.xml

@@ -0,0 +1,32 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+  ~ Copyright 2002-2020 the original author or authors.
+  ~
+  ~ Licensed under the Apache License, Version 2.0 (the "License");
+  ~ you may not use this file except in compliance with the License.
+  ~ You may obtain a copy of the License at
+  ~
+  ~       https://www.apache.org/licenses/LICENSE-2.0
+  ~
+  ~ Unless required by applicable law or agreed to in writing, software
+  ~ distributed under the License is distributed on an "AS IS" BASIS,
+  ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  ~ See the License for the specific language governing permissions and
+  ~ limitations under the License.
+  -->
+
+<b:beans xmlns:b="http://www.springframework.org/schema/beans"
+		xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+		xmlns="http://www.springframework.org/schema/security"
+		xsi:schemaLocation="
+			http://www.springframework.org/schema/security
+			https://www.springframework.org/schema/security/spring-security.xsd
+			http://www.springframework.org/schema/beans
+			https://www.springframework.org/schema/beans/spring-beans.xsd">
+
+	<b:import resource="MiscHttpConfigTests-HttpFirewall.xml"/>
+
+	<b:bean id="requestRejectedHandler" class="org.mockito.Mockito" factory-method="mock">
+		<b:constructor-arg value="org.springframework.security.web.firewall.RequestRejectedHandler"/>
+	</b:bean>
+</b:beans>