Pārlūkot izejas kodu

Add SecurityContextHolderStrategy to Saml2

Issue gh-11060
Josh Cummings 3 gadi atpakaļ
vecāks
revīzija
e90a11b1c0

+ 16 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilter.java

@@ -29,6 +29,7 @@ import org.springframework.core.log.LogMessage;
 import org.springframework.http.MediaType;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.saml2.core.Saml2ParameterNames;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal;
 import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest;
@@ -63,6 +64,9 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
 
 	private final Log logger = LogFactory.getLog(getClass());
 
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
 	private final Saml2LogoutRequestValidator logoutRequestValidator;
 
 	private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
@@ -107,7 +111,7 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
 			return;
 		}
 
-		Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
+		Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
 		RelyingPartyRegistration registration = this.relyingPartyRegistrationResolver.resolve(request,
 				getRegistrationId(authentication));
 		if (registration == null) {
@@ -167,6 +171,17 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
 		this.logoutRequestMatcher = logoutRequestMatcher;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
+	 * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
+	 *
+	 * @since 5.8
+	 */
+	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
+	}
+
 	private String getRegistrationId(Authentication authentication) {
 		if (authentication == null) {
 			return null;

+ 8 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -25,6 +25,8 @@ import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
+import org.springframework.security.core.context.SecurityContextImpl;
 import org.springframework.security.saml2.core.Saml2Error;
 import org.springframework.security.saml2.core.Saml2ParameterNames;
 import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidator;
@@ -48,6 +50,8 @@ import static org.mockito.Mockito.verifyNoInteractions;
  */
 public class Saml2LogoutRequestFilterTests {
 
+	SecurityContextHolderStrategy securityContextHolderStrategy = mock(SecurityContextHolderStrategy.class);
+
 	RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = mock(RelyingPartyRegistrationResolver.class);
 
 	Saml2LogoutRequestValidator logoutRequestValidator = mock(Saml2LogoutRequestValidator.class);
@@ -94,6 +98,8 @@ public class Saml2LogoutRequestFilterTests {
 		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full()
 				.assertingPartyDetails((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST)).build();
 		Authentication authentication = new TestingAuthenticationToken("user", "password");
+		given(this.securityContextHolderStrategy.getContext()).willReturn(new SecurityContextImpl(authentication));
+		this.logoutRequestProcessingFilter.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
 		SecurityContextHolder.getContext().setAuthentication(authentication);
 		MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/saml2/slo");
 		request.setServletPath("/logout/saml2/slo");
@@ -114,6 +120,7 @@ public class Saml2LogoutRequestFilterTests {
 		assertThat(content).contains(
 				"<meta http-equiv=\"Content-Security-Policy\" content=\"script-src 'sha256-t+jmhLjs1ocvgaHBJsFcgznRk68d37TLtbI3NE9h7EU='\">");
 		assertThat(content).contains("<script>window.onload = () => document.forms[0].submit();</script>");
+		verify(this.securityContextHolderStrategy).getContext();
 	}
 
 	@Test