Browse Source

Implement internal cache in JtiClaimValidator

Closes gh-17107
Joe Grandja 3 months ago
parent
commit
5f7155bfc7

+ 20 - 5
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/DPoPProofJwtDecoderFactory.java

@@ -18,12 +18,12 @@ package org.springframework.security.oauth2.jwt;
 
 import java.nio.charset.StandardCharsets;
 import java.security.MessageDigest;
-import java.time.Clock;
 import java.time.Instant;
+import java.time.temporal.ChronoUnit;
 import java.util.Base64;
 import java.util.Collections;
+import java.util.LinkedHashMap;
 import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Function;
 
 import com.nimbusds.jose.JOSEException;
@@ -146,7 +146,7 @@ public final class DPoPProofJwtDecoderFactory implements JwtDecoderFactory<DPoPP
 
 	private static final class JtiClaimValidator implements OAuth2TokenValidator<Jwt> {
 
-		private static final Map<String, Long> jtiCache = new ConcurrentHashMap<>();
+		private static final Map<String, Long> JTI_CACHE = Collections.synchronizedMap(new JtiCache());
 
 		@Override
 		public OAuth2TokenValidatorResult validate(Jwt jwt) {
@@ -166,8 +166,8 @@ public final class DPoPProofJwtDecoderFactory implements JwtDecoderFactory<DPoPP
 				OAuth2Error error = createOAuth2Error("jti claim is invalid.");
 				return OAuth2TokenValidatorResult.failure(error);
 			}
-			Instant now = Instant.now(Clock.systemUTC());
-			if ((jtiCache.putIfAbsent(jtiHash, now.toEpochMilli())) != null) {
+			Instant expiry = Instant.now().plus(1, ChronoUnit.HOURS);
+			if ((JTI_CACHE.putIfAbsent(jtiHash, expiry.toEpochMilli())) != null) {
 				// Already used
 				OAuth2Error error = createOAuth2Error("jti claim is invalid.");
 				return OAuth2TokenValidatorResult.failure(error);
@@ -185,6 +185,21 @@ public final class DPoPProofJwtDecoderFactory implements JwtDecoderFactory<DPoPP
 			return Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
 		}
 
+		private static final class JtiCache extends LinkedHashMap<String, Long> {
+
+			private static final int MAX_SIZE = 1000;
+
+			@Override
+			protected boolean removeEldestEntry(Map.Entry<String, Long> eldest) {
+				if (size() > MAX_SIZE) {
+					return true;
+				}
+				Instant expiry = Instant.ofEpochMilli(eldest.getValue());
+				return Instant.now().isAfter(expiry);
+			}
+
+		}
+
 	}
 
 }