浏览代码

Polish LogoutRequest#EncryptedID Support

Issue gh-10663
Josh Cummings 3 年之前
父节点
当前提交
d493598e17

+ 5 - 5
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/logout/LogoutRequestEncryptedIDUtils.java → saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/logout/LogoutRequestEncryptedIdUtils.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -46,16 +46,16 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP
  *
  * @author Robert Stoiber
  */
-final class LogoutRequestEncryptedIDUtils {
+final class LogoutRequestEncryptedIdUtils {
 
 	private static final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver(
 			Arrays.asList(new InlineEncryptedKeyResolver(), new EncryptedElementTypeEncryptedKeyResolver(),
 					new SimpleRetrievalMethodEncryptedKeyResolver()));
 
-	static SAMLObject decryptEncryptedID(EncryptedID encryptedID, RelyingPartyRegistration registration) {
+	static SAMLObject decryptEncryptedId(EncryptedID encryptedId, RelyingPartyRegistration registration) {
 		Decrypter decrypter = decrypter(registration);
 		try {
-			return decrypter.decrypt(encryptedID);
+			return decrypter.decrypt(encryptedId);
 
 		}
 		catch (Exception ex) {
@@ -75,7 +75,7 @@ final class LogoutRequestEncryptedIDUtils {
 		return decrypter;
 	}
 
-	private LogoutRequestEncryptedIDUtils() {
+	private LogoutRequestEncryptedIdUtils() {
 	}
 
 }

+ 19 - 14
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSamlLogoutRequestValidator.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -161,25 +161,30 @@ public final class OpenSamlLogoutRequestValidator implements Saml2LogoutRequestV
 			if (authentication == null) {
 				return;
 			}
-			NameID nameId = request.getNameID();
-			EncryptedID encryptedID = request.getEncryptedID();
-			if (nameId == null && encryptedID == null) {
+			NameID nameId = getNameId(request, registration);
+			if (nameId == null) {
 				errors.add(
 						new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND, "Failed to find subject in LogoutRequest"));
 				return;
 			}
 
-			if (nameId != null) {
-				validateNameID(nameId, authentication, errors);
-			}
-			else {
-				final NameID nameIDFromEncryptedID = decryptNameID(encryptedID, registration);
-				validateNameID(nameIDFromEncryptedID, authentication, errors);
-			}
+			validateNameId(nameId, authentication, errors);
 		};
 	}
 
-	private void validateNameID(NameID nameId, Authentication authentication, Collection<Saml2Error> errors) {
+	private NameID getNameId(LogoutRequest request, RelyingPartyRegistration registration) {
+		NameID nameId = request.getNameID();
+		if (nameId != null) {
+			return nameId;
+		}
+		EncryptedID encryptedId = request.getEncryptedID();
+		if (encryptedId == null) {
+			return null;
+		}
+		return decryptNameId(encryptedId, registration);
+	}
+
+	private void validateNameId(NameID nameId, Authentication authentication, Collection<Saml2Error> errors) {
 		String name = nameId.getValue();
 		if (!name.equals(authentication.getName())) {
 			errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_REQUEST,
@@ -187,8 +192,8 @@ public final class OpenSamlLogoutRequestValidator implements Saml2LogoutRequestV
 		}
 	}
 
-	private NameID decryptNameID(EncryptedID encryptedID, RelyingPartyRegistration registration) {
-		final SAMLObject decryptedId = LogoutRequestEncryptedIDUtils.decryptEncryptedID(encryptedID, registration);
+	private NameID decryptNameId(EncryptedID encryptedId, RelyingPartyRegistration registration) {
+		final SAMLObject decryptedId = LogoutRequestEncryptedIdUtils.decryptEncryptedId(encryptedId, registration);
 		if (decryptedId instanceof NameID) {
 			return ((NameID) decryptedId);
 		}

+ 4 - 2
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java

@@ -373,8 +373,10 @@ public final class TestOpenSamlObjects {
 		NameID nameId = nameIdBuilder.buildObject();
 		nameId.setValue("user");
 		logoutRequest.setNameID(null);
-		logoutRequest.setEncryptedID(encrypted(nameId,
-				registration.getAssertingPartyDetails().getEncryptionX509Credentials().stream().findFirst().get()));
+		Saml2X509Credential credential = registration.getAssertingPartyDetails().getEncryptionX509Credentials()
+				.iterator().next();
+		EncryptedID encrypted = encrypted(nameId, credential);
+		logoutRequest.setEncryptedID(encrypted);
 		IssuerBuilder issuerBuilder = new IssuerBuilder();
 		Issuer issuer = issuerBuilder.buildObject();
 		issuer.setValue(registration.getAssertingPartyDetails().getEntityId());

+ 11 - 8
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/logout/OpenSamlLogoutRequestValidatorTests.java

@@ -61,17 +61,16 @@ public class OpenSamlLogoutRequestValidatorTests {
 	}
 
 	@Test
-	public void handleWhenNameIdInEncryptedIdPostThenValidates() {
+	public void handleWhenNameIdIsEncryptedIdPostThenValidates() {
 
-		RelyingPartyRegistration registration = registrationWithEncryption()
-				.assertingPartyDetails((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST)).build();
+		RelyingPartyRegistration registration = decrypting(encrypting(registration())).build();
 		LogoutRequest logoutRequest = TestOpenSamlObjects.assertingPartyLogoutRequestNameIdInEncryptedId(registration);
 		sign(logoutRequest, registration);
 		Saml2LogoutRequest request = post(logoutRequest, registration);
 		Saml2LogoutRequestValidatorParameters parameters = new Saml2LogoutRequestValidatorParameters(request,
 				registration, authentication(registration));
 		Saml2LogoutValidatorResult result = this.manager.validate(parameters);
-		assertThat(result.hasErrors()).withFailMessage(() -> result.getErrors().toString()).isFalse().isFalse();
+		assertThat(result.hasErrors()).withFailMessage(() -> result.getErrors().toString()).isFalse();
 
 	}
 
@@ -149,10 +148,14 @@ public class OpenSamlLogoutRequestValidatorTests {
 				.assertingPartyDetails((party) -> party.singleLogoutServiceBinding(Saml2MessageBinding.POST));
 	}
 
-	private RelyingPartyRegistration.Builder registrationWithEncryption() {
-		return signing(verifying(TestRelyingPartyRegistrations.full()))
-				.assertingPartyDetails((party) -> party.encryptionX509Credentials(
-						(c) -> c.add(TestSaml2X509Credentials.assertingPartyEncryptingCredential())));
+	private RelyingPartyRegistration.Builder decrypting(RelyingPartyRegistration.Builder builder) {
+		return builder
+				.decryptionX509Credentials((c) -> c.add(TestSaml2X509Credentials.relyingPartyDecryptingCredential()));
+	}
+
+	private RelyingPartyRegistration.Builder encrypting(RelyingPartyRegistration.Builder builder) {
+		return builder.assertingPartyDetails((party) -> party.encryptionX509Credentials(
+				(c) -> c.add(TestSaml2X509Credentials.assertingPartyEncryptingCredential())));
 	}
 
 	private RelyingPartyRegistration.Builder verifying(RelyingPartyRegistration.Builder builder) {