瀏覽代碼

Store one request by default in WebSessionOAuth2ServerAuthorizationRequestRepository

Related to gh-9649
Closes gh-9857
Steve Riesenberg 4 年之前
父節點
當前提交
67a18f564a

+ 76 - 51
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -34,6 +34,7 @@ import reactor.core.publisher.Mono;
  * {@link OAuth2AuthorizationRequest} in the {@code WebSession}.
  *
  * @author Rob Winch
+ * @author Steve Riesenberg
  * @since 5.1
  * @see AuthorizationRequestRepository
  * @see OAuth2AuthorizationRequest
@@ -46,6 +47,8 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
 
 	private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME;
 
+	private boolean allowMultipleAuthorizationRequests;
+
 	@Override
 	public Mono<OAuth2AuthorizationRequest> loadAuthorizationRequest(
 			ServerWebExchange exchange) {
@@ -53,17 +56,33 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
 		if (state == null) {
 			return Mono.empty();
 		}
-		return getStateToAuthorizationRequest(exchange)
-				.filter(stateToAuthorizationRequest -> stateToAuthorizationRequest.containsKey(state))
-				.map(stateToAuthorizationRequest -> stateToAuthorizationRequest.get(state));
+		// @formatter:off
+		return this.getSessionAttributes(exchange)
+				.filter((sessionAttrs) -> sessionAttrs.containsKey(this.sessionAttributeName))
+				.map(this::getAuthorizationRequests)
+				.filter((stateToAuthorizationRequest) -> stateToAuthorizationRequest.containsKey(state))
+				.map((stateToAuthorizationRequest) -> stateToAuthorizationRequest.get(state));
+		// @formatter:on
 	}
 
 	@Override
 	public Mono<Void> saveAuthorizationRequest(
 			OAuth2AuthorizationRequest authorizationRequest, ServerWebExchange exchange) {
 		Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
-		return saveStateToAuthorizationRequest(exchange)
-				.doOnNext(stateToAuthorizationRequest -> stateToAuthorizationRequest.put(authorizationRequest.getState(), authorizationRequest))
+		Assert.notNull(exchange, "exchange cannot be null");
+		// @formatter:off
+		return getSessionAttributes(exchange)
+				.doOnNext((sessionAttrs) -> {
+					if (this.allowMultipleAuthorizationRequests) {
+						Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(
+								sessionAttrs);
+						authorizationRequests.put(authorizationRequest.getState(), authorizationRequest);
+						sessionAttrs.put(this.sessionAttributeName, authorizationRequests);
+					}
+					else {
+						sessionAttrs.put(this.sessionAttributeName, authorizationRequest);
+					}
+				})
 				.then();
 	}
 
@@ -74,27 +93,24 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
 		if (state == null) {
 			return Mono.empty();
 		}
-		return exchange.getSession()
-			.map(WebSession::getAttributes)
-			.handle((sessionAttrs, sink) -> {
-				Map<String, OAuth2AuthorizationRequest> stateToAuthzRequest = sessionAttrsMapStateToAuthorizationRequest(sessionAttrs);
-				if (stateToAuthzRequest == null) {
-					sink.complete();
-					return;
-				}
-				OAuth2AuthorizationRequest removedValue = stateToAuthzRequest.remove(state);
-				if (stateToAuthzRequest.isEmpty()) {
-					sessionAttrs.remove(this.sessionAttributeName);
-				} else if (removedValue != null) {
-					// gh-7327 Overwrite the existing Map to ensure the state is saved for distributed sessions
-					sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest);
-				}
-				if (removedValue == null) {
-					sink.complete();
-				} else {
-					sink.next(removedValue);
-				}
-			});
+		// @formatter:off
+		return getSessionAttributes(exchange)
+				.flatMap((sessionAttrs) -> {
+					Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(
+							sessionAttrs);
+					OAuth2AuthorizationRequest originalRequest = authorizationRequests.remove(state);
+					if (authorizationRequests.isEmpty()) {
+						sessionAttrs.remove(this.sessionAttributeName);
+					}
+					else if (authorizationRequests.size() == 1) {
+						sessionAttrs.put(this.sessionAttributeName, authorizationRequests.values().iterator().next());
+					}
+					else {
+						sessionAttrs.put(this.sessionAttributeName, authorizationRequests);
+					}
+					return Mono.justOrEmpty(originalRequest);
+				});
+		// @formatter:on
 	}
 
 	/**
@@ -111,31 +127,40 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
 		return exchange.getSession().map(WebSession::getAttributes);
 	}
 
-	private Mono<Map<String, OAuth2AuthorizationRequest>> getStateToAuthorizationRequest(ServerWebExchange exchange) {
-		Assert.notNull(exchange, "exchange cannot be null");
-
-		return getSessionAttributes(exchange)
-			.flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs)));
-	}
-
-	private Mono<Map<String, OAuth2AuthorizationRequest>> saveStateToAuthorizationRequest(ServerWebExchange exchange) {
-		Assert.notNull(exchange, "exchange cannot be null");
-
-		return getSessionAttributes(exchange)
-			.doOnNext(sessionAttrs -> {
-				Object stateToAuthzRequest = sessionAttrs.get(this.sessionAttributeName);
-
-				if (stateToAuthzRequest == null) {
-					stateToAuthzRequest = new HashMap<String, OAuth2AuthorizationRequest>();
-				}
-
-				// No matter stateToAuthzRequest was in session or not, we should always put it into session again
-				// in case of redis or hazelcast session. #6215
-				sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest);
-			}).flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs)));
+	private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(Map<String, Object> sessionAttrs) {
+		Object sessionAttributeValue = sessionAttrs.get(this.sessionAttributeName);
+		if (sessionAttributeValue == null) {
+			return new HashMap<>();
+		}
+		else if (sessionAttributeValue instanceof OAuth2AuthorizationRequest) {
+			OAuth2AuthorizationRequest oauth2AuthorizationRequest = (OAuth2AuthorizationRequest) sessionAttributeValue;
+			Map<String, OAuth2AuthorizationRequest> authorizationRequests = new HashMap<>(1);
+			authorizationRequests.put(oauth2AuthorizationRequest.getState(), oauth2AuthorizationRequest);
+			return authorizationRequests;
+		}
+		else if (sessionAttributeValue instanceof Map) {
+			@SuppressWarnings("unchecked")
+			Map<String, OAuth2AuthorizationRequest> authorizationRequests = (Map<String, OAuth2AuthorizationRequest>) sessionAttrs
+					.get(this.sessionAttributeName);
+			return authorizationRequests;
+		}
+		else {
+			throw new IllegalStateException(
+					"authorizationRequests is supposed to be a Map or OAuth2AuthorizationRequest but actually is a "
+							+ sessionAttributeValue.getClass());
+		}
 	}
 
-	private Map<String, OAuth2AuthorizationRequest> sessionAttrsMapStateToAuthorizationRequest(Map<String, Object> sessionAttrs) {
-		return (Map<String, OAuth2AuthorizationRequest>) sessionAttrs.get(this.sessionAttributeName);
+	/**
+	 * Configure if multiple {@link OAuth2AuthorizationRequest}s should be stored per
+	 * session. Default is false (not allow multiple {@link OAuth2AuthorizationRequest}
+	 * per session).
+	 * @param allowMultipleAuthorizationRequests true allows more than one
+	 * {@link OAuth2AuthorizationRequest} to be stored per session.
+	 * @since 5.5
+	 */
+	@Deprecated
+	public void setAllowMultipleAuthorizationRequests(boolean allowMultipleAuthorizationRequests) {
+		this.allowMultipleAuthorizationRequests = allowMultipleAuthorizationRequests;
 	}
 }

+ 252 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests.java

@@ -0,0 +1,252 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.web.server;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.Before;
+import org.junit.Test;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+
+import org.springframework.http.codec.ServerCodecConfigurer;
+import org.springframework.http.server.reactive.ServerHttpRequest;
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
+import org.springframework.mock.http.server.reactive.MockServerHttpResponse;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.WebSession;
+import org.springframework.web.server.adapter.DefaultServerWebExchange;
+import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver;
+import org.springframework.web.server.session.WebSessionManager;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+/**
+ * Tests for {@link WebSessionOAuth2ServerAuthorizationRequestRepository} when
+ * {@link WebSessionOAuth2ServerAuthorizationRequestRepository#setAllowMultipleAuthorizationRequests(boolean)}
+ * is enabled.
+ *
+ * @author Steve Riesenberg
+ */
+
+public class WebSessionOAuth2ServerAuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests
+		extends WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
+
+	@Before
+	public void setup() {
+		this.repository = new WebSessionOAuth2ServerAuthorizationRequestRepository();
+		this.repository.setAllowMultipleAuthorizationRequests(true);
+	}
+
+	@Test
+	public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() {
+		String oldState = "state0";
+		// @formatter:off
+		MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, oldState)
+				.build();
+		OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(oldState)
+				.build();
+		// @formatter:on
+		WebSessionManager sessionManager = (e) -> this.exchange.getSession();
+		this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(),
+				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+		ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(),
+				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+		// @formatter:off
+		Mono<OAuth2AuthorizationRequest> saveAndSaveAndLoad = this.repository
+				.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
+				.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
+				.then(this.repository.loadAuthorizationRequest(oldExchange));
+		StepVerifier.create(saveAndSaveAndLoad)
+				.expectNext(oldAuthorizationRequest)
+				.verifyComplete();
+		StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
+				.expectNext(this.authorizationRequest)
+				.verifyComplete();
+		// @formatter:on
+	}
+
+	// gh-5145
+	@Test
+	public void loadAuthorizationRequestWhenSavedWithAllowMultipleAuthorizationRequestsThenReturnOldAuthorizationRequest() {
+		// save 2 requests with legacy (allowMultipleAuthorizationRequests=true) and load
+		// with new
+		WebSessionOAuth2ServerAuthorizationRequestRepository legacy = new WebSessionOAuth2ServerAuthorizationRequestRepository();
+		legacy.setAllowMultipleAuthorizationRequests(true);
+		// @formatter:off
+		String state1 = "state-1122";
+		OAuth2AuthorizationRequest authorizationRequest1 = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(state1)
+				.build();
+		StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest1, this.exchange))
+				.verifyComplete();
+		String state2 = "state-3344";
+		OAuth2AuthorizationRequest authorizationRequest2 = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(state2)
+				.build();
+		StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest2, this.exchange))
+				.verifyComplete();
+		ServerHttpRequest newRequest = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, state1)
+				.build();
+		ServerWebExchange newExchange = this.exchange.mutate()
+				.request(newRequest)
+				.build();
+		StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange))
+				.expectNext(authorizationRequest1)
+				.verifyComplete();
+		// @formatter:on
+	}
+
+	// gh-5145
+	@Test
+	public void saveAuthorizationRequestWhenSavedWithAllowMultipleAuthorizationRequestsThenLoadNewAuthorizationRequest() {
+		// save 2 requests with legacy (allowMultipleAuthorizationRequests=true), save
+		// with new, and load with new
+		WebSessionOAuth2ServerAuthorizationRequestRepository legacy = new WebSessionOAuth2ServerAuthorizationRequestRepository();
+		legacy.setAllowMultipleAuthorizationRequests(true);
+		// @formatter:off
+		String state1 = "state-1122";
+		OAuth2AuthorizationRequest authorizationRequest1 = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(state1)
+				.build();
+		StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest1, this.exchange))
+				.verifyComplete();
+		String state2 = "state-3344";
+		OAuth2AuthorizationRequest authorizationRequest2 = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(state2)
+				.build();
+		StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest2, this.exchange))
+				.verifyComplete();
+		String state3 = "state-5566";
+		OAuth2AuthorizationRequest authorizationRequest3 = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(state3)
+				.build();
+		ServerHttpRequest newRequest = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, state3)
+				.build();
+		ServerWebExchange newExchange = this.exchange.mutate()
+				.request(newRequest)
+				.build();
+		Mono<OAuth2AuthorizationRequest> saveAndLoad = this.repository
+				.saveAuthorizationRequest(authorizationRequest3, this.exchange)
+				.then(this.repository.loadAuthorizationRequest(newExchange));
+		StepVerifier.create(saveAndLoad)
+				.expectNext(authorizationRequest3)
+				.verifyComplete();
+		// @formatter:on
+	}
+
+	@Test
+	public void removeAuthorizationRequestWhenMultipleThenOnlyOneRemoved() {
+		String oldState = "state0";
+		// @formatter:off
+		MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, oldState)
+				.build();
+		OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(oldState)
+				.build();
+		// @formatter:on
+		WebSessionManager sessionManager = (e) -> this.exchange.getSession();
+		this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(),
+				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+		ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(),
+				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+		// @formatter:off
+		Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository
+				.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
+				.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
+				.then(this.repository.removeAuthorizationRequest(this.exchange));
+		StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest)
+				.verifyComplete();
+		StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
+				.verifyComplete();
+		StepVerifier.create(this.repository.loadAuthorizationRequest(oldExchange))
+				.expectNext(oldAuthorizationRequest)
+				.verifyComplete();
+		// @formatter:on
+	}
+
+	// gh-7327
+	@Test
+	public void removeAuthorizationRequestWhenMultipleThenRemovedAndSessionAttributeUpdated() {
+		String oldState = "state0";
+		// @formatter:off
+		MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, oldState)
+				.build();
+		OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(oldState)
+				.build();
+		// @formatter:on
+		Map<String, Object> sessionAttrs = spy(new HashMap<>());
+		WebSession session = mock(WebSession.class);
+		given(session.getAttributes()).willReturn(sessionAttrs);
+		WebSessionManager sessionManager = (e) -> Mono.just(session);
+		this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(),
+				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+		ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(),
+				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+		// @formatter:off
+		Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository
+				.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
+				.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
+				.then(this.repository.removeAuthorizationRequest(this.exchange));
+		StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest)
+				.verifyComplete();
+		StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
+				.verifyComplete();
+		// @formatter:on
+		verify(sessionAttrs, times(3)).put(any(), any());
+	}
+
+}

+ 159 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests.java

@@ -0,0 +1,159 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.web.server;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.Before;
+import org.junit.Test;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+
+import org.springframework.http.codec.ServerCodecConfigurer;
+import org.springframework.http.server.reactive.ServerHttpRequest;
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
+import org.springframework.mock.http.server.reactive.MockServerHttpResponse;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.WebSession;
+import org.springframework.web.server.adapter.DefaultServerWebExchange;
+import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver;
+import org.springframework.web.server.session.WebSessionManager;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+/**
+ * Tests for {@link WebSessionOAuth2ServerAuthorizationRequestRepository} when
+ * {@link WebSessionOAuth2ServerAuthorizationRequestRepository#setAllowMultipleAuthorizationRequests(boolean)}
+ * is disabled.
+ *
+ * @author Steve Riesenberg
+ */
+public class WebSessionOAuth2ServerAuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests
+		extends WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
+
+	@Before
+	public void setup() {
+		this.repository = new WebSessionOAuth2ServerAuthorizationRequestRepository();
+		this.repository.setAllowMultipleAuthorizationRequests(false);
+	}
+
+	// gh-5145
+	@Test
+	public void loadAuthorizationRequestWhenMultipleSavedThenReturnLastAuthorizationRequest() {
+		// @formatter:off
+		String state1 = "state-1122";
+		OAuth2AuthorizationRequest authorizationRequest1 = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(state1)
+				.build();
+		StepVerifier.create(this.repository.saveAuthorizationRequest(authorizationRequest1, this.exchange))
+				.verifyComplete();
+		String state2 = "state-3344";
+		OAuth2AuthorizationRequest authorizationRequest2 = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(state2)
+				.build();
+		StepVerifier.create(this.repository.saveAuthorizationRequest(authorizationRequest2, this.exchange))
+				.verifyComplete();
+		String state3 = "state-5566";
+		OAuth2AuthorizationRequest authorizationRequest3 = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(state3)
+				.build();
+		StepVerifier.create(this.repository.saveAuthorizationRequest(authorizationRequest3, this.exchange))
+				.verifyComplete();
+		ServerHttpRequest newRequest1 = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, state1)
+				.build();
+		ServerWebExchange newExchange1 = this.exchange.mutate()
+				.request(newRequest1)
+				.build();
+		StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange1))
+				.verifyComplete();
+		ServerHttpRequest newRequest2 = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, state2)
+				.build();
+		ServerWebExchange newExchange2 = this.exchange.mutate()
+				.request(newRequest2)
+				.build();
+		StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange2))
+				.verifyComplete();
+		ServerHttpRequest newRequest3 = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, state3)
+				.build();
+		ServerWebExchange newExchange3 = this.exchange.mutate()
+				.request(newRequest3)
+				.build();
+		StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange3))
+				.expectNext(authorizationRequest3)
+				.verifyComplete();
+		// @formatter:on
+	}
+
+	// gh-5145
+	@Test
+	public void removeAuthorizationRequestWhenMultipleThenSessionAttributeRemoved() {
+		String oldState = "state0";
+		// @formatter:off
+		MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, oldState)
+				.build();
+		OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(oldState)
+				.build();
+		// @formatter:on
+		Map<String, Object> sessionAttrs = spy(new HashMap<>());
+		WebSession session = mock(WebSession.class);
+		given(session.getAttributes()).willReturn(sessionAttrs);
+		WebSessionManager sessionManager = (e) -> Mono.just(session);
+		this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(),
+				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+		ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(),
+				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+		// @formatter:off
+		Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository
+				.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
+				.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
+				.then(this.repository.removeAuthorizationRequest(this.exchange));
+		StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest)
+				.verifyComplete();
+		StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
+				.verifyComplete();
+		// @formatter:on
+		verify(sessionAttrs, times(2)).put(anyString(), any());
+		verify(sessionAttrs).remove(anyString());
+	}
+
+}

+ 11 - 130
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,51 +16,39 @@
 
 package org.springframework.security.oauth2.client.web.server;
 
-import static org.assertj.core.api.Assertions.assertThatThrownBy;
-
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.spy;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
-import java.util.HashMap;
 import java.util.Map;
 
 import org.junit.Test;
-import org.springframework.http.codec.ServerCodecConfigurer;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+
 import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
-import org.springframework.mock.http.server.reactive.MockServerHttpResponse;
 import org.springframework.mock.web.server.MockServerWebExchange;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebSession;
-import org.springframework.web.server.adapter.DefaultServerWebExchange;
-import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver;
-import org.springframework.web.server.session.WebSessionManager;
 
-import reactor.core.publisher.Mono;
-import reactor.test.StepVerifier;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /**
  * @author Rob Winch
  * @since 5.1
  */
-public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
+public abstract class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
 
-	private WebSessionOAuth2ServerAuthorizationRequestRepository repository =
-			new WebSessionOAuth2ServerAuthorizationRequestRepository();
+	protected WebSessionOAuth2ServerAuthorizationRequestRepository repository;
 
-	private OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+	// @formatter:off
+	protected OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
 			.authorizationUri("https://example.com/oauth2/authorize")
 			.clientId("client-id")
 			.redirectUri("http://localhost/client-1")
 			.state("state")
 			.build();
 
-	private ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/")
-			.queryParam(OAuth2ParameterNames.STATE, "state"));
+	protected ServerWebExchange exchange = MockServerWebExchange
+			.from(MockServerHttpRequest.get("/").queryParam(OAuth2ParameterNames.STATE, "state"));
 
 	@Test
 	public void loadAuthorizationRequestWhenNullExchangeThenIllegalArgumentException() {
@@ -106,39 +94,6 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
 				.verifyComplete();
 	}
 
-	@Test
-	public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() {
-		String oldState = "state0";
-		MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
-				.queryParam(OAuth2ParameterNames.STATE, oldState).build();
-
-		OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
-				.authorizationUri("https://example.com/oauth2/authorize")
-				.clientId("client-id")
-				.redirectUri("http://localhost/client-1")
-				.state(oldState)
-				.build();
-
-		WebSessionManager sessionManager = e -> this.exchange.getSession();
-
-		this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager,
-				ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
-		ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager,
-				ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
-
-		Mono<OAuth2AuthorizationRequest> saveAndSaveAndLoad = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
-				.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
-				.then(this.repository.loadAuthorizationRequest(oldExchange));
-
-		StepVerifier.create(saveAndSaveAndLoad)
-				.expectNext(oldAuthorizationRequest)
-				.verifyComplete();
-
-		StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
-				.expectNext(this.authorizationRequest)
-				.verifyComplete();
-	}
-
 	@Test
 	public void saveAuthorizationRequestWhenAuthorizationRequestNullThenThrowsIllegalArgumentException() {
 		this.authorizationRequest = null;
@@ -203,80 +158,6 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
 				.verifyComplete();
 	}
 
-	@Test
-	public void removeAuthorizationRequestWhenMultipleThenOnlyOneRemoved() {
-		String oldState = "state0";
-		MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
-				.queryParam(OAuth2ParameterNames.STATE, oldState).build();
-
-		OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
-				.authorizationUri("https://example.com/oauth2/authorize")
-				.clientId("client-id")
-				.redirectUri("http://localhost/client-1")
-				.state(oldState)
-				.build();
-
-		WebSessionManager sessionManager = e -> this.exchange.getSession();
-
-		this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager,
-				ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
-		ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager,
-				ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
-
-		Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
-				.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
-				.then(this.repository.removeAuthorizationRequest(this.exchange));
-
-		StepVerifier.create(saveAndSaveAndRemove)
-				.expectNext(this.authorizationRequest)
-				.verifyComplete();
-
-		StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
-				.verifyComplete();
-
-		StepVerifier.create(this.repository.loadAuthorizationRequest(oldExchange))
-				.expectNext(oldAuthorizationRequest)
-				.verifyComplete();
-	}
-
-	// gh-7327
-	@Test
-	public void removeAuthorizationRequestWhenMultipleThenRemovedAndSessionAttributeUpdated() {
-		String oldState = "state0";
-		MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
-				.queryParam(OAuth2ParameterNames.STATE, oldState).build();
-
-		OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
-				.authorizationUri("https://example.com/oauth2/authorize")
-				.clientId("client-id")
-				.redirectUri("http://localhost/client-1")
-				.state(oldState)
-				.build();
-
-		Map<String, Object> sessionAttrs = spy(new HashMap<>());
-		WebSession session = mock(WebSession.class);
-		when(session.getAttributes()).thenReturn(sessionAttrs);
-		WebSessionManager sessionManager = e -> Mono.just(session);
-
-		this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager,
-				ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
-		ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager,
-				ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
-
-		Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
-				.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
-				.then(this.repository.removeAuthorizationRequest(this.exchange));
-
-		StepVerifier.create(saveAndSaveAndRemove)
-				.expectNext(this.authorizationRequest)
-				.verifyComplete();
-
-		StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
-				.verifyComplete();
-
-		verify(sessionAttrs, times(3)).put(any(), any());
-	}
-
 	private void assertSessionStartedIs(boolean expected) {
 		Mono<Boolean> isStarted = this.exchange.getSession().map(WebSession::isStarted);
 		StepVerifier.create(isStarted)