فهرست منبع

Reactive Jwt Validation

This allows a user to customize the Jwt validation steps that
NimbusReactiveJwtDecoder will take for each Jwt.

Fixes: gh-5650
Josh Cummings 7 سال پیش
والد
کامیت
01443e35b4

+ 29 - 1
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java

@@ -40,6 +40,8 @@ import com.nimbusds.jwt.proc.DefaultJWTProcessor;
 import com.nimbusds.jwt.proc.JWTProcessor;
 import reactor.core.publisher.Mono;
 
+import org.springframework.security.oauth2.core.OAuth2TokenValidator;
+import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
 import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
 import org.springframework.util.Assert;
 
@@ -67,6 +69,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 
 	private final JWKSelectorFactory jwkSelectorFactory;
 
+	private OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefault();
+
 	public NimbusReactiveJwtDecoder(RSAPublicKey publicKey) {
 		JWSAlgorithm algorithm = JWSAlgorithm.parse(JwsAlgorithms.RS256);
 
@@ -77,6 +81,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 				new JWSVerificationKeySelector<>(algorithm, jwkSource);
 		DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>();
 		jwtProcessor.setJWSKeySelector(jwsKeySelector);
+		jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});
 
 		this.jwtProcessor = jwtProcessor;
 		this.reactiveJwkSource = new ReactiveJWKSourceAdapter(jwkSource);
@@ -98,6 +103,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 
 		DefaultJWTProcessor<JWKContext> jwtProcessor = new DefaultJWTProcessor<>();
 		jwtProcessor.setJWSKeySelector(jwsKeySelector);
+		jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});
 		this.jwtProcessor = jwtProcessor;
 
 		this.reactiveJwkSource = new ReactiveRemoteJWKSource(jwkSetUrl);
@@ -106,6 +112,16 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 
 	}
 
+	/**
+	 * Use the provided {@link OAuth2TokenValidator} to validate incoming {@link Jwt}s.
+	 *
+	 * @param jwtValidator the {@link OAuth2TokenValidator} to use
+	 */
+	public void setJwtValidator(OAuth2TokenValidator<Jwt> jwtValidator) {
+		Assert.notNull(jwtValidator, "jwtValidator cannot be null");
+		this.jwtValidator = jwtValidator;
+	}
+
 	@Override
 	public Mono<Jwt> decode(String token) throws JwtException {
 		JWT jwt = parse(token);
@@ -131,7 +147,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 				.onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e))
 				.map(jwkList -> createClaimsSet(parsedToken, jwkList))
 				.map(set -> createJwt(parsedToken, set))
-				.onErrorMap(e -> !(e instanceof IllegalStateException), e -> new JwtException("An error occurred while attempting to decode the Jwt: ", e));
+				.map(this::validateJwt)
+				.onErrorMap(e -> !(e instanceof IllegalStateException) && !(e instanceof JwtException), e -> new JwtException("An error occurred while attempting to decode the Jwt: ", e));
 		} catch (RuntimeException ex) {
 			throw new JwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex);
 		}
@@ -164,6 +181,17 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 		return new Jwt(parsedJwt.getParsedString(), issuedAt, expiresAt, headers, jwtClaimsSet.getClaims());
 	}
 
+	private Jwt validateJwt(Jwt jwt) {
+		OAuth2TokenValidatorResult result = this.jwtValidator.validate(jwt);
+
+		if ( result.hasErrors() ) {
+			String message = result.getErrors().iterator().next().getDescription();
+			throw new JwtValidationException(message, result.getErrors());
+		}
+
+		return jwt;
+	}
+
 	private static RSAKey rsaKey(RSAPublicKey publicKey) {
 		return new RSAKey.Builder(publicKey)
 				.build();

+ 34 - 7
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java

@@ -16,12 +16,6 @@
 
 package org.springframework.security.oauth2.jwt;
 
-import okhttp3.mockwebserver.MockResponse;
-import okhttp3.mockwebserver.MockWebServer;
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
-
 import java.net.UnknownHostException;
 import java.security.KeyFactory;
 import java.security.interfaces.RSAPublicKey;
@@ -29,8 +23,21 @@ import java.security.spec.X509EncodedKeySpec;
 import java.util.Base64;
 import java.util.Date;
 
+import okhttp3.mockwebserver.MockResponse;
+import okhttp3.mockwebserver.MockWebServer;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2TokenValidator;
+import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
+
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 /**
  * @author Rob Winch
@@ -114,7 +121,7 @@ public class NimbusReactiveJwtDecoderTests {
 	@Test
 	public void decodeWhenExpiredThenFail() {
 		assertThatCode(() -> this.decoder.decode(this.expired).block())
-				.isInstanceOf(JwtException.class);
+				.isInstanceOf(JwtValidationException.class);
 	}
 
 	@Test
@@ -155,4 +162,24 @@ public class NimbusReactiveJwtDecoderTests {
 				.isInstanceOf(JwtException.class)
 				.hasMessage("Unsupported algorithm of none");
 	}
+
+	@Test
+	public void decodeWhenUsingCustomValidatorThenValidatorIsInvoked() {
+		OAuth2TokenValidator jwtValidator = mock(OAuth2TokenValidator.class);
+		this.decoder.setJwtValidator(jwtValidator);
+
+		OAuth2Error error = new OAuth2Error("mock-error", "mock-description", "mock-uri");
+		OAuth2TokenValidatorResult result = OAuth2TokenValidatorResult.failure(error);
+		when(jwtValidator.validate(any(Jwt.class))).thenReturn(result);
+
+		assertThatCode(() -> this.decoder.decode(messageReadToken).block())
+				.isInstanceOf(JwtException.class)
+				.hasMessageContaining("mock-description");
+	}
+
+	@Test
+	public void setJwtValidatorWhenGivenNullThrowsIllegalArgumentException() {
+		assertThatCode(() -> this.decoder.setJwtValidator(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
 }