Jelajahi Sumber

Add OAuth2UserAuthenticationProvider

Moved logic from AuthorizationCodeAuthenticationProvider
to OAuth2UserAuthenticationProvider (new) related to
loading user attributes via OAuth2UserService.

This re-factor is part of the work required for Issue gh-4513
Joe Grandja 8 tahun lalu
induk
melakukan
5c14e48b18

+ 13 - 5
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeAuthenticationFilterConfigurer.java

@@ -23,6 +23,7 @@ import org.springframework.security.oauth2.client.authentication.AuthorizationCo
 import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticator;
 import org.springframework.security.oauth2.client.authentication.AuthorizationGrantAuthenticator;
 import org.springframework.security.oauth2.client.authentication.DelegatingAuthorizationGrantAuthenticator;
+import org.springframework.security.oauth2.client.authentication.OAuth2UserAuthenticationProvider;
 import org.springframework.security.oauth2.client.authentication.jwt.JwtDecoderRegistry;
 import org.springframework.security.oauth2.client.authentication.jwt.nimbus.NimbusJwtDecoderRegistry;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
@@ -128,13 +129,20 @@ final class AuthorizationCodeAuthenticationFilterConfigurer<H extends HttpSecuri
 
 	@Override
 	public void init(H http) throws Exception {
-		AuthorizationCodeAuthenticationProvider authenticationProvider = new AuthorizationCodeAuthenticationProvider(
-			this.getAuthorizationCodeAuthenticator(), this.getAccessTokenRepository(), this.getUserInfoService());
+		AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider =
+			new AuthorizationCodeAuthenticationProvider(
+				this.getAuthorizationCodeAuthenticator(), this.getAccessTokenRepository());
+		authorizationCodeAuthenticationProvider = this.postProcess(authorizationCodeAuthenticationProvider);
+		http.authenticationProvider(authorizationCodeAuthenticationProvider);
+
+		OAuth2UserAuthenticationProvider oauth2UserAuthenticationProvider =
+			new OAuth2UserAuthenticationProvider(this.getUserInfoService());
 		if (this.userAuthoritiesMapper != null) {
-			authenticationProvider.setAuthoritiesMapper(this.userAuthoritiesMapper);
+			oauth2UserAuthenticationProvider.setAuthoritiesMapper(this.userAuthoritiesMapper);
 		}
-		authenticationProvider = this.postProcess(authenticationProvider);
-		http.authenticationProvider(authenticationProvider);
+		oauth2UserAuthenticationProvider = this.postProcess(oauth2UserAuthenticationProvider);
+		http.authenticationProvider(oauth2UserAuthenticationProvider);
+
 		super.init(http);
 	}
 

+ 3 - 53
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/AuthorizationCodeAuthenticationProvider.java

@@ -15,57 +15,32 @@
  */
 package org.springframework.security.oauth2.client.authentication;
 
-import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
-import org.springframework.security.core.GrantedAuthority;
-import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
 import org.springframework.security.oauth2.client.token.SecurityTokenRepository;
-import org.springframework.security.oauth2.client.user.OAuth2UserService;
 import org.springframework.security.oauth2.core.AccessToken;
-import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.security.oauth2.oidc.client.authentication.OidcClientAuthenticationToken;
-import org.springframework.security.oauth2.oidc.client.authentication.OidcUserAuthenticationToken;
-import org.springframework.security.oauth2.oidc.core.user.OidcUser;
 import org.springframework.util.Assert;
 
-import java.util.Collection;
-
 /**
  * An implementation of an {@link AuthenticationProvider} that is responsible for authenticating
  * an <i>authorization code</i> credential with the authorization server's <i>Token Endpoint</i>
  * and if valid, exchanging it for an <i>access token</i> credential and optionally an
  * <i>id token</i> credential (for OpenID Connect Authorization Code Flow).
- * Additionally, it will also obtain the end-user's (resource owner) attributes from the <i>UserInfo Endpoint</i>
- * (using the <i>access token</i>) and create a <code>Principal</code> in the form of an {@link OAuth2User}
- * associating it with the returned {@link OAuth2UserAuthenticationToken}.
  *
  * <p>
  * The {@link AuthorizationCodeAuthenticationProvider} uses an {@link AuthorizationGrantAuthenticator}
  * to authenticate the {@link AuthorizationCodeAuthenticationToken#getAuthorizationCode()} and ultimately
  * return an <i>&quot;Authorized Client&quot;</i> as an {@link OAuth2ClientAuthenticationToken}.
  *
- * <p>
- * It will then call {@link OAuth2UserService#loadUser(OAuth2ClientAuthenticationToken)}
- * to obtain the end-user's (resource owner) attributes in the form of an {@link OAuth2User}.
- *
- * <p>
- * Finally, it will create an {@link OAuth2UserAuthenticationToken}, associating the {@link OAuth2User}
- * and {@link OAuth2ClientAuthenticationToken} and return it to the {@link AuthenticationManager},
- * at which point the {@link OAuth2UserAuthenticationToken} is considered <i>&quot;authenticated&quot;</i>.
- *
  * @author Joe Grandja
  * @since 5.0
  * @see AuthorizationCodeAuthenticationToken
  * @see OAuth2ClientAuthenticationToken
  * @see OidcClientAuthenticationToken
- * @see OAuth2UserAuthenticationToken
- * @see OidcUserAuthenticationToken
  * @see AuthorizationGrantAuthenticator
- * @see OAuth2UserService
- * @see OAuth2User
- * @see OidcUser
+ * @see SecurityTokenRepository
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant Flow</a>
  * @see <a target="_blank" href="http://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth">Section 3.1 OpenID Connect Authorization Code Flow</a>
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.3">Section 4.1.3 Access Token Request</a>
@@ -75,20 +50,15 @@ import java.util.Collection;
 public class AuthorizationCodeAuthenticationProvider implements AuthenticationProvider {
 	private final AuthorizationGrantAuthenticator<AuthorizationCodeAuthenticationToken> authorizationCodeAuthenticator;
 	private final SecurityTokenRepository<AccessToken> accessTokenRepository;
-	private final OAuth2UserService userService;
-	private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities);
 
 	public AuthorizationCodeAuthenticationProvider(
 			AuthorizationGrantAuthenticator<AuthorizationCodeAuthenticationToken> authorizationCodeAuthenticator,
-			SecurityTokenRepository<AccessToken> accessTokenRepository,
-			OAuth2UserService userService) {
+			SecurityTokenRepository<AccessToken> accessTokenRepository) {
 
 		Assert.notNull(authorizationCodeAuthenticator, "authorizationCodeAuthenticator cannot be null");
 		Assert.notNull(accessTokenRepository, "accessTokenRepository cannot be null");
-		Assert.notNull(userService, "userService cannot be null");
 		this.authorizationCodeAuthenticator = authorizationCodeAuthenticator;
 		this.accessTokenRepository = accessTokenRepository;
-		this.userService = userService;
 	}
 
 	@Override
@@ -103,27 +73,7 @@ public class AuthorizationCodeAuthenticationProvider implements AuthenticationPr
 			oauth2ClientAuthentication.getAccessToken(),
 			oauth2ClientAuthentication.getClientRegistration());
 
-		OAuth2User oauth2User = this.userService.loadUser(oauth2ClientAuthentication);
-
-		Collection<? extends GrantedAuthority> mappedAuthorities =
-				this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities());
-
-		OAuth2UserAuthenticationToken oauth2UserAuthentication;
-		if (OidcUser.class.isAssignableFrom(oauth2User.getClass())) {
-			oauth2UserAuthentication = new OidcUserAuthenticationToken(
-				(OidcUser)oauth2User, mappedAuthorities, (OidcClientAuthenticationToken)oauth2ClientAuthentication);
-		} else {
-			oauth2UserAuthentication = new OAuth2UserAuthenticationToken(
-				oauth2User, mappedAuthorities, oauth2ClientAuthentication);
-		}
-		oauth2UserAuthentication.setDetails(oauth2ClientAuthentication.getDetails());
-
-		return oauth2UserAuthentication;
-	}
-
-	public final void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
-		Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null");
-		this.authoritiesMapper = authoritiesMapper;
+		return oauth2ClientAuthentication;
 	}
 
 	@Override

+ 99 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2UserAuthenticationProvider.java

@@ -0,0 +1,99 @@
+/*
+ * Copyright 2012-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.authentication;
+
+import org.springframework.security.authentication.AuthenticationProvider;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.AuthenticationException;
+import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
+import org.springframework.security.oauth2.client.user.DefaultOAuth2UserService;
+import org.springframework.security.oauth2.client.user.OAuth2UserService;
+import org.springframework.security.oauth2.core.user.OAuth2User;
+import org.springframework.security.oauth2.oidc.client.authentication.OidcClientAuthenticationToken;
+import org.springframework.security.oauth2.oidc.client.authentication.OidcUserAuthenticationToken;
+import org.springframework.security.oauth2.oidc.client.user.OidcUserService;
+import org.springframework.security.oauth2.oidc.core.user.OidcUser;
+import org.springframework.util.Assert;
+
+import java.util.Collection;
+
+/**
+ * An implementation of an {@link AuthenticationProvider} that is responsible
+ * for obtaining the user attributes of the <i>End-User</i> (resource owner)
+ * from the <i>UserInfo Endpoint</i> and creating a <code>Principal</code>
+ * in the form of an {@link OAuth2User}.
+ *
+ * <p>
+ * The {@link OAuth2UserAuthenticationProvider} uses an {@link OAuth2UserService}
+ * for loading the {@link OAuth2User} and then associating it
+ * to the returned {@link OAuth2UserAuthenticationToken}.
+ *
+ * @author Joe Grandja
+ * @since 5.0
+ * @see OAuth2UserAuthenticationToken
+ * @see OidcUserAuthenticationToken
+ * @see OAuth2ClientAuthenticationToken
+ * @see OidcClientAuthenticationToken
+ * @see OAuth2UserService
+ * @see DefaultOAuth2UserService
+ * @see OidcUserService
+ * @see OAuth2User
+ * @see OidcUser
+ */
+public class OAuth2UserAuthenticationProvider implements AuthenticationProvider {
+	private final OAuth2UserService userService;
+	private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities);
+
+	public OAuth2UserAuthenticationProvider(OAuth2UserService userService) {
+		Assert.notNull(userService, "userService cannot be null");
+		this.userService = userService;
+	}
+
+	@Override
+	public Authentication authenticate(Authentication authentication) throws AuthenticationException {
+		OAuth2UserAuthenticationToken oauth2UserAuthentication = (OAuth2UserAuthenticationToken) authentication;
+
+		OAuth2ClientAuthenticationToken oauth2ClientAuthentication = oauth2UserAuthentication.getClientAuthentication();
+
+		OAuth2User oauth2User = this.userService.loadUser(oauth2ClientAuthentication);
+
+		Collection<? extends GrantedAuthority> mappedAuthorities =
+				this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities());
+
+		OAuth2UserAuthenticationToken authenticationResult;
+		if (OidcUser.class.isAssignableFrom(oauth2User.getClass())) {
+			authenticationResult = new OidcUserAuthenticationToken(
+				(OidcUser)oauth2User, mappedAuthorities, (OidcClientAuthenticationToken)oauth2ClientAuthentication);
+		} else {
+			authenticationResult = new OAuth2UserAuthenticationToken(
+				oauth2User, mappedAuthorities, oauth2ClientAuthentication);
+		}
+		authenticationResult.setDetails(oauth2ClientAuthentication.getDetails());
+
+		return authenticationResult;
+	}
+
+	public final void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
+		Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null");
+		this.authoritiesMapper = authoritiesMapper;
+	}
+
+	@Override
+	public boolean supports(Class<?> authentication) {
+		return OAuth2UserAuthenticationToken.class.isAssignableFrom(authentication);
+	}
+}

+ 6 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2UserAuthenticationToken.java

@@ -19,6 +19,7 @@ import org.springframework.security.authentication.AbstractAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.SpringSecurityCoreVersion;
+import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.util.Assert;
 
@@ -42,14 +43,17 @@ public class OAuth2UserAuthenticationToken extends AbstractAuthenticationToken {
 	private final OAuth2User principal;
 	private final OAuth2ClientAuthenticationToken clientAuthentication;
 
+	public OAuth2UserAuthenticationToken(OAuth2ClientAuthenticationToken clientAuthentication) {
+		this(null, AuthorityUtils.NO_AUTHORITIES, clientAuthentication);
+	}
+
 	public OAuth2UserAuthenticationToken(OAuth2User principal, Collection<? extends GrantedAuthority> authorities,
 											OAuth2ClientAuthenticationToken clientAuthentication) {
 		super(authorities);
-		Assert.notNull(principal, "principal cannot be null");
 		Assert.notNull(clientAuthentication, "clientAuthentication cannot be null");
 		this.principal = principal;
 		this.clientAuthentication = clientAuthentication;
-		this.setAuthenticated(true);
+		this.setAuthenticated(principal != null);
 	}
 
 	@Override

+ 82 - 4
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilter.java

@@ -18,10 +18,14 @@ package org.springframework.security.oauth2.client.web;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
+import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationProvider;
 import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationToken;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.client.authentication.OAuth2ClientAuthenticationToken;
+import org.springframework.security.oauth2.client.authentication.OAuth2UserAuthenticationToken;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationIdentifierStrategy;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.converter.AuthorizationCodeAuthorizationResponseAttributesConverter;
 import org.springframework.security.oauth2.client.web.converter.ErrorResponseAttributesConverter;
@@ -30,6 +34,10 @@ import org.springframework.security.oauth2.core.endpoint.AuthorizationCodeAuthor
 import org.springframework.security.oauth2.core.endpoint.AuthorizationRequestAttributes;
 import org.springframework.security.oauth2.core.endpoint.ErrorResponseAttributes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter;
+import org.springframework.security.oauth2.core.user.OAuth2User;
+import org.springframework.security.oauth2.oidc.client.authentication.OidcClientAuthenticationToken;
+import org.springframework.security.oauth2.oidc.client.authentication.OidcUserAuthenticationToken;
+import org.springframework.security.oauth2.oidc.core.user.OidcUser;
 import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
@@ -84,6 +92,7 @@ public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAut
 	private ClientRegistrationRepository clientRegistrationRepository;
 	private RequestMatcher authorizationResponseMatcher = new AuthorizationResponseMatcher();
 	private AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository();
+	private final ClientRegistrationIdentifierStrategy<String> providerIdentifierStrategy = new ProviderIdentifierStrategy();
 
 	public AuthorizationCodeAuthenticationProcessingFilter() {
 		super(new AuthorizationResponseMatcher());
@@ -119,14 +128,27 @@ public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAut
 		AuthorizationCodeAuthorizationResponseAttributes authorizationCodeResponseAttributes =
 				this.authorizationCodeResponseConverter.apply(request);
 
-		AuthorizationCodeAuthenticationToken authRequest = new AuthorizationCodeAuthenticationToken(
+		AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = new AuthorizationCodeAuthenticationToken(
 				authorizationCodeResponseAttributes.getCode(), clientRegistration, matchingAuthorizationRequest);
+		authorizationCodeAuthentication.setDetails(this.authenticationDetailsSource.buildDetails(request));
 
-		authRequest.setDetails(this.authenticationDetailsSource.buildDetails(request));
+		OAuth2ClientAuthenticationToken oauth2ClientAuthentication =
+			(OAuth2ClientAuthenticationToken)this.getAuthenticationManager().authenticate(authorizationCodeAuthentication);
 
-		Authentication authenticated = this.getAuthenticationManager().authenticate(authRequest);
+		OAuth2UserAuthenticationToken oauth2UserAuthentication;
+		if (this.authenticated() && this.authenticatedSameProviderAs(oauth2ClientAuthentication)) {
+			// Create a new user authentication (using same principal)
+			// but with a different client authentication association
+			oauth2UserAuthentication = (OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication();
+			oauth2UserAuthentication = this.createUserAuthentication(oauth2UserAuthentication, oauth2ClientAuthentication);
+		} else {
+			// Authenticate the user... the user needs to be authenticated
+			// before we can associate the client authentication to the user
+			oauth2UserAuthentication = (OAuth2UserAuthenticationToken)this.getAuthenticationManager().authenticate(
+				this.createUserAuthentication(oauth2ClientAuthentication));
+		}
 
-		return authenticated;
+		return oauth2UserAuthentication;
 	}
 
 	public RequestMatcher getAuthorizationResponseMatcher() {
@@ -182,6 +204,50 @@ public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAut
 		}
 	}
 
+	private boolean authenticated() {
+		Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication();
+		return currentAuthentication != null &&
+			currentAuthentication instanceof OAuth2UserAuthenticationToken &&
+			currentAuthentication.isAuthenticated();
+	}
+
+	private boolean authenticatedSameProviderAs(OAuth2ClientAuthenticationToken oauth2ClientAuthentication) {
+		OAuth2UserAuthenticationToken userAuthentication =
+			(OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication();
+
+		String userProviderId = this.providerIdentifierStrategy.getIdentifier(
+			userAuthentication.getClientAuthentication().getClientRegistration());
+		String clientProviderId = this.providerIdentifierStrategy.getIdentifier(
+			oauth2ClientAuthentication.getClientRegistration());
+
+		return userProviderId.equals(clientProviderId);
+	}
+
+	private OAuth2UserAuthenticationToken createUserAuthentication(OAuth2ClientAuthenticationToken clientAuthentication) {
+		if (OidcClientAuthenticationToken.class.isAssignableFrom(clientAuthentication.getClass())) {
+			return new OidcUserAuthenticationToken((OidcClientAuthenticationToken)clientAuthentication);
+		} else {
+			return new OAuth2UserAuthenticationToken(clientAuthentication);
+		}
+	}
+
+	private OAuth2UserAuthenticationToken createUserAuthentication(
+		OAuth2UserAuthenticationToken currentUserAuthentication,
+		OAuth2ClientAuthenticationToken newClientAuthentication) {
+
+		if (OidcUserAuthenticationToken.class.isAssignableFrom(currentUserAuthentication.getClass())) {
+			return new OidcUserAuthenticationToken(
+				(OidcUser) currentUserAuthentication.getPrincipal(),
+				currentUserAuthentication.getAuthorities(),
+				newClientAuthentication);
+		} else {
+			return new OAuth2UserAuthenticationToken(
+				(OAuth2User)currentUserAuthentication.getPrincipal(),
+				currentUserAuthentication.getAuthorities(),
+				newClientAuthentication);
+		}
+	}
+
 	private static class AuthorizationResponseMatcher implements RequestMatcher {
 
 		@Override
@@ -199,4 +265,16 @@ public class AuthorizationCodeAuthenticationProcessingFilter extends AbstractAut
 				StringUtils.hasText(request.getParameter(OAuth2Parameter.STATE));
 		}
 	}
+
+	private static class ProviderIdentifierStrategy implements ClientRegistrationIdentifierStrategy<String> {
+
+		@Override
+		public String getIdentifier(ClientRegistration clientRegistration) {
+			StringBuilder builder = new StringBuilder();
+			builder.append("[").append(clientRegistration.getProviderDetails().getAuthorizationUri()).append("]");
+			builder.append("[").append(clientRegistration.getProviderDetails().getTokenUri()).append("]");
+			builder.append("[").append(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()).append("]");
+			return builder.toString();
+		}
+	}
 }

+ 11 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/oidc/client/authentication/OidcUserAuthenticationToken.java

@@ -17,6 +17,8 @@ package org.springframework.security.oauth2.oidc.client.authentication;
 
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.oauth2.client.authentication.OAuth2ClientAuthenticationToken;
 import org.springframework.security.oauth2.client.authentication.OAuth2UserAuthenticationToken;
 import org.springframework.security.oauth2.oidc.core.user.OidcUser;
 
@@ -38,8 +40,17 @@ import java.util.Collection;
  */
 public class OidcUserAuthenticationToken extends OAuth2UserAuthenticationToken {
 
+	public OidcUserAuthenticationToken(OidcClientAuthenticationToken clientAuthentication) {
+		this(null, AuthorityUtils.NO_AUTHORITIES, clientAuthentication);
+	}
+
 	public OidcUserAuthenticationToken(OidcUser principal, Collection<? extends GrantedAuthority> authorities,
 										OidcClientAuthenticationToken clientAuthentication) {
+		this(principal, authorities, (OAuth2ClientAuthenticationToken)clientAuthentication);
+	}
+
+	public OidcUserAuthenticationToken(OidcUser principal, Collection<? extends GrantedAuthority> authorities,
+										OAuth2ClientAuthenticationToken clientAuthentication) {
 		super(principal, authorities, clientAuthentication);
 	}
 }

+ 28 - 18
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationProcessingFilterTests.java

@@ -23,15 +23,20 @@ import org.mockito.Mockito;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.authentication.AuthenticationManager;
-import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
+import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.client.authentication.OAuth2ClientAuthenticationToken;
+import org.springframework.security.oauth2.client.authentication.OAuth2UserAuthenticationToken;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.core.AccessToken;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.endpoint.AuthorizationRequestAttributes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter;
+import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 
@@ -41,6 +46,8 @@ import javax.servlet.http.HttpServletResponse;
 import java.util.HashMap;
 import java.util.Map;
 
+import static org.mockito.Mockito.mock;
+
 /**
  * Tests {@link AuthorizationCodeAuthenticationProcessingFilter}.
  *
@@ -58,7 +65,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestURI);
 		request.setServletPath(requestURI);
 		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = Mockito.mock(FilterChain.class);
+		FilterChain filterChain = mock(FilterChain.class);
 
 		filter.doFilter(request, response, filterChain);
 
@@ -71,7 +78,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
 		ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
 
 		AuthorizationCodeAuthenticationProcessingFilter filter = Mockito.spy(setupFilter(clientRegistration));
-		AuthenticationFailureHandler failureHandler = Mockito.mock(AuthenticationFailureHandler.class);
+		AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class);
 		filter.setAuthenticationFailureHandler(failureHandler);
 
 		MockHttpServletRequest request = this.setupRequest(clientRegistration);
@@ -79,7 +86,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
 		request.addParameter(OAuth2Parameter.ERROR, errorCode);
 		request.addParameter(OAuth2Parameter.STATE, "some state");
 		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = Mockito.mock(FilterChain.class);
+		FilterChain filterChain = mock(FilterChain.class);
 
 		filter.doFilter(request, response, filterChain);
 
@@ -90,14 +97,17 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
 
 	@Test
 	public void doFilterWhenAuthorizationCodeSuccessResponseThenAuthenticationSuccessHandlerIsCalled() throws Exception {
-		TestingAuthenticationToken authentication = new TestingAuthenticationToken("joe", "password", "user", "admin");
-		AuthenticationManager authenticationManager = Mockito.mock(AuthenticationManager.class);
-		Mockito.when(authenticationManager.authenticate(Matchers.any(Authentication.class))).thenReturn(authentication);
-
 		ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
+		OAuth2ClientAuthenticationToken clientAuthentication = new OAuth2ClientAuthenticationToken(
+			clientRegistration, mock(AccessToken.class));
+		OAuth2UserAuthenticationToken userAuthentication = new OAuth2UserAuthenticationToken(
+			mock(OAuth2User.class), AuthorityUtils.createAuthorityList("ROLE_USER"), clientAuthentication);
+		SecurityContextHolder.getContext().setAuthentication(userAuthentication);
+		AuthenticationManager authenticationManager = mock(AuthenticationManager.class);
+		Mockito.when(authenticationManager.authenticate(Matchers.any(Authentication.class))).thenReturn(clientAuthentication);
 
 		AuthorizationCodeAuthenticationProcessingFilter filter = Mockito.spy(setupFilter(authenticationManager, clientRegistration));
-		AuthenticationSuccessHandler successHandler = Mockito.mock(AuthenticationSuccessHandler.class);
+		AuthenticationSuccessHandler successHandler = mock(AuthenticationSuccessHandler.class);
 		filter.setAuthenticationSuccessHandler(successHandler);
 		AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository();
 		filter.setAuthorizationRequestRepository(authorizationRequestRepository);
@@ -109,7 +119,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
 		request.addParameter(OAuth2Parameter.STATE, state);
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		setupAuthorizationRequest(authorizationRequestRepository, request, response, clientRegistration, state);
-		FilterChain filterChain = Mockito.mock(FilterChain.class);
+		FilterChain filterChain = mock(FilterChain.class);
 
 		filter.doFilter(request, response, filterChain);
 
@@ -118,7 +128,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
 		ArgumentCaptor<Authentication> authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class);
 		Mockito.verify(successHandler).onAuthenticationSuccess(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class),
 				authenticationArgCaptor.capture());
-		Assertions.assertThat(authenticationArgCaptor.getValue()).isEqualTo(authentication);
+		Assertions.assertThat(authenticationArgCaptor.getValue()).isEqualTo(userAuthentication);
 	}
 
 	@Test
@@ -126,7 +136,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
 		ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
 
 		AuthorizationCodeAuthenticationProcessingFilter filter = Mockito.spy(setupFilter(clientRegistration));
-		AuthenticationFailureHandler failureHandler = Mockito.mock(AuthenticationFailureHandler.class);
+		AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class);
 		filter.setAuthenticationFailureHandler(failureHandler);
 
 		MockHttpServletRequest request = this.setupRequest(clientRegistration);
@@ -135,7 +145,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
 		request.addParameter(OAuth2Parameter.CODE, authCode);
 		request.addParameter(OAuth2Parameter.STATE, state);
 		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = Mockito.mock(FilterChain.class);
+		FilterChain filterChain = mock(FilterChain.class);
 
 		filter.doFilter(request, response, filterChain);
 
@@ -147,7 +157,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
 		ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
 
 		AuthorizationCodeAuthenticationProcessingFilter filter = Mockito.spy(setupFilter(clientRegistration));
-		AuthenticationFailureHandler failureHandler = Mockito.mock(AuthenticationFailureHandler.class);
+		AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class);
 		filter.setAuthenticationFailureHandler(failureHandler);
 		AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository();
 		filter.setAuthorizationRequestRepository(authorizationRequestRepository);
@@ -159,7 +169,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
 		request.addParameter(OAuth2Parameter.STATE, state);
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		setupAuthorizationRequest(authorizationRequestRepository, request, response, clientRegistration, "some state");
-		FilterChain filterChain = Mockito.mock(FilterChain.class);
+		FilterChain filterChain = mock(FilterChain.class);
 
 		filter.doFilter(request, response, filterChain);
 
@@ -171,7 +181,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
 		ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
 
 		AuthorizationCodeAuthenticationProcessingFilter filter = Mockito.spy(setupFilter(clientRegistration));
-		AuthenticationFailureHandler failureHandler = Mockito.mock(AuthenticationFailureHandler.class);
+		AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class);
 		filter.setAuthenticationFailureHandler(failureHandler);
 		AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository();
 		filter.setAuthorizationRequestRepository(authorizationRequestRepository);
@@ -184,7 +194,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
 		request.addParameter(OAuth2Parameter.STATE, state);
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		setupAuthorizationRequest(authorizationRequestRepository, request, response, clientRegistration, state);
-		FilterChain filterChain = Mockito.mock(FilterChain.class);
+		FilterChain filterChain = mock(FilterChain.class);
 
 		filter.doFilter(request, response, filterChain);
 
@@ -209,7 +219,7 @@ public class AuthorizationCodeAuthenticationProcessingFilterTests {
 	}
 
 	private AuthorizationCodeAuthenticationProcessingFilter setupFilter(ClientRegistration... clientRegistrations) throws Exception {
-		AuthenticationManager authenticationManager = Mockito.mock(AuthenticationManager.class);
+		AuthenticationManager authenticationManager = mock(AuthenticationManager.class);
 
 		return setupFilter(authenticationManager, clientRegistrations);
 	}

+ 2 - 0
samples/boot/oauth2login/src/main/resources/META-INF/oauth2-clients-defaults.yml

@@ -10,6 +10,7 @@ security:
           authorization-uri: "https://accounts.google.com/o/oauth2/v2/auth"
           token-uri: "https://www.googleapis.com/oauth2/v4/token"
           user-info-uri: "https://www.googleapis.com/oauth2/v3/userinfo"
+          user-name-attribute-name: "sub"
           jwk-set-uri: "https://www.googleapis.com/oauth2/v3/certs"
           client-name: Google
         github:
@@ -38,3 +39,4 @@ security:
           redirect-uri: "{scheme}://{serverName}:{serverPort}{contextPath}/oauth2/authorize/code/{registrationId}"
           scope: openid, profile, email, address, phone
           client-name: Okta
+          user-name-attribute-name: "sub"