Browse Source

Remove cache from (Reactive)OidcIdTokenDecoderFactory

Closes gh-16647

Signed-off-by: iigolovko <iigolovko@ginc-it.ru>
Ivan Golovko 7 months ago
parent
commit
979ac7c336

+ 8 - 13
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java

@@ -23,7 +23,6 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Function;
 
 import javax.crypto.spec.SecretKeySpec;
@@ -78,8 +77,6 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory<Client
 
 	private static final ClaimTypeConverter DEFAULT_CLAIM_TYPE_CONVERTER = createDefaultClaimTypeConverter();
 
-	private final Map<String, JwtDecoder> jwtDecoders = new ConcurrentHashMap<>();
-
 	private Function<ClientRegistration, OAuth2TokenValidator<Jwt>> jwtValidatorFactory = new DefaultOidcIdTokenValidatorFactory();
 
 	private Function<ClientRegistration, JwsAlgorithm> jwsAlgorithmResolver = (
@@ -135,16 +132,14 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory<Client
 	@Override
 	public JwtDecoder createDecoder(ClientRegistration clientRegistration) {
 		Assert.notNull(clientRegistration, "clientRegistration cannot be null");
-		return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), (key) -> {
-			NimbusJwtDecoder jwtDecoder = buildDecoder(clientRegistration);
-			jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration));
-			Converter<Map<String, Object>, Map<String, Object>> claimTypeConverter = this.claimTypeConverterFactory
-				.apply(clientRegistration);
-			if (claimTypeConverter != null) {
-				jwtDecoder.setClaimSetConverter(claimTypeConverter);
-			}
-			return jwtDecoder;
-		});
+		NimbusJwtDecoder jwtDecoder = buildDecoder(clientRegistration);
+		jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration));
+		Converter<Map<String, Object>, Map<String, Object>> claimTypeConverter = this.claimTypeConverterFactory
+			.apply(clientRegistration);
+		if (claimTypeConverter != null) {
+			jwtDecoder.setClaimSetConverter(claimTypeConverter);
+		}
+		return jwtDecoder;
 	}
 
 	private NimbusJwtDecoder buildDecoder(ClientRegistration clientRegistration) {

+ 8 - 13
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java

@@ -23,7 +23,6 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Function;
 
 import javax.crypto.spec.SecretKeySpec;
@@ -80,8 +79,6 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod
 	private static final ClaimTypeConverter DEFAULT_CLAIM_TYPE_CONVERTER = new ClaimTypeConverter(
 			createDefaultClaimTypeConverters());
 
-	private final Map<String, ReactiveJwtDecoder> jwtDecoders = new ConcurrentHashMap<>();
-
 	private Function<ClientRegistration, OAuth2TokenValidator<Jwt>> jwtValidatorFactory = new DefaultOidcIdTokenValidatorFactory();
 
 	private Function<ClientRegistration, JwsAlgorithm> jwsAlgorithmResolver = (
@@ -126,16 +123,14 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod
 	@Override
 	public ReactiveJwtDecoder createDecoder(ClientRegistration clientRegistration) {
 		Assert.notNull(clientRegistration, "clientRegistration cannot be null");
-		return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), (key) -> {
-			NimbusReactiveJwtDecoder jwtDecoder = buildDecoder(clientRegistration);
-			jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration));
-			Converter<Map<String, Object>, Map<String, Object>> claimTypeConverter = this.claimTypeConverterFactory
-				.apply(clientRegistration);
-			if (claimTypeConverter != null) {
-				jwtDecoder.setClaimSetConverter(claimTypeConverter);
-			}
-			return jwtDecoder;
-		});
+		NimbusReactiveJwtDecoder jwtDecoder = buildDecoder(clientRegistration);
+		jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration));
+		Converter<Map<String, Object>, Map<String, Object>> claimTypeConverter = this.claimTypeConverterFactory
+			.apply(clientRegistration);
+		if (claimTypeConverter != null) {
+			jwtDecoder.setClaimSetConverter(claimTypeConverter);
+		}
+		return jwtDecoder;
 	}
 
 	private NimbusReactiveJwtDecoder buildDecoder(ClientRegistration clientRegistration) {

+ 11 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java

@@ -34,6 +34,7 @@ import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
 import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.JwtDecoder;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@@ -46,6 +47,7 @@ import static org.mockito.Mockito.verify;
 /**
  * @author Joe Grandja
  * @author Rafael Dominguez
+ * @author Ivan Golovko
  * @since 5.2
  */
 public class OidcIdTokenDecoderFactoryTests {
@@ -177,4 +179,13 @@ public class OidcIdTokenDecoderFactoryTests {
 		verify(customClaimTypeConverterFactory).apply(same(clientRegistration));
 	}
 
+	// gh-16647
+	@Test
+	public void createDecoderWhenCachingRemovedThenReturnNewDecoder() {
+		ClientRegistration clientRegistration = this.registration.build();
+		JwtDecoder decoder1 = this.idTokenDecoderFactory.createDecoder(clientRegistration);
+		JwtDecoder decoder2 = this.idTokenDecoderFactory.createDecoder(clientRegistration);
+		assertThat(decoder1).isNotSameAs(decoder2);
+	}
+
 }

+ 11 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java

@@ -34,6 +34,7 @@ import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
 import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@@ -47,6 +48,7 @@ import static org.mockito.Mockito.verify;
  * @author Joe Grandja
  * @author Rafael Dominguez
  * @author Ubaid ur Rehman
+ * @author Ivan Golovko
  * @since 5.2
  */
 public class ReactiveOidcIdTokenDecoderFactoryTests {
@@ -177,4 +179,13 @@ public class ReactiveOidcIdTokenDecoderFactoryTests {
 		verify(customClaimTypeConverterFactory).apply(same(clientRegistration));
 	}
 
+	// gh-16647
+	@Test
+	public void createDecoderWhenCachingRemovedThenReturnNewDecoder() {
+		ClientRegistration clientRegistration = this.registration.build();
+		ReactiveJwtDecoder decoder1 = this.idTokenDecoderFactory.createDecoder(clientRegistration);
+		ReactiveJwtDecoder decoder2 = this.idTokenDecoderFactory.createDecoder(clientRegistration);
+		assertThat(decoder1).isNotSameAs(decoder2);
+	}
+
 }