Răsfoiți Sursa

OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> @Bean is discovered by OAuth2ClientConfiguration

Fixes gh-6572
Daniel Fritz 6 ani în urmă
părinte
comite
bfe1e6a154

+ 15 - 3
config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2019 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,11 +15,15 @@
  */
 package org.springframework.security.config.annotation.web.configuration;
 
+import java.util.List;
+import java.util.Optional;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Configuration;
 import org.springframework.context.annotation.Import;
 import org.springframework.context.annotation.ImportSelector;
 import org.springframework.core.type.AnnotationMetadata;
+import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
+import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver;
@@ -27,8 +31,6 @@ import org.springframework.util.ClassUtils;
 import org.springframework.web.method.support.HandlerMethodArgumentResolver;
 import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
 
-import java.util.List;
-
 /**
  * {@link Configuration} for OAuth 2.0 Client support.
  *
@@ -60,6 +62,7 @@ final class OAuth2ClientConfiguration {
 	static class OAuth2ClientWebMvcSecurityConfiguration implements WebMvcConfigurer {
 		private ClientRegistrationRepository clientRegistrationRepository;
 		private OAuth2AuthorizedClientRepository authorizedClientRepository;
+		private OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient;
 
 		@Override
 		public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) {
@@ -67,6 +70,9 @@ final class OAuth2ClientConfiguration {
 				OAuth2AuthorizedClientArgumentResolver authorizedClientArgumentResolver =
 						new OAuth2AuthorizedClientArgumentResolver(
 								this.clientRegistrationRepository, this.authorizedClientRepository);
+				if (this.accessTokenResponseClient != null) {
+					authorizedClientArgumentResolver.setClientCredentialsTokenResponseClient(this.accessTokenResponseClient);
+				}
 				argumentResolvers.add(authorizedClientArgumentResolver);
 			}
 		}
@@ -84,5 +90,11 @@ final class OAuth2ClientConfiguration {
 				this.authorizedClientRepository = authorizedClientRepositories.get(0);
 			}
 		}
+
+		@Autowired
+		public void setAccessTokenResponseClient(
+				Optional<OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest>> accessTokenResponseClient) {
+			accessTokenResponseClient.ifPresent(client -> this.accessTokenResponseClient = client);
+		}
 	}
 }

+ 122 - 17
config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2019 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,6 +15,21 @@
  */
 package org.springframework.security.config.annotation.web.configuration;
 
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyZeroInteractions;
+import static org.mockito.Mockito.when;
+import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientCredentials;
+import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
+import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
+import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
+import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
+
+import javax.servlet.http.HttpServletRequest;
 import org.junit.Rule;
 import org.junit.Test;
 import org.springframework.beans.factory.NoSuchBeanDefinitionException;
@@ -26,26 +41,18 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.test.SpringTestRule;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
+import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.web.bind.annotation.GetMapping;
 import org.springframework.web.bind.annotation.RestController;
 import org.springframework.web.servlet.config.annotation.EnableWebMvc;
 
-import javax.servlet.http.HttpServletRequest;
-
-import static org.assertj.core.api.Assertions.assertThatThrownBy;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
-import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
-import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
-import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
-import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
-
 /**
  * Tests for {@link OAuth2ClientConfiguration}.
  *
@@ -64,26 +71,66 @@ public class OAuth2ClientConfigurationTests {
 		String principalName = "user1";
 		TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password");
 
+		ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class);
 		OAuth2AuthorizedClientRepository authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class);
 		OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class);
 		when(authorizedClientRepository.loadAuthorizedClient(
-				eq(clientRegistrationId), eq(authentication), any(HttpServletRequest.class))).thenReturn(authorizedClient);
+				eq(clientRegistrationId), eq(authentication), any(HttpServletRequest.class)))
+				.thenReturn(authorizedClient);
 
 		OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
 		when(authorizedClient.getAccessToken()).thenReturn(accessToken);
 
+		OAuth2AccessTokenResponseClient accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
+
+		OAuth2AuthorizedClientArgumentResolverConfig.CLIENT_REGISTRATION_REPOSITORY = clientRegistrationRepository;
 		OAuth2AuthorizedClientArgumentResolverConfig.AUTHORIZED_CLIENT_REPOSITORY = authorizedClientRepository;
+		OAuth2AuthorizedClientArgumentResolverConfig.ACCESS_TOKEN_RESPONSE_CLIENT = accessTokenResponseClient;
 		this.spring.register(OAuth2AuthorizedClientArgumentResolverConfig.class).autowire();
 
 		this.mockMvc.perform(get("/authorized-client").with(authentication(authentication)))
-			.andExpect(status().isOk())
-			.andExpect(content().string("resolved"));
+				.andExpect(status().isOk())
+				.andExpect(content().string("resolved"));
+		verifyZeroInteractions(accessTokenResponseClient);
+	}
+
+	@Test
+	public void requestWhenAuthorizedClientNotFoundAndClientCredentialsThenTokenResponseClientIsUsed() throws Exception {
+		String clientRegistrationId = "client1";
+		String principalName = "user1";
+		TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password");
+
+		ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class);
+		OAuth2AuthorizedClientRepository authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class);
+		OAuth2AccessTokenResponseClient accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
+
+		ClientRegistration clientRegistration = clientCredentials().registrationId(clientRegistrationId).build();
+		when(clientRegistrationRepository.findByRegistrationId(clientRegistrationId)).thenReturn(clientRegistration);
+
+		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("access-token-1234")
+				.tokenType(OAuth2AccessToken.TokenType.BEARER)
+				.expiresIn(300)
+				.build();
+		when(accessTokenResponseClient.getTokenResponse(any(OAuth2ClientCredentialsGrantRequest.class)))
+				.thenReturn(accessTokenResponse);
+
+		OAuth2AuthorizedClientArgumentResolverConfig.CLIENT_REGISTRATION_REPOSITORY = clientRegistrationRepository;
+		OAuth2AuthorizedClientArgumentResolverConfig.AUTHORIZED_CLIENT_REPOSITORY = authorizedClientRepository;
+		OAuth2AuthorizedClientArgumentResolverConfig.ACCESS_TOKEN_RESPONSE_CLIENT = accessTokenResponseClient;
+		this.spring.register(OAuth2AuthorizedClientArgumentResolverConfig.class).autowire();
+
+		this.mockMvc.perform(get("/authorized-client").with(authentication(authentication)))
+				.andExpect(status().isOk())
+				.andExpect(content().string("resolved"));
+		verify(accessTokenResponseClient, times(1)).getTokenResponse(any(OAuth2ClientCredentialsGrantRequest.class));
 	}
 
 	@EnableWebMvc
 	@EnableWebSecurity
 	static class OAuth2AuthorizedClientArgumentResolverConfig extends WebSecurityConfigurerAdapter {
+		static ClientRegistrationRepository CLIENT_REGISTRATION_REPOSITORY;
 		static OAuth2AuthorizedClientRepository AUTHORIZED_CLIENT_REPOSITORY;
+		static OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> ACCESS_TOKEN_RESPONSE_CLIENT;
 
 		@Override
 		protected void configure(HttpSecurity http) throws Exception {
@@ -100,13 +147,18 @@ public class OAuth2ClientConfigurationTests {
 
 		@Bean
 		public ClientRegistrationRepository clientRegistrationRepository() {
-			return mock(ClientRegistrationRepository.class);
+			return CLIENT_REGISTRATION_REPOSITORY;
 		}
 
 		@Bean
 		public OAuth2AuthorizedClientRepository authorizedClientRepository() {
 			return AUTHORIZED_CLIENT_REPOSITORY;
 		}
+
+		@Bean
+		public OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient() {
+			return ACCESS_TOKEN_RESPONSE_CLIENT;
+		}
 	}
 
 	// gh-5321
@@ -147,6 +199,11 @@ public class OAuth2ClientConfigurationTests {
 		public OAuth2AuthorizedClientRepository authorizedClientRepository2() {
 			return mock(OAuth2AuthorizedClientRepository.class);
 		}
+
+		@Bean
+		public OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient() {
+			return mock(OAuth2AccessTokenResponseClient.class);
+		}
 	}
 
 	@Test
@@ -208,5 +265,53 @@ public class OAuth2ClientConfigurationTests {
 		public OAuth2AuthorizedClientRepository authorizedClientRepository() {
 			return mock(OAuth2AuthorizedClientRepository.class);
 		}
+
+		@Bean
+		public OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient() {
+			return mock(OAuth2AccessTokenResponseClient.class);
+		}
+	}
+
+	@Test
+	public void loadContextWhenAccessTokenResponseClientRegisteredTwiceThenThrowNoUniqueBeanDefinitionException() {
+		assertThatThrownBy(() -> this.spring.register(AccessTokenResponseClientRegisteredTwiceConfig.class).autowire())
+				.hasRootCauseInstanceOf(NoUniqueBeanDefinitionException.class)
+				.hasMessageContaining("expected single matching bean but found 2: accessTokenResponseClient1,accessTokenResponseClient2");
+	}
+
+	@EnableWebMvc
+	@EnableWebSecurity
+	static class AccessTokenResponseClientRegisteredTwiceConfig extends WebSecurityConfigurerAdapter {
+
+		@Override
+		protected void configure(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+				.authorizeRequests()
+					.anyRequest().authenticated()
+					.and()
+				.oauth2Login();
+			// @formatter:on
+		}
+
+		@Bean
+		public ClientRegistrationRepository clientRegistrationRepository() {
+			return mock(ClientRegistrationRepository.class);
+		}
+
+		@Bean
+		public OAuth2AuthorizedClientRepository authorizedClientRepository() {
+			return mock(OAuth2AuthorizedClientRepository.class);
+		}
+
+		@Bean
+		public OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient1() {
+			return mock(OAuth2AccessTokenResponseClient.class);
+		}
+
+		@Bean
+		public OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient2() {
+			return mock(OAuth2AccessTokenResponseClient.class);
+		}
 	}
 }