Browse Source

OAuth2AuthorizationEndpointFilter is applied after AuthorizationFilter

Closes gh-18251
Joe Grandja 1 week ago
parent
commit
c53e66a217

+ 29 - 3
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationEndpointConfigurer.java

@@ -16,10 +16,12 @@
 
 
 package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization;
 package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization;
 
 
+import java.lang.reflect.Method;
 import java.util.ArrayList;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.List;
 import java.util.function.Consumer;
 import java.util.function.Consumer;
 
 
+import jakarta.servlet.Filter;
 import jakarta.servlet.http.HttpServletRequest;
 import jakarta.servlet.http.HttpServletRequest;
 
 
 import org.springframework.http.HttpMethod;
 import org.springframework.http.HttpMethod;
@@ -36,10 +38,12 @@ import org.springframework.security.oauth2.server.authorization.authentication.O
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationValidator;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationValidator;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationToken;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
 import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
 import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter;
 import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter;
 import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeRequestAuthenticationConverter;
 import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeRequestAuthenticationConverter;
 import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationConsentAuthenticationConverter;
 import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationConsentAuthenticationConverter;
+import org.springframework.security.web.access.intercept.AuthorizationFilter;
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
@@ -50,6 +54,7 @@ import org.springframework.security.web.servlet.util.matcher.PathPatternRequestM
 import org.springframework.security.web.util.matcher.OrRequestMatcher;
 import org.springframework.security.web.util.matcher.OrRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
 import org.springframework.util.Assert;
+import org.springframework.util.ReflectionUtils;
 import org.springframework.util.StringUtils;
 import org.springframework.util.StringUtils;
 
 
 /**
 /**
@@ -83,6 +88,8 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C
 
 
 	private Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authorizationCodeRequestAuthenticationValidator;
 	private Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authorizationCodeRequestAuthenticationValidator;
 
 
+	private Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authorizationCodeRequestAuthenticationValidatorComposite;
+
 	private SessionAuthenticationStrategy sessionAuthenticationStrategy;
 	private SessionAuthenticationStrategy sessionAuthenticationStrategy;
 
 
 	/**
 	/**
@@ -248,8 +255,16 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C
 			authenticationProviders.addAll(0, this.authenticationProviders);
 			authenticationProviders.addAll(0, this.authenticationProviders);
 		}
 		}
 		this.authenticationProvidersConsumer.accept(authenticationProviders);
 		this.authenticationProvidersConsumer.accept(authenticationProviders);
-		authenticationProviders.forEach(
-				(authenticationProvider) -> httpSecurity.authenticationProvider(postProcess(authenticationProvider)));
+		authenticationProviders.forEach((authenticationProvider) -> {
+			httpSecurity.authenticationProvider(postProcess(authenticationProvider));
+			if (authenticationProvider instanceof OAuth2AuthorizationCodeRequestAuthenticationProvider) {
+				Method method = ReflectionUtils.findMethod(OAuth2AuthorizationCodeRequestAuthenticationProvider.class,
+						"getAuthenticationValidatorComposite");
+				ReflectionUtils.makeAccessible(method);
+				this.authorizationCodeRequestAuthenticationValidatorComposite = (Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext>) ReflectionUtils
+					.invokeMethod(method, authenticationProvider);
+			}
+		});
 	}
 	}
 
 
 	@Override
 	@Override
@@ -282,7 +297,18 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C
 		if (this.sessionAuthenticationStrategy != null) {
 		if (this.sessionAuthenticationStrategy != null) {
 			authorizationEndpointFilter.setSessionAuthenticationStrategy(this.sessionAuthenticationStrategy);
 			authorizationEndpointFilter.setSessionAuthenticationStrategy(this.sessionAuthenticationStrategy);
 		}
 		}
-		httpSecurity.addFilterBefore(postProcess(authorizationEndpointFilter),
+		httpSecurity.addFilterAfter(postProcess(authorizationEndpointFilter), AuthorizationFilter.class);
+		// Create and add
+		// OAuth2AuthorizationEndpointFilter.OAuth2AuthorizationCodeRequestValidatingFilter
+		Method method = ReflectionUtils.findMethod(OAuth2AuthorizationEndpointFilter.class,
+				"createAuthorizationCodeRequestValidatingFilter", RegisteredClientRepository.class, Consumer.class);
+		ReflectionUtils.makeAccessible(method);
+		RegisteredClientRepository registeredClientRepository = OAuth2ConfigurerUtils
+			.getRegisteredClientRepository(httpSecurity);
+		Filter authorizationCodeRequestValidatingFilter = (Filter) ReflectionUtils.invokeMethod(method,
+				authorizationEndpointFilter, registeredClientRepository,
+				this.authorizationCodeRequestAuthenticationValidatorComposite);
+		httpSecurity.addFilterBefore(postProcess(authorizationCodeRequestValidatingFilter),
 				AbstractPreAuthenticatedProcessingFilter.class);
 				AbstractPreAuthenticatedProcessingFilter.class);
 	}
 	}
 
 

+ 15 - 6
config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java

@@ -307,8 +307,8 @@ public class OAuth2AuthorizationCodeGrantTests {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 
 
 		this.mvc
 		this.mvc
-			.perform(
-					get(DEFAULT_AUTHORIZATION_ENDPOINT_URI).params(getAuthorizationRequestParameters(registeredClient)))
+			.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
+				.queryParams(getAuthorizationRequestParameters(registeredClient)))
 			.andExpect(status().isBadRequest())
 			.andExpect(status().isBadRequest())
 			.andReturn();
 			.andReturn();
 	}
 	}
@@ -851,21 +851,31 @@ public class OAuth2AuthorizationCodeGrantTests {
 		this.spring.register(AuthorizationServerConfigurationCustomAuthorizationEndpoint.class).autowire();
 		this.spring.register(AuthorizationServerConfigurationCustomAuthorizationEndpoint.class).autowire();
 
 
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		this.registeredClientRepository.save(registeredClient);
+
 		TestingAuthenticationToken principal = new TestingAuthenticationToken("principalName", "password");
 		TestingAuthenticationToken principal = new TestingAuthenticationToken("principalName", "password");
+		Map<String, Object> additionalParameters = new HashMap<>();
+		additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE);
+		additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
+		OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication = new OAuth2AuthorizationCodeRequestAuthenticationToken(
+				"https://provider.com/oauth2/authorize", registeredClient.getClientId(), principal,
+				registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED, registeredClient.getScopes(),
+				additionalParameters);
 		OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode("code", Instant.now(),
 		OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode("code", Instant.now(),
 				Instant.now().plus(5, ChronoUnit.MINUTES));
 				Instant.now().plus(5, ChronoUnit.MINUTES));
 		OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken(
 		OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken(
 				"https://provider.com/oauth2/authorize", registeredClient.getClientId(), principal, authorizationCode,
 				"https://provider.com/oauth2/authorize", registeredClient.getClientId(), principal, authorizationCode,
 				registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED,
 				registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED,
 				registeredClient.getScopes());
 				registeredClient.getScopes());
-		given(authorizationRequestConverter.convert(any())).willReturn(authorizationCodeRequestAuthenticationResult);
+		given(authorizationRequestConverter.convert(any())).willReturn(authorizationCodeRequestAuthentication);
 		given(authorizationRequestAuthenticationProvider
 		given(authorizationRequestAuthenticationProvider
 			.supports(eq(OAuth2AuthorizationCodeRequestAuthenticationToken.class))).willReturn(true);
 			.supports(eq(OAuth2AuthorizationCodeRequestAuthenticationToken.class))).willReturn(true);
 		given(authorizationRequestAuthenticationProvider.authenticate(any()))
 		given(authorizationRequestAuthenticationProvider.authenticate(any()))
 			.willReturn(authorizationCodeRequestAuthenticationResult);
 			.willReturn(authorizationCodeRequestAuthenticationResult);
 
 
 		this.mvc
 		this.mvc
-			.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI).params(getAuthorizationRequestParameters(registeredClient))
+			.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
+				.queryParams(getAuthorizationRequestParameters(registeredClient))
 				.with(user("user")))
 				.with(user("user")))
 			.andExpect(status().isOk());
 			.andExpect(status().isOk());
 
 
@@ -880,8 +890,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 				|| converter instanceof OAuth2AuthorizationCodeRequestAuthenticationConverter
 				|| converter instanceof OAuth2AuthorizationCodeRequestAuthenticationConverter
 				|| converter instanceof OAuth2AuthorizationConsentAuthenticationConverter);
 				|| converter instanceof OAuth2AuthorizationConsentAuthenticationConverter);
 
 
-		verify(authorizationRequestAuthenticationProvider)
-			.authenticate(eq(authorizationCodeRequestAuthenticationResult));
+		verify(authorizationRequestAuthenticationProvider).authenticate(eq(authorizationCodeRequestAuthentication));
 
 
 		@SuppressWarnings("unchecked")
 		@SuppressWarnings("unchecked")
 		ArgumentCaptor<List<AuthenticationProvider>> authenticationProvidersCaptor = ArgumentCaptor
 		ArgumentCaptor<List<AuthenticationProvider>> authenticationProvidersCaptor = ArgumentCaptor

+ 39 - 28
oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java

@@ -190,33 +190,31 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 		OAuth2AuthorizationCodeRequestAuthenticationContext.Builder authenticationContextBuilder = OAuth2AuthorizationCodeRequestAuthenticationContext
 		OAuth2AuthorizationCodeRequestAuthenticationContext.Builder authenticationContextBuilder = OAuth2AuthorizationCodeRequestAuthenticationContext
 			.with(authorizationCodeRequestAuthentication)
 			.with(authorizationCodeRequestAuthentication)
 			.registeredClient(registeredClient);
 			.registeredClient(registeredClient);
-		OAuth2AuthorizationCodeRequestAuthenticationContext authenticationContext = authenticationContextBuilder
-			.build();
 
 
-		// grant_type
-		OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_AUTHORIZATION_GRANT_TYPE_VALIDATOR
-			.accept(authenticationContext);
+		if (!authorizationCodeRequestAuthentication.isValidated()) {
+			OAuth2AuthorizationCodeRequestAuthenticationContext authenticationContext = authenticationContextBuilder
+				.build();
 
 
-		// redirect_uri and scope
-		this.authenticationValidator.accept(authenticationContext);
+			// grant_type
+			OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_AUTHORIZATION_GRANT_TYPE_VALIDATOR
+				.accept(authenticationContext);
 
 
-		// code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE)
-		OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_CODE_CHALLENGE_VALIDATOR
-			.accept(authenticationContext);
+			// redirect_uri and scope
+			this.authenticationValidator.accept(authenticationContext);
 
 
-		// prompt (OPTIONAL for OpenID Connect 1.0 Authentication Request)
-		Set<String> promptValues = Collections.emptySet();
-		if (authorizationCodeRequestAuthentication.getScopes().contains(OidcScopes.OPENID)) {
-			String prompt = (String) authorizationCodeRequestAuthentication.getAdditionalParameters().get("prompt");
-			if (StringUtils.hasText(prompt)) {
-				OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_PROMPT_VALIDATOR
-					.accept(authenticationContext);
-				promptValues = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(prompt, " ")));
-			}
-		}
+			// code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE)
+			OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_CODE_CHALLENGE_VALIDATOR
+				.accept(authenticationContext);
 
 
-		if (this.logger.isTraceEnabled()) {
-			this.logger.trace("Validated authorization code request parameters");
+			// prompt (OPTIONAL for OpenID Connect 1.0 Authentication Request)
+			OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_PROMPT_VALIDATOR
+				.accept(authenticationContext);
+
+			authorizationCodeRequestAuthentication.setValidated(true);
+
+			if (this.logger.isTraceEnabled()) {
+				this.logger.trace("Validated authorization code request parameters");
+			}
 		}
 		}
 
 
 		// ---------------
 		// ---------------
@@ -224,17 +222,23 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 		// ---------------
 		// ---------------
 
 
 		Authentication principal = (Authentication) authorizationCodeRequestAuthentication.getPrincipal();
 		Authentication principal = (Authentication) authorizationCodeRequestAuthentication.getPrincipal();
+
+		Set<String> promptValues = Collections.emptySet();
+		if (authorizationCodeRequestAuthentication.getScopes().contains(OidcScopes.OPENID)) {
+			String prompt = (String) authorizationCodeRequestAuthentication.getAdditionalParameters().get("prompt");
+			if (StringUtils.hasText(prompt)) {
+				promptValues = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(prompt, " ")));
+			}
+		}
+
 		if (!isPrincipalAuthenticated(principal)) {
 		if (!isPrincipalAuthenticated(principal)) {
 			if (promptValues.contains(OidcPrompt.NONE)) {
 			if (promptValues.contains(OidcPrompt.NONE)) {
-				// Return an error instead of displaying the login page (via the
-				// configured AuthenticationEntryPoint)
 				throwError("login_required", "prompt", authorizationCodeRequestAuthentication, registeredClient);
 				throwError("login_required", "prompt", authorizationCodeRequestAuthentication, registeredClient);
 			}
 			}
-			if (this.logger.isTraceEnabled()) {
-				this.logger.trace("Did not authenticate authorization code request since principal not authenticated");
+			else {
+				throwError(OAuth2ErrorCodes.INVALID_REQUEST, "principal", authorizationCodeRequestAuthentication,
+						registeredClient);
 			}
 			}
-			// Return the authorization request as-is where isAuthenticated() is false
-			return authorizationCodeRequestAuthentication;
 		}
 		}
 
 
 		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
 		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
@@ -400,6 +404,13 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 		this.authorizationConsentRequired = authorizationConsentRequired;
 		this.authorizationConsentRequired = authorizationConsentRequired;
 	}
 	}
 
 
+	Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> getAuthenticationValidatorComposite() {
+		return OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_AUTHORIZATION_GRANT_TYPE_VALIDATOR
+			.andThen(this.authenticationValidator)
+			.andThen(OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_CODE_CHALLENGE_VALIDATOR)
+			.andThen(OAuth2AuthorizationCodeRequestAuthenticationValidator.DEFAULT_PROMPT_VALIDATOR);
+	}
+
 	private static boolean isAuthorizationConsentRequired(
 	private static boolean isAuthorizationConsentRequired(
 			OAuth2AuthorizationCodeRequestAuthenticationContext authenticationContext) {
 			OAuth2AuthorizationCodeRequestAuthenticationContext authenticationContext) {
 		if (!authenticationContext.getRegisteredClient().getClientSettings().isRequireAuthorizationConsent()) {
 		if (!authenticationContext.getRegisteredClient().getClientSettings().isRequireAuthorizationConsent()) {

+ 10 - 0
oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationToken.java

@@ -42,6 +42,8 @@ public class OAuth2AuthorizationCodeRequestAuthenticationToken
 
 
 	private final OAuth2AuthorizationCode authorizationCode;
 	private final OAuth2AuthorizationCode authorizationCode;
 
 
+	private boolean validated;
+
 	/**
 	/**
 	 * Constructs an {@code OAuth2AuthorizationCodeRequestAuthenticationToken} using the
 	 * Constructs an {@code OAuth2AuthorizationCodeRequestAuthenticationToken} using the
 	 * provided parameters.
 	 * provided parameters.
@@ -89,4 +91,12 @@ public class OAuth2AuthorizationCodeRequestAuthenticationToken
 		return this.authorizationCode;
 		return this.authorizationCode;
 	}
 	}
 
 
+	final boolean isValidated() {
+		return this.validated;
+	}
+
+	final void setValidated(boolean validated) {
+		this.validated = validated;
+	}
+
 }
 }

+ 122 - 12
oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java

@@ -17,11 +17,14 @@
 package org.springframework.security.oauth2.server.authorization.web;
 package org.springframework.security.oauth2.server.authorization.web;
 
 
 import java.io.IOException;
 import java.io.IOException;
+import java.lang.reflect.Field;
 import java.nio.charset.StandardCharsets;
 import java.nio.charset.StandardCharsets;
 import java.util.Arrays;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.Collections;
 import java.util.Set;
 import java.util.Set;
+import java.util.function.Consumer;
 
 
+import jakarta.servlet.Filter;
 import jakarta.servlet.FilterChain;
 import jakarta.servlet.FilterChain;
 import jakarta.servlet.ServletException;
 import jakarta.servlet.ServletException;
 import jakarta.servlet.http.HttpServletRequest;
 import jakarta.servlet.http.HttpServletRequest;
@@ -38,14 +41,18 @@ import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.session.SessionRegistry;
 import org.springframework.security.core.session.SessionRegistry;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationContext;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationException;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationException;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationToken;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeRequestAuthenticationConverter;
 import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeRequestAuthenticationConverter;
 import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationConsentAuthenticationConverter;
 import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationConsentAuthenticationConverter;
 import org.springframework.security.web.DefaultRedirectStrategy;
 import org.springframework.security.web.DefaultRedirectStrategy;
@@ -64,6 +71,7 @@ import org.springframework.security.web.util.matcher.NegatedRequestMatcher;
 import org.springframework.security.web.util.matcher.OrRequestMatcher;
 import org.springframework.security.web.util.matcher.OrRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
 import org.springframework.util.Assert;
+import org.springframework.util.ReflectionUtils;
 import org.springframework.util.StringUtils;
 import org.springframework.util.StringUtils;
 import org.springframework.web.filter.OncePerRequestFilter;
 import org.springframework.web.filter.OncePerRequestFilter;
 import org.springframework.web.util.UriComponentsBuilder;
 import org.springframework.web.util.UriComponentsBuilder;
@@ -180,21 +188,18 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
 		}
 		}
 
 
 		try {
 		try {
-			Authentication authentication = this.authenticationConverter.convert(request);
-			if (authentication instanceof AbstractAuthenticationToken authenticationToken) {
-				authenticationToken.setDetails(this.authenticationDetailsSource.buildDetails(request));
+			// Get the pre-validated authorization code request (if available),
+			// which was set by OAuth2AuthorizationCodeRequestValidatingFilter
+			Authentication authentication = (Authentication) request
+				.getAttribute(OAuth2AuthorizationCodeRequestAuthenticationToken.class.getName());
+			if (authentication == null) {
+				authentication = this.authenticationConverter.convert(request);
+				if (authentication instanceof AbstractAuthenticationToken authenticationToken) {
+					authenticationToken.setDetails(this.authenticationDetailsSource.buildDetails(request));
+				}
 			}
 			}
 			Authentication authenticationResult = this.authenticationManager.authenticate(authentication);
 			Authentication authenticationResult = this.authenticationManager.authenticate(authentication);
 
 
-			if (!authenticationResult.isAuthenticated()) {
-				// If the Principal (Resource Owner) is not authenticated then pass
-				// through the chain
-				// with the expectation that the authentication process will commence via
-				// AuthenticationEntryPoint
-				filterChain.doFilter(request, response);
-				return;
-			}
-
 			if (authenticationResult instanceof OAuth2AuthorizationConsentAuthenticationToken authorizationConsentAuthenticationToken) {
 			if (authenticationResult instanceof OAuth2AuthorizationConsentAuthenticationToken authorizationConsentAuthenticationToken) {
 				if (this.logger.isTraceEnabled()) {
 				if (this.logger.isTraceEnabled()) {
 					this.logger.trace("Authorization consent is required");
 					this.logger.trace("Authorization consent is required");
@@ -401,4 +406,109 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
 		this.redirectStrategy.sendRedirect(request, response, redirectUri);
 		this.redirectStrategy.sendRedirect(request, response, redirectUri);
 	}
 	}
 
 
+	Filter createAuthorizationCodeRequestValidatingFilter(RegisteredClientRepository registeredClientRepository,
+			Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authenticationValidator) {
+		return new OAuth2AuthorizationCodeRequestValidatingFilter(registeredClientRepository, authenticationValidator);
+	}
+
+	/**
+	 * A {@code Filter} that is applied before {@code OAuth2AuthorizationEndpointFilter}
+	 * and handles the pre-validation of an OAuth 2.0 Authorization Code Request.
+	 */
+	private final class OAuth2AuthorizationCodeRequestValidatingFilter extends OncePerRequestFilter {
+
+		private final RegisteredClientRepository registeredClientRepository;
+
+		private final Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authenticationValidator;
+
+		private final Field setValidatedField;
+
+		private OAuth2AuthorizationCodeRequestValidatingFilter(RegisteredClientRepository registeredClientRepository,
+				Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authenticationValidator) {
+			Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
+			Assert.notNull(authenticationValidator, "authenticationValidator cannot be null");
+			this.registeredClientRepository = registeredClientRepository;
+			this.authenticationValidator = authenticationValidator;
+			this.setValidatedField = ReflectionUtils.findField(OAuth2AuthorizationCodeRequestAuthenticationToken.class,
+					"validated");
+			ReflectionUtils.makeAccessible(this.setValidatedField);
+		}
+
+		@Override
+		protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
+				FilterChain filterChain) throws ServletException, IOException {
+
+			if (!OAuth2AuthorizationEndpointFilter.this.authorizationEndpointMatcher.matches(request)) {
+				filterChain.doFilter(request, response);
+				return;
+			}
+
+			try {
+				Authentication authentication = OAuth2AuthorizationEndpointFilter.this.authenticationConverter
+					.convert(request);
+				if (!(authentication instanceof OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication)) {
+					filterChain.doFilter(request, response);
+					return;
+				}
+
+				String requestUri = (String) authorizationCodeRequestAuthentication.getAdditionalParameters()
+					.get(OAuth2ParameterNames.REQUEST_URI);
+				if (StringUtils.hasText(requestUri)) {
+					filterChain.doFilter(request, response);
+					return;
+				}
+
+				authorizationCodeRequestAuthentication.setDetails(
+						OAuth2AuthorizationEndpointFilter.this.authenticationDetailsSource.buildDetails(request));
+
+				RegisteredClient registeredClient = this.registeredClientRepository
+					.findByClientId(authorizationCodeRequestAuthentication.getClientId());
+				if (registeredClient == null) {
+					String redirectUri = null; // Prevent redirect
+					OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken(
+							authorizationCodeRequestAuthentication.getAuthorizationUri(),
+							authorizationCodeRequestAuthentication.getClientId(),
+							(Authentication) authorizationCodeRequestAuthentication.getPrincipal(), redirectUri,
+							authorizationCodeRequestAuthentication.getState(),
+							authorizationCodeRequestAuthentication.getScopes(),
+							authorizationCodeRequestAuthentication.getAdditionalParameters());
+
+					OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST,
+							"OAuth 2.0 Parameter: " + OAuth2ParameterNames.CLIENT_ID,
+							"https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1");
+					throw new OAuth2AuthorizationCodeRequestAuthenticationException(error,
+							authorizationCodeRequestAuthenticationResult);
+				}
+
+				OAuth2AuthorizationCodeRequestAuthenticationContext authenticationContext = OAuth2AuthorizationCodeRequestAuthenticationContext
+					.with(authorizationCodeRequestAuthentication)
+					.registeredClient(registeredClient)
+					.build();
+
+				this.authenticationValidator.accept(authenticationContext);
+
+				ReflectionUtils.setField(this.setValidatedField, authorizationCodeRequestAuthentication, true);
+
+				// Set the validated authorization code request as a request
+				// attribute
+				// to be used upstream by OAuth2AuthorizationEndpointFilter
+				request.setAttribute(OAuth2AuthorizationCodeRequestAuthenticationToken.class.getName(),
+						authorizationCodeRequestAuthentication);
+
+				filterChain.doFilter(request, response);
+			}
+			catch (OAuth2AuthenticationException ex) {
+				if (this.logger.isTraceEnabled()) {
+					this.logger.trace(LogMessage.format("Authorization request failed: %s", ex.getError()), ex);
+				}
+				OAuth2AuthorizationEndpointFilter.this.authenticationFailureHandler.onAuthenticationFailure(request,
+						response, ex);
+			}
+			finally {
+				request.removeAttribute(OAuth2AuthorizationCodeRequestAuthenticationToken.class.getName());
+			}
+		}
+
+	}
+
 }
 }

+ 5 - 7
oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java

@@ -428,7 +428,7 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
 	}
 	}
 
 
 	@Test
 	@Test
-	public void authenticateWhenPrincipalNotAuthenticatedThenReturnAuthorizationCodeRequest() {
+	public void authenticateWhenPrincipalNotAuthenticatedThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		given(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
 		given(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
 			.willReturn(registeredClient);
 			.willReturn(registeredClient);
@@ -438,12 +438,10 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
 		OAuth2AuthorizationCodeRequestAuthenticationToken authentication = new OAuth2AuthorizationCodeRequestAuthenticationToken(
 		OAuth2AuthorizationCodeRequestAuthenticationToken authentication = new OAuth2AuthorizationCodeRequestAuthenticationToken(
 				AUTHORIZATION_URI, registeredClient.getClientId(), this.principal, redirectUri, STATE,
 				AUTHORIZATION_URI, registeredClient.getClientId(), this.principal, redirectUri, STATE,
 				registeredClient.getScopes(), createPkceParameters());
 				registeredClient.getScopes(), createPkceParameters());
-
-		OAuth2AuthorizationCodeRequestAuthenticationToken authenticationResult = (OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationProvider
-			.authenticate(authentication);
-
-		assertThat(authenticationResult).isSameAs(authentication);
-		assertThat(authenticationResult.isAuthenticated()).isFalse();
+		assertThatExceptionOfType(OAuth2AuthorizationCodeRequestAuthenticationException.class)
+			.isThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+			.satisfies((ex) -> assertAuthenticationException(ex, OAuth2ErrorCodes.INVALID_REQUEST, "principal",
+					authentication.getRedirectUri()));
 	}
 	}
 
 
 	@Test
 	@Test

+ 11 - 26
oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@@ -372,7 +372,11 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		given(authenticationConverter.convert(any())).willReturn(authorizationCodeRequestAuthentication);
 		given(authenticationConverter.convert(any())).willReturn(authorizationCodeRequestAuthentication);
 		this.filter.setAuthenticationConverter(authenticationConverter);
 		this.filter.setAuthenticationConverter(authenticationConverter);
 
 
-		given(this.authenticationManager.authenticate(any())).willReturn(authorizationCodeRequestAuthentication);
+		OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken(
+				AUTHORIZATION_URI, registeredClient.getClientId(), this.principal, this.authorizationCode,
+				registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes());
+		authorizationCodeRequestAuthenticationResult.setAuthenticated(true);
+		given(this.authenticationManager.authenticate(any())).willReturn(authorizationCodeRequestAuthenticationResult);
 
 
 		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
 		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		MockHttpServletResponse response = new MockHttpServletResponse();
@@ -382,7 +386,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 
 
 		verify(authenticationConverter).convert(any());
 		verify(authenticationConverter).convert(any());
 		verify(this.authenticationManager).authenticate(any());
 		verify(this.authenticationManager).authenticate(any());
-		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+		verifyNoInteractions(filterChain);
 	}
 	}
 
 
 	@Test
 	@Test
@@ -461,9 +465,6 @@ public class OAuth2AuthorizationEndpointFilterTests {
 	@Test
 	@Test
 	public void doFilterWhenCustomAuthenticationDetailsSourceThenUsed() throws Exception {
 	public void doFilterWhenCustomAuthenticationDetailsSourceThenUsed() throws Exception {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
-		OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication = new OAuth2AuthorizationCodeRequestAuthenticationToken(
-				AUTHORIZATION_URI, registeredClient.getClientId(), this.principal,
-				registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes(), null);
 		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
 		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
 
 
 		AuthenticationDetailsSource<HttpServletRequest, WebAuthenticationDetails> authenticationDetailsSource = mock(
 		AuthenticationDetailsSource<HttpServletRequest, WebAuthenticationDetails> authenticationDetailsSource = mock(
@@ -472,36 +473,20 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		given(authenticationDetailsSource.buildDetails(request)).willReturn(webAuthenticationDetails);
 		given(authenticationDetailsSource.buildDetails(request)).willReturn(webAuthenticationDetails);
 		this.filter.setAuthenticationDetailsSource(authenticationDetailsSource);
 		this.filter.setAuthenticationDetailsSource(authenticationDetailsSource);
 
 
-		given(this.authenticationManager.authenticate(any())).willReturn(authorizationCodeRequestAuthentication);
-
-		MockHttpServletResponse response = new MockHttpServletResponse();
-		FilterChain filterChain = mock(FilterChain.class);
-
-		this.filter.doFilter(request, response, filterChain);
-
-		verify(authenticationDetailsSource).buildDetails(any());
-		verify(this.authenticationManager).authenticate(any());
-		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-	}
-
-	@Test
-	public void doFilterWhenAuthorizationRequestPrincipalNotAuthenticatedThenCommenceAuthentication() throws Exception {
-		this.principal.setAuthenticated(false);
-		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken(
 		OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken(
-				AUTHORIZATION_URI, registeredClient.getClientId(), this.principal,
-				registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes(), null);
-		authorizationCodeRequestAuthenticationResult.setAuthenticated(false);
+				AUTHORIZATION_URI, registeredClient.getClientId(), this.principal, this.authorizationCode,
+				registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes());
+		authorizationCodeRequestAuthenticationResult.setAuthenticated(true);
 		given(this.authenticationManager.authenticate(any())).willReturn(authorizationCodeRequestAuthenticationResult);
 		given(this.authenticationManager.authenticate(any())).willReturn(authorizationCodeRequestAuthenticationResult);
 
 
-		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
 		FilterChain filterChain = mock(FilterChain.class);
 
 
 		this.filter.doFilter(request, response, filterChain);
 		this.filter.doFilter(request, response, filterChain);
 
 
+		verify(authenticationDetailsSource).buildDetails(any());
 		verify(this.authenticationManager).authenticate(any());
 		verify(this.authenticationManager).authenticate(any());
-		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
+		verifyNoInteractions(filterChain);
 	}
 	}
 
 
 	@Test
 	@Test