|
@@ -17,8 +17,6 @@
|
|
|
package org.springframework.security.saml2.provider.service.authentication;
|
|
|
|
|
|
import java.io.ByteArrayInputStream;
|
|
|
-import java.util.function.Consumer;
|
|
|
-import java.util.function.Function;
|
|
|
|
|
|
import org.junit.Assert;
|
|
|
import org.junit.Before;
|
|
@@ -31,6 +29,7 @@ import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller;
|
|
|
import org.w3c.dom.Document;
|
|
|
import org.w3c.dom.Element;
|
|
|
|
|
|
+import org.springframework.core.convert.converter.Converter;
|
|
|
import org.springframework.security.saml2.Saml2Exception;
|
|
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
|
|
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
|
|
@@ -47,6 +46,7 @@ import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getU
|
|
|
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
|
|
|
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
|
|
|
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate;
|
|
|
+import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.authnRequest;
|
|
|
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
|
|
|
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
|
|
|
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
|
|
@@ -63,8 +63,7 @@ public class OpenSamlAuthenticationRequestFactoryTests {
|
|
|
private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder;
|
|
|
private RelyingPartyRegistration relyingPartyRegistration;
|
|
|
|
|
|
- private AuthnRequestUnmarshaller unmarshaller = (AuthnRequestUnmarshaller) getUnmarshallerFactory()
|
|
|
- .getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
|
|
|
+ private AuthnRequestUnmarshaller unmarshaller;
|
|
|
|
|
|
@Rule
|
|
|
public ExpectedException exception = ExpectedException.none();
|
|
@@ -84,6 +83,8 @@ public class OpenSamlAuthenticationRequestFactoryTests {
|
|
|
.assertionConsumerServiceUrl("https://issuer/sso");
|
|
|
context = contextBuilder.build();
|
|
|
factory = new OpenSamlAuthenticationRequestFactory();
|
|
|
+ this.unmarshaller =(AuthnRequestUnmarshaller) getUnmarshallerFactory()
|
|
|
+ .getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
@@ -182,29 +183,29 @@ public class OpenSamlAuthenticationRequestFactoryTests {
|
|
|
|
|
|
@Test
|
|
|
public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
|
|
|
- Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
|
|
|
- mock(Function.class);
|
|
|
- when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
|
|
|
- this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
|
|
|
+ Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter =
|
|
|
+ mock(Converter.class);
|
|
|
+ when(authenticationRequestContextConverter.convert(this.context)).thenReturn(authnRequest());
|
|
|
+ this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter);
|
|
|
|
|
|
this.factory.createPostAuthenticationRequest(this.context);
|
|
|
- verify(authnRequestConsumerResolver).apply(this.context);
|
|
|
+ verify(authenticationRequestContextConverter).convert(this.context);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
|
|
|
- Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
|
|
|
- mock(Function.class);
|
|
|
- when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
|
|
|
- this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
|
|
|
+ Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter =
|
|
|
+ mock(Converter.class);
|
|
|
+ when(authenticationRequestContextConverter.convert(this.context)).thenReturn(authnRequest());
|
|
|
+ this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter);
|
|
|
|
|
|
this.factory.createRedirectAuthenticationRequest(this.context);
|
|
|
- verify(authnRequestConsumerResolver).apply(this.context);
|
|
|
+ verify(authenticationRequestContextConverter).convert(this.context);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void setAuthnRequestConsumerResolverWhenNullThenException() {
|
|
|
- assertThatCode(() -> this.factory.setAuthnRequestConsumerResolver(null))
|
|
|
+ public void setAuthenticationRequestContextConverterWhenNullThenException() {
|
|
|
+ assertThatCode(() -> this.factory.setAuthenticationRequestContextConverter(null))
|
|
|
.isInstanceOf(IllegalArgumentException.class);
|
|
|
}
|
|
|
|