소스 검색

User-Specified JwtDecoder

This exposes JwtConfigurer#decoder as well as makes the configurer
look in the application context for a bean of type JwtDecoder.

Fixes: gh-5519
Josh Cummings 7 년 전
부모
커밋
4fc1e63369

+ 47 - 19
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurer.java

@@ -18,8 +18,8 @@ package org.springframework.security.config.annotation.web.configurers.oauth2.se
 
 import javax.servlet.http.HttpServletRequest;
 
+import org.springframework.context.ApplicationContext;
 import org.springframework.security.authentication.AuthenticationManager;
-import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
 import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
 import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer;
@@ -51,7 +51,19 @@ import org.springframework.util.Assert;
  * </ul>
  *
  * <p>
- * When using {@link #jwt()}, a Jwk Set Uri must be supplied via {@link JwtConfigurer#jwkSetUri}
+ * When using {@link #jwt()}, either
+ *
+ * <ul>
+ * <li>
+ * supply a Jwk Set Uri via {@link JwtConfigurer#jwkSetUri}, or
+ * </li>
+ * <li>
+ * supply a {@link JwtDecoder} instance via {@link JwtConfigurer#decoder}, or
+ * </li>
+ * <li>
+ * expose a {@link JwtDecoder} bean
+ * </li>
+ * </ul>
  *
  * <h2>Security Filters</h2>
  *
@@ -77,10 +89,6 @@ import org.springframework.util.Assert;
  * <li>{@link AuthenticationManager}</li>
  * </ul>
  *
- * If {@link #jwt()} isn't supplied, then the {@link BearerTokenAuthenticationFilter} is still added, but without
- * any OAuth 2.0 {@link AuthenticationProvider}s. This is useful if needing to switch out Spring Security's Jwt support
- * for a custom one.
- *
  * @author Josh Cummings
  * @since 5.1
  * @see BearerTokenAuthenticationFilter
@@ -100,9 +108,14 @@ public final class OAuth2ResourceServerConfigurer<H extends HttpSecurityBuilder<
 	private BearerTokenAccessDeniedHandler accessDeniedHandler
 			= new BearerTokenAccessDeniedHandler();
 
-	private JwtConfigurer jwtConfigurer = new JwtConfigurer();
+	private JwtConfigurer jwtConfigurer;
 
 	public JwtConfigurer jwt() {
+		if ( this.jwtConfigurer == null ) {
+			ApplicationContext context = this.getBuilder().getSharedObject(ApplicationContext.class);
+			this.jwtConfigurer = new JwtConfigurer(context);
+		}
+
 		return this.jwtConfigurer;
 	}
 
@@ -133,32 +146,47 @@ public final class OAuth2ResourceServerConfigurer<H extends HttpSecurityBuilder<
 
 		http.addFilter(filter);
 
+		if ( this.jwtConfigurer == null ) {
+			throw new IllegalStateException("Jwt is the only supported format for bearer tokens " +
+					"in Spring Security and no Jwt configuration was found. Make sure to specify " +
+					"a jwk set uri by doing http.oauth2().resourceServer().jwt().jwkSetUri(uri), or wire a " +
+					"JwtDecoder instance by doing http.oauth2().resourceServer().jwt().decoder(decoder), or " +
+					"expose a JwtDecoder instance as a bean and do http.oauth2().resourceServer().jwt().");
+		}
+
 		JwtDecoder decoder = this.jwtConfigurer.getJwtDecoder();
 
-		if (decoder != null) {
-			JwtAuthenticationProvider provider =
-					new JwtAuthenticationProvider(decoder);
-			provider = postProcess(provider);
+		JwtAuthenticationProvider provider =
+				new JwtAuthenticationProvider(decoder);
+		provider = postProcess(provider);
 
-			http.authenticationProvider(provider);
-		} else {
-			throw new IllegalStateException("Jwt is the only supported format for bearer tokens " +
-					"in Spring Security and no instance of JwtDecoder could be found. Make sure to specify " +
-					"a jwk set uri by doing http.oauth2().resourceServer().jwt().jwkSetUri(uri)");
-		}
+		http.authenticationProvider(provider);
 	}
 
 	public class JwtConfigurer {
+		private final ApplicationContext context;
+
 		private JwtDecoder decoder;
 
-		private JwtConfigurer() {}
+		JwtConfigurer(ApplicationContext context) {
+			this.context = context;
+		}
+
+		public OAuth2ResourceServerConfigurer<H> decoder(JwtDecoder decoder) {
+			this.decoder = decoder;
+			return OAuth2ResourceServerConfigurer.this;
+		}
 
 		public OAuth2ResourceServerConfigurer<H> jwkSetUri(String uri) {
 			this.decoder = new NimbusJwtDecoderJwkSupport(uri);
 			return OAuth2ResourceServerConfigurer.this;
 		}
 
-		private JwtDecoder getJwtDecoder() {
+		JwtDecoder getJwtDecoder() {
+			if ( this.decoder == null ) {
+				return this.context.getBean(JwtDecoder.class);
+			}
+
 			return this.decoder;
 		}
 	}

+ 191 - 3
config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurerTests.java

@@ -20,6 +20,9 @@ import java.io.BufferedReader;
 import java.io.FileReader;
 import java.io.IOException;
 import java.lang.reflect.Field;
+import java.time.Instant;
+import java.util.Collections;
+import java.util.Map;
 import java.util.stream.Collectors;
 import javax.annotation.PreDestroy;
 
@@ -34,9 +37,11 @@ import org.junit.Test;
 
 import org.springframework.beans.BeansException;
 import org.springframework.beans.factory.BeanCreationException;
+import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Value;
 import org.springframework.beans.factory.config.BeanPostProcessor;
+import org.springframework.context.ApplicationContext;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Configuration;
 import org.springframework.core.io.ClassPathResource;
@@ -55,6 +60,11 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.annotation.AuthenticationPrincipal;
 import org.springframework.security.core.userdetails.UserDetailsService;
+import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
+import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.JwtClaimNames;
+import org.springframework.security.oauth2.jwt.JwtDecoder;
+import org.springframework.security.oauth2.jwt.NimbusJwtDecoderJwkSupport;
 import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken;
 import org.springframework.security.provisioning.InMemoryUserDetailsManager;
 import org.springframework.test.web.servlet.MockMvc;
@@ -68,9 +78,13 @@ import org.springframework.web.bind.annotation.GetMapping;
 import org.springframework.web.bind.annotation.PostMapping;
 import org.springframework.web.bind.annotation.RequestMapping;
 import org.springframework.web.bind.annotation.RestController;
+import org.springframework.web.context.support.GenericWebApplicationContext;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
@@ -86,8 +100,14 @@ import static org.springframework.web.bind.annotation.RequestMethod.POST;
  * @author Josh Cummings
  */
 public class OAuth2ResourceServerConfigurerTests {
+	private static final String JWT_TOKEN = "token";
+	private static final String JWT_SUBJECT = "mock-test-subject";
+	private static final Map<String, Object> JWT_HEADERS = Collections.singletonMap("alg", JwsAlgorithms.RS256);
+	private static final Map<String, Object> JWT_CLAIMS = Collections.singletonMap(JwtClaimNames.SUB, JWT_SUBJECT);
+	private static final Jwt JWT = new Jwt(JWT_TOKEN, Instant.MIN, Instant.MAX, JWT_HEADERS, JWT_CLAIMS);
+	private static final String JWK_SET_URI = "https://mock.org";
 
-	@Autowired
+	@Autowired(required = false)
 	MockMvc mvc;
 
 	@Autowired(required = false)
@@ -506,6 +526,130 @@ public class OAuth2ResourceServerConfigurerTests {
 		assertThat(result.getRequest().getSession(false)).isNotNull();
 	}
 
+	// -- custom jwt decoder
+
+	@Test
+	public void requestWhenCustomJwtDecoderWiredOnDslThenUsed()
+			throws Exception {
+
+		this.spring.register(CustomJwtDecoderOnDsl.class, BasicController.class).autowire();
+
+		CustomJwtDecoderOnDsl config = this.spring.getContext().getBean(CustomJwtDecoderOnDsl.class);
+		JwtDecoder decoder = config.decoder();
+
+		when(decoder.decode(anyString())).thenReturn(JWT);
+
+		this.mvc.perform(get("/authenticated")
+				.with(bearerToken(JWT_TOKEN)))
+				.andExpect(status().isOk())
+				.andExpect(content().string(JWT_SUBJECT));
+	}
+
+	@Test
+	public void requestWhenCustomJwtDecoderExposedAsBeanThenUsed()
+			throws Exception {
+
+		this.spring.register(CustomJwtDecoderAsBean.class, BasicController.class).autowire();
+
+		JwtDecoder decoder = this.spring.getContext().getBean(JwtDecoder.class);
+
+		when(decoder.decode(anyString())).thenReturn(JWT);
+
+		this.mvc.perform(get("/authenticated")
+				.with(bearerToken(JWT_TOKEN)))
+				.andExpect(status().isOk())
+				.andExpect(content().string(JWT_SUBJECT));
+	}
+
+	@Test
+	public void getJwtDecoderWhenConfiguredWithDecoderAndJwkSetUriThenLastOneWins() {
+		OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer =
+				new OAuth2ResourceServerConfigurer().new JwtConfigurer(null);
+
+		JwtDecoder decoder = mock(JwtDecoder.class);
+
+		jwtConfigurer.jwkSetUri(JWK_SET_URI);
+		jwtConfigurer.decoder(decoder);
+
+		assertThat(jwtConfigurer.getJwtDecoder()).isEqualTo(decoder);
+
+		jwtConfigurer =
+				new OAuth2ResourceServerConfigurer().new JwtConfigurer(null);
+
+		jwtConfigurer.decoder(decoder);
+		jwtConfigurer.jwkSetUri(JWK_SET_URI);
+
+		assertThat(jwtConfigurer.getJwtDecoder()).isInstanceOf(NimbusJwtDecoderJwkSupport.class);
+
+	}
+
+	@Test
+	public void getJwtDecoderWhenConflictingJwtDecodersThenTheDslWiredOneTakesPrecedence() {
+
+		JwtDecoder decoderBean = mock(JwtDecoder.class);
+		JwtDecoder decoder = mock(JwtDecoder.class);
+
+		ApplicationContext context = mock(ApplicationContext.class);
+		when(context.getBean(JwtDecoder.class)).thenReturn(decoderBean);
+
+		OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer =
+				new OAuth2ResourceServerConfigurer().new JwtConfigurer(context);
+		jwtConfigurer.decoder(decoder);
+
+		assertThat(jwtConfigurer.getJwtDecoder()).isEqualTo(decoder);
+	}
+
+	@Test
+	public void getJwtDecoderWhenContextHasBeanAndUserConfiguresJwkSetUriThenJwkSetUriTakesPrecedence() {
+
+		JwtDecoder decoder = mock(JwtDecoder.class);
+		ApplicationContext context = mock(ApplicationContext.class);
+		when(context.getBean(JwtDecoder.class)).thenReturn(decoder);
+
+		OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer =
+				new OAuth2ResourceServerConfigurer().new JwtConfigurer(context);
+
+		jwtConfigurer.jwkSetUri(JWK_SET_URI);
+
+		assertThat(jwtConfigurer.getJwtDecoder()).isNotEqualTo(decoder);
+		assertThat(jwtConfigurer.getJwtDecoder()).isInstanceOf(NimbusJwtDecoderJwkSupport.class);
+	}
+
+	@Test
+	public void getJwtDecoderWhenTwoJwtDecoderBeansAndAnotherWiredOnDslThenDslWiredOneTakesPrecedence() {
+
+		JwtDecoder decoderBean = mock(JwtDecoder.class);
+		JwtDecoder decoder = mock(JwtDecoder.class);
+
+		GenericWebApplicationContext context = new GenericWebApplicationContext();
+		context.registerBean("decoderOne", JwtDecoder.class, () -> decoderBean);
+		context.registerBean("decoderTwo", JwtDecoder.class, () -> decoderBean);
+		this.spring.context(context).autowire();
+
+		OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer =
+				new OAuth2ResourceServerConfigurer().new JwtConfigurer(context);
+		jwtConfigurer.decoder(decoder);
+
+		assertThat(jwtConfigurer.getJwtDecoder()).isEqualTo(decoder);
+	}
+
+	@Test
+	public void getJwtDecoderWhenTwoJwtDecoderBeansThenThrowsException() {
+
+		JwtDecoder decoder = mock(JwtDecoder.class);
+		GenericWebApplicationContext context = new GenericWebApplicationContext();
+		context.registerBean("decoderOne", JwtDecoder.class, () -> decoder);
+		context.registerBean("decoderTwo", JwtDecoder.class, () -> decoder);
+
+		this.spring.context(context).autowire();
+
+		OAuth2ResourceServerConfigurer.JwtConfigurer jwtConfigurer =
+				new OAuth2ResourceServerConfigurer().new JwtConfigurer(context);
+
+		assertThatCode(() -> jwtConfigurer.getJwtDecoder())
+				.isInstanceOf(NoUniqueBeanDefinitionException.class);
+	}
+
 	// -- In combination with other authentication providers
 
 	@Test
@@ -534,7 +678,7 @@ public class OAuth2ResourceServerConfigurerTests {
 
 		assertThatCode(() -> this.spring.register(JwtlessConfig.class).autowire())
 				.isInstanceOf(BeanCreationException.class)
-				.hasMessageContaining("no instance of JwtDecoder");
+				.hasMessageContaining("no Jwt configuration was found");
 	}
 
 	@Test
@@ -542,7 +686,7 @@ public class OAuth2ResourceServerConfigurerTests {
 
 		assertThatCode(() -> this.spring.register(JwtHalfConfiguredConfig.class).autowire())
 				.isInstanceOf(BeanCreationException.class)
-				.hasMessageContaining("no instance of JwtDecoder");
+				.hasMessageContaining("No qualifying bean of type");
 	}
 
 	// -- support
@@ -689,6 +833,50 @@ public class OAuth2ResourceServerConfigurerTests {
 		}
 	}
 
+	@EnableWebSecurity
+	static class CustomJwtDecoderOnDsl extends WebSecurityConfigurerAdapter {
+		JwtDecoder decoder = mock(JwtDecoder.class);
+
+		@Override
+		protected void configure(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+				.authorizeRequests()
+					.anyRequest().authenticated()
+					.and()
+				.oauth2()
+					.resourceServer()
+						.jwt()
+							.decoder(decoder());
+			// @formatter:on
+		}
+
+		JwtDecoder decoder() {
+			return this.decoder;
+		}
+	}
+
+	@EnableWebSecurity
+	static class CustomJwtDecoderAsBean extends WebSecurityConfigurerAdapter {
+		@Override
+		protected void configure(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+				.authorizeRequests()
+					.anyRequest().authenticated()
+					.and()
+				.oauth2()
+					.resourceServer()
+						.jwt();
+			// @formatter:on
+		}
+
+		@Bean
+		public JwtDecoder decoder() {
+			return mock(JwtDecoder.class);
+		}
+	}
+
 	@RestController
 	static class BasicController {
 		@GetMapping("/")