Browse Source

Configure token-exchange via a bean

Issue gh-5199
Issue gh-11783
Closes gh-14701
Steve Riesenberg 1 year ago
parent
commit
d6382b83dc

+ 29 - 1
config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java

@@ -51,11 +51,13 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.TokenExchangeOAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
 import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
@@ -183,7 +185,8 @@ final class OAuth2ClientConfiguration {
 				RefreshTokenOAuth2AuthorizedClientProvider.class,
 				ClientCredentialsOAuth2AuthorizedClientProvider.class,
 				PasswordOAuth2AuthorizedClientProvider.class,
-				JwtBearerOAuth2AuthorizedClientProvider.class
+				JwtBearerOAuth2AuthorizedClientProvider.class,
+				TokenExchangeOAuth2AuthorizedClientProvider.class
 		);
 		// @formatter:on
 
@@ -255,6 +258,12 @@ final class OAuth2ClientConfiguration {
 					authorizedClientProviders.add(jwtBearerAuthorizedClientProvider);
 				}
 
+				OAuth2AuthorizedClientProvider tokenExchangeAuthorizedClientProvider = getTokenExchangeAuthorizedClientProvider(
+						authorizedClientProviderBeans);
+				if (tokenExchangeAuthorizedClientProvider != null) {
+					authorizedClientProviders.add(tokenExchangeAuthorizedClientProvider);
+				}
+
 				authorizedClientProviders.addAll(getAdditionalAuthorizedClientProviders(authorizedClientProviderBeans));
 				authorizedClientProvider = new DelegatingOAuth2AuthorizedClientProvider(authorizedClientProviders);
 			}
@@ -364,6 +373,25 @@ final class OAuth2ClientConfiguration {
 			return authorizedClientProvider;
 		}
 
+		private OAuth2AuthorizedClientProvider getTokenExchangeAuthorizedClientProvider(
+				Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders) {
+			TokenExchangeOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType(
+					authorizedClientProviders, TokenExchangeOAuth2AuthorizedClientProvider.class);
+
+			OAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> accessTokenResponseClient = getBeanOfType(
+					ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class,
+							TokenExchangeGrantRequest.class));
+			if (accessTokenResponseClient != null) {
+				if (authorizedClientProvider == null) {
+					authorizedClientProvider = new TokenExchangeOAuth2AuthorizedClientProvider();
+				}
+
+				authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
+			}
+
+			return authorizedClientProvider;
+		}
+
 		private List<OAuth2AuthorizedClientProvider> getAdditionalAuthorizedClientProviders(
 				Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders) {
 			List<OAuth2AuthorizedClientProvider> additionalAuthorizedClientProviders = new ArrayList<>(

+ 29 - 1
config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java

@@ -44,11 +44,13 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.TokenExchangeOAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
 import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
@@ -76,7 +78,8 @@ final class OAuth2AuthorizedClientManagerRegistrar implements BeanDefinitionRegi
 			RefreshTokenOAuth2AuthorizedClientProvider.class,
 			ClientCredentialsOAuth2AuthorizedClientProvider.class,
 			PasswordOAuth2AuthorizedClientProvider.class,
-			JwtBearerOAuth2AuthorizedClientProvider.class
+			JwtBearerOAuth2AuthorizedClientProvider.class,
+			TokenExchangeOAuth2AuthorizedClientProvider.class
 	);
 	// @formatter:on
 
@@ -137,6 +140,12 @@ final class OAuth2AuthorizedClientManagerRegistrar implements BeanDefinitionRegi
 				authorizedClientProviders.add(jwtBearerAuthorizedClientProvider);
 			}
 
+			OAuth2AuthorizedClientProvider tokenExchangeAuthorizedClientProvider = getTokenExchangeAuthorizedClientProvider(
+					authorizedClientProviderBeans);
+			if (tokenExchangeAuthorizedClientProvider != null) {
+				authorizedClientProviders.add(tokenExchangeAuthorizedClientProvider);
+			}
+
 			authorizedClientProviders.addAll(getAdditionalAuthorizedClientProviders(authorizedClientProviderBeans));
 			authorizedClientProvider = new DelegatingOAuth2AuthorizedClientProvider(authorizedClientProviders);
 		}
@@ -245,6 +254,25 @@ final class OAuth2AuthorizedClientManagerRegistrar implements BeanDefinitionRegi
 		return authorizedClientProvider;
 	}
 
+	private OAuth2AuthorizedClientProvider getTokenExchangeAuthorizedClientProvider(
+			Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders) {
+		TokenExchangeOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType(
+				authorizedClientProviders, TokenExchangeOAuth2AuthorizedClientProvider.class);
+
+		OAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> accessTokenResponseClient = getBeanOfType(
+				ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class,
+						TokenExchangeGrantRequest.class));
+		if (accessTokenResponseClient != null) {
+			if (authorizedClientProvider == null) {
+				authorizedClientProvider = new TokenExchangeOAuth2AuthorizedClientProvider();
+			}
+
+			authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
+		}
+
+		return authorizedClientProvider;
+	}
+
 	private List<OAuth2AuthorizedClientProvider> getAdditionalAuthorizedClientProviders(
 			Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders) {
 		List<OAuth2AuthorizedClientProvider> additionalAuthorizedClientProviders = new ArrayList<>(

+ 73 - 0
config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2AuthorizedClientManagerConfigurationTests.java

@@ -52,6 +52,7 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.TokenExchangeOAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.endpoint.AbstractOAuth2AuthorizationGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
@@ -59,6 +60,7 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCo
 import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest;
 import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
 import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
@@ -70,6 +72,7 @@ import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
 import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
@@ -327,6 +330,47 @@ public class OAuth2AuthorizedClientManagerConfigurationTests {
 		assertThat(grantRequest.getJwt().getSubject()).isEqualTo("user");
 	}
 
+	@Test
+	public void authorizeWhenTokenExchangeAccessTokenResponseClientBeanThenUsed() {
+		this.spring.register(CustomAccessTokenResponseClientsConfig.class).autowire();
+		testTokenExchangeGrant();
+	}
+
+	@Test
+	public void authorizeWhenTokenExchangeAuthorizedClientProviderBeanThenUsed() {
+		this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire();
+		testTokenExchangeGrant();
+	}
+
+	private void testTokenExchangeGrant() {
+		OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
+		given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(TokenExchangeGrantRequest.class)))
+			.willReturn(accessTokenResponse);
+
+		JwtAuthenticationToken authentication = new JwtAuthenticationToken(getJwt());
+		ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("auth0");
+		// @formatter:off
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
+				.withClientRegistrationId(clientRegistration.getRegistrationId())
+				.principal(authentication)
+				.attribute(HttpServletRequest.class.getName(), this.request)
+				.attribute(HttpServletResponse.class.getName(), this.response)
+				.build();
+		// @formatter:on
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
+		assertThat(authorizedClient).isNotNull();
+
+		ArgumentCaptor<TokenExchangeGrantRequest> grantRequestCaptor = ArgumentCaptor
+			.forClass(TokenExchangeGrantRequest.class);
+		verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture());
+
+		TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue();
+		assertThat(grantRequest.getClientRegistration().getRegistrationId())
+			.isEqualTo(clientRegistration.getRegistrationId());
+		assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE);
+		assertThat(grantRequest.getSubjectToken()).isEqualTo(authentication.getToken());
+	}
+
 	private static OAuth2AccessToken getExpiredAccessToken() {
 		Instant expiresAt = Instant.now().minusSeconds(60);
 		Instant issuedAt = expiresAt.minus(Duration.ofDays(1));
@@ -376,6 +420,11 @@ public class OAuth2AuthorizedClientManagerConfigurationTests {
 			return new MockJwtBearerClient();
 		}
 
+		@Bean
+		OAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> tokenExchangeTokenResponseClient() {
+			return new MockTokenExchangeClient();
+		}
+
 		@Bean
 		OAuth2UserService<OAuth2UserRequest, OAuth2User> oauth2UserService() {
 			return mock(DefaultOAuth2UserService.class);
@@ -425,6 +474,13 @@ public class OAuth2AuthorizedClientManagerConfigurationTests {
 			return authorizedClientProvider;
 		}
 
+		@Bean
+		TokenExchangeOAuth2AuthorizedClientProvider tokenExchangeAuthorizedClientProvider() {
+			TokenExchangeOAuth2AuthorizedClientProvider authorizedClientProvider = new TokenExchangeOAuth2AuthorizedClientProvider();
+			authorizedClientProvider.setAccessTokenResponseClient(new MockTokenExchangeClient());
+			return authorizedClientProvider;
+		}
+
 	}
 
 	abstract static class OAuth2ClientBaseConfig {
@@ -463,6 +519,14 @@ public class OAuth2AuthorizedClientManagerConfigurationTests {
 							.clientId("okta-client-id")
 							.clientSecret("okta-client-secret")
 							.authorizationGrantType(AuthorizationGrantType.JWT_BEARER)
+							.build(),
+					ClientRegistration.withRegistrationId("auth0")
+							.clientName("Auth0")
+							.clientId("auth0-client-id")
+							.clientSecret("auth0-client-secret")
+							.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC)
+							.authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE)
+							.scope("user.read", "user.write")
 							.build()));
 			// @formatter:on
 		}
@@ -544,4 +608,13 @@ public class OAuth2AuthorizedClientManagerConfigurationTests {
 
 	}
 
+	private static class MockTokenExchangeClient implements OAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> {
+
+		@Override
+		public OAuth2AccessTokenResponse getTokenResponse(TokenExchangeGrantRequest authorizationGrantRequest) {
+			return MOCK_RESPONSE_CLIENT.getTokenResponse(authorizationGrantRequest);
+		}
+
+	}
+
 }

+ 71 - 0
config/src/test/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrarTests.java

@@ -49,6 +49,7 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.TokenExchangeOAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.endpoint.AbstractOAuth2AuthorizationGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
@@ -56,11 +57,13 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCo
 import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
@@ -316,6 +319,47 @@ public class OAuth2AuthorizedClientManagerRegistrarTests {
 		assertThat(grantRequest.getJwt().getSubject()).isEqualTo("user");
 	}
 
+	@Test
+	public void authorizeWhenTokenExchangeAccessTokenResponseClientBeanThenUsed() {
+		this.spring.configLocations(xml("clients")).autowire();
+		testTokenExchangeGrant();
+	}
+
+	@Test
+	public void authorizeWhenTokenExchangeAuthorizedClientProviderBeanThenUsed() {
+		this.spring.configLocations(xml("providers")).autowire();
+		testTokenExchangeGrant();
+	}
+
+	private void testTokenExchangeGrant() {
+		OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
+		given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(TokenExchangeGrantRequest.class)))
+			.willReturn(accessTokenResponse);
+
+		JwtAuthenticationToken authentication = new JwtAuthenticationToken(getJwt());
+		ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("auth0");
+		// @formatter:off
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
+				.withClientRegistrationId(clientRegistration.getRegistrationId())
+				.principal(authentication)
+				.attribute(HttpServletRequest.class.getName(), this.request)
+				.attribute(HttpServletResponse.class.getName(), this.response)
+				.build();
+		// @formatter:on
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
+		assertThat(authorizedClient).isNotNull();
+
+		ArgumentCaptor<TokenExchangeGrantRequest> grantRequestCaptor = ArgumentCaptor
+			.forClass(TokenExchangeGrantRequest.class);
+		verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture());
+
+		TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue();
+		assertThat(grantRequest.getClientRegistration().getRegistrationId())
+			.isEqualTo(clientRegistration.getRegistrationId());
+		assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE);
+		assertThat(grantRequest.getSubjectToken()).isEqualTo(authentication.getToken());
+	}
+
 	private static OAuth2AccessToken getExpiredAccessToken() {
 		Instant expiresAt = Instant.now().minusSeconds(60);
 		Instant issuedAt = expiresAt.minus(Duration.ofDays(1));
@@ -356,6 +400,14 @@ public class OAuth2AuthorizedClientManagerRegistrarTests {
 						.clientId("okta-client-id")
 						.clientSecret("okta-client-secret")
 						.authorizationGrantType(AuthorizationGrantType.JWT_BEARER)
+						.build(),
+				ClientRegistration.withRegistrationId("auth0")
+						.clientName("Auth0")
+						.clientId("auth0-client-id")
+						.clientSecret("auth0-client-secret")
+						.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC)
+						.authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE)
+						.scope("user.read", "user.write")
 						.build());
 		// @formatter:on
 	}
@@ -422,6 +474,16 @@ public class OAuth2AuthorizedClientManagerRegistrarTests {
 		return new MockJwtBearerClient();
 	}
 
+	public static TokenExchangeOAuth2AuthorizedClientProvider tokenExchangeAuthorizedClientProvider() {
+		TokenExchangeOAuth2AuthorizedClientProvider authorizedClientProvider = new TokenExchangeOAuth2AuthorizedClientProvider();
+		authorizedClientProvider.setAccessTokenResponseClient(tokenExchangeAccessTokenResponseClient());
+		return authorizedClientProvider;
+	}
+
+	public static OAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> tokenExchangeAccessTokenResponseClient() {
+		return new MockTokenExchangeClient();
+	}
+
 	private static class MockAuthorizationCodeClient
 			implements OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> {
 
@@ -472,4 +534,13 @@ public class OAuth2AuthorizedClientManagerRegistrarTests {
 
 	}
 
+	private static class MockTokenExchangeClient implements OAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> {
+
+		@Override
+		public OAuth2AccessTokenResponse getTokenResponse(TokenExchangeGrantRequest authorizationGrantRequest) {
+			return MOCK_RESPONSE_CLIENT.getTokenResponse(authorizationGrantRequest);
+		}
+
+	}
+
 }

+ 3 - 0
config/src/test/resources/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrarTests-clients.xml

@@ -53,4 +53,7 @@
 	<b:bean class="org.springframework.security.config.http.OAuth2AuthorizedClientManagerRegistrarTests"
 			factory-method="jwtBearerAccessTokenResponseClient"/>
 
+	<b:bean class="org.springframework.security.config.http.OAuth2AuthorizedClientManagerRegistrarTests"
+			factory-method="tokenExchangeAccessTokenResponseClient"/>
+
 </b:beans>

+ 3 - 0
config/src/test/resources/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrarTests-providers.xml

@@ -56,4 +56,7 @@
 	<b:bean class="org.springframework.security.config.http.OAuth2AuthorizedClientManagerRegistrarTests"
 			factory-method="jwtBearerAuthorizedClientProvider"/>
 
+	<b:bean class="org.springframework.security.config.http.OAuth2AuthorizedClientManagerRegistrarTests"
+			factory-method="tokenExchangeAuthorizedClientProvider"/>
+
 </b:beans>