|
@@ -42,8 +42,11 @@ import static org.assertj.core.api.Assertions.assertThat;
|
|
|
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
|
|
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
|
|
import static org.mockito.ArgumentMatchers.any;
|
|
|
+import static org.mockito.ArgumentMatchers.eq;
|
|
|
+import static org.mockito.ArgumentMatchers.isNull;
|
|
|
import static org.mockito.BDDMockito.given;
|
|
|
import static org.mockito.Mockito.mock;
|
|
|
+import static org.mockito.Mockito.verify;
|
|
|
|
|
|
@ExtendWith(MockitoExtension.class)
|
|
|
public class Saml2AuthenticationTokenConverterTests {
|
|
@@ -69,6 +72,21 @@ public class Saml2AuthenticationTokenConverterTests {
|
|
|
.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
|
|
|
}
|
|
|
|
|
|
+ @Test
|
|
|
+ public void convertWhenSamlResponseWithRelyingPartyRegistrationResolver(
|
|
|
+ @Mock RelyingPartyRegistrationResolver resolver) {
|
|
|
+ Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(resolver);
|
|
|
+ given(resolver.resolve(any(HttpServletRequest.class), any())).willReturn(this.relyingPartyRegistration);
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest();
|
|
|
+ request.setParameter(Saml2ParameterNames.SAML_RESPONSE,
|
|
|
+ Saml2Utils.samlEncodeNotRfc2045("response".getBytes(StandardCharsets.UTF_8)));
|
|
|
+ Saml2AuthenticationToken token = converter.convert(request);
|
|
|
+ assertThat(token.getSaml2Response()).isEqualTo("response");
|
|
|
+ assertThat(token.getRelyingPartyRegistration().getRegistrationId())
|
|
|
+ .isEqualTo(this.relyingPartyRegistration.getRegistrationId());
|
|
|
+ verify(resolver).resolve(any(), isNull());
|
|
|
+ }
|
|
|
+
|
|
|
@Test
|
|
|
public void convertWhenSamlResponseInvalidBase64ThenSaml2AuthenticationException() {
|
|
|
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
|
|
@@ -157,6 +175,8 @@ public class Saml2AuthenticationTokenConverterTests {
|
|
|
Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
|
|
|
Saml2AuthenticationRequestRepository.class);
|
|
|
AbstractSaml2AuthenticationRequest authenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
|
|
|
+ given(authenticationRequest.getRelyingPartyRegistrationId())
|
|
|
+ .willReturn(this.relyingPartyRegistration.getRegistrationId());
|
|
|
Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
|
|
|
this.relyingPartyRegistrationResolver);
|
|
|
converter.setAuthenticationRequestRepository(authenticationRequestRepository);
|
|
@@ -174,6 +194,30 @@ public class Saml2AuthenticationTokenConverterTests {
|
|
|
assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest);
|
|
|
}
|
|
|
|
|
|
+ @Test
|
|
|
+ public void convertWhenSavedAuthenticationRequestThenTokenWithRelyingPartyRegistrationResolver(
|
|
|
+ @Mock RelyingPartyRegistrationResolver resolver) {
|
|
|
+ Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
|
|
|
+ Saml2AuthenticationRequestRepository.class);
|
|
|
+ AbstractSaml2AuthenticationRequest authenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
|
|
|
+ given(authenticationRequest.getRelyingPartyRegistrationId())
|
|
|
+ .willReturn(this.relyingPartyRegistration.getRegistrationId());
|
|
|
+ Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(resolver);
|
|
|
+ converter.setAuthenticationRequestRepository(authenticationRequestRepository);
|
|
|
+ given(resolver.resolve(any(HttpServletRequest.class), any())).willReturn(this.relyingPartyRegistration);
|
|
|
+ given(authenticationRequestRepository.loadAuthenticationRequest(any(HttpServletRequest.class)))
|
|
|
+ .willReturn(authenticationRequest);
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest();
|
|
|
+ request.setParameter(Saml2ParameterNames.SAML_RESPONSE,
|
|
|
+ Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8)));
|
|
|
+ Saml2AuthenticationToken token = converter.convert(request);
|
|
|
+ assertThat(token.getSaml2Response()).isEqualTo("response");
|
|
|
+ assertThat(token.getRelyingPartyRegistration().getRegistrationId())
|
|
|
+ .isEqualTo(this.relyingPartyRegistration.getRegistrationId());
|
|
|
+ assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest);
|
|
|
+ verify(resolver).resolve(any(), eq(this.relyingPartyRegistration.getRegistrationId()));
|
|
|
+ }
|
|
|
+
|
|
|
@Test
|
|
|
public void constructorWhenResolverIsNullThenIllegalArgument() {
|
|
|
assertThatIllegalArgumentException().isThrownBy(() -> new Saml2AuthenticationTokenConverter(null));
|