소스 검색

OAuth2AuthorizedClientArgumentResolver Uses ServerOAuth2AuthorizedClientRepository

Issue: gh-5621
Rob Winch 7 년 전
부모
커밋
dd7925cb63

+ 21 - 2
config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java

@@ -21,6 +21,8 @@ import org.springframework.context.annotation.Configuration;
 import org.springframework.context.annotation.ImportSelector;
 import org.springframework.core.type.AnnotationMetadata;
 import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
+import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
 import org.springframework.util.ClassUtils;
 import org.springframework.web.reactive.config.WebFluxConfigurer;
@@ -51,20 +53,37 @@ final class ReactiveOAuth2ClientImportSelector implements ImportSelector {
 
 	@Configuration
 	static class OAuth2ClientWebFluxSecurityConfiguration implements WebFluxConfigurer {
+		private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
+
 		private ReactiveOAuth2AuthorizedClientService authorizedClientService;
 
 		@Override
 		public void configureArgumentResolvers(ArgumentResolverConfigurer configurer) {
-			if (this.authorizedClientService != null) {
-				configurer.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientService));
+			if (this.authorizedClientRepository != null) {
+				configurer.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver(getAuthorizedClientRepository()));
 			}
 		}
 
+		@Autowired(required = false)
+		public void setAuthorizedClientRepository(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
+			this.authorizedClientRepository = authorizedClientRepository;
+		}
+
 		@Autowired(required = false)
 		public void setAuthorizedClientService(List<ReactiveOAuth2AuthorizedClientService> authorizedClientService) {
 			if (authorizedClientService.size() == 1) {
 				this.authorizedClientService = authorizedClientService.get(0);
 			}
 		}
+
+		private ServerOAuth2AuthorizedClientRepository getAuthorizedClientRepository() {
+			if (this.authorizedClientRepository != null) {
+				return this.authorizedClientRepository;
+			}
+			if (this.authorizedClientService != null) {
+				return new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository(this.authorizedClientService);
+			}
+			return null;
+		}
 	}
 }

+ 16 - 12
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java

@@ -18,14 +18,16 @@ package org.springframework.security.oauth2.client.web.reactive.result.method.an
 
 import org.springframework.core.MethodParameter;
 import org.springframework.core.annotation.AnnotatedElementUtils;
+import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.ReactiveSecurityContextHolder;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
-import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
+import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
 import org.springframework.web.reactive.BindingContext;
@@ -54,16 +56,16 @@ import reactor.core.publisher.Mono;
  * @see RegisteredOAuth2AuthorizedClient
  */
 public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver {
-	private final ReactiveOAuth2AuthorizedClientService authorizedClientService;
+	private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
 
 	/**
 	 * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters.
 	 *
-	 * @param authorizedClientService the authorized client service
+	 * @param authorizedClientRepository the authorized client repository
 	 */
-	public OAuth2AuthorizedClientArgumentResolver(ReactiveOAuth2AuthorizedClientService authorizedClientService) {
-		Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
-		this.authorizedClientService = authorizedClientService;
+	public OAuth2AuthorizedClientArgumentResolver(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
+		Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
+		this.authorizedClientRepository = authorizedClientRepository;
 	}
 
 	@Override
@@ -84,20 +86,22 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
 					.switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalArgumentException(
 							"Unable to resolve the Client Registration Identifier. It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or @RegisteredOAuth2AuthorizedClient(registrationId = \"client1\")."))));
 
-			Mono<String> principalName = ReactiveSecurityContextHolder.getContext()
-					.map(SecurityContext::getAuthentication).map(Authentication::getName);
+			Mono<Authentication> principal = ReactiveSecurityContextHolder.getContext()
+					.map(SecurityContext::getAuthentication)
+					.defaultIfEmpty(new AnonymousAuthenticationToken("key", "anonymous",
+							AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")));
 
 			Mono<OAuth2AuthorizedClient> authorizedClient = Mono
-					.zip(clientRegistrationId, principalName).switchIfEmpty(
+					.zip(clientRegistrationId, principal).switchIfEmpty(
 							clientRegistrationId.flatMap(id -> Mono.error(new IllegalStateException(
 									"Unable to resolve the Authorized Client with registration identifier \""
 											+ id
 											+ "\". An \"authenticated\" or \"unauthenticated\" session is required. To allow for unauthenticated access, ensure ServerHttpSecurity.anonymous() is configured."))))
 					.flatMap(zipped -> {
 						String registrationId = zipped.getT1();
-						String username = zipped.getT2();
-						return this.authorizedClientService
-								.loadAuthorizedClient(registrationId, username).switchIfEmpty(Mono.defer(() -> Mono
+						Authentication authentication = zipped.getT2();
+						return this.authorizedClientRepository
+								.loadAuthorizedClient(registrationId, authentication, exchange).switchIfEmpty(Mono.defer(() -> Mono
 										.error(new ClientAuthorizationRequiredException(
 												registrationId))));
 					}).cast(OAuth2AuthorizedClient.class);

+ 7 - 12
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java

@@ -27,9 +27,9 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.ReactiveSecurityContextHolder;
 import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
-import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
+import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.util.ReflectionUtils;
 import reactor.core.publisher.Hooks;
 import reactor.core.publisher.Mono;
@@ -51,7 +51,7 @@ import static org.mockito.Mockito.when;
 @RunWith(MockitoJUnitRunner.class)
 public class OAuth2AuthorizedClientArgumentResolverTests {
 	@Mock
-	private ReactiveOAuth2AuthorizedClientService authorizedClientService;
+	private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
 	private OAuth2AuthorizedClientArgumentResolver argumentResolver;
 	private OAuth2AuthorizedClient authorizedClient;
 
@@ -59,9 +59,9 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
 
 	@Before
 	public void setUp() {
-		this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientService);
+		this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientRepository);
 		this.authorizedClient = mock(OAuth2AuthorizedClient.class);
-		when(this.authorizedClientService.loadAuthorizedClient(anyString(), any())).thenReturn(Mono.just(this.authorizedClient));
+		when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.just(this.authorizedClient));
 		Hooks.onOperatorDebug();
 	}
 
@@ -100,21 +100,16 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
 	@Test
 	public void resolveArgumentWhenRegistrationIdEmptyAndOAuth2AuthenticationThenResolves() {
 		this.authentication = mock(OAuth2AuthenticationToken.class);
-		when(this.authentication.getName()).thenReturn("client1");
 		when(((OAuth2AuthenticationToken) this.authentication).getAuthorizedClientRegistrationId()).thenReturn("client1");
 		MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class);
 		resolveArgument(methodParameter);
 	}
 
 	@Test
-	public void resolveArgumentWhenParameterTypeOAuth2AuthorizedClientAndCurrentAuthenticationNullThenThrowIllegalStateException() {
+	public void resolveArgumentWhenParameterTypeOAuth2AuthorizedClientAndCurrentAuthenticationNullThenResolves() {
 		this.authentication = null;
 		MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
-		assertThatThrownBy(() -> resolveArgument(methodParameter))
-				.isInstanceOf(IllegalStateException.class)
-				.hasMessage("Unable to resolve the Authorized Client with registration identifier \"client1\". " +
-						"An \"authenticated\" or \"unauthenticated\" session is required. " +
-						"To allow for unauthenticated access, ensure ServerHttpSecurity.anonymous() is configured.");
+		assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient);
 	}
 
 	@Test
@@ -125,7 +120,7 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
 
 	@Test
 	public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() {
-		when(this.authorizedClientService.loadAuthorizedClient(anyString(), any())).thenReturn(Mono.empty());
+		when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.empty());
 		MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
 		assertThatThrownBy(() -> resolveArgument(methodParameter))
 				.isInstanceOf(ClientAuthorizationRequiredException.class);