Bladeren bron

Add tests to oauth2-client

Fixes gh-4299
Joe Grandja 7 jaren geleden
bovenliggende
commit
473ac0e37c
32 gewijzigde bestanden met toevoegingen van 3328 en 264 verwijderingen
  1. 4 0
      oauth2/oauth2-client/spring-security-oauth2-client.gradle
  2. 3 1
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClient.java
  3. 1 1
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java
  4. 3 2
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/NimbusUserInfoResponseClient.java
  5. 2 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java
  6. 1 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserService.java
  7. 2 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java
  8. 1 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DelegatingOAuth2UserService.java
  9. 3 2
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/NimbusUserInfoResponseClient.java
  10. 5 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java
  11. 2 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestUriBuilder.java
  12. 188 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java
  13. 72 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientTests.java
  14. 77 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthenticationTokenTests.java
  15. 215 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java
  16. 131 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationTokenTests.java
  17. 300 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClientTests.java
  18. 66 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestTests.java
  19. 414 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java
  20. 73 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestTests.java
  21. 260 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java
  22. 354 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java
  23. 3 1
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepositoryTests.java
  24. 267 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java
  25. 197 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java
  26. 84 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DelegatingOAuth2UserServiceTests.java
  27. 63 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestTests.java
  28. 138 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java
  29. 202 52
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java
  30. 7 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestUriBuilderTests.java
  31. 190 134
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java
  32. 0 71
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/TestUtil.java

+ 4 - 0
oauth2/oauth2-client/spring-security-oauth2-client.gradle

@@ -9,5 +9,9 @@ dependencies {
 
 	optional project(':spring-security-oauth2-jose')
 
+	testCompile powerMock2Dependencies
+	testCompile 'com.squareup.okhttp3:mockwebserver'
+	testCompile 'com.fasterxml.jackson.core:jackson-databind'
+
 	provided 'javax.servlet:javax.servlet-api'
 }

+ 3 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClient.java

@@ -100,7 +100,9 @@ public class NimbusAuthorizationCodeTokenResponseClient implements OAuth2AccessT
 			httpRequest.setReadTimeout(30000);
 			tokenResponse = com.nimbusds.oauth2.sdk.TokenResponse.parse(httpRequest.send());
 		} catch (ParseException pe) {
-			throw new OAuth2AuthenticationException(new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE), pe);
+			OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
+				"An error occurred parsing the Access Token response: " + pe.getMessage(), null);
+			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), pe);
 		} catch (IOException ioe) {
 			throw new AuthenticationServiceException("An error occurred while sending the Access Token Request: " +
 					ioe.getMessage(), ioe);

+ 1 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java

@@ -262,7 +262,7 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati
 		// 10. The iat Claim can be used to reject tokens that were issued too far away from the current time,
 		// limiting the amount of time that nonces need to be stored to prevent attacks.
 		// The acceptable range is Client specific.
-		Instant maxIssuedAt = now.plusSeconds(30);
+		Instant maxIssuedAt = Instant.now().plusSeconds(30);
 		if (issuedAt.isAfter(maxIssuedAt)) {
 			this.throwInvalidIdTokenException();
 		}

+ 3 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/NimbusUserInfoResponseClient.java

@@ -27,6 +27,7 @@ import org.springframework.http.HttpHeaders;
 import org.springframework.http.client.AbstractClientHttpResponse;
 import org.springframework.http.client.ClientHttpResponse;
 import org.springframework.http.converter.GenericHttpMessageConverter;
+import org.springframework.http.converter.HttpMessageNotReadableException;
 import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
 import org.springframework.security.authentication.AuthenticationServiceException;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
@@ -57,7 +58,7 @@ final class NimbusUserInfoResponseClient {
 			userInfoRequest.getClientRegistration(), userInfoRequest.getAccessToken());
 		try {
 			return (T) this.genericHttpMessageConverter.read(returnType, userInfoResponse);
-		} catch (IOException ex) {
+		} catch (IOException | HttpMessageNotReadableException ex) {
 			OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
 				"An error occurred reading the UserInfo Success response: " + ex.getMessage(), null);
 			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
@@ -69,7 +70,7 @@ final class NimbusUserInfoResponseClient {
 			userInfoRequest.getClientRegistration(), userInfoRequest.getAccessToken());
 		try {
 			return (T) this.genericHttpMessageConverter.read(typeReference.getType(), null, userInfoResponse);
-		} catch (IOException ex) {
+		} catch (IOException | HttpMessageNotReadableException ex) {
 			OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
 				"An error occurred reading the UserInfo Success response: " + ex.getMessage(), null);
 			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);

+ 2 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java

@@ -26,6 +26,7 @@ import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
 import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
 import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority;
+import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
 
 import java.util.Arrays;
@@ -52,6 +53,7 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
 
 	@Override
 	public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
+		Assert.notNull(userRequest, "userRequest cannot be null");
 		OidcUserInfo userInfo = null;
 		if (this.shouldRetrieveUserInfo(userRequest)) {
 			ParameterizedTypeReference<Map<String, Object>> typeReference =

+ 1 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserService.java

@@ -49,6 +49,7 @@ public class CustomUserTypesOAuth2UserService implements OAuth2UserService<OAuth
 
 	@Override
 	public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2AuthenticationException {
+		Assert.notNull(userRequest, "userRequest cannot be null");
 		String registrationId = userRequest.getClientRegistration().getRegistrationId();
 		Class<? extends OAuth2User> customUserType;
 		if ((customUserType = this.customUserTypes.get(registrationId)) == null) {

+ 2 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java

@@ -23,6 +23,7 @@ import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
+import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
 
 import java.util.HashSet;
@@ -52,6 +53,7 @@ public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserReq
 
 	@Override
 	public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2AuthenticationException {
+		Assert.notNull(userRequest, "userRequest cannot be null");
 		String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName();
 		if (!StringUtils.hasText(userNameAttributeName)) {
 			OAuth2Error oauth2Error = new OAuth2Error(

+ 1 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DelegatingOAuth2UserService.java

@@ -51,6 +51,7 @@ public class DelegatingOAuth2UserService<R extends OAuth2UserRequest, U extends
 
 	@Override
 	public U loadUser(R userRequest) throws OAuth2AuthenticationException {
+		Assert.notNull(userRequest, "userRequest cannot be null");
 		return this.userServices.stream()
 			.map(userService -> userService.loadUser(userRequest))
 			.filter(Objects::nonNull)

+ 3 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/NimbusUserInfoResponseClient.java

@@ -27,6 +27,7 @@ import org.springframework.http.HttpHeaders;
 import org.springframework.http.client.AbstractClientHttpResponse;
 import org.springframework.http.client.ClientHttpResponse;
 import org.springframework.http.converter.GenericHttpMessageConverter;
+import org.springframework.http.converter.HttpMessageNotReadableException;
 import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
 import org.springframework.security.authentication.AuthenticationServiceException;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
@@ -54,7 +55,7 @@ final class NimbusUserInfoResponseClient {
 			userInfoRequest.getClientRegistration(), userInfoRequest.getAccessToken());
 		try {
 			return (T) this.genericHttpMessageConverter.read(returnType, userInfoResponse);
-		} catch (IOException ex) {
+		} catch (IOException | HttpMessageNotReadableException ex) {
 			OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
 				"An error occurred reading the UserInfo Success response: " + ex.getMessage(), null);
 			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
@@ -66,7 +67,7 @@ final class NimbusUserInfoResponseClient {
 			userInfoRequest.getClientRegistration(), userInfoRequest.getAccessToken());
 		try {
 			return (T) this.genericHttpMessageConverter.read(typeReference.getType(), null, userInfoResponse);
-		} catch (IOException ex) {
+		} catch (IOException | HttpMessageNotReadableException ex) {
 			OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
 				"An error occurred reading the UserInfo Success response: " + ex.getMessage(), null);
 			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);

+ 5 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java

@@ -16,6 +16,7 @@
 package org.springframework.security.oauth2.client.web;
 
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.util.Assert;
 
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
@@ -36,6 +37,7 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
 
 	@Override
 	public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) {
+		Assert.notNull(request, "request cannot be null");
 		HttpSession session = request.getSession(false);
 		if (session != null) {
 			return (OAuth2AuthorizationRequest) session.getAttribute(this.sessionAttributeName);
@@ -46,6 +48,8 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
 	@Override
 	public void saveAuthorizationRequest(OAuth2AuthorizationRequest authorizationRequest, HttpServletRequest request,
 											HttpServletResponse response) {
+		Assert.notNull(request, "request cannot be null");
+		Assert.notNull(response, "response cannot be null");
 		if (authorizationRequest == null) {
 			this.removeAuthorizationRequest(request);
 			return;
@@ -55,6 +59,7 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
 
 	@Override
 	public OAuth2AuthorizationRequest removeAuthorizationRequest(HttpServletRequest request) {
+		Assert.notNull(request, "request cannot be null");
 		OAuth2AuthorizationRequest authorizationRequest = this.loadAuthorizationRequest(request);
 		if (authorizationRequest != null) {
 			request.getSession().removeAttribute(this.sessionAttributeName);

+ 2 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestUriBuilder.java

@@ -17,6 +17,7 @@ package org.springframework.security.oauth2.client.web;
 
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
 import org.springframework.web.util.UriComponentsBuilder;
 
@@ -35,6 +36,7 @@ import java.util.Set;
 class OAuth2AuthorizationRequestUriBuilder {
 
 	URI build(OAuth2AuthorizationRequest authorizationRequest) {
+		Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
 		Set<String> scopes = authorizationRequest.getScopes();
 		UriComponentsBuilder uriBuilder = UriComponentsBuilder
 			.fromUriString(authorizationRequest.getAuthorizationUri())

+ 188 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java

@@ -0,0 +1,188 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client;
+
+import org.junit.Test;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link InMemoryOAuth2AuthorizedClientService}.
+ *
+ * @author Joe Grandja
+ */
+public class InMemoryOAuth2AuthorizedClientServiceTests {
+	private String registrationId1 = "registration-1";
+	private String registrationId2 = "registration-2";
+	private String registrationId3 = "registration-3";
+	private String principalName1 = "principal-1";
+	private String principalName2 = "principal-2";
+
+	private ClientRegistration registration1 = ClientRegistration.withRegistrationId(this.registrationId1)
+		.clientId("client-1")
+		.clientSecret("secret")
+		.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+		.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+		.redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/code/{registrationId}")
+		.scope("user")
+		.authorizationUri("https://provider.com/oauth2/authorize")
+		.tokenUri("https://provider.com/oauth2/token")
+		.userInfoUri("https://provider.com/oauth2/user")
+		.userNameAttributeName("id")
+		.clientName("client-1")
+		.build();
+
+	private ClientRegistration registration2 = ClientRegistration.withRegistrationId(this.registrationId2)
+		.clientId("client-2")
+		.clientSecret("secret")
+		.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+		.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+		.redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/code/{registrationId}")
+		.scope("openid", "profile", "email")
+		.authorizationUri("https://provider.com/oauth2/authorize")
+		.tokenUri("https://provider.com/oauth2/token")
+		.userInfoUri("https://provider.com/oauth2/userinfo")
+		.jwkSetUri("https://provider.com/oauth2/keys")
+		.clientName("client-2")
+		.build();
+
+	private ClientRegistration registration3 = ClientRegistration.withRegistrationId(this.registrationId3)
+		.clientId("client-3")
+		.clientSecret("secret")
+		.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+		.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+		.redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/code/{registrationId}")
+		.scope("openid", "profile")
+		.authorizationUri("https://provider.com/oauth2/authorize")
+		.tokenUri("https://provider.com/oauth2/token")
+		.userInfoUri("https://provider.com/oauth2/userinfo")
+		.jwkSetUri("https://provider.com/oauth2/keys")
+		.clientName("client-3")
+		.build();
+
+	private ClientRegistrationRepository clientRegistrationRepository =
+		new InMemoryClientRegistrationRepository(this.registration1, this.registration2, this.registration3);
+
+	private InMemoryOAuth2AuthorizedClientService authorizedClientService =
+		new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository);
+
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
+		new InMemoryOAuth2AuthorizedClientService(null);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
+		this.authorizedClientService.loadAuthorizedClient(null, this.principalName1);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void loadAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() {
+		this.authorizedClientService.loadAuthorizedClient(this.registrationId1, null);
+	}
+
+	@Test
+	public void loadAuthorizedClientWhenClientRegistrationNotFoundThenReturnNull() {
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
+			"registration-not-found", this.principalName1);
+		assertThat(authorizedClient).isNull();
+	}
+
+	@Test
+	public void loadAuthorizedClientWhenClientRegistrationFoundButNotAssociatedToPrincipalThenReturnNull() {
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
+			this.registrationId1, "principal-not-found");
+		assertThat(authorizedClient).isNull();
+	}
+
+	@Test
+	public void loadAuthorizedClientWhenClientRegistrationFoundAndAssociatedToPrincipalThenReturnAuthorizedClient() {
+		Authentication authentication = mock(Authentication.class);
+		when(authentication.getName()).thenReturn(this.principalName1);
+
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+			this.registration1, this.principalName1, mock(OAuth2AccessToken.class));
+		this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
+
+		OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService.loadAuthorizedClient(
+			this.registrationId1, this.principalName1);
+		assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void saveAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() {
+		this.authorizedClientService.saveAuthorizedClient(null, mock(Authentication.class));
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void saveAuthorizedClientWhenPrincipalIsNullThenThrowIllegalArgumentException() {
+		this.authorizedClientService.saveAuthorizedClient(mock(OAuth2AuthorizedClient.class), null);
+	}
+
+	@Test
+	public void saveAuthorizedClientWhenSavedThenCanLoad() {
+		Authentication authentication = mock(Authentication.class);
+		when(authentication.getName()).thenReturn(this.principalName2);
+
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+			this.registration3, this.principalName2, mock(OAuth2AccessToken.class));
+		this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
+
+		OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService.loadAuthorizedClient(
+			this.registrationId3, this.principalName2);
+		assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void removeAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
+		this.authorizedClientService.removeAuthorizedClient(null, this.principalName2);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void removeAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() {
+		this.authorizedClientService.removeAuthorizedClient(this.registrationId2, null);
+	}
+
+	@Test
+	public void removeAuthorizedClientWhenSavedThenRemoved() {
+		Authentication authentication = mock(Authentication.class);
+		when(authentication.getName()).thenReturn(this.principalName2);
+
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+			this.registration2, this.principalName2, mock(OAuth2AccessToken.class));
+		this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
+
+		OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService.loadAuthorizedClient(
+			this.registrationId2, this.principalName2);
+		assertThat(loadedAuthorizedClient).isNotNull();
+
+		this.authorizedClientService.removeAuthorizedClient(this.registrationId2, this.principalName2);
+
+		loadedAuthorizedClient = this.authorizedClientService.loadAuthorizedClient(
+			this.registrationId2, this.principalName2);
+		assertThat(loadedAuthorizedClient).isNull();
+	}
+}

+ 72 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientTests.java

@@ -0,0 +1,72 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Tests for {@link OAuth2AuthorizedClient}.
+ *
+ * @author Joe Grandja
+ */
+@RunWith(PowerMockRunner.class)
+@PrepareForTest(ClientRegistration.class)
+public class OAuth2AuthorizedClientTests {
+	private ClientRegistration clientRegistration;
+	private String principalName;
+	private OAuth2AccessToken accessToken;
+
+	@Before
+	public void setUp() {
+		this.clientRegistration = mock(ClientRegistration.class);
+		this.principalName = "principal";
+		this.accessToken = mock(OAuth2AccessToken.class);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
+		new OAuth2AuthorizedClient(null, this.principalName, this.accessToken);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorWhenPrincipalNameIsNullThenThrowIllegalArgumentException() {
+		new OAuth2AuthorizedClient(this.clientRegistration, null, this.accessToken);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() {
+		new OAuth2AuthorizedClient(this.clientRegistration, this.principalName, null);
+	}
+
+	@Test
+	public void constructorWhenAllParametersProvidedAndValidThenCreated() {
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+			this.clientRegistration, this.principalName, this.accessToken);
+
+		assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName);
+		assertThat(authorizedClient.getAccessToken()).isEqualTo(this.accessToken);
+	}
+}

+ 77 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthenticationTokenTests.java

@@ -0,0 +1,77 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.authentication;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.oauth2.core.user.OAuth2User;
+
+import java.util.Collection;
+import java.util.Collections;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Tests for {@link OAuth2AuthenticationToken}.
+ *
+ * @author Joe Grandja
+ */
+public class OAuth2AuthenticationTokenTests {
+	private OAuth2User principal;
+	private Collection<? extends GrantedAuthority> authorities;
+	private String authorizedClientRegistrationId;
+
+	@Before
+	public void setUp() {
+		this.principal = mock(OAuth2User.class);
+		this.authorities = Collections.emptyList();
+		this.authorizedClientRegistrationId = "client-registration-1";
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorWhenPrincipalIsNullThenThrowIllegalArgumentException() {
+		new OAuth2AuthenticationToken(null, this.authorities, this.authorizedClientRegistrationId);
+	}
+
+	@Test
+	public void constructorWhenAuthoritiesIsNullThenCreated() {
+		new OAuth2AuthenticationToken(this.principal, null, this.authorizedClientRegistrationId);
+	}
+
+	@Test
+	public void constructorWhenAuthoritiesIsEmptyThenCreated() {
+		new OAuth2AuthenticationToken(this.principal, Collections.emptyList(), this.authorizedClientRegistrationId);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorWhenAuthorizedClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
+		new OAuth2AuthenticationToken(this.principal, this.authorities, null);
+	}
+
+	@Test
+	public void constructorWhenAllParametersProvidedAndValidThenCreated() {
+		OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
+			this.principal, this.authorities, this.authorizedClientRegistrationId);
+
+		assertThat(authentication.getPrincipal()).isEqualTo(this.principal);
+		assertThat(authentication.getCredentials()).isEqualTo("");
+		assertThat(authentication.getAuthorities()).isEqualTo(this.authorities);
+		assertThat(authentication.getAuthorizedClientRegistrationId()).isEqualTo(this.authorizedClientRegistrationId);
+		assertThat(authentication.isAuthenticated()).isEqualTo(true);
+	}
+}

+ 215 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java

@@ -0,0 +1,215 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.authentication;
+
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.mockito.stubbing.Answer;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
+import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
+import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
+import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
+import org.springframework.security.oauth2.core.user.OAuth2User;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.LinkedHashSet;
+import java.util.List;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.hamcrest.CoreMatchers.containsString;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyCollection;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link OAuth2LoginAuthenticationProvider}.
+ *
+ * @author Joe Grandja
+ */
+@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationRequest.class,
+	OAuth2AuthorizationResponse.class, OAuth2AccessTokenResponse.class})
+@RunWith(PowerMockRunner.class)
+public class OAuth2LoginAuthenticationProviderTests {
+	private ClientRegistration clientRegistration;
+	private OAuth2AuthorizationRequest authorizationRequest;
+	private OAuth2AuthorizationResponse authorizationResponse;
+	private OAuth2AuthorizationExchange authorizationExchange;
+	private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
+	private OAuth2UserService<OAuth2UserRequest, OAuth2User> userService;
+	private OAuth2LoginAuthenticationProvider authenticationProvider;
+
+	@Rule
+	public ExpectedException exception = ExpectedException.none();
+
+	@Before
+	@SuppressWarnings("unchecked")
+	public void setUp() throws Exception {
+		this.clientRegistration = mock(ClientRegistration.class);
+		this.authorizationRequest = mock(OAuth2AuthorizationRequest.class);
+		this.authorizationResponse = mock(OAuth2AuthorizationResponse.class);
+		this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse);
+		this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
+		this.userService = mock(OAuth2UserService.class);
+		this.authenticationProvider = new OAuth2LoginAuthenticationProvider(this.accessTokenResponseClient, this.userService);
+
+		when(this.authorizationRequest.getScopes()).thenReturn(new LinkedHashSet<>(Arrays.asList("scope1", "scope2")));
+		when(this.authorizationRequest.getState()).thenReturn("12345");
+		when(this.authorizationResponse.getState()).thenReturn("12345");
+		when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com");
+		when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example.com");
+	}
+
+	@Test
+	public void constructorWhenAccessTokenResponseClientIsNullThenThrowIllegalArgumentException() {
+		this.exception.expect(IllegalArgumentException.class);
+		new OAuth2LoginAuthenticationProvider(null, this.userService);
+	}
+
+	@Test
+	public void constructorWhenUserServiceIsNullThenThrowIllegalArgumentException() {
+		this.exception.expect(IllegalArgumentException.class);
+		new OAuth2LoginAuthenticationProvider(this.accessTokenResponseClient, null);
+	}
+
+	@Test
+	public void setAuthoritiesMapperWhenAuthoritiesMapperIsNullThenThrowIllegalArgumentException() {
+		this.exception.expect(IllegalArgumentException.class);
+		this.authenticationProvider.setAuthoritiesMapper(null);
+	}
+
+	@Test
+	public void supportsWhenTypeOAuth2LoginAuthenticationTokenThenReturnTrue() {
+		assertThat(this.authenticationProvider.supports(OAuth2LoginAuthenticationToken.class)).isTrue();
+	}
+
+	@Test
+	public void authenticateWhenAuthorizationRequestContainsOpenidScopeThenReturnNull() {
+		when(this.authorizationRequest.getScopes()).thenReturn(new LinkedHashSet<>(Collections.singleton("openid")));
+
+		OAuth2LoginAuthenticationToken authentication =
+			(OAuth2LoginAuthenticationToken)this.authenticationProvider.authenticate(
+				new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+
+		assertThat(authentication).isNull();
+	}
+
+	@Test
+	public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthenticationException() {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_REQUEST));
+
+		when(this.authorizationResponse.statusError()).thenReturn(true);
+		when(this.authorizationResponse.getError()).thenReturn(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST));
+
+		this.authenticationProvider.authenticate(
+			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+	}
+
+	@Test
+	public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthenticationException() {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_state_parameter"));
+
+		when(this.authorizationRequest.getState()).thenReturn("12345");
+		when(this.authorizationResponse.getState()).thenReturn("67890");
+
+		this.authenticationProvider.authenticate(
+			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+	}
+
+	@Test
+	public void authenticateWhenAuthorizationResponseRedirectUriNotEqualAuthorizationRequestRedirectUriThenThrowOAuth2AuthenticationException() {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_redirect_uri_parameter"));
+
+		when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com");
+		when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example2.com");
+
+		this.authenticationProvider.authenticate(
+			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+	}
+
+	@Test
+	public void authenticateWhenLoginSuccessThenReturnAuthentication() {
+		OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
+		OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
+		when(accessTokenResponse.getAccessToken()).thenReturn(accessToken);
+		when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
+
+		OAuth2User principal = mock(OAuth2User.class);
+		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
+		when(principal.getAuthorities()).thenAnswer(
+			(Answer<List<GrantedAuthority>>) invocation -> authorities);
+		when(this.userService.loadUser(any())).thenReturn(principal);
+
+		OAuth2LoginAuthenticationToken authentication =
+			(OAuth2LoginAuthenticationToken)this.authenticationProvider.authenticate(
+				new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+
+		assertThat(authentication.isAuthenticated()).isTrue();
+		assertThat(authentication.getPrincipal()).isEqualTo(principal);
+		assertThat(authentication.getCredentials()).isEqualTo("");
+		assertThat(authentication.getAuthorities()).isEqualTo(authorities);
+		assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
+		assertThat(authentication.getAccessToken()).isEqualTo(accessToken);
+	}
+
+	@Test
+	public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() {
+		OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
+		OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
+		when(accessTokenResponse.getAccessToken()).thenReturn(accessToken);
+		when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
+
+		OAuth2User principal = mock(OAuth2User.class);
+		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
+		when(principal.getAuthorities()).thenAnswer(
+			(Answer<List<GrantedAuthority>>) invocation -> authorities);
+		when(this.userService.loadUser(any())).thenReturn(principal);
+
+		List<GrantedAuthority> mappedAuthorities = AuthorityUtils.createAuthorityList("ROLE_OAUTH2_USER");
+		GrantedAuthoritiesMapper authoritiesMapper = mock(GrantedAuthoritiesMapper.class);
+		when(authoritiesMapper.mapAuthorities(anyCollection())).thenAnswer(
+			(Answer<List<GrantedAuthority>>) invocation -> mappedAuthorities);
+		this.authenticationProvider.setAuthoritiesMapper(authoritiesMapper);
+
+		OAuth2LoginAuthenticationToken authentication =
+			(OAuth2LoginAuthenticationToken)this.authenticationProvider.authenticate(
+				new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+
+		assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities);
+	}
+}

+ 131 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationTokenTests.java

@@ -0,0 +1,131 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.authentication;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
+import org.springframework.security.oauth2.core.user.OAuth2User;
+
+import java.util.Collection;
+import java.util.Collections;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Tests for {@link OAuth2LoginAuthenticationToken}.
+ *
+ * @author Joe Grandja
+ */
+@RunWith(PowerMockRunner.class)
+@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationExchange.class})
+public class OAuth2LoginAuthenticationTokenTests {
+	private OAuth2User principal;
+	private Collection<? extends GrantedAuthority> authorities;
+	private ClientRegistration clientRegistration;
+	private OAuth2AuthorizationExchange authorizationExchange;
+	private OAuth2AccessToken accessToken;
+
+	@Before
+	public void setUp() {
+		this.principal = mock(OAuth2User.class);
+		this.authorities = Collections.emptyList();
+		this.clientRegistration = mock(ClientRegistration.class);
+		this.authorizationExchange = mock(OAuth2AuthorizationExchange.class);
+		this.accessToken = mock(OAuth2AccessToken.class);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorAuthorizationRequestResponseWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
+		new OAuth2LoginAuthenticationToken(null, this.authorizationExchange);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorAuthorizationRequestResponseWhenAuthorizationExchangeIsNullThenThrowIllegalArgumentException() {
+		new OAuth2LoginAuthenticationToken(this.clientRegistration, null);
+	}
+
+	@Test
+	public void constructorAuthorizationRequestResponseWhenAllParametersProvidedAndValidThenCreated() {
+		OAuth2LoginAuthenticationToken authentication = new OAuth2LoginAuthenticationToken(
+			this.clientRegistration, this.authorizationExchange);
+
+		assertThat(authentication.getPrincipal()).isNull();
+		assertThat(authentication.getCredentials()).isEqualTo("");
+		assertThat(authentication.getAuthorities()).isEqualTo(Collections.emptyList());
+		assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
+		assertThat(authentication.getAccessToken()).isNull();
+		assertThat(authentication.isAuthenticated()).isEqualTo(false);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorTokenRequestResponseWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
+		new OAuth2LoginAuthenticationToken(null, this.authorizationExchange, this.principal,
+			this.authorities, this.accessToken);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorTokenRequestResponseWhenAuthorizationExchangeIsNullThenThrowIllegalArgumentException() {
+		new OAuth2LoginAuthenticationToken(this.clientRegistration, null, this.principal,
+			this.authorities, this.accessToken);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorTokenRequestResponseWhenPrincipalIsNullThenThrowIllegalArgumentException() {
+		new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange, null,
+			this.authorities, this.accessToken);
+	}
+
+	@Test
+	public void constructorTokenRequestResponseWhenAuthoritiesIsNullThenCreated() {
+		new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange,
+			this.principal, null, this.accessToken);
+	}
+
+	@Test
+	public void constructorTokenRequestResponseWhenAuthoritiesIsEmptyThenCreated() {
+		new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange,
+			this.principal, Collections.emptyList(), this.accessToken);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorTokenRequestResponseWhenAccessTokenIsNullThenThrowIllegalArgumentException() {
+		new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange, this.principal,
+			this.authorities, null);
+	}
+
+	@Test
+	public void constructorTokenRequestResponseWhenAllParametersProvidedAndValidThenCreated() {
+		OAuth2LoginAuthenticationToken authentication = new OAuth2LoginAuthenticationToken(
+			this.clientRegistration, this.authorizationExchange, this.principal, this.authorities, this.accessToken);
+
+		assertThat(authentication.getPrincipal()).isEqualTo(this.principal);
+		assertThat(authentication.getCredentials()).isEqualTo("");
+		assertThat(authentication.getAuthorities()).isEqualTo(this.authorities);
+		assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
+		assertThat(authentication.getAccessToken()).isEqualTo(this.accessToken);
+		assertThat(authentication.isAuthenticated()).isEqualTo(true);
+	}
+}

+ 300 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClientTests.java

@@ -0,0 +1,300 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.endpoint;
+
+import okhttp3.mockwebserver.MockResponse;
+import okhttp3.mockwebserver.MockWebServer;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.powermock.core.classloader.annotations.PowerMockIgnore;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
+import org.springframework.security.authentication.AuthenticationServiceException;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
+
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.LinkedHashSet;
+import java.util.Set;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.hamcrest.CoreMatchers.containsString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link NimbusAuthorizationCodeTokenResponseClient}.
+ *
+ * @author Joe Grandja
+ */
+@PowerMockIgnore("okhttp3.*")
+@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationRequest.class, OAuth2AuthorizationResponse.class, OAuth2AuthorizationExchange.class})
+@RunWith(PowerMockRunner.class)
+public class NimbusAuthorizationCodeTokenResponseClientTests {
+	private ClientRegistration clientRegistration;
+	private ClientRegistration.ProviderDetails providerDetails;
+	private OAuth2AuthorizationRequest authorizationRequest;
+	private OAuth2AuthorizationResponse authorizationResponse;
+	private OAuth2AuthorizationExchange authorizationExchange;
+	private NimbusAuthorizationCodeTokenResponseClient tokenResponseClient = new NimbusAuthorizationCodeTokenResponseClient();
+
+	@Rule
+	public ExpectedException exception = ExpectedException.none();
+
+	@Before
+	public void setUp() throws Exception {
+		this.clientRegistration = mock(ClientRegistration.class);
+		this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
+		this.authorizationRequest = mock(OAuth2AuthorizationRequest.class);
+		this.authorizationResponse = mock(OAuth2AuthorizationResponse.class);
+		this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse);
+
+		when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails);
+		when(this.clientRegistration.getClientId()).thenReturn("client-id");
+		when(this.clientRegistration.getClientSecret()).thenReturn("secret");
+		when(this.clientRegistration.getClientAuthenticationMethod()).thenReturn(ClientAuthenticationMethod.BASIC);
+		when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com");
+		when(this.authorizationResponse.getCode()).thenReturn("code");
+	}
+
+	@Test
+	public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
+		MockWebServer server = new MockWebServer();
+
+		String accessTokenSuccessResponse = "{\n" +
+			"	\"access_token\": \"access-token-1234\",\n" +
+			"   \"token_type\": \"bearer\",\n" +
+			"   \"expires_in\": \"3600\",\n" +
+			"   \"scope\": \"openid profile\",\n" +
+			"   \"custom_parameter_1\": \"custom-value-1\",\n" +
+			"   \"custom_parameter_2\": \"custom-value-2\"\n" +
+			"}\n";
+		server.enqueue(new MockResponse()
+			.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
+			.setBody(accessTokenSuccessResponse));
+		server.start();
+
+		String tokenUri = server.url("/oauth2/token").toString();
+		when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
+
+		Instant expiresAtBefore = Instant.now().plusSeconds(3600);
+
+		OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(
+			new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
+
+		Instant expiresAtAfter = Instant.now().plusSeconds(3600);
+
+		server.shutdown();
+
+		assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234");
+		assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER);
+		assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter);
+		assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("openid", "profile");
+		assertThat(accessTokenResponse.getAdditionalParameters().size()).isEqualTo(2);
+		assertThat(accessTokenResponse.getAdditionalParameters()).containsEntry("custom_parameter_1", "custom-value-1");
+		assertThat(accessTokenResponse.getAdditionalParameters()).containsEntry("custom_parameter_2", "custom-value-2");
+	}
+
+	@Test
+	public void getTokenResponseWhenRedirectUriMalformedThenThrowIllegalArgumentException() throws Exception {
+		this.exception.expect(IllegalArgumentException.class);
+
+		String redirectUri = "http:\\example.com";
+		when(this.clientRegistration.getRedirectUri()).thenReturn(redirectUri);
+
+		this.tokenResponseClient.getTokenResponse(
+			new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
+	}
+
+	@Test
+	public void getTokenResponseWhenTokenUriMalformedThenThrowIllegalArgumentException() throws Exception {
+		this.exception.expect(IllegalArgumentException.class);
+
+		String tokenUri = "http:\\provider.com\\oauth2\\token";
+		when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
+
+		this.tokenResponseClient.getTokenResponse(
+			new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
+	}
+
+	@Test
+	public void getTokenResponseWhenSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_token_response"));
+
+		MockWebServer server = new MockWebServer();
+
+		String accessTokenSuccessResponse = "{\n" +
+			"	\"access_token\": \"access-token-1234\",\n" +
+			"   \"token_type\": \"bearer\",\n" +
+			"   \"expires_in\": \"3600\",\n" +
+			"   \"scope\": \"openid profile\",\n" +
+			"   \"custom_parameter_1\": \"custom-value-1\",\n" +
+			"   \"custom_parameter_2\": \"custom-value-2\"\n";
+//			"}\n";		// Make the JSON invalid/malformed
+
+		server.enqueue(new MockResponse()
+			.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
+			.setBody(accessTokenSuccessResponse));
+		server.start();
+
+		String tokenUri = server.url("/oauth2/token").toString();
+		when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
+
+		try {
+			this.tokenResponseClient.getTokenResponse(
+				new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
+		} finally {
+			server.shutdown();
+		}
+	}
+
+	@Test
+	public void getTokenResponseWhenTokenUriInvalidThenThrowAuthenticationServiceException() throws Exception {
+		this.exception.expect(AuthenticationServiceException.class);
+
+		String tokenUri = "http://invalid-provider.com/oauth2/token";
+		when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
+
+		this.tokenResponseClient.getTokenResponse(
+			new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
+	}
+
+	@Test
+	public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthenticationException() throws Exception {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("unauthorized_client"));
+
+		MockWebServer server = new MockWebServer();
+
+		String accessTokenErrorResponse = "{\n" +
+			"   \"error\": \"unauthorized_client\"\n" +
+			"}\n";
+		server.enqueue(new MockResponse()
+			.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
+			.setResponseCode(500)
+			.setBody(accessTokenErrorResponse));
+		server.start();
+
+		String tokenUri = server.url("/oauth2/token").toString();
+		when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
+
+		try {
+			this.tokenResponseClient.getTokenResponse(
+				new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
+		} finally {
+			server.shutdown();
+		}
+	}
+
+	@Test
+	public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthenticationException() throws Exception {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_token_response"));
+
+		MockWebServer server = new MockWebServer();
+
+		String accessTokenSuccessResponse = "{\n" +
+			"	\"access_token\": \"access-token-1234\",\n" +
+			"   \"token_type\": \"not-bearer\",\n" +
+			"   \"expires_in\": \"3600\"\n" +
+			"}\n";
+
+		server.enqueue(new MockResponse()
+			.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
+			.setBody(accessTokenSuccessResponse));
+		server.start();
+
+		String tokenUri = server.url("/oauth2/token").toString();
+		when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
+
+		try {
+			this.tokenResponseClient.getTokenResponse(
+				new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
+		} finally {
+			server.shutdown();
+		}
+	}
+
+	@Test
+	public void getTokenResponseWhenSuccessResponseIncludesScopeThenReturnAccessTokenResponseUsingResponseScope() throws Exception {
+		MockWebServer server = new MockWebServer();
+
+		String accessTokenSuccessResponse = "{\n" +
+			"	\"access_token\": \"access-token-1234\",\n" +
+			"   \"token_type\": \"bearer\",\n" +
+			"   \"expires_in\": \"3600\",\n" +
+			"   \"scope\": \"openid profile\"\n" +
+			"}\n";
+		server.enqueue(new MockResponse()
+			.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
+			.setBody(accessTokenSuccessResponse));
+		server.start();
+
+		String tokenUri = server.url("/oauth2/token").toString();
+		when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
+
+		Set<String> requestedScopes = new LinkedHashSet<>(Arrays.asList("openid", "profile", "email", "address"));
+		when(this.authorizationRequest.getScopes()).thenReturn(requestedScopes);
+
+		OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(
+			new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
+
+		server.shutdown();
+
+		assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("openid", "profile");
+	}
+
+	@Test
+	public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenReturnAccessTokenResponseUsingRequestedScope() throws Exception {
+		MockWebServer server = new MockWebServer();
+
+		String accessTokenSuccessResponse = "{\n" +
+			"	\"access_token\": \"access-token-1234\",\n" +
+			"   \"token_type\": \"bearer\",\n" +
+			"   \"expires_in\": \"3600\"\n" +
+			"}\n";
+		server.enqueue(new MockResponse()
+			.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
+			.setBody(accessTokenSuccessResponse));
+		server.start();
+
+		String tokenUri = server.url("/oauth2/token").toString();
+		when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
+
+		Set<String> requestedScopes = new LinkedHashSet<>(Arrays.asList("openid", "profile", "email", "address"));
+		when(this.authorizationRequest.getScopes()).thenReturn(requestedScopes);
+
+		OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(
+			new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
+
+		server.shutdown();
+
+		assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("openid", "profile", "email", "address");
+	}
+}

+ 66 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestTests.java

@@ -0,0 +1,66 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.endpoint;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Tests for {@link OAuth2AuthorizationCodeGrantRequest}.
+ *
+ * @author Joe Grandja
+ */
+@RunWith(PowerMockRunner.class)
+@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationExchange.class})
+public class OAuth2AuthorizationCodeGrantRequestTests {
+	private ClientRegistration clientRegistration;
+	private OAuth2AuthorizationExchange authorizationExchange;
+
+	@Before
+	public void setUp() {
+		this.clientRegistration = mock(ClientRegistration.class);
+		this.authorizationExchange = mock(OAuth2AuthorizationExchange.class);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
+		new OAuth2AuthorizationCodeGrantRequest(null, this.authorizationExchange);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorWhenAuthorizationExchangeIsNullThenThrowIllegalArgumentException() {
+		new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, null);
+	}
+
+	@Test
+	public void constructorWhenAllParametersProvidedAndValidThenCreated() {
+		OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest =
+			new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange);
+
+		assertThat(authorizationCodeGrantRequest.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationCodeGrantRequest.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
+		assertThat(authorizationCodeGrantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
+	}
+}

+ 414 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java

@@ -0,0 +1,414 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.oidc.authentication;
+
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.mockito.stubbing.Answer;
+import org.powermock.api.mockito.PowerMockito;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
+import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
+import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
+import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
+import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
+import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
+import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
+import org.springframework.security.oauth2.core.oidc.user.OidcUser;
+import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.JwtDecoder;
+
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.hamcrest.CoreMatchers.containsString;
+import static org.mockito.ArgumentMatchers.*;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link OidcAuthorizationCodeAuthenticationProvider}.
+ *
+ * @author Joe Grandja
+ */
+@PrepareForTest({ClientRegistration.class, OAuth2AuthorizationRequest.class, OAuth2AuthorizationResponse.class,
+	OAuth2AccessTokenResponse.class, OidcAuthorizationCodeAuthenticationProvider.class})
+@RunWith(PowerMockRunner.class)
+public class OidcAuthorizationCodeAuthenticationProviderTests {
+	private ClientRegistration clientRegistration;
+	private ClientRegistration.ProviderDetails providerDetails;
+	private OAuth2AuthorizationRequest authorizationRequest;
+	private OAuth2AuthorizationResponse authorizationResponse;
+	private OAuth2AuthorizationExchange authorizationExchange;
+	private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
+	private OAuth2AccessTokenResponse accessTokenResponse;
+	private OAuth2AccessToken accessToken;
+	private OAuth2UserService<OidcUserRequest, OidcUser> userService;
+	private OidcAuthorizationCodeAuthenticationProvider authenticationProvider;
+
+	@Rule
+	public ExpectedException exception = ExpectedException.none();
+
+	@Before
+	@SuppressWarnings("unchecked")
+	public void setUp() throws Exception {
+		this.clientRegistration = mock(ClientRegistration.class);
+		this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
+		this.authorizationRequest = mock(OAuth2AuthorizationRequest.class);
+		this.authorizationResponse = mock(OAuth2AuthorizationResponse.class);
+		this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse);
+		this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
+		this.accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
+		this.accessToken = mock(OAuth2AccessToken.class);
+		this.userService = mock(OAuth2UserService.class);
+		this.authenticationProvider = PowerMockito.spy(
+			new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, this.userService));
+
+		when(this.clientRegistration.getRegistrationId()).thenReturn("client-registration-id-1");
+		when(this.clientRegistration.getClientId()).thenReturn("client1");
+		when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails);
+		when(this.providerDetails.getJwkSetUri()).thenReturn("https://provider.com/oauth2/keys");
+		when(this.authorizationRequest.getScopes()).thenReturn(new LinkedHashSet<>(Arrays.asList("openid", "profile", "email")));
+		when(this.authorizationRequest.getState()).thenReturn("12345");
+		when(this.authorizationResponse.getState()).thenReturn("12345");
+		when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com");
+		when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example.com");
+		when(this.accessTokenResponse.getAccessToken()).thenReturn(this.accessToken);
+		Map<String, Object> additionalParameters = new HashMap<>();
+		additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token");
+		when(this.accessTokenResponse.getAdditionalParameters()).thenReturn(additionalParameters);
+		when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(this.accessTokenResponse);
+	}
+
+	@Test
+	public void constructorWhenAccessTokenResponseClientIsNullThenThrowIllegalArgumentException() {
+		this.exception.expect(IllegalArgumentException.class);
+		new OidcAuthorizationCodeAuthenticationProvider(null, this.userService);
+	}
+
+	@Test
+	public void constructorWhenUserServiceIsNullThenThrowIllegalArgumentException() {
+		this.exception.expect(IllegalArgumentException.class);
+		new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, null);
+	}
+
+	@Test
+	public void setAuthoritiesMapperWhenAuthoritiesMapperIsNullThenThrowIllegalArgumentException() {
+		this.exception.expect(IllegalArgumentException.class);
+		this.authenticationProvider.setAuthoritiesMapper(null);
+	}
+
+	@Test
+	public void supportsWhenTypeOAuth2LoginAuthenticationTokenThenReturnTrue() {
+		assertThat(this.authenticationProvider.supports(OAuth2LoginAuthenticationToken.class)).isTrue();
+	}
+
+	@Test
+	public void authenticateWhenAuthorizationRequestDoesNotContainOpenidScopeThenReturnNull() {
+		when(this.authorizationRequest.getScopes()).thenReturn(new LinkedHashSet<>(Collections.singleton("scope1")));
+
+		OAuth2LoginAuthenticationToken authentication =
+			(OAuth2LoginAuthenticationToken)this.authenticationProvider.authenticate(
+				new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+
+		assertThat(authentication).isNull();
+	}
+
+	@Test
+	public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthenticationException() {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_SCOPE));
+
+		when(this.authorizationResponse.statusError()).thenReturn(true);
+		when(this.authorizationResponse.getError()).thenReturn(new OAuth2Error(OAuth2ErrorCodes.INVALID_SCOPE));
+
+		this.authenticationProvider.authenticate(
+			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+	}
+
+	@Test
+	public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthenticationException() {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_state_parameter"));
+
+		when(this.authorizationRequest.getState()).thenReturn("34567");
+		when(this.authorizationResponse.getState()).thenReturn("89012");
+
+		this.authenticationProvider.authenticate(
+			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+	}
+
+	@Test
+	public void authenticateWhenAuthorizationResponseRedirectUriNotEqualAuthorizationRequestRedirectUriThenThrowOAuth2AuthenticationException() {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_redirect_uri_parameter"));
+
+		when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example1.com");
+		when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example2.com");
+
+		this.authenticationProvider.authenticate(
+			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+	}
+
+	@Test
+	public void authenticateWhenTokenResponseDoesNotContainIdTokenThenThrowOAuth2AuthenticationException() {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_id_token"));
+
+		when(this.accessTokenResponse.getAdditionalParameters()).thenReturn(Collections.emptyMap());
+
+		this.authenticationProvider.authenticate(
+			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+	}
+
+	@Test
+	public void authenticateWhenJwkSetUriNotSetThenThrowOAuth2AuthenticationException() {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("missing_signature_verifier"));
+
+		when(this.providerDetails.getJwkSetUri()).thenReturn(null);
+
+		this.authenticationProvider.authenticate(
+			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+	}
+
+	@Test
+	public void authenticateWhenIdTokenIssuerClaimIsNullThenThrowOAuth2AuthenticationException() throws Exception {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_id_token"));
+
+		Map<String, Object> claims = new HashMap<>();
+		claims.put(IdTokenClaimNames.SUB, "subject1");
+
+		this.setUpIdToken(claims);
+
+		this.authenticationProvider.authenticate(
+			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+	}
+
+	@Test
+	public void authenticateWhenIdTokenSubjectClaimIsNullThenThrowOAuth2AuthenticationException() throws Exception {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_id_token"));
+
+		Map<String, Object> claims = new HashMap<>();
+		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
+
+		this.setUpIdToken(claims);
+
+		this.authenticationProvider.authenticate(
+			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+	}
+
+	@Test
+	public void authenticateWhenIdTokenAudienceClaimIsNullThenThrowOAuth2AuthenticationException() throws Exception {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_id_token"));
+
+		Map<String, Object> claims = new HashMap<>();
+		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
+		claims.put(IdTokenClaimNames.SUB, "subject1");
+
+		this.setUpIdToken(claims);
+
+		this.authenticationProvider.authenticate(
+			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+	}
+
+	@Test
+	public void authenticateWhenIdTokenAudienceClaimDoesNotContainClientIdThenThrowOAuth2AuthenticationException() throws Exception {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_id_token"));
+
+		Map<String, Object> claims = new HashMap<>();
+		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
+		claims.put(IdTokenClaimNames.SUB, "subject1");
+		claims.put(IdTokenClaimNames.AUD, Collections.singletonList("other-client"));
+
+		this.setUpIdToken(claims);
+
+		this.authenticationProvider.authenticate(
+			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+	}
+
+	@Test
+	public void authenticateWhenIdTokenMultipleAudienceClaimAndAuthorizedPartyClaimIsNullThenThrowOAuth2AuthenticationException() throws Exception {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_id_token"));
+
+		Map<String, Object> claims = new HashMap<>();
+		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
+		claims.put(IdTokenClaimNames.SUB, "subject1");
+		claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
+
+		this.setUpIdToken(claims);
+
+		this.authenticationProvider.authenticate(
+			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+	}
+
+	@Test
+	public void authenticateWhenIdTokenAuthorizedPartyClaimNotEqualToClientIdThenThrowOAuth2AuthenticationException() throws Exception {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_id_token"));
+
+		Map<String, Object> claims = new HashMap<>();
+		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
+		claims.put(IdTokenClaimNames.SUB, "subject1");
+		claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
+		claims.put(IdTokenClaimNames.AZP, "other-client");
+
+		this.setUpIdToken(claims);
+
+		this.authenticationProvider.authenticate(
+			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+	}
+
+	@Test
+	public void authenticateWhenIdTokenExpiresAtIsBeforeNowThenThrowOAuth2AuthenticationException() throws Exception {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_id_token"));
+
+		Map<String, Object> claims = new HashMap<>();
+		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
+		claims.put(IdTokenClaimNames.SUB, "subject1");
+		claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
+		claims.put(IdTokenClaimNames.AZP, "client1");
+
+		Instant issuedAt = Instant.now().minusSeconds(10);
+		Instant expiresAt = Instant.from(issuedAt).plusSeconds(5);
+
+		this.setUpIdToken(claims, issuedAt, expiresAt);
+
+		this.authenticationProvider.authenticate(
+			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+	}
+
+	@Test
+	public void authenticateWhenIdTokenIssuedAtIsAfterMaxIssuedAtThenThrowOAuth2AuthenticationException() throws Exception {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_id_token"));
+
+		Map<String, Object> claims = new HashMap<>();
+		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
+		claims.put(IdTokenClaimNames.SUB, "subject1");
+		claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
+		claims.put(IdTokenClaimNames.AZP, "client1");
+
+		Instant issuedAt = Instant.now().plusSeconds(35);
+		Instant expiresAt = Instant.from(issuedAt).plusSeconds(60);
+
+		this.setUpIdToken(claims, issuedAt, expiresAt);
+
+		this.authenticationProvider.authenticate(
+			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+	}
+
+	@Test
+	public void authenticateWhenLoginSuccessThenReturnAuthentication() throws Exception {
+		Map<String, Object> claims = new HashMap<>();
+		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
+		claims.put(IdTokenClaimNames.SUB, "subject1");
+		claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
+		claims.put(IdTokenClaimNames.AZP, "client1");
+		this.setUpIdToken(claims);
+
+		OidcUser principal = mock(OidcUser.class);
+		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
+		when(principal.getAuthorities()).thenAnswer(
+			(Answer<List<GrantedAuthority>>) invocation -> authorities);
+		when(this.userService.loadUser(any())).thenReturn(principal);
+
+		OAuth2LoginAuthenticationToken authentication =
+			(OAuth2LoginAuthenticationToken)this.authenticationProvider.authenticate(
+				new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+
+		assertThat(authentication.isAuthenticated()).isTrue();
+		assertThat(authentication.getPrincipal()).isEqualTo(principal);
+		assertThat(authentication.getCredentials()).isEqualTo("");
+		assertThat(authentication.getAuthorities()).isEqualTo(authorities);
+		assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
+		assertThat(authentication.getAccessToken()).isEqualTo(this.accessToken);
+	}
+
+	@Test
+	public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() throws Exception {
+		Map<String, Object> claims = new HashMap<>();
+		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
+		claims.put(IdTokenClaimNames.SUB, "subject1");
+		claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
+		claims.put(IdTokenClaimNames.AZP, "client1");
+		this.setUpIdToken(claims);
+
+		OidcUser principal = mock(OidcUser.class);
+		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
+		when(principal.getAuthorities()).thenAnswer(
+			(Answer<List<GrantedAuthority>>) invocation -> authorities);
+		when(this.userService.loadUser(any())).thenReturn(principal);
+
+		List<GrantedAuthority> mappedAuthorities = AuthorityUtils.createAuthorityList("ROLE_OIDC_USER");
+		GrantedAuthoritiesMapper authoritiesMapper = mock(GrantedAuthoritiesMapper.class);
+		when(authoritiesMapper.mapAuthorities(anyCollection())).thenAnswer(
+			(Answer<List<GrantedAuthority>>) invocation -> mappedAuthorities);
+		this.authenticationProvider.setAuthoritiesMapper(authoritiesMapper);
+
+		OAuth2LoginAuthenticationToken authentication =
+			(OAuth2LoginAuthenticationToken)this.authenticationProvider.authenticate(
+				new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+
+		assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities);
+	}
+
+	private void setUpIdToken(Map<String, Object> claims) throws Exception {
+		Instant issuedAt = Instant.now();
+		Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600);
+		this.setUpIdToken(claims, issuedAt, expiresAt);
+	}
+
+	private void setUpIdToken(Map<String, Object> claims, Instant issuedAt, Instant expiresAt) throws Exception {
+		Map<String, Object> headers = new HashMap<>();
+		headers.put("alg", "RS256");
+
+		Jwt idToken = new Jwt("id-token", issuedAt, expiresAt, headers, claims);
+
+		JwtDecoder jwtDecoder = mock(JwtDecoder.class);
+		when(jwtDecoder.decode(anyString())).thenReturn(idToken);
+		PowerMockito.doReturn(jwtDecoder).when(this.authenticationProvider, "getJwtDecoder", any(ClientRegistration.class));
+	}
+}

+ 73 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestTests.java

@@ -0,0 +1,73 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.oidc.userinfo;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.oidc.OidcIdToken;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Tests for {@link OidcUserRequest}.
+ *
+ * @author Joe Grandja
+ */
+@RunWith(PowerMockRunner.class)
+@PrepareForTest(ClientRegistration.class)
+public class OidcUserRequestTests {
+	private ClientRegistration clientRegistration;
+	private OAuth2AccessToken accessToken;
+	private OidcIdToken idToken;
+
+	@Before
+	public void setUp() {
+		this.clientRegistration = mock(ClientRegistration.class);
+		this.accessToken = mock(OAuth2AccessToken.class);
+		this.idToken = mock(OidcIdToken.class);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
+		new OidcUserRequest(null, this.accessToken, this.idToken);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() {
+		new OidcUserRequest(this.clientRegistration, null, this.idToken);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorWhenIdTokenIsNullThenThrowIllegalArgumentException() {
+		new OidcUserRequest(this.clientRegistration, this.accessToken, null);
+	}
+
+	@Test
+	public void constructorWhenAllParametersProvidedAndValidThenCreated() {
+		OidcUserRequest userRequest = new OidcUserRequest(
+			this.clientRegistration, this.accessToken, this.idToken);
+
+		assertThat(userRequest.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(userRequest.getAccessToken()).isEqualTo(this.accessToken);
+		assertThat(userRequest.getIdToken()).isEqualTo(this.idToken);
+	}
+}

+ 260 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java

@@ -0,0 +1,260 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.oidc.userinfo;
+
+import okhttp3.mockwebserver.MockResponse;
+import okhttp3.mockwebserver.MockWebServer;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.powermock.core.classloader.annotations.PowerMockIgnore;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
+import org.springframework.security.authentication.AuthenticationServiceException;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
+import org.springframework.security.oauth2.core.oidc.OidcIdToken;
+import org.springframework.security.oauth2.core.oidc.OidcScopes;
+import org.springframework.security.oauth2.core.oidc.user.OidcUser;
+import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.LinkedHashSet;
+import java.util.Map;
+import java.util.Set;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.hamcrest.CoreMatchers.containsString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link OidcUserService}.
+ *
+ * @author Joe Grandja
+ */
+@PowerMockIgnore("okhttp3.*")
+@PrepareForTest(ClientRegistration.class)
+@RunWith(PowerMockRunner.class)
+public class OidcUserServiceTests {
+	private ClientRegistration clientRegistration;
+	private ClientRegistration.ProviderDetails providerDetails;
+	private ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint;
+	private OAuth2AccessToken accessToken;
+	private OidcIdToken idToken;
+	private OidcUserService userService = new OidcUserService();
+
+	@Rule
+	public ExpectedException exception = ExpectedException.none();
+
+	@Before
+	public void setUp() throws Exception {
+		this.clientRegistration = mock(ClientRegistration.class);
+		this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
+		this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class);
+		when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails);
+		when(this.providerDetails.getUserInfoEndpoint()).thenReturn(this.userInfoEndpoint);
+		when(this.clientRegistration.getAuthorizationGrantType()).thenReturn(AuthorizationGrantType.AUTHORIZATION_CODE);
+
+		this.accessToken = mock(OAuth2AccessToken.class);
+		Set<String> authorizedScopes = new LinkedHashSet<>(Arrays.asList(OidcScopes.OPENID, OidcScopes.PROFILE));
+		when(this.accessToken.getScopes()).thenReturn(authorizedScopes);
+
+		this.idToken = mock(OidcIdToken.class);
+		Map<String, Object> idTokenClaims = new HashMap<>();
+		idTokenClaims.put(IdTokenClaimNames.ISS, "https://provider.com");
+		idTokenClaims.put(IdTokenClaimNames.SUB, "subject1");
+		when(this.idToken.getClaims()).thenReturn(idTokenClaims);
+		when(this.idToken.getSubject()).thenReturn("subject1");
+	}
+
+	@Test
+	public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
+		this.exception.expect(IllegalArgumentException.class);
+		this.userService.loadUser(null);
+	}
+
+	@Test
+	public void loadUserWhenUserInfoUriIsNullThenUserInfoEndpointNotRequested() {
+		when(this.userInfoEndpoint.getUri()).thenReturn(null);
+
+		OidcUser user = this.userService.loadUser(
+			new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
+		assertThat(user.getUserInfo()).isNull();
+	}
+
+	@Test
+	public void loadUserWhenAuthorizedScopesDoesNotContainUserInfoScopesThenUserInfoEndpointNotRequested() {
+		Set<String> authorizedScopes = new LinkedHashSet<>(Arrays.asList("scope1", "scope2"));
+		when(this.accessToken.getScopes()).thenReturn(authorizedScopes);
+
+		when(this.userInfoEndpoint.getUri()).thenReturn("http://provider.com/user");
+
+		OidcUser user = this.userService.loadUser(
+			new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
+		assertThat(user.getUserInfo()).isNull();
+	}
+
+	@Test
+	public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception {
+		MockWebServer server = new MockWebServer();
+
+		String userInfoResponse = "{\n" +
+			"	\"sub\": \"subject1\",\n" +
+			"   \"name\": \"first last\",\n" +
+			"   \"given_name\": \"first\",\n" +
+			"   \"family_name\": \"last\",\n" +
+			"   \"preferred_username\": \"user1\",\n" +
+			"   \"email\": \"user1@example.com\"\n" +
+			"}\n";
+		server.enqueue(new MockResponse()
+			.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
+			.setBody(userInfoResponse));
+
+		server.start();
+
+		String userInfoUri = server.url("/user").toString();
+
+		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
+		when(this.accessToken.getTokenValue()).thenReturn("access-token");
+
+		OidcUser user = this.userService.loadUser(
+			new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
+
+		server.shutdown();
+
+		assertThat(user.getIdToken()).isNotNull();
+		assertThat(user.getUserInfo()).isNotNull();
+		assertThat(user.getUserInfo().getClaims().size()).isEqualTo(6);
+		assertThat(user.getIdToken()).isEqualTo(this.idToken);
+		assertThat(user.getName()).isEqualTo("subject1");
+		assertThat(user.getUserInfo().getSubject()).isEqualTo("subject1");
+		assertThat(user.getUserInfo().getFullName()).isEqualTo("first last");
+		assertThat(user.getUserInfo().getGivenName()).isEqualTo("first");
+		assertThat(user.getUserInfo().getFamilyName()).isEqualTo("last");
+		assertThat(user.getUserInfo().getPreferredUsername()).isEqualTo("user1");
+		assertThat(user.getUserInfo().getEmail()).isEqualTo("user1@example.com");
+
+		assertThat(user.getAuthorities().size()).isEqualTo(1);
+		assertThat(user.getAuthorities().iterator().next()).isInstanceOf(OidcUserAuthority.class);
+		OidcUserAuthority userAuthority = (OidcUserAuthority)user.getAuthorities().iterator().next();
+		assertThat(userAuthority.getAuthority()).isEqualTo("ROLE_USER");
+		assertThat(userAuthority.getIdToken()).isEqualTo(user.getIdToken());
+		assertThat(userAuthority.getUserInfo()).isEqualTo(user.getUserInfo());
+	}
+
+	@Test
+	public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectNotSameAsIdTokenSubjectThenThrowOAuth2AuthenticationException() throws Exception {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_user_info_response"));
+
+		MockWebServer server = new MockWebServer();
+
+		String userInfoResponse = "{\n" +
+			"	\"sub\": \"other-subject\"\n" +
+			"}\n";
+		server.enqueue(new MockResponse()
+			.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
+			.setBody(userInfoResponse));
+
+		server.start();
+
+		String userInfoUri = server.url("/user").toString();
+
+		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
+		when(this.accessToken.getTokenValue()).thenReturn("access-token");
+
+		try {
+			this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
+		} finally {
+			server.shutdown();
+		}
+	}
+
+	@Test
+	public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_user_info_response"));
+
+		MockWebServer server = new MockWebServer();
+
+		String userInfoResponse = "{\n" +
+			"	\"sub\": \"subject1\",\n" +
+			"   \"name\": \"first last\",\n" +
+			"   \"given_name\": \"first\",\n" +
+			"   \"family_name\": \"last\",\n" +
+			"   \"preferred_username\": \"user1\",\n" +
+			"   \"email\": \"user1@example.com\"\n";
+//			"}\n";		// Make the JSON invalid/malformed
+		server.enqueue(new MockResponse()
+			.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
+			.setBody(userInfoResponse));
+
+		server.start();
+
+		String userInfoUri = server.url("/user").toString();
+
+		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
+		when(this.accessToken.getTokenValue()).thenReturn("access-token");
+
+		try {
+			this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
+		} finally {
+			server.shutdown();
+		}
+	}
+
+	@Test
+	public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() throws Exception {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_user_info_response"));
+
+		MockWebServer server = new MockWebServer();
+		server.enqueue(new MockResponse().setResponseCode(500));
+		server.start();
+
+		String userInfoUri = server.url("/user").toString();
+
+		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
+		when(this.accessToken.getTokenValue()).thenReturn("access-token");
+
+		try {
+			this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
+		} finally {
+			server.shutdown();
+		}
+	}
+
+	@Test
+	public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceException() throws Exception {
+		this.exception.expect(AuthenticationServiceException.class);
+
+		String userInfoUri = "http://invalid-provider.com/user";
+
+		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
+		when(this.accessToken.getTokenValue()).thenReturn("access-token");
+
+		this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
+	}
+}

+ 354 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java

@@ -0,0 +1,354 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.registration;
+
+import org.junit.Test;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+
+import java.util.Arrays;
+import java.util.LinkedHashSet;
+import java.util.Set;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Tests for {@link ClientRegistration}.
+ *
+ * @author Joe Grandja
+ */
+public class ClientRegistrationTests {
+	private static final String REGISTRATION_ID = "registration-1";
+	private static final String CLIENT_ID = "client-1";
+	private static final String CLIENT_SECRET = "secret";
+	private static final String REDIRECT_URI = "https://example.com";
+	private static final Set<String> SCOPES = new LinkedHashSet<>(Arrays.asList("openid", "profile", "email"));
+	private static final String AUTHORIZATION_URI = "https://provider.com/oauth2/authorization";
+	private static final String TOKEN_URI = "https://provider.com/oauth2/token";
+	private static final String JWK_SET_URI = "https://provider.com/oauth2/keys";
+	private static final String CLIENT_NAME = "Client 1";
+
+	@Test(expected = IllegalArgumentException.class)
+	public void buildWhenAuthorizationGrantTypeIsNullThenThrowIllegalArgumentException() {
+		ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(CLIENT_ID)
+			.clientSecret(CLIENT_SECRET)
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(null)
+			.redirectUri(REDIRECT_URI)
+			.scope(SCOPES.toArray(new String[0]))
+			.authorizationUri(AUTHORIZATION_URI)
+			.tokenUri(TOKEN_URI)
+			.jwkSetUri(JWK_SET_URI)
+			.clientName(CLIENT_NAME)
+			.build();
+	}
+
+	@Test
+	public void buildWhenAuthorizationCodeGrantAllAttributesProvidedThenAllAttributesAreSet() {
+		ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(CLIENT_ID)
+			.clientSecret(CLIENT_SECRET)
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.redirectUri(REDIRECT_URI)
+			.scope(SCOPES.toArray(new String[0]))
+			.authorizationUri(AUTHORIZATION_URI)
+			.tokenUri(TOKEN_URI)
+			.jwkSetUri(JWK_SET_URI)
+			.clientName(CLIENT_NAME)
+			.build();
+
+		assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID);
+		assertThat(registration.getClientId()).isEqualTo(CLIENT_ID);
+		assertThat(registration.getClientSecret()).isEqualTo(CLIENT_SECRET);
+		assertThat(registration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC);
+		assertThat(registration.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
+		assertThat(registration.getRedirectUri()).isEqualTo(REDIRECT_URI);
+		assertThat(registration.getScopes()).isEqualTo(SCOPES);
+		assertThat(registration.getProviderDetails().getAuthorizationUri()).isEqualTo(AUTHORIZATION_URI);
+		assertThat(registration.getProviderDetails().getTokenUri()).isEqualTo(TOKEN_URI);
+		assertThat(registration.getProviderDetails().getJwkSetUri()).isEqualTo(JWK_SET_URI);
+		assertThat(registration.getClientName()).isEqualTo(CLIENT_NAME);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void buildWhenAuthorizationCodeGrantRegistrationIdIsNullThenThrowIllegalArgumentException() {
+		ClientRegistration.withRegistrationId(null)
+			.clientId(CLIENT_ID)
+			.clientSecret(CLIENT_SECRET)
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.redirectUri(REDIRECT_URI)
+			.scope(SCOPES.toArray(new String[0]))
+			.authorizationUri(AUTHORIZATION_URI)
+			.tokenUri(TOKEN_URI)
+			.jwkSetUri(JWK_SET_URI)
+			.clientName(CLIENT_NAME)
+			.build();
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void buildWhenAuthorizationCodeGrantClientIdIsNullThenThrowIllegalArgumentException() {
+		ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(null)
+			.clientSecret(CLIENT_SECRET)
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.redirectUri(REDIRECT_URI)
+			.scope(SCOPES.toArray(new String[0]))
+			.authorizationUri(AUTHORIZATION_URI)
+			.tokenUri(TOKEN_URI)
+			.jwkSetUri(JWK_SET_URI)
+			.clientName(CLIENT_NAME)
+			.build();
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void buildWhenAuthorizationCodeGrantClientSecretIsNullThenThrowIllegalArgumentException() {
+		ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(CLIENT_ID)
+			.clientSecret(null)
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.redirectUri(REDIRECT_URI)
+			.scope(SCOPES.toArray(new String[0]))
+			.authorizationUri(AUTHORIZATION_URI)
+			.tokenUri(TOKEN_URI)
+			.jwkSetUri(JWK_SET_URI)
+			.clientName(CLIENT_NAME)
+			.build();
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void buildWhenAuthorizationCodeGrantClientAuthenticationMethodIsNullThenThrowIllegalArgumentException() {
+		ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(CLIENT_ID)
+			.clientSecret(CLIENT_SECRET)
+			.clientAuthenticationMethod(null)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.redirectUri(REDIRECT_URI)
+			.scope(SCOPES.toArray(new String[0]))
+			.authorizationUri(AUTHORIZATION_URI)
+			.tokenUri(TOKEN_URI)
+			.jwkSetUri(JWK_SET_URI)
+			.clientName(CLIENT_NAME)
+			.build();
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void buildWhenAuthorizationCodeGrantRedirectUriIsNullThenThrowIllegalArgumentException() {
+		ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(CLIENT_ID)
+			.clientSecret(CLIENT_SECRET)
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.redirectUri(null)
+			.scope(SCOPES.toArray(new String[0]))
+			.authorizationUri(AUTHORIZATION_URI)
+			.tokenUri(TOKEN_URI)
+			.jwkSetUri(JWK_SET_URI)
+			.clientName(CLIENT_NAME)
+			.build();
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void buildWhenAuthorizationCodeGrantScopeIsNullThenThrowIllegalArgumentException() {
+		ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(CLIENT_ID)
+			.clientSecret(CLIENT_SECRET)
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.redirectUri(REDIRECT_URI)
+			.scope(null)
+			.authorizationUri(AUTHORIZATION_URI)
+			.tokenUri(TOKEN_URI)
+			.jwkSetUri(JWK_SET_URI)
+			.clientName(CLIENT_NAME)
+			.build();
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void buildWhenAuthorizationCodeGrantAuthorizationUriIsNullThenThrowIllegalArgumentException() {
+		ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(CLIENT_ID)
+			.clientSecret(CLIENT_SECRET)
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.redirectUri(REDIRECT_URI)
+			.scope(SCOPES.toArray(new String[0]))
+			.authorizationUri(null)
+			.tokenUri(TOKEN_URI)
+			.jwkSetUri(JWK_SET_URI)
+			.clientName(CLIENT_NAME)
+			.build();
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void buildWhenAuthorizationCodeGrantTokenUriIsNullThenThrowIllegalArgumentException() {
+		ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(CLIENT_ID)
+			.clientSecret(CLIENT_SECRET)
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.redirectUri(REDIRECT_URI)
+			.scope(SCOPES.toArray(new String[0]))
+			.authorizationUri(AUTHORIZATION_URI)
+			.tokenUri(null)
+			.jwkSetUri(JWK_SET_URI)
+			.clientName(CLIENT_NAME)
+			.build();
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void buildWhenAuthorizationCodeGrantJwkSetUriIsNullThenThrowIllegalArgumentException() {
+		ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(CLIENT_ID)
+			.clientSecret(CLIENT_SECRET)
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.redirectUri(REDIRECT_URI)
+			.scope(SCOPES.toArray(new String[0]))
+			.authorizationUri(AUTHORIZATION_URI)
+			.tokenUri(TOKEN_URI)
+			.jwkSetUri(null)
+			.clientName(CLIENT_NAME)
+			.build();
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void buildWhenAuthorizationCodeGrantClientNameIsNullThenThrowIllegalArgumentException() {
+		ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(CLIENT_ID)
+			.clientSecret(CLIENT_SECRET)
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.redirectUri(REDIRECT_URI)
+			.scope(SCOPES.toArray(new String[0]))
+			.authorizationUri(AUTHORIZATION_URI)
+			.tokenUri(TOKEN_URI)
+			.jwkSetUri(JWK_SET_URI)
+			.clientName(null)
+			.build();
+	}
+
+	@Test
+	public void buildWhenAuthorizationCodeGrantScopeDoesNotContainOpenidThenJwkSetUriNotRequired() {
+		ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(CLIENT_ID)
+			.clientSecret(CLIENT_SECRET)
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.redirectUri(REDIRECT_URI)
+			.scope("scope1")
+			.authorizationUri(AUTHORIZATION_URI)
+			.tokenUri(TOKEN_URI)
+			.clientName(CLIENT_NAME)
+			.build();
+	}
+
+	@Test
+	public void buildWhenImplicitGrantAllAttributesProvidedThenAllAttributesAreSet() {
+		ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(CLIENT_ID)
+			.authorizationGrantType(AuthorizationGrantType.IMPLICIT)
+			.redirectUri(REDIRECT_URI)
+			.scope(SCOPES.toArray(new String[0]))
+			.authorizationUri(AUTHORIZATION_URI)
+			.clientName(CLIENT_NAME)
+			.build();
+
+		assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID);
+		assertThat(registration.getClientId()).isEqualTo(CLIENT_ID);
+		assertThat(registration.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.IMPLICIT);
+		assertThat(registration.getRedirectUri()).isEqualTo(REDIRECT_URI);
+		assertThat(registration.getScopes()).isEqualTo(SCOPES);
+		assertThat(registration.getProviderDetails().getAuthorizationUri()).isEqualTo(AUTHORIZATION_URI);
+		assertThat(registration.getClientName()).isEqualTo(CLIENT_NAME);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void buildWhenImplicitGrantRegistrationIdIsNullThenThrowIllegalArgumentException() {
+		ClientRegistration.withRegistrationId(null)
+			.clientId(CLIENT_ID)
+			.authorizationGrantType(AuthorizationGrantType.IMPLICIT)
+			.redirectUri(REDIRECT_URI)
+			.scope(SCOPES.toArray(new String[0]))
+			.authorizationUri(AUTHORIZATION_URI)
+			.clientName(CLIENT_NAME)
+			.build();
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void buildWhenImplicitGrantClientIdIsNullThenThrowIllegalArgumentException() {
+		ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(null)
+			.authorizationGrantType(AuthorizationGrantType.IMPLICIT)
+			.redirectUri(REDIRECT_URI)
+			.scope(SCOPES.toArray(new String[0]))
+			.authorizationUri(AUTHORIZATION_URI)
+			.clientName(CLIENT_NAME)
+			.build();
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void buildWhenImplicitGrantRedirectUriIsNullThenThrowIllegalArgumentException() {
+		ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(CLIENT_ID)
+			.authorizationGrantType(AuthorizationGrantType.IMPLICIT)
+			.redirectUri(null)
+			.scope(SCOPES.toArray(new String[0]))
+			.authorizationUri(AUTHORIZATION_URI)
+			.clientName(CLIENT_NAME)
+			.build();
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void buildWhenImplicitGrantScopeIsNullThenThrowIllegalArgumentException() {
+		ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(CLIENT_ID)
+			.authorizationGrantType(AuthorizationGrantType.IMPLICIT)
+			.redirectUri(REDIRECT_URI)
+			.scope(null)
+			.authorizationUri(AUTHORIZATION_URI)
+			.clientName(CLIENT_NAME)
+			.build();
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void buildWhenImplicitGrantAuthorizationUriIsNullThenThrowIllegalArgumentException() {
+		ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(CLIENT_ID)
+			.authorizationGrantType(AuthorizationGrantType.IMPLICIT)
+			.redirectUri(REDIRECT_URI)
+			.scope(SCOPES.toArray(new String[0]))
+			.authorizationUri(null)
+			.clientName(CLIENT_NAME)
+			.build();
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void buildWhenImplicitGrantClientNameIsNullThenThrowIllegalArgumentException() {
+		ClientRegistration.withRegistrationId(REGISTRATION_ID)
+			.clientId(CLIENT_ID)
+			.authorizationGrantType(AuthorizationGrantType.IMPLICIT)
+			.redirectUri(REDIRECT_URI)
+			.scope(SCOPES.toArray(new String[0]))
+			.authorizationUri(AUTHORIZATION_URI)
+			.clientName(null)
+			.build();
+	}
+}

+ 3 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepositoryTests.java

@@ -24,9 +24,11 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 
-import static org.assertj.core.api.Assertions.*;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /**
+ * Tests for {@link InMemoryClientRegistrationRepository}.
+ *
  * @author Rob Winch
  * @since 5.0
  */

+ 267 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java

@@ -0,0 +1,267 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.userinfo;
+
+import okhttp3.mockwebserver.MockResponse;
+import okhttp3.mockwebserver.MockWebServer;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.powermock.core.classloader.annotations.PowerMockIgnore;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
+import org.springframework.security.authentication.AuthenticationServiceException;
+import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.user.OAuth2User;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.hamcrest.CoreMatchers.containsString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link CustomUserTypesOAuth2UserService}.
+ *
+ * @author Joe Grandja
+ */
+@PowerMockIgnore("okhttp3.*")
+@PrepareForTest(ClientRegistration.class)
+@RunWith(PowerMockRunner.class)
+public class CustomUserTypesOAuth2UserServiceTests {
+	private ClientRegistration clientRegistration;
+	private ClientRegistration.ProviderDetails providerDetails;
+	private ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint;
+	private OAuth2AccessToken accessToken;
+	private CustomUserTypesOAuth2UserService userService;
+
+	@Rule
+	public ExpectedException exception = ExpectedException.none();
+
+	@Before
+	public void setUp() throws Exception {
+		this.clientRegistration = mock(ClientRegistration.class);
+		this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
+		this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class);
+		when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails);
+		when(this.providerDetails.getUserInfoEndpoint()).thenReturn(this.userInfoEndpoint);
+		String registrationId = "client-registration-id-1";
+		when(this.clientRegistration.getRegistrationId()).thenReturn(registrationId);
+		this.accessToken = mock(OAuth2AccessToken.class);
+
+		Map<String, Class<? extends OAuth2User>> customUserTypes = new HashMap<>();
+		customUserTypes.put(registrationId, CustomOAuth2User.class);
+		this.userService = new CustomUserTypesOAuth2UserService(customUserTypes);
+	}
+
+	@Test
+	public void constructorWhenCustomUserTypesIsNullThenThrowIllegalArgumentException() {
+		this.exception.expect(IllegalArgumentException.class);
+		new CustomUserTypesOAuth2UserService(null);
+	}
+
+	@Test
+	public void constructorWhenCustomUserTypesIsEmptyThenThrowIllegalArgumentException() {
+		this.exception.expect(IllegalArgumentException.class);
+		new CustomUserTypesOAuth2UserService(Collections.emptyMap());
+	}
+
+	@Test
+	public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
+		this.exception.expect(IllegalArgumentException.class);
+		this.userService.loadUser(null);
+	}
+
+	@Test
+	public void loadUserWhenCustomUserTypeNotFoundThenReturnNull() {
+		when(this.clientRegistration.getRegistrationId()).thenReturn("other-client-registration-id-1");
+
+		OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
+		assertThat(user).isNull();
+	}
+
+	@Test
+	public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception {
+		MockWebServer server = new MockWebServer();
+
+		String userInfoResponse = "{\n" +
+			"	\"id\": \"12345\",\n" +
+			"   \"name\": \"first last\",\n" +
+			"   \"login\": \"user1\",\n" +
+			"   \"email\": \"user1@example.com\"\n" +
+			"}\n";
+		server.enqueue(new MockResponse()
+			.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
+			.setBody(userInfoResponse));
+
+		server.start();
+
+		String userInfoUri = server.url("/user").toString();
+
+		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
+		when(this.accessToken.getTokenValue()).thenReturn("access-token");
+
+		OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
+
+		server.shutdown();
+
+		assertThat(user.getName()).isEqualTo("first last");
+		assertThat(user.getAttributes().size()).isEqualTo(4);
+		assertThat(user.getAttributes().get("id")).isEqualTo("12345");
+		assertThat(user.getAttributes().get("name")).isEqualTo("first last");
+		assertThat(user.getAttributes().get("login")).isEqualTo("user1");
+		assertThat(user.getAttributes().get("email")).isEqualTo("user1@example.com");
+
+		assertThat(user.getAuthorities().size()).isEqualTo(1);
+		assertThat(user.getAuthorities().iterator().next().getAuthority()).isEqualTo("ROLE_USER");
+	}
+
+	@Test
+	public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_user_info_response"));
+
+		MockWebServer server = new MockWebServer();
+
+		String userInfoResponse = "{\n" +
+			"	\"id\": \"12345\",\n" +
+			"   \"name\": \"first last\",\n" +
+			"   \"login\": \"user1\",\n" +
+			"   \"email\": \"user1@example.com\"\n";
+//			"}\n";		// Make the JSON invalid/malformed
+		server.enqueue(new MockResponse()
+			.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
+			.setBody(userInfoResponse));
+
+		server.start();
+
+		String userInfoUri = server.url("/user").toString();
+
+		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
+		when(this.accessToken.getTokenValue()).thenReturn("access-token");
+
+		try {
+			this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
+		} finally {
+			server.shutdown();
+		}
+	}
+
+	@Test
+	public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() throws Exception {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_user_info_response"));
+
+		MockWebServer server = new MockWebServer();
+		server.enqueue(new MockResponse().setResponseCode(500));
+		server.start();
+
+		String userInfoUri = server.url("/user").toString();
+
+		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
+		when(this.accessToken.getTokenValue()).thenReturn("access-token");
+
+		try {
+			this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
+		} finally {
+			server.shutdown();
+		}
+	}
+
+	@Test
+	public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceException() throws Exception {
+		this.exception.expect(AuthenticationServiceException.class);
+
+		String userInfoUri = "http://invalid-provider.com/user";
+
+		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
+		when(this.accessToken.getTokenValue()).thenReturn("access-token");
+
+		this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
+	}
+
+	public static class CustomOAuth2User implements OAuth2User {
+		private List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
+		private String id;
+		private String name;
+		private String login;
+		private String email;
+
+		public CustomOAuth2User() {
+		}
+
+		@Override
+		public Collection<? extends GrantedAuthority> getAuthorities() {
+			return this.authorities;
+		}
+
+		@Override
+		public Map<String, Object> getAttributes() {
+			Map<String, Object> attributes = new HashMap<>();
+			attributes.put("id", this.getId());
+			attributes.put("name", this.getName());
+			attributes.put("login", this.getLogin());
+			attributes.put("email", this.getEmail());
+			return attributes;
+		}
+
+		public String getId() {
+			return this.id;
+		}
+
+		public void setId(String id) {
+			this.id = id;
+		}
+
+		@Override
+		public String getName() {
+			return this.name;
+		}
+
+		public void setName(String name) {
+			this.name = name;
+		}
+
+		public String getLogin() {
+			return this.login;
+		}
+
+		public void setLogin(String login) {
+			this.login = login;
+		}
+
+		public String getEmail() {
+			return this.email;
+		}
+
+		public void setEmail(String email) {
+			this.email = email;
+		}
+	}
+}

+ 197 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java

@@ -0,0 +1,197 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.userinfo;
+
+import okhttp3.mockwebserver.MockResponse;
+import okhttp3.mockwebserver.MockWebServer;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.powermock.core.classloader.annotations.PowerMockIgnore;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
+import org.springframework.security.authentication.AuthenticationServiceException;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.user.OAuth2User;
+import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.hamcrest.CoreMatchers.containsString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link DefaultOAuth2UserService}.
+ *
+ * @author Joe Grandja
+ */
+@PowerMockIgnore("okhttp3.*")
+@PrepareForTest(ClientRegistration.class)
+@RunWith(PowerMockRunner.class)
+public class DefaultOAuth2UserServiceTests {
+	private ClientRegistration clientRegistration;
+	private ClientRegistration.ProviderDetails providerDetails;
+	private ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint;
+	private OAuth2AccessToken accessToken;
+	private DefaultOAuth2UserService userService = new DefaultOAuth2UserService();
+
+	@Rule
+	public ExpectedException exception = ExpectedException.none();
+
+	@Before
+	public void setUp() throws Exception {
+		this.clientRegistration = mock(ClientRegistration.class);
+		this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
+		this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class);
+		when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails);
+		when(this.providerDetails.getUserInfoEndpoint()).thenReturn(this.userInfoEndpoint);
+		this.accessToken = mock(OAuth2AccessToken.class);
+	}
+
+	@Test
+	public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
+		this.exception.expect(IllegalArgumentException.class);
+		this.userService.loadUser(null);
+	}
+
+	@Test
+	public void loadUserWhenUserNameAttributeNameIsNullThenThrowOAuth2AuthenticationException() {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("missing_user_name_attribute"));
+
+		when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(null);
+		this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
+	}
+
+	@Test
+	public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception {
+		MockWebServer server = new MockWebServer();
+
+		String userInfoResponse = "{\n" +
+			"	\"user-name\": \"user1\",\n" +
+			"   \"first-name\": \"first\",\n" +
+			"   \"last-name\": \"last\",\n" +
+			"   \"middle-name\": \"middle\",\n" +
+			"   \"address\": \"address\",\n" +
+			"   \"email\": \"user1@example.com\"\n" +
+			"}\n";
+		server.enqueue(new MockResponse()
+			.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
+			.setBody(userInfoResponse));
+
+		server.start();
+
+		String userInfoUri = server.url("/user").toString();
+
+		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
+		when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
+		when(this.accessToken.getTokenValue()).thenReturn("access-token");
+
+		OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
+
+		server.shutdown();
+
+		assertThat(user.getName()).isEqualTo("user1");
+		assertThat(user.getAttributes().size()).isEqualTo(6);
+		assertThat(user.getAttributes().get("user-name")).isEqualTo("user1");
+		assertThat(user.getAttributes().get("first-name")).isEqualTo("first");
+		assertThat(user.getAttributes().get("last-name")).isEqualTo("last");
+		assertThat(user.getAttributes().get("middle-name")).isEqualTo("middle");
+		assertThat(user.getAttributes().get("address")).isEqualTo("address");
+		assertThat(user.getAttributes().get("email")).isEqualTo("user1@example.com");
+
+		assertThat(user.getAuthorities().size()).isEqualTo(1);
+		assertThat(user.getAuthorities().iterator().next()).isInstanceOf(OAuth2UserAuthority.class);
+		OAuth2UserAuthority userAuthority = (OAuth2UserAuthority)user.getAuthorities().iterator().next();
+		assertThat(userAuthority.getAuthority()).isEqualTo("ROLE_USER");
+		assertThat(userAuthority.getAttributes()).isEqualTo(user.getAttributes());
+	}
+
+	@Test
+	public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_user_info_response"));
+
+		MockWebServer server = new MockWebServer();
+
+		String userInfoResponse = "{\n" +
+			"	\"user-name\": \"user1\",\n" +
+			"   \"first-name\": \"first\",\n" +
+			"   \"last-name\": \"last\",\n" +
+			"   \"middle-name\": \"middle\",\n" +
+			"   \"address\": \"address\",\n" +
+			"   \"email\": \"user1@example.com\"\n";
+//			"}\n";		// Make the JSON invalid/malformed
+		server.enqueue(new MockResponse()
+			.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
+			.setBody(userInfoResponse));
+
+		server.start();
+
+		String userInfoUri = server.url("/user").toString();
+
+		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
+		when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
+		when(this.accessToken.getTokenValue()).thenReturn("access-token");
+
+		try {
+			this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
+		} finally {
+			server.shutdown();
+		}
+	}
+
+	@Test
+	public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() throws Exception {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("invalid_user_info_response"));
+
+		MockWebServer server = new MockWebServer();
+		server.enqueue(new MockResponse().setResponseCode(500));
+		server.start();
+
+		String userInfoUri = server.url("/user").toString();
+
+		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
+		when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
+		when(this.accessToken.getTokenValue()).thenReturn("access-token");
+
+		try {
+			this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
+		} finally {
+			server.shutdown();
+		}
+	}
+
+	@Test
+	public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceException() throws Exception {
+		this.exception.expect(AuthenticationServiceException.class);
+
+		String userInfoUri = "http://invalid-provider.com/user";
+
+		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
+		when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
+		when(this.accessToken.getTokenValue()).thenReturn("access-token");
+
+		this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
+	}
+}

+ 84 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DelegatingOAuth2UserServiceTests.java

@@ -0,0 +1,84 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.userinfo;
+
+import org.junit.Test;
+import org.springframework.security.oauth2.core.user.OAuth2User;
+
+import java.util.Arrays;
+import java.util.Collections;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link DelegatingOAuth2UserService}.
+ *
+ * @author Joe Grandja
+ */
+public class DelegatingOAuth2UserServiceTests {
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorWhenUserServicesIsNullThenThrowIllegalArgumentException() {
+		new DelegatingOAuth2UserService<>(null);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorWhenUserServicesIsEmptyThenThrowIllegalArgumentException() {
+		new DelegatingOAuth2UserService<>(Collections.emptyList());
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	@SuppressWarnings("unchecked")
+	public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
+		DelegatingOAuth2UserService<OAuth2UserRequest, OAuth2User> delegatingUserService =
+			new DelegatingOAuth2UserService<>(
+				Arrays.asList(mock(OAuth2UserService.class), mock(OAuth2UserService.class)));
+		delegatingUserService.loadUser(null);
+	}
+
+	@Test
+	@SuppressWarnings("unchecked")
+	public void loadUserWhenUserServiceCanLoadThenReturnUser() {
+		OAuth2UserService<OAuth2UserRequest, OAuth2User> userService1 = mock(OAuth2UserService.class);
+		OAuth2UserService<OAuth2UserRequest, OAuth2User> userService2 = mock(OAuth2UserService.class);
+		OAuth2UserService<OAuth2UserRequest, OAuth2User> userService3 = mock(OAuth2UserService.class);
+		OAuth2User mockUser = mock(OAuth2User.class);
+		when(userService3.loadUser(any(OAuth2UserRequest.class))).thenReturn(mockUser);
+
+		DelegatingOAuth2UserService<OAuth2UserRequest, OAuth2User> delegatingUserService =
+			new DelegatingOAuth2UserService<>(Arrays.asList(userService1, userService2, userService3));
+
+		OAuth2User loadedUser = delegatingUserService.loadUser(mock(OAuth2UserRequest.class));
+		assertThat(loadedUser).isEqualTo(mockUser);
+	}
+
+	@Test
+	@SuppressWarnings("unchecked")
+	public void loadUserWhenUserServiceCannotLoadThenReturnNull() {
+		OAuth2UserService<OAuth2UserRequest, OAuth2User> userService1 = mock(OAuth2UserService.class);
+		OAuth2UserService<OAuth2UserRequest, OAuth2User> userService2 = mock(OAuth2UserService.class);
+		OAuth2UserService<OAuth2UserRequest, OAuth2User> userService3 = mock(OAuth2UserService.class);
+
+		DelegatingOAuth2UserService<OAuth2UserRequest, OAuth2User> delegatingUserService =
+			new DelegatingOAuth2UserService<>(Arrays.asList(userService1, userService2, userService3));
+
+		OAuth2User loadedUser = delegatingUserService.loadUser(mock(OAuth2UserRequest.class));
+		assertThat(loadedUser).isNull();
+	}
+}

+ 63 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestTests.java

@@ -0,0 +1,63 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.userinfo;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Tests for {@link OAuth2UserRequest}.
+ *
+ * @author Joe Grandja
+ */
+@RunWith(PowerMockRunner.class)
+@PrepareForTest(ClientRegistration.class)
+public class OAuth2UserRequestTests {
+	private ClientRegistration clientRegistration;
+	private OAuth2AccessToken accessToken;
+
+	@Before
+	public void setUp() {
+		this.clientRegistration = mock(ClientRegistration.class);
+		this.accessToken = mock(OAuth2AccessToken.class);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
+		new OAuth2UserRequest(null, this.accessToken);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() {
+		new OAuth2UserRequest(this.clientRegistration, null);
+	}
+
+	@Test
+	public void constructorWhenAllParametersProvidedAndValidThenCreated() {
+		OAuth2UserRequest userRequest = new OAuth2UserRequest(this.clientRegistration, this.accessToken);
+
+		assertThat(userRequest.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(userRequest.getAccessToken()).isEqualTo(this.accessToken);
+	}
+}

+ 138 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java

@@ -0,0 +1,138 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.web;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository}.
+ *
+ * @author Joe Grandja
+ */
+@PrepareForTest(OAuth2AuthorizationRequest.class)
+@RunWith(PowerMockRunner.class)
+public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
+	private HttpSessionOAuth2AuthorizationRequestRepository authorizationRequestRepository =
+		new HttpSessionOAuth2AuthorizationRequestRepository();
+
+	@Test(expected = IllegalArgumentException.class)
+	public void loadAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() {
+		this.authorizationRequestRepository.loadAuthorizationRequest(null);
+	}
+
+	@Test
+	public void loadAuthorizationRequestWhenNotSavedThenReturnNull() {
+		OAuth2AuthorizationRequest authorizationRequest =
+			this.authorizationRequestRepository.loadAuthorizationRequest(new MockHttpServletRequest());
+
+		assertThat(authorizationRequest).isNull();
+	}
+
+	@Test
+	public void loadAuthorizationRequestWhenSavedThenReturnAuthorizationRequest() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
+
+		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
+		OAuth2AuthorizationRequest loadedAuthorizationRequest =
+			this.authorizationRequestRepository.loadAuthorizationRequest(request);
+
+		assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void saveAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() {
+		this.authorizationRequestRepository.saveAuthorizationRequest(
+			mock(OAuth2AuthorizationRequest.class), null, new MockHttpServletResponse());
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void saveAuthorizationRequestWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() {
+		this.authorizationRequestRepository.saveAuthorizationRequest(
+			mock(OAuth2AuthorizationRequest.class), new MockHttpServletRequest(), null);
+	}
+
+	@Test
+	public void saveAuthorizationRequestWhenNotNullThenSaved() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
+
+		this.authorizationRequestRepository.saveAuthorizationRequest(
+			authorizationRequest, request, new MockHttpServletResponse());
+		OAuth2AuthorizationRequest loadedAuthorizationRequest =
+			this.authorizationRequestRepository.loadAuthorizationRequest(request);
+
+		assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest);
+	}
+
+	@Test
+	public void saveAuthorizationRequestWhenNullThenRemoved() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
+
+		this.authorizationRequestRepository.saveAuthorizationRequest(		// Save
+			authorizationRequest, request, response);
+		this.authorizationRequestRepository.saveAuthorizationRequest(		// Null value removes
+			null, request, response);
+		OAuth2AuthorizationRequest loadedAuthorizationRequest =
+			this.authorizationRequestRepository.loadAuthorizationRequest(request);
+
+		assertThat(loadedAuthorizationRequest).isNull();
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void removeAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() {
+		this.authorizationRequestRepository.removeAuthorizationRequest(null);
+	}
+
+	@Test
+	public void removeAuthorizationRequestWhenSavedThenRemoved() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
+
+		this.authorizationRequestRepository.saveAuthorizationRequest(
+			authorizationRequest, request, response);
+		OAuth2AuthorizationRequest removedAuthorizationRequest =
+			this.authorizationRequestRepository.removeAuthorizationRequest(request);
+		OAuth2AuthorizationRequest loadedAuthorizationRequest =
+			this.authorizationRequestRepository.loadAuthorizationRequest(request);
+
+		assertThat(removedAuthorizationRequest).isNotNull();
+		assertThat(loadedAuthorizationRequest).isNull();
+	}
+
+	@Test
+	public void removeAuthorizationRequestWhenNotSavedThenNotRemoved() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+
+		OAuth2AuthorizationRequest removedAuthorizationRequest =
+			this.authorizationRequestRepository.removeAuthorizationRequest(request);
+
+		assertThat(removedAuthorizationRequest).isNull();
+	}
+}

+ 202 - 52
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java

@@ -15,109 +15,259 @@
  */
 package org.springframework.security.oauth2.client.web;
 
-import org.assertj.core.api.Assertions;
+import org.junit.Before;
 import org.junit.Test;
-import org.mockito.Matchers;
-import org.mockito.Mockito;
+import org.springframework.http.HttpStatus;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 
 import javax.servlet.FilterChain;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.*;
+
 /**
- * Tests {@link OAuth2AuthorizationRequestRedirectFilter}.
+ * Tests for {@link OAuth2AuthorizationRequestRedirectFilter}.
  *
  * @author Joe Grandja
  */
 public class OAuth2AuthorizationRequestRedirectFilterTests {
+	private ClientRegistration registration1;
+	private ClientRegistration registration2;
+	private ClientRegistration registration3;
+	private ClientRegistrationRepository clientRegistrationRepository;
+	private OAuth2AuthorizationRequestRedirectFilter filter;
+
+	@Before
+	public void setUp() {
+		this.registration1 = ClientRegistration.withRegistrationId("registration-1")
+			.clientId("client-1")
+			.clientSecret("secret")
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/code/{registrationId}")
+			.scope("user")
+			.authorizationUri("https://provider.com/oauth2/authorize")
+			.tokenUri("https://provider.com/oauth2/token")
+			.userInfoUri("https://provider.com/oauth2/user")
+			.userNameAttributeName("id")
+			.clientName("client-1")
+			.build();
+		this.registration2 = ClientRegistration.withRegistrationId("registration-2")
+			.clientId("client-2")
+			.clientSecret("secret")
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/code/{registrationId}")
+			.scope("openid", "profile", "email")
+			.authorizationUri("https://provider.com/oauth2/authorize")
+			.tokenUri("https://provider.com/oauth2/token")
+			.userInfoUri("https://provider.com/oauth2/userinfo")
+			.jwkSetUri("https://provider.com/oauth2/keys")
+			.clientName("client-2")
+			.build();
+		this.registration3 = ClientRegistration.withRegistrationId("registration-3")
+			.clientId("client-3")
+			.authorizationGrantType(AuthorizationGrantType.IMPLICIT)
+			.redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/implicit/{registrationId}")
+			.scope("openid", "profile", "email")
+			.authorizationUri("https://provider.com/oauth2/authorize")
+			.tokenUri("https://provider.com/oauth2/token")
+			.userInfoUri("https://provider.com/oauth2/userinfo")
+			.clientName("client-3")
+			.build();
+		this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(
+			this.registration1, this.registration2, this.registration3);
+		this.filter = new OAuth2AuthorizationRequestRedirectFilter(this.clientRegistrationRepository);
+	}
 
 	@Test(expected = IllegalArgumentException.class)
 	public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
 		new OAuth2AuthorizationRequestRedirectFilter(null);
 	}
 
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorWhenAuthorizationRequestBaseUriIsNullThenThrowIllegalArgumentException() {
+		new OAuth2AuthorizationRequestRedirectFilter(null, this.clientRegistrationRepository);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void setAuthorizationRequestRepositoryWhenAuthorizationRequestRepositoryIsNullThenThrowIllegalArgumentException() {
+		this.filter.setAuthorizationRequestRepository(null);
+	}
+
 	@Test
-	public void doFilterWhenRequestDoesNotMatchClientThenContinueChain() throws Exception {
-		ClientRegistration clientRegistration = TestUtil.googleClientRegistration();
-		String authorizationUri = clientRegistration.getProviderDetails().getAuthorizationUri().toString();
-		OAuth2AuthorizationRequestRedirectFilter filter =
-				setupFilter(authorizationUri, clientRegistration);
-
-		String requestURI = "/path";
-		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestURI);
-		request.setServletPath(requestURI);
+	public void doFilterWhenNotAuthorizationRequestThenNextFilter() throws Exception {
+		String requestUri = "/path";
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
 		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = Mockito.mock(FilterChain.class);
+		FilterChain filterChain = mock(FilterChain.class);
 
-		filter.doFilter(request, response, filterChain);
+		this.filter.doFilter(request, response, filterChain);
 
-		Mockito.verify(filterChain).doFilter(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class));
+		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
 	}
 
 	@Test
-	public void doFilterWhenRequestMatchesClientThenRedirectForAuthorization() throws Exception {
-		ClientRegistration clientRegistration = TestUtil.googleClientRegistration();
-		String authorizationUri = clientRegistration.getProviderDetails().getAuthorizationUri().toString();
-		OAuth2AuthorizationRequestRedirectFilter filter =
-				setupFilter(authorizationUri, clientRegistration);
+	public void doFilterWhenAuthorizationRequestWithInvalidClientThenStatusBadRequest() throws Exception {
+		String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
+			"/" + this.registration1.getRegistrationId() + "-invalid";
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyZeroInteractions(filterChain);
 
-		String requestUri = TestUtil.AUTHORIZATION_BASE_URI + "/" + clientRegistration.getRegistrationId();
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+		assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.BAD_REQUEST.getReasonPhrase());
+	}
+
+	@Test
+	public void doFilterWhenAuthorizationRequestAuthorizationCodeGrantThenRedirectForAuthorization() throws Exception {
+		String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
+			"/" + this.registration1.getRegistrationId();
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 		request.setServletPath(requestUri);
 		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = Mockito.mock(FilterChain.class);
+		FilterChain filterChain = mock(FilterChain.class);
 
-		filter.doFilter(request, response, filterChain);
+		this.filter.doFilter(request, response, filterChain);
 
-		Mockito.verifyZeroInteractions(filterChain);        // Request should not proceed up the chain
+		verifyZeroInteractions(filterChain);
 
-		Assertions.assertThat(response.getRedirectedUrl()).matches("https://accounts.google.com/o/oauth2/auth\\?response_type=code&client_id=google-client-id&scope=openid%20email%20profile&state=.{15,}&redirect_uri=https://localhost:8080/login/oauth2/code/google");
+		assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost:80/login/oauth2/code/registration-1");
 	}
 
 	@Test
-	public void doFilterWhenRequestMatchesClientThenAuthorizationRequestSavedInSession() throws Exception {
-		ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
-		String authorizationUri = clientRegistration.getProviderDetails().getAuthorizationUri().toString();
-		OAuth2AuthorizationRequestRedirectFilter filter =
-				setupFilter(authorizationUri, clientRegistration);
+	public void doFilterWhenAuthorizationRequestAuthorizationCodeGrantThenAuthorizationRequestSavedInSession() throws Exception {
+		String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
+			"/" + this.registration2.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
 		AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
 			new HttpSessionOAuth2AuthorizationRequestRepository();
-		filter.setAuthorizationRequestRepository(authorizationRequestRepository);
+		this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyZeroInteractions(filterChain);
+
+		OAuth2AuthorizationRequest authorizationRequest = authorizationRequestRepository.loadAuthorizationRequest(request);
+
+		assertThat(authorizationRequest).isNotNull();
+		assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo(
+			this.registration2.getProviderDetails().getAuthorizationUri());
+		assertThat(authorizationRequest.getGrantType()).isEqualTo(
+			this.registration2.getAuthorizationGrantType());
+		assertThat(authorizationRequest.getResponseType()).isEqualTo(
+			OAuth2AuthorizationResponseType.CODE);
+		assertThat(authorizationRequest.getClientId()).isEqualTo(
+			this.registration2.getClientId());
+		assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
+			"http://localhost:80/login/oauth2/code/registration-2");
+		assertThat(authorizationRequest.getScopes()).isEqualTo(
+			this.registration2.getScopes());
+		assertThat(authorizationRequest.getState()).isNotNull();
+		assertThat(authorizationRequest.getAdditionalParameters()
+			.get(OAuth2ParameterNames.REGISTRATION_ID)).isEqualTo(this.registration2.getRegistrationId());
+	}
+
+	@Test
+	public void doFilterWhenAuthorizationRequestImplicitGrantThenRedirectForAuthorization() throws Exception {
+		String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
+			"/" + this.registration3.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyZeroInteractions(filterChain);
 
-		String requestUri = TestUtil.AUTHORIZATION_BASE_URI + "/" + clientRegistration.getRegistrationId();
+		assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=token&client_id=client-3&scope=openid%20profile%20email&state=.{15,}&redirect_uri=http://localhost:80/login/oauth2/implicit/registration-3");
+	}
+
+	@Test
+	public void doFilterWhenAuthorizationRequestImplicitGrantThenAuthorizationRequestNotSavedInSession() throws Exception {
+		String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
+			"/" + this.registration3.getRegistrationId();
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 		request.setServletPath(requestUri);
 		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = Mockito.mock(FilterChain.class);
+		FilterChain filterChain = mock(FilterChain.class);
+
+		AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
+			new HttpSessionOAuth2AuthorizationRequestRepository();
+		this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);
 
-		filter.doFilter(request, response, filterChain);
+		this.filter.doFilter(request, response, filterChain);
 
-		Mockito.verifyZeroInteractions(filterChain);        // Request should not proceed up the chain
+		verifyZeroInteractions(filterChain);
 
-		// The authorization request attributes are saved in the session before the redirect happens
-		OAuth2AuthorizationRequest authorizationRequest =
-				authorizationRequestRepository.loadAuthorizationRequest(request);
-		Assertions.assertThat(authorizationRequest).isNotNull();
+		OAuth2AuthorizationRequest authorizationRequest = authorizationRequestRepository.loadAuthorizationRequest(request);
 
-		Assertions.assertThat(authorizationRequest.getAuthorizationUri()).isNotNull();
-		Assertions.assertThat(authorizationRequest.getGrantType()).isNotNull();
-		Assertions.assertThat(authorizationRequest.getResponseType()).isNotNull();
-		Assertions.assertThat(authorizationRequest.getClientId()).isNotNull();
-		Assertions.assertThat(authorizationRequest.getRedirectUri()).isNotNull();
-		Assertions.assertThat(authorizationRequest.getScopes()).isNotNull();
-		Assertions.assertThat(authorizationRequest.getState()).isNotNull();
+		assertThat(authorizationRequest).isNull();
 	}
 
-	private OAuth2AuthorizationRequestRedirectFilter setupFilter(String authorizationUri,
-																	ClientRegistration... clientRegistrations) throws Exception {
-		ClientRegistrationRepository clientRegistrationRepository = new InMemoryClientRegistrationRepository(clientRegistrations);
-		OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(clientRegistrationRepository);
-		return filter;
+	@Test
+	public void doFilterWhenCustomAuthorizationRequestBaseUriThenRedirectForAuthorization() throws Exception {
+		String authorizationRequestBaseUri = "/custom/authorization";
+		this.filter = new OAuth2AuthorizationRequestRedirectFilter(authorizationRequestBaseUri, this.clientRegistrationRepository);
+
+		String requestUri = authorizationRequestBaseUri + "/" + this.registration1.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyZeroInteractions(filterChain);
+
+		assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost:80/login/oauth2/code/registration-1");
+	}
+
+	@Test
+	public void doFilterWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriExpanded() throws Exception {
+		String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
+			"/" + this.registration2.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
+			new HttpSessionOAuth2AuthorizationRequestRepository();
+		this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyZeroInteractions(filterChain);
+
+		OAuth2AuthorizationRequest authorizationRequest = authorizationRequestRepository.loadAuthorizationRequest(request);
+
+		assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(
+			this.registration2.getRedirectUri());
+		assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
+			"http://localhost:80/login/oauth2/code/registration-2");
 	}
 }

+ 7 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestUriBuilderTests.java

@@ -27,12 +27,19 @@ import java.util.HashSet;
 import static org.assertj.core.api.Assertions.assertThat;
 
 /**
+ * Tests for {@link OAuth2AuthorizationRequestUriBuilder}.
+ *
  * @author Rob Winch
  * @since 5.0
  */
 public class OAuth2AuthorizationRequestUriBuilderTests {
 	private OAuth2AuthorizationRequestUriBuilder builder = new OAuth2AuthorizationRequestUriBuilder();
 
+	@Test(expected = IllegalArgumentException.class)
+	public void buildWhenAuthorizationRequestIsNullThenThrowIllegalArgumentException() {
+		this.builder.build(null);
+	}
+
 	@Test
 	public void buildWhenScopeMultiThenSeparatedByEncodedSpace() {
 		OAuth2AuthorizationRequest request = OAuth2AuthorizationRequest.implicit()

+ 190 - 134
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java

@@ -15,32 +15,36 @@
  */
 package org.springframework.security.oauth2.client.web;
 
-import org.assertj.core.api.Assertions;
+import org.junit.Before;
 import org.junit.Test;
+import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
-import org.mockito.Matchers;
-import org.mockito.Mockito;
+import org.powermock.core.classloader.annotations.PowerMockIgnore;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.authority.AuthorityUtils;
-import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
-import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
 import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
-import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 
 import javax.servlet.FilterChain;
 import javax.servlet.http.HttpServletRequest;
@@ -48,183 +52,235 @@ import javax.servlet.http.HttpServletResponse;
 import java.util.HashMap;
 import java.util.Map;
 
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.*;
 
 /**
- * Tests {@link OAuth2LoginAuthenticationFilter}.
+ * Tests for {@link OAuth2LoginAuthenticationFilter}.
  *
  * @author Joe Grandja
  */
+@PowerMockIgnore("javax.security.*")
+@PrepareForTest({OAuth2AuthorizationRequest.class, OAuth2AuthorizationExchange.class})
+@RunWith(PowerMockRunner.class)
 public class OAuth2LoginAuthenticationFilterTests {
+	private ClientRegistration registration1;
+	private ClientRegistration registration2;
+	private String principalName1 = "principal-1";
+	private ClientRegistrationRepository clientRegistrationRepository;
+	private OAuth2AuthorizedClientService authorizedClientService;
+	private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
+	private AuthenticationFailureHandler failureHandler;
+	private AuthenticationManager authenticationManager;
+	private OAuth2LoginAuthenticationFilter filter;
+
+	@Before
+	public void setUp() {
+		this.registration1 = ClientRegistration.withRegistrationId("registration-1")
+			.clientId("client-1")
+			.clientSecret("secret")
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/code/{registrationId}")
+			.scope("user")
+			.authorizationUri("https://provider.com/oauth2/authorize")
+			.tokenUri("https://provider.com/oauth2/token")
+			.userInfoUri("https://provider.com/oauth2/user")
+			.userNameAttributeName("id")
+			.clientName("client-1")
+			.build();
+		this.registration2 = ClientRegistration.withRegistrationId("registration-2")
+			.clientId("client-2")
+			.clientSecret("secret")
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/code/{registrationId}")
+			.scope("openid", "profile", "email")
+			.authorizationUri("https://provider.com/oauth2/authorize")
+			.tokenUri("https://provider.com/oauth2/token")
+			.userInfoUri("https://provider.com/oauth2/userinfo")
+			.jwkSetUri("https://provider.com/oauth2/keys")
+			.clientName("client-2")
+			.build();
+		this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(
+			this.registration1, this.registration2);
+		this.authorizedClientService = new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository);
+		this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
+		this.failureHandler = mock(AuthenticationFailureHandler.class);
+		this.authenticationManager = mock(AuthenticationManager.class);
+		this.filter = spy(new OAuth2LoginAuthenticationFilter(
+			this.clientRegistrationRepository, this.authorizedClientService));
+		this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
+		this.filter.setAuthenticationFailureHandler(this.failureHandler);
+		this.filter.setAuthenticationManager(this.authenticationManager);
+	}
 
-	@Test
-	public void doFilterWhenNotAuthorizationCodeResponseThenContinueChain() throws Exception {
-		ClientRegistration clientRegistration = TestUtil.googleClientRegistration();
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
+		new OAuth2LoginAuthenticationFilter(null, this.authorizedClientService);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorWhenAuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
+		new OAuth2LoginAuthenticationFilter(this.clientRegistrationRepository, null);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorWhenFilterProcessesUrlIsNullThenThrowIllegalArgumentException() {
+		new OAuth2LoginAuthenticationFilter(null, this.clientRegistrationRepository, this.authorizedClientService);
+	}
 
-		OAuth2LoginAuthenticationFilter filter = Mockito.spy(setupFilter(clientRegistration));
+	@Test(expected = IllegalArgumentException.class)
+	public void setAuthorizationRequestRepositoryWhenAuthorizationRequestRepositoryIsNullThenThrowIllegalArgumentException() {
+		this.filter.setAuthorizationRequestRepository(null);
+	}
 
-		String requestURI = "/path";
-		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestURI);
-		request.setServletPath(requestURI);
+	@Test
+	public void doFilterWhenNotAuthorizationResponseThenNextFilter() throws Exception {
+		String requestUri = "/path";
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
 
-		filter.doFilter(request, response, filterChain);
+		this.filter.doFilter(request, response, filterChain);
 
-		Mockito.verify(filterChain).doFilter(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class));
-		Mockito.verify(filter, Mockito.never()).attemptAuthentication(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class));
+		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+		verify(this.filter, never()).attemptAuthentication(any(HttpServletRequest.class), any(HttpServletResponse.class));
 	}
 
 	@Test
-	public void doFilterWhenAuthorizationCodeErrorResponseThenAuthenticationFailureHandlerIsCalled() throws Exception {
-		ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
+	public void doFilterWhenAuthorizationResponseInvalidThenInvalidRequestError() throws Exception {
+		String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		// NOTE:
+		// A valid Authorization Response contains either a 'code' or 'error' parameter.
+		// Don't set it to force an invalid Authorization Response.
 
-		OAuth2LoginAuthenticationFilter filter = Mockito.spy(setupFilter(clientRegistration));
-		AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class);
-		filter.setAuthenticationFailureHandler(failureHandler);
-
-		MockHttpServletRequest request = this.setupRequest(clientRegistration);
-		String errorCode = OAuth2ErrorCodes.INVALID_GRANT;
-		request.addParameter(OAuth2ParameterNames.ERROR, errorCode);
-		request.addParameter(OAuth2ParameterNames.STATE, "some state");
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
 
-		filter.doFilter(request, response, filterChain);
+		this.filter.doFilter(request, response, filterChain);
+
+		ArgumentCaptor<AuthenticationException> authenticationExceptionArgCaptor = ArgumentCaptor.forClass(AuthenticationException.class);
+		verify(this.failureHandler).onAuthenticationFailure(any(HttpServletRequest.class), any(HttpServletResponse.class),
+			authenticationExceptionArgCaptor.capture());
 
-		Mockito.verify(filter).attemptAuthentication(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class));
-		Mockito.verify(failureHandler).onAuthenticationFailure(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class),
-				Matchers.any(AuthenticationException.class));
+		assertThat(authenticationExceptionArgCaptor.getValue()).isInstanceOf(OAuth2AuthenticationException.class);
+		OAuth2AuthenticationException authenticationException = (OAuth2AuthenticationException) authenticationExceptionArgCaptor.getValue();
+		assertThat(authenticationException.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
 	}
 
 	@Test
-	public void doFilterWhenAuthorizationCodeSuccessResponseThenAuthenticationSuccessHandlerIsCalled() throws Exception {
-		ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
-		OAuth2User oauth2User = mock(OAuth2User.class);
-		when(oauth2User.getName()).thenReturn("principal name");
-		OAuth2LoginAuthenticationToken loginAuthentication = mock(OAuth2LoginAuthenticationToken.class);
-		when(loginAuthentication.getPrincipal()).thenReturn(oauth2User);
-		when(loginAuthentication.getClientRegistration()).thenReturn(clientRegistration);
-		when(loginAuthentication.getAccessToken()).thenReturn(mock(OAuth2AccessToken.class));
+	public void doFilterWhenAuthorizationResponseAuthorizationRequestNotFoundThenAuthorizationRequestNotFoundError() throws Exception {
+		String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		request.addParameter(OAuth2ParameterNames.CODE, "code");
+		request.addParameter(OAuth2ParameterNames.STATE, "state");
 
-		OAuth2AuthenticationToken userAuthentication = new OAuth2AuthenticationToken(
-			oauth2User, AuthorityUtils.NO_AUTHORITIES, clientRegistration.getRegistrationId());
-		SecurityContextHolder.getContext().setAuthentication(userAuthentication);
-		AuthenticationManager authenticationManager = mock(AuthenticationManager.class);
-		when(authenticationManager.authenticate(Matchers.any(Authentication.class))).thenReturn(loginAuthentication);
-
-		OAuth2LoginAuthenticationFilter filter = Mockito.spy(setupFilter(authenticationManager, clientRegistration));
-		AuthenticationSuccessHandler successHandler = mock(AuthenticationSuccessHandler.class);
-		filter.setAuthenticationSuccessHandler(successHandler);
-		AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
-			new HttpSessionOAuth2AuthorizationRequestRepository();
-		filter.setAuthorizationRequestRepository(authorizationRequestRepository);
-
-		MockHttpServletRequest request = this.setupRequest(clientRegistration);
-		String authCode = "some code";
-		String state = "some state";
-		request.addParameter(OAuth2ParameterNames.CODE, authCode);
-		request.addParameter(OAuth2ParameterNames.STATE, state);
 		MockHttpServletResponse response = new MockHttpServletResponse();
-		setupAuthorizationRequest(authorizationRequestRepository, request, response, clientRegistration, state);
 		FilterChain filterChain = mock(FilterChain.class);
 
-		filter.doFilter(request, response, filterChain);
+		this.filter.doFilter(request, response, filterChain);
 
-		Mockito.verify(filter).attemptAuthentication(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class));
+		ArgumentCaptor<AuthenticationException> authenticationExceptionArgCaptor = ArgumentCaptor.forClass(AuthenticationException.class);
+		verify(this.failureHandler).onAuthenticationFailure(any(HttpServletRequest.class), any(HttpServletResponse.class),
+			authenticationExceptionArgCaptor.capture());
 
-		ArgumentCaptor<Authentication> authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class);
-		Mockito.verify(successHandler).onAuthenticationSuccess(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class),
-				authenticationArgCaptor.capture());
-		Assertions.assertThat(authenticationArgCaptor.getValue()).isEqualTo(userAuthentication);
+		assertThat(authenticationExceptionArgCaptor.getValue()).isInstanceOf(OAuth2AuthenticationException.class);
+		OAuth2AuthenticationException authenticationException = (OAuth2AuthenticationException) authenticationExceptionArgCaptor.getValue();
+		assertThat(authenticationException.getError().getErrorCode()).isEqualTo("authorization_request_not_found");
 	}
 
 	@Test
-	public void doFilterWhenAuthorizationCodeSuccessResponseAndNoMatchingAuthorizationRequestThenThrowOAuth2AuthenticationExceptionAuthorizationRequestNotFound() throws Exception {
-		ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
-
-		OAuth2LoginAuthenticationFilter filter = Mockito.spy(setupFilter(clientRegistration));
-		AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class);
-		filter.setAuthenticationFailureHandler(failureHandler);
-
-		MockHttpServletRequest request = this.setupRequest(clientRegistration);
-		String authCode = "some code";
-		String state = "some state";
-		request.addParameter(OAuth2ParameterNames.CODE, authCode);
-		request.addParameter(OAuth2ParameterNames.STATE, state);
+	public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved() throws Exception {
+		String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		request.addParameter(OAuth2ParameterNames.CODE, "code");
+		request.addParameter(OAuth2ParameterNames.STATE, "state");
+
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
 
-		filter.doFilter(request, response, filterChain);
+		this.setUpAuthorizationRequest(request, response, this.registration2);
+		this.setUpAuthenticationResult(this.registration2);
 
-		verifyThrowsOAuth2AuthenticationExceptionWithErrorCode(filter, failureHandler, "authorization_request_not_found");
-	}
+		this.filter.doFilter(request, response, filterChain);
 
-	private void verifyThrowsOAuth2AuthenticationExceptionWithErrorCode(OAuth2LoginAuthenticationFilter filter,
-																		AuthenticationFailureHandler failureHandler,
-																		String errorCode) throws Exception {
-
-		Mockito.verify(filter).attemptAuthentication(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class));
-
-		ArgumentCaptor<AuthenticationException> authenticationExceptionArgCaptor =
-				ArgumentCaptor.forClass(AuthenticationException.class);
-		Mockito.verify(failureHandler).onAuthenticationFailure(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class),
-				authenticationExceptionArgCaptor.capture());
-		Assertions.assertThat(authenticationExceptionArgCaptor.getValue()).isInstanceOf(OAuth2AuthenticationException.class);
-		OAuth2AuthenticationException oauth2AuthenticationException =
-				(OAuth2AuthenticationException)authenticationExceptionArgCaptor.getValue();
-		Assertions.assertThat(oauth2AuthenticationException.getError()).isNotNull();
-		Assertions.assertThat(oauth2AuthenticationException.getError().getErrorCode()).isEqualTo(errorCode);
+		assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(request)).isNull();
 	}
 
-	private OAuth2LoginAuthenticationFilter setupFilter(ClientRegistration... clientRegistrations) throws Exception {
-		AuthenticationManager authenticationManager = mock(AuthenticationManager.class);
+	@Test
+	public void doFilterWhenAuthorizationResponseValidThenAuthorizedClientSaved() throws Exception {
+		String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		request.addParameter(OAuth2ParameterNames.CODE, "code");
+		request.addParameter(OAuth2ParameterNames.STATE, "state");
+
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.setUpAuthorizationRequest(request, response, this.registration1);
+		this.setUpAuthenticationResult(this.registration1);
+
+		this.filter.doFilter(request, response, filterChain);
 
-		return setupFilter(authenticationManager, clientRegistrations);
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
+			this.registration1.getRegistrationId(), this.principalName1);
+		assertThat(authorizedClient).isNotNull();
+		assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1);
+		assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName1);
+		assertThat(authorizedClient.getAccessToken()).isNotNull();
 	}
 
-	private OAuth2LoginAuthenticationFilter setupFilter(
-			AuthenticationManager authenticationManager, ClientRegistration... clientRegistrations) throws Exception {
+	@Test
+	public void doFilterWhenCustomFilterProcessesUrlThenFilterProcesses() throws Exception {
+		String filterProcessesUrl = "/login/oauth2/custom/*";
+		this.filter = spy(new OAuth2LoginAuthenticationFilter(filterProcessesUrl,
+			this.clientRegistrationRepository, this.authorizedClientService));
+		this.filter.setAuthenticationManager(this.authenticationManager);
+
+		String requestUri = "/login/oauth2/custom/" + this.registration2.getRegistrationId();
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		request.addParameter(OAuth2ParameterNames.CODE, "code");
+		request.addParameter(OAuth2ParameterNames.STATE, "state");
+
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
 
-		ClientRegistrationRepository clientRegistrationRepository = new InMemoryClientRegistrationRepository(clientRegistrations);
+		this.setUpAuthorizationRequest(request, response, this.registration2);
+		this.setUpAuthenticationResult(this.registration2);
 
-		OAuth2LoginAuthenticationFilter filter = new OAuth2LoginAuthenticationFilter(
-			clientRegistrationRepository, mock(OAuth2AuthorizedClientService.class));
-		filter.setAuthenticationManager(authenticationManager);
+		this.filter.doFilter(request, response, filterChain);
 
-		return filter;
+		verifyZeroInteractions(filterChain);
+		verify(this.filter).attemptAuthentication(any(HttpServletRequest.class), any(HttpServletResponse.class));
 	}
 
-	private void setupAuthorizationRequest(AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository,
-											HttpServletRequest request,
-											HttpServletResponse response,
-											ClientRegistration clientRegistration,
-											String state) {
-
-		Map<String,Object> additionalParameters = new HashMap<>();
-		additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());
-
-		OAuth2AuthorizationRequest authorizationRequest =
-			OAuth2AuthorizationRequest.authorizationCode()
-				.clientId(clientRegistration.getClientId())
-				.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
-				.redirectUri(clientRegistration.getRedirectUri())
-				.scopes(clientRegistration.getScopes())
-				.state(state)
-				.additionalParameters(additionalParameters)
-				.build();
-
-		authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
+	private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
+											ClientRegistration registration) {
+		OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
+		Map<String, Object> additionalParameters = new HashMap<>();
+		additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
+		when(authorizationRequest.getAdditionalParameters()).thenReturn(additionalParameters);
+		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
 	}
 
-	private MockHttpServletRequest setupRequest(ClientRegistration clientRegistration) {
-		String requestURI = TestUtil.AUTHORIZE_BASE_URI + "/" + clientRegistration.getRegistrationId();
-		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestURI);
-		request.setScheme(TestUtil.DEFAULT_SCHEME);
-		request.setServerName(TestUtil.DEFAULT_SERVER_NAME);
-		request.setServerPort(TestUtil.DEFAULT_SERVER_PORT);
-		request.setServletPath(requestURI);
-		return request;
+	private void setUpAuthenticationResult(ClientRegistration registration) {
+		OAuth2User user = mock(OAuth2User.class);
+		when(user.getName()).thenReturn(this.principalName1);
+		OAuth2LoginAuthenticationToken loginAuthentication = mock(OAuth2LoginAuthenticationToken.class);
+		when(loginAuthentication.getPrincipal()).thenReturn(user);
+		when(loginAuthentication.getAuthorities()).thenReturn(AuthorityUtils.createAuthorityList("ROLE_USER"));
+		when(loginAuthentication.getClientRegistration()).thenReturn(registration);
+		when(loginAuthentication.getAuthorizationExchange()).thenReturn(mock(OAuth2AuthorizationExchange.class));
+		when(loginAuthentication.getAccessToken()).thenReturn(mock(OAuth2AccessToken.class));
+		when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(loginAuthentication);
 	}
 }

+ 0 - 71
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/TestUtil.java

@@ -1,71 +0,0 @@
-/*
- * Copyright 2002-2017 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.security.oauth2.client.web;
-
-import org.springframework.security.oauth2.client.registration.ClientRegistration;
-import org.springframework.security.oauth2.core.AuthorizationGrantType;
-
-
-/**
- * @author Joe Grandja
- */
-class TestUtil {
-	static final String DEFAULT_SCHEME = "https";
-	static final String DEFAULT_SERVER_NAME = "localhost";
-	static final int DEFAULT_SERVER_PORT = 8080;
-	static final String DEFAULT_SERVER_URL = DEFAULT_SCHEME + "://" + DEFAULT_SERVER_NAME + ":" + DEFAULT_SERVER_PORT;
-	static final String AUTHORIZATION_BASE_URI = "/oauth2/authorization";
-	static final String AUTHORIZE_BASE_URI = "/login/oauth2/code";
-	static final String GOOGLE_REGISTRATION_ID = "google";
-	static final String GITHUB_REGISTRATION_ID = "github";
-
-	static ClientRegistration googleClientRegistration() {
-		return googleClientRegistration(DEFAULT_SERVER_URL + AUTHORIZE_BASE_URI + "/" + GOOGLE_REGISTRATION_ID);
-	}
-
-	static ClientRegistration googleClientRegistration(String redirectUri) {
-		return ClientRegistration.withRegistrationId(GOOGLE_REGISTRATION_ID)
-			.clientId("google-client-id")
-			.clientSecret("secret")
-			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
-			.clientName("Google Client")
-			.authorizationUri("https://accounts.google.com/o/oauth2/auth")
-			.tokenUri("https://accounts.google.com/o/oauth2/token")
-			.userInfoUri("https://www.googleapis.com/oauth2/v3/userinfo")
-			.jwkSetUri("https://www.googleapis.com/oauth2/v3/certs")
-			.redirectUri(redirectUri)
-			.scope("openid", "email", "profile")
-			.build();
-	}
-
-	static ClientRegistration githubClientRegistration() {
-		return githubClientRegistration(DEFAULT_SERVER_URL + AUTHORIZE_BASE_URI + "/" + GITHUB_REGISTRATION_ID);
-	}
-
-	static ClientRegistration githubClientRegistration(String redirectUri) {
-		return ClientRegistration.withRegistrationId(GITHUB_REGISTRATION_ID)
-			.clientId("github-client-id")
-			.clientSecret("secret")
-			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
-			.clientName("GitHub Client")
-			.authorizationUri("https://github.com/login/oauth/authorize")
-			.tokenUri("https://github.com/login/oauth/access_token")
-			.userInfoUri("https://api.github.com/user")
-			.redirectUri(redirectUri)
-			.scope("user")
-			.build();
-	}
-}