|
@@ -17,11 +17,14 @@
|
|
|
package org.springframework.security.oauth2.server.authorization.web;
|
|
package org.springframework.security.oauth2.server.authorization.web;
|
|
|
|
|
|
|
|
import java.io.IOException;
|
|
import java.io.IOException;
|
|
|
|
|
+import java.lang.reflect.Field;
|
|
|
import java.nio.charset.StandardCharsets;
|
|
import java.nio.charset.StandardCharsets;
|
|
|
import java.util.Arrays;
|
|
import java.util.Arrays;
|
|
|
import java.util.Collections;
|
|
import java.util.Collections;
|
|
|
import java.util.Set;
|
|
import java.util.Set;
|
|
|
|
|
+import java.util.function.Consumer;
|
|
|
|
|
|
|
|
|
|
+import jakarta.servlet.Filter;
|
|
|
import jakarta.servlet.FilterChain;
|
|
import jakarta.servlet.FilterChain;
|
|
|
import jakarta.servlet.ServletException;
|
|
import jakarta.servlet.ServletException;
|
|
|
import jakarta.servlet.http.HttpServletRequest;
|
|
import jakarta.servlet.http.HttpServletRequest;
|
|
@@ -38,14 +41,18 @@ import org.springframework.security.core.AuthenticationException;
|
|
|
import org.springframework.security.core.session.SessionRegistry;
|
|
import org.springframework.security.core.session.SessionRegistry;
|
|
|
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
|
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
|
|
import org.springframework.security.oauth2.core.OAuth2Error;
|
|
import org.springframework.security.oauth2.core.OAuth2Error;
|
|
|
|
|
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
|
|
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
|
|
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
|
|
|
|
|
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationContext;
|
|
|
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationException;
|
|
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationException;
|
|
|
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationProvider;
|
|
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationProvider;
|
|
|
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken;
|
|
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken;
|
|
|
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationProvider;
|
|
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationProvider;
|
|
|
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationToken;
|
|
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationToken;
|
|
|
|
|
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
|
|
|
|
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
|
|
|
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeRequestAuthenticationConverter;
|
|
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeRequestAuthenticationConverter;
|
|
|
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationConsentAuthenticationConverter;
|
|
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationConsentAuthenticationConverter;
|
|
|
import org.springframework.security.web.DefaultRedirectStrategy;
|
|
import org.springframework.security.web.DefaultRedirectStrategy;
|
|
@@ -64,6 +71,7 @@ import org.springframework.security.web.util.matcher.NegatedRequestMatcher;
|
|
|
import org.springframework.security.web.util.matcher.OrRequestMatcher;
|
|
import org.springframework.security.web.util.matcher.OrRequestMatcher;
|
|
|
import org.springframework.security.web.util.matcher.RequestMatcher;
|
|
import org.springframework.security.web.util.matcher.RequestMatcher;
|
|
|
import org.springframework.util.Assert;
|
|
import org.springframework.util.Assert;
|
|
|
|
|
+import org.springframework.util.ReflectionUtils;
|
|
|
import org.springframework.util.StringUtils;
|
|
import org.springframework.util.StringUtils;
|
|
|
import org.springframework.web.filter.OncePerRequestFilter;
|
|
import org.springframework.web.filter.OncePerRequestFilter;
|
|
|
import org.springframework.web.util.UriComponentsBuilder;
|
|
import org.springframework.web.util.UriComponentsBuilder;
|
|
@@ -180,21 +188,18 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
try {
|
|
try {
|
|
|
- Authentication authentication = this.authenticationConverter.convert(request);
|
|
|
|
|
- if (authentication instanceof AbstractAuthenticationToken authenticationToken) {
|
|
|
|
|
- authenticationToken.setDetails(this.authenticationDetailsSource.buildDetails(request));
|
|
|
|
|
|
|
+ // Get the pre-validated authorization code request (if available),
|
|
|
|
|
+ // which was set by OAuth2AuthorizationCodeRequestValidatingFilter
|
|
|
|
|
+ Authentication authentication = (Authentication) request
|
|
|
|
|
+ .getAttribute(OAuth2AuthorizationCodeRequestAuthenticationToken.class.getName());
|
|
|
|
|
+ if (authentication == null) {
|
|
|
|
|
+ authentication = this.authenticationConverter.convert(request);
|
|
|
|
|
+ if (authentication instanceof AbstractAuthenticationToken authenticationToken) {
|
|
|
|
|
+ authenticationToken.setDetails(this.authenticationDetailsSource.buildDetails(request));
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
Authentication authenticationResult = this.authenticationManager.authenticate(authentication);
|
|
Authentication authenticationResult = this.authenticationManager.authenticate(authentication);
|
|
|
|
|
|
|
|
- if (!authenticationResult.isAuthenticated()) {
|
|
|
|
|
- // If the Principal (Resource Owner) is not authenticated then pass
|
|
|
|
|
- // through the chain
|
|
|
|
|
- // with the expectation that the authentication process will commence via
|
|
|
|
|
- // AuthenticationEntryPoint
|
|
|
|
|
- filterChain.doFilter(request, response);
|
|
|
|
|
- return;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
if (authenticationResult instanceof OAuth2AuthorizationConsentAuthenticationToken authorizationConsentAuthenticationToken) {
|
|
if (authenticationResult instanceof OAuth2AuthorizationConsentAuthenticationToken authorizationConsentAuthenticationToken) {
|
|
|
if (this.logger.isTraceEnabled()) {
|
|
if (this.logger.isTraceEnabled()) {
|
|
|
this.logger.trace("Authorization consent is required");
|
|
this.logger.trace("Authorization consent is required");
|
|
@@ -401,4 +406,109 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
|
|
|
this.redirectStrategy.sendRedirect(request, response, redirectUri);
|
|
this.redirectStrategy.sendRedirect(request, response, redirectUri);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ Filter createAuthorizationCodeRequestValidatingFilter(RegisteredClientRepository registeredClientRepository,
|
|
|
|
|
+ Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authenticationValidator) {
|
|
|
|
|
+ return new OAuth2AuthorizationCodeRequestValidatingFilter(registeredClientRepository, authenticationValidator);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /**
|
|
|
|
|
+ * A {@code Filter} that is applied before {@code OAuth2AuthorizationEndpointFilter}
|
|
|
|
|
+ * and handles the pre-validation of an OAuth 2.0 Authorization Code Request.
|
|
|
|
|
+ */
|
|
|
|
|
+ private final class OAuth2AuthorizationCodeRequestValidatingFilter extends OncePerRequestFilter {
|
|
|
|
|
+
|
|
|
|
|
+ private final RegisteredClientRepository registeredClientRepository;
|
|
|
|
|
+
|
|
|
|
|
+ private final Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authenticationValidator;
|
|
|
|
|
+
|
|
|
|
|
+ private final Field setValidatedField;
|
|
|
|
|
+
|
|
|
|
|
+ private OAuth2AuthorizationCodeRequestValidatingFilter(RegisteredClientRepository registeredClientRepository,
|
|
|
|
|
+ Consumer<OAuth2AuthorizationCodeRequestAuthenticationContext> authenticationValidator) {
|
|
|
|
|
+ Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
|
|
|
|
|
+ Assert.notNull(authenticationValidator, "authenticationValidator cannot be null");
|
|
|
|
|
+ this.registeredClientRepository = registeredClientRepository;
|
|
|
|
|
+ this.authenticationValidator = authenticationValidator;
|
|
|
|
|
+ this.setValidatedField = ReflectionUtils.findField(OAuth2AuthorizationCodeRequestAuthenticationToken.class,
|
|
|
|
|
+ "validated");
|
|
|
|
|
+ ReflectionUtils.makeAccessible(this.setValidatedField);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ @Override
|
|
|
|
|
+ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
|
|
|
|
|
+ FilterChain filterChain) throws ServletException, IOException {
|
|
|
|
|
+
|
|
|
|
|
+ if (!OAuth2AuthorizationEndpointFilter.this.authorizationEndpointMatcher.matches(request)) {
|
|
|
|
|
+ filterChain.doFilter(request, response);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ try {
|
|
|
|
|
+ Authentication authentication = OAuth2AuthorizationEndpointFilter.this.authenticationConverter
|
|
|
|
|
+ .convert(request);
|
|
|
|
|
+ if (!(authentication instanceof OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication)) {
|
|
|
|
|
+ filterChain.doFilter(request, response);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ String requestUri = (String) authorizationCodeRequestAuthentication.getAdditionalParameters()
|
|
|
|
|
+ .get(OAuth2ParameterNames.REQUEST_URI);
|
|
|
|
|
+ if (StringUtils.hasText(requestUri)) {
|
|
|
|
|
+ filterChain.doFilter(request, response);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ authorizationCodeRequestAuthentication.setDetails(
|
|
|
|
|
+ OAuth2AuthorizationEndpointFilter.this.authenticationDetailsSource.buildDetails(request));
|
|
|
|
|
+
|
|
|
|
|
+ RegisteredClient registeredClient = this.registeredClientRepository
|
|
|
|
|
+ .findByClientId(authorizationCodeRequestAuthentication.getClientId());
|
|
|
|
|
+ if (registeredClient == null) {
|
|
|
|
|
+ String redirectUri = null; // Prevent redirect
|
|
|
|
|
+ OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken(
|
|
|
|
|
+ authorizationCodeRequestAuthentication.getAuthorizationUri(),
|
|
|
|
|
+ authorizationCodeRequestAuthentication.getClientId(),
|
|
|
|
|
+ (Authentication) authorizationCodeRequestAuthentication.getPrincipal(), redirectUri,
|
|
|
|
|
+ authorizationCodeRequestAuthentication.getState(),
|
|
|
|
|
+ authorizationCodeRequestAuthentication.getScopes(),
|
|
|
|
|
+ authorizationCodeRequestAuthentication.getAdditionalParameters());
|
|
|
|
|
+
|
|
|
|
|
+ OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST,
|
|
|
|
|
+ "OAuth 2.0 Parameter: " + OAuth2ParameterNames.CLIENT_ID,
|
|
|
|
|
+ "https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1");
|
|
|
|
|
+ throw new OAuth2AuthorizationCodeRequestAuthenticationException(error,
|
|
|
|
|
+ authorizationCodeRequestAuthenticationResult);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ OAuth2AuthorizationCodeRequestAuthenticationContext authenticationContext = OAuth2AuthorizationCodeRequestAuthenticationContext
|
|
|
|
|
+ .with(authorizationCodeRequestAuthentication)
|
|
|
|
|
+ .registeredClient(registeredClient)
|
|
|
|
|
+ .build();
|
|
|
|
|
+
|
|
|
|
|
+ this.authenticationValidator.accept(authenticationContext);
|
|
|
|
|
+
|
|
|
|
|
+ ReflectionUtils.setField(this.setValidatedField, authorizationCodeRequestAuthentication, true);
|
|
|
|
|
+
|
|
|
|
|
+ // Set the validated authorization code request as a request
|
|
|
|
|
+ // attribute
|
|
|
|
|
+ // to be used upstream by OAuth2AuthorizationEndpointFilter
|
|
|
|
|
+ request.setAttribute(OAuth2AuthorizationCodeRequestAuthenticationToken.class.getName(),
|
|
|
|
|
+ authorizationCodeRequestAuthentication);
|
|
|
|
|
+
|
|
|
|
|
+ filterChain.doFilter(request, response);
|
|
|
|
|
+ }
|
|
|
|
|
+ catch (OAuth2AuthenticationException ex) {
|
|
|
|
|
+ if (this.logger.isTraceEnabled()) {
|
|
|
|
|
+ this.logger.trace(LogMessage.format("Authorization request failed: %s", ex.getError()), ex);
|
|
|
|
|
+ }
|
|
|
|
|
+ OAuth2AuthorizationEndpointFilter.this.authenticationFailureHandler.onAuthenticationFailure(request,
|
|
|
|
|
+ response, ex);
|
|
|
|
|
+ }
|
|
|
|
|
+ finally {
|
|
|
|
|
+ request.removeAttribute(OAuth2AuthorizationCodeRequestAuthenticationToken.class.getName());
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
}
|
|
}
|