Browse Source

Use OAuth2AuthorizedClientRepository in filters and resolver

Fixes gh-5544
Joe Grandja 7 years ago
parent
commit
9a144d742e
12 changed files with 215 additions and 87 deletions
  1. 7 9
      config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java
  2. 18 6
      config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java
  3. 30 7
      config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerUtils.java
  4. 20 4
      config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java
  5. 27 20
      config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java
  6. 13 5
      config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java
  7. 27 13
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java
  8. 26 14
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java
  9. 21 4
      samples/boot/authcodegrant/src/integration-test/java/org/springframework/security/samples/OAuth2AuthorizationCodeGrantApplicationTests.java
  10. 8 0
      samples/boot/authcodegrant/src/main/java/sample/config/SecurityConfig.java
  11. 9 5
      samples/boot/oauth2login/src/integration-test/java/org/springframework/security/samples/OAuth2LoginApplicationTests.java
  12. 9 0
      samples/boot/oauth2login/src/main/java/sample/OAuth2LoginApplication.java

+ 7 - 9
config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java

@@ -20,8 +20,7 @@ import org.springframework.context.annotation.Configuration;
 import org.springframework.context.annotation.Import;
 import org.springframework.context.annotation.ImportSelector;
 import org.springframework.core.type.AnnotationMetadata;
-import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
-import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver;
 import org.springframework.util.ClassUtils;
 import org.springframework.web.method.support.HandlerMethodArgumentResolver;
@@ -58,22 +57,21 @@ final class OAuth2ClientConfiguration {
 
 	@Configuration
 	static class OAuth2ClientWebMvcSecurityConfiguration implements WebMvcConfigurer {
-		private OAuth2AuthorizedClientService authorizedClientService;
+		private OAuth2AuthorizedClientRepository authorizedClientRepository;
 
 		@Override
 		public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) {
-			if (this.authorizedClientService != null) {
+			if (this.authorizedClientRepository != null) {
 				OAuth2AuthorizedClientArgumentResolver authorizedClientArgumentResolver =
-						new OAuth2AuthorizedClientArgumentResolver(
-								new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(this.authorizedClientService));
+						new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientRepository);
 				argumentResolvers.add(authorizedClientArgumentResolver);
 			}
 		}
 
 		@Autowired(required = false)
-		public void setAuthorizedClientService(List<OAuth2AuthorizedClientService> authorizedClientServices) {
-			if (authorizedClientServices.size() == 1) {
-				this.authorizedClientService = authorizedClientServices.get(0);
+		public void setAuthorizedClientRepository(List<OAuth2AuthorizedClientRepository> authorizedClientRepositories) {
+			if (authorizedClientRepositories.size() == 1) {
+				this.authorizedClientRepository = authorizedClientRepositories.get(0);
 			}
 		}
 	}

+ 18 - 6
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java

@@ -29,6 +29,7 @@ import org.springframework.security.oauth2.client.web.AuthorizationRequestReposi
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.web.savedrequest.RequestCache;
 import org.springframework.util.Assert;
@@ -63,7 +64,7 @@ import org.springframework.util.Assert;
  *
  * <ul>
  * <li>{@link ClientRegistrationRepository} (required)</li>
- * <li>{@link OAuth2AuthorizedClientService} (optional)</li>
+ * <li>{@link OAuth2AuthorizedClientRepository} (optional)</li>
  * </ul>
  *
  * <h2>Shared Objects Used</h2>
@@ -72,7 +73,7 @@ import org.springframework.util.Assert;
  *
  * <ul>
  * <li>{@link ClientRegistrationRepository}</li>
- * <li>{@link OAuth2AuthorizedClientService}</li>
+ * <li>{@link OAuth2AuthorizedClientRepository}</li>
  * </ul>
  *
  * @author Joe Grandja
@@ -80,7 +81,7 @@ import org.springframework.util.Assert;
  * @see OAuth2AuthorizationRequestRedirectFilter
  * @see OAuth2AuthorizationCodeGrantFilter
  * @see ClientRegistrationRepository
- * @see OAuth2AuthorizedClientService
+ * @see OAuth2AuthorizedClientRepository
  * @see AbstractHttpConfigurer
  */
 public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> extends
@@ -100,6 +101,18 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> exte
 		return this;
 	}
 
+	/**
+	 * Sets the repository for authorized client(s).
+	 *
+	 * @param authorizedClientRepository the authorized client repository
+	 * @return the {@link OAuth2ClientConfigurer} for further configuration
+	 */
+	public OAuth2ClientConfigurer<B> authorizedClientRepository(OAuth2AuthorizedClientRepository authorizedClientRepository) {
+		Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
+		this.getBuilder().setSharedObject(OAuth2AuthorizedClientRepository.class, authorizedClientRepository);
+		return this;
+	}
+
 	/**
 	 * Sets the service for authorized client(s).
 	 *
@@ -108,7 +121,7 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> exte
 	 */
 	public OAuth2ClientConfigurer<B> authorizedClientService(OAuth2AuthorizedClientService authorizedClientService) {
 		Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
-		this.getBuilder().setSharedObject(OAuth2AuthorizedClientService.class, authorizedClientService);
+		this.authorizedClientRepository(new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService));
 		return this;
 	}
 
@@ -309,8 +322,7 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> exte
 
 		OAuth2AuthorizationCodeGrantFilter authorizationCodeGrantFilter = new OAuth2AuthorizationCodeGrantFilter(
 				OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder),
-				new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(
-						OAuth2ClientConfigurerUtils.getAuthorizedClientService(builder)),
+				OAuth2ClientConfigurerUtils.getAuthorizedClientRepository(builder),
 				authenticationManager);
 
 		if (authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestRepository != null) {

+ 30 - 7
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerUtils.java

@@ -24,6 +24,8 @@ import org.springframework.security.config.annotation.web.configurers.AbstractHt
 import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.util.StringUtils;
 
 import java.util.Map;
@@ -61,14 +63,35 @@ final class OAuth2ClientConfigurerUtils {
 		return clientRegistrationRepositoryMap.values().iterator().next();
 	}
 
-	static <B extends HttpSecurityBuilder<B>> OAuth2AuthorizedClientService getAuthorizedClientService(B builder) {
-		OAuth2AuthorizedClientService authorizedClientService = builder.getSharedObject(OAuth2AuthorizedClientService.class);
-		if (authorizedClientService == null) {
-			authorizedClientService = getAuthorizedClientServiceBean(builder);
-			if (authorizedClientService == null) {
-				authorizedClientService = new InMemoryOAuth2AuthorizedClientService(getClientRegistrationRepository(builder));
+	static <B extends HttpSecurityBuilder<B>> OAuth2AuthorizedClientRepository getAuthorizedClientRepository(B builder) {
+		OAuth2AuthorizedClientRepository authorizedClientRepository = builder.getSharedObject(OAuth2AuthorizedClientRepository.class);
+		if (authorizedClientRepository == null) {
+			authorizedClientRepository = getAuthorizedClientRepositoryBean(builder);
+			if (authorizedClientRepository == null) {
+				authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(
+						getAuthorizedClientService((builder)));
 			}
-			builder.setSharedObject(OAuth2AuthorizedClientService.class, authorizedClientService);
+			builder.setSharedObject(OAuth2AuthorizedClientRepository.class, authorizedClientRepository);
+		}
+		return authorizedClientRepository;
+	}
+
+	private static <B extends HttpSecurityBuilder<B>> OAuth2AuthorizedClientRepository getAuthorizedClientRepositoryBean(B builder) {
+		Map<String, OAuth2AuthorizedClientRepository> authorizedClientRepositoryMap = BeanFactoryUtils.beansOfTypeIncludingAncestors(
+				builder.getSharedObject(ApplicationContext.class), OAuth2AuthorizedClientRepository.class);
+		if (authorizedClientRepositoryMap.size() > 1) {
+			throw new NoUniqueBeanDefinitionException(OAuth2AuthorizedClientRepository.class, authorizedClientRepositoryMap.size(),
+					"Expected single matching bean of type '" + OAuth2AuthorizedClientRepository.class.getName() + "' but found " +
+							authorizedClientRepositoryMap.size() + ": " + StringUtils.collectionToCommaDelimitedString(authorizedClientRepositoryMap.keySet()));
+		}
+		return (!authorizedClientRepositoryMap.isEmpty() ? authorizedClientRepositoryMap.values().iterator().next() : null);
+	}
+
+
+	private static <B extends HttpSecurityBuilder<B>> OAuth2AuthorizedClientService getAuthorizedClientService(B builder) {
+		OAuth2AuthorizedClientService authorizedClientService = getAuthorizedClientServiceBean(builder);
+		if (authorizedClientService == null) {
+			authorizedClientService = new InMemoryOAuth2AuthorizedClientService(getClientRegistrationRepository(builder));
 		}
 		return authorizedClientService;
 	}

+ 20 - 4
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java

@@ -42,9 +42,11 @@ import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserServ
 import org.springframework.security.oauth2.client.userinfo.DelegatingOAuth2UserService;
 import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
 import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
+import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
@@ -92,7 +94,7 @@ import java.util.Map;
  *
  * <ul>
  * <li>{@link ClientRegistrationRepository} (required)</li>
- * <li>{@link OAuth2AuthorizedClientService} (optional)</li>
+ * <li>{@link OAuth2AuthorizedClientRepository} (optional)</li>
  * <li>{@link GrantedAuthoritiesMapper} (optional)</li>
  * </ul>
  *
@@ -102,7 +104,7 @@ import java.util.Map;
  *
  * <ul>
  * <li>{@link ClientRegistrationRepository}</li>
- * <li>{@link OAuth2AuthorizedClientService}</li>
+ * <li>{@link OAuth2AuthorizedClientRepository}</li>
  * <li>{@link GrantedAuthoritiesMapper}</li>
  * <li>{@link DefaultLoginPageGeneratingFilter} - if {@link #loginPage(String)} is not configured
  * and {@code DefaultLoginPageGeneratingFilter} is available, than a default login page will be made available</li>
@@ -115,6 +117,7 @@ import java.util.Map;
  * @see OAuth2AuthorizationRequestRedirectFilter
  * @see OAuth2LoginAuthenticationFilter
  * @see ClientRegistrationRepository
+ * @see OAuth2AuthorizedClientRepository
  * @see AbstractAuthenticationFilterConfigurer
  */
 public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> extends
@@ -139,6 +142,19 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> exten
 		return this;
 	}
 
+	/**
+	 * Sets the repository for authorized client(s).
+	 *
+	 * @since 5.1
+	 * @param authorizedClientRepository the authorized client repository
+	 * @return the {@link OAuth2LoginConfigurer} for further configuration
+	 */
+	public OAuth2LoginConfigurer<B> authorizedClientRepository(OAuth2AuthorizedClientRepository authorizedClientRepository) {
+		Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
+		this.getBuilder().setSharedObject(OAuth2AuthorizedClientRepository.class, authorizedClientRepository);
+		return this;
+	}
+
 	/**
 	 * Sets the service for authorized client(s).
 	 *
@@ -147,7 +163,7 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> exten
 	 */
 	public OAuth2LoginConfigurer<B> authorizedClientService(OAuth2AuthorizedClientService authorizedClientService) {
 		Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
-		this.getBuilder().setSharedObject(OAuth2AuthorizedClientService.class, authorizedClientService);
+		this.authorizedClientRepository(new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService));
 		return this;
 	}
 
@@ -400,7 +416,7 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> exten
 		OAuth2LoginAuthenticationFilter authenticationFilter =
 			new OAuth2LoginAuthenticationFilter(
 				OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()),
-				OAuth2ClientConfigurerUtils.getAuthorizedClientService(this.getBuilder()),
+				OAuth2ClientConfigurerUtils.getAuthorizedClientRepository(this.getBuilder()),
 				this.loginProcessingUrl);
 		this.setAuthenticationFilter(authenticationFilter);
 		super.loginProcessingUrl(this.loginProcessingUrl);

+ 27 - 20
config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java

@@ -21,22 +21,27 @@ import org.springframework.beans.factory.NoSuchBeanDefinitionException;
 import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Bean;
+import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.test.SpringTestRule;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
-import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.web.bind.annotation.GetMapping;
 import org.springframework.web.bind.annotation.RestController;
 import org.springframework.web.servlet.config.annotation.EnableWebMvc;
 
+import javax.servlet.http.HttpServletRequest;
+
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
-import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user;
+import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
@@ -57,18 +62,20 @@ public class OAuth2ClientConfigurationTests {
 	public void requestWhenAuthorizedClientFoundThenMethodArgumentResolved() throws Exception {
 		String clientRegistrationId = "client1";
 		String principalName = "user1";
+		TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password");
 
-		OAuth2AuthorizedClientService authorizedClientService = mock(OAuth2AuthorizedClientService.class);
+		OAuth2AuthorizedClientRepository authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class);
 		OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class);
-		when(authorizedClientService.loadAuthorizedClient(clientRegistrationId, principalName)).thenReturn(authorizedClient);
+		when(authorizedClientRepository.loadAuthorizedClient(
+				eq(clientRegistrationId), eq(authentication), any(HttpServletRequest.class))).thenReturn(authorizedClient);
 
 		OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
 		when(authorizedClient.getAccessToken()).thenReturn(accessToken);
 
-		OAuth2AuthorizedClientArgumentResolverConfig.AUTHORIZED_CLIENT_SERVICE = authorizedClientService;
+		OAuth2AuthorizedClientArgumentResolverConfig.AUTHORIZED_CLIENT_REPOSITORY = authorizedClientRepository;
 		this.spring.register(OAuth2AuthorizedClientArgumentResolverConfig.class).autowire();
 
-		this.mockMvc.perform(get("/authorized-client").with(user(principalName)))
+		this.mockMvc.perform(get("/authorized-client").with(authentication(authentication)))
 			.andExpect(status().isOk())
 			.andExpect(content().string("resolved"));
 	}
@@ -76,7 +83,7 @@ public class OAuth2ClientConfigurationTests {
 	@EnableWebMvc
 	@EnableWebSecurity
 	static class OAuth2AuthorizedClientArgumentResolverConfig extends WebSecurityConfigurerAdapter {
-		static OAuth2AuthorizedClientService AUTHORIZED_CLIENT_SERVICE;
+		static OAuth2AuthorizedClientRepository AUTHORIZED_CLIENT_REPOSITORY;
 
 		@Override
 		protected void configure(HttpSecurity http) throws Exception {
@@ -92,23 +99,23 @@ public class OAuth2ClientConfigurationTests {
 		}
 
 		@Bean
-		public OAuth2AuthorizedClientService authorizedClientService() {
-			return AUTHORIZED_CLIENT_SERVICE;
+		public OAuth2AuthorizedClientRepository authorizedClientRepository() {
+			return AUTHORIZED_CLIENT_REPOSITORY;
 		}
 	}
 
 	// gh-5321
 	@Test
-	public void loadContextWhenOAuth2AuthorizedClientServiceRegisteredTwiceThenThrowNoUniqueBeanDefinitionException() {
-		assertThatThrownBy(() -> this.spring.register(OAuth2AuthorizedClientServiceRegisteredTwiceConfig.class).autowire())
+	public void loadContextWhenOAuth2AuthorizedClientRepositoryRegisteredTwiceThenThrowNoUniqueBeanDefinitionException() {
+		assertThatThrownBy(() -> this.spring.register(OAuth2AuthorizedClientRepositoryRegisteredTwiceConfig.class).autowire())
 				.hasRootCauseInstanceOf(NoUniqueBeanDefinitionException.class)
-				.hasMessageContaining("Expected single matching bean of type '" + OAuth2AuthorizedClientService.class.getName() +
-					"' but found 2: authorizedClientService1,authorizedClientService2");
+				.hasMessageContaining("Expected single matching bean of type '" + OAuth2AuthorizedClientRepository.class.getName() +
+					"' but found 2: authorizedClientRepository1,authorizedClientRepository2");
 	}
 
 	@EnableWebMvc
 	@EnableWebSecurity
-	static class OAuth2AuthorizedClientServiceRegisteredTwiceConfig extends WebSecurityConfigurerAdapter {
+	static class OAuth2AuthorizedClientRepositoryRegisteredTwiceConfig extends WebSecurityConfigurerAdapter {
 
 		@Override
 		protected void configure(HttpSecurity http) throws Exception {
@@ -127,13 +134,13 @@ public class OAuth2ClientConfigurationTests {
 		}
 
 		@Bean
-		public OAuth2AuthorizedClientService authorizedClientService1() {
-			return mock(OAuth2AuthorizedClientService.class);
+		public OAuth2AuthorizedClientRepository authorizedClientRepository1() {
+			return mock(OAuth2AuthorizedClientRepository.class);
 		}
 
 		@Bean
-		public OAuth2AuthorizedClientService authorizedClientService2() {
-			return mock(OAuth2AuthorizedClientService.class);
+		public OAuth2AuthorizedClientRepository authorizedClientRepository2() {
+			return mock(OAuth2AuthorizedClientRepository.class);
 		}
 	}
 
@@ -194,8 +201,8 @@ public class OAuth2ClientConfigurationTests {
 		}
 
 		@Bean
-		public OAuth2AuthorizedClientService authorizedClientService() {
-			return mock(OAuth2AuthorizedClientService.class);
+		public OAuth2AuthorizedClientRepository authorizedClientRepository() {
+			return mock(OAuth2AuthorizedClientRepository.class);
 		}
 	}
 }

+ 13 - 5
config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java

@@ -23,6 +23,7 @@ import org.springframework.context.annotation.Bean;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockHttpSession;
+import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
@@ -36,10 +37,12 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCo
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
+import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
 import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizationRequestResolver;
 import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
@@ -61,6 +64,7 @@ import java.util.Map;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.*;
+import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
@@ -76,6 +80,8 @@ public class OAuth2ClientConfigurerTests {
 
 	private static OAuth2AuthorizedClientService authorizedClientService;
 
+	private static OAuth2AuthorizedClientRepository authorizedClientRepository;
+
 	private static OAuth2AuthorizationRequestResolver authorizationRequestResolver;
 
 	private static OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
@@ -107,6 +113,7 @@ public class OAuth2ClientConfigurerTests {
 			.build();
 		clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1);
 		authorizedClientService = new InMemoryOAuth2AuthorizedClientService(clientRegistrationRepository);
+		authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService);
 		authorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
 				clientRegistrationRepository, "/oauth2/authorization");
 
@@ -153,17 +160,18 @@ public class OAuth2ClientConfigurerTests {
 		MockHttpSession session = (MockHttpSession) request.getSession();
 
 		String principalName = "user1";
+		TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password");
 
 		this.mockMvc.perform(get("/client-1")
 			.param(OAuth2ParameterNames.CODE, "code")
 			.param(OAuth2ParameterNames.STATE, "state")
-			.with(user(principalName))
+			.with(authentication(authentication))
 			.session(session))
 			.andExpect(status().is3xxRedirection())
 			.andExpect(redirectedUrl("http://localhost/client-1"));
 
-		OAuth2AuthorizedClient authorizedClient = authorizedClientService.loadAuthorizedClient(
-			this.registration1.getRegistrationId(), principalName);
+		OAuth2AuthorizedClient authorizedClient = authorizedClientRepository.loadAuthorizedClient(
+			this.registration1.getRegistrationId(), authentication, request);
 		assertThat(authorizedClient).isNotNull();
 	}
 
@@ -229,8 +237,8 @@ public class OAuth2ClientConfigurerTests {
 		}
 
 		@Bean
-		public OAuth2AuthorizedClientService authorizedClientService() {
-			return authorizedClientService;
+		public OAuth2AuthorizedClientRepository authorizedClientRepository() {
+			return authorizedClientRepository;
 		}
 
 		@RestController

+ 27 - 13
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,12 +15,6 @@
  */
 package org.springframework.security.oauth2.client.web;
 
-import java.io.IOException;
-
-import javax.servlet.ServletException;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
@@ -43,6 +37,11 @@ import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.util.Assert;
 import org.springframework.util.MultiValueMap;
 
+import javax.servlet.ServletException;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.io.IOException;
+
 /**
  * An implementation of an {@link AbstractAuthenticationProcessingFilter} for OAuth 2.0 Login.
  *
@@ -68,7 +67,7 @@ import org.springframework.util.MultiValueMap;
  * </li>
  * <li>
  *  Upon a successful authentication, an {@link OAuth2AuthenticationToken} is created (representing the End-User {@code Principal})
- *  and associated to the {@link OAuth2AuthorizedClient Authorized Client} using the {@link OAuth2AuthorizedClientService}.
+ *  and associated to the {@link OAuth2AuthorizedClient Authorized Client} using the {@link OAuth2AuthorizedClientRepository}.
  * </li>
  * <li>
  *  Finally, the {@link OAuth2AuthenticationToken} is returned and ultimately stored
@@ -88,7 +87,7 @@ import org.springframework.util.MultiValueMap;
  * @see OAuth2AuthorizationRequestRedirectFilter
  * @see ClientRegistrationRepository
  * @see OAuth2AuthorizedClient
- * @see OAuth2AuthorizedClientService
+ * @see OAuth2AuthorizedClientRepository
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a>
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.2">Section 4.1.2 Authorization Response</a>
  */
@@ -100,7 +99,7 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
 	private static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found";
 	private static final String CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE = "client_registration_not_found";
 	private ClientRegistrationRepository clientRegistrationRepository;
-	private OAuth2AuthorizedClientService authorizedClientService;
+	private OAuth2AuthorizedClientRepository authorizedClientRepository;
 	private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
 		new HttpSessionOAuth2AuthorizationRequestRepository();
 
@@ -125,11 +124,26 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
 	public OAuth2LoginAuthenticationFilter(ClientRegistrationRepository clientRegistrationRepository,
 											OAuth2AuthorizedClientService authorizedClientService,
 											String filterProcessesUrl) {
+		this(clientRegistrationRepository,
+				new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService), filterProcessesUrl);
+	}
+
+	/**
+	 * Constructs an {@code OAuth2LoginAuthenticationFilter} using the provided parameters.
+	 *
+	 * @since 5.1
+	 * @param clientRegistrationRepository the repository of client registrations
+	 * @param authorizedClientRepository the authorized client repository
+	 * @param filterProcessesUrl the {@code URI} where this {@code Filter} will process the authentication requests
+	 */
+	public OAuth2LoginAuthenticationFilter(ClientRegistrationRepository clientRegistrationRepository,
+											OAuth2AuthorizedClientRepository authorizedClientRepository,
+											String filterProcessesUrl) {
 		super(filterProcessesUrl);
 		Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
-		Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
+		Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
 		this.clientRegistrationRepository = clientRegistrationRepository;
-		this.authorizedClientService = authorizedClientService;
+		this.authorizedClientRepository = authorizedClientRepository;
 	}
 
 	@Override
@@ -176,7 +190,7 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
 			authenticationResult.getAccessToken(),
 			authenticationResult.getRefreshToken());
 
-		this.authorizedClientService.saveAuthorizedClient(authorizedClient, oauth2Authentication);
+		this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, oauth2Authentication, request, response);
 
 		return oauth2Authentication;
 	}

+ 26 - 14
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java

@@ -70,10 +70,12 @@ public class OAuth2LoginAuthenticationFilterTests {
 	private ClientRegistration registration2;
 	private String principalName1 = "principal-1";
 	private ClientRegistrationRepository clientRegistrationRepository;
+	private OAuth2AuthorizedClientRepository authorizedClientRepository;
 	private OAuth2AuthorizedClientService authorizedClientService;
 	private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
 	private AuthenticationFailureHandler failureHandler;
 	private AuthenticationManager authenticationManager;
+	private OAuth2LoginAuthenticationToken loginAuthentication;
 	private OAuth2LoginAuthenticationFilter filter;
 
 	@Before
@@ -107,11 +109,12 @@ public class OAuth2LoginAuthenticationFilterTests {
 		this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(
 			this.registration1, this.registration2);
 		this.authorizedClientService = new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository);
+		this.authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(this.authorizedClientService);
 		this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
 		this.failureHandler = mock(AuthenticationFailureHandler.class);
 		this.authenticationManager = mock(AuthenticationManager.class);
-		this.filter = spy(new OAuth2LoginAuthenticationFilter(
-			this.clientRegistrationRepository, this.authorizedClientService));
+		this.filter = spy(new OAuth2LoginAuthenticationFilter(this.clientRegistrationRepository,
+				this.authorizedClientRepository, OAuth2LoginAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI));
 		this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
 		this.filter.setAuthenticationFailureHandler(this.failureHandler);
 		this.filter.setAuthenticationManager(this.authenticationManager);
@@ -129,9 +132,16 @@ public class OAuth2LoginAuthenticationFilterTests {
 				.isInstanceOf(IllegalArgumentException.class);
 	}
 
+	@Test
+	public void constructorWhenAuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2LoginAuthenticationFilter(this.clientRegistrationRepository,
+				(OAuth2AuthorizedClientRepository) null, OAuth2LoginAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
 	@Test
 	public void constructorWhenFilterProcessesUrlIsNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2LoginAuthenticationFilter(this.clientRegistrationRepository, this.authorizedClientService, null))
+		assertThatThrownBy(() -> new OAuth2LoginAuthenticationFilter(this.clientRegistrationRepository, this.authorizedClientRepository, null))
 				.isInstanceOf(IllegalArgumentException.class);
 	}
 
@@ -276,8 +286,8 @@ public class OAuth2LoginAuthenticationFilterTests {
 
 		this.filter.doFilter(request, response, filterChain);
 
-		OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
-			this.registration1.getRegistrationId(), this.principalName1);
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
+			this.registration1.getRegistrationId(), this.loginAuthentication, request);
 		assertThat(authorizedClient).isNotNull();
 		assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1);
 		assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName1);
@@ -289,7 +299,7 @@ public class OAuth2LoginAuthenticationFilterTests {
 	public void doFilterWhenCustomFilterProcessesUrlThenFilterProcesses() throws Exception {
 		String filterProcessesUrl = "/login/oauth2/custom/*";
 		this.filter = spy(new OAuth2LoginAuthenticationFilter(
-			this.clientRegistrationRepository, this.authorizedClientService, filterProcessesUrl));
+			this.clientRegistrationRepository, this.authorizedClientRepository, filterProcessesUrl));
 		this.filter.setAuthenticationManager(this.authenticationManager);
 
 		String requestUri = "/login/oauth2/custom/" + this.registration2.getRegistrationId();
@@ -324,13 +334,15 @@ public class OAuth2LoginAuthenticationFilterTests {
 	private void setUpAuthenticationResult(ClientRegistration registration) {
 		OAuth2User user = mock(OAuth2User.class);
 		when(user.getName()).thenReturn(this.principalName1);
-		OAuth2LoginAuthenticationToken loginAuthentication = mock(OAuth2LoginAuthenticationToken.class);
-		when(loginAuthentication.getPrincipal()).thenReturn(user);
-		when(loginAuthentication.getAuthorities()).thenReturn(AuthorityUtils.createAuthorityList("ROLE_USER"));
-		when(loginAuthentication.getClientRegistration()).thenReturn(registration);
-		when(loginAuthentication.getAuthorizationExchange()).thenReturn(mock(OAuth2AuthorizationExchange.class));
-		when(loginAuthentication.getAccessToken()).thenReturn(mock(OAuth2AccessToken.class));
-		when(loginAuthentication.getRefreshToken()).thenReturn(mock(OAuth2RefreshToken.class));
-		when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(loginAuthentication);
+		this.loginAuthentication = mock(OAuth2LoginAuthenticationToken.class);
+		when(this.loginAuthentication.getPrincipal()).thenReturn(user);
+		when(this.loginAuthentication.getName()).thenReturn(this.principalName1);
+		when(this.loginAuthentication.getAuthorities()).thenReturn(AuthorityUtils.createAuthorityList("ROLE_USER"));
+		when(this.loginAuthentication.getClientRegistration()).thenReturn(registration);
+		when(this.loginAuthentication.getAuthorizationExchange()).thenReturn(mock(OAuth2AuthorizationExchange.class));
+		when(this.loginAuthentication.getAccessToken()).thenReturn(mock(OAuth2AccessToken.class));
+		when(this.loginAuthentication.getRefreshToken()).thenReturn(mock(OAuth2RefreshToken.class));
+		when(this.loginAuthentication.isAuthenticated()).thenReturn(true);
+		when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(this.loginAuthentication);
 	}
 }

+ 21 - 4
samples/boot/authcodegrant/src/integration-test/java/org/springframework/security/samples/OAuth2AuthorizationCodeGrantApplicationTests.java

@@ -22,24 +22,29 @@ import org.springframework.boot.SpringBootConfiguration;
 import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
 import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc;
 import org.springframework.boot.test.context.SpringBootTest;
+import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.ComponentScan;
 import org.springframework.context.annotation.Import;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockHttpSession;
+import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
+import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
 import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
@@ -57,6 +62,7 @@ import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
+import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
@@ -78,7 +84,7 @@ public class OAuth2AuthorizationCodeGrantApplicationTests {
 	private ClientRegistrationRepository clientRegistrationRepository;
 
 	@Autowired
-	private OAuth2AuthorizedClientService authorizedClientService;
+	private OAuth2AuthorizedClientRepository authorizedClientRepository;
 
 	@Autowired
 	private MockMvc mockMvc;
@@ -116,18 +122,19 @@ public class OAuth2AuthorizationCodeGrantApplicationTests {
 		MockHttpSession session = (MockHttpSession) request.getSession();
 
 		String principalName = "user";
+		TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password");
 
 		// Authorization Response
 		this.mockMvc.perform(get("/github-repos")
 			.param(OAuth2ParameterNames.CODE, "code")
 			.param(OAuth2ParameterNames.STATE, "state")
-			.with(user(principalName))
+			.with(authentication(authentication))
 			.session(session))
 			.andExpect(status().is3xxRedirection())
 			.andExpect(redirectedUrl("http://localhost/github-repos"));
 
-		OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
-			registration.getRegistrationId(), principalName);
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
+			registration.getRegistrationId(), authentication, request);
 		assertThat(authorizedClient).isNotNull();
 	}
 
@@ -164,5 +171,15 @@ public class OAuth2AuthorizationCodeGrantApplicationTests {
 	@ComponentScan(basePackages = "sample.web")
 	@Import(WebClientConfig.class)
 	public static class SpringBootApplicationTestConfig {
+
+		@Bean
+		public OAuth2AuthorizedClientService authorizedClientService(ClientRegistrationRepository clientRegistrationRepository) {
+			return new InMemoryOAuth2AuthorizedClientService(clientRegistrationRepository);
+		}
+
+		@Bean
+		public OAuth2AuthorizedClientRepository authorizedClientRepository(OAuth2AuthorizedClientService authorizedClientService) {
+			return new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService);
+		}
 	}
 }

+ 8 - 0
samples/boot/authcodegrant/src/main/java/sample/config/SecurityConfig.java

@@ -22,6 +22,9 @@ import org.springframework.security.config.annotation.web.configuration.WebSecur
 import org.springframework.security.core.userdetails.User;
 import org.springframework.security.core.userdetails.UserDetails;
 import org.springframework.security.core.userdetails.UserDetailsService;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
+import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.provisioning.InMemoryUserDetailsManager;
 
 /**
@@ -43,6 +46,11 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter {
 					.authorizationCodeGrant();
 	}
 
+	@Bean
+	public OAuth2AuthorizedClientRepository authorizedClientRepository(OAuth2AuthorizedClientService authorizedClientService) {
+		return new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService);
+	}
+
 	@Bean
 	public UserDetailsService userDetailsService() {
 		UserDetails userDetails = User.withDefaultPasswordEncoder()

+ 9 - 5
samples/boot/oauth2login/src/integration-test/java/org/springframework/security/samples/OAuth2LoginApplicationTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -45,7 +45,9 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
 import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
+import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
@@ -403,12 +405,14 @@ public class OAuth2LoginApplicationTests {
 	@ComponentScan(basePackages = "sample.web")
 	public static class SpringBootApplicationTestConfig {
 
-		@Autowired
-		private ClientRegistrationRepository clientRegistrationRepository;
+		@Bean
+		public OAuth2AuthorizedClientService authorizedClientService(ClientRegistrationRepository clientRegistrationRepository) {
+			return new InMemoryOAuth2AuthorizedClientService(clientRegistrationRepository);
+		}
 
 		@Bean
-		public OAuth2AuthorizedClientService authorizedClientService() {
-			return new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository);
+		public OAuth2AuthorizedClientRepository authorizedClientRepository(OAuth2AuthorizedClientService authorizedClientService) {
+			return new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService);
 		}
 	}
 }

+ 9 - 0
samples/boot/oauth2login/src/main/java/sample/OAuth2LoginApplication.java

@@ -17,6 +17,10 @@ package sample;
 
 import org.springframework.boot.SpringApplication;
 import org.springframework.boot.autoconfigure.SpringBootApplication;
+import org.springframework.context.annotation.Bean;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
+import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 
 /**
  * @author Joe Grandja
@@ -27,4 +31,9 @@ public class OAuth2LoginApplication {
 	public static void main(String[] args) {
 		SpringApplication.run(OAuth2LoginApplication.class, args);
 	}
+
+	@Bean
+	public OAuth2AuthorizedClientRepository authorizedClientRepository(OAuth2AuthorizedClientService authorizedClientService) {
+		return new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(authorizedClientService);
+	}
 }