|
@@ -0,0 +1,252 @@
|
|
|
+/*
|
|
|
+ * Copyright 2002-2018 the original author or authors.
|
|
|
+ *
|
|
|
+ * Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
+ * you may not use this file except in compliance with the License.
|
|
|
+ * You may obtain a copy of the License at
|
|
|
+ *
|
|
|
+ * http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+ *
|
|
|
+ * Unless required by applicable law or agreed to in writing, software
|
|
|
+ * distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
+ * See the License for the specific language governing permissions and
|
|
|
+ * limitations under the License.
|
|
|
+ */
|
|
|
+package org.springframework.security.oauth2.client.web;
|
|
|
+
|
|
|
+import org.junit.Before;
|
|
|
+import org.junit.Test;
|
|
|
+import org.junit.runner.RunWith;
|
|
|
+import org.powermock.core.classloader.annotations.PowerMockIgnore;
|
|
|
+import org.powermock.core.classloader.annotations.PrepareForTest;
|
|
|
+import org.powermock.modules.junit4.PowerMockRunner;
|
|
|
+import org.springframework.mock.web.MockHttpServletRequest;
|
|
|
+import org.springframework.mock.web.MockHttpServletResponse;
|
|
|
+import org.springframework.security.authentication.AuthenticationManager;
|
|
|
+import org.springframework.security.authentication.TestingAuthenticationToken;
|
|
|
+import org.springframework.security.core.Authentication;
|
|
|
+import org.springframework.security.core.context.SecurityContext;
|
|
|
+import org.springframework.security.core.context.SecurityContextHolder;
|
|
|
+import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
|
|
|
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
|
|
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
|
|
|
+import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
|
|
|
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
|
|
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
|
|
|
+import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
|
|
|
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
|
|
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
|
|
|
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
|
|
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
|
|
+import org.springframework.security.oauth2.core.OAuth2Error;
|
|
|
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
|
|
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
|
|
|
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
|
|
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
|
|
+
|
|
|
+import javax.servlet.FilterChain;
|
|
|
+import javax.servlet.http.HttpServletRequest;
|
|
|
+import javax.servlet.http.HttpServletResponse;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.Map;
|
|
|
+
|
|
|
+import static org.assertj.core.api.Assertions.assertThat;
|
|
|
+import static org.mockito.Mockito.*;
|
|
|
+
|
|
|
+/**
|
|
|
+ * Tests for {@link OAuth2AuthorizationCodeGrantFilter}.
|
|
|
+ *
|
|
|
+ * @author Joe Grandja
|
|
|
+ */
|
|
|
+@PowerMockIgnore("javax.security.*")
|
|
|
+@PrepareForTest({OAuth2AuthorizationRequest.class, OAuth2AuthorizationExchange.class, OAuth2AuthorizationCodeGrantFilter.class})
|
|
|
+@RunWith(PowerMockRunner.class)
|
|
|
+public class OAuth2AuthorizationCodeGrantFilterTests {
|
|
|
+ private ClientRegistration registration1;
|
|
|
+ private String principalName1 = "principal-1";
|
|
|
+ private ClientRegistrationRepository clientRegistrationRepository;
|
|
|
+ private OAuth2AuthorizedClientService authorizedClientService;
|
|
|
+ private AuthenticationManager authenticationManager;
|
|
|
+ private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
|
|
|
+ private OAuth2AuthorizationCodeGrantFilter filter;
|
|
|
+
|
|
|
+ @Before
|
|
|
+ public void setUp() {
|
|
|
+ this.registration1 = ClientRegistration.withRegistrationId("registration-1")
|
|
|
+ .clientId("client-1")
|
|
|
+ .clientSecret("secret")
|
|
|
+ .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
|
|
|
+ .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
|
|
|
+ .redirectUriTemplate("{baseUrl}/callback/client-1")
|
|
|
+ .scope("user")
|
|
|
+ .authorizationUri("https://provider.com/oauth2/authorize")
|
|
|
+ .tokenUri("https://provider.com/oauth2/token")
|
|
|
+ .userInfoUri("https://provider.com/oauth2/user")
|
|
|
+ .userNameAttributeName("id")
|
|
|
+ .clientName("client-1")
|
|
|
+ .build();
|
|
|
+ this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1);
|
|
|
+ this.authorizedClientService = new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository);
|
|
|
+ this.authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
|
|
|
+ this.authenticationManager = mock(AuthenticationManager.class);
|
|
|
+ this.filter = spy(new OAuth2AuthorizationCodeGrantFilter(
|
|
|
+ this.clientRegistrationRepository, this.authorizedClientService, this.authenticationManager));
|
|
|
+ this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
|
|
|
+
|
|
|
+ SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
|
|
|
+ securityContext.setAuthentication(new TestingAuthenticationToken(this.principalName1, "password"));
|
|
|
+ SecurityContextHolder.setContext(securityContext);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test(expected = IllegalArgumentException.class)
|
|
|
+ public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
|
|
|
+ new OAuth2AuthorizationCodeGrantFilter(null, this.authorizedClientService, this.authenticationManager);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test(expected = IllegalArgumentException.class)
|
|
|
+ public void constructorWhenAuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
|
|
|
+ new OAuth2AuthorizationCodeGrantFilter(this.clientRegistrationRepository, null, this.authenticationManager);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test(expected = IllegalArgumentException.class)
|
|
|
+ public void constructorWhenAuthenticationManagerIsNullThenThrowIllegalArgumentException() {
|
|
|
+ new OAuth2AuthorizationCodeGrantFilter(this.clientRegistrationRepository, this.authorizedClientService, null);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test(expected = IllegalArgumentException.class)
|
|
|
+ public void setAuthorizationRequestRepositoryWhenAuthorizationRequestRepositoryIsNullThenThrowIllegalArgumentException() {
|
|
|
+ this.filter.setAuthorizationRequestRepository(null);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenNotAuthorizationResponseThenNotProcessed() throws Exception {
|
|
|
+ String requestUri = "/path";
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
+ request.setServletPath(requestUri);
|
|
|
+ // NOTE: A valid Authorization Response contains either a 'code' or 'error' parameter.
|
|
|
+
|
|
|
+ HttpServletResponse response = mock(HttpServletResponse.class);
|
|
|
+ FilterChain filterChain = mock(FilterChain.class);
|
|
|
+
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
+
|
|
|
+ verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenAuthorizationRequestNotFoundThenNotProcessed() throws Exception {
|
|
|
+ String requestUri = "/path";
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
+ request.setServletPath(requestUri);
|
|
|
+ request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
+ request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
+
|
|
|
+ HttpServletResponse response = mock(HttpServletResponse.class);
|
|
|
+ FilterChain filterChain = mock(FilterChain.class);
|
|
|
+
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
+
|
|
|
+ verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenAuthorizationResponseValidThenAuthorizationRequestRemoved() throws Exception {
|
|
|
+ String requestUri = "/callback/client-1";
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
+ request.setServletPath(requestUri);
|
|
|
+ request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
+ request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
+
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
+ FilterChain filterChain = mock(FilterChain.class);
|
|
|
+
|
|
|
+ this.setUpAuthorizationRequest(request, response, this.registration1);
|
|
|
+ this.setUpAuthenticationResult(this.registration1);
|
|
|
+
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
+
|
|
|
+ assertThat(this.authorizationRequestRepository.loadAuthorizationRequest(request)).isNull();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenAuthenticationFailsThenHandleOAuth2AuthenticationException() throws Exception {
|
|
|
+ String requestUri = "/callback/client-1";
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
+ request.setServletPath(requestUri);
|
|
|
+ request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
+ request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
+
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
+ FilterChain filterChain = mock(FilterChain.class);
|
|
|
+
|
|
|
+ this.setUpAuthorizationRequest(request, response, this.registration1);
|
|
|
+ OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT);
|
|
|
+ when(this.authenticationManager.authenticate(any(Authentication.class)))
|
|
|
+ .thenThrow(new OAuth2AuthenticationException(error, error.toString()));
|
|
|
+
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
+
|
|
|
+ assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1?error=invalid_grant");
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenAuthorizationResponseSuccessThenAuthorizedClientSaved() throws Exception {
|
|
|
+ String requestUri = "/callback/client-1";
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
+ request.setServletPath(requestUri);
|
|
|
+ request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
+ request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
+
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
+ FilterChain filterChain = mock(FilterChain.class);
|
|
|
+
|
|
|
+ this.setUpAuthorizationRequest(request, response, this.registration1);
|
|
|
+ this.setUpAuthenticationResult(this.registration1);
|
|
|
+
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
+
|
|
|
+ OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
|
|
|
+ this.registration1.getRegistrationId(), this.principalName1);
|
|
|
+ assertThat(authorizedClient).isNotNull();
|
|
|
+ assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration1);
|
|
|
+ assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName1);
|
|
|
+ assertThat(authorizedClient.getAccessToken()).isNotNull();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenAuthorizationResponseSuccessThenRedirected() throws Exception {
|
|
|
+ String requestUri = "/callback/client-1";
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
+ request.setServletPath(requestUri);
|
|
|
+ request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
+ request.addParameter(OAuth2ParameterNames.STATE, "state");
|
|
|
+
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
+ FilterChain filterChain = mock(FilterChain.class);
|
|
|
+
|
|
|
+ this.setUpAuthorizationRequest(request, response, this.registration1);
|
|
|
+ this.setUpAuthenticationResult(this.registration1);
|
|
|
+
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
+
|
|
|
+ assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/callback/client-1");
|
|
|
+ }
|
|
|
+
|
|
|
+ private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
|
|
|
+ ClientRegistration registration) {
|
|
|
+ Map<String, Object> additionalParameters = new HashMap<>();
|
|
|
+ additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
|
|
|
+ OAuth2AuthorizationRequest authorizationRequest = mock(OAuth2AuthorizationRequest.class);
|
|
|
+ when(authorizationRequest.getAdditionalParameters()).thenReturn(additionalParameters);
|
|
|
+ this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
|
|
|
+ }
|
|
|
+
|
|
|
+ private void setUpAuthenticationResult(ClientRegistration registration) {
|
|
|
+ OAuth2AuthorizationCodeAuthenticationToken authentication = mock(OAuth2AuthorizationCodeAuthenticationToken.class);
|
|
|
+ when(authentication.getClientRegistration()).thenReturn(registration);
|
|
|
+ when(authentication.getAuthorizationExchange()).thenReturn(mock(OAuth2AuthorizationExchange.class));
|
|
|
+ when(authentication.getAccessToken()).thenReturn(mock(OAuth2AccessToken.class));
|
|
|
+ when(this.authenticationManager.authenticate(any(Authentication.class))).thenReturn(authentication);
|
|
|
+ }
|
|
|
+}
|