瀏覽代碼

Finer variables for OAuth2 redirectUriTemplate expansion

Fixes #6239
Marek Sabo 6 年之前
父節點
當前提交
7cfb17a8a3

+ 43 - 11
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java

@@ -27,6 +27,8 @@ import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.security.web.util.UrlUtils;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
+import org.springframework.web.util.UriComponents;
 import org.springframework.web.util.UriComponentsBuilder;
 
 import javax.servlet.http.HttpServletRequest;
@@ -54,6 +56,7 @@ import java.util.Map;
  */
 public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2AuthorizationRequestResolver {
 	private static final String REGISTRATION_ID_URI_VARIABLE_NAME = "registrationId";
+	private static final char PATH_DELIMITER = '/';
 	private final ClientRegistrationRepository clientRegistrationRepository;
 	private final AntPathRequestMatcher authorizationRequestMatcher;
 	private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
@@ -127,7 +130,7 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
 					") for Client Registration with Id: " + clientRegistration.getRegistrationId());
 		}
 
-		String redirectUriStr = this.expandRedirectUri(request, clientRegistration, redirectUriAction);
+		String redirectUriStr = expandRedirectUri(request, clientRegistration, redirectUriAction);
 
 		OAuth2AuthorizationRequest authorizationRequest = builder
 				.clientId(clientRegistration.getClientId())
@@ -149,20 +152,49 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
 		return null;
 	}
 
-	private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration, String action) {
-		// Supported URI variables -> baseUrl, action, registrationId
-		// Used in -> CommonOAuth2Provider.DEFAULT_REDIRECT_URL = "{baseUrl}/{action}/oauth2/code/{registrationId}"
+	/**
+	 * Expands the {@link ClientRegistration#getRedirectUriTemplate()} with following provided variables:<br/>
+	 * - baseUrl (e.g. https://localhost/app) <br/>
+	 * - baseScheme (e.g. https) <br/>
+	 * - baseHost (e.g. localhost) <br/>
+	 * - basePort (e.g. :8080) <br/>
+	 * - basePath (e.g. /app) <br/>
+	 * - registrationId (e.g. google) <br/>
+	 * - action (e.g. login) <br/>
+	 * <p/>
+	 * Null variables are provided as empty strings.
+	 * <p/>
+	 * Default redirectUriTemplate is: {@link org.springframework.security.config.oauth2.client}.CommonOAuth2Provider#DEFAULT_REDIRECT_URL
+	 *
+	 * @return expanded URI
+	 */
+	private static String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration, String action) {
 		Map<String, String> uriVariables = new HashMap<>();
 		uriVariables.put("registrationId", clientRegistration.getRegistrationId());
-		String baseUrl = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
-				.replaceQuery(null)
+
+		UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
 				.replacePath(request.getContextPath())
-				.build()
-				.toUriString();
-		uriVariables.put("baseUrl", baseUrl);
-		if (action != null) {
-			uriVariables.put("action", action);
+				.replaceQuery(null)
+				.fragment(null)
+				.build();
+		String scheme = uriComponents.getScheme();
+		uriVariables.put("baseScheme", scheme == null ? "" : scheme);
+		String host = uriComponents.getHost();
+		uriVariables.put("baseHost", host == null ? "" : host);
+		// following logic is based on HierarchicalUriComponents#toUriString()
+		int port = uriComponents.getPort();
+		uriVariables.put("basePort", port == -1 ? "" : ":" + port);
+		String path = uriComponents.getPath();
+		if (StringUtils.hasLength(path)) {
+			if (path.charAt(0) != PATH_DELIMITER) {
+				path = PATH_DELIMITER + path;
+			}
 		}
+		uriVariables.put("basePath", path == null ? "" : path);
+		uriVariables.put("baseUrl", uriComponents.toUriString());
+
+		uriVariables.put("action", action == null ? "" : action);
+
 		return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUriTemplate())
 				.buildAndExpand(uriVariables)
 				.toUriString();

+ 44 - 13
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java

@@ -30,8 +30,10 @@ import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
 import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
 import org.springframework.web.server.ResponseStatusException;
 import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.util.UriComponents;
 import org.springframework.web.util.UriComponentsBuilder;
 import reactor.core.publisher.Mono;
 
@@ -63,8 +65,9 @@ public class DefaultServerOAuth2AuthorizationRequestResolver
 	/**
 	 * The default pattern used to resolve the {@link ClientRegistration#getRegistrationId()}
 	 */
-	public static final String DEFAULT_AUTHORIZATION_REQUEST_PATTERN = "/oauth2/authorization/{" + DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME
-			+ "}";
+	public static final String DEFAULT_AUTHORIZATION_REQUEST_PATTERN = "/oauth2/authorization/{" + DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME + "}";
+
+	private static final char PATH_DELIMITER = '/';
 
 	private final ServerWebExchangeMatcher authorizationRequestMatcher;
 
@@ -121,8 +124,7 @@ public class DefaultServerOAuth2AuthorizationRequestResolver
 
 	private OAuth2AuthorizationRequest authorizationRequest(ServerWebExchange exchange,
 			ClientRegistration clientRegistration) {
-		String redirectUriStr = this
-					.expandRedirectUri(exchange.getRequest(), clientRegistration);
+		String redirectUriStr = expandRedirectUri(exchange.getRequest(), clientRegistration);
 
 		Map<String, Object> attributes = new HashMap<>();
 		attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());
@@ -153,23 +155,52 @@ public class DefaultServerOAuth2AuthorizationRequestResolver
 				.build();
 	}
 
-	private String expandRedirectUri(ServerHttpRequest request, ClientRegistration clientRegistration) {
-		// Supported URI variables -> baseUrl, action, registrationId
-		// Used in -> CommonOAuth2Provider.DEFAULT_REDIRECT_URL = "{baseUrl}/{action}/oauth2/code/{registrationId}"
+	/**
+	 * Expands the {@link ClientRegistration#getRedirectUriTemplate()} with following provided variables:<br/>
+	 * - baseUrl (e.g. https://localhost/app) <br/>
+	 * - baseScheme (e.g. https) <br/>
+	 * - baseHost (e.g. localhost) <br/>
+	 * - basePort (e.g. :8080) <br/>
+	 * - basePath (e.g. /app) <br/>
+	 * - registrationId (e.g. google) <br/>
+	 * - action (e.g. login) <br/>
+	 * <p/>
+	 * Null variables are provided as empty strings.
+	 * <p/>
+	 * Default redirectUriTemplate is: {@link org.springframework.security.config.oauth2.client}.CommonOAuth2Provider#DEFAULT_REDIRECT_URL
+	 *
+	 * @return expanded URI
+	 */
+	private static String expandRedirectUri(ServerHttpRequest request, ClientRegistration clientRegistration) {
 		Map<String, String> uriVariables = new HashMap<>();
 		uriVariables.put("registrationId", clientRegistration.getRegistrationId());
 
-		String baseUrl = UriComponentsBuilder.fromUri(request.getURI())
+		UriComponents uriComponents = UriComponentsBuilder.fromUri(request.getURI())
 				.replacePath(request.getPath().contextPath().value())
 				.replaceQuery(null)
-				.build()
-				.toUriString();
-		uriVariables.put("baseUrl", baseUrl);
+				.fragment(null)
+				.build();
+		String scheme = uriComponents.getScheme();
+		uriVariables.put("baseScheme", scheme == null ? "" : scheme);
+		String host = uriComponents.getHost();
+		uriVariables.put("baseHost", host == null ? "" : host);
+		// following logic is based on HierarchicalUriComponents#toUriString()
+		int port = uriComponents.getPort();
+		uriVariables.put("basePort", port == -1 ? "" : ":" + port);
+		String path = uriComponents.getPath();
+		if (StringUtils.hasLength(path)) {
+			if (path.charAt(0) != PATH_DELIMITER) {
+				path = PATH_DELIMITER + path;
+			}
+		}
+		uriVariables.put("basePath", path == null ? "" : path);
+		uriVariables.put("baseUrl", uriComponents.toUriString());
 
+		String action = "";
 		if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
-			String loginAction = "login";
-			uriVariables.put("action", loginAction);
+			action = "login";
 		}
+		uriVariables.put("action", action);
 
 		return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUriTemplate())
 				.buildAndExpand(uriVariables)

+ 93 - 2
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java

@@ -41,15 +41,17 @@ import static org.assertj.core.api.Assertions.entry;
 public class DefaultOAuth2AuthorizationRequestResolverTests {
 	private ClientRegistration registration1;
 	private ClientRegistration registration2;
+	private ClientRegistration fineRedirectUriTemplateRegistration;
 	private ClientRegistration pkceRegistration;
 	private ClientRegistrationRepository clientRegistrationRepository;
-	private String authorizationRequestBaseUri = "/oauth2/authorization";
+	private final String authorizationRequestBaseUri = "/oauth2/authorization";
 	private DefaultOAuth2AuthorizationRequestResolver resolver;
 
 	@Before
 	public void setUp() {
 		this.registration1 = TestClientRegistrations.clientRegistration().build();
 		this.registration2 = TestClientRegistrations.clientRegistration2().build();
+		this.fineRedirectUriTemplateRegistration = fineRedirectUriTemplateClientRegistration().build();
 		this.pkceRegistration = TestClientRegistrations.clientRegistration()
 				.registrationId("pkce-client-registration-id")
 				.clientId("pkce-client-id")
@@ -58,7 +60,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
 				.build();
 
 		this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(
-				this.registration1, this.registration2, this.pkceRegistration);
+				this.registration1, this.registration2, this.fineRedirectUriTemplateRegistration, this.pkceRegistration);
 		this.resolver = new DefaultOAuth2AuthorizationRequestResolver(
 				this.clientRegistrationRepository, this.authorizationRequestBaseUri);
 	}
@@ -152,6 +154,80 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
 				"http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId());
 	}
 
+	@Test
+	public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenHttpRedirectUriWithExtraVarsExpanded() {
+		ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration;
+		String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServerPort(8080);
+		request.setServletPath(requestUri);
+
+		OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+		assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUriTemplate());
+		assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
+				"http://localhost:8080/login/oauth2/code/" + clientRegistration.getRegistrationId());
+	}
+
+	@Test
+	public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenHttpsRedirectUriWithExtraVarsExpanded() {
+		ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration;
+		String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setScheme("https");
+		request.setServerPort(8081);
+		request.setServletPath(requestUri);
+
+		OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+		assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUriTemplate());
+		assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
+				"https://localhost:8081/login/oauth2/code/" + clientRegistration.getRegistrationId());
+	}
+
+	@Test
+	public void resolveWhenAuthorizationRequestIncludesPort80ThenExpandedRedirectUriWithExtraVarsExcludesPort() {
+		ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration;
+		String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setScheme("http");
+		request.setServerPort(80);
+		request.setServletPath(requestUri);
+
+		OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+		assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUriTemplate());
+		assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
+				"http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId());
+	}
+
+	@Test
+	public void resolveWhenAuthorizationRequestIncludesPort443ThenExpandedRedirectUriWithExtraVarsExcludesPort() {
+		ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration;
+		String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setScheme("https");
+		request.setServerPort(443);
+		request.setServletPath(requestUri);
+
+		OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+		assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUriTemplate());
+		assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
+				"https://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId());
+	}
+
+	@Test
+	public void resolveWhenAuthorizationRequestHasNoPortThenExpandedRedirectUriWithExtraVarsExcludesPort() {
+		ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration;
+		String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setScheme("https");
+		request.setServerPort(-1);
+		request.setServletPath(requestUri);
+
+		OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+		assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUriTemplate());
+		assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
+				"https://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId());
+	}
+
 	// gh-5520
 	@Test
 	public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriExpandedExcludesQueryString() {
@@ -301,4 +377,19 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
 						"code_challenge_method=S256&" +
 						"code_challenge=([a-zA-Z0-9\\-\\.\\_\\~]){43}");
 	}
+
+	private static ClientRegistration.Builder fineRedirectUriTemplateClientRegistration() {
+		return ClientRegistration.withRegistrationId("fine-redirect-uri-template-client-registration")
+				.redirectUriTemplate("{baseScheme}://{baseHost}{basePort}{basePath}/{action}/oauth2/code/{registrationId}")
+				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+				.scope("read:user")
+				.authorizationUri("https://example.com/login/oauth/authorize")
+				.tokenUri("https://example.com/login/oauth/access_token")
+				.userInfoUri("https://api.example.com/user")
+				.userNameAttributeName("id")
+				.clientName("Fine Redirect Uri Template Client")
+				.clientId("fine-redirect-uri-template-client")
+				.clientSecret("client-secret");
+	}
 }