|
@@ -37,13 +37,13 @@ import javax.servlet.ServletResponse;
|
|
|
import javax.servlet.http.HttpServletRequest;
|
|
|
import javax.servlet.http.HttpServletResponse;
|
|
|
import java.lang.reflect.Constructor;
|
|
|
+import java.util.Collections;
|
|
|
import java.util.HashMap;
|
|
|
import java.util.Map;
|
|
|
|
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
|
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
|
|
import static org.mockito.Mockito.*;
|
|
|
-import static org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME;
|
|
|
|
|
|
/**
|
|
|
* Tests for {@link OAuth2AuthorizationRequestRedirectFilter}.
|
|
@@ -274,7 +274,6 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
|
|
|
|
|
|
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fcode%2Fregistration-1");
|
|
|
verify(this.requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
|
|
- assertThat(request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME)).isNull();
|
|
|
}
|
|
|
|
|
|
@Test
|
|
@@ -288,7 +287,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
|
|
|
doThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId()))
|
|
|
.when(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
|
|
|
|
|
|
- OAuth2AuthorizationRequestResolver resolver = req -> null;
|
|
|
+ OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class);
|
|
|
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
|
|
|
|
|
|
filter.doFilter(request, response, filterChain);
|
|
@@ -315,14 +314,13 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
|
|
|
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
|
|
|
this.clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
|
|
|
|
|
|
- OAuth2AuthorizationRequestResolver resolver = req -> {
|
|
|
- OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(req);
|
|
|
- Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
|
|
|
- additionalParameters.put("idp", req.getParameter("idp"));
|
|
|
- return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
|
|
|
- .additionalParameters(additionalParameters)
|
|
|
- .build();
|
|
|
- };
|
|
|
+ OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class);
|
|
|
+ OAuth2AuthorizationRequest result = OAuth2AuthorizationRequest
|
|
|
+ .from(defaultAuthorizationRequestResolver.resolve(request))
|
|
|
+ .additionalParameters(
|
|
|
+ Collections.singletonMap("idp", request.getParameter("idp")))
|
|
|
+ .build();
|
|
|
+ when(resolver.resolve(any())).thenReturn(result);
|
|
|
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
|
|
|
|
|
|
filter.doFilter(request, response, filterChain);
|
|
@@ -347,19 +345,23 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
|
|
|
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
|
|
|
this.clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
|
|
|
|
|
|
- OAuth2AuthorizationRequestResolver resolver = req -> {
|
|
|
- OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(req);
|
|
|
- Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
|
|
|
- additionalParameters.put(loginHintParamName, req.getParameter(loginHintParamName));
|
|
|
- String customAuthorizationRequestUri = UriComponentsBuilder
|
|
|
- .fromUriString(defaultAuthorizationRequest.getAuthorizationRequestUri())
|
|
|
- .queryParam(loginHintParamName, additionalParameters.get(loginHintParamName))
|
|
|
- .build(true).toUriString();
|
|
|
- return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
|
|
|
- .additionalParameters(additionalParameters)
|
|
|
- .authorizationRequestUri(customAuthorizationRequestUri)
|
|
|
- .build();
|
|
|
- };
|
|
|
+ OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class);
|
|
|
+
|
|
|
+ OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request);
|
|
|
+ Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
|
|
|
+ additionalParameters.put(loginHintParamName, request.getParameter(loginHintParamName));
|
|
|
+ String customAuthorizationRequestUri = UriComponentsBuilder
|
|
|
+ .fromUriString(defaultAuthorizationRequest.getAuthorizationRequestUri())
|
|
|
+ .queryParam(loginHintParamName, additionalParameters.get(loginHintParamName))
|
|
|
+ .build(true).toUriString();
|
|
|
+ OAuth2AuthorizationRequest result = OAuth2AuthorizationRequest
|
|
|
+ .from(defaultAuthorizationRequestResolver.resolve(request))
|
|
|
+ .additionalParameters(
|
|
|
+ Collections.singletonMap("idp", request.getParameter("idp")))
|
|
|
+ .authorizationRequestUri(customAuthorizationRequestUri)
|
|
|
+ .build();
|
|
|
+ when(resolver.resolve(any())).thenReturn(result);
|
|
|
+
|
|
|
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
|
|
|
|
|
|
filter.doFilter(request, response, filterChain);
|