Browse Source

Polish OAuth2ClientAuthenticationProviderTests

Joe Grandja 3 years ago
parent
commit
b455268fa1

+ 20 - 13
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java

@@ -28,6 +28,7 @@ import org.springframework.security.crypto.factory.PasswordEncoderFactories;
 import org.springframework.security.crypto.password.PasswordEncoder;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.OAuth2TokenType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
@@ -54,6 +55,7 @@ import org.springframework.util.StringUtils;
  * @see PasswordEncoder
  */
 public final class OAuth2ClientAuthenticationProvider implements AuthenticationProvider {
+	private static final String CLIENT_AUTHENTICATION_ERROR_URI = "https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-01#section-3.2.1";
 	private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE);
 	private final RegisteredClientRepository registeredClientRepository;
 	private final OAuth2AuthorizationService authorizationService;
@@ -95,28 +97,28 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP
 		String clientId = clientAuthentication.getPrincipal().toString();
 		RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
 		if (registeredClient == null) {
-			throwInvalidClient();
+			throwInvalidClient(OAuth2ParameterNames.CLIENT_ID);
 		}
 
 		if (!registeredClient.getClientAuthenticationMethods().contains(
 				clientAuthentication.getClientAuthenticationMethod())) {
-			throwInvalidClient();
+			throwInvalidClient("authentication_method");
 		}
 
-		boolean authenticatedCredentials = false;
+		boolean credentialsAuthenticated = false;
 
 		if (clientAuthentication.getCredentials() != null) {
 			String clientSecret = clientAuthentication.getCredentials().toString();
 			if (!this.passwordEncoder.matches(clientSecret, registeredClient.getClientSecret())) {
-				throwInvalidClient();
+				throwInvalidClient(OAuth2ParameterNames.CLIENT_SECRET);
 			}
-			authenticatedCredentials = true;
+			credentialsAuthenticated = true;
 		}
 
 		boolean pkceAuthenticated = authenticatePkceIfAvailable(clientAuthentication, registeredClient);
-		authenticatedCredentials = authenticatedCredentials || pkceAuthenticated;
-		if (!authenticatedCredentials) {
-			throwInvalidClient();
+		credentialsAuthenticated = credentialsAuthenticated || pkceAuthenticated;
+		if (!credentialsAuthenticated) {
+			throwInvalidClient("credentials");
 		}
 
 		return new OAuth2ClientAuthenticationToken(registeredClient,
@@ -140,7 +142,7 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP
 				(String) parameters.get(OAuth2ParameterNames.CODE),
 				AUTHORIZATION_CODE_TOKEN_TYPE);
 		if (authorization == null) {
-			throwInvalidClient();
+			throwInvalidClient(OAuth2ParameterNames.CODE);
 		}
 
 		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
@@ -150,7 +152,7 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP
 				.get(PkceParameterNames.CODE_CHALLENGE);
 		if (!StringUtils.hasText(codeChallenge)) {
 			if (registeredClient.getClientSettings().isRequireProofKey()) {
-				throwInvalidClient();
+				throwInvalidClient(PkceParameterNames.CODE_CHALLENGE);
 			} else {
 				return false;
 			}
@@ -160,7 +162,7 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP
 				.get(PkceParameterNames.CODE_CHALLENGE_METHOD);
 		String codeVerifier = (String) parameters.get(PkceParameterNames.CODE_VERIFIER);
 		if (!codeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) {
-			throwInvalidClient();
+			throwInvalidClient(PkceParameterNames.CODE_VERIFIER);
 		}
 
 		return true;
@@ -191,7 +193,12 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP
 		throw new OAuth2AuthenticationException(OAuth2ErrorCodes.SERVER_ERROR);
 	}
 
-	private static void throwInvalidClient() {
-		throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT);
+	private static void throwInvalidClient(String parameterName) {
+		OAuth2Error error = new OAuth2Error(
+				OAuth2ErrorCodes.INVALID_CLIENT,
+				"Client authentication failed: " + parameterName,
+				CLIENT_AUTHENTICATION_ERROR_URI);
+		throw new OAuth2AuthenticationException(error);
 	}
+
 }

+ 3 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java

@@ -178,7 +178,9 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
 		} else {
 			httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
 		}
-		this.errorHttpResponseConverter.write(error, null, httpResponse);
+		// We don't want to reveal too much information to the caller so just return the error code
+		OAuth2Error errorResponse = new OAuth2Error(error.getErrorCode());
+		this.errorHttpResponseConverter.write(errorResponse, null, httpResponse);
 	}
 
 }

+ 36 - 18
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java

@@ -128,8 +128,10 @@ public class OAuth2ClientAuthenticationProviderTests {
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
 				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
-				.extracting("errorCode")
-				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+				.satisfies(error -> {
+					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+					assertThat(error.getDescription()).contains(OAuth2ParameterNames.CLIENT_ID);
+				});
 	}
 
 	@Test
@@ -143,8 +145,10 @@ public class OAuth2ClientAuthenticationProviderTests {
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
 				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
-				.extracting("errorCode")
-				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+				.satisfies(error -> {
+					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+					assertThat(error.getDescription()).contains(OAuth2ParameterNames.CLIENT_SECRET);
+				});
 		verify(this.passwordEncoder).matches(any(), any());
 	}
 
@@ -159,8 +163,10 @@ public class OAuth2ClientAuthenticationProviderTests {
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
 				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
-				.extracting("errorCode")
-				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+				.satisfies(error -> {
+					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+					assertThat(error.getDescription()).contains("credentials");
+				});
 	}
 
 	@Test
@@ -222,8 +228,10 @@ public class OAuth2ClientAuthenticationProviderTests {
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
 				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
-				.extracting("errorCode")
-				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+				.satisfies(error -> {
+					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+					assertThat(error.getDescription()).contains(OAuth2ParameterNames.CODE);
+				});
 	}
 
 	@Test
@@ -246,8 +254,10 @@ public class OAuth2ClientAuthenticationProviderTests {
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
 				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
-				.extracting("errorCode")
-				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+				.satisfies(error -> {
+					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+					assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER);
+				});
 	}
 
 	@Test
@@ -270,8 +280,10 @@ public class OAuth2ClientAuthenticationProviderTests {
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
 				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
-				.extracting("errorCode")
-				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+				.satisfies(error -> {
+					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+					assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER);
+				});
 	}
 
 	@Test
@@ -294,8 +306,10 @@ public class OAuth2ClientAuthenticationProviderTests {
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
 				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
-				.extracting("errorCode")
-				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+				.satisfies(error -> {
+					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+					assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER);
+				});
 	}
 
 	@Test
@@ -318,8 +332,10 @@ public class OAuth2ClientAuthenticationProviderTests {
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
 				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
-				.extracting("errorCode")
-				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+				.satisfies(error -> {
+					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+					assertThat(error.getDescription()).contains(PkceParameterNames.CODE_VERIFIER);
+				});
 	}
 
 	@Test
@@ -437,8 +453,10 @@ public class OAuth2ClientAuthenticationProviderTests {
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
 				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
-				.extracting("errorCode")
-				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+				.satisfies(error -> {
+					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+					assertThat(error.getDescription()).contains("authentication_method");
+				});
 	}
 
 	private static Map<String, Object> createAuthorizationCodeTokenParameters() {