浏览代码

Polish JwtEncoder APIs

Closes gh-391
Joe Grandja 4 年之前
父节点
当前提交
6b5d9f0fe5

+ 63 - 60
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/jwt/JoseHeader.java

@@ -24,7 +24,7 @@ import java.util.Set;
 import java.util.function.Consumer;
 
 import org.springframework.security.oauth2.core.converter.ClaimConversionService;
-import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
+import org.springframework.security.oauth2.jose.JwaAlgorithm;
 import org.springframework.util.Assert;
 
 /**
@@ -36,9 +36,9 @@ import org.springframework.util.Assert;
  * @author Joe Grandja
  * @since 0.0.1
  * @see Jwt
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc7519#section-5">JWT JOSE Header</a>
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc7515#section-4">JWS JOSE Header</a>
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc7516#section-4">JWE JOSE Header</a>
+ * @see <a target="_blank" href="https://datatracker.ietf.org/doc/html/rfc7519#section-5">JWT JOSE Header</a>
+ * @see <a target="_blank" href="https://datatracker.ietf.org/doc/html/rfc7515#section-4">JWS JOSE Header</a>
+ * @see <a target="_blank" href="https://datatracker.ietf.org/doc/html/rfc7516#section-4">JWE JOSE Header</a>
  */
 public final class JoseHeader {
 	private final Map<String, Object> headers;
@@ -48,12 +48,13 @@ public final class JoseHeader {
 	}
 
 	/**
-	 * Returns the {@link JwsAlgorithm JWS algorithm} used to digitally sign the JWS.
+	 * Returns the {@link JwaAlgorithm JWA algorithm} used to digitally sign the JWS or encrypt the JWE.
 	 *
-	 * @return the JWS algorithm
+	 * @return the {@link JwaAlgorithm}
 	 */
-	public JwsAlgorithm getJwsAlgorithm() {
-		return getHeader(JoseHeaderNames.ALG);
+	@SuppressWarnings("unchecked")
+	public <T extends JwaAlgorithm> T getAlgorithm() {
+		return (T) getHeader(JoseHeaderNames.ALG);
 	}
 
 	/**
@@ -62,7 +63,7 @@ public final class JoseHeader {
 	 *
 	 * @return the JWK Set URL
 	 */
-	public URL getJwkSetUri() {
+	public URL getJwkSetUrl() {
 		return getHeader(JoseHeaderNames.JKU);
 	}
 
@@ -91,13 +92,16 @@ public final class JoseHeader {
 	 *
 	 * @return the X.509 URL
 	 */
-	public URL getX509Uri() {
+	public URL getX509Url() {
 		return getHeader(JoseHeaderNames.X5U);
 	}
 
 	/**
 	 * Returns the X.509 certificate chain that contains the X.509 public key certificate
-	 * or certificate chain corresponding to the key used to digitally sign the JWS or encrypt the JWE.
+	 * or certificate chain corresponding to the key used to digitally sign the JWS or
+	 * encrypt the JWE. The certificate or certificate chain is represented as a
+	 * {@code List} of certificate value {@code String}s. Each {@code String} in the
+	 * {@code List} is a Base64-encoded DER PKIX certificate value.
 	 *
 	 * @return the X.509 certificate chain
 	 */
@@ -125,16 +129,6 @@ public final class JoseHeader {
 		return getHeader(JoseHeaderNames.X5T_S256);
 	}
 
-	/**
-	 * Returns the critical headers that indicates which extensions to the JWS/JWE/JWA specifications
-	 * are being used that MUST be understood and processed.
-	 *
-	 * @return the critical headers
-	 */
-	public Set<String> getCritical() {
-		return getHeader(JoseHeaderNames.CRIT);
-	}
-
 	/**
 	 * Returns the type header that declares the media type of the JWS/JWE.
 	 *
@@ -153,6 +147,16 @@ public final class JoseHeader {
 		return getHeader(JoseHeaderNames.CTY);
 	}
 
+	/**
+	 * Returns the critical headers that indicates which extensions to the JWS/JWE/JWA specifications
+	 * are being used that MUST be understood and processed.
+	 *
+	 * @return the critical headers
+	 */
+	public Set<String> getCritical() {
+		return getHeader(JoseHeaderNames.CRIT);
+	}
+
 	/**
 	 * Returns the headers.
 	 *
@@ -185,13 +189,13 @@ public final class JoseHeader {
 	}
 
 	/**
-	 * Returns a new {@link Builder}, initialized with the provided {@link JwsAlgorithm}.
+	 * Returns a new {@link Builder}, initialized with the provided {@link JwaAlgorithm}.
 	 *
-	 * @param jwsAlgorithm the {@link JwsAlgorithm}
+	 * @param jwaAlgorithm the {@link JwaAlgorithm}
 	 * @return the {@link Builder}
 	 */
-	public static Builder withAlgorithm(JwsAlgorithm jwsAlgorithm) {
-		return new Builder(jwsAlgorithm);
+	public static Builder withAlgorithm(JwaAlgorithm jwaAlgorithm) {
+		return new Builder(jwaAlgorithm);
 	}
 
 	/**
@@ -213,9 +217,8 @@ public final class JoseHeader {
 		private Builder() {
 		}
 
-		private Builder(JwsAlgorithm jwsAlgorithm) {
-			Assert.notNull(jwsAlgorithm, "jwsAlgorithm cannot be null");
-			header(JoseHeaderNames.ALG, jwsAlgorithm);
+		private Builder(JwaAlgorithm jwaAlgorithm) {
+			algorithm(jwaAlgorithm);
 		}
 
 		private Builder(JoseHeader headers) {
@@ -224,24 +227,25 @@ public final class JoseHeader {
 		}
 
 		/**
-		 * Sets the {@link JwsAlgorithm JWS algorithm} used to digitally sign the JWS.
+		 * Sets the {@link JwaAlgorithm JWA algorithm} used to digitally sign the JWS or encrypt the JWE.
 		 *
-		 * @param jwsAlgorithm the JWS algorithm
+		 * @param jwaAlgorithm the {@link JwaAlgorithm}
 		 * @return the {@link Builder}
 		 */
-		public Builder jwsAlgorithm(JwsAlgorithm jwsAlgorithm) {
-			return header(JoseHeaderNames.ALG, jwsAlgorithm);
+		public Builder algorithm(JwaAlgorithm jwaAlgorithm) {
+			Assert.notNull(jwaAlgorithm, "jwaAlgorithm cannot be null");
+			return header(JoseHeaderNames.ALG, jwaAlgorithm);
 		}
 
 		/**
 		 * Sets the JWK Set URL that refers to the resource of a set of JSON-encoded public keys,
 		 * one of which corresponds to the key used to digitally sign the JWS or encrypt the JWE.
 		 *
-		 * @param jwkSetUri the JWK Set URL
+		 * @param jwkSetUrl the JWK Set URL
 		 * @return the {@link Builder}
 		 */
-		public Builder jwkSetUri(String jwkSetUri) {
-			return header(JoseHeaderNames.JKU, jwkSetUri);
+		public Builder jwkSetUrl(String jwkSetUrl) {
+			return header(JoseHeaderNames.JKU, convertAsURL(JoseHeaderNames.JKU, jwkSetUrl));
 		}
 
 		/**
@@ -269,16 +273,19 @@ public final class JoseHeader {
 		 * Sets the X.509 URL that refers to the resource for the X.509 public key certificate
 		 * or certificate chain corresponding to the key used to digitally sign the JWS or encrypt the JWE.
 		 *
-		 * @param x509Uri the X.509 URL
+		 * @param x509Url the X.509 URL
 		 * @return the {@link Builder}
 		 */
-		public Builder x509Uri(String x509Uri) {
-			return header(JoseHeaderNames.X5U, x509Uri);
+		public Builder x509Url(String x509Url) {
+			return header(JoseHeaderNames.X5U, convertAsURL(JoseHeaderNames.X5U, x509Url));
 		}
 
 		/**
 		 * Sets the X.509 certificate chain that contains the X.509 public key certificate
-		 * or certificate chain corresponding to the key used to digitally sign the JWS or encrypt the JWE.
+		 * or certificate chain corresponding to the key used to digitally sign the JWS or
+		 * encrypt the JWE. The certificate or certificate chain is represented as a
+		 * {@code List} of certificate value {@code String}s. Each {@code String} in the
+		 * {@code List} is a Base64-encoded DER PKIX certificate value.
 		 *
 		 * @param x509CertificateChain the X.509 certificate chain
 		 * @return the {@link Builder}
@@ -309,17 +316,6 @@ public final class JoseHeader {
 			return header(JoseHeaderNames.X5T_S256, x509SHA256Thumbprint);
 		}
 
-		/**
-		 * Sets the critical headers that indicates which extensions to the JWS/JWE/JWA specifications
-		 * are being used that MUST be understood and processed.
-		 *
-		 * @param headerNames the critical header names
-		 * @return the {@link Builder}
-		 */
-		public Builder critical(Set<String> headerNames) {
-			return header(JoseHeaderNames.CRIT, headerNames);
-		}
-
 		/**
 		 * Sets the type header that declares the media type of the JWS/JWE.
 		 *
@@ -340,6 +336,17 @@ public final class JoseHeader {
 			return header(JoseHeaderNames.CTY, contentType);
 		}
 
+		/**
+		 * Sets the critical headers that indicates which extensions to the JWS/JWE/JWA specifications
+		 * are being used that MUST be understood and processed.
+		 *
+		 * @param headerNames the critical header names
+		 * @return the {@link Builder}
+		 */
+		public Builder critical(Set<String> headerNames) {
+			return header(JoseHeaderNames.CRIT, headerNames);
+		}
+
 		/**
 		 * Sets the header.
 		 *
@@ -373,19 +380,15 @@ public final class JoseHeader {
 		 */
 		public JoseHeader build() {
 			Assert.notEmpty(this.headers, "headers cannot be empty");
-			convertAsURL(JoseHeaderNames.JKU);
-			convertAsURL(JoseHeaderNames.X5U);
 			return new JoseHeader(this.headers);
 		}
 
-		private void convertAsURL(String header) {
-			Object value = this.headers.get(header);
-			if (value != null) {
-				URL convertedValue = ClaimConversionService.getSharedInstance().convert(value, URL.class);
-				Assert.isTrue(convertedValue != null,
-						() -> "Unable to convert header '" + header + "' of type '" + value.getClass() + "' to URL.");
-				this.headers.put(header, convertedValue);
-			}
+		private static URL convertAsURL(String header, String value) {
+			URL convertedValue = ClaimConversionService.getSharedInstance().convert(value, URL.class);
+			Assert.isTrue(convertedValue != null,
+					() -> "Unable to convert header '" + header + "' of type '" + value.getClass() + "' to URL.");
+			return convertedValue;
 		}
+
 	}
 }

+ 16 - 3
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/jwt/JwtClaimsSet.java

@@ -15,6 +15,7 @@
  */
 package org.springframework.security.oauth2.jwt;
 
+import java.net.URL;
 import java.time.Instant;
 import java.util.Collections;
 import java.util.HashMap;
@@ -22,6 +23,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.function.Consumer;
 
+import org.springframework.security.oauth2.core.converter.ClaimConversionService;
 import org.springframework.util.Assert;
 
 /**
@@ -32,7 +34,7 @@ import org.springframework.util.Assert;
  * @since 0.0.1
  * @see Jwt
  * @see JwtClaimAccessor
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc7519#section-4">JWT Claims Set</a>
+ * @see <a target="_blank" href="https://datatracker.ietf.org/doc/html/rfc7519#section-4">JWT Claims Set</a>
  */
 public final class JwtClaimsSet implements JwtClaimAccessor {
 	private final Map<String, Object> claims;
@@ -166,10 +168,10 @@ public final class JwtClaimsSet implements JwtClaimAccessor {
 		}
 
 		/**
-		 * A {@code Consumer} to be provided access to the claims set
+		 * A {@code Consumer} to be provided access to the claims
 		 * allowing the ability to add, replace, or remove.
 		 *
-		 * @param claimsConsumer a {@code Consumer} of the claims set
+		 * @param claimsConsumer a {@code Consumer} of the claims
 		 */
 		public Builder claims(Consumer<Map<String, Object>> claimsConsumer) {
 			claimsConsumer.accept(this.claims);
@@ -183,6 +185,17 @@ public final class JwtClaimsSet implements JwtClaimAccessor {
 		 */
 		public JwtClaimsSet build() {
 			Assert.notEmpty(this.claims, "claims cannot be empty");
+
+			// The value of the 'iss' claim is a String or URL (StringOrURI).
+			// Attempt to convert to URL.
+			Object issuer = this.claims.get(JwtClaimNames.ISS);
+			if (issuer != null) {
+				URL convertedValue = ClaimConversionService.getSharedInstance().convert(issuer, URL.class);
+				if (convertedValue != null) {
+					this.claims.put(JwtClaimNames.ISS, convertedValue);
+				}
+			}
+
 			return new JwtClaimsSet(this.claims);
 		}
 	}

+ 5 - 5
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/jwt/JwtEncoder.java

@@ -32,11 +32,11 @@ package org.springframework.security.oauth2.jwt;
  * @see JoseHeader
  * @see JwtClaimsSet
  * @see JwtDecoder
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc7519">JSON Web Token (JWT)</a>
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc7515">JSON Web Signature (JWS)</a>
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc7516">JSON Web Encryption (JWE)</a>
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc7515#section-3.1">JWS Compact Serialization</a>
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc7516#section-3.1">JWE Compact Serialization</a>
+ * @see <a target="_blank" href="https://datatracker.ietf.org/doc/html/rfc7519">JSON Web Token (JWT)</a>
+ * @see <a target="_blank" href="https://datatracker.ietf.org/doc/html/rfc7515">JSON Web Signature (JWS)</a>
+ * @see <a target="_blank" href="https://datatracker.ietf.org/doc/html/rfc7516">JSON Web Encryption (JWE)</a>
+ * @see <a target="_blank" href="https://datatracker.ietf.org/doc/html/rfc7515#section-3.1">JWS Compact Serialization</a>
+ * @see <a target="_blank" href="https://datatracker.ietf.org/doc/html/rfc7516#section-3.1">JWE Compact Serialization</a>
  */
 @FunctionalInterface
 public interface JwtEncoder {

+ 143 - 110
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java

@@ -15,26 +15,28 @@
  */
 package org.springframework.security.oauth2.jwt;
 
+import java.net.URI;
 import java.net.URL;
 import java.time.Instant;
+import java.util.ArrayList;
 import java.util.Date;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.UUID;
 import java.util.concurrent.ConcurrentHashMap;
-import java.util.stream.Collectors;
 
 import com.nimbusds.jose.JOSEException;
 import com.nimbusds.jose.JOSEObjectType;
 import com.nimbusds.jose.JWSAlgorithm;
 import com.nimbusds.jose.JWSHeader;
 import com.nimbusds.jose.JWSSigner;
-import com.nimbusds.jose.KeySourceException;
 import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
 import com.nimbusds.jose.jwk.JWK;
 import com.nimbusds.jose.jwk.JWKMatcher;
 import com.nimbusds.jose.jwk.JWKSelector;
+import com.nimbusds.jose.jwk.KeyType;
+import com.nimbusds.jose.jwk.KeyUse;
 import com.nimbusds.jose.jwk.source.JWKSource;
 import com.nimbusds.jose.proc.SecurityContext;
 import com.nimbusds.jose.produce.JWSSignerFactory;
@@ -62,27 +64,17 @@ import org.springframework.util.StringUtils;
  * @see JwtEncoder
  * @see com.nimbusds.jose.jwk.source.JWKSource
  * @see com.nimbusds.jose.jwk.JWK
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc7519">JSON Web Token
- * (JWT)</a>
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc7515">JSON Web Signature
- * (JWS)</a>
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc7515#section-3.1">JWS
- * Compact Serialization</a>
- * @see <a target="_blank" href="https://connect2id.com/products/nimbus-jose-jwt">Nimbus
- * JOSE + JWT SDK</a>
+ * @see <a target="_blank" href="https://datatracker.ietf.org/doc/html/rfc7519">JSON Web Token (JWT)</a>
+ * @see <a target="_blank" href="https://datatracker.ietf.org/doc/html/rfc7515">JSON Web Signature (JWS)</a>
+ * @see <a target="_blank" href="https://datatracker.ietf.org/doc/html/rfc7515#section-3.1">JWS Compact Serialization</a>
+ * @see <a target="_blank" href="https://connect2id.com/products/nimbus-jose-jwt">Nimbus JOSE + JWT SDK</a>
  */
 public final class NimbusJwsEncoder implements JwtEncoder {
-
 	private static final String ENCODING_ERROR_MESSAGE_TEMPLATE = "An error occurred while attempting to encode the Jwt: %s";
-
 	private static final Converter<JoseHeader, JWSHeader> JWS_HEADER_CONVERTER = new JwsHeaderConverter();
-
 	private static final Converter<JwtClaimsSet, JWTClaimsSet> JWT_CLAIMS_SET_CONVERTER = new JwtClaimsSetConverter();
-
 	private static final JWSSignerFactory JWS_SIGNER_FACTORY = new DefaultJWSSignerFactory();
-
 	private final Map<JWK, JWSSigner> jwsSigners = new ConcurrentHashMap<>();
-
 	private final JWKSource<SecurityContext> jwkSource;
 
 	/**
@@ -100,108 +92,126 @@ public final class NimbusJwsEncoder implements JwtEncoder {
 		Assert.notNull(claims, "claims cannot be null");
 
 		JWK jwk = selectJwk(headers);
-		if (jwk == null) {
-			throw new JwtEncodingException(
-					String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key"));
+		headers = addKeyIdentifierHeadersIfNecessary(headers, jwk);
+
+		String jws = serialize(headers, claims, jwk);
+
+		return new Jwt(jws, claims.getIssuedAt(), claims.getExpiresAt(), headers.getHeaders(), claims.getClaims());
+	}
+
+	private JWK selectJwk(JoseHeader headers) {
+		List<JWK> jwks;
+		try {
+			JWKSelector jwkSelector = new JWKSelector(createJwkMatcher(headers));
+			jwks = this.jwkSource.get(jwkSelector, null);
+		} catch (Exception ex) {
+			throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
+					"Failed to select a JWK signing key -> " + ex.getMessage()), ex);
+		}
+
+		if (jwks.size() > 1) {
+			throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
+					"Found multiple JWK signing keys for algorithm '" + headers.getAlgorithm().getName() + "'"));
 		}
-		else if (!StringUtils.hasText(jwk.getKeyID())) {
+
+		if (jwks.isEmpty()) {
 			throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
-					"The \"kid\" (key ID) from the selected JWK cannot be empty"));
+					"Failed to select a JWK signing key"));
 		}
 
-		// @formatter:off
-		headers = JoseHeader.from(headers)
-				.type(JOSEObjectType.JWT.getType())
-				.keyId(jwk.getKeyID())
-				.build();
-		claims = JwtClaimsSet.from(claims)
-				.id(UUID.randomUUID().toString())
-				.build();
-		// @formatter:on
+		return jwks.get(0);
+	}
 
+	private String serialize(JoseHeader headers, JwtClaimsSet claims, JWK jwk) {
 		JWSHeader jwsHeader = JWS_HEADER_CONVERTER.convert(headers);
 		JWTClaimsSet jwtClaimsSet = JWT_CLAIMS_SET_CONVERTER.convert(claims);
 
-		JWSSigner jwsSigner = this.jwsSigners.computeIfAbsent(jwk, (key) -> {
-			try {
-				return JWS_SIGNER_FACTORY.createJWSSigner(key);
-			}
-			catch (JOSEException ex) {
-				throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
-						"Failed to create a JWS Signer -> " + ex.getMessage()), ex);
-			}
-		});
+		JWSSigner jwsSigner = this.jwsSigners.computeIfAbsent(jwk, NimbusJwsEncoder::createSigner);
 
 		SignedJWT signedJwt = new SignedJWT(jwsHeader, jwtClaimsSet);
 		try {
 			signedJwt.sign(jwsSigner);
+		} catch (JOSEException ex) {
+			throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
+					"Failed to sign the JWT -> " + ex.getMessage()), ex);
 		}
-		catch (JOSEException ex) {
-			throw new JwtEncodingException(
-					String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to sign the JWT -> " + ex.getMessage()), ex);
+		return signedJwt.serialize();
+	}
+
+	private static JWKMatcher createJwkMatcher(JoseHeader headers) {
+		JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(headers.getAlgorithm().getName());
+
+		if (JWSAlgorithm.Family.RSA.contains(jwsAlgorithm) || JWSAlgorithm.Family.EC.contains(jwsAlgorithm)) {
+			// @formatter:off
+			return new JWKMatcher.Builder()
+					.keyType(KeyType.forAlgorithm(jwsAlgorithm))
+					.keyID(headers.getKeyId())
+					.keyUses(KeyUse.SIGNATURE, null)
+					.algorithms(jwsAlgorithm, null)
+					.x509CertSHA256Thumbprint(Base64URL.from(headers.getX509SHA256Thumbprint()))
+					.build();
+			// @formatter:on
+		} else if (JWSAlgorithm.Family.HMAC_SHA.contains(jwsAlgorithm)) {
+			// @formatter:off
+			return new JWKMatcher.Builder()
+					.keyType(KeyType.forAlgorithm(jwsAlgorithm))
+					.keyID(headers.getKeyId())
+					.privateOnly(true)
+					.algorithms(jwsAlgorithm, null)
+					.build();
+			// @formatter:on
 		}
-		String jws = signedJwt.serialize();
 
-		return new Jwt(jws, claims.getIssuedAt(), claims.getExpiresAt(), headers.getHeaders(), claims.getClaims());
+		return null;
 	}
 
-	private JWK selectJwk(JoseHeader headers) {
-		JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(headers.getJwsAlgorithm().getName());
-		JWSHeader jwsHeader = new JWSHeader(jwsAlgorithm);
-		JWKSelector jwkSelector = new JWKSelector(JWKMatcher.forJWSHeader(jwsHeader));
+	private static JoseHeader addKeyIdentifierHeadersIfNecessary(JoseHeader headers, JWK jwk) {
+		// Check if headers have already been added
+		if (StringUtils.hasText(headers.getKeyId()) && StringUtils.hasText(headers.getX509SHA256Thumbprint())) {
+			return headers;
+		}
+		// Check if headers can be added from JWK
+		if (!StringUtils.hasText(jwk.getKeyID()) && jwk.getX509CertSHA256Thumbprint() == null) {
+			return headers;
+		}
 
-		List<JWK> jwks;
-		try {
-			jwks = this.jwkSource.get(jwkSelector, null);
+		JoseHeader.Builder headersBuilder = JoseHeader.from(headers);
+		if (!StringUtils.hasText(headers.getKeyId()) && StringUtils.hasText(jwk.getKeyID())) {
+			headersBuilder.keyId(jwk.getKeyID());
 		}
-		catch (KeySourceException ex) {
-			throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
-					"Failed to select a JWK signing key -> " + ex.getMessage()), ex);
+		if (!StringUtils.hasText(headers.getX509SHA256Thumbprint()) && jwk.getX509CertSHA256Thumbprint() != null) {
+			headersBuilder.x509SHA256Thumbprint(jwk.getX509CertSHA256Thumbprint().toString());
 		}
 
-		if (jwks.size() > 1) {
+		return headersBuilder.build();
+	}
+
+	private static JWSSigner createSigner(JWK jwk) {
+		try {
+			return JWS_SIGNER_FACTORY.createJWSSigner(jwk);
+		} catch (JOSEException ex) {
 			throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
-					"Found multiple JWK signing keys for algorithm '" + jwsAlgorithm.getName() + "'"));
+					"Failed to create a JWS Signer -> " + ex.getMessage()), ex);
 		}
-
-		return !jwks.isEmpty() ? jwks.get(0) : null;
 	}
 
 	private static class JwsHeaderConverter implements Converter<JoseHeader, JWSHeader> {
 
 		@Override
 		public JWSHeader convert(JoseHeader headers) {
-			JWSHeader.Builder builder = new JWSHeader.Builder(JWSAlgorithm.parse(headers.getJwsAlgorithm().getName()));
+			JWSHeader.Builder builder = new JWSHeader.Builder(JWSAlgorithm.parse(headers.getAlgorithm().getName()));
 
-			Set<String> critical = headers.getCritical();
-			if (!CollectionUtils.isEmpty(critical)) {
-				builder.criticalParams(critical);
-			}
-
-			String contentType = headers.getContentType();
-			if (StringUtils.hasText(contentType)) {
-				builder.contentType(contentType);
-			}
-
-			URL jwkSetUri = headers.getJwkSetUri();
-			if (jwkSetUri != null) {
-				try {
-					builder.jwkURL(jwkSetUri.toURI());
-				}
-				catch (Exception ex) {
-					throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
-							"Failed to convert '" + JoseHeaderNames.JKU + "' JOSE header to a URI"), ex);
-				}
+			if (headers.getJwkSetUrl() != null) {
+				builder.jwkURL(convertAsURI(JoseHeaderNames.JKU, headers.getJwkSetUrl()));
 			}
 
 			Map<String, Object> jwk = headers.getJwk();
 			if (!CollectionUtils.isEmpty(jwk)) {
 				try {
 					builder.jwk(JWK.parse(jwk));
-				}
-				catch (Exception ex) {
+				} catch (Exception ex) {
 					throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
-							"Failed to convert '" + JoseHeaderNames.JWK + "' JOSE header"), ex);
+							"Unable to convert '" + JoseHeaderNames.JWK + "' JOSE header"), ex);
 				}
 			}
 
@@ -210,14 +220,17 @@ public final class NimbusJwsEncoder implements JwtEncoder {
 				builder.keyID(keyId);
 			}
 
-			String type = headers.getType();
-			if (StringUtils.hasText(type)) {
-				builder.type(new JOSEObjectType(type));
+			if (headers.getX509Url() != null) {
+				builder.x509CertURL(convertAsURI(JoseHeaderNames.X5U, headers.getX509Url()));
 			}
 
 			List<String> x509CertificateChain = headers.getX509CertificateChain();
 			if (!CollectionUtils.isEmpty(x509CertificateChain)) {
-				builder.x509CertChain(x509CertificateChain.stream().map(Base64::new).collect(Collectors.toList()));
+				List<Base64> x5cList = new ArrayList<>();
+				x509CertificateChain.forEach((x5c) -> x5cList.add(new Base64(x5c)));
+				if (!x5cList.isEmpty()) {
+					builder.x509CertChain(x5cList);
+				}
 			}
 
 			String x509SHA1Thumbprint = headers.getX509SHA1Thumbprint();
@@ -230,27 +243,43 @@ public final class NimbusJwsEncoder implements JwtEncoder {
 				builder.x509CertSHA256Thumbprint(new Base64URL(x509SHA256Thumbprint));
 			}
 
-			URL x509Uri = headers.getX509Uri();
-			if (x509Uri != null) {
-				try {
-					builder.x509CertURL(x509Uri.toURI());
-				}
-				catch (Exception ex) {
-					throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
-							"Failed to convert '" + JoseHeaderNames.X5U + "' JOSE header to a URI"), ex);
-				}
+			String type = headers.getType();
+			if (StringUtils.hasText(type)) {
+				builder.type(new JOSEObjectType(type));
+			}
+
+			String contentType = headers.getContentType();
+			if (StringUtils.hasText(contentType)) {
+				builder.contentType(contentType);
+			}
+
+			Set<String> critical = headers.getCritical();
+			if (!CollectionUtils.isEmpty(critical)) {
+				builder.criticalParams(critical);
 			}
 
-			Map<String, Object> customHeaders = headers.getHeaders().entrySet().stream()
-					.filter((header) -> !JWSHeader.getRegisteredParameterNames().contains(header.getKey()))
-					.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
-			if (!CollectionUtils.isEmpty(customHeaders)) {
+			Map<String, Object> customHeaders = new HashMap<>();
+			headers.getHeaders().forEach((name, value) -> {
+				if (!JWSHeader.getRegisteredParameterNames().contains(name)) {
+					customHeaders.put(name, value);
+				}
+			});
+			if (!customHeaders.isEmpty()) {
 				builder.customParams(customHeaders);
 			}
 
 			return builder.build();
 		}
 
+		private static URI convertAsURI(String header, URL url) {
+			try {
+				return url.toURI();
+			} catch (Exception ex) {
+				throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
+						"Unable to convert '" + header + "' JOSE header to a URI"), ex);
+			}
+		}
+
 	}
 
 	private static class JwtClaimsSetConverter implements Converter<JwtClaimsSet, JWTClaimsSet> {
@@ -259,9 +288,10 @@ public final class NimbusJwsEncoder implements JwtEncoder {
 		public JWTClaimsSet convert(JwtClaimsSet claims) {
 			JWTClaimsSet.Builder builder = new JWTClaimsSet.Builder();
 
-			URL issuer = claims.getIssuer();
+			// NOTE: The value of the 'iss' claim is a String or URL (StringOrURI).
+			Object issuer = claims.getClaim(JwtClaimNames.ISS);
 			if (issuer != null) {
-				builder.issuer(issuer.toExternalForm());
+				builder.issuer(issuer.toString());
 			}
 
 			String subject = claims.getSubject();
@@ -274,11 +304,6 @@ public final class NimbusJwsEncoder implements JwtEncoder {
 				builder.audience(audience);
 			}
 
-			Instant issuedAt = claims.getIssuedAt();
-			if (issuedAt != null) {
-				builder.issueTime(Date.from(issuedAt));
-			}
-
 			Instant expiresAt = claims.getExpiresAt();
 			if (expiresAt != null) {
 				builder.expirationTime(Date.from(expiresAt));
@@ -289,15 +314,23 @@ public final class NimbusJwsEncoder implements JwtEncoder {
 				builder.notBeforeTime(Date.from(notBefore));
 			}
 
+			Instant issuedAt = claims.getIssuedAt();
+			if (issuedAt != null) {
+				builder.issueTime(Date.from(issuedAt));
+			}
+
 			String jwtId = claims.getId();
 			if (StringUtils.hasText(jwtId)) {
 				builder.jwtID(jwtId);
 			}
 
-			Map<String, Object> customClaims = claims.getClaims().entrySet().stream()
-					.filter((claim) -> !JWTClaimsSet.getRegisteredNames().contains(claim.getKey()))
-					.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
-			if (!CollectionUtils.isEmpty(customClaims)) {
+			Map<String, Object> customClaims = new HashMap<>();
+			claims.getClaims().forEach((name, value) -> {
+				if (!JWTClaimsSet.getRegisteredNames().contains(name)) {
+					customClaims.put(name, value);
+				}
+			});
+			if (!customClaims.isEmpty()) {
 				customClaims.forEach(builder::claim);
 			}
 

+ 8 - 7
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/jwt/JoseHeaderTests.java

@@ -17,6 +17,7 @@ package org.springframework.security.oauth2.jwt;
 
 import org.junit.Test;
 
+import org.springframework.security.oauth2.jose.JwaAlgorithm;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 
 import static org.assertj.core.api.Assertions.assertThat;
@@ -33,18 +34,18 @@ public class JoseHeaderTests {
 	public void withAlgorithmWhenNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> JoseHeader.withAlgorithm(null))
 				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("jwsAlgorithm cannot be null");
+				.hasMessage("jwaAlgorithm cannot be null");
 	}
 
 	@Test
 	public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() {
 		JoseHeader expectedJoseHeader = TestJoseHeaders.joseHeader().build();
 
-		JoseHeader joseHeader = JoseHeader.withAlgorithm(expectedJoseHeader.getJwsAlgorithm())
-				.jwkSetUri(expectedJoseHeader.getJwkSetUri().toExternalForm())
+		JoseHeader joseHeader = JoseHeader.withAlgorithm(expectedJoseHeader.getAlgorithm())
+				.jwkSetUrl(expectedJoseHeader.getJwkSetUrl().toExternalForm())
 				.jwk(expectedJoseHeader.getJwk())
 				.keyId(expectedJoseHeader.getKeyId())
-				.x509Uri(expectedJoseHeader.getX509Uri().toExternalForm())
+				.x509Url(expectedJoseHeader.getX509Url().toExternalForm())
 				.x509CertificateChain(expectedJoseHeader.getX509CertificateChain())
 				.x509SHA1Thumbprint(expectedJoseHeader.getX509SHA1Thumbprint())
 				.x509SHA256Thumbprint(expectedJoseHeader.getX509SHA256Thumbprint())
@@ -53,11 +54,11 @@ public class JoseHeaderTests {
 				.headers(headers -> headers.put("custom-header-name", "custom-header-value"))
 				.build();
 
-		assertThat(joseHeader.getJwsAlgorithm()).isEqualTo(expectedJoseHeader.getJwsAlgorithm());
-		assertThat(joseHeader.getJwkSetUri()).isEqualTo(expectedJoseHeader.getJwkSetUri());
+		assertThat(joseHeader.<JwaAlgorithm>getAlgorithm()).isEqualTo(expectedJoseHeader.getAlgorithm());
+		assertThat(joseHeader.getJwkSetUrl()).isEqualTo(expectedJoseHeader.getJwkSetUrl());
 		assertThat(joseHeader.getJwk()).isEqualTo(expectedJoseHeader.getJwk());
 		assertThat(joseHeader.getKeyId()).isEqualTo(expectedJoseHeader.getKeyId());
-		assertThat(joseHeader.getX509Uri()).isEqualTo(expectedJoseHeader.getX509Uri());
+		assertThat(joseHeader.getX509Url()).isEqualTo(expectedJoseHeader.getX509Url());
 		assertThat(joseHeader.getX509CertificateChain()).isEqualTo(expectedJoseHeader.getX509CertificateChain());
 		assertThat(joseHeader.getX509SHA1Thumbprint()).isEqualTo(expectedJoseHeader.getX509SHA1Thumbprint());
 		assertThat(joseHeader.getX509SHA256Thumbprint()).isEqualTo(expectedJoseHeader.getX509SHA256Thumbprint());

+ 23 - 24
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java

@@ -39,6 +39,7 @@ import org.mockito.stubbing.Answer;
 
 import org.springframework.security.oauth2.jose.TestJwks;
 import org.springframework.security.oauth2.jose.TestKeys;
+import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@@ -111,7 +112,7 @@ public class NimbusJwsEncoderTests {
 		this.jwkList.add(rsaJwk);
 		this.jwkList.add(rsaJwk);
 
-		JoseHeader joseHeader = TestJoseHeaders.joseHeader().build();
+		JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
 		JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
 
 		assertThatExceptionOfType(JwtEncodingException.class)
@@ -129,24 +130,6 @@ public class NimbusJwsEncoderTests {
 				.withMessageContaining("Failed to select a JWK signing key");
 	}
 
-	@Test
-	public void encodeWhenJwkKidNullThenThrowJwtEncodingException() throws Exception {
-		// @formatter:off
-		RSAKey rsaJwk = TestJwks.jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY)
-				.keyID(null)
-				.build();
-		// @formatter:on
-
-		this.jwkList.add(rsaJwk);
-
-		JoseHeader joseHeader = TestJoseHeaders.joseHeader().build();
-		JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
-
-		assertThatExceptionOfType(JwtEncodingException.class)
-				.isThrownBy(() -> this.jwsEncoder.encode(joseHeader, jwtClaimsSet))
-				.withMessageContaining("The \"kid\" (key ID) from the selected JWK cannot be empty");
-	}
-
 	@Test
 	public void encodeWhenJwkUseEncryptionThenThrowJwtEncodingException() throws Exception {
 		// @formatter:off
@@ -172,15 +155,31 @@ public class NimbusJwsEncoderTests {
 		RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK;
 		this.jwkList.add(rsaJwk);
 
-		JoseHeader joseHeader = TestJoseHeaders.joseHeader().build();
+		JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
 		JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
 
 		Jwt encodedJws = this.jwsEncoder.encode(joseHeader, jwtClaimsSet);
 
-		// Assert headers/claims were added
-		assertThat(encodedJws.getHeaders().get(JoseHeaderNames.TYP)).isEqualTo("JWT");
+		assertThat(encodedJws.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(joseHeader.getAlgorithm());
+		assertThat(encodedJws.getHeaders().get(JoseHeaderNames.JKU)).isNull();
+		assertThat(encodedJws.getHeaders().get(JoseHeaderNames.JWK)).isNull();
 		assertThat(encodedJws.getHeaders().get(JoseHeaderNames.KID)).isEqualTo(rsaJwk.getKeyID());
-		assertThat(encodedJws.getId()).isNotNull();
+		assertThat(encodedJws.getHeaders().get(JoseHeaderNames.X5U)).isNull();
+		assertThat(encodedJws.getHeaders().get(JoseHeaderNames.X5C)).isNull();
+		assertThat(encodedJws.getHeaders().get(JoseHeaderNames.X5T)).isNull();
+		assertThat(encodedJws.getHeaders().get(JoseHeaderNames.X5T_S256)).isNull();
+		assertThat(encodedJws.getHeaders().get(JoseHeaderNames.TYP)).isNull();
+		assertThat(encodedJws.getHeaders().get(JoseHeaderNames.CTY)).isNull();
+		assertThat(encodedJws.getHeaders().get(JoseHeaderNames.CRIT)).isNull();
+
+		assertThat(encodedJws.getIssuer()).isEqualTo(jwtClaimsSet.getIssuer());
+		assertThat(encodedJws.getSubject()).isEqualTo(jwtClaimsSet.getSubject());
+		assertThat(encodedJws.getAudience()).isEqualTo(jwtClaimsSet.getAudience());
+		assertThat(encodedJws.getExpiresAt()).isEqualTo(jwtClaimsSet.getExpiresAt());
+		assertThat(encodedJws.getNotBefore()).isEqualTo(jwtClaimsSet.getNotBefore());
+		assertThat(encodedJws.getIssuedAt()).isEqualTo(jwtClaimsSet.getIssuedAt());
+		assertThat(encodedJws.getId()).isEqualTo(jwtClaimsSet.getId());
+		assertThat(encodedJws.<String>getClaim("custom-claim-name")).isEqualTo("custom-claim-value");
 
 		NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk.toRSAPublicKey()).build();
 		jwtDecoder.decode(encodedJws.getTokenValue());
@@ -200,7 +199,7 @@ public class NimbusJwsEncoderTests {
 		JwkListResultCaptor jwkListResultCaptor = new JwkListResultCaptor();
 		willAnswer(jwkListResultCaptor).given(jwkSourceDelegate).get(any(), any());
 
-		JoseHeader joseHeader = TestJoseHeaders.joseHeader().build();
+		JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
 		JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
 
 		Jwt encodedJws = jwsEncoder.encode(joseHeader, jwtClaimsSet);

+ 2 - 2
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/jwt/TestJoseHeaders.java

@@ -36,10 +36,10 @@ public final class TestJoseHeaders {
 	public static JoseHeader.Builder joseHeader(SignatureAlgorithm signatureAlgorithm) {
 		// @formatter:off
 		return JoseHeader.withAlgorithm(signatureAlgorithm)
-				.jwkSetUri("https://provider.com/oauth2/jwks")
+				.jwkSetUrl("https://provider.com/oauth2/jwks")
 				.jwk(rsaJwk())
 				.keyId("keyId")
-				.x509Uri("https://provider.com/oauth2/x509")
+				.x509Url("https://provider.com/oauth2/x509")
 				.x509CertificateChain(Arrays.asList("x509Cert1", "x509Cert2"))
 				.x509SHA1Thumbprint("x509SHA1Thumbprint")
 				.x509SHA256Thumbprint("x509SHA256Thumbprint")