|
@@ -1,5 +1,5 @@
|
|
|
/*
|
|
|
- * Copyright 2002-2020 the original author or authors.
|
|
|
+ * Copyright 2002-2021 the original author or authors.
|
|
|
*
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
* you may not use this file except in compliance with the License.
|
|
@@ -16,10 +16,14 @@
|
|
|
|
|
|
package org.springframework.security.config.annotation.web.configuration;
|
|
|
|
|
|
+import java.util.Collection;
|
|
|
import java.util.Collections;
|
|
|
import java.util.HashMap;
|
|
|
import java.util.Map;
|
|
|
+import java.util.Set;
|
|
|
+import java.util.concurrent.ConcurrentHashMap;
|
|
|
import java.util.function.Function;
|
|
|
+import java.util.function.Supplier;
|
|
|
|
|
|
import javax.servlet.http.HttpServletRequest;
|
|
|
import javax.servlet.http.HttpServletResponse;
|
|
@@ -36,7 +40,6 @@ import org.springframework.beans.factory.InitializingBean;
|
|
|
import org.springframework.context.annotation.Bean;
|
|
|
import org.springframework.context.annotation.Configuration;
|
|
|
import org.springframework.security.core.Authentication;
|
|
|
-import org.springframework.security.core.context.SecurityContext;
|
|
|
import org.springframework.security.core.context.SecurityContextHolder;
|
|
|
import org.springframework.web.context.request.RequestAttributes;
|
|
|
import org.springframework.web.context.request.RequestContextHolder;
|
|
@@ -68,17 +71,22 @@ class SecurityReactorContextConfiguration {
|
|
|
|
|
|
private static final String SECURITY_REACTOR_CONTEXT_OPERATOR_KEY = "org.springframework.security.SECURITY_REACTOR_CONTEXT_OPERATOR";
|
|
|
|
|
|
+ private static final Map<Object, Supplier<Object>> CONTEXT_ATTRIBUTE_VALUE_LOADERS = new HashMap<>();
|
|
|
+
|
|
|
+ static {
|
|
|
+ CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletRequest.class,
|
|
|
+ SecurityReactorContextSubscriberRegistrar::getRequest);
|
|
|
+ CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletResponse.class,
|
|
|
+ SecurityReactorContextSubscriberRegistrar::getResponse);
|
|
|
+ CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(Authentication.class,
|
|
|
+ SecurityReactorContextSubscriberRegistrar::getAuthentication);
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
public void afterPropertiesSet() throws Exception {
|
|
|
Function<? super Publisher<Object>, ? extends Publisher<Object>> lifter = Operators
|
|
|
.liftPublisher((pub, sub) -> createSubscriberIfNecessary(sub));
|
|
|
- Hooks.onLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY, (pub) -> {
|
|
|
- if (!contextAttributesAvailable()) {
|
|
|
- // No need to decorate so return original Publisher
|
|
|
- return pub;
|
|
|
- }
|
|
|
- return lifter.apply(pub);
|
|
|
- });
|
|
|
+ Hooks.onLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY, lifter::apply);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -94,45 +102,30 @@ class SecurityReactorContextConfiguration {
|
|
|
return new SecurityReactorContextSubscriber<>(delegate, getContextAttributes());
|
|
|
}
|
|
|
|
|
|
- private static boolean contextAttributesAvailable() {
|
|
|
- SecurityContext context = SecurityContextHolder.peekContext();
|
|
|
- Authentication authentication = null;
|
|
|
- if (context != null) {
|
|
|
- authentication = context.getAuthentication();
|
|
|
- }
|
|
|
- return authentication != null
|
|
|
- || RequestContextHolder.getRequestAttributes() instanceof ServletRequestAttributes;
|
|
|
+ private static Map<Object, Object> getContextAttributes() {
|
|
|
+ return new LoadingMap<>(CONTEXT_ATTRIBUTE_VALUE_LOADERS);
|
|
|
}
|
|
|
|
|
|
- private static Map<Object, Object> getContextAttributes() {
|
|
|
- HttpServletRequest servletRequest = null;
|
|
|
- HttpServletResponse servletResponse = null;
|
|
|
+ private static HttpServletRequest getRequest() {
|
|
|
RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
|
|
|
if (requestAttributes instanceof ServletRequestAttributes) {
|
|
|
ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) requestAttributes;
|
|
|
- servletRequest = servletRequestAttributes.getRequest();
|
|
|
- servletResponse = servletRequestAttributes.getResponse(); // possible null
|
|
|
- }
|
|
|
- SecurityContext context = SecurityContextHolder.peekContext();
|
|
|
- Authentication authentication = null;
|
|
|
- if (context != null) {
|
|
|
- authentication = context.getAuthentication();
|
|
|
- }
|
|
|
- if (authentication == null && servletRequest == null) {
|
|
|
- return Collections.emptyMap();
|
|
|
- }
|
|
|
- Map<Object, Object> contextAttributes = new HashMap<>();
|
|
|
- if (servletRequest != null) {
|
|
|
- contextAttributes.put(HttpServletRequest.class, servletRequest);
|
|
|
- }
|
|
|
- if (servletResponse != null) {
|
|
|
- contextAttributes.put(HttpServletResponse.class, servletResponse);
|
|
|
+ return servletRequestAttributes.getRequest();
|
|
|
}
|
|
|
- if (authentication != null) {
|
|
|
- contextAttributes.put(Authentication.class, authentication);
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+
|
|
|
+ private static HttpServletResponse getResponse() {
|
|
|
+ RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
|
|
|
+ if (requestAttributes instanceof ServletRequestAttributes) {
|
|
|
+ ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) requestAttributes;
|
|
|
+ return servletRequestAttributes.getResponse(); // possible null
|
|
|
}
|
|
|
+ return null;
|
|
|
+ }
|
|
|
|
|
|
- return contextAttributes;
|
|
|
+ private static Authentication getAuthentication() {
|
|
|
+ return SecurityContextHolder.getContext().getAuthentication();
|
|
|
}
|
|
|
|
|
|
}
|
|
@@ -185,4 +178,112 @@ class SecurityReactorContextConfiguration {
|
|
|
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * A map that computes each value when {@link #get} is invoked
|
|
|
+ */
|
|
|
+ static class LoadingMap<K, V> implements Map<K, V> {
|
|
|
+
|
|
|
+ private final Map<K, V> loaded = new ConcurrentHashMap<>();
|
|
|
+
|
|
|
+ private final Map<K, Supplier<V>> loaders;
|
|
|
+
|
|
|
+ LoadingMap(Map<K, Supplier<V>> loaders) {
|
|
|
+ this.loaders = Collections.unmodifiableMap(new HashMap<>(loaders));
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public int size() {
|
|
|
+ return this.loaders.size();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public boolean isEmpty() {
|
|
|
+ return this.loaders.isEmpty();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public boolean containsKey(Object key) {
|
|
|
+ return this.loaders.containsKey(key);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Set<K> keySet() {
|
|
|
+ return this.loaders.keySet();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public V get(Object key) {
|
|
|
+ if (!this.loaders.containsKey(key)) {
|
|
|
+ throw new IllegalArgumentException(
|
|
|
+ "This map only supports the following keys: " + this.loaders.keySet());
|
|
|
+ }
|
|
|
+ return this.loaded.computeIfAbsent((K) key, (k) -> this.loaders.get(k).get());
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public V put(K key, V value) {
|
|
|
+ if (!this.loaders.containsKey(key)) {
|
|
|
+ throw new IllegalArgumentException(
|
|
|
+ "This map only supports the following keys: " + this.loaders.keySet());
|
|
|
+ }
|
|
|
+ return this.loaded.put(key, value);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public V remove(Object key) {
|
|
|
+ if (!this.loaders.containsKey(key)) {
|
|
|
+ throw new IllegalArgumentException(
|
|
|
+ "This map only supports the following keys: " + this.loaders.keySet());
|
|
|
+ }
|
|
|
+ return this.loaded.remove(key);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void putAll(Map<? extends K, ? extends V> m) {
|
|
|
+ for (Map.Entry<? extends K, ? extends V> entry : m.entrySet()) {
|
|
|
+ put(entry.getKey(), entry.getValue());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void clear() {
|
|
|
+ this.loaded.clear();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public boolean containsValue(Object value) {
|
|
|
+ return this.loaded.containsValue(value);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Collection<V> values() {
|
|
|
+ return this.loaded.values();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Set<Entry<K, V>> entrySet() {
|
|
|
+ return this.loaded.entrySet();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public boolean equals(Object o) {
|
|
|
+ if (this == o) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (o == null || getClass() != o.getClass()) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ LoadingMap<?, ?> that = (LoadingMap<?, ?>) o;
|
|
|
+
|
|
|
+ return this.loaded.equals(that.loaded);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public int hashCode() {
|
|
|
+ return this.loaded.hashCode();
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
}
|