Joe Grandja vor 5 Jahren
Ursprung
Commit
3ab314c8aa

+ 15 - 23
core/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java

@@ -17,33 +17,33 @@ package org.springframework.security.oauth2.server.authorization;
 
 
 import org.springframework.util.Assert;
 import org.springframework.util.Assert;
 
 
-import java.util.Collections;
 import java.util.List;
 import java.util.List;
 import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.CopyOnWriteArrayList;
 
 
 /**
 /**
- * In-memory implementation of {@link OAuth2AuthorizationService}.
+ * An {@link OAuth2AuthorizationService} that stores {@link OAuth2Authorization}'s in-memory.
  *
  *
  * @author Krisztian Toth
  * @author Krisztian Toth
+ * @since 0.0.1
+ * @see OAuth2AuthorizationService
  */
  */
 public final class InMemoryOAuth2AuthorizationService implements OAuth2AuthorizationService {
 public final class InMemoryOAuth2AuthorizationService implements OAuth2AuthorizationService {
 	private final List<OAuth2Authorization> authorizations;
 	private final List<OAuth2Authorization> authorizations;
 
 
 	/**
 	/**
-	 * Creates an {@link InMemoryOAuth2AuthorizationService}.
+	 * Constructs an {@code InMemoryOAuth2AuthorizationService}.
 	 */
 	 */
 	public InMemoryOAuth2AuthorizationService() {
 	public InMemoryOAuth2AuthorizationService() {
-		this(Collections.emptyList());
+		this.authorizations = new CopyOnWriteArrayList<>();
 	}
 	}
 
 
 	/**
 	/**
-	 * Creates an {@link InMemoryOAuth2AuthorizationService} with the provided {@link List}<{@link OAuth2Authorization}>
-	 * as the in-memory store.
+	 * Constructs an {@code InMemoryOAuth2AuthorizationService} using the provided parameters.
 	 *
 	 *
-	 * @param authorizations a {@link List}<{@link OAuth2Authorization}> object to use as the store
+	 * @param authorizations the initial {@code List} of {@link OAuth2Authorization}(s)
 	 */
 	 */
 	public InMemoryOAuth2AuthorizationService(List<OAuth2Authorization> authorizations) {
 	public InMemoryOAuth2AuthorizationService(List<OAuth2Authorization> authorizations) {
-		Assert.notNull(authorizations, "authorizations cannot be null");
+		Assert.notEmpty(authorizations, "authorizations cannot be empty");
 		this.authorizations = new CopyOnWriteArrayList<>(authorizations);
 		this.authorizations = new CopyOnWriteArrayList<>(authorizations);
 	}
 	}
 
 
@@ -58,26 +58,18 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
 		Assert.hasText(token, "token cannot be empty");
 		Assert.hasText(token, "token cannot be empty");
 		Assert.notNull(tokenType, "tokenType cannot be null");
 		Assert.notNull(tokenType, "tokenType cannot be null");
 		return this.authorizations.stream()
 		return this.authorizations.stream()
-				.filter(authorization -> doesMatch(authorization, token, tokenType))
+				.filter(authorization -> hasToken(authorization, token, tokenType))
 				.findFirst()
 				.findFirst()
 				.orElse(null);
 				.orElse(null);
-
 	}
 	}
 
 
-	private boolean doesMatch(OAuth2Authorization authorization, String token, TokenType tokenType) {
-		if (tokenType.equals(TokenType.ACCESS_TOKEN)) {
-			return isAccessTokenEqual(token, authorization);
-		} else if (tokenType.equals(TokenType.AUTHORIZATION_CODE)) {
-			return isAuthorizationCodeEqual(token, authorization);
+	private boolean hasToken(OAuth2Authorization authorization, String token, TokenType tokenType) {
+		if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) {
+			return token.equals(authorization.getAttributes().get(TokenType.AUTHORIZATION_CODE.getValue()));
+		} else if (TokenType.ACCESS_TOKEN.equals(tokenType)) {
+			return authorization.getAccessToken() != null &&
+					authorization.getAccessToken().getTokenValue().equals(token);
 		}
 		}
 		return false;
 		return false;
 	}
 	}
-
-	private boolean isAccessTokenEqual(String token, OAuth2Authorization authorization) {
-		return authorization.getAccessToken() != null && token.equals(authorization.getAccessToken().getTokenValue());
-	}
-
-	private boolean isAuthorizationCodeEqual(String token, OAuth2Authorization authorization) {
-		return token.equals(authorization.getAttributes().get(TokenType.AUTHORIZATION_CODE.getValue()));
-	}
 }
 }

+ 65 - 67
core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java

@@ -16,8 +16,10 @@
 package org.springframework.security.oauth2.server.authorization;
 package org.springframework.security.oauth2.server.authorization;
 
 
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.util.Assert;
 import org.springframework.util.Assert;
 
 
+import java.io.Serializable;
 import java.util.Collections;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Map;
@@ -25,12 +27,18 @@ import java.util.Objects;
 import java.util.function.Consumer;
 import java.util.function.Consumer;
 
 
 /**
 /**
- * Represents a collection of attributes which describe an OAuth 2.0 authorization context.
+ * A representation of an OAuth 2.0 Authorization,
+ * which holds state related to the authorization granted to the {@link #getRegisteredClientId() client}
+ * by the {@link #getPrincipalName() resource owner}.
  *
  *
  * @author Joe Grandja
  * @author Joe Grandja
  * @author Krisztian Toth
  * @author Krisztian Toth
+ * @since 0.0.1
+ * @see RegisteredClient
+ * @see OAuth2AccessToken
  */
  */
-public class OAuth2Authorization {
+public class OAuth2Authorization implements Serializable {
+	private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
 	private String registeredClientId;
 	private String registeredClientId;
 	private String principalName;
 	private String principalName;
 	private OAuth2AccessToken accessToken;
 	private OAuth2AccessToken accessToken;
@@ -39,43 +47,64 @@ public class OAuth2Authorization {
 	protected OAuth2Authorization() {
 	protected OAuth2Authorization() {
 	}
 	}
 
 
+	/**
+	 * Returns the identifier for the {@link RegisteredClient#getId() registered client}.
+	 *
+	 * @return the {@link RegisteredClient#getId()}
+	 */
 	public String getRegisteredClientId() {
 	public String getRegisteredClientId() {
 		return this.registeredClientId;
 		return this.registeredClientId;
 	}
 	}
 
 
+	/**
+	 * Returns the resource owner's {@code Principal} name.
+	 *
+	 * @return the resource owner's {@code Principal} name
+	 */
 	public String getPrincipalName() {
 	public String getPrincipalName() {
 		return this.principalName;
 		return this.principalName;
 	}
 	}
 
 
+	/**
+	 * Returns the {@link OAuth2AccessToken access token} credential.
+	 *
+	 * @return the {@link OAuth2AccessToken}
+	 */
 	public OAuth2AccessToken getAccessToken() {
 	public OAuth2AccessToken getAccessToken() {
 		return this.accessToken;
 		return this.accessToken;
 	}
 	}
 
 
+	/**
+	 * Returns the attribute(s) associated to the authorization.
+	 *
+	 * @return a {@code Map} of the attribute(s)
+	 */
 	public Map<String, Object> getAttributes() {
 	public Map<String, Object> getAttributes() {
 		return this.attributes;
 		return this.attributes;
 	}
 	}
 
 
 	/**
 	/**
-	 * Returns an attribute with the provided name or {@code null} if not found.
+	 * Returns the value of an attribute associated to the authorization.
 	 *
 	 *
 	 * @param name the name of the attribute
 	 * @param name the name of the attribute
-	 * @param <T>  the type of the attribute
-	 * @return the found attribute or {@code null}
+	 * @param <T> the type of the attribute
+	 * @return the value of the attribute associated to the authorization, or {@code null} if not available
 	 */
 	 */
+	@SuppressWarnings("unchecked")
 	public <T> T getAttribute(String name) {
 	public <T> T getAttribute(String name) {
 		Assert.hasText(name, "name cannot be empty");
 		Assert.hasText(name, "name cannot be empty");
 		return (T) this.attributes.get(name);
 		return (T) this.attributes.get(name);
 	}
 	}
 
 
 	@Override
 	@Override
-	public boolean equals(Object o) {
-		if (this == o) {
+	public boolean equals(Object obj) {
+		if (this == obj) {
 			return true;
 			return true;
 		}
 		}
-		if (o == null || getClass() != o.getClass()) {
+		if (obj == null || getClass() != obj.getClass()) {
 			return false;
 			return false;
 		}
 		}
-		OAuth2Authorization that = (OAuth2Authorization) o;
+		OAuth2Authorization that = (OAuth2Authorization) obj;
 		return Objects.equals(this.registeredClientId, that.registeredClientId) &&
 		return Objects.equals(this.registeredClientId, that.registeredClientId) &&
 				Objects.equals(this.principalName, that.principalName) &&
 				Objects.equals(this.principalName, that.principalName) &&
 				Objects.equals(this.accessToken, that.accessToken) &&
 				Objects.equals(this.accessToken, that.accessToken) &&
@@ -88,59 +117,34 @@ public class OAuth2Authorization {
 	}
 	}
 
 
 	/**
 	/**
-	 * Returns an empty {@link Builder}.
+	 * Returns a new {@link Builder}, initialized with the provided {@link RegisteredClient#getId()}.
 	 *
 	 *
+	 * @param registeredClient the {@link RegisteredClient}
 	 * @return the {@link Builder}
 	 * @return the {@link Builder}
 	 */
 	 */
-	public static Builder builder() {
-		return new Builder();
+	public static Builder withRegisteredClient(RegisteredClient registeredClient) {
+		Assert.notNull(registeredClient, "registeredClient cannot be null");
+		return new Builder(registeredClient.getId());
 	}
 	}
 
 
 	/**
 	/**
-	 * Returns a new {@link Builder}, initialized with the provided {@link OAuth2Authorization}.
-	 *
-	 * @param authorization the {@link OAuth2Authorization} to copy from
-	 * @return the {@link Builder}
+	 * A builder for {@link OAuth2Authorization}.
 	 */
 	 */
-	public static Builder withAuthorization(OAuth2Authorization authorization) {
-		Assert.notNull(authorization, "authorization cannot be null");
-		return new Builder(authorization);
-	}
-
-	/**
-	 * Builder class for {@link OAuth2Authorization}.
-	 */
-	public static class Builder {
+	public static class Builder implements Serializable {
+		private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
 		private String registeredClientId;
 		private String registeredClientId;
 		private String principalName;
 		private String principalName;
 		private OAuth2AccessToken accessToken;
 		private OAuth2AccessToken accessToken;
 		private Map<String, Object> attributes = new HashMap<>();
 		private Map<String, Object> attributes = new HashMap<>();
 
 
-		protected Builder() {
-		}
-
-		protected Builder(OAuth2Authorization authorization) {
-			this.registeredClientId = authorization.registeredClientId;
-			this.principalName = authorization.principalName;
-			this.accessToken = authorization.accessToken;
-			this.attributes = authorization.attributes;
-		}
-
-		/**
-		 * Sets the registered client identifier.
-		 *
-		 * @param registeredClientId the client id
-		 * @return the {@link Builder}
-		 */
-		public Builder registeredClientId(String registeredClientId) {
+		protected Builder(String registeredClientId) {
 			this.registeredClientId = registeredClientId;
 			this.registeredClientId = registeredClientId;
-			return this;
 		}
 		}
 
 
 		/**
 		/**
-		 * Sets the principal name.
+		 * Sets the resource owner's {@code Principal} name.
 		 *
 		 *
-		 * @param principalName the principal name
+		 * @param principalName the resource owner's {@code Principal} name
 		 * @return the {@link Builder}
 		 * @return the {@link Builder}
 		 */
 		 */
 		public Builder principalName(String principalName) {
 		public Builder principalName(String principalName) {
@@ -149,7 +153,7 @@ public class OAuth2Authorization {
 		}
 		}
 
 
 		/**
 		/**
-		 * Sets the {@link OAuth2AccessToken}.
+		 * Sets the {@link OAuth2AccessToken access token} credential.
 		 *
 		 *
 		 * @param accessToken the {@link OAuth2AccessToken}
 		 * @param accessToken the {@link OAuth2AccessToken}
 		 * @return the {@link Builder}
 		 * @return the {@link Builder}
@@ -160,23 +164,24 @@ public class OAuth2Authorization {
 		}
 		}
 
 
 		/**
 		/**
-		 * Adds the attribute with the specified name and {@link String} value to the attributes map.
+		 * Adds an attribute associated to the authorization.
 		 *
 		 *
-		 * @param name  the name of the attribute
+		 * @param name the name of the attribute
 		 * @param value the value of the attribute
 		 * @param value the value of the attribute
 		 * @return the {@link Builder}
 		 * @return the {@link Builder}
 		 */
 		 */
-		public Builder attribute(String name, String value) {
+		public Builder attribute(String name, Object value) {
 			Assert.hasText(name, "name cannot be empty");
 			Assert.hasText(name, "name cannot be empty");
-			Assert.hasText(value, "value cannot be empty");
+			Assert.notNull(value, "value cannot be null");
 			this.attributes.put(name, value);
 			this.attributes.put(name, value);
 			return this;
 			return this;
 		}
 		}
 
 
 		/**
 		/**
-		 * A {@code Consumer} of the attributes map allowing to access or modify its content.
+		 * A {@code Consumer} of the attributes {@code Map}
+		 * allowing the ability to add, replace, or remove.
 		 *
 		 *
-		 * @param attributesConsumer a {@link Consumer} of the attributes map
+		 * @param attributesConsumer a {@link Consumer} of the attributes {@code Map}
 		 * @return the {@link Builder}
 		 * @return the {@link Builder}
 		 */
 		 */
 		public Builder attributes(Consumer<Map<String, Object>> attributesConsumer) {
 		public Builder attributes(Consumer<Map<String, Object>> attributesConsumer) {
@@ -190,22 +195,15 @@ public class OAuth2Authorization {
 		 * @return the {@link OAuth2Authorization}
 		 * @return the {@link OAuth2Authorization}
 		 */
 		 */
 		public OAuth2Authorization build() {
 		public OAuth2Authorization build() {
-			Assert.hasText(this.registeredClientId, "registeredClientId cannot be empty");
 			Assert.hasText(this.principalName, "principalName cannot be empty");
 			Assert.hasText(this.principalName, "principalName cannot be empty");
-			if (this.accessToken == null && this.attributes.get(TokenType.AUTHORIZATION_CODE.getValue()) == null) {
-				throw new IllegalArgumentException("either accessToken has to be set or the authorization code with key '"
-						+ TokenType.AUTHORIZATION_CODE.getValue() + "' must be provided in the attributes map");
-			}
-			return create();
-		}
-
-		private OAuth2Authorization create() {
-			OAuth2Authorization oAuth2Authorization = new OAuth2Authorization();
-			oAuth2Authorization.registeredClientId = this.registeredClientId;
-			oAuth2Authorization.principalName = this.principalName;
-			oAuth2Authorization.accessToken = this.accessToken;
-			oAuth2Authorization.attributes = Collections.unmodifiableMap(this.attributes);
-			return oAuth2Authorization;
+			Assert.notNull(this.attributes.get(TokenType.AUTHORIZATION_CODE.getValue()), "authorization code cannot be null");
+
+			OAuth2Authorization authorization = new OAuth2Authorization();
+			authorization.registeredClientId = this.registeredClientId;
+			authorization.principalName = this.principalName;
+			authorization.accessToken = this.accessToken;
+			authorization.attributes = Collections.unmodifiableMap(this.attributes);
+			return authorization;
 		}
 		}
 	}
 	}
 }
 }

+ 18 - 0
core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationService.java

@@ -16,12 +16,30 @@
 package org.springframework.security.oauth2.server.authorization;
 package org.springframework.security.oauth2.server.authorization;
 
 
 /**
 /**
+ * Implementations of this interface are responsible for the management
+ * of {@link OAuth2Authorization OAuth 2.0 Authorization(s)}.
+ *
  * @author Joe Grandja
  * @author Joe Grandja
+ * @since 0.0.1
+ * @see OAuth2Authorization
  */
  */
 public interface OAuth2AuthorizationService {
 public interface OAuth2AuthorizationService {
 
 
+	/**
+	 * Saves the {@link OAuth2Authorization}.
+	 *
+	 * @param authorization the {@link OAuth2Authorization}
+	 */
 	void save(OAuth2Authorization authorization);
 	void save(OAuth2Authorization authorization);
 
 
+	/**
+	 * Returns the {@link OAuth2Authorization} containing the provided {@code token},
+	 * or {@code null} if not found.
+	 *
+	 * @param token the token credential
+	 * @param tokenType the {@link TokenType token type}
+	 * @return the {@link OAuth2Authorization} if found, otherwise {@code null}
+	 */
 	OAuth2Authorization findByTokenAndTokenType(String token, TokenType tokenType);
 	OAuth2Authorization findByTokenAndTokenType(String token, TokenType tokenType);
 
 
 }
 }

+ 61 - 58
core/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java

@@ -15,105 +15,108 @@
  */
  */
 package org.springframework.security.oauth2.server.authorization;
 package org.springframework.security.oauth2.server.authorization;
 
 
+import org.junit.Before;
 import org.junit.Test;
 import org.junit.Test;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 
 
 import java.time.Instant;
 import java.time.Instant;
-import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Collections;
 
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 
 /**
 /**
- * Unit tests for {@link InMemoryOAuth2AuthorizationService}.
+ * Tests for {@link InMemoryOAuth2AuthorizationService}.
  *
  *
  * @author Krisztian Toth
  * @author Krisztian Toth
  */
  */
 public class InMemoryOAuth2AuthorizationServiceTests {
 public class InMemoryOAuth2AuthorizationServiceTests {
+	private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build();
+	private static final String PRINCIPAL_NAME = "principal";
+	private static final String AUTHORIZATION_CODE = "code";
+	private InMemoryOAuth2AuthorizationService authorizationService;
 
 
-	private static final String TOKEN = "token";
-	private static final TokenType AUTHORIZATION_CODE = TokenType.AUTHORIZATION_CODE;
-	private static final TokenType ACCESS_TOKEN = TokenType.ACCESS_TOKEN;
-	private static final Instant ISSUED_AT = Instant.now().minusSeconds(60);
-	private static final Instant EXPIRES_AT = Instant.now();
+	@Before
+	public void setup() {
+		this.authorizationService = new InMemoryOAuth2AuthorizationService();
+	}
 
 
-	private InMemoryOAuth2AuthorizationService authorizationService;
+	@Test
+	public void constructorWhenAuthorizationListNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new InMemoryOAuth2AuthorizationService(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizations cannot be empty");
+	}
 
 
 	@Test
 	@Test
-	public void saveWhenAuthorizationProvidedThenSavedInList() {
-		authorizationService = new InMemoryOAuth2AuthorizationService(new ArrayList<>());
+	public void saveWhenAuthorizationNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizationService.save(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorization cannot be null");
+	}
 
 
-		OAuth2Authorization authorization = OAuth2Authorization.builder()
-				.registeredClientId("clientId")
-				.principalName("principalName")
-				.attribute(AUTHORIZATION_CODE.getValue(), TOKEN)
+	@Test
+	public void saveWhenAuthorizationProvidedThenSaved() {
+		OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+				.principalName(PRINCIPAL_NAME)
+				.attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE)
 				.build();
 				.build();
-		authorizationService.save(authorization);
+		this.authorizationService.save(expectedAuthorization);
 
 
-		assertThat(authorizationService.findByTokenAndTokenType(TOKEN, AUTHORIZATION_CODE)).isEqualTo(authorization);
+		OAuth2Authorization authorization = this.authorizationService.findByTokenAndTokenType(
+				AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE);
+		assertThat(authorization).isEqualTo(expectedAuthorization);
 	}
 	}
 
 
 	@Test
 	@Test
-	public void saveWhenAuthorizationNotProvidedThenThrowIllegalArgumentException() {
-		authorizationService = new InMemoryOAuth2AuthorizationService(new ArrayList<>());
+	public void findByTokenAndTokenTypeWhenTokenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizationService.findByTokenAndTokenType(null, TokenType.AUTHORIZATION_CODE))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("token cannot be empty");
+	}
 
 
-		assertThatThrownBy(() -> authorizationService.save(null))
-				.isInstanceOf(IllegalArgumentException.class);
+	@Test
+	public void findByTokenAndTokenTypeWhenTokenTypeNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizationService.findByTokenAndTokenType(AUTHORIZATION_CODE, null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("tokenType cannot be null");
 	}
 	}
 
 
 	@Test
 	@Test
-	public void findByTokenAndTokenTypeWhenTokenTypeIsAuthorizationCodeThenFound() {
-		OAuth2Authorization authorization = OAuth2Authorization.builder()
-				.registeredClientId("clientId")
-				.principalName("principalName")
-				.attribute(AUTHORIZATION_CODE.getValue(), TOKEN)
+	public void findByTokenAndTokenTypeWhenTokenTypeAuthorizationCodeThenFound() {
+		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+				.principalName(PRINCIPAL_NAME)
+				.attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE)
 				.build();
 				.build();
-		authorizationService = new InMemoryOAuth2AuthorizationService(Collections.singletonList(authorization));
+		this.authorizationService = new InMemoryOAuth2AuthorizationService(Collections.singletonList(authorization));
 
 
-		OAuth2Authorization result = authorizationService.findByTokenAndTokenType(TOKEN, TokenType.AUTHORIZATION_CODE);
+		OAuth2Authorization result = this.authorizationService.findByTokenAndTokenType(
+				AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE);
 		assertThat(authorization).isEqualTo(result);
 		assertThat(authorization).isEqualTo(result);
 	}
 	}
 
 
 	@Test
 	@Test
-	public void findByTokenAndTokenTypeWhenTokenTypeIsAccessTokenThenFound() {
-		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, TOKEN, ISSUED_AT,
-				EXPIRES_AT);
-		OAuth2Authorization authorization = OAuth2Authorization.builder()
-				.registeredClientId("clientId")
-				.principalName("principalName")
+	public void findByTokenAndTokenTypeWhenTokenTypeAccessTokenThenFound() {
+		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
+				"access-token", Instant.now().minusSeconds(60), Instant.now());
+		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+				.principalName(PRINCIPAL_NAME)
+				.attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE)
 				.accessToken(accessToken)
 				.accessToken(accessToken)
 				.build();
 				.build();
-		authorizationService = new InMemoryOAuth2AuthorizationService(Collections.singletonList(authorization));
+		this.authorizationService.save(authorization);
 
 
-		OAuth2Authorization result = authorizationService.findByTokenAndTokenType(TOKEN, ACCESS_TOKEN);
+		OAuth2Authorization result = this.authorizationService.findByTokenAndTokenType(
+				"access-token", TokenType.ACCESS_TOKEN);
 		assertThat(authorization).isEqualTo(result);
 		assertThat(authorization).isEqualTo(result);
 	}
 	}
 
 
 	@Test
 	@Test
-	public void findByTokenAndTokenTypeWhenTokenWithTokenTypeDoesNotExistThenNull() {
-		OAuth2Authorization authorization = OAuth2Authorization.builder()
-				.registeredClientId("clientId")
-				.principalName("principalName")
-				.attribute(AUTHORIZATION_CODE.getValue(), TOKEN)
-				.build();
-		authorizationService = new InMemoryOAuth2AuthorizationService(Collections.singletonList(authorization));
-
-		OAuth2Authorization result = authorizationService.findByTokenAndTokenType(TOKEN, ACCESS_TOKEN);
+	public void findByTokenAndTokenTypeWhenTokenDoesNotExistThenNull() {
+		OAuth2Authorization result = this.authorizationService.findByTokenAndTokenType(
+				"access-token", TokenType.ACCESS_TOKEN);
 		assertThat(result).isNull();
 		assertThat(result).isNull();
 	}
 	}
-
-	@Test
-	public void findByTokenAndTokenTypeWhenTokenNullThenThrowIllegalArgumentException() {
-		authorizationService = new InMemoryOAuth2AuthorizationService();
-		assertThatThrownBy(() -> authorizationService.findByTokenAndTokenType(null, TokenType.AUTHORIZATION_CODE))
-				.isInstanceOf(IllegalArgumentException.class);
-	}
-
-	@Test
-	public void findByTokenAndTokenTypeWhenTokenTypeNullThenThrowIllegalArgumentException() {
-		authorizationService = new InMemoryOAuth2AuthorizationService();
-		assertThatThrownBy(() -> authorizationService.findByTokenAndTokenType(TOKEN, null))
-				.isInstanceOf(IllegalArgumentException.class);
-	}
 }
 }

+ 38 - 78
core/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java

@@ -17,120 +17,80 @@ package org.springframework.security.oauth2.server.authorization;
 
 
 import org.junit.Test;
 import org.junit.Test;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 
 
 import java.time.Instant;
 import java.time.Instant;
-import java.util.Collections;
-import java.util.Map;
 
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.assertj.core.data.MapEntry.entry;
 
 
 /**
 /**
- * Unit tests For {@link OAuth2Authorization}.
+ * Tests for {@link OAuth2Authorization}.
  *
  *
  * @author Krisztian Toth
  * @author Krisztian Toth
  */
  */
 public class OAuth2AuthorizationTests {
 public class OAuth2AuthorizationTests {
-
-	public static final String REGISTERED_CLIENT_ID = "clientId";
-	public static final String PRINCIPAL_NAME = "principal";
-	public static final OAuth2AccessToken ACCESS_TOKEN = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
-			"token", Instant.now().minusSeconds(60), Instant.now());
-	public static final String AUTHORIZATION_CODE_VALUE = TokenType.AUTHORIZATION_CODE.getValue();
-	public static final String CODE = "code";
-	public static final Map<String, Object> ATTRIBUTES = Collections.singletonMap(AUTHORIZATION_CODE_VALUE, CODE);
+	private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build();
+	private static final String PRINCIPAL_NAME = "principal";
+	private static final OAuth2AccessToken ACCESS_TOKEN = new OAuth2AccessToken(
+			OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now().minusSeconds(60), Instant.now());
+	private static final String AUTHORIZATION_CODE = "code";
 
 
 	@Test
 	@Test
-	public void buildWhenAllAttributesAreProvidedThenAllAttributesAreSet() {
-		OAuth2Authorization authorization = OAuth2Authorization.builder()
-				.registeredClientId(REGISTERED_CLIENT_ID)
-				.principalName(PRINCIPAL_NAME)
-				.accessToken(ACCESS_TOKEN)
-				.attribute(AUTHORIZATION_CODE_VALUE, CODE)
-				.build();
-
-		assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT_ID);
-		assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME);
-		assertThat(authorization.getAccessToken()).isEqualTo(ACCESS_TOKEN);
-		assertThat(authorization.getAttributes()).isEqualTo(ATTRIBUTES);
+	public void withRegisteredClientWhenRegisteredClientNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> OAuth2Authorization.withRegisteredClient(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("registeredClient cannot be null");
 	}
 	}
 
 
 	@Test
 	@Test
-	public void buildWhenBuildThenImmutableMapIsCreated() {
-		OAuth2Authorization authorization = OAuth2Authorization.builder()
-				.registeredClientId(REGISTERED_CLIENT_ID)
-				.principalName(PRINCIPAL_NAME)
-				.accessToken(ACCESS_TOKEN)
-				.attribute("any", "value")
-				.build();
-
-		assertThatThrownBy(() -> authorization.getAttributes().put("any", "value"))
-				.isInstanceOf(UnsupportedOperationException.class);
-	}
-
-	@Test
-	public void buildWhenAccessTokenAndAuthorizationCodeNotProvidedThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() ->
-				OAuth2Authorization.builder()
-						.registeredClientId(REGISTERED_CLIENT_ID)
-						.principalName(PRINCIPAL_NAME)
-						.build()
-		).isInstanceOf(IllegalArgumentException.class);
-	}
-
-	@Test
-	public void buildWhenRegisteredClientIdNotProvidedThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() ->
-				OAuth2Authorization.builder()
-						.principalName(PRINCIPAL_NAME)
-						.accessToken(ACCESS_TOKEN)
-						.attribute(AUTHORIZATION_CODE_VALUE, CODE)
-						.build()
-		).isInstanceOf(IllegalArgumentException.class);
+	public void buildWhenPrincipalNameNotProvidedThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT).build())
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("principalName cannot be empty");
 	}
 	}
 
 
 	@Test
 	@Test
-	public void buildWhenPrincipalNameNotProvidedThenThrowIllegalArgumentException() {
+	public void buildWhenAuthorizationCodeNotProvidedThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() ->
 		assertThatThrownBy(() ->
-				OAuth2Authorization.builder()
-						.registeredClientId(REGISTERED_CLIENT_ID)
-						.accessToken(ACCESS_TOKEN)
-						.attribute(AUTHORIZATION_CODE_VALUE, CODE)
-						.build()
-		).isInstanceOf(IllegalArgumentException.class);
+				OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+						.principalName(PRINCIPAL_NAME).build())
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorization code cannot be null");
 	}
 	}
 
 
 	@Test
 	@Test
-	public void buildWhenAttributeSetWithNullNameThenThrowIllegalArgumentException() {
+	public void attributeWhenNameNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() ->
 		assertThatThrownBy(() ->
-				OAuth2Authorization.builder()
-						.attribute(null, CODE)
-		).isInstanceOf(IllegalArgumentException.class);
+				OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+						.attribute(null, AUTHORIZATION_CODE))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("name cannot be empty");
 	}
 	}
 
 
 	@Test
 	@Test
-	public void buildWhenAttributeSetWithNullValueThenThrowIllegalArgumentException() {
+	public void attributeWhenValueNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() ->
 		assertThatThrownBy(() ->
-				OAuth2Authorization.builder()
-						.attribute(AUTHORIZATION_CODE_VALUE, null)
-		).isInstanceOf(IllegalArgumentException.class);
+				OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+						.attribute(TokenType.AUTHORIZATION_CODE.getValue(), null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("value cannot be null");
 	}
 	}
 
 
 	@Test
 	@Test
-	public void withOAuth2AuthorizationWhenAuthorizationProvidedThenAllAttributesAreCopied() {
-		OAuth2Authorization authorizationToCopy = OAuth2Authorization.builder()
-				.registeredClientId(REGISTERED_CLIENT_ID)
+	public void buildWhenAllAttributesAreProvidedThenAllAttributesAreSet() {
+		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
 				.principalName(PRINCIPAL_NAME)
-				.attribute(AUTHORIZATION_CODE_VALUE, CODE)
-				.build();
-
-		OAuth2Authorization authorization = OAuth2Authorization.withAuthorization(authorizationToCopy)
 				.accessToken(ACCESS_TOKEN)
 				.accessToken(ACCESS_TOKEN)
+				.attribute(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE)
 				.build();
 				.build();
 
 
-		assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT_ID);
+		assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId());
 		assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME);
 		assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME);
 		assertThat(authorization.getAccessToken()).isEqualTo(ACCESS_TOKEN);
 		assertThat(authorization.getAccessToken()).isEqualTo(ACCESS_TOKEN);
-		assertThat(authorization.getAttributes()).isEqualTo(ATTRIBUTES);
+		assertThat(authorization.getAttributes()).containsExactly(
+				entry(TokenType.AUTHORIZATION_CODE.getValue(), AUTHORIZATION_CODE));
 	}
 	}
 }
 }