Browse Source

Add SecurityContextHolderStrategy to OAuth2

Issue gh-11060
Josh Cummings 3 years ago
parent
commit
1d72a05c32

+ 17 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -32,6 +32,7 @@ import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationProvider;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
@@ -103,6 +104,9 @@ import org.springframework.web.util.UriComponentsBuilder;
  */
 public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
 
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
 	private final ClientRegistrationRepository clientRegistrationRepository;
 
 	private final OAuth2AuthorizedClientRepository authorizedClientRepository;
@@ -158,6 +162,17 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
 		this.requestCache = requestCache;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
+	 * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
+	 *
+	 * @since 5.8
+	 */
+	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
+	}
+
 	@Override
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
@@ -232,7 +247,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
 			this.redirectStrategy.sendRedirect(request, response, uriBuilder.build().encode().toString());
 			return;
 		}
-		Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication();
+		Authentication currentAuthentication = this.securityContextHolderStrategy.getContext().getAuthentication();
 		String principalName = (currentAuthentication != null) ? currentAuthentication.getName() : "anonymousUser";
 		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
 				authenticationResult.getClientRegistration(), principalName, authenticationResult.getAccessToken(),

+ 18 - 3
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -27,6 +27,7 @@ import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
@@ -72,6 +73,9 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
 	private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken("anonymous",
 			"anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
 
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
 	private OAuth2AuthorizedClientManager authorizedClientManager;
 
 	private boolean defaultAuthorizedClientManager;
@@ -120,7 +124,7 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
 					+ "It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or "
 					+ "@RegisteredOAuth2AuthorizedClient(registrationId = \"client1\").");
 		}
-		Authentication principal = SecurityContextHolder.getContext().getAuthentication();
+		Authentication principal = this.securityContextHolderStrategy.getContext().getAuthentication();
 		if (principal == null) {
 			principal = ANONYMOUS_AUTHENTICATION;
 		}
@@ -140,7 +144,7 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
 	private String resolveClientRegistrationId(MethodParameter parameter) {
 		RegisteredOAuth2AuthorizedClient authorizedClientAnnotation = AnnotatedElementUtils
 				.findMergedAnnotation(parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class);
-		Authentication principal = SecurityContextHolder.getContext().getAuthentication();
+		Authentication principal = this.securityContextHolderStrategy.getContext().getAuthentication();
 		if (!StringUtils.isEmpty(authorizedClientAnnotation.registrationId())) {
 			return authorizedClientAnnotation.registrationId();
 		}
@@ -179,6 +183,17 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
 		updateDefaultAuthorizedClientManager(clientCredentialsTokenResponseClient);
 	}
 
+	/**
+	 * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
+	 * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
+	 *
+	 * @since 5.8
+	 */
+	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
+	}
+
 	private void updateDefaultAuthorizedClientManager(
 			OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
 		// @formatter:off

+ 1 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.

+ 17 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -38,6 +38,7 @@ import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.oauth2.client.ClientAuthorizationException;
 import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler;
@@ -151,6 +152,9 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 	private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken("anonymous",
 			"anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
 
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
 	@Deprecated
 	private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
 
@@ -304,6 +308,17 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 		this.defaultClientRegistrationId = clientRegistrationId;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
+	 * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
+	 *
+	 * @since 5.8
+	 */
+	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
+	}
+
 	/**
 	 * Configures the builder with {@link #defaultRequest()} and adds this as a
 	 * {@link ExchangeFilterFunction}
@@ -513,7 +528,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 		if (attrs.containsKey(AUTHENTICATION_ATTR_NAME)) {
 			return;
 		}
-		Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
+		Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
 		attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
 	}
 

+ 20 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -39,6 +39,8 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
+import org.springframework.security.core.context.SecurityContextImpl;
 import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
@@ -306,6 +308,23 @@ public class OAuth2AuthorizationCodeGrantFilterTests {
 		assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1");
 	}
 
+	@Test
+	public void doFilterWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
+		MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1");
+		MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+		this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1);
+		this.setUpAuthenticationResult(this.registration1);
+		SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
+		given(strategy.getContext())
+				.willReturn(new SecurityContextImpl(new TestingAuthenticationToken("user", "password")));
+		this.filter.setSecurityContextHolderStrategy(strategy);
+		this.filter.doFilter(authorizationResponse, response, filterChain);
+		verify(strategy).getContext();
+		assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1");
+	}
+
 	@Test
 	public void doFilterWhenAuthorizationSucceedsAndHasSavedRequestThenRedirectToSavedRequest() throws Exception {
 		String requestUri = "/saved-request";

+ 16 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -34,6 +34,8 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
+import org.springframework.security.core.context.SecurityContextImpl;
 import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
 import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
@@ -70,6 +72,7 @@ import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 
@@ -254,6 +257,18 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
 				new ServletWebRequest(this.request, this.response), null)).isSameAs(this.authorizedClient1);
 	}
 
+	@Test
+	public void resolveArgumentWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
+		SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
+		given(strategy.getContext()).willReturn(new SecurityContextImpl(this.authentication));
+		this.argumentResolver.setSecurityContextHolderStrategy(strategy);
+		MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient",
+				OAuth2AuthorizedClient.class);
+		assertThat(this.argumentResolver.resolveArgument(methodParameter, null,
+				new ServletWebRequest(this.request, this.response), null)).isSameAs(this.authorizedClient1);
+		verify(strategy, atLeastOnce()).getContext();
+	}
+
 	@Test
 	public void resolveArgumentWhenRegistrationIdInvalidThenThrowIllegalArgumentException() {
 		MethodParameter methodParameter = this.getMethodParameter("registrationIdInvalid",

+ 15 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -65,6 +65,8 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
+import org.springframework.security.core.context.SecurityContextImpl;
 import org.springframework.security.oauth2.client.ClientAuthorizationException;
 import org.springframework.security.oauth2.client.JwtBearerOAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
@@ -282,6 +284,18 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		verifyNoInteractions(this.authorizedClientRepository);
 	}
 
+	@Test
+	public void defaultRequestAuthenticationWhenCustomSecurityContextHolderStrategyThenAuthenticationSet() {
+		SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
+		given(strategy.getContext()).willReturn(new SecurityContextImpl(this.authentication));
+		this.function.setSecurityContextHolderStrategy(strategy);
+		Map<String, Object> attrs = getDefaultRequestAttributes();
+		assertThat(ServletOAuth2AuthorizedClientExchangeFilterFunction.getAuthentication(attrs))
+				.isEqualTo(this.authentication);
+		verify(strategy).getContext();
+		verifyNoInteractions(this.authorizedClientRepository);
+	}
+
 	private Map<String, Object> getDefaultRequestAttributes() {
 		this.function.defaultRequest().accept(this.spec);
 		verify(this.spec).attributes(this.attrs.capture());

+ 19 - 4
oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -32,6 +32,7 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken;
 import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationProvider;
@@ -64,6 +65,9 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
 
 	private final AuthenticationManagerResolver<HttpServletRequest> authenticationManagerResolver;
 
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
 	private AuthenticationEntryPoint authenticationEntryPoint = new BearerTokenAuthenticationEntryPoint();
 
 	private AuthenticationFailureHandler authenticationFailureHandler = (request, response, exception) -> {
@@ -132,9 +136,9 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
 		try {
 			AuthenticationManager authenticationManager = this.authenticationManagerResolver.resolve(request);
 			Authentication authenticationResult = authenticationManager.authenticate(authenticationRequest);
-			SecurityContext context = SecurityContextHolder.createEmptyContext();
+			SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
 			context.setAuthentication(authenticationResult);
-			SecurityContextHolder.setContext(context);
+			this.securityContextHolderStrategy.setContext(context);
 			this.securityContextRepository.saveContext(context, request, response);
 			if (this.logger.isDebugEnabled()) {
 				this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authenticationResult));
@@ -142,12 +146,23 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
 			filterChain.doFilter(request, response);
 		}
 		catch (AuthenticationException failed) {
-			SecurityContextHolder.clearContext();
+			this.securityContextHolderStrategy.clearContext();
 			this.logger.trace("Failed to process authentication request", failed);
 			this.authenticationFailureHandler.onAuthenticationFailure(request, response, failed);
 		}
 	}
 
+	/**
+	 * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
+	 * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
+	 *
+	 * @since 5.8
+	 */
+	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
+	}
+
 	/**
 	 * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
 	 * authentication success. The default action is not to save the

+ 15 - 1
oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -38,6 +38,8 @@ import org.springframework.security.authentication.AuthenticationManagerResolver
 import org.springframework.security.authentication.AuthenticationServiceException;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
+import org.springframework.security.core.context.SecurityContextImpl;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken;
 import org.springframework.security.oauth2.server.resource.BearerTokenError;
@@ -205,6 +207,18 @@ public class BearerTokenAuthenticationFilterTests {
 		verify(this.authenticationDetailsSource).buildDetails(this.request);
 	}
 
+	@Test
+	public void doFilterWhenCustomSecurityContextHolderStrategyThenUses() throws ServletException, IOException {
+		given(this.bearerTokenResolver.resolve(this.request)).willReturn("token");
+		BearerTokenAuthenticationFilter filter = addMocks(
+				new BearerTokenAuthenticationFilter(this.authenticationManager));
+		SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
+		given(strategy.createEmptyContext()).willReturn(new SecurityContextImpl());
+		filter.setSecurityContextHolderStrategy(strategy);
+		filter.doFilter(this.request, this.response, this.filterChain);
+		verify(strategy).setContext(any());
+	}
+
 	@Test
 	public void setAuthenticationEntryPointWhenNullThenThrowsException() {
 		BearerTokenAuthenticationFilter filter = new BearerTokenAuthenticationFilter(this.authenticationManager);