Browse Source

Ensure consistent matching of redirect_uri

Fixes gh-5756
Joe Grandja 7 năm trước cách đây
mục cha
commit
9a49795abc

+ 5 - 11
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -25,6 +25,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.web.DefaultRedirectStrategy;
 import org.springframework.security.web.RedirectStrategy;
+import org.springframework.security.web.util.UrlUtils;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.util.Assert;
 import org.springframework.web.filter.OncePerRequestFilter;
@@ -183,16 +184,9 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
 	}
 
 	private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration) {
-		int port = request.getServerPort();
-		if (("http".equals(request.getScheme()) && port == 80) || ("https".equals(request.getScheme()) && port == 443)) {
-			port = -1;		// Removes the port in UriComponentsBuilder
-		}
-
-		String baseUrl = UriComponentsBuilder.newInstance()
-			.scheme(request.getScheme())
-			.host(request.getServerName())
-			.port(port)
-			.path(request.getContextPath())
+		String baseUrl = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
+			.replaceQuery(null)
+			.replacePath(request.getContextPath())
 			.build()
 			.toUriString();
 

+ 6 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java

@@ -34,8 +34,10 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResp
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
 import org.springframework.security.web.context.SecurityContextRepository;
+import org.springframework.security.web.util.UrlUtils;
 import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
+import org.springframework.web.util.UriComponentsBuilder;
 
 import javax.servlet.ServletException;
 import javax.servlet.http.HttpServletRequest;
@@ -192,7 +194,10 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
 		String code = request.getParameter(OAuth2ParameterNames.CODE);
 		String errorCode = request.getParameter(OAuth2ParameterNames.ERROR);
 		String state = request.getParameter(OAuth2ParameterNames.STATE);
-		String redirectUri = request.getRequestURL().toString();
+		String redirectUri = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
+			.replaceQuery(null)
+			.build()
+			.toUriString();
 
 		if (StringUtils.hasText(code)) {
 			return OAuth2AuthorizationResponse.success(code)

+ 124 - 3
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java

@@ -43,9 +43,12 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+import org.springframework.security.web.util.UrlUtils;
+import org.springframework.web.util.UriComponentsBuilder;
 
 import javax.servlet.FilterChain;
 import javax.servlet.http.HttpServletRequest;
@@ -64,7 +67,7 @@ import static org.powermock.api.mockito.PowerMockito.verifyPrivate;
  * @author Joe Grandja
  */
 @PowerMockIgnore("javax.security.*")
-@PrepareForTest({OAuth2AuthorizationRequest.class, OAuth2AuthorizationExchange.class, OAuth2LoginAuthenticationFilter.class})
+@PrepareForTest({OAuth2AuthorizationExchange.class, OAuth2LoginAuthenticationFilter.class})
 @RunWith(PowerMockRunner.class)
 public class OAuth2LoginAuthenticationFilterTests {
 	private ClientRegistration registration1;
@@ -322,15 +325,133 @@ public class OAuth2LoginAuthenticationFilterTests {
 		}
 	}
 
+	// gh-5756
+	@Test
+	public void doFilterWhenAuthorizationResponseHasDefaultPort80ThenRedirectUriMatchingExcludesPort() throws Exception {
+		String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setScheme("http");
+		request.setServerName("example.com");
+		request.setServerPort(80);
+		request.setServletPath(requestUri);
+		request.addParameter(OAuth2ParameterNames.CODE, "code");
+		request.addParameter(OAuth2ParameterNames.STATE, "state");
+
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.setUpAuthorizationRequest(request, response, this.registration2);
+		this.setUpAuthenticationResult(this.registration2);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		ArgumentCaptor<Authentication> authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class);
+		verify(this.authenticationManager).authenticate(authenticationArgCaptor.capture());
+
+		OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) authenticationArgCaptor.getValue();
+		OAuth2AuthorizationRequest authorizationRequest = authentication.getAuthorizationExchange().getAuthorizationRequest();
+		OAuth2AuthorizationResponse authorizationResponse = authentication.getAuthorizationExchange().getAuthorizationResponse();
+
+		String expectedRedirectUri = "http://example.com/login/oauth2/code/registration-2";
+		assertThat(authorizationRequest.getRedirectUri()).isEqualTo(expectedRedirectUri);
+		assertThat(authorizationResponse.getRedirectUri()).isEqualTo(expectedRedirectUri);
+	}
+
+	// gh-5756
+	@Test
+	public void doFilterWhenAuthorizationResponseHasDefaultPort443ThenRedirectUriMatchingExcludesPort() throws Exception {
+		String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setScheme("https");
+		request.setServerName("example.com");
+		request.setServerPort(443);
+		request.setServletPath(requestUri);
+		request.addParameter(OAuth2ParameterNames.CODE, "code");
+		request.addParameter(OAuth2ParameterNames.STATE, "state");
+
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.setUpAuthorizationRequest(request, response, this.registration2);
+		this.setUpAuthenticationResult(this.registration2);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		ArgumentCaptor<Authentication> authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class);
+		verify(this.authenticationManager).authenticate(authenticationArgCaptor.capture());
+
+		OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) authenticationArgCaptor.getValue();
+		OAuth2AuthorizationRequest authorizationRequest = authentication.getAuthorizationExchange().getAuthorizationRequest();
+		OAuth2AuthorizationResponse authorizationResponse = authentication.getAuthorizationExchange().getAuthorizationResponse();
+
+		String expectedRedirectUri = "https://example.com/login/oauth2/code/registration-2";
+		assertThat(authorizationRequest.getRedirectUri()).isEqualTo(expectedRedirectUri);
+		assertThat(authorizationResponse.getRedirectUri()).isEqualTo(expectedRedirectUri);
+	}
+
+	// gh-5756
+	@Test
+	public void doFilterWhenAuthorizationResponseHasNonDefaultPortThenRedirectUriMatchingIncludesPort() throws Exception {
+		String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setScheme("https");
+		request.setServerName("example.com");
+		request.setServerPort(9090);
+		request.setServletPath(requestUri);
+		request.addParameter(OAuth2ParameterNames.CODE, "code");
+		request.addParameter(OAuth2ParameterNames.STATE, "state");
+
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.setUpAuthorizationRequest(request, response, this.registration2);
+		this.setUpAuthenticationResult(this.registration2);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		ArgumentCaptor<Authentication> authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class);
+		verify(this.authenticationManager).authenticate(authenticationArgCaptor.capture());
+
+		OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) authenticationArgCaptor.getValue();
+		OAuth2AuthorizationRequest authorizationRequest = authentication.getAuthorizationExchange().getAuthorizationRequest();
+		OAuth2AuthorizationResponse authorizationResponse = authentication.getAuthorizationExchange().getAuthorizationResponse();
+
+		String expectedRedirectUri = "https://example.com:9090/login/oauth2/code/registration-2";
+		assertThat(authorizationRequest.getRedirectUri()).isEqualTo(expectedRedirectUri);
+		assertThat(authorizationResponse.getRedirectUri()).isEqualTo(expectedRedirectUri);
+	}
+
 	private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
 											ClientRegistration registration) {
-		OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
 		Map<String, Object> additionalParameters = new HashMap<>();
 		additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
-		when(authorizationRequest.getAdditionalParameters()).thenReturn(additionalParameters);
+		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+			.authorizationUri(registration.getProviderDetails().getAuthorizationUri())
+			.clientId(registration.getClientId())
+			.redirectUri(expandRedirectUri(request, registration))
+			.scopes(registration.getScopes())
+			.state("state")
+			.additionalParameters(additionalParameters)
+			.build();
 		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
 	}
 
+	private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration) {
+		String baseUrl = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
+			.replaceQuery(null)
+			.replacePath(request.getContextPath())
+			.build()
+			.toUriString();
+
+		Map<String, String> uriVariables = new HashMap<>();
+		uriVariables.put("baseUrl", baseUrl);
+		uriVariables.put("registrationId", clientRegistration.getRegistrationId());
+
+		return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUriTemplate())
+			.buildAndExpand(uriVariables)
+			.toUriString();
+	}
+
 	private void setUpAuthenticationResult(ClientRegistration registration) {
 		OAuth2User user = mock(OAuth2User.class);
 		when(user.getName()).thenReturn(this.principalName1);