瀏覽代碼

Fix NimbusJwkSetEndpointFilter

Closes gh-198
Joe Grandja 4 年之前
父節點
當前提交
b7996e26d0

+ 3 - 3
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/NimbusJwkSetEndpointFilter.java

@@ -73,7 +73,7 @@ public class NimbusJwkSetEndpointFilter extends OncePerRequestFilter {
 		Assert.notNull(jwkSource, "jwkSource cannot be null");
 		Assert.hasText(jwkSetEndpointUri, "jwkSetEndpointUri cannot be empty");
 		this.jwkSource = jwkSource;
-		this.jwkSelector = new JWKSelector(new JWKMatcher.Builder().publicOnly(true).build());
+		this.jwkSelector = new JWKSelector(new JWKMatcher.Builder().build());
 		this.requestMatcher = new AntPathRequestMatcher(jwkSetEndpointUri, HttpMethod.GET.name());
 	}
 
@@ -91,12 +91,12 @@ public class NimbusJwkSetEndpointFilter extends OncePerRequestFilter {
 			jwkSet = new JWKSet(this.jwkSource.get(this.jwkSelector, null));
 		}
 		catch (Exception ex) {
-			throw new IllegalStateException("Failed to select the JWK public key(s) -> " + ex.getMessage(), ex);
+			throw new IllegalStateException("Failed to select the JWK(s) -> " + ex.getMessage(), ex);
 		}
 
 		response.setContentType(MediaType.APPLICATION_JSON_VALUE);
 		try (Writer writer = response.getWriter()) {
-			writer.write(jwkSet.toString());
+			writer.write(jwkSet.toString());	// toString() excludes private keys
 		}
 	}
 }

+ 9 - 6
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/NimbusJwkSetEndpointFilterTests.java

@@ -15,14 +15,15 @@
  */
 package org.springframework.security.oauth2.server.authorization.web;
 
-import java.util.Arrays;
-import java.util.Collections;
+import java.util.ArrayList;
+import java.util.List;
 
 import javax.servlet.FilterChain;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
 import com.nimbusds.jose.jwk.ECKey;
+import com.nimbusds.jose.jwk.JWK;
 import com.nimbusds.jose.jwk.JWKSet;
 import com.nimbusds.jose.jwk.KeyUse;
 import com.nimbusds.jose.jwk.OctetSequenceKey;
@@ -40,7 +41,6 @@ import org.springframework.security.oauth2.jose.TestJwks;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoInteractions;
@@ -51,12 +51,14 @@ import static org.mockito.Mockito.verifyNoInteractions;
  * @author Joe Grandja
  */
 public class NimbusJwkSetEndpointFilterTests {
+	private List<JWK> jwkList;
 	private JWKSource<SecurityContext> jwkSource;
 	private NimbusJwkSetEndpointFilter filter;
 
 	@Before
 	public void setUp() {
-		this.jwkSource = mock(JWKSource.class);
+		this.jwkList = new ArrayList<>();
+		this.jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(new JWKSet(this.jwkList));
 		this.filter = new NimbusJwkSetEndpointFilter(this.jwkSource);
 	}
 
@@ -103,8 +105,9 @@ public class NimbusJwkSetEndpointFilterTests {
 	@Test
 	public void doFilterWhenAsymmetricKeysThenJwkSetResponse() throws Exception {
 		RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK;
+		this.jwkList.add(rsaJwk);
 		ECKey ecJwk = TestJwks.DEFAULT_EC_JWK;
-		given(this.jwkSource.get(any(), any())).willReturn(Arrays.asList(rsaJwk, ecJwk));
+		this.jwkList.add(ecJwk);
 
 		String requestUri = NimbusJwkSetEndpointFilter.DEFAULT_JWK_SET_ENDPOINT_URI;
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
@@ -137,7 +140,7 @@ public class NimbusJwkSetEndpointFilterTests {
 	@Test
 	public void doFilterWhenSymmetricKeysThenJwkSetResponseEmpty() throws Exception {
 		OctetSequenceKey secretJwk = TestJwks.DEFAULT_SECRET_JWK;
-		given(this.jwkSource.get(any(), any())).willReturn(Collections.singletonList(secretJwk));
+		this.jwkList.add(secretJwk);
 
 		String requestUri = NimbusJwkSetEndpointFilter.DEFAULT_JWK_SET_ENDPOINT_URI;
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);