Просмотр исходного кода

Post-process AuthenticationRequestFilter

Fixes gh-8552
Josh Cummings 5 лет назад
Родитель
Сommit
51a0cffd36

+ 2 - 2
config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java

@@ -323,9 +323,9 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
 		private Filter build(B http) {
 			Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
 
-			return new Saml2WebSsoAuthenticationRequestFilter(
+			return postProcess(new Saml2WebSsoAuthenticationRequestFilter(
 							Saml2LoginConfigurer.this.relyingPartyRegistrationRepository,
-							authenticationRequestResolver);
+							authenticationRequestResolver));
 		}
 
 		private Saml2AuthenticationRequestFactory getResolver(B http) {

+ 56 - 0
config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java

@@ -23,6 +23,7 @@ import java.util.Base64;
 import java.util.Collection;
 import java.util.Collections;
 import javax.servlet.ServletException;
+import javax.servlet.http.HttpServletRequest;
 
 import org.junit.After;
 import org.junit.Assert;
@@ -55,9 +56,13 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
 import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
 import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication;
+import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
+import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
+import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
 import org.springframework.security.web.FilterChainProxy;
 import org.springframework.security.web.context.HttpRequestResponseHolder;
 import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
@@ -66,10 +71,15 @@ import org.springframework.test.util.ReflectionTestUtils;
 import org.springframework.test.web.servlet.MockMvc;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
+import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
 import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
+import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
+import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
 
 /**
  * Tests for different Java configuration for {@link Saml2LoginConfigurer}
@@ -133,6 +143,20 @@ public class Saml2LoginConfigurerTests {
 		validateSaml2WebSsoAuthenticationFilterConfiguration();
 	}
 
+	@Test
+	public void saml2LoginWhenCustomAuthenticationRequestContextResolverThenUses() throws Exception {
+		this.spring.register(CustomAuthenticationRequestContextResolver.class).autowire();
+
+		Saml2AuthenticationRequestContext context = authenticationRequestContext().build();
+		Saml2AuthenticationRequestContextResolver resolver =
+				CustomAuthenticationRequestContextResolver.resolver;
+		when(resolver.resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class)))
+				.thenReturn(context);
+		this.mvc.perform(get("/saml2/authenticate/registration-id"))
+				.andExpect(status().isFound());
+		verify(resolver).resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class));
+	}
+
 	private void validateSaml2WebSsoAuthenticationFilterConfiguration() {
 		// get the OpenSamlAuthenticationProvider
 		Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain);
@@ -219,6 +243,38 @@ public class Saml2LoginConfigurerTests {
 		}
 	}
 
+	@EnableWebSecurity
+	@Import(Saml2LoginConfigBeans.class)
+	static class CustomAuthenticationRequestContextResolver extends WebSecurityConfigurerAdapter {
+		private static final Saml2AuthenticationRequestContextResolver resolver =
+				mock(Saml2AuthenticationRequestContextResolver.class);
+
+		@Override
+		protected void configure(HttpSecurity http) throws Exception {
+			ObjectPostProcessor<Saml2WebSsoAuthenticationRequestFilter> processor
+					= new ObjectPostProcessor<Saml2WebSsoAuthenticationRequestFilter>() {
+				@Override
+				public <O extends Saml2WebSsoAuthenticationRequestFilter> O postProcess(O filter) {
+					filter.setAuthenticationRequestContextResolver(resolver);
+					return filter;
+				}
+			};
+
+			http
+				.authorizeRequests(authz -> authz
+						.anyRequest().authenticated()
+				)
+				.saml2Login(saml2 -> saml2
+						.addObjectPostProcessor(processor)
+				);
+		}
+
+		@Bean
+		Saml2AuthenticationRequestContextResolver resolver() {
+			return resolver;
+		}
+	}
+
 	private static AuthenticationManager getAuthenticationManagerMock(String role) {
 		return new AuthenticationManager() {