|
@@ -40,6 +40,8 @@ import com.nimbusds.jwt.proc.DefaultJWTProcessor;
|
|
import com.nimbusds.jwt.proc.JWTProcessor;
|
|
import com.nimbusds.jwt.proc.JWTProcessor;
|
|
import reactor.core.publisher.Mono;
|
|
import reactor.core.publisher.Mono;
|
|
|
|
|
|
|
|
+import org.springframework.security.oauth2.core.OAuth2TokenValidator;
|
|
|
|
+import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
|
|
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
|
|
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
|
|
import org.springframework.util.Assert;
|
|
import org.springframework.util.Assert;
|
|
|
|
|
|
@@ -67,6 +69,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
|
|
|
|
|
private final JWKSelectorFactory jwkSelectorFactory;
|
|
private final JWKSelectorFactory jwkSelectorFactory;
|
|
|
|
|
|
|
|
+ private OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefault();
|
|
|
|
+
|
|
public NimbusReactiveJwtDecoder(RSAPublicKey publicKey) {
|
|
public NimbusReactiveJwtDecoder(RSAPublicKey publicKey) {
|
|
JWSAlgorithm algorithm = JWSAlgorithm.parse(JwsAlgorithms.RS256);
|
|
JWSAlgorithm algorithm = JWSAlgorithm.parse(JwsAlgorithms.RS256);
|
|
|
|
|
|
@@ -77,6 +81,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
|
new JWSVerificationKeySelector<>(algorithm, jwkSource);
|
|
new JWSVerificationKeySelector<>(algorithm, jwkSource);
|
|
DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>();
|
|
DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>();
|
|
jwtProcessor.setJWSKeySelector(jwsKeySelector);
|
|
jwtProcessor.setJWSKeySelector(jwsKeySelector);
|
|
|
|
+ jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});
|
|
|
|
|
|
this.jwtProcessor = jwtProcessor;
|
|
this.jwtProcessor = jwtProcessor;
|
|
this.reactiveJwkSource = new ReactiveJWKSourceAdapter(jwkSource);
|
|
this.reactiveJwkSource = new ReactiveJWKSourceAdapter(jwkSource);
|
|
@@ -98,6 +103,7 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
|
|
|
|
|
DefaultJWTProcessor<JWKContext> jwtProcessor = new DefaultJWTProcessor<>();
|
|
DefaultJWTProcessor<JWKContext> jwtProcessor = new DefaultJWTProcessor<>();
|
|
jwtProcessor.setJWSKeySelector(jwsKeySelector);
|
|
jwtProcessor.setJWSKeySelector(jwsKeySelector);
|
|
|
|
+ jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});
|
|
this.jwtProcessor = jwtProcessor;
|
|
this.jwtProcessor = jwtProcessor;
|
|
|
|
|
|
this.reactiveJwkSource = new ReactiveRemoteJWKSource(jwkSetUrl);
|
|
this.reactiveJwkSource = new ReactiveRemoteJWKSource(jwkSetUrl);
|
|
@@ -106,6 +112,16 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ /**
|
|
|
|
+ * Use the provided {@link OAuth2TokenValidator} to validate incoming {@link Jwt}s.
|
|
|
|
+ *
|
|
|
|
+ * @param jwtValidator the {@link OAuth2TokenValidator} to use
|
|
|
|
+ */
|
|
|
|
+ public void setJwtValidator(OAuth2TokenValidator<Jwt> jwtValidator) {
|
|
|
|
+ Assert.notNull(jwtValidator, "jwtValidator cannot be null");
|
|
|
|
+ this.jwtValidator = jwtValidator;
|
|
|
|
+ }
|
|
|
|
+
|
|
@Override
|
|
@Override
|
|
public Mono<Jwt> decode(String token) throws JwtException {
|
|
public Mono<Jwt> decode(String token) throws JwtException {
|
|
JWT jwt = parse(token);
|
|
JWT jwt = parse(token);
|
|
@@ -131,7 +147,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
|
.onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e))
|
|
.onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e))
|
|
.map(jwkList -> createClaimsSet(parsedToken, jwkList))
|
|
.map(jwkList -> createClaimsSet(parsedToken, jwkList))
|
|
.map(set -> createJwt(parsedToken, set))
|
|
.map(set -> createJwt(parsedToken, set))
|
|
- .onErrorMap(e -> !(e instanceof IllegalStateException), e -> new JwtException("An error occurred while attempting to decode the Jwt: ", e));
|
|
|
|
|
|
+ .map(this::validateJwt)
|
|
|
|
+ .onErrorMap(e -> !(e instanceof IllegalStateException) && !(e instanceof JwtException), e -> new JwtException("An error occurred while attempting to decode the Jwt: ", e));
|
|
} catch (RuntimeException ex) {
|
|
} catch (RuntimeException ex) {
|
|
throw new JwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex);
|
|
throw new JwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex);
|
|
}
|
|
}
|
|
@@ -164,6 +181,17 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
|
|
return new Jwt(parsedJwt.getParsedString(), issuedAt, expiresAt, headers, jwtClaimsSet.getClaims());
|
|
return new Jwt(parsedJwt.getParsedString(), issuedAt, expiresAt, headers, jwtClaimsSet.getClaims());
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ private Jwt validateJwt(Jwt jwt) {
|
|
|
|
+ OAuth2TokenValidatorResult result = this.jwtValidator.validate(jwt);
|
|
|
|
+
|
|
|
|
+ if ( result.hasErrors() ) {
|
|
|
|
+ String message = result.getErrors().iterator().next().getDescription();
|
|
|
|
+ throw new JwtValidationException(message, result.getErrors());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return jwt;
|
|
|
|
+ }
|
|
|
|
+
|
|
private static RSAKey rsaKey(RSAPublicKey publicKey) {
|
|
private static RSAKey rsaKey(RSAPublicKey publicKey) {
|
|
return new RSAKey.Builder(publicKey)
|
|
return new RSAKey.Builder(publicKey)
|
|
.build();
|
|
.build();
|