|
@@ -83,30 +83,41 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
|
|
|
public static final String DEFAULT_AUTHORIZATION_RESPONSE_BASE_URI = "/oauth2/authorize/code";
|
|
|
private static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found";
|
|
|
private final AuthorizationResponseConverter authorizationResponseConverter = new AuthorizationResponseConverter();
|
|
|
+ private final ClientRegistrationIdentifierStrategy<String> providerIdentifierStrategy = new ProviderIdentifierStrategy();
|
|
|
+ private RequestMatcher authorizationResponseMatcher;
|
|
|
private ClientRegistrationRepository clientRegistrationRepository;
|
|
|
- private RequestMatcher authorizationResponseMatcher = new AuthorizationResponseMatcher();
|
|
|
private AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository();
|
|
|
- private final ClientRegistrationIdentifierStrategy<String> providerIdentifierStrategy = new ProviderIdentifierStrategy();
|
|
|
|
|
|
public AuthorizationCodeAuthenticationFilter() {
|
|
|
- super(new AuthorizationResponseMatcher());
|
|
|
+ this(DEFAULT_AUTHORIZATION_RESPONSE_BASE_URI);
|
|
|
+ }
|
|
|
+
|
|
|
+ public AuthorizationCodeAuthenticationFilter(String authorizationResponseBaseUri) {
|
|
|
+ super(new AuthorizationResponseMatcher(authorizationResponseBaseUri));
|
|
|
+ this.authorizationResponseMatcher = new AuthorizationResponseMatcher(authorizationResponseBaseUri);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void afterPropertiesSet() {
|
|
|
+ super.afterPropertiesSet();
|
|
|
+ Assert.notNull(this.clientRegistrationRepository, "clientRegistrationRepository cannot be null");
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response)
|
|
|
throws AuthenticationException, IOException, ServletException {
|
|
|
|
|
|
- AuthorizationRequest authorizationRequest = this.getAuthorizationRequestRepository().loadAuthorizationRequest(request);
|
|
|
+ AuthorizationRequest authorizationRequest = this.authorizationRequestRepository.loadAuthorizationRequest(request);
|
|
|
if (authorizationRequest == null) {
|
|
|
OAuth2Error oauth2Error = new OAuth2Error(AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE);
|
|
|
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
|
|
|
}
|
|
|
- this.getAuthorizationRequestRepository().removeAuthorizationRequest(request);
|
|
|
+ this.authorizationRequestRepository.removeAuthorizationRequest(request);
|
|
|
|
|
|
AuthorizationResponse authorizationResponse = this.authorizationResponseConverter.apply(request);
|
|
|
|
|
|
String registrationId = (String)authorizationRequest.getAdditionalParameters().get(OAuth2Parameter.REGISTRATION_ID);
|
|
|
- ClientRegistration clientRegistration = this.getClientRegistrationRepository().findByRegistrationId(registrationId);
|
|
|
+ ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId);
|
|
|
|
|
|
// The clientRegistration.redirectUri may contain Uri template variables, whether it's configured by
|
|
|
// the user or configured by default. In these cases, the redirectUri will be expanded and ultimately changed
|
|
@@ -142,18 +153,14 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
|
|
|
return oauth2UserAuthentication;
|
|
|
}
|
|
|
|
|
|
- public RequestMatcher getAuthorizationResponseMatcher() {
|
|
|
+ public final RequestMatcher getAuthorizationResponseMatcher() {
|
|
|
return this.authorizationResponseMatcher;
|
|
|
}
|
|
|
|
|
|
- public final <T extends RequestMatcher> void setAuthorizationResponseMatcher(T authorizationResponseMatcher) {
|
|
|
- Assert.notNull(authorizationResponseMatcher, "authorizationResponseMatcher cannot be null");
|
|
|
- this.authorizationResponseMatcher = authorizationResponseMatcher;
|
|
|
- this.setRequiresAuthenticationRequestMatcher(authorizationResponseMatcher);
|
|
|
- }
|
|
|
-
|
|
|
- protected ClientRegistrationRepository getClientRegistrationRepository() {
|
|
|
- return this.clientRegistrationRepository;
|
|
|
+ public final void setAuthorizationResponseBaseUri(String authorizationResponseBaseUri) {
|
|
|
+ Assert.hasText(authorizationResponseBaseUri, "authorizationResponseBaseUri cannot be empty");
|
|
|
+ this.authorizationResponseMatcher = new AuthorizationResponseMatcher(authorizationResponseBaseUri);
|
|
|
+ this.setRequiresAuthenticationRequestMatcher(this.authorizationResponseMatcher);
|
|
|
}
|
|
|
|
|
|
public final void setClientRegistrationRepository(ClientRegistrationRepository clientRegistrationRepository) {
|
|
@@ -161,10 +168,6 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
|
|
|
this.clientRegistrationRepository = clientRegistrationRepository;
|
|
|
}
|
|
|
|
|
|
- protected AuthorizationRequestRepository getAuthorizationRequestRepository() {
|
|
|
- return this.authorizationRequestRepository;
|
|
|
- }
|
|
|
-
|
|
|
public final void setAuthorizationRequestRepository(AuthorizationRequestRepository authorizationRequestRepository) {
|
|
|
Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null");
|
|
|
this.authorizationRequestRepository = authorizationRequestRepository;
|
|
@@ -215,10 +218,17 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
|
|
|
}
|
|
|
|
|
|
private static class AuthorizationResponseMatcher implements RequestMatcher {
|
|
|
+ private final String baseUri;
|
|
|
+
|
|
|
+ private AuthorizationResponseMatcher(String baseUri) {
|
|
|
+ Assert.hasText(baseUri, "baseUri cannot be empty");
|
|
|
+ this.baseUri = baseUri;
|
|
|
+ }
|
|
|
|
|
|
@Override
|
|
|
public boolean matches(HttpServletRequest request) {
|
|
|
- return this.successResponse(request) || this.errorResponse(request);
|
|
|
+ return request.getRequestURI().startsWith(this.baseUri) &&
|
|
|
+ (this.successResponse(request) || this.errorResponse(request));
|
|
|
}
|
|
|
|
|
|
private boolean successResponse(HttpServletRequest request) {
|