|
@@ -17,7 +17,6 @@ package org.springframework.security.oauth2.jwt;
|
|
|
|
|
|
import java.security.interfaces.RSAPublicKey;
|
|
|
import java.util.Collections;
|
|
|
-import java.util.HashMap;
|
|
|
import java.util.HashSet;
|
|
|
import java.util.LinkedHashMap;
|
|
|
import java.util.Map;
|
|
@@ -307,16 +306,13 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
|
|
JWSKeySelector<JWKSecurityContext> jwsKeySelector(JWKSource<JWKSecurityContext> jwkSource) {
|
|
|
if (this.signatureAlgorithms.isEmpty()) {
|
|
|
return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
|
|
|
- } else if (this.signatureAlgorithms.size() == 1) {
|
|
|
- JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(this.signatureAlgorithms.iterator().next().getName());
|
|
|
- return new JWSVerificationKeySelector<>(jwsAlgorithm, jwkSource);
|
|
|
} else {
|
|
|
- Map<JWSAlgorithm, JWSKeySelector<JWKSecurityContext>> jwsKeySelectors = new HashMap<>();
|
|
|
+ Set<JWSAlgorithm> jwsAlgorithms = new HashSet<>();
|
|
|
for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) {
|
|
|
- JWSAlgorithm jwsAlg = JWSAlgorithm.parse(signatureAlgorithm.getName());
|
|
|
- jwsKeySelectors.put(jwsAlg, new JWSVerificationKeySelector<>(jwsAlg, jwkSource));
|
|
|
+ JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
|
|
|
+ jwsAlgorithms.add(jwsAlgorithm);
|
|
|
}
|
|
|
- return new JWSAlgorithmMapJWSKeySelector<>(jwsKeySelectors);
|
|
|
+ return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -330,7 +326,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
|
|
ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource(this.jwkSetUri);
|
|
|
source.setWebClient(this.webClient);
|
|
|
|
|
|
- Set<JWSAlgorithm> expectedJwsAlgorithms = getExpectedJwsAlgorithms(jwsKeySelector);
|
|
|
+ Function<JWSAlgorithm, Boolean> expectedJwsAlgorithms = getExpectedJwsAlgorithms(jwsKeySelector);
|
|
|
return jwt -> {
|
|
|
JWKSelector selector = createSelector(expectedJwsAlgorithms, jwt.getHeader());
|
|
|
return source.get(selector)
|
|
@@ -339,22 +335,20 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
|
|
};
|
|
|
}
|
|
|
|
|
|
- private Set<JWSAlgorithm> getExpectedJwsAlgorithms(JWSKeySelector<?> jwsKeySelector) {
|
|
|
+ private Function<JWSAlgorithm, Boolean> getExpectedJwsAlgorithms(JWSKeySelector<?> jwsKeySelector) {
|
|
|
if (jwsKeySelector instanceof JWSVerificationKeySelector) {
|
|
|
- return Collections.singleton(((JWSVerificationKeySelector<?>) jwsKeySelector).getExpectedJWSAlgorithm());
|
|
|
- }
|
|
|
- if (jwsKeySelector instanceof JWSAlgorithmMapJWSKeySelector) {
|
|
|
- return ((JWSAlgorithmMapJWSKeySelector<?>) jwsKeySelector).getExpectedJWSAlgorithms();
|
|
|
+ return ((JWSVerificationKeySelector<?>) jwsKeySelector)::isAllowed;
|
|
|
}
|
|
|
throw new IllegalArgumentException("Unsupported key selector type " + jwsKeySelector.getClass());
|
|
|
}
|
|
|
|
|
|
- private JWKSelector createSelector(Set<JWSAlgorithm> expectedJwsAlgorithms, Header header) {
|
|
|
- if (!expectedJwsAlgorithms.contains(header.getAlgorithm())) {
|
|
|
+ private JWKSelector createSelector(Function<JWSAlgorithm, Boolean> expectedJwsAlgorithms, Header header) {
|
|
|
+ JWSHeader jwsHeader = (JWSHeader) header;
|
|
|
+ if (!expectedJwsAlgorithms.apply(jwsHeader.getAlgorithm())) {
|
|
|
throw new BadJwtException("Unsupported algorithm of " + header.getAlgorithm());
|
|
|
}
|
|
|
|
|
|
- return new JWKSelector(JWKMatcher.forJWSHeader((JWSHeader) header));
|
|
|
+ return new JWKSelector(JWKMatcher.forJWSHeader(jwsHeader));
|
|
|
}
|
|
|
}
|
|
|
|