Browse Source

HttpSessionOAuth2AuthorizationRequestRepository: store one request by default

Add setAllowMultipleAuthorizationRequests allowing applications to
revert to the previous functionality should they need to do so.

Closes gh-5145
Intentionally regresses gh-5110
Craig Andrews 4 years ago
parent
commit
35f5ebdbcf

+ 50 - 11
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -33,6 +33,7 @@ import org.springframework.util.Assert;
  *
  * @author Joe Grandja
  * @author Rob Winch
+ * @author Craig Andrews
  * @since 5.0
  * @see AuthorizationRequestRepository
  * @see OAuth2AuthorizationRequest
@@ -45,6 +46,8 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository
 
 	private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME;
 
+	private boolean allowMultipleAuthorizationRequests;
+
 	@Override
 	public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) {
 		Assert.notNull(request, "request cannot be null");
@@ -67,9 +70,14 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository
 		}
 		String state = authorizationRequest.getState();
 		Assert.hasText(state, "authorizationRequest.state cannot be empty");
-		Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
-		authorizationRequests.put(state, authorizationRequest);
-		request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
+		if (this.allowMultipleAuthorizationRequests) {
+			Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
+			authorizationRequests.put(state, authorizationRequest);
+			request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
+		}
+		else {
+			request.getSession().setAttribute(this.sessionAttributeName, authorizationRequest);
+		}
 	}
 
 	@Override
@@ -81,11 +89,15 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository
 		}
 		Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
 		OAuth2AuthorizationRequest originalRequest = authorizationRequests.remove(stateParameter);
-		if (!authorizationRequests.isEmpty()) {
-			request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
+		if (authorizationRequests.size() == 0) {
+			request.getSession().removeAttribute(this.sessionAttributeName);
+		}
+		else if (authorizationRequests.size() == 1) {
+			request.getSession().setAttribute(this.sessionAttributeName,
+					authorizationRequests.values().iterator().next());
 		}
 		else {
-			request.getSession().removeAttribute(this.sessionAttributeName);
+			request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
 		}
 		return originalRequest;
 	}
@@ -115,12 +127,39 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository
 	 */
 	private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(HttpServletRequest request) {
 		HttpSession session = request.getSession(false);
-		Map<String, OAuth2AuthorizationRequest> authorizationRequests = (session != null)
-				? (Map<String, OAuth2AuthorizationRequest>) session.getAttribute(this.sessionAttributeName) : null;
-		if (authorizationRequests == null) {
+		Object sessionAttributeValue = (session != null) ? session.getAttribute(this.sessionAttributeName) : null;
+		if (sessionAttributeValue == null) {
 			return new HashMap<>();
 		}
-		return authorizationRequests;
+		else if (sessionAttributeValue instanceof OAuth2AuthorizationRequest) {
+			OAuth2AuthorizationRequest auth2AuthorizationRequest = (OAuth2AuthorizationRequest) sessionAttributeValue;
+			Map<String, OAuth2AuthorizationRequest> authorizationRequests = new HashMap<>(1);
+			authorizationRequests.put(auth2AuthorizationRequest.getState(), auth2AuthorizationRequest);
+			return authorizationRequests;
+		}
+		else if (sessionAttributeValue instanceof Map) {
+			@SuppressWarnings("unchecked")
+			Map<String, OAuth2AuthorizationRequest> authorizationRequests = (Map<String, OAuth2AuthorizationRequest>) sessionAttributeValue;
+			return authorizationRequests;
+		}
+		else {
+			throw new IllegalStateException(
+					"authorizationRequests is supposed to be a Map or OAuth2AuthorizationRequest but actually is a "
+							+ sessionAttributeValue.getClass());
+		}
+	}
+
+	/**
+	 * Configure if multiple {@link OAuth2AuthorizationRequest}s should be stored per
+	 * session. Default is false (not allow multiple {@link OAuth2AuthorizationRequest}
+	 * per session).
+	 * @param allowMultipleAuthorizationRequests true allows more than one
+	 * {@link OAuth2AuthorizationRequest} to be stored per session.
+	 * @since 5.5
+	 */
+	@Deprecated
+	public void setAllowMultipleAuthorizationRequests(boolean allowMultipleAuthorizationRequests) {
+		this.allowMultipleAuthorizationRequests = allowMultipleAuthorizationRequests;
 	}
 
 }

+ 76 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests.java

@@ -0,0 +1,76 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.web;
+
+import org.junit.Before;
+import org.junit.Test;
+
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository} when
+ * {@link HttpSessionOAuth2AuthorizationRequestRepository#setAllowMultipleAuthorizationRequests(boolean)}
+ * is enabled.
+ *
+ * @author Joe Grandja
+ * @author Craig Andrews
+ */
+public class HttpSessionOAuth2AuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests
+		extends HttpSessionOAuth2AuthorizationRequestRepositoryTests {
+
+	@Before
+	public void setup() {
+		this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
+		this.authorizationRequestRepository.setAllowMultipleAuthorizationRequests(true);
+	}
+
+	// gh-5110
+	@Test
+	public void loadAuthorizationRequestWhenMultipleSavedThenReturnMatchingAuthorizationRequest() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		String state1 = "state-1122";
+		OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1).build();
+		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response);
+		String state2 = "state-3344";
+		OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2).build();
+		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response);
+		String state3 = "state-5566";
+		OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3).build();
+		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response);
+		request.addParameter(OAuth2ParameterNames.STATE, state1);
+		OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository
+				.loadAuthorizationRequest(request);
+		assertThat(loadedAuthorizationRequest1).isEqualTo(authorizationRequest1);
+		request.removeParameter(OAuth2ParameterNames.STATE);
+		request.addParameter(OAuth2ParameterNames.STATE, state2);
+		OAuth2AuthorizationRequest loadedAuthorizationRequest2 = this.authorizationRequestRepository
+				.loadAuthorizationRequest(request);
+		assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2);
+		request.removeParameter(OAuth2ParameterNames.STATE);
+		request.addParameter(OAuth2ParameterNames.STATE, state3);
+		OAuth2AuthorizationRequest loadedAuthorizationRequest3 = this.authorizationRequestRepository
+				.loadAuthorizationRequest(request);
+		assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3);
+	}
+
+}

+ 76 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests.java

@@ -0,0 +1,76 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.web;
+
+import org.junit.Before;
+import org.junit.Test;
+
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository} when
+ * {@link HttpSessionOAuth2AuthorizationRequestRepository#setAllowMultipleAuthorizationRequests(boolean)}
+ * is disabled.
+ *
+ * @author Joe Grandja
+ * @author Craig Andrews
+ */
+public class HttpSessionOAuth2AuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests
+		extends HttpSessionOAuth2AuthorizationRequestRepositoryTests {
+
+	@Before
+	public void setup() {
+		this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
+		this.authorizationRequestRepository.setAllowMultipleAuthorizationRequests(false);
+	}
+
+	// gh-5145
+	@Test
+	public void loadAuthorizationRequestWhenMultipleSavedThenReturnLastAuthorizationRequest() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		String state1 = "state-1122";
+		OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1).build();
+		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response);
+		String state2 = "state-3344";
+		OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2).build();
+		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response);
+		String state3 = "state-5566";
+		OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3).build();
+		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response);
+		request.addParameter(OAuth2ParameterNames.STATE, state1);
+		OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository
+				.loadAuthorizationRequest(request);
+		assertThat(loadedAuthorizationRequest1).isNull();
+		request.removeParameter(OAuth2ParameterNames.STATE);
+		request.addParameter(OAuth2ParameterNames.STATE, state2);
+		OAuth2AuthorizationRequest loadedAuthorizationRequest2 = this.authorizationRequestRepository
+				.loadAuthorizationRequest(request);
+		assertThat(loadedAuthorizationRequest2).isNull();
+		request.removeParameter(OAuth2ParameterNames.STATE);
+		request.addParameter(OAuth2ParameterNames.STATE, state3);
+		OAuth2AuthorizationRequest loadedAuthorizationRequest3 = this.authorizationRequestRepository
+				.loadAuthorizationRequest(request);
+		assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3);
+	}
+
+}

+ 5 - 34
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -36,11 +36,12 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
  * Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository}.
  *
  * @author Joe Grandja
+ * @author Craig Andrews
  */
 @RunWith(MockitoJUnitRunner.class)
-public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
+public abstract class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
 
-	private HttpSessionOAuth2AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
+	protected HttpSessionOAuth2AuthorizationRequestRepository authorizationRequestRepository;
 
 	@Test
 	public void loadAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() {
@@ -69,36 +70,6 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
 		assertThat(loadedAuthorizationRequest).isEqualTo(authorizationRequest);
 	}
 
-	// gh-5110
-	@Test
-	public void loadAuthorizationRequestWhenMultipleSavedThenReturnMatchingAuthorizationRequest() {
-		MockHttpServletRequest request = new MockHttpServletRequest();
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		String state1 = "state-1122";
-		OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1).build();
-		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response);
-		String state2 = "state-3344";
-		OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2).build();
-		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response);
-		String state3 = "state-5566";
-		OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3).build();
-		this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response);
-		request.addParameter(OAuth2ParameterNames.STATE, state1);
-		OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository
-				.loadAuthorizationRequest(request);
-		assertThat(loadedAuthorizationRequest1).isEqualTo(authorizationRequest1);
-		request.removeParameter(OAuth2ParameterNames.STATE);
-		request.addParameter(OAuth2ParameterNames.STATE, state2);
-		OAuth2AuthorizationRequest loadedAuthorizationRequest2 = this.authorizationRequestRepository
-				.loadAuthorizationRequest(request);
-		assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2);
-		request.removeParameter(OAuth2ParameterNames.STATE);
-		request.addParameter(OAuth2ParameterNames.STATE, state3);
-		OAuth2AuthorizationRequest loadedAuthorizationRequest3 = this.authorizationRequestRepository
-				.loadAuthorizationRequest(request);
-		assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3);
-	}
-
 	@Test
 	public void loadAuthorizationRequestWhenSavedAndStateParameterNullThenReturnNull() {
 		MockHttpServletRequest request = new MockHttpServletRequest();
@@ -237,7 +208,7 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
 		assertThat(removedAuthorizationRequest).isNull();
 	}
 
-	private OAuth2AuthorizationRequest.Builder createAuthorizationRequest() {
+	protected OAuth2AuthorizationRequest.Builder createAuthorizationRequest() {
 		return OAuth2AuthorizationRequest.authorizationCode().authorizationUri("https://example.com/oauth2/authorize")
 				.clientId("client-id-1234").state("state-1234");
 	}