Эх сурвалжийг харах

Optimize InMemoryOAuth2AuthorizationService

Closes gh-654
Joe Grandja 3 жил өмнө
parent
commit
5b7d900424

+ 61 - 4
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright 2020-2021 the original author or authors.
+ * Copyright 2020-2022 the original author or authors.
  *
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@ package org.springframework.security.oauth2.server.authorization;
 
 
 import java.util.Arrays;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.Collections;
+import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.List;
 import java.util.Map;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentHashMap;
@@ -41,8 +42,29 @@ import org.springframework.util.Assert;
  * @see OAuth2AuthorizationService
  * @see OAuth2AuthorizationService
  */
  */
 public final class InMemoryOAuth2AuthorizationService implements OAuth2AuthorizationService {
 public final class InMemoryOAuth2AuthorizationService implements OAuth2AuthorizationService {
+	private int maxInitializedAuthorizations = 100;
+
+	/*
+	 * Stores "initialized" (uncompleted) authorizations, where an access token has not yet been granted.
+	 * This state occurs with the authorization_code grant flow during the user consent step OR
+	 * when the code is returned in the authorization response but the access token request is not yet initiated.
+	 */
+	private Map<String, OAuth2Authorization> initializedAuthorizations =
+			Collections.synchronizedMap(new MaxSizeHashMap<>(this.maxInitializedAuthorizations));
+
+	/*
+	 * Stores "completed" authorizations, where an access token has been granted.
+	 */
 	private final Map<String, OAuth2Authorization> authorizations = new ConcurrentHashMap<>();
 	private final Map<String, OAuth2Authorization> authorizations = new ConcurrentHashMap<>();
 
 
+	/*
+	 * Constructor used for testing only.
+	 */
+	InMemoryOAuth2AuthorizationService(int maxInitializedAuthorizations) {
+		this.maxInitializedAuthorizations = maxInitializedAuthorizations;
+		this.initializedAuthorizations = Collections.synchronizedMap(new MaxSizeHashMap<>(this.maxInitializedAuthorizations));
+	}
+
 	/**
 	/**
 	 * Constructs an {@code InMemoryOAuth2AuthorizationService}.
 	 * Constructs an {@code InMemoryOAuth2AuthorizationService}.
 	 */
 	 */
@@ -77,20 +99,31 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
 	@Override
 	@Override
 	public void save(OAuth2Authorization authorization) {
 	public void save(OAuth2Authorization authorization) {
 		Assert.notNull(authorization, "authorization cannot be null");
 		Assert.notNull(authorization, "authorization cannot be null");
-		this.authorizations.put(authorization.getId(), authorization);
+		if (isComplete(authorization)) {
+			this.authorizations.put(authorization.getId(), authorization);
+		} else {
+			this.initializedAuthorizations.put(authorization.getId(), authorization);
+		}
 	}
 	}
 
 
 	@Override
 	@Override
 	public void remove(OAuth2Authorization authorization) {
 	public void remove(OAuth2Authorization authorization) {
 		Assert.notNull(authorization, "authorization cannot be null");
 		Assert.notNull(authorization, "authorization cannot be null");
-		this.authorizations.remove(authorization.getId(), authorization);
+		if (isComplete(authorization)) {
+			this.authorizations.remove(authorization.getId(), authorization);
+		} else {
+			this.initializedAuthorizations.remove(authorization.getId(), authorization);
+		}
 	}
 	}
 
 
 	@Nullable
 	@Nullable
 	@Override
 	@Override
 	public OAuth2Authorization findById(String id) {
 	public OAuth2Authorization findById(String id) {
 		Assert.hasText(id, "id cannot be empty");
 		Assert.hasText(id, "id cannot be empty");
-		return this.authorizations.get(id);
+		OAuth2Authorization authorization = this.authorizations.get(id);
+		return authorization != null ?
+				authorization :
+				this.initializedAuthorizations.get(id);
 	}
 	}
 
 
 	@Nullable
 	@Nullable
@@ -102,9 +135,18 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
 				return authorization;
 				return authorization;
 			}
 			}
 		}
 		}
+		for (OAuth2Authorization authorization : this.initializedAuthorizations.values()) {
+			if (hasToken(authorization, token, tokenType)) {
+				return authorization;
+			}
+		}
 		return null;
 		return null;
 	}
 	}
 
 
+	private static boolean isComplete(OAuth2Authorization authorization) {
+		return authorization.getAccessToken() != null;
+	}
+
 	private static boolean hasToken(OAuth2Authorization authorization, String token, @Nullable OAuth2TokenType tokenType) {
 	private static boolean hasToken(OAuth2Authorization authorization, String token, @Nullable OAuth2TokenType tokenType) {
 		if (tokenType == null) {
 		if (tokenType == null) {
 			return matchesState(authorization, token) ||
 			return matchesState(authorization, token) ||
@@ -144,4 +186,19 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
 				authorization.getToken(OAuth2RefreshToken.class);
 				authorization.getToken(OAuth2RefreshToken.class);
 		return refreshToken != null && refreshToken.getToken().getTokenValue().equals(token);
 		return refreshToken != null && refreshToken.getToken().getTokenValue().equals(token);
 	}
 	}
+
+	private static final class MaxSizeHashMap<K, V> extends LinkedHashMap<K, V> {
+		private final int maxSize;
+
+		private MaxSizeHashMap(int maxSize) {
+			this.maxSize = maxSize;
+		}
+
+		@Override
+		protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
+			return size() > this.maxSize;
+		}
+
+	}
+
 }
 }

+ 0 - 2
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java

@@ -266,8 +266,6 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 					.build();
 					.build();
 			this.authorizationService.save(authorization);
 			this.authorizationService.save(authorization);
 
 
-			// TODO Need to remove 'in-flight' authorization if consent step is not completed (e.g. approved or cancelled)
-
 			Set<String> currentAuthorizedScopes = currentAuthorizationConsent != null ?
 			Set<String> currentAuthorizedScopes = currentAuthorizationConsent != null ?
 					currentAuthorizationConsent.getScopes() : null;
 					currentAuthorizationConsent.getScopes() : null;
 
 

+ 38 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright 2020-2021 the original author or authors.
+ * Copyright 2020-2022 the original author or authors.
  *
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -132,6 +132,43 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 		assertThat(authorization).isNotEqualTo(originalAuthorization);
 		assertThat(authorization).isNotEqualTo(originalAuthorization);
 	}
 	}
 
 
+	@Test
+	public void saveWhenInitializedAuthorizationsReachMaxThenOldestRemoved() {
+		int maxInitializedAuthorizations = 5;
+		InMemoryOAuth2AuthorizationService authorizationService =
+				new InMemoryOAuth2AuthorizationService(maxInitializedAuthorizations);
+
+		OAuth2Authorization initialAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+				.id(ID + "-initial")
+				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
+				.attribute(OAuth2ParameterNames.STATE, "state-initial")
+				.build();
+		authorizationService.save(initialAuthorization);
+
+		OAuth2Authorization authorization = authorizationService.findById(initialAuthorization.getId());
+		assertThat(authorization).isEqualTo(initialAuthorization);
+		authorization = authorizationService.findByToken(
+				initialAuthorization.getAttribute(OAuth2ParameterNames.STATE), STATE_TOKEN_TYPE);
+		assertThat(authorization).isEqualTo(initialAuthorization);
+
+		for (int i = 0; i < maxInitializedAuthorizations; i++) {
+			authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+					.id(ID + "-" + i)
+					.principalName(PRINCIPAL_NAME)
+					.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
+					.attribute(OAuth2ParameterNames.STATE, "state-" + i)
+					.build();
+			authorizationService.save(authorization);
+		}
+
+		authorization = authorizationService.findById(initialAuthorization.getId());
+		assertThat(authorization).isNull();
+		authorization = authorizationService.findByToken(
+				initialAuthorization.getAttribute(OAuth2ParameterNames.STATE), STATE_TOKEN_TYPE);
+		assertThat(authorization).isNull();
+	}
+
 	@Test
 	@Test
 	public void removeWhenAuthorizationNullThenThrowIllegalArgumentException() {
 	public void removeWhenAuthorizationNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> this.authorizationService.remove(null))
 		assertThatThrownBy(() -> this.authorizationService.remove(null))