瀏覽代碼

Add Response to Authentication Conversion Support

Closes gh-8010
Josh Cummings 5 年之前
父節點
當前提交
da7477cd41

+ 83 - 6
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java

@@ -28,7 +28,6 @@ import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.function.Function;
 import javax.annotation.Nonnull;
 import javax.xml.namespace.QName;
 
@@ -185,8 +184,10 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 	private GrantedAuthoritiesMapper authoritiesMapper = (a -> a);
 	private Duration responseTimeValidationSkew = Duration.ofMinutes(5);
 
-	private Function<Saml2AuthenticationToken, Converter<Response, AbstractAuthenticationToken>> authenticationConverter =
-			token -> response -> {
+	private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter =
+			responseToken -> {
+				Response response = responseToken.response;
+				Saml2AuthenticationToken token = responseToken.token;
 				Assertion assertion = CollectionUtils.firstElement(response.getAssertions());
 				String username = assertion.getSubject().getNameID().getValue();
 				Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
@@ -255,11 +256,42 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		this.assertionValidator = assertionValidator;
 	}
 
+	/**
+	 * Set the {@link Converter} to use for converting a validated {@link Response} into
+	 * an {@link AbstractAuthenticationToken}.
+	 *
+	 * You can delegate to the default behavior by calling {@link #createDefaultResponseAuthenticationConverter()}
+	 * like so:
+	 *
+	 * <pre>
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 * 	Converter&lt;ResponseToken, Saml2Authentication&gt; authenticationConverter =
+	 * 			createDefaultResponseAuthenticationConverter();
+	 *	provider.setResponseAuthenticationConverter(responseToken -> {
+	 *		Saml2Authentication authentication = authenticationConverter.convert(responseToken);
+	 *		User user = myUserRepository.findByUsername(authentication.getName());
+	 *		return new MyAuthentication(authentication, user);
+	 *	});
+	 * </pre>
+	 *
+	 * This method takes precedence over {@link #setAuthoritiesExtractor(Converter)} and
+	 * {@link #setAuthoritiesMapper(GrantedAuthoritiesMapper)}.
+	 *
+	 * @param responseAuthenticationConverter the {@link Converter} to use
+	 * @since 5.4
+	 */
+	public void setResponseAuthenticationConverter(
+			Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter) {
+		Assert.notNull(responseAuthenticationConverter, "responseAuthenticationConverter cannot be null");
+		this.responseAuthenticationConverter = responseAuthenticationConverter;
+	}
+
 	/**
 	 * Sets the {@link Converter} used for extracting assertion attributes that
 	 * can be mapped to authorities.
 	 * @param authoritiesExtractor the {@code Converter} used for mapping the
 	 *                             assertion attributes to authorities
+	 * @deprecated Use {@link #setResponseAuthenticationConverter(Converter)} instead
 	 */
 	public void setAuthoritiesExtractor(Converter<Assertion, Collection<? extends GrantedAuthority>> authoritiesExtractor) {
 		Assert.notNull(authoritiesExtractor, "authoritiesExtractor cannot be null");
@@ -271,6 +303,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 	 * to a new set of authorities which will be associated to the {@link Saml2Authentication}.
 	 * Note: This implementation is only retrieving
 	 * @param authoritiesMapper the {@link GrantedAuthoritiesMapper} used for mapping the user's authorities
+	 * @deprecated Use {@link #setResponseAuthenticationConverter(Converter)} instead
 	 */
 	public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
 		notNull(authoritiesMapper, "authoritiesMapper cannot be null");
@@ -286,6 +319,27 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		this.responseTimeValidationSkew = responseTimeValidationSkew;
 	}
 
+	/**
+	 * Construct a default strategy for converting a SAML 2.0 Response and {@link Authentication}
+	 * token into a {@link Saml2Authentication}
+	 *
+	 * @return the default response authentication converter strategy
+	 * @since 5.4
+	 */
+	public static Converter<ResponseToken, Saml2Authentication>
+			createDefaultResponseAuthenticationConverter() {
+		return responseToken -> {
+			Saml2AuthenticationToken token = responseToken.token;
+			Response response = responseToken.response;
+			Assertion assertion = CollectionUtils.firstElement(response.getAssertions());
+			String username = assertion.getSubject().getNameID().getValue();
+			Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
+			return new Saml2Authentication(
+					new DefaultSaml2AuthenticatedPrincipal(username, attributes), token.getSaml2Response(),
+					Collections.singleton(new SimpleGrantedAuthority("ROLE_USER")));
+		};
+	}
+
 	/**
 	 * @param authentication the authentication request object, must be of type
 	 *                       {@link Saml2AuthenticationToken}
@@ -300,7 +354,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 			String serializedResponse = token.getSaml2Response();
 			Response response = parse(serializedResponse);
 			process(token, response);
-			return this.authenticationConverter.apply(token).convert(response);
+			return this.responseAuthenticationConverter.convert(new ResponseToken(response, token));
 		} catch (Saml2AuthenticationException e) {
 			throw e;
 		} catch (Exception e) {
@@ -496,7 +550,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		}
 	}
 
-	private Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
+	private static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
 		Map<String, List<Object>> attributeMap = new LinkedHashMap<>();
 		for (AttributeStatement attributeStatement : assertion.getAttributeStatements()) {
 			for (Attribute attribute : attributeStatement.getAttributes()) {
@@ -515,7 +569,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		return attributeMap;
 	}
 
-	private Object getXmlObjectValue(XMLObject xmlObject) {
+	private static Object getXmlObjectValue(XMLObject xmlObject) {
 		if (xmlObject instanceof XSAny) {
 			return ((XSAny) xmlObject).getTextContent();
 		}
@@ -706,6 +760,29 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		return new Saml2AuthenticationException(validationError(code, description), cause);
 	}
 
+	/**
+	 * A tuple containing an OpenSAML {@link Response} and its associated authentication token.
+	 *
+	 * @since 5.4
+	 */
+	public static class ResponseToken {
+		private final Saml2AuthenticationToken token;
+		private final Response response;
+
+		ResponseToken(Response response, Saml2AuthenticationToken token) {
+			this.token = token;
+			this.response = response;
+		}
+
+		public Response getResponse() {
+			return this.response;
+		}
+
+		public Saml2AuthenticationToken getToken() {
+			return this.token;
+		}
+	}
+
 	/**
 	 * A tuple containing an OpenSAML {@link Assertion} and its associated authentication token.
 	 *

+ 38 - 2
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java

@@ -77,17 +77,20 @@ import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParamete
 import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED;
 import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_ASSERTION;
 import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_SIGNATURE;
+import static org.springframework.security.saml2.core.Saml2ResponseValidatorResult.success;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyDecryptingCredential;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
 import static org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.createDefaultAssertionValidator;
+import static org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.createDefaultResponseAuthenticationConverter;
 import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion;
 import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements;
 import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted;
 import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.response;
 import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signed;
+import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signedResponseWithOneAssertion;
 import static org.springframework.util.StringUtils.hasText;
 
 /**
@@ -103,6 +106,10 @@ public class OpenSamlAuthenticationProviderTests {
 	private static String ASSERTING_PARTY_ENTITY_ID = "https://some.idp.test/saml2/idp";
 
 	private OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	private Saml2AuthenticatedPrincipal principal = new DefaultSaml2AuthenticatedPrincipal
+			("name", Collections.emptyMap());
+	private Saml2Authentication authentication = new Saml2Authentication
+			(this.principal, "response", Collections.emptyList());
 
 	@Rule
 	public ExpectedException exception = ExpectedException.none();
@@ -380,7 +387,7 @@ public class OpenSamlAuthenticationProviderTests {
 		signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
 		Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
 		when(validator.convert(any(OpenSamlAuthenticationProvider.AssertionToken.class)))
-			.thenReturn(Saml2ResponseValidatorResult.success());
+			.thenReturn(success());
 		provider.authenticate(token);
 		verify(validator).convert(any(OpenSamlAuthenticationProvider.AssertionToken.class));
 	}
@@ -388,7 +395,7 @@ public class OpenSamlAuthenticationProviderTests {
 	@Test
 	public void authenticateWhenDefaultConditionValidatorNotUsedThenSignatureStillChecked() {
 		OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
-		provider.setAssertionValidator(assertionToken -> Saml2ResponseValidatorResult.success());
+		provider.setAssertionValidator(assertionToken -> success());
 		Response response = response();
 		Assertion assertion = assertion();
 		signed(assertion, relyingPartyDecryptingCredential(), RELYING_PARTY_ENTITY_ID); // broken signature
@@ -424,6 +431,35 @@ public class OpenSamlAuthenticationProviderTests {
 				.isInstanceOf(IllegalArgumentException.class);
 	}
 
+	@Test
+	public void createDefaultResponseAuthenticationConverterWhenResponseThenConverts() {
+		Response response = signedResponseWithOneAssertion();
+		Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
+		OpenSamlAuthenticationProvider.ResponseToken responseToken =
+				new OpenSamlAuthenticationProvider.ResponseToken(response, token);
+		Saml2Authentication authentication = createDefaultResponseAuthenticationConverter()
+				.convert(responseToken);
+		assertThat(authentication.getName()).isEqualTo("test@saml.user");
+	}
+
+	@Test
+	public void authenticateWhenResponseAuthenticationConverterConfiguredThenUses() {
+		Converter<OpenSamlAuthenticationProvider.ResponseToken, Saml2Authentication> authenticationConverter =
+				mock(Converter.class);
+		OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+		provider.setResponseAuthenticationConverter(authenticationConverter);
+		Response response = signedResponseWithOneAssertion();
+		Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
+		provider.authenticate(token);
+		verify(authenticationConverter).convert(any());
+	}
+
+	@Test
+	public void setResponseAuthenticationConverterWhenNullThenIllegalArgument() {
+		assertThatCode(() -> this.provider.setResponseAuthenticationConverter(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
 	private <T extends XMLObject> T build(QName qName) {
 		return (T) getBuilderFactory().getBuilder(qName).buildObject(qName);
 	}

+ 7 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java

@@ -79,6 +79,7 @@ import org.springframework.security.saml2.core.OpenSamlInitializationService;
 import org.springframework.security.saml2.core.Saml2X509Credential;
 
 import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory;
+import static org.springframework.security.saml2.core.TestSaml2X509Credentials.assertingPartySigningCredential;
 
 final class TestOpenSamlObjects {
 	static {
@@ -107,6 +108,12 @@ final class TestOpenSamlObjects {
 		return response;
 	}
 
+	static Response signedResponseWithOneAssertion() {
+		Response response = response();
+		response.getAssertions().add(assertion());
+		return signed(response, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
+	}
+
 	static Assertion assertion() {
 		return assertion(USERNAME, ASSERTING_PARTY_ENTITY_ID, RELYING_PARTY_ENTITY_ID, DESTINATION);
 	}
@@ -135,7 +142,6 @@ final class TestOpenSamlObjects {
 		return assertion;
 	}
 
-
 	static Issuer issuer(String entityId) {
 		Issuer issuer = build(Issuer.DEFAULT_ELEMENT_NAME);
 		issuer.setValue(entityId);