Forráskód Böngészése

Allow configuration of SessionAuthenticationStrategy for CSRF

Closes gh-5300
Michael Vitz 7 éve
szülő
commit
09e8ae42ed

+ 41 - 2
config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2019 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -30,6 +30,7 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.web.access.AccessDeniedHandler;
 import org.springframework.security.web.access.AccessDeniedHandlerImpl;
 import org.springframework.security.web.access.DelegatingAccessDeniedHandler;
+import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
 import org.springframework.security.web.csrf.CsrfAuthenticationStrategy;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfLogoutHandler;
@@ -81,6 +82,7 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
 			new HttpSessionCsrfTokenRepository());
 	private RequestMatcher requireCsrfProtectionMatcher = CsrfFilter.DEFAULT_CSRF_MATCHER;
 	private List<RequestMatcher> ignoredCsrfProtectionMatchers = new ArrayList<>();
+	private SessionAuthenticationStrategy sessionAuthenticationStrategy;
 	private final ApplicationContext context;
 
 	/**
@@ -179,6 +181,26 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
 				.and();
 	}
 
+	/**
+	 * <p>
+	 * Specify the {@link SessionAuthenticationStrategy} to use. The default is a
+	 * {@link CsrfAuthenticationStrategy}.
+	 * </p>
+	 *
+	 * @author Michael Vitz
+	 * @since 5.1
+	 *
+	 * @param sessionAuthenticationStrategy the {@link SessionAuthenticationStrategy} to use
+	 * @return the {@link CsrfConfigurer} for further customizations
+	 */
+	public CsrfConfigurer<H> sessionAuthenticationStrategy(
+			SessionAuthenticationStrategy sessionAuthenticationStrategy) {
+		Assert.notNull(sessionAuthenticationStrategy,
+				"sessionAuthenticationStrategy cannot be null");
+		this.sessionAuthenticationStrategy = sessionAuthenticationStrategy;
+		return this;
+	}
+
 	@SuppressWarnings("unchecked")
 	@Override
 	public void configure(H http) throws Exception {
@@ -200,7 +222,7 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
 				.getConfigurer(SessionManagementConfigurer.class);
 		if (sessionConfigurer != null) {
 			sessionConfigurer.addSessionAuthenticationStrategy(
-					new CsrfAuthenticationStrategy(this.csrfTokenRepository));
+					getSessionAuthenticationStrategy());
 		}
 		filter = postProcess(filter);
 		http.addFilter(filter);
@@ -289,6 +311,23 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
 		return new DelegatingAccessDeniedHandler(handlers, defaultAccessDeniedHandler);
 	}
 
+	/**
+	 * Gets the {@link SessionAuthenticationStrategy} to use. If none was set by the user a
+	 * {@link CsrfAuthenticationStrategy} is created.
+	 *
+	 * @author Michael Vitz
+	 * @since 5.1
+	 *
+	 * @return the {@link SessionAuthenticationStrategy}
+	 */
+	private SessionAuthenticationStrategy getSessionAuthenticationStrategy() {
+		if (sessionAuthenticationStrategy != null) {
+			return sessionAuthenticationStrategy;
+		} else {
+			return new CsrfAuthenticationStrategy(this.csrfTokenRepository);
+		}
+	}
+
 	/**
 	 * Allows registering {@link RequestMatcher} instances that should be ignored (even if
 	 * the {@link HttpServletRequest} matches the

+ 65 - 8
config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java

@@ -29,8 +29,10 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
 import org.springframework.security.config.test.SpringTestRule;
+import org.springframework.security.core.Authentication;
 import org.springframework.security.core.userdetails.PasswordEncodedUser;
 import org.springframework.security.web.access.AccessDeniedHandler;
+import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
 import org.springframework.security.web.csrf.DefaultCsrfToken;
 import org.springframework.security.web.firewall.StrictHttpFirewall;
@@ -60,14 +62,7 @@ import static org.springframework.security.test.web.servlet.request.SecurityMock
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user;
 import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated;
 import static org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.unauthenticated;
-import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete;
-import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
-import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.head;
-import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.options;
-import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.patch;
-import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
-import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.put;
-import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.request;
+import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
 
@@ -76,6 +71,8 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.
  *
  * @author Rob Winch
  * @author Eleftheria Stein
+ * @author Michael Vitz
+ * @author Sam Simmons
  */
 public class CsrfConfigurerTests {
 	@Rule
@@ -684,6 +681,66 @@ public class CsrfConfigurerTests {
 		}
 	}
 
+	@EnableWebSecurity
+	static class NullAuthenticationStrategy extends WebSecurityConfigurerAdapter {
+		@Override
+		protected void configure(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+					.csrf()
+					.sessionAuthenticationStrategy(null);
+			// @formatter:on
+		}
+	}
+
+	@Test
+	public void getWhenNullAuthenticationStrategyThenException() {
+		assertThatThrownBy(() -> this.spring.register(NullAuthenticationStrategy.class).autowire())
+				.isInstanceOf(BeanCreationException.class)
+				.hasRootCauseInstanceOf(IllegalArgumentException.class);
+	}
+
+	@EnableWebSecurity
+	static class CsrfAuthenticationStrategyConfig extends WebSecurityConfigurerAdapter {
+		static SessionAuthenticationStrategy STRATEGY;
+
+		@Override
+		protected void configure(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+					.formLogin()
+					.and()
+					.csrf()
+					.sessionAuthenticationStrategy(STRATEGY);
+			// @formatter:on
+		}
+
+		@Override
+		protected void configure(AuthenticationManagerBuilder auth) throws Exception {
+			// @formatter:off
+			auth
+					.inMemoryAuthentication()
+					.withUser(PasswordEncodedUser.user());
+			// @formatter:on
+		}
+	}
+
+	@Test
+	public void csrfAuthenticationStrategyConfiguredThenStrategyUsed() throws Exception {
+		CsrfAuthenticationStrategyConfig.STRATEGY = mock(SessionAuthenticationStrategy.class);
+
+		this.spring.register(CsrfAuthenticationStrategyConfig.class).autowire();
+
+		this.mvc.perform(post("/login")
+				.with(csrf())
+				.param("username", "user")
+				.param("password", "password"))
+				.andExpect(redirectedUrl("/"));
+
+		verify(CsrfAuthenticationStrategyConfig.STRATEGY, atLeastOnce())
+				.onAuthentication(any(Authentication.class), any(HttpServletRequest.class), any(HttpServletResponse.class));
+	}
+
 	@RestController
 	static class BasicController {
 		@GetMapping("/")