Sfoglia il codice sorgente

client_id authentication parameter must have printable ASCII characters

Closes gh-889
Joe Grandja 2 anni fa
parent
commit
8ed0194744

+ 22 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java

@@ -118,6 +118,7 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
 						this.authenticationDetailsSource.buildDetails(request));
 			}
 			if (authenticationRequest != null) {
+				validateClientIdentifier(authenticationRequest);
 				Authentication authenticationResult = this.authenticationManager.authenticate(authenticationRequest);
 				this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authenticationResult);
 			}
@@ -201,4 +202,25 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
 		this.errorHttpResponseConverter.write(errorResponse, null, httpResponse);
 	}
 
+	private static void validateClientIdentifier(Authentication authentication) {
+		if (!(authentication instanceof OAuth2ClientAuthenticationToken)) {
+			return;
+		}
+
+		// As per spec, in Appendix A.1.
+		// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-07#appendix-A.1
+		// The syntax for client_id is *VSCHAR (%x20-7E):
+		// -> Hex 20 -> ASCII 32 -> space
+		// -> Hex 7E -> ASCII 126 -> tilde
+
+		OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken) authentication;
+		String clientId = (String) clientAuthentication.getPrincipal();
+		for (int i = 0; i < clientId.length(); i++) {
+			char charAt = clientId.charAt(i);
+			if (!(charAt >= 32 && charAt <= 126)) {
+				throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST);
+			}
+		}
+	}
+
 }

+ 46 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilterTests.java

@@ -15,6 +15,8 @@
  */
 package org.springframework.security.oauth2.server.authorization.web;
 
+import java.nio.charset.StandardCharsets;
+
 import javax.servlet.FilterChain;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
@@ -33,6 +35,7 @@ import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.crypto.codec.Hex;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
@@ -130,6 +133,7 @@ public class OAuth2ClientAuthenticationFilterTests {
 		this.filter.doFilter(request, response, filterChain);
 
 		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+		verifyNoInteractions(this.authenticationConverter);
 	}
 
 	@Test
@@ -142,6 +146,7 @@ public class OAuth2ClientAuthenticationFilterTests {
 		this.filter.doFilter(request, response, filterChain);
 
 		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+		verifyNoInteractions(this.authenticationManager);
 	}
 
 	@Test
@@ -164,6 +169,46 @@ public class OAuth2ClientAuthenticationFilterTests {
 		assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
 	}
 
+	// gh-889
+	@Test
+	public void doFilterWhenRequestMatchesAndClientIdContainsNonPrintableASCIIThenInvalidRequestError() throws Exception {
+		// Hex 00 -> null
+		String clientId = new String(Hex.decode("00"), StandardCharsets.UTF_8);
+		assertWhenInvalidClientIdThenInvalidRequestError(clientId);
+
+		// Hex 0a61 -> line feed + a
+		clientId = new String(Hex.decode("0a61"), StandardCharsets.UTF_8);
+		assertWhenInvalidClientIdThenInvalidRequestError(clientId);
+
+		// Hex 1b -> escape
+		clientId = new String(Hex.decode("1b"), StandardCharsets.UTF_8);
+		assertWhenInvalidClientIdThenInvalidRequestError(clientId);
+
+		// Hex 1b61 -> escape + a
+		clientId = new String(Hex.decode("1b61"), StandardCharsets.UTF_8);
+		assertWhenInvalidClientIdThenInvalidRequestError(clientId);
+	}
+
+	private void assertWhenInvalidClientIdThenInvalidRequestError(String clientId) throws Exception {
+		when(this.authenticationConverter.convert(any(HttpServletRequest.class))).thenReturn(
+				new OAuth2ClientAuthenticationToken(clientId, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, "secret", null));
+
+		MockHttpServletRequest request = new MockHttpServletRequest("POST", this.filterProcessesUrl);
+		request.setServletPath(this.filterProcessesUrl);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+		verifyNoInteractions(this.authenticationManager);
+
+		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+		OAuth2Error error = readError(response);
+		assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
+	}
+
 	@Test
 	public void doFilterWhenRequestMatchesAndBadCredentialsThenInvalidClientError() throws Exception {
 		when(this.authenticationConverter.convert(any(HttpServletRequest.class))).thenReturn(
@@ -179,6 +224,7 @@ public class OAuth2ClientAuthenticationFilterTests {
 		this.filter.doFilter(request, response, filterChain);
 
 		verifyNoInteractions(filterChain);
+		verify(this.authenticationManager).authenticate(any());
 
 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
 		assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value());