Przeglądaj źródła

Parameterize getFilter() method in HttpSecurityBeanDefinitionParserTests.

Removes the need for casting to specific filter type.
Luke Taylor 15 lat temu
rodzic
commit
51abedcbef

+ 32 - 38
config/src/test/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParserTests.java

@@ -218,7 +218,7 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "   <form-login default-target-url='/default' always-use-default-target='true' />" +
                 "</http>" + AUTH_PROVIDER_XML);
         // These will be matched by the default pattern "/**"
-        UsernamePasswordAuthenticationFilter filter = (UsernamePasswordAuthenticationFilter) getFilters("/anything").get(1);
+        UsernamePasswordAuthenticationFilter filter = getFilter(UsernamePasswordAuthenticationFilter.class);
         assertEquals("/default", FieldUtils.getFieldValue(filter, "successHandler.defaultTargetUrl"));
         assertEquals(Boolean.TRUE, FieldUtils.getFieldValue(filter, "successHandler.alwaysUseDefaultTargetUrl"));
     }
@@ -250,7 +250,7 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "   <form-login />" +
                 "   <anonymous enabled='true' username='joe' granted-authority='anonymity' key='customKey' />" +
                 "</http>" + AUTH_PROVIDER_XML);
-        AnonymousAuthenticationFilter filter = (AnonymousAuthenticationFilter) getFilters("/anything").get(5);
+        AnonymousAuthenticationFilter filter = getFilter(AnonymousAuthenticationFilter.class);
         assertEquals("customKey", filter.getKey());
         assertEquals("joe", filter.getUserAttribute().getPassword());
         assertEquals("anonymity", filter.getUserAttribute().getAuthorities().get(0).getAuthority());
@@ -313,7 +313,7 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "        <intercept-url pattern='/**' access='ROLE_C' />" +
                 "    </http>" + AUTH_PROVIDER_XML);
 
-        FilterSecurityInterceptor fis = (FilterSecurityInterceptor) getFilter(FilterSecurityInterceptor.class);
+        FilterSecurityInterceptor fis = getFilter(FilterSecurityInterceptor.class);
 
         FilterInvocationSecurityMetadataSource fids = fis.getSecurityMetadataSource();
         Collection<ConfigAttribute> attrDef = fids.getAttributes(createFilterinvocation("/Secure", null));
@@ -363,7 +363,7 @@ public class HttpSecurityBeanDefinitionParserTests {
 
     private void checkPropertyValues() throws Exception {
         // Check the security attribute
-        FilterSecurityInterceptor fis = (FilterSecurityInterceptor) getFilter(FilterSecurityInterceptor.class);
+        FilterSecurityInterceptor fis = getFilter(FilterSecurityInterceptor.class);
         FilterInvocationSecurityMetadataSource fids = fis.getSecurityMetadataSource();
         Collection<ConfigAttribute> attrs = fids.getAttributes(createFilterinvocation("/secure", null));
         assertNotNull(attrs);
@@ -371,12 +371,11 @@ public class HttpSecurityBeanDefinitionParserTests {
         assertTrue(attrs.contains(new SecurityConfig("ROLE_A")));
 
         // Check the form login properties are set
-        UsernamePasswordAuthenticationFilter apf = (UsernamePasswordAuthenticationFilter)
-                getFilter(UsernamePasswordAuthenticationFilter.class);
+        UsernamePasswordAuthenticationFilter apf = getFilter(UsernamePasswordAuthenticationFilter.class);
         assertEquals("/defaultTarget", FieldUtils.getFieldValue(apf, "successHandler.defaultTargetUrl"));
         assertEquals("/authFailure", FieldUtils.getFieldValue(apf, "failureHandler.defaultFailureUrl"));
 
-        ExceptionTranslationFilter etf = (ExceptionTranslationFilter) getFilter(ExceptionTranslationFilter.class);
+        ExceptionTranslationFilter etf = getFilter(ExceptionTranslationFilter.class);
         assertEquals("/loginPage", FieldUtils.getFieldValue(etf, "authenticationEntryPoint.loginFormUrl"));
     }
 
@@ -389,7 +388,7 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "        <intercept-url pattern='/secure*' method='POST' access='ROLE_A,ROLE_B' />" +
                 "    </http>" + AUTH_PROVIDER_XML);
 
-        FilterSecurityInterceptor fis = (FilterSecurityInterceptor) getFilter(FilterSecurityInterceptor.class);
+        FilterSecurityInterceptor fis = getFilter(FilterSecurityInterceptor.class);
         FilterInvocationSecurityMetadataSource fids = fis.getSecurityMetadataSource();
         Collection<ConfigAttribute> attrs = fids.getAttributes(createFilterinvocation("/secure", "POST"));
         assertEquals(2, attrs.size());
@@ -400,9 +399,8 @@ public class HttpSecurityBeanDefinitionParserTests {
     @Test
     public void oncePerRequestAttributeIsSupported() throws Exception {
         setContext("<http once-per-request='false'><http-basic /></http>" + AUTH_PROVIDER_XML);
-        List<Filter> filters = getFilters("/someurl");
 
-        FilterSecurityInterceptor fsi = (FilterSecurityInterceptor) filters.get(filters.size() - 1);
+        FilterSecurityInterceptor fsi = getFilter(FilterSecurityInterceptor.class);
 
         assertFalse(fsi.isObserveOncePerRequest());
     }
@@ -410,9 +408,8 @@ public class HttpSecurityBeanDefinitionParserTests {
     @Test
     public void accessDeniedPageAttributeIsSupported() throws Exception {
         setContext("<http access-denied-page='/access-denied'><http-basic /></http>" + AUTH_PROVIDER_XML);
-        List<Filter> filters = getFilters("/someurl");
 
-        ExceptionTranslationFilter etf = (ExceptionTranslationFilter) filters.get(filters.size() - 2);
+        ExceptionTranslationFilter etf = getFilter(ExceptionTranslationFilter.class);
 
         assertEquals("/access-denied", FieldUtils.getFieldValue(etf, "accessDeniedHandler.errorPage"));
     }
@@ -510,7 +507,7 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "    <http auto-config='true'>" +
                 "        <access-denied-handler error-page=\"#{'/go' + '-away'} \" />" +
                 "    </http>" + AUTH_PROVIDER_XML);
-        ExceptionTranslationFilter filter = (ExceptionTranslationFilter) getFilter(ExceptionTranslationFilter.class);
+        ExceptionTranslationFilter filter = getFilter(ExceptionTranslationFilter.class);
         assertEquals("/go-away", FieldUtils.getFieldValue(filter, "accessDeniedHandler.errorPage"));
     }
 
@@ -521,7 +518,7 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "    <http auto-config='true'>" +
                 "        <access-denied-handler ref='adh'/>" +
                 "    </http>" + AUTH_PROVIDER_XML);
-        ExceptionTranslationFilter filter = (ExceptionTranslationFilter) getFilter(ExceptionTranslationFilter.class);
+        ExceptionTranslationFilter filter = getFilter(ExceptionTranslationFilter.class);
         AccessDeniedHandlerImpl adh = (AccessDeniedHandlerImpl) appContext.getBean("adh");
         assertSame(adh, FieldUtils.getFieldValue(filter, "accessDeniedHandler"));
     }
@@ -532,7 +529,7 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "    <http auto-config='true' access-denied-page='/go-away'>" +
                 "        <access-denied-handler error-page='/go-away'/>" +
                 "    </http>" + AUTH_PROVIDER_XML);
-        ExceptionTranslationFilter filter = (ExceptionTranslationFilter) getFilter(ExceptionTranslationFilter.class);
+        ExceptionTranslationFilter filter = getFilter(ExceptionTranslationFilter.class);
         assertEquals("/go-away", FieldUtils.getFieldValue(filter, "accessDeniedHandler.errorPage"));
     }
 
@@ -543,7 +540,7 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "    <http auto-config='true'>" +
                 "        <access-denied-handler error-page='/go-away' ref='adh'/>" +
                 "    </http>" + AUTH_PROVIDER_XML);
-        ExceptionTranslationFilter filter = (ExceptionTranslationFilter) getFilter(ExceptionTranslationFilter.class);
+        ExceptionTranslationFilter filter = getFilter(ExceptionTranslationFilter.class);
         assertEquals("/go-away", FieldUtils.getFieldValue(filter, "accessDeniedHandler.errorPage"));
     }
 
@@ -703,9 +700,8 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "<http auto-config='true'>" +
                 "    <x509 subject-principal-regex='${subject-principal-regex}'/>" +
                 "</http>"  + AUTH_PROVIDER_XML);
-        List<Filter> filters = getFilters("/someurl");
 
-        X509AuthenticationFilter filter = (X509AuthenticationFilter) filters.get(2);
+        X509AuthenticationFilter filter = getFilter(X509AuthenticationFilter.class);
         SubjectDnX509PrincipalExtractor pe = (SubjectDnX509PrincipalExtractor) FieldUtils.getFieldValue(filter, "principalExtractor");
         Pattern p = (Pattern) FieldUtils.getFieldValue(pe, "subjectDnPattern");
         assertEquals("uid=(.*),", p.pattern());
@@ -723,7 +719,7 @@ public class HttpSecurityBeanDefinitionParserTests {
 
         assertTrue(filters.get(0) instanceof ConcurrentSessionFilter);
         assertNotNull(appContext.getBean("sr"));
-        SessionManagementFilter smf = (SessionManagementFilter) getFilter(SessionManagementFilter.class);
+        SessionManagementFilter smf = getFilter(SessionManagementFilter.class);
         assertNotNull(smf);
         checkSessionRegistry();
     }
@@ -777,7 +773,7 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "        <concurrency-control max-sessions='2' error-if-maximum-exceeded='true' />" +
                 "    </session-management>" +
                 "</http>"  + AUTH_PROVIDER_XML);
-        SessionManagementFilter seshFilter = (SessionManagementFilter) getFilter(SessionManagementFilter.class);
+        SessionManagementFilter seshFilter = getFilter(SessionManagementFilter.class);
         UsernamePasswordAuthenticationToken auth = new UsernamePasswordAuthenticationToken("bob", "pass");
         SecurityContextHolder.getContext().setAuthentication(auth);
         // Register 2 sessions and then check a third
@@ -800,9 +796,8 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "</http>" +
                 "<b:bean id='cache' class='" + HttpSessionRequestCache.class.getName() + "'/>" +
                 AUTH_PROVIDER_XML);
-        ExceptionTranslationFilter etf = (ExceptionTranslationFilter) getFilter(ExceptionTranslationFilter.class);
         Object requestCache = appContext.getBean("cache");
-        assertSame(requestCache, FieldUtils.getFieldValue(etf, "requestCache"));
+        assertSame(requestCache, FieldUtils.getFieldValue(getFilter(ExceptionTranslationFilter.class), "requestCache"));
     }
 
     @Test
@@ -811,9 +806,8 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "<http auto-config='true' entry-point-ref='entryPoint'/>" +
                 "<b:bean id='entryPoint' class='" + MockEntryPoint.class.getName() + "'>" +
                 "</b:bean>" + AUTH_PROVIDER_XML);
-        ExceptionTranslationFilter etf = (ExceptionTranslationFilter) getFilters("/someurl").get(AUTO_CONFIG_FILTERS-2);
         assertTrue("ExceptionTranslationFilter should be configured with custom entry point",
-                etf.getAuthenticationEntryPoint() instanceof MockEntryPoint);
+                getFilter(ExceptionTranslationFilter.class).getAuthenticationEntryPoint() instanceof MockEntryPoint);
     }
 
     @SuppressWarnings("unused")
@@ -970,7 +964,7 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "    <intercept-url pattern='/someurl' access='ROLE_A'/>" +
                 "    <intercept-url pattern='/someurl' access='ROLE_B'/>" +
                 "</http>" + AUTH_PROVIDER_XML);
-        FilterSecurityInterceptor fis = (FilterSecurityInterceptor) getFilter(FilterSecurityInterceptor.class);
+        FilterSecurityInterceptor fis = getFilter(FilterSecurityInterceptor.class);
 
         FilterInvocationSecurityMetadataSource fids = fis.getSecurityMetadataSource();
         Collection<ConfigAttribute> attrDef = fids.getAttributes(createFilterinvocation("/someurl", null));
@@ -985,7 +979,7 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "<http create-session='always' security-context-repository-ref='repo'>" +
                 "    <http-basic />" +
                 "</http>" + AUTH_PROVIDER_XML);
-        SecurityContextPersistenceFilter filter = (SecurityContextPersistenceFilter) getFilter(SecurityContextPersistenceFilter.class);;
+        SecurityContextPersistenceFilter filter = getFilter(SecurityContextPersistenceFilter.class);;
         HttpSessionSecurityContextRepository repo = (HttpSessionSecurityContextRepository) appContext.getBean("repo");
         assertSame(repo, FieldUtils.getFieldValue(filter, "repo"));
         assertTrue((Boolean)FieldUtils.getFieldValue(filter, "forceEagerSessionCreation"));
@@ -1008,7 +1002,7 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "        <intercept-url pattern='/**' access='permitAll()' />" +
                 "    </http>" + AUTH_PROVIDER_XML);
 
-        FilterSecurityInterceptor fis = (FilterSecurityInterceptor) getFilter(FilterSecurityInterceptor.class);
+        FilterSecurityInterceptor fis = getFilter(FilterSecurityInterceptor.class);
 
         FilterInvocationSecurityMetadataSource fids = fis.getSecurityMetadataSource();
         Collection<ConfigAttribute> attrDef = fids.getAttributes(createFilterinvocation("/secure", null));
@@ -1037,7 +1031,7 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "<b:bean id='sh' class='" + SavedRequestAwareAuthenticationSuccessHandler.class.getName() +"'/>" +
                 "<b:bean id='fh' class='" + SimpleUrlAuthenticationFailureHandler.class.getName() + "'/>" +
                 AUTH_PROVIDER_XML);
-        UsernamePasswordAuthenticationFilter apf = (UsernamePasswordAuthenticationFilter) getFilter(UsernamePasswordAuthenticationFilter.class);
+        UsernamePasswordAuthenticationFilter apf = getFilter(UsernamePasswordAuthenticationFilter.class);
         AuthenticationSuccessHandler sh = (AuthenticationSuccessHandler) appContext.getBean("sh");
         AuthenticationFailureHandler fh = (AuthenticationFailureHandler) appContext.getBean("fh");
         assertSame(sh, FieldUtils.getFieldValue(apf, "successHandler"));
@@ -1069,7 +1063,7 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "   <form-login />" +
                 "</http>" +
                 AUTH_PROVIDER_XML);
-        ExceptionTranslationFilter etf = (ExceptionTranslationFilter) getFilter(ExceptionTranslationFilter.class);
+        ExceptionTranslationFilter etf = getFilter(ExceptionTranslationFilter.class);
         LoginUrlAuthenticationEntryPoint ap = (LoginUrlAuthenticationEntryPoint) etf.getAuthenticationEntryPoint();
         assertEquals("/spring_security_login", ap.getLoginFormUrl());
         // Default login filter should be present since we haven't specified any login URLs
@@ -1084,7 +1078,7 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "   <form-login login-page='/form_login_page' />" +
                 "</http>" +
                 AUTH_PROVIDER_XML);
-        ExceptionTranslationFilter etf = (ExceptionTranslationFilter) getFilter(ExceptionTranslationFilter.class);
+        ExceptionTranslationFilter etf = getFilter(ExceptionTranslationFilter.class);
         LoginUrlAuthenticationEntryPoint ap = (LoginUrlAuthenticationEntryPoint) etf.getAuthenticationEntryPoint();
         assertEquals("/form_login_page", ap.getLoginFormUrl());
         try {
@@ -1102,7 +1096,7 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "   <form-login />" +
                 "</http>" +
                 AUTH_PROVIDER_XML);
-        ExceptionTranslationFilter etf = (ExceptionTranslationFilter) getFilter(ExceptionTranslationFilter.class);
+        ExceptionTranslationFilter etf =  getFilter(ExceptionTranslationFilter.class);
         LoginUrlAuthenticationEntryPoint ap = (LoginUrlAuthenticationEntryPoint) etf.getAuthenticationEntryPoint();
         assertEquals("/openid_login", ap.getLoginFormUrl());
     }
@@ -1120,7 +1114,7 @@ public class HttpSecurityBeanDefinitionParserTests {
                 "   </openid-login>" +
                 "</http>" +
                 AUTH_PROVIDER_XML);
-        OpenIDAuthenticationFilter apf = (OpenIDAuthenticationFilter) getFilter(OpenIDAuthenticationFilter.class);
+        OpenIDAuthenticationFilter apf = getFilter(OpenIDAuthenticationFilter.class);
 
         OpenID4JavaConsumer consumer = (OpenID4JavaConsumer) FieldUtils.getFieldValue(apf, "consumer");
         List<OpenIDAttribute> attributes = (List<OpenIDAttribute>) FieldUtils.getFieldValue(consumer, "attributesToFetch");
@@ -1164,11 +1158,11 @@ public class HttpSecurityBeanDefinitionParserTests {
     }
 
     @SuppressWarnings("unchecked")
-    private List<Filter> getFilters(String url) throws Exception {
+    private <T extends Filter> List<T> getFilters(String url) throws Exception {
         FilterChainProxy fcp = (FilterChainProxy) appContext.getBean(BeanIds.FILTER_CHAIN_PROXY);
         Method getFilters = fcp.getClass().getDeclaredMethod("getFilters", String.class);
         getFilters.setAccessible(true);
-        return (List<Filter>) ReflectionUtils.invokeMethod(getFilters, fcp, new Object[] {url});
+        return (List<T>) ReflectionUtils.invokeMethod(getFilters, fcp, new Object[] {url});
     }
 
     private FilterInvocation createFilterinvocation(String path, String method) {
@@ -1181,10 +1175,10 @@ public class HttpSecurityBeanDefinitionParserTests {
         return new FilterInvocation(request, new MockHttpServletResponse(), new MockFilterChain());
     }
 
-    private Object getFilter(Class<? extends Filter> type) throws Exception {
-        List<Filter> filters = getFilters("/any");
+    private <T extends Filter> T getFilter(Class<T> type) throws Exception {
+        List<T> filters = getFilters("/any");
 
-        for (Filter f : filters) {
+        for (T f : filters) {
             if (f.getClass().isAssignableFrom(type)) {
                 return f;
             }
@@ -1194,7 +1188,7 @@ public class HttpSecurityBeanDefinitionParserTests {
     }
 
     private RememberMeServices getRememberMeServices() throws Exception {
-        return ((RememberMeAuthenticationFilter)getFilter(RememberMeAuthenticationFilter.class)).getRememberMeServices();
+        return getFilter(RememberMeAuthenticationFilter.class).getRememberMeServices();
     }
 
 }