Browse Source

Allow customizing Jwt claims and headers

Closes gh-173
Joe Grandja 4 years ago
parent
commit
79f1cf5a50

+ 32 - 37
oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java

@@ -26,6 +26,7 @@ import org.springframework.security.config.annotation.web.configurers.AbstractHt
 import org.springframework.security.config.annotation.web.configurers.ExceptionHandlingConfigurer;
 import org.springframework.security.crypto.key.CryptoKeySource;
 import org.springframework.security.oauth2.jose.jws.NimbusJwsEncoder;
+import org.springframework.security.oauth2.jwt.JwtEncoder;
 import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationProvider;
@@ -166,7 +167,7 @@ public final class OAuth2AuthorizationServerConfigurer<B extends HttpSecurityBui
 						getAuthorizationService(builder));
 		builder.authenticationProvider(postProcess(clientAuthenticationProvider));
 
-		NimbusJwsEncoder jwtEncoder = new NimbusJwsEncoder(getKeySource(builder));
+		JwtEncoder jwtEncoder = getJwtEncoder(builder);
 
 		OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider =
 				new OAuth2AuthorizationCodeAuthenticationProvider(
@@ -253,23 +254,29 @@ public final class OAuth2AuthorizationServerConfigurer<B extends HttpSecurityBui
 		builder.addFilterAfter(postProcess(tokenRevocationEndpointFilter), OAuth2TokenEndpointFilter.class);
 	}
 
+	private static void validateProviderSettings(ProviderSettings providerSettings) {
+		if (providerSettings.issuer() != null) {
+			try {
+				new URI(providerSettings.issuer()).toURL();
+			} catch (Exception ex) {
+				throw new IllegalArgumentException("issuer must be a valid URL", ex);
+			}
+		}
+	}
+
 	private static <B extends HttpSecurityBuilder<B>> RegisteredClientRepository getRegisteredClientRepository(B builder) {
 		RegisteredClientRepository registeredClientRepository = builder.getSharedObject(RegisteredClientRepository.class);
 		if (registeredClientRepository == null) {
-			registeredClientRepository = getRegisteredClientRepositoryBean(builder);
+			registeredClientRepository = getBean(builder, RegisteredClientRepository.class);
 			builder.setSharedObject(RegisteredClientRepository.class, registeredClientRepository);
 		}
 		return registeredClientRepository;
 	}
 
-	private static <B extends HttpSecurityBuilder<B>> RegisteredClientRepository getRegisteredClientRepositoryBean(B builder) {
-		return builder.getSharedObject(ApplicationContext.class).getBean(RegisteredClientRepository.class);
-	}
-
 	private static <B extends HttpSecurityBuilder<B>> OAuth2AuthorizationService getAuthorizationService(B builder) {
 		OAuth2AuthorizationService authorizationService = builder.getSharedObject(OAuth2AuthorizationService.class);
 		if (authorizationService == null) {
-			authorizationService = getAuthorizationServiceBean(builder);
+			authorizationService = getOptionalBean(builder, OAuth2AuthorizationService.class);
 			if (authorizationService == null) {
 				authorizationService = new InMemoryOAuth2AuthorizationService();
 			}
@@ -278,34 +285,28 @@ public final class OAuth2AuthorizationServerConfigurer<B extends HttpSecurityBui
 		return authorizationService;
 	}
 
-	private static <B extends HttpSecurityBuilder<B>> OAuth2AuthorizationService getAuthorizationServiceBean(B builder) {
-		Map<String, OAuth2AuthorizationService> authorizationServiceMap = BeanFactoryUtils.beansOfTypeIncludingAncestors(
-				builder.getSharedObject(ApplicationContext.class), OAuth2AuthorizationService.class);
-		if (authorizationServiceMap.size() > 1) {
-			throw new NoUniqueBeanDefinitionException(OAuth2AuthorizationService.class, authorizationServiceMap.size(),
-					"Expected single matching bean of type '" + OAuth2AuthorizationService.class.getName() + "' but found " +
-							authorizationServiceMap.size() + ": " + StringUtils.collectionToCommaDelimitedString(authorizationServiceMap.keySet()));
+	private static <B extends HttpSecurityBuilder<B>> JwtEncoder getJwtEncoder(B builder) {
+		JwtEncoder jwtEncoder = getOptionalBean(builder, JwtEncoder.class);
+		if (jwtEncoder == null) {
+			CryptoKeySource keySource = getKeySource(builder);
+			jwtEncoder = new NimbusJwsEncoder(keySource);
 		}
-		return (!authorizationServiceMap.isEmpty() ? authorizationServiceMap.values().iterator().next() : null);
+		return jwtEncoder;
 	}
 
 	private static <B extends HttpSecurityBuilder<B>> CryptoKeySource getKeySource(B builder) {
 		CryptoKeySource keySource = builder.getSharedObject(CryptoKeySource.class);
 		if (keySource == null) {
-			keySource = getKeySourceBean(builder);
+			keySource = getBean(builder, CryptoKeySource.class);
 			builder.setSharedObject(CryptoKeySource.class, keySource);
 		}
 		return keySource;
 	}
 
-	private static <B extends HttpSecurityBuilder<B>> CryptoKeySource getKeySourceBean(B builder) {
-		return builder.getSharedObject(ApplicationContext.class).getBean(CryptoKeySource.class);
-	}
-
 	private static <B extends HttpSecurityBuilder<B>> ProviderSettings getProviderSettings(B builder) {
 		ProviderSettings providerSettings = builder.getSharedObject(ProviderSettings.class);
 		if (providerSettings == null) {
-			providerSettings = getProviderSettingsBean(builder);
+			providerSettings = getOptionalBean(builder, ProviderSettings.class);
 			if (providerSettings == null) {
 				providerSettings = new ProviderSettings();
 			}
@@ -314,24 +315,18 @@ public final class OAuth2AuthorizationServerConfigurer<B extends HttpSecurityBui
 		return providerSettings;
 	}
 
-	private static <B extends HttpSecurityBuilder<B>> ProviderSettings getProviderSettingsBean(B builder) {
-		Map<String, ProviderSettings> providerSettingsMap = BeanFactoryUtils.beansOfTypeIncludingAncestors(
-				builder.getSharedObject(ApplicationContext.class), ProviderSettings.class);
-		if (providerSettingsMap.size() > 1) {
-			throw new NoUniqueBeanDefinitionException(ProviderSettings.class, providerSettingsMap.size(),
-					"Expected single matching bean of type '" + ProviderSettings.class.getName() + "' but found " +
-							providerSettingsMap.size() + ": " + StringUtils.collectionToCommaDelimitedString(providerSettingsMap.keySet()));
-		}
-		return (!providerSettingsMap.isEmpty() ? providerSettingsMap.values().iterator().next() : null);
+	private static <B extends HttpSecurityBuilder<B>, T> T getBean(B builder, Class<T> type) {
+		return builder.getSharedObject(ApplicationContext.class).getBean(type);
 	}
 
-	private void validateProviderSettings(ProviderSettings providerSettings) {
-		if (providerSettings.issuer() != null) {
-			try {
-				new URI(providerSettings.issuer()).toURL();
-			} catch (Exception ex) {
-				throw new IllegalArgumentException("issuer must be a valid URL", ex);
-			}
+	private static <B extends HttpSecurityBuilder<B>, T> T getOptionalBean(B builder, Class<T> type) {
+		Map<String, T> beansMap = BeanFactoryUtils.beansOfTypeIncludingAncestors(
+				builder.getSharedObject(ApplicationContext.class), type);
+		if (beansMap.size() > 1) {
+			throw new NoUniqueBeanDefinitionException(type, beansMap.size(),
+					"Expected single matching bean of type '" + type.getName() + "' but found " +
+							beansMap.size() + ": " + StringUtils.collectionToCommaDelimitedString(beansMap.keySet()));
 		}
+		return (!beansMap.isEmpty() ? beansMap.values().iterator().next() : null);
 	}
 }

+ 25 - 7
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoder.java

@@ -53,6 +53,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.UUID;
+import java.util.function.BiConsumer;
 import java.util.stream.Collectors;
 
 /**
@@ -94,6 +95,7 @@ public final class NimbusJwsEncoder implements JwtEncoder {
 	private static final Converter<JoseHeader, JWSHeader> jwsHeaderConverter = new JwsHeaderConverter();
 	private static final Converter<JwtClaimsSet, JWTClaimsSet> jwtClaimsSetConverter = new JwtClaimsSetConverter();
 	private final CryptoKeySource keySource;
+	private BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer = (headers, claims) -> {};
 
 	/**
 	 * Constructs a {@code NimbusJwsEncoder} using the provided parameters.
@@ -105,6 +107,19 @@ public final class NimbusJwsEncoder implements JwtEncoder {
 		this.keySource = keySource;
 	}
 
+	/**
+	 * Sets the {@link Jwt} customizer to be provided the
+	 * {@link JoseHeader.Builder} and {@link JwtClaimsSet.Builder}
+	 * allowing for further customizations.
+	 *
+	 * @param jwtCustomizer the {@link Jwt} customizer to be provided the
+	 * {@link JoseHeader.Builder} and {@link JwtClaimsSet.Builder}
+	 */
+	public void setJwtCustomizer(BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer) {
+		Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null");
+		this.jwtCustomizer = jwtCustomizer;
+	}
+
 	@Override
 	public Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException {
 		Assert.notNull(headers, "headers cannot be null");
@@ -136,15 +151,18 @@ public final class NimbusJwsEncoder implements JwtEncoder {
 			}
 		}
 
-		headers = JoseHeader.from(headers)
+		JoseHeader.Builder headersBuilder = JoseHeader.from(headers)
 				.type(JOSEObjectType.JWT.getType())
-				.keyId(cryptoKey.getId())
-				.build();
-		JWSHeader jwsHeader = jwsHeaderConverter.convert(headers);
+				.keyId(cryptoKey.getId());
+		JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.from(claims)
+				.id(UUID.randomUUID().toString());
+
+		this.jwtCustomizer.accept(headersBuilder, claimsBuilder);
 
-		claims = JwtClaimsSet.from(claims)
-				.id(UUID.randomUUID().toString())
-				.build();
+		headers = headersBuilder.build();
+		claims = claimsBuilder.build();
+
+		JWSHeader jwsHeader = jwsHeaderConverter.convert(headers);
 		JWTClaimsSet jwtClaimsSet = jwtClaimsSetConverter.convert(claims);
 
 		SignedJWT signedJWT = new SignedJWT(jwsHeader, jwtClaimsSet);

+ 42 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java

@@ -33,6 +33,10 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
+import org.springframework.security.oauth2.jose.JoseHeader;
+import org.springframework.security.oauth2.jose.jws.NimbusJwsEncoder;
+import org.springframework.security.oauth2.jwt.JwtClaimsSet;
+import org.springframework.security.oauth2.jwt.JwtEncoder;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
@@ -53,6 +57,7 @@ import org.springframework.util.StringUtils;
 import java.net.URLEncoder;
 import java.nio.charset.StandardCharsets;
 import java.util.Base64;
+import java.util.function.BiConsumer;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.hamcrest.CoreMatchers.containsString;
@@ -86,6 +91,8 @@ public class OAuth2AuthorizationCodeGrantTests {
 	private static RegisteredClientRepository registeredClientRepository;
 	private static OAuth2AuthorizationService authorizationService;
 	private static CryptoKeySource keySource;
+	private static NimbusJwsEncoder jwtEncoder;
+	private static BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer;
 
 	@Rule
 	public final SpringTestRule spring = new SpringTestRule();
@@ -98,6 +105,9 @@ public class OAuth2AuthorizationCodeGrantTests {
 		registeredClientRepository = mock(RegisteredClientRepository.class);
 		authorizationService = mock(OAuth2AuthorizationService.class);
 		keySource = new StaticKeyGeneratingCryptoKeySource();
+		jwtEncoder = new NimbusJwsEncoder(keySource);
+		jwtCustomizer = mock(BiConsumer.class);
+		jwtEncoder.setJwtCustomizer(jwtCustomizer);
 	}
 
 	@Before
@@ -223,6 +233,28 @@ public class OAuth2AuthorizationCodeGrantTests {
 		verify(authorizationService, times(2)).save(any());
 	}
 
+	@Test
+	public void requestWhenCustomJwtEncoderThenUsed() throws Exception {
+		this.spring.register(AuthorizationServerConfigurationWithJwtEncoder.class).autowire();
+
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		when(registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
+				.thenReturn(registeredClient);
+
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
+		when(authorizationService.findByToken(
+				eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
+				eq(TokenType.AUTHORIZATION_CODE)))
+				.thenReturn(authorization);
+
+		this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI)
+				.params(getTokenRequestParameters(registeredClient, authorization))
+				.header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth(
+						registeredClient.getClientId(), registeredClient.getClientSecret())));
+
+		verify(jwtCustomizer).accept(any(JoseHeader.Builder.class), any(JwtClaimsSet.Builder.class));
+	}
+
 	private static MultiValueMap<String, String> getAuthorizationRequestParameters(RegisteredClient registeredClient) {
 		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
 		parameters.set(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
@@ -270,4 +302,14 @@ public class OAuth2AuthorizationCodeGrantTests {
 			return keySource;
 		}
 	}
+
+	@EnableWebSecurity
+	@Import(OAuth2AuthorizationServerConfiguration.class)
+	static class AuthorizationServerConfigurationWithJwtEncoder extends AuthorizationServerConfiguration {
+
+		@Bean
+		JwtEncoder jwtEncoder() {
+			return jwtEncoder;
+		}
+	}
 }

+ 28 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/jose/jws/NimbusJwsEncoderTests.java

@@ -32,11 +32,14 @@ import org.springframework.security.oauth2.jwt.TestJwtClaimsSets;
 import java.security.interfaces.RSAPublicKey;
 import java.util.Collections;
 import java.util.LinkedHashSet;
+import java.util.function.BiConsumer;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 /**
@@ -61,6 +64,13 @@ public class NimbusJwsEncoderTests {
 				.hasMessage("keySource cannot be null");
 	}
 
+	@Test
+	public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.jwtEncoder.setJwtCustomizer(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("jwtCustomizer cannot be null");
+	}
+
 	@Test
 	public void encodeWhenHeadersNullThenThrowIllegalArgumentException() {
 		JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
@@ -128,6 +138,24 @@ public class NimbusJwsEncoderTests {
 		jwtDecoder.decode(jws.getTokenValue());
 	}
 
+	@Test
+	public void encodeWhenCustomizerSetThenCalled() {
+		AsymmetricKey rsaKey = TestCryptoKeys.rsaKey().build();
+		when(this.keySource.getKeys()).thenReturn(Collections.singleton(rsaKey));
+
+		JoseHeader joseHeader = TestJoseHeaders.joseHeader()
+				.headers(headers -> headers.remove(JoseHeaderNames.CRIT))
+				.build();
+		JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
+
+		BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer = mock(BiConsumer.class);
+		this.jwtEncoder.setJwtCustomizer(jwtCustomizer);
+
+		this.jwtEncoder.encode(joseHeader, jwtClaimsSet);
+
+		verify(jwtCustomizer).accept(any(JoseHeader.Builder.class), any(JwtClaimsSet.Builder.class));
+	}
+
 	@Test
 	public void encodeWhenMultipleActiveKeysThenUseFirst() {
 		AsymmetricKey rsaKey1 = TestCryptoKeys.rsaKey().build();