浏览代码

Add RequestCache setter in OAuth2AuthorizationCodeGrantFilter

Fixes gh-8120
Parikshit Dutta 5 年之前
父节点
当前提交
1e211b6558

+ 6 - 1
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -79,6 +79,7 @@ import org.springframework.util.Assert;
  * </ul>
  *
  * @author Joe Grandja
+ * @author Parikshit Dutta
  * @since 5.1
  * @see OAuth2AuthorizationRequestRedirectFilter
  * @see OAuth2AuthorizationCodeGrantFilter
@@ -256,6 +257,10 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> exte
 			if (this.authorizationRequestRepository != null) {
 				authorizationCodeGrantFilter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
 			}
+			RequestCache requestCache = builder.getSharedObject(RequestCache.class);
+			if (requestCache != null) {
+				authorizationCodeGrantFilter.setRequestCache(requestCache);
+			}
 			return authorizationCodeGrantFilter;
 		}
 

+ 39 - 1
config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -75,6 +75,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.
  * Tests for {@link OAuth2ClientConfigurer}.
  *
  * @author Joe Grandja
+ * @author Parikshit Dutta
  */
 public class OAuth2ClientConfigurerTests {
 	private static ClientRegistrationRepository clientRegistrationRepository;
@@ -208,6 +209,43 @@ public class OAuth2ClientConfigurerTests {
 		verify(requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
 	}
 
+	@Test
+	public void configureWhenRequestCacheProvidedAndClientAuthorizationSucceedsThenRequestCacheUsed() throws Exception {
+		this.spring.register(OAuth2ClientConfig.class).autowire();
+
+		// Setup the Authorization Request in the session
+		Map<String, Object> attributes = new HashMap<>();
+		attributes.put(OAuth2ParameterNames.REGISTRATION_ID, this.registration1.getRegistrationId());
+		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri(this.registration1.getProviderDetails().getAuthorizationUri())
+				.clientId(this.registration1.getClientId())
+				.redirectUri("http://localhost/client-1")
+				.state("state")
+				.attributes(attributes)
+				.build();
+
+		AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
+				new HttpSessionOAuth2AuthorizationRequestRepository();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", "");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
+
+		MockHttpSession session = (MockHttpSession) request.getSession();
+
+		String principalName = "user1";
+		TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password");
+
+		this.mockMvc.perform(get("/client-1")
+				.param(OAuth2ParameterNames.CODE, "code")
+				.param(OAuth2ParameterNames.STATE, "state")
+				.with(authentication(authentication))
+				.session(session))
+				.andExpect(status().is3xxRedirection())
+				.andExpect(redirectedUrl("http://localhost/client-1"));
+
+		verify(requestCache).getRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
+	}
+
 	// gh-5521
 	@Test
 	public void configureWhenCustomAuthorizationRequestResolverSetThenAuthorizationRequestIncludesCustomParameters() throws Exception {

+ 14 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java

@@ -83,6 +83,7 @@ import java.util.Set;
  * </ul>
  *
  * @author Joe Grandja
+ * @author Parikshit Dutta
  * @since 5.1
  * @see OAuth2AuthorizationCodeAuthenticationToken
  * @see OAuth2AuthorizationCodeAuthenticationProvider
@@ -104,7 +105,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
 		new HttpSessionOAuth2AuthorizationRequestRepository();
 	private final AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();
 	private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
-	private final RequestCache requestCache = new HttpSessionRequestCache();
+	private RequestCache requestCache = new HttpSessionRequestCache();
 
 	/**
 	 * Constructs an {@code OAuth2AuthorizationCodeGrantFilter} using the provided parameters.
@@ -134,6 +135,18 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
 		this.authorizationRequestRepository = authorizationRequestRepository;
 	}
 
+	/**
+	 * Sets the {@link RequestCache} used for loading a previously saved request (if available)
+	 * and replaying it after completing the processing of the OAuth 2.0 Authorization Response.
+	 *
+	 * @since 5.4
+	 * @param requestCache the cache used for loading a previously saved request (if available)
+	 */
+	public final void setRequestCache(RequestCache requestCache) {
+		Assert.notNull(requestCache, "requestCache cannot be null");
+		this.requestCache = requestCache;
+	}
+
 	@Override
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 		throws ServletException, IOException {

+ 29 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java

@@ -72,6 +72,7 @@ import static org.springframework.security.oauth2.core.endpoint.TestOAuth2Author
  * Tests for {@link OAuth2AuthorizationCodeGrantFilter}.
  *
  * @author Joe Grandja
+ * @author Parikshit Dutta
  */
 public class OAuth2AuthorizationCodeGrantFilterTests {
 	private ClientRegistration registration1;
@@ -130,6 +131,12 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
 				.isInstanceOf(IllegalArgumentException.class);
 	}
 
+	@Test
+	public void setRequestCacheWhenRequestCacheIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.filter.setRequestCache(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
 	@Test
 	public void doFilterWhenNotAuthorizationResponseThenNotProcessed() throws Exception {
 		String requestUri = "/path";
@@ -326,6 +333,28 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
 		assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/saved-request");
 	}
 
+	@Test
+	public void doFilterWhenAuthorizationSucceedsAndRequestCacheConfiguredThenRequestCacheUsed() throws Exception {
+		MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
+		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+
+		FilterChain filterChain = mock(FilterChain.class);
+		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
+		this.setUpAuthenticationResult(this.registration1);
+
+		RequestCache requestCache = spy(HttpSessionRequestCache.class);
+		this.filter.setRequestCache(requestCache);
+
+		authorizationRequest.setRequestURI("/saved-request");
+		requestCache.saveRequest(authorizationRequest, response);
+
+		this.filter.doFilter(authorizationResponse, response, filterChain);
+
+		verify(requestCache).getRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
+		assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/saved-request");
+	}
+
 	@Test
 	public void doFilterWhenAuthorizationSucceedsAndAnonymousAccessThenAuthorizedClientSavedToHttpSession() throws Exception {
 		AnonymousAuthenticationToken anonymousPrincipal =