Bläddra i källkod

ReactorContextWebFilter & SecurityContextServerWebExchangeWebFilter

Issue: gh-4719
Rob Winch 7 år sedan
förälder
incheckning
437ba56415

+ 4 - 4
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -44,8 +44,8 @@ import org.springframework.security.web.server.authorization.AuthorizationContex
 import org.springframework.security.web.server.authorization.AuthorizationWebFilter;
 import org.springframework.security.web.server.authorization.DelegatingReactiveAuthorizationManager;
 import org.springframework.security.web.server.authorization.ExceptionTranslationWebFilter;
-import org.springframework.security.web.server.context.AuthenticationReactorContextWebFilter;
-import org.springframework.security.web.server.context.SecurityContextRepositoryWebFilter;
+import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
+import org.springframework.security.web.server.context.ReactorContextWebFilter;
 import org.springframework.security.web.server.context.ServerSecurityContextRepository;
 import org.springframework.security.web.server.context.ServerWebExchangeAttributeServerSecurityContextRepository;
 import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
@@ -211,7 +211,7 @@ public class ServerHttpSecurity {
 		if(this.logout != null) {
 			this.logout.configure(this);
 		}
-		this.addFilterAt(new AuthenticationReactorContextWebFilter(), SecurityWebFiltersOrder.AUTHENTICATION_CONTEXT);
+		this.addFilterAt(new SecurityContextServerWebExchangeWebFilter(), SecurityWebFiltersOrder.AUTHENTICATION_CONTEXT);
 		if(this.authorizeExchangeBuilder != null) {
 			ServerAuthenticationEntryPoint serverAuthenticationEntryPoint = getServerAuthenticationEntryPoint();
 			ExceptionTranslationWebFilter exceptionTranslationWebFilter = new ExceptionTranslationWebFilter();
@@ -262,8 +262,8 @@ public class ServerHttpSecurity {
 		if(repository == null) {
 			return null;
 		}
-		WebFilter result = new SecurityContextRepositoryWebFilter(repository);
 		return new OrderedWebFilter(result, SecurityWebFiltersOrder.SECURITY_CONTEXT_REPOSITORY.getOrder());
+		WebFilter result = new ReactorContextWebFilter(repository);
 	}
 
 	private ServerHttpSecurity() {}

+ 0 - 1
web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java

@@ -26,7 +26,6 @@ import org.springframework.security.core.context.SecurityContextImpl;
 import org.springframework.security.web.server.ServerHttpBasicAuthenticationConverter;
 import org.springframework.security.web.server.WebFilterExchange;
 import org.springframework.security.web.server.context.ServerSecurityContextRepository;
-import org.springframework.security.web.server.context.SecurityContextRepositoryServerWebExchange;
 import org.springframework.security.web.server.context.ServerWebExchangeAttributeServerSecurityContextRepository;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;

+ 9 - 5
web/src/main/java/org/springframework/security/web/server/context/SecurityContextRepositoryWebFilter.java → web/src/main/java/org/springframework/security/web/server/context/ReactorContextWebFilter.java

@@ -15,6 +15,8 @@
  */
 package org.springframework.security.web.server.context;
 
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.util.Assert;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebFilter;
@@ -25,18 +27,20 @@ import reactor.core.publisher.Mono;
  * @author Rob Winch
  * @since 5.0
  */
-public class SecurityContextRepositoryWebFilter implements WebFilter {
+public class ReactorContextWebFilter implements WebFilter {
 	private final ServerSecurityContextRepository repository;
 
-	public SecurityContextRepositoryWebFilter(ServerSecurityContextRepository repository) {
+	public ReactorContextWebFilter(ServerSecurityContextRepository repository) {
 		Assert.notNull(repository, "repository cannot be null");
 		this.repository = repository;
 	}
 
 	@Override
 	public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
-		SecurityContextRepositoryServerWebExchange delegate =
-				new SecurityContextRepositoryServerWebExchange(exchange, repository);
-		return chain.filter(delegate);
+		return chain.filter(exchange)
+			.subscriberContext(c -> c.hasKey(SecurityContext.class) ? c :
+				Mono.defer(() -> this.repository.load(exchange))
+					.as(ReactiveSecurityContextHolder::withSecurityContext)
+			);
 	}
 }

+ 7 - 10
web/src/main/java/org/springframework/security/web/server/context/SecurityContextRepositoryServerWebExchange.java → web/src/main/java/org/springframework/security/web/server/context/SecurityContextServerWebExchange.java

@@ -17,6 +17,8 @@ package org.springframework.security.web.server.context;
 
 import java.security.Principal;
 
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.ServerWebExchangeDecorator;
 
@@ -26,22 +28,17 @@ import reactor.core.publisher.Mono;
  * @author Rob Winch
  * @since 5.0
  */
-public class SecurityContextRepositoryServerWebExchange extends ServerWebExchangeDecorator {
-	private final ServerSecurityContextRepository repository;
+public class SecurityContextServerWebExchange extends ServerWebExchangeDecorator {
+	private final Mono<SecurityContext> context;
 
-	public SecurityContextRepositoryServerWebExchange(ServerWebExchange delegate, ServerSecurityContextRepository repository) {
+	public SecurityContextServerWebExchange(ServerWebExchange delegate, Mono<SecurityContext> context) {
 		super(delegate);
-		this.repository = repository;
+		this.context = context;
 	}
 
 	@Override
 	@SuppressWarnings("unchecked")
 	public <T extends Principal> Mono<T> getPrincipal() {
-		return Mono.defer(() ->
-			this.repository.load(this)
-				.filter(c -> c.getAuthentication() != null)
-				.flatMap(c -> Mono.just((T) c.getAuthentication()))
-				.switchIfEmpty( super.getPrincipal() )
-		);
+		return this.context.map(c -> (T) c.getAuthentication());
 	}
 }

+ 2 - 10
web/src/main/java/org/springframework/security/web/server/context/AuthenticationReactorContextWebFilter.java → web/src/main/java/org/springframework/security/web/server/context/SecurityContextServerWebExchangeWebFilter.java

@@ -34,19 +34,11 @@ import java.security.Principal;
  * @author Rob Winch
  * @since 5.0
  */
-public class AuthenticationReactorContextWebFilter implements WebFilter {
+public class SecurityContextServerWebExchangeWebFilter implements WebFilter {
 
 	@Override
 	public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
 
-		return chain.filter(exchange)
-				.subscriberContext(createContext(exchange));
-	}
-
-	private Context createContext(ServerWebExchange exchange) {
-		return exchange.getPrincipal()
-			.cast(Authentication.class)
-			.map(SecurityContextImpl::new)
-			.as(ReactiveSecurityContextHolder::withSecurityContext);
+		return chain.filter(new SecurityContextServerWebExchange(exchange, ReactiveSecurityContextHolder.getContext()));
 	}
 }

+ 102 - 0
web/src/test/java/org/springframework/security/web/server/context/ReactorContextWebFilterTests.java

@@ -0,0 +1,102 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.server.context;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextImpl;
+import org.springframework.security.test.web.reactive.server.WebTestHandler;
+import reactor.core.publisher.Mono;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.*;
+
+
+/**
+ * @author Rob Winch
+ * @since 5.0
+ */
+@RunWith(MockitoJUnitRunner.class)
+public class ReactorContextWebFilterTests {
+	@Mock
+	private Authentication principal;
+
+	@Mock
+	private ServerSecurityContextRepository repository;
+
+	private MockServerHttpRequest.BaseBuilder<?> exchange = MockServerHttpRequest.get("/");
+
+	private ReactorContextWebFilter filter;
+
+	private WebTestHandler handler;
+
+
+	@Before
+	public void setup() {
+		this.filter = new ReactorContextWebFilter(this.repository);
+		this.handler = WebTestHandler.bindToWebFilters(this.filter);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorNullSecurityContextRepository() {
+		ServerSecurityContextRepository repository = null;
+		new ReactorContextWebFilter(repository);
+	}
+
+	@Test
+	public void filterWhenNoPrincipalAccessThenNoInteractions() {
+		this.handler.exchange(this.exchange);
+
+		verifyZeroInteractions(this.repository);
+	}
+
+	@Test
+	public void filterWhenGetPrincipalMonoThenNoInteractions() {
+		this.handler = WebTestHandler.bindToWebFilters(this.filter, (e,c) -> {
+			ReactiveSecurityContextHolder.getContext();
+			return c.filter(e);
+		});
+
+		this.handler.exchange(this.exchange);
+
+		verifyZeroInteractions(this.repository);
+	}
+
+	@Test
+	public void filterWhenPrincipalAndGetPrincipalThenInteractAndUseOriginalPrincipal() {
+		SecurityContextImpl context = new SecurityContextImpl(this.principal);
+		when(this.repository.load(any())).thenReturn(Mono.just(context));
+		this.handler = WebTestHandler.bindToWebFilters(this.filter, (e,c) ->
+			ReactiveSecurityContextHolder.getContext()
+				.map(SecurityContext::getAuthentication)
+				.doOnSuccess( p -> assertThat(p).isSameAs(this.principal))
+				.flatMap(p -> c.filter(e))
+		);
+
+		WebTestHandler.WebHandlerResult result = this.handler.exchange(this.exchange);
+
+		verify(this.repository).load(any());
+	}
+}

+ 25 - 26
web/src/test/java/org/springframework/security/web/server/context/AuthenticationReactorContextWebFilterTests.java → web/src/test/java/org/springframework/security/web/server/context/SecurityContextServerWebExchangeWebFilterTests.java

@@ -36,57 +36,56 @@ import static org.assertj.core.api.Assertions.assertThat;
  * @author Rob Winch
  * @since 5.0
  */
-public class AuthenticationReactorContextWebFilterTests {
-	AuthenticationReactorContextWebFilter filter = new AuthenticationReactorContextWebFilter();
+public class SecurityContextServerWebExchangeWebFilterTests {
+	SecurityContextServerWebExchangeWebFilter filter = new SecurityContextServerWebExchangeWebFilter();
 
-	Principal principal = new TestingAuthenticationToken("user","password", "ROLE_USER");
+	Authentication principal = new TestingAuthenticationToken("user","password", "ROLE_USER");
 
 	ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/").build());
 
 	@Test
 	public void filterWhenExistingContextAndPrincipalNotNullThenContextPopulated() {
-		exchange = exchange.mutate().principal(Mono.just(principal)).build();
-		StepVerifier.create(filter.filter(exchange,
-			new DefaultWebFilterChain( e ->
-				ReactiveSecurityContextHolder.getContext()
-					.map(SecurityContext::getAuthentication)
+		Mono<Void> result = this.filter.filter(this.exchange, new DefaultWebFilterChain( e ->
+				e.getPrincipal()
 					.doOnSuccess(contextPrincipal -> assertThat(contextPrincipal).isEqualTo(principal))
 					.flatMap( contextPrincipal -> Mono.subscriberContext())
 					.doOnSuccess( context -> assertThat(context.<String>get("foo")).isEqualTo("bar"))
 					.then()
 			)
 		)
-		.subscriberContext( context -> context.put("foo", "bar")))
-		.verifyComplete();
+		.subscriberContext( context -> context.put("foo", "bar"))
+		.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.principal));
+
+		StepVerifier.create(result)
+			.verifyComplete();
 	}
 
 	@Test
 	public void filterWhenPrincipalNotNullThenContextPopulated() {
-		exchange = exchange.mutate().principal(Mono.just(principal)).build();
-		StepVerifier.create(filter.filter(exchange,
-			new DefaultWebFilterChain( e ->
-				ReactiveSecurityContextHolder.getContext()
-					.map(SecurityContext::getAuthentication)
-					.doOnSuccess(contextPrincipal -> assertThat(contextPrincipal).isEqualTo(principal))
+		Mono<Void> result = this.filter.filter(this.exchange, new DefaultWebFilterChain( e ->
+				e.getPrincipal()
+					.doOnSuccess(contextPrincipal -> assertThat(contextPrincipal).isEqualTo(this.principal))
 					.then()
 			)
-		))
-		.verifyComplete();
+		)
+		.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.principal));
+
+		StepVerifier.create(result)
+			.verifyComplete();
 	}
 
 	@Test
 	public void filterWhenPrincipalNullThenContextEmpty() {
 		Authentication defaultAuthentication = new TestingAuthenticationToken("anonymouse","anonymous", "TEST");
-		StepVerifier.create(filter.filter(exchange,
-			new DefaultWebFilterChain( e ->
-				ReactiveSecurityContextHolder.getContext()
-					.map(SecurityContext::getAuthentication)
+		Mono<Void> result = this.filter.filter(this.exchange, new DefaultWebFilterChain( e ->
+				e.getPrincipal()
 					.defaultIfEmpty(defaultAuthentication)
 					.doOnSuccess( contextPrincipal -> assertThat(contextPrincipal).isEqualTo(defaultAuthentication)
-				)
-				.then()
+					)
+					.then()
 			)
-		))
-		.verifyComplete();
+		);
+		StepVerifier.create(result)
+			.verifyComplete();
 	}
 }

+ 0 - 113
web/src/test/java/org/springframework/security/web/server/context/ServerSecurityContextRepositoryWebFilterTests.java

@@ -1,113 +0,0 @@
-/*
- * Copyright 2002-2017 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.springframework.security.web.server.context;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mock;
-import org.mockito.junit.MockitoJUnitRunner;
-import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
-import org.springframework.mock.web.server.MockServerWebExchange;
-import org.springframework.security.core.Authentication;
-import org.springframework.security.core.context.SecurityContextImpl;
-import org.springframework.security.test.web.reactive.server.WebTestHandler;
-import org.springframework.web.server.ServerWebExchange;
-import reactor.core.publisher.Mono;
-
-import java.security.Principal;
-
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.mockito.Matchers.any;
-import static org.mockito.Mockito.*;
-
-
-/**
- * @author Rob Winch
- * @since 5.0
- */
-@RunWith(MockitoJUnitRunner.class)
-public class ServerSecurityContextRepositoryWebFilterTests {
-	@Mock
-	Authentication principal;
-
-	@Mock ServerSecurityContextRepository repository;
-
-	MockServerHttpRequest.BaseBuilder<?> exchange = MockServerHttpRequest.get("/");
-
-	SecurityContextRepositoryWebFilter filter;
-
-	WebTestHandler filters;
-
-
-	@Before
-	public void setup() {
-		filter = new SecurityContextRepositoryWebFilter(repository);
-		filters = WebTestHandler.bindToWebFilters(filter);
-	}
-
-	@Test(expected = IllegalArgumentException.class)
-	public void constructorNullSecurityContextRepository() {
-		ServerSecurityContextRepository repository = null;
-		new SecurityContextRepositoryWebFilter(repository);
-	}
-
-	@Test
-	public void filterWhenNoPrincipalAccessThenNoInteractions() {
-		filters.exchange(exchange);
-
-		verifyZeroInteractions(repository);
-	}
-
-	@Test
-	public void filterWhenGetPrincipalMonoThenNoInteractions() {
-		filters = WebTestHandler.bindToWebFilters(filter, (e,c) -> {
-			Mono<Principal> p = e.getPrincipal();
-			return c.filter(e);
-		});
-
-		filters.exchange(exchange);
-
-		verifyZeroInteractions(repository);
-	}
-
-	// We must use the original principal if the result is empty for test support to work
-	@Test
-	public void filterWhenEmptyAndGetPrincipalThenInteractAndUseOriginalPrincipal() {
-		when(repository.load(any())).thenReturn(Mono.empty());
-		filters = WebTestHandler.bindToWebFilters(filter, (e,c) -> e.getPrincipal().flatMap( p-> c.filter(e))) ;
-
-		ServerWebExchange exchangeWithPrincipal = MockServerWebExchange.from(exchange.build()).mutate().principal(Mono.just(principal)).build();
-		WebTestHandler.WebHandlerResult result = filters.exchange(exchangeWithPrincipal);
-
-		verify(repository).load(any());
-		assertThat(result.getExchange().getPrincipal().block()).isSameAs(principal);
-	}
-
-	@Test
-	public void filterWhenPrincipalAndGetPrincipalThenInteractAndUseOriginalPrincipal() {
-		SecurityContextImpl context = new SecurityContextImpl();
-		context.setAuthentication(principal);
-		when(repository.load(any())).thenReturn(Mono.just(context));
-		filters = WebTestHandler.bindToWebFilters(filter, (e,c) -> e.getPrincipal().flatMap( p-> c.filter(e))) ;
-
-		WebTestHandler.WebHandlerResult result = filters.exchange(exchange);
-
-		verify(repository).load(any());
-		assertThat(result.getExchange().getPrincipal().block()).isSameAs(principal);
-	}
-}