瀏覽代碼

Use Optional in case child status code is null

youngkih 1 年之前
父節點
當前提交
994e064412

+ 23 - 20
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java

@@ -29,6 +29,7 @@ import java.util.Map;
 import java.util.Set;
 import java.util.Set;
 import java.util.HashSet;
 import java.util.HashSet;
 import java.util.Arrays;
 import java.util.Arrays;
+import java.util.Optional;
 import java.util.function.Consumer;
 import java.util.function.Consumer;
 
 
 import javax.annotation.Nonnull;
 import javax.annotation.Nonnull;
@@ -173,6 +174,8 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
 
 
 	private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter = createDefaultResponseAuthenticationConverter();
 	private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter = createDefaultResponseAuthenticationConverter();
 
 
+	private static final Set<String> includeChildStatusCodes = new HashSet<>(Arrays.asList(REQUESTER, RESPONDER, VERSION_MISMATCH));
+
 	/**
 	/**
 	 * Creates an {@link OpenSaml4AuthenticationProvider}
 	 * Creates an {@link OpenSaml4AuthenticationProvider}
 	 */
 	 */
@@ -409,6 +412,26 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
 		};
 		};
 	}
 	}
 
 
+	private static String getStatusCode(Response response) {
+		if (response.getStatus() == null) {
+			return StatusCode.SUCCESS;
+		}
+		if (response.getStatus().getStatusCode() == null) {
+			return StatusCode.SUCCESS;
+		}
+
+		StatusCode parentStatusCode = response.getStatus().getStatusCode();
+		String parentStatusCodeValue = parentStatusCode.getValue();
+		if (includeChildStatusCodes.contains(parentStatusCodeValue)) {
+			return Optional.ofNullable(parentStatusCode.getStatusCode())
+					.map(StatusCode::getValue)
+					.map(childStatusCodeValue -> parentStatusCodeValue + childStatusCodeValue)
+					.orElse(parentStatusCodeValue);
+		}
+
+		return parentStatusCodeValue;
+	}
+
 	private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest,
 	private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest,
 			String inResponseTo) {
 			String inResponseTo) {
 		if (!StringUtils.hasText(inResponseTo)) {
 		if (!StringUtils.hasText(inResponseTo)) {
@@ -619,26 +642,6 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
 		};
 		};
 	}
 	}
 
 
-	private static String getStatusCode(Response response) {
-		if (response.getStatus() == null) {
-			return StatusCode.SUCCESS;
-		}
-		if (response.getStatus().getStatusCode() == null) {
-			return StatusCode.SUCCESS;
-		}
-
-		Set<String> statusCodes = new HashSet<>(Arrays.asList(REQUESTER, RESPONDER, VERSION_MISMATCH));
-		StatusCode parentStatusCode = response.getStatus().getStatusCode();
-		String parentStatusCodeValue = parentStatusCode.getValue();
-		if (statusCodes.contains(parentStatusCodeValue)) {
-			StatusCode childStatusCode = parentStatusCode.getStatusCode();
-			String childStatusCodeValue = childStatusCode.getValue();
-			return parentStatusCodeValue + childStatusCodeValue;
-		}
-
-		return parentStatusCodeValue;
-	}
-
 	private Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionSignatureValidator() {
 	private Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionSignatureValidator() {
 		return createAssertionValidator(Saml2ErrorCodes.INVALID_SIGNATURE, (assertionToken) -> {
 		return createAssertionValidator(Saml2ErrorCodes.INVALID_SIGNATURE, (assertionToken) -> {
 			RelyingPartyRegistration registration = assertionToken.getToken().getRelyingPartyRegistration();
 			RelyingPartyRegistration registration = assertionToken.getToken().getRelyingPartyRegistration();