Browse Source

InMemoryOAuth2AuthorizationService uniquely identifies OAuth2Authorization

Closes gh-98
Joe Grandja 5 years ago
parent
commit
eeca3df66b

+ 40 - 22
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java

@@ -16,48 +16,37 @@
 package org.springframework.security.oauth2.server.authorization;
 package org.springframework.security.oauth2.server.authorization;
 
 
 import org.springframework.lang.Nullable;
 import org.springframework.lang.Nullable;
+import org.springframework.security.core.SpringSecurityCoreVersion2;
 import org.springframework.util.Assert;
 import org.springframework.util.Assert;
 
 
-import java.util.List;
-import java.util.concurrent.CopyOnWriteArrayList;
+import java.io.Serializable;
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
 
 
 /**
 /**
  * An {@link OAuth2AuthorizationService} that stores {@link OAuth2Authorization}'s in-memory.
  * An {@link OAuth2AuthorizationService} that stores {@link OAuth2Authorization}'s in-memory.
  *
  *
  * @author Krisztian Toth
  * @author Krisztian Toth
+ * @author Joe Grandja
  * @since 0.0.1
  * @since 0.0.1
  * @see OAuth2AuthorizationService
  * @see OAuth2AuthorizationService
  */
  */
 public final class InMemoryOAuth2AuthorizationService implements OAuth2AuthorizationService {
 public final class InMemoryOAuth2AuthorizationService implements OAuth2AuthorizationService {
-	private final List<OAuth2Authorization> authorizations;
-
-	/**
-	 * Constructs an {@code InMemoryOAuth2AuthorizationService}.
-	 */
-	public InMemoryOAuth2AuthorizationService() {
-		this.authorizations = new CopyOnWriteArrayList<>();
-	}
-
-	/**
-	 * Constructs an {@code InMemoryOAuth2AuthorizationService} using the provided parameters.
-	 *
-	 * @param authorizations the initial {@code List} of {@link OAuth2Authorization}(s)
-	 */
-	public InMemoryOAuth2AuthorizationService(List<OAuth2Authorization> authorizations) {
-		Assert.notEmpty(authorizations, "authorizations cannot be empty");
-		this.authorizations = new CopyOnWriteArrayList<>(authorizations);
-	}
+	private final Map<OAuth2AuthorizationId, OAuth2Authorization> authorizations = new ConcurrentHashMap<>();
 
 
 	@Override
 	@Override
 	public void save(OAuth2Authorization authorization) {
 	public void save(OAuth2Authorization authorization) {
 		Assert.notNull(authorization, "authorization cannot be null");
 		Assert.notNull(authorization, "authorization cannot be null");
-		this.authorizations.add(authorization);
+		OAuth2AuthorizationId authorizationId = new OAuth2AuthorizationId(
+				authorization.getRegisteredClientId(), authorization.getPrincipalName());
+		this.authorizations.put(authorizationId, authorization);
 	}
 	}
 
 
 	@Override
 	@Override
 	public OAuth2Authorization findByToken(String token, @Nullable TokenType tokenType) {
 	public OAuth2Authorization findByToken(String token, @Nullable TokenType tokenType) {
 		Assert.hasText(token, "token cannot be empty");
 		Assert.hasText(token, "token cannot be empty");
-		return this.authorizations.stream()
+		return this.authorizations.values().stream()
 				.filter(authorization -> hasToken(authorization, token, tokenType))
 				.filter(authorization -> hasToken(authorization, token, tokenType))
 				.findFirst()
 				.findFirst()
 				.orElse(null);
 				.orElse(null);
@@ -72,4 +61,33 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
 		}
 		}
 		return false;
 		return false;
 	}
 	}
+
+	private static class OAuth2AuthorizationId implements Serializable {
+		private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID;
+		private final String registeredClientId;
+		private final String principalName;
+
+		private OAuth2AuthorizationId(String registeredClientId, String principalName) {
+			this.registeredClientId = registeredClientId;
+			this.principalName = principalName;
+		}
+
+		@Override
+		public boolean equals(Object obj) {
+			if (this == obj) {
+				return true;
+			}
+			if (obj == null || getClass() != obj.getClass()) {
+				return false;
+			}
+			OAuth2AuthorizationId that = (OAuth2AuthorizationId) obj;
+			return Objects.equals(this.registeredClientId, that.registeredClientId) &&
+					Objects.equals(this.principalName, that.principalName);
+		}
+
+		@Override
+		public int hashCode() {
+			return Objects.hash(this.registeredClientId, this.principalName);
+		}
+	}
 }
 }

+ 1 - 9
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java

@@ -22,7 +22,6 @@ import org.springframework.security.oauth2.server.authorization.client.Registere
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 
 
 import java.time.Instant;
 import java.time.Instant;
-import java.util.Collections;
 
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -43,13 +42,6 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 		this.authorizationService = new InMemoryOAuth2AuthorizationService();
 		this.authorizationService = new InMemoryOAuth2AuthorizationService();
 	}
 	}
 
 
-	@Test
-	public void constructorWhenAuthorizationListNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new InMemoryOAuth2AuthorizationService(null))
-				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("authorizations cannot be empty");
-	}
-
 	@Test
 	@Test
 	public void saveWhenAuthorizationNullThenThrowIllegalArgumentException() {
 	public void saveWhenAuthorizationNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> this.authorizationService.save(null))
 		assertThatThrownBy(() -> this.authorizationService.save(null))
@@ -83,7 +75,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 				.principalName(PRINCIPAL_NAME)
 				.principalName(PRINCIPAL_NAME)
 				.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
 				.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
 				.build();
 				.build();
-		this.authorizationService = new InMemoryOAuth2AuthorizationService(Collections.singletonList(authorization));
+		this.authorizationService.save(authorization);
 
 
 		OAuth2Authorization result = this.authorizationService.findByToken(
 		OAuth2Authorization result = this.authorizationService.findByToken(
 				AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE);
 				AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE);