|
@@ -30,8 +30,11 @@ import org.springframework.http.MediaType;
|
|
|
import org.springframework.security.core.Authentication;
|
|
|
import org.springframework.security.core.context.SecurityContextHolder;
|
|
|
import org.springframework.security.core.context.SecurityContextHolderStrategy;
|
|
|
+import org.springframework.security.saml2.core.Saml2Error;
|
|
|
+import org.springframework.security.saml2.core.Saml2ErrorCodes;
|
|
|
import org.springframework.security.saml2.core.Saml2ParameterNames;
|
|
|
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal;
|
|
|
+import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
|
|
|
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest;
|
|
|
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidator;
|
|
|
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidatorParameters;
|
|
@@ -39,6 +42,8 @@ import org.springframework.security.saml2.provider.service.authentication.logout
|
|
|
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutValidatorResult;
|
|
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
|
|
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
|
|
|
+import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers;
|
|
|
+import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver;
|
|
|
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
|
|
|
import org.springframework.security.web.DefaultRedirectStrategy;
|
|
|
import org.springframework.security.web.RedirectStrategy;
|
|
@@ -67,9 +72,9 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
|
|
|
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
|
|
|
.getContextHolderStrategy();
|
|
|
|
|
|
- private final Saml2LogoutRequestValidator logoutRequestValidator;
|
|
|
+ private final Saml2LogoutRequestValidatorParametersResolver logoutRequestResolver;
|
|
|
|
|
|
- private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
|
|
|
+ private final Saml2LogoutRequestValidator logoutRequestValidator;
|
|
|
|
|
|
private final Saml2LogoutResponseResolver logoutResponseResolver;
|
|
|
|
|
@@ -77,7 +82,14 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
|
|
|
|
|
|
private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
|
|
|
|
|
|
- private RequestMatcher logoutRequestMatcher = new AntPathRequestMatcher("/logout/saml2/slo");
|
|
|
+ public Saml2LogoutRequestFilter(Saml2LogoutRequestValidatorParametersResolver logoutRequestResolver,
|
|
|
+ Saml2LogoutRequestValidator logoutRequestValidator, Saml2LogoutResponseResolver logoutResponseResolver,
|
|
|
+ LogoutHandler... handlers) {
|
|
|
+ this.logoutRequestResolver = logoutRequestResolver;
|
|
|
+ this.logoutRequestValidator = logoutRequestValidator;
|
|
|
+ this.logoutResponseResolver = logoutResponseResolver;
|
|
|
+ this.handler = new CompositeLogoutHandler(handlers);
|
|
|
+ }
|
|
|
|
|
|
/**
|
|
|
* Constructs a {@link Saml2LogoutResponseFilter} for accepting SAML 2.0 Logout
|
|
@@ -91,7 +103,7 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
|
|
|
public Saml2LogoutRequestFilter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver,
|
|
|
Saml2LogoutRequestValidator logoutRequestValidator, Saml2LogoutResponseResolver logoutResponseResolver,
|
|
|
LogoutHandler... handlers) {
|
|
|
- this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
|
|
|
+ this.logoutRequestResolver = new Saml2AssertingPartyLogoutRequestResolver(relyingPartyRegistrationResolver);
|
|
|
this.logoutRequestValidator = logoutRequestValidator;
|
|
|
this.logoutResponseResolver = logoutResponseResolver;
|
|
|
this.handler = new CompositeLogoutHandler(handlers);
|
|
@@ -100,26 +112,21 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
|
|
|
@Override
|
|
|
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
|
|
|
throws ServletException, IOException {
|
|
|
-
|
|
|
- if (!this.logoutRequestMatcher.matches(request)) {
|
|
|
- chain.doFilter(request, response);
|
|
|
- return;
|
|
|
+ Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
|
|
|
+ Saml2LogoutRequestValidatorParameters parameters;
|
|
|
+ try {
|
|
|
+ parameters = this.logoutRequestResolver.resolve(request, authentication);
|
|
|
}
|
|
|
-
|
|
|
- if (request.getParameter(Saml2ParameterNames.SAML_REQUEST) == null) {
|
|
|
- chain.doFilter(request, response);
|
|
|
+ catch (Saml2AuthenticationException ex) {
|
|
|
+ this.logger.trace("Did not process logout request since failed to find requested RelyingPartyRegistration");
|
|
|
+ response.sendError(HttpServletResponse.SC_BAD_REQUEST);
|
|
|
return;
|
|
|
}
|
|
|
-
|
|
|
- Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
|
|
|
- RelyingPartyRegistration registration = this.relyingPartyRegistrationResolver.resolve(request,
|
|
|
- getRegistrationId(authentication));
|
|
|
- if (registration == null) {
|
|
|
- this.logger
|
|
|
- .trace("Did not process logout request since failed to find associated RelyingPartyRegistration");
|
|
|
- response.sendError(HttpServletResponse.SC_BAD_REQUEST);
|
|
|
+ if (parameters == null) {
|
|
|
+ chain.doFilter(request, response);
|
|
|
return;
|
|
|
}
|
|
|
+ RelyingPartyRegistration registration = parameters.getRelyingPartyRegistration();
|
|
|
if (registration.getSingleLogoutServiceLocation() == null) {
|
|
|
this.logger.trace(
|
|
|
"Did not process logout request since RelyingPartyRegistration has not been configured with a logout request endpoint");
|
|
@@ -134,17 +141,6 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- String serialized = request.getParameter(Saml2ParameterNames.SAML_REQUEST);
|
|
|
- Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration)
|
|
|
- .samlRequest(serialized).relayState(request.getParameter(Saml2ParameterNames.RELAY_STATE))
|
|
|
- .binding(saml2MessageBinding).location(registration.getSingleLogoutServiceLocation())
|
|
|
- .parameters((params) -> params.put(Saml2ParameterNames.SIG_ALG,
|
|
|
- request.getParameter(Saml2ParameterNames.SIG_ALG)))
|
|
|
- .parameters((params) -> params.put(Saml2ParameterNames.SIGNATURE,
|
|
|
- request.getParameter(Saml2ParameterNames.SIGNATURE)))
|
|
|
- .parametersQuery((params) -> request.getQueryString()).build();
|
|
|
- Saml2LogoutRequestValidatorParameters parameters = new Saml2LogoutRequestValidatorParameters(logoutRequest,
|
|
|
- registration, authentication);
|
|
|
Saml2LogoutValidatorResult result = this.logoutRequestValidator.validate(parameters);
|
|
|
if (result.hasErrors()) {
|
|
|
response.sendError(HttpServletResponse.SC_UNAUTHORIZED, result.getErrors().iterator().next().toString());
|
|
@@ -168,7 +164,10 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
|
|
|
|
|
|
public void setLogoutRequestMatcher(RequestMatcher logoutRequestMatcher) {
|
|
|
Assert.notNull(logoutRequestMatcher, "logoutRequestMatcher cannot be null");
|
|
|
- this.logoutRequestMatcher = logoutRequestMatcher;
|
|
|
+ Assert.isInstanceOf(Saml2AssertingPartyLogoutRequestResolver.class, this.logoutRequestResolver,
|
|
|
+ "saml2LogoutRequestResolver and logoutRequestMatcher cannot both be set. Please set the request matcher in the saml2LogoutRequestResolver itself.");
|
|
|
+ ((Saml2AssertingPartyLogoutRequestResolver) this.logoutRequestResolver)
|
|
|
+ .setLogoutRequestMatcher(logoutRequestMatcher);
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -182,17 +181,6 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
|
|
|
this.securityContextHolderStrategy = securityContextHolderStrategy;
|
|
|
}
|
|
|
|
|
|
- private String getRegistrationId(Authentication authentication) {
|
|
|
- if (authentication == null) {
|
|
|
- return null;
|
|
|
- }
|
|
|
- Object principal = authentication.getPrincipal();
|
|
|
- if (principal instanceof Saml2AuthenticatedPrincipal) {
|
|
|
- return ((Saml2AuthenticatedPrincipal) principal).getRelyingPartyRegistrationId();
|
|
|
- }
|
|
|
- return null;
|
|
|
- }
|
|
|
-
|
|
|
private void doRedirect(HttpServletRequest request, HttpServletResponse response,
|
|
|
Saml2LogoutResponse logoutResponse) throws IOException {
|
|
|
String location = logoutResponse.getResponseLocation();
|
|
@@ -252,4 +240,73 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
|
|
|
return html.toString();
|
|
|
}
|
|
|
|
|
|
+ private static class Saml2AssertingPartyLogoutRequestResolver
|
|
|
+ implements Saml2LogoutRequestValidatorParametersResolver {
|
|
|
+
|
|
|
+ private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
|
|
|
+
|
|
|
+ private RequestMatcher logoutRequestMatcher = new AntPathRequestMatcher("/logout/saml2/slo");
|
|
|
+
|
|
|
+ Saml2AssertingPartyLogoutRequestResolver(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
|
|
|
+ this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Saml2LogoutRequestValidatorParameters resolve(HttpServletRequest request,
|
|
|
+ Authentication authentication) {
|
|
|
+ String serialized = request.getParameter(Saml2ParameterNames.SAML_REQUEST);
|
|
|
+ if (serialized == null) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ RequestMatcher.MatchResult result = this.logoutRequestMatcher.matcher(request);
|
|
|
+ if (!result.isMatch()) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ String registrationId = getRegistrationId(result, authentication);
|
|
|
+ RelyingPartyRegistration registration = this.relyingPartyRegistrationResolver.resolve(request,
|
|
|
+ registrationId);
|
|
|
+ if (registration == null) {
|
|
|
+ throw new Saml2AuthenticationException(
|
|
|
+ new Saml2Error(Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND, "registration not found"),
|
|
|
+ "registration not found");
|
|
|
+ }
|
|
|
+ UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
|
|
|
+ String entityId = uriResolver.resolve(registration.getEntityId());
|
|
|
+ String logoutLocation = uriResolver.resolve(registration.getSingleLogoutServiceLocation());
|
|
|
+ String logoutResponseLocation = uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation());
|
|
|
+ registration = registration.mutate().entityId(entityId).singleLogoutServiceLocation(logoutLocation)
|
|
|
+ .singleLogoutServiceResponseLocation(logoutResponseLocation).build();
|
|
|
+ Saml2MessageBinding saml2MessageBinding = Saml2MessageBindingUtils.resolveBinding(request);
|
|
|
+ Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration)
|
|
|
+ .samlRequest(serialized).relayState(request.getParameter(Saml2ParameterNames.RELAY_STATE))
|
|
|
+ .binding(saml2MessageBinding).location(registration.getSingleLogoutServiceLocation())
|
|
|
+ .parameters((params) -> params.put(Saml2ParameterNames.SIG_ALG,
|
|
|
+ request.getParameter(Saml2ParameterNames.SIG_ALG)))
|
|
|
+ .parameters((params) -> params.put(Saml2ParameterNames.SIGNATURE,
|
|
|
+ request.getParameter(Saml2ParameterNames.SIGNATURE)))
|
|
|
+ .parametersQuery((params) -> request.getQueryString()).build();
|
|
|
+ return new Saml2LogoutRequestValidatorParameters(logoutRequest, registration, authentication);
|
|
|
+ }
|
|
|
+
|
|
|
+ void setLogoutRequestMatcher(RequestMatcher logoutRequestMatcher) {
|
|
|
+ Assert.notNull(logoutRequestMatcher, "logoutRequestMatcher cannot be null");
|
|
|
+ this.logoutRequestMatcher = logoutRequestMatcher;
|
|
|
+ }
|
|
|
+
|
|
|
+ private String getRegistrationId(RequestMatcher.MatchResult result, Authentication authentication) {
|
|
|
+ String registrationId = result.getVariables().get("registrationId");
|
|
|
+ if (registrationId != null) {
|
|
|
+ return registrationId;
|
|
|
+ }
|
|
|
+ if (authentication == null) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ if (authentication.getPrincipal() instanceof Saml2AuthenticatedPrincipal principal) {
|
|
|
+ return principal.getRelyingPartyRegistrationId();
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
}
|