Ver código fonte

Validate authorization request before authenticated check

Issue gh-66
Joe Grandja 5 anos atrás
pai
commit
485b7e9319

+ 18 - 14
core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java

@@ -114,17 +114,14 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
 
-		if (!this.authorizationEndpointMatcher.matches(request) || !isPrincipalAuthenticated()) {
+		if (!this.authorizationEndpointMatcher.matches(request)) {
 			filterChain.doFilter(request, response);
 			return;
 		}
 
-//		TODO
-//		The authorization server validates the request to ensure that all
-//		required parameters are present and valid.  If the request is valid,
-//		the authorization server authenticates the resource owner and obtains
-//		an authorization decision (by asking the resource owner or by
-//		establishing approval via other means).
+		// ---------------
+		// Validate the request to ensure that all required parameters are present and valid
+		// ---------------
 
 		MultiValueMap<String, String> parameters = getParameters(request);
 		String stateParameter = parameters.getFirst(OAuth2ParameterNames.STATE);
@@ -179,7 +176,18 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 			return;
 		}
 
+		// ---------------
+		// The request is valid - ensure the resource owner is authenticated
+		// ---------------
+
 		Authentication principal = SecurityContextHolder.getContext().getAuthentication();
+		if (!isPrincipalAuthenticated(principal)) {
+			// Pass through the chain with the expectation that the authentication process
+			// will commence via AuthenticationEntryPoint
+			filterChain.doFilter(request, response);
+			return;
+		}
+
 		String code = this.codeGenerator.generateKey();
 		OAuth2AuthorizationRequest authorizationRequest = convertAuthorizationRequest(request);
 
@@ -238,8 +246,9 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 		this.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString());
 	}
 
-	private static boolean isPrincipalAuthenticated() {
-		return isPrincipalAuthenticated(SecurityContextHolder.getContext().getAuthentication());
+	private static OAuth2Error createError(String errorCode, String parameterName) {
+		return new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName,
+				"https://tools.ietf.org/html/rfc6749#section-4.1.2.1");
 	}
 
 	private static boolean isPrincipalAuthenticated(Authentication principal) {
@@ -248,11 +257,6 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 				principal.isAuthenticated();
 	}
 
-	private static OAuth2Error createError(String errorCode, String parameterName) {
-		return new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName,
-				"https://tools.ietf.org/html/rfc6749#section-4.1.2.1");
-	}
-
 	private static OAuth2AuthorizationRequest convertAuthorizationRequest(HttpServletRequest request) {
 		MultiValueMap<String, String> parameters = getParameters(request);
 

+ 17 - 15
core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@@ -128,21 +128,6 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
 	}
 
-	@Test
-	public void doFilterWhenAuthorizationRequestNotAuthenticatedThenNotProcessed() throws Exception {
-		String requestUri = OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI;
-		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
-		request.setServletPath(requestUri);
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		this.authentication.setAuthenticated(false);
-
-		this.filter.doFilter(request, response, filterChain);
-
-		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-	}
-
 	@Test
 	public void doFilterWhenAuthorizationRequestMissingClientIdThenInvalidRequestError() throws Exception {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
@@ -341,6 +326,23 @@ public class OAuth2AuthorizationEndpointFilterTests {
 				"state=state");
 	}
 
+	@Test
+	public void doFilterWhenAuthorizationRequestValidNotAuthenticatedThenContinueChainToCommenceAuthentication() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+				.thenReturn(registeredClient);
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.authentication.setAuthenticated(false);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+	}
+
 	@Test
 	public void doFilterWhenAuthorizationRequestValidThenAuthorizationResponse() throws Exception {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();