Browse Source

Remove SignedJWT Check

JWTProcessor already does sufficient checking to confirm that the JWT
is of the appropriate type.

Fixes: gh-7034
Josh Cummings 6 years ago
parent
commit
37d108ccc2

+ 17 - 16
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java

@@ -16,6 +16,17 @@
 
 package org.springframework.security.oauth2.jwt;
 
+import java.io.IOException;
+import java.net.MalformedURLException;
+import java.net.URL;
+import java.security.interfaces.RSAPublicKey;
+import java.text.ParseException;
+import java.time.Instant;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import javax.crypto.SecretKey;
+
 import com.nimbusds.jose.JWSAlgorithm;
 import com.nimbusds.jose.RemoteKeySourceException;
 import com.nimbusds.jose.jwk.JWKSet;
@@ -32,10 +43,11 @@ import com.nimbusds.jose.util.ResourceRetriever;
 import com.nimbusds.jwt.JWT;
 import com.nimbusds.jwt.JWTClaimsSet;
 import com.nimbusds.jwt.JWTParser;
-import com.nimbusds.jwt.SignedJWT;
+import com.nimbusds.jwt.PlainJWT;
 import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
 import com.nimbusds.jwt.proc.DefaultJWTProcessor;
 import com.nimbusds.jwt.proc.JWTProcessor;
+
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
@@ -51,17 +63,6 @@ import org.springframework.util.Assert;
 import org.springframework.web.client.RestOperations;
 import org.springframework.web.client.RestTemplate;
 
-import javax.crypto.SecretKey;
-import java.io.IOException;
-import java.net.MalformedURLException;
-import java.net.URL;
-import java.security.interfaces.RSAPublicKey;
-import java.text.ParseException;
-import java.time.Instant;
-import java.util.Collections;
-import java.util.LinkedHashMap;
-import java.util.Map;
-
 /**
  * A low-level Nimbus implementation of {@link JwtDecoder} which takes a raw Nimbus configuration.
  *
@@ -119,11 +120,11 @@ public final class NimbusJwtDecoder implements JwtDecoder {
 	@Override
 	public Jwt decode(String token) throws JwtException {
 		JWT jwt = parse(token);
-		if (jwt instanceof SignedJWT) {
-			Jwt createdJwt = createJwt(token, jwt);
-			return validateJwt(createdJwt);
+		if (jwt instanceof PlainJWT) {
+			throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm());
 		}
-		throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm());
+		Jwt createdJwt = createJwt(token, jwt);
+		return validateJwt(createdJwt);
 	}
 
 	private JWT parse(String token) {

+ 37 - 30
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java

@@ -15,6 +15,15 @@
  */
 package org.springframework.security.oauth2.jwt;
 
+import java.security.interfaces.RSAPublicKey;
+import java.time.Instant;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.function.Function;
+import javax.crypto.SecretKey;
+
+import com.nimbusds.jose.Header;
 import com.nimbusds.jose.JOSEException;
 import com.nimbusds.jose.JWSAlgorithm;
 import com.nimbusds.jose.JWSHeader;
@@ -35,9 +44,13 @@ import com.nimbusds.jose.proc.SecurityContext;
 import com.nimbusds.jwt.JWT;
 import com.nimbusds.jwt.JWTClaimsSet;
 import com.nimbusds.jwt.JWTParser;
+import com.nimbusds.jwt.PlainJWT;
 import com.nimbusds.jwt.SignedJWT;
 import com.nimbusds.jwt.proc.DefaultJWTProcessor;
 import com.nimbusds.jwt.proc.JWTProcessor;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.oauth2.core.OAuth2TokenValidator;
 import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
@@ -46,16 +59,6 @@ import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 import org.springframework.util.Assert;
 import org.springframework.web.reactive.function.client.WebClient;
-import reactor.core.publisher.Flux;
-import reactor.core.publisher.Mono;
-
-import javax.crypto.SecretKey;
-import java.security.interfaces.RSAPublicKey;
-import java.time.Instant;
-import java.util.Collections;
-import java.util.LinkedHashMap;
-import java.util.Map;
-import java.util.function.Function;
 
 /**
  * An implementation of a {@link ReactiveJwtDecoder} that "decodes" a
@@ -75,7 +78,7 @@ import java.util.function.Function;
  * @see <a target="_blank" href="https://connect2id.com/products/nimbus-jose-jwt">Nimbus JOSE + JWT SDK</a>
  */
 public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
-	private final Converter<SignedJWT, Mono<JWTClaimsSet>> jwtProcessor;
+	private final Converter<JWT, Mono<JWTClaimsSet>> jwtProcessor;
 
 	private OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefault();
 	private Converter<Map<String, Object>, Map<String, Object>> claimSetConverter =
@@ -106,7 +109,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 	 * @param jwtProcessor the {@link Converter} used to process and verify the signed Jwt and return the Jwt Claim Set
 	 * @since 5.2
 	 */
-	public NimbusReactiveJwtDecoder(Converter<SignedJWT, Mono<JWTClaimsSet>> jwtProcessor) {
+	public NimbusReactiveJwtDecoder(Converter<JWT, Mono<JWTClaimsSet>> jwtProcessor) {
 		this.jwtProcessor = jwtProcessor;
 	}
 
@@ -133,10 +136,10 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 	@Override
 	public Mono<Jwt> decode(String token) throws JwtException {
 		JWT jwt = parse(token);
-		if (jwt instanceof SignedJWT) {
-			return this.decode((SignedJWT) jwt);
+		if (jwt instanceof PlainJWT) {
+			throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm());
 		}
-		throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm());
+		return this.decode(jwt);
 	}
 
 	private JWT parse(String token) {
@@ -147,7 +150,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 		}
 	}
 
-	private Mono<Jwt> decode(SignedJWT parsedToken) {
+	private Mono<Jwt> decode(JWT parsedToken) {
 		try {
 			return this.jwtProcessor.convert(parsedToken)
 				.map(set -> createJwt(parsedToken, set))
@@ -280,7 +283,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 			return new NimbusReactiveJwtDecoder(processor());
 		}
 
-		Converter<SignedJWT, Mono<JWTClaimsSet>> processor() {
+		Converter<JWT, Mono<JWTClaimsSet>> processor() {
 			JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet();
 
 			JWSKeySelector<JWKSecurityContext> jwsKeySelector =
@@ -292,20 +295,20 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 			ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource(this.jwkSetUri);
 			source.setWebClient(this.webClient);
 
-			return signedJWT -> {
-				JWKSelector selector = createSelector(signedJWT.getHeader());
+			return jwt -> {
+				JWKSelector selector = createSelector(jwt.getHeader());
 				return source.get(selector)
 						.onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e))
-						.map(jwkList -> createClaimsSet(jwtProcessor, signedJWT, new JWKSecurityContext(jwkList)));
+						.map(jwkList -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwkList)));
 			};
 		}
 
-		private JWKSelector createSelector(JWSHeader header) {
+		private JWKSelector createSelector(Header header) {
 			if (!this.jwsAlgorithm.equals(header.getAlgorithm())) {
 				throw new JwtException("Unsupported algorithm of " + header.getAlgorithm());
 			}
 
-			return new JWKSelector(JWKMatcher.forJWSHeader(header));
+			return new JWKSelector(JWKMatcher.forJWSHeader((JWSHeader) header));
 		}
 	}
 
@@ -353,7 +356,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 			return new NimbusReactiveJwtDecoder(processor());
 		}
 
-		Converter<SignedJWT, Mono<JWTClaimsSet>> processor() {
+		Converter<JWT, Mono<JWTClaimsSet>> processor() {
 			if (!JWSAlgorithm.Family.RSA.contains(this.jwsAlgorithm)) {
 				throw new IllegalStateException("The provided key is of type RSA; " +
 						"however the signature algorithm is of some other type: " +
@@ -370,7 +373,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 			// Spring Security validates the claim set independent from Nimbus
 			jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { });
 
-			return signedJWT -> Mono.just(signedJWT).map(jwt -> createClaimsSet(jwtProcessor, jwt, null));
+			return jwt -> Mono.just(createClaimsSet(jwtProcessor, jwt, null));
 		}
 	}
 
@@ -414,7 +417,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 			return new NimbusReactiveJwtDecoder(processor());
 		}
 
-		Converter<SignedJWT, Mono<JWTClaimsSet>> processor() {
+		Converter<JWT, Mono<JWTClaimsSet>> processor() {
 			JWKSource<SecurityContext> jwkSource = new ImmutableSecret<>(this.secretKey);
 			JWSKeySelector<SecurityContext> jwsKeySelector =
 					new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
@@ -424,7 +427,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 			// Spring Security validates the claim set independent from Nimbus
 			jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { });
 
-			return signedJWT -> Mono.just(signedJWT).map(jwt -> createClaimsSet(jwtProcessor, jwt, null));
+			return jwt -> Mono.just(createClaimsSet(jwtProcessor, jwt, null));
 		}
 	}
 
@@ -464,7 +467,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 			return new NimbusReactiveJwtDecoder(processor());
 		}
 
-		Converter<SignedJWT, Mono<JWTClaimsSet>> processor() {
+		Converter<JWT, Mono<JWTClaimsSet>> processor() {
 			JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet();
 			JWSKeySelector<JWKSecurityContext> jwsKeySelector =
 					new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
@@ -472,11 +475,15 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 			jwtProcessor.setJWSKeySelector(jwsKeySelector);
 			jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});
 
-			return signedJWT ->
-					this.jwkSource.apply(signedJWT)
+			return jwt -> {
+				if (jwt instanceof SignedJWT) {
+					return this.jwkSource.apply((SignedJWT) jwt)
 							.onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e))
 							.collectList()
-							.map(jwks -> createClaimsSet(jwtProcessor, signedJWT, new JWKSecurityContext(jwks)));
+							.map(jwks -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwks)));
+				}
+				throw new JwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm());
+			};
 		}
 	}