Browse Source

DigestAuthenticationFilter.securityContextRepository

Issue gh-10953
Rob Winch 3 years ago
parent
commit
ba7fb0cb14

+ 17 - 0
web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilter.java

@@ -49,6 +49,8 @@ import org.springframework.security.core.userdetails.UserDetailsService;
 import org.springframework.security.core.userdetails.UsernameNotFoundException;
 import org.springframework.security.core.userdetails.cache.NullUserCache;
 import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
+import org.springframework.security.web.context.NullSecurityContextRepository;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
 import org.springframework.web.filter.GenericFilterBean;
@@ -106,6 +108,8 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
 
 	private boolean createAuthenticatedToken = false;
 
+	private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
+
 	@Override
 	public void afterPropertiesSet() {
 		Assert.notNull(this.userDetailsService, "A UserDetailsService is required");
@@ -192,6 +196,7 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
 		SecurityContext context = SecurityContextHolder.createEmptyContext();
 		context.setAuthentication(authentication);
 		SecurityContextHolder.setContext(context);
+		this.securityContextRepository.saveContext(context, request, response);
 		chain.doFilter(request, response);
 	}
 
@@ -271,6 +276,18 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
 		this.createAuthenticatedToken = createAuthenticatedToken;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
+	 * authentication success. The default action is not to save the
+	 * {@link SecurityContext}.
+	 * @param securityContextRepository the {@link SecurityContextRepository} to use.
+	 * Cannot be null.
+	 */
+	public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
+		Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
+		this.securityContextRepository = securityContextRepository;
+	}
+
 	private class DigestData {
 
 		private final String username;

+ 24 - 0
web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java

@@ -29,6 +29,7 @@ import org.apache.commons.codec.digest.DigestUtils;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
 
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
@@ -40,10 +41,12 @@ import org.springframework.security.core.userdetails.User;
 import org.springframework.security.core.userdetails.UserDetails;
 import org.springframework.security.core.userdetails.UserDetailsService;
 import org.springframework.security.core.userdetails.cache.NullUserCache;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.util.StringUtils;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -389,4 +392,25 @@ public class DigestAuthenticationFilterTests {
 		assertThat(existingAuthentication).isSameAs(existingContext.getAuthentication());
 	}
 
+	@Test
+	public void testSecurityContextRepository() throws Exception {
+		SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
+		ArgumentCaptor<SecurityContext> contextArg = ArgumentCaptor.forClass(SecurityContext.class);
+		String responseDigest = DigestAuthUtils.generateDigest(false, USERNAME, REALM, PASSWORD, "GET", REQUEST_URI,
+				QOP, NONCE, NC, CNONCE);
+		this.request.addHeader("Authorization",
+				createAuthorizationHeader(USERNAME, REALM, NONCE, REQUEST_URI, responseDigest, QOP, NC, CNONCE));
+		this.filter.setSecurityContextRepository(securityContextRepository);
+		this.filter.setCreateAuthenticatedToken(true);
+		MockHttpServletResponse response = executeFilterInContainerSimulator(this.filter, this.request, true);
+		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull();
+		assertThat(((UserDetails) SecurityContextHolder.getContext().getAuthentication().getPrincipal()).getUsername())
+				.isEqualTo(USERNAME);
+		assertThat(SecurityContextHolder.getContext().getAuthentication().isAuthenticated()).isTrue();
+		assertThat(SecurityContextHolder.getContext().getAuthentication().getAuthorities())
+				.isEqualTo(AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"));
+		verify(securityContextRepository).saveContext(contextArg.capture(), eq(this.request), eq(response));
+		assertThat(contextArg.getValue().getAuthentication().getName()).isEqualTo(USERNAME);
+	}
+
 }