|
@@ -32,10 +32,12 @@ import org.junit.After;
|
|
import org.junit.Before;
|
|
import org.junit.Before;
|
|
import org.junit.Test;
|
|
import org.junit.Test;
|
|
|
|
|
|
|
|
+import org.mockito.ArgumentCaptor;
|
|
import org.springframework.http.HttpStatus;
|
|
import org.springframework.http.HttpStatus;
|
|
import org.springframework.http.MediaType;
|
|
import org.springframework.http.MediaType;
|
|
import org.springframework.mock.web.MockHttpServletRequest;
|
|
import org.springframework.mock.web.MockHttpServletRequest;
|
|
import org.springframework.mock.web.MockHttpServletResponse;
|
|
import org.springframework.mock.web.MockHttpServletResponse;
|
|
|
|
+import org.springframework.security.authentication.AuthenticationDetailsSource;
|
|
import org.springframework.security.authentication.AuthenticationManager;
|
|
import org.springframework.security.authentication.AuthenticationManager;
|
|
import org.springframework.security.authentication.TestingAuthenticationToken;
|
|
import org.springframework.security.authentication.TestingAuthenticationToken;
|
|
import org.springframework.security.core.Authentication;
|
|
import org.springframework.security.core.Authentication;
|
|
@@ -55,10 +57,12 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegis
|
|
import org.springframework.security.web.authentication.AuthenticationConverter;
|
|
import org.springframework.security.web.authentication.AuthenticationConverter;
|
|
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
|
|
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
|
|
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
|
|
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
|
|
|
|
+import org.springframework.security.web.authentication.WebAuthenticationDetails;
|
|
import org.springframework.util.StringUtils;
|
|
import org.springframework.util.StringUtils;
|
|
|
|
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
|
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
|
|
|
+import static org.assertj.core.api.InstanceOfAssertFactories.type;
|
|
import static org.mockito.ArgumentMatchers.any;
|
|
import static org.mockito.ArgumentMatchers.any;
|
|
import static org.mockito.ArgumentMatchers.same;
|
|
import static org.mockito.ArgumentMatchers.same;
|
|
import static org.mockito.Mockito.mock;
|
|
import static org.mockito.Mockito.mock;
|
|
@@ -78,6 +82,7 @@ import static org.mockito.Mockito.when;
|
|
*/
|
|
*/
|
|
public class OAuth2AuthorizationEndpointFilterTests {
|
|
public class OAuth2AuthorizationEndpointFilterTests {
|
|
private static final String DEFAULT_AUTHORIZATION_ENDPOINT_URI = "/oauth2/authorize";
|
|
private static final String DEFAULT_AUTHORIZATION_ENDPOINT_URI = "/oauth2/authorize";
|
|
|
|
+ private static final String REMOTE_ADDRESS = "remote-address";
|
|
private AuthenticationManager authenticationManager;
|
|
private AuthenticationManager authenticationManager;
|
|
private OAuth2AuthorizationEndpointFilter filter;
|
|
private OAuth2AuthorizationEndpointFilter filter;
|
|
private TestingAuthenticationToken principal;
|
|
private TestingAuthenticationToken principal;
|
|
@@ -116,6 +121,13 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
.hasMessage("authorizationEndpointUri cannot be empty");
|
|
.hasMessage("authorizationEndpointUri cannot be empty");
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ @Test
|
|
|
|
+ public void setAuthenticationDetailsSourceWhenNullThenThrowIllegalArgumentException() {
|
|
|
|
+ assertThatThrownBy(() -> this.filter.setAuthenticationDetailsSource(null))
|
|
|
|
+ .isInstanceOf(IllegalArgumentException.class)
|
|
|
|
+ .hasMessage("authenticationDetailsSource cannot be null");
|
|
|
|
+ }
|
|
|
|
+
|
|
@Test
|
|
@Test
|
|
public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException() {
|
|
public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException() {
|
|
assertThatThrownBy(() -> this.filter.setAuthenticationConverter(null))
|
|
assertThatThrownBy(() -> this.filter.setAuthenticationConverter(null))
|
|
@@ -364,6 +376,32 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), same(authenticationException));
|
|
verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), same(authenticationException));
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ @Test
|
|
|
|
+ public void doFilterWhenCustomAuthenticationDetailsSourceThenUsed() throws Exception {
|
|
|
|
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
|
|
|
+ OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
|
|
|
|
+ authorizationCodeRequestAuthentication(registeredClient, this.principal).build();
|
|
|
|
+ MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
|
|
|
+
|
|
|
|
+ AuthenticationDetailsSource<HttpServletRequest, WebAuthenticationDetails> authenticationDetailsSource =
|
|
|
|
+ mock(AuthenticationDetailsSource.class);
|
|
|
|
+ WebAuthenticationDetails webAuthenticationDetails = new WebAuthenticationDetails(request);
|
|
|
|
+ when(authenticationDetailsSource.buildDetails(request)).thenReturn(webAuthenticationDetails);
|
|
|
|
+ this.filter.setAuthenticationDetailsSource(authenticationDetailsSource);
|
|
|
|
+
|
|
|
|
+ when(this.authenticationManager.authenticate(any()))
|
|
|
|
+ .thenReturn(authorizationCodeRequestAuthentication);
|
|
|
|
+
|
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
|
+ FilterChain filterChain = mock(FilterChain.class);
|
|
|
|
+
|
|
|
|
+ this.filter.doFilter(request, response, filterChain);
|
|
|
|
+
|
|
|
|
+ verify(authenticationDetailsSource).buildDetails(any());
|
|
|
|
+ verify(this.authenticationManager).authenticate(any());
|
|
|
|
+ verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
|
|
|
+ }
|
|
|
|
+
|
|
@Test
|
|
@Test
|
|
public void doFilterWhenAuthorizationRequestPrincipalNotAuthenticatedThenCommenceAuthentication() throws Exception {
|
|
public void doFilterWhenAuthorizationRequestPrincipalNotAuthenticatedThenCommenceAuthentication() throws Exception {
|
|
this.principal.setAuthenticated(false);
|
|
this.principal.setAuthenticated(false);
|
|
@@ -507,9 +545,15 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
|
|
|
|
this.filter.doFilter(request, response, filterChain);
|
|
this.filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
- verify(this.authenticationManager).authenticate(any());
|
|
|
|
|
|
+ ArgumentCaptor<OAuth2AuthorizationCodeRequestAuthenticationToken> authorizationCodeRequestAuthenticationCaptor =
|
|
|
|
+ ArgumentCaptor.forClass(OAuth2AuthorizationCodeRequestAuthenticationToken.class);
|
|
|
|
+ verify(this.authenticationManager).authenticate(authorizationCodeRequestAuthenticationCaptor.capture());
|
|
verifyNoInteractions(filterChain);
|
|
verifyNoInteractions(filterChain);
|
|
|
|
|
|
|
|
+ assertThat(authorizationCodeRequestAuthenticationCaptor.getValue().getDetails())
|
|
|
|
+ .asInstanceOf(type(WebAuthenticationDetails.class))
|
|
|
|
+ .extracting(WebAuthenticationDetails::getRemoteAddress)
|
|
|
|
+ .isEqualTo(REMOTE_ADDRESS);
|
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
|
|
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
|
|
assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?code=code&state=state");
|
|
assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?code=code&state=state");
|
|
}
|
|
}
|
|
@@ -578,6 +622,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI;
|
|
String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI;
|
|
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
request.setServletPath(requestUri);
|
|
request.setServletPath(requestUri);
|
|
|
|
+ request.setRemoteAddr(REMOTE_ADDRESS);
|
|
|
|
|
|
request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
|
|
request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
|
|
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
|
|
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
|
|
@@ -593,6 +638,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
|
String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI;
|
|
String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI;
|
|
MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);
|
|
MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);
|
|
request.setServletPath(requestUri);
|
|
request.setServletPath(requestUri);
|
|
|
|
+ request.setRemoteAddr(REMOTE_ADDRESS);
|
|
|
|
|
|
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
|
|
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
|
|
registeredClient.getScopes().forEach((scope) -> request.addParameter(OAuth2ParameterNames.SCOPE, scope));
|
|
registeredClient.getScopes().forEach((scope) -> request.addParameter(OAuth2ParameterNames.SCOPE, scope));
|