Browse Source

Polish gh-79

Joe Grandja 5 years ago
parent
commit
4c8f89af5c

+ 0 - 10
core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java

@@ -33,7 +33,6 @@ import java.util.function.Consumer;
  *
  * @author Joe Grandja
  * @author Krisztian Toth
- * @author Madhu Bhat
  * @since 0.0.1
  * @see RegisteredClient
  * @see OAuth2AccessToken
@@ -75,15 +74,6 @@ public class OAuth2Authorization implements Serializable {
 		return this.accessToken;
 	}
 
-	/**
-	 * Sets the access token {@link OAuth2AccessToken} in the {@link OAuth2Authorization}.
-	 *
-	 * @param accessToken the access token
-	 */
-	public final void setAccessToken(OAuth2AccessToken accessToken) {
-		this.accessToken = accessToken;
-	}
-
 	/**
 	 * Returns the attribute(s) associated to the authorization.
 	 *

+ 4 - 4
core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java

@@ -17,8 +17,8 @@ package org.springframework.security.oauth2.server.authorization.authentication;
 
 import org.springframework.security.authentication.AbstractAuthenticationToken;
 import org.springframework.security.core.Authentication;
-import org.springframework.security.core.SpringSecurityCoreVersion;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.server.authorization.Version;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 
 import java.util.Collections;
@@ -28,7 +28,7 @@ import java.util.Collections;
  * @author Madhu Bhat
  */
 public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthenticationToken {
-	private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID;
+	private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
 	private RegisteredClient registeredClient;
 	private Authentication clientPrincipal;
 	private OAuth2AccessToken accessToken;
@@ -52,9 +52,9 @@ public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthentication
 	}
 
 	/**
-	 * Returns the access token {@link OAuth2AccessToken}.
+	 * Returns the {@link OAuth2AccessToken access token}.
 	 *
-	 * @return the access token
+	 * @return the {@link OAuth2AccessToken}
 	 */
 	public OAuth2AccessToken getAccessToken() {
 		return this.accessToken;

+ 17 - 8
core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java

@@ -18,7 +18,7 @@ package org.springframework.security.oauth2.server.authorization.authentication;
 import org.springframework.lang.Nullable;
 import org.springframework.security.authentication.AbstractAuthenticationToken;
 import org.springframework.security.core.Authentication;
-import org.springframework.security.core.SpringSecurityCoreVersion;
+import org.springframework.security.oauth2.server.authorization.Version;
 
 import java.util.Collections;
 
@@ -27,7 +27,7 @@ import java.util.Collections;
  * @author Madhu Bhat
  */
 public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken {
-	private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID;
+	private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
 	private String code;
 	private Authentication clientPrincipal;
 	private String clientId;
@@ -37,26 +37,26 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
 			Authentication clientPrincipal, @Nullable String redirectUri) {
 		super(Collections.emptyList());
 		this.code = code;
-		this.redirectUri = redirectUri;
 		this.clientPrincipal = clientPrincipal;
+		this.redirectUri = redirectUri;
 	}
 
 	public OAuth2AuthorizationCodeAuthenticationToken(String code,
 			String clientId, @Nullable String redirectUri) {
 		super(Collections.emptyList());
 		this.code = code;
-		this.redirectUri = redirectUri;
 		this.clientId = clientId;
+		this.redirectUri = redirectUri;
 	}
 
 	@Override
-	public Object getCredentials() {
-		return null;
+	public Object getPrincipal() {
+		return this.clientPrincipal != null ? this.clientPrincipal : this.clientId;
 	}
 
 	@Override
-	public Object getPrincipal() {
-		return null;
+	public Object getCredentials() {
+		return "";
 	}
 
 	/**
@@ -67,4 +67,13 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
 	public String getCode() {
 		return this.code;
 	}
+
+	/**
+	 * Returns the redirectUri.
+	 *
+	 * @return the redirectUri
+	 */
+	public String getRedirectUri() {
+		return this.redirectUri;
+	}
 }

+ 2 - 17
core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java

@@ -38,7 +38,6 @@ import org.springframework.security.web.RedirectStrategy;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
-import org.springframework.util.LinkedMultiValueMap;
 import org.springframework.util.MultiValueMap;
 import org.springframework.util.StringUtils;
 import org.springframework.web.filter.OncePerRequestFilter;
@@ -53,7 +52,6 @@ import java.util.Arrays;
 import java.util.Base64;
 import java.util.Collections;
 import java.util.HashSet;
-import java.util.Map;
 import java.util.Set;
 
 /**
@@ -123,7 +121,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 		// Validate the request to ensure that all required parameters are present and valid
 		// ---------------
 
-		MultiValueMap<String, String> parameters = getParameters(request);
+		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
 		String stateParameter = parameters.getFirst(OAuth2ParameterNames.STATE);
 
 		// client_id (REQUIRED)
@@ -258,7 +256,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 	}
 
 	private static OAuth2AuthorizationRequest convertAuthorizationRequest(HttpServletRequest request) {
-		MultiValueMap<String, String> parameters = getParameters(request);
+		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
 
 		Set<String> scopes = Collections.emptySet();
 		if (parameters.containsKey(OAuth2ParameterNames.SCOPE)) {
@@ -282,17 +280,4 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 								.forEach(e -> additionalParameters.put(e.getKey(), e.getValue().get(0))))
 				.build();
 	}
-
-	private static MultiValueMap<String, String> getParameters(HttpServletRequest request) {
-		Map<String, String[]> parameterMap = request.getParameterMap();
-		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(parameterMap.size());
-		parameterMap.forEach((key, values) -> {
-			if (values.length > 0) {
-				for (String value : values) {
-					parameters.add(key, value);
-				}
-			}
-		});
-		return parameters;
-	}
 }

+ 49 - 0
core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2EndpointUtils.java

@@ -0,0 +1,49 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.web;
+
+import org.springframework.util.LinkedMultiValueMap;
+import org.springframework.util.MultiValueMap;
+
+import javax.servlet.http.HttpServletRequest;
+import java.util.Map;
+
+/**
+ * Utility methods for the OAuth 2.0 Protocol Endpoints.
+ *
+ * @author Joe Grandja
+ * @since 0.0.1
+ * @see OAuth2AuthorizationEndpointFilter
+ * @see OAuth2TokenEndpointFilter
+ */
+final class OAuth2EndpointUtils {
+
+	private OAuth2EndpointUtils() {
+	}
+
+	static MultiValueMap<String, String> getParameters(HttpServletRequest request) {
+		Map<String, String[]> parameterMap = request.getParameterMap();
+		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(parameterMap.size());
+		parameterMap.forEach((key, values) -> {
+			if (values.length > 0) {
+				for (String value : values) {
+					parameters.add(key, value);
+				}
+			}
+		});
+		return parameters;
+	}
+}

+ 121 - 95
core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java

@@ -15,13 +15,11 @@
  */
 package org.springframework.security.oauth2.server.authorization.web;
 
-import com.fasterxml.jackson.annotation.JsonInclude;
-import com.fasterxml.jackson.databind.ObjectMapper;
 import org.springframework.core.convert.converter.Converter;
-import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.HttpStatus;
-import org.springframework.http.MediaType;
+import org.springframework.http.converter.HttpMessageConverter;
+import org.springframework.http.server.ServletServerHttpResponse;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
@@ -30,15 +28,17 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
-import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
+import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
-import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
+import org.springframework.util.MultiValueMap;
 import org.springframework.util.StringUtils;
 import org.springframework.web.filter.OncePerRequestFilter;
 
@@ -47,145 +47,171 @@ import javax.servlet.ServletException;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 import java.io.IOException;
-import java.io.Writer;
+import java.time.temporal.ChronoUnit;
 
 /**
- * This {@code Filter} is used by the client to obtain an access token by presenting
- * its authorization grant.
+ * A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
+ * which handles the processing of the OAuth 2.0 Access Token Request.
  *
  * <p>
- * It converts the OAuth 2.0 Access Token Request to {@link OAuth2AuthorizationCodeAuthenticationToken},
- * which is then authenticated by the {@link AuthenticationManager} and gets back
- * {@link OAuth2AccessTokenAuthenticationToken} which has the {@link OAuth2AccessToken} if the request
- * was successfully authenticated. The {@link OAuth2AccessToken} is then updated in the in-flight {@link OAuth2Authorization}
- * and sent back to the client. In case the authentication fails, an HTTP 401 (Unauthorized) response is returned.
+ * It converts the OAuth 2.0 Access Token Request to an {@link OAuth2AuthorizationCodeAuthenticationToken},
+ * which is then authenticated by the {@link AuthenticationManager}.
+ * If the authentication succeeds, the {@link AuthenticationManager} returns an
+ * {@link OAuth2AccessTokenAuthenticationToken}, which contains
+ * the {@link OAuth2AccessToken} that is returned in the response.
+ * In case of any error, an {@link OAuth2Error} is returned in the response.
  *
  * <p>
  * By default, this {@code Filter} responds to access token requests
- * at the {@code URI} {@code /oauth2/token} and {@code HttpMethod} {@code POST}
- * using the default {@link AntPathRequestMatcher}.
+ * at the {@code URI} {@code /oauth2/token} and {@code HttpMethod} {@code POST}.
  *
  * <p>
- * The default base {@code URI} {@code /oauth2/token} may be overridden
- * via the constructor {@link #OAuth2TokenEndpointFilter(OAuth2AuthorizationService, AuthenticationManager, String)}.
+ * The default endpoint {@code URI} {@code /oauth2/token} may be overridden
+ * via the constructor {@link #OAuth2TokenEndpointFilter(AuthenticationManager, OAuth2AuthorizationService, String)}.
  *
  * @author Joe Grandja
  * @author Madhu Bhat
+ * @since 0.0.1
+ * @see AuthenticationManager
+ * @see OAuth2AuthorizationService
+ * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a>
+ * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.3">Section 4.1.3 Access Token Request</a>
  */
 public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
 	/**
 	 * The default endpoint {@code URI} for access token requests.
 	 */
-	private static final String DEFAULT_TOKEN_ENDPOINT_URI = "/oauth2/token";
+	public static final String DEFAULT_TOKEN_ENDPOINT_URI = "/oauth2/token";
 
-	private Converter<HttpServletRequest, Authentication> authorizationGrantConverter = this::convert;
-	private AuthenticationManager authenticationManager;
-	private OAuth2AuthorizationService authorizationService;
-	private RequestMatcher uriMatcher;
-	private ObjectMapper objectMapper = new ObjectMapper().setSerializationInclusion(JsonInclude.Include.NON_NULL);
+	private final AuthenticationManager authenticationManager;
+	private final OAuth2AuthorizationService authorizationService;
+	private final RequestMatcher tokenEndpointMatcher;
+	private final Converter<HttpServletRequest, Authentication> authorizationGrantAuthenticationConverter =
+			new AuthorizationCodeAuthenticationConverter();
+	private final HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter =
+			new OAuth2AccessTokenResponseHttpMessageConverter();
+	private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
+			new OAuth2ErrorHttpMessageConverter();
 
 	/**
 	 * Constructs an {@code OAuth2TokenEndpointFilter} using the provided parameters.
 	 *
-	 * @param authorizationService  the authorization service implementation
-	 * @param authenticationManager the authentication manager implementation
+	 * @param authenticationManager the authentication manager
+	 * @param authorizationService the authorization service
 	 */
-	public OAuth2TokenEndpointFilter(OAuth2AuthorizationService authorizationService, AuthenticationManager authenticationManager) {
-		Assert.notNull(authorizationService, "authorizationService cannot be null");
-		Assert.notNull(authenticationManager, "authenticationManager cannot be null");
-		this.authenticationManager = authenticationManager;
-		this.authorizationService = authorizationService;
-		this.uriMatcher = new AntPathRequestMatcher(DEFAULT_TOKEN_ENDPOINT_URI, HttpMethod.POST.name());
+	public OAuth2TokenEndpointFilter(AuthenticationManager authenticationManager,
+			OAuth2AuthorizationService authorizationService) {
+		this(authenticationManager, authorizationService, DEFAULT_TOKEN_ENDPOINT_URI);
 	}
 
 	/**
 	 * Constructs an {@code OAuth2TokenEndpointFilter} using the provided parameters.
 	 *
-	 * @param authorizationService  the authorization service implementation
-	 * @param authenticationManager the authentication manager implementation
-	 * @param tokenEndpointUri      the token endpoint's uri
+	 * @param authenticationManager the authentication manager
+	 * @param authorizationService the authorization service
+	 * @param tokenEndpointUri the endpoint {@code URI} for access token requests
 	 */
-	public OAuth2TokenEndpointFilter(OAuth2AuthorizationService authorizationService, AuthenticationManager authenticationManager,
-			String tokenEndpointUri) {
-		Assert.notNull(authorizationService, "authorizationService cannot be null");
+	public OAuth2TokenEndpointFilter(AuthenticationManager authenticationManager,
+			OAuth2AuthorizationService authorizationService, String tokenEndpointUri) {
 		Assert.notNull(authenticationManager, "authenticationManager cannot be null");
+		Assert.notNull(authorizationService, "authorizationService cannot be null");
 		Assert.hasText(tokenEndpointUri, "tokenEndpointUri cannot be empty");
 		this.authenticationManager = authenticationManager;
 		this.authorizationService = authorizationService;
-		this.uriMatcher = new AntPathRequestMatcher(tokenEndpointUri, HttpMethod.POST.name());
+		this.tokenEndpointMatcher = new AntPathRequestMatcher(tokenEndpointUri, HttpMethod.POST.name());
 	}
 
 	@Override
-	protected void doFilterInternal(HttpServletRequest request,
-			HttpServletResponse response, FilterChain filterChain)
+	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
-		if (uriMatcher.matches(request)) {
-			try {
-				if (validateAccessTokenRequest(request)) {
-					OAuth2AuthorizationCodeAuthenticationToken authCodeAuthToken =
-							(OAuth2AuthorizationCodeAuthenticationToken) authorizationGrantConverter.convert(request);
-					OAuth2AccessTokenAuthenticationToken accessTokenAuthenticationToken =
-							(OAuth2AccessTokenAuthenticationToken) authenticationManager.authenticate(authCodeAuthToken);
-					if (accessTokenAuthenticationToken.isAuthenticated()) {
-						OAuth2Authorization authorization = authorizationService
-								.findByTokenAndTokenType(authCodeAuthToken.getCode(), TokenType.AUTHORIZATION_CODE);
-						authorization.setAccessToken(accessTokenAuthenticationToken.getAccessToken());
-						authorizationService.save(authorization);
-						writeSuccessResponse(response, accessTokenAuthenticationToken.getAccessToken());
-					} else {
-						throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
-					}
-				}
-			} catch (OAuth2AuthenticationException exception) {
-				SecurityContextHolder.clearContext();
-				writeFailureResponse(response, exception.getError());
-			}
-		} else {
+
+		if (!this.tokenEndpointMatcher.matches(request)) {
 			filterChain.doFilter(request, response);
+			return;
+		}
+
+		try {
+			Authentication authorizationGrantAuthentication =
+					this.authorizationGrantAuthenticationConverter.convert(request);
+			OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
+					(OAuth2AccessTokenAuthenticationToken) this.authenticationManager.authenticate(authorizationGrantAuthentication);
+			sendAccessTokenResponse(response, accessTokenAuthentication.getAccessToken());
+		} catch (OAuth2AuthenticationException ex) {
+			SecurityContextHolder.clearContext();
+			sendErrorResponse(response, ex.getError());
 		}
 	}
 
-	private boolean validateAccessTokenRequest(HttpServletRequest request) {
-		if (StringUtils.isEmpty(request.getParameter(OAuth2ParameterNames.CODE))
-				|| StringUtils.isEmpty(request.getParameter(OAuth2ParameterNames.REDIRECT_URI))
-				|| StringUtils.isEmpty(request.getParameter(OAuth2ParameterNames.GRANT_TYPE))) {
-			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
-		} else if (!AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(request.getParameter(OAuth2ParameterNames.GRANT_TYPE))) {
-			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE));
+	private void sendAccessTokenResponse(HttpServletResponse response, OAuth2AccessToken accessToken) throws IOException {
+		OAuth2AccessTokenResponse.Builder builder =
+				OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue())
+						.tokenType(accessToken.getTokenType())
+						.scopes(accessToken.getScopes());
+		if (accessToken.getIssuedAt() != null && accessToken.getExpiresAt() != null) {
+			builder.expiresIn(ChronoUnit.SECONDS.between(accessToken.getIssuedAt(), accessToken.getExpiresAt()));
 		}
-		return true;
+		OAuth2AccessTokenResponse accessTokenResponse = builder.build();
+		ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
+		this.accessTokenHttpResponseConverter.write(accessTokenResponse, null, httpResponse);
 	}
 
-	private OAuth2AuthorizationCodeAuthenticationToken convert(HttpServletRequest request) {
-		Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
-		return new OAuth2AuthorizationCodeAuthenticationToken(
-				request.getParameter(OAuth2ParameterNames.CODE),
-				clientPrincipal,
-				request.getParameter(OAuth2ParameterNames.REDIRECT_URI)
-		);
+	private void sendErrorResponse(HttpServletResponse response, OAuth2Error error) throws IOException {
+		ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
+		httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
+		this.errorHttpResponseConverter.write(error, null, httpResponse);
 	}
 
-	private void writeSuccessResponse(HttpServletResponse response, OAuth2AccessToken body) throws IOException {
-		try (Writer out = response.getWriter()) {
-			response.setStatus(HttpStatus.OK.value());
-			response.setContentType(MediaType.APPLICATION_JSON_VALUE);
-			response.setCharacterEncoding("UTF-8");
-			response.setHeader(HttpHeaders.CACHE_CONTROL, "no-store");
-			response.setHeader(HttpHeaders.PRAGMA, "no-cache");
-			out.write(objectMapper.writeValueAsString(body));
-		}
+	private static OAuth2AuthenticationException throwError(String errorCode, String parameterName) {
+		OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName,
+				"https://tools.ietf.org/html/rfc6749#section-5.2");
+		throw new OAuth2AuthenticationException(error);
 	}
 
-	private void writeFailureResponse(HttpServletResponse response, OAuth2Error error) throws IOException {
-		try (Writer out = response.getWriter()) {
-			if (error.getErrorCode().equals(OAuth2ErrorCodes.INVALID_CLIENT)) {
-				response.setStatus(HttpStatus.UNAUTHORIZED.value());
+	private static class AuthorizationCodeAuthenticationConverter implements Converter<HttpServletRequest, Authentication> {
+
+		@Override
+		public Authentication convert(HttpServletRequest request) {
+			MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
+
+			// grant_type (REQUIRED)
+			String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE);
+			if (!StringUtils.hasText(grantType) ||
+					parameters.get(OAuth2ParameterNames.GRANT_TYPE).size() != 1) {
+				throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.GRANT_TYPE);
+			}
+			if (!AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(grantType)) {
+				throwError(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE, OAuth2ParameterNames.GRANT_TYPE);
+			}
+
+			// client_id (REQUIRED)
+			String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
+			Authentication clientPrincipal = null;
+			if (StringUtils.hasText(clientId)) {
+				if (parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
+					throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
+				}
 			} else {
-				response.setStatus(HttpStatus.BAD_REQUEST.value());
+				clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
 			}
-			response.setContentType(MediaType.APPLICATION_JSON_VALUE);
-			response.setCharacterEncoding("UTF-8");
-			out.write(objectMapper.writeValueAsString(error));
+
+			// code (REQUIRED)
+			String code = parameters.getFirst(OAuth2ParameterNames.CODE);
+			if (!StringUtils.hasText(code) ||
+					parameters.get(OAuth2ParameterNames.CODE).size() != 1) {
+				throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CODE);
+			}
+
+			// redirect_uri (REQUIRED)
+			// Required only if the "redirect_uri" parameter was included in the authorization request
+			String redirectUri = parameters.getFirst(OAuth2ParameterNames.REDIRECT_URI);
+			if (StringUtils.hasText(redirectUri) &&
+					parameters.get(OAuth2ParameterNames.REDIRECT_URI).size() != 1) {
+				throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
+			}
+
+			return clientPrincipal != null ?
+					new OAuth2AuthorizationCodeAuthenticationToken(code, clientPrincipal, redirectUri) :
+					new OAuth2AuthorizationCodeAuthenticationToken(code, clientId, redirectUri);
 		}
 	}
 }

+ 183 - 128
core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java

@@ -15,36 +15,47 @@
  */
 package org.springframework.security.oauth2.server.authorization.web;
 
+import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
-import org.springframework.http.HttpHeaders;
+import org.mockito.ArgumentCaptor;
 import org.springframework.http.HttpStatus;
+import org.springframework.http.converter.HttpMessageConverter;
+import org.springframework.mock.http.client.MockClientHttpResponse;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
-import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
-import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
+import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
+import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
-import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 
 import javax.servlet.FilterChain;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
+import java.time.Duration;
 import java.time.Instant;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.function.Consumer;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.anyString;
 import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoInteractions;
 import static org.mockito.Mockito.when;
@@ -53,178 +64,222 @@ import static org.mockito.Mockito.when;
  * Tests for {@link OAuth2TokenEndpointFilter}.
  *
  * @author Madhu Bhat
+ * @author Joe Grandja
  */
 public class OAuth2TokenEndpointFilterTests {
-
+	private AuthenticationManager authenticationManager;
+	private OAuth2AuthorizationService authorizationService;
 	private OAuth2TokenEndpointFilter filter;
-	private OAuth2AuthorizationService authorizationService = mock(OAuth2AuthorizationService.class);
-	private AuthenticationManager authenticationManager = mock(AuthenticationManager.class);
-	private FilterChain filterChain = mock(FilterChain.class);
-	private String requestUri;
-	private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build();
-	private static final String PRINCIPAL_NAME = "principal";
-	private static final String AUTHORIZATION_CODE = "code";
+	private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
+			new OAuth2ErrorHttpMessageConverter();
+	private final HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter =
+			new OAuth2AccessTokenResponseHttpMessageConverter();
 
 	@Before
 	public void setUp() {
-		this.filter = new OAuth2TokenEndpointFilter(this.authorizationService, this.authenticationManager);
-		this.requestUri = "/oauth2/token";
+		this.authenticationManager = mock(AuthenticationManager.class);
+		this.authorizationService = mock(OAuth2AuthorizationService.class);
+		this.filter = new OAuth2TokenEndpointFilter(this.authenticationManager, this.authorizationService);
+	}
+
+	@After
+	public void cleanup() {
+		SecurityContextHolder.clearContext();
+	}
+
+	@Test
+	public void constructorWhenAuthenticationManagerNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2TokenEndpointFilter(null, this.authorizationService))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authenticationManager cannot be null");
 	}
 
 	@Test
-	public void constructorServiceAndManagerWhenNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> {
-			new OAuth2TokenEndpointFilter(null, null);
-		}).isInstanceOf(IllegalArgumentException.class);
+	public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2TokenEndpointFilter(this.authenticationManager, null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizationService cannot be null");
 	}
 
 	@Test
-	public void constructorServiceAndManagerAndEndpointWhenNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> {
-			new OAuth2TokenEndpointFilter(null, null, null);
-		}).isInstanceOf(IllegalArgumentException.class);
+	public void constructorWhenTokenEndpointUriNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2TokenEndpointFilter(this.authenticationManager, this.authorizationService, null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("tokenEndpointUri cannot be empty");
 	}
 
 	@Test
-	public void doFilterWhenNotTokenRequestThenNextFilter() throws Exception {
-		this.requestUri = "/path";
-		MockHttpServletRequest request = new MockHttpServletRequest("GET", this.requestUri);
-		request.setServletPath(this.requestUri);
+	public void doFilterWhenNotTokenRequestThenNotProcessed() throws Exception {
+		String requestUri = "/path";
+		MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);
+		request.setServletPath(requestUri);
 		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
 
-		this.filter.doFilter(request, response, this.filterChain);
+		this.filter.doFilter(request, response, filterChain);
 
-		verify(this.filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
 	}
 
 	@Test
-	public void doFilterWhenAccessTokenRequestWithoutGrantTypeThenRespondWithBadRequest() throws Exception {
-		MockHttpServletRequest request = new MockHttpServletRequest("POST", this.requestUri);
-		request.addParameter(OAuth2ParameterNames.CODE, "testAuthCode");
-		request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "testRedirectUri");
-		request.setServletPath(this.requestUri);
+	public void doFilterWhenTokenRequestGetThenNotProcessed() throws Exception {
+		String requestUri = OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI;
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
 		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
 
-		this.filter.doFilter(request, response, this.filterChain);
+		this.filter.doFilter(request, response, filterChain);
 
-		verifyNoInteractions(this.filterChain);
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
-		assertThat(response.getContentAsString()).isEqualTo("{\"errorCode\":\"invalid_request\"}");
+		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
 	}
 
 	@Test
-	public void doFilterWhenAccessTokenRequestWithoutCodeThenRespondWithBadRequest() throws Exception {
-		MockHttpServletRequest request = new MockHttpServletRequest("POST", this.requestUri);
-		request.addParameter(OAuth2ParameterNames.GRANT_TYPE, "testGrantType");
-		request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "testRedirectUri");
-		request.setServletPath(this.requestUri);
-		MockHttpServletResponse response = new MockHttpServletResponse();
+	public void doFilterWhenTokenRequestMissingGrantTypeThenInvalidRequestError() throws Exception {
+		doFilterWhenTokenRequestInvalidParameterThenError(
+				OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.INVALID_REQUEST,
+				request -> request.removeParameter(OAuth2ParameterNames.GRANT_TYPE));
+	}
 
-		this.filter.doFilter(request, response, this.filterChain);
+	@Test
+	public void doFilterWhenTokenRequestMultipleGrantTypeThenInvalidRequestError() throws Exception {
+		doFilterWhenTokenRequestInvalidParameterThenError(
+				OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.INVALID_REQUEST,
+				request -> request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()));
+	}
 
-		verifyNoInteractions(this.filterChain);
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
-		assertThat(response.getContentAsString()).isEqualTo("{\"errorCode\":\"invalid_request\"}");
+	@Test
+	public void doFilterWhenTokenRequestInvalidGrantTypeThenUnsupportedGrantTypeError() throws Exception {
+		doFilterWhenTokenRequestInvalidParameterThenError(
+				OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE,
+				request -> request.setParameter(OAuth2ParameterNames.GRANT_TYPE, "invalid-grant-type"));
 	}
 
 	@Test
-	public void doFilterWhenAccessTokenRequestWithoutRedirectUriThenRespondWithBadRequest() throws Exception {
-		MockHttpServletRequest request = new MockHttpServletRequest("POST", this.requestUri);
-		request.addParameter(OAuth2ParameterNames.GRANT_TYPE, "testGrantType");
-		request.addParameter(OAuth2ParameterNames.CODE, "testAuthCode");
-		request.setServletPath(this.requestUri);
-		MockHttpServletResponse response = new MockHttpServletResponse();
+	public void doFilterWhenTokenRequestMultipleClientIdThenInvalidRequestError() throws Exception {
+		doFilterWhenTokenRequestInvalidParameterThenError(
+				OAuth2ParameterNames.CLIENT_ID, OAuth2ErrorCodes.INVALID_REQUEST,
+				request -> {
+					request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-1");
+					request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2");
+				});
+	}
 
-		this.filter.doFilter(request, response, this.filterChain);
+	@Test
+	public void doFilterWhenTokenRequestMissingCodeThenInvalidRequestError() throws Exception {
+		doFilterWhenTokenRequestInvalidParameterThenError(
+				OAuth2ParameterNames.CODE, OAuth2ErrorCodes.INVALID_REQUEST,
+				request -> request.removeParameter(OAuth2ParameterNames.CODE));
+	}
 
-		verifyNoInteractions(this.filterChain);
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
-		assertThat(response.getContentAsString()).isEqualTo("{\"errorCode\":\"invalid_request\"}");
+	@Test
+	public void doFilterWhenTokenRequestMultipleCodeThenInvalidRequestError() throws Exception {
+		doFilterWhenTokenRequestInvalidParameterThenError(
+				OAuth2ParameterNames.CODE, OAuth2ErrorCodes.INVALID_REQUEST,
+				request -> request.addParameter(OAuth2ParameterNames.CODE, "code-2"));
 	}
 
 	@Test
-	public void doFilterWhenAccessTokenRequestWithoutAuthCodeGrantTypeThenRespondWithBadRequest() throws Exception {
-		MockHttpServletRequest request = new MockHttpServletRequest("POST", this.requestUri);
-		request.addParameter(OAuth2ParameterNames.GRANT_TYPE, "testGrantType");
-		request.addParameter(OAuth2ParameterNames.CODE, "testAuthCode");
-		request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "testRedirectUri");
-		request.setServletPath(this.requestUri);
+	public void doFilterWhenTokenRequestMultipleRedirectUriThenInvalidRequestError() throws Exception {
+		doFilterWhenTokenRequestInvalidParameterThenError(
+				OAuth2ParameterNames.REDIRECT_URI, OAuth2ErrorCodes.INVALID_REQUEST,
+				request -> request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"));
+	}
+
+	@Test
+	public void doFilterWhenTokenRequestValidThenAccessTokenResponse() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
+		OAuth2AccessToken accessToken = new OAuth2AccessToken(
+				OAuth2AccessToken.TokenType.BEARER, "token",
+				Instant.now(), Instant.now().plus(Duration.ofHours(1)),
+				new HashSet<>(Arrays.asList("scope1", "scope2")));
+		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
+				new OAuth2AccessTokenAuthenticationToken(
+						registeredClient, clientPrincipal, accessToken);
+
+		when(this.authenticationManager.authenticate(any())).thenReturn(accessTokenAuthentication);
+
+		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
+		securityContext.setAuthentication(clientPrincipal);
+		SecurityContextHolder.setContext(securityContext);
+
+		MockHttpServletRequest request = createTokenRequest(registeredClient);
 		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
 
-		this.filter.doFilter(request, response, this.filterChain);
+		this.filter.doFilter(request, response, filterChain);
 
-		verifyNoInteractions(this.filterChain);
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
-		assertThat(response.getContentAsString()).isEqualTo("{\"errorCode\":\"unsupported_grant_type\"}");
+		verifyNoInteractions(filterChain);
+
+		ArgumentCaptor<OAuth2AuthorizationCodeAuthenticationToken> authorizationCodeAuthenticationCaptor =
+				ArgumentCaptor.forClass(OAuth2AuthorizationCodeAuthenticationToken.class);
+		verify(this.authenticationManager).authenticate(authorizationCodeAuthenticationCaptor.capture());
+
+		OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication =
+				authorizationCodeAuthenticationCaptor.getValue();
+		assertThat(authorizationCodeAuthentication.getCode()).isEqualTo(
+				request.getParameter(OAuth2ParameterNames.CODE));
+		assertThat(authorizationCodeAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
+		assertThat(authorizationCodeAuthentication.getRedirectUri()).isEqualTo(
+				request.getParameter(OAuth2ParameterNames.REDIRECT_URI));
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value());
+		OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response);
+
+		OAuth2AccessToken accessTokenResult = accessTokenResponse.getAccessToken();
+		assertThat(accessTokenResult.getTokenType()).isEqualTo(accessToken.getTokenType());
+		assertThat(accessTokenResult.getTokenValue()).isEqualTo(accessToken.getTokenValue());
+		assertThat(accessTokenResult.getIssuedAt()).isBetween(
+				accessToken.getIssuedAt().minusSeconds(1), accessToken.getIssuedAt().plusSeconds(1));
+		assertThat(accessTokenResult.getExpiresAt()).isBetween(
+				accessToken.getExpiresAt().minusSeconds(1), accessToken.getExpiresAt().plusSeconds(1));
+		assertThat(accessTokenResult.getScopes()).isEqualTo(accessToken.getScopes());
 	}
 
-	@Test
-	public void doFilterWhenAccessTokenRequestIsNotAuthenticatedThenRespondWithUnauthorized() throws Exception {
-		MockHttpServletRequest request = new MockHttpServletRequest("POST", this.requestUri);
-		request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
-		request.addParameter(OAuth2ParameterNames.CODE, "testAuthCode");
-		request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "testRedirectUri");
-		request.setServletPath(this.requestUri);
+	private void doFilterWhenTokenRequestInvalidParameterThenError(String parameterName, String errorCode,
+			Consumer<MockHttpServletRequest> requestConsumer) throws Exception {
+
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+
+		MockHttpServletRequest request = createTokenRequest(registeredClient);
+		requestConsumer.accept(request);
 		MockHttpServletResponse response = new MockHttpServletResponse();
-		Authentication clientPrincipal = mock(Authentication.class);
-		RegisteredClient registeredClient = mock(RegisteredClient.class);
+		FilterChain filterChain = mock(FilterChain.class);
 
-		OAuth2AccessToken accessToken = new OAuth2AccessToken(
-				OAuth2AccessToken.TokenType.BEARER,  "testToken", Instant.now().minusSeconds(60), Instant.now());
-		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
-				.principalName(PRINCIPAL_NAME)
-				.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
-				.build();
-		OAuth2AccessTokenAuthenticationToken accessTokenAuthenticationToken =
-				new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken);
-		accessTokenAuthenticationToken.setAuthenticated(false);
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
 
-		when(this.authorizationService.findByTokenAndTokenType(anyString(), any(TokenType.class))).thenReturn(authorization);
-		when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(accessTokenAuthenticationToken);
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+		OAuth2Error error = readError(response);
+		assertThat(error.getErrorCode()).isEqualTo(errorCode);
+		assertThat(error.getDescription()).isEqualTo("OAuth 2.0 Parameter: " + parameterName);
+	}
 
-		this.filter.doFilter(request, response, this.filterChain);
+	private OAuth2Error readError(MockHttpServletResponse response) throws Exception {
+		MockClientHttpResponse httpResponse = new MockClientHttpResponse(
+				response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus()));
+		return this.errorHttpResponseConverter.read(OAuth2Error.class, httpResponse);
+	}
 
-		verifyNoInteractions(this.filterChain);
-		verify(this.authorizationService, times(0)).save(authorization);
-		verify(this.authenticationManager, times(1)).authenticate(any(Authentication.class));
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value());
-		assertThat(response.getContentAsString())
-				.isEqualTo("{\"errorCode\":\"invalid_client\"}");
+	private OAuth2AccessTokenResponse readAccessTokenResponse(MockHttpServletResponse response) throws Exception {
+		MockClientHttpResponse httpResponse = new MockClientHttpResponse(
+				response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus()));
+		return this.accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse);
 	}
 
-	@Test
-	public void doFilterWhenValidAccessTokenRequestThenRespondWithAccessToken() throws Exception {
-		MockHttpServletRequest request = new MockHttpServletRequest("POST", this.requestUri);
+	private static MockHttpServletRequest createTokenRequest(RegisteredClient registeredClient) {
+		String[] redirectUris = registeredClient.getRedirectUris().toArray(new String[0]);
+
+		String requestUri = OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI;
+		MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);
+		request.setServletPath(requestUri);
+
 		request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
-		request.addParameter(OAuth2ParameterNames.CODE, "testAuthCode");
-		request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "testRedirectUri");
-		request.setServletPath(this.requestUri);
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		Authentication clientPrincipal = mock(Authentication.class);
-		RegisteredClient registeredClient = mock(RegisteredClient.class);
+		request.addParameter(OAuth2ParameterNames.CODE, "code");
+		request.addParameter(OAuth2ParameterNames.REDIRECT_URI, redirectUris[0]);
 
-		OAuth2AccessToken accessToken = new OAuth2AccessToken(
-				OAuth2AccessToken.TokenType.BEARER,  "testToken", Instant.now().minusSeconds(60), Instant.now());
-		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
-				.principalName(PRINCIPAL_NAME)
-				.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
-				.build();
-		OAuth2AccessTokenAuthenticationToken accessTokenAuthenticationToken =
-				new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken);
-		accessTokenAuthenticationToken.setAuthenticated(true);
-
-		when(this.authorizationService.findByTokenAndTokenType(anyString(), any(TokenType.class))).thenReturn(authorization);
-		when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(accessTokenAuthenticationToken);
-
-		this.filter.doFilter(request, response, this.filterChain);
-
-		verifyNoInteractions(this.filterChain);
-		verify(this.authorizationService, times(1)).save(authorization);
-		verify(this.authenticationManager, times(1)).authenticate(any(Authentication.class));
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value());
-		assertThat(response.getContentAsString()).contains("\"tokenValue\":\"testToken\"");
-		assertThat(response.getContentAsString()).contains("\"tokenType\":{\"value\":\"Bearer\"}");
-		assertThat(response.getHeader(HttpHeaders.CACHE_CONTROL)).isEqualTo("no-store");
-		assertThat(response.getHeader(HttpHeaders.PRAGMA)).isEqualTo("no-cache");
+		return request;
 	}
 }