|
@@ -24,25 +24,29 @@ import jakarta.servlet.http.HttpServletRequest;
|
|
|
import jakarta.servlet.http.HttpServletResponse;
|
|
|
|
|
|
import org.springframework.core.log.LogMessage;
|
|
|
+import org.springframework.http.HttpStatus;
|
|
|
+import org.springframework.http.converter.HttpMessageConverter;
|
|
|
+import org.springframework.http.server.ServletServerHttpResponse;
|
|
|
import org.springframework.security.authentication.AbstractAuthenticationToken;
|
|
|
import org.springframework.security.authentication.AuthenticationDetailsSource;
|
|
|
import org.springframework.security.authentication.AuthenticationManager;
|
|
|
import org.springframework.security.core.Authentication;
|
|
|
+import org.springframework.security.core.AuthenticationException;
|
|
|
import org.springframework.security.core.context.SecurityContext;
|
|
|
import org.springframework.security.core.context.SecurityContextHolder;
|
|
|
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
|
|
|
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
|
|
import org.springframework.security.oauth2.core.OAuth2Error;
|
|
|
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
|
|
+import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
|
|
|
import org.springframework.security.oauth2.server.authorization.authentication.ClientSecretAuthenticationProvider;
|
|
|
import org.springframework.security.oauth2.server.authorization.authentication.JwtClientAssertionAuthenticationProvider;
|
|
|
-import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationException;
|
|
|
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
|
|
|
import org.springframework.security.oauth2.server.authorization.authentication.PublicClientAuthenticationProvider;
|
|
|
import org.springframework.security.oauth2.server.authorization.authentication.X509ClientCertificateAuthenticationProvider;
|
|
|
import org.springframework.security.oauth2.server.authorization.web.authentication.ClientSecretBasicAuthenticationConverter;
|
|
|
import org.springframework.security.oauth2.server.authorization.web.authentication.ClientSecretPostAuthenticationConverter;
|
|
|
import org.springframework.security.oauth2.server.authorization.web.authentication.JwtClientAssertionAuthenticationConverter;
|
|
|
-import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ClientAuthenticationFailureHandler;
|
|
|
import org.springframework.security.oauth2.server.authorization.web.authentication.PublicClientAuthenticationConverter;
|
|
|
import org.springframework.security.oauth2.server.authorization.web.authentication.X509ClientCertificateAuthenticationConverter;
|
|
|
import org.springframework.security.web.authentication.AuthenticationConverter;
|
|
@@ -50,6 +54,7 @@ import org.springframework.security.web.authentication.AuthenticationFailureHand
|
|
|
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
|
|
|
import org.springframework.security.web.authentication.DelegatingAuthenticationConverter;
|
|
|
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
|
|
|
+import org.springframework.security.web.authentication.www.BasicAuthenticationEntryPoint;
|
|
|
import org.springframework.security.web.util.matcher.RequestMatcher;
|
|
|
import org.springframework.util.Assert;
|
|
|
import org.springframework.web.filter.OncePerRequestFilter;
|
|
@@ -70,7 +75,6 @@ import org.springframework.web.filter.OncePerRequestFilter;
|
|
|
* @see ClientSecretAuthenticationProvider
|
|
|
* @see PublicClientAuthenticationConverter
|
|
|
* @see PublicClientAuthenticationProvider
|
|
|
- * @see OAuth2ClientAuthenticationFailureHandler
|
|
|
* @see <a target="_blank" href=
|
|
|
* "https://datatracker.ietf.org/doc/html/rfc6749#section-2.3">Section 2.3 Client
|
|
|
* Authentication</a>
|
|
@@ -84,13 +88,17 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
|
|
|
|
|
|
private final RequestMatcher requestMatcher;
|
|
|
|
|
|
+ private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter();
|
|
|
+
|
|
|
private final AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();
|
|
|
|
|
|
+ private final BasicAuthenticationEntryPoint basicAuthenticationEntryPoint = new BasicAuthenticationEntryPoint();
|
|
|
+
|
|
|
private AuthenticationConverter authenticationConverter;
|
|
|
|
|
|
private AuthenticationSuccessHandler authenticationSuccessHandler = this::onAuthenticationSuccess;
|
|
|
|
|
|
- private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ClientAuthenticationFailureHandler();
|
|
|
+ private AuthenticationFailureHandler authenticationFailureHandler = this::onAuthenticationFailure;
|
|
|
|
|
|
/**
|
|
|
* Constructs an {@code OAuth2ClientAuthenticationFilter} using the provided
|
|
@@ -106,6 +114,7 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
|
|
|
Assert.notNull(requestMatcher, "requestMatcher cannot be null");
|
|
|
this.authenticationManager = authenticationManager;
|
|
|
this.requestMatcher = requestMatcher;
|
|
|
+ this.basicAuthenticationEntryPoint.setRealmName("default");
|
|
|
// @formatter:off
|
|
|
this.authenticationConverter = new DelegatingAuthenticationConverter(
|
|
|
Arrays.asList(
|
|
@@ -129,16 +138,16 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
|
|
|
Authentication authenticationRequest = null;
|
|
|
try {
|
|
|
authenticationRequest = this.authenticationConverter.convert(request);
|
|
|
- if (authenticationRequest == null) {
|
|
|
- throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT);
|
|
|
- }
|
|
|
if (authenticationRequest instanceof AbstractAuthenticationToken authenticationToken) {
|
|
|
authenticationToken.setDetails(this.authenticationDetailsSource.buildDetails(request));
|
|
|
}
|
|
|
- validateClientIdentifier(authenticationRequest);
|
|
|
- Authentication authenticationResult = this.authenticationManager.authenticate(authenticationRequest);
|
|
|
- this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authenticationResult);
|
|
|
+ if (authenticationRequest != null) {
|
|
|
+ validateClientIdentifier(authenticationRequest);
|
|
|
+ Authentication authenticationResult = this.authenticationManager.authenticate(authenticationRequest);
|
|
|
+ this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authenticationResult);
|
|
|
+ }
|
|
|
filterChain.doFilter(request, response);
|
|
|
+
|
|
|
}
|
|
|
catch (OAuth2AuthenticationException ex) {
|
|
|
if (this.logger.isTraceEnabled()) {
|
|
@@ -151,8 +160,8 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
|
|
|
else {
|
|
|
this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
|
|
|
}
|
|
|
- }
|
|
|
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -202,6 +211,35 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ private void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response,
|
|
|
+ AuthenticationException authenticationException) throws IOException {
|
|
|
+
|
|
|
+ SecurityContextHolder.clearContext();
|
|
|
+
|
|
|
+ if (authenticationException instanceof OAuth2ClientAuthenticationException clientAuthenticationException) {
|
|
|
+ OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationException
|
|
|
+ .getClientAuthentication();
|
|
|
+ if (ClientAuthenticationMethod.CLIENT_SECRET_BASIC
|
|
|
+ .equals(clientAuthentication.getClientAuthenticationMethod())) {
|
|
|
+ this.basicAuthenticationEntryPoint.commence(request, response, authenticationException);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError();
|
|
|
+ ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
|
|
|
+ if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) {
|
|
|
+ httpResponse.setStatusCode(HttpStatus.UNAUTHORIZED);
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
|
|
|
+ }
|
|
|
+ // We don't want to reveal too much information to the caller so just return the
|
|
|
+ // error code
|
|
|
+ OAuth2Error errorResponse = new OAuth2Error(error.getErrorCode());
|
|
|
+ this.errorHttpResponseConverter.write(errorResponse, null, httpResponse);
|
|
|
+ }
|
|
|
+
|
|
|
private static void validateClientIdentifier(Authentication authentication) {
|
|
|
if (!(authentication instanceof OAuth2ClientAuthenticationToken)) {
|
|
|
return;
|
|
@@ -223,4 +261,21 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ private static final class OAuth2ClientAuthenticationException extends OAuth2AuthenticationException {
|
|
|
+
|
|
|
+ private final OAuth2ClientAuthenticationToken clientAuthentication;
|
|
|
+
|
|
|
+ private OAuth2ClientAuthenticationException(OAuth2Error error, Throwable cause,
|
|
|
+ OAuth2ClientAuthenticationToken clientAuthentication) {
|
|
|
+ super(error, cause);
|
|
|
+ Assert.notNull(clientAuthentication, "clientAuthentication cannot be null");
|
|
|
+ this.clientAuthentication = clientAuthentication;
|
|
|
+ }
|
|
|
+
|
|
|
+ private OAuth2ClientAuthenticationToken getClientAuthentication() {
|
|
|
+ return this.clientAuthentication;
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
}
|