浏览代码

Add tests for unknown KID error

Issue gh-11621
tinolazreg 3 年之前
父节点
当前提交
3e73fa6954

+ 80 - 0
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java

@@ -36,12 +36,16 @@ import java.util.concurrent.Callable;
 
 import javax.crypto.SecretKey;
 
+import com.nimbusds.jose.JOSEException;
 import com.nimbusds.jose.JOSEObjectType;
 import com.nimbusds.jose.JWSAlgorithm;
 import com.nimbusds.jose.JWSHeader;
 import com.nimbusds.jose.JWSSigner;
 import com.nimbusds.jose.crypto.MACSigner;
 import com.nimbusds.jose.crypto.RSASSASigner;
+import com.nimbusds.jose.jwk.JWKSet;
+import com.nimbusds.jose.jwk.RSAKey;
+import com.nimbusds.jose.jwk.gen.RSAKeyGenerator;
 import com.nimbusds.jose.jwk.source.JWKSource;
 import com.nimbusds.jose.proc.BadJOSEException;
 import com.nimbusds.jose.proc.DefaultJOSEObjectTypeVerifier;
@@ -82,6 +86,7 @@ import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoInteractions;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
@@ -660,6 +665,81 @@ public class NimbusJwtDecoderTests {
 		verifyNoInteractions(restOperations);
 	}
 
+	@Test
+	public void decodeWhenCacheAndUnknownKidShouldTriggerFetchOfJwkSet() throws JOSEException {
+		RestOperations restOperations = mock(RestOperations.class);
+
+		Cache cache = mock(Cache.class);
+		given(cache.get(eq(JWK_SET_URI), any(Callable.class))).willReturn(JWK_SET);
+
+		RSAKey rsaJWK = new RSAKeyGenerator(2048)
+				.keyID("new_kid")
+				.generate();
+		String jwkSetWithNewKid = new JWKSet(rsaJWK).toPublicJWKSet().toString();
+		given(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
+				.willReturn(new ResponseEntity<>(jwkSetWithNewKid, HttpStatus.OK));
+
+		// @formatter:off
+		NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI)
+				.cache(cache)
+				.restOperations(restOperations)
+				.build();
+		// @formatter:on
+
+		// Decode JWT with new KID
+		JWSSigner signer = new RSASSASigner(rsaJWK);
+		JWTClaimsSet claimsSet = new JWTClaimsSet.Builder()
+				.expirationTime(Date.from(Instant.now().plusSeconds(60)))
+				.build();
+		SignedJWT signedJWT = new SignedJWT(new JWSHeader.Builder(JWSAlgorithm.RS256).keyID(rsaJWK.getKeyID()).build(), claimsSet);
+		signedJWT.sign(signer);
+		String token = signedJWT.serialize();
+
+		jwtDecoder.decode(token);
+
+		ArgumentCaptor<RequestEntity>  requestEntityCaptor = ArgumentCaptor.forClass(RequestEntity.class);
+		verify(restOperations).exchange(requestEntityCaptor.capture(), eq(String.class));
+		verifyNoMoreInteractions(restOperations);
+		assertThat(requestEntityCaptor.getValue().getHeaders().getAccept()).contains(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON);
+	}
+
+	@Test
+	public void decodeWithoutCacheSpecifiedAndUnknownKidShouldTriggerFetchOfJwkSet() throws JOSEException {
+		RestOperations restOperations = mock(RestOperations.class);
+
+		RSAKey rsaJWK = new RSAKeyGenerator(2048)
+				.keyID("new_kid")
+				.generate();
+		String jwkSetWithNewKid = new JWKSet(rsaJWK).toPublicJWKSet().toString();
+		given(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
+				.willReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK), new ResponseEntity<>(jwkSetWithNewKid, HttpStatus.OK));
+
+		// @formatter:off
+		NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI)
+				.restOperations(restOperations)
+				.build();
+		// @formatter:on
+		jwtDecoder.decode(SIGNED_JWT);
+
+		// Decode JWT with new KID
+		JWSSigner signer = new RSASSASigner(rsaJWK);
+		JWTClaimsSet claimsSet = new JWTClaimsSet.Builder()
+				.expirationTime(Date.from(Instant.now().plusSeconds(60)))
+				.build();
+		SignedJWT signedJWT = new SignedJWT(new JWSHeader.Builder(JWSAlgorithm.RS256).keyID(rsaJWK.getKeyID()).build(), claimsSet);
+		signedJWT.sign(signer);
+		String token = signedJWT.serialize();
+
+		jwtDecoder.decode(token);
+
+		ArgumentCaptor<RequestEntity>  requestEntityCaptor = ArgumentCaptor.forClass(RequestEntity.class);
+		verify(restOperations, times(2)).exchange(requestEntityCaptor.capture(), eq(String.class));
+		verifyNoMoreInteractions(restOperations);
+		List<RequestEntity> requestEntities = requestEntityCaptor.getAllValues();
+		assertThat(requestEntities.get(0).getHeaders().getAccept()).contains(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON);
+		assertThat(requestEntities.get(1).getHeaders().getAccept()).contains(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON);
+	}
+
 	@Test
 	public void decodeWhenCacheIsConfiguredAndValueLoaderErrorsThenThrowsJwtException() {
 		Cache cache = new ConcurrentMapCache("test-jwk-set-cache");