Browse Source

Addressed review comments

Signed-off-by: Liviu Gheorghe <liviu.gheorghe.ro@gmail.com>
1livv 2 months ago
parent
commit
edfd7b9b43

+ 2 - 2
config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurerTests.java

@@ -371,7 +371,7 @@ public class Saml2LogoutConfigurerTests {
 	}
 
 	@Test
-	public void saml2LogoutRequestWhenNoRegistrationThen401() throws Exception {
+	public void saml2LogoutRequestWhenNoRegistrationThen400() throws Exception {
 		this.spring.register(Saml2LogoutDefaultsConfig.class).autowire();
 		DefaultSaml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal("user",
 				Collections.emptyMap());
@@ -384,7 +384,7 @@ public class Saml2LogoutConfigurerTests {
 				.param("SigAlg", this.apLogoutRequestSigAlg)
 				.param("Signature", this.apLogoutRequestSignature)
 				.with(authentication(user)))
-			.andExpect(status().isUnauthorized());
+			.andExpect(status().isBadRequest());
 		verifyNoInteractions(getBean(LogoutHandler.class));
 	}
 

+ 2 - 2
config/src/test/java/org/springframework/security/config/http/Saml2LogoutBeanDefinitionParserTests.java

@@ -271,7 +271,7 @@ public class Saml2LogoutBeanDefinitionParserTests {
 	}
 
 	@Test
-	public void saml2LogoutRequestWhenNoRegistrationThen401() throws Exception {
+	public void saml2LogoutRequestWhenNoRegistrationThen400() throws Exception {
 		this.spring.configLocations(this.xml("Default")).autowire();
 		DefaultSaml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal("user",
 				Collections.emptyMap());
@@ -284,7 +284,7 @@ public class Saml2LogoutBeanDefinitionParserTests {
 				.param("SigAlg", this.apLogoutRequestSigAlg)
 				.param("Signature", this.apLogoutRequestSignature)
 				.with(authentication(user)))
-			.andExpect(status().isUnauthorized());
+			.andExpect(status().isBadRequest());
 	}
 
 	@Test

+ 0 - 24
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ErrorCodes.java

@@ -130,30 +130,6 @@ public final class Saml2ErrorCodes {
 	 */
 	public static final String INVALID_IN_RESPONSE_TO = "invalid_in_response_to";
 
-	/**
-	 * The RP registration does not have configured a logout request endpoint
-	 * @since 6.3
-	 */
-	public static final String MISSING_LOGOUT_REQUEST_ENDPOINT = "missing_logout_request_endpoint";
-
-	/**
-	 * The saml response or logout request was delivered via an invalid binding
-	 * @since 6.3
-	 */
-	public static final String INVALID_BINDING = "invalid_binding";
-
-	/**
-	 * The saml logout request failed validation
-	 * @since 6.3
-	 */
-	public static final String INVALID_LOGOUT_REQUEST = "invalid_logout_request";
-
-	/**
-	 * The saml logout response could not be generated
-	 * @since 6.3
-	 */
-	public static final String FAILED_TO_GENERATE_LOGOUT_RESPONSE = "failed_to_generate_logout_response";
-
 	private Saml2ErrorCodes() {
 	}
 

+ 2 - 3
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/BaseOpenSamlLogoutResponseResolver.java

@@ -240,9 +240,8 @@ final class BaseOpenSamlLogoutResponseResolver implements Saml2LogoutResponseRes
 	private String getSamlStatus(Saml2AuthenticationException exception) {
 		Saml2Error saml2Error = exception.getSaml2Error();
 		return switch (saml2Error.getErrorCode()) {
-			case Saml2ErrorCodes.MISSING_LOGOUT_REQUEST_ENDPOINT, Saml2ErrorCodes.INVALID_BINDING ->
-				StatusCode.REQUEST_DENIED;
-			case Saml2ErrorCodes.INVALID_LOGOUT_REQUEST -> StatusCode.REQUESTER;
+			case Saml2ErrorCodes.INVALID_DESTINATION -> StatusCode.REQUEST_DENIED;
+			case Saml2ErrorCodes.INVALID_REQUEST -> StatusCode.REQUESTER;
 			default -> StatusCode.RESPONDER;
 		};
 	}

+ 30 - 23
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilter.java

@@ -113,27 +113,44 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
 			throws ServletException, IOException {
 		Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
+		Saml2LogoutRequestValidatorParameters parameters;
 		try {
-			Saml2LogoutRequestValidatorParameters parameters = this.logoutRequestResolver.resolve(request,
-					authentication);
-			if (parameters == null) {
-				chain.doFilter(request, response);
-				return;
-			}
+			parameters = this.logoutRequestResolver.resolve(request, authentication);
+		}
+		catch (Saml2AuthenticationException ex) {
+			this.logger.trace("Did not process logout request since failed to find requested RelyingPartyRegistration");
+			response.sendError(HttpServletResponse.SC_BAD_REQUEST);
+			return;
+		}
+		if (parameters == null) {
+			chain.doFilter(request, response);
+			return;
+		}
 
-			Saml2LogoutResponse logoutResponse = processLogoutRequest(request, response, authentication, parameters);
-			sendLogoutResponse(request, response, logoutResponse);
+		try {
+			validateLogoutRequest(request, parameters);
 		}
 		catch (Saml2AuthenticationException ex) {
 			Saml2LogoutResponse errorLogoutResponse = this.logoutResponseResolver.resolve(request, authentication, ex);
 			if (errorLogoutResponse == null) {
-				this.logger.trace("Returning error since no error logout response could be generated", ex);
+				this.logger.trace(LogMessage.format(
+						"Returning error since no error logout response could be generated: %s", ex.getSaml2Error()));
 				response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
 				return;
 			}
 
 			sendLogoutResponse(request, response, errorLogoutResponse);
+			return;
 		}
+
+		this.handler.logout(request, response, authentication);
+		Saml2LogoutResponse logoutResponse = this.logoutResponseResolver.resolve(request, authentication);
+		if (logoutResponse == null) {
+			this.logger.trace("Returning error since no logout response generated");
+			response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
+			return;
+		}
+		sendLogoutResponse(request, response, logoutResponse);
 	}
 
 	public void setLogoutRequestMatcher(RequestMatcher logoutRequestMatcher) {
@@ -155,13 +172,12 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
 		this.securityContextHolderStrategy = securityContextHolderStrategy;
 	}
 
-	private Saml2LogoutResponse processLogoutRequest(HttpServletRequest request, HttpServletResponse response,
-			Authentication authentication, Saml2LogoutRequestValidatorParameters parameters) {
+	private void validateLogoutRequest(HttpServletRequest request, Saml2LogoutRequestValidatorParameters parameters) {
 		RelyingPartyRegistration registration = parameters.getRelyingPartyRegistration();
 		if (registration.getSingleLogoutServiceLocation() == null) {
 			this.logger.trace(
 					"Did not process logout request since RelyingPartyRegistration has not been configured with a logout request endpoint");
-			throw new Saml2AuthenticationException(new Saml2Error(Saml2ErrorCodes.MISSING_LOGOUT_REQUEST_ENDPOINT,
+			throw new Saml2AuthenticationException(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION,
 					"RelyingPartyRegistration has not been configured with a logout request endpoint"));
 		}
 
@@ -169,24 +185,15 @@ public final class Saml2LogoutRequestFilter extends OncePerRequestFilter {
 		if (!registration.getSingleLogoutServiceBindings().contains(saml2MessageBinding)) {
 			this.logger.trace("Did not process logout request since used incorrect binding");
 			throw new Saml2AuthenticationException(
-					new Saml2Error(Saml2ErrorCodes.INVALID_BINDING, "Logout request used invalid binding"));
+					new Saml2Error(Saml2ErrorCodes.INVALID_REQUEST, "Logout request used invalid binding"));
 		}
 
 		Saml2LogoutValidatorResult result = this.logoutRequestValidator.validate(parameters);
 		if (result.hasErrors()) {
 			this.logger.debug(LogMessage.format("Failed to validate LogoutRequest: %s", result.getErrors()));
 			throw new Saml2AuthenticationException(
-					new Saml2Error(Saml2ErrorCodes.INVALID_LOGOUT_REQUEST, "Failed to validate the logout request"));
-		}
-
-		this.handler.logout(request, response, authentication);
-		Saml2LogoutResponse logoutResponse = this.logoutResponseResolver.resolve(request, authentication);
-		if (logoutResponse == null) {
-			this.logger.trace("Returning error since no logout response generated");
-			throw new Saml2AuthenticationException(new Saml2Error(Saml2ErrorCodes.FAILED_TO_GENERATE_LOGOUT_RESPONSE,
-					"Could not generated logout response"));
+					new Saml2Error(Saml2ErrorCodes.INVALID_REQUEST, "Failed to validate the logout request"));
 		}
-		return logoutResponse;
 	}
 
 	private void sendLogoutResponse(HttpServletRequest request, HttpServletResponse response,

+ 4 - 2
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutResponseResolver.java

@@ -53,7 +53,9 @@ public interface Saml2LogoutResponseResolver {
 	 * processed
 	 * @return a signed and serialized SAML 2.0 Logout Response
 	 */
-	Saml2LogoutResponse resolve(HttpServletRequest request, Authentication authentication,
-			Saml2AuthenticationException authenticationException);
+	default Saml2LogoutResponse resolve(HttpServletRequest request, Authentication authentication,
+			Saml2AuthenticationException authenticationException) {
+		return null;
+	}
 
 }

+ 4 - 13
saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutResponseResolverTests.java

@@ -65,7 +65,7 @@ public class OpenSaml4LogoutResponseResolverTests {
 		logoutResponseResolver.setParametersConsumer(parametersConsumer);
 		MockHttpServletRequest request = new MockHttpServletRequest();
 		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration()
-			.assertingPartyDetails(
+			.assertingPartyMetadata(
 					(party) -> party.singleLogoutServiceResponseLocation("https://ap.example.com/logout"))
 			.build();
 		Authentication authentication = new TestingAuthenticationToken("user", "password");
@@ -109,19 +109,10 @@ public class OpenSaml4LogoutResponseResolverTests {
 
 	private static Stream<Arguments> provideAuthExceptionAndExpectedSamlStatusCode() {
 		return Stream.of(
-				Arguments.of(
-						new Saml2AuthenticationException(
-								new Saml2Error(Saml2ErrorCodes.MISSING_LOGOUT_REQUEST_ENDPOINT, "")),
-						StatusCode.REQUEST_DENIED),
-				Arguments.of(new Saml2AuthenticationException(new Saml2Error(Saml2ErrorCodes.INVALID_BINDING, "")),
+				Arguments.of(new Saml2AuthenticationException(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION, "")),
 						StatusCode.REQUEST_DENIED),
-				Arguments.of(
-						new Saml2AuthenticationException(new Saml2Error(Saml2ErrorCodes.INVALID_LOGOUT_REQUEST, "")),
-						StatusCode.REQUESTER),
-				Arguments.of(
-						new Saml2AuthenticationException(
-								new Saml2Error(Saml2ErrorCodes.FAILED_TO_GENERATE_LOGOUT_RESPONSE, "")),
-						StatusCode.RESPONDER)
+				Arguments.of(new Saml2AuthenticationException(new Saml2Error(Saml2ErrorCodes.INVALID_REQUEST, "")),
+						StatusCode.REQUESTER)
 
 		);
 	}

+ 15 - 15
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/Saml2LogoutRequestFilterTests.java

@@ -99,7 +99,7 @@ public class Saml2LogoutRequestFilterTests {
 	@Test
 	public void doFilterWhenSamlRequestThenPosts() throws Exception {
 		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full()
-			.assertingPartyDetails((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST))
+			.assertingPartyMetadata((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST))
 			.build();
 		Authentication authentication = new TestingAuthenticationToken("user", "password");
 		given(this.securityContextHolderStrategy.getContext()).willReturn(new SecurityContextImpl(authentication));
@@ -149,7 +149,7 @@ public class Saml2LogoutRequestFilterTests {
 	@Test
 	public void doFilterWhenValidationFailsErrorLogoutResponseIsPosted() throws Exception {
 		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full()
-			.assertingPartyDetails((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST))
+			.assertingPartyMetadata((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST))
 			.build();
 		Authentication authentication = new TestingAuthenticationToken("user", "password");
 		SecurityContextHolder.getContext().setAuthentication(authentication);
@@ -165,7 +165,7 @@ public class Saml2LogoutRequestFilterTests {
 		given(this.logoutRequestValidator.validate(any()))
 			.willReturn(Saml2LogoutValidatorResult.withErrors(new Saml2Error("error", "description")).build());
 		given(this.logoutResponseResolver.resolve(any(), any(),
-				argThat((ex) -> ex.getSaml2Error().getErrorCode().equals(Saml2ErrorCodes.INVALID_LOGOUT_REQUEST))))
+				argThat((ex) -> ex.getSaml2Error().getErrorCode().equals(Saml2ErrorCodes.INVALID_REQUEST))))
 			.willReturn(logoutResponse);
 
 		this.logoutRequestProcessingFilter.doFilter(request, response, new MockFilterChain());
@@ -173,7 +173,7 @@ public class Saml2LogoutRequestFilterTests {
 		checkResponse(response.getContentAsString(), registration);
 		verify(this.logoutRequestValidator).validate(any());
 		verify(this.logoutResponseResolver).resolve(any(), any(),
-				argThat((ex) -> ex.getSaml2Error().getErrorCode().equals(Saml2ErrorCodes.INVALID_LOGOUT_REQUEST)));
+				argThat((ex) -> ex.getSaml2Error().getErrorCode().equals(Saml2ErrorCodes.INVALID_REQUEST)));
 		verifyNoInteractions(this.logoutHandler);
 	}
 
@@ -186,7 +186,7 @@ public class Saml2LogoutRequestFilterTests {
 		request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request");
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full()
-			.assertingPartyDetails((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST))
+			.assertingPartyMetadata((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST))
 			.singleLogoutServiceLocation(null)
 			.build();
 		Saml2LogoutResponse logoutResponse = Saml2LogoutResponse.withRelyingPartyRegistration(registration)
@@ -194,15 +194,15 @@ public class Saml2LogoutRequestFilterTests {
 			.build();
 
 		given(this.relyingPartyRegistrationResolver.resolve(any(), any())).willReturn(registration);
-		given(this.logoutResponseResolver.resolve(any(), any(), argThat(
-				(ex) -> ex.getSaml2Error().getErrorCode().equals(Saml2ErrorCodes.MISSING_LOGOUT_REQUEST_ENDPOINT))))
+		given(this.logoutResponseResolver.resolve(any(), any(),
+				argThat((ex) -> ex.getSaml2Error().getErrorCode().equals(Saml2ErrorCodes.INVALID_DESTINATION))))
 			.willReturn(logoutResponse);
 
 		this.logoutRequestProcessingFilter.doFilterInternal(request, response, new MockFilterChain());
 
 		checkResponse(response.getContentAsString(), registration);
-		verify(this.logoutResponseResolver).resolve(any(), any(), argThat(
-				(ex) -> ex.getSaml2Error().getErrorCode().equals(Saml2ErrorCodes.MISSING_LOGOUT_REQUEST_ENDPOINT)));
+		verify(this.logoutResponseResolver).resolve(any(), any(),
+				argThat((ex) -> ex.getSaml2Error().getErrorCode().equals(Saml2ErrorCodes.INVALID_DESTINATION)));
 		verifyNoInteractions(this.logoutHandler);
 	}
 
@@ -215,7 +215,7 @@ public class Saml2LogoutRequestFilterTests {
 		request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request");
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full()
-			.assertingPartyDetails((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST))
+			.assertingPartyMetadata((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST))
 			.singleLogoutServiceBindings((bindings) -> {
 				bindings.clear();
 				bindings.add(Saml2MessageBinding.REDIRECT);
@@ -227,14 +227,14 @@ public class Saml2LogoutRequestFilterTests {
 
 		given(this.relyingPartyRegistrationResolver.resolve(any(), any())).willReturn(registration);
 		given(this.logoutResponseResolver.resolve(any(), any(),
-				argThat((ex) -> ex.getSaml2Error().getErrorCode().equals(Saml2ErrorCodes.INVALID_BINDING))))
+				argThat((ex) -> ex.getSaml2Error().getErrorCode().equals(Saml2ErrorCodes.INVALID_REQUEST))))
 			.willReturn(logoutResponse);
 
 		this.logoutRequestProcessingFilter.doFilterInternal(request, response, new MockFilterChain());
 
 		checkResponse(response.getContentAsString(), registration);
 		verify(this.logoutResponseResolver).resolve(any(), any(),
-				argThat((ex) -> ex.getSaml2Error().getErrorCode().equals(Saml2ErrorCodes.INVALID_BINDING)));
+				argThat((ex) -> ex.getSaml2Error().getErrorCode().equals(Saml2ErrorCodes.INVALID_REQUEST)));
 		verifyNoInteractions(this.logoutHandler);
 	}
 
@@ -247,7 +247,7 @@ public class Saml2LogoutRequestFilterTests {
 		request.setParameter(Saml2ParameterNames.SAML_REQUEST, "request");
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full()
-			.assertingPartyDetails((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST))
+			.assertingPartyMetadata((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST))
 			.singleLogoutServiceBindings((bindings) -> {
 				bindings.clear();
 				bindings.add(Saml2MessageBinding.REDIRECT);
@@ -256,14 +256,14 @@ public class Saml2LogoutRequestFilterTests {
 
 		given(this.relyingPartyRegistrationResolver.resolve(any(), any())).willReturn(registration);
 		given(this.logoutResponseResolver.resolve(any(), any(),
-				argThat((ex) -> ex.getSaml2Error().getErrorCode().equals(Saml2ErrorCodes.INVALID_BINDING))))
+				argThat((ex) -> ex.getSaml2Error().getErrorCode().equals(Saml2ErrorCodes.INVALID_REQUEST))))
 			.willReturn(null);
 
 		this.logoutRequestProcessingFilter.doFilterInternal(request, response, new MockFilterChain());
 
 		assertThat(response.getStatus()).isEqualTo(401);
 		verify(this.logoutResponseResolver).resolve(any(), any(),
-				argThat((ex) -> ex.getSaml2Error().getErrorCode().equals(Saml2ErrorCodes.INVALID_BINDING)));
+				argThat((ex) -> ex.getSaml2Error().getErrorCode().equals(Saml2ErrorCodes.INVALID_REQUEST)));
 		verifyNoInteractions(this.logoutHandler);
 	}