Browse Source

Reactive Jwt Claim Set Converter Support

Exposes setClaimSetConverter on NimbusReactiveJwtDecoder, lining it up
with the same support on NimbusJwtDecoder.

Fixes: gh-6015
Josh Cummings 6 năm trước cách đây
mục cha
commit
ae74f22e30

+ 18 - 13
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java

@@ -17,6 +17,7 @@ package org.springframework.security.oauth2.jwt;
 
 import java.security.interfaces.RSAPublicKey;
 import java.time.Instant;
+import java.util.Collections;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
@@ -40,6 +41,7 @@ import com.nimbusds.jwt.proc.DefaultJWTProcessor;
 import com.nimbusds.jwt.proc.JWTProcessor;
 import reactor.core.publisher.Mono;
 
+import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.oauth2.core.OAuth2TokenValidator;
 import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
 import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
@@ -70,6 +72,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 	private final JWKSelectorFactory jwkSelectorFactory;
 
 	private OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefault();
+	private Converter<Map<String, Object>, Map<String, Object>> claimSetConverter = MappedJwtClaimSetConverter
+			.withDefaults(Collections.emptyMap());
 
 	public NimbusReactiveJwtDecoder(RSAPublicKey publicKey) {
 		JWSAlgorithm algorithm = JWSAlgorithm.parse(JwsAlgorithms.RS256);
@@ -122,6 +126,16 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 		this.jwtValidator = jwtValidator;
 	}
 
+	/**
+	 * Use the following {@link Converter} for manipulating the JWT's claim set
+	 *
+	 * @param claimSetConverter the {@link Converter} to use
+	 */
+	public void setClaimSetConverter(Converter<Map<String, Object>, Map<String, Object>> claimSetConverter) {
+		Assert.notNull(claimSetConverter, "claimSetConverter cannot be null");
+		this.claimSetConverter = claimSetConverter;
+	}
+
 	@Override
 	public Mono<Jwt> decode(String token) throws JwtException {
 		JWT jwt = parse(token);
@@ -164,21 +178,12 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
 	}
 
 	private Jwt createJwt(JWT parsedJwt, JWTClaimsSet jwtClaimsSet) {
-		Instant expiresAt = null;
-		if (jwtClaimsSet.getExpirationTime() != null) {
-			expiresAt = jwtClaimsSet.getExpirationTime().toInstant();
-		}
-		Instant issuedAt = null;
-		if (jwtClaimsSet.getIssueTime() != null) {
-			issuedAt = jwtClaimsSet.getIssueTime().toInstant();
-		} else if (expiresAt != null) {
-			// Default to expiresAt - 1 second
-			issuedAt = Instant.from(expiresAt).minusSeconds(1);
-		}
-
 		Map<String, Object> headers = new LinkedHashMap<>(parsedJwt.getHeader().toJSONObject());
+		Map<String, Object> claims = this.claimSetConverter.convert(jwtClaimsSet.getClaims());
 
-		return new Jwt(parsedJwt.getParsedString(), issuedAt, expiresAt, headers, jwtClaimsSet.getClaims());
+		Instant expiresAt = (Instant) claims.get(JwtClaimNames.EXP);
+		Instant issuedAt = (Instant) claims.get(JwtClaimNames.IAT);
+		return new Jwt(parsedJwt.getParsedString(), issuedAt, expiresAt, headers, claims);
 	}
 
 	private Jwt validateJwt(Jwt jwt) {

+ 25 - 2
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java

@@ -20,8 +20,10 @@ import java.net.UnknownHostException;
 import java.security.KeyFactory;
 import java.security.interfaces.RSAPublicKey;
 import java.security.spec.X509EncodedKeySpec;
+import java.time.Instant;
 import java.util.Base64;
-import java.util.Date;
+import java.util.Collections;
+import java.util.Map;
 
 import okhttp3.mockwebserver.MockResponse;
 import okhttp3.mockwebserver.MockWebServer;
@@ -29,6 +31,7 @@ import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
+import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2TokenValidator;
 import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
@@ -37,6 +40,7 @@ import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatCode;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 /**
@@ -115,7 +119,7 @@ public class NimbusReactiveJwtDecoderTests {
 
 		Jwt jwt = this.decoder.decode(withIssuedAt).block();
 
-		assertThat(jwt.getClaims().get(JwtClaimNames.IAT)).isEqualTo(new Date(1529942448000L));
+		assertThat(jwt.getClaims().get(JwtClaimNames.IAT)).isEqualTo(Instant.ofEpochSecond(1529942448L));
 	}
 
 	@Test
@@ -177,9 +181,28 @@ public class NimbusReactiveJwtDecoderTests {
 				.hasMessageContaining("mock-description");
 	}
 
+	@Test
+	public void decodeWhenUsingSignedJwtThenReturnsClaimsGivenByClaimSetConverter() {
+		Converter<Map<String, Object>, Map<String, Object>> claimSetConverter = mock(Converter.class);
+		this.decoder.setClaimSetConverter(claimSetConverter);
+
+		when(claimSetConverter.convert(any(Map.class))).thenReturn(Collections.singletonMap("custom", "value"));
+
+		Jwt jwt = this.decoder.decode(this.messageReadToken).block();
+		assertThat(jwt.getClaims().size()).isEqualTo(1);
+		assertThat(jwt.getClaims().get("custom")).isEqualTo("value");
+		verify(claimSetConverter).convert(any(Map.class));
+	}
+
 	@Test
 	public void setJwtValidatorWhenGivenNullThrowsIllegalArgumentException() {
 		assertThatCode(() -> this.decoder.setJwtValidator(null))
 				.isInstanceOf(IllegalArgumentException.class);
 	}
+
+	@Test
+	public void setClaimSetConverterWhenNullThrowsIllegalArgumentException() {
+		assertThatCode(() -> this.decoder.setClaimSetConverter(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
 }