فهرست منبع

Simplify customizing OAuth2AuthorizationRequest

Fixes gh-7696
Joe Grandja 5 سال پیش
والد
کامیت
23ce717380

+ 20 - 5
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -41,6 +41,7 @@ import java.security.NoSuchAlgorithmException;
 import java.util.Base64;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.function.Consumer;
 
 /**
  * An implementation of an {@link OAuth2AuthorizationRequestResolver} that attempts to
@@ -66,6 +67,7 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
 	private final AntPathRequestMatcher authorizationRequestMatcher;
 	private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
 	private final StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
+	private Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer = customizer -> {};
 
 	/**
 	 * Constructs a {@code DefaultOAuth2AuthorizationRequestResolver} using the provided parameters.
@@ -98,6 +100,18 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
 		return resolve(request, registrationId, redirectUriAction);
 	}
 
+	/**
+	 * Sets the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
+	 * allowing for further customizations.
+	 *
+	 * @since 5.3
+	 * @param authorizationRequestCustomizer the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
+	 */
+	public void setAuthorizationRequestCustomizer(Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer) {
+		Assert.notNull(authorizationRequestCustomizer, "authorizationRequestCustomizer cannot be null");
+		this.authorizationRequestCustomizer = authorizationRequestCustomizer;
+	}
+
 	private String getAction(HttpServletRequest request, String defaultAction) {
 		String action = request.getParameter("action");
 		if (action == null) {
@@ -144,16 +158,17 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
 
 		String redirectUriStr = expandRedirectUri(request, clientRegistration, redirectUriAction);
 
-		OAuth2AuthorizationRequest authorizationRequest = builder
+		builder
 				.clientId(clientRegistration.getClientId())
 				.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
 				.redirectUri(redirectUriStr)
 				.scopes(clientRegistration.getScopes())
 				.state(this.stateGenerator.generateKey())
-				.attributes(attributes)
-				.build();
+				.attributes(attributes);
+
+		this.authorizationRequestCustomizer.accept(builder);
 
-		return authorizationRequest;
+		return builder.build();
 	}
 
 	private String resolveRegistrationId(HttpServletRequest request) {

+ 24 - 5
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -46,6 +46,7 @@ import java.security.NoSuchAlgorithmException;
 import java.util.Base64;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.function.Consumer;
 
 /**
  * The default implementation of {@link ServerOAuth2AuthorizationRequestResolver}.
@@ -81,6 +82,8 @@ public class DefaultServerOAuth2AuthorizationRequestResolver
 
 	private final StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
 
+	private Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer = customizer -> {};
+
 	/**
 	 * Creates a new instance
 	 * @param clientRegistrationRepository the repository to resolve the {@link ClientRegistration}
@@ -121,6 +124,18 @@ public class DefaultServerOAuth2AuthorizationRequestResolver
 			.map(clientRegistration -> authorizationRequest(exchange, clientRegistration));
 	}
 
+	/**
+	 * Sets the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
+	 * allowing for further customizations.
+	 *
+	 * @since 5.3
+	 * @param authorizationRequestCustomizer the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
+	 */
+	public final void setAuthorizationRequestCustomizer(Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer) {
+		Assert.notNull(authorizationRequestCustomizer, "authorizationRequestCustomizer cannot be null");
+		this.authorizationRequestCustomizer = authorizationRequestCustomizer;
+	}
+
 	private Mono<ClientRegistration> findByRegistrationId(ServerWebExchange exchange, String clientRegistration) {
 		return this.clientRegistrationRepository.findByRegistrationId(clientRegistration)
 				.switchIfEmpty(Mono.error(() -> new ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid client registration id")));
@@ -155,13 +170,17 @@ public class DefaultServerOAuth2AuthorizationRequestResolver
 					"Invalid Authorization Grant Type (" + clientRegistration.getAuthorizationGrantType().getValue()
 							+ ") for Client Registration with Id: " + clientRegistration.getRegistrationId());
 		}
-		return builder
+		builder
 				.clientId(clientRegistration.getClientId())
 				.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
-				.redirectUri(redirectUriStr).scopes(clientRegistration.getScopes())
+				.redirectUri(redirectUriStr)
+				.scopes(clientRegistration.getScopes())
 				.state(this.stateGenerator.generateKey())
-				.attributes(attributes)
-				.build();
+				.attributes(attributes);
+
+		this.authorizationRequestCustomizer.accept(builder);
+
+		return builder.build();
 	}
 
 	/**

+ 5 - 4
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestMixinTests.java

@@ -27,6 +27,7 @@ import org.springframework.security.oauth2.core.endpoint.TestOAuth2Authorization
 import org.springframework.util.CollectionUtils;
 import org.springframework.util.StringUtils;
 
+import java.util.Collections;
 import java.util.LinkedHashMap;
 import java.util.Map;
 import java.util.stream.Collectors;
@@ -70,8 +71,8 @@ public class OAuth2AuthorizationRequestMixinTests {
 				this.authorizationRequestBuilder
 						.scopes(null)
 						.state(null)
-						.additionalParameters(null)
-						.attributes(null)
+						.additionalParameters(Collections.emptyMap())
+						.attributes(Collections.emptyMap())
 						.build();
 		String expectedJson = asJson(authorizationRequest);
 		String json = this.mapper.writeValueAsString(authorizationRequest);
@@ -118,8 +119,8 @@ public class OAuth2AuthorizationRequestMixinTests {
 				this.authorizationRequestBuilder
 						.scopes(null)
 						.state(null)
-						.additionalParameters(null)
-						.attributes(null)
+						.additionalParameters(Collections.emptyMap())
+						.attributes(Collections.emptyMap())
 						.build();
 		String json = asJson(expectedAuthorizationRequest);
 		OAuth2AuthorizationRequest authorizationRequest = this.mapper.readValue(json, OAuth2AuthorizationRequest.class);

+ 80 - 2
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -31,7 +31,9 @@ import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.security.oauth2.core.oidc.OidcScopes;
 import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
 
-import static org.assertj.core.api.Assertions.*;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.assertj.core.api.Assertions.entry;
 
 /**
  * Tests for {@link DefaultOAuth2AuthorizationRequestResolver}.
@@ -81,6 +83,12 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
 				.isInstanceOf(IllegalArgumentException.class);
 	}
 
+	@Test
+	public void setAuthorizationRequestCustomizerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.resolver.setAuthorizationRequestCustomizer(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
 	@Test
 	public void resolveWhenNotAuthorizationRequestThenDoesNotResolve() {
 		String requestUri = "/path";
@@ -414,6 +422,76 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
 						"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}");
 	}
 
+	// gh-7696
+	@Test
+	public void resolveWhenAuthorizationRequestCustomizerRemovesNonceThenQueryExcludesNonce() {
+		ClientRegistration clientRegistration = this.oidcRegistration;
+		String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+
+		this.resolver.setAuthorizationRequestCustomizer(customizer -> customizer
+				.additionalParameters(params -> params.remove(OidcParameterNames.NONCE))
+				.attributes(attrs -> attrs.remove(OidcParameterNames.NONCE)));
+
+		OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+		assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE);
+		assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE);
+		assertThat(authorizationRequest.getAuthorizationRequestUri())
+				.matches("https://example.com/login/oauth/authorize\\?" +
+						"response_type=code&client_id=client-id&" +
+						"scope=openid&state=.{15,}&" +
+						"redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id");
+	}
+
+	@Test
+	public void resolveWhenAuthorizationRequestCustomizerAddsParameterThenQueryIncludesParameter() {
+		ClientRegistration clientRegistration = this.oidcRegistration;
+		String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+
+		this.resolver.setAuthorizationRequestCustomizer(customizer ->
+				customizer.authorizationRequestUri(uriBuilder -> {
+					uriBuilder.queryParam("param1", "value1");
+					return uriBuilder.build();
+				})
+		);
+
+		OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+		assertThat(authorizationRequest.getAuthorizationRequestUri())
+				.matches("https://example.com/login/oauth/authorize\\?" +
+						"response_type=code&client_id=client-id&" +
+						"scope=openid&state=.{15,}&" +
+						"redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id&" +
+						"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
+						"param1=value1");
+	}
+
+	@Test
+	public void resolveWhenAuthorizationRequestCustomizerOverridesParameterThenQueryIncludesParameter() {
+		ClientRegistration clientRegistration = this.oidcRegistration;
+		String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+
+		this.resolver.setAuthorizationRequestCustomizer(customizer ->
+				customizer.parameters(params -> {
+					params.put("appid", params.get("client_id"));
+					params.remove("client_id");
+				})
+		);
+
+		OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+		assertThat(authorizationRequest.getAuthorizationRequestUri())
+				.matches("https://example.com/login/oauth/authorize\\?" +
+						"response_type=code&" +
+						"scope=openid&state=.{15,}&" +
+						"redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id&" +
+						"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
+						"appid=client-id");
+	}
+
 	private static ClientRegistration.Builder fineRedirectUriTemplateClientRegistration() {
 		return ClientRegistration.withRegistrationId("fine-redirect-uri-template-client-registration")
 				.redirectUriTemplate("{baseScheme}://{baseHost}{basePort}{basePath}/{action}/oauth2/code/{registrationId}")

+ 81 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolverTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -37,6 +37,7 @@ import org.springframework.web.server.ServerWebExchange;
 import reactor.core.publisher.Mono;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.assertj.core.api.Assertions.catchThrowableOfType;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.when;
@@ -59,6 +60,12 @@ public class DefaultServerOAuth2AuthorizationRequestResolverTests {
 		this.resolver = new DefaultServerOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository);
 	}
 
+	@Test
+	public void setAuthorizationRequestCustomizerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.resolver.setAuthorizationRequestCustomizer(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
 	@Test
 	public void resolveWhenNotMatchThenNull() {
 		assertThat(resolve("/")).isNull();
@@ -139,6 +146,79 @@ public class DefaultServerOAuth2AuthorizationRequestResolverTests {
 				"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}");
 	}
 
+	// gh-7696
+	@Test
+	public void resolveWhenAuthorizationRequestCustomizerRemovesNonceThenQueryExcludesNonce() {
+		when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(
+				Mono.just(TestClientRegistrations.clientRegistration()
+						.scope(OidcScopes.OPENID)
+						.build()));
+
+		this.resolver.setAuthorizationRequestCustomizer(customizer -> customizer
+				.additionalParameters(params -> params.remove(OidcParameterNames.NONCE))
+				.attributes(attrs -> attrs.remove(OidcParameterNames.NONCE)));
+
+		OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id");
+
+		assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE);
+		assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE);
+		assertThat(authorizationRequest.getAuthorizationRequestUri())
+				.matches("https://example.com/login/oauth/authorize\\?" +
+						"response_type=code&client_id=client-id&" +
+						"scope=openid&state=.{15,}&" +
+						"redirect_uri=/login/oauth2/code/registration-id");
+	}
+
+	@Test
+	public void resolveWhenAuthorizationRequestCustomizerAddsParameterThenQueryIncludesParameter() {
+		when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(
+				Mono.just(TestClientRegistrations.clientRegistration()
+						.scope(OidcScopes.OPENID)
+						.build()));
+
+		this.resolver.setAuthorizationRequestCustomizer(customizer ->
+				customizer.authorizationRequestUri(uriBuilder -> {
+					uriBuilder.queryParam("param1", "value1");
+					return uriBuilder.build();
+				})
+		);
+
+		OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id");
+
+		assertThat(authorizationRequest.getAuthorizationRequestUri())
+				.matches("https://example.com/login/oauth/authorize\\?" +
+						"response_type=code&client_id=client-id&" +
+						"scope=openid&state=.{15,}&" +
+						"redirect_uri=/login/oauth2/code/registration-id&" +
+						"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
+						"param1=value1");
+	}
+
+	@Test
+	public void resolveWhenAuthorizationRequestCustomizerOverridesParameterThenQueryIncludesParameter() {
+		when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(
+				Mono.just(TestClientRegistrations.clientRegistration()
+						.scope(OidcScopes.OPENID)
+						.build()));
+
+		this.resolver.setAuthorizationRequestCustomizer(customizer ->
+				customizer.parameters(params -> {
+					params.put("appid", params.get("client_id"));
+					params.remove("client_id");
+				})
+		);
+
+		OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id");
+
+		assertThat(authorizationRequest.getAuthorizationRequestUri())
+				.matches("https://example.com/login/oauth/authorize\\?" +
+						"response_type=code&" +
+						"scope=openid&state=.{15,}&" +
+						"redirect_uri=/login/oauth2/code/registration-id&" +
+						"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
+						"appid=client-id");
+	}
+
 	private OAuth2AuthorizationRequest resolve(String path) {
 		ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get(path));
 		return this.resolver.resolve(exchange).block();

+ 116 - 44
oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java

@@ -22,16 +22,21 @@ import org.springframework.util.CollectionUtils;
 import org.springframework.util.LinkedMultiValueMap;
 import org.springframework.util.MultiValueMap;
 import org.springframework.util.StringUtils;
-import org.springframework.web.util.UriComponentsBuilder;
+import org.springframework.web.util.DefaultUriBuilderFactory;
+import org.springframework.web.util.UriBuilder;
 import org.springframework.web.util.UriUtils;
 
 import java.io.Serializable;
+import java.net.URI;
 import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.LinkedHashMap;
 import java.util.LinkedHashSet;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.Consumer;
+import java.util.function.Function;
 
 /**
  * A representation of an OAuth 2.0 Authorization Request
@@ -108,7 +113,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 	/**
 	 * Returns the scope(s).
 	 *
-	 * @return the scope(s)
+	 * @return the scope(s), or an empty {@code Set} if not available
 	 */
 	public Set<String> getScopes() {
 		return this.scopes;
@@ -124,31 +129,31 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 	}
 
 	/**
-	 * Returns the additional parameters used in the request.
+	 * Returns the additional parameter(s) used in the request.
 	 *
-	 * @return a {@code Map} of the additional parameters used in the request
+	 * @return a {@code Map} of the additional parameter(s), or an empty {@code Map} if not available
 	 */
 	public Map<String, Object> getAdditionalParameters() {
 		return this.additionalParameters;
 	}
 
 	/**
-	 * Returns the attributes associated to the request.
+	 * Returns the attribute(s) associated to the request.
 	 *
 	 * @since 5.2
-	 * @return a {@code Map} of the attributes associated to the request
+	 * @return a {@code Map} of the attribute(s), or an empty {@code Map} if not available
 	 */
 	public Map<String, Object> getAttributes() {
 		return this.attributes;
 	}
 
 	/**
-	 * Returns the value of an attribute associated to the request, or {@code null} if not available.
+	 * Returns the value of an attribute associated to the request.
 	 *
 	 * @since 5.2
 	 * @param name the name of the attribute
 	 * @param <T> the type of the attribute
-	 * @return the value of the attribute associated to the request
+	 * @return the value of the attribute associated to the request, or {@code null} if not available
 	 */
 	@SuppressWarnings("unchecked")
 	public <T> T getAttribute(String name) {
@@ -219,9 +224,12 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 		private String redirectUri;
 		private Set<String> scopes;
 		private String state;
-		private Map<String, Object> additionalParameters;
+		private Consumer<Map<String, Object>> additionalParametersConsumer = params -> {};
+		private Consumer<Map<String, Object>> parametersConsumer = params -> {};
+		private Consumer<Map<String, Object>> attributesConsumer = attrs -> {};
 		private String authorizationRequestUri;
-		private Map<String, Object> attributes;
+		private Function<UriBuilder, URI> authorizationRequestUriFunction = builder -> builder.build();
+		private final DefaultUriBuilderFactory uriBuilderFactory;
 
 		private Builder(AuthorizationGrantType authorizationGrantType) {
 			Assert.notNull(authorizationGrantType, "authorizationGrantType cannot be null");
@@ -231,6 +239,10 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 			} else if (AuthorizationGrantType.IMPLICIT.equals(authorizationGrantType)) {
 				this.responseType = OAuth2AuthorizationResponseType.TOKEN;
 			}
+			this.uriBuilderFactory = new DefaultUriBuilderFactory();
+			// The supplied authorizationUri may contain encoded parameters
+			// so disable encoding in UriBuilder and instead apply encoding within this builder
+			this.uriBuilderFactory.setEncodingMode(DefaultUriBuilderFactory.EncodingMode.NONE);
 		}
 
 		/**
@@ -274,7 +286,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 		 */
 		public Builder scope(String... scope) {
 			if (scope != null && scope.length > 0) {
-				return this.scopes(toLinkedHashSet(scope));
+				return scopes(new LinkedHashSet<>(Arrays.asList(scope)));
 			}
 			return this;
 		}
@@ -302,13 +314,43 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 		}
 
 		/**
-		 * Sets the additional parameters used in the request.
+		 * Sets the additional parameter(s) used in the request.
 		 *
-		 * @param additionalParameters the additional parameters used in the request
+		 * @param additionalParameters the additional parameter(s) used in the request
 		 * @return the {@link Builder}
 		 */
 		public Builder additionalParameters(Map<String, Object> additionalParameters) {
-			this.additionalParameters = additionalParameters;
+			if (additionalParameters != null) {
+				return additionalParameters(params -> params.putAll(additionalParameters));
+			}
+			return this;
+		}
+
+		/**
+		 * A {@code Consumer} to be provided access to the additional parameter(s)
+		 * allowing the ability to add, replace, or remove.
+		 *
+		 * @since 5.3
+		 * @param additionalParametersConsumer a {@code Consumer} of the additional parameters
+		 */
+		public Builder additionalParameters(Consumer<Map<String, Object>> additionalParametersConsumer) {
+			if (additionalParametersConsumer != null) {
+				this.additionalParametersConsumer = additionalParametersConsumer;
+			}
+			return this;
+		}
+
+		/**
+		 * A {@code Consumer} to be provided access to all the parameters
+		 * allowing the ability to add, replace, or remove.
+		 *
+		 * @since 5.3
+		 * @param parametersConsumer a {@code Consumer} of all the parameters
+		 */
+		public Builder parameters(Consumer<Map<String, Object>> parametersConsumer) {
+			if (parametersConsumer != null) {
+				this.parametersConsumer = parametersConsumer;
+			}
 			return this;
 		}
 
@@ -320,7 +362,23 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 		 * @return the {@link Builder}
 		 */
 		public Builder attributes(Map<String, Object> attributes) {
-			this.attributes = attributes;
+			if (attributes != null) {
+				return attributes(attrs -> attrs.putAll(attributes));
+			}
+			return this;
+		}
+
+		/**
+		 * A {@code Consumer} to be provided access to the attribute(s)
+		 * allowing the ability to add, replace, or remove.
+		 *
+		 * @since 5.3
+		 * @param attributesConsumer a {@code Consumer} of the attribute(s)
+		 */
+		public Builder attributes(Consumer<Map<String, Object>> attributesConsumer) {
+			if (attributesConsumer != null) {
+				this.attributesConsumer = attributesConsumer;
+			}
 			return this;
 		}
 
@@ -340,6 +398,20 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 			return this;
 		}
 
+		/**
+		 * A {@code Function} to be provided a {@code UriBuilder} representation
+		 * of the OAuth 2.0 Authorization Request allowing for further customizations.
+		 *
+		 * @since 5.3
+		 * @param authorizationRequestUriFunction a {@code Function} to be provided a {@code UriBuilder} representation of the OAuth 2.0 Authorization Request
+		 */
+		public Builder authorizationRequestUri(Function<UriBuilder, URI> authorizationRequestUriFunction) {
+			if (authorizationRequestUriFunction != null) {
+				this.authorizationRequestUriFunction = authorizationRequestUriFunction;
+			}
+			return this;
+		}
+
 		/**
 		 * Builds a new {@link OAuth2AuthorizationRequest}.
 		 *
@@ -362,53 +434,53 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 			authorizationRequest.scopes = Collections.unmodifiableSet(
 				CollectionUtils.isEmpty(this.scopes) ?
 					Collections.emptySet() : new LinkedHashSet<>(this.scopes));
-			authorizationRequest.additionalParameters = Collections.unmodifiableMap(
-				CollectionUtils.isEmpty(this.additionalParameters) ?
-					Collections.emptyMap() : new LinkedHashMap<>(this.additionalParameters));
+			Map<String, Object> additionalParameters = new LinkedHashMap<>();
+			this.additionalParametersConsumer.accept(additionalParameters);
+			authorizationRequest.additionalParameters = Collections.unmodifiableMap(additionalParameters);
+			Map<String, Object> attributes = new LinkedHashMap<>();
+			this.attributesConsumer.accept(attributes);
+			authorizationRequest.attributes = Collections.unmodifiableMap(attributes);
 			authorizationRequest.authorizationRequestUri =
 					StringUtils.hasText(this.authorizationRequestUri) ?
-						this.authorizationRequestUri : this.buildAuthorizationRequestUri();
-			authorizationRequest.attributes = Collections.unmodifiableMap(
-					CollectionUtils.isEmpty(this.attributes) ?
-							Collections.emptyMap() : new LinkedHashMap<>(this.attributes));
+							this.authorizationRequestUri : this.buildAuthorizationRequestUri();
 
 			return authorizationRequest;
 		}
 
 		private String buildAuthorizationRequestUri() {
-			MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
-			parameters.set(OAuth2ParameterNames.RESPONSE_TYPE, encodeQueryParam(this.responseType.getValue()));
-			parameters.set(OAuth2ParameterNames.CLIENT_ID, encodeQueryParam(this.clientId));
+			Map<String, Object> parameters = getParameters();	// Not encoded
+			this.parametersConsumer.accept(parameters);
+			MultiValueMap<String, String> queryParams = new LinkedMultiValueMap<>();
+			parameters.forEach((k, v) -> queryParams.set(
+					encodeQueryParam(k), encodeQueryParam(v.toString())));		// Encoded
+			UriBuilder uriBuilder = this.uriBuilderFactory.uriString(this.authorizationUri)
+					.queryParams(queryParams);
+			return this.authorizationRequestUriFunction.apply(uriBuilder).toString();
+		}
+
+		private Map<String, Object> getParameters() {
+			Map<String, Object> parameters = new LinkedHashMap<>();
+			parameters.put(OAuth2ParameterNames.RESPONSE_TYPE, this.responseType.getValue());
+			parameters.put(OAuth2ParameterNames.CLIENT_ID, this.clientId);
 			if (!CollectionUtils.isEmpty(this.scopes)) {
-				parameters.set(OAuth2ParameterNames.SCOPE,
-						encodeQueryParam(StringUtils.collectionToDelimitedString(this.scopes, " ")));
+				parameters.put(OAuth2ParameterNames.SCOPE,
+						StringUtils.collectionToDelimitedString(this.scopes, " "));
 			}
 			if (this.state != null) {
-				parameters.set(OAuth2ParameterNames.STATE, encodeQueryParam(this.state));
+				parameters.put(OAuth2ParameterNames.STATE, this.state);
 			}
 			if (this.redirectUri != null) {
-				parameters.set(OAuth2ParameterNames.REDIRECT_URI, encodeQueryParam(this.redirectUri));
+				parameters.put(OAuth2ParameterNames.REDIRECT_URI, this.redirectUri);
 			}
-			if (!CollectionUtils.isEmpty(this.additionalParameters)) {
-				this.additionalParameters.forEach((k, v) ->
-						parameters.set(encodeQueryParam(k), encodeQueryParam(v.toString())));
-			}
-
-			return UriComponentsBuilder.fromHttpUrl(this.authorizationUri)
-					.queryParams(parameters)
-					.build()
-					.toUriString();
+			Map<String, Object> additionalParameters = new LinkedHashMap<>();
+			this.additionalParametersConsumer.accept(additionalParameters);
+			additionalParameters.forEach((k, v) -> parameters.put(k, v.toString()));
+			return parameters;
 		}
 
 		// Encode query parameter value according to RFC 3986
 		private static String encodeQueryParam(String value) {
 			return UriUtils.encodeQueryParam(value, StandardCharsets.UTF_8);
 		}
-
-		private LinkedHashSet<String> toLinkedHashSet(String... scope) {
-			LinkedHashSet<String> result = new LinkedHashSet<>();
-			Collections.addAll(result, scope);
-			return result;
-		}
 	}
 }

+ 18 - 2
oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java

@@ -18,13 +18,16 @@ package org.springframework.security.oauth2.core.endpoint;
 import org.junit.Test;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 
+import java.net.URI;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.LinkedHashSet;
 import java.util.Map;
 import java.util.Set;
 
-import static org.assertj.core.api.Assertions.*;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /**
  * Tests for {@link OAuth2AuthorizationRequest}.
@@ -126,7 +129,7 @@ public class OAuth2AuthorizationRequestTests {
 					.redirectUri(REDIRECT_URI)
 					.scopes(SCOPES)
 					.state(STATE)
-					.additionalParameters(null)
+					.additionalParameters((Map) null)
 					.build())
 				.doesNotThrowAnyException();
 	}
@@ -220,6 +223,19 @@ public class OAuth2AuthorizationRequestTests {
 		assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo(AUTHORIZATION_URI);
 	}
 
+	@Test
+	public void buildWhenAuthorizationRequestUriFunctionSetThenOverridesDefault() {
+		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri(AUTHORIZATION_URI)
+				.clientId(CLIENT_ID)
+				.redirectUri(REDIRECT_URI)
+				.scopes(SCOPES)
+				.state(STATE)
+				.authorizationRequestUri(uriBuilder -> URI.create(AUTHORIZATION_URI))
+				.build();
+		assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo(AUTHORIZATION_URI);
+	}
+
 	@Test
 	public void buildWhenAuthorizationRequestUriNotSetThenDefaultSet() {
 		Map<String, Object> additionalParameters = new HashMap<>();