Browse Source

Use Nimbus Multiple Algorithm Support

Closes gh-8623
Josh Cummings 5 năm trước cách đây
mục cha
commit
aa84c79e87

+ 0 - 54
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JWSAlgorithmMapJWSKeySelector.java

@@ -1,54 +0,0 @@
-/*
- * Copyright 2002-2019 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.springframework.security.oauth2.jwt;
-
-import java.security.Key;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-
-import com.nimbusds.jose.JWSAlgorithm;
-import com.nimbusds.jose.JWSHeader;
-import com.nimbusds.jose.KeySourceException;
-import com.nimbusds.jose.proc.JWSKeySelector;
-import com.nimbusds.jose.proc.SecurityContext;
-
-/**
- * Class for delegating to a Nimbus JWSKeySelector by the given JWSAlgorithm
- *
- * @author Josh Cummings
- */
-class JWSAlgorithmMapJWSKeySelector<C extends SecurityContext> implements JWSKeySelector<C> {
-	private Map<JWSAlgorithm, JWSKeySelector<C>> jwsKeySelectors;
-
-	JWSAlgorithmMapJWSKeySelector(Map<JWSAlgorithm, JWSKeySelector<C>> jwsKeySelectors) {
-		this.jwsKeySelectors = jwsKeySelectors;
-	}
-
-	@Override
-	public List<? extends Key> selectJWSKeys(JWSHeader header, C context) throws KeySourceException {
-		JWSKeySelector<C> keySelector = this.jwsKeySelectors.get(header.getAlgorithm());
-		if (keySelector == null) {
-			throw new IllegalArgumentException("Unsupported algorithm of " + header.getAlgorithm());
-		}
-		return keySelector.selectJWSKeys(header, context);
-	}
-
-	public Set<JWSAlgorithm> getExpectedJWSAlgorithms() {
-		return this.jwsKeySelectors.keySet();
-	}
-}

+ 4 - 8
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java

@@ -23,7 +23,6 @@ import java.security.interfaces.RSAPublicKey;
 import java.text.ParseException;
 import java.util.Arrays;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.Map;
@@ -286,16 +285,13 @@ public final class NimbusJwtDecoder implements JwtDecoder {
 		JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSource) {
 			if (this.signatureAlgorithms.isEmpty()) {
 				return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
-			} else if (this.signatureAlgorithms.size() == 1) {
-				JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(this.signatureAlgorithms.iterator().next().getName());
-				return new JWSVerificationKeySelector<>(jwsAlgorithm, jwkSource);
 			} else {
-				Map<JWSAlgorithm, JWSKeySelector<SecurityContext>> jwsKeySelectors = new HashMap<>();
+				Set<JWSAlgorithm> jwsAlgorithms = new HashSet<>();
 				for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) {
-					JWSAlgorithm jwsAlg = JWSAlgorithm.parse(signatureAlgorithm.getName());
-					jwsKeySelectors.put(jwsAlg, new JWSVerificationKeySelector<>(jwsAlg, jwkSource));
+					JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
+					jwsAlgorithms.add(jwsAlgorithm);
 				}
-				return new JWSAlgorithmMapJWSKeySelector<>(jwsKeySelectors);
+				return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource);
 			}
 		}
 

+ 11 - 17
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java

@@ -17,7 +17,6 @@ package org.springframework.security.oauth2.jwt;
 
 import java.security.interfaces.RSAPublicKey;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.Map;
@@ -307,16 +306,13 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 		JWSKeySelector<JWKSecurityContext> jwsKeySelector(JWKSource<JWKSecurityContext> jwkSource) {
 			if (this.signatureAlgorithms.isEmpty()) {
 				return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
-			} else if (this.signatureAlgorithms.size() == 1) {
-				JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(this.signatureAlgorithms.iterator().next().getName());
-				return new JWSVerificationKeySelector<>(jwsAlgorithm, jwkSource);
 			} else {
-				Map<JWSAlgorithm, JWSKeySelector<JWKSecurityContext>> jwsKeySelectors = new HashMap<>();
+				Set<JWSAlgorithm> jwsAlgorithms = new HashSet<>();
 				for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) {
-					JWSAlgorithm jwsAlg = JWSAlgorithm.parse(signatureAlgorithm.getName());
-					jwsKeySelectors.put(jwsAlg, new JWSVerificationKeySelector<>(jwsAlg, jwkSource));
+					JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
+					jwsAlgorithms.add(jwsAlgorithm);
 				}
-				return new JWSAlgorithmMapJWSKeySelector<>(jwsKeySelectors);
+				return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource);
 			}
 		}
 
@@ -330,7 +326,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 			ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource(this.jwkSetUri);
 			source.setWebClient(this.webClient);
 
-			Set<JWSAlgorithm> expectedJwsAlgorithms = getExpectedJwsAlgorithms(jwsKeySelector);
+			Function<JWSAlgorithm, Boolean> expectedJwsAlgorithms = getExpectedJwsAlgorithms(jwsKeySelector);
 			return jwt -> {
 				JWKSelector selector = createSelector(expectedJwsAlgorithms, jwt.getHeader());
 				return source.get(selector)
@@ -339,22 +335,20 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 			};
 		}
 
-		private Set<JWSAlgorithm> getExpectedJwsAlgorithms(JWSKeySelector<?> jwsKeySelector) {
+		private Function<JWSAlgorithm, Boolean> getExpectedJwsAlgorithms(JWSKeySelector<?> jwsKeySelector) {
 			if (jwsKeySelector instanceof JWSVerificationKeySelector) {
-				return Collections.singleton(((JWSVerificationKeySelector<?>) jwsKeySelector).getExpectedJWSAlgorithm());
-			}
-			if (jwsKeySelector instanceof JWSAlgorithmMapJWSKeySelector) {
-				return ((JWSAlgorithmMapJWSKeySelector<?>) jwsKeySelector).getExpectedJWSAlgorithms();
+				return ((JWSVerificationKeySelector<?>) jwsKeySelector)::isAllowed;
 			}
 			throw new IllegalArgumentException("Unsupported key selector type " + jwsKeySelector.getClass());
 		}
 
-		private JWKSelector createSelector(Set<JWSAlgorithm> expectedJwsAlgorithms, Header header) {
-			if (!expectedJwsAlgorithms.contains(header.getAlgorithm())) {
+		private JWKSelector createSelector(Function<JWSAlgorithm, Boolean> expectedJwsAlgorithms, Header header) {
+			JWSHeader jwsHeader = (JWSHeader) header;
+			if (!expectedJwsAlgorithms.apply(jwsHeader.getAlgorithm())) {
 				throw new BadJwtException("Unsupported algorithm of " + header.getAlgorithm());
 			}
 
-			return new JWKSelector(JWKMatcher.forJWSHeader((JWSHeader) header));
+			return new JWKSelector(JWKMatcher.forJWSHeader(jwsHeader));
 		}
 	}
 

+ 11 - 9
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java

@@ -415,8 +415,8 @@ public class NimbusJwtDecoderTests {
 		assertThat(jwsKeySelector instanceof JWSVerificationKeySelector);
 		JWSVerificationKeySelector<?> jwsVerificationKeySelector =
 				(JWSVerificationKeySelector<?>) jwsKeySelector;
-		assertThat(jwsVerificationKeySelector.getExpectedJWSAlgorithm())
-				.isEqualTo(JWSAlgorithm.RS256);
+		assertThat(jwsVerificationKeySelector.isAllowed(JWSAlgorithm.RS256))
+				.isTrue();
 	}
 
 	@Test
@@ -428,8 +428,8 @@ public class NimbusJwtDecoderTests {
 		assertThat(jwsKeySelector instanceof JWSVerificationKeySelector);
 		JWSVerificationKeySelector<?> jwsVerificationKeySelector =
 				(JWSVerificationKeySelector<?>) jwsKeySelector;
-		assertThat(jwsVerificationKeySelector.getExpectedJWSAlgorithm())
-				.isEqualTo(JWSAlgorithm.RS512);
+		assertThat(jwsVerificationKeySelector.isAllowed(JWSAlgorithm.RS512))
+				.isTrue();
 	}
 
 	@Test
@@ -440,11 +440,13 @@ public class NimbusJwtDecoderTests {
 						.jwsAlgorithm(SignatureAlgorithm.RS256)
 						.jwsAlgorithm(SignatureAlgorithm.RS512)
 						.jwsKeySelector(jwkSource);
-		assertThat(jwsKeySelector instanceof JWSAlgorithmMapJWSKeySelector);
-		JWSAlgorithmMapJWSKeySelector<?> jwsAlgorithmMapKeySelector =
-				(JWSAlgorithmMapJWSKeySelector<?>) jwsKeySelector;
-		assertThat(jwsAlgorithmMapKeySelector.getExpectedJWSAlgorithms())
-				.containsExactlyInAnyOrder(JWSAlgorithm.RS256, JWSAlgorithm.RS512);
+		assertThat(jwsKeySelector instanceof JWSVerificationKeySelector);
+		JWSVerificationKeySelector<?> jwsAlgorithmMapKeySelector =
+				(JWSVerificationKeySelector<?>) jwsKeySelector;
+		assertThat(jwsAlgorithmMapKeySelector.isAllowed(JWSAlgorithm.RS256))
+				.isTrue();
+		assertThat(jwsAlgorithmMapKeySelector.isAllowed(JWSAlgorithm.RS512))
+				.isTrue();
 	}
 
 	// gh-7290

+ 11 - 9
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java

@@ -395,8 +395,8 @@ public class NimbusReactiveJwtDecoderTests {
 		assertThat(jwsKeySelector instanceof JWSVerificationKeySelector);
 		JWSVerificationKeySelector<JWKSecurityContext> jwsVerificationKeySelector =
 				(JWSVerificationKeySelector<JWKSecurityContext>) jwsKeySelector;
-		assertThat(jwsVerificationKeySelector.getExpectedJWSAlgorithm())
-				.isEqualTo(JWSAlgorithm.RS256);
+		assertThat(jwsVerificationKeySelector.isAllowed(JWSAlgorithm.RS256))
+				.isTrue();
 	}
 
 	@Test
@@ -408,8 +408,8 @@ public class NimbusReactiveJwtDecoderTests {
 		assertThat(jwsKeySelector instanceof JWSVerificationKeySelector);
 		JWSVerificationKeySelector<JWKSecurityContext> jwsVerificationKeySelector =
 				(JWSVerificationKeySelector<JWKSecurityContext>) jwsKeySelector;
-		assertThat(jwsVerificationKeySelector.getExpectedJWSAlgorithm())
-				.isEqualTo(JWSAlgorithm.RS512);
+		assertThat(jwsVerificationKeySelector.isAllowed(JWSAlgorithm.RS512))
+				.isTrue();
 	}
 
 	@Test
@@ -420,11 +420,13 @@ public class NimbusReactiveJwtDecoderTests {
 						.jwsAlgorithm(SignatureAlgorithm.RS256)
 						.jwsAlgorithm(SignatureAlgorithm.RS512)
 						.jwsKeySelector(jwkSource);
-		assertThat(jwsKeySelector instanceof JWSAlgorithmMapJWSKeySelector);
-		JWSAlgorithmMapJWSKeySelector<?> jwsAlgorithmMapKeySelector =
-				(JWSAlgorithmMapJWSKeySelector<?>) jwsKeySelector;
-		assertThat(jwsAlgorithmMapKeySelector.getExpectedJWSAlgorithms())
-				.containsExactlyInAnyOrder(JWSAlgorithm.RS256, JWSAlgorithm.RS512);
+		assertThat(jwsKeySelector instanceof JWSVerificationKeySelector);
+		JWSVerificationKeySelector<?> jwsAlgorithmMapKeySelector =
+				(JWSVerificationKeySelector<?>) jwsKeySelector;
+		assertThat(jwsAlgorithmMapKeySelector.isAllowed(JWSAlgorithm.RS256))
+				.isTrue();
+		assertThat(jwsAlgorithmMapKeySelector.isAllowed(JWSAlgorithm.RS512))
+				.isTrue();
 	}
 
 	private SignedJWT signedJwt(SecretKey secretKey, MacAlgorithm jwsAlgorithm, JWTClaimsSet claimsSet) throws Exception {