|
@@ -16,28 +16,21 @@
|
|
|
|
|
|
package org.springframework.security.oauth2.client.web.reactive.function.client;
|
|
|
|
|
|
-import com.sun.security.ntlm.Server;
|
|
|
import org.springframework.http.HttpHeaders;
|
|
|
import org.springframework.http.HttpMethod;
|
|
|
import org.springframework.http.MediaType;
|
|
|
import org.springframework.security.authentication.AnonymousAuthenticationToken;
|
|
|
import org.springframework.security.core.Authentication;
|
|
|
-import org.springframework.security.core.GrantedAuthority;
|
|
|
import org.springframework.security.core.authority.AuthorityUtils;
|
|
|
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
|
|
-import org.springframework.security.core.context.SecurityContext;
|
|
|
-import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
|
|
|
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
|
|
-import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
|
|
|
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
|
|
|
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
|
|
|
-import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient;
|
|
|
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
|
|
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
|
|
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
|
|
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
|
|
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
|
|
|
-import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
|
|
|
import org.springframework.util.Assert;
|
|
|
import org.springframework.web.reactive.function.BodyInserters;
|
|
|
import org.springframework.web.reactive.function.client.ClientRequest;
|
|
@@ -51,9 +44,7 @@ import java.net.URI;
|
|
|
import java.time.Clock;
|
|
|
import java.time.Duration;
|
|
|
import java.time.Instant;
|
|
|
-import java.util.Collection;
|
|
|
import java.util.Map;
|
|
|
-import java.util.Optional;
|
|
|
import java.util.function.Consumer;
|
|
|
|
|
|
import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse;
|
|
@@ -88,20 +79,13 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|
|
|
|
|
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
|
|
|
|
|
|
- private boolean defaultOAuth2AuthorizedClient;
|
|
|
-
|
|
|
- private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
|
|
|
- new WebClientReactiveClientCredentialsTokenResponseClient();
|
|
|
-
|
|
|
- private ReactiveClientRegistrationRepository clientRegistrationRepository;
|
|
|
-
|
|
|
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
|
|
|
|
|
|
- public ServerOAuth2AuthorizedClientExchangeFilterFunction() {}
|
|
|
+ private final OAuth2AuthorizedClientResolver authorizedClientResolver;
|
|
|
|
|
|
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
|
|
|
- this.clientRegistrationRepository = clientRegistrationRepository;
|
|
|
this.authorizedClientRepository = authorizedClientRepository;
|
|
|
+ this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository);
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -142,6 +126,9 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|
|
return attributes -> attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient);
|
|
|
}
|
|
|
|
|
|
+ private static OAuth2AuthorizedClient oauth2AuthorizedClient(ClientRequest request) {
|
|
|
+ return (OAuth2AuthorizedClient) request.attributes().get(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME);
|
|
|
+ }
|
|
|
|
|
|
/**
|
|
|
* Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for
|
|
@@ -166,6 +153,10 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|
|
return attributes -> attributes.put(SERVER_WEB_EXCHANGE_ATTR_NAME, serverWebExchange);
|
|
|
}
|
|
|
|
|
|
+ private static ServerWebExchange serverWebExchange(ClientRequest request) {
|
|
|
+ return (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
|
|
|
+ }
|
|
|
+
|
|
|
/**
|
|
|
* Modifies the {@link ClientRequest#attributes()} to include the {@link ClientRegistration#getRegistrationId()} to
|
|
|
* be used to look up the {@link OAuth2AuthorizedClient}.
|
|
@@ -178,6 +169,14 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|
|
return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
|
|
|
}
|
|
|
|
|
|
+ private static String clientRegistrationId(ClientRequest request) {
|
|
|
+ OAuth2AuthorizedClient authorizedClient = oauth2AuthorizedClient(request);
|
|
|
+ if (authorizedClient != null) {
|
|
|
+ return authorizedClient.getClientRegistration().getRegistrationId();
|
|
|
+ }
|
|
|
+ return (String) request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME);
|
|
|
+ }
|
|
|
+
|
|
|
/**
|
|
|
* If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
|
|
|
* recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
|
|
@@ -186,7 +185,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|
|
* Default is false.
|
|
|
*/
|
|
|
public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) {
|
|
|
- this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient;
|
|
|
+ this.authorizedClientResolver.setDefaultOAuth2AuthorizedClient(defaultOAuth2AuthorizedClient);
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -196,8 +195,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|
|
*/
|
|
|
public void setClientCredentialsTokenResponseClient(
|
|
|
ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
|
|
|
- Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
|
|
|
- this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
|
|
|
+ this.authorizedClientResolver.setClientCredentialsTokenResponseClient(clientCredentialsTokenResponseClient);
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -212,128 +210,59 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|
|
|
|
|
@Override
|
|
|
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
|
|
|
- return authorizedClient(request)
|
|
|
- .flatMap(authorizedClient -> refreshIfNecessary(next, authorizedClient, request))
|
|
|
+ return authorizedClient(request, next)
|
|
|
.map(authorizedClient -> bearer(request, authorizedClient))
|
|
|
.flatMap(next::exchange)
|
|
|
.switchIfEmpty(next.exchange(request));
|
|
|
}
|
|
|
|
|
|
- private Mono<ServerWebExchange> serverWebExchange(ClientRequest request) {
|
|
|
- ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
|
|
|
- return Mono.justOrEmpty(exchange)
|
|
|
- .switchIfEmpty(serverWebExchange());
|
|
|
+ private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request, ExchangeFunction next) {
|
|
|
+ OAuth2AuthorizedClient authorizedClientFromAttrs = oauth2AuthorizedClient(request);
|
|
|
+ return Mono.justOrEmpty(authorizedClientFromAttrs)
|
|
|
+ .switchIfEmpty(Mono.defer(() -> loadAuthorizedClient(request)))
|
|
|
+ .flatMap(authorizedClient -> refreshIfNecessary(request, next, authorizedClient));
|
|
|
}
|
|
|
|
|
|
- private Mono<ServerWebExchange> serverWebExchange() {
|
|
|
- return Mono.subscriberContext()
|
|
|
- .filter(c -> c.hasKey(ServerWebExchange.class))
|
|
|
- .map(c -> c.get(ServerWebExchange.class));
|
|
|
+ private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(ClientRequest request) {
|
|
|
+ return createRequest(request)
|
|
|
+ .flatMap(r -> this.authorizedClientResolver.loadAuthorizedClient(r));
|
|
|
}
|
|
|
|
|
|
- private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) {
|
|
|
- Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
|
|
|
- .map(OAuth2AuthorizedClient.class::cast);
|
|
|
- return Mono.justOrEmpty(attribute)
|
|
|
- .switchIfEmpty(findAuthorizedClientByRegistrationId(request));
|
|
|
+ private Mono<OAuth2AuthorizedClientResolver.Request> createRequest(ClientRequest request) {
|
|
|
+ String clientRegistrationId = clientRegistrationId(request);
|
|
|
+ Authentication authentication = null;
|
|
|
+ ServerWebExchange exchange = serverWebExchange(request);
|
|
|
+ return this.authorizedClientResolver.createDefaultedRequest(clientRegistrationId, authentication, exchange);
|
|
|
}
|
|
|
|
|
|
- private Mono<OAuth2AuthorizedClient> findAuthorizedClientByRegistrationId(ClientRequest request) {
|
|
|
- if (this.authorizedClientRepository == null) {
|
|
|
- return Mono.empty();
|
|
|
- }
|
|
|
-
|
|
|
- return currentAuthentication()
|
|
|
- .flatMap(principal -> clientRegistrationId(request, principal)
|
|
|
- .flatMap(clientRegistrationId -> serverWebExchange(request).flatMap(exchange -> loadAuthorizedClient(clientRegistrationId, exchange, principal)))
|
|
|
- );
|
|
|
- }
|
|
|
-
|
|
|
- private Mono<String> clientRegistrationId(ClientRequest request, Authentication authentication) {
|
|
|
- return Mono.justOrEmpty(request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME))
|
|
|
- .cast(String.class)
|
|
|
- .switchIfEmpty(clientRegistrationId(authentication));
|
|
|
- }
|
|
|
-
|
|
|
- private Mono<String> clientRegistrationId(Authentication authentication) {
|
|
|
- return Mono.justOrEmpty(authentication)
|
|
|
- .filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken)
|
|
|
- .cast(OAuth2AuthenticationToken.class)
|
|
|
- .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
|
|
|
- }
|
|
|
-
|
|
|
- private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId,
|
|
|
- ServerWebExchange exchange, Authentication principal) {
|
|
|
- return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange)
|
|
|
- .switchIfEmpty(authorizedClientNotFound(clientRegistrationId, exchange));
|
|
|
- }
|
|
|
-
|
|
|
- private Mono<OAuth2AuthorizedClient> authorizedClientNotFound(String clientRegistrationId, ServerWebExchange exchange) {
|
|
|
- return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
|
|
|
- .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
|
|
|
- .flatMap(clientRegistration -> {
|
|
|
- if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
|
|
|
- return clientCredentials(clientRegistration, exchange);
|
|
|
- }
|
|
|
- return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
|
|
|
- });
|
|
|
- }
|
|
|
-
|
|
|
- private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
|
|
|
- ClientRegistration clientRegistration, ServerWebExchange exchange) {
|
|
|
- OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
|
|
|
- return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest)
|
|
|
- .flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, tokenResponse, exchange));
|
|
|
- }
|
|
|
-
|
|
|
- private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistration clientRegistration, OAuth2AccessTokenResponse tokenResponse, ServerWebExchange exchange) {
|
|
|
- return currentAuthentication()
|
|
|
- .flatMap(principal -> {
|
|
|
- OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
|
|
|
- clientRegistration, (principal != null ?
|
|
|
- principal.getName() :
|
|
|
- "anonymousUser"),
|
|
|
- tokenResponse.getAccessToken());
|
|
|
-
|
|
|
- return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, null)
|
|
|
- .thenReturn(authorizedClient);
|
|
|
- });
|
|
|
- }
|
|
|
-
|
|
|
- private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ClientRequest request) {
|
|
|
+ private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
|
|
|
if (shouldRefresh(authorizedClient)) {
|
|
|
- return serverWebExchange(request)
|
|
|
- .flatMap(exchange -> refreshAuthorizedClient(next, authorizedClient, exchange));
|
|
|
+ return createRequest(request)
|
|
|
+ .flatMap(r -> refreshAuthorizedClient(next, authorizedClient, r));
|
|
|
}
|
|
|
return Mono.just(authorizedClient);
|
|
|
}
|
|
|
|
|
|
private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ExchangeFunction next,
|
|
|
- OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
|
|
|
+ OAuth2AuthorizedClient authorizedClient, OAuth2AuthorizedClientResolver.Request r) {
|
|
|
+ ServerWebExchange exchange = r.getExchange();
|
|
|
+ Authentication authentication = r.getAuthentication();
|
|
|
ClientRegistration clientRegistration = authorizedClient
|
|
|
.getClientRegistration();
|
|
|
String tokenUri = clientRegistration
|
|
|
.getProviderDetails().getTokenUri();
|
|
|
- ClientRequest request = ClientRequest.create(HttpMethod.POST, URI.create(tokenUri))
|
|
|
+ ClientRequest refreshRequest = ClientRequest.create(HttpMethod.POST, URI.create(tokenUri))
|
|
|
.header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
|
|
|
.headers(headers -> headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()))
|
|
|
.body(refreshTokenBody(authorizedClient.getRefreshToken().getTokenValue()))
|
|
|
.build();
|
|
|
- return next.exchange(request)
|
|
|
- .flatMap(response -> response.body(oauth2AccessTokenResponse()))
|
|
|
+ return next.exchange(refreshRequest)
|
|
|
+ .flatMap(refreshResponse -> refreshResponse.body(oauth2AccessTokenResponse()))
|
|
|
.map(accessTokenResponse -> new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken()))
|
|
|
- .flatMap(result -> currentAuthentication()
|
|
|
- .defaultIfEmpty(new PrincipalNameAuthentication(authorizedClient.getPrincipalName()))
|
|
|
- .flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(result, principal, exchange))
|
|
|
+ .flatMap(result -> this.authorizedClientRepository.saveAuthorizedClient(result, authentication, exchange)
|
|
|
.thenReturn(result));
|
|
|
}
|
|
|
|
|
|
- private Mono<Authentication> currentAuthentication() {
|
|
|
- return ReactiveSecurityContextHolder.getContext()
|
|
|
- .map(SecurityContext::getAuthentication)
|
|
|
- .defaultIfEmpty(ANONYMOUS_USER_TOKEN);
|
|
|
- }
|
|
|
-
|
|
|
private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
|
|
|
if (this.authorizedClientRepository == null) {
|
|
|
return false;
|
|
@@ -361,52 +290,4 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
|
|
.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue())
|
|
|
.with("refresh_token", refreshToken);
|
|
|
}
|
|
|
-
|
|
|
- private static class PrincipalNameAuthentication implements Authentication {
|
|
|
- private final String username;
|
|
|
-
|
|
|
- private PrincipalNameAuthentication(String username) {
|
|
|
- this.username = username;
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public Collection<? extends GrantedAuthority> getAuthorities() {
|
|
|
- throw unsupported();
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public Object getCredentials() {
|
|
|
- throw unsupported();
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public Object getDetails() {
|
|
|
- throw unsupported();
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public Object getPrincipal() {
|
|
|
- throw unsupported();
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public boolean isAuthenticated() {
|
|
|
- throw unsupported();
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public void setAuthenticated(boolean isAuthenticated)
|
|
|
- throws IllegalArgumentException {
|
|
|
- throw unsupported();
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public String getName() {
|
|
|
- return this.username;
|
|
|
- }
|
|
|
-
|
|
|
- private UnsupportedOperationException unsupported() {
|
|
|
- return new UnsupportedOperationException("Not Supported");
|
|
|
- }
|
|
|
- }
|
|
|
}
|