Browse Source

Remove Deprecated Usages of RemoteJWKSet

Closes gh-16251

Signed-off-by: Daeho Kwon <trewq231@naver.com>
Daeho Kwon 6 months ago
parent
commit
7b7abb28bb

+ 113 - 67
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2023 the original author or authors.
+ * Copyright 2002-2025 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,6 +16,12 @@
 
 package org.springframework.security.oauth2.jwt;
 
+import com.nimbusds.jose.KeySourceException;
+import com.nimbusds.jose.jwk.JWK;
+import com.nimbusds.jose.jwk.JWKMatcher;
+import com.nimbusds.jose.jwk.JWKSelector;
+import com.nimbusds.jose.jwk.source.JWKSetParseException;
+import com.nimbusds.jose.jwk.source.JWKSetRetrievalException;
 import java.io.IOException;
 import java.net.MalformedURLException;
 import java.net.URL;
@@ -26,8 +32,10 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.locks.ReentrantLock;
 import java.util.function.Consumer;
 import java.util.function.Function;
 
@@ -35,17 +43,12 @@ import javax.crypto.SecretKey;
 
 import com.nimbusds.jose.JOSEException;
 import com.nimbusds.jose.JWSAlgorithm;
-import com.nimbusds.jose.RemoteKeySourceException;
 import com.nimbusds.jose.jwk.JWKSet;
-import com.nimbusds.jose.jwk.source.JWKSetCache;
 import com.nimbusds.jose.jwk.source.JWKSource;
-import com.nimbusds.jose.jwk.source.RemoteJWKSet;
 import com.nimbusds.jose.proc.JWSKeySelector;
 import com.nimbusds.jose.proc.JWSVerificationKeySelector;
 import com.nimbusds.jose.proc.SecurityContext;
 import com.nimbusds.jose.proc.SingleKeyJWSKeySelector;
-import com.nimbusds.jose.util.Resource;
-import com.nimbusds.jose.util.ResourceRetriever;
 import com.nimbusds.jwt.JWT;
 import com.nimbusds.jwt.JWTClaimsSet;
 import com.nimbusds.jwt.JWTParser;
@@ -57,6 +60,7 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 
 import org.springframework.cache.Cache;
+import org.springframework.cache.support.NoOpCache;
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
@@ -80,6 +84,7 @@ import org.springframework.web.client.RestTemplate;
  * @author Josh Cummings
  * @author Joe Grandja
  * @author Mykyta Bezverkhyi
+ * @author Daeho Kwon
  * @since 5.2
  */
 public final class NimbusJwtDecoder implements JwtDecoder {
@@ -165,7 +170,7 @@ public final class NimbusJwtDecoder implements JwtDecoder {
 					.build();
 			// @formatter:on
 		}
-		catch (RemoteKeySourceException ex) {
+		catch (KeySourceException ex) {
 			this.logger.trace("Failed to retrieve JWK set", ex);
 			if (ex.getCause() instanceof ParseException) {
 				throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex);
@@ -273,7 +278,7 @@ public final class NimbusJwtDecoder implements JwtDecoder {
 
 		private RestOperations restOperations = new RestTemplate();
 
-		private Cache cache;
+		private Cache cache = new NoOpCache("default");
 
 		private Consumer<ConfigurableJWTProcessor<SecurityContext>> jwtProcessorCustomizer;
 
@@ -376,18 +381,13 @@ public final class NimbusJwtDecoder implements JwtDecoder {
 			return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource);
 		}
 
-		JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever, String jwkSetUri) {
-			if (this.cache == null) {
-				return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever);
-			}
-			JWKSetCache jwkSetCache = new SpringJWKSetCache(jwkSetUri, this.cache);
-			return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever, jwkSetCache);
+		JWKSource<SecurityContext> jwkSource() {
+			String jwkSetUri = this.jwkSetUri.apply(this.restOperations);
+			return new SpringJWKSource<>(this.restOperations, this.cache, toURL(jwkSetUri), jwkSetUri);
 		}
 
 		JWTProcessor<SecurityContext> processor() {
-			ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever(this.restOperations);
-			String jwkSetUri = this.jwkSetUri.apply(this.restOperations);
-			JWKSource<SecurityContext> jwkSource = jwkSource(jwkSetRetriever, jwkSetUri);
+			JWKSource<SecurityContext> jwkSource = jwkSource();
 			ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
 			jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource));
 			// Spring Security validates the claim set independent from Nimbus
@@ -414,84 +414,130 @@ public final class NimbusJwtDecoder implements JwtDecoder {
 			}
 		}
 
-		private static final class SpringJWKSetCache implements JWKSetCache {
+		private static final class SpringJWKSource<C extends SecurityContext> implements JWKSource<C> {
 
-			private final String jwkSetUri;
+			private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
+
+			private final ReentrantLock reentrantLock = new ReentrantLock();
+
+			private final RestOperations restOperations;
 
 			private final Cache cache;
 
-			private JWKSet jwkSet;
+			private final URL url;
 
-			SpringJWKSetCache(String jwkSetUri, Cache cache) {
-				this.jwkSetUri = jwkSetUri;
+			private final String jwkSetUri;
+
+			private SpringJWKSource(RestOperations restOperations, Cache cache, URL url, String jwkSetUri) {
+				Assert.notNull(restOperations, "restOperations cannot be null");
+				this.restOperations = restOperations;
 				this.cache = cache;
-				this.updateJwkSetFromCache();
+				this.url = url;
+				this.jwkSetUri = jwkSetUri;
 			}
 
-			private void updateJwkSetFromCache() {
+
+			@Override
+			public List<JWK> get(JWKSelector jwkSelector, SecurityContext context) throws KeySourceException {
 				String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class);
+				JWKSet jwkSet = null;
 				if (cachedJwkSet != null) {
-					try {
-						this.jwkSet = JWKSet.parse(cachedJwkSet);
-					}
-					catch (ParseException ignored) {
-						// Ignore invalid cache value
+					jwkSet = parse(cachedJwkSet);
+				}
+				if (jwkSet == null) {
+					if(reentrantLock.tryLock()) {
+						try {
+							String cachedJwkSetAfterLock = this.cache.get(this.jwkSetUri, String.class);
+							if (cachedJwkSetAfterLock != null) {
+								jwkSet = parse(cachedJwkSetAfterLock);
+							}
+							if(jwkSet == null) {
+								try {
+									jwkSet = fetchJWKSet();
+								} catch (IOException e) {
+									throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e);
+								}
+							}
+						} finally {
+							reentrantLock.unlock();
+						}
 					}
 				}
-			}
-
-			// Note: Only called from inside a synchronized block in RemoteJWKSet.
-			@Override
-			public void put(JWKSet jwkSet) {
-				this.jwkSet = jwkSet;
-				this.cache.put(this.jwkSetUri, jwkSet.toString(false));
-			}
-
-			@Override
-			public JWKSet get() {
-				return (!requiresRefresh()) ? this.jwkSet : null;
-
-			}
-
-			@Override
-			public boolean requiresRefresh() {
-				return this.cache.get(this.jwkSetUri) == null;
-			}
-
-		}
-
-		private static class RestOperationsResourceRetriever implements ResourceRetriever {
-
-			private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
-
-			private final RestOperations restOperations;
+				List<JWK> matches = jwkSelector.select(jwkSet);
+				if(!matches.isEmpty()) {
+					return matches;
+				}
+				String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher());
+				if (soughtKeyID == null) {
+					return Collections.emptyList();
+				}
+				if (jwkSet.getKeyByKeyId(soughtKeyID) != null) {
+					return Collections.emptyList();
+				}
 
-			RestOperationsResourceRetriever(RestOperations restOperations) {
-				Assert.notNull(restOperations, "restOperations cannot be null");
-				this.restOperations = restOperations;
+				if(reentrantLock.tryLock()) {
+					try {
+						String jwkSetUri = this.cache.get(this.jwkSetUri, String.class);
+						JWKSet cacheJwkSet = parse(jwkSetUri);
+						if(jwkSetUri != null && cacheJwkSet.toString().equals(jwkSet.toString())) {
+							try {
+								jwkSet = fetchJWKSet();
+							} catch (IOException e) {
+								throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e);
+							}
+						} else if (jwkSetUri != null) {
+							jwkSet = parse(jwkSetUri);
+						}
+					} finally {
+						reentrantLock.unlock();
+					}
+				}
+				if(jwkSet == null) {
+					return Collections.emptyList();
+				}
+				return jwkSelector.select(jwkSet);
 			}
 
-			@Override
-			public Resource retrieveResource(URL url) throws IOException {
+			private JWKSet fetchJWKSet() throws IOException, KeySourceException {
 				HttpHeaders headers = new HttpHeaders();
 				headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON));
-				ResponseEntity<String> response = getResponse(url, headers);
+				ResponseEntity<String> response = getResponse(headers);
 				if (response.getStatusCode().value() != 200) {
 					throw new IOException(response.toString());
 				}
-				return new Resource(response.getBody(), "UTF-8");
+				try {
+					String jwkSet = response.getBody();
+					this.cache.put(this.jwkSetUri, jwkSet);
+					return JWKSet.parse(jwkSet);
+				} catch (ParseException e) {
+					throw new JWKSetParseException("Unable to parse JWK set", e);
+				}
 			}
 
-			private ResponseEntity<String> getResponse(URL url, HttpHeaders headers) throws IOException {
+			private ResponseEntity<String> getResponse(HttpHeaders headers) throws IOException {
 				try {
-					RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, url.toURI());
+					RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, this.url.toURI());
 					return this.restOperations.exchange(request, String.class);
-				}
-				catch (Exception ex) {
+				} catch (Exception ex) {
 					throw new IOException(ex);
 				}
 			}
 
+			private JWKSet parse(String cachedJwkSet) {
+				JWKSet jwkSet = null;
+				try {
+					jwkSet = JWKSet.parse(cachedJwkSet);
+				} catch (ParseException ignored) {
+					// Ignore invalid cache value
+				}
+				return jwkSet;
+			}
+
+			private String getFirstSpecifiedKeyID(JWKMatcher jwkMatcher) {
+				Set<String> keyIDs = jwkMatcher.getKeyIDs();
+				return (keyIDs == null || keyIDs.isEmpty()) ?
+						null : keyIDs.stream().filter(id -> id != null).findFirst().orElse(null);
+			}
 		}
 
 	}

+ 2 - 1
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2025 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -308,6 +308,7 @@ public class JwtDecodersTests {
 	private void prepareConfigurationResponse(String body) {
 		this.server.enqueue(response(body));
 		this.server.enqueue(response(JWK_SET));
+		this.server.enqueue(response(JWK_SET)); // default NoOpCache
 	}
 
 	private void prepareConfigurationResponseOidc() {