Browse Source

Extract authentication logic from AuthorizationCodeAuthenticationFilter

Fixes gh-4590
Joe Grandja 8 years ago
parent
commit
97c938e7f3

+ 24 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/AuthorizationCodeAuthenticationProvider.java

@@ -21,6 +21,9 @@ import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.oauth2.client.token.InMemoryAccessTokenRepository;
 import org.springframework.security.oauth2.client.token.SecurityTokenRepository;
 import org.springframework.security.oauth2.core.AccessToken;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.endpoint.AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.AuthorizationResponse;
 import org.springframework.security.oauth2.oidc.client.authentication.OidcClientAuthenticationToken;
 import org.springframework.util.Assert;
 
@@ -32,7 +35,7 @@ import org.springframework.util.Assert;
  *
  * <p>
  * The {@link AuthorizationCodeAuthenticationProvider} uses an {@link AuthorizationGrantAuthenticator}
- * to authenticate the {@link AuthorizationCodeAuthenticationToken#getAuthorizationCode()} and ultimately
+ * to authenticate the <i>authorization code</i> credential and ultimately
  * return an <i>&quot;Authorized Client&quot;</i> as an {@link OAuth2ClientAuthenticationToken}.
  *
  * @author Joe Grandja
@@ -49,6 +52,8 @@ import org.springframework.util.Assert;
  * @see <a target="_blank" href="http://openid.net/specs/openid-connect-core-1_0.html#TokenResponse">Section 3.1.3.3 OpenID Connect Token Response</a>
  */
 public class AuthorizationCodeAuthenticationProvider implements AuthenticationProvider {
+	private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter";
+	private static final String INVALID_REDIRECT_URI_PARAMETER_ERROR_CODE = "invalid_redirect_uri_parameter";
 	private final AuthorizationGrantAuthenticator<AuthorizationCodeAuthenticationToken> authorizationCodeAuthenticator;
 	private SecurityTokenRepository<AccessToken> accessTokenRepository = new InMemoryAccessTokenRepository();
 
@@ -64,6 +69,24 @@ public class AuthorizationCodeAuthenticationProvider implements AuthenticationPr
 		AuthorizationCodeAuthenticationToken authorizationCodeAuthentication =
 				(AuthorizationCodeAuthenticationToken) authentication;
 
+		AuthorizationRequest authorizationRequest = authorizationCodeAuthentication.getAuthorizationRequest();
+		AuthorizationResponse authorizationResponse = authorizationCodeAuthentication.getAuthorizationResponse();
+
+		if (authorizationResponse.statusError()) {
+			throw new OAuth2AuthenticationException(
+				authorizationResponse.getError(), authorizationResponse.getError().toString());
+		}
+
+		if (!authorizationResponse.getState().equals(authorizationRequest.getState())) {
+			OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE);
+			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
+		}
+
+		if (!authorizationResponse.getRedirectUri().equals(authorizationRequest.getRedirectUri())) {
+			OAuth2Error oauth2Error = new OAuth2Error(INVALID_REDIRECT_URI_PARAMETER_ERROR_CODE);
+			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
+		}
+
 		OAuth2ClientAuthenticationToken oauth2ClientAuthentication =
 			this.authorizationCodeAuthenticator.authenticate(authorizationCodeAuthentication);
 

+ 16 - 12
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/AuthorizationCodeAuthenticationToken.java

@@ -18,6 +18,7 @@ package org.springframework.security.oauth2.client.authentication;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.endpoint.AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.AuthorizationResponse;
 import org.springframework.util.Assert;
 
 /**
@@ -28,38 +29,37 @@ import org.springframework.util.Assert;
  * @since 5.0
  * @see AuthorizationGrantAuthenticationToken
  * @see ClientRegistration
+ * @see AuthorizationRequest
+ * @see AuthorizationResponse
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-1.3.1">Section 1.3.1 Authorization Code Grant</a>
  */
 public class AuthorizationCodeAuthenticationToken extends AuthorizationGrantAuthenticationToken {
-	private final String authorizationCode;
 	private final ClientRegistration clientRegistration;
 	private final AuthorizationRequest authorizationRequest;
+	private final AuthorizationResponse authorizationResponse;
+
+	public AuthorizationCodeAuthenticationToken(ClientRegistration clientRegistration,
+												AuthorizationRequest authorizationRequest,
+												AuthorizationResponse authorizationResponse) {
 
-	public AuthorizationCodeAuthenticationToken(String authorizationCode,
-												ClientRegistration clientRegistration,
-												AuthorizationRequest authorizationRequest) {
 		super(AuthorizationGrantType.AUTHORIZATION_CODE);
-		Assert.hasText(authorizationCode, "authorizationCode cannot be empty");
 		Assert.notNull(clientRegistration, "clientRegistration cannot be null");
 		Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
-		this.authorizationCode = authorizationCode;
+		Assert.notNull(authorizationResponse, "authorizationResponse cannot be null");
 		this.clientRegistration = clientRegistration;
 		this.authorizationRequest = authorizationRequest;
+		this.authorizationResponse = authorizationResponse;
 		this.setAuthenticated(false);
 	}
 
 	@Override
 	public Object getPrincipal() {
-		return this.getClientRegistration().getClientId();
+		return "";
 	}
 
 	@Override
 	public Object getCredentials() {
-		return this.getAuthorizationCode();
-	}
-
-	public String getAuthorizationCode() {
-		return this.authorizationCode;
+		return "";
 	}
 
 	public ClientRegistration getClientRegistration() {
@@ -69,4 +69,8 @@ public class AuthorizationCodeAuthenticationToken extends AuthorizationGrantAuth
 	public AuthorizationRequest getAuthorizationRequest() {
 		return this.authorizationRequest;
 	}
+
+	public AuthorizationResponse getAuthorizationResponse() {
+		return this.authorizationResponse;
+	}
 }

+ 11 - 38
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationFilter.java

@@ -82,8 +82,6 @@ import java.io.IOException;
 public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticationProcessingFilter {
 	public static final String DEFAULT_AUTHORIZATION_RESPONSE_BASE_URI = "/oauth2/authorize/code";
 	private static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found";
-	private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter";
-	private static final String INVALID_REDIRECT_URI_PARAMETER_ERROR_CODE = "invalid_redirect_uri_parameter";
 	private final AuthorizationResponseConverter authorizationResponseConverter = new AuthorizationResponseConverter();
 	private ClientRegistrationRepository clientRegistrationRepository;
 	private RequestMatcher authorizationResponseMatcher = new AuthorizationResponseMatcher();
@@ -98,16 +96,16 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
 	public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response)
 			throws AuthenticationException, IOException, ServletException {
 
-		AuthorizationResponse authorizationResponse = this.authorizationResponseConverter.apply(request);
-
-		if (authorizationResponse.statusError()) {
-			this.getAuthorizationRequestRepository().removeAuthorizationRequest(request);
-			throw new OAuth2AuthenticationException(
-				authorizationResponse.getError(), authorizationResponse.getError().toString());
+		AuthorizationRequest authorizationRequest = this.getAuthorizationRequestRepository().loadAuthorizationRequest(request);
+		if (authorizationRequest == null) {
+			OAuth2Error oauth2Error = new OAuth2Error(AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE);
+			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
 		}
+		this.getAuthorizationRequestRepository().removeAuthorizationRequest(request);
 
-		AuthorizationRequest matchingAuthorizationRequest = this.resolveAuthorizationRequest(request);
-		String registrationId = (String)matchingAuthorizationRequest.getAdditionalParameters().get(OAuth2Parameter.REGISTRATION_ID);
+		AuthorizationResponse authorizationResponse = this.authorizationResponseConverter.apply(request);
+
+		String registrationId = (String)authorizationRequest.getAdditionalParameters().get(OAuth2Parameter.REGISTRATION_ID);
 		ClientRegistration clientRegistration = this.getClientRegistrationRepository().findByRegistrationId(registrationId);
 
 		// The clientRegistration.redirectUri may contain Uri template variables, whether it's configured by
@@ -116,13 +114,13 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
 		// The resulting redirectUri used for the authorization request and saved within the AuthorizationRequestRepository
 		// MUST BE the same one used to complete the authorization code flow.
 		// Therefore, we'll create a copy of the clientRegistration and override the redirectUri
-		// with the one contained in matchingAuthorizationRequest.
+		// with the one contained in authorizationRequest.
 		clientRegistration = new ClientRegistration.Builder(clientRegistration)
-			.redirectUri(matchingAuthorizationRequest.getRedirectUri())
+			.redirectUri(authorizationRequest.getRedirectUri())
 			.build();
 
 		AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = new AuthorizationCodeAuthenticationToken(
-				authorizationResponse.getCode(), clientRegistration, matchingAuthorizationRequest);
+				clientRegistration, authorizationRequest, authorizationResponse);
 		authorizationCodeAuthentication.setDetails(this.authenticationDetailsSource.buildDetails(request));
 
 		OAuth2ClientAuthenticationToken oauth2ClientAuthentication =
@@ -172,31 +170,6 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
 		this.authorizationRequestRepository = authorizationRequestRepository;
 	}
 
-	private AuthorizationRequest resolveAuthorizationRequest(HttpServletRequest request) {
-		AuthorizationRequest authorizationRequest =
-				this.getAuthorizationRequestRepository().loadAuthorizationRequest(request);
-		if (authorizationRequest == null) {
-			OAuth2Error oauth2Error = new OAuth2Error(AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE);
-			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
-		}
-		this.getAuthorizationRequestRepository().removeAuthorizationRequest(request);
-		this.assertMatchingAuthorizationRequest(request, authorizationRequest);
-		return authorizationRequest;
-	}
-
-	private void assertMatchingAuthorizationRequest(HttpServletRequest request, AuthorizationRequest authorizationRequest) {
-		String state = request.getParameter(OAuth2Parameter.STATE);
-		if (!authorizationRequest.getState().equals(state)) {
-			OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE);
-			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
-		}
-
-		if (!request.getRequestURL().toString().equals(authorizationRequest.getRedirectUri())) {
-			OAuth2Error oauth2Error = new OAuth2Error(INVALID_REDIRECT_URI_PARAMETER_ERROR_CODE);
-			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
-		}
-	}
-
 	private boolean authenticated() {
 		Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication();
 		return currentAuthentication != null &&

+ 3 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/converter/AuthorizationResponseConverter.java

@@ -38,15 +38,18 @@ public final class AuthorizationResponseConverter implements Function<HttpServle
 		String code = request.getParameter(OAuth2Parameter.CODE);
 		String errorCode = request.getParameter(OAuth2Parameter.ERROR);
 		String state = request.getParameter(OAuth2Parameter.STATE);
+		String redirectUri = request.getRequestURL().toString();
 
 		if (StringUtils.hasText(code)) {
 			return AuthorizationResponse.success(code)
+				.redirectUri(redirectUri)
 				.state(state)
 				.build();
 		} else if (StringUtils.hasText(errorCode)) {
 			String description = request.getParameter(OAuth2Parameter.ERROR_DESCRIPTION);
 			String uri = request.getParameter(OAuth2Parameter.ERROR_URI);
 			return AuthorizationResponse.error(errorCode)
+				.redirectUri(redirectUri)
 				.errorDescription(description)
 				.errorUri(uri)
 				.state(state)

+ 2 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/nimbus/NimbusAuthorizationCodeTokenExchanger.java

@@ -76,7 +76,8 @@ public class NimbusAuthorizationCodeTokenExchanger implements AuthorizationGrant
 		ClientRegistration clientRegistration = authorizationCodeAuthenticationToken.getClientRegistration();
 
 		// Build the authorization code grant request for the token endpoint
-		AuthorizationCode authorizationCode = new AuthorizationCode(authorizationCodeAuthenticationToken.getAuthorizationCode());
+		AuthorizationCode authorizationCode = new AuthorizationCode(
+			authorizationCodeAuthenticationToken.getAuthorizationResponse().getCode());
 		URI redirectUri = this.toURI(clientRegistration.getRedirectUri());
 		AuthorizationGrant authorizationCodeGrant = new AuthorizationCodeGrant(authorizationCode, redirectUri);
 		URI tokenUri = this.toURI(clientRegistration.getProviderDetails().getTokenUri());

+ 0 - 49
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationFilterTests.java

@@ -152,55 +152,6 @@ public class AuthorizationCodeAuthenticationFilterTests {
 		verifyThrowsOAuth2AuthenticationExceptionWithErrorCode(filter, failureHandler, "authorization_request_not_found");
 	}
 
-	@Test
-	public void doFilterWhenAuthorizationCodeSuccessResponseWithInvalidStateParamThenThrowOAuth2AuthenticationExceptionInvalidStateParameter() throws Exception {
-		ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
-
-		AuthorizationCodeAuthenticationFilter filter = Mockito.spy(setupFilter(clientRegistration));
-		AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class);
-		filter.setAuthenticationFailureHandler(failureHandler);
-		AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository();
-		filter.setAuthorizationRequestRepository(authorizationRequestRepository);
-
-		MockHttpServletRequest request = this.setupRequest(clientRegistration);
-		String authCode = "some code";
-		String state = "some other state";
-		request.addParameter(OAuth2Parameter.CODE, authCode);
-		request.addParameter(OAuth2Parameter.STATE, state);
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		setupAuthorizationRequest(authorizationRequestRepository, request, response, clientRegistration, "some state");
-		FilterChain filterChain = mock(FilterChain.class);
-
-		filter.doFilter(request, response, filterChain);
-
-		verifyThrowsOAuth2AuthenticationExceptionWithErrorCode(filter, failureHandler, "invalid_state_parameter");
-	}
-
-	@Test
-	public void doFilterWhenAuthorizationCodeSuccessResponseWithInvalidRedirectUriParamThenThrowOAuth2AuthenticationExceptionInvalidRedirectUriParameter() throws Exception {
-		ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
-
-		AuthorizationCodeAuthenticationFilter filter = Mockito.spy(setupFilter(clientRegistration));
-		AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class);
-		filter.setAuthenticationFailureHandler(failureHandler);
-		AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository();
-		filter.setAuthorizationRequestRepository(authorizationRequestRepository);
-
-		MockHttpServletRequest request = this.setupRequest(clientRegistration);
-		request.setRequestURI(request.getRequestURI() + "-other");
-		String authCode = "some code";
-		String state = "some state";
-		request.addParameter(OAuth2Parameter.CODE, authCode);
-		request.addParameter(OAuth2Parameter.STATE, state);
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		setupAuthorizationRequest(authorizationRequestRepository, request, response, clientRegistration, state);
-		FilterChain filterChain = mock(FilterChain.class);
-
-		filter.doFilter(request, response, filterChain);
-
-		verifyThrowsOAuth2AuthenticationExceptionWithErrorCode(filter, failureHandler, "invalid_redirect_uri_parameter");
-	}
-
 	private void verifyThrowsOAuth2AuthenticationExceptionWithErrorCode(AuthorizationCodeAuthenticationFilter filter,
 																		AuthenticationFailureHandler failureHandler,
 																		String errorCode) throws Exception {

+ 21 - 7
oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/AuthorizationResponse.java

@@ -27,21 +27,26 @@ import org.springframework.util.StringUtils;
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.2">Section 4.1.2 Authorization Response</a>
  */
 public final class AuthorizationResponse {
-	private String code;
+	private String redirectUri;
 	private String state;
+	private String code;
 	private OAuth2Error error;
 
 	private AuthorizationResponse() {
 	}
 
-	public String getCode() {
-		return this.code;
+	public String getRedirectUri() {
+		return this.redirectUri;
 	}
 
 	public String getState() {
 		return this.state;
 	}
 
+	public String getCode() {
+		return this.code;
+	}
+
 	public OAuth2Error getError() {
 		return this.error;
 	}
@@ -65,8 +70,9 @@ public final class AuthorizationResponse {
 	}
 
 	public static class Builder {
-		private String code;
+		private String redirectUri;
 		private String state;
+		private String code;
 		private String errorCode;
 		private String errorDescription;
 		private String errorUri;
@@ -74,8 +80,8 @@ public final class AuthorizationResponse {
 		private Builder() {
 		}
 
-		public Builder code(String code) {
-			this.code = code;
+		public Builder redirectUri(String redirectUri) {
+			this.redirectUri = redirectUri;
 			return this;
 		}
 
@@ -84,6 +90,11 @@ public final class AuthorizationResponse {
 			return this;
 		}
 
+		public Builder code(String code) {
+			this.code = code;
+			return this;
+		}
+
 		public Builder errorCode(String errorCode) {
 			this.errorCode = errorCode;
 			return this;
@@ -103,14 +114,17 @@ public final class AuthorizationResponse {
 			if (StringUtils.hasText(this.code) && StringUtils.hasText(this.errorCode)) {
 				throw new IllegalArgumentException("code and errorCode cannot both be set");
 			}
+			Assert.hasText(this.redirectUri, "redirectUri cannot be empty");
+
 			AuthorizationResponse authorizationResponse = new AuthorizationResponse();
+			authorizationResponse.redirectUri = this.redirectUri;
+			authorizationResponse.state = this.state;
 			if (StringUtils.hasText(this.code)) {
 				authorizationResponse.code = this.code;
 			} else {
 				authorizationResponse.error = new OAuth2Error(
 					this.errorCode, this.errorDescription, this.errorUri);
 			}
-			authorizationResponse.state = this.state;
 			return authorizationResponse;
 		}
 	}

+ 0 - 25
samples/boot/oauth2login/src/integration-test/java/org/springframework/security/samples/OAuth2LoginApplicationTests.java

@@ -44,7 +44,6 @@ import org.springframework.security.oauth2.client.web.AuthorizationCodeAuthentic
 import org.springframework.security.oauth2.client.web.AuthorizationCodeRequestRedirectFilter;
 import org.springframework.security.oauth2.client.web.AuthorizationGrantTokenExchanger;
 import org.springframework.security.oauth2.core.AccessToken;
-import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter;
 import org.springframework.security.oauth2.core.endpoint.ResponseType;
 import org.springframework.security.oauth2.core.endpoint.TokenResponse;
@@ -282,30 +281,6 @@ public class OAuth2LoginApplicationTests {
 		assertThat(errorElement.asText()).contains("invalid_redirect_uri_parameter");
 	}
 
-	@Test
-	public void requestAuthorizationCodeGrantWhenStandardErrorCodeResponseThenDisplayLoginPageWithError() throws Exception {
-		HtmlPage page = this.webClient.getPage("/");
-		URL loginPageUrl = page.getBaseURL();
-		URL loginErrorPageUrl = new URL(loginPageUrl.toString() + "?error");
-
-		String error = OAuth2Error.INVALID_CLIENT_ERROR_CODE;
-		String state = "state";
-		String redirectUri = AUTHORIZE_BASE_URL + "/" + this.githubClientRegistration.getRegistrationId();
-
-		String authorizationResponseUri =
-				UriComponentsBuilder.fromHttpUrl(redirectUri)
-						.queryParam(OAuth2Parameter.ERROR, error)
-						.queryParam(OAuth2Parameter.STATE, state)
-						.build().encode().toUriString();
-
-		page = this.webClient.getPage(new URL(authorizationResponseUri));
-		assertThat(page.getBaseURL()).isEqualTo(loginErrorPageUrl);
-
-		HtmlElement errorElement = page.getBody().getFirstByXPath("p");
-		assertThat(errorElement).isNotNull();
-		assertThat(errorElement.asText()).contains(error);
-	}
-
 	private void assertLoginPage(HtmlPage page) throws Exception {
 		assertThat(page.getTitleText()).isEqualTo("Login Page");