Procházet zdrojové kódy

Provide RestOperations in CustomUserTypesOAuth2UserService

Fixes gh-5602
Joe Grandja před 7 roky
rodič
revize
3b480a3a05

+ 63 - 3
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserService.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,10 +15,19 @@
  */
 package org.springframework.security.oauth2.client.userinfo;
 
+import org.springframework.core.convert.converter.Converter;
+import org.springframework.http.RequestEntity;
+import org.springframework.http.ResponseEntity;
+import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.util.Assert;
+import org.springframework.web.client.ResponseErrorHandler;
+import org.springframework.web.client.RestClientException;
+import org.springframework.web.client.RestOperations;
+import org.springframework.web.client.RestTemplate;
 
 import java.util.Collections;
 import java.util.LinkedHashMap;
@@ -39,8 +48,13 @@ import java.util.Map;
  * @see ClientRegistration
  */
 public class CustomUserTypesOAuth2UserService implements OAuth2UserService<OAuth2UserRequest, OAuth2User> {
+	private static final String INVALID_USER_INFO_RESPONSE_ERROR_CODE = "invalid_user_info_response";
+
 	private final Map<String, Class<? extends OAuth2User>> customUserTypes;
-	private NimbusUserInfoResponseClient userInfoResponseClient = new NimbusUserInfoResponseClient();
+
+	private Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter = new OAuth2UserRequestEntityConverter();
+
+	private RestOperations restOperations;
 
 	/**
 	 * Constructs a {@code CustomUserTypesOAuth2UserService} using the provided parameters.
@@ -50,6 +64,9 @@ public class CustomUserTypesOAuth2UserService implements OAuth2UserService<OAuth
 	public CustomUserTypesOAuth2UserService(Map<String, Class<? extends OAuth2User>> customUserTypes) {
 		Assert.notEmpty(customUserTypes, "customUserTypes cannot be empty");
 		this.customUserTypes = Collections.unmodifiableMap(new LinkedHashMap<>(customUserTypes));
+		RestTemplate restTemplate = new RestTemplate();
+		restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler());
+		this.restOperations = restTemplate;
 	}
 
 	@Override
@@ -60,6 +77,49 @@ public class CustomUserTypesOAuth2UserService implements OAuth2UserService<OAuth
 		if ((customUserType = this.customUserTypes.get(registrationId)) == null) {
 			return null;
 		}
-		return this.userInfoResponseClient.getUserInfoResponse(userRequest, customUserType);
+
+		RequestEntity<?> request = this.requestEntityConverter.convert(userRequest);
+
+		ResponseEntity<? extends OAuth2User> response;
+		try {
+			response = this.restOperations.exchange(request, customUserType);
+		} catch (RestClientException ex) {
+			OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
+					"An error occurred while attempting to retrieve the UserInfo Resource: " + ex.getMessage(), null);
+			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
+		}
+
+		OAuth2User oauth2User = response.getBody();
+
+		return oauth2User;
+	}
+
+	/**
+	 * Sets the {@link Converter} used for converting the {@link OAuth2UserRequest}
+	 * to a {@link RequestEntity} representation of the UserInfo Request.
+	 *
+	 * @since 5.1
+	 * @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation of the UserInfo Request
+	 */
+	public final void setRequestEntityConverter(Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter) {
+		Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null");
+		this.requestEntityConverter = requestEntityConverter;
+	}
+
+	/**
+	 * Sets the {@link RestOperations} used when requesting the UserInfo resource.
+	 *
+	 * <p>
+	 * <b>NOTE:</b> At a minimum, the supplied {@code restOperations} must be configured with the following:
+	 * <ol>
+	 *  <li>{@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}</li>
+	 * </ol>
+	 *
+	 * @since 5.1
+	 * @param restOperations the {@link RestOperations} used when requesting the UserInfo resource
+	 */
+	public final void setRestOperations(RestOperations restOperations) {
+		Assert.notNull(restOperations, "restOperations cannot be null");
+		this.restOperations = restOperations;
 	}
 }

+ 0 - 169
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/NimbusUserInfoResponseClient.java

@@ -1,169 +0,0 @@
-/*
- * Copyright 2002-2018 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.security.oauth2.client.userinfo;
-
-import com.nimbusds.oauth2.sdk.ErrorObject;
-import com.nimbusds.oauth2.sdk.ParseException;
-import com.nimbusds.oauth2.sdk.http.HTTPRequest;
-import com.nimbusds.oauth2.sdk.http.HTTPResponse;
-import com.nimbusds.oauth2.sdk.token.BearerAccessToken;
-import com.nimbusds.openid.connect.sdk.UserInfoErrorResponse;
-import com.nimbusds.openid.connect.sdk.UserInfoRequest;
-import org.springframework.core.ParameterizedTypeReference;
-import org.springframework.http.HttpHeaders;
-import org.springframework.http.MediaType;
-import org.springframework.http.client.AbstractClientHttpResponse;
-import org.springframework.http.client.ClientHttpResponse;
-import org.springframework.http.converter.GenericHttpMessageConverter;
-import org.springframework.http.converter.HttpMessageNotReadableException;
-import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
-import org.springframework.security.authentication.AuthenticationServiceException;
-import org.springframework.security.oauth2.client.registration.ClientRegistration;
-import org.springframework.security.oauth2.core.AuthenticationMethod;
-import org.springframework.security.oauth2.core.OAuth2AccessToken;
-import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
-import org.springframework.security.oauth2.core.OAuth2Error;
-import org.springframework.util.Assert;
-
-import java.io.ByteArrayInputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.net.URI;
-import java.nio.charset.Charset;
-
-/**
- * @author Joe Grandja
- * @since 5.0
- */
-final class NimbusUserInfoResponseClient {
-	private static final String INVALID_USER_INFO_RESPONSE_ERROR_CODE = "invalid_user_info_response";
-	private final GenericHttpMessageConverter genericHttpMessageConverter = new MappingJackson2HttpMessageConverter();
-
-	<T> T getUserInfoResponse(OAuth2UserRequest userInfoRequest, Class<T> returnType) throws OAuth2AuthenticationException {
-		ClientHttpResponse userInfoResponse = this.getUserInfoResponse(
-			userInfoRequest.getClientRegistration(), userInfoRequest.getAccessToken());
-		try {
-			return (T) this.genericHttpMessageConverter.read(returnType, userInfoResponse);
-		} catch (IOException | HttpMessageNotReadableException ex) {
-			OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
-				"An error occurred reading the UserInfo Success response: " + ex.getMessage(), null);
-			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
-		}
-	}
-
-	<T> T getUserInfoResponse(OAuth2UserRequest userInfoRequest, ParameterizedTypeReference<T> typeReference) throws OAuth2AuthenticationException {
-		ClientHttpResponse userInfoResponse = this.getUserInfoResponse(
-			userInfoRequest.getClientRegistration(), userInfoRequest.getAccessToken());
-		try {
-			return (T) this.genericHttpMessageConverter.read(typeReference.getType(), null, userInfoResponse);
-		} catch (IOException | HttpMessageNotReadableException ex) {
-			OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
-				"An error occurred reading the UserInfo Success response: " + ex.getMessage(), null);
-			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
-		}
-	}
-
-	private ClientHttpResponse getUserInfoResponse(ClientRegistration clientRegistration,
-													OAuth2AccessToken oauth2AccessToken) throws OAuth2AuthenticationException {
-		URI userInfoUri = URI.create(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri());
-		BearerAccessToken accessToken = new BearerAccessToken(oauth2AccessToken.getTokenValue());
-		AuthenticationMethod authenticationMethod = clientRegistration.getProviderDetails().getUserInfoEndpoint().getAuthenticationMethod();
-		HTTPRequest.Method httpMethod = AuthenticationMethod.FORM.equals(authenticationMethod)
-				? HTTPRequest.Method.POST : HTTPRequest.Method.GET;
-
-		UserInfoRequest userInfoRequest = new UserInfoRequest(userInfoUri, httpMethod, accessToken);
-		HTTPRequest httpRequest = userInfoRequest.toHTTPRequest();
-		httpRequest.setAccept(MediaType.APPLICATION_JSON_VALUE);
-		httpRequest.setConnectTimeout(30000);
-		httpRequest.setReadTimeout(30000);
-		HTTPResponse httpResponse;
-
-		try {
-			httpResponse = httpRequest.send();
-		} catch (IOException ex) {
-			throw new AuthenticationServiceException("An error occurred while sending the UserInfo Request: " +
-				ex.getMessage(), ex);
-		}
-
-		if (httpResponse.getStatusCode() == HTTPResponse.SC_OK) {
-			return new NimbusClientHttpResponse(httpResponse);
-		}
-
-		UserInfoErrorResponse userInfoErrorResponse;
-		try {
-			userInfoErrorResponse = UserInfoErrorResponse.parse(httpResponse);
-		} catch (ParseException ex) {
-			OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
-				"An error occurred parsing the UserInfo Error response: " + ex.getMessage(), null);
-			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
-		}
-		ErrorObject errorObject = userInfoErrorResponse.getErrorObject();
-
-		StringBuilder errorDescription = new StringBuilder();
-		errorDescription.append("An error occurred while attempting to access the UserInfo Endpoint -> ");
-		errorDescription.append("Error details: [");
-		errorDescription.append("UserInfo Uri: ").append(userInfoUri.toString());
-		errorDescription.append(", Http Status: ").append(errorObject.getHTTPStatusCode());
-		if (errorObject.getCode() != null) {
-			errorDescription.append(", Error Code: ").append(errorObject.getCode());
-		}
-		if (errorObject.getDescription() != null) {
-			errorDescription.append(", Error Description: ").append(errorObject.getDescription());
-		}
-		errorDescription.append("]");
-
-		OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, errorDescription.toString(), null);
-		throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
-	}
-
-	private static class NimbusClientHttpResponse extends AbstractClientHttpResponse {
-		private final HTTPResponse httpResponse;
-		private final HttpHeaders headers;
-
-		private NimbusClientHttpResponse(HTTPResponse httpResponse) {
-			Assert.notNull(httpResponse, "httpResponse cannot be null");
-			this.httpResponse = httpResponse;
-			this.headers = new HttpHeaders();
-			this.headers.setAll(httpResponse.getHeaders());
-		}
-
-		@Override
-		public int getRawStatusCode() throws IOException {
-			return this.httpResponse.getStatusCode();
-		}
-
-		@Override
-		public String getStatusText() throws IOException {
-			return String.valueOf(this.getRawStatusCode());
-		}
-
-		@Override
-		public void close() {
-		}
-
-		@Override
-		public InputStream getBody() throws IOException {
-			InputStream inputStream = new ByteArrayInputStream(
-				this.httpResponse.getContent().getBytes(Charset.forName("UTF-8")));
-			return inputStream;
-		}
-
-		@Override
-		public HttpHeaders getHeaders() {
-			return this.headers;
-		}
-	}
-}

+ 44 - 41
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@ package org.springframework.security.oauth2.client.userinfo;
 
 import okhttp3.mockwebserver.MockResponse;
 import okhttp3.mockwebserver.MockWebServer;
+import org.junit.After;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -27,7 +28,6 @@ import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.MediaType;
-import org.springframework.security.authentication.AuthenticationServiceException;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
@@ -60,12 +60,15 @@ public class CustomUserTypesOAuth2UserServiceTests {
 	private ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint;
 	private OAuth2AccessToken accessToken;
 	private CustomUserTypesOAuth2UserService userService;
+	private MockWebServer server;
 
 	@Rule
 	public ExpectedException exception = ExpectedException.none();
 
 	@Before
 	public void setUp() throws Exception {
+		this.server = new MockWebServer();
+		this.server.start();
 		this.clientRegistration = mock(ClientRegistration.class);
 		this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
 		this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class);
@@ -80,6 +83,11 @@ public class CustomUserTypesOAuth2UserServiceTests {
 		this.userService = new CustomUserTypesOAuth2UserService(customUserTypes);
 	}
 
+	@After
+	public void cleanup() throws Exception {
+		this.server.shutdown();
+	}
+
 	@Test
 	public void constructorWhenCustomUserTypesIsNullThenThrowIllegalArgumentException() {
 		this.exception.expect(IllegalArgumentException.class);
@@ -92,6 +100,18 @@ public class CustomUserTypesOAuth2UserServiceTests {
 		new CustomUserTypesOAuth2UserService(Collections.emptyMap());
 	}
 
+	@Test
+	public void setRequestEntityConverterWhenNullThenThrowIllegalArgumentException() {
+		this.exception.expect(IllegalArgumentException.class);
+		this.userService.setRequestEntityConverter(null);
+	}
+
+	@Test
+	public void setRestOperationsWhenNullThenThrowIllegalArgumentException() {
+		this.exception.expect(IllegalArgumentException.class);
+		this.userService.setRestOperations(null);
+	}
+
 	@Test
 	public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
 		this.exception.expect(IllegalArgumentException.class);
@@ -107,30 +127,22 @@ public class CustomUserTypesOAuth2UserServiceTests {
 	}
 
 	@Test
-	public void loadUserWhenUserInfoSuccessResponseThenReturnUser() throws Exception {
-		MockWebServer server = new MockWebServer();
-
+	public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
 		String userInfoResponse = "{\n" +
 			"	\"id\": \"12345\",\n" +
 			"   \"name\": \"first last\",\n" +
 			"   \"login\": \"user1\",\n" +
 			"   \"email\": \"user1@example.com\"\n" +
 			"}\n";
-		server.enqueue(new MockResponse()
-			.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
-			.setBody(userInfoResponse));
+		this.server.enqueue(jsonResponse(userInfoResponse));
 
-		server.start();
-
-		String userInfoUri = server.url("/user").toString();
+		String userInfoUri = this.server.url("/user").toString();
 
 		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
 		when(this.accessToken.getTokenValue()).thenReturn("access-token");
 
 		OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
 
-		server.shutdown();
-
 		assertThat(user.getName()).isEqualTo("first last");
 		assertThat(user.getAttributes().size()).isEqualTo(4);
 		assertThat(user.getAttributes().get("id")).isEqualTo("12345");
@@ -143,11 +155,9 @@ public class CustomUserTypesOAuth2UserServiceTests {
 	}
 
 	@Test
-	public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception {
+	public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() {
 		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("invalid_user_info_response"));
-
-		MockWebServer server = new MockWebServer();
+		this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
 
 		String userInfoResponse = "{\n" +
 			"	\"id\": \"12345\",\n" +
@@ -155,48 +165,35 @@ public class CustomUserTypesOAuth2UserServiceTests {
 			"   \"login\": \"user1\",\n" +
 			"   \"email\": \"user1@example.com\"\n";
 //			"}\n";		// Make the JSON invalid/malformed
-		server.enqueue(new MockResponse()
-			.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
-			.setBody(userInfoResponse));
+		this.server.enqueue(jsonResponse(userInfoResponse));
 
-		server.start();
-
-		String userInfoUri = server.url("/user").toString();
+		String userInfoUri = this.server.url("/user").toString();
 
 		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
 		when(this.accessToken.getTokenValue()).thenReturn("access-token");
 
-		try {
-			this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
-		} finally {
-			server.shutdown();
-		}
+		this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
 	}
 
 	@Test
-	public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() throws Exception {
+	public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() {
 		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("invalid_user_info_response"));
+		this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error"));
 
-		MockWebServer server = new MockWebServer();
-		server.enqueue(new MockResponse().setResponseCode(500));
-		server.start();
+		this.server.enqueue(new MockResponse().setResponseCode(500));
 
-		String userInfoUri = server.url("/user").toString();
+		String userInfoUri = this.server.url("/user").toString();
 
 		when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
 		when(this.accessToken.getTokenValue()).thenReturn("access-token");
 
-		try {
-			this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
-		} finally {
-			server.shutdown();
-		}
+		this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
 	}
 
 	@Test
-	public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceException() throws Exception {
-		this.exception.expect(AuthenticationServiceException.class);
+	public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
 
 		String userInfoUri = "http://invalid-provider.com/user";
 
@@ -206,6 +203,12 @@ public class CustomUserTypesOAuth2UserServiceTests {
 		this.userService.loadUser(new OAuth2UserRequest(this.clientRegistration, this.accessToken));
 	}
 
+	private MockResponse jsonResponse(String json) {
+		return new MockResponse()
+				.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
+				.setBody(json);
+	}
+
 	public static class CustomOAuth2User implements OAuth2User {
 		private List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
 		private String id;