Browse Source

Add Support for @Transient SecurityContext

Closes gh-9995
Rob Winch 3 years ago
parent
commit
6f0029fc44

+ 40 - 0
core/src/main/java/org/springframework/security/core/context/TransientSecurityContext.java

@@ -0,0 +1,40 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.core.context;
+
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.Transient;
+
+/**
+ * A {@link SecurityContext} that is annotated with @{@link Transient} and thus should
+ * never be stored across requests. This is useful in situations where one might run as a
+ * different user for part of a request.
+ *
+ * @author Rob Winch
+ * @since 5.7
+ */
+@Transient
+public class TransientSecurityContext extends SecurityContextImpl {
+
+	public TransientSecurityContext() {
+	}
+
+	public TransientSecurityContext(Authentication authentication) {
+		super(authentication);
+	}
+
+}

+ 7 - 4
web/src/main/java/org/springframework/security/web/context/HttpSessionSecurityContextRepository.java

@@ -232,11 +232,11 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
 		this.springSecurityContextKey = springSecurityContextKey;
 	}
 
-	private boolean isTransientAuthentication(Authentication authentication) {
-		if (authentication == null) {
+	private boolean isTransient(Object object) {
+		if (object == null) {
 			return false;
 		}
-		return AnnotationUtils.getAnnotation(authentication.getClass(), Transient.class) != null;
+		return AnnotationUtils.getAnnotation(object.getClass(), Transient.class) != null;
 	}
 
 	/**
@@ -329,8 +329,11 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
 		 */
 		@Override
 		protected void saveContext(SecurityContext context) {
+			if (isTransient(context)) {
+				return;
+			}
 			final Authentication authentication = context.getAuthentication();
-			if (isTransientAuthentication(authentication)) {
+			if (isTransient(authentication)) {
 				return;
 			}
 			HttpSession httpSession = this.request.getSession(false);

+ 65 - 0
web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java

@@ -43,13 +43,16 @@ import org.springframework.mock.web.MockHttpSession;
 import org.springframework.security.authentication.AbstractAuthenticationToken;
 import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.authentication.AuthenticationTrustResolver;
+import org.springframework.security.authentication.TestAuthentication;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
+import org.springframework.security.core.Authentication;
 import org.springframework.security.core.Transient;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.context.SecurityContextImpl;
+import org.springframework.security.core.context.TransientSecurityContext;
 import org.springframework.security.core.userdetails.User;
 import org.springframework.security.core.userdetails.UserDetails;
 
@@ -587,6 +590,68 @@ public class HttpSessionSecurityContextRepositoryTests {
 		assertThatIllegalStateException().isThrownBy(() -> repo.saveContext(context, request, response));
 	}
 
+	@Test
+	public void saveContextWhenTransientSecurityContextThenSkipped() {
+		HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository();
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		HttpRequestResponseHolder holder = new HttpRequestResponseHolder(request, response);
+		SecurityContext context = repo.loadContext(holder);
+		SecurityContext transientSecurityContext = new TransientSecurityContext();
+		Authentication authentication = TestAuthentication.authenticatedUser();
+		transientSecurityContext.setAuthentication(authentication);
+		repo.saveContext(transientSecurityContext, holder.getRequest(), holder.getResponse());
+		MockHttpSession session = (MockHttpSession) request.getSession(false);
+		assertThat(session).isNull();
+	}
+
+	@Test
+	public void saveContextWhenTransientSecurityContextSubclassThenSkipped() {
+		HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository();
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		HttpRequestResponseHolder holder = new HttpRequestResponseHolder(request, response);
+		SecurityContext context = repo.loadContext(holder);
+		SecurityContext transientSecurityContext = new TransientSecurityContext() {
+		};
+		Authentication authentication = TestAuthentication.authenticatedUser();
+		transientSecurityContext.setAuthentication(authentication);
+		repo.saveContext(transientSecurityContext, holder.getRequest(), holder.getResponse());
+		MockHttpSession session = (MockHttpSession) request.getSession(false);
+		assertThat(session).isNull();
+	}
+
+	@Test
+	public void saveContextWhenTransientSecurityContextAndSessionExistsThenSkipped() {
+		HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository();
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.getSession(); // ensure the session exists
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		HttpRequestResponseHolder holder = new HttpRequestResponseHolder(request, response);
+		SecurityContext context = repo.loadContext(holder);
+		SecurityContext transientSecurityContext = new TransientSecurityContext();
+		Authentication authentication = TestAuthentication.authenticatedUser();
+		transientSecurityContext.setAuthentication(authentication);
+		repo.saveContext(transientSecurityContext, holder.getRequest(), holder.getResponse());
+		MockHttpSession session = (MockHttpSession) request.getSession(false);
+		assertThat(Collections.list(session.getAttributeNames())).isEmpty();
+	}
+
+	@Test
+	public void saveContextWhenTransientSecurityContextWithCustomAnnotationThenSkipped() {
+		HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository();
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		HttpRequestResponseHolder holder = new HttpRequestResponseHolder(request, response);
+		SecurityContext context = repo.loadContext(holder);
+		SecurityContext transientSecurityContext = new TransientSecurityContext();
+		Authentication authentication = TestAuthentication.authenticatedUser();
+		transientSecurityContext.setAuthentication(authentication);
+		repo.saveContext(transientSecurityContext, holder.getRequest(), holder.getResponse());
+		MockHttpSession session = (MockHttpSession) request.getSession(false);
+		assertThat(session).isNull();
+	}
+
 	@Test
 	public void saveContextWhenTransientAuthenticationThenSkipped() {
 		HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository();