|
@@ -16,14 +16,20 @@
|
|
|
|
|
|
package org.springframework.security.oauth2.client.registration;
|
|
|
|
|
|
+import java.lang.reflect.Field;
|
|
|
+import java.lang.reflect.Modifier;
|
|
|
+import java.util.Arrays;
|
|
|
import java.util.Collections;
|
|
|
import java.util.LinkedHashMap;
|
|
|
+import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.Set;
|
|
|
import java.util.stream.Collectors;
|
|
|
import java.util.stream.Stream;
|
|
|
|
|
|
import org.junit.jupiter.api.Test;
|
|
|
+import org.junit.jupiter.params.ParameterizedTest;
|
|
|
+import org.junit.jupiter.params.provider.MethodSource;
|
|
|
|
|
|
import org.springframework.security.oauth2.core.AuthenticationMethod;
|
|
|
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
|
@@ -31,6 +37,7 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
|
|
|
|
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
|
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
|
|
+import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
|
|
|
|
|
|
/**
|
|
|
* Tests for {@link ClientRegistration}.
|
|
@@ -776,4 +783,59 @@ public class ClientRegistrationTests {
|
|
|
assertThat(clientRegistration.getClientSettings().isRequireProofKey()).isFalse();
|
|
|
}
|
|
|
|
|
|
+ // gh-16382
|
|
|
+ @Test
|
|
|
+ void buildWhenNewAuthorizationCodeAndPkceThenBuilds() {
|
|
|
+ ClientSettings pkceEnabled = ClientSettings.builder().requireProofKey(true).build();
|
|
|
+ ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
|
|
|
+ .clientId(CLIENT_ID)
|
|
|
+ .clientSettings(pkceEnabled)
|
|
|
+ .authorizationGrantType(new AuthorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()))
|
|
|
+ .redirectUri(REDIRECT_URI)
|
|
|
+ .authorizationUri(AUTHORIZATION_URI)
|
|
|
+ .tokenUri(TOKEN_URI)
|
|
|
+ .build();
|
|
|
+
|
|
|
+ // proof key should be false for passivity
|
|
|
+ assertThat(clientRegistration.getClientSettings().isRequireProofKey()).isTrue();
|
|
|
+ }
|
|
|
+
|
|
|
+ @ParameterizedTest
|
|
|
+ @MethodSource("invalidPkceGrantTypes")
|
|
|
+ void buildWhenInvalidGrantTypeForPkceThenException(AuthorizationGrantType invalidGrantType) {
|
|
|
+ ClientSettings pkceEnabled = ClientSettings.builder().requireProofKey(true).build();
|
|
|
+ ClientRegistration.Builder builder = ClientRegistration.withRegistrationId(REGISTRATION_ID)
|
|
|
+ .clientId(CLIENT_ID)
|
|
|
+ .clientSettings(pkceEnabled)
|
|
|
+ .authorizationGrantType(invalidGrantType)
|
|
|
+ .redirectUri(REDIRECT_URI)
|
|
|
+ .authorizationUri(AUTHORIZATION_URI)
|
|
|
+ .tokenUri(TOKEN_URI);
|
|
|
+
|
|
|
+ assertThatIllegalStateException().describedAs(
|
|
|
+ "clientSettings.isRequireProofKey=true is only valid with authorizationGrantType=AUTHORIZATION_CODE. Got authorizationGrantType={}",
|
|
|
+ invalidGrantType)
|
|
|
+ .isThrownBy(builder::build);
|
|
|
+ }
|
|
|
+
|
|
|
+ static List<AuthorizationGrantType> invalidPkceGrantTypes() {
|
|
|
+ return Arrays.stream(AuthorizationGrantType.class.getFields())
|
|
|
+ .filter((field) -> Modifier.isFinal(field.getModifiers())
|
|
|
+ && field.getType() == AuthorizationGrantType.class)
|
|
|
+ .map((field) -> getStaticValue(field, AuthorizationGrantType.class))
|
|
|
+ .filter((grantType) -> grantType != AuthorizationGrantType.AUTHORIZATION_CODE)
|
|
|
+ // ensure works with .equals
|
|
|
+ .map((grantType) -> new AuthorizationGrantType(grantType.getValue()))
|
|
|
+ .collect(Collectors.toList());
|
|
|
+ }
|
|
|
+
|
|
|
+ private static <T> T getStaticValue(Field field, Class<T> clazz) {
|
|
|
+ try {
|
|
|
+ return (T) field.get(null);
|
|
|
+ }
|
|
|
+ catch (IllegalAccessException ex) {
|
|
|
+ throw new RuntimeException(ex);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
}
|