Browse Source

Share JWKSource Instances

Closes gh-10312
Josh Cummings 3 năm trước cách đây
mục cha
commit
7b599d4770

+ 0 - 3
config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurerTests.java

@@ -1250,7 +1250,6 @@ public class OAuth2ResourceServerConfigurerTests {
 		String jwtThree = jwtFromIssuer(issuerThree);
 		mockWebServer(String.format(metadata, issuerOne, issuerOne));
 		mockWebServer(jwkSet);
-		mockWebServer(jwkSet);
 		// @formatter:off
 		this.mvc.perform(get("/authenticated").with(bearerToken(jwtOne)))
 				.andExpect(status().isOk())
@@ -1258,7 +1257,6 @@ public class OAuth2ResourceServerConfigurerTests {
 		// @formatter:on
 		mockWebServer(String.format(metadata, issuerTwo, issuerTwo));
 		mockWebServer(jwkSet);
-		mockWebServer(jwkSet);
 		// @formatter:off
 		this.mvc.perform(get("/authenticated").with(bearerToken(jwtTwo)))
 				.andExpect(status().isOk())
@@ -1266,7 +1264,6 @@ public class OAuth2ResourceServerConfigurerTests {
 		// @formatter:on
 		mockWebServer(String.format(metadata, issuerThree, issuerThree));
 		mockWebServer(jwkSet);
-		mockWebServer(jwkSet);
 		// @formatter:off
 		this.mvc.perform(get("/authenticated").with(bearerToken(jwtThree)))
 				.andExpect(status().isUnauthorized())

+ 0 - 3
config/src/test/java/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParserTests.java

@@ -707,21 +707,18 @@ public class OAuth2ResourceServerBeanDefinitionParserTests {
 		String jwtThree = jwtFromIssuer(issuerThree);
 		mockWebServer(String.format(metadata, issuerOne, issuerOne));
 		mockWebServer(jwkSet);
-		mockWebServer(jwkSet);
 		// @formatter:off
 		this.mvc.perform(get("/authenticated").header("Authorization", "Bearer " + jwtOne))
 				.andExpect(status().isNotFound());
 		// @formatter:on
 		mockWebServer(String.format(metadata, issuerTwo, issuerTwo));
 		mockWebServer(jwkSet);
-		mockWebServer(jwkSet);
 		// @formatter:off
 		this.mvc.perform(get("/authenticated").header("Authorization", "Bearer " + jwtTwo))
 				.andExpect(status().isNotFound());
 		// @formatter:on
 		mockWebServer(String.format(metadata, issuerThree, issuerThree));
 		mockWebServer(jwkSet);
-		mockWebServer(jwkSet);
 		// @formatter:off
 		this.mvc.perform(get("/authenticated").header("Authorization", "Bearer " + jwtThree))
 				.andExpect(status().isUnauthorized())

+ 20 - 2
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoderProviderConfigurationUtils.java

@@ -31,7 +31,10 @@ import com.nimbusds.jose.jwk.JWKSelector;
 import com.nimbusds.jose.jwk.KeyType;
 import com.nimbusds.jose.jwk.KeyUse;
 import com.nimbusds.jose.jwk.source.JWKSource;
+import com.nimbusds.jose.proc.JWSKeySelector;
+import com.nimbusds.jose.proc.JWSVerificationKeySelector;
 import com.nimbusds.jose.proc.SecurityContext;
+import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
 
 import org.springframework.core.ParameterizedTypeReference;
 import org.springframework.http.RequestEntity;
@@ -82,7 +85,17 @@ final class JwtDecoderProviderConfigurationUtils {
 				+ "\" provided in the configuration did not " + "match the requested issuer \"" + issuer + "\"");
 	}
 
-	static Set<SignatureAlgorithm> getSignatureAlgorithms(JWKSource<SecurityContext> jwkSource) {
+	static <C extends SecurityContext> void addJWSAlgorithms(ConfigurableJWTProcessor<C> jwtProcessor) {
+		JWSKeySelector<C> selector = jwtProcessor.getJWSKeySelector();
+		if (selector instanceof JWSVerificationKeySelector) {
+			JWKSource<C> jwkSource = ((JWSVerificationKeySelector<C>) selector).getJWKSource();
+			Set<JWSAlgorithm> algorithms = getJWSAlgorithms(jwkSource);
+			selector = new JWSVerificationKeySelector<>(algorithms, jwkSource);
+			jwtProcessor.setJWSKeySelector(selector);
+		}
+	}
+
+	static <C extends SecurityContext> Set<JWSAlgorithm> getJWSAlgorithms(JWKSource<C> jwkSource) {
 		JWKMatcher jwkMatcher = new JWKMatcher.Builder().publicOnly(true).keyUses(KeyUse.SIGNATURE, null)
 				.keyTypes(KeyType.RSA, KeyType.EC).build();
 		Set<JWSAlgorithm> jwsAlgorithms = new HashSet<>();
@@ -106,6 +119,12 @@ final class JwtDecoderProviderConfigurationUtils {
 		catch (KeySourceException ex) {
 			throw new IllegalStateException(ex);
 		}
+		Assert.notEmpty(jwsAlgorithms, "Failed to find any algorithms from the JWK set");
+		return jwsAlgorithms;
+	}
+
+	static Set<SignatureAlgorithm> getSignatureAlgorithms(JWKSource<SecurityContext> jwkSource) {
+		Set<JWSAlgorithm> jwsAlgorithms = getJWSAlgorithms(jwkSource);
 		Set<SignatureAlgorithm> signatureAlgorithms = new HashSet<>();
 		for (JWSAlgorithm jwsAlgorithm : jwsAlgorithms) {
 			SignatureAlgorithm signatureAlgorithm = SignatureAlgorithm.from(jwsAlgorithm.getName());
@@ -113,7 +132,6 @@ final class JwtDecoderProviderConfigurationUtils {
 				signatureAlgorithms.add(signatureAlgorithm);
 			}
 		}
-		Assert.notEmpty(signatureAlgorithms, "Failed to find any algorithms from the JWK set");
 		return signatureAlgorithms;
 	}
 

+ 1 - 21
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtDecoders.java

@@ -16,17 +16,9 @@
 
 package org.springframework.security.oauth2.jwt;
 
-import java.io.IOException;
-import java.io.UncheckedIOException;
-import java.net.URL;
 import java.util.Map;
-import java.util.Set;
-
-import com.nimbusds.jose.jwk.source.RemoteJWKSet;
-import com.nimbusds.jose.proc.SecurityContext;
 
 import org.springframework.security.oauth2.core.OAuth2TokenValidator;
-import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 import org.springframework.util.Assert;
 
 /**
@@ -117,22 +109,10 @@ public final class JwtDecoders {
 		JwtDecoderProviderConfigurationUtils.validateIssuer(configuration, issuer);
 		OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefaultWithIssuer(issuer);
 		String jwkSetUri = configuration.get("jwks_uri").toString();
-		RemoteJWKSet<SecurityContext> jwkSource = new RemoteJWKSet<>(url(jwkSetUri));
-		Set<SignatureAlgorithm> signatureAlgorithms = JwtDecoderProviderConfigurationUtils
-				.getSignatureAlgorithms(jwkSource);
 		NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(jwkSetUri)
-				.jwsAlgorithms((algs) -> algs.addAll(signatureAlgorithms)).build();
+				.jwtProcessorCustomizer(JwtDecoderProviderConfigurationUtils::addJWSAlgorithms).build();
 		jwtDecoder.setJwtValidator(jwtValidator);
 		return jwtDecoder;
 	}
 
-	private static URL url(String url) {
-		try {
-			return new URL(url);
-		}
-		catch (IOException ex) {
-			throw new UncheckedIOException(ex);
-		}
-	}
-
 }

+ 21 - 6
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java

@@ -17,12 +17,14 @@
 package org.springframework.security.oauth2.jwt;
 
 import java.security.interfaces.RSAPublicKey;
+import java.time.Duration;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.BiFunction;
 import java.util.function.Consumer;
 import java.util.function.Function;
 
@@ -274,19 +276,20 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 	 */
 	public static final class JwkSetUriReactiveJwtDecoderBuilder {
 
+		private static final Duration FOREVER = Duration.ofMillis(Long.MAX_VALUE);
+
 		private final String jwkSetUri;
 
 		private Set<SignatureAlgorithm> signatureAlgorithms = new HashSet<>();
 
 		private WebClient webClient = WebClient.create();
 
-		private Consumer<ConfigurableJWTProcessor<JWKSecurityContext>> jwtProcessorCustomizer;
+		private BiFunction<ReactiveRemoteJWKSource, ConfigurableJWTProcessor<JWKSecurityContext>, Mono<ConfigurableJWTProcessor<JWKSecurityContext>>> jwtProcessorCustomizer;
 
 		private JwkSetUriReactiveJwtDecoderBuilder(String jwkSetUri) {
 			Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty");
 			this.jwkSetUri = jwkSetUri;
-			this.jwtProcessorCustomizer = (processor) -> {
-			};
+			this.jwtProcessorCustomizer = (source, processor) -> Mono.just(processor);
 		}
 
 		/**
@@ -342,6 +345,16 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 		public JwkSetUriReactiveJwtDecoderBuilder jwtProcessorCustomizer(
 				Consumer<ConfigurableJWTProcessor<JWKSecurityContext>> jwtProcessorCustomizer) {
 			Assert.notNull(jwtProcessorCustomizer, "jwtProcessorCustomizer cannot be null");
+			this.jwtProcessorCustomizer = (source, processor) -> {
+				jwtProcessorCustomizer.accept(processor);
+				return Mono.just(processor);
+			};
+			return this;
+		}
+
+		JwkSetUriReactiveJwtDecoderBuilder jwtProcessorCustomizer(
+				BiFunction<ReactiveRemoteJWKSource, ConfigurableJWTProcessor<JWKSecurityContext>, Mono<ConfigurableJWTProcessor<JWKSecurityContext>>> jwtProcessorCustomizer) {
+			Assert.notNull(jwtProcessorCustomizer, "jwtProcessorCustomizer cannot be null");
 			this.jwtProcessorCustomizer = jwtProcessorCustomizer;
 			return this;
 		}
@@ -373,15 +386,17 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 			jwtProcessor.setJWSKeySelector(jwsKeySelector);
 			jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {
 			});
-			this.jwtProcessorCustomizer.accept(jwtProcessor);
 			ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource(this.jwkSetUri);
 			source.setWebClient(this.webClient);
 			Function<JWSAlgorithm, Boolean> expectedJwsAlgorithms = getExpectedJwsAlgorithms(jwsKeySelector);
+			Mono<ConfigurableJWTProcessor<JWKSecurityContext>> jwtProcessorMono = this.jwtProcessorCustomizer
+					.apply(source, jwtProcessor)
+					.cache((processor) -> FOREVER, (ex) -> Duration.ZERO, () -> Duration.ZERO);
 			return (jwt) -> {
 				JWKSelector selector = createSelector(expectedJwsAlgorithms, jwt.getHeader());
-				return source.get(selector)
+				return jwtProcessorMono.flatMap((processor) -> source.get(selector)
 						.onErrorMap((ex) -> new IllegalStateException("Could not obtain the keys", ex))
-						.map((jwkList) -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwkList)));
+						.map((jwkList) -> createClaimsSet(processor, jwt, new JWKSecurityContext(jwkList))));
 			};
 		}
 

+ 81 - 0
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoderProviderConfigurationUtils.java

@@ -0,0 +1,81 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.jwt;
+
+import java.util.HashSet;
+import java.util.Set;
+
+import com.nimbusds.jose.JWSAlgorithm;
+import com.nimbusds.jose.KeySourceException;
+import com.nimbusds.jose.jwk.JWK;
+import com.nimbusds.jose.jwk.JWKMatcher;
+import com.nimbusds.jose.jwk.JWKSelector;
+import com.nimbusds.jose.jwk.KeyType;
+import com.nimbusds.jose.jwk.KeyUse;
+import com.nimbusds.jose.jwk.source.JWKSource;
+import com.nimbusds.jose.proc.JWSKeySelector;
+import com.nimbusds.jose.proc.JWSVerificationKeySelector;
+import com.nimbusds.jose.proc.SecurityContext;
+import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
+import reactor.core.publisher.Mono;
+
+import org.springframework.util.Assert;
+
+final class ReactiveJwtDecoderProviderConfigurationUtils {
+
+	static <C extends SecurityContext> Mono<ConfigurableJWTProcessor<C>> addJWSAlgorithms(
+			ReactiveRemoteJWKSource jwkSource, ConfigurableJWTProcessor<C> jwtProcessor) {
+		JWSKeySelector<C> selector = jwtProcessor.getJWSKeySelector();
+		if (!(selector instanceof JWSVerificationKeySelector)) {
+			return Mono.just(jwtProcessor);
+		}
+		JWKSource<C> delegate = ((JWSVerificationKeySelector<C>) selector).getJWKSource();
+		return getJWSAlgorithms(jwkSource).map((algorithms) -> new JWSVerificationKeySelector<>(algorithms, delegate))
+				.map((replacement) -> {
+					jwtProcessor.setJWSKeySelector(replacement);
+					return jwtProcessor;
+				});
+	}
+
+	static Mono<Set<JWSAlgorithm>> getJWSAlgorithms(ReactiveRemoteJWKSource jwkSource) {
+		JWKMatcher jwkMatcher = new JWKMatcher.Builder().publicOnly(true).keyUses(KeyUse.SIGNATURE, null)
+				.keyTypes(KeyType.RSA, KeyType.EC).build();
+		return jwkSource.get(new JWKSelector(jwkMatcher)).map((jwks) -> {
+			Set<JWSAlgorithm> jwsAlgorithms = new HashSet<>();
+			for (JWK jwk : jwks) {
+				if (jwk.getAlgorithm() != null) {
+					JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(jwk.getAlgorithm().getName());
+					jwsAlgorithms.add(jwsAlgorithm);
+				}
+				else {
+					if (jwk.getKeyType() == KeyType.RSA) {
+						jwsAlgorithms.addAll(JWSAlgorithm.Family.RSA);
+					}
+					else if (jwk.getKeyType() == KeyType.EC) {
+						jwsAlgorithms.addAll(JWSAlgorithm.Family.EC);
+					}
+				}
+			}
+			Assert.notEmpty(jwsAlgorithms, "Failed to find any algorithms from the JWK set");
+			return jwsAlgorithms;
+		}).onErrorMap(KeySourceException.class, (ex) -> new IllegalStateException(ex));
+	}
+
+	private ReactiveJwtDecoderProviderConfigurationUtils() {
+	}
+
+}

+ 1 - 21
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecoders.java

@@ -16,17 +16,9 @@
 
 package org.springframework.security.oauth2.jwt;
 
-import java.io.IOException;
-import java.io.UncheckedIOException;
-import java.net.URL;
 import java.util.Map;
-import java.util.Set;
-
-import com.nimbusds.jose.jwk.source.RemoteJWKSet;
-import com.nimbusds.jose.proc.SecurityContext;
 
 import org.springframework.security.oauth2.core.OAuth2TokenValidator;
-import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 import org.springframework.util.Assert;
 
 /**
@@ -115,22 +107,10 @@ public final class ReactiveJwtDecoders {
 		JwtDecoderProviderConfigurationUtils.validateIssuer(configuration, issuer);
 		OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefaultWithIssuer(issuer);
 		String jwkSetUri = configuration.get("jwks_uri").toString();
-		RemoteJWKSet<SecurityContext> jwkSource = new RemoteJWKSet<>(url(jwkSetUri));
-		Set<SignatureAlgorithm> signatureAlgorithms = JwtDecoderProviderConfigurationUtils
-				.getSignatureAlgorithms(jwkSource);
 		NimbusReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder.withJwkSetUri(jwkSetUri)
-				.jwsAlgorithms((algs) -> algs.addAll(signatureAlgorithms)).build();
+				.jwtProcessorCustomizer(ReactiveJwtDecoderProviderConfigurationUtils::addJWSAlgorithms).build();
 		jwtDecoder.setJwtValidator(jwtValidator);
 		return jwtDecoder;
 	}
 
-	private static URL url(String url) {
-		try {
-			return new URL(url);
-		}
-		catch (IOException ex) {
-			throw new UncheckedIOException(ex);
-		}
-	}
-
 }

+ 0 - 1
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java

@@ -308,7 +308,6 @@ public class JwtDecodersTests {
 	private void prepareConfigurationResponse(String body) {
 		this.server.enqueue(response(body));
 		this.server.enqueue(response(JWK_SET));
-		this.server.enqueue(response(JWK_SET));
 	}
 
 	private void prepareConfigurationResponseOidc() {

+ 3 - 1
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java

@@ -29,6 +29,7 @@ import java.util.Base64;
 import java.util.Collections;
 import java.util.Date;
 import java.util.Map;
+import java.util.function.Consumer;
 
 import javax.crypto.SecretKey;
 
@@ -45,6 +46,7 @@ import com.nimbusds.jose.proc.JWSKeySelector;
 import com.nimbusds.jose.proc.JWSVerificationKeySelector;
 import com.nimbusds.jwt.JWTClaimsSet;
 import com.nimbusds.jwt.SignedJWT;
+import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
 import okhttp3.mockwebserver.MockResponse;
 import okhttp3.mockwebserver.MockWebServer;
 import org.junit.jupiter.api.AfterEach;
@@ -314,7 +316,7 @@ public class NimbusReactiveJwtDecoderTests {
 		assertThatIllegalArgumentException()
 				.isThrownBy(() -> NimbusReactiveJwtDecoder
 						.withJwkSetUri(this.jwkSetUri)
-						.jwtProcessorCustomizer(null)
+						.jwtProcessorCustomizer((Consumer<ConfigurableJWTProcessor<JWKSecurityContext>>) null)
 						.build()
 				)
 				.withMessage("jwtProcessorCustomizer cannot be null");

+ 0 - 1
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveJwtDecodersTests.java

@@ -282,7 +282,6 @@ public class ReactiveJwtDecodersTests {
 	private void prepareConfigurationResponse(String body) {
 		this.server.enqueue(response(body));
 		this.server.enqueue(response(JWK_SET));
-		this.server.enqueue(response(JWK_SET));
 	}
 
 	private void prepareConfigurationResponseOidc() {