Quellcode durchsuchen

Add ServerRequestCache setter in OAuth2AuthorizationCodeGrantWebFilter

Fixes gh-8536
Parikshit Dutta vor 5 Jahren
Ursprung
Commit
28d2cfa14a

+ 8 - 0
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -236,6 +236,7 @@ import static org.springframework.security.web.server.util.matcher.ServerWebExch
  * @author Rafiullah Hamedy
  * @author Eddú Meléndez
  * @author Joe Grandja
+ * @author Parikshit Dutta
  * @since 5.0
  */
 public class ServerHttpSecurity {
@@ -1511,10 +1512,17 @@ public class ServerHttpSecurity {
 			OAuth2AuthorizationCodeGrantWebFilter codeGrantWebFilter = new OAuth2AuthorizationCodeGrantWebFilter(
 					authenticationManager, authenticationConverter, authorizedClientRepository);
 			codeGrantWebFilter.setAuthorizationRequestRepository(getAuthorizationRequestRepository());
+			if (http.requestCache != null) {
+				codeGrantWebFilter.setRequestCache(http.requestCache.requestCache);
+			}
 
 			OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter(
 					clientRegistrationRepository);
 			oauthRedirectFilter.setAuthorizationRequestRepository(getAuthorizationRequestRepository());
+			if (http.requestCache != null) {
+				oauthRedirectFilter.setRequestCache(http.requestCache.requestCache);
+			}
+
 			http.addFilterAt(codeGrantWebFilter, SecurityWebFiltersOrder.OAUTH2_AUTHORIZATION_CODE);
 			http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC);
 		}

+ 20 - 4
config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,6 +16,8 @@
 
 package org.springframework.security.config.web.server;
 
+import java.net.URI;
+
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -48,6 +50,7 @@ import org.springframework.security.test.context.annotation.SecurityTestExecutio
 import org.springframework.security.test.context.support.WithMockUser;
 import org.springframework.security.web.server.SecurityWebFilterChain;
 import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
+import org.springframework.security.web.server.savedrequest.ServerRequestCache;
 import org.springframework.test.context.junit4.SpringRunner;
 import org.springframework.test.web.reactive.server.WebTestClient;
 import org.springframework.web.bind.annotation.GetMapping;
@@ -62,6 +65,7 @@ import static org.mockito.Mockito.when;
 
 /**
  * @author Rob Winch
+ * @author Parikshit Dutta
  * @since 5.1
  */
 @RunWith(SpringRunner.class)
@@ -146,6 +150,7 @@ public class OAuth2ClientSpecTests {
 		ServerAuthenticationConverter converter = config.authenticationConverter;
 		ReactiveAuthenticationManager manager = config.manager;
 		ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = config.authorizationRequestRepository;
+		ServerRequestCache requestCache = config.requestCache;
 
 		OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
 				.redirectUri("/authorize/oauth2/code/registration-id")
@@ -163,6 +168,7 @@ public class OAuth2ClientSpecTests {
 		when(authorizationRequestRepository.loadAuthorizationRequest(any())).thenReturn(Mono.just(authorizationRequest));
 		when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c")));
 		when(manager.authenticate(any())).thenReturn(Mono.just(result));
+		when(requestCache.getRedirectUri(any())).thenReturn(Mono.just(URI.create("/saved-request")));
 
 		this.client.get()
 				.uri(uriBuilder ->
@@ -175,6 +181,7 @@ public class OAuth2ClientSpecTests {
 
 		verify(converter).convert(any());
 		verify(manager).authenticate(any());
+		verify(requestCache).getRedirectUri(any());
 	}
 
 	@EnableWebFlux
@@ -197,13 +204,17 @@ public class OAuth2ClientSpecTests {
 
 		ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class);
 
+		ServerRequestCache requestCache = mock(ServerRequestCache.class);
+
 		@Bean
 		public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) {
 			http
 				.oauth2Client()
 					.authenticationConverter(this.authenticationConverter)
 					.authenticationManager(this.manager)
-					.authorizationRequestRepository(this.authorizationRequestRepository);
+					.authorizationRequestRepository(this.authorizationRequestRepository)
+					.and()
+				.requestCache(c -> c.requestCache(this.requestCache));
 			return http.build();
 		}
 	}
@@ -217,6 +228,7 @@ public class OAuth2ClientSpecTests {
 		ServerAuthenticationConverter converter = config.authenticationConverter;
 		ReactiveAuthenticationManager manager = config.manager;
 		ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = config.authorizationRequestRepository;
+		ServerRequestCache requestCache = config.requestCache;
 
 		OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
 				.redirectUri("/authorize/oauth2/code/registration-id")
@@ -234,6 +246,7 @@ public class OAuth2ClientSpecTests {
 		when(authorizationRequestRepository.loadAuthorizationRequest(any())).thenReturn(Mono.just(authorizationRequest));
 		when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c")));
 		when(manager.authenticate(any())).thenReturn(Mono.just(result));
+		when(requestCache.getRedirectUri(any())).thenReturn(Mono.just(URI.create("/saved-request")));
 
 		this.client.get()
 				.uri(uriBuilder ->
@@ -246,6 +259,7 @@ public class OAuth2ClientSpecTests {
 
 		verify(converter).convert(any());
 		verify(manager).authenticate(any());
+		verify(requestCache).getRedirectUri(any());
 	}
 
 	@Configuration
@@ -256,6 +270,8 @@ public class OAuth2ClientSpecTests {
 
 		ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class);
 
+		ServerRequestCache requestCache = mock(ServerRequestCache.class);
+
 		@Bean
 		public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) {
 			http
@@ -263,8 +279,8 @@ public class OAuth2ClientSpecTests {
 					oauth2Client
 						.authenticationConverter(this.authenticationConverter)
 						.authenticationManager(this.manager)
-						.authorizationRequestRepository(this.authorizationRequestRepository)
-				);
+						.authorizationRequestRepository(this.authorizationRequestRepository))
+				.requestCache(c -> c.requestCache(this.requestCache));
 			return http.build();
 		}
 	}

+ 30 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java

@@ -35,6 +35,8 @@ import org.springframework.security.web.server.authentication.RedirectServerAuth
 import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
 import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler;
 import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler;
+import org.springframework.security.web.server.savedrequest.ServerRequestCache;
+import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
 import org.springframework.util.Assert;
 import org.springframework.web.server.ServerWebExchange;
@@ -80,6 +82,7 @@ import java.util.Set;
  *
  * @author Rob Winch
  * @author Joe Grandja
+ * @author Parikshit Dutta
  * @since 5.1
  * @see OAuth2AuthorizationCodeAuthenticationToken
  * @see org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeReactiveAuthenticationManager
@@ -111,6 +114,8 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter {
 
 	private ServerWebExchangeMatcher requiresAuthenticationMatcher;
 
+	private ServerRequestCache requestCache = new WebSessionServerRequestCache();
+
 	private AnonymousAuthenticationToken anonymousToken = new AnonymousAuthenticationToken("key", "anonymous",
 					AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
 
@@ -129,7 +134,10 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter {
 		authenticationConverter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
 		this.authenticationConverter = authenticationConverter;
 		this.defaultAuthenticationConverter = true;
-		this.authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler();
+		RedirectServerAuthenticationSuccessHandler authenticationSuccessHandler =
+				new RedirectServerAuthenticationSuccessHandler();
+		authenticationSuccessHandler.setRequestCache(this.requestCache);
+		this.authenticationSuccessHandler = authenticationSuccessHandler;
 		this.authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception);
 	}
 
@@ -144,7 +152,10 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter {
 		this.authorizedClientRepository = authorizedClientRepository;
 		this.requiresAuthenticationMatcher = this::matchesAuthorizationResponse;
 		this.authenticationConverter = authenticationConverter;
-		this.authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler();
+		RedirectServerAuthenticationSuccessHandler authenticationSuccessHandler =
+				new RedirectServerAuthenticationSuccessHandler();
+		authenticationSuccessHandler.setRequestCache(this.requestCache);
+		this.authenticationSuccessHandler = authenticationSuccessHandler;
 		this.authenticationFailureHandler = (webFilterExchange, exception) -> Mono.error(exception);
 	}
 
@@ -169,6 +180,23 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter {
 		}
 	}
 
+	/**
+	 * Sets the {@link ServerRequestCache} used for loading a previously saved request (if available)
+	 * and replaying it after completing the processing of the OAuth 2.0 Authorization Response.
+	 *
+	 * @since 5.4
+	 * @param requestCache the cache used for loading a previously saved request (if available)
+	 */
+	public final void setRequestCache(ServerRequestCache requestCache) {
+		Assert.notNull(requestCache, "requestCache cannot be null");
+		this.requestCache = requestCache;
+		updateDefaultAuthenticationSuccessHandler();
+	}
+
+	private void updateDefaultAuthenticationSuccessHandler() {
+		((RedirectServerAuthenticationSuccessHandler) this.authenticationSuccessHandler).setRequestCache(this.requestCache);
+	}
+
 	@Override
 	public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
 		return this.requiresAuthenticationMatcher.matches(exchange)

+ 46 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java

@@ -31,17 +31,22 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.web.server.savedrequest.ServerRequestCache;
 import org.springframework.util.CollectionUtils;
+import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.handler.DefaultWebFilterChain;
 import reactor.core.publisher.Mono;
 
+import java.net.URI;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.Map;
 
+import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatCode;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoInteractions;
@@ -50,6 +55,7 @@ import static org.springframework.security.oauth2.core.endpoint.TestOAuth2Author
 
 /**
  * @author Rob Winch
+ * @author Parikshit Dutta
  * @since 5.1
  */
 @RunWith(MockitoJUnitRunner.class)
@@ -99,6 +105,12 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests {
 				.isInstanceOf(IllegalArgumentException.class);
 	}
 
+	@Test
+	public void setRequestCacheWhenRequestCacheIsNullThenThrowIllegalArgumentException() {
+		assertThatCode(() -> this.filter.setRequestCache(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
 	@Test
 	public void filterWhenNotMatchThenAuthenticationManagerNotCalled() {
 		MockServerWebExchange exchange = MockServerWebExchange
@@ -233,6 +245,40 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests {
 		verifyNoInteractions(this.authenticationManager);
 	}
 
+	@Test
+	public void filterWhenAuthorizationSucceedsAndRequestCacheConfiguredThenRequestCacheUsed() {
+		ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
+		when(this.clientRegistrationRepository.findByRegistrationId(any()))
+				.thenReturn(Mono.just(clientRegistration));
+		when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any()))
+				.thenReturn(Mono.empty());
+		when(this.authenticationManager.authenticate(any()))
+				.thenReturn(Mono.just(TestOAuth2AuthorizationCodeAuthenticationTokens.authenticated()));
+
+		MockServerHttpRequest authorizationRequest = createAuthorizationRequest("/authorization/callback");
+		OAuth2AuthorizationRequest oauth2AuthorizationRequest =
+				createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration);
+		when(this.authorizationRequestRepository.loadAuthorizationRequest(any()))
+				.thenReturn(Mono.just(oauth2AuthorizationRequest));
+		when(this.authorizationRequestRepository.removeAuthorizationRequest(any()))
+				.thenReturn(Mono.just(oauth2AuthorizationRequest));
+
+		MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
+		MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse);
+		DefaultWebFilterChain chain = new DefaultWebFilterChain(
+				e -> e.getResponse().setComplete(), Collections.emptyList());
+
+		ServerRequestCache requestCache = mock(ServerRequestCache.class);
+		when(requestCache.getRedirectUri(any(ServerWebExchange.class))).thenReturn(Mono.just(URI.create("/saved-request")));
+
+		this.filter.setRequestCache(requestCache);
+
+		this.filter.filter(exchange, chain).block();
+
+		verify(requestCache).getRedirectUri(exchange);
+		assertThat(exchange.getResponse().getHeaders().getLocation().toString()).isEqualTo("/saved-request");
+	}
+
 	private static OAuth2AuthorizationRequest createOAuth2AuthorizationRequest(
 			MockServerHttpRequest authorizationRequest, ClientRegistration registration) {
 		Map<String, Object> attributes = new HashMap<>();