Przeglądaj źródła

ServerOAuth2AuthorizedClientExchangeFilterFunction clientRegistrationId

Issue: gh-4921
Rob Winch 7 lat temu
rodzic
commit
158b8aa6d5

+ 58 - 5
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

@@ -27,12 +27,15 @@ import org.springframework.security.core.context.ReactiveSecurityContextHolder;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
-import org.springframework.security.oauth2.client.OAuth2ClientException;
+import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
+import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
-import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.util.Assert;
 import org.springframework.web.reactive.function.BodyInserters;
 import org.springframework.web.reactive.function.client.ClientRequest;
@@ -75,18 +78,25 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 	 * The request attribute name used to locate the {@link org.springframework.web.server.ServerWebExchange}.
 	 */
 	private static final String SERVER_WEB_EXCHANGE_ATTR_NAME = ServerWebExchange.class.getName();
-	public static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
+
+	private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
 			AuthorityUtils.createAuthorityList("ROLE_USER"));
 
 	private Clock clock = Clock.systemUTC();
 
 	private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
 
+	private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
+			new WebClientReactiveClientCredentialsTokenResponseClient();
+
+	private ReactiveClientRegistrationRepository clientRegistrationRepository;
+
 	private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
 
 	public ServerOAuth2AuthorizedClientExchangeFilterFunction() {}
 
-	public ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
+	public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
+		this.clientRegistrationRepository = clientRegistrationRepository;
 		this.authorizedClientRepository = authorizedClientRepository;
 	}
 
@@ -164,6 +174,17 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 		return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
 	}
 
+	/**
+	 * Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
+	 * client_credentials grant.
+	 * @param clientCredentialsTokenResponseClient the client to use
+	 */
+	public void setClientCredentialsTokenResponseClient(
+			ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
+		Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
+		this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
+	}
+
 	/**
 	 * An access token will be considered expired by comparing its expiration to now +
 	 * this skewed Duration. The default is 1 minute.
@@ -208,7 +229,39 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 	private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId,
 			ServerWebExchange exchange, Authentication principal) {
 		return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange)
-			.switchIfEmpty(Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId)));
+			.switchIfEmpty(authorizedClientNotFound(clientRegistrationId, exchange));
+	}
+
+	private Mono<OAuth2AuthorizedClient> authorizedClientNotFound(String clientRegistrationId, ServerWebExchange exchange) {
+		return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
+			.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
+			.flatMap(clientRegistration -> {
+				if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
+					return clientCredentials(clientRegistration, exchange);
+				}
+				return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
+			});
+	}
+
+	private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
+			ClientRegistration clientRegistration, ServerWebExchange exchange) {
+		OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
+		return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest)
+			.flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, tokenResponse, exchange));
+	}
+
+	private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistration clientRegistration, OAuth2AccessTokenResponse tokenResponse, ServerWebExchange exchange) {
+		return currentAuthentication()
+			.flatMap(principal -> {
+				OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+						clientRegistration, (principal != null ?
+						principal.getName() :
+						"anonymousUser"),
+						tokenResponse.getAccessToken());
+
+				return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, null)
+						.thenReturn(authorizedClient);
+			});
 	}
 
 	private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {

+ 16 - 11
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@@ -37,6 +37,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.context.ReactiveSecurityContextHolder;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
@@ -71,7 +72,10 @@ import static org.springframework.security.oauth2.client.web.reactive.function.c
 @RunWith(MockitoJUnitRunner.class)
 public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 	@Mock
-	private ServerOAuth2AuthorizedClientRepository auth2AuthorizedClientRepository;
+	private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
+
+	@Mock
+	private ReactiveClientRegistrationRepository clientRegistrationRepository;
 
 	private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction();
 
@@ -125,7 +129,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 
 	@Test
 	public void filterWhenRefreshRequiredThenRefresh() {
-		when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
+		when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
 		OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
 				.tokenType(OAuth2AccessToken.TokenType.BEARER)
 				.expiresIn(3600)
@@ -140,7 +144,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 				this.accessToken.getTokenValue(),
 				issuedAt,
 				accessTokenExpiresAt);
-		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
+		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
 
 		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
 		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@@ -154,7 +158,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 				.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
 				.block();
 
-		verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
+		verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
 
 		List<ClientRequest> requests = this.exchange.getRequests();
 		assertThat(requests).hasSize(2);
@@ -174,7 +178,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 
 	@Test
 	public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() {
-		when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
+		when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
 		OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
 				.tokenType(OAuth2AccessToken.TokenType.BEARER)
 				.expiresIn(3600)
@@ -189,7 +193,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 				this.accessToken.getTokenValue(),
 				issuedAt,
 				accessTokenExpiresAt);
-		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
+		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
 
 		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
 		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@@ -201,7 +205,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		this.function.filter(request, this.exchange)
 				.block();
 
-		verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), any(), any());
+		verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any());
 
 		List<ClientRequest> requests = this.exchange.getRequests();
 		assertThat(requests).hasSize(2);
@@ -221,7 +225,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 
 	@Test
 	public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
-		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
+		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
 
 		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
 				"principalName", this.accessToken);
@@ -243,7 +247,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 
 	@Test
 	public void filterWhenNotExpiredThenShouldRefreshFalse() {
-		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
+		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
 
 		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
 		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@@ -266,12 +270,13 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 
 	@Test
 	public void filterWhenClientRegistrationIdThenAuthorizedClientResolved() {
-		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
+		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
 
 		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
 		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
 				"principalName", this.accessToken, refreshToken);
-		when(this.auth2AuthorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
+		when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
+		when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration));
 		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
 				.attributes(clientRegistrationId(this.registration.getRegistrationId()))
 				.build();

+ 5 - 2
samples/boot/authcodegrant-webflux/src/main/java/sample/config/WebClientConfig.java

@@ -18,7 +18,9 @@ package sample.config;
 
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Configuration;
+import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction;
+import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.web.reactive.function.client.WebClient;
 
 /**
@@ -29,9 +31,10 @@ import org.springframework.web.reactive.function.client.WebClient;
 public class WebClientConfig {
 
 	@Bean
-	WebClient webClient() {
+	WebClient webClient(ReactiveClientRegistrationRepository clientRegistrationRepository,
+			ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
 		return WebClient.builder()
-				.filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction())
+				.filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction(clientRegistrationRepository, authorizedClientRepository))
 				.build();
 	}
 }