Просмотр исходного кода

Add support for client_credentials grant

Fixes gh-4982
Joe Grandja 7 лет назад
Родитель
Сommit
952743269d
12 измененных файлов с 1081 добавлено и 32 удалено
  1. 12 2
      config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java
  2. 5 0
      config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java
  3. 270 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java
  4. 56 0
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequest.java
  5. 13 1
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java
  6. 85 11
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java
  7. 326 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java
  8. 76 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestTests.java
  9. 88 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java
  10. 118 17
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java
  11. 1 0
      oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthorizationGrantType.java
  12. 31 1
      oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java

+ 12 - 2
config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java

@@ -20,6 +20,7 @@ import org.springframework.context.annotation.Configuration;
 import org.springframework.context.annotation.Import;
 import org.springframework.context.annotation.ImportSelector;
 import org.springframework.core.type.AnnotationMetadata;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver;
 import org.springframework.util.ClassUtils;
@@ -57,17 +58,26 @@ final class OAuth2ClientConfiguration {
 
 	@Configuration
 	static class OAuth2ClientWebMvcSecurityConfiguration implements WebMvcConfigurer {
+		private ClientRegistrationRepository clientRegistrationRepository;
 		private OAuth2AuthorizedClientRepository authorizedClientRepository;
 
 		@Override
 		public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) {
-			if (this.authorizedClientRepository != null) {
+			if (this.clientRegistrationRepository != null && this.authorizedClientRepository != null) {
 				OAuth2AuthorizedClientArgumentResolver authorizedClientArgumentResolver =
-						new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientRepository);
+						new OAuth2AuthorizedClientArgumentResolver(
+								this.clientRegistrationRepository, this.authorizedClientRepository);
 				argumentResolvers.add(authorizedClientArgumentResolver);
 			}
 		}
 
+		@Autowired(required = false)
+		public void setClientRegistrationRepository(List<ClientRegistrationRepository> clientRegistrationRepositories) {
+			if (clientRegistrationRepositories.size() == 1) {
+				this.clientRegistrationRepository = clientRegistrationRepositories.get(0);
+			}
+		}
+
 		@Autowired(required = false)
 		public void setAuthorizedClientRepository(List<OAuth2AuthorizedClientRepository> authorizedClientRepositories) {
 			if (authorizedClientRepositories.size() == 1) {

+ 5 - 0
config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java

@@ -98,6 +98,11 @@ public class OAuth2ClientConfigurationTests {
 			}
 		}
 
+		@Bean
+		public ClientRegistrationRepository clientRegistrationRepository() {
+			return mock(ClientRegistrationRepository.class);
+		}
+
 		@Bean
 		public OAuth2AuthorizedClientRepository authorizedClientRepository() {
 			return AUTHORIZED_CLIENT_REPOSITORY;

+ 270 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java

@@ -0,0 +1,270 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.endpoint;
+
+import org.springframework.core.ParameterizedTypeReference;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpMethod;
+import org.springframework.http.MediaType;
+import org.springframework.http.RequestEntity;
+import org.springframework.http.ResponseEntity;
+import org.springframework.http.client.ClientHttpResponse;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
+import org.springframework.util.LinkedMultiValueMap;
+import org.springframework.util.MultiValueMap;
+import org.springframework.util.StringUtils;
+import org.springframework.web.client.ResponseErrorHandler;
+import org.springframework.web.client.RestOperations;
+import org.springframework.web.client.RestTemplate;
+import org.springframework.web.util.UriComponentsBuilder;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+/**
+ * The default implementation of an {@link OAuth2AccessTokenResponseClient}
+ * for the {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} grant.
+ * This implementation uses a {@link RestOperations} when requesting
+ * an access token credential at the Authorization Server's Token Endpoint.
+ *
+ * @author Joe Grandja
+ * @since 5.1
+ * @see OAuth2AccessTokenResponseClient
+ * @see OAuth2ClientCredentialsGrantRequest
+ * @see OAuth2AccessTokenResponse
+ * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.4.2">Section 4.4.2 Access Token Request (Client Credentials Grant)</a>
+ * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.4.3">Section 4.4.3 Access Token Response (Client Credentials Grant)</a>
+ */
+public class DefaultClientCredentialsTokenResponseClient implements OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> {
+	private static final String INVALID_TOKEN_REQUEST_ERROR_CODE = "invalid_token_request";
+
+	private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response";
+
+	private static final String[] TOKEN_RESPONSE_PARAMETER_NAMES = {
+			OAuth2ParameterNames.ACCESS_TOKEN,
+			OAuth2ParameterNames.TOKEN_TYPE,
+			OAuth2ParameterNames.EXPIRES_IN,
+			OAuth2ParameterNames.SCOPE,
+			OAuth2ParameterNames.REFRESH_TOKEN
+	};
+
+	private RestOperations restOperations;
+
+	public DefaultClientCredentialsTokenResponseClient() {
+		RestTemplate restTemplate = new RestTemplate();
+		// Disable the ResponseErrorHandler as errors are handled directly within this class
+		restTemplate.setErrorHandler(new NoOpResponseErrorHandler());
+		this.restOperations = restTemplate;
+	}
+
+	@Override
+	public OAuth2AccessTokenResponse getTokenResponse(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest)
+			throws OAuth2AuthenticationException {
+
+		Assert.notNull(clientCredentialsGrantRequest, "clientCredentialsGrantRequest cannot be null");
+
+		// Build request
+		RequestEntity<MultiValueMap<String, String>> request = this.buildRequest(clientCredentialsGrantRequest);
+
+		// Exchange
+		ResponseEntity<Map<String, String>> response;
+		try {
+			response = this.restOperations.exchange(
+					request, new ParameterizedTypeReference<Map<String, String>>() {});
+		} catch (Exception ex) {
+			OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_REQUEST_ERROR_CODE,
+					"An error occurred while sending the Access Token Request: " + ex.getMessage(), null);
+			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
+		}
+
+		Map<String, String> responseParameters = response.getBody();
+
+		// Check for Error Response
+		if (response.getStatusCodeValue() != 200) {
+			OAuth2Error oauth2Error = this.parseErrorResponse(responseParameters);
+			if (oauth2Error == null) {
+				oauth2Error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR);
+			}
+			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
+		}
+
+		// Success Response
+		OAuth2AccessTokenResponse tokenResponse;
+		try {
+			tokenResponse = this.parseTokenResponse(responseParameters);
+		} catch (Exception ex) {
+			OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
+					"An error occurred parsing the Access Token response (200 OK): " + ex.getMessage(), null);
+			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
+		}
+
+		if (tokenResponse == null) {
+			// This should never happen as long as the provider
+			// implements a Successful Response as defined in Section 5.1
+			// https://tools.ietf.org/html/rfc6749#section-5.1
+			OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
+					"An error occurred parsing the Access Token response (200 OK). " +
+							"Missing required parameters: access_token and/or token_type", null);
+			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
+		}
+
+		if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) {
+			// As per spec, in Section 5.1 Successful Access Token Response
+			// https://tools.ietf.org/html/rfc6749#section-5.1
+			// If AccessTokenResponse.scope is empty, then default to the scope
+			// originally requested by the client in the Token Request
+			tokenResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse)
+					.scopes(clientCredentialsGrantRequest.getClientRegistration().getScopes())
+					.build();
+		}
+
+		return tokenResponse;
+	}
+
+	private RequestEntity<MultiValueMap<String, String>> buildRequest(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) {
+		HttpHeaders headers = this.buildHeaders(clientCredentialsGrantRequest);
+		MultiValueMap<String, String> formParameters = this.buildFormParameters(clientCredentialsGrantRequest);
+		URI uri = UriComponentsBuilder.fromUriString(clientCredentialsGrantRequest.getClientRegistration().getProviderDetails().getTokenUri())
+				.build()
+				.toUri();
+
+		return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri);
+	}
+
+	private HttpHeaders buildHeaders(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) {
+		ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration();
+
+		HttpHeaders headers = new HttpHeaders();
+		headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON));
+		headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
+		if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) {
+			headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
+		}
+
+		return headers;
+	}
+
+	private MultiValueMap<String, String> buildFormParameters(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) {
+		ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration();
+
+		MultiValueMap<String, String> formParameters = new LinkedMultiValueMap<>();
+		formParameters.add(OAuth2ParameterNames.GRANT_TYPE, clientCredentialsGrantRequest.getGrantType().getValue());
+		if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) {
+			formParameters.add(OAuth2ParameterNames.SCOPE,
+					StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
+		}
+		if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) {
+			formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
+			formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
+		}
+
+		return formParameters;
+	}
+
+	private OAuth2Error parseErrorResponse(Map<String, String> responseParameters) {
+		if (CollectionUtils.isEmpty(responseParameters) ||
+				!responseParameters.containsKey(OAuth2ParameterNames.ERROR)) {
+			return null;
+		}
+
+		String errorCode = responseParameters.get(OAuth2ParameterNames.ERROR);
+		String errorDescription = responseParameters.get(OAuth2ParameterNames.ERROR_DESCRIPTION);
+		String errorUri = responseParameters.get(OAuth2ParameterNames.ERROR_URI);
+
+		return new OAuth2Error(errorCode, errorDescription, errorUri);
+	}
+
+	private OAuth2AccessTokenResponse parseTokenResponse(Map<String, String> responseParameters) {
+		if (CollectionUtils.isEmpty(responseParameters) ||
+				!responseParameters.containsKey(OAuth2ParameterNames.ACCESS_TOKEN) ||
+				!responseParameters.containsKey(OAuth2ParameterNames.TOKEN_TYPE)) {
+			return null;
+		}
+
+		String accessToken = responseParameters.get(OAuth2ParameterNames.ACCESS_TOKEN);
+
+		OAuth2AccessToken.TokenType accessTokenType = null;
+		if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(
+				responseParameters.get(OAuth2ParameterNames.TOKEN_TYPE))) {
+			accessTokenType = OAuth2AccessToken.TokenType.BEARER;
+		}
+
+		long expiresIn = 0;
+		if (responseParameters.containsKey(OAuth2ParameterNames.EXPIRES_IN)) {
+			try {
+				expiresIn = Long.valueOf(responseParameters.get(OAuth2ParameterNames.EXPIRES_IN));
+			} catch (NumberFormatException ex) { }
+		}
+
+		Set<String> scopes = Collections.emptySet();
+		if (responseParameters.containsKey(OAuth2ParameterNames.SCOPE)) {
+			String scope = responseParameters.get(OAuth2ParameterNames.SCOPE);
+			scopes = Arrays.stream(StringUtils.delimitedListToStringArray(scope, " ")).collect(Collectors.toSet());
+		}
+
+		Map<String, Object> additionalParameters = new LinkedHashMap<>();
+		Set<String> tokenResponseParameterNames = Stream.of(TOKEN_RESPONSE_PARAMETER_NAMES).collect(Collectors.toSet());
+		responseParameters.entrySet().stream()
+				.filter(e -> !tokenResponseParameterNames.contains(e.getKey()))
+				.forEach(e -> additionalParameters.put(e.getKey(), e.getValue()));
+
+		return OAuth2AccessTokenResponse.withToken(accessToken)
+				.tokenType(accessTokenType)
+				.expiresIn(expiresIn)
+				.scopes(scopes)
+				.additionalParameters(additionalParameters)
+				.build();
+	}
+
+	/**
+	 * Sets the {@link RestOperations} used when requesting the access token response.
+	 *
+	 * @param restOperations the {@link RestOperations} used when requesting the access token response
+	 */
+	public final void setRestOperations(RestOperations restOperations) {
+		Assert.notNull(restOperations, "restOperations cannot be null");
+		this.restOperations = restOperations;
+	}
+
+	private static class NoOpResponseErrorHandler implements ResponseErrorHandler {
+
+		@Override
+		public boolean hasError(ClientHttpResponse response) throws IOException {
+			return false;
+		}
+
+		@Override
+		public void handleError(ClientHttpResponse response) throws IOException {
+		}
+	}
+}

+ 56 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequest.java

@@ -0,0 +1,56 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.endpoint;
+
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.util.Assert;
+
+/**
+ * An OAuth 2.0 Client Credentials Grant request that holds
+ * the client's credentials in {@link #getClientRegistration()}.
+ *
+ * @author Joe Grandja
+ * @since 5.1
+ * @see AbstractOAuth2AuthorizationGrantRequest
+ * @see ClientRegistration
+ * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-1.3.4">Section 1.3.4 Client Credentials Grant</a>
+ */
+public class OAuth2ClientCredentialsGrantRequest extends AbstractOAuth2AuthorizationGrantRequest {
+	private final ClientRegistration clientRegistration;
+
+	/**
+	 * Constructs an {@code OAuth2ClientCredentialsGrantRequest} using the provided parameters.
+	 *
+	 * @param clientRegistration the client registration
+	 */
+	public OAuth2ClientCredentialsGrantRequest(ClientRegistration clientRegistration) {
+		super(AuthorizationGrantType.CLIENT_CREDENTIALS);
+		Assert.notNull(clientRegistration, "clientRegistration cannot be null");
+		Assert.isTrue(AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType()),
+				"clientRegistration.authorizationGrantType must be AuthorizationGrantType.CLIENT_CREDENTIALS");
+		this.clientRegistration = clientRegistration;
+	}
+
+	/**
+	 * Returns the {@link ClientRegistration client registration}.
+	 *
+	 * @return the {@link ClientRegistration}
+	 */
+	public ClientRegistration getClientRegistration() {
+		return this.clientRegistration;
+	}
+}

+ 13 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java

@@ -448,7 +448,9 @@ public final class ClientRegistration {
 		 */
 		public ClientRegistration build() {
 			Assert.notNull(this.authorizationGrantType, "authorizationGrantType cannot be null");
-			if (AuthorizationGrantType.IMPLICIT.equals(this.authorizationGrantType)) {
+			if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(this.authorizationGrantType)) {
+				this.validateClientCredentialsGrantType();
+			} else if (AuthorizationGrantType.IMPLICIT.equals(this.authorizationGrantType)) {
 				this.validateImplicitGrantType();
 			} else {
 				this.validateAuthorizationCodeGrantType();
@@ -507,5 +509,15 @@ public final class ClientRegistration {
 			Assert.hasText(this.authorizationUri, "authorizationUri cannot be empty");
 			Assert.hasText(this.clientName, "clientName cannot be empty");
 		}
+
+		private void validateClientCredentialsGrantType() {
+			Assert.isTrue(AuthorizationGrantType.CLIENT_CREDENTIALS.equals(this.authorizationGrantType),
+					() -> "authorizationGrantType must be " + AuthorizationGrantType.CLIENT_CREDENTIALS.getValue());
+			Assert.hasText(this.registrationId, "registrationId cannot be empty");
+			Assert.hasText(this.clientId, "clientId cannot be empty");
+			Assert.hasText(this.clientSecret, "clientSecret cannot be empty");
+			Assert.notNull(this.clientAuthenticationMethod, "clientAuthenticationMethod cannot be null");
+			Assert.hasText(this.tokenUri, "tokenUri cannot be empty");
+		}
 	}
 }

+ 85 - 11
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java

@@ -25,7 +25,14 @@ import org.springframework.security.oauth2.client.ClientAuthorizationRequiredExc
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
+import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient;
+import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
+import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
 import org.springframework.web.bind.support.WebDataBinderFactory;
@@ -34,6 +41,7 @@ import org.springframework.web.method.support.HandlerMethodArgumentResolver;
 import org.springframework.web.method.support.ModelAndViewContainer;
 
 import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
 
 /**
  * An implementation of a {@link HandlerMethodArgumentResolver} that is capable
@@ -56,15 +64,22 @@ import javax.servlet.http.HttpServletRequest;
  * @see RegisteredOAuth2AuthorizedClient
  */
 public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver {
+	private final ClientRegistrationRepository clientRegistrationRepository;
 	private final OAuth2AuthorizedClientRepository authorizedClientRepository;
+	private OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
+			new DefaultClientCredentialsTokenResponseClient();
 
 	/**
 	 * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters.
 	 *
-	 * @param authorizedClientRepository the authorized client repository
+	 * @param clientRegistrationRepository the repository of client registrations
+	 * @param authorizedClientRepository the repository of authorized clients
 	 */
-	public OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientRepository authorizedClientRepository) {
+	public OAuth2AuthorizedClientArgumentResolver(ClientRegistrationRepository clientRegistrationRepository,
+													OAuth2AuthorizedClientRepository authorizedClientRepository) {
+		Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
 		Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
+		this.clientRegistrationRepository = clientRegistrationRepository;
 		this.authorizedClientRepository = authorizedClientRepository;
 	}
 
@@ -83,8 +98,43 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
 									NativeWebRequest webRequest,
 									@Nullable WebDataBinderFactory binderFactory) throws Exception {
 
+		String clientRegistrationId = this.resolveClientRegistrationId(parameter);
+		if (StringUtils.isEmpty(clientRegistrationId)) {
+			throw new IllegalArgumentException("Unable to resolve the Client Registration Identifier. " +
+					"It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or " +
+					"@RegisteredOAuth2AuthorizedClient(registrationId = \"client1\").");
+		}
+
+		Authentication principal = SecurityContextHolder.getContext().getAuthentication();
+		HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class);
+
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
+				clientRegistrationId, principal, servletRequest);
+		if (authorizedClient != null) {
+			return authorizedClient;
+		}
+
+		ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
+		if (clientRegistration == null) {
+			return null;
+		}
+
+		if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
+			throw new ClientAuthorizationRequiredException(clientRegistrationId);
+		}
+
+		if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
+			HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class);
+			authorizedClient = this.authorizeClientCredentialsClient(clientRegistration, servletRequest, servletResponse);
+		}
+
+		return authorizedClient;
+	}
+
+	private String resolveClientRegistrationId(MethodParameter parameter) {
 		RegisteredOAuth2AuthorizedClient authorizedClientAnnotation = AnnotatedElementUtils.findMergedAnnotation(
 				parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class);
+
 		Authentication principal = SecurityContextHolder.getContext().getAuthentication();
 
 		String clientRegistrationId = null;
@@ -95,17 +145,41 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
 		} else if (principal != null && OAuth2AuthenticationToken.class.isAssignableFrom(principal.getClass())) {
 			clientRegistrationId = ((OAuth2AuthenticationToken) principal).getAuthorizedClientRegistrationId();
 		}
-		if (StringUtils.isEmpty(clientRegistrationId)) {
-			throw new IllegalArgumentException("Unable to resolve the Client Registration Identifier. " +
-					"It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or @RegisteredOAuth2AuthorizedClient(registrationId = \"client1\").");
-		}
 
-		OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
-			clientRegistrationId, principal, webRequest.getNativeRequest(HttpServletRequest.class));
-		if (authorizedClient == null) {
-			throw new ClientAuthorizationRequiredException(clientRegistrationId);
-		}
+		return clientRegistrationId;
+	}
+
+	private OAuth2AuthorizedClient authorizeClientCredentialsClient(ClientRegistration clientRegistration,
+																	HttpServletRequest request, HttpServletResponse response) {
+		OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest =
+				new OAuth2ClientCredentialsGrantRequest(clientRegistration);
+		OAuth2AccessTokenResponse tokenResponse =
+				this.clientCredentialsTokenResponseClient.getTokenResponse(clientCredentialsGrantRequest);
+
+		Authentication principal = SecurityContextHolder.getContext().getAuthentication();
+
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+				clientRegistration,
+				(principal != null ? principal.getName() : "anonymousUser"),
+				tokenResponse.getAccessToken());
+
+		this.authorizedClientRepository.saveAuthorizedClient(
+				authorizedClient,
+				principal,
+				request,
+				response);
 
 		return authorizedClient;
 	}
+
+	/**
+	 * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant.
+	 *
+	 * @param clientCredentialsTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant
+	 */
+	public final void setClientCredentialsTokenResponseClient(
+			OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
+		Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
+		this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
+	}
 }

+ 326 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClientTests.java

@@ -0,0 +1,326 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.endpoint;
+
+import okhttp3.mockwebserver.MockResponse;
+import okhttp3.mockwebserver.MockWebServer;
+import okhttp3.mockwebserver.RecordedRequest;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpMethod;
+import org.springframework.http.MediaType;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+
+import java.time.Instant;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Tests for {@link DefaultClientCredentialsTokenResponseClient}.
+ *
+ * @author Joe Grandja
+ */
+public class DefaultClientCredentialsTokenResponseClientTests {
+	private DefaultClientCredentialsTokenResponseClient tokenResponseClient = new DefaultClientCredentialsTokenResponseClient();
+	private ClientRegistration clientRegistration;
+	private MockWebServer server;
+
+	@Before
+	public void setup() throws Exception {
+		this.server = new MockWebServer();
+		this.server.start();
+
+		String tokenUri = this.server.url("/oauth2/token").toString();
+
+		this.clientRegistration = ClientRegistration.withRegistrationId("registration-1")
+				.clientId("client-1")
+				.clientSecret("secret")
+				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+				.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
+				.scope("read", "write")
+				.tokenUri(tokenUri)
+				.build();
+	}
+
+	@After
+	public void cleanup() throws Exception {
+		this.server.shutdown();
+	}
+
+	@Test
+	public void setRestOperationsWhenRestOperationsIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.tokenResponseClient.setRestOperations(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void getTokenResponseWhenRequestIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
+		String accessTokenSuccessResponse = "{\n" +
+				"	\"access_token\": \"access-token-1234\",\n" +
+				"   \"token_type\": \"bearer\",\n" +
+				"   \"expires_in\": \"3600\",\n" +
+				"   \"scope\": \"read write\",\n" +
+				"   \"custom_parameter_1\": \"custom-value-1\",\n" +
+				"   \"custom_parameter_2\": \"custom-value-2\"\n" +
+				"}\n";
+		this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
+
+		Instant expiresAtBefore = Instant.now().plusSeconds(3600);
+
+		OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest =
+				new OAuth2ClientCredentialsGrantRequest(this.clientRegistration);
+
+		OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest);
+
+		Instant expiresAtAfter = Instant.now().plusSeconds(3600);
+
+		RecordedRequest recordedRequest = this.server.takeRequest();
+		assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString());
+		assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON.toString());
+		assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)).startsWith(MediaType.APPLICATION_FORM_URLENCODED.toString());
+
+		String formParameters = recordedRequest.getBody().readUtf8();
+		assertThat(formParameters).contains("grant_type=client_credentials");
+		assertThat(formParameters).contains("scope=read+write");
+
+		assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234");
+		assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER);
+		assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter);
+		assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read", "write");
+		assertThat(accessTokenResponse.getRefreshToken()).isNull();
+		assertThat(accessTokenResponse.getAdditionalParameters().size()).isEqualTo(2);
+		assertThat(accessTokenResponse.getAdditionalParameters()).containsEntry("custom_parameter_1", "custom-value-1");
+		assertThat(accessTokenResponse.getAdditionalParameters()).containsEntry("custom_parameter_2", "custom-value-2");
+	}
+
+	@Test
+	public void getTokenResponseWhenClientAuthenticationBasicThenAuthorizationHeaderIsSent() throws Exception {
+		String accessTokenSuccessResponse = "{\n" +
+				"	\"access_token\": \"access-token-1234\",\n" +
+				"   \"token_type\": \"bearer\",\n" +
+				"   \"expires_in\": \"3600\"\n" +
+				"}\n";
+		this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
+
+		OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest =
+				new OAuth2ClientCredentialsGrantRequest(this.clientRegistration);
+
+		this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest);
+
+		RecordedRequest recordedRequest = this.server.takeRequest();
+		assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic ");
+	}
+
+	@Test
+	public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception {
+		String accessTokenSuccessResponse = "{\n" +
+				"	\"access_token\": \"access-token-1234\",\n" +
+				"   \"token_type\": \"bearer\",\n" +
+				"   \"expires_in\": \"3600\"\n" +
+				"}\n";
+		this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
+
+		ClientRegistration clientRegistration = this.from(this.clientRegistration)
+				.clientAuthenticationMethod(ClientAuthenticationMethod.POST)
+				.build();
+
+		OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest =
+				new OAuth2ClientCredentialsGrantRequest(clientRegistration);
+
+		this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest);
+
+		RecordedRequest recordedRequest = this.server.takeRequest();
+		assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
+
+		String formParameters = recordedRequest.getBody().readUtf8();
+		assertThat(formParameters).contains("client_id=client-1");
+		assertThat(formParameters).contains("client_secret=secret");
+	}
+
+	@Test
+	public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthenticationException() {
+		String accessTokenSuccessResponse = "{\n" +
+				"	\"access_token\": \"access-token-1234\",\n" +
+				"   \"token_type\": \"not-bearer\",\n" +
+				"   \"expires_in\": \"3600\"\n" +
+				"}\n";
+		this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
+
+		OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest =
+				new OAuth2ClientCredentialsGrantRequest(this.clientRegistration);
+
+		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.hasMessageContaining("[invalid_token_response] An error occurred parsing the Access Token response (200 OK): tokenType cannot be null");
+	}
+
+	@Test
+	public void getTokenResponseWhenSuccessResponseAndMissingTokenTypeParameterThenThrowOAuth2AuthenticationException() {
+		String accessTokenSuccessResponse = "{\n" +
+				"	\"access_token\": \"access-token-1234\"\n" +
+				"}\n";
+		this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
+
+		OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest =
+				new OAuth2ClientCredentialsGrantRequest(this.clientRegistration);
+
+		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.hasMessageContaining("[invalid_token_response] An error occurred parsing the Access Token response (200 OK). Missing required parameters: access_token and/or token_type");
+	}
+
+	@Test
+	public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() {
+		String accessTokenSuccessResponse = "{\n" +
+				"	\"access_token\": \"access-token-1234\",\n" +
+				"   \"token_type\": \"bearer\",\n" +
+				"   \"expires_in\": \"3600\",\n" +
+				"   \"scope\": \"read\"\n" +
+				"}\n";
+		this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
+
+		OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest =
+				new OAuth2ClientCredentialsGrantRequest(this.clientRegistration);
+
+		OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest);
+
+		assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read");
+	}
+
+	@Test
+	public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasDefaultScope() {
+		String accessTokenSuccessResponse = "{\n" +
+				"	\"access_token\": \"access-token-1234\",\n" +
+				"   \"token_type\": \"bearer\",\n" +
+				"   \"expires_in\": \"3600\"\n" +
+				"}\n";
+		this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
+
+		OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest =
+				new OAuth2ClientCredentialsGrantRequest(this.clientRegistration);
+
+		OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest);
+
+		assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read", "write");
+	}
+
+	@Test
+	public void getTokenResponseWhenTokenUriMalformedThenThrowOAuth2AuthenticationException() {
+		String malformedTokenUri = "http:\\provider.com\\oauth2\\token";
+		ClientRegistration clientRegistration = this.from(this.clientRegistration)
+				.tokenUri(malformedTokenUri)
+				.build();
+
+		OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest =
+				new OAuth2ClientCredentialsGrantRequest(clientRegistration);
+
+		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.hasMessageContaining("[invalid_token_request] An error occurred while sending the Access Token Request:");
+	}
+
+	@Test
+	public void getTokenResponseWhenTokenUriInvalidThenThrowOAuth2AuthenticationException() {
+		String invalidTokenUri = "http://invalid-provider.com/oauth2/token";
+		ClientRegistration clientRegistration = this.from(this.clientRegistration)
+				.tokenUri(invalidTokenUri)
+				.build();
+
+		OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest =
+				new OAuth2ClientCredentialsGrantRequest(clientRegistration);
+
+		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.hasMessageContaining("[invalid_token_request] An error occurred while sending the Access Token Request:");
+	}
+
+	@Test
+	public void getTokenResponseWhenMalformedResponseThenThrowOAuth2AuthenticationException() {
+		String accessTokenSuccessResponse = "{\n" +
+				"	\"access_token\": \"access-token-1234\",\n" +
+				"   \"token_type\": \"bearer\",\n" +
+				"   \"expires_in\": \"3600\",\n" +
+				"   \"scope\": \"read write\",\n" +
+				"   \"custom_parameter_1\": \"custom-value-1\",\n" +
+				"   \"custom_parameter_2\": \"custom-value-2\"\n";
+//			"}\n";		// Make the JSON invalid/malformed
+		this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
+
+		OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest =
+				new OAuth2ClientCredentialsGrantRequest(this.clientRegistration);
+
+		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.hasMessageContaining("[invalid_token_request] An error occurred while sending the Access Token Request:");
+	}
+
+	@Test
+	public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthenticationException() {
+		String accessTokenErrorResponse = "{\n" +
+				"   \"error\": \"unauthorized_client\"\n" +
+				"}\n";
+		this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400));
+
+		OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest =
+				new OAuth2ClientCredentialsGrantRequest(this.clientRegistration);
+
+		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.hasMessageContaining("[unauthorized_client]");
+	}
+
+	@Test
+	public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthenticationException() {
+		this.server.enqueue(new MockResponse().setResponseCode(500));
+
+		OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest =
+				new OAuth2ClientCredentialsGrantRequest(this.clientRegistration);
+
+		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.hasMessageContaining("[server_error]");
+	}
+
+	private MockResponse jsonResponse(String json) {
+		return new MockResponse()
+				.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
+				.setBody(json);
+	}
+
+	private ClientRegistration.Builder from(ClientRegistration registration) {
+		return ClientRegistration.withRegistrationId(registration.getRegistrationId())
+				.clientId(registration.getClientId())
+				.clientSecret(registration.getClientSecret())
+				.clientAuthenticationMethod(registration.getClientAuthenticationMethod())
+				.authorizationGrantType(registration.getAuthorizationGrantType())
+				.scope(registration.getScopes())
+				.tokenUri(registration.getProviderDetails().getTokenUri());
+	}
+}

+ 76 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2ClientCredentialsGrantRequestTests.java

@@ -0,0 +1,76 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.endpoint;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Java6Assertions.assertThatThrownBy;
+
+/**
+ * Tests for {@link OAuth2ClientCredentialsGrantRequest}.
+ *
+ * @author Joe Grandja
+ */
+public class OAuth2ClientCredentialsGrantRequestTests {
+	private ClientRegistration clientRegistration;
+
+	@Before
+	public void setup() {
+		this.clientRegistration = ClientRegistration.withRegistrationId("registration-1")
+				.clientId("client-1")
+				.clientSecret("secret")
+				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+				.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
+				.scope("read", "write")
+				.tokenUri("https://provider.com/oauth2/token")
+				.build();
+	}
+
+	@Test
+	public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2ClientCredentialsGrantRequest(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void constructorWhenClientRegistrationInvalidGrantTypeThenThrowIllegalArgumentException() {
+		ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("registration-1")
+				.clientId("client-1")
+				.authorizationGrantType(AuthorizationGrantType.IMPLICIT)
+				.redirectUriTemplate("https://localhost:8080/redirect-uri")
+				.authorizationUri("https://provider.com/oauth2/auth")
+				.clientName("Client 1")
+				.build();
+
+		assertThatThrownBy(() -> new OAuth2ClientCredentialsGrantRequest(clientRegistration))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("clientRegistration.authorizationGrantType must be AuthorizationGrantType.CLIENT_CREDENTIALS");
+	}
+
+	@Test
+	public void constructorWhenValidParametersProvidedThenCreated() {
+		OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest =
+				new OAuth2ClientCredentialsGrantRequest(this.clientRegistration);
+
+		assertThat(clientCredentialsGrantRequest.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(clientCredentialsGrantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS);
+	}
+}

+ 88 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java

@@ -25,6 +25,7 @@ import java.util.LinkedHashSet;
 import java.util.Set;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /**
  * Tests for {@link ClientRegistration}.
@@ -411,4 +412,91 @@ public class ClientRegistrationTests {
 
 		assertThat(registration.getRegistrationId()).isEqualTo(overriddenId);
 	}
+
+	@Test
+	public void buildWhenClientCredentialsGrantAllAttributesProvidedThenAllAttributesAreSet() {
+		ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
+				.clientId(CLIENT_ID)
+				.clientSecret(CLIENT_SECRET)
+				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+				.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
+				.scope(SCOPES.toArray(new String[0]))
+				.tokenUri(TOKEN_URI)
+				.clientName(CLIENT_NAME)
+				.build();
+
+		assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID);
+		assertThat(registration.getClientId()).isEqualTo(CLIENT_ID);
+		assertThat(registration.getClientSecret()).isEqualTo(CLIENT_SECRET);
+		assertThat(registration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC);
+		assertThat(registration.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS);
+		assertThat(registration.getScopes()).isEqualTo(SCOPES);
+		assertThat(registration.getProviderDetails().getTokenUri()).isEqualTo(TOKEN_URI);
+		assertThat(registration.getClientName()).isEqualTo(CLIENT_NAME);
+	}
+
+	@Test
+	public void buildWhenClientCredentialsGrantRegistrationIdIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() ->
+				ClientRegistration.withRegistrationId(null)
+						.clientId(CLIENT_ID)
+						.clientSecret(CLIENT_SECRET)
+						.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+						.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
+						.tokenUri(TOKEN_URI)
+						.build()
+		).isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void buildWhenClientCredentialsGrantClientIdIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() ->
+				ClientRegistration.withRegistrationId(REGISTRATION_ID)
+						.clientId(null)
+						.clientSecret(CLIENT_SECRET)
+						.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+						.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
+						.tokenUri(TOKEN_URI)
+						.build()
+		).isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void buildWhenClientCredentialsGrantClientSecretIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() ->
+				ClientRegistration.withRegistrationId(REGISTRATION_ID)
+						.clientId(CLIENT_ID)
+						.clientSecret(null)
+						.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+						.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
+						.tokenUri(TOKEN_URI)
+						.build()
+		).isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void buildWhenClientCredentialsGrantClientAuthenticationMethodIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() ->
+				ClientRegistration.withRegistrationId(REGISTRATION_ID)
+						.clientId(CLIENT_ID)
+						.clientSecret(CLIENT_SECRET)
+						.clientAuthenticationMethod(null)
+						.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
+						.tokenUri(TOKEN_URI)
+						.build()
+		).isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void buildWhenClientCredentialsGrantTokenUriIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() ->
+				ClientRegistration.withRegistrationId(REGISTRATION_ID)
+						.clientId(CLIENT_ID)
+						.clientSecret(CLIENT_SECRET)
+						.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+						.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
+						.tokenUri(null)
+						.build()
+		).isInstanceOf(IllegalArgumentException.class);
+	}
 }

+ 118 - 17
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java

@@ -20,6 +20,7 @@ import org.junit.Before;
 import org.junit.Test;
 import org.springframework.core.MethodParameter;
 import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
@@ -27,7 +28,16 @@ import org.springframework.security.oauth2.client.ClientAuthorizationRequiredExc
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
+import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
+import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.util.ReflectionUtils;
 import org.springframework.web.context.request.ServletWebRequest;
 
@@ -38,8 +48,8 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
 import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyString;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.*;
 
 /**
  * Tests for {@link OAuth2AuthorizedClientArgumentResolver}.
@@ -47,22 +57,58 @@ import static org.mockito.Mockito.when;
  * @author Joe Grandja
  */
 public class OAuth2AuthorizedClientArgumentResolverTests {
+	private TestingAuthenticationToken authentication;
+	private String principalName = "principal-1";
+	private ClientRegistration registration1;
+	private ClientRegistration registration2;
+	private ClientRegistrationRepository clientRegistrationRepository;
+	private OAuth2AuthorizedClient authorizedClient1;
+	private OAuth2AuthorizedClient authorizedClient2;
 	private OAuth2AuthorizedClientRepository authorizedClientRepository;
 	private OAuth2AuthorizedClientArgumentResolver argumentResolver;
-	private OAuth2AuthorizedClient authorizedClient;
 	private MockHttpServletRequest request;
 
 	@Before
 	public void setup() {
-		this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class);
-		this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientRepository);
-		this.authorizedClient = mock(OAuth2AuthorizedClient.class);
-		this.request = new MockHttpServletRequest();
-		when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class)))
-				.thenReturn(this.authorizedClient);
+		this.authentication = new TestingAuthenticationToken(this.principalName, "password");
 		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
-		securityContext.setAuthentication(mock(Authentication.class));
+		securityContext.setAuthentication(this.authentication);
 		SecurityContextHolder.setContext(securityContext);
+
+		this.registration1 = ClientRegistration.withRegistrationId("client1")
+				.clientId("client-1")
+				.clientSecret("secret")
+				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+				.redirectUriTemplate("{baseUrl}/login/oauth2/code/{registrationId}")
+				.scope("user")
+				.authorizationUri("https://provider.com/oauth2/authorize")
+				.tokenUri("https://provider.com/oauth2/token")
+				.userInfoUri("https://provider.com/oauth2/user")
+				.userNameAttributeName("id")
+				.clientName("client-1")
+				.build();
+		this.registration2 = ClientRegistration.withRegistrationId("client2")
+				.clientId("client-2")
+				.clientSecret("secret")
+				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+				.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
+				.scope("read", "write")
+				.tokenUri("https://provider.com/oauth2/token")
+				.build();
+		this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1, this.registration2);
+		this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class);
+		this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(
+				this.clientRegistrationRepository, this.authorizedClientRepository);
+		this.authorizedClient1 = new OAuth2AuthorizedClient(this.registration1, this.principalName, mock(OAuth2AccessToken.class));
+		when(this.authorizedClientRepository.loadAuthorizedClient(
+				eq(this.registration1.getRegistrationId()), any(Authentication.class), any(HttpServletRequest.class)))
+				.thenReturn(this.authorizedClient1);
+		this.authorizedClient2 = new OAuth2AuthorizedClient(this.registration2, this.principalName, mock(OAuth2AccessToken.class));
+		when(this.authorizedClientRepository.loadAuthorizedClient(
+				eq(this.registration2.getRegistrationId()), any(Authentication.class), any(HttpServletRequest.class)))
+				.thenReturn(this.authorizedClient2);
+		this.request = new MockHttpServletRequest();
 	}
 
 	@After
@@ -71,8 +117,20 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
 	}
 
 	@Test
-	public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null))
+	public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null, this.authorizedClientRepository))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void setClientCredentialsTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.argumentResolver.setClientCredentialsTokenResponseClient(null))
 				.isInstanceOf(IllegalArgumentException.class);
 	}
 
@@ -101,7 +159,7 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
 	}
 
 	@Test
-	public void resolveArgumentWhenRegistrationIdEmptyAndNotOAuth2AuthenticationThenThrowIllegalArgumentException() throws Exception {
+	public void resolveArgumentWhenRegistrationIdEmptyAndNotOAuth2AuthenticationThenThrowIllegalArgumentException() {
 		MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class);
 		assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, null, null))
 				.isInstanceOf(IllegalArgumentException.class)
@@ -116,18 +174,26 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
 		securityContext.setAuthentication(authentication);
 		SecurityContextHolder.setContext(securityContext);
 		MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class);
-		this.argumentResolver.resolveArgument(methodParameter, null, new ServletWebRequest(this.request), null);
+		assertThat(this.argumentResolver.resolveArgument(
+				methodParameter, null, new ServletWebRequest(this.request), null)).isSameAs(this.authorizedClient1);
 	}
 
 	@Test
-	public void resolveArgumentWhenOAuth2AuthorizedClientFoundThenResolves() throws Exception {
+	public void resolveArgumentWhenAuthorizedClientFoundThenResolves() throws Exception {
 		MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
 		assertThat(this.argumentResolver.resolveArgument(
-				methodParameter, null, new ServletWebRequest(this.request), null)).isSameAs(this.authorizedClient);
+				methodParameter, null, new ServletWebRequest(this.request), null)).isSameAs(this.authorizedClient1);
+	}
+
+	@Test
+	public void resolveArgumentWhenRegistrationIdInvalidThenDoesNotResolve() throws Exception {
+		MethodParameter methodParameter = this.getMethodParameter("registrationIdInvalid", OAuth2AuthorizedClient.class);
+		assertThat(this.argumentResolver.resolveArgument(
+				methodParameter, null, new ServletWebRequest(this.request), null)).isNull();
 	}
 
 	@Test
-	public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() throws Exception {
+	public void resolveArgumentWhenAuthorizedClientNotFoundForAuthorizationCodeClientThenThrowClientAuthorizationRequiredException() {
 		when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class)))
 				.thenReturn(null);
 		MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
@@ -135,6 +201,35 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
 				.isInstanceOf(ClientAuthorizationRequiredException.class);
 	}
 
+	@SuppressWarnings("unchecked")
+	@Test
+	public void resolveArgumentWhenAuthorizedClientNotFoundForClientCredentialsClientThenResolvesFromTokenResponseClient() throws Exception {
+		OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
+				mock(OAuth2AccessTokenResponseClient.class);
+		this.argumentResolver.setClientCredentialsTokenResponseClient(clientCredentialsTokenResponseClient);
+		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
+				.withToken("access-token-1234")
+				.tokenType(OAuth2AccessToken.TokenType.BEARER)
+				.expiresIn(3600)
+				.build();
+		when(clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
+
+		when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class)))
+				.thenReturn(null);
+		MethodParameter methodParameter = this.getMethodParameter("clientCredentialsClient", OAuth2AuthorizedClient.class);
+
+		OAuth2AuthorizedClient authorizedClient = (OAuth2AuthorizedClient) this.argumentResolver.resolveArgument(
+				methodParameter, null, new ServletWebRequest(this.request), null);
+
+		assertThat(authorizedClient).isNotNull();
+		assertThat(authorizedClient.getClientRegistration()).isSameAs(this.registration2);
+		assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName);
+		assertThat(authorizedClient.getAccessToken()).isSameAs(accessTokenResponse.getAccessToken());
+
+		verify(this.authorizedClientRepository).saveAuthorizedClient(
+				eq(authorizedClient), eq(this.authentication), any(HttpServletRequest.class), eq(null));
+	}
+
 	private MethodParameter getMethodParameter(String methodName, Class<?>... paramTypes) {
 		Method method = ReflectionUtils.findMethod(TestController.class, methodName, paramTypes);
 		return new MethodParameter(method, 0);
@@ -155,5 +250,11 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
 
 		void registrationIdEmpty(@RegisteredOAuth2AuthorizedClient OAuth2AuthorizedClient authorizedClient) {
 		}
+
+		void registrationIdInvalid(@RegisteredOAuth2AuthorizedClient("invalid") OAuth2AuthorizedClient authorizedClient) {
+		}
+
+		void clientCredentialsClient(@RegisteredOAuth2AuthorizedClient("client2") OAuth2AuthorizedClient authorizedClient) {
+		}
 	}
 }

+ 1 - 0
oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthorizationGrantType.java

@@ -38,6 +38,7 @@ public final class AuthorizationGrantType implements Serializable {
 	public static final AuthorizationGrantType AUTHORIZATION_CODE = new AuthorizationGrantType("authorization_code");
 	public static final AuthorizationGrantType IMPLICIT = new AuthorizationGrantType("implicit");
 	public static final AuthorizationGrantType REFRESH_TOKEN = new AuthorizationGrantType("refresh_token");
+	public static final AuthorizationGrantType CLIENT_CREDENTIALS = new AuthorizationGrantType("client_credentials");
 	private final String value;
 
 	/**

+ 31 - 1
oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -25,6 +25,11 @@ package org.springframework.security.oauth2.core.endpoint;
  */
 public interface OAuth2ParameterNames {
 
+	/**
+	 * {@code grant_type} - used in Access Token Request.
+	 */
+	String GRANT_TYPE = "grant_type";
+
 	/**
 	 * {@code response_type} - used in Authorization Request.
 	 */
@@ -35,6 +40,11 @@ public interface OAuth2ParameterNames {
 	 */
 	String CLIENT_ID = "client_id";
 
+	/**
+	 * {@code client_secret} - used in Access Token Request.
+	 */
+	String CLIENT_SECRET = "client_secret";
+
 	/**
 	 * {@code redirect_uri} - used in Authorization Request and Access Token Request.
 	 */
@@ -55,6 +65,26 @@ public interface OAuth2ParameterNames {
 	 */
 	String CODE = "code";
 
+	/**
+	 * {@code access_token} - used in Authorization Response and Access Token Response.
+	 */
+	String ACCESS_TOKEN = "access_token";
+
+	/**
+	 * {@code token_type} - used in Authorization Response and Access Token Response.
+	 */
+	String TOKEN_TYPE = "token_type";
+
+	/**
+	 * {@code expires_in} - used in Authorization Response and Access Token Response.
+	 */
+	String EXPIRES_IN = "expires_in";
+
+	/**
+	 * {@code refresh_token} - used in Access Token Request and Access Token Response.
+	 */
+	String REFRESH_TOKEN = "refresh_token";
+
 	/**
 	 * {@code error} - used in Authorization Response and Access Token Response.
 	 */