|  | @@ -1,5 +1,5 @@
 | 
	
		
			
				|  |  |  /*
 | 
	
		
			
				|  |  | - * Copyright 2002-2018 the original author or authors.
 | 
	
		
			
				|  |  | + * Copyright 2002-2020 the original author or authors.
 | 
	
		
			
				|  |  |   *
 | 
	
		
			
				|  |  |   * Licensed under the Apache License, Version 2.0 (the "License");
 | 
	
		
			
				|  |  |   * you may not use this file except in compliance with the License.
 | 
	
	
		
			
				|  | @@ -18,10 +18,6 @@ package org.springframework.security.oauth2.client.web;
 | 
	
		
			
				|  |  |  import org.junit.After;
 | 
	
		
			
				|  |  |  import org.junit.Before;
 | 
	
		
			
				|  |  |  import org.junit.Test;
 | 
	
		
			
				|  |  | -import org.junit.runner.RunWith;
 | 
	
		
			
				|  |  | -import org.powermock.core.classloader.annotations.PowerMockIgnore;
 | 
	
		
			
				|  |  | -import org.powermock.core.classloader.annotations.PrepareForTest;
 | 
	
		
			
				|  |  | -import org.powermock.modules.junit4.PowerMockRunner;
 | 
	
		
			
				|  |  |  import org.springframework.mock.web.MockHttpServletRequest;
 | 
	
		
			
				|  |  |  import org.springframework.mock.web.MockHttpServletResponse;
 | 
	
		
			
				|  |  |  import org.springframework.security.authentication.AnonymousAuthenticationToken;
 | 
	
	
		
			
				|  | @@ -39,36 +35,44 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 | 
	
		
			
				|  |  |  import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 | 
	
		
			
				|  |  |  import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
 | 
	
		
			
				|  |  |  import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 | 
	
		
			
				|  |  | -import org.springframework.security.oauth2.core.OAuth2AccessToken;
 | 
	
		
			
				|  |  |  import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 | 
	
		
			
				|  |  |  import org.springframework.security.oauth2.core.OAuth2Error;
 | 
	
		
			
				|  |  |  import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 | 
	
		
			
				|  |  | -import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 | 
	
		
			
				|  |  | -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
 | 
	
		
			
				|  |  |  import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 | 
	
		
			
				|  |  |  import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 | 
	
		
			
				|  |  |  import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
 | 
	
		
			
				|  |  |  import org.springframework.security.web.savedrequest.RequestCache;
 | 
	
		
			
				|  |  | +import org.springframework.security.web.util.UrlUtils;
 | 
	
		
			
				|  |  | +import org.springframework.util.CollectionUtils;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import javax.servlet.FilterChain;
 | 
	
		
			
				|  |  |  import javax.servlet.http.HttpServletRequest;
 | 
	
		
			
				|  |  |  import javax.servlet.http.HttpServletResponse;
 | 
	
		
			
				|  |  |  import javax.servlet.http.HttpSession;
 | 
	
		
			
				|  |  |  import java.util.HashMap;
 | 
	
		
			
				|  |  | +import java.util.LinkedHashMap;
 | 
	
		
			
				|  |  |  import java.util.Map;
 | 
	
		
			
				|  |  | +import java.util.stream.Collectors;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import static org.assertj.core.api.Assertions.assertThat;
 | 
	
		
			
				|  |  |  import static org.assertj.core.api.Assertions.assertThatThrownBy;
 | 
	
		
			
				|  |  | -import static org.mockito.Mockito.*;
 | 
	
		
			
				|  |  | +import static org.mockito.Mockito.any;
 | 
	
		
			
				|  |  | +import static org.mockito.Mockito.mock;
 | 
	
		
			
				|  |  | +import static org.mockito.Mockito.spy;
 | 
	
		
			
				|  |  | +import static org.mockito.Mockito.times;
 | 
	
		
			
				|  |  | +import static org.mockito.Mockito.verify;
 | 
	
		
			
				|  |  | +import static org.mockito.Mockito.verifyZeroInteractions;
 | 
	
		
			
				|  |  | +import static org.mockito.Mockito.when;
 | 
	
		
			
				|  |  | +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
 | 
	
		
			
				|  |  | +import static org.springframework.security.oauth2.core.TestOAuth2RefreshTokens.refreshToken;
 | 
	
		
			
				|  |  | +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges.success;
 | 
	
		
			
				|  |  | +import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  /**
 | 
	
		
			
				|  |  |   * Tests for {@link OAuth2AuthorizationCodeGrantFilter}.
 | 
	
		
			
				|  |  |   *
 | 
	
		
			
				|  |  |   * @author Joe Grandja
 | 
	
		
			
				|  |  |   */
 | 
	
		
			
				|  |  | -@PowerMockIgnore("javax.security.*")
 | 
	
		
			
				|  |  | -@PrepareForTest({OAuth2AuthorizationRequest.class, OAuth2AuthorizationExchange.class, OAuth2AuthorizationCodeGrantFilter.class})
 | 
	
		
			
				|  |  | -@RunWith(PowerMockRunner.class)
 | 
	
		
			
				|  |  |  public class OAuth2AuthorizationCodeGrantFilterTests {
 | 
	
		
			
				|  |  |  	private ClientRegistration registration1;
 | 
	
		
			
				|  |  |  	private String principalName1 = "principal-1";
 | 
	
	
		
			
				|  | @@ -132,8 +136,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
 | 
	
		
			
				|  |  |  		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 | 
	
		
			
				|  |  |  		request.setServletPath(requestUri);
 | 
	
		
			
				|  |  |  		// NOTE: A valid Authorization Response contains either a 'code' or 'error' parameter.
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -		HttpServletResponse response = mock(HttpServletResponse.class);
 | 
	
		
			
				|  |  | +		MockHttpServletResponse response = new MockHttpServletResponse();
 | 
	
		
			
				|  |  |  		FilterChain filterChain = mock(FilterChain.class);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  		this.filter.doFilter(request, response, filterChain);
 | 
	
	
		
			
				|  | @@ -143,94 +146,142 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	@Test
 | 
	
		
			
				|  |  |  	public void doFilterWhenAuthorizationRequestNotFoundThenNotProcessed() throws Exception {
 | 
	
		
			
				|  |  | -		String requestUri = "/path";
 | 
	
		
			
				|  |  | -		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 | 
	
		
			
				|  |  | -		request.setServletPath(requestUri);
 | 
	
		
			
				|  |  | -		request.addParameter(OAuth2ParameterNames.CODE, "code");
 | 
	
		
			
				|  |  | -		request.addParameter(OAuth2ParameterNames.STATE, "state");
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -		HttpServletResponse response = mock(HttpServletResponse.class);
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/path");
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
 | 
	
		
			
				|  |  | +		MockHttpServletResponse response = new MockHttpServletResponse();
 | 
	
		
			
				|  |  |  		FilterChain filterChain = mock(FilterChain.class);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		this.filter.doFilter(request, response, filterChain);
 | 
	
		
			
				|  |  | +		this.filter.doFilter(authorizationResponse, response, filterChain);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	@Test
 | 
	
		
			
				|  |  | -	public void doFilterWhenAuthorizationResponseUrlDoesNotMatchAuthorizationRequestRedirectUriThenNotProcessed() throws Exception {
 | 
	
		
			
				|  |  | +	public void doFilterWhenAuthorizationRequestRedirectUriDoesNotMatchThenNotProcessed() throws Exception {
 | 
	
		
			
				|  |  |  		String requestUri = "/callback/client-1";
 | 
	
		
			
				|  |  | -		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 | 
	
		
			
				|  |  | -		request.setServletPath(requestUri);
 | 
	
		
			
				|  |  | -		request.addParameter(OAuth2ParameterNames.CODE, "code");
 | 
	
		
			
				|  |  | -		request.addParameter(OAuth2ParameterNames.STATE, "state");
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -		HttpServletResponse response = mock(HttpServletResponse.class);
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationRequest = createAuthorizationRequest(requestUri);
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
 | 
	
		
			
				|  |  | +		MockHttpServletResponse response = new MockHttpServletResponse();
 | 
	
		
			
				|  |  | +		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
 | 
	
		
			
				|  |  | +		authorizationResponse.setRequestURI(requestUri + "-no-match");
 | 
	
		
			
				|  |  |  		FilterChain filterChain = mock(FilterChain.class);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		this.setUpAuthorizationRequest(request, response, this.registration1);
 | 
	
		
			
				|  |  | -		request.setRequestURI(requestUri + "-no-match");
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -		this.filter.doFilter(request, response, filterChain);
 | 
	
		
			
				|  |  | +		this.filter.doFilter(authorizationResponse, response, filterChain);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +	// gh-7963
 | 
	
		
			
				|  |  |  	@Test
 | 
	
		
			
				|  |  | -	public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved() throws Exception {
 | 
	
		
			
				|  |  | +	public void doFilterWhenAuthorizationRequestRedirectUriParametersMatchThenProcessed() throws Exception {
 | 
	
		
			
				|  |  | +		// 1) redirect_uri with query parameters
 | 
	
		
			
				|  |  |  		String requestUri = "/callback/client-1";
 | 
	
		
			
				|  |  | -		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 | 
	
		
			
				|  |  | -		request.setServletPath(requestUri);
 | 
	
		
			
				|  |  | -		request.addParameter(OAuth2ParameterNames.CODE, "code");
 | 
	
		
			
				|  |  | -		request.addParameter(OAuth2ParameterNames.STATE, "state");
 | 
	
		
			
				|  |  | +		Map<String, String> parameters = new LinkedHashMap<>();
 | 
	
		
			
				|  |  | +		parameters.put("param1", "value1");
 | 
	
		
			
				|  |  | +		parameters.put("param2", "value2");
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationRequest = createAuthorizationRequest(requestUri, parameters);
 | 
	
		
			
				|  |  | +		MockHttpServletResponse response = new MockHttpServletResponse();
 | 
	
		
			
				|  |  | +		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
 | 
	
		
			
				|  |  | +		this.setUpAuthenticationResult(this.registration1);
 | 
	
		
			
				|  |  | +		FilterChain filterChain = mock(FilterChain.class);
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
 | 
	
		
			
				|  |  | +		this.filter.doFilter(authorizationResponse, response, filterChain);
 | 
	
		
			
				|  |  | +		verifyZeroInteractions(filterChain);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		// 2) redirect_uri with query parameters AND authorization response additional parameters
 | 
	
		
			
				|  |  | +		Map<String, String> additionalParameters = new LinkedHashMap<>();
 | 
	
		
			
				|  |  | +		additionalParameters.put("auth-param1", "value1");
 | 
	
		
			
				|  |  | +		additionalParameters.put("auth-param2", "value2");
 | 
	
		
			
				|  |  | +		response = new MockHttpServletResponse();
 | 
	
		
			
				|  |  | +		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
 | 
	
		
			
				|  |  | +		authorizationResponse = createAuthorizationResponse(authorizationRequest, additionalParameters);
 | 
	
		
			
				|  |  | +		this.filter.doFilter(authorizationResponse, response, filterChain);
 | 
	
		
			
				|  |  | +		verifyZeroInteractions(filterChain);
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +	// gh-7963
 | 
	
		
			
				|  |  | +	@Test
 | 
	
		
			
				|  |  | +	public void doFilterWhenAuthorizationRequestRedirectUriParametersDoesNotMatchThenNotProcessed() throws Exception {
 | 
	
		
			
				|  |  | +		String requestUri = "/callback/client-1";
 | 
	
		
			
				|  |  | +		Map<String, String> parameters = new LinkedHashMap<>();
 | 
	
		
			
				|  |  | +		parameters.put("param1", "value1");
 | 
	
		
			
				|  |  | +		parameters.put("param2", "value2");
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationRequest = createAuthorizationRequest(requestUri, parameters);
 | 
	
		
			
				|  |  |  		MockHttpServletResponse response = new MockHttpServletResponse();
 | 
	
		
			
				|  |  | +		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
 | 
	
		
			
				|  |  | +		this.setUpAuthenticationResult(this.registration1);
 | 
	
		
			
				|  |  |  		FilterChain filterChain = mock(FilterChain.class);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		this.setUpAuthorizationRequest(request, response, this.registration1);
 | 
	
		
			
				|  |  | +		// 1) Parameter value
 | 
	
		
			
				|  |  | +		Map<String, String> parametersNotMatch = new LinkedHashMap<>(parameters);
 | 
	
		
			
				|  |  | +		parametersNotMatch.put("param2", "value8");
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(
 | 
	
		
			
				|  |  | +				createAuthorizationRequest(requestUri, parametersNotMatch));
 | 
	
		
			
				|  |  | +		authorizationResponse.setSession(authorizationRequest.getSession());
 | 
	
		
			
				|  |  | +		this.filter.doFilter(authorizationResponse, response, filterChain);
 | 
	
		
			
				|  |  | +		verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		// 2) Parameter order
 | 
	
		
			
				|  |  | +		parametersNotMatch = new LinkedHashMap<>();
 | 
	
		
			
				|  |  | +		parametersNotMatch.put("param2", "value2");
 | 
	
		
			
				|  |  | +		parametersNotMatch.put("param1", "value1");
 | 
	
		
			
				|  |  | +		authorizationResponse = createAuthorizationResponse(
 | 
	
		
			
				|  |  | +				createAuthorizationRequest(requestUri, parametersNotMatch));
 | 
	
		
			
				|  |  | +		authorizationResponse.setSession(authorizationRequest.getSession());
 | 
	
		
			
				|  |  | +		this.filter.doFilter(authorizationResponse, response, filterChain);
 | 
	
		
			
				|  |  | +		verify(filterChain, times(2)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		// 3) Parameter missing
 | 
	
		
			
				|  |  | +		parametersNotMatch = new LinkedHashMap<>(parameters);
 | 
	
		
			
				|  |  | +		parametersNotMatch.remove("param2");
 | 
	
		
			
				|  |  | +		authorizationResponse = createAuthorizationResponse(
 | 
	
		
			
				|  |  | +				createAuthorizationRequest(requestUri, parametersNotMatch));
 | 
	
		
			
				|  |  | +		authorizationResponse.setSession(authorizationRequest.getSession());
 | 
	
		
			
				|  |  | +		this.filter.doFilter(authorizationResponse, response, filterChain);
 | 
	
		
			
				|  |  | +		verify(filterChain, times(3)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	@Test
 | 
	
		
			
				|  |  | +	public void doFilterWhenAuthorizationRequestMatchThenAuthorizationRequestRemoved() throws Exception {
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
 | 
	
		
			
				|  |  | +		MockHttpServletResponse response = new MockHttpServletResponse();
 | 
	
		
			
				|  |  | +		FilterChain filterChain = mock(FilterChain.class);
 | 
	
		
			
				|  |  | +		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
 | 
	
		
			
				|  |  |  		this.setUpAuthenticationResult(this.registration1);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		this.filter.doFilter(request, response, filterChain);
 | 
	
		
			
				|  |  | +		this.filter.doFilter(authorizationResponse, response, filterChain);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(request)).isNull();
 | 
	
		
			
				|  |  | +		assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(authorizationResponse)).isNull();
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	@Test
 | 
	
		
			
				|  |  |  	public void doFilterWhenAuthorizationFailsThenHandleOAuth2AuthorizationException() throws Exception {
 | 
	
		
			
				|  |  | -		String requestUri = "/callback/client-1";
 | 
	
		
			
				|  |  | -		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 | 
	
		
			
				|  |  | -		request.setServletPath(requestUri);
 | 
	
		
			
				|  |  | -		request.addParameter(OAuth2ParameterNames.CODE, "code");
 | 
	
		
			
				|  |  | -		request.addParameter(OAuth2ParameterNames.STATE, "state");
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
 | 
	
		
			
				|  |  |  		MockHttpServletResponse response = new MockHttpServletResponse();
 | 
	
		
			
				|  |  |  		FilterChain filterChain = mock(FilterChain.class);
 | 
	
		
			
				|  |  | +		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		this.setUpAuthorizationRequest(request, response, this.registration1);
 | 
	
		
			
				|  |  |  		OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT);
 | 
	
		
			
				|  |  |  		when(this.authenticationManager.authenticate(any(Authentication.class)))
 | 
	
		
			
				|  |  |  			.thenThrow(new OAuth2AuthorizationException(error));
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		this.filter.doFilter(request, response, filterChain);
 | 
	
		
			
				|  |  | +		this.filter.doFilter(authorizationResponse, response, filterChain);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  		assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1?error=invalid_grant");
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	@Test
 | 
	
		
			
				|  |  | -	public void doFilterWhenAuthorizationResponseSuccessThenAuthorizedClientSavedToService() throws Exception {
 | 
	
		
			
				|  |  | -		String requestUri = "/callback/client-1";
 | 
	
		
			
				|  |  | -		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 | 
	
		
			
				|  |  | -		request.setServletPath(requestUri);
 | 
	
		
			
				|  |  | -		request.addParameter(OAuth2ParameterNames.CODE, "code");
 | 
	
		
			
				|  |  | -		request.addParameter(OAuth2ParameterNames.STATE, "state");
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +	public void doFilterWhenAuthorizationSucceedsThenAuthorizedClientSavedToService() throws Exception {
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
 | 
	
		
			
				|  |  |  		MockHttpServletResponse response = new MockHttpServletResponse();
 | 
	
		
			
				|  |  |  		FilterChain filterChain = mock(FilterChain.class);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -		this.setUpAuthorizationRequest(request, response, this.registration1);
 | 
	
		
			
				|  |  | +		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
 | 
	
		
			
				|  |  |  		this.setUpAuthenticationResult(this.registration1);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		this.filter.doFilter(request, response, filterChain);
 | 
	
		
			
				|  |  | +		this.filter.doFilter(authorizationResponse, response, filterChain);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  		OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
 | 
	
		
			
				|  |  |  			this.registration1.getRegistrationId(), this.principalName1);
 | 
	
	
		
			
				|  | @@ -242,40 +293,31 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	@Test
 | 
	
		
			
				|  |  | -	public void doFilterWhenAuthorizationResponseSuccessThenRedirected() throws Exception {
 | 
	
		
			
				|  |  | -		String requestUri = "/callback/client-1";
 | 
	
		
			
				|  |  | -		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 | 
	
		
			
				|  |  | -		request.setServletPath(requestUri);
 | 
	
		
			
				|  |  | -		request.addParameter(OAuth2ParameterNames.CODE, "code");
 | 
	
		
			
				|  |  | -		request.addParameter(OAuth2ParameterNames.STATE, "state");
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +	public void doFilterWhenAuthorizationSucceedsThenRedirected() throws Exception {
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
 | 
	
		
			
				|  |  |  		MockHttpServletResponse response = new MockHttpServletResponse();
 | 
	
		
			
				|  |  |  		FilterChain filterChain = mock(FilterChain.class);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -		this.setUpAuthorizationRequest(request, response, this.registration1);
 | 
	
		
			
				|  |  | +		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
 | 
	
		
			
				|  |  |  		this.setUpAuthenticationResult(this.registration1);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		this.filter.doFilter(request, response, filterChain);
 | 
	
		
			
				|  |  | +		this.filter.doFilter(authorizationResponse, response, filterChain);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  		assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1");
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	@Test
 | 
	
		
			
				|  |  | -	public void doFilterWhenAuthorizationResponseSuccessHasSavedRequestThenRedirectedToSavedRequest() throws Exception {
 | 
	
		
			
				|  |  | +	public void doFilterWhenAuthorizationSucceedsAndHasSavedRequestThenRedirectToSavedRequest() throws Exception {
 | 
	
		
			
				|  |  |  		String requestUri = "/saved-request";
 | 
	
		
			
				|  |  |  		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 | 
	
		
			
				|  |  |  		request.setServletPath(requestUri);
 | 
	
		
			
				|  |  |  		MockHttpServletResponse response = new MockHttpServletResponse();
 | 
	
		
			
				|  |  |  		RequestCache requestCache = new HttpSessionRequestCache();
 | 
	
		
			
				|  |  |  		requestCache.saveRequest(request, response);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -		requestUri = "/callback/client-1";
 | 
	
		
			
				|  |  | -		request.setRequestURI(requestUri);
 | 
	
		
			
				|  |  | +		request.setRequestURI("/callback/client-1");
 | 
	
		
			
				|  |  |  		request.addParameter(OAuth2ParameterNames.CODE, "code");
 | 
	
		
			
				|  |  |  		request.addParameter(OAuth2ParameterNames.STATE, "state");
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |  		FilterChain filterChain = mock(FilterChain.class);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |  		this.setUpAuthorizationRequest(request, response, this.registration1);
 | 
	
		
			
				|  |  |  		this.setUpAuthenticationResult(this.registration1);
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -285,36 +327,30 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	@Test
 | 
	
		
			
				|  |  | -	public void doFilterWhenAuthorizationResponseSuccessAndAnonymousAccessThenAuthorizedClientSavedToHttpSession() throws Exception {
 | 
	
		
			
				|  |  | +	public void doFilterWhenAuthorizationSucceedsAndAnonymousAccessThenAuthorizedClientSavedToHttpSession() throws Exception {
 | 
	
		
			
				|  |  |  		AnonymousAuthenticationToken anonymousPrincipal =
 | 
	
		
			
				|  |  |  				new AnonymousAuthenticationToken("key-1234", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
 | 
	
		
			
				|  |  |  		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
 | 
	
		
			
				|  |  |  		securityContext.setAuthentication(anonymousPrincipal);
 | 
	
		
			
				|  |  |  		SecurityContextHolder.setContext(securityContext);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		String requestUri = "/callback/client-1";
 | 
	
		
			
				|  |  | -		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 | 
	
		
			
				|  |  | -		request.setServletPath(requestUri);
 | 
	
		
			
				|  |  | -		request.addParameter(OAuth2ParameterNames.CODE, "code");
 | 
	
		
			
				|  |  | -		request.addParameter(OAuth2ParameterNames.STATE, "state");
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
 | 
	
		
			
				|  |  |  		MockHttpServletResponse response = new MockHttpServletResponse();
 | 
	
		
			
				|  |  |  		FilterChain filterChain = mock(FilterChain.class);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -		this.setUpAuthorizationRequest(request, response, this.registration1);
 | 
	
		
			
				|  |  | +		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
 | 
	
		
			
				|  |  |  		this.setUpAuthenticationResult(this.registration1);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		this.filter.doFilter(request, response, filterChain);
 | 
	
		
			
				|  |  | +		this.filter.doFilter(authorizationResponse, response, filterChain);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  		OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
 | 
	
		
			
				|  |  | -				this.registration1.getRegistrationId(), anonymousPrincipal, request);
 | 
	
		
			
				|  |  | +				this.registration1.getRegistrationId(), anonymousPrincipal, authorizationResponse);
 | 
	
		
			
				|  |  |  		assertThat(authorizedClient).isNotNull();
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |  		assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1);
 | 
	
		
			
				|  |  |  		assertThat(authorizedClient.getPrincipalName()).isEqualTo(anonymousPrincipal.getName());
 | 
	
		
			
				|  |  |  		assertThat(authorizedClient.getAccessToken()).isNotNull();
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		HttpSession session = request.getSession(false);
 | 
	
		
			
				|  |  | +		HttpSession session = authorizationResponse.getSession(false);
 | 
	
		
			
				|  |  |  		assertThat(session).isNotNull();
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  		@SuppressWarnings("unchecked")
 | 
	
	
		
			
				|  | @@ -326,33 +362,27 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	@Test
 | 
	
		
			
				|  |  | -	public void doFilterWhenAuthorizationResponseSuccessAndAnonymousAccessNullAuthenticationThenAuthorizedClientSavedToHttpSession() throws Exception {
 | 
	
		
			
				|  |  | +	public void doFilterWhenAuthorizationSucceedsAndAnonymousAccessNullAuthenticationThenAuthorizedClientSavedToHttpSession() throws Exception {
 | 
	
		
			
				|  |  |  		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
 | 
	
		
			
				|  |  |  		SecurityContextHolder.setContext(securityContext);		// null Authentication
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		String requestUri = "/callback/client-1";
 | 
	
		
			
				|  |  | -		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 | 
	
		
			
				|  |  | -		request.setServletPath(requestUri);
 | 
	
		
			
				|  |  | -		request.addParameter(OAuth2ParameterNames.CODE, "code");
 | 
	
		
			
				|  |  | -		request.addParameter(OAuth2ParameterNames.STATE, "state");
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
 | 
	
		
			
				|  |  |  		MockHttpServletResponse response = new MockHttpServletResponse();
 | 
	
		
			
				|  |  |  		FilterChain filterChain = mock(FilterChain.class);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -		this.setUpAuthorizationRequest(request, response, this.registration1);
 | 
	
		
			
				|  |  | +		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
 | 
	
		
			
				|  |  |  		this.setUpAuthenticationResult(this.registration1);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		this.filter.doFilter(request, response, filterChain);
 | 
	
		
			
				|  |  | +		this.filter.doFilter(authorizationResponse, response, filterChain);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  		OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
 | 
	
		
			
				|  |  | -				this.registration1.getRegistrationId(), null, request);
 | 
	
		
			
				|  |  | +				this.registration1.getRegistrationId(), null, authorizationResponse);
 | 
	
		
			
				|  |  |  		assertThat(authorizedClient).isNotNull();
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |  		assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1);
 | 
	
		
			
				|  |  |  		assertThat(authorizedClient.getPrincipalName()).isEqualTo("anonymousUser");
 | 
	
		
			
				|  |  |  		assertThat(authorizedClient.getAccessToken()).isNotNull();
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		HttpSession session = request.getSession(false);
 | 
	
		
			
				|  |  | +		HttpSession session = authorizationResponse.getSession(false);
 | 
	
		
			
				|  |  |  		assertThat(session).isNotNull();
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  		@SuppressWarnings("unchecked")
 | 
	
	
		
			
				|  | @@ -363,23 +393,57 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
 | 
	
		
			
				|  |  |  		assertThat(authorizedClients.values().iterator().next()).isSameAs(authorizedClient);
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +	private static MockHttpServletRequest createAuthorizationRequest(String requestUri) {
 | 
	
		
			
				|  |  | +		return createAuthorizationRequest(requestUri, new LinkedHashMap<>());
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	private static MockHttpServletRequest createAuthorizationRequest(String requestUri, Map<String, String> parameters) {
 | 
	
		
			
				|  |  | +		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 | 
	
		
			
				|  |  | +		request.setServletPath(requestUri);
 | 
	
		
			
				|  |  | +		if (!CollectionUtils.isEmpty(parameters)) {
 | 
	
		
			
				|  |  | +			parameters.forEach(request::addParameter);
 | 
	
		
			
				|  |  | +			request.setQueryString(
 | 
	
		
			
				|  |  | +					parameters.entrySet().stream()
 | 
	
		
			
				|  |  | +							.map(e -> e.getKey() + "=" + e.getValue())
 | 
	
		
			
				|  |  | +							.collect(Collectors.joining("&")));
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +		return request;
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	private static MockHttpServletRequest createAuthorizationResponse(MockHttpServletRequest authorizationRequest) {
 | 
	
		
			
				|  |  | +		return createAuthorizationResponse(authorizationRequest, new LinkedHashMap<>());
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	private static MockHttpServletRequest createAuthorizationResponse(
 | 
	
		
			
				|  |  | +			MockHttpServletRequest authorizationRequest, Map<String, String> additionalParameters) {
 | 
	
		
			
				|  |  | +		MockHttpServletRequest authorizationResponse = new MockHttpServletRequest(
 | 
	
		
			
				|  |  | +				authorizationRequest.getMethod(), authorizationRequest.getRequestURI());
 | 
	
		
			
				|  |  | +		authorizationResponse.setServletPath(authorizationRequest.getRequestURI());
 | 
	
		
			
				|  |  | +		authorizationRequest.getParameterMap().forEach(authorizationResponse::addParameter);
 | 
	
		
			
				|  |  | +		authorizationResponse.addParameter(OAuth2ParameterNames.CODE, "code");
 | 
	
		
			
				|  |  | +		authorizationResponse.addParameter(OAuth2ParameterNames.STATE, "state");
 | 
	
		
			
				|  |  | +		additionalParameters.forEach(authorizationResponse::addParameter);
 | 
	
		
			
				|  |  | +		authorizationResponse.setQueryString(
 | 
	
		
			
				|  |  | +				authorizationResponse.getParameterMap().entrySet().stream()
 | 
	
		
			
				|  |  | +						.map(e -> e.getKey() + "=" + e.getValue()[0])
 | 
	
		
			
				|  |  | +						.collect(Collectors.joining("&")));
 | 
	
		
			
				|  |  | +		authorizationResponse.setSession(authorizationRequest.getSession());
 | 
	
		
			
				|  |  | +		return authorizationResponse;
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  	private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
 | 
	
		
			
				|  |  |  											ClientRegistration registration) {
 | 
	
		
			
				|  |  |  		Map<String, Object> additionalParameters = new HashMap<>();
 | 
	
		
			
				|  |  |  		additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
 | 
	
		
			
				|  |  | -		OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
 | 
	
		
			
				|  |  | -		when(authorizationRequest.getAdditionalParameters()).thenReturn(additionalParameters);
 | 
	
		
			
				|  |  | -		when(authorizationRequest.getRedirectUri()).thenReturn(request.getRequestURL().toString());
 | 
	
		
			
				|  |  | -		when(authorizationRequest.getState()).thenReturn("state");
 | 
	
		
			
				|  |  | +		OAuth2AuthorizationRequest authorizationRequest = request()
 | 
	
		
			
				|  |  | +				.additionalParameters(additionalParameters)
 | 
	
		
			
				|  |  | +				.redirectUri(UrlUtils.buildFullRequestUrl(request)).build();
 | 
	
		
			
				|  |  |  		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	private void setUpAuthenticationResult(ClientRegistration registration) {
 | 
	
		
			
				|  |  | -		OAuth2AuthorizationCodeAuthenticationToken authentication = mock(OAuth2AuthorizationCodeAuthenticationToken.class);
 | 
	
		
			
				|  |  | -		when(authentication.getClientRegistration()).thenReturn(registration);
 | 
	
		
			
				|  |  | -		when(authentication.getAuthorizationExchange()).thenReturn(mock(OAuth2AuthorizationExchange.class));
 | 
	
		
			
				|  |  | -		when(authentication.getAccessToken()).thenReturn(mock(OAuth2AccessToken.class));
 | 
	
		
			
				|  |  | -		when(authentication.getRefreshToken()).thenReturn(mock(OAuth2RefreshToken.class));
 | 
	
		
			
				|  |  | +		OAuth2AuthorizationCodeAuthenticationToken authentication =
 | 
	
		
			
				|  |  | +				new OAuth2AuthorizationCodeAuthenticationToken(registration, success(), noScopes(), refreshToken());
 | 
	
		
			
				|  |  |  		when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(authentication);
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  }
 |