Browse Source

Add ServletOAuth2AuthorizedClientExchangeFilterFunction

Fixes: gh-5545
Rob Winch 7 years ago
parent
commit
1b79bbed7f

+ 408 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

@@ -0,0 +1,408 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.web.reactive.function.client;
+
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpMethod;
+import org.springframework.http.MediaType;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.util.Assert;
+import org.springframework.web.context.request.RequestContextHolder;
+import org.springframework.web.context.request.ServletRequestAttributes;
+import org.springframework.web.reactive.function.BodyInserters;
+import org.springframework.web.reactive.function.client.ClientRequest;
+import org.springframework.web.reactive.function.client.ClientResponse;
+import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
+import org.springframework.web.reactive.function.client.ExchangeFunction;
+import org.springframework.web.reactive.function.client.WebClient;
+import reactor.core.publisher.Mono;
+import reactor.core.scheduler.Schedulers;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.net.URI;
+import java.nio.charset.StandardCharsets;
+import java.time.Clock;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Base64;
+import java.util.Collection;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.Consumer;
+
+import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse;
+import static org.springframework.security.web.http.SecurityHeaders.bearerToken;
+
+/**
+ * Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth2 requests by including the
+ * token as a Bearer Token. It also provides mechanisms for looking up the {@link OAuth2AuthorizedClient}. This class is
+ * intended to be used in a servlet environment.
+ *
+ * Example usage:
+ *
+ * <pre>
+ * OAuth2AuthorizedClientExchangeFilterFunction oauth2 = new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientService);
+ * WebClient webClient = WebClient.builder()
+ *    .apply(oauth2.oauth2Configuration())
+ *    .build();
+ * Mono<String> response = webClient
+ *    .get()
+ *    .uri(uri)
+ *    .attributes(oauth2AuthorizedClient(authorizedClient))
+ *    // ...
+ *    .retrieve()
+ *    .bodyToMono(String.class);
+ * </pre>
+ *
+ * An attempt to automatically refresh the token will be made if all of the following
+ * are true:
+ *
+ * <ul>
+ * <li>The ReactiveOAuth2AuthorizedClientService on the
+ * {@link ServletOAuth2AuthorizedClientExchangeFilterFunction} is not null</li>
+ * <li>A refresh token is present on the OAuth2AuthorizedClient</li>
+ * <li>The access token will be expired in
+ * {@link #setAccessTokenExpiresSkew(Duration)}</li>
+ * <li>The {@link ReactiveSecurityContextHolder} will be used to attempt to save
+ * the token. If it is empty, then the principal name on the OAuth2AuthorizedClient
+ * will be used to create an Authentication for saving.</li>
+ * </ul>
+ *
+ * @author Rob Winch
+ * @since 5.1
+ */
+public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
+	/**
+	 * The request attribute name used to locate the {@link OAuth2AuthorizedClient}.
+	 */
+	private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName();
+	private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2AuthorizedClient.class.getName().concat(".CLIENT_REGISTRATION_ID");
+	private static final String AUTHENTICATION_ATTR_NAME = Authentication.class.getName();
+	private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
+	private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName();
+
+	private Clock clock = Clock.systemUTC();
+
+	private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
+
+	private OAuth2AuthorizedClientRepository authorizedClientRepository;
+
+	public ServletOAuth2AuthorizedClientExchangeFilterFunction() {}
+
+	public ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientRepository authorizedClientRepository) {
+		this.authorizedClientRepository = authorizedClientRepository;
+	}
+
+	/**
+	 * Configures the builder with {@link #defaultRequest()} and adds this as a {@link ExchangeFilterFunction}
+	 * @return the {@link Consumer} to configure the builder
+	 */
+	public Consumer<WebClient.Builder> oauth2Configuration() {
+		return builder -> builder.defaultRequest(defaultRequest()).filter(this);
+	}
+
+	/**
+	 * Provides defaults for the {@link HttpServletRequest} and the {@link HttpServletResponse} using
+	 * {@link RequestContextHolder}. It also provides defaults for the {@link Authentication} using
+	 * {@link SecurityContextHolder}. It also can default the {@link OAuth2AuthorizedClient} using the
+	 * {@link #clientRegistrationId(String)} or the {@link #authentication(Authentication)}.
+	 * @return the {@link Consumer} to populate the attributes
+	 */
+	public Consumer<WebClient.RequestHeadersSpec<?>> defaultRequest() {
+		return spec -> {
+			spec.attributes(attrs -> {
+				populateDefaultRequestResponse(attrs);
+				populateDefaultAuthentication(attrs);
+				populateDefaultOAuth2AuthorizedClient(attrs);
+			});
+		};
+	}
+
+	/**
+	 * Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for
+	 * providing the Bearer Token.
+	 *
+	 * @param authorizedClient the {@link OAuth2AuthorizedClient} to use.
+	 * @return the {@link Consumer} to populate the attributes
+	 */
+	public static Consumer<Map<String, Object>> oauth2AuthorizedClient(OAuth2AuthorizedClient authorizedClient) {
+		return attributes -> attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient);
+	}
+
+	/**
+	 * Modifies the {@link ClientRequest#attributes()} to include the {@link ClientRegistration#getRegistrationId()} to
+	 * be used to look up the {@link OAuth2AuthorizedClient}.
+	 *
+	 * @param clientRegistrationId the {@link ClientRegistration#getRegistrationId()} to
+	 * be used to look up the {@link OAuth2AuthorizedClient}.
+	 * @return the {@link Consumer} to populate the attributes
+	 */
+	public static Consumer<Map<String, Object>> clientRegistrationId(String clientRegistrationId) {
+		return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
+	}
+
+	/**
+	 * Modifies the {@link ClientRequest#attributes()} to include the {@link Authentication} used to
+	 * look up and save the {@link OAuth2AuthorizedClient}. The value is defaulted in
+	 * {@link ServletOAuth2AuthorizedClientExchangeFilterFunction#defaultRequest()}
+	 *
+	 * @param authentication the {@link Authentication} to use.
+	 * @return the {@link Consumer} to populate the attributes
+	 */
+	public static Consumer<Map<String, Object>> authentication(Authentication authentication) {
+		return attributes -> attributes.put(AUTHENTICATION_ATTR_NAME, authentication);
+	}
+
+	/**
+	 * Modifies the {@link ClientRequest#attributes()} to include the {@link HttpServletRequest} used to
+	 * look up and save the {@link OAuth2AuthorizedClient}. The value is defaulted in
+	 * {@link ServletOAuth2AuthorizedClientExchangeFilterFunction#defaultRequest()}
+	 *
+	 * @param request the {@link HttpServletRequest} to use.
+	 * @return the {@link Consumer} to populate the attributes
+	 */
+	public static Consumer<Map<String, Object>> httpServletRequest(HttpServletRequest request) {
+		return attributes -> attributes.put(HTTP_SERVLET_REQUEST_ATTR_NAME, request);
+	}
+
+	/**
+	 * Modifies the {@link ClientRequest#attributes()} to include the {@link HttpServletResponse} used to
+	 * save the {@link OAuth2AuthorizedClient}. The value is defaulted in
+	 * {@link ServletOAuth2AuthorizedClientExchangeFilterFunction#defaultRequest()}
+	 *
+	 * @param response the {@link HttpServletResponse} to use.
+	 * @return the {@link Consumer} to populate the attributes
+	 */
+	public static Consumer<Map<String, Object>> httpServletResponse(HttpServletResponse response) {
+		return attributes -> attributes.put(HTTP_SERVLET_RESPONSE_ATTR_NAME, response);
+	}
+
+	/**
+	 * An access token will be considered expired by comparing its expiration to now +
+	 * this skewed Duration. The default is 1 minute.
+	 * @param accessTokenExpiresSkew the Duration to use.
+	 */
+	public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) {
+		Assert.notNull(accessTokenExpiresSkew, "accessTokenExpiresSkew cannot be null");
+		this.accessTokenExpiresSkew = accessTokenExpiresSkew;
+	}
+
+	@Override
+	public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
+		Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
+				.map(OAuth2AuthorizedClient.class::cast);
+		return Mono.justOrEmpty(attribute)
+				.flatMap(authorizedClient -> authorizedClient(request, next, authorizedClient))
+				.map(authorizedClient -> bearer(request, authorizedClient))
+				.flatMap(next::exchange)
+				.switchIfEmpty(next.exchange(request));
+	}
+
+	private void populateDefaultRequestResponse(Map<String, Object> attrs) {
+		if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && attrs.containsKey(
+				HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
+			return;
+		}
+		ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
+		HttpServletRequest request = null;
+		HttpServletResponse response = null;
+		if (context != null) {
+			request = context.getRequest();
+			response = context.getResponse();
+		}
+		attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request);
+		attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response);
+	}
+
+	private void populateDefaultAuthentication(Map<String, Object> attrs) {
+		if (attrs.containsKey(AUTHENTICATION_ATTR_NAME)) {
+			return;
+		}
+		Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
+		attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
+	}
+
+	private void populateDefaultOAuth2AuthorizedClient(Map<String, Object> attrs) {
+		if (this.authorizedClientRepository == null || attrs.containsKey(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)) {
+			return;
+		}
+
+		Authentication authentication = getAuthentication(attrs);
+		String clientRegistrationId = getClientRegistrationId(attrs);
+		if (clientRegistrationId == null  && authentication instanceof OAuth2AuthenticationToken) {
+			clientRegistrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId();
+		}
+		if (clientRegistrationId != null) {
+			HttpServletRequest request = (HttpServletRequest) attrs.get(
+					HTTP_SERVLET_REQUEST_ATTR_NAME);
+			OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository
+					.loadAuthorizedClient(clientRegistrationId, authentication,
+							request);
+			oauth2AuthorizedClient(authorizedClient).accept(attrs);
+		}
+	}
+
+	private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
+		if (shouldRefresh(authorizedClient)) {
+			return refreshAuthorizedClient(request, next, authorizedClient);
+		}
+		return Mono.just(authorizedClient);
+	}
+
+	private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ClientRequest request, ExchangeFunction next,
+			OAuth2AuthorizedClient authorizedClient) {
+		ClientRegistration clientRegistration = authorizedClient
+				.getClientRegistration();
+		String tokenUri = clientRegistration
+				.getProviderDetails().getTokenUri();
+		ClientRequest refreshRequest = ClientRequest.create(HttpMethod.POST, URI.create(tokenUri))
+				.header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
+				.headers(httpBasic(clientRegistration.getClientId(), clientRegistration.getClientSecret()))
+				.body(refreshTokenBody(authorizedClient.getRefreshToken().getTokenValue()))
+				.build();
+		return next.exchange(refreshRequest)
+				.flatMap(response -> response.body(oauth2AccessTokenResponse()))
+				.map(accessTokenResponse -> new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken()))
+				.map(result -> {
+					Authentication principal = (Authentication) request.attribute(
+							AUTHENTICATION_ATTR_NAME).orElse(new PrincipalNameAuthentication(authorizedClient.getPrincipalName()));
+					HttpServletRequest httpRequest = (HttpServletRequest) request.attributes().get(
+							HTTP_SERVLET_REQUEST_ATTR_NAME);
+					HttpServletResponse httpResponse = (HttpServletResponse) request.attributes().get(
+							HTTP_SERVLET_RESPONSE_ATTR_NAME);
+					this.authorizedClientRepository.saveAuthorizedClient(result, principal, httpRequest, httpResponse);
+					return result;
+				})
+				.publishOn(Schedulers.elastic());
+	}
+
+	private static Consumer<HttpHeaders> httpBasic(String username, String password) {
+		return httpHeaders -> {
+			String credentialsString = username + ":" + password;
+			byte[] credentialBytes = credentialsString.getBytes(StandardCharsets.ISO_8859_1);
+			byte[] encodedBytes = Base64.getEncoder().encode(credentialBytes);
+			String encodedCredentials = new String(encodedBytes, StandardCharsets.ISO_8859_1);
+			httpHeaders.set(HttpHeaders.AUTHORIZATION, "Basic " + encodedCredentials);
+		};
+	}
+
+	private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
+		if (this.authorizedClientRepository == null) {
+			return false;
+		}
+		OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken();
+		if (refreshToken == null) {
+			return false;
+		}
+		Instant now = this.clock.instant();
+		Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
+		if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) {
+			return true;
+		}
+		return false;
+	}
+
+	private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) {
+		return ClientRequest.from(request)
+					.headers(bearerToken(authorizedClient.getAccessToken().getTokenValue()))
+					.build();
+	}
+
+	private static BodyInserters.FormInserter<String> refreshTokenBody(String refreshToken) {
+		return BodyInserters
+				.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue())
+				.with("refresh_token", refreshToken);
+	}
+
+	static OAuth2AuthorizedClient getOAuth2AuthorizedClient(Map<String, Object> attrs) {
+		return (OAuth2AuthorizedClient) attrs.get(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME);
+	}
+
+	static String getClientRegistrationId(Map<String, Object> attrs) {
+		return (String) attrs.get(CLIENT_REGISTRATION_ID_ATTR_NAME);
+	}
+
+	static Authentication getAuthentication(Map<String, Object> attrs) {
+		return (Authentication) attrs.get(AUTHENTICATION_ATTR_NAME);
+	}
+
+	static HttpServletRequest getRequest(Map<String, Object> attrs) {
+		return (HttpServletRequest) attrs.get(HTTP_SERVLET_REQUEST_ATTR_NAME);
+	}
+
+	static HttpServletResponse getResponse(Map<String, Object> attrs) {
+		return (HttpServletResponse) attrs.get(HTTP_SERVLET_RESPONSE_ATTR_NAME);
+	}
+
+	private static class PrincipalNameAuthentication implements Authentication {
+		private final String username;
+
+		private PrincipalNameAuthentication(String username) {
+			this.username = username;
+		}
+
+		@Override
+		public Collection<? extends GrantedAuthority> getAuthorities() {
+			throw unsupported();
+		}
+
+		@Override
+		public Object getCredentials() {
+			throw unsupported();
+		}
+
+		@Override
+		public Object getDetails() {
+			throw unsupported();
+		}
+
+		@Override
+		public Object getPrincipal() {
+			throw unsupported();
+		}
+
+		@Override
+		public boolean isAuthenticated() {
+			throw unsupported();
+		}
+
+		@Override
+		public void setAuthenticated(boolean isAuthenticated)
+				throws IllegalArgumentException {
+			throw unsupported();
+		}
+
+		@Override
+		public String getName() {
+			return this.username;
+		}
+
+		private UnsupportedOperationException unsupported() {
+			return new UnsupportedOperationException("Not Supported");
+		}
+	}
+}

+ 478 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@@ -0,0 +1,478 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.web.reactive.function.client;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+import org.springframework.core.codec.ByteBufferEncoder;
+import org.springframework.core.codec.CharSequenceEncoder;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpMethod;
+import org.springframework.http.codec.EncoderHttpMessageWriter;
+import org.springframework.http.codec.FormHttpMessageWriter;
+import org.springframework.http.codec.HttpMessageWriter;
+import org.springframework.http.codec.ResourceHttpMessageWriter;
+import org.springframework.http.codec.ServerSentEventHttpMessageWriter;
+import org.springframework.http.codec.json.Jackson2JsonEncoder;
+import org.springframework.http.codec.multipart.MultipartHttpMessageWriter;
+import org.springframework.http.server.reactive.ServerHttpRequest;
+import org.springframework.mock.http.client.reactive.MockClientHttpRequest;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.user.OAuth2User;
+import org.springframework.web.context.request.RequestContextHolder;
+import org.springframework.web.context.request.ServletRequestAttributes;
+import org.springframework.web.reactive.function.BodyInserter;
+import org.springframework.web.reactive.function.client.ClientRequest;
+import org.springframework.web.reactive.function.client.WebClient;
+import reactor.core.publisher.Mono;
+
+import java.net.URI;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.Consumer;
+
+import static org.assertj.core.api.Assertions.*;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyZeroInteractions;
+import static org.mockito.Mockito.when;
+import static org.springframework.http.HttpMethod.GET;
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.*;
+
+/**
+ * @author Rob Winch
+ * @since 5.1
+ */
+@RunWith(MockitoJUnitRunner.class)
+public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
+	@Mock
+	private OAuth2AuthorizedClientRepository authorizedClientRepository;
+	@Mock
+	private WebClient.RequestHeadersSpec<?> spec;
+	@Captor
+	private ArgumentCaptor<Consumer<Map<String, Object>>> attrs;
+
+	/**
+	 * Used for get the attributes from defaultRequest.
+	 */
+	private Map<String, Object> result = new HashMap<>();
+
+	private ServletOAuth2AuthorizedClientExchangeFilterFunction function = new ServletOAuth2AuthorizedClientExchangeFilterFunction();
+
+	private MockExchangeFunction exchange = new MockExchangeFunction();
+
+	private Authentication authentication;
+
+	private ClientRegistration github = ClientRegistration.withRegistrationId("github")
+			.redirectUriTemplate("{baseUrl}/{action}/oauth2/code/{registrationId}")
+			.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.scope("read:user")
+			.authorizationUri("https://github.com/login/oauth/authorize")
+			.tokenUri("https://github.com/login/oauth/access_token")
+			.userInfoUri("https://api.github.com/user")
+			.userNameAttributeName("id")
+			.clientName("GitHub")
+			.clientId("clientId")
+			.clientSecret("clientSecret")
+			.build();
+
+	private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
+			"token-0",
+			Instant.now(),
+			Instant.now().plus(Duration.ofDays(1)));
+
+	@Before
+	public void setup() {
+		this.authentication = new TestingAuthenticationToken("test", "this");
+	}
+
+	@After
+	public void cleanup() {
+		SecurityContextHolder.clearContext();
+		RequestContextHolder.resetRequestAttributes();
+	}
+
+	@Test
+	public void defaultRequestRequestResponseWhenNullRequestContextThenRequestAndResponseNull() {
+		Map<String, Object> attrs = getDefaultRequestAttributes();
+		assertThat(getRequest(attrs)).isNull();
+		assertThat(getResponse(attrs)).isNull();
+	}
+
+	@Test
+	public void defaultRequestRequestResponseWhenRequestContextThenRequestAndResponseSet() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response));
+		Map<String, Object> attrs = getDefaultRequestAttributes();
+		assertThat(getRequest(attrs)).isEqualTo(request);
+		assertThat(getResponse(attrs)).isEqualTo(response);
+	}
+
+	@Test
+	public void defaultRequestAuthenticationWhenSecurityContextEmptyThenAuthenticationNull() {
+		Map<String, Object> attrs = getDefaultRequestAttributes();
+		assertThat(getAuthentication(attrs)).isNull();
+	}
+
+	@Test
+	public void defaultRequestAuthenticationWhenAuthenticationSetThenAuthenticationSet() {
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		SecurityContextHolder.getContext().setAuthentication(this.authentication);
+		Map<String, Object> attrs = getDefaultRequestAttributes();
+		assertThat(getAuthentication(attrs)).isEqualTo(this.authentication);
+		verifyZeroInteractions(this.authorizedClientRepository);
+	}
+
+	@Test
+	public void defaultRequestOAuth2AuthorizedClientWhenOAuth2AuthorizationClientAndClientIdThenNotOverride() {
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github,
+				"principalName", this.accessToken);
+		oauth2AuthorizedClient(authorizedClient).accept(this.result);
+		Map<String, Object> attrs = getDefaultRequestAttributes();
+		assertThat(getOAuth2AuthorizedClient(attrs)).isEqualTo(authorizedClient);
+		verifyZeroInteractions(this.authorizedClientRepository);
+	}
+
+	@Test
+	public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() {
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		Map<String, Object> attrs = getDefaultRequestAttributes();
+		assertThat(getOAuth2AuthorizedClient(attrs)).isNull();
+		verifyZeroInteractions(this.authorizedClientRepository);
+	}
+
+	@Test
+	public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationWrongTypeAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() {
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		Map<String, Object> attrs = getDefaultRequestAttributes();
+		assertThat(getOAuth2AuthorizedClient(attrs)).isNull();
+		verifyZeroInteractions(this.authorizedClientRepository);
+	}
+
+	@Test
+	public void defaultRequestOAuth2AuthorizedClientWhenRepositoryNullThenOAuth2AuthorizedClient() {
+		OAuth2User user = mock(OAuth2User.class);
+		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
+		OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id");
+		authentication(token).accept(this.result);
+
+		Map<String, Object> attrs = getDefaultRequestAttributes();
+
+		assertThat(getOAuth2AuthorizedClient(attrs)).isNull();
+	}
+
+	@Test
+	public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() {
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		OAuth2User user = mock(OAuth2User.class);
+		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
+		OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id");
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github,
+				"principalName", this.accessToken);
+		when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient);
+		authentication(token).accept(this.result);
+
+		Map<String, Object> attrs = getDefaultRequestAttributes();
+
+		assertThat(getOAuth2AuthorizedClient(attrs)).isEqualTo(authorizedClient);
+		verify(this.authorizedClientRepository).loadAuthorizedClient(eq(token.getAuthorizedClientRegistrationId()), any(), any());
+	}
+
+	@Test
+	public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegistrationIdThenIdIsExplicit() {
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		OAuth2User user = mock(OAuth2User.class);
+		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
+		OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id");
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github,
+				"principalName", this.accessToken);
+		when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient);
+		authentication(token).accept(this.result);
+		clientRegistrationId("explicit").accept(this.result);
+
+		Map<String, Object> attrs = getDefaultRequestAttributes();
+
+		assertThat(getOAuth2AuthorizedClient(attrs)).isEqualTo(authorizedClient);
+		verify(this.authorizedClientRepository).loadAuthorizedClient(eq("explicit"), any(), any());
+	}
+
+	@Test
+	public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdThenOAuth2AuthorizedClient() {
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		OAuth2User user = mock(OAuth2User.class);
+		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github,
+				"principalName", this.accessToken);
+		when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient);
+		clientRegistrationId("id").accept(this.result);
+
+		Map<String, Object> attrs = getDefaultRequestAttributes();
+
+		assertThat(getOAuth2AuthorizedClient(attrs)).isEqualTo(authorizedClient);
+		verify(this.authorizedClientRepository).loadAuthorizedClient(eq("id"), any(), any());
+	}
+
+	private Map<String, Object> getDefaultRequestAttributes() {
+		this.function.defaultRequest().accept(this.spec);
+		verify(this.spec).attributes(this.attrs.capture());
+
+		this.attrs.getValue().accept(this.result);
+
+		return this.result;
+	}
+
+	@Test
+	public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() {
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.build();
+
+		this.function.filter(request, this.exchange).block();
+
+		assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
+	}
+
+	@Test
+	public void filterWhenAuthorizedClientThenAuthorizationHeader() {
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github,
+				"principalName", this.accessToken);
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.build();
+
+		this.function.filter(request, this.exchange).block();
+
+		assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer " + this.accessToken.getTokenValue());
+	}
+
+	@Test
+	public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() {
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github,
+				"principalName", this.accessToken);
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.header(HttpHeaders.AUTHORIZATION, "Existing")
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.build();
+
+		this.function.filter(request, this.exchange).block();
+
+		HttpHeaders headers = this.exchange.getRequest().headers();
+		assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue());
+	}
+
+	@Test
+	public void filterWhenRefreshRequiredThenRefresh() {
+		OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
+				.tokenType(OAuth2AccessToken.TokenType.BEARER)
+				.expiresIn(3600)
+				.refreshToken("refresh-1")
+				.build();
+		when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response));
+		Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
+		Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
+		Instant refreshTokenExpiresAt = Instant.now().plus(Duration.ofHours(1));
+
+		this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
+				this.accessToken.getTokenValue(),
+				issuedAt,
+				accessTokenExpiresAt);
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+
+		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github,
+				"principalName", this.accessToken, refreshToken);
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.attributes(authentication(this.authentication))
+				.build();
+
+		this.function.filter(request, this.exchange).block();
+
+		verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(this.authentication), any(), any());
+
+		List<ClientRequest> requests = this.exchange.getRequests();
+		assertThat(requests).hasSize(2);
+
+		ClientRequest request0 = requests.get(0);
+		assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50SWQ6Y2xpZW50U2VjcmV0");
+		assertThat(request0.url().toASCIIString()).isEqualTo("https://github.com/login/oauth/access_token");
+		assertThat(request0.method()).isEqualTo(HttpMethod.POST);
+		assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token");
+
+		ClientRequest request1 = requests.get(1);
+		assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1");
+		assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
+		assertThat(request1.method()).isEqualTo(HttpMethod.GET);
+		assertThat(getBody(request1)).isEmpty();
+	}
+
+	@Test
+	public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() {
+		OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
+				.tokenType(OAuth2AccessToken.TokenType.BEARER)
+				.expiresIn(3600)
+				.refreshToken("refresh-1")
+				.build();
+		when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response));
+		Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
+		Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
+		Instant refreshTokenExpiresAt = Instant.now().plus(Duration.ofHours(1));
+
+		this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
+				this.accessToken.getTokenValue(),
+				issuedAt,
+				accessTokenExpiresAt);
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+
+		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github,
+				"principalName", this.accessToken, refreshToken);
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.build();
+
+		this.function.filter(request, this.exchange)
+				.block();
+
+		verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any(), any());
+
+		List<ClientRequest> requests = this.exchange.getRequests();
+		assertThat(requests).hasSize(2);
+
+		ClientRequest request0 = requests.get(0);
+		assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50SWQ6Y2xpZW50U2VjcmV0");
+		assertThat(request0.url().toASCIIString()).isEqualTo("https://github.com/login/oauth/access_token");
+		assertThat(request0.method()).isEqualTo(HttpMethod.POST);
+		assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token");
+
+		ClientRequest request1 = requests.get(1);
+		assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1");
+		assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
+		assertThat(request1.method()).isEqualTo(HttpMethod.GET);
+		assertThat(getBody(request1)).isEmpty();
+	}
+
+	@Test
+	public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github,
+				"principalName", this.accessToken);
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.build();
+
+		this.function.filter(request, this.exchange).block();
+
+		List<ClientRequest> requests = this.exchange.getRequests();
+		assertThat(requests).hasSize(1);
+
+		ClientRequest request0 = requests.get(0);
+		assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
+		assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com");
+		assertThat(request0.method()).isEqualTo(HttpMethod.GET);
+		assertThat(getBody(request0)).isEmpty();
+	}
+
+	@Test
+	public void filterWhenNotExpiredThenShouldRefreshFalse() {
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+
+		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.github,
+				"principalName", this.accessToken, refreshToken);
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.build();
+
+		this.function.filter(request, this.exchange).block();
+
+		List<ClientRequest> requests = this.exchange.getRequests();
+		assertThat(requests).hasSize(1);
+
+		ClientRequest request0 = requests.get(0);
+		assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
+		assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com");
+		assertThat(request0.method()).isEqualTo(HttpMethod.GET);
+		assertThat(getBody(request0)).isEmpty();
+	}
+
+	private static String getBody(ClientRequest request) {
+		final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
+		messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));
+		messageWriters.add(new EncoderHttpMessageWriter<>(CharSequenceEncoder.textPlainOnly()));
+		messageWriters.add(new ResourceHttpMessageWriter());
+		Jackson2JsonEncoder jsonEncoder = new Jackson2JsonEncoder();
+		messageWriters.add(new EncoderHttpMessageWriter<>(jsonEncoder));
+		messageWriters.add(new ServerSentEventHttpMessageWriter(jsonEncoder));
+		messageWriters.add(new FormHttpMessageWriter());
+		messageWriters.add(new EncoderHttpMessageWriter<>(CharSequenceEncoder.allMimeTypes()));
+		messageWriters.add(new MultipartHttpMessageWriter(messageWriters));
+
+		BodyInserter.Context context = new BodyInserter.Context() {
+			@Override
+			public List<HttpMessageWriter<?>> messageWriters() {
+				return messageWriters;
+			}
+
+			@Override
+			public Optional<ServerHttpRequest> serverRequest() {
+				return Optional.empty();
+			}
+
+			@Override
+			public Map<String, Object> hints() {
+				return new HashMap<>();
+			}
+		};
+
+		MockClientHttpRequest body = new MockClientHttpRequest(HttpMethod.GET, "/");
+		request.body().insert(body, context).block();
+		return body.getBodyAsString().block();
+	}
+
+}