فهرست منبع

Add Saml2AuthenticationRequestRepository

Closes gh-9185
Marcus Da Coregio 4 سال پیش
والد
کامیت
16e17d242e
15فایلهای تغییر یافته به همراه655 افزوده شده و 17 حذف شده
  1. 26 3
      config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java
  2. 48 1
      config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java
  3. 30 0
      docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc
  4. 1 0
      saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle
  5. 37 4
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationToken.java
  6. 73 0
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/HttpSessionSaml2AuthenticationRequestRepository.java
  7. 60 0
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/Saml2AuthenticationRequestRepository.java
  8. 32 1
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java
  9. 26 5
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java
  10. 26 1
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java
  11. 38 0
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestSaml2AuthenticationTokens.java
  12. 143 0
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/HttpSessionSaml2AuthenticationRequestRepositoryTests.java
  13. 50 1
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java
  14. 41 1
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java
  15. 24 0
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java

+ 26 - 3
config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -33,6 +33,7 @@ import org.springframework.security.config.annotation.web.configurers.AbstractAu
 import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
 import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider;
 import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
@@ -40,6 +41,8 @@ import org.springframework.security.saml2.provider.service.authentication.OpenSa
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
+import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository;
+import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
 import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
 import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
 import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
@@ -206,6 +209,7 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
 		}
 		this.saml2WebSsoAuthenticationFilter = new Saml2WebSsoAuthenticationFilter(getAuthenticationConverter(http),
 				this.loginProcessingUrl);
+		setAuthenticationRequestRepository(http, this.saml2WebSsoAuthenticationFilter);
 		setAuthenticationFilter(this.saml2WebSsoAuthenticationFilter);
 		super.loginProcessingUrl(this.loginProcessingUrl);
 		if (StringUtils.hasText(this.loginPage)) {
@@ -252,6 +256,11 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
 		}
 	}
 
+	private void setAuthenticationRequestRepository(B http,
+			Saml2WebSsoAuthenticationFilter saml2WebSsoAuthenticationFilter) {
+		saml2WebSsoAuthenticationFilter.setAuthenticationRequestRepository(getAuthenticationRequestRepository(http));
+	}
+
 	private AuthenticationConverter getAuthenticationConverter(B http) {
 		if (this.authenticationConverter == null) {
 			return new Saml2AuthenticationTokenConverter(
@@ -311,6 +320,16 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
 		return idps;
 	}
 
+	private Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> getAuthenticationRequestRepository(
+			B http) {
+		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = getBeanOrNull(http,
+				Saml2AuthenticationRequestRepository.class);
+		if (repository == null) {
+			return new HttpSessionSaml2AuthenticationRequestRepository();
+		}
+		return repository;
+	}
+
 	private <C> C getSharedOrBean(B http, Class<C> clazz) {
 		C shared = http.getSharedObject(clazz);
 		if (shared != null) {
@@ -348,8 +367,12 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
 		private Filter build(B http) {
 			Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
 			Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http);
-			return postProcess(
-					new Saml2WebSsoAuthenticationRequestFilter(contextResolver, authenticationRequestResolver));
+			Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = getAuthenticationRequestRepository(
+					http);
+			Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(contextResolver,
+					authenticationRequestResolver);
+			filter.setAuthenticationRequestRepository(repository);
+			return postProcess(filter);
 		}
 
 		private Saml2AuthenticationRequestFactory getResolver(B http) {

+ 48 - 1
config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java

@@ -63,6 +63,7 @@ import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMap
 import org.springframework.security.saml2.core.Saml2ErrorCodes;
 import org.springframework.security.saml2.core.Saml2Utils;
 import org.springframework.security.saml2.core.TestSaml2X509Credentials;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider;
 import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
@@ -76,9 +77,11 @@ import org.springframework.security.saml2.provider.service.authentication.TestSa
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
+import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
 import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
 import org.springframework.security.web.FilterChainProxy;
+import org.springframework.security.web.SecurityFilterChain;
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.context.HttpRequestResponseHolder;
@@ -237,6 +240,29 @@ public class Saml2LoginConfigurerTests {
 		assertThat(exception.getCause()).isInstanceOf(IOException.class);
 	}
 
+	@Test
+	public void authenticationRequestWhenCustomAuthnRequestRepositoryThenUses() throws Exception {
+		this.spring.register(CustomAuthenticationRequestRepository.class).autowire();
+		MockHttpServletRequestBuilder request = get("/saml2/authenticate/registration-id");
+		this.mvc.perform(request).andExpect(status().isFound());
+		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = this.spring.getContext()
+				.getBean(Saml2AuthenticationRequestRepository.class);
+		verify(repository).saveAuthenticationRequest(any(AbstractSaml2AuthenticationRequest.class),
+				any(HttpServletRequest.class), any(HttpServletResponse.class));
+	}
+
+	@Test
+	public void authenticateWhenCustomAuthnRequestRepositoryThenUses() throws Exception {
+		this.spring.register(CustomAuthenticationRequestRepository.class).autowire();
+		MockHttpServletRequestBuilder request = post("/login/saml2/sso/registration-id").param("SAMLResponse",
+				SIGNED_RESPONSE);
+		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = this.spring.getContext()
+				.getBean(Saml2AuthenticationRequestRepository.class);
+		this.mvc.perform(request);
+		verify(repository).loadAuthenticationRequest(any(HttpServletRequest.class));
+		verify(repository).removeAuthenticationRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
+	}
+
 	private void validateSaml2WebSsoAuthenticationFilterConfiguration() {
 		// get the OpenSamlAuthenticationProvider
 		Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain);
@@ -371,7 +397,7 @@ public class Saml2LoginConfigurerTests {
 
 		@Bean
 		Saml2AuthenticationRequestContextResolver resolver() {
-			return resolver;
+			return this.resolver;
 		}
 
 	}
@@ -420,6 +446,27 @@ public class Saml2LoginConfigurerTests {
 
 	}
 
+	@EnableWebSecurity
+	@Import(Saml2LoginConfigBeans.class)
+	static class CustomAuthenticationRequestRepository {
+
+		private static final Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = mock(
+				Saml2AuthenticationRequestRepository.class);
+
+		@Bean
+		SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
+			http.authorizeRequests((authz) -> authz.anyRequest().authenticated());
+			http.saml2Login(withDefaults());
+			return http.build();
+		}
+
+		@Bean
+		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository() {
+			return this.repository;
+		}
+
+	}
+
 	static class Saml2LoginConfigBeans {
 
 		@Bean

+ 30 - 0
docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc

@@ -1610,3 +1610,33 @@ http {
 The success handler will send logout requests to the asserting party.
 
 The request matcher will detect logout requests from the asserting party.
+
+[[servlet-saml2login-store-authn-request]]
+=== Storing the `AuthnRequest`
+
+The `Saml2AuthenticationRequestRepository` is responsible for the persistence of the `AuthnRequest` from the time the `AuthnRequest` <<servlet-saml2login-sp-initiated-factory,is initiated>> to the time the `SAMLResponse` <<servlet-saml2login-authenticate-responses,is received>>.
+The `Saml2AuthenticationTokenConverter` is responsible for loading the `AuthnRequest` from the `Saml2AuthenticationRequestRepository` and saving it into the `Saml2AuthenticationToken`.
+
+The default implementation of `Saml2AuthenticationRequestRepository` is `HttpSessionSaml2AuthenticationRequestRepository`, which stores the `AuthnRequest` in the `HttpSession`.
+
+If you have a custom implementation of `Saml2AuthenticationRequestRepository`, you may configure it by exposing it as a `@Bean` as shown in the following example:
+
+====
+.Java
+[source,java,role="primary"]
+----
+@Bean
+Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository() {
+	return new CustomSaml2AuthenticationRequestRepository();
+}
+----
+
+.Kotlin
+[source,kotlin,role="secondary"]
+----
+@Bean
+open fun authenticationRequestRepository(): Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> {
+    return CustomSaml2AuthenticationRequestRepository()
+}
+----
+====

+ 1 - 0
saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle

@@ -55,6 +55,7 @@ dependencies {
 	testImplementation "org.junit.jupiter:junit-jupiter-params"
 	testImplementation "org.junit.jupiter:junit-jupiter-engine"
 	testImplementation "org.mockito:mockito-core"
+	testImplementation "org.mockito:mockito-inline"
 	testImplementation "org.mockito:mockito-junit-jupiter"
 	testImplementation "org.springframework:spring-test"
 }

+ 37 - 4
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationToken.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -38,8 +38,10 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
 
 	private final String saml2Response;
 
+	private final AbstractSaml2AuthenticationRequest authenticationRequest;
+
 	/**
-	 * Creates a {@link Saml2AuthenticationToken} with the provided parameters
+	 * Creates a {@link Saml2AuthenticationToken} with the provided parameters.
 	 *
 	 * Note that the given {@link RelyingPartyRegistration} should have all its templates
 	 * resolved at this point. See
@@ -48,15 +50,35 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
 	 * @param relyingPartyRegistration the resolved {@link RelyingPartyRegistration} to
 	 * use
 	 * @param saml2Response the SAML 2.0 response to authenticate
+	 * @param authenticationRequest the {@code AuthNRequest} sent to the asserting party
 	 *
-	 * @since 5.4
+	 * @since 5.6
 	 */
-	public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, String saml2Response) {
+	public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, String saml2Response,
+			AbstractSaml2AuthenticationRequest authenticationRequest) {
 		super(Collections.emptyList());
 		Assert.notNull(relyingPartyRegistration, "relyingPartyRegistration cannot be null");
 		Assert.notNull(saml2Response, "saml2Response cannot be null");
 		this.relyingPartyRegistration = relyingPartyRegistration;
 		this.saml2Response = saml2Response;
+		this.authenticationRequest = authenticationRequest;
+	}
+
+	/**
+	 * Creates a {@link Saml2AuthenticationToken} with the provided parameters
+	 *
+	 * Note that the given {@link RelyingPartyRegistration} should have all its templates
+	 * resolved at this point. See
+	 * {@link org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter}
+	 * for an example of performing that resolution.
+	 * @param relyingPartyRegistration the resolved {@link RelyingPartyRegistration} to
+	 * use
+	 * @param saml2Response the SAML 2.0 response to authenticate
+	 *
+	 * @since 5.4
+	 */
+	public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, String saml2Response) {
+		this(relyingPartyRegistration, saml2Response, null);
 	}
 
 	/**
@@ -81,6 +103,7 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
 						.entityId(idpEntityId).singleSignOnServiceLocation(idpEntityId))
 				.build();
 		this.saml2Response = saml2Response;
+		this.authenticationRequest = null;
 	}
 
 	/**
@@ -179,4 +202,14 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
 		return this.relyingPartyRegistration.getAssertingPartyDetails().getEntityId();
 	}
 
+	/**
+	 * Returns the authentication request sent to the assertion party or {@code null} if
+	 * no authentication request is present
+	 * @return the authentication request sent to the assertion party
+	 * @since 5.6
+	 */
+	public AbstractSaml2AuthenticationRequest getAuthenticationRequest() {
+		return this.authenticationRequest;
+	}
+
 }

+ 73 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/HttpSessionSaml2AuthenticationRequestRepository.java

@@ -0,0 +1,73 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.servlet;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import javax.servlet.http.HttpSession;
+
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
+
+/**
+ * A {@link Saml2AuthenticationRequestRepository} implementation that uses
+ * {@link HttpSession} to store and retrieve the
+ * {@link AbstractSaml2AuthenticationRequest}
+ *
+ * @author Marcus Da Coregio
+ * @since 5.6
+ */
+public class HttpSessionSaml2AuthenticationRequestRepository
+		implements Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> {
+
+	private static final String DEFAULT_SAML2_AUTHN_REQUEST_ATTR_NAME = HttpSessionSaml2AuthenticationRequestRepository.class
+			.getName().concat(".SAML2_AUTHN_REQUEST");
+
+	private String saml2AuthnRequestAttributeName = DEFAULT_SAML2_AUTHN_REQUEST_ATTR_NAME;
+
+	@Override
+	public AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
+		HttpSession httpSession = request.getSession(false);
+		if (httpSession == null) {
+			return null;
+		}
+		return (AbstractSaml2AuthenticationRequest) httpSession.getAttribute(this.saml2AuthnRequestAttributeName);
+	}
+
+	@Override
+	public void saveAuthenticationRequest(AbstractSaml2AuthenticationRequest authenticationRequest,
+			HttpServletRequest request, HttpServletResponse response) {
+		if (authenticationRequest == null) {
+			removeAuthenticationRequest(request, response);
+			return;
+		}
+		HttpSession httpSession = request.getSession();
+		httpSession.setAttribute(this.saml2AuthnRequestAttributeName, authenticationRequest);
+	}
+
+	@Override
+	public AbstractSaml2AuthenticationRequest removeAuthenticationRequest(HttpServletRequest request,
+			HttpServletResponse response) {
+		AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
+		if (authenticationRequest == null) {
+			return null;
+		}
+		HttpSession httpSession = request.getSession();
+		httpSession.removeAttribute(this.saml2AuthnRequestAttributeName);
+		return authenticationRequest;
+	}
+
+}

+ 60 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/Saml2AuthenticationRequestRepository.java

@@ -0,0 +1,60 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.servlet;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
+
+/**
+ * A repository for {@link AbstractSaml2AuthenticationRequest}
+ *
+ * @param <T> the type of SAML 2.0 Authentication Request
+ * @author Marcus Da Coregio
+ * @since 5.6
+ */
+public interface Saml2AuthenticationRequestRepository<T extends AbstractSaml2AuthenticationRequest> {
+
+	/**
+	 * Loads the {@link AbstractSaml2AuthenticationRequest} from the request
+	 * @param request the current request
+	 * @return the {@link AbstractSaml2AuthenticationRequest} or {@code null} if it is not
+	 * present
+	 */
+	T loadAuthenticationRequest(HttpServletRequest request);
+
+	/**
+	 * Saves the current authentication request using the {@link HttpServletRequest} and
+	 * {@link HttpServletResponse}
+	 * @param authenticationRequest the {@link AbstractSaml2AuthenticationRequest}
+	 * @param request the current request
+	 * @param response the current response
+	 */
+	void saveAuthenticationRequest(T authenticationRequest, HttpServletRequest request, HttpServletResponse response);
+
+	/**
+	 * Removes the authentication request using the {@link HttpServletRequest} and
+	 * {@link HttpServletResponse}
+	 * @param request the current request
+	 * @param response the current response
+	 * @return the removed {@link AbstractSaml2AuthenticationRequest} or {@code null} if
+	 * it is not present
+	 */
+	T removeAuthenticationRequest(HttpServletRequest request, HttpServletResponse response);
+
+}

+ 32 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -23,8 +23,11 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.saml2.core.Saml2Error;
 import org.springframework.security.saml2.core.Saml2ErrorCodes;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
+import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository;
+import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
 import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
 import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
@@ -42,6 +45,8 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
 
 	private final AuthenticationConverter authenticationConverter;
 
+	private Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = new HttpSessionSaml2AuthenticationRequestRepository();
+
 	/**
 	 * Creates a {@code Saml2WebSsoAuthenticationFilter} authentication filter that is
 	 * configured to use the {@link #DEFAULT_FILTER_PROCESSES_URI} processing URL
@@ -100,7 +105,33 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
 					"No relying party registration found");
 			throw new Saml2AuthenticationException(saml2Error);
 		}
+		this.authenticationRequestRepository.removeAuthenticationRequest(request, response);
 		return getAuthenticationManager().authenticate(authentication);
 	}
 
+	/**
+	 * Use the given {@link Saml2AuthenticationRequestRepository} to remove the saved
+	 * authentication request. If the {@link #authenticationConverter} is of the type
+	 * {@link Saml2AuthenticationTokenConverter}, the
+	 * {@link Saml2AuthenticationRequestRepository} will also be set into the
+	 * {@link #authenticationConverter}.
+	 * @param authenticationRequestRepository the
+	 * {@link Saml2AuthenticationRequestRepository} to use
+	 * @since 5.6
+	 */
+	public void setAuthenticationRequestRepository(
+			Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository) {
+		Assert.notNull(authenticationRequestRepository, "authenticationRequestRepository cannot be null");
+		this.authenticationRequestRepository = authenticationRequestRepository;
+		setAuthenticationRequestRepositoryIntoAuthenticationConverter(authenticationRequestRepository);
+	}
+
+	private void setAuthenticationRequestRepositoryIntoAuthenticationConverter(
+			Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository) {
+		if (this.authenticationConverter instanceof Saml2AuthenticationTokenConverter) {
+			Saml2AuthenticationTokenConverter authenticationTokenConverter = (Saml2AuthenticationTokenConverter) this.authenticationConverter;
+			authenticationTokenConverter.setAuthenticationRequestRepository(authenticationRequestRepository);
+		}
+	}
+
 }

+ 26 - 5
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java

@@ -27,6 +27,7 @@ import javax.servlet.http.HttpServletResponse;
 import org.opensaml.core.Version;
 
 import org.springframework.http.MediaType;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
@@ -34,6 +35,8 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2R
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
+import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository;
+import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
 import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
 import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
@@ -79,6 +82,8 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
 
 	private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}");
 
+	private Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = new HttpSessionSaml2AuthenticationRequestRepository();
+
 	/**
 	 * Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided
 	 * parameters
@@ -149,6 +154,19 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
 		this.redirectMatcher = redirectMatcher;
 	}
 
+	/**
+	 * Use the given {@link Saml2AuthenticationRequestRepository} to save the
+	 * authentication request
+	 * @param authenticationRequestRepository the
+	 * {@link Saml2AuthenticationRequestRepository} to use
+	 * @since 5.6
+	 */
+	public void setAuthenticationRequestRepository(
+			Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository) {
+		Assert.notNull(authenticationRequestRepository, "authenticationRequestRepository cannot be null");
+		this.authenticationRequestRepository = authenticationRequestRepository;
+	}
+
 	@Override
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
@@ -165,17 +183,18 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
 		}
 		RelyingPartyRegistration relyingParty = context.getRelyingPartyRegistration();
 		if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) {
-			sendRedirect(response, context);
+			sendRedirect(request, response, context);
 		}
 		else {
-			sendPost(response, context);
+			sendPost(request, response, context);
 		}
 	}
 
-	private void sendRedirect(HttpServletResponse response, Saml2AuthenticationRequestContext context)
-			throws IOException {
+	private void sendRedirect(HttpServletRequest request, HttpServletResponse response,
+			Saml2AuthenticationRequestContext context) throws IOException {
 		Saml2RedirectAuthenticationRequest authenticationRequest = this.authenticationRequestFactory
 				.createRedirectAuthenticationRequest(context);
+		this.authenticationRequestRepository.saveAuthenticationRequest(authenticationRequest, request, response);
 		UriComponentsBuilder uriBuilder = UriComponentsBuilder
 				.fromUriString(authenticationRequest.getAuthenticationRequestUri());
 		addParameter("SAMLRequest", authenticationRequest.getSamlRequest(), uriBuilder);
@@ -194,9 +213,11 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
 		}
 	}
 
-	private void sendPost(HttpServletResponse response, Saml2AuthenticationRequestContext context) throws IOException {
+	private void sendPost(HttpServletRequest request, HttpServletResponse response,
+			Saml2AuthenticationRequestContext context) throws IOException {
 		Saml2PostAuthenticationRequest authenticationRequest = this.authenticationRequestFactory
 				.createPostAuthenticationRequest(context);
+		this.authenticationRequestRepository.saveAuthenticationRequest(authenticationRequest, request, response);
 		String html = createSamlPostRequestFormData(authenticationRequest);
 		response.setContentType(MediaType.TEXT_HTML_VALUE);
 		response.getWriter().write(html);

+ 26 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java

@@ -18,6 +18,7 @@ package org.springframework.security.saml2.provider.service.web;
 
 import java.io.ByteArrayOutputStream;
 import java.nio.charset.StandardCharsets;
+import java.util.function.Function;
 import java.util.zip.Inflater;
 import java.util.zip.InflaterOutputStream;
 
@@ -30,9 +31,12 @@ import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.HttpMethod;
 import org.springframework.security.saml2.core.Saml2Error;
 import org.springframework.security.saml2.core.Saml2ErrorCodes;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository;
+import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.util.Assert;
 
@@ -50,6 +54,8 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
 
 	private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;
 
+	private Function<HttpServletRequest, AbstractSaml2AuthenticationRequest> loader;
+
 	/**
 	 * Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for
 	 * resolving {@link RelyingPartyRegistration}s
@@ -60,6 +66,7 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
 			Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) {
 		Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
 		this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
+		this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest;
 	}
 
 	@Override
@@ -74,7 +81,25 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
 		}
 		byte[] b = samlDecode(saml2Response);
 		saml2Response = inflateIfRequired(request, b);
-		return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response);
+		AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
+		return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response, authenticationRequest);
+	}
+
+	/**
+	 * Use the given {@link Saml2AuthenticationRequestRepository} to load authentication
+	 * request.
+	 * @param authenticationRequestRepository the
+	 * {@link Saml2AuthenticationRequestRepository} to use
+	 * @since 5.6
+	 */
+	public void setAuthenticationRequestRepository(
+			Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository) {
+		Assert.notNull(authenticationRequestRepository, "authenticationRequestRepository cannot be null");
+		this.loader = authenticationRequestRepository::loadAuthenticationRequest;
+	}
+
+	private AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
+		return this.loader.apply(request);
 	}
 
 	private String inflateIfRequired(HttpServletRequest request, byte[] b) {

+ 38 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestSaml2AuthenticationTokens.java

@@ -0,0 +1,38 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.authentication;
+
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
+
+/**
+ * Tests instances for {@link Saml2AuthenticationToken}
+ *
+ * @author Marcus Da Coregio
+ */
+public final class TestSaml2AuthenticationTokens {
+
+	private TestSaml2AuthenticationTokens() {
+	}
+
+	public static Saml2AuthenticationToken token() {
+		RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.relyingPartyRegistration()
+				.build();
+		return new Saml2AuthenticationToken(relyingPartyRegistration, "saml2-xml-response-object");
+	}
+
+}

+ 143 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/HttpSessionSaml2AuthenticationRequestRepositoryTests.java

@@ -0,0 +1,143 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.servlet;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.mock.web.MockHttpSession;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+
+/**
+ * @author Marcus Da Coregio
+ */
+public class HttpSessionSaml2AuthenticationRequestRepositoryTests {
+
+	private static final String IDP_SSO_URL = "https://sso-url.example.com/IDP/SSO";
+
+	private MockHttpServletRequest request;
+
+	private MockHttpServletResponse response;
+
+	private HttpSessionSaml2AuthenticationRequestRepository authenticationRequestRepository;
+
+	@BeforeEach
+	public void setup() {
+		this.request = new MockHttpServletRequest();
+		this.response = new MockHttpServletResponse();
+		this.authenticationRequestRepository = new HttpSessionSaml2AuthenticationRequestRepository();
+	}
+
+	@Test
+	public void loadAuthenticationRequestWhenInvalidSessionThenNull() {
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
+				.loadAuthenticationRequest(this.request);
+		assertThat(authenticationRequest).isNull();
+	}
+
+	@Test
+	public void loadAuthenticationRequestWhenNoAttributeInSessionThenNull() {
+		this.request.getSession();
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
+				.loadAuthenticationRequest(this.request);
+		assertThat(authenticationRequest).isNull();
+	}
+
+	@Test
+	public void loadAuthenticationRequestWhenAttributeInSessionThenReturnsAuthenticationRequest() {
+		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
+		given(mockAuthenticationRequest.getAuthenticationRequestUri()).willReturn(IDP_SSO_URL);
+		this.request.getSession();
+		this.authenticationRequestRepository.saveAuthenticationRequest(mockAuthenticationRequest, this.request,
+				this.response);
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
+				.loadAuthenticationRequest(this.request);
+		assertThat(authenticationRequest.getAuthenticationRequestUri()).isEqualTo(IDP_SSO_URL);
+	}
+
+	@Test
+	public void saveAuthenticationRequestWhenSessionDontExistsThenCreateAndSave() {
+		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
+		this.authenticationRequestRepository.saveAuthenticationRequest(mockAuthenticationRequest, this.request,
+				this.response);
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
+				.loadAuthenticationRequest(this.request);
+		assertThat(authenticationRequest).isNotNull();
+	}
+
+	@Test
+	public void saveAuthenticationRequestWhenSessionExistsThenSave() {
+		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
+		this.request.getSession();
+		this.authenticationRequestRepository.saveAuthenticationRequest(mockAuthenticationRequest, this.request,
+				this.response);
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
+				.loadAuthenticationRequest(this.request);
+		assertThat(authenticationRequest).isNotNull();
+	}
+
+	@Test
+	public void saveAuthenticationRequestWhenNullAuthenticationRequestThenDontSave() {
+		this.request.getSession();
+		this.authenticationRequestRepository.saveAuthenticationRequest(null, this.request, this.response);
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
+				.loadAuthenticationRequest(this.request);
+		assertThat(authenticationRequest).isNull();
+	}
+
+	@Test
+	public void removeAuthenticationRequestWhenInvalidSessionThenReturnNull() {
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
+				.removeAuthenticationRequest(this.request, this.response);
+		assertThat(authenticationRequest).isNull();
+	}
+
+	@Test
+	public void removeAuthenticationRequestWhenAttributeInSessionThenRemoveAuthenticationRequest() {
+		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
+		given(mockAuthenticationRequest.getAuthenticationRequestUri()).willReturn(IDP_SSO_URL);
+		this.request.getSession();
+		this.authenticationRequestRepository.saveAuthenticationRequest(mockAuthenticationRequest, this.request,
+				this.response);
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
+				.removeAuthenticationRequest(this.request, this.response);
+		AbstractSaml2AuthenticationRequest authenticationRequestAfterRemove = this.authenticationRequestRepository
+				.loadAuthenticationRequest(this.request);
+		assertThat(authenticationRequest.getAuthenticationRequestUri()).isEqualTo(IDP_SSO_URL);
+		assertThat(authenticationRequestAfterRemove).isNull();
+	}
+
+	@Test
+	public void removeAuthenticationRequestWhenValidSessionNoAttributeThenReturnsNull() {
+		MockHttpSession session = mock(MockHttpSession.class);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.setSession(session);
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
+				.removeAuthenticationRequest(request, this.response);
+		verify(session).getAttribute(anyString());
+		assertThat(authenticationRequest).isNull();
+	}
+
+}

+ 50 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -24,12 +24,20 @@ import org.junit.jupiter.api.Test;
 
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
+import org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationTokens;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
+import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
+import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
+import org.springframework.security.web.authentication.AuthenticationConverter;
 
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
 
 public class Saml2WebSsoAuthenticationFilterTests {
 
@@ -84,4 +92,45 @@ public class Saml2WebSsoAuthenticationFilterTests {
 				.withMessage("No relying party registration found");
 	}
 
+	@Test
+	public void attemptAuthenticationWhenSavedAuthnRequestThenRemovesAuthnRequest() {
+		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
+				Saml2AuthenticationRequestRepository.class);
+		AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class);
+		given(authenticationConverter.convert(this.request)).willReturn(TestSaml2AuthenticationTokens.token());
+		this.filter = new Saml2WebSsoAuthenticationFilter(authenticationConverter, "/some/other/path/{registrationId}");
+		this.filter.setAuthenticationManager((authentication) -> null);
+		this.request.setPathInfo("/some/other/path/idp-registration-id");
+		this.filter.setAuthenticationRequestRepository(authenticationRequestRepository);
+		this.filter.attemptAuthentication(this.request, this.response);
+		verify(authenticationRequestRepository).removeAuthenticationRequest(this.request, this.response);
+	}
+
+	@Test
+	public void setAuthenticationRequestRepositoryWhenNullThenThrowsIllegalArgument() {
+		assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationRequestRepository(null))
+				.withMessage("authenticationRequestRepository cannot be null");
+	}
+
+	@Test
+	public void setAuthenticationRequestRepositoryWhenExpectedAuthenticationConverterTypeThenSetLoaderIntoConverter() {
+		Saml2AuthenticationTokenConverter authenticationConverterMock = mock(Saml2AuthenticationTokenConverter.class);
+		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
+				Saml2AuthenticationRequestRepository.class);
+		this.filter = new Saml2WebSsoAuthenticationFilter(authenticationConverterMock,
+				"/some/other/path/{registrationId}");
+		this.filter.setAuthenticationRequestRepository(authenticationRequestRepository);
+		verify(authenticationConverterMock).setAuthenticationRequestRepository(authenticationRequestRepository);
+	}
+
+	@Test
+	public void setAuthenticationRequestRepositoryWhenNotExpectedAuthenticationConverterTypeThenDontSet() {
+		AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class);
+		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
+				Saml2AuthenticationRequestRepository.class);
+		this.filter = new Saml2WebSsoAuthenticationFilter(authenticationConverter, "/some/other/path/{registrationId}");
+		this.filter.setAuthenticationRequestRepository(authenticationRequestRepository);
+		verifyNoInteractions(authenticationConverter);
+	}
+
 }

+ 41 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -28,6 +28,7 @@ import org.springframework.mock.web.MockFilterChain;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
@@ -36,6 +37,7 @@ import org.springframework.security.saml2.provider.service.authentication.TestSa
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
+import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
 import org.springframework.web.util.HtmlUtils;
 import org.springframework.web.util.UriUtils;
@@ -43,6 +45,7 @@ import org.springframework.web.util.UriUtils;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
@@ -60,6 +63,9 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
 
 	private Saml2AuthenticationRequestContextResolver resolver = mock(Saml2AuthenticationRequestContextResolver.class);
 
+	private Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
+			Saml2AuthenticationRequestRepository.class);
+
 	private MockHttpServletRequest request;
 
 	private MockHttpServletResponse response;
@@ -79,6 +85,7 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
 				.providerDetails((c) -> c.entityId("idp-entity-id")).providerDetails((c) -> c.webSsoUrl(IDP_SSO_URL))
 				.assertionConsumerServiceUrlTemplate("template")
 				.credentials((c) -> c.add(TestSaml2X509Credentials.assertingPartyPrivateCredential()));
+		this.filter.setAuthenticationRequestRepository(this.authenticationRequestRepository);
 	}
 
 	@Test
@@ -216,4 +223,37 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
 		assertThat(this.response.getStatus()).isEqualTo(401);
 	}
 
+	@Test
+	public void setAuthenticationRequestRepositoryWhenNullThenException() {
+		Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(this.resolver,
+				this.factory);
+		assertThatIllegalArgumentException().isThrownBy(() -> filter.setAuthenticationRequestRepository(null));
+	}
+
+	@Test
+	public void doFilterWhenRedirectThenSaveRedirectRequest() throws ServletException, IOException {
+		Saml2AuthenticationRequestContext context = authenticationRequestContext().build();
+		Saml2RedirectAuthenticationRequest request = redirectAuthenticationRequest(context).build();
+		given(this.resolver.resolve(any())).willReturn(context);
+		given(this.factory.createRedirectAuthenticationRequest(any())).willReturn(request);
+		this.filter.doFilterInternal(this.request, this.response, this.filterChain);
+		verify(this.authenticationRequestRepository).saveAuthenticationRequest(
+				any(Saml2RedirectAuthenticationRequest.class), eq(this.request), eq(this.response));
+	}
+
+	@Test
+	public void doFilterWhenPostThenSaveRedirectRequest() throws ServletException, IOException {
+		RelyingPartyRegistration registration = this.rpBuilder
+				.assertingPartyDetails((asserting) -> asserting.singleSignOnServiceBinding(Saml2MessageBinding.POST))
+				.build();
+		Saml2AuthenticationRequestContext context = authenticationRequestContext()
+				.relyingPartyRegistration(registration).build();
+		Saml2PostAuthenticationRequest request = postAuthenticationRequest(context).build();
+		given(this.resolver.resolve(any())).willReturn(context);
+		given(this.factory.createPostAuthenticationRequest(any())).willReturn(request);
+		this.filter.doFilterInternal(this.request, this.response, this.filterChain);
+		verify(this.authenticationRequestRepository).saveAuthenticationRequest(
+				any(Saml2PostAuthenticationRequest.class), eq(this.request), eq(this.response));
+	}
+
 }

+ 24 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java

@@ -31,10 +31,12 @@ import org.springframework.core.io.ClassPathResource;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.security.saml2.core.Saml2ErrorCodes;
 import org.springframework.security.saml2.core.Saml2Utils;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
+import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
 import org.springframework.util.StreamUtils;
 import org.springframework.web.util.UriUtils;
 
@@ -43,6 +45,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
 
 @ExtendWith(MockitoExtension.class)
 public class Saml2AuthenticationTokenConverterTests {
@@ -155,6 +158,27 @@ public class Saml2AuthenticationTokenConverterTests {
 		validateSsoCircleXml(token.getSaml2Response());
 	}
 
+	@Test
+	public void convertWhenSavedAuthenticationRequestThenToken() {
+		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
+				Saml2AuthenticationRequestRepository.class);
+		AbstractSaml2AuthenticationRequest authenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
+		Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
+				this.relyingPartyRegistrationResolver);
+		converter.setAuthenticationRequestRepository(authenticationRequestRepository);
+		given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
+				.willReturn(this.relyingPartyRegistration);
+		given(authenticationRequestRepository.loadAuthenticationRequest(any(HttpServletRequest.class)))
+				.willReturn(authenticationRequest);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.setParameter("SAMLResponse", Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8)));
+		Saml2AuthenticationToken token = converter.convert(request);
+		assertThat(token.getSaml2Response()).isEqualTo("response");
+		assertThat(token.getRelyingPartyRegistration().getRegistrationId())
+				.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
+		assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest);
+	}
+
 	private void validateSsoCircleXml(String xml) {
 		assertThat(xml).contains("InResponseTo=\"ARQ9a73ead-7dcf-45a8-89eb-26f3c9900c36\"")
 				.contains(" ID=\"s246d157446618e90e43fb79bdd4d9e9e19cf2c7c4\"")