Steve Riesenberg 9 mesiacov pred
rodič
commit
0eb6acde96

+ 2 - 6
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java

@@ -68,12 +68,8 @@ public final class InMemoryReactiveOAuth2AuthorizedClientService implements Reac
 				if (cachedAuthorizedClient == null) {
 					return null;
 				}
-			// @formatter:off
-				return new OAuth2AuthorizedClient(clientRegistration,
-					cachedAuthorizedClient.getPrincipalName(),
-					cachedAuthorizedClient.getAccessToken(),
-					cachedAuthorizedClient.getRefreshToken());
-			// @formatter:on
+				return new OAuth2AuthorizedClient(clientRegistration, cachedAuthorizedClient.getPrincipalName(),
+						cachedAuthorizedClient.getAccessToken(), cachedAuthorizedClient.getRefreshToken());
 			});
 	}
 

+ 53 - 41
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java

@@ -18,15 +18,18 @@ package org.springframework.security.oauth2.client;
 
 import java.util.Collections;
 import java.util.Map;
+import java.util.function.Consumer;
 
 import org.junit.jupiter.api.Test;
 
+import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@@ -126,7 +129,7 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
 		this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
 		OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService
 			.loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1);
-		assertAuthorizedClientEquals(authorizedClient, loadedAuthorizedClient);
+		assertThat(loadedAuthorizedClient).satisfies(isEqualTo(authorizedClient));
 	}
 
 	@Test
@@ -134,27 +137,27 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
 		ClientRegistration updatedRegistration = ClientRegistration.withClientRegistration(this.registration1)
 			.clientSecret("updated secret")
 			.build();
-		ClientRegistrationRepository repository = mock(ClientRegistrationRepository.class);
-		given(repository.findByRegistrationId(this.registration1.getRegistrationId())).willReturn(this.registration1,
-				updatedRegistration);
 
-		Authentication authentication = mock(Authentication.class);
-		given(authentication.getName()).willReturn(this.principalName1);
+		ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class);
+		given(clientRegistrationRepository.findByRegistrationId(this.registration1.getRegistrationId()))
+			.willReturn(this.registration1, updatedRegistration);
 
-		InMemoryOAuth2AuthorizedClientService service = new InMemoryOAuth2AuthorizedClientService(repository);
+		InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService(
+				clientRegistrationRepository);
 
-		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration1, this.principalName1,
-				mock(OAuth2AccessToken.class));
-		service.saveAuthorizedClient(authorizedClient, authentication);
+		OAuth2AuthorizedClient cachedAuthorizedClient = new OAuth2AuthorizedClient(this.registration1,
+				this.principalName1, mock(OAuth2AccessToken.class), mock(OAuth2RefreshToken.class));
+		authorizedClientService.saveAuthorizedClient(cachedAuthorizedClient,
+				new TestingAuthenticationToken(this.principalName1, null));
 
 		OAuth2AuthorizedClient authorizedClientWithUpdatedRegistration = new OAuth2AuthorizedClient(updatedRegistration,
-				this.principalName1, mock(OAuth2AccessToken.class));
-		OAuth2AuthorizedClient firstLoadedClient = service.loadAuthorizedClient(this.registration1.getRegistrationId(),
-				this.principalName1);
-		OAuth2AuthorizedClient secondLoadedClient = service.loadAuthorizedClient(this.registration1.getRegistrationId(),
-				this.principalName1);
-		assertAuthorizedClientEquals(authorizedClient, firstLoadedClient);
-		assertAuthorizedClientEquals(authorizedClientWithUpdatedRegistration, secondLoadedClient);
+				this.principalName1, mock(OAuth2AccessToken.class), mock(OAuth2RefreshToken.class));
+		OAuth2AuthorizedClient firstLoadedClient = authorizedClientService
+			.loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1);
+		OAuth2AuthorizedClient secondLoadedClient = authorizedClientService
+			.loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1);
+		assertThat(firstLoadedClient).satisfies(isEqualTo(cachedAuthorizedClient));
+		assertThat(secondLoadedClient).satisfies(isEqualTo(authorizedClientWithUpdatedRegistration));
 	}
 
 	@Test
@@ -178,7 +181,7 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
 		this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
 		OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService
 			.loadAuthorizedClient(this.registration3.getRegistrationId(), this.principalName2);
-		assertAuthorizedClientEquals(authorizedClient, loadedAuthorizedClient);
+		assertThat(loadedAuthorizedClient).satisfies(isEqualTo(authorizedClient));
 	}
 
 	@Test
@@ -210,29 +213,38 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
 		assertThat(loadedAuthorizedClient).isNull();
 	}
 
-	private static void assertAuthorizedClientEquals(OAuth2AuthorizedClient expected, OAuth2AuthorizedClient actual) {
-		assertThat(actual).isNotNull();
-		assertThat(actual.getClientRegistration().getRegistrationId())
-			.isEqualTo(expected.getClientRegistration().getRegistrationId());
-		assertThat(actual.getClientRegistration().getClientName())
-			.isEqualTo(expected.getClientRegistration().getClientName());
-		assertThat(actual.getClientRegistration().getRedirectUri())
-			.isEqualTo(expected.getClientRegistration().getRedirectUri());
-		assertThat(actual.getClientRegistration().getAuthorizationGrantType())
-			.isEqualTo(expected.getClientRegistration().getAuthorizationGrantType());
-		assertThat(actual.getClientRegistration().getClientAuthenticationMethod())
-			.isEqualTo(expected.getClientRegistration().getClientAuthenticationMethod());
-		assertThat(actual.getClientRegistration().getClientId())
-			.isEqualTo(expected.getClientRegistration().getClientId());
-		assertThat(actual.getClientRegistration().getClientSecret())
-			.isEqualTo(expected.getClientRegistration().getClientSecret());
-		assertThat(actual.getPrincipalName()).isEqualTo(expected.getPrincipalName());
-		assertThat(actual.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
-		assertThat(actual.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
-		assertThat(actual.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
-		assertThat(actual.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
-		assertThat(actual.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes());
-		assertThat(actual.getRefreshToken()).isEqualTo(expected.getRefreshToken());
+	private static Consumer<OAuth2AuthorizedClient> isEqualTo(OAuth2AuthorizedClient expected) {
+		return (actual) -> {
+			assertThat(actual).isNotNull();
+			assertThat(actual.getClientRegistration().getRegistrationId())
+				.isEqualTo(expected.getClientRegistration().getRegistrationId());
+			assertThat(actual.getClientRegistration().getClientName())
+				.isEqualTo(expected.getClientRegistration().getClientName());
+			assertThat(actual.getClientRegistration().getRedirectUri())
+				.isEqualTo(expected.getClientRegistration().getRedirectUri());
+			assertThat(actual.getClientRegistration().getAuthorizationGrantType())
+				.isEqualTo(expected.getClientRegistration().getAuthorizationGrantType());
+			assertThat(actual.getClientRegistration().getClientAuthenticationMethod())
+				.isEqualTo(expected.getClientRegistration().getClientAuthenticationMethod());
+			assertThat(actual.getClientRegistration().getClientId())
+				.isEqualTo(expected.getClientRegistration().getClientId());
+			assertThat(actual.getClientRegistration().getClientSecret())
+				.isEqualTo(expected.getClientRegistration().getClientSecret());
+			assertThat(actual.getPrincipalName()).isEqualTo(expected.getPrincipalName());
+			assertThat(actual.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
+			assertThat(actual.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
+			assertThat(actual.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
+			assertThat(actual.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
+			assertThat(actual.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes());
+			if (expected.getRefreshToken() != null) {
+				assertThat(actual.getRefreshToken()).isNotNull();
+				assertThat(actual.getRefreshToken().getTokenValue())
+					.isEqualTo(expected.getRefreshToken().getTokenValue());
+				assertThat(actual.getRefreshToken().getIssuedAt()).isEqualTo(expected.getRefreshToken().getIssuedAt());
+				assertThat(actual.getRefreshToken().getExpiresAt())
+					.isEqualTo(expected.getRefreshToken().getExpiresAt());
+			}
+		};
 	}
 
 }

+ 27 - 13
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientServiceTests.java

@@ -36,6 +36,7 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@@ -59,8 +60,9 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests {
 
 	private Authentication principal = new TestingAuthenticationToken(this.principalName, "notused");
 
-	OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token", Instant.now(),
-			Instant.now().plus(Duration.ofDays(1)));
+	private OAuth2AccessToken accessToken;
+
+	private OAuth2RefreshToken refreshToken;
 
 	// @formatter:off
 	private ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(this.clientRegistrationId)
@@ -82,6 +84,11 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests {
 	public void setup() {
 		this.authorizedClientService = new InMemoryReactiveOAuth2AuthorizedClientService(
 				this.clientRegistrationRepository);
+
+		Instant issuedAt = Instant.now();
+		Instant expiresAt = issuedAt.plus(Duration.ofDays(1));
+		this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token", issuedAt, expiresAt);
+		this.refreshToken = new OAuth2RefreshToken("refresh", issuedAt, expiresAt);
 	}
 
 	@Test
@@ -163,26 +170,26 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests {
 
 	@Test
 	@SuppressWarnings("unchecked")
-	public void loadAuthorizedClientWhenClientRegistrationChangedThenCurrentVersionFound() {
-		ClientRegistration changedClientRegistration = ClientRegistration
-			.withClientRegistration(this.clientRegistration)
+	public void loadAuthorizedClientWhenClientRegistrationIsUpdatedThenReturnsAuthorizedClientWithUpdatedClientRegistration() {
+		ClientRegistration updatedRegistration = ClientRegistration.withClientRegistration(this.clientRegistration)
 			.clientSecret("updated secret")
 			.build();
 
 		given(this.clientRegistrationRepository.findByRegistrationId(this.clientRegistrationId))
-			.willReturn(Mono.just(this.clientRegistration), Mono.just(changedClientRegistration));
-		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
-				this.principalName, this.accessToken);
-		OAuth2AuthorizedClient authorizedClientWithChangedRegistration = new OAuth2AuthorizedClient(
-				changedClientRegistration, this.principalName, this.accessToken);
+			.willReturn(Mono.just(this.clientRegistration), Mono.just(updatedRegistration));
+
+		OAuth2AuthorizedClient cachedAuthorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
+				this.principalName, this.accessToken, this.refreshToken);
+		OAuth2AuthorizedClient authorizedClientWithChangedRegistration = new OAuth2AuthorizedClient(updatedRegistration,
+				this.principalName, this.accessToken, this.refreshToken);
 
 		Flux<OAuth2AuthorizedClient> saveAndLoadTwice = this.authorizedClientService
-			.saveAuthorizedClient(authorizedClient, this.principal)
+			.saveAuthorizedClient(cachedAuthorizedClient, this.principal)
 			.then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName))
 			.concatWith(
 					this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName));
 		StepVerifier.create(saveAndLoadTwice)
-			.assertNext(isEqualTo(authorizedClient))
+			.assertNext(isEqualTo(cachedAuthorizedClient))
 			.assertNext(isEqualTo(authorizedClientWithChangedRegistration))
 			.verifyComplete();
 	}
@@ -298,7 +305,14 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests {
 			assertThat(actual.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
 			assertThat(actual.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
 			assertThat(actual.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes());
-			assertThat(actual.getRefreshToken()).isEqualTo(expected.getRefreshToken());
+			if (expected.getRefreshToken() != null) {
+				assertThat(actual.getRefreshToken()).isNotNull();
+				assertThat(actual.getRefreshToken().getTokenValue())
+					.isEqualTo(expected.getRefreshToken().getTokenValue());
+				assertThat(actual.getRefreshToken().getIssuedAt()).isEqualTo(expected.getRefreshToken().getIssuedAt());
+				assertThat(actual.getRefreshToken().getExpiresAt())
+					.isEqualTo(expected.getRefreshToken().getExpiresAt());
+			}
 		};
 	}