Browse Source

Polish OAuth2AuthorizationEndpointFilterTests

Issue gh-77
Joe Grandja 5 years ago
parent
commit
02b64f0ef0

+ 57 - 82
core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@@ -26,6 +26,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@@ -41,6 +42,7 @@ import javax.servlet.FilterChain;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 import java.util.Set;
+import java.util.function.Consumer;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -130,53 +132,29 @@ public class OAuth2AuthorizationEndpointFilterTests {
 
 	@Test
 	public void doFilterWhenAuthorizationRequestMissingClientIdThenInvalidRequestError() throws Exception {
-		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
-
-		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
-		request.removeParameter(OAuth2ParameterNames.CLIENT_ID);
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		this.filter.doFilter(request, response, filterChain);
-
-		verifyNoInteractions(filterChain);
-
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
-		assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
+		doFilterWhenAuthorizationRequestInvalidParameterThenError(
+				TestRegisteredClients.registeredClient().build(),
+				OAuth2ParameterNames.CLIENT_ID,
+				OAuth2ErrorCodes.INVALID_REQUEST,
+				request -> request.removeParameter(OAuth2ParameterNames.CLIENT_ID));
 	}
 
 	@Test
 	public void doFilterWhenAuthorizationRequestMultipleClientIdThenInvalidRequestError() throws Exception {
-		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
-
-		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
-		request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		this.filter.doFilter(request, response, filterChain);
-
-		verifyNoInteractions(filterChain);
-
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
-		assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
+		doFilterWhenAuthorizationRequestInvalidParameterThenError(
+				TestRegisteredClients.registeredClient().build(),
+				OAuth2ParameterNames.CLIENT_ID,
+				OAuth2ErrorCodes.INVALID_REQUEST,
+				request -> request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2"));
 	}
 
 	@Test
 	public void doFilterWhenAuthorizationRequestInvalidClientIdThenInvalidRequestError() throws Exception {
-		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
-
-		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
-		request.setParameter(OAuth2ParameterNames.CLIENT_ID, "invalid");
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		this.filter.doFilter(request, response, filterChain);
-
-		verifyNoInteractions(filterChain);
-
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
-		assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: client_id");
+		doFilterWhenAuthorizationRequestInvalidParameterThenError(
+				TestRegisteredClients.registeredClient().build(),
+				OAuth2ParameterNames.CLIENT_ID,
+				OAuth2ErrorCodes.INVALID_REQUEST,
+				request -> request.setParameter(OAuth2ParameterNames.CLIENT_ID, "invalid"));
 	}
 
 	@Test
@@ -188,16 +166,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
 				.thenReturn(registeredClient);
 
-		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		this.filter.doFilter(request, response, filterChain);
-
-		verifyNoInteractions(filterChain);
-
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
-		assertThat(response.getErrorMessage()).isEqualTo("[unauthorized_client] OAuth 2.0 Parameter: client_id");
+		doFilterWhenAuthorizationRequestInvalidParameterThenError(
+				registeredClient,
+				OAuth2ParameterNames.CLIENT_ID,
+				OAuth2ErrorCodes.UNAUTHORIZED_CLIENT);
 	}
 
 	@Test
@@ -206,17 +178,11 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
 				.thenReturn(registeredClient);
 
-		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
-		request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "https://invalid-example.com");
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		this.filter.doFilter(request, response, filterChain);
-
-		verifyNoInteractions(filterChain);
-
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
-		assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
+		doFilterWhenAuthorizationRequestInvalidParameterThenError(
+				registeredClient,
+				OAuth2ParameterNames.REDIRECT_URI,
+				OAuth2ErrorCodes.INVALID_REQUEST,
+				request -> request.setParameter(OAuth2ParameterNames.REDIRECT_URI, "https://invalid-example.com"));
 	}
 
 	@Test
@@ -225,17 +191,11 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
 				.thenReturn(registeredClient);
 
-		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
-		request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com");
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		this.filter.doFilter(request, response, filterChain);
-
-		verifyNoInteractions(filterChain);
-
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
-		assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
+		doFilterWhenAuthorizationRequestInvalidParameterThenError(
+				registeredClient,
+				OAuth2ParameterNames.REDIRECT_URI,
+				OAuth2ErrorCodes.INVALID_REQUEST,
+				request -> request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"));
 	}
 
 	@Test
@@ -244,17 +204,11 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
 				.thenReturn(registeredClient);
 
-		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
-		request.removeParameter(OAuth2ParameterNames.REDIRECT_URI);
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		this.filter.doFilter(request, response, filterChain);
-
-		verifyNoInteractions(filterChain);
-
-		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
-		assertThat(response.getErrorMessage()).isEqualTo("[invalid_request] OAuth 2.0 Parameter: redirect_uri");
+		doFilterWhenAuthorizationRequestInvalidParameterThenError(
+				registeredClient,
+				OAuth2ParameterNames.REDIRECT_URI,
+				OAuth2ErrorCodes.INVALID_REQUEST,
+				request -> request.removeParameter(OAuth2ParameterNames.REDIRECT_URI));
 	}
 
 	@Test
@@ -383,6 +337,27 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		assertThat(authorizationRequest.getAdditionalParameters()).isEmpty();
 	}
 
+	private void doFilterWhenAuthorizationRequestInvalidParameterThenError(RegisteredClient registeredClient,
+			String parameterName, String errorCode) throws Exception {
+		doFilterWhenAuthorizationRequestInvalidParameterThenError(registeredClient, parameterName, errorCode, request -> {});
+	}
+
+	private void doFilterWhenAuthorizationRequestInvalidParameterThenError(RegisteredClient registeredClient,
+			String parameterName, String errorCode, Consumer<MockHttpServletRequest> requestConsumer) throws Exception {
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		requestConsumer.accept(request);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+		assertThat(response.getErrorMessage()).isEqualTo("[" + errorCode + "] OAuth 2.0 Parameter: " + parameterName);
+	}
+
 	private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) {
 		String[] redirectUris = registeredClient.getRedirectUris().toArray(new String[0]);