Pārlūkot izejas kodu

Polish setJwkSelector

Make so that it runs only when selection is needed.
Require the provided selector be non-null.
Add Tests.

Issue gh-16170
Josh Cummings 6 mēneši atpakaļ
vecāks
revīzija
6793334575

+ 22 - 17
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java

@@ -87,17 +87,12 @@ public final class NimbusJwtEncoder implements JwtEncoder {
 
 	private final JWKSource<SecurityContext> jwkSource;
 
-	private Converter<List<JWK>, JWK> jwkSelector= (jwks)->{
-		if (jwks.size() > 1) {
-			throw new JwtEncodingException(String.format(
-					"Failed to select a key since there are multiple for the signing algorithm [%s]; " +
-							"please specify a selector in NimbusJwsEncoder#setJwkSelector",jwks.get(0).getAlgorithm()));
-		}
-		if (jwks.isEmpty()) {
-			throw new JwtEncodingException(
-					String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key"));
-		}
-		return jwks.get(0);
+	private Converter<List<JWK>, JWK> jwkSelector = (jwks) -> {
+		throw new JwtEncodingException(
+				String.format(
+						"Failed to select a key since there are multiple for the signing algorithm [%s]; "
+								+ "please specify a selector in NimbusJwsEncoder#setJwkSelector",
+						jwks.get(0).getAlgorithm()));
 	};
 
 	/**
@@ -108,17 +103,20 @@ public final class NimbusJwtEncoder implements JwtEncoder {
 		Assert.notNull(jwkSource, "jwkSource cannot be null");
 		this.jwkSource = jwkSource;
 	}
+
 	/**
-	 * Use this strategy to reduce the list of matching JWKs down to a since one.
-	 * <p> For example, you can call {@code setJwkSelector(List::getFirst)} in order
-	 * to have this encoder select the first match.
+	 * Use this strategy to reduce the list of matching JWKs when there is more than one.
+	 * <p>
+	 * For example, you can call {@code setJwkSelector(List::getFirst)} in order to have
+	 * this encoder select the first match.
 	 *
-	 * <p> By default, the class with throw an exception if there is more than one result.
+	 * <p>
+	 * By default, the class with throw an exception.
 	 * @since 6.5
 	 */
 	public void setJwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
-		if(null!=jwkSelector)
-			this.jwkSelector = jwkSelector;
+		Assert.notNull(jwkSelector, "jwkSelector cannot be null");
+		this.jwkSelector = jwkSelector;
 	}
 
 	@Override
@@ -149,6 +147,13 @@ public final class NimbusJwtEncoder implements JwtEncoder {
 			throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
 					"Failed to select a JWK signing key -> " + ex.getMessage()), ex);
 		}
+		if (jwks.isEmpty()) {
+			throw new JwtEncodingException(
+					String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key"));
+		}
+		if (jwks.size() == 1) {
+			return jwks.get(0);
+		}
 		return this.jwkSelector.convert(jwks);
 	}
 

+ 5 - 1
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestJwks.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2025 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -59,6 +59,10 @@ public final class TestJwks {
 	private TestJwks() {
 	}
 
+	public static RSAKey.Builder rsa() {
+		return jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY);
+	}
+
 	public static RSAKey.Builder jwk(RSAPublicKey publicKey, RSAPrivateKey privateKey) {
 		// @formatter:off
 		return new RSAKey.Builder(publicKey)

+ 56 - 3
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2025 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 
+import com.nimbusds.jose.JWSAlgorithm;
 import com.nimbusds.jose.KeySourceException;
 import com.nimbusds.jose.jwk.ECKey;
 import com.nimbusds.jose.jwk.JWK;
@@ -39,6 +40,7 @@ import org.junit.jupiter.api.Test;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
+import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.oauth2.jose.TestJwks;
 import org.springframework.security.oauth2.jose.TestKeys;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
@@ -51,6 +53,8 @@ import static org.mockito.BDDMockito.given;
 import static org.mockito.BDDMockito.willAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
 
 /**
  * Tests for {@link NimbusJwtEncoder}.
@@ -109,7 +113,7 @@ public class NimbusJwtEncoderTests {
 
 	@Test
 	public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws Exception {
-		RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK;
+		RSAKey rsaJwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build();
 		this.jwkList.add(rsaJwk);
 		this.jwkList.add(rsaJwk);
 
@@ -118,7 +122,7 @@ public class NimbusJwtEncoderTests {
 
 		assertThatExceptionOfType(JwtEncodingException.class)
 			.isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet)))
-			.withMessageContaining("Found multiple JWK signing keys for algorithm 'RS256'");
+			.withMessageContaining("Failed to select a key since there are multiple for the signing algorithm [RS256]");
 	}
 
 	@Test
@@ -291,6 +295,55 @@ public class NimbusJwtEncoderTests {
 		assertThat(jwk1.getKeyID()).isNotEqualTo(jwk2.getKeyID());
 	}
 
+	@Test
+	public void encodeWhenMultipleKeysThenJwkSelectorUsed() throws Exception {
+		JWK jwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build();
+		JWKSource<SecurityContext> jwkSource = mock(JWKSource.class);
+		given(jwkSource.get(any(), any())).willReturn(List.of(jwk, jwk));
+		Converter<List<JWK>, JWK> selector = mock(Converter.class);
+		given(selector.convert(any())).willReturn(TestJwks.DEFAULT_RSA_JWK);
+
+		NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource);
+		jwtEncoder.setJwkSelector(selector);
+
+		JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build();
+		jwtEncoder.encode(JwtEncoderParameters.from(claims));
+
+		verify(selector).convert(any());
+	}
+
+	@Test
+	public void encodeWhenSingleKeyThenJwkSelectorIsNotUsed() throws Exception {
+		JWK jwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build();
+		JWKSource<SecurityContext> jwkSource = mock(JWKSource.class);
+		given(jwkSource.get(any(), any())).willReturn(List.of(jwk));
+		Converter<List<JWK>, JWK> selector = mock(Converter.class);
+
+		NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource);
+		jwtEncoder.setJwkSelector(selector);
+
+		JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build();
+		jwtEncoder.encode(JwtEncoderParameters.from(claims));
+
+		verifyNoInteractions(selector);
+	}
+
+	@Test
+	public void encodeWhenNoKeysThenJwkSelectorIsNotUsed() throws Exception {
+		JWKSource<SecurityContext> jwkSource = mock(JWKSource.class);
+		given(jwkSource.get(any(), any())).willReturn(List.of());
+		Converter<List<JWK>, JWK> selector = mock(Converter.class);
+
+		NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource);
+		jwtEncoder.setJwkSelector(selector);
+
+		JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build();
+		assertThatExceptionOfType(JwtEncodingException.class)
+			.isThrownBy(() -> jwtEncoder.encode(JwtEncoderParameters.from(claims)));
+
+		verifyNoInteractions(selector);
+	}
+
 	private static final class JwkListResultCaptor implements Answer<List<JWK>> {
 
 		private List<JWK> result;