Bläddra i källkod

Polish gh-13976

Closes gh-13757
Steve Riesenberg 1 år sedan
förälder
incheckning
5161712c35

+ 8 - 10
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java

@@ -44,22 +44,23 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource {
 	private final AtomicReference<Mono<JWKSet>> cachedJWKSet = new AtomicReference<>(Mono.empty());
 
 	/**
-	 * cached url for jwk set.
+	 * The cached JWK set URL.
 	 */
 	private final AtomicReference<String> cachedJwkSetUrl = new AtomicReference<>();
 
 	private WebClient webClient = WebClient.create();
 
-	private Mono<String> jwkSetURLProvider;
+	private final Mono<String> jwkSetUrlProvider;
 
 	ReactiveRemoteJWKSource(String jwkSetURL) {
 		Assert.hasText(jwkSetURL, "jwkSetURL cannot be empty");
-		this.cachedJwkSetUrl.set(jwkSetURL);
+		this.jwkSetUrlProvider = Mono.just(jwkSetURL);
 	}
 
-	ReactiveRemoteJWKSource(Mono<String> jwkSetURLProvider) {
-		Assert.notNull(jwkSetURLProvider, "jwkSetURLProvider cannot be null");
-		this.jwkSetURLProvider = jwkSetURLProvider;
+	ReactiveRemoteJWKSource(Mono<String> jwkSetUrlProvider) {
+		Assert.notNull(jwkSetUrlProvider, "jwkSetUrlProvider cannot be null");
+		this.jwkSetUrlProvider = Mono.fromCallable(this.cachedJwkSetUrl::get)
+			.switchIfEmpty(Mono.defer(() -> jwkSetUrlProvider.doOnNext(this.cachedJwkSetUrl::set)));
 	}
 
 	@Override
@@ -105,10 +106,7 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource {
 	 */
 	private Mono<JWKSet> getJWKSet() {
 		// @formatter:off
-		return Mono.justOrEmpty(this.cachedJwkSetUrl.get())
-				.switchIfEmpty(Mono.defer(() -> this.jwkSetURLProvider
-					.doOnNext(this.cachedJwkSetUrl::set))
-				)
+		return this.jwkSetUrlProvider
 				.flatMap((jwkSetURL) -> this.webClient.get()
 					.uri(jwkSetURL)
 					.retrieve()

+ 12 - 11
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -32,15 +32,16 @@ import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtendWith;
 import org.mockito.Mock;
 import org.mockito.junit.jupiter.MockitoExtension;
-import org.springframework.web.reactive.function.client.WebClientResponseException;
 import reactor.core.publisher.Mono;
 
+import org.springframework.web.reactive.function.client.WebClientResponseException;
+
 import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
-import static org.mockito.Mockito.doReturn;
-import static org.mockito.Mockito.doThrow;
+import static org.mockito.BDDMockito.willReturn;
+import static org.mockito.BDDMockito.willThrow;
 
 /**
  * @author Rob Winch
@@ -168,14 +169,14 @@ public class ReactiveRemoteJWKSourceTests {
 	@Test
 	public void getShouldRecoverAndReturnKeysAfterErrorCase() {
 		given(this.matcher.matches(any())).willReturn(true);
-		this.source = new ReactiveRemoteJWKSource(Mono.fromSupplier(mockStringSupplier));
-		doThrow(WebClientResponseException.ServiceUnavailable.class).when(this.mockStringSupplier).get();
+		this.source = new ReactiveRemoteJWKSource(Mono.fromSupplier(this.mockStringSupplier));
+		willThrow(WebClientResponseException.ServiceUnavailable.class).given(this.mockStringSupplier).get();
 		// first case: id provider has error state
-		assertThatThrownBy(() -> this.source.get(this.selector).block())
-			.isExactlyInstanceOf(WebClientResponseException.ServiceUnavailable.class);
+		assertThatExceptionOfType(WebClientResponseException.ServiceUnavailable.class)
+			.isThrownBy(() -> this.source.get(this.selector).block());
 		// second case: id provider is healthy again
-		doReturn(this.server.url("/").toString()).when(this.mockStringSupplier).get();
-		var actual = this.source.get(this.selector).block();
+		willReturn(this.server.url("/").toString()).given(this.mockStringSupplier).get();
+		List<JWK> actual = this.source.get(this.selector).block();
 		assertThat(actual).isNotEmpty();
 	}