|
@@ -37,7 +37,6 @@ import org.springframework.http.HttpStatus;
|
|
|
import org.springframework.http.MediaType;
|
|
|
import org.springframework.mock.web.MockHttpServletRequest;
|
|
|
import org.springframework.mock.web.MockHttpServletResponse;
|
|
|
-import org.springframework.mock.web.MockServletContext;
|
|
|
import org.springframework.security.authentication.AuthenticationDetailsSource;
|
|
|
import org.springframework.security.authentication.AuthenticationManager;
|
|
|
import org.springframework.security.authentication.TestingAuthenticationToken;
|
|
@@ -59,8 +58,8 @@ import org.springframework.security.web.authentication.AuthenticationConverter;
|
|
|
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
|
|
|
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
|
|
|
import org.springframework.security.web.authentication.WebAuthenticationDetails;
|
|
|
-import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
|
|
|
import org.springframework.util.StringUtils;
|
|
|
+import org.springframework.web.util.UriComponentsBuilder;
|
|
|
|
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
|
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
|
@@ -173,7 +172,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
TestRegisteredClients.registeredClient().build(),
|
|
|
OAuth2ParameterNames.RESPONSE_TYPE,
|
|
|
OAuth2ErrorCodes.INVALID_REQUEST,
|
|
|
- request -> request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE));
|
|
|
+ request -> {
|
|
|
+ request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE);
|
|
|
+ updateQueryString(request);
|
|
|
+ });
|
|
|
}
|
|
|
|
|
|
@Test
|
|
@@ -182,7 +184,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
TestRegisteredClients.registeredClient().build(),
|
|
|
OAuth2ParameterNames.RESPONSE_TYPE,
|
|
|
OAuth2ErrorCodes.INVALID_REQUEST,
|
|
|
- request -> request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token"));
|
|
|
+ request -> {
|
|
|
+ request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token");
|
|
|
+ updateQueryString(request);
|
|
|
+ });
|
|
|
}
|
|
|
|
|
|
@Test
|
|
@@ -191,7 +196,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
TestRegisteredClients.registeredClient().build(),
|
|
|
OAuth2ParameterNames.RESPONSE_TYPE,
|
|
|
OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE,
|
|
|
- request -> request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token"));
|
|
|
+ request -> {
|
|
|
+ request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token");
|
|
|
+ updateQueryString(request);
|
|
|
+ });
|
|
|
}
|
|
|
|
|
|
@Test
|
|
@@ -200,7 +208,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
TestRegisteredClients.registeredClient().build(),
|
|
|
OAuth2ParameterNames.CLIENT_ID,
|
|
|
OAuth2ErrorCodes.INVALID_REQUEST,
|
|
|
- request -> request.removeParameter(OAuth2ParameterNames.CLIENT_ID));
|
|
|
+ request -> {
|
|
|
+ request.removeParameter(OAuth2ParameterNames.CLIENT_ID);
|
|
|
+ updateQueryString(request);
|
|
|
+ });
|
|
|
}
|
|
|
|
|
|
@Test
|
|
@@ -209,7 +220,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
TestRegisteredClients.registeredClient().build(),
|
|
|
OAuth2ParameterNames.CLIENT_ID,
|
|
|
OAuth2ErrorCodes.INVALID_REQUEST,
|
|
|
- request -> request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2"));
|
|
|
+ request -> {
|
|
|
+ request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2");
|
|
|
+ updateQueryString(request);
|
|
|
+ });
|
|
|
}
|
|
|
|
|
|
@Test
|
|
@@ -218,7 +232,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
TestRegisteredClients.registeredClient().build(),
|
|
|
OAuth2ParameterNames.REDIRECT_URI,
|
|
|
OAuth2ErrorCodes.INVALID_REQUEST,
|
|
|
- request -> request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"));
|
|
|
+ request -> {
|
|
|
+ request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com");
|
|
|
+ updateQueryString(request);
|
|
|
+ });
|
|
|
}
|
|
|
|
|
|
@Test
|
|
@@ -227,7 +244,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
TestRegisteredClients.registeredClient().build(),
|
|
|
OAuth2ParameterNames.SCOPE,
|
|
|
OAuth2ErrorCodes.INVALID_REQUEST,
|
|
|
- request -> request.addParameter(OAuth2ParameterNames.SCOPE, "scope2"));
|
|
|
+ request -> {
|
|
|
+ request.addParameter(OAuth2ParameterNames.SCOPE, "scope2");
|
|
|
+ updateQueryString(request);
|
|
|
+ });
|
|
|
}
|
|
|
|
|
|
@Test
|
|
@@ -236,7 +256,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
TestRegisteredClients.registeredClient().build(),
|
|
|
OAuth2ParameterNames.STATE,
|
|
|
OAuth2ErrorCodes.INVALID_REQUEST,
|
|
|
- request -> request.addParameter(OAuth2ParameterNames.STATE, "state2"));
|
|
|
+ request -> {
|
|
|
+ request.addParameter(OAuth2ParameterNames.STATE, "state2");
|
|
|
+ updateQueryString(request);
|
|
|
+ });
|
|
|
}
|
|
|
|
|
|
@Test
|
|
@@ -266,13 +289,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
request -> {
|
|
|
request.addParameter(PkceParameterNames.CODE_CHALLENGE, "code-challenge");
|
|
|
request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenge");
|
|
|
- String originalQueryString = request.getQueryString();
|
|
|
- if (StringUtils.hasText(originalQueryString)) {
|
|
|
- String newQueryString = originalQueryString.concat(PkceParameterNames.CODE_CHALLENGE)
|
|
|
- .concat("=code-challenge").concat("&")
|
|
|
- .concat(PkceParameterNames.CODE_CHALLENGE).concat("=another-code-challenge");
|
|
|
- request.setQueryString(newQueryString);
|
|
|
- }
|
|
|
+ updateQueryString(request);
|
|
|
});
|
|
|
}
|
|
|
|
|
@@ -285,13 +302,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
request -> {
|
|
|
request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
|
|
|
request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
|
|
|
- String originalQueryString = request.getQueryString();
|
|
|
- if (StringUtils.hasText(originalQueryString)) {
|
|
|
- String newQueryString = originalQueryString.concat(PkceParameterNames.CODE_CHALLENGE_METHOD)
|
|
|
- .concat("=S256").concat("&")
|
|
|
- .concat(PkceParameterNames.CODE_CHALLENGE_METHOD).concat("=S256");
|
|
|
- request.setQueryString(newQueryString);
|
|
|
- }
|
|
|
+ updateQueryString(request);
|
|
|
});
|
|
|
}
|
|
|
|
|
@@ -574,10 +585,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
|
|
|
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
|
|
request.addParameter("custom-param", "custom-value-1", "custom-value-2");
|
|
|
- String newQueryString = request.getQueryString().concat("custom-param")
|
|
|
- .concat("=custom-value-1").concat("&")
|
|
|
- .concat("custom-param").concat("=custom-value-2");
|
|
|
- request.setQueryString(newQueryString);
|
|
|
+ updateQueryString(request);
|
|
|
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
FilterChain filterChain = mock(FilterChain.class);
|
|
@@ -623,6 +631,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
|
|
|
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
|
|
request.setMethod("POST"); // OpenID Connect supports POST method
|
|
|
+ request.setQueryString(null);
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
FilterChain filterChain = mock(FilterChain.class);
|
|
|
|
|
@@ -667,15 +676,18 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
|
|
|
private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) {
|
|
|
String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI;
|
|
|
- MockHttpServletRequest request = MockMvcRequestBuilders.get(requestUri)
|
|
|
- .queryParam(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue())
|
|
|
- .queryParam(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
|
|
|
- .queryParam(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next())
|
|
|
- .queryParam(OAuth2ParameterNames.SCOPE,
|
|
|
- StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "))
|
|
|
- .queryParam(OAuth2ParameterNames.STATE, "state")
|
|
|
- .buildRequest(new MockServletContext());
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
+ request.setServletPath(requestUri);
|
|
|
request.setRemoteAddr(REMOTE_ADDRESS);
|
|
|
+
|
|
|
+ request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
|
|
|
+ request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
|
|
|
+ request.addParameter(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next());
|
|
|
+ request.addParameter(OAuth2ParameterNames.SCOPE,
|
|
|
+ StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
|
|
|
+ request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
+ updateQueryString(request);
|
|
|
+
|
|
|
return request;
|
|
|
}
|
|
|
|
|
@@ -692,6 +704,18 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
return request;
|
|
|
}
|
|
|
|
|
|
+ private static void updateQueryString(MockHttpServletRequest request) {
|
|
|
+ UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(request.getRequestURI());
|
|
|
+ request.getParameterMap().forEach((key, values) -> {
|
|
|
+ if (values.length > 0) {
|
|
|
+ for (String value : values) {
|
|
|
+ uriBuilder.queryParam(key, value);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ });
|
|
|
+ request.setQueryString(uriBuilder.build().getQuery());
|
|
|
+ }
|
|
|
+
|
|
|
private static String scopeCheckbox(String scope) {
|
|
|
return MessageFormat.format(
|
|
|
"<input class=\"form-check-input\" type=\"checkbox\" name=\"scope\" value=\"{0}\" id=\"{0}\">",
|