|
@@ -1,5 +1,5 @@
|
|
|
/*
|
|
|
- * Copyright 2020-2021 the original author or authors.
|
|
|
+ * Copyright 2020-2022 the original author or authors.
|
|
|
*
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
* you may not use this file except in compliance with the License.
|
|
@@ -17,6 +17,7 @@ package org.springframework.security.oauth2.server.authorization;
|
|
|
|
|
|
import java.util.Arrays;
|
|
|
import java.util.Collections;
|
|
|
+import java.util.LinkedHashMap;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.concurrent.ConcurrentHashMap;
|
|
@@ -41,8 +42,29 @@ import org.springframework.util.Assert;
|
|
|
* @see OAuth2AuthorizationService
|
|
|
*/
|
|
|
public final class InMemoryOAuth2AuthorizationService implements OAuth2AuthorizationService {
|
|
|
+ private int maxInitializedAuthorizations = 100;
|
|
|
+
|
|
|
+ /*
|
|
|
+ * Stores "initialized" (uncompleted) authorizations, where an access token has not yet been granted.
|
|
|
+ * This state occurs with the authorization_code grant flow during the user consent step OR
|
|
|
+ * when the code is returned in the authorization response but the access token request is not yet initiated.
|
|
|
+ */
|
|
|
+ private Map<String, OAuth2Authorization> initializedAuthorizations =
|
|
|
+ Collections.synchronizedMap(new MaxSizeHashMap<>(this.maxInitializedAuthorizations));
|
|
|
+
|
|
|
+ /*
|
|
|
+ * Stores "completed" authorizations, where an access token has been granted.
|
|
|
+ */
|
|
|
private final Map<String, OAuth2Authorization> authorizations = new ConcurrentHashMap<>();
|
|
|
|
|
|
+ /*
|
|
|
+ * Constructor used for testing only.
|
|
|
+ */
|
|
|
+ InMemoryOAuth2AuthorizationService(int maxInitializedAuthorizations) {
|
|
|
+ this.maxInitializedAuthorizations = maxInitializedAuthorizations;
|
|
|
+ this.initializedAuthorizations = Collections.synchronizedMap(new MaxSizeHashMap<>(this.maxInitializedAuthorizations));
|
|
|
+ }
|
|
|
+
|
|
|
/**
|
|
|
* Constructs an {@code InMemoryOAuth2AuthorizationService}.
|
|
|
*/
|
|
@@ -77,20 +99,31 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
|
|
|
@Override
|
|
|
public void save(OAuth2Authorization authorization) {
|
|
|
Assert.notNull(authorization, "authorization cannot be null");
|
|
|
- this.authorizations.put(authorization.getId(), authorization);
|
|
|
+ if (isComplete(authorization)) {
|
|
|
+ this.authorizations.put(authorization.getId(), authorization);
|
|
|
+ } else {
|
|
|
+ this.initializedAuthorizations.put(authorization.getId(), authorization);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public void remove(OAuth2Authorization authorization) {
|
|
|
Assert.notNull(authorization, "authorization cannot be null");
|
|
|
- this.authorizations.remove(authorization.getId(), authorization);
|
|
|
+ if (isComplete(authorization)) {
|
|
|
+ this.authorizations.remove(authorization.getId(), authorization);
|
|
|
+ } else {
|
|
|
+ this.initializedAuthorizations.remove(authorization.getId(), authorization);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
@Nullable
|
|
|
@Override
|
|
|
public OAuth2Authorization findById(String id) {
|
|
|
Assert.hasText(id, "id cannot be empty");
|
|
|
- return this.authorizations.get(id);
|
|
|
+ OAuth2Authorization authorization = this.authorizations.get(id);
|
|
|
+ return authorization != null ?
|
|
|
+ authorization :
|
|
|
+ this.initializedAuthorizations.get(id);
|
|
|
}
|
|
|
|
|
|
@Nullable
|
|
@@ -102,9 +135,18 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
|
|
|
return authorization;
|
|
|
}
|
|
|
}
|
|
|
+ for (OAuth2Authorization authorization : this.initializedAuthorizations.values()) {
|
|
|
+ if (hasToken(authorization, token, tokenType)) {
|
|
|
+ return authorization;
|
|
|
+ }
|
|
|
+ }
|
|
|
return null;
|
|
|
}
|
|
|
|
|
|
+ private static boolean isComplete(OAuth2Authorization authorization) {
|
|
|
+ return authorization.getAccessToken() != null;
|
|
|
+ }
|
|
|
+
|
|
|
private static boolean hasToken(OAuth2Authorization authorization, String token, @Nullable OAuth2TokenType tokenType) {
|
|
|
if (tokenType == null) {
|
|
|
return matchesState(authorization, token) ||
|
|
@@ -144,4 +186,19 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
|
|
|
authorization.getToken(OAuth2RefreshToken.class);
|
|
|
return refreshToken != null && refreshToken.getToken().getTokenValue().equals(token);
|
|
|
}
|
|
|
+
|
|
|
+ private static final class MaxSizeHashMap<K, V> extends LinkedHashMap<K, V> {
|
|
|
+ private final int maxSize;
|
|
|
+
|
|
|
+ private MaxSizeHashMap(int maxSize) {
|
|
|
+ this.maxSize = maxSize;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
|
|
|
+ return size() > this.maxSize;
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
}
|