浏览代码

OAuth2AuthorizationResponseUtils uses MultiMap

Fixes: gh-5331
Rob Winch 7 年之前
父节点
当前提交
c696640276

+ 13 - 8
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java

@@ -15,6 +15,13 @@
  */
 package org.springframework.security.oauth2.client.web;
 
+import java.io.IOException;
+
+import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
 import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
@@ -39,16 +46,11 @@ import org.springframework.security.web.savedrequest.RequestCache;
 import org.springframework.security.web.savedrequest.SavedRequest;
 import org.springframework.security.web.util.UrlUtils;
 import org.springframework.util.Assert;
+import org.springframework.util.MultiValueMap;
 import org.springframework.util.StringUtils;
 import org.springframework.web.filter.OncePerRequestFilter;
 import org.springframework.web.util.UriComponentsBuilder;
 
-import javax.servlet.FilterChain;
-import javax.servlet.ServletException;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-import java.io.IOException;
-
 /**
  * A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
  * which handles the processing of the OAuth 2.0 Authorization Response.
@@ -147,8 +149,9 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
 		}
 		String requestUrl = UrlUtils.buildFullRequestUrl(request.getScheme(), request.getServerName(),
 				request.getServerPort(), request.getRequestURI(), null);
+		MultiValueMap<String, String> params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap());
 		if (requestUrl.equals(authorizationRequest.getRedirectUri()) &&
-				OAuth2AuthorizationResponseUtils.isAuthorizationResponse(request)) {
+				OAuth2AuthorizationResponseUtils.isAuthorizationResponse(params)) {
 			return true;
 		}
 		return false;
@@ -162,7 +165,9 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
 		String registrationId = (String) authorizationRequest.getAdditionalParameters().get(OAuth2ParameterNames.REGISTRATION_ID);
 		ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId);
 
-		OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(request);
+		MultiValueMap<String, String> params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap());
+		String redirectUri = request.getRequestURL().toString();
+		OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(params, redirectUri);
 
 		OAuth2AuthorizationCodeAuthenticationToken authenticationRequest = new OAuth2AuthorizationCodeAuthenticationToken(
 			clientRegistration, new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse));

+ 29 - 16
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationResponseUtils.java

@@ -15,12 +15,14 @@
  */
 package org.springframework.security.oauth2.client.web;
 
+import java.util.Map;
+
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.util.LinkedMultiValueMap;
+import org.springframework.util.MultiValueMap;
 import org.springframework.util.StringUtils;
 
-import javax.servlet.http.HttpServletRequest;
-
 /**
  * Utility methods for an OAuth 2.0 Authorization Response.
  *
@@ -33,25 +35,36 @@ final class OAuth2AuthorizationResponseUtils {
 	private OAuth2AuthorizationResponseUtils() {
 	}
 
-	static boolean isAuthorizationResponse(HttpServletRequest request) {
+	static MultiValueMap<String, String> toMultiMap(Map<String, String[]> map) {
+		MultiValueMap<String, String> params = new LinkedMultiValueMap<>(map.size());
+		map.forEach((key, values) -> {
+			if (values.length > 0) {
+				for (String value : values) {
+					params.add(key, value);
+				}
+			}
+		});
+		return params;
+	}
+
+	static boolean isAuthorizationResponse(MultiValueMap<String, String> request) {
 		return isAuthorizationResponseSuccess(request) || isAuthorizationResponseError(request);
 	}
 
-	static boolean isAuthorizationResponseSuccess(HttpServletRequest request) {
-		return StringUtils.hasText(request.getParameter(OAuth2ParameterNames.CODE)) &&
-			StringUtils.hasText(request.getParameter(OAuth2ParameterNames.STATE));
+	static boolean isAuthorizationResponseSuccess(MultiValueMap<String, String> request) {
+		return StringUtils.hasText(request.getFirst(OAuth2ParameterNames.CODE)) &&
+			StringUtils.hasText(request.getFirst(OAuth2ParameterNames.STATE));
 	}
 
-	static boolean isAuthorizationResponseError(HttpServletRequest request) {
-		return StringUtils.hasText(request.getParameter(OAuth2ParameterNames.ERROR)) &&
-			StringUtils.hasText(request.getParameter(OAuth2ParameterNames.STATE));
+	static boolean isAuthorizationResponseError(MultiValueMap<String, String> request) {
+		return StringUtils.hasText(request.getFirst(OAuth2ParameterNames.ERROR)) &&
+			StringUtils.hasText(request.getFirst(OAuth2ParameterNames.STATE));
 	}
 
-	static OAuth2AuthorizationResponse convert(HttpServletRequest request) {
-		String code = request.getParameter(OAuth2ParameterNames.CODE);
-		String errorCode = request.getParameter(OAuth2ParameterNames.ERROR);
-		String state = request.getParameter(OAuth2ParameterNames.STATE);
-		String redirectUri = request.getRequestURL().toString();
+	static OAuth2AuthorizationResponse convert(MultiValueMap<String, String> request, String redirectUri) {
+		String code = request.getFirst(OAuth2ParameterNames.CODE);
+		String errorCode = request.getFirst(OAuth2ParameterNames.ERROR);
+		String state = request.getFirst(OAuth2ParameterNames.STATE);
 
 		if (StringUtils.hasText(code)) {
 			return OAuth2AuthorizationResponse.success(code)
@@ -59,8 +72,8 @@ final class OAuth2AuthorizationResponseUtils {
 				.state(state)
 				.build();
 		} else {
-			String errorDescription = request.getParameter(OAuth2ParameterNames.ERROR_DESCRIPTION);
-			String errorUri = request.getParameter(OAuth2ParameterNames.ERROR_URI);
+			String errorDescription = request.getFirst(OAuth2ParameterNames.ERROR_DESCRIPTION);
+			String errorUri = request.getFirst(OAuth2ParameterNames.ERROR_URI);
 			return OAuth2AuthorizationResponse.error(errorCode)
 				.redirectUri(redirectUri)
 				.errorDescription(errorDescription)

+ 11 - 7
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java

@@ -15,6 +15,12 @@
  */
 package org.springframework.security.oauth2.client.web;
 
+import java.io.IOException;
+
+import javax.servlet.ServletException;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
@@ -35,11 +41,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.util.Assert;
-
-import javax.servlet.ServletException;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-import java.io.IOException;
+import org.springframework.util.MultiValueMap;
 
 /**
  * An implementation of an {@link AbstractAuthenticationProcessingFilter} for OAuth 2.0 Login.
@@ -134,7 +136,8 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
 	public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response)
 			throws AuthenticationException, IOException, ServletException {
 
-		if (!OAuth2AuthorizationResponseUtils.isAuthorizationResponse(request)) {
+		MultiValueMap<String, String> params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap());
+		if (!OAuth2AuthorizationResponseUtils.isAuthorizationResponse(params)) {
 			OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST);
 			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
 		}
@@ -152,7 +155,8 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
 					"Client Registration not found with Id: " + registrationId, null);
 			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
 		}
-		OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(request);
+		String redirectUri = request.getRequestURL().toString();
+		OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(params, redirectUri);
 
 		OAuth2LoginAuthenticationToken authenticationRequest = new OAuth2LoginAuthenticationToken(
 				clientRegistration, new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse));