ソースを参照

Enable custom configuration for HTTP client

Fixes gh-4477
Joe Grandja 8 年 前
コミット
c872499eee

+ 31 - 9
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/AuthorizationCodeAuthenticationFilterConfigurer.java

@@ -15,9 +15,11 @@
  */
 package org.springframework.security.config.annotation.web.configurers.oauth2.client;
 
+import org.springframework.context.ApplicationContext;
 import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
 import org.springframework.security.config.annotation.web.configurers.AbstractAuthenticationFilterConfigurer;
 import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
+import org.springframework.security.jose.jws.JwsAlgorithm;
 import org.springframework.security.jwt.JwtDecoder;
 import org.springframework.security.jwt.nimbus.NimbusJwtDecoderJwkSupport;
 import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationProcessingFilter;
@@ -31,6 +33,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.user.OAuth2UserService;
 import org.springframework.security.oauth2.client.user.nimbus.NimbusOAuth2UserService;
+import org.springframework.security.oauth2.core.http.HttpClientConfig;
 import org.springframework.security.oauth2.core.provider.DefaultProviderMetadata;
 import org.springframework.security.oauth2.core.provider.ProviderMetadata;
 import org.springframework.security.oauth2.core.user.OAuth2User;
@@ -113,7 +116,7 @@ final class AuthorizationCodeAuthenticationFilterConfigurer<H extends HttpSecuri
 	@Override
 	public void init(H http) throws Exception {
 		AuthorizationCodeAuthenticationProvider authenticationProvider = new AuthorizationCodeAuthenticationProvider(
-				this.getAuthorizationCodeTokenExchanger(), this.getProviderJwtDecoderRegistry(), this.getUserInfoService());
+				this.getAuthorizationCodeTokenExchanger(http), this.getProviderJwtDecoderRegistry(http), this.getUserInfoService(http));
 		if (this.userAuthoritiesMapper != null) {
 			authenticationProvider.setAuthoritiesMapper(this.userAuthoritiesMapper);
 		}
@@ -134,14 +137,20 @@ final class AuthorizationCodeAuthenticationFilterConfigurer<H extends HttpSecuri
 		return this.getAuthenticationFilter().getAuthorizeRequestMatcher();
 	}
 
-	private AuthorizationGrantTokenExchanger<AuthorizationCodeAuthenticationToken> getAuthorizationCodeTokenExchanger() {
+	private AuthorizationGrantTokenExchanger<AuthorizationCodeAuthenticationToken> getAuthorizationCodeTokenExchanger(H http) {
 		if (this.authorizationCodeTokenExchanger == null) {
-			this.authorizationCodeTokenExchanger = new NimbusAuthorizationCodeTokenExchanger();
+			NimbusAuthorizationCodeTokenExchanger nimbusAuthorizationCodeTokenExchanger = new NimbusAuthorizationCodeTokenExchanger();
+			HttpClientConfig httpClientConfig = this.getHttpClientConfig(http);
+			if (httpClientConfig != null) {
+				nimbusAuthorizationCodeTokenExchanger.setHttpClientConfig(httpClientConfig);
+			}
+			this.authorizationCodeTokenExchanger = nimbusAuthorizationCodeTokenExchanger;
 		}
 		return this.authorizationCodeTokenExchanger;
 	}
 
-	private ProviderJwtDecoderRegistry getProviderJwtDecoderRegistry() {
+	private ProviderJwtDecoderRegistry getProviderJwtDecoderRegistry(H http) {
+		HttpClientConfig httpClientConfig = this.getHttpClientConfig(http);
 		Map<ProviderMetadata, JwtDecoder> jwtDecoders = new HashMap<>();
 		ClientRegistrationRepository clientRegistrationRepository = OAuth2LoginConfigurer.getClientRegistrationRepository(this.getBuilder());
 		clientRegistrationRepository.getRegistrations().stream().forEach(registration -> {
@@ -159,25 +168,38 @@ final class AuthorizationCodeAuthenticationFilterConfigurer<H extends HttpSecuri
 				providerMetadata.setTokenEndpoint(this.toURL(providerDetails.getTokenUri()));
 				providerMetadata.setUserInfoEndpoint(this.toURL(providerDetails.getUserInfoUri()));
 				providerMetadata.setJwkSetUri(this.toURL(providerDetails.getJwkSetUri()));
-				jwtDecoders.put(providerMetadata, new NimbusJwtDecoderJwkSupport(providerDetails.getJwkSetUri()));
+				NimbusJwtDecoderJwkSupport nimbusJwtDecoderJwkSupport = new NimbusJwtDecoderJwkSupport(
+					providerDetails.getJwkSetUri(), JwsAlgorithm.RS256, httpClientConfig);
+				jwtDecoders.put(providerMetadata, nimbusJwtDecoderJwkSupport);
 			}
 		});
 		return new DefaultProviderJwtDecoderRegistry(jwtDecoders);
 	}
 
-	private OAuth2UserService getUserInfoService() {
+	private OAuth2UserService getUserInfoService(H http) {
 		if (this.userInfoService == null) {
-			this.userInfoService = new NimbusOAuth2UserService();
+			NimbusOAuth2UserService nimbusOAuth2UserService = new NimbusOAuth2UserService();
 			if (!this.customUserTypes.isEmpty()) {
-				((NimbusOAuth2UserService)this.userInfoService).setCustomUserTypes(this.customUserTypes);
+				nimbusOAuth2UserService.setCustomUserTypes(this.customUserTypes);
 			}
 			if (!this.userNameAttributeNames.isEmpty()) {
-				((NimbusOAuth2UserService)this.userInfoService).setUserNameAttributeNames(this.userNameAttributeNames);
+				nimbusOAuth2UserService.setUserNameAttributeNames(this.userNameAttributeNames);
+			}
+			HttpClientConfig httpClientConfig = this.getHttpClientConfig(http);
+			if (httpClientConfig != null) {
+				nimbusOAuth2UserService.setHttpClientConfig(httpClientConfig);
 			}
+			this.userInfoService = nimbusOAuth2UserService;
 		}
 		return this.userInfoService;
 	}
 
+	private HttpClientConfig getHttpClientConfig(H http) {
+		Map<String, HttpClientConfig> httpClientConfigs =
+			http.getSharedObject(ApplicationContext.class).getBeansOfType(HttpClientConfig.class);
+		return (!httpClientConfigs.isEmpty() ? httpClientConfigs.values().iterator().next() : null);
+	}
+
 	private URL toURL(String urlStr) {
 		if (!StringUtils.hasText(urlStr)) {
 			return null;

+ 15 - 2
oauth2/jwt-jose/src/main/java/org/springframework/security/jwt/nimbus/NimbusJwtDecoderJwkSupport.java

@@ -21,6 +21,8 @@ import com.nimbusds.jose.jwk.source.RemoteJWKSet;
 import com.nimbusds.jose.proc.JWSKeySelector;
 import com.nimbusds.jose.proc.JWSVerificationKeySelector;
 import com.nimbusds.jose.proc.SecurityContext;
+import com.nimbusds.jose.util.DefaultResourceRetriever;
+import com.nimbusds.jose.util.ResourceRetriever;
 import com.nimbusds.jwt.JWT;
 import com.nimbusds.jwt.JWTClaimsSet;
 import com.nimbusds.jwt.JWTParser;
@@ -30,6 +32,7 @@ import org.springframework.security.jose.jws.JwsAlgorithm;
 import org.springframework.security.jwt.Jwt;
 import org.springframework.security.jwt.JwtDecoder;
 import org.springframework.security.jwt.JwtException;
+import org.springframework.security.oauth2.core.http.HttpClientConfig;
 import org.springframework.util.Assert;
 
 import java.net.MalformedURLException;
@@ -65,6 +68,10 @@ public class NimbusJwtDecoderJwkSupport implements JwtDecoder {
 	}
 
 	public NimbusJwtDecoderJwkSupport(String jwkSetUrl, String jwsAlgorithm) {
+		this(jwkSetUrl, jwsAlgorithm, null);
+	}
+
+	public NimbusJwtDecoderJwkSupport(String jwkSetUrl, String jwsAlgorithm, HttpClientConfig httpClientConfig) {
 		Assert.hasText(jwkSetUrl, "jwkSetUrl cannot be empty");
 		Assert.hasText(jwsAlgorithm, "jwsAlgorithm cannot be empty");
 		try {
@@ -74,10 +81,16 @@ public class NimbusJwtDecoderJwkSupport implements JwtDecoder {
 		}
 		this.jwsAlgorithm = JWSAlgorithm.parse(jwsAlgorithm);
 
-		this.jwtProcessor = new DefaultJWTProcessor<>();
-		JWKSource jwkSource = new RemoteJWKSet(this.jwkSetUrl);
+		int connectTimeout = (httpClientConfig != null ?
+			httpClientConfig.getConnectTimeout() : HttpClientConfig.DEFAULT_CONNECT_TIMEOUT);
+		int readTimeout = (httpClientConfig != null ?
+			httpClientConfig.getReadTimeout() : HttpClientConfig.DEFAULT_READ_TIMEOUT);
+		ResourceRetriever jwkSetRetriever = new DefaultResourceRetriever(connectTimeout, readTimeout);
+		JWKSource jwkSource = new RemoteJWKSet(this.jwkSetUrl, jwkSetRetriever);
 		JWSKeySelector<SecurityContext> jwsKeySelector =
 			new JWSVerificationKeySelector<SecurityContext>(this.jwsAlgorithm, jwkSource);
+
+		this.jwtProcessor = new DefaultJWTProcessor<>();
 		this.jwtProcessor.setJWSKeySelector(jwsKeySelector);
 	}
 

+ 10 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/nimbus/NimbusAuthorizationCodeTokenExchanger.java

@@ -28,11 +28,13 @@ import org.springframework.security.authentication.AuthenticationServiceExceptio
 import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationToken;
 import org.springframework.security.oauth2.client.authentication.AuthorizationGrantTokenExchanger;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.http.HttpClientConfig;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.core.AccessToken;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.endpoint.TokenResponseAttributes;
+import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
 
 import java.io.IOException;
@@ -61,6 +63,7 @@ import java.util.stream.Collectors;
  */
 public class NimbusAuthorizationCodeTokenExchanger implements AuthorizationGrantTokenExchanger<AuthorizationCodeAuthenticationToken> {
 	private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response";
+	private HttpClientConfig httpClientConfig = new HttpClientConfig();
 
 	@Override
 	public TokenResponseAttributes exchange(AuthorizationCodeAuthenticationToken authorizationCodeAuthenticationToken)
@@ -90,6 +93,8 @@ public class NimbusAuthorizationCodeTokenExchanger implements AuthorizationGrant
 			TokenRequest tokenRequest = new TokenRequest(tokenUri, clientAuthentication, authorizationCodeGrant);
 			HTTPRequest httpRequest = tokenRequest.toHTTPRequest();
 			httpRequest.setAccept(MediaType.APPLICATION_JSON_VALUE);
+			httpRequest.setConnectTimeout(this.httpClientConfig.getConnectTimeout());
+			httpRequest.setReadTimeout(this.httpClientConfig.getReadTimeout());
 			tokenResponse = TokenResponse.parse(httpRequest.send());
 		} catch (ParseException pe) {
 			// This error occurs if the Access Token Response is not well-formed,
@@ -132,6 +137,11 @@ public class NimbusAuthorizationCodeTokenExchanger implements AuthorizationGrant
 			.build();
 	}
 
+	public final void setHttpClientConfig(HttpClientConfig httpClientConfig) {
+		Assert.notNull(httpClientConfig, "httpClientConfig cannot be null");
+		this.httpClientConfig = httpClientConfig;
+	}
+
 	private URI toURI(String uriStr) {
 		try {
 			return new URI(uriStr);

+ 9 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/user/nimbus/NimbusOAuth2UserService.java

@@ -31,6 +31,7 @@ import org.springframework.security.authentication.AuthenticationServiceExceptio
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
+import org.springframework.security.oauth2.core.http.HttpClientConfig;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.user.OAuth2UserService;
 import org.springframework.security.oauth2.core.OAuth2Error;
@@ -72,6 +73,7 @@ public class NimbusOAuth2UserService implements OAuth2UserService {
 	private final HttpMessageConverter jackson2HttpMessageConverter = new MappingJackson2HttpMessageConverter();
 	private Map<URI, String> userNameAttributeNames = Collections.unmodifiableMap(Collections.emptyMap());
 	private Map<URI, Class<? extends OAuth2User>> customUserTypes = Collections.unmodifiableMap(Collections.emptyMap());
+	private HttpClientConfig httpClientConfig = new HttpClientConfig();
 
 	public NimbusOAuth2UserService() {
 	}
@@ -151,6 +153,8 @@ public class NimbusOAuth2UserService implements OAuth2UserService {
 		UserInfoRequest userInfoRequest = new UserInfoRequest(userInfoUri, accessToken);
 		HTTPRequest httpRequest = userInfoRequest.toHTTPRequest();
 		httpRequest.setAccept(MediaType.APPLICATION_JSON_VALUE);
+		httpRequest.setConnectTimeout(this.httpClientConfig.getConnectTimeout());
+		httpRequest.setReadTimeout(this.httpClientConfig.getReadTimeout());
 		HTTPResponse httpResponse;
 
 		try {
@@ -215,6 +219,11 @@ public class NimbusOAuth2UserService implements OAuth2UserService {
 		this.customUserTypes = Collections.unmodifiableMap(new HashMap<>(customUserTypes));
 	}
 
+	public final void setHttpClientConfig(HttpClientConfig httpClientConfig) {
+		Assert.notNull(httpClientConfig, "httpClientConfig cannot be null");
+		this.httpClientConfig = httpClientConfig;
+	}
+
 	private URI getUserInfoUri(OAuth2AuthenticationToken token) {
 		ClientRegistration clientRegistration = token.getClientRegistration();
 		try {

+ 89 - 0
oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/http/HttpClientConfig.java

@@ -0,0 +1,89 @@
+/*
+ * Copyright 2012-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.core.http;
+
+import org.springframework.util.Assert;
+
+/**
+ * This class provides the capability for configuring the underlying HTTP client.
+ *
+ * <p>
+ * To customize the configuration of the underlying HTTP client, create/configure
+ * an instance of {@link HttpClientConfig} and register it with the <code>ApplicationContext</code>.
+ *
+ * <p>
+ * For example:
+ *
+ * <pre>
+ * &#064;Bean
+ * public HttpClientConfig httpClientConfig() {
+ *    HttpClientConfig httpClientConfig = new HttpClientConfig();
+ *    httpClientConfig.setConnectTimeout(60000);
+ *    httpClientConfig.setReadTimeout(60000);
+ *    return httpClientConfig;
+ * }
+ * </pre>
+ *
+ * @author Joe Grandja
+ * @since 5.0
+ */
+public class HttpClientConfig {
+	public static final int DEFAULT_CONNECT_TIMEOUT = 30000;
+	public static final int DEFAULT_READ_TIMEOUT = 30000;
+	private int connectTimeout = DEFAULT_CONNECT_TIMEOUT;
+	private int readTimeout = DEFAULT_READ_TIMEOUT;
+
+	/**
+	 * Returns the timeout in milliseconds until a connection is established.
+	 *
+	 * @return the connect timeout value in milliseconds
+	 */
+	public int getConnectTimeout() {
+		return this.connectTimeout;
+	}
+
+	/**
+	 * Sets the timeout in milliseconds until a connection is established.
+	 * A timeout value of 0 implies the option is disabled (timeout of infinity).
+	 *
+	 * @param connectTimeout the connect timeout value in milliseconds
+	 */
+	public void setConnectTimeout(int connectTimeout) {
+		Assert.isTrue(connectTimeout >= 0, "connectTimeout cannot be negative");
+		this.connectTimeout = connectTimeout;
+	}
+
+	/**
+	 * Returns the timeout in milliseconds for inactivity when reading from the <code>InputStream</code>.
+	 *
+	 * @return the read timeout value in milliseconds
+	 */
+	public int getReadTimeout() {
+		return this.readTimeout;
+	}
+
+	/**
+	 * Sets the timeout in milliseconds for inactivity when reading from the <code>InputStream</code>.
+	 * A timeout value of 0 implies the option is disabled (timeout of infinity).
+	 *
+	 * @param readTimeout the read timeout value in milliseconds
+	 */
+	public void setReadTimeout(int readTimeout) {
+		Assert.isTrue(readTimeout >= 0, "readTimeout cannot be negative");
+		this.readTimeout = readTimeout;
+	}
+
+}