|
@@ -22,6 +22,7 @@ import org.springframework.beans.factory.InitializingBean;
|
|
|
import org.springframework.http.HttpHeaders;
|
|
|
import org.springframework.http.HttpMethod;
|
|
|
import org.springframework.http.MediaType;
|
|
|
+import org.springframework.lang.Nullable;
|
|
|
import org.springframework.security.core.Authentication;
|
|
|
import org.springframework.security.core.GrantedAuthority;
|
|
|
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
|
@@ -103,6 +104,7 @@ import static org.springframework.security.oauth2.core.web.reactive.function.OAu
|
|
|
* </ul>
|
|
|
*
|
|
|
* @author Rob Winch
|
|
|
+ * @author Roman Matiushchenko
|
|
|
* @since 5.1
|
|
|
*/
|
|
|
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
|
@@ -146,7 +148,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
|
|
|
|
|
@Override
|
|
|
public void afterPropertiesSet() throws Exception {
|
|
|
- Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.lift((s, sub) -> createRequestContextSubscriber(sub)));
|
|
|
+ Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.liftPublisher((s, sub) -> createRequestContextSubscriberIfNecessary(sub)));
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -319,14 +321,22 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
|
|
}
|
|
|
|
|
|
private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) {
|
|
|
- if (ctx.hasKey(HTTP_SERVLET_REQUEST_ATTR_NAME)) {
|
|
|
- attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ctx.get(HTTP_SERVLET_REQUEST_ATTR_NAME));
|
|
|
- }
|
|
|
- if (ctx.hasKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
|
|
|
- attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ctx.get(HTTP_SERVLET_RESPONSE_ATTR_NAME));
|
|
|
- }
|
|
|
- if (ctx.hasKey(AUTHENTICATION_ATTR_NAME)) {
|
|
|
- attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, ctx.get(AUTHENTICATION_ATTR_NAME));
|
|
|
+ RequestContextDataHolder holder = RequestContextSubscriber.getRequestContext(ctx);
|
|
|
+ if (holder != null) {
|
|
|
+ HttpServletRequest request = holder.getRequest();
|
|
|
+ if (request != null) {
|
|
|
+ attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request);
|
|
|
+ }
|
|
|
+
|
|
|
+ HttpServletResponse response = holder.getResponse();
|
|
|
+ if (response != null) {
|
|
|
+ attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response);
|
|
|
+ }
|
|
|
+
|
|
|
+ Authentication authentication = holder.getAuthentication();
|
|
|
+ if (authentication != null) {
|
|
|
+ attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
|
|
|
+ }
|
|
|
}
|
|
|
populateDefaultOAuth2AuthorizedClient(attrs);
|
|
|
}
|
|
@@ -488,7 +498,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
|
|
.build();
|
|
|
}
|
|
|
|
|
|
- private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> delegate) {
|
|
|
+ <T> CoreSubscriber<T> createRequestContextSubscriberIfNecessary(CoreSubscriber<T> delegate) {
|
|
|
HttpServletRequest request = null;
|
|
|
HttpServletResponse response = null;
|
|
|
ServletRequestAttributes requestAttributes =
|
|
@@ -498,6 +508,10 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
|
|
response = requestAttributes.getResponse();
|
|
|
}
|
|
|
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
|
|
|
+ if (authentication == null && request == null && response == null) {
|
|
|
+ //do not need to create RequestContextSubscriber with empty data
|
|
|
+ return delegate;
|
|
|
+ }
|
|
|
return new RequestContextSubscriber<>(delegate, request, response, authentication);
|
|
|
}
|
|
|
|
|
@@ -575,34 +589,37 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
|
|
|
- private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME");
|
|
|
+ static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
|
|
|
+ static final String REQUEST_CONTEXT_DATA_HOLDER =
|
|
|
+ RequestContextSubscriber.class.getName().concat(".REQUEST_CONTEXT_DATA_HOLDER");
|
|
|
private final CoreSubscriber<T> delegate;
|
|
|
- private final HttpServletRequest request;
|
|
|
- private final HttpServletResponse response;
|
|
|
- private final Authentication authentication;
|
|
|
+ private final Context context;
|
|
|
|
|
|
- private RequestContextSubscriber(CoreSubscriber<T> delegate,
|
|
|
- HttpServletRequest request,
|
|
|
- HttpServletResponse response,
|
|
|
- Authentication authentication) {
|
|
|
+ RequestContextSubscriber(CoreSubscriber<T> delegate,
|
|
|
+ HttpServletRequest request,
|
|
|
+ HttpServletResponse response,
|
|
|
+ Authentication authentication) {
|
|
|
this.delegate = delegate;
|
|
|
- this.request = request;
|
|
|
- this.response = response;
|
|
|
- this.authentication = authentication;
|
|
|
+
|
|
|
+ Context parentContext = this.delegate.currentContext();
|
|
|
+ Context context;
|
|
|
+ if (parentContext.hasKey(REQUEST_CONTEXT_DATA_HOLDER)) {
|
|
|
+ context = parentContext;
|
|
|
+ } else {
|
|
|
+ context = parentContext.put(REQUEST_CONTEXT_DATA_HOLDER, new RequestContextDataHolder(request, response, authentication));
|
|
|
+ }
|
|
|
+
|
|
|
+ this.context = context;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Nullable
|
|
|
+ private static RequestContextDataHolder getRequestContext(Context ctx) {
|
|
|
+ return ctx.getOrDefault(REQUEST_CONTEXT_DATA_HOLDER, null);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public Context currentContext() {
|
|
|
- Context context = this.delegate.currentContext();
|
|
|
- if (context.hasKey(CONTEXT_DEFAULTED_ATTR_NAME)) {
|
|
|
- return context;
|
|
|
- }
|
|
|
- return Context.of(
|
|
|
- CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE,
|
|
|
- HTTP_SERVLET_REQUEST_ATTR_NAME, this.request,
|
|
|
- HTTP_SERVLET_RESPONSE_ATTR_NAME, this.response,
|
|
|
- AUTHENTICATION_ATTR_NAME, this.authentication);
|
|
|
+ return this.context;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -625,4 +642,33 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
|
|
|
this.delegate.onComplete();
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ static class RequestContextDataHolder {
|
|
|
+ private final HttpServletRequest request;
|
|
|
+ private final HttpServletResponse response;
|
|
|
+ private final Authentication authentication;
|
|
|
+
|
|
|
+ RequestContextDataHolder(@Nullable HttpServletRequest request,
|
|
|
+ @Nullable HttpServletResponse response,
|
|
|
+ @Nullable Authentication authentication) {
|
|
|
+ this.request = request;
|
|
|
+ this.response = response;
|
|
|
+ this.authentication = authentication;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Nullable
|
|
|
+ private HttpServletRequest getRequest() {
|
|
|
+ return this.request;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Nullable
|
|
|
+ private HttpServletResponse getResponse() {
|
|
|
+ return this.response;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Nullable
|
|
|
+ private Authentication getAuthentication() {
|
|
|
+ return this.authentication;
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|