Kaynağa Gözat

Polish gh-77

Joe Grandja 5 yıl önce
ebeveyn
işleme
fbc98d511c

+ 2 - 1
core/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java

@@ -15,6 +15,7 @@
  */
 package org.springframework.security.oauth2.server.authorization;
 
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.util.Assert;
 
 import java.util.List;
@@ -65,7 +66,7 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
 
 	private boolean hasToken(OAuth2Authorization authorization, String token, TokenType tokenType) {
 		if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) {
-			return token.equals(authorization.getAttributes().get(TokenType.AUTHORIZATION_CODE.getValue()));
+			return token.equals(authorization.getAttributes().get(OAuth2ParameterNames.class.getName().concat(".CODE")));
 		} else if (TokenType.ACCESS_TOKEN.equals(tokenType)) {
 			return authorization.getAccessToken() != null &&
 					authorization.getAccessToken().getTokenValue().equals(token);

+ 2 - 1
core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java

@@ -16,6 +16,7 @@
 package org.springframework.security.oauth2.server.authorization;
 
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.util.Assert;
 
@@ -196,7 +197,7 @@ public class OAuth2Authorization implements Serializable {
 		 */
 		public OAuth2Authorization build() {
 			Assert.hasText(this.principalName, "principalName cannot be empty");
-			Assert.notNull(this.attributes.get(TokenType.AUTHORIZATION_CODE.getValue()), "authorization code cannot be null");
+			Assert.notNull(this.attributes.get(OAuth2ParameterNames.class.getName().concat(".CODE")), "authorization code cannot be null");
 
 			OAuth2Authorization authorization = new OAuth2Authorization();
 			authorization.registeredClientId = this.registeredClientId;

+ 205 - 157
core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java

@@ -15,29 +15,21 @@
  */
 package org.springframework.security.oauth2.server.authorization.web;
 
-import java.io.IOException;
-import java.util.stream.Stream;
-
-import javax.servlet.FilterChain;
-import javax.servlet.ServletException;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-
-import org.springframework.core.convert.converter.Converter;
+import org.springframework.http.HttpMethod;
 import org.springframework.http.HttpStatus;
+import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
 import org.springframework.security.crypto.keygen.StringKeyGenerator;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
-import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
-import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.web.DefaultRedirectStrategy;
@@ -45,201 +37,257 @@ import org.springframework.security.web.RedirectStrategy;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
+import org.springframework.util.LinkedMultiValueMap;
+import org.springframework.util.MultiValueMap;
 import org.springframework.util.StringUtils;
 import org.springframework.web.filter.OncePerRequestFilter;
 import org.springframework.web.util.UriComponentsBuilder;
 
+import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Base64;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
 /**
+ * A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
+ * which handles the processing of the OAuth 2.0 Authorization Request.
+ *
  * @author Joe Grandja
  * @author Paurav Munshi
  * @since 0.0.1
+ * @see RegisteredClientRepository
+ * @see OAuth2AuthorizationService
+ * @see OAuth2Authorization
+ * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a>
+ * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.1">Section 4.1.1 Authorization Request</a>
  */
 public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
-
-	private static final String DEFAULT_ENDPOINT = "/oauth2/authorize";
-
-	private Converter<HttpServletRequest, OAuth2AuthorizationRequest> authorizationRequestConverter = new OAuth2AuthorizationRequestConverter();
-	private RegisteredClientRepository registeredClientRepository;
-	private OAuth2AuthorizationService authorizationService;
-	private StringKeyGenerator codeGenerator = new Base64StringKeyGenerator();
-	private RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy();
-	private RequestMatcher authorizationEndpointMatcher = new AntPathRequestMatcher(DEFAULT_ENDPOINT);
-
+	/**
+	 * The default endpoint {@code URI} for authorization requests.
+	 */
+	public static final String DEFAULT_AUTHORIZATION_ENDPOINT_URI = "/oauth2/authorize";
+
+	private final RegisteredClientRepository registeredClientRepository;
+	private final OAuth2AuthorizationService authorizationService;
+	private final RequestMatcher authorizationEndpointMatcher;
+	private final StringKeyGenerator codeGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
+	private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
+
+	/**
+	 * Constructs an {@code OAuth2AuthorizationEndpointFilter} using the provided parameters.
+	 *
+	 * @param registeredClientRepository the repository of registered clients
+	 * @param authorizationService the authorization service
+	 */
 	public OAuth2AuthorizationEndpointFilter(RegisteredClientRepository registeredClientRepository,
 			OAuth2AuthorizationService authorizationService) {
-		Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null.");
-		Assert.notNull(authorizationService, "authorizationService cannot be null.");
-		this.registeredClientRepository = registeredClientRepository;
-		this.authorizationService = authorizationService;
-	}
-
-	public final void setAuthorizationRequestConverter(
-			Converter<HttpServletRequest, OAuth2AuthorizationRequest> authorizationRequestConverter) {
-		Assert.notNull(authorizationRequestConverter, "authorizationRequestConverter cannot be set to null");
-		this.authorizationRequestConverter = authorizationRequestConverter;
-	}
-
-	public final void setCodeGenerator(StringKeyGenerator codeGenerator) {
-		Assert.notNull(codeGenerator, "codeGenerator cannot be set to null");
-		this.codeGenerator = codeGenerator;
-	}
-
-	public final void setAuthorizationRedirectStrategy(RedirectStrategy authorizationRedirectStrategy) {
-		Assert.notNull(authorizationRedirectStrategy, "authorizationRedirectStrategy cannot be set to null");
-		this.authorizationRedirectStrategy = authorizationRedirectStrategy;
+		this(registeredClientRepository, authorizationService, DEFAULT_AUTHORIZATION_ENDPOINT_URI);
 	}
 
-	public final void setAuthorizationEndpointMatcher(RequestMatcher authorizationEndpointMatcher) {
-		Assert.notNull(authorizationEndpointMatcher, "authorizationEndpointMatcher cannot be set to null");
-		this.authorizationEndpointMatcher = authorizationEndpointMatcher;
-	}
-
-	@Override
-	protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
-		boolean pathMatch = this.authorizationEndpointMatcher.matches(request);
-		String responseType = request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE);
-		boolean responseTypeMatch = OAuth2ParameterNames.CODE.equals(responseType);
-		if (pathMatch && responseTypeMatch) {
-			return false;
-		}else {
-			return true;
-		}
+	/**
+	 * Constructs an {@code OAuth2AuthorizationEndpointFilter} using the provided parameters.
+	 *
+	 * @param registeredClientRepository the repository of registered clients
+	 * @param authorizationService the authorization service
+	 * @param authorizationEndpointUri the endpoint {@code URI} for authorization requests
+	 */
+	public OAuth2AuthorizationEndpointFilter(RegisteredClientRepository registeredClientRepository,
+			OAuth2AuthorizationService authorizationService, String authorizationEndpointUri) {
+		Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
+		Assert.notNull(authorizationService, "authorizationService cannot be null");
+		Assert.hasText(authorizationEndpointUri, "authorizationEndpointUri cannot be empty");
+		this.registeredClientRepository = registeredClientRepository;
+		this.authorizationService = authorizationService;
+		this.authorizationEndpointMatcher = new AntPathRequestMatcher(
+				authorizationEndpointUri, HttpMethod.GET.name());
 	}
 
 	@Override
-	protected void doFilterInternal(HttpServletRequest request,
-			HttpServletResponse response, FilterChain filterChain)
+	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
 
-		RegisteredClient client = null;
-		OAuth2AuthorizationRequest authorizationRequest = null;
-		OAuth2Authorization authorization = null;
-
-		try {
-			checkUserAuthenticated();
-			Authentication auth = SecurityContextHolder.getContext().getAuthentication();
-			client = fetchRegisteredClient(request);
-
-			authorizationRequest = this.authorizationRequestConverter.convert(request);
-			validateAuthorizationRequest(authorizationRequest, client);
-
-			String code = this.codeGenerator.generateKey();
-			authorization = buildOAuth2Authorization(auth, client, authorizationRequest, code);
-			this.authorizationService.save(authorization);
+		if (!this.authorizationEndpointMatcher.matches(request) || !isPrincipalAuthenticated()) {
+			filterChain.doFilter(request, response);
+			return;
+		}
 
-			String redirectUri = getRedirectUri(authorizationRequest, client);
-			sendCodeOnSuccess(request, response, authorizationRequest, redirectUri, code);
+//		TODO
+//		The authorization server validates the request to ensure that all
+//		required parameters are present and valid.  If the request is valid,
+//		the authorization server authenticates the resource owner and obtains
+//		an authorization decision (by asking the resource owner or by
+//		establishing approval via other means).
+
+		MultiValueMap<String, String> parameters = getParameters(request);
+		String stateParameter = parameters.getFirst(OAuth2ParameterNames.STATE);
+
+		// client_id (REQUIRED)
+		String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
+		if (!StringUtils.hasText(clientId) ||
+				parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
+			OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
+			sendErrorResponse(request, response, error, stateParameter, null);	// when redirectUri is null then don't redirect
+			return;
+		}
+		RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
+		if (registeredClient == null) {
+			OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
+			sendErrorResponse(request, response, error, stateParameter, null);	// when redirectUri is null then don't redirect
+			return;
+		} else if (!registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.AUTHORIZATION_CODE)) {
+			OAuth2Error error = createError(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT, OAuth2ParameterNames.CLIENT_ID);
+			sendErrorResponse(request, response, error, stateParameter, null);	// when redirectUri is null then don't redirect
+			return;
 		}
-		catch(OAuth2AuthorizationException authorizationException) {
-			OAuth2Error authorizationError = authorizationException.getError();
 
-			if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST)
-					|| authorizationError.getErrorCode().equals(OAuth2ErrorCodes.ACCESS_DENIED)) {
-				sendErrorInResponse(response, authorizationError);
-			}
-			else if (authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE)
-					|| authorizationError.getErrorCode().equals(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT)) {
-				String redirectUri = getRedirectUri(authorizationRequest, client);
-				sendErrorInRedirect(request, response, authorizationRequest, authorizationError, redirectUri);
-			}
-			else {
-				throw new ServletException(authorizationException);
+		// redirect_uri (OPTIONAL)
+		String redirectUriParameter = parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI);
+		if (StringUtils.hasText(redirectUriParameter)) {
+			if (!registeredClient.getRedirectUris().contains(redirectUriParameter) ||
+					parameters.get(OAuth2ParameterNames.REDIRECT_URI).size() != 1) {
+				OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
+				sendErrorResponse(request, response, error, stateParameter, null);	// when redirectUri is null then don't redirect
+				return;
 			}
+		} else if (registeredClient.getRedirectUris().size() != 1) {
+			OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
+			sendErrorResponse(request, response, error, stateParameter, null);	// when redirectUri is null then don't redirect
+			return;
 		}
 
-	}
-
-	private void checkUserAuthenticated() {
-		Authentication currentAuth = SecurityContextHolder.getContext().getAuthentication();
-		if (currentAuth==null || !currentAuth.isAuthenticated()) {
-			throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
+		String redirectUri = StringUtils.hasText(redirectUriParameter) ?
+				redirectUriParameter : registeredClient.getRedirectUris().iterator().next();
+
+		// response_type (REQUIRED)
+		String responseType = parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE);
+		if (!StringUtils.hasText(responseType) ||
+				parameters.get(OAuth2ParameterNames.RESPONSE_TYPE).size() != 1) {
+			OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.RESPONSE_TYPE);
+			sendErrorResponse(request, response, error, stateParameter, redirectUri);
+			return;
+		} else if (!responseType.equals(OAuth2AuthorizationResponseType.CODE.getValue())) {
+			OAuth2Error error = createError(OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE, OAuth2ParameterNames.RESPONSE_TYPE);
+			sendErrorResponse(request, response, error, stateParameter, redirectUri);
+			return;
 		}
-	}
 
-	private RegisteredClient fetchRegisteredClient(HttpServletRequest request) throws OAuth2AuthorizationException {
-		String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID);
-		if (StringUtils.isEmpty(clientId)) {
-			throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
-		}
+		Authentication principal = SecurityContextHolder.getContext().getAuthentication();
+		String code = this.codeGenerator.generateKey();
+		OAuth2AuthorizationRequest authorizationRequest = convertAuthorizationRequest(request);
 
-		RegisteredClient client = this.registeredClientRepository.findByClientId(clientId);
-		if (client==null) {
-			throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
-		}
+		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient)
+				.principalName(principal.getName())
+				.attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), code)
+				.attribute(OAuth2AuthorizationRequest.class.getName(), authorizationRequest)
+				.build();
 
-		boolean isAuthorizationGrantAllowed = Stream.of(client.getAuthorizationGrantTypes())
-				.anyMatch(grantType -> grantType.contains(AuthorizationGrantType.AUTHORIZATION_CODE));
-		if (!isAuthorizationGrantAllowed) {
-			throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED));
-		}
+		this.authorizationService.save(authorization);
 
-		return client;
+//		TODO security checks for code parameter
+//		The authorization code MUST expire shortly after it is issued to mitigate the risk of leaks.
+//		A maximum authorization code lifetime of 10 minutes is RECOMMENDED.
+//		The client MUST NOT use the authorization code more than once.
+//		If an authorization code is used more than once, the authorization server MUST deny the request
+//		and SHOULD revoke (when possible) all tokens previously issued based on that authorization code.
+//		The authorization code is bound to the client identifier and redirection URI.
 
+		sendAuthorizationResponse(request, response, authorizationRequest, code, redirectUri);
 	}
 
-	private OAuth2Authorization buildOAuth2Authorization(Authentication auth, RegisteredClient client,
-			OAuth2AuthorizationRequest authorizationRequest, String code) {
-		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(client)
-					.principalName(auth.getPrincipal().toString())
-					.attribute(TokenType.AUTHORIZATION_CODE.getValue(), code)
-					.attributes(attirbutesMap -> attirbutesMap.putAll(authorizationRequest.getAttributes()))
-					.build();
+	private void sendAuthorizationResponse(HttpServletRequest request, HttpServletResponse response,
+			OAuth2AuthorizationRequest authorizationRequest, String code, String redirectUri) throws IOException {
 
-		return authorization;
+		UriComponentsBuilder uriBuilder = UriComponentsBuilder
+				.fromUriString(redirectUri)
+				.queryParam(OAuth2ParameterNames.CODE, code);
+		if (StringUtils.hasText(authorizationRequest.getState())) {
+			uriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState());
+		}
+		this.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString());
 	}
 
+	private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
+			OAuth2Error error, String state, String redirectUri) throws IOException {
 
-	private void validateAuthorizationRequest(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) {
-		String redirectUri = authorizationRequest.getRedirectUri();
-		if (StringUtils.isEmpty(redirectUri) && client.getRedirectUris().size() > 1) {
-			throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
+		if (redirectUri == null) {
+			// TODO Send default html error response
+			response.sendError(HttpStatus.BAD_REQUEST.value(), error.toString());
+			return;
 		}
-		if (!StringUtils.isEmpty(redirectUri) && !client.getRedirectUris().contains(redirectUri)) {
-			throw new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
+
+		UriComponentsBuilder uriBuilder = UriComponentsBuilder
+				.fromUriString(redirectUri)
+				.queryParam(OAuth2ParameterNames.ERROR, error.getErrorCode());
+		if (StringUtils.hasText(error.getDescription())) {
+			uriBuilder.queryParam(OAuth2ParameterNames.ERROR_DESCRIPTION, error.getDescription());
 		}
+		if (StringUtils.hasText(error.getUri())) {
+			uriBuilder.queryParam(OAuth2ParameterNames.ERROR_URI, error.getUri());
+		}
+		if (StringUtils.hasText(state)) {
+			uriBuilder.queryParam(OAuth2ParameterNames.STATE, state);
+		}
+		this.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString());
 	}
 
-	private String getRedirectUri(OAuth2AuthorizationRequest authorizationRequest, RegisteredClient client) {
-		return !StringUtils.isEmpty(authorizationRequest.getRedirectUri())
-		? authorizationRequest.getRedirectUri()
-		: client.getRedirectUris().stream().findFirst().get();
+	private static boolean isPrincipalAuthenticated() {
+		return isPrincipalAuthenticated(SecurityContextHolder.getContext().getAuthentication());
 	}
 
-	private void sendCodeOnSuccess(HttpServletRequest request, HttpServletResponse response,
-			OAuth2AuthorizationRequest authorizationRequest, String redirectUri, String code) throws IOException {
-		UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri)
-				.queryParam(OAuth2ParameterNames.CODE, code);
-		if (!StringUtils.isEmpty(authorizationRequest.getState())) {
-			redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState());
-		}
-
-		String finalRedirectUri = redirectUriBuilder.toUriString();
-		this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectUri);
+	private static boolean isPrincipalAuthenticated(Authentication principal) {
+		return principal != null &&
+				!AnonymousAuthenticationToken.class.isAssignableFrom(principal.getClass()) &&
+				principal.isAuthenticated();
 	}
 
-	private void sendErrorInResponse(HttpServletResponse response, OAuth2Error authorizationError) throws IOException {
-		int errorStatus = -1;
-		String errorCode = authorizationError.getErrorCode();
-		if (errorCode.equals(OAuth2ErrorCodes.ACCESS_DENIED)) {
-			errorStatus=HttpStatus.FORBIDDEN.value();
-		}
-		else {
-			errorStatus=HttpStatus.INTERNAL_SERVER_ERROR.value();
-		}
-		response.sendError(errorStatus, authorizationError.getErrorCode());
+	private static OAuth2Error createError(String errorCode, String parameterName) {
+		return new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName,
+				"https://tools.ietf.org/html/rfc6749#section-4.1.2.1");
 	}
 
-	private void sendErrorInRedirect(HttpServletRequest request, HttpServletResponse response,
-			OAuth2AuthorizationRequest authorizationRequest, OAuth2Error authorizationError,
-			String redirectUri) throws IOException {
-		UriComponentsBuilder redirectUriBuilder = UriComponentsBuilder.fromUriString(redirectUri)
-				.queryParam(OAuth2ParameterNames.ERROR, authorizationError.getErrorCode());
+	private static OAuth2AuthorizationRequest convertAuthorizationRequest(HttpServletRequest request) {
+		MultiValueMap<String, String> parameters = getParameters(request);
 
-		if (!StringUtils.isEmpty(authorizationRequest.getState())) {
-			redirectUriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState());
+		Set<String> scopes = Collections.emptySet();
+		if (parameters.containsKey(OAuth2ParameterNames.SCOPE)) {
+			String scope = parameters.getFirst(OAuth2ParameterNames.SCOPE);
+			scopes = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(scope, " ")));
 		}
 
-		String finalRedirectURI = redirectUriBuilder.toUriString();
-		this.authorizationRedirectStrategy.sendRedirect(request, response, finalRedirectURI);
+		return OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri(request.getRequestURL().toString())
+				.clientId(parameters.getFirst(OAuth2ParameterNames.CLIENT_ID))
+				.redirectUri(parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI))
+				.scopes(scopes)
+				.state(parameters.getFirst(OAuth2ParameterNames.STATE))
+				.additionalParameters(additionalParameters ->
+						parameters.entrySet().stream()
+								.filter(e -> !e.getKey().equals(OAuth2ParameterNames.RESPONSE_TYPE) &&
+										!e.getKey().equals(OAuth2ParameterNames.CLIENT_ID) &&
+										!e.getKey().equals(OAuth2ParameterNames.REDIRECT_URI) &&
+										!e.getKey().equals(OAuth2ParameterNames.SCOPE) &&
+										!e.getKey().equals(OAuth2ParameterNames.STATE))
+								.forEach(e -> additionalParameters.put(e.getKey(), e.getValue().get(0))))
+				.build();
+	}
+
+	private static MultiValueMap<String, String> getParameters(HttpServletRequest request) {
+		Map<String, String[]> parameterMap = request.getParameterMap();
+		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(parameterMap.size());
+		parameterMap.forEach((key, values) -> {
+			if (values.length > 0) {
+				for (String value : values) {
+					parameters.add(key, value);
+				}
+			}
+		});
+		return parameters;
 	}
 }

+ 0 - 55
core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationRequestConverter.java

@@ -1,55 +0,0 @@
-/*
- * Copyright 2020 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.security.oauth2.server.authorization.web;
-
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.LinkedHashSet;
-import java.util.Set;
-
-import javax.servlet.http.HttpServletRequest;
-
-import org.springframework.core.convert.converter.Converter;
-import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
-import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
-import org.springframework.util.StringUtils;
-
-/**
- * @author Paurav Munshi
- * @since 0.0.1
- * @see Converter
- */
-public class OAuth2AuthorizationRequestConverter implements Converter<HttpServletRequest, OAuth2AuthorizationRequest> {
-
-	@Override
-	public OAuth2AuthorizationRequest convert(HttpServletRequest request) {
-		String scope = request.getParameter(OAuth2ParameterNames.SCOPE);
-		Set<String> scopes = !StringUtils.isEmpty(scope)
-				? new LinkedHashSet<String>(Arrays.asList(scope.split(" ")))
-				: Collections.emptySet();
-
-		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
-				.clientId(request.getParameter(OAuth2ParameterNames.CLIENT_ID))
-				.redirectUri(request.getParameter(OAuth2ParameterNames.REDIRECT_URI))
-				.scopes(scopes)
-				.state(request.getParameter(OAuth2ParameterNames.STATE))
-				.authorizationUri(request.getServletPath())
-				.build();
-
-		return authorizationRequest;
-	}
-
-}

+ 4 - 3
core/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java

@@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization;
 import org.junit.Before;
 import org.junit.Test;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 
@@ -61,7 +62,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 	public void saveWhenAuthorizationProvidedThenSaved() {
 		OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
-				.attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE)
+				.attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE)
 				.build();
 		this.authorizationService.save(expectedAuthorization);
 
@@ -88,7 +89,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 	public void findByTokenAndTokenTypeWhenTokenTypeAuthorizationCodeThenFound() {
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
-				.attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE)
+				.attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE)
 				.build();
 		this.authorizationService = new InMemoryOAuth2AuthorizationService(Collections.singletonList(authorization));
 
@@ -103,7 +104,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 				"access-token", Instant.now().minusSeconds(60), Instant.now());
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
-				.attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE)
+				.attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE)
 				.accessToken(accessToken)
 				.build();
 		this.authorizationService.save(authorization);

+ 3 - 2
core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java

@@ -17,6 +17,7 @@ package org.springframework.security.oauth2.server.authorization;
 
 import org.junit.Test;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 
@@ -84,13 +85,13 @@ public class OAuth2AuthorizationTests {
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
 				.accessToken(ACCESS_TOKEN)
-				.attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE)
+				.attribute(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE)
 				.build();
 
 		assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId());
 		assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME);
 		assertThat(authorization.getAccessToken()).isEqualTo(ACCESS_TOKEN);
 		assertThat(authorization.getAttributes()).containsExactly(
-				entry(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE));
+				entry(OAuth2ParameterNames.class.getName().concat(".CODE"), AUTHORIZATION_CODE));
 	}
 }

+ 0 - 36
core/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java

@@ -46,40 +46,4 @@ public class TestRegisteredClients {
 				.scope("profile")
 				.scope("email");
 	}
-
-	public static RegisteredClient.Builder validAuthorizationGrantRegisteredClient() {
-		return RegisteredClient.withId("valid_client_id")
-				.clientId("valid_client")
-				.clientSecret("valid_secret")
-				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
-				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
-				.redirectUri("http://localhost:8080/test-application/callback")
-				.scope("openid")
-				.scope("profile")
-				.scope("email");
-	}
-
-	public static RegisteredClient.Builder validAuthorizationGrantClientMultiRedirectUris() {
-		return RegisteredClient.withId("valid_client_multi_uri_id")
-				.clientId("valid_client_multi_uri")
-				.clientSecret("valid_secret")
-				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
-				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
-				.redirectUri("http://localhost:8080/test-application/callback")
-				.redirectUri("http://localhost:8080/another-test-application/callback")
-				.scope("openid")
-				.scope("profile")
-				.scope("email");
-	}
-
-	public static RegisteredClient.Builder validClientCredentialsGrantRegisteredClient() {
-		return RegisteredClient.withId("valid_cc_client_id")
-				.clientId("valid_cc_client")
-				.clientSecret("valid_secret")
-				.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
-				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
-				.scope("openid")
-				.scope("profile")
-				.scope("email");
-	}
 }

+ 0 - 371
core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTest.java

@@ -1,371 +0,0 @@
-/*
- * Copyright 2020 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.security.oauth2.server.authorization.web;
-
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.assertThatThrownBy;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyString;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.spy;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
-
-import javax.servlet.FilterChain;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.springframework.http.HttpStatus;
-import org.springframework.mock.web.MockHttpServletRequest;
-import org.springframework.mock.web.MockHttpServletResponse;
-import org.springframework.security.core.Authentication;
-import org.springframework.security.core.context.SecurityContextHolder;
-import org.springframework.security.crypto.keygen.StringKeyGenerator;
-import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
-import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
-import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
-import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
-import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
-import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
-import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
-
-
-/**
- * Tests for {@link OAuth2AuthorizationEndpointFilter}.
- *
- * @author Paurav Munshi
- * @since 0.0.1
- */
-
-public class OAuth2AuthorizationEndpointFilterTest {
-
-	private static final String VALID_CLIENT = "valid_client";
-	private static final String VALID_CLIENT_MULTI_URI = "valid_client_multi_uri";
-	private static final String VALID_CC_CLIENT = "valid_cc_client";
-
-	private OAuth2AuthorizationEndpointFilter filter;
-
-	private OAuth2AuthorizationService authorizationService = mock(OAuth2AuthorizationService.class);
-	private StringKeyGenerator codeGenerator = mock(StringKeyGenerator.class);
-	private RegisteredClientRepository registeredClientRepository = mock(RegisteredClientRepository.class);
-	private Authentication authentication = mock(Authentication.class);
-
-	@Before
-	public void setUp() {
-		this.filter = new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService);
-		this.filter.setCodeGenerator(this.codeGenerator);
-
-		SecurityContextHolder.getContext().setAuthentication(this.authentication);
-	}
-
-	@Test
-	public void constructorWhenRegisteredClientRepositoryIsNullThenIllegalArgumentExceptionIsThrows() throws Exception {
-		assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(null, this.authorizationService))
-			.isInstanceOf(IllegalArgumentException.class);
-	}
-
-	@Test
-	public void constructorWhenAuthorizationServiceIsNullThenIllegalArgumentExceptionIsThrows() throws Exception {
-		assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, null))
-			.isInstanceOf(IllegalArgumentException.class);
-	}
-
-	@Test
-	public void setAuthorizationEndpointMatcherWhenAuthorizationEndpointMatcherIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
-		assertThatThrownBy(() ->this.filter.setAuthorizationEndpointMatcher(null))
-			.isInstanceOf(IllegalArgumentException.class);
-	}
-
-	@Test
-	public void setAuthorizationRedirectStrategyWhenAuthorizationRedirectStrategyIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
-		assertThatThrownBy(() ->this.filter.setAuthorizationRedirectStrategy(null))
-			.isInstanceOf(IllegalArgumentException.class);
-	}
-
-	@Test
-	public void setAuthorizationRequestConverterWhenAuthorizationRequestConverterIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
-		assertThatThrownBy(() ->this.filter.setAuthorizationRequestConverter(null))
-			.isInstanceOf(IllegalArgumentException.class);
-	}
-
-	@Test
-	public void setCodeGeneratorWhenCodeGeneratorIsNullThenIllegalArgumentExceptionIsThrown() throws Exception {
-		assertThatThrownBy(() ->this.filter.setCodeGenerator(null))
-			.isInstanceOf(IllegalArgumentException.class);
-	}
-
-	@Test
-	public void doFilterWhenValidRequestIsReceivedThenResponseRedirectedToRedirectURIWithCode() throws Exception {
-		MockHttpServletRequest request = getValidMockHttpServletRequest();
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
-		when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
-		when(this.codeGenerator.generateKey()).thenReturn("sample_code");
-		when(this.authentication.getPrincipal()).thenReturn("test-user");
-		when(this.authentication.isAuthenticated()).thenReturn(true);
-
-
-		this.filter.doFilter(request, response, filterChain);
-
-		verify(this.authentication).isAuthenticated();
-		verify(this.registeredClientRepository).findByClientId(VALID_CLIENT);
-		verify(this.authorizationService).save(any(OAuth2Authorization.class));
-		verify(this.codeGenerator).generateKey();
-		verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
-		assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate");
-
-	}
-
-	@Test
-	public void doFilterWhenValidRequestWithBlankRedirectURIIsReceivedThenResponseRedirectedToConfiguredRedirectURI() throws Exception {
-		MockHttpServletRequest request = getValidMockHttpServletRequest();
-		request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "");
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
-		when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
-		when(this.codeGenerator.generateKey()).thenReturn("sample_code");
-		when(this.authentication.getPrincipal()).thenReturn("test-user");
-		when(this.authentication.isAuthenticated()).thenReturn(true);
-
-		this.filter.doFilter(request, response, filterChain);
-
-		verify(this.authentication).isAuthenticated();
-		verify(this.registeredClientRepository).findByClientId(VALID_CLIENT);
-		verify(this.authorizationService).save(any(OAuth2Authorization.class));
-		verify(this.codeGenerator).generateKey();
-		verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
-		assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost:8080/test-application/callback?code=sample_code&state=teststate");
-
-	}
-
-	@Test
-	public void doFilterWhenRedirectURINotPresentAndClientHasMulitipleUrisThenErrorIsSentInResponse() throws Exception {
-		MockHttpServletRequest request = getValidMockHttpServletRequest();
-		request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT_MULTI_URI);
-		request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "");
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantClientMultiRedirectUris().build();
-		when(this.registeredClientRepository.findByClientId(VALID_CLIENT_MULTI_URI)).thenReturn(registeredClient);
-		when(this.authentication.isAuthenticated()).thenReturn(true);
-
-
-		this.filter.doFilter(request, response, filterChain);
-
-		verify(this.authentication, times(1)).isAuthenticated();
-		verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT_MULTI_URI);
-		verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
-		verify(this.codeGenerator, times(0)).generateKey();
-		verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
-		assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
-
-	}
-
-	@Test
-	public void doFilterWhenRequestedRedirectUriNotConfiguredInClientThenErrorSentInResponse() throws Exception {
-		MockHttpServletRequest request = getValidMockHttpServletRequest();
-		request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/not-configred-app/callback");
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		RegisteredClient registeredClient = TestRegisteredClients.validAuthorizationGrantRegisteredClient().build();
-		when(this.registeredClientRepository.findByClientId(VALID_CLIENT)).thenReturn(registeredClient);
-		when(this.authentication.isAuthenticated()).thenReturn(true);
-
-
-		this.filter.doFilter(request, response, filterChain);
-
-		verify(this.authentication, times(1)).isAuthenticated();
-		verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CLIENT);
-		verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
-		verify(this.codeGenerator, times(0)).generateKey();
-		verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
-		assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
-
-	}
-
-	@Test
-	public void doFilterWhenClientIdDoesNotSupportAuthorizationGrantFlowThenErrorSentInResponse() throws Exception {
-		MockHttpServletRequest request = getValidMockHttpServletRequest();
-		request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CC_CLIENT);
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		RegisteredClient registeredClient = TestRegisteredClients.validClientCredentialsGrantRegisteredClient().build();
-		when(this.registeredClientRepository.findByClientId(VALID_CC_CLIENT)).thenReturn(registeredClient);
-		when(this.authentication.isAuthenticated()).thenReturn(true);
-
-
-		this.filter.doFilter(request, response, filterChain);
-
-		verify(this.authentication, times(1)).isAuthenticated();
-		verify(this.registeredClientRepository, times(1)).findByClientId(VALID_CC_CLIENT);
-		verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
-		verify(this.codeGenerator, times(0)).generateKey();
-		verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
-		assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
-
-	}
-
-	@Test
-	public void doFilterWhenClientIdIsMissinInRequestThenErrorSentInResponse() throws Exception {
-		MockHttpServletRequest request = getValidMockHttpServletRequest();
-		request.setParameter(OAuth2ParameterNames.CLIENT_ID, "");
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		when(this.authentication.isAuthenticated()).thenReturn(true);
-
-		this.filter.doFilter(request, response, filterChain);
-
-		verify(this.authentication).isAuthenticated();
-		verify(this.registeredClientRepository, times(0)).findByClientId(anyString());
-		verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
-		verify(this.codeGenerator, times(0)).generateKey();
-		verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
-		assertThat(response.getContentAsString()).isEmpty();
-		assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
-
-	}
-
-	@Test
-	public void doFilterWhenUnregisteredClientInRequestThenErrorIsSentInResponse() throws Exception {
-		MockHttpServletRequest request = getValidMockHttpServletRequest();
-		request.setParameter(OAuth2ParameterNames.CLIENT_ID, "unregistered_client");
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		when(this.registeredClientRepository.findByClientId("unregistered_client")).thenReturn(null);
-		when(this.codeGenerator.generateKey()).thenReturn("sample_code");
-		when(this.authentication.isAuthenticated()).thenReturn(true);
-
-		this.filter.doFilter(request, response, filterChain);
-
-		verify(this.authentication).isAuthenticated();
-		verify(this.registeredClientRepository, times(1)).findByClientId("unregistered_client");
-		verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
-		verify(this.codeGenerator, times(0)).generateKey();
-		verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
-		assertThat(response.getContentAsString()).isEmpty();
-		assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
-
-	}
-
-	@Test
-	public void doFilterWhenUnauthenticatedUserInRequestThenErrorIsSentInResponse() throws Exception {
-		MockHttpServletRequest request = getValidMockHttpServletRequest();
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		when(authentication.isAuthenticated()).thenReturn(false);
-
-		this.filter.doFilter(request, response, filterChain);
-
-		verify(this.authentication).isAuthenticated();
-		verify(this.registeredClientRepository, times(0)).findByClientId(anyString());
-		verify(this.authorizationService, times(0)).save(any(OAuth2Authorization.class));
-		verify(this.codeGenerator, times(0)).generateKey();
-		verify(filterChain, times(0)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.FORBIDDEN.value());
-		assertThat(response.getContentAsString()).isEmpty();
-		assertThat(response.getErrorMessage()).isEqualTo(OAuth2ErrorCodes.ACCESS_DENIED);
-
-	}
-
-	@Test
-	public void doFilterWhenRequestEndPointIsNotAuthorizationEndpointThenFilterShouldProceedWithFilterChain() throws Exception {
-		MockHttpServletRequest request = getValidMockHttpServletRequest();
-		request.setServletPath("/custom/authorize");
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
-		spyFilter.doFilter(request, response, filterChain);
-
-		verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-		verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
-		verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
-	}
-
-	@Test
-	public void doFilterWhenResponseTypeIsNotPresentInRequestThenErrorIsSentInRedirectURIQueryParameter() throws Exception {
-		MockHttpServletRequest request = getValidMockHttpServletRequest();
-		request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "");
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
-		spyFilter.doFilter(request, response, filterChain);
-
-		verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
-		verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
-		verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-	}
-
-	@Test
-	public void doFilterWhenResponseTypeInRequestIsUnsupportedThenErrorIsSentInRedirectURIQueryParameter() throws Exception {
-		MockHttpServletRequest request = getValidMockHttpServletRequest();
-		request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "token");
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		OAuth2AuthorizationEndpointFilter spyFilter = spy(this.filter);
-		spyFilter.doFilter(request, response, filterChain);
-
-		verify(spyFilter, times(1)).shouldNotFilter(any(HttpServletRequest.class));
-		verify(spyFilter, times(0)).doFilterInternal(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
-		verify(filterChain, times(1)).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-	}
-
-	private MockHttpServletRequest getValidMockHttpServletRequest() {
-
-		MockHttpServletRequest request = new MockHttpServletRequest();
-		request.setParameter(OAuth2ParameterNames.CLIENT_ID, VALID_CLIENT);
-		request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "code");
-		request.setParameter(OAuth2ParameterNames.SCOPE, "openid profile email");
-		request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "http://localhost:8080/test-application/callback");
-		request.setParameter(OAuth2ParameterNames.STATE, "teststate");
-		request.setServletPath("/oauth2/authorize");
-
-		return request;
-
-
-	}
-
-}

+ 399 - 0
core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@@ -0,0 +1,399 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.web;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+import org.springframework.http.HttpStatus;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.util.StringUtils;
+
+import javax.servlet.FilterChain;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.util.Set;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link OAuth2AuthorizationEndpointFilter}.
+ *
+ * @author Paurav Munshi
+ * @author Joe Grandja
+ * @since 0.0.1
+ */
+public class OAuth2AuthorizationEndpointFilterTests {
+	private RegisteredClientRepository registeredClientRepository;
+	private OAuth2AuthorizationService authorizationService;
+	private OAuth2AuthorizationEndpointFilter filter;
+	private TestingAuthenticationToken authentication;
+
+	@Before
+	public void setUp() {
+		this.registeredClientRepository = mock(RegisteredClientRepository.class);
+		this.authorizationService = mock(OAuth2AuthorizationService.class);
+		this.filter = new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService);
+		this.authentication = new TestingAuthenticationToken("principalName", "password");
+		this.authentication.setAuthenticated(true);
+		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
+		securityContext.setAuthentication(this.authentication);
+		SecurityContextHolder.setContext(securityContext);
+	}
+
+	@After
+	public void cleanup() {
+		SecurityContextHolder.clearContext();
+	}
+
+	@Test
+	public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(null, this.authorizationService))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("registeredClientRepository cannot be null");
+	}
+
+	@Test
+	public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizationService cannot be null");
+	}
+
+	@Test
+	public void constructorWhenAuthorizationEndpointUriNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2AuthorizationEndpointFilter(this.registeredClientRepository, this.authorizationService, null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizationEndpointUri cannot be empty");
+	}
+
+	@Test
+	public void doFilterWhenNotAuthorizationRequestThenNotProcessed() throws Exception {
+		String requestUri = "/path";
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+	}
+
+	@Test
+	public void doFilterWhenAuthorizationRequestPostThenNotProcessed() throws Exception {
+		String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI;
+		MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);
+		request.setServletPath(requestUri);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+	}
+
+	@Test
+	public void doFilterWhenAuthorizationRequestNotAuthenticatedThenNotProcessed() throws Exception {
+		String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI;
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.authentication.setAuthenticated(false);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+	}
+
+	@Test
+	public void doFilterWhenAuthorizationRequestMissingClientIdThenInvalidRequestError() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		request.removeParameter(OAuth2ParameterNames.CLIENT_ID);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+		assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
+	}
+
+	@Test
+	public void doFilterWhenAuthorizationRequestMultipleClientIdThenInvalidRequestError() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+		assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
+	}
+
+	@Test
+	public void doFilterWhenAuthorizationRequestInvalidClientIdThenInvalidRequestError() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		request.setParameter(OAuth2ParameterNames.CLIENT_ID, "invalid");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+		assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
+	}
+
+	@Test
+	public void doFilterWhenAuthorizationRequestAndClientNotAuthorizedToRequestCodeThenUnauthorizedClientError() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.authorizationGrantTypes(Set::clear)
+				.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
+				.build();
+		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+				.thenReturn(registeredClient);
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+		assertThat(response.getErrorMessage()).isEqualTo("[unauthorized_client] OAuth 2.0 Parameter: client_id");
+	}
+
+	@Test
+	public void doFilterWhenAuthorizationRequestInvalidRedirectUriThenInvalidRequestError() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+				.thenReturn(registeredClient);
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "https://invalid-example.com");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+		assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
+	}
+
+	@Test
+	public void doFilterWhenAuthorizationRequestMultipleRedirectUriThenInvalidRequestError() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+				.thenReturn(registeredClient);
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+		assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
+	}
+
+	@Test
+	public void doFilterWhenAuthorizationRequestExcludesRedirectUriAndMultipleRegisteredThenInvalidRequestError() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().redirectUri("https://example2.com").build();
+		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+				.thenReturn(registeredClient);
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		request.removeParameter(OAuth2ParameterNames.REDIRECT_URI);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+		assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
+	}
+
+	@Test
+	public void doFilterWhenAuthorizationRequestMissingResponseTypeThenInvalidRequestError() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+				.thenReturn(registeredClient);
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
+		assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
+				"error=invalid_request&" +
+				"error_description=OAuth%202.0%20Parameter:%20response_type&" +
+				"error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" +
+				"state=state");
+	}
+
+	@Test
+	public void doFilterWhenAuthorizationRequestMultipleResponseTypeThenInvalidRequestError() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+				.thenReturn(registeredClient);
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
+		assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
+				"error=invalid_request&" +
+				"error_description=OAuth%202.0%20Parameter:%20response_type&" +
+				"error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" +
+				"state=state");
+	}
+
+	@Test
+	public void doFilterWhenAuthorizationRequestInvalidResponseTypeThenUnsupportedResponseTypeError() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+				.thenReturn(registeredClient);
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
+		assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
+				"error=unsupported_response_type&" +
+				"error_description=OAuth%202.0%20Parameter:%20response_type&" +
+				"error_uri=https://tools.ietf.org/html/rfc6749%23section-4.1.2.1&" +
+				"state=state");
+	}
+
+	@Test
+	public void doFilterWhenAuthorizationRequestValidThenAuthorizationResponse() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+				.thenReturn(registeredClient);
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
+		assertThat(response.getRedirectedUrl()).matches("https://example.com\\?code=.{15,}&state=state");
+
+		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
+
+		verify(this.authorizationService).save(authorizationCaptor.capture());
+
+		OAuth2Authorization authorization = authorizationCaptor.getValue();
+		assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
+		assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
+
+		String code = authorization.getAttribute(OAuth2ParameterNames.class.getName().concat(".CODE"));
+		assertThat(code).isNotNull();
+
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName());
+		assertThat(authorizationRequest).isNotNull();
+		assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo("http://localhost/oauth2/authorize");
+		assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
+		assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE);
+		assertThat(authorizationRequest.getClientId()).isEqualTo(registeredClient.getClientId());
+		assertThat(authorizationRequest.getRedirectUri()).isEqualTo(registeredClient.getRedirectUris().iterator().next());
+		assertThat(authorizationRequest.getScopes()).containsExactlyInAnyOrderElementsOf(registeredClient.getScopes());
+		assertThat(authorizationRequest.getState()).isEqualTo("state");
+		assertThat(authorizationRequest.getAdditionalParameters()).isEmpty();
+	}
+
+	private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) {
+		String[] redirectUris = registeredClient.getRedirectUris().toArray(new String[0]);
+
+		String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI;
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+
+		request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
+		request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
+		request.addParameter(OAuth2ParameterNames.REDIRECT_URI, redirectUris[0]);
+		request.addParameter(OAuth2ParameterNames.SCOPE,
+				StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
+		request.addParameter(OAuth2ParameterNames.STATE, "state");
+
+		return request;
+	}
+}