2
0
Эх сурвалжийг харах

Add CORS WebFlux Support

Fixes: gh-4832
Rob Winch 7 жил өмнө
parent
commit
cecbc2175b

+ 4 - 0
config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java

@@ -23,6 +23,10 @@ package org.springframework.security.config.web.server;
 public enum SecurityWebFiltersOrder {
 	FIRST(Integer.MIN_VALUE),
 	HTTP_HEADERS_WRITER,
+	/**
+	 * {@link org.springframework.web.cors.reactive.CorsWebFilter}
+	 */
+	CORS,
 	/**
 	 * {@link org.springframework.security.web.server.csrf.CsrfWebFilter}
 	 */

+ 83 - 0
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -111,6 +111,10 @@ import org.springframework.security.web.server.util.matcher.ServerWebExchangeMat
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;
 import org.springframework.util.Assert;
 import org.springframework.util.ClassUtils;
+import org.springframework.web.cors.reactive.CorsConfigurationSource;
+import org.springframework.web.cors.reactive.CorsProcessor;
+import org.springframework.web.cors.reactive.CorsWebFilter;
+import org.springframework.web.cors.reactive.DefaultCorsProcessor;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebFilter;
 import org.springframework.web.server.WebFilterChain;
@@ -181,6 +185,8 @@ public class ServerHttpSecurity {
 
 	private CsrfSpec csrf = new CsrfSpec();
 
+	private CorsSpec cors = new CorsSpec();
+
 	private ExceptionHandlingSpec exceptionHandling = new ExceptionHandlingSpec();
 
 	private HttpBasicSpec httpBasic;
@@ -299,6 +305,80 @@ public class ServerHttpSecurity {
 		return this.csrf;
 	}
 
+	/**
+	 * Configures CORS headers. By default if a {@link CorsConfigurationSource} Bean is found, it will be used
+	 * to create a {@link CorsWebFilter}. If {@link CorsSpec#configurationSource(CorsConfigurationSource)} is invoked
+	 * it will be used instead. If neither has been configured, the Cors configuration will do nothing.
+	 * @return the {@link CorsSpec} to customize
+	 */
+	public CorsSpec cors() {
+		if (this.cors == null) {
+			this.cors = new CorsSpec();
+		}
+		return this.cors;
+	}
+
+	/**
+	 * Configures CORS support within Spring Security. This ensures that the {@link CorsWebFilter} is place in the
+	 * correct order.
+	 */
+	public class CorsSpec {
+		private CorsWebFilter corsFilter;
+
+		/**
+		 * Configures the {@link CorsConfigurationSource} to be used
+		 * @param source the source to use
+		 * @return the {@link CorsSpec} for additional configuration
+		 */
+		public CorsSpec configurationSource(CorsConfigurationSource source) {
+			this.corsFilter = new CorsWebFilter(source);
+			return this;
+		}
+
+		/**
+		 * Disables CORS support within Spring Security.
+		 * @return the {@link ServerHttpSecurity} to continue configuring
+		 */
+		public ServerHttpSecurity disable() {
+			ServerHttpSecurity.this.cors = null;
+			return ServerHttpSecurity.this;
+		}
+
+		/**
+		 * Allows method chaining to continue configuring the {@link ServerHttpSecurity}
+		 * @return the {@link ServerHttpSecurity} to continue configuring
+		 */
+		public ServerHttpSecurity and() {
+			return ServerHttpSecurity.this;
+		}
+
+		protected void configure(ServerHttpSecurity http) {
+			CorsWebFilter corsFilter = getCorsFilter();
+			if (corsFilter != null) {
+				http.addFilterAt(this.corsFilter, SecurityWebFiltersOrder.CORS);
+			}
+		}
+
+		private CorsWebFilter getCorsFilter() {
+			if (this.corsFilter != null) {
+				return this.corsFilter;
+			}
+
+			CorsConfigurationSource source = getBeanOrNull(CorsConfigurationSource.class);
+			if (source == null) {
+				return null;
+			}
+			CorsProcessor processor = getBeanOrNull(CorsProcessor.class);
+			if (processor == null) {
+				processor = new DefaultCorsProcessor();
+			}
+			this.corsFilter = new CorsWebFilter(source, processor);
+			return this.corsFilter;
+		}
+
+		private CorsSpec() {}
+	}
+
 	/**
 	 * Configures HTTP Basic authentication. An example configuration is provided below:
 	 *
@@ -782,6 +862,9 @@ public class ServerHttpSecurity {
 		if(this.csrf != null) {
 			this.csrf.configure(this);
 		}
+		if (this.cors != null) {
+			this.cors.configure(this);
+		}
 		if(this.httpBasic != null) {
 			this.httpBasic.authenticationManager(this.authenticationManager);
 			this.httpBasic.configure(this);

+ 117 - 0
config/src/test/java/org/springframework/security/config/web/server/CorsSpecTests.java

@@ -0,0 +1,117 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.config.web.server;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+import org.springframework.context.ApplicationContext;
+import org.springframework.core.ResolvableType;
+import org.springframework.http.HttpHeaders;
+import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
+import org.springframework.test.web.reactive.server.FluxExchangeResult;
+import org.springframework.test.web.reactive.server.WebTestClient;
+import org.springframework.web.cors.CorsConfiguration;
+import org.springframework.web.cors.reactive.CorsConfigurationSource;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.when;
+
+/**
+ * @author Rob Winch
+ * @since 5.0
+ */
+@RunWith(MockitoJUnitRunner.class)
+public class CorsSpecTests {
+	@Mock
+	private CorsConfigurationSource source;
+	@Mock
+	private ApplicationContext context;
+
+	ServerHttpSecurity http;
+
+	HttpHeaders expectedHeaders = new HttpHeaders();
+
+	Set<String> headerNamesNotPresent = new HashSet<>();
+
+	@Before
+	public void setup() {
+		this.http = new TestingServerHttpSecurity()
+				.applicationContext(this.context);
+		CorsConfiguration value = new CorsConfiguration();
+		value.setAllowedOrigins(Arrays.asList("*"));
+		when(this.source.getCorsConfiguration(any())).thenReturn(value);
+	}
+
+	@Test
+	public void corsWhenEnabledThenAccessControlAllowOriginAndSecurityHeaders() {
+		this.http.cors().configurationSource(this.source);
+		this.expectedHeaders.set("Access-Control-Allow-Origin", "*");
+		this.expectedHeaders.set("X-Frame-Options", "DENY");
+		assertHeaders();
+	}
+
+	@Test
+	public void corsWhenCorsConfigurationSourceBeanThenAccessControlAllowOriginAndSecurityHeaders() {
+		when(this.context.getBeanNamesForType(any(ResolvableType.class))).thenReturn(new String[] {"source"}, new String[0]);
+		when(this.context.getBean("source")).thenReturn(this.source);
+		this.expectedHeaders.set("Access-Control-Allow-Origin", "*");
+		this.expectedHeaders.set("X-Frame-Options", "DENY");
+		assertHeaders();
+	}
+
+	@Test
+	public void corsWhenNoConfigurationSourceThenNoCorsHeaders() {
+		when(this.context.getBeanNamesForType(any(ResolvableType.class))).thenReturn(new String[0]);
+		this.headerNamesNotPresent.add("Access-Control-Allow-Origin");
+		assertHeaders();
+	}
+
+	private void assertHeaders() {
+		WebTestClient client = buildClient();
+		FluxExchangeResult<String> response = client.get()
+			.uri("https://example.com/")
+			.headers(h -> h.setOrigin("https://origin.example.com"))
+			.exchange()
+			.returnResult(String.class);
+
+		Map<String, List<String>> responseHeaders = response.getResponseHeaders();
+
+		if (!this.expectedHeaders.isEmpty()) {
+			assertThat(responseHeaders).describedAs(response.toString())
+					.containsAllEntriesOf(this.expectedHeaders);
+		}
+		if (!this.headerNamesNotPresent.isEmpty()) {
+			assertThat(responseHeaders.keySet()).doesNotContainAnyElementsOf(this.headerNamesNotPresent);
+		}
+	}
+
+	private WebTestClient buildClient() {
+		return WebTestClientBuilder
+				.bindToWebFilters(this.http.build())
+				.build();
+	}
+}

+ 32 - 0
config/src/test/java/org/springframework/security/config/web/server/TestingServerHttpSecurity.java

@@ -0,0 +1,32 @@
+/*
+ * Copyright 2002-2016 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.config.web.server;
+
+import org.springframework.beans.BeansException;
+import org.springframework.context.ApplicationContext;
+
+/**
+ * @author Rob Winch
+ * @since 5.1
+ */
+public class TestingServerHttpSecurity extends ServerHttpSecurity {
+	public TestingServerHttpSecurity applicationContext(ApplicationContext applicationContext)
+			throws BeansException {
+		super.setApplicationContext(applicationContext);
+		return this;
+	}
+}