Просмотр исходного кода

When expired retrieve new Client Credentials token.

Once client credentials access token has expired retrieve a new token from the OAuth2 authorization server.
These tokens can't be refreshed because they do not have a refresh token associated with. This is standard behaviour for Oauth 2 client credentails

Fixes gh-5893
Warren Bailey 6 лет назад
Родитель
Сommit
450a20add4

+ 1 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientResolver.java

@@ -133,7 +133,7 @@ class OAuth2AuthorizedClientResolver {
 			});
 }
 
-	private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
+	Mono<OAuth2AuthorizedClient> clientCredentials(
 			ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) {
 		OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
 		return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest)

+ 27 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

@@ -85,8 +85,12 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 	private final OAuth2AuthorizedClientResolver authorizedClientResolver;
 
 	public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
+		this(authorizedClientRepository, new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository));
+	}
+
+	ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository, OAuth2AuthorizedClientResolver authorizedClientResolver) {
 		this.authorizedClientRepository = authorizedClientRepository;
-		this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository);
+		this.authorizedClientResolver = authorizedClientResolver;
 	}
 
 	/**
@@ -246,13 +250,30 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 	}
 
 	private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
-		if (shouldRefresh(authorizedClient)) {
+		ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
+		if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) {
+			return createRequest(request)
+					.flatMap(r -> authorizeWithClientCredentials(clientRegistration, r));
+		} else if (shouldRefresh(authorizedClient)) {
 			return createRequest(request)
 				.flatMap(r -> refreshAuthorizedClient(next, authorizedClient, r));
 		}
 		return Mono.just(authorizedClient);
 	}
 
+	private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) {
+		return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType());
+	}
+
+	private Mono<OAuth2AuthorizedClient> authorizeWithClientCredentials(ClientRegistration clientRegistration, OAuth2AuthorizedClientResolver.Request request) {
+		Authentication authentication = request.getAuthentication();
+		ServerWebExchange exchange = request.getExchange();
+
+		return this.authorizedClientResolver.clientCredentials(clientRegistration, authentication, exchange).
+				flatMap(result -> this.authorizedClientRepository.saveAuthorizedClient(result, authentication, exchange)
+						.thenReturn(result));
+	}
+
 	private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ExchangeFunction next,
 			OAuth2AuthorizedClient authorizedClient, OAuth2AuthorizedClientResolver.Request r) {
 		ServerWebExchange exchange = r.getExchange();
@@ -285,6 +306,10 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 		if (refreshToken == null) {
 			return false;
 		}
+		return hasTokenExpired(authorizedClient);
+	}
+
+	private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
 		Instant now = this.clock.instant();
 		Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
 		if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) {

+ 13 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

@@ -412,6 +412,10 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
 		throw new ClientAuthorizationRequiredException(clientRegistrationId);
 	}
 
+	private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) {
+		return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType());
+	}
+
 	private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegistration,
 			Map<String, Object> attrs) {
 
@@ -439,7 +443,11 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
 	}
 
 	private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
-		if (shouldRefresh(authorizedClient)) {
+		ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
+		if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) {
+			//Client credentials grant do not have refresh tokens but can expire so we need to get another one
+			return Mono.fromSupplier(() -> getAuthorizedClient(clientRegistration, request.attributes()));
+		} else if (shouldRefresh(authorizedClient)) {
 			return refreshAuthorizedClient(request, next, authorizedClient);
 		}
 		return Mono.just(authorizedClient);
@@ -484,6 +492,10 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
 		if (refreshToken == null) {
 			return false;
 		}
+		return hasTokenExpired(authorizedClient);
+	}
+
+	private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
 		Instant now = this.clock.instant();
 		Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
 		if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) {

+ 87 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@@ -44,6 +44,7 @@ import org.springframework.security.oauth2.client.authentication.OAuth2Authentic
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
+import org.springframework.security.oauth2.client.web.reactive.function.client.OAuth2AuthorizedClientResolver.Request;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
@@ -69,6 +70,7 @@ import java.util.Optional;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyZeroInteractions;
 import static org.mockito.Mockito.when;
@@ -88,6 +90,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 	@Mock
 	private ReactiveClientRegistrationRepository clientRegistrationRepository;
 
+	@Mock
+	private OAuth2AuthorizedClientResolver oAuth2AuthorizedClientResolver;
+
 	@Mock
 	private ServerWebExchange serverWebExchange;
 
@@ -149,6 +154,88 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue());
 	}
 
+	@Test
+	public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() {
+		TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
+		ClientRegistration registration = TestClientRegistrations.clientCredentials().build();
+		String clientRegistrationId = registration.getClientId();
+
+		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.oAuth2AuthorizedClientResolver);
+
+		OAuth2AccessToken newAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
+				"new-token",
+				Instant.now(),
+				Instant.now().plus(Duration.ofDays(1)));
+		OAuth2AuthorizedClient newAuthorizedClient = new OAuth2AuthorizedClient(registration,
+				"principalName", newAccessToken, null);
+		Request r = new Request(clientRegistrationId, authentication, null);
+		when(this.oAuth2AuthorizedClientResolver.clientCredentials(any(), any(), any())).thenReturn(Mono.just(newAuthorizedClient));
+		when(this.oAuth2AuthorizedClientResolver.createDefaultedRequest(any(), any(), any())).thenReturn(Mono.just(r));
+
+		when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
+
+		Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
+		Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
+
+		OAuth2AccessToken accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
+				this.accessToken.getTokenValue(),
+				issuedAt,
+				accessTokenExpiresAt);
+
+
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration,
+				"principalName", accessToken, null);
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.build();
+
+
+		this.function.filter(request, this.exchange)
+				.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
+				.block();
+
+		verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
+		verify(this.oAuth2AuthorizedClientResolver).clientCredentials(any(), any(), any());
+		verify(this.oAuth2AuthorizedClientResolver).createDefaultedRequest(any(), any(), any());
+
+		List<ClientRequest> requests = this.exchange.getRequests();
+		assertThat(requests).hasSize(1);
+		ClientRequest request1 = requests.get(0);
+		assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer new-token");
+		assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
+		assertThat(request1.method()).isEqualTo(HttpMethod.GET);
+		assertThat(getBody(request1)).isEmpty();
+	}
+
+	@Test
+	public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() {
+		TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
+		ClientRegistration registration = TestClientRegistrations.clientCredentials().build();
+
+		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.oAuth2AuthorizedClientResolver);
+
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration,
+				"principalName", this.accessToken, null);
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.build();
+
+		this.function.filter(request, this.exchange)
+				.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
+				.block();
+
+		verify(this.oAuth2AuthorizedClientResolver, never()).clientCredentials(any(), any(), any());
+		verify(this.oAuth2AuthorizedClientResolver, never()).createDefaultedRequest(any(), any(), any());
+
+		List<ClientRequest> requests = this.exchange.getRequests();
+		assertThat(requests).hasSize(1);
+		ClientRequest request1 = requests.get(0);
+		assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
+		assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
+		assertThat(request1.method()).isEqualTo(HttpMethod.GET);
+		assertThat(getBody(request1)).isEmpty();
+	}
+
 	@Test
 	public void filterWhenRefreshRequiredThenRefresh() {
 		when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());

+ 80 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@@ -55,6 +55,7 @@ import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepo
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.web.context.request.RequestContextHolder;
 import org.springframework.web.context.request.ServletRequestAttributes;
@@ -80,7 +81,11 @@ import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatCode;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.Mockito.*;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyZeroInteractions;
+import static org.mockito.Mockito.when;
 import static org.springframework.http.HttpMethod.GET;
 import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.*;
 
@@ -433,6 +438,80 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		assertThat(getBody(request1)).isEmpty();
 	}
 
+	@Test
+	public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() {
+		this.registration = TestClientRegistrations.clientCredentials().build();
+
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
+				this.authorizedClientRepository);
+		this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient);
+
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
+				"principalName", this.accessToken, null);
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.attributes(authentication(this.authentication))
+				.build();
+
+		this.function.filter(request, this.exchange).block();
+
+		verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), eq(this.authentication), any(), any());
+
+		verify(clientCredentialsTokenResponseClient, never()).getTokenResponse(any());
+
+		List<ClientRequest> requests = this.exchange.getRequests();
+		assertThat(requests).hasSize(1);
+
+		ClientRequest request1 = requests.get(0);
+		assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
+		assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
+		assertThat(request1.method()).isEqualTo(HttpMethod.GET);
+		assertThat(getBody(request1)).isEmpty();
+	}
+
+	@Test
+	public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() {
+		this.registration = TestClientRegistrations.clientCredentials().build();
+
+		OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses
+				.accessTokenResponse().build();
+		when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(
+				accessTokenResponse);
+
+		Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
+		Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
+
+		this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
+				this.accessToken.getTokenValue(),
+				issuedAt,
+				accessTokenExpiresAt);
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
+				this.authorizedClientRepository);
+		this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient);
+
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
+				"principalName", this.accessToken, null);
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(oauth2AuthorizedClient(authorizedClient))
+				.attributes(authentication(this.authentication))
+				.build();
+
+		this.function.filter(request, this.exchange).block();
+
+		verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(this.authentication), any(), any());
+
+		verify(clientCredentialsTokenResponseClient).getTokenResponse(any());
+
+		List<ClientRequest> requests = this.exchange.getRequests();
+		assertThat(requests).hasSize(1);
+
+		ClientRequest request1 = requests.get(0);
+		assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token");
+		assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
+		assertThat(request1.method()).isEqualTo(HttpMethod.GET);
+		assertThat(getBody(request1)).isEmpty();
+	}
+
 	@Test
 	public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() {
 		OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")