Просмотр исходного кода

Add support for custom authorization request parameters

Fixes gh-4911
Joe Grandja 7 лет назад
Родитель
Сommit
779597af2a
12 измененных файлов с 867 добавлено и 423 удалено
  1. 2 2
      config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java
  2. 169 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java
  3. 41 120
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java
  4. 3 2
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilter.java
  5. 43 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestResolver.java
  6. 0 53
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestUriBuilder.java
  7. 242 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java
  8. 73 102
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java
  9. 0 58
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestUriBuilderTests.java
  10. 92 1
      oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java
  11. 201 84
      oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java
  12. 1 1
      samples/boot/authcodegrant/src/integration-test/java/org/springframework/security/samples/OAuth2AuthorizationCodeGrantApplicationTests.java

+ 2 - 2
config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java

@@ -120,7 +120,7 @@ public class OAuth2ClientConfigurerTests {
 		MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/authorization/registration-1"))
 			.andExpect(status().is3xxRedirection())
 			.andReturn();
-		assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost/client-1");
+		assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fclient-1");
 	}
 
 	@Test
@@ -168,7 +168,7 @@ public class OAuth2ClientConfigurerTests {
 		MvcResult mvcResult = this.mockMvc.perform(get("/resource1").with(user("user1")))
 				.andExpect(status().is3xxRedirection())
 				.andReturn();
-		assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost/client-1");
+		assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fclient-1");
 
 		verify(requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
 	}

+ 169 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java

@@ -0,0 +1,169 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.web;
+
+import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
+import org.springframework.security.crypto.keygen.StringKeyGenerator;
+import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.web.util.UrlUtils;
+import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
+import org.springframework.util.Assert;
+import org.springframework.web.util.UriComponentsBuilder;
+
+import javax.servlet.http.HttpServletRequest;
+import java.util.Base64;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME;
+
+/**
+ * An implementation of an {@link OAuth2AuthorizationRequestResolver} that attempts to
+ * resolve an {@link OAuth2AuthorizationRequest} from the provided {@code HttpServletRequest}
+ * using the default request {@code URI} pattern {@code /oauth2/authorization/{registrationId}}.
+ *
+ * <p>
+ * <b>NOTE:</b> The default base {@code URI} {@code /oauth2/authorization} may be overridden
+ * via it's constructor {@link #DefaultOAuth2AuthorizationRequestResolver(ClientRegistrationRepository, String)}.
+ *
+ * @author Joe Grandja
+ * @since 5.1
+ * @see OAuth2AuthorizationRequestResolver
+ * @see OAuth2AuthorizationRequestRedirectFilter
+ */
+public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2AuthorizationRequestResolver {
+	private static final String REGISTRATION_ID_URI_VARIABLE_NAME = "registrationId";
+	private final ClientRegistrationRepository clientRegistrationRepository;
+	private final AntPathRequestMatcher authorizationRequestMatcher;
+	private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
+
+	/**
+	 * Constructs a {@code DefaultOAuth2AuthorizationRequestResolver} using the provided parameters.
+	 *
+	 * @param clientRegistrationRepository the repository of client registrations
+	 * @param authorizationRequestBaseUri the base {@code URI} used for resolving authorization requests
+	 */
+	public DefaultOAuth2AuthorizationRequestResolver(ClientRegistrationRepository clientRegistrationRepository,
+														String authorizationRequestBaseUri) {
+		Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
+		Assert.hasText(authorizationRequestBaseUri, "authorizationRequestBaseUri cannot be empty");
+		this.clientRegistrationRepository = clientRegistrationRepository;
+		this.authorizationRequestMatcher = new AntPathRequestMatcher(
+				authorizationRequestBaseUri + "/{" + REGISTRATION_ID_URI_VARIABLE_NAME + "}");
+	}
+
+	@Override
+	public OAuth2AuthorizationRequest resolve(HttpServletRequest request) {
+		String registrationId = this.resolveRegistrationId(request);
+		if (registrationId == null) {
+			return null;
+		}
+
+		ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId);
+		if (clientRegistration == null) {
+			throw new IllegalArgumentException("Invalid Client Registration with Id: " + registrationId);
+		}
+
+		OAuth2AuthorizationRequest.Builder builder;
+		if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
+			builder = OAuth2AuthorizationRequest.authorizationCode();
+		} else if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizationGrantType())) {
+			builder = OAuth2AuthorizationRequest.implicit();
+		} else {
+			throw new IllegalArgumentException("Invalid Authorization Grant Type ("  +
+					clientRegistration.getAuthorizationGrantType().getValue() +
+					") for Client Registration with Id: " + clientRegistration.getRegistrationId());
+		}
+
+		String redirectUriAction = this.resolveRedirectUriAction(request, clientRegistration);
+		String redirectUriStr = this.expandRedirectUri(request, clientRegistration, redirectUriAction);
+
+		Map<String, Object> additionalParameters = new HashMap<>();
+		additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());
+
+		OAuth2AuthorizationRequest authorizationRequest = builder
+				.clientId(clientRegistration.getClientId())
+				.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
+				.redirectUri(redirectUriStr)
+				.scopes(clientRegistration.getScopes())
+				.state(this.stateGenerator.generateKey())
+				.additionalParameters(additionalParameters)
+				.build();
+
+		return authorizationRequest;
+	}
+
+	private String resolveRegistrationId(HttpServletRequest request) {
+		// Check for ClientAuthorizationRequiredException which may have been set
+		// in the request by OAuth2AuthorizationRequestRedirectFilter
+		ClientAuthorizationRequiredException authzEx =
+				(ClientAuthorizationRequiredException) request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME);
+		if (authzEx != null) {
+			return authzEx.getClientRegistrationId();
+		}
+		if (this.authorizationRequestMatcher.matches(request)) {
+			return this.authorizationRequestMatcher
+					.extractUriTemplateVariables(request).get(REGISTRATION_ID_URI_VARIABLE_NAME);
+		}
+		return null;
+	}
+
+	private String resolveRedirectUriAction(HttpServletRequest request, ClientRegistration clientRegistration) {
+		String action = null;
+		if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
+			String loginAction = "login";
+			String authorizeAction = "authorize";
+			String actionParameter = request.getParameter("action");
+			if (request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME) != null) {
+				// Check for ClientAuthorizationRequiredException which may have been set
+				// in the request by OAuth2AuthorizationRequestRedirectFilter
+				action = authorizeAction;
+			} else if (actionParameter == null) {
+				action = loginAction;		// Default
+			} else {
+				if (actionParameter.equalsIgnoreCase(loginAction)) {
+					action = loginAction;
+				} else {
+					action = authorizeAction;
+				}
+			}
+		}
+		return action;
+	}
+
+	private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration, String action) {
+		// Supported URI variables -> baseUrl, action, registrationId
+		// Used in -> CommonOAuth2Provider.DEFAULT_REDIRECT_URL = "{baseUrl}/{action}/oauth2/code/{registrationId}"
+		Map<String, String> uriVariables = new HashMap<>();
+		uriVariables.put("registrationId", clientRegistration.getRegistrationId());
+		String baseUrl = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
+				.replacePath(request.getContextPath())
+				.build()
+				.toUriString();
+		uriVariables.put("baseUrl", baseUrl);
+		if (action != null) {
+			uriVariables.put("action", action);
+		}
+		return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUriTemplate())
+				.buildAndExpand(uriVariables)
+				.toUriString();
+	}
+}

+ 41 - 120
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java

@@ -16,34 +16,24 @@
 package org.springframework.security.oauth2.client.web;
 
 import org.springframework.http.HttpStatus;
-import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
-import org.springframework.security.crypto.keygen.StringKeyGenerator;
 import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
-import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.web.DefaultRedirectStrategy;
 import org.springframework.security.web.RedirectStrategy;
 import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
 import org.springframework.security.web.savedrequest.RequestCache;
 import org.springframework.security.web.util.ThrowableAnalyzer;
-import org.springframework.security.web.util.UrlUtils;
-import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.util.Assert;
 import org.springframework.web.filter.OncePerRequestFilter;
-import org.springframework.web.util.UriComponentsBuilder;
 
 import javax.servlet.FilterChain;
 import javax.servlet.ServletException;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 import java.io.IOException;
-import java.net.URI;
-import java.util.Base64;
-import java.util.HashMap;
-import java.util.Map;
 
 /**
  * This {@code Filter} initiates the authorization code grant or implicit grant flow
@@ -58,19 +48,24 @@ import java.util.Map;
  *
  * <p>
  * By default, this {@code Filter} responds to authorization requests
- * at the {@code URI} {@code /oauth2/authorization/{registrationId}}.
+ * at the {@code URI} {@code /oauth2/authorization/{registrationId}}
+ * using the default {@link OAuth2AuthorizationRequestResolver}.
  * The {@code URI} template variable {@code {registrationId}} represents the
  * {@link ClientRegistration#getRegistrationId() registration identifier} of the client
  * that is used for initiating the OAuth 2.0 Authorization Request.
  *
  * <p>
- * <b>NOTE:</b> The default base {@code URI} {@code /oauth2/authorization} may be overridden
- * via it's constructor {@link #OAuth2AuthorizationRequestRedirectFilter(ClientRegistrationRepository, String)}.
+ * The default base {@code URI} {@code /oauth2/authorization} may be overridden
+ * via the constructor {@link #OAuth2AuthorizationRequestRedirectFilter(ClientRegistrationRepository, String)},
+ * or alternatively, an {@code OAuth2AuthorizationRequestResolver} may be provided to the constructor
+ * {@link #OAuth2AuthorizationRequestRedirectFilter(OAuth2AuthorizationRequestResolver)}
+ * to override the resolving of authorization requests.
 
  * @author Joe Grandja
  * @author Rob Winch
  * @since 5.0
  * @see OAuth2AuthorizationRequest
+ * @see OAuth2AuthorizationRequestResolver
  * @see AuthorizationRequestRepository
  * @see ClientRegistration
  * @see ClientRegistrationRepository
@@ -84,18 +79,14 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
 	 * The default base {@code URI} used for authorization requests.
 	 */
 	public static final String DEFAULT_AUTHORIZATION_REQUEST_BASE_URI = "/oauth2/authorization";
-	private static final String REGISTRATION_ID_URI_VARIABLE_NAME = "registrationId";
-	private static final String AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME =
+	static final String AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME =
 			ClientAuthorizationRequiredException.class.getName() + ".AUTHORIZATION_REQUIRED_EXCEPTION";
-	private final AntPathRequestMatcher authorizationRequestMatcher;
-	private final ClientRegistrationRepository clientRegistrationRepository;
-	private final OAuth2AuthorizationRequestUriBuilder authorizationRequestUriBuilder = new OAuth2AuthorizationRequestUriBuilder();
+	private final ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();
 	private final RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy();
-	private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
+	private OAuth2AuthorizationRequestResolver authorizationRequestResolver;
 	private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
 		new HttpSessionOAuth2AuthorizationRequestRepository();
 	private RequestCache requestCache = new HttpSessionRequestCache();
-	private final ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();
 
 	/**
 	 * Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided parameters.
@@ -112,14 +103,23 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
 	 * @param clientRegistrationRepository the repository of client registrations
 	 * @param authorizationRequestBaseUri the base {@code URI} used for authorization requests
 	 */
-	public OAuth2AuthorizationRequestRedirectFilter(
-		ClientRegistrationRepository clientRegistrationRepository, String authorizationRequestBaseUri) {
-
-		Assert.hasText(authorizationRequestBaseUri, "authorizationRequestBaseUri cannot be empty");
+	public OAuth2AuthorizationRequestRedirectFilter(ClientRegistrationRepository clientRegistrationRepository,
+													String authorizationRequestBaseUri) {
 		Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
-		this.authorizationRequestMatcher = new AntPathRequestMatcher(
-			authorizationRequestBaseUri + "/{" + REGISTRATION_ID_URI_VARIABLE_NAME + "}");
-		this.clientRegistrationRepository = clientRegistrationRepository;
+		Assert.hasText(authorizationRequestBaseUri, "authorizationRequestBaseUri cannot be empty");
+		this.authorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
+				clientRegistrationRepository, authorizationRequestBaseUri);
+	}
+
+	/**
+	 * Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided parameters.
+	 *
+	 * @since 5.1
+	 * @param authorizationRequestResolver the resolver used for resolving authorization requests
+	 */
+	public OAuth2AuthorizationRequestRedirectFilter(OAuth2AuthorizationRequestResolver authorizationRequestResolver) {
+		Assert.notNull(authorizationRequestResolver, "authorizationRequestResolver cannot be null");
+		this.authorizationRequestResolver = authorizationRequestResolver;
 	}
 
 	/**
@@ -147,12 +147,14 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
 
-		if (this.shouldRequestAuthorization(request, response)) {
-			try {
-				this.sendRedirectForAuthorization(request, response);
-			} catch (Exception failed) {
-				this.unsuccessfulRedirectForAuthorization(request, response, failed);
+		try {
+			OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestResolver.resolve(request);
+			if (authorizationRequest != null) {
+				this.sendRedirectForAuthorization(request, response, authorizationRequest);
+				return;
 			}
+		} catch (Exception failed) {
+			this.unsuccessfulRedirectForAuthorization(request, response, failed);
 			return;
 		}
 
@@ -168,7 +170,11 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
 			if (authzEx != null) {
 				try {
 					request.setAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME, authzEx);
-					this.sendRedirectForAuthorization(request, response, authzEx.getClientRegistrationId());
+					OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestResolver.resolve(request);
+					if (authorizationRequest == null) {
+						throw authzEx;
+					}
+					this.sendRedirectForAuthorization(request, response, authorizationRequest);
 					this.requestCache.saveRequest(request, response);
 				} catch (Exception failed) {
 					this.unsuccessfulRedirectForAuthorization(request, response, failed);
@@ -188,61 +194,13 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
 		}
 	}
 
-	private boolean shouldRequestAuthorization(HttpServletRequest request, HttpServletResponse response) {
-		return this.authorizationRequestMatcher.matches(request);
-	}
-
-	private void sendRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response)
-		throws IOException, ServletException {
-
-		String registrationId = this.authorizationRequestMatcher
-			.extractUriTemplateVariables(request).get(REGISTRATION_ID_URI_VARIABLE_NAME);
-		this.sendRedirectForAuthorization(request, response, registrationId);
-	}
-
 	private void sendRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response,
-												String registrationId) throws IOException, ServletException {
-
-		ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId);
-		if (clientRegistration == null) {
-			throw new IllegalArgumentException("Invalid Client Registration with Id: " + registrationId);
-		}
-		this.sendRedirectForAuthorization(request, response, clientRegistration);
-	}
-
-	private void sendRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response,
-												ClientRegistration clientRegistration) throws IOException, ServletException {
-
-		String redirectUriStr = this.expandRedirectUri(request, clientRegistration);
-
-		Map<String, Object> additionalParameters = new HashMap<>();
-		additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());
-
-		OAuth2AuthorizationRequest.Builder builder;
-		if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
-			builder = OAuth2AuthorizationRequest.authorizationCode();
-		} else if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizationGrantType())) {
-			builder = OAuth2AuthorizationRequest.implicit();
-		} else {
-			throw new IllegalArgumentException("Invalid Authorization Grant Type ("  +
-					clientRegistration.getAuthorizationGrantType().getValue() +
-					") for Client Registration with Id: " + clientRegistration.getRegistrationId());
-		}
-		OAuth2AuthorizationRequest authorizationRequest = builder
-				.clientId(clientRegistration.getClientId())
-				.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
-				.redirectUri(redirectUriStr)
-				.scopes(clientRegistration.getScopes())
-				.state(this.stateGenerator.generateKey())
-				.additionalParameters(additionalParameters)
-				.build();
+												OAuth2AuthorizationRequest authorizationRequest) throws IOException, ServletException {
 
 		if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(authorizationRequest.getGrantType())) {
 			this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
 		}
-
-		URI redirectUri = this.authorizationRequestUriBuilder.build(authorizationRequest);
-		this.authorizationRedirectStrategy.sendRedirect(request, response, redirectUri.toString());
+		this.authorizationRedirectStrategy.sendRedirect(request, response, authorizationRequest.getAuthorizationRequestUri());
 	}
 
 	private void unsuccessfulRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response,
@@ -254,43 +212,6 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
 		response.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(), HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase());
 	}
 
-	private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration) {
-		// Supported URI variables -> baseUrl, action, registrationId
-		// Used in -> CommonOAuth2Provider.DEFAULT_REDIRECT_URL = "{baseUrl}/{action}/oauth2/code/{registrationId}"
-		Map<String, String> uriVariables = new HashMap<>();
-		uriVariables.put("registrationId", clientRegistration.getRegistrationId());
-
-		String baseUrl = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
-				.replacePath(request.getContextPath())
-				.build()
-				.toUriString();
-		uriVariables.put("baseUrl", baseUrl);
-
-		if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
-			String loginAction = "login";
-			String authorizeAction = "authorize";
-			String actionParameter = "action";
-			String action;
-			if (request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME) != null) {
-				action = authorizeAction;
-			} else if (request.getParameter(actionParameter) == null) {
-				action = loginAction;
-			} else {
-				String actionValue = request.getParameter(actionParameter);
-				if (loginAction.equalsIgnoreCase(actionValue)) {
-					action = loginAction;
-				} else {
-					action = authorizeAction;
-				}
-			}
-			uriVariables.put("action", action);
-		}
-
-		return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUriTemplate())
-			.buildAndExpand(uriVariables)
-			.toUriString();
-	}
-
 	private static final class DefaultThrowableAnalyzer extends ThrowableAnalyzer {
 		protected void initExtractorMap() {
 			super.initExtractorMap();

+ 3 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectWebFilter.java

@@ -87,7 +87,6 @@ public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter {
 			ClientAuthorizationRequiredException.class.getName() + ".AUTHORIZATION_REQUIRED_EXCEPTION";
 	private final ServerWebExchangeMatcher authorizationRequestMatcher;
 	private final ReactiveClientRegistrationRepository clientRegistrationRepository;
-	private final OAuth2AuthorizationRequestUriBuilder authorizationRequestUriBuilder = new OAuth2AuthorizationRequestUriBuilder();
 	private final ServerRedirectStrategy authorizationRedirectStrategy = new DefaultServerRedirectStrategy();
 	private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
 	private ReactiveAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
@@ -184,7 +183,9 @@ public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter {
 						.saveAuthorizationRequest(authorizationRequest, exchange);
 			}
 
-			URI redirectUri = this.authorizationRequestUriBuilder.build(authorizationRequest);
+			URI redirectUri = UriComponentsBuilder
+					.fromUriString(authorizationRequest.getAuthorizationRequestUri())
+					.build(true).toUri();
 			return saveAuthorizationRequest
 					.then(this.authorizationRedirectStrategy.sendRedirect(exchange, redirectUri));
 		});

+ 43 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestResolver.java

@@ -0,0 +1,43 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.web;
+
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+
+import javax.servlet.http.HttpServletRequest;
+
+/**
+ * Implementations of this interface are capable of resolving
+ * an {@link OAuth2AuthorizationRequest} from the provided {@code HttpServletRequest}.
+ * Used by the {@link OAuth2AuthorizationRequestRedirectFilter} for resolving Authorization Requests.
+ *
+ * @author Joe Grandja
+ * @since 5.1
+ * @see OAuth2AuthorizationRequest
+ * @see OAuth2AuthorizationRequestRedirectFilter
+ */
+public interface OAuth2AuthorizationRequestResolver {
+
+	/**
+	 * Returns the {@link OAuth2AuthorizationRequest} resolved from
+	 * the provided {@code HttpServletRequest} or {@code null} if not available.
+	 *
+	 * @param request the {@code HttpServletRequest}
+	 * @return the resolved {@link OAuth2AuthorizationRequest} or {@code null} if not available
+	 */
+	OAuth2AuthorizationRequest resolve(HttpServletRequest request);
+
+}

+ 0 - 53
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestUriBuilder.java

@@ -1,53 +0,0 @@
-/*
- * Copyright 2002-2017 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.security.oauth2.client.web;
-
-import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
-import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
-import org.springframework.util.Assert;
-import org.springframework.util.StringUtils;
-import org.springframework.web.util.UriComponentsBuilder;
-
-import java.net.URI;
-import java.util.Set;
-
-/**
- * A {@code URI} builder for an OAuth 2.0 Authorization Request.
- *
- * @author Joe Grandja
- * @since 5.0
- * @see OAuth2AuthorizationRequest
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.1">Section 4.1.1 Authorization Code Grant Request</a>
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.2.1">Section 4.2.1 Implicit Grant Request</a>
- */
-class OAuth2AuthorizationRequestUriBuilder {
-
-	URI build(OAuth2AuthorizationRequest authorizationRequest) {
-		Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
-		Set<String> scopes = authorizationRequest.getScopes();
-		UriComponentsBuilder uriBuilder = UriComponentsBuilder
-			.fromUriString(authorizationRequest.getAuthorizationUri())
-			.queryParam(OAuth2ParameterNames.RESPONSE_TYPE, authorizationRequest.getResponseType().getValue())
-			.queryParam(OAuth2ParameterNames.CLIENT_ID, authorizationRequest.getClientId())
-			.queryParam(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(scopes, " "))
-			.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState());
-		if (authorizationRequest.getRedirectUri() != null) {
-			uriBuilder.queryParam(OAuth2ParameterNames.REDIRECT_URI, authorizationRequest.getRedirectUri());
-		}
-
-		return uriBuilder.build().encode().toUri();
-	}
-}

+ 242 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java

@@ -0,0 +1,242 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.web;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+
+import static org.assertj.core.api.Assertions.*;
+
+/**
+ * Tests for {@link DefaultOAuth2AuthorizationRequestResolver}.
+ *
+ * @author Joe Grandja
+ */
+public class DefaultOAuth2AuthorizationRequestResolverTests {
+	private ClientRegistration registration1;
+	private ClientRegistration registration2;
+	private ClientRegistrationRepository clientRegistrationRepository;
+	private String authorizationRequestBaseUri = "/oauth2/authorization";
+	private DefaultOAuth2AuthorizationRequestResolver resolver;
+
+	@Before
+	public void setUp() {
+		this.registration1 = ClientRegistration.withRegistrationId("registration-1")
+				.clientId("client-1")
+				.clientSecret("secret")
+				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+				.redirectUriTemplate("{baseUrl}/{action}/oauth2/code/{registrationId}")
+				.scope("user")
+				.authorizationUri("https://provider.com/oauth2/authorize")
+				.tokenUri("https://provider.com/oauth2/token")
+				.userInfoUri("https://provider.com/oauth2/user")
+				.userNameAttributeName("id")
+				.clientName("client-1")
+				.build();
+		this.registration2 = ClientRegistration.withRegistrationId("registration-2")
+				.clientId("client-2")
+				.clientSecret("secret")
+				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+				.redirectUriTemplate("{baseUrl}/{action}/oauth2/code/{registrationId}")
+				.scope("openid", "profile", "email")
+				.authorizationUri("https://provider.com/oauth2/authorize")
+				.tokenUri("https://provider.com/oauth2/token")
+				.userInfoUri("https://provider.com/oauth2/userinfo")
+				.jwkSetUri("https://provider.com/oauth2/keys")
+				.clientName("client-2")
+				.build();
+		this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(
+				this.registration1, this.registration2);
+		this.resolver = new DefaultOAuth2AuthorizationRequestResolver(
+				this.clientRegistrationRepository, this.authorizationRequestBaseUri);
+	}
+
+	@Test
+	public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new DefaultOAuth2AuthorizationRequestResolver(null, this.authorizationRequestBaseUri))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void constructorWhenAuthorizationRequestBaseUriIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new DefaultOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository, null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void resolveWhenNotAuthorizationRequestThenDoesNotResolve() {
+		String requestUri = "/path";
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+
+		OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+		assertThat(authorizationRequest).isNull();
+	}
+
+	@Test
+	public void resolveWhenAuthorizationRequestWithInvalidClientThenThrowIllegalArgumentException() {
+		ClientRegistration clientRegistration = this.registration1;
+		String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId() + "-invalid";
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+
+		assertThatThrownBy(() -> this.resolver.resolve(request))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("Invalid Client Registration with Id: " + clientRegistration.getRegistrationId() + "-invalid");
+	}
+
+	@Test
+	public void resolveWhenAuthorizationRequestWithValidClientThenResolves() {
+		ClientRegistration clientRegistration = this.registration1;
+		String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+
+		OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+		assertThat(authorizationRequest).isNotNull();
+		assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo(
+				clientRegistration.getProviderDetails().getAuthorizationUri());
+		assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
+		assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE);
+		assertThat(authorizationRequest.getClientId()).isEqualTo(clientRegistration.getClientId());
+		assertThat(authorizationRequest.getRedirectUri())
+				.isEqualTo("http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId());
+		assertThat(authorizationRequest.getScopes()).isEqualTo(clientRegistration.getScopes());
+		assertThat(authorizationRequest.getState()).isNotNull();
+		assertThat(authorizationRequest.getAdditionalParameters())
+				.containsExactly(entry(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()));
+		assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fregistration-1");
+	}
+
+	@Test
+	public void resolveWhenClientAuthorizationRequiredExceptionAvailableThenResolves() {
+		ClientRegistration clientRegistration = this.registration2;
+		String requestUri = "/path";
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		request.setAttribute(
+				OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME,
+				new ClientAuthorizationRequiredException(clientRegistration.getRegistrationId()));
+
+		OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+		assertThat(authorizationRequest).isNotNull();
+		assertThat(authorizationRequest.getAdditionalParameters())
+				.containsExactly(entry(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()));
+	}
+
+	@Test
+	public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriExpanded() {
+		ClientRegistration clientRegistration = this.registration2;
+		String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+
+		OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+		assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(
+				clientRegistration.getRedirectUriTemplate());
+		assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
+				"http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId());
+	}
+
+	@Test
+	public void resolveWhenAuthorizationRequestIncludesPort80ThenExpandedRedirectUriExcludesPort() {
+		ClientRegistration clientRegistration = this.registration1;
+		String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setScheme("http");
+		request.setServerName("example.com");
+		request.setServerPort(80);
+		request.setServletPath(requestUri);
+
+		OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+		assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Fexample.com%2Flogin%2Foauth2%2Fcode%2Fregistration-1");
+	}
+
+	@Test
+	public void resolveWhenAuthorizationRequestIncludesPort443ThenExpandedRedirectUriExcludesPort() {
+		ClientRegistration clientRegistration = this.registration1;
+		String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setScheme("https");
+		request.setServerName("example.com");
+		request.setServerPort(443);
+		request.setServletPath(requestUri);
+
+		OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+		assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=https%3A%2F%2Fexample.com%2Flogin%2Foauth2%2Fcode%2Fregistration-1");
+	}
+
+	@Test
+	public void resolveWhenClientAuthorizationRequiredExceptionAvailableThenRedirectUriIsAuthorize() {
+		ClientRegistration clientRegistration = this.registration1;
+		String requestUri = "/path";
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		request.setAttribute(
+				OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME,
+				new ClientAuthorizationRequiredException(clientRegistration.getRegistrationId()));
+
+		OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+		assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fcode%2Fregistration-1");
+	}
+
+	@Test
+	public void resolveWhenAuthorizationRequestOAuth2LoginThenRedirectUriIsLogin() {
+		ClientRegistration clientRegistration = this.registration2;
+		String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+
+		OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+		assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-2&scope=openid\\+profile\\+email&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fregistration-2");
+	}
+
+	@Test
+	public void resolveWhenAuthorizationRequestHasActionParameterAuthorizeThenRedirectUriIsAuthorize() {
+		ClientRegistration clientRegistration = this.registration1;
+		String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.addParameter("action", "authorize");
+		request.setServletPath(requestUri);
+
+		OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+		assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fcode%2Fregistration-1");
+	}
+
+	@Test
+	public void resolveWhenAuthorizationRequestHasActionParameterLoginThenRedirectUriIsLogin() {
+		ClientRegistration clientRegistration = this.registration2;
+		String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.addParameter("action", "login");
+		request.setServletPath(requestUri);
+
+		OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+		assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-2&scope=openid\\+profile\\+email&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fregistration-2");
+	}
+}

+ 73 - 102
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@ package org.springframework.security.oauth2.client.web;
 
 import org.junit.Before;
 import org.junit.Test;
-import org.mockito.ArgumentCaptor;
 import org.springframework.http.HttpStatus;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
@@ -29,16 +28,22 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.web.savedrequest.RequestCache;
+import org.springframework.util.ClassUtils;
+import org.springframework.web.util.UriComponentsBuilder;
 
 import javax.servlet.FilterChain;
 import javax.servlet.ServletRequest;
 import javax.servlet.ServletResponse;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
+import java.lang.reflect.Constructor;
+import java.util.HashMap;
+import java.util.Map;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.Mockito.*;
+import static org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME;
 
 /**
  * Tests for {@link OAuth2AuthorizationRequestRedirectFilter}.
@@ -100,7 +105,9 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
 
 	@Test
 	public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2AuthorizationRequestRedirectFilter(null))
+		Constructor<OAuth2AuthorizationRequestRedirectFilter> constructor = ClassUtils.getConstructorIfAvailable(
+				OAuth2AuthorizationRequestRedirectFilter.class, ClientRegistrationRepository.class);
+		assertThatThrownBy(() -> constructor.newInstance(null))
 				.isInstanceOf(IllegalArgumentException.class);
 	}
 
@@ -110,6 +117,14 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
 				.isInstanceOf(IllegalArgumentException.class);
 	}
 
+	@Test
+	public void constructorWhenAuthorizationRequestResolverIsNullThenThrowIllegalArgumentException() {
+		Constructor<OAuth2AuthorizationRequestRedirectFilter> constructor = ClassUtils.getConstructorIfAvailable(
+				OAuth2AuthorizationRequestRedirectFilter.class, OAuth2AuthorizationRequestResolver.class);
+		assertThatThrownBy(() -> constructor.newInstance(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
 	@Test
 	public void setAuthorizationRequestRepositoryWhenAuthorizationRequestRepositoryIsNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> this.filter.setAuthorizationRequestRepository(null))
@@ -165,7 +180,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
 
 		verifyZeroInteractions(filterChain);
 
-		assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost/login/oauth2/code/registration-1");
+		assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fregistration-1");
 	}
 
 	@Test
@@ -201,7 +216,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
 
 		verifyZeroInteractions(filterChain);
 
-		assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=token&client_id=client-3&scope=openid%20profile%20email&state=.{15,}&redirect_uri=http://localhost/authorize/oauth2/implicit/registration-3");
+		assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=token&client_id=client-3&scope=openid\\+profile\\+email&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fimplicit%2Fregistration-3");
 	}
 
 	@Test
@@ -239,75 +254,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
 
 		verifyZeroInteractions(filterChain);
 
-		assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost/login/oauth2/code/registration-1");
-	}
-
-	@Test
-	public void doFilterWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriExpanded() throws Exception {
-		String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
-			"/" + this.registration2.getRegistrationId();
-		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
-		request.setServletPath(requestUri);
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
-				mock(AuthorizationRequestRepository.class);
-		this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);
-
-		this.filter.doFilter(request, response, filterChain);
-
-		ArgumentCaptor<OAuth2AuthorizationRequest> authorizationRequestArgCaptor =
-			ArgumentCaptor.forClass(OAuth2AuthorizationRequest.class);
-
-		verifyZeroInteractions(filterChain);
-		verify(authorizationRequestRepository).saveAuthorizationRequest(
-			authorizationRequestArgCaptor.capture(), any(HttpServletRequest.class), any(HttpServletResponse.class));
-
-		OAuth2AuthorizationRequest authorizationRequest = authorizationRequestArgCaptor.getValue();
-
-		assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(
-			this.registration2.getRedirectUriTemplate());
-		assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
-			"http://localhost/login/oauth2/code/registration-2");
-	}
-
-	@Test
-	public void doFilterWhenAuthorizationRequestIncludesPort80ThenExpandedRedirectUriExcludesPort() throws Exception {
-		String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
-			"/" + this.registration1.getRegistrationId();
-		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
-		request.setScheme("http");
-		request.setServerName("example.com");
-		request.setServerPort(80);
-		request.setServletPath(requestUri);
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		this.filter.doFilter(request, response, filterChain);
-
-		verifyZeroInteractions(filterChain);
-
-		assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://example.com/login/oauth2/code/registration-1");
-	}
-
-	@Test
-	public void doFilterWhenAuthorizationRequestIncludesPort443ThenExpandedRedirectUriExcludesPort() throws Exception {
-		String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
-			"/" + this.registration1.getRegistrationId();
-		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
-		request.setScheme("https");
-		request.setServerName("example.com");
-		request.setServerPort(443);
-		request.setServletPath(requestUri);
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		this.filter.doFilter(request, response, filterChain);
-
-		verifyZeroInteractions(filterChain);
-
-		assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=https://example.com/login/oauth2/code/registration-1");
+		assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fregistration-1");
 	}
 
 	@Test
@@ -325,13 +272,13 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
 
 		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
 
-		assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost/authorize/oauth2/code/registration-1");
-
+		assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fcode%2Fregistration-1");
 		verify(this.requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
+		assertThat(request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME)).isNull();
 	}
 
 	@Test
-	public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownThenRedirectUriIsAuthorize() throws Exception {
+	public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownButAuthorizationRequestNotResolvedThenStatusInternalServerError() throws Exception {
 		String requestUri = "/path";
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 		request.setServletPath(requestUri);
@@ -341,60 +288,84 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
 		doThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId()))
 				.when(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
 
-		this.filter.doFilter(request, response, filterChain);
+		OAuth2AuthorizationRequestResolver resolver = req -> null;
+		OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
 
-		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+		filter.doFilter(request, response, filterChain);
 
-		assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost/authorize/oauth2/code/registration-1");
-	}
-
-	@Test
-	public void doFilterWhenAuthorizationRequestOAuth2LoginThenRedirectUriIsLogin() throws Exception {
-		String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
-				"/" + this.registration2.getRegistrationId();
-		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
-		request.setServletPath(requestUri);
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		this.filter.doFilter(request, response, filterChain);
+		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
 
 		verifyZeroInteractions(filterChain);
 
-		assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-2&scope=openid%20profile%20email&state=.{15,}&redirect_uri=http://localhost/login/oauth2/code/registration-2");
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
+		assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase());
 	}
 
+	// gh-4911
 	@Test
-	public void doFilterWhenAuthorizationRequestHasActionParameterAuthorizeThenRedirectUriIsAuthorize() throws Exception {
+	public void doFilterWhenAuthorizationRequestAndAdditionalParametersProvidedThenAuthorizationRequestIncludesAdditionalParameters() throws Exception {
 		String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
 				"/" + this.registration1.getRegistrationId();
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
-		request.addParameter("action", "authorize");
 		request.setServletPath(requestUri);
+		request.addParameter("idp", "https://other.provider.com");
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
 
-		this.filter.doFilter(request, response, filterChain);
+		OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
+				this.clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
+
+		OAuth2AuthorizationRequestResolver resolver = req -> {
+			OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(req);
+			Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
+			additionalParameters.put("idp", req.getParameter("idp"));
+			return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
+					.additionalParameters(additionalParameters)
+					.build();
+		};
+		OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
+
+		filter.doFilter(request, response, filterChain);
 
 		verifyZeroInteractions(filterChain);
 
-		assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost/authorize/oauth2/code/registration-1");
+		assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fregistration-1&idp=https%3A%2F%2Fother.provider.com");
 	}
 
+	// gh-4911, gh-5244
 	@Test
-	public void doFilterWhenAuthorizationRequestHasActionParameterLoginThenRedirectUriIsLogin() throws Exception {
+	public void doFilterWhenAuthorizationRequestAndCustomAuthorizationRequestUriSetThenCustomAuthorizationRequestUriUsed() throws Exception {
 		String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
-				"/" + this.registration2.getRegistrationId();
+				"/" + this.registration1.getRegistrationId();
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
-		request.addParameter("action", "login");
 		request.setServletPath(requestUri);
+		String loginHintParamName = "login_hint";
+		request.addParameter(loginHintParamName, "user@provider.com");
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
 
-		this.filter.doFilter(request, response, filterChain);
+		OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
+				this.clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
+
+		OAuth2AuthorizationRequestResolver resolver = req -> {
+			OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(req);
+			Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
+			additionalParameters.put(loginHintParamName, req.getParameter(loginHintParamName));
+			String customAuthorizationRequestUri = UriComponentsBuilder
+					.fromUriString(defaultAuthorizationRequest.getAuthorizationRequestUri())
+					.queryParam(loginHintParamName, additionalParameters.get(loginHintParamName))
+					.build(true).toUriString();
+			return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
+					.additionalParameters(additionalParameters)
+					.authorizationRequestUri(customAuthorizationRequestUri)
+					.build();
+		};
+		OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
+
+		filter.doFilter(request, response, filterChain);
 
 		verifyZeroInteractions(filterChain);
 
-		assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-2&scope=openid%20profile%20email&state=.{15,}&redirect_uri=http://localhost/login/oauth2/code/registration-2");
+		assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fregistration-1&login_hint=user@provider\\.com");
 	}
 }

+ 0 - 58
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestUriBuilderTests.java

@@ -1,58 +0,0 @@
-/*
- * Copyright 2002-2017 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.springframework.security.oauth2.client.web;
-
-import org.junit.Test;
-import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
-
-import java.net.URI;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.HashSet;
-
-import static org.assertj.core.api.Assertions.assertThat;
-
-/**
- * Tests for {@link OAuth2AuthorizationRequestUriBuilder}.
- *
- * @author Rob Winch
- * @since 5.0
- */
-public class OAuth2AuthorizationRequestUriBuilderTests {
-	private OAuth2AuthorizationRequestUriBuilder builder = new OAuth2AuthorizationRequestUriBuilder();
-
-	@Test(expected = IllegalArgumentException.class)
-	public void buildWhenAuthorizationRequestIsNullThenThrowIllegalArgumentException() {
-		this.builder.build(null);
-	}
-
-	@Test
-	public void buildWhenScopeMultiThenSeparatedByEncodedSpace() {
-		OAuth2AuthorizationRequest request = OAuth2AuthorizationRequest.implicit()
-			.additionalParameters(Collections.singletonMap("foo", "bar"))
-			.authorizationUri("https://idp.example.com/oauth2/v2/auth")
-			.clientId("client-id")
-			.state("thestate")
-			.redirectUri("https://client.example.com/login/oauth2")
-			.scopes(new HashSet<>(Arrays.asList("openid", "user")))
-			.build();
-
-		URI result = this.builder.build(request);
-
-		assertThat(result.toASCIIString()).isEqualTo("https://idp.example.com/oauth2/v2/auth?response_type=token&client_id=client-id&scope=openid%20user&state=thestate&redirect_uri=https://client.example.com/login/oauth2");
-	}
-}

+ 92 - 1
oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -19,14 +19,18 @@ import org.springframework.security.core.SpringSecurityCoreVersion;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
+import org.springframework.util.StringUtils;
 
 import java.io.Serializable;
+import java.io.UnsupportedEncodingException;
+import java.net.URLEncoder;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.LinkedHashMap;
 import java.util.LinkedHashSet;
 import java.util.Map;
 import java.util.Set;
+import java.util.StringJoiner;
 import java.util.stream.Collectors;
 
 /**
@@ -50,6 +54,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 	private Set<String> scopes;
 	private String state;
 	private Map<String, Object> additionalParameters;
+	private String authorizationRequestUri;
 
 	private OAuth2AuthorizationRequest() {
 	}
@@ -126,6 +131,20 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 		return this.additionalParameters;
 	}
 
+	/**
+	 * Returns the {@code URI} string representation of the OAuth 2.0 Authorization Request.
+	 *
+	 * <p>
+	 * <b>NOTE:</b> The {@code URI} string is encoded in the
+	 * {@code application/x-www-form-urlencoded} MIME format.
+	 *
+	 * @since 5.1
+	 * @return the {@code URI} string representation of the OAuth 2.0 Authorization Request
+	 */
+	public String getAuthorizationRequestUri() {
+		return this.authorizationRequestUri;
+	}
+
 	/**
 	 * Returns a new {@link Builder}, initialized with the authorization code grant type.
 	 *
@@ -144,6 +163,26 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 		return new Builder(AuthorizationGrantType.IMPLICIT);
 	}
 
+	/**
+	 * Returns a new {@link Builder}, initialized with the values
+	 * from the provided {@code authorizationRequest}.
+	 *
+	 * @since 5.1
+	 * @param authorizationRequest the authorization request used for initializing the {@link Builder}
+	 * @return the {@link Builder}
+	 */
+	public static Builder from(OAuth2AuthorizationRequest authorizationRequest) {
+		Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
+
+		return new Builder(authorizationRequest.getGrantType())
+				.authorizationUri(authorizationRequest.getAuthorizationUri())
+				.clientId(authorizationRequest.getClientId())
+				.redirectUri(authorizationRequest.getRedirectUri())
+				.scopes(authorizationRequest.getScopes())
+				.state(authorizationRequest.getState())
+				.additionalParameters(authorizationRequest.getAdditionalParameters());
+	}
+
 	/**
 	 * A builder for {@link OAuth2AuthorizationRequest}.
 	 */
@@ -156,6 +195,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 		private Set<String> scopes;
 		private String state;
 		private Map<String, Object> additionalParameters;
+		private String authorizationRequestUri;
 
 		private Builder(AuthorizationGrantType authorizationGrantType) {
 			Assert.notNull(authorizationGrantType, "authorizationGrantType cannot be null");
@@ -247,6 +287,22 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 			return this;
 		}
 
+		/**
+		 * Sets the {@code URI} string representation of the OAuth 2.0 Authorization Request.
+		 *
+		 * <p>
+		 * <b>NOTE:</b> The {@code URI} string is <b>required</b> to be encoded in the
+		 * {@code application/x-www-form-urlencoded} MIME format.
+		 *
+		 * @since 5.1
+		 * @param authorizationRequestUri the {@code URI} string representation of the OAuth 2.0 Authorization Request
+		 * @return the {@link Builder}
+		 */
+		public Builder authorizationRequestUri(String authorizationRequestUri) {
+			this.authorizationRequestUri = authorizationRequestUri;
+			return this;
+		}
+
 		/**
 		 * Builds a new {@link OAuth2AuthorizationRequest}.
 		 *
@@ -272,7 +328,42 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 			authorizationRequest.additionalParameters = Collections.unmodifiableMap(
 				CollectionUtils.isEmpty(this.additionalParameters) ?
 					Collections.emptyMap() : new LinkedHashMap<>(this.additionalParameters));
+			authorizationRequest.authorizationRequestUri =
+					StringUtils.hasText(this.authorizationRequestUri) ?
+						this.authorizationRequestUri : this.buildAuthorizationRequestUri();
+
 			return authorizationRequest;
 		}
+
+		private String buildAuthorizationRequestUri() {
+			Map<String, String> parameters = new LinkedHashMap<>();
+			parameters.put(OAuth2ParameterNames.RESPONSE_TYPE, this.responseType.getValue());
+			parameters.put(OAuth2ParameterNames.CLIENT_ID, this.clientId);
+			if (!CollectionUtils.isEmpty(this.scopes)) {
+				parameters.put(OAuth2ParameterNames.SCOPE,
+						StringUtils.collectionToDelimitedString(this.scopes, " "));
+			}
+			if (this.state != null) {
+				parameters.put(OAuth2ParameterNames.STATE, this.state);
+			}
+			if (this.redirectUri != null) {
+				parameters.put(OAuth2ParameterNames.REDIRECT_URI, this.redirectUri);
+			}
+			if (!CollectionUtils.isEmpty(this.additionalParameters)) {
+				this.additionalParameters.entrySet().stream()
+						.filter(e -> !e.getKey().equals(OAuth2ParameterNames.REGISTRATION_ID))
+						.forEach(e -> parameters.put(e.getKey(), e.getValue().toString()));
+			}
+
+			try {
+				StringJoiner queryParams = new StringJoiner("&");
+				for (String paramName : parameters.keySet()) {
+					queryParams.add(paramName + "=" + URLEncoder.encode(parameters.get(paramName), "UTF-8"));
+				}
+				return this.authorizationUri + "?" + queryParams.toString();
+			} catch (UnsupportedEncodingException ex) {
+				throw new IllegalArgumentException("Unable to build authorization request uri: " + ex.getMessage(), ex);
+			}
+		}
 	}
 }

+ 201 - 84
oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,9 +16,6 @@
 package org.springframework.security.oauth2.core.endpoint;
 
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.powermock.core.classloader.annotations.PrepareForTest;
-import org.powermock.modules.junit4.PowerMockRunner;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 
 import java.util.Arrays;
@@ -27,8 +24,7 @@ import java.util.LinkedHashSet;
 import java.util.Map;
 import java.util.Set;
 
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.assertj.core.api.Assertions.*;
 
 /**
  * Tests for {@link OAuth2AuthorizationRequest}.
@@ -36,8 +32,6 @@ import static org.assertj.core.api.Assertions.assertThatCode;
  * @author Luander Ribeiro
  * @author Joe Grandja
  */
-@RunWith(PowerMockRunner.class)
-@PrepareForTest(OAuth2AuthorizationRequest.class)
 public class OAuth2AuthorizationRequestTests {
 	private static final String AUTHORIZATION_URI = "https://provider.com/oauth2/authorize";
 	private static final String CLIENT_ID = "client-id";
@@ -45,59 +39,107 @@ public class OAuth2AuthorizationRequestTests {
 	private static final Set<String> SCOPES = new LinkedHashSet<>(Arrays.asList("scope1", "scope2"));
 	private static final String STATE = "state";
 
-	@Test(expected = IllegalArgumentException.class)
+	@Test
 	public void buildWhenAuthorizationUriIsNullThenThrowIllegalArgumentException() {
-		OAuth2AuthorizationRequest.authorizationCode()
-			.authorizationUri(null)
-			.clientId(CLIENT_ID)
-			.redirectUri(REDIRECT_URI)
-			.scopes(SCOPES)
-			.state(STATE)
-			.build();
+		assertThatThrownBy(() ->
+				OAuth2AuthorizationRequest.authorizationCode()
+					.authorizationUri(null)
+					.clientId(CLIENT_ID)
+					.redirectUri(REDIRECT_URI)
+					.scopes(SCOPES)
+					.state(STATE)
+					.build()
+		).isInstanceOf(IllegalArgumentException.class);
 	}
 
-	@Test(expected = IllegalArgumentException.class)
+	@Test
 	public void buildWhenClientIdIsNullThenThrowIllegalArgumentException() {
-		OAuth2AuthorizationRequest.authorizationCode()
-			.authorizationUri(AUTHORIZATION_URI)
-			.clientId(null)
-			.redirectUri(REDIRECT_URI)
-			.scopes(SCOPES)
-			.state(STATE)
-			.build();
+		assertThatThrownBy(() ->
+				OAuth2AuthorizationRequest.authorizationCode()
+					.authorizationUri(AUTHORIZATION_URI)
+					.clientId(null)
+					.redirectUri(REDIRECT_URI)
+					.scopes(SCOPES)
+					.state(STATE)
+					.build()
+		).isInstanceOf(IllegalArgumentException.class);
 	}
 
-	@Test(expected = IllegalArgumentException.class)
+	@Test
 	public void buildWhenRedirectUriIsNullForImplicitThenThrowIllegalArgumentException() {
-		OAuth2AuthorizationRequest.implicit()
-			.authorizationUri(AUTHORIZATION_URI)
-			.clientId(CLIENT_ID)
-			.redirectUri(null)
-			.scopes(SCOPES)
-			.state(STATE)
-			.build();
+		assertThatThrownBy(() ->
+				OAuth2AuthorizationRequest.implicit()
+					.authorizationUri(AUTHORIZATION_URI)
+					.clientId(CLIENT_ID)
+					.redirectUri(null)
+					.scopes(SCOPES)
+					.state(STATE)
+					.build()
+		).isInstanceOf(IllegalArgumentException.class);
 	}
 
 	@Test
 	public void buildWhenRedirectUriIsNullForAuthorizationCodeThenDoesNotThrowAnyException() {
-		assertThatCode(() -> OAuth2AuthorizationRequest.authorizationCode()
-			.authorizationUri(AUTHORIZATION_URI)
-			.clientId(CLIENT_ID)
-			.redirectUri(null)
-			.scopes(SCOPES)
-			.state(STATE)
-			.build()).doesNotThrowAnyException();
+		assertThatCode(() ->
+				OAuth2AuthorizationRequest.authorizationCode()
+					.authorizationUri(AUTHORIZATION_URI)
+					.clientId(CLIENT_ID)
+					.redirectUri(null)
+					.scopes(SCOPES)
+					.state(STATE)
+					.build())
+				.doesNotThrowAnyException();
+	}
+
+	@Test
+	public void buildWhenScopesIsNullThenDoesNotThrowAnyException() {
+		assertThatCode(() ->
+				OAuth2AuthorizationRequest.authorizationCode()
+					.authorizationUri(AUTHORIZATION_URI)
+					.clientId(CLIENT_ID)
+					.redirectUri(REDIRECT_URI)
+					.scopes(null)
+					.state(STATE)
+					.build())
+				.doesNotThrowAnyException();
+	}
+
+	@Test
+	public void buildWhenStateIsNullThenDoesNotThrowAnyException() {
+		assertThatCode(() ->
+				OAuth2AuthorizationRequest.authorizationCode()
+					.authorizationUri(AUTHORIZATION_URI)
+					.clientId(CLIENT_ID)
+					.redirectUri(REDIRECT_URI)
+					.scopes(SCOPES)
+					.state(null)
+					.build())
+				.doesNotThrowAnyException();
+	}
+
+	@Test
+	public void buildWhenAdditionalParametersIsNullThenDoesNotThrowAnyException() {
+		assertThatCode(() ->
+				OAuth2AuthorizationRequest.authorizationCode()
+					.authorizationUri(AUTHORIZATION_URI)
+					.clientId(CLIENT_ID)
+					.redirectUri(REDIRECT_URI)
+					.scopes(SCOPES)
+					.state(STATE)
+					.additionalParameters(null)
+					.build())
+				.doesNotThrowAnyException();
 	}
 
 	@Test
 	public void buildWhenImplicitThenGrantTypeResponseTypeIsSet() {
 		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.implicit()
-			.authorizationUri(AUTHORIZATION_URI)
-			.clientId(CLIENT_ID)
-			.redirectUri(REDIRECT_URI)
-			.scopes(SCOPES)
-			.state(STATE)
-			.build();
+				.authorizationUri(AUTHORIZATION_URI)
+				.clientId(CLIENT_ID)
+				.redirectUri(REDIRECT_URI)
+				.scopes(SCOPES)
+				.state(STATE)
+				.build();
 		assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.IMPLICIT);
 		assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.TOKEN);
 	}
@@ -105,30 +147,31 @@ public class OAuth2AuthorizationRequestTests {
 	@Test
 	public void buildWhenAuthorizationCodeThenGrantTypeResponseTypeIsSet() {
 		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
-			.authorizationUri(AUTHORIZATION_URI)
-			.clientId(CLIENT_ID)
-			.redirectUri(null)
-			.scopes(SCOPES)
-			.state(STATE)
-			.build();
+				.authorizationUri(AUTHORIZATION_URI)
+				.clientId(CLIENT_ID)
+				.redirectUri(null)
+				.scopes(SCOPES)
+				.state(STATE)
+				.build();
 		assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
 		assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE);
 	}
 
 	@Test
-	public void buildWhenAllAttributesProvidedThenAllAttributesAreSet() {
+	public void buildWhenAllValuesProvidedThenAllValuesAreSet() {
 		Map<String, Object> additionalParameters = new HashMap<>();
 		additionalParameters.put("param1", "value1");
 		additionalParameters.put("param2", "value2");
 
 		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
-			.authorizationUri(AUTHORIZATION_URI)
-			.clientId(CLIENT_ID)
-			.redirectUri(REDIRECT_URI)
-			.scopes(SCOPES)
-			.state(STATE)
-			.additionalParameters(additionalParameters)
-			.build();
+				.authorizationUri(AUTHORIZATION_URI)
+				.clientId(CLIENT_ID)
+				.redirectUri(REDIRECT_URI)
+				.scopes(SCOPES)
+				.state(STATE)
+				.additionalParameters(additionalParameters)
+				.authorizationRequestUri(AUTHORIZATION_URI)
+				.build();
 
 		assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo(AUTHORIZATION_URI);
 		assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
@@ -138,39 +181,113 @@ public class OAuth2AuthorizationRequestTests {
 		assertThat(authorizationRequest.getScopes()).isEqualTo(SCOPES);
 		assertThat(authorizationRequest.getState()).isEqualTo(STATE);
 		assertThat(authorizationRequest.getAdditionalParameters()).isEqualTo(additionalParameters);
+		assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo(AUTHORIZATION_URI);
 	}
 
 	@Test
-	public void buildWhenScopesIsNullThenDoesNotThrowAnyException() {
-		assertThatCode(() -> OAuth2AuthorizationRequest.authorizationCode()
-			.authorizationUri(AUTHORIZATION_URI)
-			.clientId(CLIENT_ID)
-			.redirectUri(REDIRECT_URI)
-			.scopes(null)
-			.state(STATE)
-			.build()).doesNotThrowAnyException();
+	public void buildWhenScopesMultiThenSeparatedByEncodedSpace() {
+		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.implicit()
+				.authorizationUri(AUTHORIZATION_URI)
+				.clientId(CLIENT_ID)
+				.redirectUri(REDIRECT_URI)
+				.scopes(SCOPES)
+				.state(STATE)
+				.build();
+
+		assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo("https://provider.com/oauth2/authorize?response_type=token&client_id=client-id&scope=scope1+scope2&state=state&redirect_uri=http%3A%2F%2Fexample.com");
 	}
 
 	@Test
-	public void buildWhenStateIsNullThenDoesNotThrowAnyException() {
-		assertThatCode(() -> OAuth2AuthorizationRequest.authorizationCode()
-			.authorizationUri(AUTHORIZATION_URI)
-			.clientId(CLIENT_ID)
-			.redirectUri(REDIRECT_URI)
-			.scopes(SCOPES)
-			.state(null)
-			.build()).doesNotThrowAnyException();
+	public void buildWhenAuthorizationRequestUriSetThenOverridesDefault() {
+		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri(AUTHORIZATION_URI)
+				.clientId(CLIENT_ID)
+				.redirectUri(REDIRECT_URI)
+				.scopes(SCOPES)
+				.state(STATE)
+				.authorizationRequestUri(AUTHORIZATION_URI)
+				.build();
+		assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo(AUTHORIZATION_URI);
 	}
 
 	@Test
-	public void buildWhenAdditionalParametersIsNullThenDoesNotThrowAnyException() {
-		assertThatCode(() -> OAuth2AuthorizationRequest.authorizationCode()
-			.authorizationUri(AUTHORIZATION_URI)
-			.clientId(CLIENT_ID)
-			.redirectUri(REDIRECT_URI)
-			.scopes(SCOPES)
-			.state(STATE)
-			.additionalParameters(null)
-			.build()).doesNotThrowAnyException();
+	public void buildWhenAuthorizationRequestUriNotSetThenDefaultSet() {
+		Map<String, Object> additionalParameters = new HashMap<>();
+		additionalParameters.put("param1", "value1");
+		additionalParameters.put("param2", "value2");
+
+		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri(AUTHORIZATION_URI)
+				.clientId(CLIENT_ID)
+				.redirectUri(REDIRECT_URI)
+				.scopes(SCOPES)
+				.state(STATE)
+				.additionalParameters(additionalParameters)
+				.build();
+
+		assertThat(authorizationRequest.getAuthorizationRequestUri()).isNotNull();
+		assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo("https://provider.com/oauth2/authorize?response_type=code&client_id=client-id&scope=scope1+scope2&state=state&redirect_uri=http%3A%2F%2Fexample.com&param1=value1&param2=value2");
+	}
+
+	@Test
+	public void buildWhenRequiredParametersSetThenAuthorizationRequestUriIncludesRequiredParametersOnly() {
+		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri(AUTHORIZATION_URI)
+				.clientId(CLIENT_ID)
+				.build();
+
+		assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo("https://provider.com/oauth2/authorize?response_type=code&client_id=client-id");
+	}
+
+	@Test
+	public void buildWhenAuthorizationRequestIncludesRegistrationIdParameterThenAuthorizationRequestUriDoesNotIncludeRegistrationIdParameter() {
+		Map<String, Object> additionalParameters = new HashMap<>();
+		additionalParameters.put("param1", "value1");
+		additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, "registration1");
+
+		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri(AUTHORIZATION_URI)
+				.clientId(CLIENT_ID)
+				.redirectUri(REDIRECT_URI)
+				.scopes(SCOPES)
+				.state(STATE)
+				.additionalParameters(additionalParameters)
+				.build();
+
+		assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo("https://provider.com/oauth2/authorize?response_type=code&client_id=client-id&scope=scope1+scope2&state=state&redirect_uri=http%3A%2F%2Fexample.com&param1=value1");
+	}
+
+	@Test
+	public void fromWhenAuthorizationRequestIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> OAuth2AuthorizationRequest.from(null)).isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void fromWhenAuthorizationRequestProvidedThenValuesAreCopied() {
+		Map<String, Object> additionalParameters = new HashMap<>();
+		additionalParameters.put("param1", "value1");
+		additionalParameters.put("param2", "value2");
+
+		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri(AUTHORIZATION_URI)
+				.clientId(CLIENT_ID)
+				.redirectUri(REDIRECT_URI)
+				.scopes(SCOPES)
+				.state(STATE)
+				.additionalParameters(additionalParameters)
+				.build();
+
+		OAuth2AuthorizationRequest authorizationRequestCopy =
+				OAuth2AuthorizationRequest.from(authorizationRequest).build();
+
+		assertThat(authorizationRequestCopy.getAuthorizationUri()).isEqualTo(authorizationRequest.getAuthorizationUri());
+		assertThat(authorizationRequestCopy.getGrantType()).isEqualTo(authorizationRequest.getGrantType());
+		assertThat(authorizationRequestCopy.getResponseType()).isEqualTo(authorizationRequest.getResponseType());
+		assertThat(authorizationRequestCopy.getClientId()).isEqualTo(authorizationRequest.getClientId());
+		assertThat(authorizationRequestCopy.getRedirectUri()).isEqualTo(authorizationRequest.getRedirectUri());
+		assertThat(authorizationRequestCopy.getScopes()).isEqualTo(authorizationRequest.getScopes());
+		assertThat(authorizationRequestCopy.getState()).isEqualTo(authorizationRequest.getState());
+		assertThat(authorizationRequestCopy.getAdditionalParameters()).isEqualTo(authorizationRequest.getAdditionalParameters());
+		assertThat(authorizationRequestCopy.getAuthorizationRequestUri()).isEqualTo(authorizationRequest.getAuthorizationRequestUri());
 	}
 }

+ 1 - 1
samples/boot/authcodegrant/src/integration-test/java/org/springframework/security/samples/OAuth2AuthorizationCodeGrantApplicationTests.java

@@ -88,7 +88,7 @@ public class OAuth2AuthorizationCodeGrantApplicationTests {
 		MvcResult mvcResult = this.mockMvc.perform(get("/repos").with(user("user")))
 			.andExpect(status().is3xxRedirection())
 			.andReturn();
-		assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://github.com/login/oauth/authorize\\?response_type=code&client_id=your-app-client-id&scope=public_repo&state=.{15,}&redirect_uri=http://localhost/github-repos");
+		assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://github.com/login/oauth/authorize\\?response_type=code&client_id=your-app-client-id&scope=public_repo&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fgithub-repos");
 	}
 
 	@Test