|
@@ -15,6 +15,8 @@
|
|
|
*/
|
|
|
package org.springframework.security.oauth2.server.authorization.web;
|
|
|
|
|
|
+import java.nio.charset.StandardCharsets;
|
|
|
+
|
|
|
import jakarta.servlet.FilterChain;
|
|
|
import jakarta.servlet.http.HttpServletRequest;
|
|
|
import jakarta.servlet.http.HttpServletResponse;
|
|
@@ -33,6 +35,7 @@ import org.springframework.mock.web.MockHttpServletResponse;
|
|
|
import org.springframework.security.authentication.AuthenticationManager;
|
|
|
import org.springframework.security.core.Authentication;
|
|
|
import org.springframework.security.core.context.SecurityContextHolder;
|
|
|
+import org.springframework.security.crypto.codec.Hex;
|
|
|
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
|
|
|
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
|
|
import org.springframework.security.oauth2.core.OAuth2Error;
|
|
@@ -130,6 +133,7 @@ public class OAuth2ClientAuthenticationFilterTests {
|
|
|
this.filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
|
|
+ verifyNoInteractions(this.authenticationConverter);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
@@ -142,6 +146,7 @@ public class OAuth2ClientAuthenticationFilterTests {
|
|
|
this.filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
|
|
+ verifyNoInteractions(this.authenticationManager);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
@@ -164,6 +169,46 @@ public class OAuth2ClientAuthenticationFilterTests {
|
|
|
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
|
|
|
}
|
|
|
|
|
|
+ // gh-889
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenRequestMatchesAndClientIdContainsNonPrintableASCIIThenInvalidRequestError() throws Exception {
|
|
|
+ // Hex 00 -> null
|
|
|
+ String clientId = new String(Hex.decode("00"), StandardCharsets.UTF_8);
|
|
|
+ assertWhenInvalidClientIdThenInvalidRequestError(clientId);
|
|
|
+
|
|
|
+ // Hex 0a61 -> line feed + a
|
|
|
+ clientId = new String(Hex.decode("0a61"), StandardCharsets.UTF_8);
|
|
|
+ assertWhenInvalidClientIdThenInvalidRequestError(clientId);
|
|
|
+
|
|
|
+ // Hex 1b -> escape
|
|
|
+ clientId = new String(Hex.decode("1b"), StandardCharsets.UTF_8);
|
|
|
+ assertWhenInvalidClientIdThenInvalidRequestError(clientId);
|
|
|
+
|
|
|
+ // Hex 1b61 -> escape + a
|
|
|
+ clientId = new String(Hex.decode("1b61"), StandardCharsets.UTF_8);
|
|
|
+ assertWhenInvalidClientIdThenInvalidRequestError(clientId);
|
|
|
+ }
|
|
|
+
|
|
|
+ private void assertWhenInvalidClientIdThenInvalidRequestError(String clientId) throws Exception {
|
|
|
+ when(this.authenticationConverter.convert(any(HttpServletRequest.class))).thenReturn(
|
|
|
+ new OAuth2ClientAuthenticationToken(clientId, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, "secret", null));
|
|
|
+
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("POST", this.filterProcessesUrl);
|
|
|
+ request.setServletPath(this.filterProcessesUrl);
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
+ FilterChain filterChain = mock(FilterChain.class);
|
|
|
+
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
+
|
|
|
+ verifyNoInteractions(filterChain);
|
|
|
+ verifyNoInteractions(this.authenticationManager);
|
|
|
+
|
|
|
+ assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
|
|
|
+ assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
|
|
|
+ OAuth2Error error = readError(response);
|
|
|
+ assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
|
|
|
+ }
|
|
|
+
|
|
|
@Test
|
|
|
public void doFilterWhenRequestMatchesAndBadCredentialsThenInvalidClientError() throws Exception {
|
|
|
when(this.authenticationConverter.convert(any(HttpServletRequest.class))).thenReturn(
|
|
@@ -179,6 +224,7 @@ public class OAuth2ClientAuthenticationFilterTests {
|
|
|
this.filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
verifyNoInteractions(filterChain);
|
|
|
+ verify(this.authenticationManager).authenticate(any());
|
|
|
|
|
|
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
|
|
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value());
|