Bladeren bron

Update to return List of StatusCodes and add Saml2Error to result object and other formatting

YoungKi Hong 1 jaar geleden
bovenliggende
commit
6e45e65cac

+ 32 - 19
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java

@@ -20,16 +20,15 @@ import java.io.ByteArrayInputStream;
 import java.nio.charset.StandardCharsets;
 import java.time.Duration;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.HashSet;
-import java.util.Arrays;
-import java.util.Optional;
 import java.util.function.Consumer;
 
 import javax.annotation.Nonnull;
@@ -98,8 +97,6 @@ import org.springframework.util.LinkedMultiValueMap;
 import org.springframework.util.MultiValueMap;
 import org.springframework.util.StringUtils;
 
-import static org.opensaml.saml.saml2.core.StatusCode.*;
-
 /**
  * Implementation of {@link AuthenticationProvider} for SAML authentications when
  * receiving a {@code Response} object containing an {@code Assertion}. This
@@ -174,7 +171,8 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
 
 	private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter = createDefaultResponseAuthenticationConverter();
 
-	private static final Set<String> includeChildStatusCodes = new HashSet<>(Arrays.asList(REQUESTER, RESPONDER, VERSION_MISMATCH));
+	private static final Set<String> includeChildStatusCodes = new HashSet<>(
+			Arrays.asList(StatusCode.REQUESTER, StatusCode.RESPONDER, StatusCode.VERSION_MISMATCH));
 
 	/**
 	 * Creates an {@link OpenSaml4AuthenticationProvider}
@@ -379,11 +377,13 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
 			Response response = responseToken.getResponse();
 			Saml2AuthenticationToken token = responseToken.getToken();
 			Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success();
-			String statusCode = getStatusCode(response);
-			if (!StatusCode.SUCCESS.equals(statusCode)) {
-				String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode,
-						response.getID());
-				result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message));
+			List<String> statusCodes = getStatusCodes(response);
+			if (!isSuccess(statusCodes)) {
+				for (String statusCode : statusCodes) {
+					String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode,
+							response.getID());
+					result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message));
+				}
 			}
 
 			String inResponseTo = response.getInResponseTo();
@@ -412,24 +412,37 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
 		};
 	}
 
-	private static String getStatusCode(Response response) {
+	private static List<String> getStatusCodes(Response response) {
 		if (response.getStatus() == null) {
-			return StatusCode.SUCCESS;
+			return Arrays.asList(StatusCode.SUCCESS);
 		}
 		if (response.getStatus().getStatusCode() == null) {
-			return StatusCode.SUCCESS;
+			return Arrays.asList(StatusCode.SUCCESS);
 		}
 
 		StatusCode parentStatusCode = response.getStatus().getStatusCode();
 		String parentStatusCodeValue = parentStatusCode.getValue();
 		if (includeChildStatusCodes.contains(parentStatusCodeValue)) {
-			return Optional.ofNullable(parentStatusCode.getStatusCode())
-					.map(StatusCode::getValue)
-					.map(childStatusCodeValue -> parentStatusCodeValue + childStatusCodeValue)
-					.orElse(parentStatusCodeValue);
+			StatusCode statusCode = parentStatusCode.getStatusCode();
+			if (statusCode != null) {
+				String childStatusCodeValue = statusCode.getValue();
+				if (childStatusCodeValue != null) {
+					return Arrays.asList(parentStatusCodeValue, childStatusCodeValue);
+				}
+			}
+			return Arrays.asList(parentStatusCodeValue);
+		}
+
+		return Arrays.asList(parentStatusCodeValue);
+	}
+
+	private static boolean isSuccess(List<String> statusCodes) {
+		if (statusCodes.size() != 1) {
+			return false;
 		}
 
-		return parentStatusCodeValue;
+		String statusCode = statusCodes.get(0);
+		return StatusCode.SUCCESS.equals(statusCode);
 	}
 
 	private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest,

+ 28 - 14
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2023 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -86,8 +86,6 @@ import org.springframework.util.StringUtils;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.atLeastOnce;
@@ -736,7 +734,7 @@ public class OpenSaml4AuthenticationProviderTests {
 	}
 
 	@Test
-	public void setsOnlyParentStatusCodeOnResultDescription() {
+	public void authenticateWhenResponseStatusIsNotSuccessThenOnlyReturnParentStatusCodes() {
 		ResponseToken mockResponseToken = mock(ResponseToken.class);
 		Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class);
 		given(mockResponseToken.getToken()).willReturn(mockSamlToken);
@@ -744,7 +742,8 @@ public class OpenSaml4AuthenticationProviderTests {
 		RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class);
 		given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration);
 
-		RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(RelyingPartyRegistration.AssertingPartyDetails.class);
+		RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(
+				RelyingPartyRegistration.AssertingPartyDetails.class);
 		given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails);
 
 		Status parentStatus = new StatusBuilder().buildObject();
@@ -763,16 +762,21 @@ public class OpenSaml4AuthenticationProviderTests {
 
 		given(mockResponseToken.getResponse()).willReturn(mockResponse);
 
-		Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider.createDefaultResponseValidator();
+		Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider
+			.createDefaultResponseValidator();
 		Saml2ResponseValidatorResult result = validator.convert(mockResponseToken);
 
-		String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", parentStatusCode.getValue());
-		assertTrue(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(expectedErrorMessage)));
-		assertFalse(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(childStatusCode.getValue())));
+		String expectedErrorMessage = String.format("Invalid status [%s] for SAML response",
+				parentStatusCode.getValue());
+		assertThat(
+				result.getErrors().stream().anyMatch((error) -> error.getDescription().contains(expectedErrorMessage)));
+		assertThat(result.getErrors()
+			.stream()
+			.noneMatch((error) -> error.getDescription().contains(childStatusCode.getValue())));
 	}
 
 	@Test
-	public void setsParentAndChildStatusCodeOnResultDescription() {
+	public void authenticateWhenResponseStatusIsNotSuccessThenReturnParentAndChildStatusCode() {
 		ResponseToken mockResponseToken = mock(ResponseToken.class);
 		Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class);
 		given(mockResponseToken.getToken()).willReturn(mockSamlToken);
@@ -780,7 +784,8 @@ public class OpenSaml4AuthenticationProviderTests {
 		RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class);
 		given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration);
 
-		RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(RelyingPartyRegistration.AssertingPartyDetails.class);
+		RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(
+				RelyingPartyRegistration.AssertingPartyDetails.class);
 		given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails);
 
 		Status parentStatus = new StatusBuilder().buildObject();
@@ -799,11 +804,20 @@ public class OpenSaml4AuthenticationProviderTests {
 
 		given(mockResponseToken.getResponse()).willReturn(mockResponse);
 
-		Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider.createDefaultResponseValidator();
+		Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider
+			.createDefaultResponseValidator();
 		Saml2ResponseValidatorResult result = validator.convert(mockResponseToken);
 
-		String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", parentStatusCode.getValue() + childStatusCode.getValue());
-		assertTrue(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(expectedErrorMessage)));
+		String expectedParentErrorMessage = String.format("Invalid status [%s] for SAML response",
+				parentStatusCode.getValue());
+		String expectedChildErrorMessage = String.format("Invalid status [%s] for SAML response",
+				childStatusCode.getValue());
+		assertThat(result.getErrors()
+			.stream()
+			.anyMatch((error) -> error.getDescription().contains(expectedParentErrorMessage)));
+		assertThat(result.getErrors()
+			.stream()
+			.anyMatch((error) -> error.getDescription().contains(expectedChildErrorMessage)));
 	}
 
 	@Test