瀏覽代碼

Expose RestOperations in NimbusJwtDecoderJwkSupport

Fixes gh-5603
Joe Grandja 7 年之前
父節點
當前提交
16fe1c5b52

+ 58 - 19
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupport.java

@@ -15,13 +15,6 @@
  */
 package org.springframework.security.oauth2.jwt;
 
-import java.net.MalformedURLException;
-import java.net.URL;
-import java.text.ParseException;
-import java.time.Instant;
-import java.util.LinkedHashMap;
-import java.util.Map;
-
 import com.nimbusds.jose.JWSAlgorithm;
 import com.nimbusds.jose.RemoteKeySourceException;
 import com.nimbusds.jose.jwk.source.JWKSource;
@@ -29,7 +22,7 @@ import com.nimbusds.jose.jwk.source.RemoteJWKSet;
 import com.nimbusds.jose.proc.JWSKeySelector;
 import com.nimbusds.jose.proc.JWSVerificationKeySelector;
 import com.nimbusds.jose.proc.SecurityContext;
-import com.nimbusds.jose.util.DefaultResourceRetriever;
+import com.nimbusds.jose.util.Resource;
 import com.nimbusds.jose.util.ResourceRetriever;
 import com.nimbusds.jwt.JWT;
 import com.nimbusds.jwt.JWTClaimsSet;
@@ -37,12 +30,27 @@ import com.nimbusds.jwt.JWTParser;
 import com.nimbusds.jwt.SignedJWT;
 import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
 import com.nimbusds.jwt.proc.DefaultJWTProcessor;
-
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpMethod;
+import org.springframework.http.MediaType;
+import org.springframework.http.RequestEntity;
+import org.springframework.http.ResponseEntity;
 import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
 import org.springframework.util.Assert;
+import org.springframework.web.client.RestOperations;
+import org.springframework.web.client.RestTemplate;
+
+import java.io.IOException;
+import java.net.MalformedURLException;
+import java.net.URL;
+import java.text.ParseException;
+import java.time.Instant;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
 
 /**
- * An implementation of a {@link JwtDecoder} that "decodes" a
+ * An implementation of a {@link JwtDecoder} that "decodes" a
  * JSON Web Token (JWT) and additionally verifies it's digital signature if the JWT is a
  * JSON Web Signature (JWS). The public key used for verification is obtained from the
  * JSON Web Key (JWK) Set {@code URL} supplied via the constructor.
@@ -63,9 +71,9 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder {
 	private static final String DECODING_ERROR_MESSAGE_TEMPLATE =
 			"An error occurred while attempting to decode the Jwt: %s";
 
-	private final URL jwkSetUrl;
 	private final JWSAlgorithm jwsAlgorithm;
 	private final ConfigurableJWTProcessor<SecurityContext> jwtProcessor;
+	private final RestOperationsResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever();
 
 	/**
 	 * Constructs a {@code NimbusJwtDecoderJwkSupport} using the provided parameters.
@@ -85,18 +93,15 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder {
 	public NimbusJwtDecoderJwkSupport(String jwkSetUrl, String jwsAlgorithm) {
 		Assert.hasText(jwkSetUrl, "jwkSetUrl cannot be empty");
 		Assert.hasText(jwsAlgorithm, "jwsAlgorithm cannot be empty");
+		JWKSource jwkSource;
 		try {
-			this.jwkSetUrl = new URL(jwkSetUrl);
+			jwkSource = new RemoteJWKSet(new URL(jwkSetUrl), this.jwkSetRetriever);
 		} catch (MalformedURLException ex) {
-			throw new IllegalArgumentException("Invalid JWK Set URL " + jwkSetUrl + " : " + ex.getMessage(), ex);
+			throw new IllegalArgumentException("Invalid JWK Set URL \"" + jwkSetUrl + "\" : " + ex.getMessage(), ex);
 		}
 		this.jwsAlgorithm = JWSAlgorithm.parse(jwsAlgorithm);
-
-		ResourceRetriever jwkSetRetriever = new DefaultResourceRetriever(30000, 30000);
-		JWKSource jwkSource = new RemoteJWKSet(this.jwkSetUrl, jwkSetRetriever);
 		JWSKeySelector<SecurityContext> jwsKeySelector =
 			new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
-
 		this.jwtProcessor = new DefaultJWTProcessor<>();
 		this.jwtProcessor.setJWSKeySelector(jwsKeySelector);
 	}
@@ -104,10 +109,9 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder {
 	@Override
 	public Jwt decode(String token) throws JwtException {
 		JWT jwt = this.parse(token);
-		if ( jwt instanceof SignedJWT ) {
+		if (jwt instanceof SignedJWT) {
 			return this.createJwt(token, jwt);
 		}
-
 		throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm());
 	}
 
@@ -158,4 +162,39 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder {
 
 		return jwt;
 	}
+
+	/**
+	 * Sets the {@link RestOperations} used when requesting the JSON Web Key (JWK) Set.
+	 *
+	 * @since 5.1
+	 * @param restOperations the {@link RestOperations} used when requesting the JSON Web Key (JWK) Set
+	 */
+	public final void setRestOperations(RestOperations restOperations) {
+		Assert.notNull(restOperations, "restOperations cannot be null");
+		this.jwkSetRetriever.restOperations = restOperations;
+	}
+
+	private static class RestOperationsResourceRetriever implements ResourceRetriever {
+		private RestOperations restOperations = new RestTemplate();
+
+		@Override
+		public Resource retrieveResource(URL url) throws IOException {
+			HttpHeaders headers = new HttpHeaders();
+			headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8));
+
+			ResponseEntity<String> response;
+			try {
+				RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, url.toURI());
+				response = this.restOperations.exchange(request, String.class);
+			} catch (Exception ex) {
+				throw new IOException(ex);
+			}
+
+			if (response.getStatusCodeValue() != 200) {
+				throw new IOException(response.toString());
+			}
+
+			return new Resource(response.getBody(), "UTF-8");
+		}
+	}
 }

+ 44 - 30
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupportTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -24,23 +24,22 @@ import com.nimbusds.jwt.SignedJWT;
 import com.nimbusds.jwt.proc.DefaultJWTProcessor;
 import okhttp3.mockwebserver.MockResponse;
 import okhttp3.mockwebserver.MockWebServer;
+import org.assertj.core.api.Assertions;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.powermock.core.classloader.annotations.PowerMockIgnore;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
-
+import org.springframework.http.RequestEntity;
 import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
+import org.springframework.web.client.RestTemplate;
 
 import static org.assertj.core.api.AssertionsForClassTypes.assertThatCode;
 import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyString;
-import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.ArgumentMatchers.*;
 import static org.mockito.Mockito.mock;
-import static org.powermock.api.mockito.PowerMockito.mockStatic;
-import static org.powermock.api.mockito.PowerMockito.when;
-import static org.powermock.api.mockito.PowerMockito.whenNew;
+import static org.mockito.Mockito.verify;
+import static org.powermock.api.mockito.PowerMockito.*;
 
 /**
  * Tests for {@link NimbusJwtDecoderJwkSupport}.
@@ -62,6 +61,8 @@ public class NimbusJwtDecoderJwkSupportTests {
 	private static final String MALFORMED_JWT = "eyJhbGciOiJSUzI1NiJ9.eyJuYmYiOnt9LCJleHAiOjQ2ODQyMjUwODd9.guoQvujdWvd3xw7FYQEn4D6-gzM_WqFvXdmvAUNSLbxG7fv2_LLCNujPdrBHJoYPbOwS1BGNxIKQWS1tylvqzmr1RohQ-RZ2iAM1HYQzboUlkoMkcd8ENM__ELqho8aNYBfqwkNdUOyBFoy7Syu_w2SoJADw2RTjnesKO6CVVa05bW118pDS4xWxqC4s7fnBjmZoTn4uQ-Kt9YSQZQk8YQxkJSiyanozzgyfgXULA6mPu1pTNU3FVFaK1i1av_xtH_zAPgb647ZeaNe4nahgqC5h8nhOlm8W2dndXbwAt29nd2ZWBsru_QwZz83XSKLhTPFz-mPBByZZDsyBbIHf9A";
 	private static final String UNSIGNED_JWT = "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJleHAiOi0yMDMzMjI0OTcsImp0aSI6IjEyMyIsInR5cCI6IkpXVCJ9.";
 
+	private NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM);
+
 	@Test
 	public void constructorWhenJwkSetUrlIsNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> new NimbusJwtDecoderJwkSupport(null))
@@ -80,10 +81,15 @@ public class NimbusJwtDecoderJwkSupportTests {
 				.isInstanceOf(IllegalArgumentException.class);
 	}
 
+	@Test
+	public void setRestOperationsWhenNullThenThrowIllegalArgumentException() {
+		Assertions.assertThatThrownBy(() -> this.jwtDecoder.setRestOperations(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
 	@Test
 	public void decodeWhenJwtInvalidThenThrowJwtException() {
-		NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM);
-		assertThatThrownBy(() -> jwtDecoder.decode("invalid"))
+		assertThatThrownBy(() -> this.jwtDecoder.decode("invalid"))
 				.isInstanceOf(JwtException.class);
 	}
 
@@ -103,16 +109,14 @@ public class NimbusJwtDecoderJwkSupportTests {
 		JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().audience("resource1").build();
 		when(jwtProcessor.process(any(JWT.class), eq(null))).thenReturn(jwtClaimsSet);
 
-		NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM);
+		NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL);
 		assertThatCode(() -> jwtDecoder.decode("encoded-jwt")).doesNotThrowAnyException();
 	}
 
 	// gh-5457
 	@Test
-	public void decodeWhenPlainJwtThenExceptionDoesNotMentionClass() throws Exception {
-		NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM);
-
-		assertThatCode(() -> jwtDecoder.decode(UNSIGNED_JWT))
+	public void decodeWhenPlainJwtThenExceptionDoesNotMentionClass() {
+		assertThatCode(() -> this.jwtDecoder.decode(UNSIGNED_JWT))
 				.isInstanceOf(JwtException.class)
 				.hasMessageContaining("Unsupported algorithm of none");
 	}
@@ -122,12 +126,11 @@ public class NimbusJwtDecoderJwkSupportTests {
 		try ( MockWebServer server = new MockWebServer() ) {
 			server.enqueue(new MockResponse().setBody(JWK_SET));
 			String jwkSetUrl = server.url("/.well-known/jwks.json").toString();
-
-			NimbusJwtDecoderJwkSupport decoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
-
-			assertThatCode(() -> decoder.decode(MALFORMED_JWT))
+			NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
+			assertThatCode(() -> jwtDecoder.decode(MALFORMED_JWT))
 					.isInstanceOf(JwtException.class)
 					.hasMessage("An error occurred while attempting to decode the Jwt: Malformed payload");
+			server.shutdown();
 		}
 	}
 
@@ -136,28 +139,39 @@ public class NimbusJwtDecoderJwkSupportTests {
 		try ( MockWebServer server = new MockWebServer() ) {
 			server.enqueue(new MockResponse().setBody(MALFORMED_JWK_SET));
 			String jwkSetUrl = server.url("/.well-known/jwks.json").toString();
-
-			NimbusJwtDecoderJwkSupport decoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
-
-			assertThatCode(() -> decoder.decode(SIGNED_JWT))
+			NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
+			assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT))
 					.isInstanceOf(JwtException.class)
 					.hasMessage("An error occurred while attempting to decode the Jwt: Malformed Jwk set");
+			server.shutdown();
 		}
 	}
 
 	@Test
-	public void decodeWhenJwkEndpointIsUnresponsiveThenRetrunsJwtException() throws Exception {
+	public void decodeWhenJwkEndpointIsUnresponsiveThenReturnsJwtException() throws Exception {
 		try ( MockWebServer server = new MockWebServer() ) {
 			server.enqueue(new MockResponse().setBody(MALFORMED_JWK_SET));
 			String jwkSetUrl = server.url("/.well-known/jwks.json").toString();
-
-			NimbusJwtDecoderJwkSupport decoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
-
-			server.shutdown();
-
-			assertThatCode(() -> decoder.decode(SIGNED_JWT))
+			NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
+			assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT))
 					.isInstanceOf(JwtException.class)
 					.hasMessageContaining("An error occurred while attempting to decode the Jwt");
+			server.shutdown();
+		}
+	}
+
+	// gh-5603
+	@Test
+	public void decodeWhenCustomRestOperationsSetThenUsed() throws Exception {
+		try ( MockWebServer server = new MockWebServer() ) {
+			server.enqueue(new MockResponse().setBody(JWK_SET));
+			String jwkSetUrl = server.url("/.well-known/jwks.json").toString();
+			NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
+			RestTemplate restTemplate = spy(new RestTemplate());
+			jwtDecoder.setRestOperations(restTemplate);
+			assertThatCode(() -> jwtDecoder.decode(SIGNED_JWT)).doesNotThrowAnyException();
+			verify(restTemplate).exchange(any(RequestEntity.class), eq(String.class));
+			server.shutdown();
 		}
 	}
 }