Browse Source

HttpSessionSecurityContextRepository does not persist @Transient Authentication

Related https://github.com/spring-projects/spring-security/pull/9993

Closes gh-482
Joe Grandja 3 years ago
parent
commit
d0e1107f36

+ 114 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java

@@ -19,6 +19,14 @@ import java.net.URI;
 import java.util.LinkedHashMap;
 import java.util.Map;
 
+import javax.servlet.AsyncContext;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletRequestWrapper;
+import javax.servlet.http.HttpServletResponse;
+
+import org.springframework.core.annotation.AnnotationUtils;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.HttpStatus;
 import org.springframework.security.authentication.AuthenticationManager;
@@ -26,6 +34,9 @@ import org.springframework.security.config.Customizer;
 import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
 import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
 import org.springframework.security.config.annotation.web.configurers.ExceptionHandlingConfigurer;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.Transient;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenIntrospectionAuthenticationProvider;
@@ -39,6 +50,10 @@ import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenR
 import org.springframework.security.web.access.intercept.FilterSecurityInterceptor;
 import org.springframework.security.web.authentication.HttpStatusEntryPoint;
 import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter;
+import org.springframework.security.web.context.HttpRequestResponseHolder;
+import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
+import org.springframework.security.web.context.SaveContextOnUpdateOrErrorResponseWrapper;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.OrRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -212,6 +227,105 @@ public final class OAuth2AuthorizationServerConfigurer<B extends HttpSecurityBui
 							this.tokenRevocationEndpointMatcher)
 			);
 		}
+
+		// gh-482
+		initSecurityContextRepository(builder);
+	}
+
+	private void initSecurityContextRepository(B builder) {
+		// TODO This is a temporary fix and should be removed after upgrading to Spring Security 5.7.0 GA.
+		//
+		// See:
+		// Prevent Save @Transient Authentication with existing HttpSession
+		// https://github.com/spring-projects/spring-security/pull/9993
+
+		final SecurityContextRepository securityContextRepository = builder.getSharedObject(SecurityContextRepository.class);
+		if (!(securityContextRepository instanceof HttpSessionSecurityContextRepository)) {
+			return;
+		}
+
+		SecurityContextRepository securityContextRepositoryTransientNotSaved = new SecurityContextRepository() {
+			// OAuth2ClientAuthenticationToken is @Transient and is accepted by
+			// OAuth2TokenEndpointFilter, OAuth2TokenIntrospectionEndpointFilter and OAuth2TokenRevocationEndpointFilter
+			private final RequestMatcher clientAuthenticationRequestMatcher = new OrRequestMatcher(
+					getRequestMatcher(OAuth2TokenEndpointConfigurer.class),
+					OAuth2AuthorizationServerConfigurer.this.tokenIntrospectionEndpointMatcher,
+					OAuth2AuthorizationServerConfigurer.this.tokenRevocationEndpointMatcher);
+
+			// JwtAuthenticationToken is @Transient and is accepted by
+			// OidcUserInfoEndpointFilter and OidcClientRegistrationEndpointFilter
+			private final RequestMatcher jwtAuthenticationRequestMatcher = getRequestMatcher(OidcConfigurer.class);
+
+			@Override
+			public SecurityContext loadContext(HttpRequestResponseHolder requestResponseHolder) {
+				final HttpServletRequest unwrappedRequest = requestResponseHolder.getRequest();
+				final HttpServletResponse unwrappedResponse = requestResponseHolder.getResponse();
+
+				SecurityContext securityContext = securityContextRepository.loadContext(requestResponseHolder);
+
+				if (this.clientAuthenticationRequestMatcher.matches(unwrappedRequest) ||
+						this.jwtAuthenticationRequestMatcher.matches(unwrappedRequest)) {
+
+					final SaveContextOnUpdateOrErrorResponseWrapper transientAuthenticationResponseWrapper =
+							new SaveContextOnUpdateOrErrorResponseWrapper(unwrappedResponse, false) {
+
+						@Override
+						protected void saveContext(SecurityContext context) {
+							// @Transient Authentication should not be saved
+							if (context.getAuthentication() != null) {
+								Assert.state(isTransientAuthentication(context.getAuthentication()), "Expected @Transient Authentication");
+							}
+						}
+
+					};
+					// Override the default HttpSessionSecurityContextRepository.SaveToSessionResponseWrapper
+					requestResponseHolder.setResponse(transientAuthenticationResponseWrapper);
+
+					final HttpServletRequestWrapper transientAuthenticationRequestWrapper =
+							new HttpServletRequestWrapper(unwrappedRequest) {
+
+						@Override
+						public AsyncContext startAsync() {
+							transientAuthenticationResponseWrapper.disableSaveOnResponseCommitted();
+							return super.startAsync();
+						}
+
+						@Override
+						public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse)
+								throws IllegalStateException {
+							transientAuthenticationResponseWrapper.disableSaveOnResponseCommitted();
+							return super.startAsync(servletRequest, servletResponse);
+						}
+
+					};
+					// Override the default HttpSessionSecurityContextRepository.SaveToSessionRequestWrapper
+					requestResponseHolder.setRequest(transientAuthenticationRequestWrapper);
+				}
+
+				return securityContext;
+			}
+
+			@Override
+			public void saveContext(SecurityContext context, HttpServletRequest request, HttpServletResponse response) {
+				Authentication authentication = context.getAuthentication();
+				if (authentication == null || isTransientAuthentication(authentication)) {
+					return;
+				}
+				securityContextRepository.saveContext(context, request, response);
+			}
+
+			@Override
+			public boolean containsContext(HttpServletRequest request) {
+				return securityContextRepository.containsContext(request);
+			}
+
+			private boolean isTransientAuthentication(Authentication authentication) {
+				return AnnotationUtils.getAnnotation(authentication.getClass(), Transient.class) != null;
+			}
+
+		};
+
+		builder.setSharedObject(SecurityContextRepository.class, securityContextRepositoryTransientNotSaved);
 	}
 
 	@Override

+ 81 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java

@@ -37,9 +37,11 @@ import com.nimbusds.jose.proc.SecurityContext;
 import org.assertj.core.matcher.AssertionMatcher;
 import org.junit.After;
 import org.junit.AfterClass;
+import org.junit.Before;
 import org.junit.BeforeClass;
 import org.junit.Rule;
 import org.junit.Test;
+import org.mockito.ArgumentCaptor;
 
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Bean;
@@ -56,6 +58,7 @@ import org.springframework.mock.http.client.MockClientHttpResponse;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration;
@@ -102,6 +105,8 @@ import org.springframework.security.web.SecurityFilterChain;
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
+import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.test.web.servlet.MvcResult;
@@ -116,6 +121,10 @@ import static org.hamcrest.CoreMatchers.containsString;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.reset;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user;
@@ -154,6 +163,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 	private static AuthenticationProvider authorizationRequestAuthenticationProvider;
 	private static AuthenticationSuccessHandler authorizationResponseHandler;
 	private static AuthenticationFailureHandler authorizationErrorResponseHandler;
+	private static SecurityContextRepository securityContextRepository;
 	private static String consentPage = "/oauth2/consent";
 
 	@Rule
@@ -187,6 +197,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 		authorizationRequestAuthenticationProvider = mock(AuthenticationProvider.class);
 		authorizationResponseHandler = mock(AuthenticationSuccessHandler.class);
 		authorizationErrorResponseHandler = mock(AuthenticationFailureHandler.class);
+		securityContextRepository = spy(new HttpSessionSecurityContextRepository());
 		db = new EmbeddedDatabaseBuilder()
 				.generateUniqueName(true)
 				.setType(EmbeddedDatabaseType.HSQL)
@@ -197,6 +208,11 @@ public class OAuth2AuthorizationCodeGrantTests {
 				.build();
 	}
 
+	@Before
+	public void setup() {
+		reset(securityContextRepository);
+	}
+
 	@After
 	public void tearDown() {
 		jdbcOperations.update("truncate table oauth2_authorization");
@@ -615,6 +631,48 @@ public class OAuth2AuthorizationCodeGrantTests {
 		verify(authorizationResponseHandler).onAuthenticationSuccess(any(), any(), eq(authorizationCodeRequestAuthenticationResult));
 	}
 
+	// gh-482
+	@Test
+	public void requestWhenClientObtainsAccessTokenThenClientAuthenticationNotPersisted() throws Exception {
+		this.spring.register(AuthorizationServerConfigurationWithSecurityContextRepository.class).autowire();
+
+		RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build();
+		this.registeredClientRepository.save(registeredClient);
+
+		MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
+				.params(getAuthorizationRequestParameters(registeredClient))
+				.param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE)
+				.param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")
+				.with(user("user")))
+				.andExpect(status().is3xxRedirection())
+				.andReturn();
+
+		ArgumentCaptor<org.springframework.security.core.context.SecurityContext> securityContextCaptor =
+				ArgumentCaptor.forClass(org.springframework.security.core.context.SecurityContext.class);
+		verify(securityContextRepository, times(2)).saveContext(securityContextCaptor.capture(), any(), any());
+		securityContextCaptor.getAllValues().forEach(securityContext ->
+				assertThat(securityContext.getAuthentication()).isInstanceOf(UsernamePasswordAuthenticationToken.class));
+		reset(securityContextRepository);
+
+		String authorizationCode = extractParameterFromRedirectUri(mvcResult.getResponse().getRedirectedUrl(), "code");
+		OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE);
+
+		this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
+				.params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization))
+				.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
+				.param(PkceParameterNames.CODE_VERIFIER, S256_CODE_VERIFIER))
+				.andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store")))
+				.andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache")))
+				.andExpect(status().isOk())
+				.andExpect(jsonPath("$.access_token").isNotEmpty())
+				.andExpect(jsonPath("$.token_type").isNotEmpty())
+				.andExpect(jsonPath("$.expires_in").isNotEmpty())
+				.andExpect(jsonPath("$.refresh_token").doesNotExist())
+				.andExpect(jsonPath("$.scope").isNotEmpty());
+
+		verify(securityContextRepository, never()).saveContext(any(), any(), any());
+	}
+
 	private static MultiValueMap<String, String> getAuthorizationRequestParameters(RegisteredClient registeredClient) {
 		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
 		parameters.set(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
@@ -739,6 +797,29 @@ public class OAuth2AuthorizationCodeGrantTests {
 
 	}
 
+	@EnableWebSecurity
+	static class AuthorizationServerConfigurationWithSecurityContextRepository extends AuthorizationServerConfiguration {
+		// @formatter:off
+		@Bean
+		public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http) throws Exception {
+			OAuth2AuthorizationServerConfigurer<HttpSecurity> authorizationServerConfigurer =
+					new OAuth2AuthorizationServerConfigurer<>();
+			RequestMatcher endpointsMatcher = authorizationServerConfigurer.getEndpointsMatcher();
+
+			http
+					.requestMatcher(endpointsMatcher)
+					.authorizeRequests(authorizeRequests ->
+							authorizeRequests.anyRequest().authenticated()
+					)
+					.csrf(csrf -> csrf.ignoringRequestMatchers(endpointsMatcher))
+					.securityContext(securityContext ->
+							securityContext.securityContextRepository(securityContextRepository))
+					.apply(authorizationServerConfigurer);
+			return http.build();
+		}
+		// @formatter:on
+	}
+
 	@EnableWebSecurity
 	@Import(OAuth2AuthorizationServerConfiguration.class)
 	static class AuthorizationServerConfigurationWithJwtEncoder extends AuthorizationServerConfiguration {

+ 67 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcUserInfoTests.java

@@ -26,6 +26,8 @@ import com.nimbusds.jose.jwk.JWKSet;
 import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
 import com.nimbusds.jose.jwk.source.JWKSource;
 import com.nimbusds.jose.proc.SecurityContext;
+import org.junit.Before;
+import org.junit.BeforeClass;
 import org.junit.Rule;
 import org.junit.Test;
 
@@ -62,10 +64,17 @@ import org.springframework.security.oauth2.server.authorization.oidc.authenticat
 import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken;
 import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken;
 import org.springframework.security.web.SecurityFilterChain;
+import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.test.web.servlet.ResultMatcher;
 
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.reset;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
 import static org.springframework.test.web.servlet.ResultMatcher.matchAll;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
@@ -79,6 +88,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.
  */
 public class OidcUserInfoTests {
 	private static final String DEFAULT_OIDC_USER_INFO_ENDPOINT_URI = "/userinfo";
+	private static SecurityContextRepository securityContextRepository;
 
 	@Rule
 	public final SpringTestRule spring = new SpringTestRule();
@@ -92,6 +102,16 @@ public class OidcUserInfoTests {
 	@Autowired
 	private OAuth2AuthorizationService authorizationService;
 
+	@BeforeClass
+	public static void init() {
+		securityContextRepository = spy(new HttpSessionSecurityContextRepository());
+	}
+
+	@Before
+	public void setup() {
+		reset(securityContextRepository);
+	}
+
 	@Test
 	public void requestWhenUserInfoRequestGetThenUserInfoResponse() throws Exception {
 		this.spring.register(AuthorizationServerConfiguration.class).autowire();
@@ -140,6 +160,25 @@ public class OidcUserInfoTests {
 		// @formatter:on
 	}
 
+	// gh-482
+	@Test
+	public void requestWhenUserInfoRequestThenBearerTokenAuthenticationNotPersisted() throws Exception {
+		this.spring.register(AuthorizationServerConfigurationWithSecurityContextRepository.class).autowire();
+
+		OAuth2Authorization authorization = createAuthorization();
+		this.authorizationService.save(authorization);
+
+		OAuth2AccessToken accessToken = authorization.getAccessToken().getToken();
+		// @formatter:off
+		this.mvc.perform(get(DEFAULT_OIDC_USER_INFO_ENDPOINT_URI)
+				.header(HttpHeaders.AUTHORIZATION, "Bearer " + accessToken.getTokenValue()))
+				.andExpect(status().is2xxSuccessful())
+				.andExpect(userInfoResponse());
+		// @formatter:on
+
+		verify(securityContextRepository, never()).saveContext(any(), any(), any());
+	}
+
 	private static ResultMatcher userInfoResponse() {
 		// @formatter:off
 		return matchAll(
@@ -257,6 +296,34 @@ public class OidcUserInfoTests {
 		}
 	}
 
+	@EnableWebSecurity
+	static class AuthorizationServerConfigurationWithSecurityContextRepository extends AuthorizationServerConfiguration {
+
+		@Bean
+		@Override
+		SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
+			OAuth2AuthorizationServerConfigurer<HttpSecurity> authorizationServerConfigurer =
+					new OAuth2AuthorizationServerConfigurer<>();
+			RequestMatcher endpointsMatcher = authorizationServerConfigurer
+					.getEndpointsMatcher();
+
+			// @formatter:off
+			http
+				.requestMatcher(endpointsMatcher)
+				.authorizeRequests(authorizeRequests ->
+					authorizeRequests.anyRequest().authenticated()
+				)
+				.csrf(csrf -> csrf.ignoringRequestMatchers(endpointsMatcher))
+				.oauth2ResourceServer(OAuth2ResourceServerConfigurer::jwt)
+				.securityContext(securityContext ->
+					securityContext.securityContextRepository(securityContextRepository))
+				.apply(authorizationServerConfigurer);
+			// @formatter:on
+
+			return http.build();
+		}
+	}
+
 	@EnableWebSecurity
 	static class AuthorizationServerConfiguration {