浏览代码

Add authenticationDetailsSource to AuthorizationEndpointFilter

Closes gh-768
Daniel Garnier-Moiroux 3 年之前
父节点
当前提交
ec7ab5c956

+ 16 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2021 the original author or authors.
+ * Copyright 2020-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -28,6 +28,7 @@ import javax.servlet.http.HttpServletResponse;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.HttpStatus;
 import org.springframework.http.MediaType;
+import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
@@ -45,6 +46,7 @@ import org.springframework.security.web.RedirectStrategy;
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
+import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
 import org.springframework.security.web.util.RedirectUrlBuilder;
 import org.springframework.security.web.util.UrlUtils;
 import org.springframework.security.web.util.matcher.AndRequestMatcher;
@@ -82,6 +84,7 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
 	private final AuthenticationManager authenticationManager;
 	private final RequestMatcher authorizationEndpointMatcher;
 	private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
+	private AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();
 	private AuthenticationConverter authenticationConverter;
 	private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendAuthorizationResponse;
 	private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse;
@@ -144,6 +147,7 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
 		try {
 			OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
 					(OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationConverter.convert(request);
+			authorizationCodeRequestAuthentication.setDetails(this.authenticationDetailsSource.buildDetails(request));
 
 			OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult =
 					(OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationManager.authenticate(authorizationCodeRequestAuthentication);
@@ -169,6 +173,17 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
 		}
 	}
 
+	/**
+	 * Sets the {@link AuthenticationDetailsSource} used for building an authentication details instance from {@link HttpServletRequest}.
+	 *
+	 * @param authenticationDetailsSource the {@link AuthenticationDetailsSource} used for building an authentication details instance from {@link HttpServletRequest}
+	 * @since 0.3.1
+	 */
+	public void setAuthenticationDetailsSource(AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource) {
+		Assert.notNull(authenticationDetailsSource, "authenticationDetailsSource cannot be null");
+		this.authenticationDetailsSource = authenticationDetailsSource;
+	}
+
 	/**
 	 * Sets the {@link AuthenticationConverter} used when attempting to extract an Authorization Request (or Consent) from {@link HttpServletRequest}
 	 * to an instance of {@link OAuth2AuthorizationCodeRequestAuthenticationToken} used for authenticating the request.

+ 47 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@@ -32,10 +32,12 @@ import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
+import org.mockito.ArgumentCaptor;
 import org.springframework.http.HttpStatus;
 import org.springframework.http.MediaType;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
@@ -55,10 +57,12 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegis
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
+import org.springframework.security.web.authentication.WebAuthenticationDetails;
 import org.springframework.util.StringUtils;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.assertj.core.api.InstanceOfAssertFactories.type;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.same;
 import static org.mockito.Mockito.mock;
@@ -78,6 +82,7 @@ import static org.mockito.Mockito.when;
  */
 public class OAuth2AuthorizationEndpointFilterTests {
 	private static final String DEFAULT_AUTHORIZATION_ENDPOINT_URI = "/oauth2/authorize";
+	private static final String REMOTE_ADDRESS = "remote-address";
 	private AuthenticationManager authenticationManager;
 	private OAuth2AuthorizationEndpointFilter filter;
 	private TestingAuthenticationToken principal;
@@ -116,6 +121,13 @@ public class OAuth2AuthorizationEndpointFilterTests {
 				.hasMessage("authorizationEndpointUri cannot be empty");
 	}
 
+	@Test
+	public void setAuthenticationDetailsSourceWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.filter.setAuthenticationDetailsSource(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authenticationDetailsSource cannot be null");
+	}
+
 	@Test
 	public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> this.filter.setAuthenticationConverter(null))
@@ -364,6 +376,32 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), same(authenticationException));
 	}
 
+	@Test
+	public void doFilterWhenCustomAuthenticationDetailsSourceThenUsed() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
+				authorizationCodeRequestAuthentication(registeredClient, this.principal).build();
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+
+		AuthenticationDetailsSource<HttpServletRequest, WebAuthenticationDetails> authenticationDetailsSource =
+				mock(AuthenticationDetailsSource.class);
+		WebAuthenticationDetails webAuthenticationDetails = new WebAuthenticationDetails(request);
+		when(authenticationDetailsSource.buildDetails(request)).thenReturn(webAuthenticationDetails);
+		this.filter.setAuthenticationDetailsSource(authenticationDetailsSource);
+
+		when(this.authenticationManager.authenticate(any()))
+				.thenReturn(authorizationCodeRequestAuthentication);
+
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(authenticationDetailsSource).buildDetails(any());
+		verify(this.authenticationManager).authenticate(any());
+		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+	}
+
 	@Test
 	public void doFilterWhenAuthorizationRequestPrincipalNotAuthenticatedThenCommenceAuthentication() throws Exception {
 		this.principal.setAuthenticated(false);
@@ -507,9 +545,15 @@ public class OAuth2AuthorizationEndpointFilterTests {
 
 		this.filter.doFilter(request, response, filterChain);
 
-		verify(this.authenticationManager).authenticate(any());
+		ArgumentCaptor<OAuth2AuthorizationCodeRequestAuthenticationToken> authorizationCodeRequestAuthenticationCaptor =
+				ArgumentCaptor.forClass(OAuth2AuthorizationCodeRequestAuthenticationToken.class);
+		verify(this.authenticationManager).authenticate(authorizationCodeRequestAuthenticationCaptor.capture());
 		verifyNoInteractions(filterChain);
 
+		assertThat(authorizationCodeRequestAuthenticationCaptor.getValue().getDetails())
+				.asInstanceOf(type(WebAuthenticationDetails.class))
+				.extracting(WebAuthenticationDetails::getRemoteAddress)
+				.isEqualTo(REMOTE_ADDRESS);
 		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
 		assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?code=code&state=state");
 	}
@@ -578,6 +622,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI;
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 		request.setServletPath(requestUri);
+		request.setRemoteAddr(REMOTE_ADDRESS);
 
 		request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
 		request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
@@ -593,6 +638,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI;
 		MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);
 		request.setServletPath(requestUri);
+		request.setRemoteAddr(REMOTE_ADDRESS);
 
 		request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
 		registeredClient.getScopes().forEach((scope) -> request.addParameter(OAuth2ParameterNames.SCOPE, scope));