Просмотр исходного кода

add media type jwk-set+json to accept header

Fixes gh-7290
Bouke Nijhuis 6 лет назад
Родитель
Сommit
bf78e43403

+ 3 - 1
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupport.java

@@ -20,6 +20,7 @@ import java.net.MalformedURLException;
 import java.net.URL;
 import java.text.ParseException;
 import java.time.Instant;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.LinkedHashMap;
 import java.util.Map;
@@ -210,12 +211,13 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder {
 	}
 
 	private static class RestOperationsResourceRetriever implements ResourceRetriever {
+		private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
 		private RestOperations restOperations = new RestTemplate();
 
 		@Override
 		public Resource retrieveResource(URL url) throws IOException {
 			HttpHeaders headers = new HttpHeaders();
-			headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8));
+			headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON));
 
 			ResponseEntity<String> response;
 			try {

+ 23 - 0
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupportTests.java

@@ -17,6 +17,7 @@ package org.springframework.security.oauth2.jwt;
 
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.List;
 import java.util.Map;
 
 import com.nimbusds.jose.JWSAlgorithm;
@@ -31,16 +32,21 @@ import okhttp3.mockwebserver.MockWebServer;
 import org.assertj.core.api.Assertions;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
 import org.powermock.core.classloader.annotations.PowerMockIgnore;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
 
 import org.springframework.core.convert.converter.Converter;
+import org.springframework.http.HttpStatus;
+import org.springframework.http.MediaType;
 import org.springframework.http.RequestEntity;
+import org.springframework.http.ResponseEntity;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2TokenValidator;
 import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
 import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
+import org.springframework.web.client.RestOperations;
 import org.springframework.web.client.RestTemplate;
 
 import static org.assertj.core.api.Assertions.assertThat;
@@ -76,6 +82,8 @@ public class NimbusJwtDecoderJwkSupportTests {
 	private static final String MALFORMED_JWT = "eyJhbGciOiJSUzI1NiJ9.eyJuYmYiOnt9LCJleHAiOjQ2ODQyMjUwODd9.guoQvujdWvd3xw7FYQEn4D6-gzM_WqFvXdmvAUNSLbxG7fv2_LLCNujPdrBHJoYPbOwS1BGNxIKQWS1tylvqzmr1RohQ-RZ2iAM1HYQzboUlkoMkcd8ENM__ELqho8aNYBfqwkNdUOyBFoy7Syu_w2SoJADw2RTjnesKO6CVVa05bW118pDS4xWxqC4s7fnBjmZoTn4uQ-Kt9YSQZQk8YQxkJSiyanozzgyfgXULA6mPu1pTNU3FVFaK1i1av_xtH_zAPgb647ZeaNe4nahgqC5h8nhOlm8W2dndXbwAt29nd2ZWBsru_QwZz83XSKLhTPFz-mPBByZZDsyBbIHf9A";
 	private static final String UNSIGNED_JWT = "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJleHAiOi0yMDMzMjI0OTcsImp0aSI6IjEyMyIsInR5cCI6IkpXVCJ9.";
 
+	private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
+
 	private NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM);
 
 	@Test
@@ -256,4 +264,19 @@ public class NimbusJwtDecoderJwkSupportTests {
 		assertThatCode(() -> jwtDecoder.setClaimSetConverter(null))
 				.isInstanceOf(IllegalArgumentException.class);
 	}
+
+	// gh-7290
+	@Test
+	public void decodeWhenJwkSetRequestedThenAcceptHeaderJsonAndJwkSetJson() {
+		RestOperations restOperations = mock(RestOperations.class);
+		when(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
+				.thenReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK));
+		NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL);
+		jwtDecoder.setRestOperations(restOperations);
+		jwtDecoder.decode(SIGNED_JWT);
+		ArgumentCaptor<RequestEntity> requestEntityCaptor = ArgumentCaptor.forClass(RequestEntity.class);
+		verify(restOperations).exchange(requestEntityCaptor.capture(), eq(String.class));
+		List<MediaType> acceptHeader = requestEntityCaptor.getValue().getHeaders().getAccept();
+		assertThat(acceptHeader).contains(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON);
+	}
 }