Forráskód Böngészése

Fix client_secret_basic authentication failures and return challenge

Closes gh-468
Joe Grandja 4 hónapja
szülő
commit
42c18c856f

+ 45 - 15
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2024 the original author or authors.
+ * Copyright 2020-2025 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -34,6 +34,7 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
@@ -53,6 +54,7 @@ import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
+import org.springframework.security.web.authentication.www.BasicAuthenticationEntryPoint;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
 import org.springframework.web.filter.OncePerRequestFilter;
@@ -90,6 +92,8 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
 
 	private final AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();
 
+	private final BasicAuthenticationEntryPoint basicAuthenticationEntryPoint = new BasicAuthenticationEntryPoint();
+
 	private AuthenticationConverter authenticationConverter;
 
 	private AuthenticationSuccessHandler authenticationSuccessHandler = this::onAuthenticationSuccess;
@@ -110,6 +114,7 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
 		Assert.notNull(requestMatcher, "requestMatcher cannot be null");
 		this.authenticationManager = authenticationManager;
 		this.requestMatcher = requestMatcher;
+		this.basicAuthenticationEntryPoint.setRealmName("default");
 		// @formatter:off
 		this.authenticationConverter = new DelegatingAuthenticationConverter(
 				Arrays.asList(
@@ -130,8 +135,9 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
 			return;
 		}
 
+		Authentication authenticationRequest = null;
 		try {
-			Authentication authenticationRequest = this.authenticationConverter.convert(request);
+			authenticationRequest = this.authenticationConverter.convert(request);
 			if (authenticationRequest instanceof AbstractAuthenticationToken) {
 				((AbstractAuthenticationToken) authenticationRequest)
 					.setDetails(this.authenticationDetailsSource.buildDetails(request));
@@ -148,7 +154,14 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
 			if (this.logger.isTraceEnabled()) {
 				this.logger.trace(LogMessage.format("Client authentication failed: %s", ex.getError()), ex);
 			}
-			this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
+			if (authenticationRequest instanceof OAuth2ClientAuthenticationToken clientAuthentication) {
+				this.authenticationFailureHandler.onAuthenticationFailure(request, response,
+						new OAuth2ClientAuthenticationException(ex.getError(), ex, clientAuthentication));
+			}
+			else {
+				this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
+			}
+
 		}
 	}
 
@@ -200,21 +213,21 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
 	}
 
 	private void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response,
-			AuthenticationException exception) throws IOException {
+			AuthenticationException authenticationException) throws IOException {
 
 		SecurityContextHolder.clearContext();
 
-		// TODO
-		// The authorization server MAY return an HTTP 401 (Unauthorized) status code
-		// to indicate which HTTP authentication schemes are supported.
-		// If the client attempted to authenticate via the "Authorization" request header
-		// field,
-		// the authorization server MUST respond with an HTTP 401 (Unauthorized) status
-		// code and
-		// include the "WWW-Authenticate" response header field
-		// matching the authentication scheme used by the client.
-
-		OAuth2Error error = ((OAuth2AuthenticationException) exception).getError();
+		if (authenticationException instanceof OAuth2ClientAuthenticationException clientAuthenticationException) {
+			OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationException
+				.getClientAuthentication();
+			if (ClientAuthenticationMethod.CLIENT_SECRET_BASIC
+				.equals(clientAuthentication.getClientAuthenticationMethod())) {
+				this.basicAuthenticationEntryPoint.commence(request, response, authenticationException);
+				return;
+			}
+		}
+
+		OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError();
 		ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
 		if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) {
 			httpResponse.setStatusCode(HttpStatus.UNAUTHORIZED);
@@ -249,4 +262,21 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
 		}
 	}
 
+	private static final class OAuth2ClientAuthenticationException extends OAuth2AuthenticationException {
+
+		private final OAuth2ClientAuthenticationToken clientAuthentication;
+
+		private OAuth2ClientAuthenticationException(OAuth2Error error, Throwable cause,
+				OAuth2ClientAuthenticationToken clientAuthentication) {
+			super(error, cause);
+			Assert.notNull(clientAuthentication, "clientAuthentication cannot be null");
+			this.clientAuthentication = clientAuthentication;
+		}
+
+		private OAuth2ClientAuthenticationToken getClientAuthentication() {
+			return this.clientAuthentication;
+		}
+
+	}
+
 }

+ 5 - 5
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2024 the original author or authors.
+ * Copyright 2020-2025 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -538,7 +538,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 	}
 
 	@Test
-	public void requestWhenConfidentialClientWithPkceAndMissingCodeVerifierThenBadRequest() throws Exception {
+	public void requestWhenConfidentialClientWithPkceAndMissingCodeVerifierThenUnauthorized() throws Exception {
 		this.spring.register(AuthorizationServerConfiguration.class).autowire();
 
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
@@ -569,7 +569,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 				.params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization))
 				.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
 				.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient)))
-			.andExpect(status().isBadRequest());
+			.andExpect(status().isUnauthorized());
 	}
 
 	// gh-1011
@@ -601,7 +601,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 	}
 
 	@Test
-	public void requestWhenConfidentialClientWithPkceAndMissingCodeChallengeButCodeVerifierProvidedThenBadRequest()
+	public void requestWhenConfidentialClientWithPkceAndMissingCodeChallengeButCodeVerifierProvidedThenUnauthorized()
 			throws Exception {
 		this.spring.register(AuthorizationServerConfiguration.class).autowire();
 
@@ -631,7 +631,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 				.params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization))
 				.param(PkceParameterNames.CODE_VERIFIER, S256_CODE_VERIFIER)
 				.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient)))
-			.andExpect(status().isBadRequest());
+			.andExpect(status().isUnauthorized());
 	}
 
 	@Test

+ 12 - 14
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2022 the original author or authors.
+ * Copyright 2020-2025 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@ import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 import org.mockito.ArgumentCaptor;
 
+import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.HttpStatus;
 import org.springframework.http.converter.HttpMessageConverter;
@@ -175,26 +176,25 @@ public class OAuth2ClientAuthenticationFilterTests {
 
 	// gh-889
 	@Test
-	public void doFilterWhenRequestMatchesAndClientIdContainsNonPrintableASCIIThenInvalidRequestError()
-			throws Exception {
+	public void doFilterWhenRequestMatchesAndClientIdContainsNonPrintableASCIIThenReturnChallenge() throws Exception {
 		// Hex 00 -> null
 		String clientId = new String(Hex.decode("00"), StandardCharsets.UTF_8);
-		assertWhenInvalidClientIdThenInvalidRequestError(clientId);
+		assertWhenInvalidClientIdThenReturnChallenge(clientId);
 
 		// Hex 0a61 -> line feed + a
 		clientId = new String(Hex.decode("0a61"), StandardCharsets.UTF_8);
-		assertWhenInvalidClientIdThenInvalidRequestError(clientId);
+		assertWhenInvalidClientIdThenReturnChallenge(clientId);
 
 		// Hex 1b -> escape
 		clientId = new String(Hex.decode("1b"), StandardCharsets.UTF_8);
-		assertWhenInvalidClientIdThenInvalidRequestError(clientId);
+		assertWhenInvalidClientIdThenReturnChallenge(clientId);
 
 		// Hex 1b61 -> escape + a
 		clientId = new String(Hex.decode("1b61"), StandardCharsets.UTF_8);
-		assertWhenInvalidClientIdThenInvalidRequestError(clientId);
+		assertWhenInvalidClientIdThenReturnChallenge(clientId);
 	}
 
-	private void assertWhenInvalidClientIdThenInvalidRequestError(String clientId) throws Exception {
+	private void assertWhenInvalidClientIdThenReturnChallenge(String clientId) throws Exception {
 		given(this.authenticationConverter.convert(any(HttpServletRequest.class)))
 			.willReturn(new OAuth2ClientAuthenticationToken(clientId, ClientAuthenticationMethod.CLIENT_SECRET_BASIC,
 					"secret", null));
@@ -210,13 +210,12 @@ public class OAuth2ClientAuthenticationFilterTests {
 		verifyNoInteractions(this.authenticationManager);
 
 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
-		OAuth2Error error = readError(response);
-		assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value());
+		assertThat(response.getHeader(HttpHeaders.WWW_AUTHENTICATE)).isEqualTo("Basic realm=\"default\"");
 	}
 
 	@Test
-	public void doFilterWhenRequestMatchesAndBadCredentialsThenInvalidClientError() throws Exception {
+	public void doFilterWhenRequestMatchesAndBadCredentialsThenReturnChallenge() throws Exception {
 		given(this.authenticationConverter.convert(any(HttpServletRequest.class)))
 			.willReturn(new OAuth2ClientAuthenticationToken("clientId", ClientAuthenticationMethod.CLIENT_SECRET_BASIC,
 					"invalid-secret", null));
@@ -235,8 +234,7 @@ public class OAuth2ClientAuthenticationFilterTests {
 
 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
 		assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value());
-		OAuth2Error error = readError(response);
-		assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+		assertThat(response.getHeader(HttpHeaders.WWW_AUTHENTICATE)).isEqualTo("Basic realm=\"default\"");
 	}
 
 	@Test