ソースを参照

ServletOAuth2AuthorizedClientExchangeFilterFunction support client_credentials

Fixes: gh-5639
Rob Winch 7 年 前
コミット
f5ad4ba0fa

+ 65 - 4
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

@@ -26,10 +26,15 @@ import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
+import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient;
+import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
+import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.util.Assert;
 import org.springframework.web.context.request.RequestContextHolder;
 import org.springframework.web.context.request.ServletRequestAttributes;
@@ -107,16 +112,35 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 
 	private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
 
+	private ClientRegistrationRepository clientRegistrationRepository;
+
 	private OAuth2AuthorizedClientRepository authorizedClientRepository;
 
+	private OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
+			new DefaultClientCredentialsTokenResponseClient();
+
 	private boolean defaultOAuth2AuthorizedClient;
 
 	public ServletOAuth2AuthorizedClientExchangeFilterFunction() {}
 
-	public ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientRepository authorizedClientRepository) {
+	public ServletOAuth2AuthorizedClientExchangeFilterFunction(
+			ClientRegistrationRepository clientRegistrationRepository,
+			OAuth2AuthorizedClientRepository authorizedClientRepository) {
+		this.clientRegistrationRepository = clientRegistrationRepository;
 		this.authorizedClientRepository = authorizedClientRepository;
 	}
 
+	/**
+	 * Sets the {@link OAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
+	 * client_credentials grant.
+	 * @param clientCredentialsTokenResponseClient the client to use
+	 */
+	public void setClientCredentialsTokenResponseClient(
+			OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
+		Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
+		this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
+	}
+
 	/**
 	 * If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
 	 * recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
@@ -277,18 +301,55 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 			clientRegistrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId();
 		}
 		if (clientRegistrationId != null) {
-			HttpServletRequest request = (HttpServletRequest) attrs.get(
-					HTTP_SERVLET_REQUEST_ATTR_NAME);
+			HttpServletRequest request = getRequest(attrs);
 			OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository
 					.loadAuthorizedClient(clientRegistrationId, authentication,
 							request);
 			if (authorizedClient == null) {
-				throw new ClientAuthorizationRequiredException(clientRegistrationId);
+				authorizedClient = getAuthorizedClient(clientRegistrationId, attrs);
 			}
 			oauth2AuthorizedClient(authorizedClient).accept(attrs);
 		}
 	}
 
+	private OAuth2AuthorizedClient getAuthorizedClient(String clientRegistrationId, Map<String, Object> attrs) {
+		ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
+		if (clientRegistration == null) {
+			throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId);
+		}
+		if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
+			return getAuthorizedClient(clientRegistration, attrs);
+		}
+		throw new ClientAuthorizationRequiredException(clientRegistrationId);
+	}
+
+
+	private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegistration,
+			Map<String, Object> attrs) {
+
+		HttpServletRequest request = getRequest(attrs);
+		HttpServletResponse response = getResponse(attrs);
+		OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest =
+				new OAuth2ClientCredentialsGrantRequest(clientRegistration);
+		OAuth2AccessTokenResponse tokenResponse =
+				this.clientCredentialsTokenResponseClient.getTokenResponse(clientCredentialsGrantRequest);
+
+		Authentication principal = getAuthentication(attrs);
+
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+				clientRegistration,
+				(principal != null ? principal.getName() : "anonymousUser"),
+				tokenResponse.getAccessToken());
+
+		this.authorizedClientRepository.saveAuthorizedClient(
+				authorizedClient,
+				principal,
+				request,
+				response);
+
+		return authorizedClient;
+	}
+
 	private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
 		if (shouldRefresh(authorizedClient)) {
 			return refreshAuthorizedClient(request, next, authorizedClient);

+ 7 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/TestClientRegistrations.java

@@ -54,4 +54,11 @@ public class TestClientRegistrations {
 				.clientId("client-id-2")
 				.clientSecret("client-secret");
 	}
+
+	public static ClientRegistration.Builder clientCredentials() {
+		return clientRegistration()
+				.registrationId("client-credentials")
+				.clientId("client-id")
+				.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS);
+	}
 }

+ 67 - 14
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@@ -46,12 +46,16 @@ import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
+import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
+import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.web.context.request.RequestContextHolder;
 import org.springframework.web.context.request.ServletRequestAttributes;
@@ -89,6 +93,10 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 	@Mock
 	private OAuth2AuthorizedClientRepository authorizedClientRepository;
 	@Mock
+	private ClientRegistrationRepository clientRegistrationRepository;
+	@Mock
+	private OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient;
+	@Mock
 	private WebClient.RequestHeadersSpec<?> spec;
 	@Captor
 	private ArgumentCaptor<Consumer<Map<String, Object>>> attrs;
@@ -148,7 +156,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 
 	@Test
 	public void defaultRequestAuthenticationWhenAuthenticationSetThenAuthenticationSet() {
-		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
+				this.authorizedClientRepository);
 		SecurityContextHolder.getContext().setAuthentication(this.authentication);
 		Map<String, Object> attrs = getDefaultRequestAttributes();
 		assertThat(getAuthentication(attrs)).isEqualTo(this.authentication);
@@ -157,7 +166,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 
 	@Test
 	public void defaultRequestOAuth2AuthorizedClientWhenOAuth2AuthorizationClientAndClientIdThenNotOverride() {
-		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
+				this.authorizedClientRepository);
 		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
 				"principalName", this.accessToken);
 		oauth2AuthorizedClient(authorizedClient).accept(this.result);
@@ -168,7 +178,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 
 	@Test
 	public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() {
-		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
+				this.authorizedClientRepository);
 		Map<String, Object> attrs = getDefaultRequestAttributes();
 		assertThat(getOAuth2AuthorizedClient(attrs)).isNull();
 		verifyZeroInteractions(this.authorizedClientRepository);
@@ -176,7 +187,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 
 	@Test
 	public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationWrongTypeAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() {
-		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
+				this.authorizedClientRepository);
 		Map<String, Object> attrs = getDefaultRequestAttributes();
 		assertThat(getOAuth2AuthorizedClient(attrs)).isNull();
 		verifyZeroInteractions(this.authorizedClientRepository);
@@ -196,7 +208,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 
 	@Test
 	public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() {
-		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
+				this.authorizedClientRepository);
 		this.function.setDefaultOAuth2AuthorizedClient(true);
 		OAuth2User user = mock(OAuth2User.class);
 		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
@@ -214,7 +227,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 
 	@Test
 	public void defaultRequestOAuth2AuthorizedClientWhenDefaultFalseAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() {
-		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
+				this.authorizedClientRepository);
 		OAuth2User user = mock(OAuth2User.class);
 		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
 		OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id");
@@ -227,7 +241,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 
 	@Test
 	public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegistrationIdThenIdIsExplicit() {
-		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
+				this.authorizedClientRepository);
 		OAuth2User user = mock(OAuth2User.class);
 		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
 		OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id");
@@ -245,9 +260,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 
 	@Test
 	public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdThenOAuth2AuthorizedClient() {
-		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
-		OAuth2User user = mock(OAuth2User.class);
-		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
+				this.authorizedClientRepository);
 		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
 				"principalName", this.accessToken);
 		when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient);
@@ -259,6 +273,41 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		verify(this.authorizedClientRepository).loadAuthorizedClient(eq("id"), any(), any());
 	}
 
+	@Test
+	public void defaultRequestWhenClientCredentialsThenAuthorizedClient() {
+		this.registration = TestClientRegistrations.clientCredentials().build();
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
+				this.authorizedClientRepository);
+		this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient);
+		when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration);
+		OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses
+				.accessTokenResponse().build();
+		when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(
+				accessTokenResponse);
+
+		clientRegistrationId(this.registration.getRegistrationId()).accept(this.result);
+
+		Map<String, Object> attrs = getDefaultRequestAttributes();
+		OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs);
+
+		assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
+		assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration);
+		assertThat(authorizedClient.getPrincipalName()).isEqualTo("anonymousUser");
+		assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
+	}
+
+	@Test
+	public void defaultRequestWhenClientIdNotFoundThenIllegalArgumentException() {
+		this.registration = TestClientRegistrations.clientCredentials().build();
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
+				this.authorizedClientRepository);
+
+		clientRegistrationId(this.registration.getRegistrationId()).accept(this.result);
+
+		assertThatCode(() -> getDefaultRequestAttributes())
+			.isInstanceOf(IllegalArgumentException.class);
+	}
+
 	private Map<String, Object> getDefaultRequestAttributes() {
 		this.function.defaultRequest().accept(this.spec);
 		verify(this.spec).attributes(this.attrs.capture());
@@ -322,7 +371,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 				this.accessToken.getTokenValue(),
 				issuedAt,
 				accessTokenExpiresAt);
-		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
+				this.authorizedClientRepository);
 
 		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
 		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@@ -368,7 +418,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 				this.accessToken.getTokenValue(),
 				issuedAt,
 				accessTokenExpiresAt);
-		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
+				this.authorizedClientRepository);
 
 		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
 		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@@ -400,7 +451,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 
 	@Test
 	public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
-		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
+				this.authorizedClientRepository);
 
 		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
 				"principalName", this.accessToken);
@@ -422,7 +474,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 
 	@Test
 	public void filterWhenNotExpiredThenShouldRefreshFalse() {
-		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
+				this.authorizedClientRepository);
 
 		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
 		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,