瀏覽代碼

WebClientReactiveClientCredentialsTokenResponseClient

Fixes: gh-5607
Rob Winch 7 年之前
父節點
當前提交
28537fa3b6

+ 107 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java

@@ -0,0 +1,107 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.endpoint;
+
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.util.CollectionUtils;
+import org.springframework.util.StringUtils;
+import org.springframework.web.reactive.function.BodyInserters;
+import org.springframework.web.reactive.function.client.WebClient;
+import reactor.core.publisher.Mono;
+
+import java.util.Set;
+import java.util.function.Consumer;
+
+import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse;
+
+/**
+ * An implementation of an {@link ReactiveOAuth2AccessTokenResponseClient} that "exchanges"
+ * an authorization code credential for an access token credential
+ * at the Authorization Server's Token Endpoint.
+ *
+ * @author Rob Winch
+ * @since 5.1
+ * @see OAuth2AccessTokenResponseClient
+ * @see OAuth2AuthorizationCodeGrantRequest
+ * @see OAuth2AccessTokenResponse
+ * @see <a target="_blank" href="https://connect2id.com/products/nimbus-oauth-openid-connect-sdk">Nimbus OAuth 2.0 SDK</a>
+ * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.3">Section 4.1.3 Access Token Request (Authorization Code Grant)</a>
+ * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.4">Section 4.1.4 Access Token Response (Authorization Code Grant)</a>
+ */
+public class WebClientReactiveClientCredentialsTokenResponseClient implements ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> {
+	private WebClient webClient = WebClient.builder()
+			.build();
+
+	@Override
+	public Mono<OAuth2AccessTokenResponse> getTokenResponse(OAuth2ClientCredentialsGrantRequest authorizationGrantRequest)
+			throws OAuth2AuthenticationException {
+
+		return Mono.defer(() -> {
+			ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration();
+
+			String tokenUri = clientRegistration.getProviderDetails().getTokenUri();
+			BodyInserters.FormInserter<String> body = body(authorizationGrantRequest);
+
+			return this.webClient.post()
+					.uri(tokenUri)
+					.accept(MediaType.APPLICATION_JSON)
+					.headers(headers(clientRegistration))
+					.body(body)
+					.exchange()
+					.flatMap(response -> response.body(oauth2AccessTokenResponse()))
+					.map(response -> {
+						if (response.getAccessToken().getScopes().isEmpty()) {
+							response = OAuth2AccessTokenResponse.withResponse(response)
+								.scopes(authorizationGrantRequest.getClientRegistration().getScopes())
+								.build();
+						}
+						return response;
+					});
+		});
+	}
+
+	private Consumer<HttpHeaders> headers(ClientRegistration clientRegistration) {
+		return headers -> {
+			headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
+			headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
+			if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) {
+				headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
+			}
+		};
+	}
+
+	private static BodyInserters.FormInserter<String> body(OAuth2ClientCredentialsGrantRequest authorizationGrantRequest) {
+		ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration();
+		BodyInserters.FormInserter<String> body = BodyInserters
+				.fromFormData(OAuth2ParameterNames.GRANT_TYPE, authorizationGrantRequest.getGrantType().getValue());
+		Set<String> scopes = clientRegistration.getScopes();
+		if (!CollectionUtils.isEmpty(scopes)) {
+			String scope = StringUtils.collectionToDelimitedString(scopes, " ");
+			body.with(OAuth2ParameterNames.SCOPE, scope);
+		}
+		if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) {
+			body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
+			body.with(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
+		}
+		return body;
+	}
+}

+ 126 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java

@@ -0,0 +1,126 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.endpoint;
+
+import okhttp3.mockwebserver.MockResponse;
+import okhttp3.mockwebserver.MockWebServer;
+import okhttp3.mockwebserver.RecordedRequest;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+
+import static org.assertj.core.api.Assertions.*;
+
+/**
+ * @author Rob Winch
+ */
+public class WebClientReactiveClientCredentialsTokenResponseClientTests {
+
+	private MockWebServer server;
+
+	private WebClientReactiveClientCredentialsTokenResponseClient client = new WebClientReactiveClientCredentialsTokenResponseClient();
+
+	private ClientRegistration.Builder clientRegistration;
+
+	@Before
+	public void setup() throws Exception {
+		this.server = new MockWebServer();
+		this.server.start();
+
+		this.clientRegistration = TestClientRegistrations
+				.clientCredentials()
+				.tokenUri(this.server.url("/oauth2/token").uri().toASCIIString());
+	}
+
+	@After
+	public void cleanup() throws Exception {
+		this.server.shutdown();
+	}
+
+	@Test
+	public void getTokenResponseWhenHeaderThenSuccess() throws Exception {
+		enqueueJson("{\n"
+				+ "  \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+				+ "  \"token_type\":\"bearer\",\n"
+				+ "  \"expires_in\":3600,\n"
+				+ "  \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n"
+				+ "  \"scope\":\"create\"\n"
+				+ "}");
+		OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(this.clientRegistration
+				.build());
+
+		OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
+		RecordedRequest actualRequest = this.server.takeRequest();
+		String body = actualRequest.getUtf8Body();
+
+		assertThat(response.getAccessToken()).isNotNull();
+		assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
+		assertThat(body).isEqualTo("grant_type=client_credentials&scope=read%3Auser");
+	}
+
+	@Test
+	public void getTokenResponseWhenPostThenSuccess() throws Exception {
+		ClientRegistration registration = this.clientRegistration
+				.clientAuthenticationMethod(ClientAuthenticationMethod.POST)
+				.build();
+		enqueueJson("{\n"
+				+ "  \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+				+ "  \"token_type\":\"bearer\",\n"
+				+ "  \"expires_in\":3600,\n"
+				+ "  \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n"
+				+ "  \"scope\":\"create\"\n"
+				+ "}");
+
+		OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration);
+
+		OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
+		String body = this.server.takeRequest().getUtf8Body();
+
+		assertThat(response.getAccessToken()).isNotNull();
+		assertThat(body).isEqualTo("grant_type=client_credentials&scope=read%3Auser&client_id=client-id&client_secret=client-secret");
+	}
+
+	@Test
+	public void getTokenResponseWhenNoScopeThenClientRegistrationScopesDefaulted() {
+		ClientRegistration registration = this.clientRegistration.build();
+		enqueueJson("{\n"
+				+ "  \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+				+ "  \"token_type\":\"bearer\",\n"
+				+ "  \"expires_in\":3600,\n"
+				+ "  \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+				+ "}");
+		OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration);
+
+		OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
+
+		assertThat(response.getAccessToken().getScopes()).isEqualTo(registration.getScopes());
+	}
+
+
+	private void enqueueJson(String body) {
+		MockResponse response = new MockResponse()
+				.setBody(body)
+				.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE);
+		this.server.enqueue(response);
+	}
+}