فهرست منبع

formLogin() and login() implement Mergable

This is necessary so that default requests like Spring REST Docs work.

Closes gh-7572
Dávid Kovács 5 سال پیش
والد
کامیت
fa9898dd6d

+ 66 - 11
test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuilders.java

@@ -15,16 +15,18 @@
  */
 package org.springframework.security.test.web.servlet.request;
 
-import javax.servlet.ServletContext;
-
+import org.springframework.beans.Mergeable;
 import org.springframework.http.MediaType;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.security.web.csrf.CsrfToken;
 import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.test.web.servlet.RequestBuilder;
+import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder;
 import org.springframework.test.web.servlet.request.RequestPostProcessor;
 import org.springframework.web.util.UriComponentsBuilder;
 
+import javax.servlet.ServletContext;
+
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
 
@@ -86,15 +88,23 @@ public final class SecurityMockMvcRequestBuilders {
 	 * @author Rob Winch
 	 * @since 4.0
 	 */
-	public static final class LogoutRequestBuilder implements RequestBuilder {
+	public static final class LogoutRequestBuilder implements RequestBuilder, Mergeable {
 		private String logoutUrl = "/logout";
 		private RequestPostProcessor postProcessor = csrf();
+		private Mergeable parent;
 
 		@Override
 		public MockHttpServletRequest buildRequest(ServletContext servletContext) {
-			MockHttpServletRequest request = post(this.logoutUrl)
-					.accept(MediaType.TEXT_HTML, MediaType.ALL)
-					.buildRequest(servletContext);
+			MockHttpServletRequestBuilder logoutRequest = post(this.logoutUrl)
+					.accept(MediaType.TEXT_HTML, MediaType.ALL);
+
+			if (this.parent != null) {
+				logoutRequest = (MockHttpServletRequestBuilder) logoutRequest.merge(this.parent);
+			}
+
+			MockHttpServletRequest request = logoutRequest.buildRequest(servletContext);
+			logoutRequest.postProcessRequest(request);
+
 			return this.postProcessor.postProcessRequest(request);
 		}
 
@@ -122,6 +132,24 @@ public final class SecurityMockMvcRequestBuilders {
 			return this;
 		}
 
+		@Override
+		public boolean isMergeEnabled() {
+			return true;
+		}
+
+		@Override
+		public Object merge(Object parent) {
+			if (parent == null) {
+				return this;
+			}
+			if (parent instanceof Mergeable) {
+				this.parent = (Mergeable) parent;
+				return this;
+			} else {
+				throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
+			}
+		}
+
 		private LogoutRequestBuilder() {
 		}
 	}
@@ -132,22 +160,31 @@ public final class SecurityMockMvcRequestBuilders {
 	 * @author Rob Winch
 	 * @since 4.0
 	 */
-	public static final class FormLoginRequestBuilder implements RequestBuilder {
+	public static final class FormLoginRequestBuilder implements RequestBuilder, Mergeable {
 		private String usernameParam = "username";
 		private String passwordParam = "password";
 		private String username = "user";
 		private String password = "password";
 		private String loginProcessingUrl = "/login";
 		private MediaType acceptMediaType = MediaType.APPLICATION_FORM_URLENCODED;
+		private Mergeable parent;
 
 		private RequestPostProcessor postProcessor = csrf();
 
 		@Override
 		public MockHttpServletRequest buildRequest(ServletContext servletContext) {
-			MockHttpServletRequest request = post(this.loginProcessingUrl)
-					.accept(this.acceptMediaType).param(this.usernameParam, this.username)
-					.param(this.passwordParam, this.password)
-					.buildRequest(servletContext);
+			MockHttpServletRequestBuilder loginRequest = post(this.loginProcessingUrl)
+					.accept(this.acceptMediaType)
+					.param(this.usernameParam, this.username)
+					.param(this.passwordParam, this.password);
+
+			if (this.parent != null) {
+				loginRequest = (MockHttpServletRequestBuilder) loginRequest.merge(this.parent);
+			}
+
+			MockHttpServletRequest request = loginRequest.buildRequest(servletContext);
+			loginRequest.postProcessRequest(request);
+
 			return this.postProcessor.postProcessRequest(request);
 		}
 
@@ -258,6 +295,24 @@ public final class SecurityMockMvcRequestBuilders {
 			return this;
 		}
 
+		@Override
+		public boolean isMergeEnabled() {
+			return true;
+		}
+
+		@Override
+		public Object merge(Object parent) {
+			if (parent == null) {
+				return this;
+			}
+			if (parent instanceof Mergeable ) {
+				this.parent = (Mergeable) parent;
+				return this;
+			} else {
+				throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
+			}
+		}
+
 		private FormLoginRequestBuilder() {
 		}
 	}

+ 1 - 1
test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java

@@ -410,7 +410,7 @@ public final class SecurityMockMvcRequestPostProcessors {
 
 			private final CsrfTokenRepository delegate;
 
-			private TestCsrfTokenRepository(CsrfTokenRepository delegate) {
+			TestCsrfTokenRepository(CsrfTokenRepository delegate) {
 				this.delegate = delegate;
 			}
 

+ 37 - 1
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java

@@ -17,14 +17,25 @@ package org.springframework.security.test.web.servlet.request;
 
 import org.junit.Before;
 import org.junit.Test;
-
+import org.springframework.http.HttpMethod;
 import org.springframework.http.MediaType;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockServletContext;
 import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
 import org.springframework.security.web.csrf.CsrfToken;
+import org.springframework.test.web.servlet.MockMvc;
+import org.springframework.test.web.servlet.MvcResult;
+import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
+import org.springframework.test.web.servlet.request.RequestPostProcessor;
+import org.springframework.test.web.servlet.setup.MockMvcBuilders;
+
+import java.util.Arrays;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.powermock.api.mockito.PowerMockito.when;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin;
 
 public class SecurityMockMvcRequestBuildersFormLoginTests {
@@ -82,6 +93,31 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
 		assertThat(request.getRequestURI()).isEqualTo("/uri-login/val1/val2");
 	}
 
+	/**
+	 * spring-restdocs uses postprocessors to do its trick. It will work only if these are merged together
+	 * with our request builders. (gh-7572)
+	 * @throws Exception
+	 */
+	@Test
+	public void postProcessorsAreMergedDuringMockMvcPerform() throws Exception {
+		RequestPostProcessor postProcessor = mock(RequestPostProcessor.class);
+		when(postProcessor.postProcessRequest(any())).thenAnswer(i -> i.getArgument(0));
+		MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new Object())
+				.defaultRequest(MockMvcRequestBuilders.get("/").with(postProcessor))
+				.build();
+
+
+		MvcResult mvcResult = mockMvc.perform(formLogin()).andReturn();
+		assertThat(mvcResult.getRequest().getMethod()).isEqualTo(HttpMethod.POST.name());
+		assertThat(mvcResult.getRequest().getHeader("Accept"))
+				.isEqualTo(MediaType.toString(Arrays.asList(MediaType.APPLICATION_FORM_URLENCODED)));
+		assertThat(mvcResult.getRequest().getParameter("username")).isEqualTo("user");
+		assertThat(mvcResult.getRequest().getParameter("password")).isEqualTo("password");
+		assertThat(mvcResult.getRequest().getRequestURI()).isEqualTo("/login");
+		assertThat(mvcResult.getRequest().getParameter("_csrf")).isNotEmpty();
+		verify(postProcessor).postProcessRequest(any());
+	}
+
 	// gh-3920
 	@Test
 	public void usesAcceptMediaForContentNegotiation() {

+ 37 - 3
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java

@@ -15,15 +15,28 @@
  */
 package org.springframework.security.test.web.servlet.request;
 
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.logout;
-
 import org.junit.Before;
 import org.junit.Test;
+import org.springframework.http.HttpMethod;
+import org.springframework.http.MediaType;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockServletContext;
 import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
 import org.springframework.security.web.csrf.CsrfToken;
+import org.springframework.test.web.servlet.MockMvc;
+import org.springframework.test.web.servlet.MvcResult;
+import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
+import org.springframework.test.web.servlet.request.RequestPostProcessor;
+import org.springframework.test.web.servlet.setup.MockMvcBuilders;
+
+import java.util.Arrays;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.powermock.api.mockito.PowerMockito.when;
+import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.logout;
 
 public class SecurityMockMvcRequestBuildersFormLogoutTests {
 	private MockServletContext servletContext;
@@ -71,4 +84,25 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
 		assertThat(request.getRequestURI()).isEqualTo("/uri-logout/val1/val2");
 	}
 
+	/**
+	 * spring-restdocs uses postprocessors to do its trick. It will work only if these are merged together
+	 * with our request builders. (gh-7572)
+	 * @throws Exception
+	 */
+	@Test
+	public void postProcessorsAreMergedDuringMockMvcPerform() throws Exception {
+		RequestPostProcessor postProcessor = mock(RequestPostProcessor.class);
+		when(postProcessor.postProcessRequest(any())).thenAnswer(i -> i.getArgument(0));
+		MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new Object())
+				.defaultRequest(MockMvcRequestBuilders.get("/").with(postProcessor))
+				.build();
+
+		MvcResult mvcResult = mockMvc.perform(logout()).andReturn();
+		assertThat(mvcResult.getRequest().getMethod()).isEqualTo(HttpMethod.POST.name());
+		assertThat(mvcResult.getRequest().getHeader("Accept"))
+				.isEqualTo(MediaType.toString(Arrays.asList(MediaType.TEXT_HTML, MediaType.ALL)));
+		assertThat(mvcResult.getRequest().getRequestURI()).isEqualTo("/logout");
+		assertThat(mvcResult.getRequest().getParameter("_csrf")).isNotEmpty();
+		verify(postProcessor).postProcessRequest(any());
+	}
 }