Browse Source

Add OAuth2AccessTokenResponseBodyExtractor

This externalizes converting a OAuth2AccessTokenResponse from a
ReactiveHttpInputMessage.

Fixes: gh-5475
Rob Winch 7 years ago
parent
commit
e27e1cd637

+ 11 - 93
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusReactiveAuthorizationCodeTokenResponseClient.java

@@ -15,38 +15,21 @@
  */
 package org.springframework.security.oauth2.client.endpoint;
 
-import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.Credentials.basicAuthenticationCredentials;
-
-import java.util.LinkedHashMap;
-import java.util.LinkedHashSet;
-import java.util.Map;
-import java.util.Set;
-
-import org.springframework.core.ParameterizedTypeReference;
 import org.springframework.http.MediaType;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
-import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
-import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
-import org.springframework.util.CollectionUtils;
 import org.springframework.web.reactive.function.BodyInserters;
 import org.springframework.web.reactive.function.client.ExchangeFilterFunctions;
 import org.springframework.web.reactive.function.client.WebClient;
-
-import com.nimbusds.oauth2.sdk.AccessTokenResponse;
-import com.nimbusds.oauth2.sdk.ErrorObject;
-import com.nimbusds.oauth2.sdk.ParseException;
-import com.nimbusds.oauth2.sdk.TokenErrorResponse;
-import com.nimbusds.oauth2.sdk.TokenResponse;
-import com.nimbusds.oauth2.sdk.token.AccessToken;
-
-import net.minidev.json.JSONObject;
 import reactor.core.publisher.Mono;
 
+import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse;
+import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.Credentials.basicAuthenticationCredentials;
+
 /**
  * An implementation of an {@link ReactiveOAuth2AccessTokenResponseClient} that "exchanges"
  * an authorization code credential for an access token credential
@@ -65,8 +48,6 @@ import reactor.core.publisher.Mono;
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.4">Section 4.1.4 Access Token Response (Authorization Code Grant)</a>
  */
 public class NimbusReactiveAuthorizationCodeTokenResponseClient implements ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> {
-	private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response";
-
 	private WebClient webClient = WebClient.builder()
 			.filter(ExchangeFilterFunctions.basicAuthentication())
 			.build();
@@ -87,52 +68,15 @@ public class NimbusReactiveAuthorizationCodeTokenResponseClient implements React
 					.accept(MediaType.APPLICATION_JSON)
 					.attributes(basicAuthenticationCredentials(clientRegistration.getClientId(), clientRegistration.getClientSecret()))
 					.body(body)
-					.retrieve()
-					.onStatus(s -> false, response -> {
-						throw new IllegalStateException("Disabled Status Handlers");
-					})
-					.bodyToMono(new ParameterizedTypeReference<Map<String, String>>() {})
-					.map(json -> parse(json))
-					.flatMap(tokenResponse -> accessTokenResponse(tokenResponse))
-					.map(accessTokenResponse -> {
-						AccessToken accessToken = accessTokenResponse.getTokens().getAccessToken();
-						OAuth2AccessToken.TokenType accessTokenType = null;
-						if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(
-								accessToken.getType().getValue())) {
-							accessTokenType = OAuth2AccessToken.TokenType.BEARER;
-						}
-						long expiresIn = accessToken.getLifetime();
-
-						// As per spec, in section 5.1 Successful Access Token Response
-						// https://tools.ietf.org/html/rfc6749#section-5.1
-						// If AccessTokenResponse.scope is empty, then default to the scope
-						// originally requested by the client in the Authorization Request
-						Set<String> scopes;
-						if (CollectionUtils.isEmpty(
-								accessToken.getScope())) {
-							scopes = new LinkedHashSet<>(
-									authorizationExchange.getAuthorizationRequest().getScopes());
-						}
-						else {
-							scopes = new LinkedHashSet<>(
-									accessToken.getScope().toStringList());
-						}
-
-						String refreshToken = null;
-						if (accessTokenResponse.getTokens().getRefreshToken() != null) {
-							refreshToken = accessTokenResponse.getTokens().getRefreshToken().getValue();
-						}
-
-						Map<String, Object> additionalParameters = new LinkedHashMap<>(
-								accessTokenResponse.getCustomParameters());
-
-						return OAuth2AccessTokenResponse.withToken(accessToken.getValue())
-								.tokenType(accessTokenType)
-								.expiresIn(expiresIn)
-								.scopes(scopes)
-								.refreshToken(refreshToken)
-								.additionalParameters(additionalParameters)
+					.exchange()
+					.flatMap(response -> response.body(oauth2AccessTokenResponse()))
+					.map(response -> {
+						if (response.getAccessToken().getScopes().isEmpty()) {
+							response = OAuth2AccessTokenResponse.withResponse(response)
+								.scopes(authorizationExchange.getAuthorizationRequest().getScopes())
 								.build();
+						}
+						return response;
 					});
 		});
 	}
@@ -148,30 +92,4 @@ public class NimbusReactiveAuthorizationCodeTokenResponseClient implements React
 		}
 		return body;
 	}
-
-	private static Mono<AccessTokenResponse> accessTokenResponse(TokenResponse tokenResponse) {
-		if (tokenResponse.indicatesSuccess()) {
-			return Mono.just(tokenResponse)
-					.cast(AccessTokenResponse.class);
-		}
-		TokenErrorResponse tokenErrorResponse = (TokenErrorResponse) tokenResponse;
-		ErrorObject errorObject = tokenErrorResponse.getErrorObject();
-		OAuth2Error oauth2Error = new OAuth2Error(errorObject.getCode(),
-				errorObject.getDescription(), (errorObject.getURI() != null ?
-				errorObject.getURI().toString() :
-				null));
-
-		return Mono.error(new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()));
-	}
-
-	private static TokenResponse parse(Map<String, String> json) {
-		try {
-			return TokenResponse.parse(new JSONObject(json));
-		}
-		catch (ParseException pe) {
-			OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
-					"An error occurred parsing the Access Token response: " + pe.getMessage(), null);
-			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), pe);
-		}
-	}
 }

+ 4 - 0
oauth2/oauth2-core/spring-security-oauth2-core.gradle

@@ -4,5 +4,9 @@ dependencies {
 	compile project(':spring-security-core')
 	compile springCoreDependency
 
+	optional 'com.fasterxml.jackson.core:jackson-databind'
+	optional 'com.nimbusds:oauth2-oidc-sdk'
+	optional 'org.springframework:spring-webflux'
+
 	testCompile powerMock2Dependencies
 }

+ 113 - 0
oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2AccessTokenResponseBodyExtractor.java

@@ -0,0 +1,113 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.core.web.reactive.function;
+
+import com.nimbusds.oauth2.sdk.AccessTokenResponse;
+import com.nimbusds.oauth2.sdk.ErrorObject;
+import com.nimbusds.oauth2.sdk.ParseException;
+import com.nimbusds.oauth2.sdk.TokenErrorResponse;
+import com.nimbusds.oauth2.sdk.TokenResponse;
+import com.nimbusds.oauth2.sdk.token.AccessToken;
+import net.minidev.json.JSONObject;
+import org.springframework.core.ParameterizedTypeReference;
+import org.springframework.http.ReactiveHttpInputMessage;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.web.reactive.function.BodyExtractor;
+import org.springframework.web.reactive.function.BodyExtractors;
+import reactor.core.publisher.Mono;
+
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Provides a way to create an {@link OAuth2AccessTokenResponse} from a {@link ReactiveHttpInputMessage}
+ * @author Rob Winch
+ * @since 5.1
+ */
+class OAuth2AccessTokenResponseBodyExtractor
+		implements BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> {
+
+	private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response";
+
+	OAuth2AccessTokenResponseBodyExtractor() {}
+
+	@Override
+	public Mono<OAuth2AccessTokenResponse> extract(ReactiveHttpInputMessage inputMessage,
+			Context context) {
+		ParameterizedTypeReference<Map<String, String>> type = new ParameterizedTypeReference<Map<String, String>>() {};
+		BodyExtractor<Mono<Map<String, String>>, ReactiveHttpInputMessage> delegate = BodyExtractors.toMono(type);
+		return delegate.extract(inputMessage, context)
+				.map(json -> parse(json))
+				.flatMap(OAuth2AccessTokenResponseBodyExtractor::oauth2AccessTokenResponse)
+				.map(OAuth2AccessTokenResponseBodyExtractor::oauth2AccessTokenResponse);
+	}
+
+	private static TokenResponse parse(Map<String, String> json) {
+		try {
+			return TokenResponse.parse(new JSONObject(json));
+		}
+		catch (ParseException pe) {
+			OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
+					"An error occurred parsing the Access Token response: " + pe.getMessage(), null);
+			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), pe);
+		}
+	}
+
+	private static Mono<AccessTokenResponse> oauth2AccessTokenResponse(TokenResponse tokenResponse) {
+		if (tokenResponse.indicatesSuccess()) {
+			return Mono.just(tokenResponse)
+					.cast(AccessTokenResponse.class);
+		}
+		TokenErrorResponse tokenErrorResponse = (TokenErrorResponse) tokenResponse;
+		ErrorObject errorObject = tokenErrorResponse.getErrorObject();
+		OAuth2Error oauth2Error = new OAuth2Error(errorObject.getCode(),
+				errorObject.getDescription(), (errorObject.getURI() != null ?
+				errorObject.getURI().toString() :
+				null));
+
+		return Mono.error(new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()));
+	}
+
+	private static OAuth2AccessTokenResponse oauth2AccessTokenResponse(AccessTokenResponse accessTokenResponse) {
+		AccessToken accessToken = accessTokenResponse.getTokens().getAccessToken();
+		OAuth2AccessToken.TokenType accessTokenType = null;
+		if (OAuth2AccessToken.TokenType.BEARER.getValue()
+				.equalsIgnoreCase(accessToken.getType().getValue())) {
+			accessTokenType = OAuth2AccessToken.TokenType.BEARER;
+		}
+		long expiresIn = accessToken.getLifetime();
+
+		Set<String> scopes = accessToken.getScope() == null ?
+				Collections.emptySet() : new LinkedHashSet<>(accessToken.getScope().toStringList());
+
+		String refreshToken = null;
+		if (accessTokenResponse.getTokens().getRefreshToken() != null) {
+			refreshToken = accessTokenResponse.getTokens().getRefreshToken().getValue();
+		}
+
+		Map<String, Object> additionalParameters = new LinkedHashMap<>(accessTokenResponse.getCustomParameters());
+
+		return OAuth2AccessTokenResponse.withToken(accessToken.getValue()).tokenType(accessTokenType).expiresIn(expiresIn).scopes(scopes)
+				.refreshToken(refreshToken).additionalParameters(additionalParameters).build();
+	}
+}

+ 40 - 0
oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractors.java

@@ -0,0 +1,40 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.core.web.reactive.function;
+
+import org.springframework.http.ReactiveHttpInputMessage;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.web.reactive.function.BodyExtractor;
+import reactor.core.publisher.Mono;
+
+/**
+ * Static factory methods for OAuth2 {@link BodyExtractor} implementations.
+ * @author Rob Winch
+ * @since 5.1
+ */
+public abstract class OAuth2BodyExtractors {
+
+	/**
+	 * Extractor to decode an {@link OAuth2AccessTokenResponse}
+	 * @return a BodyExtractor for {@link OAuth2AccessTokenResponse}
+	 */
+	public static BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> oauth2AccessTokenResponse() {
+		return new OAuth2AccessTokenResponseBodyExtractor();
+	}
+
+	private OAuth2BodyExtractors() {}
+}

+ 125 - 0
oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractorsTests.java

@@ -0,0 +1,125 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.core.web.reactive.function;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.core.codec.ByteBufferDecoder;
+import org.springframework.core.codec.StringDecoder;
+import org.springframework.http.HttpStatus;
+import org.springframework.http.MediaType;
+import org.springframework.http.ReactiveHttpInputMessage;
+import org.springframework.http.codec.DecoderHttpMessageReader;
+import org.springframework.http.codec.FormHttpMessageReader;
+import org.springframework.http.codec.HttpMessageReader;
+import org.springframework.http.codec.json.Jackson2JsonDecoder;
+import org.springframework.http.codec.xml.Jaxb2XmlDecoder;
+import org.springframework.http.server.reactive.ServerHttpResponse;
+import org.springframework.mock.http.client.reactive.MockClientHttpResponse;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.web.reactive.function.BodyExtractor;
+import reactor.core.publisher.Mono;
+
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
+
+/**
+ * @author Rob Winch
+ * @since 5.1
+ */
+public class OAuth2BodyExtractorsTests {
+
+	private BodyExtractor.Context context;
+
+	private Map<String, Object> hints;
+
+	@Before
+	public void createContext() {
+		final List<HttpMessageReader<?>> messageReaders = new ArrayList<>();
+		messageReaders.add(new DecoderHttpMessageReader<>(new ByteBufferDecoder()));
+		messageReaders.add(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes()));
+		messageReaders.add(new DecoderHttpMessageReader<>(new Jaxb2XmlDecoder()));
+		messageReaders.add(new DecoderHttpMessageReader<>(new Jackson2JsonDecoder()));
+		messageReaders.add(new FormHttpMessageReader());
+
+		this.hints = new HashMap<String, Object>();
+		this.context = new BodyExtractor.Context() {
+			@Override
+			public List<HttpMessageReader<?>> messageReaders() {
+				return messageReaders;
+			}
+
+			@Override
+			public Optional<ServerHttpResponse> serverResponse() {
+				return Optional.empty();
+			}
+
+			@Override
+			public Map<String, Object> hints() {
+				return OAuth2BodyExtractorsTests.this.hints;
+			}
+		};
+	}
+
+	@Test
+	public void oauth2AccessTokenResponseWhenInvalidJsonThenException() {
+		BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = OAuth2BodyExtractors
+				.oauth2AccessTokenResponse();
+
+		MockClientHttpResponse response = new MockClientHttpResponse(HttpStatus.OK);
+		response.getHeaders().setContentType(MediaType.APPLICATION_JSON);
+		response.setBody("{");
+
+		Mono<OAuth2AccessTokenResponse> result = extractor.extract(response, this.context);
+
+		assertThatCode(() -> result.block())
+				.isInstanceOf(RuntimeException.class);
+	}
+
+	@Test
+	public void oauth2AccessTokenResponseWhenValidThenCreated() throws Exception {
+		BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = OAuth2BodyExtractors
+				.oauth2AccessTokenResponse();
+
+		MockClientHttpResponse response = new MockClientHttpResponse(HttpStatus.OK);
+		response.getHeaders().setContentType(MediaType.APPLICATION_JSON);
+		response.setBody("{\n"
+			+ "       \"access_token\":\"2YotnFZFEjr1zCsicMWpAA\",\n"
+			+ "       \"token_type\":\"Bearer\",\n"
+			+ "       \"expires_in\":3600,\n"
+			+ "       \"refresh_token\":\"tGzv3JOkF0XG5Qx2TlKWIA\",\n"
+			+ "       \"example_parameter\":\"example_value\"\n"
+			+ "     }");
+
+		Instant now = Instant.now();
+		OAuth2AccessTokenResponse result = extractor.extract(response, this.context).block();
+
+		assertThat(result.getAccessToken().getTokenValue()).isEqualTo("2YotnFZFEjr1zCsicMWpAA");
+		assertThat(result.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER);
+		assertThat(result.getAccessToken().getExpiresAt()).isBetween(now.plusSeconds(3600), now.plusSeconds(3600 + 2));
+		assertThat(result.getRefreshToken().getTokenValue()).isEqualTo("tGzv3JOkF0XG5Qx2TlKWIA");
+		assertThat(result.getAdditionalParameters()).containsEntry("example_parameter", "example_value");
+	}
+}