|
@@ -1,5 +1,5 @@
|
|
|
/*
|
|
|
- * Copyright 2002-2018 the original author or authors.
|
|
|
+ * Copyright 2002-2020 the original author or authors.
|
|
|
*
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
* you may not use this file except in compliance with the License.
|
|
@@ -15,17 +15,9 @@
|
|
|
*/
|
|
|
package org.springframework.security.oauth2.client.web;
|
|
|
|
|
|
-import java.util.HashMap;
|
|
|
-import java.util.Map;
|
|
|
-import javax.servlet.FilterChain;
|
|
|
-import javax.servlet.http.HttpServletRequest;
|
|
|
-import javax.servlet.http.HttpServletResponse;
|
|
|
-import javax.servlet.http.HttpSession;
|
|
|
-
|
|
|
import org.junit.After;
|
|
|
import org.junit.Before;
|
|
|
import org.junit.Test;
|
|
|
-
|
|
|
import org.springframework.mock.web.MockHttpServletRequest;
|
|
|
import org.springframework.mock.web.MockHttpServletResponse;
|
|
|
import org.springframework.security.authentication.AnonymousAuthenticationToken;
|
|
@@ -50,13 +42,26 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
|
|
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
|
|
|
import org.springframework.security.web.savedrequest.RequestCache;
|
|
|
+import org.springframework.security.web.util.UrlUtils;
|
|
|
+import org.springframework.util.CollectionUtils;
|
|
|
+
|
|
|
+import javax.servlet.FilterChain;
|
|
|
+import javax.servlet.http.HttpServletRequest;
|
|
|
+import javax.servlet.http.HttpServletResponse;
|
|
|
+import javax.servlet.http.HttpSession;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.LinkedHashMap;
|
|
|
+import java.util.Map;
|
|
|
+import java.util.stream.Collectors;
|
|
|
|
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
|
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
|
|
import static org.mockito.Mockito.any;
|
|
|
import static org.mockito.Mockito.mock;
|
|
|
import static org.mockito.Mockito.spy;
|
|
|
+import static org.mockito.Mockito.times;
|
|
|
import static org.mockito.Mockito.verify;
|
|
|
+import static org.mockito.Mockito.verifyNoInteractions;
|
|
|
import static org.mockito.Mockito.when;
|
|
|
import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
|
|
|
import static org.springframework.security.oauth2.core.TestOAuth2RefreshTokens.refreshToken;
|
|
@@ -131,8 +136,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
|
|
|
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
request.setServletPath(requestUri);
|
|
|
// NOTE: A valid Authorization Response contains either a 'code' or 'error' parameter.
|
|
|
-
|
|
|
- HttpServletResponse response = mock(HttpServletResponse.class);
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
FilterChain filterChain = mock(FilterChain.class);
|
|
|
|
|
|
this.filter.doFilter(request, response, filterChain);
|
|
@@ -142,94 +146,142 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
|
|
|
|
|
|
@Test
|
|
|
public void doFilterWhenAuthorizationRequestNotFoundThenNotProcessed() throws Exception {
|
|
|
- String requestUri = "/path";
|
|
|
- MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
- request.setServletPath(requestUri);
|
|
|
- request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
- request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
-
|
|
|
- HttpServletResponse response = mock(HttpServletResponse.class);
|
|
|
+ MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/path");
|
|
|
+ MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
FilterChain filterChain = mock(FilterChain.class);
|
|
|
|
|
|
- this.filter.doFilter(request, response, filterChain);
|
|
|
+ this.filter.doFilter(authorizationResponse, response, filterChain);
|
|
|
|
|
|
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterWhenAuthorizationResponseUrlDoesNotMatchAuthorizationRequestRedirectUriThenNotProcessed() throws Exception {
|
|
|
+ public void doFilterWhenAuthorizationRequestRedirectUriDoesNotMatchThenNotProcessed() throws Exception {
|
|
|
String requestUri = "/callback/client-1";
|
|
|
- MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
- request.setServletPath(requestUri);
|
|
|
- request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
- request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
-
|
|
|
- HttpServletResponse response = mock(HttpServletResponse.class);
|
|
|
+ MockHttpServletRequest authorizationRequest = createAuthorizationRequest(requestUri);
|
|
|
+ MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
+ this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
|
|
+ authorizationResponse.setRequestURI(requestUri + "-no-match");
|
|
|
FilterChain filterChain = mock(FilterChain.class);
|
|
|
|
|
|
- this.setUpAuthorizationRequest(request, response, this.registration1);
|
|
|
- request.setRequestURI(requestUri + "-no-match");
|
|
|
-
|
|
|
- this.filter.doFilter(request, response, filterChain);
|
|
|
+ this.filter.doFilter(authorizationResponse, response, filterChain);
|
|
|
|
|
|
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
|
|
}
|
|
|
|
|
|
+ // gh-7963
|
|
|
@Test
|
|
|
- public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved() throws Exception {
|
|
|
+ public void doFilterWhenAuthorizationRequestRedirectUriParametersMatchThenProcessed() throws Exception {
|
|
|
+ // 1) redirect_uri with query parameters
|
|
|
String requestUri = "/callback/client-1";
|
|
|
- MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
- request.setServletPath(requestUri);
|
|
|
- request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
- request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
+ Map<String, String> parameters = new LinkedHashMap<>();
|
|
|
+ parameters.put("param1", "value1");
|
|
|
+ parameters.put("param2", "value2");
|
|
|
+ MockHttpServletRequest authorizationRequest = createAuthorizationRequest(requestUri, parameters);
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
+ this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
|
|
+ this.setUpAuthenticationResult(this.registration1);
|
|
|
+ FilterChain filterChain = mock(FilterChain.class);
|
|
|
+ MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
|
|
+ this.filter.doFilter(authorizationResponse, response, filterChain);
|
|
|
+ verifyNoInteractions(filterChain);
|
|
|
+
|
|
|
+ // 2) redirect_uri with query parameters AND authorization response additional parameters
|
|
|
+ Map<String, String> additionalParameters = new LinkedHashMap<>();
|
|
|
+ additionalParameters.put("auth-param1", "value1");
|
|
|
+ additionalParameters.put("auth-param2", "value2");
|
|
|
+ response = new MockHttpServletResponse();
|
|
|
+ this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
|
|
+ authorizationResponse = createAuthorizationResponse(authorizationRequest, additionalParameters);
|
|
|
+ this.filter.doFilter(authorizationResponse, response, filterChain);
|
|
|
+ verifyNoInteractions(filterChain);
|
|
|
+ }
|
|
|
|
|
|
+ // gh-7963
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenAuthorizationRequestRedirectUriParametersDoesNotMatchThenNotProcessed() throws Exception {
|
|
|
+ String requestUri = "/callback/client-1";
|
|
|
+ Map<String, String> parameters = new LinkedHashMap<>();
|
|
|
+ parameters.put("param1", "value1");
|
|
|
+ parameters.put("param2", "value2");
|
|
|
+ MockHttpServletRequest authorizationRequest = createAuthorizationRequest(requestUri, parameters);
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
+ this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
|
|
+ this.setUpAuthenticationResult(this.registration1);
|
|
|
FilterChain filterChain = mock(FilterChain.class);
|
|
|
|
|
|
- this.setUpAuthorizationRequest(request, response, this.registration1);
|
|
|
+ // 1) Parameter value
|
|
|
+ Map<String, String> parametersNotMatch = new LinkedHashMap<>(parameters);
|
|
|
+ parametersNotMatch.put("param2", "value8");
|
|
|
+ MockHttpServletRequest authorizationResponse = createAuthorizationResponse(
|
|
|
+ createAuthorizationRequest(requestUri, parametersNotMatch));
|
|
|
+ authorizationResponse.setSession(authorizationRequest.getSession());
|
|
|
+ this.filter.doFilter(authorizationResponse, response, filterChain);
|
|
|
+ verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
|
|
+
|
|
|
+ // 2) Parameter order
|
|
|
+ parametersNotMatch = new LinkedHashMap<>();
|
|
|
+ parametersNotMatch.put("param2", "value2");
|
|
|
+ parametersNotMatch.put("param1", "value1");
|
|
|
+ authorizationResponse = createAuthorizationResponse(
|
|
|
+ createAuthorizationRequest(requestUri, parametersNotMatch));
|
|
|
+ authorizationResponse.setSession(authorizationRequest.getSession());
|
|
|
+ this.filter.doFilter(authorizationResponse, response, filterChain);
|
|
|
+ verify(filterChain, times(2)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
|
|
+
|
|
|
+ // 3) Parameter missing
|
|
|
+ parametersNotMatch = new LinkedHashMap<>(parameters);
|
|
|
+ parametersNotMatch.remove("param2");
|
|
|
+ authorizationResponse = createAuthorizationResponse(
|
|
|
+ createAuthorizationRequest(requestUri, parametersNotMatch));
|
|
|
+ authorizationResponse.setSession(authorizationRequest.getSession());
|
|
|
+ this.filter.doFilter(authorizationResponse, response, filterChain);
|
|
|
+ verify(filterChain, times(3)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenAuthorizationRequestMatchThenAuthorizationRequestRemoved() throws Exception {
|
|
|
+ MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
|
|
|
+ MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
+ FilterChain filterChain = mock(FilterChain.class);
|
|
|
+ this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
|
|
this.setUpAuthenticationResult(this.registration1);
|
|
|
|
|
|
- this.filter.doFilter(request, response, filterChain);
|
|
|
+ this.filter.doFilter(authorizationResponse, response, filterChain);
|
|
|
|
|
|
- assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(request)).isNull();
|
|
|
+ assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(authorizationResponse)).isNull();
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
public void doFilterWhenAuthorizationFailsThenHandleOAuth2AuthorizationException() throws Exception {
|
|
|
- String requestUri = "/callback/client-1";
|
|
|
- MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
- request.setServletPath(requestUri);
|
|
|
- request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
- request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
-
|
|
|
+ MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
|
|
|
+ MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
FilterChain filterChain = mock(FilterChain.class);
|
|
|
+ this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
|
|
|
|
|
- this.setUpAuthorizationRequest(request, response, this.registration1);
|
|
|
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT);
|
|
|
when(this.authenticationManager.authenticate(any(Authentication.class)))
|
|
|
.thenThrow(new OAuth2AuthorizationException(error));
|
|
|
|
|
|
- this.filter.doFilter(request, response, filterChain);
|
|
|
+ this.filter.doFilter(authorizationResponse, response, filterChain);
|
|
|
|
|
|
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1?error=invalid_grant");
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterWhenAuthorizationResponseSuccessThenAuthorizedClientSavedToService() throws Exception {
|
|
|
- String requestUri = "/callback/client-1";
|
|
|
- MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
- request.setServletPath(requestUri);
|
|
|
- request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
- request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
-
|
|
|
+ public void doFilterWhenAuthorizationSucceedsThenAuthorizedClientSavedToService() throws Exception {
|
|
|
+ MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
|
|
|
+ MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
FilterChain filterChain = mock(FilterChain.class);
|
|
|
-
|
|
|
- this.setUpAuthorizationRequest(request, response, this.registration1);
|
|
|
+ this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
|
|
this.setUpAuthenticationResult(this.registration1);
|
|
|
|
|
|
- this.filter.doFilter(request, response, filterChain);
|
|
|
+ this.filter.doFilter(authorizationResponse, response, filterChain);
|
|
|
|
|
|
OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
|
|
|
this.registration1.getRegistrationId(), this.principalName1);
|
|
@@ -241,40 +293,31 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterWhenAuthorizationResponseSuccessThenRedirected() throws Exception {
|
|
|
- String requestUri = "/callback/client-1";
|
|
|
- MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
- request.setServletPath(requestUri);
|
|
|
- request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
- request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
-
|
|
|
+ public void doFilterWhenAuthorizationSucceedsThenRedirected() throws Exception {
|
|
|
+ MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
|
|
|
+ MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
FilterChain filterChain = mock(FilterChain.class);
|
|
|
-
|
|
|
- this.setUpAuthorizationRequest(request, response, this.registration1);
|
|
|
+ this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
|
|
this.setUpAuthenticationResult(this.registration1);
|
|
|
|
|
|
- this.filter.doFilter(request, response, filterChain);
|
|
|
+ this.filter.doFilter(authorizationResponse, response, filterChain);
|
|
|
|
|
|
assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1");
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterWhenAuthorizationResponseSuccessHasSavedRequestThenRedirectedToSavedRequest() throws Exception {
|
|
|
+ public void doFilterWhenAuthorizationSucceedsAndHasSavedRequestThenRedirectToSavedRequest() throws Exception {
|
|
|
String requestUri = "/saved-request";
|
|
|
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
request.setServletPath(requestUri);
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
RequestCache requestCache = new HttpSessionRequestCache();
|
|
|
requestCache.saveRequest(request, response);
|
|
|
-
|
|
|
- requestUri = "/callback/client-1";
|
|
|
- request.setRequestURI(requestUri);
|
|
|
+ request.setRequestURI("/callback/client-1");
|
|
|
request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
-
|
|
|
FilterChain filterChain = mock(FilterChain.class);
|
|
|
-
|
|
|
this.setUpAuthorizationRequest(request, response, this.registration1);
|
|
|
this.setUpAuthenticationResult(this.registration1);
|
|
|
|
|
@@ -284,36 +327,30 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterWhenAuthorizationResponseSuccessAndAnonymousAccessThenAuthorizedClientSavedToHttpSession() throws Exception {
|
|
|
+ public void doFilterWhenAuthorizationSucceedsAndAnonymousAccessThenAuthorizedClientSavedToHttpSession() throws Exception {
|
|
|
AnonymousAuthenticationToken anonymousPrincipal =
|
|
|
new AnonymousAuthenticationToken("key-1234", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
|
|
|
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
|
|
|
securityContext.setAuthentication(anonymousPrincipal);
|
|
|
SecurityContextHolder.setContext(securityContext);
|
|
|
|
|
|
- String requestUri = "/callback/client-1";
|
|
|
- MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
- request.setServletPath(requestUri);
|
|
|
- request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
- request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
-
|
|
|
+ MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
|
|
|
+ MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
FilterChain filterChain = mock(FilterChain.class);
|
|
|
-
|
|
|
- this.setUpAuthorizationRequest(request, response, this.registration1);
|
|
|
+ this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
|
|
this.setUpAuthenticationResult(this.registration1);
|
|
|
|
|
|
- this.filter.doFilter(request, response, filterChain);
|
|
|
+ this.filter.doFilter(authorizationResponse, response, filterChain);
|
|
|
|
|
|
OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
|
|
|
- this.registration1.getRegistrationId(), anonymousPrincipal, request);
|
|
|
+ this.registration1.getRegistrationId(), anonymousPrincipal, authorizationResponse);
|
|
|
assertThat(authorizedClient).isNotNull();
|
|
|
-
|
|
|
assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1);
|
|
|
assertThat(authorizedClient.getPrincipalName()).isEqualTo(anonymousPrincipal.getName());
|
|
|
assertThat(authorizedClient.getAccessToken()).isNotNull();
|
|
|
|
|
|
- HttpSession session = request.getSession(false);
|
|
|
+ HttpSession session = authorizationResponse.getSession(false);
|
|
|
assertThat(session).isNotNull();
|
|
|
|
|
|
@SuppressWarnings("unchecked")
|
|
@@ -325,33 +362,27 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterWhenAuthorizationResponseSuccessAndAnonymousAccessNullAuthenticationThenAuthorizedClientSavedToHttpSession() throws Exception {
|
|
|
+ public void doFilterWhenAuthorizationSucceedsAndAnonymousAccessNullAuthenticationThenAuthorizedClientSavedToHttpSession() throws Exception {
|
|
|
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
|
|
|
SecurityContextHolder.setContext(securityContext); // null Authentication
|
|
|
|
|
|
- String requestUri = "/callback/client-1";
|
|
|
- MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
- request.setServletPath(requestUri);
|
|
|
- request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
- request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
-
|
|
|
+ MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
|
|
|
+ MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
FilterChain filterChain = mock(FilterChain.class);
|
|
|
-
|
|
|
- this.setUpAuthorizationRequest(request, response, this.registration1);
|
|
|
+ this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
|
|
|
this.setUpAuthenticationResult(this.registration1);
|
|
|
|
|
|
- this.filter.doFilter(request, response, filterChain);
|
|
|
+ this.filter.doFilter(authorizationResponse, response, filterChain);
|
|
|
|
|
|
OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
|
|
|
- this.registration1.getRegistrationId(), null, request);
|
|
|
+ this.registration1.getRegistrationId(), null, authorizationResponse);
|
|
|
assertThat(authorizedClient).isNotNull();
|
|
|
-
|
|
|
assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1);
|
|
|
assertThat(authorizedClient.getPrincipalName()).isEqualTo("anonymousUser");
|
|
|
assertThat(authorizedClient.getAccessToken()).isNotNull();
|
|
|
|
|
|
- HttpSession session = request.getSession(false);
|
|
|
+ HttpSession session = authorizationResponse.getSession(false);
|
|
|
assertThat(session).isNotNull();
|
|
|
|
|
|
@SuppressWarnings("unchecked")
|
|
@@ -362,13 +393,51 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
|
|
|
assertThat(authorizedClients.values().iterator().next()).isSameAs(authorizedClient);
|
|
|
}
|
|
|
|
|
|
+ private static MockHttpServletRequest createAuthorizationRequest(String requestUri) {
|
|
|
+ return createAuthorizationRequest(requestUri, new LinkedHashMap<>());
|
|
|
+ }
|
|
|
+
|
|
|
+ private static MockHttpServletRequest createAuthorizationRequest(String requestUri, Map<String, String> parameters) {
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
+ request.setServletPath(requestUri);
|
|
|
+ if (!CollectionUtils.isEmpty(parameters)) {
|
|
|
+ parameters.forEach(request::addParameter);
|
|
|
+ request.setQueryString(
|
|
|
+ parameters.entrySet().stream()
|
|
|
+ .map(e -> e.getKey() + "=" + e.getValue())
|
|
|
+ .collect(Collectors.joining("&")));
|
|
|
+ }
|
|
|
+ return request;
|
|
|
+ }
|
|
|
+
|
|
|
+ private static MockHttpServletRequest createAuthorizationResponse(MockHttpServletRequest authorizationRequest) {
|
|
|
+ return createAuthorizationResponse(authorizationRequest, new LinkedHashMap<>());
|
|
|
+ }
|
|
|
+
|
|
|
+ private static MockHttpServletRequest createAuthorizationResponse(
|
|
|
+ MockHttpServletRequest authorizationRequest, Map<String, String> additionalParameters) {
|
|
|
+ MockHttpServletRequest authorizationResponse = new MockHttpServletRequest(
|
|
|
+ authorizationRequest.getMethod(), authorizationRequest.getRequestURI());
|
|
|
+ authorizationResponse.setServletPath(authorizationRequest.getRequestURI());
|
|
|
+ authorizationRequest.getParameterMap().forEach(authorizationResponse::addParameter);
|
|
|
+ authorizationResponse.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
+ authorizationResponse.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
+ additionalParameters.forEach(authorizationResponse::addParameter);
|
|
|
+ authorizationResponse.setQueryString(
|
|
|
+ authorizationResponse.getParameterMap().entrySet().stream()
|
|
|
+ .map(e -> e.getKey() + "=" + e.getValue()[0])
|
|
|
+ .collect(Collectors.joining("&")));
|
|
|
+ authorizationResponse.setSession(authorizationRequest.getSession());
|
|
|
+ return authorizationResponse;
|
|
|
+ }
|
|
|
+
|
|
|
private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
|
|
|
ClientRegistration registration) {
|
|
|
- Map<String, Object> additionalParameters = new HashMap<>();
|
|
|
- additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
|
|
|
+ Map<String, Object> attributes = new HashMap<>();
|
|
|
+ attributes.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
|
|
|
OAuth2AuthorizationRequest authorizationRequest = request()
|
|
|
- .additionalParameters(additionalParameters)
|
|
|
- .redirectUri(request.getRequestURL().toString()).build();
|
|
|
+ .attributes(attributes)
|
|
|
+ .redirectUri(UrlUtils.buildFullRequestUrl(request)).build();
|
|
|
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
|
|
|
}
|
|
|
|