|
@@ -37,6 +37,7 @@ import org.springframework.http.HttpStatus;
|
|
|
import org.springframework.http.MediaType;
|
|
|
import org.springframework.mock.web.MockHttpServletRequest;
|
|
|
import org.springframework.mock.web.MockHttpServletResponse;
|
|
|
+import org.springframework.mock.web.MockServletContext;
|
|
|
import org.springframework.security.authentication.AuthenticationDetailsSource;
|
|
|
import org.springframework.security.authentication.AuthenticationManager;
|
|
|
import org.springframework.security.authentication.TestingAuthenticationToken;
|
|
@@ -58,6 +59,7 @@ import org.springframework.security.web.authentication.AuthenticationConverter;
|
|
|
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
|
|
|
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
|
|
|
import org.springframework.security.web.authentication.WebAuthenticationDetails;
|
|
|
+import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
|
|
|
import org.springframework.util.StringUtils;
|
|
|
|
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
@@ -78,6 +80,7 @@ import static org.mockito.Mockito.when;
|
|
|
* @author Daniel Garnier-Moiroux
|
|
|
* @author Anoop Garlapati
|
|
|
* @author Dmitriy Dubson
|
|
|
+ * @author Greg Li
|
|
|
* @since 0.0.1
|
|
|
*/
|
|
|
public class OAuth2AuthorizationEndpointFilterTests {
|
|
@@ -263,6 +266,13 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
request -> {
|
|
|
request.addParameter(PkceParameterNames.CODE_CHALLENGE, "code-challenge");
|
|
|
request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenge");
|
|
|
+ String originalQueryString = request.getQueryString();
|
|
|
+ if (StringUtils.hasText(originalQueryString)) {
|
|
|
+ String newQueryString = originalQueryString.concat(PkceParameterNames.CODE_CHALLENGE)
|
|
|
+ .concat("=code-challenge").concat("&")
|
|
|
+ .concat(PkceParameterNames.CODE_CHALLENGE).concat("=another-code-challenge");
|
|
|
+ request.setQueryString(newQueryString);
|
|
|
+ }
|
|
|
});
|
|
|
}
|
|
|
|
|
@@ -275,6 +285,13 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
request -> {
|
|
|
request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
|
|
|
request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
|
|
|
+ String originalQueryString = request.getQueryString();
|
|
|
+ if (StringUtils.hasText(originalQueryString)) {
|
|
|
+ String newQueryString = originalQueryString.concat(PkceParameterNames.CODE_CHALLENGE_METHOD)
|
|
|
+ .concat("=S256").concat("&")
|
|
|
+ .concat(PkceParameterNames.CODE_CHALLENGE_METHOD).concat("=S256");
|
|
|
+ request.setQueryString(newQueryString);
|
|
|
+ }
|
|
|
});
|
|
|
}
|
|
|
|
|
@@ -557,6 +574,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
|
|
|
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
|
|
request.addParameter("custom-param", "custom-value-1", "custom-value-2");
|
|
|
+ String newQueryString = request.getQueryString().concat("custom-param")
|
|
|
+ .concat("=custom-value-1").concat("&")
|
|
|
+ .concat("custom-param").concat("=custom-value-2");
|
|
|
+ request.setQueryString(newQueryString);
|
|
|
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
FilterChain filterChain = mock(FilterChain.class);
|
|
@@ -646,17 +667,15 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
|
|
|
private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) {
|
|
|
String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI;
|
|
|
- MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
- request.setServletPath(requestUri);
|
|
|
+ MockHttpServletRequest request = MockMvcRequestBuilders.get(requestUri)
|
|
|
+ .queryParam(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue())
|
|
|
+ .queryParam(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
|
|
|
+ .queryParam(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next())
|
|
|
+ .queryParam(OAuth2ParameterNames.SCOPE,
|
|
|
+ StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "))
|
|
|
+ .queryParam(OAuth2ParameterNames.STATE, "state")
|
|
|
+ .buildRequest(new MockServletContext());
|
|
|
request.setRemoteAddr(REMOTE_ADDRESS);
|
|
|
-
|
|
|
- request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
|
|
|
- request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
|
|
|
- request.addParameter(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next());
|
|
|
- request.addParameter(OAuth2ParameterNames.SCOPE,
|
|
|
- StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
|
|
|
- request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
-
|
|
|
return request;
|
|
|
}
|
|
|
|