فهرست منبع

Add SecurityContextHolderStrategy Java Configuration for OAuth2

Issue gh-11061
Josh Cummings 3 سال پیش
والد
کامیت
1d22316574

+ 15 - 2
config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@ import org.springframework.context.annotation.Configuration;
 import org.springframework.context.annotation.Import;
 import org.springframework.context.annotation.ImportSelector;
 import org.springframework.core.type.AnnotationMetadata;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder;
@@ -75,11 +76,18 @@ final class OAuth2ClientConfiguration {
 
 		private OAuth2AuthorizedClientManager authorizedClientManager;
 
+		private SecurityContextHolderStrategy securityContextHolderStrategy;
+
 		@Override
 		public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) {
 			OAuth2AuthorizedClientManager authorizedClientManager = getAuthorizedClientManager();
 			if (authorizedClientManager != null) {
-				argumentResolvers.add(new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager));
+				OAuth2AuthorizedClientArgumentResolver resolver = new OAuth2AuthorizedClientArgumentResolver(
+						authorizedClientManager);
+				if (this.securityContextHolderStrategy != null) {
+					resolver.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
+				}
+				argumentResolvers.add(resolver);
 			}
 		}
 
@@ -110,6 +118,11 @@ final class OAuth2ClientConfiguration {
 			}
 		}
 
+		@Autowired(required = false)
+		void setSecurityContextHolderStrategy(SecurityContextHolderStrategy strategy) {
+			this.securityContextHolderStrategy = strategy;
+		}
+
 		private OAuth2AuthorizedClientManager getAuthorizedClientManager() {
 			if (this.authorizedClientManager != null) {
 				return this.authorizedClientManager;

+ 2 - 1
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -272,6 +272,7 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>>
 			if (this.authorizationRequestRepository != null) {
 				authorizationCodeGrantFilter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
 			}
+			authorizationCodeGrantFilter.setSecurityContextHolderStrategy(getSecurityContextHolderStrategy());
 			RequestCache requestCache = builder.getSharedObject(RequestCache.class);
 			if (requestCache != null) {
 				authorizationCodeGrantFilter.setRequestCache(requestCache);

+ 2 - 1
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurer.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -270,6 +270,7 @@ public final class OAuth2ResourceServerConfigurer<H extends HttpSecurityBuilder<
 		BearerTokenAuthenticationFilter filter = new BearerTokenAuthenticationFilter(resolver);
 		filter.setBearerTokenResolver(bearerTokenResolver);
 		filter.setAuthenticationEntryPoint(this.authenticationEntryPoint);
+		filter.setSecurityContextHolderStrategy(getSecurityContextHolderStrategy());
 		filter = postProcess(filter);
 		http.addFilter(filter);
 	}

+ 25 - 0
config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java

@@ -42,6 +42,7 @@ import org.springframework.mock.web.MockFilterChain;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.authentication.event.AuthenticationSuccessEvent;
+import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
@@ -53,6 +54,8 @@ import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
+import org.springframework.security.core.context.SecurityContextChangedListener;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
@@ -99,7 +102,10 @@ import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.springframework.security.config.annotation.SecurityContextChangedListenerArgumentMatchers.setAuthentication;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
@@ -193,6 +199,25 @@ public class OAuth2LoginConfigurerTests {
 				.hasToString("ROLE_USER");
 	}
 
+	@Test
+	public void requestWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
+		loadConfig(OAuth2LoginConfig.class, SecurityContextChangedListenerConfig.class);
+		OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest();
+		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response);
+		this.request.setParameter("code", "code123");
+		this.request.setParameter("state", authorizationRequest.getState());
+		this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
+		Authentication authentication = this.securityContextRepository
+				.loadContext(new HttpRequestResponseHolder(this.request, this.response)).getAuthentication();
+		assertThat(authentication.getAuthorities()).hasSize(1);
+		assertThat(authentication.getAuthorities()).first().isInstanceOf(OAuth2UserAuthority.class)
+				.hasToString("ROLE_USER");
+		SecurityContextHolderStrategy strategy = this.context.getBean(SecurityContextHolderStrategy.class);
+		verify(strategy, atLeastOnce()).getContext();
+		SecurityContextChangedListener listener = this.context.getBean(SecurityContextChangedListener.class);
+		verify(listener).securityContextChanged(setAuthentication(OAuth2AuthenticationToken.class));
+	}
+
 	@Test
 	public void requestWhenOauth2LoginInLambdaThenAuthenticationContainsOauth2UserAuthority() throws Exception {
 		loadConfig(OAuth2LoginInLambdaConfig.class);

+ 37 - 1
config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurerTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -51,6 +51,7 @@ import org.hamcrest.core.StringEndsWith;
 import org.hamcrest.core.StringStartsWith;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.verification.VerificationMode;
 
 import org.springframework.beans.factory.BeanCreationException;
 import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
@@ -82,6 +83,7 @@ import org.springframework.security.authentication.AuthenticationManagerResolver
 import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.authentication.AuthenticationServiceException;
 import org.springframework.security.config.annotation.ObjectPostProcessor;
+import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig;
 import org.springframework.security.config.annotation.method.configuration.EnableGlobalMethodSecurity;
 import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
@@ -93,6 +95,8 @@ import org.springframework.security.config.test.SpringTestContextExtension;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.authority.SimpleGrantedAuthority;
+import org.springframework.security.core.context.SecurityContextChangedListener;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.core.userdetails.UserDetailsService;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
@@ -152,6 +156,7 @@ import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.verify;
@@ -217,6 +222,33 @@ public class OAuth2ResourceServerConfigurerTests {
 		// @formatter:on
 	}
 
+	@Test
+	public void getWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
+		this.spring.register(RestOperationsConfig.class, DefaultConfig.class, BasicController.class, SecurityContextChangedListenerConfig.class).autowire();
+		mockRestOperations(jwks("Default"));
+		String token = this.token("ValidNoScopes");
+		// @formatter:off
+		this.mvc.perform(get("/").with(bearerToken(token)))
+				.andExpect(status().isOk())
+				.andExpect(content().string("ok"));
+		// @formatter:on
+		verifyBean(SecurityContextHolderStrategy.class, atLeastOnce()).getContext();
+	}
+
+	@Test
+	public void getWhenSecurityContextHolderStrategyThenUses() throws Exception {
+		this.spring.register(RestOperationsConfig.class, DefaultConfig.class,
+				SecurityContextChangedListenerConfig.class, BasicController.class).autowire();
+		mockRestOperations(jwks("Default"));
+		String token = this.token("ValidNoScopes");
+		// @formatter:off
+		this.mvc.perform(get("/").with(bearerToken(token)))
+				.andExpect(status().isOk())
+				.andExpect(content().string("ok"));
+		// @formatter:on
+		verifyBean(SecurityContextChangedListener.class, atLeastOnce()).securityContextChanged(any());
+	}
+
 	@Test
 	public void getWhenUsingDefaultsInLambdaWithValidBearerTokenThenAcceptsRequest() throws Exception {
 		this.spring.register(RestOperationsConfig.class, DefaultInLambdaConfig.class, BasicController.class).autowire();
@@ -1418,6 +1450,10 @@ public class OAuth2ResourceServerConfigurerTests {
 		return verify(this.spring.getContext().getBean(beanClass));
 	}
 
+	private <T> T verifyBean(Class<T> beanClass, VerificationMode mode) {
+		return verify(this.spring.getContext().getBean(beanClass), mode);
+	}
+
 	private String json(String name) throws IOException {
 		return resource(name + ".json");
 	}

+ 1 - 1
config/src/test/java/org/springframework/security/config/annotation/web/configurers/openid/OpenIDLoginConfigurerTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.