Browse Source

HttpSessionOAuth2AuthorizationRequestRepository handle multiple OAuth2AuthorizationRequest per session

Fixes gh-5110
Joe Grandja 7 năm trước cách đây
mục cha
commit
59cef7d339

+ 30 - 6
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,11 +16,14 @@
 package org.springframework.security.oauth2.client.web;
 
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.util.Assert;
 
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 import javax.servlet.http.HttpSession;
+import java.util.HashMap;
+import java.util.Map;
 
 /**
  * An implementation of an {@link AuthorizationRequestRepository} that stores
@@ -39,9 +42,10 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
 	@Override
 	public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) {
 		Assert.notNull(request, "request cannot be null");
-		HttpSession session = request.getSession(false);
-		if (session != null) {
-			return (OAuth2AuthorizationRequest) session.getAttribute(this.sessionAttributeName);
+		Assert.hasText(request.getParameter(OAuth2ParameterNames.STATE), "state parameter cannot be empty");
+		Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
+		if (authorizationRequests != null) {
+			return authorizationRequests.get(request.getParameter(OAuth2ParameterNames.STATE));
 		}
 		return null;
 	}
@@ -55,7 +59,9 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
 			this.removeAuthorizationRequest(request);
 			return;
 		}
-		request.getSession().setAttribute(this.sessionAttributeName, authorizationRequest);
+		Assert.hasText(authorizationRequest.getState(), "authorizationRequest.state cannot be empty");
+		Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request, true);
+		authorizationRequests.put(authorizationRequest.getState(), authorizationRequest);
 	}
 
 	@Override
@@ -63,8 +69,26 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
 		Assert.notNull(request, "request cannot be null");
 		OAuth2AuthorizationRequest authorizationRequest = this.loadAuthorizationRequest(request);
 		if (authorizationRequest != null) {
-			request.getSession().removeAttribute(this.sessionAttributeName);
+			Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
+			authorizationRequests.remove(authorizationRequest.getState());
 		}
 		return authorizationRequest;
 	}
+
+	private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(HttpServletRequest request) {
+		return this.getAuthorizationRequests(request, false);
+	}
+
+	private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(HttpServletRequest request, boolean createSession) {
+		Map<String, OAuth2AuthorizationRequest> authorizationRequests = null;
+		HttpSession session = request.getSession(createSession);
+		if (session != null) {
+			authorizationRequests = (Map<String, OAuth2AuthorizationRequest>) session.getAttribute(this.sessionAttributeName);
+			if (authorizationRequests == null) {
+				authorizationRequests = new HashMap<>();
+				session.setAttribute(this.sessionAttributeName, authorizationRequests);
+			}
+		}
+		return authorizationRequests;
+	}
 }

+ 79 - 2
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java

@@ -22,9 +22,11 @@ import org.powermock.modules.junit4.PowerMockRunner;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 /**
  * Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository}.
@@ -44,8 +46,10 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
 
 	@Test
 	public void loadAuthorizationRequestWhenNotSavedThenReturnNull() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.addParameter(OAuth2ParameterNames.STATE, "state-1234");
 		OAuth2AuthorizationRequest authorizationRequest =
-			this.authorizationRequestRepository.loadAuthorizationRequest(new MockHttpServletRequest());
+			this.authorizationRequestRepository.loadAuthorizationRequest(request);
 
 		assertThat(authorizationRequest).isNull();
 	}
@@ -54,15 +58,69 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
 	public void loadAuthorizationRequestWhenSavedThenReturnAuthorizationRequest() {
 		MockHttpServletRequest request = new MockHttpServletRequest();
 		MockHttpServletResponse response = new MockHttpServletResponse();
+
 		OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
+		when(authorizationRequest.getState()).thenReturn("state-1234");
 
 		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
+		request.addParameter(OAuth2ParameterNames.STATE, "state-1234");
 		OAuth2AuthorizationRequest loadedAuthorizationRequest =
 			this.authorizationRequestRepository.loadAuthorizationRequest(request);
 
 		assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest);
 	}
 
+	// gh-5110
+	@Test
+	public void loadAuthorizationRequestWhenMultipleSavedThenReturnMatchingAuthorizationRequest() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+
+		String state1 = "state-1122";
+		OAuth2AuthorizationRequest authorizationRequest1 = mock(OAuth2AuthorizationRequest.class);
+		when(authorizationRequest1.getState()).thenReturn(state1);
+		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response);
+
+		String state2 = "state-3344";
+		OAuth2AuthorizationRequest authorizationRequest2 = mock(OAuth2AuthorizationRequest.class);
+		when(authorizationRequest2.getState()).thenReturn(state2);
+		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response);
+
+		String state3 = "state-5566";
+		OAuth2AuthorizationRequest authorizationRequest3 = mock(OAuth2AuthorizationRequest.class);
+		when(authorizationRequest3.getState()).thenReturn(state3);
+		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response);
+
+		request.addParameter(OAuth2ParameterNames.STATE, state1);
+		OAuth2AuthorizationRequest loadedAuthorizationRequest1 =
+			this.authorizationRequestRepository.loadAuthorizationRequest(request);
+		assertThat(loadedAuthorizationRequest1).isEqualTo(authorizationRequest1);
+
+		request.removeParameter(OAuth2ParameterNames.STATE);
+		request.addParameter(OAuth2ParameterNames.STATE, state2);
+		OAuth2AuthorizationRequest loadedAuthorizationRequest2 =
+			this.authorizationRequestRepository.loadAuthorizationRequest(request);
+		assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2);
+
+		request.removeParameter(OAuth2ParameterNames.STATE);
+		request.addParameter(OAuth2ParameterNames.STATE, state3);
+		OAuth2AuthorizationRequest loadedAuthorizationRequest3 =
+			this.authorizationRequestRepository.loadAuthorizationRequest(request);
+		assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void loadAuthorizationRequestWhenSavedAndStateParameterNullThenThrowIllegalArgumentException() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+
+		OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
+		when(authorizationRequest.getState()).thenReturn("state-1234");
+		this.authorizationRequestRepository.saveAuthorizationRequest(
+			authorizationRequest, request, new MockHttpServletResponse());
+
+		this.authorizationRequestRepository.loadAuthorizationRequest(request);
+	}
+
 	@Test(expected = IllegalArgumentException.class)
 	public void saveAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() {
 		this.authorizationRequestRepository.saveAuthorizationRequest(
@@ -75,13 +133,22 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
 			mock(OAuth2AuthorizationRequest.class), new MockHttpServletRequest(), null);
 	}
 
+	@Test(expected = IllegalArgumentException.class)
+	public void saveAuthorizationRequestWhenStateNullThenThrowIllegalArgumentException() {
+		this.authorizationRequestRepository.saveAuthorizationRequest(
+			mock(OAuth2AuthorizationRequest.class), new MockHttpServletRequest(), new MockHttpServletResponse());
+	}
+
 	@Test
 	public void saveAuthorizationRequestWhenNotNullThenSaved() {
 		MockHttpServletRequest request = new MockHttpServletRequest();
-		OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
 
+		OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
+		when(authorizationRequest.getState()).thenReturn("state-1234");
 		this.authorizationRequestRepository.saveAuthorizationRequest(
 			authorizationRequest, request, new MockHttpServletResponse());
+
+		request.addParameter(OAuth2ParameterNames.STATE, "state-1234");
 		OAuth2AuthorizationRequest loadedAuthorizationRequest =
 			this.authorizationRequestRepository.loadAuthorizationRequest(request);
 
@@ -92,12 +159,17 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
 	public void saveAuthorizationRequestWhenNullThenRemoved() {
 		MockHttpServletRequest request = new MockHttpServletRequest();
 		MockHttpServletResponse response = new MockHttpServletResponse();
+
 		OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
+		when(authorizationRequest.getState()).thenReturn("state-1234");
 
 		this.authorizationRequestRepository.saveAuthorizationRequest(		// Save
 			authorizationRequest, request, response);
+
+		request.addParameter(OAuth2ParameterNames.STATE, "state-1234");
 		this.authorizationRequestRepository.saveAuthorizationRequest(		// Null value removes
 			null, request, response);
+
 		OAuth2AuthorizationRequest loadedAuthorizationRequest =
 			this.authorizationRequestRepository.loadAuthorizationRequest(request);
 
@@ -113,10 +185,14 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
 	public void removeAuthorizationRequestWhenSavedThenRemoved() {
 		MockHttpServletRequest request = new MockHttpServletRequest();
 		MockHttpServletResponse response = new MockHttpServletResponse();
+
 		OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
+		when(authorizationRequest.getState()).thenReturn("state-1234");
 
 		this.authorizationRequestRepository.saveAuthorizationRequest(
 			authorizationRequest, request, response);
+
+		request.addParameter(OAuth2ParameterNames.STATE, "state-1234");
 		OAuth2AuthorizationRequest removedAuthorizationRequest =
 			this.authorizationRequestRepository.removeAuthorizationRequest(request);
 		OAuth2AuthorizationRequest loadedAuthorizationRequest =
@@ -129,6 +205,7 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
 	@Test
 	public void removeAuthorizationRequestWhenNotSavedThenNotRemoved() {
 		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.addParameter(OAuth2ParameterNames.STATE, "state-1234");
 
 		OAuth2AuthorizationRequest removedAuthorizationRequest =
 			this.authorizationRequestRepository.removeAuthorizationRequest(request);

+ 16 - 31
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java

@@ -17,6 +17,7 @@ package org.springframework.security.oauth2.client.web;
 
 import org.junit.Before;
 import org.junit.Test;
+import org.mockito.ArgumentCaptor;
 import org.springframework.http.HttpStatus;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
@@ -26,8 +27,6 @@ import org.springframework.security.oauth2.client.registration.InMemoryClientReg
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
-import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
-import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 
 import javax.servlet.FilterChain;
 import javax.servlet.http.HttpServletRequest;
@@ -153,7 +152,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
 	}
 
 	@Test
-	public void doFilterWhenAuthorizationRequestAuthorizationCodeGrantThenAuthorizationRequestSavedInSession() throws Exception {
+	public void doFilterWhenAuthorizationRequestAuthorizationCodeGrantThenAuthorizationRequestSaved() throws Exception {
 		String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
 			"/" + this.registration2.getRegistrationId();
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
@@ -162,31 +161,14 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
 		FilterChain filterChain = mock(FilterChain.class);
 
 		AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
-			new HttpSessionOAuth2AuthorizationRequestRepository();
+				mock(AuthorizationRequestRepository.class);
 		this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);
 
 		this.filter.doFilter(request, response, filterChain);
 
 		verifyZeroInteractions(filterChain);
-
-		OAuth2AuthorizationRequest authorizationRequest = authorizationRequestRepository.loadAuthorizationRequest(request);
-
-		assertThat(authorizationRequest).isNotNull();
-		assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo(
-			this.registration2.getProviderDetails().getAuthorizationUri());
-		assertThat(authorizationRequest.getGrantType()).isEqualTo(
-			this.registration2.getAuthorizationGrantType());
-		assertThat(authorizationRequest.getResponseType()).isEqualTo(
-			OAuth2AuthorizationResponseType.CODE);
-		assertThat(authorizationRequest.getClientId()).isEqualTo(
-			this.registration2.getClientId());
-		assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
-			"http://localhost/login/oauth2/code/registration-2");
-		assertThat(authorizationRequest.getScopes()).isEqualTo(
-			this.registration2.getScopes());
-		assertThat(authorizationRequest.getState()).isNotNull();
-		assertThat(authorizationRequest.getAdditionalParameters()
-			.get(OAuth2ParameterNames.REGISTRATION_ID)).isEqualTo(this.registration2.getRegistrationId());
+		verify(authorizationRequestRepository).saveAuthorizationRequest(
+			any(OAuth2AuthorizationRequest.class), any(HttpServletRequest.class), any(HttpServletResponse.class));
 	}
 
 	@Test
@@ -206,7 +188,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
 	}
 
 	@Test
-	public void doFilterWhenAuthorizationRequestImplicitGrantThenAuthorizationRequestNotSavedInSession() throws Exception {
+	public void doFilterWhenAuthorizationRequestImplicitGrantThenAuthorizationRequestNotSaved() throws Exception {
 		String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
 			"/" + this.registration3.getRegistrationId();
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
@@ -215,16 +197,14 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
 		FilterChain filterChain = mock(FilterChain.class);
 
 		AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
-			new HttpSessionOAuth2AuthorizationRequestRepository();
+				mock(AuthorizationRequestRepository.class);
 		this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);
 
 		this.filter.doFilter(request, response, filterChain);
 
 		verifyZeroInteractions(filterChain);
-
-		OAuth2AuthorizationRequest authorizationRequest = authorizationRequestRepository.loadAuthorizationRequest(request);
-
-		assertThat(authorizationRequest).isNull();
+		verify(authorizationRequestRepository, times(0)).saveAuthorizationRequest(
+				any(OAuth2AuthorizationRequest.class), any(HttpServletRequest.class), any(HttpServletResponse.class));
 	}
 
 	@Test
@@ -255,14 +235,19 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
 		FilterChain filterChain = mock(FilterChain.class);
 
 		AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
-			new HttpSessionOAuth2AuthorizationRequestRepository();
+				mock(AuthorizationRequestRepository.class);
 		this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);
 
 		this.filter.doFilter(request, response, filterChain);
 
+		ArgumentCaptor<OAuth2AuthorizationRequest> authorizationRequestArgCaptor =
+			ArgumentCaptor.forClass(OAuth2AuthorizationRequest.class);
+
 		verifyZeroInteractions(filterChain);
+		verify(authorizationRequestRepository).saveAuthorizationRequest(
+			authorizationRequestArgCaptor.capture(), any(HttpServletRequest.class), any(HttpServletResponse.class));
 
-		OAuth2AuthorizationRequest authorizationRequest = authorizationRequestRepository.loadAuthorizationRequest(request);
+		OAuth2AuthorizationRequest authorizationRequest = authorizationRequestArgCaptor.getValue();
 
 		assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(
 			this.registration2.getRedirectUriTemplate());

+ 11 - 7
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java

@@ -200,15 +200,16 @@ public class OAuth2LoginAuthenticationFilterTests {
 	@Test
 	public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved() throws Exception {
 		String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
+		String state = "state";
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 		request.setServletPath(requestUri);
 		request.addParameter(OAuth2ParameterNames.CODE, "code");
-		request.addParameter(OAuth2ParameterNames.STATE, "state");
+		request.addParameter(OAuth2ParameterNames.STATE, state);
 
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
 
-		this.setUpAuthorizationRequest(request, response, this.registration2);
+		this.setUpAuthorizationRequest(request, response, this.registration2, state);
 		this.setUpAuthenticationResult(this.registration2);
 
 		this.filter.doFilter(request, response, filterChain);
@@ -219,15 +220,16 @@ public class OAuth2LoginAuthenticationFilterTests {
 	@Test
 	public void doFilterWhenAuthorizationResponseValidThenAuthorizedClientSaved() throws Exception {
 		String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
+		String state = "state";
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 		request.setServletPath(requestUri);
 		request.addParameter(OAuth2ParameterNames.CODE, "code");
-		request.addParameter(OAuth2ParameterNames.STATE, "state");
+		request.addParameter(OAuth2ParameterNames.STATE, state);
 
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
 
-		this.setUpAuthorizationRequest(request, response, this.registration1);
+		this.setUpAuthorizationRequest(request, response, this.registration1, state);
 		this.setUpAuthenticationResult(this.registration1);
 
 		this.filter.doFilter(request, response, filterChain);
@@ -248,15 +250,16 @@ public class OAuth2LoginAuthenticationFilterTests {
 		this.filter.setAuthenticationManager(this.authenticationManager);
 
 		String requestUri = "/login/oauth2/custom/" + this.registration2.getRegistrationId();
+		String state = "state";
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 		request.setServletPath(requestUri);
 		request.addParameter(OAuth2ParameterNames.CODE, "code");
-		request.addParameter(OAuth2ParameterNames.STATE, "state");
+		request.addParameter(OAuth2ParameterNames.STATE, state);
 
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
 
-		this.setUpAuthorizationRequest(request, response, this.registration2);
+		this.setUpAuthorizationRequest(request, response, this.registration2, state);
 		this.setUpAuthenticationResult(this.registration2);
 
 		this.filter.doFilter(request, response, filterChain);
@@ -285,8 +288,9 @@ public class OAuth2LoginAuthenticationFilterTests {
 	}
 
 	private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
-											ClientRegistration registration) {
+											ClientRegistration registration, String state) {
 		OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
+		when(authorizationRequest.getState()).thenReturn(state);
 		Map<String, Object> additionalParameters = new HashMap<>();
 		additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
 		when(authorizationRequest.getAdditionalParameters()).thenReturn(additionalParameters);

+ 1 - 1
samples/boot/oauth2login/src/integration-test/java/org/springframework/security/samples/OAuth2LoginApplicationTests.java

@@ -250,7 +250,7 @@ public class OAuth2LoginApplicationTests {
 
 		HtmlElement errorElement = page.getBody().getFirstByXPath("p");
 		assertThat(errorElement).isNotNull();
-		assertThat(errorElement.asText()).contains("invalid_state_parameter");
+		assertThat(errorElement.asText()).contains("authorization_request_not_found");
 	}
 
 	@Test