Selaa lähdekoodia

Merge branch '5.8.x' into 6.0.x

Closes gh-13005
Josh Cummings 2 vuotta sitten
vanhempi
commit
b423db5f93

+ 13 - 7
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -55,6 +55,8 @@ import com.nimbusds.jwt.proc.DefaultJWTProcessor;
 import com.nimbusds.jwt.proc.JWTProcessor;
 import reactor.core.publisher.Flux;
 import reactor.core.publisher.Mono;
+import reactor.util.function.Tuple2;
+import reactor.util.function.Tuples;
 
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.oauth2.core.OAuth2Error;
@@ -388,15 +390,19 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 			});
 			ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource(this.jwkSetUri);
 			source.setWebClient(this.webClient);
-			Function<JWSAlgorithm, Boolean> expectedJwsAlgorithms = getExpectedJwsAlgorithms(jwsKeySelector);
-			Mono<ConfigurableJWTProcessor<JWKSecurityContext>> jwtProcessorMono = this.jwtProcessorCustomizer
+			Mono<Tuple2<ConfigurableJWTProcessor<JWKSecurityContext>, Function<JWSAlgorithm, Boolean>>> jwtProcessorMono = this.jwtProcessorCustomizer
 					.apply(source, jwtProcessor)
+					.map((processor) -> Tuples.of(processor, getExpectedJwsAlgorithms(processor.getJWSKeySelector())))
 					.cache((processor) -> FOREVER, (ex) -> Duration.ZERO, () -> Duration.ZERO);
 			return (jwt) -> {
-				JWKSelector selector = createSelector(expectedJwsAlgorithms, jwt.getHeader());
-				return jwtProcessorMono.flatMap((processor) -> source.get(selector)
-						.onErrorMap((ex) -> new IllegalStateException("Could not obtain the keys", ex))
-						.map((jwkList) -> createClaimsSet(processor, jwt, new JWKSecurityContext(jwkList))));
+				return jwtProcessorMono.flatMap((tuple) -> {
+					JWTProcessor<JWKSecurityContext> processor = tuple.getT1();
+					Function<JWSAlgorithm, Boolean> expectedJwsAlgorithms = tuple.getT2();
+					JWKSelector selector = createSelector(expectedJwsAlgorithms, jwt.getHeader());
+					return source.get(selector)
+							.onErrorMap((ex) -> new IllegalStateException("Could not obtain the keys", ex))
+							.map((jwkList) -> createClaimsSet(processor, jwt, new JWKSecurityContext(jwkList)));
+				});
 			};
 		}
 

+ 17 - 1
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -39,6 +39,8 @@ import com.nimbusds.jose.JWSHeader;
 import com.nimbusds.jose.JWSSigner;
 import com.nimbusds.jose.crypto.MACSigner;
 import com.nimbusds.jose.jwk.JWKSet;
+import com.nimbusds.jose.jwk.RSAKey;
+import com.nimbusds.jose.jwk.source.JWKSecurityContextJWKSet;
 import com.nimbusds.jose.jwk.source.JWKSource;
 import com.nimbusds.jose.proc.DefaultJOSEObjectTypeVerifier;
 import com.nimbusds.jose.proc.JWKSecurityContext;
@@ -365,6 +367,20 @@ public class NimbusReactiveJwtDecoderTests {
 		// @formatter:on
 	}
 
+	@Test
+	public void withJwkSetUriWhenJwtProcessorCustomizerSetsJWSKeySelectorThenUseCustomizedJWSKeySelector()
+			throws InvalidKeySpecException {
+		WebClient webClient = mockJwkSetResponse(new JWKSet(new RSAKey.Builder(key()).build()).toString());
+		// @formatter:off
+		NimbusReactiveJwtDecoder decoder = NimbusReactiveJwtDecoder.withJwkSetUri(this.jwkSetUri)
+				.jwsAlgorithm(SignatureAlgorithm.ES256).webClient(webClient)
+				.jwtProcessorCustomizer((p) -> p
+						.setJWSKeySelector(new JWSVerificationKeySelector<>(JWSAlgorithm.RS512, new JWKSecurityContextJWKSet())))
+				.build();
+		assertThat(decoder.decode(this.rsa512).block()).extracting(Jwt::getSubject).isEqualTo("test-subject");
+		// @formatter:on
+	}
+
 	@Test
 	public void withPublicKeyWhenNullThenThrowsException() {
 		// @formatter:off