Parcourir la source

Fix OAuth2AuthorizationRequest additionalParameters/attributes Consumer

Fixes gh-8177
Joe Grandja il y a 5 ans
Parent
commit
46baf38f59

+ 4 - 5
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestMixinTests.java

@@ -27,7 +27,6 @@ import org.springframework.security.oauth2.core.endpoint.TestOAuth2Authorization
 import org.springframework.util.CollectionUtils;
 import org.springframework.util.StringUtils;
 
-import java.util.Collections;
 import java.util.LinkedHashMap;
 import java.util.Map;
 import java.util.stream.Collectors;
@@ -71,8 +70,8 @@ public class OAuth2AuthorizationRequestMixinTests {
 				this.authorizationRequestBuilder
 						.scopes(null)
 						.state(null)
-						.additionalParameters(Collections.emptyMap())
-						.attributes(Collections.emptyMap())
+						.additionalParameters(Map::clear)
+						.attributes(Map::clear)
 						.build();
 		String expectedJson = asJson(authorizationRequest);
 		String json = this.mapper.writeValueAsString(authorizationRequest);
@@ -119,8 +118,8 @@ public class OAuth2AuthorizationRequestMixinTests {
 				this.authorizationRequestBuilder
 						.scopes(null)
 						.state(null)
-						.additionalParameters(Collections.emptyMap())
-						.attributes(Collections.emptyMap())
+						.additionalParameters(Map::clear)
+						.attributes(Map::clear)
 						.build();
 		String json = asJson(expectedAuthorizationRequest);
 		OAuth2AuthorizationRequest authorizationRequest = this.mapper.readValue(json, OAuth2AuthorizationRequest.class);

+ 1 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java

@@ -437,6 +437,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
 		OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
 		assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE);
 		assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE);
+		assertThat(authorizationRequest.getAttributes()).containsKey(OAuth2ParameterNames.REGISTRATION_ID);
 		assertThat(authorizationRequest.getAuthorizationRequestUri())
 				.matches("https://example.com/login/oauth/authorize\\?" +
 						"response_type=code&client_id=client-id&" +

+ 2 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolverTests.java

@@ -29,6 +29,7 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.security.oauth2.core.oidc.OidcScopes;
 import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
@@ -162,6 +163,7 @@ public class DefaultServerOAuth2AuthorizationRequestResolverTests {
 
 		assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE);
 		assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE);
+		assertThat(authorizationRequest.getAttributes()).containsKey(OAuth2ParameterNames.REGISTRATION_ID);
 		assertThat(authorizationRequest.getAuthorizationRequestUri())
 				.matches("https://example.com/login/oauth/authorize\\?" +
 						"response_type=code&client_id=client-id&" +

+ 2 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest.java

@@ -35,6 +35,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import reactor.core.publisher.Mono;
 
 import java.util.Collections;
+import java.util.Map;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -96,7 +97,7 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest {
 
 	@Test
 	public void applyWhenAttributesMissingThenOAuth2AuthorizationException() {
-		this.authorizationRequest.attributes(Collections.emptyMap());
+		this.authorizationRequest.attributes(Map::clear);
 		when(this.authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.just(this.authorizationRequest.build()));
 
 		assertThatThrownBy(() -> applyConverter())

+ 12 - 18
oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java

@@ -229,9 +229,9 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 		private String redirectUri;
 		private Set<String> scopes;
 		private String state;
-		private Consumer<Map<String, Object>> additionalParametersConsumer = params -> {};
+		private Map<String, Object> additionalParameters = new LinkedHashMap<>();
 		private Consumer<Map<String, Object>> parametersConsumer = params -> {};
-		private Consumer<Map<String, Object>> attributesConsumer = attrs -> {};
+		private Map<String, Object> attributes = new LinkedHashMap<>();
 		private String authorizationRequestUri;
 		private Function<UriBuilder, URI> authorizationRequestUriFunction = builder -> builder.build();
 		private final DefaultUriBuilderFactory uriBuilderFactory;
@@ -325,8 +325,8 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 		 * @return the {@link Builder}
 		 */
 		public Builder additionalParameters(Map<String, Object> additionalParameters) {
-			if (additionalParameters != null) {
-				return additionalParameters(params -> params.putAll(additionalParameters));
+			if (!CollectionUtils.isEmpty(additionalParameters)) {
+				this.additionalParameters.putAll(additionalParameters);
 			}
 			return this;
 		}
@@ -340,7 +340,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 		 */
 		public Builder additionalParameters(Consumer<Map<String, Object>> additionalParametersConsumer) {
 			if (additionalParametersConsumer != null) {
-				this.additionalParametersConsumer = additionalParametersConsumer;
+				additionalParametersConsumer.accept(this.additionalParameters);
 			}
 			return this;
 		}
@@ -367,8 +367,8 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 		 * @return the {@link Builder}
 		 */
 		public Builder attributes(Map<String, Object> attributes) {
-			if (attributes != null) {
-				return attributes(attrs -> attrs.putAll(attributes));
+			if (!CollectionUtils.isEmpty(attributes)) {
+				this.attributes.putAll(attributes);
 			}
 			return this;
 		}
@@ -382,7 +382,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 		 */
 		public Builder attributes(Consumer<Map<String, Object>> attributesConsumer) {
 			if (attributesConsumer != null) {
-				this.attributesConsumer = attributesConsumer;
+				attributesConsumer.accept(this.attributes);
 			}
 			return this;
 		}
@@ -439,12 +439,8 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 			authorizationRequest.scopes = Collections.unmodifiableSet(
 				CollectionUtils.isEmpty(this.scopes) ?
 					Collections.emptySet() : new LinkedHashSet<>(this.scopes));
-			Map<String, Object> additionalParameters = new LinkedHashMap<>();
-			this.additionalParametersConsumer.accept(additionalParameters);
-			authorizationRequest.additionalParameters = Collections.unmodifiableMap(additionalParameters);
-			Map<String, Object> attributes = new LinkedHashMap<>();
-			this.attributesConsumer.accept(attributes);
-			authorizationRequest.attributes = Collections.unmodifiableMap(attributes);
+			authorizationRequest.additionalParameters = Collections.unmodifiableMap(this.additionalParameters);
+			authorizationRequest.attributes = Collections.unmodifiableMap(this.attributes);
 			authorizationRequest.authorizationRequestUri =
 					StringUtils.hasText(this.authorizationRequestUri) ?
 							this.authorizationRequestUri : this.buildAuthorizationRequestUri();
@@ -457,7 +453,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 			this.parametersConsumer.accept(parameters);
 			MultiValueMap<String, String> queryParams = new LinkedMultiValueMap<>();
 			parameters.forEach((k, v) -> queryParams.set(
-					encodeQueryParam(k), encodeQueryParam(v.toString())));		// Encoded
+					encodeQueryParam(k), encodeQueryParam(String.valueOf(v))));		// Encoded
 			UriBuilder uriBuilder = this.uriBuilderFactory.uriString(this.authorizationUri)
 					.queryParams(queryParams);
 			return this.authorizationRequestUriFunction.apply(uriBuilder).toString();
@@ -477,9 +473,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 			if (this.redirectUri != null) {
 				parameters.put(OAuth2ParameterNames.REDIRECT_URI, this.redirectUri);
 			}
-			Map<String, Object> additionalParameters = new LinkedHashMap<>();
-			this.additionalParametersConsumer.accept(additionalParameters);
-			additionalParameters.forEach((k, v) -> parameters.put(k, v.toString()));
+			parameters.putAll(this.additionalParameters);
 			return parameters;
 		}
 

+ 2 - 2
oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java

@@ -121,7 +121,7 @@ public class OAuth2AuthorizationRequestTests {
 	}
 
 	@Test
-	public void buildWhenAdditionalParametersIsNullThenDoesNotThrowAnyException() {
+	public void buildWhenAdditionalParametersEmptyThenDoesNotThrowAnyException() {
 		assertThatCode(() ->
 				OAuth2AuthorizationRequest.authorizationCode()
 					.authorizationUri(AUTHORIZATION_URI)
@@ -129,7 +129,7 @@ public class OAuth2AuthorizationRequestTests {
 					.redirectUri(REDIRECT_URI)
 					.scopes(SCOPES)
 					.state(STATE)
-					.additionalParameters((Map) null)
+					.additionalParameters(Map::clear)
 					.build())
 				.doesNotThrowAnyException();
 	}