|
@@ -16,6 +16,9 @@
|
|
|
|
|
|
package org.springframework.security.oauth2.client.web.reactive.function.client;
|
|
|
|
|
|
+import org.reactivestreams.Subscription;
|
|
|
+import org.springframework.beans.factory.DisposableBean;
|
|
|
+import org.springframework.beans.factory.InitializingBean;
|
|
|
import org.springframework.http.HttpHeaders;
|
|
|
import org.springframework.http.HttpMethod;
|
|
|
import org.springframework.http.MediaType;
|
|
@@ -44,8 +47,12 @@ import org.springframework.web.reactive.function.client.ClientResponse;
|
|
|
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
|
|
|
import org.springframework.web.reactive.function.client.ExchangeFunction;
|
|
|
import org.springframework.web.reactive.function.client.WebClient;
|
|
|
+import reactor.core.CoreSubscriber;
|
|
|
+import reactor.core.publisher.Hooks;
|
|
|
import reactor.core.publisher.Mono;
|
|
|
+import reactor.core.publisher.Operators;
|
|
|
import reactor.core.scheduler.Schedulers;
|
|
|
+import reactor.util.context.Context;
|
|
|
|
|
|
import javax.servlet.http.HttpServletRequest;
|
|
|
import javax.servlet.http.HttpServletResponse;
|
|
@@ -98,7 +105,9 @@ import static org.springframework.security.oauth2.core.web.reactive.function.OAu
|
|
|
* @author Rob Winch
|
|
|
* @since 5.1
|
|
|
*/
|
|
|
-public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
|
|
|
+public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
|
|
+ implements ExchangeFilterFunction, InitializingBean, DisposableBean {
|
|
|
+
|
|
|
/**
|
|
|
* The request attribute name used to locate the {@link OAuth2AuthorizedClient}.
|
|
|
*/
|
|
@@ -108,6 +117,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
|
|
|
private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
|
|
|
private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName();
|
|
|
|
|
|
+ private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName();
|
|
|
+
|
|
|
private Clock clock = Clock.systemUTC();
|
|
|
|
|
|
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
|
|
@@ -123,7 +134,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
|
|
|
|
|
|
private String defaultClientRegistrationId;
|
|
|
|
|
|
- public ServletOAuth2AuthorizedClientExchangeFilterFunction() {}
|
|
|
+ public ServletOAuth2AuthorizedClientExchangeFilterFunction() {
|
|
|
+ }
|
|
|
|
|
|
public ServletOAuth2AuthorizedClientExchangeFilterFunction(
|
|
|
ClientRegistrationRepository clientRegistrationRepository,
|
|
@@ -132,6 +144,16 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
|
|
|
this.authorizedClientRepository = authorizedClientRepository;
|
|
|
}
|
|
|
|
|
|
+ @Override
|
|
|
+ public void afterPropertiesSet() throws Exception {
|
|
|
+ Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.lift((s, sub) -> createRequestContextSubscriber(sub)));
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void destroy() throws Exception {
|
|
|
+ Hooks.resetOnLastOperator(REQUEST_CONTEXT_OPERATOR_KEY);
|
|
|
+ }
|
|
|
+
|
|
|
/**
|
|
|
* Sets the {@link OAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
|
|
|
* client_credentials grant.
|
|
@@ -266,15 +288,36 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
|
|
|
|
|
|
@Override
|
|
|
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
|
|
|
- Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
|
|
|
- .map(OAuth2AuthorizedClient.class::cast);
|
|
|
- return Mono.justOrEmpty(attribute)
|
|
|
- .flatMap(authorizedClient -> authorizedClient(request, next, authorizedClient))
|
|
|
+ return Mono.just(request)
|
|
|
+ .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
|
|
|
+ .switchIfEmpty(mergeRequestAttributesFromContext(request))
|
|
|
+ .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
|
|
|
+ .flatMap(req -> authorizedClient(req, next, getOAuth2AuthorizedClient(req.attributes())))
|
|
|
.map(authorizedClient -> bearer(request, authorizedClient))
|
|
|
.flatMap(next::exchange)
|
|
|
.switchIfEmpty(next.exchange(request));
|
|
|
}
|
|
|
|
|
|
+ private Mono<ClientRequest> mergeRequestAttributesFromContext(ClientRequest request) {
|
|
|
+ return Mono.just(ClientRequest.from(request))
|
|
|
+ .flatMap(builder -> Mono.subscriberContext()
|
|
|
+ .map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx))))
|
|
|
+ .map(ClientRequest.Builder::build);
|
|
|
+ }
|
|
|
+
|
|
|
+ private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) {
|
|
|
+ if (ctx.hasKey(HTTP_SERVLET_REQUEST_ATTR_NAME)) {
|
|
|
+ attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ctx.get(HTTP_SERVLET_REQUEST_ATTR_NAME));
|
|
|
+ }
|
|
|
+ if (ctx.hasKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
|
|
|
+ attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ctx.get(HTTP_SERVLET_RESPONSE_ATTR_NAME));
|
|
|
+ }
|
|
|
+ if (ctx.hasKey(AUTHENTICATION_ATTR_NAME)) {
|
|
|
+ attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, ctx.get(AUTHENTICATION_ATTR_NAME));
|
|
|
+ }
|
|
|
+ populateDefaultOAuth2AuthorizedClient(attrs);
|
|
|
+ }
|
|
|
+
|
|
|
private void populateDefaultRequestResponse(Map<String, Object> attrs) {
|
|
|
if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && attrs.containsKey(
|
|
|
HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
|
|
@@ -435,6 +478,19 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
|
|
|
.build();
|
|
|
}
|
|
|
|
|
|
+ private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> delegate) {
|
|
|
+ HttpServletRequest request = null;
|
|
|
+ HttpServletResponse response = null;
|
|
|
+ ServletRequestAttributes requestAttributes =
|
|
|
+ (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
|
|
|
+ if (requestAttributes != null) {
|
|
|
+ request = requestAttributes.getRequest();
|
|
|
+ response = requestAttributes.getResponse();
|
|
|
+ }
|
|
|
+ Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
|
|
|
+ return new RequestContextSubscriber<>(delegate, request, response, authentication);
|
|
|
+ }
|
|
|
+
|
|
|
private static BodyInserters.FormInserter<String> refreshTokenBody(String refreshToken) {
|
|
|
return BodyInserters
|
|
|
.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue())
|
|
@@ -508,4 +564,55 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
|
|
|
return new UnsupportedOperationException("Not Supported");
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ private static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
|
|
|
+ private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME");
|
|
|
+ private final CoreSubscriber<T> delegate;
|
|
|
+ private final HttpServletRequest request;
|
|
|
+ private final HttpServletResponse response;
|
|
|
+ private final Authentication authentication;
|
|
|
+
|
|
|
+ private RequestContextSubscriber(CoreSubscriber<T> delegate,
|
|
|
+ HttpServletRequest request,
|
|
|
+ HttpServletResponse response,
|
|
|
+ Authentication authentication) {
|
|
|
+ this.delegate = delegate;
|
|
|
+ this.request = request;
|
|
|
+ this.response = response;
|
|
|
+ this.authentication = authentication;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Context currentContext() {
|
|
|
+ Context context = this.delegate.currentContext();
|
|
|
+ if (context.hasKey(CONTEXT_DEFAULTED_ATTR_NAME)) {
|
|
|
+ return context;
|
|
|
+ }
|
|
|
+ return Context.of(
|
|
|
+ CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE,
|
|
|
+ HTTP_SERVLET_REQUEST_ATTR_NAME, this.request,
|
|
|
+ HTTP_SERVLET_RESPONSE_ATTR_NAME, this.response,
|
|
|
+ AUTHENTICATION_ATTR_NAME, this.authentication);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onSubscribe(Subscription s) {
|
|
|
+ this.delegate.onSubscribe(s);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onNext(T t) {
|
|
|
+ this.delegate.onNext(t);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onError(Throwable t) {
|
|
|
+ this.delegate.onError(t);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onComplete() {
|
|
|
+ this.delegate.onComplete();
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|