浏览代码

Merge branch '6.3.x'

Closes gh-16139
Steve Riesenberg 9 月之前
父节点
当前提交
77233daae7

+ 8 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -80,7 +80,13 @@ public final class InMemoryOAuth2AuthorizedClientService implements OAuth2Author
 		if (registration == null) {
 			return null;
 		}
-		return (T) this.authorizedClients.get(new OAuth2AuthorizedClientId(clientRegistrationId, principalName));
+		OAuth2AuthorizedClient cachedAuthorizedClient = this.authorizedClients
+			.get(new OAuth2AuthorizedClientId(clientRegistrationId, principalName));
+		if (cachedAuthorizedClient == null) {
+			return null;
+		}
+		return (T) new OAuth2AuthorizedClient(registration, cachedAuthorizedClient.getPrincipalName(),
+				cachedAuthorizedClient.getAccessToken(), cachedAuthorizedClient.getRefreshToken());
 	}
 
 	@Override

+ 10 - 3
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -62,8 +62,15 @@ public final class InMemoryReactiveOAuth2AuthorizedClientService implements Reac
 		Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
 		Assert.hasText(principalName, "principalName cannot be empty");
 		return (Mono<T>) this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
-			.map((clientRegistration) -> new OAuth2AuthorizedClientId(clientRegistrationId, principalName))
-			.flatMap((identifier) -> Mono.justOrEmpty(this.authorizedClients.get(identifier)));
+			.mapNotNull((clientRegistration) -> {
+				OAuth2AuthorizedClientId id = new OAuth2AuthorizedClientId(clientRegistrationId, principalName);
+				OAuth2AuthorizedClient cachedAuthorizedClient = this.authorizedClients.get(id);
+				if (cachedAuthorizedClient == null) {
+					return null;
+				}
+				return new OAuth2AuthorizedClient(clientRegistration, cachedAuthorizedClient.getPrincipalName(),
+						cachedAuthorizedClient.getAccessToken(), cachedAuthorizedClient.getRefreshToken());
+			});
 	}
 
 	@Override

+ 72 - 5
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -18,22 +18,25 @@ package org.springframework.security.oauth2.client;
 
 import java.util.Collections;
 import java.util.Map;
+import java.util.function.Consumer;
 
 import org.junit.jupiter.api.Test;
 
+import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.assertj.core.api.Assertions.assertThatObject;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.BDDMockito.given;
-import static org.mockito.Mockito.mock;
+import static org.mockito.BDDMockito.mock;
 
 /**
  * Tests for {@link InMemoryOAuth2AuthorizedClientService}.
@@ -79,9 +82,11 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
 	@Test
 	public void constructorWhenAuthorizedClientsProvidedThenUseProvidedAuthorizedClients() {
 		String registrationId = this.registration3.getRegistrationId();
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration3, this.principalName1,
+				mock(OAuth2AccessToken.class));
 		Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = Collections.singletonMap(
 				new OAuth2AuthorizedClientId(this.registration3.getRegistrationId(), this.principalName1),
-				mock(OAuth2AuthorizedClient.class));
+				authorizedClient);
 		ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class);
 		given(clientRegistrationRepository.findByRegistrationId(eq(registrationId))).willReturn(this.registration3);
 		InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService(
@@ -124,7 +129,35 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
 		this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
 		OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService
 			.loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1);
-		assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient);
+		assertThat(loadedAuthorizedClient).satisfies(isEqualTo(authorizedClient));
+	}
+
+	@Test
+	public void loadAuthorizedClientWhenClientRegistrationIsUpdatedThenReturnAuthorizedClientWithUpdatedClientRegistration() {
+		ClientRegistration updatedRegistration = ClientRegistration.withClientRegistration(this.registration1)
+			.clientSecret("updated secret")
+			.build();
+
+		ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class);
+		given(clientRegistrationRepository.findByRegistrationId(this.registration1.getRegistrationId()))
+			.willReturn(this.registration1, updatedRegistration);
+
+		InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService(
+				clientRegistrationRepository);
+
+		OAuth2AuthorizedClient cachedAuthorizedClient = new OAuth2AuthorizedClient(this.registration1,
+				this.principalName1, mock(OAuth2AccessToken.class), mock(OAuth2RefreshToken.class));
+		authorizedClientService.saveAuthorizedClient(cachedAuthorizedClient,
+				new TestingAuthenticationToken(this.principalName1, null));
+
+		OAuth2AuthorizedClient authorizedClientWithUpdatedRegistration = new OAuth2AuthorizedClient(updatedRegistration,
+				this.principalName1, mock(OAuth2AccessToken.class), mock(OAuth2RefreshToken.class));
+		OAuth2AuthorizedClient firstLoadedClient = authorizedClientService
+			.loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1);
+		OAuth2AuthorizedClient secondLoadedClient = authorizedClientService
+			.loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1);
+		assertThat(firstLoadedClient).satisfies(isEqualTo(cachedAuthorizedClient));
+		assertThat(secondLoadedClient).satisfies(isEqualTo(authorizedClientWithUpdatedRegistration));
 	}
 
 	@Test
@@ -148,7 +181,7 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
 		this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
 		OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService
 			.loadAuthorizedClient(this.registration3.getRegistrationId(), this.principalName2);
-		assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient);
+		assertThat(loadedAuthorizedClient).satisfies(isEqualTo(authorizedClient));
 	}
 
 	@Test
@@ -180,4 +213,38 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
 		assertThat(loadedAuthorizedClient).isNull();
 	}
 
+	private static Consumer<OAuth2AuthorizedClient> isEqualTo(OAuth2AuthorizedClient expected) {
+		return (actual) -> {
+			assertThat(actual).isNotNull();
+			assertThat(actual.getClientRegistration().getRegistrationId())
+				.isEqualTo(expected.getClientRegistration().getRegistrationId());
+			assertThat(actual.getClientRegistration().getClientName())
+				.isEqualTo(expected.getClientRegistration().getClientName());
+			assertThat(actual.getClientRegistration().getRedirectUri())
+				.isEqualTo(expected.getClientRegistration().getRedirectUri());
+			assertThat(actual.getClientRegistration().getAuthorizationGrantType())
+				.isEqualTo(expected.getClientRegistration().getAuthorizationGrantType());
+			assertThat(actual.getClientRegistration().getClientAuthenticationMethod())
+				.isEqualTo(expected.getClientRegistration().getClientAuthenticationMethod());
+			assertThat(actual.getClientRegistration().getClientId())
+				.isEqualTo(expected.getClientRegistration().getClientId());
+			assertThat(actual.getClientRegistration().getClientSecret())
+				.isEqualTo(expected.getClientRegistration().getClientSecret());
+			assertThat(actual.getPrincipalName()).isEqualTo(expected.getPrincipalName());
+			assertThat(actual.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
+			assertThat(actual.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
+			assertThat(actual.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
+			assertThat(actual.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
+			assertThat(actual.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes());
+			if (expected.getRefreshToken() != null) {
+				assertThat(actual.getRefreshToken()).isNotNull();
+				assertThat(actual.getRefreshToken().getTokenValue())
+					.isEqualTo(expected.getRefreshToken().getTokenValue());
+				assertThat(actual.getRefreshToken().getIssuedAt()).isEqualTo(expected.getRefreshToken().getIssuedAt());
+				assertThat(actual.getRefreshToken().getExpiresAt())
+					.isEqualTo(expected.getRefreshToken().getExpiresAt());
+			}
+		};
+	}
+
 }

+ 74 - 4
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientServiceTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -18,12 +18,14 @@ package org.springframework.security.oauth2.client;
 
 import java.time.Duration;
 import java.time.Instant;
+import java.util.function.Consumer;
 
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtendWith;
 import org.mockito.Mock;
 import org.mockito.junit.jupiter.MockitoExtension;
+import reactor.core.publisher.Flux;
 import reactor.core.publisher.Mono;
 import reactor.test.StepVerifier;
 
@@ -34,7 +36,9 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 
+import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.BDDMockito.given;
 
@@ -56,8 +60,9 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests {
 
 	private Authentication principal = new TestingAuthenticationToken(this.principalName, "notused");
 
-	OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token", Instant.now(),
-			Instant.now().plus(Duration.ofDays(1)));
+	private OAuth2AccessToken accessToken;
+
+	private OAuth2RefreshToken refreshToken;
 
 	// @formatter:off
 	private ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(this.clientRegistrationId)
@@ -79,6 +84,11 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests {
 	public void setup() {
 		this.authorizedClientService = new InMemoryReactiveOAuth2AuthorizedClientService(
 				this.clientRegistrationRepository);
+
+		Instant issuedAt = Instant.now();
+		Instant expiresAt = issuedAt.plus(Duration.ofDays(1));
+		this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token", issuedAt, expiresAt);
+		this.refreshToken = new OAuth2RefreshToken("refresh", issuedAt, expiresAt);
 	}
 
 	@Test
@@ -153,11 +163,37 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests {
 				.saveAuthorizedClient(authorizedClient, this.principal)
 				.then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName));
 		StepVerifier.create(saveAndLoad)
-				.expectNext(authorizedClient)
+				.assertNext(isEqualTo(authorizedClient))
 				.verifyComplete();
 		// @formatter:on
 	}
 
+	@Test
+	@SuppressWarnings("unchecked")
+	public void loadAuthorizedClientWhenClientRegistrationIsUpdatedThenReturnsAuthorizedClientWithUpdatedClientRegistration() {
+		ClientRegistration updatedRegistration = ClientRegistration.withClientRegistration(this.clientRegistration)
+			.clientSecret("updated secret")
+			.build();
+
+		given(this.clientRegistrationRepository.findByRegistrationId(this.clientRegistrationId))
+			.willReturn(Mono.just(this.clientRegistration), Mono.just(updatedRegistration));
+
+		OAuth2AuthorizedClient cachedAuthorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
+				this.principalName, this.accessToken, this.refreshToken);
+		OAuth2AuthorizedClient authorizedClientWithChangedRegistration = new OAuth2AuthorizedClient(updatedRegistration,
+				this.principalName, this.accessToken, this.refreshToken);
+
+		Flux<OAuth2AuthorizedClient> saveAndLoadTwice = this.authorizedClientService
+			.saveAuthorizedClient(cachedAuthorizedClient, this.principal)
+			.then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName))
+			.concatWith(
+					this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName));
+		StepVerifier.create(saveAndLoadTwice)
+			.assertNext(isEqualTo(cachedAuthorizedClient))
+			.assertNext(isEqualTo(authorizedClientWithChangedRegistration))
+			.verifyComplete();
+	}
+
 	@Test
 	public void saveAuthorizedClientWhenAuthorizedClientNullThenIllegalArgumentException() {
 		OAuth2AuthorizedClient authorizedClient = null;
@@ -246,4 +282,38 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests {
 		// @formatter:on
 	}
 
+	private static Consumer<OAuth2AuthorizedClient> isEqualTo(OAuth2AuthorizedClient expected) {
+		return (actual) -> {
+			assertThat(actual).isNotNull();
+			assertThat(actual.getClientRegistration().getRegistrationId())
+				.isEqualTo(expected.getClientRegistration().getRegistrationId());
+			assertThat(actual.getClientRegistration().getClientName())
+				.isEqualTo(expected.getClientRegistration().getClientName());
+			assertThat(actual.getClientRegistration().getRedirectUri())
+				.isEqualTo(expected.getClientRegistration().getRedirectUri());
+			assertThat(actual.getClientRegistration().getAuthorizationGrantType())
+				.isEqualTo(expected.getClientRegistration().getAuthorizationGrantType());
+			assertThat(actual.getClientRegistration().getClientAuthenticationMethod())
+				.isEqualTo(expected.getClientRegistration().getClientAuthenticationMethod());
+			assertThat(actual.getClientRegistration().getClientId())
+				.isEqualTo(expected.getClientRegistration().getClientId());
+			assertThat(actual.getClientRegistration().getClientSecret())
+				.isEqualTo(expected.getClientRegistration().getClientSecret());
+			assertThat(actual.getPrincipalName()).isEqualTo(expected.getPrincipalName());
+			assertThat(actual.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
+			assertThat(actual.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
+			assertThat(actual.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
+			assertThat(actual.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
+			assertThat(actual.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes());
+			if (expected.getRefreshToken() != null) {
+				assertThat(actual.getRefreshToken()).isNotNull();
+				assertThat(actual.getRefreshToken().getTokenValue())
+					.isEqualTo(expected.getRefreshToken().getTokenValue());
+				assertThat(actual.getRefreshToken().getIssuedAt()).isEqualTo(expected.getRefreshToken().getIssuedAt());
+				assertThat(actual.getRefreshToken().getExpiresAt())
+					.isEqualTo(expected.getRefreshToken().getExpiresAt());
+			}
+		};
+	}
+
 }