|
@@ -46,12 +46,16 @@ import org.springframework.security.core.authority.AuthorityUtils;
|
|
|
import org.springframework.security.core.context.SecurityContextHolder;
|
|
|
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
|
|
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
|
|
|
+import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
|
|
|
+import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
|
|
|
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
|
|
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
|
|
|
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
|
|
|
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
|
|
|
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
|
|
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
|
|
|
+import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
|
|
|
import org.springframework.security.oauth2.core.user.OAuth2User;
|
|
|
import org.springframework.web.context.request.RequestContextHolder;
|
|
|
import org.springframework.web.context.request.ServletRequestAttributes;
|
|
@@ -89,6 +93,10 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|
|
@Mock
|
|
|
private OAuth2AuthorizedClientRepository authorizedClientRepository;
|
|
|
@Mock
|
|
|
+ private ClientRegistrationRepository clientRegistrationRepository;
|
|
|
+ @Mock
|
|
|
+ private OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient;
|
|
|
+ @Mock
|
|
|
private WebClient.RequestHeadersSpec<?> spec;
|
|
|
@Captor
|
|
|
private ArgumentCaptor<Consumer<Map<String, Object>>> attrs;
|
|
@@ -148,7 +156,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|
|
|
|
|
@Test
|
|
|
public void defaultRequestAuthenticationWhenAuthenticationSetThenAuthenticationSet() {
|
|
|
- this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
|
|
|
+ this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
|
|
|
+ this.authorizedClientRepository);
|
|
|
SecurityContextHolder.getContext().setAuthentication(this.authentication);
|
|
|
Map<String, Object> attrs = getDefaultRequestAttributes();
|
|
|
assertThat(getAuthentication(attrs)).isEqualTo(this.authentication);
|
|
@@ -157,7 +166,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|
|
|
|
|
@Test
|
|
|
public void defaultRequestOAuth2AuthorizedClientWhenOAuth2AuthorizationClientAndClientIdThenNotOverride() {
|
|
|
- this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
|
|
|
+ this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
|
|
|
+ this.authorizedClientRepository);
|
|
|
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
|
|
"principalName", this.accessToken);
|
|
|
oauth2AuthorizedClient(authorizedClient).accept(this.result);
|
|
@@ -168,7 +178,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|
|
|
|
|
@Test
|
|
|
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() {
|
|
|
- this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
|
|
|
+ this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
|
|
|
+ this.authorizedClientRepository);
|
|
|
Map<String, Object> attrs = getDefaultRequestAttributes();
|
|
|
assertThat(getOAuth2AuthorizedClient(attrs)).isNull();
|
|
|
verifyZeroInteractions(this.authorizedClientRepository);
|
|
@@ -176,7 +187,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|
|
|
|
|
@Test
|
|
|
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationWrongTypeAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() {
|
|
|
- this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
|
|
|
+ this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
|
|
|
+ this.authorizedClientRepository);
|
|
|
Map<String, Object> attrs = getDefaultRequestAttributes();
|
|
|
assertThat(getOAuth2AuthorizedClient(attrs)).isNull();
|
|
|
verifyZeroInteractions(this.authorizedClientRepository);
|
|
@@ -196,7 +208,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|
|
|
|
|
@Test
|
|
|
public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() {
|
|
|
- this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
|
|
|
+ this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
|
|
|
+ this.authorizedClientRepository);
|
|
|
this.function.setDefaultOAuth2AuthorizedClient(true);
|
|
|
OAuth2User user = mock(OAuth2User.class);
|
|
|
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
|
|
@@ -214,7 +227,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|
|
|
|
|
@Test
|
|
|
public void defaultRequestOAuth2AuthorizedClientWhenDefaultFalseAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() {
|
|
|
- this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
|
|
|
+ this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
|
|
|
+ this.authorizedClientRepository);
|
|
|
OAuth2User user = mock(OAuth2User.class);
|
|
|
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
|
|
|
OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id");
|
|
@@ -227,7 +241,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|
|
|
|
|
@Test
|
|
|
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegistrationIdThenIdIsExplicit() {
|
|
|
- this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
|
|
|
+ this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
|
|
|
+ this.authorizedClientRepository);
|
|
|
OAuth2User user = mock(OAuth2User.class);
|
|
|
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
|
|
|
OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id");
|
|
@@ -245,9 +260,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|
|
|
|
|
@Test
|
|
|
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdThenOAuth2AuthorizedClient() {
|
|
|
- this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
|
|
|
- OAuth2User user = mock(OAuth2User.class);
|
|
|
- List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
|
|
|
+ this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
|
|
|
+ this.authorizedClientRepository);
|
|
|
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
|
|
"principalName", this.accessToken);
|
|
|
when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient);
|
|
@@ -259,6 +273,41 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|
|
verify(this.authorizedClientRepository).loadAuthorizedClient(eq("id"), any(), any());
|
|
|
}
|
|
|
|
|
|
+ @Test
|
|
|
+ public void defaultRequestWhenClientCredentialsThenAuthorizedClient() {
|
|
|
+ this.registration = TestClientRegistrations.clientCredentials().build();
|
|
|
+ this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
|
|
|
+ this.authorizedClientRepository);
|
|
|
+ this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient);
|
|
|
+ when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration);
|
|
|
+ OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses
|
|
|
+ .accessTokenResponse().build();
|
|
|
+ when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(
|
|
|
+ accessTokenResponse);
|
|
|
+
|
|
|
+ clientRegistrationId(this.registration.getRegistrationId()).accept(this.result);
|
|
|
+
|
|
|
+ Map<String, Object> attrs = getDefaultRequestAttributes();
|
|
|
+ OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs);
|
|
|
+
|
|
|
+ assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
|
|
|
+ assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration);
|
|
|
+ assertThat(authorizedClient.getPrincipalName()).isEqualTo("anonymousUser");
|
|
|
+ assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void defaultRequestWhenClientIdNotFoundThenIllegalArgumentException() {
|
|
|
+ this.registration = TestClientRegistrations.clientCredentials().build();
|
|
|
+ this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
|
|
|
+ this.authorizedClientRepository);
|
|
|
+
|
|
|
+ clientRegistrationId(this.registration.getRegistrationId()).accept(this.result);
|
|
|
+
|
|
|
+ assertThatCode(() -> getDefaultRequestAttributes())
|
|
|
+ .isInstanceOf(IllegalArgumentException.class);
|
|
|
+ }
|
|
|
+
|
|
|
private Map<String, Object> getDefaultRequestAttributes() {
|
|
|
this.function.defaultRequest().accept(this.spec);
|
|
|
verify(this.spec).attributes(this.attrs.capture());
|
|
@@ -322,7 +371,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|
|
this.accessToken.getTokenValue(),
|
|
|
issuedAt,
|
|
|
accessTokenExpiresAt);
|
|
|
- this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
|
|
|
+ this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
|
|
|
+ this.authorizedClientRepository);
|
|
|
|
|
|
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
|
|
|
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
|
@@ -368,7 +418,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|
|
this.accessToken.getTokenValue(),
|
|
|
issuedAt,
|
|
|
accessTokenExpiresAt);
|
|
|
- this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
|
|
|
+ this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
|
|
|
+ this.authorizedClientRepository);
|
|
|
|
|
|
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
|
|
|
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
|
@@ -400,7 +451,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|
|
|
|
|
@Test
|
|
|
public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
|
|
|
- this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
|
|
|
+ this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
|
|
|
+ this.authorizedClientRepository);
|
|
|
|
|
|
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
|
|
"principalName", this.accessToken);
|
|
@@ -422,7 +474,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|
|
|
|
|
@Test
|
|
|
public void filterWhenNotExpiredThenShouldRefreshFalse() {
|
|
|
- this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
|
|
|
+ this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
|
|
|
+ this.authorizedClientRepository);
|
|
|
|
|
|
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
|
|
|
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|