Преглед на файлове

Use param matching for Authorization Response

Fixes gh-4576
Joe Grandja преди 8 години
родител
ревизия
9a8ddebc94

+ 1 - 2
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeAuthenticationFilterConfigurer.java

@@ -36,7 +36,6 @@ import org.springframework.security.oauth2.core.AccessToken;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.security.oauth2.oidc.client.user.OidcUserService;
 import org.springframework.security.web.util.matcher.RequestMatcher;
-import org.springframework.security.web.util.matcher.RequestVariablesExtractor;
 import org.springframework.util.Assert;
 
 import java.net.URI;
@@ -48,7 +47,7 @@ import java.util.Map;
 /**
  * @author Joe Grandja
  */
-final class AuthorizationCodeAuthenticationFilterConfigurer<H extends HttpSecurityBuilder<H>, R extends RequestMatcher & RequestVariablesExtractor> extends
+final class AuthorizationCodeAuthenticationFilterConfigurer<H extends HttpSecurityBuilder<H>, R extends RequestMatcher> extends
 		AbstractAuthenticationFilterConfigurer<H, AuthorizationCodeAuthenticationFilterConfigurer<H, R>, AuthorizationCodeAuthenticationProcessingFilter> {
 
 	private R authorizationResponseMatcher;

+ 1 - 1
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java

@@ -166,7 +166,7 @@ public final class OAuth2LoginConfigurer<H extends HttpSecurityBuilder<H>> exten
 		private RedirectionEndpointConfig() {
 		}
 
-		public <R extends RequestMatcher & RequestVariablesExtractor> RedirectionEndpointConfig requestMatcher(R authorizationResponseMatcher) {
+		public <R extends RequestMatcher> RedirectionEndpointConfig requestMatcher(R authorizationResponseMatcher) {
 			Assert.notNull(authorizationResponseMatcher, "authorizationResponseMatcher cannot be null");
 			OAuth2LoginConfigurer.this.authorizationCodeAuthenticationFilterConfigurer.authorizationResponseMatcher(authorizationResponseMatcher);
 			return this;

+ 23 - 18
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilter.java

@@ -37,9 +37,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter;
 import org.springframework.security.oauth2.core.endpoint.TokenResponseAttributes;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
-import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
-import org.springframework.security.web.util.matcher.RequestVariablesExtractor;
 import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
 
@@ -111,20 +109,18 @@ import java.io.IOException;
  */
 public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAuthenticationProcessingFilter {
 	public static final String DEFAULT_AUTHORIZATION_RESPONSE_BASE_URI = "/oauth2/authorize/code";
-	public static final String REGISTRATION_ID_URI_VARIABLE_NAME = "registrationId";
-	public static final String DEFAULT_AUTHORIZATION_RESPONSE_URI = DEFAULT_AUTHORIZATION_RESPONSE_BASE_URI + "/{" + REGISTRATION_ID_URI_VARIABLE_NAME + "}";
 	private static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found";
 	private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter";
 	private static final String INVALID_REDIRECT_URI_PARAMETER_ERROR_CODE = "invalid_redirect_uri_parameter";
 	private final ErrorResponseAttributesConverter errorResponseConverter = new ErrorResponseAttributesConverter();
 	private final AuthorizationCodeAuthorizationResponseAttributesConverter authorizationCodeResponseConverter =
 		new AuthorizationCodeAuthorizationResponseAttributesConverter();
-	private RequestMatcher authorizationResponseMatcher = new AntPathRequestMatcher(DEFAULT_AUTHORIZATION_RESPONSE_URI);
 	private ClientRegistrationRepository clientRegistrationRepository;
+	private RequestMatcher authorizationResponseMatcher = new AuthorizationResponseMatcher();
 	private AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository();
 
 	public AuthorizationCodeAuthenticationProcessingFilter() {
-		super(DEFAULT_AUTHORIZATION_RESPONSE_URI);
+		super(new AuthorizationResponseMatcher());
 	}
 
 	@Override
@@ -140,17 +136,8 @@ public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAut
 		}
 
 		AuthorizationRequestAttributes matchingAuthorizationRequest = this.resolveAuthorizationRequest(request);
-
-		String registrationId = ((RequestVariablesExtractor)this.getAuthorizationResponseMatcher())
-			.extractUriTemplateVariables(request).get(REGISTRATION_ID_URI_VARIABLE_NAME);
-		ClientRegistration clientRegistration = null;
-		if (!StringUtils.isEmpty(registrationId)) {
-			clientRegistration = this.getClientRegistrationRepository().findByRegistrationId(registrationId);
-		}
-		if (clientRegistration == null || !matchingAuthorizationRequest.getClientId().equals(clientRegistration.getClientId())) {
-			OAuth2Error oauth2Error = new OAuth2Error(OAuth2Error.INVALID_REQUEST_ERROR_CODE);
-			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
-		}
+		String registrationId = (String)matchingAuthorizationRequest.getAdditionalParameters().get(OAuth2Parameter.REGISTRATION_ID);
+		ClientRegistration clientRegistration = this.getClientRegistrationRepository().findByRegistrationId(registrationId);
 
 		// The clientRegistration.redirectUri may contain Uri template variables, whether it's configured by
 		// the user or configured by default. In these cases, the redirectUri will be expanded and ultimately changed
@@ -180,7 +167,7 @@ public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAut
 		return this.authorizationResponseMatcher;
 	}
 
-	public final <T extends RequestMatcher & RequestVariablesExtractor> void setAuthorizationResponseMatcher(T authorizationResponseMatcher) {
+	public final <T extends RequestMatcher> void setAuthorizationResponseMatcher(T authorizationResponseMatcher) {
 		Assert.notNull(authorizationResponseMatcher, "authorizationResponseMatcher cannot be null");
 		this.authorizationResponseMatcher = authorizationResponseMatcher;
 		this.setRequiresAuthenticationRequestMatcher(authorizationResponseMatcher);
@@ -228,4 +215,22 @@ public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAut
 			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
 		}
 	}
+
+	private static class AuthorizationResponseMatcher implements RequestMatcher {
+
+		@Override
+		public boolean matches(HttpServletRequest request) {
+			return this.successResponse(request) || this.errorResponse(request);
+		}
+
+		private boolean successResponse(HttpServletRequest request) {
+			return StringUtils.hasText(request.getParameter(OAuth2Parameter.CODE)) &&
+				StringUtils.hasText(request.getParameter(OAuth2Parameter.STATE));
+		}
+
+		private boolean errorResponse(HttpServletRequest request) {
+			return StringUtils.hasText(request.getParameter(OAuth2Parameter.ERROR)) &&
+				StringUtils.hasText(request.getParameter(OAuth2Parameter.STATE));
+		}
+	}
 }

+ 5 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeRequestRedirectFilter.java

@@ -19,6 +19,7 @@ import org.springframework.security.crypto.keygen.StringKeyGenerator;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.core.endpoint.AuthorizationRequestAttributes;
+import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter;
 import org.springframework.security.web.DefaultRedirectStrategy;
 import org.springframework.security.web.RedirectStrategy;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
@@ -122,6 +123,9 @@ public class AuthorizationCodeRequestRedirectFilter extends OncePerRequestFilter
 
 		String redirectUriStr = this.expandRedirectUri(request, clientRegistration);
 
+		Map<String,Object> additionalParameters = new HashMap<>();
+		additionalParameters.put(OAuth2Parameter.REGISTRATION_ID, clientRegistration.getRegistrationId());
+
 		AuthorizationRequestAttributes authorizationRequestAttributes =
 			AuthorizationRequestAttributes.withAuthorizationCode()
 				.clientId(clientRegistration.getClientId())
@@ -129,6 +133,7 @@ public class AuthorizationCodeRequestRedirectFilter extends OncePerRequestFilter
 				.redirectUri(redirectUriStr)
 				.scope(clientRegistration.getScope())
 				.state(this.stateGenerator.generateKey())
+				.additionalParameters(additionalParameters)
 				.build();
 
 		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequestAttributes, request, response);

+ 6 - 3
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilterTests.java

@@ -38,9 +38,8 @@ import org.springframework.security.web.authentication.AuthenticationSuccessHand
 import javax.servlet.FilterChain;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
-
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.mockito.Matchers.any;
+import java.util.HashMap;
+import java.util.Map;
 
 /**
  * Tests {@link AuthorizationCodeAuthenticationProcessingFilter}.
@@ -233,6 +232,9 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
 											ClientRegistration clientRegistration,
 											String state) {
 
+		Map<String,Object> additionalParameters = new HashMap<>();
+		additionalParameters.put(OAuth2Parameter.REGISTRATION_ID, clientRegistration.getRegistrationId());
+
 		AuthorizationRequestAttributes authorizationRequestAttributes =
 			AuthorizationRequestAttributes.withAuthorizationCode()
 				.clientId(clientRegistration.getClientId())
@@ -240,6 +242,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
 				.redirectUri(clientRegistration.getRedirectUri())
 				.scope(clientRegistration.getScope())
 				.state(state)
+				.additionalParameters(additionalParameters)
 				.build();
 
 		authorizationRequestRepository.saveAuthorizationRequest(authorizationRequestAttributes, request, response);

+ 19 - 2
oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/AuthorizationRequestAttributes.java

@@ -21,7 +21,9 @@ import org.springframework.util.CollectionUtils;
 
 import java.io.Serializable;
 import java.util.Collections;
+import java.util.LinkedHashMap;
 import java.util.LinkedHashSet;
+import java.util.Map;
 import java.util.Set;
 
 /**
@@ -43,6 +45,7 @@ public final class AuthorizationRequestAttributes implements Serializable {
 	private String redirectUri;
 	private Set<String> scope;
 	private String state;
+	private Map<String,Object> additionalParameters;
 
 	private AuthorizationRequestAttributes() {
 	}
@@ -75,6 +78,10 @@ public final class AuthorizationRequestAttributes implements Serializable {
 		return this.state;
 	}
 
+	public Map<String, Object> getAdditionalParameters() {
+		return this.additionalParameters;
+	}
+
 	public static Builder withAuthorizationCode() {
 		return new Builder(AuthorizationGrantType.AUTHORIZATION_CODE);
 	}
@@ -107,8 +114,7 @@ public final class AuthorizationRequestAttributes implements Serializable {
 		}
 
 		public Builder scope(Set<String> scope) {
-			this.authorizationRequest.scope = Collections.unmodifiableSet(
-				CollectionUtils.isEmpty(scope) ? Collections.emptySet() : new LinkedHashSet<>(scope));
+			this.authorizationRequest.scope = scope;
 			return this;
 		}
 
@@ -117,9 +123,20 @@ public final class AuthorizationRequestAttributes implements Serializable {
 			return this;
 		}
 
+		public Builder additionalParameters(Map<String,Object> additionalParameters) {
+			this.authorizationRequest.additionalParameters = additionalParameters;
+			return this;
+		}
+
 		public AuthorizationRequestAttributes build() {
 			Assert.hasText(this.authorizationRequest.clientId, "clientId cannot be empty");
 			Assert.hasText(this.authorizationRequest.authorizeUri, "authorizeUri cannot be empty");
+			this.authorizationRequest.scope = Collections.unmodifiableSet(
+				CollectionUtils.isEmpty(this.authorizationRequest.scope) ?
+					Collections.emptySet() : new LinkedHashSet<>(this.authorizationRequest.scope));
+			this.authorizationRequest.additionalParameters = Collections.unmodifiableMap(
+				CollectionUtils.isEmpty(this.authorizationRequest.additionalParameters) ?
+					Collections.emptyMap() : new LinkedHashMap<>(this.authorizationRequest.additionalParameters));
 			return this.authorizationRequest;
 		}
 	}

+ 3 - 1
oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2Parameter.java

@@ -16,7 +16,7 @@
 package org.springframework.security.oauth2.core.endpoint;
 
 /**
- * Standard parameters defined in the OAuth Parameters Registry
+ * Standard and additional (custom) parameters defined in the OAuth Parameters Registry
  * and used by the authorization endpoint and token endpoint.
  *
  * @author Joe Grandja
@@ -43,4 +43,6 @@ public interface OAuth2Parameter {
 
 	String ERROR_URI = "error_uri";
 
+	String REGISTRATION_ID = "registration_id";		// Non-standard additional parameter
+
 }