浏览代码

Provide RestOperations in DefaultOAuth2UserService

Fixes gh-5600
Joe Grandja 7 年之前
父节点
当前提交
4a8c95a3e8

+ 37 - 4
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandler.java

@@ -15,11 +15,15 @@
  */
 package org.springframework.security.oauth2.client.http;
 
+import com.nimbusds.oauth2.sdk.token.BearerTokenError;
+import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpStatus;
 import org.springframework.http.client.ClientHttpResponse;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
+import org.springframework.util.StringUtils;
 import org.springframework.web.client.DefaultResponseErrorHandler;
 import org.springframework.web.client.ResponseErrorHandler;
 
@@ -44,10 +48,39 @@ public class OAuth2ErrorResponseErrorHandler implements ResponseErrorHandler {
 
 	@Override
 	public void handleError(ClientHttpResponse response) throws IOException {
-		if (HttpStatus.BAD_REQUEST.equals(response.getStatusCode())) {
-			OAuth2Error oauth2Error = this.oauth2ErrorConverter.read(OAuth2Error.class, response);
-			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
+		if (!HttpStatus.BAD_REQUEST.equals(response.getStatusCode())) {
+			this.defaultErrorHandler.handleError(response);
 		}
-		this.defaultErrorHandler.handleError(response);
+
+		// A Bearer Token Error may be in the WWW-Authenticate response header
+		// See https://tools.ietf.org/html/rfc6750#section-3
+		OAuth2Error	oauth2Error = this.readErrorFromWwwAuthenticate(response.getHeaders());
+		if (oauth2Error == null) {
+			oauth2Error = this.oauth2ErrorConverter.read(OAuth2Error.class, response);
+		}
+
+		throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
+	}
+
+	private OAuth2Error readErrorFromWwwAuthenticate(HttpHeaders headers) {
+		String wwwAuthenticateHeader = headers.getFirst(HttpHeaders.WWW_AUTHENTICATE);
+		if (!StringUtils.hasText(wwwAuthenticateHeader)) {
+			return null;
+		}
+
+		BearerTokenError bearerTokenError;
+		try {
+			bearerTokenError = BearerTokenError.parse(wwwAuthenticateHeader);
+		} catch (Exception ex) {
+			return null;
+		}
+
+		String errorCode = bearerTokenError.getCode() != null ?
+				bearerTokenError.getCode() : OAuth2ErrorCodes.SERVER_ERROR;
+		String errorDescription = bearerTokenError.getDescription();
+		String errorUri = bearerTokenError.getURI() != null ?
+				bearerTokenError.getURI().toString() : null;
+
+		return new OAuth2Error(errorCode, errorDescription, errorUri);
 	}
 }

+ 86 - 11
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,7 +16,12 @@
 package org.springframework.security.oauth2.client.userinfo;
 
 import org.springframework.core.ParameterizedTypeReference;
+import org.springframework.core.convert.converter.Converter;
+import org.springframework.http.RequestEntity;
+import org.springframework.http.ResponseEntity;
 import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
@@ -24,8 +29,12 @@ import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
 import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
+import org.springframework.web.client.ResponseErrorHandler;
+import org.springframework.web.client.RestClientException;
+import org.springframework.web.client.RestOperations;
+import org.springframework.web.client.RestTemplate;
 
-import java.util.HashSet;
+import java.util.Collections;
 import java.util.Map;
 import java.util.Set;
 
@@ -34,7 +43,7 @@ import java.util.Set;
  * <p>
  * For standard OAuth 2.0 Provider's, the attribute name used to access the user's name
  * from the UserInfo response is required and therefore must be available via
- * {@link org.springframework.security.oauth2.client.registration.ClientRegistration.ProviderDetails.UserInfoEndpoint#getUserNameAttributeName() UserInfoEndpoint.getUserNameAttributeName()}.
+ * {@link ClientRegistration.ProviderDetails.UserInfoEndpoint#getUserNameAttributeName() UserInfoEndpoint.getUserNameAttributeName()}.
  * <p>
  * <b>NOTE:</b> Attribute names are <b>not</b> standardized between providers and therefore will vary.
  * Please consult the provider's API documentation for the set of supported user attribute names.
@@ -48,8 +57,23 @@ import java.util.Set;
  */
 public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserRequest, OAuth2User> {
 	private static final String MISSING_USER_INFO_URI_ERROR_CODE = "missing_user_info_uri";
+
 	private static final String MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE = "missing_user_name_attribute";
-	private NimbusUserInfoResponseClient userInfoResponseClient = new NimbusUserInfoResponseClient();
+
+	private static final String INVALID_USER_INFO_RESPONSE_ERROR_CODE = "invalid_user_info_response";
+
+	private static final ParameterizedTypeReference<Map<String, Object>> PARAMETERIZED_RESPONSE_TYPE =
+			new ParameterizedTypeReference<Map<String, Object>>() {};
+
+	private Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter = new OAuth2UserRequestEntityConverter();
+
+	private RestOperations restOperations;
+
+	public DefaultOAuth2UserService() {
+		RestTemplate restTemplate = new RestTemplate();
+		restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler());
+		this.restOperations = restTemplate;
+	}
 
 	@Override
 	public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2AuthenticationException {
@@ -64,7 +88,8 @@ public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserReq
 			);
 			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
 		}
-		String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName();
+		String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails()
+				.getUserInfoEndpoint().getUserNameAttributeName();
 		if (!StringUtils.hasText(userNameAttributeName)) {
 			OAuth2Error oauth2Error = new OAuth2Error(
 				MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE,
@@ -75,13 +100,63 @@ public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserReq
 			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
 		}
 
-		ParameterizedTypeReference<Map<String, Object>> typeReference =
-			new ParameterizedTypeReference<Map<String, Object>>() {};
-		Map<String, Object> userAttributes = this.userInfoResponseClient.getUserInfoResponse(userRequest, typeReference);
-		GrantedAuthority authority = new OAuth2UserAuthority(userAttributes);
-		Set<GrantedAuthority> authorities = new HashSet<>();
-		authorities.add(authority);
+		RequestEntity<?> request = this.requestEntityConverter.convert(userRequest);
+
+		ResponseEntity<Map<String, Object>> response;
+		try {
+			response = this.restOperations.exchange(request, PARAMETERIZED_RESPONSE_TYPE);
+		} catch (OAuth2AuthenticationException ex) {
+			OAuth2Error oauth2Error = ex.getError();
+			StringBuilder errorDetails = new StringBuilder();
+			errorDetails.append("Error details: [");
+			errorDetails.append("UserInfo Uri: ").append(
+					userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUri());
+			errorDetails.append(", Error Code: ").append(oauth2Error.getErrorCode());
+			if (oauth2Error.getDescription() != null) {
+				errorDetails.append(", Error Description: ").append(oauth2Error.getDescription());
+			}
+			errorDetails.append("]");
+			oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
+					"An error occurred while attempting to retrieve the UserInfo Resource: " + errorDetails.toString(), null);
+			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
+		} catch (RestClientException ex) {
+			OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
+					"An error occurred while attempting to retrieve the UserInfo Resource: " + ex.getMessage(), null);
+			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
+		}
+
+		Map<String, Object> userAttributes = response.getBody();
+		Set<GrantedAuthority> authorities = Collections.singleton(new OAuth2UserAuthority(userAttributes));
 
 		return new DefaultOAuth2User(authorities, userAttributes, userNameAttributeName);
 	}
+
+	/**
+	 * Sets the {@link Converter} used for converting the {@link OAuth2UserRequest}
+	 * to a {@link RequestEntity} representation of the UserInfo Request.
+	 *
+	 * @since 5.1
+	 * @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation of the UserInfo Request
+	 */
+	public final void setRequestEntityConverter(Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter) {
+		Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null");
+		this.requestEntityConverter = requestEntityConverter;
+	}
+
+	/**
+	 * Sets the {@link RestOperations} used when requesting the UserInfo resource.
+	 *
+	 * <p>
+	 * <b>NOTE:</b> At a minimum, the supplied {@code restOperations} must be configured with the following:
+	 * <ol>
+	 *  <li>{@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}</li>
+	 * </ol>
+	 *
+	 * @since 5.1
+	 * @param restOperations the {@link RestOperations} used when requesting the UserInfo resource
+	 */
+	public final void setRestOperations(RestOperations restOperations) {
+		Assert.notNull(restOperations, "restOperations cannot be null");
+		this.restOperations = restOperations;
+	}
 }

+ 81 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverter.java

@@ -0,0 +1,81 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.userinfo;
+
+import org.springframework.core.convert.converter.Converter;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpMethod;
+import org.springframework.http.MediaType;
+import org.springframework.http.RequestEntity;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.AuthenticationMethod;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.util.LinkedMultiValueMap;
+import org.springframework.util.MultiValueMap;
+import org.springframework.web.util.UriComponentsBuilder;
+
+import java.net.URI;
+import java.util.Collections;
+
+import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE;
+
+/**
+ * A {@link Converter} that converts the provided {@link OAuth2UserRequest}
+ * to a {@link RequestEntity} representation of a request for the UserInfo Endpoint.
+ *
+ * @author Joe Grandja
+ * @since 5.1
+ * @see Converter
+ * @see OAuth2UserRequest
+ * @see RequestEntity
+ */
+public class OAuth2UserRequestEntityConverter implements Converter<OAuth2UserRequest, RequestEntity<?>> {
+	private static final MediaType DEFAULT_CONTENT_TYPE = MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8");
+
+	/**
+	 * Returns the {@link RequestEntity} used for the UserInfo Request.
+	 *
+	 * @param userRequest the user request
+	 * @return the {@link RequestEntity} used for the UserInfo Request
+	 */
+	@Override
+	public RequestEntity<?> convert(OAuth2UserRequest userRequest) {
+		ClientRegistration clientRegistration = userRequest.getClientRegistration();
+
+		HttpMethod httpMethod = HttpMethod.GET;
+		if (AuthenticationMethod.FORM.equals(clientRegistration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod())) {
+			httpMethod = HttpMethod.POST;
+		}
+		HttpHeaders headers = new HttpHeaders();
+		headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8));
+		URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri())
+				.build()
+				.toUri();
+
+		RequestEntity<?> request;
+		if (HttpMethod.POST.equals(httpMethod)) {
+			headers.setContentType(DEFAULT_CONTENT_TYPE);
+			MultiValueMap<String, String> formParameters = new LinkedMultiValueMap<>();
+			formParameters.add(OAuth2ParameterNames.ACCESS_TOKEN, userRequest.getAccessToken().getTokenValue());
+			request = new RequestEntity<>(formParameters, headers, httpMethod, uri);
+		} else {
+			headers.setBearerAuth(userRequest.getAccessToken().getTokenValue());
+			request = new RequestEntity<>(headers, httpMethod, uri);
+		}
+
+		return request;
+	}
+}

+ 15 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/http/OAuth2ErrorResponseErrorHandlerTests.java

@@ -16,6 +16,7 @@
 package org.springframework.security.oauth2.client.http;
 
 import org.junit.Test;
+import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpStatus;
 import org.springframework.mock.http.client.MockClientHttpResponse;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
@@ -31,7 +32,7 @@ public class OAuth2ErrorResponseErrorHandlerTests {
 	private OAuth2ErrorResponseErrorHandler errorHandler = new OAuth2ErrorResponseErrorHandler();
 
 	@Test
-	public void handleErrorWhenStatusBadRequestThenHandled() {
+	public void handleErrorWhenErrorResponseBodyThenHandled() {
 		String errorResponse = "{\n" +
 				"	\"error\": \"unauthorized_client\",\n" +
 				"   \"error_description\": \"The client is not authorized\"\n" +
@@ -44,4 +45,17 @@ public class OAuth2ErrorResponseErrorHandlerTests {
 				.isInstanceOf(OAuth2AuthenticationException.class)
 				.hasMessage("[unauthorized_client] The client is not authorized");
 	}
+
+	@Test
+	public void handleErrorWhenErrorResponseWwwAuthenticateHeaderThenHandled() {
+		String wwwAuthenticateHeader = "Bearer realm=\"auth-realm\" error=\"insufficient_scope\" error_description=\"The access token expired\"";
+
+		MockClientHttpResponse response = new MockClientHttpResponse(
+				new byte[0], HttpStatus.BAD_REQUEST);
+		response.getHeaders().add(HttpHeaders.WWW_AUTHENTICATE, wwwAuthenticateHeader);
+
+		assertThatThrownBy(() -> this.errorHandler.handleError(response))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.hasMessage("[insufficient_scope] The access token expired");
+	}
 }

+ 54 - 112
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java

@@ -18,6 +18,7 @@ package org.springframework.security.oauth2.client.oidc.userinfo;
 import okhttp3.mockwebserver.MockResponse;
 import okhttp3.mockwebserver.MockWebServer;
 import okhttp3.mockwebserver.RecordedRequest;
+import org.junit.After;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -29,7 +30,6 @@ import org.powermock.modules.junit4.PowerMockRunner;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.MediaType;
-import org.springframework.security.authentication.AuthenticationServiceException;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService;
 import org.springframework.security.oauth2.core.AuthenticationMethod;
@@ -71,12 +71,15 @@ public class OidcUserServiceTests {
 	private OAuth2AccessToken accessToken;
 	private OidcIdToken idToken;
 	private OidcUserService userService = new OidcUserService();
+	private MockWebServer server;
 
 	@Rule
 	public ExpectedException exception = ExpectedException.none();
 
 	@Before
-	public void setUp() throws Exception {
+	public void setup() throws Exception {
+		this.server = new MockWebServer();
+		this.server.start();
 		this.clientRegistration = mock(ClientRegistration.class);
 		this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
 		this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class);
@@ -101,6 +104,11 @@ public class OidcUserServiceTests {
 		this.userService.setOauth2UserService(new DefaultOAuth2UserService());
 	}
 
+	@After
+	public void cleanup() throws Exception {
+		this.server.shutdown();
+	}
+
 	@Test
 	public void setOauth2UserServiceWhenNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> this.userService.setOauth2UserService(null))
@@ -135,9 +143,7 @@ public class OidcUserServiceTests {
 	}
 
 	@Test
-	public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception {
-		MockWebServer server = new MockWebServer();
-
+	public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
 		String userInfoResponse = "{\n" +
 			"	\"sub\": \"subject1\",\n" +
 			"   \"name\": \"first last\",\n" +
@@ -146,13 +152,9 @@ public class OidcUserServiceTests {
 			"   \"preferred_username\": \"user1\",\n" +
 			"   \"email\": \"user1@example.com\"\n" +
 			"}\n";
-		server.enqueue(new MockResponse()
-			.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
-			.setBody(userInfoResponse));
+		this.server.enqueue(jsonResponse(userInfoResponse));
 
-		server.start();
-
-		String userInfoUri = server.url("/user").toString();
+		String userInfoUri = this.server.url("/user").toString();
 
 		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
 		when(this.accessToken.getTokenValue()).thenReturn("access-token");
@@ -160,8 +162,6 @@ public class OidcUserServiceTests {
 		OidcUser user = this.userService.loadUser(
 			new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
 
-		server.shutdown();
-
 		assertThat(user.getIdToken()).isNotNull();
 		assertThat(user.getUserInfo()).isNotNull();
 		assertThat(user.getUserInfo().getClaims().size()).isEqualTo(6);
@@ -184,69 +184,47 @@ public class OidcUserServiceTests {
 
 	// gh-5447
 	@Test
-	public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectIsNullThenThrowOAuth2AuthenticationException() throws Exception {
+	public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectIsNullThenThrowOAuth2AuthenticationException() {
 		this.exception.expect(OAuth2AuthenticationException.class);
 		this.exception.expectMessage(containsString("invalid_user_info_response"));
 
-		MockWebServer server = new MockWebServer();
-
 		String userInfoResponse = "{\n" +
 				"	\"email\": \"full_name@provider.com\",\n" +
 				"	\"name\": \"full name\"\n" +
 				"}\n";
-		server.enqueue(new MockResponse()
-				.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
-				.setBody(userInfoResponse));
+		this.server.enqueue(jsonResponse(userInfoResponse));
 
-		server.start();
-
-		String userInfoUri = server.url("/user").toString();
+		String userInfoUri = this.server.url("/user").toString();
 
 		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
 		when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(StandardClaimNames.EMAIL);
 		when(this.accessToken.getTokenValue()).thenReturn("access-token");
 
-		try {
-			this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
-		} finally {
-			server.shutdown();
-		}
+		this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
 	}
 
 	@Test
-	public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectNotSameAsIdTokenSubjectThenThrowOAuth2AuthenticationException() throws Exception {
+	public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectNotSameAsIdTokenSubjectThenThrowOAuth2AuthenticationException() {
 		this.exception.expect(OAuth2AuthenticationException.class);
 		this.exception.expectMessage(containsString("invalid_user_info_response"));
 
-		MockWebServer server = new MockWebServer();
-
 		String userInfoResponse = "{\n" +
 			"	\"sub\": \"other-subject\"\n" +
 			"}\n";
-		server.enqueue(new MockResponse()
-			.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
-			.setBody(userInfoResponse));
-
-		server.start();
+		this.server.enqueue(jsonResponse(userInfoResponse));
 
-		String userInfoUri = server.url("/user").toString();
+		String userInfoUri = this.server.url("/user").toString();
 
 		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
 		when(this.accessToken.getTokenValue()).thenReturn("access-token");
 
-		try {
-			this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
-		} finally {
-			server.shutdown();
-		}
+		this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
 	}
 
 	@Test
-	public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception {
+	public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() {
 		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("invalid_user_info_response"));
-
-		MockWebServer server = new MockWebServer();
+		this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
 
 		String userInfoResponse = "{\n" +
 			"	\"sub\": \"subject1\",\n" +
@@ -256,48 +234,35 @@ public class OidcUserServiceTests {
 			"   \"preferred_username\": \"user1\",\n" +
 			"   \"email\": \"user1@example.com\"\n";
 //			"}\n";		// Make the JSON invalid/malformed
-		server.enqueue(new MockResponse()
-			.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
-			.setBody(userInfoResponse));
-
-		server.start();
+		this.server.enqueue(jsonResponse(userInfoResponse));
 
-		String userInfoUri = server.url("/user").toString();
+		String userInfoUri = this.server.url("/user").toString();
 
 		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
 		when(this.accessToken.getTokenValue()).thenReturn("access-token");
 
-		try {
-			this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
-		} finally {
-			server.shutdown();
-		}
+		this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
 	}
 
 	@Test
-	public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() throws Exception {
+	public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() {
 		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("invalid_user_info_response"));
+		this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error"));
 
-		MockWebServer server = new MockWebServer();
-		server.enqueue(new MockResponse().setResponseCode(500));
-		server.start();
+		this.server.enqueue(new MockResponse().setResponseCode(500));
 
 		String userInfoUri = server.url("/user").toString();
 
 		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
 		when(this.accessToken.getTokenValue()).thenReturn("access-token");
 
-		try {
-			this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
-		} finally {
-			server.shutdown();
-		}
+		this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
 	}
 
 	@Test
-	public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceException() throws Exception {
-		this.exception.expect(AuthenticationServiceException.class);
+	public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
 
 		String userInfoUri = "http://invalid-provider.com/user";
 
@@ -308,9 +273,7 @@ public class OidcUserServiceTests {
 	}
 
 	@Test
-	public void loadUserWhenCustomUserNameAttributeNameThenGetNameReturnsCustomUserName() throws Exception {
-		MockWebServer server = new MockWebServer();
-
+	public void loadUserWhenCustomUserNameAttributeNameThenGetNameReturnsCustomUserName() {
 		String userInfoResponse = "{\n" +
 			"	\"sub\": \"subject1\",\n" +
 			"   \"name\": \"first last\",\n" +
@@ -319,13 +282,9 @@ public class OidcUserServiceTests {
 			"   \"preferred_username\": \"user1\",\n" +
 			"   \"email\": \"user1@example.com\"\n" +
 			"}\n";
-		server.enqueue(new MockResponse()
-			.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
-			.setBody(userInfoResponse));
+		this.server.enqueue(jsonResponse(userInfoResponse));
 
-		server.start();
-
-		String userInfoUri = server.url("/user").toString();
+		String userInfoUri = this.server.url("/user").toString();
 
 		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
 		when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(StandardClaimNames.EMAIL);
@@ -334,16 +293,12 @@ public class OidcUserServiceTests {
 		OidcUser user = this.userService.loadUser(
 			new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
 
-		server.shutdown();
-
 		assertThat(user.getName()).isEqualTo("user1@example.com");
 	}
 
 	// gh-5294
 	@Test
 	public void loadUserWhenUserInfoSuccessResponseThenAcceptHeaderJson() throws Exception {
-		MockWebServer server = new MockWebServer();
-
 		String userInfoResponse = "{\n" +
 				"	\"sub\": \"subject1\",\n" +
 				"   \"name\": \"first last\",\n" +
@@ -352,28 +307,21 @@ public class OidcUserServiceTests {
 				"   \"preferred_username\": \"user1\",\n" +
 				"   \"email\": \"user1@example.com\"\n" +
 				"}\n";
-		server.enqueue(new MockResponse()
-				.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
-				.setBody(userInfoResponse));
-
-		server.start();
+		this.server.enqueue(jsonResponse(userInfoResponse));
 
-		String userInfoUri = server.url("/user").toString();
+		String userInfoUri = this.server.url("/user").toString();
 
 		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
 		when(this.accessToken.getTokenValue()).thenReturn("access-token");
 
 		this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
-		server.shutdown();
-		assertThat(server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT))
-				.isEqualTo(MediaType.APPLICATION_JSON_VALUE);
+		assertThat(this.server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT))
+				.isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE);
 	}
 
 	// gh-5500
 	@Test
 	public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodGet() throws Exception {
-		MockWebServer server = new MockWebServer();
-
 		String userInfoResponse = "{\n" +
 				"	\"sub\": \"subject1\",\n" +
 				"   \"name\": \"first last\",\n" +
@@ -382,31 +330,24 @@ public class OidcUserServiceTests {
 				"   \"preferred_username\": \"user1\",\n" +
 				"   \"email\": \"user1@example.com\"\n" +
 				"}\n";
-		server.enqueue(new MockResponse()
-				.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
-				.setBody(userInfoResponse));
+		this.server.enqueue(jsonResponse(userInfoResponse));
 
-		server.start();
-
-		String userInfoUri = server.url("/user").toString();
+		String userInfoUri = this.server.url("/user").toString();
 
 		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
 		when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
 		when(this.accessToken.getTokenValue()).thenReturn("access-token");
 
 		this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
-		server.shutdown();
-		RecordedRequest request = server.takeRequest();
+		RecordedRequest request = this.server.takeRequest();
 		assertThat(request.getMethod()).isEqualTo(HttpMethod.GET.name());
-		assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);
+		assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE);
 		assertThat(request.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue());
 	}
 
 	// gh-5500
 	@Test
 	public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPost() throws Exception {
-		MockWebServer server = new MockWebServer();
-
 		String userInfoResponse = "{\n" +
 				"	\"sub\": \"subject1\",\n" +
 				"   \"name\": \"first last\",\n" +
@@ -415,24 +356,25 @@ public class OidcUserServiceTests {
 				"   \"preferred_username\": \"user1\",\n" +
 				"   \"email\": \"user1@example.com\"\n" +
 				"}\n";
-		server.enqueue(new MockResponse()
-				.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
-				.setBody(userInfoResponse));
+		this.server.enqueue(jsonResponse(userInfoResponse));
 
-		server.start();
-
-		String userInfoUri = server.url("/user").toString();
+		String userInfoUri = this.server.url("/user").toString();
 
 		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
 		when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.FORM);
 		when(this.accessToken.getTokenValue()).thenReturn("access-token");
 
 		this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
-		server.shutdown();
-		RecordedRequest request = server.takeRequest();
+		RecordedRequest request = this.server.takeRequest();
 		assertThat(request.getMethod()).isEqualTo(HttpMethod.POST.name());
-		assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);
+		assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE);
 		assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE)).contains(MediaType.APPLICATION_FORM_URLENCODED_VALUE);
 		assertThat(request.getBody().readUtf8()).isEqualTo("access_token=" + this.accessToken.getTokenValue());
 	}
+
+	private MockResponse jsonResponse(String json) {
+		return new MockResponse()
+				.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
+				.setBody(json);
+	}
 }

+ 99 - 74
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java

@@ -18,7 +18,7 @@ package org.springframework.security.oauth2.client.userinfo;
 import okhttp3.mockwebserver.MockResponse;
 import okhttp3.mockwebserver.MockWebServer;
 import okhttp3.mockwebserver.RecordedRequest;
-
+import org.junit.After;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -30,7 +30,6 @@ import org.powermock.modules.junit4.PowerMockRunner;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.MediaType;
-import org.springframework.security.authentication.AuthenticationServiceException;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.core.AuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
@@ -59,12 +58,15 @@ public class DefaultOAuth2UserServiceTests {
 	private ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint;
 	private OAuth2AccessToken accessToken;
 	private DefaultOAuth2UserService userService = new DefaultOAuth2UserService();
+	private MockWebServer server;
 
 	@Rule
 	public ExpectedException exception = ExpectedException.none();
 
 	@Before
-	public void setUp() throws Exception {
+	public void setup() throws Exception {
+		this.server = new MockWebServer();
+		this.server.start();
 		this.clientRegistration = mock(ClientRegistration.class);
 		this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
 		this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class);
@@ -73,6 +75,23 @@ public class DefaultOAuth2UserServiceTests {
 		this.accessToken = mock(OAuth2AccessToken.class);
 	}
 
+	@After
+	public void cleanup() throws Exception {
+		this.server.shutdown();
+	}
+
+	@Test
+	public void setRequestEntityConverterWhenNullThenThrowIllegalArgumentException() {
+		this.exception.expect(IllegalArgumentException.class);
+		this.userService.setRequestEntityConverter(null);
+	}
+
+	@Test
+	public void setRestOperationsWhenNullThenThrowIllegalArgumentException() {
+		this.exception.expect(IllegalArgumentException.class);
+		this.userService.setRestOperations(null);
+	}
+
 	@Test
 	public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
 		this.exception.expect(IllegalArgumentException.class);
@@ -99,9 +118,7 @@ public class DefaultOAuth2UserServiceTests {
 	}
 
 	@Test
-	public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception {
-		MockWebServer server = new MockWebServer();
-
+	public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
 		String userInfoResponse = "{\n" +
 			"	\"user-name\": \"user1\",\n" +
 			"   \"first-name\": \"first\",\n" +
@@ -110,13 +127,9 @@ public class DefaultOAuth2UserServiceTests {
 			"   \"address\": \"address\",\n" +
 			"   \"email\": \"user1@example.com\"\n" +
 			"}\n";
-		server.enqueue(new MockResponse()
-			.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
-			.setBody(userInfoResponse));
-
-		server.start();
+		this.server.enqueue(jsonResponse(userInfoResponse));
 
-		String userInfoUri = server.url("/user").toString();
+		String userInfoUri = this.server.url("/user").toString();
 
 		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
 		when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
@@ -125,8 +138,6 @@ public class DefaultOAuth2UserServiceTests {
 
 		OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
 
-		server.shutdown();
-
 		assertThat(user.getName()).isEqualTo("user1");
 		assertThat(user.getAttributes().size()).isEqualTo(6);
 		assertThat(user.getAttributes().get("user-name")).isEqualTo("user1");
@@ -144,11 +155,9 @@ public class DefaultOAuth2UserServiceTests {
 	}
 
 	@Test
-	public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception {
+	public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() {
 		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("invalid_user_info_response"));
-
-		MockWebServer server = new MockWebServer();
+		this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
 
 		String userInfoResponse = "{\n" +
 			"	\"user-name\": \"user1\",\n" +
@@ -158,52 +167,83 @@ public class DefaultOAuth2UserServiceTests {
 			"   \"address\": \"address\",\n" +
 			"   \"email\": \"user1@example.com\"\n";
 //			"}\n";		// Make the JSON invalid/malformed
-		server.enqueue(new MockResponse()
-			.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
-			.setBody(userInfoResponse));
+		this.server.enqueue(jsonResponse(userInfoResponse));
+
+		String userInfoUri = this.server.url("/user").toString();
+
+		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
+		when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
+		when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
+		when(this.accessToken.getTokenValue()).thenReturn("access-token");
+
+		this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
+	}
+
+	@Test
+	public void loadUserWhenUserInfoErrorResponseWwwAuthenticateHeaderThenThrowOAuth2AuthenticationException() {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
+		this.exception.expectMessage(containsString("Error Code: insufficient_scope, Error Description: The access token expired"));
 
-		server.start();
+		String wwwAuthenticateHeader = "Bearer realm=\"auth-realm\" error=\"insufficient_scope\" error_description=\"The access token expired\"";
 
-		String userInfoUri = server.url("/user").toString();
+		MockResponse response = new MockResponse();
+		response.setHeader(HttpHeaders.WWW_AUTHENTICATE, wwwAuthenticateHeader);
+		response.setResponseCode(400);
+		this.server.enqueue(response);
+
+		String userInfoUri = this.server.url("/user").toString();
+
+		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
+		when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
+		when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
+		when(this.accessToken.getTokenValue()).thenReturn("access-token");
+
+		this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
+	}
+
+	@Test
+	public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
+		this.exception.expectMessage(containsString("Error Code: invalid_token"));
+
+		String userInfoErrorResponse = "{\n" +
+				"   \"error\": \"invalid_token\"\n" +
+				"}\n";
+		this.server.enqueue(jsonResponse(userInfoErrorResponse).setResponseCode(400));
+
+		String userInfoUri = this.server.url("/user").toString();
 
 		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
 		when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
 		when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
 		when(this.accessToken.getTokenValue()).thenReturn("access-token");
 
-		try {
-			this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
-		} finally {
-			server.shutdown();
-		}
+		this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
 	}
 
 	@Test
-	public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() throws Exception {
+	public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() {
 		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("invalid_user_info_response"));
+		this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error"));
 
-		MockWebServer server = new MockWebServer();
-		server.enqueue(new MockResponse().setResponseCode(500));
-		server.start();
+		this.server.enqueue(new MockResponse().setResponseCode(500));
 
-		String userInfoUri = server.url("/user").toString();
+		String userInfoUri = this.server.url("/user").toString();
 
 		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
 		when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
 		when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn("user-name");
 		when(this.accessToken.getTokenValue()).thenReturn("access-token");
 
-		try {
-			this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
-		} finally {
-			server.shutdown();
-		}
+		this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
 	}
 
 	@Test
-	public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceException() throws Exception {
-		this.exception.expect(AuthenticationServiceException.class);
+	public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
 
 		String userInfoUri = "http://invalid-provider.com/user";
 
@@ -218,8 +258,6 @@ public class DefaultOAuth2UserServiceTests {
 	// gh-5294
 	@Test
 	public void loadUserWhenUserInfoSuccessResponseThenAcceptHeaderJson() throws Exception {
-		MockWebServer server = new MockWebServer();
-
 		String userInfoResponse = "{\n" +
 				"	\"user-name\": \"user1\",\n" +
 				"   \"first-name\": \"first\",\n" +
@@ -228,13 +266,9 @@ public class DefaultOAuth2UserServiceTests {
 				"   \"address\": \"address\",\n" +
 				"   \"email\": \"user1@example.com\"\n" +
 				"}\n";
-		server.enqueue(new MockResponse()
-				.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
-				.setBody(userInfoResponse));
-
-		server.start();
+		this.server.enqueue(jsonResponse(userInfoResponse));
 
-		String userInfoUri = server.url("/user").toString();
+		String userInfoUri = this.server.url("/user").toString();
 
 		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
 		when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
@@ -242,16 +276,13 @@ public class DefaultOAuth2UserServiceTests {
 		when(this.accessToken.getTokenValue()).thenReturn("access-token");
 
 		this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
-		server.shutdown();
-		assertThat(server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT))
-				.isEqualTo(MediaType.APPLICATION_JSON_VALUE);
+		assertThat(this.server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT))
+				.isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE);
 	}
 
 	// gh-5500
 	@Test
 	public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodGet() throws Exception {
-		MockWebServer server = new MockWebServer();
-
 		String userInfoResponse = "{\n" +
 				"	\"user-name\": \"user1\",\n" +
 				"   \"first-name\": \"first\",\n" +
@@ -260,13 +291,9 @@ public class DefaultOAuth2UserServiceTests {
 				"   \"address\": \"address\",\n" +
 				"   \"email\": \"user1@example.com\"\n" +
 				"}\n";
-		server.enqueue(new MockResponse()
-				.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
-				.setBody(userInfoResponse));
-
-		server.start();
+		this.server.enqueue(jsonResponse(userInfoResponse));
 
-		String userInfoUri = server.url("/user").toString();
+		String userInfoUri = this.server.url("/user").toString();
 
 		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
 		when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
@@ -274,18 +301,15 @@ public class DefaultOAuth2UserServiceTests {
 		when(this.accessToken.getTokenValue()).thenReturn("access-token");
 
 		this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
-		server.shutdown();
-		RecordedRequest request = server.takeRequest();
+		RecordedRequest request = this.server.takeRequest();
 		assertThat(request.getMethod()).isEqualTo(HttpMethod.GET.name());
-		assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);
+		assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE);
 		assertThat(request.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue());
 	}
 
 	// gh-5500
 	@Test
 	public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPost() throws Exception {
-		MockWebServer server = new MockWebServer();
-
 		String userInfoResponse = "{\n" +
 				"	\"user-name\": \"user1\",\n" +
 				"   \"first-name\": \"first\",\n" +
@@ -294,13 +318,9 @@ public class DefaultOAuth2UserServiceTests {
 				"   \"address\": \"address\",\n" +
 				"   \"email\": \"user1@example.com\"\n" +
 				"}\n";
-		server.enqueue(new MockResponse()
-				.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
-				.setBody(userInfoResponse));
+		this.server.enqueue(jsonResponse(userInfoResponse));
 
-		server.start();
-
-		String userInfoUri = server.url("/user").toString();
+		String userInfoUri = this.server.url("/user").toString();
 
 		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
 		when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.FORM);
@@ -308,11 +328,16 @@ public class DefaultOAuth2UserServiceTests {
 		when(this.accessToken.getTokenValue()).thenReturn("access-token");
 
 		this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
-		server.shutdown();
-		RecordedRequest request = server.takeRequest();
+		RecordedRequest request = this.server.takeRequest();
 		assertThat(request.getMethod()).isEqualTo(HttpMethod.POST.name());
-		assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);
+		assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE);
 		assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE)).contains(MediaType.APPLICATION_FORM_URLENCODED_VALUE);
 		assertThat(request.getBody().readUtf8()).isEqualTo("access_token=" + this.accessToken.getTokenValue());
 	}
+
+	private MockResponse jsonResponse(String json) {
+		return new MockResponse()
+				.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
+				.setBody(json);
+	}
 }

+ 125 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestEntityConverterTests.java

@@ -0,0 +1,125 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.userinfo;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpMethod;
+import org.springframework.http.MediaType;
+import org.springframework.http.RequestEntity;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.AuthenticationMethod;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.util.MultiValueMap;
+
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.LinkedHashSet;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE;
+
+/**
+ * Tests for {@link OAuth2UserRequestEntityConverter}.
+ *
+ * @author Joe Grandja
+ */
+public class OAuth2UserRequestEntityConverterTests {
+	private OAuth2UserRequestEntityConverter converter = new OAuth2UserRequestEntityConverter();
+	private OAuth2UserRequest userRequest;
+
+	@Before
+	public void setup() {
+		ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("registration-1")
+				.clientId("client-1")
+				.clientSecret("secret")
+				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+				.redirectUriTemplate("https://client.com/callback/client-1")
+				.scope("read", "write")
+				.authorizationUri("https://provider.com/oauth2/authorize")
+				.tokenUri("https://provider.com/oauth2/token")
+				.userInfoUri("https://provider.com/user")
+				.userInfoAuthenticationMethod(AuthenticationMethod.HEADER)
+				.userNameAttributeName("id")
+				.build();
+		OAuth2AccessToken accessToken = new OAuth2AccessToken(
+				OAuth2AccessToken.TokenType.BEARER, "access-token-1234", Instant.now(),
+				Instant.now().plusSeconds(3600), new LinkedHashSet<>(Arrays.asList("read", "write")));
+		this.userRequest = new OAuth2UserRequest(clientRegistration, accessToken);
+	}
+
+	@SuppressWarnings("unchecked")
+	@Test
+	public void convertWhenAuthenticationMethodHeaderThenGetRequest() {
+		RequestEntity<?> requestEntity = this.converter.convert(this.userRequest);
+
+		ClientRegistration clientRegistration = this.userRequest.getClientRegistration();
+
+		assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.GET);
+		assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo(
+				clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri());
+
+		HttpHeaders headers = requestEntity.getHeaders();
+		assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8);
+		assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo(
+				"Bearer " + this.userRequest.getAccessToken().getTokenValue());
+	}
+
+	@SuppressWarnings("unchecked")
+	@Test
+	public void convertWhenAuthenticationMethodFormThenPostRequest() {
+		ClientRegistration clientRegistration = this.from(this.userRequest.getClientRegistration())
+				.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
+				.build();
+		OAuth2UserRequest userRequest = new OAuth2UserRequest(
+				clientRegistration, this.userRequest.getAccessToken());
+
+		RequestEntity<?> requestEntity = this.converter.convert(userRequest);
+
+		assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST);
+		assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo(
+				clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri());
+
+		HttpHeaders headers = requestEntity.getHeaders();
+		assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8);
+		assertThat(headers.getContentType()).isEqualTo(
+				MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"));
+
+		MultiValueMap<String, String> formParameters = (MultiValueMap<String, String>) requestEntity.getBody();
+		assertThat(formParameters.getFirst(OAuth2ParameterNames.ACCESS_TOKEN)).isEqualTo(
+				this.userRequest.getAccessToken().getTokenValue());
+	}
+
+	private ClientRegistration.Builder from(ClientRegistration registration) {
+		return ClientRegistration.withRegistrationId(registration.getRegistrationId())
+				.clientId(registration.getClientId())
+				.clientSecret(registration.getClientSecret())
+				.clientAuthenticationMethod(registration.getClientAuthenticationMethod())
+				.authorizationGrantType(registration.getAuthorizationGrantType())
+				.redirectUriTemplate(registration.getRedirectUriTemplate())
+				.scope(registration.getScopes())
+				.authorizationUri(registration.getProviderDetails().getAuthorizationUri())
+				.tokenUri(registration.getProviderDetails().getTokenUri())
+				.userInfoUri(registration.getProviderDetails().getUserInfoEndpoint().getUri())
+				.userNameAttributeName(registration.getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName())
+				.clientName(registration.getClientName());
+	}
+}