|
@@ -53,11 +53,14 @@ import org.springframework.security.oauth2.server.authorization.authentication.O
|
|
|
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
|
|
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
|
|
|
import org.springframework.security.web.authentication.AuthenticationConverter;
|
|
|
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
|
|
|
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
|
|
|
import org.springframework.util.StringUtils;
|
|
|
|
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
|
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
|
|
import static org.mockito.ArgumentMatchers.any;
|
|
|
+import static org.mockito.ArgumentMatchers.same;
|
|
|
import static org.mockito.Mockito.mock;
|
|
|
import static org.mockito.Mockito.verify;
|
|
|
import static org.mockito.Mockito.verifyNoInteractions;
|
|
@@ -118,6 +121,20 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
.hasMessage("authenticationConverter cannot be null");
|
|
|
}
|
|
|
|
|
|
+ @Test
|
|
|
+ public void setAuthenticationSuccessHandlerWhenNullThenThrowIllegalArgumentException() {
|
|
|
+ assertThatThrownBy(() -> this.filter.setAuthenticationSuccessHandler(null))
|
|
|
+ .isInstanceOf(IllegalArgumentException.class)
|
|
|
+ .hasMessage("authenticationSuccessHandler cannot be null");
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void setAuthenticationFailureHandlerWhenNullThenThrowIllegalArgumentException() {
|
|
|
+ assertThatThrownBy(() -> this.filter.setAuthenticationFailureHandler(null))
|
|
|
+ .isInstanceOf(IllegalArgumentException.class)
|
|
|
+ .hasMessage("authenticationFailureHandler cannot be null");
|
|
|
+ }
|
|
|
+
|
|
|
@Test
|
|
|
public void doFilterWhenNotAuthorizationRequestThenNotProcessed() throws Exception {
|
|
|
String requestUri = "/path";
|
|
@@ -275,6 +292,57 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
|
|
}
|
|
|
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenCustomAuthenticationSuccessHandlerThenUsed() throws Exception {
|
|
|
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
|
|
+ OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult =
|
|
|
+ authorizationCodeRequestAuthentication(registeredClient, this.principal)
|
|
|
+ .authorizationCode(this.authorizationCode)
|
|
|
+ .build();
|
|
|
+ authorizationCodeRequestAuthenticationResult.setAuthenticated(true);
|
|
|
+ when(this.authenticationManager.authenticate(any()))
|
|
|
+ .thenReturn(authorizationCodeRequestAuthenticationResult);
|
|
|
+
|
|
|
+ AuthenticationSuccessHandler authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class);
|
|
|
+ this.filter.setAuthenticationSuccessHandler(authenticationSuccessHandler);
|
|
|
+
|
|
|
+ MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
+ FilterChain filterChain = mock(FilterChain.class);
|
|
|
+
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
+
|
|
|
+ verify(this.authenticationManager).authenticate(any());
|
|
|
+ verifyNoInteractions(filterChain);
|
|
|
+ verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), same(authorizationCodeRequestAuthenticationResult));
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenCustomAuthenticationFailureHandlerThenUsed() throws Exception {
|
|
|
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
|
|
+ OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
|
|
|
+ authorizationCodeRequestAuthentication(registeredClient, this.principal)
|
|
|
+ .build();
|
|
|
+ OAuth2Error error = new OAuth2Error("errorCode", "errorDescription", "errorUri");
|
|
|
+ OAuth2AuthorizationCodeRequestAuthenticationException authenticationException =
|
|
|
+ new OAuth2AuthorizationCodeRequestAuthenticationException(error, authorizationCodeRequestAuthentication);
|
|
|
+ when(this.authenticationManager.authenticate(any()))
|
|
|
+ .thenThrow(authenticationException);
|
|
|
+
|
|
|
+ AuthenticationFailureHandler authenticationFailureHandler = mock(AuthenticationFailureHandler.class);
|
|
|
+ this.filter.setAuthenticationFailureHandler(authenticationFailureHandler);
|
|
|
+
|
|
|
+ MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
+ FilterChain filterChain = mock(FilterChain.class);
|
|
|
+
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
+
|
|
|
+ verify(this.authenticationManager).authenticate(any());
|
|
|
+ verifyNoInteractions(filterChain);
|
|
|
+ verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), same(authenticationException));
|
|
|
+ }
|
|
|
+
|
|
|
@Test
|
|
|
public void doFilterWhenAuthorizationRequestPrincipalNotAuthenticatedThenCommenceAuthentication() throws Exception {
|
|
|
this.principal.setAuthenticated(false);
|