2
0
Эх сурвалжийг харах

PKCE cannot be true and AuthorizationGrantType != AUTHORIZATION_CODE

PKCE is only valid for AuthorizationGrantType.AUTHORIZATION_CODE so the
code should validate this.

Issue gh-16382
Rob Winch 7 сар өмнө
parent
commit
f9498d3885

+ 6 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java

@@ -711,6 +711,12 @@ public final class ClientRegistration implements Serializable {
 							"AuthorizationGrantType: %s does not match the pre-defined constant %s and won't match a valid OAuth2AuthorizedClientProvider",
 							this.authorizationGrantType, authorizationGrantType));
 				}
+				if (!AuthorizationGrantType.AUTHORIZATION_CODE.equals(this.authorizationGrantType)
+						&& this.clientSettings.isRequireProofKey()) {
+					throw new IllegalStateException(
+							"clientSettings.isRequireProofKey=true is only valid with authorizationGrantType=AUTHORIZATION_CODE. Got authorizationGrantType="
+									+ this.authorizationGrantType);
+				}
 			}
 		}
 

+ 62 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java

@@ -16,14 +16,20 @@
 
 package org.springframework.security.oauth2.client.registration;
 
+import java.lang.reflect.Field;
+import java.lang.reflect.Modifier;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.LinkedHashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.MethodSource;
 
 import org.springframework.security.oauth2.core.AuthenticationMethod;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
@@ -31,6 +37,7 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
 
 /**
  * Tests for {@link ClientRegistration}.
@@ -776,4 +783,59 @@ public class ClientRegistrationTests {
 		assertThat(clientRegistration.getClientSettings().isRequireProofKey()).isFalse();
 	}
 
+	// gh-16382
+	@Test
+	void buildWhenNewAuthorizationCodeAndPkceThenBuilds() {
+		ClientSettings pkceEnabled = ClientSettings.builder().requireProofKey(true).build();
+		ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(CLIENT_ID)
+			.clientSettings(pkceEnabled)
+			.authorizationGrantType(new AuthorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()))
+			.redirectUri(REDIRECT_URI)
+			.authorizationUri(AUTHORIZATION_URI)
+			.tokenUri(TOKEN_URI)
+			.build();
+
+		// proof key should be false for passivity
+		assertThat(clientRegistration.getClientSettings().isRequireProofKey()).isTrue();
+	}
+
+	@ParameterizedTest
+	@MethodSource("invalidPkceGrantTypes")
+	void buildWhenInvalidGrantTypeForPkceThenException(AuthorizationGrantType invalidGrantType) {
+		ClientSettings pkceEnabled = ClientSettings.builder().requireProofKey(true).build();
+		ClientRegistration.Builder builder = ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(CLIENT_ID)
+			.clientSettings(pkceEnabled)
+			.authorizationGrantType(invalidGrantType)
+			.redirectUri(REDIRECT_URI)
+			.authorizationUri(AUTHORIZATION_URI)
+			.tokenUri(TOKEN_URI);
+
+		assertThatIllegalStateException().describedAs(
+				"clientSettings.isRequireProofKey=true is only valid with authorizationGrantType=AUTHORIZATION_CODE. Got authorizationGrantType={}",
+				invalidGrantType)
+			.isThrownBy(builder::build);
+	}
+
+	static List<AuthorizationGrantType> invalidPkceGrantTypes() {
+		return Arrays.stream(AuthorizationGrantType.class.getFields())
+			.filter((field) -> Modifier.isFinal(field.getModifiers())
+					&& field.getType() == AuthorizationGrantType.class)
+			.map((field) -> getStaticValue(field, AuthorizationGrantType.class))
+			.filter((grantType) -> grantType != AuthorizationGrantType.AUTHORIZATION_CODE)
+			// ensure works with .equals
+			.map((grantType) -> new AuthorizationGrantType(grantType.getValue()))
+			.collect(Collectors.toList());
+	}
+
+	private static <T> T getStaticValue(Field field, Class<T> clazz) {
+		try {
+			return (T) field.get(null);
+		}
+		catch (IllegalAccessException ex) {
+			throw new RuntimeException(ex);
+		}
+	}
+
 }