فهرست منبع

Store one request by default in WebSessionOAuth2ServerAuthorizationRequestRepository

Related to gh-9649
Closes gh-9857
Closes gh-9912
Steve Riesenberg 4 سال پیش
والد
کامیت
a108868529

+ 62 - 51
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -34,6 +34,7 @@ import org.springframework.web.server.WebSession;
  * {@link OAuth2AuthorizationRequest} in the {@code WebSession}.
  *
  * @author Rob Winch
+ * @author Steve Riesenberg
  * @since 5.1
  * @see AuthorizationRequestRepository
  * @see OAuth2AuthorizationRequest
@@ -46,6 +47,8 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
 
 	private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME;
 
+	private boolean allowMultipleAuthorizationRequests;
+
 	@Override
 	public Mono<OAuth2AuthorizationRequest> loadAuthorizationRequest(ServerWebExchange exchange) {
 		String state = getStateParameter(exchange);
@@ -53,7 +56,9 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
 			return Mono.empty();
 		}
 		// @formatter:off
-		return getStateToAuthorizationRequest(exchange)
+		return this.getSessionAttributes(exchange)
+				.filter((sessionAttrs) -> sessionAttrs.containsKey(this.sessionAttributeName))
+				.map(this::getAuthorizationRequests)
 				.filter((stateToAuthorizationRequest) -> stateToAuthorizationRequest.containsKey(state))
 				.map((stateToAuthorizationRequest) -> stateToAuthorizationRequest.get(state));
 		// @formatter:on
@@ -63,10 +68,20 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
 	public Mono<Void> saveAuthorizationRequest(OAuth2AuthorizationRequest authorizationRequest,
 			ServerWebExchange exchange) {
 		Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
+		Assert.notNull(exchange, "exchange cannot be null");
 		// @formatter:off
-		return saveStateToAuthorizationRequest(exchange)
-				.doOnNext((stateToAuthorizationRequest) -> stateToAuthorizationRequest
-						.put(authorizationRequest.getState(), authorizationRequest))
+		return getSessionAttributes(exchange)
+				.doOnNext((sessionAttrs) -> {
+					if (this.allowMultipleAuthorizationRequests) {
+						Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(
+								sessionAttrs);
+						authorizationRequests.put(authorizationRequest.getState(), authorizationRequest);
+						sessionAttrs.put(this.sessionAttributeName, authorizationRequests);
+					}
+					else {
+						sessionAttrs.put(this.sessionAttributeName, authorizationRequest);
+					}
+				})
 				.then();
 		// @formatter:on
 	}
@@ -78,30 +93,21 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
 			return Mono.empty();
 		}
 		// @formatter:off
-		return exchange.getSession()
-				.map(WebSession::getAttributes)
-				.handle((sessionAttrs, sink) -> {
-					Map<String, OAuth2AuthorizationRequest> stateToAuthzRequest = sessionAttrsMapStateToAuthorizationRequest(
+		return getSessionAttributes(exchange)
+				.flatMap((sessionAttrs) -> {
+					Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(
 							sessionAttrs);
-					if (stateToAuthzRequest == null) {
-						sink.complete();
-						return;
-					}
-					OAuth2AuthorizationRequest removedValue = stateToAuthzRequest.remove(state);
-					if (stateToAuthzRequest.isEmpty()) {
+					OAuth2AuthorizationRequest originalRequest = authorizationRequests.remove(state);
+					if (authorizationRequests.isEmpty()) {
 						sessionAttrs.remove(this.sessionAttributeName);
 					}
-					else if (removedValue != null) {
-						// gh-7327 Overwrite the existing Map to ensure the state is saved for
-						// distributed sessions
-						sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest);
-					}
-					if (removedValue == null) {
-						sink.complete();
+					else if (authorizationRequests.size() == 1) {
+						sessionAttrs.put(this.sessionAttributeName, authorizationRequests.values().iterator().next());
 					}
 					else {
-						sink.next(removedValue);
+						sessionAttrs.put(this.sessionAttributeName, authorizationRequests);
 					}
+					return Mono.justOrEmpty(originalRequest);
 				});
 		// @formatter:on
 	}
@@ -120,36 +126,41 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
 		return exchange.getSession().map(WebSession::getAttributes);
 	}
 
-	private Mono<Map<String, OAuth2AuthorizationRequest>> getStateToAuthorizationRequest(ServerWebExchange exchange) {
-		Assert.notNull(exchange, "exchange cannot be null");
-
-		// @formatter:off
-		return getSessionAttributes(exchange)
-				.flatMap((sessionAttrs) -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs)));
-		// @formatter:on
-	}
-
-	private Mono<Map<String, OAuth2AuthorizationRequest>> saveStateToAuthorizationRequest(ServerWebExchange exchange) {
-		Assert.notNull(exchange, "exchange cannot be null");
-		// @formatter:off
-		return getSessionAttributes(exchange)
-				.doOnNext((sessionAttrs) -> {
-					Object stateToAuthzRequest = sessionAttrs.get(this.sessionAttributeName);
-					if (stateToAuthzRequest == null) {
-						stateToAuthzRequest = new HashMap<String, OAuth2AuthorizationRequest>();
-					}
-					// No matter stateToAuthzRequest was in session or not, we should always put
-					// it into session again
-					// in case of redis or hazelcast session. #6215
-					sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest);
-				})
-				.flatMap((sessionAttrs) -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs)));
-		// @formatter:on
+	private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(Map<String, Object> sessionAttrs) {
+		Object sessionAttributeValue = sessionAttrs.get(this.sessionAttributeName);
+		if (sessionAttributeValue == null) {
+			return new HashMap<>();
+		}
+		else if (sessionAttributeValue instanceof OAuth2AuthorizationRequest) {
+			OAuth2AuthorizationRequest oauth2AuthorizationRequest = (OAuth2AuthorizationRequest) sessionAttributeValue;
+			Map<String, OAuth2AuthorizationRequest> authorizationRequests = new HashMap<>(1);
+			authorizationRequests.put(oauth2AuthorizationRequest.getState(), oauth2AuthorizationRequest);
+			return authorizationRequests;
+		}
+		else if (sessionAttributeValue instanceof Map) {
+			@SuppressWarnings("unchecked")
+			Map<String, OAuth2AuthorizationRequest> authorizationRequests = (Map<String, OAuth2AuthorizationRequest>) sessionAttrs
+					.get(this.sessionAttributeName);
+			return authorizationRequests;
+		}
+		else {
+			throw new IllegalStateException(
+					"authorizationRequests is supposed to be a Map or OAuth2AuthorizationRequest but actually is a "
+							+ sessionAttributeValue.getClass());
+		}
 	}
 
-	private Map<String, OAuth2AuthorizationRequest> sessionAttrsMapStateToAuthorizationRequest(
-			Map<String, Object> sessionAttrs) {
-		return (Map<String, OAuth2AuthorizationRequest>) sessionAttrs.get(this.sessionAttributeName);
+	/**
+	 * Configure if multiple {@link OAuth2AuthorizationRequest}s should be stored per
+	 * session. Default is false (not allow multiple {@link OAuth2AuthorizationRequest}
+	 * per session).
+	 * @param allowMultipleAuthorizationRequests true allows more than one
+	 * {@link OAuth2AuthorizationRequest} to be stored per session.
+	 * @since 5.5
+	 */
+	@Deprecated
+	public void setAllowMultipleAuthorizationRequests(boolean allowMultipleAuthorizationRequests) {
+		this.allowMultipleAuthorizationRequests = allowMultipleAuthorizationRequests;
 	}
 
 }

+ 252 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests.java

@@ -0,0 +1,252 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.web.server;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.Before;
+import org.junit.Test;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+
+import org.springframework.http.codec.ServerCodecConfigurer;
+import org.springframework.http.server.reactive.ServerHttpRequest;
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
+import org.springframework.mock.http.server.reactive.MockServerHttpResponse;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.WebSession;
+import org.springframework.web.server.adapter.DefaultServerWebExchange;
+import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver;
+import org.springframework.web.server.session.WebSessionManager;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+/**
+ * Tests for {@link WebSessionOAuth2ServerAuthorizationRequestRepository} when
+ * {@link WebSessionOAuth2ServerAuthorizationRequestRepository#setAllowMultipleAuthorizationRequests(boolean)}
+ * is enabled.
+ *
+ * @author Steve Riesenberg
+ */
+
+public class WebSessionOAuth2ServerAuthorizationRequestRepositoryAllowMultipleAuthorizationRequestsTests
+		extends WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
+
+	@Before
+	public void setup() {
+		this.repository = new WebSessionOAuth2ServerAuthorizationRequestRepository();
+		this.repository.setAllowMultipleAuthorizationRequests(true);
+	}
+
+	@Test
+	public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() {
+		String oldState = "state0";
+		// @formatter:off
+		MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, oldState)
+				.build();
+		OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(oldState)
+				.build();
+		// @formatter:on
+		WebSessionManager sessionManager = (e) -> this.exchange.getSession();
+		this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(),
+				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+		ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(),
+				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+		// @formatter:off
+		Mono<OAuth2AuthorizationRequest> saveAndSaveAndLoad = this.repository
+				.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
+				.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
+				.then(this.repository.loadAuthorizationRequest(oldExchange));
+		StepVerifier.create(saveAndSaveAndLoad)
+				.expectNext(oldAuthorizationRequest)
+				.verifyComplete();
+		StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
+				.expectNext(this.authorizationRequest)
+				.verifyComplete();
+		// @formatter:on
+	}
+
+	// gh-5145
+	@Test
+	public void loadAuthorizationRequestWhenSavedWithAllowMultipleAuthorizationRequestsThenReturnOldAuthorizationRequest() {
+		// save 2 requests with legacy (allowMultipleAuthorizationRequests=true) and load
+		// with new
+		WebSessionOAuth2ServerAuthorizationRequestRepository legacy = new WebSessionOAuth2ServerAuthorizationRequestRepository();
+		legacy.setAllowMultipleAuthorizationRequests(true);
+		// @formatter:off
+		String state1 = "state-1122";
+		OAuth2AuthorizationRequest authorizationRequest1 = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(state1)
+				.build();
+		StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest1, this.exchange))
+				.verifyComplete();
+		String state2 = "state-3344";
+		OAuth2AuthorizationRequest authorizationRequest2 = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(state2)
+				.build();
+		StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest2, this.exchange))
+				.verifyComplete();
+		ServerHttpRequest newRequest = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, state1)
+				.build();
+		ServerWebExchange newExchange = this.exchange.mutate()
+				.request(newRequest)
+				.build();
+		StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange))
+				.expectNext(authorizationRequest1)
+				.verifyComplete();
+		// @formatter:on
+	}
+
+	// gh-5145
+	@Test
+	public void saveAuthorizationRequestWhenSavedWithAllowMultipleAuthorizationRequestsThenLoadNewAuthorizationRequest() {
+		// save 2 requests with legacy (allowMultipleAuthorizationRequests=true), save
+		// with new, and load with new
+		WebSessionOAuth2ServerAuthorizationRequestRepository legacy = new WebSessionOAuth2ServerAuthorizationRequestRepository();
+		legacy.setAllowMultipleAuthorizationRequests(true);
+		// @formatter:off
+		String state1 = "state-1122";
+		OAuth2AuthorizationRequest authorizationRequest1 = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(state1)
+				.build();
+		StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest1, this.exchange))
+				.verifyComplete();
+		String state2 = "state-3344";
+		OAuth2AuthorizationRequest authorizationRequest2 = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(state2)
+				.build();
+		StepVerifier.create(legacy.saveAuthorizationRequest(authorizationRequest2, this.exchange))
+				.verifyComplete();
+		String state3 = "state-5566";
+		OAuth2AuthorizationRequest authorizationRequest3 = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(state3)
+				.build();
+		ServerHttpRequest newRequest = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, state3)
+				.build();
+		ServerWebExchange newExchange = this.exchange.mutate()
+				.request(newRequest)
+				.build();
+		Mono<OAuth2AuthorizationRequest> saveAndLoad = this.repository
+				.saveAuthorizationRequest(authorizationRequest3, this.exchange)
+				.then(this.repository.loadAuthorizationRequest(newExchange));
+		StepVerifier.create(saveAndLoad)
+				.expectNext(authorizationRequest3)
+				.verifyComplete();
+		// @formatter:on
+	}
+
+	@Test
+	public void removeAuthorizationRequestWhenMultipleThenOnlyOneRemoved() {
+		String oldState = "state0";
+		// @formatter:off
+		MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, oldState)
+				.build();
+		OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(oldState)
+				.build();
+		// @formatter:on
+		WebSessionManager sessionManager = (e) -> this.exchange.getSession();
+		this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(),
+				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+		ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(),
+				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+		// @formatter:off
+		Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository
+				.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
+				.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
+				.then(this.repository.removeAuthorizationRequest(this.exchange));
+		StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest)
+				.verifyComplete();
+		StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
+				.verifyComplete();
+		StepVerifier.create(this.repository.loadAuthorizationRequest(oldExchange))
+				.expectNext(oldAuthorizationRequest)
+				.verifyComplete();
+		// @formatter:on
+	}
+
+	// gh-7327
+	@Test
+	public void removeAuthorizationRequestWhenMultipleThenRemovedAndSessionAttributeUpdated() {
+		String oldState = "state0";
+		// @formatter:off
+		MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, oldState)
+				.build();
+		OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(oldState)
+				.build();
+		// @formatter:on
+		Map<String, Object> sessionAttrs = spy(new HashMap<>());
+		WebSession session = mock(WebSession.class);
+		given(session.getAttributes()).willReturn(sessionAttrs);
+		WebSessionManager sessionManager = (e) -> Mono.just(session);
+		this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(),
+				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+		ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(),
+				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+		// @formatter:off
+		Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository
+				.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
+				.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
+				.then(this.repository.removeAuthorizationRequest(this.exchange));
+		StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest)
+				.verifyComplete();
+		StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
+				.verifyComplete();
+		// @formatter:on
+		verify(sessionAttrs, times(3)).put(any(), any());
+	}
+
+}

+ 159 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests.java

@@ -0,0 +1,159 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.web.server;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.Before;
+import org.junit.Test;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+
+import org.springframework.http.codec.ServerCodecConfigurer;
+import org.springframework.http.server.reactive.ServerHttpRequest;
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
+import org.springframework.mock.http.server.reactive.MockServerHttpResponse;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.WebSession;
+import org.springframework.web.server.adapter.DefaultServerWebExchange;
+import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver;
+import org.springframework.web.server.session.WebSessionManager;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+/**
+ * Tests for {@link WebSessionOAuth2ServerAuthorizationRequestRepository} when
+ * {@link WebSessionOAuth2ServerAuthorizationRequestRepository#setAllowMultipleAuthorizationRequests(boolean)}
+ * is disabled.
+ *
+ * @author Steve Riesenberg
+ */
+public class WebSessionOAuth2ServerAuthorizationRequestRepositoryDoNotAllowMultipleAuthorizationRequestsTests
+		extends WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
+
+	@Before
+	public void setup() {
+		this.repository = new WebSessionOAuth2ServerAuthorizationRequestRepository();
+		this.repository.setAllowMultipleAuthorizationRequests(false);
+	}
+
+	// gh-5145
+	@Test
+	public void loadAuthorizationRequestWhenMultipleSavedThenReturnLastAuthorizationRequest() {
+		// @formatter:off
+		String state1 = "state-1122";
+		OAuth2AuthorizationRequest authorizationRequest1 = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(state1)
+				.build();
+		StepVerifier.create(this.repository.saveAuthorizationRequest(authorizationRequest1, this.exchange))
+				.verifyComplete();
+		String state2 = "state-3344";
+		OAuth2AuthorizationRequest authorizationRequest2 = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(state2)
+				.build();
+		StepVerifier.create(this.repository.saveAuthorizationRequest(authorizationRequest2, this.exchange))
+				.verifyComplete();
+		String state3 = "state-5566";
+		OAuth2AuthorizationRequest authorizationRequest3 = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(state3)
+				.build();
+		StepVerifier.create(this.repository.saveAuthorizationRequest(authorizationRequest3, this.exchange))
+				.verifyComplete();
+		ServerHttpRequest newRequest1 = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, state1)
+				.build();
+		ServerWebExchange newExchange1 = this.exchange.mutate()
+				.request(newRequest1)
+				.build();
+		StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange1))
+				.verifyComplete();
+		ServerHttpRequest newRequest2 = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, state2)
+				.build();
+		ServerWebExchange newExchange2 = this.exchange.mutate()
+				.request(newRequest2)
+				.build();
+		StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange2))
+				.verifyComplete();
+		ServerHttpRequest newRequest3 = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, state3)
+				.build();
+		ServerWebExchange newExchange3 = this.exchange.mutate()
+				.request(newRequest3)
+				.build();
+		StepVerifier.create(this.repository.loadAuthorizationRequest(newExchange3))
+				.expectNext(authorizationRequest3)
+				.verifyComplete();
+		// @formatter:on
+	}
+
+	// gh-5145
+	@Test
+	public void removeAuthorizationRequestWhenMultipleThenSessionAttributeRemoved() {
+		String oldState = "state0";
+		// @formatter:off
+		MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, oldState)
+				.build();
+		OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(oldState)
+				.build();
+		// @formatter:on
+		Map<String, Object> sessionAttrs = spy(new HashMap<>());
+		WebSession session = mock(WebSession.class);
+		given(session.getAttributes()).willReturn(sessionAttrs);
+		WebSessionManager sessionManager = (e) -> Mono.just(session);
+		this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(),
+				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+		ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(),
+				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+		// @formatter:off
+		Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository
+				.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
+				.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
+				.then(this.repository.removeAuthorizationRequest(this.exchange));
+		StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest)
+				.verifyComplete();
+		StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
+				.verifyComplete();
+		// @formatter:on
+		verify(sessionAttrs, times(2)).put(anyString(), any());
+		verify(sessionAttrs).remove(anyString());
+	}
+
+}

+ 5 - 120
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,43 +16,31 @@
 
 package org.springframework.security.oauth2.client.web.server;
 
-import java.util.HashMap;
 import java.util.Map;
 
 import org.junit.Test;
 import reactor.core.publisher.Mono;
 import reactor.test.StepVerifier;
 
-import org.springframework.http.codec.ServerCodecConfigurer;
 import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
-import org.springframework.mock.http.server.reactive.MockServerHttpResponse;
 import org.springframework.mock.web.server.MockServerWebExchange;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebSession;
-import org.springframework.web.server.adapter.DefaultServerWebExchange;
-import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver;
-import org.springframework.web.server.session.WebSessionManager;
 
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.BDDMockito.given;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.spy;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
 
 /**
  * @author Rob Winch
  * @since 5.1
  */
-public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
+public abstract class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
 
-	private WebSessionOAuth2ServerAuthorizationRequestRepository repository = new WebSessionOAuth2ServerAuthorizationRequestRepository();
+	protected WebSessionOAuth2ServerAuthorizationRequestRepository repository;
 
 	// @formatter:off
-	private OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+	protected OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
 			.authorizationUri("https://example.com/oauth2/authorize")
 			.clientId("client-id")
 			.redirectUri("http://localhost/client-1")
@@ -60,7 +48,7 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
 			.build();
 	// @formatter:on
 
-	private ServerWebExchange exchange = MockServerWebExchange
+	protected ServerWebExchange exchange = MockServerWebExchange
 			.from(MockServerHttpRequest.get("/").queryParam(OAuth2ParameterNames.STATE, "state"));
 
 	@Test
@@ -114,39 +102,6 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
 		// @formatter:on
 	}
 
-	@Test
-	public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() {
-		String oldState = "state0";
-		// @formatter:off
-		MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
-				.queryParam(OAuth2ParameterNames.STATE, oldState)
-				.build();
-		OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
-				.authorizationUri("https://example.com/oauth2/authorize")
-				.clientId("client-id")
-				.redirectUri("http://localhost/client-1")
-				.state(oldState)
-				.build();
-		// @formatter:on
-		WebSessionManager sessionManager = (e) -> this.exchange.getSession();
-		this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(),
-				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
-		ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(),
-				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
-		// @formatter:off
-		Mono<OAuth2AuthorizationRequest> saveAndSaveAndLoad = this.repository
-				.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
-				.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
-				.then(this.repository.loadAuthorizationRequest(oldExchange));
-		StepVerifier.create(saveAndSaveAndLoad)
-				.expectNext(oldAuthorizationRequest)
-				.verifyComplete();
-		StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
-				.expectNext(this.authorizationRequest)
-				.verifyComplete();
-		// @formatter:on
-	}
-
 	@Test
 	public void saveAuthorizationRequestWhenAuthorizationRequestNullThenThrowsIllegalArgumentException() {
 		this.authorizationRequest = null;
@@ -211,76 +166,6 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
 		// @formatter:on
 	}
 
-	@Test
-	public void removeAuthorizationRequestWhenMultipleThenOnlyOneRemoved() {
-		String oldState = "state0";
-		// @formatter:off
-		MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
-				.queryParam(OAuth2ParameterNames.STATE, oldState)
-				.build();
-		OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
-				.authorizationUri("https://example.com/oauth2/authorize")
-				.clientId("client-id")
-				.redirectUri("http://localhost/client-1")
-				.state(oldState)
-				.build();
-		// @formatter:on
-		WebSessionManager sessionManager = (e) -> this.exchange.getSession();
-		this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(),
-				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
-		ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(),
-				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
-		// @formatter:off
-		Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository
-				.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
-				.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
-				.then(this.repository.removeAuthorizationRequest(this.exchange));
-		StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest)
-				.verifyComplete();
-		StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
-				.verifyComplete();
-		StepVerifier.create(this.repository.loadAuthorizationRequest(oldExchange))
-				.expectNext(oldAuthorizationRequest)
-				.verifyComplete();
-		// @formatter:on
-	}
-
-	// gh-7327
-	@Test
-	public void removeAuthorizationRequestWhenMultipleThenRemovedAndSessionAttributeUpdated() {
-		String oldState = "state0";
-		// @formatter:off
-		MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
-				.queryParam(OAuth2ParameterNames.STATE, oldState)
-				.build();
-		OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
-				.authorizationUri("https://example.com/oauth2/authorize")
-				.clientId("client-id")
-				.redirectUri("http://localhost/client-1")
-				.state(oldState)
-				.build();
-		// @formatter:on
-		Map<String, Object> sessionAttrs = spy(new HashMap<>());
-		WebSession session = mock(WebSession.class);
-		given(session.getAttributes()).willReturn(sessionAttrs);
-		WebSessionManager sessionManager = (e) -> Mono.just(session);
-		this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(),
-				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
-		ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(),
-				sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
-		// @formatter:off
-		Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository
-				.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
-				.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
-				.then(this.repository.removeAuthorizationRequest(this.exchange));
-		StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest)
-				.verifyComplete();
-		StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
-				.verifyComplete();
-		// @formatter:on
-		verify(sessionAttrs, times(3)).put(any(), any());
-	}
-
 	private void assertSessionStartedIs(boolean expected) {
 		// @formatter:off
 		Mono<Boolean> isStarted = this.exchange.getSession()