ソースを参照

Polish OAuth2ClientAuthenticationProvider

Commit 5c31fb1b7e7a0efbb60cb7aa34762ad5577eba45
Joe Grandja 5 年 前
コミット
7720e275e4

+ 21 - 11
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java

@@ -82,15 +82,22 @@ public class OAuth2ClientAuthenticationProvider implements AuthenticationProvide
 			throwInvalidClient();
 		}
 
+		boolean authenticatedCredentials = false;
+
 		if (clientAuthentication.getCredentials() != null) {
 			String clientSecret = clientAuthentication.getCredentials().toString();
 			// TODO Use PasswordEncoder.matches()
 			if (!registeredClient.getClientSecret().equals(clientSecret)) {
 				throwInvalidClient();
 			}
+			authenticatedCredentials = true;
 		}
 
-		authenticatePkceIfAvailable(clientAuthentication, registeredClient);
+		authenticatedCredentials = authenticatedCredentials ||
+				authenticatePkceIfAvailable(clientAuthentication, registeredClient);
+		if (!authenticatedCredentials) {
+			throwInvalidClient();
+		}
 
 		return new OAuth2ClientAuthenticationToken(registeredClient);
 	}
@@ -100,12 +107,12 @@ public class OAuth2ClientAuthenticationProvider implements AuthenticationProvide
 		return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication);
 	}
 
-	private void authenticatePkceIfAvailable(OAuth2ClientAuthenticationToken clientAuthentication,
+	private boolean authenticatePkceIfAvailable(OAuth2ClientAuthenticationToken clientAuthentication,
 			RegisteredClient registeredClient) {
 
 		Map<String, Object> parameters = clientAuthentication.getAdditionalParameters();
 		if (CollectionUtils.isEmpty(parameters) || !authorizationCodeGrant(parameters)) {
-			return;
+			return false;
 		}
 
 		OAuth2Authorization authorization = this.authorizationService.findByToken(
@@ -120,16 +127,19 @@ public class OAuth2ClientAuthenticationProvider implements AuthenticationProvide
 
 		String codeChallenge = (String) authorizationRequest.getAdditionalParameters()
 				.get(PkceParameterNames.CODE_CHALLENGE);
-		if (StringUtils.hasText(codeChallenge)) {
-			String codeChallengeMethod = (String) authorizationRequest.getAdditionalParameters()
-					.get(PkceParameterNames.CODE_CHALLENGE_METHOD);
-			String codeVerifier = (String) parameters.get(PkceParameterNames.CODE_VERIFIER);
-			if (!codeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) {
-				throwInvalidClient();
-			}
-		} else if (registeredClient.getClientSettings().requireProofKey()) {
+		if (!StringUtils.hasText(codeChallenge) &&
+				registeredClient.getClientSettings().requireProofKey()) {
 			throwInvalidClient();
 		}
+
+		String codeChallengeMethod = (String) authorizationRequest.getAdditionalParameters()
+				.get(PkceParameterNames.CODE_CHALLENGE_METHOD);
+		String codeVerifier = (String) parameters.get(PkceParameterNames.CODE_VERIFIER);
+		if (!codeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) {
+			throwInvalidClient();
+		}
+
+		return true;
 	}
 
 	private static boolean authorizationCodeGrant(Map<String, Object> parameters) {

+ 12 - 13
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java

@@ -37,7 +37,6 @@ import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.verifyNoInteractions;
 import static org.mockito.Mockito.when;
 
 /**
@@ -120,23 +119,22 @@ public class OAuth2ClientAuthenticationProviderTests {
 	}
 
 	@Test
-	public void authenticateWhenValidCredentialsThenAuthenticated() {
+	public void authenticateWhenClientSecretNotProvidedThenThrowOAuth2AuthenticationException() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
 				.thenReturn(registeredClient);
 
-		OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(
-				registeredClient.getClientId(), registeredClient.getClientSecret(), null);
-		OAuth2ClientAuthenticationToken authenticationResult =
-				(OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication);
-		assertThat(authenticationResult.isAuthenticated()).isTrue();
-		assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId());
-		assertThat(authenticationResult.getCredentials()).isNull();
-		assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient);
+		OAuth2ClientAuthenticationToken authentication =
+				new OAuth2ClientAuthenticationToken(registeredClient.getClientId(), null);
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
 	}
 
 	@Test
-	public void authenticateWhenNotPkceThenContinueAuthenticated() {
+	public void authenticateWhenValidCredentialsThenAuthenticated() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
 				.thenReturn(registeredClient);
@@ -146,8 +144,9 @@ public class OAuth2ClientAuthenticationProviderTests {
 		OAuth2ClientAuthenticationToken authenticationResult =
 				(OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication);
 		assertThat(authenticationResult.isAuthenticated()).isTrue();
-
-		verifyNoInteractions(this.authorizationService);
+		assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId());
+		assertThat(authenticationResult.getCredentials()).isNull();
+		assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient);
 	}
 
 	@Test