Kaynağa Gözat

Polish NimbusJwtEncoder Builders

- Simplify withKeyPair methods to match withPublicKey convention
in NimbusJwtDecoder
- Update tests to confirm support of other algorithms
- Update constructor to apply additional JWK properties
to the default header
- Deduce the possibly algorithms for a given key based
on curve and key size
- Remove algorithm method from EC builder since the
algorithm is determined by the Curve of the EC Key

Issue gh-16267

Co-Authored-By: Suraj Bhadrike <surajbh2233@gmail.com>
Josh Cummings 2 ay önce
ebeveyn
işleme
676b44ebb0

+ 87 - 0
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JWKS.java

@@ -0,0 +1,87 @@
+/*
+ * Copyright 2002-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.jwt;
+
+import java.security.interfaces.ECPrivateKey;
+import java.security.interfaces.ECPublicKey;
+import java.security.interfaces.RSAPrivateKey;
+import java.security.interfaces.RSAPublicKey;
+import java.util.Date;
+import java.util.Set;
+
+import javax.crypto.SecretKey;
+
+import com.nimbusds.jose.JOSEException;
+import com.nimbusds.jose.JWSAlgorithm;
+import com.nimbusds.jose.crypto.impl.ECDSA;
+import com.nimbusds.jose.jwk.Curve;
+import com.nimbusds.jose.jwk.ECKey;
+import com.nimbusds.jose.jwk.KeyOperation;
+import com.nimbusds.jose.jwk.KeyUse;
+import com.nimbusds.jose.jwk.OctetSequenceKey;
+import com.nimbusds.jose.jwk.RSAKey;
+
+final class JWKS {
+
+	private JWKS() {
+
+	}
+
+	static OctetSequenceKey.Builder signing(SecretKey key) throws JOSEException {
+		Date issued = new Date();
+		return new OctetSequenceKey.Builder(key).keyOperations(Set.of(KeyOperation.SIGN))
+			.keyUse(KeyUse.SIGNATURE)
+			.algorithm(JWSAlgorithm.HS256)
+			.keyIDFromThumbprint()
+			.issueTime(issued)
+			.notBeforeTime(issued);
+	}
+
+	static ECKey.Builder signingWithEc(ECPublicKey pub, ECPrivateKey key) throws JOSEException {
+		Date issued = new Date();
+		Curve curve = Curve.forECParameterSpec(pub.getParams());
+		JWSAlgorithm algorithm = computeAlgorithm(curve);
+		return new ECKey.Builder(curve, pub).privateKey(key)
+			.keyOperations(Set.of(KeyOperation.SIGN))
+			.keyUse(KeyUse.SIGNATURE)
+			.algorithm(algorithm)
+			.keyIDFromThumbprint()
+			.issueTime(issued)
+			.notBeforeTime(issued);
+	}
+
+	private static JWSAlgorithm computeAlgorithm(Curve curve) {
+		try {
+			return ECDSA.resolveAlgorithm(curve);
+		}
+		catch (JOSEException ex) {
+			throw new IllegalArgumentException(ex);
+		}
+	}
+
+	static RSAKey.Builder signingWithRsa(RSAPublicKey pub, RSAPrivateKey key) throws JOSEException {
+		Date issued = new Date();
+		return new RSAKey.Builder(pub).privateKey(key)
+			.keyUse(KeyUse.SIGNATURE)
+			.keyOperations(Set.of(KeyOperation.SIGN))
+			.algorithm(JWSAlgorithm.RS256)
+			.keyIDFromThumbprint()
+			.issueTime(issued)
+			.notBeforeTime(issued);
+	}
+
+}

+ 138 - 148
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java

@@ -19,7 +19,10 @@ package org.springframework.security.oauth2.jwt;
 import java.net.URI;
 import java.net.URL;
 import java.security.KeyPair;
+import java.security.interfaces.ECPrivateKey;
 import java.security.interfaces.ECPublicKey;
+import java.security.interfaces.RSAPrivateKey;
+import java.security.interfaces.RSAPublicKey;
 import java.time.Instant;
 import java.util.ArrayList;
 import java.util.Date;
@@ -27,8 +30,8 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.UUID;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Consumer;
 
 import javax.crypto.SecretKey;
 
@@ -37,6 +40,7 @@ import com.nimbusds.jose.JOSEObjectType;
 import com.nimbusds.jose.JWSAlgorithm;
 import com.nimbusds.jose.JWSHeader;
 import com.nimbusds.jose.JWSSigner;
+import com.nimbusds.jose.crypto.MACSigner;
 import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
 import com.nimbusds.jose.jwk.Curve;
 import com.nimbusds.jose.jwk.ECKey;
@@ -58,11 +62,14 @@ import com.nimbusds.jwt.JWTClaimsSet;
 import com.nimbusds.jwt.SignedJWT;
 
 import org.springframework.core.convert.converter.Converter;
+import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
 import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
 import org.springframework.util.StringUtils;
+import org.springframework.util.function.ThrowingBiFunction;
+import org.springframework.util.function.ThrowingFunction;
 
 /**
  * An implementation of a {@link JwtEncoder} that encodes a JSON Web Token (JWT) using the
@@ -74,6 +81,8 @@ import org.springframework.util.StringUtils;
  * <b>NOTE:</b> This implementation uses the Nimbus JOSE + JWT SDK.
  *
  * @author Joe Grandja
+ * @author Josh Cummings
+ * @author Suraj Bhadrike
  * @since 5.6
  * @see JwtEncoder
  * @see com.nimbusds.jose.jwk.source.JWKSource
@@ -95,7 +104,7 @@ public final class NimbusJwtEncoder implements JwtEncoder {
 
 	private static final JWSSignerFactory JWS_SIGNER_FACTORY = new DefaultJWSSignerFactory();
 
-	private JwsHeader jwsHeader;
+	private final JwsHeader defaultJwsHeader;
 
 	private final Map<JWK, JWSSigner> jwsSigners = new ConcurrentHashMap<>();
 
@@ -114,10 +123,35 @@ public final class NimbusJwtEncoder implements JwtEncoder {
 	 * @param jwkSource the {@code com.nimbusds.jose.jwk.source.JWKSource}
 	 */
 	public NimbusJwtEncoder(JWKSource<SecurityContext> jwkSource) {
+		this.defaultJwsHeader = DEFAULT_JWS_HEADER;
 		Assert.notNull(jwkSource, "jwkSource cannot be null");
 		this.jwkSource = jwkSource;
 	}
 
+	private NimbusJwtEncoder(JWK jwk) {
+		Assert.notNull(jwk, "jwk cannot be null");
+		this.jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
+		JwsAlgorithm algorithm = SignatureAlgorithm.from(jwk.getAlgorithm().getName());
+		if (algorithm == null) {
+			algorithm = MacAlgorithm.from(jwk.getAlgorithm().getName());
+		}
+		Assert.notNull(algorithm, "Failed to derive supported algorithm from " + jwk.getAlgorithm());
+		JwsHeader.Builder builder = JwsHeader.with(algorithm).type(jwk.getKeyType().getValue()).keyId(jwk.getKeyID());
+		URI x509Url = jwk.getX509CertURL();
+		if (x509Url != null) {
+			builder.x509Url(jwk.getX509CertURL().toASCIIString());
+		}
+		List<Base64> certs = jwk.getX509CertChain();
+		if (certs != null) {
+			builder.x509CertificateChain(certs.stream().map(Base64::toString).toList());
+		}
+		Base64URL thumbprint = jwk.getX509CertSHA256Thumbprint();
+		if (thumbprint != null) {
+			builder.x509SHA256Thumbprint(thumbprint.toString());
+		}
+		this.defaultJwsHeader = builder.build();
+	}
+
 	/**
 	 * Use this strategy to reduce the list of matching JWKs when there is more than one.
 	 * <p>
@@ -133,16 +167,15 @@ public final class NimbusJwtEncoder implements JwtEncoder {
 		this.jwkSelector = jwkSelector;
 	}
 
-	public void setJwsHeader(JwsHeader jwsHeader) {
-		this.jwsHeader = jwsHeader;
-	}
-
 	@Override
 	public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException {
 		Assert.notNull(parameters, "parameters cannot be null");
 
 		JwsHeader headers = parameters.getJwsHeader();
-		headers = (headers != null) ? headers : (this.jwsHeader != null) ? this.jwsHeader : DEFAULT_JWS_HEADER;
+		if (headers == null) {
+			headers = this.defaultJwsHeader;
+		}
+
 		JwtClaimsSet claims = parameters.getClaims();
 
 		JWK jwk = selectJwk(headers);
@@ -387,38 +420,34 @@ public final class NimbusJwtEncoder implements JwtEncoder {
 
 	/**
 	 * Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
-	 * {@link SecretKey}.
-	 * @param secretKey the {@link SecretKey} to use for signing JWTs
-	 * @return a {@link SecretKeyJwtEncoderBuilder} for further configuration
+	 * @param publicKey the {@link RSAPublicKey} and @Param privateKey the
+	 * {@link RSAPrivateKey} to use for signing JWTs
+	 * @return a {@link RsaKeyPairJwtEncoderBuilder}
 	 * @since 7.0
 	 */
-	public static SecretKeyJwtEncoderBuilder withSecretKey(SecretKey secretKey) {
-		Assert.notNull(secretKey, "secretKey cannot be null");
-		return new SecretKeyJwtEncoderBuilder(secretKey);
+	public static RsaKeyPairJwtEncoderBuilder withKeyPair(RSAPublicKey publicKey, RSAPrivateKey privateKey) {
+		return new RsaKeyPairJwtEncoderBuilder(publicKey, privateKey);
+	}
+
+	/**
+	 * Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
+	 * @param publicKey the {@link ECPublicKey} and @param privateKey the
+	 * {@link ECPrivateKey} to use for signing JWTs
+	 * @return a {@link EcKeyPairJwtEncoderBuilder}
+	 * @since 7.0
+	 */
+	public static EcKeyPairJwtEncoderBuilder withKeyPair(ECPublicKey publicKey, ECPrivateKey privateKey) {
+		return new EcKeyPairJwtEncoderBuilder(publicKey, privateKey);
 	}
 
 	/**
 	 * Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
-	 * {@link KeyPair}. The key pair must contain either an {@link RSAKey} or an
-	 * {@link ECKey}.
-	 * @param keyPair the {@link KeyPair} to use for signing JWTs
-	 * @return a {@link KeyPairJwtEncoderBuilder} for further configuration
+	 * @param secretKey
+	 * @return a {@link SecretKeyJwtEncoderBuilder} for configuring the {@link JWK}
 	 * @since 7.0
 	 */
-	public static KeyPairJwtEncoderBuilder withKeyPair(KeyPair keyPair) {
-		Assert.isTrue(keyPair != null && keyPair.getPrivate() != null && keyPair.getPublic() != null,
-				"keyPair, its private key, and public key must not be null");
-		Assert.isTrue(
-				keyPair.getPrivate() instanceof java.security.interfaces.RSAKey
-						|| keyPair.getPrivate() instanceof java.security.interfaces.ECKey,
-				"keyPair must be an RSAKey or an ECKey");
-		if (keyPair.getPrivate() instanceof java.security.interfaces.RSAKey) {
-			return new RsaKeyPairJwtEncoderBuilder(keyPair);
-		}
-		if (keyPair.getPrivate() instanceof java.security.interfaces.ECKey) {
-			return new EcKeyPairJwtEncoderBuilder(keyPair);
-		}
-		throw new IllegalArgumentException("keyPair must be an RSAKey or an ECKey");
+	public static SecretKeyJwtEncoderBuilder withSecretKey(SecretKey secretKey) {
+		return new SecretKeyJwtEncoderBuilder(secretKey);
 	}
 
 	/**
@@ -429,14 +458,29 @@ public final class NimbusJwtEncoder implements JwtEncoder {
 	 */
 	public static final class SecretKeyJwtEncoderBuilder {
 
-		private final SecretKey secretKey;
+		private static final ThrowingFunction<SecretKey, OctetSequenceKey.Builder> defaultJwk = JWKS::signing;
 
-		private String keyId;
+		private final OctetSequenceKey.Builder builder;
 
-		private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.HS256;
+		private final Set<JWSAlgorithm> allowedAlgorithms;
 
 		private SecretKeyJwtEncoderBuilder(SecretKey secretKey) {
-			this.secretKey = secretKey;
+			Assert.notNull(secretKey, "secretKey cannot be null");
+			Set<JWSAlgorithm> allowedAlgorithms = computeAllowedAlgorithms(secretKey);
+			Assert.notEmpty(allowedAlgorithms,
+					"This key is too small for any standard JWK symmetric signing algorithm");
+			this.allowedAlgorithms = allowedAlgorithms;
+			this.builder = defaultJwk.apply(secretKey, IllegalArgumentException::new)
+				.algorithm(this.allowedAlgorithms.iterator().next());
+		}
+
+		private Set<JWSAlgorithm> computeAllowedAlgorithms(SecretKey secretKey) {
+			try {
+				return new MACSigner(secretKey).supportedJWSAlgorithms();
+			}
+			catch (JOSEException ex) {
+				throw new IllegalArgumentException(ex);
+			}
 		}
 
 		/**
@@ -446,24 +490,24 @@ public final class NimbusJwtEncoder implements JwtEncoder {
 		 * @param macAlgorithm the {@link MacAlgorithm} to use
 		 * @return this builder instance for method chaining
 		 */
-		public SecretKeyJwtEncoderBuilder macAlgorithm(MacAlgorithm macAlgorithm) {
+		public SecretKeyJwtEncoderBuilder algorithm(MacAlgorithm macAlgorithm) {
 			Assert.notNull(macAlgorithm, "macAlgorithm cannot be null");
-			Assert.state(JWSAlgorithm.Family.HMAC_SHA.contains(this.jwsAlgorithm),
-					() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with a SecretKey. "
-							+ "Please use one of the HS256, HS384, or HS512 algorithms.");
-
-			this.jwsAlgorithm = JWSAlgorithm.parse(macAlgorithm.getName());
+			JWSAlgorithm jws = JWSAlgorithm.parse(macAlgorithm.getName());
+			Assert.isTrue(this.allowedAlgorithms.contains(jws), String
+				.format("This key can only support " + "the following algorithms: [%s]", this.allowedAlgorithms));
+			this.builder.algorithm(JWSAlgorithm.parse(macAlgorithm.getName()));
 			return this;
 		}
 
 		/**
-		 * Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
-		 * header.
-		 * @param keyId the key identifier
+		 * Post-process the {@link JWK} using the given {@link Consumer}. For example, you
+		 * may use this to override the default {@code kid}
+		 * @param jwkPostProcessor the post-processor to use
 		 * @return this builder instance for method chaining
 		 */
-		public SecretKeyJwtEncoderBuilder keyId(String keyId) {
-			this.keyId = keyId;
+		public SecretKeyJwtEncoderBuilder jwkPostProcessor(Consumer<OctetSequenceKey.Builder> jwkPostProcessor) {
+			Assert.notNull(jwkPostProcessor, "jwkPostProcessor cannot be null");
+			jwkPostProcessor.accept(this.builder);
 			return this;
 		}
 
@@ -474,17 +518,7 @@ public final class NimbusJwtEncoder implements JwtEncoder {
 		 * with a {@link SecretKey}.
 		 */
 		public NimbusJwtEncoder build() {
-			this.jwsAlgorithm = (this.jwsAlgorithm != null) ? this.jwsAlgorithm : JWSAlgorithm.HS256;
-
-			OctetSequenceKey.Builder builder = new OctetSequenceKey.Builder(this.secretKey).keyUse(KeyUse.SIGNATURE)
-				.algorithm(this.jwsAlgorithm)
-				.keyID(this.keyId);
-
-			OctetSequenceKey jwk = builder.build();
-			JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
-			NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource);
-			encoder.setJwsHeader(JwsHeader.with(MacAlgorithm.from(this.jwsAlgorithm.getName())).build());
-			return encoder;
+			return new NimbusJwtEncoder(this.builder.build());
 		}
 
 	}
@@ -495,137 +529,93 @@ public final class NimbusJwtEncoder implements JwtEncoder {
 	 *
 	 * @since 7.0
 	 */
-	public abstract static class KeyPairJwtEncoderBuilder {
-
-		private final KeyPair keyPair;
+	public static final class RsaKeyPairJwtEncoderBuilder {
 
-		private String keyId;
+		private static final ThrowingBiFunction<RSAPublicKey, RSAPrivateKey, RSAKey.Builder> defaultKid = JWKS::signingWithRsa;
 
-		private JWSAlgorithm jwsAlgorithm;
+		private final RSAKey.Builder builder;
 
-		private KeyPairJwtEncoderBuilder(KeyPair keyPair) {
-			this.keyPair = keyPair;
+		private RsaKeyPairJwtEncoderBuilder(RSAPublicKey publicKey, RSAPrivateKey privateKey) {
+			Assert.notNull(publicKey, "publicKey cannot be null");
+			Assert.notNull(privateKey, "privateKey cannot be null");
+			this.builder = defaultKid.apply(publicKey, privateKey);
 		}
 
 		/**
-		 * Sets the JWS algorithm to use for signing. Must be compatible with the key type
-		 * (RSA or EC). If not set, a default algorithm will be chosen based on the key
-		 * type (e.g., RS256 for RSA, ES256 for EC).
+		 * Sets the JWS algorithm to use for signing. Defaults to
+		 * {@link SignatureAlgorithm#RS256}. Must be an RSA-based algorithm
 		 * @param signatureAlgorithm the {@link SignatureAlgorithm} to use
 		 * @return this builder instance for method chaining
 		 */
-		public KeyPairJwtEncoderBuilder signatureAlgorithm(SignatureAlgorithm signatureAlgorithm) {
+		public RsaKeyPairJwtEncoderBuilder algorithm(SignatureAlgorithm signatureAlgorithm) {
 			Assert.notNull(signatureAlgorithm, "signatureAlgorithm cannot be null");
-			this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
+			this.builder.algorithm(JWSAlgorithm.parse(signatureAlgorithm.getName()));
 			return this;
 		}
 
 		/**
-		 * Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
-		 * header.
-		 * @param keyId the key identifier
+		 * Add commentMore actions Post-process the {@link JWK} using the given
+		 * {@link Consumer}. For example, you may use this to override the default
+		 * {@code kid}
+		 * @param jwkPostProcessor the post-processor to use
 		 * @return this builder instance for method chaining
 		 */
-		public KeyPairJwtEncoderBuilder keyId(String keyId) {
-			this.keyId = keyId;
+		public RsaKeyPairJwtEncoderBuilder jwkPostProcessor(Consumer<RSAKey.Builder> jwkPostProcessor) {
+			Assert.notNull(jwkPostProcessor, "jwkPostProcessor cannot be null");
+			jwkPostProcessor.accept(this.builder);
 			return this;
 		}
 
 		/**
 		 * Builds the {@link NimbusJwtEncoder} instance.
 		 * @return the configured {@link NimbusJwtEncoder}
-		 * @throws IllegalStateException if the key type is unsupported or the configured
-		 * JWS algorithm is not compatible with the key type.
-		 * @throws JwtEncodingException if the key is invalid (e.g., EC key with unknown
-		 * curve)
 		 */
 		public NimbusJwtEncoder build() {
-			this.keyId = (this.keyId != null) ? this.keyId : UUID.randomUUID().toString();
-			JWK jwk = buildJwk();
-			JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
-			NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource);
-			JwsHeader jwsHeader = JwsHeader.with(SignatureAlgorithm.from(this.jwsAlgorithm.getName()))
-				.keyId(jwk.getKeyID())
-				.build();
-			encoder.setJwsHeader(jwsHeader);
-			return encoder;
+			return new NimbusJwtEncoder(this.builder.build());
 		}
 
-		protected abstract JWK buildJwk();
-
 	}
 
 	/**
 	 * A builder for creating {@link NimbusJwtEncoder} instances configured with a
-	 * {@link KeyPair}.
+	 * {@link ECPublicKey} and {@link ECPrivateKey}.
+	 * <p>
+	 * This builder is used to create a {@link NimbusJwtEncoder}
 	 *
 	 * @since 7.0
 	 */
-	public static final class RsaKeyPairJwtEncoderBuilder extends KeyPairJwtEncoderBuilder {
+	public static final class EcKeyPairJwtEncoderBuilder {
 
-		private RsaKeyPairJwtEncoderBuilder(KeyPair keyPair) {
-			super(keyPair);
-		}
+		private static final ThrowingBiFunction<ECPublicKey, ECPrivateKey, ECKey.Builder> defaultKid = JWKS::signingWithEc;
 
-		@Override
-		protected JWK buildJwk() {
-			if (super.jwsAlgorithm == null) {
-				super.jwsAlgorithm = JWSAlgorithm.RS256;
-			}
-			Assert.state(JWSAlgorithm.Family.RSA.contains(super.jwsAlgorithm),
-					() -> "The algorithm '" + super.jwsAlgorithm + "' is not compatible with an RSAKey. "
-							+ "Please use one of the RS256, RS384, RS512, PS256, PS384, or PS512 algorithms.");
+		private final ECKey.Builder builder;
 
-			RSAKey.Builder builder = new RSAKey.Builder(
-					(java.security.interfaces.RSAPublicKey) super.keyPair.getPublic())
-				.privateKey(super.keyPair.getPrivate())
-				.keyID(super.keyId)
-				.keyUse(KeyUse.SIGNATURE)
-				.algorithm(super.jwsAlgorithm);
-			return builder.build();
+		private EcKeyPairJwtEncoderBuilder(ECPublicKey publicKey, ECPrivateKey privateKey) {
+			Assert.notNull(publicKey, "publicKey cannot be null");
+			Assert.notNull(privateKey, "privateKey cannot be null");
+			Curve curve = Curve.forECParameterSpec(publicKey.getParams());
+			Assert.notNull(curve, "Unable to determine Curve for EC public key.");
+			this.builder = defaultKid.apply(publicKey, privateKey);
 		}
 
-	}
-
-	/**
-	 * A builder for creating {@link NimbusJwtEncoder} instances configured with a
-	 * {@link KeyPair}.
-	 *
-	 * @since 7.0
-	 */
-	public static final class EcKeyPairJwtEncoderBuilder extends KeyPairJwtEncoderBuilder {
-
-		private EcKeyPairJwtEncoderBuilder(KeyPair keyPair) {
-			super(keyPair);
+		/**
+		 * Post-process the {@link JWK} using the given {@link Consumer}. For example, you
+		 * may use this to override the default {@code kid}
+		 * @param jwkPostProcessor the post-processor to use
+		 * @return this builder instance for method chaining
+		 */
+		public EcKeyPairJwtEncoderBuilder jwkPostProcessor(Consumer<ECKey.Builder> jwkPostProcessor) {
+			Assert.notNull(jwkPostProcessor, "jwkPostProcessor cannot be null");
+			jwkPostProcessor.accept(this.builder);
+			return this;
 		}
 
-		@Override
-		protected JWK buildJwk() {
-			if (super.jwsAlgorithm == null) {
-				super.jwsAlgorithm = JWSAlgorithm.ES256;
-			}
-			Assert.state(JWSAlgorithm.Family.EC.contains(super.jwsAlgorithm),
-					() -> "The algorithm '" + super.jwsAlgorithm + "' is not compatible with an ECKey. "
-							+ "Please use one of the ES256, ES384, or ES512 algorithms.");
-
-			ECPublicKey publicKey = (ECPublicKey) super.keyPair.getPublic();
-			Curve curve = Curve.forECParameterSpec(publicKey.getParams());
-			if (curve == null) {
-				throw new JwtEncodingException("Unable to determine Curve for EC public key.");
-			}
-
-			com.nimbusds.jose.jwk.ECKey.Builder builder = new com.nimbusds.jose.jwk.ECKey.Builder(curve, publicKey)
-				.privateKey(super.keyPair.getPrivate())
-				.keyUse(KeyUse.SIGNATURE)
-				.keyID(super.keyId)
-				.algorithm(super.jwsAlgorithm);
-
-			try {
-				return builder.build();
-			}
-			catch (IllegalStateException ex) {
-				throw new IllegalArgumentException("Failed to build ECKey: " + ex.getMessage(), ex);
-			}
+		/**
+		 * Builds the {@link NimbusJwtEncoder} instance.
+		 * @return the configured {@link NimbusJwtEncoder}
+		 */
+		public NimbusJwtEncoder build() {
+			return new NimbusJwtEncoder(this.builder.build());
 		}
 
 	}

+ 94 - 113
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java

@@ -16,8 +16,6 @@
 
 package org.springframework.security.oauth2.jwt;
 
-import java.security.KeyPair;
-import java.security.KeyPairGenerator;
 import java.security.interfaces.ECPrivateKey;
 import java.security.interfaces.ECPublicKey;
 import java.time.Instant;
@@ -27,12 +25,15 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 import java.util.UUID;
+import java.util.function.Consumer;
 
 import javax.crypto.SecretKey;
 import javax.crypto.spec.SecretKeySpec;
 
+import com.nimbusds.jose.JOSEException;
 import com.nimbusds.jose.JWSAlgorithm;
 import com.nimbusds.jose.KeySourceException;
+import com.nimbusds.jose.jwk.Curve;
 import com.nimbusds.jose.jwk.ECKey;
 import com.nimbusds.jose.jwk.JWK;
 import com.nimbusds.jose.jwk.JWKSelector;
@@ -40,6 +41,8 @@ import com.nimbusds.jose.jwk.JWKSet;
 import com.nimbusds.jose.jwk.KeyUse;
 import com.nimbusds.jose.jwk.OctetSequenceKey;
 import com.nimbusds.jose.jwk.RSAKey;
+import com.nimbusds.jose.jwk.gen.ECKeyGenerator;
+import com.nimbusds.jose.jwk.gen.RSAKeyGenerator;
 import com.nimbusds.jose.jwk.source.JWKSource;
 import com.nimbusds.jose.proc.SecurityContext;
 import com.nimbusds.jose.util.Base64URL;
@@ -51,12 +54,12 @@ import org.mockito.stubbing.Answer;
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.oauth2.jose.TestJwks;
 import org.springframework.security.oauth2.jose.TestKeys;
+import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
-import static org.assertj.core.api.Assertions.assertThatNoException;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.BDDMockito.willAnswer;
@@ -353,160 +356,138 @@ public class NimbusJwtEncoderTests {
 		verifyNoInteractions(selector);
 	}
 
+	// Default algorithm
 	@Test
-	void secretKeyBuilderWithDefaultAlgorithm() {
-		SecretKey secretKey = new SecretKeySpec("thisIsASecretKeyUsedForTesting12345".getBytes(), "HMAC");
+	void keyPairBuilderWithRsaDefaultAlgorithm() throws JOSEException {
+		RSAKeyGenerator generator = new RSAKeyGenerator(2048);
+		RSAKey key = generator.generate();
+		NimbusJwtEncoder jwtEncoder = NimbusJwtEncoder.withKeyPair(key.toRSAPublicKey(), key.toRSAPrivateKey()).build();
 		JwtClaimsSet claims = buildClaims();
-
-		NimbusJwtEncoder encoder = NimbusJwtEncoder.withSecretKey(secretKey).build();
-		Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims));
-
-		assertThat(jwt).isNotNull();
-		assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("HS256");
-		assertThatNoException().isThrownBy(jwt::getClaims);
+		Jwt jwt = jwtEncoder.encode(JwtEncoderParameters.from(claims));
 		assertJwt(jwt);
+		assertThat(jwt.getHeaders()).containsKey(JoseHeaderNames.KID);
 	}
 
 	@Test
-	void secretKeyBuilderWithKeyId() {
-		SecretKey secretKey = new SecretKeySpec("thisIsASecretKeyUsedForTesting12345".getBytes(), "HMAC");
-		String keyId = "test-key-id";
+	void keyPairBuilderWithEcDefaultAlgorithm() throws JOSEException {
+		ECKeyGenerator generator = new ECKeyGenerator(Curve.P_256);
+		ECKey key = generator.generate();
+		NimbusJwtEncoder jwtEncoder = NimbusJwtEncoder.withKeyPair(key.toECPublicKey(), key.toECPrivateKey()).build();
 		JwtClaimsSet claims = buildClaims();
-
-		NimbusJwtEncoder encoder = NimbusJwtEncoder.withSecretKey(secretKey).keyId(keyId).build();
-		Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims));
-
-		assertThat(jwt).isNotNull();
-		assertThat(jwt.getHeaders().get("kid").toString()).isEqualTo(keyId);
-		assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("HS256");
-		assertThatNoException().isThrownBy(jwt::getClaims);
+		Jwt jwt = jwtEncoder.encode(JwtEncoderParameters.from(claims));
 		assertJwt(jwt);
+		assertThat(jwt.getHeaders()).containsKey(JoseHeaderNames.KID);
 	}
 
 	@Test
-	void secretKeyBuilderWithCustomJwkSelector() {
-		SecretKey secretKey = new SecretKeySpec("thisIsASecretKeyUsedForTesting12345".getBytes(), "HMAC");
-		String keyId = "test-key-id";
+	void keyPairBuilderWithSecretKeyDefaultAlgorithm() {
+		SecretKey key = TestKeys.DEFAULT_SECRET_KEY;
+		NimbusJwtEncoder jwtEncoder = NimbusJwtEncoder.withSecretKey(key).build();
 		JwtClaimsSet claims = buildClaims();
-
-		NimbusJwtEncoder encoder = NimbusJwtEncoder.withSecretKey(secretKey).keyId(keyId).build();
-		Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims));
-
-		assertThat(jwt).isNotNull();
-		assertThat(jwt.getHeaders().get("kid")).isEqualTo(keyId);
-		assertThat(jwt.getClaims()).containsEntry("sub", "subject");
-		assertThatNoException().isThrownBy(() -> jwt.getClaims());
+		Jwt jwt = jwtEncoder.encode(JwtEncoderParameters.from(claims));
 		assertJwt(jwt);
+		assertThat(jwt.getHeaders()).containsKey(JoseHeaderNames.KID);
 	}
 
+	// With custom algorithm
 	@Test
-	void secretKeyBuilderWithCustomHeaders() {
-		SecretKey secretKey = new SecretKeySpec("thisIsASecretKeyUsedForTesting12345".getBytes(), "HMAC");
-		JwtClaimsSet claims = buildClaims();
-		JwsHeader headers = JwsHeader.with(org.springframework.security.oauth2.jose.jws.MacAlgorithm.HS256)
-			.type("JWT")
-			.contentType("application/jwt")
+	void keyPairBuilderWithRsaWithAlgorithm() throws JOSEException {
+		RSAKeyGenerator generator = new RSAKeyGenerator(2048);
+		RSAKey key = generator.generate();
+		NimbusJwtEncoder jwtEncoder = NimbusJwtEncoder.withKeyPair(key.toRSAPublicKey(), key.toRSAPrivateKey())
+			.algorithm(SignatureAlgorithm.RS384)
 			.build();
-
-		NimbusJwtEncoder encoder = NimbusJwtEncoder.withSecretKey(secretKey).build();
-		Jwt jwt = encoder.encode(JwtEncoderParameters.from(headers, claims));
-
-		assertThat(jwt).isNotNull();
-		assertThat(jwt.getHeaders().get("typ").toString()).isEqualTo("JWT");
-		assertThat(jwt.getHeaders().get("cty").toString()).isEqualTo("application/jwt");
-		assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("HS256");
-		assertThatNoException().isThrownBy(() -> jwt.getClaims());
+		JwtClaimsSet claims = buildClaims();
+		Jwt jwt = jwtEncoder.encode(JwtEncoderParameters.from(claims));
 		assertJwt(jwt);
+		assertThat(jwt.getHeaders()).containsEntry(JoseHeaderNames.ALG, SignatureAlgorithm.RS384);
+		assertThat(jwt.getHeaders()).containsKey(JoseHeaderNames.KID);
 	}
 
 	@Test
-	void keyPairBuilderWithRsaDefaultAlgorithm() throws Exception {
-		KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
-		keyPairGenerator.initialize(2048);
-		KeyPair keyPair = keyPairGenerator.generateKeyPair();
+	void keyPairBuilderWithEcWithAlgorithm() throws JOSEException {
+		ECKeyGenerator generator = new ECKeyGenerator(Curve.P_384);
+		ECKey key = generator.generate();
+		NimbusJwtEncoder jwtEncoder = NimbusJwtEncoder.withKeyPair(key.toECPublicKey(), key.toECPrivateKey()).build();
 		JwtClaimsSet claims = buildClaims();
-
-		NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair).build();
-		Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims));
-
-		assertThat(jwt).isNotNull();
-		assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("RS256");
-		assertThat(jwt.getSubject()).isEqualTo(claims.getSubject());
-		assertThat(jwt.getAudience()).isEqualTo(claims.getAudience());
-		assertThatNoException().isThrownBy(() -> jwt.getClaims());
+		Jwt jwt = jwtEncoder.encode(JwtEncoderParameters.from(claims));
 		assertJwt(jwt);
+		assertThat(jwt.getHeaders()).containsEntry(JoseHeaderNames.ALG, SignatureAlgorithm.ES384);
+		assertThat(jwt.getHeaders()).containsKey(JoseHeaderNames.KID);
 	}
 
 	@Test
-	void keyPairBuilderWithRsaCustomAlgorithm() throws Exception {
-		KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
-		keyPairGenerator.initialize(2048);
-		KeyPair keyPair = keyPairGenerator.generateKeyPair();
+	void keyPairBuilderWithSecretKeyWithAlgorithm() {
+		String keyStr = UUID.randomUUID().toString();
+		keyStr += keyStr;
+		SecretKey Key = new SecretKeySpec(keyStr.getBytes(), "AES");
+		NimbusJwtEncoder jwtEncoder = NimbusJwtEncoder.withSecretKey(Key).algorithm(MacAlgorithm.HS512).build();
 		JwtClaimsSet claims = buildClaims();
-
-		NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair)
-			.signatureAlgorithm(SignatureAlgorithm.RS512)
-			.build();
-		Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims));
-
-		assertThat(jwt).isNotNull();
-		assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("RS512");
-		assertThat(jwt.getSubject()).isEqualTo(claims.getSubject());
-		assertThatNoException().isThrownBy(() -> jwt.getClaims());
+		Jwt jwt = jwtEncoder.encode(JwtEncoderParameters.from(claims));
 		assertJwt(jwt);
+		assertThat(jwt.getHeaders()).containsEntry(JoseHeaderNames.ALG, MacAlgorithm.HS512);
+		assertThat(jwt.getHeaders()).containsKey(JoseHeaderNames.KID);
 	}
 
 	@Test
-	void keyPairBuilderWithEcDefaultAlgorithm() throws Exception {
-		KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("EC");
-		keyPairGenerator.initialize(256);
-		KeyPair keyPair = keyPairGenerator.generateKeyPair();
-		JwtClaimsSet claims = buildClaims();
+	void keyPairBuilderWhenShortSecretThenHigherAlgorithmNotSupported() {
+		String keyStr = UUID.randomUUID().toString();
+		SecretKey Key = new SecretKeySpec(keyStr.getBytes(), "AES");
+		assertThatExceptionOfType(IllegalArgumentException.class)
+			.isThrownBy(() -> NimbusJwtEncoder.withSecretKey(Key).algorithm(MacAlgorithm.HS512).build());
+	}
 
-		NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair).build();
-		Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims));
+	@Test
+	void keyPairBuilderWhenTooShortSecretThenException() {
+		SecretKey Key = new SecretKeySpec("key".getBytes(), "AES");
+		assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> NimbusJwtEncoder.withSecretKey(Key));
+	}
 
-		assertThat(jwt).isNotNull();
-		assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("ES256");
-		assertThat(jwt.getSubject()).isEqualTo(claims.getSubject());
-		assertThatNoException().isThrownBy(() -> jwt.getClaims());
+	// with custom jwkPostProcessor
+	@Test
+	void keyPairBuilderWithRsaWithAlgorithmAndJwkSource() throws JOSEException {
+		RSAKeyGenerator generator = new RSAKeyGenerator(2048);
+		RSAKey key = generator.generate();
+		String keyId = UUID.randomUUID().toString();
+		NimbusJwtEncoder jwtEncoder = NimbusJwtEncoder.withKeyPair(key.toRSAPublicKey(), key.toRSAPrivateKey())
+			.algorithm(SignatureAlgorithm.RS384)
+			.jwkPostProcessor((builder) -> builder.keyID(keyId))
+			.build();
+		JwtClaimsSet claims = buildClaims();
+		Jwt jwt = jwtEncoder.encode(JwtEncoderParameters.from(claims));
 		assertJwt(jwt);
+		assertThat(jwt.getHeaders()).containsEntry(JoseHeaderNames.ALG, SignatureAlgorithm.RS384);
+		assertThat(jwt.getHeaders()).containsEntry(JoseHeaderNames.KID, keyId);
 	}
 
 	@Test
-	void keyPairBuilderWithEcCustomAlgorithm() throws Exception {
-		KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("EC");
-		keyPairGenerator.initialize(256);
-		KeyPair keyPair = keyPairGenerator.generateKeyPair();
-		NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair)
-			.keyId(UUID.randomUUID().toString())
-			.signatureAlgorithm(SignatureAlgorithm.ES256)
+	void keyPairBuilderWithEcWithAlgorithmAndJwkSource() throws JOSEException {
+		ECKeyGenerator generator = new ECKeyGenerator(Curve.P_256);
+		ECKey key = generator.generate();
+		String keyId = UUID.randomUUID().toString();
+		Consumer<ECKey.Builder> jwkPostProcessor = (builder) -> builder.keyID(keyId);
+		NimbusJwtEncoder jwtEncoder = NimbusJwtEncoder.withKeyPair(key.toECPublicKey(), key.toECPrivateKey())
+			.jwkPostProcessor(jwkPostProcessor)
 			.build();
-
 		JwtClaimsSet claims = buildClaims();
-		Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims));
-
-		assertThat(jwt).isNotNull();
-		assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("ES256");
-		assertThatNoException().isThrownBy(() -> jwt.getClaims());
+		Jwt jwt = jwtEncoder.encode(JwtEncoderParameters.from(claims));
 		assertJwt(jwt);
+		assertThat(jwt.getHeaders()).containsEntry(JoseHeaderNames.ALG, SignatureAlgorithm.ES256);
+		assertThat(jwt.getHeaders()).containsEntry(JoseHeaderNames.KID, keyId);
 	}
 
 	@Test
-	void keyPairBuilderWithKeyId() throws Exception { // d
-		KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
-		keyPairGenerator.initialize(2048);
-		KeyPair keyPair = keyPairGenerator.generateKeyPair();
-		String keyId = "test-key-id";
+	void keyPairBuilderWithSecretKeyWithAlgorithmAndJwkSource() {
+		final String keyStr = UUID.randomUUID().toString();
+		SecretKey key = new SecretKeySpec(keyStr.getBytes(), "HS256");
+		String keyId = UUID.randomUUID().toString();
+		Consumer<OctetSequenceKey.Builder> jwkPostProcessor = (builder) -> builder.keyID(keyId);
+		NimbusJwtEncoder jwtEncoder = NimbusJwtEncoder.withSecretKey(key).jwkPostProcessor(jwkPostProcessor).build();
 		JwtClaimsSet claims = buildClaims();
-
-		NimbusJwtEncoder encoder = NimbusJwtEncoder.withKeyPair(keyPair).keyId(keyId).build();
-		Jwt jwt = encoder.encode(JwtEncoderParameters.from(claims));
-
-		assertThat(jwt).isNotNull();
-		assertThat(jwt.getHeaders().get("kid")).isEqualTo(keyId);
-		assertThat(jwt.getHeaders().get("alg").toString()).isEqualTo("RS256");
-		assertThatNoException().isThrownBy(() -> jwt.getClaims());
+		Jwt jwt = jwtEncoder.encode(JwtEncoderParameters.from(claims));
+		assertJwt(jwt);
+		assertThat(jwt.getHeaders()).containsEntry(JoseHeaderNames.ALG, MacAlgorithm.HS256);
+		assertThat(jwt.getHeaders()).containsEntry(JoseHeaderNames.KID, keyId);
 	}
 
 	private JwtClaimsSet buildClaims() {