|
@@ -18,6 +18,7 @@ package org.springframework.security.web.context;
|
|
|
|
|
|
import java.util.function.Supplier;
|
|
|
|
|
|
+import jakarta.servlet.DispatcherType;
|
|
|
import jakarta.servlet.FilterChain;
|
|
|
import jakarta.servlet.http.HttpServletRequest;
|
|
|
import jakarta.servlet.http.HttpServletResponse;
|
|
@@ -25,11 +26,15 @@ import org.junit.jupiter.api.AfterEach;
|
|
|
import org.junit.jupiter.api.BeforeEach;
|
|
|
import org.junit.jupiter.api.Test;
|
|
|
import org.junit.jupiter.api.extension.ExtendWith;
|
|
|
+import org.junit.jupiter.params.ParameterizedTest;
|
|
|
+import org.junit.jupiter.params.provider.EnumSource;
|
|
|
import org.mockito.ArgumentCaptor;
|
|
|
import org.mockito.Captor;
|
|
|
+import org.mockito.InOrder;
|
|
|
import org.mockito.Mock;
|
|
|
import org.mockito.junit.jupiter.MockitoExtension;
|
|
|
|
|
|
+import org.springframework.mock.web.MockFilterChain;
|
|
|
import org.springframework.security.authentication.TestAuthentication;
|
|
|
import org.springframework.security.core.Authentication;
|
|
|
import org.springframework.security.core.context.SecurityContext;
|
|
@@ -39,11 +44,17 @@ import org.springframework.security.core.context.SecurityContextImpl;
|
|
|
|
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
|
import static org.mockito.BDDMockito.given;
|
|
|
+import static org.mockito.Mockito.inOrder;
|
|
|
+import static org.mockito.Mockito.lenient;
|
|
|
+import static org.mockito.Mockito.times;
|
|
|
import static org.mockito.Mockito.verify;
|
|
|
+import static org.mockito.Mockito.verifyNoInteractions;
|
|
|
|
|
|
@ExtendWith(MockitoExtension.class)
|
|
|
class SecurityContextHolderFilterTests {
|
|
|
|
|
|
+ private static final String FILTER_APPLIED = "org.springframework.security.web.context.SecurityContextHolderFilter.APPLIED";
|
|
|
+
|
|
|
@Mock
|
|
|
private SecurityContextRepository repository;
|
|
|
|
|
@@ -104,14 +115,38 @@ class SecurityContextHolderFilterTests {
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- void shouldNotFilterErrorDispatchWhenDefault() {
|
|
|
- assertThat(this.filter.shouldNotFilterErrorDispatch()).isFalse();
|
|
|
+ void doFilterWhenFilterAppliedThenDoNothing() throws Exception {
|
|
|
+ given(this.request.getAttribute(FILTER_APPLIED)).willReturn(true);
|
|
|
+ this.filter.doFilter(this.request, this.response, new MockFilterChain());
|
|
|
+ verify(this.request, times(1)).getAttribute(FILTER_APPLIED);
|
|
|
+ verifyNoInteractions(this.repository, this.response);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- void shouldNotFilterErrorDispatchWhenOverridden() {
|
|
|
- this.filter.setShouldNotFilterErrorDispatch(true);
|
|
|
- assertThat(this.filter.shouldNotFilterErrorDispatch()).isTrue();
|
|
|
+ void doFilterWhenNotAppliedThenSetsAndRemovesAttribute() throws Exception {
|
|
|
+ given(this.repository.loadDeferredContext(this.requestArg.capture())).willReturn(
|
|
|
+ new SupplierDeferredSecurityContext(SecurityContextHolder::createEmptyContext, this.strategy));
|
|
|
+
|
|
|
+ this.filter.doFilter(this.request, this.response, new MockFilterChain());
|
|
|
+
|
|
|
+ InOrder inOrder = inOrder(this.request, this.repository);
|
|
|
+ inOrder.verify(this.request).setAttribute(FILTER_APPLIED, true);
|
|
|
+ inOrder.verify(this.repository).loadDeferredContext(this.request);
|
|
|
+ inOrder.verify(this.request).removeAttribute(FILTER_APPLIED);
|
|
|
+ }
|
|
|
+
|
|
|
+ @ParameterizedTest
|
|
|
+ @EnumSource(DispatcherType.class)
|
|
|
+ void doFilterWhenAnyDispatcherTypeThenFilter(DispatcherType dispatcherType) throws Exception {
|
|
|
+ lenient().when(this.request.getDispatcherType()).thenReturn(dispatcherType);
|
|
|
+ Authentication authentication = TestAuthentication.authenticatedUser();
|
|
|
+ SecurityContext expectedContext = new SecurityContextImpl(authentication);
|
|
|
+ given(this.repository.loadDeferredContext(this.requestArg.capture()))
|
|
|
+ .willReturn(new SupplierDeferredSecurityContext(() -> expectedContext, this.strategy));
|
|
|
+ FilterChain filterChain = (request, response) -> assertThat(SecurityContextHolder.getContext())
|
|
|
+ .isEqualTo(expectedContext);
|
|
|
+
|
|
|
+ this.filter.doFilter(this.request, this.response, filterChain);
|
|
|
}
|
|
|
|
|
|
}
|