Explorar o código

Merge branch '5.8.x' into 6.0.x

Closes gh-14040
Steve Riesenberg hai 1 ano
pai
achega
bb732e9d35

+ 23 - 10
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -43,23 +43,34 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource {
 	 */
 	private final AtomicReference<Mono<JWKSet>> cachedJWKSet = new AtomicReference<>(Mono.empty());
 
+	/**
+	 * The cached JWK set URL.
+	 */
+	private final AtomicReference<String> cachedJwkSetUrl = new AtomicReference<>();
+
 	private WebClient webClient = WebClient.create();
 
-	private final String jwkSetURL;
+	private final Mono<String> jwkSetUrlProvider;
 
 	ReactiveRemoteJWKSource(String jwkSetURL) {
 		Assert.hasText(jwkSetURL, "jwkSetURL cannot be empty");
-		this.jwkSetURL = jwkSetURL;
+		this.jwkSetUrlProvider = Mono.just(jwkSetURL);
+	}
+
+	ReactiveRemoteJWKSource(Mono<String> jwkSetUrlProvider) {
+		Assert.notNull(jwkSetUrlProvider, "jwkSetUrlProvider cannot be null");
+		this.jwkSetUrlProvider = Mono.fromCallable(this.cachedJwkSetUrl::get)
+			.switchIfEmpty(Mono.defer(() -> jwkSetUrlProvider.doOnNext(this.cachedJwkSetUrl::set)));
 	}
 
 	@Override
 	public Mono<List<JWK>> get(JWKSelector jwkSelector) {
 		// @formatter:off
 		return this.cachedJWKSet.get()
-				.switchIfEmpty(Mono.defer(() -> getJWKSet()))
+				.switchIfEmpty(Mono.defer(this::getJWKSet))
 				.flatMap((jwkSet) -> get(jwkSelector, jwkSet))
 				.switchIfEmpty(Mono.defer(() -> getJWKSet()
-						.map((jwkSet) -> jwkSelector.select(jwkSet)))
+						.map(jwkSelector::select))
 				);
 		// @formatter:on
 	}
@@ -95,13 +106,15 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource {
 	 */
 	private Mono<JWKSet> getJWKSet() {
 		// @formatter:off
-		return this.webClient.get()
-				.uri(this.jwkSetURL)
-				.retrieve()
-				.bodyToMono(String.class)
+		return this.jwkSetUrlProvider
+				.flatMap((jwkSetURL) -> this.webClient.get()
+					.uri(jwkSetURL)
+					.retrieve()
+					.bodyToMono(String.class)
+				)
 				.map(this::parse)
 				.doOnNext((jwkSet) -> this.cachedJWKSet
-						.set(Mono.just(jwkSet))
+					.set(Mono.just(jwkSet))
 				)
 				.cache();
 		// @formatter:on

+ 25 - 1
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@ package org.springframework.security.oauth2.jwt;
 
 import java.util.Collections;
 import java.util.List;
+import java.util.function.Supplier;
 
 import com.nimbusds.jose.jwk.JWK;
 import com.nimbusds.jose.jwk.JWKMatcher;
@@ -31,10 +32,16 @@ import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtendWith;
 import org.mockito.Mock;
 import org.mockito.junit.jupiter.MockitoExtension;
+import reactor.core.publisher.Mono;
+
+import org.springframework.web.reactive.function.client.WebClientResponseException;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
+import static org.mockito.BDDMockito.willReturn;
+import static org.mockito.BDDMockito.willThrow;
 
 /**
  * @author Rob Winch
@@ -52,6 +59,9 @@ public class ReactiveRemoteJWKSourceTests {
 
 	private MockWebServer server;
 
+	@Mock
+	private Supplier<String> mockStringSupplier;
+
 	// @formatter:off
 	private String keys = "{\n"
 			+ "    \"keys\": [\n"
@@ -156,4 +166,18 @@ public class ReactiveRemoteJWKSourceTests {
 		assertThat(this.source.get(this.selector).block()).isEmpty();
 	}
 
+	@Test
+	public void getShouldRecoverAndReturnKeysAfterErrorCase() {
+		given(this.matcher.matches(any())).willReturn(true);
+		this.source = new ReactiveRemoteJWKSource(Mono.fromSupplier(this.mockStringSupplier));
+		willThrow(WebClientResponseException.ServiceUnavailable.class).given(this.mockStringSupplier).get();
+		// first case: id provider has error state
+		assertThatExceptionOfType(WebClientResponseException.ServiceUnavailable.class)
+			.isThrownBy(() -> this.source.get(this.selector).block());
+		// second case: id provider is healthy again
+		willReturn(this.server.url("/").toString()).given(this.mockStringSupplier).get();
+		List<JWK> actual = this.source.get(this.selector).block();
+		assertThat(actual).isNotEmpty();
+	}
+
 }