2
0
Эх сурвалжийг харах

OAuth2AuthorizationCodeGrantFilter matches on query parameters

Fixes gh-7963
Joe Grandja 5 жил өмнө
parent
commit
3c86239b39

+ 33 - 15
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -41,6 +41,7 @@ import org.springframework.util.Assert;
 import org.springframework.util.MultiValueMap;
 import org.springframework.util.StringUtils;
 import org.springframework.web.filter.OncePerRequestFilter;
+import org.springframework.web.util.UriComponents;
 import org.springframework.web.util.UriComponentsBuilder;
 
 import javax.servlet.FilterChain;
@@ -48,6 +49,11 @@ import javax.servlet.ServletException;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 import java.io.IOException;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
 
 /**
  * A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
@@ -132,24 +138,39 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 		throws ServletException, IOException {
 
-		if (this.shouldProcessAuthorizationResponse(request)) {
-			this.processAuthorizationResponse(request, response);
+		if (matchesAuthorizationResponse(request)) {
+			processAuthorizationResponse(request, response);
 			return;
 		}
 
 		filterChain.doFilter(request, response);
 	}
 
-	private boolean shouldProcessAuthorizationResponse(HttpServletRequest request) {
+	private boolean matchesAuthorizationResponse(HttpServletRequest request) {
+		MultiValueMap<String, String> params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap());
+		if (!OAuth2AuthorizationResponseUtils.isAuthorizationResponse(params)) {
+			return false;
+		}
 		OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository.loadAuthorizationRequest(request);
 		if (authorizationRequest == null) {
 			return false;
 		}
-		String requestUrl = UrlUtils.buildFullRequestUrl(request.getScheme(), request.getServerName(),
-				request.getServerPort(), request.getRequestURI(), null);
-		MultiValueMap<String, String> params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap());
-		if (requestUrl.equals(authorizationRequest.getRedirectUri()) &&
-				OAuth2AuthorizationResponseUtils.isAuthorizationResponse(params)) {
+
+		// Compare redirect_uri
+		UriComponents requestUri = UriComponentsBuilder.fromUriString(UrlUtils.buildFullRequestUrl(request)).build();
+		UriComponents redirectUri = UriComponentsBuilder.fromUriString(authorizationRequest.getRedirectUri()).build();
+		Set<Map.Entry<String, List<String>>> requestUriParameters = new LinkedHashSet<>(requestUri.getQueryParams().entrySet());
+		Set<Map.Entry<String, List<String>>> redirectUriParameters = new LinkedHashSet<>(redirectUri.getQueryParams().entrySet());
+		// Remove the additional request parameters (if any) from the authorization response (request)
+		// before doing an exact comparison with the authorizationRequest.getRedirectUri() parameters (if any)
+		requestUriParameters.retainAll(redirectUriParameters);
+
+		if (Objects.equals(requestUri.getScheme(), redirectUri.getScheme()) &&
+				Objects.equals(requestUri.getUserInfo(), redirectUri.getUserInfo()) &&
+				Objects.equals(requestUri.getHost(), redirectUri.getHost()) &&
+				Objects.equals(requestUri.getPort(), redirectUri.getPort()) &&
+				Objects.equals(requestUri.getPath(), redirectUri.getPath()) &&
+				Objects.equals(requestUriParameters.toString(), redirectUriParameters.toString())) {
 			return true;
 		}
 		return false;
@@ -165,10 +186,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
 		ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId);
 
 		MultiValueMap<String, String> params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap());
-		String redirectUri = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
-				.replaceQuery(null)
-				.build()
-				.toUriString();
+		String redirectUri = UrlUtils.buildFullRequestUrl(request);
 		OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(params, redirectUri);
 
 		OAuth2AuthorizationCodeAuthenticationToken authenticationRequest = new OAuth2AuthorizationCodeAuthenticationToken(
@@ -183,7 +201,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
 		} catch (OAuth2AuthorizationException ex) {
 			OAuth2Error error = ex.getError();
 			UriComponentsBuilder uriBuilder = UriComponentsBuilder
-				.fromUriString(authorizationResponse.getRedirectUri())
+				.fromUriString(authorizationRequest.getRedirectUri())
 				.queryParam(OAuth2ParameterNames.ERROR, error.getErrorCode());
 			if (!StringUtils.isEmpty(error.getDescription())) {
 				uriBuilder.queryParam(OAuth2ParameterNames.ERROR_DESCRIPTION, error.getDescription());
@@ -206,7 +224,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
 
 		this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, currentAuthentication, request, response);
 
-		String redirectUrl = authorizationResponse.getRedirectUri();
+		String redirectUrl = authorizationRequest.getRedirectUri();
 		SavedRequest savedRequest = this.requestCache.getRequest(request, response);
 		if (savedRequest != null) {
 			redirectUrl = savedRequest.getRedirectUrl();

+ 171 - 102
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,17 +15,9 @@
  */
 package org.springframework.security.oauth2.client.web;
 
-import java.util.HashMap;
-import java.util.Map;
-import javax.servlet.FilterChain;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-import javax.servlet.http.HttpSession;
-
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
-
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.authentication.AnonymousAuthenticationToken;
@@ -50,13 +42,26 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
 import org.springframework.security.web.savedrequest.RequestCache;
+import org.springframework.security.web.util.UrlUtils;
+import org.springframework.util.CollectionUtils;
+
+import javax.servlet.FilterChain;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import javax.servlet.http.HttpSession;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.stream.Collectors;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
 import static org.mockito.Mockito.when;
 import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
 import static org.springframework.security.oauth2.core.TestOAuth2RefreshTokens.refreshToken;
@@ -131,8 +136,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 		request.setServletPath(requestUri);
 		// NOTE: A valid Authorization Response contains either a 'code' or 'error' parameter.
-
-		HttpServletResponse response = mock(HttpServletResponse.class);
+		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
 
 		this.filter.doFilter(request, response, filterChain);
@@ -142,94 +146,142 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
 
 	@Test
 	public void doFilterWhenAuthorizationRequestNotFoundThenNotProcessed() throws Exception {
-		String requestUri = "/path";
-		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
-		request.setServletPath(requestUri);
-		request.addParameter(OAuth2ParameterNames.CODE, "code");
-		request.addParameter(OAuth2ParameterNames.STATE, "state");
-
-		HttpServletResponse response = mock(HttpServletResponse.class);
+		MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/path");
+		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
+		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
 
-		this.filter.doFilter(request, response, filterChain);
+		this.filter.doFilter(authorizationResponse, response, filterChain);
 
 		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
 	}
 
 	@Test
-	public void doFilterWhenAuthorizationResponseUrlDoesNotMatchAuthorizationRequestRedirectUriThenNotProcessed() throws Exception {
+	public void doFilterWhenAuthorizationRequestRedirectUriDoesNotMatchThenNotProcessed() throws Exception {
 		String requestUri = "/callback/client-1";
-		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
-		request.setServletPath(requestUri);
-		request.addParameter(OAuth2ParameterNames.CODE, "code");
-		request.addParameter(OAuth2ParameterNames.STATE, "state");
-
-		HttpServletResponse response = mock(HttpServletResponse.class);
+		MockHttpServletRequest authorizationRequest = createAuthorizationRequest(requestUri);
+		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
+		authorizationResponse.setRequestURI(requestUri + "-no-match");
 		FilterChain filterChain = mock(FilterChain.class);
 
-		this.setUpAuthorizationRequest(request, response, this.registration1);
-		request.setRequestURI(requestUri + "-no-match");
-
-		this.filter.doFilter(request, response, filterChain);
+		this.filter.doFilter(authorizationResponse, response, filterChain);
 
 		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
 	}
 
+	// gh-7963
 	@Test
-	public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved() throws Exception {
+	public void doFilterWhenAuthorizationRequestRedirectUriParametersMatchThenProcessed() throws Exception {
+		// 1) redirect_uri with query parameters
 		String requestUri = "/callback/client-1";
-		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
-		request.setServletPath(requestUri);
-		request.addParameter(OAuth2ParameterNames.CODE, "code");
-		request.addParameter(OAuth2ParameterNames.STATE, "state");
+		Map<String, String> parameters = new LinkedHashMap<>();
+		parameters.put("param1", "value1");
+		parameters.put("param2", "value2");
+		MockHttpServletRequest authorizationRequest = createAuthorizationRequest(requestUri, parameters);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
+		this.setUpAuthenticationResult(this.registration1);
+		FilterChain filterChain = mock(FilterChain.class);
+		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
+		this.filter.doFilter(authorizationResponse, response, filterChain);
+		verifyNoInteractions(filterChain);
+
+		// 2) redirect_uri with query parameters AND authorization response additional parameters
+		Map<String, String> additionalParameters = new LinkedHashMap<>();
+		additionalParameters.put("auth-param1", "value1");
+		additionalParameters.put("auth-param2", "value2");
+		response = new MockHttpServletResponse();
+		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
+		authorizationResponse = createAuthorizationResponse(authorizationRequest, additionalParameters);
+		this.filter.doFilter(authorizationResponse, response, filterChain);
+		verifyNoInteractions(filterChain);
+	}
 
+	// gh-7963
+	@Test
+	public void doFilterWhenAuthorizationRequestRedirectUriParametersDoesNotMatchThenNotProcessed() throws Exception {
+		String requestUri = "/callback/client-1";
+		Map<String, String> parameters = new LinkedHashMap<>();
+		parameters.put("param1", "value1");
+		parameters.put("param2", "value2");
+		MockHttpServletRequest authorizationRequest = createAuthorizationRequest(requestUri, parameters);
 		MockHttpServletResponse response = new MockHttpServletResponse();
+		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
+		this.setUpAuthenticationResult(this.registration1);
 		FilterChain filterChain = mock(FilterChain.class);
 
-		this.setUpAuthorizationRequest(request, response, this.registration1);
+		// 1) Parameter value
+		Map<String, String> parametersNotMatch = new LinkedHashMap<>(parameters);
+		parametersNotMatch.put("param2", "value8");
+		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(
+				createAuthorizationRequest(requestUri, parametersNotMatch));
+		authorizationResponse.setSession(authorizationRequest.getSession());
+		this.filter.doFilter(authorizationResponse, response, filterChain);
+		verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+
+		// 2) Parameter order
+		parametersNotMatch = new LinkedHashMap<>();
+		parametersNotMatch.put("param2", "value2");
+		parametersNotMatch.put("param1", "value1");
+		authorizationResponse = createAuthorizationResponse(
+				createAuthorizationRequest(requestUri, parametersNotMatch));
+		authorizationResponse.setSession(authorizationRequest.getSession());
+		this.filter.doFilter(authorizationResponse, response, filterChain);
+		verify(filterChain, times(2)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+
+		// 3) Parameter missing
+		parametersNotMatch = new LinkedHashMap<>(parameters);
+		parametersNotMatch.remove("param2");
+		authorizationResponse = createAuthorizationResponse(
+				createAuthorizationRequest(requestUri, parametersNotMatch));
+		authorizationResponse.setSession(authorizationRequest.getSession());
+		this.filter.doFilter(authorizationResponse, response, filterChain);
+		verify(filterChain, times(3)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+	}
+
+	@Test
+	public void doFilterWhenAuthorizationRequestMatchThenAuthorizationRequestRemoved() throws Exception {
+		MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
+		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
 		this.setUpAuthenticationResult(this.registration1);
 
-		this.filter.doFilter(request, response, filterChain);
+		this.filter.doFilter(authorizationResponse, response, filterChain);
 
-		assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(request)).isNull();
+		assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(authorizationResponse)).isNull();
 	}
 
 	@Test
 	public void doFilterWhenAuthorizationFailsThenHandleOAuth2AuthorizationException() throws Exception {
-		String requestUri = "/callback/client-1";
-		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
-		request.setServletPath(requestUri);
-		request.addParameter(OAuth2ParameterNames.CODE, "code");
-		request.addParameter(OAuth2ParameterNames.STATE, "state");
-
+		MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
+		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
+		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
 
-		this.setUpAuthorizationRequest(request, response, this.registration1);
 		OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT);
 		when(this.authenticationManager.authenticate(any(Authentication.class)))
 			.thenThrow(new OAuth2AuthorizationException(error));
 
-		this.filter.doFilter(request, response, filterChain);
+		this.filter.doFilter(authorizationResponse, response, filterChain);
 
 		assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1?error=invalid_grant");
 	}
 
 	@Test
-	public void doFilterWhenAuthorizationResponseSuccessThenAuthorizedClientSavedToService() throws Exception {
-		String requestUri = "/callback/client-1";
-		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
-		request.setServletPath(requestUri);
-		request.addParameter(OAuth2ParameterNames.CODE, "code");
-		request.addParameter(OAuth2ParameterNames.STATE, "state");
-
+	public void doFilterWhenAuthorizationSucceedsThenAuthorizedClientSavedToService() throws Exception {
+		MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
+		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
-
-		this.setUpAuthorizationRequest(request, response, this.registration1);
+		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
 		this.setUpAuthenticationResult(this.registration1);
 
-		this.filter.doFilter(request, response, filterChain);
+		this.filter.doFilter(authorizationResponse, response, filterChain);
 
 		OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
 			this.registration1.getRegistrationId(), this.principalName1);
@@ -241,40 +293,31 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
 	}
 
 	@Test
-	public void doFilterWhenAuthorizationResponseSuccessThenRedirected() throws Exception {
-		String requestUri = "/callback/client-1";
-		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
-		request.setServletPath(requestUri);
-		request.addParameter(OAuth2ParameterNames.CODE, "code");
-		request.addParameter(OAuth2ParameterNames.STATE, "state");
-
+	public void doFilterWhenAuthorizationSucceedsThenRedirected() throws Exception {
+		MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
+		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
-
-		this.setUpAuthorizationRequest(request, response, this.registration1);
+		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
 		this.setUpAuthenticationResult(this.registration1);
 
-		this.filter.doFilter(request, response, filterChain);
+		this.filter.doFilter(authorizationResponse, response, filterChain);
 
 		assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1");
 	}
 
 	@Test
-	public void doFilterWhenAuthorizationResponseSuccessHasSavedRequestThenRedirectedToSavedRequest() throws Exception {
+	public void doFilterWhenAuthorizationSucceedsAndHasSavedRequestThenRedirectToSavedRequest() throws Exception {
 		String requestUri = "/saved-request";
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 		request.setServletPath(requestUri);
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		RequestCache requestCache = new HttpSessionRequestCache();
 		requestCache.saveRequest(request, response);
-
-		requestUri = "/callback/client-1";
-		request.setRequestURI(requestUri);
+		request.setRequestURI("/callback/client-1");
 		request.addParameter(OAuth2ParameterNames.CODE, "code");
 		request.addParameter(OAuth2ParameterNames.STATE, "state");
-
 		FilterChain filterChain = mock(FilterChain.class);
-
 		this.setUpAuthorizationRequest(request, response, this.registration1);
 		this.setUpAuthenticationResult(this.registration1);
 
@@ -284,36 +327,30 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
 	}
 
 	@Test
-	public void doFilterWhenAuthorizationResponseSuccessAndAnonymousAccessThenAuthorizedClientSavedToHttpSession() throws Exception {
+	public void doFilterWhenAuthorizationSucceedsAndAnonymousAccessThenAuthorizedClientSavedToHttpSession() throws Exception {
 		AnonymousAuthenticationToken anonymousPrincipal =
 				new AnonymousAuthenticationToken("key-1234", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
 		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
 		securityContext.setAuthentication(anonymousPrincipal);
 		SecurityContextHolder.setContext(securityContext);
 
-		String requestUri = "/callback/client-1";
-		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
-		request.setServletPath(requestUri);
-		request.addParameter(OAuth2ParameterNames.CODE, "code");
-		request.addParameter(OAuth2ParameterNames.STATE, "state");
-
+		MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
+		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
-
-		this.setUpAuthorizationRequest(request, response, this.registration1);
+		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
 		this.setUpAuthenticationResult(this.registration1);
 
-		this.filter.doFilter(request, response, filterChain);
+		this.filter.doFilter(authorizationResponse, response, filterChain);
 
 		OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
-				this.registration1.getRegistrationId(), anonymousPrincipal, request);
+				this.registration1.getRegistrationId(), anonymousPrincipal, authorizationResponse);
 		assertThat(authorizedClient).isNotNull();
-
 		assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1);
 		assertThat(authorizedClient.getPrincipalName()).isEqualTo(anonymousPrincipal.getName());
 		assertThat(authorizedClient.getAccessToken()).isNotNull();
 
-		HttpSession session = request.getSession(false);
+		HttpSession session = authorizationResponse.getSession(false);
 		assertThat(session).isNotNull();
 
 		@SuppressWarnings("unchecked")
@@ -325,33 +362,27 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
 	}
 
 	@Test
-	public void doFilterWhenAuthorizationResponseSuccessAndAnonymousAccessNullAuthenticationThenAuthorizedClientSavedToHttpSession() throws Exception {
+	public void doFilterWhenAuthorizationSucceedsAndAnonymousAccessNullAuthenticationThenAuthorizedClientSavedToHttpSession() throws Exception {
 		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
 		SecurityContextHolder.setContext(securityContext);		// null Authentication
 
-		String requestUri = "/callback/client-1";
-		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
-		request.setServletPath(requestUri);
-		request.addParameter(OAuth2ParameterNames.CODE, "code");
-		request.addParameter(OAuth2ParameterNames.STATE, "state");
-
+		MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
+		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
-
-		this.setUpAuthorizationRequest(request, response, this.registration1);
+		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
 		this.setUpAuthenticationResult(this.registration1);
 
-		this.filter.doFilter(request, response, filterChain);
+		this.filter.doFilter(authorizationResponse, response, filterChain);
 
 		OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
-				this.registration1.getRegistrationId(), null, request);
+				this.registration1.getRegistrationId(), null, authorizationResponse);
 		assertThat(authorizedClient).isNotNull();
-
 		assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1);
 		assertThat(authorizedClient.getPrincipalName()).isEqualTo("anonymousUser");
 		assertThat(authorizedClient.getAccessToken()).isNotNull();
 
-		HttpSession session = request.getSession(false);
+		HttpSession session = authorizationResponse.getSession(false);
 		assertThat(session).isNotNull();
 
 		@SuppressWarnings("unchecked")
@@ -362,13 +393,51 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
 		assertThat(authorizedClients.values().iterator().next()).isSameAs(authorizedClient);
 	}
 
+	private static MockHttpServletRequest createAuthorizationRequest(String requestUri) {
+		return createAuthorizationRequest(requestUri, new LinkedHashMap<>());
+	}
+
+	private static MockHttpServletRequest createAuthorizationRequest(String requestUri, Map<String, String> parameters) {
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		if (!CollectionUtils.isEmpty(parameters)) {
+			parameters.forEach(request::addParameter);
+			request.setQueryString(
+					parameters.entrySet().stream()
+							.map(e -> e.getKey() + "=" + e.getValue())
+							.collect(Collectors.joining("&")));
+		}
+		return request;
+	}
+
+	private static MockHttpServletRequest createAuthorizationResponse(MockHttpServletRequest authorizationRequest) {
+		return createAuthorizationResponse(authorizationRequest, new LinkedHashMap<>());
+	}
+
+	private static MockHttpServletRequest createAuthorizationResponse(
+			MockHttpServletRequest authorizationRequest, Map<String, String> additionalParameters) {
+		MockHttpServletRequest authorizationResponse = new MockHttpServletRequest(
+				authorizationRequest.getMethod(), authorizationRequest.getRequestURI());
+		authorizationResponse.setServletPath(authorizationRequest.getRequestURI());
+		authorizationRequest.getParameterMap().forEach(authorizationResponse::addParameter);
+		authorizationResponse.addParameter(OAuth2ParameterNames.CODE, "code");
+		authorizationResponse.addParameter(OAuth2ParameterNames.STATE, "state");
+		additionalParameters.forEach(authorizationResponse::addParameter);
+		authorizationResponse.setQueryString(
+				authorizationResponse.getParameterMap().entrySet().stream()
+						.map(e -> e.getKey() + "=" + e.getValue()[0])
+						.collect(Collectors.joining("&")));
+		authorizationResponse.setSession(authorizationRequest.getSession());
+		return authorizationResponse;
+	}
+
 	private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
 											ClientRegistration registration) {
-		Map<String, Object> additionalParameters = new HashMap<>();
-		additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
+		Map<String, Object> attributes = new HashMap<>();
+		attributes.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
 		OAuth2AuthorizationRequest authorizationRequest = request()
-				.additionalParameters(additionalParameters)
-				.redirectUri(request.getRequestURL().toString()).build();
+				.attributes(attributes)
+				.redirectUri(UrlUtils.buildFullRequestUrl(request)).build();
 		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
 	}