|
@@ -1,5 +1,5 @@
|
|
|
/*
|
|
|
- * Copyright 2020-2024 the original author or authors.
|
|
|
+ * Copyright 2020-2025 the original author or authors.
|
|
|
*
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
* you may not use this file except in compliance with the License.
|
|
@@ -34,6 +34,7 @@ import org.springframework.security.core.Authentication;
|
|
|
import org.springframework.security.core.AuthenticationException;
|
|
|
import org.springframework.security.core.context.SecurityContext;
|
|
|
import org.springframework.security.core.context.SecurityContextHolder;
|
|
|
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
|
|
|
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
|
|
import org.springframework.security.oauth2.core.OAuth2Error;
|
|
|
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
|
@@ -53,6 +54,7 @@ import org.springframework.security.web.authentication.AuthenticationConverter;
|
|
|
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
|
|
|
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
|
|
|
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
|
|
|
+import org.springframework.security.web.authentication.www.BasicAuthenticationEntryPoint;
|
|
|
import org.springframework.security.web.util.matcher.RequestMatcher;
|
|
|
import org.springframework.util.Assert;
|
|
|
import org.springframework.web.filter.OncePerRequestFilter;
|
|
@@ -90,6 +92,8 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
|
|
|
|
|
|
private final AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();
|
|
|
|
|
|
+ private final BasicAuthenticationEntryPoint basicAuthenticationEntryPoint = new BasicAuthenticationEntryPoint();
|
|
|
+
|
|
|
private AuthenticationConverter authenticationConverter;
|
|
|
|
|
|
private AuthenticationSuccessHandler authenticationSuccessHandler = this::onAuthenticationSuccess;
|
|
@@ -110,6 +114,7 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
|
|
|
Assert.notNull(requestMatcher, "requestMatcher cannot be null");
|
|
|
this.authenticationManager = authenticationManager;
|
|
|
this.requestMatcher = requestMatcher;
|
|
|
+ this.basicAuthenticationEntryPoint.setRealmName("default");
|
|
|
// @formatter:off
|
|
|
this.authenticationConverter = new DelegatingAuthenticationConverter(
|
|
|
Arrays.asList(
|
|
@@ -130,8 +135,9 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
+ Authentication authenticationRequest = null;
|
|
|
try {
|
|
|
- Authentication authenticationRequest = this.authenticationConverter.convert(request);
|
|
|
+ authenticationRequest = this.authenticationConverter.convert(request);
|
|
|
if (authenticationRequest instanceof AbstractAuthenticationToken) {
|
|
|
((AbstractAuthenticationToken) authenticationRequest)
|
|
|
.setDetails(this.authenticationDetailsSource.buildDetails(request));
|
|
@@ -148,7 +154,14 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
|
|
|
if (this.logger.isTraceEnabled()) {
|
|
|
this.logger.trace(LogMessage.format("Client authentication failed: %s", ex.getError()), ex);
|
|
|
}
|
|
|
- this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
|
|
|
+ if (authenticationRequest instanceof OAuth2ClientAuthenticationToken clientAuthentication) {
|
|
|
+ this.authenticationFailureHandler.onAuthenticationFailure(request, response,
|
|
|
+ new OAuth2ClientAuthenticationException(ex.getError(), ex, clientAuthentication));
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
|
|
|
+ }
|
|
|
+
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -200,21 +213,21 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
|
|
|
}
|
|
|
|
|
|
private void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response,
|
|
|
- AuthenticationException exception) throws IOException {
|
|
|
+ AuthenticationException authenticationException) throws IOException {
|
|
|
|
|
|
SecurityContextHolder.clearContext();
|
|
|
|
|
|
- // TODO
|
|
|
- // The authorization server MAY return an HTTP 401 (Unauthorized) status code
|
|
|
- // to indicate which HTTP authentication schemes are supported.
|
|
|
- // If the client attempted to authenticate via the "Authorization" request header
|
|
|
- // field,
|
|
|
- // the authorization server MUST respond with an HTTP 401 (Unauthorized) status
|
|
|
- // code and
|
|
|
- // include the "WWW-Authenticate" response header field
|
|
|
- // matching the authentication scheme used by the client.
|
|
|
-
|
|
|
- OAuth2Error error = ((OAuth2AuthenticationException) exception).getError();
|
|
|
+ if (authenticationException instanceof OAuth2ClientAuthenticationException clientAuthenticationException) {
|
|
|
+ OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationException
|
|
|
+ .getClientAuthentication();
|
|
|
+ if (ClientAuthenticationMethod.CLIENT_SECRET_BASIC
|
|
|
+ .equals(clientAuthentication.getClientAuthenticationMethod())) {
|
|
|
+ this.basicAuthenticationEntryPoint.commence(request, response, authenticationException);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError();
|
|
|
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
|
|
|
if (OAuth2ErrorCodes.INVALID_CLIENT.equals(error.getErrorCode())) {
|
|
|
httpResponse.setStatusCode(HttpStatus.UNAUTHORIZED);
|
|
@@ -249,4 +262,21 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ private static final class OAuth2ClientAuthenticationException extends OAuth2AuthenticationException {
|
|
|
+
|
|
|
+ private final OAuth2ClientAuthenticationToken clientAuthentication;
|
|
|
+
|
|
|
+ private OAuth2ClientAuthenticationException(OAuth2Error error, Throwable cause,
|
|
|
+ OAuth2ClientAuthenticationToken clientAuthentication) {
|
|
|
+ super(error, cause);
|
|
|
+ Assert.notNull(clientAuthentication, "clientAuthentication cannot be null");
|
|
|
+ this.clientAuthentication = clientAuthentication;
|
|
|
+ }
|
|
|
+
|
|
|
+ private OAuth2ClientAuthenticationToken getClientAuthentication() {
|
|
|
+ return this.clientAuthentication;
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
}
|