2
0
Эх сурвалжийг харах

Polish AuthnRequest Customization Support

Having the application generate the AuthnRequest fresh allows Spring
Security to back away more gracefully. Using a Consumer implies that
the application will need to undo any values that Spring Security set
that the application doesn't want.

Also, if this does become a configuration burden, it can be simplified
in a separate ticket by exposing the default Converter.

Issue gh-8776
Josh Cummings 5 жил өмнө
parent
commit
af5c55c380

+ 11 - 5
config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java

@@ -35,6 +35,7 @@ import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
 import org.opensaml.saml.saml2.core.Assertion;
+import org.opensaml.saml.saml2.core.AuthnRequest;
 
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.ConfigurableApplicationContext;
@@ -89,6 +90,7 @@ import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 import static org.springframework.security.config.Customizer.withDefaults;
 import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
+import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.authnRequest;
 import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
 import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials;
 import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
@@ -176,8 +178,8 @@ public class Saml2LoginConfigurerTests {
 	}
 
 	@Test
-	public void authenticationRequestWhenAuthnRequestConsumerResolverThenUses() throws Exception {
-		this.spring.register(CustomAuthnRequestConsumerResolver.class).autowire();
+	public void authenticationRequestWhenAuthnRequestContextConverterThenUses() throws Exception {
+		this.spring.register(CustomAuthenticationRequestContextConverterResolver.class).autowire();
 
 		MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id"))
 				.andReturn();
@@ -315,7 +317,7 @@ public class Saml2LoginConfigurerTests {
 
 	@EnableWebSecurity
 	@Import(Saml2LoginConfigBeans.class)
-	static class CustomAuthnRequestConsumerResolver extends WebSecurityConfigurerAdapter {
+	static class CustomAuthenticationRequestContextConverterResolver extends WebSecurityConfigurerAdapter {
 
 		@Override
 		protected void configure(HttpSecurity http) throws Exception {
@@ -330,8 +332,12 @@ public class Saml2LoginConfigurerTests {
 		Saml2AuthenticationRequestFactory authenticationRequestFactory() {
 			OpenSamlAuthenticationRequestFactory authenticationRequestFactory =
 					new OpenSamlAuthenticationRequestFactory();
-			authenticationRequestFactory.setAuthnRequestConsumerResolver(
-					context -> authnRequest -> authnRequest.setForceAuthn(true));
+			authenticationRequestFactory.setAuthenticationRequestContextConverter(
+					context -> {
+						AuthnRequest authnRequest = authnRequest();
+						authnRequest.setForceAuthn(true);
+						return authnRequest;
+					});
 			return authenticationRequestFactory;
 		}
 	}

+ 10 - 14
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java

@@ -25,8 +25,6 @@ import java.util.Collection;
 import java.util.LinkedHashMap;
 import java.util.Map;
 import java.util.UUID;
-import java.util.function.Consumer;
-import java.util.function.Function;
 
 import net.shibboleth.utilities.java.support.xml.SerializeSupport;
 import org.joda.time.DateTime;
@@ -88,8 +86,8 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 				return context.getRelyingPartyRegistration().getAssertionConsumerServiceBinding().getUrn();
 			};
 
-	private Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver
-			= context -> authnRequest -> {};
+	private Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter
+			= this::createAuthnRequest;
 
 	/**
 	 * Creates an {@link OpenSamlAuthenticationRequestFactory}
@@ -124,7 +122,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 	 */
 	@Override
 	public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) {
-		AuthnRequest authnRequest = createAuthnRequest(context);
+		AuthnRequest authnRequest = this.authenticationRequestContextConverter.convert(context);
 		String xml = context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned() ?
 			serialize(sign(authnRequest, context.getRelyingPartyRegistration())) :
 			serialize(authnRequest);
@@ -139,7 +137,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 	 */
 	@Override
 	public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext context) {
-		AuthnRequest authnRequest = createAuthnRequest(context);
+		AuthnRequest authnRequest = this.authenticationRequestContextConverter.convert(context);
 		String xml = serialize(authnRequest);
 		Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context);
 		String deflatedAndEncoded = samlEncode(samlDeflate(xml));
@@ -168,11 +166,9 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 	}
 
 	private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) {
-		AuthnRequest authnRequest = createAuthnRequest(context.getIssuer(),
+		return createAuthnRequest(context.getIssuer(),
 				context.getDestination(), context.getAssertionConsumerServiceUrl(),
 				this.protocolBindingResolver.convert(context));
-		this.authnRequestConsumerResolver.apply(context).accept(authnRequest);
-		return authnRequest;
 	}
 
 	private AuthnRequest createAuthnRequest
@@ -194,13 +190,13 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 	/**
 	 * Set the {@link AuthnRequest} post-processor resolver
 	 *
-	 * @param authnRequestConsumerResolver
+	 * @param authenticationRequestContextConverter
 	 * @since 5.4
 	 */
-	public void setAuthnRequestConsumerResolver(
-			Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver) {
-		Assert.notNull(authnRequestConsumerResolver, "authnRequestConsumerResolver cannot be null");
-		this.authnRequestConsumerResolver = authnRequestConsumerResolver;
+	public void setAuthenticationRequestContextConverter(
+			Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter) {
+		Assert.notNull(authenticationRequestContextConverter, "authenticationRequestContextConverter cannot be null");
+		this.authenticationRequestContextConverter = authenticationRequestContextConverter;
 	}
 
 	/**

+ 17 - 16
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java

@@ -17,8 +17,6 @@
 package org.springframework.security.saml2.provider.service.authentication;
 
 import java.io.ByteArrayInputStream;
-import java.util.function.Consumer;
-import java.util.function.Function;
 
 import org.junit.Assert;
 import org.junit.Before;
@@ -31,6 +29,7 @@ import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller;
 import org.w3c.dom.Document;
 import org.w3c.dom.Element;
 
+import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.saml2.Saml2Exception;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
@@ -47,6 +46,7 @@ import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getU
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
 import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
 import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate;
+import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.authnRequest;
 import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
 import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
 import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
@@ -63,8 +63,7 @@ public class OpenSamlAuthenticationRequestFactoryTests {
 	private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder;
 	private RelyingPartyRegistration relyingPartyRegistration;
 
-	private AuthnRequestUnmarshaller unmarshaller = (AuthnRequestUnmarshaller) getUnmarshallerFactory()
-			.getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
+	private AuthnRequestUnmarshaller unmarshaller;
 
 	@Rule
 	public ExpectedException exception = ExpectedException.none();
@@ -84,6 +83,8 @@ public class OpenSamlAuthenticationRequestFactoryTests {
 				.assertionConsumerServiceUrl("https://issuer/sso");
 		context = contextBuilder.build();
 		factory = new OpenSamlAuthenticationRequestFactory();
+		this.unmarshaller =(AuthnRequestUnmarshaller) getUnmarshallerFactory()
+				.getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
 	}
 
 	@Test
@@ -182,29 +183,29 @@ public class OpenSamlAuthenticationRequestFactoryTests {
 
 	@Test
 	public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
-		Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
-				mock(Function.class);
-		when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
-		this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
+		Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter =
+				mock(Converter.class);
+		when(authenticationRequestContextConverter.convert(this.context)).thenReturn(authnRequest());
+		this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter);
 
 		this.factory.createPostAuthenticationRequest(this.context);
-		verify(authnRequestConsumerResolver).apply(this.context);
+		verify(authenticationRequestContextConverter).convert(this.context);
 	}
 
 	@Test
 	public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
-		Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
-				mock(Function.class);
-		when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
-		this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
+		Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter =
+				mock(Converter.class);
+		when(authenticationRequestContextConverter.convert(this.context)).thenReturn(authnRequest());
+		this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter);
 
 		this.factory.createRedirectAuthenticationRequest(this.context);
-		verify(authnRequestConsumerResolver).apply(this.context);
+		verify(authenticationRequestContextConverter).convert(this.context);
 	}
 
 	@Test
-	public void setAuthnRequestConsumerResolverWhenNullThenException() {
-		assertThatCode(() -> this.factory.setAuthnRequestConsumerResolver(null))
+	public void setAuthenticationRequestContextConverterWhenNullThenException() {
+		assertThatCode(() -> this.factory.setAuthenticationRequestContextConverter(null))
 				.isInstanceOf(IllegalArgumentException.class);
 	}
 

+ 12 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java

@@ -53,6 +53,7 @@ import org.opensaml.saml.saml2.core.Assertion;
 import org.opensaml.saml.saml2.core.Attribute;
 import org.opensaml.saml.saml2.core.AttributeStatement;
 import org.opensaml.saml.saml2.core.AttributeValue;
+import org.opensaml.saml.saml2.core.AuthnRequest;
 import org.opensaml.saml.saml2.core.Conditions;
 import org.opensaml.saml.saml2.core.EncryptedAssertion;
 import org.opensaml.saml.saml2.core.EncryptedID;
@@ -86,7 +87,7 @@ import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getB
 import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS;
 import static org.springframework.security.saml2.core.TestSaml2X509Credentials.assertingPartySigningCredential;
 
-final class TestOpenSamlObjects {
+public final class TestOpenSamlObjects {
 	static {
 		OpenSamlInitializationService.initialize();
 	}
@@ -188,6 +189,16 @@ final class TestOpenSamlObjects {
 		return conditions;
 	}
 
+	public static AuthnRequest authnRequest() {
+		Issuer issuer = build(Issuer.DEFAULT_ELEMENT_NAME);
+		issuer.setValue(ASSERTING_PARTY_ENTITY_ID);
+		AuthnRequest authnRequest = build(AuthnRequest.DEFAULT_ELEMENT_NAME);
+		authnRequest.setIssuer(issuer);
+		authnRequest.setDestination(ASSERTING_PARTY_ENTITY_ID + "/SSO.saml2");
+		authnRequest.setAssertionConsumerServiceURL(DESTINATION);
+		return authnRequest;
+	}
+
 	static Credential getSigningCredential(Saml2X509Credential credential, String entityId) {
 		BasicCredential cred = getBasicCredential(credential);
 		cred.setEntityId(entityId);