Pārlūkot izejas kodu

Assert WebSession is not null

Issue gh-14975
JANG 1 gadu atpakaļ
vecāks
revīzija
1695d03b72

+ 4 - 7
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepository.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -96,12 +96,9 @@ public final class WebSessionServerOAuth2AuthorizedClientRepository implements S
 
 	@SuppressWarnings("unchecked")
 	private Map<String, OAuth2AuthorizedClient> getAuthorizedClients(WebSession session) {
-		Map<String, OAuth2AuthorizedClient> authorizedClients = (session != null)
-				? (Map<String, OAuth2AuthorizedClient>) session.getAttribute(this.sessionAttributeName) : null;
-		if (authorizedClients == null) {
-			authorizedClients = new HashMap<>();
-		}
-		return authorizedClients;
+		Assert.notNull(session, "session cannot be null");
+		Map<String, OAuth2AuthorizedClient> authorizedClients = session.getAttribute(this.sessionAttributeName);
+		return (authorizedClients != null) ? authorizedClients : new HashMap<>();
 	}
 
 }

+ 23 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/WebSessionServerOAuth2AuthorizedClientRepositoryTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -25,10 +25,12 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.web.server.WebSession;
+import reactor.core.publisher.Mono;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 /**
  * @author Rob Winch
@@ -201,5 +203,25 @@ public class WebSessionServerOAuth2AuthorizedClientRepositoryTests {
 		assertThat(loadedAuthorizedClient2).isNotNull();
 		assertThat(loadedAuthorizedClient2).isSameAs(authorizedClient2);
 	}
+	
+	@Test
+	public void saveAuthorizedClientWhenSessionIsNullThenThrowIllegalArgumentException() {
+		MockServerWebExchange mockedExchange = mock(MockServerWebExchange.class);
+		when(mockedExchange.getSession()).thenReturn(Mono.empty());
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration1, this.principalName1,
+				mock(OAuth2AccessToken.class));
+		assertThatIllegalArgumentException().isThrownBy(
+				() -> authorizedClientRepository.saveAuthorizedClient(authorizedClient, null, mockedExchange).block())
+				.withMessage("session cannot be null");
+	}
+	
+	@Test
+	public void removeAuthorizedClientWhenSessionIsNullThenThrowIllegalArgumentException() {
+		MockServerWebExchange mockedExchange = mock(MockServerWebExchange.class);
+		when(mockedExchange.getSession()).thenReturn(Mono.empty());
+		assertThatIllegalArgumentException().isThrownBy(
+				() -> authorizedClientRepository.removeAuthorizedClient(this.registrationId1, null, mockedExchange).block())
+				.withMessage("session cannot be null");
+	}
 
 }