Răsfoiți Sursa

Merge branch '5.8.x' into 6.0.x

Closes gh-12368
Marcus Da Coregio 2 ani în urmă
părinte
comite
898c36287c

+ 20 - 20
web/src/main/java/org/springframework/security/web/context/SecurityContextHolderFilter.java

@@ -21,6 +21,8 @@ import java.util.function.Supplier;
 
 import jakarta.servlet.FilterChain;
 import jakarta.servlet.ServletException;
+import jakarta.servlet.ServletRequest;
+import jakarta.servlet.ServletResponse;
 import jakarta.servlet.http.HttpServletRequest;
 import jakarta.servlet.http.HttpServletResponse;
 
@@ -28,7 +30,7 @@ import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.util.Assert;
-import org.springframework.web.filter.OncePerRequestFilter;
+import org.springframework.web.filter.GenericFilterBean;
 
 /**
  * A {@link jakarta.servlet.Filter} that uses the {@link SecurityContextRepository} to
@@ -40,17 +42,18 @@ import org.springframework.web.filter.OncePerRequestFilter;
  * mechanisms to choose individually if authentication should be persisted.
  *
  * @author Rob Winch
+ * @author Marcus da Coregio
  * @since 5.7
  */
-public class SecurityContextHolderFilter extends OncePerRequestFilter {
+public class SecurityContextHolderFilter extends GenericFilterBean {
+
+	private static final String FILTER_APPLIED = SecurityContextHolderFilter.class.getName() + ".APPLIED";
 
 	private final SecurityContextRepository securityContextRepository;
 
 	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
 			.getContextHolderStrategy();
 
-	private boolean shouldNotFilterErrorDispatch;
-
 	/**
 	 * Creates a new instance.
 	 * @param securityContextRepository the repository to use. Cannot be null.
@@ -61,23 +64,29 @@ public class SecurityContextHolderFilter extends OncePerRequestFilter {
 	}
 
 	@Override
-	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
+	public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
+			throws IOException, ServletException {
+		doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
+	}
+
+	private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
 			throws ServletException, IOException {
+		if (request.getAttribute(FILTER_APPLIED) != null) {
+			chain.doFilter(request, response);
+			return;
+		}
+		request.setAttribute(FILTER_APPLIED, Boolean.TRUE);
 		Supplier<SecurityContext> deferredContext = this.securityContextRepository.loadDeferredContext(request);
 		try {
 			this.securityContextHolderStrategy.setDeferredContext(deferredContext);
-			filterChain.doFilter(request, response);
+			chain.doFilter(request, response);
 		}
 		finally {
 			this.securityContextHolderStrategy.clearContext();
+			request.removeAttribute(FILTER_APPLIED);
 		}
 	}
 
-	@Override
-	protected boolean shouldNotFilterErrorDispatch() {
-		return this.shouldNotFilterErrorDispatch;
-	}
-
 	/**
 	 * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
 	 * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
@@ -89,13 +98,4 @@ public class SecurityContextHolderFilter extends OncePerRequestFilter {
 		this.securityContextHolderStrategy = securityContextHolderStrategy;
 	}
 
-	/**
-	 * Disables {@link SecurityContextHolderFilter} for error dispatch.
-	 * @param shouldNotFilterErrorDispatch if the Filter should be disabled for error
-	 * dispatch. Default is false.
-	 */
-	public void setShouldNotFilterErrorDispatch(boolean shouldNotFilterErrorDispatch) {
-		this.shouldNotFilterErrorDispatch = shouldNotFilterErrorDispatch;
-	}
-
 }

+ 40 - 5
web/src/test/java/org/springframework/security/web/context/SecurityContextHolderFilterTests.java

@@ -18,6 +18,7 @@ package org.springframework.security.web.context;
 
 import java.util.function.Supplier;
 
+import jakarta.servlet.DispatcherType;
 import jakarta.servlet.FilterChain;
 import jakarta.servlet.http.HttpServletRequest;
 import jakarta.servlet.http.HttpServletResponse;
@@ -25,11 +26,15 @@ import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtendWith;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.EnumSource;
 import org.mockito.ArgumentCaptor;
 import org.mockito.Captor;
+import org.mockito.InOrder;
 import org.mockito.Mock;
 import org.mockito.junit.jupiter.MockitoExtension;
 
+import org.springframework.mock.web.MockFilterChain;
 import org.springframework.security.authentication.TestAuthentication;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContext;
@@ -39,11 +44,17 @@ import org.springframework.security.core.context.SecurityContextImpl;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.inOrder;
+import static org.mockito.Mockito.lenient;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
 
 @ExtendWith(MockitoExtension.class)
 class SecurityContextHolderFilterTests {
 
+	private static final String FILTER_APPLIED = "org.springframework.security.web.context.SecurityContextHolderFilter.APPLIED";
+
 	@Mock
 	private SecurityContextRepository repository;
 
@@ -104,14 +115,38 @@ class SecurityContextHolderFilterTests {
 	}
 
 	@Test
-	void shouldNotFilterErrorDispatchWhenDefault() {
-		assertThat(this.filter.shouldNotFilterErrorDispatch()).isFalse();
+	void doFilterWhenFilterAppliedThenDoNothing() throws Exception {
+		given(this.request.getAttribute(FILTER_APPLIED)).willReturn(true);
+		this.filter.doFilter(this.request, this.response, new MockFilterChain());
+		verify(this.request, times(1)).getAttribute(FILTER_APPLIED);
+		verifyNoInteractions(this.repository, this.response);
 	}
 
 	@Test
-	void shouldNotFilterErrorDispatchWhenOverridden() {
-		this.filter.setShouldNotFilterErrorDispatch(true);
-		assertThat(this.filter.shouldNotFilterErrorDispatch()).isTrue();
+	void doFilterWhenNotAppliedThenSetsAndRemovesAttribute() throws Exception {
+		given(this.repository.loadDeferredContext(this.requestArg.capture())).willReturn(
+				new SupplierDeferredSecurityContext(SecurityContextHolder::createEmptyContext, this.strategy));
+
+		this.filter.doFilter(this.request, this.response, new MockFilterChain());
+
+		InOrder inOrder = inOrder(this.request, this.repository);
+		inOrder.verify(this.request).setAttribute(FILTER_APPLIED, true);
+		inOrder.verify(this.repository).loadDeferredContext(this.request);
+		inOrder.verify(this.request).removeAttribute(FILTER_APPLIED);
+	}
+
+	@ParameterizedTest
+	@EnumSource(DispatcherType.class)
+	void doFilterWhenAnyDispatcherTypeThenFilter(DispatcherType dispatcherType) throws Exception {
+		lenient().when(this.request.getDispatcherType()).thenReturn(dispatcherType);
+		Authentication authentication = TestAuthentication.authenticatedUser();
+		SecurityContext expectedContext = new SecurityContextImpl(authentication);
+		given(this.repository.loadDeferredContext(this.requestArg.capture()))
+				.willReturn(new SupplierDeferredSecurityContext(() -> expectedContext, this.strategy));
+		FilterChain filterChain = (request, response) -> assertThat(SecurityContextHolder.getContext())
+				.isEqualTo(expectedContext);
+
+		this.filter.doFilter(this.request, this.response, filterChain);
 	}
 
 }