|
@@ -44,9 +44,12 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
|
|
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
|
|
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
|
|
import org.springframework.security.oauth2.core.user.OAuth2User;
|
|
|
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
|
|
|
+import org.springframework.security.web.util.UrlUtils;
|
|
|
+import org.springframework.web.util.UriComponentsBuilder;
|
|
|
|
|
|
import javax.servlet.FilterChain;
|
|
|
import javax.servlet.http.HttpServletRequest;
|
|
@@ -64,7 +67,7 @@ import static org.mockito.Mockito.*;
|
|
|
* @author Joe Grandja
|
|
|
*/
|
|
|
@PowerMockIgnore("javax.security.*")
|
|
|
-@PrepareForTest({OAuth2AuthorizationRequest.class, OAuth2AuthorizationExchange.class, OAuth2LoginAuthenticationFilter.class})
|
|
|
+@PrepareForTest({OAuth2AuthorizationExchange.class, OAuth2LoginAuthenticationFilter.class})
|
|
|
@RunWith(PowerMockRunner.class)
|
|
|
public class OAuth2LoginAuthenticationFilterTests {
|
|
|
private ClientRegistration registration1;
|
|
@@ -298,16 +301,137 @@ public class OAuth2LoginAuthenticationFilterTests {
|
|
|
verify(this.filter).attemptAuthentication(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
|
|
}
|
|
|
|
|
|
+ // gh-5890
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenAuthorizationResponseHasDefaultPort80ThenRedirectUriMatchingExcludesPort() throws Exception {
|
|
|
+ String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
|
|
|
+ String state = "state";
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
+ request.setScheme("http");
|
|
|
+ request.setServerName("example.com");
|
|
|
+ request.setServerPort(80);
|
|
|
+ request.setServletPath(requestUri);
|
|
|
+ request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
+ request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
+
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
+ FilterChain filterChain = mock(FilterChain.class);
|
|
|
+
|
|
|
+ this.setUpAuthorizationRequest(request, response, this.registration2, state);
|
|
|
+ this.setUpAuthenticationResult(this.registration2);
|
|
|
+
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
+
|
|
|
+ ArgumentCaptor<Authentication> authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class);
|
|
|
+ verify(this.authenticationManager).authenticate(authenticationArgCaptor.capture());
|
|
|
+
|
|
|
+ OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) authenticationArgCaptor.getValue();
|
|
|
+ OAuth2AuthorizationRequest authorizationRequest = authentication.getAuthorizationExchange().getAuthorizationRequest();
|
|
|
+ OAuth2AuthorizationResponse authorizationResponse = authentication.getAuthorizationExchange().getAuthorizationResponse();
|
|
|
+
|
|
|
+ String expectedRedirectUri = "http://example.com/login/oauth2/code/registration-id-2";
|
|
|
+ assertThat(authorizationRequest.getRedirectUri()).isEqualTo(expectedRedirectUri);
|
|
|
+ assertThat(authorizationResponse.getRedirectUri()).isEqualTo(expectedRedirectUri);
|
|
|
+ }
|
|
|
+
|
|
|
+ // gh-5890
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenAuthorizationResponseHasDefaultPort443ThenRedirectUriMatchingExcludesPort() throws Exception {
|
|
|
+ String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
|
|
|
+ String state = "state";
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
+ request.setScheme("https");
|
|
|
+ request.setServerName("example.com");
|
|
|
+ request.setServerPort(443);
|
|
|
+ request.setServletPath(requestUri);
|
|
|
+ request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
+ request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
+
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
+ FilterChain filterChain = mock(FilterChain.class);
|
|
|
+
|
|
|
+ this.setUpAuthorizationRequest(request, response, this.registration2, state);
|
|
|
+ this.setUpAuthenticationResult(this.registration2);
|
|
|
+
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
+
|
|
|
+ ArgumentCaptor<Authentication> authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class);
|
|
|
+ verify(this.authenticationManager).authenticate(authenticationArgCaptor.capture());
|
|
|
+
|
|
|
+ OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) authenticationArgCaptor.getValue();
|
|
|
+ OAuth2AuthorizationRequest authorizationRequest = authentication.getAuthorizationExchange().getAuthorizationRequest();
|
|
|
+ OAuth2AuthorizationResponse authorizationResponse = authentication.getAuthorizationExchange().getAuthorizationResponse();
|
|
|
+
|
|
|
+ String expectedRedirectUri = "https://example.com/login/oauth2/code/registration-id-2";
|
|
|
+ assertThat(authorizationRequest.getRedirectUri()).isEqualTo(expectedRedirectUri);
|
|
|
+ assertThat(authorizationResponse.getRedirectUri()).isEqualTo(expectedRedirectUri);
|
|
|
+ }
|
|
|
+
|
|
|
+ // gh-5890
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenAuthorizationResponseHasNonDefaultPortThenRedirectUriMatchingIncludesPort() throws Exception {
|
|
|
+ String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
|
|
|
+ String state = "state";
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
+ request.setScheme("https");
|
|
|
+ request.setServerName("example.com");
|
|
|
+ request.setServerPort(9090);
|
|
|
+ request.setServletPath(requestUri);
|
|
|
+ request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
+ request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
+
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
+ FilterChain filterChain = mock(FilterChain.class);
|
|
|
+
|
|
|
+ this.setUpAuthorizationRequest(request, response, this.registration2, state);
|
|
|
+ this.setUpAuthenticationResult(this.registration2);
|
|
|
+
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
+
|
|
|
+ ArgumentCaptor<Authentication> authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class);
|
|
|
+ verify(this.authenticationManager).authenticate(authenticationArgCaptor.capture());
|
|
|
+
|
|
|
+ OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) authenticationArgCaptor.getValue();
|
|
|
+ OAuth2AuthorizationRequest authorizationRequest = authentication.getAuthorizationExchange().getAuthorizationRequest();
|
|
|
+ OAuth2AuthorizationResponse authorizationResponse = authentication.getAuthorizationExchange().getAuthorizationResponse();
|
|
|
+
|
|
|
+ String expectedRedirectUri = "https://example.com:9090/login/oauth2/code/registration-id-2";
|
|
|
+ assertThat(authorizationRequest.getRedirectUri()).isEqualTo(expectedRedirectUri);
|
|
|
+ assertThat(authorizationResponse.getRedirectUri()).isEqualTo(expectedRedirectUri);
|
|
|
+ }
|
|
|
+
|
|
|
private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
|
|
|
ClientRegistration registration, String state) {
|
|
|
- OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
|
|
|
- when(authorizationRequest.getState()).thenReturn(state);
|
|
|
Map<String, Object> additionalParameters = new HashMap<>();
|
|
|
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
|
|
|
- when(authorizationRequest.getAdditionalParameters()).thenReturn(additionalParameters);
|
|
|
+ OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
|
|
|
+ .authorizationUri(registration.getProviderDetails().getAuthorizationUri())
|
|
|
+ .clientId(registration.getClientId())
|
|
|
+ .redirectUri(expandRedirectUri(request, registration))
|
|
|
+ .scopes(registration.getScopes())
|
|
|
+ .state(state)
|
|
|
+ .additionalParameters(additionalParameters)
|
|
|
+ .build();
|
|
|
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
|
|
|
}
|
|
|
|
|
|
+ private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration) {
|
|
|
+ String baseUrl = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
|
|
|
+ .replaceQuery(null)
|
|
|
+ .replacePath(request.getContextPath())
|
|
|
+ .build()
|
|
|
+ .toUriString();
|
|
|
+
|
|
|
+ Map<String, String> uriVariables = new HashMap<>();
|
|
|
+ uriVariables.put("baseUrl", baseUrl);
|
|
|
+ uriVariables.put("action", "login");
|
|
|
+ uriVariables.put("registrationId", clientRegistration.getRegistrationId());
|
|
|
+
|
|
|
+ return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUriTemplate())
|
|
|
+ .buildAndExpand(uriVariables)
|
|
|
+ .toUriString();
|
|
|
+ }
|
|
|
+
|
|
|
private void setUpAuthenticationResult(ClientRegistration registration) {
|
|
|
OAuth2User user = mock(OAuth2User.class);
|
|
|
when(user.getName()).thenReturn(this.principalName1);
|