Explorar el Código

Add attributes Consumer to OAuth2AuthorizationContext

Fixes gh-7385
Joe Grandja hace 6 años
padre
commit
93cda94969

+ 7 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java

@@ -21,6 +21,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
 import org.springframework.util.StringUtils;
 
 import java.util.Collections;
@@ -83,7 +84,12 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen
 		}
 		OAuth2AuthorizationContext authorizationContext = contextBuilder
 				.principal(principal)
-				.attributes(this.contextAttributesMapper.apply(authorizeRequest))
+				.attributes(attributes -> {
+					Map<String, Object> contextAttributes = this.contextAttributesMapper.apply(authorizeRequest);
+					if (!CollectionUtils.isEmpty(contextAttributes)) {
+						attributes.putAll(contextAttributes);
+					}
+				})
 				.build();
 
 		authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);

+ 9 - 5
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java

@@ -25,6 +25,7 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.Map;
+import java.util.function.Consumer;
 
 /**
  * A context that holds authorization-specific state and is used by an {@link OAuth2AuthorizedClientProvider}
@@ -161,13 +162,16 @@ public final class OAuth2AuthorizationContext {
 		}
 
 		/**
-		 * Sets the attributes associated to the context.
+		 * Provides a {@link Consumer} access to the attributes associated to the context.
 		 *
-		 * @param attributes the attributes associated to the context
-		 * @return the {@link Builder}
+		 * @param attributesConsumer a {@link Consumer} of the attributes associated to the context
+		 * @return the {@link OAuth2AuthorizeRequest.Builder}
 		 */
-		public Builder attributes(Map<String, Object> attributes) {
-			this.attributes = attributes;
+		public Builder attributes(Consumer<Map<String, Object>> attributesConsumer) {
+			if (this.attributes == null) {
+				this.attributes = new HashMap<>();
+			}
+			attributesConsumer.accept(this.attributes);
 			return this;
 		}
 

+ 7 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java

@@ -26,6 +26,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
 import org.springframework.util.StringUtils;
 import org.springframework.web.context.request.RequestContextHolder;
 import org.springframework.web.context.request.ServletRequestAttributes;
@@ -95,7 +96,12 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
 		}
 		OAuth2AuthorizationContext authorizationContext = contextBuilder
 				.principal(principal)
-				.attributes(this.contextAttributesMapper.apply(authorizeRequest))
+				.attributes(attributes -> {
+					Map<String, Object> contextAttributes = this.contextAttributesMapper.apply(authorizeRequest);
+					if (!CollectionUtils.isEmpty(contextAttributes)) {
+						attributes.putAll(contextAttributes);
+					}
+				})
 				.build();
 
 		authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);

+ 11 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java

@@ -26,6 +26,7 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
 import org.springframework.util.StringUtils;
 import org.springframework.web.server.ServerWebExchange;
 import reactor.core.publisher.Mono;
@@ -106,7 +107,11 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
 				.flatMap(this.contextAttributesMapper::apply)
 				.map(attrs -> OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient)
 						.principal(authorizeRequest.getPrincipal())
-						.attributes(attrs)
+						.attributes(attributes -> {
+							if (!CollectionUtils.isEmpty(attrs)) {
+								attributes.putAll(attrs);
+							}
+						})
 						.build());
 	}
 
@@ -116,7 +121,11 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
 				.flatMap(this.contextAttributesMapper::apply)
 				.map(attrs -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration)
 						.principal(authorizeRequest.getPrincipal())
-						.attributes(attrs)
+						.attributes(attributes -> {
+							if (!CollectionUtils.isEmpty(attrs)) {
+								attributes.putAll(attrs);
+							}
+						})
 						.build());
 	}
 

+ 4 - 2
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java

@@ -68,8 +68,10 @@ public class OAuth2AuthorizationContextTests {
 	public void withAuthorizedClientWhenAllValuesProvidedThenAllValuesAreSet() {
 		OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient)
 				.principal(this.principal)
-				.attribute("attribute1", "value1")
-				.attribute("attribute2", "value2")
+				.attributes(attributes -> {
+					attributes.put("attribute1", "value1");
+					attributes.put("attribute2", "value2");
+				})
 				.build();
 		assertThat(authorizationContext.getClientRegistration()).isSameAs(this.clientRegistration);
 		assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);