|
@@ -15,32 +15,36 @@
|
|
|
*/
|
|
|
package org.springframework.security.oauth2.client.web;
|
|
|
|
|
|
-import org.assertj.core.api.Assertions;
|
|
|
+import org.junit.Before;
|
|
|
import org.junit.Test;
|
|
|
+import org.junit.runner.RunWith;
|
|
|
import org.mockito.ArgumentCaptor;
|
|
|
-import org.mockito.Matchers;
|
|
|
-import org.mockito.Mockito;
|
|
|
+import org.powermock.core.classloader.annotations.PowerMockIgnore;
|
|
|
+import org.powermock.core.classloader.annotations.PrepareForTest;
|
|
|
+import org.powermock.modules.junit4.PowerMockRunner;
|
|
|
import org.springframework.mock.web.MockHttpServletRequest;
|
|
|
import org.springframework.mock.web.MockHttpServletResponse;
|
|
|
import org.springframework.security.authentication.AuthenticationManager;
|
|
|
import org.springframework.security.core.Authentication;
|
|
|
import org.springframework.security.core.AuthenticationException;
|
|
|
import org.springframework.security.core.authority.AuthorityUtils;
|
|
|
-import org.springframework.security.core.context.SecurityContextHolder;
|
|
|
+import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
|
|
|
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
|
|
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
|
|
|
-import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
|
|
|
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
|
|
|
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
|
|
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
|
|
|
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
|
|
|
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
|
|
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
|
|
|
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
|
|
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
|
|
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
|
|
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
|
|
import org.springframework.security.oauth2.core.user.OAuth2User;
|
|
|
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
|
|
|
-import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
|
|
|
|
|
|
import javax.servlet.FilterChain;
|
|
|
import javax.servlet.http.HttpServletRequest;
|
|
@@ -48,183 +52,235 @@ import javax.servlet.http.HttpServletResponse;
|
|
|
import java.util.HashMap;
|
|
|
import java.util.Map;
|
|
|
|
|
|
-import static org.mockito.Mockito.mock;
|
|
|
-import static org.mockito.Mockito.when;
|
|
|
+import static org.assertj.core.api.Assertions.assertThat;
|
|
|
+import static org.mockito.ArgumentMatchers.any;
|
|
|
+import static org.mockito.Mockito.*;
|
|
|
|
|
|
/**
|
|
|
- * Tests {@link OAuth2LoginAuthenticationFilter}.
|
|
|
+ * Tests for {@link OAuth2LoginAuthenticationFilter}.
|
|
|
*
|
|
|
* @author Joe Grandja
|
|
|
*/
|
|
|
+@PowerMockIgnore("javax.security.*")
|
|
|
+@PrepareForTest({OAuth2AuthorizationRequest.class, OAuth2AuthorizationExchange.class})
|
|
|
+@RunWith(PowerMockRunner.class)
|
|
|
public class OAuth2LoginAuthenticationFilterTests {
|
|
|
+ private ClientRegistration registration1;
|
|
|
+ private ClientRegistration registration2;
|
|
|
+ private String principalName1 = "principal-1";
|
|
|
+ private ClientRegistrationRepository clientRegistrationRepository;
|
|
|
+ private OAuth2AuthorizedClientService authorizedClientService;
|
|
|
+ private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
|
|
|
+ private AuthenticationFailureHandler failureHandler;
|
|
|
+ private AuthenticationManager authenticationManager;
|
|
|
+ private OAuth2LoginAuthenticationFilter filter;
|
|
|
+
|
|
|
+ @Before
|
|
|
+ public void setUp() {
|
|
|
+ this.registration1 = ClientRegistration.withRegistrationId("registration-1")
|
|
|
+ .clientId("client-1")
|
|
|
+ .clientSecret("secret")
|
|
|
+ .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
|
|
|
+ .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
|
|
|
+ .redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/code/{registrationId}")
|
|
|
+ .scope("user")
|
|
|
+ .authorizationUri("https://provider.com/oauth2/authorize")
|
|
|
+ .tokenUri("https://provider.com/oauth2/token")
|
|
|
+ .userInfoUri("https://provider.com/oauth2/user")
|
|
|
+ .userNameAttributeName("id")
|
|
|
+ .clientName("client-1")
|
|
|
+ .build();
|
|
|
+ this.registration2 = ClientRegistration.withRegistrationId("registration-2")
|
|
|
+ .clientId("client-2")
|
|
|
+ .clientSecret("secret")
|
|
|
+ .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
|
|
|
+ .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
|
|
|
+ .redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/code/{registrationId}")
|
|
|
+ .scope("openid", "profile", "email")
|
|
|
+ .authorizationUri("https://provider.com/oauth2/authorize")
|
|
|
+ .tokenUri("https://provider.com/oauth2/token")
|
|
|
+ .userInfoUri("https://provider.com/oauth2/userinfo")
|
|
|
+ .jwkSetUri("https://provider.com/oauth2/keys")
|
|
|
+ .clientName("client-2")
|
|
|
+ .build();
|
|
|
+ this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(
|
|
|
+ this.registration1, this.registration2);
|
|
|
+ this.authorizedClientService = new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository);
|
|
|
+ this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
|
|
|
+ this.failureHandler = mock(AuthenticationFailureHandler.class);
|
|
|
+ this.authenticationManager = mock(AuthenticationManager.class);
|
|
|
+ this.filter = spy(new OAuth2LoginAuthenticationFilter(
|
|
|
+ this.clientRegistrationRepository, this.authorizedClientService));
|
|
|
+ this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
|
|
|
+ this.filter.setAuthenticationFailureHandler(this.failureHandler);
|
|
|
+ this.filter.setAuthenticationManager(this.authenticationManager);
|
|
|
+ }
|
|
|
|
|
|
- @Test
|
|
|
- public void doFilterWhenNotAuthorizationCodeResponseThenContinueChain() throws Exception {
|
|
|
- ClientRegistration clientRegistration = TestUtil.googleClientRegistration();
|
|
|
+ @Test(expected = IllegalArgumentException.class)
|
|
|
+ public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
|
|
|
+ new OAuth2LoginAuthenticationFilter(null, this.authorizedClientService);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test(expected = IllegalArgumentException.class)
|
|
|
+ public void constructorWhenAuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
|
|
|
+ new OAuth2LoginAuthenticationFilter(this.clientRegistrationRepository, null);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test(expected = IllegalArgumentException.class)
|
|
|
+ public void constructorWhenFilterProcessesUrlIsNullThenThrowIllegalArgumentException() {
|
|
|
+ new OAuth2LoginAuthenticationFilter(null, this.clientRegistrationRepository, this.authorizedClientService);
|
|
|
+ }
|
|
|
|
|
|
- OAuth2LoginAuthenticationFilter filter = Mockito.spy(setupFilter(clientRegistration));
|
|
|
+ @Test(expected = IllegalArgumentException.class)
|
|
|
+ public void setAuthorizationRequestRepositoryWhenAuthorizationRequestRepositoryIsNullThenThrowIllegalArgumentException() {
|
|
|
+ this.filter.setAuthorizationRequestRepository(null);
|
|
|
+ }
|
|
|
|
|
|
- String requestURI = "/path";
|
|
|
- MockHttpServletRequest request = new MockHttpServletRequest("GET", requestURI);
|
|
|
- request.setServletPath(requestURI);
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenNotAuthorizationResponseThenNextFilter() throws Exception {
|
|
|
+ String requestUri = "/path";
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
+ request.setServletPath(requestUri);
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
FilterChain filterChain = mock(FilterChain.class);
|
|
|
|
|
|
- filter.doFilter(request, response, filterChain);
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
- Mockito.verify(filterChain).doFilter(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class));
|
|
|
- Mockito.verify(filter, Mockito.never()).attemptAuthentication(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class));
|
|
|
+ verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
|
|
+ verify(this.filter, never()).attemptAuthentication(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterWhenAuthorizationCodeErrorResponseThenAuthenticationFailureHandlerIsCalled() throws Exception {
|
|
|
- ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
|
|
|
+ public void doFilterWhenAuthorizationResponseInvalidThenInvalidRequestError() throws Exception {
|
|
|
+ String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
+ request.setServletPath(requestUri);
|
|
|
+ // NOTE:
|
|
|
+ // A valid Authorization Response contains either a 'code' or 'error' parameter.
|
|
|
+ // Don't set it to force an invalid Authorization Response.
|
|
|
|
|
|
- OAuth2LoginAuthenticationFilter filter = Mockito.spy(setupFilter(clientRegistration));
|
|
|
- AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class);
|
|
|
- filter.setAuthenticationFailureHandler(failureHandler);
|
|
|
-
|
|
|
- MockHttpServletRequest request = this.setupRequest(clientRegistration);
|
|
|
- String errorCode = OAuth2ErrorCodes.INVALID_GRANT;
|
|
|
- request.addParameter(OAuth2ParameterNames.ERROR, errorCode);
|
|
|
- request.addParameter(OAuth2ParameterNames.STATE, "some state");
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
FilterChain filterChain = mock(FilterChain.class);
|
|
|
|
|
|
- filter.doFilter(request, response, filterChain);
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
+
|
|
|
+ ArgumentCaptor<AuthenticationException> authenticationExceptionArgCaptor = ArgumentCaptor.forClass(AuthenticationException.class);
|
|
|
+ verify(this.failureHandler).onAuthenticationFailure(any(HttpServletRequest.class), any(HttpServletResponse.class),
|
|
|
+ authenticationExceptionArgCaptor.capture());
|
|
|
|
|
|
- Mockito.verify(filter).attemptAuthentication(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class));
|
|
|
- Mockito.verify(failureHandler).onAuthenticationFailure(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class),
|
|
|
- Matchers.any(AuthenticationException.class));
|
|
|
+ assertThat(authenticationExceptionArgCaptor.getValue()).isInstanceOf(OAuth2AuthenticationException.class);
|
|
|
+ OAuth2AuthenticationException authenticationException = (OAuth2AuthenticationException) authenticationExceptionArgCaptor.getValue();
|
|
|
+ assertThat(authenticationException.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterWhenAuthorizationCodeSuccessResponseThenAuthenticationSuccessHandlerIsCalled() throws Exception {
|
|
|
- ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
|
|
|
- OAuth2User oauth2User = mock(OAuth2User.class);
|
|
|
- when(oauth2User.getName()).thenReturn("principal name");
|
|
|
- OAuth2LoginAuthenticationToken loginAuthentication = mock(OAuth2LoginAuthenticationToken.class);
|
|
|
- when(loginAuthentication.getPrincipal()).thenReturn(oauth2User);
|
|
|
- when(loginAuthentication.getClientRegistration()).thenReturn(clientRegistration);
|
|
|
- when(loginAuthentication.getAccessToken()).thenReturn(mock(OAuth2AccessToken.class));
|
|
|
+ public void doFilterWhenAuthorizationResponseAuthorizationRequestNotFoundThenAuthorizationRequestNotFoundError() throws Exception {
|
|
|
+ String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
+ request.setServletPath(requestUri);
|
|
|
+ request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
+ request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
|
|
|
- OAuth2AuthenticationToken userAuthentication = new OAuth2AuthenticationToken(
|
|
|
- oauth2User, AuthorityUtils.NO_AUTHORITIES, clientRegistration.getRegistrationId());
|
|
|
- SecurityContextHolder.getContext().setAuthentication(userAuthentication);
|
|
|
- AuthenticationManager authenticationManager = mock(AuthenticationManager.class);
|
|
|
- when(authenticationManager.authenticate(Matchers.any(Authentication.class))).thenReturn(loginAuthentication);
|
|
|
-
|
|
|
- OAuth2LoginAuthenticationFilter filter = Mockito.spy(setupFilter(authenticationManager, clientRegistration));
|
|
|
- AuthenticationSuccessHandler successHandler = mock(AuthenticationSuccessHandler.class);
|
|
|
- filter.setAuthenticationSuccessHandler(successHandler);
|
|
|
- AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
|
|
|
- new HttpSessionOAuth2AuthorizationRequestRepository();
|
|
|
- filter.setAuthorizationRequestRepository(authorizationRequestRepository);
|
|
|
-
|
|
|
- MockHttpServletRequest request = this.setupRequest(clientRegistration);
|
|
|
- String authCode = "some code";
|
|
|
- String state = "some state";
|
|
|
- request.addParameter(OAuth2ParameterNames.CODE, authCode);
|
|
|
- request.addParameter(OAuth2ParameterNames.STATE, state);
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
- setupAuthorizationRequest(authorizationRequestRepository, request, response, clientRegistration, state);
|
|
|
FilterChain filterChain = mock(FilterChain.class);
|
|
|
|
|
|
- filter.doFilter(request, response, filterChain);
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
- Mockito.verify(filter).attemptAuthentication(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class));
|
|
|
+ ArgumentCaptor<AuthenticationException> authenticationExceptionArgCaptor = ArgumentCaptor.forClass(AuthenticationException.class);
|
|
|
+ verify(this.failureHandler).onAuthenticationFailure(any(HttpServletRequest.class), any(HttpServletResponse.class),
|
|
|
+ authenticationExceptionArgCaptor.capture());
|
|
|
|
|
|
- ArgumentCaptor<Authentication> authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class);
|
|
|
- Mockito.verify(successHandler).onAuthenticationSuccess(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class),
|
|
|
- authenticationArgCaptor.capture());
|
|
|
- Assertions.assertThat(authenticationArgCaptor.getValue()).isEqualTo(userAuthentication);
|
|
|
+ assertThat(authenticationExceptionArgCaptor.getValue()).isInstanceOf(OAuth2AuthenticationException.class);
|
|
|
+ OAuth2AuthenticationException authenticationException = (OAuth2AuthenticationException) authenticationExceptionArgCaptor.getValue();
|
|
|
+ assertThat(authenticationException.getError().getErrorCode()).isEqualTo("authorization_request_not_found");
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterWhenAuthorizationCodeSuccessResponseAndNoMatchingAuthorizationRequestThenThrowOAuth2AuthenticationExceptionAuthorizationRequestNotFound() throws Exception {
|
|
|
- ClientRegistration clientRegistration = TestUtil.githubClientRegistration();
|
|
|
-
|
|
|
- OAuth2LoginAuthenticationFilter filter = Mockito.spy(setupFilter(clientRegistration));
|
|
|
- AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class);
|
|
|
- filter.setAuthenticationFailureHandler(failureHandler);
|
|
|
-
|
|
|
- MockHttpServletRequest request = this.setupRequest(clientRegistration);
|
|
|
- String authCode = "some code";
|
|
|
- String state = "some state";
|
|
|
- request.addParameter(OAuth2ParameterNames.CODE, authCode);
|
|
|
- request.addParameter(OAuth2ParameterNames.STATE, state);
|
|
|
+ public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved() throws Exception {
|
|
|
+ String requestUri = "/login/oauth2/code/" + this.registration2.getRegistrationId();
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
+ request.setServletPath(requestUri);
|
|
|
+ request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
+ request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
+
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
FilterChain filterChain = mock(FilterChain.class);
|
|
|
|
|
|
- filter.doFilter(request, response, filterChain);
|
|
|
+ this.setUpAuthorizationRequest(request, response, this.registration2);
|
|
|
+ this.setUpAuthenticationResult(this.registration2);
|
|
|
|
|
|
- verifyThrowsOAuth2AuthenticationExceptionWithErrorCode(filter, failureHandler, "authorization_request_not_found");
|
|
|
- }
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
- private void verifyThrowsOAuth2AuthenticationExceptionWithErrorCode(OAuth2LoginAuthenticationFilter filter,
|
|
|
- AuthenticationFailureHandler failureHandler,
|
|
|
- String errorCode) throws Exception {
|
|
|
-
|
|
|
- Mockito.verify(filter).attemptAuthentication(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class));
|
|
|
-
|
|
|
- ArgumentCaptor<AuthenticationException> authenticationExceptionArgCaptor =
|
|
|
- ArgumentCaptor.forClass(AuthenticationException.class);
|
|
|
- Mockito.verify(failureHandler).onAuthenticationFailure(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class),
|
|
|
- authenticationExceptionArgCaptor.capture());
|
|
|
- Assertions.assertThat(authenticationExceptionArgCaptor.getValue()).isInstanceOf(OAuth2AuthenticationException.class);
|
|
|
- OAuth2AuthenticationException oauth2AuthenticationException =
|
|
|
- (OAuth2AuthenticationException)authenticationExceptionArgCaptor.getValue();
|
|
|
- Assertions.assertThat(oauth2AuthenticationException.getError()).isNotNull();
|
|
|
- Assertions.assertThat(oauth2AuthenticationException.getError().getErrorCode()).isEqualTo(errorCode);
|
|
|
+ assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(request)).isNull();
|
|
|
}
|
|
|
|
|
|
- private OAuth2LoginAuthenticationFilter setupFilter(ClientRegistration... clientRegistrations) throws Exception {
|
|
|
- AuthenticationManager authenticationManager = mock(AuthenticationManager.class);
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenAuthorizationResponseValidThenAuthorizedClientSaved() throws Exception {
|
|
|
+ String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
+ request.setServletPath(requestUri);
|
|
|
+ request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
+ request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
+
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
+ FilterChain filterChain = mock(FilterChain.class);
|
|
|
+
|
|
|
+ this.setUpAuthorizationRequest(request, response, this.registration1);
|
|
|
+ this.setUpAuthenticationResult(this.registration1);
|
|
|
+
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
- return setupFilter(authenticationManager, clientRegistrations);
|
|
|
+ OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
|
|
|
+ this.registration1.getRegistrationId(), this.principalName1);
|
|
|
+ assertThat(authorizedClient).isNotNull();
|
|
|
+ assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1);
|
|
|
+ assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName1);
|
|
|
+ assertThat(authorizedClient.getAccessToken()).isNotNull();
|
|
|
}
|
|
|
|
|
|
- private OAuth2LoginAuthenticationFilter setupFilter(
|
|
|
- AuthenticationManager authenticationManager, ClientRegistration... clientRegistrations) throws Exception {
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenCustomFilterProcessesUrlThenFilterProcesses() throws Exception {
|
|
|
+ String filterProcessesUrl = "/login/oauth2/custom/*";
|
|
|
+ this.filter = spy(new OAuth2LoginAuthenticationFilter(filterProcessesUrl,
|
|
|
+ this.clientRegistrationRepository, this.authorizedClientService));
|
|
|
+ this.filter.setAuthenticationManager(this.authenticationManager);
|
|
|
+
|
|
|
+ String requestUri = "/login/oauth2/custom/" + this.registration2.getRegistrationId();
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
+ request.setServletPath(requestUri);
|
|
|
+ request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
+ request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
+
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
+ FilterChain filterChain = mock(FilterChain.class);
|
|
|
|
|
|
- ClientRegistrationRepository clientRegistrationRepository = new InMemoryClientRegistrationRepository(clientRegistrations);
|
|
|
+ this.setUpAuthorizationRequest(request, response, this.registration2);
|
|
|
+ this.setUpAuthenticationResult(this.registration2);
|
|
|
|
|
|
- OAuth2LoginAuthenticationFilter filter = new OAuth2LoginAuthenticationFilter(
|
|
|
- clientRegistrationRepository, mock(OAuth2AuthorizedClientService.class));
|
|
|
- filter.setAuthenticationManager(authenticationManager);
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
- return filter;
|
|
|
+ verifyZeroInteractions(filterChain);
|
|
|
+ verify(this.filter).attemptAuthentication(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
|
|
}
|
|
|
|
|
|
- private void setupAuthorizationRequest(AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository,
|
|
|
- HttpServletRequest request,
|
|
|
- HttpServletResponse response,
|
|
|
- ClientRegistration clientRegistration,
|
|
|
- String state) {
|
|
|
-
|
|
|
- Map<String,Object> additionalParameters = new HashMap<>();
|
|
|
- additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());
|
|
|
-
|
|
|
- OAuth2AuthorizationRequest authorizationRequest =
|
|
|
- OAuth2AuthorizationRequest.authorizationCode()
|
|
|
- .clientId(clientRegistration.getClientId())
|
|
|
- .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
|
|
|
- .redirectUri(clientRegistration.getRedirectUri())
|
|
|
- .scopes(clientRegistration.getScopes())
|
|
|
- .state(state)
|
|
|
- .additionalParameters(additionalParameters)
|
|
|
- .build();
|
|
|
-
|
|
|
- authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
|
|
|
+ private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
|
|
|
+ ClientRegistration registration) {
|
|
|
+ OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
|
|
|
+ Map<String, Object> additionalParameters = new HashMap<>();
|
|
|
+ additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
|
|
|
+ when(authorizationRequest.getAdditionalParameters()).thenReturn(additionalParameters);
|
|
|
+ this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
|
|
|
}
|
|
|
|
|
|
- private MockHttpServletRequest setupRequest(ClientRegistration clientRegistration) {
|
|
|
- String requestURI = TestUtil.AUTHORIZE_BASE_URI + "/" + clientRegistration.getRegistrationId();
|
|
|
- MockHttpServletRequest request = new MockHttpServletRequest("GET", requestURI);
|
|
|
- request.setScheme(TestUtil.DEFAULT_SCHEME);
|
|
|
- request.setServerName(TestUtil.DEFAULT_SERVER_NAME);
|
|
|
- request.setServerPort(TestUtil.DEFAULT_SERVER_PORT);
|
|
|
- request.setServletPath(requestURI);
|
|
|
- return request;
|
|
|
+ private void setUpAuthenticationResult(ClientRegistration registration) {
|
|
|
+ OAuth2User user = mock(OAuth2User.class);
|
|
|
+ when(user.getName()).thenReturn(this.principalName1);
|
|
|
+ OAuth2LoginAuthenticationToken loginAuthentication = mock(OAuth2LoginAuthenticationToken.class);
|
|
|
+ when(loginAuthentication.getPrincipal()).thenReturn(user);
|
|
|
+ when(loginAuthentication.getAuthorities()).thenReturn(AuthorityUtils.createAuthorityList("ROLE_USER"));
|
|
|
+ when(loginAuthentication.getClientRegistration()).thenReturn(registration);
|
|
|
+ when(loginAuthentication.getAuthorizationExchange()).thenReturn(mock(OAuth2AuthorizationExchange.class));
|
|
|
+ when(loginAuthentication.getAccessToken()).thenReturn(mock(OAuth2AccessToken.class));
|
|
|
+ when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(loginAuthentication);
|
|
|
}
|
|
|
}
|