Browse Source

Polish Memory Leak Mitigation

Issue gh-9841
Josh Cummings 3 years ago
parent
commit
a68411566e

+ 141 - 40
config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,10 +16,14 @@
 
 package org.springframework.security.config.annotation.web.configuration;
 
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Function;
+import java.util.function.Supplier;
 
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
@@ -36,7 +40,6 @@ import org.springframework.beans.factory.InitializingBean;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Configuration;
 import org.springframework.security.core.Authentication;
-import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.web.context.request.RequestAttributes;
 import org.springframework.web.context.request.RequestContextHolder;
@@ -68,17 +71,22 @@ class SecurityReactorContextConfiguration {
 
 		private static final String SECURITY_REACTOR_CONTEXT_OPERATOR_KEY = "org.springframework.security.SECURITY_REACTOR_CONTEXT_OPERATOR";
 
+		private static final Map<Object, Supplier<Object>> CONTEXT_ATTRIBUTE_VALUE_LOADERS = new HashMap<>();
+
+		static {
+			CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletRequest.class,
+					SecurityReactorContextSubscriberRegistrar::getRequest);
+			CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletResponse.class,
+					SecurityReactorContextSubscriberRegistrar::getResponse);
+			CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(Authentication.class,
+					SecurityReactorContextSubscriberRegistrar::getAuthentication);
+		}
+
 		@Override
 		public void afterPropertiesSet() throws Exception {
 			Function<? super Publisher<Object>, ? extends Publisher<Object>> lifter = Operators
 					.liftPublisher((pub, sub) -> createSubscriberIfNecessary(sub));
-			Hooks.onLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY, (pub) -> {
-				if (!contextAttributesAvailable()) {
-					// No need to decorate so return original Publisher
-					return pub;
-				}
-				return lifter.apply(pub);
-			});
+			Hooks.onLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY, lifter::apply);
 		}
 
 		@Override
@@ -94,45 +102,30 @@ class SecurityReactorContextConfiguration {
 			return new SecurityReactorContextSubscriber<>(delegate, getContextAttributes());
 		}
 
-		private static boolean contextAttributesAvailable() {
-			SecurityContext context = SecurityContextHolder.peekContext();
-			Authentication authentication = null;
-			if (context != null) {
-				authentication = context.getAuthentication();
-			}
-			return authentication != null
-					|| RequestContextHolder.getRequestAttributes() instanceof ServletRequestAttributes;
+		private static Map<Object, Object> getContextAttributes() {
+			return new LoadingMap<>(CONTEXT_ATTRIBUTE_VALUE_LOADERS);
 		}
 
-		private static Map<Object, Object> getContextAttributes() {
-			HttpServletRequest servletRequest = null;
-			HttpServletResponse servletResponse = null;
+		private static HttpServletRequest getRequest() {
 			RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
 			if (requestAttributes instanceof ServletRequestAttributes) {
 				ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) requestAttributes;
-				servletRequest = servletRequestAttributes.getRequest();
-				servletResponse = servletRequestAttributes.getResponse(); // possible null
-			}
-			SecurityContext context = SecurityContextHolder.peekContext();
-			Authentication authentication = null;
-			if (context != null) {
-				authentication = context.getAuthentication();
-			}
-			if (authentication == null && servletRequest == null) {
-				return Collections.emptyMap();
-			}
-			Map<Object, Object> contextAttributes = new HashMap<>();
-			if (servletRequest != null) {
-				contextAttributes.put(HttpServletRequest.class, servletRequest);
-			}
-			if (servletResponse != null) {
-				contextAttributes.put(HttpServletResponse.class, servletResponse);
+				return servletRequestAttributes.getRequest();
 			}
-			if (authentication != null) {
-				contextAttributes.put(Authentication.class, authentication);
+			return null;
+		}
+
+		private static HttpServletResponse getResponse() {
+			RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
+			if (requestAttributes instanceof ServletRequestAttributes) {
+				ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) requestAttributes;
+				return servletRequestAttributes.getResponse(); // possible null
 			}
+			return null;
+		}
 
-			return contextAttributes;
+		private static Authentication getAuthentication() {
+			return SecurityContextHolder.getContext().getAuthentication();
 		}
 
 	}
@@ -185,4 +178,112 @@ class SecurityReactorContextConfiguration {
 
 	}
 
+	/**
+	 * A map that computes each value when {@link #get} is invoked
+	 */
+	static class LoadingMap<K, V> implements Map<K, V> {
+
+		private final Map<K, V> loaded = new ConcurrentHashMap<>();
+
+		private final Map<K, Supplier<V>> loaders;
+
+		LoadingMap(Map<K, Supplier<V>> loaders) {
+			this.loaders = Collections.unmodifiableMap(new HashMap<>(loaders));
+		}
+
+		@Override
+		public int size() {
+			return this.loaders.size();
+		}
+
+		@Override
+		public boolean isEmpty() {
+			return this.loaders.isEmpty();
+		}
+
+		@Override
+		public boolean containsKey(Object key) {
+			return this.loaders.containsKey(key);
+		}
+
+		@Override
+		public Set<K> keySet() {
+			return this.loaders.keySet();
+		}
+
+		@Override
+		public V get(Object key) {
+			if (!this.loaders.containsKey(key)) {
+				throw new IllegalArgumentException(
+						"This map only supports the following keys: " + this.loaders.keySet());
+			}
+			return this.loaded.computeIfAbsent((K) key, (k) -> this.loaders.get(k).get());
+		}
+
+		@Override
+		public V put(K key, V value) {
+			if (!this.loaders.containsKey(key)) {
+				throw new IllegalArgumentException(
+						"This map only supports the following keys: " + this.loaders.keySet());
+			}
+			return this.loaded.put(key, value);
+		}
+
+		@Override
+		public V remove(Object key) {
+			if (!this.loaders.containsKey(key)) {
+				throw new IllegalArgumentException(
+						"This map only supports the following keys: " + this.loaders.keySet());
+			}
+			return this.loaded.remove(key);
+		}
+
+		@Override
+		public void putAll(Map<? extends K, ? extends V> m) {
+			for (Map.Entry<? extends K, ? extends V> entry : m.entrySet()) {
+				put(entry.getKey(), entry.getValue());
+			}
+		}
+
+		@Override
+		public void clear() {
+			this.loaded.clear();
+		}
+
+		@Override
+		public boolean containsValue(Object value) {
+			return this.loaded.containsValue(value);
+		}
+
+		@Override
+		public Collection<V> values() {
+			return this.loaded.values();
+		}
+
+		@Override
+		public Set<Entry<K, V>> entrySet() {
+			return this.loaded.entrySet();
+		}
+
+		@Override
+		public boolean equals(Object o) {
+			if (this == o) {
+				return true;
+			}
+			if (o == null || getClass() != o.getClass()) {
+				return false;
+			}
+
+			LoadingMap<?, ?> that = (LoadingMap<?, ?>) o;
+
+			return this.loaded.equals(that.loaded);
+		}
+
+		@Override
+		public int hashCode() {
+			return this.loaded.hashCode();
+		}
+
+	}
+
 }

+ 0 - 5
core/src/main/java/org/springframework/security/core/context/GlobalSecurityContextHolderStrategy.java

@@ -44,11 +44,6 @@ final class GlobalSecurityContextHolderStrategy implements SecurityContextHolder
 		return contextHolder;
 	}
 
-	@Override
-	public SecurityContext peekContext() {
-		return contextHolder;
-	}
-
 	@Override
 	public void setContext(SecurityContext context) {
 		Assert.notNull(context, "Only non-null SecurityContext instances are permitted");

+ 0 - 5
core/src/main/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategy.java

@@ -44,11 +44,6 @@ final class InheritableThreadLocalSecurityContextHolderStrategy implements Secur
 		return ctx;
 	}
 
-	@Override
-	public SecurityContext peekContext() {
-		return contextHolder.get();
-	}
-
 	@Override
 	public void setContext(SecurityContext context) {
 		Assert.notNull(context, "Only non-null SecurityContext instances are permitted");

+ 0 - 8
core/src/main/java/org/springframework/security/core/context/SecurityContextHolder.java

@@ -123,14 +123,6 @@ public class SecurityContextHolder {
 		return strategy.getContext();
 	}
 
-	/**
-	 * Peeks the current <code>SecurityContext</code>.
-	 * @return the security context (may be <code>null</code>)
-	 */
-	public static SecurityContext peekContext() {
-		return strategy.peekContext();
-	}
-
 	/**
 	 * Primarily for troubleshooting purposes, this method shows how many times the class
 	 * has re-initialized its <code>SecurityContextHolderStrategy</code>.

+ 0 - 6
core/src/main/java/org/springframework/security/core/context/SecurityContextHolderStrategy.java

@@ -38,12 +38,6 @@ public interface SecurityContextHolderStrategy {
 	 */
 	SecurityContext getContext();
 
-	/**
-	 * Peeks the current context without creating an empty context.
-	 * @return a context (may be <code>null</code>)
-	 */
-	SecurityContext peekContext();
-
 	/**
 	 * Sets the current context.
 	 * @param context to the new argument (should never be <code>null</code>, although

+ 0 - 5
core/src/main/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategy.java

@@ -45,11 +45,6 @@ final class ThreadLocalSecurityContextHolderStrategy implements SecurityContextH
 		return ctx;
 	}
 
-	@Override
-	public SecurityContext peekContext() {
-		return contextHolder.get();
-	}
-
 	@Override
 	public void setContext(SecurityContext context) {
 		Assert.notNull(context, "Only non-null SecurityContext instances are permitted");