Browse Source

Add Cache to NimbusJwtDecoderJwkSetUriBuilder

PR gh-8332
Mykyta Bezverkhyi 5 years ago
parent
commit
9133cc24e4

+ 66 - 1
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java

@@ -34,6 +34,8 @@ import javax.crypto.SecretKey;
 import com.nimbusds.jose.JOSEException;
 import com.nimbusds.jose.JWSAlgorithm;
 import com.nimbusds.jose.RemoteKeySourceException;
+import com.nimbusds.jose.jwk.JWKSet;
+import com.nimbusds.jose.jwk.source.JWKSetCache;
 import com.nimbusds.jose.jwk.source.JWKSource;
 import com.nimbusds.jose.jwk.source.RemoteJWKSet;
 import com.nimbusds.jose.proc.JWSKeySelector;
@@ -49,6 +51,7 @@ import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
 import com.nimbusds.jwt.proc.DefaultJWTProcessor;
 import com.nimbusds.jwt.proc.JWTProcessor;
 
+import org.springframework.cache.Cache;
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
@@ -68,6 +71,7 @@ import org.springframework.web.client.RestTemplate;
  *
  * @author Josh Cummings
  * @author Joe Grandja
+ * @author Mykyta Bezverkhyi
  * @since 5.2
  */
 public final class NimbusJwtDecoder implements JwtDecoder {
@@ -215,6 +219,7 @@ public final class NimbusJwtDecoder implements JwtDecoder {
 		private String jwkSetUri;
 		private Set<SignatureAlgorithm> signatureAlgorithms = new HashSet<>();
 		private RestOperations restOperations = new RestTemplate();
+		private Cache cache;
 
 		private JwkSetUriJwtDecoderBuilder(String jwkSetUri) {
 			Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty");
@@ -264,6 +269,20 @@ public final class NimbusJwtDecoder implements JwtDecoder {
 			return this;
 		}
 
+		/**
+		 * Use the given {@link Cache} to store
+		 * <a href="https://tools.ietf.org/html/rfc7517#section-5">JWK Set</a>.
+		 *
+		 * @param cache the {@link Cache} to be used to store JWK Set
+		 * @return a {@link JwkSetUriJwtDecoderBuilder} for further configurations
+		 * @since 5.4
+		 */
+		public JwkSetUriJwtDecoderBuilder cache(Cache cache) {
+			Assert.notNull(cache, "cache cannot be null");
+			this.cache = cache;
+			return this;
+		}
+
 		JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSource) {
 			if (this.signatureAlgorithms.isEmpty()) {
 				return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
@@ -280,9 +299,17 @@ public final class NimbusJwtDecoder implements JwtDecoder {
 			}
 		}
 
+		JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever) {
+			if (this.cache == null) {
+				return new RemoteJWKSet<>(toURL(this.jwkSetUri), jwkSetRetriever);
+			}
+			ResourceRetriever cachingJwkSetRetriever = new CachingResourceRetriever(this.cache, jwkSetRetriever);
+			return new RemoteJWKSet<>(toURL(this.jwkSetUri), cachingJwkSetRetriever, new NoOpJwkSetCache());
+		}
+
 		JWTProcessor<SecurityContext> processor() {
 			ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever(this.restOperations);
-			JWKSource<SecurityContext> jwkSource = new RemoteJWKSet<>(toURL(this.jwkSetUri), jwkSetRetriever);
+			JWKSource<SecurityContext> jwkSource = jwkSource(jwkSetRetriever);
 			ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
 			jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource));
 
@@ -309,6 +336,44 @@ public final class NimbusJwtDecoder implements JwtDecoder {
 			}
 		}
 
+		private static class NoOpJwkSetCache implements JWKSetCache {
+			@Override
+			public void put(JWKSet jwkSet) {
+			}
+
+			@Override
+			public JWKSet get() {
+				return null;
+			}
+
+			@Override
+			public boolean requiresRefresh() {
+				return true;
+			}
+		}
+
+		private static class CachingResourceRetriever implements ResourceRetriever {
+			private final Cache cache;
+			private final ResourceRetriever resourceRetriever;
+
+			CachingResourceRetriever(Cache cache, ResourceRetriever resourceRetriever) {
+				this.cache = cache;
+				this.resourceRetriever = resourceRetriever;
+			}
+
+			@Override
+			public Resource retrieveResource(URL url) throws IOException {
+				String jwkSet;
+				try {
+					jwkSet = cache.get(url.toString(), () -> resourceRetriever.retrieveResource(url).getContent());
+				} catch (Exception ex) {
+					throw new IOException(ex);
+				}
+
+				return new Resource(jwkSet, "UTF-8");
+			}
+		}
+
 		private static class RestOperationsResourceRetriever implements ResourceRetriever {
 			private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
 			private final RestOperations restOperations;

+ 88 - 2
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java

@@ -32,6 +32,7 @@ import java.util.Collections;
 import java.util.Date;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.Callable;
 import javax.crypto.SecretKey;
 
 import com.nimbusds.jose.JWSAlgorithm;
@@ -55,6 +56,8 @@ import org.junit.BeforeClass;
 import org.junit.Test;
 
 import org.mockito.ArgumentCaptor;
+import org.springframework.cache.Cache;
+import org.springframework.cache.concurrent.ConcurrentMapCache;
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.HttpStatus;
 import org.springframework.http.MediaType;
@@ -66,6 +69,7 @@ import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
 import org.springframework.security.oauth2.jose.TestKeys;
 import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
+import org.springframework.web.client.RestClientException;
 import org.springframework.web.client.RestOperations;
 
 import static org.assertj.core.api.Assertions.assertThat;
@@ -75,6 +79,8 @@ import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.when;
 import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri;
 import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withPublicKey;
@@ -85,6 +91,7 @@ import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withSecre
  *
  * @author Josh Cummings
  * @author Joe Grandja
+ * @author Mykyta Bezverkhyi
  */
 public class NimbusJwtDecoderTests {
 	private static final String JWK_SET = "{\"keys\":[{\"p\":\"49neceJFs8R6n7WamRGy45F5Tv0YM-R2ODK3eSBUSLOSH2tAqjEVKOkLE5fiNA3ygqq15NcKRadB2pTVf-Yb5ZIBuKzko8bzYIkIqYhSh_FAdEEr0vHF5fq_yWSvc6swsOJGqvBEtuqtJY027u-G2gAQasCQdhyejer68zsTn8M\",\"kty\":\"RSA\",\"q\":\"tWR-ysspjZ73B6p2vVRVyHwP3KQWL5KEQcdgcmMOE_P_cPs98vZJfLhxobXVmvzuEWBpRSiqiuyKlQnpstKt94Cy77iO8m8ISfF3C9VyLWXi9HUGAJb99irWABFl3sNDff5K2ODQ8CmuXLYM25OwN3ikbrhEJozlXg_NJFSGD4E\",\"d\":\"FkZHYZlw5KSoqQ1i2RA2kCUygSUOf1OqMt3uomtXuUmqKBm_bY7PCOhmwbvbn4xZYEeHuTR8Xix-0KpHe3NKyWrtRjkq1T_un49_1LLVUhJ0dL-9_x0xRquVjhl_XrsRXaGMEHs8G9pLTvXQ1uST585gxIfmCe0sxPZLvwoic-bXf64UZ9BGRV3lFexWJQqCZp2S21HfoU7wiz6kfLRNi-K4xiVNB1gswm_8o5lRuY7zB9bRARQ3TS2G4eW7p5sxT3CgsGiQD3_wPugU8iDplqAjgJ5ofNJXZezoj0t6JMB_qOpbrmAM1EnomIPebSLW7Ky9SugEd6KMdL5lW6AuAQ\",\"e\":\"AQAB\",\"use\":\"sig\",\"kid\":\"one\",\"qi\":\"wdkFu_tV2V1l_PWUUimG516Zvhqk2SWDw1F7uNDD-Lvrv_WNRIJVzuffZ8WYiPy8VvYQPJUrT2EXL8P0ocqwlaSTuXctrORcbjwgxDQDLsiZE0C23HYzgi0cofbScsJdhcBg7d07LAf7cdJWG0YVl1FkMCsxUlZ2wTwHfKWf-v4\",\"dp\":\"uwnPxqC-IxG4r33-SIT02kZC1IqC4aY7PWq0nePiDEQMQWpjjNH50rlq9EyLzbtdRdIouo-jyQXB01K15-XXJJ60dwrGLYNVqfsTd0eGqD1scYJGHUWG9IDgCsxyEnuG3s0AwbW2UolWVSsU2xMZGb9PurIUZECeD1XDZwMp2s0\",\"dq\":\"hra786AunB8TF35h8PpROzPoE9VJJMuLrc6Esm8eZXMwopf0yhxfN2FEAvUoTpLJu93-UH6DKenCgi16gnQ0_zt1qNNIVoRfg4rw_rjmsxCYHTVL3-RDeC8X_7TsEySxW0EgFTHh-nr6I6CQrAJjPM88T35KHtdFATZ7BCBB8AE\",\"n\":\"oXJ8OyOv_eRnce4akdanR4KYRfnC2zLV4uYNQpcFn6oHL0dj7D6kxQmsXoYgJV8ZVDn71KGmuLvolxsDncc2UrhyMBY6DVQVgMSVYaPCTgW76iYEKGgzTEw5IBRQL9w3SRJWd3VJTZZQjkXef48Ocz06PGF3lhbz4t5UEZtdF4rIe7u-977QwHuh7yRPBQ3sII-cVoOUMgaXB9SHcGF2iZCtPzL_IffDUcfhLQteGebhW8A6eUHgpD5A1PQ-JCw_G7UOzZAjjDjtNM2eqm8j-Ms_gqnm4MiCZ4E-9pDN77CAAPVN7kuX6ejs9KBXpk01z48i9fORYk9u7rAkh1HuQw\"}]}";
@@ -247,6 +254,21 @@ public class NimbusJwtDecoderTests {
 		}
 	}
 
+	@Test
+	public void shouldThrowJwtExceptionWhenJwkSetEndpointHasNotRespondedAndCacheIsConfigured() throws Exception {
+		try ( MockWebServer server = new MockWebServer() ) {
+			Cache cache = new ConcurrentMapCache("test-jwk-set-cache");
+			String jwkSetUri = server.url("/.well-known/jwks.json").toString();
+			NimbusJwtDecoder jwtDecoder = withJwkSetUri(jwkSetUri).cache(cache).build();
+
+			server.shutdown();
+			assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT))
+					.isInstanceOf(JwtException.class)
+					.isNotInstanceOf(BadJwtException.class)
+					.hasMessageContaining("An error occurred while attempting to decode the Jwt");
+		}
+	}
+
 	@Test
 	public void withJwkSetUriWhenNullOrEmptyThenThrowsException() {
 		Assertions.assertThatCode(() -> withJwkSetUri(null)).isInstanceOf(IllegalArgumentException.class);
@@ -264,6 +286,12 @@ public class NimbusJwtDecoderTests {
 		Assertions.assertThatCode(() -> builder.restOperations(null)).isInstanceOf(IllegalArgumentException.class);
 	}
 
+	@Test
+	public void shouldThrowIllegalArgumentExceptionWhenJwkSetCacheIsNull() {
+		NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder builder = withJwkSetUri(JWK_SET_URI);
+		Assertions.assertThatCode(() -> builder.cache(null)).isInstanceOf(IllegalArgumentException.class);
+	}
+
 	@Test
 	public void withPublicKeyWhenNullThenThrowsException() {
 		assertThatThrownBy(() -> withPublicKey(null))
@@ -425,7 +453,7 @@ public class NimbusJwtDecoderTests {
 		RestOperations restOperations = mock(RestOperations.class);
 		when(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
 				.thenReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK));
-		JWTProcessor<SecurityContext> processor = withJwkSetUri("https://issuer/.well-known/jwks.json")
+		JWTProcessor<SecurityContext> processor = withJwkSetUri(JWK_SET_URI)
 				.restOperations(restOperations)
 				.processor();
 		NimbusJwtDecoder jwtDecoder = new NimbusJwtDecoder(processor);
@@ -436,6 +464,64 @@ public class NimbusJwtDecoderTests {
 		assertThat(acceptHeader).contains(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON);
 	}
 
+	@Test
+	public void shouldStoreRetrievedJwkSetToCache() {
+		// given
+		Cache cache = new ConcurrentMapCache("test-jwk-set-cache");
+		RestOperations restOperations = mock(RestOperations.class);
+		when(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
+				.thenReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK));
+		NimbusJwtDecoder jwtDecoder = withJwkSetUri(JWK_SET_URI)
+				.restOperations(restOperations)
+				.cache(cache)
+				.build();
+		// when
+		jwtDecoder.decode(SIGNED_JWT);
+		// then
+		assertThat(cache.get(JWK_SET_URI, String.class)).isEqualTo(JWK_SET);
+		ArgumentCaptor<RequestEntity> requestEntityCaptor = ArgumentCaptor.forClass(RequestEntity.class);
+		verify(restOperations).exchange(requestEntityCaptor.capture(), eq(String.class));
+		verifyNoMoreInteractions(restOperations);
+		List<MediaType> acceptHeader = requestEntityCaptor.getValue().getHeaders().getAccept();
+		assertThat(acceptHeader).contains(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON);
+	}
+
+	@Test
+	public void shouldDecodeJwtUsingJwkSetCache() {
+		// given
+		RestOperations restOperations = mock(RestOperations.class);
+		Cache cache = mock(Cache.class);
+		when(cache.get(eq(JWK_SET_URI), any(Callable.class))).thenReturn(JWK_SET);
+		NimbusJwtDecoder jwtDecoder = withJwkSetUri(JWK_SET_URI)
+				.cache(cache)
+				.restOperations(restOperations)
+				.build();
+		// when
+		jwtDecoder.decode(SIGNED_JWT);
+		// then
+		verify(cache).get(eq(JWK_SET_URI), any(Callable.class));
+		verifyNoMoreInteractions(cache);
+		verifyNoInteractions(restOperations);
+	}
+
+	@Test
+	public void shouldThrowJwtExceptionWhenExceptionOccurredWhileRetrievingJwkSetInsideCachingRetriever() {
+		// given
+		Cache cache = new ConcurrentMapCache("test-jwk-set-cache");
+		RestOperations restOperations = mock(RestOperations.class);
+		when(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
+				.thenThrow(new RestClientException("Cannot retrieve JWK Set"));
+		NimbusJwtDecoder jwtDecoder = withJwkSetUri(JWK_SET_URI)
+				.restOperations(restOperations)
+				.cache(cache)
+				.build();
+		// then
+		assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT))
+				.isInstanceOf(JwtException.class)
+				.isNotInstanceOf(BadJwtException.class)
+				.hasMessageContaining("An error occurred while attempting to decode the Jwt");
+	}
+
 	private RSAPublicKey key() throws InvalidKeySpecException {
 		byte[] decoded = Base64.getDecoder().decode(VERIFY_KEY.getBytes());
 		EncodedKeySpec spec = new X509EncodedKeySpec(decoded);
@@ -466,7 +552,7 @@ public class NimbusJwtDecoderTests {
 		RestOperations restOperations = mock(RestOperations.class);
 		when(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
 				.thenReturn(new ResponseEntity<>(jwkResponse, HttpStatus.OK));
-		return withJwkSetUri("https://issuer/.well-known/jwks.json")
+		return withJwkSetUri(JWK_SET_URI)
 				.restOperations(restOperations)
 				.processor();
 	}