瀏覽代碼

Merge branch '5.7.x' into 5.8.x

Closes gh-12510
Marcus Da Coregio 2 年之前
父節點
當前提交
ae46032ced

+ 19 - 0
web/src/main/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilter.java

@@ -59,6 +59,8 @@ import org.springframework.security.web.authentication.AuthenticationSuccessHand
 import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler;
 import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler;
 import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
+import org.springframework.security.web.context.RequestAttributeSecurityContextRepository;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.UrlUtils;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -146,6 +148,8 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
 
 	private AuthenticationFailureHandler failureHandler;
 
+	private SecurityContextRepository securityContextRepository = new RequestAttributeSecurityContextRepository();
+
 	@Override
 	public void afterPropertiesSet() {
 		Assert.notNull(this.userDetailsService, "userDetailsService must be specified");
@@ -183,6 +187,7 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
 				context.setAuthentication(targetUser);
 				this.securityContextHolderStrategy.setContext(context);
 				this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", targetUser));
+				this.securityContextRepository.saveContext(context, request, response);
 				// redirect to target url
 				this.successHandler.onAuthenticationSuccess(request, response, targetUser);
 			}
@@ -200,6 +205,7 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
 			context.setAuthentication(originalUser);
 			this.securityContextHolderStrategy.setContext(context);
 			this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", originalUser));
+			this.securityContextRepository.saveContext(context, request, response);
 			// redirect to target url
 			this.successHandler.onAuthenticationSuccess(request, response, originalUser);
 			return;
@@ -525,6 +531,19 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
 		this.securityContextHolderStrategy = securityContextHolderStrategy;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
+	 * switch user success. The default is
+	 * {@link RequestAttributeSecurityContextRepository}.
+	 * @param securityContextRepository the {@link SecurityContextRepository} to use.
+	 * Cannot be null.
+	 * @since 5.7.7
+	 */
+	public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
+		Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
+		this.securityContextRepository = securityContextRepository;
+	}
+
 	private static RequestMatcher createMatcher(String pattern) {
 		return new AntPathRequestMatcher(pattern, "POST", true, new UrlPathHelper());
 	}

+ 60 - 0
web/src/test/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilterTests.java

@@ -16,15 +16,18 @@
 
 package org.springframework.security.web.authentication.switchuser;
 
+import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 
 import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
 
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 
+import org.springframework.mock.web.MockFilterChain;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.authentication.AccountExpiredException;
@@ -47,11 +50,15 @@ import org.springframework.security.core.userdetails.UsernameNotFoundException;
 import org.springframework.security.util.FieldUtils;
 import org.springframework.security.web.DefaultRedirectStrategy;
 import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler;
+import org.springframework.security.web.context.RequestAttributeSecurityContextRepository;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.AnyRequestMatcher;
+import org.springframework.test.util.ReflectionTestUtils;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
@@ -503,6 +510,59 @@ public class SwitchUserFilterTests {
 		filter.setSwitchFailureUrl("/foo");
 	}
 
+	@Test
+	void filterWhenDefaultSecurityContextRepositoryThenRequestAttributeRepository() {
+		SwitchUserFilter switchUserFilter = new SwitchUserFilter();
+		assertThat(ReflectionTestUtils.getField(switchUserFilter, "securityContextRepository"))
+				.isInstanceOf(RequestAttributeSecurityContextRepository.class);
+	}
+
+	@Test
+	void doFilterWhenSwitchUserThenSaveSecurityContext() throws ServletException, IOException {
+		SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		MockFilterChain filterChain = new MockFilterChain();
+		request.setParameter(SwitchUserFilter.SPRING_SECURITY_SWITCH_USERNAME_KEY, "jacklord");
+		request.setRequestURI("/login/impersonate");
+		SwitchUserFilter filter = new SwitchUserFilter();
+		filter.setSecurityContextRepository(securityContextRepository);
+		filter.setUserDetailsService(new MockUserDetailsService());
+		filter.setTargetUrl("/target");
+		filter.afterPropertiesSet();
+
+		filter.doFilter(request, response, filterChain);
+
+		verify(securityContextRepository).saveContext(any(), any(), any());
+	}
+
+	@Test
+	void doFilterWhenExitUserThenSaveSecurityContext() throws ServletException, IOException {
+		UsernamePasswordAuthenticationToken source = UsernamePasswordAuthenticationToken.authenticated("dano",
+				"hawaii50", ROLES_12);
+		// set current user (Admin)
+		List<GrantedAuthority> adminAuths = new ArrayList<>(ROLES_12);
+		adminAuths.add(new SwitchUserGrantedAuthority("PREVIOUS_ADMINISTRATOR", source));
+		UsernamePasswordAuthenticationToken admin = UsernamePasswordAuthenticationToken.authenticated("jacklord",
+				"hawaii50", adminAuths);
+		SecurityContextHolder.getContext().setAuthentication(admin);
+		SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		MockFilterChain filterChain = new MockFilterChain();
+		request.setParameter(SwitchUserFilter.SPRING_SECURITY_SWITCH_USERNAME_KEY, "jacklord");
+		request.setRequestURI("/logout/impersonate");
+		SwitchUserFilter filter = new SwitchUserFilter();
+		filter.setSecurityContextRepository(securityContextRepository);
+		filter.setUserDetailsService(new MockUserDetailsService());
+		filter.setTargetUrl("/target");
+		filter.afterPropertiesSet();
+
+		filter.doFilter(request, response, filterChain);
+
+		verify(securityContextRepository).saveContext(any(), any(), any());
+	}
+
 	private class MockUserDetailsService implements UserDetailsService {
 
 		private String password = "hawaii50";