|
@@ -20,10 +20,12 @@ import static org.assertj.core.api.Assertions.assertThat;
|
|
import static org.mockito.BDDMockito.given;
|
|
import static org.mockito.BDDMockito.given;
|
|
import static org.mockito.ArgumentMatchers.any;
|
|
import static org.mockito.ArgumentMatchers.any;
|
|
import static org.mockito.Mockito.mock;
|
|
import static org.mockito.Mockito.mock;
|
|
|
|
+import static org.mockito.Mockito.spy;
|
|
import static org.mockito.Mockito.verify;
|
|
import static org.mockito.Mockito.verify;
|
|
import static org.mockito.Mockito.verifyZeroInteractions;
|
|
import static org.mockito.Mockito.verifyZeroInteractions;
|
|
import static org.mockito.Mockito.when;
|
|
import static org.mockito.Mockito.when;
|
|
import static org.springframework.security.config.Customizer.withDefaults;
|
|
import static org.springframework.security.config.Customizer.withDefaults;
|
|
|
|
+import static org.springframework.test.util.ReflectionTestUtils.getField;
|
|
|
|
|
|
import java.util.Arrays;
|
|
import java.util.Arrays;
|
|
import java.util.List;
|
|
import java.util.List;
|
|
@@ -35,16 +37,20 @@ import org.apache.http.HttpHeaders;
|
|
import org.junit.Before;
|
|
import org.junit.Before;
|
|
import org.junit.Test;
|
|
import org.junit.Test;
|
|
import org.junit.runner.RunWith;
|
|
import org.junit.runner.RunWith;
|
|
|
|
+import org.mockito.ArgumentCaptor;
|
|
import org.mockito.Mock;
|
|
import org.mockito.Mock;
|
|
import org.mockito.junit.MockitoJUnitRunner;
|
|
import org.mockito.junit.MockitoJUnitRunner;
|
|
|
|
|
|
import org.springframework.security.core.Authentication;
|
|
import org.springframework.security.core.Authentication;
|
|
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
|
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
|
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
|
|
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
|
|
|
|
+import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter;
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
|
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
|
|
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
|
|
import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor;
|
|
import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor;
|
|
import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
|
|
import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
|
|
|
|
+import org.springframework.security.web.server.savedrequest.ServerRequestCache;
|
|
|
|
+import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache;
|
|
import reactor.core.publisher.Mono;
|
|
import reactor.core.publisher.Mono;
|
|
import reactor.test.publisher.TestPublisher;
|
|
import reactor.test.publisher.TestPublisher;
|
|
|
|
|
|
@@ -64,7 +70,6 @@ import org.springframework.security.web.server.context.WebSessionServerSecurityC
|
|
import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
|
|
import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
|
|
import org.springframework.security.web.server.csrf.CsrfWebFilter;
|
|
import org.springframework.security.web.server.csrf.CsrfWebFilter;
|
|
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
|
|
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
|
|
-import org.springframework.test.util.ReflectionTestUtils;
|
|
|
|
import org.springframework.test.web.reactive.server.EntityExchangeResult;
|
|
import org.springframework.test.web.reactive.server.EntityExchangeResult;
|
|
import org.springframework.test.web.reactive.server.FluxExchangeResult;
|
|
import org.springframework.test.web.reactive.server.FluxExchangeResult;
|
|
import org.springframework.test.web.reactive.server.WebTestClient;
|
|
import org.springframework.test.web.reactive.server.WebTestClient;
|
|
@@ -200,7 +205,7 @@ public class ServerHttpSecurityTests {
|
|
.isNotPresent();
|
|
.isNotPresent();
|
|
|
|
|
|
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
|
|
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
|
|
- .map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
|
|
|
|
|
|
+ .map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
|
|
|
|
|
|
assertThat(logoutHandler)
|
|
assertThat(logoutHandler)
|
|
.get()
|
|
.get()
|
|
@@ -213,17 +218,17 @@ public class ServerHttpSecurityTests {
|
|
|
|
|
|
assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class))
|
|
assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class))
|
|
.get()
|
|
.get()
|
|
- .extracting(csrfWebFilter -> ReflectionTestUtils.getField(csrfWebFilter, "csrfTokenRepository"))
|
|
|
|
|
|
+ .extracting(csrfWebFilter -> getField(csrfWebFilter, "csrfTokenRepository"))
|
|
.isEqualTo(this.csrfTokenRepository);
|
|
.isEqualTo(this.csrfTokenRepository);
|
|
|
|
|
|
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
|
|
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
|
|
- .map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
|
|
|
|
|
|
+ .map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
|
|
|
|
|
|
assertThat(logoutHandler)
|
|
assertThat(logoutHandler)
|
|
.get()
|
|
.get()
|
|
.isExactlyInstanceOf(DelegatingServerLogoutHandler.class)
|
|
.isExactlyInstanceOf(DelegatingServerLogoutHandler.class)
|
|
.extracting(delegatingLogoutHandler ->
|
|
.extracting(delegatingLogoutHandler ->
|
|
- ((List<ServerLogoutHandler>) ReflectionTestUtils.getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream()
|
|
|
|
|
|
+ ((List<ServerLogoutHandler>) getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream()
|
|
.map(ServerLogoutHandler::getClass)
|
|
.map(ServerLogoutHandler::getClass)
|
|
.collect(Collectors.toList()))
|
|
.collect(Collectors.toList()))
|
|
.isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class));
|
|
.isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class));
|
|
@@ -479,6 +484,33 @@ public class ServerHttpSecurityTests {
|
|
verify(customServerCsrfTokenRepository).loadToken(any());
|
|
verify(customServerCsrfTokenRepository).loadToken(any());
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ @Test
|
|
|
|
+ public void shouldConfigureRequestCacheForOAuth2LoginAuthenticationEntryPointAndSuccessHandler() {
|
|
|
|
+ ServerRequestCache requestCache = spy(new WebSessionServerRequestCache());
|
|
|
|
+ ReactiveClientRegistrationRepository clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class);
|
|
|
|
+
|
|
|
|
+ SecurityWebFilterChain securityFilterChain = this.http
|
|
|
|
+ .oauth2Login()
|
|
|
|
+ .clientRegistrationRepository(clientRegistrationRepository)
|
|
|
|
+ .and()
|
|
|
|
+ .authorizeExchange().anyExchange().authenticated()
|
|
|
|
+ .and()
|
|
|
|
+ .requestCache(c -> c.requestCache(requestCache))
|
|
|
|
+ .build();
|
|
|
|
+
|
|
|
|
+ WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build();
|
|
|
|
+ client.get().uri("/test").exchange();
|
|
|
|
+ ArgumentCaptor<ServerWebExchange> captor = ArgumentCaptor.forClass(ServerWebExchange.class);
|
|
|
|
+ verify(requestCache).saveRequest(captor.capture());
|
|
|
|
+ assertThat(captor.getValue().getRequest().getURI().toString()).isEqualTo("/test");
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ OAuth2LoginAuthenticationWebFilter authenticationWebFilter =
|
|
|
|
+ getWebFilter(securityFilterChain, OAuth2LoginAuthenticationWebFilter.class).get();
|
|
|
|
+ Object handler = getField(authenticationWebFilter, "authenticationSuccessHandler");
|
|
|
|
+ assertThat(getField(handler, "requestCache")).isSameAs(requestCache);
|
|
|
|
+ }
|
|
|
|
+
|
|
@Test
|
|
@Test
|
|
public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() {
|
|
public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() {
|
|
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class);
|
|
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class);
|
|
@@ -503,7 +535,7 @@ public class ServerHttpSecurityTests {
|
|
|
|
|
|
private boolean isX509Filter(WebFilter filter) {
|
|
private boolean isX509Filter(WebFilter filter) {
|
|
try {
|
|
try {
|
|
- Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter");
|
|
|
|
|
|
+ Object converter = getField(filter, "authenticationConverter");
|
|
return converter.getClass().isAssignableFrom(ServerX509AuthenticationConverter.class);
|
|
return converter.getClass().isAssignableFrom(ServerX509AuthenticationConverter.class);
|
|
} catch (IllegalArgumentException e) {
|
|
} catch (IllegalArgumentException e) {
|
|
// field doesn't exist
|
|
// field doesn't exist
|