Prechádzať zdrojové kódy

BearerTokenAuthenticationFilter.securityContextRepository

Issue gh-10953
Rob Winch 3 rokov pred
rodič
commit
9db79aa5d7

+ 17 - 0
oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java

@@ -38,6 +38,8 @@ import org.springframework.security.oauth2.server.resource.authentication.JwtAut
 import org.springframework.security.web.AuthenticationEntryPoint;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
+import org.springframework.security.web.context.NullSecurityContextRepository;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.util.Assert;
 import org.springframework.web.filter.OncePerRequestFilter;
 
@@ -75,6 +77,8 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
 
 	private AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();
 
+	private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
+
 	/**
 	 * Construct a {@code BearerTokenAuthenticationFilter} using the provided parameter(s)
 	 * @param authenticationManagerResolver
@@ -131,6 +135,7 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
 			SecurityContext context = SecurityContextHolder.createEmptyContext();
 			context.setAuthentication(authenticationResult);
 			SecurityContextHolder.setContext(context);
+			this.securityContextRepository.saveContext(context, request, response);
 			if (this.logger.isDebugEnabled()) {
 				this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authenticationResult));
 			}
@@ -143,6 +148,18 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
 		}
 	}
 
+	/**
+	 * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
+	 * authentication success. The default action is not to save the
+	 * {@link SecurityContext}.
+	 * @param securityContextRepository the {@link SecurityContextRepository} to use.
+	 * Cannot be null.
+	 */
+	public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
+		Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
+		this.securityContextRepository = securityContextRepository;
+	}
+
 	/**
 	 * Set the {@link BearerTokenResolver} to use. Defaults to
 	 * {@link DefaultBearerTokenResolver}.

+ 25 - 0
oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java

@@ -36,18 +36,23 @@ import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.AuthenticationManagerResolver;
 import org.springframework.security.authentication.AuthenticationServiceException;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken;
 import org.springframework.security.oauth2.server.resource.BearerTokenError;
 import org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes;
 import org.springframework.security.web.AuthenticationEntryPoint;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+import org.springframework.security.web.context.SecurityContextRepository;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 
@@ -102,6 +107,26 @@ public class BearerTokenAuthenticationFilterTests {
 		assertThat(captor.getValue().getPrincipal()).isEqualTo("token");
 	}
 
+	@Test
+	public void doFilterWhenSecurityContextRepositoryThenSaves() throws ServletException, IOException {
+		SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
+		String token = "token";
+		given(this.bearerTokenResolver.resolve(this.request)).willReturn(token);
+		TestingAuthenticationToken expectedAuthentication = new TestingAuthenticationToken("test", "password");
+		given(this.authenticationManager.authenticate(any())).willReturn(expectedAuthentication);
+		BearerTokenAuthenticationFilter filter = addMocks(
+				new BearerTokenAuthenticationFilter(this.authenticationManager));
+		filter.setSecurityContextRepository(securityContextRepository);
+		filter.doFilter(this.request, this.response, this.filterChain);
+		ArgumentCaptor<BearerTokenAuthenticationToken> captor = ArgumentCaptor
+				.forClass(BearerTokenAuthenticationToken.class);
+		verify(this.authenticationManager).authenticate(captor.capture());
+		assertThat(captor.getValue().getPrincipal()).isEqualTo(token);
+		ArgumentCaptor<SecurityContext> contextArg = ArgumentCaptor.forClass(SecurityContext.class);
+		verify(securityContextRepository).saveContext(contextArg.capture(), eq(this.request), eq(this.response));
+		assertThat(contextArg.getValue().getAuthentication().getName()).isEqualTo(expectedAuthentication.getName());
+	}
+
 	@Test
 	public void doFilterWhenUsingAuthenticationManagerResolverThenAuthenticates() throws Exception {
 		BearerTokenAuthenticationFilter filter = addMocks(