فهرست منبع

Add HttpServletResponse param to removeAuthorizationRequest

Fixes gh-5313
Joe Grandja 7 سال پیش
والد
کامیت
2c1c2c78c3

+ 13 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationRequestRepository.java

@@ -63,9 +63,22 @@ public interface AuthorizationRequestRepository<T extends OAuth2AuthorizationReq
 	 * Removes and returns the {@link OAuth2AuthorizationRequest} associated to the
 	 * provided {@code HttpServletRequest} or if not available returns {@code null}.
 	 *
+	 * @deprecated Use {@link #removeAuthorizationRequest(HttpServletRequest, HttpServletResponse)} instead
 	 * @param request the {@code HttpServletRequest}
 	 * @return the removed {@link OAuth2AuthorizationRequest} or {@code null} if not available
 	 */
 	T removeAuthorizationRequest(HttpServletRequest request);
 
+	/**
+	 * Removes and returns the {@link OAuth2AuthorizationRequest} associated to the
+	 * provided {@code HttpServletRequest} or if not available returns {@code null}.
+	 *
+	 * @since 5.1
+	 * @param request the {@code HttpServletRequest}
+	 * @param response the {@code HttpServletResponse}
+	 * @return the {@link OAuth2AuthorizationRequest} or {@code null} if not available
+	 */
+	default T removeAuthorizationRequest(HttpServletRequest request, HttpServletResponse response) {
+		return removeAuthorizationRequest(request);
+	}
 }

+ 7 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java

@@ -58,7 +58,7 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
 		Assert.notNull(request, "request cannot be null");
 		Assert.notNull(response, "response cannot be null");
 		if (authorizationRequest == null) {
-			this.removeAuthorizationRequest(request);
+			this.removeAuthorizationRequest(request, response);
 			return;
 		}
 		String state = authorizationRequest.getState();
@@ -85,6 +85,12 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
 		return originalRequest;
 	}
 
+	@Override
+	public OAuth2AuthorizationRequest removeAuthorizationRequest(HttpServletRequest request, HttpServletResponse response) {
+		Assert.notNull(response, "response cannot be null");
+		return this.removeAuthorizationRequest(request);
+	}
+
 	/**
 	 * Gets the state parameter from the {@link HttpServletRequest}
 	 * @param request the request to use

+ 2 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java

@@ -158,7 +158,8 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
 	private void processAuthorizationResponse(HttpServletRequest request, HttpServletResponse response)
 		throws ServletException, IOException {
 
-		OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository.removeAuthorizationRequest(request);
+		OAuth2AuthorizationRequest authorizationRequest =
+				this.authorizationRequestRepository.removeAuthorizationRequest(request, response);
 
 		String registrationId = (String) authorizationRequest.getAdditionalParameters().get(OAuth2ParameterNames.REGISTRATION_ID);
 		ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId);

+ 2 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java

@@ -156,7 +156,8 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
 			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
 		}
 
-		OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository.removeAuthorizationRequest(request);
+		OAuth2AuthorizationRequest authorizationRequest =
+				this.authorizationRequestRepository.removeAuthorizationRequest(request, response);
 		if (authorizationRequest == null) {
 			OAuth2Error oauth2Error = new OAuth2Error(AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE);
 			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());

+ 14 - 5
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java

@@ -217,9 +217,16 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
 		assertThat(loadedAuthorizationRequest).isNull();
 	}
 
-	@Test(expected = IllegalArgumentException.class)
+	@Test
 	public void removeAuthorizationRequestWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() {
-		this.authorizationRequestRepository.removeAuthorizationRequest(null);
+		assertThatThrownBy(() -> this.authorizationRequestRepository.removeAuthorizationRequest(
+				null, new MockHttpServletResponse())).isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void removeAuthorizationRequestWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizationRequestRepository.removeAuthorizationRequest(
+				new MockHttpServletRequest(), null)).isInstanceOf(IllegalArgumentException.class);
 	}
 
 	@Test
@@ -234,7 +241,7 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
 
 		request.addParameter(OAuth2ParameterNames.STATE, authorizationRequest.getState());
 		OAuth2AuthorizationRequest removedAuthorizationRequest =
-			this.authorizationRequestRepository.removeAuthorizationRequest(request);
+			this.authorizationRequestRepository.removeAuthorizationRequest(request, response);
 		OAuth2AuthorizationRequest loadedAuthorizationRequest =
 			this.authorizationRequestRepository.loadAuthorizationRequest(request);
 
@@ -255,7 +262,7 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
 
 		request.addParameter(OAuth2ParameterNames.STATE, authorizationRequest.getState());
 		OAuth2AuthorizationRequest removedAuthorizationRequest =
-				this.authorizationRequestRepository.removeAuthorizationRequest(request);
+				this.authorizationRequestRepository.removeAuthorizationRequest(request, response);
 
 		String sessionAttributeName = HttpSessionOAuth2AuthorizationRequestRepository.class.getName() +
 				".AUTHORIZATION_REQUEST";
@@ -269,8 +276,10 @@ public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
 		MockHttpServletRequest request = new MockHttpServletRequest();
 		request.addParameter(OAuth2ParameterNames.STATE, "state-1234");
 
+		MockHttpServletResponse response = new MockHttpServletResponse();
+
 		OAuth2AuthorizationRequest removedAuthorizationRequest =
-			this.authorizationRequestRepository.removeAuthorizationRequest(request);
+			this.authorizationRequestRepository.removeAuthorizationRequest(request, response);
 
 		assertThat(removedAuthorizationRequest).isNull();
 	}