Browse Source

Customize authenticationDetailsSource of OAuth2TokenEndpointFilter

Closes gh-431
figozhang 3 years ago
parent
commit
c9954af084

+ 12 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java

@@ -104,7 +104,7 @@ public final class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
 			new OAuth2AccessTokenResponseHttpMessageConverter();
 	private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
 			new OAuth2ErrorHttpMessageConverter();
-	private final AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource =
+	private AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource =
 			new WebAuthenticationDetailsSource();
 	private AuthenticationConverter authenticationConverter;
 	private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendAccessTokenResponse;
@@ -170,6 +170,17 @@ public final class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
 		}
 	}
 
+	/**
+	 * Sets the {@link AuthenticationDetailsSource} used for building an authentication details instance from {@link HttpServletRequest}.
+	 *
+	 * @param authenticationDetailsSource the {@link AuthenticationDetailsSource} used for building an authentication details instance from {@link HttpServletRequest}
+	 */
+	public void setAuthenticationDetailsSource(
+			AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource) {
+		Assert.notNull(authenticationDetailsSource, "authenticationDetailsSource cannot be null");
+		this.authenticationDetailsSource = authenticationDetailsSource;
+	}
+
 	/**
 	 * Sets the {@link AuthenticationConverter} used when attempting to extract an Access Token Request from {@link HttpServletRequest}
 	 * to an instance of {@link OAuth2AuthorizationGrantAuthenticationToken} used for authenticating the authorization grant.

+ 43 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java

@@ -37,6 +37,7 @@ import org.springframework.http.converter.HttpMessageConverter;
 import org.springframework.mock.http.client.MockClientHttpResponse;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContext;
@@ -115,6 +116,13 @@ public class OAuth2TokenEndpointFilterTests {
 				.hasMessage("tokenEndpointUri cannot be empty");
 	}
 
+	@Test
+	public void setAuthenticationDetailsSourceWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.filter.setAuthenticationDetailsSource(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authenticationDetailsSource cannot be null");
+	}
+
 	@Test
 	public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> this.filter.setAuthenticationConverter(null))
@@ -444,6 +452,41 @@ public class OAuth2TokenEndpointFilterTests {
 		assertThat(refreshTokenResult.getTokenValue()).isEqualTo(refreshToken.getTokenValue());
 	}
 
+	@Test
+	public void doFilterWhenCustomAuthenticationDetailsSourceThenUsed() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(
+				registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());
+
+		MockHttpServletRequest request = createAuthorizationCodeTokenRequest(registeredClient);
+
+		AuthenticationDetailsSource<HttpServletRequest, WebAuthenticationDetails> authenticationDetailsSource =
+				mock(AuthenticationDetailsSource.class);
+		WebAuthenticationDetails webAuthenticationDetails = new WebAuthenticationDetails(request);
+		when(authenticationDetailsSource.buildDetails(request)).thenReturn(webAuthenticationDetails);
+		this.filter.setAuthenticationDetailsSource(authenticationDetailsSource);
+
+		OAuth2AccessToken accessToken = new OAuth2AccessToken(
+				OAuth2AccessToken.TokenType.BEARER, "token",
+				Instant.now(), Instant.now().plus(Duration.ofHours(1)),
+				new HashSet<>(Arrays.asList("scope1", "scope2")));
+		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
+				new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken);
+
+		when(this.authenticationManager.authenticate(any())).thenReturn(accessTokenAuthentication);
+
+		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
+		securityContext.setAuthentication(clientPrincipal);
+		SecurityContextHolder.setContext(securityContext);
+
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(authenticationDetailsSource).buildDetails(request);
+	}
+
 	@Test
 	public void doFilterWhenCustomAuthenticationConverterThenUsed() throws Exception {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();