Browse Source

Use PortResolver Beans by Default

Closes gh-16664
Rob Winch 5 months ago
parent
commit
76a566265c

+ 14 - 0
config/src/main/java/org/springframework/security/config/annotation/web/configurers/AbstractAuthenticationFilterConfigurer.java

@@ -21,6 +21,7 @@ import java.util.Collections;
 
 
 import jakarta.servlet.http.HttpServletRequest;
 import jakarta.servlet.http.HttpServletRequest;
 
 
+import org.springframework.context.ApplicationContext;
 import org.springframework.http.MediaType;
 import org.springframework.http.MediaType;
 import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.AuthenticationManager;
@@ -28,6 +29,7 @@ import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.web.AuthenticationEntryPoint;
 import org.springframework.security.web.AuthenticationEntryPoint;
 import org.springframework.security.web.PortMapper;
 import org.springframework.security.web.PortMapper;
+import org.springframework.security.web.PortResolver;
 import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
 import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
@@ -272,6 +274,10 @@ public abstract class AbstractAuthenticationFilterConfigurer<B extends HttpSecur
 		if (portMapper != null) {
 		if (portMapper != null) {
 			this.authenticationEntryPoint.setPortMapper(portMapper);
 			this.authenticationEntryPoint.setPortMapper(portMapper);
 		}
 		}
+		PortResolver portResolver = getBeanOrNull(http, PortResolver.class);
+		if (portResolver != null) {
+			this.authenticationEntryPoint.setPortResolver(portResolver);
+		}
 		RequestCache requestCache = http.getSharedObject(RequestCache.class);
 		RequestCache requestCache = http.getSharedObject(RequestCache.class);
 		if (requestCache != null) {
 		if (requestCache != null) {
 			this.defaultSuccessHandler.setRequestCache(requestCache);
 			this.defaultSuccessHandler.setRequestCache(requestCache);
@@ -412,6 +418,14 @@ public abstract class AbstractAuthenticationFilterConfigurer<B extends HttpSecur
 		this.authenticationEntryPoint = new LoginUrlAuthenticationEntryPoint(loginPage);
 		this.authenticationEntryPoint = new LoginUrlAuthenticationEntryPoint(loginPage);
 	}
 	}
 
 
+	private <C> C getBeanOrNull(B http, Class<C> clazz) {
+		ApplicationContext context = http.getSharedObject(ApplicationContext.class);
+		if (context == null) {
+			return null;
+		}
+		return context.getBeanProvider(clazz).getIfUnique();
+	}
+
 	@SuppressWarnings("unchecked")
 	@SuppressWarnings("unchecked")
 	private T getSelf() {
 	private T getSelf() {
 		return (T) this;
 		return (T) this;

+ 7 - 1
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java

@@ -83,6 +83,7 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
 import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
 import org.springframework.security.web.AuthenticationEntryPoint;
 import org.springframework.security.web.AuthenticationEntryPoint;
+import org.springframework.security.web.PortResolver;
 import org.springframework.security.web.RedirectStrategy;
 import org.springframework.security.web.RedirectStrategy;
 import org.springframework.security.web.authentication.DelegatingAuthenticationEntryPoint;
 import org.springframework.security.web.authentication.DelegatingAuthenticationEntryPoint;
 import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
 import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
@@ -578,8 +579,13 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>>
 				new RequestHeaderRequestMatcher("X-Requested-With", "XMLHttpRequest"));
 				new RequestHeaderRequestMatcher("X-Requested-With", "XMLHttpRequest"));
 		RequestMatcher formLoginNotEnabled = getFormLoginNotEnabledRequestMatcher(http);
 		RequestMatcher formLoginNotEnabled = getFormLoginNotEnabledRequestMatcher(http);
 		LinkedHashMap<RequestMatcher, AuthenticationEntryPoint> entryPoints = new LinkedHashMap<>();
 		LinkedHashMap<RequestMatcher, AuthenticationEntryPoint> entryPoints = new LinkedHashMap<>();
+		LoginUrlAuthenticationEntryPoint loginUrlEntryPoint = new LoginUrlAuthenticationEntryPoint(providerLoginPage);
+		PortResolver portResolver = getBeanOrNull(ResolvableType.forClass(PortResolver.class));
+		if (portResolver != null) {
+			loginUrlEntryPoint.setPortResolver(portResolver);
+		}
 		entryPoints.put(new AndRequestMatcher(notXRequestedWith, new NegatedRequestMatcher(defaultLoginPageMatcher),
 		entryPoints.put(new AndRequestMatcher(notXRequestedWith, new NegatedRequestMatcher(defaultLoginPageMatcher),
-				formLoginNotEnabled), new LoginUrlAuthenticationEntryPoint(providerLoginPage));
+				formLoginNotEnabled), loginUrlEntryPoint);
 		DelegatingAuthenticationEntryPoint loginEntryPoint = new DelegatingAuthenticationEntryPoint(entryPoints);
 		DelegatingAuthenticationEntryPoint loginEntryPoint = new DelegatingAuthenticationEntryPoint(entryPoints);
 		loginEntryPoint.setDefaultEntryPoint(this.getAuthenticationEntryPoint());
 		loginEntryPoint.setDefaultEntryPoint(this.getAuthenticationEntryPoint());
 		return loginEntryPoint;
 		return loginEntryPoint;

+ 7 - 1
config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java

@@ -52,6 +52,7 @@ import org.springframework.security.saml2.provider.service.web.authentication.Op
 import org.springframework.security.saml2.provider.service.web.authentication.Saml2AuthenticationRequestResolver;
 import org.springframework.security.saml2.provider.service.web.authentication.Saml2AuthenticationRequestResolver;
 import org.springframework.security.saml2.provider.service.web.authentication.Saml2WebSsoAuthenticationFilter;
 import org.springframework.security.saml2.provider.service.web.authentication.Saml2WebSsoAuthenticationFilter;
 import org.springframework.security.web.AuthenticationEntryPoint;
 import org.springframework.security.web.AuthenticationEntryPoint;
+import org.springframework.security.web.PortResolver;
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.DelegatingAuthenticationEntryPoint;
 import org.springframework.security.web.authentication.DelegatingAuthenticationEntryPoint;
 import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
 import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
@@ -344,8 +345,13 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
 		RequestMatcher notXRequestedWith = new NegatedRequestMatcher(
 		RequestMatcher notXRequestedWith = new NegatedRequestMatcher(
 				new RequestHeaderRequestMatcher("X-Requested-With", "XMLHttpRequest"));
 				new RequestHeaderRequestMatcher("X-Requested-With", "XMLHttpRequest"));
 		LinkedHashMap<RequestMatcher, AuthenticationEntryPoint> entryPoints = new LinkedHashMap<>();
 		LinkedHashMap<RequestMatcher, AuthenticationEntryPoint> entryPoints = new LinkedHashMap<>();
+		LoginUrlAuthenticationEntryPoint loginUrlEntryPoint = new LoginUrlAuthenticationEntryPoint(providerLoginPage);
+		PortResolver portResolver = getBeanOrNull(http, PortResolver.class);
+		if (portResolver != null) {
+			loginUrlEntryPoint.setPortResolver(portResolver);
+		}
 		entryPoints.put(new AndRequestMatcher(notXRequestedWith, new NegatedRequestMatcher(defaultLoginPageMatcher)),
 		entryPoints.put(new AndRequestMatcher(notXRequestedWith, new NegatedRequestMatcher(defaultLoginPageMatcher)),
-				new LoginUrlAuthenticationEntryPoint(providerLoginPage));
+				loginUrlEntryPoint);
 		DelegatingAuthenticationEntryPoint loginEntryPoint = new DelegatingAuthenticationEntryPoint(entryPoints);
 		DelegatingAuthenticationEntryPoint loginEntryPoint = new DelegatingAuthenticationEntryPoint(entryPoints);
 		loginEntryPoint.setDefaultEntryPoint(this.getAuthenticationEntryPoint());
 		loginEntryPoint.setDefaultEntryPoint(this.getAuthenticationEntryPoint());
 		return loginEntryPoint;
 		return loginEntryPoint;

+ 4 - 0
config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java

@@ -240,6 +240,10 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
 	}
 	}
 
 
 	private RuntimeBeanReference createPortResolver(BeanReference portMapper, ParserContext pc) {
 	private RuntimeBeanReference createPortResolver(BeanReference portMapper, ParserContext pc) {
+		String beanName = "portResolver";
+		if (pc.getRegistry().containsBeanDefinition(beanName)) {
+			return new RuntimeBeanReference(beanName);
+		}
 		RootBeanDefinition portResolver = new RootBeanDefinition(PortResolverImpl.class);
 		RootBeanDefinition portResolver = new RootBeanDefinition(PortResolverImpl.class);
 		portResolver.getPropertyValues().addPropertyValue("portMapper", portMapper);
 		portResolver.getPropertyValues().addPropertyValue("portMapper", portMapper);
 		String portResolverName = pc.getReaderContext().generateBeanName(portResolver);
 		String portResolverName = pc.getReaderContext().generateBeanName(portResolver);

+ 37 - 0
config/src/test/java/org/springframework/security/config/annotation/web/configurers/FormLoginConfigurerTests.java

@@ -38,6 +38,7 @@ import org.springframework.security.core.userdetails.UserDetailsService;
 import org.springframework.security.provisioning.InMemoryUserDetailsManager;
 import org.springframework.security.provisioning.InMemoryUserDetailsManager;
 import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders;
 import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders;
 import org.springframework.security.web.PortMapper;
 import org.springframework.security.web.PortMapper;
+import org.springframework.security.web.PortResolver;
 import org.springframework.security.web.SecurityFilterChain;
 import org.springframework.security.web.SecurityFilterChain;
 import org.springframework.security.web.access.ExceptionTranslationFilter;
 import org.springframework.security.web.access.ExceptionTranslationFilter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
@@ -378,6 +379,13 @@ public class FormLoginConfigurerTests {
 		verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(ExceptionTranslationFilter.class));
 		verify(ObjectPostProcessorConfig.objectPostProcessor).postProcess(any(ExceptionTranslationFilter.class));
 	}
 	}
 
 
+	@Test
+	public void configureWhenPortResolverBeanThenPortResolverUsed() throws Exception {
+		this.spring.register(CustomPortResolverConfig.class).autowire();
+		this.mockMvc.perform(get("/requires-authentication")).andExpect(status().is3xxRedirection());
+		verify(this.spring.getContext().getBean(PortResolver.class)).getServerPort(any());
+	}
+
 	@Configuration
 	@Configuration
 	@EnableWebSecurity
 	@EnableWebSecurity
 	static class RequestCacheConfig {
 	static class RequestCacheConfig {
@@ -723,6 +731,35 @@ public class FormLoginConfigurerTests {
 
 
 	}
 	}
 
 
+	@Configuration
+	@EnableWebSecurity
+	static class CustomPortResolverConfig {
+
+		@Bean
+		SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+				.authorizeHttpRequests((requests) -> requests
+					.anyRequest().authenticated()
+				)
+				.formLogin(withDefaults())
+				.requestCache(withDefaults());
+			return http.build();
+			// @formatter:on
+		}
+
+		@Bean
+		PortResolver portResolver() {
+			return mock(PortResolver.class);
+		}
+
+		@Bean
+		UserDetailsService userDetailsService() {
+			return new InMemoryUserDetailsManager(PasswordEncodedUser.user());
+		}
+
+	}
+
 	static class ReflectingObjectPostProcessor implements ObjectPostProcessor<Object> {
 	static class ReflectingObjectPostProcessor implements ObjectPostProcessor<Object> {
 
 
 		@Override
 		@Override

+ 13 - 0
config/src/test/java/org/springframework/security/config/http/FormLoginConfigTests.java

@@ -35,6 +35,7 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.web.FilterChainProxy;
 import org.springframework.security.web.FilterChainProxy;
+import org.springframework.security.web.PortResolver;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
 import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
@@ -45,6 +46,7 @@ import org.springframework.web.bind.annotation.RestController;
 
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verify;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
@@ -210,6 +212,17 @@ public class FormLoginConfigTests {
 		// @formatter:on
 		// @formatter:on
 	}
 	}
 
 
+	@Test
+	public void portResolver() throws Exception {
+		this.spring.configLocations(this.xml("PortResolverBean")).autowire();
+		// @formatter:off
+		this.mvc.perform(get("/requires-authentication"))
+				.andExpect(status().is3xxRedirection());
+		// @formatter:on
+		PortResolver portResolver = this.spring.getContext().getBean(PortResolver.class);
+		verify(portResolver, atLeastOnce()).getServerPort(any());
+	}
+
 	private Filter getFilter(ApplicationContext context, Class<? extends Filter> filterClass) {
 	private Filter getFilter(ApplicationContext context, Class<? extends Filter> filterClass) {
 		FilterChainProxy filterChain = context.getBean(BeanIds.FILTER_CHAIN_PROXY, FilterChainProxy.class);
 		FilterChainProxy filterChain = context.getBean(BeanIds.FILTER_CHAIN_PROXY, FilterChainProxy.class);
 		List<Filter> filters = filterChain.getFilters("/any");
 		List<Filter> filters = filterChain.getFilters("/any");

+ 24 - 0
config/src/test/kotlin/org/springframework/security/config/annotation/web/FormLoginDslTests.kt

@@ -34,6 +34,7 @@ import org.springframework.security.core.userdetails.User
 import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin
 import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin
 import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf
 import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf
 import org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated
 import org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated
+import org.springframework.security.web.PortResolver
 import org.springframework.security.web.SecurityFilterChain
 import org.springframework.security.web.SecurityFilterChain
 import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler
 import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler
 import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler
 import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler
@@ -240,6 +241,29 @@ class FormLoginDslTests {
         }
         }
     }
     }
 
 
+    @Test
+    fun `portResolerBean is used`() {
+        this.spring.register(PortResolverBeanConfig::class.java, AllSecuredConfig::class.java, UserConfig::class.java).autowire()
+
+        val portResolver = this.spring.context.getBean(PortResolver::class.java)
+        every { portResolver.getServerPort(any()) }.returns(1234)
+        this.mockMvc.get("/")
+            .andExpect {
+                status().isFound
+                redirectedUrl("http://localhost:1234/login")
+            }
+
+        verify { portResolver.getServerPort(any()) }
+    }
+
+    @Configuration
+    open class PortResolverBeanConfig {
+        @Bean
+        open fun portResolverBean(): PortResolver {
+            return mockk()
+        }
+    }
+
     @Test
     @Test
     fun `login when custom failure url then used`() {
     fun `login when custom failure url then used`() {
         this.spring.register(FailureHandlerConfig::class.java, UserConfig::class.java).autowire()
         this.spring.register(FailureHandlerConfig::class.java, UserConfig::class.java).autowire()

+ 37 - 0
config/src/test/resources/org/springframework/security/config/http/FormLoginConfigTests-PortResolverBean.xml

@@ -0,0 +1,37 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+  ~ Copyright 2002-2022 the original author or authors.
+  ~
+  ~ Licensed under the Apache License, Version 2.0 (the "License");
+  ~ you may not use this file except in compliance with the License.
+  ~ You may obtain a copy of the License at
+  ~
+  ~       https://www.apache.org/licenses/LICENSE-2.0
+  ~
+  ~ Unless required by applicable law or agreed to in writing, software
+  ~ distributed under the License is distributed on an "AS IS" BASIS,
+  ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  ~ See the License for the specific language governing permissions and
+  ~ limitations under the License.
+  -->
+
+<b:beans xmlns:b="http://www.springframework.org/schema/beans"
+		xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+		xmlns="http://www.springframework.org/schema/security"
+		xsi:schemaLocation="
+			http://www.springframework.org/schema/security
+			https://www.springframework.org/schema/security/spring-security.xsd
+			http://www.springframework.org/schema/beans
+			https://www.springframework.org/schema/beans/spring-beans.xsd">
+
+	<b:bean id="portResolver" class="org.mockito.Mockito" factory-method="mock" scope="singleton">
+		<b:constructor-arg value="org.springframework.security.web.PortResolver" type="java.lang.Class"/>
+	</b:bean>
+
+	<http auto-config="true">
+		<csrf disabled="true"/>
+		<intercept-url pattern="/**" access="authenticated"/>
+	</http>
+
+	<b:import resource="userservice.xml"/>
+</b:beans>