Explorar o código

Additional Test for HttpSessionSecurityContextRepository

Issue gh-9387
Rob Winch %!s(int64=4) %!d(string=hai) anos
pai
achega
95da12110b

+ 58 - 0
web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java

@@ -16,12 +16,16 @@
 
 package org.springframework.security.web.context;
 
+import java.io.IOException;
 import java.lang.annotation.ElementType;
 import java.lang.annotation.Retention;
 import java.lang.annotation.RetentionPolicy;
 import java.lang.annotation.Target;
 
+import javax.servlet.Filter;
+import javax.servlet.ServletException;
 import javax.servlet.ServletOutputStream;
+import javax.servlet.http.HttpServlet;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletRequestWrapper;
 import javax.servlet.http.HttpServletResponse;
@@ -31,6 +35,7 @@ import javax.servlet.http.HttpSession;
 import org.junit.After;
 import org.junit.Test;
 
+import org.springframework.mock.web.MockFilterChain;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockHttpSession;
@@ -38,10 +43,14 @@ import org.springframework.security.authentication.AbstractAuthenticationToken;
 import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.authentication.AuthenticationTrustResolver;
 import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
 import org.springframework.security.core.Transient;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextImpl;
+import org.springframework.security.core.userdetails.User;
+import org.springframework.security.core.userdetails.UserDetails;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@@ -162,6 +171,48 @@ public class HttpSessionSecurityContextRepositoryTests {
 		verify(session).setAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY, ctx);
 	}
 
+	@Test
+	public void saveContextWhenSaveNewContextThenOriginalContextThenOriginalContextSaved() throws Exception {
+		HttpSessionSecurityContextRepository repository = new HttpSessionSecurityContextRepository();
+		SecurityContextPersistenceFilter securityContextPersistenceFilter = new SecurityContextPersistenceFilter(
+				repository);
+
+		UserDetails original = User.withUsername("user").password("password").roles("USER").build();
+		SecurityContext originalContext = createSecurityContext(original);
+		UserDetails impersonate = User.withUserDetails(original).username("impersonate").build();
+		SecurityContext impersonateContext = createSecurityContext(impersonate);
+
+		MockHttpServletRequest mockRequest = new MockHttpServletRequest();
+		MockHttpServletResponse mockResponse = new MockHttpServletResponse();
+
+		Filter saveImpersonateContext = (request, response, chain) -> {
+			SecurityContextHolder.setContext(impersonateContext);
+			// ensure the response is committed to trigger save
+			response.flushBuffer();
+			chain.doFilter(request, response);
+		};
+		Filter saveOriginalContext = (request, response, chain) -> {
+			SecurityContextHolder.setContext(originalContext);
+			chain.doFilter(request, response);
+		};
+		HttpServlet servlet = new HttpServlet() {
+			@Override
+			protected void service(HttpServletRequest req, HttpServletResponse resp)
+					throws ServletException, IOException {
+				resp.getWriter().write("Hi");
+			}
+		};
+
+		SecurityContextHolder.setContext(originalContext);
+		MockFilterChain chain = new MockFilterChain(servlet, saveImpersonateContext, saveOriginalContext);
+
+		securityContextPersistenceFilter.doFilter(mockRequest, mockResponse, chain);
+
+		assertThat(
+				mockRequest.getSession().getAttribute(HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY))
+						.isEqualTo(originalContext);
+	}
+
 	@Test
 	public void nonSecurityContextInSessionIsIgnored() {
 		HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository();
@@ -577,6 +628,13 @@ public class HttpSessionSecurityContextRepositoryTests {
 		assertThat(session).isNull();
 	}
 
+	private SecurityContext createSecurityContext(UserDetails userDetails) {
+		UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(userDetails,
+				userDetails.getPassword(), userDetails.getAuthorities());
+		SecurityContext securityContext = new SecurityContextImpl(token);
+		return securityContext;
+	}
+
 	@Transient
 	private static class SomeTransientAuthentication extends AbstractAuthenticationToken {