Просмотр исходного кода

Add OAuth2AuthorizedClientExchangeFilterFunction

Fixes: gh-5386
Rob Winch 7 лет назад
Родитель
Сommit
c68cf991ae

+ 84 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientExchangeFilterFunction.java

@@ -0,0 +1,84 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.web.reactive.function.client;
+
+import org.springframework.http.HttpHeaders;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.web.reactive.function.client.ClientRequest;
+import org.springframework.web.reactive.function.client.ClientResponse;
+import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
+import org.springframework.web.reactive.function.client.ExchangeFunction;
+import reactor.core.publisher.Mono;
+
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.Consumer;
+
+/**
+ * Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth2 requests by including the
+ * token as a Bearer Token.
+ *
+ * @author Rob Winch
+ * @since 5.1
+ */
+public final class OAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
+	/**
+	 * The request attribute name used to locate the {@link OAuth2AuthorizedClient}.
+	 */
+	private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName();
+
+	/**
+	 * Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for
+	 * providing the Bearer Token. Example usage:
+	 *
+	 * <pre>
+	 * Mono<String> response = this.webClient
+	 *    .get()
+	 *    .uri(uri)
+	 *    .attributes(oauth2AuthorizedClient(authorizedClient))
+	 *    // ...
+	 *    .retrieve()
+	 *    .bodyToMono(String.class);
+	 * </pre>
+	 *
+	 * @param authorizedClient the {@link OAuth2AuthorizedClient} to use.
+	 * @return the {@link Consumer} to populate the
+	 */
+	public static Consumer<Map<String, Object>> oauth2AuthorizedClient(OAuth2AuthorizedClient authorizedClient) {
+		return attributes -> attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient);
+	}
+
+	@Override
+	public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
+		Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
+				.map(OAuth2AuthorizedClient.class::cast);
+		return attribute
+				.map(authorizedClient -> bearer(request, authorizedClient))
+				.map(next::exchange)
+				.orElseGet(() -> next.exchange(request));
+	}
+
+	private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) {
+		return ClientRequest.from(request)
+					.headers(bearerToken(authorizedClient.getAccessToken().getTokenValue()))
+					.build();
+	}
+
+	private Consumer<HttpHeaders> bearerToken(String token) {
+		return headers -> headers.set(HttpHeaders.AUTHORIZATION, "Bearer " + token);
+	}
+}

+ 47 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/MockExchangeFunction.java

@@ -0,0 +1,47 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.web.reactive.function.client;
+
+import static org.mockito.Mockito.mock;
+
+import org.springframework.web.reactive.function.client.ClientRequest;
+import org.springframework.web.reactive.function.client.ClientResponse;
+import org.springframework.web.reactive.function.client.ExchangeFunction;
+
+import reactor.core.publisher.Mono;
+
+/**
+ * @author Rob Winch
+ * @since 5.1
+ */
+public class MockExchangeFunction implements ExchangeFunction {
+	private ClientRequest request;
+
+	private ClientResponse response = mock(ClientResponse.class);
+
+	public ClientRequest getRequest() {
+		return this.request;
+	}
+
+	@Override
+	public Mono<ClientResponse> exchange(ClientRequest request) {
+		return Mono.defer(() -> {
+			this.request = request;
+			return Mono.just(this.response);
+		});
+	}
+}

+ 101 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientExchangeFilterFunctionTests.java

@@ -0,0 +1,101 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.web.reactive.function.client;
+
+import org.junit.Test;
+import org.springframework.http.HttpHeaders;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.web.reactive.function.client.ClientRequest;
+
+import java.net.URI;
+import java.time.Duration;
+import java.time.Instant;
+
+import static org.assertj.core.api.Assertions.*;
+import static org.springframework.http.HttpMethod.GET;
+import static org.springframework.security.oauth2.client.web.reactive.function.client.OAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient;
+
+/**
+ * @author Rob Winch
+ * @since 5.1
+ */
+public class OAuth2AuthorizedClientExchangeFilterFunctionTests {
+	private OAuth2AuthorizedClientExchangeFilterFunction function = new OAuth2AuthorizedClientExchangeFilterFunction();
+
+	private MockExchangeFunction exchange = new MockExchangeFunction();
+
+	private ClientRegistration github = ClientRegistration.withRegistrationId("github")
+			.redirectUriTemplate("{baseUrl}/{action}/oauth2/code/{registrationId}")
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.scope("read:user")
+			.authorizationUri("https://github.com/login/oauth/authorize")
+			.tokenUri("https://github.com/login/oauth/access_token")
+			.userInfoUri("https://api.github.com/user")
+			.userNameAttributeName("id")
+			.clientName("GitHub")
+			.clientId("clientId")
+			.clientSecret("clientSecret")
+			.build();
+
+	private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
+			"token",
+			Instant.now(),
+			Instant.now().plus(Duration.ofDays(1)));
+
+	@Test
+	public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() {
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+			.build();
+
+		this.function.filter(request, this.exchange).block();
+
+		assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
+	}
+
+	@Test
+	public void filterWhenAuthorizedClientThenAuthorizationHeader() {
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github,
+				"principalName", this.accessToken);
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.build();
+
+		this.function.filter(request, this.exchange).block();
+
+		assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue());
+	}
+
+	@Test
+	public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() {
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github,
+				"principalName", this.accessToken);
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.header(HttpHeaders.AUTHORIZATION, "Existing")
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.build();
+
+		this.function.filter(request, this.exchange).block();
+
+		HttpHeaders headers = this.exchange.getRequest().headers();
+		assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue());
+	}
+}