浏览代码

Merge branch '5.8.x' into 6.2.x

Steve Riesenberg 1 年之前
父节点
当前提交
5a1d261ce0

+ 4 - 8
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepository.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -94,14 +94,10 @@ public final class WebSessionServerOAuth2AuthorizedClientRepository implements S
 		// @formatter:on
 	}
 
-	@SuppressWarnings("unchecked")
 	private Map<String, OAuth2AuthorizedClient> getAuthorizedClients(WebSession session) {
-		Map<String, OAuth2AuthorizedClient> authorizedClients = (session != null)
-				? (Map<String, OAuth2AuthorizedClient>) session.getAttribute(this.sessionAttributeName) : null;
-		if (authorizedClients == null) {
-			authorizedClients = new HashMap<>();
-		}
-		return authorizedClients;
+		Assert.notNull(session, "session cannot be null");
+		Map<String, OAuth2AuthorizedClient> authorizedClients = session.getAttribute(this.sessionAttributeName);
+		return (authorizedClients != null) ? authorizedClients : new HashMap<>();
 	}
 
 }

+ 28 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepositoryTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@
 package org.springframework.security.oauth2.client.web.server;
 
 import org.junit.jupiter.api.Test;
+import reactor.core.publisher.Mono;
 
 import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
 import org.springframework.mock.web.server.MockServerWebExchange;
@@ -24,10 +25,12 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebSession;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 
 /**
@@ -202,4 +205,28 @@ public class WebSessionServerOAuth2AuthorizedClientRepositoryTests {
 		assertThat(loadedAuthorizedClient2).isSameAs(authorizedClient2);
 	}
 
+	@Test
+	public void saveAuthorizedClientWhenSessionIsNullThenThrowIllegalArgumentException() {
+		ServerWebExchange exchange = mock(ServerWebExchange.class);
+		given(exchange.getSession()).willReturn(Mono.empty());
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration1, this.principalName1,
+				mock(OAuth2AccessToken.class));
+		// @formatter:off
+		assertThatIllegalArgumentException()
+			.isThrownBy(() -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, exchange).block())
+			.withMessage("session cannot be null");
+		// @formatter:on
+	}
+
+	@Test
+	public void removeAuthorizedClientWhenSessionIsNullThenThrowIllegalArgumentException() {
+		ServerWebExchange exchange = mock(ServerWebExchange.class);
+		given(exchange.getSession()).willReturn(Mono.empty());
+		// @formatter:off
+		assertThatIllegalArgumentException()
+			.isThrownBy(() -> this.authorizedClientRepository.removeAuthorizedClient(this.registrationId1, null, exchange).block())
+			.withMessage("session cannot be null");
+		// @formatter:on
+	}
+
 }