2
0
Эх сурвалжийг харах

Allow configuring a custom OAuth2AuthorizationRequestResolver

Fixes gh-5521
Joe Grandja 7 жил өмнө
parent
commit
2cd548221d

+ 26 - 6
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java

@@ -28,6 +28,7 @@ import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAut
 import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.web.savedrequest.RequestCache;
 import org.springframework.util.Assert;
@@ -147,6 +148,7 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> exte
 		 */
 		public class AuthorizationEndpointConfig {
 			private String authorizationRequestBaseUri;
+			private OAuth2AuthorizationRequestResolver authorizationRequestResolver;
 			private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
 
 			private AuthorizationEndpointConfig() {
@@ -164,6 +166,18 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> exte
 				return this;
 			}
 
+			/**
+			 * Sets the resolver used for resolving {@link OAuth2AuthorizationRequest}'s.
+			 *
+			 * @param authorizationRequestResolver the resolver used for resolving {@link OAuth2AuthorizationRequest}'s
+			 * @return the {@link AuthorizationEndpointConfig} for further configuration
+			 */
+			public AuthorizationEndpointConfig authorizationRequestResolver(OAuth2AuthorizationRequestResolver authorizationRequestResolver) {
+				Assert.notNull(authorizationRequestResolver, "authorizationRequestResolver cannot be null");
+				this.authorizationRequestResolver = authorizationRequestResolver;
+				return this;
+			}
+
 			/**
 			 * Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s.
 			 *
@@ -267,14 +281,20 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> exte
 	}
 
 	private void configure(B builder, AuthorizationCodeGrantConfigurer authorizationCodeGrantConfigurer) throws Exception {
-		String authorizationRequestBaseUri = authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestBaseUri;
-		if (authorizationRequestBaseUri == null) {
-			authorizationRequestBaseUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI;
+		OAuth2AuthorizationRequestRedirectFilter authorizationRequestFilter;
+
+		if (authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestResolver != null) {
+			authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter(
+					authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestResolver);
+		} else {
+			String authorizationRequestBaseUri = authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestBaseUri;
+			if (authorizationRequestBaseUri == null) {
+				authorizationRequestBaseUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI;
+			}
+			authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter(
+					OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder), authorizationRequestBaseUri);
 		}
 
-		OAuth2AuthorizationRequestRedirectFilter authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter(
-			OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder), authorizationRequestBaseUri);
-
 		if (authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestRepository != null) {
 			authorizationRequestFilter.setAuthorizationRequestRepository(
 				authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestRepository);

+ 27 - 6
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java

@@ -44,6 +44,7 @@ import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
 import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
 import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
 import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
@@ -178,6 +179,7 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> exten
 	 */
 	public class AuthorizationEndpointConfig {
 		private String authorizationRequestBaseUri;
+		private OAuth2AuthorizationRequestResolver authorizationRequestResolver;
 		private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
 
 		private AuthorizationEndpointConfig() {
@@ -195,6 +197,19 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> exten
 			return this;
 		}
 
+		/**
+		 * Sets the resolver used for resolving {@link OAuth2AuthorizationRequest}'s.
+		 *
+		 * @since 5.1
+		 * @param authorizationRequestResolver the resolver used for resolving {@link OAuth2AuthorizationRequest}'s
+		 * @return the {@link AuthorizationEndpointConfig} for further configuration
+		 */
+		public AuthorizationEndpointConfig authorizationRequestResolver(OAuth2AuthorizationRequestResolver authorizationRequestResolver) {
+			Assert.notNull(authorizationRequestResolver, "authorizationRequestResolver cannot be null");
+			this.authorizationRequestResolver = authorizationRequestResolver;
+			return this;
+		}
+
 		/**
 		 * Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s.
 		 *
@@ -444,13 +459,19 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> exten
 
 	@Override
 	public void configure(B http) throws Exception {
-		String authorizationRequestBaseUri = this.authorizationEndpointConfig.authorizationRequestBaseUri;
-		if (authorizationRequestBaseUri == null) {
-			authorizationRequestBaseUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI;
-		}
+		OAuth2AuthorizationRequestRedirectFilter authorizationRequestFilter;
 
-		OAuth2AuthorizationRequestRedirectFilter authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter(
-			OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()), authorizationRequestBaseUri);
+		if (this.authorizationEndpointConfig.authorizationRequestResolver != null) {
+			authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter(
+					this.authorizationEndpointConfig.authorizationRequestResolver);
+		} else {
+			String authorizationRequestBaseUri = this.authorizationEndpointConfig.authorizationRequestBaseUri;
+			if (authorizationRequestBaseUri == null) {
+				authorizationRequestBaseUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI;
+			}
+			authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter(
+					OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()), authorizationRequestBaseUri);
+		}
 
 		if (this.authorizationEndpointConfig.authorizationRequestRepository != null) {
 			authorizationRequestFilter.setAuthorizationRequestRepository(

+ 31 - 0
config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java

@@ -37,7 +37,9 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
+import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizationRequestResolver;
 import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
@@ -74,6 +76,8 @@ public class OAuth2ClientConfigurerTests {
 
 	private static OAuth2AuthorizedClientService authorizedClientService;
 
+	private static OAuth2AuthorizationRequestResolver authorizationRequestResolver;
+
 	private static OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
 
 	private static RequestCache requestCache;
@@ -103,6 +107,8 @@ public class OAuth2ClientConfigurerTests {
 			.build();
 		clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1);
 		authorizedClientService = new InMemoryOAuth2AuthorizedClientService(clientRegistrationRepository);
+		authorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
+				clientRegistrationRepository, "/oauth2/authorization");
 
 		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("access-token-1234")
 				.tokenType(OAuth2AccessToken.TokenType.BEARER)
@@ -173,6 +179,28 @@ public class OAuth2ClientConfigurerTests {
 		verify(requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
 	}
 
+	// gh-5521
+	@Test
+	public void configureWhenCustomAuthorizationRequestResolverSetThenAuthorizationRequestIncludesCustomParameters() throws Exception {
+		// Override default resolver
+		OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = authorizationRequestResolver;
+		authorizationRequestResolver = request -> {
+			OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request);
+			Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
+			additionalParameters.put("param1", "value1");
+			return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
+					.additionalParameters(additionalParameters)
+					.build();
+		};
+
+		this.spring.register(OAuth2ClientConfig.class).autowire();
+
+		MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/authorization/registration-1"))
+				.andExpect(status().is3xxRedirection())
+				.andReturn();
+		assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fclient-1&param1=value1");
+	}
+
 	@EnableWebSecurity
 	@EnableWebMvc
 	static class OAuth2ClientConfig extends WebSecurityConfigurerAdapter {
@@ -188,6 +216,9 @@ public class OAuth2ClientConfigurerTests {
 				.oauth2()
 					.client()
 						.authorizationCodeGrant()
+							.authorizationEndpoint()
+								.authorizationRequestResolver(authorizationRequestResolver)
+								.and()
 							.tokenEndpoint()
 								.accessTokenResponseClient(accessTokenResponseClient);
 		}

+ 46 - 3
config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java

@@ -42,7 +42,9 @@ import org.springframework.security.oauth2.client.registration.InMemoryClientReg
 import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
 import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
 import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
+import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizationRequestResolver;
 import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
 import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
@@ -105,11 +107,9 @@ public class OAuth2LoginConfigurerTests {
 	@Before
 	public void setup() {
 		this.request = new MockHttpServletRequest("GET", "");
+		this.request.setServletPath("/login/oauth2/code/google");
 		this.response = new MockHttpServletResponse();
 		this.filterChain = new MockFilterChain();
-
-		this.request.setMethod("GET");
-		this.request.setServletPath("/login/oauth2/code/google");
 	}
 
 	@After
@@ -225,6 +225,20 @@ public class OAuth2LoginConfigurerTests {
 				.isInstanceOf(OAuth2UserAuthority.class).hasToString("ROLE_USER");
 	}
 
+	// gh-5521
+	@Test
+	public void oauth2LoginWithCustomAuthorizationRequestParameters() throws Exception {
+		loadConfig(OAuth2LoginConfigCustomAuthorizationRequestResolver.class);
+
+		String requestUri = "/oauth2/authorization/google";
+		this.request = new MockHttpServletRequest("GET", requestUri);
+		this.request.setServletPath(requestUri);
+
+		this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
+
+		assertThat(this.response.getRedirectedUrl()).matches("https://accounts.google.com/o/oauth2/v2/auth\\?response_type=code&client_id=clientId&scope=openid\\+profile\\+email&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1");
+	}
+
 	@Test
 	public void oidcLogin() throws Exception {
 		// setup application context
@@ -406,6 +420,35 @@ public class OAuth2LoginConfigurerTests {
 		}
 	}
 
+	@EnableWebSecurity
+	static class OAuth2LoginConfigCustomAuthorizationRequestResolver extends CommonWebSecurityConfigurerAdapter {
+		private ClientRegistrationRepository clientRegistrationRepository =
+				new InMemoryClientRegistrationRepository(CLIENT_REGISTRATION);
+
+		@Override
+		protected void configure(HttpSecurity http) throws Exception {
+			http
+				.oauth2Login()
+					.clientRegistrationRepository(this.clientRegistrationRepository)
+					.authorizationEndpoint()
+						.authorizationRequestResolver(this.getAuthorizationRequestResolver());
+			super.configure(http);
+		}
+
+		private OAuth2AuthorizationRequestResolver getAuthorizationRequestResolver() {
+			OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver =
+					new DefaultOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository, "/oauth2/authorization");
+			return request -> {
+				OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request);
+				Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
+				additionalParameters.put("custom-param1", "custom-value1");
+				return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
+						.additionalParameters(additionalParameters)
+						.build();
+			};
+		}
+	}
+
 	private static abstract class CommonWebSecurityConfigurerAdapter extends WebSecurityConfigurerAdapter {
 		@Override
 		protected void configure(HttpSecurity http) throws Exception {