Explorar el Código

SAML Response Reads EntityId

Closes gh-10243
Josh Cummings hace 2 años
padre
commit
dbdf04f151

+ 48 - 30
docs/modules/ROOT/pages/servlet/saml2/login/authentication.adoc

@@ -13,35 +13,11 @@ You can configure this in a number of ways including:
 
 To configure these, you'll use the `saml2Login#authenticationManager` method in the DSL.
 
-[[relyingpartyregistrationresolver-apply]]
-== Changing `RelyingPartyRegistration` Lookup
-
-`RelyingPartyRegistration` lookup is customized xref:servlet/saml2/login/overview.adoc#servlet-saml2login-rpr-relyingpartyregistrationresolver[in a `RelyingPartyRegistrationResolver`].
-
-To apply a `RelyingPartyRegistrationResolver` when processing `<saml2:Response>` payloads, you should first publish a `Saml2AuthenticationTokenConverter` bean like so:
-
-====
-.Java
-[source,java,role="primary"]
-----
-@Bean
-Saml2AuthenticationTokenConverter authenticationConverter(InMemoryRelyingPartyRegistrationRepository registrations) {
-	return new Saml2AuthenticationTokenConverter(new MyRelyingPartyRegistrationResolver(registrations));
-}
-----
-
-.Kotlin
-[source,kotlin,role="secondary"]
-----
-@Bean
-fun authenticationConverter(val registrations: InMemoryRelyingPartyRegistrationRepository): Saml2AuthenticationTokenConverter {
-	return Saml2AuthenticationTokenConverter(MyRelyingPartyRegistrationResolver(registrations));
-}
-----
-====
+[[saml2-response-processing-endpoint]]
+== Changing the SAML Response Processing Endpoint
 
-Recall that the Assertion Consumer Service URL is `+/saml2/login/sso/{registrationId}+` by default.
-If you are no longer wanting the `registrationId` in the URL, change it in the filter chain and in your relying party metadata:
+The default endpoint is `+/login/saml2/sso/{registrationId}+`.
+You can change this in the DSL and in the associated metadata like so:
 
 ====
 .Java
@@ -82,13 +58,55 @@ and:
 .Java
 [source,java,role="primary"]
 ----
-relyingPartyRegistrationBuilder.assertionConsumerServiceLocation("/saml2/login/sso")
+relyingPartyRegistrationBuilder.assertionConsumerServiceLocation("/saml/SSO")
+----
+
+.Kotlin
+[source,kotlin,role="secondary"]
+----
+relyingPartyRegistrationBuilder.assertionConsumerServiceLocation("/saml/SSO")
+----
+====
+
+[[relyingpartyregistrationresolver-apply]]
+== Changing `RelyingPartyRegistration` lookup
+
+By default, this converter will match against any associated `<saml2:AuthnRequest>` or any `registrationId` it finds in the URL.
+Or, if it cannot find one in either of those cases, then it attempts to look it up by the `<saml2:Response#Issuer>` element.
+
+There are a number of circumstances where you might need something more sophisticated, like if you are supporting `ARTIFACT` binding.
+In those cases, you can customize lookup through a custom `AuthenticationConverter`, which you can customize like so:
+
+====
+.Java
+[source,java,role="primary"]
+----
+@Bean
+SecurityFilterChain securityFilters(HttpSecurity http, AuthenticationConverter authenticationConverter) throws Exception {
+	http
+        // ...
+        .saml2Login((saml2) -> saml2.authenticationConverter(authenticationConverter))
+        // ...
+
+    return http.build();
+}
 ----
 
 .Kotlin
 [source,kotlin,role="secondary"]
 ----
-relyingPartyRegistrationBuilder.assertionConsumerServiceLocation("/saml2/login/sso")
+@Bean
+fun securityFilters(val http: HttpSecurity, val converter: AuthenticationConverter): SecurityFilterChain {
+	http {
+        // ...
+        .saml2Login {
+            authenticationConverter = converter
+        }
+        // ...
+    }
+
+    return http.build()
+}
 ----
 ====
 

+ 142 - 20
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
 
 package org.springframework.security.saml2.provider.service.web;
 
+import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
 import java.nio.charset.StandardCharsets;
 import java.util.Arrays;
@@ -25,8 +26,18 @@ import java.util.zip.Inflater;
 import java.util.zip.InflaterOutputStream;
 
 import jakarta.servlet.http.HttpServletRequest;
+import net.shibboleth.utilities.java.support.xml.ParserPool;
+import org.opensaml.core.config.ConfigurationService;
+import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
+import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
+import org.opensaml.saml.saml2.core.Response;
+import org.opensaml.saml.saml2.core.impl.ResponseUnmarshaller;
+import org.w3c.dom.Document;
+import org.w3c.dom.Element;
 
 import org.springframework.http.HttpMethod;
+import org.springframework.security.saml2.Saml2Exception;
+import org.springframework.security.saml2.core.OpenSamlInitializationService;
 import org.springframework.security.saml2.core.Saml2Error;
 import org.springframework.security.saml2.core.Saml2ErrorCodes;
 import org.springframework.security.saml2.core.Saml2ParameterNames;
@@ -34,7 +45,12 @@ import org.springframework.security.saml2.provider.service.authentication.Abstra
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
+import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver;
 import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
+import org.springframework.security.web.util.matcher.OrRequestMatcher;
+import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
 
 /**
@@ -43,9 +59,13 @@ import org.springframework.util.Assert;
  * {@link org.springframework.security.authentication.AuthenticationManager}.
  *
  * @author Josh Cummings
- * @since 5.4
+ * @since 6.1
  */
-public final class Saml2AuthenticationTokenConverter implements AuthenticationConverter {
+public final class OpenSamlAuthenticationTokenConverter implements AuthenticationConverter {
+
+	static {
+		OpenSamlInitializationService.initialize();
+	}
 
 	// MimeDecoder allows extra line-breaks as well as other non-alphabet values.
 	// This matches the behaviour of the commons-codec decoder.
@@ -53,39 +73,120 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
 
 	private static final Base64Checker BASE_64_CHECKER = new Base64Checker();
 
-	private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
+	private final RelyingPartyRegistrationRepository registrations;
+
+	private RequestMatcher requestMatcher = new OrRequestMatcher(
+			new AntPathRequestMatcher("/login/saml2/sso/{registrationId}"),
+			new AntPathRequestMatcher("/login/saml2/sso"));
+
+	private final ParserPool parserPool;
+
+	private final ResponseUnmarshaller unmarshaller;
 
 	private Function<HttpServletRequest, AbstractSaml2AuthenticationRequest> loader;
 
 	/**
-	 * Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for
-	 * resolving {@link RelyingPartyRegistration}s
-	 * @param relyingPartyRegistrationResolver the strategy for resolving
+	 * Constructs a {@link OpenSamlAuthenticationTokenConverter} given a repository for
+	 * {@link RelyingPartyRegistration}s
+	 * @param registrations the repository for {@link RelyingPartyRegistration}s
 	 * {@link RelyingPartyRegistration}s
 	 */
-	public Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
-		Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
-		this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
+	public OpenSamlAuthenticationTokenConverter(RelyingPartyRegistrationRepository registrations) {
+		Assert.notNull(registrations, "relyingPartyRegistrationRepository cannot be null");
+		XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class);
+		this.parserPool = registry.getParserPool();
+		this.unmarshaller = (ResponseUnmarshaller) XMLObjectProviderRegistrySupport.getUnmarshallerFactory()
+				.getUnmarshaller(Response.DEFAULT_ELEMENT_NAME);
+		this.registrations = registrations;
 		this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest;
 	}
 
+	/**
+	 * Resolve an authentication request from the given {@link HttpServletRequest}.
+	 *
+	 * <p>
+	 * First uses the configured {@link RequestMatcher} to deduce whether an
+	 * authentication request is being made and optionally for which
+	 * {@code registrationId}.
+	 *
+	 * <p>
+	 * If there is an associated {@code <saml2:AuthnRequest>}, then the
+	 * {@code registrationId} is looked up and used.
+	 *
+	 * <p>
+	 * If a {@code registrationId} is found in the request, then it is looked up and used.
+	 * In that case, if none is found a {@link Saml2AuthenticationException} is thrown.
+	 *
+	 * <p>
+	 * Finally, if no {@code registrationId} is found in the request, then the code
+	 * attempts to resolve the {@link RelyingPartyRegistration} from the SAML Response's
+	 * Issuer.
+	 * @param request the HTTP request
+	 * @return the {@link Saml2AuthenticationToken} authentication request
+	 * @throws Saml2AuthenticationException if the {@link RequestMatcher} specifies a
+	 * non-existent {@code registrationId}
+	 */
 	@Override
 	public Saml2AuthenticationToken convert(HttpServletRequest request) {
+		String serialized = request.getParameter(Saml2ParameterNames.SAML_RESPONSE);
+		if (serialized == null) {
+			return null;
+		}
+		RequestMatcher.MatchResult result = this.requestMatcher.matcher(request);
+		if (!result.isMatch()) {
+			return null;
+		}
+		Saml2AuthenticationToken token = tokenByAuthenticationRequest(request);
+		if (token == null) {
+			token = tokenByRegistrationId(request, result);
+		}
+		if (token == null) {
+			token = tokenByEntityId(request);
+		}
+		return token;
+	}
+
+	private Saml2AuthenticationToken tokenByAuthenticationRequest(HttpServletRequest request) {
 		AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
-		String relyingPartyRegistrationId = (authenticationRequest != null)
-				? authenticationRequest.getRelyingPartyRegistrationId() : null;
-		RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.resolve(request,
-				relyingPartyRegistrationId);
-		if (relyingPartyRegistration == null) {
+		if (authenticationRequest == null) {
+			return null;
+		}
+		String registrationId = authenticationRequest.getRelyingPartyRegistrationId();
+		RelyingPartyRegistration registration = this.registrations.findByRegistrationId(registrationId);
+		return tokenByRegistration(request, registration, authenticationRequest);
+	}
+
+	private Saml2AuthenticationToken tokenByRegistrationId(HttpServletRequest request,
+			RequestMatcher.MatchResult result) {
+		String registrationId = result.getVariables().get("registrationId");
+		if (registrationId == null) {
 			return null;
 		}
-		String saml2Response = request.getParameter(Saml2ParameterNames.SAML_RESPONSE);
-		if (saml2Response == null) {
+		RelyingPartyRegistration registration = this.registrations.findByRegistrationId(registrationId);
+		return tokenByRegistration(request, registration, null);
+	}
+
+	private Saml2AuthenticationToken tokenByEntityId(HttpServletRequest request) {
+		String serialized = request.getParameter(Saml2ParameterNames.SAML_RESPONSE);
+		String decoded = new String(samlDecode(serialized), StandardCharsets.UTF_8);
+		Response response = parse(decoded);
+		String issuer = response.getIssuer().getValue();
+		RelyingPartyRegistration registration = this.registrations.findUniqueByAssertingPartyEntityId(issuer);
+		return tokenByRegistration(request, registration, null);
+	}
+
+	private Saml2AuthenticationToken tokenByRegistration(HttpServletRequest request,
+			RelyingPartyRegistration registration, AbstractSaml2AuthenticationRequest authenticationRequest) {
+		if (registration == null) {
 			return null;
 		}
-		byte[] b = samlDecode(saml2Response);
-		saml2Response = inflateIfRequired(request, b);
-		return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response, authenticationRequest);
+		String serialized = request.getParameter(Saml2ParameterNames.SAML_RESPONSE);
+		String decoded = inflateIfRequired(request, samlDecode(serialized));
+		UriResolver resolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
+		registration = registration.mutate().entityId(resolver.resolve(registration.getEntityId()))
+				.assertionConsumerServiceLocation(resolver.resolve(registration.getAssertionConsumerServiceLocation()))
+				.build();
+		return new Saml2AuthenticationToken(registration, decoded, authenticationRequest);
 	}
 
 	/**
@@ -100,6 +201,15 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
 		this.loader = authenticationRequestRepository::loadAuthenticationRequest;
 	}
 
+	/**
+	 * Use the given {@link RequestMatcher} to match the request.
+	 * @param requestMatcher the {@link RequestMatcher} to use
+	 */
+	public void setRequestMatcher(RequestMatcher requestMatcher) {
+		Assert.notNull(requestMatcher, "requestMatcher cannot be null");
+		this.requestMatcher = requestMatcher;
+	}
+
 	private AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
 		return this.loader.apply(request);
 	}
@@ -136,6 +246,18 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
 		}
 	}
 
+	private Response parse(String request) throws Saml2Exception {
+		try {
+			Document document = this.parserPool
+					.parse(new ByteArrayInputStream(request.getBytes(StandardCharsets.UTF_8)));
+			Element element = document.getDocumentElement();
+			return (Response) this.unmarshaller.unmarshall(element);
+		}
+		catch (Exception ex) {
+			throw new Saml2Exception("Failed to deserialize LogoutRequest", ex);
+		}
+	}
+
 	static class Base64Checker {
 
 		private static final int[] values = genValueMapping();

+ 2 - 2
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java

@@ -105,7 +105,7 @@ public final class TestOpenSamlObjects {
 
 	public static String RELYING_PARTY_ENTITY_ID = "https://localhost/saml2/service-provider-metadata/idp-alias";
 
-	private static String ASSERTING_PARTY_ENTITY_ID = "https://some.idp.test/saml2/idp";
+	public static String ASSERTING_PARTY_ENTITY_ID = "https://some.idp.test/saml2/idp";
 
 	private static SecretKey SECRET_KEY = new SecretKeySpec(
 			Base64.getDecoder().decode("shOnwNMoCv88HKMEa91+FlYoD5RNvzMTAL5LGxZKIFk="), "AES");
@@ -113,7 +113,7 @@ public final class TestOpenSamlObjects {
 	private TestOpenSamlObjects() {
 	}
 
-	static Response response() {
+	public static Response response() {
 		return response(DESTINATION, ASSERTING_PARTY_ENTITY_ID);
 	}
 

+ 258 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverterTests.java

@@ -0,0 +1,258 @@
+/*
+ * Copyright 2002-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.web;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.time.Instant;
+
+import jakarta.servlet.http.HttpServletRequest;
+import net.shibboleth.utilities.java.support.xml.SerializeSupport;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import org.opensaml.core.xml.XMLObject;
+import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
+import org.opensaml.core.xml.io.Marshaller;
+import org.opensaml.core.xml.io.MarshallingException;
+import org.opensaml.saml.common.SignableSAMLObject;
+import org.opensaml.saml.saml2.core.Response;
+import org.w3c.dom.Element;
+
+import org.springframework.core.io.ClassPathResource;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.security.saml2.Saml2Exception;
+import org.springframework.security.saml2.core.Saml2ErrorCodes;
+import org.springframework.security.saml2.core.Saml2ParameterNames;
+import org.springframework.security.saml2.core.Saml2Utils;
+import org.springframework.security.saml2.core.TestSaml2X509Credentials;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
+import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
+import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
+import org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
+import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
+import org.springframework.util.StreamUtils;
+import org.springframework.web.util.UriUtils;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Tests for {@link OpenSamlAuthenticationTokenConverter}
+ */
+@ExtendWith(MockitoExtension.class)
+public final class OpenSamlAuthenticationTokenConverterTests {
+
+	@Mock
+	RelyingPartyRegistrationRepository registrations;
+
+	RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build();
+
+	@Test
+	public void convertWhenSamlResponseThenToken() {
+		OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations);
+		given(this.registrations.findByRegistrationId(any())).willReturn(this.registration);
+		MockHttpServletRequest request = post("/login/saml2/sso/" + this.registration.getRegistrationId());
+		request.setParameter(Saml2ParameterNames.SAML_RESPONSE,
+				Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8)));
+		Saml2AuthenticationToken token = converter.convert(request);
+		assertThat(token.getSaml2Response()).isEqualTo("response");
+		assertThat(token.getRelyingPartyRegistration().getRegistrationId())
+				.isEqualTo(this.registration.getRegistrationId());
+	}
+
+	@Test
+	public void convertWhenSamlResponseInvalidBase64ThenSaml2AuthenticationException() {
+		OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations);
+		given(this.registrations.findByRegistrationId(any())).willReturn(this.registration);
+		MockHttpServletRequest request = post("/login/saml2/sso/" + this.registration.getRegistrationId());
+		request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "invalid");
+		assertThatExceptionOfType(Saml2AuthenticationException.class).isThrownBy(() -> converter.convert(request))
+				.withCauseInstanceOf(IllegalArgumentException.class)
+				.satisfies((ex) -> assertThat(ex.getSaml2Error().getErrorCode())
+						.isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE))
+				.satisfies((ex) -> assertThat(ex.getSaml2Error().getDescription())
+						.isEqualTo("Failed to decode SAMLResponse"));
+	}
+
+	@Test
+	public void convertWhenNoSamlResponseThenNull() {
+		OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations);
+		MockHttpServletRequest request = post("/login/saml2/sso/" + this.registration.getRegistrationId());
+		assertThat(converter.convert(request)).isNull();
+	}
+
+	@Test
+	public void convertWhenNoMatchingRequestThenNull() {
+		OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "ignored");
+		assertThat(converter.convert(request)).isNull();
+	}
+
+	@Test
+	public void convertWhenNoRelyingPartyRegistrationThenNull() {
+		OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations);
+		MockHttpServletRequest request = post("/login/saml2/sso/" + this.registration.getRegistrationId());
+		String response = Saml2Utils.samlEncode(serialize(signed(response())).getBytes(StandardCharsets.UTF_8));
+		request.setParameter(Saml2ParameterNames.SAML_RESPONSE, response);
+		assertThat(converter.convert(request)).isNull();
+	}
+
+	@Test
+	public void convertWhenGetRequestThenInflates() {
+		OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations);
+		given(this.registrations.findByRegistrationId(any())).willReturn(this.registration);
+		MockHttpServletRequest request = get("/login/saml2/sso/" + this.registration.getRegistrationId());
+		byte[] deflated = Saml2Utils.samlDeflate("response");
+		String encoded = Saml2Utils.samlEncode(deflated);
+		request.setParameter(Saml2ParameterNames.SAML_RESPONSE, encoded);
+		Saml2AuthenticationToken token = converter.convert(request);
+		assertThat(token.getSaml2Response()).isEqualTo("response");
+		assertThat(token.getRelyingPartyRegistration().getRegistrationId())
+				.isEqualTo(this.registration.getRegistrationId());
+	}
+
+	@Test
+	public void convertWhenGetRequestInvalidDeflatedThenSaml2AuthenticationException() {
+		OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations);
+		given(this.registrations.findByRegistrationId(any())).willReturn(this.registration);
+		MockHttpServletRequest request = get("/login/saml2/sso/" + this.registration.getRegistrationId());
+		byte[] invalidDeflated = "invalid".getBytes();
+		String encoded = Saml2Utils.samlEncode(invalidDeflated);
+		request.setParameter(Saml2ParameterNames.SAML_RESPONSE, encoded);
+		assertThatExceptionOfType(Saml2AuthenticationException.class).isThrownBy(() -> converter.convert(request))
+				.withCauseInstanceOf(IOException.class)
+				.satisfies((ex) -> assertThat(ex.getSaml2Error().getErrorCode())
+						.isEqualTo(Saml2ErrorCodes.INVALID_RESPONSE))
+				.satisfies(
+						(ex) -> assertThat(ex.getSaml2Error().getDescription()).isEqualTo("Unable to inflate string"));
+	}
+
+	@Test
+	public void convertWhenUsingSamlUtilsBase64ThenXmlIsValid() throws Exception {
+		OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations);
+		given(this.registrations.findByRegistrationId(any())).willReturn(this.registration);
+		MockHttpServletRequest request = post("/login/saml2/sso/" + this.registration.getRegistrationId());
+		request.setParameter(Saml2ParameterNames.SAML_RESPONSE, getSsoCircleEncodedXml());
+		Saml2AuthenticationToken token = converter.convert(request);
+		validateSsoCircleXml(token.getSaml2Response());
+	}
+
+	@Test
+	public void convertWhenSavedAuthenticationRequestThenToken() {
+		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
+				Saml2AuthenticationRequestRepository.class);
+		AbstractSaml2AuthenticationRequest authenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
+		given(authenticationRequest.getRelyingPartyRegistrationId()).willReturn(this.registration.getRegistrationId());
+		OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations);
+		converter.setAuthenticationRequestRepository(authenticationRequestRepository);
+		given(this.registrations.findByRegistrationId(any())).willReturn(this.registration);
+		given(authenticationRequestRepository.loadAuthenticationRequest(any(HttpServletRequest.class)))
+				.willReturn(authenticationRequest);
+		MockHttpServletRequest request = post("/login/saml2/sso/" + this.registration.getRegistrationId());
+		request.setParameter(Saml2ParameterNames.SAML_RESPONSE,
+				Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8)));
+		Saml2AuthenticationToken token = converter.convert(request);
+		assertThat(token.getSaml2Response()).isEqualTo("response");
+		assertThat(token.getRelyingPartyRegistration().getRegistrationId())
+				.isEqualTo(this.registration.getRegistrationId());
+		assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest);
+	}
+
+	@Test
+	public void convertWhenMatchingNoRegistrationIdThenLooksUpByAssertingEntityId() {
+		OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations);
+		String response = serialize(signed(response()));
+		String encoded = Saml2Utils.samlEncode(response.getBytes(StandardCharsets.UTF_8));
+		given(this.registrations.findUniqueByAssertingPartyEntityId(TestOpenSamlObjects.ASSERTING_PARTY_ENTITY_ID))
+				.willReturn(this.registration);
+		MockHttpServletRequest request = post("/login/saml2/sso");
+		request.setParameter(Saml2ParameterNames.SAML_RESPONSE, encoded);
+		Saml2AuthenticationToken token = converter.convert(request);
+		assertThat(token.getSaml2Response()).isEqualTo(response);
+		assertThat(token.getRelyingPartyRegistration().getRegistrationId())
+				.isEqualTo(this.registration.getRegistrationId());
+	}
+
+	@Test
+	public void constructorWhenResolverIsNullThenIllegalArgument() {
+		assertThatIllegalArgumentException().isThrownBy(() -> new Saml2AuthenticationTokenConverter(null));
+	}
+
+	@Test
+	public void setAuthenticationRequestRepositoryWhenNullThenIllegalArgument() {
+		OpenSamlAuthenticationTokenConverter converter = new OpenSamlAuthenticationTokenConverter(this.registrations);
+		assertThatExceptionOfType(IllegalArgumentException.class)
+				.isThrownBy(() -> converter.setAuthenticationRequestRepository(null));
+	}
+
+	private void validateSsoCircleXml(String xml) {
+		assertThat(xml).contains("InResponseTo=\"ARQ9a73ead-7dcf-45a8-89eb-26f3c9900c36\"")
+				.contains(" ID=\"s246d157446618e90e43fb79bdd4d9e9e19cf2c7c4\"")
+				.contains("<saml:Issuer>https://idp.ssocircle.com</saml:Issuer>");
+	}
+
+	private String getSsoCircleEncodedXml() throws IOException {
+		ClassPathResource resource = new ClassPathResource("saml2-response-sso-circle.encoded");
+		String response = StreamUtils.copyToString(resource.getInputStream(), StandardCharsets.UTF_8);
+		return UriUtils.decode(response, StandardCharsets.UTF_8);
+	}
+
+	private MockHttpServletRequest post(String uri) {
+		MockHttpServletRequest request = new MockHttpServletRequest("POST", uri);
+		request.setServletPath(uri);
+		return request;
+	}
+
+	private MockHttpServletRequest get(String uri) {
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", uri);
+		request.setServletPath(uri);
+		return request;
+	}
+
+	private <T extends SignableSAMLObject> T signed(T toSign) {
+		TestOpenSamlObjects.signed(toSign, TestSaml2X509Credentials.assertingPartySigningCredential(),
+				TestOpenSamlObjects.RELYING_PARTY_ENTITY_ID);
+		return toSign;
+	}
+
+	private Response response() {
+		Response response = TestOpenSamlObjects.response();
+		response.setIssueInstant(Instant.now());
+		return response;
+	}
+
+	private String serialize(XMLObject object) {
+		try {
+			Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object);
+			Element element = marshaller.marshall(object);
+			return SerializeSupport.nodeToString(element);
+		}
+		catch (MarshallingException ex) {
+			throw new Saml2Exception(ex);
+		}
+	}
+
+}