Quellcode durchsuchen

Polish spring-security-messaging main code

Manually polish `spring-security-messaging` following the formatting
and checkstyle fixes.

Issue gh-8945
Phillip Webb vor 5 Jahren
Ursprung
Commit
ad1dbf425f
16 geänderte Dateien mit 60 neuen und 103 gelöschten Zeilen
  1. 1 2
      messaging/src/main/java/org/springframework/security/messaging/access/expression/EvaluationContextPostProcessor.java
  2. 3 5
      messaging/src/main/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactory.java
  3. 1 1
      messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionConfigAttribute.java
  4. 3 7
      messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionVoter.java
  5. 0 2
      messaging/src/main/java/org/springframework/security/messaging/access/intercept/DefaultMessageSecurityMetadataSource.java
  6. 1 7
      messaging/src/main/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolver.java
  7. 18 23
      messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java
  8. 1 8
      messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolver.java
  9. 1 8
      messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java
  10. 8 4
      messaging/src/main/java/org/springframework/security/messaging/util/matcher/AbstractMessageMatcherComposite.java
  11. 4 5
      messaging/src/main/java/org/springframework/security/messaging/util/matcher/AndMessageMatcher.java
  12. 9 7
      messaging/src/main/java/org/springframework/security/messaging/util/matcher/MessageMatcher.java
  13. 4 5
      messaging/src/main/java/org/springframework/security/messaging/util/matcher/OrMessageMatcher.java
  14. 3 10
      messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpDestinationMessageMatcher.java
  15. 0 2
      messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpMessageTypeMatcher.java
  16. 3 7
      messaging/src/main/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptor.java

+ 1 - 2
messaging/src/main/java/org/springframework/security/messaging/access/expression/EvaluationContextPostProcessor.java

@@ -19,8 +19,7 @@ package org.springframework.security.messaging.access.expression;
 import org.springframework.expression.EvaluationContext;
 
 /**
- *
- * /** Allows post processing the {@link EvaluationContext}
+ * Allows post processing the {@link EvaluationContext}
  *
  * <p>
  * This API is intentionally kept package scope as it may evolve over time.

+ 3 - 5
messaging/src/main/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactory.java

@@ -38,6 +38,9 @@ import org.springframework.security.messaging.util.matcher.MessageMatcher;
  */
 public final class ExpressionBasedMessageSecurityMetadataSourceFactory {
 
+	private ExpressionBasedMessageSecurityMetadataSourceFactory() {
+	}
+
 	/**
 	 * Create a {@link MessageSecurityMetadataSource} that uses {@link MessageMatcher}
 	 * mapped to Spring Expressions. Each entry is considered in order and only the first
@@ -108,9 +111,7 @@ public final class ExpressionBasedMessageSecurityMetadataSourceFactory {
 	public static MessageSecurityMetadataSource createExpressionMessageMetadataSource(
 			LinkedHashMap<MessageMatcher<?>, String> matcherToExpression,
 			SecurityExpressionHandler<Message<Object>> handler) {
-
 		LinkedHashMap<MessageMatcher<?>, Collection<ConfigAttribute>> matcherToAttrs = new LinkedHashMap<>();
-
 		for (Map.Entry<MessageMatcher<?>, String> entry : matcherToExpression.entrySet()) {
 			MessageMatcher<?> matcher = entry.getKey();
 			String rawExpression = entry.getValue();
@@ -121,7 +122,4 @@ public final class ExpressionBasedMessageSecurityMetadataSourceFactory {
 		return new DefaultMessageSecurityMetadataSource(matcherToAttrs);
 	}
 
-	private ExpressionBasedMessageSecurityMetadataSourceFactory() {
-	}
-
 }

+ 1 - 1
messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionConfigAttribute.java

@@ -69,7 +69,7 @@ class MessageExpressionConfigAttribute implements ConfigAttribute, EvaluationCon
 	@Override
 	public EvaluationContext postProcess(EvaluationContext ctx, Message<?> message) {
 		if (this.matcher instanceof SimpDestinationMessageMatcher) {
-			final Map<String, String> variables = ((SimpDestinationMessageMatcher) this.matcher)
+			Map<String, String> variables = ((SimpDestinationMessageMatcher) this.matcher)
 					.extractPathVariables(message);
 			for (Map.Entry<String, String> entry : variables.entrySet()) {
 				ctx.setVariable(entry.getKey(), entry.getValue());

+ 3 - 7
messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionVoter.java

@@ -44,19 +44,15 @@ public class MessageExpressionVoter<T> implements AccessDecisionVoter<Message<T>
 
 	@Override
 	public int vote(Authentication authentication, Message<T> message, Collection<ConfigAttribute> attributes) {
-		assert authentication != null;
-		assert message != null;
-		assert attributes != null;
-
+		Assert.notNull(authentication, "authentication must not be null");
+		Assert.notNull(message, "message must not be null");
+		Assert.notNull(attributes, "attributes must not be null");
 		MessageExpressionConfigAttribute attr = findConfigAttribute(attributes);
-
 		if (attr == null) {
 			return ACCESS_ABSTAIN;
 		}
-
 		EvaluationContext ctx = this.expressionHandler.createEvaluationContext(authentication, message);
 		ctx = attr.postProcess(ctx, message);
-
 		return ExpressionUtils.evaluateAsBoolean(attr.getAuthorizeExpression(), ctx) ? ACCESS_GRANTED : ACCESS_DENIED;
 	}
 

+ 0 - 2
messaging/src/main/java/org/springframework/security/messaging/access/intercept/DefaultMessageSecurityMetadataSource.java

@@ -65,11 +65,9 @@ public final class DefaultMessageSecurityMetadataSource implements MessageSecuri
 	@Override
 	public Collection<ConfigAttribute> getAllConfigAttributes() {
 		Set<ConfigAttribute> allAttributes = new HashSet<>();
-
 		for (Collection<ConfigAttribute> entry : this.messageMap.values()) {
 			allAttributes.addAll(entry);
 		}
-
 		return allAttributes;
 	}
 

+ 1 - 7
messaging/src/main/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolver.java

@@ -98,26 +98,20 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet
 			return null;
 		}
 		Object principal = authentication.getPrincipal();
-
 		AuthenticationPrincipal authPrincipal = findMethodAnnotation(AuthenticationPrincipal.class, parameter);
-
 		String expressionToParse = authPrincipal.expression();
 		if (StringUtils.hasLength(expressionToParse)) {
 			StandardEvaluationContext context = new StandardEvaluationContext();
 			context.setRootObject(principal);
 			context.setVariable("this", principal);
-
 			Expression expression = this.parser.parseExpression(expressionToParse);
 			principal = expression.getValue(context);
 		}
-
 		if (principal != null && !parameter.getParameterType().isAssignableFrom(principal.getClass())) {
 			if (authPrincipal.errorOnInvalidType()) {
 				throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType());
 			}
-			else {
-				return null;
-			}
+			return null;
 		}
 		return principal;
 	}

+ 18 - 23
messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java

@@ -43,9 +43,9 @@ import org.springframework.util.Assert;
 public final class SecurityContextChannelInterceptor extends ChannelInterceptorAdapter
 		implements ExecutorChannelInterceptor {
 
-	private final SecurityContext EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext();
+	private static final SecurityContext EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext();
 
-	private static final ThreadLocal<Stack<SecurityContext>> ORIGINAL_CONTEXT = new ThreadLocal<>();
+	private static final ThreadLocal<Stack<SecurityContext>> originalContext = new ThreadLocal<>();
 
 	private final String authenticationHeaderName;
 
@@ -110,46 +110,41 @@ public final class SecurityContextChannelInterceptor extends ChannelInterceptorA
 
 	private void setup(Message<?> message) {
 		SecurityContext currentContext = SecurityContextHolder.getContext();
-
-		Stack<SecurityContext> contextStack = ORIGINAL_CONTEXT.get();
+		Stack<SecurityContext> contextStack = originalContext.get();
 		if (contextStack == null) {
 			contextStack = new Stack<>();
-			ORIGINAL_CONTEXT.set(contextStack);
+			originalContext.set(contextStack);
 		}
 		contextStack.push(currentContext);
-
 		Object user = message.getHeaders().get(this.authenticationHeaderName);
-
-		Authentication authentication;
-		if ((user instanceof Authentication)) {
-			authentication = (Authentication) user;
-		}
-		else {
-			authentication = this.anonymous;
-		}
+		Authentication authentication = getAuthentication(user);
 		SecurityContext context = SecurityContextHolder.createEmptyContext();
 		context.setAuthentication(authentication);
 		SecurityContextHolder.setContext(context);
 	}
 
-	private void cleanup() {
-		Stack<SecurityContext> contextStack = ORIGINAL_CONTEXT.get();
+	private Authentication getAuthentication(Object user) {
+		if ((user instanceof Authentication)) {
+			return (Authentication) user;
+		}
+		return this.anonymous;
+	}
 
+	private void cleanup() {
+		Stack<SecurityContext> contextStack = originalContext.get();
 		if (contextStack == null || contextStack.isEmpty()) {
 			SecurityContextHolder.clearContext();
-			ORIGINAL_CONTEXT.remove();
+			originalContext.remove();
 			return;
 		}
-
-		SecurityContext originalContext = contextStack.pop();
-
+		SecurityContext context = contextStack.pop();
 		try {
-			if (this.EMPTY_CONTEXT.equals(originalContext)) {
+			if (SecurityContextChannelInterceptor.EMPTY_CONTEXT.equals(context)) {
 				SecurityContextHolder.clearContext();
-				ORIGINAL_CONTEXT.remove();
+				originalContext.remove();
 			}
 			else {
-				SecurityContextHolder.setContext(originalContext);
+				SecurityContextHolder.setContext(context);
 			}
 		}
 		catch (Throwable ex) {

+ 1 - 8
messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolver.java

@@ -134,28 +134,21 @@ public class AuthenticationPrincipalArgumentResolver implements HandlerMethodArg
 
 	private Object resolvePrincipal(MethodParameter parameter, Object principal) {
 		AuthenticationPrincipal authPrincipal = findMethodAnnotation(AuthenticationPrincipal.class, parameter);
-
 		String expressionToParse = authPrincipal.expression();
 		if (StringUtils.hasLength(expressionToParse)) {
 			StandardEvaluationContext context = new StandardEvaluationContext();
 			context.setRootObject(principal);
 			context.setVariable("this", principal);
 			context.setBeanResolver(this.beanResolver);
-
 			Expression expression = this.parser.parseExpression(expressionToParse);
 			principal = expression.getValue(context);
 		}
-
 		if (isInvalidType(parameter, principal)) {
-
 			if (authPrincipal.errorOnInvalidType()) {
 				throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType());
 			}
-			else {
-				return null;
-			}
+			return null;
 		}
-
 		return principal;
 	}
 

+ 1 - 8
messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java

@@ -133,28 +133,21 @@ public class CurrentSecurityContextArgumentResolver implements HandlerMethodArgu
 
 	private Object resolveSecurityContext(MethodParameter parameter, Object securityContext) {
 		CurrentSecurityContext contextAnno = findMethodAnnotation(CurrentSecurityContext.class, parameter);
-
 		String expressionToParse = contextAnno.expression();
 		if (StringUtils.hasLength(expressionToParse)) {
 			StandardEvaluationContext context = new StandardEvaluationContext();
 			context.setRootObject(securityContext);
 			context.setVariable("this", securityContext);
 			context.setBeanResolver(this.beanResolver);
-
 			Expression expression = this.parser.parseExpression(expressionToParse);
 			securityContext = expression.getValue(context);
 		}
-
 		if (isInvalidType(parameter, securityContext)) {
-
 			if (contextAnno.errorOnInvalidType()) {
 				throw new ClassCastException(securityContext + " is not assignable to " + parameter.getParameterType());
 			}
-			else {
-				return null;
-			}
+			return null;
 		}
-
 		return securityContext;
 	}
 

+ 8 - 4
messaging/src/main/java/org/springframework/security/messaging/util/matcher/AbstractMessageMatcherComposite.java

@@ -31,7 +31,13 @@ import org.springframework.util.Assert;
  */
 public abstract class AbstractMessageMatcherComposite<T> implements MessageMatcher<T> {
 
-	protected final Log LOGGER = LogFactory.getLog(getClass());
+	protected final Log logger = LogFactory.getLog(getClass());
+
+	/**
+	 * @deprecated since 5.4 in favor of {@link #logger}
+	 */
+	@Deprecated
+	protected final Log LOGGER = this.logger;
 
 	private final List<MessageMatcher<T>> messageMatchers;
 
@@ -41,9 +47,7 @@ public abstract class AbstractMessageMatcherComposite<T> implements MessageMatch
 	 */
 	AbstractMessageMatcherComposite(List<MessageMatcher<T>> messageMatchers) {
 		Assert.notEmpty(messageMatchers, "messageMatchers must contain a value");
-		if (messageMatchers.contains(null)) {
-			throw new IllegalArgumentException("messageMatchers cannot contain null values");
-		}
+		Assert.isTrue(!messageMatchers.contains(null), "messageMatchers cannot contain null values");
 		this.messageMatchers = messageMatchers;
 
 	}

+ 4 - 5
messaging/src/main/java/org/springframework/security/messaging/util/matcher/AndMessageMatcher.java

@@ -18,6 +18,7 @@ package org.springframework.security.messaging.util.matcher;
 
 import java.util.List;
 
+import org.springframework.core.log.LogMessage;
 import org.springframework.messaging.Message;
 
 /**
@@ -49,15 +50,13 @@ public final class AndMessageMatcher<T> extends AbstractMessageMatcherComposite<
 	@Override
 	public boolean matches(Message<? extends T> message) {
 		for (MessageMatcher<T> matcher : getMessageMatchers()) {
-			if (this.LOGGER.isDebugEnabled()) {
-				this.LOGGER.debug("Trying to match using " + matcher);
-			}
+			this.logger.debug(LogMessage.format("Trying to match using %s", matcher));
 			if (!matcher.matches(message)) {
-				this.LOGGER.debug("Did not match");
+				this.logger.debug("Did not match");
 				return false;
 			}
 		}
-		this.LOGGER.debug("All messageMatchers returned true");
+		this.logger.debug("All messageMatchers returned true");
 		return true;
 	}
 

+ 9 - 7
messaging/src/main/java/org/springframework/security/messaging/util/matcher/MessageMatcher.java

@@ -26,17 +26,11 @@ import org.springframework.messaging.Message;
  */
 public interface MessageMatcher<T> {
 
-	/**
-	 * Returns true if the {@link Message} matches, else false
-	 * @param message the {@link Message} to match on
-	 * @return true if the {@link Message} matches, else false
-	 */
-	boolean matches(Message<? extends T> message);
-
 	/**
 	 * Matches every {@link Message}
 	 */
 	MessageMatcher<Object> ANY_MESSAGE = new MessageMatcher<Object>() {
+
 		@Override
 		public boolean matches(Message<?> message) {
 			return true;
@@ -46,6 +40,14 @@ public interface MessageMatcher<T> {
 		public String toString() {
 			return "ANY_MESSAGE";
 		}
+
 	};
 
+	/**
+	 * Returns true if the {@link Message} matches, else false
+	 * @param message the {@link Message} to match on
+	 * @return true if the {@link Message} matches, else false
+	 */
+	boolean matches(Message<? extends T> message);
+
 }

+ 4 - 5
messaging/src/main/java/org/springframework/security/messaging/util/matcher/OrMessageMatcher.java

@@ -18,6 +18,7 @@ package org.springframework.security.messaging.util.matcher;
 
 import java.util.List;
 
+import org.springframework.core.log.LogMessage;
 import org.springframework.messaging.Message;
 
 /**
@@ -49,15 +50,13 @@ public final class OrMessageMatcher<T> extends AbstractMessageMatcherComposite<T
 	@Override
 	public boolean matches(Message<? extends T> message) {
 		for (MessageMatcher<T> matcher : getMessageMatchers()) {
-			if (this.LOGGER.isDebugEnabled()) {
-				this.LOGGER.debug("Trying to match using " + matcher);
-			}
+			this.logger.debug(LogMessage.format("Trying to match using %s", matcher));
 			if (matcher.matches(message)) {
-				this.LOGGER.debug("matched");
+				this.logger.debug("matched");
 				return true;
 			}
 		}
-		this.LOGGER.debug("No matches found");
+		this.logger.debug("No matches found");
 		return false;
 	}
 

+ 3 - 10
messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpDestinationMessageMatcher.java

@@ -107,11 +107,8 @@ public final class SimpDestinationMessageMatcher implements MessageMatcher<Objec
 	private SimpDestinationMessageMatcher(String pattern, SimpMessageType type, PathMatcher pathMatcher) {
 		Assert.notNull(pattern, "pattern cannot be null");
 		Assert.notNull(pathMatcher, "pathMatcher cannot be null");
-		if (!isTypeWithDestination(type)) {
-			throw new IllegalArgumentException(
-					"SimpMessageType " + type + " does not contain a destination and so cannot be matched on.");
-		}
-
+		Assert.isTrue(isTypeWithDestination(type),
+				() -> "SimpMessageType " + type + " does not contain a destination and so cannot be matched on.");
 		this.matcher = pathMatcher;
 		this.messageTypeMatcher = (type != null) ? new SimpMessageTypeMatcher(type) : ANY_MESSAGE;
 		this.pattern = pattern;
@@ -122,7 +119,6 @@ public final class SimpDestinationMessageMatcher implements MessageMatcher<Objec
 		if (!this.messageTypeMatcher.matches(message)) {
 			return false;
 		}
-
 		String destination = SimpMessageHeaderAccessor.getDestination(message.getHeaders());
 		return destination != null && this.matcher.match(this.pattern, destination);
 	}
@@ -144,10 +140,7 @@ public final class SimpDestinationMessageMatcher implements MessageMatcher<Objec
 	}
 
 	private boolean isTypeWithDestination(SimpMessageType type) {
-		if (type == null) {
-			return true;
-		}
-		return SimpMessageType.MESSAGE.equals(type) || SimpMessageType.SUBSCRIBE.equals(type);
+		return type == null || SimpMessageType.MESSAGE.equals(type) || SimpMessageType.SUBSCRIBE.equals(type);
 	}
 
 	/**

+ 0 - 2
messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpMessageTypeMatcher.java

@@ -49,7 +49,6 @@ public class SimpMessageTypeMatcher implements MessageMatcher<Object> {
 	public boolean matches(Message<?> message) {
 		MessageHeaders headers = message.getHeaders();
 		SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers);
-
 		return this.typeToMatch == messageType;
 	}
 
@@ -63,7 +62,6 @@ public class SimpMessageTypeMatcher implements MessageMatcher<Object> {
 		}
 		SimpMessageTypeMatcher otherMatcher = (SimpMessageTypeMatcher) other;
 		return ObjectUtils.nullSafeEquals(this.typeToMatch, otherMatcher.typeToMatch);
-
 	}
 
 	@Override

+ 3 - 7
messaging/src/main/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptor.java

@@ -46,23 +46,19 @@ public final class CsrfChannelInterceptor extends ChannelInterceptorAdapter {
 		if (!this.matcher.matches(message)) {
 			return message;
 		}
-
 		Map<String, Object> sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders());
 		CsrfToken expectedToken = (sessionAttributes != null)
 				? (CsrfToken) sessionAttributes.get(CsrfToken.class.getName()) : null;
-
 		if (expectedToken == null) {
 			throw new MissingCsrfTokenException(null);
 		}
-
 		String actualTokenValue = SimpMessageHeaderAccessor.wrap(message)
 				.getFirstNativeHeader(expectedToken.getHeaderName());
-
 		boolean csrfCheckPassed = expectedToken.getToken().equals(actualTokenValue);
-		if (csrfCheckPassed) {
-			return message;
+		if (!csrfCheckPassed) {
+			throw new InvalidCsrfTokenException(expectedToken, actualTokenValue);
 		}
-		throw new InvalidCsrfTokenException(expectedToken, actualTokenValue);
+		return message;
 	}
 
 }