Browse Source

Use SecurityContextHolderStrategy for Digest

Issue gh-11060
Josh Cummings 3 years ago
parent
commit
5086409dcf

+ 19 - 4
web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilter.java

@@ -42,6 +42,7 @@ import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.SpringSecurityMessageSource;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.core.userdetails.UserCache;
 import org.springframework.security.core.userdetails.UserDetails;
 import org.springframework.security.core.userdetails.UserDetailsService;
@@ -93,6 +94,9 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
 
 	private static final Log logger = LogFactory.getLog(DigestAuthenticationFilter.class);
 
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
 	private AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();
 
 	private DigestAuthenticationEntryPoint authenticationEntryPoint;
@@ -192,9 +196,9 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
 		logger.debug(LogMessage.format("Authentication success for user: '%s' with response: '%s'",
 				digestAuth.getUsername(), digestAuth.getResponse()));
 		Authentication authentication = createSuccessfulAuthentication(request, user);
-		SecurityContext context = SecurityContextHolder.createEmptyContext();
+		SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
 		context.setAuthentication(authentication);
-		SecurityContextHolder.setContext(context);
+		this.securityContextHolderStrategy.setContext(context);
 		this.securityContextRepository.saveContext(context, request, response);
 		chain.doFilter(request, response);
 	}
@@ -214,8 +218,8 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
 
 	private void fail(HttpServletRequest request, HttpServletResponse response, AuthenticationException failed)
 			throws IOException, ServletException {
-		SecurityContext context = SecurityContextHolder.createEmptyContext();
-		SecurityContextHolder.setContext(context);
+		SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
+		this.securityContextHolderStrategy.setContext(context);
 		logger.debug(failed);
 		this.authenticationEntryPoint.commence(request, response, failed);
 	}
@@ -287,6 +291,17 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
 		this.securityContextRepository = securityContextRepository;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
+	 * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
+	 *
+	 * @since 5.8
+	 */
+	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
+	}
+
 	private class DigestData {
 
 		private final String username;

+ 47 - 0
web/src/test/java/org/springframework/security/MockSecurityContextHolderStrategy.java

@@ -0,0 +1,47 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security;
+
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
+import org.springframework.security.core.context.SecurityContextImpl;
+
+public class MockSecurityContextHolderStrategy implements SecurityContextHolderStrategy {
+
+	private SecurityContext mock;
+
+	@Override
+	public void clearContext() {
+		this.mock = null;
+	}
+
+	@Override
+	public SecurityContext getContext() {
+		return this.mock;
+	}
+
+	@Override
+	public void setContext(SecurityContext context) {
+		this.mock = context;
+	}
+
+	@Override
+	public SecurityContext createEmptyContext() {
+		return new SecurityContextImpl();
+	}
+
+}

+ 16 - 0
web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java

@@ -30,10 +30,12 @@ import org.mockito.ArgumentCaptor;
 
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.MockSecurityContextHolderStrategy;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.core.userdetails.User;
 import org.springframework.security.core.userdetails.UserDetails;
 import org.springframework.security.core.userdetails.UserDetailsService;
@@ -44,8 +46,10 @@ import org.springframework.util.StringUtils;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 
@@ -306,6 +310,18 @@ public class DigestAuthenticationFilterTests {
 		assertThatIllegalArgumentException().isThrownBy(filter::afterPropertiesSet);
 	}
 
+	@Test
+	public void authenticateUsesCustomSecurityContextHolderStrategy() throws Exception {
+		SecurityContextHolderStrategy securityContextHolderStrategy = spy(new MockSecurityContextHolderStrategy());
+		String responseDigest = DigestAuthUtils.generateDigest(false, USERNAME, REALM, PASSWORD, "GET", REQUEST_URI,
+				QOP, NONCE, NC, CNONCE);
+		this.request.addHeader("Authorization",
+				createAuthorizationHeader(USERNAME, REALM, NONCE, REQUEST_URI, responseDigest, QOP, NC, CNONCE));
+		this.filter.setSecurityContextHolderStrategy(securityContextHolderStrategy);
+		executeFilterInContainerSimulator(this.filter, this.request, true);
+		verify(securityContextHolderStrategy).setContext(any());
+	}
+
 	@Test
 	public void successfulLoginThenFailedLoginResultsInSessionLosingToken() throws Exception {
 		String responseDigest = DigestAuthUtils.generateDigest(false, USERNAME, REALM, PASSWORD, "GET", REQUEST_URI,