Przeglądaj źródła

Fix OAuth2 Client with Ditributed Session

Fixes: gh-6215
Zhanwei Wang 6 lat temu
rodzic
commit
a60fd43534

+ 20 - 9
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepository.java

@@ -53,7 +53,7 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
 		if (state == null) {
 			return Mono.empty();
 		}
-		return getStateToAuthorizationRequest(exchange, false)
+		return getStateToAuthorizationRequest(exchange)
 				.filter(stateToAuthorizationRequest -> stateToAuthorizationRequest.containsKey(state))
 				.map(stateToAuthorizationRequest -> stateToAuthorizationRequest.get(state));
 	}
@@ -62,9 +62,8 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
 	public Mono<Void> saveAuthorizationRequest(
 			OAuth2AuthorizationRequest authorizationRequest, ServerWebExchange exchange) {
 		Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
-		return getStateToAuthorizationRequest(exchange, true)
-				.doOnNext(stateToAuthorizationRequest -> stateToAuthorizationRequest.put(authorizationRequest.getState(), authorizationRequest))
-				.then();
+		return saveStateToAuthorizationRequest(exchange).doOnNext(stateToAuthorizationRequest ->
+				stateToAuthorizationRequest.put(authorizationRequest.getState(), authorizationRequest)).then();
 	}
 
 	@Override
@@ -108,16 +107,28 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
 		return exchange.getSession().map(WebSession::getAttributes);
 	}
 
-	private Mono<Map<String, OAuth2AuthorizationRequest>> getStateToAuthorizationRequest(ServerWebExchange exchange, boolean create) {
+	private Mono<Map<String, OAuth2AuthorizationRequest>> getStateToAuthorizationRequest(ServerWebExchange exchange) {
+		Assert.notNull(exchange, "exchange cannot be null");
+
+		return getSessionAttributes(exchange)
+			.flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs)));
+	}
+
+	private Mono<Map<String, OAuth2AuthorizationRequest>> saveStateToAuthorizationRequest(ServerWebExchange exchange) {
 		Assert.notNull(exchange, "exchange cannot be null");
 
 		return getSessionAttributes(exchange)
 			.doOnNext(sessionAttrs -> {
-				if (create) {
-					sessionAttrs.putIfAbsent(this.sessionAttributeName, new HashMap<String, OAuth2AuthorizationRequest>());
+				Object stateToAuthzRequest = sessionAttrs.get(this.sessionAttributeName);
+
+				if (stateToAuthzRequest == null) {
+					stateToAuthzRequest = new HashMap<String, OAuth2AuthorizationRequest>();
 				}
-			})
-			.flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs)));
+
+				// No matter stateToAuthzRequest was in session or not, we should always put it into session again
+				// in case of redis or hazelcast session. #6215
+				sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest);
+			}).flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs)));
 	}
 
 	private Map<String, OAuth2AuthorizationRequest> sessionAttrsMapStateToAuthorizationRequest(Map<String, Object> sessionAttrs) {

+ 37 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionOAuth2ServerAuthorizationRequestRepositoryTests.java

@@ -18,6 +18,13 @@ package org.springframework.security.oauth2.client.web.server;
 
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+import java.util.HashMap;
 import java.util.Map;
 
 import org.junit.Test;
@@ -99,6 +106,36 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
 				.verifyComplete();
 	}
 
+	@Test
+	public void multipleSavedAuthorizationRequestAndRedisCookie() {
+		String oldState = "state0";
+		MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, oldState).build();
+
+		OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(oldState)
+				.build();
+
+		Map<String, Object> sessionAttrs = spy(new HashMap<>());
+		WebSession session = mock(WebSession.class);
+		when(session.getAttributes()).thenReturn(sessionAttrs);
+		WebSessionManager sessionManager = e -> Mono.just(session);
+
+		this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager,
+				ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+		ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager,
+				ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+
+		Mono<Void> saveAndSave = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
+				.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange));
+
+		StepVerifier.create(saveAndSave).verifyComplete();
+		verify(sessionAttrs, times(2)).put(any(), any());
+	}
+
 	@Test
 	public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() {
 		String oldState = "state0";