浏览代码

Saml2AuthenticationToken takes a RelyingPartyRegistration

Closes gh-8845
Josh Cummings 5 年之前
父节点
当前提交
a54e77a3c3

+ 13 - 18
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java

@@ -480,13 +480,10 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 
 		private SignatureTrustEngine buildSignatureTrustEngine(Saml2AuthenticationToken token) {
 			Set<Credential> credentials = new HashSet<>();
-			for (Saml2X509Credential key : token.getX509Credentials()) {
-				if (!key.isSignatureVerficationCredential()) {
-					continue;
-				}
+			for (Saml2X509Credential key : token.getRelyingPartyRegistration().getVerificationCredentials()) {
 				BasicX509Credential cred = new BasicX509Credential(key.getCertificate());
 				cred.setUsageType(UsageType.SIGNING);
-				cred.setEntityId(token.getIdpEntityId());
+				cred.setEntityId(token.getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId());
 				credentials.add(cred);
 			}
 			CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials);
@@ -506,13 +503,14 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 				Map<String, Saml2AuthenticationException> validationExceptions = new LinkedHashMap<>();
 
 				String destination = response.getDestination();
-				if (StringUtils.hasText(destination) && !destination.equals(token.getRecipientUri())) {
+				String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
+				if (StringUtils.hasText(destination) && !destination.equals(location)) {
 					String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + "]";
 					validationExceptions.put(INVALID_DESTINATION, authException(INVALID_DESTINATION, message));
 				}
 
 				String issuer = response.getIssuer().getValue();
-				String assertingPartyEntityId = token.getIdpEntityId();
+				String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId();
 				if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) {
 					String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
 					validationExceptions.put(INVALID_ISSUER, authException(INVALID_ISSUER, message));
@@ -538,11 +536,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 			return encrypted -> {
 				Saml2AuthenticationException last =
 						authException(DECRYPTION_ERROR, "No valid decryption credentials found.");
-				List<Saml2X509Credential> decryptionCredentials = token.getX509Credentials();
+				List<Saml2X509Credential> decryptionCredentials = token.getRelyingPartyRegistration().getDecryptionCredentials();
 				for (Saml2X509Credential key : decryptionCredentials) {
-					if (!key.isDecryptionCredential()) {
-						continue;
-					}
 					Decrypter decrypter = getDecrypter(key);
 					try {
 						return decrypter.decrypt(encrypted);
@@ -623,11 +618,10 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 
 		private SignatureTrustEngine buildSignatureTrustEngine(Saml2AuthenticationToken token) {
 			Set<Credential> credentials = new HashSet<>();
-			for (Saml2X509Credential key : token.getX509Credentials()) {
-				if (!key.isSignatureVerficationCredential()) continue;
+			for (Saml2X509Credential key : token.getRelyingPartyRegistration().getVerificationCredentials()) {
 				BasicX509Credential cred = new BasicX509Credential(key.getCertificate());
 				cred.setUsageType(UsageType.SIGNING);
-				cred.setEntityId(token.getIdpEntityId());
+				cred.setEntityId(token.getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId());
 				credentials.add(cred);
 			}
 			CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials);
@@ -709,10 +703,12 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 							}
 						},
 						token -> {
+							String audience = token.getRelyingPartyRegistration().getEntityId();
+							String recipient = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
 							Map<String, Object> params = new HashMap<>();
 							params.put(CLOCK_SKEW, Duration.ofMinutes(5).toMillis());
-							params.put(COND_VALID_AUDIENCES, singleton(token.getIdpEntityId()));
-							params.put(SC_VALID_RECIPIENTS, singleton(token.getRecipientUri()));
+							params.put(COND_VALID_AUDIENCES, singleton(audience));
+							params.put(SC_VALID_RECIPIENTS, singleton(recipient));
 							params.putAll(this.validationContextParameters);
 							return new ValidationContext(params);
 						});
@@ -734,9 +730,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 			return encrypted -> {
 				Saml2AuthenticationException last =
 						authException(DECRYPTION_ERROR, "No valid decryption credentials found.");
-				List<Saml2X509Credential> decryptionCredentials = token.getX509Credentials();
+				List<Saml2X509Credential> decryptionCredentials = token.getRelyingPartyRegistration().getDecryptionCredentials();
 				for (Saml2X509Credential key : decryptionCredentials) {
-					if (!key.isDecryptionCredential()) continue;
 					Decrypter decrypter = getDecrypter(key);
 					try {
 						return (NameID) decrypter.decrypt(encrypted);

+ 66 - 14
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationToken.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,23 +16,51 @@
 
 package org.springframework.security.saml2.provider.service.authentication;
 
+import java.util.Collections;
+import java.util.List;
+
 import org.springframework.security.authentication.AbstractAuthenticationToken;
 import org.springframework.security.saml2.credentials.Saml2X509Credential;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.util.Assert;
 
-import java.util.List;
+import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRegistrationId;
 
 /**
  * Represents an incoming SAML 2.0 response containing an assertion that has not been validated.
  * {@link Saml2AuthenticationToken#isAuthenticated()} will always return false.
+ *
  * @since 5.2
+ * @author Filip Hanik
+ * @author Josh Cummings
  */
 public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
 
+	private final RelyingPartyRegistration relyingPartyRegistration;
 	private final String saml2Response;
-	private final String recipientUri;
-	private String idpEntityId;
-	private String localSpEntityId;
-	private List<Saml2X509Credential> credentials;
+
+	/**
+	 * Creates a {@link Saml2AuthenticationToken} with the provided parameters
+	 *
+	 * Note that the given {@link RelyingPartyRegistration} should have all its
+	 * templates resolved at this point. See
+	 * {@link org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter}
+	 * for an example of performing that resolution.
+	 *
+	 * @param relyingPartyRegistration the resolved {@link RelyingPartyRegistration} to use
+	 * @param saml2Response the SAML 2.0 response to authenticate
+	 *
+	 * @since 5.4
+	 */
+	public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration,
+			String saml2Response) {
+
+		super(Collections.emptyList());
+		Assert.notNull(relyingPartyRegistration, "relyingPartyRegistration cannot be null");
+		Assert.notNull(saml2Response, "saml2Response cannot be null");
+		this.relyingPartyRegistration = relyingPartyRegistration;
+		this.saml2Response = saml2Response;
+	}
 
 	/**
 	 * Creates an authentication token from an incoming SAML 2 Response object
@@ -41,18 +69,24 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
 	 * @param idpEntityId the entity ID of the asserting entity
 	 * @param localSpEntityId the configured local SP, the relying party, entity ID
 	 * @param credentials the credentials configured for signature verification and decryption
+	 * @deprecated Use {@link Saml2AuthenticationToken(RelyingPartyRegistration, String)} instead
 	 */
+	@Deprecated
 	public Saml2AuthenticationToken(String saml2Response,
 									String recipientUri,
 									String idpEntityId,
 									String localSpEntityId,
 									List<Saml2X509Credential> credentials) {
 		super(null);
+		this.relyingPartyRegistration = withRegistrationId(idpEntityId)
+				.entityId(localSpEntityId)
+				.assertionConsumerServiceLocation(recipientUri)
+				.credentials(c -> c.addAll(credentials))
+				.assertingPartyDetails(assertingParty -> assertingParty
+						.entityId(idpEntityId)
+						.singleSignOnServiceLocation(idpEntityId))
+				.build();
 		this.saml2Response = saml2Response;
-		this.recipientUri = recipientUri;
-		this.idpEntityId = idpEntityId;
-		this.localSpEntityId = localSpEntityId;
-		this.credentials = credentials;
 	}
 
 	/**
@@ -73,6 +107,16 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
 		return null;
 	}
 
+	/**
+	 * Get the resolved {@link RelyingPartyRegistration} associated with the request
+	 *
+	 * @return the resolved {@link RelyingPartyRegistration}
+	 * @since 5.4
+	 */
+	public RelyingPartyRegistration getRelyingPartyRegistration() {
+		return this.relyingPartyRegistration;
+	}
+
 	/**
 	 * Returns inflated and decoded XML representation of the SAML 2 Response
 	 * @return inflated and decoded XML representation of the SAML 2 Response
@@ -84,25 +128,31 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
 	/**
 	 * Returns the URI that the SAML 2 Response object came in on
 	 * @return URI as a string
+	 * @deprecated Use {@link #getRelyingPartyRegistration().getAssertionConsumerServiceLocation()} instead
 	 */
+	@Deprecated
 	public String getRecipientUri() {
-		return this.recipientUri;
+		return this.relyingPartyRegistration.getAssertionConsumerServiceLocation();
 	}
 
 	/**
 	 * Returns the configured entity ID of the receiving relying party, SP
 	 * @return an entityID for the configured local relying party
+	 * @deprecated Use {@link #getRelyingPartyRegistration().getEntityId()} instead
 	 */
+	@Deprecated
 	public String getLocalSpEntityId() {
-		return this.localSpEntityId;
+		return this.relyingPartyRegistration.getEntityId();
 	}
 
 	/**
 	 * Returns all the credentials associated with the relying party configuraiton
 	 * @return
+	 * @deprecated Get the credentials through {@link #getRelyingPartyRegistration()} instead
 	 */
+	@Deprecated
 	public List<Saml2X509Credential> getX509Credentials() {
-		return this.credentials;
+		return this.relyingPartyRegistration.getCredentials();
 	}
 
 	/**
@@ -126,8 +176,10 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
 	/**
 	 * Returns the configured IDP, asserting party, entity ID
 	 * @return a string representing the entity ID
+	 * @deprecated Use {@link #getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId()} instead
 	 */
+	@Deprecated
 	public String getIdpEntityId() {
-		return this.idpEntityId;
+		return this.relyingPartyRegistration.getAssertingPartyDetails().getEntityId();
 	}
 }

+ 10 - 8
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java

@@ -35,6 +35,7 @@ import org.springframework.util.Assert;
 
 import static java.nio.charset.StandardCharsets.UTF_8;
 import static org.springframework.security.saml2.core.Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND;
+import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
 import static org.springframework.util.StringUtils.hasText;
 
 /**
@@ -98,14 +99,15 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
 			throw new Saml2AuthenticationException(saml2Error);
 		}
 		String applicationUri = Saml2ServletUtils.getApplicationUri(request);
-		String localSpEntityId = Saml2ServletUtils.resolveUrlTemplate(rp.getEntityId(), applicationUri, rp);
-		final Saml2AuthenticationToken authentication = new Saml2AuthenticationToken(
-				responseXml,
-				request.getRequestURL().toString(),
-				rp.getAssertingPartyDetails().getEntityId(),
-				localSpEntityId,
-				rp.getCredentials()
-		);
+		String relyingPartyEntityId = Saml2ServletUtils.resolveUrlTemplate(rp.getEntityId(), applicationUri, rp);
+		String assertionConsumerServiceLocation = Saml2ServletUtils.resolveUrlTemplate(
+				rp.getAssertionConsumerServiceLocation(), applicationUri, rp);
+		RelyingPartyRegistration relyingPartyRegistration = withRelyingPartyRegistration(rp)
+				.entityId(relyingPartyEntityId)
+				.assertionConsumerServiceLocation(assertionConsumerServiceLocation)
+				.build();
+		Saml2AuthenticationToken authentication = new Saml2AuthenticationToken(
+				relyingPartyRegistration, responseXml);
 		return getAuthenticationManager().authenticate(authentication);
 	}
 

+ 4 - 4
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java

@@ -111,14 +111,14 @@ public class OpenSamlAuthenticationProviderTests {
 		this.exception.expect(authenticationMatcher(Saml2ErrorCodes.UNKNOWN_RESPONSE_CLASS));
 
 		Assertion assertion = this.saml.buildSamlObject(Assertion.DEFAULT_ELEMENT_NAME);
-		this.provider.authenticate(token(this.saml.serialize(assertion)));
+		this.provider.authenticate(token(this.saml.serialize(assertion), relyingPartyVerifyingCredential()));
 	}
 
 	@Test
 	public void authenticateWhenXmlErrorThenThrowAuthenticationException() {
 		this.exception.expect(authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA));
 
-		Saml2AuthenticationToken token = token("invalid xml");
+		Saml2AuthenticationToken token = token("invalid xml", relyingPartyVerifyingCredential());
 		this.provider.authenticate(token);
 	}
 
@@ -149,7 +149,7 @@ public class OpenSamlAuthenticationProviderTests {
 
 		Response response = response();
 		response.getAssertions().add(assertion());
-		Saml2AuthenticationToken token = token(response);
+		Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
 		this.provider.authenticate(token);
 	}
 
@@ -316,7 +316,7 @@ public class OpenSamlAuthenticationProviderTests {
 		Response response = response();
 		EncryptedAssertion encryptedAssertion = encrypted(assertion(), assertingPartyEncryptingCredential());
 		response.getEncryptedAssertions().add(encryptedAssertion);
-		Saml2AuthenticationToken token = token(this.saml.serialize(response));
+		Saml2AuthenticationToken token = token(this.saml.serialize(response), relyingPartyVerifyingCredential());
 		this.provider.authenticate(token);
 	}