Pārlūkot izejas kodu

Add ServerRequestCache

Fixes: gh-4789
Rob Winch 7 gadi atpakaļ
vecāks
revīzija
1b70efce2b

+ 4 - 0
config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java

@@ -46,6 +46,10 @@ public enum SecurityWebFiltersOrder {
 	 * {@link org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter}
 	 */
 	SECURITY_CONTEXT_SERVER_WEB_EXCHANGE,
+	/**
+	 * {@link org.springframework.security.web.server.savedrequest.ServerRequestCacheWebFilter}
+	 */
+	SERVER_REQUEST_CACHE,
 	LOGOUT,
 	EXCEPTION_TRANSLATION,
 	AUTHORIZATION,

+ 56 - 5
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -39,7 +39,6 @@ import org.springframework.security.web.server.authentication.ServerAuthenticati
 import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler;
 import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler;
 import org.springframework.security.web.server.authentication.logout.LogoutWebFilter;
-import org.springframework.security.web.server.authentication.logout.SecurityContextServerLogoutHandler;
 import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler;
 import org.springframework.security.web.server.authentication.logout.ServerLogoutSuccessHandler;
 import org.springframework.security.web.server.authorization.AuthorizationContext;
@@ -47,10 +46,10 @@ import org.springframework.security.web.server.authorization.AuthorizationWebFil
 import org.springframework.security.web.server.authorization.DelegatingReactiveAuthorizationManager;
 import org.springframework.security.web.server.authorization.ExceptionTranslationWebFilter;
 import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
-import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
+import org.springframework.security.web.server.context.NoOpServerSecurityContextRepository;
 import org.springframework.security.web.server.context.ReactorContextWebFilter;
+import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
 import org.springframework.security.web.server.context.ServerSecurityContextRepository;
-import org.springframework.security.web.server.context.NoOpServerSecurityContextRepository;
 import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
 import org.springframework.security.web.server.csrf.CsrfWebFilter;
 import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
@@ -62,6 +61,10 @@ import org.springframework.security.web.server.header.ServerHttpHeadersWriter;
 import org.springframework.security.web.server.header.StrictTransportSecurityServerHttpHeadersWriter;
 import org.springframework.security.web.server.header.XFrameOptionsServerHttpHeadersWriter;
 import org.springframework.security.web.server.header.XXssProtectionServerHttpHeadersWriter;
+import org.springframework.security.web.server.savedrequest.NoOpServerRequestCache;
+import org.springframework.security.web.server.savedrequest.ServerRequestCache;
+import org.springframework.security.web.server.savedrequest.ServerRequestCacheWebFilter;
+import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache;
 import org.springframework.security.web.server.ui.LoginPageGeneratingWebFilter;
 import org.springframework.security.web.server.ui.LogoutPageGeneratingWebFilter;
 import org.springframework.security.web.server.util.matcher.MediaTypeServerWebExchangeMatcher;
@@ -102,6 +105,8 @@ public class ServerHttpSecurity {
 
 	private HttpBasicBuilder httpBasic;
 
+	private final RequestCacheBuilder requestCache = new RequestCacheBuilder();
+
 	private FormLoginBuilder formLogin;
 
 	private LogoutBuilder logout = new LogoutBuilder();
@@ -198,6 +203,10 @@ public class ServerHttpSecurity {
 		return this.logout;
 	}
 
+	public RequestCacheBuilder requestCache() {
+		return this.requestCache;
+	}
+
 	public ServerHttpSecurity authenticationManager(ReactiveAuthenticationManager manager) {
 		this.authenticationManager = manager;
 		return this;
@@ -239,6 +248,7 @@ public class ServerHttpSecurity {
 		if(this.logout != null) {
 			this.logout.configure(this);
 		}
+		this.requestCache.configure(this);
 		this.addFilterAt(new SecurityContextServerWebExchangeWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE);
 		if(this.authorizeExchangeBuilder != null) {
 			ServerAuthenticationEntryPoint serverAuthenticationEntryPoint = getServerAuthenticationEntryPoint();
@@ -433,6 +443,35 @@ public class ServerHttpSecurity {
 		private ExceptionHandlingBuilder() {}
 	}
 
+	/**
+	 * @author Rob Winch
+	 * @since 5.0
+	 */
+	public class RequestCacheBuilder {
+		private ServerRequestCache requestCache = new WebSessionServerRequestCache();
+
+		public RequestCacheBuilder requestCache(ServerRequestCache requestCache) {
+			Assert.notNull(requestCache, "requestCache cannot be null");
+			this.requestCache = requestCache;
+			return this;
+		}
+
+		protected void configure(ServerHttpSecurity http) {
+			http.addFilterAt(new ServerRequestCacheWebFilter(), SecurityWebFiltersOrder.SERVER_REQUEST_CACHE);
+		}
+
+		public ServerHttpSecurity and() {
+			return ServerHttpSecurity.this;
+		}
+
+		public ServerHttpSecurity disable() {
+			this.requestCache = NoOpServerRequestCache.getInstance();
+			return and();
+		}
+
+		private RequestCacheBuilder() {}
+	}
+
 	/**
 	 * @author Rob Winch
 	 * @since 5.0
@@ -489,6 +528,10 @@ public class ServerHttpSecurity {
 	 * @since 5.0
 	 */
 	public class FormLoginBuilder {
+		private final RedirectServerAuthenticationSuccessHandler defaultSuccessHandler = new RedirectServerAuthenticationSuccessHandler("/");
+
+		private RedirectServerAuthenticationEntryPoint defaultEntryPoint;
+
 		private ReactiveAuthenticationManager authenticationManager;
 
 		private ServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
@@ -499,7 +542,7 @@ public class ServerHttpSecurity {
 
 		private ServerAuthenticationFailureHandler serverAuthenticationFailureHandler;
 
-		private ServerAuthenticationSuccessHandler serverAuthenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler("/");
+		private ServerAuthenticationSuccessHandler serverAuthenticationSuccessHandler = this.defaultSuccessHandler;
 
 		public FormLoginBuilder authenticationManager(ReactiveAuthenticationManager authenticationManager) {
 			this.authenticationManager = authenticationManager;
@@ -514,7 +557,8 @@ public class ServerHttpSecurity {
 		}
 
 		public FormLoginBuilder loginPage(String loginPage) {
-			this.serverAuthenticationEntryPoint =  new RedirectServerAuthenticationEntryPoint(loginPage);
+			this.defaultEntryPoint = new RedirectServerAuthenticationEntryPoint(loginPage);
+			this.serverAuthenticationEntryPoint = this.defaultEntryPoint;
 			this.requiresAuthenticationMatcher = ServerWebExchangeMatchers.pathMatchers(HttpMethod.POST, loginPage);
 			this.serverAuthenticationFailureHandler = new RedirectServerAuthenticationFailureHandler(loginPage + "?error");
 			return this;
@@ -553,6 +597,13 @@ public class ServerHttpSecurity {
 			if(this.serverAuthenticationEntryPoint == null) {
 				loginPage("/login");
 			}
+			if(http.requestCache != null) {
+				ServerRequestCache requestCache = http.requestCache.requestCache;
+				this.defaultSuccessHandler.setRequestCache(requestCache);
+				if(this.defaultEntryPoint != null) {
+					this.defaultEntryPoint.setRequestCache(requestCache);
+				}
+			}
 			MediaTypeServerWebExchangeMatcher htmlMatcher = new MediaTypeServerWebExchangeMatcher(
 				MediaType.TEXT_HTML);
 			htmlMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL));

+ 6 - 9
config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java

@@ -21,15 +21,9 @@ import org.openqa.selenium.WebDriver;
 import org.openqa.selenium.WebElement;
 import org.openqa.selenium.support.FindBy;
 import org.openqa.selenium.support.PageFactory;
-import org.springframework.security.authentication.ReactiveAuthenticationManager;
-import org.springframework.security.authentication.UserDetailsRepositoryReactiveAuthenticationManager;
 import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder;
-import org.springframework.security.core.userdetails.MapReactiveUserDetailsService;
-import org.springframework.security.core.userdetails.User;
-import org.springframework.security.core.userdetails.UserDetails;
 import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder;
 import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
-import org.springframework.security.web.context.SaveContextOnUpdateOrErrorResponseWrapperTests;
 import org.springframework.security.web.server.SecurityWebFilterChain;
 import org.springframework.security.web.server.WebFilterChainProxy;
 import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler;
@@ -39,7 +33,6 @@ import org.springframework.test.web.reactive.server.WebTestClient;
 import org.springframework.web.bind.annotation.GetMapping;
 import org.springframework.web.bind.annotation.ResponseBody;
 import org.springframework.web.server.ServerWebExchange;
-import reactor.core.publisher.Mono;
 
 import static org.assertj.core.api.Assertions.assertThat;
 
@@ -143,7 +136,7 @@ public class FormLoginTests {
 			.webTestClientSetup(webTestClient)
 			.build();
 
-		DefaultLoginPage loginPage = HomePage.to(driver, DefaultLoginPage.class)
+		DefaultLoginPage loginPage = DefaultLoginPage.to(driver)
 			.assertAt();
 
 		HomePage homePage = loginPage.loginForm()
@@ -238,6 +231,11 @@ public class FormLoginTests {
 			return this.loginForm;
 		}
 
+		static DefaultLoginPage to(WebDriver driver) {
+			driver.get("http://localhost/login");
+			return PageFactory.initElements(driver, DefaultLoginPage.class);
+		}
+
 		public static class LoginForm {
 			private WebDriver driver;
 			private WebElement username;
@@ -347,6 +345,5 @@ public class FormLoginTests {
 				+ "  </body>\n"
 				+ "</html>";
 		}
-
 	}
 }

+ 147 - 0
config/src/test/java/org/springframework/security/config/web/server/RequestCacheTests.java

@@ -0,0 +1,147 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.config.web.server;
+
+import org.junit.Test;
+import org.openqa.selenium.WebDriver;
+import org.openqa.selenium.WebElement;
+import org.openqa.selenium.support.FindBy;
+import org.openqa.selenium.support.PageFactory;
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
+import org.springframework.mock.web.server.MockServerWebExchange;
+import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder;
+import org.springframework.security.config.web.server.FormLoginTests.DefaultLoginPage;
+import org.springframework.security.config.web.server.FormLoginTests.HomePage;
+import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder;
+import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
+import org.springframework.security.web.server.SecurityWebFilterChain;
+import org.springframework.security.web.server.WebFilterChainProxy;
+import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler;
+import org.springframework.security.web.server.csrf.CsrfToken;
+import org.springframework.security.web.server.savedrequest.NoOpServerRequestCache;
+import org.springframework.stereotype.Controller;
+import org.springframework.test.web.reactive.server.WebTestClient;
+import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.ResponseBody;
+import org.springframework.web.server.ServerWebExchange;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * @author Rob Winch
+ * @since 5.0
+ */
+public class RequestCacheTests {
+	private ServerHttpSecurity http = ServerHttpSecurityConfigurationBuilder.httpWithDefaultAuthentication();
+
+	@Test
+	public void defaultFormLoginRequestCache() {
+		SecurityWebFilterChain securityWebFilter = this.http
+			.authorizeExchange()
+			.anyExchange().authenticated()
+			.and()
+			.formLogin().and()
+			.build();
+
+		WebTestClient webTestClient = WebTestClient
+			.bindToController(new SecuredPageController(), new WebTestClientBuilder.Http200RestController())
+			.webFilter(new WebFilterChainProxy(securityWebFilter))
+			.build();
+
+		WebDriver driver = WebTestClientHtmlUnitDriverBuilder
+			.webTestClientSetup(webTestClient)
+			.build();
+
+		DefaultLoginPage loginPage = SecuredPage.to(driver, DefaultLoginPage.class)
+			.assertAt();
+
+		SecuredPage securedPage = loginPage.loginForm()
+			.username("user")
+			.password("password")
+			.submit(SecuredPage.class);
+
+		securedPage.assertAt();
+	}
+
+	@Test
+	public void requestCacheNoOp() {
+		SecurityWebFilterChain securityWebFilter = this.http
+			.authorizeExchange()
+				.anyExchange().authenticated()
+				.and()
+			.formLogin().and()
+			.requestCache()
+				.requestCache(NoOpServerRequestCache.getInstance())
+				.and()
+			.build();
+
+		WebTestClient webTestClient = WebTestClient
+			.bindToController(new SecuredPageController(), new WebTestClientBuilder.Http200RestController())
+			.webFilter(new WebFilterChainProxy(securityWebFilter))
+			.build();
+
+		WebDriver driver = WebTestClientHtmlUnitDriverBuilder
+			.webTestClientSetup(webTestClient)
+			.build();
+
+		DefaultLoginPage loginPage = SecuredPage.to(driver, DefaultLoginPage.class)
+			.assertAt();
+
+		HomePage securedPage = loginPage.loginForm()
+			.username("user")
+			.password("password")
+			.submit(HomePage.class);
+
+		securedPage.assertAt();
+	}
+
+	public static class SecuredPage {
+		private WebDriver driver;
+
+		public SecuredPage(WebDriver driver) {
+			this.driver = driver;
+		}
+
+		public void assertAt() {
+			assertThat(this.driver.getTitle()).isEqualTo("Secured");
+		}
+
+		static <T> T to(WebDriver driver, Class<T> page) {
+			driver.get("http://localhost/secured");
+			return PageFactory.initElements(driver, page);
+		}
+	}
+
+	@Controller
+	public static class SecuredPageController {
+		@ResponseBody
+		@GetMapping("/secured")
+		public String login(ServerWebExchange exchange) {
+			CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
+			return
+				"<!DOCTYPE html>\n"
+					+ "<html lang=\"en\">\n"
+					+ "  <head>\n"
+					+ "    <title>Secured</title>\n"
+					+ "  </head>\n"
+					+ "  <body>\n"
+					+ "    <h1>Secured</h1>\n"
+					+ "  </body>\n"
+					+ "</html>";
+		}
+	}
+}

+ 11 - 1
web/src/main/java/org/springframework/security/web/server/authentication/RedirectServerAuthenticationEntryPoint.java

@@ -20,6 +20,8 @@ import java.net.URI;
 
 import org.springframework.security.web.server.DefaultServerRedirectStrategy;
 import org.springframework.security.web.server.ServerRedirectStrategy;
+import org.springframework.security.web.server.savedrequest.ServerRequestCache;
+import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache;
 import reactor.core.publisher.Mono;
 
 import org.springframework.security.core.AuthenticationException;
@@ -39,14 +41,22 @@ public class RedirectServerAuthenticationEntryPoint
 
 	private ServerRedirectStrategy serverRedirectStrategy = new DefaultServerRedirectStrategy();
 
+	private ServerRequestCache requestCache = new WebSessionServerRequestCache();
+
 	public RedirectServerAuthenticationEntryPoint(String location) {
 		Assert.notNull(location, "location cannot be null");
 		this.location = URI.create(location);
 	}
 
+	public void setRequestCache(ServerRequestCache requestCache) {
+		Assert.notNull(requestCache, "requestCache cannot be null");
+		this.requestCache = requestCache;
+	}
+
 	@Override
 	public Mono<Void> commence(ServerWebExchange exchange, AuthenticationException e) {
-		return this.serverRedirectStrategy.sendRedirect(exchange, this.location);
+		return this.requestCache.saveRequest(exchange)
+			.then(this.serverRedirectStrategy.sendRedirect(exchange, this.location));
 	}
 
 	/**

+ 14 - 1
web/src/main/java/org/springframework/security/web/server/authentication/RedirectServerAuthenticationSuccessHandler.java

@@ -20,6 +20,8 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.web.server.DefaultServerRedirectStrategy;
 import org.springframework.security.web.server.ServerRedirectStrategy;
 import org.springframework.security.web.server.WebFilterExchange;
+import org.springframework.security.web.server.savedrequest.ServerRequestCache;
+import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache;
 import org.springframework.util.Assert;
 import org.springframework.web.server.ServerWebExchange;
 import reactor.core.publisher.Mono;
@@ -36,17 +38,28 @@ public class RedirectServerAuthenticationSuccessHandler
 
 	private ServerRedirectStrategy serverRedirectStrategy = new DefaultServerRedirectStrategy();
 
+	private ServerRequestCache requestCache = new WebSessionServerRequestCache();
+
 	public RedirectServerAuthenticationSuccessHandler() {}
 
 	public RedirectServerAuthenticationSuccessHandler(String location) {
 		this.location = URI.create(location);
 	}
 
+	public void setRequestCache(ServerRequestCache requestCache) {
+		Assert.notNull(requestCache, "requestCache cannot be null");
+		this.requestCache = requestCache;
+	}
+
 	@Override
 	public Mono<Void> onAuthenticationSuccess(WebFilterExchange webFilterExchange,
 		Authentication authentication) {
 		ServerWebExchange exchange = webFilterExchange.getExchange();
-		return this.serverRedirectStrategy.sendRedirect(exchange, this.location);
+		return this.requestCache.getRequest(exchange)
+			.map(r -> r.getPath().pathWithinApplication().value())
+			.map(URI::create)
+			.defaultIfEmpty(this.location)
+			.flatMap(location -> this.serverRedirectStrategy.sendRedirect(exchange, location));
 	}
 
 	/**

+ 54 - 0
web/src/main/java/org/springframework/security/web/server/savedrequest/NoOpServerRequestCache.java

@@ -0,0 +1,54 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.server.savedrequest;
+
+import org.springframework.http.server.reactive.ServerHttpRequest;
+import org.springframework.web.server.ServerWebExchange;
+import reactor.core.publisher.Mono;
+
+/**
+ * @author Rob Winch
+ * @since 5.0
+ */
+public class NoOpServerRequestCache implements ServerRequestCache {
+	@Override
+	public Mono<Void> saveRequest(ServerWebExchange exchange) {
+		return Mono.empty();
+	}
+
+	@Override
+	public Mono<ServerHttpRequest> getRequest(ServerWebExchange exchange) {
+		return Mono.empty();
+	}
+
+	@Override
+	public Mono<ServerHttpRequest> getMatchingRequest(
+		ServerWebExchange exchange) {
+		return Mono.empty();
+	}
+
+	@Override
+	public Mono<ServerHttpRequest> removeRequest(ServerWebExchange exchange) {
+		return Mono.empty();
+	}
+
+	public static NoOpServerRequestCache getInstance() {
+		return new NoOpServerRequestCache();
+	}
+
+	private NoOpServerRequestCache() {}
+}

+ 64 - 0
web/src/main/java/org/springframework/security/web/server/savedrequest/ServerRequestCache.java

@@ -0,0 +1,64 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.server.savedrequest;
+
+import org.springframework.http.server.reactive.ServerHttpRequest;
+import org.springframework.web.server.ServerWebExchange;
+import reactor.core.publisher.Mono;
+
+/**
+ * Saves a {@link ServerHttpRequest} so it can be "replayed" later. This is useful for
+ * when a page was requested and authentication is necessary.
+ *
+ * @author Rob Winch
+ * @since 5.0
+ */
+public interface ServerRequestCache {
+
+	/**
+	 * Save the {@link ServerHttpRequest}
+	 * @param exchange the exchange to save
+	 * @return Return a {@code Mono<Void>} which only replays complete and error signals
+	 * from this {@link Mono}.
+	 */
+	Mono<Void> saveRequest(ServerWebExchange exchange);
+
+	/**
+	 * Get the saved {@link ServerHttpRequest}
+	 * @param exchange the exchange to obtain the saved {@link ServerHttpRequest} from
+	 * @return the {@link ServerHttpRequest}
+	 */
+	Mono<ServerHttpRequest> getRequest(ServerWebExchange exchange);
+
+	/**
+	 * If the provided {@link ServerWebExchange} matches the saved {@link ServerHttpRequest}
+	 * gets the saved {@link ServerHttpRequest}
+	 * @param exchange the exchange to obtain the request from
+	 * @return the {@link ServerHttpRequest}
+	 */
+	Mono<ServerHttpRequest> getMatchingRequest(ServerWebExchange exchange);
+
+	/**
+	 * If the {@link ServerWebExchange} contains a saved {@link ServerHttpRequest} remove
+	 * and return it.
+	 *
+	 * @param exchange the {@link ServerWebExchange} to obtain and remove the
+	 * {@link ServerHttpRequest}
+	 * @return the {@link ServerHttpRequest}
+	 */
+	Mono<ServerHttpRequest> removeRequest(ServerWebExchange exchange);
+}

+ 47 - 0
web/src/main/java/org/springframework/security/web/server/savedrequest/ServerRequestCacheWebFilter.java

@@ -0,0 +1,47 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.server.savedrequest;
+
+import org.springframework.util.Assert;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.WebFilter;
+import org.springframework.web.server.WebFilterChain;
+import reactor.core.publisher.Mono;
+
+/**
+ * A {@link WebFilter} that replays any matching request in {@link ServerRequestCache}
+ *
+ * @author Rob Winch
+ * @since 5.0
+ */
+public class ServerRequestCacheWebFilter implements WebFilter {
+	private ServerRequestCache requestCache = new WebSessionServerRequestCache();
+
+	@Override
+	public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
+		return this.requestCache.getMatchingRequest(exchange)
+			.flatMap(r -> this.requestCache.removeRequest(exchange))
+			.map(r -> exchange.mutate().request(r).build())
+			.defaultIfEmpty(exchange)
+			.flatMap(e -> chain.filter(e));
+	}
+
+	public void setRequestCache(ServerRequestCache requestCache) {
+		Assert.notNull(requestCache, "requestCache cannot be null");
+		this.requestCache = requestCache;
+	}
+}

+ 99 - 0
web/src/main/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCache.java

@@ -0,0 +1,99 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.server.savedrequest;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.springframework.http.HttpMethod;
+import org.springframework.http.server.reactive.ServerHttpRequest;
+import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
+import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;
+import org.springframework.util.Assert;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.WebSession;
+import reactor.core.publisher.Mono;
+
+import java.net.URI;
+
+/**
+ * An implementation of {@link ServerRequestCache} that saves the
+ * {@link ServerHttpRequest} in the {@link WebSession}.
+ *
+ * The current implementation only saves the URL that was requested.
+ *
+ * @author Rob Winch
+ * @since 5.0
+ */
+public class WebSessionServerRequestCache implements ServerRequestCache {
+	private static final String DEFAULT_SAVED_REQUEST_ATTR = "SPRING_SECURITY_SAVED_REQUEST";
+
+	protected final Log logger = LogFactory.getLog(this.getClass());
+
+	private String sessionAttrName = DEFAULT_SAVED_REQUEST_ATTR;
+
+	private ServerWebExchangeMatcher saveRequestMatcher = ServerWebExchangeMatchers.pathMatchers(
+		HttpMethod.GET, "/**");
+
+	/**
+	 * Sets the matcher to determine if the request should be saved. The default is to match
+	 * on any GET request.
+	 *
+	 * @param saveRequestMatcher
+	 */
+	public void setSaveRequestMatcher(ServerWebExchangeMatcher saveRequestMatcher) {
+		Assert.notNull(saveRequestMatcher, "saveRequestMatcher cannot be null");
+		this.saveRequestMatcher = saveRequestMatcher;
+	}
+
+	@Override
+	public Mono<Void> saveRequest(ServerWebExchange exchange) {
+		return this.saveRequestMatcher.matches(exchange)
+			.filter(m -> m.isMatch())
+			.flatMap(m -> exchange.getSession())
+			.map(WebSession::getAttributes)
+			.doOnNext(attrs -> attrs.put(this.sessionAttrName, pathInApplication(exchange.getRequest())))
+			.then();
+	}
+
+	@Override
+	public Mono<ServerHttpRequest> getRequest(ServerWebExchange exchange) {
+		return exchange.getSession()
+			.flatMap(session -> Mono.justOrEmpty(session.<String>getAttribute(this.sessionAttrName)))
+			.map(path -> exchange.getRequest().mutate().path(path).build());
+	}
+
+	@Override
+	public Mono<ServerHttpRequest> getMatchingRequest(
+		ServerWebExchange exchange) {
+		return getRequest(exchange)
+			.filter( request -> pathInApplication(request).equals(
+				pathInApplication(exchange.getRequest())));
+	}
+
+	@Override
+	public Mono<ServerHttpRequest> removeRequest(ServerWebExchange exchange) {
+		return exchange.getSession()
+			.map(WebSession::getAttributes)
+			.flatMap(attrs -> Mono.justOrEmpty(attrs.remove(this.sessionAttrName)))
+			.cast(String.class)
+			.map(path -> exchange.getRequest().mutate().path(path).build());
+	}
+
+	private static String pathInApplication(ServerHttpRequest request) {
+		return request.getPath().pathWithinApplication().value();
+	}
+}

+ 82 - 0
web/src/test/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCacheTests.java

@@ -0,0 +1,82 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.server.savedrequest;
+
+import org.junit.Test;
+import org.springframework.http.server.reactive.ServerHttpRequest;
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
+import org.springframework.mock.web.server.MockServerWebExchange;
+import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
+
+import static org.assertj.core.api.Assertions.*;
+
+/**
+ * @author Rob Winch
+ * @since 5.0
+ */
+public class WebSessionServerRequestCacheTests {
+	private WebSessionServerRequestCache cache = new WebSessionServerRequestCache();
+
+	@Test
+	public void saveRequestGetRequestWhenGetThenFound() {
+		MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/secured/"));
+		this.cache.saveRequest(exchange).block();
+
+		ServerHttpRequest saved = this.cache.getRequest(exchange).block();
+
+		assertThat(saved.getURI()).isEqualTo(exchange.getRequest().getURI());
+	}
+
+	@Test
+	public void saveRequestGetRequestWhenPostThenNotFound() {
+		MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.post("/secured/"));
+		this.cache.saveRequest(exchange).block();
+
+		assertThat(this.cache.getRequest(exchange).block()).isNull();
+	}
+
+	@Test
+	public void saveRequestGetRequestWhenPostAndCustomMatcherThenFound() {
+		this.cache.setSaveRequestMatcher(e -> ServerWebExchangeMatcher.MatchResult.match());
+		MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.post("/secured/"));
+		this.cache.saveRequest(exchange).block();
+
+		ServerHttpRequest saved = this.cache.getRequest(exchange).block();
+
+		assertThat(saved.getURI()).isEqualTo(exchange.getRequest().getURI());
+	}
+
+	@Test
+	public void saveRequestRemoveRequestWhenThenFound() {
+		MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/secured/"));
+		this.cache.saveRequest(exchange).block();
+
+		ServerHttpRequest saved = this.cache.removeRequest(exchange).block();
+
+		assertThat(saved.getURI()).isEqualTo(exchange.getRequest().getURI());
+	}
+
+	@Test
+	public void removeRequestGetRequestWhenDefaultThenNotFound() {
+		MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/secured/"));
+		this.cache.saveRequest(exchange).block();
+
+		this.cache.removeRequest(exchange).block();
+
+		assertThat(this.cache.getRequest(exchange).block()).isNull();
+	}
+}