瀏覽代碼

Support anonymous Principal for OAuth2AuthorizedClient

Fixes gh-5064
Joe Grandja 7 年之前
父節點
當前提交
371221d729
共有 11 個文件被更改,包括 777 次插入64 次删除
  1. 3 1
      config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java
  2. 5 3
      config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java
  3. 104 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepository.java
  4. 89 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepository.java
  5. 14 16
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java
  6. 84 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientRepository.java
  7. 10 15
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java
  8. 122 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepositoryTests.java
  9. 261 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepositoryTests.java
  10. 61 8
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java
  11. 24 21
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java

+ 3 - 1
config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java

@@ -21,6 +21,7 @@ import org.springframework.context.annotation.Import;
 import org.springframework.context.annotation.ImportSelector;
 import org.springframework.context.annotation.ImportSelector;
 import org.springframework.core.type.AnnotationMetadata;
 import org.springframework.core.type.AnnotationMetadata;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
+import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver;
 import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver;
 import org.springframework.util.ClassUtils;
 import org.springframework.util.ClassUtils;
 import org.springframework.web.method.support.HandlerMethodArgumentResolver;
 import org.springframework.web.method.support.HandlerMethodArgumentResolver;
@@ -63,7 +64,8 @@ final class OAuth2ClientConfiguration {
 		public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) {
 		public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) {
 			if (this.authorizedClientService != null) {
 			if (this.authorizedClientService != null) {
 				OAuth2AuthorizedClientArgumentResolver authorizedClientArgumentResolver =
 				OAuth2AuthorizedClientArgumentResolver authorizedClientArgumentResolver =
-						new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientService);
+						new OAuth2AuthorizedClientArgumentResolver(
+								new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(this.authorizedClientService));
 				argumentResolvers.add(authorizedClientArgumentResolver);
 				argumentResolvers.add(authorizedClientArgumentResolver);
 			}
 			}
 		}
 		}

+ 5 - 3
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java

@@ -24,6 +24,7 @@ import org.springframework.security.oauth2.client.endpoint.NimbusAuthorizationCo
 import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
 import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
@@ -287,9 +288,10 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> exte
 		AuthenticationManager authenticationManager = builder.getSharedObject(AuthenticationManager.class);
 		AuthenticationManager authenticationManager = builder.getSharedObject(AuthenticationManager.class);
 
 
 		OAuth2AuthorizationCodeGrantFilter authorizationCodeGrantFilter = new OAuth2AuthorizationCodeGrantFilter(
 		OAuth2AuthorizationCodeGrantFilter authorizationCodeGrantFilter = new OAuth2AuthorizationCodeGrantFilter(
-			OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder),
-			OAuth2ClientConfigurerUtils.getAuthorizedClientService(builder),
-			authenticationManager);
+				OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder),
+				new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(
+						OAuth2ClientConfigurerUtils.getAuthorizedClientService(builder)),
+				authenticationManager);
 
 
 		if (authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestRepository != null) {
 		if (authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestRepository != null) {
 			authorizationCodeGrantFilter.setAuthorizationRequestRepository(
 			authorizationCodeGrantFilter.setAuthorizationRequestRepository(

+ 104 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepository.java

@@ -0,0 +1,104 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.web;
+
+import org.springframework.security.authentication.AuthenticationTrustResolver;
+import org.springframework.security.authentication.AuthenticationTrustResolverImpl;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
+import org.springframework.util.Assert;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
+/**
+ * An implementation of an {@link OAuth2AuthorizedClientRepository} that
+ * delegates to the provided {@link OAuth2AuthorizedClientService} if the current
+ * {@code Principal} is authenticated, otherwise,
+ * to the default (or provided) {@link OAuth2AuthorizedClientRepository}
+ * if the current request is unauthenticated (or anonymous).
+ * The default {@code OAuth2AuthorizedClientRepository} is {@link HttpSessionOAuth2AuthorizedClientRepository}.
+ *
+ * @author Joe Grandja
+ * @since 5.1
+ * @see OAuth2AuthorizedClientRepository
+ * @see OAuth2AuthorizedClient
+ * @see OAuth2AuthorizedClientService
+ * @see HttpSessionOAuth2AuthorizedClientRepository
+ */
+public final class AuthenticatedPrincipalOAuth2AuthorizedClientRepository implements OAuth2AuthorizedClientRepository {
+	private final AuthenticationTrustResolver authenticationTrustResolver = new AuthenticationTrustResolverImpl();
+	private final OAuth2AuthorizedClientService authorizedClientService;
+	private OAuth2AuthorizedClientRepository anonymousAuthorizedClientRepository = new HttpSessionOAuth2AuthorizedClientRepository();
+
+	/**
+	 * Constructs a {@code AuthenticatedPrincipalOAuth2AuthorizedClientRepository} using the provided parameters.
+	 *
+	 * @param authorizedClientService the authorized client service
+	 */
+	public AuthenticatedPrincipalOAuth2AuthorizedClientRepository(OAuth2AuthorizedClientService authorizedClientService) {
+		Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
+		this.authorizedClientService = authorizedClientService;
+	}
+
+	/**
+	 * Sets the {@link OAuth2AuthorizedClientRepository} used for requests that are unauthenticated (or anonymous).
+	 * The default is {@link HttpSessionOAuth2AuthorizedClientRepository}.
+	 *
+	 * @param anonymousAuthorizedClientRepository the repository used for requests that are unauthenticated (or anonymous)
+	 */
+	public final void setAnonymousAuthorizedClientRepository(OAuth2AuthorizedClientRepository anonymousAuthorizedClientRepository) {
+		Assert.notNull(anonymousAuthorizedClientRepository, "anonymousAuthorizedClientRepository cannot be null");
+		this.anonymousAuthorizedClientRepository = anonymousAuthorizedClientRepository;
+	}
+
+	@Override
+	public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRegistrationId, Authentication principal,
+																		HttpServletRequest request) {
+		if (this.isPrincipalAuthenticated(principal)) {
+			return this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName());
+		} else {
+			return this.anonymousAuthorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, request);
+		}
+	}
+
+	@Override
+	public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal,
+										HttpServletRequest request, HttpServletResponse response) {
+		if (this.isPrincipalAuthenticated(principal)) {
+			this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal);
+		} else {
+			this.anonymousAuthorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, request, response);
+		}
+	}
+
+	@Override
+	public void removeAuthorizedClient(String clientRegistrationId, Authentication principal,
+										HttpServletRequest request, HttpServletResponse response) {
+		if (this.isPrincipalAuthenticated(principal)) {
+			this.authorizedClientService.removeAuthorizedClient(clientRegistrationId, principal.getName());
+		} else {
+			this.anonymousAuthorizedClientRepository.removeAuthorizedClient(clientRegistrationId, principal, request, response);
+		}
+	}
+
+	private boolean isPrincipalAuthenticated(Authentication authentication) {
+		return authentication != null &&
+				!this.authenticationTrustResolver.isAnonymous(authentication) &&
+				authentication.isAuthenticated();
+	}
+}

+ 89 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepository.java

@@ -0,0 +1,89 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.web;
+
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.util.Assert;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import javax.servlet.http.HttpSession;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * An implementation of an {@link OAuth2AuthorizedClientRepository} that stores
+ * {@link OAuth2AuthorizedClient}'s in the {@code HttpSession}.
+ *
+ * @author Joe Grandja
+ * @since 5.1
+ * @see OAuth2AuthorizedClientRepository
+ * @see OAuth2AuthorizedClient
+ */
+public final class HttpSessionOAuth2AuthorizedClientRepository implements OAuth2AuthorizedClientRepository {
+	private static final String DEFAULT_AUTHORIZED_CLIENTS_ATTR_NAME =
+			HttpSessionOAuth2AuthorizedClientRepository.class.getName() +  ".AUTHORIZED_CLIENTS";
+	private final String sessionAttributeName = DEFAULT_AUTHORIZED_CLIENTS_ATTR_NAME;
+
+	@SuppressWarnings("unchecked")
+	@Override
+	public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRegistrationId, Authentication principal,
+																		HttpServletRequest request) {
+		Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
+		Assert.notNull(request, "request cannot be null");
+		return (T) this.getAuthorizedClients(request).get(clientRegistrationId);
+	}
+
+	@Override
+	public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal,
+										HttpServletRequest request, HttpServletResponse response) {
+		Assert.notNull(authorizedClient, "authorizedClient cannot be null");
+		Assert.notNull(request, "request cannot be null");
+		Assert.notNull(response, "response cannot be null");
+		Map<String, OAuth2AuthorizedClient> authorizedClients = this.getAuthorizedClients(request);
+		authorizedClients.put(authorizedClient.getClientRegistration().getRegistrationId(), authorizedClient);
+		request.getSession().setAttribute(this.sessionAttributeName, authorizedClients);
+	}
+
+	@Override
+	public void removeAuthorizedClient(String clientRegistrationId, Authentication principal,
+										HttpServletRequest request, HttpServletResponse response) {
+		Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
+		Assert.notNull(request, "request cannot be null");
+		Map<String, OAuth2AuthorizedClient> authorizedClients = this.getAuthorizedClients(request);
+		if (!authorizedClients.isEmpty()) {
+			if (authorizedClients.remove(clientRegistrationId) != null) {
+				if (!authorizedClients.isEmpty()) {
+					request.getSession().setAttribute(this.sessionAttributeName, authorizedClients);
+				} else {
+					request.getSession().removeAttribute(this.sessionAttributeName);
+				}
+			}
+		}
+	}
+
+	@SuppressWarnings("unchecked")
+	private Map<String, OAuth2AuthorizedClient> getAuthorizedClients(HttpServletRequest request) {
+		HttpSession session = request.getSession(false);
+		Map<String, OAuth2AuthorizedClient> authorizedClients = session == null ? null :
+				(Map<String, OAuth2AuthorizedClient>) session.getAttribute(this.sessionAttributeName);
+		if (authorizedClients == null) {
+			authorizedClients = new HashMap<>();
+		}
+		return authorizedClients;
+	}
+}

+ 14 - 16
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java

@@ -15,19 +15,11 @@
  */
  */
 package org.springframework.security.oauth2.client.web;
 package org.springframework.security.oauth2.client.web;
 
 
-import java.io.IOException;
-
-import javax.servlet.FilterChain;
-import javax.servlet.ServletException;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-
 import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
-import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationProvider;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationProvider;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
@@ -51,6 +43,12 @@ import org.springframework.util.StringUtils;
 import org.springframework.web.filter.OncePerRequestFilter;
 import org.springframework.web.filter.OncePerRequestFilter;
 import org.springframework.web.util.UriComponentsBuilder;
 import org.springframework.web.util.UriComponentsBuilder;
 
 
+import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.io.IOException;
+
 /**
 /**
  * A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
  * A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
  * which handles the processing of the OAuth 2.0 Authorization Response.
  * which handles the processing of the OAuth 2.0 Authorization Response.
@@ -74,7 +72,7 @@ import org.springframework.web.util.UriComponentsBuilder;
  *  Upon a successful authentication, an {@link OAuth2AuthorizedClient Authorized Client} is created by associating the
  *  Upon a successful authentication, an {@link OAuth2AuthorizedClient Authorized Client} is created by associating the
  *  {@link OAuth2AuthorizationCodeAuthenticationToken#getClientRegistration() client} to the
  *  {@link OAuth2AuthorizationCodeAuthenticationToken#getClientRegistration() client} to the
  *  {@link OAuth2AuthorizationCodeAuthenticationToken#getAccessToken() access token} and current {@code Principal}
  *  {@link OAuth2AuthorizationCodeAuthenticationToken#getAccessToken() access token} and current {@code Principal}
- *  and saving it via the {@link OAuth2AuthorizedClientService}.
+ *  and saving it via the {@link OAuth2AuthorizedClientRepository}.
  * </li>
  * </li>
  * </ul>
  * </ul>
  *
  *
@@ -88,13 +86,13 @@ import org.springframework.web.util.UriComponentsBuilder;
  * @see OAuth2AuthorizationRequestRedirectFilter
  * @see OAuth2AuthorizationRequestRedirectFilter
  * @see ClientRegistrationRepository
  * @see ClientRegistrationRepository
  * @see OAuth2AuthorizedClient
  * @see OAuth2AuthorizedClient
- * @see OAuth2AuthorizedClientService
+ * @see OAuth2AuthorizedClientRepository
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a>
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a>
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.2">Section 4.1.2 Authorization Response</a>
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.2">Section 4.1.2 Authorization Response</a>
  */
  */
 public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
 public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
 	private final ClientRegistrationRepository clientRegistrationRepository;
 	private final ClientRegistrationRepository clientRegistrationRepository;
-	private final OAuth2AuthorizedClientService authorizedClientService;
+	private final OAuth2AuthorizedClientRepository authorizedClientRepository;
 	private final AuthenticationManager authenticationManager;
 	private final AuthenticationManager authenticationManager;
 	private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
 	private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
 		new HttpSessionOAuth2AuthorizationRequestRepository();
 		new HttpSessionOAuth2AuthorizationRequestRepository();
@@ -106,17 +104,17 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
 	 * Constructs an {@code OAuth2AuthorizationCodeGrantFilter} using the provided parameters.
 	 * Constructs an {@code OAuth2AuthorizationCodeGrantFilter} using the provided parameters.
 	 *
 	 *
 	 * @param clientRegistrationRepository the repository of client registrations
 	 * @param clientRegistrationRepository the repository of client registrations
-	 * @param authorizedClientService the authorized client service
+	 * @param authorizedClientRepository the authorized client repository
 	 * @param authenticationManager the authentication manager
 	 * @param authenticationManager the authentication manager
 	 */
 	 */
 	public OAuth2AuthorizationCodeGrantFilter(ClientRegistrationRepository clientRegistrationRepository,
 	public OAuth2AuthorizationCodeGrantFilter(ClientRegistrationRepository clientRegistrationRepository,
-												OAuth2AuthorizedClientService authorizedClientService,
+												OAuth2AuthorizedClientRepository authorizedClientRepository,
 												AuthenticationManager authenticationManager) {
 												AuthenticationManager authenticationManager) {
 		Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
 		Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
-		Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
+		Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
 		Assert.notNull(authenticationManager, "authenticationManager cannot be null");
 		Assert.notNull(authenticationManager, "authenticationManager cannot be null");
 		this.clientRegistrationRepository = clientRegistrationRepository;
 		this.clientRegistrationRepository = clientRegistrationRepository;
-		this.authorizedClientService = authorizedClientService;
+		this.authorizedClientRepository = authorizedClientRepository;
 		this.authenticationManager = authenticationManager;
 		this.authenticationManager = authenticationManager;
 	}
 	}
 
 
@@ -201,7 +199,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
 			authenticationResult.getAccessToken(),
 			authenticationResult.getAccessToken(),
 			authenticationResult.getRefreshToken());
 			authenticationResult.getRefreshToken());
 
 
-		this.authorizedClientService.saveAuthorizedClient(authorizedClient, currentAuthentication);
+		this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, currentAuthentication, request, response);
 
 
 		String redirectUrl = authorizationResponse.getRedirectUri();
 		String redirectUrl = authorizationResponse.getRedirectUri();
 		SavedRequest savedRequest = this.requestCache.getRequest(request, response);
 		SavedRequest savedRequest = this.requestCache.getRequest(request, response);

+ 84 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientRepository.java

@@ -0,0 +1,84 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.web;
+
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
+/**
+ * Implementations of this interface are responsible for the persistence
+ * of {@link OAuth2AuthorizedClient Authorized Client(s)} between requests.
+ *
+ * <p>
+ * The primary purpose of an {@link OAuth2AuthorizedClient Authorized Client}
+ * is to associate an {@link OAuth2AuthorizedClient#getAccessToken() Access Token} credential
+ * to a {@link OAuth2AuthorizedClient#getClientRegistration() Client} and Resource Owner,
+ * who is the {@link OAuth2AuthorizedClient#getPrincipalName() Principal}
+ * that originally granted the authorization.
+ *
+ * @author Joe Grandja
+ * @since 5.1
+ * @see OAuth2AuthorizedClient
+ * @see ClientRegistration
+ * @see Authentication
+ * @see OAuth2AccessToken
+ */
+public interface OAuth2AuthorizedClientRepository {
+
+	/**
+	 * Returns the {@link OAuth2AuthorizedClient} associated to the
+	 * provided client registration identifier and End-User {@link Authentication} (Resource Owner)
+	 * or {@code null} if not available.
+	 *
+	 * @param clientRegistrationId the identifier for the client's registration
+	 * @param principal the End-User {@link Authentication} (Resource Owner)
+	 * @param request the {@code HttpServletRequest}
+	 * @param <T> a type of OAuth2AuthorizedClient
+	 * @return the {@link OAuth2AuthorizedClient} or {@code null} if not available
+	 */
+	<T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRegistrationId, Authentication principal,
+																HttpServletRequest request);
+
+	/**
+	 * Saves the {@link OAuth2AuthorizedClient} associating it to
+	 * the provided End-User {@link Authentication} (Resource Owner).
+	 *
+	 * @param authorizedClient the authorized client
+	 * @param principal the End-User {@link Authentication} (Resource Owner)
+	 * @param request the {@code HttpServletRequest}
+	 * @param response the {@code HttpServletResponse}
+	 */
+	void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal,
+								HttpServletRequest request, HttpServletResponse response);
+
+	/**
+	 * Removes the {@link OAuth2AuthorizedClient} associated to the
+	 * provided client registration identifier and End-User {@link Authentication} (Resource Owner).
+	 *
+	 * @param clientRegistrationId the identifier for the client's registration
+	 * @param principal the End-User {@link Authentication} (Resource Owner)
+	 * @param request the {@code HttpServletRequest}
+	 * @param response the {@code HttpServletResponse}
+	 */
+	void removeAuthorizedClient(String clientRegistrationId, Authentication principal,
+								HttpServletRequest request, HttpServletResponse response);
+
+}

+ 10 - 15
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java

@@ -23,9 +23,9 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
 import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
-import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.util.Assert;
 import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
 import org.springframework.util.StringUtils;
 import org.springframework.web.bind.support.WebDataBinderFactory;
 import org.springframework.web.bind.support.WebDataBinderFactory;
@@ -33,6 +33,8 @@ import org.springframework.web.context.request.NativeWebRequest;
 import org.springframework.web.method.support.HandlerMethodArgumentResolver;
 import org.springframework.web.method.support.HandlerMethodArgumentResolver;
 import org.springframework.web.method.support.ModelAndViewContainer;
 import org.springframework.web.method.support.ModelAndViewContainer;
 
 
+import javax.servlet.http.HttpServletRequest;
+
 /**
 /**
  * An implementation of a {@link HandlerMethodArgumentResolver} that is capable
  * An implementation of a {@link HandlerMethodArgumentResolver} that is capable
  * of resolving a method parameter to an argument value of type {@link OAuth2AuthorizedClient}.
  * of resolving a method parameter to an argument value of type {@link OAuth2AuthorizedClient}.
@@ -54,16 +56,16 @@ import org.springframework.web.method.support.ModelAndViewContainer;
  * @see RegisteredOAuth2AuthorizedClient
  * @see RegisteredOAuth2AuthorizedClient
  */
  */
 public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver {
 public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver {
-	private final OAuth2AuthorizedClientService authorizedClientService;
+	private final OAuth2AuthorizedClientRepository authorizedClientRepository;
 
 
 	/**
 	/**
 	 * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters.
 	 * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters.
 	 *
 	 *
-	 * @param authorizedClientService the authorized client service
+	 * @param authorizedClientRepository the authorized client repository
 	 */
 	 */
-	public OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientService authorizedClientService) {
-		Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
-		this.authorizedClientService = authorizedClientService;
+	public OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientRepository authorizedClientRepository) {
+		Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
+		this.authorizedClientRepository = authorizedClientRepository;
 	}
 	}
 
 
 	@Override
 	@Override
@@ -98,15 +100,8 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
 					"It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or @RegisteredOAuth2AuthorizedClient(registrationId = \"client1\").");
 					"It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or @RegisteredOAuth2AuthorizedClient(registrationId = \"client1\").");
 		}
 		}
 
 
-		if (principal == null) {
-			// An Authentication is required given that an OAuth2AuthorizedClient is associated to a Principal
-			throw new IllegalStateException("Unable to resolve the Authorized Client with registration identifier \"" +
-					clientRegistrationId + "\". An \"authenticated\" or \"unauthenticated\" session is required. " +
-					"To allow for unauthenticated access, ensure HttpSecurity.anonymous() is configured.");
-		}
-
-		OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
-			clientRegistrationId, principal.getName());
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
+			clientRegistrationId, principal, webRequest.getNativeRequest(HttpServletRequest.class));
 		if (authorizedClient == null) {
 		if (authorizedClient == null) {
 			throw new ClientAuthorizationRequiredException(clientRegistrationId);
 			throw new ClientAuthorizationRequiredException(clientRegistrationId);
 		}
 		}

+ 122 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthenticatedPrincipalOAuth2AuthorizedClientRepositoryTests.java

@@ -0,0 +1,122 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.web;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.authentication.AnonymousAuthenticationToken;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
+
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+
+/**
+ * Tests for {@link AuthenticatedPrincipalOAuth2AuthorizedClientRepository}.
+ *
+ * @author Joe Grandja
+ */
+public class AuthenticatedPrincipalOAuth2AuthorizedClientRepositoryTests {
+	private String registrationId = "registrationId";
+	private String principalName = "principalName";
+	private OAuth2AuthorizedClientService authorizedClientService;
+	private OAuth2AuthorizedClientRepository anonymousAuthorizedClientRepository;
+	private AuthenticatedPrincipalOAuth2AuthorizedClientRepository authorizedClientRepository;
+	private MockHttpServletRequest request;
+	private MockHttpServletResponse response;
+
+	@Before
+	public void setup() {
+		this.authorizedClientService = mock(OAuth2AuthorizedClientService.class);
+		this.anonymousAuthorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class);
+		this.authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(this.authorizedClientService);
+		this.authorizedClientRepository.setAnonymousAuthorizedClientRepository(this.anonymousAuthorizedClientRepository);
+		this.request = new MockHttpServletRequest();
+		this.response = new MockHttpServletResponse();
+	}
+
+	@Test
+	public void constructorWhenAuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void setAuthorizedClientRepositoryWhenAuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientRepository.setAnonymousAuthorizedClientRepository(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void loadAuthorizedClientWhenAuthenticatedPrincipalThenLoadFromService() {
+		Authentication authentication = this.createAuthenticatedPrincipal();
+		this.authorizedClientRepository.loadAuthorizedClient(this.registrationId, authentication, this.request);
+		verify(this.authorizedClientService).loadAuthorizedClient(this.registrationId, this.principalName);
+	}
+
+	@Test
+	public void loadAuthorizedClientWhenAnonymousPrincipalThenLoadFromAnonymousRepository() {
+		Authentication authentication = this.createAnonymousPrincipal();
+		this.authorizedClientRepository.loadAuthorizedClient(this.registrationId, authentication, this.request);
+		verify(this.anonymousAuthorizedClientRepository).loadAuthorizedClient(this.registrationId, authentication, this.request);
+	}
+
+	@Test
+	public void saveAuthorizedClientWhenAuthenticatedPrincipalThenSaveToService() {
+		Authentication authentication = this.createAuthenticatedPrincipal();
+		OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class);
+		this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, this.request, this.response);
+		verify(this.authorizedClientService).saveAuthorizedClient(authorizedClient, authentication);
+	}
+
+	@Test
+	public void saveAuthorizedClientWhenAnonymousPrincipalThenSaveToAnonymousRepository() {
+		Authentication authentication = this.createAnonymousPrincipal();
+		OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class);
+		this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, this.request, this.response);
+		verify(this.anonymousAuthorizedClientRepository).saveAuthorizedClient(authorizedClient, authentication, this.request, this.response);
+	}
+
+	@Test
+	public void removeAuthorizedClientWhenAuthenticatedPrincipalThenRemoveFromService() {
+		Authentication authentication = this.createAuthenticatedPrincipal();
+		this.authorizedClientRepository.removeAuthorizedClient(this.registrationId, authentication, this.request, this.response);
+		verify(this.authorizedClientService).removeAuthorizedClient(this.registrationId, this.principalName);
+	}
+
+	@Test
+	public void removeAuthorizedClientWhenAnonymousPrincipalThenRemoveFromAnonymousRepository() {
+		Authentication authentication = this.createAnonymousPrincipal();
+		this.authorizedClientRepository.removeAuthorizedClient(this.registrationId, authentication, this.request, this.response);
+		verify(this.anonymousAuthorizedClientRepository).removeAuthorizedClient(this.registrationId, authentication, this.request, this.response);
+	}
+
+	private Authentication createAuthenticatedPrincipal() {
+		TestingAuthenticationToken authentication = new TestingAuthenticationToken(this.principalName, "password");
+		authentication.setAuthenticated(true);
+		return authentication;
+	}
+
+	private Authentication createAnonymousPrincipal() {
+		return new AnonymousAuthenticationToken("key-1234", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
+	}
+}

+ 261 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizedClientRepositoryTests.java

@@ -0,0 +1,261 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.web;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+
+import javax.servlet.http.HttpSession;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Tests for {@link HttpSessionOAuth2AuthorizedClientRepository}.
+ *
+ * @author Joe Grandja
+ */
+public class HttpSessionOAuth2AuthorizedClientRepositoryTests {
+	private String registrationId1 = "registration-1";
+	private String registrationId2 = "registration-2";
+	private String principalName1 = "principalName-1";
+
+	private ClientRegistration registration1 = ClientRegistration.withRegistrationId(this.registrationId1)
+			.clientId("client-1")
+			.clientSecret("secret")
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.redirectUriTemplate("{baseUrl}/login/oauth2/code/{registrationId}")
+			.scope("user")
+			.authorizationUri("https://provider.com/oauth2/authorize")
+			.tokenUri("https://provider.com/oauth2/token")
+			.userInfoUri("https://provider.com/oauth2/user")
+			.userNameAttributeName("id")
+			.clientName("client-1")
+			.build();
+
+	private ClientRegistration registration2 = ClientRegistration.withRegistrationId(this.registrationId2)
+			.clientId("client-2")
+			.clientSecret("secret")
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.redirectUriTemplate("{baseUrl}/login/oauth2/code/{registrationId}")
+			.scope("openid", "profile", "email")
+			.authorizationUri("https://provider.com/oauth2/authorize")
+			.tokenUri("https://provider.com/oauth2/token")
+			.userInfoUri("https://provider.com/oauth2/userinfo")
+			.jwkSetUri("https://provider.com/oauth2/keys")
+			.clientName("client-2")
+			.build();
+
+	private HttpSessionOAuth2AuthorizedClientRepository authorizedClientRepository =
+			new HttpSessionOAuth2AuthorizedClientRepository();
+
+	private MockHttpServletRequest request;
+
+	private MockHttpServletResponse response;
+
+	@Before
+	public void setup() {
+		this.request = new MockHttpServletRequest();
+		this.response = new MockHttpServletResponse();
+	}
+
+	@Test
+	public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientRepository.loadAuthorizedClient(null, null, this.request))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void loadAuthorizedClientWhenPrincipalNameIsNullThenExceptionNotThrown() {
+		this.authorizedClientRepository.loadAuthorizedClient(this.registrationId1, null, this.request);
+	}
+
+	@Test
+	public void loadAuthorizedClientWhenRequestIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientRepository.loadAuthorizedClient(this.registrationId1, null, null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void loadAuthorizedClientWhenClientRegistrationNotFoundThenReturnNull() {
+		OAuth2AuthorizedClient authorizedClient =
+				this.authorizedClientRepository.loadAuthorizedClient("registration-not-found", null, this.request);
+		assertThat(authorizedClient).isNull();
+	}
+
+	@Test
+	public void loadAuthorizedClientWhenSavedThenReturnAuthorizedClient() {
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+				this.registration1, this.principalName1, mock(OAuth2AccessToken.class));
+		this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response);
+
+		OAuth2AuthorizedClient loadedAuthorizedClient =
+				this.authorizedClientRepository.loadAuthorizedClient(this.registrationId1, null, this.request);
+		assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient);
+	}
+
+	@Test
+	public void saveAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientRepository.saveAuthorizedClient(null, null, this.request, this.response))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void saveAuthorizedClientWhenAuthenticationIsNullThenExceptionNotThrown() {
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+				this.registration2, this.principalName1, mock(OAuth2AccessToken.class));
+		this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response);
+	}
+
+	@Test
+	public void saveAuthorizedClientWhenRequestIsNullThenThrowIllegalArgumentException() {
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+				this.registration2, this.principalName1, mock(OAuth2AccessToken.class));
+		assertThatThrownBy(() -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, null, this.response))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void saveAuthorizedClientWhenResponseIsNullThenThrowIllegalArgumentException() {
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+				this.registration2, this.principalName1, mock(OAuth2AccessToken.class));
+		assertThatThrownBy(() -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void saveAuthorizedClientWhenSavedThenSavedToSession() {
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+				this.registration2, this.principalName1, mock(OAuth2AccessToken.class));
+		this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response);
+
+		HttpSession session = this.request.getSession(false);
+		assertThat(session).isNotNull();
+
+		@SuppressWarnings("unchecked")
+		Map<String, OAuth2AuthorizedClient> authorizedClients = (Map<String, OAuth2AuthorizedClient>)
+				session.getAttribute(HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS");
+		assertThat(authorizedClients).isNotEmpty();
+		assertThat(authorizedClients).hasSize(1);
+		assertThat(authorizedClients.values().iterator().next()).isSameAs(authorizedClient);
+	}
+
+	@Test
+	public void removeAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientRepository.removeAuthorizedClient(
+				null, null, this.request, this.response)).isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void removeAuthorizedClientWhenPrincipalNameIsNullThenExceptionNotThrown() {
+		this.authorizedClientRepository.removeAuthorizedClient(this.registrationId1, null, this.request, this.response);
+	}
+
+	@Test
+	public void removeAuthorizedClientWhenRequestIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientRepository.removeAuthorizedClient(
+				this.registrationId1, null, null, this.response)).isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void removeAuthorizedClientWhenResponseIsNullThenExceptionNotThrown() {
+		this.authorizedClientRepository.removeAuthorizedClient(this.registrationId1, null, this.request, null);
+	}
+
+	@Test
+	public void removeAuthorizedClientWhenNotSavedThenSessionNotCreated() {
+		this.authorizedClientRepository.removeAuthorizedClient(
+				this.registrationId2, null, this.request, this.response);
+		assertThat(this.request.getSession(false)).isNull();
+	}
+
+	@Test
+	public void removeAuthorizedClientWhenClient1SavedAndClient2RemovedThenClient1NotRemoved() {
+		OAuth2AuthorizedClient authorizedClient1 = new OAuth2AuthorizedClient(
+				this.registration1, this.principalName1, mock(OAuth2AccessToken.class));
+		this.authorizedClientRepository.saveAuthorizedClient(authorizedClient1, null, this.request, this.response);
+
+		// Remove registrationId2 (never added so is not removed either)
+		this.authorizedClientRepository.removeAuthorizedClient(
+				this.registrationId2, null, this.request, this.response);
+
+		OAuth2AuthorizedClient loadedAuthorizedClient1 = this.authorizedClientRepository.loadAuthorizedClient(
+				this.registrationId1, null, this.request);
+		assertThat(loadedAuthorizedClient1).isNotNull();
+		assertThat(loadedAuthorizedClient1).isSameAs(authorizedClient1);
+	}
+
+	@Test
+	public void removeAuthorizedClientWhenSavedThenRemoved() {
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+				this.registration2, this.principalName1, mock(OAuth2AccessToken.class));
+		this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response);
+		OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
+				this.registrationId2, null, this.request);
+		assertThat(loadedAuthorizedClient).isSameAs(authorizedClient);
+		this.authorizedClientRepository.removeAuthorizedClient(
+				this.registrationId2, null, this.request, this.response);
+		loadedAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
+				this.registrationId2, null, this.request);
+		assertThat(loadedAuthorizedClient).isNull();
+	}
+
+	@Test
+	public void removeAuthorizedClientWhenSavedThenRemovedFromSession() {
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+				this.registration1, this.principalName1, mock(OAuth2AccessToken.class));
+		this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, this.request, this.response);
+		OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
+				this.registrationId1, null, this.request);
+		assertThat(loadedAuthorizedClient).isSameAs(authorizedClient);
+		this.authorizedClientRepository.removeAuthorizedClient(
+				this.registrationId1, null, this.request, this.response);
+
+		HttpSession session = this.request.getSession(false);
+		assertThat(session).isNotNull();
+		assertThat(session.getAttribute(HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS")).isNull();
+	}
+
+	@Test
+	public void removeAuthorizedClientWhenClient1Client2SavedAndClient1RemovedThenClient2NotRemoved() {
+		OAuth2AuthorizedClient authorizedClient1 = new OAuth2AuthorizedClient(
+				this.registration1, this.principalName1, mock(OAuth2AccessToken.class));
+		this.authorizedClientRepository.saveAuthorizedClient(authorizedClient1, null, this.request, this.response);
+
+		OAuth2AuthorizedClient authorizedClient2 = new OAuth2AuthorizedClient(
+				this.registration2, this.principalName1, mock(OAuth2AccessToken.class));
+		this.authorizedClientRepository.saveAuthorizedClient(authorizedClient2, null, this.request, this.response);
+
+		this.authorizedClientRepository.removeAuthorizedClient(
+				this.registrationId1, null, this.request, this.response);
+
+		OAuth2AuthorizedClient loadedAuthorizedClient2 = this.authorizedClientRepository.loadAuthorizedClient(
+				this.registrationId2, null, this.request);
+		assertThat(loadedAuthorizedClient2).isNotNull();
+		assertThat(loadedAuthorizedClient2).isSameAs(authorizedClient2);
+	}
+}

+ 61 - 8
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java

@@ -15,6 +15,7 @@
  */
  */
 package org.springframework.security.oauth2.client.web;
 package org.springframework.security.oauth2.client.web;
 
 
+import org.junit.After;
 import org.junit.Before;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runner.RunWith;
@@ -23,9 +24,11 @@ import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
 import org.powermock.modules.junit4.PowerMockRunner;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
@@ -51,6 +54,7 @@ import org.springframework.security.web.savedrequest.RequestCache;
 import javax.servlet.FilterChain;
 import javax.servlet.FilterChain;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 import javax.servlet.http.HttpServletResponse;
+import javax.servlet.http.HttpSession;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Map;
 
 
@@ -71,12 +75,13 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
 	private String principalName1 = "principal-1";
 	private String principalName1 = "principal-1";
 	private ClientRegistrationRepository clientRegistrationRepository;
 	private ClientRegistrationRepository clientRegistrationRepository;
 	private OAuth2AuthorizedClientService authorizedClientService;
 	private OAuth2AuthorizedClientService authorizedClientService;
+	private OAuth2AuthorizedClientRepository authorizedClientRepository;
 	private AuthenticationManager authenticationManager;
 	private AuthenticationManager authenticationManager;
 	private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
 	private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
 	private OAuth2AuthorizationCodeGrantFilter filter;
 	private OAuth2AuthorizationCodeGrantFilter filter;
 
 
 	@Before
 	@Before
-	public void setUp() {
+	public void setup() {
 		this.registration1 = ClientRegistration.withRegistrationId("registration-1")
 		this.registration1 = ClientRegistration.withRegistrationId("registration-1")
 			.clientId("client-1")
 			.clientId("client-1")
 			.clientSecret("secret")
 			.clientSecret("secret")
@@ -92,32 +97,39 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
 			.build();
 			.build();
 		this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1);
 		this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1);
 		this.authorizedClientService = new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository);
 		this.authorizedClientService = new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository);
+		this.authorizedClientRepository = new AuthenticatedPrincipalOAuth2AuthorizedClientRepository(this.authorizedClientService);
 		this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
 		this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
 		this.authenticationManager = mock(AuthenticationManager.class);
 		this.authenticationManager = mock(AuthenticationManager.class);
 		this.filter = spy(new OAuth2AuthorizationCodeGrantFilter(
 		this.filter = spy(new OAuth2AuthorizationCodeGrantFilter(
-			this.clientRegistrationRepository, this.authorizedClientService, this.authenticationManager));
+			this.clientRegistrationRepository, this.authorizedClientRepository, this.authenticationManager));
 		this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
 		this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
-
+		TestingAuthenticationToken authentication = new TestingAuthenticationToken(this.principalName1, "password");
+		authentication.setAuthenticated(true);
 		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
 		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
-		securityContext.setAuthentication(new TestingAuthenticationToken(this.principalName1, "password"));
+		securityContext.setAuthentication(authentication);
 		SecurityContextHolder.setContext(securityContext);
 		SecurityContextHolder.setContext(securityContext);
 	}
 	}
 
 
+	@After
+	public void cleanup() {
+		SecurityContextHolder.clearContext();
+	}
+
 	@Test
 	@Test
 	public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
 	public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(null, this.authorizedClientService, this.authenticationManager))
+		assertThatThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(null, this.authorizedClientRepository, this.authenticationManager))
 				.isInstanceOf(IllegalArgumentException.class);
 				.isInstanceOf(IllegalArgumentException.class);
 	}
 	}
 
 
 	@Test
 	@Test
-	public void constructorWhenAuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
+	public void constructorWhenAuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(this.clientRegistrationRepository, null, this.authenticationManager))
 		assertThatThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(this.clientRegistrationRepository, null, this.authenticationManager))
 				.isInstanceOf(IllegalArgumentException.class);
 				.isInstanceOf(IllegalArgumentException.class);
 	}
 	}
 
 
 	@Test
 	@Test
 	public void constructorWhenAuthenticationManagerIsNullThenThrowIllegalArgumentException() {
 	public void constructorWhenAuthenticationManagerIsNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(this.clientRegistrationRepository, this.authorizedClientService, null))
+		assertThatThrownBy(() -> new OAuth2AuthorizationCodeGrantFilter(this.clientRegistrationRepository, this.authorizedClientRepository, null))
 				.isInstanceOf(IllegalArgumentException.class);
 				.isInstanceOf(IllegalArgumentException.class);
 	}
 	}
 
 
@@ -218,7 +230,7 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
 	}
 	}
 
 
 	@Test
 	@Test
-	public void doFilterWhenAuthorizationResponseSuccessThenAuthorizedClientSaved() throws Exception {
+	public void doFilterWhenAuthorizationResponseSuccessThenAuthorizedClientSavedToService() throws Exception {
 		String requestUri = "/callback/client-1";
 		String requestUri = "/callback/client-1";
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 		request.setServletPath(requestUri);
 		request.setServletPath(requestUri);
@@ -285,6 +297,47 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
 		assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/saved-request");
 		assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/saved-request");
 	}
 	}
 
 
+	@Test
+	public void doFilterWhenAuthorizationResponseSuccessAndAnonymousAccessThenAuthorizedClientSavedToHttpSession() throws Exception {
+		AnonymousAuthenticationToken anonymousPrincipal =
+				new AnonymousAuthenticationToken("key-1234", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
+		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
+		securityContext.setAuthentication(anonymousPrincipal);
+		SecurityContextHolder.setContext(securityContext);
+
+		String requestUri = "/callback/client-1";
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		request.addParameter(OAuth2ParameterNames.CODE, "code");
+		request.addParameter(OAuth2ParameterNames.STATE, "state");
+
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.setUpAuthorizationRequest(request, response, this.registration1);
+		this.setUpAuthenticationResult(this.registration1);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
+				this.registration1.getRegistrationId(), anonymousPrincipal, request);
+		assertThat(authorizedClient).isNotNull();
+
+		assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1);
+		assertThat(authorizedClient.getPrincipalName()).isEqualTo(anonymousPrincipal.getName());
+		assertThat(authorizedClient.getAccessToken()).isNotNull();
+
+		HttpSession session = request.getSession(false);
+		assertThat(session).isNotNull();
+
+		@SuppressWarnings("unchecked")
+		Map<String, OAuth2AuthorizedClient> authorizedClients = (Map<String, OAuth2AuthorizedClient>)
+				session.getAttribute(HttpSessionOAuth2AuthorizedClientRepository.class.getName() + ".AUTHORIZED_CLIENTS");
+		assertThat(authorizedClients).isNotEmpty();
+		assertThat(authorizedClients).hasSize(1);
+		assertThat(authorizedClients.values().iterator().next()).isSameAs(authorizedClient);
+	}
+
 	private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
 	private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
 											ClientRegistration registration) {
 											ClientRegistration registration) {
 		Map<String, Object> additionalParameters = new HashMap<>();
 		Map<String, Object> additionalParameters = new HashMap<>();

+ 24 - 21
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java

@@ -15,19 +15,23 @@
  */
  */
 package org.springframework.security.oauth2.client.web.method.annotation;
 package org.springframework.security.oauth2.client.web.method.annotation;
 
 
+import org.junit.After;
 import org.junit.Before;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.Test;
 import org.springframework.core.MethodParameter;
 import org.springframework.core.MethodParameter;
+import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
 import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
-import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.util.ReflectionUtils;
 import org.springframework.util.ReflectionUtils;
+import org.springframework.web.context.request.ServletWebRequest;
 
 
+import javax.servlet.http.HttpServletRequest;
 import java.lang.reflect.Method;
 import java.lang.reflect.Method;
 
 
 import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
 import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
@@ -43,21 +47,29 @@ import static org.mockito.Mockito.when;
  * @author Joe Grandja
  * @author Joe Grandja
  */
  */
 public class OAuth2AuthorizedClientArgumentResolverTests {
 public class OAuth2AuthorizedClientArgumentResolverTests {
-	private OAuth2AuthorizedClientService authorizedClientService;
+	private OAuth2AuthorizedClientRepository authorizedClientRepository;
 	private OAuth2AuthorizedClientArgumentResolver argumentResolver;
 	private OAuth2AuthorizedClientArgumentResolver argumentResolver;
 	private OAuth2AuthorizedClient authorizedClient;
 	private OAuth2AuthorizedClient authorizedClient;
+	private MockHttpServletRequest request;
 
 
 	@Before
 	@Before
-	public void setUp() {
-		this.authorizedClientService = mock(OAuth2AuthorizedClientService.class);
-		this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientService);
+	public void setup() {
+		this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class);
+		this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientRepository);
 		this.authorizedClient = mock(OAuth2AuthorizedClient.class);
 		this.authorizedClient = mock(OAuth2AuthorizedClient.class);
-		when(this.authorizedClientService.loadAuthorizedClient(anyString(), any())).thenReturn(this.authorizedClient);
+		this.request = new MockHttpServletRequest();
+		when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class)))
+				.thenReturn(this.authorizedClient);
 		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
 		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
 		securityContext.setAuthentication(mock(Authentication.class));
 		securityContext.setAuthentication(mock(Authentication.class));
 		SecurityContextHolder.setContext(securityContext);
 		SecurityContextHolder.setContext(securityContext);
 	}
 	}
 
 
+	@After
+	public void cleanup() {
+		SecurityContextHolder.clearContext();
+	}
+
 	@Test
 	@Test
 	public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
 	public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null))
 		assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null))
@@ -104,31 +116,22 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
 		securityContext.setAuthentication(authentication);
 		securityContext.setAuthentication(authentication);
 		SecurityContextHolder.setContext(securityContext);
 		SecurityContextHolder.setContext(securityContext);
 		MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class);
 		MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class);
-		this.argumentResolver.resolveArgument(methodParameter, null, null, null);
-	}
-
-	@Test
-	public void resolveArgumentWhenParameterTypeOAuth2AuthorizedClientAndCurrentAuthenticationNullThenThrowIllegalStateException() throws Exception {
-		SecurityContextHolder.clearContext();
-		MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
-		assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, null, null))
-				.isInstanceOf(IllegalStateException.class)
-				.hasMessage("Unable to resolve the Authorized Client with registration identifier \"client1\". " +
-						"An \"authenticated\" or \"unauthenticated\" session is required. " +
-						"To allow for unauthenticated access, ensure HttpSecurity.anonymous() is configured.");
+		this.argumentResolver.resolveArgument(methodParameter, null, new ServletWebRequest(this.request), null);
 	}
 	}
 
 
 	@Test
 	@Test
 	public void resolveArgumentWhenOAuth2AuthorizedClientFoundThenResolves() throws Exception {
 	public void resolveArgumentWhenOAuth2AuthorizedClientFoundThenResolves() throws Exception {
 		MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
 		MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
-		assertThat(this.argumentResolver.resolveArgument(methodParameter, null, null, null)).isSameAs(this.authorizedClient);
+		assertThat(this.argumentResolver.resolveArgument(
+				methodParameter, null, new ServletWebRequest(this.request), null)).isSameAs(this.authorizedClient);
 	}
 	}
 
 
 	@Test
 	@Test
 	public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() throws Exception {
 	public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() throws Exception {
-		when(this.authorizedClientService.loadAuthorizedClient(anyString(), any())).thenReturn(null);
+		when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class)))
+				.thenReturn(null);
 		MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
 		MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
-		assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, null, null))
+		assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, new ServletWebRequest(this.request), null))
 				.isInstanceOf(ClientAuthorizationRequiredException.class);
 				.isInstanceOf(ClientAuthorizationRequiredException.class);
 	}
 	}