Przeglądaj źródła

Add support configuring OAuth2AuthorizationRequestResolver as bean

Closes gh-15236
Max Batischev 1 rok temu
rodzic
commit
4e52eda0f5

+ 10 - 8
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2023 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -58,7 +58,7 @@ import org.springframework.util.Assert;
  * {@link ClientRegistrationRepository} {@code @Bean} may be registered instead.
  *
  * <h2>Security Filters</h2>
- *
+ * <p>
  * The following {@code Filter}'s are populated for {@link #authorizationCodeGrant()}:
  *
  * <ul>
@@ -67,7 +67,7 @@ import org.springframework.util.Assert;
  * </ul>
  *
  * <h2>Shared Objects Created</h2>
- *
+ * <p>
  * The following shared objects are populated:
  *
  * <ul>
@@ -76,7 +76,7 @@ import org.springframework.util.Assert;
  * </ul>
  *
  * <h2>Shared Objects Used</h2>
- *
+ * <p>
  * The following shared objects are used:
  *
  * <ul>
@@ -283,10 +283,12 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>>
 			if (this.authorizationRequestResolver != null) {
 				return this.authorizationRequestResolver;
 			}
-			ClientRegistrationRepository clientRegistrationRepository = OAuth2ClientConfigurerUtils
-				.getClientRegistrationRepository(getBuilder());
-			return new DefaultOAuth2AuthorizationRequestResolver(clientRegistrationRepository,
-					OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
+			ResolvableType resolvableType = ResolvableType.forClass(OAuth2AuthorizationRequestResolver.class);
+			OAuth2AuthorizationRequestResolver bean = getBeanOrNull(resolvableType);
+			return (bean != null) ? bean
+					: new DefaultOAuth2AuthorizationRequestResolver(
+							OAuth2ClientConfigurerUtils.getClientRegistrationRepository(getBuilder()),
+							OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
 		}
 
 		private OAuth2AuthorizationCodeGrantFilter createAuthorizationCodeGrantFilter(B builder) {

+ 6 - 3
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -4532,9 +4532,12 @@ public class ServerHttpSecurity {
 		}
 
 		private OAuth2AuthorizationRequestRedirectWebFilter getRedirectWebFilter() {
-			OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter;
-			if (this.authorizationRequestResolver != null) {
-				return new OAuth2AuthorizationRequestRedirectWebFilter(this.authorizationRequestResolver);
+			ServerOAuth2AuthorizationRequestResolver result = this.authorizationRequestResolver;
+			if (result == null) {
+				result = getBeanOrNull(ServerOAuth2AuthorizationRequestResolver.class);
+			}
+			if (result != null) {
+				return new OAuth2AuthorizationRequestRedirectWebFilter(result);
 			}
 			return new OAuth2AuthorizationRequestRedirectWebFilter(getClientRegistrationRepository());
 		}

+ 68 - 1
config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -285,6 +285,18 @@ public class OAuth2ClientConfigurerTests {
 		verify(authorizationRedirectStrategy).sendRedirect(any(), any(), anyString());
 	}
 
+	@Test
+	public void configureWhenCustomAuthorizationRequestResolverBeanPresentThenAuthorizationRequestIncludesCustomParameters()
+			throws Exception {
+		this.spring.register(OAuth2ClientBeanConfig.class).autowire();
+		// @formatter:off
+		this.mockMvc.perform(get("/oauth2/authorization/registration-1"))
+				.andExpect(status().is3xxRedirection())
+				.andReturn();
+		// @formatter:on
+		verify(authorizationRequestResolver).resolve(any());
+	}
+
 	@EnableWebSecurity
 	@Configuration
 	@EnableWebMvc
@@ -362,4 +374,59 @@ public class OAuth2ClientConfigurerTests {
 
 	}
 
+	@EnableWebSecurity
+	@Configuration
+	@EnableWebMvc
+	static class OAuth2ClientBeanConfig {
+
+		@Bean
+		SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+					.authorizeRequests()
+					.anyRequest().authenticated()
+					.and()
+					.requestCache()
+					.requestCache(requestCache)
+					.and()
+					.oauth2Client()
+					.authorizationCodeGrant()
+					.authorizationRedirectStrategy(authorizationRedirectStrategy)
+					.accessTokenResponseClient(accessTokenResponseClient);
+			return http.build();
+			// @formatter:on
+		}
+
+		@Bean
+		ClientRegistrationRepository clientRegistrationRepository() {
+			return clientRegistrationRepository;
+		}
+
+		@Bean
+		OAuth2AuthorizedClientRepository authorizedClientRepository() {
+			return authorizedClientRepository;
+		}
+
+		@Bean
+		OAuth2AuthorizationRequestResolver authorizationRequestResolver() {
+			OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = authorizationRequestResolver;
+			authorizationRequestResolver = mock(OAuth2AuthorizationRequestResolver.class);
+			given(authorizationRequestResolver.resolve(any()))
+				.willAnswer((invocation) -> defaultAuthorizationRequestResolver.resolve(invocation.getArgument(0)));
+			return authorizationRequestResolver;
+		}
+
+		@RestController
+		class ResourceController {
+
+			@GetMapping("/resource1")
+			String resource1(
+					@RegisteredOAuth2AuthorizedClient("registration-1") OAuth2AuthorizedClient authorizedClient) {
+				return "resource1";
+			}
+
+		}
+
+	}
+
 }

+ 12 - 0
config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java

@@ -64,6 +64,7 @@ import org.springframework.security.oauth2.client.registration.InMemoryReactiveC
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
+import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizationRequestResolver;
 import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
@@ -457,6 +458,7 @@ public class OAuth2LoginTests {
 		OidcUser user = TestOidcUsers.create();
 		ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = config.userService;
 		given(userService.loadUser(any())).willReturn(Mono.just(user));
+		ServerOAuth2AuthorizationRequestResolver resolver = config.resolver;
 		// @formatter:off
 		webTestClient.get()
 				.uri("/login/oauth2/code/google")
@@ -466,6 +468,7 @@ public class OAuth2LoginTests {
 		verify(config.jwtDecoderFactory).createDecoder(any());
 		verify(tokenResponseClient).getTokenResponse(any());
 		verify(securityContextRepository).save(any(), any());
+		verify(resolver).resolve(any());
 	}
 
 	// gh-5562
@@ -837,6 +840,10 @@ public class OAuth2LoginTests {
 
 		ServerSecurityContextRepository securityContextRepository = mock(ServerSecurityContextRepository.class);
 
+		ServerOAuth2AuthorizationRequestResolver resolver = spy(
+				new DefaultServerOAuth2AuthorizationRequestResolver(new InMemoryReactiveClientRegistrationRepository(
+						TestClientRegistrations.clientRegistration().build())));
+
 		@Bean
 		SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) {
 			// @formatter:off
@@ -864,6 +871,11 @@ public class OAuth2LoginTests {
 			return this.jwtDecoderFactory;
 		}
 
+		@Bean
+		ServerOAuth2AuthorizationRequestResolver resolver() {
+			return this.resolver;
+		}
+
 		@Bean
 		ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient() {
 			return this.tokenResponseClient;