|
@@ -44,8 +44,11 @@ import static org.assertj.core.api.Assertions.assertThat;
|
|
|
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
|
|
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
|
|
import static org.mockito.ArgumentMatchers.any;
|
|
|
+import static org.mockito.ArgumentMatchers.eq;
|
|
|
+import static org.mockito.ArgumentMatchers.isNull;
|
|
|
import static org.mockito.BDDMockito.given;
|
|
|
import static org.mockito.Mockito.mock;
|
|
|
+import static org.mockito.Mockito.verify;
|
|
|
|
|
|
@ExtendWith(MockitoExtension.class)
|
|
|
public class Saml2AuthenticationTokenConverterTests {
|
|
@@ -71,6 +74,21 @@ public class Saml2AuthenticationTokenConverterTests {
|
|
|
.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
|
|
|
}
|
|
|
|
|
|
+ @Test
|
|
|
+ public void convertWhenSamlResponseWithRelyingPartyRegistrationResolver(
|
|
|
+ @Mock RelyingPartyRegistrationResolver resolver) {
|
|
|
+ Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(resolver);
|
|
|
+ given(resolver.resolve(any(HttpServletRequest.class), any())).willReturn(this.relyingPartyRegistration);
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest();
|
|
|
+ request.setParameter(Saml2ParameterNames.SAML_RESPONSE,
|
|
|
+ Saml2Utils.samlEncodeNotRfc2045("response".getBytes(StandardCharsets.UTF_8)));
|
|
|
+ Saml2AuthenticationToken token = converter.convert(request);
|
|
|
+ assertThat(token.getSaml2Response()).isEqualTo("response");
|
|
|
+ assertThat(token.getRelyingPartyRegistration().getRegistrationId())
|
|
|
+ .isEqualTo(this.relyingPartyRegistration.getRegistrationId());
|
|
|
+ verify(resolver).resolve(any(), isNull());
|
|
|
+ }
|
|
|
+
|
|
|
@Test
|
|
|
public void convertWhenSamlResponseInvalidBase64ThenSaml2AuthenticationException() {
|
|
|
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
|
|
@@ -159,6 +177,8 @@ public class Saml2AuthenticationTokenConverterTests {
|
|
|
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
|
|
|
Saml2AuthenticationRequestRepository.class);
|
|
|
AbstractSaml2AuthenticationRequest authenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
|
|
|
+ given(authenticationRequest.getRelyingPartyRegistrationId())
|
|
|
+ .willReturn(this.relyingPartyRegistration.getRegistrationId());
|
|
|
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
|
|
|
this.relyingPartyRegistrationResolver);
|
|
|
converter.setAuthenticationRequestRepository(authenticationRequestRepository);
|
|
@@ -176,6 +196,30 @@ public class Saml2AuthenticationTokenConverterTests {
|
|
|
assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest);
|
|
|
}
|
|
|
|
|
|
+ @Test
|
|
|
+ public void convertWhenSavedAuthenticationRequestThenTokenWithRelyingPartyRegistrationResolver(
|
|
|
+ @Mock RelyingPartyRegistrationResolver resolver) {
|
|
|
+ Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
|
|
|
+ Saml2AuthenticationRequestRepository.class);
|
|
|
+ AbstractSaml2AuthenticationRequest authenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
|
|
|
+ given(authenticationRequest.getRelyingPartyRegistrationId())
|
|
|
+ .willReturn(this.relyingPartyRegistration.getRegistrationId());
|
|
|
+ Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(resolver);
|
|
|
+ converter.setAuthenticationRequestRepository(authenticationRequestRepository);
|
|
|
+ given(resolver.resolve(any(HttpServletRequest.class), any())).willReturn(this.relyingPartyRegistration);
|
|
|
+ given(authenticationRequestRepository.loadAuthenticationRequest(any(HttpServletRequest.class)))
|
|
|
+ .willReturn(authenticationRequest);
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest();
|
|
|
+ request.setParameter(Saml2ParameterNames.SAML_RESPONSE,
|
|
|
+ Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8)));
|
|
|
+ Saml2AuthenticationToken token = converter.convert(request);
|
|
|
+ assertThat(token.getSaml2Response()).isEqualTo("response");
|
|
|
+ assertThat(token.getRelyingPartyRegistration().getRegistrationId())
|
|
|
+ .isEqualTo(this.relyingPartyRegistration.getRegistrationId());
|
|
|
+ assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest);
|
|
|
+ verify(resolver).resolve(any(), eq(this.relyingPartyRegistration.getRegistrationId()));
|
|
|
+ }
|
|
|
+
|
|
|
@Test
|
|
|
public void constructorWhenResolverIsNullThenIllegalArgument() {
|
|
|
assertThatIllegalArgumentException()
|