Browse Source

Add SecurityContextRepository.loadContext(HttpServletRequest)

This allows loading the SecurityContext lazily, without the need for the
response, and does not attempt to automatically save the request when
the response is comitted.

Closes gh-11028
Rob Winch 3 years ago
parent
commit
67fd46bfa6

+ 2 - 1
config/src/test/java/org/springframework/security/config/annotation/web/configurers/SecurityContextConfigurerTests.java

@@ -81,7 +81,8 @@ public class SecurityContextConfigurerTests {
 	@Test
 	@Test
 	public void securityContextWhenInvokedTwiceThenUsesOriginalSecurityContextRepository() throws Exception {
 	public void securityContextWhenInvokedTwiceThenUsesOriginalSecurityContextRepository() throws Exception {
 		this.spring.register(DuplicateDoesNotOverrideConfig.class).autowire();
 		this.spring.register(DuplicateDoesNotOverrideConfig.class).autowire();
-		given(DuplicateDoesNotOverrideConfig.SCR.loadContext(any())).willReturn(mock(SecurityContext.class));
+		given(DuplicateDoesNotOverrideConfig.SCR.loadContext(any(HttpRequestResponseHolder.class)))
+				.willReturn(mock(SecurityContext.class));
 		this.mvc.perform(get("/"));
 		this.mvc.perform(get("/"));
 		verify(DuplicateDoesNotOverrideConfig.SCR).loadContext(any(HttpRequestResponseHolder.class));
 		verify(DuplicateDoesNotOverrideConfig.SCR).loadContext(any(HttpRequestResponseHolder.class));
 	}
 	}

+ 1 - 3
config/src/test/java/org/springframework/security/config/http/MiscHttpConfigTests.java

@@ -126,7 +126,6 @@ import static org.springframework.security.test.web.servlet.request.SecurityMock
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.x509;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.x509;
-import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
@@ -470,11 +469,10 @@ public class MiscHttpConfigTests {
 		this.spring.configLocations(xml("ExplicitSaveAndExplicitRepository")).autowire();
 		this.spring.configLocations(xml("ExplicitSaveAndExplicitRepository")).autowire();
 		SecurityContextRepository repository = this.spring.getContext().getBean(SecurityContextRepository.class);
 		SecurityContextRepository repository = this.spring.getContext().getBean(SecurityContextRepository.class);
 		SecurityContext context = new SecurityContextImpl(new TestingAuthenticationToken("user", "password"));
 		SecurityContext context = new SecurityContextImpl(new TestingAuthenticationToken("user", "password"));
-		given(repository.loadContext(any(HttpRequestResponseHolder.class))).willReturn(context);
+		given(repository.loadContext(any(HttpServletRequest.class))).willReturn(() -> context);
 		// @formatter:off
 		// @formatter:off
 		MvcResult result = this.mvc.perform(formLogin())
 		MvcResult result = this.mvc.perform(formLogin())
 				.andExpect(status().is3xxRedirection())
 				.andExpect(status().is3xxRedirection())
-				.andExpect(authenticated())
 				.andReturn();
 				.andReturn();
 		// @formatter:on
 		// @formatter:on
 		verify(repository, atLeastOnce()).saveContext(any(SecurityContext.class), any(HttpServletRequest.class),
 		verify(repository, atLeastOnce()).saveContext(any(SecurityContext.class), any(HttpServletRequest.class),

+ 7 - 4
web/src/main/java/org/springframework/security/web/context/RequestAttributeSecurityContextRepository.java

@@ -16,6 +16,8 @@
 
 
 package org.springframework.security.web.context;
 package org.springframework.security.web.context;
 
 
+import java.util.function.Supplier;
+
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 import javax.servlet.http.HttpServletResponse;
 
 
@@ -64,17 +66,18 @@ public final class RequestAttributeSecurityContextRepository implements Security
 
 
 	@Override
 	@Override
 	public boolean containsContext(HttpServletRequest request) {
 	public boolean containsContext(HttpServletRequest request) {
-		return loadContext(request) != null;
+		return loadContext(request).get() != null;
 	}
 	}
 
 
 	@Override
 	@Override
 	public SecurityContext loadContext(HttpRequestResponseHolder requestResponseHolder) {
 	public SecurityContext loadContext(HttpRequestResponseHolder requestResponseHolder) {
-		SecurityContext context = loadContext(requestResponseHolder.getRequest());
+		SecurityContext context = loadContext(requestResponseHolder.getRequest()).get();
 		return (context != null) ? context : SecurityContextHolder.createEmptyContext();
 		return (context != null) ? context : SecurityContextHolder.createEmptyContext();
 	}
 	}
 
 
-	private SecurityContext loadContext(HttpServletRequest request) {
-		return (SecurityContext) request.getAttribute(this.requestAttributeName);
+	@Override
+	public Supplier<SecurityContext> loadContext(HttpServletRequest request) {
+		return () -> (SecurityContext) request.getAttribute(this.requestAttributeName);
 	}
 	}
 
 
 	@Override
 	@Override

+ 1 - 2
web/src/main/java/org/springframework/security/web/context/SecurityContextHolderFilter.java

@@ -58,8 +58,7 @@ public class SecurityContextHolderFilter extends OncePerRequestFilter {
 	@Override
 	@Override
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
 			throws ServletException, IOException {
-		SecurityContext securityContext = this.securityContextRepository
-				.loadContext(new HttpRequestResponseHolder(request, response));
+		SecurityContext securityContext = this.securityContextRepository.loadContext(request).get();
 		try {
 		try {
 			SecurityContextHolder.setContext(securityContext);
 			SecurityContextHolder.setContext(securityContext);
 			filterChain.doFilter(request, response);
 			filterChain.doFilter(request, response);

+ 16 - 0
web/src/main/java/org/springframework/security/web/context/SecurityContextRepository.java

@@ -16,6 +16,8 @@
 
 
 package org.springframework.security.web.context;
 package org.springframework.security.web.context;
 
 
+import java.util.function.Supplier;
+
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 import javax.servlet.http.HttpServletResponse;
 
 
@@ -61,6 +63,20 @@ public interface SecurityContextRepository {
 	 */
 	 */
 	SecurityContext loadContext(HttpRequestResponseHolder requestResponseHolder);
 	SecurityContext loadContext(HttpRequestResponseHolder requestResponseHolder);
 
 
+	/**
+	 * Obtains the security context for the supplied request. For an unauthenticated user,
+	 * an empty context implementation should be returned. This method should not return
+	 * null.
+	 * @param request the {@link HttpServletRequest} to load the {@link SecurityContext}
+	 * from
+	 * @return a {@link Supplier} that returns the {@link SecurityContext} which cannot be
+	 * null.
+	 * @since 5.7
+	 */
+	default Supplier<SecurityContext> loadContext(HttpServletRequest request) {
+		return () -> loadContext(new HttpRequestResponseHolder(request, null));
+	}
+
 	/**
 	/**
 	 * Stores the security context on completion of a request.
 	 * Stores the security context on completion of a request.
 	 * @param context the non-null context which was obtained from the holder.
 	 * @param context the non-null context which was obtained from the holder.

+ 28 - 0
web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java

@@ -64,6 +64,7 @@ import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
 
 
 /**
 /**
  * @author Luke Taylor
  * @author Luke Taylor
@@ -142,6 +143,33 @@ public class HttpSessionSecurityContextRepositoryTests {
 		assertThat(repo.loadContext(holder)).isEqualTo(SecurityContextHolder.createEmptyContext());
 		assertThat(repo.loadContext(holder)).isEqualTo(SecurityContextHolder.createEmptyContext());
 	}
 	}
 
 
+	@Test
+	public void loadContextHttpServletRequestWhenNotSavedThenEmptyContextReturned() {
+		HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository();
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		assertThat(repo.loadContext(request).get()).isEqualTo(SecurityContextHolder.createEmptyContext());
+	}
+
+	@Test
+	public void loadContextHttpServletRequestWhenSavedThenSavedContextReturned() {
+		SecurityContextImpl expectedContext = new SecurityContextImpl(this.testToken);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository();
+		repo.saveContext(expectedContext, request, response);
+		assertThat(repo.loadContext(request).get()).isEqualTo(expectedContext);
+	}
+
+	@Test
+	public void loadContextHttpServletRequestWhenNotAccessedThenHttpSessionNotAccessed() {
+		HttpSession session = mock(HttpSession.class);
+		HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository();
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.setSession(session);
+		repo.loadContext(request);
+		verifyNoInteractions(session);
+	}
+
 	@Test
 	@Test
 	public void existingContextIsSuccessFullyLoadedFromSessionAndSavedBack() {
 	public void existingContextIsSuccessFullyLoadedFromSessionAndSavedBack() {
 		HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository();
 		HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository();

+ 2 - 5
web/src/test/java/org/springframework/security/web/context/SecurityContextHolderFilterTests.java

@@ -50,11 +50,8 @@ class SecurityContextHolderFilterTests {
 	@Mock
 	@Mock
 	private HttpServletResponse response;
 	private HttpServletResponse response;
 
 
-	@Mock
-	private FilterChain chain;
-
 	@Captor
 	@Captor
-	private ArgumentCaptor<HttpRequestResponseHolder> requestResponse;
+	private ArgumentCaptor<HttpServletRequest> requestArg;
 
 
 	private SecurityContextHolderFilter filter;
 	private SecurityContextHolderFilter filter;
 
 
@@ -72,7 +69,7 @@ class SecurityContextHolderFilterTests {
 	void doFilterThenSetsAndClearsSecurityContextHolder() throws Exception {
 	void doFilterThenSetsAndClearsSecurityContextHolder() throws Exception {
 		Authentication authentication = TestAuthentication.authenticatedUser();
 		Authentication authentication = TestAuthentication.authenticatedUser();
 		SecurityContext expectedContext = new SecurityContextImpl(authentication);
 		SecurityContext expectedContext = new SecurityContextImpl(authentication);
-		given(this.repository.loadContext(this.requestResponse.capture())).willReturn(expectedContext);
+		given(this.repository.loadContext(this.requestArg.capture())).willReturn(() -> expectedContext);
 		FilterChain filterChain = (request, response) -> assertThat(SecurityContextHolder.getContext())
 		FilterChain filterChain = (request, response) -> assertThat(SecurityContextHolder.getContext())
 				.isEqualTo(expectedContext);
 				.isEqualTo(expectedContext);