|
@@ -19,14 +19,19 @@ package org.springframework.security.oauth2.client.web.reactive.function.client;
|
|
|
import org.springframework.http.HttpHeaders;
|
|
|
import org.springframework.http.HttpMethod;
|
|
|
import org.springframework.http.MediaType;
|
|
|
+import org.springframework.security.authentication.AnonymousAuthenticationToken;
|
|
|
import org.springframework.security.core.Authentication;
|
|
|
import org.springframework.security.core.GrantedAuthority;
|
|
|
+import org.springframework.security.core.authority.AuthorityUtils;
|
|
|
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
|
|
import org.springframework.security.core.context.SecurityContext;
|
|
|
+import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
|
|
|
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
|
|
+import org.springframework.security.oauth2.client.OAuth2ClientException;
|
|
|
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
|
|
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
|
|
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
|
|
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
|
|
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
|
|
|
import org.springframework.util.Assert;
|
|
|
import org.springframework.web.reactive.function.BodyInserters;
|
|
@@ -61,10 +66,17 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|
|
*/
|
|
|
private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName();
|
|
|
|
|
|
+ /**
|
|
|
+ * The client request attribute name used to locate the {@link ClientRegistration#getRegistrationId()}
|
|
|
+ */
|
|
|
+ private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2AuthorizedClient.class.getName().concat(".CLIENT_REGISTRATION_ID");
|
|
|
+
|
|
|
/**
|
|
|
* The request attribute name used to locate the {@link org.springframework.web.server.ServerWebExchange}.
|
|
|
*/
|
|
|
private static final String SERVER_WEB_EXCHANGE_ATTR_NAME = ServerWebExchange.class.getName();
|
|
|
+ public static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
|
|
|
+ AuthorityUtils.createAuthorityList("ROLE_USER"));
|
|
|
|
|
|
private Clock clock = Clock.systemUTC();
|
|
|
|
|
@@ -74,8 +86,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|
|
|
|
|
public ServerOAuth2AuthorizedClientExchangeFilterFunction() {}
|
|
|
|
|
|
- public ServerOAuth2AuthorizedClientExchangeFilterFunction(
|
|
|
- ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
|
|
|
+ public ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
|
|
|
this.authorizedClientRepository = authorizedClientRepository;
|
|
|
}
|
|
|
|
|
@@ -141,6 +152,18 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|
|
return attributes -> attributes.put(SERVER_WEB_EXCHANGE_ATTR_NAME, serverWebExchange);
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * Modifies the {@link ClientRequest#attributes()} to include the {@link ClientRegistration#getRegistrationId()} to
|
|
|
+ * be used to look up the {@link OAuth2AuthorizedClient}.
|
|
|
+ *
|
|
|
+ * @param clientRegistrationId the {@link ClientRegistration#getRegistrationId()} to
|
|
|
+ * be used to look up the {@link OAuth2AuthorizedClient}.
|
|
|
+ * @return the {@link Consumer} to populate the attributes
|
|
|
+ */
|
|
|
+ public static Consumer<Map<String, Object>> clientRegistrationId(String clientRegistrationId) {
|
|
|
+ return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
|
|
|
+ }
|
|
|
+
|
|
|
/**
|
|
|
* An access token will be considered expired by comparing its expiration to now +
|
|
|
* this skewed Duration. The default is 1 minute.
|
|
@@ -153,17 +176,42 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|
|
|
|
|
@Override
|
|
|
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
|
|
|
- Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
|
|
|
- .map(OAuth2AuthorizedClient.class::cast);
|
|
|
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
|
|
|
- return Mono.justOrEmpty(attribute)
|
|
|
- .flatMap(authorizedClient -> authorizedClient(next, authorizedClient, exchange))
|
|
|
+ return authorizedClient(request)
|
|
|
+ .flatMap(authorizedClient -> refreshIfNecessary(next, authorizedClient, exchange))
|
|
|
.map(authorizedClient -> bearer(request, authorizedClient))
|
|
|
.flatMap(next::exchange)
|
|
|
.switchIfEmpty(next.exchange(request));
|
|
|
}
|
|
|
|
|
|
- private Mono<OAuth2AuthorizedClient> authorizedClient(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
|
|
|
+ private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) {
|
|
|
+ Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
|
|
|
+ .map(OAuth2AuthorizedClient.class::cast);
|
|
|
+ return Mono.justOrEmpty(attribute)
|
|
|
+ .switchIfEmpty(findAuthorizedClientByRegistrationId(request));
|
|
|
+ }
|
|
|
+
|
|
|
+ private Mono<OAuth2AuthorizedClient> findAuthorizedClientByRegistrationId(ClientRequest request) {
|
|
|
+ if (this.authorizedClientRepository == null) {
|
|
|
+ return Mono.empty();
|
|
|
+ }
|
|
|
+ String clientRegistrationId = (String) request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME);
|
|
|
+ if (clientRegistrationId == null) {
|
|
|
+ return Mono.empty();
|
|
|
+ }
|
|
|
+ ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
|
|
|
+ return currentAuthentication()
|
|
|
+ .flatMap(principal -> loadAuthorizedClient(clientRegistrationId, exchange, principal)
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId,
|
|
|
+ ServerWebExchange exchange, Authentication principal) {
|
|
|
+ return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange)
|
|
|
+ .switchIfEmpty(Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId)));
|
|
|
+ }
|
|
|
+
|
|
|
+ private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
|
|
|
if (shouldRefresh(authorizedClient)) {
|
|
|
return refreshAuthorizedClient(next, authorizedClient, exchange);
|
|
|
}
|
|
@@ -184,13 +232,18 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|
|
return next.exchange(request)
|
|
|
.flatMap(response -> response.body(oauth2AccessTokenResponse()))
|
|
|
.map(accessTokenResponse -> new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken()))
|
|
|
- .flatMap(result -> ReactiveSecurityContextHolder.getContext()
|
|
|
- .map(SecurityContext::getAuthentication)
|
|
|
+ .flatMap(result -> currentAuthentication()
|
|
|
.defaultIfEmpty(new PrincipalNameAuthentication(authorizedClient.getPrincipalName()))
|
|
|
.flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(result, principal, exchange))
|
|
|
.thenReturn(result));
|
|
|
}
|
|
|
|
|
|
+ private Mono<Authentication> currentAuthentication() {
|
|
|
+ return ReactiveSecurityContextHolder.getContext()
|
|
|
+ .map(SecurityContext::getAuthentication)
|
|
|
+ .defaultIfEmpty(ANONYMOUS_USER_TOKEN);
|
|
|
+ }
|
|
|
+
|
|
|
private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
|
|
|
if (this.authorizedClientRepository == null) {
|
|
|
return false;
|