|
@@ -19,7 +19,6 @@ import org.junit.After;
|
|
|
import org.junit.Before;
|
|
|
import org.junit.Test;
|
|
|
import org.mockito.ArgumentCaptor;
|
|
|
-
|
|
|
import org.springframework.http.HttpStatus;
|
|
|
import org.springframework.http.converter.HttpMessageConverter;
|
|
|
import org.springframework.mock.http.client.MockClientHttpResponse;
|
|
@@ -44,17 +43,15 @@ import org.springframework.security.oauth2.server.authorization.authentication.O
|
|
|
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationToken;
|
|
|
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
|
|
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
|
|
|
+import org.springframework.util.StringUtils;
|
|
|
|
|
|
import javax.servlet.FilterChain;
|
|
|
-import javax.servlet.ServletException;
|
|
|
import javax.servlet.http.HttpServletRequest;
|
|
|
import javax.servlet.http.HttpServletResponse;
|
|
|
-import java.io.IOException;
|
|
|
import java.time.Duration;
|
|
|
import java.time.Instant;
|
|
|
import java.util.Arrays;
|
|
|
import java.util.HashSet;
|
|
|
-import java.util.function.Consumer;
|
|
|
|
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
|
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
|
@@ -140,58 +137,77 @@ public class OAuth2TokenEndpointFilterTests {
|
|
|
|
|
|
@Test
|
|
|
public void doFilterWhenTokenRequestMissingGrantTypeThenInvalidRequestError() throws Exception {
|
|
|
+ MockHttpServletRequest request = createAuthorizationCodeTokenRequest(
|
|
|
+ TestRegisteredClients.registeredClient().build());
|
|
|
+ request.removeParameter(OAuth2ParameterNames.GRANT_TYPE);
|
|
|
+
|
|
|
doFilterWhenTokenRequestInvalidParameterThenError(
|
|
|
- OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.INVALID_REQUEST,
|
|
|
- request -> request.removeParameter(OAuth2ParameterNames.GRANT_TYPE));
|
|
|
+ OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.INVALID_REQUEST, request);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
public void doFilterWhenTokenRequestMultipleGrantTypeThenInvalidRequestError() throws Exception {
|
|
|
+ MockHttpServletRequest request = createAuthorizationCodeTokenRequest(
|
|
|
+ TestRegisteredClients.registeredClient().build());
|
|
|
+ request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
|
|
|
+
|
|
|
doFilterWhenTokenRequestInvalidParameterThenError(
|
|
|
- OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.INVALID_REQUEST,
|
|
|
- request -> request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()));
|
|
|
+ OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.INVALID_REQUEST, request);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
public void doFilterWhenTokenRequestInvalidGrantTypeThenUnsupportedGrantTypeError() throws Exception {
|
|
|
+ MockHttpServletRequest request = createAuthorizationCodeTokenRequest(
|
|
|
+ TestRegisteredClients.registeredClient().build());
|
|
|
+ request.setParameter(OAuth2ParameterNames.GRANT_TYPE, "invalid-grant-type");
|
|
|
+
|
|
|
doFilterWhenTokenRequestInvalidParameterThenError(
|
|
|
- OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE,
|
|
|
- request -> request.setParameter(OAuth2ParameterNames.GRANT_TYPE, "invalid-grant-type"));
|
|
|
+ OAuth2ParameterNames.GRANT_TYPE, OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE, request);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
public void doFilterWhenTokenRequestMultipleClientIdThenInvalidRequestError() throws Exception {
|
|
|
+ MockHttpServletRequest request = createAuthorizationCodeTokenRequest(
|
|
|
+ TestRegisteredClients.registeredClient().build());
|
|
|
+ request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-1");
|
|
|
+ request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2");
|
|
|
+
|
|
|
doFilterWhenTokenRequestInvalidParameterThenError(
|
|
|
- OAuth2ParameterNames.CLIENT_ID, OAuth2ErrorCodes.INVALID_REQUEST,
|
|
|
- request -> {
|
|
|
- request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-1");
|
|
|
- request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2");
|
|
|
- });
|
|
|
+ OAuth2ParameterNames.CLIENT_ID, OAuth2ErrorCodes.INVALID_REQUEST, request);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
public void doFilterWhenTokenRequestMissingCodeThenInvalidRequestError() throws Exception {
|
|
|
+ MockHttpServletRequest request = createAuthorizationCodeTokenRequest(
|
|
|
+ TestRegisteredClients.registeredClient().build());
|
|
|
+ request.removeParameter(OAuth2ParameterNames.CODE);
|
|
|
+
|
|
|
doFilterWhenTokenRequestInvalidParameterThenError(
|
|
|
- OAuth2ParameterNames.CODE, OAuth2ErrorCodes.INVALID_REQUEST,
|
|
|
- request -> request.removeParameter(OAuth2ParameterNames.CODE));
|
|
|
+ OAuth2ParameterNames.CODE, OAuth2ErrorCodes.INVALID_REQUEST, request);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
public void doFilterWhenTokenRequestMultipleCodeThenInvalidRequestError() throws Exception {
|
|
|
+ MockHttpServletRequest request = createAuthorizationCodeTokenRequest(
|
|
|
+ TestRegisteredClients.registeredClient().build());
|
|
|
+ request.addParameter(OAuth2ParameterNames.CODE, "code-2");
|
|
|
+
|
|
|
doFilterWhenTokenRequestInvalidParameterThenError(
|
|
|
- OAuth2ParameterNames.CODE, OAuth2ErrorCodes.INVALID_REQUEST,
|
|
|
- request -> request.addParameter(OAuth2ParameterNames.CODE, "code-2"));
|
|
|
+ OAuth2ParameterNames.CODE, OAuth2ErrorCodes.INVALID_REQUEST, request);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
public void doFilterWhenTokenRequestMultipleRedirectUriThenInvalidRequestError() throws Exception {
|
|
|
+ MockHttpServletRequest request = createAuthorizationCodeTokenRequest(
|
|
|
+ TestRegisteredClients.registeredClient().build());
|
|
|
+ request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com");
|
|
|
+
|
|
|
doFilterWhenTokenRequestInvalidParameterThenError(
|
|
|
- OAuth2ParameterNames.REDIRECT_URI, OAuth2ErrorCodes.INVALID_REQUEST,
|
|
|
- request -> request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"));
|
|
|
+ OAuth2ParameterNames.REDIRECT_URI, OAuth2ErrorCodes.INVALID_REQUEST, request);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterWhenTokenRequestValidThenAccessTokenResponse() throws Exception {
|
|
|
+ public void doFilterWhenAuthorizationCodeTokenRequestValidThenAccessTokenResponse() throws Exception {
|
|
|
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
|
|
Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
|
|
|
OAuth2AccessToken accessToken = new OAuth2AccessToken(
|
|
@@ -208,7 +224,7 @@ public class OAuth2TokenEndpointFilterTests {
|
|
|
securityContext.setAuthentication(clientPrincipal);
|
|
|
SecurityContextHolder.setContext(securityContext);
|
|
|
|
|
|
- MockHttpServletRequest request = createTokenRequest(registeredClient);
|
|
|
+ MockHttpServletRequest request = createAuthorizationCodeTokenRequest(registeredClient);
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
FilterChain filterChain = mock(FilterChain.class);
|
|
|
|
|
@@ -242,38 +258,24 @@ public class OAuth2TokenEndpointFilterTests {
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterWhenGrantTypeIsClientCredentialsThenAuthenticateWithClientCredentialsToken() throws ServletException, IOException {
|
|
|
- RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
|
|
- doFilterForClientCredentialsGrant(registeredClient, null);
|
|
|
+ public void doFilterWhenTokenRequestMultipleScopeThenInvalidRequestError() throws Exception {
|
|
|
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
|
|
|
+ Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
|
|
|
|
|
|
- ArgumentCaptor<Authentication> captor = ArgumentCaptor.forClass(Authentication.class);
|
|
|
- verify(this.authenticationManager).authenticate(captor.capture());
|
|
|
+ SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
|
|
|
+ securityContext.setAuthentication(clientPrincipal);
|
|
|
+ SecurityContextHolder.setContext(securityContext);
|
|
|
|
|
|
- assertThat(captor.getValue()).isInstanceOf(OAuth2ClientCredentialsAuthenticationToken.class);
|
|
|
- OAuth2ClientCredentialsAuthenticationToken clientAuthenticationToken = (OAuth2ClientCredentialsAuthenticationToken) captor.getValue();
|
|
|
+ MockHttpServletRequest request = createClientCredentialsTokenRequest(registeredClient);
|
|
|
+ request.addParameter(OAuth2ParameterNames.SCOPE, "profile");
|
|
|
|
|
|
- assertThat(clientAuthenticationToken.getPrincipal()).isEqualTo(new OAuth2ClientAuthenticationToken(registeredClient));
|
|
|
+ doFilterWhenTokenRequestInvalidParameterThenError(
|
|
|
+ OAuth2ParameterNames.SCOPE, OAuth2ErrorCodes.INVALID_REQUEST, request);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterWhenGrantTypeIsClientCredentialsWithScopeThenIncludeScopeInResponse() throws ServletException, IOException {
|
|
|
- RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
|
|
- doFilterForClientCredentialsGrant(registeredClient, "openid email");
|
|
|
-
|
|
|
- ArgumentCaptor<Authentication> captor = ArgumentCaptor.forClass(Authentication.class);
|
|
|
- verify(this.authenticationManager).authenticate(captor.capture());
|
|
|
-
|
|
|
- assertThat(captor.getValue()).isInstanceOf(OAuth2ClientCredentialsAuthenticationToken.class);
|
|
|
- OAuth2ClientCredentialsAuthenticationToken clientAuthenticationToken = (OAuth2ClientCredentialsAuthenticationToken) captor.getValue();
|
|
|
-
|
|
|
- HashSet<String> expectedScopes = new HashSet<>();
|
|
|
- expectedScopes.add("openid");
|
|
|
- expectedScopes.add("email");
|
|
|
-
|
|
|
- assertThat(clientAuthenticationToken.getScopes()).isEqualTo(expectedScopes);
|
|
|
- }
|
|
|
-
|
|
|
- private void doFilterForClientCredentialsGrant(RegisteredClient registeredClient, String scope) throws ServletException, IOException {
|
|
|
+ public void doFilterWhenClientCredentialsTokenRequestValidThenAccessTokenResponse() throws Exception {
|
|
|
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
|
|
|
Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
|
|
|
OAuth2AccessToken accessToken = new OAuth2AccessToken(
|
|
|
OAuth2AccessToken.TokenType.BEARER, "token",
|
|
@@ -282,35 +284,46 @@ public class OAuth2TokenEndpointFilterTests {
|
|
|
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
|
|
|
new OAuth2AccessTokenAuthenticationToken(
|
|
|
registeredClient, clientPrincipal, accessToken);
|
|
|
- final String clientId = registeredClient.getClientId();
|
|
|
- final String clientSecret = registeredClient.getClientSecret();
|
|
|
-
|
|
|
- MockHttpServletRequest request = new MockHttpServletRequest("POST", OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI);
|
|
|
- request.setServletPath(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI);
|
|
|
- request.addParameter("client_id", clientId);
|
|
|
- request.addParameter("client_secret", clientSecret);
|
|
|
- request.addParameter("grant_type", AuthorizationGrantType.CLIENT_CREDENTIALS.getValue());
|
|
|
- if (scope != null) {
|
|
|
- request.addParameter("scope", scope);
|
|
|
- }
|
|
|
|
|
|
when(this.authenticationManager.authenticate(any())).thenReturn(accessTokenAuthentication);
|
|
|
|
|
|
- SecurityContext context = SecurityContextHolder.createEmptyContext();
|
|
|
- context.setAuthentication(new OAuth2ClientAuthenticationToken(registeredClient));
|
|
|
- SecurityContextHolder.setContext(context);
|
|
|
+ SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
|
|
|
+ securityContext.setAuthentication(clientPrincipal);
|
|
|
+ SecurityContextHolder.setContext(securityContext);
|
|
|
|
|
|
+ MockHttpServletRequest request = createClientCredentialsTokenRequest(registeredClient);
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
- filter.doFilter(request, response, mock(FilterChain.class));
|
|
|
+ FilterChain filterChain = mock(FilterChain.class);
|
|
|
+
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
+
|
|
|
+ verifyNoInteractions(filterChain);
|
|
|
+
|
|
|
+ ArgumentCaptor<OAuth2ClientCredentialsAuthenticationToken> clientCredentialsAuthenticationCaptor =
|
|
|
+ ArgumentCaptor.forClass(OAuth2ClientCredentialsAuthenticationToken.class);
|
|
|
+ verify(this.authenticationManager).authenticate(clientCredentialsAuthenticationCaptor.capture());
|
|
|
+
|
|
|
+ OAuth2ClientCredentialsAuthenticationToken clientCredentialsAuthentication =
|
|
|
+ clientCredentialsAuthenticationCaptor.getValue();
|
|
|
+ assertThat(clientCredentialsAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
|
|
|
+ assertThat(clientCredentialsAuthentication.getScopes()).isEqualTo(registeredClient.getScopes());
|
|
|
+
|
|
|
+ assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value());
|
|
|
+ OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response);
|
|
|
+
|
|
|
+ OAuth2AccessToken accessTokenResult = accessTokenResponse.getAccessToken();
|
|
|
+ assertThat(accessTokenResult.getTokenType()).isEqualTo(accessToken.getTokenType());
|
|
|
+ assertThat(accessTokenResult.getTokenValue()).isEqualTo(accessToken.getTokenValue());
|
|
|
+ assertThat(accessTokenResult.getIssuedAt()).isBetween(
|
|
|
+ accessToken.getIssuedAt().minusSeconds(1), accessToken.getIssuedAt().plusSeconds(1));
|
|
|
+ assertThat(accessTokenResult.getExpiresAt()).isBetween(
|
|
|
+ accessToken.getExpiresAt().minusSeconds(1), accessToken.getExpiresAt().plusSeconds(1));
|
|
|
+ assertThat(accessTokenResult.getScopes()).isEqualTo(accessToken.getScopes());
|
|
|
}
|
|
|
|
|
|
private void doFilterWhenTokenRequestInvalidParameterThenError(String parameterName, String errorCode,
|
|
|
- Consumer<MockHttpServletRequest> requestConsumer) throws Exception {
|
|
|
+ MockHttpServletRequest request) throws Exception {
|
|
|
|
|
|
- RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
|
|
-
|
|
|
- MockHttpServletRequest request = createTokenRequest(registeredClient);
|
|
|
- requestConsumer.accept(request);
|
|
|
MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
FilterChain filterChain = mock(FilterChain.class);
|
|
|
|
|
@@ -336,7 +349,7 @@ public class OAuth2TokenEndpointFilterTests {
|
|
|
return this.accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse);
|
|
|
}
|
|
|
|
|
|
- private static MockHttpServletRequest createTokenRequest(RegisteredClient registeredClient) {
|
|
|
+ private static MockHttpServletRequest createAuthorizationCodeTokenRequest(RegisteredClient registeredClient) {
|
|
|
String[] redirectUris = registeredClient.getRedirectUris().toArray(new String[0]);
|
|
|
|
|
|
String requestUri = OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI;
|
|
@@ -349,4 +362,16 @@ public class OAuth2TokenEndpointFilterTests {
|
|
|
|
|
|
return request;
|
|
|
}
|
|
|
+
|
|
|
+ private static MockHttpServletRequest createClientCredentialsTokenRequest(RegisteredClient registeredClient) {
|
|
|
+ String requestUri = OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI;
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);
|
|
|
+ request.setServletPath(requestUri);
|
|
|
+
|
|
|
+ request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue());
|
|
|
+ request.addParameter(OAuth2ParameterNames.SCOPE,
|
|
|
+ StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
|
|
|
+
|
|
|
+ return request;
|
|
|
+ }
|
|
|
}
|