Explorar o código

AuthenticationFilter Session Fixation Protection

Fixes gh-7446
Josh Cummings %!s(int64=6) %!d(string=hai) anos
pai
achega
7576dc44d7

+ 6 - 1
web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java

@@ -16,12 +16,12 @@
 package org.springframework.security.web.authentication;
 
 import java.io.IOException;
-
 import javax.servlet.Filter;
 import javax.servlet.FilterChain;
 import javax.servlet.ServletException;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
+import javax.servlet.http.HttpSession;
 
 import org.springframework.http.HttpStatus;
 import org.springframework.security.authentication.AuthenticationManager;
@@ -146,6 +146,11 @@ public class AuthenticationFilter extends OncePerRequestFilter {
 				return;
 			}
 
+			HttpSession session = request.getSession(false);
+			if (session != null) {
+				request.changeSessionId();
+			}
+
 			successfulAuthentication(request, response, filterChain, authenticationResult);
 		} catch (AuthenticationException e) {
 			unsuccessfulAuthentication(request, response, e);

+ 31 - 8
web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java

@@ -15,14 +15,6 @@
  */
 package org.springframework.security.web.authentication;
 
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyZeroInteractions;
-import static org.mockito.Mockito.when;
-
 import javax.servlet.FilterChain;
 import javax.servlet.ServletException;
 import javax.servlet.ServletRequest;
@@ -35,9 +27,12 @@ import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
+
 import org.springframework.http.HttpStatus;
+import org.springframework.mock.web.MockFilterChain;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.mock.web.MockHttpSession;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.AuthenticationManagerResolver;
 import org.springframework.security.authentication.BadCredentialsException;
@@ -46,6 +41,14 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyZeroInteractions;
+import static org.mockito.Mockito.when;
+
 /**
  * @author Sergey Bespalov
  * @since 5.2.0
@@ -246,4 +249,24 @@ public class AuthenticationFilterTests {
 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
 	}
 
+	// gh-7446
+	@Test
+	public void filterWhenSuccessfulAuthenticationThenSessionIdChanges() throws Exception {
+		Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE_USER");
+		when(this.authenticationConverter.convert(any())).thenReturn(authentication);
+		when(this.authenticationManager.authenticate(any())).thenReturn(authentication);
+
+		MockHttpSession session = new MockHttpSession();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", "/");
+		request.setSession(session);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain chain = new MockFilterChain();
+
+		String sessionId = session.getId();
+		AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManager, this.authenticationConverter);
+		filter.doFilter(request, response, chain);
+
+		assertThat(session.getId()).isNotEqualTo(sessionId);
+	}
+
 }