|
@@ -248,6 +248,72 @@ public class SecurityContextHolderAwareRequestFilterTests {
|
|
|
verifyZeroInteractions(authenticationEntryPoint, authenticationManager, logoutHandler);
|
|
|
}
|
|
|
|
|
|
+ @Test
|
|
|
+ public void getAsyncContextStart() throws Exception {
|
|
|
+ ArgumentCaptor<Runnable> runnableCaptor = ArgumentCaptor.forClass(Runnable.class);
|
|
|
+ SecurityContext context = SecurityContextHolder.createEmptyContext();
|
|
|
+ TestingAuthenticationToken expectedAuth = new TestingAuthenticationToken("user", "password","ROLE_USER");
|
|
|
+ context.setAuthentication(expectedAuth);
|
|
|
+ SecurityContextHolder.setContext(context);
|
|
|
+ AsyncContext asyncContext = mock(AsyncContext.class);
|
|
|
+ when(request.getAsyncContext()).thenReturn(asyncContext);
|
|
|
+ Runnable runnable = new Runnable() {
|
|
|
+ public void run() {}
|
|
|
+ };
|
|
|
+
|
|
|
+ wrappedRequest().getAsyncContext().start(runnable);
|
|
|
+
|
|
|
+ verifyZeroInteractions(authenticationManager, logoutHandler);
|
|
|
+ verify(asyncContext).start(runnableCaptor.capture());
|
|
|
+ DelegatingSecurityContextRunnable wrappedRunnable = (DelegatingSecurityContextRunnable) runnableCaptor.getValue();
|
|
|
+ assertThat(WhiteboxImpl.getInternalState(wrappedRunnable, SecurityContext.class)).isEqualTo(context);
|
|
|
+ assertThat(WhiteboxImpl.getInternalState(wrappedRunnable, Runnable.class)).isEqualTo(runnable);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void startAsyncStart() throws Exception {
|
|
|
+ ArgumentCaptor<Runnable> runnableCaptor = ArgumentCaptor.forClass(Runnable.class);
|
|
|
+ SecurityContext context = SecurityContextHolder.createEmptyContext();
|
|
|
+ TestingAuthenticationToken expectedAuth = new TestingAuthenticationToken("user", "password","ROLE_USER");
|
|
|
+ context.setAuthentication(expectedAuth);
|
|
|
+ SecurityContextHolder.setContext(context);
|
|
|
+ AsyncContext asyncContext = mock(AsyncContext.class);
|
|
|
+ when(request.startAsync()).thenReturn(asyncContext);
|
|
|
+ Runnable runnable = new Runnable() {
|
|
|
+ public void run() {}
|
|
|
+ };
|
|
|
+
|
|
|
+ wrappedRequest().startAsync().start(runnable);
|
|
|
+
|
|
|
+ verifyZeroInteractions(authenticationManager, logoutHandler);
|
|
|
+ verify(asyncContext).start(runnableCaptor.capture());
|
|
|
+ DelegatingSecurityContextRunnable wrappedRunnable = (DelegatingSecurityContextRunnable) runnableCaptor.getValue();
|
|
|
+ assertThat(WhiteboxImpl.getInternalState(wrappedRunnable, SecurityContext.class)).isEqualTo(context);
|
|
|
+ assertThat(WhiteboxImpl.getInternalState(wrappedRunnable, Runnable.class)).isEqualTo(runnable);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void startAsyncWithRequestResponseStart() throws Exception {
|
|
|
+ ArgumentCaptor<Runnable> runnableCaptor = ArgumentCaptor.forClass(Runnable.class);
|
|
|
+ SecurityContext context = SecurityContextHolder.createEmptyContext();
|
|
|
+ TestingAuthenticationToken expectedAuth = new TestingAuthenticationToken("user", "password","ROLE_USER");
|
|
|
+ context.setAuthentication(expectedAuth);
|
|
|
+ SecurityContextHolder.setContext(context);
|
|
|
+ AsyncContext asyncContext = mock(AsyncContext.class);
|
|
|
+ when(request.startAsync(request,response)).thenReturn(asyncContext);
|
|
|
+ Runnable runnable = new Runnable() {
|
|
|
+ public void run() {}
|
|
|
+ };
|
|
|
+
|
|
|
+ wrappedRequest().startAsync(request, response).start(runnable);
|
|
|
+
|
|
|
+ verifyZeroInteractions(authenticationManager, logoutHandler);
|
|
|
+ verify(asyncContext).start(runnableCaptor.capture());
|
|
|
+ DelegatingSecurityContextRunnable wrappedRunnable = (DelegatingSecurityContextRunnable) runnableCaptor.getValue();
|
|
|
+ assertThat(WhiteboxImpl.getInternalState(wrappedRunnable, SecurityContext.class)).isEqualTo(context);
|
|
|
+ assertThat(WhiteboxImpl.getInternalState(wrappedRunnable, Runnable.class)).isEqualTo(runnable);
|
|
|
+ }
|
|
|
+
|
|
|
private HttpServletRequest wrappedRequest() throws Exception {
|
|
|
filter.doFilter(request, response, filterChain);
|
|
|
verify(filterChain).doFilter(requestCaptor.capture(), any(HttpServletResponse.class));
|