Pārlūkot izejas kodu

Allow customize the AuthenticationConverter in BasicAuthenticationFilter

Closes gh-13988
Marcus Da Coregio 1 gadu atpakaļ
vecāks
revīzija
7e9d707c7d

+ 38 - 5
web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java

@@ -28,15 +28,16 @@ import org.springframework.core.log.LogMessage;
 import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.AuthenticationManager;
-import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.web.AuthenticationEntryPoint;
 import org.springframework.security.web.AuthenticationEntryPoint;
+import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.NullRememberMeServices;
 import org.springframework.security.web.authentication.NullRememberMeServices;
 import org.springframework.security.web.authentication.RememberMeServices;
 import org.springframework.security.web.authentication.RememberMeServices;
+import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
 import org.springframework.security.web.context.RequestAttributeSecurityContextRepository;
 import org.springframework.security.web.context.RequestAttributeSecurityContextRepository;
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.util.Assert;
 import org.springframework.util.Assert;
@@ -105,7 +106,7 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
 
 
 	private String credentialsCharset = "UTF-8";
 	private String credentialsCharset = "UTF-8";
 
 
-	private BasicAuthenticationConverter authenticationConverter = new BasicAuthenticationConverter();
+	private AuthenticationConverter authenticationConverter = new BasicAuthenticationConverter();
 
 
 	private SecurityContextRepository securityContextRepository = new RequestAttributeSecurityContextRepository();
 	private SecurityContextRepository securityContextRepository = new RequestAttributeSecurityContextRepository();
 
 
@@ -149,6 +150,18 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
 		this.securityContextRepository = securityContextRepository;
 		this.securityContextRepository = securityContextRepository;
 	}
 	}
 
 
+	/**
+	 * Sets the
+	 * {@link org.springframework.security.web.authentication.AuthenticationConverter} to
+	 * use. Defaults to {@link BasicAuthenticationConverter}
+	 * @param authenticationConverter the converter to use
+	 * @since 6.2
+	 */
+	public void setAuthenticationConverter(AuthenticationConverter authenticationConverter) {
+		Assert.notNull(authenticationConverter, "authenticationConverter cannot be null");
+		this.authenticationConverter = authenticationConverter;
+	}
+
 	@Override
 	@Override
 	public void afterPropertiesSet() {
 	public void afterPropertiesSet() {
 		Assert.notNull(this.authenticationManager, "An AuthenticationManager is required");
 		Assert.notNull(this.authenticationManager, "An AuthenticationManager is required");
@@ -161,7 +174,7 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
 			throws IOException, ServletException {
 			throws IOException, ServletException {
 		try {
 		try {
-			UsernamePasswordAuthenticationToken authRequest = this.authenticationConverter.convert(request);
+			Authentication authRequest = this.authenticationConverter.convert(request);
 			if (authRequest == null) {
 			if (authRequest == null) {
 				this.logger.trace("Did not process authentication request since failed to find "
 				this.logger.trace("Did not process authentication request since failed to find "
 						+ "username and password in Basic Authorization header");
 						+ "username and password in Basic Authorization header");
@@ -250,9 +263,19 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
 		this.securityContextHolderStrategy = securityContextHolderStrategy;
 		this.securityContextHolderStrategy = securityContextHolderStrategy;
 	}
 	}
 
 
+	/**
+	 * Sets the {@link AuthenticationDetailsSource} to use. By default, it is set to use
+	 * the {@link WebAuthenticationDetailsSource}. Note that this configuration applies
+	 * exclusively when the {@link #authenticationConverter} is set to
+	 * {@link BasicAuthenticationConverter}. If you are utilizing a different
+	 * implementation, you will need to manually specify the authentication details on it.
+	 * @param authenticationDetailsSource the {@link AuthenticationDetailsSource} to use.
+	 */
 	public void setAuthenticationDetailsSource(
 	public void setAuthenticationDetailsSource(
 			AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource) {
 			AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource) {
-		this.authenticationConverter.setAuthenticationDetailsSource(authenticationDetailsSource);
+		if (this.authenticationConverter instanceof BasicAuthenticationConverter basicAuthenticationConverter) {
+			basicAuthenticationConverter.setAuthenticationDetailsSource(authenticationDetailsSource);
+		}
 	}
 	}
 
 
 	public void setRememberMeServices(RememberMeServices rememberMeServices) {
 	public void setRememberMeServices(RememberMeServices rememberMeServices) {
@@ -260,10 +283,20 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
 		this.rememberMeServices = rememberMeServices;
 		this.rememberMeServices = rememberMeServices;
 	}
 	}
 
 
+	/**
+	 * Sets the charset to use when decoding credentials to {@link String}s. By default,
+	 * it is set to {@code UTF-8}. Note that this configuration applies exclusively when
+	 * the {@link #authenticationConverter} is set to
+	 * {@link BasicAuthenticationConverter}. If you are utilizing a different
+	 * implementation, you will need to manually specify the charset on it.
+	 * @param credentialsCharset the charset to use.
+	 */
 	public void setCredentialsCharset(String credentialsCharset) {
 	public void setCredentialsCharset(String credentialsCharset) {
 		Assert.hasText(credentialsCharset, "credentialsCharset cannot be null or empty");
 		Assert.hasText(credentialsCharset, "credentialsCharset cannot be null or empty");
 		this.credentialsCharset = credentialsCharset;
 		this.credentialsCharset = credentialsCharset;
-		this.authenticationConverter.setCredentialsCharset(Charset.forName(credentialsCharset));
+		if (this.authenticationConverter instanceof BasicAuthenticationConverter basicAuthenticationConverter) {
+			basicAuthenticationConverter.setCredentialsCharset(Charset.forName(credentialsCharset));
+		}
 	}
 	}
 
 
 	protected String getCredentialsCharset(HttpServletRequest httpRequest) {
 	protected String getCredentialsCharset(HttpServletRequest httpRequest) {

+ 57 - 0
web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java

@@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets;
 import jakarta.servlet.FilterChain;
 import jakarta.servlet.FilterChain;
 import jakarta.servlet.ServletRequest;
 import jakarta.servlet.ServletRequest;
 import jakarta.servlet.ServletResponse;
 import jakarta.servlet.ServletResponse;
+import jakarta.servlet.http.HttpServletRequest;
 import jakarta.servlet.http.HttpServletResponse;
 import jakarta.servlet.http.HttpServletResponse;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.BeforeEach;
@@ -41,9 +42,12 @@ import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.test.web.CodecTestUtils;
 import org.springframework.security.test.web.CodecTestUtils;
+import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.WebAuthenticationDetails;
 import org.springframework.security.web.authentication.WebAuthenticationDetails;
 import org.springframework.security.web.context.RequestAttributeSecurityContextRepository;
 import org.springframework.security.web.context.RequestAttributeSecurityContextRepository;
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.context.SecurityContextRepository;
+import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
+import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.web.util.WebUtils;
 import org.springframework.web.util.WebUtils;
 
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThat;
@@ -488,4 +492,57 @@ public class BasicAuthenticationFilterTests {
 		assertThat(authenticationRequest.getName()).isEqualTo("rod");
 		assertThat(authenticationRequest.getName()).isEqualTo("rod");
 	}
 	}
 
 
+	@Test
+	public void doFilterWhenCustomAuthenticationConverterThatIgnoresRequestThenIgnores() throws Exception {
+		this.filter.setAuthenticationConverter(new TestAuthenticationConverter());
+		String token = "rod:koala";
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token));
+		request.setServletPath("/ignored");
+		FilterChain filterChain = mock(FilterChain.class);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		this.filter.doFilter(request, response, filterChain);
+		assertThat(response.getStatus()).isEqualTo(200);
+
+		verify(this.manager, never()).authenticate(any(Authentication.class));
+		verify(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
+		verifyNoMoreInteractions(this.manager, filterChain);
+	}
+
+	@Test
+	public void doFilterWhenCustomAuthenticationConverterRequestThenAuthenticate() throws Exception {
+		this.filter.setAuthenticationConverter(new TestAuthenticationConverter());
+		String token = "rod:koala";
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.addHeader("Authorization", "Basic " + CodecTestUtils.encodeBase64(token));
+		request.setServletPath("/ok");
+		FilterChain filterChain = mock(FilterChain.class);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		this.filter.doFilter(request, response, filterChain);
+		assertThat(response.getStatus()).isEqualTo(200);
+		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull();
+		assertThat(SecurityContextHolder.getContext().getAuthentication().getName()).isEqualTo("rod");
+	}
+
+	@Test
+	public void setAuthenticationConverterWhenNullThenException() {
+		assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationConverter(null));
+	}
+
+	static class TestAuthenticationConverter implements AuthenticationConverter {
+
+		private final RequestMatcher matcher = AntPathRequestMatcher.antMatcher("/ignored");
+
+		private final BasicAuthenticationConverter delegate = new BasicAuthenticationConverter();
+
+		@Override
+		public Authentication convert(HttpServletRequest request) {
+			if (this.matcher.matches(request)) {
+				return null;
+			}
+			return this.delegate.convert(request);
+		}
+
+	}
+
 }
 }