瀏覽代碼

Add OAuth2Authorization.authorizationGrantType

Issue gh-213
Joe Grandja 4 年之前
父節點
當前提交
7261b40cd5

+ 30 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java

@@ -24,6 +24,7 @@ import java.util.function.Consumer;
 
 import org.springframework.lang.Nullable;
 import org.springframework.security.oauth2.core.AbstractOAuth2Token;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
@@ -39,6 +40,7 @@ import org.springframework.util.Assert;
  * @author Krisztian Toth
  * @since 0.0.1
  * @see RegisteredClient
+ * @see AuthorizationGrantType
  * @see AbstractOAuth2Token
  * @see OAuth2AccessToken
  * @see OAuth2RefreshToken
@@ -47,6 +49,7 @@ public class OAuth2Authorization implements Serializable {
 	private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
 	private String registeredClientId;
 	private String principalName;
+	private AuthorizationGrantType authorizationGrantType;
 	private Map<Class<? extends AbstractOAuth2Token>, Token<?>> tokens;
 	private Map<String, Object> attributes;
 
@@ -71,6 +74,15 @@ public class OAuth2Authorization implements Serializable {
 		return this.principalName;
 	}
 
+	/**
+	 * Returns the {@link AuthorizationGrantType authorization grant type} used for the authorization.
+	 *
+	 * @return the {@link AuthorizationGrantType} used for the authorization
+	 */
+	public AuthorizationGrantType getAuthorizationGrantType() {
+		return this.authorizationGrantType;
+	}
+
 	/**
 	 * Returns the {@link Token} of type {@link OAuth2AccessToken}.
 	 *
@@ -157,13 +169,15 @@ public class OAuth2Authorization implements Serializable {
 		OAuth2Authorization that = (OAuth2Authorization) obj;
 		return Objects.equals(this.registeredClientId, that.registeredClientId) &&
 				Objects.equals(this.principalName, that.principalName) &&
+				Objects.equals(this.authorizationGrantType, that.authorizationGrantType) &&
 				Objects.equals(this.tokens, that.tokens) &&
 				Objects.equals(this.attributes, that.attributes);
 	}
 
 	@Override
 	public int hashCode() {
-		return Objects.hash(this.registeredClientId, this.principalName, this.tokens, this.attributes);
+		return Objects.hash(this.registeredClientId, this.principalName,
+				this.authorizationGrantType, this.tokens, this.attributes);
 	}
 
 	/**
@@ -187,6 +201,7 @@ public class OAuth2Authorization implements Serializable {
 		Assert.notNull(authorization, "authorization cannot be null");
 		return new Builder(authorization.getRegisteredClientId())
 				.principalName(authorization.getPrincipalName())
+				.authorizationGrantType(authorization.getAuthorizationGrantType())
 				.tokens(authorization.tokens)
 				.attributes(attrs -> attrs.putAll(authorization.getAttributes()));
 	}
@@ -292,6 +307,7 @@ public class OAuth2Authorization implements Serializable {
 		private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
 		private final String registeredClientId;
 		private String principalName;
+		private AuthorizationGrantType authorizationGrantType;
 		private Map<Class<? extends AbstractOAuth2Token>, Token<?>> tokens = new HashMap<>();
 		private final Map<String, Object> attributes = new HashMap<>();
 
@@ -310,6 +326,17 @@ public class OAuth2Authorization implements Serializable {
 			return this;
 		}
 
+		/**
+		 * Sets the {@link AuthorizationGrantType authorization grant type} used for the authorization.
+		 *
+		 * @param authorizationGrantType the {@link AuthorizationGrantType}
+		 * @return the {@link Builder}
+		 */
+		public Builder authorizationGrantType(AuthorizationGrantType authorizationGrantType) {
+			this.authorizationGrantType = authorizationGrantType;
+			return this;
+		}
+
 		/**
 		 * Sets the {@link OAuth2AccessToken access token}.
 		 *
@@ -401,10 +428,12 @@ public class OAuth2Authorization implements Serializable {
 		 */
 		public OAuth2Authorization build() {
 			Assert.hasText(this.principalName, "principalName cannot be empty");
+			Assert.notNull(this.authorizationGrantType, "authorizationGrantType cannot be null");
 
 			OAuth2Authorization authorization = new OAuth2Authorization();
 			authorization.registeredClientId = this.registeredClientId;
 			authorization.principalName = this.principalName;
+			authorization.authorizationGrantType = this.authorizationGrantType;
 			authorization.tokens = Collections.unmodifiableMap(this.tokens);
 			authorization.attributes = Collections.unmodifiableMap(this.attributes);
 			return authorization;

+ 1 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java

@@ -124,6 +124,7 @@ public class OAuth2ClientCredentialsAuthenticationProvider implements Authentica
 
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient)
 				.principalName(clientPrincipal.getName())
+				.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
 				.token(accessToken)
 				.attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt)
 				.build();

+ 1 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java

@@ -193,6 +193,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 		OAuth2AuthorizationRequest authorizationRequest = authorizationRequestContext.buildAuthorizationRequest();
 		OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient)
 				.principalName(principal.getName())
+				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
 				.attribute(OAuth2AuthorizationAttributeNames.PRINCIPAL, principal)
 				.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest);
 

+ 8 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java

@@ -21,6 +21,7 @@ import java.time.temporal.ChronoUnit;
 import org.junit.Before;
 import org.junit.Test;
 
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
@@ -39,6 +40,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
 public class InMemoryOAuth2AuthorizationServiceTests {
 	private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build();
 	private static final String PRINCIPAL_NAME = "principal";
+	private static final AuthorizationGrantType AUTHORIZATION_GRANT_TYPE = AuthorizationGrantType.AUTHORIZATION_CODE;
 	private static final OAuth2AuthorizationCode AUTHORIZATION_CODE = new OAuth2AuthorizationCode(
 			"code", Instant.now(), Instant.now().plus(5, ChronoUnit.MINUTES));
 	private InMemoryOAuth2AuthorizationService authorizationService;
@@ -59,6 +61,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 	public void saveWhenAuthorizationProvidedThenSaved() {
 		OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
 				.token(AUTHORIZATION_CODE)
 				.build();
 		this.authorizationService.save(expectedAuthorization);
@@ -79,6 +82,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 	public void removeWhenAuthorizationProvidedThenRemoved() {
 		OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
 				.token(AUTHORIZATION_CODE)
 				.build();
 
@@ -105,6 +109,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 		String state = "state";
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
 				.attribute(OAuth2AuthorizationAttributeNames.STATE, state)
 				.build();
 		this.authorizationService.save(authorization);
@@ -120,6 +125,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 	public void findByTokenWhenAuthorizationCodeExistsThenFound() {
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
 				.token(AUTHORIZATION_CODE)
 				.build();
 		this.authorizationService.save(authorization);
@@ -137,6 +143,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 				"access-token", Instant.now().minusSeconds(60), Instant.now());
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
 				.token(AUTHORIZATION_CODE)
 				.accessToken(accessToken)
 				.build();
@@ -154,6 +161,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", Instant.now());
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
 				.refreshToken(refreshToken)
 				.build();
 		this.authorizationService.save(authorization);

+ 13 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java

@@ -20,6 +20,7 @@ import java.time.temporal.ChronoUnit;
 
 import org.junit.Test;
 
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
@@ -38,6 +39,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
 public class OAuth2AuthorizationTests {
 	private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build();
 	private static final String PRINCIPAL_NAME = "principal";
+	private static final AuthorizationGrantType AUTHORIZATION_GRANT_TYPE = AuthorizationGrantType.AUTHORIZATION_CODE;
 	private static final OAuth2AccessToken ACCESS_TOKEN = new OAuth2AccessToken(
 			OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300));
 	private static final OAuth2RefreshToken REFRESH_TOKEN = new OAuth2RefreshToken("refresh-token", Instant.now());
@@ -62,6 +64,7 @@ public class OAuth2AuthorizationTests {
 	public void fromWhenAuthorizationProvidedThenCopied() {
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
 				.token(AUTHORIZATION_CODE)
 				.accessToken(ACCESS_TOKEN)
 				.build();
@@ -69,6 +72,7 @@ public class OAuth2AuthorizationTests {
 
 		assertThat(authorizationResult.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId());
 		assertThat(authorizationResult.getPrincipalName()).isEqualTo(authorization.getPrincipalName());
+		assertThat(authorizationResult.getAuthorizationGrantType()).isEqualTo(authorization.getAuthorizationGrantType());
 		assertThat(authorizationResult.getAccessToken()).isEqualTo(authorization.getAccessToken());
 		assertThat(authorizationResult.getToken(OAuth2AuthorizationCode.class))
 				.isEqualTo(authorization.getToken(OAuth2AuthorizationCode.class));
@@ -82,6 +86,13 @@ public class OAuth2AuthorizationTests {
 				.hasMessage("principalName cannot be empty");
 	}
 
+	@Test
+	public void buildWhenAuthorizationGrantTypeNotProvidedThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT).principalName(PRINCIPAL_NAME).build())
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizationGrantType cannot be null");
+	}
+
 	@Test
 	public void attributeWhenNameNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() ->
@@ -104,6 +115,7 @@ public class OAuth2AuthorizationTests {
 	public void buildWhenAllAttributesAreProvidedThenAllAttributesAreSet() {
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
 				.token(AUTHORIZATION_CODE)
 				.accessToken(ACCESS_TOKEN)
 				.refreshToken(REFRESH_TOKEN)
@@ -111,6 +123,7 @@ public class OAuth2AuthorizationTests {
 
 		assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId());
 		assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME);
+		assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AUTHORIZATION_GRANT_TYPE);
 		assertThat(authorization.getToken(OAuth2AuthorizationCode.class).getToken()).isEqualTo(AUTHORIZATION_CODE);
 		assertThat(authorization.getAccessToken().getToken()).isEqualTo(ACCESS_TOKEN);
 		assertThat(authorization.getRefreshToken().getToken()).isEqualTo(REFRESH_TOKEN);

+ 2 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java

@@ -21,6 +21,7 @@ import java.util.Collections;
 import java.util.Map;
 
 import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
@@ -61,6 +62,7 @@ public class TestOAuth2Authorizations {
 				.build();
 		return OAuth2Authorization.withRegisteredClient(registeredClient)
 				.principalName("principal")
+				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
 				.token(authorizationCode)
 				.accessToken(accessToken)
 				.refreshToken(refreshToken)

+ 1 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java

@@ -204,6 +204,7 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
 
 		assertThat(authorization.getRegisteredClientId()).isEqualTo(clientPrincipal.getRegisteredClient().getId());
 		assertThat(authorization.getPrincipalName()).isEqualTo(clientPrincipal.getName());
+		assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS);
 		assertThat(authorization.getAccessToken()).isNotNull();
 		assertThat(authorization.getAccessToken().getToken().getScopes()).isEqualTo(clientPrincipal.getRegisteredClient().getScopes());
 		assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);

+ 4 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@@ -467,6 +467,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2Authorization authorization = authorizationCaptor.getValue();
 		assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
 		assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
+		assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
 		assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
 				.isEqualTo(this.authentication);
 
@@ -516,6 +517,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2Authorization authorization = authorizationCaptor.getValue();
 		assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
 		assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
+		assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
 		assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
 				.isEqualTo(this.authentication);
 
@@ -563,6 +565,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2Authorization authorization = authorizationCaptor.getValue();
 		assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
 		assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
+		assertThat(authorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
 		assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
 				.isEqualTo(this.authentication);
 
@@ -795,6 +798,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
 		assertThat(updatedAuthorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
 		assertThat(updatedAuthorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
+		assertThat(updatedAuthorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
 		assertThat(updatedAuthorization.getToken(OAuth2AuthorizationCode.class)).isNotNull();
 		assertThat(updatedAuthorization.<String>getAttribute(OAuth2AuthorizationAttributeNames.STATE)).isNull();
 		assertThat(updatedAuthorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST))