|
@@ -18,8 +18,12 @@ package org.springframework.security.oauth2.jwt;
|
|
import java.security.interfaces.RSAPublicKey;
|
|
import java.security.interfaces.RSAPublicKey;
|
|
import java.time.Instant;
|
|
import java.time.Instant;
|
|
import java.util.Collections;
|
|
import java.util.Collections;
|
|
|
|
+import java.util.HashMap;
|
|
|
|
+import java.util.HashSet;
|
|
import java.util.LinkedHashMap;
|
|
import java.util.LinkedHashMap;
|
|
import java.util.Map;
|
|
import java.util.Map;
|
|
|
|
+import java.util.Set;
|
|
|
|
+import java.util.function.Consumer;
|
|
import java.util.function.Function;
|
|
import java.util.function.Function;
|
|
import javax.crypto.SecretKey;
|
|
import javax.crypto.SecretKey;
|
|
|
|
|
|
@@ -31,6 +35,7 @@ import com.nimbusds.jose.jwk.JWK;
|
|
import com.nimbusds.jose.jwk.JWKMatcher;
|
|
import com.nimbusds.jose.jwk.JWKMatcher;
|
|
import com.nimbusds.jose.jwk.JWKSelector;
|
|
import com.nimbusds.jose.jwk.JWKSelector;
|
|
import com.nimbusds.jose.jwk.source.JWKSecurityContextJWKSet;
|
|
import com.nimbusds.jose.jwk.source.JWKSecurityContextJWKSet;
|
|
|
|
+import com.nimbusds.jose.jwk.source.JWKSource;
|
|
import com.nimbusds.jose.proc.BadJOSEException;
|
|
import com.nimbusds.jose.proc.BadJOSEException;
|
|
import com.nimbusds.jose.proc.JWKSecurityContext;
|
|
import com.nimbusds.jose.proc.JWKSecurityContext;
|
|
import com.nimbusds.jose.proc.JWSKeySelector;
|
|
import com.nimbusds.jose.proc.JWSKeySelector;
|
|
@@ -233,7 +238,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
|
*/
|
|
*/
|
|
public static final class JwkSetUriReactiveJwtDecoderBuilder {
|
|
public static final class JwkSetUriReactiveJwtDecoderBuilder {
|
|
private final String jwkSetUri;
|
|
private final String jwkSetUri;
|
|
- private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.RS256;
|
|
|
|
|
|
+ private Set<SignatureAlgorithm> signatureAlgorithms = new HashSet<>();
|
|
private WebClient webClient = WebClient.create();
|
|
private WebClient webClient = WebClient.create();
|
|
|
|
|
|
private JwkSetUriReactiveJwtDecoderBuilder(String jwkSetUri) {
|
|
private JwkSetUriReactiveJwtDecoderBuilder(String jwkSetUri) {
|
|
@@ -242,15 +247,30 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
|
}
|
|
}
|
|
|
|
|
|
/**
|
|
/**
|
|
- * Use the given signing
|
|
|
|
- * <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithm</a>.
|
|
|
|
|
|
+ * Append the given signing
|
|
|
|
+ * <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithm</a>
|
|
|
|
+ * to the set of algorithms to use.
|
|
*
|
|
*
|
|
* @param signatureAlgorithm the algorithm to use
|
|
* @param signatureAlgorithm the algorithm to use
|
|
* @return a {@link JwkSetUriReactiveJwtDecoderBuilder} for further configurations
|
|
* @return a {@link JwkSetUriReactiveJwtDecoderBuilder} for further configurations
|
|
*/
|
|
*/
|
|
public JwkSetUriReactiveJwtDecoderBuilder jwsAlgorithm(SignatureAlgorithm signatureAlgorithm) {
|
|
public JwkSetUriReactiveJwtDecoderBuilder jwsAlgorithm(SignatureAlgorithm signatureAlgorithm) {
|
|
Assert.notNull(signatureAlgorithm, "sig cannot be null");
|
|
Assert.notNull(signatureAlgorithm, "sig cannot be null");
|
|
- this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
|
|
|
|
|
|
+ this.signatureAlgorithms.add(signatureAlgorithm);
|
|
|
|
+ return this;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ /**
|
|
|
|
+ * Configure the list of
|
|
|
|
+ * <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithms</a>
|
|
|
|
+ * to use with the given {@link Consumer}.
|
|
|
|
+ *
|
|
|
|
+ * @param signatureAlgorithmsConsumer a {@link Consumer} for further configuring the algorithm list
|
|
|
|
+ * @return a {@link JwkSetUriReactiveJwtDecoderBuilder} for further configurations
|
|
|
|
+ */
|
|
|
|
+ public JwkSetUriReactiveJwtDecoderBuilder jwsAlgorithms(Consumer<Set<SignatureAlgorithm>> signatureAlgorithmsConsumer) {
|
|
|
|
+ Assert.notNull(signatureAlgorithmsConsumer, "signatureAlgorithmsConsumer cannot be null");
|
|
|
|
+ signatureAlgorithmsConsumer.accept(this.signatureAlgorithms);
|
|
return this;
|
|
return this;
|
|
}
|
|
}
|
|
|
|
|
|
@@ -278,28 +298,53 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
|
return new NimbusReactiveJwtDecoder(processor());
|
|
return new NimbusReactiveJwtDecoder(processor());
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ JWSKeySelector<JWKSecurityContext> jwsKeySelector(JWKSource<JWKSecurityContext> jwkSource) {
|
|
|
|
+ if (this.signatureAlgorithms.isEmpty()) {
|
|
|
|
+ return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
|
|
|
|
+ } else if (this.signatureAlgorithms.size() == 1) {
|
|
|
|
+ JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(this.signatureAlgorithms.iterator().next().getName());
|
|
|
|
+ return new JWSVerificationKeySelector<>(jwsAlgorithm, jwkSource);
|
|
|
|
+ } else {
|
|
|
|
+ Map<JWSAlgorithm, JWSKeySelector<JWKSecurityContext>> jwsKeySelectors = new HashMap<>();
|
|
|
|
+ for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) {
|
|
|
|
+ JWSAlgorithm jwsAlg = JWSAlgorithm.parse(signatureAlgorithm.getName());
|
|
|
|
+ jwsKeySelectors.put(jwsAlg, new JWSVerificationKeySelector<>(jwsAlg, jwkSource));
|
|
|
|
+ }
|
|
|
|
+ return new JWSAlgorithmMapJWSKeySelector<>(jwsKeySelectors);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
Converter<JWT, Mono<JWTClaimsSet>> processor() {
|
|
Converter<JWT, Mono<JWTClaimsSet>> processor() {
|
|
JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet();
|
|
JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet();
|
|
-
|
|
|
|
- JWSKeySelector<JWKSecurityContext> jwsKeySelector =
|
|
|
|
- new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
|
|
|
|
DefaultJWTProcessor<JWKSecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
|
|
DefaultJWTProcessor<JWKSecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
|
|
|
|
+ JWSKeySelector<JWKSecurityContext> jwsKeySelector = jwsKeySelector(jwkSource);
|
|
jwtProcessor.setJWSKeySelector(jwsKeySelector);
|
|
jwtProcessor.setJWSKeySelector(jwsKeySelector);
|
|
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});
|
|
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});
|
|
|
|
|
|
ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource(this.jwkSetUri);
|
|
ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource(this.jwkSetUri);
|
|
source.setWebClient(this.webClient);
|
|
source.setWebClient(this.webClient);
|
|
|
|
|
|
|
|
+ Set<JWSAlgorithm> expectedJwsAlgorithms = getExpectedJwsAlgorithms(jwsKeySelector);
|
|
return jwt -> {
|
|
return jwt -> {
|
|
- JWKSelector selector = createSelector(jwt.getHeader());
|
|
|
|
|
|
+ JWKSelector selector = createSelector(expectedJwsAlgorithms, jwt.getHeader());
|
|
return source.get(selector)
|
|
return source.get(selector)
|
|
.onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e))
|
|
.onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e))
|
|
.map(jwkList -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwkList)));
|
|
.map(jwkList -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwkList)));
|
|
};
|
|
};
|
|
}
|
|
}
|
|
|
|
|
|
- private JWKSelector createSelector(Header header) {
|
|
|
|
- if (!this.jwsAlgorithm.equals(header.getAlgorithm())) {
|
|
|
|
|
|
+ private Set<JWSAlgorithm> getExpectedJwsAlgorithms(JWSKeySelector<?> jwsKeySelector) {
|
|
|
|
+ if (jwsKeySelector instanceof JWSVerificationKeySelector) {
|
|
|
|
+ return Collections.singleton(((JWSVerificationKeySelector<?>) jwsKeySelector).getExpectedJWSAlgorithm());
|
|
|
|
+ }
|
|
|
|
+ if (jwsKeySelector instanceof JWSAlgorithmMapJWSKeySelector) {
|
|
|
|
+ return ((JWSAlgorithmMapJWSKeySelector<?>) jwsKeySelector).getExpectedJWSAlgorithms();
|
|
|
|
+ }
|
|
|
|
+ throw new IllegalArgumentException("Unsupported key selector type " + jwsKeySelector.getClass());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private JWKSelector createSelector(Set<JWSAlgorithm> expectedJwsAlgorithms, Header header) {
|
|
|
|
+ if (!expectedJwsAlgorithms.contains(header.getAlgorithm())) {
|
|
throw new JwtException("Unsupported algorithm of " + header.getAlgorithm());
|
|
throw new JwtException("Unsupported algorithm of " + header.getAlgorithm());
|
|
}
|
|
}
|
|
|
|
|