浏览代码

Add NimbusReactiveAuthorizationCodeTokenResponseClient

Issue: gh-4807
Rob Winch 7 年之前
父节点
当前提交
3cd2ddf793

+ 171 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusReactiveAuthorizationCodeTokenResponseClient.java

@@ -0,0 +1,171 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.endpoint;
+
+import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.Credentials.basicAuthenticationCredentials;
+
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
+import java.util.Map;
+import java.util.Set;
+
+import org.springframework.core.ParameterizedTypeReference;
+import org.springframework.http.MediaType;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
+import org.springframework.util.CollectionUtils;
+import org.springframework.web.reactive.function.BodyInserters;
+import org.springframework.web.reactive.function.client.ExchangeFilterFunctions;
+import org.springframework.web.reactive.function.client.WebClient;
+
+import com.nimbusds.oauth2.sdk.AccessTokenResponse;
+import com.nimbusds.oauth2.sdk.ErrorObject;
+import com.nimbusds.oauth2.sdk.ParseException;
+import com.nimbusds.oauth2.sdk.TokenErrorResponse;
+import com.nimbusds.oauth2.sdk.TokenResponse;
+import com.nimbusds.oauth2.sdk.token.AccessToken;
+
+import net.minidev.json.JSONObject;
+import reactor.core.publisher.Mono;
+
+/**
+ * An implementation of an {@link ReactiveOAuth2AccessTokenResponseClient} that "exchanges"
+ * an authorization code credential for an access token credential
+ * at the Authorization Server's Token Endpoint.
+ *
+ * <p>
+ * <b>NOTE:</b> This implementation uses the Nimbus OAuth 2.0 SDK internally.
+ *
+ * @author Rob Winch
+ * @since 5.1
+ * @see OAuth2AccessTokenResponseClient
+ * @see OAuth2AuthorizationCodeGrantRequest
+ * @see OAuth2AccessTokenResponse
+ * @see <a target="_blank" href="https://connect2id.com/products/nimbus-oauth-openid-connect-sdk">Nimbus OAuth 2.0 SDK</a>
+ * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.3">Section 4.1.3 Access Token Request (Authorization Code Grant)</a>
+ * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.4">Section 4.1.4 Access Token Response (Authorization Code Grant)</a>
+ */
+public class NimbusReactiveAuthorizationCodeTokenResponseClient implements ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> {
+	private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response";
+
+	private WebClient webClient = WebClient.builder()
+			.filter(ExchangeFilterFunctions.basicAuthentication())
+			.build();
+
+	@Override
+	public Mono<OAuth2AccessTokenResponse> getTokenResponse(OAuth2AuthorizationCodeGrantRequest authorizationGrantRequest)
+			throws OAuth2AuthenticationException {
+
+		return Mono.defer(() -> {
+			ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration();
+
+			OAuth2AuthorizationExchange authorizationExchange = authorizationGrantRequest.getAuthorizationExchange();
+			String tokenUri = clientRegistration.getProviderDetails().getTokenUri();
+			BodyInserters.FormInserter<String> body = body(authorizationExchange);
+
+			return this.webClient.post()
+					.uri(tokenUri)
+					.accept(MediaType.APPLICATION_JSON)
+					.attributes(basicAuthenticationCredentials(clientRegistration.getClientId(), clientRegistration.getClientSecret()))
+					.body(body)
+					.retrieve()
+					.onStatus(s -> false, response -> {
+						throw new IllegalStateException("Disabled Status Handlers");
+					})
+					.bodyToMono(new ParameterizedTypeReference<Map<String, String>>() {})
+					.map(json -> parse(json))
+					.flatMap(tokenResponse -> accessTokenResponse(tokenResponse))
+					.map(accessTokenResponse -> {
+						AccessToken accessToken = accessTokenResponse.getTokens().getAccessToken();
+						OAuth2AccessToken.TokenType accessTokenType = null;
+						if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(
+								accessToken.getType().getValue())) {
+							accessTokenType = OAuth2AccessToken.TokenType.BEARER;
+						}
+						long expiresIn = accessToken.getLifetime();
+
+						// As per spec, in section 5.1 Successful Access Token Response
+						// https://tools.ietf.org/html/rfc6749#section-5.1
+						// If AccessTokenResponse.scope is empty, then default to the scope
+						// originally requested by the client in the Authorization Request
+						Set<String> scopes;
+						if (CollectionUtils.isEmpty(
+								accessToken.getScope())) {
+							scopes = new LinkedHashSet<>(
+									authorizationExchange.getAuthorizationRequest().getScopes());
+						}
+						else {
+							scopes = new LinkedHashSet<>(
+									accessToken.getScope().toStringList());
+						}
+
+						Map<String, Object> additionalParameters = new LinkedHashMap<>(
+								accessTokenResponse.getCustomParameters());
+
+						return OAuth2AccessTokenResponse.withToken(accessToken.getValue())
+								.tokenType(accessTokenType)
+								.expiresIn(expiresIn)
+								.scopes(scopes)
+								.additionalParameters(additionalParameters)
+								.build();
+					});
+		});
+	}
+
+	private static BodyInserters.FormInserter<String> body(OAuth2AuthorizationExchange authorizationExchange) {
+		OAuth2AuthorizationResponse authorizationResponse = authorizationExchange.getAuthorizationResponse();
+		String redirectUri = authorizationExchange.getAuthorizationRequest().getRedirectUri();
+		BodyInserters.FormInserter<String> body = BodyInserters
+				.fromFormData("grant_type", AuthorizationGrantType.AUTHORIZATION_CODE.getValue())
+				.with("code", authorizationResponse.getCode());
+		if (redirectUri != null) {
+			body.with("redirect_uri", redirectUri);
+		}
+		return body;
+	}
+
+	private static Mono<AccessTokenResponse> accessTokenResponse(TokenResponse tokenResponse) {
+		if (tokenResponse.indicatesSuccess()) {
+			return Mono.just(tokenResponse)
+					.cast(AccessTokenResponse.class);
+		}
+		TokenErrorResponse tokenErrorResponse = (TokenErrorResponse) tokenResponse;
+		ErrorObject errorObject = tokenErrorResponse.getErrorObject();
+		OAuth2Error oauth2Error = new OAuth2Error(errorObject.getCode(),
+				errorObject.getDescription(), (errorObject.getURI() != null ?
+				errorObject.getURI().toString() :
+				null));
+
+		return Mono.error(new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()));
+	}
+
+	private static TokenResponse parse(Map<String, String> json) {
+		try {
+			return TokenResponse.parse(new JSONObject(json));
+		}
+		catch (ParseException pe) {
+			OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
+					"An error occurred parsing the Access Token response: " + pe.getMessage(), null);
+			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), pe);
+		}
+	}
+}

+ 50 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/ReactiveOAuth2AccessTokenResponseClient.java

@@ -0,0 +1,50 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.endpoint;
+
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+
+import reactor.core.publisher.Mono;
+
+/**
+ * A reactive strategy for &quot;exchanging&quot; an authorization grant credential
+ * (e.g. an Authorization Code) for an access token credential
+ * at the Authorization Server's Token Endpoint.
+ *
+ * @author Rob Winch
+ * @since 5.1
+ * @see AbstractOAuth2AuthorizationGrantRequest
+ * @see OAuth2AccessTokenResponse
+ * @see AuthorizationGrantType
+ * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-1.3">Section 1.3 Authorization Grant</a>
+ * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.3">Section 4.1.3 Access Token Request (Authorization Code Grant)</a>
+ * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.4">Section 4.1.4 Access Token Response (Authorization Code Grant)</a>
+ */
+public interface ReactiveOAuth2AccessTokenResponseClient<T extends AbstractOAuth2AuthorizationGrantRequest>  {
+
+	/**
+	 * Exchanges the authorization grant credential, provided in the authorization grant request,
+	 * for an access token credential at the Authorization Server's Token Endpoint.
+	 *
+	 * @param authorizationGrantRequest the authorization grant request that contains the authorization grant credential
+	 * @return an {@link OAuth2AccessTokenResponse} that contains the {@link OAuth2AccessTokenResponse#getAccessToken() access token} credential
+	 * @throws OAuth2AuthenticationException if an error occurs while attempting to exchange for the access token credential
+	 */
+	Mono<OAuth2AccessTokenResponse> getTokenResponse(T authorizationGrantRequest) throws OAuth2AuthenticationException;
+
+}

+ 262 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusReactiveAuthorizationCodeTokenResponseClientTests.java

@@ -0,0 +1,262 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.endpoint;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+import java.time.Instant;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpStatus;
+import org.springframework.http.MediaType;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
+
+import okhttp3.mockwebserver.MockResponse;
+import okhttp3.mockwebserver.MockWebServer;
+
+/**
+ * @author Rob Winch
+ * @since 5.1
+ */
+public class NimbusReactiveAuthorizationCodeTokenResponseClientTests {
+	private ClientRegistration.Builder clientRegistration;
+
+	private NimbusReactiveAuthorizationCodeTokenResponseClient tokenResponseClient = new NimbusReactiveAuthorizationCodeTokenResponseClient();
+
+	private MockWebServer server;
+
+	@Before
+	public void setup() throws Exception {
+		this.server = new MockWebServer();
+		this.server.start();
+
+		String tokenUri = this.server.url("/oauth2/token").toString();
+
+		this.clientRegistration = ClientRegistration.withRegistrationId("github")
+				.redirectUriTemplate("https://example.com/oauth2/code/github")
+				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+				.scope("read:user")
+				.authorizationUri("https://github.com/login/oauth/authorize")
+				.tokenUri(tokenUri)
+				.userInfoUri("https://api.example.com/user")
+				.userNameAttributeName("user-name")
+				.clientName("GitHub")
+				.clientId("clientId")
+				.jwkSetUri("https://example.com/oauth2/jwk")
+				.clientSecret("clientSecret");
+	}
+
+	@After
+	public void cleanup() throws Exception {
+		this.server.shutdown();
+	}
+
+	@Test
+	public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
+		String accessTokenSuccessResponse = "{\n" +
+				"	\"access_token\": \"access-token-1234\",\n" +
+				"   \"token_type\": \"bearer\",\n" +
+				"   \"expires_in\": \"3600\",\n" +
+				"   \"scope\": \"openid profile\",\n" +
+				"   \"custom_parameter_1\": \"custom-value-1\",\n" +
+				"   \"custom_parameter_2\": \"custom-value-2\"\n" +
+				"}\n";
+		this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
+
+
+		Instant expiresAtBefore = Instant.now().plusSeconds(3600);
+
+		OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block();
+
+		Instant expiresAtAfter = Instant.now().plusSeconds(3600);
+
+		assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234");
+		assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(
+				OAuth2AccessToken.TokenType.BEARER);
+		assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter);
+		assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("openid", "profile");
+		assertThat(accessTokenResponse.getAdditionalParameters().size()).isEqualTo(2);
+		assertThat(accessTokenResponse.getAdditionalParameters()).containsEntry("custom_parameter_1", "custom-value-1");
+		assertThat(accessTokenResponse.getAdditionalParameters()).containsEntry("custom_parameter_2", "custom-value-2");
+	}
+
+//	@Test
+//	public void getTokenResponseWhenRedirectUriMalformedThenThrowIllegalArgumentException() throws Exception {
+//		this.exception.expect(IllegalArgumentException.class);
+//
+//		String redirectUri = "http:\\example.com";
+//		when(this.clientRegistration.getRedirectUriTemplate()).thenReturn(redirectUri);
+//
+//		this.tokenResponseClient.getTokenResponse(
+//				new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
+//	}
+//
+//	@Test
+//	public void getTokenResponseWhenTokenUriMalformedThenThrowIllegalArgumentException() throws Exception {
+//		this.exception.expect(IllegalArgumentException.class);
+//
+//		String tokenUri = "http:\\provider.com\\oauth2\\token";
+//		when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
+//
+//		this.tokenResponseClient.getTokenResponse(
+//				new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
+//	}
+//
+//	@Test
+//	public void getTokenResponseWhenSuccessResponseInvalidThenThrowOAuth2AuthenticationException() throws Exception {
+//		this.exception.expect(OAuth2AuthenticationException.class);
+//		this.exception.expectMessage(containsString("invalid_token_response"));
+//
+//		MockWebServer server = new MockWebServer();
+//
+//		String accessTokenSuccessResponse = "{\n" +
+//				"	\"access_token\": \"access-token-1234\",\n" +
+//				"   \"token_type\": \"bearer\",\n" +
+//				"   \"expires_in\": \"3600\",\n" +
+//				"   \"scope\": \"openid profile\",\n" +
+//				"   \"custom_parameter_1\": \"custom-value-1\",\n" +
+//				"   \"custom_parameter_2\": \"custom-value-2\"\n";
+//		//			"}\n";		// Make the JSON invalid/malformed
+//
+//		server.enqueue(new MockResponse()
+//				.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
+//				.setBody(accessTokenSuccessResponse));
+//		server.start();
+//
+//		String tokenUri = server.url("/oauth2/token").toString();
+//		when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
+//
+//		try {
+//			this.tokenResponseClient.getTokenResponse(
+//					new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
+//		} finally {
+//			server.shutdown();
+//		}
+//	}
+//
+//	@Test
+//	public void getTokenResponseWhenTokenUriInvalidThenThrowAuthenticationServiceException() throws Exception {
+//		this.exception.expect(AuthenticationServiceException.class);
+//
+//		String tokenUri = "http://invalid-provider.com/oauth2/token";
+//		when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
+//
+//		this.tokenResponseClient.getTokenResponse(
+//				new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
+//	}
+//
+	@Test
+	public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthenticationException() throws Exception {
+		String accessTokenErrorResponse = "{\n" +
+				"   \"error\": \"unauthorized_client\"\n" +
+				"}\n";
+
+		this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(HttpStatus.INTERNAL_SERVER_ERROR.value()));
+
+		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block())
+			.isInstanceOf(OAuth2AuthenticationException.class)
+			.hasMessageContaining("unauthorized_client");
+	}
+
+	@Test
+	public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthenticationException() throws Exception {
+		String accessTokenSuccessResponse = "{\n" +
+				"	\"access_token\": \"access-token-1234\",\n" +
+				"   \"token_type\": \"not-bearer\",\n" +
+				"   \"expires_in\": \"3600\"\n" +
+				"}\n";
+
+		this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
+
+		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block())
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.hasMessageContaining("invalid_token_response");
+	}
+
+	@Test
+	public void getTokenResponseWhenSuccessResponseIncludesScopeThenReturnAccessTokenResponseUsingResponseScope() throws Exception {
+		String accessTokenSuccessResponse = "{\n" +
+				"	\"access_token\": \"access-token-1234\",\n" +
+				"   \"token_type\": \"bearer\",\n" +
+				"   \"expires_in\": \"3600\",\n" +
+				"   \"scope\": \"openid profile\"\n" +
+				"}\n";
+		this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
+
+		this.clientRegistration.scope("openid", "profile", "email", "address");
+
+		OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block();
+
+		assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("openid", "profile");
+	}
+
+	@Test
+	public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenReturnAccessTokenResponseUsingRequestedScope() throws Exception {
+		String accessTokenSuccessResponse = "{\n" +
+				"	\"access_token\": \"access-token-1234\",\n" +
+				"   \"token_type\": \"bearer\",\n" +
+				"   \"expires_in\": \"3600\"\n" +
+				"}\n";
+		this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
+
+
+		this.clientRegistration.scope("openid", "profile", "email", "address");
+
+		OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block();
+
+		assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("openid", "profile", "email", "address");
+	}
+
+	private OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest() {
+		ClientRegistration registration = this.clientRegistration.build();
+		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest
+				.authorizationCode()
+				.clientId(registration.getClientId())
+				.state("state")
+				.authorizationUri(registration.getProviderDetails().getAuthorizationUri())
+				.redirectUri(registration.getRedirectUriTemplate())
+				.scopes(registration.getScopes())
+				.build();
+		OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponse
+				.success("code")
+				.state("state")
+				.redirectUri(registration.getRedirectUriTemplate())
+				.build();
+		OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest,
+				authorizationResponse);
+		return new OAuth2AuthorizationCodeGrantRequest(registration, authorizationExchange);
+	}
+
+	private MockResponse jsonResponse(String json) {
+		return new MockResponse()
+			.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
+			.setBody(json);
+	}
+}