Explorar o código

Revert "Fix client_secret_basic authentication failures and return challenge"

This reverts commit 42c18c856f168c24930ca480d2d26d84d455af47.
Joe Grandja hai 4 meses
pai
achega
c624d0a908

+ 15 - 45
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2025 the original author or authors.
+ * Copyright 2020-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -34,7 +34,6 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
-import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
@@ -54,7 +53,6 @@ import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
-import org.springframework.security.web.authentication.www.BasicAuthenticationEntryPoint;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
 import org.springframework.web.filter.OncePerRequestFilter;
@@ -92,8 +90,6 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
 
 	private final AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();
 
-	private final BasicAuthenticationEntryPoint basicAuthenticationEntryPoint = new BasicAuthenticationEntryPoint();
-
 	private AuthenticationConverter authenticationConverter;
 
 	private AuthenticationSuccessHandler authenticationSuccessHandler = this::onAuthenticationSuccess;
@@ -114,7 +110,6 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
 		Assert.notNull(requestMatcher, "requestMatcher cannot be null");
 		this.authenticationManager = authenticationManager;
 		this.requestMatcher = requestMatcher;
-		this.basicAuthenticationEntryPoint.setRealmName("default");
 		// @formatter:off
 		this.authenticationConverter = new DelegatingAuthenticationConverter(
 				Arrays.asList(
@@ -135,9 +130,8 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
 			return;
 		}
 
-		Authentication authenticationRequest = null;
 		try {
-			authenticationRequest = this.authenticationConverter.convert(request);
+			Authentication authenticationRequest = this.authenticationConverter.convert(request);
 			if (authenticationRequest instanceof AbstractAuthenticationToken) {
 				((AbstractAuthenticationToken) authenticationRequest)
 					.setDetails(this.authenticationDetailsSource.buildDetails(request));
@@ -154,14 +148,7 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
 			if (this.logger.isTraceEnabled()) {
 				this.logger.trace(LogMessage.format("Client authentication failed: %s", ex.getError()), ex);
 			}
-			if (authenticationRequest instanceof OAuth2ClientAuthenticationToken clientAuthentication) {
-				this.authenticationFailureHandler.onAuthenticationFailure(request, response,
-						new OAuth2ClientAuthenticationException(ex.getError(), ex, clientAuthentication));
-			}
-			else {
-				this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
-			}
-
+			this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
 		}
 	}
 
@@ -213,21 +200,21 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
 	}
 
 	private void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response,
-			AuthenticationException authenticationException) throws IOException {
+			AuthenticationException exception) throws IOException {
 
 		SecurityContextHolder.clearContext();
 
-		if (authenticationException instanceof OAuth2ClientAuthenticationException clientAuthenticationException) {
-			OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationException
-				.getClientAuthentication();
-			if (ClientAuthenticationMethod.CLIENT_SECRET_BASIC
-				.equals(clientAuthentication.getClientAuthenticationMethod())) {
-				this.basicAuthenticationEntryPoint.commence(request, response, authenticationException);
-				return;
-			}
-		}
-
-		OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError();
+		// TODO
+		// The authorization server MAY return an HTTP 401 (Unauthorized) status code
+		// to indicate which HTTP authentication schemes are supported.
+		// If the client attempted to authenticate via the "Authorization" request header
+		// field,
+		// the authorization server MUST respond with an HTTP 401 (Unauthorized) status
+		// code and
+		// include the "WWW-Authenticate" response header field
+		// matching the authentication scheme used by the client.
+
+		OAuth2Error error = ((OAuth2AuthenticationException) exception).getError();
 		ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
 		if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) {
 			httpResponse.setStatusCode(HttpStatus.UNAUTHORIZED);
@@ -262,21 +249,4 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
 		}
 	}
 
-	private static final class OAuth2ClientAuthenticationException extends OAuth2AuthenticationException {
-
-		private final OAuth2ClientAuthenticationToken clientAuthentication;
-
-		private OAuth2ClientAuthenticationException(OAuth2Error error, Throwable cause,
-				OAuth2ClientAuthenticationToken clientAuthentication) {
-			super(error, cause);
-			Assert.notNull(clientAuthentication, "clientAuthentication cannot be null");
-			this.clientAuthentication = clientAuthentication;
-		}
-
-		private OAuth2ClientAuthenticationToken getClientAuthentication() {
-			return this.clientAuthentication;
-		}
-
-	}
-
 }

+ 5 - 5
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2025 the original author or authors.
+ * Copyright 2020-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -538,7 +538,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 	}
 
 	@Test
-	public void requestWhenConfidentialClientWithPkceAndMissingCodeVerifierThenUnauthorized() throws Exception {
+	public void requestWhenConfidentialClientWithPkceAndMissingCodeVerifierThenBadRequest() throws Exception {
 		this.spring.register(AuthorizationServerConfiguration.class).autowire();
 
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
@@ -569,7 +569,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 				.params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization))
 				.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
 				.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient)))
-			.andExpect(status().isUnauthorized());
+			.andExpect(status().isBadRequest());
 	}
 
 	// gh-1011
@@ -601,7 +601,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 	}
 
 	@Test
-	public void requestWhenConfidentialClientWithPkceAndMissingCodeChallengeButCodeVerifierProvidedThenUnauthorized()
+	public void requestWhenConfidentialClientWithPkceAndMissingCodeChallengeButCodeVerifierProvidedThenBadRequest()
 			throws Exception {
 		this.spring.register(AuthorizationServerConfiguration.class).autowire();
 
@@ -631,7 +631,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 				.params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization))
 				.param(PkceParameterNames.CODE_VERIFIER, S256_CODE_VERIFIER)
 				.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient)))
-			.andExpect(status().isUnauthorized());
+			.andExpect(status().isBadRequest());
 	}
 
 	@Test

+ 14 - 12
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2025 the original author or authors.
+ * Copyright 2020-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -26,7 +26,6 @@ import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 import org.mockito.ArgumentCaptor;
 
-import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.HttpStatus;
 import org.springframework.http.converter.HttpMessageConverter;
@@ -176,25 +175,26 @@ public class OAuth2ClientAuthenticationFilterTests {
 
 	// gh-889
 	@Test
-	public void doFilterWhenRequestMatchesAndClientIdContainsNonPrintableASCIIThenReturnChallenge() throws Exception {
+	public void doFilterWhenRequestMatchesAndClientIdContainsNonPrintableASCIIThenInvalidRequestError()
+			throws Exception {
 		// Hex 00 -> null
 		String clientId = new String(Hex.decode("00"), StandardCharsets.UTF_8);
-		assertWhenInvalidClientIdThenReturnChallenge(clientId);
+		assertWhenInvalidClientIdThenInvalidRequestError(clientId);
 
 		// Hex 0a61 -> line feed + a
 		clientId = new String(Hex.decode("0a61"), StandardCharsets.UTF_8);
-		assertWhenInvalidClientIdThenReturnChallenge(clientId);
+		assertWhenInvalidClientIdThenInvalidRequestError(clientId);
 
 		// Hex 1b -> escape
 		clientId = new String(Hex.decode("1b"), StandardCharsets.UTF_8);
-		assertWhenInvalidClientIdThenReturnChallenge(clientId);
+		assertWhenInvalidClientIdThenInvalidRequestError(clientId);
 
 		// Hex 1b61 -> escape + a
 		clientId = new String(Hex.decode("1b61"), StandardCharsets.UTF_8);
-		assertWhenInvalidClientIdThenReturnChallenge(clientId);
+		assertWhenInvalidClientIdThenInvalidRequestError(clientId);
 	}
 
-	private void assertWhenInvalidClientIdThenReturnChallenge(String clientId) throws Exception {
+	private void assertWhenInvalidClientIdThenInvalidRequestError(String clientId) throws Exception {
 		given(this.authenticationConverter.convert(any(HttpServletRequest.class)))
 			.willReturn(new OAuth2ClientAuthenticationToken(clientId, ClientAuthenticationMethod.CLIENT_SECRET_BASIC,
 					"secret", null));
@@ -210,12 +210,13 @@ public class OAuth2ClientAuthenticationFilterTests {
 		verifyNoInteractions(this.authenticationManager);
 
 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value());
-		assertThat(response.getHeader(HttpHeaders.WWW_AUTHENTICATE)).isEqualTo("Basic realm=\"default\"");
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+		OAuth2Error error = readError(response);
+		assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
 	}
 
 	@Test
-	public void doFilterWhenRequestMatchesAndBadCredentialsThenReturnChallenge() throws Exception {
+	public void doFilterWhenRequestMatchesAndBadCredentialsThenInvalidClientError() throws Exception {
 		given(this.authenticationConverter.convert(any(HttpServletRequest.class)))
 			.willReturn(new OAuth2ClientAuthenticationToken("clientId", ClientAuthenticationMethod.CLIENT_SECRET_BASIC,
 					"invalid-secret", null));
@@ -234,7 +235,8 @@ public class OAuth2ClientAuthenticationFilterTests {
 
 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
 		assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value());
-		assertThat(response.getHeader(HttpHeaders.WWW_AUTHENTICATE)).isEqualTo("Basic realm=\"default\"");
+		OAuth2Error error = readError(response);
+		assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
 	}
 
 	@Test