|  | @@ -30,8 +30,11 @@ import org.springframework.http.MediaType;
 | 
	
		
			
				|  |  |  import org.springframework.security.core.Authentication;
 | 
	
		
			
				|  |  |  import org.springframework.security.core.context.SecurityContextHolder;
 | 
	
		
			
				|  |  |  import org.springframework.security.core.context.SecurityContextHolderStrategy;
 | 
	
		
			
				|  |  | +import org.springframework.security.saml2.core.Saml2Error;
 | 
	
		
			
				|  |  | +import org.springframework.security.saml2.core.Saml2ErrorCodes;
 | 
	
		
			
				|  |  |  import org.springframework.security.saml2.core.Saml2ParameterNames;
 | 
	
		
			
				|  |  |  import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal;
 | 
	
		
			
				|  |  | +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
 | 
	
		
			
				|  |  |  import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest;
 | 
	
		
			
				|  |  |  import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidator;
 | 
	
		
			
				|  |  |  import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequestValidatorParameters;
 | 
	
	
		
			
				|  | @@ -39,6 +42,8 @@ import org.springframework.security.saml2.provider.service.authentication.logout
 | 
	
		
			
				|  |  |  import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutValidatorResult;
 | 
	
		
			
				|  |  |  import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 | 
	
		
			
				|  |  |  import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
 | 
	
		
			
				|  |  | +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers;
 | 
	
		
			
				|  |  | +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver;
 | 
	
		
			
				|  |  |  import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
 | 
	
		
			
				|  |  |  import org.springframework.security.web.DefaultRedirectStrategy;
 | 
	
		
			
				|  |  |  import org.springframework.security.web.RedirectStrategy;
 | 
	
	
		
			
				|  | @@ -67,9 +72,9 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
 | 
	
		
			
				|  |  |  	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
 | 
	
		
			
				|  |  |  			.getContextHolderStrategy();
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	private final Saml2LogoutRequestValidator logoutRequestValidator;
 | 
	
		
			
				|  |  | +	private final Saml2LogoutRequestValidatorParametersResolver logoutRequestResolver;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
 | 
	
		
			
				|  |  | +	private final Saml2LogoutRequestValidator logoutRequestValidator;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	private final Saml2LogoutResponseResolver logoutResponseResolver;
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -77,7 +82,14 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	private RequestMatcher logoutRequestMatcher = new AntPathRequestMatcher("/logout/saml2/slo");
 | 
	
		
			
				|  |  | +	public Saml2LogoutRequestFilter(Saml2LogoutRequestValidatorParametersResolver logoutRequestResolver,
 | 
	
		
			
				|  |  | +			Saml2LogoutRequestValidator logoutRequestValidator, Saml2LogoutResponseResolver logoutResponseResolver,
 | 
	
		
			
				|  |  | +			LogoutHandler... handlers) {
 | 
	
		
			
				|  |  | +		this.logoutRequestResolver = logoutRequestResolver;
 | 
	
		
			
				|  |  | +		this.logoutRequestValidator = logoutRequestValidator;
 | 
	
		
			
				|  |  | +		this.logoutResponseResolver = logoutResponseResolver;
 | 
	
		
			
				|  |  | +		this.handler = new CompositeLogoutHandler(handlers);
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	/**
 | 
	
		
			
				|  |  |  	 * Constructs a {@link Saml2LogoutResponseFilter} for accepting SAML 2.0 Logout
 | 
	
	
		
			
				|  | @@ -91,7 +103,7 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
 | 
	
		
			
				|  |  |  	public Saml2LogoutRequestFilter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver,
 | 
	
		
			
				|  |  |  			Saml2LogoutRequestValidator logoutRequestValidator, Saml2LogoutResponseResolver logoutResponseResolver,
 | 
	
		
			
				|  |  |  			LogoutHandler... handlers) {
 | 
	
		
			
				|  |  | -		this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
 | 
	
		
			
				|  |  | +		this.logoutRequestResolver = new Saml2AssertingPartyLogoutRequestResolver(relyingPartyRegistrationResolver);
 | 
	
		
			
				|  |  |  		this.logoutRequestValidator = logoutRequestValidator;
 | 
	
		
			
				|  |  |  		this.logoutResponseResolver = logoutResponseResolver;
 | 
	
		
			
				|  |  |  		this.handler = new CompositeLogoutHandler(handlers);
 | 
	
	
		
			
				|  | @@ -100,26 +112,21 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
 | 
	
		
			
				|  |  |  	@Override
 | 
	
		
			
				|  |  |  	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
 | 
	
		
			
				|  |  |  			throws ServletException, IOException {
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -		if (!this.logoutRequestMatcher.matches(request)) {
 | 
	
		
			
				|  |  | -			chain.doFilter(request, response);
 | 
	
		
			
				|  |  | -			return;
 | 
	
		
			
				|  |  | +		Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
 | 
	
		
			
				|  |  | +		Saml2LogoutRequestValidatorParameters parameters;
 | 
	
		
			
				|  |  | +		try {
 | 
	
		
			
				|  |  | +			parameters = this.logoutRequestResolver.resolve(request, authentication);
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -		if (request.getParameter(Saml2ParameterNames.SAML_REQUEST) == null) {
 | 
	
		
			
				|  |  | -			chain.doFilter(request, response);
 | 
	
		
			
				|  |  | +		catch (Saml2AuthenticationException ex) {
 | 
	
		
			
				|  |  | +			this.logger.trace("Did not process logout request since failed to find requested RelyingPartyRegistration");
 | 
	
		
			
				|  |  | +			response.sendError(HttpServletResponse.SC_BAD_REQUEST);
 | 
	
		
			
				|  |  |  			return;
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -		Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
 | 
	
		
			
				|  |  | -		RelyingPartyRegistration registration = this.relyingPartyRegistrationResolver.resolve(request,
 | 
	
		
			
				|  |  | -				getRegistrationId(authentication));
 | 
	
		
			
				|  |  | -		if (registration == null) {
 | 
	
		
			
				|  |  | -			this.logger
 | 
	
		
			
				|  |  | -					.trace("Did not process logout request since failed to find associated RelyingPartyRegistration");
 | 
	
		
			
				|  |  | -			response.sendError(HttpServletResponse.SC_BAD_REQUEST);
 | 
	
		
			
				|  |  | +		if (parameters == null) {
 | 
	
		
			
				|  |  | +			chain.doFilter(request, response);
 | 
	
		
			
				|  |  |  			return;
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  | +		RelyingPartyRegistration registration = parameters.getRelyingPartyRegistration();
 | 
	
		
			
				|  |  |  		if (registration.getSingleLogoutServiceLocation() == null) {
 | 
	
		
			
				|  |  |  			this.logger.trace(
 | 
	
		
			
				|  |  |  					"Did not process logout request since RelyingPartyRegistration has not been configured with a logout request endpoint");
 | 
	
	
		
			
				|  | @@ -134,17 +141,6 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
 | 
	
		
			
				|  |  |  			return;
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		String serialized = request.getParameter(Saml2ParameterNames.SAML_REQUEST);
 | 
	
		
			
				|  |  | -		Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration)
 | 
	
		
			
				|  |  | -				.samlRequest(serialized).relayState(request.getParameter(Saml2ParameterNames.RELAY_STATE))
 | 
	
		
			
				|  |  | -				.binding(saml2MessageBinding).location(registration.getSingleLogoutServiceLocation())
 | 
	
		
			
				|  |  | -				.parameters((params) -> params.put(Saml2ParameterNames.SIG_ALG,
 | 
	
		
			
				|  |  | -						request.getParameter(Saml2ParameterNames.SIG_ALG)))
 | 
	
		
			
				|  |  | -				.parameters((params) -> params.put(Saml2ParameterNames.SIGNATURE,
 | 
	
		
			
				|  |  | -						request.getParameter(Saml2ParameterNames.SIGNATURE)))
 | 
	
		
			
				|  |  | -				.parametersQuery((params) -> request.getQueryString()).build();
 | 
	
		
			
				|  |  | -		Saml2LogoutRequestValidatorParameters parameters = new Saml2LogoutRequestValidatorParameters(logoutRequest,
 | 
	
		
			
				|  |  | -				registration, authentication);
 | 
	
		
			
				|  |  |  		Saml2LogoutValidatorResult result = this.logoutRequestValidator.validate(parameters);
 | 
	
		
			
				|  |  |  		if (result.hasErrors()) {
 | 
	
		
			
				|  |  |  			response.sendError(HttpServletResponse.SC_UNAUTHORIZED, result.getErrors().iterator().next().toString());
 | 
	
	
		
			
				|  | @@ -168,7 +164,10 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	public void setLogoutRequestMatcher(RequestMatcher logoutRequestMatcher) {
 | 
	
		
			
				|  |  |  		Assert.notNull(logoutRequestMatcher, "logoutRequestMatcher cannot be null");
 | 
	
		
			
				|  |  | -		this.logoutRequestMatcher = logoutRequestMatcher;
 | 
	
		
			
				|  |  | +		Assert.isInstanceOf(Saml2AssertingPartyLogoutRequestResolver.class, this.logoutRequestResolver,
 | 
	
		
			
				|  |  | +				"saml2LogoutRequestResolver and logoutRequestMatcher cannot both be set. Please set the request matcher in the saml2LogoutRequestResolver itself.");
 | 
	
		
			
				|  |  | +		((Saml2AssertingPartyLogoutRequestResolver) this.logoutRequestResolver)
 | 
	
		
			
				|  |  | +				.setLogoutRequestMatcher(logoutRequestMatcher);
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	/**
 | 
	
	
		
			
				|  | @@ -182,17 +181,6 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
 | 
	
		
			
				|  |  |  		this.securityContextHolderStrategy = securityContextHolderStrategy;
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	private String getRegistrationId(Authentication authentication) {
 | 
	
		
			
				|  |  | -		if (authentication == null) {
 | 
	
		
			
				|  |  | -			return null;
 | 
	
		
			
				|  |  | -		}
 | 
	
		
			
				|  |  | -		Object principal = authentication.getPrincipal();
 | 
	
		
			
				|  |  | -		if (principal instanceof Saml2AuthenticatedPrincipal) {
 | 
	
		
			
				|  |  | -			return ((Saml2AuthenticatedPrincipal) principal).getRelyingPartyRegistrationId();
 | 
	
		
			
				|  |  | -		}
 | 
	
		
			
				|  |  | -		return null;
 | 
	
		
			
				|  |  | -	}
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |  	private void doRedirect(HttpServletRequest request, HttpServletResponse response,
 | 
	
		
			
				|  |  |  			Saml2LogoutResponse logoutResponse) throws IOException {
 | 
	
		
			
				|  |  |  		String location = logoutResponse.getResponseLocation();
 | 
	
	
		
			
				|  | @@ -252,4 +240,73 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
 | 
	
		
			
				|  |  |  		return html.toString();
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +	private static class Saml2AssertingPartyLogoutRequestResolver
 | 
	
		
			
				|  |  | +			implements Saml2LogoutRequestValidatorParametersResolver {
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		private RequestMatcher logoutRequestMatcher = new AntPathRequestMatcher("/logout/saml2/slo");
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		Saml2AssertingPartyLogoutRequestResolver(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
 | 
	
		
			
				|  |  | +			this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		@Override
 | 
	
		
			
				|  |  | +		public Saml2LogoutRequestValidatorParameters resolve(HttpServletRequest request,
 | 
	
		
			
				|  |  | +				Authentication authentication) {
 | 
	
		
			
				|  |  | +			String serialized = request.getParameter(Saml2ParameterNames.SAML_REQUEST);
 | 
	
		
			
				|  |  | +			if (serialized == null) {
 | 
	
		
			
				|  |  | +				return null;
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +			RequestMatcher.MatchResult result = this.logoutRequestMatcher.matcher(request);
 | 
	
		
			
				|  |  | +			if (!result.isMatch()) {
 | 
	
		
			
				|  |  | +				return null;
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +			String registrationId = getRegistrationId(result, authentication);
 | 
	
		
			
				|  |  | +			RelyingPartyRegistration registration = this.relyingPartyRegistrationResolver.resolve(request,
 | 
	
		
			
				|  |  | +					registrationId);
 | 
	
		
			
				|  |  | +			if (registration == null) {
 | 
	
		
			
				|  |  | +				throw new Saml2AuthenticationException(
 | 
	
		
			
				|  |  | +						new Saml2Error(Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND, "registration not found"),
 | 
	
		
			
				|  |  | +						"registration not found");
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +			UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
 | 
	
		
			
				|  |  | +			String entityId = uriResolver.resolve(registration.getEntityId());
 | 
	
		
			
				|  |  | +			String logoutLocation = uriResolver.resolve(registration.getSingleLogoutServiceLocation());
 | 
	
		
			
				|  |  | +			String logoutResponseLocation = uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation());
 | 
	
		
			
				|  |  | +			registration = registration.mutate().entityId(entityId).singleLogoutServiceLocation(logoutLocation)
 | 
	
		
			
				|  |  | +					.singleLogoutServiceResponseLocation(logoutResponseLocation).build();
 | 
	
		
			
				|  |  | +			Saml2MessageBinding saml2MessageBinding = Saml2MessageBindingUtils.resolveBinding(request);
 | 
	
		
			
				|  |  | +			Saml2LogoutRequest logoutRequest = Saml2LogoutRequest.withRelyingPartyRegistration(registration)
 | 
	
		
			
				|  |  | +					.samlRequest(serialized).relayState(request.getParameter(Saml2ParameterNames.RELAY_STATE))
 | 
	
		
			
				|  |  | +					.binding(saml2MessageBinding).location(registration.getSingleLogoutServiceLocation())
 | 
	
		
			
				|  |  | +					.parameters((params) -> params.put(Saml2ParameterNames.SIG_ALG,
 | 
	
		
			
				|  |  | +							request.getParameter(Saml2ParameterNames.SIG_ALG)))
 | 
	
		
			
				|  |  | +					.parameters((params) -> params.put(Saml2ParameterNames.SIGNATURE,
 | 
	
		
			
				|  |  | +							request.getParameter(Saml2ParameterNames.SIGNATURE)))
 | 
	
		
			
				|  |  | +					.parametersQuery((params) -> request.getQueryString()).build();
 | 
	
		
			
				|  |  | +			return new Saml2LogoutRequestValidatorParameters(logoutRequest, registration, authentication);
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		void setLogoutRequestMatcher(RequestMatcher logoutRequestMatcher) {
 | 
	
		
			
				|  |  | +			Assert.notNull(logoutRequestMatcher, "logoutRequestMatcher cannot be null");
 | 
	
		
			
				|  |  | +			this.logoutRequestMatcher = logoutRequestMatcher;
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		private String getRegistrationId(RequestMatcher.MatchResult result, Authentication authentication) {
 | 
	
		
			
				|  |  | +			String registrationId = result.getVariables().get("registrationId");
 | 
	
		
			
				|  |  | +			if (registrationId != null) {
 | 
	
		
			
				|  |  | +				return registrationId;
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +			if (authentication == null) {
 | 
	
		
			
				|  |  | +				return null;
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +			if (authentication.getPrincipal() instanceof Saml2AuthenticatedPrincipal principal) {
 | 
	
		
			
				|  |  | +				return principal.getRelyingPartyRegistrationId();
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  | +			return null;
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  }
 |