|
@@ -19,23 +19,19 @@ package org.springframework.security.saml2.provider.service.servlet.filter;
|
|
|
import javax.servlet.http.HttpServletRequest;
|
|
|
import javax.servlet.http.HttpServletResponse;
|
|
|
|
|
|
-import org.springframework.http.HttpMethod;
|
|
|
import org.springframework.security.core.Authentication;
|
|
|
import org.springframework.security.core.AuthenticationException;
|
|
|
-import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
|
|
|
-import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
|
|
|
import org.springframework.security.saml2.core.Saml2Error;
|
|
|
-import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
|
|
+import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
|
|
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
|
|
+import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
|
|
|
+import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
|
|
|
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
|
|
|
+import org.springframework.security.web.authentication.AuthenticationConverter;
|
|
|
import org.springframework.security.web.authentication.session.ChangeSessionIdAuthenticationStrategy;
|
|
|
-import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
|
|
-import org.springframework.security.web.util.matcher.RequestMatcher;
|
|
|
import org.springframework.util.Assert;
|
|
|
|
|
|
-import static java.nio.charset.StandardCharsets.UTF_8;
|
|
|
import static org.springframework.security.saml2.core.Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND;
|
|
|
-import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
|
|
|
import static org.springframework.util.StringUtils.hasText;
|
|
|
|
|
|
/**
|
|
@@ -44,8 +40,7 @@ import static org.springframework.util.StringUtils.hasText;
|
|
|
public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProcessingFilter {
|
|
|
|
|
|
public static final String DEFAULT_FILTER_PROCESSES_URI = "/login/saml2/sso/{registrationId}";
|
|
|
- private final RequestMatcher matcher;
|
|
|
- private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
|
|
|
+ private final AuthenticationConverter authenticationConverter;
|
|
|
|
|
|
/**
|
|
|
* Creates a {@code Saml2WebSsoAuthenticationFilter} authentication filter that is configured
|
|
@@ -64,16 +59,30 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
|
|
|
public Saml2WebSsoAuthenticationFilter(
|
|
|
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
|
|
|
String filterProcessesUrl) {
|
|
|
- super(filterProcessesUrl);
|
|
|
- Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null");
|
|
|
- Assert.hasText(filterProcessesUrl, "filterProcessesUrl must contain a URL pattern");
|
|
|
+ this(new Saml2AuthenticationTokenConverter
|
|
|
+ (new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)),
|
|
|
+ filterProcessesUrl);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Creates a {@link Saml2WebSsoAuthenticationFilter} given the provided parameters
|
|
|
+ *
|
|
|
+ * @param authenticationConverter the strategy for converting an {@link HttpServletRequest}
|
|
|
+ * into an {@link Authentication}
|
|
|
+ * @param filterProcessingUrl the processing URL, must contain a {registrationId} variable
|
|
|
+ * @since 5.4
|
|
|
+ */
|
|
|
+ public Saml2WebSsoAuthenticationFilter(
|
|
|
+ AuthenticationConverter authenticationConverter,
|
|
|
+ String filterProcessingUrl) {
|
|
|
+ super(filterProcessingUrl);
|
|
|
+ Assert.notNull(authenticationConverter, "authenticationConverter cannot be null");
|
|
|
+ Assert.hasText(filterProcessingUrl, "filterProcessesUrl must contain a URL pattern");
|
|
|
Assert.isTrue(
|
|
|
- filterProcessesUrl.contains("{registrationId}"),
|
|
|
+ filterProcessingUrl.contains("{registrationId}"),
|
|
|
"filterProcessesUrl must contain a {registrationId} match variable"
|
|
|
);
|
|
|
- this.matcher = new AntPathRequestMatcher(filterProcessesUrl);
|
|
|
- setRequiresAuthenticationRequestMatcher(this.matcher);
|
|
|
- this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
|
|
|
+ this.authenticationConverter = authenticationConverter;
|
|
|
setAllowSessionCreation(true);
|
|
|
setSessionAuthenticationStrategy(new ChangeSessionIdAuthenticationStrategy());
|
|
|
}
|
|
@@ -86,37 +95,12 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
|
|
|
@Override
|
|
|
public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response)
|
|
|
throws AuthenticationException {
|
|
|
- String saml2Response = request.getParameter("SAMLResponse");
|
|
|
- byte[] b = Saml2Utils.samlDecode(saml2Response);
|
|
|
-
|
|
|
- String responseXml = inflateIfRequired(request, b);
|
|
|
- String registrationId = this.matcher.matcher(request).getVariables().get("registrationId");
|
|
|
- RelyingPartyRegistration rp =
|
|
|
- this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
|
|
|
- if (rp == null) {
|
|
|
+ Authentication authentication = this.authenticationConverter.convert(request);
|
|
|
+ if (authentication == null) {
|
|
|
Saml2Error saml2Error = new Saml2Error(RELYING_PARTY_REGISTRATION_NOT_FOUND,
|
|
|
- "Relying Party Registration not found with ID: " + registrationId);
|
|
|
+ "No relying party registration found");
|
|
|
throw new Saml2AuthenticationException(saml2Error);
|
|
|
}
|
|
|
- String applicationUri = Saml2ServletUtils.getApplicationUri(request);
|
|
|
- String relyingPartyEntityId = Saml2ServletUtils.resolveUrlTemplate(rp.getEntityId(), applicationUri, rp);
|
|
|
- String assertionConsumerServiceLocation = Saml2ServletUtils.resolveUrlTemplate(
|
|
|
- rp.getAssertionConsumerServiceLocation(), applicationUri, rp);
|
|
|
- RelyingPartyRegistration relyingPartyRegistration = withRelyingPartyRegistration(rp)
|
|
|
- .entityId(relyingPartyEntityId)
|
|
|
- .assertionConsumerServiceLocation(assertionConsumerServiceLocation)
|
|
|
- .build();
|
|
|
- Saml2AuthenticationToken authentication = new Saml2AuthenticationToken(
|
|
|
- relyingPartyRegistration, responseXml);
|
|
|
return getAuthenticationManager().authenticate(authentication);
|
|
|
}
|
|
|
-
|
|
|
- private String inflateIfRequired(HttpServletRequest request, byte[] b) {
|
|
|
- if (HttpMethod.GET.matches(request.getMethod())) {
|
|
|
- return Saml2Utils.samlInflate(b);
|
|
|
- }
|
|
|
- else {
|
|
|
- return new String(b, UTF_8);
|
|
|
- }
|
|
|
- }
|
|
|
}
|