|
@@ -16,15 +16,7 @@
|
|
|
|
|
|
package org.springframework.security.oauth2.jwt;
|
|
|
|
|
|
-import com.nimbusds.jose.KeySourceException;
|
|
|
-import com.nimbusds.jose.jwk.JWK;
|
|
|
-import com.nimbusds.jose.jwk.JWKMatcher;
|
|
|
-import com.nimbusds.jose.jwk.JWKSelector;
|
|
|
-import com.nimbusds.jose.jwk.source.JWKSetParseException;
|
|
|
-import com.nimbusds.jose.jwk.source.JWKSetRetrievalException;
|
|
|
-import java.io.IOException;
|
|
|
-import java.net.MalformedURLException;
|
|
|
-import java.net.URL;
|
|
|
+import java.net.URI;
|
|
|
import java.security.interfaces.RSAPublicKey;
|
|
|
import java.text.ParseException;
|
|
|
import java.util.Arrays;
|
|
@@ -32,7 +24,6 @@ import java.util.Collection;
|
|
|
import java.util.Collections;
|
|
|
import java.util.HashSet;
|
|
|
import java.util.LinkedHashMap;
|
|
|
-import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.Set;
|
|
|
import java.util.concurrent.locks.ReentrantLock;
|
|
@@ -43,8 +34,13 @@ import javax.crypto.SecretKey;
|
|
|
|
|
|
import com.nimbusds.jose.JOSEException;
|
|
|
import com.nimbusds.jose.JWSAlgorithm;
|
|
|
+import com.nimbusds.jose.KeySourceException;
|
|
|
+import com.nimbusds.jose.RemoteKeySourceException;
|
|
|
import com.nimbusds.jose.jwk.JWKSet;
|
|
|
+import com.nimbusds.jose.jwk.source.JWKSetCacheRefreshEvaluator;
|
|
|
+import com.nimbusds.jose.jwk.source.JWKSetSource;
|
|
|
import com.nimbusds.jose.jwk.source.JWKSource;
|
|
|
+import com.nimbusds.jose.jwk.source.JWKSourceBuilder;
|
|
|
import com.nimbusds.jose.proc.JWSKeySelector;
|
|
|
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
|
|
|
import com.nimbusds.jose.proc.SecurityContext;
|
|
@@ -170,7 +166,7 @@ public final class NimbusJwtDecoder implements JwtDecoder {
|
|
|
.build();
|
|
|
// @formatter:on
|
|
|
}
|
|
|
- catch (KeySourceException ex) {
|
|
|
+ catch (RemoteKeySourceException ex) {
|
|
|
this.logger.trace("Failed to retrieve JWK set", ex);
|
|
|
if (ex.getCause() instanceof ParseException) {
|
|
|
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex);
|
|
@@ -383,7 +379,11 @@ public final class NimbusJwtDecoder implements JwtDecoder {
|
|
|
|
|
|
JWKSource<SecurityContext> jwkSource() {
|
|
|
String jwkSetUri = this.jwkSetUri.apply(this.restOperations);
|
|
|
- return new SpringJWKSource<>(this.restOperations, this.cache, toURL(jwkSetUri), jwkSetUri);
|
|
|
+ return JWKSourceBuilder.create(new SpringJWKSource<>(this.restOperations, this.cache, jwkSetUri))
|
|
|
+ .refreshAheadCache(false)
|
|
|
+ .rateLimited(false)
|
|
|
+ .cache(this.cache instanceof NoOpCache)
|
|
|
+ .build();
|
|
|
}
|
|
|
|
|
|
JWTProcessor<SecurityContext> processor() {
|
|
@@ -405,16 +405,7 @@ public final class NimbusJwtDecoder implements JwtDecoder {
|
|
|
return new NimbusJwtDecoder(processor());
|
|
|
}
|
|
|
|
|
|
- private static URL toURL(String url) {
|
|
|
- try {
|
|
|
- return new URL(url);
|
|
|
- }
|
|
|
- catch (MalformedURLException ex) {
|
|
|
- throw new IllegalArgumentException("Invalid JWK Set URL \"" + url + "\" : " + ex.getMessage(), ex);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- private static final class SpringJWKSource<C extends SecurityContext> implements JWKSource<C> {
|
|
|
+ private static final class SpringJWKSource<C extends SecurityContext> implements JWKSetSource<C> {
|
|
|
|
|
|
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
|
|
|
|
|
@@ -424,120 +415,63 @@ public final class NimbusJwtDecoder implements JwtDecoder {
|
|
|
|
|
|
private final Cache cache;
|
|
|
|
|
|
- private final URL url;
|
|
|
-
|
|
|
private final String jwkSetUri;
|
|
|
|
|
|
- private SpringJWKSource(RestOperations restOperations, Cache cache, URL url, String jwkSetUri) {
|
|
|
+ private JWKSet jwkSet;
|
|
|
+
|
|
|
+ private SpringJWKSource(RestOperations restOperations, Cache cache, String jwkSetUri) {
|
|
|
Assert.notNull(restOperations, "restOperations cannot be null");
|
|
|
this.restOperations = restOperations;
|
|
|
this.cache = cache;
|
|
|
- this.url = url;
|
|
|
this.jwkSetUri = jwkSetUri;
|
|
|
- }
|
|
|
-
|
|
|
-
|
|
|
- @Override
|
|
|
- public List<JWK> get(JWKSelector jwkSelector, SecurityContext context) throws KeySourceException {
|
|
|
- String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class);
|
|
|
- JWKSet jwkSet = null;
|
|
|
- if (cachedJwkSet != null) {
|
|
|
- jwkSet = parse(cachedJwkSet);
|
|
|
- }
|
|
|
- if (jwkSet == null) {
|
|
|
- if(reentrantLock.tryLock()) {
|
|
|
- try {
|
|
|
- String cachedJwkSetAfterLock = this.cache.get(this.jwkSetUri, String.class);
|
|
|
- if (cachedJwkSetAfterLock != null) {
|
|
|
- jwkSet = parse(cachedJwkSetAfterLock);
|
|
|
- }
|
|
|
- if(jwkSet == null) {
|
|
|
- try {
|
|
|
- jwkSet = fetchJWKSet();
|
|
|
- } catch (IOException e) {
|
|
|
- throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e);
|
|
|
- }
|
|
|
- }
|
|
|
- } finally {
|
|
|
- reentrantLock.unlock();
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- List<JWK> matches = jwkSelector.select(jwkSet);
|
|
|
- if(!matches.isEmpty()) {
|
|
|
- return matches;
|
|
|
- }
|
|
|
- String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher());
|
|
|
- if (soughtKeyID == null) {
|
|
|
- return Collections.emptyList();
|
|
|
- }
|
|
|
- if (jwkSet.getKeyByKeyId(soughtKeyID) != null) {
|
|
|
- return Collections.emptyList();
|
|
|
- }
|
|
|
-
|
|
|
- if(reentrantLock.tryLock()) {
|
|
|
+ String jwks = this.cache.get(this.jwkSetUri, String.class);
|
|
|
+ if (jwks != null) {
|
|
|
try {
|
|
|
- String jwkSetUri = this.cache.get(this.jwkSetUri, String.class);
|
|
|
- JWKSet cacheJwkSet = parse(jwkSetUri);
|
|
|
- if(jwkSetUri != null && cacheJwkSet.toString().equals(jwkSet.toString())) {
|
|
|
- try {
|
|
|
- jwkSet = fetchJWKSet();
|
|
|
- } catch (IOException e) {
|
|
|
- throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e);
|
|
|
- }
|
|
|
- } else if (jwkSetUri != null) {
|
|
|
- jwkSet = parse(jwkSetUri);
|
|
|
- }
|
|
|
- } finally {
|
|
|
- reentrantLock.unlock();
|
|
|
+ this.jwkSet = JWKSet.parse(jwks);
|
|
|
+ }
|
|
|
+ catch (ParseException ignored) {
|
|
|
+ // Ignore invalid cache value
|
|
|
}
|
|
|
}
|
|
|
- if(jwkSet == null) {
|
|
|
- return Collections.emptyList();
|
|
|
- }
|
|
|
- return jwkSelector.select(jwkSet);
|
|
|
}
|
|
|
|
|
|
- private JWKSet fetchJWKSet() throws IOException, KeySourceException {
|
|
|
+ private String fetchJwks() throws Exception {
|
|
|
HttpHeaders headers = new HttpHeaders();
|
|
|
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON));
|
|
|
- ResponseEntity<String> response = getResponse(headers);
|
|
|
- if (response.getStatusCode().value() != 200) {
|
|
|
- throw new IOException(response.toString());
|
|
|
- }
|
|
|
- try {
|
|
|
- String jwkSet = response.getBody();
|
|
|
- this.cache.put(this.jwkSetUri, jwkSet);
|
|
|
- return JWKSet.parse(jwkSet);
|
|
|
- } catch (ParseException e) {
|
|
|
- throw new JWKSetParseException("Unable to parse JWK set", e);
|
|
|
- }
|
|
|
+ RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, URI.create(this.jwkSetUri));
|
|
|
+ ResponseEntity<String> response = this.restOperations.exchange(request, String.class);
|
|
|
+ String jwks = response.getBody();
|
|
|
+ this.jwkSet = JWKSet.parse(jwks);
|
|
|
+ return jwks;
|
|
|
}
|
|
|
|
|
|
- private ResponseEntity<String> getResponse(HttpHeaders headers) throws IOException {
|
|
|
+ @Override
|
|
|
+ public JWKSet getJWKSet(JWKSetCacheRefreshEvaluator refreshEvaluator, long currentTime, C context)
|
|
|
+ throws KeySourceException {
|
|
|
try {
|
|
|
- RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, this.url.toURI());
|
|
|
- return this.restOperations.exchange(request, String.class);
|
|
|
- } catch (Exception ex) {
|
|
|
- throw new IOException(ex);
|
|
|
+ this.reentrantLock.lock();
|
|
|
+ if (refreshEvaluator.requiresRefresh(this.jwkSet)) {
|
|
|
+ this.cache.invalidate();
|
|
|
+ }
|
|
|
+ this.cache.get(this.jwkSetUri, this::fetchJwks);
|
|
|
+ return this.jwkSet;
|
|
|
}
|
|
|
- }
|
|
|
-
|
|
|
- private JWKSet parse(String cachedJwkSet) {
|
|
|
- JWKSet jwkSet = null;
|
|
|
- try {
|
|
|
- jwkSet = JWKSet.parse(cachedJwkSet);
|
|
|
- } catch (ParseException ignored) {
|
|
|
- // Ignore invalid cache value
|
|
|
+ catch (Cache.ValueRetrievalException ex) {
|
|
|
+ if (ex.getCause() instanceof RemoteKeySourceException keys) {
|
|
|
+ throw keys;
|
|
|
+ }
|
|
|
+ throw new RemoteKeySourceException(ex.getCause().getMessage(), ex.getCause());
|
|
|
+ }
|
|
|
+ finally {
|
|
|
+ this.reentrantLock.unlock();
|
|
|
}
|
|
|
- return jwkSet;
|
|
|
}
|
|
|
|
|
|
- private String getFirstSpecifiedKeyID(JWKMatcher jwkMatcher) {
|
|
|
- Set<String> keyIDs = jwkMatcher.getKeyIDs();
|
|
|
- return (keyIDs == null || keyIDs.isEmpty()) ?
|
|
|
- null : keyIDs.stream().filter(id -> id != null).findFirst().orElse(null);
|
|
|
+ @Override
|
|
|
+ public void close() {
|
|
|
+
|
|
|
}
|
|
|
+
|
|
|
}
|
|
|
|
|
|
}
|