浏览代码

Allow redirect status code to be customized

Closes gh-12797
Mark Chesney 1 年之前
父节点
当前提交
d9399dfda0

+ 28 - 2
web/src/main/java/org/springframework/security/web/DefaultRedirectStrategy.java

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright 2002-2016 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -24,6 +24,8 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.commons.logging.LogFactory;
 
 
 import org.springframework.core.log.LogMessage;
 import org.springframework.core.log.LogMessage;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpStatus;
 import org.springframework.security.web.util.UrlUtils;
 import org.springframework.security.web.util.UrlUtils;
 import org.springframework.util.Assert;
 import org.springframework.util.Assert;
 
 
@@ -32,6 +34,7 @@ import org.springframework.util.Assert;
  * the framework.
  * the framework.
  *
  *
  * @author Luke Taylor
  * @author Luke Taylor
+ * @author Mark Chesney
  * @since 3.0
  * @since 3.0
  */
  */
 public class DefaultRedirectStrategy implements RedirectStrategy {
 public class DefaultRedirectStrategy implements RedirectStrategy {
@@ -40,6 +43,8 @@ public class DefaultRedirectStrategy implements RedirectStrategy {
 
 
 	private boolean contextRelative;
 	private boolean contextRelative;
 
 
+	private HttpStatus statusCode = HttpStatus.FOUND;
+
 	/**
 	/**
 	 * Redirects the response to the supplied URL.
 	 * Redirects the response to the supplied URL.
 	 * <p>
 	 * <p>
@@ -55,7 +60,14 @@ public class DefaultRedirectStrategy implements RedirectStrategy {
 		if (this.logger.isDebugEnabled()) {
 		if (this.logger.isDebugEnabled()) {
 			this.logger.debug(LogMessage.format("Redirecting to %s", redirectUrl));
 			this.logger.debug(LogMessage.format("Redirecting to %s", redirectUrl));
 		}
 		}
-		response.sendRedirect(redirectUrl);
+		if (this.statusCode == HttpStatus.FOUND) {
+			response.sendRedirect(redirectUrl);
+		}
+		else {
+			response.setHeader(HttpHeaders.LOCATION, redirectUrl);
+			response.setStatus(this.statusCode.value());
+			response.getWriter().flush();
+		}
 	}
 	}
 
 
 	protected String calculateRedirectUrl(String contextPath, String url) {
 	protected String calculateRedirectUrl(String contextPath, String url) {
@@ -96,4 +108,18 @@ public class DefaultRedirectStrategy implements RedirectStrategy {
 		return this.contextRelative;
 		return this.contextRelative;
 	}
 	}
 
 
+	/**
+	 * Sets the HTTP status code to use. The default is {@link HttpStatus#FOUND}.
+	 * <p>
+	 * Note that according to RFC 7231, with {@link HttpStatus#FOUND}, a user agent MAY
+	 * change the request method from POST to GET for the subsequent request. If this
+	 * behavior is undesired, {@link HttpStatus#TEMPORARY_REDIRECT} can be used instead.
+	 * @param statusCode the HTTP status code to use.
+	 * @since 6.2
+	 */
+	public void setStatusCode(HttpStatus statusCode) {
+		Assert.notNull(statusCode, "statusCode cannot be null");
+		this.statusCode = statusCode;
+	}
+
 }
 }

+ 20 - 1
web/src/test/java/org/springframework/security/web/DefaultRedirectStrategyTests.java

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright 2002-2016 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@ package org.springframework.security.web;
 
 
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.Test;
 
 
+import org.springframework.http.HttpStatus;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockHttpServletResponse;
 
 
@@ -26,6 +27,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
 
 
 /**
 /**
  * @author Luke Taylor
  * @author Luke Taylor
+ * @author Mark Chesney
  * @since 3.0
  * @since 3.0
  */
  */
 public class DefaultRedirectStrategyTests {
 public class DefaultRedirectStrategyTests {
@@ -64,4 +66,21 @@ public class DefaultRedirectStrategyTests {
 			.isThrownBy(() -> rds.sendRedirect(request, response, "https://redirectme.somewhere.else"));
 			.isThrownBy(() -> rds.sendRedirect(request, response, "https://redirectme.somewhere.else"));
 	}
 	}
 
 
+	@Test
+	public void statusCodeIsHandledCorrectly() throws Exception {
+		// given
+		DefaultRedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
+		redirectStrategy.setStatusCode(HttpStatus.TEMPORARY_REDIRECT);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+
+		// when
+		redirectStrategy.sendRedirect(request, response, "/requested");
+
+		// then
+		assertThat(response.isCommitted()).isTrue();
+		assertThat(response.getRedirectedUrl()).isEqualTo("/requested");
+		assertThat(response.getStatus()).isEqualTo(307);
+	}
+
 }
 }

+ 32 - 0
web/src/test/java/org/springframework/security/web/session/SessionManagementFilterTests.java

@@ -23,6 +23,7 @@ import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.Test;
 
 
+import org.springframework.http.HttpStatus;
 import org.springframework.mock.web.MockFilterChain;
 import org.springframework.mock.web.MockFilterChain;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockHttpServletResponse;
@@ -210,6 +211,37 @@ public class SessionManagementFilterTests {
 		assertThat(response.getStatus()).isEqualTo(302);
 		assertThat(response.getStatus()).isEqualTo(302);
 	}
 	}
 
 
+	@Test
+	public void responseIsRedirectedToRequestedUrlIfStatusCodeIsSetAndSessionIsInvalid() throws Exception {
+		// given
+		DefaultRedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
+		redirectStrategy.setStatusCode(HttpStatus.TEMPORARY_REDIRECT);
+		RequestedUrlRedirectInvalidSessionStrategy invalidSessionStrategy = new RequestedUrlRedirectInvalidSessionStrategy();
+		invalidSessionStrategy.setCreateNewSession(true);
+		invalidSessionStrategy.setRedirectStrategy(redirectStrategy);
+		SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
+		SessionAuthenticationStrategy sessionAuthenticationStrategy = mock(SessionAuthenticationStrategy.class);
+		SessionManagementFilter filter = new SessionManagementFilter(securityContextRepository,
+				sessionAuthenticationStrategy);
+		filter.setInvalidSessionStrategy(invalidSessionStrategy);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.setRequestedSessionId("xxx");
+		request.setRequestedSessionIdValid(false);
+		request.setRequestURI("/requested");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain chain = mock(FilterChain.class);
+
+		// when
+		filter.doFilter(request, response, chain);
+
+		// then
+		verify(securityContextRepository).containsContext(request);
+		verifyNoMoreInteractions(securityContextRepository, sessionAuthenticationStrategy, chain);
+		assertThat(response.isCommitted()).isTrue();
+		assertThat(response.getRedirectedUrl()).isEqualTo("/requested");
+		assertThat(response.getStatus()).isEqualTo(307);
+	}
+
 	@Test
 	@Test
 	public void customAuthenticationTrustResolver() throws Exception {
 	public void customAuthenticationTrustResolver() throws Exception {
 		AuthenticationTrustResolver trustResolver = mock(AuthenticationTrustResolver.class);
 		AuthenticationTrustResolver trustResolver = mock(AuthenticationTrustResolver.class);