소스 검색

Add postProcess support to Saml2LogoutConfigurer

Closes gh-10311
Gaurav Tiwari 3 년 전
부모
커밋
33708e61fb

+ 3 - 3
config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java

@@ -253,7 +253,7 @@ public final class Saml2LogoutConfigurer<H extends HttpSecurityBuilder<H>>
 		Saml2LogoutRequestFilter filter = new Saml2LogoutRequestFilter(registrations,
 				this.logoutRequestConfigurer.logoutRequestValidator(), logoutResponseResolver, logoutHandlers);
 		filter.setLogoutRequestMatcher(createLogoutRequestMatcher());
-		return filter;
+		return postProcess(filter);
 	}
 
 	private Saml2LogoutResponseFilter createLogoutResponseProcessingFilter(
@@ -262,7 +262,7 @@ public final class Saml2LogoutConfigurer<H extends HttpSecurityBuilder<H>>
 				this.logoutResponseConfigurer.logoutResponseValidator(), this.logoutSuccessHandler);
 		logoutResponseFilter.setLogoutRequestMatcher(createLogoutResponseMatcher());
 		logoutResponseFilter.setLogoutRequestRepository(this.logoutRequestConfigurer.logoutRequestRepository);
-		return logoutResponseFilter;
+		return postProcess(logoutResponseFilter);
 	}
 
 	private LogoutFilter createRelyingPartyLogoutFilter(RelyingPartyRegistrationResolver registrations) {
@@ -271,7 +271,7 @@ public final class Saml2LogoutConfigurer<H extends HttpSecurityBuilder<H>>
 				registrations);
 		LogoutFilter logoutFilter = new LogoutFilter(logoutRequestSuccessHandler, logoutHandlers);
 		logoutFilter.setLogoutRequestMatcher(createLogoutMatcher());
-		return logoutFilter;
+		return postProcess(logoutFilter);
 	}
 
 	private RequestMatcher createLogoutMatcher() {

+ 112 - 0
config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurerTests.java

@@ -37,6 +37,7 @@ import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockHttpSession;
 import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.config.annotation.ObjectPostProcessor;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.test.SpringTestContext;
@@ -59,12 +60,16 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP
 import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
 import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
 import org.springframework.security.saml2.provider.service.web.authentication.logout.HttpSessionLogoutRequestRepository;
+import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestFilter;
 import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestRepository;
 import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestResolver;
+import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutResponseFilter;
 import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutResponseResolver;
 import org.springframework.security.web.SecurityFilterChain;
+import org.springframework.security.web.authentication.logout.LogoutFilter;
 import org.springframework.security.web.authentication.logout.LogoutHandler;
 import org.springframework.security.web.authentication.logout.LogoutSuccessHandler;
+import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.test.web.servlet.MvcResult;
 
@@ -75,6 +80,8 @@ import static org.mockito.BDDMockito.given;
 import static org.mockito.BDDMockito.mock;
 import static org.mockito.BDDMockito.verify;
 import static org.mockito.BDDMockito.verifyNoInteractions;
+import static org.mockito.Mockito.atLeastOnce;
+import static org.mockito.Mockito.spy;
 import static org.springframework.security.config.Customizer.withDefaults;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
@@ -346,6 +353,47 @@ public class Saml2LogoutConfigurerTests {
 		verify(getBean(Saml2LogoutResponseValidator.class)).validate(any());
 	}
 
+	@Test
+	public void saml2LogoutWhenLogoutGetThenLogsOutAndSendsLogoutRequest() throws Exception {
+		this.spring.register(Saml2LogoutWithHttpGet.class).autowire();
+		MvcResult result = this.mvc.perform(get("/logout").with(authentication(this.user)))
+				.andExpect(status().isFound()).andReturn();
+		String location = result.getResponse().getHeader("Location");
+		LogoutHandler logoutHandler = this.spring.getContext().getBean(LogoutHandler.class);
+		assertThat(location).startsWith("https://ap.example.org/logout/saml2/request");
+		verify(logoutHandler).logout(any(), any(), any());
+	}
+
+	@Test
+	public void saml2LogoutWhenSaml2LogoutRequestFilterPostProcessedThenUses() {
+
+		Saml2DefaultsWithObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class);
+		this.spring.register(Saml2DefaultsWithObjectPostProcessorConfig.class).autowire();
+		verify(Saml2DefaultsWithObjectPostProcessorConfig.objectPostProcessor)
+				.postProcess(any(Saml2LogoutRequestFilter.class));
+
+	}
+
+	@Test
+	public void saml2LogoutWhenSaml2LogoutResponseFilterPostProcessedThenUses() {
+
+		Saml2DefaultsWithObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class);
+		this.spring.register(Saml2DefaultsWithObjectPostProcessorConfig.class).autowire();
+		verify(Saml2DefaultsWithObjectPostProcessorConfig.objectPostProcessor)
+				.postProcess(any(Saml2LogoutResponseFilter.class));
+
+	}
+
+	@Test
+	public void saml2LogoutWhenLogoutFilterPostProcessedThenUses() {
+
+		Saml2DefaultsWithObjectPostProcessorConfig.objectPostProcessor = spy(ReflectingObjectPostProcessor.class);
+		this.spring.register(Saml2DefaultsWithObjectPostProcessorConfig.class).autowire();
+		verify(Saml2DefaultsWithObjectPostProcessorConfig.objectPostProcessor, atLeastOnce())
+				.postProcess(any(LogoutFilter.class));
+
+	}
+
 	private <T> T getBean(Class<T> clazz) {
 		return this.spring.getContext().getBean(clazz);
 	}
@@ -401,6 +449,61 @@ public class Saml2LogoutConfigurerTests {
 
 	}
 
+	@EnableWebSecurity
+	@Import(Saml2LoginConfigBeans.class)
+	static class Saml2LogoutWithHttpGet {
+
+		LogoutHandler mockLogoutHandler = mock(LogoutHandler.class);
+
+		@Bean
+		SecurityFilterChain web(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+				.authorizeRequests((authorize) -> authorize.anyRequest().authenticated())
+				.logout((logout) -> logout.addLogoutHandler(this.mockLogoutHandler))
+				.saml2Login(withDefaults())
+				.saml2Logout((saml2) -> saml2.addObjectPostProcessor(new ObjectPostProcessor<LogoutFilter>() {
+					@Override
+					public <O extends LogoutFilter> O postProcess(O filter) {
+						filter.setLogoutRequestMatcher(new AntPathRequestMatcher("/logout", "GET"));
+						return filter;
+					}
+				}));
+			return http.build();
+			// @formatter:on
+		}
+
+		@Bean
+		LogoutHandler logoutHandler() {
+			return this.mockLogoutHandler;
+		}
+
+	}
+
+	@EnableWebSecurity
+	@Import(Saml2LoginConfigBeans.class)
+	static class Saml2DefaultsWithObjectPostProcessorConfig {
+
+		static ObjectPostProcessor<Object> objectPostProcessor;
+
+		@Bean
+		SecurityFilterChain web(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+				.authorizeRequests((authorize) -> authorize.anyRequest().authenticated())
+				.saml2Login(withDefaults())
+				.saml2Logout(withDefaults());
+			return http.build();
+			// @formatter:on
+		}
+
+		@Bean
+		static ObjectPostProcessor<Object> objectPostProcessor() {
+			return objectPostProcessor;
+		}
+
+	}
+
 	@EnableWebSecurity
 	@Import(Saml2LoginConfigBeans.class)
 	static class Saml2LogoutComponentsConfig {
@@ -490,4 +593,13 @@ public class Saml2LogoutConfigurerTests {
 
 	}
 
+	static class ReflectingObjectPostProcessor implements ObjectPostProcessor<Object> {
+
+		@Override
+		public <O> O postProcess(O object) {
+			return object;
+		}
+
+	}
+
 }