Browse Source

AuthenticationEntryPoint & AccessDeniedHandler use Mono<Void>

Rob Winch 8 years ago
parent
commit
8f5069053e

+ 3 - 4
webflux/src/main/java/org/springframework/security/web/server/AuthenticationEntryPoint.java

@@ -17,12 +17,11 @@
  */
 package org.springframework.security.web.server;
 
+import reactor.core.publisher.Mono;
+
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.web.server.ServerWebExchange;
 
-
-import reactor.core.publisher.Mono;
-
 /**
  *
  * @author Rob Winch
@@ -30,5 +29,5 @@ import reactor.core.publisher.Mono;
  */
 public interface AuthenticationEntryPoint {
 
-	<T> Mono<T> commence(ServerWebExchange exchange, AuthenticationException e);
+	Mono<Void> commence(ServerWebExchange exchange, AuthenticationException e);
 }

+ 25 - 5
webflux/src/main/java/org/springframework/security/web/server/authentication/www/HttpBasicAuthenticationEntryPoint.java

@@ -21,6 +21,7 @@ import org.springframework.http.HttpStatus;
 import org.springframework.http.server.reactive.ServerHttpResponse;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.web.server.AuthenticationEntryPoint;
+import org.springframework.util.Assert;
 import org.springframework.web.server.ServerWebExchange;
 
 import reactor.core.publisher.Mono;
@@ -31,12 +32,31 @@ import reactor.core.publisher.Mono;
  * @since 5.0
  */
 public class HttpBasicAuthenticationEntryPoint implements AuthenticationEntryPoint {
+	private static final String WWW_AUTHENTICATE = "WWW-Authenticate";
+	private static final String DEFAULT_REALM = "Realm";
+	private static String WWW_AUTHENTICATE_FORMAT = "Basic realm=\"%s\"";
+
+	private String headerValue = createHeaderValue(DEFAULT_REALM);
 
 	@Override
-	public <T> Mono<T> commence(ServerWebExchange exchange, AuthenticationException e) {
-		ServerHttpResponse response = exchange.getResponse();
-		response.setStatusCode(HttpStatus.UNAUTHORIZED);
-		response.getHeaders().set("WWW-Authenticate", "Basic realm=\"Realm\"");
-		return Mono.empty();
+	public Mono<Void> commence(ServerWebExchange exchange, AuthenticationException e) {
+		return Mono.fromRunnable(() -> {
+			ServerHttpResponse response = exchange.getResponse();
+			response.setStatusCode(HttpStatus.UNAUTHORIZED);
+			response.getHeaders().set(WWW_AUTHENTICATE, this.headerValue);
+		});
+	}
+
+	/**
+	 * Sets the realm to be used
+	 * @param realm the realm. Default is "Realm"
+	 */
+	public void setRealm(String realm) {
+		this.headerValue = createHeaderValue(realm);
+	}
+
+	private static String createHeaderValue(String realm) {
+		Assert.notNull(realm, "realm cannot be null");
+		return String.format(WWW_AUTHENTICATE_FORMAT, realm);
 	}
 }

+ 1 - 1
webflux/src/main/java/org/springframework/security/web/server/authorization/AccessDeniedHandler.java

@@ -29,5 +29,5 @@ import reactor.core.publisher.Mono;
  */
 public interface AccessDeniedHandler {
 
-	<T> Mono<T> handle(ServerWebExchange exchange, AccessDeniedException denied);
+	Mono<Void> handle(ServerWebExchange exchange, AccessDeniedException denied);
 }

+ 33 - 7
webflux/src/main/java/org/springframework/security/web/server/authorization/ExceptionTranslationWebFilter.java

@@ -17,16 +17,17 @@
  */
 package org.springframework.security.web.server.authorization;
 
+import reactor.core.publisher.Mono;
 
 import org.springframework.http.HttpStatus;
 import org.springframework.security.access.AccessDeniedException;
 import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
 import org.springframework.security.web.server.AuthenticationEntryPoint;
 import org.springframework.security.web.server.authentication.www.HttpBasicAuthenticationEntryPoint;
+import org.springframework.util.Assert;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebFilter;
 import org.springframework.web.server.WebFilterChain;
-import reactor.core.publisher.Mono;
 
 /**
  *
@@ -34,18 +35,43 @@ import reactor.core.publisher.Mono;
  * @since 5.0
  */
 public class ExceptionTranslationWebFilter implements WebFilter {
-	private AuthenticationEntryPoint entryPoint = new HttpBasicAuthenticationEntryPoint();
+	private AuthenticationEntryPoint authenticationEntryPoint = new HttpBasicAuthenticationEntryPoint();
 
 	private AccessDeniedHandler accessDeniedHandler = new HttpStatusAccessDeniedHandler(HttpStatus.FORBIDDEN);
 
 	@Override
 	public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
 		return chain.filter(exchange)
-			.onErrorResume(AccessDeniedException.class, denied -> {
-				return exchange.getPrincipal()
-					.switchIfEmpty( Mono.defer( () -> entryPoint.commence(exchange, new AuthenticationCredentialsNotFoundException("Not Authenticated", denied))))
-					.flatMap( principal -> accessDeniedHandler.handle(exchange, denied));
-			});
+			.onErrorResume(AccessDeniedException.class, denied -> exchange.getPrincipal()
+				.switchIfEmpty( commenceAuthentication(exchange, denied))
+				.flatMap( principal -> this.accessDeniedHandler.handle(exchange, denied))
+			);
+	}
+
+	/**
+	 * Sets the access denied handler.
+	 * @param accessDeniedHandler the access denied handler to use. Default is
+	 * HttpStatusAccessDeniedHandler with HttpStatus.FORBIDDEN
+	 */
+	public void setAccessDeniedHandler(AccessDeniedHandler accessDeniedHandler) {
+		Assert.notNull(accessDeniedHandler, "accessDeniedHandler cannot be null");
+		this.accessDeniedHandler = accessDeniedHandler;
 	}
 
+	/**
+	 * Sets the authentication entry point used when authentication is required
+	 * @param authenticationEntryPoint the authentication entry point to use. Default is
+	 * {@link HttpBasicAuthenticationEntryPoint}
+	 */
+	public void setAuthenticationEntryPoint(
+		AuthenticationEntryPoint authenticationEntryPoint) {
+		Assert.notNull(authenticationEntryPoint, "authenticationEntryPoint cannot be null");
+		this.authenticationEntryPoint = authenticationEntryPoint;
+	}
+
+	private <T> Mono<T> commenceAuthentication(ServerWebExchange exchange, AccessDeniedException denied) {
+		return this.authenticationEntryPoint.commence(exchange, new AuthenticationCredentialsNotFoundException("Not Authenticated", denied))
+			.then(Mono.empty());
+	}
 }
+

+ 7 - 5
webflux/src/main/java/org/springframework/security/web/server/authorization/HttpStatusAccessDeniedHandler.java

@@ -18,13 +18,15 @@
 
 package org.springframework.security.web.server.authorization;
 
+import reactor.core.publisher.Mono;
+
 import org.springframework.http.HttpStatus;
 import org.springframework.security.access.AccessDeniedException;
+import org.springframework.util.Assert;
 import org.springframework.web.server.ServerWebExchange;
-import org.springframework.web.server.WebFilter;
-import reactor.core.publisher.Mono;
 
 /**
+ * Sets an HTTP Status that is provided when
  * @author Rob Winch
  * @since 5.0
  */
@@ -32,12 +34,12 @@ public class HttpStatusAccessDeniedHandler implements AccessDeniedHandler {
 	private final HttpStatus httpStatus;
 
 	public HttpStatusAccessDeniedHandler(HttpStatus httpStatus) {
+		Assert.notNull(httpStatus, "httpStatus cannot be null");
 		this.httpStatus = httpStatus;
 	}
 
 	@Override
-	public <T> Mono<T> handle(ServerWebExchange exchange, AccessDeniedException e) {
-		exchange.getResponse().setStatusCode(HttpStatus.FORBIDDEN);
-		return Mono.empty();
+	public Mono<Void> handle(ServerWebExchange exchange, AccessDeniedException e) {
+		return Mono.fromRunnable(() -> exchange.getResponse().setStatusCode(HttpStatus.FORBIDDEN));
 	}
 }

+ 84 - 0
webflux/src/test/java/org/springframework/security/web/server/authentication/www/HttpBasicAuthenticationEntryPointTests.java

@@ -0,0 +1,84 @@
+/*
+ *
+ *  * Copyright 2002-2017 the original author or authors.
+ *  *
+ *  * Licensed under the Apache License, Version 2.0 (the "License");
+ *  * you may not use this file except in compliance with the License.
+ *  * You may obtain a copy of the License at
+ *  *
+ *  *      http://www.apache.org/licenses/LICENSE-2.0
+ *  *
+ *  * Unless required by applicable law or agreed to in writing, software
+ *  * distributed under the License is distributed on an "AS IS" BASIS,
+ *  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *  * See the License for the specific language governing permissions and
+ *  * limitations under the License.
+ *
+ */
+
+package org.springframework.security.web.server.authentication.www;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.runners.MockitoJUnitRunner;
+
+import org.springframework.http.HttpStatus;
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
+import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
+import org.springframework.security.core.AuthenticationException;
+import org.springframework.web.server.ServerWebExchange;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Mockito.verifyZeroInteractions;
+
+/**
+ * @author Rob Winch
+ * @since 5.0
+ */
+@RunWith(MockitoJUnitRunner.class)
+public class HttpBasicAuthenticationEntryPointTests {
+	@Mock
+	private ServerWebExchange exchange;
+	private HttpBasicAuthenticationEntryPoint entryPoint = new HttpBasicAuthenticationEntryPoint();
+
+	private AuthenticationException exception = new AuthenticationCredentialsNotFoundException("Authenticate");
+
+	@Test
+	public void commenceWhenNoSubscribersThenNoActions() {
+		this.entryPoint.commence(this.exchange,
+				this.exception);
+
+		verifyZeroInteractions(this.exchange);
+	}
+
+	@Test
+	public void commenceWhenSubscribeThenStatusAndHeaderSet() {
+		this.exchange = MockServerHttpRequest.get("/").toExchange();
+
+		this.entryPoint.commence(this.exchange, this.exception).block();
+
+		assertThat(this.exchange.getResponse().getStatusCode()).isEqualTo(
+			HttpStatus.UNAUTHORIZED);
+		assertThat(this.exchange.getResponse().getHeaders().get("WWW-Authenticate")).containsOnly(
+			"Basic realm=\"Realm\"");
+	}
+
+	@Test
+	public void commenceWhenCustomRealmThenStatusAndHeaderSet() {
+		this.entryPoint.setRealm("Custom");
+		this.exchange = MockServerHttpRequest.get("/").toExchange();
+
+		this.entryPoint.commence(this.exchange, this.exception).block();
+
+		assertThat(this.exchange.getResponse().getStatusCode()).isEqualTo(
+			HttpStatus.UNAUTHORIZED);
+		assertThat(this.exchange.getResponse().getHeaders().get("WWW-Authenticate")).containsOnly(
+			"Basic realm=\"Custom\"");
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void setRealmWhenNullThenException() {
+		this.entryPoint.setRealm(null);
+	}
+}

+ 178 - 0
webflux/src/test/java/org/springframework/security/web/server/authorization/ExceptionTranslationWebFilterTests.java

@@ -0,0 +1,178 @@
+/*
+ *
+ *  * Copyright 2002-2017 the original author or authors.
+ *  *
+ *  * Licensed under the Apache License, Version 2.0 (the "License");
+ *  * you may not use this file except in compliance with the License.
+ *  * You may obtain a copy of the License at
+ *  *
+ *  *      http://www.apache.org/licenses/LICENSE-2.0
+ *  *
+ *  * Unless required by applicable law or agreed to in writing, software
+ *  * distributed under the License is distributed on an "AS IS" BASIS,
+ *  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *  * See the License for the specific language governing permissions and
+ *  * limitations under the License.
+ *
+ */
+
+package org.springframework.security.web.server.authorization;
+
+import java.security.Principal;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.runners.MockitoJUnitRunner;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+
+import org.springframework.http.HttpStatus;
+import org.springframework.mock.http.server.reactive.MockServerHttpResponse;
+import org.springframework.security.access.AccessDeniedException;
+import org.springframework.security.web.server.AuthenticationEntryPoint;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.WebFilterChain;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.when;
+
+/**
+ * @author Rob Winch
+ * @since 5.0
+ */
+@RunWith(MockitoJUnitRunner.class)
+public class ExceptionTranslationWebFilterTests {
+	@Mock
+	private Principal principal;
+	@Mock
+	private ServerWebExchange exchange;
+	@Mock
+	private WebFilterChain chain;
+	@Mock
+	private AccessDeniedHandler deniedHandler;
+	@Mock
+	private AuthenticationEntryPoint entryPoint;
+
+	private TestMono<Void> deniedMono = TestMono.create();
+	private TestMono<Void> entryPointMono = TestMono.create();
+
+	private ExceptionTranslationWebFilter filter = new ExceptionTranslationWebFilter();
+
+	@Before
+	public void setup() {
+		when(this.exchange.getResponse()).thenReturn(new MockServerHttpResponse());
+		when(this.deniedHandler.handle(any(), any())).thenReturn(this.deniedMono.mono());
+		when(this.entryPoint.commence(any(), any())).thenReturn(this.entryPointMono.mono());
+
+		this.filter.setAuthenticationEntryPoint(this.entryPoint);
+		this.filter.setAccessDeniedHandler(this.deniedHandler);
+	}
+
+	@Test
+	public void filterWhenNoExceptionThenNotHandled() {
+		when(this.chain.filter(this.exchange)).thenReturn(Mono.empty());
+
+		StepVerifier.create(this.filter.filter(this.exchange, this.chain))
+			.expectComplete()
+			.verify();
+
+		assertThat(this.deniedMono.isInvoked()).isFalse();
+		assertThat(this.entryPointMono.isInvoked()).isFalse();
+	}
+
+	@Test
+	public void filterWhenNotAccessDeniedExceptionThenNotHandled() {
+		when(this.chain.filter(this.exchange)).thenReturn(Mono.error(new IllegalArgumentException("oops")));
+
+		StepVerifier.create(this.filter.filter(this.exchange, this.chain))
+			.expectError(IllegalArgumentException.class)
+			.verify();
+
+		assertThat(this.deniedMono.isInvoked()).isFalse();
+		assertThat(this.entryPointMono.isInvoked()).isFalse();
+	}
+
+	@Test
+	public void filterWhenAccessDeniedExceptionAndNotAuthenticatedThenHandled() {
+		when(this.exchange.getPrincipal()).thenReturn(Mono.empty());
+		when(this.chain.filter(this.exchange)).thenReturn(Mono.error(new AccessDeniedException("Not Authorized")));
+
+		StepVerifier.create(this.filter.filter(this.exchange, this.chain))
+			.expectComplete()
+			.verify();
+
+		assertThat(this.deniedMono.isInvoked()).isFalse();
+		assertThat(this.entryPointMono.isInvoked()).isTrue();
+	}
+
+	@Test
+	public void filterWhenDefaultsAndAccessDeniedExceptionAndAuthenticatedThenForbidden() {
+		this.filter = new ExceptionTranslationWebFilter();
+		when(this.exchange.getPrincipal()).thenReturn(Mono.just(this.principal));
+		when(this.chain.filter(this.exchange)).thenReturn(Mono.error(new AccessDeniedException("Not Authorized")));
+
+		StepVerifier.create(this.filter.filter(this.exchange, this.chain))
+			.expectComplete()
+			.verify();
+
+		assertThat(this.exchange.getResponse().getStatusCode()).isEqualTo(
+			HttpStatus.FORBIDDEN);
+	}
+
+	@Test
+	public void filterWhenDefaultsAndAccessDeniedExceptionAndNotAuthenticatedThenUnauthorized() {
+		this.filter = new ExceptionTranslationWebFilter();
+		when(this.exchange.getPrincipal()).thenReturn(Mono.empty());
+		when(this.chain.filter(this.exchange)).thenReturn(Mono.error(new AccessDeniedException("Not Authorized")));
+
+		StepVerifier.create(this.filter.filter(this.exchange, this.chain))
+			.expectComplete()
+			.verify();
+
+		assertThat(this.exchange.getResponse().getStatusCode()).isEqualTo(
+			HttpStatus.UNAUTHORIZED);
+	}
+
+	@Test
+	public void filterWhenAccessDeniedExceptionAndAuthenticatedThenHandled() {
+		when(this.exchange.getPrincipal()).thenReturn(Mono.just(this.principal));
+		when(this.chain.filter(this.exchange)).thenReturn(Mono.error(new AccessDeniedException("Not Authorized")));
+
+		StepVerifier.create(this.filter.filter(this.exchange, this.chain))
+			.expectComplete()
+			.verify();
+
+		assertThat(this.deniedMono.isInvoked()).isTrue();
+		assertThat(this.entryPointMono.isInvoked()).isFalse();
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void setAccessDeniedHandlerWhenNullThenException() {
+		this.filter.setAccessDeniedHandler(null);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void setAuthenticationEntryPointWhenNullThenException() {
+		this.filter.setAuthenticationEntryPoint(null);
+	}
+
+	static class TestMono<T> {
+		private final AtomicBoolean invoked = new AtomicBoolean();
+
+		public Mono<T> mono() {
+			return Mono.<T>empty().doOnSubscribe(s -> this.invoked.set(true));
+		}
+
+		public boolean isInvoked() {
+			return this.invoked.get();
+		}
+
+		public static <T> TestMono<T> create() {
+			return new TestMono<T>();
+		}
+	}
+}

+ 67 - 0
webflux/src/test/java/org/springframework/security/web/server/authorization/HttpStatusAccessDeniedHandlerTests.java

@@ -0,0 +1,67 @@
+/*
+ *
+ *  * Copyright 2002-2017 the original author or authors.
+ *  *
+ *  * Licensed under the Apache License, Version 2.0 (the "License");
+ *  * you may not use this file except in compliance with the License.
+ *  * You may obtain a copy of the License at
+ *  *
+ *  *      http://www.apache.org/licenses/LICENSE-2.0
+ *  *
+ *  * Unless required by applicable law or agreed to in writing, software
+ *  * distributed under the License is distributed on an "AS IS" BASIS,
+ *  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *  * See the License for the specific language governing permissions and
+ *  * limitations under the License.
+ *
+ */
+
+package org.springframework.security.web.server.authorization;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.runners.MockitoJUnitRunner;
+
+import org.springframework.http.HttpStatus;
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
+import org.springframework.security.access.AccessDeniedException;
+import org.springframework.web.server.ServerWebExchange;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Mockito.verifyZeroInteractions;
+
+/**
+ * @author Rob Winch
+ * @since 5.0
+ */
+@RunWith(MockitoJUnitRunner.class)
+public class HttpStatusAccessDeniedHandlerTests {
+	@Mock
+	private ServerWebExchange exchange;
+	private final HttpStatus httpStatus = HttpStatus.FORBIDDEN;
+	private HttpStatusAccessDeniedHandler handler = new HttpStatusAccessDeniedHandler(this.httpStatus);
+
+	private AccessDeniedException exception = new AccessDeniedException("Forbidden");
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorHttpStatusWhenNullThenException() {
+		new HttpStatusAccessDeniedHandler((HttpStatus) null);
+	}
+
+	@Test
+	public void commenceWhenNoSubscribersThenNoActions() {
+		this.handler.handle(this.exchange, this.exception);
+
+		verifyZeroInteractions(this.exchange);
+	}
+
+	@Test
+	public void commenceWhenSubscribeThenStatusSet() {
+		this.exchange = MockServerHttpRequest.get("/").toExchange();
+
+		this.handler.handle(this.exchange, this.exception).block();
+
+		assertThat(this.exchange.getResponse().getStatusCode()).isEqualTo(this.httpStatus);
+	}
+}