Răsfoiți Sursa

NimbusReactiveJwtDecoder Takes Reactive Processor

Fixes: gh-5937
Josh Cummings 6 ani în urmă
părinte
comite
55e8df1efe

+ 0 - 43
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JWKContext.java

@@ -1,43 +0,0 @@
-/*
- * Copyright 2002-2018 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.springframework.security.oauth2.jwt;
-
-import com.nimbusds.jose.jwk.JWK;
-import com.nimbusds.jose.proc.SecurityContext;
-import org.springframework.util.Assert;
-
-import java.util.List;
-
-/**
- * A {@link SecurityContext} that is used by {@link JWKContextJWKSource}.
- *
- * @author Rob Winch
- * @since 5.1
- * @see JWKContextJWKSource
- */
-class JWKContext implements SecurityContext {
-	private final List<JWK> jwkList;
-
-	JWKContext(List<JWK> jwkList) {
-		Assert.notNull(jwkList, "jwkList cannot be null");
-		this.jwkList = jwkList;
-	}
-
-	public List<JWK> getJwkList() {
-		return this.jwkList;
-	}
-}

+ 0 - 43
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JWKContextJWKSource.java

@@ -1,43 +0,0 @@
-/*
- * Copyright 2002-2018 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.springframework.security.oauth2.jwt;
-
-import com.nimbusds.jose.jwk.JWK;
-import com.nimbusds.jose.jwk.JWKSelector;
-import com.nimbusds.jose.jwk.source.JWKSource;
-
-import java.util.List;
-
-/**
- * A {@link JWKSource} used for reactive applications that returns the {@link JWK} from the {@link JWKContext}.
- *
- * <p>
- * The Nimbus {@link JWKSource} is a blocking API which means the {@link JWK} cannot be resolved using code that blocks.
- * This means that the JWK Set could not be retrieved from HTTP endpoint. To work around this the {@link JWK} is
- * resolved in the {@link ReactiveJwtDecoder} and provided via the {@link JWKContext}.
- * </p>
- *
- * @author Rob Winch
- * @since 5.1
- */
-class JWKContextJWKSource implements JWKSource<JWKContext> {
-
-	@Override
-	public List<JWK> get(JWKSelector jwkSelector, JWKContext context) {
-		return context.getJwkList();
-	}
-}

+ 0 - 62
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JWKSelectorFactory.java

@@ -1,62 +0,0 @@
-/*
- * Copyright 2002-2018 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.springframework.security.oauth2.jwt;
-
-import com.nimbusds.jose.JWSAlgorithm;
-import com.nimbusds.jose.JWSHeader;
-import com.nimbusds.jose.KeySourceException;
-import com.nimbusds.jose.jwk.JWKMatcher;
-import com.nimbusds.jose.jwk.JWKSelector;
-import com.nimbusds.jose.proc.JWSVerificationKeySelector;
-
-/**
- * @author Rob Winch
- * @since 5.1
- */
-class JWKSelectorFactory {
-	private final DelegateSelectorFactory delegate;
-
-	JWKSelectorFactory(JWSAlgorithm expectedJWSAlgorithm) {
-		this.delegate = new DelegateSelectorFactory(expectedJWSAlgorithm);
-	}
-
-	JWKSelector createSelector(JWSHeader jwsHeader) {
-		return new JWKSelector(this.delegate.createJWKMatcher(jwsHeader));
-	}
-
-	/**
-	 * Used to expose the protected {@link #createJWKMatcher(JWSHeader)} method.
-	 */
-	private static class DelegateSelectorFactory extends JWSVerificationKeySelector {
-		/**
-		 * Creates a new JWS verification key selector.
-		 *
-		 * @param jwsAlg    The expected JWS algorithm for the objects to be
-		 *                  verified. Must not be {@code null}.
-		 */
-		public DelegateSelectorFactory(JWSAlgorithm jwsAlg) {
-			super(jwsAlg, (jwkSelector, context) -> {
-				throw new KeySourceException("JWKSelectorFactory is only intended for creating a selector");
-			});
-		}
-
-		@Override
-		public JWKMatcher createJWKMatcher(JWSHeader jwsHeader) {
-			return super.createJWKMatcher(jwsHeader);
-		}
-	}
-}

+ 262 - 53
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2019 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -19,26 +19,32 @@ import java.security.interfaces.RSAPublicKey;
 import java.time.Instant;
 import java.util.Collections;
 import java.util.LinkedHashMap;
-import java.util.List;
 import java.util.Map;
+import java.util.function.Function;
 
 import com.nimbusds.jose.JOSEException;
 import com.nimbusds.jose.JWSAlgorithm;
+import com.nimbusds.jose.JWSHeader;
 import com.nimbusds.jose.jwk.JWK;
+import com.nimbusds.jose.jwk.JWKMatcher;
 import com.nimbusds.jose.jwk.JWKSelector;
 import com.nimbusds.jose.jwk.JWKSet;
 import com.nimbusds.jose.jwk.RSAKey;
 import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
+import com.nimbusds.jose.jwk.source.JWKSecurityContextJWKSet;
 import com.nimbusds.jose.jwk.source.JWKSource;
 import com.nimbusds.jose.proc.BadJOSEException;
+import com.nimbusds.jose.proc.JWKSecurityContext;
 import com.nimbusds.jose.proc.JWSKeySelector;
 import com.nimbusds.jose.proc.JWSVerificationKeySelector;
+import com.nimbusds.jose.proc.SecurityContext;
 import com.nimbusds.jwt.JWT;
 import com.nimbusds.jwt.JWTClaimsSet;
 import com.nimbusds.jwt.JWTParser;
 import com.nimbusds.jwt.SignedJWT;
 import com.nimbusds.jwt.proc.DefaultJWTProcessor;
 import com.nimbusds.jwt.proc.JWTProcessor;
+import reactor.core.publisher.Flux;
 import reactor.core.publisher.Mono;
 
 import org.springframework.core.convert.converter.Converter;
@@ -46,6 +52,7 @@ import org.springframework.security.oauth2.core.OAuth2TokenValidator;
 import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
 import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
 import org.springframework.util.Assert;
+import org.springframework.web.reactive.function.client.WebClient;
 
 /**
  * An implementation of a {@link ReactiveJwtDecoder} that &quot;decodes&quot; a
@@ -65,31 +72,14 @@ import org.springframework.util.Assert;
  * @see <a target="_blank" href="https://connect2id.com/products/nimbus-jose-jwt">Nimbus JOSE + JWT SDK</a>
  */
 public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
-	private final JWTProcessor<JWKContext> jwtProcessor;
-
-	private final ReactiveJWKSource reactiveJwkSource;
-
-	private final JWKSelectorFactory jwkSelectorFactory;
+	private final Converter<SignedJWT, Mono<JWTClaimsSet>> jwtProcessor;
 
 	private OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefault();
 	private Converter<Map<String, Object>, Map<String, Object>> claimSetConverter = MappedJwtClaimSetConverter
 			.withDefaults(Collections.emptyMap());
 
 	public NimbusReactiveJwtDecoder(RSAPublicKey publicKey) {
-		JWSAlgorithm algorithm = JWSAlgorithm.parse(JwsAlgorithms.RS256);
-
-		RSAKey rsaKey = rsaKey(publicKey);
-		JWKSet jwkSet = new JWKSet(rsaKey);
-		JWKSource jwkSource = new ImmutableJWKSet<>(jwkSet);
-		JWSKeySelector<JWKContext> jwsKeySelector =
-				new JWSVerificationKeySelector<>(algorithm, jwkSource);
-		DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>();
-		jwtProcessor.setJWSKeySelector(jwsKeySelector);
-		jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});
-
-		this.jwtProcessor = jwtProcessor;
-		this.reactiveJwkSource = new ReactiveJWKSourceAdapter(jwkSource);
-		this.jwkSelectorFactory = new JWKSelectorFactory(algorithm);
+		this.jwtProcessor = withPublicKey(publicKey).processor();
 	}
 
 	/**
@@ -98,22 +88,11 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 	 * @param jwkSetUrl the JSON Web Key (JWK) Set {@code URL}
 	 */
 	public NimbusReactiveJwtDecoder(String jwkSetUrl) {
-		Assert.hasText(jwkSetUrl, "jwkSetUrl cannot be empty");
-		String jwsAlgorithm = JwsAlgorithms.RS256;
-		JWSAlgorithm algorithm = JWSAlgorithm.parse(jwsAlgorithm);
-		JWKSource jwkSource = new JWKContextJWKSource();
-		JWSKeySelector<JWKContext> jwsKeySelector =
-				new JWSVerificationKeySelector<>(algorithm, jwkSource);
-
-		DefaultJWTProcessor<JWKContext> jwtProcessor = new DefaultJWTProcessor<>();
-		jwtProcessor.setJWSKeySelector(jwsKeySelector);
-		jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});
-		this.jwtProcessor = jwtProcessor;
-
-		this.reactiveJwkSource = new ReactiveRemoteJWKSource(jwkSetUrl);
-
-		this.jwkSelectorFactory = new JWKSelectorFactory(algorithm);
+		this.jwtProcessor = withJwkSetUri(jwkSetUrl).processor();
+	}
 
+	public NimbusReactiveJwtDecoder(Converter<SignedJWT, Mono<JWTClaimsSet>> jwtProcessor) {
+		this.jwtProcessor = jwtProcessor;
 	}
 
 	/**
@@ -155,11 +134,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 
 	private Mono<Jwt> decode(SignedJWT parsedToken) {
 		try {
-			JWKSelector selector = this.jwkSelectorFactory
-					.createSelector(parsedToken.getHeader());
-			return this.reactiveJwkSource.get(selector)
-				.onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e))
-				.map(jwkList -> createClaimsSet(parsedToken, jwkList))
+			return this.jwtProcessor.convert(parsedToken)
 				.map(set -> createJwt(parsedToken, set))
 				.map(this::validateJwt)
 				.onErrorMap(e -> !(e instanceof IllegalStateException) && !(e instanceof JwtException), e -> new JwtException("An error occurred while attempting to decode the Jwt: ", e));
@@ -168,15 +143,6 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 		}
 	}
 
-	private JWTClaimsSet createClaimsSet(JWT parsedToken, List<JWK> jwkList) {
-		try {
-			return this.jwtProcessor.process(parsedToken, new JWKContext(jwkList));
-		}
-		catch (BadJOSEException | JOSEException e) {
-			throw new JwtException("Failed to validate the token", e);
-		}
-	}
-
 	private Jwt createJwt(JWT parsedJwt, JWTClaimsSet jwtClaimsSet) {
 		Map<String, Object> headers = new LinkedHashMap<>(parsedJwt.getHeader().toJSONObject());
 		Map<String, Object> claims = this.claimSetConverter.convert(jwtClaimsSet.getClaims());
@@ -197,8 +163,251 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 		return jwt;
 	}
 
-	private static RSAKey rsaKey(RSAPublicKey publicKey) {
-		return new RSAKey.Builder(publicKey)
-				.build();
+	/**
+	 * Use the given
+	 * <a href="https://tools.ietf.org/html/rfc7517#section-5">JWK Set</a> uri to validate JWTs.
+	 *
+	 * @param jwkSetUri the JWK Set uri to use
+	 * @return a {@link JwkSetUriReactiveJwtDecoderBuilder} for further configurations
+	 *
+	 * @since 5.2
+	 */
+	public static JwkSetUriReactiveJwtDecoderBuilder withJwkSetUri(String jwkSetUri) {
+		return new JwkSetUriReactiveJwtDecoderBuilder(jwkSetUri);
+	}
+
+	/**
+	 * Use the given public key to validate JWTs
+	 *
+	 * @param key the public key to use
+	 * @return a {@link PublicKeyReactiveJwtDecoderBuilder} for further configurations
+	 *
+	 * @since 5.2
+	 */
+	public static PublicKeyReactiveJwtDecoderBuilder withPublicKey(RSAPublicKey key) {
+		return new PublicKeyReactiveJwtDecoderBuilder(key);
+	}
+
+	/**
+	 * Use the given {@link Function} to validate JWTs
+	 *
+	 * @param source the {@link Function}
+	 * @return a {@link JwkSourceReactiveJwtDecoderBuilder} for further configurations
+	 *
+	 * @since 5.2
+	 */
+	public static JwkSourceReactiveJwtDecoderBuilder withJwkSource(Function<JWT, Flux<JWK>> source) {
+		return new JwkSourceReactiveJwtDecoderBuilder(source);
+	}
+
+	/**
+	 * A builder for creating {@link NimbusReactiveJwtDecoder} instances based on a
+	 * <a target="_blank" href="https://tools.ietf.org/html/rfc7517#section-5">JWK Set</a> uri.
+	 *
+	 * @since 5.2
+	 */
+	public static final class JwkSetUriReactiveJwtDecoderBuilder {
+
+		private String jwkSetUri;
+		private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.RS256;
+		private WebClient webClient = WebClient.create();
+
+		private JwkSetUriReactiveJwtDecoderBuilder(String jwkSetUri) {
+			Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty");
+			this.jwkSetUri = jwkSetUri;
+		}
+
+		/**
+		 * Use the given signing
+		 * <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithm</a>.
+		 *
+		 * @param jwsAlgorithm the algorithm to use
+		 * @return a {@link JwkSetUriReactiveJwtDecoderBuilder} for further configurations
+		 */
+		public JwkSetUriReactiveJwtDecoderBuilder jwsAlgorithm(String jwsAlgorithm) {
+			Assert.hasText(jwsAlgorithm, "jwsAlgorithm cannot be empty");
+			this.jwsAlgorithm = JWSAlgorithm.parse(jwsAlgorithm);
+			return this;
+		}
+
+		/**
+		 * Use the given {@link WebClient} to coordinate with the authorization servers indicated in the
+		 * <a href="https://tools.ietf.org/html/rfc7517#section-5">JWK Set</a> uri
+		 * as well as the
+		 * <a href="http://openid.net/specs/openid-connect-core-1_0.html#IssuerIdentifier">Issuer</a>.
+		 *
+		 * @param webClient
+		 * @return a {@link JwkSetUriReactiveJwtDecoderBuilder} for further configurations
+		 */
+		public JwkSetUriReactiveJwtDecoderBuilder webClient(WebClient webClient) {
+			Assert.notNull(webClient, "webClient cannot be null");
+			this.webClient = webClient;
+			return this;
+		}
+
+		/**
+		 * Build the configured {@link NimbusReactiveJwtDecoder}.
+		 *
+		 * @return the configured {@link NimbusReactiveJwtDecoder}
+		 */
+		public NimbusReactiveJwtDecoder build() {
+			return new NimbusReactiveJwtDecoder(processor());
+		}
+
+		Converter<SignedJWT, Mono<JWTClaimsSet>> processor() {
+			JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet();
+
+			JWSKeySelector<JWKSecurityContext> jwsKeySelector =
+					new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
+			DefaultJWTProcessor<JWKSecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
+			jwtProcessor.setJWSKeySelector(jwsKeySelector);
+			jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});
+
+			ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource(this.jwkSetUri);
+			source.setWebClient(this.webClient);
+
+			return signedJWT -> {
+				JWKSelector selector = createSelector(signedJWT.getHeader());
+				return source.get(selector)
+						.onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e))
+						.map(jwkList -> createClaimsSet(jwtProcessor, signedJWT, new JWKSecurityContext(jwkList)));
+			};
+		}
+
+		private JWKSelector createSelector(JWSHeader header) {
+			if (!this.jwsAlgorithm.equals(header.getAlgorithm())) {
+				throw new JwtException("Unsupported algorithm of " + header.getAlgorithm());
+			}
+
+			return new JWKSelector(JWKMatcher.forJWSHeader(header));
+		}
+	}
+
+	/**
+	 * A builder for creating Nimbus {@link JWTProcessor} instances based on a
+	 * public key.
+	 *
+	 * @since 5.2
+	 */
+	public static final class PublicKeyReactiveJwtDecoderBuilder {
+		private JWSAlgorithm jwsAlgorithm;
+		private RSAKey key;
+
+		private PublicKeyReactiveJwtDecoderBuilder(RSAPublicKey key) {
+			Assert.notNull(key, "key cannot be null");
+			this.jwsAlgorithm = JWSAlgorithm.parse(JwsAlgorithms.RS256);
+			this.key = rsaKey(key);
+		}
+
+		private static RSAKey rsaKey(RSAPublicKey publicKey) {
+			return new RSAKey.Builder(publicKey)
+					.build();
+		}
+
+		/**
+		 * Use the given signing
+		 * <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithm</a>.
+		 * The value should be one of
+		 * <a href="https://tools.ietf.org/html/rfc7518#section-3.3" target="_blank">RS256, RS384, or RS512</a>.
+		 *
+		 * @param jwsAlgorithm the algorithm to use
+		 * @return a {@link PublicKeyReactiveJwtDecoderBuilder} for further configurations
+		 */
+		public PublicKeyReactiveJwtDecoderBuilder jwsAlgorithm(String jwsAlgorithm) {
+			Assert.hasText(jwsAlgorithm, "jwsAlgorithm cannot be empty");
+			this.jwsAlgorithm = JWSAlgorithm.parse(jwsAlgorithm);
+			return this;
+		}
+
+		/**
+		 * Build the configured {@link NimbusReactiveJwtDecoder}.
+		 *
+		 * @return the configured {@link NimbusReactiveJwtDecoder}
+		 */
+		public NimbusReactiveJwtDecoder build() {
+			return new NimbusReactiveJwtDecoder(processor());
+		}
+
+		Converter<SignedJWT, Mono<JWTClaimsSet>> processor() {
+			if (!JWSAlgorithm.Family.RSA.contains(this.jwsAlgorithm)) {
+				throw new IllegalStateException("The provided key is of type RSA; " +
+						"however the signature algorithm is of some other type: " +
+						this.jwsAlgorithm + ". Please indicate one of RS256, RS384, or RS512.");
+			}
+
+			JWKSet jwkSet = new JWKSet(this.key);
+			JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(jwkSet);
+			JWSKeySelector<SecurityContext> jwsKeySelector =
+					new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
+			DefaultJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
+			jwtProcessor.setJWSKeySelector(jwsKeySelector);
+
+			// Spring Security validates the claim set independent from Nimbus
+			jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { });
+
+			return signedJWT -> Mono.just(signedJWT).map(jwt -> createClaimsSet(jwtProcessor, jwt, null));
+		}
+	}
+
+	/**
+	 * A builder for creating {@link NimbusReactiveJwtDecoder} instances.
+	 *
+	 * @since 5.2
+	 */
+	public static final class JwkSourceReactiveJwtDecoderBuilder {
+		private Function<JWT, Flux<JWK>> jwkSource;
+		private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.RS256;
+
+		private JwkSourceReactiveJwtDecoderBuilder(Function<JWT, Flux<JWK>> jwkSource) {
+			Assert.notNull(jwkSource, "jwkSource cannot be empty");
+			this.jwkSource = jwkSource;
+		}
+
+		/**
+		 * Use the given signing
+		 * <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithm</a>.
+		 *
+		 * @param jwsAlgorithm the algorithm to use
+		 * @return a {@link JwkSourceReactiveJwtDecoderBuilder} for further configurations
+		 */
+		public JwkSourceReactiveJwtDecoderBuilder jwsAlgorithm(String jwsAlgorithm) {
+			Assert.hasText(jwsAlgorithm, "jwsAlgorithm cannot be empty");
+			this.jwsAlgorithm = JWSAlgorithm.parse(jwsAlgorithm);
+			return this;
+		}
+
+		/**
+		 * Build the configured {@link NimbusReactiveJwtDecoder}.
+		 *
+		 * @return the configured {@link NimbusReactiveJwtDecoder}
+		 */
+		public NimbusReactiveJwtDecoder build() {
+			return new NimbusReactiveJwtDecoder(processor());
+		}
+
+		Converter<SignedJWT, Mono<JWTClaimsSet>> processor() {
+			JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet();
+			JWSKeySelector<JWKSecurityContext> jwsKeySelector =
+					new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
+			DefaultJWTProcessor<JWKSecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
+			jwtProcessor.setJWSKeySelector(jwsKeySelector);
+			jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});
+
+			return signedJWT ->
+					this.jwkSource.apply(signedJWT)
+							.onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e))
+							.collectList()
+							.map(jwks -> createClaimsSet(jwtProcessor, signedJWT, new JWKSecurityContext(jwks)));
+		}
+	}
+
+	private static <C extends SecurityContext> JWTClaimsSet createClaimsSet(JWTProcessor<C> jwtProcessor,
+																			JWT parsedToken, C context) {
+		try {
+			return jwtProcessor.process(parsedToken, context);
+		}
+		catch (BadJOSEException | JOSEException e) {
+			throw new JwtException("Failed to validate the token", e);
+		}
 	}
 }

+ 16 - 9
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2019 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,19 +16,21 @@
 
 package org.springframework.security.oauth2.jwt;
 
+import java.text.ParseException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicReference;
+
 import com.nimbusds.jose.RemoteKeySourceException;
 import com.nimbusds.jose.jwk.JWK;
 import com.nimbusds.jose.jwk.JWKMatcher;
 import com.nimbusds.jose.jwk.JWKSelector;
 import com.nimbusds.jose.jwk.JWKSet;
-import org.springframework.web.reactive.function.client.WebClient;
 import reactor.core.publisher.Mono;
 
-import java.text.ParseException;
-import java.util.Collections;
-import java.util.List;
-import java.util.Set;
-import java.util.concurrent.atomic.AtomicReference;
+import org.springframework.util.Assert;
+import org.springframework.web.reactive.function.client.WebClient;
 
 /**
  * @author Rob Winch
@@ -45,14 +47,15 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource {
 	private final String jwkSetURL;
 
 	ReactiveRemoteJWKSource(String jwkSetURL) {
+		Assert.hasText(jwkSetURL, "jwkSetURL cannot be empty");
 		this.jwkSetURL = jwkSetURL;
 	}
 
 	public Mono<List<JWK>> get(JWKSelector jwkSelector) {
 		return this.cachedJWKSet.get()
-				.switchIfEmpty(getJWKSet())
+				.switchIfEmpty(Mono.defer(() -> getJWKSet()))
 				.flatMap(jwkSet -> get(jwkSelector, jwkSet))
-				.switchIfEmpty(getJWKSet().map(jwkSet -> jwkSelector.select(jwkSet)));
+				.switchIfEmpty(Mono.defer(() -> getJWKSet().map(jwkSet -> jwkSelector.select(jwkSet))));
 	}
 
 	private Mono<List<JWK>> get(JWKSelector jwkSelector, JWKSet jwkSet) {
@@ -133,4 +136,8 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource {
 		}
 		return null; // No kid in matcher
 	}
+
+	public void setWebClient(WebClient webClient) {
+		this.webClient = webClient;
+	}
 }

+ 0 - 42
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JWKContextJWKSourceTests.java

@@ -1,42 +0,0 @@
-/*
- * Copyright 2002-2018 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.springframework.security.oauth2.jwt;
-
-import com.nimbusds.jose.jwk.JWK;
-import org.junit.Test;
-
-import java.util.Arrays;
-
-import static org.assertj.core.api.Assertions.*;
-import static org.mockito.Mockito.mock;
-
-/**
- * @author Rob Winch
- * @since 5.1
- */
-public class JWKContextJWKSourceTests {
-	private JWKContextJWKSource source = new JWKContextJWKSource();
-
-	@Test
-	public void getWhenKeysNotEmptyThenContainsKeys() {
-		JWK key = mock(JWK.class);
-		JWKContext jwkContext = new JWKContext(Arrays.asList(key));
-
-		assertThat(this.source.get(null, jwkContext)).containsOnly(key);
-	}
-
-}

+ 0 - 54
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JWKContextTests.java

@@ -1,54 +0,0 @@
-/*
- * Copyright 2002-2018 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.springframework.security.oauth2.jwt;
-
-import com.nimbusds.jose.jwk.JWK;
-import org.junit.Test;
-
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-
-import static org.assertj.core.api.Assertions.*;
-import static org.mockito.Mockito.mock;
-
-/**
- * @author Rob Winch
- * @since 5.1
- */
-public class JWKContextTests {
-
-	@Test
-	public void constructorWhenNullThenIllegalArgumentException() {
-		List<JWK> jwkList = null;
-		assertThatCode(() -> new JWKContext(jwkList))
-				.isInstanceOf(IllegalArgumentException.class);
-	}
-
-	@Test
-	public void getJwkListWhenEmpty() {
-		JWKContext jwkContext = new JWKContext(Collections.emptyList());
-		assertThat(jwkContext.getJwkList()).isEmpty();
-	}
-
-	@Test
-	public void getJwkListWhenNotEmpty() {
-		JWK key = mock(JWK.class);
-		JWKContext jwkContext = new JWKContext(Arrays.asList(key));
-		assertThat(jwkContext.getJwkList()).containsOnly(key);
-	}
-}

+ 139 - 4
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2019 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -18,30 +18,45 @@ package org.springframework.security.oauth2.jwt;
 
 import java.net.UnknownHostException;
 import java.security.KeyFactory;
+import java.security.NoSuchAlgorithmException;
 import java.security.interfaces.RSAPublicKey;
+import java.security.spec.EncodedKeySpec;
+import java.security.spec.InvalidKeySpecException;
 import java.security.spec.X509EncodedKeySpec;
+import java.text.ParseException;
 import java.time.Instant;
 import java.util.Base64;
 import java.util.Collections;
 import java.util.Map;
 
+import com.nimbusds.jose.jwk.JWKSet;
 import okhttp3.mockwebserver.MockResponse;
 import okhttp3.mockwebserver.MockWebServer;
 import org.junit.After;
 import org.junit.Before;
+import org.junit.BeforeClass;
 import org.junit.Test;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
 
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2TokenValidator;
 import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
+import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
+import org.springframework.web.reactive.function.client.WebClient;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
+import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withJwkSetUri;
+import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withJwkSource;
+import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withPublicKey;
 
 /**
  * @author Rob Winch
@@ -50,11 +65,8 @@ import static org.mockito.Mockito.when;
 public class NimbusReactiveJwtDecoderTests {
 
 	private String expired = "eyJraWQiOiJrZXktaWQtMSIsImFsZyI6IlJTMjU2In0.eyJzY29wZSI6Im1lc3NhZ2U6cmVhZCIsImV4cCI6MTUyOTkzNzYzMX0.Dt5jFOKkB8zAmjciwvlGkj4LNStXWH0HNIfr8YYajIthBIpVgY5Hg_JL8GBmUFzKDgyusT0q60OOg8_Pdi4Lu-VTWyYutLSlNUNayMlyBaVEWfyZJnh2_OwMZr1vRys6HF-o1qZldhwcfvczHg61LwPa1ISoqaAltDTzBu9cGISz2iBUCuR0x71QhbuRNyJdjsyS96NqiM_TspyiOSxmlNch2oAef1MssOQ23CrKilIvEDsz_zk5H94q7rH0giWGdEHCENESsTJS0zvzH6r2xIWjd5WnihFpCPkwznEayxaEhrdvJqT_ceyXCIfY4m3vujPQHNDG0UshpwvDuEbPUg";
-
 	private String messageReadToken = "eyJraWQiOiJrZXktaWQtMSIsImFsZyI6IlJTMjU2In0.eyJzY29wZSI6Im1lc3NhZ2U6cmVhZCIsImV4cCI6OTIyMzM3MjAwNjA5NjM3NX0.bnQ8IJDXmQbmIXWku0YT1HOyV_3d0iQSA_0W2CmPyELhsxFETzBEEcZ0v0xCBiswDT51rwD83wbX3YXxb84fM64AhpU8wWOxLjha4J6HJX2JnlG47ydaAVD7eWGSYTavyyQ-CwUjQWrfMVcObFZLYG11ydzRYOR9-aiHcK3AobcTcS8jZFeI8EGQV_Cd3IJ018uFCf6VnXLv7eV2kRt08Go2RiPLW47ExvD7Dzzz_wDBKfb4pNem7fDvuzB3UPcp5m9QvLZicnbS_6AvDi6P1y_DFJf-1T5gkGmX5piDH1L1jg2Yl6tjmXbk5B3VhsyjJuXE6gzq1d-xie0Z1NVOxw";
-
 	private String unsignedToken = "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJleHAiOi0yMDMzMjI0OTcsImp0aSI6IjEyMyIsInR5cCI6IkpXVCJ9.";
-
 	private String jwkSet =
 		"{\n"
 		+ "   \"keys\":[\n"
@@ -67,10 +79,22 @@ public class NimbusReactiveJwtDecoderTests {
 		+ "      }\n"
 		+ "   ]\n"
 		+ "}";
+	private String jwkSetUri = "http://issuer/certs";
+
+	private String rsa512 = "eyJhbGciOiJSUzUxMiJ9.eyJzdWIiOiJ0ZXN0LXN1YmplY3QiLCJleHAiOjE5NzQzMjYxMTl9.LKAx-60EBfD7jC1jb1eKcjO4uLvf3ssISV-8tN-qp7gAjSvKvj4YA9-V2mIb6jcS1X_xGmNy6EIimZXpWaBR3nJmeu-jpe85u4WaW2Ztr8ecAi-dTO7ZozwdtljKuBKKvj4u1nF70zyCNl15AozSG0W1ASrjUuWrJtfyDG6WoZ8VfNMuhtU-xUYUFvscmeZKUYQcJ1KS-oV5tHeF8aNiwQoiPC_9KXCOZtNEJFdq6-uzFdHxvOP2yex5Gbmg5hXonauIFXG2ZPPGdXzm-5xkhBpgM8U7A_6wb3So8wBvLYYm2245QUump63AJRAy8tQpwt4n9MvQxQgS3z9R-NK92A";
+	private String rsa256 = "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXN1YmplY3QiLCJleHAiOjE5NzQzMjYzMzl9.CT-H2OWEqmSs1NWmnta5ealLFvM8OlbQTjGhfRcKLNxrTrzsOkqBJl-AN3k16BQU7mS32o744TiiZ29NcDlxPsr1MqTlN86-dobPiuNIDLp3A1bOVdXMcVFuMYkrNv0yW0tGS9OjEqsCCuZDkZ1by6AhsHLbGwRY-6AQdcRouZygGpOQu1hNun5j8q5DpSTY4AXKARIFlF-O3OpVbPJ0ebr3Ki-i3U9p_55H0e4-wx2bqcApWlqgofl1I8NKWacbhZgn81iibup2W7E0CzCzh71u1Mcy3xk1sYePx-dwcxJnHmxJReBBWjJZEAeCrkbnn_OCuo2fA-EQyNJtlN5F2w";
+	private String publicKey = "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAq4yKxb6SNePdDmQi9xFCrP6QvHosErQzryknQTTTffs0t3cy3Er3lIceuhZ7yQNSCDfPFqG8GoyoKhuChRiA5D+J2ab7bqTa1QJKfnCyERoscftgN2fXPHjHoiKbpGV2tMVw8mXl//tePOAiKbMJaBUnlAvJgkk1rVm08dSwpLC1sr2M19euf9jwnRGkMRZuhp9iCPgECRke5T8Ixpv0uQjSmGHnWUKTFlbj8sM83suROR1Ue64JSGScANc5vk3huJ/J97qTC+K2oKj6L8d9O8dpc4obijEOJwpydNvTYDgbiivYeSB00KS9jlBkQ5B2QqLvLVEygDl3dp59nGx6YQIDAQAB";
 
 	private MockWebServer server;
 	private NimbusReactiveJwtDecoder decoder;
 
+	private static KeyFactory kf;
+
+	@BeforeClass
+	public static void keyFactory() throws NoSuchAlgorithmException {
+		kf = KeyFactory.getInstance("RSA");
+	}
+
 	@Before
 	public void setup() throws Exception {
 		this.server = new MockWebServer();
@@ -205,4 +229,115 @@ public class NimbusReactiveJwtDecoderTests {
 		assertThatCode(() -> this.decoder.setClaimSetConverter(null))
 				.isInstanceOf(IllegalArgumentException.class);
 	}
+
+	@Test
+	public void withJwkSetUriWhenNullOrEmptyThenThrowsException() {
+		assertThatCode(() -> withJwkSetUri(null)).isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void jwsAlgorithmWhenNullOrEmptyThenThrowsException() {
+		NimbusReactiveJwtDecoder.JwkSetUriReactiveJwtDecoderBuilder builder = withJwkSetUri(this.jwkSetUri);
+		assertThatCode(() -> builder.jwsAlgorithm(null)).isInstanceOf(IllegalArgumentException.class);
+		assertThatCode(() -> builder.jwsAlgorithm("")).isInstanceOf(IllegalArgumentException.class);
+		assertThatCode(() -> builder.jwsAlgorithm("RS4096")).doesNotThrowAnyException();
+	}
+
+	@Test
+	public void restOperationsWhenNullThenThrowsException() {
+		NimbusReactiveJwtDecoder.JwkSetUriReactiveJwtDecoderBuilder builder = withJwkSetUri(this.jwkSetUri);
+		assertThatCode(() -> builder.webClient(null)).isInstanceOf(IllegalArgumentException.class);
+	}
+
+	// gh-5603
+	@Test
+	public void decodeWhenSignedThenOk() {
+		WebClient webClient = mockJwkSetResponse(this.jwkSet);
+		NimbusReactiveJwtDecoder decoder = withJwkSetUri(this.jwkSetUri).webClient(webClient).build();
+		assertThat(decoder.decode(messageReadToken).block())
+				.extracting(Jwt::getExpiresAt)
+				.isNotNull();
+		verify(webClient).get();
+	}
+
+	@Test
+	public void withPublicKeyWhenNullThenThrowsException() {
+		assertThatThrownBy(() -> withPublicKey(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void buildWhenSignatureAlgorithmMismatchesKeyTypeThenThrowsException() {
+		assertThatCode(() -> withPublicKey(key())
+				.jwsAlgorithm(JwsAlgorithms.ES256)
+				.build())
+				.isInstanceOf(IllegalStateException.class);
+	}
+
+	@Test
+	public void decodeWhenUsingPublicKeyThenSuccessfullyDecodes() throws Exception {
+		NimbusReactiveJwtDecoder decoder = withPublicKey(key()).build();
+		assertThat(decoder.decode(this.rsa256).block())
+				.extracting(Jwt::getSubject)
+				.isEqualTo("test-subject");
+	}
+
+	@Test
+	public void decodeWhenUsingPublicKeyWithRs512ThenSuccessfullyDecodes() throws Exception {
+		NimbusReactiveJwtDecoder decoder =
+				withPublicKey(key()).jwsAlgorithm(JwsAlgorithms.RS512).build();
+		assertThat(decoder.decode(this.rsa512).block())
+				.extracting(Jwt::getSubject)
+				.isEqualTo("test-subject");
+	}
+
+	@Test
+	public void decodeWhenSignatureMismatchesAlgorithmThenThrowsException() throws Exception {
+		NimbusReactiveJwtDecoder decoder =
+				withPublicKey(key()).jwsAlgorithm(JwsAlgorithms.RS512).build();
+		assertThatCode(() -> decoder.decode(this.rsa256).block())
+				.isInstanceOf(JwtException.class);
+	}
+
+	@Test
+	public void withJwkSourceWhenNullThenThrowsException() {
+		assertThatCode(() -> withJwkSource(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void decodeWhenCustomJwkSourceResolutionThenDecodes() {
+		NimbusReactiveJwtDecoder decoder =
+				withJwkSource(jwt -> Flux.fromIterable(parseJWKSet(this.jwkSet).getKeys()))
+						.build();
+
+		assertThat(decoder.decode(this.messageReadToken).block())
+				.extracting(Jwt::getExpiresAt)
+				.isNotNull();
+	}
+
+	private JWKSet parseJWKSet(String jwkSet) {
+		try {
+			return JWKSet.parse(jwkSet);
+		} catch (ParseException e) {
+			throw new IllegalArgumentException(e);
+		}
+	}
+
+	private RSAPublicKey key() throws InvalidKeySpecException {
+		byte[] decoded = Base64.getDecoder().decode(this.publicKey.getBytes());
+		EncodedKeySpec spec = new X509EncodedKeySpec(decoded);
+		return (RSAPublicKey) kf.generatePublic(spec);
+	}
+
+	private static WebClient mockJwkSetResponse(String response) {
+		WebClient real = WebClient.builder().build();
+		WebClient.RequestHeadersUriSpec spec = spy(real.get());
+		WebClient webClient = spy(WebClient.class);
+		when(webClient.get()).thenReturn(spec);
+		WebClient.ResponseSpec responseSpec = mock(WebClient.ResponseSpec.class);
+		when(responseSpec.bodyToMono(String.class)).thenReturn(Mono.just(response));
+		when(spec.retrieve()).thenReturn(responseSpec);
+		return webClient;
+	}
 }