|
@@ -16,11 +16,14 @@
|
|
|
|
|
|
package org.springframework.security.oauth2.jwt;
|
|
|
|
|
|
+import java.net.MalformedURLException;
|
|
|
+import java.net.URL;
|
|
|
import java.security.interfaces.RSAPublicKey;
|
|
|
import java.util.Collection;
|
|
|
import java.util.Collections;
|
|
|
import java.util.HashSet;
|
|
|
import java.util.LinkedHashMap;
|
|
|
+import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.Set;
|
|
|
import java.util.function.Consumer;
|
|
@@ -28,6 +31,7 @@ import java.util.function.Function;
|
|
|
|
|
|
import javax.crypto.SecretKey;
|
|
|
|
|
|
+import com.nimbusds.jose.Algorithm;
|
|
|
import com.nimbusds.jose.Header;
|
|
|
import com.nimbusds.jose.JOSEException;
|
|
|
import com.nimbusds.jose.JWSAlgorithm;
|
|
@@ -35,6 +39,8 @@ import com.nimbusds.jose.JWSHeader;
|
|
|
import com.nimbusds.jose.jwk.JWK;
|
|
|
import com.nimbusds.jose.jwk.JWKMatcher;
|
|
|
import com.nimbusds.jose.jwk.JWKSelector;
|
|
|
+import com.nimbusds.jose.jwk.JWKSet;
|
|
|
+import com.nimbusds.jose.jwk.KeyUse;
|
|
|
import com.nimbusds.jose.jwk.source.JWKSecurityContextJWKSet;
|
|
|
import com.nimbusds.jose.jwk.source.JWKSource;
|
|
|
import com.nimbusds.jose.proc.BadJOSEException;
|
|
@@ -50,6 +56,8 @@ import com.nimbusds.jwt.SignedJWT;
|
|
|
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
|
|
|
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
|
|
|
import com.nimbusds.jwt.proc.JWTProcessor;
|
|
|
+import org.apache.commons.logging.Log;
|
|
|
+import org.apache.commons.logging.LogFactory;
|
|
|
import reactor.core.publisher.Flux;
|
|
|
import reactor.core.publisher.Mono;
|
|
|
|
|
@@ -273,6 +281,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
|
|
*/
|
|
|
public static final class JwkSetUriReactiveJwtDecoderBuilder {
|
|
|
|
|
|
+ private static final Log log = LogFactory.getLog(JwkSetUriReactiveJwtDecoderBuilder.class);
|
|
|
+
|
|
|
private final String jwkSetUri;
|
|
|
|
|
|
private Set<SignatureAlgorithm> signatureAlgorithms = new HashSet<>();
|
|
@@ -354,17 +364,63 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
|
|
}
|
|
|
|
|
|
JWSKeySelector<JWKSecurityContext> jwsKeySelector(JWKSource<JWKSecurityContext> jwkSource) {
|
|
|
- if (this.signatureAlgorithms.isEmpty()) {
|
|
|
- return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
|
|
|
+ Set<SignatureAlgorithm> algorithms = new HashSet<>();
|
|
|
+ if (!this.signatureAlgorithms.isEmpty()) {
|
|
|
+ algorithms.addAll(this.signatureAlgorithms);
|
|
|
+ } else {
|
|
|
+ algorithms.addAll(fetchSignatureAlgorithms());
|
|
|
+ }
|
|
|
+
|
|
|
+ if (algorithms.isEmpty()) {
|
|
|
+ algorithms.add(SignatureAlgorithm.RS256);
|
|
|
}
|
|
|
+
|
|
|
Set<JWSAlgorithm> jwsAlgorithms = new HashSet<>();
|
|
|
- for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) {
|
|
|
- JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
|
|
|
- jwsAlgorithms.add(jwsAlgorithm);
|
|
|
+ for (SignatureAlgorithm signatureAlgorithm : algorithms) {
|
|
|
+ jwsAlgorithms.add(JWSAlgorithm.parse(signatureAlgorithm.getName()));
|
|
|
}
|
|
|
+
|
|
|
return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource);
|
|
|
}
|
|
|
|
|
|
+ private Set<SignatureAlgorithm> fetchSignatureAlgorithms() {
|
|
|
+ if (StringUtils.isEmpty(jwkSetUri)) {
|
|
|
+ return Collections.emptySet();
|
|
|
+ }
|
|
|
+ try {
|
|
|
+ return parseAlgorithms(JWKSet.load(toURL(jwkSetUri), 5000, 5000, 0));
|
|
|
+ } catch (Exception ex) {
|
|
|
+ throw new IllegalArgumentException("Failed to load Signature Algorithms from remote JWK source.", ex);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private Set<SignatureAlgorithm> parseAlgorithms(JWKSet jwkSet) {
|
|
|
+ if (jwkSet == null) {
|
|
|
+ throw new IllegalArgumentException(String.format("No JWKs received from %s", jwkSetUri));
|
|
|
+ }
|
|
|
+
|
|
|
+ List<JWK> jwks = new ArrayList<>();
|
|
|
+ for (JWK jwk : jwkSet.getKeys()) {
|
|
|
+ KeyUse keyUse = jwk.getKeyUse();
|
|
|
+ if (keyUse != null && keyUse.equals(KeyUse.SIGNATURE)) {
|
|
|
+ jwks.add(jwk);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ Set<SignatureAlgorithm> algorithms = new HashSet<>();
|
|
|
+ for (JWK jwk : jwks) {
|
|
|
+ Algorithm algorithm = jwk.getAlgorithm();
|
|
|
+ if (algorithm != null) {
|
|
|
+ SignatureAlgorithm signatureAlgorithm = SignatureAlgorithm.from(algorithm.getName());
|
|
|
+ if (signatureAlgorithm != null) {
|
|
|
+ algorithms.add(signatureAlgorithm);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return algorithms;
|
|
|
+ }
|
|
|
+
|
|
|
Converter<JWT, Mono<JWTClaimsSet>> processor() {
|
|
|
JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet();
|
|
|
DefaultJWTProcessor<JWKSecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
|
|
@@ -399,6 +455,13 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
|
|
return new JWKSelector(JWKMatcher.forJWSHeader(jwsHeader));
|
|
|
}
|
|
|
|
|
|
+ private static URL toURL(String url) {
|
|
|
+ try {
|
|
|
+ return new URL(url);
|
|
|
+ } catch (MalformedURLException ex) {
|
|
|
+ throw new IllegalArgumentException("Invalid JWK Set URL \"" + url + "\" : " + ex.getMessage(), ex);
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
/**
|