2
0
Эх сурвалжийг харах

Jwt Claim Mapping

This introduces a hook for users to customize standard Jwt Claim
values in cases where the JWT issuer isn't spec compliant or where the
user needs to add or remove claims.

Fixes: gh-5223
Josh Cummings 7 жил өмнө
parent
commit
9e0f171d47

+ 241 - 0
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/MappedJwtClaimSetConverter.java

@@ -0,0 +1,241 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.jwt;
+
+import java.net.MalformedURLException;
+import java.net.URI;
+import java.net.URL;
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Date;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+import java.util.stream.Collectors;
+
+import org.springframework.core.convert.converter.Converter;
+import org.springframework.util.Assert;
+
+/**
+ * Converts a JWT claim set, claim by claim. Can be configured with custom converters
+ * by claim name.
+ *
+ * @author Josh Cummings
+ * @since 5.1
+ */
+public final class MappedJwtClaimSetConverter
+		implements Converter<Map<String, Object>, Map<String, Object>> {
+
+	private static final Converter<Object, Collection<String>> AUDIENCE_CONVERTER = new AudienceConverter();
+	private static final Converter<Object, URL> ISSUER_CONVERTER = new IssuerConverter();
+	private static final Converter<Object, String> STRING_CONVERTER = new StringConverter();
+	private static final Converter<Object, Instant> TEMPORAL_CONVERTER = new InstantConverter();
+
+	private final Map<String, Converter<Object, ?>> claimConverters;
+
+	/**
+	 * Constructs a {@link MappedJwtClaimSetConverter} with the provided arguments
+	 *
+	 * This will completely replace any set of default converters.
+	 *
+	 * @param claimConverters The {@link Map} of converters to use
+	 */
+	public MappedJwtClaimSetConverter(Map<String, Converter<Object, ?>> claimConverters) {
+		Assert.notNull(claimConverters, "claimConverters cannot be null");
+		this.claimConverters = new HashMap<>(claimConverters);
+	}
+
+	/**
+	 * Construct a {@link MappedJwtClaimSetConverter}, overriding individual claim
+	 * converters with the provided {@link Map} of {@link Converter}s.
+	 *
+	 * For example, the following would give an instance that is configured with only the default
+	 * claim converters:
+	 *
+	 * <pre>
+	 * 	MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap());
+	 * </pre>
+	 *
+	 * Or, the following would supply a custom converter for the subject, leaving the other defaults
+	 * in place:
+	 *
+	 * <pre>
+	 * 	MappedJwtClaimsSetConverter.withDefaults(
+	 * 		Collections.singletonMap(JwtClaimNames.SUB, new UserDetailsServiceJwtSubjectConverter()));
+	 * </pre>
+	 *
+	 * To completely replace the underlying {@link Map} of converters, {@see MappedJwtClaimSetConverter(Map)}.
+	 *
+	 * @param claimConverters
+	 * @return An instance of {@link MappedJwtClaimSetConverter} that contains the converters provided,
+	 *   plus any defaults that were not overridden.
+	 */
+	public static MappedJwtClaimSetConverter withDefaults
+			(Map<String, Converter<Object, ?>> claimConverters) {
+		Assert.notNull(claimConverters, "claimConverters cannot be null");
+
+		Map<String, Converter<Object, ?>> claimNameToConverter = new HashMap<>();
+		claimNameToConverter.put(JwtClaimNames.AUD, AUDIENCE_CONVERTER);
+		claimNameToConverter.put(JwtClaimNames.EXP, TEMPORAL_CONVERTER);
+		claimNameToConverter.put(JwtClaimNames.IAT, TEMPORAL_CONVERTER);
+		claimNameToConverter.put(JwtClaimNames.ISS, ISSUER_CONVERTER);
+		claimNameToConverter.put(JwtClaimNames.JTI, STRING_CONVERTER);
+		claimNameToConverter.put(JwtClaimNames.NBF, TEMPORAL_CONVERTER);
+		claimNameToConverter.put(JwtClaimNames.SUB, STRING_CONVERTER);
+		claimNameToConverter.putAll(claimConverters);
+
+		return new MappedJwtClaimSetConverter(claimNameToConverter);
+	}
+
+	/**
+	 * {@inheritDoc}
+	 */
+	@Override
+	public Map<String, Object> convert(Map<String, Object> claims) {
+		Assert.notNull(claims, "claims cannot be null");
+
+		Map<String, Object> mappedClaims = new HashMap<>(claims);
+
+		for (Map.Entry<String, Converter<Object, ?>> entry : this.claimConverters.entrySet()) {
+			String claimName = entry.getKey();
+			Converter<Object, ?> converter = entry.getValue();
+			if (converter != null) {
+				Object claim = claims.get(claimName);
+				Object mappedClaim = converter.convert(claim);
+				mappedClaims.compute(claimName, (key, value) -> mappedClaim);
+			}
+		}
+
+		Instant issuedAt = (Instant) mappedClaims.get(JwtClaimNames.IAT);
+		Instant expiresAt = (Instant) mappedClaims.get(JwtClaimNames.EXP);
+		if (issuedAt == null && expiresAt != null) {
+			mappedClaims.put(JwtClaimNames.IAT, expiresAt.minusSeconds(1));
+		}
+
+		return mappedClaims;
+	}
+
+	/**
+	 * Coerces an <a target="_blank" href="https://tools.ietf.org/html/rfc7519#section-4.1.3">Audience</a> claim
+	 * into a {@link Collection<String>}, ignoring null values, and throwing an error if its coercion efforts fail.
+	 */
+	private static class AudienceConverter implements Converter<Object, Collection<String>> {
+
+		@Override
+		public Collection<String> convert(Object source) {
+			if (source == null) {
+				return null;
+			}
+
+			if (source instanceof Collection) {
+				return ((Collection<?>) source).stream()
+						.filter(Objects::nonNull)
+						.map(Objects::toString)
+						.collect(Collectors.toList());
+			}
+
+			return Arrays.asList(source.toString());
+		}
+	}
+
+	/**
+	 * Coerces an <a target="_blank" href="https://tools.ietf.org/html/rfc7519#section-4.1.1">Issuer</a> claim
+	 * into a {@link URL}, ignoring null values, and throwing an error if its coercion efforts fail.
+	 */
+	private static class IssuerConverter implements Converter<Object, URL> {
+
+		@Override
+		public URL convert(Object source) {
+			if (source == null) {
+				return null;
+			}
+
+			if (source instanceof URL) {
+				return (URL) source;
+			}
+
+			if (source instanceof URI) {
+				return toUrl((URI) source);
+			}
+
+			return toUrl(source.toString());
+		}
+
+		private URL toUrl(URI source) {
+			try {
+				return source.toURL();
+			} catch (MalformedURLException e) {
+				throw new IllegalStateException("Could not coerce " + source + " into a URL", e);
+			}
+		}
+
+		private URL toUrl(String source) {
+			try {
+				return new URL(source);
+			} catch (MalformedURLException e) {
+				throw new IllegalStateException("Could not coerce " + source + " into a URL", e);
+			}
+		}
+	}
+
+	/**
+	 * Coerces a claim into an {@link Instant}, ignoring null values, and throwing an error
+	 * if its coercion efforts fail.
+	 */
+	private static class InstantConverter implements Converter<Object, Instant> {
+		@Override
+		public Instant convert(Object source) {
+			if (source == null) {
+				return null;
+			}
+
+			if (source instanceof Instant) {
+				return (Instant) source;
+			}
+
+			if (source instanceof Date) {
+				return ((Date) source).toInstant();
+			}
+
+			if (source instanceof Number) {
+				return Instant.ofEpochSecond(((Number) source).longValue());
+			}
+
+			try {
+				return Instant.ofEpochSecond(Long.parseLong(source.toString()));
+			} catch (Exception e) {
+				throw new IllegalStateException("Could not coerce " + source + " into an Instant", e);
+			}
+		}
+	}
+
+	/**
+	 * Coerces a claim into a {@link String}, ignoring null values, and throwing an error if its
+	 * coercion efforts fail.
+	 */
+	private static class StringConverter implements Converter<Object, String> {
+		@Override
+		public String convert(Object source) {
+			if (source == null) {
+				return null;
+			}
+
+			return source.toString();
+		}
+	}
+}

+ 18 - 14
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupport.java

@@ -40,6 +40,7 @@ import com.nimbusds.jwt.SignedJWT;
 import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
 import com.nimbusds.jwt.proc.DefaultJWTProcessor;
 
+import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.MediaType;
@@ -78,8 +79,11 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder {
 	private final ConfigurableJWTProcessor<SecurityContext> jwtProcessor;
 	private final RestOperationsResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever();
 
+	private Converter<Map<String, Object>, Map<String, Object>> claimSetConverter =
+			MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap());
 	private OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefault();
 
+
 	/**
 	 * Constructs a {@code NimbusJwtDecoderJwkSupport} using the provided parameters.
 	 *
@@ -134,6 +138,16 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder {
 		this.jwtValidator = jwtValidator;
 	}
 
+	/**
+	 * Use the following {@link Converter} for manipulating the JWT's claim set
+	 *
+	 * @param claimSetConverter the {@link Converter} to use
+	 */
+	public final void setClaimSetConverter(Converter<Map<String, Object>, Map<String, Object>> claimSetConverter) {
+		Assert.notNull(claimSetConverter, "claimSetConverter cannot be null");
+		this.claimSetConverter = claimSetConverter;
+	}
+
 	private JWT parse(String token) {
 		try {
 			return JWTParser.parse(token);
@@ -149,22 +163,12 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder {
 			// Verify the signature
 			JWTClaimsSet jwtClaimsSet = this.jwtProcessor.process(parsedJwt, null);
 
-			Instant expiresAt = null;
-			if (jwtClaimsSet.getExpirationTime() != null) {
-				expiresAt = jwtClaimsSet.getExpirationTime().toInstant();
-			}
-			Instant issuedAt = null;
-			if (jwtClaimsSet.getIssueTime() != null) {
-				issuedAt = jwtClaimsSet.getIssueTime().toInstant();
-			} else if (expiresAt != null) {
-				// Default to expiresAt - 1 second
-				issuedAt = Instant.from(expiresAt).minusSeconds(1);
-			}
-
 			Map<String, Object> headers = new LinkedHashMap<>(parsedJwt.getHeader().toJSONObject());
+			Map<String, Object> claims = this.claimSetConverter.convert(jwtClaimsSet.getClaims());
 
-			jwt = new Jwt(token, issuedAt, expiresAt, headers, jwtClaimsSet.getClaims());
-
+			Instant expiresAt = (Instant) claims.get(JwtClaimNames.EXP);
+			Instant issuedAt = (Instant) claims.get(JwtClaimNames.IAT);
+			jwt = new Jwt(token, issuedAt, expiresAt, headers, claims);
 		} catch (RemoteKeySourceException ex) {
 			if (ex.getCause() instanceof ParseException) {
 				throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"));

+ 223 - 0
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/MappedJwtClaimSetConverterTests.java

@@ -0,0 +1,223 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.jwt;
+
+import java.net.URI;
+import java.net.URL;
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Date;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.Test;
+
+import org.springframework.core.convert.converter.Converter;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link MappedJwtClaimSetConverter}
+ *
+ * @author Josh Cummings
+ */
+public class MappedJwtClaimSetConverterTests {
+	@Test
+	public void convertWhenUsingCustomExpiresAtConverterThenIssuedAtConverterStillConsultsIt() {
+		Instant at = Instant.ofEpochMilli(1000000000000L);
+		Converter<Object, Instant> expiresAtConverter = mock(Converter.class);
+		when(expiresAtConverter.convert(any())).thenReturn(at);
+
+		MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter
+				.withDefaults(Collections.singletonMap(JwtClaimNames.EXP, expiresAtConverter));
+
+		Map<String, Object> source = new HashMap<>();
+		Map<String, Object> target = converter.convert(source);
+
+		assertThat(target.get(JwtClaimNames.IAT)).
+				isEqualTo(Instant.ofEpochMilli(at.toEpochMilli()).minusSeconds(1));
+	}
+
+	@Test
+	public void convertWhenUsingDefaultsThenBasesIssuedAtOffOfExpiration() {
+		MappedJwtClaimSetConverter converter =
+				MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap());
+
+		Map<String, Object> source = Collections.singletonMap(JwtClaimNames.EXP, 1000000000L);
+		Map<String, Object> target = converter.convert(source);
+
+		assertThat(target.get(JwtClaimNames.EXP)).isEqualTo(Instant.ofEpochSecond(1000000000L));
+		assertThat(target.get(JwtClaimNames.IAT)).isEqualTo(Instant.ofEpochSecond(1000000000L).minusSeconds(1));
+	}
+
+	@Test
+	public void convertWhenUsingDefaultsThenCoercesAudienceAccordingToJwtSpec() {
+		MappedJwtClaimSetConverter converter =
+				MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap());
+
+		Map<String, Object> source = Collections.singletonMap(JwtClaimNames.AUD, "audience");
+		Map<String, Object> target = converter.convert(source);
+
+		assertThat(target.get(JwtClaimNames.AUD)).isInstanceOf(Collection.class);
+		assertThat(target.get(JwtClaimNames.AUD)).isEqualTo(Arrays.asList("audience"));
+
+		source = Collections.singletonMap(JwtClaimNames.AUD, Arrays.asList("one", "two"));
+		target = converter.convert(source);
+
+		assertThat(target.get(JwtClaimNames.AUD)).isInstanceOf(Collection.class);
+		assertThat(target.get(JwtClaimNames.AUD)).isEqualTo(Arrays.asList("one", "two"));
+	}
+
+	@Test
+	public void convertWhenUsingDefaultsThenCoercesAllAttributesInJwtSpec() throws Exception {
+		MappedJwtClaimSetConverter converter =
+				MappedJwtClaimSetConverter.withDefaults(Collections.emptyMap());
+
+		Map<String, Object> source = new HashMap<>();
+		source.put(JwtClaimNames.JTI, 1);
+		source.put(JwtClaimNames.AUD, "audience");
+		source.put(JwtClaimNames.EXP, 2000000000L);
+		source.put(JwtClaimNames.IAT, new Date(1000000000000L));
+		source.put(JwtClaimNames.ISS, "https://any.url");
+		source.put(JwtClaimNames.NBF, 1000000000);
+		source.put(JwtClaimNames.SUB, 1234);
+
+		Map<String, Object> target = converter.convert(source);
+
+		assertThat(target.get(JwtClaimNames.JTI)).isEqualTo("1");
+		assertThat(target.get(JwtClaimNames.AUD)).isEqualTo(Arrays.asList("audience"));
+		assertThat(target.get(JwtClaimNames.EXP)).isEqualTo(Instant.ofEpochSecond(2000000000L));
+		assertThat(target.get(JwtClaimNames.IAT)).isEqualTo(Instant.ofEpochSecond(1000000000L));
+		assertThat(target.get(JwtClaimNames.ISS)).isEqualTo(new URL("https://any.url"));
+		assertThat(target.get(JwtClaimNames.NBF)).isEqualTo(Instant.ofEpochSecond(1000000000L));
+		assertThat(target.get(JwtClaimNames.SUB)).isEqualTo("1234");
+	}
+
+	@Test
+	public void convertWhenUsingCustomConverterThenAllOtherDefaultsAreStillUsed() throws Exception {
+		Converter<Object, String> claimConverter = mock(Converter.class);
+		MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter
+				.withDefaults(Collections.singletonMap(JwtClaimNames.SUB, claimConverter));
+		when(claimConverter.convert(any(Object.class))).thenReturn("1234");
+
+		Map<String, Object> source = new HashMap<>();
+		source.put(JwtClaimNames.JTI, 1);
+		source.put(JwtClaimNames.AUD, "audience");
+		source.put(JwtClaimNames.EXP, Instant.ofEpochSecond(2000000000L));
+		source.put(JwtClaimNames.IAT, new Date(1000000000000L));
+		source.put(JwtClaimNames.ISS, URI.create("https://any.url"));
+		source.put(JwtClaimNames.NBF, "1000000000");
+		source.put(JwtClaimNames.SUB, 2345);
+
+		Map<String, Object> target = converter.convert(source);
+
+		assertThat(target.get(JwtClaimNames.JTI)).isEqualTo("1");
+		assertThat(target.get(JwtClaimNames.AUD)).isEqualTo(Arrays.asList("audience"));
+		assertThat(target.get(JwtClaimNames.EXP)).isEqualTo(Instant.ofEpochSecond(2000000000L));
+		assertThat(target.get(JwtClaimNames.IAT)).isEqualTo(Instant.ofEpochSecond(1000000000L));
+		assertThat(target.get(JwtClaimNames.ISS)).isEqualTo(new URL("https://any.url"));
+		assertThat(target.get(JwtClaimNames.NBF)).isEqualTo(Instant.ofEpochSecond(1000000000L));
+		assertThat(target.get(JwtClaimNames.SUB)).isEqualTo("1234");
+	}
+
+	@Test
+	public void convertWhenConverterReturnsNullThenClaimIsRemoved() {
+		MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter
+				.withDefaults(Collections.emptyMap());
+
+		Map<String, Object> source = Collections.singletonMap(JwtClaimNames.ISS, null);
+		Map<String, Object> target = converter.convert(source);
+
+		assertThat(target).doesNotContainKey(JwtClaimNames.ISS);
+	}
+
+	@Test
+	public void convertWhenConverterReturnsValueWhenEntryIsMissingThenEntryIsAdded() {
+		Converter<Object, String> claimConverter = mock(Converter.class);
+		MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter
+				.withDefaults(Collections.singletonMap("custom-claim", claimConverter));
+		when(claimConverter.convert(any())).thenReturn("custom-value");
+
+		Map<String, Object> source = new HashMap<>();
+		Map<String, Object> target = converter.convert(source);
+
+		assertThat(target.get("custom-claim")).isEqualTo("custom-value");
+	}
+
+	@Test
+	public void convertWhenUsingConstructorThenOnlyConvertersInThatMapAreUsedForConversion() {
+		Converter<Object, String> claimConverter = mock(Converter.class);
+		MappedJwtClaimSetConverter converter = new MappedJwtClaimSetConverter(
+				Collections.singletonMap(JwtClaimNames.SUB, claimConverter));
+		when(claimConverter.convert(any(Object.class))).thenReturn("1234");
+
+		Map<String, Object> source = new HashMap<>();
+		source.put(JwtClaimNames.JTI, new Object());
+		source.put(JwtClaimNames.AUD, new Object());
+		source.put(JwtClaimNames.EXP, Instant.ofEpochSecond(1L));
+		source.put(JwtClaimNames.IAT, Instant.ofEpochSecond(1L));
+		source.put(JwtClaimNames.ISS, new Object());
+		source.put(JwtClaimNames.NBF, new Object());
+		source.put(JwtClaimNames.SUB, new Object());
+
+		Map<String, Object> target = converter.convert(source);
+
+		assertThat(target.get(JwtClaimNames.JTI)).isEqualTo(source.get(JwtClaimNames.JTI));
+		assertThat(target.get(JwtClaimNames.AUD)).isEqualTo(source.get(JwtClaimNames.AUD));
+		assertThat(target.get(JwtClaimNames.EXP)).isEqualTo(source.get(JwtClaimNames.EXP));
+		assertThat(target.get(JwtClaimNames.IAT)).isEqualTo(source.get(JwtClaimNames.IAT));
+		assertThat(target.get(JwtClaimNames.ISS)).isEqualTo(source.get(JwtClaimNames.ISS));
+		assertThat(target.get(JwtClaimNames.NBF)).isEqualTo(source.get(JwtClaimNames.NBF));
+		assertThat(target.get(JwtClaimNames.SUB)).isEqualTo("1234");
+	}
+
+	@Test
+	public void convertWhenUsingDefaultsThenFailedConversionThrowsIllegalStateException() {
+		MappedJwtClaimSetConverter converter = MappedJwtClaimSetConverter
+				.withDefaults(Collections.emptyMap());
+
+		Map<String, Object> badIssuer = Collections.singletonMap(JwtClaimNames.ISS, "badly-formed-iss");
+		assertThatCode(() -> converter.convert(badIssuer)).isInstanceOf(IllegalStateException.class);
+
+		Map<String, Object> badIssuedAt = Collections.singletonMap(JwtClaimNames.IAT, "badly-formed-iat");
+		assertThatCode(() -> converter.convert(badIssuedAt)).isInstanceOf(IllegalStateException.class);
+
+		Map<String, Object> badExpiresAt = Collections.singletonMap(JwtClaimNames.EXP, "badly-formed-exp");
+		assertThatCode(() -> converter.convert(badExpiresAt)).isInstanceOf(IllegalStateException.class);
+
+		Map<String, Object> badNotBefore = Collections.singletonMap(JwtClaimNames.NBF, "badly-formed-nbf");
+		assertThatCode(() -> converter.convert(badNotBefore)).isInstanceOf(IllegalStateException.class);
+	}
+
+	@Test
+	public void constructWhenAnyParameterIsNullThenIllegalArgumentException() {
+		assertThatCode(() -> new MappedJwtClaimSetConverter(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void withDefaultsWhenAnyParameterIsNullThenIllegalArgumentException() {
+		assertThatCode(() -> MappedJwtClaimSetConverter.withDefaults(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+}

+ 28 - 0
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupportTests.java

@@ -16,6 +16,8 @@
 package org.springframework.security.oauth2.jwt;
 
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.Map;
 
 import com.nimbusds.jose.JWSAlgorithm;
 import com.nimbusds.jose.JWSHeader;
@@ -33,6 +35,7 @@ import org.powermock.core.classloader.annotations.PowerMockIgnore;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
 
+import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.RequestEntity;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2TokenValidator;
@@ -40,6 +43,7 @@ import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
 import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
 import org.springframework.web.client.RestTemplate;
 
+import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.AssertionsForClassTypes.assertThatCode;
 import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
 import static org.mockito.ArgumentMatchers.any;
@@ -228,4 +232,28 @@ public class NimbusJwtDecoderJwkSupportTests {
 					.hasFieldOrPropertyWithValue("errors", Arrays.asList(firstFailure, secondFailure));
 		}
 	}
+
+	@Test
+	public void decodeWhenUsingSignedJwtThenReturnsClaimsGivenByClaimSetConverter() throws Exception {
+		try ( MockWebServer server = new MockWebServer() ) {
+			server.enqueue(new MockResponse().setBody(JWK_SET));
+			String jwkSetUrl = server.url("/.well-known/jwks.json").toString();
+
+			NimbusJwtDecoderJwkSupport decoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
+
+			Converter<Map<String, Object>, Map<String, Object>> claimSetConverter = mock(Converter.class);
+			when(claimSetConverter.convert(any(Map.class))).thenReturn(Collections.singletonMap("custom", "value"));
+			decoder.setClaimSetConverter(claimSetConverter);
+
+			Jwt jwt = decoder.decode(SIGNED_JWT);
+			assertThat(jwt.getClaims().size()).isEqualTo(1);
+			assertThat(jwt.getClaims().get("custom")).isEqualTo("value");
+		}
+	}
+
+	@Test
+	public void setClaimSetConverterWhenIsNullThenThrowsIllegalArgumentException() {
+		assertThatCode(() -> jwtDecoder.setClaimSetConverter(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
 }