Browse Source

ServerHttpSecurity ReactiveJwtDecoder discovery

This makes so that WebFlux OAuth 2.0 Resource Server configuration
will pick up a ReactiveJwtDecoder exposed as a bean.

Fixes: gh-5720
Josh Cummings 7 năm trước cách đây
mục cha
commit
cba2444e1a

+ 16 - 1
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -741,8 +741,9 @@ public class ServerHttpSecurity {
 
 			protected void configure(ServerHttpSecurity http) {
 				BearerTokenServerAuthenticationEntryPoint entryPoint = new BearerTokenServerAuthenticationEntryPoint();
+				ReactiveJwtDecoder jwtDecoder = this.getJwtDecoder();
 				JwtReactiveAuthenticationManager authenticationManager = new JwtReactiveAuthenticationManager(
-						this.jwtDecoder);
+						jwtDecoder);
 				AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(authenticationManager);
 				oauth2.setServerAuthenticationConverter(new ServerBearerTokenAuthenticationConverter());
 				oauth2.setAuthenticationFailureHandler(new ServerAuthenticationEntryPointFailureHandler(entryPoint));
@@ -752,6 +753,13 @@ public class ServerHttpSecurity {
 						.and()
 					.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION);
 			}
+
+			protected ReactiveJwtDecoder getJwtDecoder() {
+				if (this.jwtDecoder == null) {
+					return getBean(ReactiveJwtDecoder.class);
+				}
+				return this.jwtDecoder;
+			}
 		}
 
 		public ServerHttpSecurity and() {
@@ -2014,6 +2022,13 @@ public class ServerHttpSecurity {
 		private LogoutSpec() {}
 	}
 
+	private <T> T getBean(Class<T> beanClass) {
+		if (this.context == null) {
+			return null;
+		}
+		return this.context.getBean(beanClass);
+	}
+
 	private <T> T getBeanOrNull(Class<T> beanClass) {
 		return getBeanOrNull(ResolvableType.forClass(beanClass));
 	}

+ 120 - 2
config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java

@@ -34,6 +34,8 @@ import org.junit.Test;
 import org.junit.runner.RunWith;
 import reactor.core.publisher.Mono;
 
+import org.springframework.beans.factory.NoSuchBeanDefinitionException;
+import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.ApplicationContext;
 import org.springframework.context.annotation.Bean;
@@ -41,13 +43,24 @@ import org.springframework.security.config.annotation.web.reactive.EnableWebFlux
 import org.springframework.security.config.test.SpringTestRule;
 import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
 import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
 import org.springframework.security.web.server.SecurityWebFilterChain;
 import org.springframework.test.context.junit4.SpringRunner;
 import org.springframework.test.web.reactive.server.WebTestClient;
 import org.springframework.web.bind.annotation.GetMapping;
 import org.springframework.web.bind.annotation.RestController;
+import org.springframework.web.context.support.GenericWebApplicationContext;
+import org.springframework.web.reactive.DispatcherHandler;
 import org.springframework.web.reactive.config.EnableWebFlux;
 
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.hamcrest.core.StringStartsWith.startsWith;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
 /**
  * Tests for {@link org.springframework.security.config.web.server.ServerHttpSecurity.OAuth2ResourceServerSpec}
  */
@@ -105,7 +118,7 @@ public class OAuth2ResourceServerSpecTests {
 				.headers(headers -> headers.setBearerAuth(this.expired))
 				.exchange()
 				.expectStatus().isUnauthorized()
-				.expectHeader().exists(HttpHeaders.WWW_AUTHENTICATE);
+				.expectHeader().value(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer error=\"invalid_token\""));
 	}
 
 	@Test
@@ -116,7 +129,22 @@ public class OAuth2ResourceServerSpecTests {
 				.headers(headers -> headers.setBearerAuth(this.unsignedToken))
 				.exchange()
 				.expectStatus().isUnauthorized()
-				.expectHeader().exists(HttpHeaders.WWW_AUTHENTICATE);
+				.expectHeader().value(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer error=\"invalid_token\""));
+	}
+
+	@Test
+	public void getWhenCustomDecoderThenAuthenticatesAccordingly() {
+		this.spring.register(CustomDecoderConfig.class, RootController.class).autowire();
+
+		ReactiveJwtDecoder jwtDecoder = this.spring.getContext().getBean(ReactiveJwtDecoder.class);
+		when(jwtDecoder.decode(anyString())).thenReturn(Mono.just(this.jwt));
+
+		this.client.get()
+				.headers(headers -> headers.setBearerAuth("token"))
+				.exchange()
+				.expectStatus().isOk();
+
+		verify(jwtDecoder).decode(anyString());
 	}
 
 	@Test
@@ -132,6 +160,67 @@ public class OAuth2ResourceServerSpecTests {
 				.expectStatus().isOk();
 	}
 
+	@Test
+	public void getJwtDecoderWhenBeanWiredAndDslWiredThenDslTakesPrecedence() {
+		GenericWebApplicationContext context = autowireWebServerGenericWebApplicationContext();
+		ServerHttpSecurity http = new ServerHttpSecurity();
+		http.setApplicationContext(context);
+
+		ReactiveJwtDecoder beanWiredJwtDecoder = mock(ReactiveJwtDecoder.class);
+		ReactiveJwtDecoder dslWiredJwtDecoder = mock(ReactiveJwtDecoder.class);
+		context.registerBean(ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder);
+
+		ServerHttpSecurity.OAuth2ResourceServerSpec.JwtSpec jwt = http.oauth2ResourceServer().jwt();
+		jwt.jwtDecoder(dslWiredJwtDecoder);
+
+		assertThat(jwt.getJwtDecoder()).isEqualTo(dslWiredJwtDecoder);
+	}
+
+	@Test
+	public void getJwtDecoderWhenTwoBeansWiredAndDslWiredThenDslTakesPrecedence() {
+		GenericWebApplicationContext context = autowireWebServerGenericWebApplicationContext();
+		ServerHttpSecurity http = new ServerHttpSecurity();
+		http.setApplicationContext(context);
+
+		ReactiveJwtDecoder beanWiredJwtDecoder = mock(ReactiveJwtDecoder.class);
+		ReactiveJwtDecoder dslWiredJwtDecoder = mock(ReactiveJwtDecoder.class);
+		context.registerBean("firstJwtDecoder", ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder);
+		context.registerBean("secondJwtDecoder", ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder);
+
+		ServerHttpSecurity.OAuth2ResourceServerSpec.JwtSpec jwt = http.oauth2ResourceServer().jwt();
+		jwt.jwtDecoder(dslWiredJwtDecoder);
+
+		assertThat(jwt.getJwtDecoder()).isEqualTo(dslWiredJwtDecoder);
+	}
+
+	@Test
+	public void getJwtDecoderWhenTwoBeansWiredThenThrowsWiringException() {
+		GenericWebApplicationContext context = autowireWebServerGenericWebApplicationContext();
+		ServerHttpSecurity http = new ServerHttpSecurity();
+		http.setApplicationContext(context);
+
+		ReactiveJwtDecoder beanWiredJwtDecoder = mock(ReactiveJwtDecoder.class);
+		context.registerBean("firstJwtDecoder", ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder);
+		context.registerBean("secondJwtDecoder", ReactiveJwtDecoder.class, () -> beanWiredJwtDecoder);
+
+		ServerHttpSecurity.OAuth2ResourceServerSpec.JwtSpec jwt = http.oauth2ResourceServer().jwt();
+
+		assertThatCode(() -> jwt.getJwtDecoder())
+				.isInstanceOf(NoUniqueBeanDefinitionException.class);
+	}
+
+	@Test
+	public void getJwtDecoderWhenNoBeansAndNoDslWiredThenWiringException() {
+		GenericWebApplicationContext context = autowireWebServerGenericWebApplicationContext();
+		ServerHttpSecurity http = new ServerHttpSecurity();
+		http.setApplicationContext(context);
+
+		ServerHttpSecurity.OAuth2ResourceServerSpec.JwtSpec jwt = http.oauth2ResourceServer().jwt();
+
+		assertThatCode(() -> jwt.getJwtDecoder())
+				.isInstanceOf(NoSuchBeanDefinitionException.class);
+	}
+
 	@EnableWebFlux
 	@EnableWebFluxSecurity
 	static class PublicKeyConfig {
@@ -187,6 +276,28 @@ public class OAuth2ResourceServerSpecTests {
 		}
 	}
 
+	@EnableWebFlux
+	@EnableWebFluxSecurity
+	static class CustomDecoderConfig {
+		ReactiveJwtDecoder jwtDecoder = mock(ReactiveJwtDecoder.class);
+
+		@Bean
+		SecurityWebFilterChain springSecurity(ServerHttpSecurity http) {
+			// @formatter:off
+			http
+				.oauth2ResourceServer()
+					.jwt();
+			// @formatter:on
+
+			return http.build();
+		}
+
+		@Bean
+		ReactiveJwtDecoder jwtDecoder() {
+			return this.jwtDecoder;
+		}
+	}
+
 	@RestController
 	static class RootController {
 		@GetMapping
@@ -194,4 +305,11 @@ public class OAuth2ResourceServerSpecTests {
 			return Mono.just("ok");
 		}
 	}
+
+	private GenericWebApplicationContext autowireWebServerGenericWebApplicationContext() {
+		GenericWebApplicationContext context = new GenericWebApplicationContext();
+		context.registerBean("webHandler", DispatcherHandler.class);
+		this.spring.context(context).autowire();
+		return (GenericWebApplicationContext) this.spring.getContext();
+	}
 }