Bladeren bron

Polish Nimbus JWK Source Implementation

Issue gh-16251
Josh Cummings 6 maanden geleden
bovenliggende
commit
11113adf62

+ 49 - 115
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java

@@ -16,15 +16,7 @@
 
 package org.springframework.security.oauth2.jwt;
 
-import com.nimbusds.jose.KeySourceException;
-import com.nimbusds.jose.jwk.JWK;
-import com.nimbusds.jose.jwk.JWKMatcher;
-import com.nimbusds.jose.jwk.JWKSelector;
-import com.nimbusds.jose.jwk.source.JWKSetParseException;
-import com.nimbusds.jose.jwk.source.JWKSetRetrievalException;
-import java.io.IOException;
-import java.net.MalformedURLException;
-import java.net.URL;
+import java.net.URI;
 import java.security.interfaces.RSAPublicKey;
 import java.text.ParseException;
 import java.util.Arrays;
@@ -32,7 +24,6 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
-import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.locks.ReentrantLock;
@@ -43,8 +34,13 @@ import javax.crypto.SecretKey;
 
 import com.nimbusds.jose.JOSEException;
 import com.nimbusds.jose.JWSAlgorithm;
+import com.nimbusds.jose.KeySourceException;
+import com.nimbusds.jose.RemoteKeySourceException;
 import com.nimbusds.jose.jwk.JWKSet;
+import com.nimbusds.jose.jwk.source.JWKSetCacheRefreshEvaluator;
+import com.nimbusds.jose.jwk.source.JWKSetSource;
 import com.nimbusds.jose.jwk.source.JWKSource;
+import com.nimbusds.jose.jwk.source.JWKSourceBuilder;
 import com.nimbusds.jose.proc.JWSKeySelector;
 import com.nimbusds.jose.proc.JWSVerificationKeySelector;
 import com.nimbusds.jose.proc.SecurityContext;
@@ -170,7 +166,7 @@ public final class NimbusJwtDecoder implements JwtDecoder {
 					.build();
 			// @formatter:on
 		}
-		catch (KeySourceException ex) {
+		catch (RemoteKeySourceException ex) {
 			this.logger.trace("Failed to retrieve JWK set", ex);
 			if (ex.getCause() instanceof ParseException) {
 				throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex);
@@ -383,7 +379,11 @@ public final class NimbusJwtDecoder implements JwtDecoder {
 
 		JWKSource<SecurityContext> jwkSource() {
 			String jwkSetUri = this.jwkSetUri.apply(this.restOperations);
-			return new SpringJWKSource<>(this.restOperations, this.cache, toURL(jwkSetUri), jwkSetUri);
+			return JWKSourceBuilder.create(new SpringJWKSource<>(this.restOperations, this.cache, jwkSetUri))
+				.refreshAheadCache(false)
+				.rateLimited(false)
+				.cache(this.cache instanceof NoOpCache)
+				.build();
 		}
 
 		JWTProcessor<SecurityContext> processor() {
@@ -405,16 +405,7 @@ public final class NimbusJwtDecoder implements JwtDecoder {
 			return new NimbusJwtDecoder(processor());
 		}
 
-		private static URL toURL(String url) {
-			try {
-				return new URL(url);
-			}
-			catch (MalformedURLException ex) {
-				throw new IllegalArgumentException("Invalid JWK Set URL \"" + url + "\" : " + ex.getMessage(), ex);
-			}
-		}
-
-		private static final class SpringJWKSource<C extends SecurityContext> implements JWKSource<C> {
+		private static final class SpringJWKSource<C extends SecurityContext> implements JWKSetSource<C> {
 
 			private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
 
@@ -424,120 +415,63 @@ public final class NimbusJwtDecoder implements JwtDecoder {
 
 			private final Cache cache;
 
-			private final URL url;
-
 			private final String jwkSetUri;
 
-			private SpringJWKSource(RestOperations restOperations, Cache cache, URL url, String jwkSetUri) {
+			private JWKSet jwkSet;
+
+			private SpringJWKSource(RestOperations restOperations, Cache cache, String jwkSetUri) {
 				Assert.notNull(restOperations, "restOperations cannot be null");
 				this.restOperations = restOperations;
 				this.cache = cache;
-				this.url = url;
 				this.jwkSetUri = jwkSetUri;
-			}
-
-
-			@Override
-			public List<JWK> get(JWKSelector jwkSelector, SecurityContext context) throws KeySourceException {
-				String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class);
-				JWKSet jwkSet = null;
-				if (cachedJwkSet != null) {
-					jwkSet = parse(cachedJwkSet);
-				}
-				if (jwkSet == null) {
-					if(reentrantLock.tryLock()) {
-						try {
-							String cachedJwkSetAfterLock = this.cache.get(this.jwkSetUri, String.class);
-							if (cachedJwkSetAfterLock != null) {
-								jwkSet = parse(cachedJwkSetAfterLock);
-							}
-							if(jwkSet == null) {
-								try {
-									jwkSet = fetchJWKSet();
-								} catch (IOException e) {
-									throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e);
-								}
-							}
-						} finally {
-							reentrantLock.unlock();
-						}
-					}
-				}
-				List<JWK> matches = jwkSelector.select(jwkSet);
-				if(!matches.isEmpty()) {
-					return matches;
-				}
-				String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher());
-				if (soughtKeyID == null) {
-					return Collections.emptyList();
-				}
-				if (jwkSet.getKeyByKeyId(soughtKeyID) != null) {
-					return Collections.emptyList();
-				}
-
-				if(reentrantLock.tryLock()) {
+				String jwks = this.cache.get(this.jwkSetUri, String.class);
+				if (jwks != null) {
 					try {
-						String jwkSetUri = this.cache.get(this.jwkSetUri, String.class);
-						JWKSet cacheJwkSet = parse(jwkSetUri);
-						if(jwkSetUri != null && cacheJwkSet.toString().equals(jwkSet.toString())) {
-							try {
-								jwkSet = fetchJWKSet();
-							} catch (IOException e) {
-								throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e);
-							}
-						} else if (jwkSetUri != null) {
-							jwkSet = parse(jwkSetUri);
-						}
-					} finally {
-						reentrantLock.unlock();
+						this.jwkSet = JWKSet.parse(jwks);
+					}
+					catch (ParseException ignored) {
+						// Ignore invalid cache value
 					}
 				}
-				if(jwkSet == null) {
-					return Collections.emptyList();
-				}
-				return jwkSelector.select(jwkSet);
 			}
 
-			private JWKSet fetchJWKSet() throws IOException, KeySourceException {
+			private String fetchJwks() throws Exception {
 				HttpHeaders headers = new HttpHeaders();
 				headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON));
-				ResponseEntity<String> response = getResponse(headers);
-				if (response.getStatusCode().value() != 200) {
-					throw new IOException(response.toString());
-				}
-				try {
-					String jwkSet = response.getBody();
-					this.cache.put(this.jwkSetUri, jwkSet);
-					return JWKSet.parse(jwkSet);
-				} catch (ParseException e) {
-					throw new JWKSetParseException("Unable to parse JWK set", e);
-				}
+				RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, URI.create(this.jwkSetUri));
+				ResponseEntity<String> response = this.restOperations.exchange(request, String.class);
+				String jwks = response.getBody();
+				this.jwkSet = JWKSet.parse(jwks);
+				return jwks;
 			}
 
-			private ResponseEntity<String> getResponse(HttpHeaders headers) throws IOException {
+			@Override
+			public JWKSet getJWKSet(JWKSetCacheRefreshEvaluator refreshEvaluator, long currentTime, C context)
+					throws KeySourceException {
 				try {
-					RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, this.url.toURI());
-					return this.restOperations.exchange(request, String.class);
-				} catch (Exception ex) {
-					throw new IOException(ex);
+					this.reentrantLock.lock();
+					if (refreshEvaluator.requiresRefresh(this.jwkSet)) {
+						this.cache.invalidate();
+					}
+					this.cache.get(this.jwkSetUri, this::fetchJwks);
+					return this.jwkSet;
 				}
-			}
-
-			private JWKSet parse(String cachedJwkSet) {
-				JWKSet jwkSet = null;
-				try {
-					jwkSet = JWKSet.parse(cachedJwkSet);
-				} catch (ParseException ignored) {
-					// Ignore invalid cache value
+				catch (Cache.ValueRetrievalException ex) {
+					if (ex.getCause() instanceof RemoteKeySourceException keys) {
+						throw keys;
+					}
+					throw new RemoteKeySourceException(ex.getCause().getMessage(), ex.getCause());
+				}
+				finally {
+					this.reentrantLock.unlock();
 				}
-				return jwkSet;
 			}
 
-			private String getFirstSpecifiedKeyID(JWKMatcher jwkMatcher) {
-				Set<String> keyIDs = jwkMatcher.getKeyIDs();
-				return (keyIDs == null || keyIDs.isEmpty()) ?
-						null : keyIDs.stream().filter(id -> id != null).findFirst().orElse(null);
+			@Override
+			public void close() {
+
 			}
+
 		}
 
 	}

+ 1 - 2
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2025 the original author or authors.
+ * Copyright 2002-2019 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -308,7 +308,6 @@ public class JwtDecodersTests {
 	private void prepareConfigurationResponse(String body) {
 		this.server.enqueue(response(body));
 		this.server.enqueue(response(JWK_SET));
-		this.server.enqueue(response(JWK_SET)); // default NoOpCache
 	}
 
 	private void prepareConfigurationResponseOidc() {