Browse Source

Add authorization_code AuthenticationProvider

Fixes gh-68
Joe Grandja 5 years ago
parent
commit
fe286b5994

+ 14 - 0
core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java

@@ -127,6 +127,20 @@ public class OAuth2Authorization implements Serializable {
 		return new Builder(registeredClient.getId());
 	}
 
+	/**
+	 * Returns a new {@link Builder}, initialized with the values from the provided {@code authorization}.
+	 *
+	 * @param authorization the authorization used for initializing the {@link Builder}
+	 * @return the {@link Builder}
+	 */
+	public static Builder from(OAuth2Authorization authorization) {
+		Assert.notNull(authorization, "authorization cannot be null");
+		return new Builder(authorization.getRegisteredClientId())
+				.principalName(authorization.getPrincipalName())
+				.accessToken(authorization.getAccessToken())
+				.attributes(attrs -> attrs.putAll(authorization.getAttributes()));
+	}
+
 	/**
 	 * A builder for {@link OAuth2Authorization}.
 	 */

+ 35 - 7
core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java

@@ -20,35 +20,63 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.server.authorization.Version;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.util.Assert;
 
 import java.util.Collections;
 
 /**
+ * An {@link Authentication} implementation used when issuing an OAuth 2.0 Access Token.
+ *
  * @author Joe Grandja
  * @author Madhu Bhat
+ * @since 0.0.1
+ * @see AbstractAuthenticationToken
+ * @see OAuth2AuthorizationCodeAuthenticationProvider
+ * @see RegisteredClient
+ * @see OAuth2AccessToken
+ * @see OAuth2ClientAuthenticationToken
  */
 public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthenticationToken {
 	private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
-	private RegisteredClient registeredClient;
-	private Authentication clientPrincipal;
-	private OAuth2AccessToken accessToken;
+	private final RegisteredClient registeredClient;
+	private final Authentication clientPrincipal;
+	private final OAuth2AccessToken accessToken;
 
+	/**
+	 * Constructs an {@code OAuth2AccessTokenAuthenticationToken} using the provided parameters.
+	 *
+	 * @param registeredClient the registered client
+	 * @param clientPrincipal the authenticated client principal
+	 * @param accessToken the access token
+	 */
 	public OAuth2AccessTokenAuthenticationToken(RegisteredClient registeredClient,
 			Authentication clientPrincipal, OAuth2AccessToken accessToken) {
 		super(Collections.emptyList());
+		Assert.notNull(registeredClient, "registeredClient cannot be null");
+		Assert.notNull(clientPrincipal, "clientPrincipal cannot be null");
+		Assert.notNull(accessToken, "accessToken cannot be null");
 		this.registeredClient = registeredClient;
 		this.clientPrincipal = clientPrincipal;
 		this.accessToken = accessToken;
 	}
 
 	@Override
-	public Object getCredentials() {
-		return null;
+	public Object getPrincipal() {
+		return this.clientPrincipal;
 	}
 
 	@Override
-	public Object getPrincipal() {
-		return null;
+	public Object getCredentials() {
+		return "";
+	}
+
+	/**
+	 * Returns the {@link RegisteredClient registered client}.
+	 *
+	 * @return the {@link RegisteredClient}
+	 */
+	public RegisteredClient getRegisteredClient() {
+		return this.registeredClient;
 	}
 
 	/**

+ 89 - 4
core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java

@@ -18,21 +18,106 @@ package org.springframework.security.oauth2.server.authorization.authentication;
 import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
+import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
 import org.springframework.security.crypto.keygen.StringKeyGenerator;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
+import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
+
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import java.util.Base64;
 
 /**
+ * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Authorization Code Grant.
+ *
  * @author Joe Grandja
+ * @since 0.0.1
+ * @see OAuth2AuthorizationCodeAuthenticationToken
+ * @see OAuth2AccessTokenAuthenticationToken
+ * @see RegisteredClientRepository
+ * @see OAuth2AuthorizationService
+ * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a>
+ * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.3">Section 4.1.3 Access Token Request</a>
  */
 public class OAuth2AuthorizationCodeAuthenticationProvider implements AuthenticationProvider {
-	private RegisteredClientRepository registeredClientRepository;
-	private OAuth2AuthorizationService authorizationService;
-	private StringKeyGenerator accessTokenGenerator;
+	private final RegisteredClientRepository registeredClientRepository;
+	private final OAuth2AuthorizationService authorizationService;
+	private final StringKeyGenerator accessTokenGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
+
+	/**
+	 * Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the provided parameters.
+	 *
+	 * @param registeredClientRepository the repository of registered clients
+	 * @param authorizationService the authorization service
+	 */
+	public OAuth2AuthorizationCodeAuthenticationProvider(RegisteredClientRepository registeredClientRepository,
+			OAuth2AuthorizationService authorizationService) {
+		Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
+		Assert.notNull(authorizationService, "authorizationService cannot be null");
+		this.registeredClientRepository = registeredClientRepository;
+		this.authorizationService = authorizationService;
+	}
 
 	@Override
 	public Authentication authenticate(Authentication authentication) throws AuthenticationException {
-		return authentication;
+		OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication =
+				(OAuth2AuthorizationCodeAuthenticationToken) authentication;
+
+		OAuth2ClientAuthenticationToken clientPrincipal = null;
+		if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(authorizationCodeAuthentication.getPrincipal().getClass())) {
+			clientPrincipal = (OAuth2ClientAuthenticationToken) authorizationCodeAuthentication.getPrincipal();
+		}
+		if (clientPrincipal == null || !clientPrincipal.isAuthenticated()) {
+			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
+		}
+
+		// TODO Authenticate public client
+		// A client MAY use the "client_id" request parameter to identify itself
+		// when sending requests to the token endpoint.
+		// In the "authorization_code" "grant_type" request to the token endpoint,
+		// an unauthenticated client MUST send its "client_id" to prevent itself
+		// from inadvertently accepting a code intended for a client with a different "client_id".
+		// This protects the client from substitution of the authentication code.
+
+		OAuth2Authorization authorization = this.authorizationService.findByTokenAndTokenType(
+				authorizationCodeAuthentication.getCode(), TokenType.AUTHORIZATION_CODE);
+		if (authorization == null) {
+			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
+		}
+		if (!clientPrincipal.getRegisteredClient().getId().equals(authorization.getRegisteredClientId())) {
+			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
+		}
+
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
+				OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
+		if (StringUtils.hasText(authorizationRequest.getRedirectUri()) &&
+				!authorizationRequest.getRedirectUri().equals(authorizationCodeAuthentication.getRedirectUri())) {
+			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
+		}
+
+		String tokenValue = this.accessTokenGenerator.generateKey();
+		Instant issuedAt = Instant.now();
+		Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS);		// TODO Allow configuration for access token lifespan
+		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
+				tokenValue, issuedAt, expiresAt, authorizationRequest.getScopes());
+
+		authorization = OAuth2Authorization.from(authorization)
+				.accessToken(accessToken)
+				.build();
+		this.authorizationService.save(authorization);
+
+		return new OAuth2AccessTokenAuthenticationToken(
+				clientPrincipal.getRegisteredClient(), clientPrincipal, accessToken);
 	}
 
 	@Override

+ 30 - 5
core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java

@@ -19,12 +19,19 @@ import org.springframework.lang.Nullable;
 import org.springframework.security.authentication.AbstractAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.server.authorization.Version;
+import org.springframework.util.Assert;
 
 import java.util.Collections;
 
 /**
+ * An {@link Authentication} implementation used for the OAuth 2.0 Authorization Code Grant.
+ *
  * @author Joe Grandja
  * @author Madhu Bhat
+ * @since 0.0.1
+ * @see AbstractAuthenticationToken
+ * @see OAuth2AuthorizationCodeAuthenticationProvider
+ * @see OAuth2ClientAuthenticationToken
  */
 public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken {
 	private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
@@ -33,17 +40,35 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
 	private String clientId;
 	private String redirectUri;
 
+	/**
+	 * Constructs an {@code OAuth2AuthorizationCodeAuthenticationToken} using the provided parameters.
+	 *
+	 * @param code the authorization code
+	 * @param clientPrincipal the authenticated client principal
+	 * @param redirectUri the redirect uri
+	 */
 	public OAuth2AuthorizationCodeAuthenticationToken(String code,
 			Authentication clientPrincipal, @Nullable String redirectUri) {
 		super(Collections.emptyList());
+		Assert.hasText(code, "code cannot be empty");
+		Assert.notNull(clientPrincipal, "clientPrincipal cannot be null");
 		this.code = code;
 		this.clientPrincipal = clientPrincipal;
 		this.redirectUri = redirectUri;
 	}
 
+	/**
+	 * Constructs an {@code OAuth2AuthorizationCodeAuthenticationToken} using the provided parameters.
+	 *
+	 * @param code the authorization code
+	 * @param clientId the client identifier
+	 * @param redirectUri the redirect uri
+	 */
 	public OAuth2AuthorizationCodeAuthenticationToken(String code,
 			String clientId, @Nullable String redirectUri) {
 		super(Collections.emptyList());
+		Assert.hasText(code, "code cannot be empty");
+		Assert.hasText(clientId, "clientId cannot be empty");
 		this.code = code;
 		this.clientId = clientId;
 		this.redirectUri = redirectUri;
@@ -60,20 +85,20 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
 	}
 
 	/**
-	 * Returns the code.
+	 * Returns the authorization code.
 	 *
-	 * @return the code
+	 * @return the authorization code
 	 */
 	public String getCode() {
 		return this.code;
 	}
 
 	/**
-	 * Returns the redirectUri.
+	 * Returns the redirect uri.
 	 *
-	 * @return the redirectUri
+	 * @return the redirect uri
 	 */
-	public String getRedirectUri() {
+	public @Nullable String getRedirectUri() {
 		return this.redirectUri;
 	}
 }

+ 24 - 1
core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java

@@ -30,12 +30,13 @@ import static org.assertj.core.data.MapEntry.entry;
  * Tests for {@link OAuth2Authorization}.
  *
  * @author Krisztian Toth
+ * @author Joe Grandja
  */
 public class OAuth2AuthorizationTests {
 	private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build();
 	private static final String PRINCIPAL_NAME = "principal";
 	private static final OAuth2AccessToken ACCESS_TOKEN = new OAuth2AccessToken(
-			OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now().minusSeconds(60), Instant.now());
+			OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300));
 	private static final String AUTHORIZATION_CODE = "code";
 
 	@Test
@@ -45,6 +46,28 @@ public class OAuth2AuthorizationTests {
 				.hasMessage("registeredClient cannot be null");
 	}
 
+	@Test
+	public void fromWhenAuthorizationNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> OAuth2Authorization.from(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorization cannot be null");
+	}
+
+	@Test
+	public void fromWhenAuthorizationProvidedThenCopied() {
+		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+				.principalName(PRINCIPAL_NAME)
+				.accessToken(ACCESS_TOKEN)
+				.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
+				.build();
+		OAuth2Authorization authorizationResult = OAuth2Authorization.from(authorization).build();
+
+		assertThat(authorizationResult.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId());
+		assertThat(authorizationResult.getPrincipalName()).isEqualTo(authorization.getPrincipalName());
+		assertThat(authorizationResult.getAccessToken()).isEqualTo(authorization.getAccessToken());
+		assertThat(authorizationResult.getAttributes()).isEqualTo(authorization.getAttributes());
+	}
+
 	@Test
 	public void buildWhenPrincipalNameNotProvidedThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT).build())

+ 46 - 0
core/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java

@@ -0,0 +1,46 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization;
+
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+
+import java.time.Instant;
+
+/**
+ * @author Joe Grandja
+ */
+public class TestOAuth2Authorizations {
+
+	public static OAuth2Authorization.Builder authorization() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		OAuth2AccessToken accessToken = new OAuth2AccessToken(
+				OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300));
+		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://provider.com/oauth2/authorize")
+				.clientId(registeredClient.getClientId())
+				.redirectUri("https://client.com/authorized")
+				.state("state")
+				.build();
+		return OAuth2Authorization.withRegisteredClient(registeredClient)
+				.principalName("principal")
+				.accessToken(accessToken)
+				.attribute(OAuth2AuthorizationAttributeNames.CODE, "code")
+				.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest);
+	}
+}

+ 70 - 0
core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationTokenTests.java

@@ -0,0 +1,70 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.authentication;
+
+import org.junit.Test;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+
+import java.time.Instant;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Tests for {@link OAuth2AccessTokenAuthenticationToken}.
+ *
+ * @author Joe Grandja
+ */
+public class OAuth2AccessTokenAuthenticationTokenTests {
+	private RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+	private OAuth2ClientAuthenticationToken clientPrincipal =
+			new OAuth2ClientAuthenticationToken(this.registeredClient);
+	private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
+			"access-token", Instant.now(), Instant.now().plusSeconds(300));
+
+	@Test
+	public void constructorWhenRegisteredClientNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2AccessTokenAuthenticationToken(null, this.clientPrincipal, this.accessToken))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("registeredClient cannot be null");
+	}
+
+	@Test
+	public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2AccessTokenAuthenticationToken(this.registeredClient, null, this.accessToken))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("clientPrincipal cannot be null");
+	}
+
+	@Test
+	public void constructorWhenAccessTokenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2AccessTokenAuthenticationToken(this.registeredClient, this.clientPrincipal, null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("accessToken cannot be null");
+	}
+
+	@Test
+	public void constructorWhenAllValuesProvidedThenCreated() {
+		OAuth2AccessTokenAuthenticationToken authentication = new OAuth2AccessTokenAuthenticationToken(
+				this.registeredClient, this.clientPrincipal, this.accessToken);
+		assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal);
+		assertThat(authentication.getCredentials().toString()).isEmpty();
+		assertThat(authentication.getRegisteredClient()).isEqualTo(this.registeredClient);
+		assertThat(authentication.getAccessToken()).isEqualTo(this.accessToken);
+	}
+}

+ 178 - 0
core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java

@@ -0,0 +1,178 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.authentication;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
+import org.springframework.security.oauth2.server.authorization.TokenType;
+import org.springframework.security.oauth2.server.authorization.client.InMemoryRegisteredClientRepository;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link OAuth2AuthorizationCodeAuthenticationProvider}.
+ *
+ * @author Joe Grandja
+ */
+public class OAuth2AuthorizationCodeAuthenticationProviderTests {
+	private RegisteredClient registeredClient;
+	private RegisteredClientRepository registeredClientRepository;
+	private OAuth2AuthorizationService authorizationService;
+	private OAuth2AuthorizationCodeAuthenticationProvider authenticationProvider;
+
+	@Before
+	public void setUp() {
+		this.registeredClient = TestRegisteredClients.registeredClient().build();
+		this.registeredClientRepository = new InMemoryRegisteredClientRepository(this.registeredClient);
+		this.authorizationService = mock(OAuth2AuthorizationService.class);
+		this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(
+				this.registeredClientRepository, this.authorizationService);
+	}
+
+	@Test
+	public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(null, this.authorizationService))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("registeredClientRepository cannot be null");
+	}
+
+	@Test
+	public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(this.registeredClientRepository, null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizationService cannot be null");
+	}
+
+	@Test
+	public void supportsWhenTypeOAuth2AuthorizationCodeAuthenticationTokenThenReturnTrue() {
+		assertThat(this.authenticationProvider.supports(OAuth2AuthorizationCodeAuthenticationToken.class)).isTrue();
+	}
+
+	@Test
+	public void authenticateWhenClientPrincipalNotOAuth2ClientAuthenticationTokenThenThrowOAuth2AuthenticationException() {
+		TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken(
+				this.registeredClient.getClientId(), this.registeredClient.getClientSecret());
+		OAuth2AuthorizationCodeAuthenticationToken authentication =
+				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+	}
+
+	@Test
+	public void authenticateWhenClientPrincipalNotAuthenticatedThenThrowOAuth2AuthenticationException() {
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
+				this.registeredClient.getClientId(), this.registeredClient.getClientSecret());
+		OAuth2AuthorizationCodeAuthenticationToken authentication =
+				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+	}
+
+	@Test
+	public void authenticateWhenInvalidCodeThenThrowOAuth2AuthenticationException() {
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
+		OAuth2AuthorizationCodeAuthenticationToken authentication =
+				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
+	}
+
+	@Test
+	public void authenticateWhenCodeIssuedToAnotherClientThenThrowOAuth2AuthenticationException() {
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build();
+		when(this.authorizationService.findByTokenAndTokenType(eq("code"), eq(TokenType.AUTHORIZATION_CODE)))
+				.thenReturn(authorization);
+
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
+				TestRegisteredClients.registeredClient2().build());
+		OAuth2AuthorizationCodeAuthenticationToken authentication =
+				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
+	}
+
+	@Test
+	public void authenticateWhenInvalidRedirectUriThenThrowOAuth2AuthenticationException() {
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build();
+		when(this.authorizationService.findByTokenAndTokenType(eq("code"), eq(TokenType.AUTHORIZATION_CODE)))
+				.thenReturn(authorization);
+
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
+				OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
+		OAuth2AuthorizationCodeAuthenticationToken authentication =
+				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri() + "-invalid");
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
+	}
+
+	@Test
+	public void authenticateWhenValidCodeThenReturnAccessToken() {
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build();
+		when(this.authorizationService.findByTokenAndTokenType(eq("code"), eq(TokenType.AUTHORIZATION_CODE)))
+				.thenReturn(authorization);
+
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
+				OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
+		OAuth2AuthorizationCodeAuthenticationToken authentication =
+				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri());
+
+		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
+				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
+
+		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
+		verify(this.authorizationService).save(authorizationCaptor.capture());
+		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
+
+		assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
+		assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
+		assertThat(updatedAuthorization.getAccessToken()).isNotNull();
+		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken());
+	}
+}

+ 77 - 0
core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java

@@ -0,0 +1,77 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.authentication;
+
+import org.junit.Test;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Tests for {@link OAuth2AuthorizationCodeAuthenticationToken}.
+ *
+ * @author Joe Grandja
+ */
+public class OAuth2AuthorizationCodeAuthenticationTokenTests {
+	private String code = "code";
+	private OAuth2ClientAuthenticationToken clientPrincipal =
+			new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().build());
+	private String clientId = "clientId";
+	private String redirectUri = "redirectUri";
+
+	@Test
+	public void constructorWhenCodeNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(null, this.clientPrincipal, this.redirectUri))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("code cannot be empty");
+	}
+
+	@Test
+	public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (Authentication) null, this.redirectUri))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("clientPrincipal cannot be null");
+	}
+
+	@Test
+	public void constructorWhenClientIdNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (String) null, this.redirectUri))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("clientId cannot be empty");
+	}
+
+	@Test
+	public void constructorWhenClientPrincipalProvidedThenCreated() {
+		OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(
+				this.code, this.clientPrincipal, this.redirectUri);
+		assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal);
+		assertThat(authentication.getCredentials().toString()).isEmpty();
+		assertThat(authentication.getCode()).isEqualTo(this.code);
+		assertThat(authentication.getRedirectUri()).isEqualTo(this.redirectUri);
+	}
+
+	@Test
+	public void constructorWhenClientIdProvidedThenCreated() {
+		OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(
+				this.code, this.clientId, this.redirectUri);
+		assertThat(authentication.getPrincipal()).isEqualTo(this.clientId);
+		assertThat(authentication.getCredentials().toString()).isEmpty();
+		assertThat(authentication.getCode()).isEqualTo(this.code);
+		assertThat(authentication.getRedirectUri()).isEqualTo(this.redirectUri);
+	}
+}