Переглянути джерело

Single-key Key Selector

Fixes: gh-7049
Fixes: gh-7056
Josh Cummings 6 роки тому
батько
коміт
ce79ef2634

+ 4 - 16
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java

@@ -29,10 +29,6 @@ import javax.crypto.SecretKey;
 
 import com.nimbusds.jose.JWSAlgorithm;
 import com.nimbusds.jose.RemoteKeySourceException;
-import com.nimbusds.jose.jwk.JWKSet;
-import com.nimbusds.jose.jwk.RSAKey;
-import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
-import com.nimbusds.jose.jwk.source.ImmutableSecret;
 import com.nimbusds.jose.jwk.source.JWKSource;
 import com.nimbusds.jose.jwk.source.RemoteJWKSet;
 import com.nimbusds.jose.proc.JWSKeySelector;
@@ -316,17 +312,12 @@ public final class NimbusJwtDecoder implements JwtDecoder {
 	 */
 	public static final class PublicKeyJwtDecoderBuilder {
 		private JWSAlgorithm jwsAlgorithm;
-		private RSAKey key;
+		private RSAPublicKey key;
 
 		private PublicKeyJwtDecoderBuilder(RSAPublicKey key) {
 			Assert.notNull(key, "key cannot be null");
 			this.jwsAlgorithm = JWSAlgorithm.RS256;
-			this.key = rsaKey(key);
-		}
-
-		private static RSAKey rsaKey(RSAPublicKey publicKey) {
-			return new RSAKey.Builder(publicKey)
-					.build();
+			this.key = key;
 		}
 
 		/**
@@ -352,10 +343,8 @@ public final class NimbusJwtDecoder implements JwtDecoder {
 						this.jwsAlgorithm + ". Please indicate one of RS256, RS384, or RS512.");
 			}
 
-			JWKSet jwkSet = new JWKSet(this.key);
-			JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(jwkSet);
 			JWSKeySelector<SecurityContext> jwsKeySelector =
-					new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
+					new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, this.key);
 			DefaultJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
 			jwtProcessor.setJWSKeySelector(jwsKeySelector);
 
@@ -414,9 +403,8 @@ public final class NimbusJwtDecoder implements JwtDecoder {
 		}
 
 		JWTProcessor<SecurityContext> processor() {
-			JWKSource<SecurityContext> jwkSource = new ImmutableSecret<>(this.secretKey);
 			JWSKeySelector<SecurityContext> jwsKeySelector =
-					new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
+					new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, this.secretKey);
 			DefaultJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
 			jwtProcessor.setJWSKeySelector(jwsKeySelector);
 

+ 4 - 17
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java

@@ -30,12 +30,7 @@ import com.nimbusds.jose.JWSHeader;
 import com.nimbusds.jose.jwk.JWK;
 import com.nimbusds.jose.jwk.JWKMatcher;
 import com.nimbusds.jose.jwk.JWKSelector;
-import com.nimbusds.jose.jwk.JWKSet;
-import com.nimbusds.jose.jwk.RSAKey;
-import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
-import com.nimbusds.jose.jwk.source.ImmutableSecret;
 import com.nimbusds.jose.jwk.source.JWKSecurityContextJWKSet;
-import com.nimbusds.jose.jwk.source.JWKSource;
 import com.nimbusds.jose.proc.BadJOSEException;
 import com.nimbusds.jose.proc.JWKSecurityContext;
 import com.nimbusds.jose.proc.JWSKeySelector;
@@ -318,20 +313,15 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 	 * @since 5.2
 	 */
 	public static final class PublicKeyReactiveJwtDecoderBuilder {
-		private final RSAKey key;
+		private final RSAPublicKey key;
 		private JWSAlgorithm jwsAlgorithm;
 
 		private PublicKeyReactiveJwtDecoderBuilder(RSAPublicKey key) {
 			Assert.notNull(key, "key cannot be null");
-			this.key = rsaKey(key);
+			this.key = key;
 			this.jwsAlgorithm = JWSAlgorithm.RS256;
 		}
 
-		private static RSAKey rsaKey(RSAPublicKey publicKey) {
-			return new RSAKey.Builder(publicKey)
-					.build();
-		}
-
 		/**
 		 * Use the given signing
 		 * <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithm</a>.
@@ -363,10 +353,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 						this.jwsAlgorithm + ". Please indicate one of RS256, RS384, or RS512.");
 			}
 
-			JWKSet jwkSet = new JWKSet(this.key);
-			JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(jwkSet);
 			JWSKeySelector<SecurityContext> jwsKeySelector =
-					new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
+					new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, this.key);
 			DefaultJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
 			jwtProcessor.setJWSKeySelector(jwsKeySelector);
 
@@ -418,9 +406,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 		}
 
 		Converter<JWT, Mono<JWTClaimsSet>> processor() {
-			JWKSource<SecurityContext> jwkSource = new ImmutableSecret<>(this.secretKey);
 			JWSKeySelector<SecurityContext> jwsKeySelector =
-					new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
+					new SingleKeyJWSKeySelector<>(this.jwsAlgorithm, this.secretKey);
 			DefaultJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
 			jwtProcessor.setJWSKeySelector(jwsKeySelector);
 

+ 54 - 0
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/SingleKeyJWSKeySelector.java

@@ -0,0 +1,54 @@
+/*
+ * Copyright 2002-2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.jwt;
+
+import java.security.Key;
+import java.util.Arrays;
+import java.util.List;
+
+import com.nimbusds.jose.JWSAlgorithm;
+import com.nimbusds.jose.JWSHeader;
+import com.nimbusds.jose.proc.JWSKeySelector;
+import com.nimbusds.jose.proc.SecurityContext;
+
+import org.springframework.util.Assert;
+
+/**
+ * An internal implementation of {@link JWSKeySelector} that always returns the same key
+ *
+ * @author Josh Cummings
+ * @since 5.2
+ */
+final class SingleKeyJWSKeySelector<C extends SecurityContext> implements JWSKeySelector<C> {
+	private final List<Key> keySet;
+	private final JWSAlgorithm expectedJwsAlgorithm;
+
+	SingleKeyJWSKeySelector(JWSAlgorithm expectedJwsAlgorithm, Key key) {
+		Assert.notNull(expectedJwsAlgorithm, "expectedJwsAlgorithm cannot be null");
+		Assert.notNull(key, "key cannot be null");
+		this.keySet = Arrays.asList(key);
+		this.expectedJwsAlgorithm = expectedJwsAlgorithm;
+	}
+
+	@Override
+	public List<? extends Key> selectJWSKeys(JWSHeader header, C context) {
+		if (!this.expectedJwsAlgorithm.equals(header.getAlgorithm())) {
+			throw new IllegalArgumentException("Unsupported algorithm of " + header.getAlgorithm());
+		}
+		return this.keySet;
+	}
+}

+ 76 - 1
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestKeys.java

@@ -15,18 +15,93 @@
  */
 package org.springframework.security.oauth2.jose;
 
+import java.security.KeyFactory;
+import java.security.NoSuchAlgorithmException;
+import java.security.interfaces.RSAPrivateKey;
+import java.security.interfaces.RSAPublicKey;
+import java.security.spec.InvalidKeySpecException;
+import java.security.spec.PKCS8EncodedKeySpec;
+import java.security.spec.X509EncodedKeySpec;
+import java.util.Base64;
 import javax.crypto.SecretKey;
 import javax.crypto.spec.SecretKeySpec;
-import java.util.Base64;
 
 /**
  * @author Joe Grandja
  * @since 5.2
  */
 public class TestKeys {
+	public static final KeyFactory kf;
+
+	static {
+		try {
+			kf = KeyFactory.getInstance("RSA");
+		} catch (NoSuchAlgorithmException e) {
+			throw new IllegalStateException(e);
+		}
+	}
+
 	public static final String DEFAULT_ENCODED_SECRET_KEY = "bCzY/M48bbkwBEWjmNSIEPfwApcvXOnkCxORBEbPr+4=";
 
 	public static final SecretKey DEFAULT_SECRET_KEY =
 			new SecretKeySpec(Base64.getDecoder().decode(DEFAULT_ENCODED_SECRET_KEY), "AES");
 
+	public static final String DEFAULT_RSA_PUBLIC_KEY =
+			"MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA3FlqJr5TRskIQIgdE3Dd" +
+			"7D9lboWdcTUT8a+fJR7MAvQm7XXNoYkm3v7MQL1NYtDvL2l8CAnc0WdSTINU6IRv" +
+			"c5Kqo2Q4csNX9SHOmEfzoROjQqahEcve1jBXluoCXdYuYpx4/1tfRgG6ii4Uhxh6" +
+			"iI8qNMJQX+fLfqhbfYfxBQVRPywBkAbIP4x1EAsbC6FSNmkhCxiMNqEgxaIpY8C2" +
+			"kJdJ/ZIV+WW4noDdzpKqHcwmB8FsrumlVY/DNVvUSDIipiq9PbP4H99TXN1o746o" +
+			"RaNa07rq1hoCgMSSy+85SagCoxlmyE+D+of9SsMY8Ol9t0rdzpobBuhyJ/o5dfvj" +
+			"KwIDAQAB";
+
+	public static final RSAPublicKey DEFAULT_PUBLIC_KEY = publicKey();
+
+	private static RSAPublicKey publicKey() {
+		X509EncodedKeySpec spec = new X509EncodedKeySpec(Base64.getDecoder().decode(DEFAULT_RSA_PUBLIC_KEY));
+		try {
+			return (RSAPublicKey) kf.generatePublic(spec);
+		} catch (InvalidKeySpecException e) {
+			throw new IllegalArgumentException(e);
+		}
+	}
+
+	public static final String DEFAULT_RSA_PRIVATE_KEY =
+			"MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDcWWomvlNGyQhA" +
+			"iB0TcN3sP2VuhZ1xNRPxr58lHswC9Cbtdc2hiSbe/sxAvU1i0O8vaXwICdzRZ1JM" +
+			"g1TohG9zkqqjZDhyw1f1Ic6YR/OhE6NCpqERy97WMFeW6gJd1i5inHj/W19GAbqK" +
+			"LhSHGHqIjyo0wlBf58t+qFt9h/EFBVE/LAGQBsg/jHUQCxsLoVI2aSELGIw2oSDF" +
+			"oiljwLaQl0n9khX5ZbiegN3OkqodzCYHwWyu6aVVj8M1W9RIMiKmKr09s/gf31Nc" +
+			"3WjvjqhFo1rTuurWGgKAxJLL7zlJqAKjGWbIT4P6h/1Kwxjw6X23St3OmhsG6HIn" +
+			"+jl1++MrAgMBAAECggEBAMf820wop3pyUOwI3aLcaH7YFx5VZMzvqJdNlvpg1jbE" +
+			"E2Sn66b1zPLNfOIxLcBG8x8r9Ody1Bi2Vsqc0/5o3KKfdgHvnxAB3Z3dPh2WCDek" +
+			"lCOVClEVoLzziTuuTdGO5/CWJXdWHcVzIjPxmK34eJXioiLaTYqN3XKqKMdpD0ZG" +
+			"mtNTGvGf+9fQ4i94t0WqIxpMpGt7NM4RHy3+Onggev0zLiDANC23mWrTsUgect/7" +
+			"62TYg8g1bKwLAb9wCBT+BiOuCc2wrArRLOJgUkj/F4/gtrR9ima34SvWUyoUaKA0" +
+			"bi4YBX9l8oJwFGHbU9uFGEMnH0T/V0KtIB7qetReywkCgYEA9cFyfBIQrYISV/OA" +
+			"+Z0bo3vh2aL0QgKrSXZ924cLt7itQAHNZ2ya+e3JRlTczi5mnWfjPWZ6eJB/8MlH" +
+			"Gpn12o/POEkU+XjZZSPe1RWGt5g0S3lWqyx9toCS9ACXcN9tGbaqcFSVI73zVTRA" +
+			"8J9grR0fbGn7jaTlTX2tnlOTQ60CgYEA5YjYpEq4L8UUMFkuj+BsS3u0oEBnzuHd" +
+			"I9LEHmN+CMPosvabQu5wkJXLuqo2TxRnAznsA8R3pCLkdPGoWMCiWRAsCn979TdY" +
+			"QbqO2qvBAD2Q19GtY7lIu6C35/enQWzJUMQE3WW0OvjLzZ0l/9mA2FBRR+3F9A1d" +
+			"rBdnmv0c3TcCgYEAi2i+ggVZcqPbtgrLOk5WVGo9F1GqUBvlgNn30WWNTx4zIaEk" +
+			"HSxtyaOLTxtq2odV7Kr3LGiKxwPpn/T+Ief+oIp92YcTn+VfJVGw4Z3BezqbR8lA" +
+			"Uf/+HF5ZfpMrVXtZD4Igs3I33Duv4sCuqhEvLWTc44pHifVloozNxYfRfU0CgYBN" +
+			"HXa7a6cJ1Yp829l62QlJKtx6Ymj95oAnQu5Ez2ROiZMqXRO4nucOjGUP55Orac1a" +
+			"FiGm+mC/skFS0MWgW8evaHGDbWU180wheQ35hW6oKAb7myRHtr4q20ouEtQMdQIF" +
+			"snV39G1iyqeeAsf7dxWElydXpRi2b68i3BIgzhzebQKBgQCdUQuTsqV9y/JFpu6H" +
+			"c5TVvhG/ubfBspI5DhQqIGijnVBzFT//UfIYMSKJo75qqBEyP2EJSmCsunWsAFsM" +
+			"TszuiGTkrKcZy9G0wJqPztZZl2F2+bJgnA6nBEV7g5PA4Af+QSmaIhRwqGDAuROR" +
+			"47jndeyIaMTNETEmOnms+as17g==";
+
+	public static final RSAPrivateKey DEFAULT_PRIVATE_KEY = privateKey();
+
+	private static RSAPrivateKey privateKey() {
+		PKCS8EncodedKeySpec spec = new PKCS8EncodedKeySpec(Base64.getDecoder().decode(DEFAULT_RSA_PRIVATE_KEY));
+		try {
+			return (RSAPrivateKey) kf.generatePrivate(spec);
+		} catch (InvalidKeySpecException e) {
+			throw new IllegalArgumentException(e);
+		}
+	}
 }

+ 70 - 18
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java

@@ -16,11 +16,29 @@
 
 package org.springframework.security.oauth2.jwt;
 
+import java.security.KeyFactory;
+import java.security.NoSuchAlgorithmException;
+import java.security.PrivateKey;
+import java.security.interfaces.RSAPrivateKey;
+import java.security.interfaces.RSAPublicKey;
+import java.security.spec.EncodedKeySpec;
+import java.security.spec.InvalidKeySpecException;
+import java.security.spec.X509EncodedKeySpec;
+import java.text.ParseException;
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.Base64;
+import java.util.Collections;
+import java.util.Date;
+import java.util.Map;
+import javax.crypto.SecretKey;
+
 import com.nimbusds.jose.JOSEException;
 import com.nimbusds.jose.JWSAlgorithm;
 import com.nimbusds.jose.JWSHeader;
 import com.nimbusds.jose.JWSSigner;
 import com.nimbusds.jose.crypto.MACSigner;
+import com.nimbusds.jose.crypto.RSASSASigner;
 import com.nimbusds.jose.proc.BadJOSEException;
 import com.nimbusds.jose.proc.SecurityContext;
 import com.nimbusds.jwt.JWTClaimsSet;
@@ -32,6 +50,7 @@ import okhttp3.mockwebserver.MockWebServer;
 import org.assertj.core.api.Assertions;
 import org.junit.BeforeClass;
 import org.junit.Test;
+
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.HttpStatus;
 import org.springframework.http.RequestEntity;
@@ -44,21 +63,6 @@ import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 import org.springframework.web.client.RestOperations;
 
-import javax.crypto.SecretKey;
-import java.security.KeyFactory;
-import java.security.NoSuchAlgorithmException;
-import java.security.interfaces.RSAPublicKey;
-import java.security.spec.EncodedKeySpec;
-import java.security.spec.InvalidKeySpecException;
-import java.security.spec.X509EncodedKeySpec;
-import java.text.ParseException;
-import java.time.Instant;
-import java.util.Arrays;
-import java.util.Base64;
-import java.util.Collections;
-import java.util.Date;
-import java.util.Map;
-
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.AssertionsForClassTypes.assertThatCode;
 import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
@@ -66,7 +70,9 @@ import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
-import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.*;
+import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri;
+import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withPublicKey;
+import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withSecretKey;
 
 /**
  * Tests for {@link NimbusJwtDecoder}
@@ -266,6 +272,23 @@ public class NimbusJwtDecoderTests {
 				.isEqualTo("test-subject");
 	}
 
+	// gh-7049
+	@Test
+	public void decodeWhenUsingPublicKeyWithKidThenStillUsesKey() throws Exception {
+		RSAPublicKey publicKey = TestKeys.DEFAULT_PUBLIC_KEY;
+		RSAPrivateKey privateKey = TestKeys.DEFAULT_PRIVATE_KEY;
+		JWSHeader header = new JWSHeader.Builder(JWSAlgorithm.RS256).keyID("one").build();
+		JWTClaimsSet claimsSet = new JWTClaimsSet.Builder()
+				.subject("test-subject")
+				.expirationTime(Date.from(Instant.now().plusSeconds(60)))
+				.build();
+		SignedJWT signedJwt = signedJwt(privateKey, header, claimsSet);
+		NimbusJwtDecoder decoder = withPublicKey(publicKey).signatureAlgorithm(SignatureAlgorithm.RS256).build();
+		assertThat(decoder.decode(signedJwt.serialize()))
+				.extracting(Jwt::getSubject)
+				.isEqualTo("test-subject");
+	}
+
 	@Test
 	public void decodeWhenSignatureMismatchesAlgorithmThenThrowsException() throws Exception {
 		NimbusJwtDecoder decoder = withPublicKey(key()).signatureAlgorithm(SignatureAlgorithm.RS512).build();
@@ -315,7 +338,23 @@ public class NimbusJwtDecoderTests {
 		NimbusJwtDecoder decoder = withSecretKey(secretKey).macAlgorithm(MacAlgorithm.HS512).build();
 		assertThatThrownBy(() -> decoder.decode(signedJWT.serialize()))
 				.isInstanceOf(JwtException.class)
-				.hasMessage("An error occurred while attempting to decode the Jwt: Signed JWT rejected: Another algorithm expected, or no matching key(s) found");
+				.hasMessageContaining("Unsupported algorithm of HS256");
+	}
+
+	// gh-7056
+	@Test
+	public void decodeWhenUsingSecertKeyWithKidThenStillUsesKey() throws Exception {
+		SecretKey secretKey = TestKeys.DEFAULT_SECRET_KEY;
+		JWSHeader header = new JWSHeader.Builder(JWSAlgorithm.HS256).keyID("one").build();
+		JWTClaimsSet claimsSet = new JWTClaimsSet.Builder()
+				.subject("test-subject")
+				.expirationTime(Date.from(Instant.now().plusSeconds(60)))
+				.build();
+		SignedJWT signedJwt = signedJwt(secretKey, header, claimsSet);
+		NimbusJwtDecoder decoder = withSecretKey(secretKey).macAlgorithm(MacAlgorithm.HS256).build();
+		assertThat(decoder.decode(signedJwt.serialize()))
+				.extracting(Jwt::getSubject)
+				.isEqualTo("test-subject");
 	}
 
 	private RSAPublicKey key() throws InvalidKeySpecException {
@@ -325,8 +364,21 @@ public class NimbusJwtDecoderTests {
 	}
 
 	private SignedJWT signedJwt(SecretKey secretKey, MacAlgorithm jwsAlgorithm, JWTClaimsSet claimsSet) throws Exception {
-		SignedJWT signedJWT = new SignedJWT(new JWSHeader(JWSAlgorithm.parse(jwsAlgorithm.getName())), claimsSet);
+		return signedJwt(secretKey, new JWSHeader(JWSAlgorithm.parse(jwsAlgorithm.getName())), claimsSet);
+	}
+
+	private SignedJWT signedJwt(SecretKey secretKey, JWSHeader header, JWTClaimsSet claimsSet) throws Exception {
 		JWSSigner signer = new MACSigner(secretKey);
+		return signedJwt(signer, header, claimsSet);
+	}
+
+	private SignedJWT signedJwt(PrivateKey privateKey, JWSHeader header, JWTClaimsSet claimsSet) throws Exception {
+		JWSSigner signer = new RSASSASigner(privateKey);
+		return signedJwt(signer, header, claimsSet);
+	}
+
+	private SignedJWT signedJwt(JWSSigner signer, JWSHeader header, JWTClaimsSet claimsSet) throws Exception {
+		SignedJWT signedJWT = new SignedJWT(header, claimsSet);
 		signedJWT.sign(signer);
 		return signedJWT;
 	}